{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Comparing different hierarchical linkage methods on toy datasets\n\nThis example shows characteristics of different linkage\nmethods for hierarchical clustering on datasets that are\n\"interesting\" but still in 2D.\n\nThe main observations to make are:\n\n- single linkage is fast, and can perform well on\n  non-globular data, but it performs poorly in the\n  presence of noise.\n- average and complete linkage perform well on\n  cleanly separated globular clusters, but have mixed\n  results otherwise.\n- Ward is the most effective method for noisy data.\n\nWhile these examples give some intuition about the\nalgorithms, this intuition might not apply to very high\ndimensional data.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport time\nimport warnings\nfrom itertools import cycle, islice\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn import cluster, datasets\nfrom sklearn.preprocessing import StandardScaler"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Generate datasets. We choose the size big enough to see the scalability\nof the algorithms, but not too big to avoid too long running times\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n_samples = 1500\nnoisy_circles = datasets.make_circles(\n    n_samples=n_samples, factor=0.5, noise=0.05, random_state=170\n)\nnoisy_moons = datasets.make_moons(n_samples=n_samples, noise=0.05, random_state=170)\nblobs = datasets.make_blobs(n_samples=n_samples, random_state=170)\nrng = np.random.RandomState(170)\nno_structure = rng.rand(n_samples, 2), None\n\n# Anisotropicly distributed data\nX, y = datasets.make_blobs(n_samples=n_samples, random_state=170)\ntransformation = [[0.6, -0.6], [-0.4, 0.8]]\nX_aniso = np.dot(X, transformation)\naniso = (X_aniso, y)\n\n# blobs with varied variances\nvaried = datasets.make_blobs(\n    n_samples=n_samples, cluster_std=[1.0, 2.5, 0.5], random_state=170\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Run the clustering and plot\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Set up cluster parameters\nplt.figure(figsize=(9 * 1.3 + 2, 14.5))\nplt.subplots_adjust(\n    left=0.02, right=0.98, bottom=0.001, top=0.96, wspace=0.05, hspace=0.01\n)\n\nplot_num = 1\n\ndefault_base = {\"n_neighbors\": 10, \"n_clusters\": 3}\n\ndatasets = [\n    (noisy_circles, {\"n_clusters\": 2}),\n    (noisy_moons, {\"n_clusters\": 2}),\n    (varied, {\"n_neighbors\": 2}),\n    (aniso, {\"n_neighbors\": 2}),\n    (blobs, {}),\n    (no_structure, {}),\n]\n\nfor i_dataset, (dataset, algo_params) in enumerate(datasets):\n    # update parameters with dataset-specific values\n    params = default_base.copy()\n    params.update(algo_params)\n\n    X, y = dataset\n\n    # normalize dataset for easier parameter selection\n    X = StandardScaler().fit_transform(X)\n\n    # ============\n    # Create cluster objects\n    # ============\n    ward = cluster.AgglomerativeClustering(\n        n_clusters=params[\"n_clusters\"], linkage=\"ward\"\n    )\n    complete = cluster.AgglomerativeClustering(\n        n_clusters=params[\"n_clusters\"], linkage=\"complete\"\n    )\n    average = cluster.AgglomerativeClustering(\n        n_clusters=params[\"n_clusters\"], linkage=\"average\"\n    )\n    single = cluster.AgglomerativeClustering(\n        n_clusters=params[\"n_clusters\"], linkage=\"single\"\n    )\n\n    clustering_algorithms = (\n        (\"Single Linkage\", single),\n        (\"Average Linkage\", average),\n        (\"Complete Linkage\", complete),\n        (\"Ward Linkage\", ward),\n    )\n\n    for name, algorithm in clustering_algorithms:\n        t0 = time.time()\n\n        # catch warnings related to kneighbors_graph\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\n                \"ignore\",\n                message=\"the number of connected components of the \"\n                \"connectivity matrix is [0-9]{1,2}\"\n                \" > 1. Completing it to avoid stopping the tree early.\",\n                category=UserWarning,\n            )\n            algorithm.fit(X)\n\n        t1 = time.time()\n        if hasattr(algorithm, \"labels_\"):\n            y_pred = algorithm.labels_.astype(int)\n        else:\n            y_pred = algorithm.predict(X)\n\n        plt.subplot(len(datasets), len(clustering_algorithms), plot_num)\n        if i_dataset == 0:\n            plt.title(name, size=18)\n\n        colors = np.array(\n            list(\n                islice(\n                    cycle(\n                        [\n                            \"#377eb8\",\n                            \"#ff7f00\",\n                            \"#4daf4a\",\n                            \"#f781bf\",\n                            \"#a65628\",\n                            \"#984ea3\",\n                            \"#999999\",\n                            \"#e41a1c\",\n                            \"#dede00\",\n                        ]\n                    ),\n                    int(max(y_pred) + 1),\n                )\n            )\n        )\n        plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])\n\n        plt.xlim(-2.5, 2.5)\n        plt.ylim(-2.5, 2.5)\n        plt.xticks(())\n        plt.yticks(())\n        plt.text(\n            0.99,\n            0.01,\n            (\"%.2fs\" % (t1 - t0)).lstrip(\"0\"),\n            transform=plt.gca().transAxes,\n            size=15,\n            horizontalalignment=\"right\",\n        )\n        plot_num += 1\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}