diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index ab07483e97..d7f615efcf 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -86,6 +86,8 @@ Sub-functions called by the kernel are also checked — they must not capture ex Other named constants (non-enum, non-module) captured from scope will raise a `QuadrantsCompilationError`, except for `UPPERCASE` names which emit a warning instead. +Wrapping a captured global in `qd.static(...)` does **not** exempt it from this check. `qd.static` only controls compile-time evaluation; it does not put the value into the cache key, so a `qd.static`-wrapped global is still flagged — though during the current transition period this emits a warning rather than raising. To use such a constant in a fastcache kernel, pass it as a parameter (template primitive, `@qd.data_oriented` member, or dataclass field) or make it one of the allowed captures above. + ### 2. Supported parameter types Fastcache supports the following parameter types: diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 509995b5b5..c1068430a8 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -94,11 +94,18 @@ def build_Name(ctx: ASTTransformerFuncContext, node: ast.Name): if isinstance(node, (ast.stmt, ast.expr)) and isinstance(node.ptr, Expr): node.ptr.dbg_info = _qd_core.DebugInfo(ctx.get_pos_info(node)) node.ptr.ptr.set_dbg_info(node.ptr.dbg_info) - if ctx.is_pure and node.violates_pure and not ctx.static_scope_status.is_in_static_scope: - if isinstance(node.ptr, (float, int, Field)): + # ``qd.static`` is intentionally NOT a purity escape hatch: a captured module global is still flagged inside + # a static scope, since its value never enters the fastcache key regardless of static wrapping. + if ctx.is_pure and node.violates_pure: + # ``str`` is included alongside the numeric/``Field`` types: a captured string only affects a kernel through + # compile-time ``qd.static`` branches, and its value never enters the fastcache key, so it is cache-unsafe + # in exactly the same way as a captured int/float. + if isinstance(node.ptr, (float, int, str, Field)): if not _is_quadrants_internal_file(ctx.file): message = f"[PURE.VIOLATION] WARNING: Accessing global variable {node.id} {type(node.ptr)} {node.violates_pure_reason}" - if node.id.upper() == node.id: + # Transition period: violations inside a ``qd.static`` scope only warn instead of raising, giving + # downstream code time to migrate such constants to kernel params. ``UPPERCASE`` names also warn. + if node.id.upper() == node.id or ctx.is_in_static_scope(): warnings.warn(message) else: raise exception.QuadrantsCompilationError(message) @@ -779,8 +786,10 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): node.violates_pure = node.value.violates_pure if node.violates_pure: node.violates_pure_reason = node.value.violates_pure_reason - if ctx.is_pure and node.violates_pure and not ctx.static_scope_status.is_in_static_scope: - if isinstance(node.ptr, (int, float, Field)): + # ``qd.static`` is intentionally NOT a purity escape hatch (see ``build_Name``). + if ctx.is_pure and node.violates_pure: + # ``str`` included for the same reason as in ``build_Name``: a captured string is cache-unsafe. + if isinstance(node.ptr, (int, float, str, Field)): violation = True if violation and isinstance(node.ptr, enum.Enum): violation = False @@ -790,7 +799,8 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): violation = False if violation: message = f"[PURE.VIOLATION] WARNING: Accessing global var {node.attr} from outside function scope within pure kernel {node.value.violates_pure_reason}" - if node.attr.upper() == node.attr: + # Transition period (see ``build_Name``): ``qd.static`` scope downgrades this to a warning. + if node.attr.upper() == node.attr or ctx.is_in_static_scope(): warnings.warn(message) else: raise exception.QuadrantsCompilationError(message) diff --git a/tests/python/quadrants/lang/fast_caching/test_pure_validation.py b/tests/python/quadrants/lang/fast_caching/test_pure_validation.py index 5d417377eb..a0f51bb70b 100644 --- a/tests/python/quadrants/lang/fast_caching/test_pure_validation.py +++ b/tests/python/quadrants/lang/fast_caching/test_pure_validation.py @@ -42,6 +42,20 @@ def k2(): k2() +@test_utils.test() +def test_pure_validation_str(): + # A captured ``str`` global is cache-unsafe in the same way as a captured int/float, so it must trigger a purity + # violation. Direct access (not wrapped in ``qd.static``) of a lowercase-named global raises. + s = "hello" + + @qd.kernel(pure=True) + def k1(): + print(s) + + with pytest.raises(qd.QuadrantsCompilationError): + k1() + + @test_utils.test() def test_pure_validation_field(): a = qd.field(qd.i32, (10,)) @@ -282,3 +296,43 @@ def k1() -> qd.i32: with pytest.warns(UserWarning, match=r"\[PURE\.VIOLATION\]"): assert k1() == 32 + + +# Restricted to a single (CPU) arch on purpose: the purity check is a Python-side AST analysis and is entirely +# arch-independent, and running it across multiple archs in one worker lets a fastcache hit from one arch suppress the +# warning on the next, which makes ``pytest.warns`` flaky. +@test_utils.test(arch=qd.cpu) +def test_pure_validation_static_scope_warns(): + # Transition period: a captured global accessed inside a ``qd.static`` scope of a pure kernel only warns instead of + # raising, to give downstream code time to migrate such constants to kernel parameters. + assert qd.lang is not None + arch = qd.lang.impl.current_cfg().arch + qd.init(arch=arch, offline_cache=False) + + use_alias = True + + @qd.kernel(pure=True) + def k1() -> qd.i32: + ret = 0 + if qd.static(use_alias): + ret = 1 + return ret + + with pytest.warns(UserWarning, match=r"\[PURE\.VIOLATION\]"): + assert k1() == 1 + + class Cfg: + def __init__(self) -> None: + self.flag = True + + cfg = Cfg() + + @qd.kernel(pure=True) + def k2() -> qd.i32: + ret = 0 + if qd.static(cfg.flag): + ret = 1 + return ret + + with pytest.warns(UserWarning, match=r"\[PURE\.VIOLATION\]"): + assert k2() == 1 diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index 711839cf5d..453ec5621e 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -294,14 +294,14 @@ def src_ll_cache_has_return_child(args: list[str]) -> None: @qd.pure @qd.kernel - def k1(a: qd.i32, output: qd.types.NDArray[qd.i32, 1]) -> bool: + def k1(a: qd.i32, output: qd.types.NDArray[qd.i32, 1], return_something: qd.Template) -> bool: output[0] = a - if qd.static(args_obj.return_something): + if qd.static(return_something): return True output = qd.ndarray(qd.i32, (10,)) if args_obj.return_something: - assert k1(3, output) + assert k1(3, output, args_obj.return_something) # Sanity check that the kernel actually ran, and did something. assert output[0] == 3 assert k1._primal.src_ll_cache_observations.cache_key_generated == args_obj.expect_used_src_ll_cache @@ -314,7 +314,7 @@ def k1(a: qd.i32, output: qd.types.NDArray[qd.i32, 1]) -> bool: with pytest.raises( qd.QuadrantsSyntaxError, match="Kernel has a return type but does not have a return statement" ): - k1(3, output) + k1(3, output, args_obj.return_something) print(TEST_RAN) sys.exit(RET_SUCCESS) diff --git a/tests/python/test_tile.py b/tests/python/test_tile.py index d71b514f92..58e39a1d9d 100644 --- a/tests/python/test_tile.py +++ b/tests/python/test_tile.py @@ -85,18 +85,18 @@ def test_zeros(TILE, make_tile, tdim, m_size, tensor_type, qd_dtype, use_zeros_a Ann = _ann(tensor_type, qd_dtype, 2) @qd.kernel(fastcache=True) - def k1(dst_arr: Ann, N: qd.Template): + def k1(dst_arr: Ann, N: qd.Template, use_alias: qd.Template): qd.loop_config(block_dim=N) tile_size = N for _ in range(tile_size): - if qd.static(use_zeros_alias): + if qd.static(use_alias): t = Tile.zeros() t._store(dst_arr, 0, tile_size, 0, tile_size) else: t = Tile() t._store(dst_arr, 0, tile_size, 0, tile_size) - k1(dst, tdim) + k1(dst, tdim, use_zeros_alias) np.testing.assert_allclose(dst.to_numpy(), np.zeros((tdim, tdim), dtype=np_dtype)) @@ -114,7 +114,7 @@ def test_eye(TILE, make_tile, tdim, m_size, tensor_type, qd_dtype, inplace): Ann = _ann(tensor_type, qd_dtype, 2) @qd.kernel(fastcache=True) - def k1(src_arr: Ann, dst_arr: Ann, N: qd.Template): + def k1(src_arr: Ann, dst_arr: Ann, N: qd.Template, inplace: qd.Template): qd.loop_config(block_dim=N) tile_size = N for _ in range(tile_size): @@ -129,7 +129,7 @@ def k1(src_arr: Ann, dst_arr: Ann, N: qd.Template): data = np.arange(tdim * tdim, dtype=np_dtype).reshape(tdim, tdim) + 100.0 src.from_numpy(data) - k1(src, dst, tdim) + k1(src, dst, tdim, inplace) np.testing.assert_allclose(dst.to_numpy(), np.eye(tdim, dtype=np_dtype)) @@ -904,7 +904,7 @@ def test_load_slice_errors(TILE, make_tile, tdim, m_size, bad_slice, match): dst = qd.ndarray(qd.f32, (tdim, tdim)) @qd.kernel(fastcache=True) - def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template): + def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template, bad_slice: qd.Template): qd.loop_config(block_dim=N) tile_size = N for _ in range(tile_size): @@ -920,7 +920,7 @@ def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Tem d[0:tile_size, 0:tile_size] = t with pytest.raises(QuadrantsSyntaxError, match=match): - k1(src, dst, tdim) + k1(src, dst, tdim, bad_slice) @pytest.mark.parametrize( @@ -939,7 +939,7 @@ def test_store_slice_errors(TILE, make_tile, tdim, m_size, bad_slice, match): dst = qd.ndarray(qd.f32, (tdim, tdim)) @qd.kernel(fastcache=True) - def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template): + def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template, bad_slice: qd.Template): qd.loop_config(block_dim=N) tile_size = N for _ in range(tile_size): @@ -955,7 +955,7 @@ def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Tem d[0:, 0:tile_size] = t with pytest.raises(QuadrantsSyntaxError, match=match): - k1(src, dst, tdim) + k1(src, dst, tdim, bad_slice) @test_utils.test(arch=qd.gpu) @@ -1101,7 +1101,7 @@ def test_vec_slice_errors(TILE, make_tile, tdim, m_size, bad_slice): dst = qd.ndarray(qd.f32, (tdim, tdim)) @qd.kernel(fastcache=True) - def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template): + def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Template, bad_slice: qd.Template): qd.loop_config(block_dim=N) tile_size = N for _ in range(tile_size): @@ -1114,7 +1114,7 @@ def k1(s: qd.types.NDArray[qd.f32, 2], d: qd.types.NDArray[qd.f32, 2], N: qd.Tem d[0:tile_size, 0:tile_size] = t with pytest.raises(QuadrantsSyntaxError, match="both start and stop"): - k1(src, dst, tdim) + k1(src, dst, tdim, bad_slice) # ============================================================================= @@ -1325,7 +1325,14 @@ def test_shared_array_partial_cols(TILE, make_tile, tdim, m_size, partial_store, dst = qd.field(dtype=qd.f32, shape=(tdim, tdim)) @qd.kernel(fastcache=True) - def k1(src_f: qd.Template, dst_f: qd.Template, NCOLS: qd.i32, N: qd.Template): + def k1( + src_f: qd.Template, + dst_f: qd.Template, + NCOLS: qd.i32, + N: qd.Template, + partial_store: qd.Template, + partial_load: qd.Template, + ): qd.loop_config(block_dim=N) tile_size = N for _ in range(tile_size): @@ -1353,7 +1360,7 @@ def k1(src_f: qd.Template, dst_f: qd.Template, NCOLS: qd.i32, N: qd.Template): data = np.arange(tdim * tdim, dtype=np.float32).reshape(tdim, tdim) + 1.0 src.from_numpy(data) - k1(src, dst, NCOLS, tdim) + k1(src, dst, NCOLS, tdim, partial_store, partial_load) result = dst.to_numpy() np.testing.assert_allclose(result[:, :NCOLS], data[:, :NCOLS]) if partial_load: