Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Extend tensor conversion to numpy arrays to work with more device types (#1132)
- Add sklearn metadata routing support: `NeuralNet` is now a metadata router + consumer, enabling `groups` and other metadata to flow through `Pipeline`/`GridSearchCV` (#1139)

### Changed

Expand Down
154 changes: 147 additions & 7 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils.metadata_routing import (
MetadataRouter,
MethodMapping,
UNUSED,
_routing_enabled,
process_routing,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure how far back these imports reach, but I checked with sklearn 1.4.0 and the tests passed.

import torch
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -309,6 +316,13 @@ class NeuralNet(BaseEstimator):
this list.

"""
# Suppress auto-generation of set_partial_fit_request that only
# allows 'classes'. We provide our own that accepts arbitrary
# metadata names, since partial_fit takes **fit_params.
# TODO: remove once scikit-learn/scikit-learn#32111 is merged and
# provides a public API for this.
__metadata_request__partial_fit = {"classes": UNUSED}

prefixes_ = ['iterator_train', 'iterator_valid', 'callbacks', 'dataset', 'compile']

cuda_dependent_attributes_ = []
Expand Down Expand Up @@ -1150,7 +1164,8 @@ def evaluation_step(self, batch, training=False):
self._set_training(training)
return self.infer(Xi)

def fit_loop(self, X, y=None, epochs=None, **fit_params):
def fit_loop(self, X, y=None, epochs=None, *, _routing_method="fit",
**fit_params):
"""The proper fit loop.

Contains the logic of what actually happens during the fit
Expand Down Expand Up @@ -1190,8 +1205,26 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
self.check_training_readiness()
epochs = epochs if epochs is not None else self.max_epochs

if _routing_enabled():
# _routing_method matches the public entry point so the
# right self-request (fit vs partial_fit) is resolved.
routed_params = process_routing(
self, _routing_method, **fit_params)
split_params = routed_params.get(
"splitter", {"split": {}}
)["split"]
# Following sklearn's router+consumer pattern: the router
# uses its own params directly (they're already in
# fit_params), while children get theirs from
# routed_params. The module's forward method should accept
# **kwargs to handle any extra params.
forward_params = fit_params
else:
split_params = fit_params
forward_params = fit_params

dataset_train, dataset_valid = self.get_split_datasets(
X, y, **fit_params)
X, y, **split_params)
on_epoch_kwargs = {
'dataset_train': dataset_train,
'dataset_valid': dataset_valid,
Expand All @@ -1205,10 +1238,10 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
self.notify('on_epoch_begin', **on_epoch_kwargs)

self.run_single_epoch(iterator_train, training=True, prefix="train",
step_fn=self.train_step, **fit_params)
step_fn=self.train_step, **forward_params)

self.run_single_epoch(iterator_valid, training=False, prefix="valid",
step_fn=self.validation_step, **fit_params)
step_fn=self.validation_step, **forward_params)

self.notify("on_epoch_end", **on_epoch_kwargs)
return self
Expand Down Expand Up @@ -1290,9 +1323,15 @@ def partial_fit(self, X, y=None, classes=None, **fit_params):
if not self.initialized_:
self.initialize()

# When called from fit(), _routing_method is threaded in via
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, too bad that this is needed, but it is what it is.

# fit_params so partial_fit's public signature stays clean
# (sklearn introspects it to auto-generate routing machinery).
routing_method = fit_params.pop("_routing_method", "partial_fit")

self.notify('on_train_begin', X=X, y=y)
try:
self.fit_loop(X, y, **fit_params)
self.fit_loop(
X, y, _routing_method=routing_method, **fit_params)
except KeyboardInterrupt:
pass
self.notify('on_train_end', X=X, y=y)
Expand Down Expand Up @@ -1333,9 +1372,105 @@ def fit(self, X, y=None, **fit_params):
if not self.warm_start or not self.initialized_:
self.initialize()

self.partial_fit(X, y, **fit_params)
self.partial_fit(X, y, _routing_method="fit", **fit_params)
return self

def set_fit_request(self, **kwargs):
"""Set requested parameters by the ``fit`` method.

Please see :ref:`sklearn:metadata_routing` for more details.

Since ``NeuralNet.fit`` accepts arbitrary ``**fit_params`` that
are passed to the module's forward method, metadata names cannot
be inferred from the signature and must be declared explicitly
using this method.

Parameters
----------
**kwargs : dict
Arguments should be of the form ``param_name=alias``, where
``alias`` can be one of ``{True, False, None, str}``.

Returns
-------
self : object
The updated object.
"""
if not _routing_enabled():
raise RuntimeError(
"This method is only available when metadata routing is"
" enabled. You can enable it using"
" sklearn.set_config(enable_metadata_routing=True)."
)

requests = self._get_metadata_request()
for param, alias in kwargs.items():
requests.fit.add_request(param=param, alias=alias)
self._metadata_request = requests
return self

def set_partial_fit_request(self, **kwargs):
"""Set requested parameters by the ``partial_fit`` method.

Please see :ref:`sklearn:metadata_routing` for more details.

Since ``NeuralNet.partial_fit`` accepts arbitrary ``**fit_params``
that are passed to the module's forward method, metadata names
cannot be inferred from the signature and must be declared
explicitly using this method.

Parameters
----------
**kwargs : dict
Arguments should be of the form ``param_name=alias``, where
``alias`` can be one of ``{True, False, None, str}``.

Returns
-------
self : object
The updated object.
"""
if not _routing_enabled():
raise RuntimeError(
"This method is only available when metadata routing is"
" enabled. You can enable it using"
" sklearn.set_config(enable_metadata_routing=True)."
)

requests = self._get_metadata_request()
for param, alias in kwargs.items():
requests.partial_fit.add_request(param=param, alias=alias)
self._metadata_request = requests
return self

def get_metadata_routing(self):
"""Get metadata routing of this object.

NeuralNet is both a consumer (its module's forward method
accepts arbitrary metadata) and a router (it routes metadata
like ``groups`` to its internal CV splitter).

Comment thread
BenjaminBossan marked this conversation as resolved.
Returns
-------
routing : MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter`
encapsulating routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__)
router.add_self_request(self)

ts = self.train_split
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

if ts is not None and hasattr(ts, 'cv'):
router.add(
splitter=ts.cv,
method_mapping=(
MethodMapping()
.add(caller="fit", callee="split")
.add(caller="partial_fit", callee="split")
),
)
return router

def check_is_fitted(self, attributes=None, *args, **kwargs):
"""Checks whether the net is initialized

Expand Down Expand Up @@ -1969,7 +2104,12 @@ def get_params_for_optimizer(self, prefix, named_parameters):
return args, kwargs

def _get_param_names(self):
return [k for k in self.__dict__ if not k.endswith('_')]
# Exclude _metadata_request: it is set by set_fit_request()
# (not __init__), and sklearn's clone() handles it separately
# via deepcopy. Including it here would cause clone() to pass
# it to the constructor, which doesn't expect it.
return [k for k in self.__dict__
if not k.endswith('_') and k != '_metadata_request']

def _get_params_callbacks(self, deep=True):
"""sklearn's .get_params checks for `hasattr(value,
Expand Down
Loading
Loading