Skip to content
4 changes: 2 additions & 2 deletions docs/source/user_guide/graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ solve(x, counter)
# x is now incremented 10 times; counter is 0
```

The argument to `qd.graph_do_while()` must be the name of a scalar `qd.i32` ndarray parameter. The loop body repeats while this value is non-zero.
The argument to `qd.graph_do_while()` must reference a scalar `qd.i32` ndarray that the kernel can access — a bare kernel parameter (`qd.graph_do_while(counter)`), a `@qd.data_oriented` member ndarray (`qd.graph_do_while(self.counter)`), or a `@dataclasses.dataclass` parameter member (`qd.graph_do_while(params.counter)`). The loop body repeats while this value is non-zero.

- On CUDA SM 9.0+ (Hopper), this uses CUDA conditional while nodes — the entire iteration runs on the GPU with no host involvement.
- On older CUDA GPUs, AMDGPU, and non-GPU backends, it falls back to a host-side do-while loop (see the [backend support table](#backend-support)).
Expand Down Expand Up @@ -256,7 +256,7 @@ while status.yielded:

- Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time with a fix-it pointing at `checkpoints=True`.
- `cp_id` must be an int literal or an `IntEnum` value, and must be unique across the kernel.
- `yield_on=` must be a kernel parameter that is a 0-d `qd.types.ndarray(qd.i32, ndim=0)`; expressions are not supported.
- `yield_on=` must reference a 0-d `qd.types.ndarray(qd.i32, ndim=0)` — a bare kernel parameter (`yield_on=flag`), a `@qd.data_oriented` member ndarray (`yield_on=self.flag`), or a `@dataclasses.dataclass` parameter member (`yield_on=params.flag`). Arbitrary expressions are not supported.
- Checkpoints cannot be nested inside other checkpoints. Checkpoints inside a `qd.graph_do_while` body are fine.
- The body of a `with qd.checkpoint(...)` block cannot contain bare top-level statements (assignments, augmented assignments, or bare call/expression statements). Every top-level statement must be inside a `for`-loop (or other control-flow construct). A docstring as the first statement is allowed. Bare statements raise `QuadrantsSyntaxError` at compile time with a fix-it pointing at the explicit one-iteration `for`-wrap:

Expand Down
138 changes: 119 additions & 19 deletions python/quadrants/lang/_fast_caching/src_hasher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib
import json
import os
import warnings
from enum import IntEnum
from typing import Any, Iterable, Sequence

import pydantic
Expand All @@ -18,8 +20,77 @@
from .python_side_cache import PythonSideCache

# Bumped whenever the persisted CacheValue schema changes (see create_cache_key). v2 replaced the single
# graph_do_while_arg string with a nested level table.
_CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v2-gdw-levels"
# graph_do_while_arg string with a nested level table. v3 added the AST-resolved flat C++ arg-ids for qd.graph_do_while
# conditions and qd.checkpoint(yield_on=...) targets so the launch path can forward them directly without per-launch
# name matching (necessary for @qd.data_oriented member ndarrays). v4 added the per-slot
# `checkpoint_user_label_enum_qualnames` table so an IntEnum cp_id (e.g. `qd.checkpoint(Stage.SIM, ...)`) round-trips
# through fast-cache restore as the original IntEnum member rather than the underlying int.
_CACHE_VALUE_SCHEMA_VERSION = "cachevalue-v4-intenum-qualnames"


def _intenum_member_qualname(value: Any) -> str | None:
"""Return ``"module.ClassQualName.MEMBER"`` for an ``IntEnum`` member, else ``None``.

Stored alongside ``checkpoint_user_labels_by_cp_id`` so that ``_resolve_intenum_member`` can rebuild the original
enum member on fast-cache restore -- pydantic coerces ``IntEnum`` to plain ``int`` at ``CacheValue`` construction
time (it sees ``list[int | None]``), which would otherwise silently break the documented contract that
``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` rather than the raw int through ``status.checkpoint``.
Returns ``None`` for plain ints, ``None`` labels, anonymous enums (no ``__module__``), and other unsupported
shapes -- the loader falls back to the raw int in those cases.
"""
if not isinstance(value, IntEnum):
return None
cls = type(value)
module = getattr(cls, "__module__", None)
qualname = getattr(cls, "__qualname__", None)
name = getattr(value, "name", None)
if not module or not qualname or not name:
return None
return f"{module}.{qualname}.{name}"


def _resolve_intenum_member(qualname: str | None, fallback: int | None) -> int | IntEnum | None:
"""Inverse of ``_intenum_member_qualname``: look up the enum member by ``"module.ClassQualName.MEMBER"``.

Returns the resolved ``IntEnum`` member if every step (module import, attribute walk) succeeds AND the member's int
value matches ``fallback`` (the raw int from ``checkpoint_user_labels_by_cp_id`` we already persisted). Mismatch or
any failure -- module renamed since the cache was written, enum class refactored, member removed, etc. -- falls back
to ``fallback`` so the user still gets a usable (if enum-identity-less) label rather than a hard crash. ``None``
qualname / ``None`` fallback short-circuit to ``fallback`` for the plain-int label case.
"""
if qualname is None or fallback is None:
return fallback
try:
# qualname is "module.path.Class[.Nested].MEMBER"; the MEMBER tail is always one segment, so rsplit once. The
# remaining cls_path mixes dotted module path + dotted class qualname; we try progressively shorter module
# prefixes until one imports, then resolve the rest as attribute chain. This handles top-level enums
# (``mymod.Stage.LOAD``), enums nested in classes (``mymod.Outer.Inner.MEMBER``), and enums in subpackages
# (``a.b.Stage.LOAD``) without needing the user to declare which prefix is the module.
cls_path, _, member_name = qualname.rpartition(".")
if not cls_path or not member_name:
return fallback
module = None
cls_attr_path = ""
segments = cls_path.split(".")
for i in range(len(segments), 0, -1):
try:
module = importlib.import_module(".".join(segments[:i]))
cls_attr_path = ".".join(segments[i:])
break
except ImportError:
continue
if module is None:
return fallback
obj: Any = module
if cls_attr_path:
for seg in cls_attr_path.split("."):
obj = getattr(obj, seg)
obj = getattr(obj, member_name)
except (AttributeError, ValueError):
return fallback
if isinstance(obj, IntEnum) and int(obj) == int(fallback):
return obj
return fallback


def create_cache_key(
Expand Down Expand Up @@ -69,18 +140,45 @@ class CacheValue(BaseModel):
frontend_cache_key: str
hashed_function_source_infos: list[HashedFunctionSourceInfo]
used_py_dataclass_parameters: set[str]
# Nested graph_do_while level table as (cond_arg_name, parent_id) pairs, indexed by level id. None / empty for
# kernels without graph_do_while.
graph_do_while_levels: list[tuple[str, int]] | None = None
# Nested graph_do_while level table as (cond_arg_name, parent_id, cond_cpp_arg_id) triples, indexed by level id.
# None / empty for kernels without graph_do_while. ``cond_cpp_arg_id`` is the flat C++ arg-id resolved at AST-build
# time by ``ASTTransformer._resolve_ndarray_kernel_arg_id`` and is required by the launch path to support
# `@qd.data_oriented` member conditions (`qd.graph_do_while(self.counter)`) -- name-matching against ``arg_metas``
# only resolves top-level parameters.
graph_do_while_levels: list[tuple[str, int, int]] | None = None
# AST-build-time-resolved checkpoint metadata, indexed by internal cp_id. Empty for kernels without any
# `with qd.checkpoint(...)` block. See `Kernel.checkpoint_yield_on_args` /
# `Kernel.checkpoint_yield_on_cpp_arg_ids` / `Kernel.checkpoint_user_labels_by_cp_id` for what each entry means.
# Restored alongside the C++-side cached kernel so the launch path can forward `yield_on=` arg-ids and translate
# `from_checkpoint=` labels without re-running the AST transformer.
checkpoint_yield_on_args: list[str | None] = []
checkpoint_yield_on_cpp_arg_ids: list[int] = []
checkpoint_user_labels_by_cp_id: list[int | None] = []

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve IntEnum labels when restoring fast-cache checkpoints

When a checkpoint uses an IntEnum cp_id and the kernel is restored from src_ll_cache, this new cache field is serialized through JSON as plain integers and _try_load_fastcache restores checkpoint_user_labels_by_cp_id as [1] rather than [Stage.LOAD]. maybe_build_graph_status() then returns the raw int for status.checkpoint on cache hits, breaking the documented/API contract that qd.checkpoint(Stage.X, ...) round-trips the enum value rather than the underlying int. Persist enough enum metadata/expression information to reconstruct the label, or avoid lossy restoration for enum labels.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch -- confirmed and fixed in 58bba23.

Pydantic coerces IntEnum to int at CacheValue.__init__ time (the field is typed list[int | None]), so just persisting the int column was lossy even before JSON. Schema bumped to cachevalue-v4-intenum-qualnames, which adds a parallel checkpoint_user_label_enum_qualnames column. src_hasher.store derives the per-slot module.ClassQualName.MEMBER string from the live label list (still holding the original IntEnum instances at store time, before pydantic strips identity), and _resolve_intenum_member re-imports the enum class via importlib on load. Mismatch / failed import (enum moved or renamed since the cache was written) falls back to the persisted int rather than raising, so stale caches degrade gracefully.

Tests:

  • test_checkpoint_fastcache_preserves_intenum_label_identity -- subprocess cache miss + hit, asserts isinstance(label, _FastcacheStage) after restore (not just int equality).
  • test_src_hasher_intenum_qualname_round_trip -- direct CacheValue unit test for mixed IntEnum / None / plain-int slots, qualname derivation, and the resolver fallback.

Both pass on x64 and CUDA on the cluster. Older v3 caches just invalidate via the version bump, no migration needed.

# Parallel to ``checkpoint_user_labels_by_cp_id``: each entry is the dotted ``module.ClassQualName.MEMBER`` of the
# original ``IntEnum`` member the user passed as ``cp_id``, or ``None`` if the user passed a plain int (or for
# implicit auto-wrap checkpoints). On fast-cache restore the loader runs each entry through
# ``_resolve_intenum_member`` to rebuild the IntEnum, preserving the documented contract that
# ``qd.checkpoint(Stage.X, ...)`` round-trips ``Stage.X`` (not the underlying int) through ``status.checkpoint`` and
# ``kernel.resume(from_checkpoint=...)`` -- pydantic coerces IntEnum to int at ``CacheValue`` construction time so
# the parallel qualname column is what carries the enum identity.
checkpoint_user_label_enum_qualnames: list[str | None] = []


def store(
frontend_cache_key: str,
fast_cache_key: str,
function_source_infos: Iterable[FunctionSourceInfo],
used_py_dataclass_parameters: set[str],
graph_do_while_levels: list[tuple[str, int]] | None = None,
graph_do_while_levels: list[tuple[str, int, int]] | None = None,
checkpoint_yield_on_args: list[str | None] | None = None,
checkpoint_yield_on_cpp_arg_ids: list[int] | None = None,
checkpoint_user_labels_by_cp_id: list[int | None] | None = None,
) -> None:
# `checkpoint_user_label_enum_qualnames` is derived from `checkpoint_user_labels_by_cp_id` here (rather than being
# plumbed through a separate kwarg from `Kernel.materialize`) so callers never have to think about the parallel
# column: they pass the live label list (which still holds the original ``IntEnum`` instances at store time, before
# pydantic's int-coercion strips identity in ``CacheValue.__init__``), and the qualname snapshot is recorded once
# here for the loader to consume.
"""
Note that unlike other caches, this cache is not going to store the actual value we want.
This cache is only used for verification that our cache key is valid. Big picture:
Expand All @@ -103,11 +201,17 @@ def store(
assert frontend_cache_key is not None
cache = PythonSideCache()
hashed_function_source_infos = function_hasher.hash_functions(function_source_infos)
labels = checkpoint_user_labels_by_cp_id or []
enum_qualnames = [_intenum_member_qualname(lbl) for lbl in labels]
cache_value_obj = CacheValue(
frontend_cache_key=frontend_cache_key,
hashed_function_source_infos=list(hashed_function_source_infos),
used_py_dataclass_parameters=used_py_dataclass_parameters,
graph_do_while_levels=graph_do_while_levels,
checkpoint_yield_on_args=checkpoint_yield_on_args or [],
checkpoint_yield_on_cpp_arg_ids=checkpoint_yield_on_cpp_arg_ids or [],
checkpoint_user_labels_by_cp_id=labels,
checkpoint_user_label_enum_qualnames=enum_qualnames,
)
cache.store(fast_cache_key, cache_value_obj.model_dump_json())

Expand All @@ -125,23 +229,19 @@ def _try_load(cache_key: str) -> CacheValue | None:
return cache_value_obj


def load(
cache_key: str,
) -> tuple[set[str], str, list[tuple[str, int]] | None] | tuple[None, None, None]:
"""
loads function source infos from cache, if available
checks the hashes against the current source code
def load(cache_key: str) -> CacheValue | None:
"""Load a validated ``CacheValue`` for *cache_key* if one exists and its source hashes still match, else None.

Returns the full ``CacheValue`` (rather than the historical 3-tuple) so callers can pick off the AST-transformer-
produced metadata (graph_do_while levels, checkpoint tables) without the loader having to grow a new return slot
every time we cache a new piece of AST output.
"""
cache_value = _try_load(cache_key)
if cache_value is None:
return None, None, None
return None
if function_hasher.validate_hashed_function_infos(cache_value.hashed_function_source_infos):
return (
cache_value.used_py_dataclass_parameters,
cache_value.frontend_cache_key,
cache_value.graph_do_while_levels,
)
return None, None, None
return cache_value
return None


def dump_stats() -> None:
Expand Down
63 changes: 44 additions & 19 deletions python/quadrants/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,19 +1352,44 @@ def build_For(ctx: ASTTransformerFuncContext, node: ast.For) -> None:
return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)

@staticmethod
def _is_graph_do_while_call(node: ast.expr) -> str | None:
"""If *node* is ``qd.graph_do_while(var)`` return the arg name, else None."""
def _is_graph_do_while_call(node: ast.expr) -> ast.expr | None:
"""If *node* is ``qd.graph_do_while(arg)`` return the arg AST node, else None.

``arg`` may be an ``ast.Name`` (a bare kernel parameter, e.g. ``counter``) or an ``ast.Attribute`` chain (a
``@qd.data_oriented`` member ndarray such as ``self.counter`` or a ``@dataclasses.dataclass`` parameter member
such as ``params.counter``). The actual resolution to a kernel ndarray argument happens in ``build_While`` via
``_resolve_ndarray_kernel_arg_id``.
"""
if not isinstance(node, ast.Call):
return None
func = node.func
if isinstance(func, ast.Attribute) and func.attr == "graph_do_while":
if len(node.args) == 1 and isinstance(node.args[0], ast.Name):
return node.args[0].id
if isinstance(func, ast.Name) and func.id == "graph_do_while":
if len(node.args) == 1 and isinstance(node.args[0], ast.Name):
return node.args[0].id
is_gdw = (isinstance(func, ast.Attribute) and func.attr == "graph_do_while") or (
isinstance(func, ast.Name) and func.id == "graph_do_while"
)
if not is_gdw:
return None
if len(node.args) == 1 and isinstance(node.args[0], (ast.Name, ast.Attribute)):
return node.args[0]
return None

@staticmethod
def _resolve_ndarray_kernel_arg_id(
ctx: ASTTransformerFuncContext,
kernel,
node: ast.expr,
usage: str,
) -> tuple[str, int]:
"""Thin forwarding wrapper around ``ndarray_arg_resolver.resolve_ndarray_kernel_arg_id``; the actual logic lives
in module ``ast_transformers/ndarray_arg_resolver.py`` to keep this file from growing per-feature (same pattern
as ``_is_checkpoint_call`` / ``CheckpointTransformer``). Returns ``(label, flat_cpp_arg_id)`` or raises
``QuadrantsSyntaxError``."""
# pylint: disable-next=C0415,import-outside-toplevel
from quadrants.lang.ast.ast_transformers.ndarray_arg_resolver import (
resolve_ndarray_kernel_arg_id,
)

return resolve_ndarray_kernel_arg_id(ctx, kernel, node, usage)

@staticmethod
def _is_checkpoint_call(node: ast.expr, global_vars: dict):
"""Thin forwarding wrapper around ``CheckpointTransformer.is_checkpoint_call``; the actual logic lives in module
Expand All @@ -1377,18 +1402,11 @@ def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None:
if node.orelse:
raise QuadrantsSyntaxError("'else' clause for 'while' not supported in Quadrants kernels")

graph_do_while_arg = ASTTransformer._is_graph_do_while_call(node.test)
if graph_do_while_arg is not None:
graph_do_while_node = ASTTransformer._is_graph_do_while_call(node.test)
if graph_do_while_node is not None:
from quadrants.lang.kernel import GraphDoWhileLevel # pylint: disable=C0415

kernel = ctx.global_context.current_kernel
arg_names = [m.name for m in kernel.arg_metas]
if graph_do_while_arg not in arg_names:
raise QuadrantsSyntaxError(
f"qd.graph_do_while({graph_do_while_arg!r}) does not match any "
f"parameter of kernel {kernel.func.__name__!r}. "
f"Available parameters: {arg_names}"
)
if not kernel.use_graph:
raise QuadrantsSyntaxError("qd.graph_do_while() requires @qd.kernel(graph=True)")
# graph_do_while emits no loop IR; its body's for-loops must be top-level (offloaded) tasks. So it may only
Expand All @@ -1399,15 +1417,22 @@ def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None:
"qd.graph_do_while() must be at the kernel top level or directly nested inside "
"another qd.graph_do_while(); it cannot appear inside a for-loop."
)
# Resolve the condition ndarray (bare parameter or @qd.data_oriented member) to its flat C++ arg-id at AST-
# build time -- the same id the runtime needs -- so the launch path forwards it directly with no per-launch
# name matching. ``cond_arg_name`` keeps the readable label (e.g. "counter" or "self.counter") for
# introspection and for the legacy ``graph_do_while_arg`` alias surfaced on Kernel.
cond_label, cond_cpp_arg_id = ASTTransformer._resolve_ndarray_kernel_arg_id(
ctx, kernel, graph_do_while_node, "qd.graph_do_while(...)"
)
# Register this loop as a new nesting level (the body restriction is validated up-front in
# FunctionDefTransformer). Outer loops get lower ids than the inner loops they contain.
parent_id = kernel._graph_do_while_level_stack[-1] if kernel._graph_do_while_level_stack else -1
level_id = len(kernel.graph_do_while_levels)
kernel.graph_do_while_levels.append(
GraphDoWhileLevel(cond_arg_name=graph_do_while_arg, parent_id=parent_id)
GraphDoWhileLevel(cond_arg_name=cond_label, parent_id=parent_id, cond_cpp_arg_id=cond_cpp_arg_id)
)
if level_id == 0:
kernel.graph_do_while_arg = graph_do_while_arg
kernel.graph_do_while_arg = cond_label
kernel._graph_do_while_level_stack.append(level_id)
ctx.ast_builder.set_graph_do_while_level_id(level_id)
try:
Expand Down
Loading
Loading