Skip to content

Latent diffusion with fiddle configuration#16

Open
leifdenby wants to merge 37 commits into
mlcast-community:mainfrom
leifdenby:wip/ldcast-fiddle-refactor
Open

Latent diffusion with fiddle configuration#16
leifdenby wants to merge 37 commits into
mlcast-community:mainfrom
leifdenby:wip/ldcast-fiddle-refactor

Conversation

@leifdenby

@leifdenby leifdenby commented Jun 4, 2026

Copy link
Copy Markdown
Member

I'm making this PR to share my progress on refactoring the great work by @martinbo-meteo on #4 to fit more into the fiddle-based configuration and class hierarchy already in mlcast.

I am intentionally calling this latent diffusion and not LDCast because what I have created so far is not a true reproduction of the architecture and training procedure defined in the LDCast paper. Instead it is an attempt to structure how we in general could train and run inference on latent diffusion models in mlcast.

Brief list of changes and additions (relative to main)

Data pipeline

  • New sequence-first data layer with SourceDataSequenceDatasetBase, SourceDataPrecomputedSequenceDataset, SourceDataRandomSequenceDataset replacing the old source-data datasets
  • ForecastingDataset — splits sequences into input/target pairs for forecasting
  • ReconstructionDataset — sliding-window wrapper on sequency datasets for autoencoder training
  • ForecastingDataModule / ReconstructionDataModule replacing SourceDataDataModule

Model architecture

  • Autoencoder: AutoencoderNet + Encoder / Decoder submodules in models/autoencoder/
  • Latent diffusion: LatentDiffusionNet, DenoiserUNet, ConditionerNet, DiffusionScheduler, DiffusionLoss, DiffusionSampler, ExponentialMovingAverage in models/diffusion/
  • Explicit ensemble dimension (B, T, M, C, H, W) throughout the forecasting pipeline
  • Refactored ConvGRU to respect the new fixed-shape model contract (input_steps, forecast_steps, ensemble_size at init)

Training modules (modules/)

  • BaseForecastingTaskModule — shared Lightning plumbing for all forecasting tasks
  • OutputSpaceForecastingTaskModule — direct forecasting (ConvGRU, custom networks)
  • ReconstructionTaskModule — autoencoder training
  • LatentDiffusionTaskModule — two-stage latent diffusion with frozen encoder, EMA support, latent-space loss

Configuration system

  • Fiddle @auto_config entry points: convgru_training_experiment, latent_diffusion_experiment
  • Experiment dataclass as the flat config schema
  • LatentDiffusionTrainingExperiment — two-stage config sharing autoencoder identity across stages
  • @applies_to_experiments decorator so fiddlers automatically handle nested multi-stage configs
  • Config consistency checks for both single and two-stage experiments
  • Config diagram generation (Graphviz SVG)

Infrastructure

  • Metric naming migrated to TensorBoard convention: split/name format, rec_loss for reconstruction, loss for forecasting/diffusion
  • Config values aligned with DMI/Martinbo: AdamW(betas=(0.5,0.9), weight_decay=1e-3), per-stage LR, ReduceLROnPlateau(factor=0.25, patience=3), EMA decay 0.9999, timesteps=1000, early stopping patience=6 with check_finite=False, save_top_k=3, accumulate_grad_batches=2, reduced batch sizes

Comparison with reference implementations

To make it clear what the differences are between the functionality in this PR and the LDCast by @martinbo-meteo and what @mfroelund worked on from the original LDCast implementation from meteoswiss I've created this table:

Category mlcast (this PR) DMI fork (ldcast-dmi/) Martinbo fork
Autoencoder Deterministic AutoencoderNet, no KL loss AutoencoderKL (VAE with KL regularization, latent sampling) Same as DMI (AutoencoderKLNet, VAE with L1 + KL loss)
Denoiser DenoiserUNet — 32 hidden channels, 2-level UNet, 3D conv, no attention UNetModel — 128 model channels, attention blocks, 8 heads, 3D conv Same as DMI
Context encoder ConditionerNet — lightweight residual conv3d, single resolution AFNONowcastNetCascade — multi-resolution AFNO spectral feature pyramid Same as DMI
Inference sampler DiffusionSampler — DDPM ancestral, full 1000 steps PLMSSampler — accelerated PLMS, ~50 steps Same as DMI
EMA Fixed decay 0.9999, full LatentDiffusionNet Adaptive min(0.9999, (1+n)/(10+n)), full model Adaptive decay 0.9999, denoiser only
Beta schedules Linear only Linear, cosine, sqrt_linear, sqrt Linear
Loss L2 (MSE), eps prediction L1 or L2, eps or x0 prediction L2 (MSE), eps prediction
Classifier-free guidance Not supported Supported in PLMS sampler Not supported
Training framework PyTorch Lightning + Fiddle config PyTorch Lightning + YAML (Pydantic) PyTorch Lightning + YAML
Configuration Python @auto_config + dataclasses YAML → Pydantic models YAML
Config naming latent_diffusion_experiment genforecast ldcast
Optimizer AdamW(betas=(0.5,0.9), wd=1e-3) Same Same
LR scheduler ReduceLROnPlateau(factor=0.25, patience=3) Same Same
Diffusion timesteps 1000, linear β ∈ [1e-4, 2e-2] Same Same
Early stopping patience=6, check_finite=False (diffusion) patience=6 patience=6, check_finite=False
Model checkpointing save_top_k=3 save_top_k=3 save_top_k=1 (Lightning default)

I am sharing this now because I think the core components are in place and working, and I would like to get feedback on the overall structure and design decisions. There are a few things that am unsure about, for how and whether it would be better to split the two-stage experiment into separate experiment configs, and how they should depend on each other in that case. I also haven't yet tested the actual training for neither the autoencoder or the latent diffusion stages, so there are likely to be bugs and issues in the training modules and loops that I haven't yet seen. I also haven't yet implemented the inference/sampling loop for the latent diffusion model, which will be important to validate that the trained models can actually generate forecasts.

On thing you might want to start with is the README. I've added the same fiddle-based config diagram as for ConvGRU for example, and there is also a description of the latent diffusion architecture

Hope you find this interesting @martinbo-meteo, @sidekock, @ladc, @franchg (and whoever else might be interested!), and feel free to respond with any feedback that you have :) Thank you

leifdenby added 30 commits May 19, 2026 14:40
- make split definitions coordinate-explicit so splitting can be extended beyond time in future

- support both fraction-based splits and explicit tuple-range splits such as ("2020-01-01T12:00", "2021-01-01T12:00")

- replace the fixed train_ratio/val_ratio API with nested split config and dataset subset handling across data, config, and tests
- only create a test dataset for fraction-mode splits when a test fraction is explicitly configured

- warn when configured split fractions do not sum to 1.0 because any remainder will be unused

- require SourceDataDataModule splits to be provided explicitly and update examples and tests accordingly
- emit per-split sample counts and resolved subset ranges directly from SourceDataDataModule.setup()

- remove the unused count_split_samples helper now that the setup logging provides the same operational visibility
- build only train and validation datasets for the Lightning fit stage

- build only the requested validation or test datasets for validate and test stages

- document the stage-dependent setup behavior and cover it with datamodule tests
- convert normalized sample arrays to float32 before torch.from_numpy so float64 source data does not leak into training tensors

- make stacked channel arrays contiguous float32 views in both dataset sampling paths and cover the returned tensor dtypes in tests
- change the broad data/ ignore rule to /data/ so src/mlcast/data is no longer matched accidentally

- avoid needing force-adds for tracked source files under the mlcast.data package
- Add LDCastTrainingExperiment dataclass and ldcast_training_experiment() config
- Add LDCast-specific validation in consistency_checks.py
- Freeze autoencoder on fit/val/test/predict start in LatentDiffusionTaskModule
- Export LDCast config from mlcast.config
- Add identity-sharing and sequencing tests
- Add train_from_config test for LDCast
- Document DMI alignment differences in plan
…n comparison

- Add LDCast section to README covering autoencoder, latent diffusion,
  and two-stage training experiment
- Add ldcast_config_diagram.svg (Fiddle config graph render)
- Add generate_ldcast_config_diagram.py script
- Add pre-commit hook to keep LDCast diagram in sync
- Add Martinbo alignment comparison notes to ldcast-refactor-plan.md
…line

ConvGruModel, task modules, and visualization now use an explicit
ensemble dim (B, T, M, C, H, W) instead of flattening into channels.
Task modules flatten via einops only at loss computation time.
Add _iter_experiment_configs tree-walker, applies_to_experiments
decorator, and _find_nn_modules_with_input_channels helper so
fiddlers work on both flat Experiment configs (convgru) and
nested containers like LDCastTrainingExperiment. Refactor
set_variables to walk pl_module for any nn.Module with
input_channels instead of hardcoding cfg.pl_module.network.
…ckage

Split Experiment dataclass (base.py) from experiment-specific config
functions. Move convgru_training_experiment to archetype/convgru.py
and ldcast_training_experiment to archetype/ldcast.py.
- Fix 5D tensor shape mismatch in visualization.log_images
  (ensemble dim was not squeezed before apply_radar_colormap)
- Fix EMA device mismatch in ExponentialMovingAverage when module
  moves between GPU stages (shadow_params stayed on CPU)
- Fix README LDCast snippet: pass experiment configs to fiddlers
  instead of data module configs directly
- Fix test harness _patch_cfg for nested LDCastTrainingExperiment
  (iterate Experiment sub-configs instead of assuming flat cfg.data)
- Migrate metric names to TensorBoard convention: split/name format
  with rec_loss for reconstruction, loss for forecasting/diffusion
- Switch both stages to AdamW with betas=(0.5, 0.9), weight_decay=1e-3
- Raise autoencoder LR to 1e-3, keep diffusion LR at 1e-4
- Reduce LR scheduler factor to 0.25, patience to 3
- Increase EMA decay to 0.9999
- Increase diffusion timesteps to 1000 with default beta range
- Reduce early stopping patience to 6, add check_finite=False on stage2
- Increase model checkpoint save_top_k to 3
- Reduce batch sizes (stage1: 16->4, stage2: 8->1)
- Add accumulate_grad_batches=2 to both stages
leifdenby added 7 commits June 4, 2026 01:52
… covers both stages via @applies_to_experiments
- Use codebase convention for jaxtyping dim names (time/channels/height/width)
- Replace unsqueeze/cat with einops.rearrange for ensemble dim
- Remove redundant/incorrect shape comment (pattern is self-documenting)
- Clean up input_channels property docstring
Our architecture (deterministic autoencoder, small denoiser with no
attention, single-resolution conditioner) is a lightweight latent
diffusion model, not a true LDCast as defined by the DMI/Martinbo
references. Renaming sets accurate expectations:

- LDCastTrainingExperiment -> LatentDiffusionTrainingExperiment
- ldcast_training_experiment -> latent_diffusion_experiment
- ldcast.py -> latent_diffusion.py
- All docs, tests, diagrams, and config references updated
@leifdenby leifdenby changed the title Latent diffusion with fiddle configuration Latent diffusion training and inference with fiddle configuration Jun 4, 2026
@leifdenby leifdenby changed the title Latent diffusion training and inference with fiddle configuration Latent diffusion training with fiddle configuration Jun 4, 2026
@leifdenby leifdenby changed the title Latent diffusion training with fiddle configuration Latent diffusion with fiddle configuration Jun 4, 2026
@sidekock

sidekock commented Jun 4, 2026

Copy link
Copy Markdown
Member

Hi Leif,
You can defenitly add me as a reviewer here. I have taken some time out next week to commit to this pr

@leifdenby leifdenby requested a review from sidekock June 4, 2026 16:13
@martinbo-meteo

Copy link
Copy Markdown

Hi Leif,
That is really nice, the code seems quite clear to me !

Here are some thoughts:

  • I think that the trick to put the forward diffusion and the training denoising in the loss function is nice but I feel that it is quite uncommon (or is it for diffusion models ?) and I wonder if this could be an issue...
  • I was also thinking that we could cut out the diffusion part of the LatentDiffusionTaskModule to make it only handle the latent part, so that it could be reused if other models internally work in a latent space (I am not sure there are a lot of such models for the moment though). The diffusion could be done by a Forecaster and the line 560 of mlcast/modules/forecasting.py should just be replaced by something like forecast_latents = self.forecaster(repeated_input_latents). But it would maybe more difficult to adapt the fact that the loss function handles the forward diffusion and the training denoising.

About the question on 1 vs 2 experiments, I think it is quite important to be able to get a feeling of what the autoencoder is actually doing before training the diffusion part.

  • I guess it would be quite easy to set keywords in the experiment configs in a such a way that one of the two stages does not happen (or is reduced to 0 epoch), so one possibility would be to just work that way if one wants to train the autoencoder
  • On other hand, when looking at mlcast/config/archetype/latent_diffusion.py, it seems that the two stages only share the sequence_dataset_factory object (which knows mainly about the zarr and the csv files, and the chunking of timesteps for the autoencoder). So we could try to provide different meaningful experiments for the autoencoder and the diffusion parts, which would be used to construct the full LatentDiffusionTrainingExperiment if needed

I did not run the code for the moment, but will take the time to do it !

@leifdenby

Copy link
Copy Markdown
Member Author

Thanks for your feedback @martinbo-meteo! I am glad you like what I've written in general :)

  • I think that the trick to put the forward diffusion and the training denoising in the loss function is nice but I feel that it is quite uncommon (or is it for diffusion models ?) and I wonder if this could be an issue...

Ah, you're right, that is weird. It shouldn't be there I agree. I think I made this mistake because I wasn't iterating over how to structure the task-based pythorch-lightning module and the llm-agent I was using must have moved that diffusion forward to there. I will move that!

Just a heads-up to @sidekock that I will move this (because I made a mistake here).

  • I was also thinking that we could cut out the diffusion part of the LatentDiffusionTaskModule to make it only handle the latent part, so that it could be reused if other models internally work in a latent space (I am not sure there are a lot of such models for the moment though). The diffusion could be done by a Forecaster and the line 560 of mlcast/modules/forecasting.py should just be replaced by something like forecast_latents = self.forecaster(repeated_input_latents). But it would maybe more difficult to adapt the fact that the loss function handles the forward diffusion and the training denoising.

Yes, I have been wondering that too. What that would mean conceptually to me though would be to say that the "data" coming into the model architecture is actually not in physical space, but rather in a latent space. And that would mean that we be using the trained encoder inside the torch.Dataset. I felt a bit weird about having torch.nn.Modules inside the dataset class itself, since I feel like the role of torch datasets should only be to load data from disk and return the tensors, but of course a data-augmentation (say rotation) would also be a transformation of tensors (like an encoder would be) and so is also actually a torch.nn.Module. So I am not sure to be honest. The other slight oddity with this approach would be how to go back into real non-latent space. Or said differently, what class would own the decoder? To get physical-space tensors out of the pytorch-lightning module the decoder would have to be available to the task-module, but then the inputs and output of the task-module are in a different space, and the encoder is given the dataset instance whereas the decoder is given to the lightning module representing the task. This also feels a bit odd to me. What do you think?

@leifdenby

Copy link
Copy Markdown
Member Author
  • So we could try to provide different meaningful experiments for the autoencoder and the diffusion parts, which would be used to construct the full LatentDiffusionTrainingExperiment if needed

Yes, I was thinking that would be a better way of going about this. In that case one could maybe have a stage-1 experiment that could be instantiated from the CLI, and then the stage-2 experiment would take a stage-1 experiment as input (as well as the parameters of the stage-2 experiment itself). I haven't tried coding this up yet though...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants