diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 2bb9dd3..6c582bb 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -32,7 +32,8 @@ jobs: - name: Install dependencies run: | pip install -U pip # upgrade pip - pip install '.[develop]' './lstm_ewts' + pip install git+https://github.com/ngwpc/nwm-ewts.git#subdirectory=runtime/python/ewts + pip install '.[develop]' - name: Echo dependency versions run: | pip freeze diff --git a/lstm/bmi_lstm.py b/lstm/bmi_lstm.py index 85422e5..a8e68ef 100644 --- a/lstm/bmi_lstm.py +++ b/lstm/bmi_lstm.py @@ -53,20 +53,26 @@ import typing from dataclasses import dataclass from pathlib import Path +import logging import numpy as np import numpy.typing as npt import pandas as pd +import pickle import torch import yaml +try: + from yaml import CSafeLoader as SafeLoader +except ImportError: + from yaml import SafeLoader from . import nextgen_cuda_lstm from .base import BmiBase -from lstm_ewts import configure_logging, MODULE_NAME from .model_state import State, StateFacade, Var -import logging -LOG = logging.getLogger(MODULE_NAME) +import ewts +LOG = ewts.get_logger(ewts.LSTM_ID) + # -------------- Dynamic Attributes ----------------------------- _dynamic_input_vars = [ @@ -147,7 +153,7 @@ def __init__(self, cfg: dict[str, typing.Any], output_scaling_factor_cms: float) # load training feature scales scaler_file = cfg["run_dir"] / "train_data/train_data_scaler.yml" with scaler_file.open("r") as fp: - train_data_scaler = yaml.safe_load(fp) + train_data_scaler = yaml.load(fp, Loader=SafeLoader) self.scalars = load_training_scalars(cfg, train_data_scaler) # initialize torch lstm object @@ -200,6 +206,16 @@ def update(self, state: Valuer) -> typing.Iterable[Var]: precipitation_mm_h ) + def serialize(self): + return { + "h": self.h_t.numpy(), + "c": self.c_t.numpy() + } + + def deserialize(self, data: dict): + self.h_t = torch.from_numpy(data["h"]) + self.c_t = torch.from_numpy(data["c"]) + def bmi_array(arr: list[float]) -> npt.NDArray: """Trivial wrapper function to ensure the expected numpy array datatype is used.""" @@ -419,16 +435,20 @@ def __init__(self) -> None: self.cfg_bmi: dict[str, typing.Any] self.ensemble_members: list[EnsembleMember] - def initialize(self, config_file: str) -> None: + # statically stored seriaized data + self._serialized_size = np.array([0], dtype=np.uint64) + self._serialized = np.array([], dtype=np.uint8) - # configure the Error Warning and Trapping System logger - configure_logging() + def initialize(self, config_file: str) -> None: + # This is required prior to the first log message is issued by t-route. + LOG.bind() + LOG.info(f"Initializing with {config_file}") # read and setup main configuration file with open(config_file, "r") as fp: - self.cfg_bmi = yaml.safe_load(fp) + self.cfg_bmi = yaml.load(fp, Loader=SafeLoader) coerce_config(self.cfg_bmi) # ----------- The output is area normalized, this is needed to un-normalize it @@ -440,7 +460,7 @@ def initialize(self, config_file: str) -> None: # initialize ensemble members self.ensemble_members = [] for member_cfg_file in self.cfg_bmi["train_cfg_file"]: - cfg = yaml.safe_load(member_cfg_file.read_text()) + cfg = yaml.load(member_cfg_file.read_text(), Loader=SafeLoader) coerce_config(cfg) member = EnsembleMember(cfg, output_factor_cms) self.ensemble_members.append(member) @@ -515,16 +535,38 @@ def get_var_grid(self, name: str) -> int: return 0 def get_var_type(self, name: str) -> str: + if name == "serialization_state": + return self._serialized.dtype.name + elif name == "serialization_size" or name == "serialization_create": + return self._serialized_size.dtype.name + elif name == "serialization_free": + return np.dtype(np.intc).name + elif name == "reset_time": + return np.dtype(np.double).name return self.get_value_ptr(name).dtype.name def get_var_units(self, name: str) -> str: return first_containing(name, self._outputs, self._dynamic_inputs).unit(name) def get_var_itemsize(self, name: str) -> int: + if name == "serialization_state": + return self._serialized.dtype.itemsize + if name == "serialization_size" or name == "serialization_create": + return self._serialized_size.dtype.itemsize + if name == "serialization_free": + return np.dtype(np.intc).itemsize + if name == "reset_time": + return np.dtype(np.double).itemsize return self.get_value_ptr(name).itemsize def get_var_nbytes(self, name: str) -> int: - return self.get_var_itemsize(name) * len(self.get_value_ptr(name)) + if name == "serialization_create": + return self._serialized_size.nbytes + if name == "serialization_free": + return np.dtype(np.intc).itemsize + if name == "reset_time": + return np.dtype(np.double).itemsize + return self.get_value_ptr(name).nbytes def get_var_location(self, name: str) -> str: # raises KeyError on failure @@ -553,6 +595,10 @@ def get_value(self, name: str, dest: np.ndarray) -> np.ndarray: def get_value_ptr(self, name: str) -> np.ndarray: """Returns a _reference_ to a variable's np.NDArray.""" + if name == "serialization_state": + return self._serialized + elif name == "serialization_size": + return self._serialized_size return first_containing(name, self._outputs, self._dynamic_inputs).value(name) def get_value_at_indices( @@ -563,9 +609,18 @@ def get_value_at_indices( ).value_at_indices(name, dest, inds) def set_value(self, name: str, src: np.ndarray) -> None: - return first_containing(name, self._outputs, self._dynamic_inputs).set_value( - name, src - ) + if name == "serialization_state": + self._deserialize(src) + elif name == "serialization_create": + self._serialize() + elif name == "serialization_free": + self._free_serialized() + elif name == "reset_time": + self._timestep = 0 + else: + return first_containing(name, self._outputs, self._dynamic_inputs).set_value( + name, src + ) def set_value_at_indices( self, name: str, inds: np.ndarray, src: np.ndarray @@ -593,6 +648,38 @@ def get_grid_type(self, grid: int) -> str: return "scalar" raise RuntimeError(f"unsupported grid type: {grid!s}. only support 0") + def _serialize(self): + """Convert all dynamic properties that can change after the `bmi_LSTM` has had `initialize()` called into an object that can be serialized through `pickle`. + Then, set the BMI's `_serialized` property to the byte representation of that pickled data and adjust the static `_serialized_size` property.""" + data = { + "dynamic_inputs": self._dynamic_inputs.serialize(), + "static_inputs": self._static_inputs.serialize(), + "outputs": self._outputs.serialize(), + "ensemble": [em.serialize() for em in self.ensemble_members], + "timestep": self._timestep, + } + serialized = pickle.dumps(data) + self._serialized = np.array(bytearray(serialized), dtype=np.uint8) + self._serialized_size[0] = len(self._serialized) + + def _deserialize(self, array: np.ndarray): + """Interpret the bytes of the numpy array as previously pickled data from `_serialize()` and update the current values. + No data structure check will be made on the input array or loaded bytes. It will be assumed that the input data is of the same structure as what is generated from `_serialize()`.""" + data = bytes(array) + deserialized = pickle.loads(data) + self._dynamic_inputs.deserialize(deserialized["dynamic_inputs"]) + self._static_inputs.deserialize(deserialized["static_inputs"]) + self._outputs.deserialize(deserialized["outputs"]) + for bmi_em, data_em in zip(self.ensemble_members, deserialized["ensemble"], strict=True): + bmi_em.deserialize(data_em) + self._timestep = deserialized["timestep"] + self._free_serialized() + + def _free_serialized(self): + """Clear the current serialized data and set the size property value to 0.""" + self._serialized_size[0] = 0 + self._serialized = np.array([], dtype=self._serialized.dtype) + def coerce_config(cfg: dict[str, typing.Any]): for key, val in cfg.items(): diff --git a/lstm/model_state.py b/lstm/model_state.py index c13fcfc..bbacbd6 100644 --- a/lstm/model_state.py +++ b/lstm/model_state.py @@ -94,6 +94,23 @@ def __iter__(self) -> typing.Iterator[Var]: def __len__(self) -> int: return len(self._name_mapping) + def serialize(self): + """Return the State represented as a list of dicts representing the `Var` properties.""" + return [ + {"name": var.name, "unit": var.unit, "value": var.value} + for var in self._name_mapping.values() + ] + + def deserialize(self, values: list): + """Replace the current Vars with values from the intput list.""" + self._name_mapping.clear() + for var in values: + self._name_mapping[var["name"]] = Var( + name=var["name"], + unit=var["unit"], + value=var["value"] + ) + class StateFacade: """ diff --git a/lstm_ewts/pyproject.toml b/lstm_ewts/pyproject.toml deleted file mode 100644 index aa30740..0000000 --- a/lstm_ewts/pyproject.toml +++ /dev/null @@ -1,13 +0,0 @@ -[build-system] -requires = ["setuptools>=70"] -build-backend = "setuptools.build_meta" - -[project] -name = "lstm-ewts" -version = "0.1.0" -description = "EWTS helper package for LSTM" -requires-python = ">=3.8" - -[tool.setuptools.packages.find] -where = ["src"] -include = ["lstm_ewts*"] diff --git a/lstm_ewts/src/lstm_ewts/__init__.py b/lstm_ewts/src/lstm_ewts/__init__.py deleted file mode 100644 index 7c8a90b..0000000 --- a/lstm_ewts/src/lstm_ewts/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Error Warning and Trapping System (EWTS) Package API - -This package provides a centralized, named logging configuration for the -Error, Warning, and Trapping System used throughout the codebase. - -EWTS configures a single, shared logger in the Python logging framework, -identified by a fixed module name. All modules that participate in EWTS -logging retrieve this logger by name via the standard logging API. - -Logging configuration should be performed once at application startup by -calling configure_logging(). The configuration function is idempotent: -subsequent calls have no effect and will not reconfigure handlers or levels. - -The logger name is exposed to allow any module to obtain the configured -logger without importing internal implementation details. - -Typical usage: - - At application startup: - from lstm_ewts import configure_logging - configure_logging() - - Within other modules: - import logging - from lstm_ewts import MODULE_NAME - - LOG = logging.getLogger(MODULE_NAME) -""" - -from .constants import MODULE_NAME -from .config import configure_logging - -__all__ = ["MODULE_NAME", "configure_logging"] diff --git a/lstm_ewts/src/lstm_ewts/config.py b/lstm_ewts/src/lstm_ewts/config.py deleted file mode 100644 index f5f0267..0000000 --- a/lstm_ewts/src/lstm_ewts/config.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -Logging configuration for the Error Warning and Trapping System (EWTS). - -This module defines the centralized logging configuration used by EWTS. -It is responsible for creating and configuring a single, named logger -within the Python logging framework, based on environment variables -provided by the runtime environment (e.g., ngen). - -Logging configuration is performed via configure_logging(), which applies -handlers, formatters, and log levels to the EWTS logger. The configuration -function is idempotent: once the logger has been initialized, subsequent -calls return immediately without modifying the existing configuration. - -Configuration behavior is controlled by environment variables, whose names -are defined in constants.py: - - - EV_EWTS_LOGGING: - Enables or disables EWTS logging. If set to "DISABLED", logging is - disabled entirely for the EWTS logger. If unset, logging is enabled - by default. - - - EV_MODULE_LOGLEVEL: - Specifies the log level for the EWTS logger. Supported values include - standard Python logging levels as well as ngen-style levels (e.g., - "SEVERE", "FATAL"), which are translated to Python equivalents. - -Log output is directed to a file determined by the path-resolution utilities -in paths.py. If a log file cannot be created, logging falls back to stdout. - -This module does not expose logging APIs directly; callers are expected to -retrieve the configured logger by name using logging.getLogger(MODULE_NAME). -""" - -import logging -import sys -import os - -from .constants import ( - MODULE_NAME, - EV_EWTS_LOGGING, - EV_MODULE_LOGLEVEL, - LOG_MODULE_NAME_LEN, -) -from .formatter import CustomFormatter -from .paths import get_log_file_path - -def translate_ngwpc_log_level(level: str) -> str: - level = level.strip().upper() - return { - "SEVERE": "ERROR", - "FATAL": "CRITICAL", - }.get(level, level) - - -def force_info(handler, logger, msg, *args): - record = logger.makeRecord( - logger.name, - logging.INFO, - __file__, - 0, - msg, - args, - None, - ) - handler.emit(record) - - -def configure_logging(): - ''' - Set logging level and specify logger configuration based on environment variables set by ngen - ''' - logger = logging.getLogger(MODULE_NAME) - - if getattr(logger, "_initialized", False): - return logger # logger already initialized, nothing else to do - - # Default to enabled if flag not set or is set to disabled - raw_value = os.getenv(EV_EWTS_LOGGING) - normalized = (raw_value or "").strip().lower() # convert None or "" to "", lowercase for easy comparison - - # Determine if logging is enabled - enabled = normalized != "disabled" - - # Inform user if logging is enabled by default (env not explicitly set to "enabled") - if enabled and normalized not in ("enabled",): - print(f"{EV_EWTS_LOGGING} not explicitly set to 'ENABLED'; logging ENABLED by default", flush=True) - - if not enabled: - logger.disabled = True - logger._initialized = True - print(f"Module {MODULE_NAME} Logging DISABLED", flush=True) - return logger - - print(f"Module {MODULE_NAME} Logging ENABLED", flush=True) - - logFilePath, appendEntries = get_log_file_path() - - handler = ( - logging.FileHandler(logFilePath, mode="a" if appendEntries else "w") - if logFilePath - else logging.StreamHandler(sys.stdout) - ) - - log_level = translate_ngwpc_log_level( - os.getenv(EV_MODULE_LOGLEVEL, "INFO") - ) - - module_fmt = MODULE_NAME.upper().ljust(LOG_MODULE_NAME_LEN)[:LOG_MODULE_NAME_LEN] - - formatter = CustomFormatter( - fmt=f"%(asctime)s.%(msecs)03d {module_fmt} %(levelname_padded)s %(message)s", - datefmt="%Y-%m-%dT%H:%M:%S", - ) - handler.setFormatter(formatter) - - # Setup logger - logger.handlers.clear() # Clear any default handlers - logger.setLevel(log_level) - logger.addHandler(handler) - - # Write log level INFO message to log regradless of the actual log level - force_info(handler, logger, "Log level set to %s", log_level) - print(f"Module {MODULE_NAME} Log Level set to {log_level}", flush=True) - - logger._initialized = True - return logger diff --git a/lstm_ewts/src/lstm_ewts/constants.py b/lstm_ewts/src/lstm_ewts/constants.py deleted file mode 100644 index de196db..0000000 --- a/lstm_ewts/src/lstm_ewts/constants.py +++ /dev/null @@ -1,40 +0,0 @@ -""" -Constants and configuration keys for the Error Warning and Trapping System (EWTS). - -This module defines all constant values used by EWTS for logging configuration, -environment variable integration, and log file naming. These values represent -the stable interface between EWTS, ngen, and participating Python modules. - -Constants are grouped into two categories: - - 1) Module-specific constants: - Values that uniquely identify the current ngen module, including the - logger name and module-specific environment variables. - - 2) Common constants: - Values shared across ngen modules that control global logging behavior, - filesystem layout, and integration with the ngen runtime environment. - -These constants are intentionally centralized to ensure consistent behavior -across the codebase and to avoid hard-coded strings in implementation logic. -Callers should treat these values as read-only. -""" - - -# Values unique to each ngen module -MODULE_NAME = "LSTM" -EV_MODULE_LOGLEVEL = "LSTM_LOGLEVEL" # This modules log level -EV_MODULE_LOGFILEPATH = "LSTM_LOGFILEPATH" # This modules log full log filename - -# Values common to all ngen modules -EV_NGEN_LOGFILEPATH = "NGEN_LOG_FILE_PATH" # Environment variable name with the log file location typically set by ngen -EV_EWTS_LOGGING = "NGEN_EWTS_LOGGING" # Environment variable name with the enable/disable state for the Error Warning - # and Trapping System typically set by ngen - -DS = "/" # Directory separator -LOG_DIR_DEFAULT = "run-logs" # Default parent log directory string if env var empty & ngencerf doesn't exist -LOG_DIR_NGENCERF = "/ngencerf/data" # ngenCERF log directory string if environement var empty. -LOG_FILE_EXT = "log" # Log file name extension -LOG_MODULE_NAME_LEN = 8 # Width of module name for log entries - - diff --git a/lstm_ewts/src/lstm_ewts/formatter.py b/lstm_ewts/src/lstm_ewts/formatter.py deleted file mode 100644 index f2531c1..0000000 --- a/lstm_ewts/src/lstm_ewts/formatter.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Custom log record formatting for the Error Warning and Trapping System (EWTS). - -This module defines a custom logging formatter used by EWTS to produce -consistent, ngen-compatible log output across all participating modules. - -The formatter applies the following behaviors: - - - Forces all timestamps to UTC, independent of system locale settings. - - Formats timestamps with millisecond precision. - - Maps Python logging levels to ngen-style severity names - (e.g., ERROR → SEVERE, CRITICAL → FATAL). - - Pads and normalizes level names to fixed width for column alignment. - - Strips trailing whitespace and newline characters from log messages. - -The formatter operates entirely within the Python logging framework and does -not modify logger configuration or handler behavior. It is intended to be used -by the EWTS logging configuration layer and not instantiated directly by -application code. -""" - -import logging -import time - -class CustomFormatter(logging.Formatter): - LEVEL_NAME_MAP = { - logging.DEBUG: "DEBUG", - logging.INFO: "INFO", - logging.WARNING: "WARNING", - logging.ERROR: "SEVERE", - logging.CRITICAL: "FATAL" - } - - # Apply custom formatter (UTC timestamps applied only to this formatter) - def converter(self, timestamp): - """Override time converter to return UTC time tuple""" - return time.gmtime(timestamp) - - def formatTime(self, record, datefmt=None): - """Use our UTC converter""" - ct = self.converter(record.created) - if datefmt: - return time.strftime(datefmt, ct) - t = time.strftime("%Y-%m-%d %H:%M:%S", ct) - return f"{t},{int(record.msecs):03d}" - - def format(self, record): - # Strip trailing whitespace/newlines from the message - if record.msg: - record.msg = str(record.msg).rstrip() - - # Map level names - original_levelname = record.levelname - record.levelname = self.LEVEL_NAME_MAP.get(record.levelno, original_levelname) - record.levelname_padded = record.levelname.ljust(7)[:7] # Exactly 7 chars - formatted = super().format(record) - - # Restore original levelname - record.levelname = original_levelname # Restore original in case it's reused - return formatted diff --git a/lstm_ewts/src/lstm_ewts/paths.py b/lstm_ewts/src/lstm_ewts/paths.py deleted file mode 100644 index f896480..0000000 --- a/lstm_ewts/src/lstm_ewts/paths.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -Log file path resolution utilities for the Error Warning and Trapping System (EWTS). - -This module provides helper functions for constructing and validating log file -paths used by the EWTS logging configuration. Log file selection follows a -well-defined precedence based on environment variables and runtime availability. - -Log file path precedence: - - 1. If the NGEN-provided log file path is available via the environment variable - defined in EV_NGEN_LOGFILEPATH, use that path. - - 2. Otherwise, create a default, module-specific log file: - 2.1) Create a base log directory under the ngenCERF data directory if it - exists; otherwise fall back to the user's home directory. - 2.2) Create a child directory using the current username if available, - otherwise use the current UTC date (YYYYMMDD). - 2.3) Construct a log filename using the module name and a UTC timestamp. - -The resolved log file path is validated by attempting to open the file. Upon -successful creation or reuse, the full log file path is stored in the -EV_MODULE_LOGFILEPATH environment variable so subsequent calls reuse the same -file. If log file creation fails, entries will be written to stdout. - -This module does not configure loggers directly; it only resolves filesystem -paths and associated metadata required by the logging configuration layer. -""" - -import getpass -import os -from datetime import datetime, timezone - -from .constants import ( - MODULE_NAME, - EV_NGEN_LOGFILEPATH, - EV_MODULE_LOGFILEPATH, - DS, - LOG_DIR_DEFAULT, - LOG_DIR_NGENCERF, - LOG_FILE_EXT, -) - -def create_timestamp(date_only=False, iso=False, append_ms=False): - now = datetime.now(timezone.utc) - - if date_only: - ts = now.strftime("%Y%m%d") - elif iso: - ts = now.strftime("%Y-%m-%dT%H:%M:%S") - else: - ts = now.strftime("%Y%m%dT%H%M%S") - - if append_ms: - ts += f".{now.microsecond // 1000:03d}" - - return ts - -def get_log_file_path(): - # Determine the log file path using the following precedence: - # 1) Use the ngen-provided log file path if available in the NGEN_LOG_FILE_PATH environment variable - # 2) Otherwise, create a default module-specific log file using the module name and a UTC timestamp. - # 2.1) First create a subdirectory under the ngenCERF data directory if available, otherwise the user home directory. - # 2.2) Next create a subdirectory name using the username, if available, otherwise use the YYYYMMDD. - # 2.3) Attempt to open the log file and upon failure, use stdout. - - appendEntries = True - moduleLogFileExists = False - - # Determine if a log file has laready been opened for this module (either the ngen log or default) - moduleEnvVar = os.getenv(EV_MODULE_LOGFILEPATH, "") - if moduleEnvVar: - logFilePath = moduleEnvVar - moduleLogFileExists = True - else: - ngenEnvVar = os.getenv(EV_NGEN_LOGFILEPATH, "") - if ngenEnvVar: - logFilePath = ngenEnvVar - else: - print(f"Module {MODULE_NAME} Env var {EV_NGEN_LOGFILEPATH} not found. Creating default log name.") - appendEntries = False - baseDir = ( - f"{LOG_DIR_NGENCERF}{DS}{LOG_DIR_DEFAULT}" - if os.path.isdir(LOG_DIR_NGENCERF) - else f"{os.path.expanduser('~')}{DS}{LOG_DIR_DEFAULT}" - ) - try: - os.makedirs(baseDir, exist_ok=True) - - childDir = getpass.getuser() or create_timestamp(True) - logFileDir = f"{baseDir}{DS}{childDir}" - os.makedirs(logFileDir, exist_ok=True) - - logFilePath = ( - f"{logFileDir}{DS}{MODULE_NAME}_{create_timestamp()}.{LOG_FILE_EXT}" - ) - except Exception as e: - print(f"Module {MODULE_NAME} {e}", flush=True) - logFilePath = "" - - # Ensure log file can be opened and set module env var - try: - if (logFilePath): - mode = "a" if appendEntries else "w" - with open(logFilePath, mode): - pass - if not moduleLogFileExists: - os.environ[EV_MODULE_LOGFILEPATH] = logFilePath - print(f"Module {MODULE_NAME} Log File: {logFilePath}", flush=True) - else: - raise IOError - except Exception: - print(f"Module {MODULE_NAME} Unable to open log file: {logFilePath}", flush=True) - print(f"Module {MODULE_NAME} Log entries will be writen to stdout", flush=True) - - return logFilePath, appendEntries diff --git a/tests/lstm_ewts/conftest.py b/tests/lstm_ewts/conftest.py deleted file mode 100644 index 2d9f623..0000000 --- a/tests/lstm_ewts/conftest.py +++ /dev/null @@ -1,25 +0,0 @@ -import logging -import pytest - - -@pytest.fixture -def clean_ewts_env(monkeypatch): - """ - Ensure EWTS-related environment variables are unset and - logging is reset before each test. - """ - # EWTS / module env vars - monkeypatch.delenv("NGEN_LOG_FILE_PATH", raising=False) - monkeypatch.delenv("LSTM_LOGLEVEL", raising=False) - monkeypatch.delenv("LSTM_LOGFILEPATH", raising=False) - monkeypatch.delenv("NGEN_EWTS_LOGGING", raising=False) - - # Reset logging state (important!) - logging.shutdown() - for handler in logging.root.handlers[:]: - logging.root.removeHandler(handler) - - yield - - # Cleanup after test (defensive) - logging.shutdown() diff --git a/tests/lstm_ewts/test_config.py b/tests/lstm_ewts/test_config.py deleted file mode 100644 index 238a03d..0000000 --- a/tests/lstm_ewts/test_config.py +++ /dev/null @@ -1,81 +0,0 @@ -import pytest - -import logging -from lstm_ewts.config import configure_logging, translate_ngwpc_log_level -from lstm_ewts.constants import MODULE_NAME, EV_EWTS_LOGGING - -# ------------------------------ -def test_configure_logging_default(clean_ewts_env): - logger = configure_logging() - - assert logger.name == MODULE_NAME - assert logger.level == logging.INFO - assert not logger.disabled - -# ------------------------------ -def test_configure_logging_idempotent(clean_ewts_env): - logger1 = configure_logging() - logger2 = configure_logging() - - assert logger1 is logger2 - assert getattr(logger1, "_initialized", False) - -# ------------------------------ -@pytest.mark.parametrize("inp,expected", [ - ("INFO", "INFO"), - ("SeVeRe", "ERROR"), - ("fatal", "CRITICAL"), - (" debug ", "DEBUG"), -]) -def test_translate_ngwpc_log_level(inp, expected): - assert translate_ngwpc_log_level(inp) == expected - -# ------------------------------ -@pytest.mark.parametrize("env_value,expected_enabled", [ - (None, True), # default: enabled - ("DISABLED", False), - ("ENABLED", True), - ("disabled", False), - ("enabled", True), - ("anystring", True), - ("", True), -]) -@pytest.mark.parametrize("level_input,expected_level", [ - ("DEBUG", logging.DEBUG), - ("INFO", logging.INFO), - ("SEVERE", logging.ERROR), - ("FATAL", logging.CRITICAL), -]) -def test_ewts_logger_matrix(clean_ewts_env, monkeypatch, capsys, env_value, expected_enabled, level_input, expected_level): - # Set environment variables - if env_value is None: - monkeypatch.delenv("NGEN_EWTS_LOGGING", raising=False) - else: - monkeypatch.setenv("NGEN_EWTS_LOGGING", env_value) - - monkeypatch.setenv("LSTM_LOGLEVEL", level_input) - - # Force logger re-initialization - logger = logging.getLogger(MODULE_NAME) - logger.handlers.clear() - logger._initialized = False - logger.disabled = False # ensure proper reset - - # Configure logger - logger = configure_logging() - - # Capture stdout - captured = capsys.readouterr() - - # Assertions - assert logger.name == MODULE_NAME - assert (not logger.disabled) == expected_enabled # True if enabled - if expected_enabled: - assert logger.level == expected_level - - # Assertions for default-enabled print - if expected_enabled and (env_value is None or env_value not in ("ENABLED", "enabled")): - assert f"{EV_EWTS_LOGGING} not explicitly set" in captured.out - else: - assert f"{EV_EWTS_LOGGING} not explicitly set" not in captured.out - diff --git a/tests/lstm_ewts/test_constants.py b/tests/lstm_ewts/test_constants.py deleted file mode 100644 index 4499bc5..0000000 --- a/tests/lstm_ewts/test_constants.py +++ /dev/null @@ -1,10 +0,0 @@ -from lstm_ewts.constants import ( - MODULE_NAME, - LOG_MODULE_NAME_LEN, -) - -def test_module_name_is_string(): - assert isinstance(MODULE_NAME, str) - -def test_module_name_length_fits_field(): - assert len(MODULE_NAME) <= LOG_MODULE_NAME_LEN diff --git a/tests/lstm_ewts/test_formatter.py b/tests/lstm_ewts/test_formatter.py deleted file mode 100644 index 3a6af0c..0000000 --- a/tests/lstm_ewts/test_formatter.py +++ /dev/null @@ -1,65 +0,0 @@ -import logging -import pytest -from lstm_ewts.formatter import CustomFormatter -from lstm_ewts.constants import MODULE_NAME - -@pytest.fixture -def formatter(): - fmt = "%(asctime)s %(levelname_padded)s %(message)s" - return CustomFormatter(fmt=fmt, datefmt="%Y-%m-%dT%H:%M:%S") - -@pytest.mark.parametrize( - "level,expected", - [ - (logging.DEBUG, "DEBUG"), - (logging.INFO, "INFO"), - (logging.WARNING, "WARNING"), - (logging.ERROR, "SEVERE"), - (logging.CRITICAL, "FATAL"), - ] -) -def test_level_name_mapping(formatter, level, expected): - record = logging.LogRecord( - name=MODULE_NAME, - level=level, - pathname="test", - lineno=0, - msg="Test message", - args=None, - exc_info=None - ) - formatted = formatter.format(record) - # Level name should appear in formatted string - assert expected in formatted - -def test_utc_timestamp(formatter): - record = logging.LogRecord( - name=MODULE_NAME, - level=logging.INFO, - pathname="test", - lineno=0, - msg="UTC test", - args=None, - exc_info=None - ) - formatted = formatter.format(record) - # Timestamp should be in UTC format "YYYY-MM-DDTHH:MM:SS" - ts_str = formatted.split()[0] - from datetime import datetime - dt = datetime.strptime(ts_str, "%Y-%m-%dT%H:%M:%S") - # It's enough to check it parses without error - -def test_trailing_whitespace_stripped(formatter): - record = logging.LogRecord( - name=MODULE_NAME, - level=logging.INFO, - pathname="test", - lineno=0, - msg="Message with space \n", - args=None, - exc_info=None - ) - formatted = formatter.format(record) - # Trailing whitespace/newline should be removed - assert " \n" not in formatted - assert formatted.endswith("Message with space") diff --git a/tests/lstm_ewts/test_paths.py b/tests/lstm_ewts/test_paths.py deleted file mode 100644 index f0072e4..0000000 --- a/tests/lstm_ewts/test_paths.py +++ /dev/null @@ -1,115 +0,0 @@ -import os -import getpass -from datetime import datetime -import pytest -from lstm_ewts import paths -from lstm_ewts.paths import create_timestamp, get_log_file_path -from lstm_ewts.constants import MODULE_NAME, EV_MODULE_LOGFILEPATH, EV_NGEN_LOGFILEPATH - -# ------------------------------- -# Fixture for a clean log environment -# ------------------------------- -@pytest.fixture -def clean_log_env(tmp_path, monkeypatch): - """Set up a temporary log environment and clean env vars. - - Yields a dict with: - tmp_dir : Path of temporary base directory - monkeypatch : the pytest monkeypatch object for further tweaks - """ - # Clear env vars - monkeypatch.delenv(EV_MODULE_LOGFILEPATH, raising=False) - monkeypatch.delenv(EV_NGEN_LOGFILEPATH, raising=False) - - # Patch constants to use tmp_path - monkeypatch.setattr(paths, "LOG_DIR_NGENCERF", tmp_path) - monkeypatch.setattr(paths, "LOG_DIR_DEFAULT", "run-logs") - - yield {"tmp_dir": tmp_path, "monkeypatch": monkeypatch} - - -# ------------------------------- -# Tests for create_timestamp() -# ------------------------------- -def test_create_timestamp_default(): - ts = create_timestamp() - assert len(ts) >= 15 - assert "T" in ts - -def test_create_timestamp_date_only(): - ts = create_timestamp(date_only=True) - assert len(ts) == 8 - -def test_create_timestamp_iso(): - ts = create_timestamp(iso=True) - assert "T" in ts and "-" in ts and ":" in ts - -def test_create_timestamp_append_ms(): - ts = create_timestamp(append_ms=True) - assert "." in ts - - -# ------------------------------- -# Tests for get_log_file_path() -# ------------------------------- -def test_get_log_file_path_uses_module_env(clean_log_env): - tmp_path = clean_log_env["tmp_dir"] - monkeypatch = clean_log_env["monkeypatch"] - - logfile = tmp_path / "test_module.log" - monkeypatch.setenv(EV_MODULE_LOGFILEPATH, str(logfile)) - - path, append = get_log_file_path() - assert path == str(logfile) - assert append is True - - -def test_get_log_file_path_uses_ngen_env(clean_log_env): - monkeypatch = clean_log_env["monkeypatch"] - tmp_path = clean_log_env["tmp_dir"] - - monkeypatch.delenv(EV_MODULE_LOGFILEPATH, raising=False) - ngen_file = tmp_path / "ngen.log" - monkeypatch.setenv(EV_NGEN_LOGFILEPATH, str(ngen_file)) - - path, append = get_log_file_path() - assert path == str(ngen_file) - assert append is True - - -def test_get_log_file_path_creates_user_subdir(clean_log_env): - tmp_path = clean_log_env["tmp_dir"] - monkeypatch = clean_log_env["monkeypatch"] - - monkeypatch.delenv(EV_MODULE_LOGFILEPATH, raising=False) - monkeypatch.delenv(EV_NGEN_LOGFILEPATH, raising=False) - - # Use real username - monkeypatch.setattr(getpass, "getuser", lambda: "alice") - - path, append = get_log_file_path() - - # Subdirectory should be username - subdir = os.path.basename(os.path.dirname(path)) - assert subdir == "alice" - assert path.endswith(".log") - assert os.path.exists(path) - - -def test_get_log_file_path_fallback_username(clean_log_env): - tmp_path = clean_log_env["tmp_dir"] - monkeypatch = clean_log_env["monkeypatch"] - - monkeypatch.delenv(EV_MODULE_LOGFILEPATH, raising=False) - monkeypatch.delenv(EV_NGEN_LOGFILEPATH, raising=False) - - # Simulate getuser() returning None - monkeypatch.setattr(getpass, "getuser", lambda: None) - - path, append = get_log_file_path() - - subdir = os.path.basename(os.path.dirname(path)) - # Should fall back to YYYYMMDD - assert len(subdir) == 8 and subdir.isdigit() - assert path.endswith(".log") - assert os.path.exists(path)