diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index b018a6194f..a9d355db45 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -4,15 +4,16 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i ## Backend support -`graph=True` and `graph_do_while` run on every backend. They are *hardware accelerated* on CUDA (via CUDA graphs) and AMDGPU (via HIP graphs); `graph_do_while` additionally requires CUDA SM 9.0+ / Hopper for its hardware-accelerated path. On other backends, `graph=True` is silently ignored and the kernel runs via the normal launch path, and `graph_do_while` falls back to a host-side do-while loop that copies the condition value GPU → host each iteration (causing a pipeline stall). `qd.checkpoint` gating runs entirely on the device on every GPU backend; only the CPU backend uses host-side gating. +`graph=True` and `graph_do_while` run on every backend. They are *hardware accelerated* on CUDA (via CUDA graphs) and AMDGPU (via HIP graphs); `graph_do_while` additionally requires [CUDA SM 9.0+](https://developer.nvidia.com/cuda/gpus) for its hardware-accelerated path. On other backends, `graph=True` is silently ignored and the kernel runs via the normal launch path, and `graph_do_while` falls back to a host-side do-while loop. `qd.checkpoint` gating runs entirely on the device on every GPU backend. | Feature | `qd.cuda` SM 9.0+ | `qd.cuda` < SM 9.0 | `qd.amdgpu` | `qd.metal` | `qd.vulkan` | `qd.cpu` | | --- | --- | --- | --- | --- | --- | --- | | `graph=True` | hardware accelerated | hardware accelerated | hardware accelerated | runs (no acceleration) | runs (no acceleration) | runs (no acceleration) | | `qd.graph_do_while` | hardware accelerated | host fallback | host fallback | host fallback | host fallback | host fallback | | `qd.checkpoint` | GPU-side | GPU-side | GPU-side | GPU-side | GPU-side | host-side | +| `qd.graph_parallel_context` / `qd.graph_parallel` (sections) | concurrent | concurrent | runs serially | runs serially | runs serially | runs serially | -AMDGPU `graph_do_while` falls back to the host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2). +AMDGPU `graph_do_while` falls back to a host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2). Nested and sibling `graph_do_while` loops (and mixing `graph_do_while` with top-level `for`-loops) are **experimental** for now — see [Nested loops and mixing with for-loops](#nested-loops-and-mixing-with-for-loops). @@ -44,7 +45,7 @@ my_kernel(x, y) # first call: builds and caches the graph my_kernel(x, y) # subsequent calls: replays the cached graph ``` -This works the same way on CUDA and AMDGPU. The cache is keyed per (compiled-kernel-specialization, launch-id), so different template instantiations (different field bindings, etc.) get their own cached graph. +This works the same way on CUDA and AMDGPU. ### Restrictions @@ -122,7 +123,7 @@ def converge(x: qd.types.ndarray(qd.f32, ndim=1), ### Do-while semantics -`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behavior of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop. +`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop. ### ndarray vs field @@ -161,7 +162,7 @@ Note that `qd.func`'s are inlined, so you can freely factorize these structures ### Restrictions -- The counter ndarray may be swapped between calls: the cached graph reads each counter through an indirection slot that is refreshed on every launch, so passing a different ndarray (or alternating between several) replays the cached graph without rebuilding it. +- The counter ndarray may be swapped between calls. Passing a different ndarray (or alternating between several) replays the cached graph without rebuilding it. ### Caveats @@ -179,7 +180,7 @@ Therefore on unsupported platforms, you might consider creating a second impleme ## Checkpoints with `qd.checkpoint` *(experimental)* -> **Experimental.** `qd.checkpoint`, `qd.GraphStatus`, and `kernel.resume(from_checkpoint=...)` are experimental APIs. The shape of the public surface (the context-manager signature, the `@qd.kernel(checkpoints=True)` flag, the `GraphStatus` fields, the host-side resume loop, the error messages, and the cross-backend lowering details) may change in any future release without a deprecation cycle. +> **Experimental.** `qd.checkpoint`, `qd.GraphStatus`, and `kernel.resume(from_checkpoint=...)` are experimental APIs. The shape of the public surface (the context-manager signature, the `@qd.kernel(checkpoints=True)` flag, the `GraphStatus` fields, the host-side resume loop, the error messages, and the cross-backend compilation details) may change in any future release without a deprecation cycle. `qd.checkpoint` lets a graph kernel break partway through, surface a reason to the host, let the host fix things up, and resume from the same location on the next launch. An example use-case is an algorithm implemented as a graph that may need to allocate additional memory partway through, where the operations in the graph are in-place, and therefore cannot be rerun without changing/corrupting the output, and therefore for which simply retrying the whole graph from the start is not an option. @@ -254,7 +255,7 @@ while status.yielded: ### Restrictions -- Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time with a fix-it pointing at `checkpoints=True`. +- Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time. - `cp_id` must be an int literal or an `IntEnum` value, and must be unique across the kernel. - `yield_on=` must be a kernel parameter that is a 0-d `qd.types.ndarray(qd.i32, ndim=0)`; expressions are not supported. - Checkpoints cannot be nested inside other checkpoints. Checkpoints inside a `qd.graph_do_while` body are fine. @@ -268,7 +269,7 @@ while status.yielded: arr[i] = arr[i] + 1 ``` -The restriction is by design: each top-level statement inside a checkpoint becomes its own GPU task / graph node, so silently wrapping bare statements would hide a sequence of N field writes ballooning into N kernel launches. Forcing the user to write the `for`-wrap themselves keeps the lowering visible and gives a single obvious place to fuse multiple writes into one task by sharing a single wrapper. +The restriction is by design: each top-level statement inside a checkpoint becomes its own GPU task / graph node, so silently wrapping bare statements would hide a sequence of N field writes ballooning into N kernel launches. Forcing the user to write the `for`-wrap themselves keeps the mapping to GPU tasks visible and gives a single obvious place to fuse multiple writes into one task by sharing a single wrapper. ## Performance @@ -393,11 +394,9 @@ k1(a, count) The recommendation is to use the graph do while here anyway, if you need it for any platform, in order to ensure the code is compact and maintainable. -If you do want fixed-size for loops to run optimally on unsupported hardware platforms, we could add a specializd `qd.graph_range_for` function. This would: -- on graph-do-while-supported hardware: handle adding the additional increment kernel -- on graph-do-while-unsupported hardware: handle running the loop entirely on the host-side, to avoid adding a gpu pipeline stall +If you need fixed-size for loops to run optimally on hardware without graph-do-while support, consider opening a PR for a dedicated helper that picks the host-side fallback automatically on those backends. -In practice, for our own kernels, i.e. in genesis-world, they largely fall under the do while formulation, see the previous section. However, also have some that used to be do while, but have been migrated to an optimized fixed-size, see next section. +In practice, most such loops fall under the do while formulation (see the previous section). Some that were originally do while have since been migrated to an optimized fixed-size form (see the next section). ### A while loop, conditional on a device-side scalar tensor, that has been optimized into a fixed-size for loop @@ -465,4 +464,68 @@ In this case, our recommendation is: - use graph do while anyway, if you need it on any platform - this will ensure your code is compact and maintainable - if you need optimum 100% performance on unsupported platforms, then consider PRing onto quadrants an optimized graph implementation for your target platform - - for example it could somehow run MAX_ITER iterations anyway, similar to the earlier hand-rolled version, but via the graph abstraction, hence allowing the code to be compact, cross-platform, and also optimally fast + +## `qd.graph_parallel` sections with `qd.graph_parallel_context` *(experimental)* + +A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them concurrently. + +`qd.graph_parallel_context` is honored by the graph builder so it composes with `graph=True` and `graph_do_while`. + +```python +@qd.kernel(graph=True) +def step(...): + while qd.graph_do_while(ncond): + assemble_shared(...) # serial: feeds both `qd.graph_parallel` sections + + with qd.graph_parallel_context(): # fork: `qd.graph_parallel` sections run concurrently + with qd.graph_parallel(): # point-triangle contacts + pt_assemble(...) + pt_hessian(...) + with qd.graph_parallel(): # edge-edge contacts (independent of pt) + ee_assemble(...) + ee_hessian(...) + # join: everything below waits for BOTH `qd.graph_parallel` sections to finish + merge_hessians(...) + precondition(...) +``` + +### Semantics + +- **Fork / join.** Every `qd.graph_parallel` section in the region forks from the work that precedes the region. All `qd.graph_parallel` sections must finish before any work *after* the region begins (the join). +- **`qd.graph_parallel` sections are independent — you guarantee it.** Calls *within* a `qd.graph_parallel` section keep their program order, but calls in *different* `qd.graph_parallel` sections have no ordering. The `qd.graph_parallel` sections must be data-race free with respect to one another: no `qd.graph_parallel` section may read what another writes, and no two `qd.graph_parallel` sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results. + +### Generating `qd.graph_parallel` sections from a compile-time sequence + +`qd.graph_parallel` sections do not have to be written out one by one. A `for ... in qd.static(...)` loop is unrolled at compile time, so each iteration that contains a `with qd.graph_parallel():` becomes its own section — handy for forking one section per element of a static list (e.g. per contact type): + +(Note: See [compound_types.md](compound_types.md) for qd.data_oriented description) +```python +@qd.data_oriented +class Solver: + def __init__(self): + self.funcs = [self._assemble_pt, self._assemble_ee] # static list of @qd.func members + + @qd.kernel(graph=True) + def step(self): + with qd.graph_parallel_context(): + for i in qd.static(range(len(self.funcs))): # unrolls to one section per func + with qd.graph_parallel(): + self.funcs[i]() +``` + +The loop **must** be a `qd.static(...)` loop (its trip count is known at compile time). A plain runtime `for i in range(n):` is rejected — a runtime loop cannot be unrolled into independent sections. + +### Restrictions (enforced at kernel compile time) + +- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional `qd.graph_parallel` section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set) or `for ... in qd.static(...)` loops (generate one `qd.graph_parallel` section per element of a compile-time sequence). +- `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()`. +- `qd.graph_parallel_context` cannot be nested, and a `qd.graph_parallel` section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a `qd.graph_parallel` section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above). + +### Backend behavior + +| backend | scheduling | +| --- | --- | +| CUDA | `qd.graph_parallel` sections run **concurrently** | +| AMDGPU / CPU / Vulkan / Metal | `qd.graph_parallel` sections run **serially** | + +Because `qd.graph_parallel` sections are independent by construction, running them serially produces identical results — only the scheduling differs. diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index a8db331bcc..a3df7cedc2 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -48,6 +48,8 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. +> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#qdgraph_parallel-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honored by the graph builder. + ### Restrictions - All top-level statements in a kernel must be either all `stream_parallel` blocks or all regular statements. Mixing the two at the top level is a compile-time error. @@ -134,4 +136,4 @@ with qd.create_stream() as s: ## Limitations - **Not compatible with graphs.** Do not pass `qd_stream` to a kernel decorated with `graph=True` (if you do, a `RuntimeError` will be raised). -- **Not compatible with autodiff.** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context (if you do, a `RuntimeError` will be raised). +- **Not compatible with [autodiff.md](autodiff.md).** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context (if you do, a `RuntimeError` will be raised). diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 0297fffa8c..47749650dd 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -32,6 +32,9 @@ from quadrants.lang.ast.ast_transformers.function_def_transformer import ( FunctionDefTransformer, ) +from quadrants.lang.ast.ast_transformers.graph_parallel_transformer import ( + GraphParallelTransformer, +) from quadrants.lang.exception import ( QuadrantsIndexError, QuadrantsRuntimeTypeError, @@ -1615,9 +1618,16 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: if checkpoint_info is not None: return ASTTransformer._build_checkpoint_with(ctx, node, checkpoint_info) + if GraphParallelTransformer.is_graph_parallel_context_call(item.context_expr): + return GraphParallelTransformer.build_graph_parallel_context_with(ctx, node, build_stmts) + + if GraphParallelTransformer.is_parallel_section_call(item.context_expr): + return GraphParallelTransformer.build_parallel_section_with(ctx, node, build_stmts) + if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars): raise QuadrantsSyntaxError( - "'with' in Quadrants kernels only supports qd.stream_parallel() or qd.checkpoint()" + "'with' in Quadrants kernels only supports qd.stream_parallel(), qd.checkpoint(), " + "qd.graph_parallel_context(), or qd.graph_parallel()" ) if not ctx.is_kernel: raise QuadrantsSyntaxError("qd.stream_parallel() can only be used inside @qd.kernel, not @qd.func") diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 4e7b8e5154..69cd102472 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -600,6 +600,38 @@ def _is_loop_config_call(stmt: ast.stmt) -> bool: return True return False + @staticmethod + def _is_graph_parallel_context_with(stmt: ast.stmt) -> bool: + """Syntactic check matching GraphParallelTransformer.is_graph_parallel_context_call: a + ``with qd.graph_parallel_context():`` fork/join region.""" + if not isinstance(stmt, ast.With) or len(stmt.items) != 1: + return False + ctx_expr = stmt.items[0].context_expr + if not isinstance(ctx_expr, ast.Call): + return False + func = ctx_expr.func + if isinstance(func, ast.Attribute) and func.attr == "graph_parallel_context": + return True + if isinstance(func, ast.Name) and func.id == "graph_parallel_context": + return True + return False + + @staticmethod + def _is_parallel_section_with(stmt: ast.stmt) -> bool: + """Syntactic check matching GraphParallelTransformer.is_parallel_section_call: a ``with qd.graph_parallel(...):`` + section of a ``qd.graph_parallel_context()`` region.""" + if not isinstance(stmt, ast.With) or len(stmt.items) != 1: + return False + ctx_expr = stmt.items[0].context_expr + if not isinstance(ctx_expr, ast.Call): + return False + func = ctx_expr.func + if isinstance(func, ast.Attribute) and func.attr == "graph_parallel": + return True + if isinstance(func, ast.Name) and func.id == "graph_parallel": + return True + return False + @staticmethod def _validate_graph_do_while_structure(body: list[ast.stmt]) -> None: """If a kernel uses qd.graph_do_while() anywhere, enforce the structural rules that remain after @@ -659,6 +691,27 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo # `CheckpointTransformer.build_checkpoint_with`. FunctionDefTransformer._validate_graph_do_while_stmt_list(stmt.body, is_kernel_top=is_kernel_top) continue + if FunctionDefTransformer._is_graph_parallel_context_with(stmt): + # A `with qd.graph_parallel_context()` region groups concurrent `with qd.graph_parallel()` + # sections; it is a legal sibling of for-loops / checkpoints. Its body must be + # `qd.graph_parallel` section blocks (optionally under `if qd.static(...)`); the full check + # is in GraphParallelTransformer.build_graph_parallel_context_with. Each `qd.graph_parallel` section's + # body is task territory, validated here with the in-loop rules. Descend through `if` members + # so `qd.graph_parallel` sections inside an optional `if qd.static(...)` are reached too. + pending = list(stmt.body) + while pending: + member = pending.pop() + if FunctionDefTransformer._is_parallel_section_with(member): + FunctionDefTransformer._validate_graph_do_while_stmt_list(member.body, is_kernel_top=False) + elif isinstance(member, ast.If): + pending.extend(member.body) + pending.extend(member.orelse) + elif isinstance(member, ast.For): + # `for ... in qd.static(...)` generates sections; descend so each unrolled section + # body is still validated with the in-loop rules (a runtime for here is rejected + # earlier by GraphParallelTransformer.build_graph_parallel_context_with). + pending.extend(member.body) + continue where = "the kernel body" if is_kernel_top else "a qd.graph_do_while() body" raise QuadrantsSyntaxError( f"When a kernel uses qd.graph_do_while(), {where} may not contain a {type(stmt).__name__} " diff --git a/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py b/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py new file mode 100644 index 0000000000..5ca9a2a896 --- /dev/null +++ b/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py @@ -0,0 +1,143 @@ +# type: ignore +"""AST recognition, validation, and lowering for ``qd.graph_parallel_context()`` / ``qd.graph_parallel()`` blocks. + +Lives alongside ``checkpoint_transformer.py`` / ``function_def_transformer.py`` so that ``ast_transformer.py`` doesn't +have to grow per-feature. ``ASTTransformer.build_With`` forwards ``qd.graph_parallel_context()`` regions and their +``qd.graph_parallel()`` sections into the static methods here. + +A ``qd.graph_parallel_context()`` region tags its body with a per-kernel region id (via +``begin/end_graph_parallel_context``) and each ``qd.graph_parallel()`` section inside lowers to a stream-parallel group +(via ``begin/end_stream_parallel``). The graph builder forks the distinct groups of one region in a contiguous run and +joins them; the region id keeps two back-to-back regions apart (each gets its own join). See +``docs/source/user_guide/graph.md`` for the user-facing surface. +""" + +from __future__ import annotations + +import ast + +from quadrants.lang.ast.ast_transformer_utils import ( + ASTTransformerFuncContext, + get_decorator, +) +from quadrants.lang.ast.ast_transformers.function_def_transformer import ( + FunctionDefTransformer, +) +from quadrants.lang.exception import QuadrantsSyntaxError + + +class GraphParallelTransformer: + @staticmethod + def is_graph_parallel_context_call(node: ast.expr) -> bool: + """If *node* is a ``qd.graph_parallel_context()`` call return True, else False.""" + if not isinstance(node, ast.Call): + return False + func = node.func + is_gpc = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel_context") or ( + isinstance(func, ast.Name) and func.id == "graph_parallel_context" + ) + if not is_gpc: + return False + if node.args or node.keywords: + raise QuadrantsSyntaxError("qd.graph_parallel_context() takes no arguments") + return True + + @staticmethod + def is_parallel_section_call(node: ast.expr) -> bool: + """If *node* is a ``qd.graph_parallel()`` (a section) call return True, else False. The call shape is validated + here so misuse raises at the ``with`` site rather than later.""" + if not isinstance(node, ast.Call): + return False + func = node.func + is_parallel_section = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( + isinstance(func, ast.Name) and func.id == "graph_parallel" + ) + if not is_parallel_section: + return False + if node.args or node.keywords: + raise QuadrantsSyntaxError("qd.graph_parallel() takes no arguments") + return True + + @staticmethod + def build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast.With, build_stmts) -> None: + """Handles ``with qd.graph_parallel_context():`` fork/join regions. + + Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains only + ``with qd.graph_parallel():`` blocks, then walks the body. The region is bracketed with + begin/end_graph_parallel_context() so its body carries a per-kernel region id, and each ``qd.graph_parallel`` + section inside lowers to a stream-parallel group (via begin/end_stream_parallel). The graph builder forks the + distinct groups of one region in a contiguous run and joins them; the region id keeps two back-to-back regions + apart (each gets its own join).""" + if not ctx.is_kernel: + raise QuadrantsSyntaxError("qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func") + kernel = ctx.global_context.current_kernel + if kernel is None or not kernel.use_graph: + raise QuadrantsSyntaxError("qd.graph_parallel_context() requires @qd.kernel(graph=True)") + if getattr(ctx, "_in_graph_parallel_context", False): + raise QuadrantsSyntaxError("qd.graph_parallel_context() regions cannot be nested") + if getattr(ctx, "_in_parallel_section", False): + raise QuadrantsSyntaxError("qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body") + GraphParallelTransformer._validate_graph_parallel_context_body(ctx, node.body) + ctx._in_graph_parallel_context = True + ctx.ast_builder.begin_graph_parallel_context() + try: + build_stmts(ctx, node.body) + finally: + ctx.ast_builder.end_graph_parallel_context() + ctx._in_graph_parallel_context = False + return None + + @staticmethod + def _validate_graph_parallel_context_body(ctx: ASTTransformerFuncContext, stmts: list[ast.stmt]) -> None: + """A qd.graph_parallel_context() region body may contain only `with qd.graph_parallel():` blocks, optionally + wrapped in compile-time `if qd.static(...)` (the optional ``qd.graph_parallel`` section pattern, e.g. qipc's + ENABLE_EE) or `for ... in qd.static(...)` loops (generate one ``qd.graph_parallel`` section per element of a + compile-time sequence). Docstrings / coverage probes / `pass` are allowed. Anything else (a runtime for-loop, a + bare assignment, etc.) is a serial task that would silently fall outside any ``qd.graph_parallel`` section, so + reject it. + + The `for` case is restricted to `qd.static(...)` loops on purpose: a static loop unrolls at trace time into its + repeated body, so it lowers to literal `with qd.graph_parallel():` blocks (each gets a fresh + stream_parallel_group_id). A *runtime* for-loop would instead trace a single parallel range_for with the section + tagging nested inside it -- malformed. Staticness is checked with `get_decorator` (the same resolution + `build_For` uses) at every nesting level, so a runtime loop nested under a static one is still rejected.""" + for i, stmt in enumerate(stmts): + if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): + continue + if isinstance(stmt, ast.Pass): + continue + if isinstance(stmt, ast.With) and stmt.items: + if GraphParallelTransformer.is_parallel_section_call(stmt.items[0].context_expr): + continue + if isinstance(stmt, ast.If): + GraphParallelTransformer._validate_graph_parallel_context_body(ctx, stmt.body) + GraphParallelTransformer._validate_graph_parallel_context_body(ctx, stmt.orelse) + continue + if isinstance(stmt, ast.For) and not stmt.orelse and get_decorator(ctx, stmt.iter) == "static": + GraphParallelTransformer._validate_graph_parallel_context_body(ctx, stmt.body) + continue + raise QuadrantsSyntaxError( + "A qd.graph_parallel_context() region may contain only 'with qd.graph_parallel():' blocks " + "(optionally inside 'if qd.static(...)' or 'for ... in qd.static(...)'). Move other work " + f"outside the region. [offending stmt {i}: {type(stmt).__name__}]" + ) + + @staticmethod + def build_parallel_section_with(ctx: ASTTransformerFuncContext, node: ast.With, build_stmts) -> None: + """Handles a ``with qd.graph_parallel():`` section of a ``qd.graph_parallel_context()`` region. + + Reuses the stream-parallel tagging: begin_stream_parallel() assigns this ``qd.graph_parallel`` section a fresh + ``stream_parallel_group_id`` that every for-loop in the body inherits, so the offloaded tasks carry the + ``qd.graph_parallel`` section id all the way to the graph builder.""" + if not getattr(ctx, "_in_graph_parallel_context", False): + raise QuadrantsSyntaxError( + "qd.graph_parallel() can only be used directly inside a qd.graph_parallel_context() region" + ) + ctx._in_parallel_section = True + ctx.ast_builder.begin_stream_parallel() + try: + build_stmts(ctx, node.body) + finally: + ctx.ast_builder.end_stream_parallel() + ctx._in_parallel_section = False + return None diff --git a/python/quadrants/lang/graph_parallel.py b/python/quadrants/lang/graph_parallel.py new file mode 100644 index 0000000000..58b2bf5447 --- /dev/null +++ b/python/quadrants/lang/graph_parallel.py @@ -0,0 +1,63 @@ +"""User-facing ``qd.graph_parallel_context`` / ``qd.graph_parallel`` context-managers and their no-op Python-runtime +stubs. + +Kept in its own module to keep ``lang/misc.py`` from growing further (mirrors ``checkpoint.py``) -- the AST transformer +and the C++ runtime do all the actual implementation work; this file is just the public API entry point. + +Re-exported via ``qd.lang.misc`` (and therefore as ``qd.graph_parallel_context`` / ``qd.graph_parallel``) for the +user-facing canonical import path. +""" + +from __future__ import annotations + +from contextlib import contextmanager + + +@contextmanager +def graph_parallel_context(): + """Opens a fork/join region whose ``qd.graph_parallel()`` sections run concurrently. + + Used as ``with qd.graph_parallel_context():`` inside a ``@qd.kernel(graph=True)`` kernel. The region's + body must contain only ``with qd.graph_parallel():`` blocks. Each ``qd.graph_parallel`` section is an + independent sequence of work; the ``qd.graph_parallel`` sections have no ordering relative to each + other and may execute concurrently, while everything after the region waits for *all* ``qd.graph_parallel`` + sections to finish (the join). This is the graph analogue of ``qd.stream_parallel()`` (which is for + non-graph kernels): it lets independent stages -- e.g. qipc's point-triangle and edge-edge assembly + -- overlap inside a captured graph. + + Concurrency contract (the author's responsibility): ``qd.graph_parallel`` sections must be data-race + free with respect to one another (no ``qd.graph_parallel`` section reads what another writes, no two + ``qd.graph_parallel`` sections write the same location). Calls *within* a ``qd.graph_parallel`` section + keep their program order. + + Backend behavior: + - CUDA SM graph path: ``qd.graph_parallel`` sections become independent graph chains joined by an + empty node, so the runtime schedules them on parallel streams (real overlap). + - CPU / Vulkan / Metal / AMDGPU graph: correct results, ``qd.graph_parallel`` sections run serially + (the concurrency tags are honored only by the graph builder today). + + Restrictions (enforced at kernel compile time): + - Must be used inside ``@qd.kernel(graph=True)``. + - The region body may contain only ``with qd.graph_parallel():`` blocks. + - Regions cannot be nested, and a ``qd.graph_parallel`` section body must be straight-line task work + (no nested ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). + + This function should not be called directly at runtime; it is recognized and transformed during AST compilation. + At Python runtime (outside kernels) it is a no-op context manager. + + See also ``docs/source/user_guide/graph.md``. + """ + yield + + +@contextmanager +def graph_parallel(): + """Declares one ``qd.graph_parallel`` section of an enclosing ``qd.graph_parallel_context()`` region. + + Used as ``with qd.graph_parallel():`` directly inside a ``with qd.graph_parallel_context():`` block. + The ``qd.graph_parallel`` section's body is an independent sequence of work that may run concurrently + with the region's other ``qd.graph_parallel`` sections. + + See ``qd.graph_parallel_context()`` for the full contract and backend behavior. + """ + yield diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index 6c45b1f1dc..1353ec70d2 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -12,6 +12,7 @@ from quadrants.lang import impl, util from quadrants.lang.checkpoint import checkpoint from quadrants.lang.expr import Expr +from quadrants.lang.graph_parallel import graph_parallel, graph_parallel_context from quadrants.lang.graph_status import GraphStatus from quadrants.lang.impl import axes, get_runtime from quadrants.profiler.kernel_profiler import get_default_kernel_profiler @@ -745,7 +746,7 @@ def graph_do_while(condition) -> bool: iteration). The one structural rule is that a ``qd.graph_do_while`` ``while``-loop may appear only at the kernel top level or directly inside another ``graph_do_while`` body, not inside a ``for``-loop. - This function should not be called directly at runtime; it is recognised and transformed during AST compilation. + This function should not be called directly at runtime; it is recognized and transformed during AST compilation. Requires ``@qd.kernel(graph=True)``. """ return bool(condition) @@ -890,6 +891,8 @@ def dump_compile_config() -> None: "GraphStatus", "checkpoint", "graph_do_while", + "graph_parallel_context", + "graph_parallel", "loop_config", "global_thread_idx", "assume_in_range", diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index ccba1e701b..291bbe264b 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -417,6 +417,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { } current_task->block_dim = stmt->block_dim; current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; + current_task->graph_parallel_region_id = stmt->graph_parallel_region_id; current_task->checkpoint_id = stmt->checkpoint_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 4e19864055..6961a97878 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -673,6 +673,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { current_task->block_dim = stmt->block_dim; current_task->dynamic_shared_array_bytes = dynamic_shared_array_bytes; current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; + current_task->graph_parallel_region_id = stmt->graph_parallel_region_id; current_task->checkpoint_id = stmt->checkpoint_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); diff --git a/quadrants/codegen/llvm/llvm_compiled_data.h b/quadrants/codegen/llvm/llvm_compiled_data.h index 7a08c77be5..8f3bed28a1 100644 --- a/quadrants/codegen/llvm/llvm_compiled_data.h +++ b/quadrants/codegen/llvm/llvm_compiled_data.h @@ -118,6 +118,10 @@ class OffloadedTask { int grid_dim{0}; int dynamic_shared_array_bytes{0}; int stream_parallel_group_id{0}; + // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Populated by the CUDA LLVM codegen from + // `OffloadedStmt::graph_parallel_region_id`. The GraphManager pairs it with `stream_parallel_group_id` so two + // back-to-back regions are built as separate fork/join groups (each with its own join) instead of one merged group. + int graph_parallel_region_id{0}; // `cp_id` of the enclosing `qd.checkpoint(...)` block for this task (`-1` outside any checkpoint). Populated by the // CUDA / AMDGPU LLVM codegen from `OffloadedStmt::checkpoint_id` (set by the offload pass from // `RangeForStmt::checkpoint_id` / `StructForStmt::checkpoint_id`). The GraphManager will consume this in slice 1c to @@ -167,6 +171,7 @@ class OffloadedTask { grid_dim, dynamic_shared_array_bytes, stream_parallel_group_id, + graph_parallel_region_id, checkpoint_id, graph_do_while_level_id, ad_stack, diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index 7b0f4fa04f..75cad9f380 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -111,6 +111,7 @@ FrontendForStmt::FrontendForStmt(const FrontendForStmt &o) mem_access_opt(o.mem_access_opt), block_dim(o.block_dim), stream_parallel_group_id(o.stream_parallel_group_id), + graph_parallel_region_id(o.graph_parallel_region_id), graph_do_while_level_id(o.graph_do_while_level_id), checkpoint_id(o.checkpoint_id), loop_name(o.loop_name) { @@ -122,6 +123,7 @@ void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) { mem_access_opt = config.mem_access_opt; block_dim = config.block_dim; stream_parallel_group_id = config.stream_parallel_group_id; + graph_parallel_region_id = config.graph_parallel_region_id; graph_do_while_level_id = config.graph_do_while_level_id; checkpoint_id = config.checkpoint_id; loop_name = config.loop_name; @@ -1510,6 +1512,7 @@ void ASTBuilder::warn_if_named_nested_loop() { void ASTBuilder::begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e, const DebugInfo &dbg_info) { warn_if_named_nested_loop(); for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; + for_loop_dec_.config.graph_parallel_region_id = current_graph_parallel_region_id_; for_loop_dec_.config.graph_do_while_level_id = current_graph_do_while_level_id_; for_loop_dec_.config.checkpoint_id = current_checkpoint_id_; auto stmt_unique = std::make_unique(i, s, e, arch_, for_loop_dec_.config, dbg_info); @@ -1527,6 +1530,7 @@ void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars, "qd.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; + for_loop_dec_.config.graph_parallel_region_id = current_graph_parallel_region_id_; for_loop_dec_.config.graph_do_while_level_id = current_graph_do_while_level_id_; for_loop_dec_.config.checkpoint_id = current_checkpoint_id_; auto stmt_unique = std::make_unique(loop_vars, snode, arch_, for_loop_dec_.config, dbg_info); @@ -1544,6 +1548,7 @@ void ASTBuilder::begin_frontend_struct_for_on_external_tensor(const ExprGroup &l "qd.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; + for_loop_dec_.config.graph_parallel_region_id = current_graph_parallel_region_id_; for_loop_dec_.config.graph_do_while_level_id = current_graph_do_while_level_id_; for_loop_dec_.config.checkpoint_id = current_checkpoint_id_; auto stmt_unique = @@ -1563,6 +1568,7 @@ void ASTBuilder::begin_frontend_mesh_for(const Expr &i, "qd.loop_config(serialize=True) does not have effect on the mesh for. " "The execution order is not guaranteed."); for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; + for_loop_dec_.config.graph_parallel_region_id = current_graph_parallel_region_id_; for_loop_dec_.config.graph_do_while_level_id = current_graph_do_while_level_id_; for_loop_dec_.config.checkpoint_id = current_checkpoint_id_; auto stmt_unique = diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 39a116ea45..ff18016937 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -24,6 +24,11 @@ struct ForLoopConfig { int block_dim{0}; bool uniform{false}; int stream_parallel_group_id{0}; + // Per-kernel id of the enclosing `qd.graph_parallel_context()` region (0 outside any region). Assigned by the AST + // builder's `current_graph_parallel_region_id_` at `begin_frontend_*_for` time and threaded alongside + // `stream_parallel_group_id` so the CUDA graph builder can tell two back-to-back regions apart and keep their joins + // (without it, adjacent regions merge into one fork/join and the second can race the first). + int graph_parallel_region_id{0}; int graph_do_while_level_id{-1}; // `cp_id` (see design doc `perso_hugh/doc/qipc/reentrant.md` section 5.1) of the enclosing `qd.checkpoint(...)` block // when this for-loop is emitted, or `-1` when the for-loop is outside any checkpoint. Assigned by the AST builder's @@ -208,6 +213,7 @@ class FrontendForStmt : public Stmt { MemoryAccessOptions mem_access_opt; int block_dim; int stream_parallel_group_id{0}; + int graph_parallel_region_id{0}; int graph_do_while_level_id{-1}; int checkpoint_id{-1}; std::string loop_name; @@ -929,6 +935,7 @@ class ASTBuilder { config.block_dim = 0; config.strictly_serialized = false; config.stream_parallel_group_id = 0; + config.graph_parallel_region_id = 0; config.graph_do_while_level_id = -1; config.checkpoint_id = -1; config.loop_name.clear(); @@ -943,6 +950,12 @@ class ASTBuilder { int id_counter_{0}; int stream_parallel_group_counter_{0}; int current_stream_parallel_group_id_{0}; + // Per-kernel counter handed out by `begin_graph_parallel_context()` (one id per `qd.graph_parallel_context()` + // region), and the innermost active region id (0 outside any region). Reset per kernel via fresh ASTBuilder + // construction. Mirrors `stream_parallel_group_counter_` / `current_stream_parallel_group_id_`, but at region + // (not section) granularity, so for-loops created inside a region carry its id. + int graph_parallel_region_counter_{0}; + int current_graph_parallel_region_id_{0}; // Innermost active graph_do_while level id (-1 if not inside any). The Python AST transformer manages the stack and // calls set_graph_do_while_level_id() on enter/exit; for-loops created while it is >= 0 are tagged with it (mirrors // current_stream_parallel_group_id_). @@ -1094,6 +1107,18 @@ class ASTBuilder { current_stream_parallel_group_id_ = 0; } + // Enter a `qd.graph_parallel_context()` region: hand out a fresh per-kernel region id that every for-loop emitted + // inside the region (across all its `qd.graph_parallel()` sections) carries, so the CUDA graph builder can keep + // adjacent regions' fork/join groups apart. The Python AST transformer calls these on region enter/exit. + void begin_graph_parallel_context() { + QD_ERROR_IF(current_graph_parallel_region_id_ != 0, "qd.graph_parallel_context() regions cannot be nested"); + current_graph_parallel_region_id_ = ++graph_parallel_region_counter_; + } + + void end_graph_parallel_context() { + current_graph_parallel_region_id_ = 0; + } + // Set the innermost active graph_do_while level id. Pass the new level id when entering a graph_do_while loop, and // the parent level id (or -1) when leaving it. The Python AST transformer owns the level stack and the level table. void set_graph_do_while_level_id(int level_id) { diff --git a/quadrants/ir/ir.h b/quadrants/ir/ir.h index fb5b6dc5a9..997399507b 100644 --- a/quadrants/ir/ir.h +++ b/quadrants/ir/ir.h @@ -397,6 +397,12 @@ class StmtFieldManager { struct GraphRegionTag { int graph_do_while_level_id{-1}; int stream_parallel_group_id{0}; + // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Only ever set on the PURE-bucket + // `fallback_tag` in `offload.cpp` (so a section's pure bound-compute serial is grouped into the same region as the + // for-loop that consumes it). Source-level serial statements live outside any region, so their stamped `region_tag` + // leaves this at 0. Excluded from operator==/!= for the same reason `is_set` is: the serial bucket only ever + // compares region-0 side-effecting tags, so region id never participates. + int graph_parallel_region_id{0}; // `cp_id` (see `quadrants/lang/checkpoint.py`) of the enclosing `qd.checkpoint(...)` block, or `-1` outside any // checkpoint. Stamped on every frontend statement by `ASTBuilder::insert` (from `current_checkpoint_id_`) so a // SIDE-EFFECTING serial task tagged with a real checkpoint can be gated/skipped on `resume(from_checkpoint=...)` diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index 3e29a5a8d2..1911f71f70 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -223,6 +223,7 @@ std::unique_ptr RangeForStmt::clone() const { block_dim, strictly_serialized); new_stmt->reversed = reversed; new_stmt->stream_parallel_group_id = stream_parallel_group_id; + new_stmt->graph_parallel_region_id = graph_parallel_region_id; new_stmt->checkpoint_id = checkpoint_id; new_stmt->graph_do_while_level_id = graph_do_while_level_id; new_stmt->loop_name = loop_name; @@ -247,6 +248,7 @@ std::unique_ptr StructForStmt::clone() const { auto new_stmt = std::make_unique(snode, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim); new_stmt->mem_access_opt = mem_access_opt; new_stmt->stream_parallel_group_id = stream_parallel_group_id; + new_stmt->graph_parallel_region_id = graph_parallel_region_id; new_stmt->checkpoint_id = checkpoint_id; new_stmt->graph_do_while_level_id = graph_do_while_level_id; new_stmt->loop_name = loop_name; @@ -410,6 +412,7 @@ std::unique_ptr OffloadedStmt::clone() const { new_stmt->bls_size = bls_size; new_stmt->mem_access_opt = mem_access_opt; new_stmt->stream_parallel_group_id = stream_parallel_group_id; + new_stmt->graph_parallel_region_id = graph_parallel_region_id; new_stmt->checkpoint_id = checkpoint_id; new_stmt->graph_do_while_level_id = graph_do_while_level_id; new_stmt->loop_name = loop_name; diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 9f722cf261..9ad6b688c3 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -977,6 +977,10 @@ class RangeForStmt : public Stmt { bool strictly_serialized; std::string range_hint; int stream_parallel_group_id{0}; + // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Propagated from + // `FrontendForStmt::graph_parallel_region_id` by `lower_ast.cpp` and on to `OffloadedStmt` by `offload.cpp`. See + // ForLoopConfig comment in `frontend_ir.h`. + int graph_parallel_region_id{0}; // `cp_id` of the enclosing `qd.checkpoint(...)` block (`-1` outside any checkpoint). Propagated from // `FrontendForStmt::checkpoint_id` by `lower_ast.cpp`, then carried into the post-offload // `OffloadedStmt::checkpoint_id` by `offload.cpp`. See ForLoopConfig comment in `frontend_ir.h` for the full @@ -1015,6 +1019,7 @@ class RangeForStmt : public Stmt { block_dim, strictly_serialized, stream_parallel_group_id, + graph_parallel_region_id, checkpoint_id, graph_do_while_level_id); QD_DEFINE_ACCEPT @@ -1036,6 +1041,8 @@ class StructForStmt : public Stmt { int block_dim; MemoryAccessOptions mem_access_opt; int stream_parallel_group_id{0}; + // See `RangeForStmt::graph_parallel_region_id` -- same lifecycle. + int graph_parallel_region_id{0}; // See `RangeForStmt::checkpoint_id` -- same lifecycle, same `-1` sentinel. int checkpoint_id{-1}; int graph_do_while_level_id{-1}; @@ -1060,6 +1067,7 @@ class StructForStmt : public Stmt { block_dim, mem_access_opt, stream_parallel_group_id, + graph_parallel_region_id, checkpoint_id, graph_do_while_level_id); QD_DEFINE_ACCEPT @@ -1406,6 +1414,10 @@ class OffloadedStmt : public Stmt { std::size_t bls_size{0}; MemoryAccessOptions mem_access_opt; int stream_parallel_group_id{0}; + // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Set by `offload.cpp` from the source + // `RangeForStmt` / `StructForStmt`. The CUDA LLVM codegen copies it into `OffloadedTask::graph_parallel_region_id` + // so the GraphManager keeps adjacent regions' fork/join groups apart. + int graph_parallel_region_id{0}; // `cp_id` of the enclosing `qd.checkpoint(...)` block for this offloaded task (`-1` outside any checkpoint). Set by // `offload.cpp` from the source `RangeForStmt::checkpoint_id` / `StructForStmt::checkpoint_id`. Read by the CUDA / // AMDGPU LLVM codegen to populate `OffloadedTask::checkpoint_id`, which the GraphManager will consume in slice 1c. @@ -1463,6 +1475,7 @@ class OffloadedStmt : public Stmt { index_offsets, mem_access_opt, stream_parallel_group_id, + graph_parallel_region_id, checkpoint_id, graph_do_while_level_id); QD_DEFINE_ACCEPT diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index e589ea907f..2bada9d9cb 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -320,6 +320,8 @@ void export_lang(nb::module_ &m) { .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag) .def("begin_stream_parallel", &ASTBuilder::begin_stream_parallel) .def("end_stream_parallel", &ASTBuilder::end_stream_parallel) + .def("begin_graph_parallel_context", &ASTBuilder::begin_graph_parallel_context) + .def("end_graph_parallel_context", &ASTBuilder::end_graph_parallel_context) .def("set_graph_do_while_level_id", &ASTBuilder::set_graph_do_while_level_id) .def("begin_checkpoint", &ASTBuilder::begin_checkpoint) .def("end_checkpoint", &ASTBuilder::end_checkpoint); diff --git a/quadrants/rhi/cuda/cuda_driver_functions.inc.h b/quadrants/rhi/cuda/cuda_driver_functions.inc.h index 90361b2832..f22ff127ef 100644 --- a/quadrants/rhi/cuda/cuda_driver_functions.inc.h +++ b/quadrants/rhi/cuda/cuda_driver_functions.inc.h @@ -81,6 +81,7 @@ PER_CUDA_FUNCTION(import_external_semaphore, cuImportExternalSemaphore,CUexterna // Graph management PER_CUDA_FUNCTION(graph_create, cuGraphCreate, void **, uint32); PER_CUDA_FUNCTION(graph_add_kernel_node, cuGraphAddKernelNode, void **, void *, const void *, std::size_t, const void *); +PER_CUDA_FUNCTION(graph_add_empty_node, cuGraphAddEmptyNode, void **, void *, const void *, std::size_t); PER_CUDA_FUNCTION(graph_add_node, cuGraphAddNode, void **, void *, const void *, std::size_t, void *); PER_CUDA_FUNCTION(graph_instantiate, cuGraphInstantiate, void **, void *, void *, char *, std::size_t); PER_CUDA_FUNCTION(graph_launch, cuGraphLaunch, void *, void *); diff --git a/quadrants/runtime/cuda/graph_manager.cpp b/quadrants/runtime/cuda/graph_manager.cpp index 99c0844f2b..ca3e196b4c 100644 --- a/quadrants/runtime/cuda/graph_manager.cpp +++ b/quadrants/runtime/cuda/graph_manager.cpp @@ -287,6 +287,13 @@ void *GraphManager::add_kernel_node(void *graph, return node; } +void *GraphManager::add_empty_node(void *graph, const std::vector &deps) { + QD_ASSERT(!deps.empty()); + void *node = nullptr; + CUDADriver::get_instance().graph_add_empty_node(&node, graph, deps.data(), deps.size()); + return node; +} + unsigned long long GraphManager::create_cond_handle(void *graph) { void *cu_ctx = CUDAContext::get_instance().get_context(); unsigned long long handle = 0; @@ -397,6 +404,68 @@ void GraphManager::build_level(int parent_id, continue; } + // --- A qd.graph_parallel_context() fork/join region: a contiguous run of this level's direct, non-checkpoint + // tasks tagged with a nonzero stream_parallel_group_id (set by qd.graph_parallel()). Each distinct group id is one + // qd.graph_parallel section; the qd.graph_parallel sections fork from the region's entry (`prev_node`), run their + // tasks in order, and join into a single empty node so downstream work waits for all of them. CUDA's graph + // executor schedules the independent qd.graph_parallel section chains on separate streams -> real overlap. + // + // The run is bounded to a single region by graph_parallel_region_id: two qd.graph_parallel_context() regions + // written back-to-back (no serial task between them to break the run) carry distinct region ids, so each builds + // its own fork/join with its own join node. Without this guard the second region's sections would fork from the + // same entry as the first's and could run concurrently with -- and race -- the first region's work. --- + if (tasks[cursor].stream_parallel_group_id != 0 && tasks[cursor].checkpoint_id < 0) { + const int region_id = tasks[cursor].graph_parallel_region_id; + int run_end = cursor; + while (run_end < end && tasks[run_end].graph_do_while_level_id == parent_id && tasks[run_end].checkpoint_id < 0 && + tasks[run_end].stream_parallel_group_id != 0 && tasks[run_end].graph_parallel_region_id == region_id) { + run_end++; + } + // Bucket the run's tasks by qd.graph_parallel section id, preserving first-seen (declaration) order. + std::vector group_ids; + std::vector> parallel_sections; + for (int t = cursor; t < run_end; t++) { + const int g = tasks[t].stream_parallel_group_id; + int idx = -1; + for (int k = 0; k < (int)group_ids.size(); k++) { + if (group_ids[k] == g) { + idx = k; + break; + } + } + if (idx < 0) { + idx = (int)group_ids.size(); + group_ids.push_back(g); + parallel_sections.emplace_back(); + } + parallel_sections[idx].push_back(t); + } + void *ctx_ptr = &cached.persistent_ctx; + std::vector tails; + tails.reserve(parallel_sections.size()); + for (auto &ps : parallel_sections) { + void *bp = prev_node; // every qd.graph_parallel section forks from the region entry dependency + for (int t : ps) { + bp = add_kernel_node(target_graph, bp, cuda_module->lookup_function(tasks[t].name), + (unsigned int)tasks[t].grid_dim, (unsigned int)tasks[t].block_dim, + (unsigned int)tasks[t].dynamic_shared_array_bytes, &ctx_ptr); + ++total_nodes; + } + tails.push_back(bp); + } + // Join. A region with a single qd.graph_parallel section (e.g. an optional qd.graph_parallel section + // compiled out) has nothing to join, so just continue the chain from its tail; otherwise collect all + // tails into one empty successor node. + if (tails.size() == 1) { + prev_node = tails[0]; + } else { + prev_node = add_empty_node(target_graph, tails); + ++total_nodes; + } + cursor = run_end; + continue; + } + // --- A direct task of this level. Group consecutive tasks by checkpoint_id. --- const int cp = tasks[cursor].checkpoint_id; if (cp < 0) { diff --git a/quadrants/runtime/cuda/graph_manager.h b/quadrants/runtime/cuda/graph_manager.h index f637f62574..210504d2f6 100644 --- a/quadrants/runtime/cuda/graph_manager.h +++ b/quadrants/runtime/cuda/graph_manager.h @@ -201,6 +201,10 @@ class GraphManager { unsigned int block_dim, unsigned int shared_mem, void **kernel_params); + // Add an empty (no-op) node to `graph` depending on every node in `deps`. Used as the join point of a + // qd.graph_parallel_context() region: it has no work but collects all qd.graph_parallel section tails into + // a single successor so downstream nodes wait for every qd.graph_parallel section. `deps` must be non-empty. + void *add_empty_node(void *graph, const std::vector &deps); // Recursively build the nodes for graph_do_while level `parent_id` (-1 = kernel top level) over the task range // [begin, end) into `target_graph` (the body graph of `parent_id`, or the root graph for -1). Direct tasks become // kernel nodes; a contiguous run of direct tasks sharing a non-negative `checkpoint_id` is wrapped in a gate-kernel + diff --git a/quadrants/transforms/lower_ast.cpp b/quadrants/transforms/lower_ast.cpp index fe8084c6bc..31a5b29b21 100644 --- a/quadrants/transforms/lower_ast.cpp +++ b/quadrants/transforms/lower_ast.cpp @@ -232,6 +232,7 @@ class LowerAST : public IRVisitor { new_for->loop_name = stmt->loop_name; new_for->index_offsets = offsets; new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; + new_for->graph_parallel_region_id = stmt->graph_parallel_region_id; new_for->checkpoint_id = stmt->checkpoint_id; new_for->graph_do_while_level_id = stmt->graph_do_while_level_id; VecStatement new_statements; @@ -269,6 +270,7 @@ class LowerAST : public IRVisitor { /*range_hint=*/fmt::format("arg ({})", fmt::join(arg_id, ", ")), /*loop_name=*/stmt->loop_name); new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; + new_for->graph_parallel_region_id = stmt->graph_parallel_region_id; new_for->checkpoint_id = stmt->checkpoint_id; new_for->graph_do_while_level_id = stmt->graph_do_while_level_id; VecStatement new_statements; @@ -306,6 +308,7 @@ class LowerAST : public IRVisitor { stmt->strictly_serialized, /*range_hint=*/"", /*loop_name=*/stmt->loop_name); new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; + new_for->graph_parallel_region_id = stmt->graph_parallel_region_id; new_for->checkpoint_id = stmt->checkpoint_id; new_for->graph_do_while_level_id = stmt->graph_do_while_level_id; new_for->body->insert(std::make_unique(new_for.get(), 0), 0); diff --git a/quadrants/transforms/offload.cpp b/quadrants/transforms/offload.cpp index 8baa775f14..36ec81ead3 100644 --- a/quadrants/transforms/offload.cpp +++ b/quadrants/transforms/offload.cpp @@ -109,6 +109,7 @@ class Offloader { const GraphRegionTag tag = bucket_has_side_effect ? bucket_tag : fallback_tag; pending_serial_statements->graph_do_while_level_id = tag.graph_do_while_level_id; pending_serial_statements->stream_parallel_group_id = tag.stream_parallel_group_id; + pending_serial_statements->graph_parallel_region_id = tag.graph_parallel_region_id; pending_serial_statements->checkpoint_id = tag.checkpoint_id; root_block->insert(std::move(pending_serial_statements)); pending_serial_statements = Stmt::make_typed(OffloadedStmt::TaskType::serial, arch, kernel); @@ -156,7 +157,9 @@ class Offloader { auto &stmt = root_statements[i]; // Note that stmt->parent is root_block, which doesn't contain stmt now. if (auto s = stmt->cast(); s && !s->strictly_serialized) { - assemble_serial_statements(GraphRegionTag{s->graph_do_while_level_id, s->stream_parallel_group_id}); + GraphRegionTag pre_for_tag{s->graph_do_while_level_id, s->stream_parallel_group_id}; + pre_for_tag.graph_parallel_region_id = s->graph_parallel_region_id; + assemble_serial_statements(pre_for_tag); auto offloaded = Stmt::make_typed(OffloadedStmt::TaskType::range_for, arch, kernel); // offloaded->body is an empty block now. offloaded->grid_dim = config.saturating_grid_dim; @@ -192,12 +195,15 @@ class Offloader { } offloaded->range_hint = s->range_hint; offloaded->stream_parallel_group_id = s->stream_parallel_group_id; + offloaded->graph_parallel_region_id = s->graph_parallel_region_id; offloaded->graph_do_while_level_id = s->graph_do_while_level_id; offloaded->checkpoint_id = s->checkpoint_id; offloaded->loop_name = s->loop_name; root_block->insert(std::move(offloaded)); } else if (auto st = stmt->cast()) { - assemble_serial_statements(GraphRegionTag{st->graph_do_while_level_id, st->stream_parallel_group_id}); + GraphRegionTag pre_for_tag{st->graph_do_while_level_id, st->stream_parallel_group_id}; + pre_for_tag.graph_parallel_region_id = st->graph_parallel_region_id; + assemble_serial_statements(pre_for_tag); emit_struct_for(st, root_block, config, st->mem_access_opt); } else if (auto st = stmt->cast()) { assemble_serial_statements(GraphRegionTag{st->graph_do_while_level_id, /*group=*/0}); @@ -309,6 +315,7 @@ class Offloader { offloaded_struct_for->num_cpu_threads = std::min(for_stmt->num_cpu_threads, config.cpu_max_num_threads); offloaded_struct_for->mem_access_opt = mem_access_opt; offloaded_struct_for->stream_parallel_group_id = for_stmt->stream_parallel_group_id; + offloaded_struct_for->graph_parallel_region_id = for_stmt->graph_parallel_region_id; offloaded_struct_for->graph_do_while_level_id = for_stmt->graph_do_while_level_id; offloaded_struct_for->checkpoint_id = for_stmt->checkpoint_id; offloaded_struct_for->loop_name = for_stmt->loop_name; diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 824b6ea7a6..e6b316d911 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -157,6 +157,8 @@ def _get_expected_matrix_apis(): "global_thread_idx", "gpu", "graph_do_while", + "graph_parallel", + "graph_parallel_context", "grouped", "i", "i16", diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py new file mode 100644 index 0000000000..b6f0e780d7 --- /dev/null +++ b/tests/python/test_graph_parallel.py @@ -0,0 +1,684 @@ +"""Tests for qd.graph_parallel_context / qd.graph_parallel -- concurrent fork/join sections in graph +kernels. + +`with qd.graph_parallel_context():` opens a fork/join region whose `with qd.graph_parallel():` members are +independent sequences of work. On the graph path the qd.graph_parallel sections become independent graph +chains joined by a single empty node, so the runtime schedules them on parallel streams; on other backends +(CPU / AMDGPU / Vulkan / Metal) they run serially but produce identical results. + +The behavioral assertions (disjoint-array correctness) hold on every backend. The graph-structure +assertions (node counts: one kernel node per qd.graph_parallel section task + one empty join node) only +apply where the builder forks/joins (CUDA today), so they are guarded by `_on_cuda()`. +""" + +import numpy as np +import pytest + +import quadrants as qd +from quadrants.lang import impl + +from tests import test_utils + + +def _on_cuda(): + return impl.current_cfg().arch == qd.cuda + + +def _platform_supports_graph(): + arch = impl.current_cfg().arch + return arch == qd.cuda or arch == qd.amdgpu + + +def _graph_num_nodes(): + return impl.get_runtime().prog.get_graph_num_nodes_on_last_call() + + +def _num_offloaded_tasks(): + return impl.get_runtime().prog.get_num_offloaded_tasks_on_last_call() + + +@test_utils.test() +def test_graph_parallel_is_no_op_outside_kernels(): + """At Python runtime (outside kernels) qd.graph_parallel_context / qd.graph_parallel must be usable no-op context + managers, so helpers that are sometimes called from Python and sometimes from kernels still import + and run. Mirrors qd.stream_parallel / qd.checkpoint.""" + sentinel = [] + with qd.graph_parallel_context(): + with qd.graph_parallel(): + sentinel.append("a") + with qd.graph_parallel(): + sentinel.append("b") + assert sentinel == ["a", "b"] + + +@test_utils.test() +def test_graph_parallel_two_sections(): + """Two qd.graph_parallel sections write disjoint arrays; a serial loop after the region reads both (so + it depends on the join). Results must match the serial reference on every backend; on CUDA the graph + has one node per task plus one empty join node.""" + n = 1024 + + @qd.kernel(graph=True) + def k( + x: qd.types.ndarray(qd.f32, ndim=1), + y: qd.types.ndarray(qd.f32, ndim=1), + z: qd.types.ndarray(qd.f32, ndim=1), + ): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] + 2.0 + for i in range(z.shape[0]): + z[i] = x[i] + y[i] + + x = qd.ndarray(qd.f32, shape=(n,)) + y = qd.ndarray(qd.f32, shape=(n,)) + z = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + z.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x, y, z) + + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + # One graph node per offloaded task (each dynamic-bound loop is a bound-compute serial + a + # range_for, both in the qd.graph_parallel section) plus exactly one empty join node for the region. + assert _graph_num_nodes() == num_tasks + 1 + + np.testing.assert_allclose(x.to_numpy(), 1.0) + np.testing.assert_allclose(y.to_numpy(), 2.0) + np.testing.assert_allclose(z.to_numpy(), 3.0) + + # Relaunch: same cached graph, same result. + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + k(x, y, z) + np.testing.assert_allclose(z.to_numpy(), 3.0) + + +@test_utils.test() +def test_graph_parallel_back_to_back_regions_keep_join(): + """Two qd.graph_parallel_context() regions written back-to-back (no serial work between them) must each get + their own fork/join. Region B reads what region A wrote, so if the two regions were merged into one fork/join + (dropping A's join) B's sections would fork alongside -- and race -- A's. On CUDA we assert the graph has one + empty join node per region (two total); on every backend the post-join values must be correct. Regression test + for the back-to-back-region merge bug. + + Each region's two sections stay mutually disjoint (one touches only x, the other only y); the only data + dependency is across the region boundary (B's x-section reads A's x-section output, B's y-section reads A's + y-section output), which is exactly the edge the join protects.""" + n = 1024 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): + # Region A: write x and y in two independent sections. + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = 1.0 + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = 2.0 + # Region B, immediately after A (no serial statement between the regions). Each section reads the region-A + # section that wrote the same array, so B must wait for A's join. B's sections remain disjoint from each + # other (x-only vs y-only). + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] * 10.0 + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] * 10.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + y = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x, y) + + if _on_cuda(): + # One empty join node per region (two regions -> two joins). The merge bug this guards would build a single + # fork/join across both regions, emitting only one join, i.e. num_tasks + 1. + assert _graph_num_nodes() == _num_offloaded_tasks() + 2 + + # Correct only if region A fully joined before region B ran: x = 1 * 10, y = 2 * 10. + np.testing.assert_allclose(x.to_numpy(), 10.0) + np.testing.assert_allclose(y.to_numpy(), 20.0) + + # Relaunch: same cached graph, same result. + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + k(x, y) + np.testing.assert_allclose(x.to_numpy(), 10.0) + np.testing.assert_allclose(y.to_numpy(), 20.0) + + +@test_utils.test() +def test_graph_parallel_three_sections(): + """Fan-out of three independent qd.graph_parallel sections; one empty join node.""" + n = 256 + + @qd.kernel(graph=True) + def k( + a: qd.types.ndarray(qd.f32, ndim=1), + b: qd.types.ndarray(qd.f32, ndim=1), + c: qd.types.ndarray(qd.f32, ndim=1), + ): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(a.shape[0]): + a[i] = a[i] + 1.0 + with qd.graph_parallel(): + for i in range(b.shape[0]): + b[i] = b[i] + 2.0 + with qd.graph_parallel(): + for i in range(c.shape[0]): + c[i] = c[i] + 3.0 + + a = qd.ndarray(qd.f32, shape=(n,)) + b = qd.ndarray(qd.f32, shape=(n,)) + c = qd.ndarray(qd.f32, shape=(n,)) + for arr in (a, b, c): + arr.from_numpy(np.zeros(n, dtype=np.float32)) + + k(a, b, c) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 # three qd.graph_parallel sections + one join + + np.testing.assert_allclose(a.to_numpy(), 1.0) + np.testing.assert_allclose(b.to_numpy(), 2.0) + np.testing.assert_allclose(c.to_numpy(), 3.0) + + +@test_utils.test() +def test_graph_parallel_multi_loop_sections(): + """Each qd.graph_parallel section contains several loops; they must chain in order inside the + qd.graph_parallel section while the two qd.graph_parallel sections run independently. qd.graph_parallel + section tasks = 4, plus one join node on CUDA.""" + n = 128 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + for i in range(x.shape[0]): + x[i] = x[i] * 2.0 + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] + 3.0 + for i in range(y.shape[0]): + y[i] = y[i] * 4.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + y = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x, y) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 # all qd.graph_parallel section tasks + one join + + np.testing.assert_allclose(x.to_numpy(), 2.0) # (0+1)*2 + np.testing.assert_allclose(y.to_numpy(), 12.0) # (0+3)*4 + + +@test_utils.test() +def test_graph_parallel_single_section_no_join(): + """A region with a single qd.graph_parallel section (e.g. an optional qd.graph_parallel section compiled + out) needs no join: it degenerates to a plain chain, so the node count equals the number of + qd.graph_parallel section tasks (no extra empty node).""" + n = 256 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 5.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks # single qd.graph_parallel section -> plain chain, no join + + np.testing.assert_allclose(x.to_numpy(), 5.0) + + +@test_utils.test() +def test_graph_parallel_optional_section_static_if(): + """The qipc ENABLE_EE pattern: a qd.graph_parallel section wrapped in `if qd.static(...)`. When the flag + is False the qd.graph_parallel section is compiled out (region has one qd.graph_parallel section -> no + join); when True both qd.graph_parallel sections run.""" + n = 128 + + @qd.kernel(graph=True) + def k_off(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + if qd.static(False): + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] + 1.0 + + @qd.kernel(graph=True) + def k_on(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + if qd.static(True): + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + y = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + k_off(x, y) + if _on_cuda(): + assert _graph_num_nodes() == _num_offloaded_tasks() # single qd.graph_parallel section -> no join + np.testing.assert_allclose(x.to_numpy(), 1.0) + np.testing.assert_allclose(y.to_numpy(), 0.0) # EE qd.graph_parallel section compiled out + + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + k_on(x, y) + if _on_cuda(): + assert _graph_num_nodes() == _num_offloaded_tasks() + 1 # two qd.graph_parallel sections + join + np.testing.assert_allclose(x.to_numpy(), 1.0) + np.testing.assert_allclose(y.to_numpy(), 1.0) + + +@test_utils.test() +def test_graph_parallel_inside_graph_do_while(): + """A fork/join region inside a qd.graph_do_while loop body must be correct across iterations: each + iteration runs both qd.graph_parallel sections, then decrements the counter.""" + n = 64 + iters = 5 + + @qd.kernel(graph=True) + def k( + x: qd.types.ndarray(qd.i32, ndim=1), + y: qd.types.ndarray(qd.i32, ndim=1), + counter: qd.types.ndarray(qd.i32, ndim=0), + ): + while qd.graph_do_while(counter): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] + 2 + for _ in range(1): + counter[()] = counter[()] - 1 + + x = qd.ndarray(qd.i32, shape=(n,)) + y = qd.ndarray(qd.i32, shape=(n,)) + counter = qd.ndarray(qd.i32, shape=()) + x.from_numpy(np.zeros(n, dtype=np.int32)) + y.from_numpy(np.zeros(n, dtype=np.int32)) + counter.from_numpy(np.array(iters, dtype=np.int32)) + + k(x, y, counter) + + assert counter.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(n, iters, dtype=np.int32)) + np.testing.assert_array_equal(y.to_numpy(), np.full(n, 2 * iters, dtype=np.int32)) + + +@test_utils.test() +def test_graph_parallel_outside_context_raises(): + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises( + qd.QuadrantsSyntaxError, match="qd.graph_parallel.. can only be used .* inside a qd.graph_parallel_context" + ): + k(x) + + +@test_utils.test() +def test_graph_parallel_requires_graph_kernel(): + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError, match="requires @qd.kernel.graph=True"): + k(x) + + +@test_utils.test() +def test_graph_parallel_context_non_graph_parallel_raises(): + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.graph_parallel"): + k(x) + + +@test_utils.test() +def test_graph_parallel_nested_region_raises(): + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError): + k(x) + + +@test_utils.test() +def test_graph_parallel_static_loop_two_sections(): + """`for b in qd.static(range(NB))` unrolls into NB literal qd.graph_parallel sections, each writing a + disjoint row.""" + nb = 2 + n = 256 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2)): + with qd.graph_parallel_context(): + for b in qd.static(range(nb)): + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + (b + 1) + + x = qd.ndarray(qd.f32, shape=(nb, n)) + x.from_numpy(np.zeros((nb, n), dtype=np.float32)) + + k(x) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 # nb qd.graph_parallel sections + one join + + out = x.to_numpy() + np.testing.assert_allclose(out[0], 1.0) + np.testing.assert_allclose(out[1], 2.0) + + +@test_utils.test() +def test_graph_parallel_static_loop_over_funcs(): + """The motivating pattern: a @qd.data_oriented class iterates a static list of @qd.func members, one + qd.graph_parallel section each (mirrors qipc's per-contact-type assembly funcs).""" + n = 4 + + @qd.data_oriented + class Demo: + def __init__(self): + self.a = qd.field(qd.i32, shape=(n,)) + self.b = qd.field(qd.i32, shape=(n,)) + self.funcs = [self._fill_a, self._fill_b] + + @qd.func + def _fill_a(self): + for i in range(n): + self.a[i] += 1 + + @qd.func + def _fill_b(self): + for i in range(n): + self.b[i] += 10 + + @qd.kernel(graph=True) + def step(self): + with qd.graph_parallel_context(): + for i in qd.static(range(len(self.funcs))): + with qd.graph_parallel(): + self.funcs[i]() + + d = Demo() + d.a.from_numpy(np.zeros(n, dtype=np.int32)) + d.b.from_numpy(np.zeros(n, dtype=np.int32)) + d.step() + np.testing.assert_array_equal(d.a.to_numpy(), np.ones(n, dtype=np.int32)) + np.testing.assert_array_equal(d.b.to_numpy(), np.full(n, 10, dtype=np.int32)) + + +@test_utils.test() +def test_graph_parallel_static_loop_single_section(): + """A static loop of one iteration is a single-section region: a plain chain, no join node.""" + n = 256 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + for _b in qd.static(range(1)): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 5.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks # single section -> no join node + + np.testing.assert_allclose(x.to_numpy(), 5.0) + + +@test_utils.test() +def test_graph_parallel_static_loop_empty_range(): + """An empty static range produces zero qd.graph_parallel sections: the region is a no-op (consistent + with wrapping the only section in `if qd.static(False)`). Serial work after it still runs.""" + n = 128 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + for _b in qd.static(range(0)): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + for i in range(x.shape[0]): + x[i] = x[i] + 5.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x) + np.testing.assert_allclose(x.to_numpy(), 5.0) # region did nothing; only the serial +5 applied + + +@test_utils.test() +def test_graph_parallel_static_loop_nested(): + """Nested static loops fan out to N*M qd.graph_parallel sections, each writing a disjoint row.""" + ni, nj = 2, 2 + nrows = ni * nj + n = 64 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2)): + with qd.graph_parallel_context(): + for i in qd.static(range(ni)): + for j in qd.static(range(nj)): + with qd.graph_parallel(): + for c in range(x.shape[1]): + x[i * nj + j, c] = x[i * nj + j, c] + (i * nj + j + 1) + + x = qd.ndarray(qd.f32, shape=(nrows, n)) + x.from_numpy(np.zeros((nrows, n), dtype=np.float32)) + + k(x) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 # nrows qd.graph_parallel sections + one join + + out = x.to_numpy() + for r in range(nrows): + np.testing.assert_allclose(out[r], float(r + 1)) + + +@test_utils.test() +def test_graph_parallel_static_loop_mixed_with_static_if(): + """A static section loop and an `if qd.static(...)` optional qd.graph_parallel section coexist in one + region.""" + nb = 2 + n = 64 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2), y: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + for b in qd.static(range(nb)): + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + (b + 1) + if qd.static(True): + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] + 7.0 + + x = qd.ndarray(qd.f32, shape=(nb, n)) + y = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros((nb, n), dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x, y) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 # nb + 1 qd.graph_parallel sections + one join + + out = x.to_numpy() + np.testing.assert_allclose(out[0], 1.0) + np.testing.assert_allclose(out[1], 2.0) + np.testing.assert_allclose(y.to_numpy(), 7.0) + + +@test_utils.test() +def test_graph_parallel_runtime_loop_raises(): + """A *runtime* for-loop in a region body stays rejected: only `qd.static(...)` loops unroll to literal + qd.graph_parallel sections; a runtime range would nest the section tagging inside a parallel range_for + (malformed).""" + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2), nb: qd.i32): + with qd.graph_parallel_context(): + for b in range(nb): + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(2, 16)) + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.graph_parallel"): + k(x, 2) + + +@test_utils.test() +def test_graph_parallel_takes_no_arguments(): + """qd.graph_parallel() (the section) takes no arguments. Any argument raises.""" + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + with qd.graph_parallel(name="bx"): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError, match="qd.graph_parallel.. takes no arguments"): + k(x) + + +@test_utils.test() +def test_graph_parallel_static_loop_body_non_section_raises(): + """A static loop body must still be section-only: serial work inside the loop (outside any + qd.graph_parallel section) would silently fall outside a section, so it is rejected (the validator + recurses into the loop body).""" + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2)): + with qd.graph_parallel_context(): + for b in qd.static(range(2)): + x[b, 0] = 1.0 # serial work outside any qd.graph_parallel section + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(2, 16)) + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.graph_parallel"): + k(x) + + +@test_utils.test() +def test_graph_parallel_static_loop_runtime_inner_loop_raises(): + """Staticness is re-checked at every nesting level: a *runtime* loop nested inside a static loop and + wrapping a qd.graph_parallel section is still rejected (only the static unroll yields independent + sections).""" + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2), m: qd.i32): + with qd.graph_parallel_context(): + for b in qd.static(range(2)): + for _j in range(m): # runtime loop around a section -> rejected + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(2, 16)) + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.graph_parallel"): + k(x, 2) + + +@test_utils.test() +def test_graph_parallel_static_loop_inside_graph_do_while(): + """A static section loop composes with qd.graph_do_while: each iteration runs all unrolled + qd.graph_parallel sections, then decrements the counter.""" + nb = 2 + n = 64 + iters = 4 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.i32, ndim=2), counter: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(counter): + with qd.graph_parallel_context(): + for b in qd.static(range(nb)): + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + (b + 1) + for _ in range(1): + counter[()] = counter[()] - 1 + + x = qd.ndarray(qd.i32, shape=(nb, n)) + counter = qd.ndarray(qd.i32, shape=()) + x.from_numpy(np.zeros((nb, n), dtype=np.int32)) + counter.from_numpy(np.array(iters, dtype=np.int32)) + + k(x, counter) + + assert counter.to_numpy() == 0 + out = x.to_numpy() + np.testing.assert_array_equal(out[0], np.full(n, iters, dtype=np.int32)) + np.testing.assert_array_equal(out[1], np.full(n, 2 * iters, dtype=np.int32))