Skip to content

Fold the dataset sampler into mlcast (mlcast.sampling)#17

Open
franchg wants to merge 15 commits into
mlcast-community:mainfrom
franchg:feat/dataset-sampler-merge
Open

Fold the dataset sampler into mlcast (mlcast.sampling)#17
franchg wants to merge 15 commits into
mlcast-community:mainfrom
franchg:feat/dataset-sampler-merge

Conversation

@franchg

@franchg franchg commented Jun 6, 2026

Copy link
Copy Markdown
Member

Summary

Folds the previously separate mlcast-dataset-sampler tool into the mlcast
package as a new mlcast.sampling subpackage. This removes the duplication
between the two repos and gives training a single in-package path from a source
radar Zarr to a sampled training set.

The design is a clean producer → consumer seam:

  • Producer (mlcast stats) — scans a source Zarr and writes a stats
    parquet
    : an index of candidate datacubes, each tagged with nan_count /
    sum / mean / frac_wet. The parquet schema + metadata footer is the
    contract (mlcast.sampling.stats_spec).
  • Consumer (training) — SourceDataIndexedDataset reads that index, and a
    pluggable Sampler reshapes the candidate pool at dataset init.

What's new

Data-prep CLI (folded into the existing mlcast argparse CLI as subcommands):

  • mlcast stats <zarr> — cumsum sliding-window scan → zstd stats parquet (CPU
    via bottleneck, or --device cuda for the GPU backend).
  • mlcast validate-stats <parquet> — validates a parquet against the contract;
    prints the schema, metadata, and a row preview.

Sampler registry (mlcast.sampling.samplers):

  • Sampler ABC + @register_sampler / get_sampler (mirrors the existing
    normalization registry).
  • UniformSampler(keep_fraction) and
    ImportanceSampler(column, q_min, scale, mean_weight) — per-row accept/reject,
    run once at init so the dataset length is fixed and reused every epoch.
  • Applied per split by SourceDataDataModule (train_sampler / eval_sampler):
    importance on train, uniform on val/test, so validation/test stay representative.

Training dataset:

  • SourceDataIndexedDataset reads the parquet index_path (and still accepts a
    legacy CSV index, used without a sampler).

Notable changes

  • bottleneck added to core dependencies; the unused fire dependency removed.
  • mlcast CLI startup made lazy — mlcast -h ~4.2s → ~0.08s by deferring the
    training stack to the train path.
  • Example updated: csv_pathindex_path.
  • README gained a "Preparing training data" section documenting the workflow.

Testing

  • uv run pytest114 passed, 1 skipped (the skip is the CUDA-only GPU
    backend parity test).
  • ruff check clean.
  • New tests cover sampler semantics + registry, the stats _process_chunk
    against two independent reference oracles, GPU/CPU parity, and the parquet
    contract.

Out of scope / follow-ups

  • Decommissioning the standalone mlcast-dataset-sampler repo (archive, or a
    thin shim pointing here).

franchg added 8 commits June 6, 2026 13:51
Phase 1 of folding the standalone mlcast-dataset-sampler into mlcast: the
training dataset now consumes the sampler's stats-parquet contract directly
and applies importance sampling at training time.

- Add mlcast.sampling subpackage:
  - stats_spec: canonical stats-parquet contract (schema + pydantic
    StatsMetadata), shared with the offline sampler
  - samplers: pluggable per-row candidate-selection schemes via
    SAMPLER_REGISTRY (Sampler ABC, UniformSampler, ImportanceSampler).
    ImportanceSampler keeps each row with prob w/max(w) on a selectable
    stats column (default 'mean'), reshaping toward extremes, no duplication
  - units: rain-rate/reflectivity classification + wet thresholds
- Rename SourceDataPrecomputedSamplingDataset -> SourceDataIndexedDataset;
  reads a stats parquet OR legacy CSV via index_path, applies an optional
  sampler once at init (fixed, reproducible kept set)
- Inject the sampler per split from the datamodule (train_sampler vs
  eval_sampler), like augment, so val/test stay representative
- Add pyarrow + pydantic deps

Tests: parquet/CSV indexing, sampler registry + schemes, per-split
injection. Full suite green.

Note: docs/config_diagram.svg + examples/config.ipynb still need
regeneration (graphviz) to reflect the rename/index_path.
The full candidate pool is impractical for val/test on large datasets, so the
default experiment now uses UniformSampler(keep_fraction=0.1) for eval —
representative (unweighted) but bounded — while train keeps ImportanceSampler.
Folds the rest of mlcast-dataset-sampler into mlcast: the `stats` producer
(zarr scan -> stats parquet, bottleneck CPU + torch GPU windowing) and
`validate-stats`, exposed as `mlcast` subcommands.

- mlcast/sampling/commands/{stats,validate_stats,_stats_gpu}.py + console.py,
  mirroring the source layout so relative imports to stats_spec/units resolve
  unchanged.
- mlcast CLI: add `stats` and `validate-stats` argparse subcommands beside
  `train`. Their modules (and bottleneck) are imported lazily in dispatch, so
  `mlcast train` and `import mlcast.sampling` never pull the producer deps.
- pyproject: [sampling] extra = bottleneck (everything else is already core);
  GPU windowing reuses core torch.
- Port producer tests (stats_process/_gpu/_spec) with importorskip(bottleneck);
  drop test_sampling (importance_weights now inlined into ImportanceSampler).

Full suite with --extra sampling green (114 passed, 1 skipped).
bottleneck is a tiny C lib (pandas already loads it when present), so gating
the data-prep CLI behind an extra was over-engineering. Making it core lets us
delete the extra, the lazy 'install the extra' error path, the
importorskip('bottleneck') test guards, and the E402 per-file ignore. Also
drop fire, which was an unused dependency (+ its termcolor transitive).

Command modules are still imported lazily in the CLI dispatch (keeps
`mlcast train` startup light); producer tests now run by default (no --extra).
Full suite green (114 passed, 1 skipped).
`mlcast -h` took 4.2s because importing __main__ eagerly pulled the whole
training stack (torch, Fiddle, absl, the model/data config) and cli() built the
Fiddle config graph just to render help — none of which -h, stats, or
validate-stats need.

- from __future__ import annotations + TYPE_CHECKING so the heavy imports
  aren't needed for type hints.
- Move torch/Fiddle/absl/config imports and the absl flag definitions into a
  lazy _define_train_flags() + per-function imports, run only for `train`.
- Build the rich `train` help lazily (a factory on the parser), so the config
  graph is constructed only for `mlcast train -h`.

mlcast -h and import __main__: 4.22s -> 0.08s. train / train -h behaviour
unchanged; full suite green.
The patched fdl.build returns a MagicMock trainer, so the config dump
falls through to Path(trainer.default_root_dir)/config.yaml. An
unconfigured mock's __fspath__ renders as MagicMock/<name>/<id>, so
every run wrote junk under the repo root. The old line set
trainer.log_dir, which that fallback branch never reads. Point
default_root_dir at tmp_path so the write lands in pytest's tmp dir.
Add a 'Preparing training data' README section covering mlcast stats
(cumsum window scan -> stats parquet, with a flags table and example),
mlcast validate-stats, and the sampler registry (uniform / importance,
and the per-split default). Also document the missing use_ratio_splits
fiddler, cross-link from the CLI section, and refresh the project tree
with the sampling/ subpackage.
@franchg franchg requested a review from leifdenby June 6, 2026 14:37
franchg added 7 commits June 6, 2026 21:19
ruff (the pinned v0.8.6 in pre-commit) flags the missing blank line
between the third-party and first-party import groups.
uv run does not sync the dev extra, so pre-commit was never installed
(the job failed to spawn it). Run it via --extra dev, and install
graphviz for the local config-diagram hook's dot dependency.
Use the same alignment-safe clock/calendar emoji as the validate-stats
summary grid (📅 Time range, 🕒 Time step). The old ⏱️ carried a U+FE0F
variation selector that rich mismeasures, misaligning the panel's
right border.
mlcast stats assigned x to axis 1 and y to axis 2, but the MLCast
source-data spec (radar_precipitation §4.3) mandates dimension order
(time, y=height, x=width), and the training dataset crops by dimension
*name*. For a spec-compliant (time, y, x) store this transposed the
parquet's x/y columns: x offsets ran over the y axis and vice-versa,
so the dataset cropped the wrong region and ran out of bounds (e.g. a
112-wide crop instead of 256).

Bind height/step_y to axis 1 and width/step_x to axis 2 in both the
CPU and CUDA backends, and label survivor offsets y (axis 1) / x
(axis 2). The two independent test oracles are flipped to the same
convention. Verified end to end against a real (time, 1400, 1200)
zarr: regenerated parquet now has x<=944 (fits the 1200 x dim) and
y<=1144 (fits the 1400 y dim), and the dataset yields full 256x256
crops including the max-x/max-y rows.

Note: parquets produced by the old code have x/y swapped and must be
regenerated (or have their x/y columns swapped) to be consumable.
The default training config used 'rainfall_rate', which is not among
the standard_names allowed by the source-data spec (radar_precipitation
§4.4: rainfall_flux, precipitation_flux, equivalent_reflectivity_factor,
precipitation_amount, rainfall_amount). Default to 'rainfall_flux' so a
spec-compliant dataset trains without a set_variables override. Align
the dataset docstring examples and regenerate the config diagram.
SourceDataIndexedDataset slices the zarr to the split's time subset
(0-based) but kept the index's t as an absolute zarr index. For any
split not starting at t=0 (val/test), __getitem__ indexed the sliced
store with a huge absolute t, xarray clipped it to an empty crop, and
the ConvGRU encoder hit 'stack expects a non-empty TensorList'.

Filter to windows whose full depth fits the subset (no cross-split
leakage), then rebase t = t - subset_start so it indexes the sliced
store. Verified end to end: training runs to completion on the real
2010-2025 dataset, and crops carry the rain the parquet reports.
The subset test now exercises a non-zero-start slice and asserts the
rebased coordinates.
Mask a grid cell whenever it is NaN at any step of the sequence (inputs
or targets), instead of per-timestep. A temporal discontinuity at a cell
makes its forecast trajectory ill-defined — and the temporal-consistency
loss term meaningless — so the cell should not be scored anywhere in the
sequence (matching dpc-nowcasting's mask semantics).

The mask is emitted as a single (1, C, H, W) tensor; the loss broadcasts
it over the forecast steps, so no (forecast_steps, C, H, W) copy is
materialised on the GPU. Also clarify the MaskedLoss broadcast_factor
contract (the mask must broadcast into the equal-or-larger elementwise
loss). Adds tests for the collapse semantics and the masked-loss
broadcasting over a collapsed mask.
@franchg franchg force-pushed the feat/dataset-sampler-merge branch from b8bc3a4 to 8410b88 Compare June 7, 2026 15:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant