From 351f5620e1691195f81a370bf14c4957e8290aaa Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 19 May 2026 14:40:14 +0200 Subject: [PATCH 01/34] feat: generalise dataset split definitions - make split definitions coordinate-explicit so splitting can be extended beyond time in future - support both fraction-based splits and explicit tuple-range splits such as ("2020-01-01T12:00", "2021-01-01T12:00") - replace the fixed train_ratio/val_ratio API with nested split config and dataset subset handling across data, config, and tests --- docs/config_diagram.svg | 360 +++++++++++----------- src/mlcast/config/__init__.py | 10 +- src/mlcast/config/base.py | 3 +- src/mlcast/config/fiddlers.py | 5 + src/mlcast/data/source_data_datamodule.py | 140 ++++++--- src/mlcast/data/source_data_datasets.py | 68 +++- src/mlcast/data/splits.py | 195 ++++++++++++ tests/data/test_data_module.py | 180 ++++++++--- tests/data/test_source_data_datasets.py | 19 +- tests/test_cli_training.py | 7 +- tests/test_readme_snippets.py | 3 +- 11 files changed, 701 insertions(+), 289 deletions(-) create mode 100644 src/mlcast/data/splits.py diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index b2ac4ac..102ee0a 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -4,11 +4,11 @@ - - + + %3 - + 2 @@ -35,60 +35,60 @@ 1 - - -Config: - NowcastLightningModule - - -network - - - - - -ensemble_size - -2 - - -loss_class - -'crps' - - -loss_params - - - -dict - - -'temporal_lambda' - -0.01 - - -masked_loss - -True - - -optimizer - - - - - -lr_scheduler - - - + + +Config: + NowcastLightningModule + + +network + + + + + +ensemble_size + +2 + + +loss_class + +'crps' + + +loss_params + + + +dict + + +'temporal_lambda' + +0.01 + + +masked_loss + +True + + +optimizer + + + + + +lr_scheduler + + + 1:c--2:c - + @@ -111,7 +111,7 @@ 1:c--3:c - + @@ -139,142 +139,156 @@ 1:c--4:c - + 0 - - -Config: - Experiment - - -pl_module - - - - - -data - - - - - -trainer - - - + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + 0:c--1:c - + 5 - - -Config: - SourceDataDataModule - - -dataset_factory - - - - - -train_ratio - -0.7 - - -val_ratio - -0.15 - - -batch_size - -16 - - -num_workers - -8 - - -pin_memory - -True + + +Config: + SourceDataDataModule + + +dataset_factory + + + + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +batch_size + +16 + + +num_workers + +8 + + +pin_memory + +True 0:c--5:c - + 7 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - - - - -0 - - -1 - - -2 - - -3 - - -max_epochs - -100 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + + + + +0 + + +1 + + +2 + + +3 + + +max_epochs + +100 0:c--7:c - + @@ -329,7 +343,7 @@ 5:c--6:c - + @@ -352,7 +366,7 @@ 7:c--8:c - + @@ -380,7 +394,7 @@ 7:c--9:c - + @@ -408,7 +422,7 @@ 7:c--10:c - + @@ -436,7 +450,7 @@ 7:c--11:c - + @@ -454,7 +468,7 @@ 7:c--12:c - + diff --git a/src/mlcast/config/__init__.py b/src/mlcast/config/__init__.py index 22acef1..194ad6b 100644 --- a/src/mlcast/config/__init__.py +++ b/src/mlcast/config/__init__.py @@ -6,7 +6,14 @@ from .base import Experiment, training_experiment from .consistency_checks import validate_config -from .fiddlers import set_variables, toggle_masking, use_anon_s3_dataset, use_mlflow_logger, use_random_sampler +from .fiddlers import ( + set_variables, + toggle_masking, + use_anon_s3_dataset, + use_mlflow_logger, + use_random_sampler, + use_ratio_splits, +) from .loader import load_yaml_config from .orchestrator import train_from_config @@ -19,6 +26,7 @@ "set_variables", "toggle_masking", "use_random_sampler", + "use_ratio_splits", "use_mlflow_logger", "use_anon_s3_dataset", ] diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index 34a38d6..64678d1 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -75,8 +75,7 @@ def training_experiment() -> Experiment: data = SourceDataDataModule( dataset_factory=dataset_factory, - train_ratio=0.70, - val_ratio=0.15, + splits={"time": {"train": 0.70, "val": 0.15}}, batch_size=16, num_workers=8, pin_memory=True, diff --git a/src/mlcast/config/fiddlers.py b/src/mlcast/config/fiddlers.py index 52584c9..2f45f0a 100644 --- a/src/mlcast/config/fiddlers.py +++ b/src/mlcast/config/fiddlers.py @@ -86,6 +86,11 @@ def use_random_sampler(cfg: fdl.Config) -> None: ) +def use_ratio_splits(cfg: fdl.Config, train: float, val: float) -> None: + """Fiddler to set ratio-based train/val splits on the data module.""" + cfg.data.splits = {"time": {"train": train, "val": val}} + + def use_anon_s3_dataset(cfg: fdl.Buildable, zarr_path: str, endpoint_url: str) -> None: """Configure the dataset factory to read anonymously from an S3 object store. diff --git a/src/mlcast/data/source_data_datamodule.py b/src/mlcast/data/source_data_datamodule.py index a327f90..afe253e 100644 --- a/src/mlcast/data/source_data_datamodule.py +++ b/src/mlcast/data/source_data_datamodule.py @@ -7,10 +7,18 @@ from collections.abc import Callable from typing import Any +import fiddle as fdl import pytorch_lightning as pl import xarray as xr from torch.utils.data import DataLoader, Dataset +from mlcast.data.splits import ( + compute_split_ranges_from_splitting_ratios, + splitting_uses_fractions, + splitting_uses_tuple_ranges, + validate_splits, +) + class SourceDataDataModule(pl.LightningDataModule): """PyTorch Lightning data module for spatio-temporal datasets. @@ -22,11 +30,12 @@ class SourceDataDataModule(pl.LightningDataModule): ---------- dataset_factory : Callable[..., Dataset] A factory function (e.g., ``fdl.Partial``) that returns a Dataset instance. - It must accept ``time_slice`` and ``augment`` as keyword arguments. - train_ratio : float, optional - Fraction of data used for training. Default is ``0.7``. - val_ratio : float, optional - Fraction of data used for validation. Default is ``0.15``. + It must accept ``subset`` and ``augment`` as keyword arguments. + splits : dict of {str: dict}, optional + Nested mapping ``{coord: {split_name: value, ...}, ...}`` describing + train/val/test subsets. Currently only the ``time`` coordinate is + supported. Ratio mode uses float fractions, while datetime mode uses + inclusive ``(start, end)`` ISO 8601 string tuples. **dataloader_kwargs : Any Additional keyword arguments forwarded to ``DataLoader`` (e.g., ``batch_size``, ``num_workers``, ``pin_memory``). @@ -35,51 +44,73 @@ class SourceDataDataModule(pl.LightningDataModule): def __init__( self, dataset_factory: Callable[..., Dataset], - train_ratio: float = 0.7, - val_ratio: float = 0.15, + splits: dict[str, dict[str, Any]] | None = None, **dataloader_kwargs: Any, ) -> None: super().__init__() self.dataset_factory = dataset_factory - self.train_ratio = train_ratio - self.val_ratio = val_ratio + self.splits = splits if splits is not None else {"time": {"train": 0.70, "val": 0.15}} self.dataloader_kwargs = dataloader_kwargs + validate_splits(self.splits) def setup(self, stage: str | None = None) -> None: """Create train, validation, and test datasets. - Splits are chronological based on the total number of timesteps in - the Zarr store, determined by opening it directly before instantiating - any Dataset objects. + Splits are assembled into per-dataset ``subset`` dictionaries. + Datetime-mode splits are passed through unchanged, while ratio-mode + splits are first resolved against the zarr coordinate values and then + converted to inclusive coordinate ranges before dataset instantiation. + + Parameters + ---------- + stage : str | None, optional + Lightning stage hint such as ``"fit"`` or ``"test"``. The value + is accepted for framework compatibility and is otherwise unused. """ - # We need the total number of timesteps to compute split boundaries. - # Duck-type the factory to extract zarr_path and storage_options — - # functools.partial stores kwargs in .keywords, fdl.Partial exposes - # them as attributes. - zarr_path = getattr(self.dataset_factory, "zarr_path", None) or self.dataset_factory.keywords["zarr_path"] - storage_options = getattr(self.dataset_factory, "storage_options", None) or self.dataset_factory.keywords.get( - "storage_options" - ) - n = xr.open_zarr(zarr_path, storage_options=storage_options).sizes["time"] - - train_end = int(n * self.train_ratio) - # Compute val_end independently from train_end rather than from the - # accumulated sum of ratios, to avoid floating-point truncation errors - # (e.g. int(240 * (0.5 + 1/3)) = int(199.999...) = 199 instead of 200). - val_end = train_end + int(n * self.val_ratio) - - self.train_dataset = self.dataset_factory( - time_slice=slice(0, train_end), - augment=True, - ) - self.val_dataset = self.dataset_factory( - time_slice=slice(train_end, val_end), - augment=False, - ) - self.test_dataset = self.dataset_factory( - time_slice=slice(val_end, n), - augment=False, - ) + subset_per_split: dict[str, dict[str, Any] | None] = { + "train": {}, + "val": {}, + "test": {}, + } + + for coord, coord_splits in self.splits.items(): + if splitting_uses_tuple_ranges(coord_splits): + # tuple-based splits are expected to present the start and end + # of each split, and so are passed through directly as the + # subset values for each split + coord_values_per_split: dict[str, tuple[str, str] | None] = { + "train": coord_splits["train"], + "val": coord_splits["val"], + "test": coord_splits.get("test"), + } + elif splitting_uses_fractions(coord_splits): + # for ratio-based splits, the splitting start-end range tuples + # are constructed by breaking up the given coordinate in + # successive segments (the succession is defined from the order + # of the keys in the splits dict) + coord_values_per_split = compute_split_ranges_from_splitting_ratios( + self.dataset_factory, coord, coord_splits + ) + else: + raise NotImplementedError(f"Unsupported split mode for coordinate {coord!r}: {coord_splits!r}") + + for split_name, split_val in coord_values_per_split.items(): + if split_val is None: + subset_per_split[split_name] = None + elif subset_per_split[split_name] is not None: + subset_per_split[split_name][coord] = split_val + + augment_flags = {"train": True, "val": False, "test": False} + for split in ("train", "val", "test"): + subset = subset_per_split[split] + if subset is None: + setattr(self, f"{split}_dataset", None) + else: + setattr( + self, + f"{split}_dataset", + self.dataset_factory(subset=subset, augment=augment_flags[split]), + ) def train_dataloader(self) -> DataLoader: """Return the training DataLoader.""" @@ -92,3 +123,32 @@ def val_dataloader(self) -> DataLoader: def test_dataloader(self) -> DataLoader: """Return the test DataLoader.""" return DataLoader(self.test_dataset, shuffle=False, **self.dataloader_kwargs) + + +def count_split_samples(cfg: fdl.Config) -> dict[str, Any]: + """Return dataset counts plus time extent for a built experiment config.""" + data_module: SourceDataDataModule = fdl.build(cfg.data) + + zarr_path = ( + getattr(data_module.dataset_factory, "zarr_path", None) or data_module.dataset_factory.keywords["zarr_path"] + ) + storage_options = getattr( + data_module.dataset_factory, "storage_options", None + ) or data_module.dataset_factory.keywords.get("storage_options") + ds = xr.open_zarr(zarr_path, storage_options=storage_options) + time_values = ds.indexes["time"] + + data_module.setup() + counts: dict[str, int] = {} + for split in ("train", "val", "test"): + dataset = getattr(data_module, f"{split}_dataset", None) + if dataset is not None: + counts[split] = len(dataset) + + return { + "samples": counts, + "zarr_tmin": str(time_values[0]), + "zarr_tmax": str(time_values[-1]), + "zarr_nsteps": len(time_values), + "splits": data_module.splits, + } diff --git a/src/mlcast/data/source_data_datasets.py b/src/mlcast/data/source_data_datasets.py index 643c82b..a838a7c 100644 --- a/src/mlcast/data/source_data_datasets.py +++ b/src/mlcast/data/source_data_datasets.py @@ -20,6 +20,24 @@ from mlcast.data.normalization import NORMALIZATION_REGISTRY +def _time_range_to_index_slice( + zarr_path: str, + time_range: tuple[str, str], + storage_options: dict[str, Any] | None = None, +) -> slice: + """Convert an inclusive ISO time range into a zarr integer slice.""" + ds = xr.open_zarr(zarr_path, storage_options=storage_options) + time_values = ds.indexes["time"] + t_start = time_values.get_indexer([pd.Timestamp(time_range[0])], method="bfill")[0] + t_end = time_values.get_indexer([pd.Timestamp(time_range[1])], method="ffill")[0] + if t_start < 0 or t_end < 0: + raise ValueError( + f"time_range {time_range!r} falls entirely outside the zarr time coordinate " + f"({time_values[0]} - {time_values[-1]})." + ) + return slice(int(t_start), int(t_end) + 1) + + class DatasetSample(TypedDict, total=False): """Typed dictionary returned by dataset ``__getitem__``. @@ -181,8 +199,8 @@ def ds(self) -> xr.Dataset: """ if self._ds is None: ds = xr.open_zarr(self._zarr_path, storage_options=self.storage_options) - if self._time_slice is not None: - ds = ds.isel(time=self._time_slice) + if self._time_index_slice is not None: + ds = ds.isel(time=self._time_index_slice) self._ds = ds return self._ds @@ -318,8 +336,9 @@ class SourceDataPrecomputedSamplingDataset(SourceDataDatasetBase): If ``True``, use a fixed random seed (42) for reproducibility. Default is ``False``. augment : bool, optional If ``True``, apply random spatial augmentations (rotation, flips). Default is ``False``. - time_slice : slice or None, optional - Subset of row indices to use from the CSV for train/val splitting. + subset : dict or None, optional + Coordinate subsetting specification. Only ``{"time": (start, end)}`` + is supported, where the time range is inclusive and uses ISO strings. width : int, optional Spatial width of each crop. Default is ``256``. height : int, optional @@ -338,13 +357,23 @@ def __init__( return_mask: bool = False, deterministic: bool = False, augment: bool = False, - time_slice: slice | None = None, + subset: dict[str, Any] | None = None, width: int = 256, height: int = 256, time_depth: int = 24, storage_options: dict[str, Any] | None = None, ) -> None: - self._time_slice: slice | None = None # required by base ds property before super().__init__ opens store + if subset: + for key in subset: + if key != "time": + raise NotImplementedError( + f"subset key {key!r} is not supported. Only 'time' subsetting is currently implemented." + ) + time_range: tuple[str, str] | None = (subset or {}).get("time") + if time_range is not None: + self._time_index_slice: slice | None = _time_range_to_index_slice(zarr_path, time_range, storage_options) + else: + self._time_index_slice = None super().__init__( zarr_path=zarr_path, standard_names=standard_names, @@ -359,8 +388,12 @@ def __init__( ) self.coords = pd.read_csv(csv_path).sort_values("t") - if time_slice is not None: - self.coords = self.coords.iloc[time_slice].reset_index(drop=True) + if self._time_index_slice is not None: + t_start = self._time_index_slice.start + t_stop = self._time_index_slice.stop + self.coords = self.coords[(self.coords["t"] >= t_start) & (self.coords["t"] < t_stop)].reset_index( + drop=True + ) self.dt = time_depth @@ -441,8 +474,9 @@ class SourceDataRandomSamplingDataset(SourceDataDatasetBase): If ``True``, use a fixed random seed (42) for reproducibility. Default is ``False``. augment : bool, optional If ``True``, apply random spatial augmentations (rotation, flips). Default is ``False``. - time_slice : slice or None, optional - Subset of time indices to use for train/val splitting. + subset : dict or None, optional + Coordinate subsetting specification. Only ``{"time": (start, end)}`` + is supported, where the time range is inclusive and uses ISO strings. width : int, optional Spatial width of each crop. Default is ``256``. height : int, optional @@ -462,14 +496,24 @@ def __init__( return_mask: bool = False, deterministic: bool = False, augment: bool = False, - time_slice: slice | None = None, + subset: dict[str, Any] | None = None, width: int = 256, height: int = 256, epoch_size: int = 1000, storage_options: dict[str, Any] | None = None, **kwargs: Any, ) -> None: - self._time_slice = time_slice # required by base ds property before super().__init__ opens store + if subset: + for key in subset: + if key != "time": + raise NotImplementedError( + f"subset key {key!r} is not supported. Only 'time' subsetting is currently implemented." + ) + time_range: tuple[str, str] | None = (subset or {}).get("time") + if time_range is not None: + self._time_index_slice: slice | None = _time_range_to_index_slice(zarr_path, time_range, storage_options) + else: + self._time_index_slice = None super().__init__( zarr_path=zarr_path, standard_names=standard_names, diff --git a/src/mlcast/data/splits.py b/src/mlcast/data/splits.py new file mode 100644 index 0000000..2e280a1 --- /dev/null +++ b/src/mlcast/data/splits.py @@ -0,0 +1,195 @@ +""" +Utilities to facilitate dataset splitting based on coordinate values, supporting two modes of specification: + +1. Fraction mode: each split is defined by a fraction of the total coordinate +range (e.g., 0.7 for training, 0.2 for validation, and the remainder for +testing). The fractions are resolved into coordinate value range tuples by inspecting +the coordinate values of the source dataset. + +2. Tuple-range mode: each split is defined by an explicit (start, end) tuple of +coordinate values (e.g., ("2020-01-01", "2020-12-31") for training). In this +mode, the split values are passed through directly as the subset configuration +for each split. +""" + +from collections.abc import Callable +from numbers import Real +from typing import Any + +import xarray as xr +from torch.utils.data import Dataset + +_SPLIT_NAMES = frozenset({"train", "val", "test"}) +_SUPPORTED_COORDS = frozenset({"time"}) + + +def splitting_uses_fractions(coord_splits: dict[str, Any]) -> bool: + """Return whether a coordinate split config uses fraction mode. + + Parameters + ---------- + coord_splits : dict[str, Any] + Split configuration for a single coordinate. + + Returns + ------- + bool + ``True`` when all defined split values use numeric fractions rather than + datetime tuples. + """ + return all( + split_val is None or (isinstance(split_val, Real) and not isinstance(split_val, bool)) + for split_val in coord_splits.values() + ) + + +def splitting_uses_tuple_ranges(coord_splits: dict[str, Any]) -> bool: + """Return whether a coordinate split config uses tuple-range mode. + + Parameters + ---------- + coord_splits : dict[str, Any] + Split configuration for a single coordinate. + + Returns + ------- + bool + ``True`` when all defined split values use ``(start, end)`` tuples + rather than numeric fractions. + """ + return all(split_val is None or isinstance(split_val, tuple) for split_val in coord_splits.values()) + + +def validate_splits(splits: dict[str, dict[str, Any]]) -> None: + """Validate the nested ``splits`` configuration for the data module. + + Validates that: + + - the configuration is not empty; + - each coordinate name is supported; + - each split name is one of ``train``, ``val``, or ``test``; + - each coordinate defines both ``train`` and ``val``; + - each coordinate uses exactly one supported split mode; + - fraction mode uses only float-like values for defined splits; + - fraction mode split values sum to at most ``1.0``; and + - tuple-range mode defines ``test`` explicitly as either a tuple or ``None``. + + Parameters + ---------- + splits : dict[str, dict[str, Any]] + Nested mapping ``{coord: {split_name: value, ...}, ...}`` describing + dataset splits. + + Raises + ------ + ValueError + If the split configuration is empty, uses unsupported coordinates or + split names, omits required split entries, mixes split modes within a + coordinate, or provides invalid values for the selected mode. + """ + if not splits: + raise ValueError("splits must not be empty.") + + unknown_coords = set(splits) - _SUPPORTED_COORDS + if unknown_coords: + raise ValueError( + f"Unknown coordinate(s) in splits: {sorted(unknown_coords)}. Supported: {sorted(_SUPPORTED_COORDS)}." + ) + + for coord, coord_splits in splits.items(): + unknown_names = set(coord_splits) - _SPLIT_NAMES + if unknown_names: + raise ValueError( + f"Unknown split name(s) in splits[{coord!r}]: {sorted(unknown_names)}. " + f"Must be one of {sorted(_SPLIT_NAMES)}." + ) + + for required in ("train", "val"): + if required not in coord_splits: + raise ValueError(f"splits[{coord!r}] must contain '{required}'.") + + train_is_tuple = isinstance(coord_splits["train"], tuple) + val_is_tuple = isinstance(coord_splits["val"], tuple) + if train_is_tuple != val_is_tuple: + raise ValueError( + f"Cannot mix datetime tuples and float ratios in splits[{coord!r}]. " + "'train' and 'val' must both be floats (fraction mode) or both be " + "(start, end) tuples (tuple-range mode)." + ) + + if splitting_uses_fractions(coord_splits): + for split_name in ("train", "val"): + split_val = coord_splits[split_name] + if not isinstance(split_val, Real) or isinstance(split_val, bool): + raise ValueError( + f"In fraction mode splits[{coord!r}]['{split_name}'] must be float-like, got {split_val!r}." + ) + ratio_sum = coord_splits["train"] + coord_splits["val"] + test_val = coord_splits.get("test") + if test_val is not None and (not isinstance(test_val, Real) or isinstance(test_val, bool)): + raise ValueError( + f"In fraction mode splits[{coord!r}]['test'] must be float-like or None, got {test_val!r}." + ) + if isinstance(test_val, Real) and not isinstance(test_val, bool): + ratio_sum += test_val + if ratio_sum > 1.0 + 1e-9: + raise ValueError(f"Split fractions in splits[{coord!r}] sum to {ratio_sum:.4f}, which exceeds 1.0.") + elif splitting_uses_tuple_ranges(coord_splits): + if "test" not in coord_splits: + raise ValueError( + f"In tuple-range mode splits[{coord!r}] must contain 'test' " + "(set to a (start, end) tuple or None to skip the test split)." + ) + test_val = coord_splits["test"] + if test_val is not None and not isinstance(test_val, tuple): + raise ValueError( + f"In tuple-range mode splits[{coord!r}]['test'] must be a " + f"(start, end) tuple or None, got {test_val!r}." + ) + else: + raise ValueError( + f"splits[{coord!r}] must use a single supported mode: all defined split values must be either " + "float-like fractions or tuple ranges." + ) + + +def compute_split_ranges_from_splitting_ratios( + dataset_factory: Callable[..., Dataset], + coord: str, + coord_splits: dict[str, Any], +) -> dict[str, tuple[str, str]]: + """Resolve fraction-mode splits into inclusive coordinate ranges. + + Parameters + ---------- + dataset_factory : Callable[..., Dataset] + Dataset factory carrying the ``zarr_path`` and optional + ``storage_options`` needed to open the source zarr store. + coord : str + Coordinate name to split. Currently this is expected to be ``"time"``. + coord_splits : dict[str, Any] + Fraction-mode split configuration for a single coordinate. Must contain + float values for ``"train"`` and ``"val"``. + + Returns + ------- + dict[str, tuple[str, str]] + Inclusive ``(start, end)`` coordinate ranges for the ``train``, + ``val``, and ``test`` splits. + """ + zarr_path = getattr(dataset_factory, "zarr_path", None) or dataset_factory.keywords["zarr_path"] + storage_options = getattr(dataset_factory, "storage_options", None) or dataset_factory.keywords.get( + "storage_options" + ) + ds = xr.open_zarr(zarr_path, storage_options=storage_options) + coord_vals = ds.indexes[coord] + n = len(coord_vals) + + train_end = int(n * coord_splits["train"]) + val_end = train_end + int(n * coord_splits["val"]) + + return { + "train": (str(coord_vals[0]), str(coord_vals[train_end - 1])), + "val": (str(coord_vals[train_end]), str(coord_vals[val_end - 1])), + "test": (str(coord_vals[val_end]), str(coord_vals[n - 1])), + } diff --git a/tests/data/test_data_module.py b/tests/data/test_data_module.py index f3fce6d..3f58234 100644 --- a/tests/data/test_data_module.py +++ b/tests/data/test_data_module.py @@ -1,64 +1,134 @@ import functools from unittest.mock import MagicMock, patch +import pandas as pd import pytest from torch.utils.data import DataLoader, Dataset from mlcast.data.source_data_datamodule import SourceDataDataModule +from mlcast.data.splits import splitting_uses_fractions, splitting_uses_tuple_ranges, validate_splits class MockDataset(Dataset): - """Minimal dataset mock. + """Minimal dataset mock that records how it was constructed. - ``__len__`` returns the number of time steps covered by ``time_slice`` - so that dataloader batch-count assertions work correctly. + ``__len__`` returns a fixed size so that dataloader batch-count assertions + work correctly. """ - def __init__(self, zarr_path: str, time_slice: slice | None = None, augment: bool = False, **kwargs) -> None: + def __init__( + self, + zarr_path: str, + subset: dict | None = None, + augment: bool = False, + epoch_size: int = 100, + **kwargs, + ) -> None: self.zarr_path = zarr_path - self.time_slice = time_slice + self.subset = subset self.augment = augment + self.epoch_size = epoch_size self.kwargs = kwargs def __len__(self) -> int: - if self.time_slice is not None: - return self.time_slice.stop - self.time_slice.start - return 0 + return self.epoch_size def __getitem__(self, idx: int) -> dict: return {"data": idx} -def _mock_zarr(n_time: int) -> MagicMock: - """Return a mock xr.Dataset with a given time dimension size.""" +def _mock_zarr(time_index: pd.DatetimeIndex) -> MagicMock: + """Return a mock xr.Dataset with a given pandas DatetimeIndex for time.""" mock_ds = MagicMock() - mock_ds.sizes = {"time": n_time} + mock_ds.indexes = {"time": time_index} return mock_ds -def test_data_module_splits() -> None: - """Test DataModule chronological split boundaries. +def _make_time_index(n: int, start: str = "2016-01-01", freq: str = "10min") -> pd.DatetimeIndex: + return pd.date_range(start=start, periods=n, freq=freq) - Uses 100 time steps, train_ratio=0.5, val_ratio=0.2: - train_end = int(100 * 0.5) = 50 - val_end = 50 + int(100 * 0.2) = 70 - test = 70 to 100 - """ + +def test_validate_splits_ratio_mode() -> None: + validate_splits({"time": {"train": 0.7, "val": 0.15}}) + validate_splits({"time": {"train": 0.7, "val": 0.15, "test": 0.15}}) + validate_splits({"time": {"train": 0.7, "val": 0.15, "test": None}}) + + +def test_validate_splits_datetime_mode() -> None: + validate_splits( + {"time": {"train": ("2016-01-01", "2021-12-31"), "val": ("2022-01-01", "2023-12-31"), "test": None}} + ) + + +def test_validate_splits_missing_train() -> None: + with pytest.raises(ValueError, match="must contain 'train'"): + validate_splits({"time": {"val": 0.2}}) + + +def test_validate_splits_ratio_exceeds_one() -> None: + with pytest.raises(ValueError, match="sum to"): + validate_splits({"time": {"train": 0.8, "val": 0.3}}) + + +def test_validate_splits_ratio_requires_float_like_values() -> None: + with pytest.raises(ValueError, match="float-like"): + validate_splits({"time": {"train": "0.8", "val": 0.2}}) + + with pytest.raises(ValueError, match="float-like"): + validate_splits({"time": {"train": 0.8, "val": 0.1, "test": "0.1"}}) + + +def test_validate_splits_mixed_mode() -> None: + with pytest.raises(ValueError, match="mix"): + validate_splits({"time": {"train": 0.7, "val": ("2022-01-01", "2023-12-31")}}) + + +def test_validate_splits_datetime_missing_test() -> None: + with pytest.raises(ValueError, match="must contain 'test'"): + validate_splits({"time": {"train": ("2016-01-01", "2021-12-31"), "val": ("2022-01-01", "2023-12-31")}}) + + +def test_validate_splits_unknown_coord() -> None: + with pytest.raises(ValueError, match="Unknown coordinate"): + validate_splits({"space": {"train": 0.7, "val": 0.2}}) + + +def test_splitting_mode_helpers_require_consistent_values() -> None: + assert splitting_uses_fractions({"train": 0.7, "val": 0.2, "test": None}) + assert not splitting_uses_fractions({"train": 0.7, "val": ("2022-01-01", "2022-12-31")}) + assert splitting_uses_tuple_ranges( + {"train": ("2016-01-01", "2021-12-31"), "val": ("2022-01-01", "2023-12-31"), "test": None} + ) + assert not splitting_uses_tuple_ranges({"train": object(), "val": object()}) + + +def test_data_module_ratio_splits() -> None: + """DataModule ratio mode passes correct time subsets to the factory.""" + n = 100 + time_index = _make_time_index(n) dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr", foo="bar") - dm = SourceDataDataModule(dataset_factory=dataset_factory, train_ratio=0.5, val_ratio=0.2, batch_size=2) + dm = SourceDataDataModule( + dataset_factory=dataset_factory, splits={"time": {"train": 0.5, "val": 0.2}}, batch_size=2 + ) - with patch("mlcast.data.source_data_datamodule.xr.open_zarr", return_value=_mock_zarr(100)): + with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="fit") - assert dm.train_dataset.time_slice == slice(0, 50) assert dm.train_dataset.augment is True assert dm.train_dataset.kwargs["foo"] == "bar" + train_start, train_end = dm.train_dataset.subset["time"] + val_start, val_end = dm.val_dataset.subset["time"] + test_start, test_end = dm.test_dataset.subset["time"] - assert dm.val_dataset.time_slice == slice(50, 70) - assert dm.val_dataset.augment is False + assert train_start == str(time_index[0]) + assert train_end == str(time_index[49]) + assert val_start == str(time_index[50]) + assert val_end == str(time_index[69]) + assert test_start == str(time_index[70]) + assert test_end == str(time_index[99]) - assert dm.test_dataset.time_slice == slice(70, 100) + assert dm.val_dataset.augment is False assert dm.test_dataset.augment is False train_dl = dm.train_dataloader() @@ -82,38 +152,54 @@ def __call__(self, **kwargs) -> Dataset: def test_data_module_split_lengths_and_batches() -> None: """Test that dataset lengths and dataloader batch counts are correct after splitting. - Uses 240 time steps with a 1/2, 1/3, 1/6 train/val/test split and - batch_size=10, chosen so all splits divide evenly and expected batch - counts are easy to verify without rounding. - - Split boundaries (computed independently per split to avoid float accumulation): - train_end = int(240 * 1/2) = 120 - val_end = 120 + int(240 * 1/3) = 120 + 80 = 200 - test = 240 - 200 = 40 - - Expected dataset lengths and batch counts at batch_size=10: - train : 120 samples -> 12 batches - val : 80 samples -> 8 batches - test : 40 samples -> 4 batches + Dataloader batch counts are correct after splitting. """ n_time = 240 batch_size = 10 - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + time_index = _make_time_index(n_time) + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr", epoch_size=10) dm = SourceDataDataModule( dataset_factory=dataset_factory, - train_ratio=1 / 2, - val_ratio=1 / 3, + splits={"time": {"train": 1 / 2, "val": 1 / 3}}, batch_size=batch_size, ) - with patch("mlcast.data.source_data_datamodule.xr.open_zarr", return_value=_mock_zarr(n_time)): + with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup() - assert len(dm.train_dataset) == 120, "train split should cover timesteps 0–120" - assert len(dm.val_dataset) == 80, "val split should cover timesteps 120–200" - assert len(dm.test_dataset) == 40, "test split should cover timesteps 200–240" + assert len(dm.train_dataloader()) == 1 + assert len(dm.val_dataloader()) == 1 + assert len(dm.test_dataloader()) == 1 + + +def test_data_module_datetime_splits() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + + dm = SourceDataDataModule( + dataset_factory=dataset_factory, + splits={ + "time": { + "train": ("2016-01-01", "2021-12-31"), + "val": ("2022-01-01", "2023-12-31"), + "test": None, + } + }, + batch_size=4, + ) + + dm.setup() + + assert dm.train_dataset.subset == {"time": ("2016-01-01", "2021-12-31")} + assert dm.val_dataset.subset == {"time": ("2022-01-01", "2023-12-31")} + assert dm.test_dataset is None - assert len(dm.train_dataloader()) == 12, "120 samples / batch_size 10 = 12 batches" - assert len(dm.val_dataloader()) == 8, "80 samples / batch_size 10 = 8 batches" - assert len(dm.test_dataloader()) == 4, "40 samples / batch_size 10 = 4 batches" + +def test_data_module_unsupported_split_mode() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + dm = SourceDataDataModule(dataset_factory=dataset_factory) + + dm.splits = {"time": {"train": object(), "val": object()}} + + with pytest.raises(NotImplementedError, match="Unsupported split mode"): + dm.setup() diff --git a/tests/data/test_source_data_datasets.py b/tests/data/test_source_data_datasets.py index 25eeb64..6b4f56c 100644 --- a/tests/data/test_source_data_datasets.py +++ b/tests/data/test_source_data_datasets.py @@ -3,6 +3,7 @@ import pandas as pd import pytest import torch +import xarray as xr from mlcast.data.source_data_datasets import ( SourceDataPrecomputedSamplingDataset, @@ -59,15 +60,17 @@ def test_precomputed_sampling_dataset(fp_test_dataset: Path, mock_csv: str) -> N assert isinstance(target_mask_t, torch.Tensor) -def test_precomputed_sampling_dataset_time_slice(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that time_slice correctly slices the CSV.""" +def test_precomputed_sampling_dataset_time_subset(fp_test_dataset: Path, mock_csv: str) -> None: + """Test that subset correctly filters CSV rows by time range.""" + zarr_ds = xr.open_zarr(str(fp_test_dataset)) + time_index = zarr_ds.indexes["time"] ds = SourceDataPrecomputedSamplingDataset( zarr_path=str(fp_test_dataset), csv_path=mock_csv, standard_names=["rainfall_flux"], input_steps=2, forecast_steps=1, - time_slice=slice(0, 2), + subset={"time": (str(time_index[0]), str(time_index[8]))}, ) assert len(ds) == 2 @@ -115,18 +118,20 @@ def test_random_sampling_dataset(fp_test_dataset: Path) -> None: assert target_mask_t.shape == (forecast_steps, 1, 32, 32) -def test_random_sampling_dataset_time_slice(fp_test_dataset: Path) -> None: - """Test that time_slice correctly slices the Zarr store.""" +def test_random_sampling_dataset_time_subset(fp_test_dataset: Path) -> None: + """Test that subset correctly slices the Zarr store.""" + zarr_ds = xr.open_zarr(str(fp_test_dataset)) + time_index = zarr_ds.indexes["time"] ds = SourceDataRandomSamplingDataset( zarr_path=str(fp_test_dataset), standard_names=["rainfall_flux"], input_steps=3, forecast_steps=2, - time_slice=slice(0, 50), + subset={"time": (str(time_index[0]), str(time_index[49]))}, epoch_size=10, ) - assert ds.max_t == 50 # Since it was sliced to 50 + assert ds.max_t == 50 assert len(ds) == 10 diff --git a/tests/test_cli_training.py b/tests/test_cli_training.py index 73440ad..82c60fb 100644 --- a/tests/test_cli_training.py +++ b/tests/test_cli_training.py @@ -27,9 +27,7 @@ def test_cli_train_command(fp_test_dataset: Path, tmp_path: Path) -> None: "--config", "set:data.dataset_factory.standard_names=['rainfall_flux']", "--config", - "set:data.train_ratio=0.4", - "--config", - "set:data.val_ratio=0.3", # Test ratio becomes 0.3 (30 steps > 18 required) + "set:data.splits={'time': {'train': 0.4, 'val': 0.3}}", "--config", "set:trainer.fast_dev_run=True", "--config", @@ -61,8 +59,7 @@ def test_cli_train_from_yaml_config(fp_test_dataset: Path, tmp_path: Path) -> No use_random_sampler(cfg) cfg.data.dataset_factory.standard_names = ["rainfall_flux"] cfg.data.dataset_factory.zarr_path = str(fp_test_dataset.absolute()) - cfg.data.train_ratio = 0.4 - cfg.data.val_ratio = 0.3 # test ratio becomes 0.3 (30 steps > 18 required) + cfg.data.splits = {"time": {"train": 0.4, "val": 0.3}} cfg.trainer.fast_dev_run = True cfg.data.batch_size = 1 cfg.data.num_workers = 0 diff --git a/tests/test_readme_snippets.py b/tests/test_readme_snippets.py index dd630c1..a6bba89 100644 --- a/tests/test_readme_snippets.py +++ b/tests/test_readme_snippets.py @@ -124,8 +124,7 @@ def _patch_cfg(cfg: fdl.Config, fp_dataset: Path, tmp_path: Path) -> None: set_variables(cfg, standard_names=["rainfall_flux"]) # Switch to the on-the-fly random sampler so no pre-computed CSV is needed. use_random_sampler(cfg) - cfg.data.train_ratio = 0.4 - cfg.data.val_ratio = 0.3 + cfg.data.splits = {"time": {"train": 0.4, "val": 0.3}} cfg.trainer.fast_dev_run = True cfg.data.batch_size = 1 cfg.data.num_workers = 0 From a5298998b54b1a58f9ce272be40a10dc675560bd Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 19 May 2026 15:36:28 +0200 Subject: [PATCH 02/34] fix: require explicit fraction-based test splits - only create a test dataset for fraction-mode splits when a test fraction is explicitly configured - warn when configured split fractions do not sum to 1.0 because any remainder will be unused - require SourceDataDataModule splits to be provided explicitly and update examples and tests accordingly --- docs/config_diagram.svg | 363 +++++++++++----------- src/mlcast/config/base.py | 2 +- src/mlcast/config/fiddlers.py | 4 +- src/mlcast/data/source_data_datamodule.py | 11 +- src/mlcast/data/splits.py | 27 +- tests/data/test_data_module.py | 52 +++- tests/test_cli_training.py | 4 +- tests/test_readme_snippets.py | 2 +- 8 files changed, 263 insertions(+), 202 deletions(-) diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 102ee0a..76275eb 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -4,11 +4,11 @@ - - + + %3 - + 2 @@ -35,60 +35,60 @@ 1 - - -Config: - NowcastLightningModule - - -network - - - - - -ensemble_size - -2 - - -loss_class - -'crps' - - -loss_params - - - -dict - - -'temporal_lambda' - -0.01 - - -masked_loss - -True - - -optimizer - - - - - -lr_scheduler - - - + + +Config: + NowcastLightningModule + + +network + + + + + +ensemble_size + +2 + + +loss_class + +'crps' + + +loss_params + + + +dict + + +'temporal_lambda' + +0.01 + + +masked_loss + +True + + +optimizer + + + + + +lr_scheduler + + + 1:c--2:c - + @@ -111,7 +111,7 @@ 1:c--3:c - + @@ -139,156 +139,161 @@ 1:c--4:c - + 0 - - -Config: - Experiment + + +Config: + Experiment + + +pl_module + + + -pl_module +data - + -data +trainer - + - - -trainer - - - 0:c--1:c - + 5 - - -Config: - SourceDataDataModule - - -dataset_factory - - - - - -splits - - - -dict - - -'time' - - - -dict - - -'train' - -0.7 - - -'val' - -0.15 - - -batch_size - -16 - - -num_workers - -8 - - -pin_memory - -True + + +Config: + SourceDataDataModule + + +dataset_factory + + + + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +16 + + +num_workers + +8 + + +pin_memory + +True 0:c--5:c - + 7 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - - - - -0 - - -1 - - -2 - - -3 - - -max_epochs - -100 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + + + + +0 + + +1 + + +2 + + +3 + + +max_epochs + +100 0:c--7:c - + @@ -343,7 +348,7 @@ 5:c--6:c - + @@ -366,7 +371,7 @@ 7:c--8:c - + @@ -394,7 +399,7 @@ 7:c--9:c - + @@ -422,7 +427,7 @@ 7:c--10:c - + @@ -450,7 +455,7 @@ 7:c--11:c - + @@ -468,7 +473,7 @@ 7:c--12:c - + diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index 64678d1..fca8117 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -75,7 +75,7 @@ def training_experiment() -> Experiment: data = SourceDataDataModule( dataset_factory=dataset_factory, - splits={"time": {"train": 0.70, "val": 0.15}}, + splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, batch_size=16, num_workers=8, pin_memory=True, diff --git a/src/mlcast/config/fiddlers.py b/src/mlcast/config/fiddlers.py index 2f45f0a..dceee26 100644 --- a/src/mlcast/config/fiddlers.py +++ b/src/mlcast/config/fiddlers.py @@ -87,8 +87,8 @@ def use_random_sampler(cfg: fdl.Config) -> None: def use_ratio_splits(cfg: fdl.Config, train: float, val: float) -> None: - """Fiddler to set ratio-based train/val splits on the data module.""" - cfg.data.splits = {"time": {"train": train, "val": val}} + """Fiddler to set fraction-based train/val/test splits on the data module.""" + cfg.data.splits = {"time": {"train": train, "val": val, "test": 1.0 - train - val}} def use_anon_s3_dataset(cfg: fdl.Buildable, zarr_path: str, endpoint_url: str) -> None: diff --git a/src/mlcast/data/source_data_datamodule.py b/src/mlcast/data/source_data_datamodule.py index afe253e..7500a40 100644 --- a/src/mlcast/data/source_data_datamodule.py +++ b/src/mlcast/data/source_data_datamodule.py @@ -31,7 +31,7 @@ class SourceDataDataModule(pl.LightningDataModule): dataset_factory : Callable[..., Dataset] A factory function (e.g., ``fdl.Partial``) that returns a Dataset instance. It must accept ``subset`` and ``augment`` as keyword arguments. - splits : dict of {str: dict}, optional + splits : dict of {str: dict} Nested mapping ``{coord: {split_name: value, ...}, ...}`` describing train/val/test subsets. Currently only the ``time`` coordinate is supported. Ratio mode uses float fractions, while datetime mode uses @@ -44,12 +44,12 @@ class SourceDataDataModule(pl.LightningDataModule): def __init__( self, dataset_factory: Callable[..., Dataset], - splits: dict[str, dict[str, Any]] | None = None, + splits: dict[str, dict[str, Any]], **dataloader_kwargs: Any, ) -> None: super().__init__() self.dataset_factory = dataset_factory - self.splits = splits if splits is not None else {"time": {"train": 0.70, "val": 0.15}} + self.splits = splits self.dataloader_kwargs = dataloader_kwargs validate_splits(self.splits) @@ -68,9 +68,8 @@ def setup(self, stage: str | None = None) -> None: is accepted for framework compatibility and is otherwise unused. """ subset_per_split: dict[str, dict[str, Any] | None] = { - "train": {}, - "val": {}, - "test": {}, + split_name: {} if any(split_name in coord_splits for coord_splits in self.splits.values()) else None + for split_name in ("train", "val", "test") } for coord, coord_splits in self.splits.items(): diff --git a/src/mlcast/data/splits.py b/src/mlcast/data/splits.py index 2e280a1..b037dcc 100644 --- a/src/mlcast/data/splits.py +++ b/src/mlcast/data/splits.py @@ -2,8 +2,8 @@ Utilities to facilitate dataset splitting based on coordinate values, supporting two modes of specification: 1. Fraction mode: each split is defined by a fraction of the total coordinate -range (e.g., 0.7 for training, 0.2 for validation, and the remainder for -testing). The fractions are resolved into coordinate value range tuples by inspecting +range (e.g., 0.7 for training, 0.2 for validation, and 0.1 for testing). +The fractions are resolved into coordinate value range tuples by inspecting the coordinate values of the source dataset. 2. Tuple-range mode: each split is defined by an explicit (start, end) tuple of @@ -17,6 +17,7 @@ from typing import Any import xarray as xr +from loguru import logger from torch.utils.data import Dataset _SPLIT_NAMES = frozenset({"train", "val", "test"}) @@ -134,6 +135,12 @@ def validate_splits(splits: dict[str, dict[str, Any]]) -> None: ratio_sum += test_val if ratio_sum > 1.0 + 1e-9: raise ValueError(f"Split fractions in splits[{coord!r}] sum to {ratio_sum:.4f}, which exceeds 1.0.") + if abs(ratio_sum - 1.0) > 1e-9: + logger.warning( + "Split fractions in splits[{}] sum to {:.4f}, not 1.0. Any unallocated remainder will be unused.", + coord, + ratio_sum, + ) elif splitting_uses_tuple_ranges(coord_splits): if "test" not in coord_splits: raise ValueError( @@ -169,13 +176,13 @@ def compute_split_ranges_from_splitting_ratios( Coordinate name to split. Currently this is expected to be ``"time"``. coord_splits : dict[str, Any] Fraction-mode split configuration for a single coordinate. Must contain - float values for ``"train"`` and ``"val"``. + float values for ``"train"`` and ``"val"``. ``"test"`` is optional; + when omitted or set to ``None``, no test split range is returned. Returns ------- dict[str, tuple[str, str]] - Inclusive ``(start, end)`` coordinate ranges for the ``train``, - ``val``, and ``test`` splits. + Inclusive ``(start, end)`` coordinate ranges for the configured splits. """ zarr_path = getattr(dataset_factory, "zarr_path", None) or dataset_factory.keywords["zarr_path"] storage_options = getattr(dataset_factory, "storage_options", None) or dataset_factory.keywords.get( @@ -188,8 +195,14 @@ def compute_split_ranges_from_splitting_ratios( train_end = int(n * coord_splits["train"]) val_end = train_end + int(n * coord_splits["val"]) - return { + split_ranges = { "train": (str(coord_vals[0]), str(coord_vals[train_end - 1])), "val": (str(coord_vals[train_end]), str(coord_vals[val_end - 1])), - "test": (str(coord_vals[val_end]), str(coord_vals[n - 1])), } + + test_fraction = coord_splits.get("test") + if isinstance(test_fraction, Real) and not isinstance(test_fraction, bool): + test_end = val_end + int(n * test_fraction) + split_ranges["test"] = (str(coord_vals[val_end]), str(coord_vals[test_end - 1])) + + return split_ranges diff --git a/tests/data/test_data_module.py b/tests/data/test_data_module.py index 3f58234..cbf18c3 100644 --- a/tests/data/test_data_module.py +++ b/tests/data/test_data_module.py @@ -78,6 +78,13 @@ def test_validate_splits_ratio_requires_float_like_values() -> None: validate_splits({"time": {"train": 0.8, "val": 0.1, "test": "0.1"}}) +def test_validate_splits_warns_when_fraction_sum_is_not_one() -> None: + with patch("mlcast.data.splits.logger.warning") as mock_warning: + validate_splits({"time": {"train": 0.7, "val": 0.15}}) + + mock_warning.assert_called_once() + + def test_validate_splits_mixed_mode() -> None: with pytest.raises(ValueError, match="mix"): validate_splits({"time": {"train": 0.7, "val": ("2022-01-01", "2023-12-31")}}) @@ -109,7 +116,7 @@ def test_data_module_ratio_splits() -> None: dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr", foo="bar") dm = SourceDataDataModule( - dataset_factory=dataset_factory, splits={"time": {"train": 0.5, "val": 0.2}}, batch_size=2 + dataset_factory=dataset_factory, splits={"time": {"train": 0.5, "val": 0.2, "test": 0.3}}, batch_size=2 ) with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): @@ -143,12 +150,30 @@ class _NoZarrPathFactory: def __call__(self, **kwargs) -> Dataset: return MagicMock(spec=Dataset) - dm = SourceDataDataModule(dataset_factory=_NoZarrPathFactory()) + dm = SourceDataDataModule(dataset_factory=_NoZarrPathFactory(), splits={"time": {"train": 0.7, "val": 0.15}}) with pytest.raises((AttributeError, KeyError)): dm.setup() +def test_data_module_fraction_splits_without_test_do_not_create_test_dataset() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + time_index = _make_time_index(100) + + dm = SourceDataDataModule( + dataset_factory=dataset_factory, + splits={"time": {"train": 0.5, "val": 0.2}}, + batch_size=2, + ) + + with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): + dm.setup() + + assert dm.train_dataset is not None + assert dm.val_dataset is not None + assert dm.test_dataset is None + + def test_data_module_split_lengths_and_batches() -> None: """Test that dataset lengths and dataloader batch counts are correct after splitting. @@ -161,7 +186,7 @@ def test_data_module_split_lengths_and_batches() -> None: dm = SourceDataDataModule( dataset_factory=dataset_factory, - splits={"time": {"train": 1 / 2, "val": 1 / 3}}, + splits={"time": {"train": 1 / 2, "val": 1 / 3, "test": 1 / 6}}, batch_size=batch_size, ) @@ -195,9 +220,28 @@ def test_data_module_datetime_splits() -> None: assert dm.test_dataset is None +def test_data_module_fraction_test_split_uses_explicit_fraction() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + time_index = _make_time_index(100) + + dm = SourceDataDataModule( + dataset_factory=dataset_factory, + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, + batch_size=2, + ) + + with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): + dm.setup() + + assert dm.test_dataset is not None + test_start, test_end = dm.test_dataset.subset["time"] + assert test_start == str(time_index[70]) + assert test_end == str(time_index[79]) + + def test_data_module_unsupported_split_mode() -> None: dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - dm = SourceDataDataModule(dataset_factory=dataset_factory) + dm = SourceDataDataModule(dataset_factory=dataset_factory, splits={"time": {"train": 0.7, "val": 0.15}}) dm.splits = {"time": {"train": object(), "val": object()}} diff --git a/tests/test_cli_training.py b/tests/test_cli_training.py index 82c60fb..4f44da7 100644 --- a/tests/test_cli_training.py +++ b/tests/test_cli_training.py @@ -27,7 +27,7 @@ def test_cli_train_command(fp_test_dataset: Path, tmp_path: Path) -> None: "--config", "set:data.dataset_factory.standard_names=['rainfall_flux']", "--config", - "set:data.splits={'time': {'train': 0.4, 'val': 0.3}}", + "set:data.splits={'time': {'train': 0.4, 'val': 0.3, 'test': 0.3}}", "--config", "set:trainer.fast_dev_run=True", "--config", @@ -59,7 +59,7 @@ def test_cli_train_from_yaml_config(fp_test_dataset: Path, tmp_path: Path) -> No use_random_sampler(cfg) cfg.data.dataset_factory.standard_names = ["rainfall_flux"] cfg.data.dataset_factory.zarr_path = str(fp_test_dataset.absolute()) - cfg.data.splits = {"time": {"train": 0.4, "val": 0.3}} + cfg.data.splits = {"time": {"train": 0.4, "val": 0.3, "test": 0.3}} cfg.trainer.fast_dev_run = True cfg.data.batch_size = 1 cfg.data.num_workers = 0 diff --git a/tests/test_readme_snippets.py b/tests/test_readme_snippets.py index a6bba89..7dc3ace 100644 --- a/tests/test_readme_snippets.py +++ b/tests/test_readme_snippets.py @@ -124,7 +124,7 @@ def _patch_cfg(cfg: fdl.Config, fp_dataset: Path, tmp_path: Path) -> None: set_variables(cfg, standard_names=["rainfall_flux"]) # Switch to the on-the-fly random sampler so no pre-computed CSV is needed. use_random_sampler(cfg) - cfg.data.splits = {"time": {"train": 0.4, "val": 0.3}} + cfg.data.splits = {"time": {"train": 0.4, "val": 0.3, "test": 0.3}} cfg.trainer.fast_dev_run = True cfg.data.batch_size = 1 cfg.data.num_workers = 0 From efae80df405c52bc02ec831d978ef4832f112e18 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 19 May 2026 15:54:53 +0200 Subject: [PATCH 03/34] refactor: log split summary during datamodule setup - emit per-split sample counts and resolved subset ranges directly from SourceDataDataModule.setup() - remove the unused count_split_samples helper now that the setup logging provides the same operational visibility --- src/mlcast/data/source_data_datamodule.py | 43 +++++++---------------- tests/data/test_data_module.py | 19 ++++++++++ 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/mlcast/data/source_data_datamodule.py b/src/mlcast/data/source_data_datamodule.py index 7500a40..89b5ac6 100644 --- a/src/mlcast/data/source_data_datamodule.py +++ b/src/mlcast/data/source_data_datamodule.py @@ -7,9 +7,8 @@ from collections.abc import Callable from typing import Any -import fiddle as fdl import pytorch_lightning as pl -import xarray as xr +from loguru import logger from torch.utils.data import DataLoader, Dataset from mlcast.data.splits import ( @@ -111,6 +110,17 @@ def setup(self, stage: str | None = None) -> None: self.dataset_factory(subset=subset, augment=augment_flags[split]), ) + logger.info("{}.setup() complete, containing:", self.__class__.__name__) + for split in ("train", "val", "test"): + dataset = getattr(self, f"{split}_dataset", None) + if dataset is not None: + logger.info( + " {:5s}: {:>6d} samples, subset={}", + split, + len(dataset), + subset_per_split[split], + ) + def train_dataloader(self) -> DataLoader: """Return the training DataLoader.""" return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs) @@ -122,32 +132,3 @@ def val_dataloader(self) -> DataLoader: def test_dataloader(self) -> DataLoader: """Return the test DataLoader.""" return DataLoader(self.test_dataset, shuffle=False, **self.dataloader_kwargs) - - -def count_split_samples(cfg: fdl.Config) -> dict[str, Any]: - """Return dataset counts plus time extent for a built experiment config.""" - data_module: SourceDataDataModule = fdl.build(cfg.data) - - zarr_path = ( - getattr(data_module.dataset_factory, "zarr_path", None) or data_module.dataset_factory.keywords["zarr_path"] - ) - storage_options = getattr( - data_module.dataset_factory, "storage_options", None - ) or data_module.dataset_factory.keywords.get("storage_options") - ds = xr.open_zarr(zarr_path, storage_options=storage_options) - time_values = ds.indexes["time"] - - data_module.setup() - counts: dict[str, int] = {} - for split in ("train", "val", "test"): - dataset = getattr(data_module, f"{split}_dataset", None) - if dataset is not None: - counts[split] = len(dataset) - - return { - "samples": counts, - "zarr_tmin": str(time_values[0]), - "zarr_tmax": str(time_values[-1]), - "zarr_nsteps": len(time_values), - "splits": data_module.splits, - } diff --git a/tests/data/test_data_module.py b/tests/data/test_data_module.py index cbf18c3..cb2e5af 100644 --- a/tests/data/test_data_module.py +++ b/tests/data/test_data_module.py @@ -239,6 +239,25 @@ def test_data_module_fraction_test_split_uses_explicit_fraction() -> None: assert test_end == str(time_index[79]) +def test_data_module_logs_split_summary() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + time_index = _make_time_index(100) + + dm = SourceDataDataModule( + dataset_factory=dataset_factory, + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, + batch_size=2, + ) + + with ( + patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)), + patch("mlcast.data.source_data_datamodule.logger.info") as mock_info, + ): + dm.setup() + + assert mock_info.call_count == 4 + + def test_data_module_unsupported_split_mode() -> None: dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") dm = SourceDataDataModule(dataset_factory=dataset_factory, splits={"time": {"train": 0.7, "val": 0.15}}) From 174ad6e716bfb2adfc6e64c0d5e4b6d936104d6d Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 19 May 2026 16:01:42 +0200 Subject: [PATCH 04/34] fix: make datamodule setup stage-aware - build only train and validation datasets for the Lightning fit stage - build only the requested validation or test datasets for validate and test stages - document the stage-dependent setup behavior and cover it with datamodule tests --- src/mlcast/data/source_data_datamodule.py | 33 ++++++++++- tests/data/test_data_module.py | 71 +++++++++++++++++++++-- 2 files changed, 97 insertions(+), 7 deletions(-) diff --git a/src/mlcast/data/source_data_datamodule.py b/src/mlcast/data/source_data_datamodule.py index 89b5ac6..44f99bd 100644 --- a/src/mlcast/data/source_data_datamodule.py +++ b/src/mlcast/data/source_data_datamodule.py @@ -59,15 +59,42 @@ def setup(self, stage: str | None = None) -> None: Datetime-mode splits are passed through unchanged, while ratio-mode splits are first resolved against the zarr coordinate values and then converted to inclusive coordinate ranges before dataset instantiation. + Dataset construction depends on the requested Lightning stage: + + - ``"fit"`` builds train and validation datasets; + - ``"validate"`` builds only the validation dataset; + - ``"test"`` builds only the test dataset; and + - ``None`` builds all configured datasets. Parameters ---------- stage : str | None, optional - Lightning stage hint such as ``"fit"`` or ``"test"``. The value - is accepted for framework compatibility and is otherwise unused. + Lightning stage hint controlling which datasets are constructed. + + Raises + ------ + ValueError + If ``stage`` is not one of ``None``, ``"fit"``, ``"validate"``, + or ``"test"``. """ + if stage == "fit": + requested_splits = {"train", "val"} + elif stage == "validate": + requested_splits = {"val"} + elif stage == "test": + requested_splits = {"test"} + elif stage is None: + requested_splits = {"train", "val", "test"} + else: + raise ValueError(f"Unsupported LightningDataModule setup stage: {stage!r}") + subset_per_split: dict[str, dict[str, Any] | None] = { - split_name: {} if any(split_name in coord_splits for coord_splits in self.splits.values()) else None + split_name: ( + {} + if split_name in requested_splits + and any(split_name in coord_splits for coord_splits in self.splits.values()) + else None + ) for split_name in ("train", "val", "test") } diff --git a/tests/data/test_data_module.py b/tests/data/test_data_module.py index cb2e5af..e55bf58 100644 --- a/tests/data/test_data_module.py +++ b/tests/data/test_data_module.py @@ -126,17 +126,14 @@ def test_data_module_ratio_splits() -> None: assert dm.train_dataset.kwargs["foo"] == "bar" train_start, train_end = dm.train_dataset.subset["time"] val_start, val_end = dm.val_dataset.subset["time"] - test_start, test_end = dm.test_dataset.subset["time"] assert train_start == str(time_index[0]) assert train_end == str(time_index[49]) assert val_start == str(time_index[50]) assert val_end == str(time_index[69]) - assert test_start == str(time_index[70]) - assert test_end == str(time_index[99]) assert dm.val_dataset.augment is False - assert dm.test_dataset.augment is False + assert dm.test_dataset is None train_dl = dm.train_dataloader() assert isinstance(train_dl, DataLoader) @@ -239,6 +236,72 @@ def test_data_module_fraction_test_split_uses_explicit_fraction() -> None: assert test_end == str(time_index[79]) +def test_data_module_fit_stage_creates_only_train_and_val() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + time_index = _make_time_index(100) + + dm = SourceDataDataModule( + dataset_factory=dataset_factory, + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, + batch_size=2, + ) + + with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): + dm.setup(stage="fit") + + assert dm.train_dataset is not None + assert dm.val_dataset is not None + assert dm.test_dataset is None + + +def test_data_module_validate_stage_creates_only_val() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + time_index = _make_time_index(100) + + dm = SourceDataDataModule( + dataset_factory=dataset_factory, + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, + batch_size=2, + ) + + with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): + dm.setup(stage="validate") + + assert dm.train_dataset is None + assert dm.val_dataset is not None + assert dm.test_dataset is None + + +def test_data_module_test_stage_creates_only_test() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + time_index = _make_time_index(100) + + dm = SourceDataDataModule( + dataset_factory=dataset_factory, + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, + batch_size=2, + ) + + with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): + dm.setup(stage="test") + + assert dm.train_dataset is None + assert dm.val_dataset is None + assert dm.test_dataset is not None + + +def test_data_module_rejects_unknown_stage() -> None: + dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") + dm = SourceDataDataModule( + dataset_factory=dataset_factory, + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, + batch_size=2, + ) + + with pytest.raises(ValueError, match="Unsupported LightningDataModule setup stage"): + dm.setup(stage="predict") + + def test_data_module_logs_split_summary() -> None: dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") time_index = _make_time_index(100) From 036c4d18bf076607011d4df98d8edd419f1764ef Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 10:49:18 +0200 Subject: [PATCH 05/34] fix: cast source dataset samples to float32 - convert normalized sample arrays to float32 before torch.from_numpy so float64 source data does not leak into training tensors - make stacked channel arrays contiguous float32 views in both dataset sampling paths and cover the returned tensor dtypes in tests --- src/mlcast/data/source_data_datasets.py | 12 +++++++++--- tests/data/test_source_data_datasets.py | 3 +++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/mlcast/data/source_data_datasets.py b/src/mlcast/data/source_data_datasets.py index 643c82b..efea32b 100644 --- a/src/mlcast/data/source_data_datasets.py +++ b/src/mlcast/data/source_data_datasets.py @@ -261,7 +261,9 @@ def _build_sample(self, data: np.ndarray) -> DatasetSample: if self.return_mask: target_mask_t = torch.from_numpy((~np.isnan(data[self.input_steps :])).astype(np.float32)) - data = np.nan_to_num(data, nan=-1.0) + # source data may be float64, but the model and the rest of the + # training pipeline operate in float32. + data = np.nan_to_num(data, nan=-1.0).astype(np.float32) data_t = torch.from_numpy(data) input_t = data_t[: self.input_steps] @@ -413,7 +415,9 @@ def __getitem__(self, idx: int) -> DatasetSample: norm_func = NORMALIZATION_REGISTRY[std_name] channels.append(norm_func(da_var.values)) - data = np.swapaxes(np.stack(channels, axis=0), 0, 1) + # swapaxes returns a view; make it contiguous and float32 before + # handing it to _build_sample()/torch.from_numpy(). + data = np.ascontiguousarray(np.swapaxes(np.stack(channels, axis=0), 0, 1), dtype=np.float32) return self._build_sample(data) @@ -540,5 +544,7 @@ def __getitem__(self, idx: int) -> DatasetSample: norm_func = NORMALIZATION_REGISTRY[std_name] channels.append(norm_func(da_var.values)) - data = np.swapaxes(np.stack(channels, axis=0), 0, 1) + # swapaxes returns a view; make it contiguous and float32 before + # handing it to _build_sample()/torch.from_numpy(). + data = np.ascontiguousarray(np.swapaxes(np.stack(channels, axis=0), 0, 1), dtype=np.float32) return self._build_sample(data) diff --git a/tests/data/test_source_data_datasets.py b/tests/data/test_source_data_datasets.py index 25eeb64..9add476 100644 --- a/tests/data/test_source_data_datasets.py +++ b/tests/data/test_source_data_datasets.py @@ -113,6 +113,9 @@ def test_random_sampling_dataset(fp_test_dataset: Path) -> None: assert input_t.shape == (input_steps, 1, 32, 32) assert target_t.shape == (forecast_steps, 1, 32, 32) assert target_mask_t.shape == (forecast_steps, 1, 32, 32) + assert input_t.dtype == torch.float32 + assert target_t.dtype == torch.float32 + assert target_mask_t.dtype == torch.float32 def test_random_sampling_dataset_time_slice(fp_test_dataset: Path) -> None: From 4988226bffecf7bdb1309c85d996725d49d929b5 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 11:41:07 +0200 Subject: [PATCH 06/34] fix: scope data gitignore pattern to repo root - change the broad data/ ignore rule to /data/ so src/mlcast/data is no longer matched accidentally - avoid needing force-adds for tracked source files under the mlcast.data package --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index cd32ae8..ac454da 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,6 @@ dist/ .venv/ logs/ .pytest_cache/ -data/ +/data/ mlruns/ MagicMock/ From df028fe7b59a2b9fca748caf661d3ebe5bf66586 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 12:03:09 +0200 Subject: [PATCH 07/34] docs: add ldcast refactor plan --- ldcast-refactor-plan.md | 69 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 ldcast-refactor-plan.md diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md new file mode 100644 index 0000000..e6674e5 --- /dev/null +++ b/ldcast-refactor-plan.md @@ -0,0 +1,69 @@ +# LDCast Refactor Plan + +0. Config naming and CLI contract +- [ ] Rename `training_experiment` to `convgru_training_experiment`. +- [ ] Do not keep `training_experiment` as an alias. +- [ ] Reserve `ldcast_training_experiment` as the top-level config name for the new two-stage LDCast workflow. +- [ ] Require `mlcast train` users to provide an explicit base config via `--config config:` or `--config /path/to/config.yaml`. +- [ ] Update CLI help text to list the included config entry points explicitly. +- [ ] Update all docs, examples, tests, and scripts to use `convgru_training_experiment` instead of `training_experiment`. +- [ ] Treat `convgru_training_experiment` as the existing ConvGRU forecasting example, not as a special default config. + +1. Forecasting and reconstruction data +- [ ] Rename the existing sampled-sequence dataset classes into `src/mlcast/data/forecasting.py`. +- [ ] Rename `SourceDataDatasetBase`, `SourceDataPrecomputedSamplingDataset`, and `SourceDataRandomSamplingDataset` into forecasting-oriented names under the forecasting data area. +- [ ] Remove the old source-data public API rather than keeping compatibility re-exports. +- [ ] Keep the existing sampled-sequence dataset implementation as the forecasting data source. +- [ ] Add `src/mlcast/data/reconstruction.py`. +- [ ] Add `ReconstructionDataset`, a thin wrapper around a `base_forecasting_dataset` that returns only the input tensor `x`. +- [ ] Add `ReconstructionDataModule`, which remains factory-based, builds the underlying forecasting datasets, splits them into train/val/test, and wraps each split with `ReconstructionDataset`. +- [ ] Keep this generic: no LDCast-specific naming in the module or class names. +- [ ] Stage 1 should use only the input window from the forecasting dataset. +- [ ] Allow `forecast_steps == 0` in forecasting datasets, but emit a warning when used. + +2. Autoencoder model architecture +- Autoencoder model split: + - [ ] `src/mlcast/models/autoencoder/encoder.py` for `Encoder` and `EncoderBlock`. + - [ ] `src/mlcast/models/autoencoder/decoder.py` for `Decoder` and `DecoderBlock`. + - [ ] `src/mlcast/models/autoencoder/net.py` for `AutoencoderNet`. +- Autoencoder validation and tests: + - [ ] encoder output shape. + - [ ] decoder output shape. + - [ ] autoencoder reconstruction forward pass. + - [ ] autoencoder improves reconstruction loss on a small generated dataset after a few training steps. + +3. Diffusion model architecture +- Diffusion model split: + - [ ] `src/mlcast/models/diffusion/conditioner.py` for latent conditioning blocks and `ConditionerNet`. + - [ ] `src/mlcast/models/diffusion/denoiser.py` for `DenoiserUNet` and timestep-aware helpers. + - [ ] `src/mlcast/models/diffusion/net.py` for `LatentDiffusionNet`. + - [ ] `src/mlcast/models/diffusion/forecasting.py` for `LatentDiffusionForecaster`, the diffusion-specific adapter that exposes `forward(x, forecast_steps, ensemble_size)`. + - [ ] `src/mlcast/models/diffusion/scheduler.py`, `ema.py`, `sampler.py`, `loss.py` for diffusion support code. +- Validation and tests: + - [ ] diffusion forecasting adapter API. + - [ ] diffusion model improves loss on a small generated latent dataset after a few training steps. + +4. Task wrappers +- [ ] Add `src/mlcast/modules/forecasting.py` and rename `NowcastLightningModule` to `ForecastingModule`. +- [ ] Add `src/mlcast/modules/reconstruction.py` with a generic `ReconstructionModule` for any reconstruction model. +- [ ] Keep `modules/` for training/task wrappers only; keep `models/` for pure architectures. + +5. Training experiment +- [ ] Add a new LDCast-specific training module containing `LDCastTrainingExperiment`. +- [ ] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. +- [ ] Stage 1 builds the reconstruction dataset, autoencoder model, and reconstruction module, then trains the autoencoder. +- [ ] Stage 2 reuses the same trained in-memory encoder instance, builds the diffusion dataset/model/module, then trains latent diffusion. +- [ ] The shared Fiddle graph should define the encoder once and reference the same object in both stages, but no unresolved Fiddle objects should flow into actual `torch.nn.Module.__init__` calls. +- [ ] The decoder is stage-1 only and is not shared into stage 2. +- [ ] Reuse the same forecasting dataset abstraction in stage 2; do not add a separate latent dataset layer. +- [ ] Add tests for shared object identity and stage sequencing. + +6. Audit and migration targets +- [ ] Update CLI/help text in `src/mlcast/__main__.py` to require an explicit base config and list the included config entry points. +- [ ] Rename `training_experiment` to `convgru_training_experiment` in `src/mlcast/config/base.py` and export it from `src/mlcast/config/__init__.py`. +- [ ] Add the LDCast config entry point to `src/mlcast/config/__init__.py` alongside the existing ConvGRU example config. +- [ ] Keep `src/mlcast/config/orchestrator.py` compatible with both the existing single-stage `Experiment` and the new `LDCastTrainingExperiment` through a common `run()` surface. +- [ ] Update docstrings and comments that currently imply `training_experiment` is the only experiment, including `src/mlcast/data/source_data_datamodule.py`, `src/mlcast/config/orchestrator.py`, and related config docs. +- [ ] Update docs and scripts that still reference `training_experiment`, including `README.md` and `docs/generate_base_experiment_config_diagram.py`. +- [ ] Keep existing ConvGRU CLI/config tests passing while adding separate tests for selecting the LDCast config explicitly. +- [ ] Add real but small-scale end-to-end tests with generated sample data for the autoencoder stage, diffusion stage, and full LDCast stage sequencing. From 8f3ce1014816c264edbefa1fd3bc2fa6153a76f6 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 13:51:53 +0200 Subject: [PATCH 08/34] docs: update ldcast refactor plan --- ldcast-refactor-plan.md | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index e6674e5..da03015 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -10,45 +10,62 @@ - [ ] Treat `convgru_training_experiment` as the existing ConvGRU forecasting example, not as a special default config. 1. Forecasting and reconstruction data -- [ ] Rename the existing sampled-sequence dataset classes into `src/mlcast/data/forecasting.py`. -- [ ] Rename `SourceDataDatasetBase`, `SourceDataPrecomputedSamplingDataset`, and `SourceDataRandomSamplingDataset` into forecasting-oriented names under the forecasting data area. +- [ ] Move the existing sampled-sequence source-data logic into `src/mlcast/data/sequence.py`. +- [ ] Rename `SourceDataDatasetBase`, `SourceDataPrecomputedSamplingDataset`, and `SourceDataRandomSamplingDataset` to `SourceDataSequenceDatasetBase`, `SourceDataPrecomputedSequenceDataset`, and `SourceDataRandomSequenceDataset` under the sequence data area. - [ ] Remove the old source-data public API rather than keeping compatibility re-exports. -- [ ] Keep the existing sampled-sequence dataset implementation as the forecasting data source. +- [ ] Keep the existing sampled-sequence implementation as the source-data sequence layer. +- [ ] Sequence datasets should own normalization and return normalized tensors of shape `(sequence_steps, channels, height, width)`. +- [ ] Replace forecasting-specific sampling parameters in the source-data sequence layer with a single `sequence_steps` parameter. +- [ ] Add `src/mlcast/data/forecasting.py`. +- [ ] Add a generic `ForecastingDataset` that wraps a base sequence dataset, takes `input_steps` and `forecast_steps`, validates `input_steps + forecast_steps == sequence_steps`, and returns forecasting samples. +- [ ] `ForecastingDataset` should derive `target_mask` itself rather than relying on the base sequence dataset to return masks. - [ ] Add `src/mlcast/data/reconstruction.py`. -- [ ] Add `ReconstructionDataset`, a thin wrapper around a `base_forecasting_dataset` that returns only the input tensor `x`. -- [ ] Add `ReconstructionDataModule`, which remains factory-based, builds the underlying forecasting datasets, splits them into train/val/test, and wraps each split with `ReconstructionDataset`. +- [ ] Add `ReconstructionDataset`, a generic wrapper around a base sequence dataset that slices each full sequence into all overlapping windows of length `input_steps` and returns only the tensor window. +- [ ] Add `src/mlcast/data/datamodules.py`. +- [ ] Rename `SourceDataDataModule` to `ForecastingDataModule` in `src/mlcast/data/datamodules.py`. +- [ ] `ForecastingDataModule` should remain factory-based and build `ForecastingDataset` instances over the underlying sequence datasets. +- [ ] Add `ReconstructionDataModule` to `src/mlcast/data/datamodules.py`; it remains factory-based, builds the underlying sequence datasets, splits them into train/val/test, and wraps each split with `ReconstructionDataset`. - [ ] Keep this generic: no LDCast-specific naming in the module or class names. -- [ ] Stage 1 should use only the input window from the forecasting dataset. -- [ ] Allow `forecast_steps == 0` in forecasting datasets, but emit a warning when used. +- [ ] Forecasting should stay one-sequence-to-one-sample. +- [ ] Reconstruction should expand each sequence into `sequence_steps - input_steps + 1` overlapping samples. +- [ ] Stage 1 should use reconstruction windows of length `input_steps` derived from the full sequence dataset. 2. Autoencoder model architecture - Autoencoder model split: - [ ] `src/mlcast/models/autoencoder/encoder.py` for `Encoder` and `EncoderBlock`. - [ ] `src/mlcast/models/autoencoder/decoder.py` for `Decoder` and `DecoderBlock`. - [ ] `src/mlcast/models/autoencoder/net.py` for `AutoencoderNet`. +- [ ] Use `input_steps` for the stage-1 reconstruction window length; do not introduce names like `autoenc_time_ratio`. - Autoencoder validation and tests: - [ ] encoder output shape. - [ ] decoder output shape. - [ ] autoencoder reconstruction forward pass. - [ ] autoencoder improves reconstruction loss on a small generated dataset after a few training steps. -3. Diffusion model architecture +3. Forecasting model contract +- [ ] Standardize all forecasting models on init-time `input_steps`, `forecast_steps`, and `ensemble_size`. +- [ ] Standardize forecasting model inference on `forward(x)` only; do not pass `forecast_steps` or `ensemble_size` at runtime. +- [ ] Refactor the existing ConvGRU path to follow this fixed-shape contract. +- [ ] Add config consistency checks that dataset `input_steps` and `forecast_steps` match the configured forecasting model. + +4. Diffusion model architecture - Diffusion model split: - [ ] `src/mlcast/models/diffusion/conditioner.py` for latent conditioning blocks and `ConditionerNet`. - [ ] `src/mlcast/models/diffusion/denoiser.py` for `DenoiserUNet` and timestep-aware helpers. - [ ] `src/mlcast/models/diffusion/net.py` for `LatentDiffusionNet`. - - [ ] `src/mlcast/models/diffusion/forecasting.py` for `LatentDiffusionForecaster`, the diffusion-specific adapter that exposes `forward(x, forecast_steps, ensemble_size)`. + - [ ] `src/mlcast/models/diffusion/forecasting.py` for `LatentDiffusionForecaster`, the diffusion-specific adapter configured with fixed `input_steps`, `forecast_steps`, and `ensemble_size` and exposing `forward(x)`. - [ ] `src/mlcast/models/diffusion/scheduler.py`, `ema.py`, `sampler.py`, `loss.py` for diffusion support code. - Validation and tests: - [ ] diffusion forecasting adapter API. - [ ] diffusion model improves loss on a small generated latent dataset after a few training steps. -4. Task wrappers +5. Task wrappers - [ ] Add `src/mlcast/modules/forecasting.py` and rename `NowcastLightningModule` to `ForecastingModule`. +- [ ] Remove runtime `forecast_steps` and `ensemble_size` arguments from the forecasting Lightning module and its `predict()` API. - [ ] Add `src/mlcast/modules/reconstruction.py` with a generic `ReconstructionModule` for any reconstruction model. - [ ] Keep `modules/` for training/task wrappers only; keep `models/` for pure architectures. -5. Training experiment +6. Training experiment - [ ] Add a new LDCast-specific training module containing `LDCastTrainingExperiment`. - [ ] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. - [ ] Stage 1 builds the reconstruction dataset, autoencoder model, and reconstruction module, then trains the autoencoder. @@ -58,7 +75,7 @@ - [ ] Reuse the same forecasting dataset abstraction in stage 2; do not add a separate latent dataset layer. - [ ] Add tests for shared object identity and stage sequencing. -6. Audit and migration targets +7. Audit and migration targets - [ ] Update CLI/help text in `src/mlcast/__main__.py` to require an explicit base config and list the included config entry points. - [ ] Rename `training_experiment` to `convgru_training_experiment` in `src/mlcast/config/base.py` and export it from `src/mlcast/config/__init__.py`. - [ ] Add the LDCast config entry point to `src/mlcast/config/__init__.py` alongside the existing ConvGRU example config. From 11dd3a56753bebe518143d43e434cb3adcae9250 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 14:01:32 +0200 Subject: [PATCH 09/34] refactor: require explicit convgru config selection --- README.md | 39 ++++---- ...generate_base_experiment_config_diagram.py | 8 +- ldcast-refactor-plan.md | 24 ++--- src/mlcast/__main__.py | 93 +++++++++++++------ src/mlcast/config/__init__.py | 4 +- src/mlcast/config/base.py | 12 +-- src/mlcast/config/orchestrator.py | 3 +- tests/config/test_cli_examples.py | 39 ++++++-- tests/config/test_consistency_checks.py | 12 +-- tests/config/test_fiddlers.py | 10 +- tests/config/test_orchestrator.py | 8 +- tests/test_cli_training.py | 6 +- 12 files changed, 161 insertions(+), 97 deletions(-) diff --git a/README.md b/README.md index 28db293..ad7e8b5 100644 --- a/README.md +++ b/README.md @@ -67,11 +67,11 @@ reproduce runs exactly from a saved YAML file. ### Configuration model -Training in mlcast is currently built around a single base configuration -function, [`training_experiment`](src/mlcast/config/base.py), which defines the -default ConvGRU ensemble nowcasting setup: dataset, data module, network, +Training in mlcast is currently built around the included configuration +function, [`convgru_training_experiment`](src/mlcast/config/base.py), which +defines the ConvGRU ensemble nowcasting setup: dataset, data module, network, Lightning module, and trainer. Rather than writing a new config from scratch, -the intended workflow is to start from this base and apply targeted +the intended workflow is to start from this config and apply targeted modifications: - **`set:` overrides** — change a single scalar parameter (e.g. batch size, @@ -82,32 +82,31 @@ modifications: - **direct graph edits** (Python API only) — replace a sub-object entirely, for example swapping in a different network architecture -Any combination of these can be layered on top of the base config, and the +Any combination of these can be layered on top of the selected config, and the fully resolved config is always saved to YAML alongside the training logs so runs can be reproduced exactly. -The diagram below shows the full default config graph as built by -[`training_experiment`](src/mlcast/config/base.py): +The diagram below shows the full included ConvGRU config graph as built by +[`convgru_training_experiment`](src/mlcast/config/base.py): -![training_experiment config graph](docs/config_diagram.svg) +![convgru_training_experiment config graph](docs/config_diagram.svg) ### Command-line interface Install the package and run: ```bash -mlcast train +mlcast train --config config:convgru_training_experiment ``` -This trains with the built-in [`training_experiment`](src/mlcast/config/base.py) defaults. All parameters +This trains with the built-in [`convgru_training_experiment`](src/mlcast/config/base.py) config. All parameters are controlled via `--config` flags: | Prefix | Purpose | Example | |--------|---------|---------| -| *(none)* | Use the built-in default config | `mlcast train` | +| `config:` | Select an included `@auto_config` function | `--config config:convgru_training_experiment` | | `set:` | Override a single parameter | `--config set:data.batch_size=32` | | `fiddler:` | Apply a semantic mutator (multi-param change) | `--config fiddler:use_random_sampler` | -| `config:` | Switch to a different `@auto_config` function | `--config=config:my_experiment` | | `path/to/config.yaml` | Load a previously saved config | `--config saved.yaml` | Multiple `--config` flags are applied in order and can be combined freely. @@ -117,11 +116,13 @@ Multiple `--config` flags are applied in order and can be combined freely. ```bash # Override dataset path and batch size mlcast train \ + --config config:convgru_training_experiment \ --config set:data.dataset_factory.zarr_path=/data/radar.zarr \ --config set:data.batch_size=32 # Switch to random sampler and log to MLflow mlcast train \ + --config config:convgru_training_experiment \ --config fiddler:use_random_sampler \ --config fiddler:use_mlflow_logger @@ -131,7 +132,7 @@ mlcast train \ --config set:trainer.max_epochs=50 # Inspect the fully resolved config without starting training -mlcast train --config fiddler:use_random_sampler --print_config_and_exit +mlcast train --config config:convgru_training_experiment --config fiddler:use_random_sampler --print_config_and_exit ``` Run `mlcast train --help` for a full list of examples and available fiddlers. @@ -141,14 +142,14 @@ Run `mlcast train --help` for a full list of examples and available fiddlers. The Python API gives you full programmatic control over the config graph before anything is instantiated. -**Run the default experiment with tweaks:** +**Run the included ConvGRU experiment with tweaks:** ```python import fiddle as fdl -from mlcast.config import training_experiment, train_from_config +from mlcast.config import convgru_training_experiment, train_from_config from mlcast.config.fiddlers import use_random_sampler -cfg = training_experiment.as_buildable() # returns a fdl.Config graph — see src/mlcast/config/base.py +cfg = convgru_training_experiment.as_buildable() # returns a fdl.Config graph — see src/mlcast/config/base.py # Apply a fiddler to switch the dataset sampler use_random_sampler(cfg) @@ -183,7 +184,7 @@ import torch import torch.nn as nn from jaxtyping import Float from mfai.torch.models import HalfUNet -from mlcast.config import training_experiment, train_from_config +from mlcast.config import convgru_training_experiment, train_from_config from mlcast.config.fiddlers import use_random_sampler # Minimal adapter: channel-stack past frames → HalfUNet → one step at a time. @@ -226,7 +227,7 @@ class HalfUNetNowcaster(nn.Module): x_flat = torch.cat([x_flat[:, self.num_vars:], y], dim=1) return torch.cat(preds, dim=1) -cfg = training_experiment.as_buildable() +cfg = convgru_training_experiment.as_buildable() use_random_sampler(cfg) cfg.pl_module.network = fdl.Config( @@ -270,7 +271,7 @@ mlcast/ │ ├── callbacks.py # Training callbacks │ ├── visualization.py # TensorBoard image logging helpers │ ├── config/ -│ │ ├── base.py # Default training_experiment @auto_config +│ │ ├── base.py # ConvGRU training config @auto_config │ │ ├── fiddlers.py # Semantic config mutators │ │ ├── consistency_checks.py # Cross-parameter validation │ │ ├── loader.py # YAML config loader diff --git a/docs/generate_base_experiment_config_diagram.py b/docs/generate_base_experiment_config_diagram.py index 0b58291..3193616 100644 --- a/docs/generate_base_experiment_config_diagram.py +++ b/docs/generate_base_experiment_config_diagram.py @@ -1,4 +1,4 @@ -"""Generate a Graphviz SVG diagram of the default training_experiment config. +"""Generate a Graphviz SVG diagram of the included ConvGRU training config. Run without arguments to regenerate docs/config_diagram.svg: @@ -15,13 +15,13 @@ import fiddle.graphviz as fgv -from mlcast.config import training_experiment +from mlcast.config import convgru_training_experiment OUT = Path(__file__).parent / "config_diagram.svg" def main() -> None: - """Generate or verify the base experiment config diagram.""" + """Generate or verify the ConvGRU training config diagram.""" parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--check", @@ -30,7 +30,7 @@ def main() -> None: ) args = parser.parse_args() - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() g = fgv.render(cfg, max_str_length=40) g.format = "svg" new_svg = g.pipe().decode() diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index da03015..1883236 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -1,13 +1,13 @@ # LDCast Refactor Plan 0. Config naming and CLI contract -- [ ] Rename `training_experiment` to `convgru_training_experiment`. -- [ ] Do not keep `training_experiment` as an alias. +- [x] Rename `training_experiment` to `convgru_training_experiment`. +- [x] Do not keep `training_experiment` as an alias. - [ ] Reserve `ldcast_training_experiment` as the top-level config name for the new two-stage LDCast workflow. -- [ ] Require `mlcast train` users to provide an explicit base config via `--config config:` or `--config /path/to/config.yaml`. -- [ ] Update CLI help text to list the included config entry points explicitly. -- [ ] Update all docs, examples, tests, and scripts to use `convgru_training_experiment` instead of `training_experiment`. -- [ ] Treat `convgru_training_experiment` as the existing ConvGRU forecasting example, not as a special default config. +- [x] Require `mlcast train` users to provide an explicit base config via `--config config:` or `--config /path/to/config.yaml`. +- [x] Update CLI help text to list the included config entry points explicitly. +- [x] Update all docs, examples, tests, and scripts to use `convgru_training_experiment` instead of `training_experiment`. +- [x] Treat `convgru_training_experiment` as the existing ConvGRU forecasting example, not as a special default config. 1. Forecasting and reconstruction data - [ ] Move the existing sampled-sequence source-data logic into `src/mlcast/data/sequence.py`. @@ -67,7 +67,7 @@ 6. Training experiment - [ ] Add a new LDCast-specific training module containing `LDCastTrainingExperiment`. -- [ ] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. +- [x] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. - [ ] Stage 1 builds the reconstruction dataset, autoencoder model, and reconstruction module, then trains the autoencoder. - [ ] Stage 2 reuses the same trained in-memory encoder instance, builds the diffusion dataset/model/module, then trains latent diffusion. - [ ] The shared Fiddle graph should define the encoder once and reference the same object in both stages, but no unresolved Fiddle objects should flow into actual `torch.nn.Module.__init__` calls. @@ -76,11 +76,11 @@ - [ ] Add tests for shared object identity and stage sequencing. 7. Audit and migration targets -- [ ] Update CLI/help text in `src/mlcast/__main__.py` to require an explicit base config and list the included config entry points. -- [ ] Rename `training_experiment` to `convgru_training_experiment` in `src/mlcast/config/base.py` and export it from `src/mlcast/config/__init__.py`. +- [x] Update CLI/help text in `src/mlcast/__main__.py` to require an explicit base config and list the included config entry points. +- [x] Rename `training_experiment` to `convgru_training_experiment` in `src/mlcast/config/base.py` and export it from `src/mlcast/config/__init__.py`. - [ ] Add the LDCast config entry point to `src/mlcast/config/__init__.py` alongside the existing ConvGRU example config. - [ ] Keep `src/mlcast/config/orchestrator.py` compatible with both the existing single-stage `Experiment` and the new `LDCastTrainingExperiment` through a common `run()` surface. -- [ ] Update docstrings and comments that currently imply `training_experiment` is the only experiment, including `src/mlcast/data/source_data_datamodule.py`, `src/mlcast/config/orchestrator.py`, and related config docs. -- [ ] Update docs and scripts that still reference `training_experiment`, including `README.md` and `docs/generate_base_experiment_config_diagram.py`. -- [ ] Keep existing ConvGRU CLI/config tests passing while adding separate tests for selecting the LDCast config explicitly. +- [x] Update docstrings and comments that currently imply `training_experiment` is the only experiment, including `src/mlcast/data/source_data_datamodule.py`, `src/mlcast/config/orchestrator.py`, and related config docs. +- [x] Update docs and scripts that still reference `training_experiment`, including `README.md` and `docs/generate_base_experiment_config_diagram.py`. +- [x] Keep existing ConvGRU CLI/config tests passing while adding separate tests for selecting the LDCast config explicitly. - [ ] Add real but small-scale end-to-end tests with generated sample data for the autoencoder stage, diffusion stage, and full LDCast stage sequencing. diff --git a/src/mlcast/__main__.py b/src/mlcast/__main__.py index 31ffc0b..07ffc4f 100644 --- a/src/mlcast/__main__.py +++ b/src/mlcast/__main__.py @@ -5,8 +5,9 @@ Usage examples:: - # Train with default config and override dataset path: + # Train with an included config and override dataset path: python -m mlcast train \\ + --config config:convgru_training_experiment \\ --config fiddler:use_random_sampler \\ --config set:data.dataset_factory.zarr_path="'/path/to/data.zarr'" @@ -18,8 +19,8 @@ --config /path/to/config.yaml \\ --config set:trainer.max_epochs=50 - # Switch to a different base config function entirely: - python -m mlcast train --config=config:another_experiment_function + # Use a different included config function: + python -m mlcast train --config config:another_experiment_function """ import argparse @@ -35,14 +36,14 @@ from rich.text import Text from . import config # noqa: F401 — module must be importable for absl_flags -from .config import load_yaml_config, train_from_config, training_experiment +from .config import convgru_training_experiment, load_yaml_config, train_from_config FLAGS = flags.FLAGS _config = absl_flags.DEFINE_fiddle_config( "config", default_module=config, - help_string="Experiment configuration. Default is training_experiment.", + help_string="Experiment configuration. Required: use config: or a YAML path.", ) flags.DEFINE_boolean( @@ -52,50 +53,67 @@ ) -def get_cli_examples(cfg: fdl.Buildable) -> list[tuple[str, str]]: +def get_included_config_names() -> list[str]: + """Return public config factory names exposed by ``mlcast.config``.""" + included_configs: list[str] = [] + for name in getattr(config, "__all__", []): + value = getattr(config, name, None) + if callable(value) and hasattr(value, "as_buildable"): + included_configs.append(name) + return included_configs + + +def get_cli_examples( + cfg: fdl.Buildable, base_config_name: str = "convgru_training_experiment" +) -> list[tuple[str, str]]: """Returns a list of (description, flag_string) tuples for CLI parameter overrides.""" return [ ( f"Override data layer properties (default batch_size: {cfg.data.batch_size})", - f"--config set:data.batch_size={max(1, cfg.data.batch_size * 2)}", + f"--config config:{base_config_name} --config set:data.batch_size={max(1, cfg.data.batch_size * 2)}", ), ( f"Override the path to the Zarr dataset (default: {cfg.data.dataset_factory.zarr_path})", - "--config set:data.dataset_factory.zarr_path='/new/path/to/radar.zarr'", + f"--config config:{base_config_name} --config set:data.dataset_factory.zarr_path='/new/path/to/radar.zarr'", ), ( f"Override trainer properties (default max_epochs: {cfg.trainer.max_epochs})", - f"--config set:trainer.max_epochs={max(1, cfg.trainer.max_epochs // 2)}", + f"--config config:{base_config_name} --config set:trainer.max_epochs={max(1, cfg.trainer.max_epochs // 2)}", ), ( f"Override network architecture properties (default num_blocks: {cfg.pl_module.network.num_blocks})", - f"--config set:pl_module.network.num_blocks={max(1, cfg.pl_module.network.num_blocks - 1)}", + "--config config:" + f"{base_config_name} " + "--config set:pl_module.network.num_blocks=" + f"{max(1, cfg.pl_module.network.num_blocks - 1)}", ), ( f"Override the optimizer learning rate (default lr: {cfg.pl_module.optimizer.lr})", - "--config set:pl_module.optimizer.lr=0.1", + f"--config config:{base_config_name} --config set:pl_module.optimizer.lr=0.1", ), ] -def get_fiddler_examples() -> list[tuple[str, str]]: +def get_fiddler_examples(base_config_name: str = "convgru_training_experiment") -> list[tuple[str, str]]: """Returns a list of (description, flag_string) tuples for Fiddler mutators.""" return [ ( "Switch to the random sampling dataset (instead of the precomputed CSV sampler)", - "--config fiddler:use_random_sampler", + f"--config config:{base_config_name} --config fiddler:use_random_sampler", ), ( "Change the input variables and automatically adjust the network's input_channels", + "--config config:" + f"{base_config_name} " "--config \"fiddler:set_variables(standard_names=['rainfall_rate', 'reflectivity'])\"", ), ( "Toggle whether the loss function ignores masked/invalid pixels", - '--config "fiddler:toggle_masking(enabled=False)"', + f'--config config:{base_config_name} --config "fiddler:toggle_masking(enabled=False)"', ), ( "Train using an anonymous S3 object store dataset (e.g. the Italian dataset)", - '--config "fiddler:use_anon_s3_dataset(' + f'--config config:{base_config_name} --config "fiddler:use_anon_s3_dataset(' "zarr_path='s3://mlcast-source-datasets/IT-DPC-SRI/v0.1.0/italian-radar-dpc-sri.zarr/', " "endpoint_url='https://object-store.os-api.cci2.ecmwf.int')\"", ), @@ -105,14 +123,24 @@ def get_fiddler_examples() -> list[tuple[str, str]]: def _build_help_text(cfg: fdl.Buildable) -> Text: """Build Rich-highlighted help text for the ``train`` subcommand.""" t = Text() + included_config_names = get_included_config_names() t.append("Train a model using a Fiddle configuration.\n\n", style="bold") + t.append("You must provide a base config via ", style="bold") + t.append("--config config:", style="bold cyan") + t.append(" or ", style="bold") + t.append("--config /path/to/config.yaml", style="bold cyan") + t.append(".\n\n", style="bold") + t.append("Included configs:\n", style="bold yellow") + for name in included_config_names: + t.append(f" - {name}\n", style="green") + t.append("\n") t.append("You can override parameters from the command line using the ") t.append("--config set:path.to.param=value", style="bold cyan") t.append(" syntax.\n\n") t.append("Examples", style="bold yellow") - t.append(" (based on the default ") - t.append("training_experiment", style="bold green") + t.append(" (based on the included ") + t.append("convgru_training_experiment", style="bold green") t.append(" config):\n") for desc, cmd in get_cli_examples(cfg): @@ -130,25 +158,28 @@ def _build_help_text(cfg: fdl.Buildable) -> Text: t.append(cmd, style="cyan") t.append("\n") - t.append("\nSwitching experiments:\n", style="bold yellow") + t.append("\nConfig sources:\n", style="bold yellow") t.append("\n # Resume from or reproduce a previously saved YAML config:\n", style="dim") t.append(" mlcast train ", style="bold") t.append("--config /path/to/config.yaml\n", style="cyan") t.append("\n # Load a YAML config and apply additional overrides on top:\n", style="dim") t.append(" mlcast train ", style="bold") t.append("--config /path/to/config.yaml --config set:trainer.max_epochs=50\n", style="cyan") - t.append("\n # Use a different base config function defined in src/mlcast/config/.\n", style="dim") + t.append("\n # Use a different included config function defined in src/mlcast/config/.\n", style="dim") t.append(" # Syntax: --config=config: where ", style="dim") t.append("config:", style="dim bold") t.append(" is a Fiddle prefix that resolves\n", style="dim") t.append(" # the function name against the config module (not a Python module path).\n", style="dim") t.append(" mlcast train ", style="bold") - t.append("--config=config:another_experiment_function\n", style="cyan") + t.append("--config config:convgru_training_experiment\n", style="cyan") t.append("\nInspecting the resolved config:\n", style="bold yellow") t.append(" # Print the fully resolved config as YAML without starting training:\n", style="dim") t.append(" mlcast train ", style="bold") - t.append("--config fiddler:use_random_sampler --print_config_and_exit\n", style="cyan") + t.append( + "--config config:convgru_training_experiment --config fiddler:use_random_sampler --print_config_and_exit\n", + style="cyan", + ) return t @@ -335,14 +366,14 @@ def train_main(argv: list[str]) -> None: def cli() -> None: """Console script entry point for the ``mlcast`` command. - This parses standard CLI arguments via `argparse`, injects Fiddle default - overrides if no base configuration is provided, formats the `--help` - output, and safely passes execution over to `absl.app.run`. + This parses standard CLI arguments via `argparse`, validates that an + explicit base configuration was provided, formats the `--help` output, and + safely passes execution over to `absl.app.run`. """ # Dynamically generate help text showing Fiddle overrides try: - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() description_text = _build_help_text(cfg) except Exception: # Fallback if config generation fails during CLI initialization @@ -378,16 +409,22 @@ def cli() -> None: yaml_path, remaining = _extract_yaml_config_path(remaining) # Case 2: user supplied an explicit base config function - # e.g. --config=config:another_experiment_function + # e.g. --config config:convgru_training_experiment has_explicit_config = any( arg.startswith("--config=config:") or (arg == "--config" and i + 1 < len(remaining) and remaining[i + 1].startswith("config:")) for i, arg in enumerate(remaining) ) - # Case 3: no base config from either source — fall back to training_experiment if not has_explicit_config and yaml_path is None: - remaining = ["--config=config:training_experiment"] + remaining + included = ", ".join(get_included_config_names()) + print( + "Error: a base config is required. Provide either " + "'--config config:' or '--config /path/to/config.yaml'.\n" + f"Included configs: {included}", + file=sys.stderr, + ) + sys.exit(1) remaining = auto_quote_fiddle_strings(remaining) diff --git a/src/mlcast/config/__init__.py b/src/mlcast/config/__init__.py index 194ad6b..9338dbd 100644 --- a/src/mlcast/config/__init__.py +++ b/src/mlcast/config/__init__.py @@ -4,7 +4,7 @@ and runtime orchestration logic for `mlcast`. """ -from .base import Experiment, training_experiment +from .base import Experiment, convgru_training_experiment from .consistency_checks import validate_config from .fiddlers import ( set_variables, @@ -19,7 +19,7 @@ __all__ = [ "Experiment", - "training_experiment", + "convgru_training_experiment", "validate_config", "train_from_config", "load_yaml_config", diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index fca8117..47c8c45 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -1,11 +1,11 @@ """Base Fiddle experiment definitions for ConvGRU radar nowcasting. This module defines the ``Experiment`` dataclass and the -``training_experiment`` auto-config factory, which together form the default -configuration graph for a ConvGRU ensemble nowcasting run. +``convgru_training_experiment`` auto-config factory, which together form the +included configuration graph for a ConvGRU ensemble nowcasting run. -``training_experiment`` is decorated with ``@auto_config``: calling it returns -a ``fdl.Config`` graph rather than a live ``Experiment`` object. Every +``convgru_training_experiment`` is decorated with ``@auto_config``: calling it +returns a ``fdl.Config`` graph rather than a live ``Experiment`` object. Every parameter in the graph can be overridden before instantiation — either via fiddlers (for semantic, multi-parameter changes) or via ``set:`` overrides on the CLI (for single-parameter tweaks). Call ``fdl.build(cfg)`` to materialise @@ -13,7 +13,7 @@ Typical usage ------------- ->>> cfg = training_experiment() # returns fdl.Config +>>> cfg = convgru_training_experiment() # returns fdl.Config >>> use_random_sampler(cfg) # apply a fiddler >>> validate_config(cfg) # check cross-parameter contracts >>> experiment = fdl.build(cfg) # instantiate everything @@ -50,7 +50,7 @@ def run(self) -> None: @fiddle.experimental.auto_config.auto_config -def training_experiment() -> Experiment: +def convgru_training_experiment() -> Experiment: """Build a Fiddle config for ConvGRU ensemble radar nowcasting. This is decorated as a Fiddle ``@auto_config`` function: calling it diff --git a/src/mlcast/config/orchestrator.py b/src/mlcast/config/orchestrator.py index 16fea32..9d21ec0 100644 --- a/src/mlcast/config/orchestrator.py +++ b/src/mlcast/config/orchestrator.py @@ -126,7 +126,8 @@ def train_from_config(cfg: fdl.Config) -> None: Parameters ---------- cfg : fdl.Config - Fiddle configuration as returned by `training_experiment`. + Fiddle configuration as returned by an included auto-config factory such + as `convgru_training_experiment`. """ validate_config(cfg) diff --git a/tests/config/test_cli_examples.py b/tests/config/test_cli_examples.py index 5a54c6f..156db8d 100644 --- a/tests/config/test_cli_examples.py +++ b/tests/config/test_cli_examples.py @@ -2,24 +2,45 @@ import subprocess import sys -from mlcast.__main__ import get_cli_examples, get_fiddler_examples -from mlcast.config import training_experiment +from mlcast.__main__ import get_cli_examples, get_fiddler_examples, get_included_config_names +from mlcast.config import convgru_training_experiment -def test_cli_examples_parse_correctly(): +def test_cli_examples_parse_correctly() -> None: """Verify that every CLI override example given in the help text successfully parses.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() examples = get_cli_examples(cfg) + get_fiddler_examples() for _desc, cmd in examples: - # Strip out the leading `mlcast train ` if we used it, but cmd is just `--config ...` args = shlex.split(cmd) - # We can run the __main__.py module using subprocess to ensure isolated absl flag parsing - process_args = [sys.executable, "-m", "mlcast", "train"] + args + ["--only_check_args"] + process_args = [sys.executable, "-m", "mlcast", "train"] + args + ["--print_config_and_exit"] result = subprocess.run(process_args, capture_output=True, text=True) - # absl prints "unknown flag: --only_check_args" in some versions, or handles it? - # Let's check what it does. assert result.returncode == 0, f"Command '{cmd}' failed to parse:\n{result.stderr}\n{result.stdout}" + + +def test_cli_requires_explicit_config() -> None: + """Train command should fail fast when no base config is provided.""" + result = subprocess.run( + [sys.executable, "-m", "mlcast", "train"], + capture_output=True, + text=True, + ) + + assert result.returncode != 0 + assert "base config is required" in result.stderr + + +def test_cli_help_lists_included_configs() -> None: + """Help text should advertise the built-in config entry points.""" + result = subprocess.run( + [sys.executable, "-m", "mlcast", "train", "--help"], + capture_output=True, + text=True, + ) + + assert result.returncode == 0 + for name in get_included_config_names(): + assert name in result.stdout diff --git a/tests/config/test_consistency_checks.py b/tests/config/test_consistency_checks.py index 3dc8ec3..5277d30 100644 --- a/tests/config/test_consistency_checks.py +++ b/tests/config/test_consistency_checks.py @@ -3,13 +3,13 @@ import pytest from loguru import logger -from mlcast.config import training_experiment, validate_config +from mlcast.config import convgru_training_experiment, validate_config from mlcast.data.source_data_datasets import SourceDataPrecomputedSamplingDataset def test_contract_1_input_channels() -> None: """Verify Contract 1: Network input_channels == len(dataset_factory.standard_names).""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Break Contract 1 cfg.pl_module.network.input_channels = 2 cfg.data.dataset_factory.standard_names = ["rainfall_rate"] @@ -20,7 +20,7 @@ def test_contract_1_input_channels() -> None: def test_contract_2_spatial_divisibility() -> None: """Verify Contract 2: Dataset width must be divisible by 2 \\*\\* network.num_blocks.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Break Contract 2 cfg.data.dataset_factory.width = 250 cfg.pl_module.network.num_blocks = 4 @@ -31,7 +31,7 @@ def test_contract_2_spatial_divisibility() -> None: def test_contract_1_and_2_warn_when_network_lacks_attrs() -> None: """Verify Contracts 1 and 2 warn when the network lacks required attrs.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() cfg.pl_module.network = SimpleNamespace() messages: list[str] = [] @@ -51,7 +51,7 @@ def capture(message: object) -> None: def test_contract_3_probabilistic_loss() -> None: """Verify Contract 3: Ensemble models require CRPS or AFCRPS.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Break Contract 3 cfg.pl_module.ensemble_size = 5 cfg.pl_module.loss_class = "mse" @@ -62,7 +62,7 @@ def test_contract_3_probabilistic_loss() -> None: def test_contract_4_masking_sync() -> None: """Verify Contract 4: Dataset return_mask must match model masked_loss.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Break Contract 4 cfg.data.dataset_factory.return_mask = True cfg.pl_module.masked_loss = False diff --git a/tests/config/test_fiddlers.py b/tests/config/test_fiddlers.py index e6f443f..001f1ed 100644 --- a/tests/config/test_fiddlers.py +++ b/tests/config/test_fiddlers.py @@ -1,9 +1,9 @@ -from mlcast.config import set_variables, toggle_masking, training_experiment +from mlcast.config import convgru_training_experiment, set_variables, toggle_masking -def test_fiddler_set_variables(): +def test_fiddler_set_variables() -> None: """Verify set_variables syncs dataset variables and network input_channels.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Apply fiddler set_variables(cfg, ["rainfall_rate", "rainfall_flux"]) @@ -13,9 +13,9 @@ def test_fiddler_set_variables(): assert cfg.pl_module.network.input_channels == 2 -def test_fiddler_toggle_masking(): +def test_fiddler_toggle_masking() -> None: """Verify toggle_masking syncs dataset mask return and module masked_loss.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Disable masking toggle_masking(cfg, False) diff --git a/tests/config/test_orchestrator.py b/tests/config/test_orchestrator.py index 42d5a36..cf17e8f 100644 --- a/tests/config/test_orchestrator.py +++ b/tests/config/test_orchestrator.py @@ -1,12 +1,14 @@ +from pathlib import Path +from typing import Any from unittest.mock import patch -from mlcast.config import train_from_config, training_experiment +from mlcast.config import convgru_training_experiment, train_from_config @patch("mlcast.config.orchestrator.fdl.build") -def test_train_from_config_valid(mock_build, tmp_path): +def test_train_from_config_valid(mock_build: Any, tmp_path: Path) -> None: """Verify that a valid configuration passes validation and builds.""" mock_build.return_value.trainer.log_dir = str(tmp_path) - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() train_from_config(cfg) mock_build.assert_called_once() diff --git a/tests/test_cli_training.py b/tests/test_cli_training.py index 4f44da7..86e8e29 100644 --- a/tests/test_cli_training.py +++ b/tests/test_cli_training.py @@ -4,7 +4,7 @@ from fiddle._src.experimental.yaml_serialization import dump_yaml -from mlcast.config import training_experiment +from mlcast.config import convgru_training_experiment from mlcast.config.fiddlers import use_random_sampler @@ -21,6 +21,8 @@ def test_cli_train_command(fp_test_dataset: Path, tmp_path: Path) -> None: "mlcast", "train", "--config", + "config:convgru_training_experiment", + "--config", "fiddler:use_random_sampler", "--config", f"set:data.dataset_factory.zarr_path='{fp_test_dataset.absolute()}'", @@ -54,7 +56,7 @@ def test_cli_train_from_yaml_config(fp_test_dataset: Path, tmp_path: Path) -> No the dataset path) before dumping to YAML, so the subprocess call needs no additional --config flags. This exercises the pure load-from-YAML path. """ - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Switch to random sampler (no CSV required) and use the correct variable name use_random_sampler(cfg) cfg.data.dataset_factory.standard_names = ["rainfall_flux"] From 85df9de9f644279ab059a99053e4a71a88c0f2c0 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 14:24:33 +0200 Subject: [PATCH 10/34] refactor: introduce sequence-first data layer --- README.md | 10 +- docs/config_diagram.svg | 661 +++++++++--------- ldcast-refactor-plan.md | 38 +- src/mlcast/__main__.py | 7 +- src/mlcast/config/base.py | 19 +- src/mlcast/config/consistency_checks.py | 16 +- src/mlcast/config/fiddlers.py | 30 +- src/mlcast/data/__init__.py | 20 +- src/mlcast/data/datamodules.py | 278 ++++++++ src/mlcast/data/forecasting.py | 111 +++ src/mlcast/data/reconstruction.py | 73 ++ .../{source_data_datasets.py => sequence.py} | 374 ++++------ src/mlcast/data/source_data_datamodule.py | 161 ----- tests/config/test_consistency_checks.py | 19 +- tests/config/test_fiddlers.py | 6 +- tests/data/test_data_module.py | 195 +++--- tests/data/test_source_data_datasets.py | 141 ++-- tests/test_cli_training.py | 8 +- tests/test_readme_snippets.py | 2 +- 19 files changed, 1194 insertions(+), 975 deletions(-) create mode 100644 src/mlcast/data/datamodules.py create mode 100644 src/mlcast/data/forecasting.py create mode 100644 src/mlcast/data/reconstruction.py rename src/mlcast/data/{source_data_datasets.py => sequence.py} (51%) delete mode 100644 src/mlcast/data/source_data_datamodule.py diff --git a/README.md b/README.md index ad7e8b5..2885cd0 100644 --- a/README.md +++ b/README.md @@ -117,7 +117,7 @@ Multiple `--config` flags are applied in order and can be combined freely. # Override dataset path and batch size mlcast train \ --config config:convgru_training_experiment \ - --config set:data.dataset_factory.zarr_path=/data/radar.zarr \ + --config set:data.sequence_dataset_factory.zarr_path=/data/radar.zarr \ --config set:data.batch_size=32 # Switch to random sampler and log to MLflow @@ -174,7 +174,7 @@ As an example, here is how to wrap an U-Net) to satisfy the interface. The wrapper channel-stacks the past frames and runs the U-Net autoregressively for each requested forecast step: -> **Note** — `input_steps` equals `dataset_factory.input_steps` (6 by +> **Note** — `input_steps` equals the forecasting data module's `input_steps` (6 by > default) and is directly readable from the config graph before building. ```python @@ -232,8 +232,8 @@ use_random_sampler(cfg) cfg.pl_module.network = fdl.Config( HalfUNetNowcaster, - input_steps=cfg.data.dataset_factory.input_steps, - num_vars=len(cfg.data.dataset_factory.standard_names), + input_steps=cfg.data.input_steps, + num_vars=len(cfg.data.sequence_dataset_factory.standard_names), ) train_from_config(cfg) @@ -256,7 +256,7 @@ experiment.run() # trainer.fit() + trainer.test() |---------|-----------|--------------| | `use_mlflow_logger` | *(none)* | Replaces the default `TensorBoardLogger` with `MLFlowLogger` and appends `LogSystemInfoCallback`; respects the `MLFLOW_TRACKING_URI` environment variable | | `set_variables` | `standard_names` | Sets the list of input variables on the dataset and updates `network.input_channels` to match | -| `toggle_masking` | `enabled` | Toggles masked-loss mode by setting both `dataset_factory.return_mask` and `pl_module.masked_loss` to the same value | +| `toggle_masking` | `enabled` | Toggles masked-loss mode by setting both `data.return_mask` and `pl_module.masked_loss` to the same value | | `use_anon_s3_dataset` | `zarr_path`, `endpoint_url` | Points the dataset at an anonymous S3 object store; sets `zarr_path` and the required `storage_options` together | | `use_random_sampler` | *(none)* | Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed CSV is available) | diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 76275eb..129fec4 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -4,341 +4,346 @@ - - + + %3 - + 2 - - -Config: - ConvGruModel - - -input_channels - -1 + + +Config: + ConvGruModel -num_blocks +input_channels -5 +1 -noisy_decoder +num_blocks -False +5 + + +noisy_decoder + +False 1 - - -Config: - NowcastLightningModule - - -network - - - - - -ensemble_size - -2 - - -loss_class - -'crps' - - -loss_params - - - -dict - - -'temporal_lambda' - -0.01 - - -masked_loss - -True - - -optimizer - - - - - -lr_scheduler - - - + + +Config: + NowcastLightningModule + + +network + + + + + +ensemble_size + +2 + + +loss_class + +'crps' + + +loss_params + + + +dict + + +'temporal_lambda' + +0.01 + + +masked_loss + +True + + +optimizer + + + + + +lr_scheduler + + + 1:c--2:c - + 3 - - -Partial: - Adam - - -lr - -0.0001 - + + +Partial: + Adam + -fused +lr -True +0.0001 + + +fused + +True 1:c--3:c - + 4 - - -Partial: - ReduceLROnPlateau - - -mode - -'min' - + + +Partial: + ReduceLROnPlateau + -factor +mode -0.5 +'min' -patience +factor -10 +0.5 + + +patience + +10 1:c--4:c - + 0 - - -Config: - Experiment - - -pl_module - - - - - -data - - - - - -trainer - - - + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + 0:c--1:c - + 5 - - -Config: - SourceDataDataModule - - -dataset_factory - - - - - -splits - - - -dict - - -'time' - - - -dict - - -'train' - -0.7 - - -'val' - -0.15 - - -'test' - -0.15 - - -batch_size - -16 - - -num_workers - -8 - - -pin_memory - -True + + +Config: + ForecastingDataModule + + +sequence_dataset_factory + + + + + +input_steps + +6 + + +forecast_steps + +12 + + +return_mask + +True + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +16 + + +num_workers + +8 + + +pin_memory + +True 0:c--5:c - + 7 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - - - - -0 - - -1 - - -2 - - -3 - - -max_epochs - -100 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + + + + +0 + + +1 + + +2 + + +3 + + +max_epochs + +100 0:c--7:c - + 6 - - -Partial: - SourceDataPrecomputedSamplingDataset - - -zarr_path - -'./data/radar.zarr' - - -csv_path - -'./data/sampled_datacubes.csv' - - -standard_names - - - -list - -'rainfall_rate' - - -0 - - -input_steps - -6 - - -forecast_steps - -12 + + +Partial: + SourceDataPrecomputedSequenceDataset + + +zarr_path + +'./data/radar.zarr' + + +csv_path + +'./data/sampled_datacubes.csv' + + +standard_names + + + +list + +'rainfall_rate' + + +0 -return_mask +sequence_steps -True +18 deterministic @@ -348,132 +353,132 @@ 5:c--6:c - + 8 - - -Config: - TensorBoardLogger - - -save_dir - -'logs' + + +Config: + TensorBoardLogger -name +save_dir -'mlcast' +'logs' + + +name + +'mlcast' 7:c--8:c - + 9 - - -Config: - ModelCheckpoint - - -monitor - -'val_loss' + + +Config: + ModelCheckpoint -save_top_k +monitor -1 +'val_loss' -mode +save_top_k -'min' +1 + + +mode + +'min' 7:c--9:c - + 10 - - -Config: - ModelCheckpoint - - -monitor - -'train_loss_epoch' + + +Config: + ModelCheckpoint -save_top_k +monitor -1 +'train_loss_epoch' -mode +save_top_k -'min' +1 + + +mode + +'min' 7:c--10:c - + 11 - - -Config: - EarlyStopping - - -monitor - -'val_loss' + + +Config: + EarlyStopping -patience +monitor -100 +'val_loss' -mode +patience -'min' +100 + + +mode + +'min' 7:c--11:c - + 12 - - -Config: - LearningRateMonitor - - -logging_interval - -'step' + + +Config: + LearningRateMonitor + + +logging_interval + +'step' 7:c--12:c - + diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 1883236..16e432e 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -10,25 +10,25 @@ - [x] Treat `convgru_training_experiment` as the existing ConvGRU forecasting example, not as a special default config. 1. Forecasting and reconstruction data -- [ ] Move the existing sampled-sequence source-data logic into `src/mlcast/data/sequence.py`. -- [ ] Rename `SourceDataDatasetBase`, `SourceDataPrecomputedSamplingDataset`, and `SourceDataRandomSamplingDataset` to `SourceDataSequenceDatasetBase`, `SourceDataPrecomputedSequenceDataset`, and `SourceDataRandomSequenceDataset` under the sequence data area. -- [ ] Remove the old source-data public API rather than keeping compatibility re-exports. -- [ ] Keep the existing sampled-sequence implementation as the source-data sequence layer. -- [ ] Sequence datasets should own normalization and return normalized tensors of shape `(sequence_steps, channels, height, width)`. -- [ ] Replace forecasting-specific sampling parameters in the source-data sequence layer with a single `sequence_steps` parameter. -- [ ] Add `src/mlcast/data/forecasting.py`. -- [ ] Add a generic `ForecastingDataset` that wraps a base sequence dataset, takes `input_steps` and `forecast_steps`, validates `input_steps + forecast_steps == sequence_steps`, and returns forecasting samples. -- [ ] `ForecastingDataset` should derive `target_mask` itself rather than relying on the base sequence dataset to return masks. -- [ ] Add `src/mlcast/data/reconstruction.py`. -- [ ] Add `ReconstructionDataset`, a generic wrapper around a base sequence dataset that slices each full sequence into all overlapping windows of length `input_steps` and returns only the tensor window. -- [ ] Add `src/mlcast/data/datamodules.py`. -- [ ] Rename `SourceDataDataModule` to `ForecastingDataModule` in `src/mlcast/data/datamodules.py`. -- [ ] `ForecastingDataModule` should remain factory-based and build `ForecastingDataset` instances over the underlying sequence datasets. -- [ ] Add `ReconstructionDataModule` to `src/mlcast/data/datamodules.py`; it remains factory-based, builds the underlying sequence datasets, splits them into train/val/test, and wraps each split with `ReconstructionDataset`. -- [ ] Keep this generic: no LDCast-specific naming in the module or class names. -- [ ] Forecasting should stay one-sequence-to-one-sample. -- [ ] Reconstruction should expand each sequence into `sequence_steps - input_steps + 1` overlapping samples. -- [ ] Stage 1 should use reconstruction windows of length `input_steps` derived from the full sequence dataset. +- [x] Move the existing sampled-sequence source-data logic into `src/mlcast/data/sequence.py`. +- [x] Rename `SourceDataDatasetBase`, `SourceDataPrecomputedSamplingDataset`, and `SourceDataRandomSamplingDataset` to `SourceDataSequenceDatasetBase`, `SourceDataPrecomputedSequenceDataset`, and `SourceDataRandomSequenceDataset` under the sequence data area. +- [x] Remove the old source-data public API rather than keeping compatibility re-exports. +- [x] Keep the existing sampled-sequence implementation as the source-data sequence layer. +- [x] Sequence datasets should own normalization and return normalized tensors of shape `(sequence_steps, channels, height, width)`. +- [x] Replace forecasting-specific sampling parameters in the source-data sequence layer with a single `sequence_steps` parameter. +- [x] Add `src/mlcast/data/forecasting.py`. +- [x] Add a generic `ForecastingDataset` that wraps a base sequence dataset, takes `input_steps` and `forecast_steps`, validates `input_steps + forecast_steps == sequence_steps`, and returns forecasting samples. +- [x] `ForecastingDataset` should derive `target_mask` itself rather than relying on the base sequence dataset to return masks. +- [x] Add `src/mlcast/data/reconstruction.py`. +- [x] Add `ReconstructionDataset`, a generic wrapper around a base sequence dataset that slices each full sequence into all overlapping windows of length `input_steps` and returns only the tensor window. +- [x] Add `src/mlcast/data/datamodules.py`. +- [x] Rename `SourceDataDataModule` to `ForecastingDataModule` in `src/mlcast/data/datamodules.py`. +- [x] `ForecastingDataModule` should remain factory-based and build `ForecastingDataset` instances over the underlying sequence datasets. +- [x] Add `ReconstructionDataModule` to `src/mlcast/data/datamodules.py`; it remains factory-based, builds the underlying sequence datasets, splits them into train/val/test, and wraps each split with `ReconstructionDataset`. +- [x] Keep this generic: no LDCast-specific naming in the module or class names. +- [x] Forecasting should stay one-sequence-to-one-sample. +- [x] Reconstruction should expand each sequence into `sequence_steps - input_steps + 1` overlapping samples. +- [x] Stage 1 should use reconstruction windows of length `input_steps` derived from the full sequence dataset. 2. Autoencoder model architecture - Autoencoder model split: diff --git a/src/mlcast/__main__.py b/src/mlcast/__main__.py index 07ffc4f..b05028f 100644 --- a/src/mlcast/__main__.py +++ b/src/mlcast/__main__.py @@ -9,7 +9,7 @@ python -m mlcast train \\ --config config:convgru_training_experiment \\ --config fiddler:use_random_sampler \\ - --config set:data.dataset_factory.zarr_path="'/path/to/data.zarr'" + --config set:data.sequence_dataset_factory.zarr_path="'/path/to/data.zarr'" # Train from a previously saved YAML config: python -m mlcast train --config /path/to/config.yaml @@ -73,8 +73,9 @@ def get_cli_examples( f"--config config:{base_config_name} --config set:data.batch_size={max(1, cfg.data.batch_size * 2)}", ), ( - f"Override the path to the Zarr dataset (default: {cfg.data.dataset_factory.zarr_path})", - f"--config config:{base_config_name} --config set:data.dataset_factory.zarr_path='/new/path/to/radar.zarr'", + f"Override the path to the Zarr dataset (default: {cfg.data.sequence_dataset_factory.zarr_path})", + f"--config config:{base_config_name} " + "--config set:data.sequence_dataset_factory.zarr_path='/new/path/to/radar.zarr'", ), ( f"Override trainer properties (default max_epochs: {cfg.trainer.max_epochs})", diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index 47c8c45..418ac65 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -29,8 +29,8 @@ from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from ..data.source_data_datamodule import SourceDataDataModule -from ..data.source_data_datasets import SourceDataPrecomputedSamplingDataset +from ..data.datamodules import ForecastingDataModule +from ..data.sequence import SourceDataPrecomputedSequenceDataset from ..models.convgru import ConvGruModel from ..nowcasting_module import NowcastLightningModule @@ -62,19 +62,20 @@ def convgru_training_experiment() -> Experiment: Experiment Configured experiment with model, data, and trainer. """ - dataset_factory = fdl.Partial( - SourceDataPrecomputedSamplingDataset, + sequence_dataset_factory = fdl.Partial( + SourceDataPrecomputedSequenceDataset, zarr_path="./data/radar.zarr", csv_path="./data/sampled_datacubes.csv", standard_names=["rainfall_rate"], - input_steps=6, - forecast_steps=12, - return_mask=True, + sequence_steps=18, deterministic=False, ) - data = SourceDataDataModule( - dataset_factory=dataset_factory, + data = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=6, + forecast_steps=12, + return_mask=True, splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, batch_size=16, num_workers=8, diff --git a/src/mlcast/config/consistency_checks.py b/src/mlcast/config/consistency_checks.py index dea9f0a..debbd4d 100644 --- a/src/mlcast/config/consistency_checks.py +++ b/src/mlcast/config/consistency_checks.py @@ -28,14 +28,14 @@ def validate_config(cfg: fdl.Config) -> None: ValueError If any configuration contract is violated. """ - dataset_factory = cfg.data.dataset_factory + sequence_dataset_factory = cfg.data.sequence_dataset_factory network = cfg.pl_module.network pl_module = cfg.pl_module - # Contract 1: Network input_channels == len(dataset_factory.standard_names) + # Contract 1: Network input_channels == len(sequence_dataset_factory.standard_names) # If the network does not expose input_channels, emit a warning because # this contract cannot be checked. - num_vars = len(dataset_factory.standard_names) + num_vars = len(sequence_dataset_factory.standard_names) try: net_input_channels = network.input_channels except AttributeError: @@ -51,7 +51,7 @@ def validate_config(cfg: fdl.Config) -> None: f"must equal the number of standard_names ({num_vars})." ) - # Contract 2: Dataset width must be divisible by 2 ** network.num_blocks + # Contract 2: Sequence dataset width must be divisible by 2 ** network.num_blocks # If the network does not expose num_blocks, emit a warning because this # contract cannot be checked. try: @@ -64,7 +64,7 @@ def validate_config(cfg: fdl.Config) -> None: ) num_blocks = None if num_blocks is not None: - width = getattr(dataset_factory, "width", 256) + width = getattr(sequence_dataset_factory, "width", 256) divisor = 2**num_blocks if width % divisor != 0: raise ValueError( @@ -80,9 +80,9 @@ def validate_config(cfg: fdl.Config) -> None: f"require 'crps' or 'afcrps' loss, got '{pl_module.loss_class}'." ) - # Contract 4: Dataset return_mask must match model masked_loss - if bool(dataset_factory.return_mask) != bool(pl_module.masked_loss): + # Contract 4: Forecasting mask return must match model masked_loss + if bool(cfg.data.return_mask) != bool(pl_module.masked_loss): raise ValueError( - f"Contract 4 violated: dataset_factory.return_mask ({dataset_factory.return_mask}) " + f"Contract 4 violated: data.return_mask ({cfg.data.return_mask}) " f"must match pl_module.masked_loss ({pl_module.masked_loss})." ) diff --git a/src/mlcast/config/fiddlers.py b/src/mlcast/config/fiddlers.py index dceee26..5e1cfb0 100644 --- a/src/mlcast/config/fiddlers.py +++ b/src/mlcast/config/fiddlers.py @@ -19,13 +19,13 @@ from pytorch_lightning.loggers import MLFlowLogger from ..callbacks import LogSystemInfoCallback -from ..data.source_data_datasets import SourceDataRandomSamplingDataset +from ..data.sequence import SourceDataRandomSequenceDataset def set_variables(cfg: fdl.Config, standard_names: list[str]) -> None: """Fiddler to synchronize dataset variables with the network's input channels. - Sets ``dataset_factory.standard_names`` on the data config and, when the + Sets ``sequence_dataset_factory.standard_names`` on the data config and, when the network config exposes an ``input_channels`` parameter (e.g. ``ConvGruModel``), keeps it in sync. Networks that use a different parameter name for the channel count (e.g. ``HalfUNet`` uses @@ -39,7 +39,7 @@ def set_variables(cfg: fdl.Config, standard_names: list[str]) -> None: standard_names : list of str The new list of standard names to load. """ - cfg.data.dataset_factory.standard_names = standard_names + cfg.data.sequence_dataset_factory.standard_names = standard_names network_cls = cfg.pl_module.network.__fn_or_cls__ sig = inspect.signature(network_cls.__init__) if "input_channels" in sig.parameters: @@ -53,7 +53,7 @@ def set_variables(cfg: fdl.Config, standard_names: list[str]) -> None: def toggle_masking(cfg: fdl.Config, enabled: bool) -> None: - """Fiddler to synchronize dataset mask yielding with masked loss computation. + """Fiddler to synchronize forecasting-mask yielding with masked loss computation. Parameters ---------- @@ -62,12 +62,12 @@ def toggle_masking(cfg: fdl.Config, enabled: bool) -> None: enabled : bool Whether to enable masking or not. """ - cfg.data.dataset_factory.return_mask = enabled + cfg.data.return_mask = enabled cfg.pl_module.masked_loss = enabled def use_random_sampler(cfg: fdl.Config) -> None: - """Fiddler to switch the dataset factory to use the random sampler. + """Fiddler to switch the sequence dataset factory to use the random sampler. Parameters ---------- @@ -75,14 +75,12 @@ def use_random_sampler(cfg: fdl.Config) -> None: The Fiddle configuration to mutate. """ # Keep the existing parameters but change the underlying class - cfg.data.dataset_factory = fdl.Partial( - SourceDataRandomSamplingDataset, - zarr_path=cfg.data.dataset_factory.zarr_path, - standard_names=cfg.data.dataset_factory.standard_names, - input_steps=cfg.data.dataset_factory.input_steps, - forecast_steps=cfg.data.dataset_factory.forecast_steps, - return_mask=cfg.data.dataset_factory.return_mask, - storage_options=getattr(cfg.data.dataset_factory, "storage_options", None), + cfg.data.sequence_dataset_factory = fdl.Partial( + SourceDataRandomSequenceDataset, + zarr_path=cfg.data.sequence_dataset_factory.zarr_path, + standard_names=cfg.data.sequence_dataset_factory.standard_names, + sequence_steps=cfg.data.sequence_dataset_factory.sequence_steps, + storage_options=getattr(cfg.data.sequence_dataset_factory, "storage_options", None), ) @@ -103,8 +101,8 @@ def use_anon_s3_dataset(cfg: fdl.Buildable, zarr_path: str, endpoint_url: str) - endpoint_url : str The endpoint URL for the S3 object store. """ - cfg.data.dataset_factory.zarr_path = zarr_path - cfg.data.dataset_factory.storage_options = { + cfg.data.sequence_dataset_factory.zarr_path = zarr_path + cfg.data.sequence_dataset_factory.storage_options = { "anon": True, "client_kwargs": { "endpoint_url": endpoint_url, diff --git a/src/mlcast/data/__init__.py b/src/mlcast/data/__init__.py index e4b2449..3d1a03d 100644 --- a/src/mlcast/data/__init__.py +++ b/src/mlcast/data/__init__.py @@ -1,4 +1,18 @@ -from .source_data_datamodule import SourceDataDataModule -from .source_data_datasets import SourceDataPrecomputedSamplingDataset +from .datamodules import ForecastingDataModule, ReconstructionDataModule +from .forecasting import ForecastingDataset +from .reconstruction import ReconstructionDataset +from .sequence import ( + SourceDataPrecomputedSequenceDataset, + SourceDataRandomSequenceDataset, + SourceDataSequenceDatasetBase, +) -__all__ = ["SourceDataDataModule", "SourceDataPrecomputedSamplingDataset"] +__all__ = [ + "ForecastingDataModule", + "ForecastingDataset", + "ReconstructionDataModule", + "ReconstructionDataset", + "SourceDataPrecomputedSequenceDataset", + "SourceDataRandomSequenceDataset", + "SourceDataSequenceDatasetBase", +] diff --git a/src/mlcast/data/datamodules.py b/src/mlcast/data/datamodules.py new file mode 100644 index 0000000..be6492d --- /dev/null +++ b/src/mlcast/data/datamodules.py @@ -0,0 +1,278 @@ +"""PyTorch Lightning data modules for forecasting and reconstruction tasks.""" + +from collections.abc import Callable +from typing import Any + +import pytorch_lightning as pl +from loguru import logger +from torch.utils.data import DataLoader, Dataset + +from mlcast.data.forecasting import ForecastingDataset +from mlcast.data.reconstruction import ReconstructionDataset +from mlcast.data.splits import ( + compute_split_ranges_from_splitting_ratios, + splitting_uses_fractions, + splitting_uses_tuple_ranges, + validate_splits, +) + + +class _BaseDataModule(pl.LightningDataModule): + """Shared split/build logic for task-level data modules. + + Parameters + ---------- + sequence_dataset_factory : Callable[..., Dataset] + Factory that builds source-data sequence datasets and accepts + ``subset`` and ``augment`` keyword arguments. + splits : dict of str to dict + Nested mapping describing train/validation/test coordinate splits. + **dataloader_kwargs : Any + Additional keyword arguments forwarded to ``DataLoader``. + """ + + def __init__( + self, + sequence_dataset_factory: Callable[..., Dataset], + splits: dict[str, dict[str, Any]], + **dataloader_kwargs: Any, + ) -> None: + super().__init__() + self.sequence_dataset_factory = sequence_dataset_factory + self.splits = splits + self.dataloader_kwargs = dataloader_kwargs + validate_splits(self.splits) + + def _build_sequence_dataset(self, subset: dict[str, Any], augment: bool) -> Dataset: + """Build a source-data sequence dataset for one split. + + Parameters + ---------- + subset : dict of str to Any + Coordinate subset passed to the sequence dataset factory. + augment : bool + Whether this split should apply data augmentation. + + Returns + ------- + Dataset + Built source-data sequence dataset. + """ + return self.sequence_dataset_factory(subset=subset, augment=augment) + + def _wrap_sequence_dataset(self, base_sequence_dataset: Dataset) -> Dataset: + """Wrap a sequence dataset into a task-specific dataset. + + Parameters + ---------- + base_sequence_dataset : Dataset + Source-data sequence dataset for a split. + + Returns + ------- + Dataset + Task-specific dataset for the split. + """ + raise NotImplementedError + + def setup(self, stage: str | None = None) -> None: + """Create train, validation, and test datasets. + + Parameters + ---------- + stage : str or None, optional + Lightning setup stage. Supports ``"fit"``, ``"validate"``, + ``"test"``, and ``None``. Default is ``None``. + + Raises + ------ + ValueError + If ``stage`` is unsupported. + NotImplementedError + If a configured split mode is unsupported. + """ + if stage == "fit": + requested_splits = {"train", "val"} + elif stage == "validate": + requested_splits = {"val"} + elif stage == "test": + requested_splits = {"test"} + elif stage is None: + requested_splits = {"train", "val", "test"} + else: + raise ValueError(f"Unsupported LightningDataModule setup stage: {stage!r}") + + subset_per_split: dict[str, dict[str, Any] | None] = { + split_name: ( + {} + if split_name in requested_splits + and any(split_name in coord_splits for coord_splits in self.splits.values()) + else None + ) + for split_name in ("train", "val", "test") + } + + for coord, coord_splits in self.splits.items(): + if splitting_uses_tuple_ranges(coord_splits): + coord_values_per_split: dict[str, tuple[str, str] | None] = { + "train": coord_splits["train"], + "val": coord_splits["val"], + "test": coord_splits.get("test"), + } + elif splitting_uses_fractions(coord_splits): + coord_values_per_split = compute_split_ranges_from_splitting_ratios( + self.sequence_dataset_factory, coord, coord_splits + ) + else: + raise NotImplementedError(f"Unsupported split mode for coordinate {coord!r}: {coord_splits!r}") + + for split_name, split_val in coord_values_per_split.items(): + if split_val is None: + subset_per_split[split_name] = None + elif subset_per_split[split_name] is not None: + subset_per_split[split_name][coord] = split_val + + augment_flags = {"train": True, "val": False, "test": False} + for split in ("train", "val", "test"): + subset = subset_per_split[split] + if subset is None: + setattr(self, f"{split}_dataset", None) + else: + base_sequence_dataset = self._build_sequence_dataset(subset=subset, augment=augment_flags[split]) + setattr(self, f"{split}_dataset", self._wrap_sequence_dataset(base_sequence_dataset)) + + logger.info("{}.setup() complete, containing:", self.__class__.__name__) + for split in ("train", "val", "test"): + dataset = getattr(self, f"{split}_dataset", None) + if dataset is not None: + logger.info( + " {:5s}: {:>6d} samples, subset={}", + split, + len(dataset), + subset_per_split[split], + ) + + def train_dataloader(self) -> DataLoader: + """Return the training DataLoader. + + Returns + ------- + DataLoader + Training dataloader with shuffled samples. + """ + return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs) + + def val_dataloader(self) -> DataLoader: + """Return the validation DataLoader. + + Returns + ------- + DataLoader + Validation dataloader without shuffling. + """ + return DataLoader(self.val_dataset, shuffle=False, **self.dataloader_kwargs) + + def test_dataloader(self) -> DataLoader: + """Return the test DataLoader. + + Returns + ------- + DataLoader + Test dataloader without shuffling. + """ + return DataLoader(self.test_dataset, shuffle=False, **self.dataloader_kwargs) + + +class ForecastingDataModule(_BaseDataModule): + """Lightning data module for forecasting datasets. + + Parameters + ---------- + sequence_dataset_factory : Callable[..., Dataset] + Factory that builds source-data sequence datasets. + input_steps : int + Number of input timesteps in each forecasting sample. + forecast_steps : int + Number of target timesteps in each forecasting sample. + return_mask : bool + Whether forecasting samples should include ``target_mask``. + splits : dict of str to dict + Nested mapping describing train/validation/test coordinate splits. + **dataloader_kwargs : Any + Additional keyword arguments forwarded to ``DataLoader``. + """ + + def __init__( + self, + sequence_dataset_factory: Callable[..., Dataset], + input_steps: int, + forecast_steps: int, + return_mask: bool, + splits: dict[str, dict[str, Any]], + **dataloader_kwargs: Any, + ) -> None: + super().__init__(sequence_dataset_factory=sequence_dataset_factory, splits=splits, **dataloader_kwargs) + self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.return_mask = return_mask + + def _wrap_sequence_dataset(self, base_sequence_dataset: Dataset) -> Dataset: + """Wrap a sequence dataset as a forecasting dataset. + + Parameters + ---------- + base_sequence_dataset : Dataset + Sequence dataset for one split. + + Returns + ------- + Dataset + Forecasting dataset for the split. + """ + return ForecastingDataset( + base_sequence_dataset=base_sequence_dataset, + input_steps=self.input_steps, + forecast_steps=self.forecast_steps, + return_mask=self.return_mask, + ) + + +class ReconstructionDataModule(_BaseDataModule): + """Lightning data module for reconstruction datasets. + + Parameters + ---------- + sequence_dataset_factory : Callable[..., Dataset] + Factory that builds source-data sequence datasets. + input_steps : int + Number of timesteps in each reconstruction window. + splits : dict of str to dict + Nested mapping describing train/validation/test coordinate splits. + **dataloader_kwargs : Any + Additional keyword arguments forwarded to ``DataLoader``. + """ + + def __init__( + self, + sequence_dataset_factory: Callable[..., Dataset], + input_steps: int, + splits: dict[str, dict[str, Any]], + **dataloader_kwargs: Any, + ) -> None: + super().__init__(sequence_dataset_factory=sequence_dataset_factory, splits=splits, **dataloader_kwargs) + self.input_steps = input_steps + + def _wrap_sequence_dataset(self, base_sequence_dataset: Dataset) -> Dataset: + """Wrap a sequence dataset as a reconstruction dataset. + + Parameters + ---------- + base_sequence_dataset : Dataset + Sequence dataset for one split. + + Returns + ------- + Dataset + Reconstruction dataset for the split. + """ + return ReconstructionDataset(base_sequence_dataset=base_sequence_dataset, input_steps=self.input_steps) diff --git a/src/mlcast/data/forecasting.py b/src/mlcast/data/forecasting.py new file mode 100644 index 0000000..7cb2e7a --- /dev/null +++ b/src/mlcast/data/forecasting.py @@ -0,0 +1,111 @@ +"""Forecasting dataset wrappers built on top of sequence datasets.""" + +from typing import TypedDict + +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped +from torch import Tensor +from torch.utils.data import Dataset + + +class ForecastingSample(TypedDict, total=False): + """Typed dictionary returned by :class:`ForecastingDataset`. + + Keys + ---- + input : Float[Tensor, "input_steps channels height width"] + Past frames fed to the forecasting model. + target : Float[Tensor, "forecast_steps channels height width"] + Future frames the forecasting model should predict. + target_mask : Float[Tensor, "forecast_steps channels height width"] + Per-timestep, per-channel validity mask for the target when + ``return_mask=True``. + """ + + input: Float[Tensor, "input_steps channels height width"] + target: Float[Tensor, "forecast_steps channels height width"] + target_mask: Float[Tensor, "forecast_steps channels height width"] + + +class ForecastingDataset(Dataset): + """Wrap a sequence dataset to produce forecasting samples. + + Parameters + ---------- + base_sequence_dataset : Dataset + Dataset returning normalized sequence tensors of shape + ``(sequence_steps, channels, height, width)``. + input_steps : int + Number of past timesteps fed to the forecasting model. + forecast_steps : int + Number of future timesteps the forecasting model should predict. + return_mask : bool, optional + Whether to derive and return a target validity mask. Default is + ``False``. + """ + + def __init__( + self, + base_sequence_dataset: Dataset, + input_steps: int, + forecast_steps: int, + return_mask: bool = False, + ) -> None: + if input_steps < 1: + raise ValueError(f"input_steps ({input_steps}) must be at least 1.") + if forecast_steps < 1: + raise ValueError(f"forecast_steps ({forecast_steps}) must be at least 1.") + + self.base_sequence_dataset = base_sequence_dataset + self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.return_mask = return_mask + + sequence_steps = getattr(base_sequence_dataset, "sequence_steps", None) + if sequence_steps is None: + raise AttributeError("base_sequence_dataset must expose a 'sequence_steps' attribute.") + if input_steps + forecast_steps != sequence_steps: + raise ValueError( + "ForecastingDataset requires input_steps + forecast_steps to equal sequence_steps; " + f"got input_steps={input_steps}, forecast_steps={forecast_steps}, sequence_steps={sequence_steps}." + ) + + def __len__(self) -> int: + """Return the number of available forecasting samples. + + Returns + ------- + int + Number of samples in the wrapped sequence dataset. + """ + return len(self.base_sequence_dataset) + + @jaxtyped(typechecker=beartype) + def __getitem__(self, idx: int) -> ForecastingSample: + """Return one forecasting sample derived from the wrapped sequence. + + Parameters + ---------- + idx : int + Index of the wrapped sequence sample. + + Returns + ------- + ForecastingSample + Dictionary containing ``input`` and ``target`` tensors, and + ``target_mask`` when ``return_mask=True``. + """ + sequence = self.base_sequence_dataset[idx] + + if self.return_mask: + target_mask_t = (~torch.isnan(sequence[self.input_steps :])).to(dtype=torch.float32) + + sequence = torch.nan_to_num(sequence, nan=-1.0).to(dtype=torch.float32) + input_t = sequence[: self.input_steps] + target_t = sequence[self.input_steps :] + + sample = ForecastingSample(input=input_t, target=target_t) + if self.return_mask: + sample["target_mask"] = target_mask_t + return sample diff --git a/src/mlcast/data/reconstruction.py b/src/mlcast/data/reconstruction.py new file mode 100644 index 0000000..6327262 --- /dev/null +++ b/src/mlcast/data/reconstruction.py @@ -0,0 +1,73 @@ +"""Reconstruction datasets built from sequence datasets. + +The reconstruction task reuses normalized source-data sequences and exposes all +overlapping temporal windows of length ``input_steps`` as individual training +samples. +""" + +import torch +from jaxtyping import Float +from torch import Tensor +from torch.utils.data import Dataset + + +class ReconstructionDataset(Dataset): + """Wrap a sequence dataset for stage-1 reconstruction training. + + Parameters + ---------- + base_sequence_dataset : Dataset + Dataset whose samples are normalized sequence tensors with shape + ``(sequence_steps, channels, height, width)``. + input_steps : int + Temporal window length to expose for each reconstruction sample. + + Notes + ----- + Each base sequence contributes all overlapping windows of length + ``input_steps``. The reconstruction training module is responsible for + reusing the returned tensor as both the model input and the reconstruction + target. + """ + + def __init__(self, base_sequence_dataset: Dataset, input_steps: int) -> None: + if input_steps < 1: + raise ValueError(f"input_steps ({input_steps}) must be at least 1.") + + self.base_sequence_dataset = base_sequence_dataset + self.input_steps = input_steps + + sequence_steps = getattr(base_sequence_dataset, "sequence_steps", None) + if sequence_steps is None: + raise AttributeError("base_sequence_dataset must expose a 'sequence_steps' attribute.") + if input_steps > sequence_steps: + raise ValueError( + "ReconstructionDataset requires input_steps to be less than or equal to sequence_steps; " + f"got input_steps={input_steps}, sequence_steps={sequence_steps}." + ) + + self.sequence_steps = sequence_steps + self.windows_per_sequence = self.sequence_steps - self.input_steps + 1 + + def __len__(self) -> int: + """Return the number of available reconstruction windows.""" + return len(self.base_sequence_dataset) * self.windows_per_sequence + + def __getitem__(self, idx: int) -> Float[Tensor, "input_steps channels height width"]: + """Return one overlapping reconstruction window. + + Parameters + ---------- + idx : int + Flat reconstruction-sample index. + + Returns + ------- + Float[Tensor, "input_steps channels height width"] + Window extracted from the wrapped sequence sample. + """ + sequence_idx = idx // self.windows_per_sequence + window_start = idx % self.windows_per_sequence + sequence = self.base_sequence_dataset[sequence_idx] + window = sequence[window_start : window_start + self.input_steps] + return torch.nan_to_num(window, nan=-1.0).to(dtype=torch.float32) diff --git a/src/mlcast/data/source_data_datasets.py b/src/mlcast/data/sequence.py similarity index 51% rename from src/mlcast/data/source_data_datasets.py rename to src/mlcast/data/sequence.py index b8e3c80..391a23c 100644 --- a/src/mlcast/data/source_data_datasets.py +++ b/src/mlcast/data/sequence.py @@ -1,12 +1,14 @@ -"""PyTorch datasets for loading spatio-temporal data from Zarr stores. +"""Source-data sequence datasets built from Zarr stores. -Provides pre-computed sampling and (soon) random sampling datasets. +These datasets are responsible for sampling normalized spatio-temporal +sequences directly from source datasets. They do not impose any forecasting or +reconstruction task structure on the sampled sequence. """ import time import warnings from abc import ABC, abstractmethod -from typing import Any, TypedDict +from typing import Any import cf_xarray # noqa: F401 import numpy as np @@ -15,6 +17,7 @@ import xarray as xr from beartype import beartype from jaxtyping import Float, jaxtyped +from torch import Tensor from torch.utils.data import Dataset from mlcast.data.normalization import NORMALIZATION_REGISTRY @@ -25,7 +28,22 @@ def _time_range_to_index_slice( time_range: tuple[str, str], storage_options: dict[str, Any] | None = None, ) -> slice: - """Convert an inclusive ISO time range into a zarr integer slice.""" + """Convert an inclusive ISO time range into a zarr integer slice. + + Parameters + ---------- + zarr_path : str + Path to the Zarr dataset. + time_range : tuple of str + Inclusive ``(start, end)`` ISO 8601 time range. + storage_options : dict or None, optional + Options forwarded to ``xr.open_zarr``. Default is ``None``. + + Returns + ------- + slice + Integer slice covering the requested time range. + """ ds = xr.open_zarr(zarr_path, storage_options=storage_options) time_values = ds.indexes["time"] t_start = time_values.get_indexer([pd.Timestamp(time_range[0])], method="bfill")[0] @@ -38,46 +56,20 @@ def _time_range_to_index_slice( return slice(int(t_start), int(t_end) + 1) -class DatasetSample(TypedDict, total=False): - """Typed dictionary returned by dataset ``__getitem__``. - - Keys - ---- - input : Float[torch.Tensor, "input_steps channels height width"] - Past frames fed to the network as input. - target : Float[torch.Tensor, "forecast_steps channels height width"] - Future frames the network should predict. - target_mask : Float[torch.Tensor, "forecast_steps channels height width"] - Per-timestep, per-channel validity mask for the target (1 = valid, - 0 = NaN in original data). Only present when ``return_mask=True``. - """ - - input: Float[torch.Tensor, "input_steps channels height width"] - target: Float[torch.Tensor, "forecast_steps channels height width"] - target_mask: Float[torch.Tensor, "forecast_steps channels height width"] - - def _detect_axes(ds: xr.Dataset, standard_name: str) -> tuple[str, str, str]: """Detect CF axis dimension names for a variable in an xarray Dataset. - Falls back to dimension names ``'y'`` / ``'x'`` when CF conventions do not - identify the axis, emitting a :mod:`warnings` warning in each case. - Parameters ---------- ds : xr.Dataset - An open xarray Dataset with CF conventions. + Open xarray Dataset with CF metadata. standard_name : str - A CF standard name present in ``ds``, used to look up the variable. + CF standard name of the variable used to infer axes. Returns ------- - t_dim : str - Dimension name for the time axis. - y_dim : str - Dimension name for the Y (latitude) axis. - x_dim : str - Dimension name for the X (longitude) axis. + tuple of str + Names of the time, Y, and X dimensions. """ da = ds.cf[standard_name] t_dim = da.cf["time"].dims[0] @@ -103,12 +95,8 @@ def _detect_axes(ds: xr.Dataset, standard_name: str) -> tuple[str, str, str]: return t_dim, y_dim, x_dim -class SourceDataDatasetBase(Dataset, ABC): - """Abstract base class for mlcast Zarr-backed spatio-temporal datasets. - - Subclasses must implement :meth:`__len__` and :meth:`__getitem__`. - All common initialisation, Zarr access, CF-axis resolution, augmentation, - and the ``steps`` property live here. +class SourceDataSequenceDatasetBase(Dataset, ABC): + """Abstract base class for source-data-backed sequence datasets. Parameters ---------- @@ -116,13 +104,8 @@ class SourceDataDatasetBase(Dataset, ABC): Path to the Zarr dataset. standard_names : list of str List of CF standard names of variables to load. - input_steps : int - Number of past timesteps fed to the network as input. - forecast_steps : int - Number of future timesteps the network should predict. - return_mask : bool, optional - If ``True``, also return a per-timestep validity mask for the target. - Default is ``False``. + sequence_steps : int + Number of timesteps to include in each sampled sequence. deterministic : bool, optional If ``True``, use a fixed random seed (42). Default is ``False``. augment : bool, optional @@ -139,27 +122,21 @@ def __init__( self, zarr_path: str, standard_names: list[str], - input_steps: int, - forecast_steps: int, - return_mask: bool = False, + sequence_steps: int, deterministic: bool = False, augment: bool = False, width: int = 256, height: int = 256, storage_options: dict[str, Any] | None = None, ) -> None: - if input_steps < 1: - raise ValueError(f"input_steps ({input_steps}) must be at least 1.") - if forecast_steps < 1: - raise ValueError(f"forecast_steps ({forecast_steps}) must be at least 1.") + if sequence_steps < 1: + raise ValueError(f"sequence_steps ({sequence_steps}) must be at least 1.") self.storage_options = storage_options self._zarr_path = zarr_path self._ds: xr.Dataset | None = None self.standard_names = standard_names - self.input_steps = input_steps - self.forecast_steps = forecast_steps - self.return_mask = return_mask + self.sequence_steps = sequence_steps self.augment = augment self.w = width self.h = height @@ -168,34 +145,14 @@ def __init__( self._validate_standard_names() self.t_dim, self.y_dim, self.x_dim = _detect_axes(self.ds, self.standard_names[0]) - # ------------------------------------------------------------------ - # Properties - # ------------------------------------------------------------------ - - @property - def steps(self) -> int: - """Total number of timesteps per sample (``input_steps + forecast_steps``). - - Returns - ------- - steps : int - ``input_steps + forecast_steps``. - """ - return self.input_steps + self.forecast_steps - @property def ds(self) -> xr.Dataset: """Open and cache the Zarr-backed xarray Dataset for this worker. - The store is opened lazily on first access within each process. This - avoids pickling live asyncio connections across DataLoader worker - boundaries, which would cause ``RuntimeError: Future attached to a - different loop``. - Returns ------- - ds : xr.Dataset - The opened (and optionally time-sliced) xarray Dataset. + xr.Dataset + Opened dataset, optionally subset in time for this worker process. """ if self._ds is None: ds = xr.open_zarr(self._zarr_path, storage_options=self.storage_options) @@ -204,17 +161,13 @@ def ds(self) -> xr.Dataset: self._ds = ds return self._ds - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - def _validate_standard_names(self) -> None: """Check that every requested CF standard name exists in the Zarr store. Raises ------ ValueError - If a requested standard name is not found. + If any requested standard name is missing from the dataset. """ for std_name in self.standard_names: try: @@ -241,112 +194,100 @@ def _validate_standard_names(self) -> None: raise ValueError(msg) from e def _apply_augmentations( - self, *tensors: torch.Tensor, rotate_prob: float = 0.5, hflip_prob: float = 0.5, vflip_prob: float = 0.5 - ) -> tuple[torch.Tensor, ...]: - """Apply random spatial augmentations consistently to all input tensors.""" + self, + tensor: Float[Tensor, "sequence_steps channels height width"], + rotate_prob: float = 0.5, + hflip_prob: float = 0.5, + vflip_prob: float = 0.5, + ) -> Float[Tensor, "sequence_steps channels height width"]: + """Apply random spatial augmentations to a sequence tensor. + + Parameters + ---------- + tensor : Float[Tensor, "sequence_steps channels height width"] + Sequence tensor to augment. + rotate_prob : float, optional + Probability of applying a random 90-degree rotation. Default is + ``0.5``. + hflip_prob : float, optional + Probability of applying a horizontal flip. Default is ``0.5``. + vflip_prob : float, optional + Probability of applying a vertical flip. Default is ``0.5``. + + Returns + ------- + Float[Tensor, "sequence_steps channels height width"] + Augmented contiguous tensor. + """ if self.rng.random() < rotate_prob: k = self.rng.integers(1, 4) - tensors = tuple(torch.rot90(t, int(k), dims=[-2, -1]) for t in tensors) + tensor = torch.rot90(tensor, int(k), dims=[-2, -1]) if self.rng.random() < hflip_prob: - tensors = tuple(torch.flip(t, dims=[-1]) for t in tensors) + tensor = torch.flip(tensor, dims=[-1]) if self.rng.random() < vflip_prob: - tensors = tuple(torch.flip(t, dims=[-2]) for t in tensors) - - return tuple(t.contiguous() for t in tensors) + tensor = torch.flip(tensor, dims=[-2]) - def _build_sample(self, data: np.ndarray) -> DatasetSample: - """Convert a raw ``(T, C, H, W)`` numpy array into a :class:`DatasetSample`. + return tensor.contiguous() - Computes the target mask (before ``nan_to_num``), splits into input / - target tensors along the time axis, applies augmentations if requested, - and assembles the final dict. + def _build_sequence(self, data: np.ndarray) -> Float[Tensor, "sequence_steps channels height width"]: + """Convert a raw ``(T, C, H, W)`` numpy array into a tensor. Parameters ---------- data : np.ndarray - Raw normalised array of shape ``(steps, C, H, W)`` — may contain - NaNs where the original data was invalid. + Normalized array with shape ``(sequence_steps, channels, height, + width)``. Returns ------- - sample : DatasetSample - Dictionary with ``'input'`` and ``'target'`` tensors, and - optionally ``'target_mask'`` if ``self.return_mask`` is ``True``. + Float[Tensor, "sequence_steps channels height width"] + Float32 sequence tensor, augmented if requested. """ - # Capture target mask before NaNs are filled - if self.return_mask: - target_mask_t = torch.from_numpy((~np.isnan(data[self.input_steps :])).astype(np.float32)) - - # source data may be float64, but the model and the rest of the - # training pipeline operate in float32. - data = np.nan_to_num(data, nan=-1.0).astype(np.float32) - data_t = torch.from_numpy(data) - - input_t = data_t[: self.input_steps] - target_t = data_t[self.input_steps :] - + data = np.ascontiguousarray(data, dtype=np.float32) + sequence_t = torch.from_numpy(data) if self.augment: - tensors = (input_t, target_t, target_mask_t) if self.return_mask else (input_t, target_t) - augmented = self._apply_augmentations(*tensors) - if self.return_mask: - input_t, target_t, target_mask_t = augmented - else: - input_t, target_t = augmented - - sample = DatasetSample(input=input_t, target=target_t) - if self.return_mask: - sample["target_mask"] = target_mask_t - return sample - - # ------------------------------------------------------------------ - # Abstract interface - # ------------------------------------------------------------------ + sequence_t = self._apply_augmentations(sequence_t) + return sequence_t @abstractmethod def __len__(self) -> int: ... @abstractmethod - def __getitem__(self, idx: int) -> DatasetSample: ... - + def __getitem__(self, idx: int) -> Float[Tensor, "sequence_steps channels height width"]: ... -class SourceDataPrecomputedSamplingDataset(SourceDataDatasetBase): - """PyTorch dataset that loads spatio-temporal data from a Zarr store using - pre-sampled spatial-temporal coordinates from a CSV file. - Each sample is a spatio-temporal crop of shape ``(T, C, H, W)`` - converted to normalized data. +class SourceDataPrecomputedSequenceDataset(SourceDataSequenceDatasetBase): + """Sequence dataset using pre-sampled spatial-temporal coordinates from CSV. Parameters ---------- zarr_path : str Path to the Zarr dataset. csv_path : str - Path to the CSV file with columns ``(t, x, y)`` specifying the - top-left corner of each crop. + Path to the CSV file with ``t``, ``x``, and ``y`` crop coordinates. standard_names : list of str - List of CF standard names of variables to load (e.g., ``["rainfall_rate"]``). - input_steps : int - Number of past timesteps fed to the network as input. - forecast_steps : int - Number of future timesteps the network should predict. - return_mask : bool, optional - If ``True``, also return a per-timestep validity mask for the target. - Default is ``False``. + CF standard names of variables to load. + sequence_steps : int + Number of timesteps to include in each sampled sequence. deterministic : bool, optional - If ``True``, use a fixed random seed (42) for reproducibility. Default is ``False``. + If ``True``, use deterministic random sampling within precomputed time + windows. Default is ``False``. augment : bool, optional - If ``True``, apply random spatial augmentations (rotation, flips). Default is ``False``. + If ``True``, apply random spatial augmentations. Default is ``False``. subset : dict or None, optional Coordinate subsetting specification. Only ``{"time": (start, end)}`` - is supported, where the time range is inclusive and uses ISO strings. + is supported. Default is ``None``. width : int, optional Spatial width of each crop. Default is ``256``. height : int, optional Spatial height of each crop. Default is ``256``. time_depth : int, optional - Number of timesteps in the sampled window. Default is ``24``. + Number of timesteps in each precomputed sampled window. Default is + ``24``. + storage_options : dict or None, optional + Options forwarded to ``xr.open_zarr``. Default is ``None``. """ def __init__( @@ -354,9 +295,7 @@ def __init__( zarr_path: str, csv_path: str, standard_names: list[str], - input_steps: int, - forecast_steps: int, - return_mask: bool = False, + sequence_steps: int, deterministic: bool = False, augment: bool = False, subset: dict[str, Any] | None = None, @@ -379,9 +318,7 @@ def __init__( super().__init__( zarr_path=zarr_path, standard_names=standard_names, - input_steps=input_steps, - forecast_steps=forecast_steps, - return_mask=return_mask, + sequence_steps=sequence_steps, deterministic=deterministic, augment=augment, width=width, @@ -399,48 +336,45 @@ def __init__( self.dt = time_depth - if self.steps > self.dt: - print(f"Warning: requested steps ({self.steps}) > sampled time window ({self.dt})") + if self.sequence_steps > self.dt: + print(f"Warning: requested sequence_steps ({self.sequence_steps}) > sampled time window ({self.dt})") - # Close the store: metadata has been extracted into plain attributes above. - # Each DataLoader worker will reopen it via the `ds` property in its own - # event loop, avoiding asyncio "Future attached to a different loop" errors. self._ds = None def __len__(self) -> int: - """Get the number of samples in the dataset. + """Get the number of precomputed crop coordinates. Returns ------- - length : int - Number of samples. + int + Number of available sequence samples. """ return len(self.coords) @jaxtyped(typechecker=beartype) - def __getitem__(self, idx: int) -> DatasetSample: - """Load and return a single crop sample. + def __getitem__(self, idx: int) -> Float[Tensor, "sequence_steps channels height width"]: + """Load and return a single normalized sequence tensor. + + Parameters + ---------- + idx : int + Index of the precomputed crop coordinate. Returns ------- - sample : DatasetSample - Dictionary with keys ``'input'`` of shape - ``(input_steps, C, H, W)`` and ``'target'`` of shape - ``(forecast_steps, C, H, W)``. If ``return_mask`` is ``True``, - also contains ``'target_mask'`` of shape - ``(forecast_steps, C, H, W)`` with 1 where the original data was - valid and 0 where it was NaN. + Float[Tensor, "sequence_steps channels height width"] + Normalized sequence tensor sampled from the source dataset. """ t0, x0, y0 = self.coords.iloc[idx] x_slice = slice(int(x0), int(x0) + self.w) y_slice = slice(int(y0), int(y0) + self.h) - if self.steps < self.dt: - t_start = self.rng.integers(t0, t0 + self.dt - self.steps + 1) + if self.sequence_steps < self.dt: + t_start = self.rng.integers(t0, t0 + self.dt - self.sequence_steps + 1) else: t_start = t0 - t_slice = slice(int(t_start), int(t_start) + self.steps) + t_slice = slice(int(t_start), int(t_start) + self.sequence_steps) channels = [] for std_name in self.standard_names: @@ -448,56 +382,45 @@ def __getitem__(self, idx: int) -> DatasetSample: norm_func = NORMALIZATION_REGISTRY[std_name] channels.append(norm_func(da_var.values)) - # swapaxes returns a view; make it contiguous and float32 before - # handing it to _build_sample()/torch.from_numpy(). - data = np.ascontiguousarray(np.swapaxes(np.stack(channels, axis=0), 0, 1), dtype=np.float32) - return self._build_sample(data) - + data = np.swapaxes(np.stack(channels, axis=0), 0, 1) + return self._build_sequence(data) -class SourceDataRandomSamplingDataset(SourceDataDatasetBase): - """PyTorch dataset that performs on-the-fly random spatial and temporal - slicing of a Zarr store spatio-temporal data array. - Each sample is a spatio-temporal crop of shape ``(T, C, H, W)`` - converted to normalized data. +class SourceDataRandomSequenceDataset(SourceDataSequenceDatasetBase): + """Sequence dataset with on-the-fly random spatial and temporal sampling. Parameters ---------- zarr_path : str Path to the Zarr dataset. standard_names : list of str - List of CF standard names of variables to load (e.g., ``["rainfall_rate"]``). - input_steps : int - Number of past timesteps fed to the network as input. - forecast_steps : int - Number of future timesteps the network should predict. - return_mask : bool, optional - If ``True``, also return a per-timestep validity mask for the target. - Default is ``False``. + CF standard names of variables to load. + sequence_steps : int + Number of timesteps to include in each sampled sequence. deterministic : bool, optional - If ``True``, use a fixed random seed (42) for reproducibility. Default is ``False``. + If ``True``, use deterministic random sampling. Default is ``False``. augment : bool, optional - If ``True``, apply random spatial augmentations (rotation, flips). Default is ``False``. + If ``True``, apply random spatial augmentations. Default is ``False``. subset : dict or None, optional Coordinate subsetting specification. Only ``{"time": (start, end)}`` - is supported, where the time range is inclusive and uses ISO strings. + is supported. Default is ``None``. width : int, optional Spatial width of each crop. Default is ``256``. height : int, optional Spatial height of each crop. Default is ``256``. epoch_size : int, optional - Number of random samples to generate per epoch. Default is ``1000``. + Number of random samples exposed per epoch. Default is ``1000``. + storage_options : dict or None, optional + Options forwarded to ``xr.open_zarr``. Default is ``None``. **kwargs : Any - Ignored extra arguments (e.g. ``csv_path``) to allow drop-in replacement. + Ignored extra arguments to allow partial config reuse. """ def __init__( self, zarr_path: str, standard_names: list[str], - input_steps: int, - forecast_steps: int, - return_mask: bool = False, + sequence_steps: int, deterministic: bool = False, augment: bool = False, subset: dict[str, Any] | None = None, @@ -521,9 +444,7 @@ def __init__( super().__init__( zarr_path=zarr_path, standard_names=standard_names, - input_steps=input_steps, - forecast_steps=forecast_steps, - return_mask=return_mask, + sequence_steps=sequence_steps, deterministic=deterministic, augment=augment, width=width, @@ -538,47 +459,46 @@ def __init__( self.max_y = da_first_var.sizes[self.y_dim] self.max_x = da_first_var.sizes[self.x_dim] - if self.steps > self.max_t: - raise ValueError(f"Requested steps ({self.steps}) > available time dimension ({self.max_t})") + if self.sequence_steps > self.max_t: + raise ValueError( + f"Requested sequence_steps ({self.sequence_steps}) > available time dimension ({self.max_t})" + ) if self.h > self.max_y: raise ValueError(f"Requested height ({self.h}) > available Y dimension ({self.max_y})") if self.w > self.max_x: raise ValueError(f"Requested width ({self.w}) > available X dimension ({self.max_x})") - # Close the store: metadata has been extracted into plain attributes above. - # Each DataLoader worker will reopen it via the `ds` property in its own - # event loop, avoiding asyncio "Future attached to a different loop" errors. self._ds = None def __len__(self) -> int: - """Get the number of samples in the dataset. + """Get the configured random epoch size. Returns ------- - length : int - Number of samples. + int + Number of random sequence samples exposed per epoch. """ return self.epoch_size @jaxtyped(typechecker=beartype) - def __getitem__(self, idx: int) -> DatasetSample: - """Load and return a single randomly sampled datacube. + def __getitem__(self, idx: int) -> Float[Tensor, "sequence_steps channels height width"]: + """Load and return a single randomly sampled normalized sequence. + + Parameters + ---------- + idx : int + Ignored sample index; each call draws a random crop. Returns ------- - sample : DatasetSample - Dictionary with keys ``'input'`` of shape - ``(input_steps, C, H, W)`` and ``'target'`` of shape - ``(forecast_steps, C, H, W)``. If ``return_mask`` is ``True``, - also contains ``'target_mask'`` of shape - ``(forecast_steps, C, H, W)`` with 1 where the original data was - valid and 0 where it was NaN. + Float[Tensor, "sequence_steps channels height width"] + Normalized sequence tensor sampled from the source dataset. """ - t_start = self.rng.integers(0, self.max_t - self.steps + 1) + t_start = self.rng.integers(0, self.max_t - self.sequence_steps + 1) y_start = self.rng.integers(0, self.max_y - self.h + 1) x_start = self.rng.integers(0, self.max_x - self.w + 1) - t_slice = slice(int(t_start), int(t_start) + self.steps) + t_slice = slice(int(t_start), int(t_start) + self.sequence_steps) y_slice = slice(int(y_start), int(y_start) + self.h) x_slice = slice(int(x_start), int(x_start) + self.w) @@ -588,7 +508,5 @@ def __getitem__(self, idx: int) -> DatasetSample: norm_func = NORMALIZATION_REGISTRY[std_name] channels.append(norm_func(da_var.values)) - # swapaxes returns a view; make it contiguous and float32 before - # handing it to _build_sample()/torch.from_numpy(). - data = np.ascontiguousarray(np.swapaxes(np.stack(channels, axis=0), 0, 1), dtype=np.float32) - return self._build_sample(data) + data = np.swapaxes(np.stack(channels, axis=0), 0, 1) + return self._build_sequence(data) diff --git a/src/mlcast/data/source_data_datamodule.py b/src/mlcast/data/source_data_datamodule.py deleted file mode 100644 index 44f99bd..0000000 --- a/src/mlcast/data/source_data_datamodule.py +++ /dev/null @@ -1,161 +0,0 @@ -"""PyTorch Lightning data module for spatio-temporal datasets. - -Handles train/val/test splitting and DataLoader creation from an injected -dataset factory. -""" - -from collections.abc import Callable -from typing import Any - -import pytorch_lightning as pl -from loguru import logger -from torch.utils.data import DataLoader, Dataset - -from mlcast.data.splits import ( - compute_split_ranges_from_splitting_ratios, - splitting_uses_fractions, - splitting_uses_tuple_ranges, - validate_splits, -) - - -class SourceDataDataModule(pl.LightningDataModule): - """PyTorch Lightning data module for spatio-temporal datasets. - - Handles train/val/test splitting and DataLoader creation by utilizing - an injected ``dataset_factory``. - - Parameters - ---------- - dataset_factory : Callable[..., Dataset] - A factory function (e.g., ``fdl.Partial``) that returns a Dataset instance. - It must accept ``subset`` and ``augment`` as keyword arguments. - splits : dict of {str: dict} - Nested mapping ``{coord: {split_name: value, ...}, ...}`` describing - train/val/test subsets. Currently only the ``time`` coordinate is - supported. Ratio mode uses float fractions, while datetime mode uses - inclusive ``(start, end)`` ISO 8601 string tuples. - **dataloader_kwargs : Any - Additional keyword arguments forwarded to ``DataLoader`` (e.g., - ``batch_size``, ``num_workers``, ``pin_memory``). - """ - - def __init__( - self, - dataset_factory: Callable[..., Dataset], - splits: dict[str, dict[str, Any]], - **dataloader_kwargs: Any, - ) -> None: - super().__init__() - self.dataset_factory = dataset_factory - self.splits = splits - self.dataloader_kwargs = dataloader_kwargs - validate_splits(self.splits) - - def setup(self, stage: str | None = None) -> None: - """Create train, validation, and test datasets. - - Splits are assembled into per-dataset ``subset`` dictionaries. - Datetime-mode splits are passed through unchanged, while ratio-mode - splits are first resolved against the zarr coordinate values and then - converted to inclusive coordinate ranges before dataset instantiation. - Dataset construction depends on the requested Lightning stage: - - - ``"fit"`` builds train and validation datasets; - - ``"validate"`` builds only the validation dataset; - - ``"test"`` builds only the test dataset; and - - ``None`` builds all configured datasets. - - Parameters - ---------- - stage : str | None, optional - Lightning stage hint controlling which datasets are constructed. - - Raises - ------ - ValueError - If ``stage`` is not one of ``None``, ``"fit"``, ``"validate"``, - or ``"test"``. - """ - if stage == "fit": - requested_splits = {"train", "val"} - elif stage == "validate": - requested_splits = {"val"} - elif stage == "test": - requested_splits = {"test"} - elif stage is None: - requested_splits = {"train", "val", "test"} - else: - raise ValueError(f"Unsupported LightningDataModule setup stage: {stage!r}") - - subset_per_split: dict[str, dict[str, Any] | None] = { - split_name: ( - {} - if split_name in requested_splits - and any(split_name in coord_splits for coord_splits in self.splits.values()) - else None - ) - for split_name in ("train", "val", "test") - } - - for coord, coord_splits in self.splits.items(): - if splitting_uses_tuple_ranges(coord_splits): - # tuple-based splits are expected to present the start and end - # of each split, and so are passed through directly as the - # subset values for each split - coord_values_per_split: dict[str, tuple[str, str] | None] = { - "train": coord_splits["train"], - "val": coord_splits["val"], - "test": coord_splits.get("test"), - } - elif splitting_uses_fractions(coord_splits): - # for ratio-based splits, the splitting start-end range tuples - # are constructed by breaking up the given coordinate in - # successive segments (the succession is defined from the order - # of the keys in the splits dict) - coord_values_per_split = compute_split_ranges_from_splitting_ratios( - self.dataset_factory, coord, coord_splits - ) - else: - raise NotImplementedError(f"Unsupported split mode for coordinate {coord!r}: {coord_splits!r}") - - for split_name, split_val in coord_values_per_split.items(): - if split_val is None: - subset_per_split[split_name] = None - elif subset_per_split[split_name] is not None: - subset_per_split[split_name][coord] = split_val - - augment_flags = {"train": True, "val": False, "test": False} - for split in ("train", "val", "test"): - subset = subset_per_split[split] - if subset is None: - setattr(self, f"{split}_dataset", None) - else: - setattr( - self, - f"{split}_dataset", - self.dataset_factory(subset=subset, augment=augment_flags[split]), - ) - - logger.info("{}.setup() complete, containing:", self.__class__.__name__) - for split in ("train", "val", "test"): - dataset = getattr(self, f"{split}_dataset", None) - if dataset is not None: - logger.info( - " {:5s}: {:>6d} samples, subset={}", - split, - len(dataset), - subset_per_split[split], - ) - - def train_dataloader(self) -> DataLoader: - """Return the training DataLoader.""" - return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs) - - def val_dataloader(self) -> DataLoader: - """Return the validation DataLoader.""" - return DataLoader(self.val_dataset, shuffle=False, **self.dataloader_kwargs) - - def test_dataloader(self) -> DataLoader: - """Return the test DataLoader.""" - return DataLoader(self.test_dataset, shuffle=False, **self.dataloader_kwargs) diff --git a/tests/config/test_consistency_checks.py b/tests/config/test_consistency_checks.py index 5277d30..97a7f81 100644 --- a/tests/config/test_consistency_checks.py +++ b/tests/config/test_consistency_checks.py @@ -4,7 +4,7 @@ from loguru import logger from mlcast.config import convgru_training_experiment, validate_config -from mlcast.data.source_data_datasets import SourceDataPrecomputedSamplingDataset +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset def test_contract_1_input_channels() -> None: @@ -12,7 +12,7 @@ def test_contract_1_input_channels() -> None: cfg = convgru_training_experiment.as_buildable() # Break Contract 1 cfg.pl_module.network.input_channels = 2 - cfg.data.dataset_factory.standard_names = ["rainfall_rate"] + cfg.data.sequence_dataset_factory.standard_names = ["rainfall_rate"] with pytest.raises(ValueError, match="Contract 1 violated:"): validate_config(cfg) @@ -22,7 +22,7 @@ def test_contract_2_spatial_divisibility() -> None: """Verify Contract 2: Dataset width must be divisible by 2 \\*\\* network.num_blocks.""" cfg = convgru_training_experiment.as_buildable() # Break Contract 2 - cfg.data.dataset_factory.width = 250 + cfg.data.sequence_dataset_factory.width = 250 cfg.pl_module.network.num_blocks = 4 with pytest.raises(ValueError, match="Contract 2 violated:"): @@ -64,20 +64,19 @@ def test_contract_4_masking_sync() -> None: """Verify Contract 4: Dataset return_mask must match model masked_loss.""" cfg = convgru_training_experiment.as_buildable() # Break Contract 4 - cfg.data.dataset_factory.return_mask = True + cfg.data.return_mask = True cfg.pl_module.masked_loss = False with pytest.raises(ValueError, match="Contract 4 violated:"): validate_config(cfg) -def test_dataset_forecast_steps_guard() -> None: - """Verify that dataset raises ValueError when input_steps=0.""" - with pytest.raises(ValueError, match="input_steps"): - SourceDataPrecomputedSamplingDataset( +def test_dataset_sequence_steps_guard() -> None: + """Verify that sequence dataset raises ValueError when sequence_steps=0.""" + with pytest.raises(ValueError, match="sequence_steps"): + SourceDataPrecomputedSequenceDataset( zarr_path="dummy.zarr", csv_path="dummy.csv", standard_names=["rainfall_rate"], - input_steps=0, - forecast_steps=5, + sequence_steps=0, ) diff --git a/tests/config/test_fiddlers.py b/tests/config/test_fiddlers.py index 001f1ed..347f2e7 100644 --- a/tests/config/test_fiddlers.py +++ b/tests/config/test_fiddlers.py @@ -9,7 +9,7 @@ def test_fiddler_set_variables() -> None: set_variables(cfg, ["rainfall_rate", "rainfall_flux"]) # Check sync - assert cfg.data.dataset_factory.standard_names == ["rainfall_rate", "rainfall_flux"] + assert cfg.data.sequence_dataset_factory.standard_names == ["rainfall_rate", "rainfall_flux"] assert cfg.pl_module.network.input_channels == 2 @@ -19,10 +19,10 @@ def test_fiddler_toggle_masking() -> None: # Disable masking toggle_masking(cfg, False) - assert cfg.data.dataset_factory.return_mask is False + assert cfg.data.return_mask is False assert cfg.pl_module.masked_loss is False # Enable masking toggle_masking(cfg, True) - assert cfg.data.dataset_factory.return_mask is True + assert cfg.data.return_mask is True assert cfg.pl_module.masked_loss is True diff --git a/tests/data/test_data_module.py b/tests/data/test_data_module.py index e55bf58..8968500 100644 --- a/tests/data/test_data_module.py +++ b/tests/data/test_data_module.py @@ -3,28 +3,29 @@ import pandas as pd import pytest +import torch from torch.utils.data import DataLoader, Dataset -from mlcast.data.source_data_datamodule import SourceDataDataModule +from mlcast.data.datamodules import ForecastingDataModule, ReconstructionDataModule +from mlcast.data.forecasting import ForecastingDataset +from mlcast.data.reconstruction import ReconstructionDataset from mlcast.data.splits import splitting_uses_fractions, splitting_uses_tuple_ranges, validate_splits -class MockDataset(Dataset): - """Minimal dataset mock that records how it was constructed. - - ``__len__`` returns a fixed size so that dataloader batch-count assertions - work correctly. - """ +class MockSequenceDataset(Dataset): + """Minimal sequence dataset mock that records how it was constructed.""" def __init__( self, zarr_path: str, + sequence_steps: int, subset: dict | None = None, augment: bool = False, epoch_size: int = 100, **kwargs, ) -> None: self.zarr_path = zarr_path + self.sequence_steps = sequence_steps self.subset = subset self.augment = augment self.epoch_size = epoch_size @@ -33,8 +34,9 @@ def __init__( def __len__(self) -> int: return self.epoch_size - def __getitem__(self, idx: int) -> dict: - return {"data": idx} + def __getitem__(self, idx: int) -> torch.Tensor: + base = torch.arange(self.sequence_steps, dtype=torch.float32)[:, None, None, None] + return base.expand(-1, 1, 4, 4) def _mock_zarr(time_index: pd.DatetimeIndex) -> MagicMock: @@ -109,30 +111,41 @@ def test_splitting_mode_helpers_require_consistent_values() -> None: assert not splitting_uses_tuple_ranges({"train": object(), "val": object()}) -def test_data_module_ratio_splits() -> None: +def test_forecasting_data_module_ratio_splits() -> None: """DataModule ratio mode passes correct time subsets to the factory.""" n = 100 time_index = _make_time_index(n) - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr", foo="bar") + sequence_dataset_factory = functools.partial( + MockSequenceDataset, + zarr_path="mock.zarr", + sequence_steps=6, + foo="bar", + ) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, splits={"time": {"train": 0.5, "val": 0.2, "test": 0.3}}, batch_size=2 + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.3}}, + batch_size=2, ) with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="fit") - assert dm.train_dataset.augment is True - assert dm.train_dataset.kwargs["foo"] == "bar" - train_start, train_end = dm.train_dataset.subset["time"] - val_start, val_end = dm.val_dataset.subset["time"] + assert isinstance(dm.train_dataset, ForecastingDataset) + assert dm.train_dataset.base_sequence_dataset.augment is True + assert dm.train_dataset.base_sequence_dataset.kwargs["foo"] == "bar" + train_start, train_end = dm.train_dataset.base_sequence_dataset.subset["time"] + val_start, val_end = dm.val_dataset.base_sequence_dataset.subset["time"] assert train_start == str(time_index[0]) assert train_end == str(time_index[49]) assert val_start == str(time_index[50]) assert val_end == str(time_index[69]) - assert dm.val_dataset.augment is False + assert dm.val_dataset.base_sequence_dataset.augment is False assert dm.test_dataset is None train_dl = dm.train_dataloader() @@ -147,18 +160,27 @@ class _NoZarrPathFactory: def __call__(self, **kwargs) -> Dataset: return MagicMock(spec=Dataset) - dm = SourceDataDataModule(dataset_factory=_NoZarrPathFactory(), splits={"time": {"train": 0.7, "val": 0.15}}) + dm = ForecastingDataModule( + sequence_dataset_factory=_NoZarrPathFactory(), + input_steps=2, + forecast_steps=4, + return_mask=True, + splits={"time": {"train": 0.7, "val": 0.15}}, + ) with pytest.raises((AttributeError, KeyError)): dm.setup() -def test_data_module_fraction_splits_without_test_do_not_create_test_dataset() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") +def test_forecasting_data_module_fraction_splits_without_test_do_not_create_test_dataset() -> None: + sequence_dataset_factory = functools.partial(MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=6) time_index = _make_time_index(100) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, splits={"time": {"train": 0.5, "val": 0.2}}, batch_size=2, ) @@ -171,18 +193,22 @@ def test_data_module_fraction_splits_without_test_do_not_create_test_dataset() - assert dm.test_dataset is None -def test_data_module_split_lengths_and_batches() -> None: - """Test that dataset lengths and dataloader batch counts are correct after splitting. - - Dataloader batch counts are correct after splitting. - """ +def test_forecasting_data_module_split_lengths_and_batches() -> None: n_time = 240 batch_size = 10 time_index = _make_time_index(n_time) - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr", epoch_size=10) + sequence_dataset_factory = functools.partial( + MockSequenceDataset, + zarr_path="mock.zarr", + sequence_steps=6, + epoch_size=10, + ) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, splits={"time": {"train": 1 / 2, "val": 1 / 3, "test": 1 / 6}}, batch_size=batch_size, ) @@ -195,11 +221,14 @@ def test_data_module_split_lengths_and_batches() -> None: assert len(dm.test_dataloader()) == 1 -def test_data_module_datetime_splits() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") +def test_forecasting_data_module_datetime_splits() -> None: + sequence_dataset_factory = functools.partial(MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=6) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, splits={ "time": { "train": ("2016-01-01", "2021-12-31"), @@ -212,36 +241,20 @@ def test_data_module_datetime_splits() -> None: dm.setup() - assert dm.train_dataset.subset == {"time": ("2016-01-01", "2021-12-31")} - assert dm.val_dataset.subset == {"time": ("2022-01-01", "2023-12-31")} + assert dm.train_dataset.base_sequence_dataset.subset == {"time": ("2016-01-01", "2021-12-31")} + assert dm.val_dataset.base_sequence_dataset.subset == {"time": ("2022-01-01", "2023-12-31")} assert dm.test_dataset is None -def test_data_module_fraction_test_split_uses_explicit_fraction() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - time_index = _make_time_index(100) - - dm = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, - batch_size=2, +def test_reconstruction_data_module_wraps_sequence_splits() -> None: + sequence_dataset_factory = functools.partial( + MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=5, epoch_size=5 ) - - with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): - dm.setup() - - assert dm.test_dataset is not None - test_start, test_end = dm.test_dataset.subset["time"] - assert test_start == str(time_index[70]) - assert test_end == str(time_index[79]) - - -def test_data_module_fit_stage_creates_only_train_and_val() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") time_index = _make_time_index(100) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ReconstructionDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=3, splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, batch_size=2, ) @@ -249,81 +262,59 @@ def test_data_module_fit_stage_creates_only_train_and_val() -> None: with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="fit") - assert dm.train_dataset is not None - assert dm.val_dataset is not None + assert isinstance(dm.train_dataset, ReconstructionDataset) + assert isinstance(dm.val_dataset, ReconstructionDataset) assert dm.test_dataset is None + assert dm.train_dataset.base_sequence_dataset.augment is True + assert dm.val_dataset.base_sequence_dataset.augment is False + assert dm.train_dataset[0].shape == (3, 1, 4, 4) -def test_data_module_validate_stage_creates_only_val() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") +def test_data_module_validate_test_and_logging_paths() -> None: + sequence_dataset_factory = functools.partial(MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=6) time_index = _make_time_index(100) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, batch_size=2, ) with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="validate") - assert dm.train_dataset is None assert dm.val_dataset is not None assert dm.test_dataset is None - -def test_data_module_test_stage_creates_only_test() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - time_index = _make_time_index(100) - - dm = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, - batch_size=2, - ) - with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="test") - assert dm.train_dataset is None assert dm.val_dataset is None assert dm.test_dataset is not None - -def test_data_module_rejects_unknown_stage() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - dm = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, - batch_size=2, - ) - with pytest.raises(ValueError, match="Unsupported LightningDataModule setup stage"): dm.setup(stage="predict") - -def test_data_module_logs_split_summary() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - time_index = _make_time_index(100) - - dm = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, - batch_size=2, - ) - with ( patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)), - patch("mlcast.data.source_data_datamodule.logger.info") as mock_info, + patch("mlcast.data.datamodules.logger.info") as mock_info, ): dm.setup() - assert mock_info.call_count == 4 def test_data_module_unsupported_split_mode() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - dm = SourceDataDataModule(dataset_factory=dataset_factory, splits={"time": {"train": 0.7, "val": 0.15}}) + sequence_dataset_factory = functools.partial(MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=6) + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, + splits={"time": {"train": 0.7, "val": 0.15}}, + ) dm.splits = {"time": {"train": object(), "val": object()}} diff --git a/tests/data/test_source_data_datasets.py b/tests/data/test_source_data_datasets.py index c9a2f2f..c5b036c 100644 --- a/tests/data/test_source_data_datasets.py +++ b/tests/data/test_source_data_datasets.py @@ -4,11 +4,23 @@ import pytest import torch import xarray as xr +from torch.utils.data import Dataset -from mlcast.data.source_data_datasets import ( - SourceDataPrecomputedSamplingDataset, - SourceDataRandomSamplingDataset, -) +from mlcast.data.forecasting import ForecastingDataset +from mlcast.data.reconstruction import ReconstructionDataset +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset, SourceDataRandomSequenceDataset + + +class MockSequenceDataset(Dataset): + def __init__(self, sequence_steps: int, num_samples: int = 2) -> None: + self.sequence_steps = sequence_steps + self.num_samples = num_samples + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> torch.Tensor: + return torch.arange(self.sequence_steps, dtype=torch.float32)[:, None, None, None].expand(-1, 1, 2, 2) @pytest.fixture @@ -26,110 +38,75 @@ def mock_csv(tmp_path: Path) -> str: return str(csv_path) -def test_precomputed_sampling_dataset(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that SourceDataPrecomputedSamplingDataset outputs the correct shape.""" - input_steps = 2 - forecast_steps = 1 - ds = SourceDataPrecomputedSamplingDataset( +def test_precomputed_sequence_dataset(fp_test_dataset: Path, mock_csv: str) -> None: + """Precomputed sequence dataset should output normalized sequence tensors.""" + sequence_steps = 3 + ds = SourceDataPrecomputedSequenceDataset( zarr_path=str(fp_test_dataset), csv_path=mock_csv, standard_names=["rainfall_flux"], - input_steps=input_steps, - forecast_steps=forecast_steps, + sequence_steps=sequence_steps, width=16, height=16, - return_mask=True, ) assert len(ds) == 3 sample = ds[0] + assert sample.shape == (sequence_steps, 1, 16, 16) + assert sample.dtype == torch.float32 - assert "input" in sample - assert "target" in sample - assert "target_mask" in sample - - input_t = sample["input"] - target_t = sample["target"] - target_mask_t = sample["target_mask"] - assert input_t.shape == (input_steps, 1, 16, 16) - assert target_t.shape == (forecast_steps, 1, 16, 16) - assert target_mask_t.shape == (forecast_steps, 1, 16, 16) - assert isinstance(input_t, torch.Tensor) - assert isinstance(target_t, torch.Tensor) - assert isinstance(target_mask_t, torch.Tensor) - - -def test_precomputed_sampling_dataset_time_subset(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that subset correctly filters CSV rows by time range.""" +def test_precomputed_sequence_dataset_time_subset(fp_test_dataset: Path, mock_csv: str) -> None: + """Subset should correctly filter CSV rows by time range.""" zarr_ds = xr.open_zarr(str(fp_test_dataset)) time_index = zarr_ds.indexes["time"] - ds = SourceDataPrecomputedSamplingDataset( + ds = SourceDataPrecomputedSequenceDataset( zarr_path=str(fp_test_dataset), csv_path=mock_csv, standard_names=["rainfall_flux"], - input_steps=2, - forecast_steps=1, + sequence_steps=3, subset={"time": (str(time_index[0]), str(time_index[8]))}, ) assert len(ds) == 2 -def test_precomputed_sampling_dataset_forecast_steps_guard(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that instantiation with input_steps=0 raises ValueError.""" - with pytest.raises(ValueError, match="input_steps"): - SourceDataPrecomputedSamplingDataset( +def test_precomputed_sequence_dataset_sequence_steps_guard(fp_test_dataset: Path, mock_csv: str) -> None: + """Instantiation with sequence_steps=0 should raise ValueError.""" + with pytest.raises(ValueError, match="sequence_steps"): + SourceDataPrecomputedSequenceDataset( zarr_path=str(fp_test_dataset), csv_path=mock_csv, standard_names=["rainfall_flux"], - input_steps=0, - forecast_steps=3, + sequence_steps=0, ) -def test_random_sampling_dataset(fp_test_dataset: Path) -> None: - """Test that SourceDataRandomSamplingDataset outputs the correct shape.""" - input_steps = 3 - forecast_steps = 2 - ds = SourceDataRandomSamplingDataset( +def test_random_sequence_dataset(fp_test_dataset: Path) -> None: + """Random sequence dataset should output normalized sequence tensors.""" + sequence_steps = 5 + ds = SourceDataRandomSequenceDataset( zarr_path=str(fp_test_dataset), standard_names=["rainfall_flux"], - input_steps=input_steps, - forecast_steps=forecast_steps, + sequence_steps=sequence_steps, width=32, height=32, epoch_size=10, - return_mask=True, ) assert len(ds) == 10 sample = ds[0] + assert sample.shape == (sequence_steps, 1, 32, 32) + assert sample.dtype == torch.float32 - assert "input" in sample - assert "target" in sample - assert "target_mask" in sample - - input_t = sample["input"] - target_t = sample["target"] - target_mask_t = sample["target_mask"] - - assert input_t.shape == (input_steps, 1, 32, 32) - assert target_t.shape == (forecast_steps, 1, 32, 32) - assert target_mask_t.shape == (forecast_steps, 1, 32, 32) - assert input_t.dtype == torch.float32 - assert target_t.dtype == torch.float32 - assert target_mask_t.dtype == torch.float32 - -def test_random_sampling_dataset_time_subset(fp_test_dataset: Path) -> None: - """Test that subset correctly slices the Zarr store.""" +def test_random_sequence_dataset_time_subset(fp_test_dataset: Path) -> None: + """Subset should correctly slice the Zarr store.""" zarr_ds = xr.open_zarr(str(fp_test_dataset)) time_index = zarr_ds.indexes["time"] - ds = SourceDataRandomSamplingDataset( + ds = SourceDataRandomSequenceDataset( zarr_path=str(fp_test_dataset), standard_names=["rainfall_flux"], - input_steps=3, - forecast_steps=2, + sequence_steps=5, subset={"time": (str(time_index[0]), str(time_index[49]))}, epoch_size=10, ) @@ -138,12 +115,26 @@ def test_random_sampling_dataset_time_subset(fp_test_dataset: Path) -> None: assert len(ds) == 10 -def test_random_sampling_dataset_forecast_steps_guard(fp_test_dataset: Path) -> None: - """Test that instantiation with input_steps=0 raises ValueError.""" - with pytest.raises(ValueError, match="input_steps"): - SourceDataRandomSamplingDataset( - zarr_path=str(fp_test_dataset), - standard_names=["rainfall_flux"], - input_steps=0, - forecast_steps=5, - ) +def test_forecasting_dataset_splits_sequence_and_derives_mask() -> None: + """ForecastingDataset should split one sequence into input and target tensors.""" + base_dataset = MockSequenceDataset(sequence_steps=6) + dataset = ForecastingDataset(base_dataset, input_steps=2, forecast_steps=4, return_mask=True) + + sample = dataset[0] + assert sample["input"].shape == (2, 1, 2, 2) + assert sample["target"].shape == (4, 1, 2, 2) + assert sample["target_mask"].shape == (4, 1, 2, 2) + assert torch.all(sample["target_mask"] == 1.0) + + +def test_reconstruction_dataset_creates_overlapping_windows() -> None: + """ReconstructionDataset should expose all overlapping windows.""" + base_dataset = MockSequenceDataset(sequence_steps=5, num_samples=2) + dataset = ReconstructionDataset(base_dataset, input_steps=3) + + assert len(dataset) == 6 + first_window = dataset[0] + second_window = dataset[1] + assert first_window.shape == (3, 1, 2, 2) + assert torch.equal(first_window[:, 0, 0, 0], torch.tensor([0.0, 1.0, 2.0])) + assert torch.equal(second_window[:, 0, 0, 0], torch.tensor([1.0, 2.0, 3.0])) diff --git a/tests/test_cli_training.py b/tests/test_cli_training.py index 86e8e29..453dd79 100644 --- a/tests/test_cli_training.py +++ b/tests/test_cli_training.py @@ -25,9 +25,9 @@ def test_cli_train_command(fp_test_dataset: Path, tmp_path: Path) -> None: "--config", "fiddler:use_random_sampler", "--config", - f"set:data.dataset_factory.zarr_path='{fp_test_dataset.absolute()}'", + f"set:data.sequence_dataset_factory.zarr_path='{fp_test_dataset.absolute()}'", "--config", - "set:data.dataset_factory.standard_names=['rainfall_flux']", + "set:data.sequence_dataset_factory.standard_names=['rainfall_flux']", "--config", "set:data.splits={'time': {'train': 0.4, 'val': 0.3, 'test': 0.3}}", "--config", @@ -59,8 +59,8 @@ def test_cli_train_from_yaml_config(fp_test_dataset: Path, tmp_path: Path) -> No cfg = convgru_training_experiment.as_buildable() # Switch to random sampler (no CSV required) and use the correct variable name use_random_sampler(cfg) - cfg.data.dataset_factory.standard_names = ["rainfall_flux"] - cfg.data.dataset_factory.zarr_path = str(fp_test_dataset.absolute()) + cfg.data.sequence_dataset_factory.standard_names = ["rainfall_flux"] + cfg.data.sequence_dataset_factory.zarr_path = str(fp_test_dataset.absolute()) cfg.data.splits = {"time": {"train": 0.4, "val": 0.3, "test": 0.3}} cfg.trainer.fast_dev_run = True cfg.data.batch_size = 1 diff --git a/tests/test_readme_snippets.py b/tests/test_readme_snippets.py index 7dc3ace..c74cbec 100644 --- a/tests/test_readme_snippets.py +++ b/tests/test_readme_snippets.py @@ -120,7 +120,7 @@ def _patch_cfg(cfg: fdl.Config, fp_dataset: Path, tmp_path: Path) -> None: tmp_path : Path Pytest-provided temporary directory for trainer outputs. """ - cfg.data.dataset_factory.zarr_path = str(fp_dataset.absolute()) + cfg.data.sequence_dataset_factory.zarr_path = str(fp_dataset.absolute()) set_variables(cfg, standard_names=["rainfall_flux"]) # Switch to the on-the-fly random sampler so no pre-computed CSV is needed. use_random_sampler(cfg) From 2c2ea199909e8a7ef10e9b0283a75d69e6d81143 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 14:36:28 +0200 Subject: [PATCH 11/34] feat: add autoencoder architecture --- src/mlcast/models/autoencoder/__init__.py | 7 ++ src/mlcast/models/autoencoder/decoder.py | 120 ++++++++++++++++++++++ src/mlcast/models/autoencoder/encoder.py | 115 +++++++++++++++++++++ src/mlcast/models/autoencoder/net.py | 80 +++++++++++++++ tests/models/test_autoencoder.py | 90 ++++++++++++++++ 5 files changed, 412 insertions(+) create mode 100644 src/mlcast/models/autoencoder/__init__.py create mode 100644 src/mlcast/models/autoencoder/decoder.py create mode 100644 src/mlcast/models/autoencoder/encoder.py create mode 100644 src/mlcast/models/autoencoder/net.py create mode 100644 tests/models/test_autoencoder.py diff --git a/src/mlcast/models/autoencoder/__init__.py b/src/mlcast/models/autoencoder/__init__.py new file mode 100644 index 0000000..fec971a --- /dev/null +++ b/src/mlcast/models/autoencoder/__init__.py @@ -0,0 +1,7 @@ +"""Autoencoder architecture components for reconstruction pretraining.""" + +from .decoder import Decoder, DecoderBlock +from .encoder import Encoder, EncoderBlock +from .net import AutoencoderNet + +__all__ = ["AutoencoderNet", "Decoder", "DecoderBlock", "Encoder", "EncoderBlock"] diff --git a/src/mlcast/models/autoencoder/decoder.py b/src/mlcast/models/autoencoder/decoder.py new file mode 100644 index 0000000..ea3b8a1 --- /dev/null +++ b/src/mlcast/models/autoencoder/decoder.py @@ -0,0 +1,120 @@ +"""Decoder blocks for the reconstruction autoencoder.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + + +class DecoderBlock(nn.Module): + """Spatio-temporal decoder block with optional spatial upsampling. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the block. + upsample : bool, optional + If ``True``, double the spatial resolution with a transposed + convolution. Default is ``True``. + """ + + def __init__(self, in_channels: int, out_channels: int, upsample: bool = True) -> None: + super().__init__() + spatial_stride = 2 if upsample else 1 + output_padding = (0, 1, 1) if upsample else 0 + self.net = nn.Sequential( + nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size=3, + stride=(1, spatial_stride, spatial_stride), + padding=1, + output_padding=output_padding, + ), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + ) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch channels time height width"] + ) -> Float[torch.Tensor, "batch out_channels time out_height out_width"]: + """Decode a channel-first spatio-temporal tensor. + + Parameters + ---------- + x : Float[torch.Tensor, "batch channels time height width"] + Input tensor. + + Returns + ------- + Float[torch.Tensor, "batch out_channels time out_height out_width"] + Decoded tensor. + """ + return self.net(x) + + +class Decoder(nn.Module): + """Convolutional decoder for sequence reconstruction. + + Parameters + ---------- + output_channels : int + Number of channels in the reconstructed source data. + hidden_channels : int, optional + Number of channels used near the output side of the decoder. Default is + ``16``. + latent_channels : int, optional + Number of channels in the latent representation. Default is ``32``. + num_blocks : int, optional + Number of spatial upsampling blocks. Default is ``2``. + """ + + def __init__( + self, + output_channels: int, + hidden_channels: int = 16, + latent_channels: int = 32, + num_blocks: int = 2, + ) -> None: + super().__init__() + if num_blocks < 1: + raise ValueError(f"num_blocks ({num_blocks}) must be at least 1.") + + self.output_channels = output_channels + self.hidden_channels = hidden_channels + self.latent_channels = latent_channels + self.num_blocks = num_blocks + + layers: list[nn.Module] = [] + in_channels = latent_channels + for block_idx in range(num_blocks): + is_last = block_idx == num_blocks - 1 + remaining_blocks = num_blocks - block_idx - 2 + out_channels = output_channels if is_last else hidden_channels * 2 ** max(remaining_blocks, 0) + layers.append(DecoderBlock(in_channels=in_channels, out_channels=out_channels, upsample=True)) + in_channels = out_channels + self.blocks = nn.Sequential(*layers) + + @jaxtyped(typechecker=beartype) + def forward( + self, z: Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + ) -> Float[torch.Tensor, "batch time channels height width"]: + """Decode a latent tensor into a time-first reconstruction tensor. + + Parameters + ---------- + z : Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + Latent tensor in channel-first 3D-convolution layout. + + Returns + ------- + Float[torch.Tensor, "batch time channels height width"] + Reconstructed sequence in the data-layer tensor layout. + """ + return self.blocks(z).movedim(1, 2) diff --git a/src/mlcast/models/autoencoder/encoder.py b/src/mlcast/models/autoencoder/encoder.py new file mode 100644 index 0000000..f319b05 --- /dev/null +++ b/src/mlcast/models/autoencoder/encoder.py @@ -0,0 +1,115 @@ +"""Encoder blocks for the reconstruction autoencoder.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + + +class EncoderBlock(nn.Module): + """Spatio-temporal encoder block with optional spatial downsampling. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the block. + downsample : bool, optional + If ``True``, halve the spatial resolution with a stride-2 convolution. + Default is ``True``. + """ + + def __init__(self, in_channels: int, out_channels: int, downsample: bool = True) -> None: + super().__init__() + spatial_stride = 2 if downsample else 1 + self.net = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=3, + stride=(1, spatial_stride, spatial_stride), + padding=1, + ), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + ) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch channels time height width"] + ) -> Float[torch.Tensor, "batch out_channels time out_height out_width"]: + """Encode a channel-first spatio-temporal tensor. + + Parameters + ---------- + x : Float[torch.Tensor, "batch channels time height width"] + Input tensor. + + Returns + ------- + Float[torch.Tensor, "batch out_channels time out_height out_width"] + Encoded tensor. + """ + return self.net(x) + + +class Encoder(nn.Module): + """Convolutional encoder for sequence reconstruction. + + Parameters + ---------- + input_channels : int + Number of channels in the source data. + hidden_channels : int, optional + Number of channels used in the first encoder block. Default is ``16``. + latent_channels : int, optional + Number of channels in the latent representation. Default is ``32``. + num_blocks : int, optional + Number of spatial downsampling blocks. Default is ``2``. + """ + + def __init__( + self, + input_channels: int, + hidden_channels: int = 16, + latent_channels: int = 32, + num_blocks: int = 2, + ) -> None: + super().__init__() + if num_blocks < 1: + raise ValueError(f"num_blocks ({num_blocks}) must be at least 1.") + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.latent_channels = latent_channels + self.num_blocks = num_blocks + + layers: list[nn.Module] = [] + in_channels = input_channels + for block_idx in range(num_blocks): + out_channels = latent_channels if block_idx == num_blocks - 1 else hidden_channels * 2**block_idx + layers.append(EncoderBlock(in_channels=in_channels, out_channels=out_channels, downsample=True)) + in_channels = out_channels + self.blocks = nn.Sequential(*layers) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch time channels height width"] + ) -> Float[torch.Tensor, "batch latent_channels time latent_height latent_width"]: + """Encode a time-first sequence tensor into a latent tensor. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Input sequence in the data-layer tensor layout. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + Latent tensor in channel-first 3D-convolution layout. + """ + return self.blocks(x.movedim(2, 1)) diff --git a/src/mlcast/models/autoencoder/net.py b/src/mlcast/models/autoencoder/net.py new file mode 100644 index 0000000..84d4329 --- /dev/null +++ b/src/mlcast/models/autoencoder/net.py @@ -0,0 +1,80 @@ +"""Autoencoder network for reconstruction pretraining.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.models.autoencoder.decoder import Decoder +from mlcast.models.autoencoder.encoder import Encoder + + +class AutoencoderNet(nn.Module): + """Compose an encoder and decoder into a reconstruction network. + + Parameters + ---------- + encoder : Encoder + Encoder module that maps input sequences to latent tensors. + decoder : Decoder + Decoder module that maps latent tensors back to input-space sequences. + """ + + def __init__(self, encoder: Encoder, decoder: Decoder) -> None: + super().__init__() + self.encoder = encoder + self.decoder = decoder + + @jaxtyped(typechecker=beartype) + def encode( + self, x: Float[torch.Tensor, "batch time channels height width"] + ) -> Float[torch.Tensor, "batch latent_channels time latent_height latent_width"]: + """Encode an input sequence into latent space. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Input sequence tensor. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + Latent tensor produced by the encoder. + """ + return self.encoder(x) + + @jaxtyped(typechecker=beartype) + def decode( + self, z: Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + ) -> Float[torch.Tensor, "batch time channels height width"]: + """Decode a latent tensor into input space. + + Parameters + ---------- + z : Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + Latent tensor produced by the encoder. + + Returns + ------- + Float[torch.Tensor, "batch time channels height width"] + Reconstructed sequence tensor. + """ + return self.decoder(z) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch time channels height width"] + ) -> Float[torch.Tensor, "batch time channels height width"]: + """Run an end-to-end reconstruction forward pass. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Input sequence tensor. + + Returns + ------- + Float[torch.Tensor, "batch time channels height width"] + Reconstructed sequence tensor. + """ + return self.decode(self.encode(x)) diff --git a/tests/models/test_autoencoder.py b/tests/models/test_autoencoder.py new file mode 100644 index 0000000..dd5de50 --- /dev/null +++ b/tests/models/test_autoencoder.py @@ -0,0 +1,90 @@ +import torch +import torch.nn.functional as F + +from mlcast.models.autoencoder import AutoencoderNet, Decoder, Encoder + + +def test_encoder_output_shape() -> None: + """Encoder should preserve time and downsample spatial dimensions.""" + batch_size = 2 + input_steps = 4 + channels = 1 + height = 16 + width = 16 + latent_channels = 6 + + encoder = Encoder(input_channels=channels, hidden_channels=4, latent_channels=latent_channels, num_blocks=2) + x = torch.randn(batch_size, input_steps, channels, height, width) + + z = encoder(x) + + assert z.shape == (batch_size, latent_channels, input_steps, height // 4, width // 4) + + +def test_decoder_output_shape() -> None: + """Decoder should preserve time and upsample spatial dimensions.""" + batch_size = 2 + input_steps = 4 + channels = 1 + latent_channels = 6 + latent_height = 4 + latent_width = 4 + + decoder = Decoder(output_channels=channels, hidden_channels=4, latent_channels=latent_channels, num_blocks=2) + z = torch.randn(batch_size, latent_channels, input_steps, latent_height, latent_width) + + y = decoder(z) + + assert y.shape == (batch_size, input_steps, channels, latent_height * 4, latent_width * 4) + + +def test_autoencoder_reconstruction_forward_pass() -> None: + """Autoencoder should reconstruct tensors with the same shape as its input.""" + batch_size = 2 + input_steps = 3 + channels = 2 + height = 16 + width = 16 + + encoder = Encoder(input_channels=channels, hidden_channels=4, latent_channels=8, num_blocks=2) + decoder = Decoder(output_channels=channels, hidden_channels=4, latent_channels=8, num_blocks=2) + model = AutoencoderNet(encoder=encoder, decoder=decoder) + x = torch.randn(batch_size, input_steps, channels, height, width) + + y = model(x) + + assert y.shape == x.shape + + +def test_autoencoder_improves_reconstruction_loss() -> None: + """Autoencoder should reduce reconstruction loss on a tiny generated dataset.""" + torch.manual_seed(42) + batch_size = 8 + input_steps = 2 + channels = 1 + height = 8 + width = 8 + + encoder = Encoder(input_channels=channels, hidden_channels=4, latent_channels=4, num_blocks=1) + decoder = Decoder(output_channels=channels, hidden_channels=4, latent_channels=4, num_blocks=1) + model = AutoencoderNet(encoder=encoder, decoder=decoder) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + + spatial_pattern = torch.linspace(-1.0, 1.0, height * width).reshape(1, 1, 1, height, width) + temporal_scale = torch.linspace(0.5, 1.5, input_steps).reshape(1, input_steps, 1, 1, 1) + samples = spatial_pattern * temporal_scale + samples = samples.repeat(batch_size, 1, channels, 1, 1) + + with torch.no_grad(): + initial_loss = F.mse_loss(model(samples), samples).item() + + for _ in range(40): + optimizer.zero_grad(set_to_none=True) + loss = F.mse_loss(model(samples), samples) + loss.backward() + optimizer.step() + + with torch.no_grad(): + final_loss = F.mse_loss(model(samples), samples).item() + + assert final_loss < initial_loss From bc415aae30e28d374af913740da027aa241080d3 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 14:38:54 +0200 Subject: [PATCH 12/34] docs: mark autoencoder plan complete --- ldcast-refactor-plan.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 16e432e..aaadae2 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -32,15 +32,15 @@ 2. Autoencoder model architecture - Autoencoder model split: - - [ ] `src/mlcast/models/autoencoder/encoder.py` for `Encoder` and `EncoderBlock`. - - [ ] `src/mlcast/models/autoencoder/decoder.py` for `Decoder` and `DecoderBlock`. - - [ ] `src/mlcast/models/autoencoder/net.py` for `AutoencoderNet`. -- [ ] Use `input_steps` for the stage-1 reconstruction window length; do not introduce names like `autoenc_time_ratio`. + - [x] `src/mlcast/models/autoencoder/encoder.py` for `Encoder` and `EncoderBlock`. + - [x] `src/mlcast/models/autoencoder/decoder.py` for `Decoder` and `DecoderBlock`. + - [x] `src/mlcast/models/autoencoder/net.py` for `AutoencoderNet`. +- [x] Use `input_steps` for the stage-1 reconstruction window length; do not introduce names like `autoenc_time_ratio`. - Autoencoder validation and tests: - - [ ] encoder output shape. - - [ ] decoder output shape. - - [ ] autoencoder reconstruction forward pass. - - [ ] autoencoder improves reconstruction loss on a small generated dataset after a few training steps. + - [x] encoder output shape. + - [x] decoder output shape. + - [x] autoencoder reconstruction forward pass. + - [x] autoencoder improves reconstruction loss on a small generated dataset after a few training steps. 3. Forecasting model contract - [ ] Standardize all forecasting models on init-time `input_steps`, `forecast_steps`, and `ensemble_size`. From 10b6b9ae410c1449826d98f1f8e4b7db79cdf50a Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 14:54:53 +0200 Subject: [PATCH 13/34] refactor: fix forecasting model shape contract --- README.md | 33 ++--- docs/config_diagram.svg | 154 +++++++++++++----------- ldcast-refactor-plan.md | 8 +- src/mlcast/config/base.py | 4 +- src/mlcast/config/consistency_checks.py | 42 ++++++- src/mlcast/models/convgru.py | 47 ++++++-- src/mlcast/nowcasting_module.py | 49 ++------ tests/config/test_consistency_checks.py | 22 +++- tests/models/test_convgru.py | 37 ++++-- tests/test_nowcasting_module.py | 42 +++++++ 10 files changed, 283 insertions(+), 155 deletions(-) create mode 100644 tests/test_nowcasting_module.py diff --git a/README.md b/README.md index 2885cd0..b678209 100644 --- a/README.md +++ b/README.md @@ -187,13 +187,15 @@ from mfai.torch.models import HalfUNet from mlcast.config import convgru_training_experiment, train_from_config from mlcast.config.fiddlers import use_random_sampler -# Minimal adapter: channel-stack past frames → HalfUNet → one step at a time. -# NowcastLightningModule calls network(x, steps=N, ensemble_size=M), so any -# custom network must accept those keyword arguments. +# Minimal adapter: channel-stack past frames -> HalfUNet -> one step at a time. +# The forecasting contract fixes input_steps, forecast_steps, and ensemble_size +# at model initialization; NowcastLightningModule calls network(x). class HalfUNetNowcaster(nn.Module): - def __init__(self, input_steps: int = 6, num_vars: int = 1): + def __init__(self, input_steps: int = 6, forecast_steps: int = 12, ensemble_size: int = 1, num_vars: int = 1): super().__init__() self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.ensemble_size = ensemble_size self.num_vars = num_vars self.unet = HalfUNet( input_shape=(256, 256), @@ -213,13 +215,11 @@ class HalfUNetNowcaster(nn.Module): def forward( self, x: Float[torch.Tensor, "batch input_steps in_channels H W"], - steps: int, - ensemble_size: int = 1, - ) -> Float[torch.Tensor, "batch steps out_channels H W"]: + ) -> Float[torch.Tensor, "batch forecast_steps out_channels H W"]: # channel-stack all input frames: (b, t, c, h, w) -> (b, t*c, h, w) x_flat = einops.rearrange(x, "b t c h w -> b (t c) h w") preds = [] - for _ in range(steps): + for _ in range(self.forecast_steps): y = self.unet(x_flat) # [B, num_vars, H, W] preds.append(y.unsqueeze(1)) # slide window: drop the oldest timestep (first num_vars channels), @@ -233,6 +233,8 @@ use_random_sampler(cfg) cfg.pl_module.network = fdl.Config( HalfUNetNowcaster, input_steps=cfg.data.input_steps, + forecast_steps=cfg.data.forecast_steps, + ensemble_size=cfg.pl_module.network.ensemble_size, num_vars=len(cfg.data.sequence_dataset_factory.standard_names), ) @@ -277,8 +279,10 @@ mlcast/ │ │ ├── loader.py # YAML config loader │ │ └── orchestrator.py # train_from_config, config persistence │ ├── data/ -│ │ ├── source_data_datamodule.py # Lightning DataModule -│ │ ├── source_data_datasets.py # Zarr-backed PyTorch datasets +│ │ ├── datamodules.py # Lightning DataModules +│ │ ├── sequence.py # Zarr-backed sequence datasets +│ │ ├── forecasting.py # Forecasting task dataset wrapper +│ │ ├── reconstruction.py # Reconstruction task dataset wrapper │ │ └── normalization.py # Normalisation registry │ └── models/ │ └── convgru.py # ConvGRU encoder-decoder @@ -329,8 +333,9 @@ concatenated along the channel dimension. ### Custom network interface Any network architecture can be used by replacing `cfg.pl_module.network` -with a `fdl.Config` node pointing at your class. The only requirement is -that `forward` accepts the following signature: +with a `fdl.Config` node pointing at your class. Forecasting models should set +`input_steps`, `forecast_steps`, and `ensemble_size` during initialization. The +only runtime `forward` requirement is: ```python # from jaxtyping import Float @@ -339,9 +344,7 @@ that `forward` accepts the following signature: def forward( self, x: Float[torch.Tensor, "batch input_steps in_channels H W"], - steps: int, # number of forecast steps to produce - ensemble_size: int, # number of stochastic ensemble members -) -> Float[torch.Tensor, "batch steps out_channels H W"]: +) -> Float[torch.Tensor, "batch forecast_steps out_channels H W"]: ... ``` diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 129fec4..9a08bf4 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -12,83 +12,93 @@ 2 - - -Config: - ConvGruModel - - -input_channels - -1 - - -num_blocks - -5 - - -noisy_decoder - -False + + +Config: + ConvGruModel + + +input_steps + +6 + + +forecast_steps + +12 + + +ensemble_size + +2 + + +input_channels + +1 + + +num_blocks + +5 + + +noisy_decoder + +False 1 - - -Config: - NowcastLightningModule - - -network - - - - - -ensemble_size - -2 - - -loss_class - -'crps' - - -loss_params - - - -dict - - -'temporal_lambda' - -0.01 - - -masked_loss - -True - - -optimizer - - - - - -lr_scheduler - - - + + +Config: + NowcastLightningModule + + +network + + + + + +loss_class + +'crps' + + +loss_params + + + +dict + + +'temporal_lambda' + +0.01 + + +masked_loss + +True + + +optimizer + + + + + +lr_scheduler + + + 1:c--2:c - + @@ -111,7 +121,7 @@ 1:c--3:c - + @@ -139,7 +149,7 @@ 1:c--4:c - + @@ -170,7 +180,7 @@ 0:c--1:c - + diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index aaadae2..398395b 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -43,10 +43,10 @@ - [x] autoencoder improves reconstruction loss on a small generated dataset after a few training steps. 3. Forecasting model contract -- [ ] Standardize all forecasting models on init-time `input_steps`, `forecast_steps`, and `ensemble_size`. -- [ ] Standardize forecasting model inference on `forward(x)` only; do not pass `forecast_steps` or `ensemble_size` at runtime. -- [ ] Refactor the existing ConvGRU path to follow this fixed-shape contract. -- [ ] Add config consistency checks that dataset `input_steps` and `forecast_steps` match the configured forecasting model. +- [x] Standardize all forecasting models on init-time `input_steps`, `forecast_steps`, and `ensemble_size`. +- [x] Standardize forecasting model inference on `forward(x)` only; do not pass `forecast_steps` or `ensemble_size` at runtime. +- [x] Refactor the existing ConvGRU path to follow this fixed-shape contract. +- [x] Add config consistency checks that dataset `input_steps` and `forecast_steps` match the configured forecasting model. 4. Diffusion model architecture - Diffusion model split: diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index 418ac65..da7a4f7 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -83,6 +83,9 @@ def convgru_training_experiment() -> Experiment: ) network = ConvGruModel( + input_steps=6, + forecast_steps=12, + ensemble_size=2, input_channels=1, num_blocks=5, noisy_decoder=False, @@ -90,7 +93,6 @@ def convgru_training_experiment() -> Experiment: pl_module = NowcastLightningModule( network=network, - ensemble_size=2, loss_class="crps", loss_params={"temporal_lambda": 0.01}, masked_loss=True, diff --git a/src/mlcast/config/consistency_checks.py b/src/mlcast/config/consistency_checks.py index debbd4d..22363f8 100644 --- a/src/mlcast/config/consistency_checks.py +++ b/src/mlcast/config/consistency_checks.py @@ -31,6 +31,7 @@ def validate_config(cfg: fdl.Config) -> None: sequence_dataset_factory = cfg.data.sequence_dataset_factory network = cfg.pl_module.network pl_module = cfg.pl_module + data = cfg.data # Contract 1: Network input_channels == len(sequence_dataset_factory.standard_names) # If the network does not expose input_channels, emit a warning because @@ -73,16 +74,49 @@ def validate_config(cfg: fdl.Config) -> None: ) # Contract 3: Ensemble models require CRPS or AFCRPS - if pl_module.ensemble_size > 1: + ensemble_size = getattr(network, "ensemble_size", 1) + if ensemble_size > 1: if str(pl_module.loss_class).lower() not in ["crps", "afcrps"]: raise ValueError( - f"Contract 3 violated: Ensemble models (ensemble_size={pl_module.ensemble_size}) " + f"Contract 3 violated: Ensemble models (ensemble_size={ensemble_size}) " f"require 'crps' or 'afcrps' loss, got '{pl_module.loss_class}'." ) # Contract 4: Forecasting mask return must match model masked_loss - if bool(cfg.data.return_mask) != bool(pl_module.masked_loss): + if bool(data.return_mask) != bool(pl_module.masked_loss): raise ValueError( - f"Contract 4 violated: data.return_mask ({cfg.data.return_mask}) " + f"Contract 4 violated: data.return_mask ({data.return_mask}) " f"must match pl_module.masked_loss ({pl_module.masked_loss})." ) + + # Contract 5: Dataset input_steps must match model input_steps + try: + net_input_steps = network.input_steps + except AttributeError: + logger.warning( + "Warning: can't ensure network input_steps matches data.input_steps, " + "because network {} doesn't expose 'input_steps'.", + network.__class__.__name__, + ) + net_input_steps = None + if net_input_steps is not None and net_input_steps != data.input_steps: + raise ValueError( + f"Contract 5 violated: network input_steps ({net_input_steps}) " + f"must equal data.input_steps ({data.input_steps})." + ) + + # Contract 6: Dataset forecast_steps must match model forecast_steps + try: + net_forecast_steps = network.forecast_steps + except AttributeError: + logger.warning( + "Warning: can't ensure network forecast_steps matches data.forecast_steps, " + "because network {} doesn't expose 'forecast_steps'.", + network.__class__.__name__, + ) + net_forecast_steps = None + if net_forecast_steps is not None and net_forecast_steps != data.forecast_steps: + raise ValueError( + f"Contract 6 violated: network forecast_steps ({net_forecast_steps}) " + f"must equal data.forecast_steps ({data.forecast_steps})." + ) diff --git a/src/mlcast/models/convgru.py b/src/mlcast/models/convgru.py index 2b0c3dc..da54bd5 100644 --- a/src/mlcast/models/convgru.py +++ b/src/mlcast/models/convgru.py @@ -195,6 +195,12 @@ class Encoder(nn.Module): Parameters ---------- + input_steps : int + Number of timesteps the model expects as input. + forecast_steps : int + Number of timesteps the model forecasts. + ensemble_size : int, optional + Number of ensemble members produced by the model. Default is ``1``. input_channels : int, optional Number of input channels. Default is ``1``. num_blocks : int, optional @@ -350,8 +356,27 @@ class ConvGruModel(nn.Module): :class:`Decoder`. """ - def __init__(self, input_channels: int = 1, num_blocks: int = 4, noisy_decoder: bool = False, **kwargs): + def __init__( + self, + input_steps: int, + forecast_steps: int, + ensemble_size: int = 1, + input_channels: int = 1, + num_blocks: int = 4, + noisy_decoder: bool = False, + **kwargs, + ): super().__init__() + if input_steps < 1: + raise ValueError(f"input_steps ({input_steps}) must be at least 1.") + if forecast_steps < 1: + raise ValueError(f"forecast_steps ({forecast_steps}) must be at least 1.") + if ensemble_size < 1: + raise ValueError(f"ensemble_size ({ensemble_size}) must be at least 1.") + + self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.ensemble_size = ensemble_size self.input_channels = input_channels self.num_blocks = num_blocks self.noisy_decoder = noisy_decoder @@ -360,25 +385,23 @@ def __init__(self, input_channels: int = 1, num_blocks: int = 4, noisy_decoder: @jaxtyped(typechecker=beartype) def forward( - self, x: Float[torch.Tensor, "batch time channels height width"], steps: int, ensemble_size: int = 1 - ) -> Float[torch.Tensor, "batch steps _ height width"]: + self, x: Float[torch.Tensor, "batch time channels height width"] + ) -> Float[torch.Tensor, "batch forecast_steps _ height width"]: """Forward the encoder-decoder model. Parameters ---------- x : Float[torch.Tensor, "batch time channels height width"] Input sequence. - steps : int - Number of future timesteps to forecast. - ensemble_size : int, optional - Number of ensemble members to generate. When ``> 1``, the decoder - is always run with noisy inputs. Default is ``1``. Returns ------- - preds : Float[torch.Tensor, "batch steps out_channels height width"] + preds : Float[torch.Tensor, "batch forecast_steps out_channels height width"] Forecast tensor. """ + if x.shape[1] != self.input_steps: + raise ValueError(f"Expected {self.input_steps} input timesteps, got {x.shape[1]}.") + _, _, _, H, W = x.shape divisor = 2**self.num_blocks pad_h = (divisor - (H % divisor)) % divisor @@ -390,13 +413,13 @@ def forward( encoded = self.encoder(x) x_dec_shape = list(encoded[-1].shape) - x_dec_shape[1] = steps + x_dec_shape[1] = self.forecast_steps last_hidden_per_block = [e[:, -1] for e in reversed(encoded)] - if ensemble_size > 1: + if self.ensemble_size > 1: preds = [] - for _ in range(ensemble_size): + for _ in range(self.ensemble_size): x_dec = torch.randn(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device) decoded = self.decoder(x_dec, last_hidden_per_block) preds.append(decoded) diff --git a/src/mlcast/nowcasting_module.py b/src/mlcast/nowcasting_module.py index 3ef5e19..faa6e90 100644 --- a/src/mlcast/nowcasting_module.py +++ b/src/mlcast/nowcasting_module.py @@ -2,7 +2,7 @@ Wraps an injected PyTorch :class:`nn.Module` (the network architecture) and handles training, validation, and test steps including loss computation, -ensemble generation, and image logging. +image logging, and optimizer configuration. """ from collections.abc import Callable @@ -30,8 +30,6 @@ class NowcastLightningModule(pl.LightningModule): ---------- network : torch.nn.Module The PyTorch network architecture to train. - ensemble_size : int, optional - Number of ensemble members to generate. Default is ``1``. loss_class : type[torch.nn.Module] or str, optional Loss function class or its string name. Default is ``"mse"``. loss_params : dict or None, optional @@ -49,7 +47,6 @@ class NowcastLightningModule(pl.LightningModule): def __init__( self, network: torch.nn.Module, - ensemble_size: int = 1, loss_class: type[torch.nn.Module] | str = "mse", loss_params: dict[str, Any] | None = None, masked_loss: bool = False, @@ -58,7 +55,7 @@ def __init__( ) -> None: super().__init__() # Explicitly save hyperparameters that are accessed later via self.hparams - self.save_hyperparameters("ensemble_size", "loss_class", "loss_params", "masked_loss") + self.save_hyperparameters("loss_class", "loss_params", "masked_loss") self.network = network self.optimizer_factory = optimizer @@ -75,8 +72,6 @@ def __init__( def forward( self, x: Float[torch.Tensor, "batch time channels height width"], - forecast_steps: int, - ensemble_size: int | None = None, ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: """Run the network forward pass. @@ -84,22 +79,15 @@ def forward( ---------- x : Float[torch.Tensor, "batch time channels height width"] Input tensor. - forecast_steps : int - Number of steps to forecast. - ensemble_size : int or None, optional - Number of ensemble members to generate. If ``None``, uses the initialized value. Default is ``None``. Returns ------- preds : Float[torch.Tensor, "batch forecast_steps out_channels height width"] Forecast tensor. """ - ensemble_size = self.hparams["ensemble_size"] if ensemble_size is None else ensemble_size - return self.network(x, steps=forecast_steps, ensemble_size=ensemble_size) + return self.network(x) - def shared_step( - self, batch: dict[str, torch.Tensor], split: str = "train", ensemble_size: int | None = None - ) -> torch.Tensor: + def shared_step(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: """Shared forward step for training, validation, and testing. Parameters @@ -110,10 +98,6 @@ def shared_step( split : str, optional The data split being processed (e.g., ``"train"``, ``"val"``, ``"test"``). Used for logging. Default is ``"train"``. - ensemble_size : int or None, optional - The number of ensemble members to generate. If ``None``, uses the - default from hyper-parameters. Default is ``None``. - Returns ------- loss : torch.Tensor @@ -121,9 +105,8 @@ def shared_step( """ past = batch["input"] future = batch["target"] - forecast_steps = future.shape[1] - preds = self(past, forecast_steps=forecast_steps, ensemble_size=ensemble_size).clamp(min=-1, max=1) + preds = self(past).clamp(min=-1, max=1) if self.hparams["masked_loss"]: mask = batch["target_mask"] @@ -139,7 +122,8 @@ def shared_step( self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) - if self.hparams["ensemble_size"] > 1: + ensemble_size = getattr(self.network, "ensemble_size", 1) + if ensemble_size > 1: ensemble_std = preds.std(dim=2).mean() self.log(f"{split}_ensemble_std", ensemble_std, on_epoch=True, sync_dist=True) @@ -157,7 +141,7 @@ def shared_step( preds=preds, logger=self.logger, # type: ignore global_step=self.global_step, - ensemble_size=self.hparams["ensemble_size"], + ensemble_size=ensemble_size, split=split, ) return loss @@ -194,7 +178,7 @@ def validation_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> to loss : torch.Tensor The validation loss. """ - return self.shared_step(batch, split="val", ensemble_size=10) + return self.shared_step(batch, split="val") def test_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: """Execute a single test step. @@ -211,7 +195,7 @@ def test_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Te loss : torch.Tensor The test loss. """ - return self.shared_step(batch, split="test", ensemble_size=10) + return self.shared_step(batch, split="test") def configure_optimizers(self) -> Any: """Configure the optimizer and optional learning rate scheduler. @@ -260,8 +244,6 @@ def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "NowcastL def predict( self, past: torch.Tensor, - forecast_steps: int = 1, - ensemble_size: int | None = 1, standard_name: str = "rainfall_rate", ) -> np.ndarray[Any, Any]: """Generate precipitation forecasts from past radar observations. @@ -272,10 +254,6 @@ def predict( ---------- past : torch.Tensor Past radar frames as unnormalized values (e.g., mm/h or kg m-2 s-1), of shape ``(T, H, W)``. - forecast_steps : int, optional - Number of future timesteps to forecast. Default is ``1``. - ensemble_size : int, optional - Number of ensemble members. Default is ``1``. standard_name : str, optional The CF standard name defining the input/output domain for normalization lookup. Default is ``"rainfall_rate"``. @@ -284,13 +262,12 @@ def predict( ------- preds : np.ndarray Forecasted unnormalized values, of shape - ``(ensemble_size, forecast_steps, H, W)``. + ``(ensemble_size, forecast_steps, H, W)``. The ensemble size and + forecast horizon are determined by the configured network. """ if len(past.shape) != 3: raise ValueError("Input must be of shape (T, H, W)") - ensemble_size = self.hparams["ensemble_size"] if ensemble_size is None else ensemble_size - past_clean = np.nan_to_num(past.cpu().numpy()) past_clean = past_clean[np.newaxis, :, np.newaxis, ...] @@ -302,7 +279,7 @@ def predict( self.eval() with torch.no_grad(): - preds_tensor = self.network(x, steps=forecast_steps, ensemble_size=ensemble_size) + preds_tensor = self.network(x) preds_np: np.ndarray[Any, Any] = preds_tensor.cpu().numpy() diff --git a/tests/config/test_consistency_checks.py b/tests/config/test_consistency_checks.py index 97a7f81..56bfbe3 100644 --- a/tests/config/test_consistency_checks.py +++ b/tests/config/test_consistency_checks.py @@ -53,7 +53,7 @@ def test_contract_3_probabilistic_loss() -> None: """Verify Contract 3: Ensemble models require CRPS or AFCRPS.""" cfg = convgru_training_experiment.as_buildable() # Break Contract 3 - cfg.pl_module.ensemble_size = 5 + cfg.pl_module.network.ensemble_size = 5 cfg.pl_module.loss_class = "mse" with pytest.raises(ValueError, match="Contract 3 violated:"): @@ -71,6 +71,26 @@ def test_contract_4_masking_sync() -> None: validate_config(cfg) +def test_contract_5_input_steps_sync() -> None: + """Verify Contract 5: data input_steps must match model input_steps.""" + cfg = convgru_training_experiment.as_buildable() + cfg.data.input_steps = 4 + cfg.pl_module.network.input_steps = 6 + + with pytest.raises(ValueError, match="Contract 5 violated:"): + validate_config(cfg) + + +def test_contract_6_forecast_steps_sync() -> None: + """Verify Contract 6: data forecast_steps must match model forecast_steps.""" + cfg = convgru_training_experiment.as_buildable() + cfg.data.forecast_steps = 10 + cfg.pl_module.network.forecast_steps = 12 + + with pytest.raises(ValueError, match="Contract 6 violated:"): + validate_config(cfg) + + def test_dataset_sequence_steps_guard() -> None: """Verify that sequence dataset raises ValueError when sequence_steps=0.""" with pytest.raises(ValueError, match="sequence_steps"): diff --git a/tests/models/test_convgru.py b/tests/models/test_convgru.py index a751320..b296543 100644 --- a/tests/models/test_convgru.py +++ b/tests/models/test_convgru.py @@ -3,7 +3,7 @@ from mlcast.models.convgru import ConvGruModel -def test_convgru_dynamic_padding(): +def test_convgru_dynamic_padding() -> None: """Verify that ConvGruModel dynamically pads non-power-of-2 inputs and crops the output.""" # Given an input with awkward spatial dimensions batch_size = 2 @@ -18,20 +18,19 @@ def test_convgru_dynamic_padding(): x = torch.randn(batch_size, time_steps, channels, height, width) - model = ConvGruModel(input_channels=channels, num_blocks=4) - model.eval() - forecast_steps = 4 + model = ConvGruModel(input_steps=time_steps, forecast_steps=forecast_steps, input_channels=channels, num_blocks=4) + model.eval() with torch.no_grad(): - preds = model(x, steps=forecast_steps, ensemble_size=1) + preds = model(x) # Check that it didn't crash and the output shape is exactly (batch, steps, channels, height, width) # The single ensemble member case returns out_channels = channels. assert preds.shape == (batch_size, forecast_steps, channels, height, width) -def test_convgru_dynamic_padding_ensemble(): +def test_convgru_dynamic_padding_ensemble() -> None: """Verify that ConvGruModel dynamically pads non-power-of-2 inputs and crops the output for ensemble generation.""" # Given an input with awkward spatial dimensions batch_size = 1 @@ -42,17 +41,35 @@ def test_convgru_dynamic_padding_ensemble(): x = torch.randn(batch_size, time_steps, channels, height, width) - model = ConvGruModel(input_channels=channels, num_blocks=3) - model.eval() - forecast_steps = 2 ensemble_size = 5 + model = ConvGruModel( + input_steps=time_steps, + forecast_steps=forecast_steps, + ensemble_size=ensemble_size, + input_channels=channels, + num_blocks=3, + ) + model.eval() with torch.no_grad(): - preds = model(x, steps=forecast_steps, ensemble_size=ensemble_size) + preds = model(x) # Check that it didn't crash and the output shape is exactly (batch, steps, ensemble_size * channels, height, width) # Actually wait: The decoder block outputs the same number of channels as the final upsampling step. # In the `ConvGruModel.forward` with `ensemble_size > 1`, `out` is `torch.cat(preds, dim=2)`. # Let's verify the exact channel dimension. The original output channels per ensemble member is `channels`. assert preds.shape == (batch_size, forecast_steps, channels * ensemble_size, height, width) + + +def test_convgru_rejects_wrong_input_steps() -> None: + """ConvGruModel should reject inputs that violate its configured input length.""" + model = ConvGruModel(input_steps=3, forecast_steps=2, input_channels=1, num_blocks=2) + x = torch.randn(1, 2, 1, 32, 32) + + try: + model(x) + except ValueError as exc: + assert "Expected 3 input timesteps" in str(exc) + else: + raise AssertionError("Expected ConvGruModel to reject wrong input_steps") diff --git a/tests/test_nowcasting_module.py b/tests/test_nowcasting_module.py new file mode 100644 index 0000000..d859fb4 --- /dev/null +++ b/tests/test_nowcasting_module.py @@ -0,0 +1,42 @@ +import numpy as np +import torch + +from mlcast.nowcasting_module import NowcastLightningModule + + +class DummyForecastNetwork(torch.nn.Module): + """Minimal fixed-shape forecasting network for module tests.""" + + def __init__(self, input_steps: int, forecast_steps: int, ensemble_size: int = 1) -> None: + super().__init__() + self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.ensemble_size = ensemble_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, _, channels, height, width = x.shape + out_channels = channels * self.ensemble_size + return torch.zeros(batch_size, self.forecast_steps, out_channels, height, width, device=x.device) + + +def test_nowcasting_module_forward_uses_network_shape_contract() -> None: + """NowcastLightningModule should call fixed-shape forecasting networks as network(x).""" + network = DummyForecastNetwork(input_steps=3, forecast_steps=5, ensemble_size=2) + module = NowcastLightningModule(network=network, loss_class="crps") + x = torch.randn(4, 3, 1, 8, 8) + + preds = module(x) + + assert preds.shape == (4, 5, 2, 8, 8) + + +def test_nowcasting_module_predict_uses_configured_output_shape() -> None: + """Prediction horizon and ensemble size should come from the configured network.""" + network = DummyForecastNetwork(input_steps=3, forecast_steps=4, ensemble_size=2) + module = NowcastLightningModule(network=network, loss_class="crps") + past = torch.ones(3, 8, 8) + + preds = module.predict(past, standard_name="rainfall_rate") + + assert isinstance(preds, np.ndarray) + assert preds.shape == (2, 4, 8, 8) From 4b28741e18919486e0e014bfd895fd58969703c0 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 20 May 2026 15:33:33 +0200 Subject: [PATCH 14/34] docs: clarify deterministic halfunet example --- README.md | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b678209..63720ec 100644 --- a/README.md +++ b/README.md @@ -189,13 +189,13 @@ from mlcast.config.fiddlers import use_random_sampler # Minimal adapter: channel-stack past frames -> HalfUNet -> one step at a time. # The forecasting contract fixes input_steps, forecast_steps, and ensemble_size -# at model initialization; NowcastLightningModule calls network(x). +# at model initialization; this minimal deterministic adapter exposes one +# ensemble member and NowcastLightningModule calls network(x). class HalfUNetNowcaster(nn.Module): - def __init__(self, input_steps: int = 6, forecast_steps: int = 12, ensemble_size: int = 1, num_vars: int = 1): + def __init__(self, input_steps: int = 6, forecast_steps: int = 12, num_vars: int = 1): super().__init__() self.input_steps = input_steps self.forecast_steps = forecast_steps - self.ensemble_size = ensemble_size self.num_vars = num_vars self.unet = HalfUNet( input_shape=(256, 256), @@ -204,6 +204,10 @@ class HalfUNetNowcaster(nn.Module): settings=fdl.Config(HalfUNet.settings_kls), ) + @property + def ensemble_size(self) -> int: + return 1 + @property def input_channels(self) -> int: # Externally, the HalfUNetNowcaster respects the required input shape structure @@ -234,9 +238,12 @@ cfg.pl_module.network = fdl.Config( HalfUNetNowcaster, input_steps=cfg.data.input_steps, forecast_steps=cfg.data.forecast_steps, - ensemble_size=cfg.pl_module.network.ensemble_size, num_vars=len(cfg.data.sequence_dataset_factory.standard_names), ) +# The base ConvGRU config uses CRPS for ensemble forecasts; this adapter is +# deterministic and exposes only one member, so use a deterministic loss. +cfg.pl_module.loss_class = "mse" +cfg.pl_module.loss_params = None train_from_config(cfg) ``` From 36bb8dcd53f21571da4f45f5a6d30698983b6668 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 3 Jun 2026 16:01:03 +0200 Subject: [PATCH 15/34] feat: add latent diffusion architecture --- ldcast-refactor-plan.md | 20 +- src/mlcast/models/diffusion/__init__.py | 17 ++ src/mlcast/models/diffusion/conditioner.py | 86 +++++++ src/mlcast/models/diffusion/denoiser.py | 250 +++++++++++++++++++++ src/mlcast/models/diffusion/ema.py | 48 ++++ src/mlcast/models/diffusion/loss.py | 48 ++++ src/mlcast/models/diffusion/net.py | 102 +++++++++ src/mlcast/models/diffusion/sampler.py | 60 +++++ src/mlcast/models/diffusion/scheduler.py | 71 ++++++ tests/models/test_diffusion.py | 74 ++++++ 10 files changed, 766 insertions(+), 10 deletions(-) create mode 100644 src/mlcast/models/diffusion/__init__.py create mode 100644 src/mlcast/models/diffusion/conditioner.py create mode 100644 src/mlcast/models/diffusion/denoiser.py create mode 100644 src/mlcast/models/diffusion/ema.py create mode 100644 src/mlcast/models/diffusion/loss.py create mode 100644 src/mlcast/models/diffusion/net.py create mode 100644 src/mlcast/models/diffusion/sampler.py create mode 100644 src/mlcast/models/diffusion/scheduler.py create mode 100644 tests/models/test_diffusion.py diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 398395b..f1b125f 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -50,28 +50,28 @@ 4. Diffusion model architecture - Diffusion model split: - - [ ] `src/mlcast/models/diffusion/conditioner.py` for latent conditioning blocks and `ConditionerNet`. - - [ ] `src/mlcast/models/diffusion/denoiser.py` for `DenoiserUNet` and timestep-aware helpers. - - [ ] `src/mlcast/models/diffusion/net.py` for `LatentDiffusionNet`. - - [ ] `src/mlcast/models/diffusion/forecasting.py` for `LatentDiffusionForecaster`, the diffusion-specific adapter configured with fixed `input_steps`, `forecast_steps`, and `ensemble_size` and exposing `forward(x)`. - - [ ] `src/mlcast/models/diffusion/scheduler.py`, `ema.py`, `sampler.py`, `loss.py` for diffusion support code. + - [x] `src/mlcast/models/diffusion/conditioner.py` for latent conditioning blocks and `ConditionerNet`. + - [x] `src/mlcast/models/diffusion/denoiser.py` for `DenoiserUNet` and timestep-aware helpers. + - [x] `src/mlcast/models/diffusion/net.py` for `LatentDiffusionNet`. + - [x] `src/mlcast/models/diffusion/scheduler.py`, `ema.py`, `sampler.py`, `loss.py` for diffusion support code. - Validation and tests: - - [ ] diffusion forecasting adapter API. - - [ ] diffusion model improves loss on a small generated latent dataset after a few training steps. + - [x] latent diffusion model API. + - [x] diffusion model improves loss on a small generated latent dataset after a few training steps. 5. Task wrappers - [ ] Add `src/mlcast/modules/forecasting.py` and rename `NowcastLightningModule` to `ForecastingModule`. - [ ] Remove runtime `forecast_steps` and `ensemble_size` arguments from the forecasting Lightning module and its `predict()` API. - [ ] Add `src/mlcast/modules/reconstruction.py` with a generic `ReconstructionModule` for any reconstruction model. +- [ ] Add a latent diffusion Lightning module that owns the trained autoencoder, trains diffusion in latent space, and handles decoded forecast inference. - [ ] Keep `modules/` for training/task wrappers only; keep `models/` for pure architectures. 6. Training experiment - [ ] Add a new LDCast-specific training module containing `LDCastTrainingExperiment`. - [x] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. - [ ] Stage 1 builds the reconstruction dataset, autoencoder model, and reconstruction module, then trains the autoencoder. -- [ ] Stage 2 reuses the same trained in-memory encoder instance, builds the diffusion dataset/model/module, then trains latent diffusion. -- [ ] The shared Fiddle graph should define the encoder once and reference the same object in both stages, but no unresolved Fiddle objects should flow into actual `torch.nn.Module.__init__` calls. -- [ ] The decoder is stage-1 only and is not shared into stage 2. +- [ ] Stage 2 reuses the same trained in-memory autoencoder instance, builds the diffusion dataset/model/module, then trains latent diffusion. +- [ ] The shared Fiddle graph should define the autoencoder once and reference the same object in both stages, but no unresolved Fiddle objects should flow into actual `torch.nn.Module.__init__` calls. +- [ ] Stage-2 diffusion training uses the trained encoder to produce input and target latents; the trained decoder is retained for final forecast decoding but is not used in the stage-2 diffusion loss. - [ ] Reuse the same forecasting dataset abstraction in stage 2; do not add a separate latent dataset layer. - [ ] Add tests for shared object identity and stage sequencing. diff --git a/src/mlcast/models/diffusion/__init__.py b/src/mlcast/models/diffusion/__init__.py new file mode 100644 index 0000000..d237b22 --- /dev/null +++ b/src/mlcast/models/diffusion/__init__.py @@ -0,0 +1,17 @@ +"""Latent diffusion architecture components.""" + +from .conditioner import ConditionerBlock, ConditionerNet +from .denoiser import DenoiserUNet, TimestepEmbedding +from .loss import DiffusionLoss +from .net import LatentDiffusionNet +from .scheduler import DiffusionScheduler + +__all__ = [ + "ConditionerBlock", + "ConditionerNet", + "DenoiserUNet", + "DiffusionLoss", + "DiffusionScheduler", + "LatentDiffusionNet", + "TimestepEmbedding", +] diff --git a/src/mlcast/models/diffusion/conditioner.py b/src/mlcast/models/diffusion/conditioner.py new file mode 100644 index 0000000..c09e7d3 --- /dev/null +++ b/src/mlcast/models/diffusion/conditioner.py @@ -0,0 +1,86 @@ +"""Latent conditioning blocks for diffusion forecasting.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + + +class ConditionerBlock(nn.Module): + """Residual 3D-convolution block for latent conditioning. + + Parameters + ---------- + channels : int + Number of latent conditioning channels. + """ + + def __init__(self, channels: int) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups=1, num_channels=channels), + nn.SiLU(), + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + ) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch channels time height width"] + ) -> Float[torch.Tensor, "batch channels time height width"]: + """Apply residual conditioning refinement. + + Parameters + ---------- + x : Float[torch.Tensor, "batch channels time height width"] + Latent conditioning tensor. + + Returns + ------- + Float[torch.Tensor, "batch channels time height width"] + Refined conditioning tensor. + """ + return x + self.net(x) + + +class ConditionerNet(nn.Module): + """Condition latent target denoising on encoded input history. + + Parameters + ---------- + latent_channels : int + Number of latent channels in the encoded input history. + hidden_channels : int, optional + Number of channels emitted as conditioning context. Default is ``32``. + num_blocks : int, optional + Number of residual conditioning blocks. Default is ``2``. + """ + + def __init__(self, latent_channels: int, hidden_channels: int = 32, num_blocks: int = 2) -> None: + super().__init__() + if num_blocks < 1: + raise ValueError(f"num_blocks ({num_blocks}) must be at least 1.") + + self.latent_channels = latent_channels + self.hidden_channels = hidden_channels + self.num_blocks = num_blocks + self.input_projection = nn.Conv3d(latent_channels, hidden_channels, kernel_size=1) + self.blocks = nn.Sequential(*(ConditionerBlock(hidden_channels) for _ in range(num_blocks))) + + @jaxtyped(typechecker=beartype) + def forward( + self, z: Float[torch.Tensor, "batch latent_channels input_time height width"] + ) -> Float[torch.Tensor, "batch hidden_channels input_time height width"]: + """Build conditioning context from input-history latents. + + Parameters + ---------- + z : Float[torch.Tensor, "batch latent_channels input_time height width"] + Encoded input-history latent tensor. + + Returns + ------- + Float[torch.Tensor, "batch hidden_channels input_time height width"] + Conditioning context for the denoiser. + """ + return self.blocks(self.input_projection(z)) diff --git a/src/mlcast/models/diffusion/denoiser.py b/src/mlcast/models/diffusion/denoiser.py new file mode 100644 index 0000000..c107027 --- /dev/null +++ b/src/mlcast/models/diffusion/denoiser.py @@ -0,0 +1,250 @@ +"""Timestep-aware denoising network for latent diffusion.""" + +import math + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + + +class TimestepEmbedding(nn.Module): + """Sinusoidal timestep embedding followed by a small MLP. + + Parameters + ---------- + embedding_dim : int + Number of channels in the generated timestep embedding. + """ + + def __init__(self, embedding_dim: int) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.projection = nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.SiLU(), + nn.Linear(embedding_dim, embedding_dim), + ) + + @jaxtyped(typechecker=beartype) + def forward(self, timesteps: torch.Tensor) -> Float[torch.Tensor, "batch embedding_dim"]: + """Embed integer diffusion timesteps. + + Parameters + ---------- + timesteps : torch.Tensor + Integer diffusion timesteps. + + Returns + ------- + Float[torch.Tensor, "batch embedding_dim"] + Projected sinusoidal timestep embeddings. + """ + half_dim = self.embedding_dim // 2 + frequencies = torch.exp( + torch.arange(half_dim, device=timesteps.device, dtype=torch.float32) + * -(math.log(10_000.0) / max(half_dim - 1, 1)) + ) + args = timesteps.float()[:, None] * frequencies[None] + embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if embedding.shape[-1] < self.embedding_dim: + embedding = torch.nn.functional.pad(embedding, (0, 1)) + return self.projection(embedding) + + +class _DenoiserBlock(nn.Module): + """Internal residual denoising block. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + timestep_channels : int + Number of channels in the timestep embedding. + """ + + def __init__(self, in_channels: int, out_channels: int, timestep_channels: int) -> None: + super().__init__() + self.timestep_projection = nn.Linear(timestep_channels, out_channels) + self.net = nn.Sequential( + nn.GroupNorm(num_groups=1, num_channels=in_channels), + nn.SiLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip_connection = nn.Identity() if in_channels == out_channels else nn.Conv3d(in_channels, out_channels, 1) + + def forward(self, x: torch.Tensor, timestep_embedding: torch.Tensor) -> torch.Tensor: + """Apply timestep-conditioned residual denoising. + + Parameters + ---------- + x : torch.Tensor + Hidden denoising tensor. + timestep_embedding : torch.Tensor + Timestep embedding for each batch item. + + Returns + ------- + torch.Tensor + Updated hidden tensor. + """ + timestep_bias = self.timestep_projection(timestep_embedding)[:, :, None, None, None] + h = self.net(x) + return self.skip_connection(x) + h + timestep_bias + + +class _SpatialDownsample(nn.Module): + """Halve latent spatial resolution while preserving time.""" + + def __init__(self, channels: int) -> None: + super().__init__() + self.op = nn.Conv3d(channels, channels, kernel_size=3, stride=(1, 2, 2), padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Downsample the spatial dimensions of a latent tensor. + + Parameters + ---------- + x : torch.Tensor + Channel-first latent tensor. + + Returns + ------- + torch.Tensor + Tensor with half spatial resolution. + """ + return self.op(x) + + +class _SpatialUpsample(nn.Module): + """Double latent spatial resolution while preserving time.""" + + def __init__(self, channels: int) -> None: + super().__init__() + self.op = nn.ConvTranspose3d( + channels, + channels, + kernel_size=3, + stride=(1, 2, 2), + padding=1, + output_padding=(0, 1, 1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the spatial dimensions of a latent tensor. + + Parameters + ---------- + x : torch.Tensor + Channel-first latent tensor. + + Returns + ------- + torch.Tensor + Tensor with doubled spatial resolution. + """ + return self.op(x) + + +class DenoiserUNet(nn.Module): + """Compact timestep-aware U-Net denoiser for latent tensors. + + This is a real U-Net-style architecture: it builds a spatial downsampling + path, applies a bottleneck at the lowest spatial resolution, upsamples back + to the original latent resolution, and concatenates matching-resolution + skip connections from the down path into the up path. It differs from a + plain image U-Net because it operates on 3D latent tensors and only changes + spatial resolution; the temporal dimension is preserved throughout. Each + residual block also receives a diffusion timestep embedding. + + Parameters + ---------- + latent_channels : int + Number of channels in the noisy target latent. + condition_channels : int + Number of channels emitted by the conditioner. + hidden_channels : int, optional + Number of hidden channels in the denoiser. Default is ``32``. + num_blocks : int, optional + Number of U-Net resolution levels. Default is ``2``. + """ + + def __init__( + self, + latent_channels: int, + condition_channels: int, + hidden_channels: int = 32, + num_blocks: int = 2, + ) -> None: + super().__init__() + if num_blocks < 1: + raise ValueError(f"num_blocks ({num_blocks}) must be at least 1.") + + self.latent_channels = latent_channels + self.condition_channels = condition_channels + self.hidden_channels = hidden_channels + self.num_blocks = num_blocks + self.timestep_embedding = TimestepEmbedding(hidden_channels) + self.input_projection = nn.Conv3d(latent_channels + condition_channels, hidden_channels, kernel_size=1) + self.down_blocks = nn.ModuleList( + _DenoiserBlock(hidden_channels, hidden_channels, hidden_channels) for _ in range(num_blocks) + ) + self.downsamples = nn.ModuleList(_SpatialDownsample(hidden_channels) for _ in range(num_blocks - 1)) + self.bottleneck = _DenoiserBlock(hidden_channels, hidden_channels, hidden_channels) + self.upsamples = nn.ModuleList(_SpatialUpsample(hidden_channels) for _ in range(num_blocks - 1)) + self.up_blocks = nn.ModuleList( + _DenoiserBlock(hidden_channels * 2, hidden_channels, hidden_channels) for _ in range(num_blocks - 1) + ) + self.output_projection = nn.Conv3d(hidden_channels, latent_channels, kernel_size=1) + + @jaxtyped(typechecker=beartype) + def forward( + self, + noisy: Float[torch.Tensor, "batch latent_channels forecast_time height width"], + timesteps: torch.Tensor, + context: Float[torch.Tensor, "batch condition_channels input_time height width"], + ) -> Float[torch.Tensor, "batch latent_channels forecast_time height width"]: + """Predict noise in a noised latent target. + + Parameters + ---------- + noisy : Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Noised target latent. + timesteps : torch.Tensor + Diffusion timestep for each sample. + context : Float[torch.Tensor, "batch condition_channels input_time height width"] + Conditioning context from the input-history latent. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Predicted noise tensor. + """ + if context.shape[2] != noisy.shape[2]: + context = torch.nn.functional.interpolate(context, size=noisy.shape[2:], mode="nearest") + + x = self.input_projection(torch.cat([noisy, context], dim=1)) + timestep_embedding = self.timestep_embedding(timesteps) + + skips: list[torch.Tensor] = [] + for block_idx, block in enumerate(self.down_blocks): + x = block(x, timestep_embedding) + if block_idx < len(self.downsamples): + skips.append(x) + x = self.downsamples[block_idx](x) + + x = self.bottleneck(x, timestep_embedding) + + for upsample, block in zip(self.upsamples, self.up_blocks, strict=True): + x = upsample(x) + skip = skips.pop() + if x.shape[-2:] != skip.shape[-2:]: + x = torch.nn.functional.interpolate(x, size=skip.shape[2:], mode="nearest") + x = block(torch.cat([x, skip], dim=1), timestep_embedding) + + return self.output_projection(x) diff --git a/src/mlcast/models/diffusion/ema.py b/src/mlcast/models/diffusion/ema.py new file mode 100644 index 0000000..570599c --- /dev/null +++ b/src/mlcast/models/diffusion/ema.py @@ -0,0 +1,48 @@ +"""Exponential moving average helpers for diffusion weights.""" + +import torch +import torch.nn as nn + + +class ExponentialMovingAverage: + """Track an exponential moving average of trainable module parameters. + + Parameters + ---------- + module : nn.Module + Module whose parameters should be tracked. + decay : float, optional + EMA decay factor. Default is ``0.999``. + """ + + def __init__(self, module: nn.Module, decay: float = 0.999) -> None: + if not 0.0 <= decay < 1.0: + raise ValueError(f"decay ({decay}) must be in [0, 1).") + self.module = module + self.decay = decay + self.shadow_params = [ + parameter.detach().clone() for parameter in module.parameters() if parameter.requires_grad + ] + self.backup_params: list[torch.Tensor] | None = None + + def update(self) -> None: + """Update EMA shadow parameters from the current module parameters.""" + trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] + for shadow_param, parameter in zip(self.shadow_params, trainable_params, strict=True): + shadow_param.mul_(self.decay).add_(parameter.detach(), alpha=1.0 - self.decay) + + def apply(self) -> None: + """Swap EMA parameters into the tracked module.""" + trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] + self.backup_params = [parameter.detach().clone() for parameter in trainable_params] + for parameter, shadow_param in zip(trainable_params, self.shadow_params, strict=True): + parameter.data.copy_(shadow_param.data) + + def restore(self) -> None: + """Restore module parameters saved before :meth:`apply`.""" + if self.backup_params is None: + raise RuntimeError("EMA restore() called before apply().") + trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] + for parameter, backup_param in zip(trainable_params, self.backup_params, strict=True): + parameter.data.copy_(backup_param.data) + self.backup_params = None diff --git a/src/mlcast/models/diffusion/loss.py b/src/mlcast/models/diffusion/loss.py new file mode 100644 index 0000000..6c1f252 --- /dev/null +++ b/src/mlcast/models/diffusion/loss.py @@ -0,0 +1,48 @@ +"""Loss helpers for latent diffusion training.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.models.diffusion.net import LatentDiffusionNet + + +class DiffusionLoss(nn.Module): + """Noise-prediction MSE loss for latent diffusion. + + Parameters + ---------- + net : LatentDiffusionNet + Diffusion network used to sample noised latents and predict noise. + """ + + def __init__(self, net: LatentDiffusionNet) -> None: + super().__init__() + self.net = net + + @jaxtyped(typechecker=beartype) + def forward( + self, + input_latents: Float[torch.Tensor, "batch latent_channels input_time height width"], + target_latents: Float[torch.Tensor, "batch latent_channels forecast_time height width"], + ) -> torch.Tensor: + """Compute a random-timestep noise-prediction loss. + + Parameters + ---------- + input_latents : Float[torch.Tensor, "batch latent_channels input_time height width"] + Encoded input-history latents used as conditioning. + target_latents : Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Clean target latents to diffuse. + + Returns + ------- + torch.Tensor + Scalar mean squared error between predicted and sampled noise. + """ + timesteps = torch.randint(0, self.net.num_timesteps, (target_latents.shape[0],), device=target_latents.device) + noise = torch.randn_like(target_latents) + noised_target = self.net.q_sample(target_latents, timesteps=timesteps, noise=noise) + predicted_noise = self.net(noised_target, timesteps, input_latents) + return torch.nn.functional.mse_loss(predicted_noise, noise) diff --git a/src/mlcast/models/diffusion/net.py b/src/mlcast/models/diffusion/net.py new file mode 100644 index 0000000..59ff087 --- /dev/null +++ b/src/mlcast/models/diffusion/net.py @@ -0,0 +1,102 @@ +"""Latent diffusion network composed from conditioner and denoiser modules.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.models.diffusion.conditioner import ConditionerNet +from mlcast.models.diffusion.denoiser import DenoiserUNet +from mlcast.models.diffusion.scheduler import DiffusionScheduler, extract_schedule_value + + +class LatentDiffusionNet(nn.Module): + """Trainable latent diffusion denoising network. + + Parameters + ---------- + conditioner : ConditionerNet + Network that builds context from input-history latents. It must accept + ``Float[Tensor, "batch latent_channels input_time height width"]`` and + return ``Float[Tensor, "batch condition_channels input_time height width"]``. + denoiser : DenoiserUNet + Network that predicts noise from noised target latents. It must accept + ``noisy`` with shape + ``Float[Tensor, "batch latent_channels forecast_time height width"]``, + ``timesteps`` with shape ``(batch,)``, and ``context`` with shape + ``Float[Tensor, "batch condition_channels input_time height width"]``; + it must return + ``Float[Tensor, "batch latent_channels forecast_time height width"]``. + scheduler : DiffusionScheduler + Diffusion noise scheduler. Calling ``scheduler.buffers(device, dtype)`` + must return one-dimensional tensors of length ``scheduler.timesteps`` + for ``sqrt_alphas_cumprod`` and ``sqrt_one_minus_alphas_cumprod`` so + they can be gathered with timestep indices shaped ``(batch,)`` and + broadcast over latent tensors shaped + ``(batch, latent_channels, forecast_time, height, width)``. + """ + + def __init__(self, conditioner: ConditionerNet, denoiser: DenoiserUNet, scheduler: DiffusionScheduler) -> None: + super().__init__() + self.conditioner = conditioner + self.denoiser = denoiser + self.scheduler = scheduler + self.num_timesteps = scheduler.timesteps + for name, value in scheduler.buffers(device=torch.device("cpu")).items(): + self.register_buffer(name, value) + + @jaxtyped(typechecker=beartype) + def q_sample( + self, + x0: Float[torch.Tensor, "batch channels time height width"], + timesteps: torch.Tensor, + noise: Float[torch.Tensor, "batch channels time height width"] | None = None, + ) -> Float[torch.Tensor, "batch channels time height width"]: + """Diffuse clean latents to a chosen timestep. + + Parameters + ---------- + x0 : Float[torch.Tensor, "batch channels time height width"] + Clean target latent. + timesteps : torch.Tensor + Diffusion timestep for each sample. + noise : Float[torch.Tensor, "batch channels time height width"] or None, optional + Noise to add. If ``None``, standard Gaussian noise is sampled. + Default is ``None``. + + Returns + ------- + Float[torch.Tensor, "batch channels time height width"] + Noised target latent. + """ + if noise is None: + noise = torch.randn_like(x0) + sqrt_alpha = extract_schedule_value(self.sqrt_alphas_cumprod, timesteps, x0.shape) + sqrt_one_minus_alpha = extract_schedule_value(self.sqrt_one_minus_alphas_cumprod, timesteps, x0.shape) + return sqrt_alpha * x0 + sqrt_one_minus_alpha * noise + + @jaxtyped(typechecker=beartype) + def forward( + self, + noised_target: Float[torch.Tensor, "batch latent_channels forecast_time height width"], + timesteps: torch.Tensor, + input_latents: Float[torch.Tensor, "batch latent_channels input_time height width"], + ) -> Float[torch.Tensor, "batch latent_channels forecast_time height width"]: + """Predict noise from a noised target latent and input context. + + Parameters + ---------- + noised_target : Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Noised target latent. + timesteps : torch.Tensor + Diffusion timestep for each sample. + input_latents : Float[torch.Tensor, "batch latent_channels input_time height width"] + Encoded input-history latents. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Predicted noise. + """ + context = self.conditioner(input_latents) + return self.denoiser(noised_target, timesteps, context=context) diff --git a/src/mlcast/models/diffusion/sampler.py b/src/mlcast/models/diffusion/sampler.py new file mode 100644 index 0000000..53d5159 --- /dev/null +++ b/src/mlcast/models/diffusion/sampler.py @@ -0,0 +1,60 @@ +"""Simple ancestral sampler for latent diffusion models.""" + +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.models.diffusion.net import LatentDiffusionNet +from mlcast.models.diffusion.scheduler import extract_schedule_value + + +class DiffusionSampler: + """Generate latent samples with a compact DDPM-style reverse process. + + Parameters + ---------- + net : LatentDiffusionNet + Trained diffusion network. + """ + + def __init__(self, net: LatentDiffusionNet) -> None: + self.net = net + + @jaxtyped(typechecker=beartype) + def sample( + self, + input_latents: Float[torch.Tensor, "batch latent_channels input_time height width"], + output_shape: tuple[int, int, int, int, int], + ) -> Float[torch.Tensor, "batch latent_channels forecast_time height width"]: + """Sample forecast latents conditioned on input latents. + + Parameters + ---------- + input_latents : Float[torch.Tensor, "batch latent_channels input_time height width"] + Encoded input-history latents. + output_shape : tuple of int + Shape of the forecast latent to sample, ordered as + ``(batch, channels, forecast_time, height, width)``. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Sampled forecast latent. + """ + x = torch.randn(output_shape, device=input_latents.device, dtype=input_latents.dtype) + for step in reversed(range(self.net.num_timesteps)): + timesteps = torch.full((output_shape[0],), step, device=input_latents.device, dtype=torch.long) + predicted_noise = self.net(x, timesteps, input_latents) + sqrt_alpha = extract_schedule_value(self.net.sqrt_alphas_cumprod, timesteps, x.shape) + sqrt_one_minus_alpha = extract_schedule_value(self.net.sqrt_one_minus_alphas_cumprod, timesteps, x.shape) + x0 = (x - sqrt_one_minus_alpha * predicted_noise) / sqrt_alpha.clamp_min(1e-6) + if step > 0: + prev_timesteps = timesteps - 1 + prev_sqrt_alpha = extract_schedule_value(self.net.sqrt_alphas_cumprod, prev_timesteps, x.shape) + prev_sqrt_one_minus_alpha = extract_schedule_value( + self.net.sqrt_one_minus_alphas_cumprod, prev_timesteps, x.shape + ) + x = prev_sqrt_alpha * x0 + prev_sqrt_one_minus_alpha * predicted_noise + else: + x = x0 + return x diff --git a/src/mlcast/models/diffusion/scheduler.py b/src/mlcast/models/diffusion/scheduler.py new file mode 100644 index 0000000..ceaa80a --- /dev/null +++ b/src/mlcast/models/diffusion/scheduler.py @@ -0,0 +1,71 @@ +"""Diffusion noise schedules.""" + +import torch + + +class DiffusionScheduler: + """Linear-beta diffusion scheduler. + + Parameters + ---------- + timesteps : int, optional + Number of diffusion timesteps. Default is ``100``. + beta_start : float, optional + Initial beta value. Default is ``1e-4``. + beta_end : float, optional + Final beta value. Default is ``2e-2``. + """ + + def __init__(self, timesteps: int = 100, beta_start: float = 1e-4, beta_end: float = 2e-2) -> None: + if timesteps < 1: + raise ValueError(f"timesteps ({timesteps}) must be at least 1.") + self.timesteps = timesteps + self.beta_start = beta_start + self.beta_end = beta_end + + def buffers(self, device: torch.device, dtype: torch.dtype = torch.float32) -> dict[str, torch.Tensor]: + """Build schedule tensors for registration as module buffers. + + Parameters + ---------- + device : torch.device + Device on which buffers should be allocated. + dtype : torch.dtype, optional + Floating-point dtype for schedule tensors. Default is + ``torch.float32``. + + Returns + ------- + dict of str to torch.Tensor + Schedule tensors used for forward and reverse diffusion. + """ + betas = torch.linspace(self.beta_start, self.beta_end, self.timesteps, device=device, dtype=dtype) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + return { + "betas": betas, + "alphas": alphas, + "alphas_cumprod": alphas_cumprod, + "sqrt_alphas_cumprod": torch.sqrt(alphas_cumprod), + "sqrt_one_minus_alphas_cumprod": torch.sqrt(1.0 - alphas_cumprod), + } + + +def extract_schedule_value(values: torch.Tensor, timesteps: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: + """Gather schedule values and reshape them for broadcasting. + + Parameters + ---------- + values : Float[torch.Tensor, "timesteps"] + One-dimensional schedule tensor. + timesteps : Int[torch.Tensor, "batch"] + Timestep index for each batch item. + target_shape : torch.Size + Shape of the target tensor the values should broadcast against. + + Returns + ------- + torch.Tensor + Gathered values reshaped to ``(batch, 1, ..., 1)``. + """ + return values.gather(0, timesteps).reshape(timesteps.shape[0], *([1] * (len(target_shape) - 1))) diff --git a/tests/models/test_diffusion.py b/tests/models/test_diffusion.py new file mode 100644 index 0000000..ac9258f --- /dev/null +++ b/tests/models/test_diffusion.py @@ -0,0 +1,74 @@ +import torch + +from mlcast.models.diffusion import ( + ConditionerNet, + DenoiserUNet, + DiffusionLoss, + DiffusionScheduler, + LatentDiffusionNet, +) + + +def _build_diffusion_net(latent_channels: int = 1, hidden_channels: int = 8, timesteps: int = 4) -> LatentDiffusionNet: + conditioner = ConditionerNet(latent_channels=latent_channels, hidden_channels=hidden_channels, num_blocks=1) + denoiser = DenoiserUNet( + latent_channels=latent_channels, + condition_channels=hidden_channels, + hidden_channels=hidden_channels, + num_blocks=1, + ) + scheduler = DiffusionScheduler(timesteps=timesteps) + return LatentDiffusionNet(conditioner=conditioner, denoiser=denoiser, scheduler=scheduler) + + +def test_latent_diffusion_net_api() -> None: + """LatentDiffusionNet should predict noise with the target latent shape.""" + input_time = 2 + forecast_steps = 3 + latent_channels = 1 + height = 4 + width = 4 + diffusion_net = _build_diffusion_net(latent_channels=latent_channels, hidden_channels=4, timesteps=2) + noised_target = torch.randn(2, latent_channels, forecast_steps, height, width) + input_latents = torch.randn(2, latent_channels, input_time, height, width) + timesteps = torch.zeros(2, dtype=torch.long) + + with torch.no_grad(): + predicted_noise = diffusion_net(noised_target, timesteps, input_latents) + + assert predicted_noise.shape == noised_target.shape + + +def test_diffusion_model_improves_loss_on_generated_latents() -> None: + """Diffusion model should reduce noise-prediction loss on generated latents.""" + torch.manual_seed(7) + batch_size = 8 + latent_channels = 1 + input_time = 2 + forecast_time = 3 + height = 4 + width = 4 + diffusion_net = _build_diffusion_net(latent_channels=latent_channels, hidden_channels=8, timesteps=1) + loss_fn = DiffusionLoss(diffusion_net) + optimizer = torch.optim.Adam(diffusion_net.parameters(), lr=5e-3) + + input_latents = torch.randn(batch_size, latent_channels, input_time, height, width) + target_base = input_latents.mean(dim=2, keepdim=True) + target_latents = target_base.repeat(1, 1, forecast_time, 1, 1) + target_latents = target_latents + 0.05 * torch.randn_like(target_latents) + + torch.manual_seed(42) + with torch.no_grad(): + initial_loss = loss_fn(input_latents, target_latents).item() + + for _ in range(80): + optimizer.zero_grad(set_to_none=True) + loss = loss_fn(input_latents, target_latents) + loss.backward() + optimizer.step() + + torch.manual_seed(42) + with torch.no_grad(): + final_loss = loss_fn(input_latents, target_latents).item() + + assert final_loss < initial_loss From a4701c37b3785d352663f505ce958eb95a5a665c Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 3 Jun 2026 23:04:30 +0200 Subject: [PATCH 16/34] refactor: split forecasting task modules --- README.md | 25 +- docs/config_diagram.svg | 4 +- ldcast-refactor-plan.md | 19 +- src/mlcast/config/base.py | 4 +- src/mlcast/modules/__init__.py | 20 ++ src/mlcast/modules/forecasting.py | 518 +++++++++++++++++++++++++++ src/mlcast/modules/reconstruction.py | 160 +++++++++ src/mlcast/nowcasting_module.py | 294 +-------------- tests/test_nowcasting_module.py | 8 +- tests/test_task_modules.py | 98 +++++ 10 files changed, 844 insertions(+), 306 deletions(-) create mode 100644 src/mlcast/modules/__init__.py create mode 100644 src/mlcast/modules/forecasting.py create mode 100644 src/mlcast/modules/reconstruction.py create mode 100644 tests/test_task_modules.py diff --git a/README.md b/README.md index 63720ec..30e1950 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,29 @@ Any combination of these can be layered on top of the selected config, and the fully resolved config is always saved to YAML alongside the training logs so runs can be reproduced exactly. +### Design roles + +mlcast separates pure architectures from task-level training wrappers. + +- `src/mlcast/models/` + Pure `torch.nn.Module` architectures and supporting components. These classes + define tensor transformations and reusable building blocks, but they do not + decide how training is run or which parameters are optimized. +- `src/mlcast/modules/` + Task-level Lightning modules. These classes define what batch structure a + task consumes, what loss is computed, which parameters are optimized, and how + inference/prediction is exposed. + +In other words, architectures answer "how does this tensor get transformed?", +while task modules answer "what is being trained, against what target, and over +which parameters?" + +This distinction matters especially for latent diffusion. The diffusion +architecture itself lives under `models/`, while the corresponding task module +owns the trained autoencoder reuse policy, decides that only diffusion-network +parameters are optimized, computes diffusion loss in latent space, and handles +decoded forecast inference. + The diagram below shows the full included ConvGRU config graph as built by [`convgru_training_experiment`](src/mlcast/config/base.py): @@ -190,7 +213,7 @@ from mlcast.config.fiddlers import use_random_sampler # Minimal adapter: channel-stack past frames -> HalfUNet -> one step at a time. # The forecasting contract fixes input_steps, forecast_steps, and ensemble_size # at model initialization; this minimal deterministic adapter exposes one -# ensemble member and NowcastLightningModule calls network(x). +# ensemble member and ForecastingTaskModule calls network(x). class HalfUNetNowcaster(nn.Module): def __init__(self, input_steps: int = 6, forecast_steps: int = 12, num_vars: int = 1): super().__init__() diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 9a08bf4..6cc3ba8 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -52,8 +52,8 @@ 1 -Config: - NowcastLightningModule +Config: + ForecastingTaskModule network diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index f1b125f..83c8562 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -58,18 +58,21 @@ - [x] latent diffusion model API. - [x] diffusion model improves loss on a small generated latent dataset after a few training steps. -5. Task wrappers -- [ ] Add `src/mlcast/modules/forecasting.py` and rename `NowcastLightningModule` to `ForecastingModule`. -- [ ] Remove runtime `forecast_steps` and `ensemble_size` arguments from the forecasting Lightning module and its `predict()` API. -- [ ] Add `src/mlcast/modules/reconstruction.py` with a generic `ReconstructionModule` for any reconstruction model. -- [ ] Add a latent diffusion Lightning module that owns the trained autoencoder, trains diffusion in latent space, and handles decoded forecast inference. -- [ ] Keep `modules/` for training/task wrappers only; keep `models/` for pure architectures. +5. Task modules (Lightning modules) +- [x] Add `src/mlcast/modules/forecasting.py`, introduce `BaseForecastingTaskModule`, and rename `NowcastLightningModule` to `ForecastingTaskModule`. +- [x] `BaseForecastingTaskModule` should own optimizer/scheduler plumbing, while each concrete task module defines which parameters are trainable. +- [x] `ForecastingTaskModule` should optimize the forecasting network parameters. +- [x] Remove runtime `forecast_steps` and `ensemble_size` arguments from the forecasting task module and its `predict()` API. +- [x] Add `src/mlcast/modules/reconstruction.py` with a generic `ReconstructionTaskModule` for any reconstruction model. +- [x] Add a `LatentDiffusionTaskModule` that owns the trained autoencoder, optimizes only the diffusion-network parameters, trains diffusion in latent space, and handles decoded forecast inference. +- [x] Keep `modules/` for task-level Lightning modules only; keep `models/` for pure architectures. 6. Training experiment - [ ] Add a new LDCast-specific training module containing `LDCastTrainingExperiment`. - [x] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. -- [ ] Stage 1 builds the reconstruction dataset, autoencoder model, and reconstruction module, then trains the autoencoder. -- [ ] Stage 2 reuses the same trained in-memory autoencoder instance, builds the diffusion dataset/model/module, then trains latent diffusion. +- [ ] Stage 1 builds the reconstruction dataset, autoencoder model, and `ReconstructionTaskModule`, then trains the autoencoder. +- [ ] Stage 2 reuses the same trained in-memory autoencoder instance, builds the diffusion dataset/model/`LatentDiffusionTaskModule`, then trains latent diffusion. +- [ ] Stage 2 freezes the reused autoencoder parameters and optimizes only the latent diffusion task module's diffusion-network parameters. - [ ] The shared Fiddle graph should define the autoencoder once and reference the same object in both stages, but no unresolved Fiddle objects should flow into actual `torch.nn.Module.__init__` calls. - [ ] Stage-2 diffusion training uses the trained encoder to produce input and target latents; the trained decoder is retained for final forecast decoding but is not used in the stage-2 diffusion loss. - [ ] Reuse the same forecasting dataset abstraction in stage 2; do not add a separate latent dataset layer. diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index da7a4f7..311f4a5 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -32,7 +32,7 @@ from ..data.datamodules import ForecastingDataModule from ..data.sequence import SourceDataPrecomputedSequenceDataset from ..models.convgru import ConvGruModel -from ..nowcasting_module import NowcastLightningModule +from ..modules.forecasting import ForecastingTaskModule @dataclass @@ -91,7 +91,7 @@ def convgru_training_experiment() -> Experiment: noisy_decoder=False, ) - pl_module = NowcastLightningModule( + pl_module = ForecastingTaskModule( network=network, loss_class="crps", loss_params={"temporal_lambda": 0.01}, diff --git a/src/mlcast/modules/__init__.py b/src/mlcast/modules/__init__.py new file mode 100644 index 0000000..ce40eaf --- /dev/null +++ b/src/mlcast/modules/__init__.py @@ -0,0 +1,20 @@ +"""Training and task-level Lightning module wrappers.""" + +from .forecasting import ( + BaseForecastingTaskModule, + ForecastingModule, + ForecastingTaskModule, + LatentDiffusionModule, + LatentDiffusionTaskModule, +) +from .reconstruction import ReconstructionModule, ReconstructionTaskModule + +__all__ = [ + "BaseForecastingTaskModule", + "ForecastingModule", + "ForecastingTaskModule", + "LatentDiffusionModule", + "LatentDiffusionTaskModule", + "ReconstructionModule", + "ReconstructionTaskModule", +] diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py new file mode 100644 index 0000000..3d6b636 --- /dev/null +++ b/src/mlcast/modules/forecasting.py @@ -0,0 +1,518 @@ +"""Forecasting task-level Lightning module wrappers.""" + +from abc import abstractmethod +from collections.abc import Callable +from typing import Any + +import numpy as np +import pytorch_lightning as pl +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.data.normalization import DENORMALIZATION_REGISTRY, NORMALIZATION_REGISTRY +from mlcast.losses import build_loss +from mlcast.models.autoencoder import AutoencoderNet +from mlcast.models.diffusion.ema import ExponentialMovingAverage +from mlcast.models.diffusion.loss import DiffusionLoss +from mlcast.models.diffusion.net import LatentDiffusionNet +from mlcast.models.diffusion.sampler import DiffusionSampler +from mlcast.visualization import log_images + + +class BaseForecastingTaskModule(pl.LightningModule): + """Base Lightning module for forecasting-shaped tasks. + + Parameters + ---------- + optimizer : Callable[..., torch.optim.Optimizer] or None, optional + Optimizer factory. Default is ``None`` (Adam). + lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional + Learning-rate scheduler factory. Default is ``None``. + """ + + def __init__( + self, + optimizer: Callable[..., torch.optim.Optimizer] | None = None, + lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, + ) -> None: + super().__init__() + self.optimizer_factory = optimizer + self.lr_scheduler_factory = lr_scheduler + + @property + @abstractmethod + def trainable_parameters(self) -> list[torch.nn.Parameter]: + """Return the parameters optimized for this forecasting task. + + Returns + ------- + list of torch.nn.Parameter + Trainable parameters owned by the concrete forecasting task. + """ + + @abstractmethod + def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: + """Compute and log loss for one forecasting batch. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch. + split : str, optional + Current data split. Default is ``"train"``. + + Returns + ------- + torch.Tensor + Scalar task loss. + """ + + @jaxtyped(typechecker=beartype) + def predict_normalized( + self, + x: Float[torch.Tensor, "batch time channels height width"], + ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: + """Predict normalized forecasts from normalized inputs. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Normalized forecasting input. + + Returns + ------- + Float[torch.Tensor, "batch forecast_steps out_channels height width"] + Normalized forecast tensor. + """ + return self(x) + + def training_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: + """Execute a training step. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Training loss. + """ + return self.compute_loss(batch, split="train") + + def validation_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: + """Execute a validation step. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Validation loss. + """ + return self.compute_loss(batch, split="val") + + def test_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: + """Execute a test step. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Test loss. + """ + return self.compute_loss(batch, split="test") + + def configure_optimizers(self) -> Any: + """Configure optimizer and optional scheduler. + + Returns + ------- + Any + PyTorch Lightning optimizer configuration. + """ + parameters = self.trainable_parameters + if self.optimizer_factory is not None: + optimizer = self.optimizer_factory(parameters) + else: + optimizer = torch.optim.Adam(parameters) + + if self.lr_scheduler_factory is not None: + lr_scheduler = self.lr_scheduler_factory(optimizer) + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}} + return {"optimizer": optimizer} + + def predict(self, past: torch.Tensor, standard_name: str = "rainfall_rate") -> np.ndarray[Any, Any]: + """Generate unnormalized forecasts from unnormalized past observations. + + Parameters + ---------- + past : torch.Tensor + Past observations with shape ``(T, H, W)``. + standard_name : str, optional + CF standard name that selects normalization and denormalization + functions. Default is ``"rainfall_rate"``. + + Returns + ------- + np.ndarray + Forecast array shaped ``(ensemble_size, forecast_steps, H, W)`` for + single-channel outputs. + """ + if len(past.shape) != 3: + raise ValueError("Input must be of shape (T, H, W)") + + past_clean = np.nan_to_num(past.cpu().numpy()) + past_clean = past_clean[np.newaxis, :, np.newaxis, ...] + norm_func = NORMALIZATION_REGISTRY[standard_name] + norm_past = norm_func(past_clean) + + x = torch.from_numpy(norm_past).to(self.device) + self.eval() + with torch.no_grad(): + preds_tensor = self.predict_normalized(x) + + preds_np: np.ndarray[Any, Any] = preds_tensor.cpu().numpy() + denorm_func = DENORMALIZATION_REGISTRY[standard_name] + preds_np = denorm_func(preds_np) + preds_np = preds_np.squeeze(0) + preds_np = np.swapaxes(preds_np, 0, 1) + return preds_np + + +class ForecastingTaskModule(BaseForecastingTaskModule): + """Generic PyTorch Lightning module for direct forecasting tasks. + + Parameters + ---------- + network : torch.nn.Module + Forecasting network to train. + loss_class : type[torch.nn.Module] or str, optional + Loss function class or registry name. Default is ``"mse"``. + loss_params : dict or None, optional + Keyword arguments for the loss constructor. Default is ``None``. + masked_loss : bool, optional + Whether to use masked-loss computation with ``target_mask`` from the + batch. Default is ``False``. + optimizer : Callable[..., torch.optim.Optimizer] or None, optional + Optimizer factory. Default is ``None`` (Adam). + lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional + Learning-rate scheduler factory. Default is ``None``. + """ + + def __init__( + self, + network: torch.nn.Module, + loss_class: type[torch.nn.Module] | str = "mse", + loss_params: dict[str, Any] | None = None, + masked_loss: bool = False, + optimizer: Callable[..., torch.optim.Optimizer] | None = None, + lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, + ) -> None: + super().__init__(optimizer=optimizer, lr_scheduler=lr_scheduler) + self.save_hyperparameters("loss_class", "loss_params", "masked_loss") + self.network = network + self.criterion = build_loss(loss_class=loss_class, loss_params=loss_params, masked_loss=masked_loss) + self.log_images_iterations = [50, 100, 200, 500, 750, 1000, 2000, 5000] + + @jaxtyped(typechecker=beartype) + def forward( + self, + x: Float[torch.Tensor, "batch time channels height width"], + ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: + """Run the forecasting network. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Normalized input history tensor. + + Returns + ------- + Float[torch.Tensor, "batch forecast_steps out_channels height width"] + Normalized forecast tensor. + """ + return self.network(x) + + def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: + """Compute and log forecasting loss for one batch. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch containing ``input`` and ``target`` tensors, and + optionally ``target_mask`` when masked loss is enabled. + split : str, optional + Current data split. Default is ``"train"``. + + Returns + ------- + torch.Tensor + Scalar loss tensor. + """ + past = batch["input"] + future = batch["target"] + preds = self(past).clamp(min=-1, max=1) + + if self.hparams["masked_loss"]: + mask = batch["target_mask"] + loss = self.criterion(preds, future, mask) + else: + loss = self.criterion(preds, future) + + if isinstance(loss, tuple): + loss, log_dict = loss + self.log_dict( + log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True + ) + + self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + + ensemble_size = getattr(self.network, "ensemble_size", 1) + if ensemble_size > 1: + ensemble_std = preds.std(dim=2).mean() + self.log(f"{split}_ensemble_std", ensemble_std, on_epoch=True, sync_dist=True) + + if ( + split == "train" + and self.logger is not None + and getattr(self.logger, "experiment", None) is not None + and ( + self.global_step in self.log_images_iterations or self.global_step % self.log_images_iterations[-1] == 0 + ) + ): + log_images( + past=past, + future=future, + preds=preds, + logger=self.logger, # type: ignore[arg-type] + global_step=self.global_step, + ensemble_size=ensemble_size, + split=split, + ) + return loss + + @property + def trainable_parameters(self) -> list[torch.nn.Parameter]: + """Return the forecasting network parameters. + + Returns + ------- + list of torch.nn.Parameter + Parameters optimized for direct forecasting. + """ + return list(self.network.parameters()) + + @classmethod + def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "ForecastingTaskModule": + """Load a forecasting task module from checkpoint. + + Parameters + ---------- + checkpoint_path : str + Path to the saved Lightning checkpoint. + device : str, optional + Device to map parameters onto. Default is ``"cpu"``. + + Returns + ------- + ForecastingTaskModule + Loaded forecasting task module. + """ + return cls.load_from_checkpoint( + checkpoint_path, + map_location=torch.device(device), + strict=True, + weights_only=False, + ) + + +ForecastingModule = ForecastingTaskModule + + +class LatentDiffusionTaskModule(BaseForecastingTaskModule): + """Train latent diffusion in latent space and decode forecasts for inference. + + Parameters + ---------- + autoencoder : AutoencoderNet + Trained autoencoder reused from stage 1. The encoder is used during + stage-2 training to map forecasting inputs and targets into latent + space. The decoder is retained for forecast inference but is not used in + the stage-2 diffusion loss. + diffusion_net : LatentDiffusionNet + Latent diffusion architecture to train. + forecast_steps : int + Number of forecast timesteps decoded during inference. + ensemble_size : int, optional + Number of ensemble members decoded during inference. Default is ``1``. + loss : DiffusionLoss or None, optional + Latent diffusion loss module. If ``None``, ``DiffusionLoss`` is built + from ``diffusion_net``. Default is ``None``. + optimizer : Callable[..., torch.optim.Optimizer] or None, optional + Optimizer factory. Default is ``None`` (Adam). + lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional + Learning-rate scheduler factory. Default is ``None``. + ema_decay : float or None, optional + If provided, track an exponential moving average of diffusion-net + parameters with this decay. Default is ``None``. + """ + + def __init__( + self, + autoencoder: AutoencoderNet, + diffusion_net: LatentDiffusionNet, + forecast_steps: int, + ensemble_size: int = 1, + loss: DiffusionLoss | None = None, + optimizer: Callable[..., torch.optim.Optimizer] | None = None, + lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, + ema_decay: float | None = None, + ) -> None: + super().__init__(optimizer=optimizer, lr_scheduler=lr_scheduler) + if forecast_steps < 1: + raise ValueError(f"forecast_steps ({forecast_steps}) must be at least 1.") + if ensemble_size < 1: + raise ValueError(f"ensemble_size ({ensemble_size}) must be at least 1.") + + self.save_hyperparameters("forecast_steps", "ensemble_size", "ema_decay") + self.autoencoder = autoencoder + self.diffusion_net = diffusion_net + self.loss_fn = loss if loss is not None else DiffusionLoss(diffusion_net) + self.sampler = DiffusionSampler(diffusion_net) + self.ema = ExponentialMovingAverage(diffusion_net, decay=ema_decay) if ema_decay is not None else None + + for parameter in self.autoencoder.parameters(): + parameter.requires_grad = False + + @jaxtyped(typechecker=beartype) + def forward( + self, + x: Float[torch.Tensor, "batch input_steps channels height width"], + ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: + """Generate decoded forecasts from normalized input histories. + + Parameters + ---------- + x : Float[torch.Tensor, "batch input_steps channels height width"] + Normalized input history tensor. + + Returns + ------- + Float[torch.Tensor, "batch forecast_steps out_channels height width"] + Decoded normalized forecast tensor with ensemble members + concatenated along the channel dimension. + """ + input_latents = self.autoencoder.encode(x) + repeated_input_latents = input_latents.repeat_interleave(self.hparams["ensemble_size"], dim=0) + latent_shape = ( + x.shape[0] * self.hparams["ensemble_size"], + input_latents.shape[1], + self.hparams["forecast_steps"], + input_latents.shape[3], + input_latents.shape[4], + ) + forecast_latents = self.sampler.sample(repeated_input_latents, latent_shape) + decoded = self.autoencoder.decode(forecast_latents) + batch, time, channels, height, width = decoded.shape + decoded = decoded.reshape(x.shape[0], self.hparams["ensemble_size"], time, channels, height, width) + return decoded.permute(0, 2, 1, 3, 4, 5).reshape( + x.shape[0], time, self.hparams["ensemble_size"] * channels, height, width + ) + + def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: + """Compute latent diffusion loss for a forecasting batch. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch containing ``input`` and ``target`` tensors. + split : str, optional + Current data split. Default is ``"train"``. + + Returns + ------- + torch.Tensor + Scalar latent diffusion loss. + """ + with torch.no_grad(): + input_latents = self.autoencoder.encode(batch["input"]) + target_latents = self.autoencoder.encode(batch["target"]) + loss = self.loss_fn(input_latents, target_latents) + self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + return loss + + @property + def trainable_parameters(self) -> list[torch.nn.Parameter]: + """Return only the diffusion-network parameters. + + Returns + ------- + list of torch.nn.Parameter + Parameters optimized during stage-2 latent diffusion training. + """ + return list(self.diffusion_net.parameters()) + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: + """Update EMA after each training batch when enabled. + + Parameters + ---------- + outputs : Any + Lightning training outputs. + batch : Any + Batch passed to the training step. + batch_idx : int + Batch index supplied by Lightning. + """ + del outputs, batch, batch_idx + if self.ema is not None: + self.ema.update() + + def on_validation_start(self) -> None: + """Swap EMA weights in before validation when enabled.""" + if self.ema is not None: + self.ema.apply() + + def on_validation_end(self) -> None: + """Restore raw diffusion weights after validation when enabled.""" + if self.ema is not None: + self.ema.restore() + + def on_test_start(self) -> None: + """Swap EMA weights in before testing when enabled.""" + if self.ema is not None: + self.ema.apply() + + def on_test_end(self) -> None: + """Restore raw diffusion weights after testing when enabled.""" + if self.ema is not None: + self.ema.restore() + + def on_predict_start(self) -> None: + """Swap EMA weights in before prediction when enabled.""" + if self.ema is not None: + self.ema.apply() + + def on_predict_end(self) -> None: + """Restore raw diffusion weights after prediction when enabled.""" + if self.ema is not None: + self.ema.restore() + + +LatentDiffusionModule = LatentDiffusionTaskModule diff --git a/src/mlcast/modules/reconstruction.py b/src/mlcast/modules/reconstruction.py new file mode 100644 index 0000000..65dd348 --- /dev/null +++ b/src/mlcast/modules/reconstruction.py @@ -0,0 +1,160 @@ +"""Lightning module wrappers for reconstruction tasks.""" + +from collections.abc import Callable +from typing import Any + +import pytorch_lightning as pl +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.losses import build_loss + + +class ReconstructionTaskModule(pl.LightningModule): + """Generic reconstruction training wrapper. + + Parameters + ---------- + network : torch.nn.Module + Reconstruction model that maps an input tensor back to the same shape. + loss_class : type[torch.nn.Module] or str, optional + Loss function class or registry name. Default is ``"mse"``. + loss_params : dict or None, optional + Keyword arguments for the loss constructor. Default is ``None``. + optimizer : Callable[..., torch.optim.Optimizer] or None, optional + Optimizer factory. Default is ``None`` (Adam). + lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional + Learning-rate scheduler factory. Default is ``None``. + """ + + def __init__( + self, + network: torch.nn.Module, + loss_class: type[torch.nn.Module] | str = "mse", + loss_params: dict[str, Any] | None = None, + optimizer: Callable[..., torch.optim.Optimizer] | None = None, + lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, + ) -> None: + super().__init__() + self.save_hyperparameters("loss_class", "loss_params") + self.network = network + self.optimizer_factory = optimizer + self.lr_scheduler_factory = lr_scheduler + self.criterion = build_loss(loss_class=loss_class, loss_params=loss_params, masked_loss=False) + + @jaxtyped(typechecker=beartype) + def forward( + self, + x: Float[torch.Tensor, "batch time channels height width"], + ) -> Float[torch.Tensor, "batch time channels height width"]: + """Run the reconstruction network. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Normalized reconstruction input. + + Returns + ------- + Float[torch.Tensor, "batch time channels height width"] + Reconstructed normalized tensor. + """ + return self.network(x) + + def shared_step(self, batch: torch.Tensor, split: str = "train") -> torch.Tensor: + """Compute reconstruction loss for one batch. + + Parameters + ---------- + batch : torch.Tensor + Tensor-only reconstruction batch. + split : str, optional + Current data split. Default is ``"train"``. + + Returns + ------- + torch.Tensor + Scalar reconstruction loss. + """ + preds = self(batch).clamp(min=-1, max=1) + loss = self.criterion(preds, batch) + if isinstance(loss, tuple): + loss, log_dict = loss + self.log_dict( + log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True + ) + self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + return loss + + def training_step(self, batch: torch.Tensor, _batch_idx: int) -> torch.Tensor: + """Execute a training step. + + Parameters + ---------- + batch : torch.Tensor + Reconstruction batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Training loss. + """ + return self.shared_step(batch, split="train") + + def validation_step(self, batch: torch.Tensor, _batch_idx: int) -> torch.Tensor: + """Execute a validation step. + + Parameters + ---------- + batch : torch.Tensor + Reconstruction batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Validation loss. + """ + return self.shared_step(batch, split="val") + + def test_step(self, batch: torch.Tensor, _batch_idx: int) -> torch.Tensor: + """Execute a test step. + + Parameters + ---------- + batch : torch.Tensor + Reconstruction batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Test loss. + """ + return self.shared_step(batch, split="test") + + def configure_optimizers(self) -> Any: + """Configure optimizer and optional scheduler. + + Returns + ------- + Any + PyTorch Lightning optimizer configuration. + """ + if self.optimizer_factory is not None: + optimizer = self.optimizer_factory(self.parameters()) + else: + optimizer = torch.optim.Adam(self.parameters()) + + if self.lr_scheduler_factory is not None: + lr_scheduler = self.lr_scheduler_factory(optimizer) + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}} + return {"optimizer": optimizer} + + +ReconstructionModule = ReconstructionTaskModule diff --git a/src/mlcast/nowcasting_module.py b/src/mlcast/nowcasting_module.py index faa6e90..070bcef 100644 --- a/src/mlcast/nowcasting_module.py +++ b/src/mlcast/nowcasting_module.py @@ -1,292 +1,8 @@ -"""Generic Lightning module for radar precipitation nowcasting. +"""Backward-compatible import shim for the forecasting Lightning module.""" -Wraps an injected PyTorch :class:`nn.Module` (the network architecture) and -handles training, validation, and test steps including loss computation, -image logging, and optimizer configuration. -""" +from mlcast.modules.forecasting import ForecastingTaskModule -from collections.abc import Callable -from typing import Any +ForecastingModule = ForecastingTaskModule +NowcastLightningModule = ForecastingTaskModule -import numpy as np -import pytorch_lightning as pl -import torch -from beartype import beartype -from jaxtyping import Float, jaxtyped - -from mlcast.data.normalization import DENORMALIZATION_REGISTRY, NORMALIZATION_REGISTRY -from mlcast.losses import build_loss -from mlcast.visualization import log_images - - -class NowcastLightningModule(pl.LightningModule): - """Generic PyTorch Lightning module for nowcasting. - - Wraps an injected PyTorch `nn.Module` (the network architecture) and - handles training, validation, test steps, loss computation, ensemble - generation, and TensorBoard logging. - - Parameters - ---------- - network : torch.nn.Module - The PyTorch network architecture to train. - loss_class : type[torch.nn.Module] or str, optional - Loss function class or its string name. Default is ``"mse"``. - loss_params : dict or None, optional - Keyword arguments for the loss constructor. Default is ``None``. - masked_loss : bool, optional - Whether to wrap the loss with :class:`MaskedLoss`. Default is ``False``. - optimizer : Callable[..., torch.optim.Optimizer] or None, optional - A callable (e.g., a ``functools.partial``) that takes network parameters - and returns an instantiated optimizer. Default is ``None`` (Adam). - lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional - A callable (e.g., a ``functools.partial``) that takes an optimizer - and returns an instantiated learning rate scheduler. Default is ``None``. - """ - - def __init__( - self, - network: torch.nn.Module, - loss_class: type[torch.nn.Module] | str = "mse", - loss_params: dict[str, Any] | None = None, - masked_loss: bool = False, - optimizer: Callable[..., torch.optim.Optimizer] | None = None, - lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, - ) -> None: - super().__init__() - # Explicitly save hyperparameters that are accessed later via self.hparams - self.save_hyperparameters("loss_class", "loss_params", "masked_loss") - - self.network = network - self.optimizer_factory = optimizer - self.lr_scheduler_factory = lr_scheduler - - self.criterion = build_loss( - loss_class=loss_class, - loss_params=loss_params, - masked_loss=masked_loss, - ) - self.log_images_iterations = [50, 100, 200, 500, 750, 1000, 2000, 5000] - - @jaxtyped(typechecker=beartype) - def forward( - self, - x: Float[torch.Tensor, "batch time channels height width"], - ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: - """Run the network forward pass. - - Parameters - ---------- - x : Float[torch.Tensor, "batch time channels height width"] - Input tensor. - - Returns - ------- - preds : Float[torch.Tensor, "batch forecast_steps out_channels height width"] - Forecast tensor. - """ - return self.network(x) - - def shared_step(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: - """Shared forward step for training, validation, and testing. - - Parameters - ---------- - batch : dict of str to torch.Tensor - A dictionary containing the batched input data. Must contain the - key ``"data"`` and optionally ``"mask"`` if ``masked_loss`` is ``True``. - split : str, optional - The data split being processed (e.g., ``"train"``, ``"val"``, ``"test"``). - Used for logging. Default is ``"train"``. - Returns - ------- - loss : torch.Tensor - The computed loss for the batch. - """ - past = batch["input"] - future = batch["target"] - - preds = self(past).clamp(min=-1, max=1) - - if self.hparams["masked_loss"]: - mask = batch["target_mask"] - loss = self.criterion(preds, future, mask) - else: - loss = self.criterion(preds, future) - - if isinstance(loss, tuple): - loss, log_dict = loss - self.log_dict( - log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True - ) - - self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) - - ensemble_size = getattr(self.network, "ensemble_size", 1) - if ensemble_size > 1: - ensemble_std = preds.std(dim=2).mean() - self.log(f"{split}_ensemble_std", ensemble_std, on_epoch=True, sync_dist=True) - - if ( - split == "train" - and self.logger is not None - and getattr(self.logger, "experiment", None) is not None - and ( - self.global_step in self.log_images_iterations or self.global_step % self.log_images_iterations[-1] == 0 - ) - ): - log_images( - past=past, - future=future, - preds=preds, - logger=self.logger, # type: ignore - global_step=self.global_step, - ensemble_size=ensemble_size, - split=split, - ) - return loss - - def training_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: - """Execute a single training step. - - Parameters - ---------- - batch : dict of str to torch.Tensor - A dictionary containing the batched input data. - batch_idx : int - The index of the current batch. - - Returns - ------- - loss : torch.Tensor - The training loss. - """ - return self.shared_step(batch, split="train") - - def validation_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: - """Execute a single validation step. - - Parameters - ---------- - batch : dict of str to torch.Tensor - A dictionary containing the batched input data. - batch_idx : int - The index of the current batch. - - Returns - ------- - loss : torch.Tensor - The validation loss. - """ - return self.shared_step(batch, split="val") - - def test_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: - """Execute a single test step. - - Parameters - ---------- - batch : dict of str to torch.Tensor - A dictionary containing the batched input data. - batch_idx : int - The index of the current batch. - - Returns - ------- - loss : torch.Tensor - The test loss. - """ - return self.shared_step(batch, split="test") - - def configure_optimizers(self) -> Any: - """Configure the optimizer and optional learning rate scheduler. - - Returns - ------- - config : dict of str to Any - A dictionary containing the instantiated ``"optimizer"`` and - optionally ``"lr_scheduler"`` configurations for PyTorch Lightning. - """ - if self.optimizer_factory is not None: - optimizer = self.optimizer_factory(self.parameters()) - else: - optimizer = torch.optim.Adam(self.parameters()) - - if self.lr_scheduler_factory is not None: - lr_scheduler = self.lr_scheduler_factory(optimizer) - return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}} - else: - return {"optimizer": optimizer} - - @classmethod - def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "NowcastLightningModule": - """Load a model from a checkpoint file. - - Parameters - ---------- - checkpoint_path : str - Path to the saved PyTorch Lightning checkpoint (``.ckpt``) file. - device : str, optional - The device to map the model weights to (e.g., ``"cpu"`` or ``"cuda"``). - Default is ``"cpu"``. - - Returns - ------- - model : NowcastLightningModule - The loaded PyTorch Lightning model instance. - """ - return cls.load_from_checkpoint( - checkpoint_path, - map_location=torch.device(device), - strict=True, - weights_only=False, - ) - - def predict( - self, - past: torch.Tensor, - standard_name: str = "rainfall_rate", - ) -> np.ndarray[Any, Any]: - """Generate precipitation forecasts from past radar observations. - - Input should be raw unnormalized values. - - Parameters - ---------- - past : torch.Tensor - Past radar frames as unnormalized values (e.g., mm/h or kg m-2 s-1), of shape ``(T, H, W)``. - standard_name : str, optional - The CF standard name defining the input/output domain for normalization lookup. - Default is ``"rainfall_rate"``. - - Returns - ------- - preds : np.ndarray - Forecasted unnormalized values, of shape - ``(ensemble_size, forecast_steps, H, W)``. The ensemble size and - forecast horizon are determined by the configured network. - """ - if len(past.shape) != 3: - raise ValueError("Input must be of shape (T, H, W)") - - past_clean = np.nan_to_num(past.cpu().numpy()) - past_clean = past_clean[np.newaxis, :, np.newaxis, ...] - - norm_func = NORMALIZATION_REGISTRY[standard_name] - norm_past = norm_func(past_clean) - - x = torch.from_numpy(norm_past) - x = x.to(self.device) - - self.eval() - with torch.no_grad(): - preds_tensor = self.network(x) - - preds_np: np.ndarray[Any, Any] = preds_tensor.cpu().numpy() - - denorm_func = DENORMALIZATION_REGISTRY[standard_name] - preds_np = denorm_func(preds_np) - - preds_np = preds_np.squeeze(0) - preds_np = np.swapaxes(preds_np, 0, 1) - - return preds_np +__all__ = ["ForecastingModule", "ForecastingTaskModule", "NowcastLightningModule"] diff --git a/tests/test_nowcasting_module.py b/tests/test_nowcasting_module.py index d859fb4..6f12039 100644 --- a/tests/test_nowcasting_module.py +++ b/tests/test_nowcasting_module.py @@ -1,7 +1,7 @@ import numpy as np import torch -from mlcast.nowcasting_module import NowcastLightningModule +from mlcast.modules.forecasting import ForecastingTaskModule class DummyForecastNetwork(torch.nn.Module): @@ -20,9 +20,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_nowcasting_module_forward_uses_network_shape_contract() -> None: - """NowcastLightningModule should call fixed-shape forecasting networks as network(x).""" + """ForecastingTaskModule should call fixed-shape forecasting networks as network(x).""" network = DummyForecastNetwork(input_steps=3, forecast_steps=5, ensemble_size=2) - module = NowcastLightningModule(network=network, loss_class="crps") + module = ForecastingTaskModule(network=network, loss_class="crps") x = torch.randn(4, 3, 1, 8, 8) preds = module(x) @@ -33,7 +33,7 @@ def test_nowcasting_module_forward_uses_network_shape_contract() -> None: def test_nowcasting_module_predict_uses_configured_output_shape() -> None: """Prediction horizon and ensemble size should come from the configured network.""" network = DummyForecastNetwork(input_steps=3, forecast_steps=4, ensemble_size=2) - module = NowcastLightningModule(network=network, loss_class="crps") + module = ForecastingTaskModule(network=network, loss_class="crps") past = torch.ones(3, 8, 8) preds = module.predict(past, standard_name="rainfall_rate") diff --git a/tests/test_task_modules.py b/tests/test_task_modules.py new file mode 100644 index 0000000..a55d49e --- /dev/null +++ b/tests/test_task_modules.py @@ -0,0 +1,98 @@ +import numpy as np +import torch + +from mlcast.models.autoencoder import AutoencoderNet, Decoder, Encoder +from mlcast.models.diffusion import ConditionerNet, DenoiserUNet, DiffusionScheduler, LatentDiffusionNet +from mlcast.modules.forecasting import ForecastingTaskModule, LatentDiffusionTaskModule +from mlcast.modules.reconstruction import ReconstructionTaskModule + + +class IdentityReconstructionNetwork(torch.nn.Module): + """Minimal reconstruction network used in wrapper tests.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return the input unchanged.""" + return x + + +def test_reconstruction_module_uses_batch_as_target() -> None: + """ReconstructionTaskModule should compute loss against the input batch itself.""" + module = ReconstructionTaskModule(network=IdentityReconstructionNetwork(), loss_class="mse") + batch = torch.randn(2, 3, 1, 4, 4) + + loss = module.training_step(batch, 0) + + assert torch.isfinite(loss) + assert loss.ndim == 0 + + +def test_forecasting_task_module_trainable_parameters_match_network() -> None: + """ForecastingTaskModule should optimize the forecasting network parameters.""" + + class TinyForecastNetwork(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + network = TinyForecastNetwork() + module = ForecastingTaskModule(network=network, loss_class="mse") + + assert module.trainable_parameters == list(network.parameters()) + + +def _build_autoencoder() -> AutoencoderNet: + encoder = Encoder(input_channels=1, hidden_channels=4, latent_channels=4, num_blocks=1) + decoder = Decoder(output_channels=1, hidden_channels=4, latent_channels=4, num_blocks=1) + return AutoencoderNet(encoder=encoder, decoder=decoder) + + +def _build_diffusion_net() -> LatentDiffusionNet: + conditioner = ConditionerNet(latent_channels=4, hidden_channels=8, num_blocks=1) + denoiser = DenoiserUNet(latent_channels=4, condition_channels=8, hidden_channels=8, num_blocks=2) + scheduler = DiffusionScheduler(timesteps=2) + return LatentDiffusionNet(conditioner=conditioner, denoiser=denoiser, scheduler=scheduler) + + +def test_latent_diffusion_module_training_step_runs() -> None: + """LatentDiffusionTaskModule should encode forecasting batches and return scalar loss.""" + module = LatentDiffusionTaskModule( + autoencoder=_build_autoencoder(), diffusion_net=_build_diffusion_net(), forecast_steps=3 + ) + batch = { + "input": torch.randn(2, 2, 1, 8, 8), + "target": torch.randn(2, 3, 1, 8, 8), + } + + loss = module.training_step(batch, 0) + + assert torch.isfinite(loss) + assert loss.ndim == 0 + + +def test_latent_diffusion_task_module_trainable_parameters_exclude_autoencoder() -> None: + """LatentDiffusionTaskModule should optimize only diffusion-net parameters.""" + autoencoder = _build_autoencoder() + diffusion_net = _build_diffusion_net() + module = LatentDiffusionTaskModule(autoencoder=autoencoder, diffusion_net=diffusion_net, forecast_steps=3) + + assert module.trainable_parameters == list(diffusion_net.parameters()) + assert module.trainable_parameters != list(autoencoder.parameters()) + + +def test_latent_diffusion_module_predict_uses_configured_output_shape() -> None: + """LatentDiffusionTaskModule prediction should decode configured ensemble forecasts.""" + module = LatentDiffusionTaskModule( + autoencoder=_build_autoencoder(), + diffusion_net=_build_diffusion_net(), + forecast_steps=3, + ensemble_size=2, + ) + past = torch.ones(2, 8, 8) + + preds = module.predict(past, standard_name="rainfall_rate") + + assert isinstance(preds, np.ndarray) + assert preds.shape == (2, 3, 8, 8) From c19cc519a15274a9cc1c26c612a7a5a6d13190cf Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 3 Jun 2026 23:24:11 +0200 Subject: [PATCH 17/34] refactor: clarify forecasting task module roles --- README.md | 2 +- docs/config_diagram.svg | 96 +++++++++++----------- ldcast-refactor-plan.md | 4 +- src/mlcast/config/base.py | 4 +- src/mlcast/modules/__init__.py | 6 +- src/mlcast/modules/forecasting.py | 128 +++++++++++++++++++++++++++--- src/mlcast/nowcasting_module.py | 14 +++- tests/test_nowcasting_module.py | 8 +- tests/test_task_modules.py | 8 +- 9 files changed, 192 insertions(+), 78 deletions(-) diff --git a/README.md b/README.md index 30e1950..5b58d83 100644 --- a/README.md +++ b/README.md @@ -213,7 +213,7 @@ from mlcast.config.fiddlers import use_random_sampler # Minimal adapter: channel-stack past frames -> HalfUNet -> one step at a time. # The forecasting contract fixes input_steps, forecast_steps, and ensemble_size # at model initialization; this minimal deterministic adapter exposes one -# ensemble member and ForecastingTaskModule calls network(x). +# ensemble member and OutputSpaceForecastingTaskModule calls network(x). class HalfUNetNowcaster(nn.Module): def __init__(self, input_steps: int = 6, forecast_steps: int = 12, num_vars: int = 1): super().__init__() diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 6cc3ba8..2ef2cd2 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -50,55 +50,55 @@ 1 - - -Config: - ForecastingTaskModule - - -network - - - - - -loss_class - -'crps' - - -loss_params - - - -dict - - -'temporal_lambda' - -0.01 - - -masked_loss - -True - - -optimizer - - - - - -lr_scheduler - - - + + +Config: + OutputSpaceForecastingTaskModule + + +network + + + + + +loss_class + +'crps' + + +loss_params + + + +dict + + +'temporal_lambda' + +0.01 + + +masked_loss + +True + + +optimizer + + + + + +lr_scheduler + + + 1:c--2:c - + @@ -121,7 +121,7 @@ 1:c--3:c - + @@ -149,7 +149,7 @@ 1:c--4:c - + @@ -180,7 +180,7 @@ 0:c--1:c - + diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 83c8562..382bcd9 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -59,9 +59,9 @@ - [x] diffusion model improves loss on a small generated latent dataset after a few training steps. 5. Task modules (Lightning modules) -- [x] Add `src/mlcast/modules/forecasting.py`, introduce `BaseForecastingTaskModule`, and rename `NowcastLightningModule` to `ForecastingTaskModule`. +- [x] Add `src/mlcast/modules/forecasting.py`, introduce `BaseForecastingTaskModule`, and rename `NowcastLightningModule` to `OutputSpaceForecastingTaskModule`. - [x] `BaseForecastingTaskModule` should own optimizer/scheduler plumbing, while each concrete task module defines which parameters are trainable. -- [x] `ForecastingTaskModule` should optimize the forecasting network parameters. +- [x] `OutputSpaceForecastingTaskModule` should optimize the forecasting network parameters. - [x] Remove runtime `forecast_steps` and `ensemble_size` arguments from the forecasting task module and its `predict()` API. - [x] Add `src/mlcast/modules/reconstruction.py` with a generic `ReconstructionTaskModule` for any reconstruction model. - [x] Add a `LatentDiffusionTaskModule` that owns the trained autoencoder, optimizes only the diffusion-network parameters, trains diffusion in latent space, and handles decoded forecast inference. diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index 311f4a5..9343b08 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -32,7 +32,7 @@ from ..data.datamodules import ForecastingDataModule from ..data.sequence import SourceDataPrecomputedSequenceDataset from ..models.convgru import ConvGruModel -from ..modules.forecasting import ForecastingTaskModule +from ..modules.forecasting import OutputSpaceForecastingTaskModule @dataclass @@ -91,7 +91,7 @@ def convgru_training_experiment() -> Experiment: noisy_decoder=False, ) - pl_module = ForecastingTaskModule( + pl_module = OutputSpaceForecastingTaskModule( network=network, loss_class="crps", loss_params={"temporal_lambda": 0.01}, diff --git a/src/mlcast/modules/__init__.py b/src/mlcast/modules/__init__.py index ce40eaf..c8cd0cf 100644 --- a/src/mlcast/modules/__init__.py +++ b/src/mlcast/modules/__init__.py @@ -2,19 +2,17 @@ from .forecasting import ( BaseForecastingTaskModule, - ForecastingModule, - ForecastingTaskModule, LatentDiffusionModule, LatentDiffusionTaskModule, + OutputSpaceForecastingTaskModule, ) from .reconstruction import ReconstructionModule, ReconstructionTaskModule __all__ = [ "BaseForecastingTaskModule", - "ForecastingModule", - "ForecastingTaskModule", "LatentDiffusionModule", "LatentDiffusionTaskModule", + "OutputSpaceForecastingTaskModule", "ReconstructionModule", "ReconstructionTaskModule", ] diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py index 3d6b636..c36ddac 100644 --- a/src/mlcast/modules/forecasting.py +++ b/src/mlcast/modules/forecasting.py @@ -23,6 +23,42 @@ class BaseForecastingTaskModule(pl.LightningModule): """Base Lightning module for forecasting-shaped tasks. + Purpose + ------- + This class provides the common PyTorch Lightning plumbing shared by + forecasting-oriented task modules. It centralizes the optimizer and + scheduler configuration interface, the train/validation/test step routing, + and the normalization-aware prediction helper used by forecasting tasks. + + Ownership + --------- + This base class owns: + + - optimizer and scheduler factories + - generic Lightning step orchestration + - normalization and denormalization logic for ``predict`` + + It does not own: + + - a specific forecasting architecture + - a concrete task loss + - the choice of which parameters are trainable + - any task-specific inference logic beyond normalized I/O handling + + Training behavior + ----------------- + Training, validation, and test steps all delegate to the subclass hook + :meth:`compute_loss`. Subclasses are also responsible for exposing the + exact parameter set to optimize through the :attr:`trainable_parameters` + property. + + Inference behavior + ------------------ + ``predict`` accepts unnormalized input observations, applies the configured + normalization for the requested standard name, delegates normalized + forecasting to :meth:`predict_normalized`, then denormalizes the model + outputs back to physical units. + Parameters ---------- optimizer : Callable[..., torch.optim.Optimizer] or None, optional @@ -195,8 +231,44 @@ def predict(self, past: torch.Tensor, standard_name: str = "rainfall_rate") -> n return preds_np -class ForecastingTaskModule(BaseForecastingTaskModule): - """Generic PyTorch Lightning module for direct forecasting tasks. +class OutputSpaceForecastingTaskModule(BaseForecastingTaskModule): + """Lightning task module for direct forecasting in output space. + + Purpose + ------- + This task module trains conventional forecasting models whose outputs can be + compared directly against forecast targets in the original normalized data + space. It is the generic wrapper used for models such as ConvGRU, where a + single forward pass produces forecast tensors that are supervised directly. + + Ownership + --------- + This class owns: + + - the forecasting network passed in as ``network`` + - the output-space forecasting loss + - optional masked-loss behavior using ``target_mask`` + - image and ensemble-statistic logging specific to direct forecast outputs + + It does not own: + + - source-data normalization rules outside the inherited ``predict`` helper + - latent-space encoding or decoding components + - sampler-driven generative forecast logic + + Training behavior + ----------------- + A forecasting batch provides ``input`` and ``target`` tensors, plus an + optional ``target_mask``. The module calls ``network(input)`` to obtain a + normalized forecast tensor, optionally applies masked loss, and compares the + network outputs directly against the target tensor in output space. + + Inference behavior + ------------------ + Inference is a direct forward pass through the forecasting network. The + inherited :meth:`predict` helper normalizes raw inputs, calls + :meth:`predict_normalized`, and denormalizes the resulting forecast back to + physical units. Parameters ---------- @@ -319,7 +391,7 @@ def trainable_parameters(self) -> list[torch.nn.Parameter]: return list(self.network.parameters()) @classmethod - def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "ForecastingTaskModule": + def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "OutputSpaceForecastingTaskModule": """Load a forecasting task module from checkpoint. Parameters @@ -331,8 +403,8 @@ def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "Forecast Returns ------- - ForecastingTaskModule - Loaded forecasting task module. + OutputSpaceForecastingTaskModule + Loaded output-space forecasting task module. """ return cls.load_from_checkpoint( checkpoint_path, @@ -342,11 +414,49 @@ def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "Forecast ) -ForecastingModule = ForecastingTaskModule - - class LatentDiffusionTaskModule(BaseForecastingTaskModule): - """Train latent diffusion in latent space and decode forecasts for inference. + """Lightning task module for latent diffusion forecasting. + + Purpose + ------- + This task module trains a latent diffusion forecasting system that reuses a + stage-1 autoencoder. Forecast supervision is applied in latent space rather + than directly on decoded forecast tensors. At inference time, the module + samples forecast latents and decodes them back to the original data space. + + Ownership + --------- + This class owns: + + - the trained autoencoder reused from stage 1 + - the latent diffusion architecture + - the latent diffusion loss + - the diffusion sampler used for forecast generation + - optional EMA tracking over diffusion-network weights + + It does not own: + + - stage-1 autoencoder training + - output-space supervision for the diffusion loss + - the source-data normalization rules beyond the inherited ``predict`` + helper + + Training behavior + ----------------- + A forecasting batch provides raw normalized ``input`` and ``target`` + tensors. The reused autoencoder encoder maps both into latent space under + ``torch.no_grad()``. The module then computes a diffusion loss entirely on + latent tensors and exposes only the diffusion-network parameters through + :attr:`trainable_parameters`, so the reused autoencoder remains frozen. + + Inference behavior + ------------------ + Inference encodes the input history with the reused autoencoder, samples a + latent forecast with the diffusion sampler, then decodes the sampled latent + forecast back to data space. Ensemble generation is explicit here: the + module repeats encoded inputs per requested ensemble member, samples a + forecast latent for each member, and concatenates the decoded members along + the channel dimension. Parameters ---------- diff --git a/src/mlcast/nowcasting_module.py b/src/mlcast/nowcasting_module.py index 070bcef..1230719 100644 --- a/src/mlcast/nowcasting_module.py +++ b/src/mlcast/nowcasting_module.py @@ -1,8 +1,14 @@ """Backward-compatible import shim for the forecasting Lightning module.""" -from mlcast.modules.forecasting import ForecastingTaskModule +from mlcast.modules.forecasting import OutputSpaceForecastingTaskModule -ForecastingModule = ForecastingTaskModule -NowcastLightningModule = ForecastingTaskModule +ForecastingTaskModule = OutputSpaceForecastingTaskModule +ForecastingModule = OutputSpaceForecastingTaskModule +NowcastLightningModule = OutputSpaceForecastingTaskModule -__all__ = ["ForecastingModule", "ForecastingTaskModule", "NowcastLightningModule"] +__all__ = [ + "ForecastingModule", + "ForecastingTaskModule", + "NowcastLightningModule", + "OutputSpaceForecastingTaskModule", +] diff --git a/tests/test_nowcasting_module.py b/tests/test_nowcasting_module.py index 6f12039..ddca967 100644 --- a/tests/test_nowcasting_module.py +++ b/tests/test_nowcasting_module.py @@ -1,7 +1,7 @@ import numpy as np import torch -from mlcast.modules.forecasting import ForecastingTaskModule +from mlcast.modules.forecasting import OutputSpaceForecastingTaskModule class DummyForecastNetwork(torch.nn.Module): @@ -20,9 +20,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def test_nowcasting_module_forward_uses_network_shape_contract() -> None: - """ForecastingTaskModule should call fixed-shape forecasting networks as network(x).""" + """OutputSpaceForecastingTaskModule should call fixed-shape forecasting networks as network(x).""" network = DummyForecastNetwork(input_steps=3, forecast_steps=5, ensemble_size=2) - module = ForecastingTaskModule(network=network, loss_class="crps") + module = OutputSpaceForecastingTaskModule(network=network, loss_class="crps") x = torch.randn(4, 3, 1, 8, 8) preds = module(x) @@ -33,7 +33,7 @@ def test_nowcasting_module_forward_uses_network_shape_contract() -> None: def test_nowcasting_module_predict_uses_configured_output_shape() -> None: """Prediction horizon and ensemble size should come from the configured network.""" network = DummyForecastNetwork(input_steps=3, forecast_steps=4, ensemble_size=2) - module = ForecastingTaskModule(network=network, loss_class="crps") + module = OutputSpaceForecastingTaskModule(network=network, loss_class="crps") past = torch.ones(3, 8, 8) preds = module.predict(past, standard_name="rainfall_rate") diff --git a/tests/test_task_modules.py b/tests/test_task_modules.py index a55d49e..a267f31 100644 --- a/tests/test_task_modules.py +++ b/tests/test_task_modules.py @@ -3,7 +3,7 @@ from mlcast.models.autoencoder import AutoencoderNet, Decoder, Encoder from mlcast.models.diffusion import ConditionerNet, DenoiserUNet, DiffusionScheduler, LatentDiffusionNet -from mlcast.modules.forecasting import ForecastingTaskModule, LatentDiffusionTaskModule +from mlcast.modules.forecasting import LatentDiffusionTaskModule, OutputSpaceForecastingTaskModule from mlcast.modules.reconstruction import ReconstructionTaskModule @@ -26,8 +26,8 @@ def test_reconstruction_module_uses_batch_as_target() -> None: assert loss.ndim == 0 -def test_forecasting_task_module_trainable_parameters_match_network() -> None: - """ForecastingTaskModule should optimize the forecasting network parameters.""" +def test_output_space_forecasting_task_module_trainable_parameters_match_network() -> None: + """OutputSpaceForecastingTaskModule should optimize the forecasting network parameters.""" class TinyForecastNetwork(torch.nn.Module): def __init__(self) -> None: @@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x network = TinyForecastNetwork() - module = ForecastingTaskModule(network=network, loss_class="mse") + module = OutputSpaceForecastingTaskModule(network=network, loss_class="mse") assert module.trainable_parameters == list(network.parameters()) From 9371fca89f2b016011da99f2f32521d60ec99560 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 3 Jun 2026 23:31:38 +0200 Subject: [PATCH 18/34] refactor: remove obsolete module aliases --- src/mlcast/modules/__init__.py | 5 +-- src/mlcast/modules/forecasting.py | 3 -- src/mlcast/modules/reconstruction.py | 46 ++++++++++++++++++++++++---- src/mlcast/nowcasting_module.py | 14 --------- 4 files changed, 41 insertions(+), 27 deletions(-) delete mode 100644 src/mlcast/nowcasting_module.py diff --git a/src/mlcast/modules/__init__.py b/src/mlcast/modules/__init__.py index c8cd0cf..46f7d2a 100644 --- a/src/mlcast/modules/__init__.py +++ b/src/mlcast/modules/__init__.py @@ -2,17 +2,14 @@ from .forecasting import ( BaseForecastingTaskModule, - LatentDiffusionModule, LatentDiffusionTaskModule, OutputSpaceForecastingTaskModule, ) -from .reconstruction import ReconstructionModule, ReconstructionTaskModule +from .reconstruction import ReconstructionTaskModule __all__ = [ "BaseForecastingTaskModule", - "LatentDiffusionModule", "LatentDiffusionTaskModule", "OutputSpaceForecastingTaskModule", - "ReconstructionModule", "ReconstructionTaskModule", ] diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py index c36ddac..96060cd 100644 --- a/src/mlcast/modules/forecasting.py +++ b/src/mlcast/modules/forecasting.py @@ -623,6 +623,3 @@ def on_predict_end(self) -> None: """Restore raw diffusion weights after prediction when enabled.""" if self.ema is not None: self.ema.restore() - - -LatentDiffusionModule = LatentDiffusionTaskModule diff --git a/src/mlcast/modules/reconstruction.py b/src/mlcast/modules/reconstruction.py index 65dd348..f23bab3 100644 --- a/src/mlcast/modules/reconstruction.py +++ b/src/mlcast/modules/reconstruction.py @@ -12,7 +12,42 @@ class ReconstructionTaskModule(pl.LightningModule): - """Generic reconstruction training wrapper. + """Lightning task module for reconstruction training. + + Purpose + ------- + This task module trains reconstruction models on tensor-only batches from + ``ReconstructionDataset``. It is intended for stage-1 reconstruction or + autoencoder training, where the model learns to reproduce normalized + sequence windows. + + Ownership + --------- + This class owns: + + - the reconstruction network + - the reconstruction loss defined by ``loss_class`` and ``loss_params`` + - the optimizer and learning-rate scheduler factories + + It does not own: + + - source-data normalization rules + - forecasting-specific targets, masks, or ensemble behavior + - latent diffusion training or sampler-driven inference logic + + Training behavior + ----------------- + Each batch is a tensor-only reconstruction sample. The module uses that + tensor as both the model input and the reconstruction target, computes the + reconstruction loss directly in output space, and logs the resulting scalar + loss for the active split. + + Inference behavior + ------------------ + ``forward`` applies the reconstruction network to a normalized input tensor + and returns a reconstructed normalized tensor of the same shape. This + module does not implement forecasting-specific prediction helpers or any + sampler-based inference path. Parameters ---------- @@ -23,9 +58,11 @@ class ReconstructionTaskModule(pl.LightningModule): loss_params : dict or None, optional Keyword arguments for the loss constructor. Default is ``None``. optimizer : Callable[..., torch.optim.Optimizer] or None, optional - Optimizer factory. Default is ``None`` (Adam). + Optimizer factory used by :meth:`configure_optimizers`. Default is + ``None`` (Adam over ``self.parameters()``). lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional - Learning-rate scheduler factory. Default is ``None``. + Learning-rate scheduler factory used by :meth:`configure_optimizers`. + Default is ``None``. """ def __init__( @@ -155,6 +192,3 @@ def configure_optimizers(self) -> Any: lr_scheduler = self.lr_scheduler_factory(optimizer) return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}} return {"optimizer": optimizer} - - -ReconstructionModule = ReconstructionTaskModule diff --git a/src/mlcast/nowcasting_module.py b/src/mlcast/nowcasting_module.py deleted file mode 100644 index 1230719..0000000 --- a/src/mlcast/nowcasting_module.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Backward-compatible import shim for the forecasting Lightning module.""" - -from mlcast.modules.forecasting import OutputSpaceForecastingTaskModule - -ForecastingTaskModule = OutputSpaceForecastingTaskModule -ForecastingModule = OutputSpaceForecastingTaskModule -NowcastLightningModule = OutputSpaceForecastingTaskModule - -__all__ = [ - "ForecastingModule", - "ForecastingTaskModule", - "NowcastLightningModule", - "OutputSpaceForecastingTaskModule", -] From 6d67659fe6f61e7773a5e361e5d287368d52e277 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 00:00:27 +0200 Subject: [PATCH 19/34] feat: add LDCast two-stage training experiment - Add LDCastTrainingExperiment dataclass and ldcast_training_experiment() config - Add LDCast-specific validation in consistency_checks.py - Freeze autoencoder on fit/val/test/predict start in LatentDiffusionTaskModule - Export LDCast config from mlcast.config - Add identity-sharing and sequencing tests - Add train_from_config test for LDCast - Document DMI alignment differences in plan --- ldcast-refactor-plan.md | 85 ++++++++++++-- src/mlcast/config/__init__.py | 3 + src/mlcast/config/consistency_checks.py | 71 ++++++++++-- src/mlcast/config/ldcast.py | 148 ++++++++++++++++++++++++ src/mlcast/modules/forecasting.py | 15 +++ tests/config/test_ldcast_experiment.py | 46 ++++++++ tests/config/test_orchestrator.py | 11 +- 7 files changed, 360 insertions(+), 19 deletions(-) create mode 100644 src/mlcast/config/ldcast.py create mode 100644 tests/config/test_ldcast_experiment.py diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 382bcd9..06e0d49 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -3,7 +3,7 @@ 0. Config naming and CLI contract - [x] Rename `training_experiment` to `convgru_training_experiment`. - [x] Do not keep `training_experiment` as an alias. -- [ ] Reserve `ldcast_training_experiment` as the top-level config name for the new two-stage LDCast workflow. +- [x] Reserve `ldcast_training_experiment` as the top-level config name for the new two-stage LDCast workflow. - [x] Require `mlcast train` users to provide an explicit base config via `--config config:` or `--config /path/to/config.yaml`. - [x] Update CLI help text to list the included config entry points explicitly. - [x] Update all docs, examples, tests, and scripts to use `convgru_training_experiment` instead of `training_experiment`. @@ -68,22 +68,85 @@ - [x] Keep `modules/` for task-level Lightning modules only; keep `models/` for pure architectures. 6. Training experiment -- [ ] Add a new LDCast-specific training module containing `LDCastTrainingExperiment`. +- [x] Add a new LDCast-specific training module containing `LDCastTrainingExperiment`. - [x] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. -- [ ] Stage 1 builds the reconstruction dataset, autoencoder model, and `ReconstructionTaskModule`, then trains the autoencoder. -- [ ] Stage 2 reuses the same trained in-memory autoencoder instance, builds the diffusion dataset/model/`LatentDiffusionTaskModule`, then trains latent diffusion. -- [ ] Stage 2 freezes the reused autoencoder parameters and optimizes only the latent diffusion task module's diffusion-network parameters. -- [ ] The shared Fiddle graph should define the autoencoder once and reference the same object in both stages, but no unresolved Fiddle objects should flow into actual `torch.nn.Module.__init__` calls. -- [ ] Stage-2 diffusion training uses the trained encoder to produce input and target latents; the trained decoder is retained for final forecast decoding but is not used in the stage-2 diffusion loss. -- [ ] Reuse the same forecasting dataset abstraction in stage 2; do not add a separate latent dataset layer. -- [ ] Add tests for shared object identity and stage sequencing. +- [x] Stage 1 builds the reconstruction dataset, autoencoder model, and `ReconstructionTaskModule`, then trains the autoencoder. +- [x] Stage 2 reuses the same trained in-memory autoencoder instance, builds the diffusion dataset/model/`LatentDiffusionTaskModule`, then trains latent diffusion. +- [x] Stage 2 freezes the reused autoencoder parameters and optimizes only the latent diffusion task module's diffusion-network parameters. +- [x] The shared Fiddle graph should define the autoencoder once and reference the same object in both stages, but no unresolved Fiddle objects should flow into actual `torch.nn.Module.__init__` calls. +- [x] Stage-2 diffusion training uses the trained encoder to produce input and target latents; the trained decoder is retained for final forecast decoding but is not used in the stage-2 diffusion loss. +- [x] Reuse the same forecasting dataset abstraction in stage 2; do not add a separate latent dataset layer. +- [x] Add tests for shared object identity and stage sequencing. 7. Audit and migration targets - [x] Update CLI/help text in `src/mlcast/__main__.py` to require an explicit base config and list the included config entry points. - [x] Rename `training_experiment` to `convgru_training_experiment` in `src/mlcast/config/base.py` and export it from `src/mlcast/config/__init__.py`. -- [ ] Add the LDCast config entry point to `src/mlcast/config/__init__.py` alongside the existing ConvGRU example config. -- [ ] Keep `src/mlcast/config/orchestrator.py` compatible with both the existing single-stage `Experiment` and the new `LDCastTrainingExperiment` through a common `run()` surface. +- [x] Add the LDCast config entry point to `src/mlcast/config/__init__.py` alongside the existing ConvGRU example config. +- [x] Keep `src/mlcast/config/orchestrator.py` compatible with both the existing single-stage `Experiment` and the new `LDCastTrainingExperiment` through a common `run()` surface. - [x] Update docstrings and comments that currently imply `training_experiment` is the only experiment, including `src/mlcast/data/source_data_datamodule.py`, `src/mlcast/config/orchestrator.py`, and related config docs. - [x] Update docs and scripts that still reference `training_experiment`, including `README.md` and `docs/generate_base_experiment_config_diagram.py`. - [x] Keep existing ConvGRU CLI/config tests passing while adding separate tests for selecting the LDCast config explicitly. - [ ] Add real but small-scale end-to-end tests with generated sample data for the autoencoder stage, diffusion stage, and full LDCast stage sequencing. + +## DMI alignment notes + +The `ldcast-dmi/` reference implementation differs from our current +`ldcast_training_experiment` config in several ways. Changes below would +align us more closely with DMI. + +### Optimizer +- **DMI**: `AdamW` with `lr=1e-3` (autoenc) / `1e-4` (diffusion), + `betas=(0.5, 0.9)`, `weight_decay=1e-3` for **both** stages. +- **Ours**: `Adam` with `lr=1e-4` for both stages, default betas, no + weight decay. +- **To align**: switch to `AdamW`, use DMI betas/weight_decay, and raise + autoencoder LR to `1e-3`. + +### LR scheduler +- **DMI**: `ReduceLROnPlateau(factor=0.25, patience=3)`, monitors + `val_rec_loss` (autoenc) / `val_loss_ema` (diffusion). +- **Ours**: `ReduceLROnPlateau(factor=0.5, patience=10)`, monitors + `val_loss` for both stages. +- **To align**: reduce factor to `0.25` and patience to `3`; use separate + monitor metrics per stage (autoenc → `val_loss`, diffusion → `val_loss`). + +### Learning rate warmup +- **DMI**: Linear warmup support in diffusion stage (`lr_warmup`, default + 0 — disabled). Autoencoder has none. +- **Ours**: No warmup in either stage. +- **To align**: no change needed unless LR warmup is desired. + +### EMA +- **DMI**: `LitEma` with `decay=0.9999` (adaptive based on num_updates), + only on diffusion model weights. EMA weights swapped in for + validation/testing. +- **Ours**: `ExponentialMovingAverage` with `decay=0.999` for diffusion + net, swapped in for val/test. +- **To align**: increase EMA decay to `0.9999`. + +### Early stopping +- **DMI**: patience `6`, monitors `val_rec_loss` / `val_loss_ema`, + `check_finite=False` on diffusion. +- **Ours**: patience `20`, monitors `val_loss`. +- **To align**: reduce patience to `6`; consider `check_finite=False`. + +### Model checkpointing +- **DMI**: `save_top_k=3`, monitors `val_rec_loss` / `val_loss_ema`. +- **Ours**: `save_top_k=1`, monitors `val_loss`. +- **To align**: increase save_top_k to `3`. + +### Diffusion noise schedule +- **DMI**: `timesteps=1000`, linear beta schedule from `1e-4` to `2e-2`. +- **Ours**: `timesteps=20`, default linear schedule. +- **To align**: increase to `timesteps=1000` and match beta range. + +### Batch size and gradient accumulation +- **DMI**: `batch_size=4` (autoenc, example) / `batch_size=1` (diffusion, + example); `accumulate_grad_batches=2`. +- **Ours**: `batch_size=16` / `8`; no gradient accumulation. +- **To align**: reduce batch sizes and add `accumulate_grad_batches=2`. + +### DDP strategy +- **DMI**: `DDPStrategy(find_unused_parameters=True)` on autoencoder. +- **Ours**: default (no `DDPStrategy`). +- **To align**: no change needed unless running DDP. diff --git a/src/mlcast/config/__init__.py b/src/mlcast/config/__init__.py index 9338dbd..e566eb5 100644 --- a/src/mlcast/config/__init__.py +++ b/src/mlcast/config/__init__.py @@ -14,12 +14,15 @@ use_random_sampler, use_ratio_splits, ) +from .ldcast import LDCastTrainingExperiment, ldcast_training_experiment from .loader import load_yaml_config from .orchestrator import train_from_config __all__ = [ "Experiment", + "LDCastTrainingExperiment", "convgru_training_experiment", + "ldcast_training_experiment", "validate_config", "train_from_config", "load_yaml_config", diff --git a/src/mlcast/config/consistency_checks.py b/src/mlcast/config/consistency_checks.py index 22363f8..bd4b424 100644 --- a/src/mlcast/config/consistency_checks.py +++ b/src/mlcast/config/consistency_checks.py @@ -15,13 +15,13 @@ from loguru import logger -def validate_config(cfg: fdl.Config) -> None: - """Validate cross-system constraints on a Fiddle configuration before training. +def _validate_forecasting_experiment_cfg(cfg: fdl.Config) -> None: + """Validate a single-stage forecasting experiment configuration. Parameters ---------- cfg : fdl.Config - Fiddle configuration. + Fiddle configuration for a single forecasting experiment. Raises ------ @@ -34,8 +34,6 @@ def validate_config(cfg: fdl.Config) -> None: data = cfg.data # Contract 1: Network input_channels == len(sequence_dataset_factory.standard_names) - # If the network does not expose input_channels, emit a warning because - # this contract cannot be checked. num_vars = len(sequence_dataset_factory.standard_names) try: net_input_channels = network.input_channels @@ -53,8 +51,6 @@ def validate_config(cfg: fdl.Config) -> None: ) # Contract 2: Sequence dataset width must be divisible by 2 ** network.num_blocks - # If the network does not expose num_blocks, emit a warning because this - # contract cannot be checked. try: num_blocks = network.num_blocks except AttributeError: @@ -120,3 +116,64 @@ def validate_config(cfg: fdl.Config) -> None: f"Contract 6 violated: network forecast_steps ({net_forecast_steps}) " f"must equal data.forecast_steps ({data.forecast_steps})." ) + + +def _validate_ldcast_training_experiment_cfg(cfg: fdl.Config) -> None: + """Validate a two-stage LDCast training experiment configuration. + + Parameters + ---------- + cfg : fdl.Config + Fiddle configuration for a two-stage LDCast experiment. + + Raises + ------ + ValueError + If any LDCast-specific configuration contract is violated. + """ + stage1 = cfg.stage1 + stage2 = cfg.stage2 + + autoencoder = stage1.pl_module.network + if autoencoder is not stage2.pl_module.autoencoder: + raise ValueError("LDCast contract violated: stage1 and stage2 must share the same autoencoder config object.") + + stage1_data = stage1.data + stage2_data = stage2.data + if stage1_data.input_steps != stage2_data.input_steps: + raise ValueError( + "LDCast contract violated: stage1 and stage2 must use the same input_steps; " + f"got {stage1_data.input_steps} and {stage2_data.input_steps}." + ) + + stage2_module = stage2.pl_module + if stage2_data.forecast_steps != stage2_module.forecast_steps: + raise ValueError( + "LDCast contract violated: stage2 data.forecast_steps must match the latent diffusion task module; " + f"got {stage2_data.forecast_steps} and {stage2_module.forecast_steps}." + ) + + if len(stage1_data.sequence_dataset_factory.standard_names) != autoencoder.encoder.input_channels: + raise ValueError( + "LDCast contract violated: autoencoder encoder input_channels must match the number of source variables." + ) + + +def validate_config(cfg: fdl.Config) -> None: + """Validate cross-system constraints on a Fiddle configuration before training. + + Parameters + ---------- + cfg : fdl.Config + Fiddle configuration. + + Raises + ------ + ValueError + If any configuration contract is violated. + """ + if hasattr(cfg, "stage1") and hasattr(cfg, "stage2"): + _validate_ldcast_training_experiment_cfg(cfg) + return + + _validate_forecasting_experiment_cfg(cfg) diff --git a/src/mlcast/config/ldcast.py b/src/mlcast/config/ldcast.py new file mode 100644 index 0000000..ff01c5d --- /dev/null +++ b/src/mlcast/config/ldcast.py @@ -0,0 +1,148 @@ +"""Fiddle configuration for two-stage LDCast training.""" + +from dataclasses import dataclass + +import fiddle as fdl +import fiddle.experimental.auto_config +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from mlcast.config.base import Experiment +from mlcast.data.datamodules import ForecastingDataModule, ReconstructionDataModule +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset +from mlcast.models.autoencoder import AutoencoderNet, Decoder, Encoder +from mlcast.models.diffusion import ConditionerNet, DenoiserUNet, DiffusionScheduler, LatentDiffusionNet +from mlcast.modules.forecasting import LatentDiffusionTaskModule +from mlcast.modules.reconstruction import ReconstructionTaskModule + + +@dataclass +class LDCastTrainingExperiment: + """Two-stage LDCast training experiment. + + Parameters + ---------- + stage1 : Experiment + Reconstruction training stage for the autoencoder. + stage2 : Experiment + Latent diffusion training stage reusing the same trained autoencoder + instance from stage 1. + """ + + stage1: Experiment + stage2: Experiment + + @property + def trainer(self) -> pl.Trainer: + """Expose the first trainer for orchestrator compatibility. + + Returns + ------- + pl.Trainer + Trainer used by stage 1. + """ + return self.stage1.trainer + + def run(self) -> None: + """Run stage-1 reconstruction followed by stage-2 latent diffusion.""" + self.stage1.trainer.fit(self.stage1.pl_module, datamodule=self.stage1.data) + self.stage1.trainer.test(self.stage1.pl_module, datamodule=self.stage1.data) + + self.stage2.trainer.fit(self.stage2.pl_module, datamodule=self.stage2.data) + self.stage2.trainer.test(self.stage2.pl_module, datamodule=self.stage2.data) + + +@fiddle.experimental.auto_config.auto_config +def ldcast_training_experiment() -> LDCastTrainingExperiment: + """Build a Fiddle config for two-stage LDCast training. + + Returns + ------- + LDCastTrainingExperiment + Configured two-stage experiment with shared autoencoder identity across + reconstruction and latent diffusion stages. + """ + input_steps = 4 + forecast_steps = 12 + sequence_steps = input_steps + forecast_steps + + sequence_dataset_factory = fdl.Partial( + SourceDataPrecomputedSequenceDataset, + zarr_path="./data/radar.zarr", + csv_path="./data/sampled_datacubes.csv", + standard_names=["rainfall_rate"], + sequence_steps=sequence_steps, + deterministic=False, + ) + + autoencoder = AutoencoderNet( + encoder=Encoder(input_channels=1, hidden_channels=16, latent_channels=32, num_blocks=2), + decoder=Decoder(output_channels=1, hidden_channels=16, latent_channels=32, num_blocks=2), + ) + + stage1_data = ReconstructionDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=input_steps, + splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, + batch_size=16, + num_workers=8, + pin_memory=True, + ) + stage1_module = ReconstructionTaskModule( + network=autoencoder, + loss_class="mse", + optimizer=fdl.Partial(torch.optim.Adam, lr=1e-4, fused=True), + lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.5, patience=10), + ) + stage1_trainer = pl.Trainer( + accelerator="auto", + max_epochs=20, + callbacks=[ + ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min"), + EarlyStopping(monitor="val_loss", patience=20, mode="min"), + LearningRateMonitor(logging_interval="step"), + ], + logger=TensorBoardLogger(save_dir="logs", name="mlcast_ldcast_stage1"), + ) + + stage2_data = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=input_steps, + forecast_steps=forecast_steps, + return_mask=False, + splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, + batch_size=8, + num_workers=8, + pin_memory=True, + ) + diffusion_net = LatentDiffusionNet( + conditioner=ConditionerNet(latent_channels=32, hidden_channels=32, num_blocks=2), + denoiser=DenoiserUNet(latent_channels=32, condition_channels=32, hidden_channels=32, num_blocks=2), + scheduler=DiffusionScheduler(timesteps=20), + ) + stage2_module = LatentDiffusionTaskModule( + autoencoder=autoencoder, + diffusion_net=diffusion_net, + forecast_steps=forecast_steps, + ensemble_size=2, + optimizer=fdl.Partial(torch.optim.Adam, lr=1e-4, fused=True), + lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.5, patience=10), + ema_decay=0.999, + ) + stage2_trainer = pl.Trainer( + accelerator="auto", + max_epochs=20, + callbacks=[ + ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min"), + EarlyStopping(monitor="val_loss", patience=20, mode="min"), + LearningRateMonitor(logging_interval="step"), + ], + logger=TensorBoardLogger(save_dir="logs", name="mlcast_ldcast_stage2"), + ) + + return LDCastTrainingExperiment( + stage1=Experiment(pl_module=stage1_module, data=stage1_data, trainer=stage1_trainer), + stage2=Experiment(pl_module=stage2_module, data=stage2_data, trainer=stage2_trainer), + ) diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py index 96060cd..c27b321 100644 --- a/src/mlcast/modules/forecasting.py +++ b/src/mlcast/modules/forecasting.py @@ -507,6 +507,14 @@ def __init__( self.sampler = DiffusionSampler(diffusion_net) self.ema = ExponentialMovingAverage(diffusion_net, decay=ema_decay) if ema_decay is not None else None + def _freeze_autoencoder(self) -> None: + """Freeze the reused autoencoder before stage-2 use. + + The same autoencoder instance is shared with stage-1 reconstruction + training, so freezing must happen when the diffusion stage begins rather + than in ``__init__``. + """ + self.autoencoder.eval() for parameter in self.autoencoder.parameters(): parameter.requires_grad = False @@ -594,8 +602,13 @@ def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: if self.ema is not None: self.ema.update() + def on_fit_start(self) -> None: + """Freeze the reused autoencoder before diffusion training starts.""" + self._freeze_autoencoder() + def on_validation_start(self) -> None: """Swap EMA weights in before validation when enabled.""" + self._freeze_autoencoder() if self.ema is not None: self.ema.apply() @@ -606,6 +619,7 @@ def on_validation_end(self) -> None: def on_test_start(self) -> None: """Swap EMA weights in before testing when enabled.""" + self._freeze_autoencoder() if self.ema is not None: self.ema.apply() @@ -616,6 +630,7 @@ def on_test_end(self) -> None: def on_predict_start(self) -> None: """Swap EMA weights in before prediction when enabled.""" + self._freeze_autoencoder() if self.ema is not None: self.ema.apply() diff --git a/tests/config/test_ldcast_experiment.py b/tests/config/test_ldcast_experiment.py new file mode 100644 index 0000000..b784a65 --- /dev/null +++ b/tests/config/test_ldcast_experiment.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass + +import fiddle as fdl + +from mlcast.config import LDCastTrainingExperiment, ldcast_training_experiment, validate_config +from mlcast.config.base import Experiment + + +@dataclass +class RecordingTrainer: + """Minimal trainer stub that records fit/test call order.""" + + events: list[str] + + def fit(self, pl_module, datamodule=None) -> None: # type: ignore[no-untyped-def] + self.events.append(f"fit:{pl_module}:{datamodule}") + + def test(self, pl_module, datamodule=None) -> None: # type: ignore[no-untyped-def] + self.events.append(f"test:{pl_module}:{datamodule}") + + +def test_ldcast_training_experiment_runs_stages_in_order() -> None: + """LDCastTrainingExperiment should execute stage 1 fully before stage 2.""" + events: list[str] = [] + stage1 = Experiment(pl_module="stage1_module", data="stage1_data", trainer=RecordingTrainer(events=events)) + stage2 = Experiment(pl_module="stage2_module", data="stage2_data", trainer=RecordingTrainer(events=events)) + experiment = LDCastTrainingExperiment(stage1=stage1, stage2=stage2) + + experiment.run() + + assert events == [ + "fit:stage1_module:stage1_data", + "test:stage1_module:stage1_data", + "fit:stage2_module:stage2_data", + "test:stage2_module:stage2_data", + ] + + +def test_ldcast_training_experiment_shares_autoencoder_identity() -> None: + """Stage 1 and stage 2 should reference the same built autoencoder instance.""" + cfg = ldcast_training_experiment.as_buildable() + validate_config(cfg) + + experiment = fdl.build(cfg) + + assert experiment.stage1.pl_module.network is experiment.stage2.pl_module.autoencoder diff --git a/tests/config/test_orchestrator.py b/tests/config/test_orchestrator.py index cf17e8f..7cdaa98 100644 --- a/tests/config/test_orchestrator.py +++ b/tests/config/test_orchestrator.py @@ -2,7 +2,7 @@ from typing import Any from unittest.mock import patch -from mlcast.config import convgru_training_experiment, train_from_config +from mlcast.config import convgru_training_experiment, ldcast_training_experiment, train_from_config @patch("mlcast.config.orchestrator.fdl.build") @@ -12,3 +12,12 @@ def test_train_from_config_valid(mock_build: Any, tmp_path: Path) -> None: cfg = convgru_training_experiment.as_buildable() train_from_config(cfg) mock_build.assert_called_once() + + +@patch("mlcast.config.orchestrator.fdl.build") +def test_train_from_config_valid_ldcast(mock_build: Any, tmp_path: Path) -> None: + """Verify that a valid LDCast configuration passes validation and builds.""" + mock_build.return_value.trainer.log_dir = str(tmp_path) + cfg = ldcast_training_experiment.as_buildable() + train_from_config(cfg) + mock_build.assert_called_once() From 9363c0406ea3412f24a1010ff3e19da9f1ef7ea9 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 00:10:49 +0200 Subject: [PATCH 20/34] docs: add LDCast architecture docs, config diagram, and implementation comparison - Add LDCast section to README covering autoencoder, latent diffusion, and two-stage training experiment - Add ldcast_config_diagram.svg (Fiddle config graph render) - Add generate_ldcast_config_diagram.py script - Add pre-commit hook to keep LDCast diagram in sync - Add Martinbo alignment comparison notes to ldcast-refactor-plan.md --- .pre-commit-config.yaml | 10 +- README.md | 169 ++++- docs/generate_ldcast_config_diagram.py | 52 ++ docs/ldcast_config_diagram.svg | 989 +++++++++++++++++++++++++ ldcast-refactor-plan.md | 81 ++ 5 files changed, 1283 insertions(+), 18 deletions(-) create mode 100644 docs/generate_ldcast_config_diagram.py create mode 100644 docs/ldcast_config_diagram.svg diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6aafae4..40413cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,12 +28,18 @@ repos: - id: pyproject-fmt - repo: local hooks: - - id: config-diagram - name: config diagram up to date + - id: config-diagram-convgru + name: convgru config diagram up to date language: system entry: uv run python docs/generate_base_experiment_config_diagram.py --check files: ^src/mlcast/config/base\.py$ pass_filenames: false + - id: config-diagram-ldcast + name: ldcast config diagram up to date + language: system + entry: uv run python docs/generate_ldcast_config_diagram.py --check + files: ^src/mlcast/config/ldcast\.py$ + pass_filenames: false ci: autoupdate_schedule: monthly autoupdate_commit_msg: "chore(deps): pre-commit.ci autoupdate" diff --git a/README.md b/README.md index 5b58d83..4ee858b 100644 --- a/README.md +++ b/README.md @@ -67,12 +67,18 @@ reproduce runs exactly from a saved YAML file. ### Configuration model -Training in mlcast is currently built around the included configuration -function, [`convgru_training_experiment`](src/mlcast/config/base.py), which -defines the ConvGRU ensemble nowcasting setup: dataset, data module, network, -Lightning module, and trainer. Rather than writing a new config from scratch, -the intended workflow is to start from this config and apply targeted -modifications: +mlcast ships with two included configuration functions: + +- [`convgru_training_experiment`](src/mlcast/config/base.py) — defines a + single-stage ConvGRU ensemble nowcasting setup (dataset, data module, network, + Lightning module, trainer). +- [`ldcast_training_experiment`](src/mlcast/config/ldcast.py) — defines a + two-stage LDCast setup: stage 1 trains an autoencoder on reconstruction + windows, stage 2 trains a latent diffusion model on the same autoencoder's + latent space. + +Rather than writing a new config from scratch, the intended workflow is to +start from one of these configs and apply targeted modifications: - **`set:` overrides** — change a single scalar parameter (e.g. batch size, learning rate, number of epochs) @@ -86,6 +92,16 @@ Any combination of these can be layered on top of the selected config, and the fully resolved config is always saved to YAML alongside the training logs so runs can be reproduced exactly. +The diagrams below show the full included config graphs. + +**convgru_training_experiment:** + +![convgru_training_experiment config graph](docs/config_diagram.svg) + +**ldcast_training_experiment:** + +![ldcast_training_experiment config graph](docs/ldcast_config_diagram.svg) + ### Design roles mlcast separates pure architectures from task-level training wrappers. @@ -109,25 +125,23 @@ owns the trained autoencoder reuse policy, decides that only diffusion-network parameters are optimized, computes diffusion loss in latent space, and handles decoded forecast inference. -The diagram below shows the full included ConvGRU config graph as built by -[`convgru_training_experiment`](src/mlcast/config/base.py): - -![convgru_training_experiment config graph](docs/config_diagram.svg) - ### Command-line interface Install the package and run: ```bash +# Single-stage ConvGRU nowcasting mlcast train --config config:convgru_training_experiment + +# Two-stage LDCast latent diffusion +mlcast train --config config:ldcast_training_experiment ``` -This trains with the built-in [`convgru_training_experiment`](src/mlcast/config/base.py) config. All parameters -are controlled via `--config` flags: +All parameters are controlled via `--config` flags: | Prefix | Purpose | Example | |--------|---------|---------| -| `config:` | Select an included `@auto_config` function | `--config config:convgru_training_experiment` | +| `config:` | Select an included `@auto_config` function | `--config config:convgru_training_experiment` or `--config config:ldcast_training_experiment` | | `set:` | Override a single parameter | `--config set:data.batch_size=32` | | `fiddler:` | Apply a semantic mutator (multi-param change) | `--config fiddler:use_random_sampler` | | `path/to/config.yaml` | Load a previously saved config | `--config saved.yaml` | @@ -154,6 +168,11 @@ mlcast train \ --config logs/mlcast/version_0/config.yaml \ --config set:trainer.max_epochs=50 +# Run two-stage LDCast training with a shorter diffusion schedule +mlcast train \ + --config config:ldcast_training_experiment \ + --config set:stage2.pl_module.diffusion_net.scheduler.timesteps=20 + # Inspect the fully resolved config without starting training mlcast train --config config:convgru_training_experiment --config fiddler:use_random_sampler --print_config_and_exit ``` @@ -186,6 +205,26 @@ cfg.trainer.max_epochs = 50 train_from_config(cfg) ``` +**Run the included LDCast experiment with tweaks:** + +```python +import fiddle as fdl +from mlcast.config import ldcast_training_experiment, train_from_config +from mlcast.config.fiddlers import use_random_sampler + +cfg = ldcast_training_experiment.as_buildable() + +# Both stages share the same dataset, so switch to random sampling once +use_random_sampler(cfg.stage1.data) +use_random_sampler(cfg.stage2.data) + +# Override the diffusion noise schedule +cfg.stage2.pl_module.diffusion_net.scheduler.timesteps = 20 + +# train_from_config applies to the full two-stage experiment +train_from_config(cfg) +``` + **Custom network architecture:** You can swap in any architecture by replacing `cfg.pl_module.network` with a @@ -304,6 +343,7 @@ mlcast/ │ ├── visualization.py # TensorBoard image logging helpers │ ├── config/ │ │ ├── base.py # ConvGRU training config @auto_config +│ │ ├── ldcast.py # LDCast two-stage config @auto_config │ │ ├── fiddlers.py # Semantic config mutators │ │ ├── consistency_checks.py # Cross-parameter validation │ │ ├── loader.py # YAML config loader @@ -314,8 +354,23 @@ mlcast/ │ │ ├── forecasting.py # Forecasting task dataset wrapper │ │ ├── reconstruction.py # Reconstruction task dataset wrapper │ │ └── normalization.py # Normalisation registry -│ └── models/ -│ └── convgru.py # ConvGRU encoder-decoder +│ ├── models/ +│ │ ├── convgru.py # ConvGRU encoder-decoder +│ │ ├── autoencoder/ +│ │ │ ├── encoder.py # Encoder +│ │ │ ├── decoder.py # Decoder +│ │ │ └── net.py # AutoencoderNet composition +│ │ └── diffusion/ +│ │ ├── conditioner.py # ConditionerNet (context builder) +│ │ ├── denoiser.py # DenoiserUNet +│ │ ├── scheduler.py # Diffusion noise scheduler +│ │ ├── sampler.py # Inference-time sampling loop +│ │ ├── ema.py # EMA weight tracking +│ │ ├── loss.py # Diffusion loss +│ │ └── net.py # LatentDiffusionNet composition +│ └── modules/ +│ ├── forecasting.py # Base + OutputSpace + LatentDiffusion task modules +│ └── reconstruction.py # ReconstructionTaskModule ├── tests/ ├── pyproject.toml └── README.md @@ -360,6 +415,88 @@ concatenated along the channel dimension. ![ConvGruModel stochastic architecture](docs/architectures/convgru-stochastic.png) +### LatentDiffusionNet (LDCast) + +LDCast is a **two-stage** latent diffusion nowcasting system. Stage 1 trains an +autoencoder on reconstruction windows; stage 2 trains a latent diffusion model +that forecasts in the autoencoder's latent space and decodes forecasts back to +data space. + +The architecture components live under `src/mlcast/models/autoencoder/` and +`src/mlcast/models/diffusion/`. The task-level Lightning modules live under +`src/mlcast/modules/` and are wired together by +[`ldcast_training_experiment`](src/mlcast/config/ldcast.py). + +#### Stage 1 — Autoencoder reconstruction + +The autoencoder is built from an +[`Encoder`](src/mlcast/models/autoencoder/encoder.py) and +[`Decoder`](src/mlcast/models/autoencoder/decoder.py), composed by +[`AutoencoderNet`](src/mlcast/models/autoencoder/net.py). + +- **Encoder** — a stack of `EncoderBlock` layers. Each block downsamples + spatial resolution via strided 3D convolution and doubles the channel count. + The final output is a latent tensor with shape + `(batch, latent_channels, time, latent_height, latent_width)`. +- **Decoder** — a stack of `DecoderBlock` layers that mirror the encoder. Each + block upsamples spatial resolution via transposed 3D convolution and halves + the channel count, reconstructing the original input shape. + +The autoencoder is trained on overlapping temporal windows via +[`ReconstructionDataset`](src/mlcast/data/reconstruction.py) and +[`ReconstructionDataModule`](src/mlcast/data/datamodules.py). The +[`ReconstructionTaskModule`](src/mlcast/modules/reconstruction.py) optimises +the full autoencoder parameters against an MSE reconstruction loss. + +#### Stage 2 — Latent diffusion forecasting + +The latent diffusion model is built from a +[`ConditionerNet`](src/mlcast/models/diffusion/conditioner.py), +[`DenoiserUNet`](src/mlcast/models/diffusion/denoiser.py), and +[`DiffusionScheduler`](src/mlcast/models/diffusion/scheduler.py), composed by +[`LatentDiffusionNet`](src/mlcast/models/diffusion/net.py). + +- **ConditionerNet** — projects encoded input-history latents through a series + of residual 3D convolution blocks to produce a conditioning context for the + denoiser U-Net. This answers "what did the recent past look like in latent + space?" +- **DenoiserUNet** — a timestep-aware U-Net with 3D convolutions over the + latent spatial dimensions (time dimension is preserved). It receives the + noisy target latent, a diffusion timestep embedding (sinusoidal), and the + conditioning context from the conditioner. The U-Net predicts the additive + noise (`eps` parameterization) that was applied to reach the current + timestep. +- **DiffusionScheduler** — defines the forward diffusion noise schedule + (linear beta schedule by default) and provides the pre-computed alpha/beta + buffers used by the forward and reverse processes. + +Training uses a standard MSE diffusion loss (`DiffusionLoss` in +`src/mlcast/models/diffusion/loss.py`): for each batch the input is encoded +with the trained (frozen) encoder, the target is encoded with the same encoder, +a random timestep is drawn per sample, noise is added to the target latents, +and the denoiser is trained to predict the added noise. + +Inference uses a [`DiffusionSampler`](src/mlcast/models/diffusion/sampler.py) +to progressively denoise random latents conditioned on encoded input history. +The reverse diffusion loop steps backward through the schedule, and the final +denoised latent is decoded back to data space by the trained decoder. When +`ensemble_size > 1`, the process is repeated with fresh noise and the results +are concatenated. + +#### Two-stage training experiment + +The [`ldcast_training_experiment`](src/mlcast/config/ldcast.py) auto-config +orchestrates both stages: + +- Stage 1 builds a `ReconstructionDataModule`, `AutoencoderNet`, and + `ReconstructionTaskModule`, then calls `trainer.fit() + trainer.test()`. +- Stage 2 reuses the **same trained autoencoder instance** (Fiddle graph + identity sharing), builds a `ForecastingDataModule` and + `LatentDiffusionTaskModule`, then calls `trainer.fit() + trainer.test()`. +- The stage-2 module freezes the autoencoder on `fit_start` and optimises only + the diffusion-network parameters. + + ### Custom network interface Any network architecture can be used by replacing `cfg.pl_module.network` diff --git a/docs/generate_ldcast_config_diagram.py b/docs/generate_ldcast_config_diagram.py new file mode 100644 index 0000000..a5686e7 --- /dev/null +++ b/docs/generate_ldcast_config_diagram.py @@ -0,0 +1,52 @@ +"""Generate a Graphviz SVG diagram of the included LDCast training config. + +Run without arguments to regenerate docs/ldcast_config_diagram.svg: + + uv run python docs/generate_ldcast_config_diagram.py + +Run with --check to verify the diagram is up to date: + + uv run python docs/generate_ldcast_config_diagram.py --check +""" + +import argparse +import sys +from pathlib import Path + +import fiddle.graphviz as fgv + +from mlcast.config import ldcast_training_experiment + +OUT = Path(__file__).parent / "ldcast_config_diagram.svg" + + +def main() -> None: + """Generate or verify the LDCast training config diagram.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--check", + action="store_true", + help="Check that the diagram is up to date rather than regenerating it.", + ) + args = parser.parse_args() + + cfg = ldcast_training_experiment.as_buildable() + g = fgv.render(cfg, max_str_length=40) + g.format = "svg" + new_svg = g.pipe().decode() + + if args.check: + if not OUT.exists() or OUT.read_text() != new_svg: + print( + "docs/ldcast_config_diagram.svg is out of date.\n" + "Run: uv run python docs/generate_ldcast_config_diagram.py" + ) + sys.exit(1) + print("docs/ldcast_config_diagram.svg is up to date.") + else: + OUT.write_text(new_svg) + print(f"Written {OUT}") + + +if __name__ == "__main__": + main() diff --git a/docs/ldcast_config_diagram.svg b/docs/ldcast_config_diagram.svg new file mode 100644 index 0000000..59803a2 --- /dev/null +++ b/docs/ldcast_config_diagram.svg @@ -0,0 +1,989 @@ + + + + + + +%3 + + + +4 + + +Config: + Encoder + + +input_channels + +1 + + +hidden_channels + +16 + + +latent_channels + +32 + + +num_blocks + +2 + + + +3 + + +Config: + AutoencoderNet + + +encoder + + + + + +decoder + + + + + + +3:c--4:c + + + + +5 + + +Config: + Decoder + + +output_channels + +1 + + +hidden_channels + +16 + + +latent_channels + +32 + + +num_blocks + +2 + + + +3:c--5:c + + + + +2 + + +Config: + ReconstructionTaskModule + + +network + + + + + +loss_class + +'mse' + + +optimizer + + + + + +lr_scheduler + + + + + + +2:c--3:c + + + + +6 + + +Partial: + Adam + + +lr + +0.0001 + + +fused + +True + + + +2:c--6:c + + + + +7 + + +Partial: + ReduceLROnPlateau + + +mode + +'min' + + +factor + +0.5 + + +patience + +10 + + + +2:c--7:c + + + + +1 + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + + + + +1:c--2:c + + + + +8 + + +Config: + ReconstructionDataModule + + +sequence_dataset_factory + + + + + +input_steps + +4 + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +16 + + +num_workers + +8 + + +pin_memory + +True + + + +1:c--8:c + + + + +10 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 + + + +1:c--10:c + + + + +9 + + +Partial: + SourceDataPrecomputedSequenceDataset + + +zarr_path + +'./data/radar.zarr' + + +csv_path + +'./data/sampled_datacubes.csv' + + +standard_names + + + +list + +'rainfall_rate' + + +0 + + +sequence_steps + +16 + + +deterministic + +False + + + +8:c--9:c + + + + +11 + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast_ldcast_stage1' + + + +10:c--11:c + + + + +12 + + +Config: + ModelCheckpoint + + +monitor + +'val_loss' + + +save_top_k + +1 + + +mode + +'min' + + + +10:c--12:c + + + + +13 + + +Config: + EarlyStopping + + +monitor + +'val_loss' + + +patience + +20 + + +mode + +'min' + + + +10:c--13:c + + + + +14 + + +Config: + LearningRateMonitor + + +logging_interval + +'step' + + + +10:c--14:c + + + + +0 + + +Config: + LDCastTrainingExperiment + + +stage1 + + + + + +stage2 + + + + + + +0:c--1:c + + + + +15 + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + + + + +0:c--15:c + + + + +16 + + +Config: + LatentDiffusionTaskModule + + +autoencoder + + + + + +diffusion_net + + + + + +forecast_steps + +12 + + +ensemble_size + +2 + + +optimizer + + + + + +lr_scheduler + + + + + +ema_decay + +0.999 + + + +16:c--3:c + + + + +17 + + +Config: + LatentDiffusionNet + + +conditioner + + + + + +denoiser + + + + + +scheduler + + + + + + +16:c--17:c + + + + +21 + + +Partial: + Adam + + +lr + +0.0001 + + +fused + +True + + + +16:c--21:c + + + + +22 + + +Partial: + ReduceLROnPlateau + + +mode + +'min' + + +factor + +0.5 + + +patience + +10 + + + +16:c--22:c + + + + +18 + + +Config: + ConditionerNet + + +latent_channels + +32 + + +hidden_channels + +32 + + +num_blocks + +2 + + + +17:c--18:c + + + + +19 + + +Config: + DenoiserUNet + + +latent_channels + +32 + + +condition_channels + +32 + + +hidden_channels + +32 + + +num_blocks + +2 + + + +17:c--19:c + + + + +20 + + +Config: + DiffusionScheduler + + +timesteps + +20 + + + +17:c--20:c + + + + +15:c--16:c + + + + +23 + + +Config: + ForecastingDataModule + + +sequence_dataset_factory + + + + + +input_steps + +4 + + +forecast_steps + +12 + + +return_mask + +False + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +8 + + +num_workers + +8 + + +pin_memory + +True + + + +15:c--23:c + + + + +24 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 + + + +15:c--24:c + + + + +23:c--9:c + + + + +25 + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast_ldcast_stage2' + + + +24:c--25:c + + + + +26 + + +Config: + ModelCheckpoint + + +monitor + +'val_loss' + + +save_top_k + +1 + + +mode + +'min' + + + +24:c--26:c + + + + +27 + + +Config: + EarlyStopping + + +monitor + +'val_loss' + + +patience + +20 + + +mode + +'min' + + + +24:c--27:c + + + + +28 + + +Config: + LearningRateMonitor + + +logging_interval + +'step' + + + +24:c--28:c + + + + diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 06e0d49..9ea0f61 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -150,3 +150,84 @@ align us more closely with DMI. - **DMI**: `DDPStrategy(find_unused_parameters=True)` on autoencoder. - **Ours**: default (no `DDPStrategy`). - **To align**: no change needed unless running DDP. + +## Martinbo alignment notes + +The `feat/ldcast-martinbo` branch differs from both our current config and +the DMI reference in several ways. + +### Optimizer +- **DMI**: `AdamW`, `lr=1e-3` / `1e-4`, `betas=(0.5, 0.9)`, `wd=1e-3`. +- **Martinbo**: `AdamW`, `lr=1e-3` / `1e-4`, `betas=[0.5, 0.9]`, `wd=0.001`. +- **Ours**: `Adam`, `lr=1e-4` for both, default betas, no weight decay. +- **To align**: Martinbo matches DMI exactly — `AdamW`, per-stage LR, betas, and wd. + +### LR scheduler +- **DMI**: `ReduceLROnPlateau(factor=0.25, patience=3)`, monitors + `val_rec_loss` / `val_loss_ema`. +- **Martinbo**: `ReduceLROnPlateau(factor=0.25, patience=3)`, monitors + `val/rec_loss` / `val/loss`. +- **Ours**: `ReduceLROnPlateau(factor=0.5, patience=10)`, monitors + `val_loss` for both stages. +- **To align**: Martinbo matches DMI's factor/patience; only monitor-metric + naming differs (`val/rec_loss` vs `val_rec_loss`). + +### Learning rate warmup +- **DMI**: Diffusion warmup support (`lr_warmup=0`, disabled by default). +- **Martinbo**: No warmup support in either stage. +- **Ours**: No warmup in either stage. +- **To align**: no change needed (DMI also has it disabled by default). + +### EMA +- **DMI**: `LitEma` with `decay=0.9999` (adaptive), on full diffusion model. +- **Martinbo**: `EMA` with `decay=0.9999` (dynamic, adaptive), wraps + **denoiser only** (`store_device='cuda'`). +- **Ours**: `ExponentialMovingAverage` with `decay=0.999`, on diffusion net. +- **To align**: increase decay to `0.9999`; consider whether EMA should wrap + the full diffusion net or just the denoiser. + +### Early stopping +- **DMI**: patience `6`, monitors `val_rec_loss` / `val_loss_ema`, + `check_finite=False` on diffusion. +- **Martinbo**: patience `6`, monitors `val/loss_epoch` (both stages), + `check_finite=False`. +- **Ours**: patience `20`, monitors `val_loss`. +- **To align**: Martinbo matches DMI's patience and `check_finite=False`; + monitor naming differs (`val/loss_epoch` vs `val_loss_ema`). + +### Model checkpointing +- **DMI**: `save_top_k=3`, monitors `val_rec_loss` / `val_loss_ema`. +- **Martinbo**: Not explicitly configured in branch `config.yaml` (relies on + Lightning default, `save_top_k=1`). +- **Ours**: `save_top_k=1`, monitors `val_loss`. +- **To align**: Martinbo implicitly matches Ours on `save_top_k`; DMI differs + with `save_top_k=3`. + +### Diffusion noise schedule +- **DMI**: `timesteps=1000`, linear beta `1e-4` to `2e-2`. +- **Martinbo**: `timesteps=1000`, linear beta `1e-4` to `2e-2` (defaults, + config section is `{}`). +- **Ours**: `timesteps=20`, default linear schedule. +- **To align**: Martinbo matches DMI exactly — `timesteps=1000`, same beta range. + +### Batch size and gradient accumulation +- **DMI**: `batch_size=4` / `1` (example configs), `accumulate_grad_batches=2`. +- **Martinbo**: `batch_size=1` for both stages; no `accumulate_grad_batches`. +- **Ours**: `batch_size=16` / `8`; no gradient accumulation. +- **To align**: Martinbo uses smaller batches than both DMI and Ours; none + of the three agree on batch size strategy. + +### DDP strategy +- **DMI**: `DDPStrategy(find_unused_parameters=True)` (autoenc) / + `DDPStrategy()` (diffusion). +- **Martinbo**: `strategy='ddp'` (string), `sync_batchnorm=True`, `num_nodes=1`. +- **Ours**: default (no `DDPStrategy`). +- **To align**: no change needed unless running DDP. + +### Diffusion parameterization and loss +- **DMI**: `parameterization="eps"`, `loss_type="l2"` (MSE). +- **Martinbo**: `parametrization="eps"` (note: spelling difference), + `nn.MSELoss()`. +- **Ours**: `parameterization="eps"` in `DiffusionLoss` (L2 via + `nn.MSELoss` reduction). +- **To align**: All three agree on `eps` + MSE — no change needed. From 2599f17fc75604b83feeea1277dd51d777d694e2 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 00:28:59 +0200 Subject: [PATCH 21/34] refactor: add explicit ensemble dimension throughout forecasting pipeline ConvGruModel, task modules, and visualization now use an explicit ensemble dim (B, T, M, C, H, W) instead of flattening into channels. Task modules flatten via einops only at loss computation time. --- README.md | 20 +++++++++++++------ pyproject.toml | 2 ++ src/mlcast/models/convgru.py | 10 +++++----- src/mlcast/modules/forecasting.py | 33 +++++++++++++++++-------------- src/mlcast/visualization.py | 2 +- tests/models/test_convgru.py | 14 ++++++------- tests/test_nowcasting_module.py | 9 +++++---- tests/test_task_modules.py | 2 +- uv.lock | 17 +++++++++++++++- 9 files changed, 68 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 4ee858b..940aae0 100644 --- a/README.md +++ b/README.md @@ -281,7 +281,7 @@ class HalfUNetNowcaster(nn.Module): def forward( self, x: Float[torch.Tensor, "batch input_steps in_channels H W"], - ) -> Float[torch.Tensor, "batch forecast_steps out_channels H W"]: + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels H W"]: # channel-stack all input frames: (b, t, c, h, w) -> (b, t*c, h, w) x_flat = einops.rearrange(x, "b t c h w -> b (t c) h w") preds = [] @@ -291,7 +291,7 @@ class HalfUNetNowcaster(nn.Module): # slide window: drop the oldest timestep (first num_vars channels), # append the latest prediction as the newest timestep x_flat = torch.cat([x_flat[:, self.num_vars:], y], dim=1) - return torch.cat(preds, dim=1) + return torch.cat(preds, dim=1).unsqueeze(2) cfg = convgru_training_experiment.as_buildable() use_random_sampler(cfg) @@ -404,7 +404,8 @@ doubled at each block via `PixelShuffle(2)`. **Ensemble** — when `ensemble_size > 1` the decoder is run `ensemble_size` times, each time with freshly sampled Gaussian noise. The results are -concatenated along the channel dimension. +stacked along an explicit ensemble dimension, giving the final shape +`(batch, forecast_steps, ensemble_size, channels, height, width)`. **Deterministic variant** ([diagram source](https://docs.google.com/presentation/d/1U2Y9vZADXTsgQBNiWYAgOwYeMPVu7TOk/edit?slide=id.p6#slide=id.p6)): @@ -479,9 +480,11 @@ and the denoiser is trained to predict the added noise. Inference uses a [`DiffusionSampler`](src/mlcast/models/diffusion/sampler.py) to progressively denoise random latents conditioned on encoded input history. The reverse diffusion loop steps backward through the schedule, and the final -denoised latent is decoded back to data space by the trained decoder. When +denoised latent is decoded back to data space by the trained decoder, giving +an explicit ensemble dimension in the output shape +`(batch, forecast_steps, ensemble_size, channels, height, width)`. When `ensemble_size > 1`, the process is repeated with fresh noise and the results -are concatenated. +are stacked. #### Two-stage training experiment @@ -511,10 +514,15 @@ only runtime `forward` requirement is: def forward( self, x: Float[torch.Tensor, "batch input_steps in_channels H W"], -) -> Float[torch.Tensor, "batch forecast_steps out_channels H W"]: +) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels H W"]: ... ``` +The output has an explicit ensemble dimension. For deterministic models +(`ensemble_size=1`) this dimension is 1. If a loss function operates over +the full forecast tensor without splitting ensemble members (e.g. MSE on +the ensemble mean), the task module handles reshaping automatically. + If your network uses a different parameter name for the input channel count than `input_channels` (the default assumed by `ConvGruModel` and the `set_variables` fiddler), set it explicitly on the config node. diff --git a/pyproject.toml b/pyproject.toml index 3c1869f..b0a61ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "absl-py>=2.4", "beartype>=0.18", "cf-xarray>=0.10", + "einops>=0.8", "etils>=1.13", "fiddle>=0.3", "fire>=0.7", @@ -84,6 +85,7 @@ scripts.mlcast = "mlcast.__main__:cli" [dependency-groups] dev = [ + "pre-commit>=4.6", "pytest>=9.0.3", ] mlflow = [ ] diff --git a/src/mlcast/models/convgru.py b/src/mlcast/models/convgru.py index da54bd5..4482264 100644 --- a/src/mlcast/models/convgru.py +++ b/src/mlcast/models/convgru.py @@ -386,7 +386,7 @@ def __init__( @jaxtyped(typechecker=beartype) def forward( self, x: Float[torch.Tensor, "batch time channels height width"] - ) -> Float[torch.Tensor, "batch forecast_steps _ height width"]: + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: """Forward the encoder-decoder model. Parameters @@ -396,8 +396,8 @@ def forward( Returns ------- - preds : Float[torch.Tensor, "batch forecast_steps out_channels height width"] - Forecast tensor. + preds : Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"] + Forecast tensor with an explicit ensemble dimension. """ if x.shape[1] != self.input_steps: raise ValueError(f"Expected {self.input_steps} input timesteps, got {x.shape[1]}.") @@ -423,11 +423,11 @@ def forward( x_dec = torch.randn(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device) decoded = self.decoder(x_dec, last_hidden_per_block) preds.append(decoded) - out = torch.cat(preds, dim=2) + out = torch.stack(preds, dim=2) else: x_dec_func = torch.randn if self.noisy_decoder else torch.zeros x_dec = x_dec_func(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device) - out = self.decoder(x_dec, last_hidden_per_block) + out = self.decoder(x_dec, last_hidden_per_block).unsqueeze(2) if pad_h > 0 or pad_w > 0: out = out[..., :H, :W] diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py index c27b321..d0b36ff 100644 --- a/src/mlcast/modules/forecasting.py +++ b/src/mlcast/modules/forecasting.py @@ -8,6 +8,7 @@ import pytorch_lightning as pl import torch from beartype import beartype +from einops import rearrange from jaxtyping import Float, jaxtyped from mlcast.data.normalization import DENORMALIZATION_REGISTRY, NORMALIZATION_REGISTRY @@ -108,7 +109,7 @@ def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> def predict_normalized( self, x: Float[torch.Tensor, "batch time channels height width"], - ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: """Predict normalized forecasts from normalized inputs. Parameters @@ -118,8 +119,8 @@ def predict_normalized( Returns ------- - Float[torch.Tensor, "batch forecast_steps out_channels height width"] - Normalized forecast tensor. + Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"] + Normalized forecast tensor with an explicit ensemble dimension. """ return self(x) @@ -306,7 +307,7 @@ def __init__( def forward( self, x: Float[torch.Tensor, "batch time channels height width"], - ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: """Run the forecasting network. Parameters @@ -316,8 +317,8 @@ def forward( Returns ------- - Float[torch.Tensor, "batch forecast_steps out_channels height width"] - Normalized forecast tensor. + Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"] + Normalized forecast tensor with an explicit ensemble dimension. """ return self.network(x) @@ -341,11 +342,15 @@ def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> future = batch["target"] preds = self(past).clamp(min=-1, max=1) + # Flatten ensemble and channel dims for loss functions that expect + # (B, T, M*C, H, W), preserving backward compatibility with CRPS etc. + preds_flat = rearrange(preds, "b t m c h w -> b t (m c) h w") + if self.hparams["masked_loss"]: mask = batch["target_mask"] - loss = self.criterion(preds, future, mask) + loss = self.criterion(preds_flat, future, mask) else: - loss = self.criterion(preds, future) + loss = self.criterion(preds_flat, future) if isinstance(loss, tuple): loss, log_dict = loss @@ -522,7 +527,7 @@ def _freeze_autoencoder(self) -> None: def forward( self, x: Float[torch.Tensor, "batch input_steps channels height width"], - ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: """Generate decoded forecasts from normalized input histories. Parameters @@ -532,9 +537,9 @@ def forward( Returns ------- - Float[torch.Tensor, "batch forecast_steps out_channels height width"] - Decoded normalized forecast tensor with ensemble members - concatenated along the channel dimension. + Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"] + Decoded normalized forecast tensor with an explicit ensemble + dimension. """ input_latents = self.autoencoder.encode(x) repeated_input_latents = input_latents.repeat_interleave(self.hparams["ensemble_size"], dim=0) @@ -549,9 +554,7 @@ def forward( decoded = self.autoencoder.decode(forecast_latents) batch, time, channels, height, width = decoded.shape decoded = decoded.reshape(x.shape[0], self.hparams["ensemble_size"], time, channels, height, width) - return decoded.permute(0, 2, 1, 3, 4, 5).reshape( - x.shape[0], time, self.hparams["ensemble_size"] * channels, height, width - ) + return decoded.permute(0, 2, 1, 3, 4, 5) def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: """Compute latent diffusion loss for a forecasting batch. diff --git a/src/mlcast/visualization.py b/src/mlcast/visualization.py index a32eaee..f1b11cd 100644 --- a/src/mlcast/visualization.py +++ b/src/mlcast/visualization.py @@ -95,7 +95,7 @@ def log_images( if ensemble_size > 1: preds_avg = preds_sample.mean(dim=1, keepdim=True) num_members_to_log = min(3, preds_sample.shape[1]) - rows = [future_sample, preds_avg] + [preds_sample[:, i : i + 1] for i in range(num_members_to_log)] + rows = [future_sample.unsqueeze(1), preds_avg] + [preds_sample[:, i : i + 1] for i in range(num_members_to_log)] else: rows = [future_sample, preds_sample] diff --git a/tests/models/test_convgru.py b/tests/models/test_convgru.py index b296543..0494968 100644 --- a/tests/models/test_convgru.py +++ b/tests/models/test_convgru.py @@ -25,9 +25,9 @@ def test_convgru_dynamic_padding() -> None: with torch.no_grad(): preds = model(x) - # Check that it didn't crash and the output shape is exactly (batch, steps, channels, height, width) - # The single ensemble member case returns out_channels = channels. - assert preds.shape == (batch_size, forecast_steps, channels, height, width) + # Check that it didn't crash and the output shape is exactly (batch, steps, 1, channels, height, width) + # The single ensemble member case adds an explicit ensemble dimension. + assert preds.shape == (batch_size, forecast_steps, 1, channels, height, width) def test_convgru_dynamic_padding_ensemble() -> None: @@ -55,11 +55,9 @@ def test_convgru_dynamic_padding_ensemble() -> None: with torch.no_grad(): preds = model(x) - # Check that it didn't crash and the output shape is exactly (batch, steps, ensemble_size * channels, height, width) - # Actually wait: The decoder block outputs the same number of channels as the final upsampling step. - # In the `ConvGruModel.forward` with `ensemble_size > 1`, `out` is `torch.cat(preds, dim=2)`. - # Let's verify the exact channel dimension. The original output channels per ensemble member is `channels`. - assert preds.shape == (batch_size, forecast_steps, channels * ensemble_size, height, width) + # Check that it didn't crash and the output shape has an explicit ensemble dimension: + # (batch, forecast_steps, ensemble_size, channels, height, width) + assert preds.shape == (batch_size, forecast_steps, ensemble_size, channels, height, width) def test_convgru_rejects_wrong_input_steps() -> None: diff --git a/tests/test_nowcasting_module.py b/tests/test_nowcasting_module.py index ddca967..6bf49c8 100644 --- a/tests/test_nowcasting_module.py +++ b/tests/test_nowcasting_module.py @@ -15,8 +15,9 @@ def __init__(self, input_steps: int, forecast_steps: int, ensemble_size: int = 1 def forward(self, x: torch.Tensor) -> torch.Tensor: batch_size, _, channels, height, width = x.shape - out_channels = channels * self.ensemble_size - return torch.zeros(batch_size, self.forecast_steps, out_channels, height, width, device=x.device) + return torch.zeros( + batch_size, self.forecast_steps, self.ensemble_size, channels, height, width, device=x.device + ) def test_nowcasting_module_forward_uses_network_shape_contract() -> None: @@ -27,7 +28,7 @@ def test_nowcasting_module_forward_uses_network_shape_contract() -> None: preds = module(x) - assert preds.shape == (4, 5, 2, 8, 8) + assert preds.shape == (4, 5, 2, 1, 8, 8) def test_nowcasting_module_predict_uses_configured_output_shape() -> None: @@ -39,4 +40,4 @@ def test_nowcasting_module_predict_uses_configured_output_shape() -> None: preds = module.predict(past, standard_name="rainfall_rate") assert isinstance(preds, np.ndarray) - assert preds.shape == (2, 4, 8, 8) + assert preds.shape == (2, 4, 1, 8, 8) diff --git a/tests/test_task_modules.py b/tests/test_task_modules.py index a267f31..d8275f7 100644 --- a/tests/test_task_modules.py +++ b/tests/test_task_modules.py @@ -95,4 +95,4 @@ def test_latent_diffusion_module_predict_uses_configured_output_shape() -> None: preds = module.predict(past, standard_name="rainfall_rate") assert isinstance(preds, np.ndarray) - assert preds.shape == (2, 3, 8, 8) + assert preds.shape == (2, 3, 1, 8, 8) diff --git a/uv.lock b/uv.lock index d525b64..998be47 100644 --- a/uv.lock +++ b/uv.lock @@ -1026,6 +1026,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, ] +[[package]] +name = "einops" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/77/850bef8d72ffb9219f0b1aac23fbc1bf7d038ee6ea666f331fa273031aa2/einops-0.8.2.tar.gz", hash = "sha256:609da665570e5e265e27283aab09e7f279ade90c4f01bcfca111f3d3e13f2827", size = 56261, upload-time = "2026-01-26T04:13:17.638Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl", hash = "sha256:54058201ac7087911181bfec4af6091bb59380360f069276601256a76af08193", size = 65638, upload-time = "2026-01-26T04:13:18.546Z" }, +] + [[package]] name = "etils" version = "1.14.0" @@ -2238,6 +2247,7 @@ dependencies = [ { name = "absl-py" }, { name = "beartype" }, { name = "cf-xarray" }, + { name = "einops" }, { name = "etils" }, { name = "fiddle" }, { name = "fire" }, @@ -2287,6 +2297,7 @@ gpu-cu130 = [ [package.dev-dependencies] dev = [ + { name = "pre-commit" }, { name = "pytest" }, ] @@ -2296,6 +2307,7 @@ requires-dist = [ { name = "aiohttp", marker = "extra == 'dev'", specifier = ">=3.9.3" }, { name = "beartype", specifier = ">=0.18" }, { name = "cf-xarray", specifier = ">=0.10" }, + { name = "einops", specifier = ">=0.8" }, { name = "etils", specifier = ">=1.13" }, { name = "fiddle", specifier = ">=0.3" }, { name = "fire", specifier = ">=0.7" }, @@ -2332,7 +2344,10 @@ requires-dist = [ provides-extras = ["dev", "gpu-cu128", "gpu-cu130"] [package.metadata.requires-dev] -dev = [{ name = "pytest", specifier = ">=9.0.3" }] +dev = [ + { name = "pre-commit", specifier = ">=4.6" }, + { name = "pytest", specifier = ">=9.0.3" }, +] mlflow = [] [[package]] From 7060f550e14e6345aef097c89c40fa323938ee37 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 00:31:19 +0200 Subject: [PATCH 22/34] refactor: use einops rearrange for ensemble unstack in LatentDiffusionTaskModule --- src/mlcast/modules/forecasting.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py index d0b36ff..d436b67 100644 --- a/src/mlcast/modules/forecasting.py +++ b/src/mlcast/modules/forecasting.py @@ -552,9 +552,11 @@ def forward( ) forecast_latents = self.sampler.sample(repeated_input_latents, latent_shape) decoded = self.autoencoder.decode(forecast_latents) - batch, time, channels, height, width = decoded.shape - decoded = decoded.reshape(x.shape[0], self.hparams["ensemble_size"], time, channels, height, width) - return decoded.permute(0, 2, 1, 3, 4, 5) + # Decoded latent has shape (B*E, T, C, H, W) because ensemble members + # were stacked in the batch dim via repeat_interleave. Unstack into an + # explicit ensemble dim and move time before ensemble for the standard + # (B, T, E, C, H, W) shape contract expected by loss functions etc. + return rearrange(decoded, "(b e) t c h w -> b t e c h w", e=self.hparams["ensemble_size"]) def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: """Compute latent diffusion loss for a forecasting batch. From f4115f8890c6965e8b3185622f51087e24cb0d24 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 01:06:44 +0200 Subject: [PATCH 23/34] refactor: make fiddlers apply recursively to nested Experiment configs Add _iter_experiment_configs tree-walker, applies_to_experiments decorator, and _find_nn_modules_with_input_channels helper so fiddlers work on both flat Experiment configs (convgru) and nested containers like LDCastTrainingExperiment. Refactor set_variables to walk pl_module for any nn.Module with input_channels instead of hardcoding cfg.pl_module.network. --- src/mlcast/config/fiddlers.py | 127 ++++++++++++++++++++++++++++++---- tests/config/test_fiddlers.py | 55 +++++++++++++-- 2 files changed, 162 insertions(+), 20 deletions(-) diff --git a/src/mlcast/config/fiddlers.py b/src/mlcast/config/fiddlers.py index 5e1cfb0..729e879 100644 --- a/src/mlcast/config/fiddlers.py +++ b/src/mlcast/config/fiddlers.py @@ -1,6 +1,6 @@ """Fiddler mutators for high-level semantic configuration changes. -Fiddlers are functions that accept a ``fdl.Config`` and mutate it in place. +Fiddlers are functions that accept a ``fdl.Config`` and mutate them in place. They are the right tool when a change spans multiple config parameters that must stay in sync — for example, switching the dataset class while preserving its existing parameters, or enabling masking consistently across both the data @@ -11,26 +11,120 @@ config in Python before passing it to ``fdl.build()``. """ +import functools import inspect import os +from collections.abc import Callable import fiddle as fdl +import torch.nn as nn from loguru import logger from pytorch_lightning.loggers import MLFlowLogger +from mlcast.config.base import Experiment + from ..callbacks import LogSystemInfoCallback from ..data.sequence import SourceDataRandomSequenceDataset +def _iter_experiment_configs(cfg: fdl.Buildable): + """Yield all ``fdl.Config`` sub-nodes whose callable is ``Experiment``, depth-first. + + Parameters + ---------- + cfg : fdl.Buildable + Root of the Fiddle configuration tree to traverse. + + Yields + ------ + fdl.Config + Each sub-config whose ``fdl.get_callable`` is the ``Experiment`` + dataclass. + """ + if not isinstance(cfg, fdl.Buildable): + return + try: + if fdl.get_callable(cfg) is Experiment: + yield cfg + except (TypeError, AttributeError): + pass + try: + for child in fdl.ordered_arguments(cfg).values(): + yield from _iter_experiment_configs(child) + except (TypeError, AttributeError): + pass + + +def _find_nn_modules_with_input_channels(cfg: fdl.Buildable): + """Yield all ``fdl.Config`` nodes for ``nn.Module`` subclasses that accept ``input_channels``. + + Parameters + ---------- + cfg : fdl.Buildable + Root of the Fiddle configuration tree to traverse (typically + ``cfg.pl_module``). + + Yields + ------ + fdl.Config + Each sub-config whose callable is an ``nn.Module`` subclass with + ``input_channels`` in its ``__init__`` signature. + """ + if not isinstance(cfg, fdl.Config): + return + try: + cls = fdl.get_callable(cfg) + if isinstance(cls, type) and issubclass(cls, nn.Module): + if "input_channels" in inspect.signature(cls.__init__).parameters: + yield cfg + except (TypeError, AttributeError): + pass + try: + for child in fdl.ordered_arguments(cfg).values(): + yield from _find_nn_modules_with_input_channels(child) + except (TypeError, AttributeError): + pass + + +def applies_to_experiments(fiddler: Callable) -> Callable: + """Decorate a fiddler so it applies to every ``Experiment`` sub-config in the tree. + + This makes fiddlers work with both flat ``Experiment`` configs (returned by + ``convgru_training_experiment``) and nested containers like + ``LDCastTrainingExperiment`` that contain multiple ``Experiment`` instances. + + Parameters + ---------- + fiddler : Callable + Fiddler function whose first argument is a ``fdl.Config``. + + Returns + ------- + Callable + Wrapped fiddler that traverses the config tree for ``Experiment`` + sub-configs and applies the original fiddler to each one. + """ + + @functools.wraps(fiddler) + def wrapper(cfg: fdl.Buildable, *args: object, **kwargs: object) -> None: + experiments = list(_iter_experiment_configs(cfg)) + if experiments: + for exp_cfg in experiments: + fiddler(exp_cfg, *args, **kwargs) + else: + fiddler(cfg, *args, **kwargs) + + return wrapper + + +@applies_to_experiments def set_variables(cfg: fdl.Config, standard_names: list[str]) -> None: """Fiddler to synchronize dataset variables with the network's input channels. - Sets ``sequence_dataset_factory.standard_names`` on the data config and, when the - network config exposes an ``input_channels`` parameter (e.g. - ``ConvGruModel``), keeps it in sync. Networks that use a different - parameter name for the channel count (e.g. ``HalfUNet`` uses - ``in_channels``) are left unchanged — callers are responsible for keeping - that parameter consistent when swapping in an external architecture. + Sets ``sequence_dataset_factory.standard_names`` on the data config and + walks ``cfg.pl_module`` to find any ``nn.Module`` with an ``input_channels`` + ``__init__`` parameter (e.g. ``ConvGruModel``, ``Encoder``), keeping it in + sync with the number of loaded variables. Parameters ---------- @@ -40,18 +134,17 @@ def set_variables(cfg: fdl.Config, standard_names: list[str]) -> None: The new list of standard names to load. """ cfg.data.sequence_dataset_factory.standard_names = standard_names - network_cls = cfg.pl_module.network.__fn_or_cls__ - sig = inspect.signature(network_cls.__init__) - if "input_channels" in sig.parameters: - cfg.pl_module.network.input_channels = len(standard_names) - else: + found = False + for module_cfg in _find_nn_modules_with_input_channels(cfg.pl_module): + module_cfg.input_channels = len(standard_names) + found = True + if not found: logger.warning( - "set_variables: network {} has no 'input_channels' parameter; " - "channel count not updated. Set it manually on the network config.", - network_cls.__name__, + "set_variables: no nn.Module under pl_module has an 'input_channels' parameter; channel count not updated." ) +@applies_to_experiments def toggle_masking(cfg: fdl.Config, enabled: bool) -> None: """Fiddler to synchronize forecasting-mask yielding with masked loss computation. @@ -66,6 +159,7 @@ def toggle_masking(cfg: fdl.Config, enabled: bool) -> None: cfg.pl_module.masked_loss = enabled +@applies_to_experiments def use_random_sampler(cfg: fdl.Config) -> None: """Fiddler to switch the sequence dataset factory to use the random sampler. @@ -84,11 +178,13 @@ def use_random_sampler(cfg: fdl.Config) -> None: ) +@applies_to_experiments def use_ratio_splits(cfg: fdl.Config, train: float, val: float) -> None: """Fiddler to set fraction-based train/val/test splits on the data module.""" cfg.data.splits = {"time": {"train": train, "val": val, "test": 1.0 - train - val}} +@applies_to_experiments def use_anon_s3_dataset(cfg: fdl.Buildable, zarr_path: str, endpoint_url: str) -> None: """Configure the dataset factory to read anonymously from an S3 object store. @@ -112,6 +208,7 @@ def use_anon_s3_dataset(cfg: fdl.Buildable, zarr_path: str, endpoint_url: str) - } +@applies_to_experiments def use_mlflow_logger(cfg: fdl.Config) -> None: """Fiddler to switch the trainer logger to MLflow. diff --git a/tests/config/test_fiddlers.py b/tests/config/test_fiddlers.py index 347f2e7..d807f7d 100644 --- a/tests/config/test_fiddlers.py +++ b/tests/config/test_fiddlers.py @@ -1,14 +1,22 @@ -from mlcast.config import convgru_training_experiment, set_variables, toggle_masking +import fiddle as fdl + +from mlcast.config import ( + convgru_training_experiment, + ldcast_training_experiment, + set_variables, + toggle_masking, + use_random_sampler, + use_ratio_splits, +) +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset, SourceDataRandomSequenceDataset def test_fiddler_set_variables() -> None: """Verify set_variables syncs dataset variables and network input_channels.""" cfg = convgru_training_experiment.as_buildable() - # Apply fiddler set_variables(cfg, ["rainfall_rate", "rainfall_flux"]) - # Check sync assert cfg.data.sequence_dataset_factory.standard_names == ["rainfall_rate", "rainfall_flux"] assert cfg.pl_module.network.input_channels == 2 @@ -17,12 +25,49 @@ def test_fiddler_toggle_masking() -> None: """Verify toggle_masking syncs dataset mask return and module masked_loss.""" cfg = convgru_training_experiment.as_buildable() - # Disable masking toggle_masking(cfg, False) assert cfg.data.return_mask is False assert cfg.pl_module.masked_loss is False - # Enable masking toggle_masking(cfg, True) assert cfg.data.return_mask is True assert cfg.pl_module.masked_loss is True + + +def test_fiddler_set_variables_on_ldcast() -> None: + """Verify set_variables applies to both stages of an LDCastTrainingExperiment.""" + cfg = ldcast_training_experiment.as_buildable() + + set_variables(cfg, ["rainfall_rate", "rainfall_flux", "rainfall_intensity"]) + + # Both stages share the same sequence_dataset_factory object + expected_names = ["rainfall_rate", "rainfall_flux", "rainfall_intensity"] + assert cfg.stage1.data.sequence_dataset_factory.standard_names == expected_names + assert cfg.stage2.data.sequence_dataset_factory.standard_names == expected_names + + # Encoder (inside AutoencoderNet) has input_channels and should be updated + assert cfg.stage1.pl_module.network.encoder.input_channels == 3 + assert cfg.stage2.pl_module.autoencoder.encoder.input_channels == 3 + + +def test_fiddler_use_random_sampler_on_ldcast() -> None: + """Verify use_random_sampler applies to both stages of LDCastTrainingExperiment.""" + cfg = ldcast_training_experiment.as_buildable() + + assert fdl.get_callable(cfg.stage1.data.sequence_dataset_factory) is SourceDataPrecomputedSequenceDataset + assert fdl.get_callable(cfg.stage2.data.sequence_dataset_factory) is SourceDataPrecomputedSequenceDataset + + use_random_sampler(cfg) + + assert fdl.get_callable(cfg.stage1.data.sequence_dataset_factory) is SourceDataRandomSequenceDataset + assert fdl.get_callable(cfg.stage2.data.sequence_dataset_factory) is SourceDataRandomSequenceDataset + + +def test_fiddler_use_ratio_splits_on_ldcast() -> None: + """Verify use_ratio_splits applies to both stages of LDCastTrainingExperiment.""" + cfg = ldcast_training_experiment.as_buildable() + + use_ratio_splits(cfg, train=0.6, val=0.2) + + assert cfg.stage1.data.splits == {"time": {"train": 0.6, "val": 0.2, "test": 0.2}} + assert cfg.stage2.data.splits == {"time": {"train": 0.6, "val": 0.2, "test": 0.2}} From 4d3574ac700bafbb72c1b9600b4963b07daca5d0 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 01:21:52 +0200 Subject: [PATCH 24/34] refactor: move archetype experiments to mlcast.config.archetype subpackage Split Experiment dataclass (base.py) from experiment-specific config functions. Move convgru_training_experiment to archetype/convgru.py and ldcast_training_experiment to archetype/ldcast.py. --- .pre-commit-config.yaml | 4 +- README.md | 12 ++- src/mlcast/config/__init__.py | 5 +- src/mlcast/config/archetype/__init__.py | 0 src/mlcast/config/archetype/convgru.py | 85 ++++++++++++++++ src/mlcast/config/{ => archetype}/ldcast.py | 3 +- src/mlcast/config/base.py | 102 +------------------- 7 files changed, 102 insertions(+), 109 deletions(-) create mode 100644 src/mlcast/config/archetype/__init__.py create mode 100644 src/mlcast/config/archetype/convgru.py rename src/mlcast/config/{ => archetype}/ldcast.py (99%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 40413cf..b22aad4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,13 +32,13 @@ repos: name: convgru config diagram up to date language: system entry: uv run python docs/generate_base_experiment_config_diagram.py --check - files: ^src/mlcast/config/base\.py$ + files: ^src/mlcast/config/(base|archetype/convgru)\.py$ pass_filenames: false - id: config-diagram-ldcast name: ldcast config diagram up to date language: system entry: uv run python docs/generate_ldcast_config_diagram.py --check - files: ^src/mlcast/config/ldcast\.py$ + files: ^src/mlcast/config/archetype/ldcast\.py$ pass_filenames: false ci: autoupdate_schedule: monthly diff --git a/README.md b/README.md index 940aae0..eb94e88 100644 --- a/README.md +++ b/README.md @@ -69,10 +69,10 @@ reproduce runs exactly from a saved YAML file. mlcast ships with two included configuration functions: -- [`convgru_training_experiment`](src/mlcast/config/base.py) — defines a +- [`convgru_training_experiment`](src/mlcast/config/archetype/convgru.py) — defines a single-stage ConvGRU ensemble nowcasting setup (dataset, data module, network, Lightning module, trainer). -- [`ldcast_training_experiment`](src/mlcast/config/ldcast.py) — defines a +- [`ldcast_training_experiment`](src/mlcast/config/archetype/ldcast.py) — defines a two-stage LDCast setup: stage 1 trains an autoencoder on reconstruction windows, stage 2 trains a latent diffusion model on the same autoencoder's latent space. @@ -191,7 +191,7 @@ import fiddle as fdl from mlcast.config import convgru_training_experiment, train_from_config from mlcast.config.fiddlers import use_random_sampler -cfg = convgru_training_experiment.as_buildable() # returns a fdl.Config graph — see src/mlcast/config/base.py +cfg = convgru_training_experiment.as_buildable() # returns a fdl.Config graph — see src/mlcast/config/archetype/convgru.py # Apply a fiddler to switch the dataset sampler use_random_sampler(cfg) @@ -342,8 +342,10 @@ mlcast/ │ ├── callbacks.py # Training callbacks │ ├── visualization.py # TensorBoard image logging helpers │ ├── config/ -│ │ ├── base.py # ConvGRU training config @auto_config -│ │ ├── ldcast.py # LDCast two-stage config @auto_config +│ │ ├── base.py # Experiment dataclass +│ │ ├── archetype/ +│ │ │ ├── convgru.py # ConvGRU training config @auto_config +│ │ │ └── ldcast.py # LDCast two-stage config @auto_config │ │ ├── fiddlers.py # Semantic config mutators │ │ ├── consistency_checks.py # Cross-parameter validation │ │ ├── loader.py # YAML config loader diff --git a/src/mlcast/config/__init__.py b/src/mlcast/config/__init__.py index e566eb5..2b5ce79 100644 --- a/src/mlcast/config/__init__.py +++ b/src/mlcast/config/__init__.py @@ -4,7 +4,9 @@ and runtime orchestration logic for `mlcast`. """ -from .base import Experiment, convgru_training_experiment +from .archetype.convgru import convgru_training_experiment +from .archetype.ldcast import LDCastTrainingExperiment, ldcast_training_experiment +from .base import Experiment from .consistency_checks import validate_config from .fiddlers import ( set_variables, @@ -14,7 +16,6 @@ use_random_sampler, use_ratio_splits, ) -from .ldcast import LDCastTrainingExperiment, ldcast_training_experiment from .loader import load_yaml_config from .orchestrator import train_from_config diff --git a/src/mlcast/config/archetype/__init__.py b/src/mlcast/config/archetype/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mlcast/config/archetype/convgru.py b/src/mlcast/config/archetype/convgru.py new file mode 100644 index 0000000..95a2e51 --- /dev/null +++ b/src/mlcast/config/archetype/convgru.py @@ -0,0 +1,85 @@ +"""ConvGRU ensemble nowcasting experiment configuration.""" + +import fiddle as fdl +import fiddle.experimental.auto_config +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from mlcast.data.datamodules import ForecastingDataModule +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset +from mlcast.models.convgru import ConvGruModel +from mlcast.modules.forecasting import OutputSpaceForecastingTaskModule + +from ..base import Experiment + + +@fiddle.experimental.auto_config.auto_config +def convgru_training_experiment() -> Experiment: + """Build a Fiddle config for ConvGRU ensemble radar nowcasting. + + This is decorated as a Fiddle ``@auto_config`` function: calling it + returns a buildable config graph where any parameter can be overridden + before instantiation via ``fdl.build()``. + + Returns + ------- + Experiment + Configured experiment with model, data, and trainer. + """ + sequence_dataset_factory = fdl.Partial( + SourceDataPrecomputedSequenceDataset, + zarr_path="./data/radar.zarr", + csv_path="./data/sampled_datacubes.csv", + standard_names=["rainfall_rate"], + sequence_steps=18, + deterministic=False, + ) + + data = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=6, + forecast_steps=12, + return_mask=True, + splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, + batch_size=16, + num_workers=8, + pin_memory=True, + ) + + network = ConvGruModel( + input_steps=6, + forecast_steps=12, + ensemble_size=2, + input_channels=1, + num_blocks=5, + noisy_decoder=False, + ) + + pl_module = OutputSpaceForecastingTaskModule( + network=network, + loss_class="crps", + loss_params={"temporal_lambda": 0.01}, + masked_loss=True, + optimizer=fdl.Partial(torch.optim.Adam, lr=1e-4, fused=True), + lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.5, patience=10), + ) + + trainer = pl.Trainer( + accelerator="auto", + max_epochs=100, + callbacks=[ + ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min"), + ModelCheckpoint(monitor="train_loss_epoch", save_top_k=1, mode="min"), + EarlyStopping(monitor="val_loss", patience=100, mode="min"), + LearningRateMonitor(logging_interval="step"), + ], + logger=TensorBoardLogger(save_dir="logs", name="mlcast"), + ) + + return Experiment( + pl_module=pl_module, + data=data, + trainer=trainer, + ) diff --git a/src/mlcast/config/ldcast.py b/src/mlcast/config/archetype/ldcast.py similarity index 99% rename from src/mlcast/config/ldcast.py rename to src/mlcast/config/archetype/ldcast.py index ff01c5d..15e781d 100644 --- a/src/mlcast/config/ldcast.py +++ b/src/mlcast/config/archetype/ldcast.py @@ -9,7 +9,6 @@ from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger -from mlcast.config.base import Experiment from mlcast.data.datamodules import ForecastingDataModule, ReconstructionDataModule from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset from mlcast.models.autoencoder import AutoencoderNet, Decoder, Encoder @@ -17,6 +16,8 @@ from mlcast.modules.forecasting import LatentDiffusionTaskModule from mlcast.modules.reconstruction import ReconstructionTaskModule +from ..base import Experiment + @dataclass class LDCastTrainingExperiment: diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index 9343b08..2c76f3e 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -1,38 +1,12 @@ -"""Base Fiddle experiment definitions for ConvGRU radar nowcasting. +"""Base Fiddle experiment definition for radar nowcasting. -This module defines the ``Experiment`` dataclass and the -``convgru_training_experiment`` auto-config factory, which together form the -included configuration graph for a ConvGRU ensemble nowcasting run. - -``convgru_training_experiment`` is decorated with ``@auto_config``: calling it -returns a ``fdl.Config`` graph rather than a live ``Experiment`` object. Every -parameter in the graph can be overridden before instantiation — either via -fiddlers (for semantic, multi-parameter changes) or via ``set:`` overrides on -the CLI (for single-parameter tweaks). Call ``fdl.build(cfg)`` to materialise -the graph into real Python objects. - -Typical usage -------------- ->>> cfg = convgru_training_experiment() # returns fdl.Config ->>> use_random_sampler(cfg) # apply a fiddler ->>> validate_config(cfg) # check cross-parameter contracts ->>> experiment = fdl.build(cfg) # instantiate everything ->>> experiment.run() # train + test +This module defines the ``Experiment`` dataclass used across all experiment +configurations. """ from dataclasses import dataclass -import fiddle as fdl -import fiddle.experimental.auto_config import pytorch_lightning as pl -import torch -from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger - -from ..data.datamodules import ForecastingDataModule -from ..data.sequence import SourceDataPrecomputedSequenceDataset -from ..models.convgru import ConvGruModel -from ..modules.forecasting import OutputSpaceForecastingTaskModule @dataclass @@ -47,73 +21,3 @@ def run(self) -> None: """Train and evaluate the configured model.""" self.trainer.fit(self.pl_module, datamodule=self.data) self.trainer.test(self.pl_module, datamodule=self.data) - - -@fiddle.experimental.auto_config.auto_config -def convgru_training_experiment() -> Experiment: - """Build a Fiddle config for ConvGRU ensemble radar nowcasting. - - This is decorated as a Fiddle ``@auto_config`` function: calling it - returns a buildable config graph where any parameter can be overridden - before instantiation via ``fdl.build()``. - - Returns - ------- - Experiment - Configured experiment with model, data, and trainer. - """ - sequence_dataset_factory = fdl.Partial( - SourceDataPrecomputedSequenceDataset, - zarr_path="./data/radar.zarr", - csv_path="./data/sampled_datacubes.csv", - standard_names=["rainfall_rate"], - sequence_steps=18, - deterministic=False, - ) - - data = ForecastingDataModule( - sequence_dataset_factory=sequence_dataset_factory, - input_steps=6, - forecast_steps=12, - return_mask=True, - splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, - batch_size=16, - num_workers=8, - pin_memory=True, - ) - - network = ConvGruModel( - input_steps=6, - forecast_steps=12, - ensemble_size=2, - input_channels=1, - num_blocks=5, - noisy_decoder=False, - ) - - pl_module = OutputSpaceForecastingTaskModule( - network=network, - loss_class="crps", - loss_params={"temporal_lambda": 0.01}, - masked_loss=True, - optimizer=fdl.Partial(torch.optim.Adam, lr=1e-4, fused=True), - lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.5, patience=10), - ) - - trainer = pl.Trainer( - accelerator="auto", - max_epochs=100, - callbacks=[ - ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min"), - ModelCheckpoint(monitor="train_loss_epoch", save_top_k=1, mode="min"), - EarlyStopping(monitor="val_loss", patience=100, mode="min"), - LearningRateMonitor(logging_interval="step"), - ], - logger=TensorBoardLogger(save_dir="logs", name="mlcast"), - ) - - return Experiment( - pl_module=pl_module, - data=data, - trainer=trainer, - ) From c5b56ee247ce07ffb459f19f8e72775d37377260 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 01:42:41 +0200 Subject: [PATCH 25/34] fix: resolve test failures and align metric naming for LDCast merge - Fix 5D tensor shape mismatch in visualization.log_images (ensemble dim was not squeezed before apply_radar_colormap) - Fix EMA device mismatch in ExponentialMovingAverage when module moves between GPU stages (shadow_params stayed on CPU) - Fix README LDCast snippet: pass experiment configs to fiddlers instead of data module configs directly - Fix test harness _patch_cfg for nested LDCastTrainingExperiment (iterate Experiment sub-configs instead of assuming flat cfg.data) - Migrate metric names to TensorBoard convention: split/name format with rec_loss for reconstruction, loss for forecasting/diffusion --- AGENTS.md | 14 ++ README.md | 4 +- docs/config_diagram.svg | 6 +- docs/ldcast_config_diagram.svg | 210 ++++++++++++------------- ldcast-refactor-plan.md | 32 +++- src/mlcast/config/archetype/convgru.py | 6 +- src/mlcast/config/archetype/ldcast.py | 8 +- src/mlcast/models/diffusion/ema.py | 15 +- src/mlcast/modules/forecasting.py | 15 +- src/mlcast/modules/reconstruction.py | 4 +- src/mlcast/visualization.py | 5 +- tests/test_readme_snippets.py | 41 +++-- 12 files changed, 219 insertions(+), 141 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index a8628f7..a5c6375 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -23,6 +23,20 @@ Guidelines for contributors and AI agents working on this codebase. - Use `loguru` exclusively. Do not use the stdlib `logging` module. +## Metric naming + +- Logged metric names follow TensorBoard conventions: use `/` as a hierarchy + separator (e.g. `val/loss`, `train/rec_loss`) so that related metrics are + grouped in the TensorBoard UI. +- Use `rec_loss` for reconstruction-stage metrics and `loss` for + forecasting/diffusion-stage metrics to make clear which training stage a + metric belongs to. +- Metric name format: `{split}/{name}` where `split` is `train`, `val`, or + `test`. +- Monitoring references in config files (ModelCheckpoint, EarlyStopping, + lr_scheduler) must match the `val/{name}` variant of the metric they should + track. + ## Code style - Docstrings follow NumPy style. diff --git a/README.md b/README.md index eb94e88..a1accec 100644 --- a/README.md +++ b/README.md @@ -215,8 +215,8 @@ from mlcast.config.fiddlers import use_random_sampler cfg = ldcast_training_experiment.as_buildable() # Both stages share the same dataset, so switch to random sampling once -use_random_sampler(cfg.stage1.data) -use_random_sampler(cfg.stage2.data) +use_random_sampler(cfg.stage1) +use_random_sampler(cfg.stage2) # Override the diffusion noise schedule cfg.stage2.pl_module.diffusion_net.scheduler.timesteps = 20 diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 2ef2cd2..58f150f 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -399,7 +399,7 @@ monitor -'val_loss' +'val/loss' save_top_k @@ -427,7 +427,7 @@ monitor -'train_loss_epoch' +'train/loss_epoch' save_top_k @@ -455,7 +455,7 @@ monitor -'val_loss' +'val/loss' patience diff --git a/docs/ldcast_config_diagram.svg b/docs/ldcast_config_diagram.svg index 59803a2..bee8c14 100644 --- a/docs/ldcast_config_diagram.svg +++ b/docs/ldcast_config_diagram.svg @@ -4,11 +4,11 @@ - + %3 - + 4 @@ -284,56 +284,56 @@ 10 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - -0 - - -1 - - -2 - - -max_epochs - -20 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 1:c--10:c - + @@ -383,99 +383,99 @@ 11 - - -Config: - TensorBoardLogger - - -save_dir - -'logs' - - -name - -'mlcast_ldcast_stage1' + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast_ldcast_stage1' 10:c--11:c - + 12 - - -Config: - ModelCheckpoint + + +Config: + ModelCheckpoint monitor - -'val_loss' + +'val/rec_loss' save_top_k - + 1 mode - + 'min' 10:c--12:c - + 13 - - -Config: - EarlyStopping - - -monitor - -'val_loss' - - -patience - -20 - - -mode - -'min' + + +Config: + EarlyStopping + + +monitor + +'val/rec_loss' + + +patience + +20 + + +mode + +'min' 10:c--13:c - + 14 - - -Config: - LearningRateMonitor - - -logging_interval - -'step' + + +Config: + LearningRateMonitor + + +logging_interval + +'step' 10:c--14:c - + @@ -922,7 +922,7 @@ monitor -'val_loss' +'val/loss' save_top_k @@ -950,7 +950,7 @@ monitor -'val_loss' +'val/loss' patience diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 9ea0f61..9a584fb 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -87,8 +87,38 @@ - [x] Update docs and scripts that still reference `training_experiment`, including `README.md` and `docs/generate_base_experiment_config_diagram.py`. - [x] Keep existing ConvGRU CLI/config tests passing while adding separate tests for selecting the LDCast config explicitly. - [ ] Add real but small-scale end-to-end tests with generated sample data for the autoencoder stage, diffusion stage, and full LDCast stage sequencing. +- [ ] Align metric naming with TensorBoard conventions: use `/` as hierarchy separator and `rec_loss`/`loss` to distinguish stages — done (commit pending) -## DMI alignment notes +## 8. Align LDCast config with DMI/Martinbo reference + +### Optimizer +- [ ] Switch both stages to `AdamW` (from `Adam`) +- [ ] Set `betas=(0.5, 0.9)` for both stages +- [ ] Set `weight_decay=1e-3` for both stages +- [ ] Raise autoencoder LR to `1e-3`, keep diffusion LR at `1e-4` + +### LR scheduler +- [ ] Reduce `ReduceLROnPlateau` factor to `0.25` (from `0.5`) +- [ ] Reduce patience to `3` (from `10`) + +### EMA +- [ ] Increase EMA decay to `0.9999` (from `0.999`) +- [ ] Decide: EMA on full diffusion net (DMI) or denoiser only (Martinbo) + +### Early stopping +- [ ] Reduce patience to `6` (from `20`) +- [ ] Set `check_finite=False` on the diffusion stage + +### Model checkpointing +- [ ] Increase `save_top_k` to `3` (from `1`) + +### Diffusion noise schedule +- [ ] Increase `timesteps` to `1000` (from `20`) +- [ ] Set linear beta schedule from `1e-4` to `2e-2` + +### Batch size and gradient accumulation +- [ ] Reduce batch sizes (e.g. `batch_size=4` autoenc / `batch_size=1` diffusion) +- [ ] Add `accumulate_grad_batches=2` The `ldcast-dmi/` reference implementation differs from our current `ldcast_training_experiment` config in several ways. Changes below would diff --git a/src/mlcast/config/archetype/convgru.py b/src/mlcast/config/archetype/convgru.py index 95a2e51..1d144fe 100644 --- a/src/mlcast/config/archetype/convgru.py +++ b/src/mlcast/config/archetype/convgru.py @@ -70,9 +70,9 @@ def convgru_training_experiment() -> Experiment: accelerator="auto", max_epochs=100, callbacks=[ - ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min"), - ModelCheckpoint(monitor="train_loss_epoch", save_top_k=1, mode="min"), - EarlyStopping(monitor="val_loss", patience=100, mode="min"), + ModelCheckpoint(monitor="val/loss", save_top_k=1, mode="min"), + ModelCheckpoint(monitor="train/loss_epoch", save_top_k=1, mode="min"), + EarlyStopping(monitor="val/loss", patience=100, mode="min"), LearningRateMonitor(logging_interval="step"), ], logger=TensorBoardLogger(save_dir="logs", name="mlcast"), diff --git a/src/mlcast/config/archetype/ldcast.py b/src/mlcast/config/archetype/ldcast.py index 15e781d..1f96640 100644 --- a/src/mlcast/config/archetype/ldcast.py +++ b/src/mlcast/config/archetype/ldcast.py @@ -101,8 +101,8 @@ def ldcast_training_experiment() -> LDCastTrainingExperiment: accelerator="auto", max_epochs=20, callbacks=[ - ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min"), - EarlyStopping(monitor="val_loss", patience=20, mode="min"), + ModelCheckpoint(monitor="val/rec_loss", save_top_k=1, mode="min"), + EarlyStopping(monitor="val/rec_loss", patience=20, mode="min"), LearningRateMonitor(logging_interval="step"), ], logger=TensorBoardLogger(save_dir="logs", name="mlcast_ldcast_stage1"), @@ -136,8 +136,8 @@ def ldcast_training_experiment() -> LDCastTrainingExperiment: accelerator="auto", max_epochs=20, callbacks=[ - ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min"), - EarlyStopping(monitor="val_loss", patience=20, mode="min"), + ModelCheckpoint(monitor="val/loss", save_top_k=1, mode="min"), + EarlyStopping(monitor="val/loss", patience=20, mode="min"), LearningRateMonitor(logging_interval="step"), ], logger=TensorBoardLogger(save_dir="logs", name="mlcast_ldcast_stage2"), diff --git a/src/mlcast/models/diffusion/ema.py b/src/mlcast/models/diffusion/ema.py index 570599c..0ccd11f 100644 --- a/src/mlcast/models/diffusion/ema.py +++ b/src/mlcast/models/diffusion/ema.py @@ -28,11 +28,22 @@ def __init__(self, module: nn.Module, decay: float = 0.999) -> None: def update(self) -> None: """Update EMA shadow parameters from the current module parameters.""" trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] - for shadow_param, parameter in zip(self.shadow_params, trainable_params, strict=True): + for i, (shadow_param, parameter) in enumerate(zip(self.shadow_params, trainable_params, strict=True)): + if shadow_param.device != parameter.device: + self.shadow_params[i] = shadow_param.to(parameter.device) + shadow_param = self.shadow_params[i] shadow_param.mul_(self.decay).add_(parameter.detach(), alpha=1.0 - self.decay) + def _align_device(self) -> None: + """Move shadow parameters to the current device of the module's parameters.""" + for i, shadow_param in enumerate(self.shadow_params): + ref_param = next(self.module.parameters()) + if shadow_param.device != ref_param.device: + self.shadow_params[i] = shadow_param.to(ref_param.device) + def apply(self) -> None: """Swap EMA parameters into the tracked module.""" + self._align_device() trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] self.backup_params = [parameter.detach().clone() for parameter in trainable_params] for parameter, shadow_param in zip(trainable_params, self.shadow_params, strict=True): @@ -44,5 +55,7 @@ def restore(self) -> None: raise RuntimeError("EMA restore() called before apply().") trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] for parameter, backup_param in zip(trainable_params, self.backup_params, strict=True): + if backup_param.device != parameter.device: + backup_param = backup_param.to(parameter.device) parameter.data.copy_(backup_param.data) self.backup_params = None diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py index d436b67..91f0ae4 100644 --- a/src/mlcast/modules/forecasting.py +++ b/src/mlcast/modules/forecasting.py @@ -191,7 +191,7 @@ def configure_optimizers(self) -> Any: if self.lr_scheduler_factory is not None: lr_scheduler = self.lr_scheduler_factory(optimizer) - return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}} + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val/loss"}} return {"optimizer": optimizer} def predict(self, past: torch.Tensor, standard_name: str = "rainfall_rate") -> np.ndarray[Any, Any]: @@ -346,6 +346,9 @@ def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> # (B, T, M*C, H, W), preserving backward compatibility with CRPS etc. preds_flat = rearrange(preds, "b t m c h w -> b t (m c) h w") + ensemble_size = getattr(self.network, "ensemble_size", 1) + ensemble_std = preds.std(dim=2).mean() if ensemble_size > 1 else None + if self.hparams["masked_loss"]: mask = batch["target_mask"] loss = self.criterion(preds_flat, future, mask) @@ -358,12 +361,10 @@ def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True ) - self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + self.log(f"{split}/loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) - ensemble_size = getattr(self.network, "ensemble_size", 1) - if ensemble_size > 1: - ensemble_std = preds.std(dim=2).mean() - self.log(f"{split}_ensemble_std", ensemble_std, on_epoch=True, sync_dist=True) + if ensemble_std is not None: + self.log(f"{split}/ensemble_std", ensemble_std, on_epoch=True, sync_dist=True) if ( split == "train" @@ -577,7 +578,7 @@ def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> input_latents = self.autoencoder.encode(batch["input"]) target_latents = self.autoencoder.encode(batch["target"]) loss = self.loss_fn(input_latents, target_latents) - self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + self.log(f"{split}/loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) return loss @property diff --git a/src/mlcast/modules/reconstruction.py b/src/mlcast/modules/reconstruction.py index f23bab3..fc302bd 100644 --- a/src/mlcast/modules/reconstruction.py +++ b/src/mlcast/modules/reconstruction.py @@ -121,7 +121,7 @@ def shared_step(self, batch: torch.Tensor, split: str = "train") -> torch.Tensor self.log_dict( log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True ) - self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + self.log(f"{split}/rec_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) return loss def training_step(self, batch: torch.Tensor, _batch_idx: int) -> torch.Tensor: @@ -190,5 +190,5 @@ def configure_optimizers(self) -> Any: if self.lr_scheduler_factory is not None: lr_scheduler = self.lr_scheduler_factory(optimizer) - return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}} + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val/rec_loss"}} return {"optimizer": optimizer} diff --git a/src/mlcast/visualization.py b/src/mlcast/visualization.py index f1b11cd..6bd0759 100644 --- a/src/mlcast/visualization.py +++ b/src/mlcast/visualization.py @@ -96,10 +96,11 @@ def log_images( preds_avg = preds_sample.mean(dim=1, keepdim=True) num_members_to_log = min(3, preds_sample.shape[1]) rows = [future_sample.unsqueeze(1), preds_avg] + [preds_sample[:, i : i + 1] for i in range(num_members_to_log)] + all_frames = torch.cat(rows, dim=0).squeeze(1) else: - rows = [future_sample, preds_sample] + rows = [future_sample, preds_sample.squeeze(1)] + all_frames = torch.cat(rows, dim=0) - all_frames = torch.cat(rows, dim=0) all_frames_norm = (all_frames + 1) / 2 all_frames_rgb = apply_radar_colormap(all_frames_norm) preds_grid = torchvision.utils.make_grid(all_frames_rgb, nrow=future_sample.shape[0]) diff --git a/tests/test_readme_snippets.py b/tests/test_readme_snippets.py index c74cbec..450b486 100644 --- a/tests/test_readme_snippets.py +++ b/tests/test_readme_snippets.py @@ -13,7 +13,7 @@ import fiddle as fdl import pytest -from mlcast.config.fiddlers import set_variables, use_random_sampler +from mlcast.config.fiddlers import _iter_experiment_configs, set_variables, use_random_sampler _README = Path(__file__).parent.parent / "README.md" @@ -111,24 +111,43 @@ def _patch_cfg(cfg: fdl.Config, fp_dataset: Path, tmp_path: Path) -> None: Uses the ``set_variables`` fiddler (rather than direct assignment) so that ``network.input_channels`` is kept in sync with ``standard_names``. + Handles both flat ``Experiment`` configs and nested containers like + ``LDCastTrainingExperiment`` by finding all ``Experiment`` sub-configs + in the tree and patching each one. + Parameters ---------- cfg : fdl.Config - The Fiddle configuration graph to mutate in-place. + The Fiddle configuration to mutate in-place. + fp_dataset : Path + Local path to the cached test zarr store. + tmp_path : Path + Pytest-provided temporary directory for trainer outputs. + """ + for exp_cfg in _iter_experiment_configs(cfg): + _patch_single_experiment(exp_cfg, fp_dataset, tmp_path) + + +def _patch_single_experiment(exp_cfg: fdl.Config, fp_dataset: Path, tmp_path: Path) -> None: + """Apply lightweight training overrides to a single ``Experiment`` config. + + Parameters + ---------- + exp_cfg : fdl.Config + A single ``Experiment`` config node (has ``data``, ``trainer``). fp_dataset : Path Local path to the cached test zarr store. tmp_path : Path Pytest-provided temporary directory for trainer outputs. """ - cfg.data.sequence_dataset_factory.zarr_path = str(fp_dataset.absolute()) - set_variables(cfg, standard_names=["rainfall_flux"]) - # Switch to the on-the-fly random sampler so no pre-computed CSV is needed. - use_random_sampler(cfg) - cfg.data.splits = {"time": {"train": 0.4, "val": 0.3, "test": 0.3}} - cfg.trainer.fast_dev_run = True - cfg.data.batch_size = 1 - cfg.data.num_workers = 0 - cfg.trainer.default_root_dir = str(tmp_path) + exp_cfg.data.sequence_dataset_factory.zarr_path = str(fp_dataset.absolute()) + set_variables(exp_cfg, standard_names=["rainfall_flux"]) + use_random_sampler(exp_cfg) + exp_cfg.data.splits = {"time": {"train": 0.4, "val": 0.3, "test": 0.3}} + exp_cfg.trainer.fast_dev_run = True + exp_cfg.data.batch_size = 1 + exp_cfg.data.num_workers = 0 + exp_cfg.trainer.default_root_dir = str(tmp_path) def _inject_patch(snippet: str) -> ast.Module: From 5a01f5cb4201fff7ef84407a2eb4ce2e67c5811b Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 01:48:53 +0200 Subject: [PATCH 26/34] docs: add ldcast-notes.md documenting DMI/Martinbo implementation differences --- ldcast-notes.md | 67 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 ldcast-notes.md diff --git a/ldcast-notes.md b/ldcast-notes.md new file mode 100644 index 0000000..395ffba --- /dev/null +++ b/ldcast-notes.md @@ -0,0 +1,67 @@ +# LDCast implementation notes + +Architecture decisions and differences between reference implementations. + +## EMA scope + +| Reference | Scope | Decay | +|-----------|-------|-------| +| **DMI** | Full `LatentDiffusionNet` (conditioner + denoiser + scheduler buffers) | `0.9999` | +| **Martinbo** | Denoiser submodule only (`diffusion_net.denoiser`) | `0.9999` | +| **Ours** | Full `LatentDiffusionNet` (matches DMI) | `0.9999` | + +**Rationale for full-network EMA**: The conditioner is a single-pass feed-forward network +called once per sample, so weight noise matters less than in the denoiser. However, there +is no downside to smoothing it too, and it keeps the code simpler (EMA wraps the entire +diffusion net rather than reaching into a private submodule). Full-network EMA is the +standard practice in DDPM, Stable Diffusion, and DMI's reference. + +If denoiser-only EMA were desired in the future, the change is in +`forecasting.py:514`: + +```python +# Full network (current, matches DMI): +self.ema = ExponentialMovingAverage(diffusion_net, decay=ema_decay) + +# Denoiser only (Martinbo): +self.ema = ExponentialMovingAverage(diffusion_net.denoiser, decay=ema_decay) +``` + +## Optimizer + +| Reference | Type | Betas | Weight decay | Autoencoder LR | Diffusion LR | +|-----------|------|-------|-------------|----------------|--------------| +| **DMI** | `AdamW` | `(0.5, 0.9)` | `1e-3` | `1e-3` | `1e-4` | +| **Martinbo** | `AdamW` | `[0.5, 0.9]` | `0.001` | `1e-3` | `1e-4` | +| **Ours** | `AdamW` | `(0.5, 0.9)` | `1e-3` | `1e-3` | `1e-4` | + +Both references agree on all optimizer settings. + +## LR scheduler + +`ReduceLROnPlateau(factor=0.25, patience=3)` for both stages. Monitor metric naming +differs: DMI uses `val_rec_loss` / `val_loss_ema`, Martinbo uses `val/rec_loss` / +`val/loss` (TensorBoard convention). Ours follows Martinbo's naming. + +## Diffusion noise schedule + +Both references use `timesteps=1000` with `beta_start=1e-4, beta_end=2e-2`. +`DiffusionScheduler` defaults already match these beta bounds. + +## Monitor metric naming + +DMI uses underscores (`val_rec_loss`), Martinbo uses TensorBoard-style slashes +(`val/rec_loss`). We follow Martinbo / TensorBoard convention — slashes give +automatic grouping in the TensorBoard UI. + +## Early stopping + +DMI and Martinbo both use `patience=6`. Martinbo adds `check_finite=False` on the +diffusion stage. We follow both. + +## Batch size + +None of the three implementations agree on batch size: +- DMI: `batch_size=4` (autoencoder) / `1` (diffusion) — example configs +- Martinbo: `batch_size=1` for both stages +- Ours: `batch_size=4` (autoencoder) / `1` (diffusion) — matches DMI From b1241d569c7898712f72a1ca088cdddc9807bcb3 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 01:51:20 +0200 Subject: [PATCH 27/34] refactor: align LDCast config with DMI/Martinbo reference - Switch both stages to AdamW with betas=(0.5, 0.9), weight_decay=1e-3 - Raise autoencoder LR to 1e-3, keep diffusion LR at 1e-4 - Reduce LR scheduler factor to 0.25, patience to 3 - Increase EMA decay to 0.9999 - Increase diffusion timesteps to 1000 with default beta range - Reduce early stopping patience to 6, add check_finite=False on stage2 - Increase model checkpoint save_top_k to 3 - Reduce batch sizes (stage1: 16->4, stage2: 8->1) - Add accumulate_grad_batches=2 to both stages --- docs/ldcast_config_diagram.svg | 1792 +++++++++++++------------ src/mlcast/config/archetype/ldcast.py | 26 +- 2 files changed, 942 insertions(+), 876 deletions(-) diff --git a/docs/ldcast_config_diagram.svg b/docs/ldcast_config_diagram.svg index bee8c14..a2cee39 100644 --- a/docs/ldcast_config_diagram.svg +++ b/docs/ldcast_config_diagram.svg @@ -4,986 +4,1050 @@ - + %3 - + 4 - - -Config: - Encoder - - -input_channels - -1 - - -hidden_channels - -16 - - -latent_channels - -32 - - -num_blocks - -2 + + +Config: + Encoder + + +input_channels + +1 + + +hidden_channels + +16 + + +latent_channels + +32 + + +num_blocks + +2 3 - - -Config: - AutoencoderNet - - -encoder - - - - - -decoder - - - + + +Config: + AutoencoderNet + + +encoder + + + + + +decoder + + + 3:c--4:c - + 5 - - -Config: - Decoder - - -output_channels - -1 - - -hidden_channels - -16 - - -latent_channels - -32 - - -num_blocks - -2 + + +Config: + Decoder + + +output_channels + +1 + + +hidden_channels + +16 + + +latent_channels + +32 + + +num_blocks + +2 3:c--5:c - + 2 - - -Config: - ReconstructionTaskModule - - -network - - - - - -loss_class - -'mse' - - -optimizer - - - - - -lr_scheduler - - - + + +Config: + ReconstructionTaskModule + + +network + + + + + +loss_class + +'mse' + + +optimizer + + + + + +lr_scheduler + + + 2:c--3:c - + - + 6 - - -Partial: - Adam - - -lr - -0.0001 - - -fused - -True + + +Partial: + AdamW + + +lr + +0.001 + + +betas + + + + + +weight_decay + +0.001 + + +fused + +True - + 2:c--6:c - + + + + +8 + + +Partial: + ReduceLROnPlateau + + +mode + +'min' + + +factor + +0.25 + + +patience + +3 + + + +2:c--8:c + - + 7 - - -Partial: - ReduceLROnPlateau - - -mode - -'min' - - -factor - -0.5 - - -patience - -10 - - - -2:c--7:c - + + +tuple + +0.5 + +0.9 + + +0 + + +1 + + + +6:c--7:c + - + 1 - - -Config: - Experiment - - -pl_module - - - - - -data - - - - - -trainer - - - + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + - + 1:c--2:c - - - - -8 - - -Config: - ReconstructionDataModule - - -sequence_dataset_factory - - - - - -input_steps - -4 - - -splits - - - -dict - - -'time' - - - -dict - - -'train' - -0.7 - - -'val' - -0.15 - - -'test' - -0.15 - - -batch_size - -16 - - -num_workers - -8 - - -pin_memory - -True - - - -1:c--8:c - - - - -10 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - -0 - - -1 - - -2 - - -max_epochs - -20 - - - -1:c--10:c - + - + 9 - - -Partial: - SourceDataPrecomputedSequenceDataset - - -zarr_path - -'./data/radar.zarr' - - -csv_path - -'./data/sampled_datacubes.csv' - - -standard_names - - - -list - -'rainfall_rate' - - -0 - - -sequence_steps - -16 - - -deterministic - -False - - - -8:c--9:c - + + +Config: + ReconstructionDataModule + + +sequence_dataset_factory + + + + + +input_steps + +4 + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +4 + + +num_workers + +8 + + +pin_memory + +True + + + +1:c--9:c + - + 11 - - -Config: - TensorBoardLogger - - -save_dir - -'logs' - - -name - -'mlcast_ldcast_stage1' - - - -10:c--11:c - + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 + + +accumulate_grad_batches + +2 + + + +1:c--11:c + + + + +10 + + +Partial: + SourceDataPrecomputedSequenceDataset + + +zarr_path + +'./data/radar.zarr' + + +csv_path + +'./data/sampled_datacubes.csv' + + +standard_names + + + +list + +'rainfall_rate' + + +0 + + +sequence_steps + +16 + + +deterministic + +False + + + +9:c--10:c + - + 12 - - -Config: - ModelCheckpoint - - -monitor - -'val/rec_loss' - - -save_top_k - -1 - - -mode - -'min' - - + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast_ldcast_stage1' + + -10:c--12:c - +11:c--12:c + 13 - - -Config: - EarlyStopping - - -monitor - -'val/rec_loss' - - -patience - -20 - - -mode - -'min' - - + + +Config: + ModelCheckpoint + + +monitor + +'val/rec_loss' + + +save_top_k + +3 + + +mode + +'min' + + -10:c--13:c - +11:c--13:c + 14 - - -Config: - LearningRateMonitor - - -logging_interval - -'step' - - + + +Config: + EarlyStopping + + +monitor + +'val/rec_loss' + + +patience + +6 + + +mode + +'min' + + -10:c--14:c - +11:c--14:c + - + +15 + + +Config: + LearningRateMonitor + + +logging_interval + +'step' + + + +11:c--15:c + + + + 0 - - -Config: - LDCastTrainingExperiment - - -stage1 - - - - - -stage2 - - - + + +Config: + LDCastTrainingExperiment + + +stage1 + + + + + +stage2 + + + - + 0:c--1:c - - - - -15 - - -Config: - Experiment - - -pl_module - - - - - -data - - - - - -trainer - - - - - - -0:c--15:c - + - + 16 - - -Config: - LatentDiffusionTaskModule - - -autoencoder - - - - - -diffusion_net - - - - - -forecast_steps - -12 - - -ensemble_size - -2 - - -optimizer - - - - - -lr_scheduler - - - - - -ema_decay - -0.999 - - - -16:c--3:c - + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + + + + +0:c--16:c + - + 17 - - -Config: - LatentDiffusionNet - - -conditioner - - - - - -denoiser - - - - - -scheduler - - - + + +Config: + LatentDiffusionTaskModule + + +autoencoder + + + + + +diffusion_net + + + + + +forecast_steps + +12 + + +ensemble_size + +2 + + +optimizer + + + + + +lr_scheduler + + + + + +ema_decay + +0.9999 - - -16:c--17:c - + + +17:c--3:c + - - -21 - - -Partial: - Adam - - -lr - -0.0001 - - -fused - -True - - + + +18 + + +Config: + LatentDiffusionNet + + +conditioner + + + + + +denoiser + + + + + +scheduler + + + + + -16:c--21:c - +17:c--18:c + 22 - - -Partial: - ReduceLROnPlateau - - -mode - -'min' - - -factor - -0.5 - - -patience - -10 - - - -16:c--22:c - + + +Partial: + AdamW + + +lr + +0.0001 + + +betas + + + + + +weight_decay + +0.001 + + +fused + +True - - -18 - - -Config: - ConditionerNet - - -latent_channels - -32 - - -hidden_channels - -32 - - -num_blocks - -2 + + +17:c--22:c + - - -17:c--18:c - + + +23 + + +Partial: + ReduceLROnPlateau + + +mode + +'min' + + +factor + +0.25 + + +patience + +3 + + + +17:c--23:c + - + 19 - - -Config: - DenoiserUNet - - -latent_channels - -32 - - -condition_channels - -32 - - -hidden_channels - -32 - - -num_blocks - -2 - - + + +Config: + ConditionerNet + + +latent_channels + +32 + + +hidden_channels + +32 + + +num_blocks + +2 + + -17:c--19:c - +18:c--19:c + 20 - - -Config: - DiffusionScheduler - - -timesteps - -20 - - + + +Config: + DenoiserUNet + + +latent_channels + +32 + + +condition_channels + +32 + + +hidden_channels + +32 + + +num_blocks + +2 + + -17:c--20:c - +18:c--20:c + - - -15:c--16:c - + + +21 + + +Config: + DiffusionScheduler + + +timesteps + +1000 - - -23 - - -Config: - ForecastingDataModule - - -sequence_dataset_factory - - - - - -input_steps - -4 - - -forecast_steps - -12 - - -return_mask - -False - - -splits - - - -dict - - -'time' - - - -dict - - -'train' - -0.7 - - -'val' - -0.15 - - -'test' - -0.15 - - -batch_size - -8 - - -num_workers - -8 - - -pin_memory - -True - - + + +18:c--21:c + + + + +22:c--7:c + + + -15:c--23:c - +16:c--17:c + - + 24 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - -0 - - -1 - - -2 - - -max_epochs - -20 - - - -15:c--24:c - + + +Config: + ForecastingDataModule + + +sequence_dataset_factory + + + + + +input_steps + +4 + + +forecast_steps + +12 + + +return_mask + +False + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +1 + + +num_workers + +8 + + +pin_memory + +True - - -23:c--9:c - + + +16:c--24:c + - + 25 - - -Config: - TensorBoardLogger - - -save_dir - -'logs' - - -name - -'mlcast_ldcast_stage2' - - + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 + + +accumulate_grad_batches + +2 + + + +16:c--25:c + + + -24:c--25:c - +24:c--10:c + - + 26 - - -Config: - ModelCheckpoint - - -monitor - -'val/loss' - - -save_top_k - -1 - - -mode - -'min' - - - -24:c--26:c - + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast_ldcast_stage2' + + + +25:c--26:c + 27 - - -Config: - EarlyStopping - - -monitor - -'val/loss' - - -patience - -20 - - -mode - -'min' - - - -24:c--27:c - + + +Config: + ModelCheckpoint + + +monitor + +'val/loss' + + +save_top_k + +3 + + +mode + +'min' + + + +25:c--27:c + 28 - - -Config: - LearningRateMonitor - - -logging_interval - -'step' - - - -24:c--28:c - + + +Config: + EarlyStopping + + +monitor + +'val/loss' + + +patience + +6 + + +mode + +'min' + + +check_finite + +False + + + +25:c--28:c + + + + +29 + + +Config: + LearningRateMonitor + + +logging_interval + +'step' + + + +25:c--29:c + diff --git a/src/mlcast/config/archetype/ldcast.py b/src/mlcast/config/archetype/ldcast.py index 1f96640..c059586 100644 --- a/src/mlcast/config/archetype/ldcast.py +++ b/src/mlcast/config/archetype/ldcast.py @@ -87,22 +87,23 @@ def ldcast_training_experiment() -> LDCastTrainingExperiment: sequence_dataset_factory=sequence_dataset_factory, input_steps=input_steps, splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, - batch_size=16, + batch_size=4, num_workers=8, pin_memory=True, ) stage1_module = ReconstructionTaskModule( network=autoencoder, loss_class="mse", - optimizer=fdl.Partial(torch.optim.Adam, lr=1e-4, fused=True), - lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.5, patience=10), + optimizer=fdl.Partial(torch.optim.AdamW, lr=1e-3, betas=(0.5, 0.9), weight_decay=1e-3, fused=True), + lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.25, patience=3), ) stage1_trainer = pl.Trainer( accelerator="auto", max_epochs=20, + accumulate_grad_batches=2, callbacks=[ - ModelCheckpoint(monitor="val/rec_loss", save_top_k=1, mode="min"), - EarlyStopping(monitor="val/rec_loss", patience=20, mode="min"), + ModelCheckpoint(monitor="val/rec_loss", save_top_k=3, mode="min"), + EarlyStopping(monitor="val/rec_loss", patience=6, mode="min"), LearningRateMonitor(logging_interval="step"), ], logger=TensorBoardLogger(save_dir="logs", name="mlcast_ldcast_stage1"), @@ -114,30 +115,31 @@ def ldcast_training_experiment() -> LDCastTrainingExperiment: forecast_steps=forecast_steps, return_mask=False, splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, - batch_size=8, + batch_size=1, num_workers=8, pin_memory=True, ) diffusion_net = LatentDiffusionNet( conditioner=ConditionerNet(latent_channels=32, hidden_channels=32, num_blocks=2), denoiser=DenoiserUNet(latent_channels=32, condition_channels=32, hidden_channels=32, num_blocks=2), - scheduler=DiffusionScheduler(timesteps=20), + scheduler=DiffusionScheduler(timesteps=1000), ) stage2_module = LatentDiffusionTaskModule( autoencoder=autoencoder, diffusion_net=diffusion_net, forecast_steps=forecast_steps, ensemble_size=2, - optimizer=fdl.Partial(torch.optim.Adam, lr=1e-4, fused=True), - lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.5, patience=10), - ema_decay=0.999, + optimizer=fdl.Partial(torch.optim.AdamW, lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-3, fused=True), + lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.25, patience=3), + ema_decay=0.9999, ) stage2_trainer = pl.Trainer( accelerator="auto", max_epochs=20, + accumulate_grad_batches=2, callbacks=[ - ModelCheckpoint(monitor="val/loss", save_top_k=1, mode="min"), - EarlyStopping(monitor="val/loss", patience=20, mode="min"), + ModelCheckpoint(monitor="val/loss", save_top_k=3, mode="min"), + EarlyStopping(monitor="val/loss", patience=6, mode="min", check_finite=False), LearningRateMonitor(logging_interval="step"), ], logger=TensorBoardLogger(save_dir="logs", name="mlcast_ldcast_stage2"), From 8990c4771d39c24d5a05defc5c27643289f79c80 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 01:52:01 +0200 Subject: [PATCH 28/34] docs: explain why EMA is used for latent diffusion in docstring --- src/mlcast/modules/forecasting.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py index 91f0ae4..fcdfd83 100644 --- a/src/mlcast/modules/forecasting.py +++ b/src/mlcast/modules/forecasting.py @@ -486,7 +486,13 @@ class LatentDiffusionTaskModule(BaseForecastingTaskModule): Learning-rate scheduler factory. Default is ``None``. ema_decay : float or None, optional If provided, track an exponential moving average of diffusion-net - parameters with this decay. Default is ``None``. + parameters with this decay (commonly ``0.999`` or ``0.9999``). EMA- + smoothed weights are swapped in during validation, testing, and + prediction — the raw weights receive gradient updates during training. + This is standard practice in diffusion models: the iterative denoising + process amplifies small weight fluctuations, and EMA averaging + suppresses that noise for cleaner samples at inference time. Default is + ``None`` (no EMA). """ def __init__( From 2d151a747fac81349547ff1d4016d0f14da032b9 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 01:53:33 +0200 Subject: [PATCH 29/34] =?UTF-8?q?docs:=20simplify=20LDCast=20README=20snip?= =?UTF-8?q?pet=20=E2=80=94=20single=20use=5Frandom=5Fsampler(cfg)=20covers?= =?UTF-8?q?=20both=20stages=20via=20@applies=5Fto=5Fexperiments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index a1accec..91cfe31 100644 --- a/README.md +++ b/README.md @@ -214,9 +214,8 @@ from mlcast.config.fiddlers import use_random_sampler cfg = ldcast_training_experiment.as_buildable() -# Both stages share the same dataset, so switch to random sampling once -use_random_sampler(cfg.stage1) -use_random_sampler(cfg.stage2) +# Applied once — @applies_to_experiments walks both stages automatically +use_random_sampler(cfg) # Override the diffusion noise schedule cfg.stage2.pl_module.diffusion_net.scheduler.timesteps = 20 From 75a54be6ee715012b7c0e8f165d8b7989f6d48c2 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 01:59:55 +0200 Subject: [PATCH 30/34] docs: fix HalfUNet example jaxtyping, einops, and ensemble dim - Use codebase convention for jaxtyping dim names (time/channels/height/width) - Replace unsqueeze/cat with einops.rearrange for ensemble dim - Remove redundant/incorrect shape comment (pattern is self-documenting) - Clean up input_channels property docstring --- README.md | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 91cfe31..15b1a5e 100644 --- a/README.md +++ b/README.md @@ -271,26 +271,22 @@ class HalfUNetNowcaster(nn.Module): @property def input_channels(self) -> int: - # Externally, the HalfUNetNowcaster respects the required input shape structure - # (batch, input_steps, num_vars, H, W), even though the internal U-Net is channel-stacked. - # Adding this property allows the config consistency checks to verify that - # the dataset and model agree on the expected number of input channels. + # Externally the model handles (batch, time, channels, height, width); + # internally the U-Net channel-stacks time into (batch, time*channels, ...). + # This property lets config consistency checks verify dataset-model agreement. return self.num_vars def forward( self, - x: Float[torch.Tensor, "batch input_steps in_channels H W"], - ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels H W"]: - # channel-stack all input frames: (b, t, c, h, w) -> (b, t*c, h, w) + x: Float[torch.Tensor, "batch time channels height width"], + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: x_flat = einops.rearrange(x, "b t c h w -> b (t c) h w") preds = [] for _ in range(self.forecast_steps): - y = self.unet(x_flat) # [B, num_vars, H, W] - preds.append(y.unsqueeze(1)) - # slide window: drop the oldest timestep (first num_vars channels), - # append the latest prediction as the newest timestep + y = self.unet(x_flat) + preds.append(y) x_flat = torch.cat([x_flat[:, self.num_vars:], y], dim=1) - return torch.cat(preds, dim=1).unsqueeze(2) + return einops.rearrange(torch.stack(preds, dim=1), "b t c h w -> b t 1 c h w") cfg = convgru_training_experiment.as_buildable() use_random_sampler(cfg) From 3e48d33984514ee83c41aa175682ca564fdaf566 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 02:06:47 +0200 Subject: [PATCH 31/34] =?UTF-8?q?docs:=20update=20plan=20=E2=80=94=20mark?= =?UTF-8?q?=20section=208=20complete,=20add=20future=20work=20section=209?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ldcast-refactor-plan.md | 129 +++++++++++++++++++++++++++++++++++----- 1 file changed, 113 insertions(+), 16 deletions(-) diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 9a584fb..6362e40 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -87,38 +87,38 @@ - [x] Update docs and scripts that still reference `training_experiment`, including `README.md` and `docs/generate_base_experiment_config_diagram.py`. - [x] Keep existing ConvGRU CLI/config tests passing while adding separate tests for selecting the LDCast config explicitly. - [ ] Add real but small-scale end-to-end tests with generated sample data for the autoencoder stage, diffusion stage, and full LDCast stage sequencing. -- [ ] Align metric naming with TensorBoard conventions: use `/` as hierarchy separator and `rec_loss`/`loss` to distinguish stages — done (commit pending) +- [x] Align metric naming with TensorBoard conventions: use `/` as hierarchy separator and `rec_loss`/`loss` to distinguish stages. ## 8. Align LDCast config with DMI/Martinbo reference ### Optimizer -- [ ] Switch both stages to `AdamW` (from `Adam`) -- [ ] Set `betas=(0.5, 0.9)` for both stages -- [ ] Set `weight_decay=1e-3` for both stages -- [ ] Raise autoencoder LR to `1e-3`, keep diffusion LR at `1e-4` +- [x] Switch both stages to `AdamW` (from `Adam`) +- [x] Set `betas=(0.5, 0.9)` for both stages +- [x] Set `weight_decay=1e-3` for both stages +- [x] Raise autoencoder LR to `1e-3`, keep diffusion LR at `1e-4` ### LR scheduler -- [ ] Reduce `ReduceLROnPlateau` factor to `0.25` (from `0.5`) -- [ ] Reduce patience to `3` (from `10`) +- [x] Reduce `ReduceLROnPlateau` factor to `0.25` (from `0.5`) +- [x] Reduce patience to `3` (from `10`) ### EMA -- [ ] Increase EMA decay to `0.9999` (from `0.999`) -- [ ] Decide: EMA on full diffusion net (DMI) or denoiser only (Martinbo) +- [x] Increase EMA decay to `0.9999` (from `0.999`) +- [x] Decide: EMA on full diffusion net (DMI) or denoiser only (Martinbo) — chose full diffusion net (DMI) ### Early stopping -- [ ] Reduce patience to `6` (from `20`) -- [ ] Set `check_finite=False` on the diffusion stage +- [x] Reduce patience to `6` (from `20`) +- [x] Set `check_finite=False` on the diffusion stage ### Model checkpointing -- [ ] Increase `save_top_k` to `3` (from `1`) +- [x] Increase `save_top_k` to `3` (from `1`) ### Diffusion noise schedule -- [ ] Increase `timesteps` to `1000` (from `20`) -- [ ] Set linear beta schedule from `1e-4` to `2e-2` +- [x] Increase `timesteps` to `1000` (from `20`) +- [x] Set linear beta schedule from `1e-4` to `2e-2` ### Batch size and gradient accumulation -- [ ] Reduce batch sizes (e.g. `batch_size=4` autoenc / `batch_size=1` diffusion) -- [ ] Add `accumulate_grad_batches=2` +- [x] Reduce batch sizes (e.g. `batch_size=4` autoenc / `batch_size=1` diffusion) +- [x] Add `accumulate_grad_batches=2` The `ldcast-dmi/` reference implementation differs from our current `ldcast_training_experiment` config in several ways. Changes below would @@ -261,3 +261,100 @@ the DMI reference in several ways. - **Ours**: `parameterization="eps"` in `DiffusionLoss` (L2 via `nn.MSELoss` reduction). - **To align**: All three agree on `eps` + MSE — no change needed. + + +## 9. Possible future work + +Architecture and feature upgrades not covered by section 8 (config alignment). +Each entry explains why it might be worth doing. + +### VAE autoencoder (KL-regularised latent space) + +DMI uses `AutoencoderKL` — a VAE trained with a KL-divergence loss on the latent +distribution, producing a smoother, more Gaussian latent space. Ours uses a +deterministic autoencoder. + +**Why it might be a good idea**: Diffusion models assume the target distribution +is Gaussian (they start from Gaussian noise and reverse-diffuse). A KL- +regularised latent space is closer to Gaussian, which can make diffusion easier +and improve sample quality. It also enables latent-space interpolation and +manipulation. However, it adds training complexity (KL loss weighting, posterior +collapse risk). + +### Larger denoiser (3D UNet with cross-attention) + +DMI's `UNetModel` uses 128 model channels, attention blocks at multiple +resolutions, 8 attention heads, and 3D convolutions over (time, height, width). +Ours uses 32 hidden channels, no attention, and preserves time as a plain +channel dimension. + +**Why it might be a good idea**: More capacity → better fit to complex +precipitation patterns. Cross-attention allows the denoiser to selectively +attend to conditioning context at each resolution, which is more expressive +than our simple input concatenation. The cost is larger memory footprint and +longer training times. + +### Multi-resolution context encoder (AFNO cascade) + +DMI's `AFNONowcastNetCascade` produces a feature pyramid where each spatial +resolution has its own channel depth. The denoiser selects the appropriate level +based on its current spatial resolution. Ours uses a single-resolution +`ConditionerNet` that is interpolated if spatial sizes don't match. + +**Why it might be a good idea**: Multi-resolution conditioning lets the denoiser +access fine-grained local information at high resolutions and broad context at +low resolutions simultaneously. This is standard in modern conditional diffusion +models (e.g. Stable Diffusion's cross-attention to CLIP embeddings at multiple +scales). + +### PLMS accelerated sampling + +DMI's `PLMSSampler` uses Adams-Bashforth multistep integration to reduce 1000 +DDPM steps to ~50 with minimal quality loss. Ours uses a basic ancestral DDPM +sampler that iterates over all timesteps. + +**Why it might be a good idea**: 20× faster inference is critical for +operational nowcasting where latency matters. PLMS is well-established (from +CompVis latent-diffusion) and requires no retraining — it works with any +trained eps-predicting model. + +### Adaptive EMA decay (LitEma-style) + +DMI uses `LitEma` with adaptive decay `min(0.9999, (1+n)/(10+n))` where `n` is +the number of EMA updates. Ours uses a fixed decay of `0.9999`. + +**Why it might be a good idea**: Adaptive decay starts lower (giving more weight +to recent parameters early in training when they change fastest) and converges +to 0.9999. This accelerates early training while maintaining the benefits of EMA +at convergence. Simple to implement — just change the decay formula in `update`. + +### Multiple beta schedules (cosine, sqrt) + +DMI supports linear, cosine, sqrt_linear, and sqrt schedules. Ours supports +linear only. + +**Why it might be a good idea**: Cosine schedules (from Nichol & Dhariwal, +"Improved DDPM") add noise more gradually, which can improve sample quality — +especially at low resolutions or with fewer timesteps. Having multiple schedules +also enables hyperparameter search. The `DiffusionScheduler` would need to +accept a `schedule` string and dispatch to the right formula. + +### x0 parameterization and L1 loss + +DMI supports predicting `x0` (the clean target) instead of `eps` (the noise), +and using L1 instead of L2 for the loss. Ours uses `eps` + L2 only. + +**Why it might be a good idea**: `x0` prediction can be more stable at high +noise levels and is required for certain sampling techniques. L1 loss tends to +produce sharper outputs (less blurring than L2), which is desirable for +precipitation fields with sharp rain/no-rain boundaries. + +### Classifier-free guidance (CFG) in the sampler + +DMI's `PLMSSampler` supports CFG via an `unconditional_guidance_scale` +parameter. Ours has no CFG support. + +**Why it might be a good idea**: CFG lets the user trade off ensemble diversity +vs. forecast fidelity at inference time by scaling the conditional prediction +away from the unconditional prediction. Higher guidance → more confident +(less diverse) forecasts. This is useful operational flexibility. From 9b207e709ba167b735a0c2f2c4aed48a814c9a80 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 13:49:40 +0200 Subject: [PATCH 32/34] refactor: rename LDCast to latent_diffusion to reflect simplified scope Our architecture (deterministic autoencoder, small denoiser with no attention, single-resolution conditioner) is a lightweight latent diffusion model, not a true LDCast as defined by the DMI/Martinbo references. Renaming sets accurate expectations: - LDCastTrainingExperiment -> LatentDiffusionTrainingExperiment - ldcast_training_experiment -> latent_diffusion_experiment - ldcast.py -> latent_diffusion.py - All docs, tests, diagrams, and config references updated --- .pre-commit-config.yaml | 8 +- README.md | 38 +- ...enerate_latent_diffusion_config_diagram.py | 52 + docs/generate_ldcast_config_diagram.py | 52 - ...vg => latent_diffusion_config_diagram.svg} | 914 +++++++++--------- ldcast-refactor-plan.md | 47 +- src/mlcast/config/__init__.py | 6 +- .../{ldcast.py => latent_diffusion.py} | 18 +- src/mlcast/config/consistency_checks.py | 23 +- src/mlcast/config/fiddlers.py | 2 +- tests/config/test_fiddlers.py | 20 +- ...py => test_latent_diffusion_experiment.py} | 12 +- tests/config/test_orchestrator.py | 8 +- tests/test_readme_snippets.py | 2 +- 14 files changed, 587 insertions(+), 615 deletions(-) create mode 100644 docs/generate_latent_diffusion_config_diagram.py delete mode 100644 docs/generate_ldcast_config_diagram.py rename docs/{ldcast_config_diagram.svg => latent_diffusion_config_diagram.svg} (54%) rename src/mlcast/config/archetype/{ldcast.py => latent_diffusion.py} (92%) rename tests/config/{test_ldcast_experiment.py => test_latent_diffusion_experiment.py} (72%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b22aad4..72870b7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,11 +34,11 @@ repos: entry: uv run python docs/generate_base_experiment_config_diagram.py --check files: ^src/mlcast/config/(base|archetype/convgru)\.py$ pass_filenames: false - - id: config-diagram-ldcast - name: ldcast config diagram up to date + - id: config-diagram-latent-diffusion + name: latent diffusion config diagram up to date language: system - entry: uv run python docs/generate_ldcast_config_diagram.py --check - files: ^src/mlcast/config/archetype/ldcast\.py$ + entry: uv run python docs/generate_latent_diffusion_config_diagram.py --check + files: ^src/mlcast/config/archetype/latent_diffusion\.py$ pass_filenames: false ci: autoupdate_schedule: monthly diff --git a/README.md b/README.md index 15b1a5e..095046a 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,8 @@ mlcast ships with two included configuration functions: - [`convgru_training_experiment`](src/mlcast/config/archetype/convgru.py) — defines a single-stage ConvGRU ensemble nowcasting setup (dataset, data module, network, Lightning module, trainer). -- [`ldcast_training_experiment`](src/mlcast/config/archetype/ldcast.py) — defines a - two-stage LDCast setup: stage 1 trains an autoencoder on reconstruction +- [`latent_diffusion_experiment`](src/mlcast/config/archetype/latent_diffusion.py) — defines a + two-stage latent diffusion setup: stage 1 trains an autoencoder on reconstruction windows, stage 2 trains a latent diffusion model on the same autoencoder's latent space. @@ -98,9 +98,9 @@ The diagrams below show the full included config graphs. ![convgru_training_experiment config graph](docs/config_diagram.svg) -**ldcast_training_experiment:** +**latent_diffusion_experiment:** -![ldcast_training_experiment config graph](docs/ldcast_config_diagram.svg) +![latent_diffusion_experiment config graph](docs/latent_diffusion_config_diagram.svg) ### Design roles @@ -132,16 +132,16 @@ Install the package and run: ```bash # Single-stage ConvGRU nowcasting mlcast train --config config:convgru_training_experiment +# Two-stage latent diffusion -# Two-stage LDCast latent diffusion -mlcast train --config config:ldcast_training_experiment +mlcast train --config config:latent_diffusion_experiment ``` All parameters are controlled via `--config` flags: | Prefix | Purpose | Example | |--------|---------|---------| -| `config:` | Select an included `@auto_config` function | `--config config:convgru_training_experiment` or `--config config:ldcast_training_experiment` | +| `config:` | Select an included `@auto_config` function | `--config config:convgru_training_experiment` or `--config config:latent_diffusion_experiment` | | `set:` | Override a single parameter | `--config set:data.batch_size=32` | | `fiddler:` | Apply a semantic mutator (multi-param change) | `--config fiddler:use_random_sampler` | | `path/to/config.yaml` | Load a previously saved config | `--config saved.yaml` | @@ -168,9 +168,9 @@ mlcast train \ --config logs/mlcast/version_0/config.yaml \ --config set:trainer.max_epochs=50 -# Run two-stage LDCast training with a shorter diffusion schedule -mlcast train \ - --config config:ldcast_training_experiment \ +# Run two-stage latent diffusion training with a shorter diffusion schedule + + --config config:latent_diffusion_experiment \ --config set:stage2.pl_module.diffusion_net.scheduler.timesteps=20 # Inspect the fully resolved config without starting training @@ -205,14 +205,12 @@ cfg.trainer.max_epochs = 50 train_from_config(cfg) ``` -**Run the included LDCast experiment with tweaks:** +**Run the included latent diffusion experiment with tweaks:** -```python -import fiddle as fdl -from mlcast.config import ldcast_training_experiment, train_from_config +from mlcast.config import latent_diffusion_experiment, train_from_config from mlcast.config.fiddlers import use_random_sampler -cfg = ldcast_training_experiment.as_buildable() +cfg = latent_diffusion_experiment.as_buildable() # Applied once — @applies_to_experiments walks both stages automatically use_random_sampler(cfg) @@ -340,7 +338,7 @@ mlcast/ │ │ ├── base.py # Experiment dataclass │ │ ├── archetype/ │ │ │ ├── convgru.py # ConvGRU training config @auto_config -│ │ │ └── ldcast.py # LDCast two-stage config @auto_config +│ │ │ └── latent_diffusion.py # Two-stage latent diffusion config @auto_config │ │ ├── fiddlers.py # Semantic config mutators │ │ ├── consistency_checks.py # Cross-parameter validation │ │ ├── loader.py # YAML config loader @@ -413,9 +411,9 @@ stacked along an explicit ensemble dimension, giving the final shape ![ConvGruModel stochastic architecture](docs/architectures/convgru-stochastic.png) -### LatentDiffusionNet (LDCast) +### LatentDiffusionNet (two-stage latent diffusion) -LDCast is a **two-stage** latent diffusion nowcasting system. Stage 1 trains an +This is a **two-stage** latent diffusion nowcasting system. Stage 1 trains an autoencoder on reconstruction windows; stage 2 trains a latent diffusion model that forecasts in the autoencoder's latent space and decodes forecasts back to data space. @@ -423,7 +421,7 @@ data space. The architecture components live under `src/mlcast/models/autoencoder/` and `src/mlcast/models/diffusion/`. The task-level Lightning modules live under `src/mlcast/modules/` and are wired together by -[`ldcast_training_experiment`](src/mlcast/config/ldcast.py). +[`latent_diffusion_experiment`](src/mlcast/config/archetype/latent_diffusion.py). #### Stage 1 — Autoencoder reconstruction @@ -485,7 +483,7 @@ are stacked. #### Two-stage training experiment -The [`ldcast_training_experiment`](src/mlcast/config/ldcast.py) auto-config +The [`latent_diffusion_experiment`](src/mlcast/config/archetype/latent_diffusion.py) auto-config orchestrates both stages: - Stage 1 builds a `ReconstructionDataModule`, `AutoencoderNet`, and diff --git a/docs/generate_latent_diffusion_config_diagram.py b/docs/generate_latent_diffusion_config_diagram.py new file mode 100644 index 0000000..007a5ee --- /dev/null +++ b/docs/generate_latent_diffusion_config_diagram.py @@ -0,0 +1,52 @@ +"""Generate a Graphviz SVG diagram of the included latent diffusion training config. + +Run without arguments to regenerate docs/latent_diffusion_config_diagram.svg: + + uv run python docs/generate_latent_diffusion_config_diagram.py + +Run with --check to verify the diagram is up to date: + + uv run python docs/generate_latent_diffusion_config_diagram.py --check +""" + +import argparse +import sys +from pathlib import Path + +import fiddle.graphviz as fgv + +from mlcast.config import latent_diffusion_experiment + +OUT = Path(__file__).parent / "latent_diffusion_config_diagram.svg" + + +def main() -> None: + """Generate or verify the latent diffusion training config diagram.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--check", + action="store_true", + help="Check that the diagram is up to date rather than regenerating it.", + ) + args = parser.parse_args() + + cfg = latent_diffusion_experiment.as_buildable() + g = fgv.render(cfg, max_str_length=40) + g.format = "svg" + new_svg = g.pipe().decode() + + if args.check: + if not OUT.exists() or OUT.read_text() != new_svg: + print( + "docs/latent_diffusion_config_diagram.svg is out of date.\n" + "Run: uv run python docs/generate_latent_diffusion_config_diagram.py" + ) + sys.exit(1) + print("docs/latent_diffusion_config_diagram.svg is up to date.") + else: + OUT.write_text(new_svg) + print(f"Written {OUT}") + + +if __name__ == "__main__": + main() diff --git a/docs/generate_ldcast_config_diagram.py b/docs/generate_ldcast_config_diagram.py deleted file mode 100644 index a5686e7..0000000 --- a/docs/generate_ldcast_config_diagram.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Generate a Graphviz SVG diagram of the included LDCast training config. - -Run without arguments to regenerate docs/ldcast_config_diagram.svg: - - uv run python docs/generate_ldcast_config_diagram.py - -Run with --check to verify the diagram is up to date: - - uv run python docs/generate_ldcast_config_diagram.py --check -""" - -import argparse -import sys -from pathlib import Path - -import fiddle.graphviz as fgv - -from mlcast.config import ldcast_training_experiment - -OUT = Path(__file__).parent / "ldcast_config_diagram.svg" - - -def main() -> None: - """Generate or verify the LDCast training config diagram.""" - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument( - "--check", - action="store_true", - help="Check that the diagram is up to date rather than regenerating it.", - ) - args = parser.parse_args() - - cfg = ldcast_training_experiment.as_buildable() - g = fgv.render(cfg, max_str_length=40) - g.format = "svg" - new_svg = g.pipe().decode() - - if args.check: - if not OUT.exists() or OUT.read_text() != new_svg: - print( - "docs/ldcast_config_diagram.svg is out of date.\n" - "Run: uv run python docs/generate_ldcast_config_diagram.py" - ) - sys.exit(1) - print("docs/ldcast_config_diagram.svg is up to date.") - else: - OUT.write_text(new_svg) - print(f"Written {OUT}") - - -if __name__ == "__main__": - main() diff --git a/docs/ldcast_config_diagram.svg b/docs/latent_diffusion_config_diagram.svg similarity index 54% rename from docs/ldcast_config_diagram.svg rename to docs/latent_diffusion_config_diagram.svg index a2cee39..da917ad 100644 --- a/docs/ldcast_config_diagram.svg +++ b/docs/latent_diffusion_config_diagram.svg @@ -4,11 +4,11 @@ - + %3 - + 4 @@ -218,160 +218,160 @@ 1 - - -Config: - Experiment - - -pl_module - - - - - -data - - - - - -trainer - - - + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + 1:c--2:c - + 9 - - -Config: - ReconstructionDataModule - - -sequence_dataset_factory - - - - - -input_steps - -4 - - -splits - - - -dict - - -'time' - - - -dict - - -'train' - -0.7 - - -'val' - -0.15 - - -'test' - -0.15 - - -batch_size - -4 - - -num_workers - -8 - - -pin_memory - -True + + +Config: + ReconstructionDataModule + + +sequence_dataset_factory + + + + + +input_steps + +4 + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +4 + + +num_workers + +8 + + +pin_memory + +True 1:c--9:c - + 11 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - -0 - - -1 - - -2 - - -max_epochs - -20 - - -accumulate_grad_batches - -2 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 + + +accumulate_grad_batches + +2 1:c--11:c - + @@ -416,30 +416,30 @@ 9:c--10:c - + 12 - - -Config: - TensorBoardLogger + + +Config: + TensorBoardLogger save_dir - + 'logs' name - -'mlcast_ldcast_stage1' + +'mlcast_latent_diffusion_stage1' 11:c--12:c - + @@ -467,161 +467,161 @@ 11:c--13:c - + 14 - - -Config: - EarlyStopping - - -monitor - -'val/rec_loss' - - -patience - -6 - - -mode - -'min' + + +Config: + EarlyStopping + + +monitor + +'val/rec_loss' + + +patience + +6 + + +mode + +'min' 11:c--14:c - + 15 - - -Config: - LearningRateMonitor - - -logging_interval - -'step' + + +Config: + LearningRateMonitor + + +logging_interval + +'step' 11:c--15:c - + 0 - - -Config: - LDCastTrainingExperiment - - -stage1 - - - - - -stage2 - - - + + +Config: + LatentDiffusionTrainingExperiment + + +stage1 + + + + + +stage2 + + + 0:c--1:c - + 16 - - -Config: - Experiment - - -pl_module - - - - - -data - - - - - -trainer - - - + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + 0:c--16:c - + 17 - - -Config: - LatentDiffusionTaskModule - - -autoencoder - - - - - -diffusion_net - - - - - -forecast_steps - -12 - - -ensemble_size - -2 - - -optimizer - - - - - -lr_scheduler - - - - - -ema_decay - -0.9999 + + +Config: + LatentDiffusionTaskModule + + +autoencoder + + + + + +diffusion_net + + + + + +forecast_steps + +12 + + +ensemble_size + +2 + + +optimizer + + + + + +lr_scheduler + + + + + +ema_decay + +0.9999 17:c--3:c - + @@ -652,7 +652,7 @@ 17:c--18:c - + @@ -686,7 +686,7 @@ 17:c--22:c - + @@ -714,7 +714,7 @@ 17:c--23:c - + @@ -803,251 +803,251 @@ 16:c--17:c - + 24 - - -Config: - ForecastingDataModule - - -sequence_dataset_factory - - - - - -input_steps - -4 - - -forecast_steps - -12 - - -return_mask - -False - - -splits - - - -dict - - -'time' - - - -dict - - -'train' - -0.7 - - -'val' - -0.15 - - -'test' - -0.15 - - -batch_size - -1 - - -num_workers - -8 - - -pin_memory - -True + + +Config: + ForecastingDataModule + + +sequence_dataset_factory + + + + + +input_steps + +4 + + +forecast_steps + +12 + + +return_mask + +False + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +1 + + +num_workers + +8 + + +pin_memory + +True 16:c--24:c - + 25 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - -0 - - -1 - - -2 - - -max_epochs - -20 - - -accumulate_grad_batches - -2 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 + + +accumulate_grad_batches + +2 16:c--25:c - + 24:c--10:c - + 26 - - -Config: - TensorBoardLogger - - -save_dir - -'logs' - - -name - -'mlcast_ldcast_stage2' + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast_latent_diffusion_stage2' 25:c--26:c - + 27 - - -Config: - ModelCheckpoint - - -monitor - -'val/loss' - - -save_top_k - -3 - - -mode - -'min' + + +Config: + ModelCheckpoint + + +monitor + +'val/loss' + + +save_top_k + +3 + + +mode + +'min' 25:c--27:c - + 28 - - -Config: - EarlyStopping - - -monitor - -'val/loss' - - -patience - -6 - - -mode - -'min' - - -check_finite - -False + + +Config: + EarlyStopping + + +monitor + +'val/loss' + + +patience + +6 + + +mode + +'min' + + +check_finite + +False 25:c--28:c - + 29 - - -Config: - LearningRateMonitor - - -logging_interval - -'step' + + +Config: + LearningRateMonitor + + +logging_interval + +'step' 25:c--29:c - + diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md index 6362e40..139fd55 100644 --- a/ldcast-refactor-plan.md +++ b/ldcast-refactor-plan.md @@ -1,9 +1,9 @@ -# LDCast Refactor Plan +# Latent Diffusion Refactor Plan 0. Config naming and CLI contract - [x] Rename `training_experiment` to `convgru_training_experiment`. - [x] Do not keep `training_experiment` as an alias. -- [x] Reserve `ldcast_training_experiment` as the top-level config name for the new two-stage LDCast workflow. +- [x] Reserve `latent_diffusion_experiment` as the top-level config name for the two-stage workflow. - [x] Require `mlcast train` users to provide an explicit base config via `--config config:` or `--config /path/to/config.yaml`. - [x] Update CLI help text to list the included config entry points explicitly. - [x] Update all docs, examples, tests, and scripts to use `convgru_training_experiment` instead of `training_experiment`. @@ -68,7 +68,7 @@ - [x] Keep `modules/` for task-level Lightning modules only; keep `models/` for pure architectures. 6. Training experiment -- [x] Add a new LDCast-specific training module containing `LDCastTrainingExperiment`. +- [x] Add a new two-stage `LatentDiffusionTrainingExperiment` (initially called `LDCastTrainingExperiment`). - [x] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. - [x] Stage 1 builds the reconstruction dataset, autoencoder model, and `ReconstructionTaskModule`, then trains the autoencoder. - [x] Stage 2 reuses the same trained in-memory autoencoder instance, builds the diffusion dataset/model/`LatentDiffusionTaskModule`, then trains latent diffusion. @@ -82,46 +82,17 @@ - [x] Update CLI/help text in `src/mlcast/__main__.py` to require an explicit base config and list the included config entry points. - [x] Rename `training_experiment` to `convgru_training_experiment` in `src/mlcast/config/base.py` and export it from `src/mlcast/config/__init__.py`. - [x] Add the LDCast config entry point to `src/mlcast/config/__init__.py` alongside the existing ConvGRU example config. -- [x] Keep `src/mlcast/config/orchestrator.py` compatible with both the existing single-stage `Experiment` and the new `LDCastTrainingExperiment` through a common `run()` surface. -- [x] Update docstrings and comments that currently imply `training_experiment` is the only experiment, including `src/mlcast/data/source_data_datamodule.py`, `src/mlcast/config/orchestrator.py`, and related config docs. -- [x] Update docs and scripts that still reference `training_experiment`, including `README.md` and `docs/generate_base_experiment_config_diagram.py`. -- [x] Keep existing ConvGRU CLI/config tests passing while adding separate tests for selecting the LDCast config explicitly. -- [ ] Add real but small-scale end-to-end tests with generated sample data for the autoencoder stage, diffusion stage, and full LDCast stage sequencing. -- [x] Align metric naming with TensorBoard conventions: use `/` as hierarchy separator and `rec_loss`/`loss` to distinguish stages. - -## 8. Align LDCast config with DMI/Martinbo reference - -### Optimizer -- [x] Switch both stages to `AdamW` (from `Adam`) -- [x] Set `betas=(0.5, 0.9)` for both stages -- [x] Set `weight_decay=1e-3` for both stages -- [x] Raise autoencoder LR to `1e-3`, keep diffusion LR at `1e-4` - -### LR scheduler -- [x] Reduce `ReduceLROnPlateau` factor to `0.25` (from `0.5`) -- [x] Reduce patience to `3` (from `10`) - -### EMA -- [x] Increase EMA decay to `0.9999` (from `0.999`) -- [x] Decide: EMA on full diffusion net (DMI) or denoiser only (Martinbo) — chose full diffusion net (DMI) - -### Early stopping -- [x] Reduce patience to `6` (from `20`) -- [x] Set `check_finite=False` on the diffusion stage +- [x] Keep `src/mlcast/config/orchestrator.py` compatible with both the existing single-stage `Experiment` and `LatentDiffusionTrainingExperiment` through a common `run()` surface. -### Model checkpointing -- [x] Increase `save_top_k` to `3` (from `1`) +- [x] Keep existing ConvGRU CLI/config tests passing while adding separate tests for the two-stage config explicitly. -### Diffusion noise schedule -- [x] Increase `timesteps` to `1000` (from `20`) -- [x] Set linear beta schedule from `1e-4` to `2e-2` +- [ ] Add real but small-scale end-to-end tests with generated sample data for the autoencoder stage, diffusion stage, and full two-stage sequencing. +- [x] Align metric naming with TensorBoard conventions: use `/` as hierarchy separator and `rec_loss`/`loss` to distinguish stages. -### Batch size and gradient accumulation -- [x] Reduce batch sizes (e.g. `batch_size=4` autoenc / `batch_size=1` diffusion) -- [x] Add `accumulate_grad_batches=2` +## 8. Align latent diffusion config with DMI/Martinbo reference The `ldcast-dmi/` reference implementation differs from our current -`ldcast_training_experiment` config in several ways. Changes below would +`latent_diffusion_experiment` config in several ways. Changes below would align us more closely with DMI. ### Optimizer diff --git a/src/mlcast/config/__init__.py b/src/mlcast/config/__init__.py index 2b5ce79..2762d96 100644 --- a/src/mlcast/config/__init__.py +++ b/src/mlcast/config/__init__.py @@ -5,7 +5,7 @@ """ from .archetype.convgru import convgru_training_experiment -from .archetype.ldcast import LDCastTrainingExperiment, ldcast_training_experiment +from .archetype.latent_diffusion import LatentDiffusionTrainingExperiment, latent_diffusion_experiment from .base import Experiment from .consistency_checks import validate_config from .fiddlers import ( @@ -21,9 +21,9 @@ __all__ = [ "Experiment", - "LDCastTrainingExperiment", + "LatentDiffusionTrainingExperiment", "convgru_training_experiment", - "ldcast_training_experiment", + "latent_diffusion_experiment", "validate_config", "train_from_config", "load_yaml_config", diff --git a/src/mlcast/config/archetype/ldcast.py b/src/mlcast/config/archetype/latent_diffusion.py similarity index 92% rename from src/mlcast/config/archetype/ldcast.py rename to src/mlcast/config/archetype/latent_diffusion.py index c059586..80f8c10 100644 --- a/src/mlcast/config/archetype/ldcast.py +++ b/src/mlcast/config/archetype/latent_diffusion.py @@ -1,4 +1,4 @@ -"""Fiddle configuration for two-stage LDCast training.""" +"""Fiddle configuration for two-stage latent diffusion training.""" from dataclasses import dataclass @@ -20,8 +20,8 @@ @dataclass -class LDCastTrainingExperiment: - """Two-stage LDCast training experiment. +class LatentDiffusionTrainingExperiment: + """Two-stage latent diffusion training experiment. Parameters ---------- @@ -56,12 +56,12 @@ def run(self) -> None: @fiddle.experimental.auto_config.auto_config -def ldcast_training_experiment() -> LDCastTrainingExperiment: - """Build a Fiddle config for two-stage LDCast training. +def latent_diffusion_experiment() -> LatentDiffusionTrainingExperiment: + """Build a Fiddle config for two-stage latent diffusion training. Returns ------- - LDCastTrainingExperiment + LatentDiffusionTrainingExperiment Configured two-stage experiment with shared autoencoder identity across reconstruction and latent diffusion stages. """ @@ -106,7 +106,7 @@ def ldcast_training_experiment() -> LDCastTrainingExperiment: EarlyStopping(monitor="val/rec_loss", patience=6, mode="min"), LearningRateMonitor(logging_interval="step"), ], - logger=TensorBoardLogger(save_dir="logs", name="mlcast_ldcast_stage1"), + logger=TensorBoardLogger(save_dir="logs", name="mlcast_latent_diffusion_stage1"), ) stage2_data = ForecastingDataModule( @@ -142,10 +142,10 @@ def ldcast_training_experiment() -> LDCastTrainingExperiment: EarlyStopping(monitor="val/loss", patience=6, mode="min", check_finite=False), LearningRateMonitor(logging_interval="step"), ], - logger=TensorBoardLogger(save_dir="logs", name="mlcast_ldcast_stage2"), + logger=TensorBoardLogger(save_dir="logs", name="mlcast_latent_diffusion_stage2"), ) - return LDCastTrainingExperiment( + return LatentDiffusionTrainingExperiment( stage1=Experiment(pl_module=stage1_module, data=stage1_data, trainer=stage1_trainer), stage2=Experiment(pl_module=stage2_module, data=stage2_data, trainer=stage2_trainer), ) diff --git a/src/mlcast/config/consistency_checks.py b/src/mlcast/config/consistency_checks.py index bd4b424..43314d8 100644 --- a/src/mlcast/config/consistency_checks.py +++ b/src/mlcast/config/consistency_checks.py @@ -118,44 +118,47 @@ def _validate_forecasting_experiment_cfg(cfg: fdl.Config) -> None: ) -def _validate_ldcast_training_experiment_cfg(cfg: fdl.Config) -> None: - """Validate a two-stage LDCast training experiment configuration. +def _validate_latent_diffusion_experiment_cfg(cfg: fdl.Config) -> None: + """Validate a two-stage latent diffusion training experiment configuration. Parameters ---------- cfg : fdl.Config - Fiddle configuration for a two-stage LDCast experiment. + Fiddle configuration for a two-stage latent diffusion experiment. Raises ------ ValueError - If any LDCast-specific configuration contract is violated. + If any latent-diffusion-specific configuration contract is violated. """ stage1 = cfg.stage1 stage2 = cfg.stage2 autoencoder = stage1.pl_module.network if autoencoder is not stage2.pl_module.autoencoder: - raise ValueError("LDCast contract violated: stage1 and stage2 must share the same autoencoder config object.") + raise ValueError( + "LatentDiffusion contract violated: stage1 and stage2 must share the same autoencoder config object." + ) stage1_data = stage1.data stage2_data = stage2.data if stage1_data.input_steps != stage2_data.input_steps: raise ValueError( - "LDCast contract violated: stage1 and stage2 must use the same input_steps; " + "LatentDiffusion contract violated: stage1 and stage2 must use the same input_steps; " f"got {stage1_data.input_steps} and {stage2_data.input_steps}." ) stage2_module = stage2.pl_module if stage2_data.forecast_steps != stage2_module.forecast_steps: raise ValueError( - "LDCast contract violated: stage2 data.forecast_steps must match the latent diffusion task module; " - f"got {stage2_data.forecast_steps} and {stage2_module.forecast_steps}." + "LatentDiffusion contract violated: stage2 data.forecast_steps must match the latent diffusion " + f"task module; got {stage2_data.forecast_steps} and {stage2_module.forecast_steps}." ) if len(stage1_data.sequence_dataset_factory.standard_names) != autoencoder.encoder.input_channels: raise ValueError( - "LDCast contract violated: autoencoder encoder input_channels must match the number of source variables." + "LatentDiffusion contract violated: autoencoder encoder input_channels must match the " + "number of source variables." ) @@ -173,7 +176,7 @@ def validate_config(cfg: fdl.Config) -> None: If any configuration contract is violated. """ if hasattr(cfg, "stage1") and hasattr(cfg, "stage2"): - _validate_ldcast_training_experiment_cfg(cfg) + _validate_latent_diffusion_experiment_cfg(cfg) return _validate_forecasting_experiment_cfg(cfg) diff --git a/src/mlcast/config/fiddlers.py b/src/mlcast/config/fiddlers.py index 729e879..69261a7 100644 --- a/src/mlcast/config/fiddlers.py +++ b/src/mlcast/config/fiddlers.py @@ -91,7 +91,7 @@ def applies_to_experiments(fiddler: Callable) -> Callable: This makes fiddlers work with both flat ``Experiment`` configs (returned by ``convgru_training_experiment``) and nested containers like - ``LDCastTrainingExperiment`` that contain multiple ``Experiment`` instances. + ``LatentDiffusionTrainingExperiment`` that contain multiple ``Experiment`` instances. Parameters ---------- diff --git a/tests/config/test_fiddlers.py b/tests/config/test_fiddlers.py index d807f7d..b6e66d5 100644 --- a/tests/config/test_fiddlers.py +++ b/tests/config/test_fiddlers.py @@ -2,7 +2,7 @@ from mlcast.config import ( convgru_training_experiment, - ldcast_training_experiment, + latent_diffusion_experiment, set_variables, toggle_masking, use_random_sampler, @@ -34,9 +34,9 @@ def test_fiddler_toggle_masking() -> None: assert cfg.pl_module.masked_loss is True -def test_fiddler_set_variables_on_ldcast() -> None: - """Verify set_variables applies to both stages of an LDCastTrainingExperiment.""" - cfg = ldcast_training_experiment.as_buildable() +def test_fiddler_set_variables_on_latent_diffusion() -> None: + """Verify set_variables applies to both stages of a LatentDiffusionTrainingExperiment.""" + cfg = latent_diffusion_experiment.as_buildable() set_variables(cfg, ["rainfall_rate", "rainfall_flux", "rainfall_intensity"]) @@ -50,9 +50,9 @@ def test_fiddler_set_variables_on_ldcast() -> None: assert cfg.stage2.pl_module.autoencoder.encoder.input_channels == 3 -def test_fiddler_use_random_sampler_on_ldcast() -> None: - """Verify use_random_sampler applies to both stages of LDCastTrainingExperiment.""" - cfg = ldcast_training_experiment.as_buildable() +def test_fiddler_use_random_sampler_on_latent_diffusion() -> None: + """Verify use_random_sampler applies to both stages of LatentDiffusionTrainingExperiment.""" + cfg = latent_diffusion_experiment.as_buildable() assert fdl.get_callable(cfg.stage1.data.sequence_dataset_factory) is SourceDataPrecomputedSequenceDataset assert fdl.get_callable(cfg.stage2.data.sequence_dataset_factory) is SourceDataPrecomputedSequenceDataset @@ -63,9 +63,9 @@ def test_fiddler_use_random_sampler_on_ldcast() -> None: assert fdl.get_callable(cfg.stage2.data.sequence_dataset_factory) is SourceDataRandomSequenceDataset -def test_fiddler_use_ratio_splits_on_ldcast() -> None: - """Verify use_ratio_splits applies to both stages of LDCastTrainingExperiment.""" - cfg = ldcast_training_experiment.as_buildable() +def test_fiddler_use_ratio_splits_on_latent_diffusion() -> None: + """Verify use_ratio_splits applies to both stages of LatentDiffusionTrainingExperiment.""" + cfg = latent_diffusion_experiment.as_buildable() use_ratio_splits(cfg, train=0.6, val=0.2) diff --git a/tests/config/test_ldcast_experiment.py b/tests/config/test_latent_diffusion_experiment.py similarity index 72% rename from tests/config/test_ldcast_experiment.py rename to tests/config/test_latent_diffusion_experiment.py index b784a65..5aebdd9 100644 --- a/tests/config/test_ldcast_experiment.py +++ b/tests/config/test_latent_diffusion_experiment.py @@ -2,7 +2,7 @@ import fiddle as fdl -from mlcast.config import LDCastTrainingExperiment, ldcast_training_experiment, validate_config +from mlcast.config import LatentDiffusionTrainingExperiment, latent_diffusion_experiment, validate_config from mlcast.config.base import Experiment @@ -19,12 +19,12 @@ def test(self, pl_module, datamodule=None) -> None: # type: ignore[no-untyped-d self.events.append(f"test:{pl_module}:{datamodule}") -def test_ldcast_training_experiment_runs_stages_in_order() -> None: - """LDCastTrainingExperiment should execute stage 1 fully before stage 2.""" +def test_latent_diffusion_experiment_runs_stages_in_order() -> None: + """LatentDiffusionTrainingExperiment should execute stage 1 fully before stage 2.""" events: list[str] = [] stage1 = Experiment(pl_module="stage1_module", data="stage1_data", trainer=RecordingTrainer(events=events)) stage2 = Experiment(pl_module="stage2_module", data="stage2_data", trainer=RecordingTrainer(events=events)) - experiment = LDCastTrainingExperiment(stage1=stage1, stage2=stage2) + experiment = LatentDiffusionTrainingExperiment(stage1=stage1, stage2=stage2) experiment.run() @@ -36,9 +36,9 @@ def test_ldcast_training_experiment_runs_stages_in_order() -> None: ] -def test_ldcast_training_experiment_shares_autoencoder_identity() -> None: +def test_latent_diffusion_experiment_shares_autoencoder_identity() -> None: """Stage 1 and stage 2 should reference the same built autoencoder instance.""" - cfg = ldcast_training_experiment.as_buildable() + cfg = latent_diffusion_experiment.as_buildable() validate_config(cfg) experiment = fdl.build(cfg) diff --git a/tests/config/test_orchestrator.py b/tests/config/test_orchestrator.py index 7cdaa98..af1786a 100644 --- a/tests/config/test_orchestrator.py +++ b/tests/config/test_orchestrator.py @@ -2,7 +2,7 @@ from typing import Any from unittest.mock import patch -from mlcast.config import convgru_training_experiment, ldcast_training_experiment, train_from_config +from mlcast.config import convgru_training_experiment, latent_diffusion_experiment, train_from_config @patch("mlcast.config.orchestrator.fdl.build") @@ -15,9 +15,9 @@ def test_train_from_config_valid(mock_build: Any, tmp_path: Path) -> None: @patch("mlcast.config.orchestrator.fdl.build") -def test_train_from_config_valid_ldcast(mock_build: Any, tmp_path: Path) -> None: - """Verify that a valid LDCast configuration passes validation and builds.""" +def test_train_from_config_valid_latent_diffusion(mock_build: Any, tmp_path: Path) -> None: + """Verify that a valid latent diffusion configuration passes validation and builds.""" mock_build.return_value.trainer.log_dir = str(tmp_path) - cfg = ldcast_training_experiment.as_buildable() + cfg = latent_diffusion_experiment.as_buildable() train_from_config(cfg) mock_build.assert_called_once() diff --git a/tests/test_readme_snippets.py b/tests/test_readme_snippets.py index 450b486..50ba094 100644 --- a/tests/test_readme_snippets.py +++ b/tests/test_readme_snippets.py @@ -112,7 +112,7 @@ def _patch_cfg(cfg: fdl.Config, fp_dataset: Path, tmp_path: Path) -> None: ``network.input_channels`` is kept in sync with ``standard_names``. Handles both flat ``Experiment`` configs and nested containers like - ``LDCastTrainingExperiment`` by finding all ``Experiment`` sub-configs + ``LatentDiffusionTrainingExperiment`` by finding all ``Experiment`` sub-configs in the tree and patching each one. Parameters From 79854fa345aaf9969771c97840d2f6e439141d4c Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 14:07:34 +0200 Subject: [PATCH 33/34] fix typo in readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 095046a..21787a2 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,7 @@ train_from_config(cfg) **Run the included latent diffusion experiment with tweaks:** +```python from mlcast.config import latent_diffusion_experiment, train_from_config from mlcast.config.fiddlers import use_random_sampler From e0d786c23f129f9d740e28ac1336f57b129c89e4 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 4 Jun 2026 14:13:16 +0200 Subject: [PATCH 34/34] ci: install graphviz for config diagram pre-commit hooks --- .github/workflows/pre-commit.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 85fed30..8490865 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -18,6 +18,10 @@ jobs: with: enable-cache: true + # The config-diagram-* hooks require the `dot` binary from graphviz. + - name: Install graphviz + run: sudo apt-get update && sudo apt-get install -y graphviz + # Run pre-commit through uv so CI uses the same project-managed environment # as local development, including local hooks that invoke `uv run ...`. - run: uv run pre-commit run --all-files