From 80aa33199d6abfe6e8215f3cd8a595b0e8d95975 Mon Sep 17 00:00:00 2001 From: ed cuss Date: Wed, 6 May 2026 20:08:10 +0100 Subject: [PATCH] fix: flatten for Either --- src/danom/_either.py | 31 +++++++++++++---- src/danom/_result.py | 11 +++--- tests/test_either.py | 13 +++++++ tests/test_monad_laws.py | 73 +++++++++++++++++++++++++++++++++------- tests/test_result.py | 31 ++++------------- 5 files changed, 112 insertions(+), 47 deletions(-) diff --git a/src/danom/_either.py b/src/danom/_either.py index 1171939..c5ad9db 100644 --- a/src/danom/_either.py +++ b/src/danom/_either.py @@ -22,6 +22,8 @@ class Either[T_co, E_co: object](ABC): Each monad is a frozen instance to prevent further mutation. """ + inner: Any = attrs.field(default=None) + @classmethod def unit(cls, inner: T_co) -> Right[T_co]: """Unit method. Given an item of type ``T`` return ``Right(T)`` @@ -164,11 +166,30 @@ def either_unwrap(result: Either[T_co, E_co]) -> T_co: """ return result.unwrap() + def flatten(self) -> Either[T_co, E_co]: + """Flatten the monad. Will return the first ``Left`` or the lowest ``Right`` instance. + + .. doctest:: + + >>> from danom import Left, Right, Stream, Either + + >>> Right(Right(Right(1))).flatten() == Right(1) + True + + >>> Right(Right(Left())).flatten() == Left() + True + + """ + current = self + + while isinstance(current, Either) and isinstance(current.inner, Either): + current = current.inner + + return current + @attrs.define(frozen=True, hash=True) class Right(Either[T_co, Never]): - inner: Any = attrs.field(default=None) - def is_ok(self) -> Literal[True]: return True @@ -179,7 +200,7 @@ def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # return self def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Either[U_co, E_co]: - return func(self.inner, *args, **kwargs) + return Right(func(self.inner, *args, **kwargs)).flatten() def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self @@ -190,8 +211,6 @@ def unwrap(self) -> T_co: @attrs.define(frozen=True, hash=True) class Left(Either[Never, E_co]): - inner: Any = attrs.field(default=None) - def is_ok(self) -> Literal[False]: return False @@ -205,7 +224,7 @@ def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: return self def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Either[U_co, E_co]: - return func(self.inner, *args, **kwargs) + return Left(func(self.inner, *args, **kwargs)).flatten() # ty: ignore[invalid-return-type] def unwrap(self) -> T_co: return self.inner diff --git a/src/danom/_result.py b/src/danom/_result.py index b4d99c6..f8b431d 100644 --- a/src/danom/_result.py +++ b/src/danom/_result.py @@ -190,10 +190,13 @@ def flatten(self) -> Result[T_co, E_co]: """ current = self - while isinstance(current, Ok) and isinstance(current.inner, Result): - current = current.inner - - return current + while True: + if isinstance(current, Ok) and isinstance(current.inner, Result): + current = current.inner + elif isinstance(current, Err) and isinstance(current.error, Result): + current = current.error + else: + return current @attrs.define(frozen=True, hash=True) diff --git a/tests/test_either.py b/tests/test_either.py index dce22d6..ab5b29a 100644 --- a/tests/test_either.py +++ b/tests/test_either.py @@ -82,3 +82,16 @@ def test_staticmethod_result_is_ok(monad, expected_result): @pytest.mark.parametrize("inner", [pytest.param(0), pytest.param("something"), pytest.param([])]) def test_staticmethod_result_unwrap(monad, inner): assert Either.either_unwrap(monad(inner)) == inner + + +@pytest.mark.parametrize( + ("monad", "expected_result"), + [ + pytest.param(Right(Right()), Right()), + pytest.param(Right(Left()), Left()), + pytest.param(Left(Right()), Right()), + pytest.param(Left(Left()), Left()), + ], +) +def test_flatten(monad, expected_result) -> None: + assert monad.flatten() == expected_result diff --git a/tests/test_monad_laws.py b/tests/test_monad_laws.py index 155aaca..c694f61 100644 --- a/tests/test_monad_laws.py +++ b/tests/test_monad_laws.py @@ -1,39 +1,88 @@ from __future__ import annotations +import pytest from hypothesis import given from hypothesis import strategies as st -from danom import Either, Err, Left, Result +from danom import Either, Err, Left, Ok, Result, Right, identity -def monad_tests(parent: type[Result | Either], err_monad: type[Err | Left]): +def monad_tests( + parent: type[Result | Either], ok_monad: type[Ok | Right], err_monad: type[Err | Left] +): inners = st.one_of(st.integers(), st.text(), st.floats(allow_nan=False, allow_infinity=False)) results = st.one_of(inners.map(parent.unit), st.just(err_monad(1))) safe_fns = st.sampled_from([lambda x: parent.unit(x * 2), err_monad]) @given(inner=inners, f=safe_fns) - def test_monadic_left_identity(inner, f): + def test_monadic_left_identity(inner, f) -> None: assert parent.unit(inner).and_then(f) == f(inner) @given(results) - def test_monadic_right_identity(monad): + def test_monadic_right_identity(monad) -> None: assert monad.and_then(parent.unit) == monad @given(monad=results, f=safe_fns, g=safe_fns, h=safe_fns) - def test_monadic_associativity(monad, f, g, h): + def test_monadic_associativity(monad, f, g, h) -> None: assert monad.and_then(f).and_then(g).or_else(h) == monad.and_then( lambda x: f(x).and_then(g) ).or_else(h) - return test_monadic_left_identity, test_monadic_right_identity, test_monadic_associativity + st_results = st.integers().map(ok_monad) | st.text().map(err_monad) + st_nested_results = st.recursive( + st_results, lambda children: st.one_of(children.map(ok_monad)), max_leaves=5 + ) + st_nested_errs = st.recursive( + st_results, lambda children: st.one_of(children.map(err_monad)), max_leaves=5 + ) + @given(monad=st_nested_results) + def test_flatten_idempotent(monad) -> None: + assert monad.flatten().flatten() == monad.flatten() -test_result_left_identity, test_result_right_identity, test_result_associativity = monad_tests( - Result, Err -) + @given(monad=st_results) + def test_flatten_noop_for_flat_monad(monad) -> None: + assert monad.flatten() == monad + @given(monad=st_nested_results) + def test_and_then_flattens(monad) -> None: + assert monad.flatten() == monad.and_then(identity) -test_either_left_identity, test_either_right_identity, test_either_associativity = monad_tests( - Either, Left -) + @given(monad=st_nested_errs) + def test_or_else_flattens(monad) -> None: + if isinstance(monad, Result): + pytest.skip("or_else flatten law not defined for Result") + assert monad.flatten() == monad.or_else(identity) + + return ( + test_monadic_left_identity, + test_monadic_right_identity, + test_monadic_associativity, + test_flatten_idempotent, + test_flatten_noop_for_flat_monad, + test_and_then_flattens, + test_or_else_flattens, + ) + + +( + test_result_left_identity, + test_result_right_identity, + test_result_associativity, + test_result_flatten_idempotent, + test_result_flatten_noop_for_flat_monad, + test_result_and_then_flattens, + test_result_or_else_flattens, +) = monad_tests(Result, Ok, Err) + + +( + test_either_left_identity, + test_either_right_identity, + test_either_associativity, + test_either_flatten_idempotent, + test_either_flatten_noop_for_flat_monad, + test_either_and_then_flattens, + test_either_or_else_flattens, +) = monad_tests(Either, Right, Left) diff --git a/tests/test_result.py b/tests/test_result.py index 99f41f5..01a053a 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -1,11 +1,8 @@ from contextlib import nullcontext import pytest -from hypothesis import given -from hypothesis import strategies as st from danom import Err, Ok, Result -from danom._utils import identity from tests.conftest import add_one @@ -127,28 +124,12 @@ def test_staticmethod_result_unwrap(monad, expected_result, expected_context): @pytest.mark.parametrize( ("monad", "expected_result"), - [pytest.param(Ok(Ok(Ok(1))), Ok(1)), pytest.param(Ok(Ok(Err())), Err())], + [ + pytest.param(Ok(Ok()), Ok()), + pytest.param(Ok(Err()), Err()), + pytest.param(Err(Ok()), Ok()), + pytest.param(Err(Err()), Err()), + ], ) def test_flatten(monad, expected_result) -> None: assert monad.flatten() == expected_result - - -st_results = st.integers().map(Ok) | st.text().map(Err) -st_nested_results = st.recursive( - st_results, lambda children: st.one_of(children.map(Ok)), max_leaves=10 -) - - -@given(monad=st_nested_results) -def test_flatten_idempotent(monad) -> None: - assert monad.flatten().flatten() == monad.flatten() - - -@given(monad=st_results) -def test_flatten_noop_for_flat_monad(monad) -> None: - assert monad.flatten() == monad - - -@given(monad=st_nested_results) -def test_and_then_flattens(monad): - assert monad.flatten() == monad.and_then(identity)