diff --git a/pyenumerable/implementations/pure_python/_enumerable.py b/pyenumerable/implementations/pure_python/_enumerable.py index fe49eba..f25c229 100644 --- a/pyenumerable/implementations/pure_python/_enumerable.py +++ b/pyenumerable/implementations/pure_python/_enumerable.py @@ -541,13 +541,15 @@ def zip[TSecond]( def reverse(self, /) -> PurePythonEnumerable[TSource]: return PurePythonEnumerable(*reversed(self.source)) - def union( + def intersect( self, second: PurePythonEnumerable[TSource], /, *, comparer: Comparer[TSource] | None = None, ) -> PurePythonEnumerable[TSource]: + if len(self.source) == 0 or len(second.source) == 0: + return PurePythonEnumerable() comparer_: Comparer[TSource] = ( comparer if comparer is not None else lambda i, o: i == o ) @@ -562,7 +564,7 @@ def union( out.append(inner) return PurePythonEnumerable(*out) - def union_by[TKey]( + def intersect_by[TKey]( self, second: PurePythonEnumerable[TSource], key_selector: Callable[[TSource], TKey], @@ -570,6 +572,8 @@ def union_by[TKey]( *, comparer: Comparer[TKey] | None = None, ) -> PurePythonEnumerable[TSource]: + if len(self.source) == 0 or len(second.source) == 0: + return PurePythonEnumerable() comparer_: Comparer[TKey] = ( comparer if comparer is not None else lambda i, o: i == o ) @@ -596,7 +600,6 @@ def sequence_equal( ) -> bool: if len(self.source) != len(other.source): return False - comparer_: Comparer[TSource] = ( comparer if comparer is not None else lambda i, o: i == o ) @@ -674,6 +677,64 @@ def aggregate( curr = func(curr, item) return curr + def union( + self, + second: PurePythonEnumerable[TSource], + /, + *, + comparer: Comparer[TSource] | None = None, + ) -> PurePythonEnumerable[TSource]: + if comparer is not None: + out: list[TSource] = [] + for inner in self.source: + for captured in out: + if comparer(inner, captured): + break + else: + out.append(inner) + for outer in second.source: + for captured in out: + if comparer(outer, captured): + break + else: + out.append(outer) + return PurePythonEnumerable(*out) + try: + return PurePythonEnumerable( + *dict.fromkeys((*self.source, *second.source)).keys() + ) + except TypeError as te: + msg = "TSource doesn't implement __hash__; Comparer isn't given" + raise TypeError(msg) from te + + def union_by[TKey]( + self, + second: PurePythonEnumerable[TSource], + key_selector: Callable[[TSource], TKey], + /, + *, + comparer: Comparer[TKey] | None = None, + ) -> PurePythonEnumerable[TSource]: + comparer_: Comparer[TKey] = ( + comparer if comparer is not None else lambda i, o: i == o + ) + out: list[TSource] = [] + for inner in self.source: + inner_key = key_selector(inner) + for captured in out: + if comparer_(inner_key, key_selector(captured)): + break + else: + out.append(inner) + for outer in second.source: + outer_key = key_selector(outer) + for captured in out: + if comparer_(outer_key, key_selector(captured)): + break + else: + out.append(outer) + return PurePythonEnumerable(*out) + @staticmethod def _assume_not_empty(instance: PurePythonEnumerable[Any]) -> None: if len(instance.source) == 0: diff --git a/readme.md b/readme.md index aac0f2f..531d12d 100644 --- a/readme.md +++ b/readme.md @@ -9,7 +9,7 @@ Implementation of .NET's [IEnumerable](https://learn.microsoft.com/en-us/dotnet/ - [ ] Implement `Enumerable` for PP Implementation - [x] Any - [x] All - - [ ] Aggregate + - [x] Aggregate - [x] Chunk - [x] Average - [x] Append @@ -19,7 +19,7 @@ Implementation of .NET's [IEnumerable](https://learn.microsoft.com/en-us/dotnet/ - [x] Contains - [x] Concat - [ ] Join - - [ ] Intersect + - [x] Intersect - [ ] Group join - [ ] Group by - [x] Prepend diff --git a/test/unit/pure_python/intersect/__init__.py b/test/unit/pure_python/intersect/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/unit/pure_python/intersect/test_intersect_by_method.py b/test/unit/pure_python/intersect/test_intersect_by_method.py new file mode 100644 index 0000000..b1d5841 --- /dev/null +++ b/test/unit/pure_python/intersect/test_intersect_by_method.py @@ -0,0 +1,83 @@ +from pyenumerable.implementations.pure_python import PurePythonEnumerable +from test.unit.pure_python.test_utility import Person, Point + + +class TestIntersectByMethod: + def test_when_self_empty(self) -> None: + first_object = PurePythonEnumerable(Point(0, 1), Point(1, 0)) + second_object: PurePythonEnumerable[Point] = PurePythonEnumerable() + + res = first_object.intersect_by(second_object, lambda point: point.y) + + assert res.source == () + + def test_when_second_empty(self) -> None: + first_object: PurePythonEnumerable[Point] = PurePythonEnumerable() + second_object = PurePythonEnumerable(Point(0, 1), Point(1, 0)) + + res = first_object.intersect_by(second_object, lambda point: point.y) + + assert res.source == () + + def test_without_comparer(self) -> None: + first_object = PurePythonEnumerable( + first := Person("john doe", 12, Person("marray doe", 27)), + Person("jane doe", 10, Person("larry doe", 31)), + Person("james doe", 11, None), + second := Person("jacob doe", 17, Person("harry doe", 41)), + third := Person(" doe", 14, Person("jerry doe", 34)), + ) + second_object = PurePythonEnumerable( + Person("john doe", 12, Person("arry doe", 27)), + Person("jane doe", 10, Person("curry doe", 35)), + Person("jacob doe", 17, Person("harry doe", 41)), + Person(" doe", 14, Person("jerry doe", 34)), + ) + + res = first_object.intersect_by( + second_object, + lambda person: None + if person.parent is None + else person.parent.age, + ) + + assert res.source == (first, second, third) + + def test_with_comparer(self) -> None: + first_object = PurePythonEnumerable( + first := Point(5, 1), + Point(3, 3), + Point(4, 5), + second := Point(2, 7), + third := Point(3, 9), + ) + second_object = PurePythonEnumerable( + Point(4, -1), Point(3, 2), Point(1, -7), Point(2, -9), Point(5, -8) + ) + + res = first_object.intersect_by( + second_object, + lambda point: point.y, + comparer=lambda first_y, second_y: abs(first_y) == abs(second_y), + ) + + assert res.source == (first, second, third) + + def test_overlap_remove(self) -> None: + first_object = PurePythonEnumerable( + first := Point(5, 1), + Point(6, 1), + second := Point(2, 7), + third := Point(3, 9), + ) + second_object = PurePythonEnumerable( + Point(4, -1), Point(3, 2), Point(1, -7), Point(2, -9), Point(5, -8) + ) + + res = first_object.intersect_by( + second_object, + lambda point: point.y, + comparer=lambda first_y, second_y: abs(first_y) == abs(second_y), + ) + + assert res.source == (first, second, third) diff --git a/test/unit/pure_python/intersect/test_intersect_method.py b/test/unit/pure_python/intersect/test_intersect_method.py new file mode 100644 index 0000000..7613d62 --- /dev/null +++ b/test/unit/pure_python/intersect/test_intersect_method.py @@ -0,0 +1,73 @@ +from pyenumerable.implementations.pure_python import PurePythonEnumerable +from test.unit.pure_python.test_utility import Point + + +class TestIntersectMethod: + def test_when_self_empty(self) -> None: + first_object = PurePythonEnumerable(*range(7)) + second_object: PurePythonEnumerable[int] = PurePythonEnumerable() + + res = first_object.intersect(second_object) + + assert res.source == () + + def test_when_second_empty(self) -> None: + first_object: PurePythonEnumerable[int] = PurePythonEnumerable() + second_object = PurePythonEnumerable(*range(7)) + + res = first_object.intersect(second_object) + + assert res.source == () + + def test_without_comparer(self) -> None: + first_object = PurePythonEnumerable(*range(end := 7)) + second_object = PurePythonEnumerable(*range(start := 3, 9)) + + res = first_object.intersect(second_object) + + assert res.source == tuple(range(start, end)) + + def test_with_comparer(self) -> None: + first_object = PurePythonEnumerable( + first := Point(0, 1), + Point(0, 3), + Point(0, 4), + second := Point(0, 7), + third := Point(0, 9), + ) + second_object = PurePythonEnumerable( + Point(0, -2), + Point(0, 1), + Point(0, 5), + Point(0, 7), + Point(0, 9), + ) + + res = first_object.intersect( + second_object, + comparer=lambda first_point, second_point: first_point.y + == second_point.y, + ) + + assert res.source == (first, second, third) + + def test_overlap_remove(self) -> None: + first_object = PurePythonEnumerable( + first := Point(0, 1), + Point(3, 1), + second := Point(0, 7), + third := Point(0, 9), + ) + second_object = PurePythonEnumerable( + Point(0, 1), + Point(0, 7), + Point(0, 9), + ) + + res = first_object.intersect( + second_object, + comparer=lambda first_point, second_point: first_point.y + == second_point.y, + ) + + assert res.source == (first, second, third) diff --git a/test/unit/pure_python/union/test_union_by_method.py b/test/unit/pure_python/union/test_union_by_method.py index 35ede6d..adcf256 100644 --- a/test/unit/pure_python/union/test_union_by_method.py +++ b/test/unit/pure_python/union/test_union_by_method.py @@ -1,43 +1,50 @@ from pyenumerable.implementations.pure_python import PurePythonEnumerable -from test.unit.pure_python.test_utility import Person, Point +from test.unit.pure_python.test_utility import Point class TestUnionByMethod: def test_without_comparer(self) -> None: first_object = PurePythonEnumerable( - first := Person("john doe", 12, Person("marray doe", 27)), - Person("jane doe", 10, Person("larry doe", 31)), - Person("james doe", 11, None), - second := Person("jacob doe", 17, Person("harry doe", 41)), - third := Person(" doe", 14, Person("jerry doe", 34)), + *( + first_items := ( + Point(0, 1), + Point(0, 2), + Point(0, 3), + ) + ) ) second_object = PurePythonEnumerable( - Person("john doe", 12, Person("arry doe", 27)), - Person("jane doe", 10, Person("curry doe", 35)), - Person("jacob doe", 17, Person("harry doe", 41)), - Person(" doe", 14, Person("jerry doe", 34)), + *( + second_items := ( + Point(0, 4), + Point(0, 5), + Point(0, 6), + ) + ) ) - res = first_object.union_by( - second_object, - lambda person: None - if person.parent is None - else person.parent.age, - ) + res = first_object.union_by(second_object, lambda point: point.y) - assert res.source == (first, second, third) + assert res.source == first_items + second_items def test_with_comparer(self) -> None: first_object = PurePythonEnumerable( - first := Point(5, 1), - Point(6, 1), - Point(3, 3), - Point(4, 5), - second := Point(2, 7), - third := Point(3, 9), + *( + first_items := ( + Point(0, 1), + Point(0, 2), + Point(0, 3), + ) + ) ) second_object = PurePythonEnumerable( - Point(4, -1), Point(3, 2), Point(1, -7), Point(2, -9), Point(5, -8) + *( + second_items := ( + Point(0, -4), + Point(0, -5), + Point(0, -6), + ) + ) ) res = first_object.union_by( @@ -46,4 +53,34 @@ def test_with_comparer(self) -> None: comparer=lambda first_y, second_y: abs(first_y) == abs(second_y), ) - assert res.source == (first, second, third) + assert res.source == first_items + second_items + + def test_overlap_remove_for_self(self) -> None: + first_object = PurePythonEnumerable( + first := Point(0, 1), + Point(1, 1), + second := Point(0, 3), + ) + second_object = PurePythonEnumerable( + third := Point(4, 5), + fourth := Point(6, 7), + ) + + res = first_object.union_by(second_object, lambda point: point.y) + + assert res.source == (first, second, third, fourth) + + def test_overlap_remove_for_second(self) -> None: + first_object = PurePythonEnumerable( + first := Point(0, 1), + second := Point(0, 3), + ) + second_object = PurePythonEnumerable( + third := Point(4, 5), + Point(5, 5), + fourth := Point(6, 7), + ) + + res = first_object.union_by(second_object, lambda point: point.y) + + assert res.source == (first, second, third, fourth) diff --git a/test/unit/pure_python/union/test_union_method.py b/test/unit/pure_python/union/test_union_method.py index 448fb58..05b01cf 100644 --- a/test/unit/pure_python/union/test_union_method.py +++ b/test/unit/pure_python/union/test_union_method.py @@ -1,32 +1,56 @@ +import pytest + from pyenumerable.implementations.pure_python import PurePythonEnumerable from test.unit.pure_python.test_utility import Point class TestUnionMethod: + def test_exc_raise_when_unhashable(self) -> None: + first_object = PurePythonEnumerable(Point(0, 1), Point(1, 0)) + second_object = PurePythonEnumerable(Point(1, 0), Point(0, 1)) + + with pytest.raises(TypeError): + first_object.union(second_object) + def test_without_comparer(self) -> None: - first_object = PurePythonEnumerable(*range(end := 7)) - second_object = PurePythonEnumerable(*range(start := 3, 9)) + first_object = PurePythonEnumerable(*(first_items := tuple(range(3)))) + second_object = PurePythonEnumerable( + *(second_items := tuple(range(7, 10))) + ) res = first_object.union(second_object) - assert res.source == tuple(range(start, end)) + assert res.source == first_items + second_items def test_with_comparer(self) -> None: - first_object = PurePythonEnumerable( - first := Point(0, 1), - Point(3, 1), - Point(0, 3), - Point(0, 4), - second := Point(0, 7), - third := Point(0, 9), + first_object = PurePythonEnumerable(*(items := tuple(range(7)))) + second_object = PurePythonEnumerable(*(-i for i in items)) + + res = first_object.union( + second_object, comparer=lambda x, y: abs(x) == abs(y) ) + + assert res.source == items + + def test_overlap_remove_for_self(self) -> None: + first_object = PurePythonEnumerable(first := Point(0, 1), Point(1, 1)) second_object = PurePythonEnumerable( - Point(0, -2), - Point(0, 1), - Point(0, 5), - Point(0, 7), - Point(0, 9), + second := Point(2, 3), third := Point(4, 5) + ) + + res = first_object.union( + second_object, + comparer=lambda first_point, second_point: first_point.y + == second_point.y, + ) + + assert res.source == (first, second, third) + + def test_overlap_remove_for_second(self) -> None: + first_object = PurePythonEnumerable( + first := Point(2, 3), second := Point(4, 5) ) + second_object = PurePythonEnumerable(third := Point(0, 1), Point(1, 1)) res = first_object.union( second_object,