Fold the dataset sampler into mlcast (mlcast.sampling)#17
Open
franchg wants to merge 15 commits into
Open
Conversation
Phase 1 of folding the standalone mlcast-dataset-sampler into mlcast: the
training dataset now consumes the sampler's stats-parquet contract directly
and applies importance sampling at training time.
- Add mlcast.sampling subpackage:
- stats_spec: canonical stats-parquet contract (schema + pydantic
StatsMetadata), shared with the offline sampler
- samplers: pluggable per-row candidate-selection schemes via
SAMPLER_REGISTRY (Sampler ABC, UniformSampler, ImportanceSampler).
ImportanceSampler keeps each row with prob w/max(w) on a selectable
stats column (default 'mean'), reshaping toward extremes, no duplication
- units: rain-rate/reflectivity classification + wet thresholds
- Rename SourceDataPrecomputedSamplingDataset -> SourceDataIndexedDataset;
reads a stats parquet OR legacy CSV via index_path, applies an optional
sampler once at init (fixed, reproducible kept set)
- Inject the sampler per split from the datamodule (train_sampler vs
eval_sampler), like augment, so val/test stay representative
- Add pyarrow + pydantic deps
Tests: parquet/CSV indexing, sampler registry + schemes, per-split
injection. Full suite green.
Note: docs/config_diagram.svg + examples/config.ipynb still need
regeneration (graphviz) to reflect the rename/index_path.
The full candidate pool is impractical for val/test on large datasets, so the default experiment now uses UniformSampler(keep_fraction=0.1) for eval — representative (unweighted) but bounded — while train keeps ImportanceSampler.
Folds the rest of mlcast-dataset-sampler into mlcast: the `stats` producer
(zarr scan -> stats parquet, bottleneck CPU + torch GPU windowing) and
`validate-stats`, exposed as `mlcast` subcommands.
- mlcast/sampling/commands/{stats,validate_stats,_stats_gpu}.py + console.py,
mirroring the source layout so relative imports to stats_spec/units resolve
unchanged.
- mlcast CLI: add `stats` and `validate-stats` argparse subcommands beside
`train`. Their modules (and bottleneck) are imported lazily in dispatch, so
`mlcast train` and `import mlcast.sampling` never pull the producer deps.
- pyproject: [sampling] extra = bottleneck (everything else is already core);
GPU windowing reuses core torch.
- Port producer tests (stats_process/_gpu/_spec) with importorskip(bottleneck);
drop test_sampling (importance_weights now inlined into ImportanceSampler).
Full suite with --extra sampling green (114 passed, 1 skipped).
bottleneck is a tiny C lib (pandas already loads it when present), so gating
the data-prep CLI behind an extra was over-engineering. Making it core lets us
delete the extra, the lazy 'install the extra' error path, the
importorskip('bottleneck') test guards, and the E402 per-file ignore. Also
drop fire, which was an unused dependency (+ its termcolor transitive).
Command modules are still imported lazily in the CLI dispatch (keeps
`mlcast train` startup light); producer tests now run by default (no --extra).
Full suite green (114 passed, 1 skipped).
`mlcast -h` took 4.2s because importing __main__ eagerly pulled the whole training stack (torch, Fiddle, absl, the model/data config) and cli() built the Fiddle config graph just to render help — none of which -h, stats, or validate-stats need. - from __future__ import annotations + TYPE_CHECKING so the heavy imports aren't needed for type hints. - Move torch/Fiddle/absl/config imports and the absl flag definitions into a lazy _define_train_flags() + per-function imports, run only for `train`. - Build the rich `train` help lazily (a factory on the parser), so the config graph is constructed only for `mlcast train -h`. mlcast -h and import __main__: 4.22s -> 0.08s. train / train -h behaviour unchanged; full suite green.
The patched fdl.build returns a MagicMock trainer, so the config dump falls through to Path(trainer.default_root_dir)/config.yaml. An unconfigured mock's __fspath__ renders as MagicMock/<name>/<id>, so every run wrote junk under the repo root. The old line set trainer.log_dir, which that fallback branch never reads. Point default_root_dir at tmp_path so the write lands in pytest's tmp dir.
Add a 'Preparing training data' README section covering mlcast stats (cumsum window scan -> stats parquet, with a flags table and example), mlcast validate-stats, and the sampler registry (uniform / importance, and the per-split default). Also document the missing use_ratio_splits fiddler, cross-link from the CLI section, and refresh the project tree with the sampling/ subpackage.
ruff (the pinned v0.8.6 in pre-commit) flags the missing blank line between the third-party and first-party import groups.
uv run does not sync the dev extra, so pre-commit was never installed (the job failed to spawn it). Run it via --extra dev, and install graphviz for the local config-diagram hook's dot dependency.
Use the same alignment-safe clock/calendar emoji as the validate-stats summary grid (📅 Time range, 🕒 Time step). The old ⏱️ carried a U+FE0F variation selector that rich mismeasures, misaligning the panel's right border.
mlcast stats assigned x to axis 1 and y to axis 2, but the MLCast source-data spec (radar_precipitation §4.3) mandates dimension order (time, y=height, x=width), and the training dataset crops by dimension *name*. For a spec-compliant (time, y, x) store this transposed the parquet's x/y columns: x offsets ran over the y axis and vice-versa, so the dataset cropped the wrong region and ran out of bounds (e.g. a 112-wide crop instead of 256). Bind height/step_y to axis 1 and width/step_x to axis 2 in both the CPU and CUDA backends, and label survivor offsets y (axis 1) / x (axis 2). The two independent test oracles are flipped to the same convention. Verified end to end against a real (time, 1400, 1200) zarr: regenerated parquet now has x<=944 (fits the 1200 x dim) and y<=1144 (fits the 1400 y dim), and the dataset yields full 256x256 crops including the max-x/max-y rows. Note: parquets produced by the old code have x/y swapped and must be regenerated (or have their x/y columns swapped) to be consumable.
The default training config used 'rainfall_rate', which is not among the standard_names allowed by the source-data spec (radar_precipitation §4.4: rainfall_flux, precipitation_flux, equivalent_reflectivity_factor, precipitation_amount, rainfall_amount). Default to 'rainfall_flux' so a spec-compliant dataset trains without a set_variables override. Align the dataset docstring examples and regenerate the config diagram.
SourceDataIndexedDataset slices the zarr to the split's time subset (0-based) but kept the index's t as an absolute zarr index. For any split not starting at t=0 (val/test), __getitem__ indexed the sliced store with a huge absolute t, xarray clipped it to an empty crop, and the ConvGRU encoder hit 'stack expects a non-empty TensorList'. Filter to windows whose full depth fits the subset (no cross-split leakage), then rebase t = t - subset_start so it indexes the sliced store. Verified end to end: training runs to completion on the real 2010-2025 dataset, and crops carry the rain the parquet reports. The subset test now exercises a non-zero-start slice and asserts the rebased coordinates.
Mask a grid cell whenever it is NaN at any step of the sequence (inputs or targets), instead of per-timestep. A temporal discontinuity at a cell makes its forecast trajectory ill-defined — and the temporal-consistency loss term meaningless — so the cell should not be scored anywhere in the sequence (matching dpc-nowcasting's mask semantics). The mask is emitted as a single (1, C, H, W) tensor; the loss broadcasts it over the forecast steps, so no (forecast_steps, C, H, W) copy is materialised on the GPU. Also clarify the MaskedLoss broadcast_factor contract (the mask must broadcast into the equal-or-larger elementwise loss). Adds tests for the collapse semantics and the masked-loss broadcasting over a collapsed mask.
b8bc3a4 to
8410b88
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Folds the previously separate
mlcast-dataset-samplertool into themlcastpackage as a new
mlcast.samplingsubpackage. This removes the duplicationbetween the two repos and gives training a single in-package path from a source
radar Zarr to a sampled training set.
The design is a clean producer → consumer seam:
mlcast stats) — scans a source Zarr and writes a statsparquet: an index of candidate datacubes, each tagged with
nan_count/sum/mean/frac_wet. The parquet schema + metadata footer is thecontract (
mlcast.sampling.stats_spec).SourceDataIndexedDatasetreads that index, and apluggable
Samplerreshapes the candidate pool at dataset init.What's new
Data-prep CLI (folded into the existing
mlcastargparse CLI as subcommands):mlcast stats <zarr>— cumsum sliding-window scan → zstd stats parquet (CPUvia
bottleneck, or--device cudafor the GPU backend).mlcast validate-stats <parquet>— validates a parquet against the contract;prints the schema, metadata, and a row preview.
Sampler registry (
mlcast.sampling.samplers):SamplerABC +@register_sampler/get_sampler(mirrors the existingnormalization registry).
UniformSampler(keep_fraction)andImportanceSampler(column, q_min, scale, mean_weight)— per-row accept/reject,run once at init so the dataset length is fixed and reused every epoch.
SourceDataDataModule(train_sampler/eval_sampler):importance on train, uniform on val/test, so validation/test stay representative.
Training dataset:
SourceDataIndexedDatasetreads the parquetindex_path(and still accepts alegacy CSV index, used without a sampler).
Notable changes
bottleneckadded to core dependencies; the unusedfiredependency removed.mlcastCLI startup made lazy —mlcast -h~4.2s → ~0.08s by deferring thetraining stack to the
trainpath.csv_path→index_path.Testing
uv run pytest→ 114 passed, 1 skipped (the skip is the CUDA-only GPUbackend parity test).
ruff checkclean._process_chunkagainst two independent reference oracles, GPU/CPU parity, and the parquet
contract.
Out of scope / follow-ups
mlcast-dataset-samplerrepo (archive, or athin shim pointing here).