diff --git a/docs/source/user_guide/compound_types.md b/docs/source/user_guide/compound_types.md index e2d000bcbb..87a3286aed 100644 --- a/docs/source/user_guide/compound_types.md +++ b/docs/source/user_guide/compound_types.md @@ -169,6 +169,31 @@ sim.step() `@qd.data_oriented` objects can also be passed as `qd.Template` parameters to kernels defined outside the class, and they support nesting (one `@qd.data_oriented` struct containing another). +### stable_members + +**Recommended for any `@qd.data_oriented` class whose ndarray members are allocated once (typically in `__init__`) and not subsequently rebound — the common case.** Decorate with `stable_members=True`: + +```python +@qd.data_oriented(stable_members=True) +class Simulation: + def __init__(self, n): + self.x = qd.ndarray(qd.f32, shape=(n,)) + self.v = qd.ndarray(qd.f32, shape=(n,)) + # ... more ndarray / field / primitive members +``` + +This skips a per-call walk that Quadrants otherwise runs to detect ndarray member rebinding between kernel launches. The walk is O(number of ndarray members) per kernel call, so the savings scale with the container's size. + +Microbenchmark on an RTX PRO 6000 Blackwell with a container holding 30 `qd.ndarray` members across two nesting levels, calling a trivial kernel that takes the container as a `qd.template()` arg: + +| | Per-launch Python overhead | +|---|---| +| `stable_members=False` (default) | 18.5 µs/call | +| `stable_members=True` | 13.5 µs/call | +| | **−5 µs/call (−28%)** | + +**Trade-off:** with `stable_members=True`, reassigning an ndarray member on an instance is undefined behavior — the previously compiled kernel will be reused even if the new ndarray has a different `dtype`, `ndim`, or layout, silently bit-reinterpreting the new array's storage. Set it only on classes whose ndarray members are allocated once (typically in `__init__`) and never rebound. See [Reassigning ndarray members](#reassigning-ndarray-members) below for the supported alternative. + ### Primitive members Primitive members on `self` (e.g. `int`, `float`, `bool`, `enum.Enum`) are supported, but they are treated as **template values**: each distinct primitive value across instances triggers a new kernel compilation, with the value baked into the kernel IR. @@ -318,6 +343,8 @@ Practical consequence: For `@qd.data_oriented` containers passed via `qd.Template`, reassigning an ndarray member between kernel launches is supported, including changes to `dtype`, `ndim`, or layout. A new specialised kernel is compiled and cached for the new shape; subsequent launches with the original shape continue to use the original cached kernel. (For `@dataclasses.dataclass` containers — passed via the dataclass-type annotation — the member binding follows the standard dataclass mutability rules: frozen dataclasses can't rebind, non-frozen ones can, and a rebind triggers a fresh kernel arg setup on the next launch.) +This support is only available on `@qd.data_oriented` classes *without* the [`stable_members=True`](#stable_members) opt-in. Setting `stable_members=True` is a promise that ndarray members on instances of the class are never reassigned; if you break that promise the previously compiled kernel is silently reused against the new ndarray. + ### Restrictions - **`@qd.dataclass` cannot contain `qd.ndarray` or `qd.field` members.** See the [`@qd.dataclass`](#qddataclass-qdtypesstruct) section above for the full list of allowed member types. (The function-form factory `qd.types.struct(...)` has the same restrictions.) diff --git a/docs/source/user_guide/fastcache.md b/docs/source/user_guide/fastcache.md index ab07483e97..8c89fee53f 100644 --- a/docs/source/user_guide/fastcache.md +++ b/docs/source/user_guide/fastcache.md @@ -95,14 +95,17 @@ Fastcache supports the following parameter types: | `qd.types.NDArray` (scalar, vector, matrix) | Yes | dtype, ndim, layout | | `torch.Tensor` | Yes | dtype, ndim | | `numpy.ndarray` | Yes | dtype, ndim | -| `dataclasses.dataclass` | Yes | member types recursively; member values if annotated with `FIELD_METADATA_CACHE_VALUE` (see [Appendix — compound-type cache keying](#compound-type-cache-keying)) | -| `@qd.data_oriented` objects | Yes | member types recursively; primitive member types and values baked into kernel (see [Appendix — compound-type cache keying](#compound-type-cache-keying)) | +| `dataclasses.dataclass` | Yes | member types recursively (narrowed to members the kernel reads or writes); member values if annotated with `FIELD_METADATA_CACHE_VALUE` (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | +| `@qd.data_oriented` objects | Yes | member types recursively (narrowed to members the kernel reads or writes); primitive member types and values baked into kernel (see [Advanced — compound-type cache keying](#compound-type-cache-keying)) | | `qd.Template` primitives (int, float, bool) | Yes | type and value (baked into kernel) | | Non-template primitives (int, float, bool) | Yes | type only | | `enum.Enum` | Yes | name and value | -| `qd.field` / `ScalarField` / `MatrixField` | **No** | — | +| `qd.field` / `ScalarField` / `MatrixField` at a kernel-read path | **No** | — | +| Anything else at a kernel-read path | **No** | — | -If any parameter is of an unsupported type, fastcache is disabled for that call and the kernel falls back to normal compilation. For `qd.field` / `ScalarField` / `MatrixField` arriving through a `qd.Tensor`-annotated parameter, this is silent — no warning is emitted. For other unsupported types, a warning is logged at the `warn` level identifying the offending parameter. +If any kernel-used parameter is of an unsupported type, fastcache is disabled for that call and the kernel falls back to normal compilation. For `qd.field` / `ScalarField` / `MatrixField` arriving through a `qd.Tensor`-annotated parameter, this is silent — no warning is emitted. For other unsupported types, a warning is logged at the `warn` level identifying the offending parameter. + +Kernel-unused members of any type — including unrecognised ones — do **not** disable fastcache. Fastcache skips them entirely, so opaque metadata (UUIDs, Pydantic configs, parent back-pointers) attached to a `@qd.data_oriented` or `dataclasses.dataclass` instance is harmless as long as the kernel doesn't read it. ### 3. Source code must be available @@ -120,6 +123,12 @@ Each compiled artifact is stored under a key derived from all of the following: When any of these change, the resulting key is different, so a new compilation occurs and a new entry is stored. Previous entries remain on disk — multiple cached versions coexist. You do not need to manually clear the cache when making code changes — the hash mismatch causes a transparent recompilation. +### Two strict invariants + +1. **If the kernel does not read or write a variable, it is entirely ignored by fastcache.** It will not cause fastcache to fail, nor emit a warning, nor emit an error. + +2. **Unrecognised types at variables the kernel reads or writes must not be silently dropped or hashed by type-name.** If the value of such a variable has a type fastcache doesn't explicitly handle (Pydantic models, UUIDs, third-party tensor wrappers, …), fastcache is disabled for the call with a one-shot `[FASTCACHE][UNKNOWN_TYPE]` warning identifying the offending type plus an `[INVALID_FUNC]` log line confirming the cache is off. + ## Advanced ### Diagnostics @@ -143,32 +152,25 @@ print(obs.cache_stored) # True if the compiled kernel was stored to cach On the first run you'll see `cache_stored=True` but `cache_loaded=False`. On the second run (after `qd.init`), `cache_loaded=True`. -## Appendix - ### Compound-type cache keying -The args hasher walks compound-type kernel parameters recursively. For each leaf member it decides what (if anything) contributes to the cache key. The headline rules: +For `@qd.data_oriented` and `dataclasses.dataclass` kernel parameters, fastcache walks members recursively. Any members that are not themselves read or written by the kernel, nor contain members read or written by the kernel, are skipped during the walk (per the [strict invariants](#two-strict-invariants) above). Member-by-member behavior: -**`@qd.data_oriented`:** the walker descends into `vars(obj)`. For each child: +- **`qd.ndarray` member** — `(dtype, ndim, layout)` is included in the cache key. Element values are not. +- **Primitive (`int` / `float` / `bool` / `enum.Enum`) member.** The handling depends on the enclosing container: + - In a `@qd.data_oriented` instance — value is baked into the kernel, same as a `qd.Template` primitive. Two instances of the same class with different primitive member values get different cache entries. + - In a `dataclasses.dataclass` instance — only the type is included by default. To include the value too, annotate the field with `FIELD_METADATA_CACHE_VALUE`: -- `qd.ndarray` member — `(dtype, ndim, layout)` is included in the cache key. Element values are not. -- Primitive (`int` / `float` / `bool` / `enum.Enum`) member — value is baked into the kernel (same semantics as a `qd.Template` primitive). Two instances of the same class with different primitive member values get different cache entries. -- Nested `@qd.data_oriented` member — recurses. -- Nested `dataclasses.dataclass` member — recurses (with the dataclass rules below). -- `qd.field` member — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted. - -**`dataclasses.dataclass`:** the walker descends into the declared members. For each member, only the *type* is included in the cache key by default — **not** the value. To include a member's value, annotate it: - -```python -import dataclasses -from quadrants.lang._fast_caching import FIELD_METADATA_CACHE_VALUE - -@dataclasses.dataclass -class SimConfig: - num_layers: int = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True}) - dt: float = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True}) -``` + ```python + import dataclasses + from quadrants.lang._fast_caching import FIELD_METADATA_CACHE_VALUE -This is necessary whenever the compiled kernel depends on the member's *value* rather than just its type (for example, when the value is used as a loop bound that the compiler bakes into the generated code). Without the annotation, two `SimConfig` instances with different `num_layers` values would share a fastcache key, and the second instance would silently load a kernel compiled for the wrong value. + @dataclasses.dataclass + class SimConfig: + num_layers: int = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True}) + dt: float = dataclasses.field(metadata={FIELD_METADATA_CACHE_VALUE: True}) + ``` -Note the asymmetry: `@qd.data_oriented` primitive members are baked into the kernel automatically (same semantics as `qd.Template`); `dataclasses.dataclass` members contribute only their *type* to the cache key unless you opt in per-member. + Annotate any member whose *value* (not just type) affects the compiled kernel. Primarily this means any variable used inside [`qd.static`](static.md). +- **Nested `@qd.data_oriented` or `dataclasses.dataclass` member** — recurses with the same rules (so an `int` inside a nested `@qd.data_oriented` is still baked into the kernel; an `int` inside a nested `dataclasses.dataclass` still needs `FIELD_METADATA_CACHE_VALUE` to bake its value). +- **`qd.field` member** — fastcache is disabled for the entire kernel call. The kernel still runs via normal compilation; a warn-level log line is emitted. diff --git a/python/quadrants/lang/_fast_caching/args_hasher.py b/python/quadrants/lang/_fast_caching/args_hasher.py index 1a949d3007..b8967412df 100644 --- a/python/quadrants/lang/_fast_caching/args_hasher.py +++ b/python/quadrants/lang/_fast_caching/args_hasher.py @@ -11,11 +11,13 @@ from quadrants._tensor_wrapper import Tensor as _TensorWrapper from quadrants.types.annotations import Template +from .._dataclass_util import create_flat_name from .._ndarray import ScalarNdarray +from .._quadrants_callable import BoundQuadrantsCallable, QuadrantsCallable from ..field import ScalarField from ..kernel_arguments import ArgMetadata from ..matrix import MatrixField, MatrixNdarray, VectorNdarray -from ..util import is_data_oriented +from ..util import is_data_oriented, is_dataclass_instance from .hash_utils import hash_iterable_strings _FIELD_TYPES = (ScalarField, MatrixField) @@ -40,6 +42,26 @@ _DC_REPR_NONE = object() +# Sentinel returned by ``stringify_obj_type`` whenever fastcache cannot safely hash a value: +# - Recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``). +# - Unrecognised type at a kernel-read path (no qualname fallback — see rules in fastcache.md). +# +# Containers (``dataclass_to_repr``, ``data_oriented`` branch, top-level ``hash_args`` loop) must propagate it upward +# — fastcache is disabled for the whole call and the caller writes the appropriate diagnostic. +class _FailFastcache: + """Singleton sentinel; identity-compared.""" + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + +_FAIL_FASTCACHE = _FailFastcache() + + class FastcacheSkip(enum.Enum): """Why fastcache does not apply to this call.""" @@ -47,11 +69,23 @@ class FastcacheSkip(enum.Enum): WARN = "warn" -# Set when the fastcache skip is something callers should warn about (as opposed to a Field arriving through a -# qd.Tensor annotation, which is a normal silent path). Reset at the start of each hash_args call. +# Set when the fastcache skip is something callers should warn about (as opposed to a ``Field`` arriving through a +# ``qd.Tensor`` annotation, which is a normal silent path). Reset at the start of each ``hash_args`` call. _should_warn = False +# Set of ``type(v).__qualname__`` strings we've already emitted the "unknown type at a kernel-read path" +# warning for. Lets the loop run thousands of times without spamming the log while still telling the user once +# that fastcache encountered an unrecognised type. Cleared by ``reset_unknown_type_warn_state`` (called from +# ``qd.init``) so each new test sees a clean log. +_warned_unknown_types: set[str] = set() + + +def reset_unknown_type_warn_state() -> None: + """Clear the once-per-process warned-unknown-types set. Called from test setup / ``qd.init``.""" + _warned_unknown_types.clear() + + def _mark_warn_if_not_tensor_annotation(arg_meta: ArgMetadata | None) -> None: """Flag that a warning is needed if the Field didn't arrive through a qd.Tensor annotation.""" global _should_warn # pylint: disable=global-statement @@ -64,40 +98,131 @@ def _mark_should_warn() -> None: _should_warn = True -def dataclass_to_repr(raise_on_templated_floats: bool, path: tuple[str, ...], arg: Any) -> str | None: - # PERF: For frozen dataclasses, the repr never changes. Cache it on the instance to avoid repeated +def _fail_unknown_type(obj: object, path: tuple[str, ...]) -> _FailFastcache: + """Disable fastcache for the call when an unrecognised type appears at a kernel-read path. + + Two rules at work here (see ``docs/source/user_guide/fastcache.md`` "Pruning-driven argument hashing"): + + 1. The fastcache key may *only* contain contributions from kernel-pruned paths — never a + ``type(v).__qualname__`` fallback for an unrecognised type, because that hash captures type identity + only and would silently mask a value-affecting change (e.g. a new tensor-like type whose dtype matters). + + 2. We may not silently *discard* something at a kernel-read path on the basis that it's unrecognised — + that would let unrecognised but codegen-affecting values escape the cache key and serve stale results. + + The only way to honour both rules is to fail the call's fastcache loudly, with a one-shot warning per type + so the user can add explicit handling in ``stringify_obj_type``. + """ + t = type(obj) + qualname = f"{getattr(t, '__module__', '')}.{getattr(t, '__qualname__', t.__name__)}" + if qualname not in _warned_unknown_types: + _warned_unknown_types.add(qualname) + _logging.warn( + f"[FASTCACHE][UNKNOWN_TYPE] Unrecognised type {qualname} reached at kernel-read path {path}. " + f"Fastcache is disabled for this call. Add explicit handling for this type to " + f"``quadrants/lang/_fast_caching/args_hasher.py::stringify_obj_type``, or refactor the kernel " + f"so it does not read this member." + ) + _mark_should_warn() + return _FAIL_FASTCACHE + + +def _child_flat(parent_flat: str | None, child_name: str) -> str | None: + """Compute the flat name a kernel parameter would have if it pointed at this container's child. + + For a top-level arg ``state`` with child ``x``: ``__qd_state__qd_x``. + For a deeper child ``state.dofs.x``: ``__qd_state__qd_dofs__qd_x`` (built incrementally). + + ``parent_flat`` is the *kernel-side* representation of this container's root: + - top-level arg of a kernel: ``arg_meta.name`` (e.g. ``"state"``, ``"self"``) — no ``__qd_`` prefix. + - any nested level: the already-computed ``__qd_…`` flat name. + + Returns ``None`` when ``parent_flat`` itself is ``None``, indicating "no path info available" — the caller + must walk the child unconditionally (i.e. ignore ``pruning_paths`` for this branch). + """ + if parent_flat is None: + return None + return create_flat_name(parent_flat, child_name) + + +def _is_path_used(pruning_paths: set[str] | None, child_flat: str | None) -> bool: + """Return True if a child at ``child_flat`` should be hashed. + + - ``pruning_paths is None``: pre-pruning-info compile — hash everything. + - ``child_flat is None``: caller could not compute a flat-name path (no parent_flat available) — hash + everything as well, so we never accidentally drop a child we couldn't classify. + - both non-None: only hash children whose flat name is in the set. Pruning's prefix-expansion step in + ``Kernel.materialize`` guarantees that if any descendant of ``__qd_a__qd_b`` is used, ``__qd_a__qd_b`` + itself is also in the set, so this single membership check is sufficient to decide whether to descend. + """ + if pruning_paths is None or child_flat is None: + return True + return child_flat in pruning_paths + + +def dataclass_to_repr( + raise_on_templated_floats: bool, + path: tuple[str, ...], + arg: Any, + pruning_paths: set[str] | None = None, + parent_flat: str | None = None, +) -> str | _FailFastcache: + """Hash a dataclass instance, optionally narrowed by pruning information. + + Returns ``_FAIL_FASTCACHE`` if any field's subtree hits a recognised-but-unsupported tensor type (``ScalarField`` / + ``MatrixField``); otherwise a string. + + Pruning: if ``pruning_paths`` is non-None, only descend into fields whose flat name is in the set. Pruning's + prefix-expansion step ensures the set already contains all ancestors of used leaves, so checking the immediate + child's flat name is sufficient. + """ + # PERF: For frozen dataclasses the repr never changes. Cache it on the instance to avoid repeated # ``dataclasses.fields()`` calls (which are slow due to extra runtime checks — see _template_mapper_hotpath.py - # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. - # A cached ``None`` is stored as the sentinel ``_DC_REPR_NONE`` to distinguish "not yet computed" from - # "computed but not fast-cacheable". + # module docstring). The cache is stored as ``_qd_dc_repr`` via ``object.__setattr__`` to bypass frozen guards. A + # cached ``_DC_REPR_NONE`` sentinel distinguishes "computed but not fast-cacheable" from "not yet computed". + # + # The cache is keyed by ``(is_frozen, pruning_paths is None)`` because a frozen dataclass's pruned repr depends on + # the pruning_paths set — we use separate cache slots for pruned vs unpruned to avoid serving the wrong narrowing. + cache_attr = "_qd_dc_repr" if pruning_paths is None else "_qd_dc_repr_narrow" is_frozen = type(arg).__hash__ is not None if is_frozen: - cached = getattr(arg, "_qd_dc_repr", None) + cached = getattr(arg, cache_attr, None) if cached is _DC_REPR_NONE: - return None - if cached is not None: + return _FAIL_FASTCACHE + if cached is not None and pruning_paths is None: + # Narrow cache may be stale if pruning_paths set changed; only reuse the unpruned cache. return cached repr_l = [] for field in dataclasses.fields(arg): + child_flat = _child_flat(parent_flat, field.name) + if not _is_path_used(pruning_paths, child_flat): + continue child_value = getattr(arg, field.name) - _repr = stringify_obj_type(raise_on_templated_floats, path + (field.name,), child_value, arg_meta=None) - if _repr is None: + _repr = stringify_obj_type( + raise_on_templated_floats, + path + (field.name,), + child_value, + arg_meta=None, + pruning_paths=pruning_paths, + parent_flat=child_flat, + ) + if _repr is _FAIL_FASTCACHE: if isinstance(child_value, _FIELD_TYPES) and field.type is not _TensorWrapper: _mark_should_warn() if is_frozen: try: - object.__setattr__(arg, "_qd_dc_repr", _DC_REPR_NONE) + object.__setattr__(arg, cache_attr, _DC_REPR_NONE) except AttributeError: pass - return None + return _FAIL_FASTCACHE full_repr = f"{field.name}: ({_repr})" if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False): full_repr += f" = {child_value}" repr_l.append(full_repr) result = "[" + ",".join(repr_l) + "]" - if is_frozen: + if is_frozen and pruning_paths is None: try: - object.__setattr__(arg, "_qd_dc_repr", result) + object.__setattr__(arg, cache_attr, result) except AttributeError: pass return result @@ -111,36 +236,53 @@ def _is_template(arg_meta: ArgMetadata | None) -> bool: def stringify_obj_type( - raise_on_templated_floats: bool, path: tuple[str, ...], obj: object, arg_meta: ArgMetadata | None -) -> str | None: - """ - Convert an object into a string representation that only depends on its type. - - String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have - to be the actual python type string, just a string that is representative of the type, and won't collide - with different (allowed) types. String should be non-empty. - - Note that fields are not included in fast cache. - - arg_meta should only be non-None for the top level arguments and for data oriented objects. It is - used currently to determine whether a value is added to the cache key, as well as the name. eg - - at the top level, primitive types have their values added to the cache key if their annotation is qd.Template, - since they are baked into the kernel - - in data oriented objects, the values of all primitive types are added to the cache key, since they are baked - into the kernel, and require a kernel recompilation, when they change + raise_on_templated_floats: bool, + path: tuple[str, ...], + obj: object, + arg_meta: ArgMetadata | None, + pruning_paths: set[str] | None = None, + parent_flat: str | None = None, +) -> str | _FailFastcache: + """Convert ``obj`` into a deterministic string that contributes to the fastcache key. + + Return contract: + - ``str``: hashable; the returned string contributes to the cache key. + - ``_FAIL_FASTCACHE``: fastcache cannot safely hash this value — caller must propagate upward and + disable fastcache for the whole call. Triggered by: + * Recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``). + * Unrecognised type at this kernel-read path (see ``_fail_unknown_type``). + + Two rules from ``docs/source/user_guide/fastcache.md`` "Pruning-driven argument hashing" govern this function: + + 1. The cache key may *only* include contributions from paths that pruning has marked kernel-accessed + (``pruning_paths``). Container walkers (dataclass + data_oriented) check ``_is_path_used`` per child and + skip non-pruned subtrees — kernel-unread paths are *guaranteed* not to affect codegen so this is safe by + construction. + + 2. At paths the kernel *does* read, unrecognised types must not be silently dropped or hashed by type-name — + fastcache fails the call (loudly, with a one-shot warning) so the gap can be closed. + + Parameters: + - ``arg_meta``: non-``None`` only for top-level kernel args and for ``@qd.data_oriented`` members. Determines + whether primitive values are baked into the cache key (template-position primitives and all primitive members + of data-oriented containers). + - ``pruning_paths``: optional set of kernel-accessed flat names from L1 cache. When provided, + ``dataclass_to_repr`` and the ``data_oriented`` branch below descend only into children whose flat name is in + the set. Pruning info is populated by ``ASTTransformer.build_Name`` / ``build_Attribute`` (kernel-arg-rooted + chains) plus ``Pruning.fold_struct_nd_paths`` (ndarray accesses through data_oriented containers). + - ``parent_flat``: the flat-name prefix for ``obj``'s children (e.g. ``__qd_self`` if ``obj`` is the ``self`` + arg of a data_oriented kernel). Used together with ``pruning_paths`` to compute each child's flat name for + the narrow-walk lookup. """ - # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` strips - # wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / data-oriented - # walkers below (``dataclass_to_repr`` and the ``is_data_oriented`` branch) do raw ``getattr`` to fetch struct - # fields, so a wrapper stored as a struct field arrives here un-stripped. Without this branch the hasher falls - # through to the ``[FASTCACHE][PARAM_INVALID]`` warning and disables the fast path for the whole call. See - # ``perso_hugh/doc/quadrants-tensor.md`` §8.14. - # ``qd.Tensor`` wrappers: unwrap to the bare impl so the type checks below match. After unwrap, ``_qd_layout`` (if - # any) is on the impl. + # ``qd.Tensor`` wrappers passed as struct fields. The top-level kernel-arg unwrap hook in ``Kernel.__call__`` + # strips wrappers off positional / keyword args before the fastcache hasher sees them, but the dataclass / + # data-oriented walkers below do raw ``getattr`` to fetch struct fields, so a wrapper stored as a struct field + # arrives here un-stripped. Without this branch the hasher would hash the wrapper as an unknown type instead of + # unwrapping to the recognised impl. See ``perso_hugh/doc/quadrants-tensor.md`` §8.14. # - # PERF-CRITICAL: The _any_tensor_constructed guard makes this check zero-cost when no qd.Tensor has been created. - # ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a pointer comparison (~10 - # ns) vs an MRO walk (~100–200 ns). Do not replace with isinstance or remove the guard. + # PERF-CRITICAL: the ``_any_tensor_constructed`` guard makes this check zero-cost when no ``qd.Tensor`` has been + # created. ``type(obj) in _TENSOR_WRAPPER_TYPES`` is used instead of ``isinstance`` because it is a pointer + # comparison (~10 ns) vs an MRO walk (~100–200 ns). Do not replace with isinstance or remove the guard. if ( _tensor_wrapper._any_tensor_constructed and type(obj) in _TENSOR_WRAPPER_TYPES ): # pyright: ignore[reportOptionalMemberAccess] @@ -148,18 +290,27 @@ def stringify_obj_type( arg_type = type(obj) _layout = getattr(obj, "_qd_layout", None) _layout_tag = "" if _layout is None else f"-L{_layout!r}" + # needs_grad is part of the parameter struct layout that ``insert_ndarray_param`` bakes into the compiled + # artifact (the slot includes a grad pointer iff needs_grad=True). Two ndarrays with identical dtype + ndim + # but differing needs_grad MUST hash distinctly, otherwise the L2 narrow args_hash collides and the cached + # artifact's slot is mis-matched at launch (the launch picks the _QD_ARRAY vs _QD_ARRAY_WITH_GRAD bucket + # off ``v.grad is not None``, against a slot whose grad-presence was fixed at compile time) — yielding + # silent miscomputation or runtime OOB depending on slot offset alignment. if isinstance(obj, ScalarNdarray): - return f"[nd-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[nd-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, VectorNdarray): - return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, ScalarField): # disabled for now, because we need to think about how to handle field offset # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - return None + return _FAIL_FASTCACHE if isinstance(obj, MatrixNdarray): - return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}]" # type: ignore[arg-type] + _grad_tag = "-g" if obj.grad is not None else "" + return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}{_layout_tag}{_grad_tag}]" # type: ignore[arg-type] if isinstance(obj, torch_type): return f"[pt-{obj.dtype}-{obj.ndim}]" # type: ignore if isinstance(obj, np.ndarray): @@ -169,30 +320,44 @@ def stringify_obj_type( # etc # TODO: think about whether there is a way to include fields _mark_warn_if_not_tensor_annotation(arg_meta) - return None - if dataclasses.is_dataclass(obj): - return dataclass_to_repr(raise_on_templated_floats, path, obj) + return _FAIL_FASTCACHE + if is_dataclass_instance(obj): + return dataclass_to_repr( + raise_on_templated_floats, path, obj, pruning_paths=pruning_paths, parent_flat=parent_flat + ) if is_data_oriented(obj): + # Walk the data_oriented container's members, narrowed by pruning info — the kernel-compile path records + # every kernel-accessed attribute chain (ndarrays via ``_promote_ndarray_if_declared`` + + # ``Pruning.fold_struct_nd_paths``; primitives, opaque members, nested structs via + # ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` propagation calling ``pruning.mark_used``). Members + # not in ``pruning_paths`` are *guaranteed* not to affect kernel codegen because the kernel cannot read them. + # Dropping them from the hash satisfies rule 1 (cache only pruned paths). child_repr_l = ["da"] - _dict = {} try: - # pyright is ok with this approach _asdict = getattr(obj, "_asdict") _dict = _asdict() except AttributeError: _dict = obj.__dict__ for k, v in _dict.items(): - _child_repr = stringify_obj_type(raise_on_templated_floats, (*path, k), v, ArgMetadata(Template, "")) - if _child_repr is None: - if _should_warn: - _logging.warn( - f"A kernel that has been marked as eligible for fast cache was passed 1 or more parameters " - f"that are not, in fact, eligible for fast cache: one of the parameters was a " - f"@qd.data_oriented object, and one of its children was not eligible. The data oriented " - f"object was of type {type(obj)} and the child {k}={type(v)} was not eligible. For " - f"information, the path of the value was {path}." - ) - return None + # Skip Quadrants method-descriptor cache entries. ``QuadrantsCallable.__get__`` stashes the per-instance + # ``BoundQuadrantsCallable`` on ``instance.__dict__`` so subsequent ``instance.method`` lookups skip the + # descriptor allocation; those entries are not data and must not invalidate the fastcache key. + v_type = type(v) + if v_type is QuadrantsCallable or v_type is BoundQuadrantsCallable: + continue + child_flat = _child_flat(parent_flat, k) + if not _is_path_used(pruning_paths, child_flat): + continue + _child_repr = stringify_obj_type( + raise_on_templated_floats, + (*path, k), + v, + ArgMetadata(Template, ""), + pruning_paths=pruning_paths, + parent_flat=child_flat, + ) + if _child_repr is _FAIL_FASTCACHE: + return _FAIL_FASTCACHE child_repr_l.append(f"{k}: {_child_repr}") return ", ".join(child_repr_l) if issubclass(arg_type, (numbers.Number, np.number)): @@ -210,21 +375,33 @@ def stringify_obj_type( return "np.bool_" if isinstance(obj, enum.Enum): return f"enum-{obj.name}-{obj.value}" - _mark_should_warn() - # The bit in caps should not be modified without updating corresponding test - # The rest of free text can be freely modified - # (will probably formalize this in more general doc / contributor guidelines at some point) - _logging.warn( - f"[FASTCACHE][PARAM_INVALID] Parameter with path {path} and type {arg_type} not allowed by fast cache." - ) - return None + # Unrecognised type at a kernel-read path — fail fastcache loudly. See ``_fail_unknown_type``. + return _fail_unknown_type(obj, path) def hash_args( - raise_on_templated_floats: bool, args: Sequence[Any], arg_metas: Sequence[ArgMetadata | None] + raise_on_templated_floats: bool, + args: Sequence[Any], + arg_metas: Sequence[ArgMetadata | None], + pruning_paths: set[str] | None = None, ) -> str | FastcacheSkip: - """Return the args hash string, or a HashFailure explaining why hashing failed.""" - global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn + """Return the args hash string, or a ``FastcacheSkip`` explaining why hashing failed. + + Parameters: + - ``pruning_paths``: optional set of kernel-accessed flat names from the L1 cache (or freshly populated + after a cold compile). When provided, the container walkers skip children whose flat name is not in + the set; this is what keeps the cache key narrow and brittleness-free (no opaque-typed member can + affect the key unless the kernel actually reads it). + + Fastcache is disabled (``FastcacheSkip`` returned) when either: + - a recognised-but-unsupported tensor-like type (``ScalarField`` / ``MatrixField``) is encountered at a + kernel-read path, OR + - an unrecognised type is encountered at a kernel-read path (see ``_fail_unknown_type``). + + Both cases are loud: ``FastcacheSkip.WARN`` triggers an ``[INVALID_FUNC]`` log line and the unknown-type + branch additionally emits a one-shot ``[UNKNOWN_TYPE]`` warning identifying the offending type. + """ + global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls, _should_warn # pylint: disable=global-statement _should_warn = False g_num_calls += 1 g_num_args += len(args) @@ -235,11 +412,23 @@ def hash_args( ) for i_arg, arg in enumerate(args): start = time.time() - _hash = stringify_obj_type(raise_on_templated_floats, (str(i_arg),), arg, arg_metas[i_arg]) + arg_meta = arg_metas[i_arg] + # Top-level arg flat name: matches the kernel-side ``arg_meta.name`` (no ``__qd_`` prefix at the root). + # Used by the narrow walk to construct child flat names compatible with ``pruning.used_vars_by_func_id``. + top_flat = arg_meta.name if arg_meta is not None else None + _hash = stringify_obj_type( + raise_on_templated_floats, + (str(i_arg),), + arg, + arg_meta, + pruning_paths=pruning_paths, + parent_flat=top_flat, + ) g_repr_time += time.time() - start - if not _hash: + if _hash is _FAIL_FASTCACHE: g_num_ignored_calls += 1 return FastcacheSkip.WARN if _should_warn else FastcacheSkip.FIELD_VIA_TENSOR + # All other return values are valid strings (qualname fallback handles unrecognised types). hash_l.append(_hash) start = time.time() res = hash_iterable_strings(hash_l) diff --git a/python/quadrants/lang/_fast_caching/src_hasher.py b/python/quadrants/lang/_fast_caching/src_hasher.py index 1c03bf737b..789beffd41 100644 --- a/python/quadrants/lang/_fast_caching/src_hasher.py +++ b/python/quadrants/lang/_fast_caching/src_hasher.py @@ -1,3 +1,30 @@ +"""Two-level fastcache key derivation and persistence. + +Two-level cache +--------------- +The fastcache now exposes pruning information (already produced during compile) as a first-class lookup so the args +hash can walk *only* paths the kernel reads: + + - L1 (this module's ``make_source_config_key`` + ``load_pruning_info`` / ``store_pruning_info``): keyed by + source+config only (no args). Stores ``PruningInfo`` — the set of kernel-accessed flat names (e.g. + ``__qd_state__qd_x``) plus the ``graph_do_while_arg`` (also a kernel-source property). + + - L2 (``make_full_cache_key`` + ``load_full`` / ``store_full``): keyed by L1 key + the *narrow* args hash computed + with pruning info from L1. Stores the C++ ``frontend_cache_key`` that names the compiled artifact. + +Lookup flow on a warm call: L1 lookup → narrow args hash (paths from L1) → L2 lookup → load artifact. + +Cold compile flow: L1 miss → cold compile (pass 0 + pass 1) → store L1 → compute narrow args hash → store L2. + +Safety implication +------------------ +A kernel-unused path's contents (any type, including unrecognised tensor-likes) is *guaranteed* not to affect kernel +codegen, so dropping it from the hash is correct by construction. Paths the kernel *does* read still go through +``args_hasher.stringify_obj_type``; if it encounters an unrecognised type at such a path it fails the call's fastcache +loudly (one-shot ``[FASTCACHE][UNKNOWN_TYPE]`` warning identifying the offending ``type(v).__qualname__``), so a missed +type registration is impossible to miss and cannot serve stale cached results. +""" + import json import os import warnings @@ -17,21 +44,55 @@ from .hash_utils import hash_iterable_strings from .python_side_cache import PythonSideCache +# Prefix bytes mixed into L1 / L2 keys so they cannot collide even if the underlying inputs happen to hash to the +# same string. The original single-level cache key (kept for backward-compat reads via ``load`` below) had no such +# prefix; the new two-level scheme uses ``l1:`` and ``l2:`` markers so old single-level entries from prior Quadrants +# installs are simply ignored rather than mis-served. +_L1_MARKER = "l1" +_L2_MARKER = "l2" + + +def make_source_config_key(kernel_source_info: FunctionSourceInfo) -> str: + """Build the L1 cache key: source + config + version, with no dependence on args. + + Used by ``_try_load_fastcache`` before any args walking. The same key drives ``load_pruning_info`` / + ``store_pruning_info``; the matching ``make_full_cache_key`` derives the L2 key from this plus the narrow args + hash. + """ + kernel_hash = function_hasher.hash_kernel(kernel_source_info) + config_hash = config_hasher.hash_compile_config() + return hash_iterable_strings( + ( + _L1_MARKER, + quadrants.__version_str__, + kernel_hash, + config_hash, + kernel_source_info.filepath, + str(kernel_source_info.start_lineno), + "pruned", + "kcov" if os.environ.get("QD_KERNEL_COVERAGE") == "1" else "", + ) + ) + + +def make_full_cache_key(source_config_key: str, narrow_args_hash: str) -> str: + """Build the L2 cache key from the L1 key + narrow args hash. See module docstring.""" + return hash_iterable_strings((_L2_MARKER, source_config_key, narrow_args_hash)) -def create_cache_key( + +def compute_narrow_args_hash( raise_on_templated_floats: bool, kernel_source_info: FunctionSourceInfo, args: Sequence[Any], arg_metas: Sequence[ArgMetadata], + pruning_paths: set[str] | None, ) -> str | None: + """Compute the args hash narrowed by ``pruning_paths`` (or wide if ``pruning_paths is None``). + + Returns ``None`` if a recognised-but-unsupported tensor-like type forces fastcache off — the caller emits + the appropriate user-visible diagnostic via the ``FastcacheSkip.WARN`` branch. """ - cache key takes into account: - - arg types - - cache value arg values - - kernel function (but not sub functions) - - compilation config (which includes arch, and debug) - """ - args_hash = args_hasher.hash_args(raise_on_templated_floats, args, arg_metas) + args_hash = args_hasher.hash_args(raise_on_templated_floats, args, arg_metas, pruning_paths=pruning_paths) if isinstance(args_hash, FastcacheSkip): if args_hash is FastcacheSkip.WARN: # the bit in caps at start should not be modified without modifying corresponding text @@ -41,24 +102,138 @@ def create_cache_key( "fast cached, because one or more parameter types were invalid" ) return None - kernel_hash = function_hasher.hash_kernel(kernel_source_info) - config_hash = config_hasher.hash_compile_config() - cache_key = hash_iterable_strings( - ( - quadrants.__version_str__, - kernel_hash, - args_hash, - config_hash, - kernel_source_info.filepath, - str(kernel_source_info.start_lineno), - "pruned", - "kcov" if os.environ.get("QD_KERNEL_COVERAGE") == "1" else "", - ) + return args_hash + + +class L1CacheValue(BaseModel): + """Persisted L1 entry — pruning info that's source-and-config-deterministic (not args-dependent). + + Pruning info is the set of *flat names* (``__qd___qd___qd_…``) that the kernel actually reads. + Computed during compile (``Pruning.used_vars_by_func_id``); persisted here so subsequent calls can build + a narrow args hash without having to recompile. + + ``graph_do_while_arg`` is also stored here because it's a property of the kernel source (not of any + particular arg value). + + ``hashed_function_source_infos`` is the same content-hash list used for L2 validation; an L1 hit is + rejected if any helper source has changed since the L1 entry was written, even if the kernel source + itself hasn't (kernel_hash only covers the entry point). + """ + + used_py_dataclass_parameters: set[str] + hashed_function_source_infos: list[HashedFunctionSourceInfo] + graph_do_while_arg: str | None = None + + +def store_pruning_info( + source_config_key: str, + function_source_infos: Iterable[FunctionSourceInfo], + used_py_dataclass_parameters: set[str], + graph_do_while_arg: str | None = None, +) -> None: + """Persist the L1 entry after a cold compile. See ``L1CacheValue`` for what's stored / why.""" + if not source_config_key: + return + cache = PythonSideCache() + hashed_function_source_infos = function_hasher.hash_functions(function_source_infos) + cache_value = L1CacheValue( + used_py_dataclass_parameters=used_py_dataclass_parameters, + hashed_function_source_infos=list(hashed_function_source_infos), + graph_do_while_arg=graph_do_while_arg, ) - return cache_key + cache.store(source_config_key, cache_value.model_dump_json()) + + +def persist_l1_and_set_l2_key( + *, + l1_key: str | None, + kernel_source_info: FunctionSourceInfo | None, + used_py_dataclass_parameters: set[str] | None, + visited_functions: Iterable[FunctionSourceInfo], + graph_do_while_arg: str | None, + pruning_paths_from_l1: set[str] | None, + fast_checksum: str | None, + raise_on_templated_floats: bool, + py_args: tuple[Any, ...], + arg_metas: Sequence[ArgMetadata], +) -> tuple[str | None, bool]: + """After a successful materialize, persist L1 (if missing) and derive the L2 key. + + Two responsibilities: + + 1. If L1 was missing (``pruning_paths_from_l1 is None``), write the freshly-computed pruning info so the next + call from a new process can skip the args-walk warm-up. + + 2. If ``fast_checksum`` is still ``None`` (either L1 was missing, or L1 hit but phase 2 of the warm-call load + path saw a FIELD-related ``FastcacheSkip`` and kept ``None``), compute the narrow args hash now using the + just-populated pruning info and derive the L2 key. + + Returns ``(new_fast_checksum, generated)`` where ``generated`` is True iff this call freshly produced a non-None + L2 key (i.e. ``fast_checksum`` was ``None`` on entry and is non-None on return). The caller assigns + ``new_fast_checksum`` back to its kernel and uses ``generated`` to update its cache-observations counter. + + Returns ``(None, False)`` if fastcache is inactive for this kernel (``l1_key`` falsy / source info missing / + used-params not recorded), or ``(fast_checksum, False)`` if nothing changed. + """ + if not l1_key: + return None, False + if kernel_source_info is None: + return fast_checksum, False + if used_py_dataclass_parameters is None: + return fast_checksum, False + if pruning_paths_from_l1 is None: + store_pruning_info( + l1_key, + visited_functions, + used_py_dataclass_parameters, + graph_do_while_arg=graph_do_while_arg, + ) + # If phase 2 didn't run (L1 cold) or returned None (FIELD encountered earlier — but in that case post-compile + # narrow hashing would also see the FIELD and produce None, which is fine: we want fast_checksum to stay None + # so no L2 entry is stored), compute the narrow args hash now. + if fast_checksum is None: + narrow_args_hash = compute_narrow_args_hash( + raise_on_templated_floats, + kernel_source_info, + py_args, + arg_metas, + used_py_dataclass_parameters, + ) + if narrow_args_hash is not None: + return make_full_cache_key(l1_key, narrow_args_hash), True + return fast_checksum, False + + +def load_pruning_info( + source_config_key: str, +) -> tuple[set[str], str | None] | tuple[None, None]: + """Look up L1 cache. Returns (pruning_paths, graph_do_while_arg) on hit, (None, None) on miss / invalid. + + Validates ``hashed_function_source_infos`` against the current on-disk source; if any helper has changed + since the entry was written, the entry is invalid and we treat the lookup as a miss so the caller does a + cold compile (which will overwrite the stale L1 entry). + """ + cache = PythonSideCache() + maybe_value_json = cache.try_load(source_config_key) + if maybe_value_json is None: + return None, None + try: + cache_value = L1CacheValue.model_validate_json(maybe_value_json) + except (pydantic.ValidationError, json.JSONDecodeError, UnicodeDecodeError) as e: + warnings.warn(f"Failed to parse L1 cache entry: {e}") + return None, None + if not function_hasher.validate_hashed_function_infos(cache_value.hashed_function_source_infos): + return None, None + return cache_value.used_py_dataclass_parameters, cache_value.graph_do_while_arg class CacheValue(BaseModel): + """Persisted L2 entry — frontend cache key for the compiled artifact + source-validation metadata. + + The full pruning info is duplicated here for backward-compat with existing on-disk caches; it's the same + set that L1 also stores. The L1 set is the source of truth for narrowing the args hash on warm calls. + """ + frontend_cache_key: str hashed_function_source_infos: list[HashedFunctionSourceInfo] used_py_dataclass_parameters: set[str] @@ -72,22 +247,10 @@ def store( used_py_dataclass_parameters: set[str], graph_do_while_arg: str | None = None, ) -> None: - """ - Note that unlike other caches, this cache is not going to store the actual value we want. - This cache is only used for verification that our cache key is valid. Big picture: - - we have a cache key, based on args and top level kernel function - - we want to use this to look up LLVM IR, in C++ side cache - - however, before doing that, we first want to validate that the source code didn't change - - i.e. is our cache key still valid? - - the python side cache contains information we will use to verify that our cache key is valid - - ie the list of function source infos - - Update! We are now going to store parameter pruning infomation, which is: - - used_py_dataclass_parameters: set[str] - - Update 2: we are going to store the cache key used by the c++ kernel cache, so that we can use that - to retrieve the immutable cached c++ kernel later, rather than, before, we were storing the c++ - cached kernel using the fast cache key, leading to bugs, when cached kernel file then had to be mutable. + """Persist the L2 entry — the C++ frontend cache key that names the compiled artifact for this call. + + ``fast_cache_key`` is the L2 key from ``make_full_cache_key``. The L1 entry has typically been stored + earlier by ``store_pruning_info`` during the same materialize. """ if not fast_cache_key: return @@ -117,9 +280,9 @@ def _try_load(cache_key: str) -> CacheValue | None: def load(cache_key: str) -> tuple[set[str], str, str | None] | tuple[None, None, None]: - """ - loads function source infos from cache, if available - checks the hashes against the current source code + """Look up L2 cache. Returns (used_pruning_paths, frontend_cache_key, graph_do_while_arg) on hit. + + Validates helper-source hashes against the live source; an L2 entry is invalidated if any helper changed. """ cache_value = _try_load(cache_key) if cache_value is None: diff --git a/python/quadrants/lang/_pruning.py b/python/quadrants/lang/_pruning.py index 3289365767..aaa71620ce 100644 --- a/python/quadrants/lang/_pruning.py +++ b/python/quadrants/lang/_pruning.py @@ -1,13 +1,38 @@ -from ast import Name, Starred, expr, keyword +from ast import Attribute, Name, Starred, expr, keyword from collections import defaultdict from typing import TYPE_CHECKING, Any +from ._dataclass_util import create_flat_name from ._exceptions import raise_exception from ._quadrants_callable import BoundQuadrantsCallable, QuadrantsCallable from .exception import QuadrantsSyntaxError from .func import Func from .kernel_arguments import ArgMetadata + +def _flatten_arg_node(node: expr) -> tuple[str, str] | None: + """Flatten an AST arg node into ``(flat_name, root_name_id)`` (or ``None`` if the node isn't a recognisable + name/attribute chain rooted at a plain Name). + + Returns both the full flat name (e.g. ``__qd_self__qd_dofs`` for ``self.dofs``) and the root Name's id (``self``). + Callers use the root id to distinguish kernel-arg-rooted chains (``self.dofs`` → root ``self``) from already- + flattened dataclass-arg references (``__qd_self__qd_dofs`` → root ``__qd_self__qd_dofs``). The flat path alone is + ambiguous because ``__qd_self__qd_dofs`` could be either an attribute chain *or* a single flattened Name. + + Mirrors ``FlattenAttributeNameTransformer._flatten_attribute_name`` but on the raw call-arg AST. + Used by ``record_after_call`` to handle ``f(self.dofs)`` etc. — without this the callee's pruning + info for attribute-chain args is dropped at the call boundary.""" + if isinstance(node, Name): + return node.id, node.id + if isinstance(node, Attribute): + parent = _flatten_arg_node(node.value) + if parent is None: + return None + parent_flat, root_id = parent + return create_flat_name(parent_flat, node.attr), root_id + return None + + if TYPE_CHECKING: import ast @@ -39,11 +64,123 @@ def __init__(self, kernel_used_parameters: set[str] | None) -> None: self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_used_parameters) # only needed for args, not kwargs self.callee_param_by_caller_arg_name_by_func_id: dict[int, dict[str, str]] = defaultdict(dict) + # id(ndarray) -> seen during the first compile pass via ``_promote_ndarray_if_declared``. Populated by the + # AST builder when a chain like ``self.x.y`` resolves to an ndarray that was pre-declared by + # ``_predeclare_struct_ndarrays``. On the second (enforcing) pass, ``_predeclare_struct_ndarrays`` only + # registers ndarrays whose id is in this set — dropping every reachable-but-unused ndarray from the kernel's + # parameter list. + self.used_struct_ndarray_ids: set[int] = set() + # Whether the non-enforcing first pass actually ran for this kernel materialize. When fastcache hits, we skip + # pass 0 entirely and ``used_struct_ndarray_ids`` is therefore unreliable — in that case + # ``_predeclare_struct_ndarrays`` falls back to registering every reachable ndarray (same as historical + # behavior). + self.pass_0_ran: bool = False + # Kernel-arg-rooted attribute chains used by each func, in flat-name form (``__qd_self__qd_dofs__qd_x``). + # Populated by ``ASTTransformer.build_Attribute`` for non-flattened kernel args (data_oriented / + # qd.template). Kept *separate* from ``used_vars_by_func_id`` because the latter drives ``struct_locals`` on + # the enforcing pass (line ~230 of kernel.py), and ``FlattenAttributeNameTransformer`` would rewrite ``s.x`` + # → ``Name('__qd_s__qd_x')`` if these chain names appeared there — yielding a ``QuadrantsNameError: Name + # "__qd_s__qd_x" is not defined``. ``record_after_call`` propagates entries from callee to caller (so + # ``f(self.dofs)`` where ``f`` reads ``s.x`` ends up with ``__qd_self__qd_dofs__qd_x`` in the kernel's set). + # After both compile passes, ``Pruning.fold_kernel_arg_chain_paths`` merges the kernel's set into + # ``used_vars_by_func_id[KERNEL_FUNC_ID]`` so fastcache stores them in L1 and the args_hasher narrow walk + # picks them up. + self.kernel_arg_chain_paths_by_func_id: dict[int, set[str]] = defaultdict(set) def mark_used(self, func_id: int, parameter_flat_name: str) -> None: assert not self.enforcing self.used_vars_by_func_id[func_id].add(parameter_flat_name) + def mark_kernel_arg_chain_used(self, func_id: int, chain_flat_name: str) -> None: + """Record a kernel-arg-rooted attribute chain (e.g. ``__qd_self__qd_dofs__qd_x``). + + Stored separately from ``used_vars_by_func_id`` — see the docstring on ``kernel_arg_chain_paths_by_func_id`` + for why.""" + assert not self.enforcing + self.kernel_arg_chain_paths_by_func_id[func_id].add(chain_flat_name) + + def fold_struct_nd_paths( + self, struct_ndarray_launch_info: list[tuple[Any, int, tuple[str, ...]]], arg_metas: list[ArgMetadata] + ) -> None: + """Add data_oriented (and dataclass-nested) ndarray attribute chains to the kernel's pruning flat name set so + ``args_hasher.hash_args`` narrow-walks them correctly. + + Background: ``used_vars_by_func_id[KERNEL_FUNC_ID]`` is populated by AST walking of flat names produced by + ``FlattenAttributeNameTransformer`` — but that transformer only flattens *dataclass* args. + ``@qd.data_oriented`` args (template-typed) stay as ``Attribute(value=Name(self), attr=…)`` in the AST and + don't contribute to ``used_vars_by_func_id``. Their kernel-accessed ndarray paths *are* recorded — in + ``struct_ndarray_launch_info`` as ``(arg_id_vec[0], arg_idx, attr_chain)`` — but only for ndarray members. + + Convert each ``(arg_idx, attr_chain)`` to a flat name like ``__qd___qd___qd_…`` and union + all prefixes into the pruning set. After this fold, narrowing in args_hasher matches the same convention used + for dataclass args. + + Limitation: non-ndarray data_oriented members (primitive ints/floats whose values are baked in at compile, + opaque Python objects) are *not* tracked anywhere as kernel-accessed. The narrow walk cannot distinguish + "kernel reads this primitive" from "kernel does not read this primitive". The + ``args_hasher.stringify_obj_type`` data_oriented branch handles this conservatively by walking *all* attrs of + a data_oriented container — narrowing only suppresses subtrees explicitly absent from the pruning set. So for + a data_oriented arg with mostly-ndarray members, the cache key correctly depends on the ndarray paths it + uses; for one with primitive members whose values matter, those members are still folded into the hash + (qualname-fallback / value paths). + """ + if not struct_ndarray_launch_info: + return + kernel_used: set[str] = self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID] + for _arg_id_cpp, arg_idx, attr_chain in struct_ndarray_launch_info: + if arg_idx < 0 or arg_idx >= len(arg_metas): + continue + arg_name = arg_metas[arg_idx].name + if not arg_name: + continue + flat = arg_name + for attr in attr_chain: + flat = create_flat_name(flat, attr) + kernel_used.add(flat) + + def fold_kernel_arg_chain_paths(self) -> None: + """Merge the kernel's chain-paths set into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` *after* both compile + passes have completed. + + Background: ``ASTTransformer.build_Attribute`` records every kernel-arg-rooted attribute chain (e.g. + ``__qd_self__qd_n``, ``__qd_self__qd_cfg``) into ``kernel_arg_chain_paths_by_func_id`` rather than + ``used_vars_by_func_id``, because the latter is read on the enforcing pass to build ``struct_locals`` for + ``FlattenAttributeNameTransformer``. If chain names appeared there, the transformer would rewrite ``self.n`` + into ``Name('__qd_self__qd_n')`` and ``build_Name`` would fail to find such a variable. + + Doing the merge here — after pass 1, just like ``fold_struct_nd_paths`` — avoids that interaction while + still making the chain paths available to the fastcache args-hash narrow walk. The set on + ``used_py_dataclass_parameters_by_key_enforcing[key]`` is the *same* object as + ``used_vars_by_func_id[KERNEL_FUNC_ID]`` (assigned by reference at end of pass 0), so updating one updates + both. + """ + kernel_chain_paths = self.kernel_arg_chain_paths_by_func_id.get(Pruning.KERNEL_FUNC_ID) + if not kernel_chain_paths: + return + self.used_vars_by_func_id[Pruning.KERNEL_FUNC_ID].update(kernel_chain_paths) + + @staticmethod + def _propagate_chain_paths( + callee_chain_paths: set[str], + callee_param_name: str, + caller_flat: str, + chain_paths_to_propagate: set[str], + ) -> None: + """When ``f(self.dofs)`` is called and ``f``'s body reads ``s.x`` (callee param ``s`` bound to caller + attribute chain ``self.dofs``), the callee's chain-paths set contains ``__qd_s__qd_x`` but the + caller's chain-paths set must record ``__qd_self__qd_dofs__qd_x``. This helper does that + prefix substitution. Only chain paths starting with ``__qd___qd_`` are propagated + (chains rooted in unrelated callee args don't apply to this caller arg).""" + prefix = f"__qd_{callee_param_name}__qd_" + for sub in callee_chain_paths: + if sub.startswith(prefix): + rest = sub[len(prefix) :] + if caller_flat.startswith("__qd_"): + new_flat = f"{caller_flat}__qd_{rest}" + else: + new_flat = f"__qd_{caller_flat}__qd_{rest}" + chain_paths_to_propagate.add(new_flat) + def enforce(self) -> None: self.enforcing = True @@ -70,7 +207,9 @@ def record_after_call( callee_func_id = func.wrapper.func_id # type: ignore # Copy the used parameters from the child function into our own function. callee_used_vars = self.used_vars_by_func_id[callee_func_id] + callee_chain_paths = self.kernel_arg_chain_paths_by_func_id.get(callee_func_id, set()) vars_to_unprune: set[str] = set() + chain_paths_to_propagate: set[str] = set() arg_id = 0 # node.args ordering will match that of the called function's metas_expanded, # because of the way calling with sequential args works. @@ -88,6 +227,20 @@ def record_after_call( callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) + # Propagate kernel-arg-rooted chain paths through attribute-chain args (``f(self.dofs)``) AND through + # plain-Name args of non-flattened types (``f(self)``). Gate on the *root* Name id, not the resulting + # flat string: ``self.dofs`` flattens to ``__qd_self__qd_dofs`` (which starts with ``__qd_``) but its + # root is the bare kernel arg ``self`` — we still need to propagate. Already-flattened dataclass refs + # like ``Name('__qd_self__qd_dofs')`` have a ``__qd_*`` root and are handled by the ``vars_to_unprune`` + # path above. + flat = _flatten_arg_node(arg) + if flat is not None: + caller_flat, root_id = flat + if not root_id.startswith("__qd_"): + callee_param_name = callee_func.arg_metas_expanded[arg_id + self_offset].name # type: ignore + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 # Note that our own arg_metas ordering will in general NOT match that of the child's. That's # because our ordering is based on the order in which we pass arguments to the function, but the @@ -101,8 +254,20 @@ def record_after_call( callee_param_name = kwarg.arg if callee_param_name in callee_used_vars: vars_to_unprune.add(caller_arg_name) + flat = _flatten_arg_node(kwarg.value) + if flat is not None: + caller_flat, root_id = flat + if not root_id.startswith("__qd_"): + callee_param_name = kwarg.arg + # ``kwarg.arg`` is ``None`` for double-star unpacking (``**kwargs``); chain propagation requires + # a concrete parameter name so just skip. + if callee_param_name is not None: + self._propagate_chain_paths( + callee_chain_paths, callee_param_name, caller_flat, chain_paths_to_propagate + ) arg_id += 1 self.used_vars_by_func_id[my_func_id].update(vars_to_unprune) + self.kernel_arg_chain_paths_by_func_id[my_func_id].update(chain_paths_to_propagate) used_callee_vars = self.used_vars_by_func_id[callee_func_id] child_arg_id = 0 diff --git a/python/quadrants/lang/_quadrants_callable.py b/python/quadrants/lang/_quadrants_callable.py index ba7e7b8217..0c071c6919 100644 --- a/python/quadrants/lang/_quadrants_callable.py +++ b/python/quadrants/lang/_quadrants_callable.py @@ -90,15 +90,31 @@ def __init__(self, fn: Callable, wrapper: Callable) -> None: self._adjoint: "Kernel | None" = None self.grad: "Kernel | None" = None self.is_pure: bool = False + self._attr_name: str | None = None update_wrapper(self, fn) + def __set_name__(self, owner: type, name: str) -> None: + # Captured at class-body time. ``data_oriented.make_kernel_indirect`` sets this explicitly on its replacement + # callable since setattr-after-class doesn't trigger __set_name__. + self._attr_name = name + def __call__(self, *args, **kwargs): return self.wrapper.__call__(*args, **kwargs) def __get__(self, instance, owner): if instance is None: return self - return BoundQuadrantsCallable(instance, self) + bound = BoundQuadrantsCallable(instance, self) + # Non-data descriptor (no __set__): a __dict__ entry on the instance wins over the descriptor on subsequent + # attribute lookups. Stash the bound callable there so future ``instance.method`` accesses skip __get__ + # allocation entirely (~0.6-1.2 us/call). Skip if the class uses __slots__ (no __dict__) or the attribute name + # wasn't captured. + name = self._attr_name + if name is not None: + inst_dict = getattr(instance, "__dict__", None) + if inst_dict is not None: + inst_dict[name] = bound + return bound class BoundQuadrantsCallable: diff --git a/python/quadrants/lang/_template_mapper.py b/python/quadrants/lang/_template_mapper.py index c8c0deb3b7..07d32892c1 100644 --- a/python/quadrants/lang/_template_mapper.py +++ b/python/quadrants/lang/_template_mapper.py @@ -15,26 +15,28 @@ _struct_nd_paths_for, ) -# Per-``type(arg)`` precomputed dispatch for the args_hash ndarray-id walk in ``TemplateMapper.lookup``. Each entry -# is either the cached attribute path list (when the class is data_oriented and actually holds ndarrays) or ``None`` -# (when the per-call walk is a no-op — covers the common case of typed-dataclass args, non-data_oriented composite -# args, primitives, and data_oriented classes with no ndarray members). One dict lookup per template-slot arg per -# call, ~30 ns, replacing the previous unconditional ``is_data_oriented(arg)`` + ``type(arg).__dict__.get`` chain -# that cost ~15% FPS on small-step CPU benches (anymal_zero CPU bs=0). Missing-key (``KeyError``) signals first -# sighting and triggers ``_classify_for_args_hash``; cached ``None`` short-circuits the walk for known-no-op types. -_arg_nd_paths_or_none: "dict[type, list[tuple] | None]" = {} - - -def _classify_for_args_hash(arg: Any) -> "list[tuple] | None": - """First-sighting classification for ``type(arg)`` in the args_hash walk. Returns the path list to walk (when the - arg is a data_oriented container that actually contains ndarrays), or ``None`` to skip subsequent per-call work - for this type.""" +# Per-class disposition for the args_hash ndarray-id walk in ``TemplateMapper.lookup``: one of ``_SKIP`` (this class +# never contributes — non-data_oriented, or ``@qd.data_oriented(stable_members=True)``) or ``_PER_INSTANCE`` (delegate +# to ``_struct_nd_paths_for`` for a per-instance walk). The disposition depends only on type (data_oriented? +# stable_members?), so caching by class is correct. The *actual* path list is per-instance because @qd.data_oriented +# classes can have polymorphic attribute structure across instances (Genesis ``DataManager`` is the motivating case). +_arg_disposition: dict[type, object] = {} +_SKIP = object() +_PER_INSTANCE = object() + + +def _classify_disposition(arg: Any) -> object: + """First-sighting per-class disposition for the args_hash walk. Returns ``_SKIP`` (no per-call walk for this + class) or ``_PER_INSTANCE`` (delegate to ``_struct_nd_paths_for`` for a per-instance walk). + + ``_qd_stable_members`` here is a *launch-time perf hint only* (see ``@qd.data_oriented(stable_members=...)``). + It promises that ndarray members are never reassigned, which lets us skip the per-call walk entirely. It does + not affect fastcache key derivation.""" if not is_data_oriented(arg): - return None - paths = _struct_nd_paths_for(arg) - if not paths: - return None - return paths + return _SKIP + if type(arg).__dict__.get("_qd_stable_members"): + return _SKIP + return _PER_INSTANCE Key: TypeAlias = tuple[Any, ...] @@ -103,25 +105,29 @@ def lookup(self, raise_on_templated_floats: bool, args: tuple[Any, ...]) -> tupl # serve a stale entry and the new ndarray's dtype/ndim would be wrong. Fold the reachable ndarray ids into the # hash for the (small) set of arg positions that need it. # - # ``template_slot_locations`` already gives us the subset of arg positions annotated as ``qd.template()`` — - # the only positions where a data_oriented container could appear (typed-dataclass args carry a specific - # dataclass type by construction and a data_oriented class is never a dataclass). Iterating just those - # positions instead of all args trims the per-call work proportionally (Genesis main ``kernel_step_1``: 4 - # template positions of 16 args). + # The kernel's ``template_slot_locations`` already gives us the subset of arg positions annotated as + # ``qd.template()`` — the only positions where a data_oriented container could appear (typed-dataclass args + # carry a specific dataclass type by construction and a data_oriented class is never a dataclass). So we only + # iterate ``template_slot_locations`` instead of all args (Genesis main kernel_step_1: 4 template positions + # of 16 args; Genesis branch step_1/step_2: 4 of 4). # - # Per-``type(arg)`` cache (``_arg_nd_paths_or_none``) maps each seen type to either the path list to walk or - # ``None`` to skip — one ``dict.get`` per candidate per call after warmup, replacing the previous unconditional - # ``is_data_oriented`` + ``__dict__.get`` chain that cost ~15% FPS on small-step CPU benches. + # For each candidate, ``_arg_disposition`` caches the per-class decision (skip vs walk-per-instance) and the + # actual paths come from ``_struct_nd_paths_for`` (per-instance, stashed on ``arg._qd_nd_paths``). Per-instance + # path caching is load-bearing for correctness — @qd.data_oriented classes can have polymorphic attribute + # structure across instances (Genesis ``DataManager`` only allocates adjoint-cache members when + # ``requires_grad=True``); a per-class cache populated from one instance can't safely be reused for another. nd_ids: list = [] for i in self.template_slot_locations: arg = args[i] cls = type(arg) - try: - paths = _arg_nd_paths_or_none[cls] - except KeyError: - paths = _classify_for_args_hash(arg) - _arg_nd_paths_or_none[cls] = paths - if paths is None: + disposition = _arg_disposition.get(cls) + if disposition is None: + disposition = _classify_disposition(arg) + _arg_disposition[cls] = disposition + if disposition is _SKIP: + continue + paths = _struct_nd_paths_for(arg) + if not paths: continue for chain in paths: v = arg diff --git a/python/quadrants/lang/_template_mapper_hotpath.py b/python/quadrants/lang/_template_mapper_hotpath.py index 2f16c4dfb8..27127976af 100644 --- a/python/quadrants/lang/_template_mapper_hotpath.py +++ b/python/quadrants/lang/_template_mapper_hotpath.py @@ -76,41 +76,32 @@ _primitive_types = {int, float, bool} -# Per-instance ndarray-path cache, stored OFF-instance in a module-level ``id(arg) -> list[paths]`` dict and cleaned -# up via ``weakref.finalize``. We can't stash it on ``arg.__dict__`` because the fastcache args walker iterates -# every key in ``__dict__`` and rejects any unsupported type (``list`` isn't whitelisted) — that disabled the L1 -# cache for ~10 Genesis kernels and broke the ``test_static`` / ``test_num_envs`` / ``test_ndarray_no_compile`` -# subprocess tests. +# Per-instance cache of ndarray attribute paths, stashed on the instance via ``object.__setattr__`` (compatible with +# frozen dataclasses). Used by both ``TemplateMapper.lookup``'s args_hash walk and the ``_extract_arg`` data_oriented +# descriptor walk. Per-instance caching is necessary because @qd.data_oriented classes can have *different attribute +# structures across instances of the same class* — Genesis ``DataManager``, for instance, only allocates +# ``*_adjoint_cache`` members when ``requires_grad=True``. A class-level cache populated from the first-ever instance +# would either crash on missing attributes (forward direction, "first instance has, second misses") or silently miss +# new ones (inverse direction), both of which produce wrong-shape kernel reuse. # -# Per-instance (not per-class) caching is correctness-load-bearing: ``@qd.data_oriented`` and ``qd.Tensor`` -# containers can have polymorphic attribute structures across instances of the same class — Genesis ``DataManager`` -# only allocates ``*_adjoint_cache`` ndarrays when ``requires_grad=True``, and a ``qd.Tensor`` field can wrap an -# ``Ndarray`` on one instance and a ``MatrixField`` on another. A per-class cache populated from the first-walked -# instance reused those paths on a sibling and crashed ``_collect_struct_nd_descriptors`` (reading ``element_type`` -# off a ``MatrixField``) — affecting ~60 Genesis tests under the ndarray backend. +# Steady-state cost: one ``__dict__`` lookup per arg per call (~30ns), same order as the previous class-level +# ``dict.get``. The walk itself (``_build_struct_nd_paths``) is paid once per instance lifetime at first kernel +# launch with that instance — typically O(10) instances per Genesis scene, so ~10us total at scene build. # -# ``_struct_nd_paths_cache`` (per-class) remains as a fallback for objects that don't support ``weakref.finalize`` -# (e.g. ``__slots__`` classes without ``__weakref__``). Genesis containers all support weakrefs so the fast path -# below is what runs in practice. +# ``_struct_nd_paths_cache`` (below) is a fallback for ``__slots__`` classes that have no ``__dict__`` and so can't +# accept the ``object.__setattr__`` stash. Such classes inherit the legacy per-class-cache behaviour (and its +# polymorphic-instance limitations). Genesis data_oriented containers don't use ``__slots__``, so this branch is +# unreachable in practice. _struct_nd_paths_cache: dict[type, list[tuple]] = {} -_struct_nd_paths_instance_cache: dict[int, list[tuple]] = {} -def _drop_instance_paths(arg_id: int) -> None: - _struct_nd_paths_instance_cache.pop(arg_id, None) - - -def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list, _seen: set | None = None) -> None: - # Cycle protection: real Genesis containers form attribute graphs with shared references and back-pointers (e.g. - # ``sim.rigid_solver.sim is sim``). Without ``_seen`` this recurses infinitely on the back-edge and blows the - # Python stack on first launch. Tracked by ``id(obj)`` so we don't accidentally rely on ``__hash__`` for arbitrary - # user types — and so primitives like equal-but-distinct dataclass instances are still walked independently. +def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list, _seen: "set[int] | None" = None) -> None: + # Cycle-safe walker. Genesis object graphs have cross-references (e.g. ``solver -> scene -> sim -> solver``) and + # Pydantic-options-style children. ``_seen`` tracks ``id(obj)`` for the current traversal to avoid re-entering a + # node we've already expanded. Cheap (one ``set`` op per frame, only allocated when we actually start recursing) + # and bounds the walk to a finite depth regardless of the graph shape. if _seen is None: - _seen = set() - obj_id = id(obj) - if obj_id in _seen: - return - _seen.add(obj_id) + _seen = {id(obj)} if is_dataclass_instance(obj): children = ((f.name, getattr(obj, f.name)) for f in dataclasses.fields(obj)) else: @@ -130,24 +121,40 @@ def _build_struct_nd_paths(obj: Any, prefix: tuple, out: list, _seen: set | None if issubclass(v_type, Ndarray): out.append(chain) elif is_data_oriented(v) or is_dataclass_instance(v): + v_id = id(v) + if v_id in _seen: + continue + _seen.add(v_id) _build_struct_nd_paths(v, chain, out, _seen) def _struct_nd_paths_for(arg: Any) -> list[tuple]: - """Return the cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` instances are - reachable from ``arg``. First call for an *instance* walks ``arg`` once via ``_build_struct_nd_paths`` and - caches the path list in ``_struct_nd_paths_instance_cache`` keyed by ``id(arg)``; ``weakref.finalize`` evicts - the entry when ``arg`` is garbage-collected. Subsequent calls are a dict lookup keyed by ``id(arg)``. + """Return the per-instance cached attribute paths (each a tuple of attr-name strings) at which ``Ndarray`` + instances are reachable from ``arg``. First call walks ``arg`` once via ``_build_struct_nd_paths`` and stashes + the result on the instance as ``_qd_nd_paths`` (via ``object.__setattr__`` so it works for frozen dataclasses + and ``@qd.data_oriented`` containers alike); subsequent calls fetch it via instance ``__dict__`` lookup. + + Per-instance caching is correctness-load-bearing (this is the fix for Codex #3 on PR #704, + https://github.com/Genesis-Embodied-AI/quadrants/pull/704#discussion_r3253281957): ``@qd.data_oriented`` classes + can have different attribute sets across instances of the same class (e.g. Genesis ``DataManager`` with vs + without ``requires_grad``), and even within an instance's lifetime a ``qd.Tensor`` member can swap backends, so + a per-class cache populated from one instance can't safely be reused for another. ``__slots__`` classes without + a ``__dict__`` fall back to per-class caching (see ``_struct_nd_paths_cache``) and retain the legacy limitation. - Per-instance (not per-class) caching is correctness-load-bearing — see the module-level comment on - ``_struct_nd_paths_instance_cache``. Objects that don't support ``weakref.finalize`` (e.g. ``__slots__`` classes - without ``__weakref__``) fall back to the legacy per-class cache; Genesis containers all support weakrefs so - the per-instance branch is the one that runs in practice. + Limitation: the path list is recorded once per instance. If a new ndarray attribute is attached to an instance + *after* its first kernel call (uncommon — Genesis containers declare all ndarrays in ``__init__``), it won't be + tracked until the cache is invalidated. Workaround: ``del arg.__dict__['_qd_nd_paths']`` (or restart the + process). """ - arg_id = id(arg) - paths = _struct_nd_paths_instance_cache.get(arg_id) - if paths is not None: - return paths + # Fast path: instance already walked. ``__dict__["…"]`` skips descriptor / ``__getattr__`` machinery (some + # third-party metaclasses, e.g. Pydantic, recurse infinitely on probe-style ``getattr`` for unknown names — + # see ``is_data_oriented`` for the same defensiveness). + try: + return arg.__dict__["_qd_nd_paths"] + except (AttributeError, KeyError): + pass + # ``__slots__`` fallback or first-sighting of this instance: check the class-level cache too, so that a + # ``__slots__`` class doesn't re-walk on every call. cls = type(arg) paths = _struct_nd_paths_cache.get(cls) if paths is not None: @@ -155,11 +162,11 @@ def _struct_nd_paths_for(arg: Any) -> list[tuple]: paths = [] _build_struct_nd_paths(arg, (), paths) try: - weakref.finalize(arg, _drop_instance_paths, arg_id) - except TypeError: + object.__setattr__(arg, "_qd_nd_paths", paths) + except AttributeError: + # ``__slots__`` class without a ``_qd_nd_paths`` slot — degrade to per-class caching. Loses correctness + # under polymorphic-instance attribute structure, but Genesis data_oriented containers don't use slots. _struct_nd_paths_cache[cls] = paths - else: - _struct_nd_paths_instance_cache[arg_id] = paths return paths @@ -191,13 +198,13 @@ def _collect_struct_nd_descriptors(arg: Any, out: list) -> None: """Emit per-ndarray shape descriptors ``(joined-path, element_type, ndim, needs_grad, layout)`` for every ndarray reachable from ``arg``. Used by the template-mapper to refine the spec key for ``@qd.data_oriented`` args holding ndarrays — see the data_oriented branch in ``_extract_arg``. - - The path cache is per-instance (see ``_struct_nd_paths_for``) so polymorphic-instance attribute structure is - handled correctly. Within a single instance's lifetime, a cached path's leaf may still cease to be an ``Ndarray`` - (e.g. a ``qd.Tensor`` member swapped from an ``Ndarray``-backed impl to a ``MatrixField``-backed one); when that - happens we silently skip the descriptor — the spec key still includes ``weakref(arg)`` so cache discrimination - remains correct. """ + # The path cache is per-instance (see ``_struct_nd_paths_for``) so polymorphic-instance attribute structure is + # handled correctly. Within a single instance's lifetime, a cached path's leaf may still cease to be an + # ``Ndarray`` (e.g. ``qd.Tensor``'s underlying impl swapped between an ``Ndarray`` and a ``MatrixField``); when + # that happens we silently skip the descriptor — ``v.element_type`` / ``v.shape`` / ``v._qd_layout`` are + # Ndarray-only accessors. The per-instance ``weakref(arg)`` part of the spec key still ensures correct cache + # discrimination across instances. for chain in _struct_nd_paths_for(arg): v = arg for a in chain: @@ -206,6 +213,8 @@ def _collect_struct_nd_descriptors(arg: Any, out: list) -> None: v = v._unwrap() if not isinstance(v, Ndarray): continue + # ``Ndarray.shape`` can legitimately be ``None`` (uninitialised ``_physical_shape``); such an instance + # has no meaningful spec contribution, so skip it rather than crashing on ``len(None)``. shape = v.shape if shape is None: continue @@ -287,6 +296,11 @@ def _extract_arg(raise_on_templated_floats: bool, arg: Any, annotation: Annotati # # Containers with no ndarrays keep the original short-path (one spec per instance via weakref) so this is # a no-op for the existing data_oriented + qd.field workloads (genesis field-backend). + # + # Opt-out: ``_qd_stable_members = True`` on the class (or ``@qd.data_oriented(stable_members=True)``) + # skips the per-call descriptor walk. + if type(arg).__dict__.get("_qd_stable_members"): + return weakref.ref(arg) nd_descriptors: list = [] _collect_struct_nd_descriptors(arg, nd_descriptors) if nd_descriptors: diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 509995b5b5..816cc9dd74 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -16,6 +16,7 @@ from quadrants._lib import core as _qd_core from quadrants.lang import exception, expr, impl, matrix, mesh from quadrants.lang import ops as qd_ops +from quadrants.lang._dataclass_util import create_flat_name from quadrants.lang._ndrange import _Ndrange from quadrants.lang._unpacked import _UnpackedVectorRef from quadrants.lang.ast.ast_transformer_utils import ( @@ -86,6 +87,21 @@ def build_Name(ctx: ASTTransformerFuncContext, node: ast.Name): pruning = ctx.global_context.pruning if not pruning.enforcing and not ctx.expanding_dataclass_call_parameters and node.id.startswith("__qd_"): ctx.global_context.pruning.mark_used(ctx.func.func_id, node.id) + # Track chains rooted at non-flattened parameter names: top-level ``@qd.kernel`` args + # (``ctx.kernel_args``) and ``@qd.func`` params (``ctx.fn_param_names``). Both appear in the AST as bare + # names (``self`` for a data_oriented kernel arg; ``static_rigid_sim_config`` for a ``qd.template()`` func + # arg bound to a ``@qd.data_oriented`` instance). ``build_Attribute`` propagates this annotation through + # ``state.dofs.x`` chains and ``mark_kernel_arg_chain_used``-s the flat name. The kernel's pruning narrow + # walk picks them up directly (kernel case) or after ``record_after_call`` propagates the callee's func-arg + # chains back through the call boundary (func case): e.g. ``func(s=self._sub)`` where ``func`` reads ``s.x`` + # ends up with ``__qd_self__qd__sub__qd_x`` recorded in the kernel's pruning, so the args-hasher hashes that + # primitive value into the fastcache key. + # Dataclass args go through ``FlattenAttributeNameTransformer`` and reach this branch as already-flat + # ``__qd_…`` Names, handled by the block above via ``mark_used``. + if not node.id.startswith("__qd_") and (node.id in ctx.kernel_args or node.id in ctx.fn_param_names): + node._qd_arg_chain = node.id # type: ignore[attr-defined] + else: + node._qd_arg_chain = None # type: ignore[attr-defined] node.violates_pure, node.ptr, node.violates_pure_reason = ctx.get_var_by_name(node.id) # Flattened struct fields (``__qd_foo__qd_bar``) injected by ``populate_global_vars_from_dataclass`` are raw # ``Ndarray`` instances. ``build_Attribute`` already promotes these via ``_promote_ndarray_if_declared`` but @@ -668,14 +684,37 @@ def build_attribute_if_is_dynamic_snode_method(ctx: ASTTransformerFuncContext, n @staticmethod def _promote_ndarray_if_declared(ctx: ASTTransformerFuncContext, value: Any) -> Any: """If *value* is a bare ``Ndarray`` that was pre-declared as a kernel arg (in ``_predeclare_struct_ndarrays``), - return the ``AnyArray`` proxy from the cache. Otherwise return *value* unchanged.""" + return the ``AnyArray`` proxy from the cache. Otherwise return *value* unchanged. + + Also records the source ndarray id in ``pruning.used_struct_ndarray_ids`` on the non-enforcing first pass, so + that the enforcing second-pass ``_predeclare_struct_ndarrays`` can skip ndarrays that the kernel never actually + accesses. Both ``Ndarray`` instances and pre-existing ``AnyArray`` proxies (tagged with + ``_qd_source_ndarray_id``) are handled — the latter is the case for accesses in inlined ``@qd.func`` bodies + whose params were bound to already-promoted proxies by Option A in ``call_transformer``. + """ from quadrants.lang._ndarray import Ndarray # pylint: disable=C0415 - if not isinstance(value, Ndarray): + pruning = ctx.global_context.pruning + # Mirror ``build_Name``'s mark_used gate: only mark on the non-enforcing first pass and not during synthetic + # per-leaf argument expansion for ``@qd.func`` calls. The callee body's own accesses (which run with + # ``expanding_dataclass_call_parameters = False``) are what we want to count. + should_mark = not pruning.enforcing and not ctx.expanding_dataclass_call_parameters + if isinstance(value, Ndarray): + cache = ctx.global_context.ndarray_to_any_array + key = id(value) + arr = cache.get(key) + if arr is not None: + if should_mark: + pruning.used_struct_ndarray_ids.add(key) + return arr return value - cache = ctx.global_context.ndarray_to_any_array - arr = cache.get(id(value)) - return arr if arr is not None else value + # Pre-promoted ``AnyArray`` flowing through an inlined ``@qd.func`` body. Mark the underlying ndarray as used + # so it survives the enforcing-pass pruning. + if should_mark: + src_id = getattr(value, "_qd_source_ndarray_id", None) + if src_id is not None: + pruning.used_struct_ndarray_ids.add(src_id) + return value @staticmethod def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): @@ -794,6 +833,26 @@ def build_Attribute(ctx: ASTTransformerFuncContext, node: ast.Attribute): warnings.warn(message) else: raise exception.QuadrantsCompilationError(message) + # Propagate the kernel-arg-rooted chain annotation and record this access in pruning's *separate* chain-paths + # set. ``build_Name`` sets ``_qd_arg_chain`` on non-flattened kernel args (e.g. data_oriented ``self``); each + # Attribute access in the chain extends it (``self`` → ``__qd_self__qd_x`` → ``__qd_self__qd_x__qd_y``). + # + # Why not ``mark_used``? On the enforcing pass, ``Kernel.materialize`` uses ``pruning.used_vars_by_func_id`` as + # ``struct_locals``, which drives ``FlattenAttributeNameTransformer`` — adding ``__qd_self__qd_x`` there would + # make the transformer rewrite ``self.x`` into ``Name('__qd_self__qd_x')``, and ``build_Name`` would then fail + # to find such a variable. ``mark_kernel_arg_chain_used`` puts the chain into a *separate* per-func set that's + # merged into ``used_vars_by_func_id[KERNEL_FUNC_ID]`` only *after* both compile passes, by + # ``Pruning.fold_kernel_arg_chain_paths`` — so the fastcache args-hash narrow walk picks them up without + # breaking codegen. + parent_chain = getattr(node.value, "_qd_arg_chain", None) + if parent_chain is not None: + flat = create_flat_name(parent_chain, node.attr) + node._qd_arg_chain = flat # type: ignore[attr-defined] + pruning = ctx.global_context.pruning + if not pruning.enforcing and not ctx.expanding_dataclass_call_parameters: + pruning.mark_kernel_arg_chain_used(ctx.func.func_id, flat) + else: + node._qd_arg_chain = None # type: ignore[attr-defined] return node.ptr @staticmethod diff --git a/python/quadrants/lang/ast/ast_transformer_utils.py b/python/quadrants/lang/ast/ast_transformer_utils.py index 506778c683..fa784a3522 100644 --- a/python/quadrants/lang/ast/ast_transformer_utils.py +++ b/python/quadrants/lang/ast/ast_transformer_utils.py @@ -247,6 +247,17 @@ def __init__( self.visited_funcdef = False self.is_real_function = is_real_function self.kernel_args: list = [] + # Names of the bare (non-flattened) parameters of a ``@qd.func`` being processed. Used by ``build_Name`` to + # seed ``_qd_arg_chain`` for attribute accesses rooted at a func param (e.g. + # ``static_rigid_sim_config.para_level`` where ``static_rigid_sim_config`` is a ``qd.template()`` arg bound to + # a ``@qd.data_oriented`` instance). Without this, chains rooted at func params would not be recorded in + # pruning, and the args-hasher would skip over kernel-read primitive members of nested data_oriented + # containers — leading to stale fastcache hits when those members change between calls. + # ``kernel_args`` only tracks top-level ``@qd.kernel`` args; ``_transform_func_arg`` for a ``@qd.func`` does + # not append to it (see function_def_transformer.py). This separate set avoids piggy-backing on + # ``kernel_args`` so the existing "kernel arg is immutable" diagnostic in ``build_assign_annotated`` doesn't + # start firing for func params. + self.fn_param_names: set[str] = set() self.only_parse_function_def: bool = False self.autodiff_mode = autodiff_mode self.loop_depth: int = 0 diff --git a/python/quadrants/lang/ast/ast_transformers/call_transformer.py b/python/quadrants/lang/ast/ast_transformers/call_transformer.py index 0d709ebd01..2bc22e8650 100644 --- a/python/quadrants/lang/ast/ast_transformers/call_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/call_transformer.py @@ -166,17 +166,23 @@ def _canonicalize_formatted_string(raw_string: str, *raw_args: list, **raw_keywo @staticmethod def _expand_Call_dataclass_args( - ctx: ASTTransformerFuncContext, args: tuple[ast.stmt, ...] + ctx: ASTTransformerFuncContext, + args: tuple[ast.stmt, ...], + called_needed: set[str] | None = None, + callee_arg_names: list[str] | None = None, ) -> tuple[tuple[ast.stmt, ...], tuple[ast.stmt, ...]]: """ - We require that each node has a .ptr attribute added to it, that contains - the associated Python object + We require that each node has a .ptr attribute added to it, that contains the associated Python object. + + ``called_needed`` and ``callee_arg_names`` are used only for the attribute-accessed-instance branch (Option A + for data_oriented @qd.func calls): the caller cannot construct a flat name from its own ``arg.id`` (the arg is + an ast.Attribute), so we look up pruning against the callee's parameter name at the same positional index. """ args_new = [] added_args = [] pruning = ctx.global_context.pruning func_id = ctx.func.func_id - for arg in args: + for arg_idx, arg in enumerate(args): val = arg.ptr if dataclasses.is_dataclass(val) and isinstance(val, type): dataclass_type = val @@ -204,6 +210,54 @@ def _expand_Call_dataclass_args( else: args_new.append(arg_node) added_args.append(arg_node) + elif dataclasses.is_dataclass(val) and not isinstance(val, type): + # Dataclass *instance* passed positionally (e.g. ``self.state`` inside a @qd.data_oriented kernel + # method). Expand into per-leaf attribute accesses against the same AST node, mirroring the typed-arg + # (instance-of-type) path above but emitting ``ast.Attribute`` children rather than ``ast.Name``. + # ``added_args`` items must not carry ``.ptr`` (build_stmt populates it downstream); only the + # intermediate node used for recursion does. + dataclass_type = type(val) + # For pruning, match the callee's flat name (it may have pruned unused fields). Use the callee's + # parameter name at this positional index. + callee_param = ( + callee_arg_names[arg_idx] + if (called_needed is not None and callee_arg_names is not None and arg_idx < len(callee_arg_names)) + else None + ) + for field in dataclasses.fields(dataclass_type): + if called_needed is not None and callee_param is not None: + callee_flat_name = create_flat_name(callee_param, field.name) + if callee_flat_name not in called_needed: + continue + child_val = getattr(val, field.name) + load_ctx = ast.Load() + child_node = ast.Attribute( + value=arg, + attr=field.name, + ctx=load_ctx, + lineno=arg.lineno, + end_lineno=arg.end_lineno, + col_offset=arg.col_offset, + end_col_offset=arg.end_col_offset, + ) + if dataclasses.is_dataclass(child_val) and not isinstance(child_val, type): + child_node.ptr = child_val + # Recurse, threading the renamed scope: the callee's expanded flat name (e.g. + # ``__qd_state__inner``) is the synthetic param name for the nested level. + nested_callee_param = ( + create_flat_name(callee_param, field.name) if callee_param is not None else None + ) + _added_args, _args_new = CallTransformer._expand_Call_dataclass_args( + ctx, + (child_node,), + called_needed=called_needed, + callee_arg_names=[nested_callee_param] if nested_callee_param is not None else None, + ) + args_new.extend(_args_new) + added_args.extend(_added_args) + else: + args_new.append(child_node) + added_args.append(child_node) else: args_new.append(arg) return tuple(added_args), tuple(args_new) @@ -261,6 +315,46 @@ def _expand_Call_dataclass_kwargs( else: kwargs_new.append(kwarg_node) added_kwargs.append(kwarg_node) + elif dataclasses.is_dataclass(val) and not isinstance(val, type): + # Dataclass *instance* passed as a keyword arg (e.g. ``write(state=self.state)`` inside a + # @qd.data_oriented kernel method). Expand into per-leaf keyword args whose values are attribute + # accesses against the original value node (e.g. ``__qd_state__x=self.state.x``). + dataclass_type = type(val) + for field in dataclasses.fields(dataclass_type): + child_name = create_flat_name(kwarg.arg, field.name) + if used_args is not None and child_name not in used_args: + continue + child_val = getattr(val, field.name) + load_ctx = ast.Load() + src_node = ast.Attribute( + value=kwarg.value, + attr=field.name, + ctx=load_ctx, + lineno=kwarg.lineno, + end_lineno=kwarg.end_lineno, + col_offset=kwarg.col_offset, + end_col_offset=kwarg.end_col_offset, + ) + src_node.ptr = child_val + kwarg_node = ast.keyword( + arg=child_name, + value=src_node, + ctx=load_ctx, + lineno=kwarg.lineno, + end_lineno=kwarg.end_lineno, + col_offset=kwarg.col_offset, + end_col_offset=kwarg.end_col_offset, + ) + if dataclasses.is_dataclass(child_val) and not isinstance(child_val, type): + kwarg_node.ptr = {child_name: child_val} + _added_kwargs, _kwargs_new = CallTransformer._expand_Call_dataclass_kwargs( + ctx, [kwarg_node], used_args + ) + kwargs_new.extend(_kwargs_new) + added_kwargs.extend(_added_kwargs) + else: + kwargs_new.append(kwarg_node) + added_kwargs.append(kwarg_node) else: kwargs_new.append(kwarg) return added_kwargs, kwargs_new @@ -286,11 +380,21 @@ def build_Call(ctx: ASTTransformerFuncContext, node: ast.Call, build_stmt, build is_func_base_wrapper = func_type in {QuadrantsCallable, BoundQuadrantsCallable} pruning = ctx.global_context.pruning called_needed = None + callee_arg_names: list[str] | None = None if pruning.enforcing and is_func_base_wrapper: called_func_id_ = func.wrapper.func_id # type: ignore called_needed = pruning.used_vars_by_func_id[called_func_id_] + if is_func_base_wrapper: + # callee param names (used by the attribute-instance positional-expansion path so it can match the + # callee's already-pruned flat names). + try: + callee_arg_names = [m.name for m in func.wrapper.arg_metas] # type: ignore[attr-defined] + except AttributeError: + callee_arg_names = None - added_args, node_args = CallTransformer._expand_Call_dataclass_args(ctx, node.args) + added_args, node_args = CallTransformer._expand_Call_dataclass_args( + ctx, node.args, called_needed=called_needed, callee_arg_names=callee_arg_names + ) added_keywords, node_keywords = CallTransformer._expand_Call_dataclass_kwargs(ctx, node.keywords, called_needed) # Create variables for the now-expanded dataclass members. diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 9f17c5c5ee..7668620467 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -156,8 +156,8 @@ def _transform_kernel_arg( elif isinstance(field.type, type) and getattr(field.type, "_data_oriented", False): # ``@qd.data_oriented`` field type inside a typed-dataclass kernel arg. The two patterns are # semantically incompatible at this layer: dataclass kernel-arg recursion uses annotations to - # flatten leaf fields into per-leaf kernel args at compile time, but data_oriented containers don't - # carry per-attribute type annotations — they need a value-driven walk + # flatten leaf fields into per-leaf kernel args at compile time, but data_oriented containers + # don't carry per-attribute type annotations — they need a value-driven walk # (``_predeclare_struct_ndarrays``), which only fires for ``qd.template()`` / ``qd.Tensor`` # annotations. Rather than silently miscompile, raise a clear error pointing users to the # recommended pattern. @@ -231,20 +231,40 @@ def _predeclare_struct_ndarrays(ctx: ASTTransformerFuncContext) -> None: Also stores ``(arg_id, template_arg_idx, attr_chain)`` tuples in ``ctx.global_context.struct_ndarray_launch_info`` so the launch path can populate the corresponding slots in the launch context. + + Pruning: in the enforcing (second) compile pass, ``pruning.used_struct_ndarray_ids`` contains the set of + ``id(ndarray)`` values that ``_promote_ndarray_if_declared`` observed being accessed during the first pass + (directly in the kernel body, or transitively through ``@qd.func`` inlining). We register only those, dropping + every unused ndarray from the kernel's parameter list. On the first pass the set is empty / not yet populated, + so we register everything as today (correctness: the first pass needs every reachable ndarray in the cache for + ``build_Attribute`` to resolve the accesses that *will* populate the set). """ + from quadrants.lang._pruning import Pruning # pylint: disable=C0415 from quadrants.lang.util import cook_dtype # pylint: disable=C0415 cache = ctx.global_context.ndarray_to_any_array launch_info = ctx.global_context.struct_ndarray_launch_info + pruning = ctx.global_context.pruning + used_ids = getattr(pruning, "used_struct_ndarray_ids", None) + # Only prune on the enforcing pass when we actually ran pass 0 to populate the used-ndarray set. On a + # fastcache hit pass 0 is skipped and the set is empty. + prune = pruning.enforcing and used_ids is not None and getattr(pruning, "pass_0_ran", False) + # On a fastcache hit (enforcing without a pass-0 run), the `id(nd)` set is empty, but the *flat-name* set on + # ``used_vars_by_func_id[KERNEL_FUNC_ID]`` was loaded from cache and already contains every kernel-accessed + # leaf path (folded in by ``Pruning.fold_struct_nd_paths`` during the compile that produced the cache entry). + # Use that to prune the walk so we register the exact same ndarray set as the originating compile produced — + # without this, every reachable ndarray gets registered, the kernel's arg slots get rebound to the wrong + # ndarrays at launch, and physics silently breaks. + prune_from_flat_names = pruning.enforcing and not getattr(pruning, "pass_0_ran", False) + kernel_used_flat_names = ( + pruning.used_vars_by_func_id.get(Pruning.KERNEL_FUNC_ID, set()) if prune_from_flat_names else None + ) - # ``_seen`` set guards against attribute-graph cycles in user containers (e.g. Genesis ``sim.solver.sim is - # sim``). Without it this walker recurses infinitely on the back-edge and blows the Python stack at compile - # time. Tracked by ``id(obj)`` to avoid relying on ``__hash__`` for arbitrary user types. + # Cycle-safe walker: Genesis object graphs have cross-references (e.g. solver <-> scene <-> sim) so we must + # avoid re-entering the same node. ``seen`` is shared across the whole arg's traversal — ``id(obj)`` is + # stable for the duration of this compile and we never need to revisit a node since the ndarray-set rooted at + # it doesn't depend on the path we took to reach it. def _walk_obj(obj, arg_idx, path, seen): - obj_id = id(obj) - if obj_id in seen: - return - seen.add(obj_id) if is_dataclass_instance(obj): for field in dataclasses.fields(obj): child = getattr(obj, field.name) @@ -253,6 +273,10 @@ def _walk_obj(obj, arg_idx, path, seen): if isinstance(child, _ndarray.Ndarray): _register_ndarray(child, arg_idx, (*path, field.name)) elif is_dataclass_instance(child) or is_data_oriented(child): + child_id = id(child) + if child_id in seen: + continue + seen.add(child_id) _walk_obj(child, arg_idx, (*path, field.name), seen) else: for attr_name, attr_val in vars(obj).items(): @@ -261,12 +285,31 @@ def _walk_obj(obj, arg_idx, path, seen): if isinstance(attr_val, _ndarray.Ndarray): _register_ndarray(attr_val, arg_idx, (*path, attr_name)) elif is_dataclass_instance(attr_val) or is_data_oriented(attr_val): + attr_id = id(attr_val) + if attr_id in seen: + continue + seen.add(attr_id) _walk_obj(attr_val, arg_idx, (*path, attr_name), seen) def _register_ndarray(nd, arg_idx, attr_chain): key = id(nd) if key in cache: return + if prune and key not in used_ids: + return + if prune_from_flat_names: + # Build the leaf flat name (e.g. ``__qd_self__qd__collider_state__qd_active_buffer``) + # and skip registration when the kernel's cached pruning set doesn't contain it. + if arg_idx < 0 or arg_idx >= len(ctx.func.arg_metas): + return + arg_name = ctx.func.arg_metas[arg_idx].name + if not arg_name: + return + flat = arg_name + for attr in attr_chain: + flat = create_flat_name(flat, attr) + if flat not in kernel_used_flat_names: + return from quadrants._lib import core as _qd_core # pylint: disable=C0415 element_type = cook_dtype(nd.element_type) @@ -281,6 +324,10 @@ def _register_ndarray(nd, arg_idx, attr_chain): _qd_core.make_external_tensor_expr(element_type, ndim, arg_id_vec, needs_grad, BoundaryMode.UNSAFE), _qd_layout=layout, ) + # Tag the AnyArray with the source ndarray id so ``_promote_ndarray_if_declared`` can mark this ndarray + # as used even when the access reaches it via an already-promoted AnyArray (e.g. callee bodies bound to + # per-leaf args by Option A). + arr._qd_source_ndarray_id = key cache[key] = arr launch_info.append((arg_id_vec[0], arg_idx, attr_chain)) @@ -297,9 +344,9 @@ def _register_ndarray(nd, arg_idx, attr_chain): if isinstance(val, _ndarray.Ndarray): continue if is_dataclass_instance(val): - _walk_obj(val, i, (), set()) + _walk_obj(val, i, (), {id(val)}) elif hasattr(val, "__dict__"): - _walk_obj(val, i, (), set()) + _walk_obj(val, i, (), {id(val)}) @staticmethod def _unwrap_tensor(data: Any) -> Any: @@ -315,6 +362,15 @@ def _transform_func_arg( argument_type: Any, data: Any, ) -> None: + # Record the bare (non-flattened) func param name so ``build_Name`` can seed ``_qd_arg_chain`` for attribute + # accesses rooted at this param. Critical for ``qd.template()`` args bound to ``@qd.data_oriented`` instances + # (e.g. ``static_rigid_sim_config.para_level`` inside a ``@qd.func``): without this, the kernel's pruning set + # never learns about ``.para_level``, the args-hasher skips the value, and different ``para_level`` + # configurations collide in the fastcache key. Flat names starting with ``__qd_`` arrive here too via the + # dataclass-flatten recursion below; they're harmless to add (``build_Name``'s chain branch gates on + # ``not node.id.startswith("__qd_")``) but the bare-name entries are what enables propagation. + ctx.fn_param_names.add(argument_name) + # Template arguments are passed by reference. if isinstance(argument_type, annotations.template): ctx.create_variable(argument_name, data) diff --git a/python/quadrants/lang/kernel.py b/python/quadrants/lang/kernel.py index 8955b7f565..f4533906ff 100644 --- a/python/quadrants/lang/kernel.py +++ b/python/quadrants/lang/kernel.py @@ -343,23 +343,60 @@ def reset(self) -> None: self.fe_ll_cache_observations = FeLlCacheObservations() def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType") -> set[str] | None: - frontend_cache_key: str | None = None + """Two-phase fastcache lookup. + + Phase 1 — L1 lookup keyed by source+config only (no args). Returns the set of kernel-accessed flat + names (pruning info). Hit OR miss, this only determines whether we have pruning info for the narrow + args walk; it never on its own justifies skipping pass 0 — that requires the C++ artifact to load. + + Phase 2 — narrow args walk + L2 lookup + artifact load. Only when *all three* succeed do we return + non-None and let ``materialize`` skip pass 0. The reason: pass 0 is what populates pruning info for + *every called ``@qd.func``* (not just the kernel itself). Skipping pass 0 is only safe when pass 1 + runs in ``only_parse_function_def`` mode (i.e. the C++ artifact is already loaded so the AST walker + never enters any callee body); otherwise callee variables can't be found in their func's empty + ``used_vars_by_func_id`` set and the build fails with "Name __qd_… is not defined". + + Side effects: populates ``self._l1_key`` (always when fastcache is active), ``self._pruning_paths_from_l1`` + (the L1 pruning info, or None if L1 miss — used by ``materialize`` for L1-store skipping and for + post-compile narrow-hash construction), and ``self.fast_checksum`` (the L2 key, when phase 2 computed + the narrow args hash). All three are read by the post-compile path in ``_maybe_persist_l1_and_set_l2_key``. + """ + self._l1_key = None # type: ignore[attr-defined] + self._pruning_paths_from_l1 = None # type: ignore[attr-defined] + self.fast_checksum = None if self.runtime.src_ll_cache and self.quadrants_callable and self.quadrants_callable.is_pure: kernel_source_info, _src = get_source_info_and_src(self.func) - self.fast_checksum = src_hasher.create_cache_key( - self.raise_on_templated_floats, kernel_source_info, args, self.arg_metas + self._kernel_source_info_cached = kernel_source_info # reused by materialize / launch_kernel + self._l1_key = src_hasher.make_source_config_key(kernel_source_info) + + # Phase 1: L1 lookup — pruning info only, no args walk yet. + pruning_paths, cached_graph_do_while_arg = src_hasher.load_pruning_info(self._l1_key) + if pruning_paths is None: + # Cold L1. ``materialize`` will compile pass 0 + pass 1 to populate pruning info, then we + # store L1 + L2 after compile. ``cache_key_generated`` is intentionally NOT flipped to True + # here: it tracks "fastcache produced a valid L2 args hash" (the pre-refactor semantic), and + # we don't know yet whether the narrow args walk will succeed. + return None + self._pruning_paths_from_l1 = pruning_paths + + # Phase 2: narrow args hash + L2 lookup. + narrow_args_hash = src_hasher.compute_narrow_args_hash( + self.raise_on_templated_floats, kernel_source_info, args, self.arg_metas, pruning_paths + ) + if narrow_args_hash is None: + # Recognised-but-unsupported tensor-like (Field / MatrixField) — fastcache off for this call. + # ``self.fast_checksum`` stays None so no L2 entry is written; ``cache_key_generated`` stays + # False to match the pre-refactor "Field disables fastcache key generation" contract. + return None + self.fast_checksum = src_hasher.make_full_cache_key(self._l1_key, narrow_args_hash) + self.src_ll_cache_observations.cache_key_generated = True + + used_py_dataclass_parameters, frontend_cache_key, cached_graph_do_while_arg_l2 = src_hasher.load( + self.fast_checksum ) - used_py_dataclass_parameters = None - cached_graph_do_while_arg: str | None = None - if self.fast_checksum: - self.src_ll_cache_observations.cache_key_generated = True - used_py_dataclass_parameters, frontend_cache_key, cached_graph_do_while_arg = src_hasher.load( # type: ignore[reportAssignmentType] - self.fast_checksum - ) if used_py_dataclass_parameters is not None and frontend_cache_key is not None: self.src_ll_cache_observations.cache_validated = True prog = impl.get_runtime().prog - assert self.fast_checksum is not None self.compiled_kernel_data_by_key[key] = prog.load_fast_cache( frontend_cache_key, self.func.__name__, @@ -369,11 +406,15 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType if self.compiled_kernel_data_by_key[key]: self.src_ll_cache_observations.cache_loaded = True self.used_py_dataclass_parameters_by_key_enforcing[key] = used_py_dataclass_parameters - if cached_graph_do_while_arg is not None: - self.graph_do_while_arg = cached_graph_do_while_arg + self.graph_do_while_arg = cached_graph_do_while_arg_l2 or cached_graph_do_while_arg return used_py_dataclass_parameters + # L2 miss or artifact load failed: report cold so ``materialize`` does pass 0 + pass 1 (needed + # to populate per-callee pruning info). ``self.fast_checksum`` is still set so the post-compile + # ``src_hasher.store`` will write a fresh L2 entry under the narrow-args key. + self.graph_do_while_arg = cached_graph_do_while_arg or self.graph_do_while_arg + return None - elif self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: + if self.quadrants_callable and not self.quadrants_callable.is_pure and self.runtime.print_non_pure: # The bit in caps should not be modified without updating corresponding test # freetext can be freely modified. # As for why we are using `print` rather than eg logger.info, it is because this is only printed when @@ -384,7 +425,6 @@ def _try_load_fastcache(self, args: tuple[Any, ...], key: "CompiledKernelKeyType def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, ...], arg_features=None): if key is None: key = (self.func, 0, self.autodiff_mode) - self.fast_checksum = None if key in self.materialized_kernels: return @@ -428,6 +468,8 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . range_begin = 0 if used_py_dataclass_parameters is None else 1 runtime = impl.get_runtime() for _pass in range(range_begin, 2): + if _pass == 0: + pruning.pass_0_ran = True if _pass >= 1: pruning.enforce() tree, ctx = self.get_tree_and_ctx( @@ -461,6 +503,23 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . self._struct_ndarray_launch_info_by_key[key] = getattr( ctx.global_context, "struct_ndarray_launch_info", [] ) + # Fold data_oriented ndarray attribute chains into the kernel's used-flat-names set so + # ``args_hasher.hash_args`` can narrow data_oriented walks too. ``used_vars_by_func_id`` only + # contains flat names from dataclass-arg expansion in ``extract_struct_locals_from_context``; + # data_oriented args don't go through that expansion, so accesses like ``self.x`` on an ndarray + # member are only tracked via ``struct_ndarray_launch_info``. Without this fold, narrow hashing + # for data_oriented args walks nothing — every (arg_idx, attr_chain) pair gets the same hash + # regardless of dtype, so changing ``state.x``'s dtype no longer invalidates the cache (the + # ``test_data_oriented_ndarray_fastcache_dtype_key_distinct`` pin caught this). + pruning.fold_struct_nd_paths(self._struct_ndarray_launch_info_by_key.get(key, []), self.arg_metas) + # Fold non-ndarray kernel-arg-rooted chain paths (primitives, opaque members, nested struct + # paths) collected by ``ASTTransformer.build_Attribute``'s ``_qd_arg_chain`` tracking. Kept + # separate from ``used_vars_by_func_id`` during compile (would otherwise poison ``struct_locals`` + # and break codegen) — see the field-level docstring on + # ``Pruning.kernel_arg_chain_paths_by_func_id``. This fold + the existing ``used_vars`` assignment + # to ``used_py_dataclass_parameters_by_key_enforcing`` share the same set by reference, so the + # final fastcache L1 entry sees all kernel-accessed paths. + pruning.fold_kernel_arg_chain_paths() else: for used_parameters in pruning.used_vars_by_func_id.values(): new_used_parameters = set() @@ -478,6 +537,27 @@ def materialize(self, key: "CompiledKernelKeyType | None", py_args: tuple[Any, . ] runtime._current_global_context = None + # Post-compile fastcache bookkeeping. See ``_maybe_persist_l1_and_set_l2_key`` docstring. + self._maybe_persist_l1_and_set_l2_key(key, py_args) + + def _maybe_persist_l1_and_set_l2_key(self, key: "CompiledKernelKeyType", py_args: tuple[Any, ...]) -> None: + """Thin delegate to ``src_hasher.persist_l1_and_set_l2_key``; see that function's docstring for behaviour.""" + new_fast_checksum, generated = src_hasher.persist_l1_and_set_l2_key( + l1_key=getattr(self, "_l1_key", None), + kernel_source_info=getattr(self, "_kernel_source_info_cached", None), + used_py_dataclass_parameters=self.used_py_dataclass_parameters_by_key_enforcing.get(key), + visited_functions=self.visited_functions, + graph_do_while_arg=self.graph_do_while_arg, + pruning_paths_from_l1=getattr(self, "_pruning_paths_from_l1", None), + fast_checksum=self.fast_checksum, + raise_on_templated_floats=self.raise_on_templated_floats, + py_args=py_args, + arg_metas=self.arg_metas, + ) + if generated: + self.fast_checksum = new_fast_checksum + self.src_ll_cache_observations.cache_key_generated = True + def launch_kernel( self, key, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args, qd_stream=None ) -> Any: @@ -502,15 +582,26 @@ def launch_kernel( # container wrapping a mutable inner container that holds the ndarray (e.g. frozen dataclass -> data_oriented # -> ndarray), id(outer) alone does not capture leaf rebinding because the inner container can still reassign # ``.x``. So we OR-fold the mutability check across every parent along ``chain`` from the root down to (but - # excluding) the leaf attribute. + # excluding) the leaf attribute. See ``chain_has_mutable_container`` in ``_template_mapper_hotpath`` for the + # exact predicate. + # + # ``_qd_stable_members = True`` on a ``@qd.data_oriented`` class (or ``@qd.data_oriented(stable_members=True)``) + # is a launch-time opt-out: the user promises ndarray members are never reassigned on instances of that class, + # so we can skip the per-call ``_resolve_struct_ndarray`` walk entirely for args of that type. if key != self._mutable_nd_cached_key: if self._struct_ndarray_launch_info_by_key: struct_nd_info = self._struct_ndarray_launch_info_by_key.get(key) if struct_nd_info: + # Data_oriented containers marked ``_qd_stable_members = True`` (or decorated with + # ``@qd.data_oriented(stable_members=True)``) promise their ndarray members are never reassigned, + # so we exclude them from the per-call ``_resolve_struct_ndarray`` walk that builds ``args_hash``. + # This is a *launch-time perf hint only* and has no fastcache role — fastcache derives its key + # from kernel-pruning info regardless of this flag. self._mutable_nd_cached_val = [ (idx, chain) for _, idx, chain in struct_nd_info - if chain_has_mutable_container(args, idx, chain) + if not type(args[idx]).__dict__.get("_qd_stable_members") + and chain_has_mutable_container(args, idx, chain) ] else: self._mutable_nd_cached_val = [] diff --git a/python/quadrants/lang/kernel_impl.py b/python/quadrants/lang/kernel_impl.py index 01c74a256d..f3dedca01e 100644 --- a/python/quadrants/lang/kernel_impl.py +++ b/python/quadrants/lang/kernel_impl.py @@ -275,7 +275,7 @@ def grad(self, *args, **kwargs) -> "Kernel": return self._adjoint(self._kernel_owner, *args, **kwargs) -def data_oriented(cls): +def data_oriented(cls=None, *, stable_members: bool = False): """Marks a class as Quadrants compatible. To allow for modularized code, Quadrants provides this decorator so that @@ -299,21 +299,45 @@ def data_oriented(cls): >>> a.inc() Args: - cls (Class): the class to be decorated + cls (Class): the class to be decorated. + stable_members (bool): launch-context perf hint — if ``True``, declares that the class's ndarray-typed members + are allocated once and never reassigned between kernel calls. Quadrants will skip the per-call ndarray- + reference walk that ``Kernel.launch_kernel`` uses to detect ndarray reassignment on mutable containers + (~1-2 us/call savings on Genesis-style containers with dozens of ndarray attrs). Reassigning a member on + a ``stable_members`` class is undefined behaviour — the previously-compiled kernel will be reused even if + the new ndarray has different dtype/ndim/layout. May also be set as a class-level attribute + ``_qd_stable_members = True`` (equivalent). + + Note: this flag is *purely* a launch-time perf hint. It no longer affects fastcache argument hashing — the + fastcache key is derived from pruning info (the set of flat names the kernel actually reads), and + unrecognised types at kernel-read paths fail fastcache loudly with a one-shot ``[UNKNOWN_TYPE]`` + + ``[INVALID_FUNC]`` diagnostic (no qualname fallback). See ``docs/source/user_guide/fastcache.md``. Returns: - The decorated class. + The decorated class (or, when called with arguments, a decorator). """ + if cls is None: + return lambda c: data_oriented(c, stable_members=stable_members) + + def make_kernel_indirect(fun, is_property, attr_name): + # Capture the primal at decoration time so the per-call path skips the ``_BoundedDifferentiableMethod`` + # allocation. The class itself is validated when ``_BoundedDifferentiableMethod`` is invoked via the + # ``.grad()`` path; for the common primal call here we replicate the check inline. + primal = fun._primal - def make_kernel_indirect(fun, is_property): @wraps(fun) def _kernel_indirect(self, *args, **kwargs): - nonlocal fun - ret = _BoundedDifferentiableMethod(self, fun) - ret.__name__ = fun.__name__ # type: ignore - return ret(*args, **kwargs) + try: + return primal(self, *args, **kwargs) + except (QuadrantsCompilationError, QuadrantsRuntimeError) as e: + if impl.get_runtime().print_full_traceback: + raise e + raise type(e)("\n" + str(e)) from None ret = QuadrantsCallable(fun, _kernel_indirect) + # setattr-after-class doesn't trigger __set_name__; set the name explicitly so QuadrantsCallable.__get__ can + # cache the BoundQuadrantsCallable on instance.__dict__. + ret._attr_name = attr_name if is_property: ret = property(ret) return ret @@ -331,8 +355,10 @@ def _kernel_indirect(self, *args, **kwargs): if isinstance(fun, (BoundQuadrantsCallable, QuadrantsCallable)): if fun._is_wrapped_kernel: if fun._is_classkernel and attr_type is not staticmethod: - setattr(cls, name, make_kernel_indirect(fun, is_property)) + setattr(cls, name, make_kernel_indirect(fun, is_property, name)) cls._data_oriented = True + if stable_members: + cls._qd_stable_members = True return cls diff --git a/tests/python/quadrants/lang/ast/test_function_def_transformer.py b/tests/python/quadrants/lang/ast/test_function_def_transformer.py index a46d5e2cbc..20b422b20b 100644 --- a/tests/python/quadrants/lang/ast/test_function_def_transformer.py +++ b/tests/python/quadrants/lang/ast/test_function_def_transformer.py @@ -81,6 +81,9 @@ def test_process_func_arg(argument_name: str, argument_type: Any, expected_varia class MockContext: def __init__(self) -> None: self.variables: dict[str, Any] = {} + # Mirror the real ``ASTTransformerFuncContext.fn_param_names`` so ``_transform_func_arg`` can record bare + # param names without crashing. + self.fn_param_names: set[str] = set() def create_variable(self, name: str, data: Any) -> None: assert name not in self.variables diff --git a/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py b/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py index a4bda2a1b1..f6f4b010aa 100644 --- a/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py +++ b/tests/python/quadrants/lang/fast_caching/test_fastcache_field_warnings.py @@ -160,7 +160,14 @@ def k(x: qd.Template): @test_utils.test(arch=qd.cpu) @pytest.mark.skipif(sys.platform.startswith("win"), reason="Windows stderr not working with capfd") def test_fastcache_field_warnings_warn_struct_template_field(tmp_path, capfd): - """Struct with qd.Template-annotated field containing a Field — warning should fire.""" + """Struct with qd.Template-annotated field containing a Field — warning should fire when the field is + actually read by the kernel. + + Pruning-driven narrowing of args hashing only walks members the kernel reads; an unused dataclass field cannot + affect kernel codegen so it's correctly omitted from the hash (and from the Field-disables-fastcache check). For + the warning path to fire, the kernel must reference the Field — that matches the user-visible contract that + fastcache fails iff a "live" Field argument prevents safe parametrisation. + """ qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) @dataclasses.dataclass(frozen=True) @@ -173,7 +180,7 @@ class S: @qd.pure @qd.kernel def k(x: S): - pass + x.a[0] = 1 capfd.readouterr() k(s) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py index e7a2d9952b..9a0fcfd271 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_hasher.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_hasher.py @@ -23,6 +23,13 @@ @test_utils.test() def test_src_hasher_create_cache_key_vary_config() -> None: + """Source+config key (L1) is stable across re-init with identical config, changes when the config changes. + + Updated from the pre-refactor ``create_cache_key`` API (single-level, args-dependent) to the two-level + ``make_source_config_key`` (L1 — source+config only, no args). The L1 key is the right level to test + because config changes only affect the L1 layer; L2 adds the args-narrow hash on top. + """ + @qd.kernel def f1() -> None: pass @@ -31,15 +38,15 @@ def f1() -> None: # so we are forcing it to false each initialization for now qd_init_same_arch(print_ir_dbg_info=False) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_base = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_base = src_hasher.make_source_config_key(kernel_info) qd_init_same_arch(print_ir_dbg_info=False) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_same = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_same = src_hasher.make_source_config_key(kernel_info) qd_init_same_arch(print_ir_dbg_info=False, random_seed=123) kernel_info, _src = get_source_info_and_src(f1.fn) - cache_key_diff = src_hasher.create_cache_key(False, kernel_info, [], []) + cache_key_diff = src_hasher.make_source_config_key(kernel_info) assert cache_key_base == cache_key_same assert cache_key_same != cache_key_diff @@ -103,7 +110,9 @@ def get_fileinfos(functions: list[Callable]) -> list[_wrap_inspect.FunctionSourc mod = temporary_module("child_diff_test_src_hasher_store_validate") kernel_info = get_fileinfos([mod.f1.fn])[0] fileinfos = get_fileinfos([mod.f1.fn, mod.f2.fn]) - fast_cache_key = src_hasher.create_cache_key(False, kernel_info, [], []) + # L2 key: source+config (L1) + narrow-args-hash. Use an empty narrow-args-hash since the test isn't + # exercising args at all — it tests the helper-source-change invalidation logic, which lives in L2. + fast_cache_key = src_hasher.make_full_cache_key(src_hasher.make_source_config_key(kernel_info), narrow_args_hash="") assert fast_cache_key is not None @@ -202,7 +211,9 @@ def src_hasher_vary_kernel_func_child(args: list[str]) -> None: sys.path.append(args_obj.module_file_path) mod = importlib.import_module(args_obj.module_name) info, _src = _wrap_inspect.get_source_info_and_src(mod.f1.fn) - cache_key = src_hasher.create_cache_key(False, info, [], []) + # Source+config key (L1) — varies with the *kernel source* (the property this test exercises) and is + # the same level as the pre-refactor ``create_cache_key`` call site, just without the args-dependent tail. + cache_key = src_hasher.make_source_config_key(info) print(f"CACHE_KEY={cache_key}") print(TEST_RAN) diff --git a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py index 711839cf5d..819f9e3701 100644 --- a/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py +++ b/tests/python/quadrants/lang/fast_caching/test_src_ll_cache.py @@ -142,10 +142,16 @@ def k1(foo: qd.Template) -> None: k1(foo=RandomClass()) _out, err = capfd.readouterr() - assert "[FASTCACHE][PARAM_INVALID]" in err + # Unrecognised types at a (top-level) kernel-read path now fail fastcache loudly: a one-shot ``[UNKNOWN_TYPE]`` + # warning identifies the offending type, and ``[INVALID_FUNC]`` then reports the disabled cache. The old silent + # ``[PARAM_INVALID]`` dead-end is gone — the two rules driving this are documented in + # ``args_hasher.py::_fail_unknown_type`` and ``fastcache.md`` "Pruning-driven argument hashing": (1) only pruned + # paths may contribute to the cache key (so no qualname fallback), (2) unrecognised types at pruned paths must + # not be silently dropped. + assert "[FASTCACHE][UNKNOWN_TYPE]" in err assert RandomClass.__name__ in err assert "[FASTCACHE][INVALID_FUNC]" in err - assert k1.__name__ in err + assert "[FASTCACHE][PARAM_INVALID]" not in err @qd.kernel def not_pure_k1(foo: qd.Template) -> None: @@ -153,8 +159,10 @@ def not_pure_k1(foo: qd.Template) -> None: not_pure_k1(foo=RandomClass()) _out, err = capfd.readouterr() + # Without ``@qd.pure``, fastcache is not active at all — neither the new UNKNOWN_TYPE nor the old + # PARAM_INVALID / INVALID_FUNC warnings should fire. + assert "[FASTCACHE][UNKNOWN_TYPE]" not in err assert "[FASTCACHE][PARAM_INVALID]" not in err - assert RandomClass.__name__ not in err assert "[FASTCACHE][INVALID_FUNC]" not in err assert k1.__name__ not in err @@ -433,6 +441,137 @@ def k1(self) -> tuple[qd.i32, qd.i32]: assert my_do.k1._primal.src_ll_cache_observations.cache_validated +@test_utils.test() +def test_src_ll_cache_needs_grad_distinguishes_args_hash(tmp_path: pathlib.Path) -> None: + """Pin: fastcache narrow args_hash MUST fold in ``needs_grad`` for every ndarray leaf. Without this, two scenes + that differ only by whether their ndarrays carry ``.grad`` (e.g. Genesis ``requires_grad=True`` vs ``False``) + collide on the L2 key, and the second scene loads the artifact compiled with the first scene's needs_grad + flag. The kernel's compiled parameter slots have a fixed needs_grad (``insert_ndarray_param`` bakes it into + the struct type), and the launch path branches on ``v.grad is not None`` to pick between ``_QD_ARRAY`` and + ``_QD_ARRAY_WITH_GRAD`` buckets — bind a needs_grad=True ndarray to a slot declared without grad and the + parameter struct's primal pointer ends up at the wrong offset, producing silent wrong results or runtime OOB. + + Reproduces the Genesis pattern (``kernel_init_link_fields`` taking a frozen-dataclass ``LinksState`` whose + members carry ``needs_grad`` from the scene's ``requires_grad``) with the smallest possible surface: a frozen + dataclass with two ``qd.f32`` ndarray members, a kernel that writes only the second one. First process compiles + without grad and stores L1+L2; second process (via ``qd.reset()`` + ``qd.init()``) runs the same kernel with + ``needs_grad=True`` members and asserts the second result is correct *and* that the L2 entry was a miss + (so the per-call needs_grad is correctly part of the cache key). + """ + import dataclasses + + import numpy as np + + arch = getattr(qd, qd.lang.impl.current_cfg().arch.name) + N = 4 + + @dataclasses.dataclass(frozen=True) + class State: + a: qd.types.NDArray[qd.f32, 1] + b: qd.types.NDArray[qd.f32, 1] + + @qd.pure + @qd.kernel + def write_b(s: State) -> None: + for i in range(N): + s.b[i] = qd.cast(i + 1, qd.f32) * 7.0 + + # Cold run: needs_grad=False (default). Populates L1 (pruning info) + L2 (artifact compiled with the slot for + # ``s.b`` declared needs_grad=False) using the narrow args_hash from ``stringify_obj_type`` on the without-grad + # ndarray ``[nd-f32-1]``. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + a1 = qd.ndarray(qd.f32, shape=(N,)) + b1 = qd.ndarray(qd.f32, shape=(N,)) + state1 = State(a=a1, b=b1) + write_b(state1) + assert write_b._primal.src_ll_cache_observations.cache_key_generated + assert not write_b._primal.src_ll_cache_observations.cache_loaded + expected = np.array([7, 14, 21, 28], dtype=np.float32) + np.testing.assert_allclose(b1.to_numpy(), expected) + + # Hot run: needs_grad=True. With the bug, ``stringify_obj_type`` yields the same ``[nd-f32-1]`` string for the + # with-grad ndarray, the narrow args_hash collides, and L2 returns the without-grad artifact. The launch path + # then routes ``b2`` through ``_QD_ARRAY_WITH_GRAD`` because ``b2.grad`` is not None, against a slot the + # cached kernel declared as plain ``_QD_ARRAY`` — silent miscomputation or OOB. + # + # After the fix, the args_hash differs (needs_grad folded into the ndarray descriptor), L2 misses, the kernel + # is recompiled with the correct needs_grad=True slot, and the launch is well-typed. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + a2 = qd.ndarray(qd.f32, shape=(N,), needs_grad=True) + b2 = qd.ndarray(qd.f32, shape=(N,), needs_grad=True) + state2 = State(a=a2, b=b2) + write_b(state2) + # Diagnostic: the L2 must NOT load the no-grad artifact. After the fix this is a cache miss. + assert not write_b._primal.src_ll_cache_observations.cache_loaded, ( + "fastcache hit between needs_grad=False (cold) and needs_grad=True (hot) — narrow args_hash is " + "missing needs_grad, the without-grad artifact will be launched against with-grad ndarrays" + ) + # Correctness: the kernel writes the expected values, regardless of cache state. + np.testing.assert_allclose(b2.to_numpy(), expected) + # ``b2.grad`` is allocated but not written by this kernel — sanity check it survived as zero (i.e. the + # launch didn't smear primal data into the grad slot via a misaligned param struct). + np.testing.assert_allclose(b2.grad.to_numpy(), np.zeros(N, dtype=np.float32)) + + +@test_utils.test() +def test_src_ll_cache_hit_predeclare_struct_ndarrays_pruned(tmp_path: pathlib.Path) -> None: + """Pin the cache-hit fix for ``_predeclare_struct_ndarrays``: on a fastcache hit pass 0 is skipped so the + ``id(nd)``-keyed used-ndarray set is empty; without flat-name fallback pruning every reachable ndarray gets + registered, scrambling the kernel's arg-slot bindings (e.g. a kernel compiled to write ``state.b`` ends up + writing ``state.a`` at launch). The fix uses the cached ``used_vars_by_func_id[KERNEL_FUNC_ID]`` flat-name + set to gate registration on the cache-hit branch, reproducing the exact ndarray set the originating compile + produced. + + The test exercises both the cold (cache-store) and hot (cache-load) paths in the same process via + ``qd.reset()`` cycles, and asserts both that the ndarray the kernel writes to is the *correct* one and that + the other ndarrays are untouched — without the fix the value would land in ``state.a`` (the first + insertion-order ndarray) instead of ``state.b``. + """ + import numpy as np # local import keeps the test module's top-level deps unchanged + + arch = getattr(qd, qd.lang.impl.current_cfg().arch.name) + N = 4 + + @qd.data_oriented + class State: + def __init__(self) -> None: + self.a = qd.ndarray(qd.i32, shape=(N,)) + self.b = qd.ndarray(qd.i32, shape=(N,)) + self.c = qd.ndarray(qd.i32, shape=(N,)) + + @qd.pure + @qd.kernel + def write_b(s: qd.template()) -> None: + for i in range(N): + s.b[i] = (i + 1) * 17 + + # Cold: cache-miss path populates the fastcache (including the kernel-used flat-name set folded in by + # ``_fold_struct_nd_paths_into_pruning``). + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + state = State() + write_b(state) + assert write_b._primal.src_ll_cache_observations.cache_key_generated + assert not write_b._primal.src_ll_cache_observations.cache_loaded + np.testing.assert_array_equal(state.b.to_numpy(), np.array([17, 34, 51, 68], dtype=np.int32)) + np.testing.assert_array_equal(state.a.to_numpy(), np.zeros(N, dtype=np.int32)) + np.testing.assert_array_equal(state.c.to_numpy(), np.zeros(N, dtype=np.int32)) + + # Hot: cache-hit path skips pass 0; this is the branch the fix protects. Without flat-name pruning all three + # ndarrays would be registered in insertion order, displacing ``state.b`` from the slot the kernel was + # compiled to write — and the write would land in ``state.a`` instead. + qd.reset() + qd.init(arch=arch, offline_cache_file_path=str(tmp_path), offline_cache=True) + state = State() + write_b(state) + assert write_b._primal.src_ll_cache_observations.cache_loaded, "expected a fastcache hit on the second run" + np.testing.assert_array_equal(state.b.to_numpy(), np.array([17, 34, 51, 68], dtype=np.int32)) + np.testing.assert_array_equal(state.a.to_numpy(), np.zeros(N, dtype=np.int32)) + np.testing.assert_array_equal(state.c.to_numpy(), np.zeros(N, dtype=np.int32)) + + class ModifySubFuncKernelArgs(pydantic.BaseModel): arch: str offline_cache_file_path: str diff --git a/tests/python/test_data_oriented_ndarray.py b/tests/python/test_data_oriented_ndarray.py index fb8137cfc9..028aee30b2 100644 --- a/tests/python/test_data_oriented_ndarray.py +++ b/tests/python/test_data_oriented_ndarray.py @@ -1,12 +1,13 @@ -"""Tests for ``@qd.data_oriented`` classes whose members are raw ``qd.ndarray`` (not ``qd.field``, not ``qd.Tensor`` -wrappers). +"""Tests for ``@qd.data_oriented`` classes whose members are raw ``qd.ndarray`` (not ``qd.field``, not +``qd.Tensor`` wrappers). The user-guide doc ``docs/source/user_guide/compound_types.md`` claims this pattern is not supported ("can contain ndarray? no" for ``@qd.data_oriented``). But the in-tree error message in ``python/quadrants/lang/impl.py`` lists ``@qd.data_oriented / frozen-dataclass template`` as a *supported* route, and the ndarray-in-struct infrastructure added by ``#561 [Type] Tensor 24`` (2026-04-28) — specifically ``_predeclare_struct_ndarrays`` in ``python/quadrants/lang/ast/ast_transformers/function_def_transformer.py`` — explicitly walks both -``dataclasses.is_dataclass(val)`` and ``hasattr(val, "__dict__")`` containers, the latter being the data_oriented case. +``dataclasses.is_dataclass(val)`` and ``hasattr(val, "__dict__")`` containers, the latter being the data_oriented +case. This file pins what actually works, and documents the gaps. See ``perso_hugh/doc/data_oriented_ndarray.md`` for the design analysis. @@ -168,8 +169,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 6. Mutation: same instance, reassign ndarray attribute to a *same-shape* ndarray between calls. The launch-time -# stale-cache guard (``_mutable_nd_cached_val`` in kernel.py) is supposed to fold the live ndarray id into args_hash -# so the launch context is not served stale. We pin that behaviour here for the data_oriented case. +# stale-cache guard (``_mutable_nd_cached_val`` in kernel.py) is supposed to fold the live ndarray id into +# args_hash so the launch context is not served stale. We pin that behaviour here for the data_oriented case. # --------------------------------------------------------------------------- @@ -202,13 +203,13 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 7. Mutation cross-shape: reassign ndarray attribute to a *different-dtype* ndarray. The template-mapper -# specialisation key (in ``_template_mapper_hotpath._extract_arg``) returns ``weakref.ref(arg)`` for -# ``is_data_oriented(arg)``; it does NOT descend into ndarray children to compute a dtype/ndim-dependent spec key. -# So if the data_oriented instance's id is unchanged but its ndarray attribute is reassigned to a different dtype, -# we expect either: -# - a graceful recompile/raise, or -# - silent miscompilation (the bug case — current expected outcome per static analysis). -# Mark xfail with strict=False so we record the actual outcome without breaking CI. +# specialisation key (in ``_template_mapper_hotpath._extract_arg``) returns ``weakref.ref(arg)`` for +# ``is_data_oriented(arg)``; it does NOT descend into ndarray children to compute a dtype/ndim-dependent spec key. +# So if the data_oriented instance's id is unchanged but its ndarray attribute is reassigned to a different dtype, +# we expect either: +# - a graceful recompile/raise, or +# - silent miscompilation (the bug case — current expected outcome per static analysis). +# Mark xfail with strict=False so we record the actual outcome without breaking CI. # --------------------------------------------------------------------------- @@ -240,8 +241,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 8. Distinct instances of same class -> spec-key behaviour. Documents that today each fresh instance triggers a -# recompile (because the spec key is ``weakref.ref(arg)`` identity). This is a perf concern, not a correctness one. -# We assert correctness here; the recompile count is documented as a perf note. +# recompile (because the spec key is ``weakref.ref(arg)`` identity). This is a perf concern, not a correctness +# one. We assert correctness here; the recompile count is documented as a perf note. # --------------------------------------------------------------------------- @@ -271,18 +272,18 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 9. Fastcache cold then warm. Per the fastcache doc (``user_guide/fastcache.md`` line 129), ``@qd.data_oriented`` -# objects are supported in the cache key. We don't assert cross-process here (that requires a fresh interpreter); we -# assert that ``cache_stored`` becomes True on the first call and ``cache_key_generated`` is True (i.e. no -# PARAM_INVALID fallthrough due to the ndarray member). +# 9. Fastcache cold then warm. Per the fastcache doc (``user_guide/fastcache.md`` line 129), +# ``@qd.data_oriented`` objects are supported in the cache key. We don't assert cross-process here (that requires +# a fresh interpreter); we assert that ``cache_stored`` becomes True on the first call and +# ``cache_key_generated`` is True (i.e. no PARAM_INVALID fallthrough due to the ndarray member). # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- # 9b. Fastcache end-to-end with ``@qd.data_oriented`` holding ndarrays. Pattern adapted from -# ``test_cache.test_fastcache``: call ``qd_init_same_arch`` twice with the same cache directory to simulate two -# processes, monkeypatch ``launch_kernel`` to capture whether ``compiled_kernel_data`` was loaded from disk. On the -# second init the data_oriented + ndarray kernel should be served from the on-disk fastcache. +# ``test_cache.test_fastcache``: call ``qd_init_same_arch`` twice with the same cache directory to simulate two +# processes, monkeypatch ``launch_kernel`` to capture whether ``compiled_kernel_data`` was loaded from disk. On +# the second init the data_oriented + ndarray kernel should be served from the on-disk fastcache. # --------------------------------------------------------------------------- @@ -328,7 +329,7 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 9c. Same as 9b but with a *nested* ``@qd.data_oriented`` holding an ndarray. Pins that the fastcache args_hasher -# recursion handles nested data_oriented containers correctly across processes. +# recursion handles nested data_oriented containers correctly across processes. # --------------------------------------------------------------------------- @@ -376,7 +377,7 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 9d. Fastcache key is dtype-sensitive: same kernel source, different ndarray dtype in the data_oriented member -> -# two distinct disk cache entries. Pins the args_hasher's ``[nd-{dtype}-{ndim}{layout}]`` repr. +# two distinct disk cache entries. Pins the args_hasher's ``[nd-{dtype}-{ndim}{layout}]`` repr. # --------------------------------------------------------------------------- @@ -425,9 +426,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 9e. Documented fallback: a @qd.data_oriented containing a qd.field disables fastcache for the whole call -# (args_hasher returns None for ScalarField). The kernel still runs correctly via non-fastcache compilation. This -# test pins the documented fallback so a future "support fields in fastcache" change explicitly chooses to update -# this test. +# (args_hasher returns None for ScalarField). The kernel still runs correctly via non-fastcache compilation. This +# test pins the documented fallback so a future "support fields in fastcache" change explicitly chooses to update +# this test. # --------------------------------------------------------------------------- @@ -480,7 +481,7 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 10. Pure validation: a @qd.pure @qd.kernel taking a data_oriented arg with an ndarray member should compile and -# run, mirroring the existing ``test_pure_validation_data_oriented_as_param`` test which only covers ``qd.field``. +# run, mirroring the existing ``test_pure_validation_data_oriented_as_param`` test which only covers ``qd.field``. # --------------------------------------------------------------------------- @@ -507,8 +508,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 11. Counter-test: confirm a dataclass-of-NDArray works (sanity check that the existing supported route still works; -# if this fails, the test environment itself is broken, not the data_oriented path). +# 11. Counter-test: confirm a dataclass-of-NDArray works (sanity check that the existing supported route still +# works; if this fails, the test environment itself is broken, not the data_oriented path). # --------------------------------------------------------------------------- @@ -534,7 +535,7 @@ def run(s: State): # --------------------------------------------------------------------------- # 12. data_oriented holding a (frozen) dataclass that holds an ndarray. Exercises the ``else`` branch of ``_walk_obj`` -# recursing through a dataclass child — added by the Bug 1 fix. +# recursing through a dataclass child — added by the Bug 1 fix. # --------------------------------------------------------------------------- @@ -565,11 +566,11 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 13. Frozen dataclass holding a data_oriented holding an ndarray, kernel-arg via ``qd.template()``. Exercises the -# dataclass branch of ``_walk_obj`` recursing through a data_oriented child — added by the Bug 1 fix. The outer -# dataclass must be frozen because (i) non-frozen dataclasses are unhashable in Python (``__hash__ is None``) and the -# template-mapper key tuple needs the value to be hashable, and (ii) the typed-dataclass-arg form -# (``def run(s: Outer):``) goes through ``_transform_kernel_arg`` which does not currently recurse on data_oriented -# field *types* (as opposed to values) — that's a separate follow-up. +# dataclass branch of ``_walk_obj`` recursing through a data_oriented child — added by the Bug 1 fix. The outer +# dataclass must be frozen because (i) non-frozen dataclasses are unhashable in Python (``__hash__ is None``) and +# the template-mapper key tuple needs the value to be hashable, and (ii) the typed-dataclass-arg form (``def +# run(s: Outer):``) goes through ``_transform_kernel_arg`` which does not currently recurse on data_oriented +# field *types* (as opposed to values) — that's a separate follow-up. # --------------------------------------------------------------------------- @@ -636,7 +637,7 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- # 15. Mutation on a nested ndarray: outer.inner.x reassigned between kernel calls. Verifies the Bug 2 stale-cache -# guard fires even when the ndarray lives several attribute hops deep. +# guard fires even when the ndarray lives several attribute hops deep. # --------------------------------------------------------------------------- @@ -718,8 +719,8 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 16. Same data_oriented instance, two kernels sharing it. Verifies the launch-info per-kernel bookkeeping is independent -# (each kernel's compile sets up its own pre-declared ndarray args). +# 16. Same data_oriented instance, two kernels sharing it. Verifies the launch-info per-kernel bookkeeping is +# independent (each kernel's compile sets up its own pre-declared ndarray args). # --------------------------------------------------------------------------- @@ -755,8 +756,8 @@ def fill_y_from_x(s: qd.template()): # --------------------------------------------------------------------------- # 17. data_oriented + ndarray + @qd.func sub-call. Pins that the AST-time attribute resolution in ``build_Attribute`` -# (which uses the predeclared AnyArray cache) works when the access happens inside a func, not just the top-level -# kernel. +# (which uses the predeclared AnyArray cache) works when the access happens inside a func, not just the top-level +# kernel. # --------------------------------------------------------------------------- @@ -786,9 +787,9 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 18. Reassign ndarray to a *different ndim* on the same data_oriented instance. Complementary to test 7 -# (different-dtype). Spec key must change so a 2D-specialised kernel is not reused for a 1D ndarray. Pins the Gap A -# fix from the dtype side. +# 18. Reassign ndarray to a *different ndim* on the same data_oriented instance. Complementary to test 7 (different- +# dtype). Spec key must change so a 2D-specialised kernel is not reused for a 1D ndarray. Pins the Gap A fix from +# the dtype side. # --------------------------------------------------------------------------- @@ -823,7 +824,7 @@ def fill_2d(s: qd.template()): # --------------------------------------------------------------------------- # 19. Spec-key descent for nested data_oriented + ndarray reassign at the leaf. Confirms the recursive walker in -# ``_collect_struct_nd_descriptors`` reaches through nested data_oriented. +# ``_collect_struct_nd_descriptors`` reaches through nested data_oriented. # --------------------------------------------------------------------------- @@ -862,16 +863,16 @@ def run_f32(s: qd.template()): # --------------------------------------------------------------------------- -# 20. No spec-key regression for data_oriented containers WITHOUT ndarrays. The Gap A fix prepends ndarray -# descriptors only when ndarrays are present; otherwise the original ``weakref.ref(arg)`` spec key is preserved (one -# spec per instance). This test pins the no-ndarray case. +# 20. No spec-key regression for data_oriented containers WITHOUT ndarrays. The Gap A fix prepends ndarray descriptors +# only when ndarrays are present; otherwise the original ``weakref.ref(arg)`` spec key is preserved (one spec per +# instance). This test pins the no-ndarray case. # --------------------------------------------------------------------------- # --------------------------------------------------------------------------- -# 21. Typed-dataclass kernel arg with a ``@qd.data_oriented`` field type — should error clearly pointing the user -# to ``qd.template()``. The two patterns are incompatible at the kernel-arg layer: dataclass kernel args are -# flattened using annotations, data_oriented containers need a value-driven walk. Pins the helpful error message. +# 21. Typed-dataclass kernel arg with a ``@qd.data_oriented`` field type — should error clearly pointing the user to +# ``qd.template()``. The two patterns are incompatible at the kernel-arg layer: dataclass kernel args are +# flattened using annotations, data_oriented containers need a value-driven walk. Pins the helpful error message. # --------------------------------------------------------------------------- @@ -923,70 +924,542 @@ def run(s: qd.template()): # --------------------------------------------------------------------------- -# 22. @qd.data_oriented holding a qd.Tensor wrapper around an ndarray. -# -# Both ``_build_struct_nd_paths`` and ``_collect_struct_nd_descriptors`` in -# ``_template_mapper_hotpath.py`` have a ``if type(v) in _TENSOR_WRAPPER_TYPES: v = v._unwrap()`` branch that the rest -# of the file doesn't exercise (every other test attaches a bare ``qd.ndarray``). This test covers that unwrap path -# for the ndarray-backed wrapper: the struct-walker should treat ``state.a`` as if it were a bare ndarray (paths -# cached on the class, shape descriptors collected from the unwrapped impl). +# 22. Robustness: object graphs with Pydantic-style metaclass ``__getattr__`` recursion, and cyclic attribute +# references. Real-world container classes (notably Genesis's ``RigidOptions`` / ``SimOptions``) inherit from +# ``pydantic.BaseModel`` whose ``ModelMetaclass.__getattr__`` recurses infinitely on missing class attributes. +# Quadrants' walker must not blow the stack when it traverses a ``data_oriented`` arg that contains such an +# object, or that contains a back-reference to itself / its parent (e.g. ``solver.scene.solver``). # --------------------------------------------------------------------------- +def test_is_data_oriented_safe_on_pydantic_like_metaclass(): + """``is_data_oriented`` must not invoke ``__getattr__`` on the class (or metaclass), so it stays safe in the + presence of pathological metaclasses whose ``__getattr__`` blows the Python recursion limit on arbitrary + attribute lookups (e.g. Pydantic's ``ModelMetaclass`` when probed for a name not in its private-attrs cache). + """ + + from quadrants.lang.util import is_data_oriented + + class RecursingMeta(type): + def __getattr__(cls, item): + return cls.__getattr__(item) + + class Pathological(metaclass=RecursingMeta): + pass + + # Pre-fix this raised RecursionError; with the MRO+__dict__ lookup it just returns False. + assert is_data_oriented(Pathological()) is False + + @test_utils.test(arch=qd.cpu) -def test_data_oriented_ndarray_wrapper(): - N = 6 +def test_data_oriented_with_pydantic_like_child(): + """A ``@qd.data_oriented`` class holding a child whose metaclass has the recursing ``__getattr__`` + (Pydantic-style). Walker must classify the child as non-data-oriented and continue without blowing the stack. + """ + N = 4 + + class RecursingMeta(type): + def __getattr__(cls, item): + return cls.__getattr__(item) + + class Options(metaclass=RecursingMeta): + pass @qd.data_oriented class State: - def __init__(self, a): - self.a = a + def __init__(self, x, opts): + self.x = x + self.opts = opts - a = qd.tensor(qd.i32, shape=(N,), backend=qd.Backend.NDARRAY) - state = State(a=a) + x = qd.ndarray(qd.i32, shape=(N,)) + state = State(x=x, opts=Options()) @qd.kernel - def run(s: qd.Template): + def run(s: qd.template()): for i in range(N): - s.a[i] = i + 1 + s.x[i] = i + 1 run(state) - np.testing.assert_array_equal(a.to_numpy(), np.arange(1, N + 1)) + np.testing.assert_array_equal(x.to_numpy(), np.arange(1, N + 1)) - run(state) +@test_utils.test(arch=qd.cpu) +def test_data_oriented_polymorphic_attr_across_instances(): + """Some real-world ``@qd.data_oriented`` containers (Genesis FEMSolver / MPMSolver / SPHSolver, etc.) hold + polymorphic children whose types differ between instances — e.g. ``self.material.x`` is an ``Ndarray`` on + instance A and a ``qd.field`` (``MatrixField``) on instance B. The per-instance path cache walks each instance + fresh, but ``_collect_struct_nd_descriptors`` must additionally tolerate a path's leaf no longer being an + ``Ndarray`` *within a single instance's lifetime* (e.g. ``qd.Tensor`` impl swap), and silently skip the stale + entry rather than crash on ``v.element_type``.""" + N = 4 -# --------------------------------------------------------------------------- -# 24. Cycle-detection regression: a ``@qd.data_oriented`` container with an attribute-graph back-edge -# (``sim.solver.sim is sim``) must not blow the Python stack when walked by either the launch-time hotpath -# ``_build_struct_nd_paths`` or the compile-time ``function_def_transformer._walk_obj``. Pre-fix this recursed -# indefinitely; both walkers now carry a ``seen`` set keyed by ``id(obj)``. -# --------------------------------------------------------------------------- + @qd.data_oriented + class State: + def __init__(self, x): + self.x = x + + # First instance: ``self.x`` is an Ndarray. The walker emits path ``('x',)`` and caches it. + x_nd = qd.ndarray(qd.i32, shape=(N,)) + state_a = State(x=x_nd) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 1 + + run(state_a) + np.testing.assert_array_equal(x_nd.to_numpy(), np.arange(1, N + 1)) + + # Second instance of the SAME class, ``self.x`` is now a ``qd.field`` (MatrixField via Vector.field). + # The cached path ``('x',)`` from instance A points to a non-Ndarray on this instance — the descriptor + # walk must skip it cleanly rather than crash on ``v.element_type``. + f = qd.Vector.field(2, qd.i32, shape=(N,)) + state_b = State(x=f) + + @qd.kernel + def run_field(s: qd.template()): + for i in range(N): + s.x[i] = [i, i + 1] + + run_field(state_b) @test_utils.test(arch=qd.cpu) -def test_data_oriented_attribute_cycle_does_not_recurse_infinitely(): +def test_data_oriented_polymorphic_attribute_set_across_instances(): + """Models the Genesis ``DataManager`` failure mode: a ``@qd.data_oriented`` class whose ``__init__`` conditionally + allocates attributes based on a construction flag. Different instances of the same class then have different + attribute *sets* (not just different value types at the same paths). + + With a per-class path cache populated from the first instance walked, this would either AttributeError when the + second instance lacks an attribute the first had (forward direction) or silently miss an ndarray the second + instance has but the first didn't (inverse direction). Per-instance caching walks each instance fresh so both + directions work.""" N = 4 @qd.data_oriented - class Solver: - def __init__(self, sim): - self.sim = sim - self.a = qd.ndarray(qd.i32, shape=(N,)) + class PolyState: + def __init__(self, with_extra: bool): + self.x = qd.ndarray(qd.i32, shape=(N,)) + if with_extra: + self.extra = qd.ndarray(qd.i32, shape=(N,)) + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 1 + + # Forward direction: first instance has 'extra', second doesn't. Used to AttributeError on the cached + # ('extra',) path when running with state_lean. + state_full = PolyState(with_extra=True) + run(state_full) + state_lean = PolyState(with_extra=False) + run(state_lean) + np.testing.assert_array_equal(state_lean.x.to_numpy(), np.arange(1, N + 1)) + + # Inverse direction: a different class so per-class cache (if used by __slots__ fallback) starts fresh; first + # instance lacks 'extra', second has it. The kernel actually *reads* ``s.extra`` so the inverse-direction + # silent miscache (which only manifests when the kernel touches the conditional attr) is exercised end-to-end. @qd.data_oriented - class Sim: + class PolyState2: + def __init__(self, with_extra: bool): + self.x = qd.ndarray(qd.i32, shape=(N,)) + if with_extra: + self.extra = qd.ndarray(qd.i32, shape=(N,)) + + @qd.kernel + def run_using_extra(s: qd.template()): + for i in range(N): + s.x[i] = s.extra[i] * 10 + + # Walk the lean instance first (no 'extra'), populating any per-class state with the *narrow* attribute set. + # With the old per-class cache, this would lock in paths = [('x',)] for the class — and the next instance's + # ``extra`` would be silently absent from args_hash and from the kernel spec, leading to a wrong-shape kernel + # or a stale-cache hit when ``extra`` is later reassigned. + state_lean2 = PolyState2(with_extra=False) + run(state_lean2) + np.testing.assert_array_equal(state_lean2.x.to_numpy(), np.arange(1, N + 1)) + + # Now the polymorphic-attr-bearing instance. The per-instance walk must include ``('extra',)`` so that + # ``state_full2.extra``'s shape/id participates in the spec and the kernel compiles correctly. + state_full2 = PolyState2(with_extra=True) + state_full2.extra.from_numpy(np.array([2, 3, 5, 7], dtype=np.int32)) + run_using_extra(state_full2) + np.testing.assert_array_equal(state_full2.x.to_numpy(), np.array([20, 30, 50, 70], dtype=np.int32)) + + # Reassignment-detection check: swap ``state_full2.extra`` to a different ndarray. The per-instance walk caches + # the *path list* ([('x',), ('extra',)]) on the instance, but the per-call args_hash still folds in + # ``id(getattr(state_full2, 'extra'))`` — so a swap should miss the spec-key cache and re-specialise. + state_full2.extra = qd.ndarray(qd.i32, shape=(N,)) + state_full2.extra.from_numpy(np.array([11, 13, 17, 19], dtype=np.int32)) + run_using_extra(state_full2) + np.testing.assert_array_equal(state_full2.x.to_numpy(), np.array([110, 130, 170, 190], dtype=np.int32)) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_with_cyclic_attr_graph(): + """A ``@qd.data_oriented`` class whose attribute graph contains a cycle (``parent.child.parent is parent``). + Walker must not re-enter the cycle.""" + N = 4 + + @qd.data_oriented + class Child: def __init__(self): - self.solver = Solver(self) + self.parent = None + + @qd.data_oriented + class Parent: + def __init__(self, x): + self.x = x + self.child = Child() + self.child.parent = self # cycle + + x = qd.ndarray(qd.i32, shape=(N,)) + p = Parent(x=x) + + @qd.kernel + def run(s: qd.template()): + for i in range(N): + s.x[i] = i + 10 + + run(p) + np.testing.assert_array_equal(x.to_numpy(), np.arange(10, 10 + N)) + + +# --------------------------------------------------------------------------- +# Pruning-driven fastcache behaviour for @qd.data_oriented containers. +# +# These pin the three rules enforced by the args hasher (see fastcache.md "Pruning-driven argument hashing"): +# 1. The cache key may only include contributions from kernel-pruned paths. +# 2. Unrecognised types at kernel-read paths must not be silently dropped. +# 3. Fastcache works for @qd.data_oriented kernel args end-to-end. +# --------------------------------------------------------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_unused_opaque_member_does_not_affect_cache(tmp_path, monkeypatch): + """Rule 1: kernel-unused opaque members do not affect the fastcache key. + + Two ``State`` instances differ only in an opaque ``uuid`` member that the kernel never reads. Both must hit the + same compiled artifact on the second process — proof that the args hasher's pruning narrow walk skips the opaque + attribute (no qualname-fallback, no spurious miss).""" + import uuid + + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x): + self.x = x + self.uuid = uuid.uuid4() # opaque member, kernel does not read it + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(4): + s.x[i] = s.x[i] + 1 + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,))) + b = State(x=qd.ndarray(qd.i32, shape=(4,))) + run(a) + run(b) + + # Second process: cold-start, must load from disk. If the uuid had leaked into the cache key, different uuid → + # different L2 key → no artifact would load. + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,))) + b = State(x=qd.ndarray(qd.i32, shape=(4,))) + run(a) + run(b) + assert captured[-2] is not None, "first instance should load from disk" + assert captured[-1] is not None, "second instance (different uuid) should ALSO load from disk" + assert run._primal.src_ll_cache_observations.cache_loaded + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_read_opaque_member_fails_fastcache(tmp_path, capfd) -> None: + """Rule 2: when the kernel actually reads an unrecognised-type member, fastcache fails loudly with [UNKNOWN_TYPE] + + [INVALID_FUNC] — no silent drop, no qualname fallback. The kernel still runs via normal compilation.""" + from quadrants._test_tools import qd_init_same_arch + from quadrants.lang._fast_caching.args_hasher import reset_unknown_type_warn_state + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + reset_unknown_type_warn_state() + + class CustomConfig: + def __init__(self, scale: int) -> None: + self.scale = scale + + @qd.data_oriented + class State: + def __init__(self, x, cfg): + self.x = x + self.cfg = cfg + + x = qd.ndarray(qd.i32, shape=(4,)) + state = State(x=x, cfg=CustomConfig(scale=3)) + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + scale = s.cfg.scale # makes ``__qd_s__qd_cfg`` and ``__qd_s__qd_cfg__qd_scale`` live + for i in range(4): + s.x[i] = i * scale + + run(state) + _out, err = capfd.readouterr() + np.testing.assert_array_equal(x.to_numpy(), np.arange(4) * 3) + + obs = run._primal.src_ll_cache_observations + assert obs.cache_key_generated is False, "unrecognised type at kernel-read path must disable fastcache" + assert "[FASTCACHE][UNKNOWN_TYPE]" in err + assert CustomConfig.__name__ in err + assert "[FASTCACHE][INVALID_FUNC]" in err + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_read_primitive_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Rule 3 (data_oriented works) + pruning correctness: when the kernel reads a primitive member, its value is + baked into the kernel and must drive a distinct cache entry per value. Two State instances differing only in + ``n`` (read by the kernel) cold-compile separately and both load from disk on the second process.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x, n): + self.x = x + self.n = n # primitive, baked into kernel via ``for i in range(s.n)`` + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(s.n): + s.x[i] = i + s.n + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), n=3) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "different ``n`` → both cold-compile" + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), n=3) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both ``n`` values should load distinct artifacts" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_kernel_unread_primitive_does_not_affect_cache(tmp_path, monkeypatch) -> None: + """Rule 1: kernel-unused primitive members do not affect the cache key. Mirror of the opaque case for + primitives. Two State instances differing only in ``unused_n`` must share the cache.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class State: + def __init__(self, x, unused_n): + self.x = x + self.unused_n = unused_n # kernel never reads this + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + for i in range(4): + s.x[i] = s.x[i] + 1 + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=99) + run(a) + run(b) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=2) + b = State(x=qd.ndarray(qd.i32, shape=(4,)), unused_n=99) + run(a) + run(b) + assert captured[-2] is not None, "first instance should load from disk" + assert captured[-1] is not None, "second instance (different unused_n) should ALSO load from disk" + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_qd_func_chain_propagation_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Pruning chain propagation through ``@qd.func`` calls (``record_after_call`` extension): when the kernel calls + ``f(self.dofs)`` and ``f`` reads ``s.x``, the kernel's pruning set must include ``__qd_self__qd_dofs__qd_x`` so + that changes to the inner ndarray's dtype invalidate the cache. Two States differing in ``dofs.x``'s dtype must + cold-compile separately.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class Dofs: + def __init__(self, x): + self.x = x + + @qd.data_oriented + class State: + def __init__(self, dofs): + self.dofs = dofs + + @qd.func + def write_dofs(d: qd.template(), v: qd.i32): + d.x[0] = v + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + write_dofs(s.dofs, 7) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(dofs=Dofs(x=qd.ndarray(qd.i32, shape=(4,)))) + b = State(dofs=Dofs(x=qd.ndarray(qd.f32, shape=(4,)))) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "differing dofs.x dtype → both cold-compile" + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(dofs=Dofs(x=qd.ndarray(qd.i32, shape=(4,)))) + b = State(dofs=Dofs(x=qd.ndarray(qd.f32, shape=(4,)))) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both dtypes load distinct artifacts" + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_nested_primitive_via_qd_func_distinguishes_cache_key(tmp_path, monkeypatch) -> None: + """Pruning chain propagation through ``f(self.child)`` for *primitive* members of nested data_oriented containers. + + Regression test for a bug where ``record_after_call`` skipped chain-path propagation whenever the caller-side arg + flattened to a ``__qd_*``-prefixed name (which Attribute chains always do — ``self.cfg`` → + ``__qd_self__qd_cfg``). When that happened, primitive members read inside the callee (``cfg.n`` → + ``__qd_cfg__qd_n`` in the callee's chain set) never made it into the kernel's pruning set, so the args-hasher + walked ``self.cfg`` as data_oriented and found no pruned children, yielding an identical hash for *any* value of + ``cfg.n``. Two configs that should produce different kernels (different ``range(s.cfg.n)`` trip counts baked into + codegen) would then share a fastcache entry — leading to stale-kernel hits and silent miscompiles (e.g. Genesis' + ``test_ndarray_no_compile`` was failing with iter-N kernels reused for iter-N+1 scenes that have a different + ``RigidSimStaticConfig.para_level`` baked into their ``qd.static`` branches). + + The fix in ``_pruning.py`` gates propagation on the *root Name* of the chain (``self``, not the flat result), so + both ``f(self)`` and ``f(self.cfg)`` propagate, while already-flattened dataclass refs + (``Name('__qd_state__qd_x')``) are still skipped.""" + from quadrants._test_tools import qd_init_same_arch + + launch_kernel_orig = qd.lang.kernel_impl.Kernel.launch_kernel + captured = [] + + def launch_kernel(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=None): + if self.func.__name__ == "run": + captured.append(compiled_kernel_data) + return launch_kernel_orig(self, key, t_kernel, compiled_kernel_data, *args, qd_stream=qd_stream) + + monkeypatch.setattr("quadrants.lang.kernel_impl.Kernel.launch_kernel", launch_kernel) + + @qd.data_oriented + class Cfg: + def __init__(self, n): + self.n = n # primitive read by ``write_x`` — drives codegen via ``range(c.n)`` + + @qd.data_oriented + class State: + def __init__(self, x, cfg): + self.x = x + self.cfg = cfg + + @qd.func + def write_x(x: qd.template(), c: qd.template()): + for i in range(c.n): + x[i] = i + c.n + + @qd.kernel(fastcache=True) + def run(s: qd.template()): + write_x(s.x, s.cfg) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=2)) + b = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=3)) + run(a) + run(b) + assert captured[-2] is None and captured[-1] is None, "different cfg.n → both cold-compile" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) + + qd_init_same_arch(offline_cache_file_path=str(tmp_path), offline_cache=True) + a = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=2)) + b = State(x=qd.ndarray(qd.i32, shape=(8,)), cfg=Cfg(n=3)) + run(a) + run(b) + assert captured[-2] is not None and captured[-1] is not None, "both cfg.n values load distinct artifacts" + np.testing.assert_array_equal(a.x.to_numpy()[:2], np.array([2, 3], dtype=np.int32)) + np.testing.assert_array_equal(b.x.to_numpy()[:3], np.array([3, 4, 5], dtype=np.int32)) + + +# --------------------------------------------------------------------------- +# 23. @qd.data_oriented holding a qd.Tensor wrapper around an ndarray. +# +# Both ``_build_struct_nd_paths`` and ``_collect_struct_nd_descriptors`` in ``_template_mapper_hotpath.py`` have a +# ``if type(v) in _TENSOR_WRAPPER_TYPES: v = v._unwrap()`` branch that the rest of the file doesn't exercise (every +# other test attaches a bare ``qd.ndarray``). This test covers that unwrap path for the ndarray-backed wrapper: the +# struct-walker should treat ``state.a`` as if it were a bare ndarray (paths cached on the class, shape descriptors +# collected from the unwrapped impl). +# --------------------------------------------------------------------------- - sim = Sim() - assert sim.solver.sim is sim + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_ndarray_wrapper(): + N = 6 + + @qd.data_oriented + class State: + def __init__(self, a): + self.a = a + + a = qd.tensor(qd.i32, shape=(N,), backend=qd.Backend.NDARRAY) + state = State(a=a) @qd.kernel def run(s: qd.Template): for i in range(N): - s.solver.a[i] = i + 7 + s.a[i] = i + 1 - run(sim) - np.testing.assert_array_equal(sim.solver.a.to_numpy(), np.arange(N) + 7) - run(sim) + run(state) + np.testing.assert_array_equal(a.to_numpy(), np.arange(1, N + 1)) + + run(state) diff --git a/tests/python/test_data_oriented_qd_func_dataclass.py b/tests/python/test_data_oriented_qd_func_dataclass.py new file mode 100644 index 0000000000..b504aae31c --- /dev/null +++ b/tests/python/test_data_oriented_qd_func_dataclass.py @@ -0,0 +1,320 @@ +"""Tests for calling @qd.func that takes a typed-dataclass arg, from a @qd.kernel method of a @qd.data_oriented +class, passing ``self.dataclass_member`` as the arg. + +Genesis's @qd.func helpers declare typed-dataclass parameters (e.g. ``def func(links_state: LinksState, ...):``) and +are designed to be called from kernels that also take typed-dataclass kernel args (so the dataclass is flattened into +per-leaf kernel-locals on both sides of the call boundary). + +When migrating Genesis modules to @qd.data_oriented, we'd like to call the same @qd.func helpers from a data_oriented +kernel method, passing ``self.links_state`` as the arg. Today this fails at AST resolution: + + Missing argument '__qd_links_state__qd_cinr_inertial'. + Unexpected argument 'links_state'. + +These tests pin down the failure modes so we can fix them. +""" + +import dataclasses + +import numpy as np + +import quadrants as qd + +from tests import test_utils + +# ----- typed-dataclass kernel-arg baseline (works) ---------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_baseline_typed_dataclass_kernel_arg_calls_qd_func(): + """Baseline: typed-dataclass kernel arg + qd.func taking same dataclass type — works.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + y: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.kernel + def run(state: State): + for i in range(N): + write_x(state, i, i * 3) + + state = State( + x=qd.ndarray(qd.i32, shape=(N,)), + y=qd.ndarray(qd.i32, shape=(N,)), + ) + run(state) + np.testing.assert_array_equal(state.x.to_numpy(), np.arange(N) * 3) + + +# ----- data_oriented self-method calling qd.func (the broken case) ----------- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_dataclass_member(): + """data_oriented holds a dataclass; self-kernel calls a @qd.func taking that dataclass.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + y: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State( + x=qd.ndarray(qd.i32, shape=(N,)), + y=qd.ndarray(qd.i32, shape=(N,)), + ) + + @qd.kernel + def run(self): + for i in range(N): + write_x(self.state, i, i * 5) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 5) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_keyword_dataclass_member(): + """Same as above but the qd.func arg is passed by keyword (Genesis pattern).""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_x(state=self.state, i=i, v=i * 7) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 7) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_stable_members_method_calls_qd_func_with_dataclass_member(): + """Same as above but with stable_members=True (the FPS-relevant case).""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_x(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_x(state=self.state, i=i, v=i * 11) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 11) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_two_dataclass_members(): + """Two dataclass members, qd.func takes both — Genesis-shaped scenario.""" + N = 4 + + @dataclasses.dataclass + class StateA: + a: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class StateB: + b: qd.types.NDArray[qd.i32, 1] + + @qd.func + def write_both(sa: StateA, sb: StateB, i: qd.i32, va: qd.i32, vb: qd.i32): + sa.a[i] = va + sb.b[i] = vb + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.sa = StateA(a=qd.ndarray(qd.i32, shape=(N,))) + self.sb = StateB(b=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + write_both(sa=self.sa, sb=self.sb, i=i, va=i * 2, vb=i * 13) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.sa.a.to_numpy(), np.arange(N) * 2) + np.testing.assert_array_equal(solver.sb.b.to_numpy(), np.arange(N) * 13) + + +# ----- nested dataclass -------------------------------------------------------- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_nested_dataclass_member(): + """data_oriented holds an Outer{ Inner{ ndarray } } and passes ``self.outer`` to a @qd.func that expands the + nested dataclass into flat leaves on both sides.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def write_inner_x(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.data_oriented + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + write_inner_x(self.outer, i, i * 17) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 17) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_calls_qd_func_with_nested_dataclass_kwarg(): + """Same as above but the dataclass arg is passed by keyword.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def write_inner_x(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + write_inner_x(outer=self.outer, i=i, v=i * 19) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 19) + + +# ----- chained @qd.func calls (qd.func -> qd.func, dataclass threaded through) --- + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_qd_func_chain_with_dataclass_member(): + """data_oriented kernel calls outer @qd.func, which in turn calls inner @qd.func, threading the same dataclass + arg through. Both qd.funcs have the typed-dataclass parameter; only the outermost call site (data_oriented method + body) uses self.X. The two inner call sites use the typed-arg path that already worked.""" + N = 4 + + @dataclasses.dataclass + class State: + x: qd.types.NDArray[qd.i32, 1] + + @qd.func + def inner_write(state: State, i: qd.i32, v: qd.i32): + state.x[i] = v + + @qd.func + def outer_write(state: State, i: qd.i32, v: qd.i32): + inner_write(state, i, v) + + @qd.data_oriented + class Solver: + def __init__(self): + self.state = State(x=qd.ndarray(qd.i32, shape=(N,))) + + @qd.kernel + def run(self): + for i in range(N): + outer_write(self.state, i, i * 23) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.state.x.to_numpy(), np.arange(N) * 23) + + +@test_utils.test(arch=qd.cpu) +def test_data_oriented_method_qd_func_chain_with_nested_dataclass_member(): + """Combination: nested dataclass passed through a chain of two @qd.func calls from a @qd.data_oriented + self-method via self.outer.""" + N = 4 + + @dataclasses.dataclass + class Inner: + x: qd.types.NDArray[qd.i32, 1] + + @dataclasses.dataclass + class Outer: + inner: Inner + + @qd.func + def inner_write(outer: Outer, i: qd.i32, v: qd.i32): + outer.inner.x[i] = v + + @qd.func + def outer_write(outer: Outer, i: qd.i32, v: qd.i32): + inner_write(outer, i, v) + + @qd.data_oriented(stable_members=True) + class Solver: + def __init__(self): + self.outer = Outer(inner=Inner(x=qd.ndarray(qd.i32, shape=(N,)))) + + @qd.kernel + def run(self): + for i in range(N): + outer_write(self.outer, i, i * 29) + + solver = Solver() + solver.run() + np.testing.assert_array_equal(solver.outer.inner.x.to_numpy(), np.arange(N) * 29) diff --git a/tests/python/test_template_typing.py b/tests/python/test_template_typing.py index 69e9ee990b..11e9d5da72 100644 --- a/tests/python/test_template_typing.py +++ b/tests/python/test_template_typing.py @@ -57,24 +57,35 @@ class DataOrientedWithoutFloat: def __init__(self) -> None: self.an_int = 123 self.a_bool = True + self.scratch = qd.ndarray(qd.i32, shape=(1,)) @qd.data_oriented class DataOrientedWithFloat: def __init__(self) -> None: self.an_int = 123 self.a_float = 1.23 + self.scratch = qd.ndarray(qd.i32, shape=(1,)) + # Read the primitive members so the fastcache narrow walk includes them in the hash. Pre-pruning the args_hasher + # walked every member of every container arg blindly; with pruning the kernel must actually access ``a.a_float`` + # for the raise-on-templated-floats guard to fire (the value being baked-in only matters when the kernel reads + # it). @qd.kernel(fastcache=True) - def k1(a: qd.Template) -> None: ... + def k1f(a: qd.Template) -> None: + a.scratch[0] = qd.cast(a.a_float, qd.i32) + + @qd.kernel(fastcache=True) + def k1i(a: qd.Template) -> None: + a.scratch[0] = a.an_int my_do1 = DataOrientedWithoutFloat() - k1(my_do1) + k1i(my_do1) my_do2 = DataOrientedWithFloat() if raise_on_templated_floats: with pytest.raises(ValueError): - k1(my_do2) + k1f(my_do2) else: - k1(my_do2) + k1f(my_do2) @pytest.mark.parametrize("raise_on_templated_floats", [False, True])