Skip to content

Designing the training pipeline #9

@leifdenby

Description

@leifdenby

@arjj8 and I had some good discussions yesterday about how to design the training pipeline API. We want to make sure that the API is flexible enough to support different use cases, while also being easy to use and understand.

The first thing we did was try to go through the steps of data transformation, from the source datasets (assuming the datastructure of that mlcast-dataset-validator imposes) to the final torch.Tensor that is fed into the model.

From the LDCast arhitecture (work by @franchg, @arjj8, @mfroelund, @martinbo-meteo, links below), we have the following steps:

  1. read source dataset
  2. tile this dataset in space and time (e.g. 12 time steps, 128x128 spatial tiles), eliminating tiles with too much missing data
  3. sample tiles to normalize the distribution of training samples (e.g. ensuring that we have a good distribution of samples across different precipitation intensities)
  4. normalize the data for training (e.g. scaling precipitation values to a certain range)
  5. batch the data into torch.Tensor batches for training

We have summarised this in the following flowchart:

flowchart TB
    A["mlcast-datasets intake repo<br />zarr URL"]
    H["user-provided zarr dataset URL"]
    B["source dataset<br />shape: [n_time, x, y]"]
    C["Tiled xr.DataArray<br />shape: [tile_id, n_time_window, x_tile, y_tile]"]
    D["Sampled xr.DataArray<br />shape: [sampled_tile_id, n_time_sample, x_tile, y_tile]"]
    I["Normalized xr.DataArray<br />shape: [sampled_tile_id, n_time_sample, x_tile, y_tile]"]
    E["torch.Tensor<br />shape: [batch_size, n_time_sample, x_tile, y_tile]"]
    F["CSV index<br />{tiling_id}.samples.csv"]
    G["CSV index<br />{sampling_id}.samples.csv"]

    A -- "open dataset:<br />OpenMLCastDataset()" --> B
    H -- "open dataset:<br />OpenXarrayDataset()" --> B
    B -- "tiling step:<br />TilingSampler()" --> C
    C -- "write tiling index:<br />{tiling_id}.samples.csv<br />tiling_id = f\"{dataset_id}.{tile_size}.{n_time_window}\"" --> F
    C -- "sampling step:<br />BinNormSampler()" --> D
    D -- "write samples index:<br />{sampling_id}.samples.csv<br />sampling_id = f\"{tiling_id}.{n_time_sample}\"" --> G
    D -- "normalize step:<br />NormalizeForTraining()" --> I
    I -- "batching step" --> E
Loading

Because some of these steps might be optional (maybe some architectures might not require tiling) or we might them done in a different way (different sampling strategy for example) I think it would be nice to have a design where we can define which steps to apply and in which order.

# Example for LDCast architecture
pipeline = [
    MLCastCatalogDataset("dmi.precipitation.5min", var_name="rainrate"),
    TilingSampler(
        n_time_window=12,
        tile_size=(128, 128),
    ),
    BinNormSampler(
        aggregation="sum",
        n_scalar_total_bins=10,
    ),
    NormalizeForTraining(),
    ToTorchTensor()
]

# Example without sampling step
pipeline = [
    MLCastCatalogDataset("dmi.precipitation.5min", var_name="rainrate"),
    TilingSampler(
        n_time_window=12,
        tile_size=(128, 128),
    ),
    NormalizeForTraining(),
    ToTorchTensor()
]

I had hoped to use torchdata datapipes for this, because torch data pipes allow for exactly this kind of composability for constructing data processing pipelines. However I have since yesterday found out that datapipes were depricated from pytorch in July 2024 and the last version to include them is torchdata==0.9.0 (which is supported by torch==2.5.0, but not torch==2.6.0, so we'd have to stick to torch==2.5.x if we wanted to use data pipes), and although a replacement is being worked on, it isn't clear to me where things are at.

One question here is: How much of this computation do we want ahead of training time? Currently, I think @franchg you are taking the view that we should:

  1. use the source data directly (i.e. not create a training dataset)
  2. create the tiling index and sampling indexes ahead of training
  3. use fixed scaling for normalization - can we get away with this in general?

Does this seem about right @franchg? Also, what should we call this step? It seems like "data-preparation" to me, can we call it that?

Doing part of this processing, but not all, ahead of training makes the design a bit tricky here, but not impossible :)

What is the problem? At its core it comes down to mixing different implementations of parallelism:

  1. When computing for example the tile statistics over the whole dataset for the sampling step, it would be natural to use dask with xarray to compute these statistics in parallel across the whole dataset.
  2. When we want to use the pipeline during training, we want to be able to use torch.utils.data.DataLoader to load batches of data in parallel using multiple workers.

As far as I can tell, there isn't a good way to mix these two types of parallelism in a single pipeline (I think this is why torch datapipes were removed, and there is no obvious alternative yet). If we for example had multiple workers doing data-loading with torch.utils.data.DataLoader, we can't then produce the tile statistics for the sampling step using dask in parallel across the whole dataset, because each worker would only see a subset of the data.

So where does this leave us? I have asked @arjj8 to refactor @franchg's tiling and sampling code as pure xarray code that can be run in parallel with dask, which return the resulting tile-set as xr.DataArray and has the side effect of creating the tiling and sampling CSV files. With this we can then either 1) have separate calls to create the processing pipeline during data-preparation (i.e. either creating a tiled+sampled dataset on disk, or using the indexes during training to load the relevant tiles and samples from the source dataset) and training, 2) still use torch datapipes and restrict ourselves to torch<2.6.0, I'm not sure if this would be an issue, or 3) or come up with an abstraction where we can use dask during `pre-rain

All this discussion might all be overkill at this stage. But I think it is worth considering how we design this, because once we want to experiment with different approaches and combine different datasets this could all become quite complex, and we want to make sure that we have a design that is flexible enough to support different use cases, while also being easy to use and understand.

Links to relevant discussions and code:

Thoughts?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions