Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ jobs:
with:
enable-cache: true

# The local `config-diagram` hook renders the Fiddle config graph with
# graphviz, which shells out to the `dot` binary.
- run: sudo apt-get update && sudo apt-get install -y graphviz

# Run pre-commit through uv so CI uses the same project-managed environment
# as local development, including local hooks that invoke `uv run ...`.
- run: uv run pre-commit run --all-files
# pre-commit lives in the `dev` extra, so it must be synced for `uv run`.
- run: uv run --extra dev pre-commit run --all-files
93 changes: 90 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ mlcast train --config fiddler:use_random_sampler --print_config_and_exit

Run `mlcast train --help` for a full list of examples and available fiddlers.

Beyond training, the CLI provides two data-prep subcommands — `mlcast stats`
and `mlcast validate-stats` — for building and checking the datacube index that
training reads. See [Preparing training data](#preparing-training-data).

### Python API

The Python API gives you full programmatic control over the config graph before
Expand Down Expand Up @@ -257,14 +261,92 @@ experiment.run() # trainer.fit() + trainer.test()
| `set_variables` | `standard_names` | Sets the list of input variables on the dataset and updates `network.input_channels` to match |
| `toggle_masking` | `enabled` | Toggles masked-loss mode by setting both `dataset_factory.return_mask` and `pl_module.masked_loss` to the same value |
| `use_anon_s3_dataset` | `zarr_path`, `endpoint_url` | Points the dataset at an anonymous S3 object store; sets `zarr_path` and the required `storage_options` together |
| `use_random_sampler` | *(none)* | Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed CSV is available) |
| `use_ratio_splits` | `train`, `val` | Sets fraction-based train/val/test time splits on the data module (test = 1 − train − val) |
| `use_random_sampler` | *(none)* | Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed sampling index is available) |

## Preparing training data

Training reads a **stats parquet**: a precomputed index of candidate datacubes —
fixed-size `time_depth × width × height` crops of a source radar Zarr — each
tagged with per-cube statistics (`nan_count`, `sum`, `mean`, `frac_wet`). The
training dataset [`SourceDataIndexedDataset`](src/mlcast/data/source_data_datasets.py)
iterates this index (it is the dataset factory's `index_path`), and a **sampler**
reshapes the candidate pool at dataset init. The producer, the schema, and the
samplers all live in [`mlcast.sampling`](src/mlcast/sampling/); two CLI
subcommands build and check the file.

### `mlcast stats` — scan a Zarr dataset → stats parquet

Slides a window over the `(time, x, y)` grid and, for every candidate datacube,
computes its statistics in O(1) per window via a cumulative-sum trick. Windows
are filtered by a maximum-NaN budget, a spatial/temporal stride, and
time-continuity (no frame gaps); the survivors are streamed to a single
zstd-compressed parquet whose footer carries every parameter as metadata (the
contract in [`stats_spec`](src/mlcast/sampling/stats_spec.py)), so downstream
commands never have to parse the filename.

```bash
# A year of radar → 24-frame 256×256 datacubes, stride 3×16×16
mlcast stats /data/radar.zarr \
--start-date 2020-01-01 --end-date 2020-12-31 \
--time-depth 24 --width 256 --height 256 \
--step-t 3 --step-x 16 --step-y 16 \
--max-nan 10000 \
-o stats_2020.parquet
```

Common flags (`mlcast stats -h` lists them all):

| Flag | Default | Purpose |
|------|---------|---------|
| `--time-depth` / `--width` / `--height` | 24 / 256 / 256 | Datacube shape `(T, X, Y)` |
| `--step-t` / `--step-x` / `--step-y` | 3 / 16 / 16 | Stride between candidate window starts |
| `--max-nan` | 10000 | Drop any datacube with more NaNs than this |
| `--wet-threshold` | auto | Wet-pixel threshold; auto = 0.1 mm/h (rain rate) or 7 dBZ (reflectivity) |
| `--device` | auto | Compute backend: `auto` / `cpu` (bottleneck) / `cuda` |
| `--workers` | 8 | CPU worker processes, or GPU chunk-reader threads |
| `--data-var` / `--time-var` | RR / time | Names of the Zarr data and time variables |
| `-o` / `--output` | auto | Output path; auto-named from the parameters if omitted |

### `mlcast validate-stats` — check a parquet against the contract

Checks a stats parquet's column schema, metadata payload, and (unless
`--no-data-checks`) its per-row value invariants, then prints the file's
parameters, its table structure, and a preview of the first 10 rows.

```bash
mlcast validate-stats stats_2020.parquet

# Footer only — schema + metadata, skip the per-row checks
mlcast validate-stats stats_2020.parquet --no-data-checks
```

### Samplers

At training time [`SourceDataDataModule`](src/mlcast/data/source_data_datamodule.py)
applies a `Sampler` to the index **once**, at dataset init, turning the full
candidate pool into the training set via a per-row keep/discard draw — so the
dataset length is fixed and known up front, and the same set is reused every
epoch. Samplers are pluggable through a registry in
[`mlcast.sampling`](src/mlcast/sampling/samplers.py):

| Sampler | Parameters | What it does |
|---------|------------|--------------|
| `UniformSampler` | `keep_fraction` | Keep each candidate with a fixed probability, independent of its stats |
| `ImportanceSampler` | `column`, `q_min`, `scale`, `mean_weight` | Keep each candidate with probability rising with one of its statistic columns (`mean` by default) — oversampling high-rainfall datacubes without duplication |

The default config applies an `ImportanceSampler` to the **train** split and a
`UniformSampler(keep_fraction=0.1)` to **val/test**, so importance sampling
reshapes only training while validation and test stay representative. Add a
scheme by subclassing `Sampler`, decorating it with `@register_sampler("name")`,
and selecting it from a config via `get_sampler("name", ...)`.

## Project Structure

```
mlcast/
├── src/mlcast/
│ ├── __main__.py # CLI entry point (mlcast train)
│ ├── __main__.py # CLI entry point (train / stats / validate-stats)
│ ├── nowcasting_module.py # Generic Lightning module for nowcasting
│ ├── losses.py # CRPS, AFCRPS, MSE loss functions
│ ├── callbacks.py # Training callbacks
Expand All @@ -277,8 +359,13 @@ mlcast/
│ │ └── orchestrator.py # train_from_config, config persistence
│ ├── data/
│ │ ├── source_data_datamodule.py # Lightning DataModule
│ │ ├── source_data_datasets.py # Zarr-backed PyTorch datasets
│ │ ├── source_data_datasets.py # Zarr-backed PyTorch datasets (reads the stats index)
│ │ └── normalization.py # Normalisation registry
│ ├── sampling/ # Data-prep: stats parquet producer + sampler registry
│ │ ├── commands/ # stats / validate-stats CLI implementations
│ │ ├── samplers.py # Sampler registry (uniform, importance)
│ │ ├── stats_spec.py # Stats parquet schema + validation (the contract)
│ │ └── units.py # Data-kind + wet-threshold detection
│ └── models/
│ └── convgru.py # ConvGRU encoder-decoder
├── tests/
Expand Down
Loading
Loading