{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Compare Stochastic learning strategies for MLPClassifier\n\nThis example visualizes some training loss curves for different stochastic\nlearning strategies, including SGD and Adam. Because of time-constraints, we\nuse several small datasets, for which L-BFGS might be more suitable. The\ngeneral trend shown in these examples seems to carry over to larger datasets,\nhowever.\n\nNote that those results can be highly dependent on the value of\n``learning_rate_init``.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport warnings\n\nimport matplotlib.pyplot as plt\n\nfrom sklearn import datasets\nfrom sklearn.exceptions import ConvergenceWarning\nfrom sklearn.neural_network import MLPClassifier\nfrom sklearn.preprocessing import MinMaxScaler\n\n# different learning rate schedules and momentum parameters\nparams = [\n    {\n        \"solver\": \"sgd\",\n        \"learning_rate\": \"constant\",\n        \"momentum\": 0,\n        \"learning_rate_init\": 0.2,\n    },\n    {\n        \"solver\": \"sgd\",\n        \"learning_rate\": \"constant\",\n        \"momentum\": 0.9,\n        \"nesterovs_momentum\": False,\n        \"learning_rate_init\": 0.2,\n    },\n    {\n        \"solver\": \"sgd\",\n        \"learning_rate\": \"constant\",\n        \"momentum\": 0.9,\n        \"nesterovs_momentum\": True,\n        \"learning_rate_init\": 0.2,\n    },\n    {\n        \"solver\": \"sgd\",\n        \"learning_rate\": \"invscaling\",\n        \"momentum\": 0,\n        \"learning_rate_init\": 0.2,\n    },\n    {\n        \"solver\": \"sgd\",\n        \"learning_rate\": \"invscaling\",\n        \"momentum\": 0.9,\n        \"nesterovs_momentum\": False,\n        \"learning_rate_init\": 0.2,\n    },\n    {\n        \"solver\": \"sgd\",\n        \"learning_rate\": \"invscaling\",\n        \"momentum\": 0.9,\n        \"nesterovs_momentum\": True,\n        \"learning_rate_init\": 0.2,\n    },\n    {\"solver\": \"adam\", \"learning_rate_init\": 0.01},\n]\n\nlabels = [\n    \"constant learning-rate\",\n    \"constant with momentum\",\n    \"constant with Nesterov's momentum\",\n    \"inv-scaling learning-rate\",\n    \"inv-scaling with momentum\",\n    \"inv-scaling with Nesterov's momentum\",\n    \"adam\",\n]\n\nplot_args = [\n    {\"c\": \"red\", \"linestyle\": \"-\"},\n    {\"c\": \"green\", \"linestyle\": \"-\"},\n    {\"c\": \"blue\", \"linestyle\": \"-\"},\n    {\"c\": \"red\", \"linestyle\": \"--\"},\n    {\"c\": \"green\", \"linestyle\": \"--\"},\n    {\"c\": \"blue\", \"linestyle\": \"--\"},\n    {\"c\": \"black\", \"linestyle\": \"-\"},\n]\n\n\ndef plot_on_dataset(X, y, ax, name):\n    # for each dataset, plot learning for each learning strategy\n    print(\"\\nlearning on dataset %s\" % name)\n    ax.set_title(name)\n\n    X = MinMaxScaler().fit_transform(X)\n    mlps = []\n    if name == \"digits\":\n        # digits is larger but converges fairly quickly\n        max_iter = 15\n    else:\n        max_iter = 400\n\n    for label, param in zip(labels, params):\n        print(\"training: %s\" % label)\n        mlp = MLPClassifier(random_state=0, max_iter=max_iter, **param)\n\n        # some parameter combinations will not converge as can be seen on the\n        # plots so they are ignored here\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\n                \"ignore\", category=ConvergenceWarning, module=\"sklearn\"\n            )\n            mlp.fit(X, y)\n\n        mlps.append(mlp)\n        print(\"Training set score: %f\" % mlp.score(X, y))\n        print(\"Training set loss: %f\" % mlp.loss_)\n    for mlp, label, args in zip(mlps, labels, plot_args):\n        ax.plot(mlp.loss_curve_, label=label, **args)\n\n\nfig, axes = plt.subplots(2, 2, figsize=(15, 10))\n# load / generate some toy datasets\niris = datasets.load_iris()\nX_digits, y_digits = datasets.load_digits(return_X_y=True)\ndata_sets = [\n    (iris.data, iris.target),\n    (X_digits, y_digits),\n    datasets.make_circles(noise=0.2, factor=0.5, random_state=1),\n    datasets.make_moons(noise=0.3, random_state=0),\n]\n\nfor ax, data, name in zip(\n    axes.ravel(), data_sets, [\"iris\", \"digits\", \"circles\", \"moons\"]\n):\n    plot_on_dataset(*data, ax=ax, name=name)\n\nfig.legend(ax.get_lines(), labels, ncol=3, loc=\"upper center\")\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}