-
Notifications
You must be signed in to change notification settings - Fork 409
ENH NeuralNet metadata routing #1139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
BenjaminBossan
merged 8 commits into
skorch-dev:master
from
adrinjalali:nn-routing-router
Apr 27, 2026
+399
−8
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
76054e6
ENH implement metadata routing for NeuralNet
adrinjalali 3195acc
cleanup
adrinjalali f09a794
simplify
adrinjalali db5232a
cleanup
adrinjalali 920145f
end to end test
adrinjalali e781e02
changelog
adrinjalali 57c6d04
apply Benjamin's review
adrinjalali e34b544
cleanup import
adrinjalali File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,13 @@ | |
|
|
||
| import numpy as np | ||
| from sklearn.base import BaseEstimator | ||
| from sklearn.utils.metadata_routing import ( | ||
| MetadataRouter, | ||
| MethodMapping, | ||
| UNUSED, | ||
| _routing_enabled, | ||
| process_routing, | ||
| ) | ||
| import torch | ||
| from torch.utils.data import DataLoader | ||
|
|
||
|
|
@@ -309,6 +316,13 @@ class NeuralNet(BaseEstimator): | |
| this list. | ||
|
|
||
| """ | ||
| # Suppress auto-generation of set_partial_fit_request that only | ||
| # allows 'classes'. We provide our own that accepts arbitrary | ||
| # metadata names, since partial_fit takes **fit_params. | ||
| # TODO: remove once scikit-learn/scikit-learn#32111 is merged and | ||
| # provides a public API for this. | ||
| __metadata_request__partial_fit = {"classes": UNUSED} | ||
|
|
||
| prefixes_ = ['iterator_train', 'iterator_valid', 'callbacks', 'dataset', 'compile'] | ||
|
|
||
| cuda_dependent_attributes_ = [] | ||
|
|
@@ -1150,7 +1164,8 @@ def evaluation_step(self, batch, training=False): | |
| self._set_training(training) | ||
| return self.infer(Xi) | ||
|
|
||
| def fit_loop(self, X, y=None, epochs=None, **fit_params): | ||
| def fit_loop(self, X, y=None, epochs=None, *, _routing_method="fit", | ||
| **fit_params): | ||
| """The proper fit loop. | ||
|
|
||
| Contains the logic of what actually happens during the fit | ||
|
|
@@ -1190,8 +1205,26 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params): | |
| self.check_training_readiness() | ||
| epochs = epochs if epochs is not None else self.max_epochs | ||
|
|
||
| if _routing_enabled(): | ||
| # _routing_method matches the public entry point so the | ||
| # right self-request (fit vs partial_fit) is resolved. | ||
| routed_params = process_routing( | ||
| self, _routing_method, **fit_params) | ||
| split_params = routed_params.get( | ||
| "splitter", {"split": {}} | ||
| )["split"] | ||
| # Following sklearn's router+consumer pattern: the router | ||
| # uses its own params directly (they're already in | ||
| # fit_params), while children get theirs from | ||
| # routed_params. The module's forward method should accept | ||
| # **kwargs to handle any extra params. | ||
| forward_params = fit_params | ||
| else: | ||
| split_params = fit_params | ||
| forward_params = fit_params | ||
|
|
||
| dataset_train, dataset_valid = self.get_split_datasets( | ||
| X, y, **fit_params) | ||
| X, y, **split_params) | ||
| on_epoch_kwargs = { | ||
| 'dataset_train': dataset_train, | ||
| 'dataset_valid': dataset_valid, | ||
|
|
@@ -1205,10 +1238,10 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params): | |
| self.notify('on_epoch_begin', **on_epoch_kwargs) | ||
|
|
||
| self.run_single_epoch(iterator_train, training=True, prefix="train", | ||
| step_fn=self.train_step, **fit_params) | ||
| step_fn=self.train_step, **forward_params) | ||
|
|
||
| self.run_single_epoch(iterator_valid, training=False, prefix="valid", | ||
| step_fn=self.validation_step, **fit_params) | ||
| step_fn=self.validation_step, **forward_params) | ||
|
|
||
| self.notify("on_epoch_end", **on_epoch_kwargs) | ||
| return self | ||
|
|
@@ -1290,9 +1323,15 @@ def partial_fit(self, X, y=None, classes=None, **fit_params): | |
| if not self.initialized_: | ||
| self.initialize() | ||
|
|
||
| # When called from fit(), _routing_method is threaded in via | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, too bad that this is needed, but it is what it is. |
||
| # fit_params so partial_fit's public signature stays clean | ||
| # (sklearn introspects it to auto-generate routing machinery). | ||
| routing_method = fit_params.pop("_routing_method", "partial_fit") | ||
|
|
||
| self.notify('on_train_begin', X=X, y=y) | ||
| try: | ||
| self.fit_loop(X, y, **fit_params) | ||
| self.fit_loop( | ||
| X, y, _routing_method=routing_method, **fit_params) | ||
| except KeyboardInterrupt: | ||
| pass | ||
| self.notify('on_train_end', X=X, y=y) | ||
|
|
@@ -1333,9 +1372,105 @@ def fit(self, X, y=None, **fit_params): | |
| if not self.warm_start or not self.initialized_: | ||
| self.initialize() | ||
|
|
||
| self.partial_fit(X, y, **fit_params) | ||
| self.partial_fit(X, y, _routing_method="fit", **fit_params) | ||
| return self | ||
|
|
||
| def set_fit_request(self, **kwargs): | ||
| """Set requested parameters by the ``fit`` method. | ||
|
|
||
| Please see :ref:`sklearn:metadata_routing` for more details. | ||
|
|
||
| Since ``NeuralNet.fit`` accepts arbitrary ``**fit_params`` that | ||
| are passed to the module's forward method, metadata names cannot | ||
| be inferred from the signature and must be declared explicitly | ||
| using this method. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| **kwargs : dict | ||
| Arguments should be of the form ``param_name=alias``, where | ||
| ``alias`` can be one of ``{True, False, None, str}``. | ||
|
|
||
| Returns | ||
| ------- | ||
| self : object | ||
| The updated object. | ||
| """ | ||
| if not _routing_enabled(): | ||
| raise RuntimeError( | ||
| "This method is only available when metadata routing is" | ||
| " enabled. You can enable it using" | ||
| " sklearn.set_config(enable_metadata_routing=True)." | ||
| ) | ||
|
|
||
| requests = self._get_metadata_request() | ||
| for param, alias in kwargs.items(): | ||
| requests.fit.add_request(param=param, alias=alias) | ||
| self._metadata_request = requests | ||
| return self | ||
|
|
||
| def set_partial_fit_request(self, **kwargs): | ||
| """Set requested parameters by the ``partial_fit`` method. | ||
|
|
||
| Please see :ref:`sklearn:metadata_routing` for more details. | ||
|
|
||
| Since ``NeuralNet.partial_fit`` accepts arbitrary ``**fit_params`` | ||
| that are passed to the module's forward method, metadata names | ||
| cannot be inferred from the signature and must be declared | ||
| explicitly using this method. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| **kwargs : dict | ||
| Arguments should be of the form ``param_name=alias``, where | ||
| ``alias`` can be one of ``{True, False, None, str}``. | ||
|
|
||
| Returns | ||
| ------- | ||
| self : object | ||
| The updated object. | ||
| """ | ||
| if not _routing_enabled(): | ||
| raise RuntimeError( | ||
| "This method is only available when metadata routing is" | ||
| " enabled. You can enable it using" | ||
| " sklearn.set_config(enable_metadata_routing=True)." | ||
| ) | ||
|
|
||
| requests = self._get_metadata_request() | ||
| for param, alias in kwargs.items(): | ||
| requests.partial_fit.add_request(param=param, alias=alias) | ||
| self._metadata_request = requests | ||
| return self | ||
|
|
||
| def get_metadata_routing(self): | ||
| """Get metadata routing of this object. | ||
|
|
||
| NeuralNet is both a consumer (its module's forward method | ||
| accepts arbitrary metadata) and a router (it routes metadata | ||
| like ``groups`` to its internal CV splitter). | ||
|
|
||
|
BenjaminBossan marked this conversation as resolved.
|
||
| Returns | ||
| ------- | ||
| routing : MetadataRouter | ||
| A :class:`~sklearn.utils.metadata_routing.MetadataRouter` | ||
| encapsulating routing information. | ||
| """ | ||
| router = MetadataRouter(owner=self.__class__.__name__) | ||
| router.add_self_request(self) | ||
|
|
||
| ts = self.train_split | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
| if ts is not None and hasattr(ts, 'cv'): | ||
| router.add( | ||
| splitter=ts.cv, | ||
| method_mapping=( | ||
| MethodMapping() | ||
| .add(caller="fit", callee="split") | ||
| .add(caller="partial_fit", callee="split") | ||
| ), | ||
| ) | ||
| return router | ||
|
|
||
| def check_is_fitted(self, attributes=None, *args, **kwargs): | ||
| """Checks whether the net is initialized | ||
|
|
||
|
|
@@ -1969,7 +2104,12 @@ def get_params_for_optimizer(self, prefix, named_parameters): | |
| return args, kwargs | ||
|
|
||
| def _get_param_names(self): | ||
| return [k for k in self.__dict__ if not k.endswith('_')] | ||
| # Exclude _metadata_request: it is set by set_fit_request() | ||
| # (not __init__), and sklearn's clone() handles it separately | ||
| # via deepcopy. Including it here would cause clone() to pass | ||
| # it to the constructor, which doesn't expect it. | ||
| return [k for k in self.__dict__ | ||
| if not k.endswith('_') and k != '_metadata_request'] | ||
|
|
||
| def _get_params_callbacks(self, deep=True): | ||
| """sklearn's .get_params checks for `hasattr(value, | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure how far back these imports reach, but I checked with sklearn 1.4.0 and the tests passed.