diff --git a/examples/lowlevel_api.py b/examples/lowlevel_api.py deleted file mode 100644 index 9bc5d58..0000000 --- a/examples/lowlevel_api.py +++ /dev/null @@ -1,119 +0,0 @@ -import asyncio - -import ezmsg.core as ez - -PORT = 12345 -MAX_COUNT = 100 -TOPIC = "/TEST" - - -async def handle_pub(pub: ez.Publisher) -> None: - print("Publisher Task Launched") - - count = 0 - - while True: - await pub.broadcast(f"{count=}") - await asyncio.sleep(0.1) - count += 1 - if count >= MAX_COUNT: - break - - print("Publisher Task Concluded") - - -async def handle_sub(sub: ez.Subscriber) -> None: - print("Subscriber Task Launched") - - rx_count = 0 - while True: - async with sub.recv_zero_copy() as msg: - # Uncomment if you want to witness backpressure! - # await asyncio.sleep(0.15) - print(msg) - - rx_count += 1 - if rx_count >= MAX_COUNT: - break - - print("Subscriber Task Concluded") - - -async def host(host: str = "127.0.0.1"): - # Manually create a GraphServer - server = ez.GraphServer() - server.start((host, PORT)) - - print(f"Created GraphServer @ {server.address}") - - try: - test_pub = await ez.Publisher.create(TOPIC, (host, PORT), host=host) - test_sub1 = await ez.Subscriber.create(TOPIC, (host, PORT)) - test_sub2 = await ez.Subscriber.create(TOPIC, (host, PORT)) - - await asyncio.sleep(1.0) - - pub_task = asyncio.Task(handle_pub(test_pub)) - sub_task_1 = asyncio.Task(handle_sub(test_sub1)) - sub_task_2 = asyncio.Task(handle_sub(test_sub2)) - - await asyncio.wait([pub_task, sub_task_1, sub_task_2]) - - test_pub.close() - test_sub1.close() - test_sub2.close() - - for future in asyncio.as_completed( - [ - test_pub.wait_closed(), - test_sub1.wait_closed(), - test_sub2.wait_closed(), - ] - ): - await future - - finally: - server.stop() - - print("Done") - - -async def attach_client(host: str = "127.0.0.1"): - - sub = await ez.Subscriber.create(TOPIC, (host, PORT)) - - try: - while True: - async with sub.recv_zero_copy() as msg: - # Uncomment if you want to see EXTREME backpressure! - # await asyncio.sleep(1.0) - print(msg) - - except asyncio.CancelledError: - pass - - finally: - sub.close() - await sub.wait_closed() - print("Detached") - - -if __name__ == "__main__": - from dataclasses import dataclass - from argparse import ArgumentParser - - parser = ArgumentParser() - parser.add_argument("--attach", action="store_true", help="attach to running graph") - parser.add_argument("--host", default="0.0.0.0", help="hostname for graphserver") - - @dataclass - class Args: - attach: bool - host: str - - args = Args(**vars(parser.parse_args())) - - if args.attach: - asyncio.run(attach_client(host=args.host)) - else: - asyncio.run(host(host=args.host)) diff --git a/examples/simple_async_publisher.py b/examples/simple_async_publisher.py new file mode 100644 index 0000000..a3a886b --- /dev/null +++ b/examples/simple_async_publisher.py @@ -0,0 +1,33 @@ +import asyncio + +import ezmsg.core as ez + +TOPIC = "/TEST" + + +async def main(host: str = "127.0.0.1", port: int = 12345) -> None: + async with ez.GraphContext((host, port), auto_start=True) as ctx: + pub = await ctx.publisher(TOPIC) + try: + print("Publisher Task Launched") + count = 0 + while True: + await pub.broadcast(f"{count=}") + await asyncio.sleep(0.1) + count += 1 + except asyncio.CancelledError: + pass + finally: + print("Publisher Task Concluded") + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--host", default="127.0.0.1", help="hostname for graphserver") + parser.add_argument("--port", default=12345, type=int, help="port for graphserver") + + args = parser.parse_args() + + asyncio.run(main(host=args.host, port=args.port)) diff --git a/examples/simple_async_subscriber.py b/examples/simple_async_subscriber.py new file mode 100644 index 0000000..d2a382c --- /dev/null +++ b/examples/simple_async_subscriber.py @@ -0,0 +1,35 @@ +import asyncio + +import ezmsg.core as ez + +PORT = 12345 +TOPIC = "/TEST" + + +async def main(host: str = "127.0.0.1", port: int = 12345) -> None: + async with ez.GraphContext((host, port), auto_start=True) as ctx: + sub = await ctx.subscriber(TOPIC) + try: + print("Subscriber Task Launched") + while True: + async with sub.recv_zero_copy() as msg: + # Uncomment if you want to witness backpressure! + # await asyncio.sleep(1.0) + print(msg) + except asyncio.CancelledError: + pass + finally: + print("Subscriber Task Concluded") + print("Detached") + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--host", default="127.0.0.1", help="hostname for graphserver") + parser.add_argument("--port", default=12345, type=int, help="port for graphserver") + + args = parser.parse_args() + + asyncio.run(main(host=args.host, port=args.port)) diff --git a/examples/simple_publisher.py b/examples/simple_publisher.py new file mode 100644 index 0000000..5454787 --- /dev/null +++ b/examples/simple_publisher.py @@ -0,0 +1,35 @@ +import time + +import ezmsg.core as ez + +TOPIC = "/TEST" + +def main(host: str = "127.0.0.1", port: int = 12345) -> None: + with ez.sync.init((host, port), auto_start=True) as ctx: + pub = ctx.create_publisher(TOPIC, force_tcp=True) + + print("Publisher Task Launched") + count = 0 + try: + while True: + output = f"{count=}" + pub.publish(output) + print(output) + time.sleep(0.1) + count += 1 + except KeyboardInterrupt: + pass + print("Publisher Task Concluded") + + print("Done") + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--host", default="127.0.0.1", help="hostname for graphserver") + parser.add_argument("--port", default=12345, type=int, help="port for graphserver") + args = parser.parse_args() + + main(host=args.host, port=args.port) diff --git a/examples/simple_subscriber.py b/examples/simple_subscriber.py new file mode 100644 index 0000000..dc7bf9e --- /dev/null +++ b/examples/simple_subscriber.py @@ -0,0 +1,30 @@ +import time +import ezmsg.core as ez + +TOPIC = "/TEST" + + +def main(host: str = "127.0.0.1", port: int = 12345) -> None: + with ez.sync.init((host, port), auto_start=True) as ctx: + print("Subscriber Task Launched") + + def on_message(msg: str) -> None: + # Uncomment if you want to witness backpressure! + # time.sleep(1.0) + print(msg) + + ctx.create_subscription(TOPIC, callback=on_message) + ez.sync.spin(ctx) + + print("Subscriber Task Concluded") + + +if __name__ == "__main__": + from argparse import ArgumentParser + + parser = ArgumentParser() + parser.add_argument("--host", default="127.0.0.1", help="hostname for graphserver") + parser.add_argument("--port", default=12345, type=int, help="port for graphserver") + args = parser.parse_args() + + main(host=args.host, port=args.port) diff --git a/src/ezmsg/core/__init__.py b/src/ezmsg/core/__init__.py index f377cb3..bc56a3b 100644 --- a/src/ezmsg/core/__init__.py +++ b/src/ezmsg/core/__init__.py @@ -25,6 +25,10 @@ "NormalTermination", "GraphServer", "GraphContext", + "sync", + "SyncContext", + "SyncPublisher", + "SyncSubscriber", "run_command", "Publisher", "Subscriber", @@ -45,6 +49,8 @@ from .backendprocess import Complete, NormalTermination from .graphserver import GraphServer from .graphcontext import GraphContext +from . import sync +from .sync import SyncContext, SyncPublisher, SyncSubscriber from .command import run_command from .pubclient import Publisher from .subclient import Subscriber diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index f773569..1d62c2b 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -25,6 +25,10 @@ class GraphContext: :param graph_service: Optional graph service instance to use :type graph_service: GraphService | None + :param auto_start: Whether to auto-start a GraphServer if connection fails. + If None, defaults to auto-start only when graph_address is not provided + and no environment override is set. + :type auto_start: bool | None .. note:: The GraphContext is typically managed automatically by the ezmsg runtime @@ -40,11 +44,13 @@ class GraphContext: def __init__( self, graph_address: AddressType | None = None, + auto_start: bool | None = None, ) -> None: self._clients = set() self._edges = set() self._graph_address = graph_address self._graph_server = None + self._auto_start = auto_start @property def graph_address(self) -> AddressType | None: @@ -130,7 +136,9 @@ async def resume(self) -> None: await GraphService(self.graph_address).resume() async def _ensure_servers(self) -> None: - self._graph_server = await GraphService(self.graph_address).ensure() + self._graph_server = await GraphService(self.graph_address).ensure( + auto_start=self._auto_start + ) async def _shutdown_servers(self) -> None: if self._graph_server is not None: diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index cb97b9e..f6d2d2d 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -95,6 +95,7 @@ def start(self, address: AddressType | None = None) -> None: # type: ignore[ove self._loop = asyncio.new_event_loop() super().start() self._server_up.wait() + logger.info(f'Started GraphServer at {address}') def stop(self) -> None: self._shutdown.set() @@ -453,15 +454,18 @@ def create_server(self) -> GraphServer: self._address = server.address return server - async def ensure(self) -> GraphServer | None: + async def ensure(self, auto_start: bool | None = None) -> GraphServer | None: """ Try connecting to an existing server. If none is listening and no explicit address/environment is set, start one and return it. If an existing one is - found, return None. + found, return None. If auto_start is provided, it overrides the default + behavior. """ server = None ensure_server = False - if self._address is None: + if auto_start is not None: + ensure_server = auto_start + elif self._address is None: # Only auto-start if env var not forcing a location ensure_server = self.ADDR_ENV not in os.environ diff --git a/src/ezmsg/core/sync.py b/src/ezmsg/core/sync.py new file mode 100644 index 0000000..4ed6447 --- /dev/null +++ b/src/ezmsg/core/sync.py @@ -0,0 +1,426 @@ +from __future__ import annotations + +import asyncio +import logging +import threading +from concurrent.futures import TimeoutError as FutureTimeoutError +from copy import deepcopy +from typing import Any, Callable, Iterable + +from .backendprocess import new_threaded_event_loop +from .graphcontext import GraphContext +from .messagecache import CacheMiss +from .netprotocol import AddressType +from .pubclient import Publisher +from .subclient import Subscriber + +logger = logging.getLogger("ezmsg") + + +def _future_result(future, timeout: float | None): + try: + return future.result(timeout=timeout) + except FutureTimeoutError as exc: + future.cancel() + raise TimeoutError("Timed out waiting for async operation") from exc + + +class SyncPublisher: + """Synchronous wrapper around ezmsg Publisher.""" + + def __init__(self, pub: Publisher, loop: asyncio.AbstractEventLoop) -> None: + self._pub = pub + self._loop = loop + + def publish(self, obj: Any, timeout: float | None = None) -> None: + fut = asyncio.run_coroutine_threadsafe(self._pub.broadcast(obj), self._loop) + _future_result(fut, timeout) + + def broadcast(self, obj: Any, timeout: float | None = None) -> None: + self.publish(obj, timeout=timeout) + + def sync(self, timeout: float | None = None) -> None: + fut = asyncio.run_coroutine_threadsafe(self._pub.sync(), self._loop) + _future_result(fut, timeout) + + def pause(self) -> None: + self._pub.pause() + + def resume(self) -> None: + self._pub.resume() + + def close(self) -> None: + self._pub.close() + + def wait_closed(self, timeout: float | None = None) -> None: + fut = asyncio.run_coroutine_threadsafe(self._pub.wait_closed(), self._loop) + _future_result(fut, timeout) + + +class _SyncZeroCopy: + def __init__( + self, + sub: Subscriber, + loop: asyncio.AbstractEventLoop, + timeout: float | None, + ) -> None: + self._sub = sub + self._loop = loop + self._timeout = timeout + self._cm = None + + def __enter__(self) -> Any: + self._cm = self._sub.recv_zero_copy() + fut = asyncio.run_coroutine_threadsafe(self._cm.__aenter__(), self._loop) + return _future_result(fut, self._timeout) + + def __exit__(self, exc_type, exc, tb) -> bool: + assert self._cm is not None + fut = asyncio.run_coroutine_threadsafe( + self._cm.__aexit__(exc_type, exc, tb), self._loop + ) + return bool(_future_result(fut, self._timeout)) + + +class SyncSubscriber: + """Synchronous wrapper around ezmsg Subscriber.""" + + def __init__( + self, + sub: Subscriber, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._sub = sub + self._loop = loop + + def recv(self, timeout: float | None = None) -> Any: + fut = asyncio.run_coroutine_threadsafe(self._sub.recv(), self._loop) + return _future_result(fut, timeout) + + def recv_zero_copy(self, timeout: float | None = None) -> _SyncZeroCopy: + return _SyncZeroCopy(self._sub, self._loop, timeout) + + def close(self) -> None: + self._sub.close() + + def wait_closed(self, timeout: float | None = None) -> None: + fut = asyncio.run_coroutine_threadsafe(self._sub.wait_closed(), self._loop) + _future_result(fut, timeout) + + +class SyncContext: + """Synchronous interface for ezmsg's low-level pub/sub API.""" + + def __init__( + self, graph_address: AddressType | None = None, auto_start: bool | None = None + ) -> None: + self._graph_address = graph_address + self._auto_start = auto_start + self._graph_context: GraphContext | None = None + self._loop_cm = None + self._loop: asyncio.AbstractEventLoop | None = None + self._shutdown_requested = threading.Event() + self._spinning = False + self._closed = False + self._spin_future = None + self._callback_subscriptions: list[ + tuple[SyncSubscriber, Callable[[Any], None], bool] + ] = [] + + @property + def graph_address(self) -> AddressType | None: + if self._graph_context is None: + return self._graph_address + return self._graph_context.graph_address + + def __enter__(self) -> "SyncContext": + if self._loop_cm is not None: + return self + + self._loop_cm = new_threaded_event_loop() + self._loop = self._loop_cm.__enter__() + + graph_context = GraphContext(self._graph_address, auto_start=self._auto_start) + + async def _enter() -> GraphContext: + return await graph_context.__aenter__() + + self._graph_context = _future_result( + asyncio.run_coroutine_threadsafe(_enter(), self._loop), None + ) + return self + + def __exit__(self, exc_type, exc, tb) -> bool: + self.shutdown() + return False + + def publisher(self, topic: str, **kwargs) -> SyncPublisher: + return self.create_publisher(topic, **kwargs) + + def create_publisher(self, topic: str, **kwargs) -> SyncPublisher: + self._ensure_started() + assert self._graph_context is not None + assert self._loop is not None + fut = asyncio.run_coroutine_threadsafe( + self._graph_context.publisher(topic, **kwargs), self._loop + ) + pub = _future_result(fut, None) + return SyncPublisher(pub, self._loop) + + def subscriber(self, topic: str, **kwargs) -> SyncSubscriber: + return self.create_subscription(topic, **kwargs) + + def create_subscription( + self, + topic: str, + callback: Callable[[Any], None] | None = None, + *, + zero_copy: bool = False, + **kwargs, + ) -> SyncSubscriber: + self._ensure_started() + assert self._graph_context is not None + assert self._loop is not None + fut = asyncio.run_coroutine_threadsafe( + self._graph_context.subscriber(topic, **kwargs), self._loop + ) + sub = _future_result(fut, None) + sync_sub = SyncSubscriber(sub, self._loop) + if callback is not None: + self._callback_subscriptions.append((sync_sub, callback, zero_copy)) + return sync_sub + + def connect(self, from_topic: str, to_topic: str, timeout: float | None = None) -> None: + self._ensure_started() + assert self._graph_context is not None + assert self._loop is not None + fut = asyncio.run_coroutine_threadsafe( + self._graph_context.connect(from_topic, to_topic), self._loop + ) + _future_result(fut, timeout) + + def disconnect( + self, from_topic: str, to_topic: str, timeout: float | None = None + ) -> None: + self._ensure_started() + assert self._graph_context is not None + assert self._loop is not None + fut = asyncio.run_coroutine_threadsafe( + self._graph_context.disconnect(from_topic, to_topic), self._loop + ) + _future_result(fut, timeout) + + def sync(self, timeout: float | None = None) -> None: + self._ensure_started() + assert self._graph_context is not None + assert self._loop is not None + fut = asyncio.run_coroutine_threadsafe(self._graph_context.sync(), self._loop) + _future_result(fut, timeout) + + def pause(self, timeout: float | None = None) -> None: + self._ensure_started() + assert self._graph_context is not None + assert self._loop is not None + fut = asyncio.run_coroutine_threadsafe(self._graph_context.pause(), self._loop) + _future_result(fut, timeout) + + def resume(self, timeout: float | None = None) -> None: + self._ensure_started() + assert self._graph_context is not None + assert self._loop is not None + fut = asyncio.run_coroutine_threadsafe(self._graph_context.resume(), self._loop) + _future_result(fut, timeout) + + def spin(self, poll_interval: float = 0.1) -> None: + self._ensure_started() + if self._shutdown_requested.is_set(): + return + + self._spinning = True + try: + while not self._shutdown_requested.is_set(): + self.spin_once(timeout=poll_interval) + except KeyboardInterrupt: + self.shutdown() + finally: + self._spinning = False + if self._shutdown_requested.is_set() and not self._closed: + self._finalize_shutdown() + + def spin_once(self, timeout: float | None = 0.0) -> bool: + """ + Process at most one subscription callback. + + :param timeout: Seconds to wait for a callback. Use None to block forever. + :return: True if a callback was processed, False otherwise. + """ + self._ensure_started() + if self._shutdown_requested.is_set(): + return False + + if not self._callback_subscriptions: + if timeout is None: + self._shutdown_requested.wait() + elif timeout > 0: + self._shutdown_requested.wait(timeout) + return False + + assert self._loop is not None + fut = asyncio.run_coroutine_threadsafe( + _recv_any(self._callback_subscriptions, timeout), self._loop + ) + self._spin_future = fut + keep_future = False + try: + result = _future_result(fut, None) + except KeyboardInterrupt: + if not fut.done(): + fut.cancel() + keep_future = True + raise + finally: + if not keep_future: + self._spin_future = None + if result is None: + return False + + entry, cm, msg = result + _, callback, zero_copy = entry + + try: + if not zero_copy: + msg = deepcopy(msg) + callback(msg) + except Exception: + logger.exception("Unhandled exception in subscription callback") + finally: + exit_fut = asyncio.run_coroutine_threadsafe( + cm.__aexit__(None, None, None), self._loop + ) + try: + _future_result(exit_fut, None) + except CacheMiss: + logger.warning( + "Cache miss while releasing message; publisher likely exited." + ) + except Exception: + logger.exception("Failed while releasing message backpressure") + return True + + def shutdown(self) -> None: + self._shutdown_requested.set() + if self._spinning: + return + self._finalize_shutdown() + + def _finalize_shutdown(self) -> None: + if self._closed: + return + self._closed = True + + if self._spin_future is not None and not self._spin_future.done(): + self._spin_future.cancel() + try: + self._spin_future.result(timeout=1.0) + except Exception: + pass + self._spin_future = None + + if self._loop is not None: + + async def _cleanup() -> None: + if self._graph_context is not None: + await self._graph_context.__aexit__(None, None, None) + + fut = asyncio.run_coroutine_threadsafe(_cleanup(), self._loop) + _future_result(fut, None) + + if self._loop_cm is not None: + self._loop_cm.__exit__(None, None, None) + + self._loop_cm = None + self._loop = None + self._graph_context = None + + def _ensure_started(self) -> None: + if self._loop is None or self._graph_context is None: + raise RuntimeError("SyncContext must be entered with 'with' before use") + + +def init( + graph_address: AddressType | None = None, auto_start: bool | None = None +) -> SyncContext: + """Create a SyncContext for low-level synchronous usage.""" + return SyncContext(graph_address=graph_address, auto_start=auto_start) + + +def spin(context: SyncContext, poll_interval: float = 0.1) -> None: + """Spin a SyncContext, dispatching subscription callbacks.""" + context.spin(poll_interval=poll_interval) + + +def spin_once(context: SyncContext, timeout: float | None = 0.0) -> bool: + """Process at most one subscription callback.""" + return context.spin_once(timeout=timeout) + + +async def _recv_any( + entries: Iterable[tuple[SyncSubscriber, Callable[[Any], None], bool]], + timeout: float | None, +) -> tuple[tuple[SyncSubscriber, Callable[[Any], None], bool], Any, Any] | None: + async def _recv_entry( + entry: tuple[SyncSubscriber, Callable[[Any], None], bool] + ) -> tuple[tuple[SyncSubscriber, Callable[[Any], None], bool], Any, Any]: + sub, _, _ = entry + cm = sub._sub.recv_zero_copy() + msg = await cm.__aenter__() + return entry, cm, msg + + loop = asyncio.get_running_loop() + deadline = None if timeout is None else loop.time() + timeout + + while True: + remaining = None if deadline is None else max(0.0, deadline - loop.time()) + tasks = [asyncio.create_task(_recv_entry(entry)) for entry in entries] + try: + done, pending = await asyncio.wait( + tasks, timeout=remaining, return_when=asyncio.FIRST_COMPLETED + ) + if not done: + for task in pending: + try: + task.cancel() + except RuntimeError: + pass + await asyncio.gather(*pending, return_exceptions=True) + return None + + for task in pending: + try: + task.cancel() + except RuntimeError: + pass + await asyncio.gather(*pending, return_exceptions=True) + + for task in done: + try: + return task.result() + except CacheMiss: + # Likely stale notification after publisher exit; keep waiting. + continue + except asyncio.CancelledError: + continue + except Exception: + logger.exception("Sync subscription receive failed") + continue + + # Only CacheMiss/cancelled/error occurred; continue within timeout window. + if deadline is not None and loop.time() >= deadline: + return None + finally: + for task in tasks: + if not task.done(): + try: + task.cancel() + except RuntimeError: + pass diff --git a/tests/perf_sync_overhead.py b/tests/perf_sync_overhead.py new file mode 100644 index 0000000..bbf4680 --- /dev/null +++ b/tests/perf_sync_overhead.py @@ -0,0 +1,119 @@ +import argparse +import asyncio +import statistics +import time + +import ezmsg.core as ez + + +TOPIC = "/PERF_SYNC_OVERHEAD" + + +def _format_result(label: str, seconds: float, count: int) -> str: + per_msg_us = (seconds / count) * 1e6 + rate = count / seconds if seconds > 0 else float("inf") + return f"{label}: {seconds:.4f}s total, {per_msg_us:.2f} us/msg, {rate:,.0f} msg/s" + + +async def _async_roundtrip( + count: int, + warmup: int, + host: str, + port: int, + force_tcp: bool, +) -> float: + async with ez.GraphContext((host, port), auto_start=True) as ctx: + pub = await ctx.publisher(TOPIC, host=host, force_tcp=force_tcp) + sub = await ctx.subscriber(TOPIC) + + for i in range(warmup): + await pub.broadcast(i) + await sub.recv() + + start = time.perf_counter() + for i in range(count): + await pub.broadcast(i) + await sub.recv() + return time.perf_counter() - start + + +def _sync_roundtrip( + count: int, + warmup: int, + host: str, + port: int, + force_tcp: bool, +) -> float: + with ez.sync.init((host, port), auto_start=True) as ctx: + pub = ctx.create_publisher(TOPIC, host=host, force_tcp=force_tcp) + sub = ctx.create_subscription(TOPIC) + + for i in range(warmup): + pub.publish(i) + sub.recv() + + start = time.perf_counter() + for i in range(count): + pub.publish(i) + sub.recv() + return time.perf_counter() - start + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Measure sync wrapper overhead vs async API." + ) + parser.add_argument("--count", type=int, default=10_000, help="messages per run") + parser.add_argument("--warmup", type=int, default=1_000, help="warmup messages") + parser.add_argument("--repeats", type=int, default=3, help="number of repeats") + parser.add_argument("--host", default="127.0.0.1", help="graphserver host") + parser.add_argument( + "--port", + type=int, + default=0, + help="graphserver port (0 uses ephemeral port)", + ) + parser.add_argument( + "--force-tcp", + action="store_true", + help="force TCP transport (disable SHM)", + ) + args = parser.parse_args() + + async_times = [] + sync_times = [] + + for _ in range(args.repeats): + async_times.append( + asyncio.run( + _async_roundtrip( + args.count, + args.warmup, + args.host, + args.port, + args.force_tcp, + ) + ) + ) + sync_times.append( + _sync_roundtrip( + args.count, + args.warmup, + args.host, + args.port, + args.force_tcp, + ) + ) + + async_med = statistics.median(async_times) + sync_med = statistics.median(sync_times) + + print(_format_result("async", async_med, args.count)) + print(_format_result("sync ", sync_med, args.count)) + if async_med > 0: + overhead = (sync_med / async_med - 1.0) * 100.0 + print(f"overhead: {overhead:.1f}%") + + +if __name__ == "__main__": + main() diff --git a/tests/test_sync_api.py b/tests/test_sync_api.py new file mode 100644 index 0000000..4cd307a --- /dev/null +++ b/tests/test_sync_api.py @@ -0,0 +1,79 @@ +import socket +import threading +import time +from contextlib import asynccontextmanager + +import pytest + +import ezmsg.core as ez +from ezmsg.core import sync as sync_mod +from ezmsg.core.messagecache import CacheMiss + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + +def test_sync_autostart_and_spin_once_no_messages(): + host = "127.0.0.1" + port = _free_port() + + with ez.sync.init((host, port), auto_start=True) as ctx: + ctx.create_subscription("/TEST", callback=lambda _: None) + assert ctx.spin_once(timeout=0.01) is False + + +def test_sync_backpressure_blocks_publish(): + host = "127.0.0.1" + port = _free_port() + + with ez.sync.init((host, port), auto_start=True) as ctx: + received = [] + done = threading.Event() + + def on_msg(msg: str) -> None: + time.sleep(0.2) + received.append(msg) + if len(received) >= 2: + done.set() + + ctx.create_subscription("/TEST", callback=on_msg, zero_copy=True) + pub = ctx.create_publisher("/TEST", num_buffers=1, force_tcp=True) + + spin_thread = threading.Thread( + target=ctx.spin, + kwargs={"poll_interval": 0.01}, + daemon=True, + ) + spin_thread.start() + + time.sleep(0.05) + pub.publish("one") + t1 = time.monotonic() + pub.publish("two") + t2 = time.monotonic() + + assert t2 - t1 >= 0.15 + done.wait(2.0) + + ctx.shutdown() + spin_thread.join(timeout=1.0) + + +@pytest.mark.asyncio +async def test_recv_any_cachemiss_does_not_raise(): + class DummySub: + @asynccontextmanager + async def recv_zero_copy(self): + raise CacheMiss + yield + + class DummySyncSub: + def __init__(self) -> None: + self._sub = DummySub() + + entry = (DummySyncSub(), lambda _: None, True) + result = await sync_mod._recv_any([entry], timeout=0.01) + assert result is None