{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Metadata Routing\n\n.. currentmodule:: sklearn\n\nThis document shows how you can use the `metadata routing mechanism\n<metadata_routing>` in scikit-learn to route metadata to the estimators,\nscorers, and CV splitters consuming them.\n\nTo better understand the following document, we need to introduce two concepts:\nrouters and consumers. A router is an object which forwards some given data and\nmetadata to other objects. In most cases, a router is a :term:`meta-estimator`,\ni.e. an estimator which takes another estimator as a parameter. A function such\nas :func:`sklearn.model_selection.cross_validate` which takes an estimator as a\nparameter and forwards data and metadata, is also a router.\n\nA consumer, on the other hand, is an object which accepts and uses some given\nmetadata. For instance, an estimator taking into account ``sample_weight`` in\nits :term:`fit` method is a consumer of ``sample_weight``.\n\nIt is possible for an object to be both a router and a consumer. For instance,\na meta-estimator may take into account ``sample_weight`` in certain\ncalculations, but it may also route it to the underlying estimator.\n\nFirst a few imports and some random data for the rest of the script.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Authors: The scikit-learn developers\n# SPDX-License-Identifier: BSD-3-Clause"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import warnings\nfrom pprint import pprint\n\nimport numpy as np\n\nfrom sklearn import set_config\nfrom sklearn.base import (\n    BaseEstimator,\n    ClassifierMixin,\n    MetaEstimatorMixin,\n    RegressorMixin,\n    TransformerMixin,\n    clone,\n)\nfrom sklearn.linear_model import LinearRegression\nfrom sklearn.utils import metadata_routing\nfrom sklearn.utils.metadata_routing import (\n    MetadataRouter,\n    MethodMapping,\n    get_routing_for_object,\n    process_routing,\n)\nfrom sklearn.utils.validation import check_is_fitted\n\nn_samples, n_features = 100, 4\nrng = np.random.RandomState(42)\nX = rng.rand(n_samples, n_features)\ny = rng.randint(0, 2, size=n_samples)\nmy_groups = rng.randint(0, 10, size=n_samples)\nmy_weights = rng.rand(n_samples)\nmy_other_weights = rng.rand(n_samples)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Metadata routing is only available if explicitly enabled:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "set_config(enable_metadata_routing=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This utility function is a dummy to check if a metadata is passed:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def check_metadata(obj, **kwargs):\n    for key, value in kwargs.items():\n        if value is not None:\n            print(\n                f\"Received {key} of length = {len(value)} in {obj.__class__.__name__}.\"\n            )\n        else:\n            print(f\"{key} is None in {obj.__class__.__name__}.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "A utility function to nicely print the routing information of an object:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def print_routing(obj):\n    pprint(obj.get_metadata_routing()._serialize())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Consuming Estimator\nHere we demonstrate how an estimator can expose the required API to support\nmetadata routing as a consumer. Imagine a simple classifier accepting\n``sample_weight`` as a metadata on its ``fit`` and ``groups`` in its\n``predict`` method:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class ExampleClassifier(ClassifierMixin, BaseEstimator):\n    def fit(self, X, y, sample_weight=None):\n        check_metadata(self, sample_weight=sample_weight)\n        # all classifiers need to expose a classes_ attribute once they're fit.\n        self.classes_ = np.array([0, 1])\n        return self\n\n    def predict(self, X, groups=None):\n        check_metadata(self, groups=groups)\n        # return a constant value of 1, not a very smart classifier!\n        return np.ones(len(X))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The above estimator now has all it needs to consume metadata. This is\naccomplished by some magic done in :class:`~base.BaseEstimator`. There are\nnow three methods exposed by the above class: ``set_fit_request``,\n``set_predict_request``, and ``get_metadata_routing``. There is also a\n``set_score_request`` for ``sample_weight`` which is present since\n:class:`~base.ClassifierMixin` implements a ``score`` method accepting\n``sample_weight``. The same applies to regressors which inherit from\n:class:`~base.RegressorMixin`.\n\nBy default, no metadata is requested, which we can see as:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print_routing(ExampleClassifier())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The above output means that ``sample_weight`` and ``groups`` are not\nrequested by `ExampleClassifier`, and if a router is given those metadata, it\nshould raise an error, since the user has not explicitly set whether they are\nrequired or not. The same is true for ``sample_weight`` in the ``score``\nmethod, which is inherited from :class:`~base.ClassifierMixin`. In order to\nexplicitly set request values for those metadata, we can use these methods:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "est = (\n    ExampleClassifier()\n    .set_fit_request(sample_weight=False)\n    .set_predict_request(groups=True)\n    .set_score_request(sample_weight=False)\n)\nprint_routing(est)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        ".. note ::\n    Please note that as long as the above estimator is not used in a\n    meta-estimator, the user does not need to set any requests for the\n    metadata and the set values are ignored, since a consumer does not\n    validate or route given metadata. A simple usage of the above estimator\n    would work as expected.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "est = ExampleClassifier()\nest.fit(X, y, sample_weight=my_weights)\nest.predict(X[:3, :], groups=my_groups)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Routing Meta-Estimator\nNow, we show how to design a meta-estimator to be a router. As a simplified\nexample, here is a meta-estimator, which doesn't do much other than routing\nthe metadata.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):\n    def __init__(self, estimator):\n        self.estimator = estimator\n\n    def get_metadata_routing(self):\n        # This method defines the routing for this meta-estimator.\n        # In order to do so, a `MetadataRouter` instance is created, and the\n        # routing is added to it. More explanations follow below.\n        router = MetadataRouter(owner=self).add(\n            estimator=self.estimator,\n            method_mapping=MethodMapping()\n            .add(caller=\"fit\", callee=\"fit\")\n            .add(caller=\"predict\", callee=\"predict\")\n            .add(caller=\"score\", callee=\"score\"),\n        )\n        return router\n\n    def fit(self, X, y, **fit_params):\n        # `get_routing_for_object` returns a copy of the `MetadataRouter`\n        # constructed by the above `get_metadata_routing` method, that is\n        # internally called.\n        request_router = get_routing_for_object(self)\n        # Meta-estimators are responsible for validating the given metadata.\n        # `method` refers to the parent's method, i.e. `fit` in this example.\n        request_router.validate_metadata(params=fit_params, method=\"fit\")\n        # `MetadataRouter.route_params` maps the given metadata to the metadata\n        # required by the underlying estimator based on the routing information\n        # defined by the MetadataRouter. The output of type `Bunch` has a key\n        # for each consuming object and those hold keys for their consuming\n        # methods, which then contain key for the metadata which should be\n        # routed to them.\n        routed_params = request_router.route_params(params=fit_params, caller=\"fit\")\n\n        # A sub-estimator is fitted and its classes are attributed to the\n        # meta-estimator.\n        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)\n        self.classes_ = self.estimator_.classes_\n        return self\n\n    def predict(self, X, **predict_params):\n        check_is_fitted(self)\n        # As in `fit`, we get a copy of the object's MetadataRouter,\n        request_router = get_routing_for_object(self)\n        # then we validate the given metadata,\n        request_router.validate_metadata(params=predict_params, method=\"predict\")\n        # and then prepare the input to the underlying `predict` method.\n        routed_params = request_router.route_params(\n            params=predict_params, caller=\"predict\"\n        )\n        return self.estimator_.predict(X, **routed_params.estimator.predict)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's break down different parts of the above code.\n\nFirst, the :meth:`~utils.metadata_routing.get_routing_for_object` takes our\nmeta-estimator (``self``) and returns a\n:class:`~utils.metadata_routing.MetadataRouter` or, a\n:class:`~utils.metadata_routing.MetadataRequest` if the object is a consumer,\nbased on the output of the estimator's ``get_metadata_routing`` method.\n\nThen in each method, we use the ``route_params`` method to construct a\ndictionary of the form ``{\"object_name\": {\"method_name\": {\"metadata\":\nvalue}}}`` to pass to the underlying estimator's method. The ``object_name``\n(``estimator`` in the above ``routed_params.estimator.fit`` example) is the\nsame as the one added in the ``get_metadata_routing``. ``validate_metadata``\nmakes sure all given metadata are requested to avoid silent bugs.\n\nNext, we illustrate the different behaviors and notably the type of errors\nraised.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est = MetaClassifier(\n    estimator=ExampleClassifier().set_fit_request(sample_weight=True)\n)\nmeta_est.fit(X, y, sample_weight=my_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note that the above example is calling our utility function\n`check_metadata()` via the `ExampleClassifier`. It checks that\n``sample_weight`` is correctly passed to it. If it is not, like in the\nfollowing example, it would print that ``sample_weight`` is ``None``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est.fit(X, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If we pass an unknown metadata, an error is raised:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "try:\n    meta_est.fit(X, y, test=my_weights)\nexcept TypeError as e:\n    print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "And if we pass a metadata which is not explicitly requested:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "try:\n    meta_est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups)\nexcept ValueError as e:\n    print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Also, if we explicitly set it as not requested, but it is provided:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est = MetaClassifier(\n    estimator=ExampleClassifier()\n    .set_fit_request(sample_weight=True)\n    .set_predict_request(groups=False)\n)\ntry:\n    meta_est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups)\nexcept TypeError as e:\n    print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Another concept to introduce is **aliased metadata**. This is when an\nestimator requests a metadata with a different variable name than the default\nvariable name. For instance, in a setting where there are two estimators in a\npipeline, one could request ``sample_weight1`` and the other\n``sample_weight2``. Note that this doesn't change what the estimator expects,\nit only tells the meta-estimator how to map the provided metadata to what is\nrequired. Here's an example, where we pass ``aliased_sample_weight`` to the\nmeta-estimator, but the meta-estimator understands that\n``aliased_sample_weight`` is an alias for ``sample_weight``, and passes it as\n``sample_weight`` to the underlying estimator:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est = MetaClassifier(\n    estimator=ExampleClassifier().set_fit_request(sample_weight=\"aliased_sample_weight\")\n)\nmeta_est.fit(X, y, aliased_sample_weight=my_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Passing ``sample_weight`` here will fail since it is requested with an\nalias and ``sample_weight`` with that name is not requested:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "try:\n    meta_est.fit(X, y, sample_weight=my_weights)\nexcept TypeError as e:\n    print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This leads us to the ``get_metadata_routing``. The way routing works in\nscikit-learn is that consumers request what they need, and routers pass that\nalong. Additionally, a router exposes what it requires itself so that it can\nbe used inside another router, e.g. a pipeline inside a grid search object.\nThe output of the ``get_metadata_routing`` which is a dictionary\nrepresentation of a :class:`~utils.metadata_routing.MetadataRouter`, includes\nthe complete tree of requested metadata by all nested objects and their\ncorresponding method routings, i.e. which method of a sub-estimator is used\nin which method of a meta-estimator:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print_routing(meta_est)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As you can see, the only metadata requested for method ``fit`` is\n``\"sample_weight\"`` with ``\"aliased_sample_weight\"`` as the alias. The\n``~utils.metadata_routing.MetadataRouter`` class enables us to easily create\nthe routing object which would create the output we need for our\n``get_metadata_routing``.\n\nIn order to understand how aliases work in meta-estimators, imagine our\nmeta-estimator inside another one:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_meta_est = MetaClassifier(estimator=meta_est).fit(\n    X, y, aliased_sample_weight=my_weights\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In the above example, this is how the ``fit`` method of `meta_meta_est`\nwill call their sub-estimator's ``fit`` methods::\n\n    # user feeds `my_weights` as `aliased_sample_weight` into `meta_meta_est`:\n    meta_meta_est.fit(X, y, aliased_sample_weight=my_weights):\n        ...\n\n        # the first sub-estimator (`meta_est`) expects `aliased_sample_weight`\n        self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight):\n            ...\n\n            # the second sub-estimator (`est`) expects `sample_weight`\n            self.estimator_.fit(X, y, sample_weight=aliased_sample_weight):\n                ...\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Consuming and routing Meta-Estimator\nFor a slightly more complex example, consider a meta-estimator that routes\nmetadata to an underlying estimator as before, but it also uses some metadata\nin its own methods. This meta-estimator is a consumer and a router at the\nsame time. Implementing one is very similar to what we had before, but with a\nfew tweaks.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator):\n    def __init__(self, estimator):\n        self.estimator = estimator\n\n    def get_metadata_routing(self):\n        router = (\n            MetadataRouter(owner=self)\n            # defining metadata routing request values for usage in the meta-estimator\n            .add_self_request(self)\n            # defining metadata routing request values for usage in the sub-estimator\n            .add(\n                estimator=self.estimator,\n                method_mapping=MethodMapping()\n                .add(caller=\"fit\", callee=\"fit\")\n                .add(caller=\"predict\", callee=\"predict\")\n                .add(caller=\"score\", callee=\"score\"),\n            )\n        )\n        return router\n\n    # Since `sample_weight` is used and consumed here, it should be defined as\n    # an explicit argument in the method's signature. All other metadata which\n    # are only routed, will be passed as `**fit_params`:\n    def fit(self, X, y, sample_weight, **fit_params):\n        if self.estimator is None:\n            raise ValueError(\"estimator cannot be None!\")\n\n        check_metadata(self, sample_weight=sample_weight)\n\n        # We add `sample_weight` to the `fit_params` dictionary.\n        if sample_weight is not None:\n            fit_params[\"sample_weight\"] = sample_weight\n\n        request_router = get_routing_for_object(self)\n        request_router.validate_metadata(params=fit_params, method=\"fit\")\n        routed_params = request_router.route_params(params=fit_params, caller=\"fit\")\n        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)\n        self.classes_ = self.estimator_.classes_\n        return self\n\n    def predict(self, X, **predict_params):\n        check_is_fitted(self)\n        # As in `fit`, we get a copy of the object's MetadataRouter,\n        request_router = get_routing_for_object(self)\n        # we validate the given metadata,\n        request_router.validate_metadata(params=predict_params, method=\"predict\")\n        # and then prepare the input to the underlying ``predict`` method.\n        routed_params = request_router.route_params(\n            params=predict_params, caller=\"predict\"\n        )\n        return self.estimator_.predict(X, **routed_params.estimator.predict)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The key parts where the above meta-estimator differs from our previous\nmeta-estimator is accepting ``sample_weight`` explicitly in ``fit`` and\nincluding it in ``fit_params``. Since ``sample_weight`` is an explicit\nargument, we can be sure that ``set_fit_request(sample_weight=...)`` is\npresent for this method. The meta-estimator is both a consumer, as well as a\nrouter of ``sample_weight``.\n\nIn ``get_metadata_routing``, we add ``self`` to the routing using\n``add_self_request`` to indicate this estimator is consuming\n``sample_weight`` as well as being a router; which also adds a\n``$self_request`` key to the routing info as illustrated below. Now let's\nlook at some examples:\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "- No metadata requested\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est = RouterConsumerClassifier(estimator=ExampleClassifier())\nprint_routing(meta_est)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "- ``sample_weight`` requested by sub-estimator\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est = RouterConsumerClassifier(\n    estimator=ExampleClassifier().set_fit_request(sample_weight=True)\n)\nprint_routing(meta_est)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "- ``sample_weight`` requested by meta-estimator\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request(\n    sample_weight=True\n)\nprint_routing(meta_est)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note the difference in the requested metadata representations above.\n\n- We can also alias the metadata to pass different values to the fit methods\n  of the meta- and the sub-estimator:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est = RouterConsumerClassifier(\n    estimator=ExampleClassifier().set_fit_request(sample_weight=\"clf_sample_weight\"),\n).set_fit_request(sample_weight=\"meta_clf_sample_weight\")\nprint_routing(meta_est)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "However, ``fit`` of the meta-estimator only needs the alias for the\nsub-estimator and addresses their own sample weight as `sample_weight`, since\nit doesn't validate and route its own required metadata:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "- Alias only on the sub-estimator:\n\nThis is useful when we don't want the meta-estimator to use the metadata, but\nthe sub-estimator should.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "meta_est = RouterConsumerClassifier(\n    estimator=ExampleClassifier().set_fit_request(sample_weight=\"aliased_sample_weight\")\n)\nprint_routing(meta_est)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The meta-estimator cannot use `aliased_sample_weight`, because it expects\nit passed as `sample_weight`. This would apply even if\n`set_fit_request(sample_weight=True)` was set on it.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Simple Pipeline\nA slightly more complicated use-case is a meta-estimator resembling a\n:class:`~pipeline.Pipeline`. Here is a meta-estimator, which accepts a\ntransformer and a classifier. When calling its `fit` method, it applies the\ntransformer's `fit` and `transform` before running the classifier on the\ntransformed data. Upon `predict`, it applies the transformer's `transform`\nbefore predicting with the classifier's `predict` method on the transformed\nnew data.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class SimplePipeline(ClassifierMixin, BaseEstimator):\n    def __init__(self, transformer, classifier):\n        self.transformer = transformer\n        self.classifier = classifier\n\n    def get_metadata_routing(self):\n        router = (\n            MetadataRouter(owner=self)\n            # We add the routing for the transformer.\n            .add(\n                transformer=self.transformer,\n                method_mapping=MethodMapping()\n                # The metadata is routed such that it retraces how\n                # `SimplePipeline` internally calls the transformer's `fit` and\n                # `transform` methods in its own methods (`fit` and `predict`).\n                .add(caller=\"fit\", callee=\"fit\")\n                .add(caller=\"fit\", callee=\"transform\")\n                .add(caller=\"predict\", callee=\"transform\"),\n            )\n            # We add the routing for the classifier.\n            .add(\n                classifier=self.classifier,\n                method_mapping=MethodMapping()\n                .add(caller=\"fit\", callee=\"fit\")\n                .add(caller=\"predict\", callee=\"predict\"),\n            )\n        )\n        return router\n\n    def fit(self, X, y, **fit_params):\n        routed_params = process_routing(self, \"fit\", **fit_params)\n\n        self.transformer_ = clone(self.transformer).fit(\n            X, y, **routed_params.transformer.fit\n        )\n        X_transformed = self.transformer_.transform(\n            X, **routed_params.transformer.transform\n        )\n\n        self.classifier_ = clone(self.classifier).fit(\n            X_transformed, y, **routed_params.classifier.fit\n        )\n        return self\n\n    def predict(self, X, **predict_params):\n        routed_params = process_routing(self, \"predict\", **predict_params)\n\n        X_transformed = self.transformer_.transform(\n            X, **routed_params.transformer.transform\n        )\n        return self.classifier_.predict(\n            X_transformed, **routed_params.classifier.predict\n        )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note the usage of :class:`~utils.metadata_routing.MethodMapping` to\ndeclare which methods of the child estimator (callee) are used in which\nmethods of the meta estimator (caller). As you can see, `SimplePipeline` uses\nthe transformer's ``transform`` and ``fit`` methods in ``fit``, and its\n``transform`` method in ``predict``, and that's what you see implemented in\nthe routing structure of the pipeline class.\n\nAnother difference in the above example with the previous ones is the usage\nof :func:`~utils.metadata_routing.process_routing`, which processes the input\nparameters, does the required validation, and returns the `routed_params`\nwhich we had created in previous examples. This reduces the boilerplate code\na developer needs to write in each meta-estimator's method. Developers are\nstrongly recommended to use this function unless there is a good reason\nagainst it.\n\nIn order to test the above pipeline, let's add an example transformer.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class ExampleTransformer(TransformerMixin, BaseEstimator):\n    def fit(self, X, y, sample_weight=None):\n        check_metadata(self, sample_weight=sample_weight)\n        return self\n\n    def transform(self, X, groups=None):\n        check_metadata(self, groups=groups)\n        return X\n\n    def fit_transform(self, X, y, sample_weight=None, groups=None):\n        return self.fit(X, y, sample_weight).transform(X, groups)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note that in the above example, we have implemented ``fit_transform`` which\ncalls ``fit`` and ``transform`` with the appropriate metadata. This is only\nrequired if ``transform`` accepts metadata, since the default ``fit_transform``\nimplementation in :class:`~base.TransformerMixin` doesn't pass metadata to\n``transform``.\n\nNow we can test our pipeline, and see if metadata is correctly passed around.\nThis example uses our `SimplePipeline`, our `ExampleTransformer`, and our\n`RouterConsumerClassifier` which uses our `ExampleClassifier`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "pipe = SimplePipeline(\n    transformer=ExampleTransformer()\n    # we set transformer's fit to receive sample_weight\n    .set_fit_request(sample_weight=True)\n    # we set transformer's transform to receive groups\n    .set_transform_request(groups=True),\n    classifier=RouterConsumerClassifier(\n        estimator=ExampleClassifier()\n        # we want this sub-estimator to receive sample_weight in fit\n        .set_fit_request(sample_weight=True)\n        # but not groups in predict\n        .set_predict_request(groups=False),\n    )\n    # and we want the meta-estimator to receive sample_weight as well\n    .set_fit_request(sample_weight=True),\n)\npipe.fit(X, y, sample_weight=my_weights, groups=my_groups).predict(\n    X[:3], groups=my_groups\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Deprecation / Default Value Change\nIn this section we show how one should handle the case where a router becomes\nalso a consumer, especially when it consumes the same metadata as its\nsub-estimator, or a consumer starts consuming a metadata which it wasn't in\nan older release. In this case, a warning should be raised for a while, to\nlet users know the behavior is changed from previous versions.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):\n    def __init__(self, estimator):\n        self.estimator = estimator\n\n    def fit(self, X, y, **fit_params):\n        routed_params = process_routing(self, \"fit\", **fit_params)\n        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)\n\n    def get_metadata_routing(self):\n        router = MetadataRouter(owner=self).add(\n            estimator=self.estimator,\n            method_mapping=MethodMapping().add(caller=\"fit\", callee=\"fit\"),\n        )\n        return router"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As explained above, this is a valid usage if `my_weights` aren't supposed\nto be passed as `sample_weight` to `MetaRegressor`:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True))\nreg.fit(X, y, sample_weight=my_weights)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now imagine we further develop ``MetaRegressor`` and it now also *consumes*\n``sample_weight``:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator):\n    # show warning to remind user to explicitly set the value with\n    # `.set_{method}_request(sample_weight={boolean})`\n    __metadata_request__fit = {\"sample_weight\": metadata_routing.WARN}\n\n    def __init__(self, estimator):\n        self.estimator = estimator\n\n    def fit(self, X, y, sample_weight=None, **fit_params):\n        routed_params = process_routing(\n            self, \"fit\", sample_weight=sample_weight, **fit_params\n        )\n        check_metadata(self, sample_weight=sample_weight)\n        self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit)\n\n    def get_metadata_routing(self):\n        router = (\n            MetadataRouter(owner=self)\n            .add_self_request(self)\n            .add(\n                estimator=self.estimator,\n                method_mapping=MethodMapping().add(caller=\"fit\", callee=\"fit\"),\n            )\n        )\n        return router"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The above implementation is almost the same as ``MetaRegressor``, and\nbecause of the default request value defined in ``__metadata_request__fit``\nthere is a warning raised when fitted.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "with warnings.catch_warnings(record=True) as record:\n    WeightedMetaRegressor(\n        estimator=LinearRegression().set_fit_request(sample_weight=False)\n    ).fit(X, y, sample_weight=my_weights)\nfor w in record:\n    print(w.message)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "When an estimator consumes a metadata which it didn't consume before, the\nfollowing pattern can be used to warn the users about it.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class ExampleRegressor(RegressorMixin, BaseEstimator):\n    __metadata_request__fit = {\"sample_weight\": metadata_routing.WARN}\n\n    def fit(self, X, y, sample_weight=None):\n        check_metadata(self, sample_weight=sample_weight)\n        return self\n\n    def predict(self, X):\n        return np.zeros(shape=(len(X)))\n\n\nwith warnings.catch_warnings(record=True) as record:\n    MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights)\nfor w in record:\n    print(w.message)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "At the end we disable the configuration flag for metadata routing:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "set_config(enable_metadata_routing=False)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Third Party Development and scikit-learn Dependency\n\nAs seen above, information is communicated between classes using\n:class:`~utils.metadata_routing.MetadataRequest` and\n:class:`~utils.metadata_routing.MetadataRouter`. It is strongly not advised,\nbut possible to vendor the tools related to metadata-routing if you strictly\nwant to have a scikit-learn compatible estimator, without depending on the\nscikit-learn package. If all of the following conditions are met, you do NOT\nneed to modify your code at all:\n\n- your estimator inherits from :class:`~base.BaseEstimator`\n- the parameters consumed by your estimator's methods, e.g. ``fit``, are\n  explicitly defined in the method's signature, as opposed to being\n  ``*args`` or ``*kwargs``.\n- your estimator does not route any metadata to the underlying objects, i.e.\n  it's not a *router*.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}