From a9b30b3c3a7b70dd263e5fb883826f4ecd7c2fa5 Mon Sep 17 00:00:00 2001 From: Adam Forest Date: Tue, 24 Feb 2026 14:07:13 -0500 Subject: [PATCH 1/2] Implimenting issue #887 --- coconut/compiler/compiler.py | 47 ++++++++++++---- coconut/compiler/grammar.py | 1 + .../src/cocotest/agnostic/primary_2.coco | 54 +++++++++++++++++++ 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 757451d6..178cb3e0 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -4597,28 +4597,50 @@ def stmt_lambdef_handle(self, original, loc, tokens): def match_comp_expr_handle(self, original, loc, tokens, dict_val=None): """Build a match comprehension by creating a temp match function. For dict comps, expr is the key and dict_val is the value.""" - expr, (matches, iterable) = tokens + expr = tokens[0] + match_for_group = tokens[1] + + if len(match_for_group) == 2: + matches, iterable = match_for_group + extra_comp_clauses = "" + else: + matches = match_for_group[0] + iterable = match_for_group[1] + extra_comp_clauses = " " + "".join(match_for_group[2:]) func_name = self.get_temp_var("match_comp", loc) iter_var = self.get_temp_var("match_comp_iter", loc) check_var = self.get_temp_var("match_check", loc) + val_var = self.get_temp_var("match_comp_val", loc) matcher = self.get_matcher(original, loc, check_var) matcher.match(matches, iter_var) match_code = matcher.build() - match_error = self.pattern_error(original, loc, iter_var, check_var) if dict_val is not None: - return_expr = "(" + expr + ", " + dict_val + ")" + inner_expr = "(" + expr + ", " + dict_val + ")" else: - return_expr = expr + inner_expr = expr + + # Always return a list: [] on no match, [inner_expr] (or a comprehension) on match. + # This filters non-matching elements instead of raising MatchError. + if extra_comp_clauses and extra_comp_clauses.lstrip().startswith("if "): + # `if` guard: use dummy var trick so the guard is a proper comprehension filter + guard_dummy_var = self.get_temp_var("guard_dummy", loc) + return_expr = "[" + inner_expr + " for " + guard_dummy_var + " in [None]" + extra_comp_clauses + "]" + elif extra_comp_clauses: + # `for` clause(s): put them inside the function so pattern vars are in scope + return_expr = "[" + inner_expr + extra_comp_clauses + "]" + else: + return_expr = "[" + inner_expr + "]" funcdef = handle_indentation( """ def {func_name}({iter_var}): {match_code} - {match_error} + if not {check_var}: + return [] return {return_expr} """, add_newline=True, @@ -4626,19 +4648,26 @@ def {func_name}({iter_var}): func_name=func_name, iter_var=iter_var, match_code=match_code, - match_error=match_error, + check_var=check_var, return_expr=return_expr, ) self.add_code_before[func_name] = self.decoratable_funcdef_stmt_handle(original, loc, [funcdef], is_stmt_lambda=True) + + if extra_comp_clauses: + expr_setname_ctx = self.current_parsing_context("expr_setnames") + if expr_setname_ctx is not None: + comp_for_vars = set() + for m in re.finditer(r"\bfor\s+(.*?)\s+in\b", extra_comp_clauses): + comp_for_vars.update(re.findall(r"[a-zA-Z_]\w*", m.group(1))) + expr_setname_ctx["new_names"] -= comp_for_vars func_expr = self.handle_expr_scope_closure(func_name, loc) if dict_val is not None: - return "_coconut.dict((" + func_expr + "(" + iter_var + ") for " + iter_var + " in " + iterable + "))" + return "_coconut.dict(" + val_var + " for " + iter_var + " in " + iterable + " for " + val_var + " in " + func_expr + "(" + iter_var + "))" else: - return func_expr + "(" + iter_var + ") for " + iter_var + " in " + iterable - + return val_var + " for " + iter_var + " in " + iterable + " for " + val_var + " in " + func_expr + "(" + iter_var + ")" def get_parent_expr_setnames(self): """Get all expr_setnames in parent contexts, but not the current context.""" expr_setname_context = self.current_parsing_context("expr_setnames") diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index fe227f95..3d528e95 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -1966,6 +1966,7 @@ class Grammar(object): + many_match + keyword("in").suppress() + comp_it_item + + Optional(comp_iter) ) match_comp_expr_ref = namedexpr_test + Optional(keyword("match").suppress()) + match_comp_for normal_comp_expr_ref = addspace(namedexpr_test + comp_for) diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index 51496eea..3dd18b5c 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -758,4 +758,58 @@ def primary_test_2() -> bool: lazy from coconut_nonexistent_test_module4.submodule import missing as lazy_missing2 assert_raises((def => lazy_missing2.attr = 1), ImportError) + # pattern-matching comprehensions (issue #887) + data CompPair(a, b) + data CompTriple(x, y, z) + comp_pairs = [CompPair(1, 2), CompPair(3, 4), CompPair(5, 6)] + comp_triples = [CompTriple(1, 2, 3), CompTriple(4, 5, 6)] + + # basic match comp + assert [a + b for CompPair(a, b) in comp_pairs] == [3, 7, 11] + + # extra for clause + assert [a + b for CompPair(a, b) in comp_pairs for _ in range(3)] == [3, 3, 3, 7, 7, 7, 11, 11, 11] + + # for + if on the loop variable (not pattern variable) + assert [a + b for CompPair(a, b) in comp_pairs for x in range(4) if x % 2 == 0] == [3, 3, 7, 7, 11, 11] + + # for with extra list + assert [a + b + i for CompPair(a, b) in comp_pairs for i in [10, 20]] == [13, 23, 17, 27, 21, 31] + + # triple data type, basic match comp + assert [x + y + z for CompTriple(x, y, z) in comp_triples] == [6, 15] + + # triple + extra for + assert [x * y for CompTriple(x, y, z) in comp_triples for _ in range(2)] == [2, 2, 20, 20] + + # if True (no filtering) + assert [a + b for CompPair(a, b) in comp_pairs if True] == [3, 7, 11] + + # for + if, only first iteration kept + assert [a + b for CompPair(a, b) in comp_pairs for x in range(2) if x == 0] == [3, 7, 11] + + # generator expression with sum + assert sum(a + b for CompPair(a, b) in comp_pairs) == 21 + + # if filtering on pattern variables + assert [a + b for CompPair(a, b) in comp_pairs if a > 1] == [7, 11] + + # if with multiple conditions on pattern variables + assert [a + b for CompPair(a, b) in comp_pairs if a > 1 if b < 6] == [7] + + # if with combined condition on pattern variables + assert [a for CompPair(a, b) in comp_pairs if a % 2 == 1 and b % 2 == 0] == [1, 3, 5] + + # dict comprehension with pattern match and if on pattern variable + assert {a: b for CompPair(a, b) in comp_pairs if a != 3} == {1: 2, 5: 6} + + # set comprehension with pattern match and if on pattern variable + assert {a + b for CompPair(a, b) in comp_pairs if b > 2} == {7, 11} + + # generator expression with if on pattern variable + assert list(a * b for CompPair(a, b) in comp_pairs if a > 2) == [12, 30] + + # triple with if filtering on pattern variable + assert [x + y + z for CompTriple(x, y, z) in comp_triples if z > 3] == [15] + return True From b2e3fa643d65af13ffea319e503ef49c3e3552b8 Mon Sep 17 00:00:00 2001 From: Adam Forest Date: Sun, 1 Mar 2026 21:22:41 -0500 Subject: [PATCH 2/2] Supporting the arbitrary mixing and matching of standard and pattern-matching comprehensions --- coconut/compiler/compiler.py | 139 +++++++++++++----- coconut/compiler/grammar.py | 24 ++- .../src/cocotest/agnostic/primary_2.coco | 28 ++++ 3 files changed, 148 insertions(+), 43 deletions(-) diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 178cb3e0..125bc50c 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -865,6 +865,14 @@ def bind(cls): cls.method("has_expr_setname_manage"), include_in_packrat_context=True, ) + cls.inner_match_comp_for <<= attach( + cls.inner_match_comp_for_ref, + cls.method("inner_match_comp_for_handle"), + ) + cls.implicit_inner_match_comp_for <<= attach( + cls.implicit_inner_match_comp_for_ref, + cls.method("inner_match_comp_for_handle"), + ) cls.normal_comp_expr <<= manage( cls.normal_comp_expr_ref, cls.method("has_expr_setname_manage"), @@ -4594,67 +4602,84 @@ def stmt_lambdef_handle(self, original, loc, tokens): return self.handle_expr_scope_closure(name, loc) - def match_comp_expr_handle(self, original, loc, tokens, dict_val=None): - """Build a match comprehension by creating a temp match function. - For dict comps, expr is the key and dict_val is the value.""" - expr = tokens[0] - match_for_group = tokens[1] - - if len(match_for_group) == 2: - matches, iterable = match_for_group - extra_comp_clauses = "" - else: - matches = match_for_group[0] - iterable = match_for_group[1] - extra_comp_clauses = " " + "".join(match_for_group[2:]) + def _build_match_comp_filter_func(self, original, loc, matches, iterable): + """Generate a filter function for a pattern-matching comprehension clause. + Returns (func_expr, iter_var, unpack, extra_func_setup) where: + - func_expr: the compiled function reference (possibly a closure) + - iter_var: temp var to use as the outer loop variable over iterable + - unpack: the target expression to use in 'for unpack in func_expr(iter_var)' + The generated function takes one item and returns [] on no match or + [(bound_var1, bound_var2, ...)] on match, so it can be used as: + for iter_var in iterable for unpack in func_expr(iter_var) + """ func_name = self.get_temp_var("match_comp", loc) + item_param = self.get_temp_var("match_comp_item", loc) iter_var = self.get_temp_var("match_comp_iter", loc) check_var = self.get_temp_var("match_check", loc) - val_var = self.get_temp_var("match_comp_val", loc) - - matcher = self.get_matcher(original, loc, check_var) - matcher.match(matches, iter_var) + bound_names = [] + matcher = self.get_matcher(original, loc, check_var, name_list=bound_names) + matcher.match(matches, item_param) match_code = matcher.build() - - if dict_val is not None: - inner_expr = "(" + expr + ", " + dict_val + ")" + if bound_names: + bound_tuple = "(" + ", ".join(bound_names) + ("," if len(bound_names) == 1 else "") + ")" + unpack = bound_tuple else: - inner_expr = expr - - # Always return a list: [] on no match, [inner_expr] (or a comprehension) on match. - # This filters non-matching elements instead of raising MatchError. - if extra_comp_clauses and extra_comp_clauses.lstrip().startswith("if "): - # `if` guard: use dummy var trick so the guard is a proper comprehension filter - guard_dummy_var = self.get_temp_var("guard_dummy", loc) - return_expr = "[" + inner_expr + " for " + guard_dummy_var + " in [None]" + extra_comp_clauses + "]" - elif extra_comp_clauses: - # `for` clause(s): put them inside the function so pattern vars are in scope - return_expr = "[" + inner_expr + extra_comp_clauses + "]" - else: - return_expr = "[" + inner_expr + "]" + bound_tuple = "None" + dummy_var = self.get_temp_var("match_comp_dummy", loc) + unpack = dummy_var funcdef = handle_indentation( """ -def {func_name}({iter_var}): +def {func_name}({item_param}): {match_code} if not {check_var}: return [] - return {return_expr} + return [{bound_tuple}] """, add_newline=True, ).format( func_name=func_name, - iter_var=iter_var, + item_param=item_param, match_code=match_code, check_var=check_var, - return_expr=return_expr, + bound_tuple=bound_tuple, ) self.add_code_before[func_name] = self.decoratable_funcdef_stmt_handle(original, loc, [funcdef], is_stmt_lambda=True) + # Update expr_setnames context to exclude variables bound by the pattern. + expr_setname_ctx = self.current_parsing_context("expr_setnames") + if expr_setname_ctx is not None and bound_names: + expr_setname_ctx["new_names"] -= set(bound_names) + + func_expr = self.handle_expr_scope_closure(func_name, loc) + return func_expr, iter_var, unpack + + def match_comp_expr_handle(self, original, loc, tokens, dict_val=None): + """Build a match comprehension by creating a per-item filter function. + For dict comps, expr is the key and dict_val is the value. + The generated comprehension expands to: + expr for iter_var in iterable for unpack in func(iter_var) [extra_comp_clauses] + This puts all extra clauses (for/if) outside the filter function so they + have access to both pattern-bound variables and any outer loop variables. + """ + expr = tokens[0] + match_for_group = tokens[1] + + if len(match_for_group) == 2: + matches, iterable = match_for_group + extra_comp_clauses = "" + else: + matches = match_for_group[0] + iterable = match_for_group[1] + extra_comp_clauses = " " + "".join(match_for_group[2:]) + + func_expr, iter_var, unpack = self._build_match_comp_filter_func(original, loc, matches, iterable) + + # Update expr_setnames for variables introduced by extra comprehension clauses. if extra_comp_clauses: expr_setname_ctx = self.current_parsing_context("expr_setnames") if expr_setname_ctx is not None: @@ -4662,12 +4687,46 @@ def {func_name}({iter_var}): for m in re.finditer(r"\bfor\s+(.*?)\s+in\b", extra_comp_clauses): comp_for_vars.update(re.findall(r"[a-zA-Z_]\w*", m.group(1))) expr_setname_ctx["new_names"] -= comp_for_vars - func_expr = self.handle_expr_scope_closure(func_name, loc) + # Build the comprehension body: extra_comp_clauses go OUTSIDE the function + # so they can reference both pattern-bound variables and outer loop variables. if dict_val is not None: - return "_coconut.dict(" + val_var + " for " + iter_var + " in " + iterable + " for " + val_var + " in " + func_expr + "(" + iter_var + "))" + inner_expr = "(" + expr + ", " + dict_val + ")" + return "_coconut.dict(" + inner_expr + " for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + "(" + iter_var + ")" + extra_comp_clauses + ")" + else: + return expr + " for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + "(" + iter_var + ")" + extra_comp_clauses + + def inner_match_comp_for_handle(self, original, loc, tokens): + """Handle a pattern-matching for clause in a non-first position. + + Generates a per-item filter function and returns the for-clause string: + for iter_var in iterable for unpack in func(iter_var) [extra_clauses] + This string is inserted into the surrounding comprehension by the grammar, + allowing full access to any outer loop variables in subsequent clauses. + """ + match_for_group = tokens[0] + + if len(match_for_group) == 2: + matches, iterable = match_for_group + extra_clauses = "" else: - return val_var + " for " + iter_var + " in " + iterable + " for " + val_var + " in " + func_expr + "(" + iter_var + ")" + matches = match_for_group[0] + iterable = match_for_group[1] + extra_clauses = " " + "".join(match_for_group[2:]) + + func_expr, iter_var, unpack = self._build_match_comp_filter_func(original, loc, matches, iterable) + + # Update expr_setnames for variables introduced by extra comprehension clauses. + if extra_clauses: + expr_setname_ctx = self.current_parsing_context("expr_setnames") + if expr_setname_ctx is not None: + comp_for_vars = set() + for m in re.finditer(r"\bfor\s+(.*?)\s+in\b", extra_clauses): + comp_for_vars.update(re.findall(r"[a-zA-Z_]\w*", m.group(1))) + expr_setname_ctx["new_names"] -= comp_for_vars + + return "for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + ("(""") + iter_var + ")" + extra_clauses + def get_parent_expr_setnames(self): """Get all expr_setnames in parent contexts, but not the current context.""" expr_setname_context = self.current_parsing_context("expr_setnames") diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index 3d528e95..5313730e 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -1044,6 +1044,8 @@ class Grammar(object): new_namedexpr_test = Forward() comp_for = Forward() match_comp_for = Forward() + inner_match_comp_for = Forward() + implicit_inner_match_comp_for = Forward() normal_comp_expr = Forward() match_comp_expr = Forward() comprehension_expr = Forward() @@ -1086,7 +1088,7 @@ class Grammar(object): | yield_classic ) normal_dict_comp_ref = test + colon.suppress() + test + comp_for - match_dict_comp_ref = test + colon.suppress() + test + Optional(keyword("match").suppress()) + match_comp_for + match_dict_comp_ref = test + colon.suppress() + test + match_comp_for dict_comp = lbrace.suppress() + ( normal_dict_comp | match_dict_comp @@ -1960,15 +1962,31 @@ class Grammar(object): async_comp_for_ref = addspace(keyword("async") + base_comp_for) comp_for <<= base_comp_for | async_comp_for comp_if = addspace(keyword("if") + test_no_cond + Optional(comp_iter)) - comp_iter <<= any_of(comp_for, comp_if) + inner_match_comp_for_ref = Group( + keyword("for").suppress() + + keyword("match").suppress() + + many_match + + keyword("in").suppress() + + comp_it_item + + Optional(comp_iter) + ) + implicit_inner_match_comp_for_ref = Group( + keyword("for").suppress() + + many_match + + keyword("in").suppress() + + comp_it_item + + Optional(comp_iter) + ) + comp_iter <<= any_of(comp_for, comp_if, inner_match_comp_for, implicit_inner_match_comp_for) match_comp_for <<= Group( keyword("for").suppress() + + Optional(keyword("match").suppress()) + many_match + keyword("in").suppress() + comp_it_item + Optional(comp_iter) ) - match_comp_expr_ref = namedexpr_test + Optional(keyword("match").suppress()) + match_comp_for + match_comp_expr_ref = namedexpr_test + match_comp_for normal_comp_expr_ref = addspace(namedexpr_test + comp_for) comprehension_expr <<= ( normal_comp_expr diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index 3dd18b5c..38a4d73b 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -812,4 +812,32 @@ def primary_test_2() -> bool: # triple with if filtering on pattern variable assert [x + y + z for CompTriple(x, y, z) in comp_triples if z > 3] == [15] + # arbitrary mixing: normal for followed by pattern-matching for + assert [a + b for z in [comp_pairs] for match CompPair(a, b) in z] == [3, 7, 11] + + # arbitrary mixing: pattern-matching for followed by another pattern-matching for + nested_pairs = [[CompPair(1, 2), "x"], [CompPair(3, 4), "y"]] + assert [a + b for [CompPair(a, b), _] in nested_pairs] == [3, 7] + + # inner PM for with if guard on outer loop variable + assert [a for z in [1, 3] for match CompPair(a, b) in comp_pairs if a == z] == [1, 3] + + # inner PM for: pattern variable available in subsequent clause + assert [a + i for match CompPair(a, b) in comp_pairs for i in range(b - a)] == [1, 3, 3, 5, 5] + + # inner PM for with dict comprehension + assert {a: b for z in [comp_pairs] for match CompPair(a, b) in z} == {1: 2, 3: 4, 5: 6} + + # implicit inner PM for (no 'match' keyword) in non-first position + assert [a + b for z in [comp_pairs] for CompPair(a, b) in z] == [3, 7, 11] + + # implicit inner PM for with if filter + assert [a + b for z in [comp_pairs] for CompPair(a, b) in z if a > 1] == [7, 11] + + # implicit inner PM for: two consecutive pattern-matching clauses (first via match_comp_for, second via implicit_inner_match_comp_for) + assert [a + b + x for CompPair(a, b) in comp_pairs for CompTriple(x, y, z) in comp_triples] == [4, 7, 8, 11, 12, 15] + + # implicit inner PM for: dict comprehension + assert {a: b for z in [comp_pairs] for CompPair(a, b) in z} == {1: 2, 3: 4, 5: 6} + return True