Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ jobs:
- name: Install dependencies
run: |
pip install -U pip # upgrade pip
pip install '.[develop]' './lstm_ewts'
pip install git+https://github.com/ngwpc/nwm-ewts.git#subdirectory=runtime/python/ewts
pip install '.[develop]'
- name: Echo dependency versions
run: |
pip freeze
Expand Down
113 changes: 100 additions & 13 deletions lstm/bmi_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,26 @@
import typing
from dataclasses import dataclass
from pathlib import Path
import logging

import numpy as np
import numpy.typing as npt
import pandas as pd
import pickle
import torch
import yaml
try:
from yaml import CSafeLoader as SafeLoader
except ImportError:
from yaml import SafeLoader

from . import nextgen_cuda_lstm
from .base import BmiBase
from lstm_ewts import configure_logging, MODULE_NAME
from .model_state import State, StateFacade, Var

import logging
LOG = logging.getLogger(MODULE_NAME)
import ewts
LOG = ewts.get_logger(ewts.LSTM_ID)


# -------------- Dynamic Attributes -----------------------------
_dynamic_input_vars = [
Expand Down Expand Up @@ -147,7 +153,7 @@ def __init__(self, cfg: dict[str, typing.Any], output_scaling_factor_cms: float)
# load training feature scales
scaler_file = cfg["run_dir"] / "train_data/train_data_scaler.yml"
with scaler_file.open("r") as fp:
train_data_scaler = yaml.safe_load(fp)
train_data_scaler = yaml.load(fp, Loader=SafeLoader)
self.scalars = load_training_scalars(cfg, train_data_scaler)

# initialize torch lstm object
Expand Down Expand Up @@ -200,6 +206,16 @@ def update(self, state: Valuer) -> typing.Iterable[Var]:
precipitation_mm_h
)

def serialize(self):
return {
"h": self.h_t.numpy(),
"c": self.c_t.numpy()
}

def deserialize(self, data: dict):
self.h_t = torch.from_numpy(data["h"])
self.c_t = torch.from_numpy(data["c"])


def bmi_array(arr: list[float]) -> npt.NDArray:
"""Trivial wrapper function to ensure the expected numpy array datatype is used."""
Expand Down Expand Up @@ -419,16 +435,20 @@ def __init__(self) -> None:
self.cfg_bmi: dict[str, typing.Any]
self.ensemble_members: list[EnsembleMember]

def initialize(self, config_file: str) -> None:
# statically stored seriaized data
self._serialized_size = np.array([0], dtype=np.uint64)
self._serialized = np.array([], dtype=np.uint8)

# configure the Error Warning and Trapping System logger
configure_logging()
def initialize(self, config_file: str) -> None:

# This is required prior to the first log message is issued by t-route.
LOG.bind()

LOG.info(f"Initializing with {config_file}")

# read and setup main configuration file
with open(config_file, "r") as fp:
self.cfg_bmi = yaml.safe_load(fp)
self.cfg_bmi = yaml.load(fp, Loader=SafeLoader)
coerce_config(self.cfg_bmi)

# ----------- The output is area normalized, this is needed to un-normalize it
Expand All @@ -440,7 +460,7 @@ def initialize(self, config_file: str) -> None:
# initialize ensemble members
self.ensemble_members = []
for member_cfg_file in self.cfg_bmi["train_cfg_file"]:
cfg = yaml.safe_load(member_cfg_file.read_text())
cfg = yaml.load(member_cfg_file.read_text(), Loader=SafeLoader)
coerce_config(cfg)
member = EnsembleMember(cfg, output_factor_cms)
self.ensemble_members.append(member)
Expand Down Expand Up @@ -515,16 +535,38 @@ def get_var_grid(self, name: str) -> int:
return 0

def get_var_type(self, name: str) -> str:
if name == "serialization_state":
return self._serialized.dtype.name
elif name == "serialization_size" or name == "serialization_create":
return self._serialized_size.dtype.name
elif name == "serialization_free":
return np.dtype(np.intc).name
elif name == "reset_time":
return np.dtype(np.double).name
return self.get_value_ptr(name).dtype.name

def get_var_units(self, name: str) -> str:
return first_containing(name, self._outputs, self._dynamic_inputs).unit(name)

def get_var_itemsize(self, name: str) -> int:
if name == "serialization_state":
return self._serialized.dtype.itemsize
if name == "serialization_size" or name == "serialization_create":
return self._serialized_size.dtype.itemsize
if name == "serialization_free":
return np.dtype(np.intc).itemsize
if name == "reset_time":
return np.dtype(np.double).itemsize
return self.get_value_ptr(name).itemsize

def get_var_nbytes(self, name: str) -> int:
return self.get_var_itemsize(name) * len(self.get_value_ptr(name))
if name == "serialization_create":
return self._serialized_size.nbytes
if name == "serialization_free":
return np.dtype(np.intc).itemsize
if name == "reset_time":
return np.dtype(np.double).itemsize
return self.get_value_ptr(name).nbytes

def get_var_location(self, name: str) -> str:
# raises KeyError on failure
Expand Down Expand Up @@ -553,6 +595,10 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray:

def get_value_ptr(self, name: str) -> np.ndarray:
"""Returns a _reference_ to a variable's np.NDArray."""
if name == "serialization_state":
return self._serialized
elif name == "serialization_size":
return self._serialized_size
return first_containing(name, self._outputs, self._dynamic_inputs).value(name)

def get_value_at_indices(
Expand All @@ -563,9 +609,18 @@ def get_value_at_indices(
).value_at_indices(name, dest, inds)

def set_value(self, name: str, src: np.ndarray) -> None:
return first_containing(name, self._outputs, self._dynamic_inputs).set_value(
name, src
)
if name == "serialization_state":
self._deserialize(src)
elif name == "serialization_create":
self._serialize()
elif name == "serialization_free":
self._free_serialized()
elif name == "reset_time":
self._timestep = 0
else:
return first_containing(name, self._outputs, self._dynamic_inputs).set_value(
name, src
)

def set_value_at_indices(
self, name: str, inds: np.ndarray, src: np.ndarray
Expand Down Expand Up @@ -593,6 +648,38 @@ def get_grid_type(self, grid: int) -> str:
return "scalar"
raise RuntimeError(f"unsupported grid type: {grid!s}. only support 0")

def _serialize(self):
"""Convert all dynamic properties that can change after the `bmi_LSTM` has had `initialize()` called into an object that can be serialized through `pickle`.
Then, set the BMI's `_serialized` property to the byte representation of that pickled data and adjust the static `_serialized_size` property."""
data = {
"dynamic_inputs": self._dynamic_inputs.serialize(),
"static_inputs": self._static_inputs.serialize(),
"outputs": self._outputs.serialize(),
"ensemble": [em.serialize() for em in self.ensemble_members],
"timestep": self._timestep,
}
serialized = pickle.dumps(data)
self._serialized = np.array(bytearray(serialized), dtype=np.uint8)
self._serialized_size[0] = len(self._serialized)

def _deserialize(self, array: np.ndarray):
"""Interpret the bytes of the numpy array as previously pickled data from `_serialize()` and update the current values.
No data structure check will be made on the input array or loaded bytes. It will be assumed that the input data is of the same structure as what is generated from `_serialize()`."""
data = bytes(array)
deserialized = pickle.loads(data)
self._dynamic_inputs.deserialize(deserialized["dynamic_inputs"])
self._static_inputs.deserialize(deserialized["static_inputs"])
self._outputs.deserialize(deserialized["outputs"])
for bmi_em, data_em in zip(self.ensemble_members, deserialized["ensemble"], strict=True):
bmi_em.deserialize(data_em)
self._timestep = deserialized["timestep"]
self._free_serialized()

def _free_serialized(self):
"""Clear the current serialized data and set the size property value to 0."""
self._serialized_size[0] = 0
self._serialized = np.array([], dtype=self._serialized.dtype)


def coerce_config(cfg: dict[str, typing.Any]):
for key, val in cfg.items():
Expand Down
17 changes: 17 additions & 0 deletions lstm/model_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,23 @@ def __iter__(self) -> typing.Iterator[Var]:
def __len__(self) -> int:
return len(self._name_mapping)

def serialize(self):
"""Return the State represented as a list of dicts representing the `Var` properties."""
return [
{"name": var.name, "unit": var.unit, "value": var.value}
for var in self._name_mapping.values()
]

def deserialize(self, values: list):
"""Replace the current Vars with values from the intput list."""
self._name_mapping.clear()
for var in values:
self._name_mapping[var["name"]] = Var(
name=var["name"],
unit=var["unit"],
value=var["value"]
)


class StateFacade:
"""
Expand Down
13 changes: 0 additions & 13 deletions lstm_ewts/pyproject.toml

This file was deleted.

34 changes: 0 additions & 34 deletions lstm_ewts/src/lstm_ewts/__init__.py

This file was deleted.

Loading
Loading