From fbeccc703084a9e23217f385a94380bc822f2705 Mon Sep 17 00:00:00 2001 From: ed cuss Date: Wed, 15 Apr 2026 18:37:20 +0100 Subject: [PATCH] feat: pass in args and kwargs into map, filter and tap --- pyproject.toml | 2 +- src/danom/_stream.py | 73 +++++++++++++++++++++++----------------- tests/test_monad_laws.py | 2 +- tests/test_stream.py | 29 +++++++++------- uv.lock | 2 +- 5 files changed, 62 insertions(+), 46 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9cd6c10..8256c50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "danom" -version = "0.12.1" +version = "0.13.0" description = "Functional streams and monads" readme = "README.md" license = "MIT" diff --git a/src/danom/_stream.py b/src/danom/_stream.py index 8342f7a..cdc07ae 100644 --- a/src/danom/_stream.py +++ b/src/danom/_stream.py @@ -8,7 +8,7 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from copy import deepcopy from enum import Enum -from functools import reduce +from functools import partial, reduce from itertools import batched from typing import ParamSpec, TypeVar, cast @@ -22,20 +22,20 @@ E = TypeVar("E") P = ParamSpec("P") -MapFn = Callable[[T], U] -FilterFn = Callable[[T], bool] -TapFn = Callable[[T], None] +MapFn = Callable[P, U] +FilterFn = Callable[P, bool] +TapFn = Callable[P, None] -AsyncMapFn = Callable[[T], Awaitable[U]] -AsyncFilterFn = Callable[[T], Awaitable[bool]] -AsyncTapFn = Callable[[T], Awaitable[None]] +AsyncMapFn = Callable[P, Awaitable[U]] +AsyncFilterFn = Callable[P, Awaitable[bool]] +AsyncTapFn = Callable[P, Awaitable[None]] StreamFn = MapFn | FilterFn | TapFn AsyncStreamFn = AsyncMapFn | AsyncFilterFn | AsyncTapFn @attrs.define(frozen=True) -class _BaseStream(ABC): +class _BaseStream[T](ABC): seq: tuple = attrs.field(validator=attrs.validators.instance_of(tuple)) ops: tuple = attrs.field(default=(), validator=attrs.validators.instance_of(tuple), repr=False) @@ -44,16 +44,24 @@ class _BaseStream(ABC): def from_iterable(cls, it: Iterable) -> _BaseStream[T]: ... @abstractmethod - def map(self, *fns: MapFn | AsyncMapFn) -> _BaseStream[T]: ... + def map[**P]( + self, fn: MapFn | AsyncMapFn, *args: P.args, **kwargs: P.kwargs + ) -> _BaseStream[T]: ... @abstractmethod - def filter(self, *fns: FilterFn | AsyncFilterFn) -> _BaseStream[T]: ... + def filter[**P]( + self, fn: FilterFn | AsyncFilterFn, *args: P.args, **kwargs: P.kwargs + ) -> _BaseStream[T]: ... @abstractmethod - def tap(self, *fns: TapFn | AsyncTapFn) -> _BaseStream[T]: ... + def tap[**P]( + self, fn: TapFn | AsyncTapFn, *args: P.args, **kwargs: P.kwargs + ) -> _BaseStream[T]: ... @abstractmethod - def partition(self, fn: FilterFn) -> tuple[_BaseStream[T], _BaseStream[U]]: ... + def partition[U]( + self, fn: FilterFn, *, workers: int = 1, use_threads: bool = False + ) -> tuple[_BaseStream[T], _BaseStream[U]]: ... @abstractmethod def fold( @@ -74,7 +82,7 @@ def __bool__(self) -> bool: @attrs.define(frozen=True) -class Stream[Type](_BaseStream): +class Stream[T](_BaseStream): """An immutable lazy iterator with functional operations. Why bother? @@ -161,6 +169,11 @@ class Stream[Type](_BaseStream): keyword breakdown: `{}` The business logic is arguably much clearer like this. + + Version changes + ---------- + ``0.13.0``: ``Stream.map``, ``Stream.filter`` and ``Stream.tap`` now take kwargs and ``partial`` them into the passed in function. + """ @classmethod @@ -178,7 +191,7 @@ def from_iterable(cls, it: Iterable) -> Stream[T]: it = [it] return cls(seq=tuple(it)) - def map(self, *fns: MapFn | AsyncMapFn) -> Stream[U]: + def map[**P](self, fn: MapFn | AsyncMapFn, *args: P.args, **kwargs: P.kwargs) -> Stream[T]: """Map a function to the elements in the ``Stream``. Will return a new ``Stream`` with the modified sequence. .. code-block:: python @@ -202,20 +215,22 @@ def two_div_value(x: float) -> float: Stream.from_iterable([0, 1, 2, 4]).map(two_div_value).collect() == (Err(error=ZeroDivisionError('division by zero')), Ok(2.0), Ok(1.0), Ok(0.5)) - Simple functions can be passed in sequence to compose more complex transformations + Keyword arguments can be passed into the functions: .. code-block:: python from danom import Stream - Stream.from_iterable(range(5)).map(mul_two, add_one).collect() == (1, 3, 5, 7, 9) + Stream.from_iterable(range(5)).map(mul, b=2).map(add, b=1).collect() == (1, 3, 5, 7, 9) """ - plan = (*self.ops, *tuple((_MAP, fn) for fn in fns)) + plan = (*self.ops, (_MAP, partial(fn, *args, **kwargs))) object.__setattr__(self, "ops", plan) return self - def filter(self, *fns: FilterFn | AsyncFilterFn) -> Stream[T]: + def filter[**P]( + self, fn: FilterFn | AsyncFilterFn, *args: P.args, **kwargs: P.kwargs + ) -> Stream[T]: """Filter the stream based on a predicate. Will return a new ``Stream`` with the modified sequence. .. doctest:: @@ -225,20 +240,20 @@ def filter(self, *fns: FilterFn | AsyncFilterFn) -> Stream[T]: >>> Stream.from_iterable([0, 1, 2, 3]).filter(lambda x: x % 2 == 0).collect() == (0, 2) True - Simple functions can be passed in sequence to compose more complex filters + Keyword arguments can be passed into the functions: .. code-block:: python from danom import Stream - Stream.from_iterable(range(20)).filter(divisible_by_3, divisible_by_5).collect() == (0, 15) + Stream.from_iterable(range(20)).filter(divisible_by, x=3).filter(divisible_by, x=5).collect() == (0, 15) """ - plan = (*self.ops, *tuple((_FILTER, fn) for fn in fns)) + plan = (*self.ops, (_FILTER, partial(fn, *args, **kwargs))) object.__setattr__(self, "ops", plan) return self - def tap(self, *fns: TapFn | AsyncTapFn) -> Stream[T]: + def tap[**P](self, fn: TapFn | AsyncTapFn, *args: P.args, **kwargs: P.kwargs) -> Stream[T]: """Tap the values to another process that returns None. Will return a new ``Stream`` with the modified sequence. The value passed to the tap function will be deep-copied to avoid any modification to the ``Stream`` item for downstream consumers. @@ -256,7 +271,7 @@ def tap(self, *fns: TapFn | AsyncTapFn) -> Stream[T]: from danom import Stream - Stream.from_iterable([0, 1, 2, 3]).tap(log_value, print_value).collect() == (0, 1, 2, 3) + Stream.from_iterable([0, 1, 2, 3]).tap(log_value).tap(print_value).collect() == (0, 1, 2, 3) ``tap`` is useful for logging and similar actions without effecting the individual items, in this example eligible and dormant users are logged using ``tap``: @@ -269,16 +284,12 @@ def tap(self, *fns: TapFn | AsyncTapFn) -> Stream[T]: Stream.from_iterable(users).map(parse_user_objects).partition(inactive_users) ) - active_users.filter(eligible_for_promotion).tap(log_eligible_users).map( - construct_promo_email, send_with_confirmation - ).collect() + active_users.filter(eligible_for_promotion).tap(log_eligible_users).map(construct_promo_email).map(send_with_confirmation).collect() - inactive_users.tap(log_inactive_users).map( - create_dormant_user_entry, add_to_dormant_table - ).collect() + inactive_users.tap(log_inactive_users).map(create_dormant_user_entry).map(add_to_dormant_table).collect() """ - plan = (*self.ops, *tuple((_TAP, fn) for fn in fns)) + plan = (*self.ops, (_TAP, partial(fn, *args, **kwargs))) object.__setattr__(self, "ops", plan) return self @@ -305,7 +316,7 @@ def partition( from danom import Stream - Stream.from_iterable(range(10)).map(add_one, add_one).partition(divisible_by_3, workers=4) + Stream.from_iterable(range(10)).map(add_one).map(add_one).partition(divisible_by_3, workers=4) part1.map(add_one).par_collect() == (4, 7, 10) part2.collect() == (2, 4, 5, 7, 8, 10, 11) diff --git a/tests/test_monad_laws.py b/tests/test_monad_laws.py index 2691097..155aaca 100644 --- a/tests/test_monad_laws.py +++ b/tests/test_monad_laws.py @@ -10,7 +10,7 @@ def monad_tests(parent: type[Result | Either], err_monad: type[Err | Left]): inners = st.one_of(st.integers(), st.text(), st.floats(allow_nan=False, allow_infinity=False)) results = st.one_of(inners.map(parent.unit), st.just(err_monad(1))) - safe_fns = st.sampled_from([lambda x: parent.unit(x * 2), lambda x: err_monad(x)]) + safe_fns = st.sampled_from([lambda x: parent.unit(x * 2), err_monad]) @given(inner=inners, f=safe_fns) def test_monadic_left_identity(inner, f): diff --git a/tests/test_stream.py b/tests/test_stream.py index 3f847c1..f432c02 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -45,7 +45,7 @@ def test_stream_pipeline(collect_fn, kwargs, it, expected_part1, expected_part2) part1, part2 = Stream.from_iterable(it).map(add_one).partition(is_even) assert ( - _get_attr_collect(part1.map(add_one, add_one).filter(divisible_by_3), collect_fn, kwargs) + _get_attr_collect(part1.map(add, 1).map(add_one).filter(divisible_by_3), collect_fn, kwargs) == expected_part1 ) assert _get_attr_collect(part2, collect_fn, kwargs) == expected_part2 @@ -70,7 +70,11 @@ def test_stream_pipeline(collect_fn, kwargs, it, expected_part1, expected_part2) def test_collect_methods(collect_fn, kwargs, it, expected_result): assert ( _get_attr_collect( - Stream.from_iterable(it).map(add_one, add_one).filter(divisible_by_3, divisible_by_5), + Stream.from_iterable(it) + .map(add_one) + .map(add_one) + .filter(divisible_by_3) + .filter(divisible_by_5), collect_fn, kwargs, ) @@ -89,7 +93,7 @@ def test_collect_methods(collect_fn, kwargs, it, expected_result): ) def test_stream_to_par_stream(collect_fn, kwargs): part1, part2 = ( - Stream.from_iterable(range(10)).map(add_one, add_one).partition(divisible_by_3, workers=4) + Stream.from_iterable(range(10)).map(add, b=2).partition(divisible_by_3, workers=4) ) assert _get_attr_collect(part1.map(add_one), collect_fn, kwargs) == (4, 7, 10) assert _get_attr_collect(part2, collect_fn, kwargs) == (2, 4, 5, 7, 8, 10, 11) @@ -122,7 +126,11 @@ def test_tap(collect_fn, kwargs): val_logger = ValueLogger(values) assert _get_attr_collect( - Stream.from_iterable(range(4)).map(add_one).tap(val_logger, val_logger).map(add_one), + Stream.from_iterable(range(4)) + .map(add, b=1) + .tap(val_logger) + .tap(val_logger) + .map(add_one), collect_fn, kwargs, ) == (2, 3, 4, 5) @@ -175,7 +183,7 @@ def test_sequence(kwargs, seq, expected_result, expected_context): @pytest.mark.asyncio async def test_async_collect(): assert await Stream.from_iterable( - sorted(Path(f"{REPO_ROOT}/tests/mock_data").glob("*")) + sorted(Path(f"{REPO_ROOT}/tests/mock_data").glob("*")) # noqa: ASYNC240 ).filter(async_is_file).map(async_read_text).async_collect() == ( "", "x = 1\n", @@ -187,7 +195,7 @@ async def test_async_collect(): @pytest.mark.asyncio async def test_async_collect_no_fns(): assert await Stream.from_iterable( - sorted(Path(f"{REPO_ROOT}/tests/mock_data").glob("*")) + sorted(Path(f"{REPO_ROOT}/tests/mock_data").glob("*")) # noqa: ASYNC240 ).async_collect() == ( Path(f"{REPO_ROOT}/tests/mock_data/__init__.py"), Path(f"{REPO_ROOT}/tests/mock_data/dir_should_skip"), @@ -202,12 +210,9 @@ async def test_async_tap(): val_logger = AsyncValueLogger() val_logger_2 = AsyncValueLogger() - assert await Stream.from_iterable(range(4)).tap(val_logger, val_logger_2).async_collect() == ( - 0, - 1, - 2, - 3, - ) + assert await Stream.from_iterable(range(4)).tap(val_logger).tap( + val_logger_2 + ).async_collect() == (0, 1, 2, 3) assert sorted(val_logger.values) == [0, 1, 2, 3] assert sorted(val_logger_2.values) == [0, 1, 2, 3] diff --git a/uv.lock b/uv.lock index 1a53c8d..338db63 100644 --- a/uv.lock +++ b/uv.lock @@ -343,7 +343,7 @@ wheels = [ [[package]] name = "danom" -version = "0.12.1" +version = "0.13.0" source = { editable = "." } dependencies = [ { name = "attrs" },