Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "danom"
version = "0.13.2"
version = "0.14.0"
description = "Functional streams and monads"
readme = "README.md"
license = "MIT"
Expand Down
60 changes: 45 additions & 15 deletions src/danom/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
P = ParamSpec("P")

Mappable = Callable[Concatenate[T_co, P], U_co]
Bindable = Callable[Concatenate[T_co, P], "Result[U_co, E_co]"]
Bindable = Callable[Concatenate[T_co, P], "Result[T_co, E_co]"]
Recoverable = Callable[Concatenate[E_co, P], "Result[T_co, E_co]"]


@attrs.define(frozen=True)
Expand Down Expand Up @@ -61,7 +62,7 @@ def is_ok(self) -> bool:
...

@abstractmethod
def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
def map[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[T_co, E_co]:
"""Pipe a pure function and wrap the return value with ``Ok``.
Given an ``Err`` will return self.

Expand All @@ -75,7 +76,7 @@ def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co,
...

@abstractmethod
def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
def map_err[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[T_co, E_co]:
"""Pipe a pure function and wrap the return value with ``Err``.
Given an ``Ok`` will return self.

Expand All @@ -89,7 +90,9 @@ def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Result[U
...

@abstractmethod
def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
def and_then[**P](
self, func: Bindable, *args: P.args, **kwargs: P.kwargs
) -> Result[T_co, E_co]:
"""Pipe another function that returns a monad. For ``Err`` will return original error.

.. code-block:: python
Expand All @@ -104,7 +107,9 @@ def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[
...

@abstractmethod
def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
def or_else[**P](
self, func: Recoverable, *args: P.args, **kwargs: P.kwargs
) -> Result[T_co, E_co]:
"""Pipe a function that returns a monad to recover from an ``Err``. For ``Ok`` will return original ``Result``.

.. code-block:: python
Expand Down Expand Up @@ -169,6 +174,27 @@ def result_unwrap(result: Result[T_co, E_co]) -> T_co:
"""
return result.unwrap()

def flatten(self) -> Result[T_co, E_co]:
"""Flatten the monad. Will return the first ``Err`` or the lowest ``Ok`` instance.

.. doctest::

>>> from danom import Err, Ok, Stream, Result

>>> Ok(Ok(Ok(1))).flatten() == Ok(1)
True

>>> Ok(Ok(Err())).flatten() == Err()
True

"""
current = self

while isinstance(current, Ok) and isinstance(current.inner, Result):
current = current.inner

return current


@attrs.define(frozen=True, hash=True)
class Ok(Result[T_co, Never]):
Expand All @@ -177,16 +203,18 @@ class Ok(Result[T_co, Never]):
def is_ok(self) -> Literal[True]:
return True

def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Ok[U_co]:
def map[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Ok[T_co]:
return Ok(func(self.inner, *args, **kwargs))

def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
def map_err[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self

def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
return func(self.inner, *args, **kwargs)
def and_then[**P](
self, func: Bindable, *args: P.args, **kwargs: P.kwargs
) -> Result[T_co, E_co]:
return Ok(func(self.inner, *args, **kwargs)).flatten()

def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
def or_else[**P](self, func: Recoverable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self

def unwrap(self) -> T_co:
Expand Down Expand Up @@ -229,17 +257,19 @@ def _extract_details(self, tb: TracebackType | None) -> list[dict[str, Any]]:
def is_ok(self) -> Literal[False]:
return False

def map(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
def map[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self

def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Err[F_co]:
def map_err[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Err[E_co]:
return Err(func(self.error, *args, **kwargs))

def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
def and_then[**P](self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self

def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Result[U_co, E_co]:
return func(self.error, *args, **kwargs)
def or_else[**P](
self, func: Recoverable, *args: P.args, **kwargs: P.kwargs
) -> Result[T_co, E_co]: # ty: ignore[invalid-method-override]
return Ok(func(self.error, *args, **kwargs)).flatten()

def unwrap(self) -> T_co:
if isinstance(self.error, Exception):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_result.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from contextlib import nullcontext

import pytest
from hypothesis import given
from hypothesis import strategies as st

from danom import Err, Ok, Result
from danom._utils import identity
from tests.conftest import add_one


Expand Down Expand Up @@ -120,3 +123,32 @@ def test_staticmethod_result_is_ok(monad, expected_result):
def test_staticmethod_result_unwrap(monad, expected_result, expected_context):
with expected_context:
assert Result.result_unwrap(monad) == expected_result


@pytest.mark.parametrize(
("monad", "expected_result"),
[pytest.param(Ok(Ok(Ok(1))), Ok(1)), pytest.param(Ok(Ok(Err())), Err())],
)
def test_flatten(monad, expected_result) -> None:
assert monad.flatten() == expected_result


st_results = st.integers().map(Ok) | st.text().map(Err)
st_nested_results = st.recursive(
st_results, lambda children: st.one_of(children.map(Ok)), max_leaves=10
)


@given(monad=st_nested_results)
def test_flatten_idempotent(monad) -> None:
assert monad.flatten().flatten() == monad.flatten()


@given(monad=st_results)
def test_flatten_noop_for_flat_monad(monad) -> None:
assert monad.flatten() == monad


@given(monad=st_nested_results)
def test_and_then_flattens(monad):
assert monad.flatten() == monad.and_then(identity)
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading