Latent diffusion with fiddle configuration#16
Conversation
- make split definitions coordinate-explicit so splitting can be extended beyond time in future
- support both fraction-based splits and explicit tuple-range splits such as ("2020-01-01T12:00", "2021-01-01T12:00")
- replace the fixed train_ratio/val_ratio API with nested split config and dataset subset handling across data, config, and tests
- only create a test dataset for fraction-mode splits when a test fraction is explicitly configured - warn when configured split fractions do not sum to 1.0 because any remainder will be unused - require SourceDataDataModule splits to be provided explicitly and update examples and tests accordingly
- emit per-split sample counts and resolved subset ranges directly from SourceDataDataModule.setup() - remove the unused count_split_samples helper now that the setup logging provides the same operational visibility
- build only train and validation datasets for the Lightning fit stage - build only the requested validation or test datasets for validate and test stages - document the stage-dependent setup behavior and cover it with datamodule tests
- convert normalized sample arrays to float32 before torch.from_numpy so float64 source data does not leak into training tensors - make stacked channel arrays contiguous float32 views in both dataset sampling paths and cover the returned tensor dtypes in tests
- change the broad data/ ignore rule to /data/ so src/mlcast/data is no longer matched accidentally - avoid needing force-adds for tracked source files under the mlcast.data package
- Add LDCastTrainingExperiment dataclass and ldcast_training_experiment() config - Add LDCast-specific validation in consistency_checks.py - Freeze autoencoder on fit/val/test/predict start in LatentDiffusionTaskModule - Export LDCast config from mlcast.config - Add identity-sharing and sequencing tests - Add train_from_config test for LDCast - Document DMI alignment differences in plan
…n comparison - Add LDCast section to README covering autoencoder, latent diffusion, and two-stage training experiment - Add ldcast_config_diagram.svg (Fiddle config graph render) - Add generate_ldcast_config_diagram.py script - Add pre-commit hook to keep LDCast diagram in sync - Add Martinbo alignment comparison notes to ldcast-refactor-plan.md
…line ConvGruModel, task modules, and visualization now use an explicit ensemble dim (B, T, M, C, H, W) instead of flattening into channels. Task modules flatten via einops only at loss computation time.
Add _iter_experiment_configs tree-walker, applies_to_experiments decorator, and _find_nn_modules_with_input_channels helper so fiddlers work on both flat Experiment configs (convgru) and nested containers like LDCastTrainingExperiment. Refactor set_variables to walk pl_module for any nn.Module with input_channels instead of hardcoding cfg.pl_module.network.
…ckage Split Experiment dataclass (base.py) from experiment-specific config functions. Move convgru_training_experiment to archetype/convgru.py and ldcast_training_experiment to archetype/ldcast.py.
…o wip/ldcast-fiddle-refactor
- Fix 5D tensor shape mismatch in visualization.log_images (ensemble dim was not squeezed before apply_radar_colormap) - Fix EMA device mismatch in ExponentialMovingAverage when module moves between GPU stages (shadow_params stayed on CPU) - Fix README LDCast snippet: pass experiment configs to fiddlers instead of data module configs directly - Fix test harness _patch_cfg for nested LDCastTrainingExperiment (iterate Experiment sub-configs instead of assuming flat cfg.data) - Migrate metric names to TensorBoard convention: split/name format with rec_loss for reconstruction, loss for forecasting/diffusion
- Switch both stages to AdamW with betas=(0.5, 0.9), weight_decay=1e-3 - Raise autoencoder LR to 1e-3, keep diffusion LR at 1e-4 - Reduce LR scheduler factor to 0.25, patience to 3 - Increase EMA decay to 0.9999 - Increase diffusion timesteps to 1000 with default beta range - Reduce early stopping patience to 6, add check_finite=False on stage2 - Increase model checkpoint save_top_k to 3 - Reduce batch sizes (stage1: 16->4, stage2: 8->1) - Add accumulate_grad_batches=2 to both stages
… covers both stages via @applies_to_experiments
- Use codebase convention for jaxtyping dim names (time/channels/height/width) - Replace unsqueeze/cat with einops.rearrange for ensemble dim - Remove redundant/incorrect shape comment (pattern is self-documenting) - Clean up input_channels property docstring
Our architecture (deterministic autoencoder, small denoiser with no attention, single-resolution conditioner) is a lightweight latent diffusion model, not a true LDCast as defined by the DMI/Martinbo references. Renaming sets accurate expectations: - LDCastTrainingExperiment -> LatentDiffusionTrainingExperiment - ldcast_training_experiment -> latent_diffusion_experiment - ldcast.py -> latent_diffusion.py - All docs, tests, diagrams, and config references updated
|
Hi Leif, |
|
Hi Leif, Here are some thoughts:
About the question on 1 vs 2 experiments, I think it is quite important to be able to get a feeling of what the autoencoder is actually doing before training the diffusion part.
I did not run the code for the moment, but will take the time to do it ! |
|
Thanks for your feedback @martinbo-meteo! I am glad you like what I've written in general :)
Ah, you're right, that is weird. It shouldn't be there I agree. I think I made this mistake because I wasn't iterating over how to structure the task-based pythorch-lightning module and the llm-agent I was using must have moved that diffusion forward to there. I will move that! Just a heads-up to @sidekock that I will move this (because I made a mistake here).
Yes, I have been wondering that too. What that would mean conceptually to me though would be to say that the "data" coming into the model architecture is actually not in physical space, but rather in a latent space. And that would mean that we be using the trained encoder inside the |
Yes, I was thinking that would be a better way of going about this. In that case one could maybe have a stage-1 experiment that could be instantiated from the CLI, and then the stage-2 experiment would take a stage-1 experiment as input (as well as the parameters of the stage-2 experiment itself). I haven't tried coding this up yet though... |
I'm making this PR to share my progress on refactoring the great work by @martinbo-meteo on #4 to fit more into the fiddle-based configuration and class hierarchy already in
mlcast.I am intentionally calling this latent diffusion and not LDCast because what I have created so far is not a true reproduction of the architecture and training procedure defined in the LDCast paper. Instead it is an attempt to structure how we in general could train and run inference on latent diffusion models in
mlcast.Brief list of changes and additions (relative to
main)Data pipeline
SourceDataSequenceDatasetBase,SourceDataPrecomputedSequenceDataset,SourceDataRandomSequenceDatasetreplacing the old source-data datasetsForecastingDataset— splits sequences into input/target pairs for forecastingReconstructionDataset— sliding-window wrapper on sequency datasets for autoencoder trainingForecastingDataModule/ReconstructionDataModulereplacingSourceDataDataModuleModel architecture
AutoencoderNet+Encoder/Decodersubmodules inmodels/autoencoder/LatentDiffusionNet,DenoiserUNet,ConditionerNet,DiffusionScheduler,DiffusionLoss,DiffusionSampler,ExponentialMovingAverageinmodels/diffusion/(B, T, M, C, H, W)throughout the forecasting pipelineinput_steps,forecast_steps,ensemble_sizeat init)Training modules (
modules/)BaseForecastingTaskModule— shared Lightning plumbing for all forecasting tasksOutputSpaceForecastingTaskModule— direct forecasting (ConvGRU, custom networks)ReconstructionTaskModule— autoencoder trainingLatentDiffusionTaskModule— two-stage latent diffusion with frozen encoder, EMA support, latent-space lossConfiguration system
@auto_configentry points:convgru_training_experiment,latent_diffusion_experimentExperimentdataclass as the flat config schemaLatentDiffusionTrainingExperiment— two-stage config sharing autoencoder identity across stages@applies_to_experimentsdecorator so fiddlers automatically handle nested multi-stage configsInfrastructure
split/nameformat,rec_lossfor reconstruction,lossfor forecasting/diffusionAdamW(betas=(0.5,0.9), weight_decay=1e-3), per-stage LR,ReduceLROnPlateau(factor=0.25, patience=3), EMA decay0.9999,timesteps=1000, early stoppingpatience=6withcheck_finite=False,save_top_k=3,accumulate_grad_batches=2, reduced batch sizesComparison with reference implementations
To make it clear what the differences are between the functionality in this PR and the LDCast by @martinbo-meteo and what @mfroelund worked on from the original LDCast implementation from meteoswiss I've created this table:
ldcast-dmi/)AutoencoderNet, no KL lossAutoencoderKL(VAE with KL regularization, latent sampling)AutoencoderKLNet, VAE with L1 + KL loss)DenoiserUNet— 32 hidden channels, 2-level UNet, 3D conv, no attentionUNetModel— 128 model channels, attention blocks, 8 heads, 3D convConditionerNet— lightweight residual conv3d, single resolutionAFNONowcastNetCascade— multi-resolution AFNO spectral feature pyramidDiffusionSampler— DDPM ancestral, full 1000 stepsPLMSSampler— accelerated PLMS, ~50 steps0.9999, fullLatentDiffusionNetmin(0.9999, (1+n)/(10+n)), full model0.9999, denoiser onlyepspredictionepsorx0predictionepsprediction@auto_config+ dataclasseslatent_diffusion_experimentgenforecastldcastAdamW(betas=(0.5,0.9), wd=1e-3)ReduceLROnPlateau(factor=0.25, patience=3)check_finite=False(diffusion)check_finite=Falsesave_top_k=3save_top_k=3save_top_k=1(Lightning default)I am sharing this now because I think the core components are in place and working, and I would like to get feedback on the overall structure and design decisions. There are a few things that am unsure about, for how and whether it would be better to split the two-stage experiment into separate experiment configs, and how they should depend on each other in that case. I also haven't yet tested the actual training for neither the autoencoder or the latent diffusion stages, so there are likely to be bugs and issues in the training modules and loops that I haven't yet seen. I also haven't yet implemented the inference/sampling loop for the latent diffusion model, which will be important to validate that the trained models can actually generate forecasts.
On thing you might want to start with is the README. I've added the same fiddle-based config diagram as for
ConvGRUfor example, and there is also a description of the latent diffusion architectureHope you find this interesting @martinbo-meteo, @sidekock, @ladc, @franchg (and whoever else might be interested!), and feel free to respond with any feedback that you have :) Thank you