From 19c26ba753973e9346e3f8d4264505f1d8aa3db0 Mon Sep 17 00:00:00 2001 From: AdrienAudren Date: Tue, 8 Jul 2025 18:10:50 +0200 Subject: [PATCH 1/6] discontinuous forcings --- py4cast/datasets/base.py | 1 + py4cast/datasets/titan/__init__.py | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/py4cast/datasets/base.py b/py4cast/datasets/base.py index 241b007f..5edb115e 100644 --- a/py4cast/datasets/base.py +++ b/py4cast/datasets/base.py @@ -418,6 +418,7 @@ def is_valid(self) -> bool: param=param, timestamps=self.timestamps, file_format=self.settings.file_format, + num_input_steps=self.settings.num_input_steps, ): return False return True diff --git a/py4cast/datasets/titan/__init__.py b/py4cast/datasets/titan/__init__.py index 4a5bccb8..06da5c7c 100644 --- a/py4cast/datasets/titan/__init__.py +++ b/py4cast/datasets/titan/__init__.py @@ -157,8 +157,14 @@ def exists( param: WeatherParam, timestamps: Timestamps, file_format: Literal["npy", "grib"] = "grib", + num_input_steps: int = 1, ) -> bool: - for date in timestamps.validity_times: + if param.kind == "input": + # inputs/forcings only required after num_input_steps + valid_times = timestamps.validity_times[num_input_steps:] + else: + valid_times = timestamps.validity_times + for date in valid_times: filepath = self.get_filepath(ds_name, param, date, file_format) if not filepath.exists(): return False From a3b0acd7e9dcbe67c289a2380f8991f758f34822 Mon Sep 17 00:00:00 2001 From: AdrienAudren Date: Fri, 5 Sep 2025 17:27:22 +0200 Subject: [PATCH 2/6] Handles training with members, loss crps --- config/CLI/dataset/titan.yaml | 21 ++++++----- config/CLI/model/unetrpp.yaml | 7 ++-- py4cast/cli.py | 8 ++++ py4cast/datasets/__init__.py | 6 ++- py4cast/datasets/access.py | 2 + py4cast/datasets/base.py | 29 ++++++++++++++- py4cast/datasets/titan/titan_cli.py | 4 ++ py4cast/lightning.py | 36 +++++++++++++++--- py4cast/losses.py | 58 ++++++++++++++++++++++++++++- 9 files changed, 149 insertions(+), 22 deletions(-) diff --git a/config/CLI/dataset/titan.yaml b/config/CLI/dataset/titan.yaml index 4f984662..331caec9 100644 --- a/config/CLI/dataset/titan.yaml +++ b/config/CLI/dataset/titan.yaml @@ -1,29 +1,32 @@ data: #args forwarded (linked) to model - dataset_name: titan_aro_arp + dataset_name: titan_full num_input_steps: 1 num_pred_steps_train: 1 num_pred_steps_val_test: 1 - batch_size: 2 + batch_size: 4 # per device If afcrps loss used, has to be set as effective_batch_size*noise_members + noise_strategy: "CondLayerNorm" # "forcing" or "CondLayerNorm" or "None" + noise_members: 4 # total number of members + #other args - num_workers: 10 + num_workers: 2 prefetch_factor: null pin_memory: False dataset_conf: periods: train: - start: 20200101 - end: 20221231 + start: 2021010100 + end: 2021010100 obs_step: 3600 valid: - start: 20230101 - end: 20231231 + start: 2021010100 + end: 2021010100 obs_step: 3600 obs_step_btw_t0: 10800 test: - start: 20240101 - end: 20240831 + start: 2023050122 + end: 2023080122 obs_step: 3600 obs_step_btw_t0: 10800 grid: diff --git a/config/CLI/model/unetrpp.yaml b/config/CLI/model/unetrpp.yaml index 715de378..356e8f49 100644 --- a/config/CLI/model/unetrpp.yaml +++ b/config/CLI/model/unetrpp.yaml @@ -1,6 +1,6 @@ model: - model_name: UNETRPP - loss_name: mse # mse or mae + model_name: UNetRPP + loss_name: afcrps # mse or mae or afcrps num_inter_steps: 1 # Number of intermediary steps (without any data) num_samples_to_plot: 1 training_strategy: scaled_ar # diff_ar or scaled_ar or downscaling_only @@ -27,4 +27,5 @@ model: decoder_proj_size: 64 encoder_proj_sizes: [64, 64, 64, 32] add_skip_connections: true - attention_code: "torch" \ No newline at end of file + attention_code: "torch" + CondLayerNorm: True # use ConditionalLayerNorm instead of LayerNorm \ No newline at end of file diff --git a/py4cast/cli.py b/py4cast/cli.py index d542912a..b3e0d232 100644 --- a/py4cast/cli.py +++ b/py4cast/cli.py @@ -44,6 +44,14 @@ def add_arguments_to_parser(self, parser): "data.dataset_conf", "model.dataset_conf", ) + parser.link_arguments( + "data.noise_members", + "model.noise_members", + ) + parser.link_arguments( + "data.noise_strategy", + "model.noise_strategy", + ) parser.link_arguments( "data.train_dataset_info", "model.dataset_info", diff --git a/py4cast/datasets/__init__.py b/py4cast/datasets/__init__.py index 209a7a7c..385d1673 100644 --- a/py4cast/datasets/__init__.py +++ b/py4cast/datasets/__init__.py @@ -1,6 +1,6 @@ import traceback import warnings -from typing import Dict, Tuple +from typing import Dict, Tuple, Literal from .base import DatasetABC # noqa: F401 @@ -47,6 +47,8 @@ def get_datasets( num_input_steps: int, num_pred_steps_train: int, num_pred_steps_val_test: int, + noise_members: int, + noise_strategy: Literal["forcing", "CondLayerNorm", "None"], dataset_conf: Dict | None = None, ) -> Tuple[DatasetABC, DatasetABC, DatasetABC]: """ @@ -76,4 +78,6 @@ def get_datasets( num_input_steps, num_pred_steps_train, num_pred_steps_val_test, + noise_members, + noise_strategy, ) diff --git a/py4cast/datasets/access.py b/py4cast/datasets/access.py index 06aa0022..ad155cdf 100644 --- a/py4cast/datasets/access.py +++ b/py4cast/datasets/access.py @@ -398,6 +398,8 @@ class SamplePreprocSettings: standardize: bool = True file_format: Literal["npy", "grib"] = "grib" members: Optional[Tuple[int]] = None + noise_members: int = 0 + noise_strategy: Literal["forcing", "CondLayerNorm", "None"] = "forcing", add_landsea_mask: bool = False diff --git a/py4cast/datasets/base.py b/py4cast/datasets/base.py index 5edb115e..52b1559c 100644 --- a/py4cast/datasets/base.py +++ b/py4cast/datasets/base.py @@ -500,6 +500,15 @@ def load(self, no_standardize: bool = False) -> Item: timedeltas=self.output_timestamps.timedeltas, grid=self.grid, ) + # Additional noise channel as forcing + if self.settings.noise_strategy == "forcing": + external_forcings.append( + NamedTensor( + feature_names=["noise"], + tensor=torch.randn_like(external_forcings[-1].tensor), + names=["timestep", "lat", "lon", "features"], + ) + ) for forcing in external_forcings: forcing.unsqueeze_and_expand_from_(loutputs[0]) @@ -707,7 +716,12 @@ def sample_list(self) -> List[Sample]: member, ) if sample.is_valid(): - samples.append(sample) + # replicate samples to match the number of noise members + if self.settings.noise_members > 0: + for k in range(self.settings.noise_members): + samples.append(sample) + else: + samples.append(sample) else: invalid_samples += 1 print( @@ -722,6 +736,7 @@ def torch_dataloader( shuffle: bool = False, prefetch_factor: Union[int, None] = None, pin_memory: bool = False, + drop_last: bool = False, ) -> DataLoader: """ Builds a torch dataloader from self. @@ -734,6 +749,7 @@ def torch_dataloader( prefetch_factor=prefetch_factor, collate_fn=collate_fn, pin_memory=pin_memory, + drop_last=drop_last, ) @cached_property @@ -743,6 +759,7 @@ def input_dim(self) -> int: """ res = 4 # For date res += 1 # For solar forcing + res += 1 if self.settings.noise_strategy == "forcing" else 0 # additional noise channel for param in self.params: if param.kind == "input": @@ -863,6 +880,8 @@ def from_dict( num_input_steps: int, num_pred_steps_train: int, num_pred_steps_val_test: int, + noise_members: int, + noise_strategy: Literal["forcing", "CondLayerNorm", "None"], ) -> Tuple[Type["DatasetABC"], Type["DatasetABC"], Type["DatasetABC"]]: grid = Grid(load_grid_info_func=accessor_kls.load_grid_info, **conf["grid"]) @@ -878,6 +897,8 @@ def from_dict( num_input_steps=num_input_steps, num_pred_steps=num_pred_steps_train, members=members, + noise_members=noise_members, + noise_strategy=noise_strategy, **conf["settings"], ) train_period = Period(**conf["periods"]["train"], name="train") @@ -890,6 +911,8 @@ def from_dict( num_input_steps=num_input_steps, num_pred_steps=num_pred_steps_val_test, members=members, + noise_members=noise_members, + noise_strategy=noise_strategy, **conf["settings"], ) valid_period = Period(**conf["periods"]["valid"], name="valid") @@ -912,6 +935,8 @@ def from_json( num_input_steps: int, num_pred_steps_train: int, num_pred_steps_val_tests: int, + noise_members: int, + noise_strategy: Literal["forcing", "CondLayerNorm", "None"], predict_conf: Union[Dict, None] = None, ) -> Tuple[Type["DatasetABC"], Type["DatasetABC"], Type["DatasetABC"]]: """ @@ -932,4 +957,6 @@ def from_json( num_input_steps, num_pred_steps_train, num_pred_steps_val_tests, + noise_members, + noise_strategy, ) diff --git a/py4cast/datasets/titan/titan_cli.py b/py4cast/datasets/titan/titan_cli.py index 82a10313..ab191128 100644 --- a/py4cast/datasets/titan/titan_cli.py +++ b/py4cast/datasets/titan/titan_cli.py @@ -50,6 +50,8 @@ def prepare( num_input_steps: int = 1, num_pred_steps_train: int = 1, num_pred_steps_val_test: int = 1, + noise_members: int = 0, + noise_strategy: str = "None", convert_grib2npy: bool = False, compute_stats: bool = True, ): @@ -72,6 +74,8 @@ def prepare( num_input_steps=num_input_steps, num_pred_steps_train=num_pred_steps_train, num_pred_steps_val_test=num_pred_steps_val_test, + noise_members=noise_members, + noise_strategy=noise_strategy, ) train_ds.cache_dir.mkdir(exist_ok=True) data_dir = train_ds.cache_dir / "data" diff --git a/py4cast/lightning.py b/py4cast/lightning.py index 62d2fe32..638d5874 100644 --- a/py4cast/lightning.py +++ b/py4cast/lightning.py @@ -66,6 +66,8 @@ def __init__( prefetch_factor: int | None = None, pin_memory: bool = False, dataset_conf: Dict | None = None, + noise_members: int = 0, + noise_strategy: Literal["forcing", "CondLayerNorm", "None"] = "forcing", ): super().__init__() self.num_input_steps = num_input_steps @@ -79,6 +81,8 @@ def __init__( self.num_workers = num_workers self.prefetch_factor = prefetch_factor self.pin_memory = pin_memory + self.noise_members = noise_members + self.noise_strategy = noise_strategy # Get dataset in initialisation to have access to this attribute before method trainer.fit self.train_ds, self.val_ds, self.test_ds = get_datasets( @@ -86,6 +90,8 @@ def __init__( num_input_steps, num_pred_steps_train, num_pred_steps_val_test, + noise_members, + noise_strategy, dataset_conf, ) @@ -108,6 +114,7 @@ def train_dataloader(self): shuffle=True, prefetch_factor=self.prefetch_factor, pin_memory=self.pin_memory, + drop_last=True, ) def val_dataloader(self): @@ -117,6 +124,7 @@ def val_dataloader(self): shuffle=False, prefetch_factor=self.prefetch_factor, pin_memory=self.pin_memory, + drop_last=True, ) def test_dataloader(self): @@ -126,6 +134,7 @@ def test_dataloader(self): shuffle=False, prefetch_factor=self.prefetch_factor, pin_memory=self.pin_memory, + drop_last=True, ) def predict_dataloader(self): @@ -161,9 +170,11 @@ def __init__( num_pred_steps_train: int = 1, num_pred_steps_val_test: int = 1, batch_size: int = 2, + noise_members: int = 0, + noise_strategy: Literal["forcing", "CondLayerNorm", "None"] = "forcing", # non-linked args model_name: Literal[tuple(model_registry.keys())] = "HalfUNet", - loss_name: Literal["mse", "mae"] = "mse", + loss_name: Literal["mse", "mae", "afcrps"] = "mse", num_inter_steps: int = 1, num_samples_to_plot: int = 1, training_strategy: Literal[ @@ -187,6 +198,8 @@ def __init__( self.dataset_conf = dataset_conf self.dataset_info = dataset_info self.batch_size = batch_size + self.noise_members = noise_members + self.noise_strategy = noise_strategy self.model_name = model_name self.num_input_steps = num_input_steps self.num_pred_steps_train = num_pred_steps_train @@ -303,6 +316,8 @@ def __init__( self.loss = WeightedLoss("MSELoss", reduction="none") elif loss_name == "mae": self.loss = WeightedLoss("L1Loss", reduction="none") + elif loss_name == "afcrps": + self.loss = WeightedLoss("AFCRPS", reduction="none") else: raise TypeError(f"Unknown loss function: {loss_name}") self.loss.prepare(self, statics.interior_mask, dataset_info) @@ -562,6 +577,9 @@ def _common_step( # Should be greater or equal to 1 (otherwise nothing is done). for k in range(num_inter_steps): x = self._next_x(batch, prev_states, i) + if self.noise_strategy == "CondLayerNorm": + # generate (32,) noise vector for stochastic conditional layer normalization + epsilon = torch.randn(self.batch_size, 32, device=prev_states.device) # Graph (B, N_grid, d_f) or Conv (B, N_lat,N_lon d_f) if self.channels_last: x = x.to(memory_format=torch.channels_last) @@ -570,10 +588,16 @@ def _common_step( # Here we adapt our tensors to the order of dimensions of CNNs and ViTs if self.model.features_second: x = features_last_to_second(x) - y = self.model(x) + if self.noise_strategy == "CondLayerNorm": + y = self.model(x, cond_z=epsilon) + else: + y = self.model(x) y = features_second_to_last(y) else: - y = self.model(x) + if self.noise_strategy == "CondLayerNorm": + y = self.model(x, cond_z=epsilon) + else: + y = self.model(x) ds = self.training_strategy == "downscaling_only" @@ -783,7 +807,7 @@ def training_step(self, batch: ItemBatch, batch_idx: int) -> torch.Tensor: mask = self.get_mask_on_nan(target) # Compute loss: mean over unrolled times and batch - batch_loss = torch.mean(self.loss(prediction, target, mask=mask)) + batch_loss = torch.mean(self.loss(prediction, target, mask=mask, noise_members=self.noise_members)) self.training_step_losses.append(batch_loss) @@ -858,7 +882,7 @@ def validation_step(self, batch: ItemBatch, batch_idx: int): mask = self.get_mask_on_nan(target) - time_step_loss = torch.mean(self.loss(prediction, target, mask), dim=0) + time_step_loss = torch.mean(self.loss(prediction, target, mask, noise_members=self.noise_members), dim=0) mean_loss = torch.mean(time_step_loss) if self.logging_enabled: @@ -988,7 +1012,7 @@ def test_step(self, batch: ItemBatch, batch_idx: int): mask = self.get_mask_on_nan(target) - time_step_loss = torch.mean(self.loss(prediction, target, mask), dim=0) + time_step_loss = torch.mean(self.loss(prediction, target, mask, noise_members=self.noise_members), dim=0) mean_loss = torch.mean(time_step_loss) if self.logging_enabled: diff --git a/py4cast/losses.py b/py4cast/losses.py index f5e8ce99..a43d3b5b 100644 --- a/py4cast/losses.py +++ b/py4cast/losses.py @@ -13,6 +13,49 @@ from py4cast.datasets.base import DatasetInfo, NamedTensor +class AlmostFairCRPS(torch.nn.Module): # Batched Almost Fair CRPS + """ + AIFS-CRPS: Ensemble forecasting using a model trained with a loss function based on the Continuous Ranked Probability Score + https://arxiv.org/abs/2412.15832 + """ + def __init__(self, alpha=0.95): + super().__init__() + self.alpha = alpha + + def forward(self, preds, targets, noise_members=2): + + B = targets.shape[0] // noise_members # effective batch size + M = noise_members # Number of ensemble members + epsilon = (1 - self.alpha) / M + + # Reshape predictions and targets + preds = preds.view(B, M, *preds.shape[1:]) # (B, M, T, W, H, d_f) + targets = targets.view(B, M, *targets.shape[1:]) # (B, M, T, W, H, d_f) + + abs_diff = torch.abs(preds - targets) # (B, M, T, W, H, d_f) + + # Pairwise differences: |x_j - x_k| + preds_j = preds.unsqueeze(2) # (B, M, 1, T, W, H, d_f) + preds_k = preds.unsqueeze(1) # (B, 1, M, T, W, H, d_f) + pairwise_diff = torch.abs(preds_j - preds_k) # (B, M, M, T, W, H, d_f) + + # |x_j - y| + |x_k - y| + abs_j = abs_diff.unsqueeze(2) # (B, M, 1, T, W, H, d_f) + abs_k = abs_diff.unsqueeze(1) # (B, 1, M, T, W, H, d_f) + pairwise_abs_sum = abs_j + abs_k # (B, M, M, T, W, H, d_f) + + # Create off-diagonal mask (M, M) + mask = ~torch.eye(M, dtype=torch.bool, device=preds.device) # (M, M) + + # Apply mask per batch element + pairwise_diff = pairwise_diff[:, mask].view(B, M, M - 1, *preds.shape[2:]) + pairwise_abs_sum = pairwise_abs_sum[:, mask].view(B, M, M - 1, *preds.shape[2:]) + + # Final computation + afcrps = (pairwise_abs_sum - (1 - epsilon) * pairwise_diff).mean(dim=(1, 2)) / 2 # (B, T, W, H, d_f) + return afcrps + + class Py4CastLoss(ABC): """ Abstract class to force the user to implement the prepare and forward method because @@ -21,7 +64,12 @@ class Py4CastLoss(ABC): """ def __init__(self, loss: str, *args, **kwargs) -> None: - self.loss = getattr(torch.nn, loss)(*args, **kwargs) + if loss in ["MSELoss", "L1Loss"]: + self.loss = getattr(torch.nn, loss)(*args, **kwargs) + elif loss == "AFCRPS": + self.loss = AlmostFairCRPS() + else: + raise ValueError("Unrecognized loss function") @abstractmethod def prepare( @@ -109,6 +157,7 @@ def forward( prediction: NamedTensor, target: NamedTensor, mask: torch.Tensor, + noise_members: int = 2, reduce_spatial_dim=True, ) -> torch.Tensor: """ @@ -117,7 +166,12 @@ def forward( returns (B, pred_steps) """ # Compute Torch loss (defined in the parent class when this Mixin is used) - torch_loss = self.loss(prediction.tensor * mask, target.tensor * mask) + if noise_members > 1: + torch_loss = self.loss(prediction.tensor * mask, target.tensor * mask, noise_members=noise_members) + else: + if self.loss.__class__ == AlmostFairCRPS: + raise ValueError("When using AFCRPS loss, noise_members must be > 1") + torch_loss = self.loss(prediction.tensor * mask, target.tensor * mask) # Retrieve the weights for each feature weights = self.weights(tuple(prediction.feature_names), prediction.device) From f274c8f0f634e01689c6258eb27a12bcdd280c2c Mon Sep 17 00:00:00 2001 From: AdrienAudren Date: Tue, 9 Sep 2025 14:23:04 +0200 Subject: [PATCH 3/6] fix members handling --- py4cast/datasets/base.py | 16 +------------- py4cast/lightning.py | 46 ++++++++++++++++++++++++++++++++-------- py4cast/losses.py | 18 +++++++--------- 3 files changed, 46 insertions(+), 34 deletions(-) diff --git a/py4cast/datasets/base.py b/py4cast/datasets/base.py index 52b1559c..043943d3 100644 --- a/py4cast/datasets/base.py +++ b/py4cast/datasets/base.py @@ -500,15 +500,6 @@ def load(self, no_standardize: bool = False) -> Item: timedeltas=self.output_timestamps.timedeltas, grid=self.grid, ) - # Additional noise channel as forcing - if self.settings.noise_strategy == "forcing": - external_forcings.append( - NamedTensor( - feature_names=["noise"], - tensor=torch.randn_like(external_forcings[-1].tensor), - names=["timestep", "lat", "lon", "features"], - ) - ) for forcing in external_forcings: forcing.unsqueeze_and_expand_from_(loutputs[0]) @@ -716,12 +707,7 @@ def sample_list(self) -> List[Sample]: member, ) if sample.is_valid(): - # replicate samples to match the number of noise members - if self.settings.noise_members > 0: - for k in range(self.settings.noise_members): - samples.append(sample) - else: - samples.append(sample) + samples.append(sample) else: invalid_samples += 1 print( diff --git a/py4cast/lightning.py b/py4cast/lightning.py index 638d5874..17d9b597 100644 --- a/py4cast/lightning.py +++ b/py4cast/lightning.py @@ -577,9 +577,16 @@ def _common_step( # Should be greater or equal to 1 (otherwise nothing is done). for k in range(num_inter_steps): x = self._next_x(batch, prev_states, i) + if self.noise_strategy == "CondLayerNorm": # generate (32,) noise vector for stochastic conditional layer normalization - epsilon = torch.randn(self.batch_size, 32, device=prev_states.device) + # Skillful joint probabilistic weather forecasting from marginals: https://arxiv.org/pdf/2506.10772 + epsilon = torch.randn(self.batch_size*self.noise_members, 32, device=prev_states.device) + + elif self.noise_strategy == "forcing": + # Add noise channel as forcing with noise=True + x = self._next_x(batch, prev_states, i, noise=True) + # Graph (B, N_grid, d_f) or Conv (B, N_lat,N_lon d_f) if self.channels_last: x = x.to(memory_format=torch.channels_last) @@ -589,12 +596,14 @@ def _common_step( if self.model.features_second: x = features_last_to_second(x) if self.noise_strategy == "CondLayerNorm": + x = x.unsqueeze(1).expand(-1, self.noise_members, -1, -1, -1).reshape(self.batch_size*self.noise_members, *x.shape[1:]) y = self.model(x, cond_z=epsilon) else: y = self.model(x) y = features_second_to_last(y) else: if self.noise_strategy == "CondLayerNorm": + x = x.unsqueeze(1).expand(-1, self.noise_members, -1, -1, -1).reshape(self.batch_size*self.noise_members, *x.shape[1:]) y = self.model(x, cond_z=epsilon) else: y = self.model(x) @@ -606,6 +615,8 @@ def _common_step( if self.mask_on_nan: last_prev_state = torch.nan_to_num(last_prev_state, nan=0) + if self.noise_members > 1: + last_prev_state = last_prev_state.unsqueeze(1).expand(-1, self.noise_members, -1, -1, -1).reshape(self.batch_size*self.noise_members, *last_prev_state.shape[1:]) # We update the latest of our prev_states with the network output if scale_y: predicted_state = ( @@ -621,11 +632,12 @@ def _common_step( # Force it to true state for all intermediary step if not (phase == "inference") and force_border: new_state = ( - self.border_mask * border_state - + self.interior_mask * predicted_state + self.border_mask.expand_as(predicted_state) * border_state.repeat(self.noise_members, 1, 1, 1) + + self.interior_mask.expand_as(predicted_state) * predicted_state ) else: new_state = predicted_state + # Only update the prev_states if we are not at the last step if i < batch.num_pred_steps - 1 or k < num_inter_steps - 1: @@ -704,7 +716,7 @@ def _step_diffs( return step_diff_std, step_diff_mean def _next_x( - self, batch: ItemBatch, prev_states: NamedTensor, step_idx: int + self, batch: ItemBatch, prev_states: NamedTensor, step_idx: int, noise: bool = False ) -> torch.Tensor: """ Build the next x input for the model at timestep step_idx using the : @@ -753,11 +765,21 @@ def _next_x( # If downscaling only, inputs are not concatenated: only use static features and forcings. x = torch.cat( - inputs * (1 - ds) # = [] if downscaling strategy - + [self.grid_static_features[: batch.batch_size], forcing.tensor] - + mask_list, - dim=forcing.dim_index("features"), - ) + inputs * (1 - ds) # = [] if downscaling strategy + + [self.grid_static_features[: batch.batch_size], forcing.tensor] + + mask_list, + dim=forcing.dim_index("features"), + ) + + if noise and self.noise_strategy == "forcing": + # concatenate noise channel as a forcing + x = torch.cat( + [ + torch.cat([x, torch.randn_like(forcing.tensor[..., 0].unsqueeze(-1))], dim=forcing.dim_index("features")).unsqueeze(1) for _ in range(self.noise_members) + ], + dim=1, + ) + x = x.reshape(self.batch_size*self.noise_members, *x.shape[2:]) return x @@ -904,6 +926,12 @@ def validation_step(self, batch: ItemBatch, batch_idx: int): self.val_mean_loss = mean_loss + if self.noise_members > 1: + # select random member for preds + prediction_tensor = prediction.tensor.reshape(self.batch_size, self.noise_members, *prediction.tensor.shape[1:]) + member = prediction_tensor[:, torch.randint(0, self.noise_members, (1,)).item()] + prediction = NamedTensor.new_like(member.type_as(prediction.tensor), prediction) + self.validation_step_logging(batch, prediction, target, mask) def validation_step_logging( diff --git a/py4cast/losses.py b/py4cast/losses.py index a43d3b5b..43e15a3f 100644 --- a/py4cast/losses.py +++ b/py4cast/losses.py @@ -24,13 +24,14 @@ def __init__(self, alpha=0.95): def forward(self, preds, targets, noise_members=2): - B = targets.shape[0] // noise_members # effective batch size + B = targets.shape[0] # effective batch size M = noise_members # Number of ensemble members epsilon = (1 - self.alpha) / M - # Reshape predictions and targets - preds = preds.view(B, M, *preds.shape[1:]) # (B, M, T, W, H, d_f) - targets = targets.view(B, M, *targets.shape[1:]) # (B, M, T, W, H, d_f) + # Reshape predictions + preds = preds.reshape(B, M, *preds.shape[1:]) # (B, M, T, W, H, d_f) + # Replicate targets along the members dimension + targets = targets.unsqueeze(1).expand(-1, M, -1, -1, -1, -1) # (B, M, T, W, H, d_f) abs_diff = torch.abs(preds - targets) # (B, M, T, W, H, d_f) @@ -141,12 +142,9 @@ def prepare( # build the dictionnary of weight loss_state_weight = {} - exponent = 2.0 if self.loss.__class__ == MSELoss else 1.0 - for name in dataset_info.state_weights: - loss_state_weight[name] = dataset_info.state_weights[name] / ( - dataset_info.diff_stats[name]["std"] ** exponent - ) + loss_state_weight[name] = torch.tensor(dataset_info.state_weights[name]) + self.register_loss_state_buffers( lm, interior_mask, loss_state_weight, squeeze_mask=True ) @@ -167,7 +165,7 @@ def forward( """ # Compute Torch loss (defined in the parent class when this Mixin is used) if noise_members > 1: - torch_loss = self.loss(prediction.tensor * mask, target.tensor * mask, noise_members=noise_members) + torch_loss = self.loss(prediction.tensor * mask.repeat(noise_members, 1, 1, 1, 1) , target.tensor * mask, noise_members=noise_members) else: if self.loss.__class__ == AlmostFairCRPS: raise ValueError("When using AFCRPS loss, noise_members must be > 1") From fc4ae682687be34da91951c43e5fe55531fd3f56 Mon Sep 17 00:00:00 2001 From: AdrienAudren Date: Tue, 9 Sep 2025 14:28:50 +0200 Subject: [PATCH 4/6] change config file comment --- config/CLI/dataset/titan.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/config/CLI/dataset/titan.yaml b/config/CLI/dataset/titan.yaml index 331caec9..2dc78066 100644 --- a/config/CLI/dataset/titan.yaml +++ b/config/CLI/dataset/titan.yaml @@ -4,11 +4,10 @@ data: num_input_steps: 1 num_pred_steps_train: 1 num_pred_steps_val_test: 1 - batch_size: 4 # per device If afcrps loss used, has to be set as effective_batch_size*noise_members + batch_size: 3 # per device noise_strategy: "CondLayerNorm" # "forcing" or "CondLayerNorm" or "None" noise_members: 4 # total number of members - #other args num_workers: 2 prefetch_factor: null From 3451585f0524c98331676c3e6e695dab06561fbc Mon Sep 17 00:00:00 2001 From: AdrienAudren Date: Tue, 9 Sep 2025 18:39:03 +0200 Subject: [PATCH 5/6] spread over domain, spread maps, some fixes --- config/CLI/dataset/titan.yaml | 1 + py4cast/cli.py | 4 + py4cast/lightning.py | 22 ++++- py4cast/losses.py | 2 +- py4cast/metrics.py | 102 ++++++++++++++++++++- py4cast/plots.py | 168 +++++++++++++++++++++++++++++++++- 6 files changed, 291 insertions(+), 8 deletions(-) diff --git a/config/CLI/dataset/titan.yaml b/config/CLI/dataset/titan.yaml index 2dc78066..9cbf0777 100644 --- a/config/CLI/dataset/titan.yaml +++ b/config/CLI/dataset/titan.yaml @@ -8,6 +8,7 @@ data: noise_strategy: "CondLayerNorm" # "forcing" or "CondLayerNorm" or "None" noise_members: 4 # total number of members + ensemble_metrics: True # spread over members, requires ensemble dataset and noise_members=1 #other args num_workers: 2 prefetch_factor: null diff --git a/py4cast/cli.py b/py4cast/cli.py index b3e0d232..164d5837 100644 --- a/py4cast/cli.py +++ b/py4cast/cli.py @@ -52,6 +52,10 @@ def add_arguments_to_parser(self, parser): "data.noise_strategy", "model.noise_strategy", ) + parser.link_arguments( + "data.ensemble_metrics", + "model.ensemble_metrics", + ) parser.link_arguments( "data.train_dataset_info", "model.dataset_info", diff --git a/py4cast/lightning.py b/py4cast/lightning.py index 17d9b597..e256b969 100644 --- a/py4cast/lightning.py +++ b/py4cast/lightning.py @@ -31,7 +31,7 @@ save_named_tensors_to_grib, ) from py4cast.losses import ScaledLoss, WeightedLoss -from py4cast.metrics import MetricACC, MetricPSDK, MetricPSDVar +from py4cast.metrics import MetricACC, MetricPSDK, MetricPSDVar, MetricSpread, MetricMapSpread from py4cast.models import build_model_from_settings, get_model_kls_and_settings from py4cast.models import registry as model_registry from py4cast.plots import ( @@ -39,6 +39,7 @@ PredictionTimestepPlot, SpatialErrorPlot, StateErrorPlot, + SpreadTimestepPlot, ) from py4cast.utils import str_to_dtype @@ -68,6 +69,7 @@ def __init__( dataset_conf: Dict | None = None, noise_members: int = 0, noise_strategy: Literal["forcing", "CondLayerNorm", "None"] = "forcing", + ensemble_metrics: bool = False, ): super().__init__() self.num_input_steps = num_input_steps @@ -83,6 +85,7 @@ def __init__( self.pin_memory = pin_memory self.noise_members = noise_members self.noise_strategy = noise_strategy + self.ensemble_metrics = ensemble_metrics # Get dataset in initialisation to have access to this attribute before method trainer.fit self.train_ds, self.val_ds, self.test_ds = get_datasets( @@ -172,6 +175,7 @@ def __init__( batch_size: int = 2, noise_members: int = 0, noise_strategy: Literal["forcing", "CondLayerNorm", "None"] = "forcing", + ensemble_metrics: bool = False, # non-linked args model_name: Literal[tuple(model_registry.keys())] = "HalfUNet", loss_name: Literal["mse", "mae", "afcrps"] = "mse", @@ -200,6 +204,7 @@ def __init__( self.batch_size = batch_size self.noise_members = noise_members self.noise_strategy = noise_strategy + self.ensemble_metrics = ensemble_metrics self.model_name = model_name self.num_input_steps = num_input_steps self.num_pred_steps_train = num_pred_steps_train @@ -334,6 +339,8 @@ def setup(self, stage=None): self.rmse_psd_plot_metric = MetricPSDVar(pred_step=max_pred_step) self.psd_plot_metric = MetricPSDK(self.save_path, pred_step=max_pred_step) self.acc_metric = MetricACC(self.dataset_info) + self.spread_metric = MetricSpread(self.dataset_info, pred_step=max_pred_step+1) + self.spread_map_metric = MetricMapSpread(self.dataset_info, pred_step=max_pred_step+1) self.configure_loggers() def configure_loggers(self): @@ -1022,6 +1029,9 @@ def on_test_start(self): loss.prepare(self, self.interior_mask, self.dataset_info) metrics[alias] = loss + if self.ensemble_metrics: + metrics["std"] = self.spread_metric + self.test_plotters = [ StateErrorPlot(metrics, save_path=self.save_path), SpatialErrorPlot(), @@ -1033,6 +1043,16 @@ def on_test_start(self): ), ] + if self.ensemble_metrics: + self.test_plotters.append(SpreadTimestepPlot( + metric=self.spread_map_metric, + dataset_name = self.dataset_info.name, + num_samples_to_plot=self.num_samples_to_plot, + num_features_to_plot=6, + prefix="Test", + save_path=self.save_path, + )) + def test_step(self, batch: ItemBatch, batch_idx: int): """Runs test on single batch""" with torch.no_grad(): diff --git a/py4cast/losses.py b/py4cast/losses.py index 43e15a3f..97e707dc 100644 --- a/py4cast/losses.py +++ b/py4cast/losses.py @@ -164,7 +164,7 @@ def forward( returns (B, pred_steps) """ # Compute Torch loss (defined in the parent class when this Mixin is used) - if noise_members > 1: + if noise_members > 1 and self.loss.__class__ == AlmostFairCRPS: torch_loss = self.loss(prediction.tensor * mask.repeat(noise_members, 1, 1, 1, 1) , target.tensor * mask, noise_members=noise_members) else: if self.loss.__class__ == AlmostFairCRPS: diff --git a/py4cast/metrics.py b/py4cast/metrics.py index d322aa81..e979452c 100644 --- a/py4cast/metrics.py +++ b/py4cast/metrics.py @@ -5,7 +5,7 @@ import torch from scipy.fftpack import dct from torchmetrics import Metric - +from torchmetrics.utilities import dim_zero_cat from py4cast.datasets.base import DatasetInfo, NamedTensor from py4cast.plots import plot_log_psd @@ -452,3 +452,103 @@ def compute(self, prefix: str = "val") -> dict: self.reset() return metric_log_dict + +class MetricSpread(Metric): + """ + Compute the spatially averaged, per pred step members spread for both the target (PE-AROME) and the prediction + """ + + def __init__(self, dataset_info: DatasetInfo, pred_step: int = 0): + super().__init__() + self.pred_step = pred_step + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("targets", default=[], dist_reduce_fx="cat") + # dist_reduce_fx="cat" because we gather multiple batch before computing std + + self.dataset_name = dataset_info.name + + def update(self, pred: NamedTensor, target: NamedTensor, *args): + """ + Assuming a batch contains the members of a same ensemble, + compute the spread of batch=members for both target and pred. + prediction/target: (B=Mb, pred_steps, N_grid, d_f) or (B=Mb, pred_steps, W, H, d_f) + called at each end of step + """ + + # a priori unknown number of spatial dims + # but they are all after pred_steps and before features + self.spatial_dims = tuple(pred.spatial_dim_idx) + if pred.tensor.shape != target.tensor.shape: + raise ValueError("preds and target must have the same shape") + + self.preds.append(pred.tensor) + self.targets.append(target.tensor) + + def compute(self, prefix: str = "val") -> dict: + """ + Compute spread mean for each channels/features, return a dict. + Should be called at each epoch's end + """ + # Ensure concatenating into a single tensor along device dimension + preds = dim_zero_cat(self.preds) + targets = dim_zero_cat(self.targets) + # Spread computation + preds_std = preds.std(dim=0).mean(dim=tuple([spatial_dim-1 for spatial_dim in self.spatial_dims])) + targets_std = targets.std(dim=0).mean(dim=tuple([spatial_dim-1 for spatial_dim in self.spatial_dims])) + # Reconvert to initial type then reset metric's state + self.preds = [] + self.targets = [] + self.reset() + + return {"preds": preds_std, + f"{self.dataset_name}": targets_std, + } + +class MetricMapSpread(Metric): + """ + Compute the spatially averaged, per pred step members spread for both the target (PE-AROME) and the prediction + """ + + def __init__(self, dataset_info: DatasetInfo, pred_step: int = 0): + super().__init__() + self.pred_step = pred_step + + self.add_state("preds", default=[], dist_reduce_fx="cat") + self.add_state("targets", default=[], dist_reduce_fx="cat") + # dist_reduce_fx="cat" because we gather multiple batch before computing std + + self.dataset_name = dataset_info.name + + def update(self, pred: torch.Tensor, target: torch.Tensor, *args): + """ + Assuming a batch contains the members of a same ensemble, + compute the spread of batch=members for both target and pred. + prediction/target: (B=Mb, pred_steps, N_grid, d_f) or (B=Mb, pred_steps, W, H, d_f) + called at each end of step + """ + if pred.shape != target.shape: + raise ValueError("preds and target must have the same shape") + + self.preds.append(pred) + self.targets.append(target) + + def compute(self, prefix: str = "val") -> dict: + """ + Compute spread mean for each channels/features, return a dict. + Should be called at each epoch's end + """ + # Ensure concatenating into a single tensor along device dimension + preds = dim_zero_cat(self.preds) + targets = dim_zero_cat(self.targets) + # Spread computation + preds_std = preds.std(dim=0) + targets_std = targets.std(dim=0) + # Reconvert to initial type then reset metric's state + self.preds = [] + self.targets = [] + self.reset() + + return {"preds": preds_std, + "targets": targets_std, + } \ No newline at end of file diff --git a/py4cast/plots.py b/py4cast/plots.py index 17603a37..57aa56b4 100644 --- a/py4cast/plots.py +++ b/py4cast/plots.py @@ -426,6 +426,150 @@ def plot_map( paths, self.save_path / f"timestep_evol_per_param/{var_name}.gif" ) +class SpreadTimestepPlot(MapPlot): + """ + Observer used to plot prediction and target spread map for each timestep. + """ + + def __init__( + self, + metric: Metric, + dataset_name: str, + num_samples_to_plot: int, + num_features_to_plot: Union[None, int] = None, + prefix: str = "Test", + save_path: Path = None, + ): + super().__init__( + num_samples_to_plot=num_samples_to_plot, + num_features_to_plot=num_features_to_plot, + prefix=prefix, + save_path=save_path, + ) + self.metric = metric + self.dataset_name = dataset_name + + def update( + self, + obj: "AutoRegressiveLightning", + batch: "ItemBatch", + prediction: NamedTensor, + target: NamedTensor, + mask: torch.Tensor, + ) -> None: + """ + Update. Should be called by on_{training/validation/test}_step + """ + pred = deepcopy(prediction).tensor # don’t modify input + targ = deepcopy(target).tensor + batch_copy = deepcopy(batch) + + # Reshape outputs from GNNs to grid + if prediction.num_spatial_dims == 1: + pred = einops.rearrange(pred, "b t (x y) n -> b t x y n", x=obj.grid_shape[0]) + targ = einops.rearrange(targ, "b t (x y) n -> b t x y n", x=obj.grid_shape[0]) + + if obj.trainer.is_global_zero and self.plotted_examples < self.num_samples_to_plot: + n_additional_examples = min( + pred.shape[0], self.num_samples_to_plot - self.plotted_examples + ) + + # Rescale to original data scale + std = obj.stats.to_list("std", prediction.feature_names).to(pred, non_blocking=True) + mean = obj.stats.to_list("mean", prediction.feature_names).to(pred, non_blocking=True) + prediction_rescaled = pred * std + mean + target_rescaled = targ * std + mean + + for pred_slice, target_slice in zip( + prediction_rescaled[:n_additional_examples], + target_rescaled[:n_additional_examples], + ): + self.plotted_examples += 1 + + # compute variance ranges + self.metric.update(pred_slice, target_slice) + std_dict = self.metric.compute(prefix="test") + prediction_std = std_dict["preds"] + target_std = std_dict["targets"] + + var_vmin = target_std.flatten(0, 2).min(dim=0)[0].cpu().numpy() + var_vmax = target_std.flatten(0, 2).max(dim=0)[0].cpu().numpy() + var_vranges = list(zip(var_vmin, var_vmax)) + + feature_names = ( + prediction.feature_names[: self.num_features_to_plot] + if self.num_features_to_plot + else prediction.feature_names + ) + + self.plot_map( + obj, + batch_copy, + prediction_std, + target_std, + feature_names, + var_vranges, + ) + + def plot_map( + self, + obj: "AutoRegressiveLightning", + batch: "ItemBatch", + prediction: torch.Tensor, + target: torch.Tensor, + feature_names: List[str], + var_vranges: List, + ) -> None: + # Prediction and target: (pred_steps, Nlat, Nlon, features) + paths_dict = defaultdict(list) + for t_i, (pred_t, target_t) in enumerate(zip(prediction, target), start=1): + units = [obj.dataset_info.units[name] for name in feature_names] + var_figs = [ + plot_prediction( + pred_t[:, :, var_i], + target_t[:, :, var_i], + obj.interior_2d[:, :, 0], + title=f"{var_name} ({var_unit}), " + f"t={t_i} ({obj.dataset_info.pred_step*t_i} h)", + vrange=var_vrange, + domain_info=obj.dataset_info.domain_info, + cmap="viridis", + ) + for var_i, (var_name, var_unit, var_vrange) in enumerate( + zip(feature_names, units, var_vranges) + ) + ] + + tensorboard = obj.logger.experiment + for var_name, fig in zip(feature_names, var_figs): + fig_name = f"timestep_spread_evol_per_param/{var_name}_example_{self.plotted_examples}" + tensorboard.add_figure(fig_name, fig, t_i) + fig_full_name = f"{fig_name}_{t_i}.png" + + if self.save_path is not None and self.save_path.exists(): + dest_file = self.save_path / fig_full_name + paths_dict[var_name].append(dest_file) + dest_file.parent.mkdir(exist_ok=True) + fig.savefig(dest_file) + + if obj.mlflow_logger: + run_id = obj.mlflow_logger.version + obj.mlflow_logger.experiment.log_figure( + run_id=run_id, + figure=fig, + artifact_file=f"figures/{fig_full_name}", + ) + + plt.close(fig) + + # build gifs + for var_name, paths in paths_dict.items(): + if len(paths) > 1: + make_gif( + paths, + self.save_path / f"timestep_spread_evol_per_param/{var_name}.gif", + ) + class PredictionEpochPlot(MapPlot): """ @@ -523,11 +667,25 @@ def update( Compute the metric. Append to a dictionnary """ for name in self.metrics: - self.losses[name].append( - obj.trainer.strategy.reduce( - self.metrics[name](prediction, target, mask), reduce_op="mean" - ).cpu() - ) + + if name == "std": + if not self.initialized: + self.metrics[name].update(prediction, target) + self.epoch_dict_loss = self.metrics[name].compute(prefix="test") + else: + self.metrics[name].update(prediction, target) + dict_loss = self.metrics[name].compute(prefix="test") + for key, value in dict_loss.items(): + self.epoch_dict_loss[key] += value + self.count+=1 + + else: + self.losses[name].append( + obj.trainer.strategy.reduce( + self.metrics[name](prediction, target, mask), reduce_op="mean" + ).cpu() + ) + if not self.initialized: self.shortnames = prediction.feature_names self.units = [ From 2fa6189732068a09abc0314449d8f0923655ccfe Mon Sep 17 00:00:00 2001 From: AdrienAudren Date: Thu, 18 Sep 2025 18:16:24 +0200 Subject: [PATCH 6/6] fix noise as forcing strat --- py4cast/datasets/base.py | 1 - py4cast/lightning.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/py4cast/datasets/base.py b/py4cast/datasets/base.py index 043943d3..3dd809bd 100644 --- a/py4cast/datasets/base.py +++ b/py4cast/datasets/base.py @@ -745,7 +745,6 @@ def input_dim(self) -> int: """ res = 4 # For date res += 1 # For solar forcing - res += 1 if self.settings.noise_strategy == "forcing" else 0 # additional noise channel for param in self.params: if param.kind == "input": diff --git a/py4cast/lightning.py b/py4cast/lightning.py index e256b969..316124e7 100644 --- a/py4cast/lightning.py +++ b/py4cast/lightning.py @@ -274,6 +274,7 @@ def __init__( + num_grid_static_features + dataset_info.forcing_dim + self.mask_on_nan + + (1 if self.noise_strategy == "forcing" else 0) ) num_output_features = dataset_info.weather_dim