{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Online learning of a dictionary of parts of faces\n\nThis example uses a large dataset of faces to learn a set of 20 x 20\nimages patches that constitute faces.\n\nFrom the programming standpoint, it is interesting because it shows how\nto use the online API of the scikit-learn to process a very large\ndataset by chunks. The way we proceed is that we load an image at a time\nand extract randomly 50 patches from this image. Once we have accumulated\n500 of these patches (using 10 images), we run the\n:func:`~sklearn.cluster.MiniBatchKMeans.partial_fit` method\nof the online KMeans object, MiniBatchKMeans.\n\nThe verbose setting on the MiniBatchKMeans enables us to see that some\nclusters are reassigned during the successive calls to\npartial-fit. This is because the number of patches that they represent\nhas become too low, and it is better to choose a random new\ncluster.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Load the data\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from sklearn import datasets\n\nfaces = datasets.fetch_olivetti_faces()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Learn the dictionary of images\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import time\n\nimport numpy as np\n\nfrom sklearn.cluster import MiniBatchKMeans\nfrom sklearn.feature_extraction.image import extract_patches_2d\n\nprint(\"Learning the dictionary... \")\nrng = np.random.RandomState(0)\nkmeans = MiniBatchKMeans(n_clusters=81, random_state=rng, verbose=True, n_init=3)\npatch_size = (20, 20)\n\nbuffer = []\nt0 = time.time()\n\n# The online learning part: cycle over the whole dataset 6 times\nindex = 0\nfor _ in range(6):\n    for img in faces.images:\n        data = extract_patches_2d(img, patch_size, max_patches=50, random_state=rng)\n        data = np.reshape(data, (len(data), -1))\n        buffer.append(data)\n        index += 1\n        if index % 10 == 0:\n            data = np.concatenate(buffer, axis=0)\n            data -= np.mean(data, axis=0)\n            data /= np.std(data, axis=0)\n            kmeans.partial_fit(data)\n            buffer = []\n        if index % 100 == 0:\n            print(\"Partial fit of %4i out of %i\" % (index, 6 * len(faces.images)))\n\ndt = time.time() - t0\nprint(\"done in %.2fs.\" % dt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Plot the results\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n\nplt.figure(figsize=(4.2, 4))\nfor i, patch in enumerate(kmeans.cluster_centers_):\n    plt.subplot(9, 9, i + 1)\n    plt.imshow(patch.reshape(patch_size), cmap=plt.cm.gray, interpolation=\"nearest\")\n    plt.xticks(())\n    plt.yticks(())\n\n\nplt.suptitle(\n    \"Patches of faces\\nTrain time %.1fs on %d patches\" % (dt, 8 * len(faces.images)),\n    fontsize=16,\n)\nplt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}