diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 85fed30..8490865 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -18,6 +18,10 @@ jobs: with: enable-cache: true + # The config-diagram-* hooks require the `dot` binary from graphviz. + - name: Install graphviz + run: sudo apt-get update && sudo apt-get install -y graphviz + # Run pre-commit through uv so CI uses the same project-managed environment # as local development, including local hooks that invoke `uv run ...`. - run: uv run pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6aafae4..72870b7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,11 +28,17 @@ repos: - id: pyproject-fmt - repo: local hooks: - - id: config-diagram - name: config diagram up to date + - id: config-diagram-convgru + name: convgru config diagram up to date language: system entry: uv run python docs/generate_base_experiment_config_diagram.py --check - files: ^src/mlcast/config/base\.py$ + files: ^src/mlcast/config/(base|archetype/convgru)\.py$ + pass_filenames: false + - id: config-diagram-latent-diffusion + name: latent diffusion config diagram up to date + language: system + entry: uv run python docs/generate_latent_diffusion_config_diagram.py --check + files: ^src/mlcast/config/archetype/latent_diffusion\.py$ pass_filenames: false ci: autoupdate_schedule: monthly diff --git a/AGENTS.md b/AGENTS.md index a8628f7..a5c6375 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -23,6 +23,20 @@ Guidelines for contributors and AI agents working on this codebase. - Use `loguru` exclusively. Do not use the stdlib `logging` module. +## Metric naming + +- Logged metric names follow TensorBoard conventions: use `/` as a hierarchy + separator (e.g. `val/loss`, `train/rec_loss`) so that related metrics are + grouped in the TensorBoard UI. +- Use `rec_loss` for reconstruction-stage metrics and `loss` for + forecasting/diffusion-stage metrics to make clear which training stage a + metric belongs to. +- Metric name format: `{split}/{name}` where `split` is `train`, `val`, or + `test`. +- Monitoring references in config files (ModelCheckpoint, EarlyStopping, + lr_scheduler) must match the `val/{name}` variant of the metric they should + track. + ## Code style - Docstrings follow NumPy style. diff --git a/README.md b/README.md index 28db293..21787a2 100644 --- a/README.md +++ b/README.md @@ -67,12 +67,18 @@ reproduce runs exactly from a saved YAML file. ### Configuration model -Training in mlcast is currently built around a single base configuration -function, [`training_experiment`](src/mlcast/config/base.py), which defines the -default ConvGRU ensemble nowcasting setup: dataset, data module, network, -Lightning module, and trainer. Rather than writing a new config from scratch, -the intended workflow is to start from this base and apply targeted -modifications: +mlcast ships with two included configuration functions: + +- [`convgru_training_experiment`](src/mlcast/config/archetype/convgru.py) — defines a + single-stage ConvGRU ensemble nowcasting setup (dataset, data module, network, + Lightning module, trainer). +- [`latent_diffusion_experiment`](src/mlcast/config/archetype/latent_diffusion.py) — defines a + two-stage latent diffusion setup: stage 1 trains an autoencoder on reconstruction + windows, stage 2 trains a latent diffusion model on the same autoencoder's + latent space. + +Rather than writing a new config from scratch, the intended workflow is to +start from one of these configs and apply targeted modifications: - **`set:` overrides** — change a single scalar parameter (e.g. batch size, learning rate, number of epochs) @@ -82,32 +88,62 @@ modifications: - **direct graph edits** (Python API only) — replace a sub-object entirely, for example swapping in a different network architecture -Any combination of these can be layered on top of the base config, and the +Any combination of these can be layered on top of the selected config, and the fully resolved config is always saved to YAML alongside the training logs so runs can be reproduced exactly. -The diagram below shows the full default config graph as built by -[`training_experiment`](src/mlcast/config/base.py): +The diagrams below show the full included config graphs. + +**convgru_training_experiment:** + +![convgru_training_experiment config graph](docs/config_diagram.svg) + +**latent_diffusion_experiment:** + +![latent_diffusion_experiment config graph](docs/latent_diffusion_config_diagram.svg) + +### Design roles + +mlcast separates pure architectures from task-level training wrappers. -![training_experiment config graph](docs/config_diagram.svg) +- `src/mlcast/models/` + Pure `torch.nn.Module` architectures and supporting components. These classes + define tensor transformations and reusable building blocks, but they do not + decide how training is run or which parameters are optimized. +- `src/mlcast/modules/` + Task-level Lightning modules. These classes define what batch structure a + task consumes, what loss is computed, which parameters are optimized, and how + inference/prediction is exposed. + +In other words, architectures answer "how does this tensor get transformed?", +while task modules answer "what is being trained, against what target, and over +which parameters?" + +This distinction matters especially for latent diffusion. The diffusion +architecture itself lives under `models/`, while the corresponding task module +owns the trained autoencoder reuse policy, decides that only diffusion-network +parameters are optimized, computes diffusion loss in latent space, and handles +decoded forecast inference. ### Command-line interface Install the package and run: ```bash -mlcast train +# Single-stage ConvGRU nowcasting +mlcast train --config config:convgru_training_experiment +# Two-stage latent diffusion + +mlcast train --config config:latent_diffusion_experiment ``` -This trains with the built-in [`training_experiment`](src/mlcast/config/base.py) defaults. All parameters -are controlled via `--config` flags: +All parameters are controlled via `--config` flags: | Prefix | Purpose | Example | |--------|---------|---------| -| *(none)* | Use the built-in default config | `mlcast train` | +| `config:` | Select an included `@auto_config` function | `--config config:convgru_training_experiment` or `--config config:latent_diffusion_experiment` | | `set:` | Override a single parameter | `--config set:data.batch_size=32` | | `fiddler:` | Apply a semantic mutator (multi-param change) | `--config fiddler:use_random_sampler` | -| `config:` | Switch to a different `@auto_config` function | `--config=config:my_experiment` | | `path/to/config.yaml` | Load a previously saved config | `--config saved.yaml` | Multiple `--config` flags are applied in order and can be combined freely. @@ -117,11 +153,13 @@ Multiple `--config` flags are applied in order and can be combined freely. ```bash # Override dataset path and batch size mlcast train \ - --config set:data.dataset_factory.zarr_path=/data/radar.zarr \ + --config config:convgru_training_experiment \ + --config set:data.sequence_dataset_factory.zarr_path=/data/radar.zarr \ --config set:data.batch_size=32 # Switch to random sampler and log to MLflow mlcast train \ + --config config:convgru_training_experiment \ --config fiddler:use_random_sampler \ --config fiddler:use_mlflow_logger @@ -130,8 +168,13 @@ mlcast train \ --config logs/mlcast/version_0/config.yaml \ --config set:trainer.max_epochs=50 +# Run two-stage latent diffusion training with a shorter diffusion schedule + + --config config:latent_diffusion_experiment \ + --config set:stage2.pl_module.diffusion_net.scheduler.timesteps=20 + # Inspect the fully resolved config without starting training -mlcast train --config fiddler:use_random_sampler --print_config_and_exit +mlcast train --config config:convgru_training_experiment --config fiddler:use_random_sampler --print_config_and_exit ``` Run `mlcast train --help` for a full list of examples and available fiddlers. @@ -141,14 +184,14 @@ Run `mlcast train --help` for a full list of examples and available fiddlers. The Python API gives you full programmatic control over the config graph before anything is instantiated. -**Run the default experiment with tweaks:** +**Run the included ConvGRU experiment with tweaks:** ```python import fiddle as fdl -from mlcast.config import training_experiment, train_from_config +from mlcast.config import convgru_training_experiment, train_from_config from mlcast.config.fiddlers import use_random_sampler -cfg = training_experiment.as_buildable() # returns a fdl.Config graph — see src/mlcast/config/base.py +cfg = convgru_training_experiment.as_buildable() # returns a fdl.Config graph — see src/mlcast/config/archetype/convgru.py # Apply a fiddler to switch the dataset sampler use_random_sampler(cfg) @@ -162,6 +205,24 @@ cfg.trainer.max_epochs = 50 train_from_config(cfg) ``` +**Run the included latent diffusion experiment with tweaks:** + +```python +from mlcast.config import latent_diffusion_experiment, train_from_config +from mlcast.config.fiddlers import use_random_sampler + +cfg = latent_diffusion_experiment.as_buildable() + +# Applied once — @applies_to_experiments walks both stages automatically +use_random_sampler(cfg) + +# Override the diffusion noise schedule +cfg.stage2.pl_module.diffusion_net.scheduler.timesteps = 20 + +# train_from_config applies to the full two-stage experiment +train_from_config(cfg) +``` + **Custom network architecture:** You can swap in any architecture by replacing `cfg.pl_module.network` with a @@ -173,7 +234,7 @@ As an example, here is how to wrap an U-Net) to satisfy the interface. The wrapper channel-stacks the past frames and runs the U-Net autoregressively for each requested forecast step: -> **Note** — `input_steps` equals `dataset_factory.input_steps` (6 by +> **Note** — `input_steps` equals the forecasting data module's `input_steps` (6 by > default) and is directly readable from the config graph before building. ```python @@ -183,16 +244,18 @@ import torch import torch.nn as nn from jaxtyping import Float from mfai.torch.models import HalfUNet -from mlcast.config import training_experiment, train_from_config +from mlcast.config import convgru_training_experiment, train_from_config from mlcast.config.fiddlers import use_random_sampler -# Minimal adapter: channel-stack past frames → HalfUNet → one step at a time. -# NowcastLightningModule calls network(x, steps=N, ensemble_size=M), so any -# custom network must accept those keyword arguments. +# Minimal adapter: channel-stack past frames -> HalfUNet -> one step at a time. +# The forecasting contract fixes input_steps, forecast_steps, and ensemble_size +# at model initialization; this minimal deterministic adapter exposes one +# ensemble member and OutputSpaceForecastingTaskModule calls network(x). class HalfUNetNowcaster(nn.Module): - def __init__(self, input_steps: int = 6, num_vars: int = 1): + def __init__(self, input_steps: int = 6, forecast_steps: int = 12, num_vars: int = 1): super().__init__() self.input_steps = input_steps + self.forecast_steps = forecast_steps self.num_vars = num_vars self.unet = HalfUNet( input_shape=(256, 256), @@ -201,39 +264,42 @@ class HalfUNetNowcaster(nn.Module): settings=fdl.Config(HalfUNet.settings_kls), ) + @property + def ensemble_size(self) -> int: + return 1 + @property def input_channels(self) -> int: - # Externally, the HalfUNetNowcaster respects the required input shape structure - # (batch, input_steps, num_vars, H, W), even though the internal U-Net is channel-stacked. - # Adding this property allows the config consistency checks to verify that - # the dataset and model agree on the expected number of input channels. + # Externally the model handles (batch, time, channels, height, width); + # internally the U-Net channel-stacks time into (batch, time*channels, ...). + # This property lets config consistency checks verify dataset-model agreement. return self.num_vars def forward( self, - x: Float[torch.Tensor, "batch input_steps in_channels H W"], - steps: int, - ensemble_size: int = 1, - ) -> Float[torch.Tensor, "batch steps out_channels H W"]: - # channel-stack all input frames: (b, t, c, h, w) -> (b, t*c, h, w) + x: Float[torch.Tensor, "batch time channels height width"], + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: x_flat = einops.rearrange(x, "b t c h w -> b (t c) h w") preds = [] - for _ in range(steps): - y = self.unet(x_flat) # [B, num_vars, H, W] - preds.append(y.unsqueeze(1)) - # slide window: drop the oldest timestep (first num_vars channels), - # append the latest prediction as the newest timestep + for _ in range(self.forecast_steps): + y = self.unet(x_flat) + preds.append(y) x_flat = torch.cat([x_flat[:, self.num_vars:], y], dim=1) - return torch.cat(preds, dim=1) + return einops.rearrange(torch.stack(preds, dim=1), "b t c h w -> b t 1 c h w") -cfg = training_experiment.as_buildable() +cfg = convgru_training_experiment.as_buildable() use_random_sampler(cfg) cfg.pl_module.network = fdl.Config( HalfUNetNowcaster, - input_steps=cfg.data.dataset_factory.input_steps, - num_vars=len(cfg.data.dataset_factory.standard_names), + input_steps=cfg.data.input_steps, + forecast_steps=cfg.data.forecast_steps, + num_vars=len(cfg.data.sequence_dataset_factory.standard_names), ) +# The base ConvGRU config uses CRPS for ensemble forecasts; this adapter is +# deterministic and exposes only one member, so use a deterministic loss. +cfg.pl_module.loss_class = "mse" +cfg.pl_module.loss_params = None train_from_config(cfg) ``` @@ -255,7 +321,7 @@ experiment.run() # trainer.fit() + trainer.test() |---------|-----------|--------------| | `use_mlflow_logger` | *(none)* | Replaces the default `TensorBoardLogger` with `MLFlowLogger` and appends `LogSystemInfoCallback`; respects the `MLFLOW_TRACKING_URI` environment variable | | `set_variables` | `standard_names` | Sets the list of input variables on the dataset and updates `network.input_channels` to match | -| `toggle_masking` | `enabled` | Toggles masked-loss mode by setting both `dataset_factory.return_mask` and `pl_module.masked_loss` to the same value | +| `toggle_masking` | `enabled` | Toggles masked-loss mode by setting both `data.return_mask` and `pl_module.masked_loss` to the same value | | `use_anon_s3_dataset` | `zarr_path`, `endpoint_url` | Points the dataset at an anonymous S3 object store; sets `zarr_path` and the required `storage_options` together | | `use_random_sampler` | *(none)* | Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed CSV is available) | @@ -270,17 +336,37 @@ mlcast/ │ ├── callbacks.py # Training callbacks │ ├── visualization.py # TensorBoard image logging helpers │ ├── config/ -│ │ ├── base.py # Default training_experiment @auto_config +│ │ ├── base.py # Experiment dataclass +│ │ ├── archetype/ +│ │ │ ├── convgru.py # ConvGRU training config @auto_config +│ │ │ └── latent_diffusion.py # Two-stage latent diffusion config @auto_config │ │ ├── fiddlers.py # Semantic config mutators │ │ ├── consistency_checks.py # Cross-parameter validation │ │ ├── loader.py # YAML config loader │ │ └── orchestrator.py # train_from_config, config persistence │ ├── data/ -│ │ ├── source_data_datamodule.py # Lightning DataModule -│ │ ├── source_data_datasets.py # Zarr-backed PyTorch datasets +│ │ ├── datamodules.py # Lightning DataModules +│ │ ├── sequence.py # Zarr-backed sequence datasets +│ │ ├── forecasting.py # Forecasting task dataset wrapper +│ │ ├── reconstruction.py # Reconstruction task dataset wrapper │ │ └── normalization.py # Normalisation registry -│ └── models/ -│ └── convgru.py # ConvGRU encoder-decoder +│ ├── models/ +│ │ ├── convgru.py # ConvGRU encoder-decoder +│ │ ├── autoencoder/ +│ │ │ ├── encoder.py # Encoder +│ │ │ ├── decoder.py # Decoder +│ │ │ └── net.py # AutoencoderNet composition +│ │ └── diffusion/ +│ │ ├── conditioner.py # ConditionerNet (context builder) +│ │ ├── denoiser.py # DenoiserUNet +│ │ ├── scheduler.py # Diffusion noise scheduler +│ │ ├── sampler.py # Inference-time sampling loop +│ │ ├── ema.py # EMA weight tracking +│ │ ├── loss.py # Diffusion loss +│ │ └── net.py # LatentDiffusionNet composition +│ └── modules/ +│ ├── forecasting.py # Base + OutputSpace + LatentDiffusion task modules +│ └── reconstruction.py # ReconstructionTaskModule ├── tests/ ├── pyproject.toml └── README.md @@ -314,7 +400,8 @@ doubled at each block via `PixelShuffle(2)`. **Ensemble** — when `ensemble_size > 1` the decoder is run `ensemble_size` times, each time with freshly sampled Gaussian noise. The results are -concatenated along the channel dimension. +stacked along an explicit ensemble dimension, giving the final shape +`(batch, forecast_steps, ensemble_size, channels, height, width)`. **Deterministic variant** ([diagram source](https://docs.google.com/presentation/d/1U2Y9vZADXTsgQBNiWYAgOwYeMPVu7TOk/edit?slide=id.p6#slide=id.p6)): @@ -325,11 +412,96 @@ concatenated along the channel dimension. ![ConvGruModel stochastic architecture](docs/architectures/convgru-stochastic.png) +### LatentDiffusionNet (two-stage latent diffusion) + +This is a **two-stage** latent diffusion nowcasting system. Stage 1 trains an +autoencoder on reconstruction windows; stage 2 trains a latent diffusion model +that forecasts in the autoencoder's latent space and decodes forecasts back to +data space. + +The architecture components live under `src/mlcast/models/autoencoder/` and +`src/mlcast/models/diffusion/`. The task-level Lightning modules live under +`src/mlcast/modules/` and are wired together by +[`latent_diffusion_experiment`](src/mlcast/config/archetype/latent_diffusion.py). + +#### Stage 1 — Autoencoder reconstruction + +The autoencoder is built from an +[`Encoder`](src/mlcast/models/autoencoder/encoder.py) and +[`Decoder`](src/mlcast/models/autoencoder/decoder.py), composed by +[`AutoencoderNet`](src/mlcast/models/autoencoder/net.py). + +- **Encoder** — a stack of `EncoderBlock` layers. Each block downsamples + spatial resolution via strided 3D convolution and doubles the channel count. + The final output is a latent tensor with shape + `(batch, latent_channels, time, latent_height, latent_width)`. +- **Decoder** — a stack of `DecoderBlock` layers that mirror the encoder. Each + block upsamples spatial resolution via transposed 3D convolution and halves + the channel count, reconstructing the original input shape. + +The autoencoder is trained on overlapping temporal windows via +[`ReconstructionDataset`](src/mlcast/data/reconstruction.py) and +[`ReconstructionDataModule`](src/mlcast/data/datamodules.py). The +[`ReconstructionTaskModule`](src/mlcast/modules/reconstruction.py) optimises +the full autoencoder parameters against an MSE reconstruction loss. + +#### Stage 2 — Latent diffusion forecasting + +The latent diffusion model is built from a +[`ConditionerNet`](src/mlcast/models/diffusion/conditioner.py), +[`DenoiserUNet`](src/mlcast/models/diffusion/denoiser.py), and +[`DiffusionScheduler`](src/mlcast/models/diffusion/scheduler.py), composed by +[`LatentDiffusionNet`](src/mlcast/models/diffusion/net.py). + +- **ConditionerNet** — projects encoded input-history latents through a series + of residual 3D convolution blocks to produce a conditioning context for the + denoiser U-Net. This answers "what did the recent past look like in latent + space?" +- **DenoiserUNet** — a timestep-aware U-Net with 3D convolutions over the + latent spatial dimensions (time dimension is preserved). It receives the + noisy target latent, a diffusion timestep embedding (sinusoidal), and the + conditioning context from the conditioner. The U-Net predicts the additive + noise (`eps` parameterization) that was applied to reach the current + timestep. +- **DiffusionScheduler** — defines the forward diffusion noise schedule + (linear beta schedule by default) and provides the pre-computed alpha/beta + buffers used by the forward and reverse processes. + +Training uses a standard MSE diffusion loss (`DiffusionLoss` in +`src/mlcast/models/diffusion/loss.py`): for each batch the input is encoded +with the trained (frozen) encoder, the target is encoded with the same encoder, +a random timestep is drawn per sample, noise is added to the target latents, +and the denoiser is trained to predict the added noise. + +Inference uses a [`DiffusionSampler`](src/mlcast/models/diffusion/sampler.py) +to progressively denoise random latents conditioned on encoded input history. +The reverse diffusion loop steps backward through the schedule, and the final +denoised latent is decoded back to data space by the trained decoder, giving +an explicit ensemble dimension in the output shape +`(batch, forecast_steps, ensemble_size, channels, height, width)`. When +`ensemble_size > 1`, the process is repeated with fresh noise and the results +are stacked. + +#### Two-stage training experiment + +The [`latent_diffusion_experiment`](src/mlcast/config/archetype/latent_diffusion.py) auto-config +orchestrates both stages: + +- Stage 1 builds a `ReconstructionDataModule`, `AutoencoderNet`, and + `ReconstructionTaskModule`, then calls `trainer.fit() + trainer.test()`. +- Stage 2 reuses the **same trained autoencoder instance** (Fiddle graph + identity sharing), builds a `ForecastingDataModule` and + `LatentDiffusionTaskModule`, then calls `trainer.fit() + trainer.test()`. +- The stage-2 module freezes the autoencoder on `fit_start` and optimises only + the diffusion-network parameters. + + ### Custom network interface Any network architecture can be used by replacing `cfg.pl_module.network` -with a `fdl.Config` node pointing at your class. The only requirement is -that `forward` accepts the following signature: +with a `fdl.Config` node pointing at your class. Forecasting models should set +`input_steps`, `forecast_steps`, and `ensemble_size` during initialization. The +only runtime `forward` requirement is: ```python # from jaxtyping import Float @@ -338,12 +510,15 @@ that `forward` accepts the following signature: def forward( self, x: Float[torch.Tensor, "batch input_steps in_channels H W"], - steps: int, # number of forecast steps to produce - ensemble_size: int, # number of stochastic ensemble members -) -> Float[torch.Tensor, "batch steps out_channels H W"]: +) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels H W"]: ... ``` +The output has an explicit ensemble dimension. For deterministic models +(`ensemble_size=1`) this dimension is 1. If a loss function operates over +the full forecast tensor without splitting ensemble members (e.g. MSE on +the ensemble mean), the task module handles reshaping automatically. + If your network uses a different parameter name for the input channel count than `input_channels` (the default assumed by `ConvGruModel` and the `set_variables` fiddler), set it explicitly on the config node. diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 76275eb..58f150f 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -4,341 +4,356 @@ - - + + %3 - + 2 - - -Config: - ConvGruModel - - -input_channels - -1 - - -num_blocks - -5 - - -noisy_decoder - -False + + +Config: + ConvGruModel + + +input_steps + +6 + + +forecast_steps + +12 + + +ensemble_size + +2 + + +input_channels + +1 + + +num_blocks + +5 + + +noisy_decoder + +False 1 - - -Config: - NowcastLightningModule - - -network - - - - - -ensemble_size - -2 - - -loss_class - -'crps' - - -loss_params - - - -dict - - -'temporal_lambda' - -0.01 - - -masked_loss - -True - - -optimizer - - - - - -lr_scheduler - - - + + +Config: + OutputSpaceForecastingTaskModule + + +network + + + + + +loss_class + +'crps' + + +loss_params + + + +dict + + +'temporal_lambda' + +0.01 + + +masked_loss + +True + + +optimizer + + + + + +lr_scheduler + + + 1:c--2:c - + 3 - - -Partial: - Adam - - -lr - -0.0001 - + + +Partial: + Adam + -fused +lr -True +0.0001 + + +fused + +True 1:c--3:c - + 4 - - -Partial: - ReduceLROnPlateau - - -mode - -'min' - + + +Partial: + ReduceLROnPlateau + -factor +mode -0.5 +'min' -patience +factor -10 +0.5 + + +patience + +10 1:c--4:c - + 0 - - -Config: - Experiment - - -pl_module - - - - - -data - - - - - -trainer - - - + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + 0:c--1:c - + 5 - - -Config: - SourceDataDataModule - - -dataset_factory - - - - - -splits - - - -dict - - -'time' - - - -dict - - -'train' - -0.7 - - -'val' - -0.15 - - -'test' - -0.15 - - -batch_size - -16 - - -num_workers - -8 - - -pin_memory - -True + + +Config: + ForecastingDataModule + + +sequence_dataset_factory + + + + + +input_steps + +6 + + +forecast_steps + +12 + + +return_mask + +True + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +16 + + +num_workers + +8 + + +pin_memory + +True 0:c--5:c - + 7 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - - - - -0 - - -1 - - -2 - - -3 - - -max_epochs - -100 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + + + + +0 + + +1 + + +2 + + +3 + + +max_epochs + +100 0:c--7:c - + 6 - - -Partial: - SourceDataPrecomputedSamplingDataset - - -zarr_path - -'./data/radar.zarr' - - -csv_path - -'./data/sampled_datacubes.csv' - - -standard_names - - - -list - -'rainfall_rate' - - -0 - - -input_steps - -6 - - -forecast_steps - -12 + + +Partial: + SourceDataPrecomputedSequenceDataset + + +zarr_path + +'./data/radar.zarr' + + +csv_path + +'./data/sampled_datacubes.csv' + + +standard_names + + + +list + +'rainfall_rate' + + +0 -return_mask +sequence_steps -True +18 deterministic @@ -348,132 +363,132 @@ 5:c--6:c - + 8 - - -Config: - TensorBoardLogger - - -save_dir - -'logs' + + +Config: + TensorBoardLogger -name +save_dir -'mlcast' +'logs' + + +name + +'mlcast' 7:c--8:c - + 9 - - -Config: - ModelCheckpoint - - -monitor - -'val_loss' + + +Config: + ModelCheckpoint -save_top_k +monitor -1 +'val/loss' -mode +save_top_k -'min' +1 + + +mode + +'min' 7:c--9:c - + 10 - - -Config: - ModelCheckpoint - - -monitor - -'train_loss_epoch' + + +Config: + ModelCheckpoint -save_top_k +monitor -1 +'train/loss_epoch' -mode +save_top_k -'min' +1 + + +mode + +'min' 7:c--10:c - + 11 - - -Config: - EarlyStopping - - -monitor - -'val_loss' + + +Config: + EarlyStopping -patience +monitor -100 +'val/loss' -mode +patience -'min' +100 + + +mode + +'min' 7:c--11:c - + 12 - - -Config: - LearningRateMonitor - - -logging_interval - -'step' + + +Config: + LearningRateMonitor + + +logging_interval + +'step' 7:c--12:c - + diff --git a/docs/generate_base_experiment_config_diagram.py b/docs/generate_base_experiment_config_diagram.py index 0b58291..3193616 100644 --- a/docs/generate_base_experiment_config_diagram.py +++ b/docs/generate_base_experiment_config_diagram.py @@ -1,4 +1,4 @@ -"""Generate a Graphviz SVG diagram of the default training_experiment config. +"""Generate a Graphviz SVG diagram of the included ConvGRU training config. Run without arguments to regenerate docs/config_diagram.svg: @@ -15,13 +15,13 @@ import fiddle.graphviz as fgv -from mlcast.config import training_experiment +from mlcast.config import convgru_training_experiment OUT = Path(__file__).parent / "config_diagram.svg" def main() -> None: - """Generate or verify the base experiment config diagram.""" + """Generate or verify the ConvGRU training config diagram.""" parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( "--check", @@ -30,7 +30,7 @@ def main() -> None: ) args = parser.parse_args() - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() g = fgv.render(cfg, max_str_length=40) g.format = "svg" new_svg = g.pipe().decode() diff --git a/docs/generate_latent_diffusion_config_diagram.py b/docs/generate_latent_diffusion_config_diagram.py new file mode 100644 index 0000000..007a5ee --- /dev/null +++ b/docs/generate_latent_diffusion_config_diagram.py @@ -0,0 +1,52 @@ +"""Generate a Graphviz SVG diagram of the included latent diffusion training config. + +Run without arguments to regenerate docs/latent_diffusion_config_diagram.svg: + + uv run python docs/generate_latent_diffusion_config_diagram.py + +Run with --check to verify the diagram is up to date: + + uv run python docs/generate_latent_diffusion_config_diagram.py --check +""" + +import argparse +import sys +from pathlib import Path + +import fiddle.graphviz as fgv + +from mlcast.config import latent_diffusion_experiment + +OUT = Path(__file__).parent / "latent_diffusion_config_diagram.svg" + + +def main() -> None: + """Generate or verify the latent diffusion training config diagram.""" + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--check", + action="store_true", + help="Check that the diagram is up to date rather than regenerating it.", + ) + args = parser.parse_args() + + cfg = latent_diffusion_experiment.as_buildable() + g = fgv.render(cfg, max_str_length=40) + g.format = "svg" + new_svg = g.pipe().decode() + + if args.check: + if not OUT.exists() or OUT.read_text() != new_svg: + print( + "docs/latent_diffusion_config_diagram.svg is out of date.\n" + "Run: uv run python docs/generate_latent_diffusion_config_diagram.py" + ) + sys.exit(1) + print("docs/latent_diffusion_config_diagram.svg is up to date.") + else: + OUT.write_text(new_svg) + print(f"Written {OUT}") + + +if __name__ == "__main__": + main() diff --git a/docs/latent_diffusion_config_diagram.svg b/docs/latent_diffusion_config_diagram.svg new file mode 100644 index 0000000..da917ad --- /dev/null +++ b/docs/latent_diffusion_config_diagram.svg @@ -0,0 +1,1053 @@ + + + + + + +%3 + + + +4 + + +Config: + Encoder + + +input_channels + +1 + + +hidden_channels + +16 + + +latent_channels + +32 + + +num_blocks + +2 + + + +3 + + +Config: + AutoencoderNet + + +encoder + + + + + +decoder + + + + + + +3:c--4:c + + + + +5 + + +Config: + Decoder + + +output_channels + +1 + + +hidden_channels + +16 + + +latent_channels + +32 + + +num_blocks + +2 + + + +3:c--5:c + + + + +2 + + +Config: + ReconstructionTaskModule + + +network + + + + + +loss_class + +'mse' + + +optimizer + + + + + +lr_scheduler + + + + + + +2:c--3:c + + + + +6 + + +Partial: + AdamW + + +lr + +0.001 + + +betas + + + + + +weight_decay + +0.001 + + +fused + +True + + + +2:c--6:c + + + + +8 + + +Partial: + ReduceLROnPlateau + + +mode + +'min' + + +factor + +0.25 + + +patience + +3 + + + +2:c--8:c + + + + +7 + + +tuple + +0.5 + +0.9 + + +0 + + +1 + + + +6:c--7:c + + + + +1 + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + + + + +1:c--2:c + + + + +9 + + +Config: + ReconstructionDataModule + + +sequence_dataset_factory + + + + + +input_steps + +4 + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +4 + + +num_workers + +8 + + +pin_memory + +True + + + +1:c--9:c + + + + +11 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 + + +accumulate_grad_batches + +2 + + + +1:c--11:c + + + + +10 + + +Partial: + SourceDataPrecomputedSequenceDataset + + +zarr_path + +'./data/radar.zarr' + + +csv_path + +'./data/sampled_datacubes.csv' + + +standard_names + + + +list + +'rainfall_rate' + + +0 + + +sequence_steps + +16 + + +deterministic + +False + + + +9:c--10:c + + + + +12 + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast_latent_diffusion_stage1' + + + +11:c--12:c + + + + +13 + + +Config: + ModelCheckpoint + + +monitor + +'val/rec_loss' + + +save_top_k + +3 + + +mode + +'min' + + + +11:c--13:c + + + + +14 + + +Config: + EarlyStopping + + +monitor + +'val/rec_loss' + + +patience + +6 + + +mode + +'min' + + + +11:c--14:c + + + + +15 + + +Config: + LearningRateMonitor + + +logging_interval + +'step' + + + +11:c--15:c + + + + +0 + + +Config: + LatentDiffusionTrainingExperiment + + +stage1 + + + + + +stage2 + + + + + + +0:c--1:c + + + + +16 + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + + + + +0:c--16:c + + + + +17 + + +Config: + LatentDiffusionTaskModule + + +autoencoder + + + + + +diffusion_net + + + + + +forecast_steps + +12 + + +ensemble_size + +2 + + +optimizer + + + + + +lr_scheduler + + + + + +ema_decay + +0.9999 + + + +17:c--3:c + + + + +18 + + +Config: + LatentDiffusionNet + + +conditioner + + + + + +denoiser + + + + + +scheduler + + + + + + +17:c--18:c + + + + +22 + + +Partial: + AdamW + + +lr + +0.0001 + + +betas + + + + + +weight_decay + +0.001 + + +fused + +True + + + +17:c--22:c + + + + +23 + + +Partial: + ReduceLROnPlateau + + +mode + +'min' + + +factor + +0.25 + + +patience + +3 + + + +17:c--23:c + + + + +19 + + +Config: + ConditionerNet + + +latent_channels + +32 + + +hidden_channels + +32 + + +num_blocks + +2 + + + +18:c--19:c + + + + +20 + + +Config: + DenoiserUNet + + +latent_channels + +32 + + +condition_channels + +32 + + +hidden_channels + +32 + + +num_blocks + +2 + + + +18:c--20:c + + + + +21 + + +Config: + DiffusionScheduler + + +timesteps + +1000 + + + +18:c--21:c + + + + +22:c--7:c + + + + +16:c--17:c + + + + +24 + + +Config: + ForecastingDataModule + + +sequence_dataset_factory + + + + + +input_steps + +4 + + +forecast_steps + +12 + + +return_mask + +False + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +batch_size + +1 + + +num_workers + +8 + + +pin_memory + +True + + + +16:c--24:c + + + + +25 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + +0 + + +1 + + +2 + + +max_epochs + +20 + + +accumulate_grad_batches + +2 + + + +16:c--25:c + + + + +24:c--10:c + + + + +26 + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast_latent_diffusion_stage2' + + + +25:c--26:c + + + + +27 + + +Config: + ModelCheckpoint + + +monitor + +'val/loss' + + +save_top_k + +3 + + +mode + +'min' + + + +25:c--27:c + + + + +28 + + +Config: + EarlyStopping + + +monitor + +'val/loss' + + +patience + +6 + + +mode + +'min' + + +check_finite + +False + + + +25:c--28:c + + + + +29 + + +Config: + LearningRateMonitor + + +logging_interval + +'step' + + + +25:c--29:c + + + + diff --git a/ldcast-notes.md b/ldcast-notes.md new file mode 100644 index 0000000..395ffba --- /dev/null +++ b/ldcast-notes.md @@ -0,0 +1,67 @@ +# LDCast implementation notes + +Architecture decisions and differences between reference implementations. + +## EMA scope + +| Reference | Scope | Decay | +|-----------|-------|-------| +| **DMI** | Full `LatentDiffusionNet` (conditioner + denoiser + scheduler buffers) | `0.9999` | +| **Martinbo** | Denoiser submodule only (`diffusion_net.denoiser`) | `0.9999` | +| **Ours** | Full `LatentDiffusionNet` (matches DMI) | `0.9999` | + +**Rationale for full-network EMA**: The conditioner is a single-pass feed-forward network +called once per sample, so weight noise matters less than in the denoiser. However, there +is no downside to smoothing it too, and it keeps the code simpler (EMA wraps the entire +diffusion net rather than reaching into a private submodule). Full-network EMA is the +standard practice in DDPM, Stable Diffusion, and DMI's reference. + +If denoiser-only EMA were desired in the future, the change is in +`forecasting.py:514`: + +```python +# Full network (current, matches DMI): +self.ema = ExponentialMovingAverage(diffusion_net, decay=ema_decay) + +# Denoiser only (Martinbo): +self.ema = ExponentialMovingAverage(diffusion_net.denoiser, decay=ema_decay) +``` + +## Optimizer + +| Reference | Type | Betas | Weight decay | Autoencoder LR | Diffusion LR | +|-----------|------|-------|-------------|----------------|--------------| +| **DMI** | `AdamW` | `(0.5, 0.9)` | `1e-3` | `1e-3` | `1e-4` | +| **Martinbo** | `AdamW` | `[0.5, 0.9]` | `0.001` | `1e-3` | `1e-4` | +| **Ours** | `AdamW` | `(0.5, 0.9)` | `1e-3` | `1e-3` | `1e-4` | + +Both references agree on all optimizer settings. + +## LR scheduler + +`ReduceLROnPlateau(factor=0.25, patience=3)` for both stages. Monitor metric naming +differs: DMI uses `val_rec_loss` / `val_loss_ema`, Martinbo uses `val/rec_loss` / +`val/loss` (TensorBoard convention). Ours follows Martinbo's naming. + +## Diffusion noise schedule + +Both references use `timesteps=1000` with `beta_start=1e-4, beta_end=2e-2`. +`DiffusionScheduler` defaults already match these beta bounds. + +## Monitor metric naming + +DMI uses underscores (`val_rec_loss`), Martinbo uses TensorBoard-style slashes +(`val/rec_loss`). We follow Martinbo / TensorBoard convention — slashes give +automatic grouping in the TensorBoard UI. + +## Early stopping + +DMI and Martinbo both use `patience=6`. Martinbo adds `check_finite=False` on the +diffusion stage. We follow both. + +## Batch size + +None of the three implementations agree on batch size: +- DMI: `batch_size=4` (autoencoder) / `1` (diffusion) — example configs +- Martinbo: `batch_size=1` for both stages +- Ours: `batch_size=4` (autoencoder) / `1` (diffusion) — matches DMI diff --git a/ldcast-refactor-plan.md b/ldcast-refactor-plan.md new file mode 100644 index 0000000..139fd55 --- /dev/null +++ b/ldcast-refactor-plan.md @@ -0,0 +1,331 @@ +# Latent Diffusion Refactor Plan + +0. Config naming and CLI contract +- [x] Rename `training_experiment` to `convgru_training_experiment`. +- [x] Do not keep `training_experiment` as an alias. +- [x] Reserve `latent_diffusion_experiment` as the top-level config name for the two-stage workflow. +- [x] Require `mlcast train` users to provide an explicit base config via `--config config:` or `--config /path/to/config.yaml`. +- [x] Update CLI help text to list the included config entry points explicitly. +- [x] Update all docs, examples, tests, and scripts to use `convgru_training_experiment` instead of `training_experiment`. +- [x] Treat `convgru_training_experiment` as the existing ConvGRU forecasting example, not as a special default config. + +1. Forecasting and reconstruction data +- [x] Move the existing sampled-sequence source-data logic into `src/mlcast/data/sequence.py`. +- [x] Rename `SourceDataDatasetBase`, `SourceDataPrecomputedSamplingDataset`, and `SourceDataRandomSamplingDataset` to `SourceDataSequenceDatasetBase`, `SourceDataPrecomputedSequenceDataset`, and `SourceDataRandomSequenceDataset` under the sequence data area. +- [x] Remove the old source-data public API rather than keeping compatibility re-exports. +- [x] Keep the existing sampled-sequence implementation as the source-data sequence layer. +- [x] Sequence datasets should own normalization and return normalized tensors of shape `(sequence_steps, channels, height, width)`. +- [x] Replace forecasting-specific sampling parameters in the source-data sequence layer with a single `sequence_steps` parameter. +- [x] Add `src/mlcast/data/forecasting.py`. +- [x] Add a generic `ForecastingDataset` that wraps a base sequence dataset, takes `input_steps` and `forecast_steps`, validates `input_steps + forecast_steps == sequence_steps`, and returns forecasting samples. +- [x] `ForecastingDataset` should derive `target_mask` itself rather than relying on the base sequence dataset to return masks. +- [x] Add `src/mlcast/data/reconstruction.py`. +- [x] Add `ReconstructionDataset`, a generic wrapper around a base sequence dataset that slices each full sequence into all overlapping windows of length `input_steps` and returns only the tensor window. +- [x] Add `src/mlcast/data/datamodules.py`. +- [x] Rename `SourceDataDataModule` to `ForecastingDataModule` in `src/mlcast/data/datamodules.py`. +- [x] `ForecastingDataModule` should remain factory-based and build `ForecastingDataset` instances over the underlying sequence datasets. +- [x] Add `ReconstructionDataModule` to `src/mlcast/data/datamodules.py`; it remains factory-based, builds the underlying sequence datasets, splits them into train/val/test, and wraps each split with `ReconstructionDataset`. +- [x] Keep this generic: no LDCast-specific naming in the module or class names. +- [x] Forecasting should stay one-sequence-to-one-sample. +- [x] Reconstruction should expand each sequence into `sequence_steps - input_steps + 1` overlapping samples. +- [x] Stage 1 should use reconstruction windows of length `input_steps` derived from the full sequence dataset. + +2. Autoencoder model architecture +- Autoencoder model split: + - [x] `src/mlcast/models/autoencoder/encoder.py` for `Encoder` and `EncoderBlock`. + - [x] `src/mlcast/models/autoencoder/decoder.py` for `Decoder` and `DecoderBlock`. + - [x] `src/mlcast/models/autoencoder/net.py` for `AutoencoderNet`. +- [x] Use `input_steps` for the stage-1 reconstruction window length; do not introduce names like `autoenc_time_ratio`. +- Autoencoder validation and tests: + - [x] encoder output shape. + - [x] decoder output shape. + - [x] autoencoder reconstruction forward pass. + - [x] autoencoder improves reconstruction loss on a small generated dataset after a few training steps. + +3. Forecasting model contract +- [x] Standardize all forecasting models on init-time `input_steps`, `forecast_steps`, and `ensemble_size`. +- [x] Standardize forecasting model inference on `forward(x)` only; do not pass `forecast_steps` or `ensemble_size` at runtime. +- [x] Refactor the existing ConvGRU path to follow this fixed-shape contract. +- [x] Add config consistency checks that dataset `input_steps` and `forecast_steps` match the configured forecasting model. + +4. Diffusion model architecture +- Diffusion model split: + - [x] `src/mlcast/models/diffusion/conditioner.py` for latent conditioning blocks and `ConditionerNet`. + - [x] `src/mlcast/models/diffusion/denoiser.py` for `DenoiserUNet` and timestep-aware helpers. + - [x] `src/mlcast/models/diffusion/net.py` for `LatentDiffusionNet`. + - [x] `src/mlcast/models/diffusion/scheduler.py`, `ema.py`, `sampler.py`, `loss.py` for diffusion support code. +- Validation and tests: + - [x] latent diffusion model API. + - [x] diffusion model improves loss on a small generated latent dataset after a few training steps. + +5. Task modules (Lightning modules) +- [x] Add `src/mlcast/modules/forecasting.py`, introduce `BaseForecastingTaskModule`, and rename `NowcastLightningModule` to `OutputSpaceForecastingTaskModule`. +- [x] `BaseForecastingTaskModule` should own optimizer/scheduler plumbing, while each concrete task module defines which parameters are trainable. +- [x] `OutputSpaceForecastingTaskModule` should optimize the forecasting network parameters. +- [x] Remove runtime `forecast_steps` and `ensemble_size` arguments from the forecasting task module and its `predict()` API. +- [x] Add `src/mlcast/modules/reconstruction.py` with a generic `ReconstructionTaskModule` for any reconstruction model. +- [x] Add a `LatentDiffusionTaskModule` that owns the trained autoencoder, optimizes only the diffusion-network parameters, trains diffusion in latent space, and handles decoded forecast inference. +- [x] Keep `modules/` for task-level Lightning modules only; keep `models/` for pure architectures. + +6. Training experiment +- [x] Add a new two-stage `LatentDiffusionTrainingExperiment` (initially called `LDCastTrainingExperiment`). +- [x] Keep `convgru_training_experiment` as the existing ConvGRU forecasting example and one of the explicitly selected included CLI configs. +- [x] Stage 1 builds the reconstruction dataset, autoencoder model, and `ReconstructionTaskModule`, then trains the autoencoder. +- [x] Stage 2 reuses the same trained in-memory autoencoder instance, builds the diffusion dataset/model/`LatentDiffusionTaskModule`, then trains latent diffusion. +- [x] Stage 2 freezes the reused autoencoder parameters and optimizes only the latent diffusion task module's diffusion-network parameters. +- [x] The shared Fiddle graph should define the autoencoder once and reference the same object in both stages, but no unresolved Fiddle objects should flow into actual `torch.nn.Module.__init__` calls. +- [x] Stage-2 diffusion training uses the trained encoder to produce input and target latents; the trained decoder is retained for final forecast decoding but is not used in the stage-2 diffusion loss. +- [x] Reuse the same forecasting dataset abstraction in stage 2; do not add a separate latent dataset layer. +- [x] Add tests for shared object identity and stage sequencing. + +7. Audit and migration targets +- [x] Update CLI/help text in `src/mlcast/__main__.py` to require an explicit base config and list the included config entry points. +- [x] Rename `training_experiment` to `convgru_training_experiment` in `src/mlcast/config/base.py` and export it from `src/mlcast/config/__init__.py`. +- [x] Add the LDCast config entry point to `src/mlcast/config/__init__.py` alongside the existing ConvGRU example config. +- [x] Keep `src/mlcast/config/orchestrator.py` compatible with both the existing single-stage `Experiment` and `LatentDiffusionTrainingExperiment` through a common `run()` surface. + +- [x] Keep existing ConvGRU CLI/config tests passing while adding separate tests for the two-stage config explicitly. + +- [ ] Add real but small-scale end-to-end tests with generated sample data for the autoencoder stage, diffusion stage, and full two-stage sequencing. +- [x] Align metric naming with TensorBoard conventions: use `/` as hierarchy separator and `rec_loss`/`loss` to distinguish stages. + +## 8. Align latent diffusion config with DMI/Martinbo reference + +The `ldcast-dmi/` reference implementation differs from our current +`latent_diffusion_experiment` config in several ways. Changes below would +align us more closely with DMI. + +### Optimizer +- **DMI**: `AdamW` with `lr=1e-3` (autoenc) / `1e-4` (diffusion), + `betas=(0.5, 0.9)`, `weight_decay=1e-3` for **both** stages. +- **Ours**: `Adam` with `lr=1e-4` for both stages, default betas, no + weight decay. +- **To align**: switch to `AdamW`, use DMI betas/weight_decay, and raise + autoencoder LR to `1e-3`. + +### LR scheduler +- **DMI**: `ReduceLROnPlateau(factor=0.25, patience=3)`, monitors + `val_rec_loss` (autoenc) / `val_loss_ema` (diffusion). +- **Ours**: `ReduceLROnPlateau(factor=0.5, patience=10)`, monitors + `val_loss` for both stages. +- **To align**: reduce factor to `0.25` and patience to `3`; use separate + monitor metrics per stage (autoenc → `val_loss`, diffusion → `val_loss`). + +### Learning rate warmup +- **DMI**: Linear warmup support in diffusion stage (`lr_warmup`, default + 0 — disabled). Autoencoder has none. +- **Ours**: No warmup in either stage. +- **To align**: no change needed unless LR warmup is desired. + +### EMA +- **DMI**: `LitEma` with `decay=0.9999` (adaptive based on num_updates), + only on diffusion model weights. EMA weights swapped in for + validation/testing. +- **Ours**: `ExponentialMovingAverage` with `decay=0.999` for diffusion + net, swapped in for val/test. +- **To align**: increase EMA decay to `0.9999`. + +### Early stopping +- **DMI**: patience `6`, monitors `val_rec_loss` / `val_loss_ema`, + `check_finite=False` on diffusion. +- **Ours**: patience `20`, monitors `val_loss`. +- **To align**: reduce patience to `6`; consider `check_finite=False`. + +### Model checkpointing +- **DMI**: `save_top_k=3`, monitors `val_rec_loss` / `val_loss_ema`. +- **Ours**: `save_top_k=1`, monitors `val_loss`. +- **To align**: increase save_top_k to `3`. + +### Diffusion noise schedule +- **DMI**: `timesteps=1000`, linear beta schedule from `1e-4` to `2e-2`. +- **Ours**: `timesteps=20`, default linear schedule. +- **To align**: increase to `timesteps=1000` and match beta range. + +### Batch size and gradient accumulation +- **DMI**: `batch_size=4` (autoenc, example) / `batch_size=1` (diffusion, + example); `accumulate_grad_batches=2`. +- **Ours**: `batch_size=16` / `8`; no gradient accumulation. +- **To align**: reduce batch sizes and add `accumulate_grad_batches=2`. + +### DDP strategy +- **DMI**: `DDPStrategy(find_unused_parameters=True)` on autoencoder. +- **Ours**: default (no `DDPStrategy`). +- **To align**: no change needed unless running DDP. + +## Martinbo alignment notes + +The `feat/ldcast-martinbo` branch differs from both our current config and +the DMI reference in several ways. + +### Optimizer +- **DMI**: `AdamW`, `lr=1e-3` / `1e-4`, `betas=(0.5, 0.9)`, `wd=1e-3`. +- **Martinbo**: `AdamW`, `lr=1e-3` / `1e-4`, `betas=[0.5, 0.9]`, `wd=0.001`. +- **Ours**: `Adam`, `lr=1e-4` for both, default betas, no weight decay. +- **To align**: Martinbo matches DMI exactly — `AdamW`, per-stage LR, betas, and wd. + +### LR scheduler +- **DMI**: `ReduceLROnPlateau(factor=0.25, patience=3)`, monitors + `val_rec_loss` / `val_loss_ema`. +- **Martinbo**: `ReduceLROnPlateau(factor=0.25, patience=3)`, monitors + `val/rec_loss` / `val/loss`. +- **Ours**: `ReduceLROnPlateau(factor=0.5, patience=10)`, monitors + `val_loss` for both stages. +- **To align**: Martinbo matches DMI's factor/patience; only monitor-metric + naming differs (`val/rec_loss` vs `val_rec_loss`). + +### Learning rate warmup +- **DMI**: Diffusion warmup support (`lr_warmup=0`, disabled by default). +- **Martinbo**: No warmup support in either stage. +- **Ours**: No warmup in either stage. +- **To align**: no change needed (DMI also has it disabled by default). + +### EMA +- **DMI**: `LitEma` with `decay=0.9999` (adaptive), on full diffusion model. +- **Martinbo**: `EMA` with `decay=0.9999` (dynamic, adaptive), wraps + **denoiser only** (`store_device='cuda'`). +- **Ours**: `ExponentialMovingAverage` with `decay=0.999`, on diffusion net. +- **To align**: increase decay to `0.9999`; consider whether EMA should wrap + the full diffusion net or just the denoiser. + +### Early stopping +- **DMI**: patience `6`, monitors `val_rec_loss` / `val_loss_ema`, + `check_finite=False` on diffusion. +- **Martinbo**: patience `6`, monitors `val/loss_epoch` (both stages), + `check_finite=False`. +- **Ours**: patience `20`, monitors `val_loss`. +- **To align**: Martinbo matches DMI's patience and `check_finite=False`; + monitor naming differs (`val/loss_epoch` vs `val_loss_ema`). + +### Model checkpointing +- **DMI**: `save_top_k=3`, monitors `val_rec_loss` / `val_loss_ema`. +- **Martinbo**: Not explicitly configured in branch `config.yaml` (relies on + Lightning default, `save_top_k=1`). +- **Ours**: `save_top_k=1`, monitors `val_loss`. +- **To align**: Martinbo implicitly matches Ours on `save_top_k`; DMI differs + with `save_top_k=3`. + +### Diffusion noise schedule +- **DMI**: `timesteps=1000`, linear beta `1e-4` to `2e-2`. +- **Martinbo**: `timesteps=1000`, linear beta `1e-4` to `2e-2` (defaults, + config section is `{}`). +- **Ours**: `timesteps=20`, default linear schedule. +- **To align**: Martinbo matches DMI exactly — `timesteps=1000`, same beta range. + +### Batch size and gradient accumulation +- **DMI**: `batch_size=4` / `1` (example configs), `accumulate_grad_batches=2`. +- **Martinbo**: `batch_size=1` for both stages; no `accumulate_grad_batches`. +- **Ours**: `batch_size=16` / `8`; no gradient accumulation. +- **To align**: Martinbo uses smaller batches than both DMI and Ours; none + of the three agree on batch size strategy. + +### DDP strategy +- **DMI**: `DDPStrategy(find_unused_parameters=True)` (autoenc) / + `DDPStrategy()` (diffusion). +- **Martinbo**: `strategy='ddp'` (string), `sync_batchnorm=True`, `num_nodes=1`. +- **Ours**: default (no `DDPStrategy`). +- **To align**: no change needed unless running DDP. + +### Diffusion parameterization and loss +- **DMI**: `parameterization="eps"`, `loss_type="l2"` (MSE). +- **Martinbo**: `parametrization="eps"` (note: spelling difference), + `nn.MSELoss()`. +- **Ours**: `parameterization="eps"` in `DiffusionLoss` (L2 via + `nn.MSELoss` reduction). +- **To align**: All three agree on `eps` + MSE — no change needed. + + +## 9. Possible future work + +Architecture and feature upgrades not covered by section 8 (config alignment). +Each entry explains why it might be worth doing. + +### VAE autoencoder (KL-regularised latent space) + +DMI uses `AutoencoderKL` — a VAE trained with a KL-divergence loss on the latent +distribution, producing a smoother, more Gaussian latent space. Ours uses a +deterministic autoencoder. + +**Why it might be a good idea**: Diffusion models assume the target distribution +is Gaussian (they start from Gaussian noise and reverse-diffuse). A KL- +regularised latent space is closer to Gaussian, which can make diffusion easier +and improve sample quality. It also enables latent-space interpolation and +manipulation. However, it adds training complexity (KL loss weighting, posterior +collapse risk). + +### Larger denoiser (3D UNet with cross-attention) + +DMI's `UNetModel` uses 128 model channels, attention blocks at multiple +resolutions, 8 attention heads, and 3D convolutions over (time, height, width). +Ours uses 32 hidden channels, no attention, and preserves time as a plain +channel dimension. + +**Why it might be a good idea**: More capacity → better fit to complex +precipitation patterns. Cross-attention allows the denoiser to selectively +attend to conditioning context at each resolution, which is more expressive +than our simple input concatenation. The cost is larger memory footprint and +longer training times. + +### Multi-resolution context encoder (AFNO cascade) + +DMI's `AFNONowcastNetCascade` produces a feature pyramid where each spatial +resolution has its own channel depth. The denoiser selects the appropriate level +based on its current spatial resolution. Ours uses a single-resolution +`ConditionerNet` that is interpolated if spatial sizes don't match. + +**Why it might be a good idea**: Multi-resolution conditioning lets the denoiser +access fine-grained local information at high resolutions and broad context at +low resolutions simultaneously. This is standard in modern conditional diffusion +models (e.g. Stable Diffusion's cross-attention to CLIP embeddings at multiple +scales). + +### PLMS accelerated sampling + +DMI's `PLMSSampler` uses Adams-Bashforth multistep integration to reduce 1000 +DDPM steps to ~50 with minimal quality loss. Ours uses a basic ancestral DDPM +sampler that iterates over all timesteps. + +**Why it might be a good idea**: 20× faster inference is critical for +operational nowcasting where latency matters. PLMS is well-established (from +CompVis latent-diffusion) and requires no retraining — it works with any +trained eps-predicting model. + +### Adaptive EMA decay (LitEma-style) + +DMI uses `LitEma` with adaptive decay `min(0.9999, (1+n)/(10+n))` where `n` is +the number of EMA updates. Ours uses a fixed decay of `0.9999`. + +**Why it might be a good idea**: Adaptive decay starts lower (giving more weight +to recent parameters early in training when they change fastest) and converges +to 0.9999. This accelerates early training while maintaining the benefits of EMA +at convergence. Simple to implement — just change the decay formula in `update`. + +### Multiple beta schedules (cosine, sqrt) + +DMI supports linear, cosine, sqrt_linear, and sqrt schedules. Ours supports +linear only. + +**Why it might be a good idea**: Cosine schedules (from Nichol & Dhariwal, +"Improved DDPM") add noise more gradually, which can improve sample quality — +especially at low resolutions or with fewer timesteps. Having multiple schedules +also enables hyperparameter search. The `DiffusionScheduler` would need to +accept a `schedule` string and dispatch to the right formula. + +### x0 parameterization and L1 loss + +DMI supports predicting `x0` (the clean target) instead of `eps` (the noise), +and using L1 instead of L2 for the loss. Ours uses `eps` + L2 only. + +**Why it might be a good idea**: `x0` prediction can be more stable at high +noise levels and is required for certain sampling techniques. L1 loss tends to +produce sharper outputs (less blurring than L2), which is desirable for +precipitation fields with sharp rain/no-rain boundaries. + +### Classifier-free guidance (CFG) in the sampler + +DMI's `PLMSSampler` supports CFG via an `unconditional_guidance_scale` +parameter. Ours has no CFG support. + +**Why it might be a good idea**: CFG lets the user trade off ensemble diversity +vs. forecast fidelity at inference time by scaling the conditional prediction +away from the unconditional prediction. Higher guidance → more confident +(less diverse) forecasts. This is useful operational flexibility. diff --git a/pyproject.toml b/pyproject.toml index 9817fe8..84843e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "absl-py>=2.4", "beartype>=0.18", "cf-xarray>=0.10", + "einops>=0.8", "etils>=1.13", "fiddle>=0.3", "fire>=0.7", @@ -85,6 +86,7 @@ scripts.mlcast = "mlcast.__main__:cli" [dependency-groups] dev = [ + "pre-commit>=4.6", "pytest>=9.0.3", ] mlflow = [ ] diff --git a/src/mlcast/__main__.py b/src/mlcast/__main__.py index 31ffc0b..b05028f 100644 --- a/src/mlcast/__main__.py +++ b/src/mlcast/__main__.py @@ -5,10 +5,11 @@ Usage examples:: - # Train with default config and override dataset path: + # Train with an included config and override dataset path: python -m mlcast train \\ + --config config:convgru_training_experiment \\ --config fiddler:use_random_sampler \\ - --config set:data.dataset_factory.zarr_path="'/path/to/data.zarr'" + --config set:data.sequence_dataset_factory.zarr_path="'/path/to/data.zarr'" # Train from a previously saved YAML config: python -m mlcast train --config /path/to/config.yaml @@ -18,8 +19,8 @@ --config /path/to/config.yaml \\ --config set:trainer.max_epochs=50 - # Switch to a different base config function entirely: - python -m mlcast train --config=config:another_experiment_function + # Use a different included config function: + python -m mlcast train --config config:another_experiment_function """ import argparse @@ -35,14 +36,14 @@ from rich.text import Text from . import config # noqa: F401 — module must be importable for absl_flags -from .config import load_yaml_config, train_from_config, training_experiment +from .config import convgru_training_experiment, load_yaml_config, train_from_config FLAGS = flags.FLAGS _config = absl_flags.DEFINE_fiddle_config( "config", default_module=config, - help_string="Experiment configuration. Default is training_experiment.", + help_string="Experiment configuration. Required: use config: or a YAML path.", ) flags.DEFINE_boolean( @@ -52,50 +53,68 @@ ) -def get_cli_examples(cfg: fdl.Buildable) -> list[tuple[str, str]]: +def get_included_config_names() -> list[str]: + """Return public config factory names exposed by ``mlcast.config``.""" + included_configs: list[str] = [] + for name in getattr(config, "__all__", []): + value = getattr(config, name, None) + if callable(value) and hasattr(value, "as_buildable"): + included_configs.append(name) + return included_configs + + +def get_cli_examples( + cfg: fdl.Buildable, base_config_name: str = "convgru_training_experiment" +) -> list[tuple[str, str]]: """Returns a list of (description, flag_string) tuples for CLI parameter overrides.""" return [ ( f"Override data layer properties (default batch_size: {cfg.data.batch_size})", - f"--config set:data.batch_size={max(1, cfg.data.batch_size * 2)}", + f"--config config:{base_config_name} --config set:data.batch_size={max(1, cfg.data.batch_size * 2)}", ), ( - f"Override the path to the Zarr dataset (default: {cfg.data.dataset_factory.zarr_path})", - "--config set:data.dataset_factory.zarr_path='/new/path/to/radar.zarr'", + f"Override the path to the Zarr dataset (default: {cfg.data.sequence_dataset_factory.zarr_path})", + f"--config config:{base_config_name} " + "--config set:data.sequence_dataset_factory.zarr_path='/new/path/to/radar.zarr'", ), ( f"Override trainer properties (default max_epochs: {cfg.trainer.max_epochs})", - f"--config set:trainer.max_epochs={max(1, cfg.trainer.max_epochs // 2)}", + f"--config config:{base_config_name} --config set:trainer.max_epochs={max(1, cfg.trainer.max_epochs // 2)}", ), ( f"Override network architecture properties (default num_blocks: {cfg.pl_module.network.num_blocks})", - f"--config set:pl_module.network.num_blocks={max(1, cfg.pl_module.network.num_blocks - 1)}", + "--config config:" + f"{base_config_name} " + "--config set:pl_module.network.num_blocks=" + f"{max(1, cfg.pl_module.network.num_blocks - 1)}", ), ( f"Override the optimizer learning rate (default lr: {cfg.pl_module.optimizer.lr})", - "--config set:pl_module.optimizer.lr=0.1", + f"--config config:{base_config_name} --config set:pl_module.optimizer.lr=0.1", ), ] -def get_fiddler_examples() -> list[tuple[str, str]]: +def get_fiddler_examples(base_config_name: str = "convgru_training_experiment") -> list[tuple[str, str]]: """Returns a list of (description, flag_string) tuples for Fiddler mutators.""" return [ ( "Switch to the random sampling dataset (instead of the precomputed CSV sampler)", - "--config fiddler:use_random_sampler", + f"--config config:{base_config_name} --config fiddler:use_random_sampler", ), ( "Change the input variables and automatically adjust the network's input_channels", + "--config config:" + f"{base_config_name} " "--config \"fiddler:set_variables(standard_names=['rainfall_rate', 'reflectivity'])\"", ), ( "Toggle whether the loss function ignores masked/invalid pixels", - '--config "fiddler:toggle_masking(enabled=False)"', + f'--config config:{base_config_name} --config "fiddler:toggle_masking(enabled=False)"', ), ( "Train using an anonymous S3 object store dataset (e.g. the Italian dataset)", - '--config "fiddler:use_anon_s3_dataset(' + f'--config config:{base_config_name} --config "fiddler:use_anon_s3_dataset(' "zarr_path='s3://mlcast-source-datasets/IT-DPC-SRI/v0.1.0/italian-radar-dpc-sri.zarr/', " "endpoint_url='https://object-store.os-api.cci2.ecmwf.int')\"", ), @@ -105,14 +124,24 @@ def get_fiddler_examples() -> list[tuple[str, str]]: def _build_help_text(cfg: fdl.Buildable) -> Text: """Build Rich-highlighted help text for the ``train`` subcommand.""" t = Text() + included_config_names = get_included_config_names() t.append("Train a model using a Fiddle configuration.\n\n", style="bold") + t.append("You must provide a base config via ", style="bold") + t.append("--config config:", style="bold cyan") + t.append(" or ", style="bold") + t.append("--config /path/to/config.yaml", style="bold cyan") + t.append(".\n\n", style="bold") + t.append("Included configs:\n", style="bold yellow") + for name in included_config_names: + t.append(f" - {name}\n", style="green") + t.append("\n") t.append("You can override parameters from the command line using the ") t.append("--config set:path.to.param=value", style="bold cyan") t.append(" syntax.\n\n") t.append("Examples", style="bold yellow") - t.append(" (based on the default ") - t.append("training_experiment", style="bold green") + t.append(" (based on the included ") + t.append("convgru_training_experiment", style="bold green") t.append(" config):\n") for desc, cmd in get_cli_examples(cfg): @@ -130,25 +159,28 @@ def _build_help_text(cfg: fdl.Buildable) -> Text: t.append(cmd, style="cyan") t.append("\n") - t.append("\nSwitching experiments:\n", style="bold yellow") + t.append("\nConfig sources:\n", style="bold yellow") t.append("\n # Resume from or reproduce a previously saved YAML config:\n", style="dim") t.append(" mlcast train ", style="bold") t.append("--config /path/to/config.yaml\n", style="cyan") t.append("\n # Load a YAML config and apply additional overrides on top:\n", style="dim") t.append(" mlcast train ", style="bold") t.append("--config /path/to/config.yaml --config set:trainer.max_epochs=50\n", style="cyan") - t.append("\n # Use a different base config function defined in src/mlcast/config/.\n", style="dim") + t.append("\n # Use a different included config function defined in src/mlcast/config/.\n", style="dim") t.append(" # Syntax: --config=config: where ", style="dim") t.append("config:", style="dim bold") t.append(" is a Fiddle prefix that resolves\n", style="dim") t.append(" # the function name against the config module (not a Python module path).\n", style="dim") t.append(" mlcast train ", style="bold") - t.append("--config=config:another_experiment_function\n", style="cyan") + t.append("--config config:convgru_training_experiment\n", style="cyan") t.append("\nInspecting the resolved config:\n", style="bold yellow") t.append(" # Print the fully resolved config as YAML without starting training:\n", style="dim") t.append(" mlcast train ", style="bold") - t.append("--config fiddler:use_random_sampler --print_config_and_exit\n", style="cyan") + t.append( + "--config config:convgru_training_experiment --config fiddler:use_random_sampler --print_config_and_exit\n", + style="cyan", + ) return t @@ -335,14 +367,14 @@ def train_main(argv: list[str]) -> None: def cli() -> None: """Console script entry point for the ``mlcast`` command. - This parses standard CLI arguments via `argparse`, injects Fiddle default - overrides if no base configuration is provided, formats the `--help` - output, and safely passes execution over to `absl.app.run`. + This parses standard CLI arguments via `argparse`, validates that an + explicit base configuration was provided, formats the `--help` output, and + safely passes execution over to `absl.app.run`. """ # Dynamically generate help text showing Fiddle overrides try: - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() description_text = _build_help_text(cfg) except Exception: # Fallback if config generation fails during CLI initialization @@ -378,16 +410,22 @@ def cli() -> None: yaml_path, remaining = _extract_yaml_config_path(remaining) # Case 2: user supplied an explicit base config function - # e.g. --config=config:another_experiment_function + # e.g. --config config:convgru_training_experiment has_explicit_config = any( arg.startswith("--config=config:") or (arg == "--config" and i + 1 < len(remaining) and remaining[i + 1].startswith("config:")) for i, arg in enumerate(remaining) ) - # Case 3: no base config from either source — fall back to training_experiment if not has_explicit_config and yaml_path is None: - remaining = ["--config=config:training_experiment"] + remaining + included = ", ".join(get_included_config_names()) + print( + "Error: a base config is required. Provide either " + "'--config config:' or '--config /path/to/config.yaml'.\n" + f"Included configs: {included}", + file=sys.stderr, + ) + sys.exit(1) remaining = auto_quote_fiddle_strings(remaining) diff --git a/src/mlcast/config/__init__.py b/src/mlcast/config/__init__.py index 194ad6b..2762d96 100644 --- a/src/mlcast/config/__init__.py +++ b/src/mlcast/config/__init__.py @@ -4,7 +4,9 @@ and runtime orchestration logic for `mlcast`. """ -from .base import Experiment, training_experiment +from .archetype.convgru import convgru_training_experiment +from .archetype.latent_diffusion import LatentDiffusionTrainingExperiment, latent_diffusion_experiment +from .base import Experiment from .consistency_checks import validate_config from .fiddlers import ( set_variables, @@ -19,7 +21,9 @@ __all__ = [ "Experiment", - "training_experiment", + "LatentDiffusionTrainingExperiment", + "convgru_training_experiment", + "latent_diffusion_experiment", "validate_config", "train_from_config", "load_yaml_config", diff --git a/src/mlcast/config/archetype/__init__.py b/src/mlcast/config/archetype/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mlcast/config/archetype/convgru.py b/src/mlcast/config/archetype/convgru.py new file mode 100644 index 0000000..1d144fe --- /dev/null +++ b/src/mlcast/config/archetype/convgru.py @@ -0,0 +1,85 @@ +"""ConvGRU ensemble nowcasting experiment configuration.""" + +import fiddle as fdl +import fiddle.experimental.auto_config +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from mlcast.data.datamodules import ForecastingDataModule +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset +from mlcast.models.convgru import ConvGruModel +from mlcast.modules.forecasting import OutputSpaceForecastingTaskModule + +from ..base import Experiment + + +@fiddle.experimental.auto_config.auto_config +def convgru_training_experiment() -> Experiment: + """Build a Fiddle config for ConvGRU ensemble radar nowcasting. + + This is decorated as a Fiddle ``@auto_config`` function: calling it + returns a buildable config graph where any parameter can be overridden + before instantiation via ``fdl.build()``. + + Returns + ------- + Experiment + Configured experiment with model, data, and trainer. + """ + sequence_dataset_factory = fdl.Partial( + SourceDataPrecomputedSequenceDataset, + zarr_path="./data/radar.zarr", + csv_path="./data/sampled_datacubes.csv", + standard_names=["rainfall_rate"], + sequence_steps=18, + deterministic=False, + ) + + data = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=6, + forecast_steps=12, + return_mask=True, + splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, + batch_size=16, + num_workers=8, + pin_memory=True, + ) + + network = ConvGruModel( + input_steps=6, + forecast_steps=12, + ensemble_size=2, + input_channels=1, + num_blocks=5, + noisy_decoder=False, + ) + + pl_module = OutputSpaceForecastingTaskModule( + network=network, + loss_class="crps", + loss_params={"temporal_lambda": 0.01}, + masked_loss=True, + optimizer=fdl.Partial(torch.optim.Adam, lr=1e-4, fused=True), + lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.5, patience=10), + ) + + trainer = pl.Trainer( + accelerator="auto", + max_epochs=100, + callbacks=[ + ModelCheckpoint(monitor="val/loss", save_top_k=1, mode="min"), + ModelCheckpoint(monitor="train/loss_epoch", save_top_k=1, mode="min"), + EarlyStopping(monitor="val/loss", patience=100, mode="min"), + LearningRateMonitor(logging_interval="step"), + ], + logger=TensorBoardLogger(save_dir="logs", name="mlcast"), + ) + + return Experiment( + pl_module=pl_module, + data=data, + trainer=trainer, + ) diff --git a/src/mlcast/config/archetype/latent_diffusion.py b/src/mlcast/config/archetype/latent_diffusion.py new file mode 100644 index 0000000..80f8c10 --- /dev/null +++ b/src/mlcast/config/archetype/latent_diffusion.py @@ -0,0 +1,151 @@ +"""Fiddle configuration for two-stage latent diffusion training.""" + +from dataclasses import dataclass + +import fiddle as fdl +import fiddle.experimental.auto_config +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from mlcast.data.datamodules import ForecastingDataModule, ReconstructionDataModule +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset +from mlcast.models.autoencoder import AutoencoderNet, Decoder, Encoder +from mlcast.models.diffusion import ConditionerNet, DenoiserUNet, DiffusionScheduler, LatentDiffusionNet +from mlcast.modules.forecasting import LatentDiffusionTaskModule +from mlcast.modules.reconstruction import ReconstructionTaskModule + +from ..base import Experiment + + +@dataclass +class LatentDiffusionTrainingExperiment: + """Two-stage latent diffusion training experiment. + + Parameters + ---------- + stage1 : Experiment + Reconstruction training stage for the autoencoder. + stage2 : Experiment + Latent diffusion training stage reusing the same trained autoencoder + instance from stage 1. + """ + + stage1: Experiment + stage2: Experiment + + @property + def trainer(self) -> pl.Trainer: + """Expose the first trainer for orchestrator compatibility. + + Returns + ------- + pl.Trainer + Trainer used by stage 1. + """ + return self.stage1.trainer + + def run(self) -> None: + """Run stage-1 reconstruction followed by stage-2 latent diffusion.""" + self.stage1.trainer.fit(self.stage1.pl_module, datamodule=self.stage1.data) + self.stage1.trainer.test(self.stage1.pl_module, datamodule=self.stage1.data) + + self.stage2.trainer.fit(self.stage2.pl_module, datamodule=self.stage2.data) + self.stage2.trainer.test(self.stage2.pl_module, datamodule=self.stage2.data) + + +@fiddle.experimental.auto_config.auto_config +def latent_diffusion_experiment() -> LatentDiffusionTrainingExperiment: + """Build a Fiddle config for two-stage latent diffusion training. + + Returns + ------- + LatentDiffusionTrainingExperiment + Configured two-stage experiment with shared autoencoder identity across + reconstruction and latent diffusion stages. + """ + input_steps = 4 + forecast_steps = 12 + sequence_steps = input_steps + forecast_steps + + sequence_dataset_factory = fdl.Partial( + SourceDataPrecomputedSequenceDataset, + zarr_path="./data/radar.zarr", + csv_path="./data/sampled_datacubes.csv", + standard_names=["rainfall_rate"], + sequence_steps=sequence_steps, + deterministic=False, + ) + + autoencoder = AutoencoderNet( + encoder=Encoder(input_channels=1, hidden_channels=16, latent_channels=32, num_blocks=2), + decoder=Decoder(output_channels=1, hidden_channels=16, latent_channels=32, num_blocks=2), + ) + + stage1_data = ReconstructionDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=input_steps, + splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, + batch_size=4, + num_workers=8, + pin_memory=True, + ) + stage1_module = ReconstructionTaskModule( + network=autoencoder, + loss_class="mse", + optimizer=fdl.Partial(torch.optim.AdamW, lr=1e-3, betas=(0.5, 0.9), weight_decay=1e-3, fused=True), + lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.25, patience=3), + ) + stage1_trainer = pl.Trainer( + accelerator="auto", + max_epochs=20, + accumulate_grad_batches=2, + callbacks=[ + ModelCheckpoint(monitor="val/rec_loss", save_top_k=3, mode="min"), + EarlyStopping(monitor="val/rec_loss", patience=6, mode="min"), + LearningRateMonitor(logging_interval="step"), + ], + logger=TensorBoardLogger(save_dir="logs", name="mlcast_latent_diffusion_stage1"), + ) + + stage2_data = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=input_steps, + forecast_steps=forecast_steps, + return_mask=False, + splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, + batch_size=1, + num_workers=8, + pin_memory=True, + ) + diffusion_net = LatentDiffusionNet( + conditioner=ConditionerNet(latent_channels=32, hidden_channels=32, num_blocks=2), + denoiser=DenoiserUNet(latent_channels=32, condition_channels=32, hidden_channels=32, num_blocks=2), + scheduler=DiffusionScheduler(timesteps=1000), + ) + stage2_module = LatentDiffusionTaskModule( + autoencoder=autoencoder, + diffusion_net=diffusion_net, + forecast_steps=forecast_steps, + ensemble_size=2, + optimizer=fdl.Partial(torch.optim.AdamW, lr=1e-4, betas=(0.5, 0.9), weight_decay=1e-3, fused=True), + lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.25, patience=3), + ema_decay=0.9999, + ) + stage2_trainer = pl.Trainer( + accelerator="auto", + max_epochs=20, + accumulate_grad_batches=2, + callbacks=[ + ModelCheckpoint(monitor="val/loss", save_top_k=3, mode="min"), + EarlyStopping(monitor="val/loss", patience=6, mode="min", check_finite=False), + LearningRateMonitor(logging_interval="step"), + ], + logger=TensorBoardLogger(save_dir="logs", name="mlcast_latent_diffusion_stage2"), + ) + + return LatentDiffusionTrainingExperiment( + stage1=Experiment(pl_module=stage1_module, data=stage1_data, trainer=stage1_trainer), + stage2=Experiment(pl_module=stage2_module, data=stage2_data, trainer=stage2_trainer), + ) diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index fca8117..2c76f3e 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -1,38 +1,12 @@ -"""Base Fiddle experiment definitions for ConvGRU radar nowcasting. +"""Base Fiddle experiment definition for radar nowcasting. -This module defines the ``Experiment`` dataclass and the -``training_experiment`` auto-config factory, which together form the default -configuration graph for a ConvGRU ensemble nowcasting run. - -``training_experiment`` is decorated with ``@auto_config``: calling it returns -a ``fdl.Config`` graph rather than a live ``Experiment`` object. Every -parameter in the graph can be overridden before instantiation — either via -fiddlers (for semantic, multi-parameter changes) or via ``set:`` overrides on -the CLI (for single-parameter tweaks). Call ``fdl.build(cfg)`` to materialise -the graph into real Python objects. - -Typical usage -------------- ->>> cfg = training_experiment() # returns fdl.Config ->>> use_random_sampler(cfg) # apply a fiddler ->>> validate_config(cfg) # check cross-parameter contracts ->>> experiment = fdl.build(cfg) # instantiate everything ->>> experiment.run() # train + test +This module defines the ``Experiment`` dataclass used across all experiment +configurations. """ from dataclasses import dataclass -import fiddle as fdl -import fiddle.experimental.auto_config import pytorch_lightning as pl -import torch -from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger - -from ..data.source_data_datamodule import SourceDataDataModule -from ..data.source_data_datasets import SourceDataPrecomputedSamplingDataset -from ..models.convgru import ConvGruModel -from ..nowcasting_module import NowcastLightningModule @dataclass @@ -47,70 +21,3 @@ def run(self) -> None: """Train and evaluate the configured model.""" self.trainer.fit(self.pl_module, datamodule=self.data) self.trainer.test(self.pl_module, datamodule=self.data) - - -@fiddle.experimental.auto_config.auto_config -def training_experiment() -> Experiment: - """Build a Fiddle config for ConvGRU ensemble radar nowcasting. - - This is decorated as a Fiddle ``@auto_config`` function: calling it - returns a buildable config graph where any parameter can be overridden - before instantiation via ``fdl.build()``. - - Returns - ------- - Experiment - Configured experiment with model, data, and trainer. - """ - dataset_factory = fdl.Partial( - SourceDataPrecomputedSamplingDataset, - zarr_path="./data/radar.zarr", - csv_path="./data/sampled_datacubes.csv", - standard_names=["rainfall_rate"], - input_steps=6, - forecast_steps=12, - return_mask=True, - deterministic=False, - ) - - data = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, - batch_size=16, - num_workers=8, - pin_memory=True, - ) - - network = ConvGruModel( - input_channels=1, - num_blocks=5, - noisy_decoder=False, - ) - - pl_module = NowcastLightningModule( - network=network, - ensemble_size=2, - loss_class="crps", - loss_params={"temporal_lambda": 0.01}, - masked_loss=True, - optimizer=fdl.Partial(torch.optim.Adam, lr=1e-4, fused=True), - lr_scheduler=fdl.Partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode="min", factor=0.5, patience=10), - ) - - trainer = pl.Trainer( - accelerator="auto", - max_epochs=100, - callbacks=[ - ModelCheckpoint(monitor="val_loss", save_top_k=1, mode="min"), - ModelCheckpoint(monitor="train_loss_epoch", save_top_k=1, mode="min"), - EarlyStopping(monitor="val_loss", patience=100, mode="min"), - LearningRateMonitor(logging_interval="step"), - ], - logger=TensorBoardLogger(save_dir="logs", name="mlcast"), - ) - - return Experiment( - pl_module=pl_module, - data=data, - trainer=trainer, - ) diff --git a/src/mlcast/config/consistency_checks.py b/src/mlcast/config/consistency_checks.py index dea9f0a..43314d8 100644 --- a/src/mlcast/config/consistency_checks.py +++ b/src/mlcast/config/consistency_checks.py @@ -15,27 +15,26 @@ from loguru import logger -def validate_config(cfg: fdl.Config) -> None: - """Validate cross-system constraints on a Fiddle configuration before training. +def _validate_forecasting_experiment_cfg(cfg: fdl.Config) -> None: + """Validate a single-stage forecasting experiment configuration. Parameters ---------- cfg : fdl.Config - Fiddle configuration. + Fiddle configuration for a single forecasting experiment. Raises ------ ValueError If any configuration contract is violated. """ - dataset_factory = cfg.data.dataset_factory + sequence_dataset_factory = cfg.data.sequence_dataset_factory network = cfg.pl_module.network pl_module = cfg.pl_module + data = cfg.data - # Contract 1: Network input_channels == len(dataset_factory.standard_names) - # If the network does not expose input_channels, emit a warning because - # this contract cannot be checked. - num_vars = len(dataset_factory.standard_names) + # Contract 1: Network input_channels == len(sequence_dataset_factory.standard_names) + num_vars = len(sequence_dataset_factory.standard_names) try: net_input_channels = network.input_channels except AttributeError: @@ -51,9 +50,7 @@ def validate_config(cfg: fdl.Config) -> None: f"must equal the number of standard_names ({num_vars})." ) - # Contract 2: Dataset width must be divisible by 2 ** network.num_blocks - # If the network does not expose num_blocks, emit a warning because this - # contract cannot be checked. + # Contract 2: Sequence dataset width must be divisible by 2 ** network.num_blocks try: num_blocks = network.num_blocks except AttributeError: @@ -64,7 +61,7 @@ def validate_config(cfg: fdl.Config) -> None: ) num_blocks = None if num_blocks is not None: - width = getattr(dataset_factory, "width", 256) + width = getattr(sequence_dataset_factory, "width", 256) divisor = 2**num_blocks if width % divisor != 0: raise ValueError( @@ -73,16 +70,113 @@ def validate_config(cfg: fdl.Config) -> None: ) # Contract 3: Ensemble models require CRPS or AFCRPS - if pl_module.ensemble_size > 1: + ensemble_size = getattr(network, "ensemble_size", 1) + if ensemble_size > 1: if str(pl_module.loss_class).lower() not in ["crps", "afcrps"]: raise ValueError( - f"Contract 3 violated: Ensemble models (ensemble_size={pl_module.ensemble_size}) " + f"Contract 3 violated: Ensemble models (ensemble_size={ensemble_size}) " f"require 'crps' or 'afcrps' loss, got '{pl_module.loss_class}'." ) - # Contract 4: Dataset return_mask must match model masked_loss - if bool(dataset_factory.return_mask) != bool(pl_module.masked_loss): + # Contract 4: Forecasting mask return must match model masked_loss + if bool(data.return_mask) != bool(pl_module.masked_loss): raise ValueError( - f"Contract 4 violated: dataset_factory.return_mask ({dataset_factory.return_mask}) " + f"Contract 4 violated: data.return_mask ({data.return_mask}) " f"must match pl_module.masked_loss ({pl_module.masked_loss})." ) + + # Contract 5: Dataset input_steps must match model input_steps + try: + net_input_steps = network.input_steps + except AttributeError: + logger.warning( + "Warning: can't ensure network input_steps matches data.input_steps, " + "because network {} doesn't expose 'input_steps'.", + network.__class__.__name__, + ) + net_input_steps = None + if net_input_steps is not None and net_input_steps != data.input_steps: + raise ValueError( + f"Contract 5 violated: network input_steps ({net_input_steps}) " + f"must equal data.input_steps ({data.input_steps})." + ) + + # Contract 6: Dataset forecast_steps must match model forecast_steps + try: + net_forecast_steps = network.forecast_steps + except AttributeError: + logger.warning( + "Warning: can't ensure network forecast_steps matches data.forecast_steps, " + "because network {} doesn't expose 'forecast_steps'.", + network.__class__.__name__, + ) + net_forecast_steps = None + if net_forecast_steps is not None and net_forecast_steps != data.forecast_steps: + raise ValueError( + f"Contract 6 violated: network forecast_steps ({net_forecast_steps}) " + f"must equal data.forecast_steps ({data.forecast_steps})." + ) + + +def _validate_latent_diffusion_experiment_cfg(cfg: fdl.Config) -> None: + """Validate a two-stage latent diffusion training experiment configuration. + + Parameters + ---------- + cfg : fdl.Config + Fiddle configuration for a two-stage latent diffusion experiment. + + Raises + ------ + ValueError + If any latent-diffusion-specific configuration contract is violated. + """ + stage1 = cfg.stage1 + stage2 = cfg.stage2 + + autoencoder = stage1.pl_module.network + if autoencoder is not stage2.pl_module.autoencoder: + raise ValueError( + "LatentDiffusion contract violated: stage1 and stage2 must share the same autoencoder config object." + ) + + stage1_data = stage1.data + stage2_data = stage2.data + if stage1_data.input_steps != stage2_data.input_steps: + raise ValueError( + "LatentDiffusion contract violated: stage1 and stage2 must use the same input_steps; " + f"got {stage1_data.input_steps} and {stage2_data.input_steps}." + ) + + stage2_module = stage2.pl_module + if stage2_data.forecast_steps != stage2_module.forecast_steps: + raise ValueError( + "LatentDiffusion contract violated: stage2 data.forecast_steps must match the latent diffusion " + f"task module; got {stage2_data.forecast_steps} and {stage2_module.forecast_steps}." + ) + + if len(stage1_data.sequence_dataset_factory.standard_names) != autoencoder.encoder.input_channels: + raise ValueError( + "LatentDiffusion contract violated: autoencoder encoder input_channels must match the " + "number of source variables." + ) + + +def validate_config(cfg: fdl.Config) -> None: + """Validate cross-system constraints on a Fiddle configuration before training. + + Parameters + ---------- + cfg : fdl.Config + Fiddle configuration. + + Raises + ------ + ValueError + If any configuration contract is violated. + """ + if hasattr(cfg, "stage1") and hasattr(cfg, "stage2"): + _validate_latent_diffusion_experiment_cfg(cfg) + return + + _validate_forecasting_experiment_cfg(cfg) diff --git a/src/mlcast/config/fiddlers.py b/src/mlcast/config/fiddlers.py index dceee26..69261a7 100644 --- a/src/mlcast/config/fiddlers.py +++ b/src/mlcast/config/fiddlers.py @@ -1,6 +1,6 @@ """Fiddler mutators for high-level semantic configuration changes. -Fiddlers are functions that accept a ``fdl.Config`` and mutate it in place. +Fiddlers are functions that accept a ``fdl.Config`` and mutate them in place. They are the right tool when a change spans multiple config parameters that must stay in sync — for example, switching the dataset class while preserving its existing parameters, or enabling masking consistently across both the data @@ -11,26 +11,120 @@ config in Python before passing it to ``fdl.build()``. """ +import functools import inspect import os +from collections.abc import Callable import fiddle as fdl +import torch.nn as nn from loguru import logger from pytorch_lightning.loggers import MLFlowLogger +from mlcast.config.base import Experiment + from ..callbacks import LogSystemInfoCallback -from ..data.source_data_datasets import SourceDataRandomSamplingDataset +from ..data.sequence import SourceDataRandomSequenceDataset + + +def _iter_experiment_configs(cfg: fdl.Buildable): + """Yield all ``fdl.Config`` sub-nodes whose callable is ``Experiment``, depth-first. + + Parameters + ---------- + cfg : fdl.Buildable + Root of the Fiddle configuration tree to traverse. + + Yields + ------ + fdl.Config + Each sub-config whose ``fdl.get_callable`` is the ``Experiment`` + dataclass. + """ + if not isinstance(cfg, fdl.Buildable): + return + try: + if fdl.get_callable(cfg) is Experiment: + yield cfg + except (TypeError, AttributeError): + pass + try: + for child in fdl.ordered_arguments(cfg).values(): + yield from _iter_experiment_configs(child) + except (TypeError, AttributeError): + pass + + +def _find_nn_modules_with_input_channels(cfg: fdl.Buildable): + """Yield all ``fdl.Config`` nodes for ``nn.Module`` subclasses that accept ``input_channels``. + + Parameters + ---------- + cfg : fdl.Buildable + Root of the Fiddle configuration tree to traverse (typically + ``cfg.pl_module``). + + Yields + ------ + fdl.Config + Each sub-config whose callable is an ``nn.Module`` subclass with + ``input_channels`` in its ``__init__`` signature. + """ + if not isinstance(cfg, fdl.Config): + return + try: + cls = fdl.get_callable(cfg) + if isinstance(cls, type) and issubclass(cls, nn.Module): + if "input_channels" in inspect.signature(cls.__init__).parameters: + yield cfg + except (TypeError, AttributeError): + pass + try: + for child in fdl.ordered_arguments(cfg).values(): + yield from _find_nn_modules_with_input_channels(child) + except (TypeError, AttributeError): + pass + + +def applies_to_experiments(fiddler: Callable) -> Callable: + """Decorate a fiddler so it applies to every ``Experiment`` sub-config in the tree. + + This makes fiddlers work with both flat ``Experiment`` configs (returned by + ``convgru_training_experiment``) and nested containers like + ``LatentDiffusionTrainingExperiment`` that contain multiple ``Experiment`` instances. + + Parameters + ---------- + fiddler : Callable + Fiddler function whose first argument is a ``fdl.Config``. + + Returns + ------- + Callable + Wrapped fiddler that traverses the config tree for ``Experiment`` + sub-configs and applies the original fiddler to each one. + """ + + @functools.wraps(fiddler) + def wrapper(cfg: fdl.Buildable, *args: object, **kwargs: object) -> None: + experiments = list(_iter_experiment_configs(cfg)) + if experiments: + for exp_cfg in experiments: + fiddler(exp_cfg, *args, **kwargs) + else: + fiddler(cfg, *args, **kwargs) + + return wrapper +@applies_to_experiments def set_variables(cfg: fdl.Config, standard_names: list[str]) -> None: """Fiddler to synchronize dataset variables with the network's input channels. - Sets ``dataset_factory.standard_names`` on the data config and, when the - network config exposes an ``input_channels`` parameter (e.g. - ``ConvGruModel``), keeps it in sync. Networks that use a different - parameter name for the channel count (e.g. ``HalfUNet`` uses - ``in_channels``) are left unchanged — callers are responsible for keeping - that parameter consistent when swapping in an external architecture. + Sets ``sequence_dataset_factory.standard_names`` on the data config and + walks ``cfg.pl_module`` to find any ``nn.Module`` with an ``input_channels`` + ``__init__`` parameter (e.g. ``ConvGruModel``, ``Encoder``), keeping it in + sync with the number of loaded variables. Parameters ---------- @@ -39,21 +133,20 @@ def set_variables(cfg: fdl.Config, standard_names: list[str]) -> None: standard_names : list of str The new list of standard names to load. """ - cfg.data.dataset_factory.standard_names = standard_names - network_cls = cfg.pl_module.network.__fn_or_cls__ - sig = inspect.signature(network_cls.__init__) - if "input_channels" in sig.parameters: - cfg.pl_module.network.input_channels = len(standard_names) - else: + cfg.data.sequence_dataset_factory.standard_names = standard_names + found = False + for module_cfg in _find_nn_modules_with_input_channels(cfg.pl_module): + module_cfg.input_channels = len(standard_names) + found = True + if not found: logger.warning( - "set_variables: network {} has no 'input_channels' parameter; " - "channel count not updated. Set it manually on the network config.", - network_cls.__name__, + "set_variables: no nn.Module under pl_module has an 'input_channels' parameter; channel count not updated." ) +@applies_to_experiments def toggle_masking(cfg: fdl.Config, enabled: bool) -> None: - """Fiddler to synchronize dataset mask yielding with masked loss computation. + """Fiddler to synchronize forecasting-mask yielding with masked loss computation. Parameters ---------- @@ -62,12 +155,13 @@ def toggle_masking(cfg: fdl.Config, enabled: bool) -> None: enabled : bool Whether to enable masking or not. """ - cfg.data.dataset_factory.return_mask = enabled + cfg.data.return_mask = enabled cfg.pl_module.masked_loss = enabled +@applies_to_experiments def use_random_sampler(cfg: fdl.Config) -> None: - """Fiddler to switch the dataset factory to use the random sampler. + """Fiddler to switch the sequence dataset factory to use the random sampler. Parameters ---------- @@ -75,22 +169,22 @@ def use_random_sampler(cfg: fdl.Config) -> None: The Fiddle configuration to mutate. """ # Keep the existing parameters but change the underlying class - cfg.data.dataset_factory = fdl.Partial( - SourceDataRandomSamplingDataset, - zarr_path=cfg.data.dataset_factory.zarr_path, - standard_names=cfg.data.dataset_factory.standard_names, - input_steps=cfg.data.dataset_factory.input_steps, - forecast_steps=cfg.data.dataset_factory.forecast_steps, - return_mask=cfg.data.dataset_factory.return_mask, - storage_options=getattr(cfg.data.dataset_factory, "storage_options", None), + cfg.data.sequence_dataset_factory = fdl.Partial( + SourceDataRandomSequenceDataset, + zarr_path=cfg.data.sequence_dataset_factory.zarr_path, + standard_names=cfg.data.sequence_dataset_factory.standard_names, + sequence_steps=cfg.data.sequence_dataset_factory.sequence_steps, + storage_options=getattr(cfg.data.sequence_dataset_factory, "storage_options", None), ) +@applies_to_experiments def use_ratio_splits(cfg: fdl.Config, train: float, val: float) -> None: """Fiddler to set fraction-based train/val/test splits on the data module.""" cfg.data.splits = {"time": {"train": train, "val": val, "test": 1.0 - train - val}} +@applies_to_experiments def use_anon_s3_dataset(cfg: fdl.Buildable, zarr_path: str, endpoint_url: str) -> None: """Configure the dataset factory to read anonymously from an S3 object store. @@ -103,8 +197,8 @@ def use_anon_s3_dataset(cfg: fdl.Buildable, zarr_path: str, endpoint_url: str) - endpoint_url : str The endpoint URL for the S3 object store. """ - cfg.data.dataset_factory.zarr_path = zarr_path - cfg.data.dataset_factory.storage_options = { + cfg.data.sequence_dataset_factory.zarr_path = zarr_path + cfg.data.sequence_dataset_factory.storage_options = { "anon": True, "client_kwargs": { "endpoint_url": endpoint_url, @@ -114,6 +208,7 @@ def use_anon_s3_dataset(cfg: fdl.Buildable, zarr_path: str, endpoint_url: str) - } +@applies_to_experiments def use_mlflow_logger(cfg: fdl.Config) -> None: """Fiddler to switch the trainer logger to MLflow. diff --git a/src/mlcast/config/orchestrator.py b/src/mlcast/config/orchestrator.py index 16fea32..9d21ec0 100644 --- a/src/mlcast/config/orchestrator.py +++ b/src/mlcast/config/orchestrator.py @@ -126,7 +126,8 @@ def train_from_config(cfg: fdl.Config) -> None: Parameters ---------- cfg : fdl.Config - Fiddle configuration as returned by `training_experiment`. + Fiddle configuration as returned by an included auto-config factory such + as `convgru_training_experiment`. """ validate_config(cfg) diff --git a/src/mlcast/data/__init__.py b/src/mlcast/data/__init__.py index e4b2449..3d1a03d 100644 --- a/src/mlcast/data/__init__.py +++ b/src/mlcast/data/__init__.py @@ -1,4 +1,18 @@ -from .source_data_datamodule import SourceDataDataModule -from .source_data_datasets import SourceDataPrecomputedSamplingDataset +from .datamodules import ForecastingDataModule, ReconstructionDataModule +from .forecasting import ForecastingDataset +from .reconstruction import ReconstructionDataset +from .sequence import ( + SourceDataPrecomputedSequenceDataset, + SourceDataRandomSequenceDataset, + SourceDataSequenceDatasetBase, +) -__all__ = ["SourceDataDataModule", "SourceDataPrecomputedSamplingDataset"] +__all__ = [ + "ForecastingDataModule", + "ForecastingDataset", + "ReconstructionDataModule", + "ReconstructionDataset", + "SourceDataPrecomputedSequenceDataset", + "SourceDataRandomSequenceDataset", + "SourceDataSequenceDatasetBase", +] diff --git a/src/mlcast/data/datamodules.py b/src/mlcast/data/datamodules.py new file mode 100644 index 0000000..be6492d --- /dev/null +++ b/src/mlcast/data/datamodules.py @@ -0,0 +1,278 @@ +"""PyTorch Lightning data modules for forecasting and reconstruction tasks.""" + +from collections.abc import Callable +from typing import Any + +import pytorch_lightning as pl +from loguru import logger +from torch.utils.data import DataLoader, Dataset + +from mlcast.data.forecasting import ForecastingDataset +from mlcast.data.reconstruction import ReconstructionDataset +from mlcast.data.splits import ( + compute_split_ranges_from_splitting_ratios, + splitting_uses_fractions, + splitting_uses_tuple_ranges, + validate_splits, +) + + +class _BaseDataModule(pl.LightningDataModule): + """Shared split/build logic for task-level data modules. + + Parameters + ---------- + sequence_dataset_factory : Callable[..., Dataset] + Factory that builds source-data sequence datasets and accepts + ``subset`` and ``augment`` keyword arguments. + splits : dict of str to dict + Nested mapping describing train/validation/test coordinate splits. + **dataloader_kwargs : Any + Additional keyword arguments forwarded to ``DataLoader``. + """ + + def __init__( + self, + sequence_dataset_factory: Callable[..., Dataset], + splits: dict[str, dict[str, Any]], + **dataloader_kwargs: Any, + ) -> None: + super().__init__() + self.sequence_dataset_factory = sequence_dataset_factory + self.splits = splits + self.dataloader_kwargs = dataloader_kwargs + validate_splits(self.splits) + + def _build_sequence_dataset(self, subset: dict[str, Any], augment: bool) -> Dataset: + """Build a source-data sequence dataset for one split. + + Parameters + ---------- + subset : dict of str to Any + Coordinate subset passed to the sequence dataset factory. + augment : bool + Whether this split should apply data augmentation. + + Returns + ------- + Dataset + Built source-data sequence dataset. + """ + return self.sequence_dataset_factory(subset=subset, augment=augment) + + def _wrap_sequence_dataset(self, base_sequence_dataset: Dataset) -> Dataset: + """Wrap a sequence dataset into a task-specific dataset. + + Parameters + ---------- + base_sequence_dataset : Dataset + Source-data sequence dataset for a split. + + Returns + ------- + Dataset + Task-specific dataset for the split. + """ + raise NotImplementedError + + def setup(self, stage: str | None = None) -> None: + """Create train, validation, and test datasets. + + Parameters + ---------- + stage : str or None, optional + Lightning setup stage. Supports ``"fit"``, ``"validate"``, + ``"test"``, and ``None``. Default is ``None``. + + Raises + ------ + ValueError + If ``stage`` is unsupported. + NotImplementedError + If a configured split mode is unsupported. + """ + if stage == "fit": + requested_splits = {"train", "val"} + elif stage == "validate": + requested_splits = {"val"} + elif stage == "test": + requested_splits = {"test"} + elif stage is None: + requested_splits = {"train", "val", "test"} + else: + raise ValueError(f"Unsupported LightningDataModule setup stage: {stage!r}") + + subset_per_split: dict[str, dict[str, Any] | None] = { + split_name: ( + {} + if split_name in requested_splits + and any(split_name in coord_splits for coord_splits in self.splits.values()) + else None + ) + for split_name in ("train", "val", "test") + } + + for coord, coord_splits in self.splits.items(): + if splitting_uses_tuple_ranges(coord_splits): + coord_values_per_split: dict[str, tuple[str, str] | None] = { + "train": coord_splits["train"], + "val": coord_splits["val"], + "test": coord_splits.get("test"), + } + elif splitting_uses_fractions(coord_splits): + coord_values_per_split = compute_split_ranges_from_splitting_ratios( + self.sequence_dataset_factory, coord, coord_splits + ) + else: + raise NotImplementedError(f"Unsupported split mode for coordinate {coord!r}: {coord_splits!r}") + + for split_name, split_val in coord_values_per_split.items(): + if split_val is None: + subset_per_split[split_name] = None + elif subset_per_split[split_name] is not None: + subset_per_split[split_name][coord] = split_val + + augment_flags = {"train": True, "val": False, "test": False} + for split in ("train", "val", "test"): + subset = subset_per_split[split] + if subset is None: + setattr(self, f"{split}_dataset", None) + else: + base_sequence_dataset = self._build_sequence_dataset(subset=subset, augment=augment_flags[split]) + setattr(self, f"{split}_dataset", self._wrap_sequence_dataset(base_sequence_dataset)) + + logger.info("{}.setup() complete, containing:", self.__class__.__name__) + for split in ("train", "val", "test"): + dataset = getattr(self, f"{split}_dataset", None) + if dataset is not None: + logger.info( + " {:5s}: {:>6d} samples, subset={}", + split, + len(dataset), + subset_per_split[split], + ) + + def train_dataloader(self) -> DataLoader: + """Return the training DataLoader. + + Returns + ------- + DataLoader + Training dataloader with shuffled samples. + """ + return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs) + + def val_dataloader(self) -> DataLoader: + """Return the validation DataLoader. + + Returns + ------- + DataLoader + Validation dataloader without shuffling. + """ + return DataLoader(self.val_dataset, shuffle=False, **self.dataloader_kwargs) + + def test_dataloader(self) -> DataLoader: + """Return the test DataLoader. + + Returns + ------- + DataLoader + Test dataloader without shuffling. + """ + return DataLoader(self.test_dataset, shuffle=False, **self.dataloader_kwargs) + + +class ForecastingDataModule(_BaseDataModule): + """Lightning data module for forecasting datasets. + + Parameters + ---------- + sequence_dataset_factory : Callable[..., Dataset] + Factory that builds source-data sequence datasets. + input_steps : int + Number of input timesteps in each forecasting sample. + forecast_steps : int + Number of target timesteps in each forecasting sample. + return_mask : bool + Whether forecasting samples should include ``target_mask``. + splits : dict of str to dict + Nested mapping describing train/validation/test coordinate splits. + **dataloader_kwargs : Any + Additional keyword arguments forwarded to ``DataLoader``. + """ + + def __init__( + self, + sequence_dataset_factory: Callable[..., Dataset], + input_steps: int, + forecast_steps: int, + return_mask: bool, + splits: dict[str, dict[str, Any]], + **dataloader_kwargs: Any, + ) -> None: + super().__init__(sequence_dataset_factory=sequence_dataset_factory, splits=splits, **dataloader_kwargs) + self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.return_mask = return_mask + + def _wrap_sequence_dataset(self, base_sequence_dataset: Dataset) -> Dataset: + """Wrap a sequence dataset as a forecasting dataset. + + Parameters + ---------- + base_sequence_dataset : Dataset + Sequence dataset for one split. + + Returns + ------- + Dataset + Forecasting dataset for the split. + """ + return ForecastingDataset( + base_sequence_dataset=base_sequence_dataset, + input_steps=self.input_steps, + forecast_steps=self.forecast_steps, + return_mask=self.return_mask, + ) + + +class ReconstructionDataModule(_BaseDataModule): + """Lightning data module for reconstruction datasets. + + Parameters + ---------- + sequence_dataset_factory : Callable[..., Dataset] + Factory that builds source-data sequence datasets. + input_steps : int + Number of timesteps in each reconstruction window. + splits : dict of str to dict + Nested mapping describing train/validation/test coordinate splits. + **dataloader_kwargs : Any + Additional keyword arguments forwarded to ``DataLoader``. + """ + + def __init__( + self, + sequence_dataset_factory: Callable[..., Dataset], + input_steps: int, + splits: dict[str, dict[str, Any]], + **dataloader_kwargs: Any, + ) -> None: + super().__init__(sequence_dataset_factory=sequence_dataset_factory, splits=splits, **dataloader_kwargs) + self.input_steps = input_steps + + def _wrap_sequence_dataset(self, base_sequence_dataset: Dataset) -> Dataset: + """Wrap a sequence dataset as a reconstruction dataset. + + Parameters + ---------- + base_sequence_dataset : Dataset + Sequence dataset for one split. + + Returns + ------- + Dataset + Reconstruction dataset for the split. + """ + return ReconstructionDataset(base_sequence_dataset=base_sequence_dataset, input_steps=self.input_steps) diff --git a/src/mlcast/data/forecasting.py b/src/mlcast/data/forecasting.py new file mode 100644 index 0000000..7cb2e7a --- /dev/null +++ b/src/mlcast/data/forecasting.py @@ -0,0 +1,111 @@ +"""Forecasting dataset wrappers built on top of sequence datasets.""" + +from typing import TypedDict + +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped +from torch import Tensor +from torch.utils.data import Dataset + + +class ForecastingSample(TypedDict, total=False): + """Typed dictionary returned by :class:`ForecastingDataset`. + + Keys + ---- + input : Float[Tensor, "input_steps channels height width"] + Past frames fed to the forecasting model. + target : Float[Tensor, "forecast_steps channels height width"] + Future frames the forecasting model should predict. + target_mask : Float[Tensor, "forecast_steps channels height width"] + Per-timestep, per-channel validity mask for the target when + ``return_mask=True``. + """ + + input: Float[Tensor, "input_steps channels height width"] + target: Float[Tensor, "forecast_steps channels height width"] + target_mask: Float[Tensor, "forecast_steps channels height width"] + + +class ForecastingDataset(Dataset): + """Wrap a sequence dataset to produce forecasting samples. + + Parameters + ---------- + base_sequence_dataset : Dataset + Dataset returning normalized sequence tensors of shape + ``(sequence_steps, channels, height, width)``. + input_steps : int + Number of past timesteps fed to the forecasting model. + forecast_steps : int + Number of future timesteps the forecasting model should predict. + return_mask : bool, optional + Whether to derive and return a target validity mask. Default is + ``False``. + """ + + def __init__( + self, + base_sequence_dataset: Dataset, + input_steps: int, + forecast_steps: int, + return_mask: bool = False, + ) -> None: + if input_steps < 1: + raise ValueError(f"input_steps ({input_steps}) must be at least 1.") + if forecast_steps < 1: + raise ValueError(f"forecast_steps ({forecast_steps}) must be at least 1.") + + self.base_sequence_dataset = base_sequence_dataset + self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.return_mask = return_mask + + sequence_steps = getattr(base_sequence_dataset, "sequence_steps", None) + if sequence_steps is None: + raise AttributeError("base_sequence_dataset must expose a 'sequence_steps' attribute.") + if input_steps + forecast_steps != sequence_steps: + raise ValueError( + "ForecastingDataset requires input_steps + forecast_steps to equal sequence_steps; " + f"got input_steps={input_steps}, forecast_steps={forecast_steps}, sequence_steps={sequence_steps}." + ) + + def __len__(self) -> int: + """Return the number of available forecasting samples. + + Returns + ------- + int + Number of samples in the wrapped sequence dataset. + """ + return len(self.base_sequence_dataset) + + @jaxtyped(typechecker=beartype) + def __getitem__(self, idx: int) -> ForecastingSample: + """Return one forecasting sample derived from the wrapped sequence. + + Parameters + ---------- + idx : int + Index of the wrapped sequence sample. + + Returns + ------- + ForecastingSample + Dictionary containing ``input`` and ``target`` tensors, and + ``target_mask`` when ``return_mask=True``. + """ + sequence = self.base_sequence_dataset[idx] + + if self.return_mask: + target_mask_t = (~torch.isnan(sequence[self.input_steps :])).to(dtype=torch.float32) + + sequence = torch.nan_to_num(sequence, nan=-1.0).to(dtype=torch.float32) + input_t = sequence[: self.input_steps] + target_t = sequence[self.input_steps :] + + sample = ForecastingSample(input=input_t, target=target_t) + if self.return_mask: + sample["target_mask"] = target_mask_t + return sample diff --git a/src/mlcast/data/reconstruction.py b/src/mlcast/data/reconstruction.py new file mode 100644 index 0000000..6327262 --- /dev/null +++ b/src/mlcast/data/reconstruction.py @@ -0,0 +1,73 @@ +"""Reconstruction datasets built from sequence datasets. + +The reconstruction task reuses normalized source-data sequences and exposes all +overlapping temporal windows of length ``input_steps`` as individual training +samples. +""" + +import torch +from jaxtyping import Float +from torch import Tensor +from torch.utils.data import Dataset + + +class ReconstructionDataset(Dataset): + """Wrap a sequence dataset for stage-1 reconstruction training. + + Parameters + ---------- + base_sequence_dataset : Dataset + Dataset whose samples are normalized sequence tensors with shape + ``(sequence_steps, channels, height, width)``. + input_steps : int + Temporal window length to expose for each reconstruction sample. + + Notes + ----- + Each base sequence contributes all overlapping windows of length + ``input_steps``. The reconstruction training module is responsible for + reusing the returned tensor as both the model input and the reconstruction + target. + """ + + def __init__(self, base_sequence_dataset: Dataset, input_steps: int) -> None: + if input_steps < 1: + raise ValueError(f"input_steps ({input_steps}) must be at least 1.") + + self.base_sequence_dataset = base_sequence_dataset + self.input_steps = input_steps + + sequence_steps = getattr(base_sequence_dataset, "sequence_steps", None) + if sequence_steps is None: + raise AttributeError("base_sequence_dataset must expose a 'sequence_steps' attribute.") + if input_steps > sequence_steps: + raise ValueError( + "ReconstructionDataset requires input_steps to be less than or equal to sequence_steps; " + f"got input_steps={input_steps}, sequence_steps={sequence_steps}." + ) + + self.sequence_steps = sequence_steps + self.windows_per_sequence = self.sequence_steps - self.input_steps + 1 + + def __len__(self) -> int: + """Return the number of available reconstruction windows.""" + return len(self.base_sequence_dataset) * self.windows_per_sequence + + def __getitem__(self, idx: int) -> Float[Tensor, "input_steps channels height width"]: + """Return one overlapping reconstruction window. + + Parameters + ---------- + idx : int + Flat reconstruction-sample index. + + Returns + ------- + Float[Tensor, "input_steps channels height width"] + Window extracted from the wrapped sequence sample. + """ + sequence_idx = idx // self.windows_per_sequence + window_start = idx % self.windows_per_sequence + sequence = self.base_sequence_dataset[sequence_idx] + window = sequence[window_start : window_start + self.input_steps] + return torch.nan_to_num(window, nan=-1.0).to(dtype=torch.float32) diff --git a/src/mlcast/data/source_data_datasets.py b/src/mlcast/data/sequence.py similarity index 51% rename from src/mlcast/data/source_data_datasets.py rename to src/mlcast/data/sequence.py index b8e3c80..391a23c 100644 --- a/src/mlcast/data/source_data_datasets.py +++ b/src/mlcast/data/sequence.py @@ -1,12 +1,14 @@ -"""PyTorch datasets for loading spatio-temporal data from Zarr stores. +"""Source-data sequence datasets built from Zarr stores. -Provides pre-computed sampling and (soon) random sampling datasets. +These datasets are responsible for sampling normalized spatio-temporal +sequences directly from source datasets. They do not impose any forecasting or +reconstruction task structure on the sampled sequence. """ import time import warnings from abc import ABC, abstractmethod -from typing import Any, TypedDict +from typing import Any import cf_xarray # noqa: F401 import numpy as np @@ -15,6 +17,7 @@ import xarray as xr from beartype import beartype from jaxtyping import Float, jaxtyped +from torch import Tensor from torch.utils.data import Dataset from mlcast.data.normalization import NORMALIZATION_REGISTRY @@ -25,7 +28,22 @@ def _time_range_to_index_slice( time_range: tuple[str, str], storage_options: dict[str, Any] | None = None, ) -> slice: - """Convert an inclusive ISO time range into a zarr integer slice.""" + """Convert an inclusive ISO time range into a zarr integer slice. + + Parameters + ---------- + zarr_path : str + Path to the Zarr dataset. + time_range : tuple of str + Inclusive ``(start, end)`` ISO 8601 time range. + storage_options : dict or None, optional + Options forwarded to ``xr.open_zarr``. Default is ``None``. + + Returns + ------- + slice + Integer slice covering the requested time range. + """ ds = xr.open_zarr(zarr_path, storage_options=storage_options) time_values = ds.indexes["time"] t_start = time_values.get_indexer([pd.Timestamp(time_range[0])], method="bfill")[0] @@ -38,46 +56,20 @@ def _time_range_to_index_slice( return slice(int(t_start), int(t_end) + 1) -class DatasetSample(TypedDict, total=False): - """Typed dictionary returned by dataset ``__getitem__``. - - Keys - ---- - input : Float[torch.Tensor, "input_steps channels height width"] - Past frames fed to the network as input. - target : Float[torch.Tensor, "forecast_steps channels height width"] - Future frames the network should predict. - target_mask : Float[torch.Tensor, "forecast_steps channels height width"] - Per-timestep, per-channel validity mask for the target (1 = valid, - 0 = NaN in original data). Only present when ``return_mask=True``. - """ - - input: Float[torch.Tensor, "input_steps channels height width"] - target: Float[torch.Tensor, "forecast_steps channels height width"] - target_mask: Float[torch.Tensor, "forecast_steps channels height width"] - - def _detect_axes(ds: xr.Dataset, standard_name: str) -> tuple[str, str, str]: """Detect CF axis dimension names for a variable in an xarray Dataset. - Falls back to dimension names ``'y'`` / ``'x'`` when CF conventions do not - identify the axis, emitting a :mod:`warnings` warning in each case. - Parameters ---------- ds : xr.Dataset - An open xarray Dataset with CF conventions. + Open xarray Dataset with CF metadata. standard_name : str - A CF standard name present in ``ds``, used to look up the variable. + CF standard name of the variable used to infer axes. Returns ------- - t_dim : str - Dimension name for the time axis. - y_dim : str - Dimension name for the Y (latitude) axis. - x_dim : str - Dimension name for the X (longitude) axis. + tuple of str + Names of the time, Y, and X dimensions. """ da = ds.cf[standard_name] t_dim = da.cf["time"].dims[0] @@ -103,12 +95,8 @@ def _detect_axes(ds: xr.Dataset, standard_name: str) -> tuple[str, str, str]: return t_dim, y_dim, x_dim -class SourceDataDatasetBase(Dataset, ABC): - """Abstract base class for mlcast Zarr-backed spatio-temporal datasets. - - Subclasses must implement :meth:`__len__` and :meth:`__getitem__`. - All common initialisation, Zarr access, CF-axis resolution, augmentation, - and the ``steps`` property live here. +class SourceDataSequenceDatasetBase(Dataset, ABC): + """Abstract base class for source-data-backed sequence datasets. Parameters ---------- @@ -116,13 +104,8 @@ class SourceDataDatasetBase(Dataset, ABC): Path to the Zarr dataset. standard_names : list of str List of CF standard names of variables to load. - input_steps : int - Number of past timesteps fed to the network as input. - forecast_steps : int - Number of future timesteps the network should predict. - return_mask : bool, optional - If ``True``, also return a per-timestep validity mask for the target. - Default is ``False``. + sequence_steps : int + Number of timesteps to include in each sampled sequence. deterministic : bool, optional If ``True``, use a fixed random seed (42). Default is ``False``. augment : bool, optional @@ -139,27 +122,21 @@ def __init__( self, zarr_path: str, standard_names: list[str], - input_steps: int, - forecast_steps: int, - return_mask: bool = False, + sequence_steps: int, deterministic: bool = False, augment: bool = False, width: int = 256, height: int = 256, storage_options: dict[str, Any] | None = None, ) -> None: - if input_steps < 1: - raise ValueError(f"input_steps ({input_steps}) must be at least 1.") - if forecast_steps < 1: - raise ValueError(f"forecast_steps ({forecast_steps}) must be at least 1.") + if sequence_steps < 1: + raise ValueError(f"sequence_steps ({sequence_steps}) must be at least 1.") self.storage_options = storage_options self._zarr_path = zarr_path self._ds: xr.Dataset | None = None self.standard_names = standard_names - self.input_steps = input_steps - self.forecast_steps = forecast_steps - self.return_mask = return_mask + self.sequence_steps = sequence_steps self.augment = augment self.w = width self.h = height @@ -168,34 +145,14 @@ def __init__( self._validate_standard_names() self.t_dim, self.y_dim, self.x_dim = _detect_axes(self.ds, self.standard_names[0]) - # ------------------------------------------------------------------ - # Properties - # ------------------------------------------------------------------ - - @property - def steps(self) -> int: - """Total number of timesteps per sample (``input_steps + forecast_steps``). - - Returns - ------- - steps : int - ``input_steps + forecast_steps``. - """ - return self.input_steps + self.forecast_steps - @property def ds(self) -> xr.Dataset: """Open and cache the Zarr-backed xarray Dataset for this worker. - The store is opened lazily on first access within each process. This - avoids pickling live asyncio connections across DataLoader worker - boundaries, which would cause ``RuntimeError: Future attached to a - different loop``. - Returns ------- - ds : xr.Dataset - The opened (and optionally time-sliced) xarray Dataset. + xr.Dataset + Opened dataset, optionally subset in time for this worker process. """ if self._ds is None: ds = xr.open_zarr(self._zarr_path, storage_options=self.storage_options) @@ -204,17 +161,13 @@ def ds(self) -> xr.Dataset: self._ds = ds return self._ds - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - def _validate_standard_names(self) -> None: """Check that every requested CF standard name exists in the Zarr store. Raises ------ ValueError - If a requested standard name is not found. + If any requested standard name is missing from the dataset. """ for std_name in self.standard_names: try: @@ -241,112 +194,100 @@ def _validate_standard_names(self) -> None: raise ValueError(msg) from e def _apply_augmentations( - self, *tensors: torch.Tensor, rotate_prob: float = 0.5, hflip_prob: float = 0.5, vflip_prob: float = 0.5 - ) -> tuple[torch.Tensor, ...]: - """Apply random spatial augmentations consistently to all input tensors.""" + self, + tensor: Float[Tensor, "sequence_steps channels height width"], + rotate_prob: float = 0.5, + hflip_prob: float = 0.5, + vflip_prob: float = 0.5, + ) -> Float[Tensor, "sequence_steps channels height width"]: + """Apply random spatial augmentations to a sequence tensor. + + Parameters + ---------- + tensor : Float[Tensor, "sequence_steps channels height width"] + Sequence tensor to augment. + rotate_prob : float, optional + Probability of applying a random 90-degree rotation. Default is + ``0.5``. + hflip_prob : float, optional + Probability of applying a horizontal flip. Default is ``0.5``. + vflip_prob : float, optional + Probability of applying a vertical flip. Default is ``0.5``. + + Returns + ------- + Float[Tensor, "sequence_steps channels height width"] + Augmented contiguous tensor. + """ if self.rng.random() < rotate_prob: k = self.rng.integers(1, 4) - tensors = tuple(torch.rot90(t, int(k), dims=[-2, -1]) for t in tensors) + tensor = torch.rot90(tensor, int(k), dims=[-2, -1]) if self.rng.random() < hflip_prob: - tensors = tuple(torch.flip(t, dims=[-1]) for t in tensors) + tensor = torch.flip(tensor, dims=[-1]) if self.rng.random() < vflip_prob: - tensors = tuple(torch.flip(t, dims=[-2]) for t in tensors) - - return tuple(t.contiguous() for t in tensors) + tensor = torch.flip(tensor, dims=[-2]) - def _build_sample(self, data: np.ndarray) -> DatasetSample: - """Convert a raw ``(T, C, H, W)`` numpy array into a :class:`DatasetSample`. + return tensor.contiguous() - Computes the target mask (before ``nan_to_num``), splits into input / - target tensors along the time axis, applies augmentations if requested, - and assembles the final dict. + def _build_sequence(self, data: np.ndarray) -> Float[Tensor, "sequence_steps channels height width"]: + """Convert a raw ``(T, C, H, W)`` numpy array into a tensor. Parameters ---------- data : np.ndarray - Raw normalised array of shape ``(steps, C, H, W)`` — may contain - NaNs where the original data was invalid. + Normalized array with shape ``(sequence_steps, channels, height, + width)``. Returns ------- - sample : DatasetSample - Dictionary with ``'input'`` and ``'target'`` tensors, and - optionally ``'target_mask'`` if ``self.return_mask`` is ``True``. + Float[Tensor, "sequence_steps channels height width"] + Float32 sequence tensor, augmented if requested. """ - # Capture target mask before NaNs are filled - if self.return_mask: - target_mask_t = torch.from_numpy((~np.isnan(data[self.input_steps :])).astype(np.float32)) - - # source data may be float64, but the model and the rest of the - # training pipeline operate in float32. - data = np.nan_to_num(data, nan=-1.0).astype(np.float32) - data_t = torch.from_numpy(data) - - input_t = data_t[: self.input_steps] - target_t = data_t[self.input_steps :] - + data = np.ascontiguousarray(data, dtype=np.float32) + sequence_t = torch.from_numpy(data) if self.augment: - tensors = (input_t, target_t, target_mask_t) if self.return_mask else (input_t, target_t) - augmented = self._apply_augmentations(*tensors) - if self.return_mask: - input_t, target_t, target_mask_t = augmented - else: - input_t, target_t = augmented - - sample = DatasetSample(input=input_t, target=target_t) - if self.return_mask: - sample["target_mask"] = target_mask_t - return sample - - # ------------------------------------------------------------------ - # Abstract interface - # ------------------------------------------------------------------ + sequence_t = self._apply_augmentations(sequence_t) + return sequence_t @abstractmethod def __len__(self) -> int: ... @abstractmethod - def __getitem__(self, idx: int) -> DatasetSample: ... - + def __getitem__(self, idx: int) -> Float[Tensor, "sequence_steps channels height width"]: ... -class SourceDataPrecomputedSamplingDataset(SourceDataDatasetBase): - """PyTorch dataset that loads spatio-temporal data from a Zarr store using - pre-sampled spatial-temporal coordinates from a CSV file. - Each sample is a spatio-temporal crop of shape ``(T, C, H, W)`` - converted to normalized data. +class SourceDataPrecomputedSequenceDataset(SourceDataSequenceDatasetBase): + """Sequence dataset using pre-sampled spatial-temporal coordinates from CSV. Parameters ---------- zarr_path : str Path to the Zarr dataset. csv_path : str - Path to the CSV file with columns ``(t, x, y)`` specifying the - top-left corner of each crop. + Path to the CSV file with ``t``, ``x``, and ``y`` crop coordinates. standard_names : list of str - List of CF standard names of variables to load (e.g., ``["rainfall_rate"]``). - input_steps : int - Number of past timesteps fed to the network as input. - forecast_steps : int - Number of future timesteps the network should predict. - return_mask : bool, optional - If ``True``, also return a per-timestep validity mask for the target. - Default is ``False``. + CF standard names of variables to load. + sequence_steps : int + Number of timesteps to include in each sampled sequence. deterministic : bool, optional - If ``True``, use a fixed random seed (42) for reproducibility. Default is ``False``. + If ``True``, use deterministic random sampling within precomputed time + windows. Default is ``False``. augment : bool, optional - If ``True``, apply random spatial augmentations (rotation, flips). Default is ``False``. + If ``True``, apply random spatial augmentations. Default is ``False``. subset : dict or None, optional Coordinate subsetting specification. Only ``{"time": (start, end)}`` - is supported, where the time range is inclusive and uses ISO strings. + is supported. Default is ``None``. width : int, optional Spatial width of each crop. Default is ``256``. height : int, optional Spatial height of each crop. Default is ``256``. time_depth : int, optional - Number of timesteps in the sampled window. Default is ``24``. + Number of timesteps in each precomputed sampled window. Default is + ``24``. + storage_options : dict or None, optional + Options forwarded to ``xr.open_zarr``. Default is ``None``. """ def __init__( @@ -354,9 +295,7 @@ def __init__( zarr_path: str, csv_path: str, standard_names: list[str], - input_steps: int, - forecast_steps: int, - return_mask: bool = False, + sequence_steps: int, deterministic: bool = False, augment: bool = False, subset: dict[str, Any] | None = None, @@ -379,9 +318,7 @@ def __init__( super().__init__( zarr_path=zarr_path, standard_names=standard_names, - input_steps=input_steps, - forecast_steps=forecast_steps, - return_mask=return_mask, + sequence_steps=sequence_steps, deterministic=deterministic, augment=augment, width=width, @@ -399,48 +336,45 @@ def __init__( self.dt = time_depth - if self.steps > self.dt: - print(f"Warning: requested steps ({self.steps}) > sampled time window ({self.dt})") + if self.sequence_steps > self.dt: + print(f"Warning: requested sequence_steps ({self.sequence_steps}) > sampled time window ({self.dt})") - # Close the store: metadata has been extracted into plain attributes above. - # Each DataLoader worker will reopen it via the `ds` property in its own - # event loop, avoiding asyncio "Future attached to a different loop" errors. self._ds = None def __len__(self) -> int: - """Get the number of samples in the dataset. + """Get the number of precomputed crop coordinates. Returns ------- - length : int - Number of samples. + int + Number of available sequence samples. """ return len(self.coords) @jaxtyped(typechecker=beartype) - def __getitem__(self, idx: int) -> DatasetSample: - """Load and return a single crop sample. + def __getitem__(self, idx: int) -> Float[Tensor, "sequence_steps channels height width"]: + """Load and return a single normalized sequence tensor. + + Parameters + ---------- + idx : int + Index of the precomputed crop coordinate. Returns ------- - sample : DatasetSample - Dictionary with keys ``'input'`` of shape - ``(input_steps, C, H, W)`` and ``'target'`` of shape - ``(forecast_steps, C, H, W)``. If ``return_mask`` is ``True``, - also contains ``'target_mask'`` of shape - ``(forecast_steps, C, H, W)`` with 1 where the original data was - valid and 0 where it was NaN. + Float[Tensor, "sequence_steps channels height width"] + Normalized sequence tensor sampled from the source dataset. """ t0, x0, y0 = self.coords.iloc[idx] x_slice = slice(int(x0), int(x0) + self.w) y_slice = slice(int(y0), int(y0) + self.h) - if self.steps < self.dt: - t_start = self.rng.integers(t0, t0 + self.dt - self.steps + 1) + if self.sequence_steps < self.dt: + t_start = self.rng.integers(t0, t0 + self.dt - self.sequence_steps + 1) else: t_start = t0 - t_slice = slice(int(t_start), int(t_start) + self.steps) + t_slice = slice(int(t_start), int(t_start) + self.sequence_steps) channels = [] for std_name in self.standard_names: @@ -448,56 +382,45 @@ def __getitem__(self, idx: int) -> DatasetSample: norm_func = NORMALIZATION_REGISTRY[std_name] channels.append(norm_func(da_var.values)) - # swapaxes returns a view; make it contiguous and float32 before - # handing it to _build_sample()/torch.from_numpy(). - data = np.ascontiguousarray(np.swapaxes(np.stack(channels, axis=0), 0, 1), dtype=np.float32) - return self._build_sample(data) - + data = np.swapaxes(np.stack(channels, axis=0), 0, 1) + return self._build_sequence(data) -class SourceDataRandomSamplingDataset(SourceDataDatasetBase): - """PyTorch dataset that performs on-the-fly random spatial and temporal - slicing of a Zarr store spatio-temporal data array. - Each sample is a spatio-temporal crop of shape ``(T, C, H, W)`` - converted to normalized data. +class SourceDataRandomSequenceDataset(SourceDataSequenceDatasetBase): + """Sequence dataset with on-the-fly random spatial and temporal sampling. Parameters ---------- zarr_path : str Path to the Zarr dataset. standard_names : list of str - List of CF standard names of variables to load (e.g., ``["rainfall_rate"]``). - input_steps : int - Number of past timesteps fed to the network as input. - forecast_steps : int - Number of future timesteps the network should predict. - return_mask : bool, optional - If ``True``, also return a per-timestep validity mask for the target. - Default is ``False``. + CF standard names of variables to load. + sequence_steps : int + Number of timesteps to include in each sampled sequence. deterministic : bool, optional - If ``True``, use a fixed random seed (42) for reproducibility. Default is ``False``. + If ``True``, use deterministic random sampling. Default is ``False``. augment : bool, optional - If ``True``, apply random spatial augmentations (rotation, flips). Default is ``False``. + If ``True``, apply random spatial augmentations. Default is ``False``. subset : dict or None, optional Coordinate subsetting specification. Only ``{"time": (start, end)}`` - is supported, where the time range is inclusive and uses ISO strings. + is supported. Default is ``None``. width : int, optional Spatial width of each crop. Default is ``256``. height : int, optional Spatial height of each crop. Default is ``256``. epoch_size : int, optional - Number of random samples to generate per epoch. Default is ``1000``. + Number of random samples exposed per epoch. Default is ``1000``. + storage_options : dict or None, optional + Options forwarded to ``xr.open_zarr``. Default is ``None``. **kwargs : Any - Ignored extra arguments (e.g. ``csv_path``) to allow drop-in replacement. + Ignored extra arguments to allow partial config reuse. """ def __init__( self, zarr_path: str, standard_names: list[str], - input_steps: int, - forecast_steps: int, - return_mask: bool = False, + sequence_steps: int, deterministic: bool = False, augment: bool = False, subset: dict[str, Any] | None = None, @@ -521,9 +444,7 @@ def __init__( super().__init__( zarr_path=zarr_path, standard_names=standard_names, - input_steps=input_steps, - forecast_steps=forecast_steps, - return_mask=return_mask, + sequence_steps=sequence_steps, deterministic=deterministic, augment=augment, width=width, @@ -538,47 +459,46 @@ def __init__( self.max_y = da_first_var.sizes[self.y_dim] self.max_x = da_first_var.sizes[self.x_dim] - if self.steps > self.max_t: - raise ValueError(f"Requested steps ({self.steps}) > available time dimension ({self.max_t})") + if self.sequence_steps > self.max_t: + raise ValueError( + f"Requested sequence_steps ({self.sequence_steps}) > available time dimension ({self.max_t})" + ) if self.h > self.max_y: raise ValueError(f"Requested height ({self.h}) > available Y dimension ({self.max_y})") if self.w > self.max_x: raise ValueError(f"Requested width ({self.w}) > available X dimension ({self.max_x})") - # Close the store: metadata has been extracted into plain attributes above. - # Each DataLoader worker will reopen it via the `ds` property in its own - # event loop, avoiding asyncio "Future attached to a different loop" errors. self._ds = None def __len__(self) -> int: - """Get the number of samples in the dataset. + """Get the configured random epoch size. Returns ------- - length : int - Number of samples. + int + Number of random sequence samples exposed per epoch. """ return self.epoch_size @jaxtyped(typechecker=beartype) - def __getitem__(self, idx: int) -> DatasetSample: - """Load and return a single randomly sampled datacube. + def __getitem__(self, idx: int) -> Float[Tensor, "sequence_steps channels height width"]: + """Load and return a single randomly sampled normalized sequence. + + Parameters + ---------- + idx : int + Ignored sample index; each call draws a random crop. Returns ------- - sample : DatasetSample - Dictionary with keys ``'input'`` of shape - ``(input_steps, C, H, W)`` and ``'target'`` of shape - ``(forecast_steps, C, H, W)``. If ``return_mask`` is ``True``, - also contains ``'target_mask'`` of shape - ``(forecast_steps, C, H, W)`` with 1 where the original data was - valid and 0 where it was NaN. + Float[Tensor, "sequence_steps channels height width"] + Normalized sequence tensor sampled from the source dataset. """ - t_start = self.rng.integers(0, self.max_t - self.steps + 1) + t_start = self.rng.integers(0, self.max_t - self.sequence_steps + 1) y_start = self.rng.integers(0, self.max_y - self.h + 1) x_start = self.rng.integers(0, self.max_x - self.w + 1) - t_slice = slice(int(t_start), int(t_start) + self.steps) + t_slice = slice(int(t_start), int(t_start) + self.sequence_steps) y_slice = slice(int(y_start), int(y_start) + self.h) x_slice = slice(int(x_start), int(x_start) + self.w) @@ -588,7 +508,5 @@ def __getitem__(self, idx: int) -> DatasetSample: norm_func = NORMALIZATION_REGISTRY[std_name] channels.append(norm_func(da_var.values)) - # swapaxes returns a view; make it contiguous and float32 before - # handing it to _build_sample()/torch.from_numpy(). - data = np.ascontiguousarray(np.swapaxes(np.stack(channels, axis=0), 0, 1), dtype=np.float32) - return self._build_sample(data) + data = np.swapaxes(np.stack(channels, axis=0), 0, 1) + return self._build_sequence(data) diff --git a/src/mlcast/data/source_data_datamodule.py b/src/mlcast/data/source_data_datamodule.py deleted file mode 100644 index 44f99bd..0000000 --- a/src/mlcast/data/source_data_datamodule.py +++ /dev/null @@ -1,161 +0,0 @@ -"""PyTorch Lightning data module for spatio-temporal datasets. - -Handles train/val/test splitting and DataLoader creation from an injected -dataset factory. -""" - -from collections.abc import Callable -from typing import Any - -import pytorch_lightning as pl -from loguru import logger -from torch.utils.data import DataLoader, Dataset - -from mlcast.data.splits import ( - compute_split_ranges_from_splitting_ratios, - splitting_uses_fractions, - splitting_uses_tuple_ranges, - validate_splits, -) - - -class SourceDataDataModule(pl.LightningDataModule): - """PyTorch Lightning data module for spatio-temporal datasets. - - Handles train/val/test splitting and DataLoader creation by utilizing - an injected ``dataset_factory``. - - Parameters - ---------- - dataset_factory : Callable[..., Dataset] - A factory function (e.g., ``fdl.Partial``) that returns a Dataset instance. - It must accept ``subset`` and ``augment`` as keyword arguments. - splits : dict of {str: dict} - Nested mapping ``{coord: {split_name: value, ...}, ...}`` describing - train/val/test subsets. Currently only the ``time`` coordinate is - supported. Ratio mode uses float fractions, while datetime mode uses - inclusive ``(start, end)`` ISO 8601 string tuples. - **dataloader_kwargs : Any - Additional keyword arguments forwarded to ``DataLoader`` (e.g., - ``batch_size``, ``num_workers``, ``pin_memory``). - """ - - def __init__( - self, - dataset_factory: Callable[..., Dataset], - splits: dict[str, dict[str, Any]], - **dataloader_kwargs: Any, - ) -> None: - super().__init__() - self.dataset_factory = dataset_factory - self.splits = splits - self.dataloader_kwargs = dataloader_kwargs - validate_splits(self.splits) - - def setup(self, stage: str | None = None) -> None: - """Create train, validation, and test datasets. - - Splits are assembled into per-dataset ``subset`` dictionaries. - Datetime-mode splits are passed through unchanged, while ratio-mode - splits are first resolved against the zarr coordinate values and then - converted to inclusive coordinate ranges before dataset instantiation. - Dataset construction depends on the requested Lightning stage: - - - ``"fit"`` builds train and validation datasets; - - ``"validate"`` builds only the validation dataset; - - ``"test"`` builds only the test dataset; and - - ``None`` builds all configured datasets. - - Parameters - ---------- - stage : str | None, optional - Lightning stage hint controlling which datasets are constructed. - - Raises - ------ - ValueError - If ``stage`` is not one of ``None``, ``"fit"``, ``"validate"``, - or ``"test"``. - """ - if stage == "fit": - requested_splits = {"train", "val"} - elif stage == "validate": - requested_splits = {"val"} - elif stage == "test": - requested_splits = {"test"} - elif stage is None: - requested_splits = {"train", "val", "test"} - else: - raise ValueError(f"Unsupported LightningDataModule setup stage: {stage!r}") - - subset_per_split: dict[str, dict[str, Any] | None] = { - split_name: ( - {} - if split_name in requested_splits - and any(split_name in coord_splits for coord_splits in self.splits.values()) - else None - ) - for split_name in ("train", "val", "test") - } - - for coord, coord_splits in self.splits.items(): - if splitting_uses_tuple_ranges(coord_splits): - # tuple-based splits are expected to present the start and end - # of each split, and so are passed through directly as the - # subset values for each split - coord_values_per_split: dict[str, tuple[str, str] | None] = { - "train": coord_splits["train"], - "val": coord_splits["val"], - "test": coord_splits.get("test"), - } - elif splitting_uses_fractions(coord_splits): - # for ratio-based splits, the splitting start-end range tuples - # are constructed by breaking up the given coordinate in - # successive segments (the succession is defined from the order - # of the keys in the splits dict) - coord_values_per_split = compute_split_ranges_from_splitting_ratios( - self.dataset_factory, coord, coord_splits - ) - else: - raise NotImplementedError(f"Unsupported split mode for coordinate {coord!r}: {coord_splits!r}") - - for split_name, split_val in coord_values_per_split.items(): - if split_val is None: - subset_per_split[split_name] = None - elif subset_per_split[split_name] is not None: - subset_per_split[split_name][coord] = split_val - - augment_flags = {"train": True, "val": False, "test": False} - for split in ("train", "val", "test"): - subset = subset_per_split[split] - if subset is None: - setattr(self, f"{split}_dataset", None) - else: - setattr( - self, - f"{split}_dataset", - self.dataset_factory(subset=subset, augment=augment_flags[split]), - ) - - logger.info("{}.setup() complete, containing:", self.__class__.__name__) - for split in ("train", "val", "test"): - dataset = getattr(self, f"{split}_dataset", None) - if dataset is not None: - logger.info( - " {:5s}: {:>6d} samples, subset={}", - split, - len(dataset), - subset_per_split[split], - ) - - def train_dataloader(self) -> DataLoader: - """Return the training DataLoader.""" - return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs) - - def val_dataloader(self) -> DataLoader: - """Return the validation DataLoader.""" - return DataLoader(self.val_dataset, shuffle=False, **self.dataloader_kwargs) - - def test_dataloader(self) -> DataLoader: - """Return the test DataLoader.""" - return DataLoader(self.test_dataset, shuffle=False, **self.dataloader_kwargs) diff --git a/src/mlcast/models/autoencoder/__init__.py b/src/mlcast/models/autoencoder/__init__.py new file mode 100644 index 0000000..fec971a --- /dev/null +++ b/src/mlcast/models/autoencoder/__init__.py @@ -0,0 +1,7 @@ +"""Autoencoder architecture components for reconstruction pretraining.""" + +from .decoder import Decoder, DecoderBlock +from .encoder import Encoder, EncoderBlock +from .net import AutoencoderNet + +__all__ = ["AutoencoderNet", "Decoder", "DecoderBlock", "Encoder", "EncoderBlock"] diff --git a/src/mlcast/models/autoencoder/decoder.py b/src/mlcast/models/autoencoder/decoder.py new file mode 100644 index 0000000..ea3b8a1 --- /dev/null +++ b/src/mlcast/models/autoencoder/decoder.py @@ -0,0 +1,120 @@ +"""Decoder blocks for the reconstruction autoencoder.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + + +class DecoderBlock(nn.Module): + """Spatio-temporal decoder block with optional spatial upsampling. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the block. + upsample : bool, optional + If ``True``, double the spatial resolution with a transposed + convolution. Default is ``True``. + """ + + def __init__(self, in_channels: int, out_channels: int, upsample: bool = True) -> None: + super().__init__() + spatial_stride = 2 if upsample else 1 + output_padding = (0, 1, 1) if upsample else 0 + self.net = nn.Sequential( + nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size=3, + stride=(1, spatial_stride, spatial_stride), + padding=1, + output_padding=output_padding, + ), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + ) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch channels time height width"] + ) -> Float[torch.Tensor, "batch out_channels time out_height out_width"]: + """Decode a channel-first spatio-temporal tensor. + + Parameters + ---------- + x : Float[torch.Tensor, "batch channels time height width"] + Input tensor. + + Returns + ------- + Float[torch.Tensor, "batch out_channels time out_height out_width"] + Decoded tensor. + """ + return self.net(x) + + +class Decoder(nn.Module): + """Convolutional decoder for sequence reconstruction. + + Parameters + ---------- + output_channels : int + Number of channels in the reconstructed source data. + hidden_channels : int, optional + Number of channels used near the output side of the decoder. Default is + ``16``. + latent_channels : int, optional + Number of channels in the latent representation. Default is ``32``. + num_blocks : int, optional + Number of spatial upsampling blocks. Default is ``2``. + """ + + def __init__( + self, + output_channels: int, + hidden_channels: int = 16, + latent_channels: int = 32, + num_blocks: int = 2, + ) -> None: + super().__init__() + if num_blocks < 1: + raise ValueError(f"num_blocks ({num_blocks}) must be at least 1.") + + self.output_channels = output_channels + self.hidden_channels = hidden_channels + self.latent_channels = latent_channels + self.num_blocks = num_blocks + + layers: list[nn.Module] = [] + in_channels = latent_channels + for block_idx in range(num_blocks): + is_last = block_idx == num_blocks - 1 + remaining_blocks = num_blocks - block_idx - 2 + out_channels = output_channels if is_last else hidden_channels * 2 ** max(remaining_blocks, 0) + layers.append(DecoderBlock(in_channels=in_channels, out_channels=out_channels, upsample=True)) + in_channels = out_channels + self.blocks = nn.Sequential(*layers) + + @jaxtyped(typechecker=beartype) + def forward( + self, z: Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + ) -> Float[torch.Tensor, "batch time channels height width"]: + """Decode a latent tensor into a time-first reconstruction tensor. + + Parameters + ---------- + z : Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + Latent tensor in channel-first 3D-convolution layout. + + Returns + ------- + Float[torch.Tensor, "batch time channels height width"] + Reconstructed sequence in the data-layer tensor layout. + """ + return self.blocks(z).movedim(1, 2) diff --git a/src/mlcast/models/autoencoder/encoder.py b/src/mlcast/models/autoencoder/encoder.py new file mode 100644 index 0000000..f319b05 --- /dev/null +++ b/src/mlcast/models/autoencoder/encoder.py @@ -0,0 +1,115 @@ +"""Encoder blocks for the reconstruction autoencoder.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + + +class EncoderBlock(nn.Module): + """Spatio-temporal encoder block with optional spatial downsampling. + + Parameters + ---------- + in_channels : int + Number of channels in the input tensor. + out_channels : int + Number of channels produced by the block. + downsample : bool, optional + If ``True``, halve the spatial resolution with a stride-2 convolution. + Default is ``True``. + """ + + def __init__(self, in_channels: int, out_channels: int, downsample: bool = True) -> None: + super().__init__() + spatial_stride = 2 if downsample else 1 + self.net = nn.Sequential( + nn.Conv3d( + in_channels, + out_channels, + kernel_size=3, + stride=(1, spatial_stride, spatial_stride), + padding=1, + ), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + ) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch channels time height width"] + ) -> Float[torch.Tensor, "batch out_channels time out_height out_width"]: + """Encode a channel-first spatio-temporal tensor. + + Parameters + ---------- + x : Float[torch.Tensor, "batch channels time height width"] + Input tensor. + + Returns + ------- + Float[torch.Tensor, "batch out_channels time out_height out_width"] + Encoded tensor. + """ + return self.net(x) + + +class Encoder(nn.Module): + """Convolutional encoder for sequence reconstruction. + + Parameters + ---------- + input_channels : int + Number of channels in the source data. + hidden_channels : int, optional + Number of channels used in the first encoder block. Default is ``16``. + latent_channels : int, optional + Number of channels in the latent representation. Default is ``32``. + num_blocks : int, optional + Number of spatial downsampling blocks. Default is ``2``. + """ + + def __init__( + self, + input_channels: int, + hidden_channels: int = 16, + latent_channels: int = 32, + num_blocks: int = 2, + ) -> None: + super().__init__() + if num_blocks < 1: + raise ValueError(f"num_blocks ({num_blocks}) must be at least 1.") + + self.input_channels = input_channels + self.hidden_channels = hidden_channels + self.latent_channels = latent_channels + self.num_blocks = num_blocks + + layers: list[nn.Module] = [] + in_channels = input_channels + for block_idx in range(num_blocks): + out_channels = latent_channels if block_idx == num_blocks - 1 else hidden_channels * 2**block_idx + layers.append(EncoderBlock(in_channels=in_channels, out_channels=out_channels, downsample=True)) + in_channels = out_channels + self.blocks = nn.Sequential(*layers) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch time channels height width"] + ) -> Float[torch.Tensor, "batch latent_channels time latent_height latent_width"]: + """Encode a time-first sequence tensor into a latent tensor. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Input sequence in the data-layer tensor layout. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + Latent tensor in channel-first 3D-convolution layout. + """ + return self.blocks(x.movedim(2, 1)) diff --git a/src/mlcast/models/autoencoder/net.py b/src/mlcast/models/autoencoder/net.py new file mode 100644 index 0000000..84d4329 --- /dev/null +++ b/src/mlcast/models/autoencoder/net.py @@ -0,0 +1,80 @@ +"""Autoencoder network for reconstruction pretraining.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.models.autoencoder.decoder import Decoder +from mlcast.models.autoencoder.encoder import Encoder + + +class AutoencoderNet(nn.Module): + """Compose an encoder and decoder into a reconstruction network. + + Parameters + ---------- + encoder : Encoder + Encoder module that maps input sequences to latent tensors. + decoder : Decoder + Decoder module that maps latent tensors back to input-space sequences. + """ + + def __init__(self, encoder: Encoder, decoder: Decoder) -> None: + super().__init__() + self.encoder = encoder + self.decoder = decoder + + @jaxtyped(typechecker=beartype) + def encode( + self, x: Float[torch.Tensor, "batch time channels height width"] + ) -> Float[torch.Tensor, "batch latent_channels time latent_height latent_width"]: + """Encode an input sequence into latent space. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Input sequence tensor. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + Latent tensor produced by the encoder. + """ + return self.encoder(x) + + @jaxtyped(typechecker=beartype) + def decode( + self, z: Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + ) -> Float[torch.Tensor, "batch time channels height width"]: + """Decode a latent tensor into input space. + + Parameters + ---------- + z : Float[torch.Tensor, "batch latent_channels time latent_height latent_width"] + Latent tensor produced by the encoder. + + Returns + ------- + Float[torch.Tensor, "batch time channels height width"] + Reconstructed sequence tensor. + """ + return self.decoder(z) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch time channels height width"] + ) -> Float[torch.Tensor, "batch time channels height width"]: + """Run an end-to-end reconstruction forward pass. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Input sequence tensor. + + Returns + ------- + Float[torch.Tensor, "batch time channels height width"] + Reconstructed sequence tensor. + """ + return self.decode(self.encode(x)) diff --git a/src/mlcast/models/convgru.py b/src/mlcast/models/convgru.py index 2b0c3dc..4482264 100644 --- a/src/mlcast/models/convgru.py +++ b/src/mlcast/models/convgru.py @@ -195,6 +195,12 @@ class Encoder(nn.Module): Parameters ---------- + input_steps : int + Number of timesteps the model expects as input. + forecast_steps : int + Number of timesteps the model forecasts. + ensemble_size : int, optional + Number of ensemble members produced by the model. Default is ``1``. input_channels : int, optional Number of input channels. Default is ``1``. num_blocks : int, optional @@ -350,8 +356,27 @@ class ConvGruModel(nn.Module): :class:`Decoder`. """ - def __init__(self, input_channels: int = 1, num_blocks: int = 4, noisy_decoder: bool = False, **kwargs): + def __init__( + self, + input_steps: int, + forecast_steps: int, + ensemble_size: int = 1, + input_channels: int = 1, + num_blocks: int = 4, + noisy_decoder: bool = False, + **kwargs, + ): super().__init__() + if input_steps < 1: + raise ValueError(f"input_steps ({input_steps}) must be at least 1.") + if forecast_steps < 1: + raise ValueError(f"forecast_steps ({forecast_steps}) must be at least 1.") + if ensemble_size < 1: + raise ValueError(f"ensemble_size ({ensemble_size}) must be at least 1.") + + self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.ensemble_size = ensemble_size self.input_channels = input_channels self.num_blocks = num_blocks self.noisy_decoder = noisy_decoder @@ -360,25 +385,23 @@ def __init__(self, input_channels: int = 1, num_blocks: int = 4, noisy_decoder: @jaxtyped(typechecker=beartype) def forward( - self, x: Float[torch.Tensor, "batch time channels height width"], steps: int, ensemble_size: int = 1 - ) -> Float[torch.Tensor, "batch steps _ height width"]: + self, x: Float[torch.Tensor, "batch time channels height width"] + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: """Forward the encoder-decoder model. Parameters ---------- x : Float[torch.Tensor, "batch time channels height width"] Input sequence. - steps : int - Number of future timesteps to forecast. - ensemble_size : int, optional - Number of ensemble members to generate. When ``> 1``, the decoder - is always run with noisy inputs. Default is ``1``. Returns ------- - preds : Float[torch.Tensor, "batch steps out_channels height width"] - Forecast tensor. + preds : Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"] + Forecast tensor with an explicit ensemble dimension. """ + if x.shape[1] != self.input_steps: + raise ValueError(f"Expected {self.input_steps} input timesteps, got {x.shape[1]}.") + _, _, _, H, W = x.shape divisor = 2**self.num_blocks pad_h = (divisor - (H % divisor)) % divisor @@ -390,21 +413,21 @@ def forward( encoded = self.encoder(x) x_dec_shape = list(encoded[-1].shape) - x_dec_shape[1] = steps + x_dec_shape[1] = self.forecast_steps last_hidden_per_block = [e[:, -1] for e in reversed(encoded)] - if ensemble_size > 1: + if self.ensemble_size > 1: preds = [] - for _ in range(ensemble_size): + for _ in range(self.ensemble_size): x_dec = torch.randn(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device) decoded = self.decoder(x_dec, last_hidden_per_block) preds.append(decoded) - out = torch.cat(preds, dim=2) + out = torch.stack(preds, dim=2) else: x_dec_func = torch.randn if self.noisy_decoder else torch.zeros x_dec = x_dec_func(x_dec_shape, dtype=encoded[-1].dtype, device=encoded[-1].device) - out = self.decoder(x_dec, last_hidden_per_block) + out = self.decoder(x_dec, last_hidden_per_block).unsqueeze(2) if pad_h > 0 or pad_w > 0: out = out[..., :H, :W] diff --git a/src/mlcast/models/diffusion/__init__.py b/src/mlcast/models/diffusion/__init__.py new file mode 100644 index 0000000..d237b22 --- /dev/null +++ b/src/mlcast/models/diffusion/__init__.py @@ -0,0 +1,17 @@ +"""Latent diffusion architecture components.""" + +from .conditioner import ConditionerBlock, ConditionerNet +from .denoiser import DenoiserUNet, TimestepEmbedding +from .loss import DiffusionLoss +from .net import LatentDiffusionNet +from .scheduler import DiffusionScheduler + +__all__ = [ + "ConditionerBlock", + "ConditionerNet", + "DenoiserUNet", + "DiffusionLoss", + "DiffusionScheduler", + "LatentDiffusionNet", + "TimestepEmbedding", +] diff --git a/src/mlcast/models/diffusion/conditioner.py b/src/mlcast/models/diffusion/conditioner.py new file mode 100644 index 0000000..c09e7d3 --- /dev/null +++ b/src/mlcast/models/diffusion/conditioner.py @@ -0,0 +1,86 @@ +"""Latent conditioning blocks for diffusion forecasting.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + + +class ConditionerBlock(nn.Module): + """Residual 3D-convolution block for latent conditioning. + + Parameters + ---------- + channels : int + Number of latent conditioning channels. + """ + + def __init__(self, channels: int) -> None: + super().__init__() + self.net = nn.Sequential( + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups=1, num_channels=channels), + nn.SiLU(), + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + ) + + @jaxtyped(typechecker=beartype) + def forward( + self, x: Float[torch.Tensor, "batch channels time height width"] + ) -> Float[torch.Tensor, "batch channels time height width"]: + """Apply residual conditioning refinement. + + Parameters + ---------- + x : Float[torch.Tensor, "batch channels time height width"] + Latent conditioning tensor. + + Returns + ------- + Float[torch.Tensor, "batch channels time height width"] + Refined conditioning tensor. + """ + return x + self.net(x) + + +class ConditionerNet(nn.Module): + """Condition latent target denoising on encoded input history. + + Parameters + ---------- + latent_channels : int + Number of latent channels in the encoded input history. + hidden_channels : int, optional + Number of channels emitted as conditioning context. Default is ``32``. + num_blocks : int, optional + Number of residual conditioning blocks. Default is ``2``. + """ + + def __init__(self, latent_channels: int, hidden_channels: int = 32, num_blocks: int = 2) -> None: + super().__init__() + if num_blocks < 1: + raise ValueError(f"num_blocks ({num_blocks}) must be at least 1.") + + self.latent_channels = latent_channels + self.hidden_channels = hidden_channels + self.num_blocks = num_blocks + self.input_projection = nn.Conv3d(latent_channels, hidden_channels, kernel_size=1) + self.blocks = nn.Sequential(*(ConditionerBlock(hidden_channels) for _ in range(num_blocks))) + + @jaxtyped(typechecker=beartype) + def forward( + self, z: Float[torch.Tensor, "batch latent_channels input_time height width"] + ) -> Float[torch.Tensor, "batch hidden_channels input_time height width"]: + """Build conditioning context from input-history latents. + + Parameters + ---------- + z : Float[torch.Tensor, "batch latent_channels input_time height width"] + Encoded input-history latent tensor. + + Returns + ------- + Float[torch.Tensor, "batch hidden_channels input_time height width"] + Conditioning context for the denoiser. + """ + return self.blocks(self.input_projection(z)) diff --git a/src/mlcast/models/diffusion/denoiser.py b/src/mlcast/models/diffusion/denoiser.py new file mode 100644 index 0000000..c107027 --- /dev/null +++ b/src/mlcast/models/diffusion/denoiser.py @@ -0,0 +1,250 @@ +"""Timestep-aware denoising network for latent diffusion.""" + +import math + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + + +class TimestepEmbedding(nn.Module): + """Sinusoidal timestep embedding followed by a small MLP. + + Parameters + ---------- + embedding_dim : int + Number of channels in the generated timestep embedding. + """ + + def __init__(self, embedding_dim: int) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.projection = nn.Sequential( + nn.Linear(embedding_dim, embedding_dim), + nn.SiLU(), + nn.Linear(embedding_dim, embedding_dim), + ) + + @jaxtyped(typechecker=beartype) + def forward(self, timesteps: torch.Tensor) -> Float[torch.Tensor, "batch embedding_dim"]: + """Embed integer diffusion timesteps. + + Parameters + ---------- + timesteps : torch.Tensor + Integer diffusion timesteps. + + Returns + ------- + Float[torch.Tensor, "batch embedding_dim"] + Projected sinusoidal timestep embeddings. + """ + half_dim = self.embedding_dim // 2 + frequencies = torch.exp( + torch.arange(half_dim, device=timesteps.device, dtype=torch.float32) + * -(math.log(10_000.0) / max(half_dim - 1, 1)) + ) + args = timesteps.float()[:, None] * frequencies[None] + embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if embedding.shape[-1] < self.embedding_dim: + embedding = torch.nn.functional.pad(embedding, (0, 1)) + return self.projection(embedding) + + +class _DenoiserBlock(nn.Module): + """Internal residual denoising block. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + timestep_channels : int + Number of channels in the timestep embedding. + """ + + def __init__(self, in_channels: int, out_channels: int, timestep_channels: int) -> None: + super().__init__() + self.timestep_projection = nn.Linear(timestep_channels, out_channels) + self.net = nn.Sequential( + nn.GroupNorm(num_groups=1, num_channels=in_channels), + nn.SiLU(), + nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups=1, num_channels=out_channels), + nn.SiLU(), + nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip_connection = nn.Identity() if in_channels == out_channels else nn.Conv3d(in_channels, out_channels, 1) + + def forward(self, x: torch.Tensor, timestep_embedding: torch.Tensor) -> torch.Tensor: + """Apply timestep-conditioned residual denoising. + + Parameters + ---------- + x : torch.Tensor + Hidden denoising tensor. + timestep_embedding : torch.Tensor + Timestep embedding for each batch item. + + Returns + ------- + torch.Tensor + Updated hidden tensor. + """ + timestep_bias = self.timestep_projection(timestep_embedding)[:, :, None, None, None] + h = self.net(x) + return self.skip_connection(x) + h + timestep_bias + + +class _SpatialDownsample(nn.Module): + """Halve latent spatial resolution while preserving time.""" + + def __init__(self, channels: int) -> None: + super().__init__() + self.op = nn.Conv3d(channels, channels, kernel_size=3, stride=(1, 2, 2), padding=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Downsample the spatial dimensions of a latent tensor. + + Parameters + ---------- + x : torch.Tensor + Channel-first latent tensor. + + Returns + ------- + torch.Tensor + Tensor with half spatial resolution. + """ + return self.op(x) + + +class _SpatialUpsample(nn.Module): + """Double latent spatial resolution while preserving time.""" + + def __init__(self, channels: int) -> None: + super().__init__() + self.op = nn.ConvTranspose3d( + channels, + channels, + kernel_size=3, + stride=(1, 2, 2), + padding=1, + output_padding=(0, 1, 1), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Upsample the spatial dimensions of a latent tensor. + + Parameters + ---------- + x : torch.Tensor + Channel-first latent tensor. + + Returns + ------- + torch.Tensor + Tensor with doubled spatial resolution. + """ + return self.op(x) + + +class DenoiserUNet(nn.Module): + """Compact timestep-aware U-Net denoiser for latent tensors. + + This is a real U-Net-style architecture: it builds a spatial downsampling + path, applies a bottleneck at the lowest spatial resolution, upsamples back + to the original latent resolution, and concatenates matching-resolution + skip connections from the down path into the up path. It differs from a + plain image U-Net because it operates on 3D latent tensors and only changes + spatial resolution; the temporal dimension is preserved throughout. Each + residual block also receives a diffusion timestep embedding. + + Parameters + ---------- + latent_channels : int + Number of channels in the noisy target latent. + condition_channels : int + Number of channels emitted by the conditioner. + hidden_channels : int, optional + Number of hidden channels in the denoiser. Default is ``32``. + num_blocks : int, optional + Number of U-Net resolution levels. Default is ``2``. + """ + + def __init__( + self, + latent_channels: int, + condition_channels: int, + hidden_channels: int = 32, + num_blocks: int = 2, + ) -> None: + super().__init__() + if num_blocks < 1: + raise ValueError(f"num_blocks ({num_blocks}) must be at least 1.") + + self.latent_channels = latent_channels + self.condition_channels = condition_channels + self.hidden_channels = hidden_channels + self.num_blocks = num_blocks + self.timestep_embedding = TimestepEmbedding(hidden_channels) + self.input_projection = nn.Conv3d(latent_channels + condition_channels, hidden_channels, kernel_size=1) + self.down_blocks = nn.ModuleList( + _DenoiserBlock(hidden_channels, hidden_channels, hidden_channels) for _ in range(num_blocks) + ) + self.downsamples = nn.ModuleList(_SpatialDownsample(hidden_channels) for _ in range(num_blocks - 1)) + self.bottleneck = _DenoiserBlock(hidden_channels, hidden_channels, hidden_channels) + self.upsamples = nn.ModuleList(_SpatialUpsample(hidden_channels) for _ in range(num_blocks - 1)) + self.up_blocks = nn.ModuleList( + _DenoiserBlock(hidden_channels * 2, hidden_channels, hidden_channels) for _ in range(num_blocks - 1) + ) + self.output_projection = nn.Conv3d(hidden_channels, latent_channels, kernel_size=1) + + @jaxtyped(typechecker=beartype) + def forward( + self, + noisy: Float[torch.Tensor, "batch latent_channels forecast_time height width"], + timesteps: torch.Tensor, + context: Float[torch.Tensor, "batch condition_channels input_time height width"], + ) -> Float[torch.Tensor, "batch latent_channels forecast_time height width"]: + """Predict noise in a noised latent target. + + Parameters + ---------- + noisy : Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Noised target latent. + timesteps : torch.Tensor + Diffusion timestep for each sample. + context : Float[torch.Tensor, "batch condition_channels input_time height width"] + Conditioning context from the input-history latent. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Predicted noise tensor. + """ + if context.shape[2] != noisy.shape[2]: + context = torch.nn.functional.interpolate(context, size=noisy.shape[2:], mode="nearest") + + x = self.input_projection(torch.cat([noisy, context], dim=1)) + timestep_embedding = self.timestep_embedding(timesteps) + + skips: list[torch.Tensor] = [] + for block_idx, block in enumerate(self.down_blocks): + x = block(x, timestep_embedding) + if block_idx < len(self.downsamples): + skips.append(x) + x = self.downsamples[block_idx](x) + + x = self.bottleneck(x, timestep_embedding) + + for upsample, block in zip(self.upsamples, self.up_blocks, strict=True): + x = upsample(x) + skip = skips.pop() + if x.shape[-2:] != skip.shape[-2:]: + x = torch.nn.functional.interpolate(x, size=skip.shape[2:], mode="nearest") + x = block(torch.cat([x, skip], dim=1), timestep_embedding) + + return self.output_projection(x) diff --git a/src/mlcast/models/diffusion/ema.py b/src/mlcast/models/diffusion/ema.py new file mode 100644 index 0000000..0ccd11f --- /dev/null +++ b/src/mlcast/models/diffusion/ema.py @@ -0,0 +1,61 @@ +"""Exponential moving average helpers for diffusion weights.""" + +import torch +import torch.nn as nn + + +class ExponentialMovingAverage: + """Track an exponential moving average of trainable module parameters. + + Parameters + ---------- + module : nn.Module + Module whose parameters should be tracked. + decay : float, optional + EMA decay factor. Default is ``0.999``. + """ + + def __init__(self, module: nn.Module, decay: float = 0.999) -> None: + if not 0.0 <= decay < 1.0: + raise ValueError(f"decay ({decay}) must be in [0, 1).") + self.module = module + self.decay = decay + self.shadow_params = [ + parameter.detach().clone() for parameter in module.parameters() if parameter.requires_grad + ] + self.backup_params: list[torch.Tensor] | None = None + + def update(self) -> None: + """Update EMA shadow parameters from the current module parameters.""" + trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] + for i, (shadow_param, parameter) in enumerate(zip(self.shadow_params, trainable_params, strict=True)): + if shadow_param.device != parameter.device: + self.shadow_params[i] = shadow_param.to(parameter.device) + shadow_param = self.shadow_params[i] + shadow_param.mul_(self.decay).add_(parameter.detach(), alpha=1.0 - self.decay) + + def _align_device(self) -> None: + """Move shadow parameters to the current device of the module's parameters.""" + for i, shadow_param in enumerate(self.shadow_params): + ref_param = next(self.module.parameters()) + if shadow_param.device != ref_param.device: + self.shadow_params[i] = shadow_param.to(ref_param.device) + + def apply(self) -> None: + """Swap EMA parameters into the tracked module.""" + self._align_device() + trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] + self.backup_params = [parameter.detach().clone() for parameter in trainable_params] + for parameter, shadow_param in zip(trainable_params, self.shadow_params, strict=True): + parameter.data.copy_(shadow_param.data) + + def restore(self) -> None: + """Restore module parameters saved before :meth:`apply`.""" + if self.backup_params is None: + raise RuntimeError("EMA restore() called before apply().") + trainable_params = [parameter for parameter in self.module.parameters() if parameter.requires_grad] + for parameter, backup_param in zip(trainable_params, self.backup_params, strict=True): + if backup_param.device != parameter.device: + backup_param = backup_param.to(parameter.device) + parameter.data.copy_(backup_param.data) + self.backup_params = None diff --git a/src/mlcast/models/diffusion/loss.py b/src/mlcast/models/diffusion/loss.py new file mode 100644 index 0000000..6c1f252 --- /dev/null +++ b/src/mlcast/models/diffusion/loss.py @@ -0,0 +1,48 @@ +"""Loss helpers for latent diffusion training.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.models.diffusion.net import LatentDiffusionNet + + +class DiffusionLoss(nn.Module): + """Noise-prediction MSE loss for latent diffusion. + + Parameters + ---------- + net : LatentDiffusionNet + Diffusion network used to sample noised latents and predict noise. + """ + + def __init__(self, net: LatentDiffusionNet) -> None: + super().__init__() + self.net = net + + @jaxtyped(typechecker=beartype) + def forward( + self, + input_latents: Float[torch.Tensor, "batch latent_channels input_time height width"], + target_latents: Float[torch.Tensor, "batch latent_channels forecast_time height width"], + ) -> torch.Tensor: + """Compute a random-timestep noise-prediction loss. + + Parameters + ---------- + input_latents : Float[torch.Tensor, "batch latent_channels input_time height width"] + Encoded input-history latents used as conditioning. + target_latents : Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Clean target latents to diffuse. + + Returns + ------- + torch.Tensor + Scalar mean squared error between predicted and sampled noise. + """ + timesteps = torch.randint(0, self.net.num_timesteps, (target_latents.shape[0],), device=target_latents.device) + noise = torch.randn_like(target_latents) + noised_target = self.net.q_sample(target_latents, timesteps=timesteps, noise=noise) + predicted_noise = self.net(noised_target, timesteps, input_latents) + return torch.nn.functional.mse_loss(predicted_noise, noise) diff --git a/src/mlcast/models/diffusion/net.py b/src/mlcast/models/diffusion/net.py new file mode 100644 index 0000000..59ff087 --- /dev/null +++ b/src/mlcast/models/diffusion/net.py @@ -0,0 +1,102 @@ +"""Latent diffusion network composed from conditioner and denoiser modules.""" + +import torch +import torch.nn as nn +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.models.diffusion.conditioner import ConditionerNet +from mlcast.models.diffusion.denoiser import DenoiserUNet +from mlcast.models.diffusion.scheduler import DiffusionScheduler, extract_schedule_value + + +class LatentDiffusionNet(nn.Module): + """Trainable latent diffusion denoising network. + + Parameters + ---------- + conditioner : ConditionerNet + Network that builds context from input-history latents. It must accept + ``Float[Tensor, "batch latent_channels input_time height width"]`` and + return ``Float[Tensor, "batch condition_channels input_time height width"]``. + denoiser : DenoiserUNet + Network that predicts noise from noised target latents. It must accept + ``noisy`` with shape + ``Float[Tensor, "batch latent_channels forecast_time height width"]``, + ``timesteps`` with shape ``(batch,)``, and ``context`` with shape + ``Float[Tensor, "batch condition_channels input_time height width"]``; + it must return + ``Float[Tensor, "batch latent_channels forecast_time height width"]``. + scheduler : DiffusionScheduler + Diffusion noise scheduler. Calling ``scheduler.buffers(device, dtype)`` + must return one-dimensional tensors of length ``scheduler.timesteps`` + for ``sqrt_alphas_cumprod`` and ``sqrt_one_minus_alphas_cumprod`` so + they can be gathered with timestep indices shaped ``(batch,)`` and + broadcast over latent tensors shaped + ``(batch, latent_channels, forecast_time, height, width)``. + """ + + def __init__(self, conditioner: ConditionerNet, denoiser: DenoiserUNet, scheduler: DiffusionScheduler) -> None: + super().__init__() + self.conditioner = conditioner + self.denoiser = denoiser + self.scheduler = scheduler + self.num_timesteps = scheduler.timesteps + for name, value in scheduler.buffers(device=torch.device("cpu")).items(): + self.register_buffer(name, value) + + @jaxtyped(typechecker=beartype) + def q_sample( + self, + x0: Float[torch.Tensor, "batch channels time height width"], + timesteps: torch.Tensor, + noise: Float[torch.Tensor, "batch channels time height width"] | None = None, + ) -> Float[torch.Tensor, "batch channels time height width"]: + """Diffuse clean latents to a chosen timestep. + + Parameters + ---------- + x0 : Float[torch.Tensor, "batch channels time height width"] + Clean target latent. + timesteps : torch.Tensor + Diffusion timestep for each sample. + noise : Float[torch.Tensor, "batch channels time height width"] or None, optional + Noise to add. If ``None``, standard Gaussian noise is sampled. + Default is ``None``. + + Returns + ------- + Float[torch.Tensor, "batch channels time height width"] + Noised target latent. + """ + if noise is None: + noise = torch.randn_like(x0) + sqrt_alpha = extract_schedule_value(self.sqrt_alphas_cumprod, timesteps, x0.shape) + sqrt_one_minus_alpha = extract_schedule_value(self.sqrt_one_minus_alphas_cumprod, timesteps, x0.shape) + return sqrt_alpha * x0 + sqrt_one_minus_alpha * noise + + @jaxtyped(typechecker=beartype) + def forward( + self, + noised_target: Float[torch.Tensor, "batch latent_channels forecast_time height width"], + timesteps: torch.Tensor, + input_latents: Float[torch.Tensor, "batch latent_channels input_time height width"], + ) -> Float[torch.Tensor, "batch latent_channels forecast_time height width"]: + """Predict noise from a noised target latent and input context. + + Parameters + ---------- + noised_target : Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Noised target latent. + timesteps : torch.Tensor + Diffusion timestep for each sample. + input_latents : Float[torch.Tensor, "batch latent_channels input_time height width"] + Encoded input-history latents. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Predicted noise. + """ + context = self.conditioner(input_latents) + return self.denoiser(noised_target, timesteps, context=context) diff --git a/src/mlcast/models/diffusion/sampler.py b/src/mlcast/models/diffusion/sampler.py new file mode 100644 index 0000000..53d5159 --- /dev/null +++ b/src/mlcast/models/diffusion/sampler.py @@ -0,0 +1,60 @@ +"""Simple ancestral sampler for latent diffusion models.""" + +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.models.diffusion.net import LatentDiffusionNet +from mlcast.models.diffusion.scheduler import extract_schedule_value + + +class DiffusionSampler: + """Generate latent samples with a compact DDPM-style reverse process. + + Parameters + ---------- + net : LatentDiffusionNet + Trained diffusion network. + """ + + def __init__(self, net: LatentDiffusionNet) -> None: + self.net = net + + @jaxtyped(typechecker=beartype) + def sample( + self, + input_latents: Float[torch.Tensor, "batch latent_channels input_time height width"], + output_shape: tuple[int, int, int, int, int], + ) -> Float[torch.Tensor, "batch latent_channels forecast_time height width"]: + """Sample forecast latents conditioned on input latents. + + Parameters + ---------- + input_latents : Float[torch.Tensor, "batch latent_channels input_time height width"] + Encoded input-history latents. + output_shape : tuple of int + Shape of the forecast latent to sample, ordered as + ``(batch, channels, forecast_time, height, width)``. + + Returns + ------- + Float[torch.Tensor, "batch latent_channels forecast_time height width"] + Sampled forecast latent. + """ + x = torch.randn(output_shape, device=input_latents.device, dtype=input_latents.dtype) + for step in reversed(range(self.net.num_timesteps)): + timesteps = torch.full((output_shape[0],), step, device=input_latents.device, dtype=torch.long) + predicted_noise = self.net(x, timesteps, input_latents) + sqrt_alpha = extract_schedule_value(self.net.sqrt_alphas_cumprod, timesteps, x.shape) + sqrt_one_minus_alpha = extract_schedule_value(self.net.sqrt_one_minus_alphas_cumprod, timesteps, x.shape) + x0 = (x - sqrt_one_minus_alpha * predicted_noise) / sqrt_alpha.clamp_min(1e-6) + if step > 0: + prev_timesteps = timesteps - 1 + prev_sqrt_alpha = extract_schedule_value(self.net.sqrt_alphas_cumprod, prev_timesteps, x.shape) + prev_sqrt_one_minus_alpha = extract_schedule_value( + self.net.sqrt_one_minus_alphas_cumprod, prev_timesteps, x.shape + ) + x = prev_sqrt_alpha * x0 + prev_sqrt_one_minus_alpha * predicted_noise + else: + x = x0 + return x diff --git a/src/mlcast/models/diffusion/scheduler.py b/src/mlcast/models/diffusion/scheduler.py new file mode 100644 index 0000000..ceaa80a --- /dev/null +++ b/src/mlcast/models/diffusion/scheduler.py @@ -0,0 +1,71 @@ +"""Diffusion noise schedules.""" + +import torch + + +class DiffusionScheduler: + """Linear-beta diffusion scheduler. + + Parameters + ---------- + timesteps : int, optional + Number of diffusion timesteps. Default is ``100``. + beta_start : float, optional + Initial beta value. Default is ``1e-4``. + beta_end : float, optional + Final beta value. Default is ``2e-2``. + """ + + def __init__(self, timesteps: int = 100, beta_start: float = 1e-4, beta_end: float = 2e-2) -> None: + if timesteps < 1: + raise ValueError(f"timesteps ({timesteps}) must be at least 1.") + self.timesteps = timesteps + self.beta_start = beta_start + self.beta_end = beta_end + + def buffers(self, device: torch.device, dtype: torch.dtype = torch.float32) -> dict[str, torch.Tensor]: + """Build schedule tensors for registration as module buffers. + + Parameters + ---------- + device : torch.device + Device on which buffers should be allocated. + dtype : torch.dtype, optional + Floating-point dtype for schedule tensors. Default is + ``torch.float32``. + + Returns + ------- + dict of str to torch.Tensor + Schedule tensors used for forward and reverse diffusion. + """ + betas = torch.linspace(self.beta_start, self.beta_end, self.timesteps, device=device, dtype=dtype) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + return { + "betas": betas, + "alphas": alphas, + "alphas_cumprod": alphas_cumprod, + "sqrt_alphas_cumprod": torch.sqrt(alphas_cumprod), + "sqrt_one_minus_alphas_cumprod": torch.sqrt(1.0 - alphas_cumprod), + } + + +def extract_schedule_value(values: torch.Tensor, timesteps: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: + """Gather schedule values and reshape them for broadcasting. + + Parameters + ---------- + values : Float[torch.Tensor, "timesteps"] + One-dimensional schedule tensor. + timesteps : Int[torch.Tensor, "batch"] + Timestep index for each batch item. + target_shape : torch.Size + Shape of the target tensor the values should broadcast against. + + Returns + ------- + torch.Tensor + Gathered values reshaped to ``(batch, 1, ..., 1)``. + """ + return values.gather(0, timesteps).reshape(timesteps.shape[0], *([1] * (len(target_shape) - 1))) diff --git a/src/mlcast/modules/__init__.py b/src/mlcast/modules/__init__.py new file mode 100644 index 0000000..46f7d2a --- /dev/null +++ b/src/mlcast/modules/__init__.py @@ -0,0 +1,15 @@ +"""Training and task-level Lightning module wrappers.""" + +from .forecasting import ( + BaseForecastingTaskModule, + LatentDiffusionTaskModule, + OutputSpaceForecastingTaskModule, +) +from .reconstruction import ReconstructionTaskModule + +__all__ = [ + "BaseForecastingTaskModule", + "LatentDiffusionTaskModule", + "OutputSpaceForecastingTaskModule", + "ReconstructionTaskModule", +] diff --git a/src/mlcast/modules/forecasting.py b/src/mlcast/modules/forecasting.py new file mode 100644 index 0000000..fcdfd83 --- /dev/null +++ b/src/mlcast/modules/forecasting.py @@ -0,0 +1,652 @@ +"""Forecasting task-level Lightning module wrappers.""" + +from abc import abstractmethod +from collections.abc import Callable +from typing import Any + +import numpy as np +import pytorch_lightning as pl +import torch +from beartype import beartype +from einops import rearrange +from jaxtyping import Float, jaxtyped + +from mlcast.data.normalization import DENORMALIZATION_REGISTRY, NORMALIZATION_REGISTRY +from mlcast.losses import build_loss +from mlcast.models.autoencoder import AutoencoderNet +from mlcast.models.diffusion.ema import ExponentialMovingAverage +from mlcast.models.diffusion.loss import DiffusionLoss +from mlcast.models.diffusion.net import LatentDiffusionNet +from mlcast.models.diffusion.sampler import DiffusionSampler +from mlcast.visualization import log_images + + +class BaseForecastingTaskModule(pl.LightningModule): + """Base Lightning module for forecasting-shaped tasks. + + Purpose + ------- + This class provides the common PyTorch Lightning plumbing shared by + forecasting-oriented task modules. It centralizes the optimizer and + scheduler configuration interface, the train/validation/test step routing, + and the normalization-aware prediction helper used by forecasting tasks. + + Ownership + --------- + This base class owns: + + - optimizer and scheduler factories + - generic Lightning step orchestration + - normalization and denormalization logic for ``predict`` + + It does not own: + + - a specific forecasting architecture + - a concrete task loss + - the choice of which parameters are trainable + - any task-specific inference logic beyond normalized I/O handling + + Training behavior + ----------------- + Training, validation, and test steps all delegate to the subclass hook + :meth:`compute_loss`. Subclasses are also responsible for exposing the + exact parameter set to optimize through the :attr:`trainable_parameters` + property. + + Inference behavior + ------------------ + ``predict`` accepts unnormalized input observations, applies the configured + normalization for the requested standard name, delegates normalized + forecasting to :meth:`predict_normalized`, then denormalizes the model + outputs back to physical units. + + Parameters + ---------- + optimizer : Callable[..., torch.optim.Optimizer] or None, optional + Optimizer factory. Default is ``None`` (Adam). + lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional + Learning-rate scheduler factory. Default is ``None``. + """ + + def __init__( + self, + optimizer: Callable[..., torch.optim.Optimizer] | None = None, + lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, + ) -> None: + super().__init__() + self.optimizer_factory = optimizer + self.lr_scheduler_factory = lr_scheduler + + @property + @abstractmethod + def trainable_parameters(self) -> list[torch.nn.Parameter]: + """Return the parameters optimized for this forecasting task. + + Returns + ------- + list of torch.nn.Parameter + Trainable parameters owned by the concrete forecasting task. + """ + + @abstractmethod + def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: + """Compute and log loss for one forecasting batch. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch. + split : str, optional + Current data split. Default is ``"train"``. + + Returns + ------- + torch.Tensor + Scalar task loss. + """ + + @jaxtyped(typechecker=beartype) + def predict_normalized( + self, + x: Float[torch.Tensor, "batch time channels height width"], + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: + """Predict normalized forecasts from normalized inputs. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Normalized forecasting input. + + Returns + ------- + Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"] + Normalized forecast tensor with an explicit ensemble dimension. + """ + return self(x) + + def training_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: + """Execute a training step. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Training loss. + """ + return self.compute_loss(batch, split="train") + + def validation_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: + """Execute a validation step. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Validation loss. + """ + return self.compute_loss(batch, split="val") + + def test_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: + """Execute a test step. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Test loss. + """ + return self.compute_loss(batch, split="test") + + def configure_optimizers(self) -> Any: + """Configure optimizer and optional scheduler. + + Returns + ------- + Any + PyTorch Lightning optimizer configuration. + """ + parameters = self.trainable_parameters + if self.optimizer_factory is not None: + optimizer = self.optimizer_factory(parameters) + else: + optimizer = torch.optim.Adam(parameters) + + if self.lr_scheduler_factory is not None: + lr_scheduler = self.lr_scheduler_factory(optimizer) + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val/loss"}} + return {"optimizer": optimizer} + + def predict(self, past: torch.Tensor, standard_name: str = "rainfall_rate") -> np.ndarray[Any, Any]: + """Generate unnormalized forecasts from unnormalized past observations. + + Parameters + ---------- + past : torch.Tensor + Past observations with shape ``(T, H, W)``. + standard_name : str, optional + CF standard name that selects normalization and denormalization + functions. Default is ``"rainfall_rate"``. + + Returns + ------- + np.ndarray + Forecast array shaped ``(ensemble_size, forecast_steps, H, W)`` for + single-channel outputs. + """ + if len(past.shape) != 3: + raise ValueError("Input must be of shape (T, H, W)") + + past_clean = np.nan_to_num(past.cpu().numpy()) + past_clean = past_clean[np.newaxis, :, np.newaxis, ...] + norm_func = NORMALIZATION_REGISTRY[standard_name] + norm_past = norm_func(past_clean) + + x = torch.from_numpy(norm_past).to(self.device) + self.eval() + with torch.no_grad(): + preds_tensor = self.predict_normalized(x) + + preds_np: np.ndarray[Any, Any] = preds_tensor.cpu().numpy() + denorm_func = DENORMALIZATION_REGISTRY[standard_name] + preds_np = denorm_func(preds_np) + preds_np = preds_np.squeeze(0) + preds_np = np.swapaxes(preds_np, 0, 1) + return preds_np + + +class OutputSpaceForecastingTaskModule(BaseForecastingTaskModule): + """Lightning task module for direct forecasting in output space. + + Purpose + ------- + This task module trains conventional forecasting models whose outputs can be + compared directly against forecast targets in the original normalized data + space. It is the generic wrapper used for models such as ConvGRU, where a + single forward pass produces forecast tensors that are supervised directly. + + Ownership + --------- + This class owns: + + - the forecasting network passed in as ``network`` + - the output-space forecasting loss + - optional masked-loss behavior using ``target_mask`` + - image and ensemble-statistic logging specific to direct forecast outputs + + It does not own: + + - source-data normalization rules outside the inherited ``predict`` helper + - latent-space encoding or decoding components + - sampler-driven generative forecast logic + + Training behavior + ----------------- + A forecasting batch provides ``input`` and ``target`` tensors, plus an + optional ``target_mask``. The module calls ``network(input)`` to obtain a + normalized forecast tensor, optionally applies masked loss, and compares the + network outputs directly against the target tensor in output space. + + Inference behavior + ------------------ + Inference is a direct forward pass through the forecasting network. The + inherited :meth:`predict` helper normalizes raw inputs, calls + :meth:`predict_normalized`, and denormalizes the resulting forecast back to + physical units. + + Parameters + ---------- + network : torch.nn.Module + Forecasting network to train. + loss_class : type[torch.nn.Module] or str, optional + Loss function class or registry name. Default is ``"mse"``. + loss_params : dict or None, optional + Keyword arguments for the loss constructor. Default is ``None``. + masked_loss : bool, optional + Whether to use masked-loss computation with ``target_mask`` from the + batch. Default is ``False``. + optimizer : Callable[..., torch.optim.Optimizer] or None, optional + Optimizer factory. Default is ``None`` (Adam). + lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional + Learning-rate scheduler factory. Default is ``None``. + """ + + def __init__( + self, + network: torch.nn.Module, + loss_class: type[torch.nn.Module] | str = "mse", + loss_params: dict[str, Any] | None = None, + masked_loss: bool = False, + optimizer: Callable[..., torch.optim.Optimizer] | None = None, + lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, + ) -> None: + super().__init__(optimizer=optimizer, lr_scheduler=lr_scheduler) + self.save_hyperparameters("loss_class", "loss_params", "masked_loss") + self.network = network + self.criterion = build_loss(loss_class=loss_class, loss_params=loss_params, masked_loss=masked_loss) + self.log_images_iterations = [50, 100, 200, 500, 750, 1000, 2000, 5000] + + @jaxtyped(typechecker=beartype) + def forward( + self, + x: Float[torch.Tensor, "batch time channels height width"], + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: + """Run the forecasting network. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Normalized input history tensor. + + Returns + ------- + Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"] + Normalized forecast tensor with an explicit ensemble dimension. + """ + return self.network(x) + + def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: + """Compute and log forecasting loss for one batch. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch containing ``input`` and ``target`` tensors, and + optionally ``target_mask`` when masked loss is enabled. + split : str, optional + Current data split. Default is ``"train"``. + + Returns + ------- + torch.Tensor + Scalar loss tensor. + """ + past = batch["input"] + future = batch["target"] + preds = self(past).clamp(min=-1, max=1) + + # Flatten ensemble and channel dims for loss functions that expect + # (B, T, M*C, H, W), preserving backward compatibility with CRPS etc. + preds_flat = rearrange(preds, "b t m c h w -> b t (m c) h w") + + ensemble_size = getattr(self.network, "ensemble_size", 1) + ensemble_std = preds.std(dim=2).mean() if ensemble_size > 1 else None + + if self.hparams["masked_loss"]: + mask = batch["target_mask"] + loss = self.criterion(preds_flat, future, mask) + else: + loss = self.criterion(preds_flat, future) + + if isinstance(loss, tuple): + loss, log_dict = loss + self.log_dict( + log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True + ) + + self.log(f"{split}/loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + + if ensemble_std is not None: + self.log(f"{split}/ensemble_std", ensemble_std, on_epoch=True, sync_dist=True) + + if ( + split == "train" + and self.logger is not None + and getattr(self.logger, "experiment", None) is not None + and ( + self.global_step in self.log_images_iterations or self.global_step % self.log_images_iterations[-1] == 0 + ) + ): + log_images( + past=past, + future=future, + preds=preds, + logger=self.logger, # type: ignore[arg-type] + global_step=self.global_step, + ensemble_size=ensemble_size, + split=split, + ) + return loss + + @property + def trainable_parameters(self) -> list[torch.nn.Parameter]: + """Return the forecasting network parameters. + + Returns + ------- + list of torch.nn.Parameter + Parameters optimized for direct forecasting. + """ + return list(self.network.parameters()) + + @classmethod + def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "OutputSpaceForecastingTaskModule": + """Load a forecasting task module from checkpoint. + + Parameters + ---------- + checkpoint_path : str + Path to the saved Lightning checkpoint. + device : str, optional + Device to map parameters onto. Default is ``"cpu"``. + + Returns + ------- + OutputSpaceForecastingTaskModule + Loaded output-space forecasting task module. + """ + return cls.load_from_checkpoint( + checkpoint_path, + map_location=torch.device(device), + strict=True, + weights_only=False, + ) + + +class LatentDiffusionTaskModule(BaseForecastingTaskModule): + """Lightning task module for latent diffusion forecasting. + + Purpose + ------- + This task module trains a latent diffusion forecasting system that reuses a + stage-1 autoencoder. Forecast supervision is applied in latent space rather + than directly on decoded forecast tensors. At inference time, the module + samples forecast latents and decodes them back to the original data space. + + Ownership + --------- + This class owns: + + - the trained autoencoder reused from stage 1 + - the latent diffusion architecture + - the latent diffusion loss + - the diffusion sampler used for forecast generation + - optional EMA tracking over diffusion-network weights + + It does not own: + + - stage-1 autoencoder training + - output-space supervision for the diffusion loss + - the source-data normalization rules beyond the inherited ``predict`` + helper + + Training behavior + ----------------- + A forecasting batch provides raw normalized ``input`` and ``target`` + tensors. The reused autoencoder encoder maps both into latent space under + ``torch.no_grad()``. The module then computes a diffusion loss entirely on + latent tensors and exposes only the diffusion-network parameters through + :attr:`trainable_parameters`, so the reused autoencoder remains frozen. + + Inference behavior + ------------------ + Inference encodes the input history with the reused autoencoder, samples a + latent forecast with the diffusion sampler, then decodes the sampled latent + forecast back to data space. Ensemble generation is explicit here: the + module repeats encoded inputs per requested ensemble member, samples a + forecast latent for each member, and concatenates the decoded members along + the channel dimension. + + Parameters + ---------- + autoencoder : AutoencoderNet + Trained autoencoder reused from stage 1. The encoder is used during + stage-2 training to map forecasting inputs and targets into latent + space. The decoder is retained for forecast inference but is not used in + the stage-2 diffusion loss. + diffusion_net : LatentDiffusionNet + Latent diffusion architecture to train. + forecast_steps : int + Number of forecast timesteps decoded during inference. + ensemble_size : int, optional + Number of ensemble members decoded during inference. Default is ``1``. + loss : DiffusionLoss or None, optional + Latent diffusion loss module. If ``None``, ``DiffusionLoss`` is built + from ``diffusion_net``. Default is ``None``. + optimizer : Callable[..., torch.optim.Optimizer] or None, optional + Optimizer factory. Default is ``None`` (Adam). + lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional + Learning-rate scheduler factory. Default is ``None``. + ema_decay : float or None, optional + If provided, track an exponential moving average of diffusion-net + parameters with this decay (commonly ``0.999`` or ``0.9999``). EMA- + smoothed weights are swapped in during validation, testing, and + prediction — the raw weights receive gradient updates during training. + This is standard practice in diffusion models: the iterative denoising + process amplifies small weight fluctuations, and EMA averaging + suppresses that noise for cleaner samples at inference time. Default is + ``None`` (no EMA). + """ + + def __init__( + self, + autoencoder: AutoencoderNet, + diffusion_net: LatentDiffusionNet, + forecast_steps: int, + ensemble_size: int = 1, + loss: DiffusionLoss | None = None, + optimizer: Callable[..., torch.optim.Optimizer] | None = None, + lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, + ema_decay: float | None = None, + ) -> None: + super().__init__(optimizer=optimizer, lr_scheduler=lr_scheduler) + if forecast_steps < 1: + raise ValueError(f"forecast_steps ({forecast_steps}) must be at least 1.") + if ensemble_size < 1: + raise ValueError(f"ensemble_size ({ensemble_size}) must be at least 1.") + + self.save_hyperparameters("forecast_steps", "ensemble_size", "ema_decay") + self.autoencoder = autoencoder + self.diffusion_net = diffusion_net + self.loss_fn = loss if loss is not None else DiffusionLoss(diffusion_net) + self.sampler = DiffusionSampler(diffusion_net) + self.ema = ExponentialMovingAverage(diffusion_net, decay=ema_decay) if ema_decay is not None else None + + def _freeze_autoencoder(self) -> None: + """Freeze the reused autoencoder before stage-2 use. + + The same autoencoder instance is shared with stage-1 reconstruction + training, so freezing must happen when the diffusion stage begins rather + than in ``__init__``. + """ + self.autoencoder.eval() + for parameter in self.autoencoder.parameters(): + parameter.requires_grad = False + + @jaxtyped(typechecker=beartype) + def forward( + self, + x: Float[torch.Tensor, "batch input_steps channels height width"], + ) -> Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"]: + """Generate decoded forecasts from normalized input histories. + + Parameters + ---------- + x : Float[torch.Tensor, "batch input_steps channels height width"] + Normalized input history tensor. + + Returns + ------- + Float[torch.Tensor, "batch forecast_steps ensemble_size out_channels height width"] + Decoded normalized forecast tensor with an explicit ensemble + dimension. + """ + input_latents = self.autoencoder.encode(x) + repeated_input_latents = input_latents.repeat_interleave(self.hparams["ensemble_size"], dim=0) + latent_shape = ( + x.shape[0] * self.hparams["ensemble_size"], + input_latents.shape[1], + self.hparams["forecast_steps"], + input_latents.shape[3], + input_latents.shape[4], + ) + forecast_latents = self.sampler.sample(repeated_input_latents, latent_shape) + decoded = self.autoencoder.decode(forecast_latents) + # Decoded latent has shape (B*E, T, C, H, W) because ensemble members + # were stacked in the batch dim via repeat_interleave. Unstack into an + # explicit ensemble dim and move time before ensemble for the standard + # (B, T, E, C, H, W) shape contract expected by loss functions etc. + return rearrange(decoded, "(b e) t c h w -> b t e c h w", e=self.hparams["ensemble_size"]) + + def compute_loss(self, batch: dict[str, torch.Tensor], split: str = "train") -> torch.Tensor: + """Compute latent diffusion loss for a forecasting batch. + + Parameters + ---------- + batch : dict of str to torch.Tensor + Forecasting batch containing ``input`` and ``target`` tensors. + split : str, optional + Current data split. Default is ``"train"``. + + Returns + ------- + torch.Tensor + Scalar latent diffusion loss. + """ + with torch.no_grad(): + input_latents = self.autoencoder.encode(batch["input"]) + target_latents = self.autoencoder.encode(batch["target"]) + loss = self.loss_fn(input_latents, target_latents) + self.log(f"{split}/loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + return loss + + @property + def trainable_parameters(self) -> list[torch.nn.Parameter]: + """Return only the diffusion-network parameters. + + Returns + ------- + list of torch.nn.Parameter + Parameters optimized during stage-2 latent diffusion training. + """ + return list(self.diffusion_net.parameters()) + + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: + """Update EMA after each training batch when enabled. + + Parameters + ---------- + outputs : Any + Lightning training outputs. + batch : Any + Batch passed to the training step. + batch_idx : int + Batch index supplied by Lightning. + """ + del outputs, batch, batch_idx + if self.ema is not None: + self.ema.update() + + def on_fit_start(self) -> None: + """Freeze the reused autoencoder before diffusion training starts.""" + self._freeze_autoencoder() + + def on_validation_start(self) -> None: + """Swap EMA weights in before validation when enabled.""" + self._freeze_autoencoder() + if self.ema is not None: + self.ema.apply() + + def on_validation_end(self) -> None: + """Restore raw diffusion weights after validation when enabled.""" + if self.ema is not None: + self.ema.restore() + + def on_test_start(self) -> None: + """Swap EMA weights in before testing when enabled.""" + self._freeze_autoencoder() + if self.ema is not None: + self.ema.apply() + + def on_test_end(self) -> None: + """Restore raw diffusion weights after testing when enabled.""" + if self.ema is not None: + self.ema.restore() + + def on_predict_start(self) -> None: + """Swap EMA weights in before prediction when enabled.""" + self._freeze_autoencoder() + if self.ema is not None: + self.ema.apply() + + def on_predict_end(self) -> None: + """Restore raw diffusion weights after prediction when enabled.""" + if self.ema is not None: + self.ema.restore() diff --git a/src/mlcast/modules/reconstruction.py b/src/mlcast/modules/reconstruction.py new file mode 100644 index 0000000..fc302bd --- /dev/null +++ b/src/mlcast/modules/reconstruction.py @@ -0,0 +1,194 @@ +"""Lightning module wrappers for reconstruction tasks.""" + +from collections.abc import Callable +from typing import Any + +import pytorch_lightning as pl +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped + +from mlcast.losses import build_loss + + +class ReconstructionTaskModule(pl.LightningModule): + """Lightning task module for reconstruction training. + + Purpose + ------- + This task module trains reconstruction models on tensor-only batches from + ``ReconstructionDataset``. It is intended for stage-1 reconstruction or + autoencoder training, where the model learns to reproduce normalized + sequence windows. + + Ownership + --------- + This class owns: + + - the reconstruction network + - the reconstruction loss defined by ``loss_class`` and ``loss_params`` + - the optimizer and learning-rate scheduler factories + + It does not own: + + - source-data normalization rules + - forecasting-specific targets, masks, or ensemble behavior + - latent diffusion training or sampler-driven inference logic + + Training behavior + ----------------- + Each batch is a tensor-only reconstruction sample. The module uses that + tensor as both the model input and the reconstruction target, computes the + reconstruction loss directly in output space, and logs the resulting scalar + loss for the active split. + + Inference behavior + ------------------ + ``forward`` applies the reconstruction network to a normalized input tensor + and returns a reconstructed normalized tensor of the same shape. This + module does not implement forecasting-specific prediction helpers or any + sampler-based inference path. + + Parameters + ---------- + network : torch.nn.Module + Reconstruction model that maps an input tensor back to the same shape. + loss_class : type[torch.nn.Module] or str, optional + Loss function class or registry name. Default is ``"mse"``. + loss_params : dict or None, optional + Keyword arguments for the loss constructor. Default is ``None``. + optimizer : Callable[..., torch.optim.Optimizer] or None, optional + Optimizer factory used by :meth:`configure_optimizers`. Default is + ``None`` (Adam over ``self.parameters()``). + lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional + Learning-rate scheduler factory used by :meth:`configure_optimizers`. + Default is ``None``. + """ + + def __init__( + self, + network: torch.nn.Module, + loss_class: type[torch.nn.Module] | str = "mse", + loss_params: dict[str, Any] | None = None, + optimizer: Callable[..., torch.optim.Optimizer] | None = None, + lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, + ) -> None: + super().__init__() + self.save_hyperparameters("loss_class", "loss_params") + self.network = network + self.optimizer_factory = optimizer + self.lr_scheduler_factory = lr_scheduler + self.criterion = build_loss(loss_class=loss_class, loss_params=loss_params, masked_loss=False) + + @jaxtyped(typechecker=beartype) + def forward( + self, + x: Float[torch.Tensor, "batch time channels height width"], + ) -> Float[torch.Tensor, "batch time channels height width"]: + """Run the reconstruction network. + + Parameters + ---------- + x : Float[torch.Tensor, "batch time channels height width"] + Normalized reconstruction input. + + Returns + ------- + Float[torch.Tensor, "batch time channels height width"] + Reconstructed normalized tensor. + """ + return self.network(x) + + def shared_step(self, batch: torch.Tensor, split: str = "train") -> torch.Tensor: + """Compute reconstruction loss for one batch. + + Parameters + ---------- + batch : torch.Tensor + Tensor-only reconstruction batch. + split : str, optional + Current data split. Default is ``"train"``. + + Returns + ------- + torch.Tensor + Scalar reconstruction loss. + """ + preds = self(batch).clamp(min=-1, max=1) + loss = self.criterion(preds, batch) + if isinstance(loss, tuple): + loss, log_dict = loss + self.log_dict( + log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True + ) + self.log(f"{split}/rec_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) + return loss + + def training_step(self, batch: torch.Tensor, _batch_idx: int) -> torch.Tensor: + """Execute a training step. + + Parameters + ---------- + batch : torch.Tensor + Reconstruction batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Training loss. + """ + return self.shared_step(batch, split="train") + + def validation_step(self, batch: torch.Tensor, _batch_idx: int) -> torch.Tensor: + """Execute a validation step. + + Parameters + ---------- + batch : torch.Tensor + Reconstruction batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Validation loss. + """ + return self.shared_step(batch, split="val") + + def test_step(self, batch: torch.Tensor, _batch_idx: int) -> torch.Tensor: + """Execute a test step. + + Parameters + ---------- + batch : torch.Tensor + Reconstruction batch. + _batch_idx : int + Batch index supplied by Lightning. + + Returns + ------- + torch.Tensor + Test loss. + """ + return self.shared_step(batch, split="test") + + def configure_optimizers(self) -> Any: + """Configure optimizer and optional scheduler. + + Returns + ------- + Any + PyTorch Lightning optimizer configuration. + """ + if self.optimizer_factory is not None: + optimizer = self.optimizer_factory(self.parameters()) + else: + optimizer = torch.optim.Adam(self.parameters()) + + if self.lr_scheduler_factory is not None: + lr_scheduler = self.lr_scheduler_factory(optimizer) + return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val/rec_loss"}} + return {"optimizer": optimizer} diff --git a/src/mlcast/nowcasting_module.py b/src/mlcast/nowcasting_module.py deleted file mode 100644 index 3ef5e19..0000000 --- a/src/mlcast/nowcasting_module.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Generic Lightning module for radar precipitation nowcasting. - -Wraps an injected PyTorch :class:`nn.Module` (the network architecture) and -handles training, validation, and test steps including loss computation, -ensemble generation, and image logging. -""" - -from collections.abc import Callable -from typing import Any - -import numpy as np -import pytorch_lightning as pl -import torch -from beartype import beartype -from jaxtyping import Float, jaxtyped - -from mlcast.data.normalization import DENORMALIZATION_REGISTRY, NORMALIZATION_REGISTRY -from mlcast.losses import build_loss -from mlcast.visualization import log_images - - -class NowcastLightningModule(pl.LightningModule): - """Generic PyTorch Lightning module for nowcasting. - - Wraps an injected PyTorch `nn.Module` (the network architecture) and - handles training, validation, test steps, loss computation, ensemble - generation, and TensorBoard logging. - - Parameters - ---------- - network : torch.nn.Module - The PyTorch network architecture to train. - ensemble_size : int, optional - Number of ensemble members to generate. Default is ``1``. - loss_class : type[torch.nn.Module] or str, optional - Loss function class or its string name. Default is ``"mse"``. - loss_params : dict or None, optional - Keyword arguments for the loss constructor. Default is ``None``. - masked_loss : bool, optional - Whether to wrap the loss with :class:`MaskedLoss`. Default is ``False``. - optimizer : Callable[..., torch.optim.Optimizer] or None, optional - A callable (e.g., a ``functools.partial``) that takes network parameters - and returns an instantiated optimizer. Default is ``None`` (Adam). - lr_scheduler : Callable[..., torch.optim.lr_scheduler.LRScheduler] or None, optional - A callable (e.g., a ``functools.partial``) that takes an optimizer - and returns an instantiated learning rate scheduler. Default is ``None``. - """ - - def __init__( - self, - network: torch.nn.Module, - ensemble_size: int = 1, - loss_class: type[torch.nn.Module] | str = "mse", - loss_params: dict[str, Any] | None = None, - masked_loss: bool = False, - optimizer: Callable[..., torch.optim.Optimizer] | None = None, - lr_scheduler: Callable[..., torch.optim.lr_scheduler.LRScheduler] | None = None, - ) -> None: - super().__init__() - # Explicitly save hyperparameters that are accessed later via self.hparams - self.save_hyperparameters("ensemble_size", "loss_class", "loss_params", "masked_loss") - - self.network = network - self.optimizer_factory = optimizer - self.lr_scheduler_factory = lr_scheduler - - self.criterion = build_loss( - loss_class=loss_class, - loss_params=loss_params, - masked_loss=masked_loss, - ) - self.log_images_iterations = [50, 100, 200, 500, 750, 1000, 2000, 5000] - - @jaxtyped(typechecker=beartype) - def forward( - self, - x: Float[torch.Tensor, "batch time channels height width"], - forecast_steps: int, - ensemble_size: int | None = None, - ) -> Float[torch.Tensor, "batch forecast_steps out_channels height width"]: - """Run the network forward pass. - - Parameters - ---------- - x : Float[torch.Tensor, "batch time channels height width"] - Input tensor. - forecast_steps : int - Number of steps to forecast. - ensemble_size : int or None, optional - Number of ensemble members to generate. If ``None``, uses the initialized value. Default is ``None``. - - Returns - ------- - preds : Float[torch.Tensor, "batch forecast_steps out_channels height width"] - Forecast tensor. - """ - ensemble_size = self.hparams["ensemble_size"] if ensemble_size is None else ensemble_size - return self.network(x, steps=forecast_steps, ensemble_size=ensemble_size) - - def shared_step( - self, batch: dict[str, torch.Tensor], split: str = "train", ensemble_size: int | None = None - ) -> torch.Tensor: - """Shared forward step for training, validation, and testing. - - Parameters - ---------- - batch : dict of str to torch.Tensor - A dictionary containing the batched input data. Must contain the - key ``"data"`` and optionally ``"mask"`` if ``masked_loss`` is ``True``. - split : str, optional - The data split being processed (e.g., ``"train"``, ``"val"``, ``"test"``). - Used for logging. Default is ``"train"``. - ensemble_size : int or None, optional - The number of ensemble members to generate. If ``None``, uses the - default from hyper-parameters. Default is ``None``. - - Returns - ------- - loss : torch.Tensor - The computed loss for the batch. - """ - past = batch["input"] - future = batch["target"] - forecast_steps = future.shape[1] - - preds = self(past, forecast_steps=forecast_steps, ensemble_size=ensemble_size).clamp(min=-1, max=1) - - if self.hparams["masked_loss"]: - mask = batch["target_mask"] - loss = self.criterion(preds, future, mask) - else: - loss = self.criterion(preds, future) - - if isinstance(loss, tuple): - loss, log_dict = loss - self.log_dict( - log_dict, prog_bar=False, logger=True, on_step=(split == "train"), on_epoch=True, sync_dist=True - ) - - self.log(f"{split}_loss", loss, prog_bar=True, on_epoch=True, on_step=(split == "train"), sync_dist=True) - - if self.hparams["ensemble_size"] > 1: - ensemble_std = preds.std(dim=2).mean() - self.log(f"{split}_ensemble_std", ensemble_std, on_epoch=True, sync_dist=True) - - if ( - split == "train" - and self.logger is not None - and getattr(self.logger, "experiment", None) is not None - and ( - self.global_step in self.log_images_iterations or self.global_step % self.log_images_iterations[-1] == 0 - ) - ): - log_images( - past=past, - future=future, - preds=preds, - logger=self.logger, # type: ignore - global_step=self.global_step, - ensemble_size=self.hparams["ensemble_size"], - split=split, - ) - return loss - - def training_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: - """Execute a single training step. - - Parameters - ---------- - batch : dict of str to torch.Tensor - A dictionary containing the batched input data. - batch_idx : int - The index of the current batch. - - Returns - ------- - loss : torch.Tensor - The training loss. - """ - return self.shared_step(batch, split="train") - - def validation_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: - """Execute a single validation step. - - Parameters - ---------- - batch : dict of str to torch.Tensor - A dictionary containing the batched input data. - batch_idx : int - The index of the current batch. - - Returns - ------- - loss : torch.Tensor - The validation loss. - """ - return self.shared_step(batch, split="val", ensemble_size=10) - - def test_step(self, batch: dict[str, torch.Tensor], _batch_idx: int) -> torch.Tensor: - """Execute a single test step. - - Parameters - ---------- - batch : dict of str to torch.Tensor - A dictionary containing the batched input data. - batch_idx : int - The index of the current batch. - - Returns - ------- - loss : torch.Tensor - The test loss. - """ - return self.shared_step(batch, split="test", ensemble_size=10) - - def configure_optimizers(self) -> Any: - """Configure the optimizer and optional learning rate scheduler. - - Returns - ------- - config : dict of str to Any - A dictionary containing the instantiated ``"optimizer"`` and - optionally ``"lr_scheduler"`` configurations for PyTorch Lightning. - """ - if self.optimizer_factory is not None: - optimizer = self.optimizer_factory(self.parameters()) - else: - optimizer = torch.optim.Adam(self.parameters()) - - if self.lr_scheduler_factory is not None: - lr_scheduler = self.lr_scheduler_factory(optimizer) - return {"optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": "val_loss"}} - else: - return {"optimizer": optimizer} - - @classmethod - def from_checkpoint(cls, checkpoint_path: str, device: str = "cpu") -> "NowcastLightningModule": - """Load a model from a checkpoint file. - - Parameters - ---------- - checkpoint_path : str - Path to the saved PyTorch Lightning checkpoint (``.ckpt``) file. - device : str, optional - The device to map the model weights to (e.g., ``"cpu"`` or ``"cuda"``). - Default is ``"cpu"``. - - Returns - ------- - model : NowcastLightningModule - The loaded PyTorch Lightning model instance. - """ - return cls.load_from_checkpoint( - checkpoint_path, - map_location=torch.device(device), - strict=True, - weights_only=False, - ) - - def predict( - self, - past: torch.Tensor, - forecast_steps: int = 1, - ensemble_size: int | None = 1, - standard_name: str = "rainfall_rate", - ) -> np.ndarray[Any, Any]: - """Generate precipitation forecasts from past radar observations. - - Input should be raw unnormalized values. - - Parameters - ---------- - past : torch.Tensor - Past radar frames as unnormalized values (e.g., mm/h or kg m-2 s-1), of shape ``(T, H, W)``. - forecast_steps : int, optional - Number of future timesteps to forecast. Default is ``1``. - ensemble_size : int, optional - Number of ensemble members. Default is ``1``. - standard_name : str, optional - The CF standard name defining the input/output domain for normalization lookup. - Default is ``"rainfall_rate"``. - - Returns - ------- - preds : np.ndarray - Forecasted unnormalized values, of shape - ``(ensemble_size, forecast_steps, H, W)``. - """ - if len(past.shape) != 3: - raise ValueError("Input must be of shape (T, H, W)") - - ensemble_size = self.hparams["ensemble_size"] if ensemble_size is None else ensemble_size - - past_clean = np.nan_to_num(past.cpu().numpy()) - past_clean = past_clean[np.newaxis, :, np.newaxis, ...] - - norm_func = NORMALIZATION_REGISTRY[standard_name] - norm_past = norm_func(past_clean) - - x = torch.from_numpy(norm_past) - x = x.to(self.device) - - self.eval() - with torch.no_grad(): - preds_tensor = self.network(x, steps=forecast_steps, ensemble_size=ensemble_size) - - preds_np: np.ndarray[Any, Any] = preds_tensor.cpu().numpy() - - denorm_func = DENORMALIZATION_REGISTRY[standard_name] - preds_np = denorm_func(preds_np) - - preds_np = preds_np.squeeze(0) - preds_np = np.swapaxes(preds_np, 0, 1) - - return preds_np diff --git a/src/mlcast/visualization.py b/src/mlcast/visualization.py index a32eaee..6bd0759 100644 --- a/src/mlcast/visualization.py +++ b/src/mlcast/visualization.py @@ -95,11 +95,12 @@ def log_images( if ensemble_size > 1: preds_avg = preds_sample.mean(dim=1, keepdim=True) num_members_to_log = min(3, preds_sample.shape[1]) - rows = [future_sample, preds_avg] + [preds_sample[:, i : i + 1] for i in range(num_members_to_log)] + rows = [future_sample.unsqueeze(1), preds_avg] + [preds_sample[:, i : i + 1] for i in range(num_members_to_log)] + all_frames = torch.cat(rows, dim=0).squeeze(1) else: - rows = [future_sample, preds_sample] + rows = [future_sample, preds_sample.squeeze(1)] + all_frames = torch.cat(rows, dim=0) - all_frames = torch.cat(rows, dim=0) all_frames_norm = (all_frames + 1) / 2 all_frames_rgb = apply_radar_colormap(all_frames_norm) preds_grid = torchvision.utils.make_grid(all_frames_rgb, nrow=future_sample.shape[0]) diff --git a/tests/config/test_cli_examples.py b/tests/config/test_cli_examples.py index 5a54c6f..156db8d 100644 --- a/tests/config/test_cli_examples.py +++ b/tests/config/test_cli_examples.py @@ -2,24 +2,45 @@ import subprocess import sys -from mlcast.__main__ import get_cli_examples, get_fiddler_examples -from mlcast.config import training_experiment +from mlcast.__main__ import get_cli_examples, get_fiddler_examples, get_included_config_names +from mlcast.config import convgru_training_experiment -def test_cli_examples_parse_correctly(): +def test_cli_examples_parse_correctly() -> None: """Verify that every CLI override example given in the help text successfully parses.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() examples = get_cli_examples(cfg) + get_fiddler_examples() for _desc, cmd in examples: - # Strip out the leading `mlcast train ` if we used it, but cmd is just `--config ...` args = shlex.split(cmd) - # We can run the __main__.py module using subprocess to ensure isolated absl flag parsing - process_args = [sys.executable, "-m", "mlcast", "train"] + args + ["--only_check_args"] + process_args = [sys.executable, "-m", "mlcast", "train"] + args + ["--print_config_and_exit"] result = subprocess.run(process_args, capture_output=True, text=True) - # absl prints "unknown flag: --only_check_args" in some versions, or handles it? - # Let's check what it does. assert result.returncode == 0, f"Command '{cmd}' failed to parse:\n{result.stderr}\n{result.stdout}" + + +def test_cli_requires_explicit_config() -> None: + """Train command should fail fast when no base config is provided.""" + result = subprocess.run( + [sys.executable, "-m", "mlcast", "train"], + capture_output=True, + text=True, + ) + + assert result.returncode != 0 + assert "base config is required" in result.stderr + + +def test_cli_help_lists_included_configs() -> None: + """Help text should advertise the built-in config entry points.""" + result = subprocess.run( + [sys.executable, "-m", "mlcast", "train", "--help"], + capture_output=True, + text=True, + ) + + assert result.returncode == 0 + for name in get_included_config_names(): + assert name in result.stdout diff --git a/tests/config/test_consistency_checks.py b/tests/config/test_consistency_checks.py index 3dc8ec3..56bfbe3 100644 --- a/tests/config/test_consistency_checks.py +++ b/tests/config/test_consistency_checks.py @@ -3,16 +3,16 @@ import pytest from loguru import logger -from mlcast.config import training_experiment, validate_config -from mlcast.data.source_data_datasets import SourceDataPrecomputedSamplingDataset +from mlcast.config import convgru_training_experiment, validate_config +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset def test_contract_1_input_channels() -> None: """Verify Contract 1: Network input_channels == len(dataset_factory.standard_names).""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Break Contract 1 cfg.pl_module.network.input_channels = 2 - cfg.data.dataset_factory.standard_names = ["rainfall_rate"] + cfg.data.sequence_dataset_factory.standard_names = ["rainfall_rate"] with pytest.raises(ValueError, match="Contract 1 violated:"): validate_config(cfg) @@ -20,9 +20,9 @@ def test_contract_1_input_channels() -> None: def test_contract_2_spatial_divisibility() -> None: """Verify Contract 2: Dataset width must be divisible by 2 \\*\\* network.num_blocks.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Break Contract 2 - cfg.data.dataset_factory.width = 250 + cfg.data.sequence_dataset_factory.width = 250 cfg.pl_module.network.num_blocks = 4 with pytest.raises(ValueError, match="Contract 2 violated:"): @@ -31,7 +31,7 @@ def test_contract_2_spatial_divisibility() -> None: def test_contract_1_and_2_warn_when_network_lacks_attrs() -> None: """Verify Contracts 1 and 2 warn when the network lacks required attrs.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() cfg.pl_module.network = SimpleNamespace() messages: list[str] = [] @@ -51,9 +51,9 @@ def capture(message: object) -> None: def test_contract_3_probabilistic_loss() -> None: """Verify Contract 3: Ensemble models require CRPS or AFCRPS.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Break Contract 3 - cfg.pl_module.ensemble_size = 5 + cfg.pl_module.network.ensemble_size = 5 cfg.pl_module.loss_class = "mse" with pytest.raises(ValueError, match="Contract 3 violated:"): @@ -62,22 +62,41 @@ def test_contract_3_probabilistic_loss() -> None: def test_contract_4_masking_sync() -> None: """Verify Contract 4: Dataset return_mask must match model masked_loss.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Break Contract 4 - cfg.data.dataset_factory.return_mask = True + cfg.data.return_mask = True cfg.pl_module.masked_loss = False with pytest.raises(ValueError, match="Contract 4 violated:"): validate_config(cfg) -def test_dataset_forecast_steps_guard() -> None: - """Verify that dataset raises ValueError when input_steps=0.""" - with pytest.raises(ValueError, match="input_steps"): - SourceDataPrecomputedSamplingDataset( +def test_contract_5_input_steps_sync() -> None: + """Verify Contract 5: data input_steps must match model input_steps.""" + cfg = convgru_training_experiment.as_buildable() + cfg.data.input_steps = 4 + cfg.pl_module.network.input_steps = 6 + + with pytest.raises(ValueError, match="Contract 5 violated:"): + validate_config(cfg) + + +def test_contract_6_forecast_steps_sync() -> None: + """Verify Contract 6: data forecast_steps must match model forecast_steps.""" + cfg = convgru_training_experiment.as_buildable() + cfg.data.forecast_steps = 10 + cfg.pl_module.network.forecast_steps = 12 + + with pytest.raises(ValueError, match="Contract 6 violated:"): + validate_config(cfg) + + +def test_dataset_sequence_steps_guard() -> None: + """Verify that sequence dataset raises ValueError when sequence_steps=0.""" + with pytest.raises(ValueError, match="sequence_steps"): + SourceDataPrecomputedSequenceDataset( zarr_path="dummy.zarr", csv_path="dummy.csv", standard_names=["rainfall_rate"], - input_steps=0, - forecast_steps=5, + sequence_steps=0, ) diff --git a/tests/config/test_fiddlers.py b/tests/config/test_fiddlers.py index e6f443f..b6e66d5 100644 --- a/tests/config/test_fiddlers.py +++ b/tests/config/test_fiddlers.py @@ -1,28 +1,73 @@ -from mlcast.config import set_variables, toggle_masking, training_experiment +import fiddle as fdl +from mlcast.config import ( + convgru_training_experiment, + latent_diffusion_experiment, + set_variables, + toggle_masking, + use_random_sampler, + use_ratio_splits, +) +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset, SourceDataRandomSequenceDataset -def test_fiddler_set_variables(): + +def test_fiddler_set_variables() -> None: """Verify set_variables syncs dataset variables and network input_channels.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() - # Apply fiddler set_variables(cfg, ["rainfall_rate", "rainfall_flux"]) - # Check sync - assert cfg.data.dataset_factory.standard_names == ["rainfall_rate", "rainfall_flux"] + assert cfg.data.sequence_dataset_factory.standard_names == ["rainfall_rate", "rainfall_flux"] assert cfg.pl_module.network.input_channels == 2 -def test_fiddler_toggle_masking(): +def test_fiddler_toggle_masking() -> None: """Verify toggle_masking syncs dataset mask return and module masked_loss.""" - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() - # Disable masking toggle_masking(cfg, False) - assert cfg.data.dataset_factory.return_mask is False + assert cfg.data.return_mask is False assert cfg.pl_module.masked_loss is False - # Enable masking toggle_masking(cfg, True) - assert cfg.data.dataset_factory.return_mask is True + assert cfg.data.return_mask is True assert cfg.pl_module.masked_loss is True + + +def test_fiddler_set_variables_on_latent_diffusion() -> None: + """Verify set_variables applies to both stages of a LatentDiffusionTrainingExperiment.""" + cfg = latent_diffusion_experiment.as_buildable() + + set_variables(cfg, ["rainfall_rate", "rainfall_flux", "rainfall_intensity"]) + + # Both stages share the same sequence_dataset_factory object + expected_names = ["rainfall_rate", "rainfall_flux", "rainfall_intensity"] + assert cfg.stage1.data.sequence_dataset_factory.standard_names == expected_names + assert cfg.stage2.data.sequence_dataset_factory.standard_names == expected_names + + # Encoder (inside AutoencoderNet) has input_channels and should be updated + assert cfg.stage1.pl_module.network.encoder.input_channels == 3 + assert cfg.stage2.pl_module.autoencoder.encoder.input_channels == 3 + + +def test_fiddler_use_random_sampler_on_latent_diffusion() -> None: + """Verify use_random_sampler applies to both stages of LatentDiffusionTrainingExperiment.""" + cfg = latent_diffusion_experiment.as_buildable() + + assert fdl.get_callable(cfg.stage1.data.sequence_dataset_factory) is SourceDataPrecomputedSequenceDataset + assert fdl.get_callable(cfg.stage2.data.sequence_dataset_factory) is SourceDataPrecomputedSequenceDataset + + use_random_sampler(cfg) + + assert fdl.get_callable(cfg.stage1.data.sequence_dataset_factory) is SourceDataRandomSequenceDataset + assert fdl.get_callable(cfg.stage2.data.sequence_dataset_factory) is SourceDataRandomSequenceDataset + + +def test_fiddler_use_ratio_splits_on_latent_diffusion() -> None: + """Verify use_ratio_splits applies to both stages of LatentDiffusionTrainingExperiment.""" + cfg = latent_diffusion_experiment.as_buildable() + + use_ratio_splits(cfg, train=0.6, val=0.2) + + assert cfg.stage1.data.splits == {"time": {"train": 0.6, "val": 0.2, "test": 0.2}} + assert cfg.stage2.data.splits == {"time": {"train": 0.6, "val": 0.2, "test": 0.2}} diff --git a/tests/config/test_latent_diffusion_experiment.py b/tests/config/test_latent_diffusion_experiment.py new file mode 100644 index 0000000..5aebdd9 --- /dev/null +++ b/tests/config/test_latent_diffusion_experiment.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass + +import fiddle as fdl + +from mlcast.config import LatentDiffusionTrainingExperiment, latent_diffusion_experiment, validate_config +from mlcast.config.base import Experiment + + +@dataclass +class RecordingTrainer: + """Minimal trainer stub that records fit/test call order.""" + + events: list[str] + + def fit(self, pl_module, datamodule=None) -> None: # type: ignore[no-untyped-def] + self.events.append(f"fit:{pl_module}:{datamodule}") + + def test(self, pl_module, datamodule=None) -> None: # type: ignore[no-untyped-def] + self.events.append(f"test:{pl_module}:{datamodule}") + + +def test_latent_diffusion_experiment_runs_stages_in_order() -> None: + """LatentDiffusionTrainingExperiment should execute stage 1 fully before stage 2.""" + events: list[str] = [] + stage1 = Experiment(pl_module="stage1_module", data="stage1_data", trainer=RecordingTrainer(events=events)) + stage2 = Experiment(pl_module="stage2_module", data="stage2_data", trainer=RecordingTrainer(events=events)) + experiment = LatentDiffusionTrainingExperiment(stage1=stage1, stage2=stage2) + + experiment.run() + + assert events == [ + "fit:stage1_module:stage1_data", + "test:stage1_module:stage1_data", + "fit:stage2_module:stage2_data", + "test:stage2_module:stage2_data", + ] + + +def test_latent_diffusion_experiment_shares_autoencoder_identity() -> None: + """Stage 1 and stage 2 should reference the same built autoencoder instance.""" + cfg = latent_diffusion_experiment.as_buildable() + validate_config(cfg) + + experiment = fdl.build(cfg) + + assert experiment.stage1.pl_module.network is experiment.stage2.pl_module.autoencoder diff --git a/tests/config/test_orchestrator.py b/tests/config/test_orchestrator.py index 42d5a36..af1786a 100644 --- a/tests/config/test_orchestrator.py +++ b/tests/config/test_orchestrator.py @@ -1,12 +1,23 @@ +from pathlib import Path +from typing import Any from unittest.mock import patch -from mlcast.config import train_from_config, training_experiment +from mlcast.config import convgru_training_experiment, latent_diffusion_experiment, train_from_config @patch("mlcast.config.orchestrator.fdl.build") -def test_train_from_config_valid(mock_build, tmp_path): +def test_train_from_config_valid(mock_build: Any, tmp_path: Path) -> None: """Verify that a valid configuration passes validation and builds.""" mock_build.return_value.trainer.log_dir = str(tmp_path) - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() + train_from_config(cfg) + mock_build.assert_called_once() + + +@patch("mlcast.config.orchestrator.fdl.build") +def test_train_from_config_valid_latent_diffusion(mock_build: Any, tmp_path: Path) -> None: + """Verify that a valid latent diffusion configuration passes validation and builds.""" + mock_build.return_value.trainer.log_dir = str(tmp_path) + cfg = latent_diffusion_experiment.as_buildable() train_from_config(cfg) mock_build.assert_called_once() diff --git a/tests/data/test_data_module.py b/tests/data/test_data_module.py index e55bf58..8968500 100644 --- a/tests/data/test_data_module.py +++ b/tests/data/test_data_module.py @@ -3,28 +3,29 @@ import pandas as pd import pytest +import torch from torch.utils.data import DataLoader, Dataset -from mlcast.data.source_data_datamodule import SourceDataDataModule +from mlcast.data.datamodules import ForecastingDataModule, ReconstructionDataModule +from mlcast.data.forecasting import ForecastingDataset +from mlcast.data.reconstruction import ReconstructionDataset from mlcast.data.splits import splitting_uses_fractions, splitting_uses_tuple_ranges, validate_splits -class MockDataset(Dataset): - """Minimal dataset mock that records how it was constructed. - - ``__len__`` returns a fixed size so that dataloader batch-count assertions - work correctly. - """ +class MockSequenceDataset(Dataset): + """Minimal sequence dataset mock that records how it was constructed.""" def __init__( self, zarr_path: str, + sequence_steps: int, subset: dict | None = None, augment: bool = False, epoch_size: int = 100, **kwargs, ) -> None: self.zarr_path = zarr_path + self.sequence_steps = sequence_steps self.subset = subset self.augment = augment self.epoch_size = epoch_size @@ -33,8 +34,9 @@ def __init__( def __len__(self) -> int: return self.epoch_size - def __getitem__(self, idx: int) -> dict: - return {"data": idx} + def __getitem__(self, idx: int) -> torch.Tensor: + base = torch.arange(self.sequence_steps, dtype=torch.float32)[:, None, None, None] + return base.expand(-1, 1, 4, 4) def _mock_zarr(time_index: pd.DatetimeIndex) -> MagicMock: @@ -109,30 +111,41 @@ def test_splitting_mode_helpers_require_consistent_values() -> None: assert not splitting_uses_tuple_ranges({"train": object(), "val": object()}) -def test_data_module_ratio_splits() -> None: +def test_forecasting_data_module_ratio_splits() -> None: """DataModule ratio mode passes correct time subsets to the factory.""" n = 100 time_index = _make_time_index(n) - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr", foo="bar") + sequence_dataset_factory = functools.partial( + MockSequenceDataset, + zarr_path="mock.zarr", + sequence_steps=6, + foo="bar", + ) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, splits={"time": {"train": 0.5, "val": 0.2, "test": 0.3}}, batch_size=2 + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.3}}, + batch_size=2, ) with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="fit") - assert dm.train_dataset.augment is True - assert dm.train_dataset.kwargs["foo"] == "bar" - train_start, train_end = dm.train_dataset.subset["time"] - val_start, val_end = dm.val_dataset.subset["time"] + assert isinstance(dm.train_dataset, ForecastingDataset) + assert dm.train_dataset.base_sequence_dataset.augment is True + assert dm.train_dataset.base_sequence_dataset.kwargs["foo"] == "bar" + train_start, train_end = dm.train_dataset.base_sequence_dataset.subset["time"] + val_start, val_end = dm.val_dataset.base_sequence_dataset.subset["time"] assert train_start == str(time_index[0]) assert train_end == str(time_index[49]) assert val_start == str(time_index[50]) assert val_end == str(time_index[69]) - assert dm.val_dataset.augment is False + assert dm.val_dataset.base_sequence_dataset.augment is False assert dm.test_dataset is None train_dl = dm.train_dataloader() @@ -147,18 +160,27 @@ class _NoZarrPathFactory: def __call__(self, **kwargs) -> Dataset: return MagicMock(spec=Dataset) - dm = SourceDataDataModule(dataset_factory=_NoZarrPathFactory(), splits={"time": {"train": 0.7, "val": 0.15}}) + dm = ForecastingDataModule( + sequence_dataset_factory=_NoZarrPathFactory(), + input_steps=2, + forecast_steps=4, + return_mask=True, + splits={"time": {"train": 0.7, "val": 0.15}}, + ) with pytest.raises((AttributeError, KeyError)): dm.setup() -def test_data_module_fraction_splits_without_test_do_not_create_test_dataset() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") +def test_forecasting_data_module_fraction_splits_without_test_do_not_create_test_dataset() -> None: + sequence_dataset_factory = functools.partial(MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=6) time_index = _make_time_index(100) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, splits={"time": {"train": 0.5, "val": 0.2}}, batch_size=2, ) @@ -171,18 +193,22 @@ def test_data_module_fraction_splits_without_test_do_not_create_test_dataset() - assert dm.test_dataset is None -def test_data_module_split_lengths_and_batches() -> None: - """Test that dataset lengths and dataloader batch counts are correct after splitting. - - Dataloader batch counts are correct after splitting. - """ +def test_forecasting_data_module_split_lengths_and_batches() -> None: n_time = 240 batch_size = 10 time_index = _make_time_index(n_time) - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr", epoch_size=10) + sequence_dataset_factory = functools.partial( + MockSequenceDataset, + zarr_path="mock.zarr", + sequence_steps=6, + epoch_size=10, + ) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, splits={"time": {"train": 1 / 2, "val": 1 / 3, "test": 1 / 6}}, batch_size=batch_size, ) @@ -195,11 +221,14 @@ def test_data_module_split_lengths_and_batches() -> None: assert len(dm.test_dataloader()) == 1 -def test_data_module_datetime_splits() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") +def test_forecasting_data_module_datetime_splits() -> None: + sequence_dataset_factory = functools.partial(MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=6) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, splits={ "time": { "train": ("2016-01-01", "2021-12-31"), @@ -212,36 +241,20 @@ def test_data_module_datetime_splits() -> None: dm.setup() - assert dm.train_dataset.subset == {"time": ("2016-01-01", "2021-12-31")} - assert dm.val_dataset.subset == {"time": ("2022-01-01", "2023-12-31")} + assert dm.train_dataset.base_sequence_dataset.subset == {"time": ("2016-01-01", "2021-12-31")} + assert dm.val_dataset.base_sequence_dataset.subset == {"time": ("2022-01-01", "2023-12-31")} assert dm.test_dataset is None -def test_data_module_fraction_test_split_uses_explicit_fraction() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - time_index = _make_time_index(100) - - dm = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, - batch_size=2, +def test_reconstruction_data_module_wraps_sequence_splits() -> None: + sequence_dataset_factory = functools.partial( + MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=5, epoch_size=5 ) - - with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): - dm.setup() - - assert dm.test_dataset is not None - test_start, test_end = dm.test_dataset.subset["time"] - assert test_start == str(time_index[70]) - assert test_end == str(time_index[79]) - - -def test_data_module_fit_stage_creates_only_train_and_val() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") time_index = _make_time_index(100) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ReconstructionDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=3, splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, batch_size=2, ) @@ -249,81 +262,59 @@ def test_data_module_fit_stage_creates_only_train_and_val() -> None: with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="fit") - assert dm.train_dataset is not None - assert dm.val_dataset is not None + assert isinstance(dm.train_dataset, ReconstructionDataset) + assert isinstance(dm.val_dataset, ReconstructionDataset) assert dm.test_dataset is None + assert dm.train_dataset.base_sequence_dataset.augment is True + assert dm.val_dataset.base_sequence_dataset.augment is False + assert dm.train_dataset[0].shape == (3, 1, 4, 4) -def test_data_module_validate_stage_creates_only_val() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") +def test_data_module_validate_test_and_logging_paths() -> None: + sequence_dataset_factory = functools.partial(MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=6) time_index = _make_time_index(100) - dm = SourceDataDataModule( - dataset_factory=dataset_factory, + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, batch_size=2, ) with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="validate") - assert dm.train_dataset is None assert dm.val_dataset is not None assert dm.test_dataset is None - -def test_data_module_test_stage_creates_only_test() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - time_index = _make_time_index(100) - - dm = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, - batch_size=2, - ) - with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): dm.setup(stage="test") - assert dm.train_dataset is None assert dm.val_dataset is None assert dm.test_dataset is not None - -def test_data_module_rejects_unknown_stage() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - dm = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, - batch_size=2, - ) - with pytest.raises(ValueError, match="Unsupported LightningDataModule setup stage"): dm.setup(stage="predict") - -def test_data_module_logs_split_summary() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - time_index = _make_time_index(100) - - dm = SourceDataDataModule( - dataset_factory=dataset_factory, - splits={"time": {"train": 0.5, "val": 0.2, "test": 0.1}}, - batch_size=2, - ) - with ( patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)), - patch("mlcast.data.source_data_datamodule.logger.info") as mock_info, + patch("mlcast.data.datamodules.logger.info") as mock_info, ): dm.setup() - assert mock_info.call_count == 4 def test_data_module_unsupported_split_mode() -> None: - dataset_factory = functools.partial(MockDataset, zarr_path="mock.zarr") - dm = SourceDataDataModule(dataset_factory=dataset_factory, splits={"time": {"train": 0.7, "val": 0.15}}) + sequence_dataset_factory = functools.partial(MockSequenceDataset, zarr_path="mock.zarr", sequence_steps=6) + dm = ForecastingDataModule( + sequence_dataset_factory=sequence_dataset_factory, + input_steps=2, + forecast_steps=4, + return_mask=True, + splits={"time": {"train": 0.7, "val": 0.15}}, + ) dm.splits = {"time": {"train": object(), "val": object()}} diff --git a/tests/data/test_source_data_datasets.py b/tests/data/test_source_data_datasets.py index c9a2f2f..c5b036c 100644 --- a/tests/data/test_source_data_datasets.py +++ b/tests/data/test_source_data_datasets.py @@ -4,11 +4,23 @@ import pytest import torch import xarray as xr +from torch.utils.data import Dataset -from mlcast.data.source_data_datasets import ( - SourceDataPrecomputedSamplingDataset, - SourceDataRandomSamplingDataset, -) +from mlcast.data.forecasting import ForecastingDataset +from mlcast.data.reconstruction import ReconstructionDataset +from mlcast.data.sequence import SourceDataPrecomputedSequenceDataset, SourceDataRandomSequenceDataset + + +class MockSequenceDataset(Dataset): + def __init__(self, sequence_steps: int, num_samples: int = 2) -> None: + self.sequence_steps = sequence_steps + self.num_samples = num_samples + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> torch.Tensor: + return torch.arange(self.sequence_steps, dtype=torch.float32)[:, None, None, None].expand(-1, 1, 2, 2) @pytest.fixture @@ -26,110 +38,75 @@ def mock_csv(tmp_path: Path) -> str: return str(csv_path) -def test_precomputed_sampling_dataset(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that SourceDataPrecomputedSamplingDataset outputs the correct shape.""" - input_steps = 2 - forecast_steps = 1 - ds = SourceDataPrecomputedSamplingDataset( +def test_precomputed_sequence_dataset(fp_test_dataset: Path, mock_csv: str) -> None: + """Precomputed sequence dataset should output normalized sequence tensors.""" + sequence_steps = 3 + ds = SourceDataPrecomputedSequenceDataset( zarr_path=str(fp_test_dataset), csv_path=mock_csv, standard_names=["rainfall_flux"], - input_steps=input_steps, - forecast_steps=forecast_steps, + sequence_steps=sequence_steps, width=16, height=16, - return_mask=True, ) assert len(ds) == 3 sample = ds[0] + assert sample.shape == (sequence_steps, 1, 16, 16) + assert sample.dtype == torch.float32 - assert "input" in sample - assert "target" in sample - assert "target_mask" in sample - - input_t = sample["input"] - target_t = sample["target"] - target_mask_t = sample["target_mask"] - assert input_t.shape == (input_steps, 1, 16, 16) - assert target_t.shape == (forecast_steps, 1, 16, 16) - assert target_mask_t.shape == (forecast_steps, 1, 16, 16) - assert isinstance(input_t, torch.Tensor) - assert isinstance(target_t, torch.Tensor) - assert isinstance(target_mask_t, torch.Tensor) - - -def test_precomputed_sampling_dataset_time_subset(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that subset correctly filters CSV rows by time range.""" +def test_precomputed_sequence_dataset_time_subset(fp_test_dataset: Path, mock_csv: str) -> None: + """Subset should correctly filter CSV rows by time range.""" zarr_ds = xr.open_zarr(str(fp_test_dataset)) time_index = zarr_ds.indexes["time"] - ds = SourceDataPrecomputedSamplingDataset( + ds = SourceDataPrecomputedSequenceDataset( zarr_path=str(fp_test_dataset), csv_path=mock_csv, standard_names=["rainfall_flux"], - input_steps=2, - forecast_steps=1, + sequence_steps=3, subset={"time": (str(time_index[0]), str(time_index[8]))}, ) assert len(ds) == 2 -def test_precomputed_sampling_dataset_forecast_steps_guard(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that instantiation with input_steps=0 raises ValueError.""" - with pytest.raises(ValueError, match="input_steps"): - SourceDataPrecomputedSamplingDataset( +def test_precomputed_sequence_dataset_sequence_steps_guard(fp_test_dataset: Path, mock_csv: str) -> None: + """Instantiation with sequence_steps=0 should raise ValueError.""" + with pytest.raises(ValueError, match="sequence_steps"): + SourceDataPrecomputedSequenceDataset( zarr_path=str(fp_test_dataset), csv_path=mock_csv, standard_names=["rainfall_flux"], - input_steps=0, - forecast_steps=3, + sequence_steps=0, ) -def test_random_sampling_dataset(fp_test_dataset: Path) -> None: - """Test that SourceDataRandomSamplingDataset outputs the correct shape.""" - input_steps = 3 - forecast_steps = 2 - ds = SourceDataRandomSamplingDataset( +def test_random_sequence_dataset(fp_test_dataset: Path) -> None: + """Random sequence dataset should output normalized sequence tensors.""" + sequence_steps = 5 + ds = SourceDataRandomSequenceDataset( zarr_path=str(fp_test_dataset), standard_names=["rainfall_flux"], - input_steps=input_steps, - forecast_steps=forecast_steps, + sequence_steps=sequence_steps, width=32, height=32, epoch_size=10, - return_mask=True, ) assert len(ds) == 10 sample = ds[0] + assert sample.shape == (sequence_steps, 1, 32, 32) + assert sample.dtype == torch.float32 - assert "input" in sample - assert "target" in sample - assert "target_mask" in sample - - input_t = sample["input"] - target_t = sample["target"] - target_mask_t = sample["target_mask"] - - assert input_t.shape == (input_steps, 1, 32, 32) - assert target_t.shape == (forecast_steps, 1, 32, 32) - assert target_mask_t.shape == (forecast_steps, 1, 32, 32) - assert input_t.dtype == torch.float32 - assert target_t.dtype == torch.float32 - assert target_mask_t.dtype == torch.float32 - -def test_random_sampling_dataset_time_subset(fp_test_dataset: Path) -> None: - """Test that subset correctly slices the Zarr store.""" +def test_random_sequence_dataset_time_subset(fp_test_dataset: Path) -> None: + """Subset should correctly slice the Zarr store.""" zarr_ds = xr.open_zarr(str(fp_test_dataset)) time_index = zarr_ds.indexes["time"] - ds = SourceDataRandomSamplingDataset( + ds = SourceDataRandomSequenceDataset( zarr_path=str(fp_test_dataset), standard_names=["rainfall_flux"], - input_steps=3, - forecast_steps=2, + sequence_steps=5, subset={"time": (str(time_index[0]), str(time_index[49]))}, epoch_size=10, ) @@ -138,12 +115,26 @@ def test_random_sampling_dataset_time_subset(fp_test_dataset: Path) -> None: assert len(ds) == 10 -def test_random_sampling_dataset_forecast_steps_guard(fp_test_dataset: Path) -> None: - """Test that instantiation with input_steps=0 raises ValueError.""" - with pytest.raises(ValueError, match="input_steps"): - SourceDataRandomSamplingDataset( - zarr_path=str(fp_test_dataset), - standard_names=["rainfall_flux"], - input_steps=0, - forecast_steps=5, - ) +def test_forecasting_dataset_splits_sequence_and_derives_mask() -> None: + """ForecastingDataset should split one sequence into input and target tensors.""" + base_dataset = MockSequenceDataset(sequence_steps=6) + dataset = ForecastingDataset(base_dataset, input_steps=2, forecast_steps=4, return_mask=True) + + sample = dataset[0] + assert sample["input"].shape == (2, 1, 2, 2) + assert sample["target"].shape == (4, 1, 2, 2) + assert sample["target_mask"].shape == (4, 1, 2, 2) + assert torch.all(sample["target_mask"] == 1.0) + + +def test_reconstruction_dataset_creates_overlapping_windows() -> None: + """ReconstructionDataset should expose all overlapping windows.""" + base_dataset = MockSequenceDataset(sequence_steps=5, num_samples=2) + dataset = ReconstructionDataset(base_dataset, input_steps=3) + + assert len(dataset) == 6 + first_window = dataset[0] + second_window = dataset[1] + assert first_window.shape == (3, 1, 2, 2) + assert torch.equal(first_window[:, 0, 0, 0], torch.tensor([0.0, 1.0, 2.0])) + assert torch.equal(second_window[:, 0, 0, 0], torch.tensor([1.0, 2.0, 3.0])) diff --git a/tests/models/test_autoencoder.py b/tests/models/test_autoencoder.py new file mode 100644 index 0000000..dd5de50 --- /dev/null +++ b/tests/models/test_autoencoder.py @@ -0,0 +1,90 @@ +import torch +import torch.nn.functional as F + +from mlcast.models.autoencoder import AutoencoderNet, Decoder, Encoder + + +def test_encoder_output_shape() -> None: + """Encoder should preserve time and downsample spatial dimensions.""" + batch_size = 2 + input_steps = 4 + channels = 1 + height = 16 + width = 16 + latent_channels = 6 + + encoder = Encoder(input_channels=channels, hidden_channels=4, latent_channels=latent_channels, num_blocks=2) + x = torch.randn(batch_size, input_steps, channels, height, width) + + z = encoder(x) + + assert z.shape == (batch_size, latent_channels, input_steps, height // 4, width // 4) + + +def test_decoder_output_shape() -> None: + """Decoder should preserve time and upsample spatial dimensions.""" + batch_size = 2 + input_steps = 4 + channels = 1 + latent_channels = 6 + latent_height = 4 + latent_width = 4 + + decoder = Decoder(output_channels=channels, hidden_channels=4, latent_channels=latent_channels, num_blocks=2) + z = torch.randn(batch_size, latent_channels, input_steps, latent_height, latent_width) + + y = decoder(z) + + assert y.shape == (batch_size, input_steps, channels, latent_height * 4, latent_width * 4) + + +def test_autoencoder_reconstruction_forward_pass() -> None: + """Autoencoder should reconstruct tensors with the same shape as its input.""" + batch_size = 2 + input_steps = 3 + channels = 2 + height = 16 + width = 16 + + encoder = Encoder(input_channels=channels, hidden_channels=4, latent_channels=8, num_blocks=2) + decoder = Decoder(output_channels=channels, hidden_channels=4, latent_channels=8, num_blocks=2) + model = AutoencoderNet(encoder=encoder, decoder=decoder) + x = torch.randn(batch_size, input_steps, channels, height, width) + + y = model(x) + + assert y.shape == x.shape + + +def test_autoencoder_improves_reconstruction_loss() -> None: + """Autoencoder should reduce reconstruction loss on a tiny generated dataset.""" + torch.manual_seed(42) + batch_size = 8 + input_steps = 2 + channels = 1 + height = 8 + width = 8 + + encoder = Encoder(input_channels=channels, hidden_channels=4, latent_channels=4, num_blocks=1) + decoder = Decoder(output_channels=channels, hidden_channels=4, latent_channels=4, num_blocks=1) + model = AutoencoderNet(encoder=encoder, decoder=decoder) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + + spatial_pattern = torch.linspace(-1.0, 1.0, height * width).reshape(1, 1, 1, height, width) + temporal_scale = torch.linspace(0.5, 1.5, input_steps).reshape(1, input_steps, 1, 1, 1) + samples = spatial_pattern * temporal_scale + samples = samples.repeat(batch_size, 1, channels, 1, 1) + + with torch.no_grad(): + initial_loss = F.mse_loss(model(samples), samples).item() + + for _ in range(40): + optimizer.zero_grad(set_to_none=True) + loss = F.mse_loss(model(samples), samples) + loss.backward() + optimizer.step() + + with torch.no_grad(): + final_loss = F.mse_loss(model(samples), samples).item() + + assert final_loss < initial_loss diff --git a/tests/models/test_convgru.py b/tests/models/test_convgru.py index a751320..0494968 100644 --- a/tests/models/test_convgru.py +++ b/tests/models/test_convgru.py @@ -3,7 +3,7 @@ from mlcast.models.convgru import ConvGruModel -def test_convgru_dynamic_padding(): +def test_convgru_dynamic_padding() -> None: """Verify that ConvGruModel dynamically pads non-power-of-2 inputs and crops the output.""" # Given an input with awkward spatial dimensions batch_size = 2 @@ -18,20 +18,19 @@ def test_convgru_dynamic_padding(): x = torch.randn(batch_size, time_steps, channels, height, width) - model = ConvGruModel(input_channels=channels, num_blocks=4) - model.eval() - forecast_steps = 4 + model = ConvGruModel(input_steps=time_steps, forecast_steps=forecast_steps, input_channels=channels, num_blocks=4) + model.eval() with torch.no_grad(): - preds = model(x, steps=forecast_steps, ensemble_size=1) + preds = model(x) - # Check that it didn't crash and the output shape is exactly (batch, steps, channels, height, width) - # The single ensemble member case returns out_channels = channels. - assert preds.shape == (batch_size, forecast_steps, channels, height, width) + # Check that it didn't crash and the output shape is exactly (batch, steps, 1, channels, height, width) + # The single ensemble member case adds an explicit ensemble dimension. + assert preds.shape == (batch_size, forecast_steps, 1, channels, height, width) -def test_convgru_dynamic_padding_ensemble(): +def test_convgru_dynamic_padding_ensemble() -> None: """Verify that ConvGruModel dynamically pads non-power-of-2 inputs and crops the output for ensemble generation.""" # Given an input with awkward spatial dimensions batch_size = 1 @@ -42,17 +41,33 @@ def test_convgru_dynamic_padding_ensemble(): x = torch.randn(batch_size, time_steps, channels, height, width) - model = ConvGruModel(input_channels=channels, num_blocks=3) - model.eval() - forecast_steps = 2 ensemble_size = 5 + model = ConvGruModel( + input_steps=time_steps, + forecast_steps=forecast_steps, + ensemble_size=ensemble_size, + input_channels=channels, + num_blocks=3, + ) + model.eval() with torch.no_grad(): - preds = model(x, steps=forecast_steps, ensemble_size=ensemble_size) + preds = model(x) + + # Check that it didn't crash and the output shape has an explicit ensemble dimension: + # (batch, forecast_steps, ensemble_size, channels, height, width) + assert preds.shape == (batch_size, forecast_steps, ensemble_size, channels, height, width) + + +def test_convgru_rejects_wrong_input_steps() -> None: + """ConvGruModel should reject inputs that violate its configured input length.""" + model = ConvGruModel(input_steps=3, forecast_steps=2, input_channels=1, num_blocks=2) + x = torch.randn(1, 2, 1, 32, 32) - # Check that it didn't crash and the output shape is exactly (batch, steps, ensemble_size * channels, height, width) - # Actually wait: The decoder block outputs the same number of channels as the final upsampling step. - # In the `ConvGruModel.forward` with `ensemble_size > 1`, `out` is `torch.cat(preds, dim=2)`. - # Let's verify the exact channel dimension. The original output channels per ensemble member is `channels`. - assert preds.shape == (batch_size, forecast_steps, channels * ensemble_size, height, width) + try: + model(x) + except ValueError as exc: + assert "Expected 3 input timesteps" in str(exc) + else: + raise AssertionError("Expected ConvGruModel to reject wrong input_steps") diff --git a/tests/models/test_diffusion.py b/tests/models/test_diffusion.py new file mode 100644 index 0000000..ac9258f --- /dev/null +++ b/tests/models/test_diffusion.py @@ -0,0 +1,74 @@ +import torch + +from mlcast.models.diffusion import ( + ConditionerNet, + DenoiserUNet, + DiffusionLoss, + DiffusionScheduler, + LatentDiffusionNet, +) + + +def _build_diffusion_net(latent_channels: int = 1, hidden_channels: int = 8, timesteps: int = 4) -> LatentDiffusionNet: + conditioner = ConditionerNet(latent_channels=latent_channels, hidden_channels=hidden_channels, num_blocks=1) + denoiser = DenoiserUNet( + latent_channels=latent_channels, + condition_channels=hidden_channels, + hidden_channels=hidden_channels, + num_blocks=1, + ) + scheduler = DiffusionScheduler(timesteps=timesteps) + return LatentDiffusionNet(conditioner=conditioner, denoiser=denoiser, scheduler=scheduler) + + +def test_latent_diffusion_net_api() -> None: + """LatentDiffusionNet should predict noise with the target latent shape.""" + input_time = 2 + forecast_steps = 3 + latent_channels = 1 + height = 4 + width = 4 + diffusion_net = _build_diffusion_net(latent_channels=latent_channels, hidden_channels=4, timesteps=2) + noised_target = torch.randn(2, latent_channels, forecast_steps, height, width) + input_latents = torch.randn(2, latent_channels, input_time, height, width) + timesteps = torch.zeros(2, dtype=torch.long) + + with torch.no_grad(): + predicted_noise = diffusion_net(noised_target, timesteps, input_latents) + + assert predicted_noise.shape == noised_target.shape + + +def test_diffusion_model_improves_loss_on_generated_latents() -> None: + """Diffusion model should reduce noise-prediction loss on generated latents.""" + torch.manual_seed(7) + batch_size = 8 + latent_channels = 1 + input_time = 2 + forecast_time = 3 + height = 4 + width = 4 + diffusion_net = _build_diffusion_net(latent_channels=latent_channels, hidden_channels=8, timesteps=1) + loss_fn = DiffusionLoss(diffusion_net) + optimizer = torch.optim.Adam(diffusion_net.parameters(), lr=5e-3) + + input_latents = torch.randn(batch_size, latent_channels, input_time, height, width) + target_base = input_latents.mean(dim=2, keepdim=True) + target_latents = target_base.repeat(1, 1, forecast_time, 1, 1) + target_latents = target_latents + 0.05 * torch.randn_like(target_latents) + + torch.manual_seed(42) + with torch.no_grad(): + initial_loss = loss_fn(input_latents, target_latents).item() + + for _ in range(80): + optimizer.zero_grad(set_to_none=True) + loss = loss_fn(input_latents, target_latents) + loss.backward() + optimizer.step() + + torch.manual_seed(42) + with torch.no_grad(): + final_loss = loss_fn(input_latents, target_latents).item() + + assert final_loss < initial_loss diff --git a/tests/test_cli_training.py b/tests/test_cli_training.py index 4f44da7..453dd79 100644 --- a/tests/test_cli_training.py +++ b/tests/test_cli_training.py @@ -4,7 +4,7 @@ from fiddle._src.experimental.yaml_serialization import dump_yaml -from mlcast.config import training_experiment +from mlcast.config import convgru_training_experiment from mlcast.config.fiddlers import use_random_sampler @@ -21,11 +21,13 @@ def test_cli_train_command(fp_test_dataset: Path, tmp_path: Path) -> None: "mlcast", "train", "--config", + "config:convgru_training_experiment", + "--config", "fiddler:use_random_sampler", "--config", - f"set:data.dataset_factory.zarr_path='{fp_test_dataset.absolute()}'", + f"set:data.sequence_dataset_factory.zarr_path='{fp_test_dataset.absolute()}'", "--config", - "set:data.dataset_factory.standard_names=['rainfall_flux']", + "set:data.sequence_dataset_factory.standard_names=['rainfall_flux']", "--config", "set:data.splits={'time': {'train': 0.4, 'val': 0.3, 'test': 0.3}}", "--config", @@ -54,11 +56,11 @@ def test_cli_train_from_yaml_config(fp_test_dataset: Path, tmp_path: Path) -> No the dataset path) before dumping to YAML, so the subprocess call needs no additional --config flags. This exercises the pure load-from-YAML path. """ - cfg = training_experiment.as_buildable() + cfg = convgru_training_experiment.as_buildable() # Switch to random sampler (no CSV required) and use the correct variable name use_random_sampler(cfg) - cfg.data.dataset_factory.standard_names = ["rainfall_flux"] - cfg.data.dataset_factory.zarr_path = str(fp_test_dataset.absolute()) + cfg.data.sequence_dataset_factory.standard_names = ["rainfall_flux"] + cfg.data.sequence_dataset_factory.zarr_path = str(fp_test_dataset.absolute()) cfg.data.splits = {"time": {"train": 0.4, "val": 0.3, "test": 0.3}} cfg.trainer.fast_dev_run = True cfg.data.batch_size = 1 diff --git a/tests/test_nowcasting_module.py b/tests/test_nowcasting_module.py new file mode 100644 index 0000000..6bf49c8 --- /dev/null +++ b/tests/test_nowcasting_module.py @@ -0,0 +1,43 @@ +import numpy as np +import torch + +from mlcast.modules.forecasting import OutputSpaceForecastingTaskModule + + +class DummyForecastNetwork(torch.nn.Module): + """Minimal fixed-shape forecasting network for module tests.""" + + def __init__(self, input_steps: int, forecast_steps: int, ensemble_size: int = 1) -> None: + super().__init__() + self.input_steps = input_steps + self.forecast_steps = forecast_steps + self.ensemble_size = ensemble_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + batch_size, _, channels, height, width = x.shape + return torch.zeros( + batch_size, self.forecast_steps, self.ensemble_size, channels, height, width, device=x.device + ) + + +def test_nowcasting_module_forward_uses_network_shape_contract() -> None: + """OutputSpaceForecastingTaskModule should call fixed-shape forecasting networks as network(x).""" + network = DummyForecastNetwork(input_steps=3, forecast_steps=5, ensemble_size=2) + module = OutputSpaceForecastingTaskModule(network=network, loss_class="crps") + x = torch.randn(4, 3, 1, 8, 8) + + preds = module(x) + + assert preds.shape == (4, 5, 2, 1, 8, 8) + + +def test_nowcasting_module_predict_uses_configured_output_shape() -> None: + """Prediction horizon and ensemble size should come from the configured network.""" + network = DummyForecastNetwork(input_steps=3, forecast_steps=4, ensemble_size=2) + module = OutputSpaceForecastingTaskModule(network=network, loss_class="crps") + past = torch.ones(3, 8, 8) + + preds = module.predict(past, standard_name="rainfall_rate") + + assert isinstance(preds, np.ndarray) + assert preds.shape == (2, 4, 1, 8, 8) diff --git a/tests/test_readme_snippets.py b/tests/test_readme_snippets.py index 7dc3ace..50ba094 100644 --- a/tests/test_readme_snippets.py +++ b/tests/test_readme_snippets.py @@ -13,7 +13,7 @@ import fiddle as fdl import pytest -from mlcast.config.fiddlers import set_variables, use_random_sampler +from mlcast.config.fiddlers import _iter_experiment_configs, set_variables, use_random_sampler _README = Path(__file__).parent.parent / "README.md" @@ -111,24 +111,43 @@ def _patch_cfg(cfg: fdl.Config, fp_dataset: Path, tmp_path: Path) -> None: Uses the ``set_variables`` fiddler (rather than direct assignment) so that ``network.input_channels`` is kept in sync with ``standard_names``. + Handles both flat ``Experiment`` configs and nested containers like + ``LatentDiffusionTrainingExperiment`` by finding all ``Experiment`` sub-configs + in the tree and patching each one. + Parameters ---------- cfg : fdl.Config - The Fiddle configuration graph to mutate in-place. + The Fiddle configuration to mutate in-place. + fp_dataset : Path + Local path to the cached test zarr store. + tmp_path : Path + Pytest-provided temporary directory for trainer outputs. + """ + for exp_cfg in _iter_experiment_configs(cfg): + _patch_single_experiment(exp_cfg, fp_dataset, tmp_path) + + +def _patch_single_experiment(exp_cfg: fdl.Config, fp_dataset: Path, tmp_path: Path) -> None: + """Apply lightweight training overrides to a single ``Experiment`` config. + + Parameters + ---------- + exp_cfg : fdl.Config + A single ``Experiment`` config node (has ``data``, ``trainer``). fp_dataset : Path Local path to the cached test zarr store. tmp_path : Path Pytest-provided temporary directory for trainer outputs. """ - cfg.data.dataset_factory.zarr_path = str(fp_dataset.absolute()) - set_variables(cfg, standard_names=["rainfall_flux"]) - # Switch to the on-the-fly random sampler so no pre-computed CSV is needed. - use_random_sampler(cfg) - cfg.data.splits = {"time": {"train": 0.4, "val": 0.3, "test": 0.3}} - cfg.trainer.fast_dev_run = True - cfg.data.batch_size = 1 - cfg.data.num_workers = 0 - cfg.trainer.default_root_dir = str(tmp_path) + exp_cfg.data.sequence_dataset_factory.zarr_path = str(fp_dataset.absolute()) + set_variables(exp_cfg, standard_names=["rainfall_flux"]) + use_random_sampler(exp_cfg) + exp_cfg.data.splits = {"time": {"train": 0.4, "val": 0.3, "test": 0.3}} + exp_cfg.trainer.fast_dev_run = True + exp_cfg.data.batch_size = 1 + exp_cfg.data.num_workers = 0 + exp_cfg.trainer.default_root_dir = str(tmp_path) def _inject_patch(snippet: str) -> ast.Module: diff --git a/tests/test_task_modules.py b/tests/test_task_modules.py new file mode 100644 index 0000000..d8275f7 --- /dev/null +++ b/tests/test_task_modules.py @@ -0,0 +1,98 @@ +import numpy as np +import torch + +from mlcast.models.autoencoder import AutoencoderNet, Decoder, Encoder +from mlcast.models.diffusion import ConditionerNet, DenoiserUNet, DiffusionScheduler, LatentDiffusionNet +from mlcast.modules.forecasting import LatentDiffusionTaskModule, OutputSpaceForecastingTaskModule +from mlcast.modules.reconstruction import ReconstructionTaskModule + + +class IdentityReconstructionNetwork(torch.nn.Module): + """Minimal reconstruction network used in wrapper tests.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Return the input unchanged.""" + return x + + +def test_reconstruction_module_uses_batch_as_target() -> None: + """ReconstructionTaskModule should compute loss against the input batch itself.""" + module = ReconstructionTaskModule(network=IdentityReconstructionNetwork(), loss_class="mse") + batch = torch.randn(2, 3, 1, 4, 4) + + loss = module.training_step(batch, 0) + + assert torch.isfinite(loss) + assert loss.ndim == 0 + + +def test_output_space_forecasting_task_module_trainable_parameters_match_network() -> None: + """OutputSpaceForecastingTaskModule should optimize the forecasting network parameters.""" + + class TinyForecastNetwork(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + network = TinyForecastNetwork() + module = OutputSpaceForecastingTaskModule(network=network, loss_class="mse") + + assert module.trainable_parameters == list(network.parameters()) + + +def _build_autoencoder() -> AutoencoderNet: + encoder = Encoder(input_channels=1, hidden_channels=4, latent_channels=4, num_blocks=1) + decoder = Decoder(output_channels=1, hidden_channels=4, latent_channels=4, num_blocks=1) + return AutoencoderNet(encoder=encoder, decoder=decoder) + + +def _build_diffusion_net() -> LatentDiffusionNet: + conditioner = ConditionerNet(latent_channels=4, hidden_channels=8, num_blocks=1) + denoiser = DenoiserUNet(latent_channels=4, condition_channels=8, hidden_channels=8, num_blocks=2) + scheduler = DiffusionScheduler(timesteps=2) + return LatentDiffusionNet(conditioner=conditioner, denoiser=denoiser, scheduler=scheduler) + + +def test_latent_diffusion_module_training_step_runs() -> None: + """LatentDiffusionTaskModule should encode forecasting batches and return scalar loss.""" + module = LatentDiffusionTaskModule( + autoencoder=_build_autoencoder(), diffusion_net=_build_diffusion_net(), forecast_steps=3 + ) + batch = { + "input": torch.randn(2, 2, 1, 8, 8), + "target": torch.randn(2, 3, 1, 8, 8), + } + + loss = module.training_step(batch, 0) + + assert torch.isfinite(loss) + assert loss.ndim == 0 + + +def test_latent_diffusion_task_module_trainable_parameters_exclude_autoencoder() -> None: + """LatentDiffusionTaskModule should optimize only diffusion-net parameters.""" + autoencoder = _build_autoencoder() + diffusion_net = _build_diffusion_net() + module = LatentDiffusionTaskModule(autoencoder=autoencoder, diffusion_net=diffusion_net, forecast_steps=3) + + assert module.trainable_parameters == list(diffusion_net.parameters()) + assert module.trainable_parameters != list(autoencoder.parameters()) + + +def test_latent_diffusion_module_predict_uses_configured_output_shape() -> None: + """LatentDiffusionTaskModule prediction should decode configured ensemble forecasts.""" + module = LatentDiffusionTaskModule( + autoencoder=_build_autoencoder(), + diffusion_net=_build_diffusion_net(), + forecast_steps=3, + ensemble_size=2, + ) + past = torch.ones(2, 8, 8) + + preds = module.predict(past, standard_name="rainfall_rate") + + assert isinstance(preds, np.ndarray) + assert preds.shape == (2, 3, 1, 8, 8) diff --git a/uv.lock b/uv.lock index a0a4af5..fc5a291 100644 --- a/uv.lock +++ b/uv.lock @@ -1026,6 +1026,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/d5/c5db1ea3394c6e1732fb3286b3bd878b59507a8f77d32a2cebda7d7b7cd4/donfig-0.8.1.post1-py3-none-any.whl", hash = "sha256:2a3175ce74a06109ff9307d90a230f81215cbac9a751f4d1c6194644b8204f9d", size = 21592, upload-time = "2024-05-23T14:13:55.283Z" }, ] +[[package]] +name = "einops" +version = "0.8.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/77/850bef8d72ffb9219f0b1aac23fbc1bf7d038ee6ea666f331fa273031aa2/einops-0.8.2.tar.gz", hash = "sha256:609da665570e5e265e27283aab09e7f279ade90c4f01bcfca111f3d3e13f2827", size = 56261, upload-time = "2026-01-26T04:13:17.638Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/09/f8d8f8f31e4483c10a906437b4ce31bdf3d6d417b73fe33f1a8b59e34228/einops-0.8.2-py3-none-any.whl", hash = "sha256:54058201ac7087911181bfec4af6091bb59380360f069276601256a76af08193", size = 65638, upload-time = "2026-01-26T04:13:18.546Z" }, +] + [[package]] name = "etils" version = "1.14.0" @@ -2237,6 +2246,7 @@ dependencies = [ { name = "absl-py" }, { name = "beartype" }, { name = "cf-xarray" }, + { name = "einops" }, { name = "etils" }, { name = "fiddle" }, { name = "fire" }, @@ -2286,6 +2296,7 @@ gpu-cu130 = [ [package.dev-dependencies] dev = [ + { name = "pre-commit" }, { name = "pytest" }, ] @@ -2295,6 +2306,7 @@ requires-dist = [ { name = "aiohttp", marker = "extra == 'dev'", specifier = ">=3.9.3" }, { name = "beartype", specifier = ">=0.18" }, { name = "cf-xarray", specifier = ">=0.10" }, + { name = "einops", specifier = ">=0.8" }, { name = "etils", specifier = ">=1.13" }, { name = "fiddle", specifier = ">=0.3" }, { name = "fire", specifier = ">=0.7" }, @@ -2331,7 +2343,10 @@ requires-dist = [ provides-extras = ["dev", "gpu-cu128", "gpu-cu130"] [package.metadata.requires-dev] -dev = [{ name = "pytest", specifier = ">=9.0.3" }] +dev = [ + { name = "pre-commit", specifier = ">=4.6" }, + { name = "pytest", specifier = ">=9.0.3" }, +] mlflow = [] [[package]]