diff --git a/coconut/compiler/compiler.py b/coconut/compiler/compiler.py index 757451d6..125bc50c 100644 --- a/coconut/compiler/compiler.py +++ b/coconut/compiler/compiler.py @@ -865,6 +865,14 @@ def bind(cls): cls.method("has_expr_setname_manage"), include_in_packrat_context=True, ) + cls.inner_match_comp_for <<= attach( + cls.inner_match_comp_for_ref, + cls.method("inner_match_comp_for_handle"), + ) + cls.implicit_inner_match_comp_for <<= attach( + cls.implicit_inner_match_comp_for_ref, + cls.method("inner_match_comp_for_handle"), + ) cls.normal_comp_expr <<= manage( cls.normal_comp_expr_ref, cls.method("has_expr_setname_manage"), @@ -4594,50 +4602,130 @@ def stmt_lambdef_handle(self, original, loc, tokens): return self.handle_expr_scope_closure(name, loc) - def match_comp_expr_handle(self, original, loc, tokens, dict_val=None): - """Build a match comprehension by creating a temp match function. - For dict comps, expr is the key and dict_val is the value.""" - expr, (matches, iterable) = tokens + def _build_match_comp_filter_func(self, original, loc, matches, iterable): + """Generate a filter function for a pattern-matching comprehension clause. + Returns (func_expr, iter_var, unpack, extra_func_setup) where: + - func_expr: the compiled function reference (possibly a closure) + - iter_var: temp var to use as the outer loop variable over iterable + - unpack: the target expression to use in 'for unpack in func_expr(iter_var)' + The generated function takes one item and returns [] on no match or + [(bound_var1, bound_var2, ...)] on match, so it can be used as: + for iter_var in iterable for unpack in func_expr(iter_var) + """ func_name = self.get_temp_var("match_comp", loc) + item_param = self.get_temp_var("match_comp_item", loc) iter_var = self.get_temp_var("match_comp_iter", loc) check_var = self.get_temp_var("match_check", loc) - matcher = self.get_matcher(original, loc, check_var) - matcher.match(matches, iter_var) - + bound_names = [] + matcher = self.get_matcher(original, loc, check_var, name_list=bound_names) + matcher.match(matches, item_param) match_code = matcher.build() - match_error = self.pattern_error(original, loc, iter_var, check_var) - - if dict_val is not None: - return_expr = "(" + expr + ", " + dict_val + ")" + if bound_names: + bound_tuple = "(" + ", ".join(bound_names) + ("," if len(bound_names) == 1 else "") + ")" + unpack = bound_tuple else: - return_expr = expr + bound_tuple = "None" + dummy_var = self.get_temp_var("match_comp_dummy", loc) + unpack = dummy_var funcdef = handle_indentation( """ -def {func_name}({iter_var}): +def {func_name}({item_param}): {match_code} - {match_error} - return {return_expr} + if not {check_var}: + return [] + return [{bound_tuple}] """, add_newline=True, ).format( func_name=func_name, - iter_var=iter_var, + item_param=item_param, match_code=match_code, - match_error=match_error, - return_expr=return_expr, + check_var=check_var, + bound_tuple=bound_tuple, ) self.add_code_before[func_name] = self.decoratable_funcdef_stmt_handle(original, loc, [funcdef], is_stmt_lambda=True) + # Update expr_setnames context to exclude variables bound by the pattern. + expr_setname_ctx = self.current_parsing_context("expr_setnames") + if expr_setname_ctx is not None and bound_names: + expr_setname_ctx["new_names"] -= set(bound_names) + func_expr = self.handle_expr_scope_closure(func_name, loc) + return func_expr, iter_var, unpack + + def match_comp_expr_handle(self, original, loc, tokens, dict_val=None): + """Build a match comprehension by creating a per-item filter function. + For dict comps, expr is the key and dict_val is the value. + The generated comprehension expands to: + expr for iter_var in iterable for unpack in func(iter_var) [extra_comp_clauses] + This puts all extra clauses (for/if) outside the filter function so they + have access to both pattern-bound variables and any outer loop variables. + """ + expr = tokens[0] + match_for_group = tokens[1] + + if len(match_for_group) == 2: + matches, iterable = match_for_group + extra_comp_clauses = "" + else: + matches = match_for_group[0] + iterable = match_for_group[1] + extra_comp_clauses = " " + "".join(match_for_group[2:]) + + func_expr, iter_var, unpack = self._build_match_comp_filter_func(original, loc, matches, iterable) + + # Update expr_setnames for variables introduced by extra comprehension clauses. + if extra_comp_clauses: + expr_setname_ctx = self.current_parsing_context("expr_setnames") + if expr_setname_ctx is not None: + comp_for_vars = set() + for m in re.finditer(r"\bfor\s+(.*?)\s+in\b", extra_comp_clauses): + comp_for_vars.update(re.findall(r"[a-zA-Z_]\w*", m.group(1))) + expr_setname_ctx["new_names"] -= comp_for_vars + + # Build the comprehension body: extra_comp_clauses go OUTSIDE the function + # so they can reference both pattern-bound variables and outer loop variables. if dict_val is not None: - return "_coconut.dict((" + func_expr + "(" + iter_var + ") for " + iter_var + " in " + iterable + "))" + inner_expr = "(" + expr + ", " + dict_val + ")" + return "_coconut.dict(" + inner_expr + " for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + "(" + iter_var + ")" + extra_comp_clauses + ")" + else: + return expr + " for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + "(" + iter_var + ")" + extra_comp_clauses + + def inner_match_comp_for_handle(self, original, loc, tokens): + """Handle a pattern-matching for clause in a non-first position. + + Generates a per-item filter function and returns the for-clause string: + for iter_var in iterable for unpack in func(iter_var) [extra_clauses] + This string is inserted into the surrounding comprehension by the grammar, + allowing full access to any outer loop variables in subsequent clauses. + """ + match_for_group = tokens[0] + + if len(match_for_group) == 2: + matches, iterable = match_for_group + extra_clauses = "" else: - return func_expr + "(" + iter_var + ") for " + iter_var + " in " + iterable + matches = match_for_group[0] + iterable = match_for_group[1] + extra_clauses = " " + "".join(match_for_group[2:]) + + func_expr, iter_var, unpack = self._build_match_comp_filter_func(original, loc, matches, iterable) + + # Update expr_setnames for variables introduced by extra comprehension clauses. + if extra_clauses: + expr_setname_ctx = self.current_parsing_context("expr_setnames") + if expr_setname_ctx is not None: + comp_for_vars = set() + for m in re.finditer(r"\bfor\s+(.*?)\s+in\b", extra_clauses): + comp_for_vars.update(re.findall(r"[a-zA-Z_]\w*", m.group(1))) + expr_setname_ctx["new_names"] -= comp_for_vars + + return "for " + iter_var + " in " + iterable + " for " + unpack + " in " + func_expr + ("(""") + iter_var + ")" + extra_clauses def get_parent_expr_setnames(self): """Get all expr_setnames in parent contexts, but not the current context.""" diff --git a/coconut/compiler/grammar.py b/coconut/compiler/grammar.py index fe227f95..5313730e 100644 --- a/coconut/compiler/grammar.py +++ b/coconut/compiler/grammar.py @@ -1044,6 +1044,8 @@ class Grammar(object): new_namedexpr_test = Forward() comp_for = Forward() match_comp_for = Forward() + inner_match_comp_for = Forward() + implicit_inner_match_comp_for = Forward() normal_comp_expr = Forward() match_comp_expr = Forward() comprehension_expr = Forward() @@ -1086,7 +1088,7 @@ class Grammar(object): | yield_classic ) normal_dict_comp_ref = test + colon.suppress() + test + comp_for - match_dict_comp_ref = test + colon.suppress() + test + Optional(keyword("match").suppress()) + match_comp_for + match_dict_comp_ref = test + colon.suppress() + test + match_comp_for dict_comp = lbrace.suppress() + ( normal_dict_comp | match_dict_comp @@ -1960,14 +1962,31 @@ class Grammar(object): async_comp_for_ref = addspace(keyword("async") + base_comp_for) comp_for <<= base_comp_for | async_comp_for comp_if = addspace(keyword("if") + test_no_cond + Optional(comp_iter)) - comp_iter <<= any_of(comp_for, comp_if) + inner_match_comp_for_ref = Group( + keyword("for").suppress() + + keyword("match").suppress() + + many_match + + keyword("in").suppress() + + comp_it_item + + Optional(comp_iter) + ) + implicit_inner_match_comp_for_ref = Group( + keyword("for").suppress() + + many_match + + keyword("in").suppress() + + comp_it_item + + Optional(comp_iter) + ) + comp_iter <<= any_of(comp_for, comp_if, inner_match_comp_for, implicit_inner_match_comp_for) match_comp_for <<= Group( keyword("for").suppress() + + Optional(keyword("match").suppress()) + many_match + keyword("in").suppress() + comp_it_item + + Optional(comp_iter) ) - match_comp_expr_ref = namedexpr_test + Optional(keyword("match").suppress()) + match_comp_for + match_comp_expr_ref = namedexpr_test + match_comp_for normal_comp_expr_ref = addspace(namedexpr_test + comp_for) comprehension_expr <<= ( normal_comp_expr diff --git a/coconut/tests/src/cocotest/agnostic/primary_2.coco b/coconut/tests/src/cocotest/agnostic/primary_2.coco index 51496eea..38a4d73b 100644 --- a/coconut/tests/src/cocotest/agnostic/primary_2.coco +++ b/coconut/tests/src/cocotest/agnostic/primary_2.coco @@ -758,4 +758,86 @@ def primary_test_2() -> bool: lazy from coconut_nonexistent_test_module4.submodule import missing as lazy_missing2 assert_raises((def => lazy_missing2.attr = 1), ImportError) + # pattern-matching comprehensions (issue #887) + data CompPair(a, b) + data CompTriple(x, y, z) + comp_pairs = [CompPair(1, 2), CompPair(3, 4), CompPair(5, 6)] + comp_triples = [CompTriple(1, 2, 3), CompTriple(4, 5, 6)] + + # basic match comp + assert [a + b for CompPair(a, b) in comp_pairs] == [3, 7, 11] + + # extra for clause + assert [a + b for CompPair(a, b) in comp_pairs for _ in range(3)] == [3, 3, 3, 7, 7, 7, 11, 11, 11] + + # for + if on the loop variable (not pattern variable) + assert [a + b for CompPair(a, b) in comp_pairs for x in range(4) if x % 2 == 0] == [3, 3, 7, 7, 11, 11] + + # for with extra list + assert [a + b + i for CompPair(a, b) in comp_pairs for i in [10, 20]] == [13, 23, 17, 27, 21, 31] + + # triple data type, basic match comp + assert [x + y + z for CompTriple(x, y, z) in comp_triples] == [6, 15] + + # triple + extra for + assert [x * y for CompTriple(x, y, z) in comp_triples for _ in range(2)] == [2, 2, 20, 20] + + # if True (no filtering) + assert [a + b for CompPair(a, b) in comp_pairs if True] == [3, 7, 11] + + # for + if, only first iteration kept + assert [a + b for CompPair(a, b) in comp_pairs for x in range(2) if x == 0] == [3, 7, 11] + + # generator expression with sum + assert sum(a + b for CompPair(a, b) in comp_pairs) == 21 + + # if filtering on pattern variables + assert [a + b for CompPair(a, b) in comp_pairs if a > 1] == [7, 11] + + # if with multiple conditions on pattern variables + assert [a + b for CompPair(a, b) in comp_pairs if a > 1 if b < 6] == [7] + + # if with combined condition on pattern variables + assert [a for CompPair(a, b) in comp_pairs if a % 2 == 1 and b % 2 == 0] == [1, 3, 5] + + # dict comprehension with pattern match and if on pattern variable + assert {a: b for CompPair(a, b) in comp_pairs if a != 3} == {1: 2, 5: 6} + + # set comprehension with pattern match and if on pattern variable + assert {a + b for CompPair(a, b) in comp_pairs if b > 2} == {7, 11} + + # generator expression with if on pattern variable + assert list(a * b for CompPair(a, b) in comp_pairs if a > 2) == [12, 30] + + # triple with if filtering on pattern variable + assert [x + y + z for CompTriple(x, y, z) in comp_triples if z > 3] == [15] + + # arbitrary mixing: normal for followed by pattern-matching for + assert [a + b for z in [comp_pairs] for match CompPair(a, b) in z] == [3, 7, 11] + + # arbitrary mixing: pattern-matching for followed by another pattern-matching for + nested_pairs = [[CompPair(1, 2), "x"], [CompPair(3, 4), "y"]] + assert [a + b for [CompPair(a, b), _] in nested_pairs] == [3, 7] + + # inner PM for with if guard on outer loop variable + assert [a for z in [1, 3] for match CompPair(a, b) in comp_pairs if a == z] == [1, 3] + + # inner PM for: pattern variable available in subsequent clause + assert [a + i for match CompPair(a, b) in comp_pairs for i in range(b - a)] == [1, 3, 3, 5, 5] + + # inner PM for with dict comprehension + assert {a: b for z in [comp_pairs] for match CompPair(a, b) in z} == {1: 2, 3: 4, 5: 6} + + # implicit inner PM for (no 'match' keyword) in non-first position + assert [a + b for z in [comp_pairs] for CompPair(a, b) in z] == [3, 7, 11] + + # implicit inner PM for with if filter + assert [a + b for z in [comp_pairs] for CompPair(a, b) in z if a > 1] == [7, 11] + + # implicit inner PM for: two consecutive pattern-matching clauses (first via match_comp_for, second via implicit_inner_match_comp_for) + assert [a + b + x for CompPair(a, b) in comp_pairs for CompTriple(x, y, z) in comp_triples] == [4, 7, 8, 11, 12, 15] + + # implicit inner PM for: dict comprehension + assert {a: b for z in [comp_pairs] for CompPair(a, b) in z} == {1: 2, 3: 4, 5: 6} + return True