Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions pyenumerable/implementations/pure_python/_enumerable.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,13 +541,15 @@ def zip[TSecond](
def reverse(self, /) -> PurePythonEnumerable[TSource]:
return PurePythonEnumerable(*reversed(self.source))

def union(
def intersect(
self,
second: PurePythonEnumerable[TSource],
/,
*,
comparer: Comparer[TSource] | None = None,
) -> PurePythonEnumerable[TSource]:
if len(self.source) == 0 or len(second.source) == 0:
return PurePythonEnumerable()
comparer_: Comparer[TSource] = (
comparer if comparer is not None else lambda i, o: i == o
)
Expand All @@ -562,14 +564,16 @@ def union(
out.append(inner)
return PurePythonEnumerable(*out)

def union_by[TKey](
def intersect_by[TKey](
self,
second: PurePythonEnumerable[TSource],
key_selector: Callable[[TSource], TKey],
/,
*,
comparer: Comparer[TKey] | None = None,
) -> PurePythonEnumerable[TSource]:
if len(self.source) == 0 or len(second.source) == 0:
return PurePythonEnumerable()
comparer_: Comparer[TKey] = (
comparer if comparer is not None else lambda i, o: i == o
)
Expand All @@ -596,7 +600,6 @@ def sequence_equal(
) -> bool:
if len(self.source) != len(other.source):
return False

comparer_: Comparer[TSource] = (
comparer if comparer is not None else lambda i, o: i == o
)
Expand Down Expand Up @@ -674,6 +677,64 @@ def aggregate(
curr = func(curr, item)
return curr

def union(
self,
second: PurePythonEnumerable[TSource],
/,
*,
comparer: Comparer[TSource] | None = None,
) -> PurePythonEnumerable[TSource]:
if comparer is not None:
out: list[TSource] = []
for inner in self.source:
for captured in out:
if comparer(inner, captured):
break
else:
out.append(inner)
for outer in second.source:
for captured in out:
if comparer(outer, captured):
break
else:
out.append(outer)
return PurePythonEnumerable(*out)
try:
return PurePythonEnumerable(
*dict.fromkeys((*self.source, *second.source)).keys()
)
except TypeError as te:
msg = "TSource doesn't implement __hash__; Comparer isn't given"
raise TypeError(msg) from te

def union_by[TKey](
self,
second: PurePythonEnumerable[TSource],
key_selector: Callable[[TSource], TKey],
/,
*,
comparer: Comparer[TKey] | None = None,
) -> PurePythonEnumerable[TSource]:
comparer_: Comparer[TKey] = (
comparer if comparer is not None else lambda i, o: i == o
)
out: list[TSource] = []
for inner in self.source:
inner_key = key_selector(inner)
for captured in out:
if comparer_(inner_key, key_selector(captured)):
break
else:
out.append(inner)
for outer in second.source:
outer_key = key_selector(outer)
for captured in out:
if comparer_(outer_key, key_selector(captured)):
break
else:
out.append(outer)
return PurePythonEnumerable(*out)

@staticmethod
def _assume_not_empty(instance: PurePythonEnumerable[Any]) -> None:
if len(instance.source) == 0:
Expand Down
4 changes: 2 additions & 2 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Implementation of .NET's [IEnumerable](https://learn.microsoft.com/en-us/dotnet/
- [ ] Implement `Enumerable` for PP Implementation
- [x] Any
- [x] All
- [ ] Aggregate
- [x] Aggregate
- [x] Chunk
- [x] Average
- [x] Append
Expand All @@ -19,7 +19,7 @@ Implementation of .NET's [IEnumerable](https://learn.microsoft.com/en-us/dotnet/
- [x] Contains
- [x] Concat
- [ ] Join
- [ ] Intersect
- [x] Intersect
- [ ] Group join
- [ ] Group by
- [x] Prepend
Expand Down
Empty file.
83 changes: 83 additions & 0 deletions test/unit/pure_python/intersect/test_intersect_by_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from pyenumerable.implementations.pure_python import PurePythonEnumerable
from test.unit.pure_python.test_utility import Person, Point


class TestIntersectByMethod:
def test_when_self_empty(self) -> None:
first_object = PurePythonEnumerable(Point(0, 1), Point(1, 0))
second_object: PurePythonEnumerable[Point] = PurePythonEnumerable()

res = first_object.intersect_by(second_object, lambda point: point.y)

assert res.source == ()

def test_when_second_empty(self) -> None:
first_object: PurePythonEnumerable[Point] = PurePythonEnumerable()
second_object = PurePythonEnumerable(Point(0, 1), Point(1, 0))

res = first_object.intersect_by(second_object, lambda point: point.y)

assert res.source == ()

def test_without_comparer(self) -> None:
first_object = PurePythonEnumerable(
first := Person("john doe", 12, Person("marray doe", 27)),
Person("jane doe", 10, Person("larry doe", 31)),
Person("james doe", 11, None),
second := Person("jacob doe", 17, Person("harry doe", 41)),
third := Person(" doe", 14, Person("jerry doe", 34)),
)
second_object = PurePythonEnumerable(
Person("john doe", 12, Person("arry doe", 27)),
Person("jane doe", 10, Person("curry doe", 35)),
Person("jacob doe", 17, Person("harry doe", 41)),
Person(" doe", 14, Person("jerry doe", 34)),
)

res = first_object.intersect_by(
second_object,
lambda person: None
if person.parent is None
else person.parent.age,
)

assert res.source == (first, second, third)

def test_with_comparer(self) -> None:
first_object = PurePythonEnumerable(
first := Point(5, 1),
Point(3, 3),
Point(4, 5),
second := Point(2, 7),
third := Point(3, 9),
)
second_object = PurePythonEnumerable(
Point(4, -1), Point(3, 2), Point(1, -7), Point(2, -9), Point(5, -8)
)

res = first_object.intersect_by(
second_object,
lambda point: point.y,
comparer=lambda first_y, second_y: abs(first_y) == abs(second_y),
)

assert res.source == (first, second, third)

def test_overlap_remove(self) -> None:
first_object = PurePythonEnumerable(
first := Point(5, 1),
Point(6, 1),
second := Point(2, 7),
third := Point(3, 9),
)
second_object = PurePythonEnumerable(
Point(4, -1), Point(3, 2), Point(1, -7), Point(2, -9), Point(5, -8)
)

res = first_object.intersect_by(
second_object,
lambda point: point.y,
comparer=lambda first_y, second_y: abs(first_y) == abs(second_y),
)

assert res.source == (first, second, third)
73 changes: 73 additions & 0 deletions test/unit/pure_python/intersect/test_intersect_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from pyenumerable.implementations.pure_python import PurePythonEnumerable
from test.unit.pure_python.test_utility import Point


class TestIntersectMethod:
def test_when_self_empty(self) -> None:
first_object = PurePythonEnumerable(*range(7))
second_object: PurePythonEnumerable[int] = PurePythonEnumerable()

res = first_object.intersect(second_object)

assert res.source == ()

def test_when_second_empty(self) -> None:
first_object: PurePythonEnumerable[int] = PurePythonEnumerable()
second_object = PurePythonEnumerable(*range(7))

res = first_object.intersect(second_object)

assert res.source == ()

def test_without_comparer(self) -> None:
first_object = PurePythonEnumerable(*range(end := 7))
second_object = PurePythonEnumerable(*range(start := 3, 9))

res = first_object.intersect(second_object)

assert res.source == tuple(range(start, end))

def test_with_comparer(self) -> None:
first_object = PurePythonEnumerable(
first := Point(0, 1),
Point(0, 3),
Point(0, 4),
second := Point(0, 7),
third := Point(0, 9),
)
second_object = PurePythonEnumerable(
Point(0, -2),
Point(0, 1),
Point(0, 5),
Point(0, 7),
Point(0, 9),
)

res = first_object.intersect(
second_object,
comparer=lambda first_point, second_point: first_point.y
== second_point.y,
)

assert res.source == (first, second, third)

def test_overlap_remove(self) -> None:
first_object = PurePythonEnumerable(
first := Point(0, 1),
Point(3, 1),
second := Point(0, 7),
third := Point(0, 9),
)
second_object = PurePythonEnumerable(
Point(0, 1),
Point(0, 7),
Point(0, 9),
)

res = first_object.intersect(
second_object,
comparer=lambda first_point, second_point: first_point.y
== second_point.y,
)

assert res.source == (first, second, third)
87 changes: 62 additions & 25 deletions test/unit/pure_python/union/test_union_by_method.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,50 @@
from pyenumerable.implementations.pure_python import PurePythonEnumerable
from test.unit.pure_python.test_utility import Person, Point
from test.unit.pure_python.test_utility import Point


class TestUnionByMethod:
def test_without_comparer(self) -> None:
first_object = PurePythonEnumerable(
first := Person("john doe", 12, Person("marray doe", 27)),
Person("jane doe", 10, Person("larry doe", 31)),
Person("james doe", 11, None),
second := Person("jacob doe", 17, Person("harry doe", 41)),
third := Person(" doe", 14, Person("jerry doe", 34)),
*(
first_items := (
Point(0, 1),
Point(0, 2),
Point(0, 3),
)
)
)
second_object = PurePythonEnumerable(
Person("john doe", 12, Person("arry doe", 27)),
Person("jane doe", 10, Person("curry doe", 35)),
Person("jacob doe", 17, Person("harry doe", 41)),
Person(" doe", 14, Person("jerry doe", 34)),
*(
second_items := (
Point(0, 4),
Point(0, 5),
Point(0, 6),
)
)
)

res = first_object.union_by(
second_object,
lambda person: None
if person.parent is None
else person.parent.age,
)
res = first_object.union_by(second_object, lambda point: point.y)

assert res.source == (first, second, third)
assert res.source == first_items + second_items

def test_with_comparer(self) -> None:
first_object = PurePythonEnumerable(
first := Point(5, 1),
Point(6, 1),
Point(3, 3),
Point(4, 5),
second := Point(2, 7),
third := Point(3, 9),
*(
first_items := (
Point(0, 1),
Point(0, 2),
Point(0, 3),
)
)
)
second_object = PurePythonEnumerable(
Point(4, -1), Point(3, 2), Point(1, -7), Point(2, -9), Point(5, -8)
*(
second_items := (
Point(0, -4),
Point(0, -5),
Point(0, -6),
)
)
)

res = first_object.union_by(
Expand All @@ -46,4 +53,34 @@ def test_with_comparer(self) -> None:
comparer=lambda first_y, second_y: abs(first_y) == abs(second_y),
)

assert res.source == (first, second, third)
assert res.source == first_items + second_items

def test_overlap_remove_for_self(self) -> None:
first_object = PurePythonEnumerable(
first := Point(0, 1),
Point(1, 1),
second := Point(0, 3),
)
second_object = PurePythonEnumerable(
third := Point(4, 5),
fourth := Point(6, 7),
)

res = first_object.union_by(second_object, lambda point: point.y)

assert res.source == (first, second, third, fourth)

def test_overlap_remove_for_second(self) -> None:
first_object = PurePythonEnumerable(
first := Point(0, 1),
second := Point(0, 3),
)
second_object = PurePythonEnumerable(
third := Point(4, 5),
Point(5, 5),
fourth := Point(6, 7),
)

res = first_object.union_by(second_object, lambda point: point.y)

assert res.source == (first, second, third, fourth)
Loading