From c63e1a5fb7a31c5117c8ec1b5432530614c1764c Mon Sep 17 00:00:00 2001 From: Mengxuan Cai Date: Tue, 3 Mar 2026 17:08:28 -0500 Subject: [PATCH] Update tag decorator to support attributes with string value --- src/pydsl/transform.py | 24 ++++++++++++++++++------ tests/e2e/test_transform.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/pydsl/transform.py b/src/pydsl/transform.py index 87d3175..9eaf317 100644 --- a/src/pydsl/transform.py +++ b/src/pydsl/transform.py @@ -6,7 +6,7 @@ from mlir.dialects import transform from mlir.dialects.transform import loop, structured -from mlir.ir import IndexType, IntegerAttr, OpView, UnitAttr +from mlir.ir import IndexType, IntegerAttr, OpView, UnitAttr, Attribute from pydsl.macro import CallMacro, Compiled, Evaluated, Uncompiled from pydsl.protocols import SubtreeOut, ToMLIRBase, lower_single @@ -98,17 +98,29 @@ def decorator(mlir): @CallMacro.generate() -def tag(visitor: ToMLIRBase, attr_name: Evaluated[str]) -> Evaluated[Callable]: +def tag( + visitor: ToMLIRBase, + attr_name: Evaluated[str], + attr_value: Evaluated[str | None] = None, +) -> Evaluated[Callable]: """ - Tags the `mlir` MLIR operation with a MLIR unit attribute with name + Tags the `mlir` MLIR operation with a MLIR attribute with name `attr_name`. Arguments: - `mlir`: AST. The AST node whose equivalent MLIR Operator is to be tagged - with the unit attribute - - `attr_name`: str. The name of the unit attribute + with the attribute + - `attr_name`: str. The name of the attribute + - `attr_value`: str. An optional argument that will convert the attribute + from a unit attribute to a key/value attribute pair. """ - return attr_setter(attr_name, UnitAttr.get()) + if type(attr_name) is not str: + raise TypeError("Attribute name is not a string") + + if attr_value is None: + return attr_setter(attr_name, UnitAttr.get()) + else: + return attr_setter(attr_name, Attribute.parse(attr_value)) @CallMacro.generate() diff --git a/tests/e2e/test_transform.py b/tests/e2e/test_transform.py index 839c185..b33e29e 100644 --- a/tests/e2e/test_transform.py +++ b/tests/e2e/test_transform.py @@ -128,8 +128,22 @@ def f(m: MemRef[UInt32, 4, 4]): ) +def test_tag_with_string_value(): + @compile(globals()) + @tag("mytag", "0") + def f(): + a: F32 = 0.0 + b: F32 = 1.0 + + mlir = f.emit_mlir() + + # No runtime check: testing function of tag only + assert r"mytag = 0" in mlir + + if __name__ == "__main__": run(test_multiple_recursively_tag) run(test_multiple_recursively_int_attr) run(test_cse_then_coalesce) run(test_outline_loop) + run(test_tag_with_string_value)