Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions examples/ezmsg_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
"""
.. note::
This example demonstrates legacy generator-based patterns from
``ezmsg.util.generator``. New code should use ``ezmsg.baseproc``
(BaseTransformer, BaseProducer, etc.) instead.

This ezmsg example showcases a design pattern where core computational
logic is encapsulated within a Python generator. This approach is
beneficial for several reasons:
Expand All @@ -22,14 +27,61 @@
"""

import asyncio
import traceback
import ezmsg.core as ez
import numpy as np
from typing import Any
from collections.abc import Generator
from collections.abc import AsyncGenerator, Callable, Generator
from functools import wraps, reduce
from ezmsg.util.messages.axisarray import AxisArray, replace
from ezmsg.util.debuglog import DebugLog
from ezmsg.util.gen_to_unit import gen_to_unit
from ezmsg.util.generator import consumer, compose, Gen


# --- Inlined helpers (previously from ezmsg.util.generator, now deprecated) ---

def consumer(func):
"""Prime a generator by advancing it to the first yield statement."""
@wraps(func)
def wrapper(*args, **kwargs):
gen = func(*args, **kwargs)
next(gen)
return gen
return wrapper


def compose(*funcs):
"""Compose a chain of generator functions into a single callable."""
return lambda x: reduce(lambda f, g: g.send(f), list(funcs), x)


class GenState(ez.State):
gen: Generator[Any, Any, None]


class Gen(ez.Unit):
STATE = GenState

INPUT = ez.InputStream(Any)
OUTPUT = ez.OutputStream(Any)

async def initialize(self) -> None:
self.construct_generator()

def construct_generator(self):
raise NotImplementedError

@ez.subscriber(INPUT)
@ez.publisher(OUTPUT)
async def on_message(self, message: Any) -> AsyncGenerator:
try:
ret = self.STATE.gen.send(message)
if ret is not None:
yield self.OUTPUT, ret
except (StopIteration, GeneratorExit):
ez.logger.debug(f"Generator closed in {self.address}")
except Exception:
ez.logger.info(traceback.format_exc())


@consumer
Expand Down
2 changes: 1 addition & 1 deletion src/ezmsg/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

Key modules:
- :mod:`ezmsg.util.debuglog`: Debug logging utilities
- :mod:`ezmsg.util.generator`: Generator-based message processing decorators
- :mod:`ezmsg.util.generator`: Generator-based message processing decorators (for newer code, use ``ezmsg.baseproc`` instead)
- :mod:`ezmsg.util.messagecodec`: JSON encoding/decoding for message logging
- :mod:`ezmsg.util.messagegate`: Message flow control and gating
- :mod:`ezmsg.util.messagelogger`: File-based message logging
Expand Down
12 changes: 12 additions & 0 deletions src/ezmsg/util/generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
"""
Generator-based message processing decorators (legacy).

.. note::
For new code, prefer ``ezmsg.baseproc`` (BaseTransformer, BaseProducer, etc.)
over the generator pattern in this module. The baseproc classes provide a cleaner
separation of state management, hashing, and processing, and are the recommended
approach for all new ezmsg units.

This module is maintained for backward compatibility and simple stateless use cases.
"""

import ezmsg.core as ez
import traceback
from collections.abc import AsyncGenerator, Callable, Generator
Expand Down
166 changes: 76 additions & 90 deletions src/ezmsg/util/messages/modify.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,122 @@
from collections.abc import AsyncGenerator, Generator
import traceback
import warnings
from collections.abc import AsyncGenerator

import numpy as np

import ezmsg.core as ez
from .axisarray import AxisArray, replace
from ..generator import consumer, GenState


@consumer
def modify_axis(
name_map: dict[str, str | None] | None = None,
) -> Generator[AxisArray, AxisArray, None]:
class ModifyAxisSettings(ez.Settings):
"""
Modify an AxisArray's axes and dims according to a name_map.
Settings for ModifyAxisTransformer and ModifyAxis unit.

:param name_map: A dictionary where the keys are the names of the old dims and the values are the new names.
Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised.
:type name_map: dict[str, str | None] | None
:return: A primed generator object ready to yield an AxisArray with modified axes for each .send(axis_array).
:rtype: collections.abc.Generator[AxisArray, AxisArray, None]
"""
# State variables
axis_arr_in = AxisArray(np.array([]), dims=[""])
axis_arr_out = AxisArray(np.array([]), dims=[""])

while True:
axis_arr_in = yield axis_arr_out

if name_map is not None:
new_dims = [name_map.get(old_k, old_k) for old_k in axis_arr_in.dims]
new_axes = {
name_map.get(old_k, old_k): v for old_k, v in axis_arr_in.axes.items()
}
drop_ax_ix = [
ix
for ix, old_dim in enumerate(axis_arr_in.dims)
if new_dims[ix] is None
]
if len(drop_ax_ix) > 0:
# Rewrite new_dims and new_axes without the dropped axes
new_dims = [_ for _ in new_dims if _ is not None]
new_axes.pop(None, None)
# Reshape data
# np.squeeze will raise ValueError if not len==1
axis_arr_out = replace(
axis_arr_in,
data=np.squeeze(axis_arr_in.data, axis=tuple(drop_ax_ix)),
dims=new_dims,
axes=new_axes,
)
else:
axis_arr_out = replace(axis_arr_in, dims=new_dims, axes=new_axes)
else:
axis_arr_out = axis_arr_in

name_map: dict[str, str | None] | None = None

class ModifyAxisSettings(ez.Settings):

class ModifyAxisTransformer:
"""
Modify an AxisArray's axes and dims according to a name_map.

This is a stateless transformer that renames dimensions and axes, with support
for dropping length-1 dimensions.

:param settings: Settings for the transformer.
:type settings: ModifyAxisSettings
"""

def __init__(self, settings: ModifyAxisSettings = ModifyAxisSettings()):
self.settings = settings

def __call__(self, message: AxisArray) -> AxisArray:
name_map = self.settings.name_map
if name_map is None:
return message

new_dims = [name_map.get(old_k, old_k) for old_k in message.dims]
new_axes = {
name_map.get(old_k, old_k): v for old_k, v in message.axes.items()
}
drop_ax_ix = [
ix
for ix, old_dim in enumerate(message.dims)
if new_dims[ix] is None
]
if len(drop_ax_ix) > 0:
new_dims = [d for d in new_dims if d is not None]
new_axes.pop(None, None)
return replace(
message,
data=np.squeeze(message.data, axis=tuple(drop_ax_ix)),
dims=new_dims,
axes=new_axes,
)
return replace(message, dims=new_dims, axes=new_axes)

_send_warned: bool = False

def send(self, message: AxisArray) -> AxisArray:
"""Alias for __call__. Deprecated — use ``__call__`` directly instead."""
if not ModifyAxisTransformer._send_warned:
ModifyAxisTransformer._send_warned = True
warnings.warn(
"ModifyAxisTransformer.send() is deprecated. Use __call__() instead.",
DeprecationWarning,
stacklevel=2,
)
return self(message)


def modify_axis(
name_map: dict[str, str | None] | None = None,
) -> ModifyAxisTransformer:
"""
Settings for ModifyAxis unit.
Create a :class:`ModifyAxisTransformer` instance.

Configuration for modifying axis names and dimensions of AxisArray messages.
This function exists for backward compatibility with code that previously used
the generator-based ``modify_axis``. The returned object supports ``.send(msg)``.

:param name_map: A dictionary where the keys are the names of the old dims and the values are the new names.
Use None as a value to drop the dimension. If the dropped dimension is not len==1 then an error is raised.
:type name_map: dict[str, str | None] | None
:return: A ModifyAxisTransformer instance.
:rtype: ModifyAxisTransformer
"""

name_map: dict[str, str | None] | None = None
return ModifyAxisTransformer(ModifyAxisSettings(name_map=name_map))


class ModifyAxis(ez.Unit):
"""
Unit for modifying axis names and dimensions of AxisArray messages.

Renames dimensions and axes according to a name mapping, with support for
dropping dimensions. Uses zero-copy operations for efficient processing.
dropping dimensions.
"""

STATE = GenState
SETTINGS = ModifyAxisSettings

INPUT_SIGNAL = ez.InputStream(AxisArray)
OUTPUT_SIGNAL = ez.OutputStream(AxisArray)
INPUT_SETTINGS = ez.InputStream(ModifyAxisSettings)

async def initialize(self) -> None:
"""
Initialize the ModifyAxis unit.

Sets up the generator for axis modification operations.
"""
self.construct_generator()
_transformer: ModifyAxisTransformer

def construct_generator(self):
"""
Construct the axis-modifying generator with current settings.

Creates a new modify_axis generator instance using the unit's name mapping.
"""
self.STATE.gen = modify_axis(name_map=self.SETTINGS.name_map)
async def initialize(self) -> None:
self._transformer = ModifyAxisTransformer(self.SETTINGS)

@ez.subscriber(INPUT_SETTINGS)
async def on_settings(self, msg: ez.Settings) -> None:
"""
Handle incoming settings updates.

:param msg: New settings to apply.
:type msg: ez.Settings
"""
self.apply_settings(msg)
self.construct_generator()
self._transformer = ModifyAxisTransformer(self.SETTINGS)

@ez.subscriber(INPUT_SIGNAL, zero_copy=True)
@ez.publisher(OUTPUT_SIGNAL)
async def on_message(self, message: AxisArray) -> AsyncGenerator:
"""
Process incoming AxisArray messages and modify their axes.

Uses zero-copy operations to efficiently modify axis names and dimensions
while preserving data integrity.

:param message: Input AxisArray to modify.
:type message: AxisArray
:return: Async generator yielding AxisArray with modified axes.
:rtype: collections.abc.AsyncGenerator
"""
try:
ret = self.STATE.gen.send(message)
if ret is not None:
yield self.OUTPUT_SIGNAL, ret
except (StopIteration, GeneratorExit):
ez.logger.debug(f"ModifyAxis closed in {self.address}")
except Exception:
ez.logger.info(traceback.format_exc())
ret = self._transformer(message)
if ret is not None:
yield self.OUTPUT_SIGNAL, ret
3 changes: 3 additions & 0 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# NOTE: This file tests the legacy ezmsg.util.generator module.
# While this pattern is typically not recommended, it has some narrow use cases.
# The @consumer usages here are the subject of the tests and should remain.
from collections.abc import AsyncGenerator, Generator
import copy
import json
Expand Down