diff --git a/pyproject.toml b/pyproject.toml index d138330..e15f811 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "danom" -version = "0.13.2" +version = "0.14.0" description = "Functional streams and monads" readme = "README.md" license = "MIT" diff --git a/src/danom/_result.py b/src/danom/_result.py index ed303c3..b4d99c6 100644 --- a/src/danom/_result.py +++ b/src/danom/_result.py @@ -15,7 +15,8 @@ P = ParamSpec("P") Mappable = Callable[Concatenate[T_co, P], U_co] -Bindable = Callable[Concatenate[T_co, P], "Result[U_co, E_co]"] +Bindable = Callable[Concatenate[T_co, P], "Result[T_co, E_co]"] +Recoverable = Callable[Concatenate[E_co, P], "Result[T_co, E_co]"] @attrs.define(frozen=True) @@ -61,7 +62,7 @@ def is_ok(self) -> bool: ... @abstractmethod - def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: + def map[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[T_co, E_co]: """Pipe a pure function and wrap the return value with ``Ok``. Given an ``Err`` will return self. @@ -75,7 +76,7 @@ def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, ... @abstractmethod - def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: + def map_err[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[T_co, E_co]: """Pipe a pure function and wrap the return value with ``Err``. Given an ``Ok`` will return self. @@ -89,7 +90,9 @@ def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U ... @abstractmethod - def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: + def and_then[**P]( + self, func: Bindable, *args: P.args, **kwargs: P.kwargs + ) -> Result[T_co, E_co]: """Pipe another function that returns a monad. For ``Err`` will return original error. .. code-block:: python @@ -104,7 +107,9 @@ def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[ ... @abstractmethod - def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: + def or_else[**P]( + self, func: Recoverable, *args: P.args, **kwargs: P.kwargs + ) -> Result[T_co, E_co]: """Pipe a function that returns a monad to recover from an ``Err``. For ``Ok`` will return original ``Result``. .. code-block:: python @@ -169,6 +174,27 @@ def result_unwrap(result: Result[T_co, E_co]) -> T_co: """ return result.unwrap() + def flatten(self) -> Result[T_co, E_co]: + """Flatten the monad. Will return the first ``Err`` or the lowest ``Ok`` instance. + + .. doctest:: + + >>> from danom import Err, Ok, Stream, Result + + >>> Ok(Ok(Ok(1))).flatten() == Ok(1) + True + + >>> Ok(Ok(Err())).flatten() == Err() + True + + """ + current = self + + while isinstance(current, Ok) and isinstance(current.inner, Result): + current = current.inner + + return current + @attrs.define(frozen=True, hash=True) class Ok(Result[T_co, Never]): @@ -177,16 +203,18 @@ class Ok(Result[T_co, Never]): def is_ok(self) -> Literal[True]: return True - def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Ok[U_co]: + def map[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Ok[T_co]: return Ok(func(self.inner, *args, **kwargs)) - def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 + def map_err[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self - def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: - return func(self.inner, *args, **kwargs) + def and_then[**P]( + self, func: Bindable, *args: P.args, **kwargs: P.kwargs + ) -> Result[T_co, E_co]: + return Ok(func(self.inner, *args, **kwargs)).flatten() - def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 + def or_else[**P](self, func: Recoverable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self def unwrap(self) -> T_co: @@ -229,17 +257,19 @@ def _extract_details(self, tb: TracebackType | None) -> list[dict[str, Any]]: def is_ok(self) -> Literal[False]: return False - def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 + def map[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self - def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Err[F_co]: + def map_err[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Err[E_co]: return Err(func(self.error, *args, **kwargs)) - def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 + def and_then[**P](self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002 return self - def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]: - return func(self.error, *args, **kwargs) + def or_else[**P]( + self, func: Recoverable, *args: P.args, **kwargs: P.kwargs + ) -> Result[T_co, E_co]: # ty: ignore[invalid-method-override] + return Ok(func(self.error, *args, **kwargs)).flatten() def unwrap(self) -> T_co: if isinstance(self.error, Exception): diff --git a/tests/test_result.py b/tests/test_result.py index a225039..99f41f5 100644 --- a/tests/test_result.py +++ b/tests/test_result.py @@ -1,8 +1,11 @@ from contextlib import nullcontext import pytest +from hypothesis import given +from hypothesis import strategies as st from danom import Err, Ok, Result +from danom._utils import identity from tests.conftest import add_one @@ -120,3 +123,32 @@ def test_staticmethod_result_is_ok(monad, expected_result): def test_staticmethod_result_unwrap(monad, expected_result, expected_context): with expected_context: assert Result.result_unwrap(monad) == expected_result + + +@pytest.mark.parametrize( + ("monad", "expected_result"), + [pytest.param(Ok(Ok(Ok(1))), Ok(1)), pytest.param(Ok(Ok(Err())), Err())], +) +def test_flatten(monad, expected_result) -> None: + assert monad.flatten() == expected_result + + +st_results = st.integers().map(Ok) | st.text().map(Err) +st_nested_results = st.recursive( + st_results, lambda children: st.one_of(children.map(Ok)), max_leaves=10 +) + + +@given(monad=st_nested_results) +def test_flatten_idempotent(monad) -> None: + assert monad.flatten().flatten() == monad.flatten() + + +@given(monad=st_results) +def test_flatten_noop_for_flat_monad(monad) -> None: + assert monad.flatten() == monad + + +@given(monad=st_nested_results) +def test_and_then_flattens(monad): + assert monad.flatten() == monad.and_then(identity) diff --git a/uv.lock b/uv.lock index d5071cc..0e56250 100644 --- a/uv.lock +++ b/uv.lock @@ -343,7 +343,7 @@ wheels = [ [[package]] name = "danom" -version = "0.13.2" +version = "0.14.0" source = { editable = "." } dependencies = [ { name = "attrs" },