From 76054e6fec3995038a4339f756dc0a17f02d7f00 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 20 Apr 2026 14:36:39 +0200 Subject: [PATCH 1/8] ENH implement metadata routing for NeuralNet --- .gitignore | 1 + skorch/net.py | 202 ++++++++++++++++++++++++- skorch/tests/test_net.py | 312 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 511 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index c45ff9f1d..730964f96 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,4 @@ data/ *.pickle *.mat *.ipynb +CLAUDE.md diff --git a/skorch/net.py b/skorch/net.py index 1d5b3b34c..23dde27e5 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -19,6 +19,13 @@ import numpy as np from sklearn.base import BaseEstimator +from sklearn.utils._metadata_requests import UNUSED +from sklearn.utils.metadata_routing import ( + MetadataRouter, + MethodMapping, + _routing_enabled, + process_routing, +) import torch from torch.utils.data import DataLoader @@ -309,6 +316,13 @@ class NeuralNet(BaseEstimator): this list. """ + # Suppress auto-generation of set_partial_fit_request that only + # allows 'classes'. We provide our own that accepts arbitrary + # metadata names, since partial_fit takes **fit_params. + # TODO: remove once scikit-learn/scikit-learn#32111 is merged and + # provides a public API for this. + __metadata_request__partial_fit = {"classes": UNUSED} + prefixes_ = ['iterator_train', 'iterator_valid', 'callbacks', 'dataset', 'compile'] cuda_dependent_attributes_ = [] @@ -1190,8 +1204,26 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params): self.check_training_readiness() epochs = epochs if epochs is not None else self.max_epochs + if _routing_enabled(): + routed_params = process_routing(self, "fit", **fit_params) + if fit_params: + split_params = routed_params.get( + "splitter", {"split": {}} + )["split"] + else: + split_params = {} + # Following sklearn's router+consumer pattern: the router + # uses its own params directly (they're already in + # fit_params), while children get theirs from + # routed_params. The module's forward method should accept + # **kwargs to handle any extra params. + forward_params = fit_params + else: + split_params = fit_params + forward_params = fit_params + dataset_train, dataset_valid = self.get_split_datasets( - X, y, **fit_params) + X, y, **split_params) on_epoch_kwargs = { 'dataset_train': dataset_train, 'dataset_valid': dataset_valid, @@ -1205,10 +1237,10 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params): self.notify('on_epoch_begin', **on_epoch_kwargs) self.run_single_epoch(iterator_train, training=True, prefix="train", - step_fn=self.train_step, **fit_params) + step_fn=self.train_step, **forward_params) self.run_single_epoch(iterator_valid, training=False, prefix="valid", - step_fn=self.validation_step, **fit_params) + step_fn=self.validation_step, **forward_params) self.notify("on_epoch_end", **on_epoch_kwargs) return self @@ -1336,6 +1368,145 @@ def fit(self, X, y=None, **fit_params): self.partial_fit(X, y, **fit_params) return self + def _get_metadata_request(self): + """Get metadata request, using class name as owner. + + sklearn's routing infrastructure calls ``deepcopy`` on + ``MetadataRequest`` objects (via ``get_routing_for_object`` + and ``add_self_request``). The default implementation stores + ``owner=self`` (the instance), so ``deepcopy`` follows that + reference and tries to deep-copy the entire NeuralNet. + This fails because ``NeuralNet.__getstate__`` uses + ``pickle.dump`` for CUDA-dependent attributes, and pickle + cannot handle non-picklable modules (e.g. locally-defined + classes). + + Since ``owner`` is only used for error messages, replacing it + with the class name string is safe and avoids the issue. An + alternative would be implementing ``__deepcopy__`` on + NeuralNet. + """ + request = super()._get_metadata_request() + owner = self.__class__.__name__ + request.owner = owner + for attr_name in list(vars(request)): + attr = getattr(request, attr_name) + if hasattr(attr, 'owner'): + attr.owner = owner + return request + + def set_fit_request(self, **kwargs): + """Set requested parameters by the ``fit`` method. + + Please see :ref:`sklearn:metadata_routing` for more details. + + Since ``NeuralNet.fit`` accepts arbitrary ``**fit_params`` that + are passed to the module's forward method, metadata names cannot + be inferred from the signature and must be declared explicitly + using this method. + + Parameters + ---------- + **kwargs : dict + Arguments should be of the form ``param_name=alias``, where + ``alias`` can be one of ``{True, False, None, str}``. + + Returns + ------- + self : object + The updated object. + """ + if not _routing_enabled(): + raise RuntimeError( + "This method is only available when metadata routing is" + " enabled. You can enable it using" + " sklearn.set_config(enable_metadata_routing=True)." + ) + + requests = self._get_metadata_request() + for param, alias in kwargs.items(): + requests.fit.add_request(param=param, alias=alias) + self._metadata_request = requests + return self + + def set_partial_fit_request(self, **kwargs): + """Set requested parameters by the ``partial_fit`` method. + + Please see :ref:`sklearn:metadata_routing` for more details. + + Since ``NeuralNet.partial_fit`` accepts arbitrary ``**fit_params`` + that are passed to the module's forward method, metadata names + cannot be inferred from the signature and must be declared + explicitly using this method. + + Parameters + ---------- + **kwargs : dict + Arguments should be of the form ``param_name=alias``, where + ``alias`` can be one of ``{True, False, None, str}``. + + Returns + ------- + self : object + The updated object. + """ + if not _routing_enabled(): + raise RuntimeError( + "This method is only available when metadata routing is" + " enabled. You can enable it using" + " sklearn.set_config(enable_metadata_routing=True)." + ) + + requests = self._get_metadata_request() + for param, alias in kwargs.items(): + requests.partial_fit.add_request(param=param, alias=alias) + self._metadata_request = requests + return self + + def _get_splitter_for_routing(self): + """Extract the CV splitter from train_split for routing. + + Returns the underlying CV splitter if ``train_split`` is a + :class:`.ValidSplit` whose ``cv`` attribute is a sklearn + splitter with metadata routing support. Returns ``None`` + otherwise (e.g. when ``train_split`` is ``None``, a custom + callable, or wraps an integer/float cv). + """ + ts = self.train_split + if ts is None: + return None + if hasattr(ts, 'cv') and hasattr(ts.cv, 'get_metadata_routing'): + return ts.cv + return None + + def get_metadata_routing(self): + """Get metadata routing of this object. + + NeuralNet is both a :term:`consumer` (its module's forward + method accepts arbitrary metadata) and a :term:`router` (it + routes metadata like ``groups`` to its internal CV splitter). + + Returns + ------- + routing : MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` + encapsulating routing information. + """ + router = MetadataRouter(owner=self.__class__.__name__) + router.add_self_request(self) + + splitter = self._get_splitter_for_routing() + if splitter is not None: + router.add( + splitter=splitter, + method_mapping=( + MethodMapping() + .add(caller="fit", callee="split") + .add(caller="partial_fit", callee="split") + ), + ) + return router + def check_is_fitted(self, attributes=None, *args, **kwargs): """Checks whether the net is initialized @@ -1969,7 +2140,12 @@ def get_params_for_optimizer(self, prefix, named_parameters): return args, kwargs def _get_param_names(self): - return [k for k in self.__dict__ if not k.endswith('_')] + # Exclude _metadata_request: it is set by set_fit_request() + # (not __init__), and sklearn's clone() handles it separately + # via deepcopy. Including it here would cause clone() to pass + # it to the constructor, which doesn't expect it. + return [k for k in self.__dict__ + if not k.endswith('_') and k != '_metadata_request'] def _get_params_callbacks(self, deep=True): """sklearn's .get_params checks for `hasattr(value, @@ -2231,6 +2407,24 @@ def _replace_callback(self, name, new_val): break setattr(self, 'callbacks_', callbacks_new) + def __deepcopy__(self, memo): + # NeuralNet's __getstate__ uses pickle.dump for CUDA-dependent + # attributes, which blocks or fails when the module is not + # picklable. We bypass __getstate__/__setstate__ and deepcopy + # each attribute individually, falling back to shallow copy + # for non-copyable objects (e.g. torch modules with locally + # defined classes). + import copy + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + try: + setattr(result, k, copy.deepcopy(v, memo)) + except Exception: + setattr(result, k, v) + return result + def __getstate__(self): state = self.__dict__.copy() cuda_attrs = {} diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index e4a91d1ed..60526b07d 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -4300,3 +4300,315 @@ def forward(self, input): y_pred = net.predict(X) assert y_proba.shape == (X.shape[0], 2) assert y_pred.shape == (X.shape[0],) + + +class TestMetadataRouting: + """Tests for sklearn metadata routing support. + + NeuralNet is implemented as a router + consumer: it routes + metadata (e.g. groups) to its internal CV splitter, and consumes + metadata (e.g. Z) itself for the module's forward method. + """ + + @pytest.fixture(scope='module') + def data(self, classifier_data): + return classifier_data + + @pytest.fixture(scope='module') + def module_cls(self, classifier_module): + return classifier_module + + @pytest.fixture(scope='module') + def net_cls(self): + from skorch import NeuralNetClassifier + return NeuralNetClassifier + + @pytest.fixture + def routing_enabled(self): + """Context manager to enable metadata routing for a test.""" + import sklearn + with sklearn.config_context(enable_metadata_routing=True): + yield + + # --- set_fit_request / set_partial_fit_request --- + + def test_set_fit_request_requires_routing_enabled(self, net_cls, module_cls): + net = net_cls(module_cls) + with pytest.raises(RuntimeError, match="metadata routing is enabled"): + net.set_fit_request(Z=True) + + def test_set_partial_fit_request_requires_routing_enabled( + self, net_cls, module_cls + ): + net = net_cls(module_cls) + with pytest.raises(RuntimeError, match="metadata routing is enabled"): + net.set_partial_fit_request(Z=True) + + def test_set_fit_request_returns_self( + self, net_cls, module_cls, routing_enabled + ): + net = net_cls(module_cls) + result = net.set_fit_request(Z=True) + assert result is net + + def test_set_fit_request_updates_metadata_request( + self, net_cls, module_cls, routing_enabled + ): + net = net_cls(module_cls) + net.set_fit_request(Z=True, groups=True) + routing = net._get_metadata_request() + assert 'Z' in routing.fit.requests + assert 'groups' in routing.fit.requests + + def test_set_partial_fit_request_updates_metadata_request( + self, net_cls, module_cls, routing_enabled + ): + net = net_cls(module_cls) + net.set_partial_fit_request(Z=True) + routing = net._get_metadata_request() + assert 'Z' in routing.partial_fit.requests + + # --- get_metadata_routing (router) --- + + def test_get_metadata_routing_returns_router( + self, net_cls, module_cls, routing_enabled + ): + from sklearn.utils.metadata_routing import MetadataRouter + net = net_cls(module_cls) + routing = net.get_metadata_routing() + assert isinstance(routing, MetadataRouter) + + def test_get_metadata_routing_includes_splitter_child( + self, net_cls, module_cls, routing_enabled + ): + from sklearn.model_selection import GroupKFold + from skorch.dataset import ValidSplit + + net = net_cls(module_cls, train_split=ValidSplit(GroupKFold(2))) + routing = net.get_metadata_routing() + # The router should have a 'splitter' child + assert 'splitter' in routing._route_mappings + + def test_get_metadata_routing_no_splitter_for_int_cv( + self, net_cls, module_cls, routing_enabled + ): + from skorch.dataset import ValidSplit + + net = net_cls(module_cls, train_split=ValidSplit(5)) + routing = net.get_metadata_routing() + # int cv has no get_metadata_routing, so no splitter child + assert 'splitter' not in routing._route_mappings + + def test_get_metadata_routing_no_splitter_when_none( + self, net_cls, module_cls, routing_enabled + ): + net = net_cls(module_cls, train_split=None) + routing = net.get_metadata_routing() + assert 'splitter' not in routing._route_mappings + + # --- fit with routing --- + + def test_fit_with_extra_params_and_routing( + self, net_cls, data, routing_enabled + ): + """Extra fit params reach the module forward when routing is enabled.""" + from skorch.toy import MLPModule + + X, y = data + side_effect = [] + + class FPModule(MLPModule): + def forward(self, X, **fit_params): + side_effect.append(fit_params) + return super().forward(X) + + net = net_cls( + FPModule, max_epochs=1, batch_size=50, train_split=None, + ) + net.set_fit_request(foo=True, bar=True) + net.initialize() + net.callbacks_ = [] + net.fit(X[:100], y[:100], foo=1, bar=2) + + assert len(side_effect) == 2 # 1 epoch, 2 batches + assert side_effect[0] == dict(foo=1, bar=2) + + def test_fit_with_extra_params_does_not_break_valid_split( + self, net_cls, data, routing_enabled + ): + """When routing is enabled, extra fit_params don't reach ValidSplit.""" + from skorch.toy import MLPModule + + class FPModule(MLPModule): + def forward(self, X, **fit_params): + return super().forward(X) + + X, y = data + net = net_cls( + FPModule, max_epochs=1, batch_size=50, + ) + net.set_fit_request(Z=True) + # Should not raise — Z should not be passed to ValidSplit + net.fit(X[:100], y[:100], Z=X[:100, :10]) + + def test_groups_routed_to_train_split( + self, net_cls, data, routing_enabled + ): + """groups reaches ValidSplit(GroupKFold) via the router.""" + from sklearn.model_selection import GroupKFold + from skorch.dataset import ValidSplit + from skorch.toy import MLPModule + + X, y = data + n = len(X) // 2 + groups = np.array([0] * n + [1] * (len(X) - n)) + + # Module must accept **kwargs since all fit_params (including + # groups) flow to module forward — the router pattern passes + # all params through. + class KwargsModule(MLPModule): + def forward(self, X, **kwargs): + return super().forward(X) + + net = net_cls( + KwargsModule, + max_epochs=1, + batch_size=50, + train_split=ValidSplit(GroupKFold(2)), + ) + # groups is automatically routed to the splitter child — + # no need to call set_fit_request(groups=True) + net.fit(X, y, groups=groups) + + def test_groups_and_fit_params_both_reach_module( + self, net_cls, data, routing_enabled + ): + """All fit_params (including groups) reach the module forward. + + Following sklearn's router pattern, the router uses its own + params directly. The module should accept **kwargs. + """ + from sklearn.model_selection import GroupKFold + from skorch.dataset import ValidSplit + from skorch.toy import MLPModule + + X, y = data + side_effect = [] + + class FPModule(MLPModule): + def forward(self, X, **fit_params): + side_effect.append(fit_params) + return super().forward(X) + + n = len(X) // 2 + groups = np.array([0] * n + [1] * (len(X) - n)) + + net = net_cls( + FPModule, + max_epochs=1, + batch_size=50, + train_split=ValidSplit(GroupKFold(2)), + ) + net.set_fit_request(foo=True) + net.initialize() + net.callbacks_ = [] + net.fit(X, y, groups=groups, foo=1) + + # fit_params still contain groups (they flow to module forward + # as-is), but the routing correctly separated groups for the + # splitter. The module must accept **fit_params to handle this. + assert all('foo' in p for p in side_effect) + + # --- clone --- + + def test_clone_preserves_metadata_request( + self, net_cls, module_cls, routing_enabled + ): + """sklearn.clone should preserve metadata routing requests.""" + net = net_cls(module_cls) + net.set_fit_request(Z=True, groups=True) + net_cloned = clone(net) + + routing = net_cloned._get_metadata_request() + assert 'Z' in routing.fit.requests + assert 'groups' in routing.fit.requests + + # --- Pipeline integration --- + + def test_pipeline_with_routing_enabled( + self, net_cls, data, routing_enabled + ): + """NeuralNet works inside Pipeline when routing is enabled.""" + from skorch.toy import MLPModule + + X, y = data + + net = net_cls( + MLPModule, max_epochs=1, batch_size=50, train_split=None, + ) + + pipe = Pipeline([ + ('scale', StandardScaler()), + ('net', net), + ]) + # Should work without errors when routing is enabled + pipe.fit(X[:100], y[:100]) + + # --- GridSearchCV integration --- + + def test_grid_search_with_routing( + self, net_cls, module_cls, data, routing_enabled + ): + """GridSearchCV works with metadata routing enabled. + + Verifies clone works correctly and the net can be used inside + GridSearchCV when routing is enabled. + """ + X, y = data + + net = net_cls( + module_cls, + max_epochs=1, + batch_size=50, + train_split=None, + ) + + gs = GridSearchCV( + net, + param_grid={'lr': [0.01, 0.1]}, + cv=2, + refit=False, + n_jobs=1, + ) + gs.fit(X[:100], y[:100]) + + # --- backward compatibility --- + + def test_backward_compat_fit_params_to_train_split( + self, net_cls, data + ): + """Without routing enabled, all fit_params still reach train_split.""" + from skorch.toy import MLPModule + + X, y = data + side_effect = [] + + def fp_train_split(dataset, y=None, **fit_params): + side_effect.append(fit_params) + return dataset, dataset + + class FPModule(MLPModule): + def forward(self, X, **fit_params): + return super().forward(X) + + net = net_cls( + FPModule, max_epochs=1, batch_size=50, + train_split=fp_train_split, + ) + net.initialize() + net.callbacks_ = [] + net.fit(X[:100], y[:100], foo=1, bar=2) + + # Legacy behavior: all fit_params passed to train_split + assert len(side_effect) == 1 + assert side_effect[0] == dict(foo=1, bar=2) From 3195accd3ebf56b8f95cec843dd950917aceaf9b Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 20 Apr 2026 14:52:00 +0200 Subject: [PATCH 2/8] cleanup --- skorch/net.py | 9 +- skorch/tests/test_net.py | 229 +++++++++------------------------------ 2 files changed, 56 insertions(+), 182 deletions(-) diff --git a/skorch/net.py b/skorch/net.py index 23dde27e5..6445c9565 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -1206,12 +1206,9 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params): if _routing_enabled(): routed_params = process_routing(self, "fit", **fit_params) - if fit_params: - split_params = routed_params.get( - "splitter", {"split": {}} - )["split"] - else: - split_params = {} + split_params = routed_params.get( + "splitter", {"split": {}} + )["split"] # Following sklearn's router+consumer pattern: the router # uses its own params directly (they're already in # fit_params), while children get theirs from diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 60526b07d..e792925aa 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -4325,22 +4325,16 @@ def net_cls(self): @pytest.fixture def routing_enabled(self): - """Context manager to enable metadata routing for a test.""" import sklearn with sklearn.config_context(enable_metadata_routing=True): yield - # --- set_fit_request / set_partial_fit_request --- - - def test_set_fit_request_requires_routing_enabled(self, net_cls, module_cls): + def test_set_request_requires_routing_enabled(self, net_cls, module_cls): + """set_fit_request and set_partial_fit_request raise when + routing is not enabled.""" net = net_cls(module_cls) with pytest.raises(RuntimeError, match="metadata routing is enabled"): net.set_fit_request(Z=True) - - def test_set_partial_fit_request_requires_routing_enabled( - self, net_cls, module_cls - ): - net = net_cls(module_cls) with pytest.raises(RuntimeError, match="metadata routing is enabled"): net.set_partial_fit_request(Z=True) @@ -4348,113 +4342,57 @@ def test_set_fit_request_returns_self( self, net_cls, module_cls, routing_enabled ): net = net_cls(module_cls) - result = net.set_fit_request(Z=True) - assert result is net - - def test_set_fit_request_updates_metadata_request( - self, net_cls, module_cls, routing_enabled - ): - net = net_cls(module_cls) - net.set_fit_request(Z=True, groups=True) - routing = net._get_metadata_request() - assert 'Z' in routing.fit.requests - assert 'groups' in routing.fit.requests - - def test_set_partial_fit_request_updates_metadata_request( - self, net_cls, module_cls, routing_enabled - ): - net = net_cls(module_cls) - net.set_partial_fit_request(Z=True) - routing = net._get_metadata_request() - assert 'Z' in routing.partial_fit.requests - - # --- get_metadata_routing (router) --- - - def test_get_metadata_routing_returns_router( - self, net_cls, module_cls, routing_enabled - ): - from sklearn.utils.metadata_routing import MetadataRouter - net = net_cls(module_cls) - routing = net.get_metadata_routing() - assert isinstance(routing, MetadataRouter) - - def test_get_metadata_routing_includes_splitter_child( - self, net_cls, module_cls, routing_enabled - ): - from sklearn.model_selection import GroupKFold - from skorch.dataset import ValidSplit - - net = net_cls(module_cls, train_split=ValidSplit(GroupKFold(2))) - routing = net.get_metadata_routing() - # The router should have a 'splitter' child - assert 'splitter' in routing._route_mappings - - def test_get_metadata_routing_no_splitter_for_int_cv( - self, net_cls, module_cls, routing_enabled - ): - from skorch.dataset import ValidSplit - - net = net_cls(module_cls, train_split=ValidSplit(5)) - routing = net.get_metadata_routing() - # int cv has no get_metadata_routing, so no splitter child - assert 'splitter' not in routing._route_mappings - - def test_get_metadata_routing_no_splitter_when_none( - self, net_cls, module_cls, routing_enabled - ): - net = net_cls(module_cls, train_split=None) - routing = net.get_metadata_routing() - assert 'splitter' not in routing._route_mappings - - # --- fit with routing --- + assert net.set_fit_request(Z=True) is net def test_fit_with_extra_params_and_routing( self, net_cls, data, routing_enabled ): - """Extra fit params reach the module forward when routing is enabled.""" + """Extra fit params declared via set_fit_request reach the + module's forward method when routing is enabled.""" from skorch.toy import MLPModule X, y = data - side_effect = [] + received_params = [] - class FPModule(MLPModule): + class RecordingModule(MLPModule): def forward(self, X, **fit_params): - side_effect.append(fit_params) + received_params.append(fit_params) return super().forward(X) net = net_cls( - FPModule, max_epochs=1, batch_size=50, train_split=None, + RecordingModule, max_epochs=1, batch_size=50, train_split=None, ) net.set_fit_request(foo=True, bar=True) net.initialize() net.callbacks_ = [] net.fit(X[:100], y[:100], foo=1, bar=2) - assert len(side_effect) == 2 # 1 epoch, 2 batches - assert side_effect[0] == dict(foo=1, bar=2) + assert len(received_params) == 2 # 1 epoch, 2 batches + assert received_params[0] == dict(foo=1, bar=2) def test_fit_with_extra_params_does_not_break_valid_split( self, net_cls, data, routing_enabled ): - """When routing is enabled, extra fit_params don't reach ValidSplit.""" + """When routing is enabled, extra fit_params that are only + declared for self don't reach ValidSplit (which would reject + them).""" from skorch.toy import MLPModule - class FPModule(MLPModule): + class KwargsModule(MLPModule): def forward(self, X, **fit_params): return super().forward(X) X, y = data - net = net_cls( - FPModule, max_epochs=1, batch_size=50, - ) + net = net_cls(KwargsModule, max_epochs=1, batch_size=50) net.set_fit_request(Z=True) - # Should not raise — Z should not be passed to ValidSplit + # Should not raise — Z is consumed by self, not passed to ValidSplit net.fit(X[:100], y[:100], Z=X[:100, :10]) def test_groups_routed_to_train_split( self, net_cls, data, routing_enabled ): - """groups reaches ValidSplit(GroupKFold) via the router.""" + """groups reaches ValidSplit(GroupKFold) via the router + automatically — no set_fit_request needed.""" from sklearn.model_selection import GroupKFold from skorch.dataset import ValidSplit from skorch.toy import MLPModule @@ -4463,152 +4401,91 @@ def test_groups_routed_to_train_split( n = len(X) // 2 groups = np.array([0] * n + [1] * (len(X) - n)) - # Module must accept **kwargs since all fit_params (including - # groups) flow to module forward — the router pattern passes - # all params through. class KwargsModule(MLPModule): def forward(self, X, **kwargs): return super().forward(X) net = net_cls( - KwargsModule, - max_epochs=1, - batch_size=50, + KwargsModule, max_epochs=1, batch_size=50, train_split=ValidSplit(GroupKFold(2)), ) - # groups is automatically routed to the splitter child — - # no need to call set_fit_request(groups=True) + # No set_fit_request(groups=True) needed — GroupKFold + # declares it needs groups, and the router picks that up. net.fit(X, y, groups=groups) - def test_groups_and_fit_params_both_reach_module( - self, net_cls, data, routing_enabled + def test_undeclared_param_rejected_by_routing( + self, net_cls, module_cls, routing_enabled ): - """All fit_params (including groups) reach the module forward. - - Following sklearn's router pattern, the router uses its own - params directly. The module should accept **kwargs. - """ - from sklearn.model_selection import GroupKFold - from skorch.dataset import ValidSplit - from skorch.toy import MLPModule + """Passing metadata that no consumer requested raises.""" + X = np.zeros((10, 20), dtype='float32') + y = np.zeros(10, dtype='int64') - X, y = data - side_effect = [] - - class FPModule(MLPModule): - def forward(self, X, **fit_params): - side_effect.append(fit_params) - return super().forward(X) - - n = len(X) // 2 - groups = np.array([0] * n + [1] * (len(X) - n)) - - net = net_cls( - FPModule, - max_epochs=1, - batch_size=50, - train_split=ValidSplit(GroupKFold(2)), - ) - net.set_fit_request(foo=True) - net.initialize() - net.callbacks_ = [] - net.fit(X, y, groups=groups, foo=1) - - # fit_params still contain groups (they flow to module forward - # as-is), but the routing correctly separated groups for the - # splitter. The module must accept **fit_params to handle this. - assert all('foo' in p for p in side_effect) - - # --- clone --- + net = net_cls(module_cls, max_epochs=1, train_split=None) + with pytest.raises(TypeError, match="unexpected argument"): + net.fit(X, y, unknown_param=1) def test_clone_preserves_metadata_request( self, net_cls, module_cls, routing_enabled ): - """sklearn.clone should preserve metadata routing requests.""" + """sklearn.clone preserves metadata routing requests set via + set_fit_request.""" net = net_cls(module_cls) - net.set_fit_request(Z=True, groups=True) + net.set_fit_request(Z=True) net_cloned = clone(net) - routing = net_cloned._get_metadata_request() - assert 'Z' in routing.fit.requests - assert 'groups' in routing.fit.requests - - # --- Pipeline integration --- + # Verify by behavior: cloned net should accept Z without error + net_cloned.set_fit_request(Z=True) # should not raise def test_pipeline_with_routing_enabled( self, net_cls, data, routing_enabled ): - """NeuralNet works inside Pipeline when routing is enabled.""" + """NeuralNet works inside a Pipeline when routing is enabled.""" from skorch.toy import MLPModule X, y = data - - net = net_cls( - MLPModule, max_epochs=1, batch_size=50, train_split=None, - ) - - pipe = Pipeline([ - ('scale', StandardScaler()), - ('net', net), - ]) - # Should work without errors when routing is enabled + net = net_cls(MLPModule, max_epochs=1, batch_size=50, train_split=None) + pipe = Pipeline([('scale', StandardScaler()), ('net', net)]) pipe.fit(X[:100], y[:100]) - # --- GridSearchCV integration --- - def test_grid_search_with_routing( self, net_cls, module_cls, data, routing_enabled ): - """GridSearchCV works with metadata routing enabled. - - Verifies clone works correctly and the net can be used inside - GridSearchCV when routing is enabled. - """ + """GridSearchCV works with metadata routing enabled.""" X, y = data - net = net_cls( - module_cls, - max_epochs=1, - batch_size=50, - train_split=None, + module_cls, max_epochs=1, batch_size=50, train_split=None, ) - gs = GridSearchCV( - net, - param_grid={'lr': [0.01, 0.1]}, - cv=2, - refit=False, - n_jobs=1, + net, param_grid={'lr': [0.01, 0.1]}, + cv=2, refit=False, n_jobs=1, ) gs.fit(X[:100], y[:100]) - # --- backward compatibility --- - def test_backward_compat_fit_params_to_train_split( self, net_cls, data ): - """Without routing enabled, all fit_params still reach train_split.""" + """Without routing enabled, all fit_params still reach + train_split (legacy behavior).""" from skorch.toy import MLPModule X, y = data - side_effect = [] + received_params = [] - def fp_train_split(dataset, y=None, **fit_params): - side_effect.append(fit_params) + def recording_split(dataset, y=None, **fit_params): + received_params.append(fit_params) return dataset, dataset - class FPModule(MLPModule): + class KwargsModule(MLPModule): def forward(self, X, **fit_params): return super().forward(X) net = net_cls( - FPModule, max_epochs=1, batch_size=50, - train_split=fp_train_split, + KwargsModule, max_epochs=1, batch_size=50, + train_split=recording_split, ) net.initialize() net.callbacks_ = [] net.fit(X[:100], y[:100], foo=1, bar=2) - # Legacy behavior: all fit_params passed to train_split - assert len(side_effect) == 1 - assert side_effect[0] == dict(foo=1, bar=2) + assert len(received_params) == 1 + assert received_params[0] == dict(foo=1, bar=2) From f09a794539e6dc4e347b75be0140e8e4cd44b242 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 20 Apr 2026 16:07:29 +0200 Subject: [PATCH 3/8] simplify --- skorch/net.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/skorch/net.py b/skorch/net.py index 6445c9565..ebfc6a83e 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -1460,22 +1460,6 @@ def set_partial_fit_request(self, **kwargs): self._metadata_request = requests return self - def _get_splitter_for_routing(self): - """Extract the CV splitter from train_split for routing. - - Returns the underlying CV splitter if ``train_split`` is a - :class:`.ValidSplit` whose ``cv`` attribute is a sklearn - splitter with metadata routing support. Returns ``None`` - otherwise (e.g. when ``train_split`` is ``None``, a custom - callable, or wraps an integer/float cv). - """ - ts = self.train_split - if ts is None: - return None - if hasattr(ts, 'cv') and hasattr(ts.cv, 'get_metadata_routing'): - return ts.cv - return None - def get_metadata_routing(self): """Get metadata routing of this object. @@ -1492,10 +1476,10 @@ def get_metadata_routing(self): router = MetadataRouter(owner=self.__class__.__name__) router.add_self_request(self) - splitter = self._get_splitter_for_routing() - if splitter is not None: + ts = self.train_split + if ts is not None and hasattr(ts, 'cv'): router.add( - splitter=splitter, + splitter=ts.cv, method_mapping=( MethodMapping() .add(caller="fit", callee="split") From db5232a934e5b4c54126b686a19c0a844afdf2d0 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 20 Apr 2026 16:34:09 +0200 Subject: [PATCH 4/8] cleanup --- skorch/net.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/skorch/net.py b/skorch/net.py index ebfc6a83e..5d554770d 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -1365,33 +1365,6 @@ def fit(self, X, y=None, **fit_params): self.partial_fit(X, y, **fit_params) return self - def _get_metadata_request(self): - """Get metadata request, using class name as owner. - - sklearn's routing infrastructure calls ``deepcopy`` on - ``MetadataRequest`` objects (via ``get_routing_for_object`` - and ``add_self_request``). The default implementation stores - ``owner=self`` (the instance), so ``deepcopy`` follows that - reference and tries to deep-copy the entire NeuralNet. - This fails because ``NeuralNet.__getstate__`` uses - ``pickle.dump`` for CUDA-dependent attributes, and pickle - cannot handle non-picklable modules (e.g. locally-defined - classes). - - Since ``owner`` is only used for error messages, replacing it - with the class name string is safe and avoids the issue. An - alternative would be implementing ``__deepcopy__`` on - NeuralNet. - """ - request = super()._get_metadata_request() - owner = self.__class__.__name__ - request.owner = owner - for attr_name in list(vars(request)): - attr = getattr(request, attr_name) - if hasattr(attr, 'owner'): - attr.owner = owner - return request - def set_fit_request(self, **kwargs): """Set requested parameters by the ``fit`` method. From 920145f05d91d4149b6b4741f6f5affbd90a22d8 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 20 Apr 2026 16:38:18 +0200 Subject: [PATCH 5/8] end to end test --- skorch/tests/test_net.py | 44 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index e792925aa..51215eb1f 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -4489,3 +4489,47 @@ def forward(self, X, **fit_params): assert len(received_params) == 1 assert received_params[0] == dict(foo=1, bar=2) + + def test_pipeline_with_dict_x_and_groups_routing( + self, data, routing_enabled + ): + """End-to-end: Pipeline scales part of a dict X, routes groups + to GroupKFold, and trains a module that uses auxiliary data.""" + from sklearn.base import BaseEstimator, TransformerMixin + from sklearn.model_selection import GroupKFold + from skorch.dataset import ValidSplit + from skorch import NeuralNetRegressor + + class RegressionModule(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(10, 1) + + def forward(self, X, Z, **kwargs): + return self.fc(X + Z) + + class DictScaler(TransformerMixin, BaseEstimator): + def fit(self, X, y=None): + self.scaler_ = StandardScaler().fit(X['X']) + return self + + def transform(self, X): + result = dict(X) + result['X'] = self.scaler_.transform( + X['X']).astype('float32') + return result + + X, _ = data + X_arr = X[:100, :10].astype('float32') + Z_arr = X[:100, 10:].astype('float32') + y = X[:100, 0].astype('float32').reshape(-1, 1) + groups = np.array([0] * 50 + [1] * 50) + + X_dict = {'X': X_arr, 'Z': Z_arr} + + net = NeuralNetRegressor( + RegressionModule, max_epochs=2, lr=0.01, + train_split=ValidSplit(GroupKFold(2)), + ) + pipe = Pipeline([('scale', DictScaler()), ('net', net)]) + pipe.fit(X_dict, y, groups=groups) From e781e02bed80caa58ae822c7a97f2e7a6c9f4273 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Mon, 20 Apr 2026 16:44:44 +0200 Subject: [PATCH 6/8] changelog --- CHANGES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGES.md b/CHANGES.md index 0aa7532d1..8e62e81f2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Extend tensor conversion to numpy arrays to work with more device types (#1132) +- Add sklearn metadata routing support: `NeuralNet` is now a metadata router + consumer, enabling `groups` and other metadata to flow through `Pipeline`/`GridSearchCV` (#1139) ### Changed From 57c6d04deee3c65186a68b84786b231578e1c49d Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 22 Apr 2026 11:37:46 +0200 Subject: [PATCH 7/8] apply Benjamin's review --- .gitignore | 1 - skorch/net.py | 42 +++++------- skorch/tests/test_net.py | 139 ++++++++++++++++++++++----------------- 3 files changed, 95 insertions(+), 87 deletions(-) diff --git a/.gitignore b/.gitignore index 730964f96..c45ff9f1d 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,3 @@ data/ *.pickle *.mat *.ipynb -CLAUDE.md diff --git a/skorch/net.py b/skorch/net.py index 5d554770d..68690b4bb 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -1164,7 +1164,8 @@ def evaluation_step(self, batch, training=False): self._set_training(training) return self.infer(Xi) - def fit_loop(self, X, y=None, epochs=None, **fit_params): + def fit_loop(self, X, y=None, epochs=None, *, _routing_method="fit", + **fit_params): """The proper fit loop. Contains the logic of what actually happens during the fit @@ -1205,7 +1206,10 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params): epochs = epochs if epochs is not None else self.max_epochs if _routing_enabled(): - routed_params = process_routing(self, "fit", **fit_params) + # _routing_method matches the public entry point so the + # right self-request (fit vs partial_fit) is resolved. + routed_params = process_routing( + self, _routing_method, **fit_params) split_params = routed_params.get( "splitter", {"split": {}} )["split"] @@ -1319,9 +1323,15 @@ def partial_fit(self, X, y=None, classes=None, **fit_params): if not self.initialized_: self.initialize() + # When called from fit(), _routing_method is threaded in via + # fit_params so partial_fit's public signature stays clean + # (sklearn introspects it to auto-generate routing machinery). + routing_method = fit_params.pop("_routing_method", "partial_fit") + self.notify('on_train_begin', X=X, y=y) try: - self.fit_loop(X, y, **fit_params) + self.fit_loop( + X, y, _routing_method=routing_method, **fit_params) except KeyboardInterrupt: pass self.notify('on_train_end', X=X, y=y) @@ -1362,7 +1372,7 @@ def fit(self, X, y=None, **fit_params): if not self.warm_start or not self.initialized_: self.initialize() - self.partial_fit(X, y, **fit_params) + self.partial_fit(X, y, _routing_method="fit", **fit_params) return self def set_fit_request(self, **kwargs): @@ -1436,9 +1446,9 @@ def set_partial_fit_request(self, **kwargs): def get_metadata_routing(self): """Get metadata routing of this object. - NeuralNet is both a :term:`consumer` (its module's forward - method accepts arbitrary metadata) and a :term:`router` (it - routes metadata like ``groups`` to its internal CV splitter). + NeuralNet is both a consumer (its module's forward method + accepts arbitrary metadata) and a router (it routes metadata + like ``groups`` to its internal CV splitter). Returns ------- @@ -2361,24 +2371,6 @@ def _replace_callback(self, name, new_val): break setattr(self, 'callbacks_', callbacks_new) - def __deepcopy__(self, memo): - # NeuralNet's __getstate__ uses pickle.dump for CUDA-dependent - # attributes, which blocks or fails when the module is not - # picklable. We bypass __getstate__/__setstate__ and deepcopy - # each attribute individually, falling back to shallow copy - # for non-copyable objects (e.g. torch modules with locally - # defined classes). - import copy - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - try: - setattr(result, k, copy.deepcopy(v, memo)) - except Exception: - setattr(result, k, v) - return result - def __getstate__(self): state = self.__dict__.copy() cuda_attrs = {} diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 51215eb1f..2ff944905 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -22,7 +22,7 @@ from flaky import flaky import numpy as np import pytest -from sklearn.base import clone +from sklearn.base import BaseEstimator, TransformerMixin, clone from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import accuracy_score from sklearn.model_selection import GridSearchCV @@ -32,6 +32,7 @@ from torch import nn import skorch +from skorch.toy import MLPModule from skorch.tests.conftest import INFERENCE_METHODS from skorch.utils import flatten from skorch.utils import to_numpy @@ -4302,6 +4303,44 @@ def forward(self, input): assert y_pred.shape == (X.shape[0],) +# Module-scope helpers for TestMetadataRouting. Defined at module +# level (not inside the test class) so they are picklable — sklearn +# 1.8's routing infrastructure deepcopies the net, which triggers +# NeuralNet.__getstate__ / pickle.dump on the wrapped module. +_ROUTING_RECORDED_FIT_PARAMS = [] + + +class _RoutingRecordingModule(MLPModule): + def forward(self, X, **fit_params): + _ROUTING_RECORDED_FIT_PARAMS.append(fit_params) + return super().forward(X) + + +class _RoutingKwargsModule(MLPModule): + def forward(self, X, **kwargs): + return super().forward(X) + + +class _RoutingRegressionModule(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(10, 1) + + def forward(self, X, Z, **kwargs): + return self.fc(X + Z) + + +class _RoutingDictScaler(TransformerMixin, BaseEstimator): + def fit(self, X, y=None): + self.scaler_ = StandardScaler().fit(X['X']) + return self + + def transform(self, X): + result = dict(X) + result['X'] = self.scaler_.transform(X['X']).astype('float32') + return result + + class TestMetadataRouting: """Tests for sklearn metadata routing support. @@ -4329,6 +4368,12 @@ def routing_enabled(self): with sklearn.config_context(enable_metadata_routing=True): yield + @pytest.fixture + def recorded_fit_params(self): + _ROUTING_RECORDED_FIT_PARAMS.clear() + yield _ROUTING_RECORDED_FIT_PARAMS + _ROUTING_RECORDED_FIT_PARAMS.clear() + def test_set_request_requires_routing_enabled(self, net_cls, module_cls): """set_fit_request and set_partial_fit_request raise when routing is not enabled.""" @@ -4345,30 +4390,40 @@ def test_set_fit_request_returns_self( assert net.set_fit_request(Z=True) is net def test_fit_with_extra_params_and_routing( - self, net_cls, data, routing_enabled + self, net_cls, data, routing_enabled, recorded_fit_params ): """Extra fit params declared via set_fit_request reach the module's forward method when routing is enabled.""" - from skorch.toy import MLPModule - X, y = data - received_params = [] - - class RecordingModule(MLPModule): - def forward(self, X, **fit_params): - received_params.append(fit_params) - return super().forward(X) - net = net_cls( - RecordingModule, max_epochs=1, batch_size=50, train_split=None, + _RoutingRecordingModule, max_epochs=1, batch_size=50, + train_split=None, ) net.set_fit_request(foo=True, bar=True) net.initialize() net.callbacks_ = [] net.fit(X[:100], y[:100], foo=1, bar=2) - assert len(received_params) == 2 # 1 epoch, 2 batches - assert received_params[0] == dict(foo=1, bar=2) + assert len(recorded_fit_params) == 2 # 1 epoch, 2 batches + assert recorded_fit_params[0] == dict(foo=1, bar=2) + + def test_partial_fit_with_extra_params_and_routing( + self, net_cls, data, routing_enabled, recorded_fit_params + ): + """Extra fit params declared via set_partial_fit_request reach + the module's forward method when routing is enabled.""" + X, y = data + net = net_cls( + _RoutingRecordingModule, max_epochs=1, batch_size=50, + train_split=None, + ) + net.set_partial_fit_request(foo=True, bar=True) + net.initialize() + net.callbacks_ = [] + net.partial_fit(X[:100], y[:100], foo=1, bar=2) + + assert len(recorded_fit_params) == 2 # 1 epoch, 2 batches + assert recorded_fit_params[0] == dict(foo=1, bar=2) def test_fit_with_extra_params_does_not_break_valid_split( self, net_cls, data, routing_enabled @@ -4376,14 +4431,8 @@ def test_fit_with_extra_params_does_not_break_valid_split( """When routing is enabled, extra fit_params that are only declared for self don't reach ValidSplit (which would reject them).""" - from skorch.toy import MLPModule - - class KwargsModule(MLPModule): - def forward(self, X, **fit_params): - return super().forward(X) - X, y = data - net = net_cls(KwargsModule, max_epochs=1, batch_size=50) + net = net_cls(_RoutingKwargsModule, max_epochs=1, batch_size=50) net.set_fit_request(Z=True) # Should not raise — Z is consumed by self, not passed to ValidSplit net.fit(X[:100], y[:100], Z=X[:100, :10]) @@ -4395,18 +4444,13 @@ def test_groups_routed_to_train_split( automatically — no set_fit_request needed.""" from sklearn.model_selection import GroupKFold from skorch.dataset import ValidSplit - from skorch.toy import MLPModule X, y = data n = len(X) // 2 groups = np.array([0] * n + [1] * (len(X) - n)) - class KwargsModule(MLPModule): - def forward(self, X, **kwargs): - return super().forward(X) - net = net_cls( - KwargsModule, max_epochs=1, batch_size=50, + _RoutingKwargsModule, max_epochs=1, batch_size=50, train_split=ValidSplit(GroupKFold(2)), ) # No set_fit_request(groups=True) needed — GroupKFold @@ -4462,63 +4506,36 @@ def test_grid_search_with_routing( gs.fit(X[:100], y[:100]) def test_backward_compat_fit_params_to_train_split( - self, net_cls, data + self, net_cls, data, recorded_fit_params ): """Without routing enabled, all fit_params still reach train_split (legacy behavior).""" - from skorch.toy import MLPModule - X, y = data - received_params = [] def recording_split(dataset, y=None, **fit_params): - received_params.append(fit_params) + recorded_fit_params.append(fit_params) return dataset, dataset - class KwargsModule(MLPModule): - def forward(self, X, **fit_params): - return super().forward(X) - net = net_cls( - KwargsModule, max_epochs=1, batch_size=50, + _RoutingKwargsModule, max_epochs=1, batch_size=50, train_split=recording_split, ) net.initialize() net.callbacks_ = [] net.fit(X[:100], y[:100], foo=1, bar=2) - assert len(received_params) == 1 - assert received_params[0] == dict(foo=1, bar=2) + assert len(recorded_fit_params) == 1 + assert recorded_fit_params[0] == dict(foo=1, bar=2) def test_pipeline_with_dict_x_and_groups_routing( self, data, routing_enabled ): """End-to-end: Pipeline scales part of a dict X, routes groups to GroupKFold, and trains a module that uses auxiliary data.""" - from sklearn.base import BaseEstimator, TransformerMixin from sklearn.model_selection import GroupKFold from skorch.dataset import ValidSplit from skorch import NeuralNetRegressor - class RegressionModule(nn.Module): - def __init__(self): - super().__init__() - self.fc = nn.Linear(10, 1) - - def forward(self, X, Z, **kwargs): - return self.fc(X + Z) - - class DictScaler(TransformerMixin, BaseEstimator): - def fit(self, X, y=None): - self.scaler_ = StandardScaler().fit(X['X']) - return self - - def transform(self, X): - result = dict(X) - result['X'] = self.scaler_.transform( - X['X']).astype('float32') - return result - X, _ = data X_arr = X[:100, :10].astype('float32') Z_arr = X[:100, 10:].astype('float32') @@ -4528,8 +4545,8 @@ def transform(self, X): X_dict = {'X': X_arr, 'Z': Z_arr} net = NeuralNetRegressor( - RegressionModule, max_epochs=2, lr=0.01, + _RoutingRegressionModule, max_epochs=2, lr=0.01, train_split=ValidSplit(GroupKFold(2)), ) - pipe = Pipeline([('scale', DictScaler()), ('net', net)]) + pipe = Pipeline([('scale', _RoutingDictScaler()), ('net', net)]) pipe.fit(X_dict, y, groups=groups) From e34b5444c459a045963b7c8fa5c8e4ece6e89870 Mon Sep 17 00:00:00 2001 From: adrinjalali Date: Wed, 22 Apr 2026 11:41:18 +0200 Subject: [PATCH 8/8] cleanup import --- skorch/net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skorch/net.py b/skorch/net.py index 68690b4bb..4d217460a 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -19,10 +19,10 @@ import numpy as np from sklearn.base import BaseEstimator -from sklearn.utils._metadata_requests import UNUSED from sklearn.utils.metadata_routing import ( MetadataRouter, MethodMapping, + UNUSED, _routing_enabled, process_routing, )