diff --git a/CHANGES.md b/CHANGES.md index 0aa7532d..8e62e81f 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Extend tensor conversion to numpy arrays to work with more device types (#1132) +- Add sklearn metadata routing support: `NeuralNet` is now a metadata router + consumer, enabling `groups` and other metadata to flow through `Pipeline`/`GridSearchCV` (#1139) ### Changed diff --git a/skorch/net.py b/skorch/net.py index 1d5b3b34..4d217460 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -19,6 +19,13 @@ import numpy as np from sklearn.base import BaseEstimator +from sklearn.utils.metadata_routing import ( + MetadataRouter, + MethodMapping, + UNUSED, + _routing_enabled, + process_routing, +) import torch from torch.utils.data import DataLoader @@ -309,6 +316,13 @@ class NeuralNet(BaseEstimator): this list. """ + # Suppress auto-generation of set_partial_fit_request that only + # allows 'classes'. We provide our own that accepts arbitrary + # metadata names, since partial_fit takes **fit_params. + # TODO: remove once scikit-learn/scikit-learn#32111 is merged and + # provides a public API for this. + __metadata_request__partial_fit = {"classes": UNUSED} + prefixes_ = ['iterator_train', 'iterator_valid', 'callbacks', 'dataset', 'compile'] cuda_dependent_attributes_ = [] @@ -1150,7 +1164,8 @@ def evaluation_step(self, batch, training=False): self._set_training(training) return self.infer(Xi) - def fit_loop(self, X, y=None, epochs=None, **fit_params): + def fit_loop(self, X, y=None, epochs=None, *, _routing_method="fit", + **fit_params): """The proper fit loop. Contains the logic of what actually happens during the fit @@ -1190,8 +1205,26 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params): self.check_training_readiness() epochs = epochs if epochs is not None else self.max_epochs + if _routing_enabled(): + # _routing_method matches the public entry point so the + # right self-request (fit vs partial_fit) is resolved. + routed_params = process_routing( + self, _routing_method, **fit_params) + split_params = routed_params.get( + "splitter", {"split": {}} + )["split"] + # Following sklearn's router+consumer pattern: the router + # uses its own params directly (they're already in + # fit_params), while children get theirs from + # routed_params. The module's forward method should accept + # **kwargs to handle any extra params. + forward_params = fit_params + else: + split_params = fit_params + forward_params = fit_params + dataset_train, dataset_valid = self.get_split_datasets( - X, y, **fit_params) + X, y, **split_params) on_epoch_kwargs = { 'dataset_train': dataset_train, 'dataset_valid': dataset_valid, @@ -1205,10 +1238,10 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params): self.notify('on_epoch_begin', **on_epoch_kwargs) self.run_single_epoch(iterator_train, training=True, prefix="train", - step_fn=self.train_step, **fit_params) + step_fn=self.train_step, **forward_params) self.run_single_epoch(iterator_valid, training=False, prefix="valid", - step_fn=self.validation_step, **fit_params) + step_fn=self.validation_step, **forward_params) self.notify("on_epoch_end", **on_epoch_kwargs) return self @@ -1290,9 +1323,15 @@ def partial_fit(self, X, y=None, classes=None, **fit_params): if not self.initialized_: self.initialize() + # When called from fit(), _routing_method is threaded in via + # fit_params so partial_fit's public signature stays clean + # (sklearn introspects it to auto-generate routing machinery). + routing_method = fit_params.pop("_routing_method", "partial_fit") + self.notify('on_train_begin', X=X, y=y) try: - self.fit_loop(X, y, **fit_params) + self.fit_loop( + X, y, _routing_method=routing_method, **fit_params) except KeyboardInterrupt: pass self.notify('on_train_end', X=X, y=y) @@ -1333,9 +1372,105 @@ def fit(self, X, y=None, **fit_params): if not self.warm_start or not self.initialized_: self.initialize() - self.partial_fit(X, y, **fit_params) + self.partial_fit(X, y, _routing_method="fit", **fit_params) return self + def set_fit_request(self, **kwargs): + """Set requested parameters by the ``fit`` method. + + Please see :ref:`sklearn:metadata_routing` for more details. + + Since ``NeuralNet.fit`` accepts arbitrary ``**fit_params`` that + are passed to the module's forward method, metadata names cannot + be inferred from the signature and must be declared explicitly + using this method. + + Parameters + ---------- + **kwargs : dict + Arguments should be of the form ``param_name=alias``, where + ``alias`` can be one of ``{True, False, None, str}``. + + Returns + ------- + self : object + The updated object. + """ + if not _routing_enabled(): + raise RuntimeError( + "This method is only available when metadata routing is" + " enabled. You can enable it using" + " sklearn.set_config(enable_metadata_routing=True)." + ) + + requests = self._get_metadata_request() + for param, alias in kwargs.items(): + requests.fit.add_request(param=param, alias=alias) + self._metadata_request = requests + return self + + def set_partial_fit_request(self, **kwargs): + """Set requested parameters by the ``partial_fit`` method. + + Please see :ref:`sklearn:metadata_routing` for more details. + + Since ``NeuralNet.partial_fit`` accepts arbitrary ``**fit_params`` + that are passed to the module's forward method, metadata names + cannot be inferred from the signature and must be declared + explicitly using this method. + + Parameters + ---------- + **kwargs : dict + Arguments should be of the form ``param_name=alias``, where + ``alias`` can be one of ``{True, False, None, str}``. + + Returns + ------- + self : object + The updated object. + """ + if not _routing_enabled(): + raise RuntimeError( + "This method is only available when metadata routing is" + " enabled. You can enable it using" + " sklearn.set_config(enable_metadata_routing=True)." + ) + + requests = self._get_metadata_request() + for param, alias in kwargs.items(): + requests.partial_fit.add_request(param=param, alias=alias) + self._metadata_request = requests + return self + + def get_metadata_routing(self): + """Get metadata routing of this object. + + NeuralNet is both a consumer (its module's forward method + accepts arbitrary metadata) and a router (it routes metadata + like ``groups`` to its internal CV splitter). + + Returns + ------- + routing : MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` + encapsulating routing information. + """ + router = MetadataRouter(owner=self.__class__.__name__) + router.add_self_request(self) + + ts = self.train_split + if ts is not None and hasattr(ts, 'cv'): + router.add( + splitter=ts.cv, + method_mapping=( + MethodMapping() + .add(caller="fit", callee="split") + .add(caller="partial_fit", callee="split") + ), + ) + return router + def check_is_fitted(self, attributes=None, *args, **kwargs): """Checks whether the net is initialized @@ -1969,7 +2104,12 @@ def get_params_for_optimizer(self, prefix, named_parameters): return args, kwargs def _get_param_names(self): - return [k for k in self.__dict__ if not k.endswith('_')] + # Exclude _metadata_request: it is set by set_fit_request() + # (not __init__), and sklearn's clone() handles it separately + # via deepcopy. Including it here would cause clone() to pass + # it to the constructor, which doesn't expect it. + return [k for k in self.__dict__ + if not k.endswith('_') and k != '_metadata_request'] def _get_params_callbacks(self, deep=True): """sklearn's .get_params checks for `hasattr(value, diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index e4a91d1e..2ff94490 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -22,7 +22,7 @@ from flaky import flaky import numpy as np import pytest -from sklearn.base import clone +from sklearn.base import BaseEstimator, TransformerMixin, clone from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import accuracy_score from sklearn.model_selection import GridSearchCV @@ -32,6 +32,7 @@ from torch import nn import skorch +from skorch.toy import MLPModule from skorch.tests.conftest import INFERENCE_METHODS from skorch.utils import flatten from skorch.utils import to_numpy @@ -4300,3 +4301,252 @@ def forward(self, input): y_pred = net.predict(X) assert y_proba.shape == (X.shape[0], 2) assert y_pred.shape == (X.shape[0],) + + +# Module-scope helpers for TestMetadataRouting. Defined at module +# level (not inside the test class) so they are picklable — sklearn +# 1.8's routing infrastructure deepcopies the net, which triggers +# NeuralNet.__getstate__ / pickle.dump on the wrapped module. +_ROUTING_RECORDED_FIT_PARAMS = [] + + +class _RoutingRecordingModule(MLPModule): + def forward(self, X, **fit_params): + _ROUTING_RECORDED_FIT_PARAMS.append(fit_params) + return super().forward(X) + + +class _RoutingKwargsModule(MLPModule): + def forward(self, X, **kwargs): + return super().forward(X) + + +class _RoutingRegressionModule(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(10, 1) + + def forward(self, X, Z, **kwargs): + return self.fc(X + Z) + + +class _RoutingDictScaler(TransformerMixin, BaseEstimator): + def fit(self, X, y=None): + self.scaler_ = StandardScaler().fit(X['X']) + return self + + def transform(self, X): + result = dict(X) + result['X'] = self.scaler_.transform(X['X']).astype('float32') + return result + + +class TestMetadataRouting: + """Tests for sklearn metadata routing support. + + NeuralNet is implemented as a router + consumer: it routes + metadata (e.g. groups) to its internal CV splitter, and consumes + metadata (e.g. Z) itself for the module's forward method. + """ + + @pytest.fixture(scope='module') + def data(self, classifier_data): + return classifier_data + + @pytest.fixture(scope='module') + def module_cls(self, classifier_module): + return classifier_module + + @pytest.fixture(scope='module') + def net_cls(self): + from skorch import NeuralNetClassifier + return NeuralNetClassifier + + @pytest.fixture + def routing_enabled(self): + import sklearn + with sklearn.config_context(enable_metadata_routing=True): + yield + + @pytest.fixture + def recorded_fit_params(self): + _ROUTING_RECORDED_FIT_PARAMS.clear() + yield _ROUTING_RECORDED_FIT_PARAMS + _ROUTING_RECORDED_FIT_PARAMS.clear() + + def test_set_request_requires_routing_enabled(self, net_cls, module_cls): + """set_fit_request and set_partial_fit_request raise when + routing is not enabled.""" + net = net_cls(module_cls) + with pytest.raises(RuntimeError, match="metadata routing is enabled"): + net.set_fit_request(Z=True) + with pytest.raises(RuntimeError, match="metadata routing is enabled"): + net.set_partial_fit_request(Z=True) + + def test_set_fit_request_returns_self( + self, net_cls, module_cls, routing_enabled + ): + net = net_cls(module_cls) + assert net.set_fit_request(Z=True) is net + + def test_fit_with_extra_params_and_routing( + self, net_cls, data, routing_enabled, recorded_fit_params + ): + """Extra fit params declared via set_fit_request reach the + module's forward method when routing is enabled.""" + X, y = data + net = net_cls( + _RoutingRecordingModule, max_epochs=1, batch_size=50, + train_split=None, + ) + net.set_fit_request(foo=True, bar=True) + net.initialize() + net.callbacks_ = [] + net.fit(X[:100], y[:100], foo=1, bar=2) + + assert len(recorded_fit_params) == 2 # 1 epoch, 2 batches + assert recorded_fit_params[0] == dict(foo=1, bar=2) + + def test_partial_fit_with_extra_params_and_routing( + self, net_cls, data, routing_enabled, recorded_fit_params + ): + """Extra fit params declared via set_partial_fit_request reach + the module's forward method when routing is enabled.""" + X, y = data + net = net_cls( + _RoutingRecordingModule, max_epochs=1, batch_size=50, + train_split=None, + ) + net.set_partial_fit_request(foo=True, bar=True) + net.initialize() + net.callbacks_ = [] + net.partial_fit(X[:100], y[:100], foo=1, bar=2) + + assert len(recorded_fit_params) == 2 # 1 epoch, 2 batches + assert recorded_fit_params[0] == dict(foo=1, bar=2) + + def test_fit_with_extra_params_does_not_break_valid_split( + self, net_cls, data, routing_enabled + ): + """When routing is enabled, extra fit_params that are only + declared for self don't reach ValidSplit (which would reject + them).""" + X, y = data + net = net_cls(_RoutingKwargsModule, max_epochs=1, batch_size=50) + net.set_fit_request(Z=True) + # Should not raise — Z is consumed by self, not passed to ValidSplit + net.fit(X[:100], y[:100], Z=X[:100, :10]) + + def test_groups_routed_to_train_split( + self, net_cls, data, routing_enabled + ): + """groups reaches ValidSplit(GroupKFold) via the router + automatically — no set_fit_request needed.""" + from sklearn.model_selection import GroupKFold + from skorch.dataset import ValidSplit + + X, y = data + n = len(X) // 2 + groups = np.array([0] * n + [1] * (len(X) - n)) + + net = net_cls( + _RoutingKwargsModule, max_epochs=1, batch_size=50, + train_split=ValidSplit(GroupKFold(2)), + ) + # No set_fit_request(groups=True) needed — GroupKFold + # declares it needs groups, and the router picks that up. + net.fit(X, y, groups=groups) + + def test_undeclared_param_rejected_by_routing( + self, net_cls, module_cls, routing_enabled + ): + """Passing metadata that no consumer requested raises.""" + X = np.zeros((10, 20), dtype='float32') + y = np.zeros(10, dtype='int64') + + net = net_cls(module_cls, max_epochs=1, train_split=None) + with pytest.raises(TypeError, match="unexpected argument"): + net.fit(X, y, unknown_param=1) + + def test_clone_preserves_metadata_request( + self, net_cls, module_cls, routing_enabled + ): + """sklearn.clone preserves metadata routing requests set via + set_fit_request.""" + net = net_cls(module_cls) + net.set_fit_request(Z=True) + net_cloned = clone(net) + + # Verify by behavior: cloned net should accept Z without error + net_cloned.set_fit_request(Z=True) # should not raise + + def test_pipeline_with_routing_enabled( + self, net_cls, data, routing_enabled + ): + """NeuralNet works inside a Pipeline when routing is enabled.""" + from skorch.toy import MLPModule + + X, y = data + net = net_cls(MLPModule, max_epochs=1, batch_size=50, train_split=None) + pipe = Pipeline([('scale', StandardScaler()), ('net', net)]) + pipe.fit(X[:100], y[:100]) + + def test_grid_search_with_routing( + self, net_cls, module_cls, data, routing_enabled + ): + """GridSearchCV works with metadata routing enabled.""" + X, y = data + net = net_cls( + module_cls, max_epochs=1, batch_size=50, train_split=None, + ) + gs = GridSearchCV( + net, param_grid={'lr': [0.01, 0.1]}, + cv=2, refit=False, n_jobs=1, + ) + gs.fit(X[:100], y[:100]) + + def test_backward_compat_fit_params_to_train_split( + self, net_cls, data, recorded_fit_params + ): + """Without routing enabled, all fit_params still reach + train_split (legacy behavior).""" + X, y = data + + def recording_split(dataset, y=None, **fit_params): + recorded_fit_params.append(fit_params) + return dataset, dataset + + net = net_cls( + _RoutingKwargsModule, max_epochs=1, batch_size=50, + train_split=recording_split, + ) + net.initialize() + net.callbacks_ = [] + net.fit(X[:100], y[:100], foo=1, bar=2) + + assert len(recorded_fit_params) == 1 + assert recorded_fit_params[0] == dict(foo=1, bar=2) + + def test_pipeline_with_dict_x_and_groups_routing( + self, data, routing_enabled + ): + """End-to-end: Pipeline scales part of a dict X, routes groups + to GroupKFold, and trains a module that uses auxiliary data.""" + from sklearn.model_selection import GroupKFold + from skorch.dataset import ValidSplit + from skorch import NeuralNetRegressor + + X, _ = data + X_arr = X[:100, :10].astype('float32') + Z_arr = X[:100, 10:].astype('float32') + y = X[:100, 0].astype('float32').reshape(-1, 1) + groups = np.array([0] * 50 + [1] * 50) + + X_dict = {'X': X_arr, 'Z': Z_arr} + + net = NeuralNetRegressor( + _RoutingRegressionModule, max_epochs=2, lr=0.01, + train_split=ValidSplit(GroupKFold(2)), + ) + pipe = Pipeline([('scale', _RoutingDictScaler()), ('net', net)]) + pipe.fit(X_dict, y, groups=groups)