diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 8b0bcae..5b42a9b 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -62,6 +62,10 @@ jobs: image: rabbitmq:4.0.2-alpine ports: - 5672:5672 + redis: + image: redis:7.4.1-alpine + ports: + - 6379:6379 strategy: fail-fast: false matrix: @@ -88,7 +92,7 @@ jobs: - name: Install the project dependencies run: poetry install --all-extras - name: Run pytest - run: poetry run pytest --broker-url=amqp://guest:guest@localhost:5672/ --broker-retry-delay=3 --broker-retry-delay-mode=constant --broker-retries-limit=10 --cov-report=xml + run: poetry run pytest --broker-retry-delay=3 --broker-retry-delay-mode=constant --broker-retries-limit=10 --cov-report=xml - name: Upload results to Codecov uses: codecov/codecov-action@v4 with: diff --git a/README.md b/README.md index ba68e4b..7655090 100644 --- a/README.md +++ b/README.md @@ -12,13 +12,14 @@ message brokers. ## key features +* builtin CLI to work with broker * same protobuf structures as in gRPC * similar calls as in gRPC * unary-unary * (TODO) unary-stream * (TODO) stream-unary * (TODO) stream-stream -* declarative style, abstract from broker commands (such as declare_exchange / queue_bind) +* resource declarative style, no broker commands (just `async with broker.publisher(...)` or `async with broker.consumer(...)`) * publisher & consumer middlewares * message serializers @@ -30,27 +31,56 @@ pyprotostuben project example for more details. You may configure codegen output using protobuf extensions from [buf schema registry](https://buf.build/zerlok/brokrpc). -## supported brokers & protocols +## drivers -* [AMQP](https://www.rabbitmq.com/tutorials/amqp-concepts) - * [aiormq](https://github.com/mosquito/aiormq) -* (TODO) redis -* (TODO) kafka -* (TODO) NATS +| driver | broker | +|------------|----------| +| `aiormq` | RabbitMQ | +| `aioredis` | Redis | +| (TODO) | Kafka | +| (TODO) | NATs | ## usage [pypi package](https://pypi.python.org/pypi/BrokRPC) -install with your favorite python package manager +install with your favorite python package manager, specify driver in extra (e.g. for `aiormq`): ```shell pip install BrokRPC[aiormq] ``` +### CLI + +`brokrpc` commands either publish messages from stdin or prints consumed messages to stdout. See help for more info. + +```shell +brokrpc --help +``` + +#### examples + +Start JSON publisher (press `ctrl+d` to stop): + +```shell +brokrpc amqp://guest:guest@localhost:5672/ publish -e my-exchange my-rk --encoder=json +``` + +Start JSON consumer (press `ctrl+d` to stop): + +```shell +brokrpc amqp://guest:guest@localhost:5672/ consume -e my-exchange my-rk --decoder=json +``` + +Check broker connection (returns `0` exit code on success and `1` on failure): + +```shell +brokrpc --retry-delay=3 --retries-limit=10 amqp://guest:guest@localhost:5672/ check +``` + ### Broker -use Broker as high level API to create consumers & publishers +Broker class is a part of high-level API. All consumers & publishers are created via this class. ```python from brokrpc.broker import Broker @@ -62,6 +92,10 @@ async with Broker(...) as broker: ... ``` +### BrokerDriver + +Driver is a part of low-level API. Broker delegates calls to driver. + ### Consumer ```python @@ -75,7 +109,7 @@ from brokrpc.serializer.json import JSONSerializer async def register_consumer(broker: Broker) -> None: # define consumer function (you also can use async function & `Consumer` interface). def consume_binary_message(message: Message[bytes]) -> None: - print(message) + print(message) # consumer is not attached yet @@ -204,7 +238,7 @@ async def main() -> None: # connect to broker async with broker: - # start RPC server until SIGINT or SIGTERM + # run RPC server until SIGINT or SIGTERM await server.run_until_terminated() diff --git a/docker-compose.yml b/docker-compose.yml index 1a8d115..18949d2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,4 @@ -version: '3.3' +version: '3.7' services: rabbitmq: @@ -6,3 +6,8 @@ services: ports: - "5672:5672" - "15672:15672" + + redis: + image: redis:7.4.1-alpine + ports: + - "6379:6379" diff --git a/poetry.lock b/poetry.lock index b415cc8..555f3b4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,9 +1,10 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. [[package]] name = "aiofiles" version = "24.1.0" description = "File support for asyncio." +category = "main" optional = true python-versions = ">=3.8" files = [ @@ -15,6 +16,7 @@ files = [ name = "aiormq" version = "6.8.1" description = "Pure python AMQP asynchronous client library" +category = "main" optional = true python-versions = "<4.0,>=3.8" files = [ @@ -30,6 +32,7 @@ yarl = "*" name = "attrs" version = "24.2.0" description = "Classes Without Boilerplate" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -49,6 +52,7 @@ tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." +category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" files = [ @@ -60,6 +64,7 @@ files = [ name = "coverage" version = "7.6.8" description = "Code coverage measurement for Python" +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -134,6 +139,7 @@ toml = ["tomli"] name = "decorator" version = "5.1.1" description = "Decorators for Humans" +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -145,6 +151,7 @@ files = [ name = "googleapis-common-protos" version = "1.66.0" description = "Common protobufs used in Google APIs" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -162,6 +169,7 @@ grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] name = "idna" version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" +category = "main" optional = false python-versions = ">=3.6" files = [ @@ -176,6 +184,7 @@ all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2 name = "iniconfig" version = "2.0.0" description = "brain-dead simple config-ini parsing" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -187,6 +196,7 @@ files = [ name = "jinja2" version = "3.1.4" description = "A very fast and expressive template engine." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -204,6 +214,7 @@ i18n = ["Babel (>=2.7)"] name = "jsonschema" version = "4.23.0" description = "An implementation of JSON Schema validation for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -225,6 +236,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339- name = "jsonschema-specifications" version = "2024.10.1" description = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -239,6 +251,7 @@ referencing = ">=0.31.0" name = "markupsafe" version = "3.0.2" description = "Safely add untrusted strings to HTML/XML markup." +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -309,6 +322,7 @@ files = [ name = "multidict" version = "6.1.0" description = "multidict implementation" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -410,6 +424,7 @@ files = [ name = "mypy" version = "1.13.0" description = "Optional static typing for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -462,6 +477,7 @@ reports = ["lxml"] name = "mypy-extensions" version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." +category = "dev" optional = false python-versions = ">=3.5" files = [ @@ -473,6 +489,7 @@ files = [ name = "packaging" version = "24.2" description = "Core utilities for Python packages" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -484,6 +501,7 @@ files = [ name = "pamqp" version = "3.3.0" description = "RabbitMQ Focused AMQP low-level library" +category = "main" optional = true python-versions = ">=3.7" files = [ @@ -499,6 +517,7 @@ testing = ["coverage", "flake8", "flake8-comprehensions", "flake8-deprecated", " name = "pluggy" version = "1.5.0" description = "plugin and hook calling mechanisms for python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -514,6 +533,7 @@ testing = ["pytest", "pytest-benchmark"] name = "propcache" version = "0.2.0" description = "Accelerated property cache" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -621,6 +641,7 @@ files = [ name = "protobuf" version = "5.29.0" description = "" +category = "main" optional = false python-versions = ">=3.8" files = [ @@ -641,6 +662,7 @@ files = [ name = "pyprotostuben" version = "0.3.3" description = "Generate Python MyPy stub modules from protobuf files." +category = "dev" optional = false python-versions = "<4.0,>=3.9" files = [ @@ -655,6 +677,7 @@ protobuf = ">=5.28.3,<6.0.0" name = "pytest" version = "8.3.4" description = "pytest: simple powerful testing with Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -675,6 +698,7 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments name = "pytest-asyncio" version = "0.24.0" description = "Pytest support for asyncio" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -693,6 +717,7 @@ testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] name = "pytest-cov" version = "6.0.0" description = "Pytest plugin for measuring coverage." +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -711,6 +736,7 @@ testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] name = "pytest-mypy-plugins" version = "3.1.2" description = "pytest plugin for writing tests for mypy plugins" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -733,6 +759,7 @@ tomlkit = ">=0.11" name = "pytest-timeout" version = "2.3.1" description = "pytest plugin to abort hanging tests" +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -747,6 +774,7 @@ pytest = ">=7.0.0" name = "pyyaml" version = "6.0.2" description = "YAML parser and emitter for Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -805,10 +833,27 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "redis" +version = "5.2.0" +description = "Python client for Redis database and key-value store" +category = "main" +optional = true +python-versions = ">=3.8" +files = [ + {file = "redis-5.2.0-py3-none-any.whl", hash = "sha256:ae174f2bb3b1bf2b09d54bf3e51fbc1469cf6c10aa03e21141f51969801a7897"}, + {file = "redis-5.2.0.tar.gz", hash = "sha256:0b1087665a771b1ff2e003aa5bdd354f15a70c9e25d5a7dbf9c722c16528a7b0"}, +] + +[package.extras] +hiredis = ["hiredis (>=3.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==23.2.1)", "requests (>=2.31.0)"] + [[package]] name = "referencing" version = "0.35.1" description = "JSON Referencing + Python" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -824,6 +869,7 @@ rpds-py = ">=0.7.0" name = "regex" version = "2024.11.6" description = "Alternative regular expression module, to replace re." +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -927,6 +973,7 @@ files = [ name = "rpds-py" version = "0.21.0" description = "Python bindings to Rust's persistent data structures (rpds)" +category = "dev" optional = false python-versions = ">=3.9" files = [ @@ -1026,6 +1073,7 @@ files = [ name = "ruff" version = "0.8.1" description = "An extremely fast Python linter and code formatter, written in Rust." +category = "dev" optional = false python-versions = ">=3.7" files = [ @@ -1053,6 +1101,7 @@ files = [ name = "tomlkit" version = "0.13.2" description = "Style preserving TOML library" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1064,6 +1113,7 @@ files = [ name = "types-aiofiles" version = "24.1.0.20240626" description = "Typing stubs for aiofiles" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1075,6 +1125,7 @@ files = [ name = "types-protobuf" version = "5.28.3.20241030" description = "Typing stubs for protobuf" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1086,6 +1137,7 @@ files = [ name = "typing-extensions" version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" +category = "dev" optional = false python-versions = ">=3.8" files = [ @@ -1097,6 +1149,7 @@ files = [ name = "yarl" version = "1.18.0" description = "Yet another URL library" +category = "main" optional = false python-versions = ">=3.9" files = [ @@ -1193,8 +1246,9 @@ propcache = ">=0.2.0" aiormq = ["aiormq"] cli = ["aiofiles"] protobuf = ["googleapis-common-protos", "protobuf"] +redis = ["redis"] [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "7f490756fb5a0f8f0c622478b1ef0db6691636ed034bc363542979fe66a58009" +content-hash = "b729a8471582ae647e5b43cd5b0d270e884be3c6f86ac96974cdd9889439a0cb" diff --git a/pyproject.toml b/pyproject.toml index 74ae393..133303f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,11 +33,13 @@ protobuf = {version = "^5.26.1", optional = true} googleapis-common-protos = {version = "^1.65.0", optional = true} aiormq = {version = "^6.8.1", optional = true} aiofiles = {version = "^24.1.0", optional = true} +redis = {version = "^5.2.0", optional = true} [tool.poetry.extras] cli = ["aiofiles"] protobuf = ["protobuf", "googleapis-common-protos"] aiormq = ["aiormq"] +redis = ["redis"] [tool.poetry.scripts] brokrpc = "brokrpc.cli:main" @@ -84,6 +86,7 @@ ignore = [ "tests/**" = [ "ARG001", # it's ok to use a fixture with a side effect in a test. "PT004", # fixture may add side effect and doesn't return anything. + "PLR0913", # test functions & fixtures may have a lot of arguments. ] diff --git a/src/brokrpc/abc.py b/src/brokrpc/abc.py index 91b3f35..238045b 100644 --- a/src/brokrpc/abc.py +++ b/src/brokrpc/abc.py @@ -30,6 +30,7 @@ async def consume(self, inner: S, message: U) -> V: raise NotImplementedError +# TODO: consider interface (e.g. make is_alive async) class BoundConsumer(metaclass=abc.ABCMeta): @abc.abstractmethod def is_alive(self) -> bool: diff --git a/src/brokrpc/broker.py b/src/brokrpc/broker.py index a925782..7372376 100644 --- a/src/brokrpc/broker.py +++ b/src/brokrpc/broker.py @@ -60,6 +60,7 @@ def __init__( ConsumerMiddleware[BinaryConsumer, BinaryMessage, ConsumerResult] ] = default_consumer_middlewares or () + self.__cm_enter_count = 0 self.__stack = AsyncExitStack() self.__lock = asyncio.Lock() self.__opened = asyncio.Event() @@ -85,10 +86,15 @@ def __repr__(self) -> str: async def __aenter__(self) -> t.Self: await self.connect() + self.__cm_enter_count += 1 + return self async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> bool | None: - await self.disconnect() + self.__cm_enter_count -= 1 + if self.__cm_enter_count <= 0: + await self.disconnect() + return None async def connect(self) -> None: @@ -262,6 +268,11 @@ def connect(options: BrokerConnectOptions) -> t.AsyncContextManager[BrokerDriver return AiormqBrokerDriver.connect(clean_options) + elif clean_options.driver == "redis": + from brokrpc.driver.redis import RedisBrokerDriver + + return RedisBrokerDriver.connect(clean_options) + else: details = "unsupported driver" raise ValueError(details, clean_options) diff --git a/src/brokrpc/cli.py b/src/brokrpc/cli.py index f0329f8..ea6fd3f 100644 --- a/src/brokrpc/cli.py +++ b/src/brokrpc/cli.py @@ -1,24 +1,177 @@ +from __future__ import annotations + import abc import asyncio +import json import sys import typing as t from argparse import ArgumentParser -from functools import partial -from signal import Signals +from dataclasses import dataclass, replace +from datetime import timedelta from uuid import uuid4 from aiofiles import stderr_bytes, stdin_bytes, stdout_bytes from yarl import URL -from brokrpc.abc import Serializer -from brokrpc.broker import Broker -from brokrpc.message import AppMessage, BinaryMessage, Message +from brokrpc.abc import Consumer, Serializer +from brokrpc.broker import Broker, parse_options +from brokrpc.message import AppMessage, BinaryMessage, Message, PackedMessage from brokrpc.middleware import AbortBadMessageMiddleware -from brokrpc.options import BindingOptions, ExchangeOptions, PublisherOptions, QueueOptions +from brokrpc.model import BrokerError, ConsumerResult +from brokrpc.options import BindingOptions, BrokerOptions, ExchangeOptions, PublisherOptions, QueueOptions from brokrpc.plugin import Loader +from brokrpc.retry import ConstantDelay, DelayRetryStrategy, ExponentialDelay, MultiplierDelay from brokrpc.serializer.ident import IdentSerializer +@dataclass() +class Options: + url: URL | None = None + retry_delay: timedelta | None = None + retry_delay_mode: t.Literal["constant", "multiplier", "exponential"] | None = None + retry_max_delay: timedelta | None = None + retries_timeout: timedelta | None = None + retries_limit: int | None = None + mode: t.Literal["publish", "consume", "check"] | None = None + exchange: str | None = None + routing_key: str | None = None + encoder: Serializer[BinaryMessage, BinaryMessage] | None = None + decoder: Serializer[bytes, BinaryMessage] | None = None + message_mode: t.Literal["body", "wide"] | None = None + stdin: AsyncStream | None = None + stdout: AsyncStream | None = None + stderr: AsyncStream | None = None + quiet: bool | None = None + done: asyncio.Event | None = None + + +class Factory: + def __init__(self, options: Options) -> None: + self.__options = options + + def create_app(self) -> App: + match self.__options.mode: + case None: + details = "mode is required" + raise ValueError(details, self) + + case "publish": + return self.__create_publisher_app() + + case "consume": + return self.__create_consumer_app() + + case "check": + return self.__create_waiter_app() + + case _: + t.assert_never(self.__options.mode) + + def __create_publisher_app(self) -> PublisherApp: + if self.__options.routing_key is None: + details = "routing key is required for publisher mode" + raise ValueError(details, self) + + return PublisherApp( + broker=self.__create_broker(), + options=( + PublisherOptions( + name=self.__options.exchange, + prefetch_count=1, + ) + ), + routing_key=self.__options.routing_key, + serializer=self.__options.encoder or IdentSerializer(), + console=self.__create_console(), + ) + + def __create_consumer_app(self) -> ConsumerApp: + if self.__options.routing_key is None: + details = "routing key is required for publisher mode" + raise ValueError(details, self) + + return ConsumerApp( + broker=self.__create_broker(), + options=BindingOptions( + exchange=ExchangeOptions(name=self.__options.exchange), + binding_keys=[self.__options.routing_key], + queue=QueueOptions( + name=f"{__package__}.{uuid4().hex}", + exclusive=True, + auto_delete=True, + prefetch_count=1, + ), + ), + serializer=self.__options.decoder or IdentSerializer(), + console=self.__create_console(), + ) + + def __create_waiter_app(self) -> WaiterApp: + return WaiterApp( + broker=self.__create_broker(), + console=self.__create_console(), + ) + + def __create_broker(self) -> Broker: + return Broker( + options=self.__create_broker_options(), + default_publisher_middlewares=[AbortBadMessageMiddleware()], + default_consumer_middlewares=[AbortBadMessageMiddleware()], + ) + + def __create_console(self) -> Console: + if self.__options.stdin is None: + details = "stdin is required" + raise ValueError(details, self.__options) + + return Console( + stdin=self.__options.stdin, + stdout=self.__options.stdout, + stderr=self.__options.stderr if not self.__options.quiet else None, + ) + + def __create_broker_options(self) -> BrokerOptions: + if self.__options.url is None: + details = "url must be set" + raise ValueError(details, self.__options) + + return replace( + parse_options(self.__options.url), + retry_delay=self.__create_retry_delay_strategy(), + retries_timeout=self.__options.retries_timeout, + retries_limit=self.__options.retries_limit, + ) + + def __create_retry_delay_strategy(self) -> DelayRetryStrategy | None: + if self.__options.retry_delay is None: + return None + + match self.__options.retry_delay_mode: + case "constant": + return ConstantDelay( + delay=self.__options.retry_delay, + ) + + case "multiplier": + return MultiplierDelay( + delay=self.__options.retry_delay, + max_delay=self.__options.retry_max_delay, + ) + + case "exponential": + return ExponentialDelay( + delay=self.__options.retry_delay, + max_delay=self.__options.retry_max_delay, + ) + + case None: + return None + + case _: + details = "unknown mode" + raise ValueError(details, self.__options.retry_delay_mode) + + class AsyncStream(t.Protocol): @abc.abstractmethod def __aiter__(self) -> t.AsyncIterator[bytes]: @@ -45,227 +198,434 @@ async def close(self) -> None: raise NotImplementedError -class CLIOptions: - url: URL - exchange: ExchangeOptions | None - routing_key: str - kind: t.Literal["consumer", "publisher"] - serializer: Serializer[Message[object], BinaryMessage] - output_mode: t.Literal["wide", "body"] - quiet: bool - input: AsyncStream - output: AsyncStream - err: AsyncStream - done: asyncio.Event - - def __str__(self) -> str: - data = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items()) - return f"{self.__class__.__name__}({data})" - - __repr__ = __str__ - - -async def run(options: CLIOptions) -> int: - async with ( - Broker( - options.url, - default_publisher_middlewares=[ - AbortBadMessageMiddleware(), - ], - default_consumer_middlewares=[ - AbortBadMessageMiddleware(), - ], - ) as broker, - ): - match options.kind: - case "consumer": - return await run_consumer(broker, options) - - case "publisher": - return await run_publisher(broker, options) +class Console: + def __init__( + self, + stdin: AsyncStream, + stdout: AsyncStream | None, + stderr: AsyncStream | None, + ) -> None: + self.__stdin = stdin + self.__stdout = stdout + self.__stderr = stderr - case _: - t.assert_never(options.kind) - - -async def run_consumer(broker: Broker, options: CLIOptions) -> int: - async with broker.consumer( - partial(consume_message, options.output) - if options.output_mode == "wide" - else partial(consume_message_body, options.output), - BindingOptions( - exchange=options.exchange, - binding_keys=[options.routing_key], - queue=QueueOptions( - name=f"{__package__}.{uuid4().hex}", - exclusive=True, - auto_delete=True, - prefetch_count=1, + def read(self) -> t.AsyncIterable[bytes]: + return aiter(self.__stdin) + + async def write(self, *lines: bytes | str) -> None: + if not lines or self.__stdout is None: + return + + for line in lines: + await self.__stdout.write((line if isinstance(line, bytes) else line.encode()) + b"\n") + + await self.__stdout.flush() + + async def log(self, line: str) -> None: + if self.__stderr is None: + return + + await self.__stderr.write(line.encode() + b"\n") + await self.__stderr.flush() + + async def error(self, err: Exception) -> None: + await self.log(f"error: {err!r}") + + async def close(self) -> None: + await self.__stdin.close() + + +class App(metaclass=abc.ABCMeta): + @abc.abstractmethod + async def run(self) -> int: + raise NotImplementedError + + +class PublisherApp(App): + def __init__( + self, + broker: Broker, + console: Console, + options: PublisherOptions, + routing_key: str, + serializer: Serializer[BinaryMessage, BinaryMessage], + ) -> None: + self.__broker = broker + self.__console = console + self.__options = options + self.__routing_key = routing_key + self.__serializer = serializer + + async def run(self) -> int: + async with ( + self.__broker, + self.__broker.publisher(options=self.__options, serializer=self.__serializer) as publisher, + ): + async for body in self.__console.read(): + message = AppMessage(body=body, routing_key=self.__routing_key) + result = await publisher.publish(message) + + match result: + case True | None: + await self.__console.log("message sent") + + case False: + await self.__console.log(f"failed to send: {body!r}") + + case _: + t.assert_never(result) + + return 0 + + +class ConsumerApp(App, Consumer[bytes, ConsumerResult]): + def __init__( + self, + broker: Broker, + options: BindingOptions, + serializer: Serializer[bytes, BinaryMessage], + console: Console, + ) -> None: + self.__broker = broker + self.__options = options + self.__serializer = serializer + self.__console = console + + async def run(self) -> int: + async with ( + self.__broker, + self.__broker.consumer( + consumer=self, + options=self.__options, + serializer=self.__serializer, ), - ), - serializer=options.serializer, - ): - await options.done.wait() + ): + async for _ in self.__console.read(): + pass + + return 0 + + async def consume(self, message: bytes) -> ConsumerResult: + await self.__console.write(message) + await self.__console.log("message consumed") + return True + + +class WaiterApp(App): + def __init__(self, broker: Broker, console: Console) -> None: + self.__broker = broker + self.__console = console + + async def run(self) -> int: + try: + async with self.__broker: + pass + + except BrokerError as err: + await self.__console.error(err) + return 1 + + else: + await self.__console.log("successfully connected to broker") + return 0 + + # TODO: implement + async def stop(self) -> None: + pass + + +class Parser: + def __init__(self, *, exit_on_error: bool = True) -> None: + self.__impl = self.__build_impl(exit_on_error=exit_on_error) + + def parse(self, args: t.Sequence[str] | None = None, default: Options | None = None) -> Options: + return self.__impl.parse_args( + args=args, + namespace=default + if default is not None + else Options( + stdin=stdin_bytes, + stdout=stdout_bytes, + stderr=stderr_bytes, + done=asyncio.Event(), + ), + ) - return 0 + @classmethod + def __build_impl(cls, *, exit_on_error: bool) -> ArgumentParser: + parser = ArgumentParser( + description="communicate with broker", + exit_on_error=exit_on_error, + ) + parser.register("type", "seconds", cls.__parse_seconds) + cls.__add_main_opts(parser) + modes_subparser = parser.add_subparsers( + title="modes", + required=True, + ) -async def run_publisher(broker: Broker, options: CLIOptions) -> int: - async with broker.publisher( - PublisherOptions( - name=options.exchange.name, - prefetch_count=1, + # publisher mode + publisher_parser = modes_subparser.add_parser( + name="publish", + help="publish messages from stdin", + exit_on_error=exit_on_error, ) - if options.exchange is not None - else PublisherOptions(prefetch_count=1), - serializer=options.serializer, - ) as publisher: - async for body in options.input: - message = AppMessage(body=body, routing_key=options.routing_key) - result = await publisher.publish(message) - - match result: - case True | None: - if not options.quiet: - await options.err.write(b"sent\n") - await options.err.flush() - - case False: - if not options.quiet: - await options.err.write(f"failed to send: {message!r}\n".encode()) - await options.err.flush() - - case _: - t.assert_never(result) - - return 0 - - -async def consume_message(out: AsyncStream, message: Message[object]) -> None: - await out.write(f"{message!r}\n".encode()) - await out.flush() - - -async def consume_message_body(out: AsyncStream, message: Message[bytes]) -> None: - await out.write(message.body + b"\n") - await out.flush() - - -def build_parser() -> ArgumentParser: - parser = ArgumentParser( - description="""Exchange messages with broker.""", - ) - parser.add_argument( - "url", - type=URL, - help="broker url to connect to", - ) - parser.add_argument( - "routing_key", - type=str, - help="routing key to use for message publishing / consumption", - ) - parser.add_argument( - "-e", - "--exchange", - type=parse_exchange_name, - help="exchange name to connect to", - default=None, - ) - - group = parser.add_mutually_exclusive_group(required=True) - group.add_argument( - "-c", - "--consumer", - dest="kind", - action="store_const", - const="consumer", - help="consumer mode. Print received messages to stdout.", - ) - group.add_argument( - "-p", - "--publisher", - dest="kind", - action="store_const", - const="publisher", - help="publisher mode. Send messages from stdin.", - ) - - parser.add_argument( - "-s", - "--serializer", - type=parse_serializer, - help="customize serializer. Available options: json, protobuf:{plugin to protobuf message class}, " - "{plugin to serializer class}. Default: none (pass binary data to body as is).", - default=IdentSerializer(), - ) - - parser.add_argument( - "-o", - "--output", - type=str, - dest="output_mode", - choices=["wide", "body"], - help="customize output mode. Available options: wide, body. Default: body", - default="body", - ) - - parser.add_argument( - "-q", - "--quiet", - dest="quiet", - action="store_const", - const=True, - help="don't print to stderr", - default=False, - ) - - return parser - - -def parse_exchange_name(value: str) -> ExchangeOptions: - return ExchangeOptions(name=value) - - -def parse_serializer(value: str | None) -> Serializer[t.Any, BinaryMessage]: - if value is None: - return IdentSerializer() - - match value.split(":", maxsplit=1): - case ["raw"]: - return IdentSerializer() - - case ["json"]: - from brokrpc.serializer.json import JSONSerializer - - return JSONSerializer() - - case ["protobuf", msg_type]: - from google.protobuf.message import Message as ProtobufMessage - - from brokrpc.serializer.protobuf import JSONProtobufSerializer - - return JSONProtobufSerializer(Loader[ProtobufMessage]().load_class(msg_type)) - - case _: - return Loader[Serializer[t.Any, BinaryMessage]]().load_instance(value) + publisher_parser.set_defaults(mode="publish") + cls.__add_routing_opts(publisher_parser) + publisher_parser.add_argument( + "--encoder", + # TODO: find a way to register custom type with `parser.register` + type=cls.__parse_encoder, + help=""" + set message encoder; + predefined: raw, json, protobuf:{entrypoint to protobuf message class}, {entrypoint to serializer callable}; + default: raw (keep binary data from body as is) + """, + default=None, + ) -def main() -> None: - options = build_parser().parse_args(namespace=CLIOptions()) + # consumer mode + consumer_parser = modes_subparser.add_parser( + name="consume", + help="consume messages to stdout", + exit_on_error=exit_on_error, + ) + consumer_parser.set_defaults(mode="consume") + cls.__add_routing_opts(consumer_parser) + consumer_parser.add_argument( + "--decoder", + # TODO: find a way to register custom type with `parser.register` + type=cls.__parse_decoder, + help=""" + set message decoder; + predefined: raw, repr, json, protobuf:{entrypoint to protobuf message class}, + {entrypoint to serializer callable}; + default: raw (keep binary data from body as is) + """, + default=None, + ) + + # checker mode + checker_parser = modes_subparser.add_parser( + name="check", + help="check connection with broker", + exit_on_error=exit_on_error, + ) + checker_parser.set_defaults(mode="check") + + return parser + + @classmethod + def __add_main_opts(cls, parser: ArgumentParser) -> None: + parser.add_argument( + "url", + type=URL, + help="broker url to connect to", + ) - options.input = stdin_bytes - options.output = stdout_bytes - options.err = stderr_bytes - done = options.done = asyncio.Event() + retry_group = parser.add_argument_group( + title="retry options", + description="customize retry logic when connecting to broker", + ) + retry_group.add_argument( + "--retry-delay", + type="seconds", + default=None, + metavar="S", + ) + retry_group.add_argument( + "--retry-delay-mode", + type=str, + choices=["constant", "multiplier", "exponential"], + default=None, + ) + retry_group.add_argument( + "--retry-max-delay", + type="seconds", + default=None, + metavar="S", + ) + retry_group.add_argument( + "--retries-timeout", + type="seconds", + default=None, + metavar="S", + ) + retry_group.add_argument( + "--retries-limit", + type=int, + default=None, + metavar="N", + ) + + parser.add_argument( + "-q", + "--quiet", + dest="quiet", + action="store_true", + help="don't print to stderr; default: %(default)s", + default=False, + ) + + @classmethod + def __add_routing_opts(cls, parser: ArgumentParser) -> None: + group = parser.add_argument_group( + title="routing options", + ) + + group.add_argument( + "routing_key", + type=str, + help="set routing key to use for messages", + default=None, + ) + group.add_argument( + "-e", + "--exchange", + type=str, + help="exchange name to connect to", + default=None, + ) + + @classmethod + def __parse_seconds(cls, value: str) -> timedelta: + return timedelta(seconds=float(value)) + + @classmethod + def __parse_encoder(cls, value: str) -> Serializer[BinaryMessage, BinaryMessage] | None: + match value.split(":", maxsplit=1): + case ["raw"]: + return IdentSerializer() + + case ["json"]: + return cls.__create_json_encoder() + + case ["protobuf", msg_type]: + return cls.__create_protobuf_encoder(msg_type) + + case _: + return Loader[Serializer[BinaryMessage, BinaryMessage]]().load_instance(value) + + @classmethod + def __create_protobuf_encoder(cls, msg: str) -> Serializer[BinaryMessage, BinaryMessage]: + from google.protobuf.message import Message as ProtobufMessage + + from brokrpc.serializer.protobuf import ProtobufSerializer + + message_class = Loader[ProtobufMessage]().load_class(msg) + + class ProtobufEncoder(Serializer[BinaryMessage, BinaryMessage]): + def __init__(self, inner: Serializer[Message[ProtobufMessage], BinaryMessage]) -> None: + self.__inner = inner + + def dump_message(self, message: BinaryMessage) -> BinaryMessage: + return self.__inner.dump_message(PackedMessage(original=message, body=json.loads(message.body))) + + def load_message(self, message: BinaryMessage) -> BinaryMessage: + return message + + return ProtobufEncoder(ProtobufSerializer(message_class)) + + @classmethod + def __create_json_encoder(cls) -> Serializer[BinaryMessage, BinaryMessage]: + from brokrpc.serializer.json import JSONSerializer + + class JSONEncoder(Serializer[BinaryMessage, BinaryMessage]): + def __init__(self, inner: Serializer[Message[object], BinaryMessage]) -> None: + self.__inner = inner - async def _main() -> int: - for sig in (Signals.SIGINT, Signals.SIGTERM): - asyncio.get_running_loop().add_signal_handler(sig, done.set) + def dump_message(self, message: BinaryMessage) -> BinaryMessage: + return self.__inner.dump_message(PackedMessage(original=message, body=json.loads(message.body))) - return await run(options) + def load_message(self, message: BinaryMessage) -> BinaryMessage: + return message - sys.exit(asyncio.run(_main())) + return JSONEncoder(JSONSerializer()) + + @classmethod + def __parse_decoder(cls, value: str) -> Serializer[bytes, BinaryMessage] | None: + match value.split(":", maxsplit=1): + case ["raw"]: + return cls.__create_raw_decoder() + + case ["repr"]: + return cls.__create_repr_decoder() + + case ["json"]: + return cls.__create_json_decoder() + + case ["protobuf", msg_type]: + return cls.__create_protobuf_decoder(msg_type) + + case _: + return Loader[Serializer[bytes, BinaryMessage]]().load_instance(value) + + @staticmethod + def __create_raw_decoder() -> Serializer[bytes, BinaryMessage]: + class RawDecoder(Serializer[bytes, BinaryMessage]): + def dump_message(self, message: bytes) -> BinaryMessage: + raise NotImplementedError + + def load_message(self, message: BinaryMessage) -> bytes: + return message.body + + return RawDecoder() + + @staticmethod + def __create_repr_decoder() -> Serializer[bytes, BinaryMessage]: + class ReprDecoder(Serializer[bytes, BinaryMessage]): + def dump_message(self, message: bytes) -> BinaryMessage: + raise NotImplementedError + + def load_message(self, message: BinaryMessage) -> bytes: + return repr(message).encode() + + return ReprDecoder() + + @staticmethod + def __create_json_decoder() -> Serializer[bytes, BinaryMessage]: + from brokrpc.serializer.json import JSONSerializer + + class JSONDecoder(Serializer[bytes, BinaryMessage]): + def __init__(self, inner: Serializer[Message[object], BinaryMessage]) -> None: + self.__inner = inner + + def dump_message(self, message: bytes) -> BinaryMessage: + raise NotImplementedError + + def load_message(self, message: BinaryMessage) -> bytes: + json_message = self.__inner.load_message(message) + return json.dumps(json_message.body, indent=None, separators=(",", ":")).encode() + + return JSONDecoder(JSONSerializer()) + + @staticmethod + def __create_protobuf_decoder(msg: str) -> Serializer[bytes, BinaryMessage]: + from google.protobuf.message import Message as ProtobufMessage + + from brokrpc.serializer.protobuf import ProtobufSerializer + + message_class = Loader[ProtobufMessage]().load_class(msg) + + class ProtobufDecoder(Serializer[bytes, BinaryMessage]): + def __init__(self, inner: Serializer[Message[ProtobufMessage], BinaryMessage]) -> None: + self.__inner = inner + + def dump_message(self, message: bytes) -> BinaryMessage: + raise NotImplementedError + + def load_message(self, message: BinaryMessage) -> bytes: + protobuf_message = self.__inner.load_message(message) + return str(protobuf_message.body).encode() + + return ProtobufDecoder(ProtobufSerializer(message_class)) + + +def main() -> None: + app = Factory(Parser().parse()).create_app() + sys.exit(asyncio.run(app.run())) if __name__ == "__main__": diff --git a/src/brokrpc/driver/aiormq.py b/src/brokrpc/driver/aiormq.py index cacdd6c..d38ca9d 100644 --- a/src/brokrpc/driver/aiormq.py +++ b/src/brokrpc/driver/aiormq.py @@ -8,10 +8,11 @@ from datetime import UTC, datetime, timedelta from uuid import uuid4 -from aiormq import Channel, Connection, ProtocolSyntaxError, spec +from aiormq import AMQPError, Channel, ChannelInvalidStateError, Connection, ProtocolSyntaxError, spec from pamqp.common import FieldTable -from brokrpc.retry import create_delay_retryer +from brokrpc.errors import ErrorTransformer +from brokrpc.retry import RetryerError, create_delay_retryer from brokrpc.stringify import to_str_obj if t.TYPE_CHECKING: @@ -20,6 +21,8 @@ from brokrpc.abc import BinaryConsumer, BinaryPublisher, BoundConsumer, BrokerDriver, Publisher from brokrpc.message import BinaryMessage, Message from brokrpc.model import ( + BrokerConnectionError, + BrokerError, ConsumerAck, ConsumerReject, ConsumerResult, @@ -35,6 +38,20 @@ QueueOptions, ) +_ERROR_TRANSFORMER = ErrorTransformer() + + +@_ERROR_TRANSFORMER.register +def _transform_conn_err( + err: ConnectionError | ProtocolSyntaxError | ChannelInvalidStateError, +) -> BrokerConnectionError: + return BrokerConnectionError(str(err)) + + +@_ERROR_TRANSFORMER.register +def _transform_generic_err(err: AMQPError) -> BrokerError: + return BrokerError(str(err)) + class AiormqMessage(Message[bytes]): __slots__ = ("__impl",) @@ -125,6 +142,7 @@ def __init__( self.__message_id_gen = message_id_gen self.__now = now + @_ERROR_TRANSFORMER.wrap async def publish(self, message: BinaryMessage) -> PublisherResult: await self.__channel.basic_publish( body=message.body, @@ -162,6 +180,7 @@ def __init__(self, channel: Channel, inner: BinaryConsumer) -> None: self.__channel = channel self.__inner = inner + @_ERROR_TRANSFORMER.wrap async def __call__(self, aiormq_message: DeliveredMessage) -> None: assert isinstance(aiormq_message.delivery_tag, int) assert isinstance(aiormq_message.routing_key, str) @@ -227,13 +246,18 @@ def log_warning(err: Exception) -> None: ) conn = Connection(options.url) - try: - await retryer.do(conn.connect) + with _ERROR_TRANSFORMER: + try: + await retryer.do(conn.connect) - yield cls(conn) + yield cls(conn) - finally: - await conn.close() + except RetryerError as err: + details = "can't connect to broker" + raise BrokerConnectionError(details, options) from err + + finally: + await conn.close() def __init__(self, connection: Connection) -> None: self.__connection = connection @@ -243,31 +267,33 @@ def __str__(self) -> str: @asynccontextmanager async def provide_publisher(self, options: PublisherOptions | None = None) -> t.AsyncIterator[BinaryPublisher]: - async with self.__provide_channel(options) as channel: - yield AiormqPublisher(channel, options, self.__gen_message_id, self.__get_now) + with _ERROR_TRANSFORMER: + async with self.__provide_channel(options) as channel: + yield AiormqPublisher(channel, options, self.__gen_message_id, self.__get_now) @asynccontextmanager async def bind_consumer(self, consumer: BinaryConsumer, options: BindingOptions) -> t.AsyncIterator[BoundConsumer]: channel: Channel - async with self.__provide_channel(options.queue) as channel: - callback = AiormqConsumerCallback(channel, consumer) + with _ERROR_TRANSFORMER: + async with self.__provide_channel(options.queue) as channel: + callback = AiormqConsumerCallback(channel, consumer) - bindings = await self.__bind(channel, options) - assert isinstance(bindings.queue, QueueOptions) - assert isinstance(bindings.queue.name, str) + bindings = await self.__bind(channel, options) + assert isinstance(bindings.queue, QueueOptions) + assert isinstance(bindings.queue.name, str) - consume_ok = await channel.basic_consume( - queue=bindings.queue.name, - consumer_callback=callback, - ) - assert isinstance(consume_ok.consumer_tag, str) + consume_ok = await channel.basic_consume( + queue=bindings.queue.name, + consumer_callback=callback, + ) + assert isinstance(consume_ok.consumer_tag, str) - try: - yield AiormqBoundConsumer(channel, callback, bindings, consume_ok.consumer_tag) + try: + yield AiormqBoundConsumer(channel, callback, bindings, consume_ok.consumer_tag) - finally: - await channel.basic_cancel(consume_ok.consumer_tag) + finally: + await channel.basic_cancel(consume_ok.consumer_tag) @asynccontextmanager async def __provide_channel(self, options: QOSOptions | None) -> t.AsyncIterator[Channel]: diff --git a/src/brokrpc/driver/redis.py b/src/brokrpc/driver/redis.py new file mode 100644 index 0000000..e7ded29 --- /dev/null +++ b/src/brokrpc/driver/redis.py @@ -0,0 +1,693 @@ +from __future__ import annotations + +import asyncio +import logging +import typing as t +from contextlib import asynccontextmanager, nullcontext, suppress +from dataclasses import replace +from datetime import datetime, timedelta +from uuid import uuid4 + +from redis import ConnectionError as RedisConnectionError +from redis import RedisError +from redis.asyncio import Redis +from redis.asyncio.retry import Retry +from redis.backoff import ConstantBackoff + +if t.TYPE_CHECKING: + from types import TracebackType + + from redis.asyncio.client import PubSub + +from brokrpc.abc import BinaryConsumer, BinaryPublisher, BoundConsumer, BrokerDriver, Publisher +from brokrpc.errors import ErrorTransformer +from brokrpc.message import BinaryMessage, Message +from brokrpc.model import ( + BrokerConnectionError, + BrokerError, + ConsumerAck, + ConsumerReject, + ConsumerRetry, + PublisherResult, +) +from brokrpc.options import BindingOptions, BrokerOptions, ExchangeOptions, PublisherOptions, QueueOptions +from brokrpc.stringify import to_str_obj + +_ERROR_TRANSFORMER = ErrorTransformer() + + +# TODO: check redis error transformation +@_ERROR_TRANSFORMER.register +def _transform_conn_err(err: RedisConnectionError) -> BrokerConnectionError: + return BrokerConnectionError(str(err)) + + +@_ERROR_TRANSFORMER.register +def _transform_generic_err(err: RedisError) -> BrokerError: + return BrokerError(str(err)) + + +class RedisPubSubSubscribeCommand(t.TypedDict): + type: t.Literal["subscribe"] + pattern: str | None + channel: bytes + data: int + + +class RedisPubSubPatternSubscribeCommand(t.TypedDict): + type: t.Literal["psubscribe"] + pattern: str | None + channel: bytes + data: int + + +class RedisPubSubMessageCommand(t.TypedDict): + type: t.Literal["message"] + pattern: str | None + channel: bytes + data: bytes + + +class RedisPubSubPatternMessageCommand(t.TypedDict): + type: t.Literal["pmessage"] + pattern: str | None + channel: bytes + data: bytes + + +class RedisPubSubUnsubscribeCommand(t.TypedDict): + type: t.Literal["unsubscribe"] + pattern: str | None + channel: bytes + data: int + + +class RedisPubSubPatternUnsubscribeCommand(t.TypedDict): + type: t.Literal["punsubscribe"] + pattern: str | None + channel: bytes + data: int + + +RedisPubSubCommand = ( + RedisPubSubSubscribeCommand + | RedisPubSubPatternSubscribeCommand + | RedisPubSubMessageCommand + | RedisPubSubPatternMessageCommand + | RedisPubSubUnsubscribeCommand + | RedisPubSubPatternUnsubscribeCommand +) +""" +type: One of the following: `subscribe`, `unsubscribe`, `psubscribe`, `punsubscribe`, `message`, `pmessage` + +channel: The channel [un]subscribed to or the channel a message was published to + +pattern: The pattern that matched a published message's channel. Will be None in all cases except for `pmessage` types. + +data: The message data. With [un]subscribe messages, this value will be the number of channels and patterns the +connection is currently subscribed to. With [p]message messages, this value will be the actual published message. +""" + + +class RedisPubSubMessage(Message[bytes]): + __slots__ = ("__impl",) + + def __init__(self, impl: RedisPubSubMessageCommand | RedisPubSubPatternMessageCommand) -> None: + self.__impl = impl + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__impl!r})" + + @property + def body(self) -> bytes: + return self.__impl["data"] + + @property + def routing_key(self) -> str: + return self.__impl["channel"].decode() + + @property + def exchange(self) -> str | None: + return None + + @property + def content_type(self) -> str | None: + return None + + @property + def content_encoding(self) -> str | None: + return None + + @property + def headers(self) -> t.Mapping[str, str] | None: + return None + + @property + def delivery_mode(self) -> int | None: + return None + + @property + def priority(self) -> int | None: + return None + + @property + def correlation_id(self) -> str | None: + return None + + @property + def reply_to(self) -> str | None: + return None + + @property + def timeout(self) -> timedelta | None: + return None + + @property + def message_id(self) -> str | None: + return None + + @property + def timestamp(self) -> datetime | None: + return None + + @property + def message_type(self) -> str | None: + return None + + @property + def user_id(self) -> str | None: + return None + + @property + def app_id(self) -> str | None: + return None + + +class RedisStreamMessage(Message[bytes]): + __slots__ = ( + "__app_id", + "__body", + "__content_encoding", + "__content_type", + "__correlation_id", + "__delivery_mode", + "__exchange", + "__headers", + "__message_id", + "__message_type", + "__priority", + "__reply_to", + "__routing_key", + "__timeout", + "__timestamp", + "__user_id", + ) + + # NOTE: message constructor has a lot of options to set up a structure (dataclass) + def __init__( # noqa: PLR0913 + self, + *, + body: bytes, + routing_key: str, + exchange: str | None = None, + content_type: str | None = None, + content_encoding: str | None = None, + headers: t.Mapping[str, str] | None = None, + delivery_mode: int | None = None, + priority: int | None = None, + correlation_id: str | None = None, + reply_to: str | None = None, + timeout: timedelta | None = None, + message_id: str | None = None, + timestamp: datetime | None = None, + message_type: str | None = None, + user_id: str | None = None, + app_id: str | None = None, + ) -> None: + self.__body = body + self.__routing_key = routing_key + self.__exchange = exchange + self.__content_type = content_type + self.__content_encoding = content_encoding + self.__headers = headers + self.__delivery_mode = delivery_mode + self.__priority = priority + self.__correlation_id = correlation_id + self.__reply_to = reply_to + self.__timeout = timeout + self.__message_id = message_id + self.__timestamp = timestamp + self.__message_type = message_type + self.__user_id = user_id + self.__app_id = app_id + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"body={self.__body!r}, " + f"routing_key={self.__routing_key!r}, " + f"exchange={self.__exchange!r}, " + f"content_type={self.__content_type!r}, " + f"content_encoding={self.__content_encoding!r}, " + f"headers={self.__headers!r}, " + f"delivery_mode={self.__delivery_mode!r}, " + f"priority={self.__priority!r}, " + f"correlation_id={self.__correlation_id!r}, " + f"reply_to={self.__reply_to!r}, " + f"timeout={self.__timeout!r}, " + f"message_id={self.__message_id!r}, " + f"timestamp={self.__timestamp!r}, " + f"message_type={self.__message_type!r}, " + f"user_id={self.__user_id!r}, " + f"app_id={self.__app_id!r}" + ")" + ) + + @property + def body(self) -> bytes: + return self.__body + + @property + def routing_key(self) -> str: + return self.__routing_key + + @property + def exchange(self) -> str | None: + return self.__exchange + + @property + def content_type(self) -> str | None: + return self.__content_type + + @property + def content_encoding(self) -> str | None: + return self.__content_encoding + + @property + def headers(self) -> t.Mapping[str, str] | None: + return self.__headers + + @property + def delivery_mode(self) -> int | None: + return self.__delivery_mode + + @property + def priority(self) -> int | None: + return self.__priority + + @property + def correlation_id(self) -> str | None: + return self.__correlation_id + + @property + def reply_to(self) -> str | None: + return self.__reply_to + + @property + def timeout(self) -> timedelta | None: + return self.__timeout + + @property + def message_id(self) -> str | None: + return self.__message_id + + @property + def timestamp(self) -> datetime | None: + return self.__timestamp + + @property + def message_type(self) -> str | None: + return self.__message_type + + @property + def user_id(self) -> str | None: + return self.__user_id + + @property + def app_id(self) -> str | None: + return self.__app_id + + +@t.overload +def _dump_key(options: BindingOptions) -> bytes: ... + + +@t.overload +def _dump_key(options: str | ExchangeOptions | None, routing_key: str) -> bytes: ... + + +def _dump_key(options: str | ExchangeOptions | BindingOptions | None, routing_key: str | None = None) -> bytes: + exchange_name = ( + options.exchange.name + if isinstance(options, BindingOptions) and options.exchange is not None + else options.name + if isinstance(options, ExchangeOptions) + else options + ) + rk = options.queue.name if isinstance(options, BindingOptions) and options.queue is not None else routing_key + assert rk is not None + + return rk.encode() if not exchange_name else f"{exchange_name}:{rk}".encode() + + +def _extract_key(key: bytes) -> tuple[str | None, str]: + *other, rk = key.decode().split(":", maxsplit=1) + return other[0] if other else None, rk + + +class RedisPubSubPublisher(Publisher[BinaryMessage, PublisherResult]): + def __init__( + self, + redis: Redis, + options: PublisherOptions | None, + message_id_gen: t.Callable[[], str | None], + ) -> None: + self.__redis = redis + self.__options = options + self.__message_id_gen = message_id_gen + + @_ERROR_TRANSFORMER.wrap + async def publish(self, message: BinaryMessage) -> PublisherResult: + # TODO: find a way to pass RPC metadata with redis message. JSON & protobuf imposes additional dependencies to + # server & client. Also, keep in mind RPC abstraction from language implementations (not python only). + assert message.correlation_id is None + assert message.reply_to is None + await self.__redis.publish(_dump_key(message.exchange or self.__options, message.routing_key), message.body) + return None + + +class RedisStreamPublisher(Publisher[BinaryMessage, PublisherResult]): + def __init__( + self, + redis: Redis, + options: PublisherOptions | None, + ) -> None: + self.__redis = redis + self.__options = options + + @_ERROR_TRANSFORMER.wrap + async def publish(self, message: BinaryMessage) -> PublisherResult: + await self.__redis.xadd( + name=_dump_key(message.exchange or self.__options, message.routing_key), + fields={ + b"body": message.body, + b"content_type": message.content_type or b"", + b"content_encoding": message.content_encoding or b"", + **{f"header.{key}".encode(): value.encode() for key, value in (message.headers or {}).items()}, + b"delivery_mode": message.delivery_mode or b"", + b"priority": message.priority or b"", + b"correlation_id": message.correlation_id or b"", + b"reply_to": message.reply_to or b"", + # b"timeout": message.timeout or b"", + b"timestamp": message.timestamp.isoformat() if message.timestamp is not None else b"", + b"message_type": message.message_type or b"", + b"user_id": message.user_id or b"", + b"app_id": message.app_id or b"", + }, + id=message.message_id if message.message_id is not None else "*", + ) + + return None + + +class RedisPubSubSubscriber(t.AsyncContextManager["RedisPubSubSubscriber"], BoundConsumer): + def __init__( + self, + pubsub: PubSub, + consumer: BinaryConsumer, + options: BindingOptions, + ) -> None: + self.__pubsub = pubsub + self.__consumer = consumer + self.__options = options + + self.__task: asyncio.Task[None] | None = None + + async def __aenter__(self) -> t.Self: + if self.__task is not None: + raise RuntimeError + + await self.__pubsub.subscribe(*self.__options.binding_keys) + self.__task = asyncio.create_task(self.__run()) + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> None: + try: + # Redis automatically handles message dispatch, so no need for manual consumer cancellation + await self.__pubsub.unsubscribe() + + finally: + task, self.__task = self.__task, None + + # TODO: stop consumer task gracefully + if task is not None: + task.cancel() + with suppress(asyncio.CancelledError): + await task + + def is_alive(self) -> bool: + return bool(self.__pubsub.subscribed) and self.__task is not None and not self.__task.done() + + def get_options(self) -> BindingOptions: + return self.__options + + @_ERROR_TRANSFORMER.wrap + async def __run(self) -> None: + cmd: RedisPubSubCommand + + async for cmd in self.__pubsub.listen(): + assert isinstance(cmd, dict) # todo: ensure typing is valid + + if cmd["type"] == "subscribe": + # todo subscribe + pass + + elif cmd["type"] == "psubscribe": + # todo psubscribe + pass + + elif cmd["type"] == "message" or cmd["type"] == "pmessage": + try: + await self.__consumer.consume(RedisPubSubMessage(cmd)) + + except Exception as err: + logging.exception("fatal consumer error", exc_info=err) + + elif cmd["type"] == "unsubscribe": + # todo unsubscribe + pass + + elif cmd["type"] == "punsubscribe": + # todo punsubscribe + pass + + else: + t.assert_never(cmd["type"]) + + +class RedisStreamConsumerGroupSubscriber(t.AsyncContextManager["RedisStreamConsumerGroupSubscriber"], BoundConsumer): + def __init__( + self, + redis: Redis, + consumer: BinaryConsumer, + options: BindingOptions, + consumer_group_id: str, + consumer_id: str, + ) -> None: + self.__redis = redis + self.__consumer = consumer + self.__options = replace( + options, + queue=replace(options.queue if options.queue is not None else QueueOptions(), name=consumer_group_id), + ) + self.__consumer_group_id = consumer_group_id + self.__consumer_id = consumer_id + # TODO: make stream ID configurable? + self.__streams: dict[bytes, bytes] = { + _dump_key(self.__options.exchange, key): b">" for key in self.__options.binding_keys + } + + self.__task: asyncio.Task[None] | None = None + + @_ERROR_TRANSFORMER.wrap + async def __aenter__(self) -> t.Self: + if self.__task is not None: + raise RuntimeError + + # TODO: get deeper understanding of stream arg passed to XGROUP CREATE & stream args passed to XREADGROUP + # https://github.com/redis/redis/issues/9523 - may help + # https://github.com/redis/redis/discussions/13680 + await asyncio.gather( + *( + self.__redis.xgroup_create( + name=stream, + groupname=self.__consumer_group_id, + mkstream=True, + ) + for stream in self.__streams + ), + # TODO: suppress only `group already exists error` + return_exceptions=True, + ) + + self.__task = asyncio.create_task(self.__run()) + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> None: + # No need to manually unsubscribe, Redis Streams manage group consumption + task, self.__task = self.__task, None + + # TODO: stop consumer task gracefully + if task is not None: + task.cancel() + with suppress(asyncio.CancelledError): + await task + + def is_alive(self) -> bool: + return self.__task is not None and not self.__task.done() + + def get_options(self) -> BindingOptions: + return self.__options + + @_ERROR_TRANSFORMER.wrap + async def __run(self) -> None: + while True: + read_response = await self.__redis.xreadgroup( + groupname=self.__consumer_group_id, + consumername=self.__consumer_id, + # FIXME: remove type ignore here + streams=self.__streams, # type: ignore[arg-type] + count=self.__options.queue.prefetch_count if self.__options.queue is not None else 1, + # TODO: make it configurable + block=300_000, # 5 minutes + ) + + for stream, stream_id, message in self.__extract_messages(read_response): + try: + consumer_result = await self.__consumer.consume(message) + + except Exception as err: + logging.exception("fatal consumer error", exc_info=err) + + else: + match consumer_result: + case ConsumerAck() | True | None: + await self.__redis.xack(stream, self.__consumer_group_id, stream_id) + self.__streams[stream] = stream_id + + case ConsumerReject() | False: + # TODO: reject message + pass + + case ConsumerRetry(): + # TODO: retry message consumption with delay + pass + + def __extract_messages( + self, + response: t.Sequence[tuple[bytes, t.Sequence[tuple[bytes, t.Mapping[bytes, bytes]]]]], + ) -> t.Iterable[tuple[bytes, bytes, RedisStreamMessage]]: + for key, messages in response: + exchange, routing_key = _extract_key(key) + + for message_id, fields in messages: + yield ( + key, + message_id, + RedisStreamMessage( + body=fields[b"body"], + routing_key=routing_key, + exchange=exchange, + content_type=fields[b"content_type"].decode() if fields[b"content_type"] else None, + content_encoding=fields[b"content_encoding"].decode() if fields[b"content_encoding"] else None, + headers={ + key[len(b"header.") :].decode(): value.decode() + for key, value in fields.items() + if key.startswith(b"header.") + }, + delivery_mode=int.from_bytes(fields[b"delivery_mode"]) if fields[b"delivery_mode"] else None, + priority=int.from_bytes(fields[b"priority"]) if fields[b"delivery_mode"] else None, + correlation_id=fields[b"correlation_id"].decode() if fields[b"correlation_id"] else None, + reply_to=fields[b"reply_to"].decode() if fields[b"reply_to"] else None, + # TODO: add timeout, + message_id=message_id.decode(), + timestamp=datetime.fromisoformat(fields[b"timestamp"].decode()) + if fields[b"timestamp"] + else None, + message_type=fields[b"message_type"].decode() if fields[b"message_type"] else None, + user_id=fields[b"user_id"].decode() if fields[b"user_id"] else None, + app_id=fields[b"app_id"].decode() if fields[b"app_id"] else None, + ), + ) + + +# TODO: consider streams +# https://redis.readthedocs.io/en/latest/examples/redis-stream-example.html +# https://redis.io/docs/latest/develop/data-types/streams/ +class RedisBrokerDriver(BrokerDriver): + @classmethod + @asynccontextmanager + async def connect(cls, options: BrokerOptions) -> t.AsyncIterator[RedisBrokerDriver]: + redis: Redis = Redis.from_url( + url=str(options.url), + retry=Retry( + backoff=ConstantBackoff(backoff=3.0), + retries=10, + # TODO: use our retry with logging, adapt it to redis error + supported_errors=(RedisConnectionError, ConnectionError), # type: ignore[arg-type] + ), + ) + + with _ERROR_TRANSFORMER: + try: + await redis.initialize() + yield cls(redis) + + finally: + await redis.aclose() + + def __init__(self, redis: Redis) -> None: + self.__redis = redis + + def __str__(self) -> str: + return to_str_obj(self, redis=self.__redis) + + def provide_publisher(self, options: PublisherOptions | None = None) -> t.AsyncContextManager[BinaryPublisher]: + return nullcontext(RedisStreamPublisher(self.__redis, options)) + + @asynccontextmanager + async def bind_consumer(self, consumer: BinaryConsumer, options: BindingOptions) -> t.AsyncIterator[BoundConsumer]: + # TODO: use PubSub if queue is not set (no consumer groups) + assert options.queue is not None + assert options.queue.name + + async with ( + _ERROR_TRANSFORMER, + # self.__redis.pubsub() as pubsub, + RedisStreamConsumerGroupSubscriber( + self.__redis, + consumer, + options, + consumer_group_id=options.queue.name, + # TODO: make it configurable + consumer_id="my-consumer", + ) as subscriber, + ): + yield subscriber + + def __gen_message_id(self) -> str: + return uuid4().hex diff --git a/src/brokrpc/errors.py b/src/brokrpc/errors.py new file mode 100644 index 0000000..bb065f3 --- /dev/null +++ b/src/brokrpc/errors.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import typing as t +from functools import singledispatch, wraps + +if t.TYPE_CHECKING: + from types import TracebackType + + +class ErrorTransformer(t.ContextManager[None], t.AsyncContextManager[None]): + def __init__(self) -> None: + self.__transform = singledispatch(self.__transform_default) + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> None: + self.__handle(exc_value) + + async def __aenter__(self) -> None: + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> None: + self.__handle(exc_value) + + def register[U: BaseException]( + self, + func: t.Callable[[U], BaseException | None], + ) -> t.Callable[[U], BaseException | None]: + self.__transform.register(func) + return func + + def wrap[**U, V]( + self, + func: t.Callable[U, t.Coroutine[t.Any, t.Any, V]], + ) -> t.Callable[U, t.Coroutine[t.Any, t.Any, V]]: + @wraps(func) + async def wrapper(*args: U.args, **kwargs: U.kwargs) -> V: + with self: + return await func(*args, **kwargs) + + return wrapper + + def __transform_default(self, err: BaseException) -> BaseException | None: + return err + + def __handle(self, err: BaseException | None) -> None: + if err is None: + return + + transformed_err = self.__transform(err) + if transformed_err is None or transformed_err is err: + return + + else: + raise transformed_err from err diff --git a/src/brokrpc/model.py b/src/brokrpc/model.py index 4be81bf..9a264d1 100644 --- a/src/brokrpc/model.py +++ b/src/brokrpc/model.py @@ -36,6 +36,10 @@ class BrokerError(BrokRPCError): pass +class BrokerConnectionError(BrokerError): + pass + + class PublisherError(BrokRPCError): pass diff --git a/src/brokrpc/serializer/json.py b/src/brokrpc/serializer/json.py index c07380c..4faa146 100644 --- a/src/brokrpc/serializer/json.py +++ b/src/brokrpc/serializer/json.py @@ -13,10 +13,12 @@ class JSONSerializer(Serializer[Message[object], Message[bytes]], RPCSerializer[ def __init__( self, + *, encoder: JSONEncoder | None = None, decoder: JSONDecoder | None = None, encoding: str | None = None, default: t.Callable[[object], object] | None = None, + strict: bool = False, ) -> None: self.__encoder = ( encoder @@ -29,6 +31,7 @@ def __init__( ) self.__decoder = decoder if decoder is not None else JSONDecoder() self.__encoding = encoding if encoding is not None else "utf-8" + self.__strict = strict def dump_message(self, message: Message[object]) -> PackedMessage[bytes]: encoding = message.content_encoding if message.content_encoding is not None else self.__encoding @@ -63,7 +66,7 @@ def load_unary_response(self, response: BinaryMessage) -> Response[object]: return self.load_message(response) def __load(self, message: BinaryMessage) -> object: - if message.content_type != self.__CONTENT_TYPE: + if message.content_type != self.__CONTENT_TYPE and (self.__strict or message.content_type is not None): details = f"invalid content type: {message.content_type}" raise SerializerLoadError(details, message) diff --git a/src/brokrpc/serializer/protobuf.py b/src/brokrpc/serializer/protobuf.py index 96fb634..3a9a80f 100644 --- a/src/brokrpc/serializer/protobuf.py +++ b/src/brokrpc/serializer/protobuf.py @@ -1,7 +1,7 @@ import typing as t +from functools import cached_property from google.protobuf.any_pb2 import Any -from google.protobuf.json_format import Error, MessageToDict, MessageToJson from google.protobuf.message import DecodeError, EncodeError from google.protobuf.message import Message as ProtobufMessage @@ -10,14 +10,19 @@ from brokrpc.model import SerializerDumpError, SerializerLoadError from brokrpc.rpc.abc import RPCSerializer from brokrpc.rpc.model import BinaryRequest, BinaryResponse, Request, Response -from brokrpc.serializer.json import JSONSerializer class ProtobufSerializer[T: ProtobufMessage](Serializer[Message[T], Message[bytes]]): __CONTENT_TYPE: t.Final[str] = "application/protobuf" - def __init__(self, message_type: type[T]) -> None: + def __init__( + self, + message_type: type[T], + *, + strict: bool = False, + ) -> None: self.__message_type = message_type + self.__strict = strict def dump_message(self, message: Message[T]) -> PackedMessage[bytes]: assert isinstance(message.body, self.__message_type) @@ -34,15 +39,15 @@ def dump_message(self, message: Message[T]) -> PackedMessage[bytes]: body=body, content_type=self.__CONTENT_TYPE, content_encoding=None, - message_type=message.body.DESCRIPTOR.full_name, + message_type=self.__message_qualname, ) def load_message(self, message: BinaryMessage) -> UnpackedMessage[T]: - if message.content_type != self.__CONTENT_TYPE: + if message.content_type != self.__CONTENT_TYPE and (self.__strict or message.content_type is not None): details = f"invalid content type: {message.content_type}" raise SerializerLoadError(details, message) - if message.message_type != self.__message_type.DESCRIPTOR.full_name: + if message.message_type != self.__message_qualname and (self.__strict or message.message_type is not None): details = f"invalid message type: {message.message_type}" raise SerializerLoadError(details, message) @@ -58,96 +63,16 @@ def load_message(self, message: BinaryMessage) -> UnpackedMessage[T]: body=body, ) - -class DictProtobufSerializer[T: ProtobufMessage](Serializer[Message[dict[str, object]], Message[bytes]]): - def __init__(self, message_type: type[T]) -> None: - self.__message_type = message_type - self.__proto = ProtobufSerializer(message_type) - - def dump_message(self, message: Message[dict[str, object]]) -> Message[bytes]: - try: - payload = self.__message_type(**message.body) - - except (TypeError, ValueError) as err: - details = "can't construct protobuf message" - raise SerializerDumpError(details, message.body) from err - - return self.__proto.dump_message(PackedMessage(body=payload, original=message)) - - def load_message(self, message: Message[bytes]) -> Message[dict[str, object]]: - protobuf_message = self.__proto.load_message(message) - - try: - dict_payload = MessageToDict(protobuf_message.body, preserving_proto_field_name=True) - - except Error as err: - details = "can't covert protobuf message to dict" - raise SerializerLoadError(details, protobuf_message.body) from err - - return UnpackedMessage( - original=protobuf_message, - body=dict_payload, - ) - - -class JSONProtobufSerializer[T: ProtobufMessage](Serializer[Message[bytes], Message[bytes]]): - def __init__(self, message_type: type[T], encoding: str | None = None) -> None: - self.__message_type = message_type - self.__json = JSONSerializer(encoding=encoding) - self.__proto = ProtobufSerializer(message_type) - - def dump_message(self, message: Message[bytes]) -> Message[bytes]: - try: - json_message = self.__json.load_message( - PackedMessage( - body=message.body, - content_type="application/json", - content_encoding=None, - message_type=None, - original=message, - ) - ) - - except SerializerLoadError as err: - details = "can't load json message" - raise SerializerDumpError(details, message) from err - - if not self.__check_json_message(json_message.body): - details = "can't dump non object json to protobuf" - raise SerializerDumpError(details, json_message) - - try: - payload = self.__message_type(**json_message.body) - - except (TypeError, ValueError) as err: - details = "can't construct protobuf message" - raise SerializerDumpError(details, message.body) from err - - return self.__proto.dump_message(PackedMessage(body=payload, original=message)) - - def load_message(self, message: Message[bytes]) -> Message[bytes]: - protobuf_message = self.__proto.load_message(message) - - try: - json_payload = MessageToJson(protobuf_message.body, preserving_proto_field_name=True, indent=None) - - except Error as err: - details = "can't covert protobuf message to json" - raise SerializerLoadError(details, protobuf_message.body) from err - - return UnpackedMessage( - original=protobuf_message, - body=json_payload.encode(), - ) - - def __check_json_message(self, obj: object) -> t.TypeGuard[dict[str, object]]: - return isinstance(obj, dict) and all(isinstance(key, str) for key in obj) + @cached_property + def __message_qualname(self) -> str: + # NOTE: actual type of full_name in message descriptor is str. + return t.cast(str, self.__message_type.DESCRIPTOR.full_name) class RPCProtobufSerializer[U: ProtobufMessage, V: ProtobufMessage](RPCSerializer[U, V]): - def __init__(self, request_type: type[U], response_type: type[V]) -> None: - self.__request = ProtobufSerializer(request_type) - self.__response = ProtobufSerializer(response_type) + def __init__(self, request_type: type[U], response_type: type[V], *, strict: bool = False) -> None: + self.__request = ProtobufSerializer(request_type, strict=strict) + self.__response = ProtobufSerializer(response_type, strict=strict) def dump_unary_request(self, request: Request[U]) -> BinaryRequest: return self.__request.dump_message(request) diff --git a/tests/conftest.py b/tests/conftest.py index d2148d2..c4f9174 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import typing as t from contextlib import nullcontext +from dataclasses import asdict from datetime import timedelta from unittest.mock import create_autospec @@ -15,7 +16,7 @@ from brokrpc.broker import Broker from brokrpc.message import BinaryMessage from brokrpc.model import ConsumerResult -from brokrpc.options import BindingOptions, BrokerOptions +from brokrpc.options import BindingOptions, BrokerOptions, ExchangeOptions, PublisherOptions, QueueOptions from brokrpc.retry import ConstantDelay, DelayRetryStrategy, ExponentialDelay, MultiplierDelay from brokrpc.rpc.abc import RPCSerializer, UnaryUnaryHandler from tests.stub.driver import StubBrokerDriver, StubConsumer @@ -35,15 +36,20 @@ def parse_seconds(value: str) -> timedelta: return timedelta(seconds=float(value)) parser.addoption( - "--broker-url", + "--broker-drivers", + type=str, + action="append", + default=None, + ) + parser.addoption( + "--aiormq-url", type=URL, default=URL("amqp://guest:guest@localhost:5672/"), ) parser.addoption( - "--broker-driver", - type=str, - choices=["aiormq"], - default="aiormq", + "--redis-url", + type=URL, + default=URL("redis://localhost:6379/"), ) parser.addoption( "--broker-retry-delay", @@ -73,16 +79,6 @@ def parse_seconds(value: str) -> timedelta: ) -def parse_broker_options(config: Config) -> BrokerOptions: - return BrokerOptions( - url=config.getoption("broker_url"), - driver=config.getoption("broker_driver"), - retry_delay=parse_retry_delay_strategy(config), - retries_timeout=config.getoption("broker_retries_timeout"), - retries_limit=config.getoption("broker_retries_limit"), - ) - - def parse_retry_delay_strategy(config: Config) -> DelayRetryStrategy | None: match mode := config.getoption("broker_retry_delay_mode"): case "constant": @@ -110,6 +106,28 @@ def parse_retry_delay_strategy(config: Config) -> DelayRetryStrategy | None: raise ValueError(details, mode) +@pytest.fixture( + params=[ + pytest.param("aiormq"), + pytest.param("redis"), + ], +) +def broker_options(request: SubRequest) -> BrokerOptions: + driver = request.param + drivers = request.config.getoption("broker_drivers", None) + + if drivers is not None and driver not in drivers: + pytest.skip(reason=f"driver is not enabled: {driver}") + + return BrokerOptions( + url=request.config.getoption(f"{driver}_url"), + driver=driver, + retry_delay=parse_retry_delay_strategy(request.config), + retries_timeout=request.config.getoption("broker_retries_timeout"), + retries_limit=request.config.getoption("broker_retries_limit"), + ) + + @pytest.fixture def stub_broker_driver() -> StubBrokerDriver: return StubBrokerDriver() @@ -150,14 +168,53 @@ def mock_unary_unary_handler() -> UnaryUnaryHandler[object, object]: return create_autospec(UnaryUnaryHandler) +@pytest.fixture +def stub_exchange_name(request: SubRequest) -> str: + return f"{request.node.originalname}-exchange" + + @pytest.fixture def stub_routing_key(request: SubRequest) -> str: - return request.node.name + return f"{request.node.originalname}-rk" + + +@pytest.fixture +def stub_queue_name(request: SubRequest) -> str: + return f"{request.node.originalname}-queue" + + +@pytest.fixture +def stub_exchange(stub_exchange_name: str) -> ExchangeOptions: + return ExchangeOptions( + name=stub_exchange_name, + auto_delete=True, + ) @pytest.fixture -def stub_binding_options(stub_routing_key: str) -> BindingOptions: - return BindingOptions(binding_keys=(stub_routing_key,)) +def stub_queue(stub_queue_name: str) -> QueueOptions: + return QueueOptions( + name=stub_queue_name, + auto_delete=True, + ) + + +@pytest.fixture +def stub_publisher_options(stub_exchange: ExchangeOptions) -> PublisherOptions: + return PublisherOptions(**asdict(stub_exchange)) + + +@pytest.fixture +def stub_binding_options( + stub_exchange: ExchangeOptions, + stub_routing_key: str, + stub_queue: QueueOptions, +) -> BindingOptions: + return BindingOptions( + exchange=stub_exchange, + binding_keys=(stub_routing_key,), + queue=stub_queue, + ) @pytest.fixture diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 3bf3719..6cc8ff2 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,8 +1,6 @@ import typing as t -from dataclasses import replace import pytest -from _pytest.fixtures import SubRequest from brokrpc.broker import Broker from brokrpc.options import BrokerOptions @@ -11,31 +9,18 @@ from brokrpc.rpc.server import Server from brokrpc.serializer.json import JSONSerializer from brokrpc.serializer.protobuf import RPCProtobufSerializer -from tests.conftest import parse_broker_options from tests.stub.proto.greeting_pb2 import GreetingRequest, GreetingResponse -@pytest.fixture( - params=[ - pytest.param("aiormq"), - ] -) -def rabbitmq_options(request: SubRequest) -> BrokerOptions: - return replace( - parse_broker_options(request.config), - driver=request.param, - ) - - @pytest.fixture -async def rabbitmq_broker(rabbitmq_options: BrokerOptions) -> t.AsyncIterator[Broker]: - async with Broker(rabbitmq_options) as broker: +async def broker(broker_options: BrokerOptions) -> t.AsyncIterator[Broker]: + async with Broker(broker_options) as broker: yield broker @pytest.fixture -def rpc_server(rabbitmq_broker: Broker) -> Server: - return Server(rabbitmq_broker) +def rpc_server(broker: Broker) -> Server: + return Server(broker) @pytest.fixture @@ -45,8 +30,8 @@ async def running_rpc_server(rpc_server: Server) -> t.AsyncIterator[Server]: @pytest.fixture -def rpc_client(rabbitmq_broker: Broker) -> Client: - return Client(rabbitmq_broker) +def rpc_client(broker: Broker) -> Client: + return Client(broker) @pytest.fixture diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index 67b8710..148e8d1 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -1,132 +1,138 @@ import asyncio -import typing as t -from argparse import ArgumentParser +from asyncio import TaskGroup import pytest -from brokrpc.cli import CLIOptions, build_parser, run -from brokrpc.options import BrokerOptions +from brokrpc.abc import Serializer +from brokrpc.broker import Broker +from brokrpc.cli import AsyncStream, Console, ConsumerApp, PublisherApp +from brokrpc.message import BinaryMessage +from brokrpc.options import BindingOptions, PublisherOptions from tests.stub.stream import InMemoryStream @pytest.mark.parametrize("body", [b"test-cli-body"]) async def test_publisher_consumer_messaging( - consumer: asyncio.Task[int], - publisher: asyncio.Task[int], - publisher_options: CLIOptions, - consumer_options: CLIOptions, + *, + consumer_app: ConsumerApp, + consumer_stdin: AsyncStream, + consumer_stdout: AsyncStream, + consumer_stderr: AsyncStream, + publisher_app: PublisherApp, + publisher_stdin: AsyncStream, + publisher_stderr: AsyncStream, body: bytes, ) -> None: - await asyncio.sleep(1.0) # wait for consumer is registered - await publisher_options.input.write(body) - await publisher_options.input.close() + async with TaskGroup() as tg: + tg.create_task(consumer_app.run()) + tg.create_task(publisher_app.run()) - await asyncio.sleep(1.0) # wait for message consumption - consumer_options.done.set() + # wait for consumer is registered + await asyncio.sleep(1.0) - await asyncio.gather(publisher, consumer) + # publish message + await publisher_stdin.write(body + b"\n") + await publisher_stdin.flush() + await publisher_stdin.close() - publisher_logs = await publisher_options.err.readlines() - consumer_output = await consumer_options.output.readlines() + # wait for message consumption + await asyncio.sleep(1.0) + await consumer_stdin.close() - assert publisher_logs == [b"sent"] - assert consumer_output == [body] + publisher_logs = await publisher_stderr.readlines() + consumer_output = await consumer_stdout.readlines() + consumer_logs = await consumer_stderr.readlines() + assert (publisher_logs, consumer_output, consumer_logs) == ([b"message sent"], [body], [b"message consumed"]) -@pytest.fixture -def publisher(event_loop: asyncio.AbstractEventLoop, publisher_options: CLIOptions) -> t.Iterator[asyncio.Task[int]]: - task = event_loop.create_task(run(publisher_options)) - try: - yield task - finally: - if not task.done(): - task.cancel() +@pytest.fixture +def publisher_app( + *, + broker: Broker, + publisher_stdin: AsyncStream, + publisher_stdout: AsyncStream, + publisher_stderr: AsyncStream, + stub_publisher_options: PublisherOptions, + stub_routing_key: str, + raw_encoder: Serializer[BinaryMessage, BinaryMessage], +) -> PublisherApp: + return PublisherApp( + broker=broker, + options=stub_publisher_options, + console=Console(publisher_stdin, publisher_stdout, publisher_stderr), + routing_key=stub_routing_key, + serializer=raw_encoder, + ) @pytest.fixture -def consumer(event_loop: asyncio.AbstractEventLoop, consumer_options: CLIOptions) -> t.Iterator[asyncio.Task[int]]: - task = event_loop.create_task(run(consumer_options)) - try: - yield task - - finally: - if not task.done(): - task.cancel() +def consumer_app( + *, + broker: Broker, + consumer_stdin: AsyncStream, + consumer_stdout: AsyncStream, + consumer_stderr: AsyncStream, + stub_binding_options: BindingOptions, + raw_decoder: Serializer[bytes, BinaryMessage], +) -> ConsumerApp: + return ConsumerApp( + broker=broker, + options=stub_binding_options, + serializer=raw_decoder, + console=Console(consumer_stdin, consumer_stdout, consumer_stderr), + ) @pytest.fixture -def parser() -> ArgumentParser: - parser = build_parser() +def raw_encoder() -> Serializer[BinaryMessage, BinaryMessage]: + class RawEncoder(Serializer[BinaryMessage, BinaryMessage]): + def dump_message(self, message: BinaryMessage) -> BinaryMessage: + return message - # NOTE: raise exceptions instead of `sys.exit` - parser.exit_on_error = False # type: ignore[attr-defined] + def load_message(self, message: BinaryMessage) -> BinaryMessage: + return message - return parser + return RawEncoder() @pytest.fixture -def routing_key() -> str: - return "test-cli-messaging" +def raw_decoder() -> Serializer[bytes, BinaryMessage]: + class RawDecoder(Serializer[bytes, BinaryMessage]): + def dump_message(self, message: bytes) -> BinaryMessage: + raise NotImplementedError + + def load_message(self, message: BinaryMessage) -> bytes: + return message.body + + return RawDecoder() @pytest.fixture -def exchange_name() -> str: - return "test-cli" +def publisher_stdin() -> AsyncStream: + return InMemoryStream() @pytest.fixture -def publisher_options( - rabbitmq_options: BrokerOptions, - parser: ArgumentParser, - routing_key: str, - exchange_name: str, - done: asyncio.Event, -) -> CLIOptions: - options = parser.parse_args( - [ - str(rabbitmq_options.url), - routing_key, - "--publisher", - f"--exchange={exchange_name}", - ], - namespace=CLIOptions(), - ) +def publisher_stdout() -> AsyncStream: + return InMemoryStream() - options.input = InMemoryStream() - options.output = InMemoryStream() - options.err = InMemoryStream() - options.done = done - return options +@pytest.fixture +def publisher_stderr() -> AsyncStream: + return InMemoryStream() @pytest.fixture -def consumer_options( - rabbitmq_options: BrokerOptions, - parser: ArgumentParser, - routing_key: str, - exchange_name: str, - done: asyncio.Event, -) -> CLIOptions: - options = parser.parse_args( - [ - str(rabbitmq_options.url), - routing_key, - "--consumer", - f"--exchange={exchange_name}", - ], - namespace=CLIOptions(), - ) +def consumer_stdin() -> AsyncStream: + return InMemoryStream() - options.input = InMemoryStream() - options.output = InMemoryStream() - options.err = InMemoryStream() - options.done = done - return options +@pytest.fixture +def consumer_stdout() -> AsyncStream: + return InMemoryStream() @pytest.fixture -def done() -> asyncio.Event: - return asyncio.Event() +def consumer_stderr() -> AsyncStream: + return InMemoryStream() diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index 6ac2d21..96d1078 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -1,6 +1,5 @@ import asyncio import typing as t -from dataclasses import asdict import pytest from _pytest.fixtures import SubRequest @@ -8,7 +7,7 @@ from brokrpc.abc import BinaryPublisher, Publisher, Serializer from brokrpc.broker import Broker from brokrpc.message import AppMessage, Message -from brokrpc.options import BindingOptions, ExchangeOptions, PublisherOptions, QueueOptions +from brokrpc.options import BindingOptions, PublisherOptions from brokrpc.rpc.abc import Caller, RPCSerializer from brokrpc.rpc.client import Client from brokrpc.rpc.server import Server @@ -77,39 +76,38 @@ def receive_waiter_handler( @pytest.fixture async def consumer( - rabbitmq_broker: Broker, + broker: Broker, receive_waiter_consumer: ReceiveWaiterConsumer, json_serializer: Serializer[Message[object], Message[bytes]], - binding_options: BindingOptions, + stub_binding_options: BindingOptions, ) -> t.AsyncIterator[ReceiveWaiterConsumer]: - async with rabbitmq_broker.consumer(receive_waiter_consumer, binding_options, serializer=json_serializer): + async with broker.consumer(receive_waiter_consumer, stub_binding_options, serializer=json_serializer): yield receive_waiter_consumer @pytest.fixture async def publisher( - rabbitmq_broker: Broker, + broker: Broker, json_serializer: Serializer[Message[object], Message[bytes]], - publisher_options: PublisherOptions, + stub_publisher_options: PublisherOptions, ) -> t.AsyncIterator[BinaryPublisher]: - async with rabbitmq_broker.publisher(publisher_options, serializer=json_serializer) as pub: + async with broker.publisher(stub_publisher_options, serializer=json_serializer) as pub: yield pub @pytest.fixture def greeting_handler( - routing_key: str, - binding_options: BindingOptions, + stub_binding_options: BindingOptions, rpc_server: Server, rpc_greeting_serializer: RPCSerializer[GreetingRequest, GreetingResponse], receive_waiter_handler: GreetingHandler, ) -> GreetingHandler: rpc_server.register_unary_unary_handler( func=receive_waiter_handler, - routing_key=routing_key, + routing_key=stub_binding_options.binding_keys[0], serializer=rpc_greeting_serializer, - exchange=binding_options.exchange, - queue=binding_options.queue, + exchange=stub_binding_options.exchange, + queue=stub_binding_options.queue, ) return receive_waiter_handler @@ -117,48 +115,22 @@ def greeting_handler( @pytest.fixture async def caller( - exchange: ExchangeOptions, - routing_key: str, + stub_publisher_options: PublisherOptions, + stub_routing_key: str, rpc_client: Client, rpc_greeting_serializer: RPCSerializer[GreetingRequest, GreetingResponse], ) -> t.AsyncIterator[Caller[GreetingRequest, GreetingResponse]]: async with rpc_client.unary_unary_caller( - routing_key=routing_key, + routing_key=stub_routing_key, serializer=rpc_greeting_serializer, - exchange=exchange, + exchange=stub_publisher_options, ) as caller: yield caller @pytest.fixture -def exchange() -> ExchangeOptions: - return ExchangeOptions(name="simple-test", auto_delete=True) - - -@pytest.fixture -def publisher_options(exchange: ExchangeOptions) -> PublisherOptions: - return PublisherOptions(**asdict(exchange)) - - -@pytest.fixture -def binding_options(publisher_options: PublisherOptions, routing_key: str) -> BindingOptions: - return BindingOptions( - exchange=publisher_options, - binding_keys=(routing_key,), - queue=QueueOptions( - auto_delete=True, - ), - ) - - -@pytest.fixture -def routing_key() -> str: - return "test-simple" - - -@pytest.fixture -def json_message(json_content: object, routing_key: str) -> Message[object]: - return AppMessage(body=json_content, routing_key=routing_key) +def json_message(json_content: object, stub_routing_key: str) -> Message[object]: + return AppMessage(body=json_content, routing_key=stub_routing_key) @pytest.fixture( diff --git a/tests/test_mypy.yml b/tests/test_mypy.yml index 6434c75..6b2b0ca 100644 --- a/tests/test_mypy.yml +++ b/tests/test_mypy.yml @@ -41,8 +41,9 @@ class Bar: pass + # FIXME: this prevents mypy errors on 3rd-party libs during tests. Find a way to avoid mypy config patching & error ignore. mypy_config: |- - [mypy-aiormq.*,pamqp.*,pkg_resources.*] + [mypy-aiormq.*,pamqp.*,pkg_resources.*,cryptography.*,redis.*] ignore_missing_imports = True ignore_errors = True @@ -88,8 +89,9 @@ class Bar: pass + # FIXME: this prevents mypy errors on 3rd-party libs during tests. Find a way to avoid mypy config patching & error ignore. mypy_config: |- - [mypy-aiormq.*,pamqp.*,pkg_resources.*] + [mypy-aiormq.*,pamqp.*,pkg_resources.*,cryptography.*,redis.*] ignore_missing_imports = True ignore_errors = True diff --git a/tests/unit/serializer/test_protobuf.py b/tests/unit/serializer/test_protobuf.py index 87cf9d0..e1b37cc 100644 --- a/tests/unit/serializer/test_protobuf.py +++ b/tests/unit/serializer/test_protobuf.py @@ -1,12 +1,10 @@ -import json - import pytest from google.protobuf.empty_pb2 import Empty from google.protobuf.message import Message as ProtobufMessage from brokrpc.message import AppMessage, Message from brokrpc.model import SerializerLoadError -from brokrpc.serializer.protobuf import JSONProtobufSerializer, ProtobufSerializer +from brokrpc.serializer.protobuf import ProtobufSerializer from tests.stub.proto.greeting_pb2 import GreetingRequest @@ -24,10 +22,6 @@ ProtobufSerializer(GreetingRequest), GreetingRequest(name="Jack"), ), - pytest.param( - JSONProtobufSerializer(GreetingRequest), - json.dumps({"name": "Jack"}).encode(), - ), ], ) def test_dump_load_ok[T: ProtobufMessage]( diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 4ebfbe1..4e507cc 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -1,135 +1,104 @@ +import re import typing as t -from argparse import ArgumentError, ArgumentParser +from argparse import ArgumentError +from datetime import timedelta import pytest from yarl import URL -from brokrpc.cli import CLIOptions, build_parser -from brokrpc.serializer.ident import IdentSerializer -from brokrpc.serializer.json import JSONSerializer -from tests.check import check_instance - - -class TestCLIOptions: - @pytest.mark.parametrize( - "values", - [ - pytest.param( - { - "url": URL("amqp://guest:guest@localhost:5672/consume"), - "routing_key": "test-rk-consume", - "kind": "consumer", - "exchange": None, - "output_mode": "body", - "quiet": False, - "serializer": JSONSerializer(), - } - ), - ], - ) - def test_to_str(self, options: CLIOptions, values: t.Mapping[str, object] | None) -> None: - cli_options_str = str(options) - assert all(key in cli_options_str and str(value) in cli_options_str for key, value in (values or {}).items()) +from brokrpc.cli import Options, Parser class TestParser: @pytest.mark.parametrize( - ("args", "expected_ns"), + ("args", "expected"), [ pytest.param( - ["amqp://guest:guest@localhost:5672/consume", "test-rk-consume", "--consumer"], - { - "url": URL("amqp://guest:guest@localhost:5672/consume"), - "routing_key": "test-rk-consume", - "kind": "consumer", - "exchange": None, - "output_mode": "body", - "quiet": False, - "serializer": check_instance(IdentSerializer), - }, + ["amqp://guest:guest@localhost:5672/consume", "consume", "test-rk-consume"], + Options( + url=URL("amqp://guest:guest@localhost:5672/consume"), + mode="consume", + routing_key="test-rk-consume", + ), id="simple consumer", ), pytest.param( - ["amqp://guest:guest@localhost:5672/publish", "test-rk-pub", "--publisher"], - { - "url": URL("amqp://guest:guest@localhost:5672/publish"), - "routing_key": "test-rk-pub", - "kind": "publisher", - "exchange": None, - "output_mode": "body", - "quiet": False, - "serializer": check_instance(IdentSerializer), - }, + ["amqp://guest:guest@localhost:5672/publish", "publish", "test-rk-pub"], + Options( + url=URL("amqp://guest:guest@localhost:5672/publish"), + mode="publish", + routing_key="test-rk-pub", + ), id="simple publisher", ), pytest.param( - ["amqp://guest:guest@localhost:5672/json-pub", "test-rk-json-pub", "--publisher", "--serializer=json"], - { - "url": URL("amqp://guest:guest@localhost:5672/json-pub"), - "routing_key": "test-rk-json-pub", - "kind": "publisher", - "exchange": None, - "output_mode": "body", - "quiet": False, - "serializer": check_instance(JSONSerializer), - }, - id="json publisher", + ["amqp://guest:guest@localhost:5672/publish", "check"], + Options( + url=URL("amqp://guest:guest@localhost:5672/publish"), + mode="check", + ), + id="simple checker", + ), + pytest.param( + [ + "amqp://guest:guest@localhost:5672/publish", + "--retry-delay=3", + "--retry-delay-mode=exponential", + "--retries-limit=4", + "check", + ], + Options( + url=URL("amqp://guest:guest@localhost:5672/publish"), + mode="check", + retry_delay=timedelta(seconds=3), + retry_delay_mode="exponential", + retries_limit=4, + ), + id="checker with exponential retryer", ), ], ) - def test_parse_ok(self, parser: ArgumentParser, args: t.Sequence[str], expected_ns: t.Mapping[str, object]) -> None: - assert parser.parse_args(args, namespace=CLIOptions()).__dict__ == expected_ns + def test_parse_ok(self, parser: Parser, args: t.Sequence[str], expected: Options) -> None: + assert parser.parse(args, Options()) == expected @pytest.mark.parametrize( ("args", "expected_message"), [ pytest.param( [], - "the following arguments are required: url, routing_key", - id="url & routing key are required", + "the following arguments are required: url", + id="url is required", ), pytest.param( - ["amqp://guest:guest@localhost:5672/", "test-rk"], - "one of the arguments -c/--consumer -p/--publisher is required", - id="consumer or publisher is required", + ["amqp://guest:guest@localhost:5672/"], + "the following arguments are required: {publish,consume,check}", + id="mode is required", ), pytest.param( - ["amqp://guest:guest@localhost:5672/", "test-rk", "--consumer", "--unknown-param"], + ["amqp://guest:guest@localhost:5672/", "consume", "test-rk", "--unknown-param"], "unrecognized arguments: --unknown-param", id="unknown param", ), pytest.param( - ["amqp://guest:guest@localhost:5672/", "test-rk", "--consumer", "--serializer=invalid"], - "argument -s/--serializer: invalid parse_serializer value: 'invalid'", + ["amqp://guest:guest@localhost:5672/", "consume", "test-rk", "--decoder=invalid"], + "argument --decoder: invalid __parse_decoder value: 'invalid'", id="invalid serializer", ), + pytest.param( + ["amqp://guest:guest@localhost:5672/", "--retry-delay-mode=unknown", "checker"], + re.escape( + "argument --retry-delay-mode: invalid choice: 'unknown' (choose from 'constant', 'multiplier', " + "'exponential')" + ), + id="invalid retry delay mode", + ), ], ) - def test_parse_error(self, parser: ArgumentParser, args: t.Sequence[str], expected_message: str) -> None: + def test_parse_error(self, parser: Parser, args: t.Sequence[str], expected_message: str) -> None: with pytest.raises(ArgumentError, match=expected_message): - parser.parse_args(args) - - -@pytest.fixture -def values() -> t.Mapping[str, object] | None: - return None + parser.parse(args) @pytest.fixture -def options(values: t.Mapping[str, object] | None) -> CLIOptions: - options = CLIOptions() - - for key, value in (values or {}).items(): - setattr(options, key, value) - - return options - - -@pytest.fixture -def parser() -> ArgumentParser: - parser = build_parser() - - # NOTE: raise exceptions instead of `sys.exit` - parser.exit_on_error = False # type: ignore[attr-defined] - - return parser +def parser() -> Parser: + return Parser(exit_on_error=False)