You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@arjj8 and I had some good discussions yesterday about how to design the training pipeline API. We want to make sure that the API is flexible enough to support different use cases, while also being easy to use and understand.
The first thing we did was try to go through the steps of data transformation, from the source datasets (assuming the datastructure of that mlcast-dataset-validator imposes) to the final torch.Tensor that is fed into the model.
tile this dataset in space and time (e.g. 12 time steps, 128x128 spatial tiles), eliminating tiles with too much missing data
sample tiles to normalize the distribution of training samples (e.g. ensuring that we have a good distribution of samples across different precipitation intensities)
normalize the data for training (e.g. scaling precipitation values to a certain range)
batch the data into torch.Tensor batches for training
We have summarised this in the following flowchart:
flowchart TB
A["mlcast-datasets intake repo<br />zarr URL"]
H["user-provided zarr dataset URL"]
B["source dataset<br />shape: [n_time, x, y]"]
C["Tiled xr.DataArray<br />shape: [tile_id, n_time_window, x_tile, y_tile]"]
D["Sampled xr.DataArray<br />shape: [sampled_tile_id, n_time_sample, x_tile, y_tile]"]
I["Normalized xr.DataArray<br />shape: [sampled_tile_id, n_time_sample, x_tile, y_tile]"]
E["torch.Tensor<br />shape: [batch_size, n_time_sample, x_tile, y_tile]"]
F["CSV index<br />{tiling_id}.samples.csv"]
G["CSV index<br />{sampling_id}.samples.csv"]
A -- "open dataset:<br />OpenMLCastDataset()" --> B
H -- "open dataset:<br />OpenXarrayDataset()" --> B
B -- "tiling step:<br />TilingSampler()" --> C
C -- "write tiling index:<br />{tiling_id}.samples.csv<br />tiling_id = f\"{dataset_id}.{tile_size}.{n_time_window}\"" --> F
C -- "sampling step:<br />BinNormSampler()" --> D
D -- "write samples index:<br />{sampling_id}.samples.csv<br />sampling_id = f\"{tiling_id}.{n_time_sample}\"" --> G
D -- "normalize step:<br />NormalizeForTraining()" --> I
I -- "batching step" --> E
Loading
Because some of these steps might be optional (maybe some architectures might not require tiling) or we might them done in a different way (different sampling strategy for example) I think it would be nice to have a design where we can define which steps to apply and in which order.
# Example for LDCast architecturepipeline= [
MLCastCatalogDataset("dmi.precipitation.5min", var_name="rainrate"),
TilingSampler(
n_time_window=12,
tile_size=(128, 128),
),
BinNormSampler(
aggregation="sum",
n_scalar_total_bins=10,
),
NormalizeForTraining(),
ToTorchTensor()
]
# Example without sampling steppipeline= [
MLCastCatalogDataset("dmi.precipitation.5min", var_name="rainrate"),
TilingSampler(
n_time_window=12,
tile_size=(128, 128),
),
NormalizeForTraining(),
ToTorchTensor()
]
One question here is: How much of this computation do we want ahead of training time? Currently, I think @franchg you are taking the view that we should:
use the source data directly (i.e. not create a training dataset)
create the tiling index and sampling indexes ahead of training
Does this seem about right @franchg? Also, what should we call this step? It seems like "data-preparation" to me, can we call it that?
Doing part of this processing, but not all, ahead of training makes the design a bit tricky here, but not impossible :)
What is the problem? At its core it comes down to mixing different implementations of parallelism:
When computing for example the tile statistics over the whole dataset for the sampling step, it would be natural to use dask with xarray to compute these statistics in parallel across the whole dataset.
When we want to use the pipeline during training, we want to be able to use torch.utils.data.DataLoader to load batches of data in parallel using multiple workers.
As far as I can tell, there isn't a good way to mix these two types of parallelism in a single pipeline (I think this is why torch datapipes were removed, and there is no obvious alternative yet). If we for example had multiple workers doing data-loading with torch.utils.data.DataLoader, we can't then produce the tile statistics for the sampling step using dask in parallel across the whole dataset, because each worker would only see a subset of the data.
So where does this leave us? I have asked @arjj8 to refactor @franchg's tiling and sampling code as pure xarray code that can be run in parallel with dask, which return the resulting tile-set as xr.DataArray and has the side effect of creating the tiling and sampling CSV files. With this we can then either 1) have separate calls to create the processing pipeline during data-preparation (i.e. either creating a tiled+sampled dataset on disk, or using the indexes during training to load the relevant tiles and samples from the source dataset) and training, 2) still use torch datapipes and restrict ourselves to torch<2.6.0, I'm not sure if this would be an issue, or 3) or come up with an abstraction where we can use dask during `pre-rain
All this discussion might all be overkill at this stage. But I think it is worth considering how we design this, because once we want to experiment with different approaches and combine different datasets this could all become quite complex, and we want to make sure that we have a design that is flexible enough to support different use cases, while also being easy to use and understand.
@arjj8 and I had some good discussions yesterday about how to design the training pipeline API. We want to make sure that the API is flexible enough to support different use cases, while also being easy to use and understand.
The first thing we did was try to go through the steps of data transformation, from the source datasets (assuming the datastructure of that mlcast-dataset-validator imposes) to the final
torch.Tensorthat is fed into the model.From the LDCast arhitecture (work by @franchg, @arjj8, @mfroelund, @martinbo-meteo, links below), we have the following steps:
We have summarised this in the following flowchart:
flowchart TB A["mlcast-datasets intake repo<br />zarr URL"] H["user-provided zarr dataset URL"] B["source dataset<br />shape: [n_time, x, y]"] C["Tiled xr.DataArray<br />shape: [tile_id, n_time_window, x_tile, y_tile]"] D["Sampled xr.DataArray<br />shape: [sampled_tile_id, n_time_sample, x_tile, y_tile]"] I["Normalized xr.DataArray<br />shape: [sampled_tile_id, n_time_sample, x_tile, y_tile]"] E["torch.Tensor<br />shape: [batch_size, n_time_sample, x_tile, y_tile]"] F["CSV index<br />{tiling_id}.samples.csv"] G["CSV index<br />{sampling_id}.samples.csv"] A -- "open dataset:<br />OpenMLCastDataset()" --> B H -- "open dataset:<br />OpenXarrayDataset()" --> B B -- "tiling step:<br />TilingSampler()" --> C C -- "write tiling index:<br />{tiling_id}.samples.csv<br />tiling_id = f\"{dataset_id}.{tile_size}.{n_time_window}\"" --> F C -- "sampling step:<br />BinNormSampler()" --> D D -- "write samples index:<br />{sampling_id}.samples.csv<br />sampling_id = f\"{tiling_id}.{n_time_sample}\"" --> G D -- "normalize step:<br />NormalizeForTraining()" --> I I -- "batching step" --> EBecause some of these steps might be optional (maybe some architectures might not require tiling) or we might them done in a different way (different sampling strategy for example) I think it would be nice to have a design where we can define which steps to apply and in which order.
I had hoped to use torchdata datapipes for this, because torch data pipes allow for exactly this kind of composability for constructing data processing pipelines. However I have since yesterday found out that datapipes were depricated from pytorch in July 2024 and the last version to include them is
torchdata==0.9.0(which is supported bytorch==2.5.0, but nottorch==2.6.0, so we'd have to stick totorch==2.5.xif we wanted to use data pipes), and although a replacement is being worked on, it isn't clear to me where things are at.One question here is: How much of this computation do we want ahead of training time? Currently, I think @franchg you are taking the view that we should:
Does this seem about right @franchg? Also, what should we call this step? It seems like "data-preparation" to me, can we call it that?
Doing part of this processing, but not all, ahead of training makes the design a bit tricky here, but not impossible :)
What is the problem? At its core it comes down to mixing different implementations of parallelism:
daskwithxarrayto compute these statistics in parallel across the whole dataset.torch.utils.data.DataLoaderto load batches of data in parallel using multiple workers.As far as I can tell, there isn't a good way to mix these two types of parallelism in a single pipeline (I think this is why torch datapipes were removed, and there is no obvious alternative yet). If we for example had multiple workers doing data-loading with
torch.utils.data.DataLoader, we can't then produce the tile statistics for the sampling step usingdaskin parallel across the whole dataset, because each worker would only see a subset of the data.So where does this leave us? I have asked @arjj8 to refactor @franchg's tiling and sampling code as pure
xarraycode that can be run in parallel withdask, which return the resulting tile-set asxr.DataArrayand has the side effect of creating the tiling and sampling CSV files. With this we can then either 1) have separate calls to create the processing pipeline during data-preparation (i.e. either creating a tiled+sampled dataset on disk, or using the indexes during training to load the relevant tiles and samples from the source dataset) and training, 2) still use torch datapipes and restrict ourselves totorch<2.6.0, I'm not sure if this would be an issue, or 3) or come up with an abstraction where we can usedaskduring `pre-rainAll this discussion might all be overkill at this stage. But I think it is worth considering how we design this, because once we want to experiment with different approaches and combine different datasets this could all become quite complex, and we want to make sure that we have a design that is flexible enough to support different use cases, while also being easy to use and understand.
Links to relevant discussions and code:
Thoughts?