ENH NeuralNet metadata routing#1139
Conversation
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for adding support for metaddata routing to skorch. Overall, this looks good, but I have a few comments, please check.
We probably can add a Mixin in scikit-learn so that you could avoid using the private type of API call.
So I assume this hasn't happened (yet)?
| break | ||
| setattr(self, 'callbacks_', callbacks_new) | ||
|
|
||
| def __deepcopy__(self, memo): |
There was a problem hiding this comment.
So I checked and prior to sklearn 1.8.0, the tests would all pass without this custom __deepcopy__ method. With 1.8.0, we get:
AttributeError: Can't pickle local object 'TestMetadataRouting.test_fit_with_extra_params_and_routing..RecordingModule'
This seems to stem from RecordingModule being defined locally in the test. When I move it to the global scope, the test passes even with sklearn 1.8.0 (other tests with local classes will fail but they all pass when making them global). So it appears that the error comes from a recent change in sklearn that makes something in metadata routing incompatible with locally defined classes.
To me, it would be okay to have a __deepcopy__ method here but it seems there is something bigger going on here.
There was a problem hiding this comment.
fix on sklearn side: scikit-learn/scikit-learn#33827
| MethodMapping, | ||
| _routing_enabled, | ||
| process_routing, | ||
| ) |
There was a problem hiding this comment.
I wasn't sure how far back these imports reach, but I checked with sklearn 1.4.0 and the tests passed.
| # each attribute individually, falling back to shallow copy | ||
| # for non-copyable objects (e.g. torch modules with locally | ||
| # defined classes). | ||
| import copy |
There was a problem hiding this comment.
Can be global import
| router = MetadataRouter(owner=self.__class__.__name__) | ||
| router.add_self_request(self) | ||
|
|
||
| ts = self.train_split |
Yeah haven't managed to get it in. It's been not very busy on the metadata routing side on sklearn, since we've been focusing on other projects, but it also means the issues haven't been too urgent, which is a good thing. The |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for your updates, Adrin, and for working on the fix in sklearn too. PR LGTM.
As it's not an urgent change, I'll leave it open for a week or so in case Thomas or someone else wants to review as well.
| if not self.initialized_: | ||
| self.initialize() | ||
|
|
||
| # When called from fit(), _routing_method is threaded in via |
There was a problem hiding this comment.
Ah, too bad that this is needed, but it is what it is.
Thanks both for the work! |
Just had a go and no more workarounds needed. TYSM @adrinjalali & @BenjaminBossan! |
|
Just so I can sort any env changes: is there a release planned soon, or will this live in main for a while? |
Thanks for testing.
I think we can do a release soon. It depends a bit on co-maintainer availability, so I can't promise. |
Implement sklearn metadata routing (router + consumer)
Closes #1095
Summary
Implements sklearn's metadata routing protocol for
NeuralNet. NeuralNet is both a consumer (its module's forward method accepts arbitrary metadata) and a router (it routes metadata likegroupsto its internal CV splitter).This follows the same pattern as sklearn's
CalibratedClassifierCV:get_metadata_routing()returns aMetadataRouterwithadd_self_request(self)for self-consumed params plusadd(splitter=..., method_mapping=fit→split)for the CV splitter child.All behavior is gated on
sklearn.set_config(enable_metadata_routing=True). When disabled, behavior is identical to before.What this enables
When
enable_metadata_routing=True, metadata flows correctly through sklearn meta-estimators (Pipeline, GridSearchCV, cross_validate) to NeuralNet and its internal components.Routing
groupsthrough a Pipeline to GroupKFold:Changes
skorch/net.py:get_metadata_routing()— returns aMetadataRouter(notMetadataRequest). Registers self as consumer viaadd_self_requestand the inner CV splitter (fromValidSplit.cv) as a child. This is what makes NeuralNet a router.set_fit_request()/set_partial_fit_request()— custom implementations sincefit(**fit_params)uses**kwargsand sklearn can't auto-generate these. Follows_BaseScorer.set_score_requestpattern. Uses__metadata_request__partial_fitclass attribute to suppress conflicting auto-generation (TODO references FIX (SLEP6) descriptor shouldn't override method scikit-learn/scikit-learn#32111)._get_metadata_request()— overrides parent to use class name string asownerinstead of the instance. Needed becauseNeuralNet.__getstate__usespickle.dumpfor CUDA-dependent attributes, which causesdeepcopyto block when the module isn't picklable. An alternative would be implementing a full__deepcopy__.__deepcopy__()— bypasses__getstate__/__setstate__(which use pickle) by deepcopying each attribute individually with a fallback to shallow copy. Needed because sklearn's routing infrastructure callsdeepcopyon MetadataRouter objects that hold references to the NeuralNet instance (e.g. via Pipeline'sget_metadata_routing)._get_splitter_for_routing()— extracts the CV splitter fromValidSplit.cvwhen it supports routing (e.g.GroupKFold). ReturnsNonefor int/float cv or custom callables.fit_loop()— when routing is enabled, callsprocess_routing()to validate and route metadata. Extractssplit_paramsforget_split_datasetsfrom routed params. Following sklearn's router pattern,forward_paramsare the fullfit_params(the router uses its own params directly)._get_param_names()— excludes_metadata_request(set byset_fit_request, not__init__; sklearn'sclonehandles it separately viadeepcopy).What this does NOT cover (future PRs)
set_predict_request/set_predict_proba_request)sample_weight→EpochScoring)Test plan
set_fit_request/set_partial_fit_requestguards and behaviorget_metadata_routingreturnsMetadataRouterwith correct childrenGroupKFold, absent for int cv andNoneValidSplitwith routing enabledgroupsrouted toValidSplit(GroupKFold)automatically (noset_fit_requestneeded)clone()preserves metadata requestsDisclaimer: The code is
claudegenerated, but I've reviewed the code, and this is the second iteration of the solution, which I ended up liking, and it matches what we do in sklearn.cc @BenjaminBossan @tsbinns @DCoupry