Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 108 additions & 20 deletions coconut/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,14 @@ def bind(cls):
cls.method("has_expr_setname_manage"),
include_in_packrat_context=True,
)
cls.inner_match_comp_for <<= attach(
cls.inner_match_comp_for_ref,
cls.method("inner_match_comp_for_handle"),
)
cls.implicit_inner_match_comp_for <<= attach(
cls.implicit_inner_match_comp_for_ref,
cls.method("inner_match_comp_for_handle"),
)
cls.normal_comp_expr <<= manage(
cls.normal_comp_expr_ref,
cls.method("has_expr_setname_manage"),
Expand Down Expand Up @@ -4594,50 +4602,130 @@ def stmt_lambdef_handle(self, original, loc, tokens):

return self.handle_expr_scope_closure(name, loc)

def match_comp_expr_handle(self, original, loc, tokens, dict_val=None):
"""Build a match comprehension by creating a temp match function.
For dict comps, expr is the key and dict_val is the value."""
expr, (matches, iterable) = tokens
def _build_match_comp_filter_func(self, original, loc, matches, iterable):
"""Generate a filter function for a pattern-matching comprehension clause.

Returns (func_expr, iter_var, unpack, extra_func_setup) where:
- func_expr: the compiled function reference (possibly a closure)
- iter_var: temp var to use as the outer loop variable over iterable
- unpack: the target expression to use in 'for unpack in func_expr(iter_var)'
The generated function takes one item and returns [] on no match or
[(bound_var1, bound_var2, ...)] on match, so it can be used as:
for iter_var in iterable for unpack in func_expr(iter_var)
"""
func_name = self.get_temp_var("match_comp", loc)
item_param = self.get_temp_var("match_comp_item", loc)
iter_var = self.get_temp_var("match_comp_iter", loc)
check_var = self.get_temp_var("match_check", loc)

matcher = self.get_matcher(original, loc, check_var)
matcher.match(matches, iter_var)

bound_names = []
matcher = self.get_matcher(original, loc, check_var, name_list=bound_names)
matcher.match(matches, item_param)
match_code = matcher.build()
match_error = self.pattern_error(original, loc, iter_var, check_var)

if dict_val is not None:
return_expr = "(" + expr + ", " + dict_val + ")"
if bound_names:
bound_tuple = "(" + ", ".join(bound_names) + ("," if len(bound_names) == 1 else "") + ")"
unpack = bound_tuple
else:
return_expr = expr
bound_tuple = "None"
dummy_var = self.get_temp_var("match_comp_dummy", loc)
unpack = dummy_var

funcdef = handle_indentation(
"""
def {func_name}({iter_var}):
def {func_name}({item_param}):
{match_code}
{match_error}
return {return_expr}
if not {check_var}:
return []
return [{bound_tuple}]
""",
add_newline=True,
).format(
func_name=func_name,
iter_var=iter_var,
item_param=item_param,
match_code=match_code,
match_error=match_error,
return_expr=return_expr,
check_var=check_var,
bound_tuple=bound_tuple,
)

self.add_code_before[func_name] = self.decoratable_funcdef_stmt_handle(original, loc, [funcdef], is_stmt_lambda=True)

# Update expr_setnames context to exclude variables bound by the pattern.
expr_setname_ctx = self.current_parsing_context("expr_setnames")
if expr_setname_ctx is not None and bound_names:
expr_setname_ctx["new_names"] -= set(bound_names)

func_expr = self.handle_expr_scope_closure(func_name, loc)
return func_expr, iter_var, unpack

def match_comp_expr_handle(self, original, loc, tokens, dict_val=None):
"""Build a match comprehension by creating a per-item filter function.
For dict comps, expr is the key and dict_val is the value.

The generated comprehension expands to:
expr for iter_var in iterable for unpack in func(iter_var) [extra_comp_clauses]
This puts all extra clauses (for/if) outside the filter function so they
have access to both pattern-bound variables and any outer loop variables.
"""
expr = tokens[0]
match_for_group = tokens[1]

if len(match_for_group) == 2:
matches, iterable = match_for_group
extra_comp_clauses = ""
else:
matches = match_for_group[0]
iterable = match_for_group[1]
extra_comp_clauses = " " + "".join(match_for_group[2:])

func_expr, iter_var, unpack = self._build_match_comp_filter_func(original, loc, matches, iterable)

# Update expr_setnames for variables introduced by extra comprehension clauses.
if extra_comp_clauses:
expr_setname_ctx = self.current_parsing_context("expr_setnames")
if expr_setname_ctx is not None:
comp_for_vars = set()
for m in re.finditer(r"\bfor\s+(.*?)\s+in\b", extra_comp_clauses):
comp_for_vars.update(re.findall(r"[a-zA-Z_]\w*", m.group(1)))
expr_setname_ctx["new_names"] -= comp_for_vars

# Build the comprehension body: extra_comp_clauses go OUTSIDE the function
# so they can reference both pattern-bound variables and outer loop variables.
if dict_val is not None:
return "_coconut.dict((" + func_expr + "(" + iter_var + ") for " + iter_var + " in " + iterable + "))"
inner_expr = "(" + expr + ", " + dict_val + ")"
return "_coconut.dict(" + inner_expr + " for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + "(" + iter_var + ")" + extra_comp_clauses + ")"
else:
return expr + " for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + "(" + iter_var + ")" + extra_comp_clauses

def inner_match_comp_for_handle(self, original, loc, tokens):
"""Handle a pattern-matching for clause in a non-first position.

Generates a per-item filter function and returns the for-clause string:
for iter_var in iterable for unpack in func(iter_var) [extra_clauses]
This string is inserted into the surrounding comprehension by the grammar,
allowing full access to any outer loop variables in subsequent clauses.
"""
match_for_group = tokens[0]

if len(match_for_group) == 2:
matches, iterable = match_for_group
extra_clauses = ""
else:
return func_expr + "(" + iter_var + ") for " + iter_var + " in " + iterable
matches = match_for_group[0]
iterable = match_for_group[1]
extra_clauses = " " + "".join(match_for_group[2:])

func_expr, iter_var, unpack = self._build_match_comp_filter_func(original, loc, matches, iterable)

# Update expr_setnames for variables introduced by extra comprehension clauses.
if extra_clauses:
expr_setname_ctx = self.current_parsing_context("expr_setnames")
if expr_setname_ctx is not None:
comp_for_vars = set()
for m in re.finditer(r"\bfor\s+(.*?)\s+in\b", extra_clauses):
comp_for_vars.update(re.findall(r"[a-zA-Z_]\w*", m.group(1)))
expr_setname_ctx["new_names"] -= comp_for_vars

return "for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + ("(""") + iter_var + ")" + extra_clauses

def get_parent_expr_setnames(self):
"""Get all expr_setnames in parent contexts, but not the current context."""
Expand Down
25 changes: 22 additions & 3 deletions coconut/compiler/grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,8 @@ class Grammar(object):
new_namedexpr_test = Forward()
comp_for = Forward()
match_comp_for = Forward()
inner_match_comp_for = Forward()
implicit_inner_match_comp_for = Forward()
normal_comp_expr = Forward()
match_comp_expr = Forward()
comprehension_expr = Forward()
Expand Down Expand Up @@ -1086,7 +1088,7 @@ class Grammar(object):
| yield_classic
)
normal_dict_comp_ref = test + colon.suppress() + test + comp_for
match_dict_comp_ref = test + colon.suppress() + test + Optional(keyword("match").suppress()) + match_comp_for
match_dict_comp_ref = test + colon.suppress() + test + match_comp_for
dict_comp = lbrace.suppress() + (
normal_dict_comp
| match_dict_comp
Expand Down Expand Up @@ -1960,14 +1962,31 @@ class Grammar(object):
async_comp_for_ref = addspace(keyword("async") + base_comp_for)
comp_for <<= base_comp_for | async_comp_for
comp_if = addspace(keyword("if") + test_no_cond + Optional(comp_iter))
comp_iter <<= any_of(comp_for, comp_if)
inner_match_comp_for_ref = Group(
keyword("for").suppress()
+ keyword("match").suppress()
+ many_match
+ keyword("in").suppress()
+ comp_it_item
+ Optional(comp_iter)
)
implicit_inner_match_comp_for_ref = Group(
keyword("for").suppress()
+ many_match
+ keyword("in").suppress()
+ comp_it_item
+ Optional(comp_iter)
)
comp_iter <<= any_of(comp_for, comp_if, inner_match_comp_for, implicit_inner_match_comp_for)
match_comp_for <<= Group(
keyword("for").suppress()
+ Optional(keyword("match").suppress())
+ many_match
+ keyword("in").suppress()
+ comp_it_item
+ Optional(comp_iter)
)
match_comp_expr_ref = namedexpr_test + Optional(keyword("match").suppress()) + match_comp_for
match_comp_expr_ref = namedexpr_test + match_comp_for
normal_comp_expr_ref = addspace(namedexpr_test + comp_for)
comprehension_expr <<= (
normal_comp_expr
Expand Down
82 changes: 82 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary_2.coco
Original file line number Diff line number Diff line change
Expand Up @@ -758,4 +758,86 @@ def primary_test_2() -> bool:
lazy from coconut_nonexistent_test_module4.submodule import missing as lazy_missing2
assert_raises((def => lazy_missing2.attr = 1), ImportError)

# pattern-matching comprehensions (issue #887)
data CompPair(a, b)
data CompTriple(x, y, z)
comp_pairs = [CompPair(1, 2), CompPair(3, 4), CompPair(5, 6)]
comp_triples = [CompTriple(1, 2, 3), CompTriple(4, 5, 6)]

# basic match comp
assert [a + b for CompPair(a, b) in comp_pairs] == [3, 7, 11]

# extra for clause
assert [a + b for CompPair(a, b) in comp_pairs for _ in range(3)] == [3, 3, 3, 7, 7, 7, 11, 11, 11]

# for + if on the loop variable (not pattern variable)
assert [a + b for CompPair(a, b) in comp_pairs for x in range(4) if x % 2 == 0] == [3, 3, 7, 7, 11, 11]

# for with extra list
assert [a + b + i for CompPair(a, b) in comp_pairs for i in [10, 20]] == [13, 23, 17, 27, 21, 31]

# triple data type, basic match comp
assert [x + y + z for CompTriple(x, y, z) in comp_triples] == [6, 15]

# triple + extra for
assert [x * y for CompTriple(x, y, z) in comp_triples for _ in range(2)] == [2, 2, 20, 20]

# if True (no filtering)
assert [a + b for CompPair(a, b) in comp_pairs if True] == [3, 7, 11]

# for + if, only first iteration kept
assert [a + b for CompPair(a, b) in comp_pairs for x in range(2) if x == 0] == [3, 7, 11]

# generator expression with sum
assert sum(a + b for CompPair(a, b) in comp_pairs) == 21

# if filtering on pattern variables
assert [a + b for CompPair(a, b) in comp_pairs if a > 1] == [7, 11]

# if with multiple conditions on pattern variables
assert [a + b for CompPair(a, b) in comp_pairs if a > 1 if b < 6] == [7]

# if with combined condition on pattern variables
assert [a for CompPair(a, b) in comp_pairs if a % 2 == 1 and b % 2 == 0] == [1, 3, 5]

# dict comprehension with pattern match and if on pattern variable
assert {a: b for CompPair(a, b) in comp_pairs if a != 3} == {1: 2, 5: 6}

# set comprehension with pattern match and if on pattern variable
assert {a + b for CompPair(a, b) in comp_pairs if b > 2} == {7, 11}

# generator expression with if on pattern variable
assert list(a * b for CompPair(a, b) in comp_pairs if a > 2) == [12, 30]

# triple with if filtering on pattern variable
assert [x + y + z for CompTriple(x, y, z) in comp_triples if z > 3] == [15]

# arbitrary mixing: normal for followed by pattern-matching for
assert [a + b for z in [comp_pairs] for match CompPair(a, b) in z] == [3, 7, 11]

# arbitrary mixing: pattern-matching for followed by another pattern-matching for
nested_pairs = [[CompPair(1, 2), "x"], [CompPair(3, 4), "y"]]
assert [a + b for [CompPair(a, b), _] in nested_pairs] == [3, 7]

# inner PM for with if guard on outer loop variable
assert [a for z in [1, 3] for match CompPair(a, b) in comp_pairs if a == z] == [1, 3]

# inner PM for: pattern variable available in subsequent clause
assert [a + i for match CompPair(a, b) in comp_pairs for i in range(b - a)] == [1, 3, 3, 5, 5]

# inner PM for with dict comprehension
assert {a: b for z in [comp_pairs] for match CompPair(a, b) in z} == {1: 2, 3: 4, 5: 6}

# implicit inner PM for (no 'match' keyword) in non-first position
assert [a + b for z in [comp_pairs] for CompPair(a, b) in z] == [3, 7, 11]

# implicit inner PM for with if filter
assert [a + b for z in [comp_pairs] for CompPair(a, b) in z if a > 1] == [7, 11]

# implicit inner PM for: two consecutive pattern-matching clauses (first via match_comp_for, second via implicit_inner_match_comp_for)
assert [a + b + x for CompPair(a, b) in comp_pairs for CompTriple(x, y, z) in comp_triples] == [4, 7, 8, 11, 12, 15]

# implicit inner PM for: dict comprehension
assert {a: b for z in [comp_pairs] for CompPair(a, b) in z} == {1: 2, 3: 4, 5: 6}

return True