diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index fb09307..673deb6 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -19,7 +19,7 @@ from typing import Any from .stream import Stream, InputStream, OutputStream -from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR, ZERO_COPY_ATTR +from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR from .messagechannel import LeakyQueue from .graphcontext import GraphContext @@ -374,8 +374,6 @@ async def wrapped_task(msg: Any = None) -> None: result = call_fn(msg) if inspect.isasyncgen(result): async for stream, obj in result: - if obj and getattr(task, ZERO_COPY_ATTR, False) and obj is msg: - obj = deepcopy(obj) await pub_fn(stream, obj) elif asyncio.iscoroutine(result): diff --git a/src/ezmsg/core/unit.py b/src/ezmsg/core/unit.py index f4a77cf..8527e71 100644 --- a/src/ezmsg/core/unit.py +++ b/src/ezmsg/core/unit.py @@ -1,6 +1,7 @@ import time import inspect import functools +import warnings from .stream import InputStream, OutputStream from .component import ComponentMeta, Component from .settings import Settings @@ -23,6 +24,8 @@ LEAKY_ATTR = "__ez_leaky__" MAX_QUEUE_ATTR = "__ez_max_queue__" +_ZERO_COPY_SENTINEL = object() + class UnitMeta(ComponentMeta): def __init__( @@ -162,17 +165,18 @@ def pub_factory(func): return pub_factory -def subscriber(stream: InputStream, zero_copy: bool = False): +def subscriber(stream: InputStream, zero_copy: Any = _ZERO_COPY_SENTINEL): """ A decorator for a method that subscribes to a stream in the task/messaging thread. An async function will run once per message received from the :obj:`InputStream` it subscribes to. A function can have both ``@subscriber`` and ``@publisher`` decorators. + The ``zero_copy`` argument is deprecated and ignored. Subscribers always receive + zero-copy messages, so callers can omit it. + :param stream: The input stream to subscribe to :type stream: InputStream - :param zero_copy: Whether to use zero-copy message passing (default: False) - :type zero_copy: bool :return: Decorated function that can subscribe to the stream :rtype: collections.abc.Callable :raises ValueError: If stream is not an InputStream @@ -183,20 +187,30 @@ def subscriber(stream: InputStream, zero_copy: bool = False): INPUT = ez.InputStream(Message) - @subscriber(INPUT) - async def print_message(self, message: Message) -> None: - print(message) + @subscriber(INPUT) + async def print_message(self, message: Message) -> None: + print(message) """ if not isinstance(stream, InputStream): raise ValueError(f"Cannot subscribe to object of type {type(stream)}") + if zero_copy is not _ZERO_COPY_SENTINEL: + warnings.warn( + "The `zero_copy` argument to @subscriber is deprecated and ignored. " + "Zero-copy behavior is now determined by the InputStream's `leaky` property " + "(non-leaky subscribers use zero-copy; leaky subscribers receive deep-copied " + "messages). Remove any explicit `zero_copy=...` usage.", + DeprecationWarning, + stacklevel=2, + ) + def sub_factory(func): subscribed_streams: InputStream | None = getattr(func, SUBSCRIBES_ATTR, None) if subscribed_streams is not None: raise Exception(f"{func} cannot subscribe to more than one stream") setattr(func, SUBSCRIBES_ATTR, stream) - setattr(func, ZERO_COPY_ATTR, zero_copy) + setattr(func, ZERO_COPY_ATTR, True) setattr(func, LEAKY_ATTR, stream.leaky) setattr(func, MAX_QUEUE_ATTR, stream.max_queue) return task(func) diff --git a/src/ezmsg/util/debuglog.py b/src/ezmsg/util/debuglog.py index 3790588..3c714af 100644 --- a/src/ezmsg/util/debuglog.py +++ b/src/ezmsg/util/debuglog.py @@ -32,7 +32,7 @@ class DebugLog(ez.Unit): OUTPUT = ez.OutputStream(Any) """Send messages back out to continue through the graph.""" - @ez.subscriber(INPUT, zero_copy=True) + @ez.subscriber(INPUT) @ez.publisher(OUTPUT) async def log(self, msg: Any) -> AsyncGenerator: logstr = f"{self.SETTINGS.name} - {msg=}" diff --git a/src/ezmsg/util/messages/key.py b/src/ezmsg/util/messages/key.py index 1850f05..d3227fd 100644 --- a/src/ezmsg/util/messages/key.py +++ b/src/ezmsg/util/messages/key.py @@ -83,7 +83,7 @@ async def on_settings(self, msg: ez.Settings) -> None: self.apply_settings(msg) self.construct_generator() - @ez.subscriber(INPUT_SIGNAL, zero_copy=True) + @ez.subscriber(INPUT_SIGNAL) @ez.publisher(OUTPUT_SIGNAL) async def on_message(self, message: AxisArray) -> AsyncGenerator: """ @@ -125,7 +125,7 @@ class FilterOnKey(ez.Unit): INPUT_SIGNAL = ez.InputStream(AxisArray) OUTPUT_SIGNAL = ez.OutputStream(AxisArray) - @ez.subscriber(INPUT_SIGNAL, zero_copy=True) + @ez.subscriber(INPUT_SIGNAL) @ez.publisher(OUTPUT_SIGNAL) async def on_message(self, message: AxisArray) -> AsyncGenerator: """ diff --git a/src/ezmsg/util/messages/modify.py b/src/ezmsg/util/messages/modify.py index 7cf434d..0f76733 100644 --- a/src/ezmsg/util/messages/modify.py +++ b/src/ezmsg/util/messages/modify.py @@ -114,7 +114,7 @@ async def on_settings(self, msg: ez.Settings) -> None: self.apply_settings(msg) self._transformer = ModifyAxisTransformer(self.SETTINGS) - @ez.subscriber(INPUT_SIGNAL, zero_copy=True) + @ez.subscriber(INPUT_SIGNAL) @ez.publisher(OUTPUT_SIGNAL) async def on_message(self, message: AxisArray) -> AsyncGenerator: ret = self._transformer(message) diff --git a/src/ezmsg/util/perf/impl.py b/src/ezmsg/util/perf/impl.py index 0e8460c..6bc83d1 100644 --- a/src/ezmsg/util/perf/impl.py +++ b/src/ezmsg/util/perf/impl.py @@ -130,7 +130,7 @@ class LoadTestRelay(ez.Unit): INPUT = ez.InputStream(LoadTestSample) OUTPUT = ez.OutputStream(LoadTestSample) - @ez.subscriber(INPUT, zero_copy=True) + @ez.subscriber(INPUT) @ez.publisher(OUTPUT) async def on_msg(self, msg: LoadTestSample) -> typing.AsyncGenerator: yield self.OUTPUT, msg @@ -152,7 +152,7 @@ class LoadTestReceiver(ez.Unit): async def initialize(self) -> None: ez.logger.info(f"Load test subscriber started. (PID: {os.getpid()})") - @ez.subscriber(INPUT, zero_copy=True) + @ez.subscriber(INPUT) async def receive(self, sample: LoadTestSample) -> None: counter = self.STATE.counters.get(sample.key, -1) if sample.counter != counter + 1: @@ -166,7 +166,7 @@ async def receive(self, sample: LoadTestSample) -> None: class LoadTestSink(LoadTestReceiver): INPUT = ez.InputStream(LoadTestSample) - @ez.subscriber(INPUT, zero_copy=True) + @ez.subscriber(INPUT) async def receive(self, sample: LoadTestSample) -> None: await super().receive(sample) if len(self.STATE.received_data) == self.SETTINGS.num_msgs: