{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Comparison between grid search and successive halving\n\nThis example compares the parameter search performed by\n:class:`~sklearn.model_selection.HalvingGridSearchCV` and\n:class:`~sklearn.model_selection.GridSearchCV`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nfrom time import time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nfrom sklearn import datasets\nfrom sklearn.experimental import enable_halving_search_cv  # noqa: F401\nfrom sklearn.model_selection import GridSearchCV, HalvingGridSearchCV\nfrom sklearn.svm import SVC"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We first define the parameter space for an :class:`~sklearn.svm.SVC`\nestimator, and compute the time required to train a\n:class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a\n:class:`~sklearn.model_selection.GridSearchCV` instance.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "rng = np.random.RandomState(0)\nX, y = datasets.make_classification(n_samples=1000, random_state=rng)\n\ngammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7]\nCs = [1, 10, 100, 1e3, 1e4, 1e5]\nparam_grid = {\"gamma\": gammas, \"C\": Cs}\n\nclf = SVC(random_state=rng)\n\ntic = time()\ngsh = HalvingGridSearchCV(\n    estimator=clf, param_grid=param_grid, factor=2, random_state=rng\n)\ngsh.fit(X, y)\ngsh_time = time() - tic\n\ntic = time()\ngs = GridSearchCV(estimator=clf, param_grid=param_grid)\ngs.fit(X, y)\ngs_time = time() - tic"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We now plot heatmaps for both search estimators.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def make_heatmap(ax, gs, is_sh=False, make_cbar=False):\n    \"\"\"Helper to make a heatmap.\"\"\"\n    results = pd.DataFrame(gs.cv_results_)\n    results[[\"param_C\", \"param_gamma\"]] = results[[\"param_C\", \"param_gamma\"]].astype(\n        np.float64\n    )\n    if is_sh:\n        # SH dataframe: get mean_test_score values for the highest iter\n        scores_matrix = results.sort_values(\"iter\").pivot_table(\n            index=\"param_gamma\",\n            columns=\"param_C\",\n            values=\"mean_test_score\",\n            aggfunc=\"last\",\n        )\n    else:\n        scores_matrix = results.pivot(\n            index=\"param_gamma\", columns=\"param_C\", values=\"mean_test_score\"\n        )\n\n    im = ax.imshow(scores_matrix)\n\n    ax.set_xticks(np.arange(len(Cs)))\n    ax.set_xticklabels([\"{:.0E}\".format(x) for x in Cs])\n    ax.set_xlabel(\"C\", fontsize=15)\n\n    ax.set_yticks(np.arange(len(gammas)))\n    ax.set_yticklabels([\"{:.0E}\".format(x) for x in gammas])\n    ax.set_ylabel(\"gamma\", fontsize=15)\n\n    # Rotate the tick labels and set their alignment.\n    plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\", rotation_mode=\"anchor\")\n\n    if is_sh:\n        iterations = results.pivot_table(\n            index=\"param_gamma\", columns=\"param_C\", values=\"iter\", aggfunc=\"max\"\n        ).values\n        for i in range(len(gammas)):\n            for j in range(len(Cs)):\n                ax.text(\n                    j,\n                    i,\n                    iterations[i, j],\n                    ha=\"center\",\n                    va=\"center\",\n                    color=\"w\",\n                    fontsize=20,\n                )\n\n    if make_cbar:\n        fig.subplots_adjust(right=0.8)\n        cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n        fig.colorbar(im, cax=cbar_ax)\n        cbar_ax.set_ylabel(\"mean_test_score\", rotation=-90, va=\"bottom\", fontsize=15)\n\n\nfig, axes = plt.subplots(ncols=2, sharey=True)\nax1, ax2 = axes\n\nmake_heatmap(ax1, gsh, is_sh=True)\nmake_heatmap(ax2, gs, make_cbar=True)\n\nax1.set_title(\"Successive Halving\\ntime = {:.3f}s\".format(gsh_time), fontsize=15)\nax2.set_title(\"GridSearch\\ntime = {:.3f}s\".format(gs_time), fontsize=15)\n\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The heatmaps show the mean test score of the parameter combinations for an\n:class:`~sklearn.svm.SVC` instance. The\n:class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the\niteration at which the combinations where last used. The combinations marked\nas ``0`` were only evaluated at the first iteration, while the ones with\n``5`` are the parameter combinations that are considered the best ones.\n\nWe can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV`\nclass is able to find parameter combinations that are just as accurate as\n:class:`~sklearn.model_selection.GridSearchCV`, in much less time.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}