diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index b018a6194f..27cb346939 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -87,7 +87,7 @@ solve(x, counter) # x is now incremented 10 times; counter is 0 ``` -The argument to `qd.graph_do_while()` must be the name of a scalar `qd.i32` ndarray parameter. The loop body repeats while this value is non-zero. +The argument to `qd.graph_do_while()` must reference a scalar `qd.i32` ndarray that the kernel can access — a bare kernel parameter (`qd.graph_do_while(counter)`), a `@qd.data_oriented` member ndarray (`qd.graph_do_while(self.counter)`), or a `@dataclasses.dataclass` parameter member (`qd.graph_do_while(params.counter)`). The loop body repeats while this value is non-zero. - On CUDA SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement. - On older CUDA GPUs, AMDGPU, and non-GPU backends, it falls back to a host-side do-while loop (see the [backend support table](#backend-support)). @@ -256,7 +256,7 @@ while status.yielded: - Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time with a fix-it pointing at `checkpoints=True`. - `cp_id` must be an int literal or an `IntEnum` value, and must be unique across the kernel. -- `yield_on=` must be a kernel parameter that is a 0-d `qd.types.ndarray(qd.i32, ndim=0)`; expressions are not supported. +- `yield_on=` must reference a 0-d `qd.types.ndarray(qd.i32, ndim=0)` — a bare kernel parameter (`yield_on=flag`), a `@qd.data_oriented` member ndarray (`yield_on=self.flag`), or a `@dataclasses.dataclass` parameter member (`yield_on=params.flag`). Arbitrary expressions are not supported. - Checkpoints cannot be nested inside other checkpoints. Checkpoints inside a `qd.graph_do_while` body are fine. - The body of a `with qd.checkpoint(...)` block cannot contain bare top-level statements (assignments, augmented assignments, or bare call/expression statements). Every top-level statement must be inside a `for`-loop (or other control-flow construct). A docstring as the first statement is allowed. Bare statements raise `QuadrantsSyntaxError` at compile time with a fix-it pointing at the explicit one-iteration `for`-wrap: diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 76a6a0160a..fc443ac481 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -1,6 +1,8 @@ +import importlib import json import os import warnings +from enum import IntEnum from typing import Any, Iterable, Sequence import pydantic @@ -18,8 +20,77 @@ from .python_side_cache import PythonSideCache # Bumped whenever the persisted CacheValue schema changes (see create_cache_key). v2 replaced the single -# graph_do_while_arg string with a nested level table. -_CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v2-gdw-levels" +# graph_do_while_arg string with a nested level table. v3 added the AST-resolved flat C++ arg-ids for qd.graph_do_while +# conditions and qd.checkpoint(yield_on=...) targets so the launch path can forward them directly without per-launch +# name matching (necessary for @qd.data_oriented member ndarrays). v4 added the per-slot +# `checkpoint_user_label_enum_qualnames` table so an IntEnum cp_id (e.g. `qd.checkpoint(Stage.SIM, ...)`) round-trips +# through fast-cache restore as the original IntEnum member rather than the underlying int. +_CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v4-intenum-qualnames" + + +def _intenum_member_qualname(value: Any) -> str | None: + """Return ``"module.ClassQualName.MEMBER"`` for an ``IntEnum`` member, else ``None``. + + Stored alongside ``checkpoint_user_labels_by_cp_id`` so that ``_resolve_intenum_member`` can rebuild the original + enum member on fast-cache restore -- pydantic coerces ``IntEnum`` to plain ``int`` at ``CacheValue`` construction + time (it sees ``list[int | None]``), which would otherwise silently break the documented contract that + ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` rather than the raw int through ``status.checkpoint``. + Returns ``None`` for plain ints, ``None`` labels, anonymous enums (no ``__module__``), and other unsupported + shapes -- the loader falls back to the raw int in those cases. + """ + if not isinstance(value, IntEnum): + return None + cls = type(value) + module = getattr(cls, "__module__", None) + qualname = getattr(cls, "__qualname__", None) + name = getattr(value, "name", None) + if not module or not qualname or not name: + return None + return f"{module}.{qualname}.{name}" + + +def _resolve_intenum_member(qualname: str | None, fallback: int | None) -> int | IntEnum | None: + """Inverse of ``_intenum_member_qualname``: look up the enum member by ``"module.ClassQualName.MEMBER"``. + + Returns the resolved ``IntEnum`` member if every step (module import, attribute walk) succeeds AND the member's int + value matches ``fallback`` (the raw int from ``checkpoint_user_labels_by_cp_id`` we already persisted). Mismatch or + any failure -- module renamed since the cache was written, enum class refactored, member removed, etc. -- falls back + to ``fallback`` so the user still gets a usable (if enum-identity-less) label rather than a hard crash. ``None`` + qualname / ``None`` fallback short-circuit to ``fallback`` for the plain-int label case. + """ + if qualname is None or fallback is None: + return fallback + try: + # qualname is "module.path.Class[.Nested].MEMBER"; the MEMBER tail is always one segment, so rsplit once. The + # remaining cls_path mixes dotted module path + dotted class qualname; we try progressively shorter module + # prefixes until one imports, then resolve the rest as attribute chain. This handles top-level enums + # (``mymod.Stage.LOAD``), enums nested in classes (``mymod.Outer.Inner.MEMBER``), and enums in subpackages + # (``a.b.Stage.LOAD``) without needing the user to declare which prefix is the module. + cls_path, _, member_name = qualname.rpartition(".") + if not cls_path or not member_name: + return fallback + module = None + cls_attr_path = "" + segments = cls_path.split(".") + for i in range(len(segments), 0, -1): + try: + module = importlib.import_module(".".join(segments[:i])) + cls_attr_path = ".".join(segments[i:]) + break + except ImportError: + continue + if module is None: + return fallback + obj: Any = module + if cls_attr_path: + for seg in cls_attr_path.split("."): + obj = getattr(obj, seg) + obj = getattr(obj, member_name) + except (AttributeError, ValueError): + return fallback + if isinstance(obj, IntEnum) and int(obj) == int(fallback): + return obj + return fallback def create_cache_key( @@ -69,9 +140,28 @@ class CacheValue(BaseModel): frontend_cache_key: str hashed_function_source_infos: list[HashedFunctionSourceInfo] used_py_dataclass_parameters: set[str] - # Nested graph_do_while level table as (cond_arg_name, parent_id) pairs, indexed by level id. None / empty for - # kernels without graph_do_while. - graph_do_while_levels: list[tuple[str, int]] | None = None + # Nested graph_do_while level table as (cond_arg_name, parent_id, cond_cpp_arg_id) triples, indexed by level id. + # None / empty for kernels without graph_do_while. ``cond_cpp_arg_id`` is the flat C++ arg-id resolved at AST-build + # time by ``ASTTransformer._resolve_ndarray_kernel_arg_id`` and is required by the launch path to support + # `@qd.data_oriented` member conditions (`qd.graph_do_while(self.counter)`) -- name-matching against ``arg_metas`` + # only resolves top-level parameters. + graph_do_while_levels: list[tuple[str, int, int]] | None = None + # AST-build-time-resolved checkpoint metadata, indexed by internal cp_id. Empty for kernels without any + # `with qd.checkpoint(...)` block. See `Kernel.checkpoint_yield_on_args` / + # `Kernel.checkpoint_yield_on_cpp_arg_ids` / `Kernel.checkpoint_user_labels_by_cp_id` for what each entry means. + # Restored alongside the C++-side cached kernel so the launch path can forward `yield_on=` arg-ids and translate + # `from_checkpoint=` labels without re-running the AST transformer. + checkpoint_yield_on_args: list[str | None] = [] + checkpoint_yield_on_cpp_arg_ids: list[int] = [] + checkpoint_user_labels_by_cp_id: list[int | None] = [] + # Parallel to ``checkpoint_user_labels_by_cp_id``: each entry is the dotted ``module.ClassQualName.MEMBER`` of the + # original ``IntEnum`` member the user passed as ``cp_id``, or ``None`` if the user passed a plain int (or for + # implicit auto-wrap checkpoints). On fast-cache restore the loader runs each entry through + # ``_resolve_intenum_member`` to rebuild the IntEnum, preserving the documented contract that + # ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` (not the underlying int) through ``status.checkpoint`` and + # ``kernel.resume(from_checkpoint=...)`` -- pydantic coerces IntEnum to int at ``CacheValue`` construction time so + # the parallel qualname column is what carries the enum identity. + checkpoint_user_label_enum_qualnames: list[str | None] = [] def store( @@ -79,8 +169,16 @@ def store( fast_cache_key: str, function_source_infos: Iterable[FunctionSourceInfo], used_py_dataclass_parameters: set[str], - graph_do_while_levels: list[tuple[str, int]] | None = None, + graph_do_while_levels: list[tuple[str, int, int]] | None = None, + checkpoint_yield_on_args: list[str | None] | None = None, + checkpoint_yield_on_cpp_arg_ids: list[int] | None = None, + checkpoint_user_labels_by_cp_id: list[int | None] | None = None, ) -> None: + # `checkpoint_user_label_enum_qualnames` is derived from `checkpoint_user_labels_by_cp_id` here (rather than being + # plumbed through a separate kwarg from `Kernel.materialize`) so callers never have to think about the parallel + # column: they pass the live label list (which still holds the original ``IntEnum`` instances at store time, before + # pydantic's int-coercion strips identity in ``CacheValue.__init__``), and the qualname snapshot is recorded once + # here for the loader to consume. """ Note that unlike other caches, this cache is not going to store the actual value we want. This cache is only used for verification that our cache key is valid. Big picture: @@ -103,11 +201,17 @@ def store( assert frontend_cache_key is not None cache = PythonSideCache() hashed_function_source_infos = function_hasher.hash_functions(function_source_infos) + labels = checkpoint_user_labels_by_cp_id or [] + enum_qualnames = [_intenum_member_qualname(lbl) for lbl in labels] cache_value_obj = CacheValue( frontend_cache_key=frontend_cache_key, hashed_function_source_infos=list(hashed_function_source_infos), used_py_dataclass_parameters=used_py_dataclass_parameters, graph_do_while_levels=graph_do_while_levels, + checkpoint_yield_on_args=checkpoint_yield_on_args or [], + checkpoint_yield_on_cpp_arg_ids=checkpoint_yield_on_cpp_arg_ids or [], + checkpoint_user_labels_by_cp_id=labels, + checkpoint_user_label_enum_qualnames=enum_qualnames, ) cache.store(fast_cache_key, cache_value_obj.model_dump_json()) @@ -125,23 +229,19 @@ def _try_load(cache_key: str) -> CacheValue | None: return cache_value_obj -def load( - cache_key: str, -) -> tuple[set[str], str, list[tuple[str, int]] | None] | tuple[None, None, None]: - """ - loads function source infos from cache, if available - checks the hashes against the current source code +def load(cache_key: str) -> CacheValue | None: + """Load a validated ``CacheValue`` for *cache_key* if one exists and its source hashes still match, else None. + + Returns the full ``CacheValue`` (rather than the historical 3-tuple) so callers can pick off the AST-transformer- + produced metadata (graph_do_while levels, checkpoint tables) without the loader having to grow a new return slot + every time we cache a new piece of AST output. """ cache_value = _try_load(cache_key) if cache_value is None: - return None, None, None + return None if function_hasher.validate_hashed_function_infos(cache_value.hashed_function_source_infos): - return ( - cache_value.used_py_dataclass_parameters, - cache_value.frontend_cache_key, - cache_value.graph_do_while_levels, - ) - return None, None, None + return cache_value + return None def dump_stats() -> None: diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 0297fffa8c..cd6a85e33d 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1352,19 +1352,44 @@ def build_For(ctx: ASTTransformerFuncContext, node: ast.For) -> None: return ASTTransformer.build_struct_for(ctx, node, is_grouped=False) @staticmethod - def _is_graph_do_while_call(node: ast.expr) -> str | None: - """If *node* is ``qd.graph_do_while(var)`` return the arg name, else None.""" + def _is_graph_do_while_call(node: ast.expr) -> ast.expr | None: + """If *node* is ``qd.graph_do_while(arg)`` return the arg AST node, else None. + + ``arg`` may be an ``ast.Name`` (a bare kernel parameter, e.g. ``counter``) or an ``ast.Attribute`` chain (a + ``@qd.data_oriented`` member ndarray such as ``self.counter`` or a ``@dataclasses.dataclass`` parameter member + such as ``params.counter``). The actual resolution to a kernel ndarray argument happens in ``build_While`` via + ``_resolve_ndarray_kernel_arg_id``. + """ if not isinstance(node, ast.Call): return None func = node.func - if isinstance(func, ast.Attribute) and func.attr == "graph_do_while": - if len(node.args) == 1 and isinstance(node.args[0], ast.Name): - return node.args[0].id - if isinstance(func, ast.Name) and func.id == "graph_do_while": - if len(node.args) == 1 and isinstance(node.args[0], ast.Name): - return node.args[0].id + is_gdw = (isinstance(func, ast.Attribute) and func.attr == "graph_do_while") or ( + isinstance(func, ast.Name) and func.id == "graph_do_while" + ) + if not is_gdw: + return None + if len(node.args) == 1 and isinstance(node.args[0], (ast.Name, ast.Attribute)): + return node.args[0] return None + @staticmethod + def _resolve_ndarray_kernel_arg_id( + ctx: ASTTransformerFuncContext, + kernel, + node: ast.expr, + usage: str, + ) -> tuple[str, int]: + """Thin forwarding wrapper around ``ndarray_arg_resolver.resolve_ndarray_kernel_arg_id``; the actual logic lives + in module ``ast_transformers/ndarray_arg_resolver.py`` to keep this file from growing per-feature (same pattern + as ``_is_checkpoint_call`` / ``CheckpointTransformer``). Returns ``(label, flat_cpp_arg_id)`` or raises + ``QuadrantsSyntaxError``.""" + # pylint: disable-next=C0415,import-outside-toplevel + from quadrants.lang.ast.ast_transformers.ndarray_arg_resolver import ( + resolve_ndarray_kernel_arg_id, + ) + + return resolve_ndarray_kernel_arg_id(ctx, kernel, node, usage) + @staticmethod def _is_checkpoint_call(node: ast.expr, global_vars: dict): """Thin forwarding wrapper around ``CheckpointTransformer.is_checkpoint_call``; the actual logic lives in module @@ -1377,18 +1402,11 @@ def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: if node.orelse: raise QuadrantsSyntaxError("'else' clause for 'while' not supported in Quadrants kernels") - graph_do_while_arg = ASTTransformer._is_graph_do_while_call(node.test) - if graph_do_while_arg is not None: + graph_do_while_node = ASTTransformer._is_graph_do_while_call(node.test) + if graph_do_while_node is not None: from quadrants.lang.kernel import GraphDoWhileLevel # pylint: disable=C0415 kernel = ctx.global_context.current_kernel - arg_names = [m.name for m in kernel.arg_metas] - if graph_do_while_arg not in arg_names: - raise QuadrantsSyntaxError( - f"qd.graph_do_while({graph_do_while_arg!r}) does not match any " - f"parameter of kernel {kernel.func.__name__!r}. " - f"Available parameters: {arg_names}" - ) if not kernel.use_graph: raise QuadrantsSyntaxError("qd.graph_do_while() requires @qd.kernel(graph=True)") # graph_do_while emits no loop IR; its body's for-loops must be top-level (offloaded) tasks. So it may only @@ -1399,15 +1417,22 @@ def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: "qd.graph_do_while() must be at the kernel top level or directly nested inside " "another qd.graph_do_while(); it cannot appear inside a for-loop." ) + # Resolve the condition ndarray (bare parameter or @qd.data_oriented member) to its flat C++ arg-id at AST- + # build time -- the same id the runtime needs -- so the launch path forwards it directly with no per-launch + # name matching. ``cond_arg_name`` keeps the readable label (e.g. "counter" or "self.counter") for + # introspection and for the legacy ``graph_do_while_arg`` alias surfaced on Kernel. + cond_label, cond_cpp_arg_id = ASTTransformer._resolve_ndarray_kernel_arg_id( + ctx, kernel, graph_do_while_node, "qd.graph_do_while(...)" + ) # Register this loop as a new nesting level (the body restriction is validated up-front in # FunctionDefTransformer). Outer loops get lower ids than the inner loops they contain. parent_id = kernel._graph_do_while_level_stack[-1] if kernel._graph_do_while_level_stack else -1 level_id = len(kernel.graph_do_while_levels) kernel.graph_do_while_levels.append( - GraphDoWhileLevel(cond_arg_name=graph_do_while_arg, parent_id=parent_id) + GraphDoWhileLevel(cond_arg_name=cond_label, parent_id=parent_id, cond_cpp_arg_id=cond_cpp_arg_id) ) if level_id == 0: - kernel.graph_do_while_arg = graph_do_while_arg + kernel.graph_do_while_arg = cond_label kernel._graph_do_while_level_stack.append(level_id) ctx.ast_builder.set_graph_do_while_level_id(level_id) try: diff --git a/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py b/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py index ab87271698..251537347f 100644 --- a/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/checkpoint_transformer.py @@ -38,13 +38,17 @@ class CheckpointCallInfo: - ``cp_id``: the user-supplied label (an ``int`` or ``IntEnum`` value), or ``None`` for an auto-wrap implicit checkpoint. - - ``yield_on``: name of the kernel parameter passed as ``yield_on=`` (an ``ast.Name`` is required), or ``None`` for - an implicit checkpoint. + - ``yield_on_node``: the ``ast.expr`` passed as ``yield_on=`` -- either an ``ast.Name`` (bare kernel parameter, e.g. + ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` owner, or ``params.flag`` + where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). ``None`` for an implicit checkpoint. + ``build_checkpoint_with`` resolves the node to a flat C++ arg-id via + ``ASTTransformer._resolve_ndarray_kernel_arg_id`` so the runtime can forward it directly without per-launch name + matching. - ``is_implicit``: ``True`` iff this Call was synthesised by ``auto_wrap_for_loops``. """ cp_id: int | None - yield_on: str | None + yield_on_node: ast.expr | None is_implicit: bool @@ -133,7 +137,7 @@ def is_checkpoint_call(node: ast.expr, global_vars: dict) -> CheckpointCallInfo return None # Auto-wrap-synthesised implicit checkpoint: no user-facing args, no validation. if getattr(node, _IMPLICIT_MARKER_ATTR, False): - return CheckpointCallInfo(cp_id=None, yield_on=None, is_implicit=True) + return CheckpointCallInfo(cp_id=None, yield_on_node=None, is_implicit=True) # User-written `qd.checkpoint(cp_id, yield_on)` -- both args are required. if len(node.args) + len(node.keywords) == 0: raise QuadrantsSyntaxError("qd.checkpoint() takes two arguments: `qd.checkpoint(cp_id, yield_on=flag)`.") @@ -173,13 +177,18 @@ def is_checkpoint_call(node: ast.expr, global_vars: dict) -> CheckpointCallInfo "qd.checkpoint() is missing required argument `yield_on` (e.g. " "`qd.checkpoint(0, yield_on=overflow_flag)`)" ) - if not isinstance(yield_on_arg, ast.Name): + # `yield_on=` must point at an ndarray kernel argument -- a bare parameter (`yield_on=flag`), a + # `@qd.data_oriented` member (`yield_on=self.flag`), or a `@dataclasses.dataclass` parameter member + # (`yield_on=params.flag`). Other expressions can't be lowered to a flat arg-id and are rejected here so the + # user gets a clear compile-time error at the `with` site. + if not isinstance(yield_on_arg, (ast.Name, ast.Attribute)): raise QuadrantsSyntaxError( - "qd.checkpoint(yield_on=...) must be the bare name of a kernel parameter (e.g. " - "`yield_on=overflow_flag`); expressions are not supported" + "qd.checkpoint(yield_on=...) must reference a kernel ndarray argument -- e.g. `yield_on=flag` for " + "a bare parameter, `yield_on=self.flag` for a @qd.data_oriented member, or `yield_on=params.flag` " + "for a @dataclasses.dataclass parameter member; arbitrary expressions are not supported" ) cp_id_value = CheckpointTransformer._resolve_cp_id(cp_id_arg, global_vars) - return CheckpointCallInfo(cp_id=cp_id_value, yield_on=yield_on_arg.id, is_implicit=False) + return CheckpointCallInfo(cp_id=cp_id_value, yield_on_node=yield_on_arg, is_implicit=False) @staticmethod def build_checkpoint_with( @@ -216,14 +225,23 @@ def build_checkpoint_with( "same kernel must be flat siblings (a checkpoint inside qd.graph_do_while is fine)" ) + yield_on_label: str | None = None + yield_on_cpp_arg_id: int = -1 if not info.is_implicit: - # Validate `yield_on=` names a real kernel parameter. - arg_names = [m.name for m in kernel.arg_metas] - if info.yield_on not in arg_names: - raise QuadrantsSyntaxError( - f"qd.checkpoint(yield_on={info.yield_on!r}) does not match any parameter of kernel " - f"{kernel.func.__name__!r}. Available parameters: {arg_names}" - ) + # Resolve `yield_on=` (a bare parameter or `@qd.data_oriented` / `@dataclasses.dataclass` member ndarray) to + # its flat C++ arg-id at AST-build time. ``resolve_ndarray_kernel_arg_id`` raises a user-facing + # ``QuadrantsSyntaxError`` if the expression does not name a real ndarray kernel argument, which keeps the + # diagnostic at the `with` site instead of leaking into the launcher. Both the label (for + # ``checkpoint_yield_on_args`` / introspection) and the resolved arg-id (for the runtime) are stashed and + # forwarded to the launch path below. + # pylint: disable-next=C0415,import-outside-toplevel + from quadrants.lang.ast.ast_transformers.ndarray_arg_resolver import ( + resolve_ndarray_kernel_arg_id, + ) + + yield_on_label, yield_on_cpp_arg_id = resolve_ndarray_kernel_arg_id( + ctx, kernel, info.yield_on_node, "qd.checkpoint(yield_on=...)" + ) # Reject duplicate user-supplied cp_id labels. existing = [lbl for lbl in kernel.checkpoint_user_labels_by_cp_id if lbl is not None] if info.cp_id in existing: @@ -272,7 +290,8 @@ def build_checkpoint_with( f" ...\n" ) - kernel.checkpoint_yield_on_args.append(info.yield_on) + kernel.checkpoint_yield_on_args.append(yield_on_label) + kernel.checkpoint_yield_on_cpp_arg_ids.append(yield_on_cpp_arg_id) kernel.checkpoint_user_labels_by_cp_id.append(info.cp_id) # Hand control to the C++ ASTBuilder so that every for-loop emitted by `build_stmts` below is tagged with this # checkpoint's internal `cp_id` on its `ForLoopConfig.checkpoint_id`. The C++ counter is the source of truth for diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 4e7b8e5154..6a7497f1e3 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -514,6 +514,7 @@ def build_FunctionDef( # different argument shape) start from an empty list. Mirrors how `graph_do_while_arg` gets overwritten # unconditionally during AST traversal. kernel.checkpoint_yield_on_args = [] + kernel.checkpoint_yield_on_cpp_arg_ids = [] kernel.checkpoint_user_labels_by_cp_id = [] # Auto-wrap pass for `@qd.kernel(graph=True, checkpoints=True)` kernels. Mutates `node.body` in place so # every top-level for-loop (and every for-loop inside a `qd.graph_do_while` body) that the user did not diff --git a/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py b/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py new file mode 100644 index 0000000000..21f328a93b --- /dev/null +++ b/python/quadrants/lang/ast/ast_transformers/ndarray_arg_resolver.py @@ -0,0 +1,71 @@ +# type: ignore +"""Resolve an ndarray-referencing AST expression to its flat C++ kernel arg-id at AST-build time. + +Lives alongside ``checkpoint_transformer.py`` / ``call_transformer.py`` so the central ``ast_transformer.py`` file +doesn't have to grow per-feature. Both ``qd.checkpoint(yield_on=...)`` (via +``CheckpointTransformer.build_checkpoint_with``) and ``qd.graph_do_while(...)`` (via ``ASTTransformer.build_While``) +share this helper: each form takes a control-flag / counter ndarray argument that may be a bare kernel parameter +(e.g. ``flag``), a ``@qd.data_oriented`` member ndarray (e.g. ``self.flag``), or a ``@dataclasses.dataclass`` +parameter member (e.g. ``params.flag``). All three flatten to the same ``ExternalTensorExpression`` after AST build, +so resolving the arg-id here -- once, at kernel build time -- means the runtime launch path can forward it directly +with no per-launch name matching (which was the original ``_checkpoint_helpers`` approach and is incompatible with +member-ndarray flattening, since the flattened name is synthesised). + +See ``docs/source/user_guide/graph.md`` for the user-facing surface and ``perso_hugh/doc/qipc/reentrant.md`` for the +design. +""" + +from __future__ import annotations + +import ast + +from quadrants.lang.ast.ast_transformer_utils import ASTTransformerFuncContext +from quadrants.lang.exception import QuadrantsSyntaxError + + +def resolve_ndarray_kernel_arg_id( + ctx: ASTTransformerFuncContext, + kernel, + node: ast.expr, + usage: str, +) -> tuple[str, int]: + """Resolve ``node`` to ``(label, flat_cpp_arg_id)`` at AST-build time. + + Shared between ``qd.checkpoint(yield_on=...)`` and ``qd.graph_do_while(...)`` to turn the control-flag argument into + the flat C++ arg-id the runtime matches against. ``node`` is an ``ast.Name`` (a bare kernel parameter, e.g. + ``flag``) or an ``ast.Attribute`` chain (e.g. ``self.flag`` for a ``@qd.data_oriented`` owner, or ``params.flag`` + where ``params`` is a ``@dataclasses.dataclass`` kernel parameter). We build the expression through the normal AST + machinery and read the arg-id off the resulting external-tensor expression -- this unifies the bare-param and + member-ndarray cases, since both flatten to a real ndarray kernel argument carrying its arg-id on the + ``ExternalTensorExpression``. + + ``usage`` is the call form (e.g. ``"qd.checkpoint(yield_on=...)"``) used in the error message. Raises + ``QuadrantsSyntaxError`` if the expression does not resolve to an ndarray kernel argument. + """ + # Local imports to avoid an ast_transformers -> ast_transformer / any_array import cycle at module load: + # ``ast_transformer`` is the central transformer module that imports ``checkpoint_transformer`` (sibling of this + # file), and ``any_array`` pulls in core ndarray bindings that aren't needed for module import. + # pylint: disable=import-outside-toplevel + from quadrants.lang.any_array import AnyArray + from quadrants.lang.ast.ast_transformer import _qd_core, build_stmt + + # pylint: enable=import-outside-toplevel + + label = ast.unparse(node) + bad = QuadrantsSyntaxError( + f"{usage} got {label!r} which does not resolve to an ndarray kernel parameter of " + f"{kernel.func.__name__!r}. The argument must reference an ndarray kernel parameter (e.g. " + f"`flag`) or a @qd.data_oriented member ndarray (e.g. `self.flag`); other expressions are not " + f"supported." + ) + try: + built = build_stmt(ctx, node) + except Exception as e: # noqa: BLE001 - any resolution failure is a user-facing misuse + raise bad from e + resolved_expr = built.ptr if isinstance(built, AnyArray) else built + if not (hasattr(resolved_expr, "is_external_tensor_expr") and resolved_expr.is_external_tensor_expr()): + raise bad + arg_id = _qd_core.get_external_tensor_arg_id(resolved_expr) + if not arg_id: + raise bad + return label, int(arg_id[0]) diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index e663c6ed9e..ecee93e0aa 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -344,9 +344,20 @@ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _is_classkernel self._graph_do_while_level_stack: list[int] = [] # Per-checkpoint metadata, one entry per `with qd.checkpoint(...)` block (explicit AND auto-injected implicit) # in declaration order. List index is the checkpoint's internal `cp_id` (0, 1, 2, ... dense, flat across the - # kernel). Each entry is the name of the `yield_on=` kernel parameter, or `None` for implicit checkpoints (which - # never yield). Populated by the AST transformer; empty means the kernel uses no checkpoints. + # kernel). Each entry is the readable label of the `yield_on=` argument (e.g. "flag" or "self.flag"), or + # `None` for implicit checkpoints (which never yield). Populated by the AST transformer; empty means the + # kernel uses no checkpoints. Used for error messages / introspection only -- the runtime forwards the flat + # C++ arg-id from `checkpoint_yield_on_cpp_arg_ids` below. self.checkpoint_yield_on_args: list[str | None] = [] + # Flat C++ arg-ids (post-template) of each explicit checkpoint's `yield_on=` ndarray, resolved at AST-build time + # by `CheckpointTransformer.build_checkpoint_with` via `ASTTransformer._resolve_ndarray_kernel_arg_id`. Same + # indexing as `checkpoint_yield_on_args`: entry `i` is the flat arg-id the runtime uses to look up the ndarray's + # device pointer for the checkpoint whose internal cp_id is `i`. `-1` for implicit checkpoints (which never + # yield). Resolving at AST-build time uniformly handles bare kernel parameters (`yield_on=flag`), + # `@qd.data_oriented` member ndarrays (`yield_on=self.flag`), and `@dataclasses.dataclass` parameter members + # (`yield_on=params.flag`); the attribute forms cannot be resolved by the per-launch name match because + # `arg_metas[i].name` only carries top-level parameter names. + self.checkpoint_yield_on_cpp_arg_ids: list[int] = [] # User-facing labels for explicit checkpoints. Same indexing as `checkpoint_yield_on_args`: entry `i` is the int # (or IntEnum value) the user passed as the first positional arg of `qd.checkpoint(cp_id, yield_on)` for the # checkpoint whose internal cp_id is `i`. Implicit checkpoints (auto-wrapped) get `None` (they have no @@ -396,14 +407,12 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType self.fast_checksum = src_hasher.create_cache_key( self.raise_on_templated_floats, kernel_source_info, args, self.arg_metas ) - used_py_dataclass_parameters = None - cached_graph_do_while_levels: list[tuple[str, int]] | None = None + cache_value = None if self.fast_checksum: self.src_ll_cache_observations.cache_key_generated = True - used_py_dataclass_parameters, frontend_cache_key, cached_graph_do_while_levels = src_hasher.load( # type: ignore[reportAssignmentType] - self.fast_checksum - ) - if used_py_dataclass_parameters is not None and frontend_cache_key is not None: + cache_value = src_hasher.load(self.fast_checksum) + if cache_value is not None: + frontend_cache_key = cache_value.frontend_cache_key self.src_ll_cache_observations.cache_validated = True prog = impl.get_runtime().prog assert self.fast_checksum is not None @@ -415,16 +424,35 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType ) if self.compiled_kernel_data_by_key[key]: self.src_ll_cache_observations.cache_loaded = True - self.used_py_dataclass_parameters_by_key_enforcing[key] = used_py_dataclass_parameters - # Fast-cache restore skips AST transformation, so rebuild the gdw level table (and the legacy - # outermost-arg alias) from the cached (cond_arg_name, parent_id) pairs. - if cached_graph_do_while_levels: + self.used_py_dataclass_parameters_by_key_enforcing[key] = cache_value.used_py_dataclass_parameters + # Fast-cache restore skips AST transformation, so rebuild the AST-transformer-produced metadata from + # the cache value: nested graph_do_while level table (with the AST-resolved flat C++ arg-id) plus + # the per-checkpoint yield_on / user-label tables. Mirrors what `function_def_transformer.py` + + # `checkpoint_transformer.py` + `build_While` would have written. + if cache_value.graph_do_while_levels: self.graph_do_while_levels = [ - GraphDoWhileLevel(cond_arg_name=name, parent_id=parent) - for name, parent in cached_graph_do_while_levels + GraphDoWhileLevel(cond_arg_name=name, parent_id=parent, cond_cpp_arg_id=cpp_arg_id) + for name, parent, cpp_arg_id in cache_value.graph_do_while_levels ] self.graph_do_while_arg = self.graph_do_while_levels[0].cond_arg_name - return used_py_dataclass_parameters + if cache_value.checkpoint_yield_on_args: + self.checkpoint_yield_on_args = list(cache_value.checkpoint_yield_on_args) + self.checkpoint_yield_on_cpp_arg_ids = list(cache_value.checkpoint_yield_on_cpp_arg_ids) + # Pydantic coerces IntEnum -> int at CacheValue construction time, so the raw labels are plain + # ints after JSON round-trip. ``checkpoint_user_label_enum_qualnames`` carries the parallel + # ``module.ClassQualName.MEMBER`` strings that ``_resolve_intenum_member`` uses to rebuild the + # original ``IntEnum`` member -- preserving the documented contract that + # ``qd.checkpoint(Stage.X, ...)`` surfaces as ``Stage.X`` (not the raw int) on + # ``status.checkpoint``. Older v3 caches predate the qualname column, so we default any missing + # slots to ``None`` -> raw-int fallback (the same behaviour they had on v3). + raw_labels = list(cache_value.checkpoint_user_labels_by_cp_id) + qualnames = list(cache_value.checkpoint_user_label_enum_qualnames) or [None] * len(raw_labels) + if len(qualnames) != len(raw_labels): + qualnames = [None] * len(raw_labels) + self.checkpoint_user_labels_by_cp_id = [ + src_hasher._resolve_intenum_member(qn, lbl) for qn, lbl in zip(qualnames, raw_labels) + ] + return cache_value.used_py_dataclass_parameters elif self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: # The bit in caps should not be modified without updating corresponding test @@ -587,13 +615,9 @@ def launch_kernel( is_launch_ctx_cacheable = True template_num = 0 i_out = 0 - # Hoist the `kernel has any yield_on= checkpoint` predicate out of the per-arg loop so non-checkpoint - # kernels (the overwhelming majority) skip the helper function call entirely on every arg. Gated on - # `use_checkpoints` first so non-checkpoint kernels pay only one attribute lookup, not the list-truthy - # check on every cache-miss build. - _kernel_has_yield_on_checkpoint = self.use_checkpoints and bool(self.checkpoint_yield_on_args) - if _kernel_has_yield_on_checkpoint: - _checkpoint_helpers.init_yield_on_arg_id_table(self) + # `checkpoint_yield_on_cpp_arg_ids` is populated at AST-build time (see + # `CheckpointTransformer.build_checkpoint_with`); no per-arg name match is needed here. The launch path + # below forwards the table to the launch context with a single `forward_yield_on_table_to_ctx` call. for i_in, val in enumerate(args): needed_ = self.arg_metas[i_in].annotation if needed_ is template or type(needed_) is template: @@ -607,12 +631,11 @@ def launch_kernel( # which weakens API/type safety and can route the wrong struct type through launch. if getattr(val, "_qd_all_field", False) and getattr(needed_, _FIELDS, None) is not None: continue - if self.graph_do_while_levels: - for _gdw_level in self.graph_do_while_levels: - if self.arg_metas[i_in].name == _gdw_level.cond_arg_name: - _gdw_level.cond_cpp_arg_id = i_out - template_num - if _kernel_has_yield_on_checkpoint: - _checkpoint_helpers.maybe_record_yield_on_arg(self, self.arg_metas[i_in].name, i_out - template_num) + # `graph_do_while_levels[*].cond_cpp_arg_id` is also populated at AST-build time (see + # `ASTTransformer.build_While` -> `_resolve_ndarray_kernel_arg_id`), so the launch path forwards it + # directly below without per-arg name matching here. This uniformly handles bare parameter conditions + # (`qd.graph_do_while(counter)`) and `@qd.data_oriented` member conditions + # (`qd.graph_do_while(self.counter)`). num_args_, is_launch_ctx_cacheable_ = self._recursive_set_args( self.used_py_dataclass_parameters_by_key_enforcing[key], self.arg_metas[i_in].name, @@ -675,8 +698,12 @@ def launch_kernel( self.visited_functions, self.used_py_dataclass_parameters_by_key_enforcing[key], graph_do_while_levels=[ # type: ignore[reportCallIssue] - (level.cond_arg_name, level.parent_id) for level in self.graph_do_while_levels + (level.cond_arg_name, level.parent_id, level.cond_cpp_arg_id) + for level in self.graph_do_while_levels ], + checkpoint_yield_on_args=list(self.checkpoint_yield_on_args), + checkpoint_yield_on_cpp_arg_ids=list(self.checkpoint_yield_on_cpp_arg_ids), + checkpoint_user_labels_by_cp_id=list(self.checkpoint_user_labels_by_cp_id), ) self.src_ll_cache_observations.cache_stored = True self._last_compiled_kernel_data = compiled_kernel_data diff --git a/python/quadrants/lang/kernel_checkpoint.py b/python/quadrants/lang/kernel_checkpoint.py index a3d44eac5f..6b80491f67 100644 --- a/python/quadrants/lang/kernel_checkpoint.py +++ b/python/quadrants/lang/kernel_checkpoint.py @@ -46,36 +46,17 @@ def translate_user_label_to_internal_cp_id(kernel: Any, user_label: int) -> int: ) -def init_yield_on_arg_id_table(kernel: Any) -> None: - """Allocate / reset the per-launch ``cp_id -> C++ arg-id`` table at the top of ``launch_kernel``'s arg iteration. - - Each entry defaults to ``-1`` ("no yield_on"); the per-arg loop below fills in the C++ arg id when it visits the - named parameter. Sized to the kernel's checkpoint count once per launch so any changes to the checkpoint set (only - possible via re-AST-walk) reset the table cleanly. No-op for kernels with no ``yield_on=`` checkpoints. - """ - if kernel.checkpoint_yield_on_args: - kernel._checkpoint_yield_on_cpp_arg_ids = [-1] * len(kernel.checkpoint_yield_on_args) - - -def maybe_record_yield_on_arg(kernel: Any, arg_name: str, cpp_arg_id: int) -> None: - """Fill the ``cp_id -> C++ arg-id`` slot when the arg iterator visits a named ``yield_on=`` kernel parameter. - - Walked once per kernel arg in ``launch_kernel``; cheap O(checkpoints) match. A single parameter can be the - ``yield_on=`` for multiple checkpoints (the inner loop fills every matching slot). - """ - if not kernel.checkpoint_yield_on_args: - return - for cp_idx, yield_name in enumerate(kernel.checkpoint_yield_on_args): - if yield_name is not None and arg_name == yield_name: - kernel._checkpoint_yield_on_cpp_arg_ids[cp_idx] = cpp_arg_id - - def forward_yield_on_table_to_ctx(kernel: Any, launch_ctx: Any) -> None: - """Copy the resolved ``cp_id -> C++ arg-id`` table onto the launch context so the runtime can find each - ``yield_on=`` ndarray's device address at launch. + """Copy the ``cp_id -> C++ arg-id`` table onto the launch context so the runtime can find each ``yield_on=`` + ndarray's device address at launch. + + The table (``kernel.checkpoint_yield_on_cpp_arg_ids``) is populated at AST-build time by + ``CheckpointTransformer.build_checkpoint_with`` via ``ASTTransformer._resolve_ndarray_kernel_arg_id``, which + uniformly handles bare kernel parameters (``yield_on=flag``) and ``@qd.data_oriented`` member ndarrays + (``yield_on=self.flag``). No per-launch arg iteration / name match is required. """ - if kernel.checkpoint_yield_on_args and hasattr(kernel, "_checkpoint_yield_on_cpp_arg_ids"): - launch_ctx.checkpoint_yield_on_arg_ids = tuple(kernel._checkpoint_yield_on_cpp_arg_ids) + if kernel.checkpoint_yield_on_cpp_arg_ids: + launch_ctx.checkpoint_yield_on_arg_ids = tuple(kernel.checkpoint_yield_on_cpp_arg_ids) def maybe_build_graph_status(kernel: Any, default_ret: Any) -> Any: diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index e589ea907f..afdfb8471d 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -862,6 +862,15 @@ void export_lang(nb::module_ &m) { } }); + // The (post-template) C++ arg-id vector of an external-tensor (ndarray) expression. For a top-level ndarray parameter + // or a flattened `@qd.data_oriented` member ndarray this is a single-element vector whose `[0]` entry is the flat + // arg-id the runtime matches against (e.g. for `qd.checkpoint(yield_on=...)` and `qd.graph_do_while(...)` + // AST-build-time resolution of bare-parameter vs `self.member` ndarray arguments). + m.def("get_external_tensor_arg_id", [](const Expr &expr) { + QD_ASSERT(expr.is()); + return expr.cast()->arg_id; + }); + m.def("get_external_tensor_shape_along_axis", Expr::make); diff --git a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py index e7a2d9952b..b2131e5817 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py @@ -4,6 +4,7 @@ import shutil import subprocess import sys +from enum import IntEnum from typing import Callable import pydantic @@ -109,19 +110,17 @@ def get_fileinfos(functions: list[Callable]) -> list[_wrap_inspect.FunctionSourc kernel_cache_key = "I'm a kernel cache key" - load_res = src_hasher.load(fast_cache_key) - assert load_res[0] is None and load_res[1] is None + assert src_hasher.load(fast_cache_key) is None some_used_vars = {"fee", "fi", "fo"} src_hasher.store(kernel_cache_key, fast_cache_key, fileinfos, some_used_vars) def assert_loaded(cache_key: str) -> None: res = src_hasher.load(cache_key) - assert res[0] is not None and res[1] is not None + assert res is not None and res.frontend_cache_key == kernel_cache_key def assert_not_loaded(cache_key: str) -> None: - res = src_hasher.load(cache_key) - assert res[0] is None and res[1] is None + assert src_hasher.load(cache_key) is None assert_loaded(fast_cache_key) @@ -136,7 +135,135 @@ def assert_not_loaded(cache_key: str) -> None: assert_not_loaded("abcdefg") - assert src_hasher.load(fast_cache_key)[0] == some_used_vars + loaded = src_hasher.load(fast_cache_key) + assert loaded is not None + assert loaded.used_py_dataclass_parameters == some_used_vars + # The new schema-v3+v4 AST-resolved fields default to empty for kernels with no graph_do_while / checkpoint + # metadata, exercising the BaseModel default path on round-trip. + assert loaded.graph_do_while_levels is None + assert loaded.checkpoint_yield_on_args == [] + assert loaded.checkpoint_yield_on_cpp_arg_ids == [] + assert loaded.checkpoint_user_labels_by_cp_id == [] + assert loaded.checkpoint_user_label_enum_qualnames == [] + + +@test_utils.test() +def test_src_hasher_store_validate_round_trips_schema_v3_metadata( + monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, temporary_module +) -> None: + """Schema v3 (`cachevalue-v3-ast-resolved-ids`) added AST-resolved arg-id fields to the persisted ``CacheValue`` so + the launch path can forward them after a fast-cache restore (which skips AST transformation). This test pins the + round-trip for the new fields -- ``graph_do_while_levels`` as 3-tuples carrying ``cond_cpp_arg_id``, plus + ``checkpoint_yield_on_args`` / ``checkpoint_yield_on_cpp_arg_ids`` / ``checkpoint_user_labels_by_cp_id``. Without + this, a schema bug (wrong tuple arity, dropped field, mis-typed BaseModel default) would only surface via a hard-to- + debug functional regression in a fast-cached checkpoint / graph_do_while kernel.""" + test_files_path = pathlib.Path("tests/python/quadrants/lang/fast_caching/test_files") + + offline_cache_path = tmp_path / "cache" + temp_import_path = tmp_path / "temp_import" + temp_import_path.mkdir(exist_ok=True) + + qd_init_same_arch(offline_cache_file_path=str(offline_cache_path)) + + monkeypatch.syspath_prepend(temp_import_path) + shutil.copy2(test_files_path / "child_diff_base.py", temp_import_path / "child_diff_schema_v3.py") + mod = temporary_module("child_diff_schema_v3") + info, _src = _wrap_inspect.get_source_info_and_src(mod.f1.fn) + fileinfos = [info] + fast_cache_key = src_hasher.create_cache_key(False, info, [], []) + assert fast_cache_key is not None + + gdw_levels = [("self.outer", -1, 4), ("self.inner", 0, 5)] + cp_yield_args = ["self.flag", None, "params.flag"] + cp_yield_cpp_ids = [3, -1, 7] + cp_user_labels = [10, None, 20] + + src_hasher.store( + "kernel_cache_key_v3", + fast_cache_key, + fileinfos, + {"used_var"}, + graph_do_while_levels=gdw_levels, + checkpoint_yield_on_args=cp_yield_args, + checkpoint_yield_on_cpp_arg_ids=cp_yield_cpp_ids, + checkpoint_user_labels_by_cp_id=cp_user_labels, + ) + + loaded = src_hasher.load(fast_cache_key) + assert loaded is not None + assert loaded.frontend_cache_key == "kernel_cache_key_v3" + assert loaded.graph_do_while_levels == [("self.outer", -1, 4), ("self.inner", 0, 5)] + assert loaded.checkpoint_yield_on_args == cp_yield_args + assert loaded.checkpoint_yield_on_cpp_arg_ids == cp_yield_cpp_ids + assert loaded.checkpoint_user_labels_by_cp_id == cp_user_labels + # Plain-int labels record `None` in the parallel qualname column (no IntEnum identity to preserve). + assert loaded.checkpoint_user_label_enum_qualnames == [None, None, None] + + +@test_utils.test() +def test_src_hasher_intenum_qualname_round_trip( + monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, temporary_module +) -> None: + """Schema v4 (`cachevalue-v4-intenum-qualnames`) added a parallel `checkpoint_user_label_enum_qualnames` column so + an ``IntEnum`` cp_id round-trips through fast-cache restore as the original enum member rather than the underlying + int. ``src_hasher.store`` derives the qualname column from the live label list (which still holds the original + ``IntEnum`` instances) before pydantic int-coerces them; ``_resolve_intenum_member`` re-imports the enum class on + load. This test covers both the store-side derivation (mixed IntEnum / plain int / None) and the load-side + resolution (verifies identity is preserved, not just int equality).""" + test_files_path = pathlib.Path("tests/python/quadrants/lang/fast_caching/test_files") + offline_cache_path = tmp_path / "cache" + temp_import_path = tmp_path / "temp_import" + temp_import_path.mkdir(exist_ok=True) + qd_init_same_arch(offline_cache_file_path=str(offline_cache_path)) + monkeypatch.syspath_prepend(temp_import_path) + shutil.copy2(test_files_path / "child_diff_base.py", temp_import_path / "child_diff_v4_intenum.py") + mod = temporary_module("child_diff_v4_intenum") + info, _src = _wrap_inspect.get_source_info_and_src(mod.f1.fn) + fast_cache_key = src_hasher.create_cache_key(False, info, [], []) + assert fast_cache_key is not None + + # Reference the module-level enum below so it has a real importable qualname. + src_hasher.store( + "kernel_cache_key_v4", + fast_cache_key, + [info], + {"used_var"}, + checkpoint_yield_on_args=["flag", None, "flag"], + checkpoint_yield_on_cpp_arg_ids=[1, -1, 1], + checkpoint_user_labels_by_cp_id=[_HasherTestStage.LOAD, None, _HasherTestStage.REDUCE], + ) + + loaded = src_hasher.load(fast_cache_key) + assert loaded is not None + # Persisted raw labels are plain ints (pydantic coerced them); the qualname column is what carries identity. + assert loaded.checkpoint_user_labels_by_cp_id == [5, None, 9] + assert loaded.checkpoint_user_label_enum_qualnames == [ + f"{_HasherTestStage.__module__}.{_HasherTestStage.__qualname__}.LOAD", + None, + f"{_HasherTestStage.__module__}.{_HasherTestStage.__qualname__}.REDUCE", + ] + + # Resolver round-trip: rebuild each slot through `_resolve_intenum_member` and confirm enum identity (not just int- + # equality) is preserved. + resolved = [ + src_hasher._resolve_intenum_member(qn, lbl) + for qn, lbl in zip(loaded.checkpoint_user_label_enum_qualnames, loaded.checkpoint_user_labels_by_cp_id) + ] + assert resolved == [_HasherTestStage.LOAD, None, _HasherTestStage.REDUCE] + assert isinstance(resolved[0], _HasherTestStage) + assert isinstance(resolved[2], _HasherTestStage) + + # Resolver fallback: an unresolvable qualname (e.g. enum class moved/renamed since cache write) must drop back to + # the persisted int rather than raising, so a stale cache entry degrades gracefully. + assert src_hasher._resolve_intenum_member("nonexistent.Module.Stage.LOAD", 5) == 5 + + +# Top-level IntEnum used by `test_src_hasher_intenum_qualname_round_trip` so the resolver can re-import it via +# `importlib.import_module("tests.python.quadrants.lang.fast_caching.test_src_hasher")`. Lives at module scope (not +# inside the test) for the same reason `_FastcacheStage` / `_Stage` do in `test_checkpoint.py`. +class _HasherTestStage(IntEnum): + LOAD = 5 + REDUCE = 9 # Should be enough to run these on cpu I think, and anything involving diff --git a/tests/python/test_checkpoint.py b/tests/python/test_checkpoint.py index cf03d6c095..e65c782f59 100644 --- a/tests/python/test_checkpoint.py +++ b/tests/python/test_checkpoint.py @@ -19,9 +19,14 @@ (IF conditional node count) are guarded behind ``_is_checkpoint_if_path_native``. """ +import os +import pathlib +import subprocess +import sys from enum import IntEnum import numpy as np +import pydantic import pytest import quadrants as qd @@ -29,6 +34,9 @@ from tests import test_utils +TEST_RAN = "test ran" +RET_SUCCESS = 42 + def _on_cuda(): return impl.current_cfg().arch == qd.cuda @@ -208,7 +216,7 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 @test_utils.test() def test_checkpoint_yield_on_nonexistent_arg_raises(): - """``yield_on`` must name a kernel parameter; typos / scope mismatches must error early.""" + """``yield_on`` must reference an ndarray kernel argument; typos / scope mismatches must error early.""" @qd.kernel(graph=True, checkpoints=True) def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0)): @@ -218,14 +226,15 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 x = qd.ndarray(qd.i32, shape=(4,)) flag = qd.ndarray(qd.i32, shape=()) - with pytest.raises(qd.QuadrantsSyntaxError, match="does not match any parameter"): + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): k(x, flag) @test_utils.test() -def test_checkpoint_yield_on_must_be_bare_name(): - """``yield_on=`` must be a bare ``ast.Name`` (a kernel parameter); expressions are not supported. Pinning the - diagnostic so the user knows to refactor.""" +def test_checkpoint_yield_on_must_be_name_or_attribute(): + """``yield_on=`` must reference an ndarray kernel argument -- either a bare ``ast.Name`` or an ``ast.Attribute`` + chain (for ``@qd.data_oriented`` member ndarrays). Arbitrary expressions are not supported; pinning the diagnostic + so the user knows to refactor.""" @qd.kernel(graph=True, checkpoints=True) def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0)): @@ -235,7 +244,7 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 x = qd.ndarray(qd.i32, shape=(4,)) flag = qd.ndarray(qd.i32, shape=()) - with pytest.raises(qd.QuadrantsSyntaxError, match=r"must be the bare name of a kernel parameter"): + with pytest.raises(qd.QuadrantsSyntaxError, match=r"must reference a kernel ndarray argument"): k(x, flag) @@ -843,6 +852,405 @@ def k( assert status.checkpoint == 0 +# ---------------------------------------------------------------------------------------------------------------------- +# Member-ndarray support for `yield_on=` (both `@qd.data_oriented` self-members and `@dataclasses.dataclass` parameter +# members). +# +# ``qd.checkpoint(yield_on=self.flag)`` and ``qd.checkpoint(yield_on=params.flag)`` (where ``params`` is a kernel +# parameter typed as a ``@dataclasses.dataclass``) both resolve the member ndarray to a flat C++ arg-id at AST-build +# time via ``ASTTransformer._resolve_ndarray_kernel_arg_id``: it builds the expression and reads the resolved +# ``ExternalTensorExpression.arg_id``, so any attribute chain that ends up as a kernel ndarray arg works the same way +# as a bare parameter name. This frees users from having to forward flag members as bare kernel parameters when the +# rest of the kernel already operates on the dataclass / data-oriented owner. +# ---------------------------------------------------------------------------------------------------------------------- + + +@test_utils.test() +def test_checkpoint_yield_on_data_oriented_member_metadata(): + """`yield_on=self.flag` is accepted and the resolved label is stored verbatim (``"self.flag"``) in + ``checkpoint_yield_on_args``, while ``checkpoint_yield_on_cpp_arg_ids`` carries the flat C++ arg-id the runtime + forwards to the launch context. Verifies the AST-build-time resolution path without booting the backend.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.flag = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True, checkpoints=True) + def step(self): + with qd.checkpoint(0, yield_on=self.flag): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + sim = Sim() + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.flag.from_numpy(np.array(0, dtype=np.int32)) + sim.step() + np.testing.assert_array_equal(sim.x.to_numpy(), np.ones(N, dtype=np.int32)) + assert sim.step._primal.checkpoint_user_labels_by_cp_id == [0] + assert sim.step._primal.checkpoint_yield_on_args == ["self.flag"] + cpp_ids = sim.step._primal.checkpoint_yield_on_cpp_arg_ids + assert len(cpp_ids) == 1 and cpp_ids[0] >= 0 + + +@test_utils.test() +def test_checkpoint_yield_on_dataclass_member_metadata(): + """`yield_on=params.flag` for a ``@dataclasses.dataclass`` kernel parameter takes the same AST-build-time resolution + path as ``self.flag`` for a ``@qd.data_oriented`` owner -- the resolved label round-trips into + ``checkpoint_yield_on_args`` and the flat arg-id lands in ``checkpoint_yield_on_cpp_arg_ids``.""" + import dataclasses # pylint: disable=import-outside-toplevel + + N = 4 + + @dataclasses.dataclass + class Params: + x: qd.types.NDArray[qd.i32, 1] + flag: qd.types.NDArray[qd.i32, 0] + + @qd.kernel(graph=True, checkpoints=True) + def step(params: Params): + with qd.checkpoint(0, yield_on=params.flag): + for i in range(params.x.shape[0]): + params.x[i] = params.x[i] + 1 + + params = Params( + x=qd.ndarray(qd.i32, shape=(N,)), + flag=qd.ndarray(qd.i32, shape=()), + ) + params.x.from_numpy(np.zeros(N, dtype=np.int32)) + params.flag.from_numpy(np.array(0, dtype=np.int32)) + step(params) + np.testing.assert_array_equal(params.x.to_numpy(), np.ones(N, dtype=np.int32)) + assert step._primal.checkpoint_user_labels_by_cp_id == [0] + # Dataclass-parameter member access gets pre-rewritten by the AST pipeline to a flattened parameter name + # (`__qd_params__qd_flag`) before the checkpoint transformer sees it, so the label round-trips in the flattened + # form. The functional contract -- a valid flat C++ arg-id is resolved and the kernel mutates the right ndarray -- + # is the same as for the bare-param / `self.flag` forms. + labels = step._primal.checkpoint_yield_on_args + assert len(labels) == 1 and labels[0] is not None and "flag" in labels[0] + cpp_ids = step._primal.checkpoint_yield_on_cpp_arg_ids + assert len(cpp_ids) == 1 and cpp_ids[0] >= 0 + + +@test_utils.test() +def test_checkpoint_yield_on_dataclass_member_yields_and_resumes(): + """Behavioural round-trip for `yield_on=params.flag` -- mirror of the `self.flag` test below, using a + `@dataclasses.dataclass` kernel parameter instead of a `@qd.data_oriented` owner. The dataclass-member access is + pre-rewritten to a flattened parameter, so verifying the full yield/resume contract end-to-end is the only way to + confirm the right ndarray is wired up at launch.""" + import dataclasses # pylint: disable=import-outside-toplevel + + if not _supports_checkpoint_yield_resume(): + pytest.skip("backend does not implement checkpoint yield/resume") + N = 4 + + @dataclasses.dataclass + class Params: + x: qd.types.NDArray[qd.i32, 1] + flag: qd.types.NDArray[qd.i32, 0] + + @qd.kernel(graph=True, checkpoints=True) + def step(params: Params): + with qd.checkpoint(7, yield_on=params.flag): + for i in range(params.x.shape[0]): + params.x[i] = params.x[i] + 1 + params.flag[()] = 1 + with qd.checkpoint(8, yield_on=params.flag): + for i in range(params.x.shape[0]): + params.x[i] = params.x[i] + 10 + + params = Params( + x=qd.ndarray(qd.i32, shape=(N,)), + flag=qd.ndarray(qd.i32, shape=()), + ) + params.x.from_numpy(np.zeros(N, dtype=np.int32)) + params.flag.from_numpy(np.array(0, dtype=np.int32)) + status = step(params) + assert status.yielded + assert status.checkpoint == 7 + np.testing.assert_array_equal(params.x.to_numpy(), np.ones(N, dtype=np.int32)) + params.flag.from_numpy(np.array(0, dtype=np.int32)) + # `step` is a free-function kernel (not a bound class kernel), so `params` must be passed positionally to + # `resume` -- the data_oriented sibling test above can omit it because the dataclass member access is implicit + # through `sim.step`'s bound `self`. + status = step.resume(params, from_checkpoint=8) + assert not status.yielded + np.testing.assert_array_equal(params.x.to_numpy(), np.full(N, 11, dtype=np.int32)) + + +@test_utils.test() +def test_checkpoint_yield_on_member_nonexistent_attribute_raises(): + """`yield_on=self.nonexistent_attr` (attribute does not exist on the `@qd.data_oriented` owner) must raise a user- + facing `QuadrantsSyntaxError` at the `with` site -- the AST-time resolver wraps the underlying attribute lookup + failure in the same `does not resolve to an ndarray kernel parameter` diagnostic as the bare-name nonexistent case, + so users see one consistent error pattern.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.flag = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True, checkpoints=True) + def step(self): + with qd.checkpoint(0, yield_on=self.nonexistent_flag): # type: ignore[attr-defined] + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + sim = Sim() + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): + sim.step() + + +@test_utils.test() +def test_checkpoint_yield_on_member_non_ndarray_attribute_raises(): + """`yield_on=self.scalar` where `self.scalar` is a Python int (not an ndarray) must raise the same `does not resolve + to an ndarray kernel parameter` diagnostic -- the AST-time resolver builds the expression but rejects it because the + resulting Expr is not an `ExternalTensorExpression`. Pinning this so future refactors of the resolver can't silently + accept non-ndarray attributes and crash later in the launcher.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.scalar = 7 + + @qd.kernel(graph=True, checkpoints=True) + def step(self): + with qd.checkpoint(0, yield_on=self.scalar): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + sim = Sim() + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): + sim.step() + + +@test_utils.test() +def test_checkpoint_yield_on_data_oriented_member_yields_and_resumes(): + """Behavioural round-trip for `yield_on=self.flag`: setting the member flag from inside the kernel yields, and + ``kernel.resume(from_checkpoint=...)`` skips ahead to the named checkpoint. Same surface contract as the + bare-parameter form (`test_checkpoint_yield_on_yields_and_resumes`); the only difference is where the flag lives.""" + if not _supports_checkpoint_yield_resume(): + pytest.skip("backend does not implement checkpoint yield/resume") + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.flag = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True, checkpoints=True) + def step(self): + with qd.checkpoint(7, yield_on=self.flag): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + self.flag[()] = 1 + with qd.checkpoint(8, yield_on=self.flag): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 10 + + sim = Sim() + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.flag.from_numpy(np.array(0, dtype=np.int32)) + status = sim.step() + # Checkpoint 7 set the flag in the first iter so the kernel yields before running checkpoint 8. + assert status.yielded + assert status.checkpoint == 7 + np.testing.assert_array_equal(sim.x.to_numpy(), np.ones(N, dtype=np.int32)) + # User clears the flag and resumes from the post-yield checkpoint (skipping the +1 loop entirely). + sim.flag.from_numpy(np.array(0, dtype=np.int32)) + status = sim.step.resume(from_checkpoint=8) + assert not status.yielded + np.testing.assert_array_equal(sim.x.to_numpy(), np.full(N, 11, dtype=np.int32)) + + +# Module-level kernel for the fastcache-restoration test below. Lives outside any test so the child subprocess can +# import the test module and reach it without re-creating the (closure-captured) outer scope. The kernel has to be +# annotated with `fastcache=True` (=> implies `pure`) and lifted out of any decorator-bound owner so it qualifies for +# the src_ll_cache path. We model the data_oriented owner as the `_FastcacheYieldOnSelfCheckpoint` class below. + + +@qd.data_oriented +class _FastcacheYieldOnSelfCheckpoint: + def __init__(self, n: int): + self.x = qd.ndarray(qd.i32, shape=(n,)) + self.flag = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True, checkpoints=True, fastcache=True) + def step(self): + with qd.checkpoint(0, yield_on=self.flag): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + +class _FastcacheCheckpointArgs(pydantic.BaseModel): + arch: str + offline_cache_file_path: str + expect_loaded_from_fastcache: bool + + +def _fastcache_checkpoint_child(args: list[str]) -> None: + args_obj = _FastcacheCheckpointArgs.model_validate_json(args[0]) + qd.init( + arch=getattr(qd, args_obj.arch), + offline_cache=True, + offline_cache_file_path=args_obj.offline_cache_file_path, + src_ll_cache=True, + ) + + N = 8 + sim = _FastcacheYieldOnSelfCheckpoint(N) + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.flag.from_numpy(np.array(0, dtype=np.int32)) + sim.step() + np.testing.assert_array_equal(sim.x.to_numpy(), np.ones(N, dtype=np.int32)) + + primal = type(sim).step._primal + # The schema-v3 fast-cache restore path must repopulate `checkpoint_yield_on_args` and + # `checkpoint_yield_on_cpp_arg_ids` from the cached `CacheValue` (since AST transformation is skipped on a cache + # hit). A regression here would surface as an empty `_forward_yield_on_table_to_ctx` call, silently breaking + # yield/resume on fast-cached checkpoint kernels. + labels = primal.checkpoint_yield_on_args + cpp_ids = primal.checkpoint_yield_on_cpp_arg_ids + assert ( + labels and len(labels) == 1 and labels[0] is not None and "flag" in labels[0] + ), f"checkpoint_yield_on_args should round-trip with one slot containing 'flag', got {labels!r}" + assert ( + len(cpp_ids) == 1 and cpp_ids[0] >= 0 + ), f"checkpoint_yield_on_cpp_arg_ids should round-trip with one valid id, got {cpp_ids!r}" + assert primal.checkpoint_user_labels_by_cp_id == [ + 0 + ], f"checkpoint_user_labels_by_cp_id should round-trip as [0], got {primal.checkpoint_user_labels_by_cp_id!r}" + assert primal.src_ll_cache_observations.cache_loaded == args_obj.expect_loaded_from_fastcache, ( + f"cache_loaded={primal.src_ll_cache_observations.cache_loaded!r} but expected " + f"{args_obj.expect_loaded_from_fastcache!r}" + ) + + print(TEST_RAN) + sys.exit(RET_SUCCESS) + + +@test_utils.test() +def test_checkpoint_fastcache_restores_self_member_yield_on(tmp_path: pathlib.Path): + """After a fast-cache restore in a fresh process, a `@qd.kernel(graph=True, checkpoints=True, fastcache=True)` + kernel with `yield_on=self.flag` must repopulate `checkpoint_yield_on_args` / `checkpoint_yield_on_cpp_arg_ids` / + `checkpoint_user_labels_by_cp_id` from the persisted ``CacheValue`` -- not from the AST transformer, which is + skipped on a cache hit. Without the schema-v3 round-trip the launch path's `forward_yield_on_table_to_ctx` would + be a no-op and yield/resume would silently break for fast-cached checkpoint kernels.""" + assert qd.lang is not None + arch = qd.lang.impl.current_cfg().arch.name + env = dict(os.environ) + env["PYTHONPATH"] = "." + + for expect_loaded in [False, True]: + args_obj = _FastcacheCheckpointArgs( + arch=arch, + offline_cache_file_path=str(tmp_path / "cache"), + expect_loaded_from_fastcache=expect_loaded, + ) + cmd_line = [sys.executable, __file__, _fastcache_checkpoint_child.__name__, args_obj.model_dump_json()] + proc = subprocess.run(cmd_line, capture_output=True, text=True, env=env) + if proc.returncode != RET_SUCCESS: + print(" ".join(cmd_line)) + print(proc.stdout) + print("-" * 100) + print(proc.stderr) + assert TEST_RAN in proc.stdout + assert proc.returncode == RET_SUCCESS + + +# Module-level IntEnum so `_resolve_intenum_member` can find it via importlib from the persisted qualname +# (`tests.python.test_checkpoint._FastcacheStage.LOAD`). The kernel below uses it as the cp_id so the fast-cache +# round-trip exercises the schema-v4 enum-identity preservation path. +class _FastcacheStage(IntEnum): + LOAD = 10 + REDUCE = 20 + + +@qd.kernel(graph=True, checkpoints=True, fastcache=True) +def _fastcache_intenum_kernel( + x: qd.types.ndarray(qd.i32, ndim=1), + flag: qd.types.ndarray(qd.i32, ndim=0), +): + with qd.checkpoint(_FastcacheStage.LOAD, yield_on=flag): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + with qd.checkpoint(_FastcacheStage.REDUCE, yield_on=flag): + for i in range(x.shape[0]): + x[i] = x[i] + 10 + + +def _fastcache_intenum_child(args: list[str]) -> None: + args_obj = _FastcacheCheckpointArgs.model_validate_json(args[0]) + qd.init( + arch=getattr(qd, args_obj.arch), + offline_cache=True, + offline_cache_file_path=args_obj.offline_cache_file_path, + src_ll_cache=True, + ) + + N = 4 + x = qd.ndarray(qd.i32, shape=(N,)) + flag = qd.ndarray(qd.i32, shape=()) + x.from_numpy(np.zeros(N, dtype=np.int32)) + flag.from_numpy(np.array(0, dtype=np.int32)) + _fastcache_intenum_kernel(x, flag) + np.testing.assert_array_equal(x.to_numpy(), np.full(N, 11, dtype=np.int32)) + + primal = _fastcache_intenum_kernel._primal + labels = primal.checkpoint_user_labels_by_cp_id + # The schema-v4 round-trip must rebuild the IntEnum identity, not just the int equality. A regression here would + # show up as `labels == [10, 20]` (plain ints) breaking the documented contract that + # `qd.checkpoint(Stage.X, ...)` surfaces as `Stage.X` (not the raw int) on `status.checkpoint`. + assert labels == [ + _FastcacheStage.LOAD, + _FastcacheStage.REDUCE, + ], f"checkpoint_user_labels_by_cp_id should round-trip with IntEnum identity, got {labels!r}" + assert all( + isinstance(lbl, _FastcacheStage) for lbl in labels + ), f"every label slot must be a _FastcacheStage instance, got {[type(lbl).__name__ for lbl in labels]!r}" + assert primal.src_ll_cache_observations.cache_loaded == args_obj.expect_loaded_from_fastcache + + print(TEST_RAN) + sys.exit(RET_SUCCESS) + + +@test_utils.test() +def test_checkpoint_fastcache_preserves_intenum_label_identity(tmp_path: pathlib.Path): + """Fast-cache restore must rebuild ``checkpoint_user_labels_by_cp_id`` with the original ``IntEnum`` members, not + just int-equal plain ints. Schema v4 adds a parallel ``checkpoint_user_label_enum_qualnames`` column so + ``_resolve_intenum_member`` can re-import the enum class on cache hit -- pydantic coerces ``IntEnum`` to ``int`` at + ``CacheValue`` construction, which would otherwise silently drop enum identity and break the documented contract + that ``qd.checkpoint(Stage.X, ...)`` surfaces as ``Stage.X`` (not the raw int) on ``status.checkpoint`` after a + fast-cache hit.""" + assert qd.lang is not None + arch = qd.lang.impl.current_cfg().arch.name + env = dict(os.environ) + env["PYTHONPATH"] = "." + + for expect_loaded in [False, True]: + args_obj = _FastcacheCheckpointArgs( + arch=arch, + offline_cache_file_path=str(tmp_path / "cache"), + expect_loaded_from_fastcache=expect_loaded, + ) + cmd_line = [sys.executable, __file__, _fastcache_intenum_child.__name__, args_obj.model_dump_json()] + proc = subprocess.run(cmd_line, capture_output=True, text=True, env=env) + if proc.returncode != RET_SUCCESS: + print(" ".join(cmd_line)) + print(proc.stdout) + print("-" * 100) + print(proc.stderr) + assert TEST_RAN in proc.stdout + assert proc.returncode == RET_SUCCESS + + # ---------------------------------------------------------------------------------------------------------------------- # CUDA-native introspection (slice 1c). # ---------------------------------------------------------------------------------------------------------------------- @@ -877,3 +1285,11 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), flag: qd.types.ndarray(qd.i32, ndim=0 assert ( _num_checkpoints_on_last_call() == 3 ), f"expected 3 IF conditional nodes (2 implicit + 1 explicit), got {_num_checkpoints_on_last_call()}" + + +# Subprocess dispatch for fast-cache restoration tests above (mirrors the pattern in `test_graph_do_while.py`). The +# parent test invokes us via `subprocess.run([sys.executable, __file__, , ])` so the child +# runs in a fresh interpreter with a clean `qd.init` -- the only way to exercise the cross-process fast-cache load +# path that ``Kernel._try_load_fastcache`` takes after a previous run has populated the on-disk cache. +if __name__ == "__main__": + globals()[sys.argv[1]](sys.argv[2:]) diff --git a/tests/python/test_graph_do_while.py b/tests/python/test_graph_do_while.py index 2a7ba506b7..e5859cd730 100644 --- a/tests/python/test_graph_do_while.py +++ b/tests/python/test_graph_do_while.py @@ -320,10 +320,161 @@ def k(x: qd.types.ndarray(qd.i32, ndim=1), c: qd.types.ndarray(qd.i32, ndim=0)): x = qd.ndarray(qd.i32, shape=(4,)) c = qd.ndarray(qd.i32, shape=()) c.from_numpy(np.array(1, dtype=np.int32)) - with pytest.raises(qd.QuadrantsSyntaxError, match="does not match any parameter"): + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): k(x, c) +@test_utils.test() +def test_graph_do_while_with_dataclass_member_counter(): + """`qd.graph_do_while(params.counter)` for a ``@dataclasses.dataclass`` kernel parameter takes the same + AST-build-time resolution path as the ``self.counter`` form -- the loop drives entirely from the device-side + counter just like the bare-parameter case.""" + import dataclasses # pylint: disable=import-outside-toplevel + + N = 4 + + @dataclasses.dataclass + class Params: + x: qd.types.NDArray[qd.i32, 1] + counter: qd.types.NDArray[qd.i32, 0] + + @qd.kernel(graph=True) + def step(params: Params): + while qd.graph_do_while(params.counter): + for i in range(params.x.shape[0]): + params.x[i] = params.x[i] + 1 + for _ in range(1): + params.counter[()] = params.counter[()] - 1 + + params = Params( + x=qd.ndarray(qd.i32, shape=(N,)), + counter=qd.ndarray(qd.i32, shape=()), + ) + params.x.from_numpy(np.zeros(N, dtype=np.int32)) + params.counter.from_numpy(np.array(3, dtype=np.int32)) + step(params) + np.testing.assert_array_equal(params.x.to_numpy(), np.full(N, 3, dtype=np.int32)) + assert params.counter.to_numpy() == 0 + levels = step._primal.graph_do_while_levels + assert len(levels) == 1 + # Dataclass-parameter member access gets pre-rewritten to a flattened parameter name (`__qd_params__qd_counter`) + # before the graph_do_while transformer sees it, so the readable label round-trips in the flattened form. The + # functional contract -- a valid flat C++ arg-id resolves and the loop drives off the device-side counter -- is the + # same as for the bare-param / `self.counter` forms. + assert "counter" in levels[0].cond_arg_name + assert levels[0].cond_cpp_arg_id >= 0 + + +@test_utils.test() +def test_graph_do_while_with_member_nonexistent_attribute_raises(): + """`qd.graph_do_while(self.nonexistent_attr)` must raise the same user-facing `does not resolve to an ndarray kernel + parameter` diagnostic as the bare-name nonexistent case. The AST-time resolver wraps the underlying attribute lookup + failure so the user sees one consistent error pattern across bare-name and attribute forms.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.counter = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True) + def step(self): + while qd.graph_do_while(self.nonexistent_counter): # type: ignore[attr-defined] + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + + sim = Sim() + with pytest.raises(qd.QuadrantsSyntaxError, match="does not resolve to an ndarray kernel parameter"): + sim.step() + + +@test_utils.test() +def test_graph_do_while_with_data_oriented_member_nested(): + """Nested `qd.graph_do_while(self.outer)` containing `qd.graph_do_while(self.inner)` exercises the level-table + machinery with member ndarrays: each level resolves its own flat C++ arg-id at AST-build time, the parent_id chain + links inner -> outer, and the loop body iterates `outer_iters * inner_iters` times the same as the bare-parameter + version (see `test_graph_do_while_nested_two_levels`).""" + if not _is_graph_do_while_natively_supported() and not ( + impl.current_cfg().arch in (qd.x64, qd.arm64, qd.amdgpu, qd.vulkan, qd.metal) + ): + pytest.skip("backend does not implement graph_do_while") + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.outer = qd.ndarray(qd.i32, shape=()) + self.inner = qd.ndarray(qd.i32, shape=()) + self.inner_start = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True) + def step(self): + while qd.graph_do_while(self.outer): + for _ in range(1): + self.inner[()] = self.inner_start[()] + while qd.graph_do_while(self.inner): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + for _ in range(1): + self.inner[()] = self.inner[()] - 1 + for _ in range(1): + self.outer[()] = self.outer[()] - 1 + + sim = Sim() + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.outer.from_numpy(np.array(3, dtype=np.int32)) + sim.inner.from_numpy(np.array(2, dtype=np.int32)) + sim.inner_start.from_numpy(np.array(2, dtype=np.int32)) + sim.step() + np.testing.assert_array_equal(sim.x.to_numpy(), np.full(N, 6, dtype=np.int32)) + assert sim.outer.to_numpy() == 0 + levels = sim.step._primal.graph_do_while_levels + assert len(levels) == 2 + assert levels[0].cond_arg_name == "self.outer" + assert levels[1].cond_arg_name == "self.inner" + assert levels[0].parent_id == -1 + assert levels[1].parent_id == 0 + assert levels[0].cond_cpp_arg_id >= 0 + assert levels[1].cond_cpp_arg_id >= 0 + assert levels[0].cond_cpp_arg_id != levels[1].cond_cpp_arg_id + + +@test_utils.test() +def test_graph_do_while_with_data_oriented_member_counter(): + """`qd.graph_do_while(self.counter)` resolves the member ndarray to the loop condition's flat C++ arg-id at + AST-build time via ``ASTTransformer._resolve_ndarray_kernel_arg_id``, lifting the previous bare-parameter + restriction. The metadata exposed on the kernel records the readable label (``"self.counter"``) plus the resolved + arg-id; the loop behaviour matches the bare-parameter form below.""" + N = 4 + + @qd.data_oriented + class Sim: + def __init__(self): + self.x = qd.ndarray(qd.i32, shape=(N,)) + self.counter = qd.ndarray(qd.i32, shape=()) + + @qd.kernel(graph=True) + def step(self): + while qd.graph_do_while(self.counter): + for i in range(self.x.shape[0]): + self.x[i] = self.x[i] + 1 + for _ in range(1): + self.counter[()] = self.counter[()] - 1 + + sim = Sim() + sim.x.from_numpy(np.zeros(N, dtype=np.int32)) + sim.counter.from_numpy(np.array(3, dtype=np.int32)) + sim.step() + np.testing.assert_array_equal(sim.x.to_numpy(), np.full(N, 3, dtype=np.int32)) + assert sim.counter.to_numpy() == 0 + levels = sim.step._primal.graph_do_while_levels + assert len(levels) == 1 + assert levels[0].cond_arg_name == "self.counter" + assert levels[0].cond_cpp_arg_id >= 0 + + @qd.kernel(graph=True, fastcache=True) def _fastcache_do_while_kernel(x: qd.types.ndarray(qd.i32, ndim=1), counter: qd.types.ndarray(qd.i32, ndim=0)): while qd.graph_do_while(counter):