diff --git a/examples/ezmsg_generator.py b/examples/ezmsg_generator.py index cb58295..065ea65 100644 --- a/examples/ezmsg_generator.py +++ b/examples/ezmsg_generator.py @@ -1,4 +1,9 @@ """ +.. note:: + This example demonstrates legacy generator-based patterns from + ``ezmsg.util.generator``. New code should use ``ezmsg.baseproc`` + (BaseTransformer, BaseProducer, etc.) instead. + This ezmsg example showcases a design pattern where core computational logic is encapsulated within a Python generator. This approach is beneficial for several reasons: @@ -22,14 +27,61 @@ """ import asyncio +import traceback import ezmsg.core as ez import numpy as np from typing import Any -from collections.abc import Generator +from collections.abc import AsyncGenerator, Callable, Generator +from functools import wraps, reduce from ezmsg.util.messages.axisarray import AxisArray, replace from ezmsg.util.debuglog import DebugLog from ezmsg.util.gen_to_unit import gen_to_unit -from ezmsg.util.generator import consumer, compose, Gen + + +# --- Inlined helpers (previously from ezmsg.util.generator, now deprecated) --- + +def consumer(func): + """Prime a generator by advancing it to the first yield statement.""" + @wraps(func) + def wrapper(*args, **kwargs): + gen = func(*args, **kwargs) + next(gen) + return gen + return wrapper + + +def compose(*funcs): + """Compose a chain of generator functions into a single callable.""" + return lambda x: reduce(lambda f, g: g.send(f), list(funcs), x) + + +class GenState(ez.State): + gen: Generator[Any, Any, None] + + +class Gen(ez.Unit): + STATE = GenState + + INPUT = ez.InputStream(Any) + OUTPUT = ez.OutputStream(Any) + + async def initialize(self) -> None: + self.construct_generator() + + def construct_generator(self): + raise NotImplementedError + + @ez.subscriber(INPUT) + @ez.publisher(OUTPUT) + async def on_message(self, message: Any) -> AsyncGenerator: + try: + ret = self.STATE.gen.send(message) + if ret is not None: + yield self.OUTPUT, ret + except (StopIteration, GeneratorExit): + ez.logger.debug(f"Generator closed in {self.address}") + except Exception: + ez.logger.info(traceback.format_exc()) @consumer diff --git a/src/ezmsg/util/__init__.py b/src/ezmsg/util/__init__.py index c722ac6..04a26a3 100644 --- a/src/ezmsg/util/__init__.py +++ b/src/ezmsg/util/__init__.py @@ -6,7 +6,7 @@ Key modules: - :mod:`ezmsg.util.debuglog`: Debug logging utilities -- :mod:`ezmsg.util.generator`: Generator-based message processing decorators +- :mod:`ezmsg.util.generator`: Generator-based message processing decorators (for newer code, use ``ezmsg.baseproc`` instead) - :mod:`ezmsg.util.messagecodec`: JSON encoding/decoding for message logging - :mod:`ezmsg.util.messagegate`: Message flow control and gating - :mod:`ezmsg.util.messagelogger`: File-based message logging diff --git a/src/ezmsg/util/generator.py b/src/ezmsg/util/generator.py index 7dc04ec..53fb3c7 100644 --- a/src/ezmsg/util/generator.py +++ b/src/ezmsg/util/generator.py @@ -1,3 +1,15 @@ +""" +Generator-based message processing decorators (legacy). + +.. note:: + For new code, prefer ``ezmsg.baseproc`` (BaseTransformer, BaseProducer, etc.) + over the generator pattern in this module. The baseproc classes provide a cleaner + separation of state management, hashing, and processing, and are the recommended + approach for all new ezmsg units. + + This module is maintained for backward compatibility and simple stateless use cases. +""" + import ezmsg.core as ez import traceback from collections.abc import AsyncGenerator, Callable, Generator diff --git a/src/ezmsg/util/messages/modify.py b/src/ezmsg/util/messages/modify.py index 49a9bf9..7cf434d 100644 --- a/src/ezmsg/util/messages/modify.py +++ b/src/ezmsg/util/messages/modify.py @@ -1,73 +1,93 @@ -from collections.abc import AsyncGenerator, Generator -import traceback +import warnings +from collections.abc import AsyncGenerator import numpy as np import ezmsg.core as ez from .axisarray import AxisArray, replace -from ..generator import consumer, GenState -@consumer -def modify_axis( - name_map: dict[str, str | None] | None = None, -) -> Generator[AxisArray, AxisArray, None]: +class ModifyAxisSettings(ez.Settings): """ - Modify an AxisArray's axes and dims according to a name_map. + Settings for ModifyAxisTransformer and ModifyAxis unit. :param name_map: A dictionary where the keys are the names of the old dims and the values are the new names. Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised. :type name_map: dict[str, str | None] | None - :return: A primed generator object ready to yield an AxisArray with modified axes for each .send(axis_array). - :rtype: collections.abc.Generator[AxisArray, AxisArray, None] """ - # State variables - axis_arr_in = AxisArray(np.array([]), dims=[""]) - axis_arr_out = AxisArray(np.array([]), dims=[""]) - - while True: - axis_arr_in = yield axis_arr_out - - if name_map is not None: - new_dims = [name_map.get(old_k, old_k) for old_k in axis_arr_in.dims] - new_axes = { - name_map.get(old_k, old_k): v for old_k, v in axis_arr_in.axes.items() - } - drop_ax_ix = [ - ix - for ix, old_dim in enumerate(axis_arr_in.dims) - if new_dims[ix] is None - ] - if len(drop_ax_ix) > 0: - # Rewrite new_dims and new_axes without the dropped axes - new_dims = [_ for _ in new_dims if _ is not None] - new_axes.pop(None, None) - # Reshape data - # np.squeeze will raise ValueError if not len==1 - axis_arr_out = replace( - axis_arr_in, - data=np.squeeze(axis_arr_in.data, axis=tuple(drop_ax_ix)), - dims=new_dims, - axes=new_axes, - ) - else: - axis_arr_out = replace(axis_arr_in, dims=new_dims, axes=new_axes) - else: - axis_arr_out = axis_arr_in + name_map: dict[str, str | None] | None = None -class ModifyAxisSettings(ez.Settings): + +class ModifyAxisTransformer: + """ + Modify an AxisArray's axes and dims according to a name_map. + + This is a stateless transformer that renames dimensions and axes, with support + for dropping length-1 dimensions. + + :param settings: Settings for the transformer. + :type settings: ModifyAxisSettings + """ + + def __init__(self, settings: ModifyAxisSettings = ModifyAxisSettings()): + self.settings = settings + + def __call__(self, message: AxisArray) -> AxisArray: + name_map = self.settings.name_map + if name_map is None: + return message + + new_dims = [name_map.get(old_k, old_k) for old_k in message.dims] + new_axes = { + name_map.get(old_k, old_k): v for old_k, v in message.axes.items() + } + drop_ax_ix = [ + ix + for ix, old_dim in enumerate(message.dims) + if new_dims[ix] is None + ] + if len(drop_ax_ix) > 0: + new_dims = [d for d in new_dims if d is not None] + new_axes.pop(None, None) + return replace( + message, + data=np.squeeze(message.data, axis=tuple(drop_ax_ix)), + dims=new_dims, + axes=new_axes, + ) + return replace(message, dims=new_dims, axes=new_axes) + + _send_warned: bool = False + + def send(self, message: AxisArray) -> AxisArray: + """Alias for __call__. Deprecated — use ``__call__`` directly instead.""" + if not ModifyAxisTransformer._send_warned: + ModifyAxisTransformer._send_warned = True + warnings.warn( + "ModifyAxisTransformer.send() is deprecated. Use __call__() instead.", + DeprecationWarning, + stacklevel=2, + ) + return self(message) + + +def modify_axis( + name_map: dict[str, str | None] | None = None, +) -> ModifyAxisTransformer: """ - Settings for ModifyAxis unit. + Create a :class:`ModifyAxisTransformer` instance. - Configuration for modifying axis names and dimensions of AxisArray messages. + This function exists for backward compatibility with code that previously used + the generator-based ``modify_axis``. The returned object supports ``.send(msg)``. :param name_map: A dictionary where the keys are the names of the old dims and the values are the new names. Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised. :type name_map: dict[str, str | None] | None + :return: A ModifyAxisTransformer instance. + :rtype: ModifyAxisTransformer """ - - name_map: dict[str, str | None] | None = None + return ModifyAxisTransformer(ModifyAxisSettings(name_map=name_map)) class ModifyAxis(ez.Unit): @@ -75,62 +95,28 @@ class ModifyAxis(ez.Unit): Unit for modifying axis names and dimensions of AxisArray messages. Renames dimensions and axes according to a name mapping, with support for - dropping dimensions. Uses zero-copy operations for efficient processing. + dropping dimensions. """ - STATE = GenState SETTINGS = ModifyAxisSettings INPUT_SIGNAL = ez.InputStream(AxisArray) OUTPUT_SIGNAL = ez.OutputStream(AxisArray) INPUT_SETTINGS = ez.InputStream(ModifyAxisSettings) - async def initialize(self) -> None: - """ - Initialize the ModifyAxis unit. - - Sets up the generator for axis modification operations. - """ - self.construct_generator() + _transformer: ModifyAxisTransformer - def construct_generator(self): - """ - Construct the axis-modifying generator with current settings. - - Creates a new modify_axis generator instance using the unit's name mapping. - """ - self.STATE.gen = modify_axis(name_map=self.SETTINGS.name_map) + async def initialize(self) -> None: + self._transformer = ModifyAxisTransformer(self.SETTINGS) @ez.subscriber(INPUT_SETTINGS) async def on_settings(self, msg: ez.Settings) -> None: - """ - Handle incoming settings updates. - - :param msg: New settings to apply. - :type msg: ez.Settings - """ self.apply_settings(msg) - self.construct_generator() + self._transformer = ModifyAxisTransformer(self.SETTINGS) @ez.subscriber(INPUT_SIGNAL, zero_copy=True) @ez.publisher(OUTPUT_SIGNAL) async def on_message(self, message: AxisArray) -> AsyncGenerator: - """ - Process incoming AxisArray messages and modify their axes. - - Uses zero-copy operations to efficiently modify axis names and dimensions - while preserving data integrity. - - :param message: Input AxisArray to modify. - :type message: AxisArray - :return: Async generator yielding AxisArray with modified axes. - :rtype: collections.abc.AsyncGenerator - """ - try: - ret = self.STATE.gen.send(message) - if ret is not None: - yield self.OUTPUT_SIGNAL, ret - except (StopIteration, GeneratorExit): - ez.logger.debug(f"ModifyAxis closed in {self.address}") - except Exception: - ez.logger.info(traceback.format_exc()) + ret = self._transformer(message) + if ret is not None: + yield self.OUTPUT_SIGNAL, ret diff --git a/tests/test_generator.py b/tests/test_generator.py index 7736567..0c01965 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,3 +1,6 @@ +# NOTE: This file tests the legacy ezmsg.util.generator module. +# While this pattern is typically not recommended, it has some narrow use cases. +# The @consumer usages here are the subject of the tests and should remain. from collections.abc import AsyncGenerator, Generator import copy import json