From 9d21f2c2a5c93dae8a33796ed9bf63b790762e3c Mon Sep 17 00:00:00 2001 From: Josh Cunningham Date: Wed, 1 Oct 2025 21:04:07 -0500 Subject: [PATCH 01/10] Eager debug fix (#62) * fix debug message eager evaluation * use faster libyaml parser * Add libyaml fallback Apply suggestion from @aaraney Co-authored-by: Austin Raney * replace f-strings in logger --------- Co-authored-by: Austin Raney --- lstm/bmi_lstm.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 7e2697f..b1372a5 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -59,6 +59,10 @@ import pandas as pd import torch import yaml +try: + from yaml import CSafeLoader as SafeLoader +except ImportError: + from yaml import SafeLoader from . import nextgen_cuda_lstm from .base import BmiBase @@ -143,7 +147,7 @@ def __init__(self, cfg: dict[str, typing.Any], output_scaling_factor_cms: float) # load training feature scales scaler_file = cfg["run_dir"] / "train_data/train_data_scaler.yml" with scaler_file.open("r") as fp: - train_data_scaler = yaml.safe_load(fp) + train_data_scaler = yaml.load(fp, Loader=SafeLoader) self.scalars = load_training_scalars(cfg, train_data_scaler) # initialize torch lstm object @@ -290,14 +294,13 @@ def gather_inputs( value = state.value(bmi_name) assert value.size == 1, "`value` should a single scalar in a 1d array" input_list.append(value[0]) - - logger.debug(f" {lstm_name=}") - logger.debug(f" {bmi_name=}") - logger.debug(f" {type(value)=}") - logger.debug(f" {value=}") + logger.debug(" lstm_name=%s", lstm_name) + logger.debug(" bmi_name=%s", bmi_name) + logger.debug(" type(value)=%s", type(value)) + logger.debug(" value=%s", value) collected = bmi_array(input_list) - logger.debug(f"Collected inputs: {collected}") + logger.debug("Collected inputs: %s",collected) return collected @@ -310,10 +313,10 @@ def scale_inputs( # Center and scale the input values for use in torch input_array_scaled = (input - mean) / std - logger.debug(f"### input_array ={input}") - logger.debug(f"### dtype(input_array) ={input.dtype}") - logger.debug(f"### type(input_array_scaled) ={type(input_array_scaled)}") - logger.debug(f"### dtype(input_array_scaled) ={input_array_scaled.dtype}") + logger.debug("### input_array =%s", input) + logger.debug("### dtype(input_array) =%s", input.dtype) + logger.debug("### type(input_array_scaled) =%s", type(input_array_scaled)) + logger.debug("### dtype(input_array_scaled) =%s", input_array_scaled.dtype) return input_array_scaled @@ -324,7 +327,7 @@ def scale_outputs( output_std: npt.NDArray, output_scale_factor_cms: float, ): - logger.debug(f"model output: {output[0, 0, 0].numpy().tolist()}") + logger.debug("model output: %s", output[0, 0, 0].numpy().tolist()) if cfg["target_variables"][0] in ["qobs_mm_per_hour", "QObs(mm/hr)", "QObs(mm/h)"]: surface_runoff_mm = output[0, 0, 0].numpy() * output_std + output_mean @@ -405,7 +408,7 @@ def __init__(self) -> None: def initialize(self, config_file: str) -> None: # read and setup main configuration file with open(config_file, "r") as fp: - self.cfg_bmi = yaml.safe_load(fp) + self.cfg_bmi = yaml.load(fp, Loader=SafeLoader) coerce_config(self.cfg_bmi) # TODO: aaraney: config logging levels to python logging levels @@ -422,7 +425,7 @@ def initialize(self, config_file: str) -> None: # initialize ensemble members self.ensemble_members = [] for member_cfg_file in self.cfg_bmi["train_cfg_file"]: - cfg = yaml.safe_load(member_cfg_file.read_text()) + cfg = yaml.load(member_cfg_file.read_text(), Loader=SafeLoader) coerce_config(cfg) member = EnsembleMember(cfg, output_factor_cms) self.ensemble_members.append(member) @@ -458,7 +461,7 @@ def update(self) -> None: def update_until(self, time: float) -> None: if time <= self.get_current_time(): current_time = self.get_current_time() - logger.warning(f"no update performed: {time=} <= {current_time=}") + logger.warning("no update performed: time=%s <= current_time=%s", time, current_time) return None n_steps, remainder = divmod( @@ -467,7 +470,7 @@ def update_until(self, time: float) -> None: if remainder != 0: logger.warning( - f"time is not multiple of time step size. updating until: {time - remainder=} " + "time is not multiple of time step size. updating until: %s", (time - remainder) ) for _ in range(int(n_steps)): From 0691d769ed79ebdf0f5442e49996ce820d949aae Mon Sep 17 00:00:00 2001 From: Ian Todd Date: Tue, 27 Jan 2026 08:54:03 -0500 Subject: [PATCH 02/10] Implement serialization --- lstm/bmi_lstm.py | 88 ++++++++++++++++++++++++++++++++++++++++++--- lstm/model_state.py | 17 +++++++++ 2 files changed, 101 insertions(+), 4 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 85422e5..7e0aaf1 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -57,6 +57,7 @@ import numpy as np import numpy.typing as npt import pandas as pd +import pickle import torch import yaml @@ -200,6 +201,16 @@ def update(self, state: Valuer) -> typing.Iterable[Var]: precipitation_mm_h ) + def serialize(self): + return { + "h": self.h_t.numpy(), + "c": self.c_t.numpy() + } + + def deserialize(self, data: dict): + self.h_t = torch.from_numpy(data["h"]) + self.c_t = torch.from_numpy(data["c"]) + def bmi_array(arr: list[float]) -> npt.NDArray: """Trivial wrapper function to ensure the expected numpy array datatype is used.""" @@ -419,6 +430,10 @@ def __init__(self) -> None: self.cfg_bmi: dict[str, typing.Any] self.ensemble_members: list[EnsembleMember] + # statically stored seriaized data + self._serialized_size = np.array([0], dtype=np.uint64) + self._serialized = np.array([], dtype=np.uint8) + def initialize(self, config_file: str) -> None: # configure the Error Warning and Trapping System logger @@ -515,16 +530,38 @@ def get_var_grid(self, name: str) -> int: return 0 def get_var_type(self, name: str) -> str: + if name == "serialization_state": + return self._serialized.dtype.name + elif name == "serialization_size" or name == "serialization_create": + return self._serialized_size.dtype.name + elif name == "serialization_free": + return np.dtype(np.intc).name + elif name == "reset_time": + return np.dtype(np.double).name return self.get_value_ptr(name).dtype.name def get_var_units(self, name: str) -> str: return first_containing(name, self._outputs, self._dynamic_inputs).unit(name) def get_var_itemsize(self, name: str) -> int: + if name == "serialization_state": + return self._serialized.dtype.itemsize + if name == "serialization_size" or name == "serialization_create": + return self._serialized_size.dtype.itemsize + if name == "serialization_free": + return np.dtype(np.intc).itemsize + if name == "reset_time": + return np.dtype(np.double).itemsize return self.get_value_ptr(name).itemsize def get_var_nbytes(self, name: str) -> int: - return self.get_var_itemsize(name) * len(self.get_value_ptr(name)) + if name == "serialization_create": + return self._serialized_size.nbytes + if name == "serialization_free": + return np.dtype(np.intc).itemsize + if name == "reset_time": + return np.dtype(np.double).itemsize + return self.get_value_ptr(name).nbytes def get_var_location(self, name: str) -> str: # raises KeyError on failure @@ -553,6 +590,10 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray: def get_value_ptr(self, name: str) -> np.ndarray: """Returns a _reference_ to a variable's np.NDArray.""" + if name == "serialization_state": + return self._serialized + elif name == "serialization_size": + return self._serialized_size return first_containing(name, self._outputs, self._dynamic_inputs).value(name) def get_value_at_indices( @@ -563,9 +604,18 @@ def get_value_at_indices( ).value_at_indices(name, dest, inds) def set_value(self, name: str, src: np.ndarray) -> None: - return first_containing(name, self._outputs, self._dynamic_inputs).set_value( - name, src - ) + if name == "serialization_state": + deserialize_bmi(src, self) + elif name == "serialization_create": + serialize_bmi(self) + elif name == "serialization_free": + free_serialized_bmi(self) + elif name == "reset_time": + self._timestep = 0 + else: + return first_containing(name, self._outputs, self._dynamic_inputs).set_value( + name, src + ) def set_value_at_indices( self, name: str, inds: np.ndarray, src: np.ndarray @@ -594,6 +644,36 @@ def get_grid_type(self, grid: int) -> str: raise RuntimeError(f"unsupported grid type: {grid!s}. only support 0") +def serialize_bmi(bmi: bmi_LSTM): + data = { + "dynamic_inputs": bmi._dynamic_inputs.serialize(), + "static_inputs": bmi._static_inputs.serialize(), + "outputs": bmi._outputs.serialize(), + "ensemble": [em.serialize() for em in bmi.ensemble_members], + "timestep": bmi._timestep, + } + serialized = pickle.dumps(data) + bmi._serialized = np.array(bytearray(serialized), dtype=np.uint8) + bmi._serialized_size[0] = len(bmi._serialized) + + +def deserialize_bmi(array: np.ndarray, bmi: bmi_LSTM): + data = bytes(array) + deserialized = pickle.loads(data) + bmi._dynamic_inputs.deserialize(deserialized["dynamic_inputs"]) + bmi._static_inputs.deserialize(deserialized["static_inputs"]) + bmi._outputs.deserialize(deserialized["outputs"]) + for bmi_em, data_em in zip(bmi.ensemble_members, deserialized["ensemble"], strict=True): + bmi_em.deserialize(data_em) + bmi._timestep = deserialized["timestep"] + free_serialized_bmi(bmi) + + +def free_serialized_bmi(bmi: bmi_LSTM): + bmi._serialized_size[0] = 0 + bmi._serialized = np.array([], dtype=bmi._serialized.dtype) + + def coerce_config(cfg: dict[str, typing.Any]): for key, val in cfg.items(): # Handle 'train_cfg_file' specifically to ensure it is always a list diff --git a/lstm/model_state.py b/lstm/model_state.py index c13fcfc..575d46d 100644 --- a/lstm/model_state.py +++ b/lstm/model_state.py @@ -94,6 +94,23 @@ def __iter__(self) -> typing.Iterator[Var]: def __len__(self) -> int: return len(self._name_mapping) + def serialize(self): + """Return the State represented as a list.""" + return [ + {"name": var.name, "unit": var.unit, "value": var.value} + for var in self._name_mapping.values() + ] + + def deserialize(self, values: list): + """Replace the current Vars with values from the intput list.""" + self._name_mapping.clear() + for var in values: + self._name_mapping[var["name"]] = Var( + name=var["name"], + unit=var["unit"], + value=var["value"] + ) + class StateFacade: """ From 72c2ab73d80652aac18a3d35b2b5af212eb88122 Mon Sep 17 00:00:00 2001 From: Ian Todd Date: Mon, 2 Feb 2026 11:15:34 -0500 Subject: [PATCH 03/10] Add docstrings --- lstm/bmi_lstm.py | 66 +++++++++++++++++++++++---------------------- lstm/model_state.py | 2 +- 2 files changed, 35 insertions(+), 33 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 7e0aaf1..1653f16 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -605,11 +605,11 @@ def get_value_at_indices( def set_value(self, name: str, src: np.ndarray) -> None: if name == "serialization_state": - deserialize_bmi(src, self) + self._deserialize(src, self) elif name == "serialization_create": - serialize_bmi(self) + self._serialize(self) elif name == "serialization_free": - free_serialized_bmi(self) + self._free_serialized(self) elif name == "reset_time": self._timestep = 0 else: @@ -643,35 +643,37 @@ def get_grid_type(self, grid: int) -> str: return "scalar" raise RuntimeError(f"unsupported grid type: {grid!s}. only support 0") - -def serialize_bmi(bmi: bmi_LSTM): - data = { - "dynamic_inputs": bmi._dynamic_inputs.serialize(), - "static_inputs": bmi._static_inputs.serialize(), - "outputs": bmi._outputs.serialize(), - "ensemble": [em.serialize() for em in bmi.ensemble_members], - "timestep": bmi._timestep, - } - serialized = pickle.dumps(data) - bmi._serialized = np.array(bytearray(serialized), dtype=np.uint8) - bmi._serialized_size[0] = len(bmi._serialized) - - -def deserialize_bmi(array: np.ndarray, bmi: bmi_LSTM): - data = bytes(array) - deserialized = pickle.loads(data) - bmi._dynamic_inputs.deserialize(deserialized["dynamic_inputs"]) - bmi._static_inputs.deserialize(deserialized["static_inputs"]) - bmi._outputs.deserialize(deserialized["outputs"]) - for bmi_em, data_em in zip(bmi.ensemble_members, deserialized["ensemble"], strict=True): - bmi_em.deserialize(data_em) - bmi._timestep = deserialized["timestep"] - free_serialized_bmi(bmi) - - -def free_serialized_bmi(bmi: bmi_LSTM): - bmi._serialized_size[0] = 0 - bmi._serialized = np.array([], dtype=bmi._serialized.dtype) + def _serialize(self): + """Convert all dynamic properties that can change after the `bmi_LSTM` has had `initialize()` called into an object that can be serialized through `pickle`. + Then, set the BMI's `_serialized` property to the byte representation of that pickled data and adjust the static `_serialized_size` property.""" + data = { + "dynamic_inputs": self._dynamic_inputs.serialize(), + "static_inputs": self._static_inputs.serialize(), + "outputs": self._outputs.serialize(), + "ensemble": [em.serialize() for em in self.ensemble_members], + "timestep": self._timestep, + } + serialized = pickle.dumps(data) + self._serialized = np.array(bytearray(serialized), dtype=np.uint8) + self._serialized_size[0] = len(self._serialized) + + def _deserialize(self, array: np.ndarray): + """Interpret the bytes of the numpy array as previously pickled data from `_serialize()` and update the current values. + No data structure check will be made on the input array or loaded bytes. It will be assumed that the input data is of the same structure as what is generated from `_serialize()`.""" + data = bytes(array) + deserialized = pickle.loads(data) + self._dynamic_inputs.deserialize(deserialized["dynamic_inputs"]) + self._static_inputs.deserialize(deserialized["static_inputs"]) + self._outputs.deserialize(deserialized["outputs"]) + for bmi_em, data_em in zip(self.ensemble_members, deserialized["ensemble"], strict=True): + bmi_em.deserialize(data_em) + self._timestep = deserialized["timestep"] + self._free_serialized() + + def _free_serialized(self): + """Clear the current serialized data and set the size property value to 0.""" + self._serialized_size[0] = 0 + self._serialized = np.array([], dtype=self._serialized.dtype) def coerce_config(cfg: dict[str, typing.Any]): diff --git a/lstm/model_state.py b/lstm/model_state.py index 575d46d..bbacbd6 100644 --- a/lstm/model_state.py +++ b/lstm/model_state.py @@ -95,7 +95,7 @@ def __len__(self) -> int: return len(self._name_mapping) def serialize(self): - """Return the State represented as a list.""" + """Return the State represented as a list of dicts representing the `Var` properties.""" return [ {"name": var.name, "unit": var.unit, "value": var.value} for var in self._name_mapping.values() From a3d5b4b66a853e5db9f118fd75556f5c52bba8af Mon Sep 17 00:00:00 2001 From: Ian Todd Date: Thu, 12 Feb 2026 14:54:33 -0500 Subject: [PATCH 04/10] Merge OWP master changes --- lstm/bmi_lstm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 85422e5..372ba2a 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -59,6 +59,10 @@ import pandas as pd import torch import yaml +try: + from yaml import CSafeLoader as SafeLoader +except ImportError: + from yaml import SafeLoader from . import nextgen_cuda_lstm from .base import BmiBase @@ -147,7 +151,7 @@ def __init__(self, cfg: dict[str, typing.Any], output_scaling_factor_cms: float) # load training feature scales scaler_file = cfg["run_dir"] / "train_data/train_data_scaler.yml" with scaler_file.open("r") as fp: - train_data_scaler = yaml.safe_load(fp) + train_data_scaler = yaml.load(fp, Loader=SafeLoader) self.scalars = load_training_scalars(cfg, train_data_scaler) # initialize torch lstm object @@ -428,7 +432,7 @@ def initialize(self, config_file: str) -> None: # read and setup main configuration file with open(config_file, "r") as fp: - self.cfg_bmi = yaml.safe_load(fp) + self.cfg_bmi = yaml.load(fp, Loader=SafeLoader) coerce_config(self.cfg_bmi) # ----------- The output is area normalized, this is needed to un-normalize it @@ -440,7 +444,7 @@ def initialize(self, config_file: str) -> None: # initialize ensemble members self.ensemble_members = [] for member_cfg_file in self.cfg_bmi["train_cfg_file"]: - cfg = yaml.safe_load(member_cfg_file.read_text()) + cfg = yaml.load(member_cfg_file.read_text(), Loader=SafeLoader) coerce_config(cfg) member = EnsembleMember(cfg, output_factor_cms) self.ensemble_members.append(member) From 650d5ed6a3986c8fda103a121f6da724cb7220a7 Mon Sep 17 00:00:00 2001 From: Ian Todd Date: Wed, 18 Feb 2026 15:07:00 -0500 Subject: [PATCH 05/10] Remove self arguments from serialization methods --- lstm/bmi_lstm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 1653f16..44b93b7 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -605,11 +605,11 @@ def get_value_at_indices( def set_value(self, name: str, src: np.ndarray) -> None: if name == "serialization_state": - self._deserialize(src, self) + self._deserialize(src) elif name == "serialization_create": - self._serialize(self) + self._serialize() elif name == "serialization_free": - self._free_serialized(self) + self._free_serialized() elif name == "reset_time": self._timestep = 0 else: From 39cb3246f01db4df507ab3c2133e4e8049d5a525 Mon Sep 17 00:00:00 2001 From: "Carolyn.Maynard" Date: Wed, 18 Feb 2026 08:07:56 -0800 Subject: [PATCH 06/10] Use ewts helper to get os env var --- lstm_ewts/src/lstm_ewts/config.py | 5 +++-- lstm_ewts/src/lstm_ewts/helper.py | 32 +++++++++++++++++++++++++++++++ lstm_ewts/src/lstm_ewts/paths.py | 5 +++-- 3 files changed, 38 insertions(+), 4 deletions(-) create mode 100644 lstm_ewts/src/lstm_ewts/helper.py diff --git a/lstm_ewts/src/lstm_ewts/config.py b/lstm_ewts/src/lstm_ewts/config.py index f5f0267..cddc4c6 100644 --- a/lstm_ewts/src/lstm_ewts/config.py +++ b/lstm_ewts/src/lstm_ewts/config.py @@ -43,6 +43,7 @@ ) from .formatter import CustomFormatter from .paths import get_log_file_path +from .helper import getenv_any def translate_ngwpc_log_level(level: str) -> str: level = level.strip().upper() @@ -75,7 +76,7 @@ def configure_logging(): return logger # logger already initialized, nothing else to do # Default to enabled if flag not set or is set to disabled - raw_value = os.getenv(EV_EWTS_LOGGING) + raw_value = getenv_any(EV_EWTS_LOGGING, "") normalized = (raw_value or "").strip().lower() # convert None or "" to "", lowercase for easy comparison # Determine if logging is enabled @@ -102,7 +103,7 @@ def configure_logging(): ) log_level = translate_ngwpc_log_level( - os.getenv(EV_MODULE_LOGLEVEL, "INFO") + getenv_any(EV_MODULE_LOGLEVEL, "INFO").strip() ) module_fmt = MODULE_NAME.upper().ljust(LOG_MODULE_NAME_LEN)[:LOG_MODULE_NAME_LEN] diff --git a/lstm_ewts/src/lstm_ewts/helper.py b/lstm_ewts/src/lstm_ewts/helper.py new file mode 100644 index 0000000..db90ea3 --- /dev/null +++ b/lstm_ewts/src/lstm_ewts/helper.py @@ -0,0 +1,32 @@ +import os + +# NOTE: +# ngen sets some env vars from C++ after the Python interpreter has started. +# In embedded Python, os.environ may not reflect those changes. +# getenv_any() falls back to libc getenv() and syncs os.environ. +def getenv_any(key: str, default: str = "") -> str: + """ + Get an environment variable reliably even when it is set from C/C++ + after the Python interpreter has started (embedded Python). + Prefers os.environ/os.getenv, falls back to libc getenv. + """ + # First try Python's mapping + v = os.environ.get(key) + if v is not None: + return v + + # Fallback: direct libc getenv (sees process env even if Python mapping is stale) + try: + import ctypes, ctypes.util + libc = ctypes.CDLL(ctypes.util.find_library("c")) + libc.getenv.restype = ctypes.c_char_p + b = libc.getenv(key.encode("utf-8")) + if not b: + return default + s = b.decode("utf-8") + + # Sync back into os.environ so future lookups work normally + os.environ[key] = s + return s + except Exception: + return default diff --git a/lstm_ewts/src/lstm_ewts/paths.py b/lstm_ewts/src/lstm_ewts/paths.py index f896480..a647fb2 100644 --- a/lstm_ewts/src/lstm_ewts/paths.py +++ b/lstm_ewts/src/lstm_ewts/paths.py @@ -30,6 +30,7 @@ import os from datetime import datetime, timezone +from .helper import getenv_any from .constants import ( MODULE_NAME, EV_NGEN_LOGFILEPATH, @@ -67,12 +68,12 @@ def get_log_file_path(): moduleLogFileExists = False # Determine if a log file has laready been opened for this module (either the ngen log or default) - moduleEnvVar = os.getenv(EV_MODULE_LOGFILEPATH, "") + moduleEnvVar = getenv_any(EV_MODULE_LOGFILEPATH).strip() if moduleEnvVar: logFilePath = moduleEnvVar moduleLogFileExists = True else: - ngenEnvVar = os.getenv(EV_NGEN_LOGFILEPATH, "") + ngenEnvVar = getenv_any(EV_NGEN_LOGFILEPATH).strip() if ngenEnvVar: logFilePath = ngenEnvVar else: From 2733d8ca9e08aea84acbb8cbeb24e2c79cc06cf4 Mon Sep 17 00:00:00 2001 From: Ian Todd Date: Fri, 20 Feb 2026 13:18:48 -0500 Subject: [PATCH 07/10] Merge resolution of changes from OWP --- lstm/bmi_lstm.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 8997c28..372ba2a 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -302,13 +302,14 @@ def gather_inputs( value = state.value(bmi_name) assert value.size == 1, "`value` should a single scalar in a 1d array" input_list.append(value[0]) - logger.debug(" lstm_name=%s", lstm_name) - logger.debug(" bmi_name=%s", bmi_name) - logger.debug(" type(value)=%s", type(value)) - logger.debug(" value=%s", value) + + LOG.debug(f" {lstm_name=}") + LOG.debug(f" {bmi_name=}") + LOG.debug(f" {type(value)=}") + LOG.debug(f" {value=}") collected = bmi_array(input_list) - logger.debug("Collected inputs: %s",collected) + LOG.debug(f"Collected inputs: %s", collected) return collected @@ -321,10 +322,10 @@ def scale_inputs( # Center and scale the input values for use in torch input_array_scaled = (input - mean) / std - logger.debug("### input_array =%s", input) - logger.debug("### dtype(input_array) =%s", input.dtype) - logger.debug("### type(input_array_scaled) =%s", type(input_array_scaled)) - logger.debug("### dtype(input_array_scaled) =%s", input_array_scaled.dtype) + LOG.debug("### input_array = %s", input) + LOG.debug("### dtype(input_array) = %s", input.dtype) + LOG.debug("### type(input_array_scaled) = %s", type(input_array_scaled)) + LOG.debug("### dtype(input_array_scaled) = %s", input_array_scaled.dtype) return input_array_scaled @@ -336,7 +337,7 @@ def scale_outputs( output_scale_factor_cms: float, precipitation_value: npt.NDArray, ): - logger.debug("model output: %s", output[0, 0, 0].numpy().tolist()) + LOG.debug(f"model output: {output[0, 0, 0].numpy().tolist()}") if cfg["target_variables"][0] in ["qobs_mm_per_hour", "QObs(mm/hr)", "QObs(mm/h)"]: surface_runoff_mm = output[0, 0, 0].numpy() * output_std + output_mean @@ -479,7 +480,7 @@ def update(self) -> None: def update_until(self, time: float) -> None: if time <= self.get_current_time(): current_time = self.get_current_time() - logger.warning("no update performed: time=%s <= current_time=%s", time, current_time) + LOG.warning(f"no update performed: {time=} <= {current_time=}") return None n_steps, remainder = divmod( @@ -487,8 +488,8 @@ def update_until(self, time: float) -> None: ) if remainder != 0: - logger.warning( - "time is not multiple of time step size. updating until: %s", (time - remainder) + LOG.warning( + f"time is not multiple of time step size. updating until: {time - remainder=} " ) for _ in range(int(n_steps)): From 8e90c914ec0deeccf087dc3fabf3132c50501dfa Mon Sep 17 00:00:00 2001 From: "Carolyn.Maynard" Date: Mon, 16 Mar 2026 17:02:21 -0700 Subject: [PATCH 08/10] Updates to use the nwm-ewts library --- lstm/bmi_lstm.py | 13 +-- lstm_ewts/pyproject.toml | 13 --- lstm_ewts/src/lstm_ewts/__init__.py | 34 ------- lstm_ewts/src/lstm_ewts/config.py | 127 --------------------------- lstm_ewts/src/lstm_ewts/constants.py | 40 --------- lstm_ewts/src/lstm_ewts/formatter.py | 60 ------------- lstm_ewts/src/lstm_ewts/helper.py | 32 ------- lstm_ewts/src/lstm_ewts/paths.py | 116 ------------------------ tests/lstm_ewts/conftest.py | 25 ------ tests/lstm_ewts/test_config.py | 81 ----------------- tests/lstm_ewts/test_constants.py | 10 --- tests/lstm_ewts/test_formatter.py | 65 -------------- tests/lstm_ewts/test_paths.py | 115 ------------------------ 13 files changed, 7 insertions(+), 724 deletions(-) delete mode 100644 lstm_ewts/pyproject.toml delete mode 100644 lstm_ewts/src/lstm_ewts/__init__.py delete mode 100644 lstm_ewts/src/lstm_ewts/config.py delete mode 100644 lstm_ewts/src/lstm_ewts/constants.py delete mode 100644 lstm_ewts/src/lstm_ewts/formatter.py delete mode 100644 lstm_ewts/src/lstm_ewts/helper.py delete mode 100644 lstm_ewts/src/lstm_ewts/paths.py delete mode 100644 tests/lstm_ewts/conftest.py delete mode 100644 tests/lstm_ewts/test_config.py delete mode 100644 tests/lstm_ewts/test_constants.py delete mode 100644 tests/lstm_ewts/test_formatter.py delete mode 100644 tests/lstm_ewts/test_paths.py diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 6eba587..a8e68ef 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -53,6 +53,7 @@ import typing from dataclasses import dataclass from pathlib import Path +import logging import numpy as np import numpy.typing as npt @@ -67,11 +68,11 @@ from . import nextgen_cuda_lstm from .base import BmiBase -from lstm_ewts import configure_logging, MODULE_NAME from .model_state import State, StateFacade, Var -import logging -LOG = logging.getLogger(MODULE_NAME) +import ewts +LOG = ewts.get_logger(ewts.LSTM_ID) + # -------------- Dynamic Attributes ----------------------------- _dynamic_input_vars = [ @@ -440,9 +441,9 @@ def __init__(self) -> None: def initialize(self, config_file: str) -> None: - # configure the Error Warning and Trapping System logger - configure_logging() - + # This is required prior to the first log message is issued by t-route. + LOG.bind() + LOG.info(f"Initializing with {config_file}") # read and setup main configuration file diff --git a/lstm_ewts/pyproject.toml b/lstm_ewts/pyproject.toml deleted file mode 100644 index aa30740..0000000 --- a/lstm_ewts/pyproject.toml +++ /dev/null @@ -1,13 +0,0 @@ -[build-system] -requires = ["setuptools>=70"] -build-backend = "setuptools.build_meta" - -[project] -name = "lstm-ewts" -version = "0.1.0" -description = "EWTS helper package for LSTM" -requires-python = ">=3.8" - -[tool.setuptools.packages.find] -where = ["src"] -include = ["lstm_ewts*"] diff --git a/lstm_ewts/src/lstm_ewts/__init__.py b/lstm_ewts/src/lstm_ewts/__init__.py deleted file mode 100644 index 7c8a90b..0000000 --- a/lstm_ewts/src/lstm_ewts/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Error Warning and Trapping System (EWTS) Package API - -This package provides a centralized, named logging configuration for the -Error, Warning, and Trapping System used throughout the codebase. - -EWTS configures a single, shared logger in the Python logging framework, -identified by a fixed module name. All modules that participate in EWTS -logging retrieve this logger by name via the standard logging API. - -Logging configuration should be performed once at application startup by -calling configure_logging(). The configuration function is idempotent: -subsequent calls have no effect and will not reconfigure handlers or levels. - -The logger name is exposed to allow any module to obtain the configured -logger without importing internal implementation details. - -Typical usage: - - At application startup: - from lstm_ewts import configure_logging - configure_logging() - - Within other modules: - import logging - from lstm_ewts import MODULE_NAME - - LOG = logging.getLogger(MODULE_NAME) -""" - -from .constants import MODULE_NAME -from .config import configure_logging - -__all__ = ["MODULE_NAME", "configure_logging"] diff --git a/lstm_ewts/src/lstm_ewts/config.py b/lstm_ewts/src/lstm_ewts/config.py deleted file mode 100644 index cddc4c6..0000000 --- a/lstm_ewts/src/lstm_ewts/config.py +++ /dev/null @@ -1,127 +0,0 @@ -""" -Logging configuration for the Error Warning and Trapping System (EWTS). - -This module defines the centralized logging configuration used by EWTS. -It is responsible for creating and configuring a single, named logger -within the Python logging framework, based on environment variables -provided by the runtime environment (e.g., ngen). - -Logging configuration is performed via configure_logging(), which applies -handlers, formatters, and log levels to the EWTS logger. The configuration -function is idempotent: once the logger has been initialized, subsequent -calls return immediately without modifying the existing configuration. - -Configuration behavior is controlled by environment variables, whose names -are defined in constants.py: - - - EV_EWTS_LOGGING: - Enables or disables EWTS logging. If set to "DISABLED", logging is - disabled entirely for the EWTS logger. If unset, logging is enabled - by default. - - - EV_MODULE_LOGLEVEL: - Specifies the log level for the EWTS logger. Supported values include - standard Python logging levels as well as ngen-style levels (e.g., - "SEVERE", "FATAL"), which are translated to Python equivalents. - -Log output is directed to a file determined by the path-resolution utilities -in paths.py. If a log file cannot be created, logging falls back to stdout. - -This module does not expose logging APIs directly; callers are expected to -retrieve the configured logger by name using logging.getLogger(MODULE_NAME). -""" - -import logging -import sys -import os - -from .constants import ( - MODULE_NAME, - EV_EWTS_LOGGING, - EV_MODULE_LOGLEVEL, - LOG_MODULE_NAME_LEN, -) -from .formatter import CustomFormatter -from .paths import get_log_file_path -from .helper import getenv_any - -def translate_ngwpc_log_level(level: str) -> str: - level = level.strip().upper() - return { - "SEVERE": "ERROR", - "FATAL": "CRITICAL", - }.get(level, level) - - -def force_info(handler, logger, msg, *args): - record = logger.makeRecord( - logger.name, - logging.INFO, - __file__, - 0, - msg, - args, - None, - ) - handler.emit(record) - - -def configure_logging(): - ''' - Set logging level and specify logger configuration based on environment variables set by ngen - ''' - logger = logging.getLogger(MODULE_NAME) - - if getattr(logger, "_initialized", False): - return logger # logger already initialized, nothing else to do - - # Default to enabled if flag not set or is set to disabled - raw_value = getenv_any(EV_EWTS_LOGGING, "") - normalized = (raw_value or "").strip().lower() # convert None or "" to "", lowercase for easy comparison - - # Determine if logging is enabled - enabled = normalized != "disabled" - - # Inform user if logging is enabled by default (env not explicitly set to "enabled") - if enabled and normalized not in ("enabled",): - print(f"{EV_EWTS_LOGGING} not explicitly set to 'ENABLED'; logging ENABLED by default", flush=True) - - if not enabled: - logger.disabled = True - logger._initialized = True - print(f"Module {MODULE_NAME} Logging DISABLED", flush=True) - return logger - - print(f"Module {MODULE_NAME} Logging ENABLED", flush=True) - - logFilePath, appendEntries = get_log_file_path() - - handler = ( - logging.FileHandler(logFilePath, mode="a" if appendEntries else "w") - if logFilePath - else logging.StreamHandler(sys.stdout) - ) - - log_level = translate_ngwpc_log_level( - getenv_any(EV_MODULE_LOGLEVEL, "INFO").strip() - ) - - module_fmt = MODULE_NAME.upper().ljust(LOG_MODULE_NAME_LEN)[:LOG_MODULE_NAME_LEN] - - formatter = CustomFormatter( - fmt=f"%(asctime)s.%(msecs)03d {module_fmt} %(levelname_padded)s %(message)s", - datefmt="%Y-%m-%dT%H:%M:%S", - ) - handler.setFormatter(formatter) - - # Setup logger - logger.handlers.clear() # Clear any default handlers - logger.setLevel(log_level) - logger.addHandler(handler) - - # Write log level INFO message to log regradless of the actual log level - force_info(handler, logger, "Log level set to %s", log_level) - print(f"Module {MODULE_NAME} Log Level set to {log_level}", flush=True) - - logger._initialized = True - return logger diff --git a/lstm_ewts/src/lstm_ewts/constants.py b/lstm_ewts/src/lstm_ewts/constants.py deleted file mode 100644 index de196db..0000000 --- a/lstm_ewts/src/lstm_ewts/constants.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Constants and configuration keys for the Error Warning and Trapping System (EWTS). - -This module defines all constant values used by EWTS for logging configuration, -environment variable integration, and log file naming. These values represent -the stable interface between EWTS, ngen, and participating Python modules. - -Constants are grouped into two categories: - - 1) Module-specific constants: - Values that uniquely identify the current ngen module, including the - logger name and module-specific environment variables. - - 2) Common constants: - Values shared across ngen modules that control global logging behavior, - filesystem layout, and integration with the ngen runtime environment. - -These constants are intentionally centralized to ensure consistent behavior -across the codebase and to avoid hard-coded strings in implementation logic. -Callers should treat these values as read-only. -""" - - -# Values unique to each ngen module -MODULE_NAME = "LSTM" -EV_MODULE_LOGLEVEL = "LSTM_LOGLEVEL" # This modules log level -EV_MODULE_LOGFILEPATH = "LSTM_LOGFILEPATH" # This modules log full log filename - -# Values common to all ngen modules -EV_NGEN_LOGFILEPATH = "NGEN_LOG_FILE_PATH" # Environment variable name with the log file location typically set by ngen -EV_EWTS_LOGGING = "NGEN_EWTS_LOGGING" # Environment variable name with the enable/disable state for the Error Warning - # and Trapping System typically set by ngen - -DS = "/" # Directory separator -LOG_DIR_DEFAULT = "run-logs" # Default parent log directory string if env var empty & ngencerf doesn't exist -LOG_DIR_NGENCERF = "/ngencerf/data" # ngenCERF log directory string if environement var empty. -LOG_FILE_EXT = "log" # Log file name extension -LOG_MODULE_NAME_LEN = 8 # Width of module name for log entries - - diff --git a/lstm_ewts/src/lstm_ewts/formatter.py b/lstm_ewts/src/lstm_ewts/formatter.py deleted file mode 100644 index f2531c1..0000000 --- a/lstm_ewts/src/lstm_ewts/formatter.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Custom log record formatting for the Error Warning and Trapping System (EWTS). - -This module defines a custom logging formatter used by EWTS to produce -consistent, ngen-compatible log output across all participating modules. - -The formatter applies the following behaviors: - - - Forces all timestamps to UTC, independent of system locale settings. - - Formats timestamps with millisecond precision. - - Maps Python logging levels to ngen-style severity names - (e.g., ERROR → SEVERE, CRITICAL → FATAL). - - Pads and normalizes level names to fixed width for column alignment. - - Strips trailing whitespace and newline characters from log messages. - -The formatter operates entirely within the Python logging framework and does -not modify logger configuration or handler behavior. It is intended to be used -by the EWTS logging configuration layer and not instantiated directly by -application code. -""" - -import logging -import time - -class CustomFormatter(logging.Formatter): - LEVEL_NAME_MAP = { - logging.DEBUG: "DEBUG", - logging.INFO: "INFO", - logging.WARNING: "WARNING", - logging.ERROR: "SEVERE", - logging.CRITICAL: "FATAL" - } - - # Apply custom formatter (UTC timestamps applied only to this formatter) - def converter(self, timestamp): - """Override time converter to return UTC time tuple""" - return time.gmtime(timestamp) - - def formatTime(self, record, datefmt=None): - """Use our UTC converter""" - ct = self.converter(record.created) - if datefmt: - return time.strftime(datefmt, ct) - t = time.strftime("%Y-%m-%d %H:%M:%S", ct) - return f"{t},{int(record.msecs):03d}" - - def format(self, record): - # Strip trailing whitespace/newlines from the message - if record.msg: - record.msg = str(record.msg).rstrip() - - # Map level names - original_levelname = record.levelname - record.levelname = self.LEVEL_NAME_MAP.get(record.levelno, original_levelname) - record.levelname_padded = record.levelname.ljust(7)[:7] # Exactly 7 chars - formatted = super().format(record) - - # Restore original levelname - record.levelname = original_levelname # Restore original in case it's reused - return formatted diff --git a/lstm_ewts/src/lstm_ewts/helper.py b/lstm_ewts/src/lstm_ewts/helper.py deleted file mode 100644 index db90ea3..0000000 --- a/lstm_ewts/src/lstm_ewts/helper.py +++ /dev/null @@ -1,32 +0,0 @@ -import os - -# NOTE: -# ngen sets some env vars from C++ after the Python interpreter has started. -# In embedded Python, os.environ may not reflect those changes. -# getenv_any() falls back to libc getenv() and syncs os.environ. -def getenv_any(key: str, default: str = "") -> str: - """ - Get an environment variable reliably even when it is set from C/C++ - after the Python interpreter has started (embedded Python). - Prefers os.environ/os.getenv, falls back to libc getenv. - """ - # First try Python's mapping - v = os.environ.get(key) - if v is not None: - return v - - # Fallback: direct libc getenv (sees process env even if Python mapping is stale) - try: - import ctypes, ctypes.util - libc = ctypes.CDLL(ctypes.util.find_library("c")) - libc.getenv.restype = ctypes.c_char_p - b = libc.getenv(key.encode("utf-8")) - if not b: - return default - s = b.decode("utf-8") - - # Sync back into os.environ so future lookups work normally - os.environ[key] = s - return s - except Exception: - return default diff --git a/lstm_ewts/src/lstm_ewts/paths.py b/lstm_ewts/src/lstm_ewts/paths.py deleted file mode 100644 index a647fb2..0000000 --- a/lstm_ewts/src/lstm_ewts/paths.py +++ /dev/null @@ -1,116 +0,0 @@ -""" -Log file path resolution utilities for the Error Warning and Trapping System (EWTS). - -This module provides helper functions for constructing and validating log file -paths used by the EWTS logging configuration. Log file selection follows a -well-defined precedence based on environment variables and runtime availability. - -Log file path precedence: - - 1. If the NGEN-provided log file path is available via the environment variable - defined in EV_NGEN_LOGFILEPATH, use that path. - - 2. Otherwise, create a default, module-specific log file: - 2.1) Create a base log directory under the ngenCERF data directory if it - exists; otherwise fall back to the user's home directory. - 2.2) Create a child directory using the current username if available, - otherwise use the current UTC date (YYYYMMDD). - 2.3) Construct a log filename using the module name and a UTC timestamp. - -The resolved log file path is validated by attempting to open the file. Upon -successful creation or reuse, the full log file path is stored in the -EV_MODULE_LOGFILEPATH environment variable so subsequent calls reuse the same -file. If log file creation fails, entries will be written to stdout. - -This module does not configure loggers directly; it only resolves filesystem -paths and associated metadata required by the logging configuration layer. -""" - -import getpass -import os -from datetime import datetime, timezone - -from .helper import getenv_any -from .constants import ( - MODULE_NAME, - EV_NGEN_LOGFILEPATH, - EV_MODULE_LOGFILEPATH, - DS, - LOG_DIR_DEFAULT, - LOG_DIR_NGENCERF, - LOG_FILE_EXT, -) - -def create_timestamp(date_only=False, iso=False, append_ms=False): - now = datetime.now(timezone.utc) - - if date_only: - ts = now.strftime("%Y%m%d") - elif iso: - ts = now.strftime("%Y-%m-%dT%H:%M:%S") - else: - ts = now.strftime("%Y%m%dT%H%M%S") - - if append_ms: - ts += f".{now.microsecond // 1000:03d}" - - return ts - -def get_log_file_path(): - # Determine the log file path using the following precedence: - # 1) Use the ngen-provided log file path if available in the NGEN_LOG_FILE_PATH environment variable - # 2) Otherwise, create a default module-specific log file using the module name and a UTC timestamp. - # 2.1) First create a subdirectory under the ngenCERF data directory if available, otherwise the user home directory. - # 2.2) Next create a subdirectory name using the username, if available, otherwise use the YYYYMMDD. - # 2.3) Attempt to open the log file and upon failure, use stdout. - - appendEntries = True - moduleLogFileExists = False - - # Determine if a log file has laready been opened for this module (either the ngen log or default) - moduleEnvVar = getenv_any(EV_MODULE_LOGFILEPATH).strip() - if moduleEnvVar: - logFilePath = moduleEnvVar - moduleLogFileExists = True - else: - ngenEnvVar = getenv_any(EV_NGEN_LOGFILEPATH).strip() - if ngenEnvVar: - logFilePath = ngenEnvVar - else: - print(f"Module {MODULE_NAME} Env var {EV_NGEN_LOGFILEPATH} not found. Creating default log name.") - appendEntries = False - baseDir = ( - f"{LOG_DIR_NGENCERF}{DS}{LOG_DIR_DEFAULT}" - if os.path.isdir(LOG_DIR_NGENCERF) - else f"{os.path.expanduser('~')}{DS}{LOG_DIR_DEFAULT}" - ) - try: - os.makedirs(baseDir, exist_ok=True) - - childDir = getpass.getuser() or create_timestamp(True) - logFileDir = f"{baseDir}{DS}{childDir}" - os.makedirs(logFileDir, exist_ok=True) - - logFilePath = ( - f"{logFileDir}{DS}{MODULE_NAME}_{create_timestamp()}.{LOG_FILE_EXT}" - ) - except Exception as e: - print(f"Module {MODULE_NAME} {e}", flush=True) - logFilePath = "" - - # Ensure log file can be opened and set module env var - try: - if (logFilePath): - mode = "a" if appendEntries else "w" - with open(logFilePath, mode): - pass - if not moduleLogFileExists: - os.environ[EV_MODULE_LOGFILEPATH] = logFilePath - print(f"Module {MODULE_NAME} Log File: {logFilePath}", flush=True) - else: - raise IOError - except Exception: - print(f"Module {MODULE_NAME} Unable to open log file: {logFilePath}", flush=True) - print(f"Module {MODULE_NAME} Log entries will be writen to stdout", flush=True) - - return logFilePath, appendEntries diff --git a/tests/lstm_ewts/conftest.py b/tests/lstm_ewts/conftest.py deleted file mode 100644 index 2d9f623..0000000 --- a/tests/lstm_ewts/conftest.py +++ /dev/null @@ -1,25 +0,0 @@ -import logging -import pytest - - -@pytest.fixture -def clean_ewts_env(monkeypatch): - """ - Ensure EWTS-related environment variables are unset and - logging is reset before each test. - """ - # EWTS / module env vars - monkeypatch.delenv("NGEN_LOG_FILE_PATH", raising=False) - monkeypatch.delenv("LSTM_LOGLEVEL", raising=False) - monkeypatch.delenv("LSTM_LOGFILEPATH", raising=False) - monkeypatch.delenv("NGEN_EWTS_LOGGING", raising=False) - - # Reset logging state (important!) - logging.shutdown() - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - - yield - - # Cleanup after test (defensive) - logging.shutdown() diff --git a/tests/lstm_ewts/test_config.py b/tests/lstm_ewts/test_config.py deleted file mode 100644 index 238a03d..0000000 --- a/tests/lstm_ewts/test_config.py +++ /dev/null @@ -1,81 +0,0 @@ -import pytest - -import logging -from lstm_ewts.config import configure_logging, translate_ngwpc_log_level -from lstm_ewts.constants import MODULE_NAME, EV_EWTS_LOGGING - -# ------------------------------ -def test_configure_logging_default(clean_ewts_env): - logger = configure_logging() - - assert logger.name == MODULE_NAME - assert logger.level == logging.INFO - assert not logger.disabled - -# ------------------------------ -def test_configure_logging_idempotent(clean_ewts_env): - logger1 = configure_logging() - logger2 = configure_logging() - - assert logger1 is logger2 - assert getattr(logger1, "_initialized", False) - -# ------------------------------ -@pytest.mark.parametrize("inp,expected", [ - ("INFO", "INFO"), - ("SeVeRe", "ERROR"), - ("fatal", "CRITICAL"), - (" debug ", "DEBUG"), -]) -def test_translate_ngwpc_log_level(inp, expected): - assert translate_ngwpc_log_level(inp) == expected - -# ------------------------------ -@pytest.mark.parametrize("env_value,expected_enabled", [ - (None, True), # default: enabled - ("DISABLED", False), - ("ENABLED", True), - ("disabled", False), - ("enabled", True), - ("anystring", True), - ("", True), -]) -@pytest.mark.parametrize("level_input,expected_level", [ - ("DEBUG", logging.DEBUG), - ("INFO", logging.INFO), - ("SEVERE", logging.ERROR), - ("FATAL", logging.CRITICAL), -]) -def test_ewts_logger_matrix(clean_ewts_env, monkeypatch, capsys, env_value, expected_enabled, level_input, expected_level): - # Set environment variables - if env_value is None: - monkeypatch.delenv("NGEN_EWTS_LOGGING", raising=False) - else: - monkeypatch.setenv("NGEN_EWTS_LOGGING", env_value) - - monkeypatch.setenv("LSTM_LOGLEVEL", level_input) - - # Force logger re-initialization - logger = logging.getLogger(MODULE_NAME) - logger.handlers.clear() - logger._initialized = False - logger.disabled = False # ensure proper reset - - # Configure logger - logger = configure_logging() - - # Capture stdout - captured = capsys.readouterr() - - # Assertions - assert logger.name == MODULE_NAME - assert (not logger.disabled) == expected_enabled # True if enabled - if expected_enabled: - assert logger.level == expected_level - - # Assertions for default-enabled print - if expected_enabled and (env_value is None or env_value not in ("ENABLED", "enabled")): - assert f"{EV_EWTS_LOGGING} not explicitly set" in captured.out - else: - assert f"{EV_EWTS_LOGGING} not explicitly set" not in captured.out - diff --git a/tests/lstm_ewts/test_constants.py b/tests/lstm_ewts/test_constants.py deleted file mode 100644 index 4499bc5..0000000 --- a/tests/lstm_ewts/test_constants.py +++ /dev/null @@ -1,10 +0,0 @@ -from lstm_ewts.constants import ( - MODULE_NAME, - LOG_MODULE_NAME_LEN, -) - -def test_module_name_is_string(): - assert isinstance(MODULE_NAME, str) - -def test_module_name_length_fits_field(): - assert len(MODULE_NAME) <= LOG_MODULE_NAME_LEN diff --git a/tests/lstm_ewts/test_formatter.py b/tests/lstm_ewts/test_formatter.py deleted file mode 100644 index 3a6af0c..0000000 --- a/tests/lstm_ewts/test_formatter.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -import pytest -from lstm_ewts.formatter import CustomFormatter -from lstm_ewts.constants import MODULE_NAME - -@pytest.fixture -def formatter(): - fmt = "%(asctime)s %(levelname_padded)s %(message)s" - return CustomFormatter(fmt=fmt, datefmt="%Y-%m-%dT%H:%M:%S") - -@pytest.mark.parametrize( - "level,expected", - [ - (logging.DEBUG, "DEBUG"), - (logging.INFO, "INFO"), - (logging.WARNING, "WARNING"), - (logging.ERROR, "SEVERE"), - (logging.CRITICAL, "FATAL"), - ] -) -def test_level_name_mapping(formatter, level, expected): - record = logging.LogRecord( - name=MODULE_NAME, - level=level, - pathname="test", - lineno=0, - msg="Test message", - args=None, - exc_info=None - ) - formatted = formatter.format(record) - # Level name should appear in formatted string - assert expected in formatted - -def test_utc_timestamp(formatter): - record = logging.LogRecord( - name=MODULE_NAME, - level=logging.INFO, - pathname="test", - lineno=0, - msg="UTC test", - args=None, - exc_info=None - ) - formatted = formatter.format(record) - # Timestamp should be in UTC format "YYYY-MM-DDTHH:MM:SS" - ts_str = formatted.split()[0] - from datetime import datetime - dt = datetime.strptime(ts_str, "%Y-%m-%dT%H:%M:%S") - # It's enough to check it parses without error - -def test_trailing_whitespace_stripped(formatter): - record = logging.LogRecord( - name=MODULE_NAME, - level=logging.INFO, - pathname="test", - lineno=0, - msg="Message with space \n", - args=None, - exc_info=None - ) - formatted = formatter.format(record) - # Trailing whitespace/newline should be removed - assert " \n" not in formatted - assert formatted.endswith("Message with space") diff --git a/tests/lstm_ewts/test_paths.py b/tests/lstm_ewts/test_paths.py deleted file mode 100644 index f0072e4..0000000 --- a/tests/lstm_ewts/test_paths.py +++ /dev/null @@ -1,115 +0,0 @@ -import os -import getpass -from datetime import datetime -import pytest -from lstm_ewts import paths -from lstm_ewts.paths import create_timestamp, get_log_file_path -from lstm_ewts.constants import MODULE_NAME, EV_MODULE_LOGFILEPATH, EV_NGEN_LOGFILEPATH - -# ------------------------------- -# Fixture for a clean log environment -# ------------------------------- -@pytest.fixture -def clean_log_env(tmp_path, monkeypatch): - """Set up a temporary log environment and clean env vars. - - Yields a dict with: - tmp_dir : Path of temporary base directory - monkeypatch : the pytest monkeypatch object for further tweaks - """ - # Clear env vars - monkeypatch.delenv(EV_MODULE_LOGFILEPATH, raising=False) - monkeypatch.delenv(EV_NGEN_LOGFILEPATH, raising=False) - - # Patch constants to use tmp_path - monkeypatch.setattr(paths, "LOG_DIR_NGENCERF", tmp_path) - monkeypatch.setattr(paths, "LOG_DIR_DEFAULT", "run-logs") - - yield {"tmp_dir": tmp_path, "monkeypatch": monkeypatch} - - -# ------------------------------- -# Tests for create_timestamp() -# ------------------------------- -def test_create_timestamp_default(): - ts = create_timestamp() - assert len(ts) >= 15 - assert "T" in ts - -def test_create_timestamp_date_only(): - ts = create_timestamp(date_only=True) - assert len(ts) == 8 - -def test_create_timestamp_iso(): - ts = create_timestamp(iso=True) - assert "T" in ts and "-" in ts and ":" in ts - -def test_create_timestamp_append_ms(): - ts = create_timestamp(append_ms=True) - assert "." in ts - - -# ------------------------------- -# Tests for get_log_file_path() -# ------------------------------- -def test_get_log_file_path_uses_module_env(clean_log_env): - tmp_path = clean_log_env["tmp_dir"] - monkeypatch = clean_log_env["monkeypatch"] - - logfile = tmp_path / "test_module.log" - monkeypatch.setenv(EV_MODULE_LOGFILEPATH, str(logfile)) - - path, append = get_log_file_path() - assert path == str(logfile) - assert append is True - - -def test_get_log_file_path_uses_ngen_env(clean_log_env): - monkeypatch = clean_log_env["monkeypatch"] - tmp_path = clean_log_env["tmp_dir"] - - monkeypatch.delenv(EV_MODULE_LOGFILEPATH, raising=False) - ngen_file = tmp_path / "ngen.log" - monkeypatch.setenv(EV_NGEN_LOGFILEPATH, str(ngen_file)) - - path, append = get_log_file_path() - assert path == str(ngen_file) - assert append is True - - -def test_get_log_file_path_creates_user_subdir(clean_log_env): - tmp_path = clean_log_env["tmp_dir"] - monkeypatch = clean_log_env["monkeypatch"] - - monkeypatch.delenv(EV_MODULE_LOGFILEPATH, raising=False) - monkeypatch.delenv(EV_NGEN_LOGFILEPATH, raising=False) - - # Use real username - monkeypatch.setattr(getpass, "getuser", lambda: "alice") - - path, append = get_log_file_path() - - # Subdirectory should be username - subdir = os.path.basename(os.path.dirname(path)) - assert subdir == "alice" - assert path.endswith(".log") - assert os.path.exists(path) - - -def test_get_log_file_path_fallback_username(clean_log_env): - tmp_path = clean_log_env["tmp_dir"] - monkeypatch = clean_log_env["monkeypatch"] - - monkeypatch.delenv(EV_MODULE_LOGFILEPATH, raising=False) - monkeypatch.delenv(EV_NGEN_LOGFILEPATH, raising=False) - - # Simulate getuser() returning None - monkeypatch.setattr(getpass, "getuser", lambda: None) - - path, append = get_log_file_path() - - subdir = os.path.basename(os.path.dirname(path)) - # Should fall back to YYYYMMDD - assert len(subdir) == 8 and subdir.isdigit() - assert path.endswith(".log") - assert os.path.exists(path) From 0ee7686eb42e70cc1bf10fe2a1acf232a3db44ca Mon Sep 17 00:00:00 2001 From: Carolyn Maynard Date: Wed, 18 Mar 2026 11:05:54 -0700 Subject: [PATCH 09/10] Update unit tests workflow to install only develop dependencies Removed installation of './lstm_ewts' from unit tests workflow. --- .github/workflows/unit_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 2bb9dd3..318a375 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -32,7 +32,7 @@ jobs: - name: Install dependencies run: | pip install -U pip # upgrade pip - pip install '.[develop]' './lstm_ewts' + pip install '.[develop]' - name: Echo dependency versions run: | pip freeze From 88bffb9f60ca2806099f2505d6e3518180b100e4 Mon Sep 17 00:00:00 2001 From: Carolyn Maynard Date: Wed, 18 Mar 2026 11:33:13 -0700 Subject: [PATCH 10/10] Add installation of ewts from GitHub repository --- .github/workflows/unit_tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 318a375..6c582bb 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -32,6 +32,7 @@ jobs: - name: Install dependencies run: | pip install -U pip # upgrade pip + pip install git+https://github.com/ngwpc/nwm-ewts.git#subdirectory=runtime/python/ewts pip install '.[develop]' - name: Echo dependency versions run: |