From aae960339a77b0b8055f4c77e2a652e7c7ae8c63 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 13:51:11 +0200 Subject: [PATCH 01/15] Integrate dataset-sampler contract + candidate sampling into mlcast Phase 1 of folding the standalone mlcast-dataset-sampler into mlcast: the training dataset now consumes the sampler's stats-parquet contract directly and applies importance sampling at training time. - Add mlcast.sampling subpackage: - stats_spec: canonical stats-parquet contract (schema + pydantic StatsMetadata), shared with the offline sampler - samplers: pluggable per-row candidate-selection schemes via SAMPLER_REGISTRY (Sampler ABC, UniformSampler, ImportanceSampler). ImportanceSampler keeps each row with prob w/max(w) on a selectable stats column (default 'mean'), reshaping toward extremes, no duplication - units: rain-rate/reflectivity classification + wet thresholds - Rename SourceDataPrecomputedSamplingDataset -> SourceDataIndexedDataset; reads a stats parquet OR legacy CSV via index_path, applies an optional sampler once at init (fixed, reproducible kept set) - Inject the sampler per split from the datamodule (train_sampler vs eval_sampler), like augment, so val/test stay representative - Add pyarrow + pydantic deps Tests: parquet/CSV indexing, sampler registry + schemes, per-split injection. Full suite green. Note: docs/config_diagram.svg + examples/config.ipynb still need regeneration (graphviz) to reflect the rename/index_path. --- README.md | 2 +- examples/scripts/simple_train.py | 5 +- pyproject.toml | 2 + src/mlcast/config/base.py | 8 +- src/mlcast/data/__init__.py | 4 +- src/mlcast/data/source_data_datamodule.py | 14 +- src/mlcast/data/source_data_datasets.py | 57 ++++-- src/mlcast/sampling/__init__.py | 47 +++++ src/mlcast/sampling/samplers.py | 116 +++++++++++++ src/mlcast/sampling/stats_spec.py | 203 ++++++++++++++++++++++ src/mlcast/sampling/units.py | 99 +++++++++++ tests/config/test_consistency_checks.py | 6 +- tests/data/test_data_module.py | 22 +++ tests/data/test_source_data_datasets.py | 137 +++++++++++++-- tests/sampling/test_samplers.py | 91 ++++++++++ uv.lock | 4 + 16 files changed, 781 insertions(+), 36 deletions(-) create mode 100644 src/mlcast/sampling/__init__.py create mode 100644 src/mlcast/sampling/samplers.py create mode 100644 src/mlcast/sampling/stats_spec.py create mode 100644 src/mlcast/sampling/units.py create mode 100644 tests/sampling/test_samplers.py diff --git a/README.md b/README.md index 28db293..ab27301 100644 --- a/README.md +++ b/README.md @@ -257,7 +257,7 @@ experiment.run() # trainer.fit() + trainer.test() | `set_variables` | `standard_names` | Sets the list of input variables on the dataset and updates `network.input_channels` to match | | `toggle_masking` | `enabled` | Toggles masked-loss mode by setting both `dataset_factory.return_mask` and `pl_module.masked_loss` to the same value | | `use_anon_s3_dataset` | `zarr_path`, `endpoint_url` | Points the dataset at an anonymous S3 object store; sets `zarr_path` and the required `storage_options` together | -| `use_random_sampler` | *(none)* | Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed CSV is available) | +| `use_random_sampler` | *(none)* | Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed sampling index is available) | ## Project Structure diff --git a/examples/scripts/simple_train.py b/examples/scripts/simple_train.py index cae7a0b..903024e 100644 --- a/examples/scripts/simple_train.py +++ b/examples/scripts/simple_train.py @@ -6,13 +6,12 @@ python -m mlcast train \\ --config config:convgru_experiment \\ --config set:data.zarr_path=/path/to/data.zarr \\ - --config set:data.csv_path=/path/to/sampled.csv \\ + --config set:data.index_path=/path/to/sampled.parquet \\ --config set:data.batch_size=32 \\ --config set:trainer.max_epochs=50 """ import fiddle as fdl - from mlcast.configs import convgru_experiment @@ -20,7 +19,7 @@ def main(): # Get the config graph — all parameters are overridable via dot-access cfg = convgru_experiment.as_buildable( zarr_path="/path/to/data.zarr", - csv_path="/path/to/sampled.csv", + index_path="/path/to/sampled.parquet", variable_name="RR", ) diff --git a/pyproject.toml b/pyproject.toml index 9817fe8..7ef7064 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,8 @@ dependencies = [ "nvidia-ml-py[mlflow]>=13.595.45", "pandas>=2", "psutil[mlflow]>=7.2.2", + "pyarrow>=17", + "pydantic>=2", "pytorch-lightning>=2", "pyyaml>=6", "rich>=13", diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index fca8117..e3c8c99 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -30,9 +30,10 @@ from pytorch_lightning.loggers import TensorBoardLogger from ..data.source_data_datamodule import SourceDataDataModule -from ..data.source_data_datasets import SourceDataPrecomputedSamplingDataset +from ..data.source_data_datasets import SourceDataIndexedDataset from ..models.convgru import ConvGruModel from ..nowcasting_module import NowcastLightningModule +from ..sampling import ImportanceSampler @dataclass @@ -63,9 +64,9 @@ def training_experiment() -> Experiment: Configured experiment with model, data, and trainer. """ dataset_factory = fdl.Partial( - SourceDataPrecomputedSamplingDataset, + SourceDataIndexedDataset, zarr_path="./data/radar.zarr", - csv_path="./data/sampled_datacubes.csv", + index_path="./data/sampled_datacubes.parquet", standard_names=["rainfall_rate"], input_steps=6, forecast_steps=12, @@ -76,6 +77,7 @@ def training_experiment() -> Experiment: data = SourceDataDataModule( dataset_factory=dataset_factory, splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, + train_sampler=ImportanceSampler(), batch_size=16, num_workers=8, pin_memory=True, diff --git a/src/mlcast/data/__init__.py b/src/mlcast/data/__init__.py index e4b2449..8c7344a 100644 --- a/src/mlcast/data/__init__.py +++ b/src/mlcast/data/__init__.py @@ -1,4 +1,4 @@ from .source_data_datamodule import SourceDataDataModule -from .source_data_datasets import SourceDataPrecomputedSamplingDataset +from .source_data_datasets import SourceDataIndexedDataset -__all__ = ["SourceDataDataModule", "SourceDataPrecomputedSamplingDataset"] +__all__ = ["SourceDataDataModule", "SourceDataIndexedDataset"] diff --git a/src/mlcast/data/source_data_datamodule.py b/src/mlcast/data/source_data_datamodule.py index 44f99bd..8d36358 100644 --- a/src/mlcast/data/source_data_datamodule.py +++ b/src/mlcast/data/source_data_datamodule.py @@ -17,6 +17,7 @@ splitting_uses_tuple_ranges, validate_splits, ) +from mlcast.sampling import Sampler class SourceDataDataModule(pl.LightningDataModule): @@ -29,12 +30,16 @@ class SourceDataDataModule(pl.LightningDataModule): ---------- dataset_factory : Callable[..., Dataset] A factory function (e.g., ``fdl.Partial``) that returns a Dataset instance. - It must accept ``subset`` and ``augment`` as keyword arguments. + It must accept ``subset``, ``augment``, and ``sampler`` keyword arguments. splits : dict of {str: dict} Nested mapping ``{coord: {split_name: value, ...}, ...}`` describing train/val/test subsets. Currently only the ``time`` coordinate is supported. Ratio mode uses float fractions, while datetime mode uses inclusive ``(start, end)`` ISO 8601 string tuples. + train_sampler, eval_sampler : Sampler or None, optional + Per-split candidate sampler passed to the factory (train vs val/test), + like ``augment``. Default ``None`` uses the full index. Keep importance + sampling on ``train_sampler`` so val/test stay representative. **dataloader_kwargs : Any Additional keyword arguments forwarded to ``DataLoader`` (e.g., ``batch_size``, ``num_workers``, ``pin_memory``). @@ -44,11 +49,15 @@ def __init__( self, dataset_factory: Callable[..., Dataset], splits: dict[str, dict[str, Any]], + train_sampler: Sampler | None = None, + eval_sampler: Sampler | None = None, **dataloader_kwargs: Any, ) -> None: super().__init__() self.dataset_factory = dataset_factory self.splits = splits + self.train_sampler = train_sampler + self.eval_sampler = eval_sampler self.dataloader_kwargs = dataloader_kwargs validate_splits(self.splits) @@ -126,6 +135,7 @@ def setup(self, stage: str | None = None) -> None: subset_per_split[split_name][coord] = split_val augment_flags = {"train": True, "val": False, "test": False} + sampler_flags = {"train": self.train_sampler, "val": self.eval_sampler, "test": self.eval_sampler} for split in ("train", "val", "test"): subset = subset_per_split[split] if subset is None: @@ -134,7 +144,7 @@ def setup(self, stage: str | None = None) -> None: setattr( self, f"{split}_dataset", - self.dataset_factory(subset=subset, augment=augment_flags[split]), + self.dataset_factory(subset=subset, augment=augment_flags[split], sampler=sampler_flags[split]), ) logger.info("{}.setup() complete, containing:", self.__class__.__name__) diff --git a/src/mlcast/data/source_data_datasets.py b/src/mlcast/data/source_data_datasets.py index b8e3c80..1e83b58 100644 --- a/src/mlcast/data/source_data_datasets.py +++ b/src/mlcast/data/source_data_datasets.py @@ -1,16 +1,19 @@ """PyTorch datasets for loading spatio-temporal data from Zarr stores. -Provides pre-computed sampling and (soon) random sampling datasets. +Provides an indexed dataset (crops from a precomputed index) and a +random-sampling dataset. """ import time import warnings from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, TypedDict import cf_xarray # noqa: F401 import numpy as np import pandas as pd +import pyarrow.parquet as pq import torch import xarray as xr from beartype import beartype @@ -18,6 +21,25 @@ from torch.utils.data import Dataset from mlcast.data.normalization import NORMALIZATION_REGISTRY +from mlcast.sampling import Sampler + + +def _load_sampling_index(path: str) -> pd.DataFrame: + """Load a precomputed sampling index as a DataFrame. + + Accepts a stats parquet (the dataset sampler's output) or a legacy + ``.csv``. Returns at least the ``t, x, y`` crop-corner columns, plus the + per-datacube ``mean`` column when the file carries it (parquet only) — the + latter feeds importance sampling. Only the needed columns are read. + """ + suffix = Path(path).suffix.lower() + if suffix == ".csv": + return pd.read_csv(path) + if suffix in (".parquet", ".pq"): + available = set(pq.read_schema(path).names) + columns = [c for c in ("t", "x", "y", "mean") if c in available] + return pq.read_table(path, columns=columns).to_pandas() + raise ValueError(f"Unsupported sampling index format {suffix!r} for {path!r}; expected .parquet or .csv") def _time_range_to_index_slice( @@ -311,9 +333,8 @@ def __len__(self) -> int: ... def __getitem__(self, idx: int) -> DatasetSample: ... -class SourceDataPrecomputedSamplingDataset(SourceDataDatasetBase): - """PyTorch dataset that loads spatio-temporal data from a Zarr store using - pre-sampled spatial-temporal coordinates from a CSV file. +class SourceDataIndexedDataset(SourceDataDatasetBase): + """PyTorch dataset yielding Zarr crops at locations read from a precomputed index. Each sample is a spatio-temporal crop of shape ``(T, C, H, W)`` converted to normalized data. @@ -322,9 +343,10 @@ class SourceDataPrecomputedSamplingDataset(SourceDataDatasetBase): ---------- zarr_path : str Path to the Zarr dataset. - csv_path : str - Path to the CSV file with columns ``(t, x, y)`` specifying the - top-left corner of each crop. + index_path : str + Path to the sampling index of ``(t, x, y)`` crop corners: a stats + parquet (the candidate pool, optionally filtered by ``sampler``) or a + legacy ``.csv`` (already sampled, used as-is). standard_names : list of str List of CF standard names of variables to load (e.g., ``["rainfall_rate"]``). input_steps : int @@ -347,12 +369,18 @@ class SourceDataPrecomputedSamplingDataset(SourceDataDatasetBase): Spatial height of each crop. Default is ``256``. time_depth : int, optional Number of timesteps in the sampled window. Default is ``24``. + sampler : Sampler or None, optional + Optional sampler to filter the candidate index with a chosen strategy, + applied once at init (see :mod:`mlcast.sampling.samplers`). Default + ``None`` keeps every candidate. + sampling_seed : int, optional + Seed for the sampler's one-time selection. Default ``42``. """ def __init__( self, zarr_path: str, - csv_path: str, + index_path: str, standard_names: list[str], input_steps: int, forecast_steps: int, @@ -364,6 +392,8 @@ def __init__( height: int = 256, time_depth: int = 24, storage_options: dict[str, Any] | None = None, + sampler: Sampler | None = None, + sampling_seed: int = 42, ) -> None: if subset: for key in subset: @@ -389,7 +419,7 @@ def __init__( storage_options=storage_options, ) - self.coords = pd.read_csv(csv_path).sort_values("t") + self.coords = _load_sampling_index(index_path).sort_values("t") if self._time_index_slice is not None: t_start = self._time_index_slice.start t_stop = self._time_index_slice.stop @@ -397,6 +427,10 @@ def __init__( drop=True ) + if sampler is not None: + selected = sampler.select(self.coords, np.random.default_rng(sampling_seed)) + self.coords = self.coords.iloc[selected].reset_index(drop=True) + self.dt = time_depth if self.steps > self.dt: @@ -431,7 +465,8 @@ def __getitem__(self, idx: int) -> DatasetSample: ``(forecast_steps, C, H, W)`` with 1 where the original data was valid and 0 where it was NaN. """ - t0, x0, y0 = self.coords.iloc[idx] + row = self.coords.iloc[idx] + t0, x0, y0 = row["t"], row["x"], row["y"] x_slice = slice(int(x0), int(x0) + self.w) y_slice = slice(int(y0), int(y0) + self.h) @@ -488,7 +523,7 @@ class SourceDataRandomSamplingDataset(SourceDataDatasetBase): epoch_size : int, optional Number of random samples to generate per epoch. Default is ``1000``. **kwargs : Any - Ignored extra arguments (e.g. ``csv_path``) to allow drop-in replacement. + Ignored extra arguments (e.g. ``index_path``) to allow drop-in replacement. """ def __init__( diff --git a/src/mlcast/sampling/__init__.py b/src/mlcast/sampling/__init__.py new file mode 100644 index 0000000..483c47d --- /dev/null +++ b/src/mlcast/sampling/__init__.py @@ -0,0 +1,47 @@ +"""Datacube sampling support for mlcast. + +This subpackage holds the contract and helpers shared between the offline +*dataset sampler* (which scans a source Zarr and writes a per-datacube stats +parquet) and the training-time dataset that consumes that parquet: + +- :mod:`mlcast.sampling.stats_spec` — the canonical stats-parquet contract + (column schema + validated metadata), read by the precomputed-sampling + dataset instead of re-parsing a filename. +- :mod:`mlcast.sampling.samplers` — pluggable candidate-selection schemes + (``Sampler`` + ``SAMPLER_REGISTRY``), e.g. :class:`ImportanceSampler`. +- :mod:`mlcast.sampling.units` — rain-rate vs reflectivity classification and + default wet-pixel thresholds, from CF attributes. +""" + +from .samplers import ( + SAMPLER_REGISTRY, + ImportanceSampler, + Sampler, + UniformSampler, + get_sampler, + register_sampler, +) +from .stats_spec import ( + STATS_SCHEMA, + StatsMetadata, + ValidationReport, + read_metadata, + validate_stats_parquet, +) +from .units import default_wet_threshold, detect_data_kind + +__all__ = [ + "SAMPLER_REGISTRY", + "STATS_SCHEMA", + "ImportanceSampler", + "Sampler", + "StatsMetadata", + "UniformSampler", + "ValidationReport", + "default_wet_threshold", + "detect_data_kind", + "get_sampler", + "read_metadata", + "register_sampler", + "validate_stats_parquet", +] diff --git a/src/mlcast/sampling/samplers.py b/src/mlcast/sampling/samplers.py new file mode 100644 index 0000000..d3d232c --- /dev/null +++ b/src/mlcast/sampling/samplers.py @@ -0,0 +1,116 @@ +"""Pluggable candidate-selection schemes for the precomputed-sampling dataset. + +A :class:`Sampler` filters the candidate pool (the rows of a stats parquet) to +a training subset once, at dataset init. Add a scheme by subclassing +:class:`Sampler`, decorating it with ``@register_sampler("name")``, and +implementing :meth:`Sampler.select`; it is then available via +:func:`get_sampler` (e.g. from a config). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import numpy as np +import pandas as pd + + +class Sampler(ABC): + """Selects a subset of candidate rows via a per-row keep/discard decision.""" + + @abstractmethod + def select(self, coords: pd.DataFrame, rng: np.random.Generator) -> np.ndarray: + """Return the positions of the kept rows (each selected at most once).""" + + +SAMPLER_REGISTRY: dict[str, type[Sampler]] = {} + + +def register_sampler(name: str): + """Class decorator registering a :class:`Sampler` subclass under ``name``.""" + + def decorator(cls: type[Sampler]) -> type[Sampler]: + SAMPLER_REGISTRY[name] = cls + cls.sampler_name = name + return cls + + return decorator + + +@register_sampler("uniform") +class UniformSampler(Sampler): + """Keep each candidate with a fixed probability, independent of its stats. + + Parameters + ---------- + keep_fraction : float + Per-row keep probability in ``[0, 1]``. ``1.0`` (default) keeps the + whole pool; smaller values take a random uniform subsample. + """ + + def __init__(self, keep_fraction: float = 1.0) -> None: + if not 0.0 <= keep_fraction <= 1.0: + raise ValueError(f"keep_fraction must be in [0, 1], got {keep_fraction}") + self.keep_fraction = keep_fraction + + def select(self, coords: pd.DataFrame, rng: np.random.Generator) -> np.ndarray: + if self.keep_fraction >= 1.0: + return np.arange(len(coords)) + return np.flatnonzero(rng.random(len(coords)) < self.keep_fraction) + + +@register_sampler("importance") +class ImportanceSampler(Sampler): + """Keep each candidate with probability ``w / w.max()``, where the weight + ``w = q_min + mean_weight * (1 - exp(-s / scale))`` rises with a per-row + statistic ``s`` (the ``column``). High-statistic datacubes are kept + preferentially and common ones thinned out, without duplication. Needs the + chosen ``column`` (a legacy CSV index has none). + + Parameters + ---------- + column : str + The stats-parquet column to weight on, e.g. ``"mean"`` (default), + ``"sum"``, or ``"frac_wet"``. + q_min : float + Floor weight on every candidate, keeping some low-statistic windows. + scale : float + Saturation scale of ``1 - exp(-s / scale)``; set it on the order of the + column's typical magnitude (``mean``/``frac_wet`` ~ O(1); ``sum`` large). + mean_weight : float + Weight given to the statistic; relative to ``q_min`` it sets how hard + low-statistic windows are thinned versus high ones. + """ + + def __init__(self, column: str = "mean", q_min: float = 1e-4, scale: float = 1.0, mean_weight: float = 0.1) -> None: + if scale <= 0: + raise ValueError(f"scale must be positive, got {scale}") + if q_min < 0 or mean_weight < 0: + raise ValueError(f"q_min and mean_weight must be non-negative, got {q_min}, {mean_weight}") + self.column = column + self.q_min = q_min + self.scale = scale + self.mean_weight = mean_weight + + def select(self, coords: pd.DataFrame, rng: np.random.Generator) -> np.ndarray: + if self.column not in coords.columns: + raise ValueError( + f"ImportanceSampler needs the {self.column!r} statistic column, absent from this " + f"index (columns: {list(coords.columns)}); a legacy CSV index carries only " + f"(t, x, y), so use it without a sampler (sampler=None)." + ) + # floor weight + a saturating response to the statistic; NaNs floored to q_min + stat = np.nan_to_num(coords[self.column].to_numpy(dtype=float), nan=0.0) + weights = self.q_min + self.mean_weight * (1.0 - np.exp(-stat / self.scale)) + w_max = weights.max(initial=0.0) + probs = weights / w_max if w_max > 0 else np.zeros_like(weights) + return np.flatnonzero(rng.random(len(coords)) < probs) + + +def get_sampler(name: str, **kwargs) -> Sampler: + """Construct a registered sampler by name, e.g. ``get_sampler("importance", scale=2.0)``.""" + try: + cls = SAMPLER_REGISTRY[name] + except KeyError: + raise ValueError(f"Unknown sampler {name!r}; available: {sorted(SAMPLER_REGISTRY)}") from None + return cls(**kwargs) diff --git a/src/mlcast/sampling/stats_spec.py b/src/mlcast/sampling/stats_spec.py new file mode 100644 index 0000000..a77fab8 --- /dev/null +++ b/src/mlcast/sampling/stats_spec.py @@ -0,0 +1,203 @@ +"""Canonical contract for a stats parquet file. + +A "stats parquet" is the output of the dataset sampler's ``stats`` command: +one row per surviving datacube candidate, plus schema-level JSON metadata +carrying the sampling parameters. Both halves of that contract live here: + +- ``STATS_SCHEMA`` — the column layout, which the sampler writes from. +- ``StatsMetadata`` — a pydantic model of the sampling parameters, with + field- and cross-field validation. Consumers read it via + :func:`read_metadata` instead of re-parsing a filename. + +:func:`validate_stats_parquet` checks an arbitrary file against both. +""" + +from __future__ import annotations + +import json +from datetime import datetime +from typing import Literal + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +from pydantic import ( + BaseModel, + ConfigDict, + NonNegativeFloat, + NonNegativeInt, + PositiveInt, + model_validator, +) + +SCHEMA_VERSION = 1 +STATS_METADATA_KEY = b"mlcast.stats" + +#: Single source of truth for the column layout of a stats parquet. +STATS_SCHEMA: pa.Schema = pa.schema( + [ + ("t", pa.int32()), + ("x", pa.int32()), + ("y", pa.int32()), + ("nan_count", pa.int32()), + ("sum", pa.float32()), + ("mean", pa.float32()), + ("frac_wet", pa.float32()), + ] +) + +#: Column names in canonical order, derived from STATS_SCHEMA. +STAT_COLUMNS: tuple[str, ...] = tuple(STATS_SCHEMA.names) + + +class StatsMetadata(BaseModel): + """Sampling parameters carried in a stats parquet's schema metadata. + + Field names match the JSON keys stored under ``mlcast.stats``. Build from + a raw dict with ``StatsMetadata.model_validate(payload)`` (unknown keys are + ignored); serialise with ``model_dump(mode="json")``. Constraints are + declared on the field types; the only imperative rules are cross-field. + """ + + model_config = ConfigDict(frozen=True) + + zarr_path: str + data_var: str + time_var: str + start_date: datetime + end_date: datetime + time_step_minutes: PositiveInt + time_depth: PositiveInt + width: PositiveInt + height: PositiveInt + step_t: PositiveInt + step_x: PositiveInt + step_y: PositiveInt + max_nan: NonNegativeInt + wet_threshold: NonNegativeFloat + data_kind: Literal["rainrate", "reflectivity"] # mirrors units.DEFAULT_WET_THRESHOLD + units: str | None = None + schema_version: int = SCHEMA_VERSION + + @property + def total_px(self) -> int: + """Number of pixels in one datacube (time_depth * width * height).""" + return self.time_depth * self.width * self.height + + @model_validator(mode="after") + def _check_cross_field(self) -> StatsMetadata: + if self.max_nan > self.total_px: + raise ValueError(f"max_nan ({self.max_nan}) exceeds datacube size ({self.total_px})") + if self.start_date > self.end_date: + raise ValueError(f"start_date ({self.start_date}) is after end_date ({self.end_date})") + return self + + +def build_schema(metadata: StatsMetadata) -> pa.Schema: + """Return STATS_SCHEMA with this file's metadata attached. + + This is what the sampler hands to its ``ParquetWriter`` so the column + layout and the metadata payload come from one place. + """ + payload = metadata.model_dump(mode="json") + encoded = {STATS_METADATA_KEY: json.dumps(payload, sort_keys=True).encode()} + return STATS_SCHEMA.with_metadata(encoded) + + +def read_metadata(path: str) -> StatsMetadata: + """Load and validate the ``mlcast.stats`` metadata from a parquet file. + + Raises + ------ + KeyError + If the file carries no ``mlcast.stats`` metadata key. + pydantic.ValidationError + If the payload is malformed or violates a constraint. + """ + schema = pq.read_schema(path) + if schema.metadata is None or STATS_METADATA_KEY not in schema.metadata: + raise KeyError( + f"{path}: no 'mlcast.stats' metadata found in parquet schema. " + f"Not a stats file produced by the mlcast dataset sampler?" + ) + payload = json.loads(schema.metadata[STATS_METADATA_KEY].decode()) + return StatsMetadata.model_validate(payload) + + +class ValidationReport(BaseModel): + """Result of validating a stats parquet file.""" + + path: str + errors: list[str] = [] + warnings: list[str] = [] + + @property + def ok(self) -> bool: + return not self.errors + + +def validate_stats_parquet(path: str, *, check_data: bool = True) -> ValidationReport: + """Validate a stats parquet file against the canonical contract. + + Always checks the column schema (names + Arrow dtypes) and the metadata + payload. When ``check_data`` is true, additionally reads the columns and + asserts per-row value invariants. Returns a :class:`ValidationReport`; + never raises for contract violations (only for unreadable files). + """ + report = ValidationReport(path=path) + + # --- Metadata (ValidationError is a ValueError subclass) ------------------ + meta: StatsMetadata | None = None + try: + meta = read_metadata(path) + except KeyError as e: + report.errors.append(str(e)) + except ValueError as e: + report.errors.append(f"metadata: {e}") + + if meta is not None and meta.schema_version != SCHEMA_VERSION: + report.warnings.append(f"schema_version {meta.schema_version} != current {SCHEMA_VERSION}") + + # --- Column schema -------------------------------------------------------- + schema = pq.read_schema(path) + actual = dict(zip(schema.names, schema.types, strict=True)) + missing = [c for c in STATS_SCHEMA.names if c not in actual] + if missing: + report.errors.append(f"missing columns: {missing}") + for name in STATS_SCHEMA.names: + expected = STATS_SCHEMA.field(name).type + if name in actual and actual[name] != expected: + report.errors.append(f"column {name!r} has dtype {actual[name]}, expected {expected}") + + if not check_data or missing: + return report + + # --- Per-row value sanity (full read) ------------------------------------ + table = pq.read_table(path, columns=list(STATS_SCHEMA.names)) + if table.num_rows == 0: + report.warnings.append("file has zero rows") + return report + cols = {name: table.column(name).to_numpy() for name in STATS_SCHEMA.names} + + for name in ("t", "x", "y", "nan_count"): + if (cols[name] < 0).any(): + report.errors.append(f"column {name!r} has negative values") + + nan_count = cols["nan_count"] + if meta is not None: + if (nan_count > meta.total_px).any(): + report.errors.append(f"nan_count exceeds datacube size {meta.total_px} for some rows") + if (nan_count > meta.max_nan).any(): + report.errors.append( + f"nan_count exceeds metadata max_nan ({meta.max_nan}) for some rows " + f"(the stats command applies this as a hard filter)" + ) + + frac_wet = cols["frac_wet"] + if (~((frac_wet >= 0.0) & (frac_wet <= 1.0))).any(): + report.errors.append("frac_wet has values outside [0, 1]") + for name in ("sum", "mean", "frac_wet"): + if np.isinf(cols[name]).any(): + report.errors.append(f"column {name!r} contains +/-inf") + + return report diff --git a/src/mlcast/sampling/units.py b/src/mlcast/sampling/units.py new file mode 100644 index 0000000..20e9b95 --- /dev/null +++ b/src/mlcast/sampling/units.py @@ -0,0 +1,99 @@ +"""Detect whether a zarr data variable is rain rate or radar reflectivity. + +Different MLCast source datasets use different units — e.g. IT-DPC stores +rainfall flux in `kg m-2 h-1` (equivalent to mm/h), DMI stores radar +reflectivity in `dBZ`. Sampling statistics that depend on a "wet pixel" +threshold (e.g. `frac_wet`) need a threshold in the same units as the +data, so this module inspects CF-convention attributes to classify the +variable and returns a sensible default threshold. + +The classification order is: +1. `standard_name` (most authoritative, CF convention). +2. `units` (normalized to lowercase, whitespace-stripped). +3. Fallback: raise, unless the caller provides an override. +""" + +from __future__ import annotations + +from collections.abc import Mapping + +RAIN_RATE_STANDARD_NAMES = { + "rainfall_flux", + "rainfall_rate", + "precipitation_flux", + "lwe_precipitation_rate", +} +REFLECTIVITY_STANDARD_NAMES = { + "equivalent_reflectivity_factor", + "radar_reflectivity", +} + +RAIN_RATE_UNITS = { + "mm/h", + "mmh-1", + "mmh1", + "mm/hr", + "mmhr-1", + "mmh^-1", + "kgm-2h-1", + "kgm2h1", +} +REFLECTIVITY_UNITS = {"dbz", "db"} + +# Defaults chosen to match conventional "meaningful rain" cutoffs: +# - 0.1 mm/h is the standard drizzle floor +# - 7 dBZ corresponds to R ≈ 0.08 mm/h via Marshall-Palmer, matching it +DEFAULT_WET_THRESHOLD = { + "rainrate": 0.1, + "reflectivity": 7.0, +} + + +def _normalize_units(units: str | None) -> str: + if not units: + return "" + return units.strip().lower().replace(" ", "").replace("**", "^") + + +def detect_data_kind(attrs: Mapping[str, object]) -> str: + """Classify a variable as 'rainrate' or 'reflectivity' from its attrs. + + Parameters + ---------- + attrs + The attribute dict of a zarr/xarray variable. + + Returns + ------- + 'rainrate' or 'reflectivity' + + Raises + ------ + ValueError + If the attributes don't match any known rain-rate or reflectivity + indicator. In that case the caller should fall back to an explicit + CLI override. + """ + std = str(attrs.get("standard_name", "") or "").strip().lower() + if std in RAIN_RATE_STANDARD_NAMES: + return "rainrate" + if std in REFLECTIVITY_STANDARD_NAMES: + return "reflectivity" + + units = _normalize_units(str(attrs.get("units", "") or "")) + if units in RAIN_RATE_UNITS: + return "rainrate" + if units in REFLECTIVITY_UNITS: + return "reflectivity" + + raise ValueError( + f"Cannot auto-detect data kind: standard_name={std!r}, units={units!r}. Pass --data-kind explicitly." + ) + + +def default_wet_threshold(data_kind: str) -> float: + """Conventional wet-pixel threshold for a given data kind.""" + try: + return DEFAULT_WET_THRESHOLD[data_kind] + except KeyError: + raise ValueError(f"Unknown data_kind {data_kind!r}") from None diff --git a/tests/config/test_consistency_checks.py b/tests/config/test_consistency_checks.py index 3dc8ec3..a37c68b 100644 --- a/tests/config/test_consistency_checks.py +++ b/tests/config/test_consistency_checks.py @@ -4,7 +4,7 @@ from loguru import logger from mlcast.config import training_experiment, validate_config -from mlcast.data.source_data_datasets import SourceDataPrecomputedSamplingDataset +from mlcast.data.source_data_datasets import SourceDataIndexedDataset def test_contract_1_input_channels() -> None: @@ -74,9 +74,9 @@ def test_contract_4_masking_sync() -> None: def test_dataset_forecast_steps_guard() -> None: """Verify that dataset raises ValueError when input_steps=0.""" with pytest.raises(ValueError, match="input_steps"): - SourceDataPrecomputedSamplingDataset( + SourceDataIndexedDataset( zarr_path="dummy.zarr", - csv_path="dummy.csv", + index_path="dummy.csv", standard_names=["rainfall_rate"], input_steps=0, forecast_steps=5, diff --git a/tests/data/test_data_module.py b/tests/data/test_data_module.py index e55bf58..1e3f1d6 100644 --- a/tests/data/test_data_module.py +++ b/tests/data/test_data_module.py @@ -7,6 +7,7 @@ from mlcast.data.source_data_datamodule import SourceDataDataModule from mlcast.data.splits import splitting_uses_fractions, splitting_uses_tuple_ranges, validate_splits +from mlcast.sampling import UniformSampler class MockDataset(Dataset): @@ -329,3 +330,24 @@ def test_data_module_unsupported_split_mode() -> None: with pytest.raises(NotImplementedError, match="Unsupported split mode"): dm.setup() + + +def test_data_module_injects_train_sampler_only_on_train() -> None: + """train gets train_sampler; val/test get eval_sampler (representative).""" + train_s = UniformSampler(keep_fraction=0.5) + eval_s = UniformSampler(keep_fraction=1.0) + time_index = _make_time_index(100) + dm = SourceDataDataModule( + dataset_factory=functools.partial(MockDataset, zarr_path="mock.zarr"), + splits={"time": {"train": 0.5, "val": 0.2, "test": 0.3}}, + train_sampler=train_s, + eval_sampler=eval_s, + batch_size=2, + ) + + with patch("mlcast.data.splits.xr.open_zarr", return_value=_mock_zarr(time_index)): + dm.setup() + + assert dm.train_dataset.kwargs["sampler"] is train_s + assert dm.val_dataset.kwargs["sampler"] is eval_s + assert dm.test_dataset.kwargs["sampler"] is eval_s diff --git a/tests/data/test_source_data_datasets.py b/tests/data/test_source_data_datasets.py index c9a2f2f..d3a9f3e 100644 --- a/tests/data/test_source_data_datasets.py +++ b/tests/data/test_source_data_datasets.py @@ -1,14 +1,19 @@ from pathlib import Path +import numpy as np import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq import pytest import torch import xarray as xr from mlcast.data.source_data_datasets import ( - SourceDataPrecomputedSamplingDataset, + SourceDataIndexedDataset, SourceDataRandomSamplingDataset, ) +from mlcast.sampling import ImportanceSampler, UniformSampler +from mlcast.sampling.stats_spec import StatsMetadata, build_schema @pytest.fixture @@ -26,13 +31,123 @@ def mock_csv(tmp_path: Path) -> str: return str(csv_path) -def test_precomputed_sampling_dataset(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that SourceDataPrecomputedSamplingDataset outputs the correct shape.""" +@pytest.fixture +def mock_parquet(tmp_path: Path) -> str: + """Create a temporary stats parquet (the sampler's output) with a mean column.""" + meta = StatsMetadata( + zarr_path="dummy.zarr", + data_var="RR", + time_var="time", + start_date="2016-01-01", + end_date="2016-12-31", + time_step_minutes=5, + time_depth=2, + width=16, + height=16, + step_t=1, + step_x=1, + step_y=1, + max_nan=0, + wet_threshold=0.1, + data_kind="rainrate", + units="mm/h", + ) + table = pa.table( + { + "t": pa.array([0, 5, 10], pa.int32()), + "x": pa.array([10, 20, 30], pa.int32()), + "y": pa.array([10, 20, 30], pa.int32()), + "nan_count": pa.array([0, 0, 0], pa.int32()), + "sum": pa.array([1.0, 2.0, 3.0], pa.float32()), + "mean": pa.array([0.05, 2.0, 10.0], pa.float32()), + "frac_wet": pa.array([0.0, 0.1, 0.5], pa.float32()), + }, + schema=build_schema(meta), + ) + parquet_path = tmp_path / "stats.parquet" + pq.write_table(table, parquet_path) + return str(parquet_path) + + +def test_indexed_sampling_dataset_parquet(fp_test_dataset: Path, mock_parquet: str) -> None: + """Reads a stats parquet index; with no sampler the full pool is used.""" + ds = SourceDataIndexedDataset( + zarr_path=str(fp_test_dataset), + index_path=mock_parquet, + standard_names=["rainfall_flux"], + input_steps=2, + forecast_steps=1, + width=16, + height=16, + return_mask=True, + ) + + assert len(ds) == 3 # all candidates, as-is + sample = ds[0] + assert sample["input"].shape == (2, 1, 16, 16) + assert sample["target"].shape == (1, 1, 16, 16) + + +def test_indexed_importance_selection_is_fixed_and_keeps_extremes(fp_test_dataset: Path, mock_parquet: str) -> None: + """ImportanceSampler selects a fixed, reproducible subset that keeps extremes.""" + kwargs = dict( + zarr_path=str(fp_test_dataset), + index_path=mock_parquet, + standard_names=["rainfall_flux"], + input_steps=2, + forecast_steps=1, + width=16, + height=16, + sampler=ImportanceSampler(), + sampling_seed=0, + ) + ds = SourceDataIndexedDataset(**kwargs) + + # a subset of the 3 candidates, with the wettest (t=10) always kept + assert 1 <= len(ds) <= 3 + assert 10 in ds.coords["t"].to_numpy() + # reproducible: same seed -> identical kept set + ds2 = SourceDataIndexedDataset(**kwargs) + assert np.array_equal(ds.coords["t"].to_numpy(), ds2.coords["t"].to_numpy()) + + +def test_indexed_importance_sampler_requires_mean_column(fp_test_dataset: Path, mock_csv: str) -> None: + """ImportanceSampler rejects a CSV index that has no mean column.""" + with pytest.raises(ValueError, match="mean"): + SourceDataIndexedDataset( + zarr_path=str(fp_test_dataset), + index_path=mock_csv, + standard_names=["rainfall_flux"], + input_steps=2, + forecast_steps=1, + width=16, + height=16, + sampler=ImportanceSampler(), + ) + + +def test_indexed_uniform_sampler_works_on_csv(fp_test_dataset: Path, mock_csv: str) -> None: + """A non-importance sampler (uniform) needs no mean column, so works on a CSV.""" + ds = SourceDataIndexedDataset( + zarr_path=str(fp_test_dataset), + index_path=mock_csv, + standard_names=["rainfall_flux"], + input_steps=2, + forecast_steps=1, + width=16, + height=16, + sampler=UniformSampler(keep_fraction=1.0), + ) + assert len(ds) == 3 # keep_fraction=1.0 -> the whole (3-row) index + + +def test_indexed_sampling_dataset(fp_test_dataset: Path, mock_csv: str) -> None: + """Test that SourceDataIndexedDataset outputs the correct shape.""" input_steps = 2 forecast_steps = 1 - ds = SourceDataPrecomputedSamplingDataset( + ds = SourceDataIndexedDataset( zarr_path=str(fp_test_dataset), - csv_path=mock_csv, + index_path=mock_csv, standard_names=["rainfall_flux"], input_steps=input_steps, forecast_steps=forecast_steps, @@ -60,13 +175,13 @@ def test_precomputed_sampling_dataset(fp_test_dataset: Path, mock_csv: str) -> N assert isinstance(target_mask_t, torch.Tensor) -def test_precomputed_sampling_dataset_time_subset(fp_test_dataset: Path, mock_csv: str) -> None: +def test_indexed_sampling_dataset_time_subset(fp_test_dataset: Path, mock_csv: str) -> None: """Test that subset correctly filters CSV rows by time range.""" zarr_ds = xr.open_zarr(str(fp_test_dataset)) time_index = zarr_ds.indexes["time"] - ds = SourceDataPrecomputedSamplingDataset( + ds = SourceDataIndexedDataset( zarr_path=str(fp_test_dataset), - csv_path=mock_csv, + index_path=mock_csv, standard_names=["rainfall_flux"], input_steps=2, forecast_steps=1, @@ -75,12 +190,12 @@ def test_precomputed_sampling_dataset_time_subset(fp_test_dataset: Path, mock_cs assert len(ds) == 2 -def test_precomputed_sampling_dataset_forecast_steps_guard(fp_test_dataset: Path, mock_csv: str) -> None: +def test_indexed_sampling_dataset_forecast_steps_guard(fp_test_dataset: Path, mock_csv: str) -> None: """Test that instantiation with input_steps=0 raises ValueError.""" with pytest.raises(ValueError, match="input_steps"): - SourceDataPrecomputedSamplingDataset( + SourceDataIndexedDataset( zarr_path=str(fp_test_dataset), - csv_path=mock_csv, + index_path=mock_csv, standard_names=["rainfall_flux"], input_steps=0, forecast_steps=3, diff --git a/tests/sampling/test_samplers.py b/tests/sampling/test_samplers.py new file mode 100644 index 0000000..d956612 --- /dev/null +++ b/tests/sampling/test_samplers.py @@ -0,0 +1,91 @@ +"""Unit tests for the candidate-selection schemes and the sampler registry.""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest + +from mlcast.sampling import ImportanceSampler, UniformSampler, get_sampler +from mlcast.sampling.samplers import SAMPLER_REGISTRY + + +def _dry_heavy_pool(n_dry: int = 900, n_wet: int = 100) -> pd.DataFrame: + """Mostly-dry pool with a wet tail (mean 0.01 vs 20.0).""" + mean = np.concatenate([np.full(n_dry, 0.01), np.full(n_wet, 20.0)]) + return pd.DataFrame({"t": np.arange(len(mean)), "x": 0, "y": 0, "mean": mean}) + + +def test_importance_selection_reshapes_toward_extremes_and_is_reproducible() -> None: + pool = _dry_heavy_pool() + sampler = ImportanceSampler() + kept = sampler.select(pool, np.random.default_rng(0)) + + # a subset, each row at most once + assert kept.ndim == 1 and len(kept) <= len(pool) + assert len(np.unique(kept)) == len(kept) + # wet rows are only 10% of the pool but dominate the kept set + wet_frac_pool = (pool["mean"].to_numpy() > 1.0).mean() + wet_frac_kept = (pool["mean"].to_numpy()[kept] > 1.0).mean() + assert wet_frac_pool == pytest.approx(0.1) + assert wet_frac_kept > 0.5 + # reproducible given the rng + again = sampler.select(pool, np.random.default_rng(0)) + assert np.array_equal(kept, again) + + +def test_importance_tuning_changes_how_much_is_kept() -> None: + pool = _dry_heavy_pool() + # a higher floor (q_min) keeps more of the dry majority + low = ImportanceSampler(q_min=1e-4).select(pool, np.random.default_rng(1)) + high = ImportanceSampler(q_min=0.05).select(pool, np.random.default_rng(1)) + assert len(high) > len(low) + + +def test_importance_sampler_requires_mean_column() -> None: + pool = pd.DataFrame({"t": [0, 1], "x": [0, 0], "y": [0, 0]}) + with pytest.raises(ValueError, match="mean"): + ImportanceSampler().select(pool, np.random.default_rng(0)) + + +def test_importance_sampler_can_weight_on_a_different_column() -> None: + # 'mean' is flat (would keep ~all); we instead weight on 'sum', whose high + # tail should dominate the kept set + pool = pd.DataFrame( + { + "t": np.arange(1000), + "x": 0, + "y": 0, + "mean": 0.0, + "sum": np.concatenate([np.full(900, 1.0), np.full(100, 1000.0)]), + } + ) + kept = ImportanceSampler(column="sum", scale=1000.0).select(pool, np.random.default_rng(0)) + assert (pool["sum"].to_numpy()[kept] > 100).mean() > 0.5 + + +def test_importance_sampler_missing_chosen_column_raises() -> None: + pool = pd.DataFrame({"t": [0, 1], "x": 0, "y": 0, "mean": [0.1, 0.2]}) + with pytest.raises(ValueError, match="frac_wet"): + ImportanceSampler(column="frac_wet").select(pool, np.random.default_rng(0)) + + +def test_uniform_sampler_keep_fraction() -> None: + pool = pd.DataFrame({"t": np.arange(1000), "x": 0, "y": 0}) # no mean column needed + assert len(UniformSampler(keep_fraction=1.0).select(pool, np.random.default_rng(0))) == 1000 + half = UniformSampler(keep_fraction=0.5).select(pool, np.random.default_rng(0)) + assert 400 < len(half) < 600 + + +def test_uniform_sampler_rejects_bad_fraction() -> None: + with pytest.raises(ValueError, match="keep_fraction"): + UniformSampler(keep_fraction=1.5) + + +def test_registry_lookup_and_unknown() -> None: + assert {"importance", "uniform"} <= set(SAMPLER_REGISTRY) + sampler = get_sampler("importance", scale=2.0) + assert isinstance(sampler, ImportanceSampler) and sampler.scale == 2.0 + assert isinstance(get_sampler("uniform"), UniformSampler) + with pytest.raises(ValueError, match="Unknown sampler"): + get_sampler("does-not-exist") diff --git a/uv.lock b/uv.lock index a0a4af5..b5175f4 100644 --- a/uv.lock +++ b/uv.lock @@ -2249,6 +2249,8 @@ dependencies = [ { name = "nvidia-ml-py" }, { name = "pandas" }, { name = "psutil" }, + { name = "pyarrow" }, + { name = "pydantic" }, { name = "pytorch-lightning" }, { name = "pyyaml" }, { name = "rich" }, @@ -2310,6 +2312,8 @@ requires-dist = [ { name = "pandas", specifier = ">=2" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.5" }, { name = "psutil", extras = ["mlflow"], specifier = ">=7.2.2" }, + { name = "pyarrow", specifier = ">=17" }, + { name = "pydantic", specifier = ">=2" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1" }, { name = "pytorch-lightning", specifier = ">=2" }, From ed80a39e2d999f985df8fc6e83d17dbe395a8672 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 14:25:25 +0200 Subject: [PATCH 02/15] config: bound val/test with a uniform eval_sampler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The full candidate pool is impractical for val/test on large datasets, so the default experiment now uses UniformSampler(keep_fraction=0.1) for eval — representative (unweighted) but bounded — while train keeps ImportanceSampler. --- src/mlcast/config/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index e3c8c99..0442a30 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -33,7 +33,7 @@ from ..data.source_data_datasets import SourceDataIndexedDataset from ..models.convgru import ConvGruModel from ..nowcasting_module import NowcastLightningModule -from ..sampling import ImportanceSampler +from ..sampling import ImportanceSampler, UniformSampler @dataclass @@ -78,6 +78,7 @@ def training_experiment() -> Experiment: dataset_factory=dataset_factory, splits={"time": {"train": 0.70, "val": 0.15, "test": 0.15}}, train_sampler=ImportanceSampler(), + eval_sampler=UniformSampler(keep_fraction=0.1), batch_size=16, num_workers=8, pin_memory=True, From 15ade3fb2ab6bd165e2b0638b26e449efac4720e Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 14:31:54 +0200 Subject: [PATCH 03/15] docs: regenerate config diagram for the new sampling API --- docs/config_diagram.svg | 682 +++++++++++++++++++++------------------- 1 file changed, 364 insertions(+), 318 deletions(-) diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 76275eb..9c37543 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -4,11 +4,11 @@ - - + + %3 - + 2 @@ -35,60 +35,60 @@ 1 - - -Config: - NowcastLightningModule - - -network - - - - - -ensemble_size - -2 - - -loss_class - -'crps' - - -loss_params - - - -dict - - -'temporal_lambda' - -0.01 - - -masked_loss - -True - - -optimizer - - - - - -lr_scheduler - - - + + +Config: + NowcastLightningModule + + +network + + + + + +ensemble_size + +2 + + +loss_class + +'crps' + + +loss_params + + + +dict + + +'temporal_lambda' + +0.01 + + +masked_loss + +True + + +optimizer + + + + + +lr_scheduler + + + 1:c--2:c - + @@ -111,7 +111,7 @@ 1:c--3:c - + @@ -139,341 +139,387 @@ 1:c--4:c - + 0 - - -Config: - Experiment - - -pl_module - - - - - -data - - - - - -trainer - - - + + +Config: + Experiment + + +pl_module + + + + + +data + + + + + +trainer + + + 0:c--1:c - + 5 - - -Config: - SourceDataDataModule - - -dataset_factory - - - - - -splits - - - -dict - - -'time' - - - -dict - - -'train' - -0.7 - - -'val' - -0.15 - - -'test' - -0.15 - - -batch_size - -16 - - -num_workers - -8 - - -pin_memory - -True + + +Config: + SourceDataDataModule + + +dataset_factory + + + + + +splits + + + +dict + + +'time' + + + +dict + + +'train' + +0.7 + + +'val' + +0.15 + + +'test' + +0.15 + + +train_sampler + + + + + +eval_sampler + + + + + +batch_size + +16 + + +num_workers + +8 + + +pin_memory + +True - + 0:c--5:c - + - - -7 - - -Config: - Trainer - - -accelerator - -'auto' - - -logger - - - - - -callbacks - - - -list - - - - - - - - - - - - - - -0 - - -1 - - -2 - - -3 - - -max_epochs - -100 + + +9 + + +Config: + Trainer + + +accelerator + +'auto' + + +logger + + + + + +callbacks + + + +list + + + + + + + + + + + + + + +0 + + +1 + + +2 + + +3 + + +max_epochs + +100 - - -0:c--7:c - + + +0:c--9:c + 6 - - -Partial: - SourceDataPrecomputedSamplingDataset + + +Partial: + SourceDataIndexedDataset zarr_path - + './data/radar.zarr' -csv_path - -'./data/sampled_datacubes.csv' +index_path + +'./data/sampled_datacubes.parquet' standard_names - - - -list - -'rainfall_rate' - - -0 + + + +list + +'rainfall_rate' + + +0 input_steps - + 6 forecast_steps - + 12 return_mask - + True deterministic - + False 5:c--6:c - + - + -8 - - -Config: - TensorBoardLogger - - -save_dir - -'logs' - - -name - -'mlcast' +7 + + +Config: + ImportanceSampler + + +no arguments - - -7:c--8:c - + + +5:c--7:c + - - -9 - - -Config: - ModelCheckpoint - - -monitor - -'val_loss' - - -save_top_k - -1 - - -mode - -'min' + + +8 + + +Config: + UniformSampler + + +keep_fraction + +0.1 - - -7:c--9:c - + + +5:c--8:c + - + 10 - - -Config: - ModelCheckpoint - - -monitor - -'train_loss_epoch' - - -save_top_k - -1 - - -mode - -'min' + + +Config: + TensorBoardLogger + + +save_dir + +'logs' + + +name + +'mlcast' - + -7:c--10:c - +9:c--10:c + 11 - - -Config: - EarlyStopping - - -monitor - -'val_loss' - - -patience - -100 - - -mode - -'min' + + +Config: + ModelCheckpoint + + +monitor + +'val_loss' + + +save_top_k + +1 + + +mode + +'min' - + -7:c--11:c - +9:c--11:c + 12 - - -Config: - LearningRateMonitor - - -logging_interval - -'step' + + +Config: + ModelCheckpoint + + +monitor + +'train_loss_epoch' + + +save_top_k + +1 + + +mode + +'min' - + -7:c--12:c - +9:c--12:c + + + + +13 + + +Config: + EarlyStopping + + +monitor + +'val_loss' + + +patience + +100 + + +mode + +'min' + + + +9:c--13:c + + + + +14 + + +Config: + LearningRateMonitor + + +logging_interval + +'step' + + + +9:c--14:c + From b3d427fc599159cef473ff97cfeaaa9f9ff6dd48 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 15:35:44 +0200 Subject: [PATCH 04/15] Phase 2: bring the stats producer + CLI into mlcast.sampling Folds the rest of mlcast-dataset-sampler into mlcast: the `stats` producer (zarr scan -> stats parquet, bottleneck CPU + torch GPU windowing) and `validate-stats`, exposed as `mlcast` subcommands. - mlcast/sampling/commands/{stats,validate_stats,_stats_gpu}.py + console.py, mirroring the source layout so relative imports to stats_spec/units resolve unchanged. - mlcast CLI: add `stats` and `validate-stats` argparse subcommands beside `train`. Their modules (and bottleneck) are imported lazily in dispatch, so `mlcast train` and `import mlcast.sampling` never pull the producer deps. - pyproject: [sampling] extra = bottleneck (everything else is already core); GPU windowing reuses core torch. - Port producer tests (stats_process/_gpu/_spec) with importorskip(bottleneck); drop test_sampling (importance_weights now inlined into ImportanceSampler). Full suite with --extra sampling green (114 passed, 1 skipped). --- pyproject.toml | 8 + src/mlcast/__main__.py | 35 + src/mlcast/sampling/commands/__init__.py | 5 + src/mlcast/sampling/commands/_stats_gpu.py | 108 +++ src/mlcast/sampling/commands/stats.py | 616 ++++++++++++++++++ .../sampling/commands/validate_stats.py | 145 +++++ src/mlcast/sampling/console.py | 12 + tests/sampling/test_stats_gpu.py | 78 +++ tests/sampling/test_stats_process.py | 239 +++++++ tests/sampling/test_stats_spec.py | 197 ++++++ uv.lock | 59 +- 11 files changed, 1501 insertions(+), 1 deletion(-) create mode 100644 src/mlcast/sampling/commands/__init__.py create mode 100644 src/mlcast/sampling/commands/_stats_gpu.py create mode 100644 src/mlcast/sampling/commands/stats.py create mode 100644 src/mlcast/sampling/commands/validate_stats.py create mode 100644 src/mlcast/sampling/console.py create mode 100644 tests/sampling/test_stats_gpu.py create mode 100644 tests/sampling/test_stats_process.py create mode 100644 tests/sampling/test_stats_spec.py diff --git a/pyproject.toml b/pyproject.toml index 7ef7064..93047cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,11 @@ optional-dependencies.gpu-cu130 = [ "torch>=2.7.1", "torchvision>=0.20", ] +# Producer side: the `mlcast stats` / `mlcast validate-stats` data-prep CLI. +# Only `bottleneck` is extra; zarr/xarray/pyarrow/rich/torch are already core. +optional-dependencies.sampling = [ + "bottleneck>=1.6", +] urls.Documentation = "https://github.com/mlcast-community/mlcast" urls.Homepage = "https://github.com/mlcast-community/mlcast" urls.Issues = "https://github.com/mlcast-community/mlcast/issues" @@ -111,6 +116,9 @@ lint.select = [ "W", # pycodestyle warnings ] lint.ignore = [ "F722" ] +# These tests `importorskip` the sampling extra before importing the command +# modules, so their imports are intentionally not at the top of the file. +lint.per-file-ignores."tests/sampling/test_stats_*.py" = [ "E402" ] [tool.mypy] strict = false diff --git a/src/mlcast/__main__.py b/src/mlcast/__main__.py index 31ffc0b..0d7124e 100644 --- a/src/mlcast/__main__.py +++ b/src/mlcast/__main__.py @@ -332,6 +332,25 @@ def train_main(argv: list[str]) -> None: train_from_config(_config.value) +def _run_sampling_command(name: str, remaining: list[str]) -> None: + """Dispatch the data-prep subcommands (``stats`` / ``validate-stats``). + + Imported lazily so ``mlcast train`` never pulls the sampling deps. + """ + try: + from mlcast.sampling import commands + except ImportError as exc: # bottleneck (the [sampling] extra) not installed + raise SystemExit(f"`mlcast {name}` needs the sampling extra: pip install 'mlcast[sampling]' ({exc})") from None + from loguru import logger + + logger.remove() + logger.add(sys.stderr, level="INFO") + module = commands.stats if name == "stats" else commands.validate_stats + sub = argparse.ArgumentParser(prog=f"mlcast {name}") + module.add_arguments(sub) + sys.exit(module.run(sub.parse_args(remaining))) + + def cli() -> None: """Console script entry point for the ``mlcast`` command. @@ -368,6 +387,20 @@ def cli() -> None: "-h", "--help", action="help", default=argparse.SUPPRESS, help="Show this message and exit." ) + # Data-prep subcommands (plain argparse, no Fiddle). Defined bare here; the + # command modules — and the heavy windowing deps — are imported lazily in + # the dispatch below, so `mlcast train` never pays for them. + subparsers.add_parser( + "stats", + add_help=False, + help="Scan a Zarr dataset and write per-datacube stats to parquet (needs mlcast[sampling]).", + ) + subparsers.add_parser( + "validate-stats", + add_help=False, + help="Validate a stats parquet file against the canonical contract.", + ) + args, remaining = parser.parse_known_args() if args.command == "train": @@ -395,6 +428,8 @@ def cli() -> None: _seed_fiddle_flag_from_yaml(yaml_path) app.run(train_main, argv=[sys.argv[0]] + remaining) + elif args.command in ("stats", "validate-stats"): + _run_sampling_command(args.command, remaining) if __name__ == "__main__": diff --git a/src/mlcast/sampling/commands/__init__.py b/src/mlcast/sampling/commands/__init__.py new file mode 100644 index 0000000..5baf714 --- /dev/null +++ b/src/mlcast/sampling/commands/__init__.py @@ -0,0 +1,5 @@ +"""Sampler commands.""" + +from . import stats, validate_stats + +__all__ = ["stats", "validate_stats"] diff --git a/src/mlcast/sampling/commands/_stats_gpu.py b/src/mlcast/sampling/commands/_stats_gpu.py new file mode 100644 index 0000000..9392df2 --- /dev/null +++ b/src/mlcast/sampling/commands/_stats_gpu.py @@ -0,0 +1,108 @@ +"""GPU (PyTorch) backend for the per-chunk stats windowing. + +Mirrors the CPU `stats._process_chunk` on CUDA tensors: the chunk is moved +to the GPU once, the three windowed stats are reduced onto the strided +candidate grid, and only the survivors are copied back. `nan_count` and +`frac_wet` match the CPU exactly; `sum`/`mean` agree to a few float32 ULP +(the GPU sums in a different order). + +This module is imported only when ``--device cuda`` is selected, so torch +stays an optional dependency (the ``gpu`` extra). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import torch + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +def _dim_window(a: torch.Tensor, dim: int, delta: int) -> torch.Tensor: + """Sliding-window sum of size `delta` along `dim` (output len = n-delta+1).""" + cs = torch.cumsum(a, dim=dim) + pad_shape = list(cs.shape) + pad_shape[dim] = 1 + padded = torch.cat([torch.zeros(pad_shape, dtype=cs.dtype, device=cs.device), cs], dim=dim) + n = a.shape[dim] + return padded.narrow(dim, delta, n - delta + 1) - padded.narrow(dim, 0, n - delta + 1) + + +def _strided_window( + a: torch.Tensor, + deltas: tuple[int, int, int], + off_t: int, + steps: tuple[int, int, int], + keep_t: torch.Tensor, +) -> torch.Tensor: + """Windowed sum reduced to the strided, gap-free candidate grid (see CPU twin).""" + Dt, w, h = deltas + step_t, step_x, step_y = steps + s = _dim_window(a, 0, Dt)[off_t::step_t][keep_t] + s = _dim_window(s, 1, w)[:, 0::step_x] + s = _dim_window(s, 2, h)[:, :, 0::step_y] + return s + + +def process_chunk( + time_range: tuple[int, int], + t_start_idx: int, + chunk_np: NDArray, + max_nan: int, + wet_threshold: float, + deltas: tuple[int, int, int], + steps: tuple[int, int, int], + valid_start_mask: NDArray[np.bool_], + device: torch.device, +) -> dict[str, NDArray]: + """GPU twin of `stats._process_chunk`. `chunk_np` is read on the CPU; this + moves it to `device`, computes the strided stats, and returns CPU numpy + arrays in the same column layout as the CPU path. + """ + start_t, end_t = time_range + Dt, w, h = deltas + step_t, step_x, step_y = steps + total_px = Dt * w * h + off_t = (-(start_t + t_start_idx)) % step_t + + chunk = torch.from_numpy(chunk_np).to(device, non_blocking=True) + + # Strided, gap-free time-window starts (computed on the CPU, cheap). + nt_win = chunk.shape[0] - Dt + 1 + t_rel_strided = np.arange(off_t, nt_win, step_t, dtype=np.int64) + keep_np = valid_start_mask[t_rel_strided + start_t] + keep_t = torch.from_numpy(keep_np).to(device) + t_rel_kept = torch.from_numpy(t_rel_strided[keep_np]).to(device) + + nan_mask = torch.isnan(chunk) + + # Pass A: nan_count. cumsum keeps exact integer counts (< 2^31). + ncw = _strided_window(nan_mask.to(torch.int32), deltas, off_t, steps, keep_t) + a, b, c = torch.nonzero(ncw <= max_nan, as_tuple=True) + nan_count = ncw[a, b, c] + + # Pass B/C on the zero-filled chunk. + chunk = torch.nan_to_num(chunk, nan=0.0) + sum_vals = _strided_window(chunk, deltas, off_t, steps, keep_t)[a, b, c] + wet_count = _strided_window((chunk > wet_threshold).to(torch.int32), deltas, off_t, steps, keep_t)[a, b, c] + + idx_t_abs = (t_rel_kept[a] + (start_t + t_start_idx)).to(torch.int32) + idx_x = (b * step_x).to(torch.int32) + idx_y = (c * step_y).to(torch.int32) + + valid_count = (total_px - nan_count).to(torch.float32) + mean_vals = torch.where(valid_count > 0, sum_vals / valid_count, torch.full_like(sum_vals, float("nan"))) + frac_wet = wet_count.to(torch.float32) / total_px + + return { + "t": idx_t_abs.cpu().numpy(), + "x": idx_x.cpu().numpy(), + "y": idx_y.cpu().numpy(), + "nan_count": nan_count.to(torch.int32).cpu().numpy(), + "sum": sum_vals.to(torch.float32).cpu().numpy(), + "mean": mean_vals.to(torch.float32).cpu().numpy(), + "frac_wet": frac_wet.cpu().numpy(), + } diff --git a/src/mlcast/sampling/commands/stats.py b/src/mlcast/sampling/commands/stats.py new file mode 100644 index 0000000..f87b43f --- /dev/null +++ b/src/mlcast/sampling/commands/stats.py @@ -0,0 +1,616 @@ +"""Per-datacube statistics via cumsum-based sliding windows. + +Scans a Zarr dataset for valid datacube candidates and computes, for each +one, `nan_count`, `sum`, `mean`, and `frac_wet`, each in O(1) per window +amortized via a prefix-sum (cumsum) trick. The survivors (those passing +the `max_nan`, stride, and time-continuity filters) are written to a stats +parquet whose contract is defined in `stats_spec`. Downstream, +`SourceDataIndexedDataset` reads this parquet and an `ImportanceSampler` +(see `samplers`) selects candidates by the `mean` column; no separate +sampling pass is needed. + +Heavy stats that cannot be computed with cumsum (max, quantiles) are out +of scope here and could be added as extra columns in a future pass. +""" + +from __future__ import annotations + +import argparse +import os +import time +from functools import partial +from multiprocessing import Pool +from queue import Queue +from threading import Thread +from typing import TYPE_CHECKING + +import bottleneck as bn +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import xarray as xr +import zarr +from loguru import logger +from rich.panel import Panel +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) +from rich.table import Table + +from ..console import console +from ..stats_spec import STAT_COLUMNS, StatsMetadata, build_schema +from ..units import default_wet_threshold, detect_data_kind + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +def add_arguments(parser: argparse.ArgumentParser) -> None: + """Add stats specific arguments to the parser.""" + parser.add_argument("zarr_path", type=str, help="Path to the Zarr dataset.") + parser.add_argument( + "-o", + "--output", + type=str, + default=None, + help="Output Parquet file path. If not specified, auto-generated from parameters.", + ) + parser.add_argument("--start-date", type=str, default=None, help="Start date (YYYY-MM-DD).") + parser.add_argument("--end-date", type=str, default=None, help="End date (YYYY-MM-DD).") + parser.add_argument("--time-depth", type=int, default=24, help="Time depth of datacubes.") + parser.add_argument("--width", type=int, default=256, help="Spatial width of datacubes.") + parser.add_argument("--height", type=int, default=256, help="Spatial height of datacubes.") + parser.add_argument("--step-t", type=int, default=3, help="Time step between datacubes.") + parser.add_argument("--step-x", type=int, default=16, help="X step between datacubes.") + parser.add_argument("--step-y", type=int, default=16, help="Y step between datacubes.") + parser.add_argument( + "--max-nan", + type=int, + default=10000, + help="Maximum NaN count per datacube (hard filter on output).", + ) + parser.add_argument( + "--wet-threshold", + type=float, + default=None, + help="Wet-pixel threshold in the same units as the data var. " + "If omitted, auto-detected: 0.1 mm/h for rain rate, 7 dBZ for reflectivity.", + ) + parser.add_argument( + "--data-kind", + choices=["rainrate", "reflectivity"], + default=None, + help="Override the data-kind auto-detection from zarr attrs. " + "Needed only if the variable has non-standard attributes.", + ) + parser.add_argument( + "--time-step-minutes", + type=int, + default=5, + help="Expected time step between consecutive frames in minutes.", + ) + parser.add_argument( + "--device", + choices=["auto", "cpu", "cuda"], + default="auto", + help="Compute backend. 'auto' (default) uses CUDA if a GPU is available, else the CPU. " + "'cuda' requires a CUDA GPU.", + ) + parser.add_argument( + "--workers", + type=int, + default=8, + help="CPU: number of worker processes. GPU: number of chunk-reader threads.", + ) + parser.add_argument("--data-var", type=str, default="RR", help="Name of the zarr data variable.") + parser.add_argument("--time-var", type=str, default="time", help="Name of the zarr time variable.") + parser.add_argument("--overwrite", action="store_true", help="Overwrite output file if it exists.") + + +def _dim_cumsum_window( + arr: NDArray, + dim: int, + delta: int, + dim_len: int, +) -> NDArray: + """3D sliding-window sum along one axis via a prefix-sum difference. + + Works for any numeric dtype (int for counting, float for summing). For + every window of size `delta` along `dim`, returns the sum of the + elements inside that window. O(n) per axis regardless of `delta`. + + The output length along `dim` is ``dim_len - delta + 1`` (one entry per + valid window start, including the final one at ``dim_len - delta``). + """ + # Use int32 (not the numpy default int64) for int inputs to halve memory. + # A window count is bounded by Dt*w*h <= 2^31, safe in int32. + cumsum = np.cumsum(arr, axis=dim, dtype=arr.dtype if arr.dtype.kind == "f" else np.int32) + + # Prepend a zero plane so window [k, k+delta) == padded[k+delta] - padded[k]. + pad_width = [(1, 0) if i == dim else (0, 0) for i in range(arr.ndim)] + padded = np.pad(cumsum, pad_width=pad_width, mode="constant", constant_values=0) + + # Window starts k = 0 .. dim_len-delta, so padded[k] spans [0, dim_len-delta] + # and padded[k+delta] spans [delta, dim_len] (note the inclusive +1 ends). + slices_start = [slice(dim_len - delta + 1) if i == dim else slice(None) for i in range(arr.ndim)] + slices_end = [slice(delta, dim_len + 1) if i == dim else slice(None) for i in range(arr.ndim)] + + return padded[tuple(slices_end)] - padded[tuple(slices_start)] + + +def _datacube_window_sum( + arr: NDArray, + deltas: tuple[int, int, int], + dim_lengths: tuple[int, int, int], + order: tuple[int, int, int] = (0, 1, 2), +) -> NDArray: + """3-axis cumsum-window sum over a (T, X, Y) array (all positions). + + `order` is the sequence in which the axes are windowed. Windowing one + axis only shrinks that axis, so for an associative reduction the result + is independent of `order`. The production path uses `_strided_window` + instead (it evaluates only the kept positions); this full-cube variant + is kept for tests and reference. + """ + s = arr + for ax in order: + s = _dim_cumsum_window(s, dim=ax, delta=deltas[ax], dim_len=dim_lengths[ax]) + return s + + +def _strided_window( + arr: NDArray, + deltas: tuple[int, int, int], + dim_lengths: tuple[int, int, int], + off_t: int, + steps: tuple[int, int, int], + keep_t: NDArray[np.bool_], +) -> NDArray: + """Windowed sum reduced to the strided, gap-free candidate grid. + + Windows the time axis, keeps only the strided, gap-free t-slices, then + windows x and y on the already-strided (smaller) arrays. Selecting the + candidate grid *between* the axes shrinks the x cumsum by ``step_t`` and + the y cumsum by ``step_t * step_x`` while leaving the kept values + unchanged: windowing one axis is independent of which positions are kept + on the others. + + The full-array time axis uses bottleneck's ``move_sum`` (faster than a + numpy cumsum on this large, non-contiguous axis); the small strided + spatial axes stay on numpy. ``move_sum`` needs float input and emits NaN + for the first ``Dt-1`` incomplete windows, dropped with ``[Dt-1:]``; + integer counts are cast back to int32 (exact to 2^31). + + `keep_t` is the boolean continuity mask over the strided t-grid + ``arange(off_t, dim_lengths[0] - deltas[0] + 1, step_t)``; its length + must match that grid. + """ + Dt, w, h = deltas + step_t, step_x, step_y = steps + s = bn.move_sum(arr.astype(np.float32, copy=False), Dt, axis=0)[Dt - 1 :] + if arr.dtype.kind in "bi": + s = s.astype(np.int32) + s = s[off_t::step_t][keep_t] + s = _dim_cumsum_window(s, dim=1, delta=w, dim_len=dim_lengths[1]) + s = s[:, 0::step_x] + s = _dim_cumsum_window(s, dim=2, delta=h, dim_len=dim_lengths[2]) + return s[:, :, 0::step_y] + + +def _process_chunk( + time_range: tuple[int, int], + t_start_idx: int, + data: zarr.Array, + max_nan: int, + wet_threshold: float, + deltas: tuple[int, int, int], + steps: tuple[int, int, int], + valid_start_mask: NDArray[np.bool_], +) -> dict[str, NDArray]: + """Compute the windowed stats for one time chunk's candidate datacubes. + + Returns the survivors — windows passing the `max_nan`, stride, and + time-continuity filters — with their `nan_count`, `sum`, `mean`, and + `frac_wet`, as a dict of numpy arrays matching `STATS_SCHEMA`. + + The three reductions (nan_count, sum, wet_count) are each computed with + `_strided_window` and freed before the next, keeping peak memory near a + single windowed array. `valid_start_mask` is a boolean lookup over the + filtered time axis (True at gap-free window starts). + """ + start_t, end_t = time_range + chunk = data[start_t + t_start_idx : end_t + t_start_idx, :, :].astype(np.float32, copy=False) + dim_lengths = chunk.shape + Dt, w, h = deltas + step_t, step_x, step_y = steps + total_px = Dt * w * h + + # Strided, gap-free time-window starts, absolute-index aligned. The time + # window count is dim_lengths[0] - Dt + 1 (see _dim_cumsum_window); the + # x/y stride offsets are 0 because chunk x/y are already absolute. + off_t = (-(start_t + t_start_idx)) % step_t + t_rel_strided = np.arange(off_t, dim_lengths[0] - Dt + 1, step_t, dtype=np.int32) + keep_t = valid_start_mask[t_rel_strided + start_t] # time-continuity filter + t_rel_kept = t_rel_strided[keep_t] + + # Build the NaN mask once; (a) drives nan_count, (b) zero-fills the sum pass. + nan_mask = np.isnan(chunk) # bool, 1 byte/element + + # --- Pass A: nan_count on the strided candidate grid ------------------------- + ncw_s = _strided_window(nan_mask, deltas, dim_lengths, off_t, steps, keep_t) + surv_t, surv_x, surv_y = np.where(ncw_s <= max_nan) + nan_count = ncw_s[surv_t, surv_x, surv_y].astype(np.int32) + del ncw_s + + # Map strided survivor indices back to chunk-relative / absolute coords. + idx_t_rel = t_rel_kept[surv_t] + idx_x = (surv_x * step_x).astype(np.int32) + idx_y = (surv_y * step_y).astype(np.int32) + idx_t_abs = (idx_t_rel + (start_t + t_start_idx)).astype(np.int32) + + # --- Pass B: sum ------------------------------------------------------------- + chunk[nan_mask] = 0.0 + sum_vals = _strided_window(chunk, deltas, dim_lengths, off_t, steps, keep_t)[surv_t, surv_x, surv_y] + + # --- Pass C: wet_count ------------------------------------------------------- + # `chunk` is now zero where it was NaN, so `> wet_threshold` is equivalent + # to (value > threshold AND not NaN). + wet_mask = chunk > wet_threshold + del chunk, nan_mask + wet_count = _strided_window(wet_mask, deltas, dim_lengths, off_t, steps, keep_t)[surv_t, surv_x, surv_y] + + # Derived stats. + valid_count = total_px - nan_count + with np.errstate(invalid="ignore", divide="ignore"): + mean_vals = np.where(valid_count > 0, sum_vals / valid_count, np.nan).astype(np.float32) + frac_wet = wet_count.astype(np.float32) / total_px + + return { + "t": idx_t_abs, + "x": idx_x, + "y": idx_y, + "nan_count": nan_count, + "sum": sum_vals.astype(np.float32), + "mean": mean_vals, + "frac_wet": frac_wet, + } + + +def _parquet_writer( + output_queue: Queue, + filename: str, + schema: pa.Schema, +) -> None: + """Drain the queue and stream rows to a Parquet file. + + Each queue item is the dict returned by `_process_chunk`. We buffer + into Arrow RecordBatches and append to a single ParquetWriter so the + on-disk file stays a single self-contained parquet. `schema` is the + canonical STATS_SCHEMA with the mlcast sampling parameters attached as + metadata (see `stats_spec.build_schema`), so downstream commands don't + need to parse the filename. + """ + # Column encodings tuned to the data: `t` is written in ascending order so + # it delta-encodes to almost nothing; x/y/nan_count are low-cardinality + # (dictionary); the floats compress better split into byte planes. + writer = pq.ParquetWriter( + filename, + schema, + compression="zstd", + compression_level=9, + use_dictionary=["x", "y", "nan_count"], + column_encoding={ + "t": "DELTA_BINARY_PACKED", + "sum": "BYTE_STREAM_SPLIT", + "mean": "BYTE_STREAM_SPLIT", + "frac_wet": "BYTE_STREAM_SPLIT", + }, + ) + total_rows = 0 + try: + while True: + item = output_queue.get() + if item is None: + break + if item["t"].size == 0: + continue + batch = pa.record_batch( + [pa.array(item[c]) for c in STAT_COLUMNS], + schema=schema, + ) + writer.write_batch(batch) + total_rows += batch.num_rows + finally: + writer.close() + logger.info(f"Wrote {total_rows} rows to {filename}") + + +def _resolve_device(requested: str) -> tuple[str, str]: + """Resolve the compute backend to ('cpu'|'cuda', human label). + + 'auto' picks CUDA when PyTorch + a GPU are importable/available, else CPU. + 'cuda' raises ValueError if PyTorch or a GPU is missing. 'cpu' is forced. + """ + if requested == "cpu": + return "cpu", "cpu (bottleneck)" + try: + import torch + except ImportError: + if requested == "cuda": + raise ValueError("--device cuda requested but PyTorch is not installed.") from None + return "cpu", "cpu (bottleneck)" + if torch.cuda.is_available(): + return "cuda", f"cuda ({torch.cuda.get_device_name(0)})" + if requested == "cuda": + raise ValueError("--device cuda requested but no CUDA GPU is available.") + return "cpu", "cpu (bottleneck)" + + +def _prefetched(read_fn, items, lookahead: int, n_workers: int): + """Yield ``read_fn(item)`` results in submission order, keeping up to + `lookahead` reads in flight across `n_workers` threads. + + Used by the GPU path to overlap chunk reads/decompression (which release + the GIL) with GPU compute, so the device stays fed. Memory is bounded by + `lookahead` chunks. + """ + from collections import deque + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(max_workers=n_workers) as ex: + it = iter(items) + inflight: deque = deque() + for _ in range(lookahead): + try: + inflight.append(ex.submit(read_fn, next(it))) + except StopIteration: + break + while inflight: + fut = inflight.popleft() + try: + inflight.append(ex.submit(read_fn, next(it))) + except StopIteration: + pass + yield fut.result() + + +def run(args: argparse.Namespace) -> int: + """Execute the stats command.""" + start_time = time.time() + Dt = args.time_depth + w = args.width + h = args.height + step_T = args.step_t + step_X = args.step_x + step_Y = args.step_y + max_nan = args.max_nan + n_workers = args.workers + time_chunk_size = 3 * Dt + + try: + device, device_label = _resolve_device(args.device) + except ValueError as e: + logger.error(str(e)) + return 1 + logger.info(f"Compute backend: {device_label}") + + logger.info(f"Opening Zarr dataset: {args.zarr_path}") + try: + zg = zarr.open(args.zarr_path, mode="r") + data = zg[args.data_var] + ds = xr.open_zarr(args.zarr_path) + time_array_full = pd.DatetimeIndex(ds[args.time_var].values) + logger.info(f"Full dataset shape: T={data.shape[0]}, X={data.shape[1]}, Y={data.shape[2]}") + logger.info(f"Time range: {time_array_full[0]} to {time_array_full[-1]}") + var_attrs = dict(ds[args.data_var].attrs) + except Exception as e: + logger.error(f"Error loading Zarr dataset: {e}") + return 1 + + # Detect or override the data kind, then resolve the wet-pixel threshold. + if args.data_kind is not None: + data_kind = args.data_kind + logger.info(f"Data kind overridden via --data-kind: {data_kind}") + else: + try: + data_kind = detect_data_kind(var_attrs) + except ValueError as e: + logger.error(str(e)) + return 1 + logger.info( + f"Detected data kind: {data_kind} " + f"(standard_name={var_attrs.get('standard_name')!r}, " + f"units={var_attrs.get('units')!r})" + ) + + wet_threshold = args.wet_threshold if args.wet_threshold is not None else default_wet_threshold(data_kind) + units_str = var_attrs.get("units", "?") + logger.info(f"Wet-pixel threshold: {wet_threshold} {units_str}") + + start_date = pd.to_datetime(args.start_date) if args.start_date else time_array_full[0] + end_date = pd.to_datetime(args.end_date) if args.end_date else time_array_full[-1] + + mask = (time_array_full >= start_date) & (time_array_full <= end_date) + valid_indices = np.where(mask)[0] + if len(valid_indices) == 0: + logger.error(f"No data found between {start_date} and {end_date}") + return 1 + + t_start_idx = valid_indices[0] + t_end_idx = valid_indices[-1] + 1 + size_T = t_end_idx - t_start_idx + size_X = data.shape[1] + size_Y = data.shape[2] + time_array = time_array_full[t_start_idx:t_end_idx] + + logger.info(f"Filtered dataset shape: T={size_T}, X={size_X}, Y={size_Y}") + logger.info(f"Filtered time range: {time_array[0]} to {time_array[-1]}") + max_t = size_T - Dt + 1 + + logger.info("Checking time continuity...") + expected_step = pd.Timedelta(minutes=args.time_step_minutes) + time_diffs = time_array[1:] - time_array[:-1] + gaps = (time_diffs != expected_step).astype(int) + window_sum = np.convolve(gaps, np.ones(Dt - 1, dtype=int), mode="valid") + valid_starts_gap = np.where(window_sum == 0)[0] + logger.info(f"Found {len(valid_starts_gap)} valid time starts without gaps") + # Boolean lookup over the filtered time axis for an O(1) continuity test + # per candidate window start. + valid_start_mask = np.zeros(size_T, dtype=bool) + valid_start_mask[valid_starts_gap] = True + + # Peak memory per worker ~= chunk (float32) + cumsum working set + nan_mask (bool). + # The three cumsum reductions run sequentially, so only one window array is + # alive at a time. + chunk_bytes = (time_chunk_size + Dt - 1) * size_X * size_Y * 4 + per_chunk_gb = 2 * chunk_bytes / (1024**3) + logger.info(f"Estimated memory per chunk: {per_chunk_gb:.2f} GB (pipelined cumsums)") + logger.info(f"Estimated total memory ({n_workers} workers): {per_chunk_gb * n_workers:.2f} GB") + + t_starts = np.arange(0, max_t, time_chunk_size) + t_ends = np.minimum(t_starts + time_chunk_size + Dt - 1, size_T) + t_pairs = np.stack((t_starts, t_ends), axis=1) + + start_str = start_date.strftime("%Y-%m-%d") + end_str = end_date.strftime("%Y-%m-%d") + if args.output: + output_file = args.output + else: + output_file = f"stats_{start_str}-{end_str}_{Dt}x{w}x{h}_{step_T}x{step_X}x{step_Y}_{max_nan}.parquet" + if os.path.exists(output_file) and not args.overwrite: + logger.error(f"File {output_file} already exists. Use --overwrite to replace.") + return 1 + logger.info(f"Output file: {output_file}") + + metadata = StatsMetadata( + zarr_path=args.zarr_path, + data_var=args.data_var, + time_var=args.time_var, + start_date=start_date.isoformat(), + end_date=end_date.isoformat(), + time_step_minutes=args.time_step_minutes, + time_depth=Dt, + width=w, + height=h, + step_t=step_T, + step_x=step_X, + step_y=step_Y, + max_nan=max_nan, + wet_threshold=wet_threshold, + data_kind=data_kind, + units=var_attrs.get("units"), + ) + schema = build_schema(metadata) + + cfg = Table.grid(padding=(0, 2)) + cfg.add_column(justify="right", style="bold cyan") + cfg.add_column() + cfg.add_row("Dataset", f"📦 T={size_T:,} X={size_X:,} Y={size_Y:,}") + cfg.add_row("Time range", f"🕐 {time_array[0]:%Y-%m-%d %H:%M} → {time_array[-1]:%Y-%m-%d %H:%M}") + cfg.add_row("Time step", f"⏱️ {args.time_step_minutes} min") + cfg.add_row("Datacube", f"🧊 {Dt} × {w} × {h} stride {step_T} × {step_X} × {step_Y}") + cfg.add_row("Valid starts", f"✅ {len(valid_starts_gap):,} gap-free") + cfg.add_row("Filters", f"🔍 max_nan={max_nan:,} wet > {wet_threshold:g} {units_str}") + cfg.add_row("Data kind", f"💧 {data_kind}") + cfg.add_row("Device", f"⚡ {device_label}") + if device == "cuda": + cfg.add_row("Readers", f"🧵 {n_workers} threads") + else: + cfg.add_row("Workers", f"🧵 {n_workers} ~{per_chunk_gb * n_workers:.1f} GB peak") + cfg.add_row("Output", f"💾 {output_file}") + console.print( + Panel( + cfg, + title="[bold]📊 mlcast stats[/]", + subtitle=f"[dim]{os.path.basename(args.zarr_path)}[/]", + border_style="blue", + expand=False, + ) + ) + + output_queue: Queue = Queue(maxsize=100) + writer_thread = Thread(target=_parquet_writer, args=(output_queue, output_file, schema)) + writer_thread.daemon = False + writer_thread.start() + + progress = Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + MofNCompleteColumn(), + TaskProgressColumn(), + TimeElapsedColumn(), + TimeRemainingColumn(), + console=console, + ) + if device == "cuda": + import torch + + from . import _stats_gpu + + dev = torch.device("cuda") + + def _read_chunk(tp): + s0, e0 = int(tp[0]), int(tp[1]) + arr = np.asarray(data[s0 + t_start_idx : e0 + t_start_idx, :, :], dtype=np.float32) + return (s0, e0), arr + + with progress: + task = progress.add_task("🔍 Scanning time chunks (GPU)", total=len(t_starts)) + for time_range, chunk_np in _prefetched( + _read_chunk, t_pairs, lookahead=n_workers + 2, n_workers=max(1, n_workers) + ): + hits = _stats_gpu.process_chunk( + time_range, + t_start_idx, + chunk_np, + max_nan, + wet_threshold, + (Dt, w, h), + (step_T, step_X, step_Y), + valid_start_mask, + dev, + ) + output_queue.put(hits) + progress.advance(task) + else: + process_chunk_partial = partial( + _process_chunk, + t_start_idx=t_start_idx, + data=data, + max_nan=max_nan, + wet_threshold=wet_threshold, + deltas=(Dt, w, h), + steps=(step_T, step_X, step_Y), + valid_start_mask=valid_start_mask, + ) + with progress: + task = progress.add_task("🔍 Scanning time chunks", total=len(t_starts)) + with Pool(n_workers) as pool: + for hits in pool.imap(process_chunk_partial, t_pairs, chunksize=1): + output_queue.put(hits) + progress.advance(task) + + output_queue.put(None) + writer_thread.join() + + n_rows = pq.read_metadata(output_file).num_rows + console.print( + Panel( + f"✅ Wrote [bold]{n_rows:,}[/] datacube candidates " + f"in [bold]{time.time() - start_time:.1f}s[/]\n" + f"[dim]💾 {output_file}[/]", + title="[bold green]🎉 stats complete[/]", + border_style="green", + expand=False, + ) + ) + return 0 diff --git a/src/mlcast/sampling/commands/validate_stats.py b/src/mlcast/sampling/commands/validate_stats.py new file mode 100644 index 0000000..64db5bc --- /dev/null +++ b/src/mlcast/sampling/commands/validate_stats.py @@ -0,0 +1,145 @@ +"""Validate a stats parquet file against the canonical contract. + +Thin CLI wrapper over :func:`stats_spec.validate_stats_parquet`. Checks the +column schema and metadata payload, and (unless ``--no-data-checks``) the +per-row value invariants. +""" + +from __future__ import annotations + +import argparse + +import pyarrow.parquet as pq +from rich import box +from rich.panel import Panel +from rich.table import Table + +from ..console import console +from ..stats_spec import read_metadata, validate_stats_parquet + + +def add_arguments(parser: argparse.ArgumentParser) -> None: + """Add validate-stats specific arguments to the parser.""" + parser.add_argument("parquet_path", type=str, help="Path to the stats parquet file.") + parser.add_argument( + "--no-data-checks", + action="store_true", + help="Only check the schema and metadata (read the footer, not the rows).", + ) + + +def _summary_grid(path: str) -> Table: + """Compact grid of the file's key parameters, shown on success.""" + meta = read_metadata(path) + grid = Table.grid(padding=(0, 2)) + grid.add_column(justify="right", style="bold cyan") + grid.add_column() + grid.add_row("Source", f"📂 {meta.zarr_path} var={meta.data_var} time={meta.time_var}") + grid.add_row( + "Datacube", + f"🧊 {meta.time_depth} × {meta.width} × {meta.height} stride {meta.step_t} × {meta.step_x} × {meta.step_y}", + ) + grid.add_row("Range", f"📅 {meta.start_date:%Y-%m-%d} → {meta.end_date:%Y-%m-%d}") + grid.add_row("Time step", f"🕒 {meta.time_step_minutes} min") + grid.add_row("Max NaN", f"🔍 {meta.max_nan:,} per datacube") + grid.add_row("Data kind", f"💧 {meta.data_kind} (wet > {meta.wet_threshold:g} {meta.units or '?'})") + grid.add_row("Rows", f"🔢 {pq.read_metadata(path).num_rows:,}") + grid.add_row("Schema", f"📐 v{meta.schema_version}") + return grid + + +def _fmt(value: object) -> str: + """Compact cell formatting: grouped ints, 6-significant-figure floats.""" + if isinstance(value, bool): + return str(value) + if isinstance(value, int): + return f"{value:,}" + if isinstance(value, float): + return f"{value:.6g}" + return str(value) + + +def _schema_table(path: str) -> Table: + """Column-by-column view of the parquet schema (read from the footer).""" + schema = pq.read_schema(path) + table = Table( + title="🧱 table structure", title_style="bold cyan", box=box.SIMPLE_HEAD, header_style="bold", pad_edge=False + ) + table.add_column("#", justify="right", style="dim") + table.add_column("column", style="bold") + table.add_column("type") + table.add_column("nullable") + for i, field in enumerate(schema): + table.add_row(str(i), field.name, str(field.type), "yes" if field.nullable else "no") + return table + + +def _preview_table(path: str, n: int = 10) -> Table: + """The first ``n`` rows, reading only the leading row group(s).""" + pf = pq.ParquetFile(path) + total = pf.metadata.num_rows + names = pf.schema_arrow.names + table = Table( + title=f"🔎 first {min(n, total)} rows", + title_style="bold cyan", + box=box.SIMPLE_HEAD, + header_style="bold", + pad_edge=False, + ) + for name in names: + table.add_column(name, justify="right") + if total == 0: + table.add_row(*["—"] * len(names)) + return table + batch = next(pf.iter_batches(batch_size=n)) + cols = [batch.column(i).to_pylist() for i in range(batch.num_columns)] + for r in range(min(n, batch.num_rows)): + table.add_row(*[_fmt(col[r]) for col in cols]) + return table + + +def run(args: argparse.Namespace) -> int: + """Execute the validate-stats command.""" + path = args.parquet_path + report = validate_stats_parquet(path, check_data=not args.no_data_checks) + + if report.ok and not report.warnings: + console.print( + Panel( + _summary_grid(path), + title="[bold green]✅ valid stats parquet[/]", + subtitle=f"[dim]{path}[/]", + border_style="green", + expand=False, + ) + ) + else: + issues = Table(show_header=True, header_style="bold", box=None, pad_edge=False) + issues.add_column("") + issues.add_column("message") + for w in report.warnings: + issues.add_row("⚠️", w) + for e in report.errors: + issues.add_row("❌", e) + + if report.ok: + title = f"[bold yellow]⚠️ valid — {len(report.warnings)} warning(s)[/]" + border = "yellow" + else: + title = f"[bold red]❌ invalid — {len(report.errors)} violation(s)[/]" + border = "red" + console.print(Panel(issues, title=title, subtitle=f"[dim]{path}[/]", border_style=border, expand=False)) + + # Structure and a row preview — useful for eyeballing a file, and helpful + # for diagnosis even when the report flagged problems. Best-effort: a file + # too broken to read here has already been reported above. + try: + console.print(_schema_table(path)) + if args.no_data_checks: + console.print("[dim]row preview skipped (--no-data-checks)[/]") + else: + console.print(_preview_table(path)) + except Exception as exc: # noqa: BLE001 - preview is non-essential + console.print(f"[yellow]could not read table preview: {exc}[/]") + + return 0 if report.ok else 1 diff --git a/src/mlcast/sampling/console.py b/src/mlcast/sampling/console.py new file mode 100644 index 0000000..a5f073e --- /dev/null +++ b/src/mlcast/sampling/console.py @@ -0,0 +1,12 @@ +"""Shared rich console for user-facing CLI output. + +Writes to stderr so stdout stays clean for piping. Both the rich progress +display and the loguru logs target stderr, and no logs are emitted during +the progress live region, so they don't clobber each other. +""" + +from __future__ import annotations + +from rich.console import Console + +console = Console(stderr=True) diff --git a/tests/sampling/test_stats_gpu.py b/tests/sampling/test_stats_gpu.py new file mode 100644 index 0000000..e383094 --- /dev/null +++ b/tests/sampling/test_stats_gpu.py @@ -0,0 +1,78 @@ +"""GPU backend tests — skipped unless torch + a CUDA GPU are available. + +Validates the torch `_stats_gpu.process_chunk` against the CPU +`_process_chunk` and the independent brute-force oracle: integer columns +exact, float sum/mean to ``allclose`` (GPU float reduction order differs). +""" + +from __future__ import annotations + +import numpy as np +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("bottleneck") # the stats command pulls bottleneck (the sampling extra) +if not torch.cuda.is_available(): # pragma: no cover + pytest.skip("no CUDA GPU available", allow_module_level=True) + +# Reuse the CPU test's data generator, brute force, lexsort, and cases. +from test_stats_process import CASES, _brute_force, _lexsort, _make_data + +from mlcast.sampling.commands import _stats_gpu +from mlcast.sampling.commands.stats import _process_chunk + + +@pytest.mark.parametrize("seed,deltas,steps,max_nan,wet_thr,start_t,t_start_idx", CASES) +def test_gpu_matches_cpu_and_brute(seed, deltas, steps, max_nan, wet_thr, start_t, t_start_idx): + T_total, X, Y = 60, 44, 40 + data = _make_data(seed, T_total, X, Y) + size_T = T_total - t_start_idx + valid_start_mask = np.ones(size_T, dtype=bool) + end_t = size_T - start_t + + chunk = data[start_t + t_start_idx : end_t + t_start_idx, :, :].astype(np.float32) # snapshot + + gpu = _lexsort( + _stats_gpu.process_chunk( + (start_t, end_t), + t_start_idx, + chunk, + max_nan, + wet_thr, + deltas, + steps, + valid_start_mask, + torch.device("cuda"), + ) + ) + cpu = _lexsort( + _process_chunk( + (start_t, end_t), + t_start_idx, + data.copy(), + max_nan, + wet_thr, + deltas, + steps, + valid_start_mask, + ) + ) + ref = _brute_force(chunk, deltas, steps, max_nan, wet_thr, start_t, t_start_idx, valid_start_mask) + + assert gpu["t"].size == cpu["t"].size == ref["t"].size + # dtypes match the parquet schema + for k in ("t", "x", "y", "nan_count"): + assert gpu[k].dtype == np.int32 + for k in ("sum", "mean", "frac_wet"): + assert gpu[k].dtype == np.float32 + + # exact integer columns vs CPU and brute force + for k in ("t", "x", "y", "nan_count"): + assert np.array_equal(gpu[k], cpu[k]), f"{k} differs from CPU" + assert np.array_equal(gpu[k].astype(np.int64), ref[k].astype(np.int64)), f"{k} differs from brute" + + # float columns: allclose. wet_count is exact (integer), but frac_wet's + # division and the sum reduction round ~1 ULP differently on the GPU. + assert np.allclose(gpu["frac_wet"], cpu["frac_wet"], rtol=1e-6, atol=1e-7), "frac_wet vs CPU" + assert np.allclose(gpu["sum"], ref["sum"], rtol=1e-4, atol=1e-2), "sum mismatch vs brute" + assert np.allclose(gpu["mean"], ref["mean"], rtol=1e-4, atol=1e-3, equal_nan=True), "mean mismatch" diff --git a/tests/sampling/test_stats_process.py b/tests/sampling/test_stats_process.py new file mode 100644 index 0000000..bde4a53 --- /dev/null +++ b/tests/sampling/test_stats_process.py @@ -0,0 +1,239 @@ +"""Correctness tests for `_process_chunk`. + +`_process_chunk` is checked against two independent reference +implementations: + +1. ``_reference_all_positions`` — windows every position, then filters by + max_nan / stride / continuity (no striding between the cumsum axes). +2. ``_brute_force`` — sums each candidate window directly, with no cumsum + trick at all. + +Integer columns (t, x, y, nan_count, frac_wet) must match exactly; sum and +mean are compared with ``allclose``, since float summation order differs +between the implementations. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +pytest.importorskip("bottleneck") # the stats command pulls bottleneck (the sampling extra) + +from mlcast.sampling.commands.stats import ( + _datacube_window_sum, + _process_chunk, +) + +STAT_KEYS = ("t", "x", "y", "nan_count", "sum", "mean", "frac_wet") + + +# --- Reference 1: window every position, then filter ------------------------- + + +def _reference_all_positions(time_range, t_start_idx, data, max_nan, wet_threshold, deltas, steps, valid_starts_gap): + start_t, end_t = time_range + chunk = data[start_t + t_start_idx : end_t + t_start_idx, :, :].astype(np.float32, copy=False) + dim_lengths = chunk.shape + Dt, w, h = deltas + total_px = Dt * w * h + nan_mask = np.isnan(chunk) + nan_count_win = _datacube_window_sum(nan_mask.astype(np.int16), deltas, dim_lengths) + valid_mask = nan_count_win <= max_nan + idx_t_rel, idx_x, idx_y = np.where(valid_mask) + idx_t_rel = idx_t_rel.astype(np.int32) + idx_x = idx_x.astype(np.int32) + idx_y = idx_y.astype(np.int32) + idx_t_abs_rel = idx_t_rel + start_t + time_mask = np.isin(idx_t_abs_rel, valid_starts_gap) + idx_t_abs = idx_t_abs_rel + t_start_idx + stride_mask = (idx_t_abs % steps[0] == 0) & (idx_x % steps[1] == 0) & (idx_y % steps[2] == 0) + keep = time_mask & stride_mask + idx_t_rel = idx_t_rel[keep] + idx_x = idx_x[keep] + idx_y = idx_y[keep] + idx_t_abs = idx_t_abs[keep] + nan_count = nan_count_win[idx_t_rel, idx_x, idx_y] + chunk[nan_mask] = 0.0 + sum_win = _datacube_window_sum(chunk, deltas, dim_lengths) + sum_vals = sum_win[idx_t_rel, idx_x, idx_y] + wet_mask_i = (chunk > wet_threshold).astype(np.int16) + wet_count_win = _datacube_window_sum(wet_mask_i, deltas, dim_lengths) + wet_count = wet_count_win[idx_t_rel, idx_x, idx_y] + valid_count = total_px - nan_count + with np.errstate(invalid="ignore", divide="ignore"): + mean_vals = np.where(valid_count > 0, sum_vals / valid_count, np.nan).astype(np.float32) + frac_wet = wet_count.astype(np.float32) / total_px + return { + "t": idx_t_abs, + "x": idx_x, + "y": idx_y, + "nan_count": nan_count, + "sum": sum_vals.astype(np.float32), + "mean": mean_vals, + "frac_wet": frac_wet, + } + + +# --- Reference 2: naive per-window brute force ------------------------------- + + +def _brute_force(chunk, deltas, steps, max_nan, wet_threshold, start_t, t_start_idx, valid_start_mask): + Dt, w, h = deltas + step_t, step_x, step_y = steps + T, X, Y = chunk.shape + total_px = Dt * w * h + # Every valid window start: it in [0, T-Dt], ix in [0, X-w], iy in [0, Y-h] + # (inclusive upper bound — the final window on each axis is included). + rows = [] + for it in range(T - Dt + 1): + if not valid_start_mask[it + start_t]: + continue + if (it + start_t + t_start_idx) % step_t != 0: + continue + for ix in range(0, X - w + 1, step_x): + for iy in range(0, Y - h + 1, step_y): + win = chunk[it : it + Dt, ix : ix + w, iy : iy + h] + filled = np.where(np.isnan(win), np.float32(0.0), win) + nan_c = int(np.isnan(win).sum()) + if nan_c > max_nan: + continue + valid = total_px - nan_c + rows.append( + ( + it + start_t + t_start_idx, + ix, + iy, + nan_c, + float(filled.sum()), + float(filled.sum() / valid) if valid > 0 else np.nan, + int((filled > wet_threshold).sum()) / total_px, + ) + ) + rows.sort(key=lambda r: (r[0], r[1], r[2])) + cols = list(zip(*rows, strict=True)) if rows else [()] * 7 + return {k: np.array(c) for k, c in zip(STAT_KEYS, cols, strict=True)} + + +def _lexsort(d): + o = np.lexsort((d["y"], d["x"], d["t"])) + return {k: d[k][o] for k in STAT_KEYS} + + +# --- Test data --------------------------------------------------------------- + + +def _make_data(seed, T_total, X, Y, nan_blocks=True): + rng = np.random.default_rng(seed) + data = (rng.random((T_total, X, Y), dtype=np.float32) ** 6) * 30.0 + # transient scattered NaNs + data[rng.random((T_total, X, Y)) < 0.03] = np.nan + if nan_blocks: + # a static "out of domain" block (like the real radar mask) + data[:, : X // 4, :] = np.nan + # a fully-NaN frame region to exercise the mean==NaN path + data[2:5, X // 2 : X // 2 + 14, Y // 2 : Y // 2 + 12] = np.nan + return data + + +# step_t such that off_t != 0 is exercised via t_start_idx +CASES = [ + # (seed, deltas, steps, max_nan, wet_thr, start_t, t_start_idx) + (0, (6, 12, 10), (2, 3, 3), 5, 0.5, 0, 0), + (1, (6, 12, 10), (3, 4, 4), 50, 0.5, 0, 2), # off_t = (-2)%3 = 1 + (2, (8, 10, 10), (2, 5, 5), 0, 1.0, 4, 1), # off_t with start_t and t_start_idx + (3, (6, 14, 12), (4, 3, 3), 6 * 14 * 12, 0.5, 0, 3), # max_nan == total_px -> all-NaN survives +] + + +@pytest.mark.parametrize("seed,deltas,steps,max_nan,wet_thr,start_t,t_start_idx", CASES) +def test_matches_reference(seed, deltas, steps, max_nan, wet_thr, start_t, t_start_idx): + T_total, X, Y = 60, 44, 40 + data = _make_data(seed, T_total, X, Y) + size_T = T_total - t_start_idx + Dt = deltas[0] + # gap-free starts, with a couple of gaps punched in + rng = np.random.default_rng(seed + 100) + valid_starts_gap = np.arange(size_T - Dt + 1, dtype=np.int64) + drop = rng.choice(valid_starts_gap, size=2, replace=False) + valid_starts_gap = np.setdiff1d(valid_starts_gap, drop) + valid_start_mask = np.zeros(size_T, dtype=bool) + valid_start_mask[valid_starts_gap] = True + + end_t = size_T - start_t # cover the rest of the filtered region from start_t + time_range = (start_t, end_t) + + # _process_chunk zero-fills its chunk in place. For a zarr array `data[slice]` + # returns a fresh array so that's harmless; for a plain numpy `data` the slice + # is a *view*, so pass each impl its own copy to avoid cross-call mutation. + out = _process_chunk(time_range, t_start_idx, data.copy(), max_nan, wet_thr, deltas, steps, valid_start_mask) + ref = _reference_all_positions( + time_range, t_start_idx, data.copy(), max_nan, wet_thr, deltas, steps, valid_starts_gap + ) + + # Exact on the integer columns; allclose on the float sum/mean (float + # summation order differs between the two implementations). + for k in ("t", "x", "y", "nan_count", "frac_wet"): + assert np.array_equal(out[k], ref[k]), f"column {k} differs from reference" + assert np.allclose(out["sum"], ref["sum"], rtol=1e-5, atol=1e-3), "sum diverges from reference" + assert np.allclose(out["mean"], ref["mean"], rtol=1e-5, atol=1e-4, equal_nan=True), "mean diverges" + # dtypes must match the parquet schema expectations + for k in ("t", "x", "y", "nan_count"): + assert out[k].dtype == np.int32, f"{k} dtype {out[k].dtype}" + for k in ("sum", "mean", "frac_wet"): + assert out[k].dtype == np.float32, f"{k} dtype {out[k].dtype}" + + +@pytest.mark.parametrize("seed,deltas,steps,max_nan,wet_thr,start_t,t_start_idx", CASES) +def test_matches_brute_force(seed, deltas, steps, max_nan, wet_thr, start_t, t_start_idx): + T_total, X, Y = 60, 44, 40 + data = _make_data(seed, T_total, X, Y) + size_T = T_total - t_start_idx + valid_start_mask = np.ones(size_T, dtype=bool) # all continuous for the oracle + + end_t = size_T - start_t + chunk = data[start_t + t_start_idx : end_t + t_start_idx, :, :].astype(np.float32) # snapshot copy + + new = _lexsort( + _process_chunk((start_t, end_t), t_start_idx, data.copy(), max_nan, wet_thr, deltas, steps, valid_start_mask) + ) + ref = _brute_force(chunk, deltas, steps, max_nan, wet_thr, start_t, t_start_idx, valid_start_mask) + + assert new["t"].size == ref["t"].size, "different number of survivors" + for k in ("t", "x", "y", "nan_count"): + assert np.array_equal(new[k].astype(np.int64), ref[k].astype(np.int64)), f"{k} mismatch vs brute force" + assert np.allclose(new["sum"], ref["sum"], rtol=1e-4, atol=1e-2), "sum mismatch vs brute force" + assert np.allclose(new["mean"], ref["mean"], rtol=1e-4, atol=1e-3, equal_nan=True), "mean mismatch" + assert np.allclose(new["frac_wet"], ref["frac_wet"], rtol=0, atol=1e-6), "frac_wet mismatch" + + +def test_window_sum_order_invariant_for_integers(): + rng = np.random.default_rng(0) + mask = (rng.random((20, 30, 28)) < 0.3).astype(np.int16) + deltas, dl = (5, 8, 7), mask.shape + a = _datacube_window_sum(mask, deltas, dl, order=(0, 1, 2)) + b = _datacube_window_sum(mask, deltas, dl, order=(2, 1, 0)) + assert np.array_equal(a, b) + + +def test_window_sum_includes_last_window(): + # Output length must be dim_len - delta + 1, and the final window start + # (at dim_len - delta) must hold the correct sum. + rng = np.random.default_rng(1) + arr = rng.integers(0, 5, size=(11, 13, 9)).astype(np.int16) + deltas, dl = (4, 5, 3), arr.shape + out = _datacube_window_sum(arr, deltas, dl) + assert out.shape == (dl[0] - deltas[0] + 1, dl[1] - deltas[1] + 1, dl[2] - deltas[2] + 1) + # last window on every axis, compared to a direct sum + lt, lx, ly = dl[0] - deltas[0], dl[1] - deltas[1], dl[2] - deltas[2] + expected = int(arr[lt : lt + deltas[0], lx : lx + deltas[1], ly : ly + deltas[2]].sum()) + assert int(out[lt, lx, ly]) == expected + + +def test_empty_when_all_windows_fail(): + # max_nan = -1 -> nothing passes + data = np.zeros((20, 30, 28), dtype=np.float32) + mask = np.ones(20, dtype=bool) + out = _process_chunk((0, 20), 0, data, -1, 0.5, (5, 8, 7), (2, 4, 4), mask) + for k in STAT_KEYS: + assert out[k].size == 0 diff --git a/tests/sampling/test_stats_spec.py b/tests/sampling/test_stats_spec.py new file mode 100644 index 0000000..cc283c2 --- /dev/null +++ b/tests/sampling/test_stats_spec.py @@ -0,0 +1,197 @@ +"""Tests for the stats-parquet contract: schema, metadata, validation.""" + +from __future__ import annotations + +import json + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import pytest +from pydantic import ValidationError + +from mlcast.sampling.stats_spec import ( + SCHEMA_VERSION, + STATS_METADATA_KEY, + STATS_SCHEMA, + StatsMetadata, + build_schema, + read_metadata, + validate_stats_parquet, +) + + +def _good_params(**overrides): + params = { + "zarr_path": "/data/it-dpc.zarr", + "data_var": "RR", + "time_var": "time", + "start_date": "2020-01-01T00:00:00", + "end_date": "2020-12-31T00:00:00", + "time_step_minutes": 5, + "time_depth": 24, + "width": 256, + "height": 256, + "step_t": 3, + "step_x": 16, + "step_y": 16, + "max_nan": 10000, + "wet_threshold": 0.1, + "data_kind": "rainrate", + "units": "mm/h", + } + params.update(overrides) + return params + + +def _write_parquet(path, metadata: StatsMetadata, rows: dict): + schema = build_schema(metadata) + batch = pa.record_batch([pa.array(rows[c]) for c in STATS_SCHEMA.names], schema=schema) + with pq.ParquetWriter(path, schema, compression="zstd") as w: + w.write_batch(batch) + + +def _good_rows(meta: StatsMetadata, n: int = 5): + total = meta.total_px + # All nan_count <= max_nan (the hard filter) and < total_px, so every + # window has valid pixels and a finite mean — what the stats command emits. + nan_count = np.array([0, 1, 100, meta.max_nan // 2, meta.max_nan], dtype=np.int32)[:n] + valid = total - nan_count + mean = np.where(valid > 0, 0.5, np.nan).astype(np.float32) + return { + "t": (np.arange(n, dtype=np.int32) * meta.step_t), + "x": (np.arange(n, dtype=np.int32) * meta.step_x), + "y": (np.arange(n, dtype=np.int32) * meta.step_y), + "nan_count": nan_count, + "sum": (mean * valid).astype(np.float32), + "mean": mean, + "frac_wet": np.full(n, 0.3, dtype=np.float32), + } + + +# --- StatsMetadata model ------------------------------------------------------ + + +def test_metadata_roundtrip(): + meta = StatsMetadata.model_validate(_good_params()) + assert meta.total_px == 24 * 256 * 256 + assert meta.schema_version == SCHEMA_VERSION + again = StatsMetadata.model_validate(meta.model_dump(mode="json")) + assert again == meta + + +def test_metadata_ignores_unknown_keys(): + meta = StatsMetadata.model_validate(_good_params(some_future_field=123)) + assert not hasattr(meta, "some_future_field") + + +@pytest.mark.parametrize( + "override", + [ + {"time_depth": 0}, + {"width": -1}, + {"max_nan": -5}, + {"data_kind": "snowfall"}, + {"wet_threshold": -0.1}, + {"start_date": "2021-01-01T00:00:00", "end_date": "2020-01-01T00:00:00"}, + {"max_nan": 24 * 256 * 256 + 1}, + ], +) +def test_metadata_rejects_bad_values(override): + with pytest.raises(ValidationError): + StatsMetadata.model_validate(_good_params(**override)) + + +def test_metadata_missing_required_key(): + params = _good_params() + del params["data_kind"] + with pytest.raises(ValidationError): + StatsMetadata.model_validate(params) + + +# --- File validation ---------------------------------------------------------- + + +def test_validate_clean_file(tmp_path): + meta = StatsMetadata.model_validate(_good_params()) + path = str(tmp_path / "stats.parquet") + _write_parquet(path, meta, _good_rows(meta)) + + report = validate_stats_parquet(path) + assert report.ok, report.errors + assert read_metadata(path) == meta + + +def test_validate_detects_bad_frac_wet(tmp_path): + meta = StatsMetadata.model_validate(_good_params()) + rows = _good_rows(meta) + rows["frac_wet"] = rows["frac_wet"].copy() + rows["frac_wet"][0] = 1.5 + path = str(tmp_path / "stats.parquet") + _write_parquet(path, meta, rows) + + report = validate_stats_parquet(path) + assert not report.ok + assert any("frac_wet" in e for e in report.errors) + + +def test_validate_detects_nan_count_over_max(tmp_path): + meta = StatsMetadata.model_validate(_good_params()) + rows = _good_rows(meta) + rows["nan_count"] = rows["nan_count"].copy() + rows["nan_count"][0] = meta.max_nan + 1 + path = str(tmp_path / "stats.parquet") + _write_parquet(path, meta, rows) + + report = validate_stats_parquet(path) + assert not report.ok + assert any("max_nan" in e for e in report.errors) + + +def test_validate_detects_wrong_dtype(tmp_path): + meta = StatsMetadata.model_validate(_good_params()) + # Build a schema where 't' is int64 instead of int32. + bad_schema = pa.schema( + [("t", pa.int64())] + [(n, STATS_SCHEMA.field(n).type) for n in STATS_SCHEMA.names[1:]] + ).with_metadata(build_schema(meta).metadata) + rows = _good_rows(meta) + path = str(tmp_path / "stats.parquet") + batch = pa.record_batch( + [pa.array(rows[c], type=bad_schema.field(c).type) for c in STATS_SCHEMA.names], + schema=bad_schema, + ) + with pq.ParquetWriter(path, bad_schema) as w: + w.write_batch(batch) + + report = validate_stats_parquet(path) + assert not report.ok + assert any("dtype" in e and "t'" in e for e in report.errors) + + +def test_validate_missing_metadata_key(tmp_path): + # A parquet with correct columns but no mlcast.stats metadata. + meta = StatsMetadata.model_validate(_good_params()) + rows = _good_rows(meta) + path = str(tmp_path / "plain.parquet") + batch = pa.record_batch([pa.array(rows[c]) for c in STATS_SCHEMA.names], schema=STATS_SCHEMA) + with pq.ParquetWriter(path, STATS_SCHEMA) as w: + w.write_batch(batch) + + report = validate_stats_parquet(path) + assert not report.ok + assert any("mlcast.stats" in e for e in report.errors) + + +def test_validate_corrupt_metadata_payload(tmp_path): + meta = StatsMetadata.model_validate(_good_params()) + rows = _good_rows(meta) + bad_meta = {STATS_METADATA_KEY: json.dumps({"time_depth": -1}).encode()} + schema = STATS_SCHEMA.with_metadata(bad_meta) + path = str(tmp_path / "corrupt.parquet") + batch = pa.record_batch([pa.array(rows[c]) for c in STATS_SCHEMA.names], schema=schema) + with pq.ParquetWriter(path, schema) as w: + w.write_batch(batch) + + report = validate_stats_parquet(path) + assert not report.ok + assert any("metadata" in e for e in report.errors) diff --git a/uv.lock b/uv.lock index b5175f4..7b3b8fa 100644 --- a/uv.lock +++ b/uv.lock @@ -277,6 +277,59 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/fc/24cc0a47c824f13933e210e9ad034b4fba22f7185b8d904c0fbf5a3b2be8/botocore-1.42.91-py3-none-any.whl", hash = "sha256:7a28c3cc6bfab5724ad18899d52402b776a0de7d87fa20c3c5270bcaaf199ce8", size = 14897344, upload-time = "2026-04-17T19:30:44.245Z" }, ] +[[package]] +name = "bottleneck" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/14/d8/6d641573e210768816023a64966d66463f2ce9fc9945fa03290c8a18f87c/bottleneck-1.6.0.tar.gz", hash = "sha256:028d46ee4b025ad9ab4d79924113816f825f62b17b87c9e1d0d8ce144a4a0e31", size = 104311, upload-time = "2025-09-08T16:30:38.617Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/96/9d51012d729f97de1e75aad986f3ba50956742a40fc99cbab4c2aa896c1c/bottleneck-1.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:69ef4514782afe39db2497aaea93b1c167ab7ab3bc5e3930500ef9cf11841db7", size = 100400, upload-time = "2025-09-08T16:29:44.464Z" }, + { url = "https://files.pythonhosted.org/packages/16/f4/4fcbebcbc42376a77e395a6838575950587e5eb82edf47d103f8daa7ba22/bottleneck-1.6.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:727363f99edc6dc83d52ed28224d4cb858c07a01c336c7499c0c2e5dd4fd3e4a", size = 375920, upload-time = "2025-09-08T16:29:45.52Z" }, + { url = "https://files.pythonhosted.org/packages/36/13/7fa8cdc41cbf2dfe0540f98e1e0caf9ffbd681b1a0fc679a91c2698adaf9/bottleneck-1.6.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:847671a9e392220d1dfd2ff2524b4d61ec47b2a36ea78e169d2aa357fd9d933a", size = 367922, upload-time = "2025-09-08T16:29:46.743Z" }, + { url = "https://files.pythonhosted.org/packages/13/7d/dccfa4a2792c1bdc0efdde8267e527727e517df1ff0d4976b84e0268c2f9/bottleneck-1.6.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:daef2603ab7b4ec4f032bb54facf5fa92dacd3a264c2fd9677c9fc22bcb5a245", size = 361379, upload-time = "2025-09-08T16:29:48.042Z" }, + { url = "https://files.pythonhosted.org/packages/93/42/21c0fad823b71c3a8904cbb847ad45136d25573a2d001a9cff48d3985fab/bottleneck-1.6.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:fc7f09bda980d967f2e9f1a746eda57479f824f66de0b92b9835c431a8c922d4", size = 371911, upload-time = "2025-09-08T16:29:49.366Z" }, + { url = "https://files.pythonhosted.org/packages/3b/b0/830ff80f8c74577d53034c494639eac7a0ffc70935c01ceadfbe77f590c2/bottleneck-1.6.0-cp311-cp311-win32.whl", hash = "sha256:1f78bad13ad190180f73cceb92d22f4101bde3d768f4647030089f704ae7cac7", size = 107831, upload-time = "2025-09-08T16:29:51.397Z" }, + { url = "https://files.pythonhosted.org/packages/6f/42/01d4920b0aa51fba503f112c90714547609bbe17b6ecfc1c7ae1da3183df/bottleneck-1.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:8f2adef59fdb9edf2983fe3a4c07e5d1b677c43e5669f4711da2c3daad8321ad", size = 113358, upload-time = "2025-09-08T16:29:52.602Z" }, + { url = "https://files.pythonhosted.org/packages/8d/72/7e3593a2a3dd69ec831a9981a7b1443647acb66a5aec34c1620a5f7f8498/bottleneck-1.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3bb16a16a86a655fdbb34df672109a8a227bb5f9c9cf5bb8ae400a639bc52fa3", size = 100515, upload-time = "2025-09-08T16:29:55.141Z" }, + { url = "https://files.pythonhosted.org/packages/b5/d4/e7bbea08f4c0f0bab819d38c1a613da5f194fba7b19aae3e2b3a27e78886/bottleneck-1.6.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0fbf5d0787af9aee6cef4db9cdd14975ce24bd02e0cc30155a51411ebe2ff35f", size = 377451, upload-time = "2025-09-08T16:29:56.718Z" }, + { url = "https://files.pythonhosted.org/packages/fe/80/a6da430e3b1a12fd85f9fe90d3ad8fe9a527ecb046644c37b4b3f4baacfc/bottleneck-1.6.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d08966f4a22384862258940346a72087a6f7cebb19038fbf3a3f6690ee7fd39f", size = 368303, upload-time = "2025-09-08T16:29:57.834Z" }, + { url = "https://files.pythonhosted.org/packages/30/11/abd30a49f3251f4538430e5f876df96f2b39dabf49e05c5836820d2c31fe/bottleneck-1.6.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:604f0b898b43b7bc631c564630e936a8759d2d952641c8b02f71e31dbcd9deaa", size = 361232, upload-time = "2025-09-08T16:29:59.104Z" }, + { url = "https://files.pythonhosted.org/packages/1d/ac/1c0e09d8d92b9951f675bd42463ce76c3c3657b31c5bf53ca1f6dd9eccff/bottleneck-1.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d33720bad761e642abc18eda5f188ff2841191c9f63f9d0c052245decc0faeb9", size = 373234, upload-time = "2025-09-08T16:30:00.488Z" }, + { url = "https://files.pythonhosted.org/packages/fb/ea/382c572ae3057ba885d484726bb63629d1f63abedf91c6cd23974eb35a9b/bottleneck-1.6.0-cp312-cp312-win32.whl", hash = "sha256:a1e5907ec2714efbe7075d9207b58c22ab6984a59102e4ecd78dced80dab8374", size = 108020, upload-time = "2025-09-08T16:30:01.773Z" }, + { url = "https://files.pythonhosted.org/packages/48/ad/d71da675eef85ac153eef5111ca0caa924548c9591da00939bcabba8de8e/bottleneck-1.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:81e3822499f057a917b7d3972ebc631ac63c6bbcc79ad3542a66c4c40634e3a6", size = 113493, upload-time = "2025-09-08T16:30:02.872Z" }, + { url = "https://files.pythonhosted.org/packages/97/1a/e117cd5ff7056126d3291deb29ac8066476e60b852555b95beb3fc9d62a0/bottleneck-1.6.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d015de414ca016ebe56440bdf5d3d1204085080527a3c51f5b7b7a3e704fe6fd", size = 100521, upload-time = "2025-09-08T16:30:03.89Z" }, + { url = "https://files.pythonhosted.org/packages/bd/22/05555a9752357e24caa1cd92324d1a7fdde6386aab162fcc451f8f8eedc2/bottleneck-1.6.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:456757c9525b0b12356f472e38020ed4b76b18375fd76e055f8d33fb62956f5e", size = 377719, upload-time = "2025-09-08T16:30:05.135Z" }, + { url = "https://files.pythonhosted.org/packages/11/ee/76593af47097d9633109bed04dbcf2170707dd84313ca29f436f9234bc51/bottleneck-1.6.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1c65254d51b6063c55f6272f175e867e2078342ae75f74be29d6612e9627b2c0", size = 368577, upload-time = "2025-09-08T16:30:06.387Z" }, + { url = "https://files.pythonhosted.org/packages/f9/f7/4dcacaf637d2b8d89ea746c74159adda43858d47358978880614c3fa4391/bottleneck-1.6.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a172322895fbb79c6127474f1b0db0866895f0b804a18d5c6b841fea093927fe", size = 361441, upload-time = "2025-09-08T16:30:07.613Z" }, + { url = "https://files.pythonhosted.org/packages/05/34/21eb1eb1c42cb7be2872d0647c292fc75768d14e1f0db66bf907b24b2464/bottleneck-1.6.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d5e81b642eb0d5a5bf00312598d7ed142d389728b694322a118c26813f3d1fa9", size = 373416, upload-time = "2025-09-08T16:30:08.899Z" }, + { url = "https://files.pythonhosted.org/packages/48/cb/7957ff40367a151139b5f1854616bf92e578f10804d226fbcdecfd73aead/bottleneck-1.6.0-cp313-cp313-win32.whl", hash = "sha256:543d3a89d22880cd322e44caff859af6c0489657bf9897977d1f5d3d3f77299c", size = 108029, upload-time = "2025-09-08T16:30:09.909Z" }, + { url = "https://files.pythonhosted.org/packages/90/a8/735df4156fa5595501d5d96a6ee102f49c13d2ce9e2a287ad51806bc3ba0/bottleneck-1.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:48a44307d604ceb81e256903e5d57d3adb96a461b1d3c6a69baa2c67e823bd36", size = 113497, upload-time = "2025-09-08T16:30:10.82Z" }, + { url = "https://files.pythonhosted.org/packages/c7/5c/8c1260df8ade7cebc2a8af513a27082b5e36aa4a5fb762d56ea6d969d893/bottleneck-1.6.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:547e6715115867c4657c9ae8cc5ddac1fec8fdad66690be3a322a7488721b06b", size = 101606, upload-time = "2025-09-08T16:30:11.935Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ea/f03e2944e91ee962922c834ed21e5be6d067c8395681f5dc6c67a0a26853/bottleneck-1.6.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5e4a4a6e05b6f014c307969129e10d1a0afd18f3a2c127b085532a4a76677aef", size = 391804, upload-time = "2025-09-08T16:30:13.13Z" }, + { url = "https://files.pythonhosted.org/packages/0b/58/2b356b8a81eb97637dccee6cf58237198dd828890e38be9afb4e5e58e38e/bottleneck-1.6.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2baae0d1589b4a520b2f9cf03528c0c8b20717b3f05675e212ec2200cf628f12", size = 383443, upload-time = "2025-09-08T16:30:14.318Z" }, + { url = "https://files.pythonhosted.org/packages/55/52/cf7d09ed3736ad0d50c624787f9b580ae3206494d95cc0f4814b93eef728/bottleneck-1.6.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2e407139b322f01d8d5b6b2e8091b810f48a25c7fa5c678cfcdc420dfe8aea0a", size = 375458, upload-time = "2025-09-08T16:30:15.379Z" }, + { url = "https://files.pythonhosted.org/packages/c4/e9/7c87a34a24e339860064f20fac49f6738e94f1717bc8726b9c47705601d8/bottleneck-1.6.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1adefb89b92aba6de9c6ea871d99bcd29d519f4fb012cc5197917813b4fc2c7f", size = 386384, upload-time = "2025-09-08T16:30:17.012Z" }, + { url = "https://files.pythonhosted.org/packages/59/57/db51855e18a47671801180be748939b4c9422a0544849af1919116346b5f/bottleneck-1.6.0-cp313-cp313t-win32.whl", hash = "sha256:64b8690393494074923780f6abdf5f5577d844b9d9689725d1575a936e74e5f0", size = 109448, upload-time = "2025-09-08T16:30:18.076Z" }, + { url = "https://files.pythonhosted.org/packages/bd/1e/683c090b624f13a5bf88a0be2241dc301e98b2fb72a45812a7ae6e456cc4/bottleneck-1.6.0-cp313-cp313t-win_amd64.whl", hash = "sha256:cb67247f65dcdf62af947c76c6c8b77d9f0ead442cac0edbaa17850d6da4e48d", size = 115190, upload-time = "2025-09-08T16:30:19.106Z" }, + { url = "https://files.pythonhosted.org/packages/77/e2/eb7c08964a3f3c4719f98795ccd21807ee9dd3071a0f9ad652a5f19196ff/bottleneck-1.6.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:98f1d789042511a0f042b3bdcd2903e8567e956d3aa3be189cce3746daeb8550", size = 100544, upload-time = "2025-09-08T16:30:20.22Z" }, + { url = "https://files.pythonhosted.org/packages/99/ec/c6f3be848f37689f481797ce7d9807d5f69a199d7fc0e46044f9b708c468/bottleneck-1.6.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1fad24c99e39ad7623fc2a76d37feb26bd32e4dd170885edf4dbf4bfce2199a3", size = 378315, upload-time = "2025-09-08T16:30:21.409Z" }, + { url = "https://files.pythonhosted.org/packages/bf/8f/2d6600836e2ea8f14fcefac592dc83497e5b88d381470c958cb9cdf88706/bottleneck-1.6.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:643e61e50a6f993debc399b495a1609a55b3bd76b057e433e4089505d9f605c7", size = 368978, upload-time = "2025-09-08T16:30:23.458Z" }, + { url = "https://files.pythonhosted.org/packages/9b/b5/bf72b49f5040212873b985feef5050015645e0a02204b591e1d265fc522a/bottleneck-1.6.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fa668efbe4c6b200524ea0ebd537212da9b9801287138016fdf64119d6fcf201", size = 362074, upload-time = "2025-09-08T16:30:24.71Z" }, + { url = "https://files.pythonhosted.org/packages/1d/c8/c4891a0604eb680031390182c6e264247e3a9a8d067d654362245396fadf/bottleneck-1.6.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9f7dd35262e89e28fedd79d45022394b1fa1aceb61d2e747c6d6842e50546daa", size = 374019, upload-time = "2025-09-08T16:30:26.438Z" }, + { url = "https://files.pythonhosted.org/packages/e6/2d/ed096f8d1b9147e84914045dd89bc64e3c32eee49b862d1e20d573a9ab0d/bottleneck-1.6.0-cp314-cp314-win32.whl", hash = "sha256:bd90bec3c470b7fdfafc2fbdcd7a1c55a4e57b5cdad88d40eea5bc9bab759bf1", size = 110173, upload-time = "2025-09-08T16:30:27.521Z" }, + { url = "https://files.pythonhosted.org/packages/33/70/1414acb6ae378a15063cfb19a0a39d69d1b6baae1120a64d2b069902549b/bottleneck-1.6.0-cp314-cp314-win_amd64.whl", hash = "sha256:b43b6d36a62ffdedc6368cf9a708e4d0a30d98656c2b5f33d88894e1bcfd6857", size = 115899, upload-time = "2025-09-08T16:30:28.524Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ed/4570b5d8c1c85ce3c54963ebc37472231ed54f0b0d8dbb5dde14303f775f/bottleneck-1.6.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:53296707a8e195b5dcaa804b714bd222b5e446bd93cd496008122277eb43fa87", size = 101615, upload-time = "2025-09-08T16:30:29.556Z" }, + { url = "https://files.pythonhosted.org/packages/2d/93/c148faa07ae91f266be1f3fad1fde95aa2449e12937f3f3df2dd720b86e0/bottleneck-1.6.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d6df19cc48a83efd70f6d6874332aa31c3f5ca06a98b782449064abbd564cf0e", size = 392411, upload-time = "2025-09-08T16:30:31.186Z" }, + { url = "https://files.pythonhosted.org/packages/6e/1c/e6ad221d345a059e7efb2ad1d46a22d9fdae0486faef70555766e1123966/bottleneck-1.6.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96bb3a52cb3c0aadfedce3106f93ab940a49c9d35cd4ed612e031f6deb27e80f", size = 384022, upload-time = "2025-09-08T16:30:32.364Z" }, + { url = "https://files.pythonhosted.org/packages/4f/40/5b15c01eb8c59d59bc84c94d01d3d30797c961f10ec190f53c27e05d62ab/bottleneck-1.6.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d1db9e831b69d5595b12e79aeb04cb02873db35576467c8dd26cdc1ee6b74581", size = 376004, upload-time = "2025-09-08T16:30:33.731Z" }, + { url = "https://files.pythonhosted.org/packages/74/f6/cb228f5949553a5c01d1d5a3c933f0216d78540d9e0bf8dd4343bb449681/bottleneck-1.6.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:4dd7ac619570865fcb7a0e8925df418005f076286ad2c702dd0f447231d7a055", size = 386909, upload-time = "2025-09-08T16:30:34.973Z" }, + { url = "https://files.pythonhosted.org/packages/09/9a/425065c37a67a9120bf53290371579b83d05bf46f3212cce65d8c01d470a/bottleneck-1.6.0-cp314-cp314t-win32.whl", hash = "sha256:7fb694165df95d428fe00b98b9ea7d126ef786c4a4b7d43ae2530248396cadcb", size = 111636, upload-time = "2025-09-08T16:30:36.044Z" }, + { url = "https://files.pythonhosted.org/packages/ad/23/c41006e42909ec5114a8961818412310aa54646d1eae0495dbff3598a095/bottleneck-1.6.0-cp314-cp314t-win_amd64.whl", hash = "sha256:174b80930ce82bd8456c67f1abb28a5975c68db49d254783ce2cb6983b4fea40", size = 117611, upload-time = "2025-09-08T16:30:37.055Z" }, +] + [[package]] name = "cachetools" version = "7.0.6" @@ -2285,6 +2338,9 @@ gpu-cu130 = [ { name = "torch", version = "2.11.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" } }, { name = "torchvision", version = "0.26.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" } }, ] +sampling = [ + { name = "bottleneck" }, +] [package.dev-dependencies] dev = [ @@ -2296,6 +2352,7 @@ requires-dist = [ { name = "absl-py", specifier = ">=2.4" }, { name = "aiohttp", marker = "extra == 'dev'", specifier = ">=3.9.3" }, { name = "beartype", specifier = ">=0.18" }, + { name = "bottleneck", marker = "extra == 'sampling'", specifier = ">=1.6" }, { name = "cf-xarray", specifier = ">=0.10" }, { name = "etils", specifier = ">=1.13" }, { name = "fiddle", specifier = ">=0.3" }, @@ -2332,7 +2389,7 @@ requires-dist = [ { name = "xarray", specifier = ">=2025.6.1" }, { name = "zarr", specifier = ">=3" }, ] -provides-extras = ["dev", "gpu-cu128", "gpu-cu130"] +provides-extras = ["dev", "gpu-cu128", "gpu-cu130", "sampling"] [package.metadata.requires-dev] dev = [{ name = "pytest", specifier = ">=9.0.3" }] From cc7d3aeacc9955bc06a3fb84e955ac97a5fd4c7f Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 15:42:45 +0200 Subject: [PATCH 05/15] Simplify: make bottleneck core, drop the [sampling] extra and fire bottleneck is a tiny C lib (pandas already loads it when present), so gating the data-prep CLI behind an extra was over-engineering. Making it core lets us delete the extra, the lazy 'install the extra' error path, the importorskip('bottleneck') test guards, and the E402 per-file ignore. Also drop fire, which was an unused dependency (+ its termcolor transitive). Command modules are still imported lazily in the CLI dispatch (keeps `mlcast train` startup light); producer tests now run by default (no --extra). Full suite green (114 passed, 1 skipped). --- pyproject.toml | 10 +-------- src/mlcast/__main__.py | 14 ++++++------- tests/sampling/test_stats_gpu.py | 8 +++---- tests/sampling/test_stats_process.py | 2 -- uv.lock | 31 +++------------------------- 5 files changed, 13 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 93047cd..7c17f19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,10 @@ dynamic = [ "version" ] dependencies = [ "absl-py>=2.4", "beartype>=0.18", + "bottleneck>=1.6", "cf-xarray>=0.10", "etils>=1.13", "fiddle>=0.3", - "fire>=0.7", "importlib-resources>=5", "ipykernel>=6.29.5", "jaxtyping>=0.2", @@ -78,11 +78,6 @@ optional-dependencies.gpu-cu130 = [ "torch>=2.7.1", "torchvision>=0.20", ] -# Producer side: the `mlcast stats` / `mlcast validate-stats` data-prep CLI. -# Only `bottleneck` is extra; zarr/xarray/pyarrow/rich/torch are already core. -optional-dependencies.sampling = [ - "bottleneck>=1.6", -] urls.Documentation = "https://github.com/mlcast-community/mlcast" urls.Homepage = "https://github.com/mlcast-community/mlcast" urls.Issues = "https://github.com/mlcast-community/mlcast/issues" @@ -116,9 +111,6 @@ lint.select = [ "W", # pycodestyle warnings ] lint.ignore = [ "F722" ] -# These tests `importorskip` the sampling extra before importing the command -# modules, so their imports are intentionally not at the top of the file. -lint.per-file-ignores."tests/sampling/test_stats_*.py" = [ "E402" ] [tool.mypy] strict = false diff --git a/src/mlcast/__main__.py b/src/mlcast/__main__.py index 0d7124e..0701d13 100644 --- a/src/mlcast/__main__.py +++ b/src/mlcast/__main__.py @@ -335,14 +335,12 @@ def train_main(argv: list[str]) -> None: def _run_sampling_command(name: str, remaining: list[str]) -> None: """Dispatch the data-prep subcommands (``stats`` / ``validate-stats``). - Imported lazily so ``mlcast train`` never pulls the sampling deps. + Imported lazily to keep ``mlcast train`` startup light. """ - try: - from mlcast.sampling import commands - except ImportError as exc: # bottleneck (the [sampling] extra) not installed - raise SystemExit(f"`mlcast {name}` needs the sampling extra: pip install 'mlcast[sampling]' ({exc})") from None from loguru import logger + from mlcast.sampling import commands + logger.remove() logger.add(sys.stderr, level="INFO") module = commands.stats if name == "stats" else commands.validate_stats @@ -388,12 +386,12 @@ def cli() -> None: ) # Data-prep subcommands (plain argparse, no Fiddle). Defined bare here; the - # command modules — and the heavy windowing deps — are imported lazily in - # the dispatch below, so `mlcast train` never pays for them. + # command modules are imported lazily in the dispatch below to keep + # `mlcast train` startup light. subparsers.add_parser( "stats", add_help=False, - help="Scan a Zarr dataset and write per-datacube stats to parquet (needs mlcast[sampling]).", + help="Scan a Zarr dataset and write per-datacube stats to parquet.", ) subparsers.add_parser( "validate-stats", diff --git a/tests/sampling/test_stats_gpu.py b/tests/sampling/test_stats_gpu.py index e383094..01c810d 100644 --- a/tests/sampling/test_stats_gpu.py +++ b/tests/sampling/test_stats_gpu.py @@ -9,11 +9,7 @@ import numpy as np import pytest - -torch = pytest.importorskip("torch") -pytest.importorskip("bottleneck") # the stats command pulls bottleneck (the sampling extra) -if not torch.cuda.is_available(): # pragma: no cover - pytest.skip("no CUDA GPU available", allow_module_level=True) +import torch # Reuse the CPU test's data generator, brute force, lexsort, and cases. from test_stats_process import CASES, _brute_force, _lexsort, _make_data @@ -21,6 +17,8 @@ from mlcast.sampling.commands import _stats_gpu from mlcast.sampling.commands.stats import _process_chunk +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="no CUDA GPU available") + @pytest.mark.parametrize("seed,deltas,steps,max_nan,wet_thr,start_t,t_start_idx", CASES) def test_gpu_matches_cpu_and_brute(seed, deltas, steps, max_nan, wet_thr, start_t, t_start_idx): diff --git a/tests/sampling/test_stats_process.py b/tests/sampling/test_stats_process.py index bde4a53..e0cffc1 100644 --- a/tests/sampling/test_stats_process.py +++ b/tests/sampling/test_stats_process.py @@ -18,8 +18,6 @@ import numpy as np import pytest -pytest.importorskip("bottleneck") # the stats command pulls bottleneck (the sampling extra) - from mlcast.sampling.commands.stats import ( _datacube_window_sum, _process_chunk, diff --git a/uv.lock b/uv.lock index 7b3b8fa..13c1f17 100644 --- a/uv.lock +++ b/uv.lock @@ -1137,18 +1137,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/47/dd9a212ef6e343a6857485ffe25bba537304f1913bdbed446a23f7f592e1/filelock-3.29.0-py3-none-any.whl", hash = "sha256:96f5f6344709aa1572bbf631c640e4ebeeb519e08da902c39a001882f30ac258", size = 39812, upload-time = "2026-04-19T15:39:08.752Z" }, ] -[[package]] -name = "fire" -version = "0.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "termcolor" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/c0/00/f8d10588d2019d6d6452653def1ee807353b21983db48550318424b5ff18/fire-0.7.1.tar.gz", hash = "sha256:3b208f05c736de98fb343310d090dcc4d8c78b2a89ea4f32b837c586270a9cbf", size = 88720, upload-time = "2025-08-16T20:20:24.175Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/4c/93d0f85318da65923e4b91c1c2ff03d8a458cbefebe3bc612a6693c7906d/fire-0.7.1-py3-none-any.whl", hash = "sha256:e43fd8a5033a9001e7e2973bab96070694b9f12f2e0ecf96d4683971b5ab1882", size = 115945, upload-time = "2025-08-16T20:20:22.87Z" }, -] - [[package]] name = "flask" version = "3.1.3" @@ -2289,10 +2277,10 @@ source = { editable = "." } dependencies = [ { name = "absl-py" }, { name = "beartype" }, + { name = "bottleneck" }, { name = "cf-xarray" }, { name = "etils" }, { name = "fiddle" }, - { name = "fire" }, { name = "importlib-resources" }, { name = "ipykernel" }, { name = "jaxtyping" }, @@ -2338,9 +2326,6 @@ gpu-cu130 = [ { name = "torch", version = "2.11.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" } }, { name = "torchvision", version = "0.26.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" } }, ] -sampling = [ - { name = "bottleneck" }, -] [package.dev-dependencies] dev = [ @@ -2352,11 +2337,10 @@ requires-dist = [ { name = "absl-py", specifier = ">=2.4" }, { name = "aiohttp", marker = "extra == 'dev'", specifier = ">=3.9.3" }, { name = "beartype", specifier = ">=0.18" }, - { name = "bottleneck", marker = "extra == 'sampling'", specifier = ">=1.6" }, + { name = "bottleneck", specifier = ">=1.6" }, { name = "cf-xarray", specifier = ">=0.10" }, { name = "etils", specifier = ">=1.13" }, { name = "fiddle", specifier = ">=0.3" }, - { name = "fire", specifier = ">=0.7" }, { name = "fsspec", marker = "extra == 'dev'", specifier = ">=2024.2" }, { name = "importlib-resources", specifier = ">=5" }, { name = "ipykernel", specifier = ">=6.29.5" }, @@ -2389,7 +2373,7 @@ requires-dist = [ { name = "xarray", specifier = ">=2025.6.1" }, { name = "zarr", specifier = ">=3" }, ] -provides-extras = ["dev", "gpu-cu128", "gpu-cu130", "sampling"] +provides-extras = ["dev", "gpu-cu128", "gpu-cu130"] [package.metadata.requires-dev] dev = [{ name = "pytest", specifier = ">=9.0.3" }] @@ -4406,15 +4390,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363, upload-time = "2023-10-23T21:23:35.583Z" }, ] -[[package]] -name = "termcolor" -version = "3.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/46/79/cf31d7a93a8fdc6aa0fbb665be84426a8c5a557d9240b6239e9e11e35fc5/termcolor-3.3.0.tar.gz", hash = "sha256:348871ca648ec6a9a983a13ab626c0acce02f515b9e1983332b17af7979521c5", size = 14434, upload-time = "2025-12-29T12:55:21.882Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/d1/8bb87d21e9aeb323cc03034f5eaf2c8f69841e40e4853c2627edf8111ed3/termcolor-3.3.0-py3-none-any.whl", hash = "sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5", size = 7734, upload-time = "2025-12-29T12:55:20.718Z" }, -] - [[package]] name = "threadpoolctl" version = "3.6.0" From a76b7154759bd3c546a9035f9507c50b2bb3d3e9 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 15:59:18 +0200 Subject: [PATCH 06/15] Speed up the mlcast CLI: defer the training stack to the train path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `mlcast -h` took 4.2s because importing __main__ eagerly pulled the whole training stack (torch, Fiddle, absl, the model/data config) and cli() built the Fiddle config graph just to render help — none of which -h, stats, or validate-stats need. - from __future__ import annotations + TYPE_CHECKING so the heavy imports aren't needed for type hints. - Move torch/Fiddle/absl/config imports and the absl flag definitions into a lazy _define_train_flags() + per-function imports, run only for `train`. - Build the rich `train` help lazily (a factory on the parser), so the config graph is constructed only for `mlcast train -h`. mlcast -h and import __main__: 4.22s -> 0.08s. train / train -h behaviour unchanged; full suite green. --- src/mlcast/__main__.py | 112 ++++++++++++++++++++++++++--------------- 1 file changed, 71 insertions(+), 41 deletions(-) diff --git a/src/mlcast/__main__.py b/src/mlcast/__main__.py index 0701d13..9c4202e 100644 --- a/src/mlcast/__main__.py +++ b/src/mlcast/__main__.py @@ -22,34 +22,45 @@ python -m mlcast train --config=config:another_experiment_function """ +from __future__ import annotations + import argparse import ast import sys - -import fiddle as fdl -import torch -from absl import app, flags -from fiddle import absl_flags -from rich import print as rprint -from rich.console import Console -from rich.text import Text - -from . import config # noqa: F401 — module must be importable for absl_flags -from .config import load_yaml_config, train_from_config, training_experiment - -FLAGS = flags.FLAGS - -_config = absl_flags.DEFINE_fiddle_config( - "config", - default_module=config, - help_string="Experiment configuration. Default is training_experiment.", -) - -flags.DEFINE_boolean( - "print_config_and_exit", - False, - "Print the resolved experiment config and exit without training.", -) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import fiddle as fdl + from rich.text import Text + +# The training stack (torch, Fiddle, absl, the model/data config) is heavy, so +# it is imported only on the `train` path; `mlcast -h`, `mlcast stats`, and +# `mlcast validate-stats` stay fast. These globals are populated by +# `_define_train_flags()`, called from `cli()` for the `train` command. +FLAGS = None +_config = None + + +def _define_train_flags() -> None: + """Import the Fiddle/absl machinery and define the `train` flags (once).""" + global FLAGS, _config + from absl import flags + from fiddle import absl_flags + + from . import config + + FLAGS = flags.FLAGS + if _config is None: + _config = absl_flags.DEFINE_fiddle_config( + "config", + default_module=config, + help_string="Experiment configuration. Default is training_experiment.", + ) + flags.DEFINE_boolean( + "print_config_and_exit", + False, + "Print the resolved experiment config and exit without training.", + ) def get_cli_examples(cfg: fdl.Buildable) -> list[tuple[str, str]]: @@ -104,6 +115,8 @@ def get_fiddler_examples() -> list[tuple[str, str]]: def _build_help_text(cfg: fdl.Buildable) -> Text: """Build Rich-highlighted help text for the ``train`` subcommand.""" + from rich.text import Text + t = Text() t.append("Train a model using a Fiddle configuration.\n\n", style="bold") @@ -156,17 +169,21 @@ def _build_help_text(cfg: fdl.Buildable) -> Text: class _RichHelpParser(argparse.ArgumentParser): """ArgumentParser that renders the description with Rich when ``--help`` is requested.""" - _rich_description: Text | None = None + # A callable returning a rich Text, invoked only when help is printed — so + # the (expensive) train config graph is built only for `mlcast train -h`. + _rich_description_factory = None def print_help(self, file=None) -> None: # type: ignore[override] + from rich.console import Console + console = Console(file=file or sys.stdout) # Print usage line first (plain argparse) formatter = self._get_formatter() formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups) console.print(formatter.format_help(), end="") - # Rich description - if self._rich_description is not None: - console.print(self._rich_description) + # Rich description (built on demand) + if self._rich_description_factory is not None: + console.print(self._rich_description_factory()) else: console.print(self.description or "") # Standard options section @@ -178,6 +195,19 @@ def print_help(self, file=None) -> None: # type: ignore[override] console.print(formatter2.format_help(), end="") +def _build_train_help() -> Text: + """Build the (expensive) Rich help for ``train`` — only when ``-h`` is asked.""" + from .config import training_experiment + + try: + cfg = training_experiment.as_buildable() + return _build_help_text(cfg) + except Exception: + from rich.text import Text + + return Text("Train a model. Overrides can be passed via --config set:key=value") + + def auto_quote_fiddle_strings(remaining: list[str]) -> list[str]: """Auto-quotes unquoted string values in Fiddle set: overrides. @@ -282,6 +312,8 @@ def _seed_fiddle_flag_from_yaml(yaml_path: str) -> None: yaml_path : str Path to the YAML config file to load. """ + from .config import load_yaml_config + cfg = load_yaml_config(yaml_path) FLAGS["config"]._value = cfg FLAGS["config"].first_command = "config" @@ -299,6 +331,10 @@ def train_main(argv: list[str]) -> None: argv : list of str The list of command-line arguments passed by absl. """ + import torch + from rich import print as rprint + + from .config import train_from_config # Catch legacy Fiddle flags and guide the user to the correct syntax. legacy_flags_used = [] @@ -356,15 +392,6 @@ def cli() -> None: overrides if no base configuration is provided, formats the `--help` output, and safely passes execution over to `absl.app.run`. """ - - # Dynamically generate help text showing Fiddle overrides - try: - cfg = training_experiment.as_buildable() - description_text = _build_help_text(cfg) - except Exception: - # Fallback if config generation fails during CLI initialization - description_text = "Train a model. Overrides can be passed via --config set:key=value" - parser = argparse.ArgumentParser( prog="mlcast", description="Entry point for mlcast. Uses Fiddle's absl_flags integration.", @@ -377,10 +404,10 @@ def cli() -> None: formatter_class=argparse.RawDescriptionHelpFormatter, add_help=False, ) - # Swap in our Rich-aware parser class and attach the highlighted description + # Swap in our Rich-aware parser; the highlighted description (which builds + # the Fiddle config graph) is produced lazily, only when `train -h` runs. train_parser.__class__ = _RichHelpParser - if isinstance(description_text, Text): - train_parser._rich_description = description_text # type: ignore[attr-defined] + train_parser._rich_description_factory = _build_train_help # type: ignore[attr-defined] train_parser.add_argument( "-h", "--help", action="help", default=argparse.SUPPRESS, help="Show this message and exit." ) @@ -402,6 +429,9 @@ def cli() -> None: args, remaining = parser.parse_known_args() if args.command == "train": + _define_train_flags() + from absl import app + # Case 1: user supplied a YAML file path as the base config # e.g. --config /path/to/config.yaml # Extract it from remaining; any set:/fiddler: flags that follow are From 6af947122bb588547831626e86ecd5614cc2e489 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 16:19:13 +0200 Subject: [PATCH 07/15] Fix test_orchestrator leaking a MagicMock/ dir into the repo The patched fdl.build returns a MagicMock trainer, so the config dump falls through to Path(trainer.default_root_dir)/config.yaml. An unconfigured mock's __fspath__ renders as MagicMock//, so every run wrote junk under the repo root. The old line set trainer.log_dir, which that fallback branch never reads. Point default_root_dir at tmp_path so the write lands in pytest's tmp dir. --- tests/config/test_orchestrator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/config/test_orchestrator.py b/tests/config/test_orchestrator.py index 42d5a36..3d82218 100644 --- a/tests/config/test_orchestrator.py +++ b/tests/config/test_orchestrator.py @@ -6,7 +6,10 @@ @patch("mlcast.config.orchestrator.fdl.build") def test_train_from_config_valid(mock_build, tmp_path): """Verify that a valid configuration passes validation and builds.""" - mock_build.return_value.trainer.log_dir = str(tmp_path) + # default_root_dir is where _log_experiment_config_yaml_file falls back to + # writing config.yaml when the (mocked) logger isn't a recognised type. + # Point it at tmp_path so the write lands in pytest's tmp dir, not the repo. + mock_build.return_value.trainer.default_root_dir = str(tmp_path) cfg = training_experiment.as_buildable() train_from_config(cfg) mock_build.assert_called_once() From 5e20962ea0ec26d7beb8908d18487912b7e4c302 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 16:29:03 +0200 Subject: [PATCH 08/15] docs: document the stats / validate-stats data-prep workflow Add a 'Preparing training data' README section covering mlcast stats (cumsum window scan -> stats parquet, with a flags table and example), mlcast validate-stats, and the sampler registry (uniform / importance, and the per-split default). Also document the missing use_ratio_splits fiddler, cross-link from the CLI section, and refresh the project tree with the sampling/ subpackage. --- README.md | 91 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 89 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ab27301..33af9b2 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,10 @@ mlcast train --config fiddler:use_random_sampler --print_config_and_exit Run `mlcast train --help` for a full list of examples and available fiddlers. +Beyond training, the CLI provides two data-prep subcommands — `mlcast stats` +and `mlcast validate-stats` — for building and checking the datacube index that +training reads. See [Preparing training data](#preparing-training-data). + ### Python API The Python API gives you full programmatic control over the config graph before @@ -257,14 +261,92 @@ experiment.run() # trainer.fit() + trainer.test() | `set_variables` | `standard_names` | Sets the list of input variables on the dataset and updates `network.input_channels` to match | | `toggle_masking` | `enabled` | Toggles masked-loss mode by setting both `dataset_factory.return_mask` and `pl_module.masked_loss` to the same value | | `use_anon_s3_dataset` | `zarr_path`, `endpoint_url` | Points the dataset at an anonymous S3 object store; sets `zarr_path` and the required `storage_options` together | +| `use_ratio_splits` | `train`, `val` | Sets fraction-based train/val/test time splits on the data module (test = 1 − train − val) | | `use_random_sampler` | *(none)* | Switches the dataset factory to the on-the-fly random sampler (useful during development when no precomputed sampling index is available) | +## Preparing training data + +Training reads a **stats parquet**: a precomputed index of candidate datacubes — +fixed-size `time_depth × width × height` crops of a source radar Zarr — each +tagged with per-cube statistics (`nan_count`, `sum`, `mean`, `frac_wet`). The +training dataset [`SourceDataIndexedDataset`](src/mlcast/data/source_data_datasets.py) +iterates this index (it is the dataset factory's `index_path`), and a **sampler** +reshapes the candidate pool at dataset init. The producer, the schema, and the +samplers all live in [`mlcast.sampling`](src/mlcast/sampling/); two CLI +subcommands build and check the file. + +### `mlcast stats` — scan a Zarr dataset → stats parquet + +Slides a window over the `(time, x, y)` grid and, for every candidate datacube, +computes its statistics in O(1) per window via a cumulative-sum trick. Windows +are filtered by a maximum-NaN budget, a spatial/temporal stride, and +time-continuity (no frame gaps); the survivors are streamed to a single +zstd-compressed parquet whose footer carries every parameter as metadata (the +contract in [`stats_spec`](src/mlcast/sampling/stats_spec.py)), so downstream +commands never have to parse the filename. + +```bash +# A year of radar → 24-frame 256×256 datacubes, stride 3×16×16 +mlcast stats /data/radar.zarr \ + --start-date 2020-01-01 --end-date 2020-12-31 \ + --time-depth 24 --width 256 --height 256 \ + --step-t 3 --step-x 16 --step-y 16 \ + --max-nan 10000 \ + -o stats_2020.parquet +``` + +Common flags (`mlcast stats -h` lists them all): + +| Flag | Default | Purpose | +|------|---------|---------| +| `--time-depth` / `--width` / `--height` | 24 / 256 / 256 | Datacube shape `(T, X, Y)` | +| `--step-t` / `--step-x` / `--step-y` | 3 / 16 / 16 | Stride between candidate window starts | +| `--max-nan` | 10000 | Drop any datacube with more NaNs than this | +| `--wet-threshold` | auto | Wet-pixel threshold; auto = 0.1 mm/h (rain rate) or 7 dBZ (reflectivity) | +| `--device` | auto | Compute backend: `auto` / `cpu` (bottleneck) / `cuda` | +| `--workers` | 8 | CPU worker processes, or GPU chunk-reader threads | +| `--data-var` / `--time-var` | RR / time | Names of the Zarr data and time variables | +| `-o` / `--output` | auto | Output path; auto-named from the parameters if omitted | + +### `mlcast validate-stats` — check a parquet against the contract + +Checks a stats parquet's column schema, metadata payload, and (unless +`--no-data-checks`) its per-row value invariants, then prints the file's +parameters, its table structure, and a preview of the first 10 rows. + +```bash +mlcast validate-stats stats_2020.parquet + +# Footer only — schema + metadata, skip the per-row checks +mlcast validate-stats stats_2020.parquet --no-data-checks +``` + +### Samplers + +At training time [`SourceDataDataModule`](src/mlcast/data/source_data_datamodule.py) +applies a `Sampler` to the index **once**, at dataset init, turning the full +candidate pool into the training set via a per-row keep/discard draw — so the +dataset length is fixed and known up front, and the same set is reused every +epoch. Samplers are pluggable through a registry in +[`mlcast.sampling`](src/mlcast/sampling/samplers.py): + +| Sampler | Parameters | What it does | +|---------|------------|--------------| +| `UniformSampler` | `keep_fraction` | Keep each candidate with a fixed probability, independent of its stats | +| `ImportanceSampler` | `column`, `q_min`, `scale`, `mean_weight` | Keep each candidate with probability rising with one of its statistic columns (`mean` by default) — oversampling high-rainfall datacubes without duplication | + +The default config applies an `ImportanceSampler` to the **train** split and a +`UniformSampler(keep_fraction=0.1)` to **val/test**, so importance sampling +reshapes only training while validation and test stay representative. Add a +scheme by subclassing `Sampler`, decorating it with `@register_sampler("name")`, +and selecting it from a config via `get_sampler("name", ...)`. + ## Project Structure ``` mlcast/ ├── src/mlcast/ -│ ├── __main__.py # CLI entry point (mlcast train) +│ ├── __main__.py # CLI entry point (train / stats / validate-stats) │ ├── nowcasting_module.py # Generic Lightning module for nowcasting │ ├── losses.py # CRPS, AFCRPS, MSE loss functions │ ├── callbacks.py # Training callbacks @@ -277,8 +359,13 @@ mlcast/ │ │ └── orchestrator.py # train_from_config, config persistence │ ├── data/ │ │ ├── source_data_datamodule.py # Lightning DataModule -│ │ ├── source_data_datasets.py # Zarr-backed PyTorch datasets +│ │ ├── source_data_datasets.py # Zarr-backed PyTorch datasets (reads the stats index) │ │ └── normalization.py # Normalisation registry +│ ├── sampling/ # Data-prep: stats parquet producer + sampler registry +│ │ ├── commands/ # stats / validate-stats CLI implementations +│ │ ├── samplers.py # Sampler registry (uniform, importance) +│ │ ├── stats_spec.py # Stats parquet schema + validation (the contract) +│ │ └── units.py # Data-kind + wet-threshold detection │ └── models/ │ └── convgru.py # ConvGRU encoder-decoder ├── tests/ From 3ddd2e63854e795c2d9fe1f767c216bb1b2eef09 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 21:19:30 +0200 Subject: [PATCH 09/15] style: restore import-group blank line in the example script ruff (the pinned v0.8.6 in pre-commit) flags the missing blank line between the third-party and first-party import groups. --- examples/scripts/simple_train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/scripts/simple_train.py b/examples/scripts/simple_train.py index 903024e..f6f5016 100644 --- a/examples/scripts/simple_train.py +++ b/examples/scripts/simple_train.py @@ -12,6 +12,7 @@ """ import fiddle as fdl + from mlcast.configs import convgru_experiment From b57d2606179681d30c3ace4da81f8e85aef0b273 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 21:19:30 +0200 Subject: [PATCH 10/15] ci: make the linting workflow actually run pre-commit uv run does not sync the dev extra, so pre-commit was never installed (the job failed to spawn it). Run it via --extra dev, and install graphviz for the local config-diagram hook's dot dependency. --- .github/workflows/pre-commit.yml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 85fed30..8d64cdf 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -18,6 +18,11 @@ jobs: with: enable-cache: true + # The local `config-diagram` hook renders the Fiddle config graph with + # graphviz, which shells out to the `dot` binary. + - run: sudo apt-get update && sudo apt-get install -y graphviz + # Run pre-commit through uv so CI uses the same project-managed environment # as local development, including local hooks that invoke `uv run ...`. - - run: uv run pre-commit run --all-files + # pre-commit lives in the `dev` extra, so it must be synced for `uv run`. + - run: uv run --extra dev pre-commit run --all-files From ba14ea80f5b1d43427c8c5444f6682dd4bd64de5 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 21:42:06 +0200 Subject: [PATCH 11/15] stats: match the validate-stats panel emoji MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use the same alignment-safe clock/calendar emoji as the validate-stats summary grid (📅 Time range, 🕒 Time step). The old ⏱️ carried a U+FE0F variation selector that rich mismeasures, misaligning the panel's right border. --- src/mlcast/sampling/commands/stats.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mlcast/sampling/commands/stats.py b/src/mlcast/sampling/commands/stats.py index f87b43f..f0fdbca 100644 --- a/src/mlcast/sampling/commands/stats.py +++ b/src/mlcast/sampling/commands/stats.py @@ -514,8 +514,8 @@ def run(args: argparse.Namespace) -> int: cfg.add_column(justify="right", style="bold cyan") cfg.add_column() cfg.add_row("Dataset", f"📦 T={size_T:,} X={size_X:,} Y={size_Y:,}") - cfg.add_row("Time range", f"🕐 {time_array[0]:%Y-%m-%d %H:%M} → {time_array[-1]:%Y-%m-%d %H:%M}") - cfg.add_row("Time step", f"⏱️ {args.time_step_minutes} min") + cfg.add_row("Time range", f"📅 {time_array[0]:%Y-%m-%d %H:%M} → {time_array[-1]:%Y-%m-%d %H:%M}") + cfg.add_row("Time step", f"🕒 {args.time_step_minutes} min") cfg.add_row("Datacube", f"🧊 {Dt} × {w} × {h} stride {step_T} × {step_X} × {step_Y}") cfg.add_row("Valid starts", f"✅ {len(valid_starts_gap):,} gap-free") cfg.add_row("Filters", f"🔍 max_nan={max_nan:,} wet > {wet_threshold:g} {units_str}") From 205fdf061f1057b85f4fb32456e88b0b1b46106d Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sat, 6 Jun 2026 23:54:42 +0200 Subject: [PATCH 12/15] stats: follow the spec's (time, y, x) axis order MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mlcast stats assigned x to axis 1 and y to axis 2, but the MLCast source-data spec (radar_precipitation §4.3) mandates dimension order (time, y=height, x=width), and the training dataset crops by dimension *name*. For a spec-compliant (time, y, x) store this transposed the parquet's x/y columns: x offsets ran over the y axis and vice-versa, so the dataset cropped the wrong region and ran out of bounds (e.g. a 112-wide crop instead of 256). Bind height/step_y to axis 1 and width/step_x to axis 2 in both the CPU and CUDA backends, and label survivor offsets y (axis 1) / x (axis 2). The two independent test oracles are flipped to the same convention. Verified end to end against a real (time, 1400, 1200) zarr: regenerated parquet now has x<=944 (fits the 1200 x dim) and y<=1144 (fits the 1400 y dim), and the dataset yields full 256x256 crops including the max-x/max-y rows. Note: parquets produced by the old code have x/y swapped and must be regenerated (or have their x/y columns swapped) to be consumable. --- src/mlcast/sampling/commands/_stats_gpu.py | 13 ++--- src/mlcast/sampling/commands/stats.py | 37 +++++++------ tests/sampling/test_stats_process.py | 60 +++++++++++----------- 3 files changed, 59 insertions(+), 51 deletions(-) diff --git a/src/mlcast/sampling/commands/_stats_gpu.py b/src/mlcast/sampling/commands/_stats_gpu.py index 9392df2..23775ad 100644 --- a/src/mlcast/sampling/commands/_stats_gpu.py +++ b/src/mlcast/sampling/commands/_stats_gpu.py @@ -63,9 +63,10 @@ def process_chunk( arrays in the same column layout as the CPU path. """ start_t, end_t = time_range - Dt, w, h = deltas - step_t, step_x, step_y = steps - total_px = Dt * w * h + # Data axes follow the MLCast source-data spec (§4.3): (time, y=height, x=width). + Dt, dy, dx = deltas + step_t, step_y, step_x = steps + total_px = Dt * dy * dx off_t = (-(start_t + t_start_idx)) % step_t chunk = torch.from_numpy(chunk_np).to(device, non_blocking=True) @@ -81,7 +82,7 @@ def process_chunk( # Pass A: nan_count. cumsum keeps exact integer counts (< 2^31). ncw = _strided_window(nan_mask.to(torch.int32), deltas, off_t, steps, keep_t) - a, b, c = torch.nonzero(ncw <= max_nan, as_tuple=True) + a, b, c = torch.nonzero(ncw <= max_nan, as_tuple=True) # a, b, c = axes time, y, x nan_count = ncw[a, b, c] # Pass B/C on the zero-filled chunk. @@ -90,8 +91,8 @@ def process_chunk( wet_count = _strided_window((chunk > wet_threshold).to(torch.int32), deltas, off_t, steps, keep_t)[a, b, c] idx_t_abs = (t_rel_kept[a] + (start_t + t_start_idx)).to(torch.int32) - idx_x = (b * step_x).to(torch.int32) - idx_y = (c * step_y).to(torch.int32) + idx_y = (b * step_y).to(torch.int32) + idx_x = (c * step_x).to(torch.int32) valid_count = (total_px - nan_count).to(torch.float32) mean_vals = torch.where(valid_count > 0, sum_vals / valid_count, torch.full_like(sum_vals, float("nan"))) diff --git a/src/mlcast/sampling/commands/stats.py b/src/mlcast/sampling/commands/stats.py index f0fdbca..107659f 100644 --- a/src/mlcast/sampling/commands/stats.py +++ b/src/mlcast/sampling/commands/stats.py @@ -229,9 +229,10 @@ def _process_chunk( start_t, end_t = time_range chunk = data[start_t + t_start_idx : end_t + t_start_idx, :, :].astype(np.float32, copy=False) dim_lengths = chunk.shape - Dt, w, h = deltas - step_t, step_x, step_y = steps - total_px = Dt * w * h + # Data axes follow the MLCast source-data spec (§4.3): (time, y=height, x=width). + Dt, dy, dx = deltas + step_t, step_y, step_x = steps + total_px = Dt * dy * dx # Strided, gap-free time-window starts, absolute-index aligned. The time # window count is dim_lengths[0] - Dt + 1 (see _dim_cumsum_window); the @@ -246,26 +247,26 @@ def _process_chunk( # --- Pass A: nan_count on the strided candidate grid ------------------------- ncw_s = _strided_window(nan_mask, deltas, dim_lengths, off_t, steps, keep_t) - surv_t, surv_x, surv_y = np.where(ncw_s <= max_nan) - nan_count = ncw_s[surv_t, surv_x, surv_y].astype(np.int32) + surv_t, surv_y, surv_x = np.where(ncw_s <= max_nan) + nan_count = ncw_s[surv_t, surv_y, surv_x].astype(np.int32) del ncw_s # Map strided survivor indices back to chunk-relative / absolute coords. idx_t_rel = t_rel_kept[surv_t] - idx_x = (surv_x * step_x).astype(np.int32) idx_y = (surv_y * step_y).astype(np.int32) + idx_x = (surv_x * step_x).astype(np.int32) idx_t_abs = (idx_t_rel + (start_t + t_start_idx)).astype(np.int32) # --- Pass B: sum ------------------------------------------------------------- chunk[nan_mask] = 0.0 - sum_vals = _strided_window(chunk, deltas, dim_lengths, off_t, steps, keep_t)[surv_t, surv_x, surv_y] + sum_vals = _strided_window(chunk, deltas, dim_lengths, off_t, steps, keep_t)[surv_t, surv_y, surv_x] # --- Pass C: wet_count ------------------------------------------------------- # `chunk` is now zero where it was NaN, so `> wet_threshold` is equivalent # to (value > threshold AND not NaN). wet_mask = chunk > wet_threshold del chunk, nan_mask - wet_count = _strided_window(wet_mask, deltas, dim_lengths, off_t, steps, keep_t)[surv_t, surv_x, surv_y] + wet_count = _strided_window(wet_mask, deltas, dim_lengths, off_t, steps, keep_t)[surv_t, surv_y, surv_x] # Derived stats. valid_count = total_px - nan_count @@ -395,6 +396,12 @@ def run(args: argparse.Namespace) -> int: n_workers = args.workers time_chunk_size = 3 * Dt + # Window extents and strides in the array's (time, y, x) axis order. Per the + # MLCast source-data spec the spatial axes are y=height (axis 1) and + # x=width (axis 2), so height/step_y bind to axis 1 and width/step_x to axis 2. + deltas = (Dt, h, w) + steps = (step_T, step_Y, step_X) + try: device, device_label = _resolve_device(args.device) except ValueError as e: @@ -408,7 +415,7 @@ def run(args: argparse.Namespace) -> int: data = zg[args.data_var] ds = xr.open_zarr(args.zarr_path) time_array_full = pd.DatetimeIndex(ds[args.time_var].values) - logger.info(f"Full dataset shape: T={data.shape[0]}, X={data.shape[1]}, Y={data.shape[2]}") + logger.info(f"Full dataset shape: T={data.shape[0]}, Y={data.shape[1]}, X={data.shape[2]}") logger.info(f"Time range: {time_array_full[0]} to {time_array_full[-1]}") var_attrs = dict(ds[args.data_var].attrs) except Exception as e: @@ -447,8 +454,8 @@ def run(args: argparse.Namespace) -> int: t_start_idx = valid_indices[0] t_end_idx = valid_indices[-1] + 1 size_T = t_end_idx - t_start_idx - size_X = data.shape[1] - size_Y = data.shape[2] + size_Y = data.shape[1] # axis 1 = y (height), per the source-data spec + size_X = data.shape[2] # axis 2 = x (width) time_array = time_array_full[t_start_idx:t_end_idx] logger.info(f"Filtered dataset shape: T={size_T}, X={size_X}, Y={size_Y}") @@ -574,8 +581,8 @@ def _read_chunk(tp): chunk_np, max_nan, wet_threshold, - (Dt, w, h), - (step_T, step_X, step_Y), + deltas, + steps, valid_start_mask, dev, ) @@ -588,8 +595,8 @@ def _read_chunk(tp): data=data, max_nan=max_nan, wet_threshold=wet_threshold, - deltas=(Dt, w, h), - steps=(step_T, step_X, step_Y), + deltas=deltas, + steps=steps, valid_start_mask=valid_start_mask, ) with progress: diff --git a/tests/sampling/test_stats_process.py b/tests/sampling/test_stats_process.py index e0cffc1..4151648 100644 --- a/tests/sampling/test_stats_process.py +++ b/tests/sampling/test_stats_process.py @@ -33,39 +33,39 @@ def _reference_all_positions(time_range, t_start_idx, data, max_nan, wet_thresho start_t, end_t = time_range chunk = data[start_t + t_start_idx : end_t + t_start_idx, :, :].astype(np.float32, copy=False) dim_lengths = chunk.shape - Dt, w, h = deltas - total_px = Dt * w * h + Dt, dy, dx = deltas # axes (time, y, x) + total_px = Dt * dy * dx nan_mask = np.isnan(chunk) nan_count_win = _datacube_window_sum(nan_mask.astype(np.int16), deltas, dim_lengths) valid_mask = nan_count_win <= max_nan - idx_t_rel, idx_x, idx_y = np.where(valid_mask) - idx_t_rel = idx_t_rel.astype(np.int32) - idx_x = idx_x.astype(np.int32) - idx_y = idx_y.astype(np.int32) - idx_t_abs_rel = idx_t_rel + start_t - time_mask = np.isin(idx_t_abs_rel, valid_starts_gap) - idx_t_abs = idx_t_abs_rel + t_start_idx - stride_mask = (idx_t_abs % steps[0] == 0) & (idx_x % steps[1] == 0) & (idx_y % steps[2] == 0) + it, iy, ix = np.where(valid_mask) # axis 0, 1 (y), 2 (x) + it = it.astype(np.int32) + iy = iy.astype(np.int32) + ix = ix.astype(np.int32) + it_abs_rel = it + start_t + time_mask = np.isin(it_abs_rel, valid_starts_gap) + it_abs = it_abs_rel + t_start_idx + stride_mask = (it_abs % steps[0] == 0) & (iy % steps[1] == 0) & (ix % steps[2] == 0) keep = time_mask & stride_mask - idx_t_rel = idx_t_rel[keep] - idx_x = idx_x[keep] - idx_y = idx_y[keep] - idx_t_abs = idx_t_abs[keep] - nan_count = nan_count_win[idx_t_rel, idx_x, idx_y] + it = it[keep] + iy = iy[keep] + ix = ix[keep] + it_abs = it_abs[keep] + nan_count = nan_count_win[it, iy, ix] chunk[nan_mask] = 0.0 sum_win = _datacube_window_sum(chunk, deltas, dim_lengths) - sum_vals = sum_win[idx_t_rel, idx_x, idx_y] + sum_vals = sum_win[it, iy, ix] wet_mask_i = (chunk > wet_threshold).astype(np.int16) wet_count_win = _datacube_window_sum(wet_mask_i, deltas, dim_lengths) - wet_count = wet_count_win[idx_t_rel, idx_x, idx_y] + wet_count = wet_count_win[it, iy, ix] valid_count = total_px - nan_count with np.errstate(invalid="ignore", divide="ignore"): mean_vals = np.where(valid_count > 0, sum_vals / valid_count, np.nan).astype(np.float32) frac_wet = wet_count.astype(np.float32) / total_px return { - "t": idx_t_abs, - "x": idx_x, - "y": idx_y, + "t": it_abs, + "y": iy, # axis 1 = y (height) + "x": ix, # axis 2 = x (width) "nan_count": nan_count, "sum": sum_vals.astype(np.float32), "mean": mean_vals, @@ -77,11 +77,11 @@ def _reference_all_positions(time_range, t_start_idx, data, max_nan, wet_thresho def _brute_force(chunk, deltas, steps, max_nan, wet_threshold, start_t, t_start_idx, valid_start_mask): - Dt, w, h = deltas - step_t, step_x, step_y = steps - T, X, Y = chunk.shape - total_px = Dt * w * h - # Every valid window start: it in [0, T-Dt], ix in [0, X-w], iy in [0, Y-h] + Dt, dy, dx = deltas # axes (time, y, x) + step_t, step_y, step_x = steps + T, NY, NX = chunk.shape + total_px = Dt * dy * dx + # Every valid window start: it in [0, T-Dt], iy in [0, NY-dy], ix in [0, NX-dx] # (inclusive upper bound — the final window on each axis is included). rows = [] for it in range(T - Dt + 1): @@ -89,9 +89,9 @@ def _brute_force(chunk, deltas, steps, max_nan, wet_threshold, start_t, t_start_ continue if (it + start_t + t_start_idx) % step_t != 0: continue - for ix in range(0, X - w + 1, step_x): - for iy in range(0, Y - h + 1, step_y): - win = chunk[it : it + Dt, ix : ix + w, iy : iy + h] + for iy in range(0, NY - dy + 1, step_y): + for ix in range(0, NX - dx + 1, step_x): + win = chunk[it : it + Dt, iy : iy + dy, ix : ix + dx] filled = np.where(np.isnan(win), np.float32(0.0), win) nan_c = int(np.isnan(win).sum()) if nan_c > max_nan: @@ -100,8 +100,8 @@ def _brute_force(chunk, deltas, steps, max_nan, wet_threshold, start_t, t_start_ rows.append( ( it + start_t + t_start_idx, - ix, - iy, + ix, # x = axis 2 (width) + iy, # y = axis 1 (height) nan_c, float(filled.sum()), float(filled.sum() / valid) if valid > 0 else np.nan, From d6c054aaa0d91231a75a09c8b1602e43f8c7c880 Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sun, 7 Jun 2026 00:01:57 +0200 Subject: [PATCH 13/15] config: default standard_names to spec-sanctioned rainfall_flux MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The default training config used 'rainfall_rate', which is not among the standard_names allowed by the source-data spec (radar_precipitation §4.4: rainfall_flux, precipitation_flux, equivalent_reflectivity_factor, precipitation_amount, rainfall_amount). Default to 'rainfall_flux' so a spec-compliant dataset trains without a set_variables override. Align the dataset docstring examples and regenerate the config diagram. --- docs/config_diagram.svg | 2 +- src/mlcast/config/base.py | 2 +- src/mlcast/data/source_data_datasets.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/config_diagram.svg b/docs/config_diagram.svg index 9c37543..6e0811f 100644 --- a/docs/config_diagram.svg +++ b/docs/config_diagram.svg @@ -332,7 +332,7 @@ list -'rainfall_rate' +'rainfall_flux' 0 diff --git a/src/mlcast/config/base.py b/src/mlcast/config/base.py index 0442a30..d8b083c 100644 --- a/src/mlcast/config/base.py +++ b/src/mlcast/config/base.py @@ -67,7 +67,7 @@ def training_experiment() -> Experiment: SourceDataIndexedDataset, zarr_path="./data/radar.zarr", index_path="./data/sampled_datacubes.parquet", - standard_names=["rainfall_rate"], + standard_names=["rainfall_flux"], input_steps=6, forecast_steps=12, return_mask=True, diff --git a/src/mlcast/data/source_data_datasets.py b/src/mlcast/data/source_data_datasets.py index 1e83b58..986ddc3 100644 --- a/src/mlcast/data/source_data_datasets.py +++ b/src/mlcast/data/source_data_datasets.py @@ -348,7 +348,7 @@ class SourceDataIndexedDataset(SourceDataDatasetBase): parquet (the candidate pool, optionally filtered by ``sampler``) or a legacy ``.csv`` (already sampled, used as-is). standard_names : list of str - List of CF standard names of variables to load (e.g., ``["rainfall_rate"]``). + List of CF standard names of variables to load (e.g., ``["rainfall_flux"]``). input_steps : int Number of past timesteps fed to the network as input. forecast_steps : int @@ -501,7 +501,7 @@ class SourceDataRandomSamplingDataset(SourceDataDatasetBase): zarr_path : str Path to the Zarr dataset. standard_names : list of str - List of CF standard names of variables to load (e.g., ``["rainfall_rate"]``). + List of CF standard names of variables to load (e.g., ``["rainfall_flux"]``). input_steps : int Number of past timesteps fed to the network as input. forecast_steps : int From a348c8ae6a99165a9a136065caeb5d1ffca04eff Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sun, 7 Jun 2026 00:25:15 +0200 Subject: [PATCH 14/15] data: rebase index t onto the sliced store for time subsets SourceDataIndexedDataset slices the zarr to the split's time subset (0-based) but kept the index's t as an absolute zarr index. For any split not starting at t=0 (val/test), __getitem__ indexed the sliced store with a huge absolute t, xarray clipped it to an empty crop, and the ConvGRU encoder hit 'stack expects a non-empty TensorList'. Filter to windows whose full depth fits the subset (no cross-split leakage), then rebase t = t - subset_start so it indexes the sliced store. Verified end to end: training runs to completion on the real 2010-2025 dataset, and crops carry the rain the parquet reports. The subset test now exercises a non-zero-start slice and asserts the rebased coordinates. --- src/mlcast/data/source_data_datasets.py | 12 +++++++++--- tests/data/test_source_data_datasets.py | 9 +++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/mlcast/data/source_data_datasets.py b/src/mlcast/data/source_data_datasets.py index 986ddc3..50d0fda 100644 --- a/src/mlcast/data/source_data_datasets.py +++ b/src/mlcast/data/source_data_datasets.py @@ -423,9 +423,15 @@ def __init__( if self._time_index_slice is not None: t_start = self._time_index_slice.start t_stop = self._time_index_slice.stop - self.coords = self.coords[(self.coords["t"] >= t_start) & (self.coords["t"] < t_stop)].reset_index( - drop=True - ) + # `t` is an absolute index into the full zarr, but `self.ds` is sliced to + # this subset (its time axis is 0-based). Keep only windows whose full + # depth fits inside the subset, then rebase `t` onto the sliced axis so + # `__getitem__` indexes it correctly and splits don't leak across the + # boundary. + self.coords = self.coords[ + (self.coords["t"] >= t_start) & (self.coords["t"] + time_depth <= t_stop) + ].reset_index(drop=True) + self.coords["t"] = self.coords["t"] - t_start if sampler is not None: selected = sampler.select(self.coords, np.random.default_rng(sampling_seed)) diff --git a/tests/data/test_source_data_datasets.py b/tests/data/test_source_data_datasets.py index d3a9f3e..f4d7a3b 100644 --- a/tests/data/test_source_data_datasets.py +++ b/tests/data/test_source_data_datasets.py @@ -176,18 +176,23 @@ def test_indexed_sampling_dataset(fp_test_dataset: Path, mock_csv: str) -> None: def test_indexed_sampling_dataset_time_subset(fp_test_dataset: Path, mock_csv: str) -> None: - """Test that subset correctly filters CSV rows by time range.""" + """Subset keeps only index rows whose full window fits the subset, and + rebases their absolute ``t`` onto the sliced (0-based) store.""" zarr_ds = xr.open_zarr(str(fp_test_dataset)) time_index = zarr_ds.indexes["time"] + # mock_csv has t = [0, 5, 10]; subset [3, 21) with time_depth 3 drops t=0 + # (before the start) and keeps t=5, 10, rebased onto the slice to [2, 7]. ds = SourceDataIndexedDataset( zarr_path=str(fp_test_dataset), index_path=mock_csv, standard_names=["rainfall_flux"], input_steps=2, forecast_steps=1, - subset={"time": (str(time_index[0]), str(time_index[8]))}, + time_depth=3, + subset={"time": (str(time_index[3]), str(time_index[20]))}, ) assert len(ds) == 2 + assert sorted(ds.coords["t"].tolist()) == [2, 7] def test_indexed_sampling_dataset_forecast_steps_guard(fp_test_dataset: Path, mock_csv: str) -> None: From 8410b88c695032831026ccc04a761d7f4b4fd26f Mon Sep 17 00:00:00 2001 From: Gabriele Franch Date: Sun, 7 Jun 2026 17:11:12 +0200 Subject: [PATCH 15/15] data: collapse the target validity mask over the whole sequence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mask a grid cell whenever it is NaN at any step of the sequence (inputs or targets), instead of per-timestep. A temporal discontinuity at a cell makes its forecast trajectory ill-defined — and the temporal-consistency loss term meaningless — so the cell should not be scored anywhere in the sequence (matching dpc-nowcasting's mask semantics). The mask is emitted as a single (1, C, H, W) tensor; the loss broadcasts it over the forecast steps, so no (forecast_steps, C, H, W) copy is materialised on the GPU. Also clarify the MaskedLoss broadcast_factor contract (the mask must broadcast into the equal-or-larger elementwise loss). Adds tests for the collapse semantics and the masked-loss broadcasting over a collapsed mask. --- src/mlcast/data/source_data_datasets.py | 21 ++++++++++++----- src/mlcast/losses.py | 3 +++ tests/data/test_source_data_datasets.py | 30 +++++++++++++++++++++++-- tests/test_losses.py | 23 +++++++++++++++++++ 4 files changed, 69 insertions(+), 8 deletions(-) diff --git a/src/mlcast/data/source_data_datasets.py b/src/mlcast/data/source_data_datasets.py index 50d0fda..f7451cd 100644 --- a/src/mlcast/data/source_data_datasets.py +++ b/src/mlcast/data/source_data_datasets.py @@ -69,14 +69,16 @@ class DatasetSample(TypedDict, total=False): Past frames fed to the network as input. target : Float[torch.Tensor, "forecast_steps channels height width"] Future frames the network should predict. - target_mask : Float[torch.Tensor, "forecast_steps channels height width"] - Per-timestep, per-channel validity mask for the target (1 = valid, - 0 = NaN in original data). Only present when ``return_mask=True``. + target_mask : Float[torch.Tensor, "1 channels height width"] + Per-cell validity mask collapsed over the whole sequence: a cell is 1 + only if it was finite at every step (inputs and targets), else 0. Its + leading axis is a single step that the loss broadcasts over the forecast + steps. Only present when ``return_mask=True``. """ input: Float[torch.Tensor, "input_steps channels height width"] target: Float[torch.Tensor, "forecast_steps channels height width"] - target_mask: Float[torch.Tensor, "forecast_steps channels height width"] + target_mask: Float[torch.Tensor, "1 channels height width"] def _detect_axes(ds: xr.Dataset, standard_name: str) -> tuple[str, str, str]: @@ -297,9 +299,16 @@ def _build_sample(self, data: np.ndarray) -> DatasetSample: Dictionary with ``'input'`` and ``'target'`` tensors, and optionally ``'target_mask'`` if ``self.return_mask`` is ``True``. """ - # Capture target mask before NaNs are filled + # Validity mask, collapsed over the whole sequence: a cell is scored only + # if it is finite at EVERY step (inputs and targets). A temporal + # discontinuity anywhere makes the forecast trajectory at that cell + # ill-defined — and the temporal-consistency loss term meaningless — so we + # mask it for the whole sequence. Kept as (1, C, H, W); the loss broadcasts + # it over the forecast steps, so no T copies are materialised on the GPU. + # Computed before NaNs are filled below. if self.return_mask: - target_mask_t = torch.from_numpy((~np.isnan(data[self.input_steps :])).astype(np.float32)) + valid = ~np.isnan(data).any(axis=0, keepdims=True) # (1, C, H, W) + target_mask_t = torch.from_numpy(valid.astype(np.float32)) # source data may be float64, but the model and the rest of the # training pipeline operate in float32. diff --git a/src/mlcast/losses.py b/src/mlcast/losses.py index 290c426..4c306dc 100644 --- a/src/mlcast/losses.py +++ b/src/mlcast/losses.py @@ -77,6 +77,9 @@ def forward(self, preds: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) elementwise_loss = self.elementwise_loss(preds, target) masked_loss = elementwise_loss * mask + # The mask is expected to broadcast into the (equal-or-larger) elementwise + # loss — e.g. a sequence-collapsed (1, C, H, W) mask over a (T, C, H, W) + # loss — so each mask element accounts for `broadcast_factor` loss elements. broadcast_factor = elementwise_loss.numel() // mask.numel() valid_pixels = mask.sum() * broadcast_factor if valid_pixels > 0: diff --git a/tests/data/test_source_data_datasets.py b/tests/data/test_source_data_datasets.py index f4d7a3b..5d2ebf7 100644 --- a/tests/data/test_source_data_datasets.py +++ b/tests/data/test_source_data_datasets.py @@ -169,7 +169,7 @@ def test_indexed_sampling_dataset(fp_test_dataset: Path, mock_csv: str) -> None: assert input_t.shape == (input_steps, 1, 16, 16) assert target_t.shape == (forecast_steps, 1, 16, 16) - assert target_mask_t.shape == (forecast_steps, 1, 16, 16) + assert target_mask_t.shape == (1, 1, 16, 16) # collapsed over the sequence assert isinstance(input_t, torch.Tensor) assert isinstance(target_t, torch.Tensor) assert isinstance(target_mask_t, torch.Tensor) @@ -235,12 +235,38 @@ def test_random_sampling_dataset(fp_test_dataset: Path) -> None: assert input_t.shape == (input_steps, 1, 32, 32) assert target_t.shape == (forecast_steps, 1, 32, 32) - assert target_mask_t.shape == (forecast_steps, 1, 32, 32) + assert target_mask_t.shape == (1, 1, 32, 32) # collapsed over the sequence assert input_t.dtype == torch.float32 assert target_t.dtype == torch.float32 assert target_mask_t.dtype == torch.float32 +def test_target_mask_collapses_over_sequence(fp_test_dataset: Path) -> None: + """A cell that is NaN at *any* step (input or target) must be masked across + every forecast step — a temporal discontinuity makes the trajectory there + ill-defined, so the cell is not scored anywhere in the sequence.""" + input_steps, forecast_steps = 3, 2 + ds = SourceDataRandomSamplingDataset( + zarr_path=str(fp_test_dataset), + standard_names=["rainfall_flux"], + input_steps=input_steps, + forecast_steps=forecast_steps, + width=8, + height=8, + return_mask=True, + ) + steps = input_steps + forecast_steps + data = np.zeros((steps, 1, 8, 8), dtype=np.float32) + data[1, 0, 2, 3] = np.nan # NaN at one input step + data[steps - 1, 0, 5, 5] = np.nan # NaN at the last forecast step + mask = ds._build_sample(data)["target_mask"].numpy() + + assert mask.shape == (1, 1, 8, 8) # collapsed over the sequence + assert mask[0, 0, 2, 3] == 0 # NaN at an input step masks the cell + assert mask[0, 0, 5, 5] == 0 # NaN at a forecast step masks the cell + assert mask[0, 0, 0, 0] == 1 # a cell valid at every step stays valid + + def test_random_sampling_dataset_time_subset(fp_test_dataset: Path) -> None: """Test that subset correctly slices the Zarr store.""" zarr_ds = xr.open_zarr(str(fp_test_dataset)) diff --git a/tests/test_losses.py b/tests/test_losses.py index 18fc222..e6b6080 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -51,3 +51,26 @@ def test_crps_loss_shapes(loss_class): # Temporal regularization usually removes the T dimension or broadcasts over it. # The current implementation returns (B, 1, 1, H, W) when temporal_lambda > 0 assert loss_temporal.shape == (B, 1, 1, H, W), f"{loss_class.__name__} temporal shape mismatch" + + +@pytest.mark.parametrize("temporal_lambda", [0.0, 0.1]) +def test_masked_loss_broadcasts_collapsed_mask(temporal_lambda): + """The dataset emits a sequence-collapsed ``(B, 1, 1, H, W)`` mask; MaskedLoss + must broadcast it over the forecast steps and normalise by the valid count — + for both the per-step loss and the time-collapsed ``temporal_lambda > 0`` one. + """ + B, T, M, H, W = 2, 4, 3, 16, 16 + torch.manual_seed(0) + preds = torch.randn(B, T, M, H, W) + target = torch.randn(B, T, 1, H, W) + criterion = build_loss(loss_class="crps", loss_params={"temporal_lambda": temporal_lambda}, masked_loss=True) + + full = criterion(preds, target, torch.ones(B, 1, 1, H, W)) + assert full > 0, "a fully-valid mask must give a non-zero loss" + # with every cell valid, the masked mean equals the plain elementwise mean + expected = criterion.elementwise_loss(preds, target).mean() + assert torch.allclose(full, expected, rtol=1e-5) + + # a genuinely empty (all-zero) mask still yields exactly zero + empty = criterion(preds, target, torch.zeros(B, 1, 1, H, W)) + assert empty == 0