Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions config/CLI/dataset/titan.yaml
Original file line number Diff line number Diff line change
@@ -1,29 +1,32 @@
data:
#args forwarded (linked) to model
dataset_name: titan_aro_arp
dataset_name: titan_full
num_input_steps: 1
num_pred_steps_train: 1
num_pred_steps_val_test: 1
batch_size: 2
batch_size: 3 # per device

noise_strategy: "CondLayerNorm" # "forcing" or "CondLayerNorm" or "None"
noise_members: 4 # total number of members
ensemble_metrics: True # spread over members, requires ensemble dataset and noise_members=1
#other args
num_workers: 10
num_workers: 2
prefetch_factor: null
pin_memory: False
dataset_conf:
periods:
train:
start: 20200101
end: 20221231
start: 2021010100
end: 2021010100
obs_step: 3600
valid:
start: 20230101
end: 20231231
start: 2021010100
end: 2021010100
obs_step: 3600
obs_step_btw_t0: 10800
test:
start: 20240101
end: 20240831
start: 2023050122
end: 2023080122
obs_step: 3600
obs_step_btw_t0: 10800
grid:
Expand Down
7 changes: 4 additions & 3 deletions config/CLI/model/unetrpp.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model:
model_name: UNETRPP
loss_name: mse # mse or mae
model_name: UNetRPP
loss_name: afcrps # mse or mae or afcrps
num_inter_steps: 1 # Number of intermediary steps (without any data)
num_samples_to_plot: 1
training_strategy: scaled_ar # diff_ar or scaled_ar or downscaling_only
Expand All @@ -27,4 +27,5 @@ model:
decoder_proj_size: 64
encoder_proj_sizes: [64, 64, 64, 32]
add_skip_connections: true
attention_code: "torch"
attention_code: "torch"
CondLayerNorm: True # use ConditionalLayerNorm instead of LayerNorm
12 changes: 12 additions & 0 deletions py4cast/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def add_arguments_to_parser(self, parser):
"data.dataset_conf",
"model.dataset_conf",
)
parser.link_arguments(
"data.noise_members",
"model.noise_members",
)
parser.link_arguments(
"data.noise_strategy",
"model.noise_strategy",
)
parser.link_arguments(
"data.ensemble_metrics",
"model.ensemble_metrics",
)
parser.link_arguments(
"data.train_dataset_info",
"model.dataset_info",
Expand Down
6 changes: 5 additions & 1 deletion py4cast/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import traceback
import warnings
from typing import Dict, Tuple
from typing import Dict, Tuple, Literal

from .base import DatasetABC # noqa: F401

Expand Down Expand Up @@ -47,6 +47,8 @@ def get_datasets(
num_input_steps: int,
num_pred_steps_train: int,
num_pred_steps_val_test: int,
noise_members: int,
noise_strategy: Literal["forcing", "CondLayerNorm", "None"],
dataset_conf: Dict | None = None,
) -> Tuple[DatasetABC, DatasetABC, DatasetABC]:
"""
Expand Down Expand Up @@ -76,4 +78,6 @@ def get_datasets(
num_input_steps,
num_pred_steps_train,
num_pred_steps_val_test,
noise_members,
noise_strategy,
)
2 changes: 2 additions & 0 deletions py4cast/datasets/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ class SamplePreprocSettings:
standardize: bool = True
file_format: Literal["npy", "grib"] = "grib"
members: Optional[Tuple[int]] = None
noise_members: int = 0
noise_strategy: Literal["forcing", "CondLayerNorm", "None"] = "forcing",
add_landsea_mask: bool = False


Expand Down
13 changes: 13 additions & 0 deletions py4cast/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def is_valid(self) -> bool:
param=param,
timestamps=self.timestamps,
file_format=self.settings.file_format,
num_input_steps=self.settings.num_input_steps,
):
return False
return True
Expand Down Expand Up @@ -721,6 +722,7 @@ def torch_dataloader(
shuffle: bool = False,
prefetch_factor: Union[int, None] = None,
pin_memory: bool = False,
drop_last: bool = False,
) -> DataLoader:
"""
Builds a torch dataloader from self.
Expand All @@ -733,6 +735,7 @@ def torch_dataloader(
prefetch_factor=prefetch_factor,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=drop_last,
)

@cached_property
Expand Down Expand Up @@ -862,6 +865,8 @@ def from_dict(
num_input_steps: int,
num_pred_steps_train: int,
num_pred_steps_val_test: int,
noise_members: int,
noise_strategy: Literal["forcing", "CondLayerNorm", "None"],
) -> Tuple[Type["DatasetABC"], Type["DatasetABC"], Type["DatasetABC"]]:
grid = Grid(load_grid_info_func=accessor_kls.load_grid_info, **conf["grid"])

Expand All @@ -877,6 +882,8 @@ def from_dict(
num_input_steps=num_input_steps,
num_pred_steps=num_pred_steps_train,
members=members,
noise_members=noise_members,
noise_strategy=noise_strategy,
**conf["settings"],
)
train_period = Period(**conf["periods"]["train"], name="train")
Expand All @@ -889,6 +896,8 @@ def from_dict(
num_input_steps=num_input_steps,
num_pred_steps=num_pred_steps_val_test,
members=members,
noise_members=noise_members,
noise_strategy=noise_strategy,
**conf["settings"],
)
valid_period = Period(**conf["periods"]["valid"], name="valid")
Expand All @@ -911,6 +920,8 @@ def from_json(
num_input_steps: int,
num_pred_steps_train: int,
num_pred_steps_val_tests: int,
noise_members: int,
noise_strategy: Literal["forcing", "CondLayerNorm", "None"],
predict_conf: Union[Dict, None] = None,
) -> Tuple[Type["DatasetABC"], Type["DatasetABC"], Type["DatasetABC"]]:
"""
Expand All @@ -931,4 +942,6 @@ def from_json(
num_input_steps,
num_pred_steps_train,
num_pred_steps_val_tests,
noise_members,
noise_strategy,
)
8 changes: 7 additions & 1 deletion py4cast/datasets/titan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,14 @@ def exists(
param: WeatherParam,
timestamps: Timestamps,
file_format: Literal["npy", "grib"] = "grib",
num_input_steps: int = 1,
) -> bool:
for date in timestamps.validity_times:
if param.kind == "input":
# inputs/forcings only required after num_input_steps
valid_times = timestamps.validity_times[num_input_steps:]
else:
valid_times = timestamps.validity_times
for date in valid_times:
filepath = self.get_filepath(ds_name, param, date, file_format)
if not filepath.exists():
return False
Expand Down
4 changes: 4 additions & 0 deletions py4cast/datasets/titan/titan_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def prepare(
num_input_steps: int = 1,
num_pred_steps_train: int = 1,
num_pred_steps_val_test: int = 1,
noise_members: int = 0,
noise_strategy: str = "None",
convert_grib2npy: bool = False,
compute_stats: bool = True,
):
Expand All @@ -72,6 +74,8 @@ def prepare(
num_input_steps=num_input_steps,
num_pred_steps_train=num_pred_steps_train,
num_pred_steps_val_test=num_pred_steps_val_test,
noise_members=noise_members,
noise_strategy=noise_strategy,
)
train_ds.cache_dir.mkdir(exist_ok=True)
data_dir = train_ds.cache_dir / "data"
Expand Down
Loading
Loading