Skip to content

ENH NeuralNet metadata routing#1139

Merged
BenjaminBossan merged 8 commits into
skorch-dev:masterfrom
adrinjalali:nn-routing-router
Apr 27, 2026
Merged

ENH NeuralNet metadata routing#1139
BenjaminBossan merged 8 commits into
skorch-dev:masterfrom
adrinjalali:nn-routing-router

Conversation

@adrinjalali
Copy link
Copy Markdown
Contributor

Implement sklearn metadata routing (router + consumer)

Closes #1095

Summary

Implements sklearn's metadata routing protocol for NeuralNet. NeuralNet is both a consumer (its module's forward method accepts arbitrary metadata) and a router (it routes metadata like groups to its internal CV splitter).

This follows the same pattern as sklearn's CalibratedClassifierCV: get_metadata_routing() returns a MetadataRouter with add_self_request(self) for self-consumed params plus add(splitter=..., method_mapping=fit→split) for the CV splitter child.

All behavior is gated on sklearn.set_config(enable_metadata_routing=True). When disabled, behavior is identical to before.

What this enables

When enable_metadata_routing=True, metadata flows correctly through sklearn meta-estimators (Pipeline, GridSearchCV, cross_validate) to NeuralNet and its internal components.

Routing groups through a Pipeline to GroupKFold:

import numpy as np
import sklearn
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.datasets import make_regression
from sklearn.model_selection import GroupKFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from torch import nn

from skorch.dataset import ValidSplit
from skorch.regressor import NeuralNetRegressor

sklearn.set_config(enable_metadata_routing=True)


class RegressionModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, X, Z, **kwargs):
        return self.fc(X + Z)


class DictScaler(TransformerMixin, BaseEstimator):
    """Scale only the 'X' key of a dict input, pass 'Z' through."""

    def __init__(self):
        self.scaler_ = StandardScaler()

    def fit(self, X, y=None):
        self.scaler_.fit(X["X"])
        return self

    def transform(self, X):
        result = dict(X)
        result["X"] = self.scaler_.transform(X["X"]).astype("float32")
        return result


X, y = make_regression(n_samples=1000, n_features=20, random_state=0)
X, Z = X[:, :10].astype("float32"), X[:, 10:].astype("float32")
y = y.astype("float32").reshape(-1, 1)
groups = np.array([0] * 500 + [1] * 500)

# Per-sample auxiliary data (Z) is packed into a dict so the
# DataLoader batches it alongside X. Metadata routing handles
# getting `groups` through the Pipeline to the GroupKFold splitter.
X_dict = {"X": X, "Z": Z}

net = NeuralNetRegressor(
    RegressionModule,
    max_epochs=3,
    lr=0.01,
    train_split=ValidSplit(GroupKFold(2)),
)

pipe = make_pipeline(DictScaler(), net)
pipe.fit(X_dict, y, groups=groups)

Changes

skorch/net.py:

  • get_metadata_routing() — returns a MetadataRouter (not MetadataRequest). Registers self as consumer via add_self_request and the inner CV splitter (from ValidSplit.cv) as a child. This is what makes NeuralNet a router.
  • set_fit_request() / set_partial_fit_request() — custom implementations since fit(**fit_params) uses **kwargs and sklearn can't auto-generate these. Follows _BaseScorer.set_score_request pattern. Uses __metadata_request__partial_fit class attribute to suppress conflicting auto-generation (TODO references FIX (SLEP6) descriptor shouldn't override method scikit-learn/scikit-learn#32111).
  • _get_metadata_request() — overrides parent to use class name string as owner instead of the instance. Needed because NeuralNet.__getstate__ uses pickle.dump for CUDA-dependent attributes, which causes deepcopy to block when the module isn't picklable. An alternative would be implementing a full __deepcopy__.
  • __deepcopy__() — bypasses __getstate__/__setstate__ (which use pickle) by deepcopying each attribute individually with a fallback to shallow copy. Needed because sklearn's routing infrastructure calls deepcopy on MetadataRouter objects that hold references to the NeuralNet instance (e.g. via Pipeline's get_metadata_routing).
  • _get_splitter_for_routing() — extracts the CV splitter from ValidSplit.cv when it supports routing (e.g. GroupKFold). Returns None for int/float cv or custom callables.
  • fit_loop() — when routing is enabled, calls process_routing() to validate and route metadata. Extracts split_params for get_split_datasets from routed params. Following sklearn's router pattern, forward_params are the full fit_params (the router uses its own params directly).
  • _get_param_names() — excludes _metadata_request (set by set_fit_request, not __init__; sklearn's clone handles it separately via deepcopy).

What this does NOT cover (future PRs)

  • Predict-path metadata (set_predict_request / set_predict_proba_request)
  • Routing metadata to scoring callbacks (sample_weightEpochScoring)

Test plan

  • set_fit_request / set_partial_fit_request guards and behavior
  • get_metadata_routing returns MetadataRouter with correct children
  • Splitter child present for GroupKFold, absent for int cv and None
  • Extra fit_params reach module forward with routing enabled
  • Extra fit_params don't break ValidSplit with routing enabled
  • groups routed to ValidSplit(GroupKFold) automatically (no set_fit_request needed)
  • clone() preserves metadata requests
  • Pipeline integration with routing enabled
  • GridSearchCV integration with routing enabled
  • Backward compat: legacy fit_params behavior when routing disabled
  • Full existing test suite passes (283 tests, 0 regressions)

Disclaimer: The code is claude generated, but I've reviewed the code, and this is the second iteration of the solution, which I ended up liking, and it matches what we do in sklearn.

cc @BenjaminBossan @tsbinns @DCoupry

@adrinjalali adrinjalali changed the title Nn routing router ENH NeuralNet metadata routing Apr 20, 2026
Copy link
Copy Markdown
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding support for metaddata routing to skorch. Overall, this looks good, but I have a few comments, please check.

We probably can add a Mixin in scikit-learn so that you could avoid using the private type of API call.

So I assume this hasn't happened (yet)?

Comment thread .gitignore Outdated
Comment thread skorch/net.py Outdated
break
setattr(self, 'callbacks_', callbacks_new)

def __deepcopy__(self, memo):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I checked and prior to sklearn 1.8.0, the tests would all pass without this custom __deepcopy__ method. With 1.8.0, we get:

AttributeError: Can't pickle local object 'TestMetadataRouting.test_fit_with_extra_params_and_routing..RecordingModule'

This seems to stem from RecordingModule being defined locally in the test. When I move it to the global scope, the test passes even with sklearn 1.8.0 (other tests with local classes will fail but they all pass when making them global). So it appears that the error comes from a recent change in sklearn that makes something in metadata routing incompatible with locally defined classes.

To me, it would be okay to have a __deepcopy__ method here but it seems there is something bigger going on here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix on sklearn side: scikit-learn/scikit-learn#33827

Comment thread skorch/net.py
Comment thread skorch/net.py
MethodMapping,
_routing_enabled,
process_routing,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure how far back these imports reach, but I checked with sklearn 1.4.0 and the tests passed.

Comment thread skorch/net.py Outdated
# each attribute individually, falling back to shallow copy
# for non-copyable objects (e.g. torch modules with locally
# defined classes).
import copy
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be global import

Comment thread skorch/net.py
router = MetadataRouter(owner=self.__class__.__name__)
router.add_self_request(self)

ts = self.train_split
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

Comment thread skorch/tests/test_net.py
@adrinjalali
Copy link
Copy Markdown
Contributor Author

We probably can add a Mixin in scikit-learn so that you could avoid using the private type of API call.

So I assume this hasn't happened (yet)?

Yeah haven't managed to get it in. It's been not very busy on the metadata routing side on sklearn, since we've been focusing on other projects, but it also means the issues haven't been too urgent, which is a good thing.

The __deepcopy__ issue, however, is real and I'd like to fix it in sklearn, we shouldn't be deepcopy-ing estimators there.

Copy link
Copy Markdown
Collaborator

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your updates, Adrin, and for working on the fix in sklearn too. PR LGTM.

As it's not an urgent change, I'll leave it open for a week or so in case Thomas or someone else wants to review as well.

Comment thread skorch/net.py
if not self.initialized_:
self.initialize()

# When called from fit(), _routing_method is threaded in via
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, too bad that this is needed, but it is what it is.

@tsbinns
Copy link
Copy Markdown

tsbinns commented Apr 22, 2026

As it's not an urgent change, I'll leave it open for a week or so in case Thomas or someone else wants to review as well.

Thanks both for the work!
Sure, could try in the next days to see if the workarounds I'm currently using are redundant now.

@tsbinns
Copy link
Copy Markdown

tsbinns commented Apr 24, 2026

Sure, could try in the next days to see if the workarounds I'm currently using are redundant now.

Just had a go and no more workarounds needed. TYSM @adrinjalali & @BenjaminBossan!

@tsbinns
Copy link
Copy Markdown

tsbinns commented Apr 24, 2026

Just so I can sort any env changes: is there a release planned soon, or will this live in main for a while?

@BenjaminBossan BenjaminBossan merged commit f5a7928 into skorch-dev:master Apr 27, 2026
16 checks passed
@BenjaminBossan
Copy link
Copy Markdown
Collaborator

Just had a go and no more workarounds needed.

Thanks for testing.

Just so I can sort any env changes: is there a release planned soon, or will this live in main for a while?

I think we can do a release soon. It depends a bit on co-maintainer availability, so I can't promise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement sklearn API Metadata routing

3 participants