diff --git a/configurations/surface-dummy-model_DINI/.gitignore b/configurations/surface-dummy-model_DINI/.gitignore new file mode 100644 index 0000000..6ec839b --- /dev/null +++ b/configurations/surface-dummy-model_DINI/.gitignore @@ -0,0 +1,6 @@ +*.zip +*.zarr/ +inference_artifact/ +*.yaml +inference_workdir/ +.env diff --git a/configurations/surface-dummy-model_DINI/Containerfile b/configurations/surface-dummy-model_DINI/Containerfile index a6744fe..889bf36 100644 --- a/configurations/surface-dummy-model_DINI/Containerfile +++ b/configurations/surface-dummy-model_DINI/Containerfile @@ -5,6 +5,7 @@ WORKDIR /workspace COPY pyproject.toml . COPY *.yaml ./ COPY entry.sh ./ +COPY src/ ./src # Download inference artifact from S3 ARG DEFAULT_ARTIFACT="s3://mlwm-artifacts/inference-artifacts/surface-dummy-model_DINI.zip" diff --git a/configurations/surface-dummy-model_DINI/DEVELOPING.md b/configurations/surface-dummy-model_DINI/DEVELOPING.md new file mode 100644 index 0000000..d99937a --- /dev/null +++ b/configurations/surface-dummy-model_DINI/DEVELOPING.md @@ -0,0 +1,51 @@ +# Development notes + +## Local development + +- currently image build only works on amd64 machines (i.e. not on macos) + +- image build requires `aws` cli which can be retrieved from https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip + +- to load AWS crendentials from `.aws/credentials` you can use the following script (drop in e.g. `~/.bashrc`): + +```bash +aws-load-creds() { + local profile=$1 + if [[ -z "$profile" ]]; then + echo "❌ Usage: aws-load-creds " + return 1 + fi + + local access_key + local secret_key + + access_key=$(aws configure get aws_access_key_id --profile "$profile" 2>/dev/null) + secret_key=$(aws configure get aws_secret_access_key --profile "$profile" 2>/dev/null) + + if [[ -z "$access_key" || -z "$secret_key" ]]; then + echo "❌ The config profile '$profile' could not be found or is incomplete." + return 1 + fi + + export AWS_ACCESS_KEY_ID="$access_key" + export AWS_SECRET_ACCESS_KEY="$secret_key" + + echo "✅ Loaded AWS credentials from profile: $profile" +} + +aws-list-profiles() { + echo "📂 AWS profiles found:" + grep '^\[profile ' ~/.aws/config 2>/dev/null | sed 's/^\[profile //' | sed 's/\]//' + grep '^\[' ~/.aws/credentials 2>/dev/null | sed 's/^\[//' | sed 's/\]//' +} +``` + +- to set the environment variables for `./entry.sh` you can use a `.env` file. E.g. to run with DINI forecast data you would use: + +```bash +# .env +ANALYSIS_TIME="2025-09-22T120000Z" +DINI_ZARR="s3://harmonie-zarr/dini/control/${ANALYSIS_TIME}/single_levels.zarr/" +DATASTORE_INPUT_PATHS="danra.danra_surface=${DINI_ZARR},danra.danra_static=${DINI_ZARR}" +TIME_DIMENSIONS="time" +``` diff --git a/configurations/surface-dummy-model_DINI/README.md b/configurations/surface-dummy-model_DINI/README.md index e69de29..3582644 100644 --- a/configurations/surface-dummy-model_DINI/README.md +++ b/configurations/surface-dummy-model_DINI/README.md @@ -0,0 +1,83 @@ +# surface-dummy-model_DINI + +The model configuration in this directory is a dummy model that was trained on +surface variables from DANRA, only 10 days of data and only trained 10 +epochs. It is intended only as a demonstration of the inference pipeline and is +expected to give very poor results. + +## Building image and running inference + +To build the image on "superjuice" (`27sj894.dmi.dk`) we need to set the AWS tokens to read the inference artifact and also use the local http proxy for pulling the base image: + +```bash +export AWS_SECRET_ACCESS_KEY= +export AWS_ACCESS_KEY_ID= +export MLWM_PULL_PROXY=http://squid1.dmi.dk:3128 +``` + + + +## Upstream package change requirements + +Relative to the `main` branch on both github.com/mllam/mllam-data-prep and +github.com/mllam/neural-lam and number of pieces of functionality are currently +required to run this configuration: + +**mllam-data-prep**: + +using branch `feat/inference-cli-args` on +https://github.com/leifdenby/mllam-data-prep@feat/inference-cli-args, which adds: + +- functionality to invert datasets created by `mllam-data-prep` back to the + structure of the input datasets that we were used. In the current + configuration that is used to restructure the forecast zarr dataset that + `neural-lam` outputs during inference back to the structure of the input + forecast dataset. + + - also in seperate branch and PR: https://github.com/leifdenby/mllam-data-prep/tree/feat/inverse-ops + +- use of cf-compliant encoding of `xarray/pandas` `MultiIndex` coordinates to + store stacked coordinates. This is required since we `MultiIndex` coordinates + can't natively be stored in zarr/netcdf files, but fortunately `cf_xarray` + have implemented the cf-compliant way of handling this (see + https://cf-xarray.readthedocs.io/en/latest/coding.html) + + - needs its own branch and PR + +- support for supplying statistics from the training dataset during creation of + the inference dataset, so that the inference dataset can be normalised in the + same way as the training dataset. + + - needs its own branch and PR + +- support for selecting only a single value from a variable/coordinate in the + configuration. This is used to select only a single analysis time during + creation of the inference dataset. + + - needs its own branch and PR + + +**neural-lam**: + +using branch `dev/first-inference-image` on +https://github.com/leifdenby/neural-lam/tree/dev/first-inference-image, which +adds: + +- support for decoding cf-compliant `MultiIndex` encoded coordinates when reading + datasets produced with mllam-data-prep. + + - this needs its own branch and PR, and needs to be implemented so datasets + made with previous versions of `mllam-data-prep` are still usable in `neural-lam` + +- support for writing output from inference (i.e. `--eval` mode) to a zarr + dataset. Needs to be merged after the multiindex decoding above. + + - also in seperate branch and PR: https://github.com/leifdenby/neural-lam/tree/feat/write-to-zarr + +- support for using forecast data in in mllam-data-prep datastore (`MDPDatastore`) + + - needs its own branch and PR + +- make logging of validation steps optional in the training CLI (i.e. `--eval` mode) + + - needs its own branch and PR diff --git a/configurations/surface-dummy-model_DINI/datastore.yaml b/configurations/surface-dummy-model_DINI/datastore.yaml deleted file mode 100644 index fcf227f..0000000 --- a/configurations/surface-dummy-model_DINI/datastore.yaml +++ /dev/null @@ -1,165 +0,0 @@ -schema_version: v0.5.0 -dataset_version: v0.1.0 - -output: - variables: - static: [grid_index, static_feature] - state: [time, grid_index, state_feature] - forcing: [time, grid_index, forcing_feature] - coord_ranges: - time: - start: - end: - step: PT3H - chunking: - time: 1 - state_feature: 20 - splitting: - dim: time - splits: - train: - start: 2000-01-01T00:00 - end: 2018-10-29T00:00 - compute_statistics: - ops: [mean, std, diff_mean, diff_std] - dims: [grid_index, time] - val: - start: 2018-11-05T00:00 - end: 2019-10-22T00:00 - test: - start: 2019-10-29T00:00 - end: 2020-10-29T00:00 - -inputs: - danra_sl_state: - path: /harmonie_cy40/danra/w12p05_s45p65_e24p52_n64p40/dx2p5km_dy2p5km//single_levels.zarr/ - dims: [time, x, y] - variables: - - pres_seasurface - - t2m - - u10m - - v10m - - pres0m - - lwavr0m - - swavr0m - dim_mapping: - time: - method: rename - dim: time - grid_index: - method: stack - dims: [x, y] - state_feature: - method: stack_variables_by_var_name - name_format: "{var_name}" - target_output_variable: state - - danra_pl_state: - path: /harmonie_cy40/danra/w12p05_s45p65_e24p52_n64p40/dx2p5km_dy2p5km//pressure_levels.zarr/ - dims: [time, x, y, pressure] - variables: - z: - pressure: - values: [100, 200, 400, 600, 700, 850, 925, 1000,] - units: hPa - t: - pressure: - values: [100, 200, 400, 600, 700, 850, 925, 1000,] - units: hPa - r: - pressure: - values: [100, 200, 400, 600, 700, 850, 925, 1000,] - units: hPa - u: - pressure: - values: [100, 200, 400, 600, 700, 850, 925, 1000,] - units: hPa - v: - pressure: - values: [100, 200, 400, 600, 700, 850, 925, 1000,] - units: hPa - tw: - pressure: - values: [100, 200, 400, 600, 700, 850, 925, 1000,] - units: hPa - dim_mapping: - time: - method: rename - dim: time - state_feature: - method: stack_variables_by_var_name - dims: [pressure] - name_format: "{var_name}{pressure}" - grid_index: - method: stack - dims: [x, y] - target_output_variable: state - - danra_static: - path: /harmonie_cy40/danra/w12p05_s45p65_e24p52_n64p40/dx2p5km_dy2p5km//single_levels.zarr/ - dims: [x, y] - variables: - - lsm - - orography - dim_mapping: - grid_index: - method: stack - dims: [x, y] - static_feature: - method: stack_variables_by_var_name - name_format: "{var_name}" - target_output_variable: static - - danra_forcing: - path: /harmonie_cy40/danra/w12p05_s45p65_e24p52_n64p40/dx2p5km_dy2p5km//single_levels.zarr/ - dims: [time, x, y] - derived_variables: - # derive variables to be used as forcings - toa_radiation: - kwargs: - time: ds_input.time - lat: ds_input.lat - lon: ds_input.lon - function: mllam_data_prep.ops.derive_variable.physical_field.calculate_toa_radiation - hour_of_day_sin: - kwargs: - time: ds_input.time - component: sin - function: mllam_data_prep.ops.derive_variable.time_components.calculate_hour_of_day - hour_of_day_cos: - kwargs: - time: ds_input.time - component: cos - function: mllam_data_prep.ops.derive_variable.time_components.calculate_hour_of_day - day_of_year_sin: - kwargs: - time: ds_input.time - component: sin - function: mllam_data_prep.ops.derive_variable.time_components.calculate_day_of_year - day_of_year_cos: - kwargs: - time: ds_input.time - component: cos - function: mllam_data_prep.ops.derive_variable.time_components.calculate_day_of_year - dim_mapping: - time: - method: rename - dim: time - grid_index: - method: stack - dims: [x, y] - forcing_feature: - method: stack_variables_by_var_name - name_format: "{var_name}" - target_output_variable: forcing - -extra: - projection: - class_name: LambertConformal - kwargs: - central_longitude: 25.0 - central_latitude: 56.7 - standard_parallels: [56.7, 56.7] - globe: - semimajor_axis: 6367470.0 - semiminor_axis: 6367470.0 diff --git a/configurations/surface-dummy-model_DINI/entry.sh b/configurations/surface-dummy-model_DINI/entry.sh index 376ba89..c8a2037 100755 --- a/configurations/surface-dummy-model_DINI/entry.sh +++ b/configurations/surface-dummy-model_DINI/entry.sh @@ -3,44 +3,110 @@ # # This script is intended to be run in a container, and assumes that during the # container image build that the inference artifact was unpacked to -# inference_artifact/ +# inference_artifact/. You can also run this script interactively if you have +# extracted the inference artifact yourself. +# +# The selection of datasets to use for input to the model, analysis time and +# forecast duration is controller by the following environment variables: +# DATASTORE_INPUT_PATHS, ANALYSIS_TIME, FORECAST_DURATION and NUM_EVAL_STEPS +# (the latter should be inferred from FORECAST_DURATION, but that is TODO) +# +# - DATASTORE_INPUT_PATHS is a comma-separated list of mappings of +# {datastore_name}.{input_name}={input_path} +# - ANALYSIS_TIME is the analysis time to start the forecast from is ISO8601 +# format +# - FORECAST_DURATION is the duration of the forecast in ISO8601 duration +# format and effects the length of the produced inference dataset +# - NUM_EVAL_STEPS is the number of autoregressive steps to run during +# inference. This should be consistent with FORECAST_DURATION and the model +# configuration (e.g. if the model was trained on 3-hourly data and +# FORECAST_DURATION is PT18H then NUM_EVAL_STEPS should be 6 +# make this script fail on any error +set -e -INFERENCE_ARTIFACT_PATH="./inference_artifact" -# XXX: these mount points could come from config.yaml for the model run configuration -INPUT_DATASETS_ROOT_PATH="/volume/inputs" -OUTPUT_DATASETS_ROOT_PATH="/volume/outputs" +## Runtime configuration (variable expected to change on every execution) +# enable use of .env so that during development we can set environment (e.g. +# paths to replace in datastore config) +if [ -f .env ] ; then + echo "Sourcing local .env file" + set -a && source .env && set +a +fi +## Model specific inference configuration (same across all executions) +NUM_HIDDEN_DIMS=2 +GRAPH_NAME="multiscale" +HIEARCHICAL_GRAPH=false +MODEL_TIMESTEP="PT3H" # model trained on 3-hourly data + +# set default override of input paths in the datastore config used for creating the +# inference dataset if environment variable isn't set +DATASTORE_INPUT_PATHS=${DATASTORE_INPUT_PATHS:-"\ +danra.danra_surface=https://object-store.os-api.cci1.ecmwf.int/danra/v0.6.0dev1/single_levels.zarr/,\ +danra.danra_static=https://object-store.os-api.cci1.ecmwf.int/danra/v0.5.0/single_levels.zarr/"} +TIME_DIMENSIONS=${TIME_DIMENSIONS:-"analysis_time,elapsed_forecast_duration"} +ANALYSIS_TIME=${ANALYSIS_TIME:-"2019-02-04T12:00"} # assumed to be in UTC # forecast out to 18 hours, which means 6 steps of 3 hours each (the model was # trained on 3-hourly analysis data) -NUM_EVAL_STEPS=6 +FORECAST_DURATION=${FORECAST_DURATION:-"PT18H"} +NUM_EVAL_STEPS=${NUM_EVAL_STEPS:-6} +INFERENCE_WORKDIR=${INFERENCE_WORKDIR:-"./inference_workdir"} + +echo "Creating forecast using following runtime args:" +echo " DATASTORE_INPUT_PATHS=${DATASTORE_INPUT_PATHS}" +echo " TIME_DIMENSIONS=${TIME_DIMENSIONS}" +echo " ANALYSIS_TIME=${ANALYSIS_TIME}" +echo " FORECAST_DURATION=${FORECAST_DURATION}" +echo " NUM_EVAL_STEPS=${NUM_EVAL_STEPS}" +echo " INFERENCE_WORKDIR=${INFERENCE_WORKDIR}" + +# set cli argument for creating hierarchical graph if needed +if [ "$HIEARCHICAL_GRAPH" = true ] ; then + CREATE_GRAPH_ARG="--hierarchical" +else + CREATE_GRAPH_ARG="" +fi + +## Setup working directories +INFERENCE_ARTIFACT_PATH="./inference_artifact" +INPUT_DATASETS_ROOT_PATH="${INFERENCE_WORKDIR}/inputs" +OUTPUT_DATASETS_ROOT_PATH="${INFERENCE_WORKDIR}/outputs" +mkdir -p ${OUTPUT_DATASETS_ROOT_PATH} + +# disable weights and biases logging, without this --eval with neural-lam fails +# because it tries to set up the logging and there is no WANDB_API_KEY set +uv run wandb disabled ## 1. Create inference dataset # This uses a cli stored within mlwm to called mllam-data-prep to create the # inference dataset. The inference dataset is created by modifying the -# configuration used during training to a) change the paths to the input datasets, -# b) include the statistics from the training dataset and c) set the dimensions -# in the configuration to have `analysis_time` and `elapsed_forecast_duration` -# instead of just `time`. -uv run python -m mlwm.create_inference_dataset \ - --config_path ${INFERENCE_ARTIFACT_PATH}/config.yaml \ - --override_input_paths \ - danra_surface=${INPUT_DATASETS_ROOT_PATH}/single_levels.zarr \ - danra_surface_forcing=${INPUT_DATASETS_ROOT_PATH}/single_levels.zarr \ - danra_static=${INPUT_DATASETS_ROOT_PATH}/single_levels.zarr \ - --use_stats_from_path ${INFERENCE_ARTIFACT_PATH}/danra.datastore.stats.zarr \ - --output_root_path inference/ +# configuration used during training to +# a) change the paths to the input datasets, +# b) include the statistics from the training dataset and +# c) set the dimensions in the configuration to have `analysis_time` and +# `elapsed_forecast_duration` instead of just `time`. +DATASTORE_INPUT_PATHS=${DATASTORE_INPUT_PATHS} \ +ANALYSIS_TIME=${ANALYSIS_TIME} \ +FORECAST_DURATION=${FORECAST_DURATION} \ +TIME_DIMENSIONS=${TIME_DIMENSIONS} \ +INFERENCE_WORKDIR=${INFERENCE_WORKDIR} \ +uv run python src/create_inference_dataset.py ## 2. Create graph -uv run python -m neural_lam.create_graph --config_path inference/config.yaml +# TODO: could cache this, although that isn't implemented at the moment +uv run python -m neural_lam.create_graph --config_path ${INFERENCE_WORKDIR}/config.yaml \ + --name ${GRAPH_NAME} ${CREATE_GRAPH_ARG} ## 3. Run inference -uv run python -m neural_lam.train_model --config_path inference/config.yaml \ - --eval \ - --graph multiscale \ - --hidden_dim 2 \ +# NB: parallel write of zarr over multiple GPUs not implemented yet, so can ony use one gpu for now +uv run python -m neural_lam.train_model --config_path ${INFERENCE_WORKDIR}/config.yaml \ + --eval test\ + --devices 1\ + --graph ${GRAPH_NAME} \ + --hidden_dim ${NUM_HIDDEN_DIMS} \ --ar_steps_eval ${NUM_EVAL_STEPS} \ - --load ${INFERENCE_ARTIFACT_PATH}/checkpoint.ckpt \ + --val_steps_to_log \ + --load ${INFERENCE_ARTIFACT_PATH}/checkpoint.pkl \ --save_eval_to_zarr_path ${OUTPUT_DATASETS_ROOT_PATH}/inference_output.zarr ## 4. Transform inference output back to original grid and variables @@ -49,5 +115,10 @@ uv run python -m neural_lam.train_model --config_path inference/config.yaml \ # means that we will have `danra_surface.zarr` in this case. We rename name # that manually here but maybe mllam-data-prep should be able to merge inputs # originating from the same zarr dataset path? -uv run python -m mllam_data_prep.recreate_inputs ${OUTPUT_DATASETS_ROOT_PATH}/inference_output.zarr -rename ${OUTPUT_DATASETS_ROOT_PATH}/danra_surface.zarr ${OUTPUT_DATASETS_ROOT_PATH}/single_levels.zarr +uv run python -m mllam_data_prep.recreate_inputs \ + --config-path ${INFERENCE_WORKDIR}/danra.datastore.yaml \ + --output-path-format "${OUTPUT_DATASETS_ROOT_PATH}/{input_name}.zarr" \ + ${OUTPUT_DATASETS_ROOT_PATH}/inference_output.zarr + +echo "Renaming ${OUTPUT_DATASETS_ROOT_PATH}/danra_surface.zarr to ${OUTPUT_DATASETS_ROOT_PATH}/single_levels.zarr" +mv ${OUTPUT_DATASETS_ROOT_PATH}/danra_surface.zarr ${OUTPUT_DATASETS_ROOT_PATH}/single_levels.zarr diff --git a/configurations/surface-dummy-model_DINI/meta.yaml b/configurations/surface-dummy-model_DINI/meta.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/configurations/surface-dummy-model_DINI/pyproject.toml b/configurations/surface-dummy-model_DINI/pyproject.toml index b72f305..7ee2dfb 100644 --- a/configurations/surface-dummy-model_DINI/pyproject.toml +++ b/configurations/surface-dummy-model_DINI/pyproject.toml @@ -1,7 +1,39 @@ [project] name = "surface-dummy-model_DINI" version = "0.1.0" -requires-python = ">=3.10" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.11" +authors = [ + {name = "Leif Denby", email = "lcd@dmi.dk"}, + {name = "Kasper Hintz", email = "kah@dmi.dk"}, +] dependencies = [ - "neural-lam @ git+https://github.com/joeloskarsson/neural-lam-dev.git@aab339f" + "parse>=1.20.2", + "dask>=2025.4.1", + "dotenv>=0.9.9", + "ipdb>=0.13.13", + "s3fs>=2025.3.2", + "tqdm>=4.67.1", + "universal-pathlib>=0.2.6", + "zarr>=3.0", + "ipython>=8.37.0", + "mllam-data-prep", + "neural-lam", +] + +[dependency-groups] +dev = [ + "pre-commit>=4.2.0", + "pytest>=8.3.5", ] + +[tool.isort] +profile = "black" + +[tool.uv.sources] +mllam-data-prep = { git = "https://github.com/leifdenby/mllam-data-prep", rev = "feat/inference-cli-args" } +neural-lam = { git = "https://github.com/leifdenby/neural-lam", rev = "dev/first-inference-image" } +[build-system] +requires = ["setuptools>=61", "setuptools_scm"] +build-backend = "setuptools.build_meta" diff --git a/configurations/surface-dummy-model_DINI/run_inference_with_dini.sh b/configurations/surface-dummy-model_DINI/run_inference_with_dini.sh new file mode 100644 index 0000000..8dd18cc --- /dev/null +++ b/configurations/surface-dummy-model_DINI/run_inference_with_dini.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# This script runs the inference container using initial conditions from DINI +# stored on AWS + +# The script takes only one argument: the analysis time to use for +# inference, in ISO8601 format (e.g. 2025-11-05T090000Z). + +if [ "$#" -ne 1 ]; then + echo "Usage: $0 " + exit 1 +fi +ANALYSIS_TIME="$1" + +DINI_ZARR="s3://harmonie-zarr/dini/control/${ANALYSIS_TIME}/single_levels.zarr/" +DATASTORE_INPUT_PATHS="danra.danra_surface=${DINI_ZARR},danra.danra_static=${DINI_ZARR}" + +ANALYSIS_TIME=${ANALYSIS_TIME} +DATASTORE_INPUT_PATHS=${DATASTORE_INPUT_PATHS} +TIME_DIMENSIONS=time + +podman run --rm \ + --device /dev/nvidia0 \ + --device /dev/nvidiactl \ + --device /dev/nvidia-uvm \ + --device /dev/nvidia-uvm-tools \ + --device /dev/nvidia-modeset \ + -v /lib/x86_64-linux-gnu/libcuda.so.1:/lib/x86_64-linux-gnu/libcuda.so.1:ro \ + -v /lib/x86_64-linux-gnu/libnvidia-ml.so.1:/lib/x86_64-linux-gnu/libnvidia-ml.so.1:ro \ + -v /lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1:/lib/x86_64-linux-gnu/libnvidia-ptxjitcompiler.so.1:ro \ + --shm-size=32g \ + -v ./inference_workdir/:/workspace/inference_workdir/ \ + -e DATASTORE_INPUT_PATHS="${DATASTORE_INPUT_PATHS}" \ + -e TIME_DIMENSIONS="${TIME_DIMENSIONS}" \ + -e ANALYSIS_TIME="${ANALYSIS_TIME}" \ + -e FORECAST_DURATION="PT18H" \ + -e NUM_EVAL_STEPS=6 \ + localhost/surface-dummy-model_dini:latest diff --git a/configurations/surface-dummy-model_DINI/src/create_inference_dataset.py b/configurations/surface-dummy-model_DINI/src/create_inference_dataset.py new file mode 100644 index 0000000..aae1692 --- /dev/null +++ b/configurations/surface-dummy-model_DINI/src/create_inference_dataset.py @@ -0,0 +1,459 @@ +import copy +import datetime +import os +from pathlib import Path +from typing import Dict + +import isodate +import mllam_data_prep as mdp +import mllam_data_prep.config as mdp_config +import parse +import xarray as xr +from loguru import logger +from neural_lam.config import DatastoreSelection, NeuralLAMConfig + +FP_TRAINING_CONFIG = "inference_artifact/configs/config.yaml" +DATASTORE_INPUT_PATH_FORMAT = "{datastore_name}.{input_name}={input_path}" + + +def _parse_datastore_input_paths(s: str) -> Dict[str, Dict[str, str]]: + """ + Parse a comma-separated list of {datastore_name}.{input_name}={input_path} + into a dictionary of dictionaries. + + Parameters + ---------- + s : str + The string to parse. + + Returns + ------- + Dict[str, Dict[str, str]] + A dictionary of dictionaries. + """ + result = {} + for item in s.split(","): + parts = parse.parse(DATASTORE_INPUT_PATH_FORMAT, item) + if parts is None: + raise ValueError( + f"Invalid format for DATASTORE_INPUT_PATHS item: {item}. " + f"Expected format is {DATASTORE_INPUT_PATH_FORMAT}" + ) + datastore_name = parts["datastore_name"] + input_name = parts["input_name"] + input_path = parts["input_path"] + + if datastore_name not in result: + result[datastore_name] = {} + elif input_name in result[datastore_name]: + raise ValueError( + f"Duplicate input name {input_name} for datastore " + f"{datastore_name} in DATASTORE_INPUT_PATHS" + ) + result[datastore_name][input_name] = input_path + return result + + +REQUIRED_ENV_VARS = { + # comma-separated list of {datastore_name}:{input_name}={input_path} + "DATASTORE_INPUT_PATHS": _parse_datastore_input_paths, + # iso8601 datetime string, e.g. 2019-02-04T12:00+0000 + "ANALYSIS_TIME": isodate.parse_datetime, + # iso8160 duration string, e.g. PT6H for 6 hours + "FORECAST_DURATION": isodate.parse_duration, + # comma-separated list of time dimensions to replace, e.g. + # time,forecast_reference_time + "TIME_DIMENSIONS": lambda s: s.split(","), + # inference working directory, relative to where inference config and + # datasets are saved + "INFERENCE_WORKDIR": str, +} + + +def _parse_env_vars() -> Dict[str, any]: + """ + Parse and validate required environment variables. + + Returns + ------- + Dict[str, any] + A dictionary of parsed environment variables. + """ + env_vars = {} + for var, parser in REQUIRED_ENV_VARS.items(): + value = os.getenv(var) + if value is None: + raise EnvironmentError(f"Environment variable {var} is not set.") + try: + env_vars[var] = parser(value) + except Exception as e: + raise ValueError(f"Error parsing environment variable {var}: {e}") + return env_vars + + +def _create_inference_datastore_config( + training_config: mdp.Config, + forecast_analysis_time: datetime.datetime, + forecast_duration: datetime.timedelta, + time_dimensions: list[str], + overwrite_input_paths: Dict[str, str] = {}, +) -> mdp.Config: + """ + From a training datastore config, create an inference datastore config that: + - samples along a new sampling dimension `sampling_dim` (default: + `analysis_time`) instead of `time` + - has a single split called "test" with a single time slice given by the + `forecast_analysis_time` argument + - optionally overwrites input paths with the `overwrite_input_paths` argument + - ensures that the output variables have the correct dimensions, for example + replacing `time` with [`analysis_time`, `elapsed_forecast_duration`] + - ensures that the input datasets have the correct dimensions and dim_mappings, + i.e. replacing `time` with [`analysis_time`, `elapsed_forecast_duration` + + Parameters + ---------- + training_config : mdp.Config + The training config to base the inference config on + forecast_analysis_time : datetime.datetime + The analysis time to use for the inference config + forecast_duration : datetime.timedelta + The forecast duration to use for the inference config + time_dimensions : list[str], optional + The list of time dimensions to replace `time` with, for example + replacing `time` with [`analysis_time`, `elapsed_forecast_duration`], + the first dimension is assumed to be the sampling dimension (e.g. the + analysis time) + overwrite_input_paths : Dict[str, str], optional + A dictionary of input names and paths to overwrite in the training config, + by default {} + + Returns + ------- + mdp.Config + The inference config + """ + # the new sampling dimension is `analysis_time` + old_sampling_dim = "time" + if not isinstance(time_dimensions, list) or len(time_dimensions) == 0: + raise ValueError( + "time_dimensions must be a non-empty list of strings, got " + f"{time_dimensions}" + ) + sampling_dim = time_dimensions[0] + # instead of only having `time` as dimension, the input forecast datasets + # have two dimensions that describe the time value [analysis_time, + # elapsed_forecast_duration] + dim_replacements = dict( + time=time_dimensions, + ) + # there will be a single split called "test" + # split_name = "test" + # which will have a single time slice, given by the analysis time argument + # to the script + sampling_coord_range = dict( + start=forecast_analysis_time, + end=forecast_analysis_time + forecast_duration, + ) + + inference_config = copy.deepcopy(training_config) + + if len(overwrite_input_paths) > 0: + for key, value in overwrite_input_paths.items(): + if key not in training_config.inputs: + raise ValueError( + f"Key {key} not found in config inputs. " + f"Available keys are: {list(training_config.inputs.keys())}" + ) + logger.info( + f"Overwriting input path for {key} with {value} previously " + f"{training_config.inputs[key].path}" + ) + inference_config.inputs[key].path = value + + # setup the split (test) for the dataset with a coordinate range along the + # sampling dimension (analysis_time) of length 1 + # inference_config.output.splitting = mdp_config.Splitting( + # dim=sampling_dim, + # splits={split_name: mdp_config.Split(**sampling_coord_range)}, + # ) + + # XXX: currently (as of 0.4.0) neural-lam requires that `train`, `val` and + # `test` splits are always present, even if they are not used. So we + # create empty `train` and `val` splits here + inference_config.output.splitting = mdp_config.Splitting( + dim="time", + splits={ + "train": mdp_config.Split( + start=forecast_analysis_time, end=forecast_analysis_time + ), + "val": mdp_config.Split( + start=forecast_analysis_time, end=forecast_analysis_time + ), + "test": mdp_config.Split( + start=forecast_analysis_time, + end=forecast_analysis_time + forecast_duration, + ), + }, + ) + + # ensure the output data is sampled along the sampling dimension + # (analysis_time) too + inference_config.output.coord_ranges = { + sampling_dim: mdp_config.Range(**sampling_coord_range) + } + + inference_config.output.chunking = {sampling_dim: 1} + + # replace old sampling_dimension (time) dimension in outputs with + # [`analysis_time`, `elapsed_forecast_time`] + for variable, dims in training_config.output.variables.items(): + if old_sampling_dim in dims: + orig_sampling_dim_index = dims.index(old_sampling_dim) + dims.remove(old_sampling_dim) + for dim in dim_replacements[old_sampling_dim][::-1]: + dims.insert(orig_sampling_dim_index, dim) + inference_config.output.variables[variable] = dims + logger.info( + f"Replaced {old_sampling_dim} dimension with" + f" {dim_replacements[old_sampling_dim]} for {variable}" + ) + + # these dimensions should also be "renamed" from the input datasets + for input_name in training_config.inputs.keys(): + if "time" in training_config.inputs[input_name].dim_mapping: + dims = training_config.inputs[input_name].dims + orig_sampling_dim_index = dims.index(old_sampling_dim) + dims.remove(old_sampling_dim) + for dim in dim_replacements[old_sampling_dim][::-1]: + dims.insert(orig_sampling_dim_index, dim) + inference_config.inputs[input_name].dims = dims + + del inference_config.inputs[input_name].dim_mapping[ + old_sampling_dim + ] + + # add new "rename" dim-mappins for `analysis_time` and + # `elapsed_forecast_duration` + for dim in dim_replacements[old_sampling_dim]: + inference_config.inputs[input_name].dim_mapping[ + dim + ] = mdp_config.DimMapping(method="rename", dim=dim) + + return inference_config + + +def _prepare_inference_dataset_zarr( + datastore_name: str, + datastore_input_paths: Dict[str, str], + fp_inference_workdir: str, + analysis_time: datetime.datetime, + forecast_duration: datetime.timedelta, + time_dimensions: list[str], +) -> str: + """ + Prepare the inference dataset for a single datastore. + + Parameters + ---------- + datastore_name : str + The name of the datastore to prepare the inference dataset for, this + sets the expected path of the training datastore config and stats. + datastore_input_paths : Dict[str, str] + A dictionary of input names and paths to overwrite in the training + config. + fp_inference_workdir : str + The path to the inference working directory, where the inference + datastore config(s) and zarr dataset(s) will be saved. + analysis_time : datetime.datetime + The analysis time to use for the inference dataset. + forecast_duration : datetime.timedelta + The forecast duration to use for the inference dataset. + time_dimensions : list[str] + The list of time dimensions to replace `time` with, for example + replacing `time` with [`analysis_time`, `elapsed_forecast_duration`] + + Returns + ------- + str + The path to the inference datastore config file. The inference dataset + is saved as a zarr store in the same directory as the config file, with + the same name but with a .zarr extension instead of .yaml. + """ + fp_training_datastore_stats = ( + f"inference_artifact/stats/{datastore_name}.datastore.stats.zarr" + ) + ds_stats = xr.open_dataset(fp_training_datastore_stats) + logger.debug(f"Opened stats dataset: {ds_stats}") + + fp_training_datastore_config = ( + f"inference_artifact/configs/{datastore_name}.datastore.yaml" + ) + + logger.debug( + f"Loading training datastore config from {fp_training_datastore_config}" + ) + datastore_training_config = mdp.Config.from_yaml_file( + fp_training_datastore_config + ) + + inference_config = _create_inference_datastore_config( + training_config=datastore_training_config, + forecast_analysis_time=analysis_time, + forecast_duration=forecast_duration, + overwrite_input_paths=datastore_input_paths, + time_dimensions=time_dimensions, + ) + + fp_inference_datastore_config = ( + f"{fp_inference_workdir}/{datastore_name}.datastore.yaml" + ) + + Path(fp_inference_datastore_config).parent.mkdir( + parents=True, exist_ok=True + ) + logger.info( + f"Saving inference datastore config to {fp_inference_datastore_config}" + ) + + # neural-lam's convention is to have the same name for the zarr store + # as the config file, but with .zarr extension + fp_dataset = fp_inference_datastore_config.replace(".yaml", ".zarr") + inference_config.to_yaml_file(fp_inference_datastore_config) + + ds = mdp.create_dataset(config=inference_config, ds_stats=ds_stats) + logger.info(f"Writing inference dataset to {fp_dataset}") + ds.to_zarr(fp_dataset) + + return fp_inference_datastore_config + + +def _prepare_all_inference_dataset_zarr( + analysis_time: datetime.datetime, + forecast_duration: datetime.timedelta, + datastore_input_paths: Dict[str, Dict[str, str]], + fp_inference_workdir: str, + time_dimensions: list[str], +) -> str: + """ + Prepare the inference dataset. + + Parameters + ---------- + analysis_time : datetime.datetime + The analysis time to use for the inference dataset(s). + forecast_duration : datetime.timedelta + The forecast duration to use for the inference dataset(s). + datastore_input_paths : Dict[str, Dict[str,str]] + A dictionary of datastore names and their corresponding input names + and paths to overwrite in the training config. + fp_inference_workdir : str + The path to the inference working directory, where the inference + datastore config(s) and zarr dataset(s) will be saved. + time_dimensions : list[str] + The list of time dimensions to replace `time` with, for example + replacing `time` with [`analysis_time`, `elapsed_forecast_duration`] + + Returns + ------- + Dict[str, str] + A dictionary of datastore names and the path to their corresponding + inference datastore config file. The inference dataset is saved as a + zarr store in the same directory as the config file, with the same + name but with a .zarr extension instead of .yaml. + """ + fps_datastore_configs = {} + for datastore_name, input_paths in datastore_input_paths.items(): + logger.info(f"Processing {datastore_name} datastore for inference") + fp_training_datastore_config = _prepare_inference_dataset_zarr( + datastore_name=datastore_name, + datastore_input_paths=input_paths, + fp_inference_workdir=fp_inference_workdir, + analysis_time=analysis_time, + forecast_duration=forecast_duration, + time_dimensions=time_dimensions, + ) + + fps_datastore_configs[datastore_name] = fp_training_datastore_config + + return fps_datastore_configs + + +def _create_inference_config( + fps_inference_datastore_config: Dict[str, str], fp_inference_workdir: str +) -> str: + training_config = NeuralLAMConfig.from_yaml_file(FP_TRAINING_CONFIG) + inference_config = copy.deepcopy(training_config) + + fp_inference_config = f"{fp_inference_workdir}/config.yaml" + + def _set_datastore_config_path(node: DatastoreSelection, fp: str): + node.config_path = Path(fp).relative_to( + Path(fp_inference_config).parent + ) + # XXX: There is a bug in neural-lam here that means that the datastore kind + # doesn't correctly get serialised to a string in the config file when + # saved to yaml + node.kind = str(node.kind) + + # see if the neural-lam config was for single or multiple datastores + if hasattr(training_config, "datastores"): + # using multiple datastores + for ( + datastore_name, + fp_datastore_config, + ) in fps_inference_datastore_config.items(): + if datastore_name not in inference_config.datastores: + raise ValueError( + f"Datastore {datastore_name} not found in training config. " + f"Available datastores are: " + f"{list(inference_config.datastores.keys())}" + ) + _set_datastore_config_path( + node=inference_config.datastores[datastore_name], + fp=fp_datastore_config, + ) + else: + fp_datastore_config = list(fps_inference_datastore_config.values())[0] + # using a single datastore + _set_datastore_config_path( + node=inference_config.datastore, fp=fp_datastore_config + ) + + inference_config.to_yaml_file(fp_inference_config) + logger.info(f"Saved inference config to {fp_inference_config}") + + return fp_inference_config + + +@logger.catch(reraise=True) +def main(): + env_vars = _parse_env_vars() + # convert analysis time to UTC and strip timezone info + analysis_time = ( + env_vars["ANALYSIS_TIME"] + .astimezone(datetime.timezone.utc) + .replace(tzinfo=None) + ) + + fps_inference_datastore_config = _prepare_all_inference_dataset_zarr( + analysis_time=analysis_time, + forecast_duration=env_vars["FORECAST_DURATION"], + datastore_input_paths=env_vars["DATASTORE_INPUT_PATHS"], + fp_inference_workdir=env_vars["INFERENCE_WORKDIR"], + time_dimensions=env_vars["TIME_DIMENSIONS"], + ) + _create_inference_config( + fps_inference_datastore_config=fps_inference_datastore_config, + fp_inference_workdir=env_vars["INFERENCE_WORKDIR"], + ) + + +if __name__ == "__main__": + with_debugger = os.getenv("MLWM_DEBUGGER", "0") + if with_debugger == "ipdb": + import ipdb + + with ipdb.launch_ipdb_on_exception(): + main() + else: + main() diff --git a/pyproject.toml b/pyproject.toml index 7d8b728..d3d52af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "mlwm-deployment" dynamic = ["version"] description = "Add your description here" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" dependencies = [ "parse>=1.20.2", "dask>=2025.4.1", @@ -15,11 +15,14 @@ dependencies = [ "s3fs>=2025.3.2", "tqdm>=4.67.1", "universal-pathlib>=0.2.6", - "zarr", + "zarr>=3.0.0", + "xarray>=2025.5.0", ] [dependency-groups] dev = [ + "ipykernel>=6.30.1", + "jinja2>=3.1.6", "pre-commit>=4.2.0", "pytest>=8.3.5", ] diff --git a/src/mlwm/tests/test_paths.py b/src/mlwm/tests/test_paths.py index e8eaa9e..91adf24 100644 --- a/src/mlwm/tests/test_paths.py +++ b/src/mlwm/tests/test_paths.py @@ -1,8 +1,7 @@ import datetime -import pytest - import mlwm.paths as mlwm_paths +import pytest @pytest.mark.parametrize(