{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Neighborhood Components Analysis Illustration\n\nThis example illustrates a learned distance metric that maximizes\nthe nearest neighbors classification accuracy. It provides a visual\nrepresentation of this metric compared to the original point\nspace. Please refer to the `User Guide <nca>` for more information.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom matplotlib import cm\nfrom scipy.special import logsumexp\n\nfrom sklearn.datasets import make_classification\nfrom sklearn.neighbors import NeighborhoodComponentsAnalysis"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Original points\nFirst we create a data set of 9 samples from 3 classes, and plot the points\nin the original space. For this example, we focus on the classification of\npoint no. 3. The thickness of a link between point no. 3 and another point\nis proportional to their distance.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X, y = make_classification(\n    n_samples=9,\n    n_features=2,\n    n_informative=2,\n    n_redundant=0,\n    n_classes=3,\n    n_clusters_per_class=1,\n    class_sep=1.0,\n    random_state=0,\n)\n\nplt.figure(1)\nax = plt.gca()\nfor i in range(X.shape[0]):\n    ax.text(X[i, 0], X[i, 1], str(i), va=\"center\", ha=\"center\")\n    ax.scatter(X[i, 0], X[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)\n\nax.set_title(\"Original points\")\nax.axes.get_xaxis().set_visible(False)\nax.axes.get_yaxis().set_visible(False)\nax.axis(\"equal\")  # so that boundaries are displayed correctly as circles\n\n\ndef link_thickness_i(X, i):\n    diff_embedded = X[i] - X\n    dist_embedded = np.einsum(\"ij,ij->i\", diff_embedded, diff_embedded)\n    dist_embedded[i] = np.inf\n\n    # compute exponentiated distances (use the log-sum-exp trick to\n    # avoid numerical instabilities\n    exp_dist_embedded = np.exp(-dist_embedded - logsumexp(-dist_embedded))\n    return exp_dist_embedded\n\n\ndef relate_point(X, i, ax):\n    pt_i = X[i]\n    for j, pt_j in enumerate(X):\n        thickness = link_thickness_i(X, i)\n        if i != j:\n            line = ([pt_i[0], pt_j[0]], [pt_i[1], pt_j[1]])\n            ax.plot(*line, c=cm.Set1(y[j]), linewidth=5 * thickness[j])\n\n\ni = 3\nrelate_point(X, i, ax)\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Learning an embedding\nWe use :class:`~sklearn.neighbors.NeighborhoodComponentsAnalysis` to learn an\nembedding and plot the points after the transformation. We then take the\nembedding and find the nearest neighbors.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "nca = NeighborhoodComponentsAnalysis(max_iter=30, random_state=0)\nnca = nca.fit(X, y)\n\nplt.figure(2)\nax2 = plt.gca()\nX_embedded = nca.transform(X)\nrelate_point(X_embedded, i, ax2)\n\nfor i in range(len(X)):\n    ax2.text(X_embedded[i, 0], X_embedded[i, 1], str(i), va=\"center\", ha=\"center\")\n    ax2.scatter(X_embedded[i, 0], X_embedded[i, 1], s=300, c=cm.Set1(y[[i]]), alpha=0.4)\n\nax2.set_title(\"NCA embedding\")\nax2.axes.get_xaxis().set_visible(False)\nax2.axes.get_yaxis().set_visible(False)\nax2.axis(\"equal\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}