{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# SGD: Weighted samples\n\nPlot decision function of a weighted dataset, where the size of points\nis proportional to its weight.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom sklearn import linear_model\n\n# we create 20 points\nnp.random.seed(0)\nX = np.r_[np.random.randn(10, 2) + [1, 1], np.random.randn(10, 2)]\ny = [1] * 10 + [-1] * 10\nsample_weight = 100 * np.abs(np.random.randn(20))\n# and assign a bigger weight to the last 10 samples\nsample_weight[:10] *= 10\n\n# plot the weighted data points\nxx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500))\nfig, ax = plt.subplots()\nax.scatter(\n    X[:, 0],\n    X[:, 1],\n    c=y,\n    s=sample_weight,\n    alpha=0.9,\n    cmap=plt.cm.bone,\n    edgecolor=\"black\",\n)\n\n# fit the unweighted model\nclf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)\nclf.fit(X, y)\nZ = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])\nZ = Z.reshape(xx.shape)\nno_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=[\"solid\"])\n\n# fit the weighted model\nclf = linear_model.SGDClassifier(alpha=0.01, max_iter=100)\nclf.fit(X, y, sample_weight=sample_weight)\nZ = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])\nZ = Z.reshape(xx.shape)\nsamples_weights = ax.contour(xx, yy, Z, levels=[0], linestyles=[\"dashed\"])\n\nno_weights_handles, _ = no_weights.legend_elements()\nweights_handles, _ = samples_weights.legend_elements()\nax.legend(\n    [no_weights_handles[0], weights_handles[0]],\n    [\"no weights\", \"with weights\"],\n    loc=\"lower left\",\n)\n\nax.set(xticks=(), yticks=())\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}