From b0101fa2089a43ef0098570b44af4fd8d34ef989 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 09:51:55 -0700 Subject: [PATCH 1/5] [Graph] Support member ndarrays as qd.checkpoint(yield_on=) / qd.graph_do_while() arguments `qd.checkpoint(yield_on=...)` and `qd.graph_do_while(...)` previously required the argument to be a bare kernel parameter (`ast.Name`). With this change they also accept attribute chains -- both `@qd.data_oriented` member ndarrays (`self.flag`, `self.counter`) and `@dataclasses.dataclass` parameter members (`params.flag`, `params.counter`) -- resolved to a flat C++ arg-id at AST-build time via a new shared `ASTTransformer._resolve_ndarray_kernel_arg_id` helper that builds the expression and reads the resolved `ExternalTensorExpression.arg_id` via a new `get_external_tensor_arg_id` accessor on `export_lang.cpp`. Any attribute chain that flattens to a kernel ndarray argument works the same way as a bare parameter name, so users no longer have to forward flag / counter members as top-level kernel parameters. The launch path now forwards `Kernel.checkpoint_yield_on_cpp_arg_ids` and `GraphDoWhileLevel.cond_cpp_arg_id` directly to the launch context, removing the per-launch name-matching step. The fast-cache schema bumps to v3 to round-trip the AST-resolved arg-ids alongside the existing graph_do_while level table and the checkpoint yield_on / user-label tables. --- docs/source/user_guide/graph.md | 4 +- .../lang/_fast_caching/src_hasher.py | 53 ++++--- python/quadrants/lang/ast/ast_transformer.py | 86 ++++++++--- .../checkpoint_transformer.py | 49 ++++--- .../function_def_transformer.py | 1 + python/quadrants/lang/kernel.py | 72 ++++++---- python/quadrants/lang/kernel_checkpoint.py | 37 ++--- quadrants/python/export_lang.cpp | 9 ++ tests/python/test_checkpoint.py | 134 +++++++++++++++++- tests/python/test_graph_do_while.py | 77 +++++++++- 10 files changed, 403 insertions(+), 119 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index b1c239b080..2a11b0e252 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -87,7 +87,7 @@ solve(x, counter) # x is now incremented 10 times; counter is 0 ``` -The argument to `qd.graph_do_while()` must be the name of a scalar `qd.i32` ndarray parameter. The loop body repeats while this value is non-zero. +The argument to `qd.graph_do_while()` must reference a scalar `qd.i32` ndarray that the kernel can access — a bare kernel parameter (`qd.graph_do_while(counter)`), a `@qd.data_oriented` member ndarray (`qd.graph_do_while(self.counter)`), or a `@dataclasses.dataclass` parameter member (`qd.graph_do_while(params.counter)`). The loop body repeats while this value is non-zero. - On CUDA SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement. - On older CUDA GPUs, AMDGPU, and non-GPU backends, it falls back to a host-side do-while loop (see the [backend support table](#backend-support)). @@ -256,7 +256,7 @@ while status.yielded: - Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time with a fix-it pointing at `checkpoints=True`. - `cp_id` must be an int literal or an `IntEnum` value, and must be unique across the kernel. -- `yield_on=` must be a kernel parameter that is a 0-d `qd.types.ndarray(qd.i32, ndim=0)`; expressions are not supported. +- `yield_on=` must reference a 0-d `qd.types.ndarray(qd.i32, ndim=0)` — a bare kernel parameter (`yield_on=flag`), a `@qd.data_oriented` member ndarray (`yield_on=self.flag`), or a `@dataclasses.dataclass` parameter member (`yield_on=params.flag`). Arbitrary expressions are not supported. - Checkpoints cannot be nested inside other checkpoints. Checkpoints inside a `qd.graph_do_while` body are fine. - The body of a `with qd.checkpoint(...)` block cannot contain bare top-level statements (assignments, augmented assignments, or bare call/expression statements). Every top-level statement must be inside a `for`-loop (or other control-flow construct). A docstring as the first statement is allowed. Bare statements raise `QuadrantsSyntaxError` at compile time with a fix-it pointing at the explicit one-iteration `for`-wrap: diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 76a6a0160a..6b6c0227df 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -18,8 +18,10 @@ from .python_side_cache import PythonSideCache # Bumped whenever the persisted CacheValue schema changes (see create_cache_key). v2 replaced the single -# graph_do_while_arg string with a nested level table. -_CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v2-gdw-levels" +# graph_do_while_arg string with a nested level table. v3 added the AST-resolved flat C++ arg-ids for +# qd.graph_do_while conditions and qd.checkpoint(yield_on=...) targets so the launch path can forward them +# directly without per-launch name matching (necessary for @qd.data_oriented member ndarrays). +_CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v3-ast-resolved-ids" def create_cache_key( @@ -69,9 +71,20 @@ class CacheValue(BaseModel): frontend_cache_key: str hashed_function_source_infos: list[HashedFunctionSourceInfo] used_py_dataclass_parameters: set[str] - # Nested graph_do_while level table as (cond_arg_name, parent_id) pairs, indexed by level id. None / empty for - # kernels without graph_do_while. - graph_do_while_levels: list[tuple[str, int]] | None = None + # Nested graph_do_while level table as (cond_arg_name, parent_id, cond_cpp_arg_id) triples, indexed by level + # id. None / empty for kernels without graph_do_while. ``cond_cpp_arg_id`` is the flat C++ arg-id resolved at + # AST-build time by ``ASTTransformer._resolve_ndarray_kernel_arg_id`` and is required by the launch path to + # support `@qd.data_oriented` member conditions (`qd.graph_do_while(self.counter)`) -- name-matching against + # ``arg_metas`` only resolves top-level parameters. + graph_do_while_levels: list[tuple[str, int, int]] | None = None + # AST-build-time-resolved checkpoint metadata, indexed by internal cp_id. Empty for kernels without any + # `with qd.checkpoint(...)` block. See `Kernel.checkpoint_yield_on_args` / + # `Kernel.checkpoint_yield_on_cpp_arg_ids` / `Kernel.checkpoint_user_labels_by_cp_id` for what each entry means. + # Restored alongside the C++-side cached kernel so the launch path can forward `yield_on=` arg-ids and + # translate `from_checkpoint=` labels without re-running the AST transformer. + checkpoint_yield_on_args: list[str | None] = [] + checkpoint_yield_on_cpp_arg_ids: list[int] = [] + checkpoint_user_labels_by_cp_id: list[int | None] = [] def store( @@ -79,7 +92,10 @@ def store( fast_cache_key: str, function_source_infos: Iterable[FunctionSourceInfo], used_py_dataclass_parameters: set[str], - graph_do_while_levels: list[tuple[str, int]] | None = None, + graph_do_while_levels: list[tuple[str, int, int]] | None = None, + checkpoint_yield_on_args: list[str | None] | None = None, + checkpoint_yield_on_cpp_arg_ids: list[int] | None = None, + checkpoint_user_labels_by_cp_id: list[int | None] | None = None, ) -> None: """ Note that unlike other caches, this cache is not going to store the actual value we want. @@ -108,6 +124,9 @@ def store( hashed_function_source_infos=list(hashed_function_source_infos), used_py_dataclass_parameters=used_py_dataclass_parameters, graph_do_while_levels=graph_do_while_levels, + checkpoint_yield_on_args=checkpoint_yield_on_args or [], + checkpoint_yield_on_cpp_arg_ids=checkpoint_yield_on_cpp_arg_ids or [], + checkpoint_user_labels_by_cp_id=checkpoint_user_labels_by_cp_id or [], ) cache.store(fast_cache_key, cache_value_obj.model_dump_json()) @@ -125,23 +144,19 @@ def _try_load(cache_key: str) -> CacheValue | None: return cache_value_obj -def load( - cache_key: str, -) -> tuple[set[str], str, list[tuple[str, int]] | None] | tuple[None, None, None]: - """ - loads function source infos from cache, if available - checks the hashes against the current source code +def load(cache_key: str) -> CacheValue | None: + """Load a validated ``CacheValue`` for *cache_key* if one exists and its source hashes still match, else None. + + Returns the full ``CacheValue`` (rather than the historical 3-tuple) so callers can pick off the + AST-transformer-produced metadata (graph_do_while levels, checkpoint tables) without the loader having to grow + a new return slot every time we cache a new piece of AST output. """ cache_value = _try_load(cache_key) if cache_value is None: - return None, None, None + return None if function_hasher.validate_hashed_function_infos(cache_value.hashed_function_source_infos): - return ( - cache_value.used_py_dataclass_parameters, - cache_value.frontend_cache_key, - cache_value.graph_do_while_levels, - ) - return None, None, None + return cache_value + return None def dump_stats() -> None: diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 0297fffa8c..29f86dde33 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1352,19 +1352,67 @@ def build_For(ctx: ASTTransformerFuncContext, node: ast.For) -> None: return ASTTransformer.build_struct_for(ctx, node, is_grouped=False) @staticmethod - def _is_graph_do_while_call(node: ast.expr) -> str | None: - """If *node* is ``qd.graph_do_while(var)`` return the arg name, else None.""" + def _is_graph_do_while_call(node: ast.expr) -> ast.expr | None: + """If *node* is ``qd.graph_do_while(arg)`` return the arg AST node, else None. + + ``arg`` may be an ``ast.Name`` (a bare kernel parameter, e.g. ``counter``) or an ``ast.Attribute`` chain + (a ``@qd.data_oriented`` member ndarray such as ``self.counter`` or a ``@dataclasses.dataclass`` parameter + member such as ``params.counter``). The actual resolution to a kernel ndarray argument happens in + ``build_While`` via ``_resolve_ndarray_kernel_arg_id``. + """ if not isinstance(node, ast.Call): return None func = node.func - if isinstance(func, ast.Attribute) and func.attr == "graph_do_while": - if len(node.args) == 1 and isinstance(node.args[0], ast.Name): - return node.args[0].id - if isinstance(func, ast.Name) and func.id == "graph_do_while": - if len(node.args) == 1 and isinstance(node.args[0], ast.Name): - return node.args[0].id + is_gdw = (isinstance(func, ast.Attribute) and func.attr == "graph_do_while") or ( + isinstance(func, ast.Name) and func.id == "graph_do_while" + ) + if not is_gdw: + return None + if len(node.args) == 1 and isinstance(node.args[0], (ast.Name, ast.Attribute)): + return node.args[0] return None + @staticmethod + def _resolve_ndarray_kernel_arg_id( + ctx: ASTTransformerFuncContext, + kernel, + node: ast.expr, + usage: str, + ) -> tuple[str, int]: + """Resolve an ndarray-referencing expression to ``(label, flat_cpp_arg_id)`` at AST-build time. + + Shared between ``qd.checkpoint(yield_on=...)`` and ``qd.graph_do_while(...)`` to turn the control-flag + argument into the flat C++ arg-id the runtime matches against. ``node`` is an ``ast.Name`` (a bare kernel + parameter, e.g. ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` + owner, or ``params.flag`` where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). We build the + expression through the normal AST machinery and read the arg-id off the resulting external-tensor + expression -- this unifies the bare-param and member-ndarray cases, since both flatten to a real ndarray + kernel argument carrying its arg-id on the ``ExternalTensorExpression``. + + ``usage`` is the call form (e.g. ``"qd.checkpoint(yield_on=...)"``) used in the error message. Raises + ``QuadrantsSyntaxError`` if the expression does not resolve to an ndarray kernel argument. + """ + from quadrants.lang.any_array import AnyArray # pylint: disable=C0415 + + label = ast.unparse(node) + bad = QuadrantsSyntaxError( + f"{usage} got {label!r} which does not resolve to an ndarray kernel parameter of " + f"{kernel.func.__name__!r}. The argument must reference an ndarray kernel parameter (e.g. " + f"`flag`) or a @qd.data_oriented member ndarray (e.g. `self.flag`); other expressions are not " + f"supported." + ) + try: + built = build_stmt(ctx, node) + except Exception as e: # noqa: BLE001 - any resolution failure is a user-facing misuse + raise bad from e + resolved_expr = built.ptr if isinstance(built, AnyArray) else built + if not (hasattr(resolved_expr, "is_external_tensor_expr") and resolved_expr.is_external_tensor_expr()): + raise bad + arg_id = _qd_core.get_external_tensor_arg_id(resolved_expr) + if not arg_id: + raise bad + return label, int(arg_id[0]) + @staticmethod def _is_checkpoint_call(node: ast.expr, global_vars: dict): """Thin forwarding wrapper around ``CheckpointTransformer.is_checkpoint_call``; the actual logic lives in module @@ -1377,18 +1425,11 @@ def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: if node.orelse: raise QuadrantsSyntaxError("'else' clause for 'while' not supported in Quadrants kernels") - graph_do_while_arg = ASTTransformer._is_graph_do_while_call(node.test) - if graph_do_while_arg is not None: + graph_do_while_node = ASTTransformer._is_graph_do_while_call(node.test) + if graph_do_while_node is not None: from quadrants.lang.kernel import GraphDoWhileLevel # pylint: disable=C0415 kernel = ctx.global_context.current_kernel - arg_names = [m.name for m in kernel.arg_metas] - if graph_do_while_arg not in arg_names: - raise QuadrantsSyntaxError( - f"qd.graph_do_while({graph_do_while_arg!r}) does not match any " - f"parameter of kernel {kernel.func.__name__!r}. " - f"Available parameters: {arg_names}" - ) if not kernel.use_graph: raise QuadrantsSyntaxError("qd.graph_do_while() requires @qd.kernel(graph=True)") # graph_do_while emits no loop IR; its body's for-loops must be top-level (offloaded) tasks. So it may only @@ -1399,15 +1440,22 @@ def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: "qd.graph_do_while() must be at the kernel top level or directly nested inside " "another qd.graph_do_while(); it cannot appear inside a for-loop." ) + # Resolve the condition ndarray (bare parameter or @qd.data_oriented member) to its flat C++ arg-id at + # AST-build time -- the same id the runtime needs -- so the launch path forwards it directly with no + # per-launch name matching. ``cond_arg_name`` keeps the readable label (e.g. "counter" or "self.counter") + # for introspection and for the legacy ``graph_do_while_arg`` alias surfaced on Kernel. + cond_label, cond_cpp_arg_id = ASTTransformer._resolve_ndarray_kernel_arg_id( + ctx, kernel, graph_do_while_node, "qd.graph_do_while(...)" + ) # Register this loop as a new nesting level (the body restriction is validated up-front in # FunctionDefTransformer). Outer loops get lower ids than the inner loops they contain. parent_id = kernel._graph_do_while_level_stack[-1] if kernel._graph_do_while_level_stack else -1 level_id = len(kernel.graph_do_while_levels) kernel.graph_do_while_levels.append( - GraphDoWhileLevel(cond_arg_name=graph_do_while_arg, parent_id=parent_id) + GraphDoWhileLevel(cond_arg_name=cond_label, parent_id=parent_id, cond_cpp_arg_id=cond_cpp_arg_id) ) if level_id == 0: - kernel.graph_do_while_arg = graph_do_while_arg + kernel.graph_do_while_arg = cond_label kernel._graph_do_while_level_stack.append(level_id) ctx.ast_builder.set_graph_do_while_level_id(level_id) try: diff --git a/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py b/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py index ab87271698..9ac922a1e7 100644 --- a/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py @@ -38,13 +38,17 @@ class CheckpointCallInfo: - ``cp_id``: the user-supplied label (an ``int`` or ``IntEnum`` value), or ``None`` for an auto-wrap implicit checkpoint. - - ``yield_on``: name of the kernel parameter passed as ``yield_on=`` (an ``ast.Name`` is required), or ``None`` for - an implicit checkpoint. + - ``yield_on_node``: the ``ast.expr`` passed as ``yield_on=`` -- either an ``ast.Name`` (bare kernel parameter, + e.g. ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` owner, or + ``params.flag`` where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). ``None`` for an implicit + checkpoint. ``build_checkpoint_with`` resolves the node to a flat C++ arg-id via + ``ASTTransformer._resolve_ndarray_kernel_arg_id`` so the runtime can forward it directly without per-launch + name matching. - ``is_implicit``: ``True`` iff this Call was synthesised by ``auto_wrap_for_loops``. """ cp_id: int | None - yield_on: str | None + yield_on_node: ast.expr | None is_implicit: bool @@ -133,7 +137,7 @@ def is_checkpoint_call(node: ast.expr, global_vars: dict) -> CheckpointCallInfo return None # Auto-wrap-synthesised implicit checkpoint: no user-facing args, no validation. if getattr(node, _IMPLICIT_MARKER_ATTR, False): - return CheckpointCallInfo(cp_id=None, yield_on=None, is_implicit=True) + return CheckpointCallInfo(cp_id=None, yield_on_node=None, is_implicit=True) # User-written `qd.checkpoint(cp_id, yield_on)` -- both args are required. if len(node.args) + len(node.keywords) == 0: raise QuadrantsSyntaxError("qd.checkpoint() takes two arguments: `qd.checkpoint(cp_id, yield_on=flag)`.") @@ -173,13 +177,18 @@ def is_checkpoint_call(node: ast.expr, global_vars: dict) -> CheckpointCallInfo "qd.checkpoint() is missing required argument `yield_on` (e.g. " "`qd.checkpoint(0, yield_on=overflow_flag)`)" ) - if not isinstance(yield_on_arg, ast.Name): + # `yield_on=` must point at an ndarray kernel argument -- a bare parameter (`yield_on=flag`), a + # `@qd.data_oriented` member (`yield_on=self.flag`), or a `@dataclasses.dataclass` parameter member + # (`yield_on=params.flag`). Other expressions can't be lowered to a flat arg-id and are rejected here so + # the user gets a clear compile-time error at the `with` site. + if not isinstance(yield_on_arg, (ast.Name, ast.Attribute)): raise QuadrantsSyntaxError( - "qd.checkpoint(yield_on=...) must be the bare name of a kernel parameter (e.g. " - "`yield_on=overflow_flag`); expressions are not supported" + "qd.checkpoint(yield_on=...) must reference a kernel ndarray argument -- e.g. `yield_on=flag` for " + "a bare parameter, `yield_on=self.flag` for a @qd.data_oriented member, or `yield_on=params.flag` " + "for a @dataclasses.dataclass parameter member; arbitrary expressions are not supported" ) cp_id_value = CheckpointTransformer._resolve_cp_id(cp_id_arg, global_vars) - return CheckpointCallInfo(cp_id=cp_id_value, yield_on=yield_on_arg.id, is_implicit=False) + return CheckpointCallInfo(cp_id=cp_id_value, yield_on_node=yield_on_arg, is_implicit=False) @staticmethod def build_checkpoint_with( @@ -216,14 +225,21 @@ def build_checkpoint_with( "same kernel must be flat siblings (a checkpoint inside qd.graph_do_while is fine)" ) + yield_on_label: str | None = None + yield_on_cpp_arg_id: int = -1 if not info.is_implicit: - # Validate `yield_on=` names a real kernel parameter. - arg_names = [m.name for m in kernel.arg_metas] - if info.yield_on not in arg_names: - raise QuadrantsSyntaxError( - f"qd.checkpoint(yield_on={info.yield_on!r}) does not match any parameter of kernel " - f"{kernel.func.__name__!r}. Available parameters: {arg_names}" - ) + # Resolve `yield_on=` (a bare parameter or `@qd.data_oriented` member ndarray) to its flat C++ arg-id at + # AST-build time. ``_resolve_ndarray_kernel_arg_id`` raises a user-facing ``QuadrantsSyntaxError`` if the + # expression does not name a real ndarray kernel argument, which keeps the diagnostic at the `with` site + # instead of leaking into the launcher. Both the label (for ``checkpoint_yield_on_args`` / introspection) + # and the resolved arg-id (for the runtime) are stashed and forwarded to the launch path below. + # Local import to avoid an ast_transformers -> ast_transformer cycle. + # pylint: disable-next=C0415,import-outside-toplevel + from quadrants.lang.ast.ast_transformer import ASTTransformer + + yield_on_label, yield_on_cpp_arg_id = ASTTransformer._resolve_ndarray_kernel_arg_id( + ctx, kernel, info.yield_on_node, "qd.checkpoint(yield_on=...)" + ) # Reject duplicate user-supplied cp_id labels. existing = [lbl for lbl in kernel.checkpoint_user_labels_by_cp_id if lbl is not None] if info.cp_id in existing: @@ -272,7 +288,8 @@ def build_checkpoint_with( f" ...\n" ) - kernel.checkpoint_yield_on_args.append(info.yield_on) + kernel.checkpoint_yield_on_args.append(yield_on_label) + kernel.checkpoint_yield_on_cpp_arg_ids.append(yield_on_cpp_arg_id) kernel.checkpoint_user_labels_by_cp_id.append(info.cp_id) # Hand control to the C++ ASTBuilder so that every for-loop emitted by `build_stmts` below is tagged with this # checkpoint's internal `cp_id` on its `ForLoopConfig.checkpoint_id`. The C++ counter is the source of truth for diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 4e7b8e5154..6a7497f1e3 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -514,6 +514,7 @@ def build_FunctionDef( # different argument shape) start from an empty list. Mirrors how `graph_do_while_arg` gets overwritten # unconditionally during AST traversal. kernel.checkpoint_yield_on_args = [] + kernel.checkpoint_yield_on_cpp_arg_ids = [] kernel.checkpoint_user_labels_by_cp_id = [] # Auto-wrap pass for `@qd.kernel(graph=True, checkpoints=True)` kernels. Mutates `node.body` in place so # every top-level for-loop (and every for-loop inside a `qd.graph_do_while` body) that the user did not diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 8a0cf0867d..89dfa4519e 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -344,9 +344,20 @@ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _is_classkernel self._graph_do_while_level_stack: list[int] = [] # Per-checkpoint metadata, one entry per `with qd.checkpoint(...)` block (explicit AND auto-injected implicit) # in declaration order. List index is the checkpoint's internal `cp_id` (0, 1, 2, ... dense, flat across the - # kernel). Each entry is the name of the `yield_on=` kernel parameter, or `None` for implicit checkpoints (which - # never yield). Populated by the AST transformer; empty means the kernel uses no checkpoints. + # kernel). Each entry is the readable label of the `yield_on=` argument (e.g. "flag" or "self.flag"), or + # `None` for implicit checkpoints (which never yield). Populated by the AST transformer; empty means the + # kernel uses no checkpoints. Used for error messages / introspection only -- the runtime forwards the flat + # C++ arg-id from `checkpoint_yield_on_cpp_arg_ids` below. self.checkpoint_yield_on_args: list[str | None] = [] + # Flat C++ arg-ids (post-template) of each explicit checkpoint's `yield_on=` ndarray, resolved at AST-build + # time by `CheckpointTransformer.build_checkpoint_with` via `ASTTransformer._resolve_ndarray_kernel_arg_id`. + # Same indexing as `checkpoint_yield_on_args`: entry `i` is the flat arg-id the runtime uses to look up the + # ndarray's device pointer for the checkpoint whose internal cp_id is `i`. `-1` for implicit checkpoints + # (which never yield). Resolving at AST-build time uniformly handles bare kernel parameters + # (`yield_on=flag`), `@qd.data_oriented` member ndarrays (`yield_on=self.flag`), and + # `@dataclasses.dataclass` parameter members (`yield_on=params.flag`); the attribute forms cannot be + # resolved by the per-launch name match because `arg_metas[i].name` only carries top-level parameter names. + self.checkpoint_yield_on_cpp_arg_ids: list[int] = [] # User-facing labels for explicit checkpoints. Same indexing as `checkpoint_yield_on_args`: entry `i` is the int # (or IntEnum value) the user passed as the first positional arg of `qd.checkpoint(cp_id, yield_on)` for the # checkpoint whose internal cp_id is `i`. Implicit checkpoints (auto-wrapped) get `None` (they have no @@ -396,14 +407,12 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType self.fast_checksum = src_hasher.create_cache_key( self.raise_on_templated_floats, kernel_source_info, args, self.arg_metas ) - used_py_dataclass_parameters = None - cached_graph_do_while_levels: list[tuple[str, int]] | None = None + cache_value = None if self.fast_checksum: self.src_ll_cache_observations.cache_key_generated = True - used_py_dataclass_parameters, frontend_cache_key, cached_graph_do_while_levels = src_hasher.load( # type: ignore[reportAssignmentType] - self.fast_checksum - ) - if used_py_dataclass_parameters is not None and frontend_cache_key is not None: + cache_value = src_hasher.load(self.fast_checksum) + if cache_value is not None: + frontend_cache_key = cache_value.frontend_cache_key self.src_ll_cache_observations.cache_validated = True prog = impl.get_runtime().prog assert self.fast_checksum is not None @@ -415,16 +424,22 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType ) if self.compiled_kernel_data_by_key[key]: self.src_ll_cache_observations.cache_loaded = True - self.used_py_dataclass_parameters_by_key_enforcing[key] = used_py_dataclass_parameters - # Fast-cache restore skips AST transformation, so rebuild the gdw level table (and the legacy - # outermost-arg alias) from the cached (cond_arg_name, parent_id) pairs. - if cached_graph_do_while_levels: + self.used_py_dataclass_parameters_by_key_enforcing[key] = cache_value.used_py_dataclass_parameters + # Fast-cache restore skips AST transformation, so rebuild the AST-transformer-produced metadata + # from the cache value: nested graph_do_while level table (with the AST-resolved flat C++ arg-id) + # plus the per-checkpoint yield_on / user-label tables. Mirrors what + # `function_def_transformer.py` + `checkpoint_transformer.py` + `build_While` would have written. + if cache_value.graph_do_while_levels: self.graph_do_while_levels = [ - GraphDoWhileLevel(cond_arg_name=name, parent_id=parent) - for name, parent in cached_graph_do_while_levels + GraphDoWhileLevel(cond_arg_name=name, parent_id=parent, cond_cpp_arg_id=cpp_arg_id) + for name, parent, cpp_arg_id in cache_value.graph_do_while_levels ] self.graph_do_while_arg = self.graph_do_while_levels[0].cond_arg_name - return used_py_dataclass_parameters + if cache_value.checkpoint_yield_on_args: + self.checkpoint_yield_on_args = list(cache_value.checkpoint_yield_on_args) + self.checkpoint_yield_on_cpp_arg_ids = list(cache_value.checkpoint_yield_on_cpp_arg_ids) + self.checkpoint_user_labels_by_cp_id = list(cache_value.checkpoint_user_labels_by_cp_id) + return cache_value.used_py_dataclass_parameters elif self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: # The bit in caps should not be modified without updating corresponding test @@ -587,13 +602,9 @@ def launch_kernel( is_launch_ctx_cacheable = True template_num = 0 i_out = 0 - # Hoist the `kernel has any yield_on= checkpoint` predicate out of the per-arg loop so non-checkpoint - # kernels (the overwhelming majority) skip the helper function call entirely on every arg. Gated on - # `use_checkpoints` first so non-checkpoint kernels pay only one attribute lookup, not the list-truthy - # check on every cache-miss build. - _kernel_has_yield_on_checkpoint = self.use_checkpoints and bool(self.checkpoint_yield_on_args) - if _kernel_has_yield_on_checkpoint: - _checkpoint_helpers.init_yield_on_arg_id_table(self) + # `checkpoint_yield_on_cpp_arg_ids` is populated at AST-build time (see + # `CheckpointTransformer.build_checkpoint_with`); no per-arg name match is needed here. The launch path + # below forwards the table to the launch context with a single `forward_yield_on_table_to_ctx` call. for i_in, val in enumerate(args): needed_ = self.arg_metas[i_in].annotation if needed_ is template or type(needed_) is template: @@ -607,12 +618,11 @@ def launch_kernel( # which weakens API/type safety and can route the wrong struct type through launch. if getattr(val, "_qd_all_field", False) and getattr(needed_, _FIELDS, None) is not None: continue - if self.graph_do_while_levels: - for _gdw_level in self.graph_do_while_levels: - if self.arg_metas[i_in].name == _gdw_level.cond_arg_name: - _gdw_level.cond_cpp_arg_id = i_out - template_num - if _kernel_has_yield_on_checkpoint: - _checkpoint_helpers.maybe_record_yield_on_arg(self, self.arg_metas[i_in].name, i_out - template_num) + # `graph_do_while_levels[*].cond_cpp_arg_id` is also populated at AST-build time (see + # `ASTTransformer.build_While` -> `_resolve_ndarray_kernel_arg_id`), so the launch path forwards it + # directly below without per-arg name matching here. This uniformly handles bare parameter conditions + # (`qd.graph_do_while(counter)`) and `@qd.data_oriented` member conditions + # (`qd.graph_do_while(self.counter)`). num_args_, is_launch_ctx_cacheable_ = self._recursive_set_args( self.used_py_dataclass_parameters_by_key_enforcing[key], self.arg_metas[i_in].name, @@ -675,8 +685,12 @@ def launch_kernel( self.visited_functions, self.used_py_dataclass_parameters_by_key_enforcing[key], graph_do_while_levels=[ # type: ignore[reportCallIssue] - (level.cond_arg_name, level.parent_id) for level in self.graph_do_while_levels + (level.cond_arg_name, level.parent_id, level.cond_cpp_arg_id) + for level in self.graph_do_while_levels ], + checkpoint_yield_on_args=list(self.checkpoint_yield_on_args), + checkpoint_yield_on_cpp_arg_ids=list(self.checkpoint_yield_on_cpp_arg_ids), + checkpoint_user_labels_by_cp_id=list(self.checkpoint_user_labels_by_cp_id), ) self.src_ll_cache_observations.cache_stored = True self._last_compiled_kernel_data = compiled_kernel_data diff --git a/python/quadrants/lang/kernel_checkpoint.py b/python/quadrants/lang/kernel_checkpoint.py index a3d44eac5f..6b80491f67 100644 --- a/python/quadrants/lang/kernel_checkpoint.py +++ b/python/quadrants/lang/kernel_checkpoint.py @@ -46,36 +46,17 @@ def translate_user_label_to_internal_cp_id(kernel: Any, user_label: int) -> int: ) -def init_yield_on_arg_id_table(kernel: Any) -> None: - """Allocate / reset the per-launch ``cp_id -> C++ arg-id`` table at the top of ``launch_kernel``'s arg iteration. - - Each entry defaults to ``-1`` ("no yield_on"); the per-arg loop below fills in the C++ arg id when it visits the - named parameter. Sized to the kernel's checkpoint count once per launch so any changes to the checkpoint set (only - possible via re-AST-walk) reset the table cleanly. No-op for kernels with no ``yield_on=`` checkpoints. - """ - if kernel.checkpoint_yield_on_args: - kernel._checkpoint_yield_on_cpp_arg_ids = [-1] * len(kernel.checkpoint_yield_on_args) - - -def maybe_record_yield_on_arg(kernel: Any, arg_name: str, cpp_arg_id: int) -> None: - """Fill the ``cp_id -> C++ arg-id`` slot when the arg iterator visits a named ``yield_on=`` kernel parameter. - - Walked once per kernel arg in ``launch_kernel``; cheap O(checkpoints) match. A single parameter can be the - ``yield_on=`` for multiple checkpoints (the inner loop fills every matching slot). - """ - if not kernel.checkpoint_yield_on_args: - return - for cp_idx, yield_name in enumerate(kernel.checkpoint_yield_on_args): - if yield_name is not None and arg_name == yield_name: - kernel._checkpoint_yield_on_cpp_arg_ids[cp_idx] = cpp_arg_id - - def forward_yield_on_table_to_ctx(kernel: Any, launch_ctx: Any) -> None: - """Copy the resolved ``cp_id -> C++ arg-id`` table onto the launch context so the runtime can find each - ``yield_on=`` ndarray's device address at launch. + """Copy the ``cp_id -> C++ arg-id`` table onto the launch context so the runtime can find each ``yield_on=`` + ndarray's device address at launch. + + The table (``kernel.checkpoint_yield_on_cpp_arg_ids``) is populated at AST-build time by + ``CheckpointTransformer.build_checkpoint_with`` via ``ASTTransformer._resolve_ndarray_kernel_arg_id``, which + uniformly handles bare kernel parameters (``yield_on=flag``) and ``@qd.data_oriented`` member ndarrays + (``yield_on=self.flag``). No per-launch arg iteration / name match is required. """ - if kernel.checkpoint_yield_on_args and hasattr(kernel, "_checkpoint_yield_on_cpp_arg_ids"): - launch_ctx.checkpoint_yield_on_arg_ids = tuple(kernel._checkpoint_yield_on_cpp_arg_ids) + if kernel.checkpoint_yield_on_cpp_arg_ids: + launch_ctx.checkpoint_yield_on_arg_ids = tuple(kernel.checkpoint_yield_on_cpp_arg_ids) def maybe_build_graph_status(kernel: Any, default_ret: Any) -> Any: diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index 573eb95b49..037cbec097 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -844,6 +844,15 @@ void export_lang(py::module &m) { } }); + // The (post-template) C++ arg-id vector of an external-tensor (ndarray) expression. For a top-level ndarray parameter + // or a flattened `@qd.data_oriented` member ndarray this is a single-element vector whose `[0]` entry is the flat + // arg-id the runtime matches against (e.g. for `qd.checkpoint(yield_on=...)` and `qd.graph_do_while(...)` + // AST-build-time resolution of bare-parameter vs `self.member` ndarray arguments). + m.def("get_external_tensor_arg_id", [](const Expr &expr) { + QD_ASSERT(expr.is()); + return expr.cast()->arg_id; + }); + m.def("get_external_tensor_shape_along_axis", Expr::make); diff --git a/tests/python/test_checkpoint.py b/tests/python/test_checkpoint.py index cf03d6c095..26c1c0e6ff 100644 --- a/tests/python/test_checkpoint.py +++ b/tests/python/test_checkpoint.py @@ -208,7 +208,7 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 @test_utils.test() def test_checkpoint_yield_on_nonexistent_arg_raises(): - """``yield_on`` must name a kernel parameter; typos / scope mismatches must error early.""" + """``yield_on`` must reference an ndarray kernel argument; typos / scope mismatches must error early.""" @qd.kernel(graph=True, checkpoints=True) def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0)): @@ -218,13 +218,14 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 x = qd.ndarray(qd.i32, shape=(4,)) flag = qd.ndarray(qd.i32, shape=()) - with pytest.raises(qd.QuadrantsSyntaxError, match="does not match any parameter"): + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): k(x, flag) @test_utils.test() -def test_checkpoint_yield_on_must_be_bare_name(): - """``yield_on=`` must be a bare ``ast.Name`` (a kernel parameter); expressions are not supported. Pinning the +def test_checkpoint_yield_on_must_be_name_or_attribute(): + """``yield_on=`` must reference an ndarray kernel argument -- either a bare ``ast.Name`` or an ``ast.Attribute`` + chain (for ``@qd.data_oriented`` member ndarrays). Arbitrary expressions are not supported; pinning the diagnostic so the user knows to refactor.""" @qd.kernel(graph=True, checkpoints=True) @@ -235,7 +236,7 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 x = qd.ndarray(qd.i32, shape=(4,)) flag = qd.ndarray(qd.i32, shape=()) - with pytest.raises(qd.QuadrantsSyntaxError, match=r"must be the bare name of a kernel parameter"): + with pytest.raises(qd.QuadrantsSyntaxError, match=r"must reference a kernel ndarray argument"): k(x, flag) @@ -843,6 +844,129 @@ def k( assert status.checkpoint == 0 +# ---------------------------------------------------------------------------------------------------------------------- +# Member-ndarray support for `yield_on=` (both `@qd.data_oriented` self-members and `@dataclasses.dataclass` parameter +# members). +# +# ``qd.checkpoint(yield_on=self.flag)`` and ``qd.checkpoint(yield_on=params.flag)`` (where ``params`` is a kernel +# parameter typed as a ``@dataclasses.dataclass``) both resolve the member ndarray to a flat C++ arg-id at AST-build +# time via ``ASTTransformer._resolve_ndarray_kernel_arg_id``: it builds the expression and reads the resolved +# ``ExternalTensorExpression.arg_id``, so any attribute chain that ends up as a kernel ndarray arg works the same +# way as a bare parameter name. This frees users from having to forward flag members as bare kernel parameters when +# the rest of the kernel already operates on the dataclass / data-oriented owner. +# ---------------------------------------------------------------------------------------------------------------------- + + +@test_utils.test() +def test_checkpoint_yield_on_data_oriented_member_metadata(): + """`yield_on=self.flag` is accepted and the resolved label is stored verbatim (``"self.flag"``) in + ``checkpoint_yield_on_args``, while ``checkpoint_yield_on_cpp_arg_ids`` carries the flat C++ arg-id the + runtime forwards to the launch context. Verifies the AST-build-time resolution path without booting the + backend.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.flag = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True, checkpoints=True) + def step(self): + with qd.checkpoint(0, yield_on=self.flag): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + sim = Sim() + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.flag.from_numpy(np.array(0, dtype=np.int32)) + sim.step() + np.testing.assert_array_equal(sim.x.to_numpy(), np.ones(N, dtype=np.int32)) + assert sim.step._primal.checkpoint_user_labels_by_cp_id == [0] + assert sim.step._primal.checkpoint_yield_on_args == ["self.flag"] + cpp_ids = sim.step._primal.checkpoint_yield_on_cpp_arg_ids + assert len(cpp_ids) == 1 and cpp_ids[0] >= 0 + + +@test_utils.test() +def test_checkpoint_yield_on_dataclass_member_metadata(): + """`yield_on=params.flag` for a ``@dataclasses.dataclass`` kernel parameter takes the same AST-build-time + resolution path as ``self.flag`` for a ``@qd.data_oriented`` owner -- the resolved label round-trips into + ``checkpoint_yield_on_args`` and the flat arg-id lands in ``checkpoint_yield_on_cpp_arg_ids``.""" + import dataclasses # pylint: disable=import-outside-toplevel + + N = 4 + + @dataclasses.dataclass + class Params: + x: qd.types.NDArray[qd.i32, 1] + flag: qd.types.NDArray[qd.i32, 0] + + @qd.kernel(graph=True, checkpoints=True) + def step(params: Params): + with qd.checkpoint(0, yield_on=params.flag): + for i in range(params.x.shape[0]): + params.x[i] = params.x[i] + 1 + + params = Params( + x=qd.ndarray(qd.i32, shape=(N,)), + flag=qd.ndarray(qd.i32, shape=()), + ) + params.x.from_numpy(np.zeros(N, dtype=np.int32)) + params.flag.from_numpy(np.array(0, dtype=np.int32)) + step(params) + np.testing.assert_array_equal(params.x.to_numpy(), np.ones(N, dtype=np.int32)) + assert step._primal.checkpoint_user_labels_by_cp_id == [0] + # Dataclass-parameter member access gets pre-rewritten by the AST pipeline to a flattened parameter name + # (`__qd_params__qd_flag`) before the checkpoint transformer sees it, so the label round-trips in the + # flattened form. The functional contract -- a valid flat C++ arg-id is resolved and the kernel mutates the + # right ndarray -- is the same as for the bare-param / `self.flag` forms. + labels = step._primal.checkpoint_yield_on_args + assert len(labels) == 1 and labels[0] is not None and "flag" in labels[0] + cpp_ids = step._primal.checkpoint_yield_on_cpp_arg_ids + assert len(cpp_ids) == 1 and cpp_ids[0] >= 0 + + +@test_utils.test() +def test_checkpoint_yield_on_data_oriented_member_yields_and_resumes(): + """Behavioural round-trip for `yield_on=self.flag`: setting the member flag from inside the kernel yields, and + ``kernel.resume(from_checkpoint=...)`` skips ahead to the named checkpoint. Same surface contract as the bare- + parameter form (`test_checkpoint_yield_on_yields_and_resumes`); the only difference is where the flag lives.""" + if not _supports_checkpoint_yield_resume(): + pytest.skip("backend does not implement checkpoint yield/resume") + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.flag = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True, checkpoints=True) + def step(self): + with qd.checkpoint(7, yield_on=self.flag): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + self.flag[()] = 1 + with qd.checkpoint(8, yield_on=self.flag): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 10 + + sim = Sim() + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.flag.from_numpy(np.array(0, dtype=np.int32)) + status = sim.step() + # Checkpoint 7 set the flag in the first iter so the kernel yields before running checkpoint 8. + assert status.yielded + assert status.checkpoint == 7 + np.testing.assert_array_equal(sim.x.to_numpy(), np.ones(N, dtype=np.int32)) + # User clears the flag and resumes from the post-yield checkpoint (skipping the +1 loop entirely). + sim.flag.from_numpy(np.array(0, dtype=np.int32)) + status = sim.step.resume(from_checkpoint=8) + assert not status.yielded + np.testing.assert_array_equal(sim.x.to_numpy(), np.full(N, 11, dtype=np.int32)) + + # ---------------------------------------------------------------------------------------------------------------------- # CUDA-native introspection (slice 1c). # ---------------------------------------------------------------------------------------------------------------------- diff --git a/tests/python/test_graph_do_while.py b/tests/python/test_graph_do_while.py index 2a7ba506b7..6e807100f9 100644 --- a/tests/python/test_graph_do_while.py +++ b/tests/python/test_graph_do_while.py @@ -320,10 +320,85 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)): x = qd.ndarray(qd.i32, shape=(4,)) c = qd.ndarray(qd.i32, shape=()) c.from_numpy(np.array(1, dtype=np.int32)) - with pytest.raises(qd.QuadrantsSyntaxError, match="does not match any parameter"): + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): k(x, c) +@test_utils.test() +def test_graph_do_while_with_dataclass_member_counter(): + """`qd.graph_do_while(params.counter)` for a ``@dataclasses.dataclass`` kernel parameter takes the same + AST-build-time resolution path as the ``self.counter`` form -- the loop drives entirely from the device-side + counter just like the bare-parameter case.""" + import dataclasses # pylint: disable=import-outside-toplevel + + N = 4 + + @dataclasses.dataclass + class Params: + x: qd.types.NDArray[qd.i32, 1] + counter: qd.types.NDArray[qd.i32, 0] + + @qd.kernel(graph=True) + def step(params: Params): + while qd.graph_do_while(params.counter): + for i in range(params.x.shape[0]): + params.x[i] = params.x[i] + 1 + for _ in range(1): + params.counter[()] = params.counter[()] - 1 + + params = Params( + x=qd.ndarray(qd.i32, shape=(N,)), + counter=qd.ndarray(qd.i32, shape=()), + ) + params.x.from_numpy(np.zeros(N, dtype=np.int32)) + params.counter.from_numpy(np.array(3, dtype=np.int32)) + step(params) + np.testing.assert_array_equal(params.x.to_numpy(), np.full(N, 3, dtype=np.int32)) + assert params.counter.to_numpy() == 0 + levels = step._primal.graph_do_while_levels + assert len(levels) == 1 + # Dataclass-parameter member access gets pre-rewritten to a flattened parameter name + # (`__qd_params__qd_counter`) before the graph_do_while transformer sees it, so the readable label round-trips + # in the flattened form. The functional contract -- a valid flat C++ arg-id resolves and the loop drives off + # the device-side counter -- is the same as for the bare-param / `self.counter` forms. + assert "counter" in levels[0].cond_arg_name + assert levels[0].cond_cpp_arg_id >= 0 + + +@test_utils.test() +def test_graph_do_while_with_data_oriented_member_counter(): + """`qd.graph_do_while(self.counter)` resolves the member ndarray to the loop condition's flat C++ arg-id at + AST-build time via ``ASTTransformer._resolve_ndarray_kernel_arg_id``, lifting the previous bare-parameter + restriction. The metadata exposed on the kernel records the readable label (``"self.counter"``) plus the + resolved arg-id; the loop behaviour matches the bare-parameter form below.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.counter = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True) + def step(self): + while qd.graph_do_while(self.counter): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + for _ in range(1): + self.counter[()] = self.counter[()] - 1 + + sim = Sim() + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.counter.from_numpy(np.array(3, dtype=np.int32)) + sim.step() + np.testing.assert_array_equal(sim.x.to_numpy(), np.full(N, 3, dtype=np.int32)) + assert sim.counter.to_numpy() == 0 + levels = sim.step._primal.graph_do_while_levels + assert len(levels) == 1 + assert levels[0].cond_arg_name == "self.counter" + assert levels[0].cond_cpp_arg_id >= 0 + + @qd.kernel(graph=True, fastcache=True) def _fastcache_do_while_kernel(x: qd.types.ndarray(qd.i32, ndim=1), counter: qd.types.ndarray(qd.i32, ndim=0)): while qd.graph_do_while(counter): From d3de1296a735daab72ee38ba230713254b7000e5 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 24 Jun 2026 06:22:33 -0700 Subject: [PATCH 2/5] [Graph] Add tests for member-ndarray yield_on= / graph_do_while + schema-v3 round-trip Follows up on the member-ndarray support commit with a wider test surface: - Behavioural yield/resume for `yield_on=params.flag` (dataclass mirror of the existing data_oriented test). - Error paths for `yield_on=self.nonexistent_attr`, `qd.graph_do_while(self.nonexistent_attr)`, and `yield_on=self.scalar_attr` (non-ndarray) -- pins the user-facing diagnostic for the attribute forms. - Nested `qd.graph_do_while(self.outer)` containing `qd.graph_do_while(self.inner)` -- exercises the level-table machinery with `@qd.data_oriented` member ndarrays end-to-end. - Direct ``CacheValue`` round-trip unit test for schema v3 (`cachevalue-v3-ast-resolved-ids`): covers the new 3-tuple `graph_do_while_levels` + `checkpoint_yield_on_args` / `checkpoint_yield_on_cpp_arg_ids` / `checkpoint_user_labels_by_cp_id` fields the loader/storer now plumb through. - Cross-process fast-cache restore test for a `@qd.kernel(graph=True, checkpoints=True, fastcache=True)` kernel with `yield_on=self.flag` -- without the schema-v3 restore the launch path's `forward_yield_on_table_to_ctx` would be a no-op and yield/resume would silently break on fast-cached checkpoint kernels. Also fixes an existing `test_src_hasher_store_validate` assertion that indexed `src_hasher.load(...)` by position; the loader now returns a `CacheValue` (or `None`) so the test is updated to use attribute access plus asserts the new default-empty fields are present on the round-tripped object. --- .../lang/fast_caching/test_src_hasher.py | 69 +++++- tests/python/test_checkpoint.py | 208 ++++++++++++++++++ tests/python/test_graph_do_while.py | 77 +++++++ 3 files changed, 348 insertions(+), 6 deletions(-) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py index e7a2d9952b..53b03d8753 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py @@ -109,19 +109,17 @@ def get_fileinfos(functions: list[Callable]) -> list[_wrap_inspect.FunctionSourc kernel_cache_key = "I'm a kernel cache key" - load_res = src_hasher.load(fast_cache_key) - assert load_res[0] is None and load_res[1] is None + assert src_hasher.load(fast_cache_key) is None some_used_vars = {"fee", "fi", "fo"} src_hasher.store(kernel_cache_key, fast_cache_key, fileinfos, some_used_vars) def assert_loaded(cache_key: str) -> None: res = src_hasher.load(cache_key) - assert res[0] is not None and res[1] is not None + assert res is not None and res.frontend_cache_key == kernel_cache_key def assert_not_loaded(cache_key: str) -> None: - res = src_hasher.load(cache_key) - assert res[0] is None and res[1] is None + assert src_hasher.load(cache_key) is None assert_loaded(fast_cache_key) @@ -136,7 +134,66 @@ def assert_not_loaded(cache_key: str) -> None: assert_not_loaded("abcdefg") - assert src_hasher.load(fast_cache_key)[0] == some_used_vars + loaded = src_hasher.load(fast_cache_key) + assert loaded is not None + assert loaded.used_py_dataclass_parameters == some_used_vars + # The new schema-v3 AST-resolved fields default to empty for kernels with no graph_do_while / checkpoint + # metadata, exercising the BaseModel default path on round-trip. + assert loaded.graph_do_while_levels is None + assert loaded.checkpoint_yield_on_args == [] + assert loaded.checkpoint_yield_on_cpp_arg_ids == [] + assert loaded.checkpoint_user_labels_by_cp_id == [] + + +@test_utils.test() +def test_src_hasher_store_validate_round_trips_schema_v3_metadata( + monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, temporary_module +) -> None: + """Schema v3 (`cachevalue-v3-ast-resolved-ids`) added AST-resolved arg-id fields to the persisted ``CacheValue`` + so the launch path can forward them after a fast-cache restore (which skips AST transformation). This test + pins the round-trip for the new fields -- ``graph_do_while_levels`` as 3-tuples carrying ``cond_cpp_arg_id``, + plus ``checkpoint_yield_on_args`` / ``checkpoint_yield_on_cpp_arg_ids`` / ``checkpoint_user_labels_by_cp_id``. + Without this, a schema bug (wrong tuple arity, dropped field, mis-typed BaseModel default) would only surface + via a hard-to-debug functional regression in a fast-cached checkpoint / graph_do_while kernel.""" + test_files_path = pathlib.Path("tests/python/quadrants/lang/fast_caching/test_files") + + offline_cache_path = tmp_path / "cache" + temp_import_path = tmp_path / "temp_import" + temp_import_path.mkdir(exist_ok=True) + + qd_init_same_arch(offline_cache_file_path=str(offline_cache_path)) + + monkeypatch.syspath_prepend(temp_import_path) + shutil.copy2(test_files_path / "child_diff_base.py", temp_import_path / "child_diff_schema_v3.py") + mod = temporary_module("child_diff_schema_v3") + info, _src = _wrap_inspect.get_source_info_and_src(mod.f1.fn) + fileinfos = [info] + fast_cache_key = src_hasher.create_cache_key(False, info, [], []) + assert fast_cache_key is not None + + gdw_levels = [("self.outer", -1, 4), ("self.inner", 0, 5)] + cp_yield_args = ["self.flag", None, "params.flag"] + cp_yield_cpp_ids = [3, -1, 7] + cp_user_labels = [10, None, 20] + + src_hasher.store( + "kernel_cache_key_v3", + fast_cache_key, + fileinfos, + {"used_var"}, + graph_do_while_levels=gdw_levels, + checkpoint_yield_on_args=cp_yield_args, + checkpoint_yield_on_cpp_arg_ids=cp_yield_cpp_ids, + checkpoint_user_labels_by_cp_id=cp_user_labels, + ) + + loaded = src_hasher.load(fast_cache_key) + assert loaded is not None + assert loaded.frontend_cache_key == "kernel_cache_key_v3" + assert loaded.graph_do_while_levels == [("self.outer", -1, 4), ("self.inner", 0, 5)] + assert loaded.checkpoint_yield_on_args == cp_yield_args + assert loaded.checkpoint_yield_on_cpp_arg_ids == cp_yield_cpp_ids + assert loaded.checkpoint_user_labels_by_cp_id == cp_user_labels # Should be enough to run these on cpu I think, and anything involving diff --git a/tests/python/test_checkpoint.py b/tests/python/test_checkpoint.py index 26c1c0e6ff..5711311764 100644 --- a/tests/python/test_checkpoint.py +++ b/tests/python/test_checkpoint.py @@ -19,9 +19,14 @@ (IF conditional node count) are guarded behind ``_is_checkpoint_if_path_native``. """ +import os +import pathlib +import subprocess +import sys from enum import IntEnum import numpy as np +import pydantic import pytest import quadrants as qd @@ -29,6 +34,9 @@ from tests import test_utils +TEST_RAN = "test ran" +RET_SUCCESS = 42 + def _on_cuda(): return impl.current_cfg().arch == qd.cuda @@ -927,6 +935,102 @@ def step(params: Params): assert len(cpp_ids) == 1 and cpp_ids[0] >= 0 +@test_utils.test() +def test_checkpoint_yield_on_dataclass_member_yields_and_resumes(): + """Behavioural round-trip for `yield_on=params.flag` -- mirror of the `self.flag` test below, using a + `@dataclasses.dataclass` kernel parameter instead of a `@qd.data_oriented` owner. The dataclass-member access + is pre-rewritten to a flattened parameter, so verifying the full yield/resume contract end-to-end is the only + way to confirm the right ndarray is wired up at launch.""" + import dataclasses # pylint: disable=import-outside-toplevel + + if not _supports_checkpoint_yield_resume(): + pytest.skip("backend does not implement checkpoint yield/resume") + N = 4 + + @dataclasses.dataclass + class Params: + x: qd.types.NDArray[qd.i32, 1] + flag: qd.types.NDArray[qd.i32, 0] + + @qd.kernel(graph=True, checkpoints=True) + def step(params: Params): + with qd.checkpoint(7, yield_on=params.flag): + for i in range(params.x.shape[0]): + params.x[i] = params.x[i] + 1 + params.flag[()] = 1 + with qd.checkpoint(8, yield_on=params.flag): + for i in range(params.x.shape[0]): + params.x[i] = params.x[i] + 10 + + params = Params( + x=qd.ndarray(qd.i32, shape=(N,)), + flag=qd.ndarray(qd.i32, shape=()), + ) + params.x.from_numpy(np.zeros(N, dtype=np.int32)) + params.flag.from_numpy(np.array(0, dtype=np.int32)) + status = step(params) + assert status.yielded + assert status.checkpoint == 7 + np.testing.assert_array_equal(params.x.to_numpy(), np.ones(N, dtype=np.int32)) + params.flag.from_numpy(np.array(0, dtype=np.int32)) + # `step` is a free-function kernel (not a bound class kernel), so `params` must be passed positionally to + # `resume` -- the data_oriented sibling test above can omit it because the dataclass member access is + # implicit through `sim.step`'s bound `self`. + status = step.resume(params, from_checkpoint=8) + assert not status.yielded + np.testing.assert_array_equal(params.x.to_numpy(), np.full(N, 11, dtype=np.int32)) + + +@test_utils.test() +def test_checkpoint_yield_on_member_nonexistent_attribute_raises(): + """`yield_on=self.nonexistent_attr` (attribute does not exist on the `@qd.data_oriented` owner) must raise a + user-facing `QuadrantsSyntaxError` at the `with` site -- the AST-time resolver wraps the underlying attribute + lookup failure in the same `does not resolve to an ndarray kernel parameter` diagnostic as the bare-name + nonexistent case, so users see one consistent error pattern.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.flag = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True, checkpoints=True) + def step(self): + with qd.checkpoint(0, yield_on=self.nonexistent_flag): # type: ignore[attr-defined] + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + sim = Sim() + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): + sim.step() + + +@test_utils.test() +def test_checkpoint_yield_on_member_non_ndarray_attribute_raises(): + """`yield_on=self.scalar` where `self.scalar` is a Python int (not an ndarray) must raise the same + `does not resolve to an ndarray kernel parameter` diagnostic -- the AST-time resolver builds the expression + but rejects it because the resulting Expr is not an `ExternalTensorExpression`. Pinning this so future + refactors of the resolver can't silently accept non-ndarray attributes and crash later in the launcher.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.scalar = 7 + + @qd.kernel(graph=True, checkpoints=True) + def step(self): + with qd.checkpoint(0, yield_on=self.scalar): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + sim = Sim() + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): + sim.step() + + @test_utils.test() def test_checkpoint_yield_on_data_oriented_member_yields_and_resumes(): """Behavioural round-trip for `yield_on=self.flag`: setting the member flag from inside the kernel yields, and @@ -967,6 +1071,102 @@ def step(self): np.testing.assert_array_equal(sim.x.to_numpy(), np.full(N, 11, dtype=np.int32)) +# Module-level kernel for the fastcache-restoration test below. Lives outside any test so the child subprocess can +# import the test module and reach it without re-creating the (closure-captured) outer scope. The kernel has to be +# annotated with `fastcache=True` (=> implies `pure`) and lifted out of any decorator-bound owner so it qualifies +# for the src_ll_cache path. We model the data_oriented owner as the `_FastcacheYieldOnSelfCheckpoint` class below. + + +@qd.data_oriented +class _FastcacheYieldOnSelfCheckpoint: + def __init__(self, n: int): + self.x = qd.ndarray(qd.i32, shape=(n,)) + self.flag = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True, checkpoints=True, fastcache=True) + def step(self): + with qd.checkpoint(0, yield_on=self.flag): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + +class _FastcacheCheckpointArgs(pydantic.BaseModel): + arch: str + offline_cache_file_path: str + expect_loaded_from_fastcache: bool + + +def _fastcache_checkpoint_child(args: list[str]) -> None: + args_obj = _FastcacheCheckpointArgs.model_validate_json(args[0]) + qd.init( + arch=getattr(qd, args_obj.arch), + offline_cache=True, + offline_cache_file_path=args_obj.offline_cache_file_path, + src_ll_cache=True, + ) + + N = 8 + sim = _FastcacheYieldOnSelfCheckpoint(N) + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.flag.from_numpy(np.array(0, dtype=np.int32)) + sim.step() + np.testing.assert_array_equal(sim.x.to_numpy(), np.ones(N, dtype=np.int32)) + + primal = type(sim).step._primal + # The schema-v3 fast-cache restore path must repopulate `checkpoint_yield_on_args` and + # `checkpoint_yield_on_cpp_arg_ids` from the cached `CacheValue` (since AST transformation is skipped on a + # cache hit). A regression here would surface as an empty `_forward_yield_on_table_to_ctx` call, silently + # breaking yield/resume on fast-cached checkpoint kernels. + labels = primal.checkpoint_yield_on_args + cpp_ids = primal.checkpoint_yield_on_cpp_arg_ids + assert ( + labels and len(labels) == 1 and labels[0] is not None and "flag" in labels[0] + ), f"checkpoint_yield_on_args should round-trip with one slot containing 'flag', got {labels!r}" + assert ( + len(cpp_ids) == 1 and cpp_ids[0] >= 0 + ), f"checkpoint_yield_on_cpp_arg_ids should round-trip with one valid id, got {cpp_ids!r}" + assert primal.checkpoint_user_labels_by_cp_id == [ + 0 + ], f"checkpoint_user_labels_by_cp_id should round-trip as [0], got {primal.checkpoint_user_labels_by_cp_id!r}" + assert primal.src_ll_cache_observations.cache_loaded == args_obj.expect_loaded_from_fastcache, ( + f"cache_loaded={primal.src_ll_cache_observations.cache_loaded!r} but expected " + f"{args_obj.expect_loaded_from_fastcache!r}" + ) + + print(TEST_RAN) + sys.exit(RET_SUCCESS) + + +@test_utils.test() +def test_checkpoint_fastcache_restores_self_member_yield_on(tmp_path: pathlib.Path): + """After a fast-cache restore in a fresh process, a `@qd.kernel(graph=True, checkpoints=True, fastcache=True)` + kernel with `yield_on=self.flag` must repopulate `checkpoint_yield_on_args` / + `checkpoint_yield_on_cpp_arg_ids` / `checkpoint_user_labels_by_cp_id` from the persisted ``CacheValue`` -- + not from the AST transformer, which is skipped on a cache hit. Without the schema-v3 round-trip the launch + path's `forward_yield_on_table_to_ctx` would be a no-op and yield/resume would silently break for fast-cached + checkpoint kernels.""" + assert qd.lang is not None + arch = qd.lang.impl.current_cfg().arch.name + env = dict(os.environ) + env["PYTHONPATH"] = "." + + for expect_loaded in [False, True]: + args_obj = _FastcacheCheckpointArgs( + arch=arch, + offline_cache_file_path=str(tmp_path / "cache"), + expect_loaded_from_fastcache=expect_loaded, + ) + cmd_line = [sys.executable, __file__, _fastcache_checkpoint_child.__name__, args_obj.model_dump_json()] + proc = subprocess.run(cmd_line, capture_output=True, text=True, env=env) + if proc.returncode != RET_SUCCESS: + print(" ".join(cmd_line)) + print(proc.stdout) + print("-" * 100) + print(proc.stderr) + assert TEST_RAN in proc.stdout + assert proc.returncode == RET_SUCCESS + + # ---------------------------------------------------------------------------------------------------------------------- # CUDA-native introspection (slice 1c). # ---------------------------------------------------------------------------------------------------------------------- @@ -1001,3 +1201,11 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 assert ( _num_checkpoints_on_last_call() == 3 ), f"expected 3 IF conditional nodes (2 implicit + 1 explicit), got {_num_checkpoints_on_last_call()}" + + +# Subprocess dispatch for fast-cache restoration tests above (mirrors the pattern in `test_graph_do_while.py`). The +# parent test invokes us via `subprocess.run([sys.executable, __file__, , ])` so the +# child runs in a fresh interpreter with a clean `qd.init` -- the only way to exercise the cross-process fast-cache +# load path that ``Kernel._try_load_fastcache`` takes after a previous run has populated the on-disk cache. +if __name__ == "__main__": + globals()[sys.argv[1]](sys.argv[2:]) diff --git a/tests/python/test_graph_do_while.py b/tests/python/test_graph_do_while.py index 6e807100f9..18ca6a8ba7 100644 --- a/tests/python/test_graph_do_while.py +++ b/tests/python/test_graph_do_while.py @@ -365,6 +365,83 @@ def step(params: Params): assert levels[0].cond_cpp_arg_id >= 0 +@test_utils.test() +def test_graph_do_while_with_member_nonexistent_attribute_raises(): + """`qd.graph_do_while(self.nonexistent_attr)` must raise the same user-facing + `does not resolve to an ndarray kernel parameter` diagnostic as the bare-name nonexistent case. The AST-time + resolver wraps the underlying attribute lookup failure so the user sees one consistent error pattern across + bare-name and attribute forms.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.counter = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True) + def step(self): + while qd.graph_do_while(self.nonexistent_counter): # type: ignore[attr-defined] + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + sim = Sim() + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): + sim.step() + + +@test_utils.test() +def test_graph_do_while_with_data_oriented_member_nested(): + """Nested `qd.graph_do_while(self.outer)` containing `qd.graph_do_while(self.inner)` exercises the level-table + machinery with member ndarrays: each level resolves its own flat C++ arg-id at AST-build time, the parent_id + chain links inner -> outer, and the loop body iterates `outer_iters * inner_iters` times the same as the + bare-parameter version (see `test_graph_do_while_nested_two_levels`).""" + if not _is_graph_do_while_natively_supported() and not ( + impl.current_cfg().arch in (qd.x64, qd.arm64, qd.amdgpu, qd.vulkan, qd.metal) + ): + pytest.skip("backend does not implement graph_do_while") + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.outer = qd.ndarray(qd.i32, shape=()) + self.inner = qd.ndarray(qd.i32, shape=()) + self.inner_start = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True) + def step(self): + while qd.graph_do_while(self.outer): + for _ in range(1): + self.inner[()] = self.inner_start[()] + while qd.graph_do_while(self.inner): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + for _ in range(1): + self.inner[()] = self.inner[()] - 1 + for _ in range(1): + self.outer[()] = self.outer[()] - 1 + + sim = Sim() + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.outer.from_numpy(np.array(3, dtype=np.int32)) + sim.inner.from_numpy(np.array(2, dtype=np.int32)) + sim.inner_start.from_numpy(np.array(2, dtype=np.int32)) + sim.step() + np.testing.assert_array_equal(sim.x.to_numpy(), np.full(N, 6, dtype=np.int32)) + assert sim.outer.to_numpy() == 0 + levels = sim.step._primal.graph_do_while_levels + assert len(levels) == 2 + assert levels[0].cond_arg_name == "self.outer" + assert levels[1].cond_arg_name == "self.inner" + assert levels[0].parent_id == -1 + assert levels[1].parent_id == 0 + assert levels[0].cond_cpp_arg_id >= 0 + assert levels[1].cond_cpp_arg_id >= 0 + assert levels[0].cond_cpp_arg_id != levels[1].cond_cpp_arg_id + + @test_utils.test() def test_graph_do_while_with_data_oriented_member_counter(): """`qd.graph_do_while(self.counter)` resolves the member ndarray to the loop condition's flat C++ arg-id at From 58bba23cd0d6dd84a7afec999048b3dd502fac06 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 24 Jun 2026 07:17:11 -0700 Subject: [PATCH 3/5] [Graph] Preserve IntEnum identity for checkpoint labels across fast-cache restore Bumps fast-cache schema to v4 (`cachevalue-v4-intenum-qualnames`) to fix a P2 regression flagged on PR #760 where ``qd.checkpoint(Stage.X, ...)`` round-tripped through fast-cache as the raw ``int`` rather than the original ``IntEnum`` member. Pydantic coerces ``IntEnum`` to ``int`` at ``CacheValue`` construction time (the field is typed ``list[int | None]``), so persisting only the int column was lossy. The fix stores a parallel ``checkpoint_user_label_enum_qualnames`` column with the original member's ``module.ClassQualName.MEMBER`` string and rebuilds the enum on load via ``_resolve_intenum_member`` (importlib + attribute walk; falls back to the persisted int if the enum was moved/renamed since the cache was written, so stale caches degrade gracefully rather than crashing). The store-side helper ``_intenum_member_qualname`` returns ``None`` for plain-int labels so non-enum users pay nothing. Tests: - ``test_checkpoint_fastcache_preserves_intenum_label_identity`` -- end-to-end subprocess cache miss + hit, asserts ``isinstance(label, _FastcacheStage)`` after fast-cache restore (not just int equality). - ``test_src_hasher_intenum_qualname_round_trip`` -- direct ``CacheValue`` unit test covering mixed IntEnum/None/int label slots, the parallel qualname derivation, and the fallback when a qualname no longer resolves. Older v3 caches drop into the same raw-int fallback path the loader uses for plain-int labels (the missing column defaults to an empty list, padded to None per slot), so no migration is required -- the version bump just invalidates them via ``create_cache_key``. --- .../lang/_fast_caching/src_hasher.py | 91 ++++++++++++++++++- python/quadrants/lang/kernel.py | 15 ++- .../lang/fast_caching/test_src_hasher.py | 72 ++++++++++++++- tests/python/test_checkpoint.py | 86 ++++++++++++++++++ 4 files changed, 259 insertions(+), 5 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 6b6c0227df..45b7fc9428 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -1,6 +1,8 @@ +import importlib import json import os import warnings +from enum import IntEnum from typing import Any, Iterable, Sequence import pydantic @@ -20,8 +22,75 @@ # Bumped whenever the persisted CacheValue schema changes (see create_cache_key). v2 replaced the single # graph_do_while_arg string with a nested level table. v3 added the AST-resolved flat C++ arg-ids for # qd.graph_do_while conditions and qd.checkpoint(yield_on=...) targets so the launch path can forward them -# directly without per-launch name matching (necessary for @qd.data_oriented member ndarrays). -_CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v3-ast-resolved-ids" +# directly without per-launch name matching (necessary for @qd.data_oriented member ndarrays). v4 added the +# per-slot `checkpoint_user_label_enum_qualnames` table so an IntEnum cp_id (e.g. `qd.checkpoint(Stage.SIM, ...)`) +# round-trips through fast-cache restore as the original IntEnum member rather than the underlying int. +_CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v4-intenum-qualnames" + + +def _intenum_member_qualname(value: Any) -> str | None: + """Return ``"module.ClassQualName.MEMBER"`` for an ``IntEnum`` member, else ``None``. + + Stored alongside ``checkpoint_user_labels_by_cp_id`` so that ``_resolve_intenum_member`` can rebuild the + original enum member on fast-cache restore -- pydantic coerces ``IntEnum`` to plain ``int`` at ``CacheValue`` + construction time (it sees ``list[int | None]``), which would otherwise silently break the documented + contract that ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` rather than the raw int through + ``status.checkpoint``. Returns ``None`` for plain ints, ``None`` labels, anonymous enums (no ``__module__``), + and other unsupported shapes -- the loader falls back to the raw int in those cases. + """ + if not isinstance(value, IntEnum): + return None + cls = type(value) + module = getattr(cls, "__module__", None) + qualname = getattr(cls, "__qualname__", None) + name = getattr(value, "name", None) + if not module or not qualname or not name: + return None + return f"{module}.{qualname}.{name}" + + +def _resolve_intenum_member(qualname: str | None, fallback: int | None) -> int | IntEnum | None: + """Inverse of ``_intenum_member_qualname``: look up the enum member by ``"module.ClassQualName.MEMBER"``. + + Returns the resolved ``IntEnum`` member if every step (module import, attribute walk) succeeds AND the member's + int value matches ``fallback`` (the raw int from ``checkpoint_user_labels_by_cp_id`` we already persisted). + Mismatch or any failure -- module renamed since the cache was written, enum class refactored, member removed, + etc. -- falls back to ``fallback`` so the user still gets a usable (if enum-identity-less) label rather than a + hard crash. ``None`` qualname / ``None`` fallback short-circuit to ``fallback`` for the plain-int label case. + """ + if qualname is None or fallback is None: + return fallback + try: + # qualname is "module.path.Class[.Nested].MEMBER"; the MEMBER tail is always one segment, so rsplit once. + # The remaining cls_path mixes dotted module path + dotted class qualname; we try progressively shorter + # module prefixes until one imports, then resolve the rest as attribute chain. This handles top-level + # enums (``mymod.Stage.LOAD``), enums nested in classes (``mymod.Outer.Inner.MEMBER``), and enums in + # subpackages (``a.b.Stage.LOAD``) without needing the user to declare which prefix is the module. + cls_path, _, member_name = qualname.rpartition(".") + if not cls_path or not member_name: + return fallback + module = None + cls_attr_path = "" + segments = cls_path.split(".") + for i in range(len(segments), 0, -1): + try: + module = importlib.import_module(".".join(segments[:i])) + cls_attr_path = ".".join(segments[i:]) + break + except ImportError: + continue + if module is None: + return fallback + obj: Any = module + if cls_attr_path: + for seg in cls_attr_path.split("."): + obj = getattr(obj, seg) + obj = getattr(obj, member_name) + except (AttributeError, ValueError): + return fallback + if isinstance(obj, IntEnum) and int(obj) == int(fallback): + return obj + return fallback def create_cache_key( @@ -85,6 +154,14 @@ class CacheValue(BaseModel): checkpoint_yield_on_args: list[str | None] = [] checkpoint_yield_on_cpp_arg_ids: list[int] = [] checkpoint_user_labels_by_cp_id: list[int | None] = [] + # Parallel to ``checkpoint_user_labels_by_cp_id``: each entry is the dotted ``module.ClassQualName.MEMBER`` of + # the original ``IntEnum`` member the user passed as ``cp_id``, or ``None`` if the user passed a plain int (or + # for implicit auto-wrap checkpoints). On fast-cache restore the loader runs each entry through + # ``_resolve_intenum_member`` to rebuild the IntEnum, preserving the documented contract that + # ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` (not the underlying int) through + # ``status.checkpoint`` and ``kernel.resume(from_checkpoint=...)`` -- pydantic coerces IntEnum to int at + # ``CacheValue`` construction time so the parallel qualname column is what carries the enum identity. + checkpoint_user_label_enum_qualnames: list[str | None] = [] def store( @@ -97,6 +174,11 @@ def store( checkpoint_yield_on_cpp_arg_ids: list[int] | None = None, checkpoint_user_labels_by_cp_id: list[int | None] | None = None, ) -> None: + # `checkpoint_user_label_enum_qualnames` is derived from `checkpoint_user_labels_by_cp_id` here (rather than + # being plumbed through a separate kwarg from `Kernel.materialize`) so callers never have to think about the + # parallel column: they pass the live label list (which still holds the original ``IntEnum`` instances at + # store time, before pydantic's int-coercion strips identity in ``CacheValue.__init__``), and the qualname + # snapshot is recorded once here for the loader to consume. """ Note that unlike other caches, this cache is not going to store the actual value we want. This cache is only used for verification that our cache key is valid. Big picture: @@ -119,6 +201,8 @@ def store( assert frontend_cache_key is not None cache = PythonSideCache() hashed_function_source_infos = function_hasher.hash_functions(function_source_infos) + labels = checkpoint_user_labels_by_cp_id or [] + enum_qualnames = [_intenum_member_qualname(lbl) for lbl in labels] cache_value_obj = CacheValue( frontend_cache_key=frontend_cache_key, hashed_function_source_infos=list(hashed_function_source_infos), @@ -126,7 +210,8 @@ def store( graph_do_while_levels=graph_do_while_levels, checkpoint_yield_on_args=checkpoint_yield_on_args or [], checkpoint_yield_on_cpp_arg_ids=checkpoint_yield_on_cpp_arg_ids or [], - checkpoint_user_labels_by_cp_id=checkpoint_user_labels_by_cp_id or [], + checkpoint_user_labels_by_cp_id=labels, + checkpoint_user_label_enum_qualnames=enum_qualnames, ) cache.store(fast_cache_key, cache_value_obj.model_dump_json()) diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index d31356dd09..9d7efe84f3 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -438,7 +438,20 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType if cache_value.checkpoint_yield_on_args: self.checkpoint_yield_on_args = list(cache_value.checkpoint_yield_on_args) self.checkpoint_yield_on_cpp_arg_ids = list(cache_value.checkpoint_yield_on_cpp_arg_ids) - self.checkpoint_user_labels_by_cp_id = list(cache_value.checkpoint_user_labels_by_cp_id) + # Pydantic coerces IntEnum -> int at CacheValue construction time, so the raw labels are + # plain ints after JSON round-trip. ``checkpoint_user_label_enum_qualnames`` carries the + # parallel ``module.ClassQualName.MEMBER`` strings that ``_resolve_intenum_member`` uses + # to rebuild the original ``IntEnum`` member -- preserving the documented contract that + # ``qd.checkpoint(Stage.X, ...)`` surfaces as ``Stage.X`` (not the raw int) on + # ``status.checkpoint``. Older v3 caches predate the qualname column, so we default any + # missing slots to ``None`` -> raw-int fallback (the same behaviour they had on v3). + raw_labels = list(cache_value.checkpoint_user_labels_by_cp_id) + qualnames = list(cache_value.checkpoint_user_label_enum_qualnames) or [None] * len(raw_labels) + if len(qualnames) != len(raw_labels): + qualnames = [None] * len(raw_labels) + self.checkpoint_user_labels_by_cp_id = [ + src_hasher._resolve_intenum_member(qn, lbl) for qn, lbl in zip(qualnames, raw_labels) + ] return cache_value.used_py_dataclass_parameters elif self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: diff --git a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py index 53b03d8753..2dd6dd6683 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py @@ -4,6 +4,7 @@ import shutil import subprocess import sys +from enum import IntEnum from typing import Callable import pydantic @@ -137,12 +138,13 @@ def assert_not_loaded(cache_key: str) -> None: loaded = src_hasher.load(fast_cache_key) assert loaded is not None assert loaded.used_py_dataclass_parameters == some_used_vars - # The new schema-v3 AST-resolved fields default to empty for kernels with no graph_do_while / checkpoint + # The new schema-v3+v4 AST-resolved fields default to empty for kernels with no graph_do_while / checkpoint # metadata, exercising the BaseModel default path on round-trip. assert loaded.graph_do_while_levels is None assert loaded.checkpoint_yield_on_args == [] assert loaded.checkpoint_yield_on_cpp_arg_ids == [] assert loaded.checkpoint_user_labels_by_cp_id == [] + assert loaded.checkpoint_user_label_enum_qualnames == [] @test_utils.test() @@ -194,6 +196,74 @@ def test_src_hasher_store_validate_round_trips_schema_v3_metadata( assert loaded.checkpoint_yield_on_args == cp_yield_args assert loaded.checkpoint_yield_on_cpp_arg_ids == cp_yield_cpp_ids assert loaded.checkpoint_user_labels_by_cp_id == cp_user_labels + # Plain-int labels record `None` in the parallel qualname column (no IntEnum identity to preserve). + assert loaded.checkpoint_user_label_enum_qualnames == [None, None, None] + + +@test_utils.test() +def test_src_hasher_intenum_qualname_round_trip( + monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, temporary_module +) -> None: + """Schema v4 (`cachevalue-v4-intenum-qualnames`) added a parallel `checkpoint_user_label_enum_qualnames` + column so an ``IntEnum`` cp_id round-trips through fast-cache restore as the original enum member rather than + the underlying int. ``src_hasher.store`` derives the qualname column from the live label list (which still + holds the original ``IntEnum`` instances) before pydantic int-coerces them; ``_resolve_intenum_member`` + re-imports the enum class on load. This test covers both the store-side derivation (mixed IntEnum / plain int + / None) and the load-side resolution (verifies identity is preserved, not just int equality).""" + test_files_path = pathlib.Path("tests/python/quadrants/lang/fast_caching/test_files") + offline_cache_path = tmp_path / "cache" + temp_import_path = tmp_path / "temp_import" + temp_import_path.mkdir(exist_ok=True) + qd_init_same_arch(offline_cache_file_path=str(offline_cache_path)) + monkeypatch.syspath_prepend(temp_import_path) + shutil.copy2(test_files_path / "child_diff_base.py", temp_import_path / "child_diff_v4_intenum.py") + mod = temporary_module("child_diff_v4_intenum") + info, _src = _wrap_inspect.get_source_info_and_src(mod.f1.fn) + fast_cache_key = src_hasher.create_cache_key(False, info, [], []) + assert fast_cache_key is not None + + # Reference the module-level enum below so it has a real importable qualname. + src_hasher.store( + "kernel_cache_key_v4", + fast_cache_key, + [info], + {"used_var"}, + checkpoint_yield_on_args=["flag", None, "flag"], + checkpoint_yield_on_cpp_arg_ids=[1, -1, 1], + checkpoint_user_labels_by_cp_id=[_HasherTestStage.LOAD, None, _HasherTestStage.REDUCE], + ) + + loaded = src_hasher.load(fast_cache_key) + assert loaded is not None + # Persisted raw labels are plain ints (pydantic coerced them); the qualname column is what carries identity. + assert loaded.checkpoint_user_labels_by_cp_id == [5, None, 9] + assert loaded.checkpoint_user_label_enum_qualnames == [ + f"{_HasherTestStage.__module__}.{_HasherTestStage.__qualname__}.LOAD", + None, + f"{_HasherTestStage.__module__}.{_HasherTestStage.__qualname__}.REDUCE", + ] + + # Resolver round-trip: rebuild each slot through `_resolve_intenum_member` and confirm enum identity (not + # just int-equality) is preserved. + resolved = [ + src_hasher._resolve_intenum_member(qn, lbl) + for qn, lbl in zip(loaded.checkpoint_user_label_enum_qualnames, loaded.checkpoint_user_labels_by_cp_id) + ] + assert resolved == [_HasherTestStage.LOAD, None, _HasherTestStage.REDUCE] + assert isinstance(resolved[0], _HasherTestStage) + assert isinstance(resolved[2], _HasherTestStage) + + # Resolver fallback: an unresolvable qualname (e.g. enum class moved/renamed since cache write) must drop + # back to the persisted int rather than raising, so a stale cache entry degrades gracefully. + assert src_hasher._resolve_intenum_member("nonexistent.Module.Stage.LOAD", 5) == 5 + + +# Top-level IntEnum used by `test_src_hasher_intenum_qualname_round_trip` so the resolver can re-import it via +# `importlib.import_module("tests.python.quadrants.lang.fast_caching.test_src_hasher")`. Lives at module scope +# (not inside the test) for the same reason `_FastcacheStage` / `_Stage` do in `test_checkpoint.py`. +class _HasherTestStage(IntEnum): + LOAD = 5 + REDUCE = 9 # Should be enough to run these on cpu I think, and anything involving diff --git a/tests/python/test_checkpoint.py b/tests/python/test_checkpoint.py index 5711311764..4f8769d519 100644 --- a/tests/python/test_checkpoint.py +++ b/tests/python/test_checkpoint.py @@ -1167,6 +1167,92 @@ def test_checkpoint_fastcache_restores_self_member_yield_on(tmp_path: pathlib.Pa assert proc.returncode == RET_SUCCESS +# Module-level IntEnum so `_resolve_intenum_member` can find it via importlib from the persisted qualname +# (`tests.python.test_checkpoint._FastcacheStage.LOAD`). The kernel below uses it as the cp_id so the fast-cache +# round-trip exercises the schema-v4 enum-identity preservation path. +class _FastcacheStage(IntEnum): + LOAD = 10 + REDUCE = 20 + + +@qd.kernel(graph=True, checkpoints=True, fastcache=True) +def _fastcache_intenum_kernel( + x: qd.types.ndarray(qd.i32, ndim=1), + flag: qd.types.ndarray(qd.i32, ndim=0), +): + with qd.checkpoint(_FastcacheStage.LOAD, yield_on=flag): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + with qd.checkpoint(_FastcacheStage.REDUCE, yield_on=flag): + for i in range(x.shape[0]): + x[i] = x[i] + 10 + + +def _fastcache_intenum_child(args: list[str]) -> None: + args_obj = _FastcacheCheckpointArgs.model_validate_json(args[0]) + qd.init( + arch=getattr(qd, args_obj.arch), + offline_cache=True, + offline_cache_file_path=args_obj.offline_cache_file_path, + src_ll_cache=True, + ) + + N = 4 + x = qd.ndarray(qd.i32, shape=(N,)) + flag = qd.ndarray(qd.i32, shape=()) + x.from_numpy(np.zeros(N, dtype=np.int32)) + flag.from_numpy(np.array(0, dtype=np.int32)) + _fastcache_intenum_kernel(x, flag) + np.testing.assert_array_equal(x.to_numpy(), np.full(N, 11, dtype=np.int32)) + + primal = _fastcache_intenum_kernel._primal + labels = primal.checkpoint_user_labels_by_cp_id + # The schema-v4 round-trip must rebuild the IntEnum identity, not just the int equality. A regression here + # would show up as `labels == [10, 20]` (plain ints) breaking the documented contract that + # `qd.checkpoint(Stage.X, ...)` surfaces as `Stage.X` (not the raw int) on `status.checkpoint`. + assert labels == [ + _FastcacheStage.LOAD, + _FastcacheStage.REDUCE, + ], f"checkpoint_user_labels_by_cp_id should round-trip with IntEnum identity, got {labels!r}" + assert all( + isinstance(lbl, _FastcacheStage) for lbl in labels + ), f"every label slot must be a _FastcacheStage instance, got {[type(lbl).__name__ for lbl in labels]!r}" + assert primal.src_ll_cache_observations.cache_loaded == args_obj.expect_loaded_from_fastcache + + print(TEST_RAN) + sys.exit(RET_SUCCESS) + + +@test_utils.test() +def test_checkpoint_fastcache_preserves_intenum_label_identity(tmp_path: pathlib.Path): + """Fast-cache restore must rebuild ``checkpoint_user_labels_by_cp_id`` with the original ``IntEnum`` members, + not just int-equal plain ints. Schema v4 adds a parallel ``checkpoint_user_label_enum_qualnames`` column so + ``_resolve_intenum_member`` can re-import the enum class on cache hit -- pydantic coerces ``IntEnum`` to + ``int`` at ``CacheValue`` construction, which would otherwise silently drop enum identity and break the + documented contract that ``qd.checkpoint(Stage.X, ...)`` surfaces as ``Stage.X`` (not the raw int) on + ``status.checkpoint`` after a fast-cache hit.""" + assert qd.lang is not None + arch = qd.lang.impl.current_cfg().arch.name + env = dict(os.environ) + env["PYTHONPATH"] = "." + + for expect_loaded in [False, True]: + args_obj = _FastcacheCheckpointArgs( + arch=arch, + offline_cache_file_path=str(tmp_path / "cache"), + expect_loaded_from_fastcache=expect_loaded, + ) + cmd_line = [sys.executable, __file__, _fastcache_intenum_child.__name__, args_obj.model_dump_json()] + proc = subprocess.run(cmd_line, capture_output=True, text=True, env=env) + if proc.returncode != RET_SUCCESS: + print(" ".join(cmd_line)) + print(proc.stdout) + print("-" * 100) + print(proc.stderr) + assert TEST_RAN in proc.stdout + assert proc.returncode == RET_SUCCESS + + # ---------------------------------------------------------------------------------------------------------------------- # CUDA-native introspection (slice 1c). # ---------------------------------------------------------------------------------------------------------------------- From 9c45a87c8645f3822fe4fe07501fcf45baeeee13 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 24 Jun 2026 09:10:12 -0700 Subject: [PATCH 4/5] [Graph] Extract _resolve_ndarray_kernel_arg_id into ast_transformers/ndarray_arg_resolver.py CI's `Check feature factorization` flagged the ~40-line `ASTTransformer._resolve_ndarray_kernel_arg_id` static method as carving a new feature into the central 1705-line `ast_transformer.py`. The bot's suggested fix matches the existing pattern adjacent to `_is_checkpoint_call` (thin forwarding wrapper on `ASTTransformer`, real logic in a sibling `ast_transformers/*_transformer.py` file). Moved the resolver to `python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py` as a free function `resolve_ndarray_kernel_arg_id`. `ASTTransformer._resolve_ndarray_kernel_arg_id` is now a one-line forwarding wrapper (kept so existing call sites in `build_While` and any third-party callers continue to work), and `CheckpointTransformer.build_checkpoint_with` imports the free function directly instead of going through the wrapper. The local-import dance dodges the `ast_transformers -> ast_transformer` cycle the same way the existing `_is_checkpoint_call` / `CheckpointTransformer` split does. --- python/quadrants/lang/ast/ast_transformer.py | 41 +++-------- .../checkpoint_transformer.py | 18 ++--- .../ast_transformers/ndarray_arg_resolver.py | 71 +++++++++++++++++++ 3 files changed, 90 insertions(+), 40 deletions(-) create mode 100644 python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 29f86dde33..4a9d3e1883 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1379,39 +1379,16 @@ def _resolve_ndarray_kernel_arg_id( node: ast.expr, usage: str, ) -> tuple[str, int]: - """Resolve an ndarray-referencing expression to ``(label, flat_cpp_arg_id)`` at AST-build time. - - Shared between ``qd.checkpoint(yield_on=...)`` and ``qd.graph_do_while(...)`` to turn the control-flag - argument into the flat C++ arg-id the runtime matches against. ``node`` is an ``ast.Name`` (a bare kernel - parameter, e.g. ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` - owner, or ``params.flag`` where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). We build the - expression through the normal AST machinery and read the arg-id off the resulting external-tensor - expression -- this unifies the bare-param and member-ndarray cases, since both flatten to a real ndarray - kernel argument carrying its arg-id on the ``ExternalTensorExpression``. - - ``usage`` is the call form (e.g. ``"qd.checkpoint(yield_on=...)"``) used in the error message. Raises - ``QuadrantsSyntaxError`` if the expression does not resolve to an ndarray kernel argument. - """ - from quadrants.lang.any_array import AnyArray # pylint: disable=C0415 - - label = ast.unparse(node) - bad = QuadrantsSyntaxError( - f"{usage} got {label!r} which does not resolve to an ndarray kernel parameter of " - f"{kernel.func.__name__!r}. The argument must reference an ndarray kernel parameter (e.g. " - f"`flag`) or a @qd.data_oriented member ndarray (e.g. `self.flag`); other expressions are not " - f"supported." + """Thin forwarding wrapper around ``ndarray_arg_resolver.resolve_ndarray_kernel_arg_id``; the actual logic + lives in module ``ast_transformers/ndarray_arg_resolver.py`` to keep this file from growing per-feature + (same pattern as ``_is_checkpoint_call`` / ``CheckpointTransformer``). Returns ``(label, flat_cpp_arg_id)`` + or raises ``QuadrantsSyntaxError``.""" + # pylint: disable-next=C0415,import-outside-toplevel + from quadrants.lang.ast.ast_transformers.ndarray_arg_resolver import ( + resolve_ndarray_kernel_arg_id, ) - try: - built = build_stmt(ctx, node) - except Exception as e: # noqa: BLE001 - any resolution failure is a user-facing misuse - raise bad from e - resolved_expr = built.ptr if isinstance(built, AnyArray) else built - if not (hasattr(resolved_expr, "is_external_tensor_expr") and resolved_expr.is_external_tensor_expr()): - raise bad - arg_id = _qd_core.get_external_tensor_arg_id(resolved_expr) - if not arg_id: - raise bad - return label, int(arg_id[0]) + + return resolve_ndarray_kernel_arg_id(ctx, kernel, node, usage) @staticmethod def _is_checkpoint_call(node: ast.expr, global_vars: dict): diff --git a/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py b/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py index 9ac922a1e7..3e092bb2e8 100644 --- a/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py @@ -228,16 +228,18 @@ def build_checkpoint_with( yield_on_label: str | None = None yield_on_cpp_arg_id: int = -1 if not info.is_implicit: - # Resolve `yield_on=` (a bare parameter or `@qd.data_oriented` member ndarray) to its flat C++ arg-id at - # AST-build time. ``_resolve_ndarray_kernel_arg_id`` raises a user-facing ``QuadrantsSyntaxError`` if the - # expression does not name a real ndarray kernel argument, which keeps the diagnostic at the `with` site - # instead of leaking into the launcher. Both the label (for ``checkpoint_yield_on_args`` / introspection) - # and the resolved arg-id (for the runtime) are stashed and forwarded to the launch path below. - # Local import to avoid an ast_transformers -> ast_transformer cycle. + # Resolve `yield_on=` (a bare parameter or `@qd.data_oriented` / `@dataclasses.dataclass` member + # ndarray) to its flat C++ arg-id at AST-build time. ``resolve_ndarray_kernel_arg_id`` raises a + # user-facing ``QuadrantsSyntaxError`` if the expression does not name a real ndarray kernel argument, + # which keeps the diagnostic at the `with` site instead of leaking into the launcher. Both the label + # (for ``checkpoint_yield_on_args`` / introspection) and the resolved arg-id (for the runtime) are + # stashed and forwarded to the launch path below. # pylint: disable-next=C0415,import-outside-toplevel - from quadrants.lang.ast.ast_transformer import ASTTransformer + from quadrants.lang.ast.ast_transformers.ndarray_arg_resolver import ( + resolve_ndarray_kernel_arg_id, + ) - yield_on_label, yield_on_cpp_arg_id = ASTTransformer._resolve_ndarray_kernel_arg_id( + yield_on_label, yield_on_cpp_arg_id = resolve_ndarray_kernel_arg_id( ctx, kernel, info.yield_on_node, "qd.checkpoint(yield_on=...)" ) # Reject duplicate user-supplied cp_id labels. diff --git a/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py b/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py new file mode 100644 index 0000000000..590b40a570 --- /dev/null +++ b/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py @@ -0,0 +1,71 @@ +# type: ignore +"""Resolve an ndarray-referencing AST expression to its flat C++ kernel arg-id at AST-build time. + +Lives alongside ``checkpoint_transformer.py`` / ``call_transformer.py`` so the central ``ast_transformer.py`` file +doesn't have to grow per-feature. Both ``qd.checkpoint(yield_on=...)`` (via +``CheckpointTransformer.build_checkpoint_with``) and ``qd.graph_do_while(...)`` (via ``ASTTransformer.build_While``) +share this helper: each form takes a control-flag / counter ndarray argument that may be a bare kernel parameter +(e.g. ``flag``), a ``@qd.data_oriented`` member ndarray (e.g. ``self.flag``), or a ``@dataclasses.dataclass`` +parameter member (e.g. ``params.flag``). All three flatten to the same ``ExternalTensorExpression`` after AST build, +so resolving the arg-id here -- once, at kernel build time -- means the runtime launch path can forward it directly +with no per-launch name matching (which was the original ``_checkpoint_helpers`` approach and is incompatible with +member-ndarray flattening, since the flattened name is synthesised). + +See ``docs/source/user_guide/graph.md`` for the user-facing surface and ``perso_hugh/doc/qipc/reentrant.md`` for the +design. +""" + +from __future__ import annotations + +import ast + +from quadrants.lang.ast.ast_transformer_utils import ASTTransformerFuncContext +from quadrants.lang.exception import QuadrantsSyntaxError + + +def resolve_ndarray_kernel_arg_id( + ctx: ASTTransformerFuncContext, + kernel, + node: ast.expr, + usage: str, +) -> tuple[str, int]: + """Resolve ``node`` to ``(label, flat_cpp_arg_id)`` at AST-build time. + + Shared between ``qd.checkpoint(yield_on=...)`` and ``qd.graph_do_while(...)`` to turn the control-flag argument + into the flat C++ arg-id the runtime matches against. ``node`` is an ``ast.Name`` (a bare kernel parameter, + e.g. ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` owner, or + ``params.flag`` where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). We build the expression + through the normal AST machinery and read the arg-id off the resulting external-tensor expression -- this + unifies the bare-param and member-ndarray cases, since both flatten to a real ndarray kernel argument carrying + its arg-id on the ``ExternalTensorExpression``. + + ``usage`` is the call form (e.g. ``"qd.checkpoint(yield_on=...)"``) used in the error message. Raises + ``QuadrantsSyntaxError`` if the expression does not resolve to an ndarray kernel argument. + """ + # Local imports to avoid an ast_transformers -> ast_transformer / any_array import cycle at module load: + # ``ast_transformer`` is the central transformer module that imports ``checkpoint_transformer`` (sibling of + # this file), and ``any_array`` pulls in core ndarray bindings that aren't needed for module import. + # pylint: disable-next=C0415,import-outside-toplevel + from quadrants.lang.ast.ast_transformer import _qd_core, build_stmt + + # pylint: disable-next=C0415,import-outside-toplevel + from quadrants.lang.any_array import AnyArray + + label = ast.unparse(node) + bad = QuadrantsSyntaxError( + f"{usage} got {label!r} which does not resolve to an ndarray kernel parameter of " + f"{kernel.func.__name__!r}. The argument must reference an ndarray kernel parameter (e.g. " + f"`flag`) or a @qd.data_oriented member ndarray (e.g. `self.flag`); other expressions are not " + f"supported." + ) + try: + built = build_stmt(ctx, node) + except Exception as e: # noqa: BLE001 - any resolution failure is a user-facing misuse + raise bad from e + resolved_expr = built.ptr if isinstance(built, AnyArray) else built + if not (hasattr(resolved_expr, "is_external_tensor_expr") and resolved_expr.is_external_tensor_expr()): + raise bad + arg_id = _qd_core.get_external_tensor_arg_id(resolved_expr) + if not arg_id: + raise bad + return label, int(arg_id[0]) From f3c9cbe940cacae4b6cc42ce057fb30ae828db10 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Thu, 25 Jun 2026 05:36:36 -0700 Subject: [PATCH 5/5] [Graph] Rewrap PR comments/docstrings to project's 120c target MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CI's line-wrap check flagged several comments / docstrings in this PR wrapped at the AI default (~80-90c) instead of the project's 120c target. Reflowed every prose run in files I authored that the cursor `rewrap-comments-120c` skill surfaced (cursor.directory rule: rewrap reported runs unless they're tables / doctests / oversized tokens). Pre-existing paragraphs touched only by sed-style renames in earlier commits (`taichi`->`quadrants` etc.) are intentionally left alone, per the prior PR feedback to only rewrap prose I actually changed. All my added lines now ≤120c; pre-commit is clean. --- .../lang/_fast_caching/src_hasher.py | 84 ++++++++-------- python/quadrants/lang/ast/ast_transformer.py | 24 ++--- .../checkpoint_transformer.py | 28 +++--- .../ast_transformers/ndarray_arg_resolver.py | 24 ++--- python/quadrants/lang/kernel.py | 36 +++---- .../lang/fast_caching/test_src_hasher.py | 36 +++---- tests/python/test_checkpoint.py | 96 +++++++++---------- tests/python/test_graph_do_while.py | 25 +++-- 8 files changed, 175 insertions(+), 178 deletions(-) diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 45b7fc9428..fc443ac481 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -20,23 +20,23 @@ from .python_side_cache import PythonSideCache # Bumped whenever the persisted CacheValue schema changes (see create_cache_key). v2 replaced the single -# graph_do_while_arg string with a nested level table. v3 added the AST-resolved flat C++ arg-ids for -# qd.graph_do_while conditions and qd.checkpoint(yield_on=...) targets so the launch path can forward them -# directly without per-launch name matching (necessary for @qd.data_oriented member ndarrays). v4 added the -# per-slot `checkpoint_user_label_enum_qualnames` table so an IntEnum cp_id (e.g. `qd.checkpoint(Stage.SIM, ...)`) -# round-trips through fast-cache restore as the original IntEnum member rather than the underlying int. +# graph_do_while_arg string with a nested level table. v3 added the AST-resolved flat C++ arg-ids for qd.graph_do_while +# conditions and qd.checkpoint(yield_on=...) targets so the launch path can forward them directly without per-launch +# name matching (necessary for @qd.data_oriented member ndarrays). v4 added the per-slot +# `checkpoint_user_label_enum_qualnames` table so an IntEnum cp_id (e.g. `qd.checkpoint(Stage.SIM, ...)`) round-trips +# through fast-cache restore as the original IntEnum member rather than the underlying int. _CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v4-intenum-qualnames" def _intenum_member_qualname(value: Any) -> str | None: """Return ``"module.ClassQualName.MEMBER"`` for an ``IntEnum`` member, else ``None``. - Stored alongside ``checkpoint_user_labels_by_cp_id`` so that ``_resolve_intenum_member`` can rebuild the - original enum member on fast-cache restore -- pydantic coerces ``IntEnum`` to plain ``int`` at ``CacheValue`` - construction time (it sees ``list[int | None]``), which would otherwise silently break the documented - contract that ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` rather than the raw int through - ``status.checkpoint``. Returns ``None`` for plain ints, ``None`` labels, anonymous enums (no ``__module__``), - and other unsupported shapes -- the loader falls back to the raw int in those cases. + Stored alongside ``checkpoint_user_labels_by_cp_id`` so that ``_resolve_intenum_member`` can rebuild the original + enum member on fast-cache restore -- pydantic coerces ``IntEnum`` to plain ``int`` at ``CacheValue`` construction + time (it sees ``list[int | None]``), which would otherwise silently break the documented contract that + ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` rather than the raw int through ``status.checkpoint``. + Returns ``None`` for plain ints, ``None`` labels, anonymous enums (no ``__module__``), and other unsupported + shapes -- the loader falls back to the raw int in those cases. """ if not isinstance(value, IntEnum): return None @@ -52,20 +52,20 @@ def _intenum_member_qualname(value: Any) -> str | None: def _resolve_intenum_member(qualname: str | None, fallback: int | None) -> int | IntEnum | None: """Inverse of ``_intenum_member_qualname``: look up the enum member by ``"module.ClassQualName.MEMBER"``. - Returns the resolved ``IntEnum`` member if every step (module import, attribute walk) succeeds AND the member's - int value matches ``fallback`` (the raw int from ``checkpoint_user_labels_by_cp_id`` we already persisted). - Mismatch or any failure -- module renamed since the cache was written, enum class refactored, member removed, - etc. -- falls back to ``fallback`` so the user still gets a usable (if enum-identity-less) label rather than a - hard crash. ``None`` qualname / ``None`` fallback short-circuit to ``fallback`` for the plain-int label case. + Returns the resolved ``IntEnum`` member if every step (module import, attribute walk) succeeds AND the member's int + value matches ``fallback`` (the raw int from ``checkpoint_user_labels_by_cp_id`` we already persisted). Mismatch or + any failure -- module renamed since the cache was written, enum class refactored, member removed, etc. -- falls back + to ``fallback`` so the user still gets a usable (if enum-identity-less) label rather than a hard crash. ``None`` + qualname / ``None`` fallback short-circuit to ``fallback`` for the plain-int label case. """ if qualname is None or fallback is None: return fallback try: - # qualname is "module.path.Class[.Nested].MEMBER"; the MEMBER tail is always one segment, so rsplit once. - # The remaining cls_path mixes dotted module path + dotted class qualname; we try progressively shorter - # module prefixes until one imports, then resolve the rest as attribute chain. This handles top-level - # enums (``mymod.Stage.LOAD``), enums nested in classes (``mymod.Outer.Inner.MEMBER``), and enums in - # subpackages (``a.b.Stage.LOAD``) without needing the user to declare which prefix is the module. + # qualname is "module.path.Class[.Nested].MEMBER"; the MEMBER tail is always one segment, so rsplit once. The + # remaining cls_path mixes dotted module path + dotted class qualname; we try progressively shorter module + # prefixes until one imports, then resolve the rest as attribute chain. This handles top-level enums + # (``mymod.Stage.LOAD``), enums nested in classes (``mymod.Outer.Inner.MEMBER``), and enums in subpackages + # (``a.b.Stage.LOAD``) without needing the user to declare which prefix is the module. cls_path, _, member_name = qualname.rpartition(".") if not cls_path or not member_name: return fallback @@ -140,27 +140,27 @@ class CacheValue(BaseModel): frontend_cache_key: str hashed_function_source_infos: list[HashedFunctionSourceInfo] used_py_dataclass_parameters: set[str] - # Nested graph_do_while level table as (cond_arg_name, parent_id, cond_cpp_arg_id) triples, indexed by level - # id. None / empty for kernels without graph_do_while. ``cond_cpp_arg_id`` is the flat C++ arg-id resolved at - # AST-build time by ``ASTTransformer._resolve_ndarray_kernel_arg_id`` and is required by the launch path to - # support `@qd.data_oriented` member conditions (`qd.graph_do_while(self.counter)`) -- name-matching against - # ``arg_metas`` only resolves top-level parameters. + # Nested graph_do_while level table as (cond_arg_name, parent_id, cond_cpp_arg_id) triples, indexed by level id. + # None / empty for kernels without graph_do_while. ``cond_cpp_arg_id`` is the flat C++ arg-id resolved at AST-build + # time by ``ASTTransformer._resolve_ndarray_kernel_arg_id`` and is required by the launch path to support + # `@qd.data_oriented` member conditions (`qd.graph_do_while(self.counter)`) -- name-matching against ``arg_metas`` + # only resolves top-level parameters. graph_do_while_levels: list[tuple[str, int, int]] | None = None # AST-build-time-resolved checkpoint metadata, indexed by internal cp_id. Empty for kernels without any # `with qd.checkpoint(...)` block. See `Kernel.checkpoint_yield_on_args` / # `Kernel.checkpoint_yield_on_cpp_arg_ids` / `Kernel.checkpoint_user_labels_by_cp_id` for what each entry means. - # Restored alongside the C++-side cached kernel so the launch path can forward `yield_on=` arg-ids and - # translate `from_checkpoint=` labels without re-running the AST transformer. + # Restored alongside the C++-side cached kernel so the launch path can forward `yield_on=` arg-ids and translate + # `from_checkpoint=` labels without re-running the AST transformer. checkpoint_yield_on_args: list[str | None] = [] checkpoint_yield_on_cpp_arg_ids: list[int] = [] checkpoint_user_labels_by_cp_id: list[int | None] = [] - # Parallel to ``checkpoint_user_labels_by_cp_id``: each entry is the dotted ``module.ClassQualName.MEMBER`` of - # the original ``IntEnum`` member the user passed as ``cp_id``, or ``None`` if the user passed a plain int (or - # for implicit auto-wrap checkpoints). On fast-cache restore the loader runs each entry through + # Parallel to ``checkpoint_user_labels_by_cp_id``: each entry is the dotted ``module.ClassQualName.MEMBER`` of the + # original ``IntEnum`` member the user passed as ``cp_id``, or ``None`` if the user passed a plain int (or for + # implicit auto-wrap checkpoints). On fast-cache restore the loader runs each entry through # ``_resolve_intenum_member`` to rebuild the IntEnum, preserving the documented contract that - # ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` (not the underlying int) through - # ``status.checkpoint`` and ``kernel.resume(from_checkpoint=...)`` -- pydantic coerces IntEnum to int at - # ``CacheValue`` construction time so the parallel qualname column is what carries the enum identity. + # ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` (not the underlying int) through ``status.checkpoint`` and + # ``kernel.resume(from_checkpoint=...)`` -- pydantic coerces IntEnum to int at ``CacheValue`` construction time so + # the parallel qualname column is what carries the enum identity. checkpoint_user_label_enum_qualnames: list[str | None] = [] @@ -174,11 +174,11 @@ def store( checkpoint_yield_on_cpp_arg_ids: list[int] | None = None, checkpoint_user_labels_by_cp_id: list[int | None] | None = None, ) -> None: - # `checkpoint_user_label_enum_qualnames` is derived from `checkpoint_user_labels_by_cp_id` here (rather than - # being plumbed through a separate kwarg from `Kernel.materialize`) so callers never have to think about the - # parallel column: they pass the live label list (which still holds the original ``IntEnum`` instances at - # store time, before pydantic's int-coercion strips identity in ``CacheValue.__init__``), and the qualname - # snapshot is recorded once here for the loader to consume. + # `checkpoint_user_label_enum_qualnames` is derived from `checkpoint_user_labels_by_cp_id` here (rather than being + # plumbed through a separate kwarg from `Kernel.materialize`) so callers never have to think about the parallel + # column: they pass the live label list (which still holds the original ``IntEnum`` instances at store time, before + # pydantic's int-coercion strips identity in ``CacheValue.__init__``), and the qualname snapshot is recorded once + # here for the loader to consume. """ Note that unlike other caches, this cache is not going to store the actual value we want. This cache is only used for verification that our cache key is valid. Big picture: @@ -232,9 +232,9 @@ def _try_load(cache_key: str) -> CacheValue | None: def load(cache_key: str) -> CacheValue | None: """Load a validated ``CacheValue`` for *cache_key* if one exists and its source hashes still match, else None. - Returns the full ``CacheValue`` (rather than the historical 3-tuple) so callers can pick off the - AST-transformer-produced metadata (graph_do_while levels, checkpoint tables) without the loader having to grow - a new return slot every time we cache a new piece of AST output. + Returns the full ``CacheValue`` (rather than the historical 3-tuple) so callers can pick off the AST-transformer- + produced metadata (graph_do_while levels, checkpoint tables) without the loader having to grow a new return slot + every time we cache a new piece of AST output. """ cache_value = _try_load(cache_key) if cache_value is None: diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 4a9d3e1883..cd6a85e33d 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1355,10 +1355,10 @@ def build_For(ctx: ASTTransformerFuncContext, node: ast.For) -> None: def _is_graph_do_while_call(node: ast.expr) -> ast.expr | None: """If *node* is ``qd.graph_do_while(arg)`` return the arg AST node, else None. - ``arg`` may be an ``ast.Name`` (a bare kernel parameter, e.g. ``counter``) or an ``ast.Attribute`` chain - (a ``@qd.data_oriented`` member ndarray such as ``self.counter`` or a ``@dataclasses.dataclass`` parameter - member such as ``params.counter``). The actual resolution to a kernel ndarray argument happens in - ``build_While`` via ``_resolve_ndarray_kernel_arg_id``. + ``arg`` may be an ``ast.Name`` (a bare kernel parameter, e.g. ``counter``) or an ``ast.Attribute`` chain (a + ``@qd.data_oriented`` member ndarray such as ``self.counter`` or a ``@dataclasses.dataclass`` parameter member + such as ``params.counter``). The actual resolution to a kernel ndarray argument happens in ``build_While`` via + ``_resolve_ndarray_kernel_arg_id``. """ if not isinstance(node, ast.Call): return None @@ -1379,10 +1379,10 @@ def _resolve_ndarray_kernel_arg_id( node: ast.expr, usage: str, ) -> tuple[str, int]: - """Thin forwarding wrapper around ``ndarray_arg_resolver.resolve_ndarray_kernel_arg_id``; the actual logic - lives in module ``ast_transformers/ndarray_arg_resolver.py`` to keep this file from growing per-feature - (same pattern as ``_is_checkpoint_call`` / ``CheckpointTransformer``). Returns ``(label, flat_cpp_arg_id)`` - or raises ``QuadrantsSyntaxError``.""" + """Thin forwarding wrapper around ``ndarray_arg_resolver.resolve_ndarray_kernel_arg_id``; the actual logic lives + in module ``ast_transformers/ndarray_arg_resolver.py`` to keep this file from growing per-feature (same pattern + as ``_is_checkpoint_call`` / ``CheckpointTransformer``). Returns ``(label, flat_cpp_arg_id)`` or raises + ``QuadrantsSyntaxError``.""" # pylint: disable-next=C0415,import-outside-toplevel from quadrants.lang.ast.ast_transformers.ndarray_arg_resolver import ( resolve_ndarray_kernel_arg_id, @@ -1417,10 +1417,10 @@ def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: "qd.graph_do_while() must be at the kernel top level or directly nested inside " "another qd.graph_do_while(); it cannot appear inside a for-loop." ) - # Resolve the condition ndarray (bare parameter or @qd.data_oriented member) to its flat C++ arg-id at - # AST-build time -- the same id the runtime needs -- so the launch path forwards it directly with no - # per-launch name matching. ``cond_arg_name`` keeps the readable label (e.g. "counter" or "self.counter") - # for introspection and for the legacy ``graph_do_while_arg`` alias surfaced on Kernel. + # Resolve the condition ndarray (bare parameter or @qd.data_oriented member) to its flat C++ arg-id at AST- + # build time -- the same id the runtime needs -- so the launch path forwards it directly with no per-launch + # name matching. ``cond_arg_name`` keeps the readable label (e.g. "counter" or "self.counter") for + # introspection and for the legacy ``graph_do_while_arg`` alias surfaced on Kernel. cond_label, cond_cpp_arg_id = ASTTransformer._resolve_ndarray_kernel_arg_id( ctx, kernel, graph_do_while_node, "qd.graph_do_while(...)" ) diff --git a/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py b/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py index 3e092bb2e8..251537347f 100644 --- a/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py @@ -38,12 +38,12 @@ class CheckpointCallInfo: - ``cp_id``: the user-supplied label (an ``int`` or ``IntEnum`` value), or ``None`` for an auto-wrap implicit checkpoint. - - ``yield_on_node``: the ``ast.expr`` passed as ``yield_on=`` -- either an ``ast.Name`` (bare kernel parameter, - e.g. ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` owner, or - ``params.flag`` where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). ``None`` for an implicit - checkpoint. ``build_checkpoint_with`` resolves the node to a flat C++ arg-id via - ``ASTTransformer._resolve_ndarray_kernel_arg_id`` so the runtime can forward it directly without per-launch - name matching. + - ``yield_on_node``: the ``ast.expr`` passed as ``yield_on=`` -- either an ``ast.Name`` (bare kernel parameter, e.g. + ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` owner, or ``params.flag`` + where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). ``None`` for an implicit checkpoint. + ``build_checkpoint_with`` resolves the node to a flat C++ arg-id via + ``ASTTransformer._resolve_ndarray_kernel_arg_id`` so the runtime can forward it directly without per-launch name + matching. - ``is_implicit``: ``True`` iff this Call was synthesised by ``auto_wrap_for_loops``. """ @@ -179,8 +179,8 @@ def is_checkpoint_call(node: ast.expr, global_vars: dict) -> CheckpointCallInfo ) # `yield_on=` must point at an ndarray kernel argument -- a bare parameter (`yield_on=flag`), a # `@qd.data_oriented` member (`yield_on=self.flag`), or a `@dataclasses.dataclass` parameter member - # (`yield_on=params.flag`). Other expressions can't be lowered to a flat arg-id and are rejected here so - # the user gets a clear compile-time error at the `with` site. + # (`yield_on=params.flag`). Other expressions can't be lowered to a flat arg-id and are rejected here so the + # user gets a clear compile-time error at the `with` site. if not isinstance(yield_on_arg, (ast.Name, ast.Attribute)): raise QuadrantsSyntaxError( "qd.checkpoint(yield_on=...) must reference a kernel ndarray argument -- e.g. `yield_on=flag` for " @@ -228,12 +228,12 @@ def build_checkpoint_with( yield_on_label: str | None = None yield_on_cpp_arg_id: int = -1 if not info.is_implicit: - # Resolve `yield_on=` (a bare parameter or `@qd.data_oriented` / `@dataclasses.dataclass` member - # ndarray) to its flat C++ arg-id at AST-build time. ``resolve_ndarray_kernel_arg_id`` raises a - # user-facing ``QuadrantsSyntaxError`` if the expression does not name a real ndarray kernel argument, - # which keeps the diagnostic at the `with` site instead of leaking into the launcher. Both the label - # (for ``checkpoint_yield_on_args`` / introspection) and the resolved arg-id (for the runtime) are - # stashed and forwarded to the launch path below. + # Resolve `yield_on=` (a bare parameter or `@qd.data_oriented` / `@dataclasses.dataclass` member ndarray) to + # its flat C++ arg-id at AST-build time. ``resolve_ndarray_kernel_arg_id`` raises a user-facing + # ``QuadrantsSyntaxError`` if the expression does not name a real ndarray kernel argument, which keeps the + # diagnostic at the `with` site instead of leaking into the launcher. Both the label (for + # ``checkpoint_yield_on_args`` / introspection) and the resolved arg-id (for the runtime) are stashed and + # forwarded to the launch path below. # pylint: disable-next=C0415,import-outside-toplevel from quadrants.lang.ast.ast_transformers.ndarray_arg_resolver import ( resolve_ndarray_kernel_arg_id, diff --git a/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py b/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py index 590b40a570..21f328a93b 100644 --- a/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py +++ b/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py @@ -31,25 +31,25 @@ def resolve_ndarray_kernel_arg_id( ) -> tuple[str, int]: """Resolve ``node`` to ``(label, flat_cpp_arg_id)`` at AST-build time. - Shared between ``qd.checkpoint(yield_on=...)`` and ``qd.graph_do_while(...)`` to turn the control-flag argument - into the flat C++ arg-id the runtime matches against. ``node`` is an ``ast.Name`` (a bare kernel parameter, - e.g. ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` owner, or - ``params.flag`` where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). We build the expression - through the normal AST machinery and read the arg-id off the resulting external-tensor expression -- this - unifies the bare-param and member-ndarray cases, since both flatten to a real ndarray kernel argument carrying - its arg-id on the ``ExternalTensorExpression``. + Shared between ``qd.checkpoint(yield_on=...)`` and ``qd.graph_do_while(...)`` to turn the control-flag argument into + the flat C++ arg-id the runtime matches against. ``node`` is an ``ast.Name`` (a bare kernel parameter, e.g. + ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` owner, or ``params.flag`` + where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). We build the expression through the normal AST + machinery and read the arg-id off the resulting external-tensor expression -- this unifies the bare-param and + member-ndarray cases, since both flatten to a real ndarray kernel argument carrying its arg-id on the + ``ExternalTensorExpression``. ``usage`` is the call form (e.g. ``"qd.checkpoint(yield_on=...)"``) used in the error message. Raises ``QuadrantsSyntaxError`` if the expression does not resolve to an ndarray kernel argument. """ # Local imports to avoid an ast_transformers -> ast_transformer / any_array import cycle at module load: - # ``ast_transformer`` is the central transformer module that imports ``checkpoint_transformer`` (sibling of - # this file), and ``any_array`` pulls in core ndarray bindings that aren't needed for module import. - # pylint: disable-next=C0415,import-outside-toplevel + # ``ast_transformer`` is the central transformer module that imports ``checkpoint_transformer`` (sibling of this + # file), and ``any_array`` pulls in core ndarray bindings that aren't needed for module import. + # pylint: disable=import-outside-toplevel + from quadrants.lang.any_array import AnyArray from quadrants.lang.ast.ast_transformer import _qd_core, build_stmt - # pylint: disable-next=C0415,import-outside-toplevel - from quadrants.lang.any_array import AnyArray + # pylint: enable=import-outside-toplevel label = ast.unparse(node) bad = QuadrantsSyntaxError( diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 9d7efe84f3..ecee93e0aa 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -349,14 +349,14 @@ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _is_classkernel # kernel uses no checkpoints. Used for error messages / introspection only -- the runtime forwards the flat # C++ arg-id from `checkpoint_yield_on_cpp_arg_ids` below. self.checkpoint_yield_on_args: list[str | None] = [] - # Flat C++ arg-ids (post-template) of each explicit checkpoint's `yield_on=` ndarray, resolved at AST-build - # time by `CheckpointTransformer.build_checkpoint_with` via `ASTTransformer._resolve_ndarray_kernel_arg_id`. - # Same indexing as `checkpoint_yield_on_args`: entry `i` is the flat arg-id the runtime uses to look up the - # ndarray's device pointer for the checkpoint whose internal cp_id is `i`. `-1` for implicit checkpoints - # (which never yield). Resolving at AST-build time uniformly handles bare kernel parameters - # (`yield_on=flag`), `@qd.data_oriented` member ndarrays (`yield_on=self.flag`), and - # `@dataclasses.dataclass` parameter members (`yield_on=params.flag`); the attribute forms cannot be - # resolved by the per-launch name match because `arg_metas[i].name` only carries top-level parameter names. + # Flat C++ arg-ids (post-template) of each explicit checkpoint's `yield_on=` ndarray, resolved at AST-build time + # by `CheckpointTransformer.build_checkpoint_with` via `ASTTransformer._resolve_ndarray_kernel_arg_id`. Same + # indexing as `checkpoint_yield_on_args`: entry `i` is the flat arg-id the runtime uses to look up the ndarray's + # device pointer for the checkpoint whose internal cp_id is `i`. `-1` for implicit checkpoints (which never + # yield). Resolving at AST-build time uniformly handles bare kernel parameters (`yield_on=flag`), + # `@qd.data_oriented` member ndarrays (`yield_on=self.flag`), and `@dataclasses.dataclass` parameter members + # (`yield_on=params.flag`); the attribute forms cannot be resolved by the per-launch name match because + # `arg_metas[i].name` only carries top-level parameter names. self.checkpoint_yield_on_cpp_arg_ids: list[int] = [] # User-facing labels for explicit checkpoints. Same indexing as `checkpoint_yield_on_args`: entry `i` is the int # (or IntEnum value) the user passed as the first positional arg of `qd.checkpoint(cp_id, yield_on)` for the @@ -425,10 +425,10 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType if self.compiled_kernel_data_by_key[key]: self.src_ll_cache_observations.cache_loaded = True self.used_py_dataclass_parameters_by_key_enforcing[key] = cache_value.used_py_dataclass_parameters - # Fast-cache restore skips AST transformation, so rebuild the AST-transformer-produced metadata - # from the cache value: nested graph_do_while level table (with the AST-resolved flat C++ arg-id) - # plus the per-checkpoint yield_on / user-label tables. Mirrors what - # `function_def_transformer.py` + `checkpoint_transformer.py` + `build_While` would have written. + # Fast-cache restore skips AST transformation, so rebuild the AST-transformer-produced metadata from + # the cache value: nested graph_do_while level table (with the AST-resolved flat C++ arg-id) plus + # the per-checkpoint yield_on / user-label tables. Mirrors what `function_def_transformer.py` + + # `checkpoint_transformer.py` + `build_While` would have written. if cache_value.graph_do_while_levels: self.graph_do_while_levels = [ GraphDoWhileLevel(cond_arg_name=name, parent_id=parent, cond_cpp_arg_id=cpp_arg_id) @@ -438,13 +438,13 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType if cache_value.checkpoint_yield_on_args: self.checkpoint_yield_on_args = list(cache_value.checkpoint_yield_on_args) self.checkpoint_yield_on_cpp_arg_ids = list(cache_value.checkpoint_yield_on_cpp_arg_ids) - # Pydantic coerces IntEnum -> int at CacheValue construction time, so the raw labels are - # plain ints after JSON round-trip. ``checkpoint_user_label_enum_qualnames`` carries the - # parallel ``module.ClassQualName.MEMBER`` strings that ``_resolve_intenum_member`` uses - # to rebuild the original ``IntEnum`` member -- preserving the documented contract that + # Pydantic coerces IntEnum -> int at CacheValue construction time, so the raw labels are plain + # ints after JSON round-trip. ``checkpoint_user_label_enum_qualnames`` carries the parallel + # ``module.ClassQualName.MEMBER`` strings that ``_resolve_intenum_member`` uses to rebuild the + # original ``IntEnum`` member -- preserving the documented contract that # ``qd.checkpoint(Stage.X, ...)`` surfaces as ``Stage.X`` (not the raw int) on - # ``status.checkpoint``. Older v3 caches predate the qualname column, so we default any - # missing slots to ``None`` -> raw-int fallback (the same behaviour they had on v3). + # ``status.checkpoint``. Older v3 caches predate the qualname column, so we default any missing + # slots to ``None`` -> raw-int fallback (the same behaviour they had on v3). raw_labels = list(cache_value.checkpoint_user_labels_by_cp_id) qualnames = list(cache_value.checkpoint_user_label_enum_qualnames) or [None] * len(raw_labels) if len(qualnames) != len(raw_labels): diff --git a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py index 2dd6dd6683..b2131e5817 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py @@ -151,12 +151,12 @@ def assert_not_loaded(cache_key: str) -> None: def test_src_hasher_store_validate_round_trips_schema_v3_metadata( monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, temporary_module ) -> None: - """Schema v3 (`cachevalue-v3-ast-resolved-ids`) added AST-resolved arg-id fields to the persisted ``CacheValue`` - so the launch path can forward them after a fast-cache restore (which skips AST transformation). This test - pins the round-trip for the new fields -- ``graph_do_while_levels`` as 3-tuples carrying ``cond_cpp_arg_id``, - plus ``checkpoint_yield_on_args`` / ``checkpoint_yield_on_cpp_arg_ids`` / ``checkpoint_user_labels_by_cp_id``. - Without this, a schema bug (wrong tuple arity, dropped field, mis-typed BaseModel default) would only surface - via a hard-to-debug functional regression in a fast-cached checkpoint / graph_do_while kernel.""" + """Schema v3 (`cachevalue-v3-ast-resolved-ids`) added AST-resolved arg-id fields to the persisted ``CacheValue`` so + the launch path can forward them after a fast-cache restore (which skips AST transformation). This test pins the + round-trip for the new fields -- ``graph_do_while_levels`` as 3-tuples carrying ``cond_cpp_arg_id``, plus + ``checkpoint_yield_on_args`` / ``checkpoint_yield_on_cpp_arg_ids`` / ``checkpoint_user_labels_by_cp_id``. Without + this, a schema bug (wrong tuple arity, dropped field, mis-typed BaseModel default) would only surface via a hard-to- + debug functional regression in a fast-cached checkpoint / graph_do_while kernel.""" test_files_path = pathlib.Path("tests/python/quadrants/lang/fast_caching/test_files") offline_cache_path = tmp_path / "cache" @@ -204,12 +204,12 @@ def test_src_hasher_store_validate_round_trips_schema_v3_metadata( def test_src_hasher_intenum_qualname_round_trip( monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, temporary_module ) -> None: - """Schema v4 (`cachevalue-v4-intenum-qualnames`) added a parallel `checkpoint_user_label_enum_qualnames` - column so an ``IntEnum`` cp_id round-trips through fast-cache restore as the original enum member rather than - the underlying int. ``src_hasher.store`` derives the qualname column from the live label list (which still - holds the original ``IntEnum`` instances) before pydantic int-coerces them; ``_resolve_intenum_member`` - re-imports the enum class on load. This test covers both the store-side derivation (mixed IntEnum / plain int - / None) and the load-side resolution (verifies identity is preserved, not just int equality).""" + """Schema v4 (`cachevalue-v4-intenum-qualnames`) added a parallel `checkpoint_user_label_enum_qualnames` column so + an ``IntEnum`` cp_id round-trips through fast-cache restore as the original enum member rather than the underlying + int. ``src_hasher.store`` derives the qualname column from the live label list (which still holds the original + ``IntEnum`` instances) before pydantic int-coerces them; ``_resolve_intenum_member`` re-imports the enum class on + load. This test covers both the store-side derivation (mixed IntEnum / plain int / None) and the load-side + resolution (verifies identity is preserved, not just int equality).""" test_files_path = pathlib.Path("tests/python/quadrants/lang/fast_caching/test_files") offline_cache_path = tmp_path / "cache" temp_import_path = tmp_path / "temp_import" @@ -243,8 +243,8 @@ def test_src_hasher_intenum_qualname_round_trip( f"{_HasherTestStage.__module__}.{_HasherTestStage.__qualname__}.REDUCE", ] - # Resolver round-trip: rebuild each slot through `_resolve_intenum_member` and confirm enum identity (not - # just int-equality) is preserved. + # Resolver round-trip: rebuild each slot through `_resolve_intenum_member` and confirm enum identity (not just int- + # equality) is preserved. resolved = [ src_hasher._resolve_intenum_member(qn, lbl) for qn, lbl in zip(loaded.checkpoint_user_label_enum_qualnames, loaded.checkpoint_user_labels_by_cp_id) @@ -253,14 +253,14 @@ def test_src_hasher_intenum_qualname_round_trip( assert isinstance(resolved[0], _HasherTestStage) assert isinstance(resolved[2], _HasherTestStage) - # Resolver fallback: an unresolvable qualname (e.g. enum class moved/renamed since cache write) must drop - # back to the persisted int rather than raising, so a stale cache entry degrades gracefully. + # Resolver fallback: an unresolvable qualname (e.g. enum class moved/renamed since cache write) must drop back to + # the persisted int rather than raising, so a stale cache entry degrades gracefully. assert src_hasher._resolve_intenum_member("nonexistent.Module.Stage.LOAD", 5) == 5 # Top-level IntEnum used by `test_src_hasher_intenum_qualname_round_trip` so the resolver can re-import it via -# `importlib.import_module("tests.python.quadrants.lang.fast_caching.test_src_hasher")`. Lives at module scope -# (not inside the test) for the same reason `_FastcacheStage` / `_Stage` do in `test_checkpoint.py`. +# `importlib.import_module("tests.python.quadrants.lang.fast_caching.test_src_hasher")`. Lives at module scope (not +# inside the test) for the same reason `_FastcacheStage` / `_Stage` do in `test_checkpoint.py`. class _HasherTestStage(IntEnum): LOAD = 5 REDUCE = 9 diff --git a/tests/python/test_checkpoint.py b/tests/python/test_checkpoint.py index 4f8769d519..e65c782f59 100644 --- a/tests/python/test_checkpoint.py +++ b/tests/python/test_checkpoint.py @@ -233,8 +233,8 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 @test_utils.test() def test_checkpoint_yield_on_must_be_name_or_attribute(): """``yield_on=`` must reference an ndarray kernel argument -- either a bare ``ast.Name`` or an ``ast.Attribute`` - chain (for ``@qd.data_oriented`` member ndarrays). Arbitrary expressions are not supported; pinning the - diagnostic so the user knows to refactor.""" + chain (for ``@qd.data_oriented`` member ndarrays). Arbitrary expressions are not supported; pinning the diagnostic + so the user knows to refactor.""" @qd.kernel(graph=True, checkpoints=True) def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0)): @@ -859,18 +859,17 @@ def k( # ``qd.checkpoint(yield_on=self.flag)`` and ``qd.checkpoint(yield_on=params.flag)`` (where ``params`` is a kernel # parameter typed as a ``@dataclasses.dataclass``) both resolve the member ndarray to a flat C++ arg-id at AST-build # time via ``ASTTransformer._resolve_ndarray_kernel_arg_id``: it builds the expression and reads the resolved -# ``ExternalTensorExpression.arg_id``, so any attribute chain that ends up as a kernel ndarray arg works the same -# way as a bare parameter name. This frees users from having to forward flag members as bare kernel parameters when -# the rest of the kernel already operates on the dataclass / data-oriented owner. +# ``ExternalTensorExpression.arg_id``, so any attribute chain that ends up as a kernel ndarray arg works the same way +# as a bare parameter name. This frees users from having to forward flag members as bare kernel parameters when the +# rest of the kernel already operates on the dataclass / data-oriented owner. # ---------------------------------------------------------------------------------------------------------------------- @test_utils.test() def test_checkpoint_yield_on_data_oriented_member_metadata(): """`yield_on=self.flag` is accepted and the resolved label is stored verbatim (``"self.flag"``) in - ``checkpoint_yield_on_args``, while ``checkpoint_yield_on_cpp_arg_ids`` carries the flat C++ arg-id the - runtime forwards to the launch context. Verifies the AST-build-time resolution path without booting the - backend.""" + ``checkpoint_yield_on_args``, while ``checkpoint_yield_on_cpp_arg_ids`` carries the flat C++ arg-id the runtime + forwards to the launch context. Verifies the AST-build-time resolution path without booting the backend.""" N = 4 @qd.data_oriented @@ -898,8 +897,8 @@ def step(self): @test_utils.test() def test_checkpoint_yield_on_dataclass_member_metadata(): - """`yield_on=params.flag` for a ``@dataclasses.dataclass`` kernel parameter takes the same AST-build-time - resolution path as ``self.flag`` for a ``@qd.data_oriented`` owner -- the resolved label round-trips into + """`yield_on=params.flag` for a ``@dataclasses.dataclass`` kernel parameter takes the same AST-build-time resolution + path as ``self.flag`` for a ``@qd.data_oriented`` owner -- the resolved label round-trips into ``checkpoint_yield_on_args`` and the flat arg-id lands in ``checkpoint_yield_on_cpp_arg_ids``.""" import dataclasses # pylint: disable=import-outside-toplevel @@ -926,9 +925,9 @@ def step(params: Params): np.testing.assert_array_equal(params.x.to_numpy(), np.ones(N, dtype=np.int32)) assert step._primal.checkpoint_user_labels_by_cp_id == [0] # Dataclass-parameter member access gets pre-rewritten by the AST pipeline to a flattened parameter name - # (`__qd_params__qd_flag`) before the checkpoint transformer sees it, so the label round-trips in the - # flattened form. The functional contract -- a valid flat C++ arg-id is resolved and the kernel mutates the - # right ndarray -- is the same as for the bare-param / `self.flag` forms. + # (`__qd_params__qd_flag`) before the checkpoint transformer sees it, so the label round-trips in the flattened + # form. The functional contract -- a valid flat C++ arg-id is resolved and the kernel mutates the right ndarray -- + # is the same as for the bare-param / `self.flag` forms. labels = step._primal.checkpoint_yield_on_args assert len(labels) == 1 and labels[0] is not None and "flag" in labels[0] cpp_ids = step._primal.checkpoint_yield_on_cpp_arg_ids @@ -938,9 +937,9 @@ def step(params: Params): @test_utils.test() def test_checkpoint_yield_on_dataclass_member_yields_and_resumes(): """Behavioural round-trip for `yield_on=params.flag` -- mirror of the `self.flag` test below, using a - `@dataclasses.dataclass` kernel parameter instead of a `@qd.data_oriented` owner. The dataclass-member access - is pre-rewritten to a flattened parameter, so verifying the full yield/resume contract end-to-end is the only - way to confirm the right ndarray is wired up at launch.""" + `@dataclasses.dataclass` kernel parameter instead of a `@qd.data_oriented` owner. The dataclass-member access is + pre-rewritten to a flattened parameter, so verifying the full yield/resume contract end-to-end is the only way to + confirm the right ndarray is wired up at launch.""" import dataclasses # pylint: disable=import-outside-toplevel if not _supports_checkpoint_yield_resume(): @@ -974,8 +973,8 @@ def step(params: Params): np.testing.assert_array_equal(params.x.to_numpy(), np.ones(N, dtype=np.int32)) params.flag.from_numpy(np.array(0, dtype=np.int32)) # `step` is a free-function kernel (not a bound class kernel), so `params` must be passed positionally to - # `resume` -- the data_oriented sibling test above can omit it because the dataclass member access is - # implicit through `sim.step`'s bound `self`. + # `resume` -- the data_oriented sibling test above can omit it because the dataclass member access is implicit + # through `sim.step`'s bound `self`. status = step.resume(params, from_checkpoint=8) assert not status.yielded np.testing.assert_array_equal(params.x.to_numpy(), np.full(N, 11, dtype=np.int32)) @@ -983,10 +982,10 @@ def step(params: Params): @test_utils.test() def test_checkpoint_yield_on_member_nonexistent_attribute_raises(): - """`yield_on=self.nonexistent_attr` (attribute does not exist on the `@qd.data_oriented` owner) must raise a - user-facing `QuadrantsSyntaxError` at the `with` site -- the AST-time resolver wraps the underlying attribute - lookup failure in the same `does not resolve to an ndarray kernel parameter` diagnostic as the bare-name - nonexistent case, so users see one consistent error pattern.""" + """`yield_on=self.nonexistent_attr` (attribute does not exist on the `@qd.data_oriented` owner) must raise a user- + facing `QuadrantsSyntaxError` at the `with` site -- the AST-time resolver wraps the underlying attribute lookup + failure in the same `does not resolve to an ndarray kernel parameter` diagnostic as the bare-name nonexistent case, + so users see one consistent error pattern.""" N = 4 @qd.data_oriented @@ -1008,10 +1007,10 @@ def step(self): @test_utils.test() def test_checkpoint_yield_on_member_non_ndarray_attribute_raises(): - """`yield_on=self.scalar` where `self.scalar` is a Python int (not an ndarray) must raise the same - `does not resolve to an ndarray kernel parameter` diagnostic -- the AST-time resolver builds the expression - but rejects it because the resulting Expr is not an `ExternalTensorExpression`. Pinning this so future - refactors of the resolver can't silently accept non-ndarray attributes and crash later in the launcher.""" + """`yield_on=self.scalar` where `self.scalar` is a Python int (not an ndarray) must raise the same `does not resolve + to an ndarray kernel parameter` diagnostic -- the AST-time resolver builds the expression but rejects it because the + resulting Expr is not an `ExternalTensorExpression`. Pinning this so future refactors of the resolver can't silently + accept non-ndarray attributes and crash later in the launcher.""" N = 4 @qd.data_oriented @@ -1034,8 +1033,8 @@ def step(self): @test_utils.test() def test_checkpoint_yield_on_data_oriented_member_yields_and_resumes(): """Behavioural round-trip for `yield_on=self.flag`: setting the member flag from inside the kernel yields, and - ``kernel.resume(from_checkpoint=...)`` skips ahead to the named checkpoint. Same surface contract as the bare- - parameter form (`test_checkpoint_yield_on_yields_and_resumes`); the only difference is where the flag lives.""" + ``kernel.resume(from_checkpoint=...)`` skips ahead to the named checkpoint. Same surface contract as the + bare-parameter form (`test_checkpoint_yield_on_yields_and_resumes`); the only difference is where the flag lives.""" if not _supports_checkpoint_yield_resume(): pytest.skip("backend does not implement checkpoint yield/resume") N = 4 @@ -1073,8 +1072,8 @@ def step(self): # Module-level kernel for the fastcache-restoration test below. Lives outside any test so the child subprocess can # import the test module and reach it without re-creating the (closure-captured) outer scope. The kernel has to be -# annotated with `fastcache=True` (=> implies `pure`) and lifted out of any decorator-bound owner so it qualifies -# for the src_ll_cache path. We model the data_oriented owner as the `_FastcacheYieldOnSelfCheckpoint` class below. +# annotated with `fastcache=True` (=> implies `pure`) and lifted out of any decorator-bound owner so it qualifies for +# the src_ll_cache path. We model the data_oriented owner as the `_FastcacheYieldOnSelfCheckpoint` class below. @qd.data_oriented @@ -1114,9 +1113,9 @@ def _fastcache_checkpoint_child(args: list[str]) -> None: primal = type(sim).step._primal # The schema-v3 fast-cache restore path must repopulate `checkpoint_yield_on_args` and - # `checkpoint_yield_on_cpp_arg_ids` from the cached `CacheValue` (since AST transformation is skipped on a - # cache hit). A regression here would surface as an empty `_forward_yield_on_table_to_ctx` call, silently - # breaking yield/resume on fast-cached checkpoint kernels. + # `checkpoint_yield_on_cpp_arg_ids` from the cached `CacheValue` (since AST transformation is skipped on a cache + # hit). A regression here would surface as an empty `_forward_yield_on_table_to_ctx` call, silently breaking + # yield/resume on fast-cached checkpoint kernels. labels = primal.checkpoint_yield_on_args cpp_ids = primal.checkpoint_yield_on_cpp_arg_ids assert ( @@ -1140,11 +1139,10 @@ def _fastcache_checkpoint_child(args: list[str]) -> None: @test_utils.test() def test_checkpoint_fastcache_restores_self_member_yield_on(tmp_path: pathlib.Path): """After a fast-cache restore in a fresh process, a `@qd.kernel(graph=True, checkpoints=True, fastcache=True)` - kernel with `yield_on=self.flag` must repopulate `checkpoint_yield_on_args` / - `checkpoint_yield_on_cpp_arg_ids` / `checkpoint_user_labels_by_cp_id` from the persisted ``CacheValue`` -- - not from the AST transformer, which is skipped on a cache hit. Without the schema-v3 round-trip the launch - path's `forward_yield_on_table_to_ctx` would be a no-op and yield/resume would silently break for fast-cached - checkpoint kernels.""" + kernel with `yield_on=self.flag` must repopulate `checkpoint_yield_on_args` / `checkpoint_yield_on_cpp_arg_ids` / + `checkpoint_user_labels_by_cp_id` from the persisted ``CacheValue`` -- not from the AST transformer, which is + skipped on a cache hit. Without the schema-v3 round-trip the launch path's `forward_yield_on_table_to_ctx` would + be a no-op and yield/resume would silently break for fast-cached checkpoint kernels.""" assert qd.lang is not None arch = qd.lang.impl.current_cfg().arch.name env = dict(os.environ) @@ -1207,8 +1205,8 @@ def _fastcache_intenum_child(args: list[str]) -> None: primal = _fastcache_intenum_kernel._primal labels = primal.checkpoint_user_labels_by_cp_id - # The schema-v4 round-trip must rebuild the IntEnum identity, not just the int equality. A regression here - # would show up as `labels == [10, 20]` (plain ints) breaking the documented contract that + # The schema-v4 round-trip must rebuild the IntEnum identity, not just the int equality. A regression here would + # show up as `labels == [10, 20]` (plain ints) breaking the documented contract that # `qd.checkpoint(Stage.X, ...)` surfaces as `Stage.X` (not the raw int) on `status.checkpoint`. assert labels == [ _FastcacheStage.LOAD, @@ -1225,12 +1223,12 @@ def _fastcache_intenum_child(args: list[str]) -> None: @test_utils.test() def test_checkpoint_fastcache_preserves_intenum_label_identity(tmp_path: pathlib.Path): - """Fast-cache restore must rebuild ``checkpoint_user_labels_by_cp_id`` with the original ``IntEnum`` members, - not just int-equal plain ints. Schema v4 adds a parallel ``checkpoint_user_label_enum_qualnames`` column so - ``_resolve_intenum_member`` can re-import the enum class on cache hit -- pydantic coerces ``IntEnum`` to - ``int`` at ``CacheValue`` construction, which would otherwise silently drop enum identity and break the - documented contract that ``qd.checkpoint(Stage.X, ...)`` surfaces as ``Stage.X`` (not the raw int) on - ``status.checkpoint`` after a fast-cache hit.""" + """Fast-cache restore must rebuild ``checkpoint_user_labels_by_cp_id`` with the original ``IntEnum`` members, not + just int-equal plain ints. Schema v4 adds a parallel ``checkpoint_user_label_enum_qualnames`` column so + ``_resolve_intenum_member`` can re-import the enum class on cache hit -- pydantic coerces ``IntEnum`` to ``int`` at + ``CacheValue`` construction, which would otherwise silently drop enum identity and break the documented contract + that ``qd.checkpoint(Stage.X, ...)`` surfaces as ``Stage.X`` (not the raw int) on ``status.checkpoint`` after a + fast-cache hit.""" assert qd.lang is not None arch = qd.lang.impl.current_cfg().arch.name env = dict(os.environ) @@ -1290,8 +1288,8 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 # Subprocess dispatch for fast-cache restoration tests above (mirrors the pattern in `test_graph_do_while.py`). The -# parent test invokes us via `subprocess.run([sys.executable, __file__, , ])` so the -# child runs in a fresh interpreter with a clean `qd.init` -- the only way to exercise the cross-process fast-cache -# load path that ``Kernel._try_load_fastcache`` takes after a previous run has populated the on-disk cache. +# parent test invokes us via `subprocess.run([sys.executable, __file__, , ])` so the child +# runs in a fresh interpreter with a clean `qd.init` -- the only way to exercise the cross-process fast-cache load +# path that ``Kernel._try_load_fastcache`` takes after a previous run has populated the on-disk cache. if __name__ == "__main__": globals()[sys.argv[1]](sys.argv[2:]) diff --git a/tests/python/test_graph_do_while.py b/tests/python/test_graph_do_while.py index 18ca6a8ba7..e5859cd730 100644 --- a/tests/python/test_graph_do_while.py +++ b/tests/python/test_graph_do_while.py @@ -357,20 +357,19 @@ def step(params: Params): assert params.counter.to_numpy() == 0 levels = step._primal.graph_do_while_levels assert len(levels) == 1 - # Dataclass-parameter member access gets pre-rewritten to a flattened parameter name - # (`__qd_params__qd_counter`) before the graph_do_while transformer sees it, so the readable label round-trips - # in the flattened form. The functional contract -- a valid flat C++ arg-id resolves and the loop drives off - # the device-side counter -- is the same as for the bare-param / `self.counter` forms. + # Dataclass-parameter member access gets pre-rewritten to a flattened parameter name (`__qd_params__qd_counter`) + # before the graph_do_while transformer sees it, so the readable label round-trips in the flattened form. The + # functional contract -- a valid flat C++ arg-id resolves and the loop drives off the device-side counter -- is the + # same as for the bare-param / `self.counter` forms. assert "counter" in levels[0].cond_arg_name assert levels[0].cond_cpp_arg_id >= 0 @test_utils.test() def test_graph_do_while_with_member_nonexistent_attribute_raises(): - """`qd.graph_do_while(self.nonexistent_attr)` must raise the same user-facing - `does not resolve to an ndarray kernel parameter` diagnostic as the bare-name nonexistent case. The AST-time - resolver wraps the underlying attribute lookup failure so the user sees one consistent error pattern across - bare-name and attribute forms.""" + """`qd.graph_do_while(self.nonexistent_attr)` must raise the same user-facing `does not resolve to an ndarray kernel + parameter` diagnostic as the bare-name nonexistent case. The AST-time resolver wraps the underlying attribute lookup + failure so the user sees one consistent error pattern across bare-name and attribute forms.""" N = 4 @qd.data_oriented @@ -393,9 +392,9 @@ def step(self): @test_utils.test() def test_graph_do_while_with_data_oriented_member_nested(): """Nested `qd.graph_do_while(self.outer)` containing `qd.graph_do_while(self.inner)` exercises the level-table - machinery with member ndarrays: each level resolves its own flat C++ arg-id at AST-build time, the parent_id - chain links inner -> outer, and the loop body iterates `outer_iters * inner_iters` times the same as the - bare-parameter version (see `test_graph_do_while_nested_two_levels`).""" + machinery with member ndarrays: each level resolves its own flat C++ arg-id at AST-build time, the parent_id chain + links inner -> outer, and the loop body iterates `outer_iters * inner_iters` times the same as the bare-parameter + version (see `test_graph_do_while_nested_two_levels`).""" if not _is_graph_do_while_natively_supported() and not ( impl.current_cfg().arch in (qd.x64, qd.arm64, qd.amdgpu, qd.vulkan, qd.metal) ): @@ -446,8 +445,8 @@ def step(self): def test_graph_do_while_with_data_oriented_member_counter(): """`qd.graph_do_while(self.counter)` resolves the member ndarray to the loop condition's flat C++ arg-id at AST-build time via ``ASTTransformer._resolve_ndarray_kernel_arg_id``, lifting the previous bare-parameter - restriction. The metadata exposed on the kernel records the readable label (``"self.counter"``) plus the - resolved arg-id; the loop behaviour matches the bare-parameter form below.""" + restriction. The metadata exposed on the kernel records the readable label (``"self.counter"``) plus the resolved + arg-id; the loop behaviour matches the bare-parameter form below.""" N = 4 @qd.data_oriented