From e32c279139a697edf88b6c39f0f169c033f603a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Fr=C3=B8lund?= Date: Thu, 27 Nov 2025 09:43:36 +0100 Subject: [PATCH 1/2] Updated README. Added docstrings to the data_handling module. Adjusted the CLI slightly. --- README.md | 194 ++++++++++++++++++++----------- ldcast/cli.py | 14 +-- ldcast/features/data_handling.py | 134 ++++++++++++++++++++- 3 files changed, 264 insertions(+), 78 deletions(-) diff --git a/README.md b/README.md index 5cbf2b7..a539547 100644 --- a/README.md +++ b/README.md @@ -1,106 +1,164 @@ LDCast is a precipitation nowcasting model based on a latent diffusion model (LDM, used by e.g. [Stable Diffusion](https://github.com/CompVis/stable-diffusion)). -This repository contains the code for using LDCast to make predictions and the code used to generate the analysis in the LDCast paper (a preprint is available at https://arxiv.org/abs/2304.12891). - -A GPU is recommended for both using and training LDCast, although you may be able to generate some samples with a CPU and enough patience. +This repository contains the code for using LDCast to make predictions. The code is reworked from https://github.com/MeteoSwiss/ldcast. # Installation -It is recommended you install the code in its own virtual environment (created with e.g. pyenv or conda). - -Clone the repository, then, in the main directory, run +The package uses the `uv` package manager to handle dependencies. First, install `uv` by executing ```bash -$ pip install -e . +curl -LsSf https://astral.sh/uv/install.sh | sh ``` -This should automatically install the required packages (which might take some minutes). In the paper, we used PyTorch 11.2 but are not aware of any problems with newer versions. - -If you don't want the requirements to be installed (e.g. if you installed them manually with conda), use: +Then, clone this repository and install the dependencies by executing ```bash -$ pip install --no-dependencies -e . +git clone git@github.com:dmidk/ldcast-dmi.git +uv sync --all-extras ``` # Using LDCast -## Pretrained models +The package defines a command line interface, which can be inspected by executing +```bash +uv run ldcast - --help +``` +Output: +``` +NAME + ldcast - Main CLI class. + +SYNOPSIS + ldcast - GROUP | COMMAND | VALUE -The pretrained models are available at the Zenodo repository https://doi.org/10.5281/zenodo.7780914. Unzip the file `ldcast-models.zip`. The default is to unzip it to the `models` directory, but you can also use another location. +DESCRIPTION + Main CLI class. -## Producing predictions +GROUPS + GROUP is one of the following: -The easiest way to produce predictions is to use the `ldcast.forecast.Forecast` class, which will set up all models and data transformations and is callable with a past precipitation array. -```python -from ldcast import forecast + train + Cli setup for executing training. -fc = forecast.Forecast( - ldm_weights_fn=ldm_weights_fn, autoenc_weights_fn=autoenc_weights_fn -) -R_pred = fc(R_past) -``` -Here, `ldm_weights_fn` is the path to the LDM weights and `autoenc_weights_fn` is the path to the autoencoder weights. `R_past` is a NumPy array of precipitation rates with shape `(timesteps, height, width)` where `timesteps` must be 4 and `height` and `width` must be divisible by 32. + visualize + Cli setup for visualization. -### Ensemble predictions +COMMANDS + COMMAND is one of the following: -If want to process multiple cases at once and/or generate several ensemble members, there is the `ldcast.forecast.ForecastDistributed` class. The usage is similar to the `Forecast` class, for example: -```python -from ldcast import forecast + forecast + Cli entry point for running forecast without training. -fc = forecast.ForecastDistributed( - ldm_weights_fn=ldm_weights_fn, autoenc_weights_fn=autoenc_weights_fn -) -R_pred = fc(R_past, ensemble_members=32) -``` -Here, `R_past` should be of shape `(cases, timesteps, height, width)` where `cases` is the number of cases you want to process. For each case, `ensemble_members` predictions are produced (this is the last axis of `R_pred`). `ForecastDistributed` automatically distributes the workload to multiple GPUs if you have them. + sample + Cli entry point for running sampling without training. -## Demo +VALUES + VALUE is one of the following: -For a practical example, you can run the demo in the `scripts` directory. First download the `ldcast-demo-20210622.zip` file from the [Zenodo repository](https://doi.org/10.5281/zenodo.7780914), then unzip it in the `data` directory. Then run -```bash -$ python forecast_demo.py + config ``` -A sample output can be found in the file `ldcast-demo-video-20210622.zip` in the data repository. See the function `forecast_demo` in `forecast_demo.py` see how the `Forecast` class works. To run an ensemble mean of 8 members using the `ForecastDistributed` class, you can use: + +To show the available commands of the groups, execute e.g. ```bash -$ python forecast_demo.py --ensemble-members=8 +uv run ldcast train --help +``` +Output: ``` +NAME + ldcast train - Cli setup for executing training. -The demo for a single ensemble member runs in a couple of minutes on our system using one V100 GPU; with a CPU around 10 minutes or more would be expected. A progress bar will show the status of the generation. +SYNOPSIS + ldcast train COMMAND | VALUE -# Training +DESCRIPTION + Cli setup for executing training. -## Training data +COMMANDS + COMMAND is one of the following: -The preprocessed training data, needed to rerun the LDCast training, can be found at the [Zenodo repository](https://doi.org/10.5281/zenodo.7780914). Unzip the `ldcast-datasets.zip` file to the `data` directory. + all + Execute all training pipelines. -## Training the autoencoder + autoenc + Execute the autoencoder training pipeline. -In the `scripts` directory, run -```bash -$ python train_autoenc.py --model_dir="../models/autoenc_train" -``` -to run the training of the autoencoder with the default parameters. The training checkpoints will be saved in the `../models/autoenc_train` directory (feel free to change this). + genforecast + Execute the genforecast training pipeline. -It has been reported that this training may encounter a condition where the loss goes to `nan`. If this happens, try restarting from the latest checkpoint: -```bash -$ python train_autoenc.py --model_dir="../models/autoenc_train" --ckpt_path="../models/autoenc_train/" -``` -where `` should be the latest checkpoint in the `../models/autoenc_train/` directory. +VALUES + VALUE is one of the following: -## Training the diffusion model + config -In the `scripts` directory, run -```bash -$ python train_genforecast.py --model_dir="../models/genforecast_train" + num_nodes + + save_model ``` -to run the training of the diffusion model with the default parameters, or + +As an example, to train the autoencoder one would execute ```bash -$ python train_genforecast.py --model_dir="../models/genforecast_train" --config= +uv run ldcast train autoenc --config path/to/config.yaml --save_model -num_nodes 1 ``` -to run the training with different parameters. Some config files can be found in the `config` directory. The training checkpoints will be saved in the `../models/genforecast_train` directory (again, this can be changed freely). -# Evaluation +# Configuration +The package specifies configurations in a YAML file. +An example configuration file can be found at `./example_config.yaml`. +The configuration file is parsed using the [pydantic](https://pydantic-docs.helpmanual.io/) library. +The configuration is structured into 4 main sections: `general`, `datasets`, `preprocessing`, and `models`, where the `general`, and the `datasets` sections are required for all commands. +The `preprocessing` section is optional, while for the `model` section it is required to specify at least one model configuration (i.e. `autoenc`, `genforecast`, or `forecast`). +In addition to various model parameters, the model sections specify what input datasets/weigths are needed, and what output datasets/weights are produced. -You can find scripts for evaluating models in the `scripts` directory: -* `eval_genforecast.py` to evaluate LDCast -* `eval_dgmr.py` to evaluate DGMR (requires tensorflow installation and the DGMR model from https://github.com/deepmind/deepmind-research/tree/master/nowcasting placed in the `models/dgmr` directory) -* `eval_pysteps.py` to evaluate PySTEPS (requires pysteps installation) -* `metrics.py` to produce metrics from the evaluation results produced with the functions in scripts above -* `plot_genforecast.py` to make plots from the results generated +# Package structure +The main structure of the package is outlined below: +``` +. +├── DataPreprocessing # Scripts for splitting radar data into patches +│   ├── merge_datasets.py +│   ├── radarForML.py +│   ├── radarToZarrML.py +│   └── settings/ +├── ldcast +│   ├── analysis/ # Scripts for analyzing model performance +│   ├── features # Various common features used across models +│   │   ├── radar/ +│   │   ├── data_handling.py # Main pytorch datahandling module +│   │   ├── debug.py +│   │   ├── io.py +│   │   ├── patches.py +│   │   ├── re_patch.py # Script for re-patching data into custom new patch sizes +│   │   ├── sampling.py # Sampling of dataset, e.g. EqualFrequencySampler +│   │   ├── split.py # Script for splitting data into train/val/test sets +│   │   ├── transform.py +│   │   └── utils.py +│   ├── models # Main model implementations +│   │   ├── autoenc +│   │   │   ├── autoenc.py +│   │   │   ├── encoder.py +│   │   │   ├── __init__.py # Main entry point to run autoencoder +│   │   │   └── training.py +│   │   ├── benchmarks/ +│   │   ├── blocks/ +│   │   ├── diffusion/ +│   │   ├── genforecast +│   │   │   ├── analysis.py +│   │   │   ├── __init__.py # Main entry point to run genforecast +│   │   │   ├── training.py +│   │   │   └── unet.py +│   │   ├── nowcast/ +│   │   ├── distributions.py +│   │   ├── forecast.py # Main entry point to run forecast +│   │   └── utils.py +│   ├── visualization/ +│   ├── __init__.py +│   ├── __main__.py +│   ├── cli.py # Defines the command line interface +│   └── config_parser.py # Configuration parser using pydantic +├── test/ +├── LICENSE +├── README.md +├── leonardo_config.yaml # Example configuration for running on Leonardo +├── example_config.yaml # Example configuration to start with +├── pyproject.toml +├── run_autoenc.sh # Bash script to run training of the autoencoder on Leonardo +├── run_forecast.sh # Bash script to run forecasting on Leonardo +├── run_genforecast.sh # Bash script to run training of the genforecast model on Leonardo +├── run_sampler.sh # Bash script to run data sampling on Leonardo +└── uv.lock +``` \ No newline at end of file diff --git a/ldcast/cli.py b/ldcast/cli.py index 3a4809d..f84989a 100644 --- a/ldcast/cli.py +++ b/ldcast/cli.py @@ -97,9 +97,7 @@ class TrainCLI(object): def __init__(self, config: cp.Config, save_model: bool, num_nodes: int): self.config = config self.save_model = save_model - self.sampling_grp = SamplingCLI(config) - self.autoenc_grp = AutoencCLI(config, save_model, num_nodes) - self.genforecast_grp = GenforecastCLI(config, save_model, num_nodes) + self.num_nodes = num_nodes def autoenc(self): """Execute the autoencoder training pipeline. @@ -110,17 +108,17 @@ def autoenc(self): specified configurations. """ - self.autoenc_grp.train() + AutoencCLI(self.config, self.save_model, self.num_nodes).train() def genforecast(self): """Execute the genforecast training pipeline.""" - self.genforecast_grp.train() + GenforecastCLI(self.config, self.save_model, self.num_nodes).train() def all(self): """Execute all training pipelines.""" - self.sampling_grp.run() - self.autoenc_grp.train() - self.genforecast_grp.train() + SamplingCLI(self.config).run() + AutoencCLI(self.config, self.save_model, self.num_nodes).train() + GenforecastCLI(self.config, self.save_model, self.num_nodes).train() class SamplingCLI(object): diff --git a/ldcast/features/data_handling.py b/ldcast/features/data_handling.py index cbb032b..25d0099 100644 --- a/ldcast/features/data_handling.py +++ b/ldcast/features/data_handling.py @@ -24,11 +24,28 @@ # Define the dataset and dataloader def get_sampled_data(path: Path, split: str): + """Open the sampled dataset with xarray. + + Args: + path (Path): The path to the sampled dataset. + split (str): The data split to open, i.e. 'train', 'valid', or 'test'. + + Returns: + xr.Dataset: The sampled dataset for the given split. + """ logger.debug(f"Opening sampled dataset for split '{split}") return xr.open_zarr(path, group=split) def get_patch_data(path: Path): + """Open the patch dataset with xarray. + + Args: + path (Path): The path to the patch dataset. + + Returns: + xr.Dataset: + """ logger.info("Opening patch dataset") ds_patches = xr.open_zarr(path) ds_patches = cfxr.decode_compress_to_multi_index( @@ -39,6 +56,8 @@ def get_patch_data(path: Path): class PredictDataset(Dataset): + """Dataset for prediction/forecasting.""" + def __init__( self, dataset: cp.Dataset, @@ -62,6 +81,7 @@ def __len__(self): return self.past_timesteps def __getitem__(self, index): + """Return radar data for a given timestep index loaded from .h5 files.""" leadtime: datetime = ( self.basetime - (self.past_timesteps - 1 - index) * self.interval ) @@ -79,6 +99,13 @@ def __getitem__(self, index): class EnsembleBatchSampler(torch.utils.data.BatchSampler): + """Custom batch sampler to generate batches for ensemble members. + + Supports distributed sampling across multiple ranks. + Any remaining members defined as self.ensemble_members % self.world_size + are assigned to the last rank. + """ + def __init__(self, sampler, batch_size, ensemble_members: int, drop_last=False): super().__init__(sampler, batch_size, drop_last=drop_last) self.batch_size = batch_size @@ -87,6 +114,7 @@ def __init__(self, sampler, batch_size, ensemble_members: int, drop_last=False): self.world_size = int(os.environ.get("WORLD_SIZE", 1)) def __iter__(self): + """Yield batches for the ensemble members assigned to the current rank.""" chunk_size = self.ensemble_members // self.world_size remainder = self.ensemble_members % self.world_size start = self.rank * chunk_size @@ -100,6 +128,7 @@ def __iter__(self): yield list(range(self.batch_size)) def __len__(self): + """Return the number of ensemble members for the current rank.""" len_ = self.ensemble_members // self.world_size if self.rank == self.world_size - 1: len_ += self.ensemble_members % self.world_size @@ -107,6 +136,8 @@ def __len__(self): class PredictDataModule(pl.LightningDataModule): + """Data module for prediction/forecasting.""" + def __init__( self, dataset: cp.Dataset, @@ -129,7 +160,13 @@ def __init__( self.predict_dataset = None self.ensemble_members = ensemble_members - def setup(self, stage=None): + def setup(self, stage: str = None): + """Setup the data module for prediction. + + Args: + stage (str, optional): The stage to setup the data module for. + Defaults to None. + """ self.predict_dataset = PredictDataset( self.dataset, self.basetime, @@ -139,6 +176,7 @@ def setup(self, stage=None): ) def predict_dataloader(self): + """Return the dataloader for prediction.""" return DataLoader( self.predict_dataset, batch_sampler=EnsembleBatchSampler( @@ -151,12 +189,15 @@ def predict_dataloader(self): ) def collate_fn(self, batch_): + """Stack the batch into a single tensor.""" return torch.stack( list(torch.from_numpy(batch_[i]) for i in range(self.batch_size)) ) class PatchesDataset(Dataset): + """Dataset for loading patches from patch dataset based on sampled dataset.""" + def __init__( self, ds_sampled: xr.Dataset, @@ -165,6 +206,15 @@ def __init__( zero_value=0, missing_value=0, ): + """Initialize the PatchesDataset. + + Args: + ds_sampled (xr.Dataset): The xarray dataset containing sampled patch locations. + ds_patches (xr.Dataset): The xarray dataset containing the patches. + timesteps (np.ndarray): The timesteps to load patches for. + zero_value (int, optional): The value to assign to zero patches. Defaults to 0. + missing_value (int, optional): The value to assign to missing patches. Defaults to 0. + """ super().__init__() self.ds_sampled = ds_sampled self.ds_patches = ds_patches @@ -179,9 +229,20 @@ def __init__( self.missing_array = np.ones(self.patch_shape, dtype=np.float32) * missing_value def __len__(self): + """Return the number of samples in the dataset.""" return len(self.ds_sampled.index) - def __getitem__(self, ind): + def __getitem__(self, ind: int): + """Return the patches and time array for a given index. + + The patches are aggregated according to the sample shape and timesteps. + + Args: + ind (int): The index to retrieve patches for. + + Returns: + Tuple[da.Array, np.ndarray]: The patches and time array. + """ item = self.ds_sampled.isel(index=ind) t, i, j = item.start_time.values, item.start_i.values, item.start_j.values @@ -218,7 +279,17 @@ def __getitem__(self, ind): class CombinedDataLoader(CombinedLoader): + """DataLoader to combine multiple dataloaders for target and prediction datasets.""" + def __init__(self, *loaders, augment=False, target_in_predictions: bool = False): + """Initialize the CombinedDataLoader. + + Args: + *loaders: The dataloaders to combine. + augment (bool, optional): Whether to apply data augmentation. Defaults to False. + target_in_predictions (bool, optional): Whether the target dataset should also be in the + prediction datasets. Defaults to False. + """ super().__init__(loaders) self.augment = augment self.target_in_predictions = target_in_predictions @@ -226,6 +297,17 @@ def __init__(self, *loaders, augment=False, target_in_predictions: bool = False) def batch_augmentation( self, batch: torch.Tensor, transpose: int, flipud: int, fliplr: int ): + """Apply data augmentation to a batch. + + Args: + batch (torch.Tensor): The batch to augment. + transpose (int): Whether to transpose the batch. + flipud (int): Whether to flip the batch upside down. + fliplr (int): Whether to flip the batch left to right. + + Returns: + torch.Tensor: The augmented batch. + """ if transpose: batch = torch.transpose(batch, -2, -1) flips = [] @@ -243,9 +325,12 @@ def __next__(self) -> _ITERATOR_RETURN: out = next(self._iterator) if isinstance(self._iterator, _Sequential): return out + + # Get output from all dataloaders out, batch_idx, dataloader_idx = out target, *predictions = tree_unflatten(out, self._spec)[0].values() + # Apply the same data augmentations to both target and predictions if self.augment: # Decide on augmentations - common to both target and predictions transpose = da.random.randint(2) @@ -267,6 +352,7 @@ def __next__(self) -> _ITERATOR_RETURN: fliplr=fliplr, ) + # Add target to predictions if self.target_in_predictions: predictions.insert(0, target) @@ -280,6 +366,8 @@ def __next__(self) -> _ITERATOR_RETURN: class TrainDataModule(pl.LightningDataModule): + """Data module for loading training, validation, and test data.""" + def __init__( self, batch_size, @@ -291,6 +379,18 @@ def __init__( batches_per_epoch: Optional[int] = None, num_workers: int = 0, ): + """Initialize the TrainDataModule. + + Args: + batch_size (int): The batch size for the dataloaders. + target_dataset (cp.Dataset): The target dataset. + prediction_datasets (Dict[str, cp.Dataset]): The prediction datasets. + sampled_dataset (cp.Dataset): The sampled dataset. + target_timesteps (np.ndarray): The timesteps for the target dataset. + prediction_timesteps (Dict[str, np.ndarray]): The timesteps for the prediction datasets. + batches_per_epoch (Optional[int], optional): The number of batches per epoch. Defaults to None. + num_workers (int, optional): The number of workers to use in the data loaders. Defaults to 0. + """ super().__init__() self.train_target_dataset = None self.val_target_dataset = None @@ -318,6 +418,7 @@ def __init__( @staticmethod def get_dataset(ds_patches, ds_sampled, timesteps): + """Get the PatchesDataset for the given patches and sampled datasets.""" return PatchesDataset( ds_sampled=ds_sampled, ds_patches=ds_patches, @@ -327,6 +428,12 @@ def get_dataset(ds_patches, ds_sampled, timesteps): ) def setup(self, stage=None): + """Setup the data module for training, validation, and testing. + + Args: + stage (str, optional): The stage to setup the data module for. + Defaults to None. + """ if stage == "fit" or stage is None: # Load training data ds_sampled_train = get_sampled_data( @@ -380,6 +487,17 @@ def setup(self, stage=None): raise NotImplementedError def get_dataloader(self, stage: Literal["train", "val", "test"], augment=False): + """Get a combined dataloader for the given stage for target and prediction. + + The dataloader uses DistributedSampler to support multi-device training. + + Args: + stage (Literal["train", "val", "test"]): The stage to get the dataloader for. + augment (bool, optional): Whether to apply data augmentation. Defaults to False. + + Returns: + CombinedDataLoader: The combined dataloader for the given stage. + """ # Retrieve relevant datasets target_dataset: dict = getattr(self, f"{stage}_target_dataset") prediction_datasets: dict = getattr(self, f"{stage}_prediction_datasets") @@ -394,6 +512,7 @@ def get_dataloader(self, stage: Literal["train", "val", "test"], augment=False): prediction_dataset, range(total_samples) ) + # Use DistributedSampler to be able to run on multiple devices. dataloaders = { "target": DataLoader( target_dataset, @@ -419,15 +538,26 @@ def get_dataloader(self, stage: Literal["train", "val", "test"], augment=False): ) def train_dataloader(self): + """Get the training dataloader.""" return self.get_dataloader("train", augment=True) def val_dataloader(self): + """Get the validation dataloader.""" return self.get_dataloader("val") def test_dataloader(self): + """Get the test dataloader.""" return self.get_dataloader("test") def collate_fn(self, batch_: Tuple[da.Array]): + """Stack the batch into one tensor for the patches, and one for the time. + + Args: + batch_ (Tuple[da.Array]): The batch to collate. + + Returns: + Tuple[torch.Tensor]: The collated batch. + """ patches_batch, time_batch = zip(*batch_) patches_batch: da.Array = da.stack(patches_batch, axis=0) time_batch: da.Array = da.stack(time_batch, axis=0) From 79bcde9a00d295e1b8aa89eb40526a85331be0c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Fr=C3=B8lund?= Date: Thu, 27 Nov 2025 09:44:07 +0100 Subject: [PATCH 2/2] Remove obsolete file --- ldcast/features/batch.py | 471 --------------------------------------- 1 file changed, 471 deletions(-) delete mode 100644 ldcast/features/batch.py diff --git a/ldcast/features/batch.py b/ldcast/features/batch.py deleted file mode 100644 index f9cb92d..0000000 --- a/ldcast/features/batch.py +++ /dev/null @@ -1,471 +0,0 @@ -import os -from typing import List, Optional - -import dask.array as da -import numba.typed as nb_typed -import numpy as np -import xarray as xr -from loguru import logger -from numba import njit, prange, types -from torch.utils.data import Dataset, IterableDataset - -from ldcast import GeneralConstants -from ldcast.features.sampling import EqualFrequencySampler -from ldcast.features.utils import recursive_merge_lists_of_dicts - - -class BatchGenerator: - def __init__( - self, - variables, - ds_patches, - ds_zero_patches, - predictors, - target, - primary_var, - time_range_sampling=(-np.timedelta64(10, "m"), 2 * np.timedelta64(10, "m")), - forecast_raw_vars=(), - sampling_bins=None, - sampler_file=None, - sample_shape=(2, 2), - batch_size=32, - interval=np.timedelta64(10, "m"), - augment=False, - split=None, - ): - super().__init__() - self.batch_size = batch_size - self.interval = interval - self.variables = variables - self.predictors = predictors - self.target = target - self.used_variables = predictors + [target] - self.augment = augment - self.timestep_secs_dict = { - var: self.variables[var]["timesteps"][None, :] - .astype("timedelta64[s]") - .astype(np.int32) - for var in self.variables - } - - # setup indices for retrieving source raw data - self.sources = set.union( - *(set(variables[v]["sources"]) for v in self.used_variables) - ) - self.forecast_raw_vars = set(forecast_raw_vars) & self.sources - self.patch_index = {} - - for raw_name_base in self.sources: - if raw_name_base in forecast_raw_vars: - raw_names = ( - rn for rn in ds_patches if rn.startswith(raw_name_base + "-") - ) - else: - raw_names = (raw_name_base,) - for raw_name in raw_names: - self.setup_index(raw_name, ds_patches, ds_zero_patches, sample_shape) - for raw_name in self.forecast_raw_vars: - patch_index_var = { - k: v - for (k, v) in self.patch_index.items() - if k.startswith(raw_name + "-") - } - self.patch_index[raw_name] = ForecastPatchIndexWrapper(patch_index_var) - - # setup samplers - if ( - (sampler_file is None) - or ( - os.path.split(sampler_file)[1] - not in os.listdir(GeneralConstants.DEFAULT_DATASETS_DIR) - ) - or (split not in os.listdir(sampler_file)) - ): # not os.path.isfile(sampler_file): - logger.info("No cached sampler found, creating a new one...") - primary_raw_var = variables[primary_var]["sources"][0] - t0 = t1 = None - for _, var_data in variables.items(): - timesteps = var_data["timesteps"][[0, -1]].copy() - timesteps[0] -= interval - t0 = timesteps[0] if t0 is None else min(t0, timesteps[0]) - t1 = timesteps[-1] if t1 is None else max(t1, timesteps[-1]) - - # Convert from timedelta64 to int - time_range_valid = tuple( - map(lambda x: np.timedelta64(x, "s").astype(int), (t0, t1 + interval)) - ) - time_range_sampling = tuple( - map(lambda x: np.timedelta64(x, "s").astype(int), time_range_sampling) - ) - sampler = EqualFrequencySampler( - sampling_bins, - self.patch_index[primary_raw_var], - sample_shape, - time_range_valid, - ds=ds_patches, - ds_zero_patches=ds_zero_patches, - time_range_sampling=time_range_sampling, - timestep_secs=self.interval.astype("timedelta64[s]").astype(int), - ) - self.ds_sampled = sampler.ds_sampled - - # think in a way to see if a group already exist in the file - if sampler_file is not None: - logger.info(f"Caching sampler to {sampler_file}/{split}.") - sampler.to_zarr(sampler_file, group=split) - else: - primary_raw_var = variables[primary_var]["sources"][0] - self.ds_sampled = EqualFrequencySampler.from_zarr( - sampler_file, - split, - ) - - def setup_index(self, raw_name, ds_patches, ds_zero_patches, box_size): - logger.info("Setting up PatchIndex") - zero_value = -32 # raw_data.get("zero_value", 0) - missing_value = zero_value # raw_data.get("missing_value", zero_value) - - self.patch_index[raw_name] = PatchIndex( - ds_patches=ds_patches, - ds_zero_patches=ds_zero_patches, - zero_value=zero_value, - missing_value=missing_value, - box_size=box_size, - ) - - def augmentations(self): - return tuple(np.random.randint(2, size=3)) - - def augment_batch(self, batch, transpose, flipud, fliplr): - if self.augment: - if transpose: - axes = list(range(batch.ndim)) - axes = axes[:-2] + [axes[-1], axes[-2]] - batch = batch.transpose(axes) - flips = [] - if flipud: - flips.append(-2) - if fliplr: - flips.append(-1) - if flips: - batch = np.flip(batch, axis=flips) - return batch.copy() - - def batch(self, ds_samples=None, batch_size=None): - if batch_size is None: - batch_size = self.batch_size - - if ds_samples is None: # here is where the random selection occurs - # NOTE: this is temporary, since shuffling should be handled by torch - random_indices = da.random.randint( - self.ds_sampled.index.shape[0], size=self.batch_size - ) - ds_samples = self.ds_sampled.isel(index=random_indices) - - t0, i0, j0 = ds_samples.start_time, ds_samples.start_i, ds_samples.start_j - - # if self.augment: - # augmentations = self.augmentations() - - batch = {} - for var_name in self.used_variables: - var_data = self.variables[var_name] - - timestep_secs: int = ds_samples.timestep_secs - t_shift = -(t0 % timestep_secs) - t0_shifted = t0 + t_shift - t = ( - t0_shifted.expand_dims("timestep", axis=1) - + self.timestep_secs_dict[var_name] - ) - t_relative = (t - t0) / timestep_secs - - # read raw data from index - batch_var = self.patch_index[var_data["sources"][0]](t, i0, j0) - - # transform to model variable - # batch_var = var_data["transform"](*raw_data) - - # add channel dimension if not already present - add_dims = (1,) if batch_var.ndim == 4 else () - batch_var = np.expand_dims(batch_var, add_dims) - - # data augmentation - # if self.augment: - # batch_var = self.augment_batch(batch_var, *augmentations) - - # bundle with time coordinates - batch[var_name] = ( - batch_var.astype(np.float32), - t_relative.astype(np.float32).compute().values, - ) - - pred_batch = [batch[v] for v in self.predictors] - target_batch = batch[self.target][0] # no time coordinates for target - return (pred_batch, target_batch) - - def batches(self, *args, num=None, **kwargs): - if num is not None: - for _ in range(num): - yield self.batch(*args, **kwargs) - else: - while True: - yield self.batch(*args, **kwargs) - - -class StreamBatchDataset(IterableDataset): - def __init__(self, batch_gen, batches_per_epoch): - super().__init__() - self.batch_gen = batch_gen - self.batches_per_epoch = batches_per_epoch - - def __iter__(self): - batches = self.batch_gen.batches(num=self.batches_per_epoch) - yield from batches - - -# also modify this one so it agreees with the new structure -class DeterministicBatchDataset(Dataset): - def __init__(self, batch_gen: BatchGenerator, batches_per_epoch): - super().__init__() - - self.batch_gen = batch_gen - self.batches_per_epoch = batches_per_epoch - - self.samples = [] - random_indices = da.random.permutation(self.batch_gen.ds_sampled.index.shape[0]) - for i in range(self.batches_per_epoch): - self.samples.append( - self.batch_gen.ds_sampled.isel( - index=random_indices[ - slice( - i * self.batch_gen.batch_size, - (i + 1) * self.batch_gen.batch_size, - ) - ] - ) - ) - - def __len__(self): - return self.batches_per_epoch - - def __getitem__(self, ind): - return self.batch_gen.batch(ds_samples=self.samples[ind]) - - -class PatchIndex: - IDX_ZERO = -1 - IDX_MISSING = -2 - - def __init__( - self, - ds_patches: xr.Dataset, - ds_zero_patches: xr.Dataset, - box_size=(4, 4), - zero_value=0, - missing_value=0, - ): - self.box_size = box_size - self.zero_value = zero_value - self.missing_value = missing_value - self.patch_data = ds_patches.patches - self.sample_shape = ( - ds_patches.patches.shape[1] * box_size[0], - ds_patches.patches.shape[2] * box_size[1], - ) - patch_indices = init_index( - ds_patches.patch_times_seconds.compute().values, - ds_patches.patch_coords.compute().values, - ) - zero_patch_indices = init_index( - ds_zero_patches.zero_patch_times_seconds.compute().values, - ds_zero_patches.zero_patch_coords.compute().values, - value=PatchIndex.IDX_ZERO, - ) - self.patch_index = recursive_merge_lists_of_dicts( - patch_indices, zero_patch_indices - ) - - self._batch = None - - def _alloc_batch(self, batch_size, num_timesteps): - needs_rebuild = ( - (self._batch is None) - or (self._batch.shape[0] < batch_size) - or (self._batch.shape[1] < num_timesteps) - ) - if needs_rebuild: - del self._batch - self._batch = np.zeros( - (batch_size, num_timesteps) + self.sample_shape, self.patch_data.dtype - ) - return self._batch - - def __call__(self, t, i0_all, j0_all): - t = t.compute().values - i0_all = i0_all.compute().values - j0_all = j0_all.compute().values - batch = self._alloc_batch(*t.shape) - - i1_all = i0_all + self.box_size[0] - j1_all = j0_all + self.box_size[1] - bi_size = self.patch_data.shape[1] - bj_size = self.patch_data.shape[2] - - build_batch( - batch, - self.patch_data, - self.patch_index, - t, - i0_all, - i1_all, - j0_all, - j1_all, - bi_size, - bj_size, - self.zero_value, - self.missing_value, - ) - return batch[:, : t.shape[1], ...] - - -# -@njit -def init_index( - patch_times_seconds: np.ndarray, - patch_coords: np.ndarray, - value: Optional[int] = None, -) -> List[dict]: - """Initialize the patch indices. - - Args: - patch_times_seconds (np.ndarray): The patch times in seconds - patch_coords (np.ndarray): The patch coordinates - value (Optional[int], optional): The value to assign for each index. Defaults to None. - - Returns: - List[dict]: The list of patch indices - """ - dict_type: dict = nb_typed.Dict.empty( - key_type=types.int64, - value_type=nb_typed.Dict.empty( - key_type=types.int64, - value_type=nb_typed.Dict.empty( - key_type=types.int64, value_type=types.int64 - ), - ), - ) - dict_list: List[dict] = nb_typed.List.empty_list(dict_type) - - if value is None: - value_array = np.arange(patch_times_seconds.shape[0], dtype=np.int64) - else: - value_array = np.full(patch_times_seconds.shape[0], value, dtype=np.int64) - - for k, t in enumerate(patch_times_seconds): - i = np.int64(patch_coords[k, 0]) - j = np.int64(patch_coords[k, 1]) - dict_list.append({np.int64(t): {i: {j: value_array[k]}}}) - return dict_list - - -# numba can't find these values from PatchIndex -IDX_ZERO = PatchIndex.IDX_ZERO -IDX_MISSING = PatchIndex.IDX_MISSING - - -# @njit(parallel=True) -def build_batch( - batch, - patch_data, - patch_index, - t_all, - i0_all, - i1_all, - j0_all, - j1_all, - bi_size, - bj_size, - zero_value, - missing_value, -): - for k in range(t_all.shape[0]): - i0 = i0_all[k] - i1 = i1_all[k] - j0 = j0_all[k] - j1 = j1_all[k] - - for bt, t in enumerate(t_all[k, :]): - for i in range(i0, i1): - bi0 = (i - i0) * bi_size - bi1 = bi0 + bi_size - for j in range(j0, j1): - - ind = IDX_MISSING - t_dict = patch_index.get(t) - if t_dict is not None: - i_dict = t_dict.get(i) - if i_dict is not None: - ind = int(i_dict.get(j, IDX_MISSING)) - - bj0 = (j - j0) * bj_size - bj1 = bj0 + bj_size - if ind >= 0: - batch[k, bt, bi0:bi1, bj0:bj1] = patch_data[ind].compute() - elif ind == IDX_ZERO: - batch[k, bt, bi0:bi1, bj0:bj1] = zero_value - elif ind == IDX_MISSING: - batch[k, bt, bi0:bi1, bj0:bj1] = missing_value - - -class ForecastPatchIndexWrapper(PatchIndex): - def __init__(self, patch_index): - - self.patch_index = patch_index - raw_names = {"-".join(v.split("-")[:-1]) for v in patch_index} - if len(raw_names) != 1: - raise ValueError("Can only wrap variables with the same base name") - self.raw_name = list(raw_names)[0] - lags_hour = [int(v.split("-")[-1]) for v in patch_index] - self.lags_hour = set(lags_hour) - forecast_interval_hour = np.diff(sorted(lags_hour)) - if len(set(forecast_interval_hour)) != 1: - raise ValueError("Lags must be evenly spaced") - forecast_interval_hour = forecast_interval_hour[0] - if 24 % forecast_interval_hour: - raise ValueError("24 hours must be a multiple of the forecast interval") - self.forecast_interval_hour = forecast_interval_hour - self.forecast_interval = 3600 * forecast_interval_hour - - # need to set these for _alloc_batch to work - self._batch = None - v = list(self.patch_index.keys())[0] - self.sample_shape = self.patch_index[v].sample_shape - self.patch_data = self.patch_index[v].patch_data - - def __call__(self, t, i0, j0): - batch = self._alloc_batch(*t.shape) - - # ensure that all data come from the same forecast - t0 = t[:, :1] - start_time_from_fc = t0 % self.forecast_interval - time_from_fc = start_time_from_fc + (t - t0) - lags_hour = ( - time_from_fc // self.forecast_interval - ) * self.forecast_interval_hour - - for lag in self.lags_hour: - raw_name_lag = f"{self.raw_name}-{lag}" - batch_lag = self.patch_index[raw_name_lag](t, i0, j0) - lag_mask = lags_hour == lag - copy_masked_times(batch_lag, batch, lag_mask) - - return batch[:, : t.shape[1], ...] - - -# @njit(parallel=True) -def copy_masked_times(from_batch, to_batch, mask): - for k in prange(from_batch.shape[0]): - for bt in range(from_batch.shape[1]): - if mask[k, bt]: - to_batch[k, bt, :, :] = from_batch[k, bt, :, :]