Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/danom/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,27 @@ def result_unwrap(result: Result[T_co, E_co]) -> T_co:
"""
return result.unwrap()

def flatten(self) -> Result[T_co, E_co]:
"""Flatten the monad. Will return the first ``Err`` or the lowest ``Ok`` instance.

.. doctest::

>>> from danom import Err, Ok, Stream, Result

>>> Ok(Ok(Ok(1))).flatten() == Ok(1)
True

>>> Ok(Ok(Err())).flatten() == Err()
True

"""
current = self

while isinstance(current, Ok) and isinstance(current.inner, Result):
current = current.inner

return current


@attrs.define(frozen=True, hash=True)
class Ok(Result[T_co, Never]):
Expand All @@ -191,7 +212,7 @@ def map_err[**P](self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Sel
def and_then[**P](
self, func: Bindable, *args: P.args, **kwargs: P.kwargs
) -> Result[T_co, E_co]:
return func(self.inner, *args, **kwargs)
return Ok(func(self.inner, *args, **kwargs)).flatten()

def or_else[**P](self, func: Recoverable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self
Expand Down Expand Up @@ -248,7 +269,7 @@ def and_then[**P](self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Se
def or_else[**P](
self, func: Recoverable, *args: P.args, **kwargs: P.kwargs
) -> Result[T_co, E_co]: # ty: ignore[invalid-method-override]
return func(self.error, *args, **kwargs)
return Ok(func(self.error, *args, **kwargs)).flatten()

def unwrap(self) -> T_co:
if isinstance(self.error, Exception):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_result.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from contextlib import nullcontext

import pytest
from hypothesis import given
from hypothesis import strategies as st

from danom import Err, Ok, Result
from danom._utils import identity
from tests.conftest import add_one


Expand Down Expand Up @@ -120,3 +123,32 @@ def test_staticmethod_result_is_ok(monad, expected_result):
def test_staticmethod_result_unwrap(monad, expected_result, expected_context):
with expected_context:
assert Result.result_unwrap(monad) == expected_result


@pytest.mark.parametrize(
("monad", "expected_result"),
[pytest.param(Ok(Ok(Ok(1))), Ok(1)), pytest.param(Ok(Ok(Err())), Err())],
)
def test_flatten(monad, expected_result) -> None:
assert monad.flatten() == expected_result


st_results = st.integers().map(Ok) | st.text().map(Err)
st_nested_results = st.recursive(
st_results, lambda children: st.one_of(children.map(Ok)), max_leaves=10
)


@given(monad=st_nested_results)
def test_flatten_idempotent(monad) -> None:
assert monad.flatten().flatten() == monad.flatten()


@given(monad=st_results)
def test_flatten_noop_for_flat_monad(monad) -> None:
assert monad.flatten() == monad


@given(monad=st_nested_results)
def test_and_then_flattens(monad):
assert monad.flatten() == monad.and_then(identity)
Loading