Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions src/danom/_either.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class Either[T_co, E_co: object](ABC):
Each monad is a frozen instance to prevent further mutation.
"""

inner: Any = attrs.field(default=None)

@classmethod
def unit(cls, inner: T_co) -> Right[T_co]:
"""Unit method. Given an item of type ``T`` return ``Right(T)``
Expand Down Expand Up @@ -164,11 +166,30 @@ def either_unwrap(result: Either[T_co, E_co]) -> T_co:
"""
return result.unwrap()

def flatten(self) -> Either[T_co, E_co]:
"""Flatten the monad. Will return the first ``Left`` or the lowest ``Right`` instance.

.. doctest::

>>> from danom import Left, Right, Stream, Either

>>> Right(Right(Right(1))).flatten() == Right(1)
True

>>> Right(Right(Left())).flatten() == Left()
True

"""
current = self

while isinstance(current, Either) and isinstance(current.inner, Either):
current = current.inner

return current


@attrs.define(frozen=True, hash=True)
class Right(Either[T_co, Never]):
inner: Any = attrs.field(default=None)

def is_ok(self) -> Literal[True]:
return True

Expand All @@ -179,7 +200,7 @@ def map_err(self, func: Mappable, *args: P.args, **kwargs: P.kwargs) -> Self: #
return self

def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Either[U_co, E_co]:
return func(self.inner, *args, **kwargs)
return Right(func(self.inner, *args, **kwargs)).flatten()

def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self: # noqa: ARG002
return self
Expand All @@ -190,8 +211,6 @@ def unwrap(self) -> T_co:

@attrs.define(frozen=True, hash=True)
class Left(Either[Never, E_co]):
inner: Any = attrs.field(default=None)

def is_ok(self) -> Literal[False]:
return False

Expand All @@ -205,7 +224,7 @@ def and_then(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Self:
return self

def or_else(self, func: Bindable, *args: P.args, **kwargs: P.kwargs) -> Either[U_co, E_co]:
return func(self.inner, *args, **kwargs)
return Left(func(self.inner, *args, **kwargs)).flatten() # ty: ignore[invalid-return-type]

def unwrap(self) -> T_co:
return self.inner
11 changes: 7 additions & 4 deletions src/danom/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,13 @@ def flatten(self) -> Result[T_co, E_co]:
"""
current = self

while isinstance(current, Ok) and isinstance(current.inner, Result):
current = current.inner

return current
while True:
if isinstance(current, Ok) and isinstance(current.inner, Result):
current = current.inner
elif isinstance(current, Err) and isinstance(current.error, Result):
current = current.error
else:
return current


@attrs.define(frozen=True, hash=True)
Expand Down
13 changes: 13 additions & 0 deletions tests/test_either.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,16 @@ def test_staticmethod_result_is_ok(monad, expected_result):
@pytest.mark.parametrize("inner", [pytest.param(0), pytest.param("something"), pytest.param([])])
def test_staticmethod_result_unwrap(monad, inner):
assert Either.either_unwrap(monad(inner)) == inner


@pytest.mark.parametrize(
("monad", "expected_result"),
[
pytest.param(Right(Right()), Right()),
pytest.param(Right(Left()), Left()),
pytest.param(Left(Right()), Right()),
pytest.param(Left(Left()), Left()),
],
)
def test_flatten(monad, expected_result) -> None:
assert monad.flatten() == expected_result
73 changes: 61 additions & 12 deletions tests/test_monad_laws.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,88 @@
from __future__ import annotations

import pytest
from hypothesis import given
from hypothesis import strategies as st

from danom import Either, Err, Left, Result
from danom import Either, Err, Left, Ok, Result, Right, identity


def monad_tests(parent: type[Result | Either], err_monad: type[Err | Left]):
def monad_tests(
parent: type[Result | Either], ok_monad: type[Ok | Right], err_monad: type[Err | Left]
):
inners = st.one_of(st.integers(), st.text(), st.floats(allow_nan=False, allow_infinity=False))

results = st.one_of(inners.map(parent.unit), st.just(err_monad(1)))
safe_fns = st.sampled_from([lambda x: parent.unit(x * 2), err_monad])

@given(inner=inners, f=safe_fns)
def test_monadic_left_identity(inner, f):
def test_monadic_left_identity(inner, f) -> None:
assert parent.unit(inner).and_then(f) == f(inner)

@given(results)
def test_monadic_right_identity(monad):
def test_monadic_right_identity(monad) -> None:
assert monad.and_then(parent.unit) == monad

@given(monad=results, f=safe_fns, g=safe_fns, h=safe_fns)
def test_monadic_associativity(monad, f, g, h):
def test_monadic_associativity(monad, f, g, h) -> None:
assert monad.and_then(f).and_then(g).or_else(h) == monad.and_then(
lambda x: f(x).and_then(g)
).or_else(h)

return test_monadic_left_identity, test_monadic_right_identity, test_monadic_associativity
st_results = st.integers().map(ok_monad) | st.text().map(err_monad)
st_nested_results = st.recursive(
st_results, lambda children: st.one_of(children.map(ok_monad)), max_leaves=5
)
st_nested_errs = st.recursive(
st_results, lambda children: st.one_of(children.map(err_monad)), max_leaves=5
)

@given(monad=st_nested_results)
def test_flatten_idempotent(monad) -> None:
assert monad.flatten().flatten() == monad.flatten()

test_result_left_identity, test_result_right_identity, test_result_associativity = monad_tests(
Result, Err
)
@given(monad=st_results)
def test_flatten_noop_for_flat_monad(monad) -> None:
assert monad.flatten() == monad

@given(monad=st_nested_results)
def test_and_then_flattens(monad) -> None:
assert monad.flatten() == monad.and_then(identity)

test_either_left_identity, test_either_right_identity, test_either_associativity = monad_tests(
Either, Left
)
@given(monad=st_nested_errs)
def test_or_else_flattens(monad) -> None:
if isinstance(monad, Result):
pytest.skip("or_else flatten law not defined for Result")
assert monad.flatten() == monad.or_else(identity)

return (
test_monadic_left_identity,
test_monadic_right_identity,
test_monadic_associativity,
test_flatten_idempotent,
test_flatten_noop_for_flat_monad,
test_and_then_flattens,
test_or_else_flattens,
)


(
test_result_left_identity,
test_result_right_identity,
test_result_associativity,
test_result_flatten_idempotent,
test_result_flatten_noop_for_flat_monad,
test_result_and_then_flattens,
test_result_or_else_flattens,
) = monad_tests(Result, Ok, Err)


(
test_either_left_identity,
test_either_right_identity,
test_either_associativity,
test_either_flatten_idempotent,
test_either_flatten_noop_for_flat_monad,
test_either_and_then_flattens,
test_either_or_else_flattens,
) = monad_tests(Either, Right, Left)
31 changes: 6 additions & 25 deletions tests/test_result.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from contextlib import nullcontext

import pytest
from hypothesis import given
from hypothesis import strategies as st

from danom import Err, Ok, Result
from danom._utils import identity
from tests.conftest import add_one


Expand Down Expand Up @@ -127,28 +124,12 @@ def test_staticmethod_result_unwrap(monad, expected_result, expected_context):

@pytest.mark.parametrize(
("monad", "expected_result"),
[pytest.param(Ok(Ok(Ok(1))), Ok(1)), pytest.param(Ok(Ok(Err())), Err())],
[
pytest.param(Ok(Ok()), Ok()),
pytest.param(Ok(Err()), Err()),
pytest.param(Err(Ok()), Ok()),
pytest.param(Err(Err()), Err()),
],
)
def test_flatten(monad, expected_result) -> None:
assert monad.flatten() == expected_result


st_results = st.integers().map(Ok) | st.text().map(Err)
st_nested_results = st.recursive(
st_results, lambda children: st.one_of(children.map(Ok)), max_leaves=10
)


@given(monad=st_nested_results)
def test_flatten_idempotent(monad) -> None:
assert monad.flatten().flatten() == monad.flatten()


@given(monad=st_results)
def test_flatten_noop_for_flat_monad(monad) -> None:
assert monad.flatten() == monad


@given(monad=st_nested_results)
def test_and_then_flattens(monad):
assert monad.flatten() == monad.and_then(identity)
Loading