From 1954d03df5335f1dafbd5d13b0174c57caa581eb Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 9 Jun 2026 00:04:50 -0700 Subject: [PATCH 01/25] feat(graph): add qd.graph_parallel/qd.branch concurrent branches in graph kernels Lets a @qd.kernel(graph=True) author mark independent sequences of work as concurrent branches so the captured CUDA graph runs them on parallel streams (recovers the PT||EE overlap qipc loses vs cgq). API: `with qd.graph_parallel():` opens a fork/join region whose members are `with qd.branch(name=...):` blocks (name optional, label only). Region body must contain only branch blocks; branches are independent (author-guaranteed race-free) and everything after the region waits for all of them (join). Implementation reuses the existing stream_parallel_group_id tag: qd.branch lowers via begin/end_stream_parallel, so the branch id flows through the existing offload/codegen path to OffloadedTask. The only runtime change is in the CUDA graph builder (build_level): a contiguous run of nonzero-group, non-checkpoint tasks is forked by group id from the region entry, each branch chained, and all tails joined into one cuGraphAddEmptyNode. Single-branch regions degenerate to a plain chain. Validation: graph=True required; branch only inside a region; region body must be branches only; no nesting; graph_do_while structure validator updated to accept regions. Other backends (CPU/AMDGPU/Vulkan/Metal) run branches serially (correct); CUDA graph path runs them concurrently. Design: perso_hugh/doc/qipc/d3_0_graph_parallel_impl.md --- python/quadrants/lang/ast/ast_transformer.py | 107 +++++++++++++++++- .../function_def_transformer.py | 41 +++++++ python/quadrants/lang/misc.py | 51 +++++++++ .../rhi/cuda/cuda_driver_functions.inc.h | 1 + quadrants/runtime/cuda/graph_manager.cpp | 62 ++++++++++ quadrants/runtime/cuda/graph_manager.h | 4 + 6 files changed, 265 insertions(+), 1 deletion(-) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 0297fffa8c..2800738e80 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1372,6 +1372,49 @@ def _is_checkpoint_call(node: ast.expr, global_vars: dict): ``CheckpointCallInfo`` or ``None``.""" return CheckpointTransformer.is_checkpoint_call(node, global_vars) + @staticmethod + def _is_graph_parallel_call(node: ast.expr) -> bool: + """If *node* is a ``qd.graph_parallel()`` call return True, else False.""" + if not isinstance(node, ast.Call): + return False + func = node.func + is_gp = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( + isinstance(func, ast.Name) and func.id == "graph_parallel" + ) + if not is_gp: + return False + if node.args or node.keywords: + raise QuadrantsSyntaxError("qd.graph_parallel() takes no arguments") + return True + + @staticmethod + def _is_branch_call(node: ast.expr) -> tuple[bool, str | None]: + """If *node* is ``qd.branch(...)`` return ``(True, name)``; otherwise ``(False, None)``. + + ``name`` is the value of the optional ``name=`` kwarg (a string literal) or ``None``. The call + shape is validated here so misuse raises at the ``with`` site rather than later. + """ + if not isinstance(node, ast.Call): + return False, None + func = node.func + is_branch = (isinstance(func, ast.Attribute) and func.attr == "branch") or ( + isinstance(func, ast.Name) and func.id == "branch" + ) + if not is_branch: + return False, None + if node.args: + raise QuadrantsSyntaxError("qd.branch() takes no positional arguments; use qd.branch(name='...') instead") + name: str | None = None + for kw in node.keywords: + if kw.arg != "name": + raise QuadrantsSyntaxError( + f"qd.branch() got unexpected keyword argument {kw.arg!r}; only 'name' is supported" + ) + if not (isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str)): + raise QuadrantsSyntaxError("qd.branch(name=...) must be a string literal") + name = kw.value.value + return True, name + @staticmethod def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: if node.orelse: @@ -1615,9 +1658,17 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: if checkpoint_info is not None: return ASTTransformer._build_checkpoint_with(ctx, node, checkpoint_info) + if ASTTransformer._is_graph_parallel_call(item.context_expr): + return ASTTransformer._build_graph_parallel_with(ctx, node) + + is_branch, branch_name = ASTTransformer._is_branch_call(item.context_expr) + if is_branch: + return ASTTransformer._build_branch_with(ctx, node, branch_name) + if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars): raise QuadrantsSyntaxError( - "'with' in Quadrants kernels only supports qd.stream_parallel() or qd.checkpoint()" + "'with' in Quadrants kernels only supports qd.stream_parallel(), qd.checkpoint(), " + "qd.graph_parallel(), or qd.branch()" ) if not ctx.is_kernel: raise QuadrantsSyntaxError("qd.stream_parallel() can only be used inside @qd.kernel, not @qd.func") @@ -1636,6 +1687,60 @@ def _build_checkpoint_with( ``ast_transformers/checkpoint_transformer.py``.""" return CheckpointTransformer.build_checkpoint_with(ctx, node, info, build_stmts) + @staticmethod + def _build_graph_parallel_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: + """Handles ``with qd.graph_parallel():`` fork/join regions. + + Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains + only ``with qd.branch():`` blocks, then walks the body. The region emits no IR tag of its own -- + each ``branch`` inside lowers to a stream-parallel group (via begin/end_stream_parallel), and the + CUDA graph builder forks the distinct groups in a contiguous run and joins them. Regions are kept + apart by the serial work between them (see d3_0_graph_parallel_impl.md).""" + if not ctx.is_kernel: + raise QuadrantsSyntaxError("qd.graph_parallel() can only be used inside @qd.kernel, not @qd.func") + kernel = ctx.global_context.current_kernel + if kernel is None or not kernel.use_graph: + raise QuadrantsSyntaxError("qd.graph_parallel() requires @qd.kernel(graph=True)") + if getattr(ctx, "_in_graph_parallel", False): + raise QuadrantsSyntaxError("qd.graph_parallel() regions cannot be nested") + if getattr(ctx, "_in_branch", False): + raise QuadrantsSyntaxError("qd.graph_parallel() cannot appear inside a qd.branch() body") + for i, stmt in enumerate(node.body): + if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): + continue + is_branch = False + if isinstance(stmt, ast.With) and stmt.items: + is_branch, _ = ASTTransformer._is_branch_call(stmt.items[0].context_expr) + if not is_branch: + raise QuadrantsSyntaxError( + "A qd.graph_parallel() region may contain only 'with qd.branch():' blocks " + f"[offending stmt {i}: {type(stmt).__name__}]" + ) + ctx._in_graph_parallel = True + try: + build_stmts(ctx, node.body) + finally: + ctx._in_graph_parallel = False + return None + + @staticmethod + def _build_branch_with(ctx: ASTTransformerFuncContext, node: ast.With, name: str | None) -> None: + """Handles ``with qd.branch():`` members of a ``qd.graph_parallel()`` region. + + Reuses the stream-parallel tagging: begin_stream_parallel() assigns this branch a fresh + ``stream_parallel_group_id`` that every for-loop in the body inherits, so the offloaded tasks + carry the branch id all the way to the graph builder. ``name`` is currently a label only.""" + if not getattr(ctx, "_in_graph_parallel", False): + raise QuadrantsSyntaxError("qd.branch() can only be used directly inside a qd.graph_parallel() region") + ctx._in_branch = True + ctx.ast_builder.begin_stream_parallel() + try: + build_stmts(ctx, node.body) + finally: + ctx.ast_builder.end_stream_parallel() + ctx._in_branch = False + return None + @staticmethod def build_Pass(ctx: ASTTransformerFuncContext, node: ast.Pass) -> None: return None diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 4e7b8e5154..5fc8cad256 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -600,6 +600,38 @@ def _is_loop_config_call(stmt: ast.stmt) -> bool: return True return False + @staticmethod + def _is_graph_parallel_with(stmt: ast.stmt) -> bool: + """Syntactic check matching ASTTransformer._is_graph_parallel_call: a + ``with qd.graph_parallel():`` fork/join region.""" + if not isinstance(stmt, ast.With) or len(stmt.items) != 1: + return False + ctx_expr = stmt.items[0].context_expr + if not isinstance(ctx_expr, ast.Call): + return False + func = ctx_expr.func + if isinstance(func, ast.Attribute) and func.attr == "graph_parallel": + return True + if isinstance(func, ast.Name) and func.id == "graph_parallel": + return True + return False + + @staticmethod + def _is_branch_with(stmt: ast.stmt) -> bool: + """Syntactic check matching ASTTransformer._is_branch_call: a ``with qd.branch(...):`` member + of a ``qd.graph_parallel()`` region.""" + if not isinstance(stmt, ast.With) or len(stmt.items) != 1: + return False + ctx_expr = stmt.items[0].context_expr + if not isinstance(ctx_expr, ast.Call): + return False + func = ctx_expr.func + if isinstance(func, ast.Attribute) and func.attr == "branch": + return True + if isinstance(func, ast.Name) and func.id == "branch": + return True + return False + @staticmethod def _validate_graph_do_while_structure(body: list[ast.stmt]) -> None: """If a kernel uses qd.graph_do_while() anywhere, enforce the structural rules that remain after @@ -659,6 +691,15 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo # `CheckpointTransformer.build_checkpoint_with`. FunctionDefTransformer._validate_graph_do_while_stmt_list(stmt.body, is_kernel_top=is_kernel_top) continue + if FunctionDefTransformer._is_graph_parallel_with(stmt): + # A `with qd.graph_parallel()` region groups concurrent `with qd.branch()` members; it is + # a legal sibling of for-loops / checkpoints. Its body must be branch blocks only (enforced + # fully in ASTTransformer._build_graph_parallel_with); each branch body is task territory, + # validated with the in-loop rules. + for member in stmt.body: + if FunctionDefTransformer._is_branch_with(member): + FunctionDefTransformer._validate_graph_do_while_stmt_list(member.body, is_kernel_top=False) + continue where = "the kernel body" if is_kernel_top else "a qd.graph_do_while() body" raise QuadrantsSyntaxError( f"When a kernel uses qd.graph_do_while(), {where} may not contain a {type(stmt).__name__} " diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index 6c45b1f1dc..3bfc147ce9 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -751,6 +751,55 @@ def graph_do_while(condition) -> bool: return bool(condition) +@contextmanager +def graph_parallel(): + """Opens a fork/join region whose ``qd.branch()`` members run concurrently. + + Used as ``with qd.graph_parallel():`` inside a ``@qd.kernel(graph=True)`` kernel. The region's body + must contain only ``with qd.branch():`` blocks. Each branch is an independent sequence of work; the + branches have no ordering relative to each other and may execute concurrently, while everything after + the region waits for *all* branches to finish (the join). This is the CUDA-graph analogue of + ``qd.stream_parallel()`` (which is for non-graph kernels): it lets independent stages -- e.g. qipc's + point-triangle and edge-edge assembly -- overlap inside a captured graph. + + Concurrency contract (the author's responsibility): branches must be data-race free with respect to + one another (no branch reads what another writes, no two branches write the same location). Calls + *within* a branch keep their program order. + + Backend behaviour: + - CUDA SM graph path: branches become independent graph chains joined by an empty node, so the + runtime schedules them on parallel streams (real overlap). + - CPU / Vulkan / Metal / AMDGPU graph: correct results, branches run serially (the concurrency + tags are honoured only by the CUDA graph builder today). + + Restrictions (enforced at kernel compile time): + - Must be used inside ``@qd.kernel(graph=True)``. + - The region body may contain only ``with qd.branch():`` blocks. + - Regions cannot be nested, and a branch body must be straight-line task work (no nested + ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel``). + + This function should not be called directly at runtime; it is recognised and transformed during AST + compilation. At Python runtime (outside kernels) it is a no-op context manager. + + See also ``docs/source/user_guide/graph.md``. + """ + yield + + +@contextmanager +def branch(name=None): + """Declares one concurrent member of an enclosing ``qd.graph_parallel()`` region. + + Used as ``with qd.branch():`` or ``with qd.branch(name="pt"):`` directly inside a + ``with qd.graph_parallel():`` block. The branch's body is an independent sequence of work that may + run concurrently with the region's other branches. ``name`` is optional and used only as a label for + profiling / graph introspection. + + See ``qd.graph_parallel()`` for the full contract and backend behaviour. + """ + yield + + def global_thread_idx(): """Returns the global thread id of this running thread, only available for cpu and cuda backends. @@ -890,6 +939,8 @@ def dump_compile_config() -> None: "GraphStatus", "checkpoint", "graph_do_while", + "graph_parallel", + "branch", "loop_config", "global_thread_idx", "assume_in_range", diff --git a/quadrants/rhi/cuda/cuda_driver_functions.inc.h b/quadrants/rhi/cuda/cuda_driver_functions.inc.h index 90361b2832..f22ff127ef 100644 --- a/quadrants/rhi/cuda/cuda_driver_functions.inc.h +++ b/quadrants/rhi/cuda/cuda_driver_functions.inc.h @@ -81,6 +81,7 @@ PER_CUDA_FUNCTION(import_external_semaphore, cuImportExternalSemaphore,CUexterna // Graph management PER_CUDA_FUNCTION(graph_create, cuGraphCreate, void **, uint32); PER_CUDA_FUNCTION(graph_add_kernel_node, cuGraphAddKernelNode, void **, void *, const void *, std::size_t, const void *); +PER_CUDA_FUNCTION(graph_add_empty_node, cuGraphAddEmptyNode, void **, void *, const void *, std::size_t); PER_CUDA_FUNCTION(graph_add_node, cuGraphAddNode, void **, void *, const void *, std::size_t, void *); PER_CUDA_FUNCTION(graph_instantiate, cuGraphInstantiate, void **, void *, void *, char *, std::size_t); PER_CUDA_FUNCTION(graph_launch, cuGraphLaunch, void *, void *); diff --git a/quadrants/runtime/cuda/graph_manager.cpp b/quadrants/runtime/cuda/graph_manager.cpp index 99c0844f2b..7827323c7c 100644 --- a/quadrants/runtime/cuda/graph_manager.cpp +++ b/quadrants/runtime/cuda/graph_manager.cpp @@ -287,6 +287,13 @@ void *GraphManager::add_kernel_node(void *graph, return node; } +void *GraphManager::add_empty_node(void *graph, const std::vector &deps) { + QD_ASSERT(!deps.empty()); + void *node = nullptr; + CUDADriver::get_instance().graph_add_empty_node(&node, graph, deps.data(), deps.size()); + return node; +} + unsigned long long GraphManager::create_cond_handle(void *graph) { void *cu_ctx = CUDAContext::get_instance().get_context(); unsigned long long handle = 0; @@ -397,6 +404,61 @@ void GraphManager::build_level(int parent_id, continue; } + // --- A qd.graph_parallel() fork/join region: a contiguous run of this level's direct, + // non-checkpoint tasks tagged with a nonzero stream_parallel_group_id (set by qd.branch()). Each + // distinct group id is one branch; branches fork from the region's entry (`prev_node`), run their + // tasks in order, and join into a single empty node so downstream work waits for all of them. CUDA's + // graph executor schedules the independent branch chains on separate streams -> real overlap. --- + if (tasks[cursor].stream_parallel_group_id != 0 && tasks[cursor].checkpoint_id < 0) { + int run_end = cursor; + while (run_end < end && tasks[run_end].graph_do_while_level_id == parent_id && + tasks[run_end].checkpoint_id < 0 && tasks[run_end].stream_parallel_group_id != 0) { + run_end++; + } + // Bucket the run's tasks by branch id, preserving first-seen (declaration) order. + std::vector group_ids; + std::vector> branches; + for (int t = cursor; t < run_end; t++) { + const int g = tasks[t].stream_parallel_group_id; + int idx = -1; + for (int k = 0; k < (int)group_ids.size(); k++) { + if (group_ids[k] == g) { + idx = k; + break; + } + } + if (idx < 0) { + idx = (int)group_ids.size(); + group_ids.push_back(g); + branches.emplace_back(); + } + branches[idx].push_back(t); + } + void *ctx_ptr = &cached.persistent_ctx; + std::vector tails; + tails.reserve(branches.size()); + for (auto &br : branches) { + void *bp = prev_node; // every branch forks from the region entry dependency + for (int t : br) { + bp = add_kernel_node(target_graph, bp, cuda_module->lookup_function(tasks[t].name), + (unsigned int)tasks[t].grid_dim, (unsigned int)tasks[t].block_dim, + (unsigned int)tasks[t].dynamic_shared_array_bytes, &ctx_ptr); + ++total_nodes; + } + tails.push_back(bp); + } + // Join. A single-branch region (e.g. an optional branch compiled out) has nothing to join, so just + // continue the chain from its tail; otherwise collect all tails into one empty successor node. + if (tails.size() == 1) { + prev_node = tails[0]; + } else { + prev_node = add_empty_node(target_graph, tails); + ++total_nodes; + } + cursor = run_end; + continue; + } + // --- A direct task of this level. Group consecutive tasks by checkpoint_id. --- const int cp = tasks[cursor].checkpoint_id; if (cp < 0) { diff --git a/quadrants/runtime/cuda/graph_manager.h b/quadrants/runtime/cuda/graph_manager.h index f637f62574..83506587a2 100644 --- a/quadrants/runtime/cuda/graph_manager.h +++ b/quadrants/runtime/cuda/graph_manager.h @@ -201,6 +201,10 @@ class GraphManager { unsigned int block_dim, unsigned int shared_mem, void **kernel_params); + // Add an empty (no-op) node to `graph` depending on every node in `deps`. Used as the join point of a + // qd.graph_parallel() region: it has no work but collects all branch tails into a single successor so + // downstream nodes wait for every branch. `deps` must be non-empty. + void *add_empty_node(void *graph, const std::vector &deps); // Recursively build the nodes for graph_do_while level `parent_id` (-1 = kernel top level) over the task range // [begin, end) into `target_graph` (the body graph of `parent_id`, or the root graph for -1). Direct tasks become // kernel nodes; a contiguous run of direct tasks sharing a non-negative `checkpoint_id` is wrapped in a gate-kernel + From 616658c82737d4d883b9f4efcc749110c546b7df Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 9 Jun 2026 00:15:24 -0700 Subject: [PATCH 02/25] test(graph): add qd.graph_parallel tests + docs; allow if-static optional branches - Recursive region-body validation: a graph_parallel region may contain branch blocks optionally wrapped in `if qd.static(...)` (qipc ENABLE_EE pattern); graph_do_while validator descends through those ifs too. - tests/python/test_graph_parallel.py: correctness vs serial (2/3/multi-loop branches), single-branch no-join, optional static-if branch, region inside graph_do_while, and compile-time error cases. Node-count assertions on CUDA. - docs: graph.md "Concurrent branches" section + backend table row; streams.md cross-link to graph_parallel for graph kernels. --- docs/source/user_guide/graph.md | 47 +++ docs/source/user_guide/streams.md | 2 + python/quadrants/lang/ast/ast_transformer.py | 37 +- .../function_def_transformer.py | 14 +- tests/python/test_graph_parallel.py | 341 ++++++++++++++++++ 5 files changed, 426 insertions(+), 15 deletions(-) create mode 100644 tests/python/test_graph_parallel.py diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index b1c239b080..7f7c3f4ac6 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -11,6 +11,7 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i | `graph=True` | hardware accelerated | hardware accelerated | hardware accelerated | runs (no acceleration) | runs (no acceleration) | runs (no acceleration) | | `qd.graph_do_while` | hardware accelerated | host fallback | host fallback | host fallback | host fallback | host fallback | | `qd.checkpoint` | GPU-side | GPU-side | GPU-side | GPU-side | GPU-side | host-side | +| `qd.graph_parallel` / `qd.branch` (concurrent branches) | concurrent (parallel streams) | concurrent (parallel streams) | runs serially (correct) | runs serially (correct) | runs serially (correct) | runs serially (correct) | AMDGPU `graph_do_while` falls back to the host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2). @@ -466,3 +467,49 @@ In this case, our recommendation is: - this will ensure your code is compact and maintainable - if you need optimum 100% performance on unsupported platforms, then consider PRing onto quadrants an optimized graph implementation for your target platform - for example it could somehow run MAX_ITER iterations anyway, similar to the earlier hand-rolled version, but via the graph abstraction, hence allowing the code to be compact, cross-platform, and also optimally fast + +## Concurrent branches with `qd.graph_parallel` *(experimental)* + +`qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel():` region lets you declare independent stages so the CUDA graph runs them on **parallel streams**. + +This is the graph-compatible analogue of [`qd.stream_parallel()`](streams.md) (which only works for non-graph kernels): both express "these sequences are independent, run them concurrently", but `graph_parallel` is honoured by the CUDA graph builder so it composes with `graph=True` and `graph_do_while`. + +```python +@qd.kernel(graph=True) +def step(...): + while qd.graph_do_while(ncond): + assemble_shared(...) # serial: feeds both branches + + with qd.graph_parallel(): # fork: branches run concurrently + with qd.branch(name="pt"): # point-triangle contacts + pt_assemble(...) + pt_hessian(...) + with qd.branch(name="ee"): # edge-edge contacts (independent of pt) + ee_assemble(...) + ee_hessian(...) + # join: everything below waits for BOTH branches to finish + merge_hessians(...) + precondition(...) +``` + +### Semantics + +- **Fork / join.** Every `qd.branch()` in the region forks from the work that precedes the region. All branches must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every branch's last kernel. +- **Branches are independent — you guarantee it.** Calls *within* a branch keep their program order, but calls in *different* branches have no ordering. The branches must be data-race free with respect to one another: no branch may read what another writes, and no two branches may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results, exactly like `qd.stream_parallel()`. +- **`name=` is optional** and used only as a label for profiling / graph introspection. + +### Restrictions (enforced at kernel compile time) + +- Must be used inside `@qd.kernel(graph=True)`. +- A region body may contain only `with qd.branch():` blocks, optionally wrapped in `if qd.static(...)` (so an optional branch can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). A single-branch region is allowed and lowers to a plain chain (no fork/join overhead). +- `qd.branch()` may appear only directly inside a `qd.graph_parallel()` region. +- Regions cannot be nested, and a branch body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel` inside a branch (a region may, however, sit inside a `qd.graph_do_while` body, as shown above). + +### Backend behaviour + +| backend | result | scheduling | +| --- | --- | --- | +| CUDA (graph path) | correct | branches run **concurrently** on parallel streams | +| AMDGPU / CPU / Vulkan / Metal | correct | branches run **serially** (the concurrency tags are honoured only by the CUDA graph builder today) | + +Because branches are independent by construction, running them serially on the other backends produces identical results — only the scheduling differs. `qd.graph_parallel` lowers onto the same internal concurrency-group mechanism as `qd.stream_parallel`, so non-graph fallbacks also fork the branches across streams. diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index a8db331bcc..24bff585d9 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -48,6 +48,8 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. +> **For `graph=True` kernels**, use [`qd.graph_parallel` / `qd.branch`](graph.md#concurrent-branches-with-qdgraph_parallel-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `graph_parallel` expresses the same "run these independent sequences concurrently" idea but is honoured by the CUDA graph builder. + ### Restrictions - All top-level statements in a kernel must be either all `stream_parallel` blocks or all regular statements. Mixing the two at the top level is a compile-time error. diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 2800738e80..85653c24f9 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1705,17 +1705,7 @@ def _build_graph_parallel_with(ctx: ASTTransformerFuncContext, node: ast.With) - raise QuadrantsSyntaxError("qd.graph_parallel() regions cannot be nested") if getattr(ctx, "_in_branch", False): raise QuadrantsSyntaxError("qd.graph_parallel() cannot appear inside a qd.branch() body") - for i, stmt in enumerate(node.body): - if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): - continue - is_branch = False - if isinstance(stmt, ast.With) and stmt.items: - is_branch, _ = ASTTransformer._is_branch_call(stmt.items[0].context_expr) - if not is_branch: - raise QuadrantsSyntaxError( - "A qd.graph_parallel() region may contain only 'with qd.branch():' blocks " - f"[offending stmt {i}: {type(stmt).__name__}]" - ) + ASTTransformer._validate_graph_parallel_body(node.body) ctx._in_graph_parallel = True try: build_stmts(ctx, node.body) @@ -1723,6 +1713,31 @@ def _build_graph_parallel_with(ctx: ASTTransformerFuncContext, node: ast.With) - ctx._in_graph_parallel = False return None + @staticmethod + def _validate_graph_parallel_body(stmts: list[ast.stmt]) -> None: + """A qd.graph_parallel() region body may contain only `with qd.branch():` blocks, optionally + wrapped in compile-time `if qd.static(...)` branches (the optional-branch pattern, e.g. qipc's + ENABLE_EE). Docstrings / coverage probes / `pass` are allowed. Anything else (a bare for-loop, + assignment, etc.) is a serial task that would silently fall outside any branch, so reject it.""" + for i, stmt in enumerate(stmts): + if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): + continue + if isinstance(stmt, ast.Pass): + continue + if isinstance(stmt, ast.With) and stmt.items: + is_branch, _ = ASTTransformer._is_branch_call(stmt.items[0].context_expr) + if is_branch: + continue + if isinstance(stmt, ast.If): + ASTTransformer._validate_graph_parallel_body(stmt.body) + ASTTransformer._validate_graph_parallel_body(stmt.orelse) + continue + raise QuadrantsSyntaxError( + "A qd.graph_parallel() region may contain only 'with qd.branch():' blocks (optionally " + "inside 'if qd.static(...)'). Move other work outside the region. " + f"[offending stmt {i}: {type(stmt).__name__}]" + ) + @staticmethod def _build_branch_with(ctx: ASTTransformerFuncContext, node: ast.With, name: str | None) -> None: """Handles ``with qd.branch():`` members of a ``qd.graph_parallel()`` region. diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 5fc8cad256..8e1d306790 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -693,12 +693,18 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo continue if FunctionDefTransformer._is_graph_parallel_with(stmt): # A `with qd.graph_parallel()` region groups concurrent `with qd.branch()` members; it is - # a legal sibling of for-loops / checkpoints. Its body must be branch blocks only (enforced - # fully in ASTTransformer._build_graph_parallel_with); each branch body is task territory, - # validated with the in-loop rules. - for member in stmt.body: + # a legal sibling of for-loops / checkpoints. Its body must be branch blocks (optionally + # under `if qd.static(...)`); the full check is in ASTTransformer._build_graph_parallel_with. + # Each branch body is task territory, validated here with the in-loop rules. Descend through + # `if` members so branches inside an optional-branch `if qd.static(...)` are reached too. + pending = list(stmt.body) + while pending: + member = pending.pop() if FunctionDefTransformer._is_branch_with(member): FunctionDefTransformer._validate_graph_do_while_stmt_list(member.body, is_kernel_top=False) + elif isinstance(member, ast.If): + pending.extend(member.body) + pending.extend(member.orelse) continue where = "the kernel body" if is_kernel_top else "a qd.graph_do_while() body" raise QuadrantsSyntaxError( diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py new file mode 100644 index 0000000000..380e0e1b68 --- /dev/null +++ b/tests/python/test_graph_parallel.py @@ -0,0 +1,341 @@ +"""Tests for qd.graph_parallel / qd.branch -- concurrent fork/join branches in graph kernels. + +`with qd.graph_parallel():` opens a fork/join region whose `with qd.branch():` members are independent +sequences of work. On the CUDA graph path the branches become independent graph chains joined by a single +empty node, so the runtime schedules them on parallel streams; on other backends (CPU / AMDGPU / Vulkan / +Metal) they run serially but produce identical results. + +The behavioural assertions (disjoint-array correctness) hold on every backend. The graph-structure +assertions (node counts: one kernel node per branch task + one empty join node) only apply on the CUDA +graph path, where the builder forks/joins; they are guarded by `_on_cuda()`. +""" + +import numpy as np +import pytest + +import quadrants as qd +from quadrants.lang import impl + +from tests import test_utils + + +def _on_cuda(): + return impl.current_cfg().arch == qd.cuda + + +def _platform_supports_graph(): + arch = impl.current_cfg().arch + return arch == qd.cuda or arch == qd.amdgpu + + +def _graph_num_nodes(): + return impl.get_runtime().prog.get_graph_num_nodes_on_last_call() + + +def _num_offloaded_tasks(): + return impl.get_runtime().prog.get_num_offloaded_tasks_on_last_call() + + +@test_utils.test() +def test_graph_parallel_is_no_op_outside_kernels(): + """At Python runtime (outside kernels) qd.graph_parallel / qd.branch must be usable no-op context + managers, so helpers that are sometimes called from Python and sometimes from kernels still import + and run. Mirrors qd.stream_parallel / qd.checkpoint.""" + sentinel = [] + with qd.graph_parallel(): + with qd.branch(): + sentinel.append("a") + with qd.branch(name="b"): + sentinel.append("b") + assert sentinel == ["a", "b"] + + +@test_utils.test() +def test_graph_parallel_two_branches(): + """Two branches write disjoint arrays; a serial loop after the region reads both (so it depends on + the join). Results must match the serial reference on every backend; on CUDA the graph has one node + per task plus one empty join node.""" + n = 1024 + + @qd.kernel(graph=True) + def k( + x: qd.types.ndarray(qd.f32, ndim=1), + y: qd.types.ndarray(qd.f32, ndim=1), + z: qd.types.ndarray(qd.f32, ndim=1), + ): + with qd.graph_parallel(): + with qd.branch(name="bx"): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + with qd.branch(name="by"): + for i in range(y.shape[0]): + y[i] = y[i] + 2.0 + for i in range(z.shape[0]): + z[i] = x[i] + y[i] + + x = qd.ndarray(qd.f32, shape=(n,)) + y = qd.ndarray(qd.f32, shape=(n,)) + z = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + z.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x, y, z) + + num_tasks = _num_offloaded_tasks() + assert num_tasks == 3 # two branch loops + the serial z loop + if _on_cuda(): + # 3 kernel nodes + 1 empty join node + assert _graph_num_nodes() == num_tasks + 1 + + np.testing.assert_allclose(x.to_numpy(), 1.0) + np.testing.assert_allclose(y.to_numpy(), 2.0) + np.testing.assert_allclose(z.to_numpy(), 3.0) + + # Relaunch: same cached graph, same result. + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + k(x, y, z) + np.testing.assert_allclose(z.to_numpy(), 3.0) + + +@test_utils.test() +def test_graph_parallel_three_branches(): + """Fan-out of three independent branches; one empty join node.""" + n = 256 + + @qd.kernel(graph=True) + def k( + a: qd.types.ndarray(qd.f32, ndim=1), + b: qd.types.ndarray(qd.f32, ndim=1), + c: qd.types.ndarray(qd.f32, ndim=1), + ): + with qd.graph_parallel(): + with qd.branch(): + for i in range(a.shape[0]): + a[i] = a[i] + 1.0 + with qd.branch(): + for i in range(b.shape[0]): + b[i] = b[i] + 2.0 + with qd.branch(): + for i in range(c.shape[0]): + c[i] = c[i] + 3.0 + + a = qd.ndarray(qd.f32, shape=(n,)) + b = qd.ndarray(qd.f32, shape=(n,)) + c = qd.ndarray(qd.f32, shape=(n,)) + for arr in (a, b, c): + arr.from_numpy(np.zeros(n, dtype=np.float32)) + + k(a, b, c) + num_tasks = _num_offloaded_tasks() + assert num_tasks == 3 + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 + + np.testing.assert_allclose(a.to_numpy(), 1.0) + np.testing.assert_allclose(b.to_numpy(), 2.0) + np.testing.assert_allclose(c.to_numpy(), 3.0) + + +@test_utils.test() +def test_graph_parallel_multi_loop_branches(): + """Each branch contains several loops; they must chain in order inside the branch while the two + branches run independently. Branch tasks = 4, plus one join node on CUDA.""" + n = 128 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel(): + with qd.branch(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + for i in range(x.shape[0]): + x[i] = x[i] * 2.0 + with qd.branch(): + for i in range(y.shape[0]): + y[i] = y[i] + 3.0 + for i in range(y.shape[0]): + y[i] = y[i] * 4.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + y = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x, y) + num_tasks = _num_offloaded_tasks() + assert num_tasks == 4 + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 + + np.testing.assert_allclose(x.to_numpy(), 2.0) # (0+1)*2 + np.testing.assert_allclose(y.to_numpy(), 12.0) # (0+3)*4 + + +@test_utils.test() +def test_graph_parallel_single_branch_no_join(): + """A region with a single branch (e.g. an optional branch compiled out) needs no join: it degenerates + to a plain chain, so the node count equals the number of branch tasks (no extra empty node).""" + n = 256 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel(): + with qd.branch(): + for i in range(x.shape[0]): + x[i] = x[i] + 5.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x) + num_tasks = _num_offloaded_tasks() + assert num_tasks == 1 + if _on_cuda(): + assert _graph_num_nodes() == num_tasks # no join node for a single branch + + np.testing.assert_allclose(x.to_numpy(), 5.0) + + +@test_utils.test() +def test_graph_parallel_optional_branch_static_if(): + """The qipc ENABLE_EE pattern: a branch wrapped in `if qd.static(...)`. When the flag is False the + branch is compiled out (region has one branch -> no join); when True both branches run.""" + n = 128 + + @qd.kernel(graph=True) + def k_off(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel(): + with qd.branch(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + if qd.static(False): + with qd.branch(): + for i in range(y.shape[0]): + y[i] = y[i] + 1.0 + + @qd.kernel(graph=True) + def k_on(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel(): + with qd.branch(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + if qd.static(True): + with qd.branch(): + for i in range(y.shape[0]): + y[i] = y[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + y = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + k_off(x, y) + assert _num_offloaded_tasks() == 1 + if _on_cuda(): + assert _graph_num_nodes() == 1 # single branch, no join + np.testing.assert_allclose(x.to_numpy(), 1.0) + np.testing.assert_allclose(y.to_numpy(), 0.0) # EE branch compiled out + + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + k_on(x, y) + assert _num_offloaded_tasks() == 2 + if _on_cuda(): + assert _graph_num_nodes() == 3 # two branches + join + np.testing.assert_allclose(x.to_numpy(), 1.0) + np.testing.assert_allclose(y.to_numpy(), 1.0) + + +@test_utils.test() +def test_graph_parallel_inside_graph_do_while(): + """A fork/join region inside a qd.graph_do_while loop body must be correct across iterations: each + iteration runs both branches, then decrements the counter.""" + n = 64 + iters = 5 + + @qd.kernel(graph=True) + def k( + x: qd.types.ndarray(qd.i32, ndim=1), + y: qd.types.ndarray(qd.i32, ndim=1), + counter: qd.types.ndarray(qd.i32, ndim=0), + ): + while qd.graph_do_while(counter): + with qd.graph_parallel(): + with qd.branch(): + for i in range(x.shape[0]): + x[i] = x[i] + 1 + with qd.branch(): + for i in range(y.shape[0]): + y[i] = y[i] + 2 + for _ in range(1): + counter[()] = counter[()] - 1 + + x = qd.ndarray(qd.i32, shape=(n,)) + y = qd.ndarray(qd.i32, shape=(n,)) + counter = qd.ndarray(qd.i32, shape=()) + x.from_numpy(np.zeros(n, dtype=np.int32)) + y.from_numpy(np.zeros(n, dtype=np.int32)) + counter.from_numpy(np.array(iters, dtype=np.int32)) + + k(x, y, counter) + + assert counter.to_numpy() == 0 + np.testing.assert_array_equal(x.to_numpy(), np.full(n, iters, dtype=np.int32)) + np.testing.assert_array_equal(y.to_numpy(), np.full(n, 2 * iters, dtype=np.int32)) + + +@test_utils.test() +def test_graph_parallel_branch_outside_region_raises(): + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.branch(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError, match="qd.branch.. can only be used .* inside a qd.graph_parallel"): + k(x) + + +@test_utils.test() +def test_graph_parallel_requires_graph_kernel(): + @qd.kernel + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel(): + with qd.branch(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError, match="requires @qd.kernel.graph=True"): + k(x) + + +@test_utils.test() +def test_graph_parallel_non_branch_body_raises(): + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.branch"): + k(x) + + +@test_utils.test() +def test_graph_parallel_nested_region_raises(): + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel(): + with qd.branch(): + with qd.graph_parallel(): + with qd.branch(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError): + k(x) From 8d2f3747884366b11f2a91f2ac6c1ead4ceab74c Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 9 Jun 2026 00:25:52 -0700 Subject: [PATCH 03/25] fix(graph): tag a loop's bound-compute serial task with its branch group id A dynamic-bound for-loop (e.g. range(x.shape[0])) lowers to a bound-compute serial task followed by the range_for. The serial task carried group 0, which split a qd.branch()'s contiguous run and defeated the graph builder's fork/join (branches ran serially). Propagate the loop's stream_parallel_group_id onto the flushed pending-serial task in the offloader so the bound task and its range_for stay in the same branch. The existing single-level / single-region frontend restriction guarantees the pending serial never mixes branches. Also make the graph_parallel test node-count assertions relative to the offloaded-task count (each dynamic-bound loop is 2 tasks). --- tests/python/test_graph_parallel.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index 380e0e1b68..ef7c2c6675 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -83,9 +83,9 @@ def k( k(x, y, z) num_tasks = _num_offloaded_tasks() - assert num_tasks == 3 # two branch loops + the serial z loop if _on_cuda(): - # 3 kernel nodes + 1 empty join node + # One graph node per offloaded task (each dynamic-bound loop is a bound-compute serial + a + # range_for, both in the branch) plus exactly one empty join node for the single region. assert _graph_num_nodes() == num_tasks + 1 np.testing.assert_allclose(x.to_numpy(), 1.0) @@ -129,9 +129,8 @@ def k( k(a, b, c) num_tasks = _num_offloaded_tasks() - assert num_tasks == 3 if _on_cuda(): - assert _graph_num_nodes() == num_tasks + 1 + assert _graph_num_nodes() == num_tasks + 1 # three branches + one join np.testing.assert_allclose(a.to_numpy(), 1.0) np.testing.assert_allclose(b.to_numpy(), 2.0) @@ -165,9 +164,8 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): k(x, y) num_tasks = _num_offloaded_tasks() - assert num_tasks == 4 if _on_cuda(): - assert _graph_num_nodes() == num_tasks + 1 + assert _graph_num_nodes() == num_tasks + 1 # all branch tasks + one join np.testing.assert_allclose(x.to_numpy(), 2.0) # (0+1)*2 np.testing.assert_allclose(y.to_numpy(), 12.0) # (0+3)*4 @@ -191,9 +189,8 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): k(x) num_tasks = _num_offloaded_tasks() - assert num_tasks == 1 if _on_cuda(): - assert _graph_num_nodes() == num_tasks # no join node for a single branch + assert _graph_num_nodes() == num_tasks # single branch -> plain chain, no join node np.testing.assert_allclose(x.to_numpy(), 5.0) @@ -231,18 +228,16 @@ def k_on(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1 x.from_numpy(np.zeros(n, dtype=np.float32)) y.from_numpy(np.zeros(n, dtype=np.float32)) k_off(x, y) - assert _num_offloaded_tasks() == 1 if _on_cuda(): - assert _graph_num_nodes() == 1 # single branch, no join + assert _graph_num_nodes() == _num_offloaded_tasks() # single branch -> no join np.testing.assert_allclose(x.to_numpy(), 1.0) np.testing.assert_allclose(y.to_numpy(), 0.0) # EE branch compiled out x.from_numpy(np.zeros(n, dtype=np.float32)) y.from_numpy(np.zeros(n, dtype=np.float32)) k_on(x, y) - assert _num_offloaded_tasks() == 2 if _on_cuda(): - assert _graph_num_nodes() == 3 # two branches + join + assert _graph_num_nodes() == _num_offloaded_tasks() + 1 # two branches + join np.testing.assert_allclose(x.to_numpy(), 1.0) np.testing.assert_allclose(y.to_numpy(), 1.0) From d00e736903757430375094a21545721b604d50bf Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 06:48:54 -0700 Subject: [PATCH 04/25] fix(graph): import contextmanager in misc.py The graph_parallel/branch context managers use @contextmanager. On the old desk8 base misc.py already imported it; rebasing the feature onto main (where misc.py no longer imports contextmanager) lost the import, breaking module import (NameError at load -> stub generation failed). --- python/quadrants/lang/misc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index 3bfc147ce9..a559124519 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -3,6 +3,7 @@ import shutil import tempfile import warnings +from contextlib import contextmanager from copy import deepcopy as _deepcopy from quadrants import _logging, _snode From a95d2b71e7995e563a0c228913bc991b69900e4a Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 06:59:37 -0700 Subject: [PATCH 05/25] docs(graph): drop stream_parallel lowering note from graph_parallel backend section --- docs/source/user_guide/graph.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 7f7c3f4ac6..3edf6da71d 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -512,4 +512,4 @@ def step(...): | CUDA (graph path) | correct | branches run **concurrently** on parallel streams | | AMDGPU / CPU / Vulkan / Metal | correct | branches run **serially** (the concurrency tags are honoured only by the CUDA graph builder today) | -Because branches are independent by construction, running them serially on the other backends produces identical results — only the scheduling differs. `qd.graph_parallel` lowers onto the same internal concurrency-group mechanism as `qd.stream_parallel`, so non-graph fallbacks also fork the branches across streams. +Because branches are independent by construction, running them serially on the other backends produces identical results — only the scheduling differs. From 70cb6974f5c239ff58ed1a9efee18c0c2dc21a81 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 07:12:40 -0700 Subject: [PATCH 06/25] refactor(graph): rename qd.graph_parallel -> qd.graph_parallel_context and qd.branch -> qd.graph_parallel The fork/join region context manager is now qd.graph_parallel_context() and each concurrent branch member is qd.graph_parallel(name=...). Updates the public API (misc.py + __all__), the AST detection/build/validation in ast_transformer.py and function_def_transformer.py (region helpers gain a _context suffix; the conceptual "branch" naming is retained for members), tests, user docs (graph.md/streams.md, including the heading anchor), and CUDA graph_manager comments. --- docs/source/user_guide/graph.md | 22 ++--- docs/source/user_guide/streams.md | 2 +- python/quadrants/lang/ast/ast_transformer.py | 92 ++++++++++--------- .../function_def_transformer.py | 31 ++++--- python/quadrants/lang/misc.py | 36 ++++---- quadrants/runtime/cuda/graph_manager.cpp | 4 +- quadrants/runtime/cuda/graph_manager.h | 2 +- tests/python/test_graph_parallel.py | 76 +++++++-------- 8 files changed, 139 insertions(+), 126 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 3edf6da71d..702163f7fc 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -11,7 +11,7 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i | `graph=True` | hardware accelerated | hardware accelerated | hardware accelerated | runs (no acceleration) | runs (no acceleration) | runs (no acceleration) | | `qd.graph_do_while` | hardware accelerated | host fallback | host fallback | host fallback | host fallback | host fallback | | `qd.checkpoint` | GPU-side | GPU-side | GPU-side | GPU-side | GPU-side | host-side | -| `qd.graph_parallel` / `qd.branch` (concurrent branches) | concurrent (parallel streams) | concurrent (parallel streams) | runs serially (correct) | runs serially (correct) | runs serially (correct) | runs serially (correct) | +| `qd.graph_parallel_context` / `qd.graph_parallel` (concurrent branches) | concurrent (parallel streams) | concurrent (parallel streams) | runs serially (correct) | runs serially (correct) | runs serially (correct) | runs serially (correct) | AMDGPU `graph_do_while` falls back to the host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2). @@ -468,11 +468,11 @@ In this case, our recommendation is: - if you need optimum 100% performance on unsupported platforms, then consider PRing onto quadrants an optimized graph implementation for your target platform - for example it could somehow run MAX_ITER iterations anyway, similar to the earlier hand-rolled version, but via the graph abstraction, hence allowing the code to be compact, cross-platform, and also optimally fast -## Concurrent branches with `qd.graph_parallel` *(experimental)* +## Concurrent branches with `qd.graph_parallel_context` *(experimental)* -`qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel():` region lets you declare independent stages so the CUDA graph runs them on **parallel streams**. +`qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel_context` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel_context():` region lets you declare independent stages so the CUDA graph runs them on **parallel streams**. -This is the graph-compatible analogue of [`qd.stream_parallel()`](streams.md) (which only works for non-graph kernels): both express "these sequences are independent, run them concurrently", but `graph_parallel` is honoured by the CUDA graph builder so it composes with `graph=True` and `graph_do_while`. +This is the graph-compatible analogue of [`qd.stream_parallel()`](streams.md) (which only works for non-graph kernels): both express "these sequences are independent, run them concurrently", but `qd.graph_parallel_context` is honoured by the CUDA graph builder so it composes with `graph=True` and `graph_do_while`. ```python @qd.kernel(graph=True) @@ -480,11 +480,11 @@ def step(...): while qd.graph_do_while(ncond): assemble_shared(...) # serial: feeds both branches - with qd.graph_parallel(): # fork: branches run concurrently - with qd.branch(name="pt"): # point-triangle contacts + with qd.graph_parallel_context(): # fork: branches run concurrently + with qd.graph_parallel(name="pt"): # point-triangle contacts pt_assemble(...) pt_hessian(...) - with qd.branch(name="ee"): # edge-edge contacts (independent of pt) + with qd.graph_parallel(name="ee"): # edge-edge contacts (independent of pt) ee_assemble(...) ee_hessian(...) # join: everything below waits for BOTH branches to finish @@ -494,16 +494,16 @@ def step(...): ### Semantics -- **Fork / join.** Every `qd.branch()` in the region forks from the work that precedes the region. All branches must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every branch's last kernel. +- **Fork / join.** Every `qd.graph_parallel()` branch in the region forks from the work that precedes the region. All branches must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every branch's last kernel. - **Branches are independent — you guarantee it.** Calls *within* a branch keep their program order, but calls in *different* branches have no ordering. The branches must be data-race free with respect to one another: no branch may read what another writes, and no two branches may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results, exactly like `qd.stream_parallel()`. - **`name=` is optional** and used only as a label for profiling / graph introspection. ### Restrictions (enforced at kernel compile time) - Must be used inside `@qd.kernel(graph=True)`. -- A region body may contain only `with qd.branch():` blocks, optionally wrapped in `if qd.static(...)` (so an optional branch can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). A single-branch region is allowed and lowers to a plain chain (no fork/join overhead). -- `qd.branch()` may appear only directly inside a `qd.graph_parallel()` region. -- Regions cannot be nested, and a branch body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel` inside a branch (a region may, however, sit inside a `qd.graph_do_while` body, as shown above). +- A region body may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional branch can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). A single-branch region is allowed and lowers to a plain chain (no fork/join overhead). +- `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()` region. +- Regions cannot be nested, and a branch body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a branch (a region may, however, sit inside a `qd.graph_do_while` body, as shown above). ### Backend behaviour diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index 24bff585d9..918d626192 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -48,7 +48,7 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. -> **For `graph=True` kernels**, use [`qd.graph_parallel` / `qd.branch`](graph.md#concurrent-branches-with-qdgraph_parallel-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `graph_parallel` expresses the same "run these independent sequences concurrently" idea but is honoured by the CUDA graph builder. +> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#concurrent-branches-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honoured by the CUDA graph builder. ### Restrictions diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 85653c24f9..b20e95b032 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1373,23 +1373,24 @@ def _is_checkpoint_call(node: ast.expr, global_vars: dict): return CheckpointTransformer.is_checkpoint_call(node, global_vars) @staticmethod - def _is_graph_parallel_call(node: ast.expr) -> bool: - """If *node* is a ``qd.graph_parallel()`` call return True, else False.""" + def _is_graph_parallel_context_call(node: ast.expr) -> bool: + """If *node* is a ``qd.graph_parallel_context()`` call return True, else False.""" if not isinstance(node, ast.Call): return False func = node.func - is_gp = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( - isinstance(func, ast.Name) and func.id == "graph_parallel" + is_gpc = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel_context") or ( + isinstance(func, ast.Name) and func.id == "graph_parallel_context" ) - if not is_gp: + if not is_gpc: return False if node.args or node.keywords: - raise QuadrantsSyntaxError("qd.graph_parallel() takes no arguments") + raise QuadrantsSyntaxError("qd.graph_parallel_context() takes no arguments") return True @staticmethod def _is_branch_call(node: ast.expr) -> tuple[bool, str | None]: - """If *node* is ``qd.branch(...)`` return ``(True, name)``; otherwise ``(False, None)``. + """If *node* is ``qd.graph_parallel(...)`` (a branch) return ``(True, name)``; otherwise + ``(False, None)``. ``name`` is the value of the optional ``name=`` kwarg (a string literal) or ``None``. The call shape is validated here so misuse raises at the ``with`` site rather than later. @@ -1397,21 +1398,23 @@ def _is_branch_call(node: ast.expr) -> tuple[bool, str | None]: if not isinstance(node, ast.Call): return False, None func = node.func - is_branch = (isinstance(func, ast.Attribute) and func.attr == "branch") or ( - isinstance(func, ast.Name) and func.id == "branch" + is_branch = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( + isinstance(func, ast.Name) and func.id == "graph_parallel" ) if not is_branch: return False, None if node.args: - raise QuadrantsSyntaxError("qd.branch() takes no positional arguments; use qd.branch(name='...') instead") + raise QuadrantsSyntaxError( + "qd.graph_parallel() takes no positional arguments; use qd.graph_parallel(name='...') instead" + ) name: str | None = None for kw in node.keywords: if kw.arg != "name": raise QuadrantsSyntaxError( - f"qd.branch() got unexpected keyword argument {kw.arg!r}; only 'name' is supported" + f"qd.graph_parallel() got unexpected keyword argument {kw.arg!r}; only 'name' is supported" ) if not (isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str)): - raise QuadrantsSyntaxError("qd.branch(name=...) must be a string literal") + raise QuadrantsSyntaxError("qd.graph_parallel(name=...) must be a string literal") name = kw.value.value return True, name @@ -1658,8 +1661,8 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: if checkpoint_info is not None: return ASTTransformer._build_checkpoint_with(ctx, node, checkpoint_info) - if ASTTransformer._is_graph_parallel_call(item.context_expr): - return ASTTransformer._build_graph_parallel_with(ctx, node) + if ASTTransformer._is_graph_parallel_context_call(item.context_expr): + return ASTTransformer._build_graph_parallel_context_with(ctx, node) is_branch, branch_name = ASTTransformer._is_branch_call(item.context_expr) if is_branch: @@ -1668,7 +1671,7 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars): raise QuadrantsSyntaxError( "'with' in Quadrants kernels only supports qd.stream_parallel(), qd.checkpoint(), " - "qd.graph_parallel(), or qd.branch()" + "qd.graph_parallel_context(), or qd.graph_parallel()" ) if not ctx.is_kernel: raise QuadrantsSyntaxError("qd.stream_parallel() can only be used inside @qd.kernel, not @qd.func") @@ -1688,37 +1691,42 @@ def _build_checkpoint_with( return CheckpointTransformer.build_checkpoint_with(ctx, node, info, build_stmts) @staticmethod - def _build_graph_parallel_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: - """Handles ``with qd.graph_parallel():`` fork/join regions. + def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: + """Handles ``with qd.graph_parallel_context():`` fork/join regions. Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains - only ``with qd.branch():`` blocks, then walks the body. The region emits no IR tag of its own -- - each ``branch`` inside lowers to a stream-parallel group (via begin/end_stream_parallel), and the - CUDA graph builder forks the distinct groups in a contiguous run and joins them. Regions are kept - apart by the serial work between them (see d3_0_graph_parallel_impl.md).""" + only ``with qd.graph_parallel():`` blocks, then walks the body. The region emits no IR tag of its + own -- each branch inside lowers to a stream-parallel group (via begin/end_stream_parallel), and + the CUDA graph builder forks the distinct groups in a contiguous run and joins them. Regions are + kept apart by the serial work between them (see d3_0_graph_parallel_impl.md).""" if not ctx.is_kernel: - raise QuadrantsSyntaxError("qd.graph_parallel() can only be used inside @qd.kernel, not @qd.func") + raise QuadrantsSyntaxError( + "qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func" + ) kernel = ctx.global_context.current_kernel if kernel is None or not kernel.use_graph: - raise QuadrantsSyntaxError("qd.graph_parallel() requires @qd.kernel(graph=True)") - if getattr(ctx, "_in_graph_parallel", False): - raise QuadrantsSyntaxError("qd.graph_parallel() regions cannot be nested") + raise QuadrantsSyntaxError("qd.graph_parallel_context() requires @qd.kernel(graph=True)") + if getattr(ctx, "_in_graph_parallel_context", False): + raise QuadrantsSyntaxError("qd.graph_parallel_context() regions cannot be nested") if getattr(ctx, "_in_branch", False): - raise QuadrantsSyntaxError("qd.graph_parallel() cannot appear inside a qd.branch() body") - ASTTransformer._validate_graph_parallel_body(node.body) - ctx._in_graph_parallel = True + raise QuadrantsSyntaxError( + "qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body" + ) + ASTTransformer._validate_graph_parallel_context_body(node.body) + ctx._in_graph_parallel_context = True try: build_stmts(ctx, node.body) finally: - ctx._in_graph_parallel = False + ctx._in_graph_parallel_context = False return None @staticmethod - def _validate_graph_parallel_body(stmts: list[ast.stmt]) -> None: - """A qd.graph_parallel() region body may contain only `with qd.branch():` blocks, optionally - wrapped in compile-time `if qd.static(...)` branches (the optional-branch pattern, e.g. qipc's - ENABLE_EE). Docstrings / coverage probes / `pass` are allowed. Anything else (a bare for-loop, - assignment, etc.) is a serial task that would silently fall outside any branch, so reject it.""" + def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: + """A qd.graph_parallel_context() region body may contain only `with qd.graph_parallel():` blocks, + optionally wrapped in compile-time `if qd.static(...)` branches (the optional-branch pattern, e.g. + qipc's ENABLE_EE). Docstrings / coverage probes / `pass` are allowed. Anything else (a bare + for-loop, assignment, etc.) is a serial task that would silently fall outside any branch, so + reject it.""" for i, stmt in enumerate(stmts): if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): continue @@ -1729,24 +1737,26 @@ def _validate_graph_parallel_body(stmts: list[ast.stmt]) -> None: if is_branch: continue if isinstance(stmt, ast.If): - ASTTransformer._validate_graph_parallel_body(stmt.body) - ASTTransformer._validate_graph_parallel_body(stmt.orelse) + ASTTransformer._validate_graph_parallel_context_body(stmt.body) + ASTTransformer._validate_graph_parallel_context_body(stmt.orelse) continue raise QuadrantsSyntaxError( - "A qd.graph_parallel() region may contain only 'with qd.branch():' blocks (optionally " - "inside 'if qd.static(...)'). Move other work outside the region. " + "A qd.graph_parallel_context() region may contain only 'with qd.graph_parallel():' blocks " + "(optionally inside 'if qd.static(...)'). Move other work outside the region. " f"[offending stmt {i}: {type(stmt).__name__}]" ) @staticmethod def _build_branch_with(ctx: ASTTransformerFuncContext, node: ast.With, name: str | None) -> None: - """Handles ``with qd.branch():`` members of a ``qd.graph_parallel()`` region. + """Handles ``with qd.graph_parallel():`` branch members of a ``qd.graph_parallel_context()`` region. Reuses the stream-parallel tagging: begin_stream_parallel() assigns this branch a fresh ``stream_parallel_group_id`` that every for-loop in the body inherits, so the offloaded tasks carry the branch id all the way to the graph builder. ``name`` is currently a label only.""" - if not getattr(ctx, "_in_graph_parallel", False): - raise QuadrantsSyntaxError("qd.branch() can only be used directly inside a qd.graph_parallel() region") + if not getattr(ctx, "_in_graph_parallel_context", False): + raise QuadrantsSyntaxError( + "qd.graph_parallel() can only be used directly inside a qd.graph_parallel_context() region" + ) ctx._in_branch = True ctx.ast_builder.begin_stream_parallel() try: diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 8e1d306790..b2de9d8132 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -601,34 +601,34 @@ def _is_loop_config_call(stmt: ast.stmt) -> bool: return False @staticmethod - def _is_graph_parallel_with(stmt: ast.stmt) -> bool: - """Syntactic check matching ASTTransformer._is_graph_parallel_call: a - ``with qd.graph_parallel():`` fork/join region.""" + def _is_graph_parallel_context_with(stmt: ast.stmt) -> bool: + """Syntactic check matching ASTTransformer._is_graph_parallel_context_call: a + ``with qd.graph_parallel_context():`` fork/join region.""" if not isinstance(stmt, ast.With) or len(stmt.items) != 1: return False ctx_expr = stmt.items[0].context_expr if not isinstance(ctx_expr, ast.Call): return False func = ctx_expr.func - if isinstance(func, ast.Attribute) and func.attr == "graph_parallel": + if isinstance(func, ast.Attribute) and func.attr == "graph_parallel_context": return True - if isinstance(func, ast.Name) and func.id == "graph_parallel": + if isinstance(func, ast.Name) and func.id == "graph_parallel_context": return True return False @staticmethod def _is_branch_with(stmt: ast.stmt) -> bool: - """Syntactic check matching ASTTransformer._is_branch_call: a ``with qd.branch(...):`` member - of a ``qd.graph_parallel()`` region.""" + """Syntactic check matching ASTTransformer._is_branch_call: a ``with qd.graph_parallel(...):`` + branch member of a ``qd.graph_parallel_context()`` region.""" if not isinstance(stmt, ast.With) or len(stmt.items) != 1: return False ctx_expr = stmt.items[0].context_expr if not isinstance(ctx_expr, ast.Call): return False func = ctx_expr.func - if isinstance(func, ast.Attribute) and func.attr == "branch": + if isinstance(func, ast.Attribute) and func.attr == "graph_parallel": return True - if isinstance(func, ast.Name) and func.id == "branch": + if isinstance(func, ast.Name) and func.id == "graph_parallel": return True return False @@ -691,12 +691,13 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo # `CheckpointTransformer.build_checkpoint_with`. FunctionDefTransformer._validate_graph_do_while_stmt_list(stmt.body, is_kernel_top=is_kernel_top) continue - if FunctionDefTransformer._is_graph_parallel_with(stmt): - # A `with qd.graph_parallel()` region groups concurrent `with qd.branch()` members; it is - # a legal sibling of for-loops / checkpoints. Its body must be branch blocks (optionally - # under `if qd.static(...)`); the full check is in ASTTransformer._build_graph_parallel_with. - # Each branch body is task territory, validated here with the in-loop rules. Descend through - # `if` members so branches inside an optional-branch `if qd.static(...)` are reached too. + if FunctionDefTransformer._is_graph_parallel_context_with(stmt): + # A `with qd.graph_parallel_context()` region groups concurrent `with qd.graph_parallel()` + # branches; it is a legal sibling of for-loops / checkpoints. Its body must be branch blocks + # (optionally under `if qd.static(...)`); the full check is in + # ASTTransformer._build_graph_parallel_context_with. Each branch body is task territory, + # validated here with the in-loop rules. Descend through `if` members so branches inside an + # optional-branch `if qd.static(...)` are reached too. pending = list(stmt.body) while pending: member = pending.pop() diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index a559124519..076d04a4a4 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -753,15 +753,15 @@ def graph_do_while(condition) -> bool: @contextmanager -def graph_parallel(): - """Opens a fork/join region whose ``qd.branch()`` members run concurrently. +def graph_parallel_context(): + """Opens a fork/join region whose ``qd.graph_parallel()`` branches run concurrently. - Used as ``with qd.graph_parallel():`` inside a ``@qd.kernel(graph=True)`` kernel. The region's body - must contain only ``with qd.branch():`` blocks. Each branch is an independent sequence of work; the - branches have no ordering relative to each other and may execute concurrently, while everything after - the region waits for *all* branches to finish (the join). This is the CUDA-graph analogue of - ``qd.stream_parallel()`` (which is for non-graph kernels): it lets independent stages -- e.g. qipc's - point-triangle and edge-edge assembly -- overlap inside a captured graph. + Used as ``with qd.graph_parallel_context():`` inside a ``@qd.kernel(graph=True)`` kernel. The + region's body must contain only ``with qd.graph_parallel():`` blocks. Each branch is an independent + sequence of work; the branches have no ordering relative to each other and may execute concurrently, + while everything after the region waits for *all* branches to finish (the join). This is the + CUDA-graph analogue of ``qd.stream_parallel()`` (which is for non-graph kernels): it lets independent + stages -- e.g. qipc's point-triangle and edge-edge assembly -- overlap inside a captured graph. Concurrency contract (the author's responsibility): branches must be data-race free with respect to one another (no branch reads what another writes, no two branches write the same location). Calls @@ -775,9 +775,9 @@ def graph_parallel(): Restrictions (enforced at kernel compile time): - Must be used inside ``@qd.kernel(graph=True)``. - - The region body may contain only ``with qd.branch():`` blocks. + - The region body may contain only ``with qd.graph_parallel():`` blocks. - Regions cannot be nested, and a branch body must be straight-line task work (no nested - ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel``). + ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). This function should not be called directly at runtime; it is recognised and transformed during AST compilation. At Python runtime (outside kernels) it is a no-op context manager. @@ -788,15 +788,15 @@ def graph_parallel(): @contextmanager -def branch(name=None): - """Declares one concurrent member of an enclosing ``qd.graph_parallel()`` region. +def graph_parallel(name=None): + """Declares one concurrent branch of an enclosing ``qd.graph_parallel_context()`` region. - Used as ``with qd.branch():`` or ``with qd.branch(name="pt"):`` directly inside a - ``with qd.graph_parallel():`` block. The branch's body is an independent sequence of work that may - run concurrently with the region's other branches. ``name`` is optional and used only as a label for - profiling / graph introspection. + Used as ``with qd.graph_parallel():`` or ``with qd.graph_parallel(name="pt"):`` directly inside a + ``with qd.graph_parallel_context():`` block. The branch's body is an independent sequence of work + that may run concurrently with the region's other branches. ``name`` is optional and used only as a + label for profiling / graph introspection. - See ``qd.graph_parallel()`` for the full contract and backend behaviour. + See ``qd.graph_parallel_context()`` for the full contract and backend behaviour. """ yield @@ -940,8 +940,8 @@ def dump_compile_config() -> None: "GraphStatus", "checkpoint", "graph_do_while", + "graph_parallel_context", "graph_parallel", - "branch", "loop_config", "global_thread_idx", "assume_in_range", diff --git a/quadrants/runtime/cuda/graph_manager.cpp b/quadrants/runtime/cuda/graph_manager.cpp index 7827323c7c..046e325115 100644 --- a/quadrants/runtime/cuda/graph_manager.cpp +++ b/quadrants/runtime/cuda/graph_manager.cpp @@ -404,8 +404,8 @@ void GraphManager::build_level(int parent_id, continue; } - // --- A qd.graph_parallel() fork/join region: a contiguous run of this level's direct, - // non-checkpoint tasks tagged with a nonzero stream_parallel_group_id (set by qd.branch()). Each + // --- A qd.graph_parallel_context() fork/join region: a contiguous run of this level's direct, + // non-checkpoint tasks tagged with a nonzero stream_parallel_group_id (set by qd.graph_parallel()). Each // distinct group id is one branch; branches fork from the region's entry (`prev_node`), run their // tasks in order, and join into a single empty node so downstream work waits for all of them. CUDA's // graph executor schedules the independent branch chains on separate streams -> real overlap. --- diff --git a/quadrants/runtime/cuda/graph_manager.h b/quadrants/runtime/cuda/graph_manager.h index 83506587a2..d5681b37ff 100644 --- a/quadrants/runtime/cuda/graph_manager.h +++ b/quadrants/runtime/cuda/graph_manager.h @@ -202,7 +202,7 @@ class GraphManager { unsigned int shared_mem, void **kernel_params); // Add an empty (no-op) node to `graph` depending on every node in `deps`. Used as the join point of a - // qd.graph_parallel() region: it has no work but collects all branch tails into a single successor so + // qd.graph_parallel_context() region: it has no work but collects all branch tails into a single successor so // downstream nodes wait for every branch. `deps` must be non-empty. void *add_empty_node(void *graph, const std::vector &deps); // Recursively build the nodes for graph_do_while level `parent_id` (-1 = kernel top level) over the task range diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index ef7c2c6675..2ebf8fe84d 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -1,6 +1,6 @@ -"""Tests for qd.graph_parallel / qd.branch -- concurrent fork/join branches in graph kernels. +"""Tests for qd.graph_parallel_context / qd.graph_parallel -- concurrent fork/join branches in graph kernels. -`with qd.graph_parallel():` opens a fork/join region whose `with qd.branch():` members are independent +`with qd.graph_parallel_context():` opens a fork/join region whose `with qd.graph_parallel():` members are independent sequences of work. On the CUDA graph path the branches become independent graph chains joined by a single empty node, so the runtime schedules them on parallel streams; on other backends (CPU / AMDGPU / Vulkan / Metal) they run serially but produce identical results. @@ -38,14 +38,14 @@ def _num_offloaded_tasks(): @test_utils.test() def test_graph_parallel_is_no_op_outside_kernels(): - """At Python runtime (outside kernels) qd.graph_parallel / qd.branch must be usable no-op context + """At Python runtime (outside kernels) qd.graph_parallel_context / qd.graph_parallel must be usable no-op context managers, so helpers that are sometimes called from Python and sometimes from kernels still import and run. Mirrors qd.stream_parallel / qd.checkpoint.""" sentinel = [] - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): sentinel.append("a") - with qd.branch(name="b"): + with qd.graph_parallel(name="b"): sentinel.append("b") assert sentinel == ["a", "b"] @@ -63,11 +63,11 @@ def k( y: qd.types.ndarray(qd.f32, ndim=1), z: qd.types.ndarray(qd.f32, ndim=1), ): - with qd.graph_parallel(): - with qd.branch(name="bx"): + with qd.graph_parallel_context(): + with qd.graph_parallel(name="bx"): for i in range(x.shape[0]): x[i] = x[i] + 1.0 - with qd.branch(name="by"): + with qd.graph_parallel(name="by"): for i in range(y.shape[0]): y[i] = y[i] + 2.0 for i in range(z.shape[0]): @@ -110,14 +110,14 @@ def k( b: qd.types.ndarray(qd.f32, ndim=1), c: qd.types.ndarray(qd.f32, ndim=1), ): - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): for i in range(a.shape[0]): a[i] = a[i] + 1.0 - with qd.branch(): + with qd.graph_parallel(): for i in range(b.shape[0]): b[i] = b[i] + 2.0 - with qd.branch(): + with qd.graph_parallel(): for i in range(c.shape[0]): c[i] = c[i] + 3.0 @@ -145,13 +145,13 @@ def test_graph_parallel_multi_loop_branches(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 1.0 for i in range(x.shape[0]): x[i] = x[i] * 2.0 - with qd.branch(): + with qd.graph_parallel(): for i in range(y.shape[0]): y[i] = y[i] + 3.0 for i in range(y.shape[0]): @@ -179,8 +179,8 @@ def test_graph_parallel_single_branch_no_join(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1)): - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 5.0 @@ -203,23 +203,23 @@ def test_graph_parallel_optional_branch_static_if(): @qd.kernel(graph=True) def k_off(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 1.0 if qd.static(False): - with qd.branch(): + with qd.graph_parallel(): for i in range(y.shape[0]): y[i] = y[i] + 1.0 @qd.kernel(graph=True) def k_on(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 1.0 if qd.static(True): - with qd.branch(): + with qd.graph_parallel(): for i in range(y.shape[0]): y[i] = y[i] + 1.0 @@ -256,11 +256,11 @@ def k( counter: qd.types.ndarray(qd.i32, ndim=0), ): while qd.graph_do_while(counter): - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 1 - with qd.branch(): + with qd.graph_parallel(): for i in range(y.shape[0]): y[i] = y[i] + 2 for _ in range(1): @@ -284,12 +284,14 @@ def k( def test_graph_parallel_branch_outside_region_raises(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1)): - with qd.branch(): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 1.0 x = qd.ndarray(qd.f32, shape=(16,)) - with pytest.raises(qd.QuadrantsSyntaxError, match="qd.branch.. can only be used .* inside a qd.graph_parallel"): + with pytest.raises( + qd.QuadrantsSyntaxError, match="qd.graph_parallel.. can only be used .* inside a qd.graph_parallel_context" + ): k(x) @@ -297,8 +299,8 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): def test_graph_parallel_requires_graph_kernel(): @qd.kernel def k(x: qd.types.ndarray(qd.f32, ndim=1)): - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 1.0 @@ -311,12 +313,12 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): def test_graph_parallel_non_branch_body_raises(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1)): - with qd.graph_parallel(): + with qd.graph_parallel_context(): for i in range(x.shape[0]): x[i] = x[i] + 1.0 x = qd.ndarray(qd.f32, shape=(16,)) - with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.branch"): + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.graph_parallel"): k(x) @@ -324,10 +326,10 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): def test_graph_parallel_nested_region_raises(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1)): - with qd.graph_parallel(): - with qd.branch(): - with qd.graph_parallel(): - with qd.branch(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): + with qd.graph_parallel_context(): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 1.0 From 706fd60c51c5ae57436ccb6501d8e3207f4c55be Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 07:19:00 -0700 Subject: [PATCH 07/25] docs/graph: drop name= param from qd.graph_parallel; tighten graph_parallel docs - Remove the optional name= parameter from qd.graph_parallel (branch members now take no arguments); simplify the AST detection/build path accordingly. - Drop the "graph-compatible analogue of qd.stream_parallel" lead-in and the "single-branch region lowers to a plain chain" note from the graph_parallel docs. - Refer to the "graph builder" rather than the "CUDA graph builder" in the graph_parallel feature prose (it is graph-abstraction level, CUDA-honoured today). --- docs/source/user_guide/graph.md | 13 +++--- docs/source/user_guide/streams.md | 2 +- python/quadrants/lang/ast/ast_transformer.py | 47 +++++++------------- python/quadrants/lang/misc.py | 17 ++++--- tests/python/test_graph_parallel.py | 12 ++--- 5 files changed, 36 insertions(+), 55 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 702163f7fc..194e86a5d7 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -470,9 +470,9 @@ In this case, our recommendation is: ## Concurrent branches with `qd.graph_parallel_context` *(experimental)* -`qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel_context` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel_context():` region lets you declare independent stages so the CUDA graph runs them on **parallel streams**. +`qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel_context` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them on **parallel streams**. -This is the graph-compatible analogue of [`qd.stream_parallel()`](streams.md) (which only works for non-graph kernels): both express "these sequences are independent, run them concurrently", but `qd.graph_parallel_context` is honoured by the CUDA graph builder so it composes with `graph=True` and `graph_do_while`. +`qd.graph_parallel_context` is honoured by the graph builder so it composes with `graph=True` and `graph_do_while`. ```python @qd.kernel(graph=True) @@ -481,10 +481,10 @@ def step(...): assemble_shared(...) # serial: feeds both branches with qd.graph_parallel_context(): # fork: branches run concurrently - with qd.graph_parallel(name="pt"): # point-triangle contacts + with qd.graph_parallel(): # point-triangle contacts pt_assemble(...) pt_hessian(...) - with qd.graph_parallel(name="ee"): # edge-edge contacts (independent of pt) + with qd.graph_parallel(): # edge-edge contacts (independent of pt) ee_assemble(...) ee_hessian(...) # join: everything below waits for BOTH branches to finish @@ -496,12 +496,11 @@ def step(...): - **Fork / join.** Every `qd.graph_parallel()` branch in the region forks from the work that precedes the region. All branches must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every branch's last kernel. - **Branches are independent — you guarantee it.** Calls *within* a branch keep their program order, but calls in *different* branches have no ordering. The branches must be data-race free with respect to one another: no branch may read what another writes, and no two branches may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results, exactly like `qd.stream_parallel()`. -- **`name=` is optional** and used only as a label for profiling / graph introspection. ### Restrictions (enforced at kernel compile time) - Must be used inside `@qd.kernel(graph=True)`. -- A region body may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional branch can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). A single-branch region is allowed and lowers to a plain chain (no fork/join overhead). +- A region body may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional branch can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). - `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()` region. - Regions cannot be nested, and a branch body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a branch (a region may, however, sit inside a `qd.graph_do_while` body, as shown above). @@ -510,6 +509,6 @@ def step(...): | backend | result | scheduling | | --- | --- | --- | | CUDA (graph path) | correct | branches run **concurrently** on parallel streams | -| AMDGPU / CPU / Vulkan / Metal | correct | branches run **serially** (the concurrency tags are honoured only by the CUDA graph builder today) | +| AMDGPU / CPU / Vulkan / Metal | correct | branches run **serially** (the concurrency tags are honoured only by the graph builder today) | Because branches are independent by construction, running them serially on the other backends produces identical results — only the scheduling differs. diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index 918d626192..f1c30e0202 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -48,7 +48,7 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. -> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#concurrent-branches-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honoured by the CUDA graph builder. +> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#concurrent-branches-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honoured by the graph builder. ### Restrictions diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index b20e95b032..84c001342b 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1388,35 +1388,20 @@ def _is_graph_parallel_context_call(node: ast.expr) -> bool: return True @staticmethod - def _is_branch_call(node: ast.expr) -> tuple[bool, str | None]: - """If *node* is ``qd.graph_parallel(...)`` (a branch) return ``(True, name)``; otherwise - ``(False, None)``. - - ``name`` is the value of the optional ``name=`` kwarg (a string literal) or ``None``. The call - shape is validated here so misuse raises at the ``with`` site rather than later. - """ + def _is_branch_call(node: ast.expr) -> bool: + """If *node* is a ``qd.graph_parallel()`` (branch) call return True, else False. The call shape is + validated here so misuse raises at the ``with`` site rather than later.""" if not isinstance(node, ast.Call): - return False, None + return False func = node.func is_branch = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( isinstance(func, ast.Name) and func.id == "graph_parallel" ) if not is_branch: - return False, None - if node.args: - raise QuadrantsSyntaxError( - "qd.graph_parallel() takes no positional arguments; use qd.graph_parallel(name='...') instead" - ) - name: str | None = None - for kw in node.keywords: - if kw.arg != "name": - raise QuadrantsSyntaxError( - f"qd.graph_parallel() got unexpected keyword argument {kw.arg!r}; only 'name' is supported" - ) - if not (isinstance(kw.value, ast.Constant) and isinstance(kw.value.value, str)): - raise QuadrantsSyntaxError("qd.graph_parallel(name=...) must be a string literal") - name = kw.value.value - return True, name + return False + if node.args or node.keywords: + raise QuadrantsSyntaxError("qd.graph_parallel() takes no arguments") + return True @staticmethod def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: @@ -1664,9 +1649,8 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: if ASTTransformer._is_graph_parallel_context_call(item.context_expr): return ASTTransformer._build_graph_parallel_context_with(ctx, node) - is_branch, branch_name = ASTTransformer._is_branch_call(item.context_expr) - if is_branch: - return ASTTransformer._build_branch_with(ctx, node, branch_name) + if ASTTransformer._is_branch_call(item.context_expr): + return ASTTransformer._build_branch_with(ctx, node) if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars): raise QuadrantsSyntaxError( @@ -1697,8 +1681,8 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains only ``with qd.graph_parallel():`` blocks, then walks the body. The region emits no IR tag of its own -- each branch inside lowers to a stream-parallel group (via begin/end_stream_parallel), and - the CUDA graph builder forks the distinct groups in a contiguous run and joins them. Regions are - kept apart by the serial work between them (see d3_0_graph_parallel_impl.md).""" + the graph builder forks the distinct groups in a contiguous run and joins them. Regions are kept + apart by the serial work between them (see d3_0_graph_parallel_impl.md).""" if not ctx.is_kernel: raise QuadrantsSyntaxError( "qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func" @@ -1733,8 +1717,7 @@ def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: if isinstance(stmt, ast.Pass): continue if isinstance(stmt, ast.With) and stmt.items: - is_branch, _ = ASTTransformer._is_branch_call(stmt.items[0].context_expr) - if is_branch: + if ASTTransformer._is_branch_call(stmt.items[0].context_expr): continue if isinstance(stmt, ast.If): ASTTransformer._validate_graph_parallel_context_body(stmt.body) @@ -1747,12 +1730,12 @@ def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: ) @staticmethod - def _build_branch_with(ctx: ASTTransformerFuncContext, node: ast.With, name: str | None) -> None: + def _build_branch_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: """Handles ``with qd.graph_parallel():`` branch members of a ``qd.graph_parallel_context()`` region. Reuses the stream-parallel tagging: begin_stream_parallel() assigns this branch a fresh ``stream_parallel_group_id`` that every for-loop in the body inherits, so the offloaded tasks - carry the branch id all the way to the graph builder. ``name`` is currently a label only.""" + carry the branch id all the way to the graph builder.""" if not getattr(ctx, "_in_graph_parallel_context", False): raise QuadrantsSyntaxError( "qd.graph_parallel() can only be used directly inside a qd.graph_parallel_context() region" diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index 076d04a4a4..c251319417 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -759,9 +759,9 @@ def graph_parallel_context(): Used as ``with qd.graph_parallel_context():`` inside a ``@qd.kernel(graph=True)`` kernel. The region's body must contain only ``with qd.graph_parallel():`` blocks. Each branch is an independent sequence of work; the branches have no ordering relative to each other and may execute concurrently, - while everything after the region waits for *all* branches to finish (the join). This is the - CUDA-graph analogue of ``qd.stream_parallel()`` (which is for non-graph kernels): it lets independent - stages -- e.g. qipc's point-triangle and edge-edge assembly -- overlap inside a captured graph. + while everything after the region waits for *all* branches to finish (the join). This is the graph + analogue of ``qd.stream_parallel()`` (which is for non-graph kernels): it lets independent stages -- + e.g. qipc's point-triangle and edge-edge assembly -- overlap inside a captured graph. Concurrency contract (the author's responsibility): branches must be data-race free with respect to one another (no branch reads what another writes, no two branches write the same location). Calls @@ -771,7 +771,7 @@ def graph_parallel_context(): - CUDA SM graph path: branches become independent graph chains joined by an empty node, so the runtime schedules them on parallel streams (real overlap). - CPU / Vulkan / Metal / AMDGPU graph: correct results, branches run serially (the concurrency - tags are honoured only by the CUDA graph builder today). + tags are honoured only by the graph builder today). Restrictions (enforced at kernel compile time): - Must be used inside ``@qd.kernel(graph=True)``. @@ -788,13 +788,12 @@ def graph_parallel_context(): @contextmanager -def graph_parallel(name=None): +def graph_parallel(): """Declares one concurrent branch of an enclosing ``qd.graph_parallel_context()`` region. - Used as ``with qd.graph_parallel():`` or ``with qd.graph_parallel(name="pt"):`` directly inside a - ``with qd.graph_parallel_context():`` block. The branch's body is an independent sequence of work - that may run concurrently with the region's other branches. ``name`` is optional and used only as a - label for profiling / graph introspection. + Used as ``with qd.graph_parallel():`` directly inside a ``with qd.graph_parallel_context():`` block. + The branch's body is an independent sequence of work that may run concurrently with the region's + other branches. See ``qd.graph_parallel_context()`` for the full contract and backend behaviour. """ diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index 2ebf8fe84d..8fabbf4b8a 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -1,13 +1,13 @@ """Tests for qd.graph_parallel_context / qd.graph_parallel -- concurrent fork/join branches in graph kernels. `with qd.graph_parallel_context():` opens a fork/join region whose `with qd.graph_parallel():` members are independent -sequences of work. On the CUDA graph path the branches become independent graph chains joined by a single +sequences of work. On the graph path the branches become independent graph chains joined by a single empty node, so the runtime schedules them on parallel streams; on other backends (CPU / AMDGPU / Vulkan / Metal) they run serially but produce identical results. The behavioural assertions (disjoint-array correctness) hold on every backend. The graph-structure -assertions (node counts: one kernel node per branch task + one empty join node) only apply on the CUDA -graph path, where the builder forks/joins; they are guarded by `_on_cuda()`. +assertions (node counts: one kernel node per branch task + one empty join node) only apply where the +builder forks/joins (CUDA today), so they are guarded by `_on_cuda()`. """ import numpy as np @@ -45,7 +45,7 @@ def test_graph_parallel_is_no_op_outside_kernels(): with qd.graph_parallel_context(): with qd.graph_parallel(): sentinel.append("a") - with qd.graph_parallel(name="b"): + with qd.graph_parallel(): sentinel.append("b") assert sentinel == ["a", "b"] @@ -64,10 +64,10 @@ def k( z: qd.types.ndarray(qd.f32, ndim=1), ): with qd.graph_parallel_context(): - with qd.graph_parallel(name="bx"): + with qd.graph_parallel(): for i in range(x.shape[0]): x[i] = x[i] + 1.0 - with qd.graph_parallel(name="by"): + with qd.graph_parallel(): for i in range(y.shape[0]): y[i] = y[i] + 2.0 for i in range(z.shape[0]): From 89f8f26234d110ae15bc7b7183e0f0d9d21574d9 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 07:30:25 -0700 Subject: [PATCH 08/25] docs/graph: simplify graph_parallel section (sections not branches, trim tables) - Rename the doc's "branch" terminology to "section" (heading, prose, tables, code comments); update the streams.md cross-reference anchor to match. - Stop describing graph_parallel scheduling in terms of parallel streams (we do not expose streams) and drop the qd.stream_parallel analogy from the section. - Trim the comparison table (drop "(parallel streams)"/"(correct)") and the backend-behaviour table (drop the result column, "(graph path)", and the "honoured only by the CUDA graph builder today" note). - Restrictions: drop the redundant "Must be used inside @qd.kernel(graph=True)" bullet and refer to qd.graph_parallel_context instead of a vague "region". - Minor: "falls back to a host-side loop"; use British "behaviour". --- docs/source/user_guide/graph.md | 37 +++++++++++++++---------------- docs/source/user_guide/streams.md | 2 +- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 194e86a5d7..3fdd79d716 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -11,9 +11,9 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i | `graph=True` | hardware accelerated | hardware accelerated | hardware accelerated | runs (no acceleration) | runs (no acceleration) | runs (no acceleration) | | `qd.graph_do_while` | hardware accelerated | host fallback | host fallback | host fallback | host fallback | host fallback | | `qd.checkpoint` | GPU-side | GPU-side | GPU-side | GPU-side | GPU-side | host-side | -| `qd.graph_parallel_context` / `qd.graph_parallel` (concurrent branches) | concurrent (parallel streams) | concurrent (parallel streams) | runs serially (correct) | runs serially (correct) | runs serially (correct) | runs serially (correct) | +| `qd.graph_parallel_context` / `qd.graph_parallel` (concurrent sections) | concurrent | concurrent | runs serially | runs serially | runs serially | runs serially | -AMDGPU `graph_do_while` falls back to the host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2). +AMDGPU `graph_do_while` falls back to a host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2). Nested and sibling `graph_do_while` loops (and mixing `graph_do_while` with top-level `for`-loops) are **experimental** for now — see [Nested loops and mixing with for-loops](#nested-loops-and-mixing-with-for-loops). @@ -123,7 +123,7 @@ def converge(x: qd.types.ndarray(qd.f32, ndim=1), ### Do-while semantics -`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behavior of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop. +`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behaviour of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop. ### ndarray vs field @@ -468,9 +468,9 @@ In this case, our recommendation is: - if you need optimum 100% performance on unsupported platforms, then consider PRing onto quadrants an optimized graph implementation for your target platform - for example it could somehow run MAX_ITER iterations anyway, similar to the earlier hand-rolled version, but via the graph abstraction, hence allowing the code to be compact, cross-platform, and also optimally fast -## Concurrent branches with `qd.graph_parallel_context` *(experimental)* +## Concurrent sections with `qd.graph_parallel_context` *(experimental)* -`qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel_context` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them on **parallel streams**. +`qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel_context` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them concurrently. `qd.graph_parallel_context` is honoured by the graph builder so it composes with `graph=True` and `graph_do_while`. @@ -478,37 +478,36 @@ In this case, our recommendation is: @qd.kernel(graph=True) def step(...): while qd.graph_do_while(ncond): - assemble_shared(...) # serial: feeds both branches + assemble_shared(...) # serial: feeds both sections - with qd.graph_parallel_context(): # fork: branches run concurrently + with qd.graph_parallel_context(): # fork: sections run concurrently with qd.graph_parallel(): # point-triangle contacts pt_assemble(...) pt_hessian(...) with qd.graph_parallel(): # edge-edge contacts (independent of pt) ee_assemble(...) ee_hessian(...) - # join: everything below waits for BOTH branches to finish + # join: everything below waits for BOTH sections to finish merge_hessians(...) precondition(...) ``` ### Semantics -- **Fork / join.** Every `qd.graph_parallel()` branch in the region forks from the work that precedes the region. All branches must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every branch's last kernel. -- **Branches are independent — you guarantee it.** Calls *within* a branch keep their program order, but calls in *different* branches have no ordering. The branches must be data-race free with respect to one another: no branch may read what another writes, and no two branches may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results, exactly like `qd.stream_parallel()`. +- **Fork / join.** Every `qd.graph_parallel()` section in the region forks from the work that precedes the region. All sections must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every section's last kernel. +- **Sections are independent — you guarantee it.** Calls *within* a section keep their program order, but calls in *different* sections have no ordering. The sections must be data-race free with respect to one another: no section may read what another writes, and no two sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results. ### Restrictions (enforced at kernel compile time) -- Must be used inside `@qd.kernel(graph=True)`. -- A region body may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional branch can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). -- `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()` region. -- Regions cannot be nested, and a branch body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a branch (a region may, however, sit inside a `qd.graph_do_while` body, as shown above). +- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). +- `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()`. +- `qd.graph_parallel_context` cannot be nested, and a section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above). ### Backend behaviour -| backend | result | scheduling | -| --- | --- | --- | -| CUDA (graph path) | correct | branches run **concurrently** on parallel streams | -| AMDGPU / CPU / Vulkan / Metal | correct | branches run **serially** (the concurrency tags are honoured only by the graph builder today) | +| backend | scheduling | +| --- | --- | +| CUDA | sections run **concurrently** | +| AMDGPU / CPU / Vulkan / Metal | sections run **serially** | -Because branches are independent by construction, running them serially on the other backends produces identical results — only the scheduling differs. +Because sections are independent by construction, running them serially produces identical results — only the scheduling differs. diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index f1c30e0202..19a40f48ee 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -48,7 +48,7 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. -> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#concurrent-branches-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honoured by the graph builder. +> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#concurrent-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honoured by the graph builder. ### Restrictions From c57e452a052e15edfdc8ae7a30e7bbbce964dc68 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 08:08:08 -0700 Subject: [PATCH 09/25] refactor(graph): use 'parallel section' for graph_parallel internals and docs Rename the internal "branch" concept to "parallel section" everywhere it referred to a qd.graph_parallel() block: AST helpers (_is_branch_call/_build_branch_with/_in_branch and _is_branch_with), the CUDA graph_manager `branches` vector, misc.py docstrings, and all related comments. Update graph.md/streams.md to say "parallel section(s)" (heading + anchor included). No behaviour change. --- docs/source/user_guide/graph.md | 24 +++++------ docs/source/user_guide/streams.md | 2 +- python/quadrants/lang/ast/ast_transformer.py | 40 +++++++++---------- .../function_def_transformer.py | 18 ++++----- python/quadrants/lang/misc.py | 37 ++++++++--------- quadrants/runtime/cuda/graph_manager.cpp | 28 +++++++------ quadrants/runtime/cuda/graph_manager.h | 4 +- 7 files changed, 78 insertions(+), 75 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 3fdd79d716..8b400a2ba1 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -11,7 +11,7 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i | `graph=True` | hardware accelerated | hardware accelerated | hardware accelerated | runs (no acceleration) | runs (no acceleration) | runs (no acceleration) | | `qd.graph_do_while` | hardware accelerated | host fallback | host fallback | host fallback | host fallback | host fallback | | `qd.checkpoint` | GPU-side | GPU-side | GPU-side | GPU-side | GPU-side | host-side | -| `qd.graph_parallel_context` / `qd.graph_parallel` (concurrent sections) | concurrent | concurrent | runs serially | runs serially | runs serially | runs serially | +| `qd.graph_parallel_context` / `qd.graph_parallel` (parallel sections) | concurrent | concurrent | runs serially | runs serially | runs serially | runs serially | AMDGPU `graph_do_while` falls back to a host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2). @@ -468,7 +468,7 @@ In this case, our recommendation is: - if you need optimum 100% performance on unsupported platforms, then consider PRing onto quadrants an optimized graph implementation for your target platform - for example it could somehow run MAX_ITER iterations anyway, similar to the earlier hand-rolled version, but via the graph abstraction, hence allowing the code to be compact, cross-platform, and also optimally fast -## Concurrent sections with `qd.graph_parallel_context` *(experimental)* +## Parallel sections with `qd.graph_parallel_context` *(experimental)* `qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel_context` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them concurrently. @@ -478,36 +478,36 @@ In this case, our recommendation is: @qd.kernel(graph=True) def step(...): while qd.graph_do_while(ncond): - assemble_shared(...) # serial: feeds both sections + assemble_shared(...) # serial: feeds both parallel sections - with qd.graph_parallel_context(): # fork: sections run concurrently + with qd.graph_parallel_context(): # fork: parallel sections run concurrently with qd.graph_parallel(): # point-triangle contacts pt_assemble(...) pt_hessian(...) with qd.graph_parallel(): # edge-edge contacts (independent of pt) ee_assemble(...) ee_hessian(...) - # join: everything below waits for BOTH sections to finish + # join: everything below waits for BOTH parallel sections to finish merge_hessians(...) precondition(...) ``` ### Semantics -- **Fork / join.** Every `qd.graph_parallel()` section in the region forks from the work that precedes the region. All sections must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every section's last kernel. -- **Sections are independent — you guarantee it.** Calls *within* a section keep their program order, but calls in *different* sections have no ordering. The sections must be data-race free with respect to one another: no section may read what another writes, and no two sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results. +- **Fork / join.** Every parallel section in the region forks from the work that precedes the region. All parallel sections must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every parallel section's last kernel. +- **Parallel sections are independent — you guarantee it.** Calls *within* a parallel section keep their program order, but calls in *different* parallel sections have no ordering. The parallel sections must be data-race free with respect to one another: no parallel section may read what another writes, and no two parallel sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results. ### Restrictions (enforced at kernel compile time) -- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). +- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional parallel section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). - `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()`. -- `qd.graph_parallel_context` cannot be nested, and a section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above). +- `qd.graph_parallel_context` cannot be nested, and a parallel section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a parallel section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above). ### Backend behaviour | backend | scheduling | | --- | --- | -| CUDA | sections run **concurrently** | -| AMDGPU / CPU / Vulkan / Metal | sections run **serially** | +| CUDA | parallel sections run **concurrently** | +| AMDGPU / CPU / Vulkan / Metal | parallel sections run **serially** | -Because sections are independent by construction, running them serially produces identical results — only the scheduling differs. +Because parallel sections are independent by construction, running them serially produces identical results — only the scheduling differs. diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index 19a40f48ee..9c1c4f28b1 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -48,7 +48,7 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. -> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#concurrent-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honoured by the graph builder. +> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#parallel-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honoured by the graph builder. ### Restrictions diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 84c001342b..765dba3758 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1388,16 +1388,16 @@ def _is_graph_parallel_context_call(node: ast.expr) -> bool: return True @staticmethod - def _is_branch_call(node: ast.expr) -> bool: - """If *node* is a ``qd.graph_parallel()`` (branch) call return True, else False. The call shape is - validated here so misuse raises at the ``with`` site rather than later.""" + def _is_parallel_section_call(node: ast.expr) -> bool: + """If *node* is a ``qd.graph_parallel()`` (a parallel section) call return True, else False. The + call shape is validated here so misuse raises at the ``with`` site rather than later.""" if not isinstance(node, ast.Call): return False func = node.func - is_branch = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( + is_parallel_section = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( isinstance(func, ast.Name) and func.id == "graph_parallel" ) - if not is_branch: + if not is_parallel_section: return False if node.args or node.keywords: raise QuadrantsSyntaxError("qd.graph_parallel() takes no arguments") @@ -1649,8 +1649,8 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: if ASTTransformer._is_graph_parallel_context_call(item.context_expr): return ASTTransformer._build_graph_parallel_context_with(ctx, node) - if ASTTransformer._is_branch_call(item.context_expr): - return ASTTransformer._build_branch_with(ctx, node) + if ASTTransformer._is_parallel_section_call(item.context_expr): + return ASTTransformer._build_parallel_section_with(ctx, node) if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars): raise QuadrantsSyntaxError( @@ -1680,8 +1680,8 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains only ``with qd.graph_parallel():`` blocks, then walks the body. The region emits no IR tag of its - own -- each branch inside lowers to a stream-parallel group (via begin/end_stream_parallel), and - the graph builder forks the distinct groups in a contiguous run and joins them. Regions are kept + own -- each parallel section inside lowers to a stream-parallel group (via begin/end_stream_parallel), + and the graph builder forks the distinct groups in a contiguous run and joins them. Regions are kept apart by the serial work between them (see d3_0_graph_parallel_impl.md).""" if not ctx.is_kernel: raise QuadrantsSyntaxError( @@ -1692,7 +1692,7 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast raise QuadrantsSyntaxError("qd.graph_parallel_context() requires @qd.kernel(graph=True)") if getattr(ctx, "_in_graph_parallel_context", False): raise QuadrantsSyntaxError("qd.graph_parallel_context() regions cannot be nested") - if getattr(ctx, "_in_branch", False): + if getattr(ctx, "_in_parallel_section", False): raise QuadrantsSyntaxError( "qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body" ) @@ -1707,17 +1707,17 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast @staticmethod def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: """A qd.graph_parallel_context() region body may contain only `with qd.graph_parallel():` blocks, - optionally wrapped in compile-time `if qd.static(...)` branches (the optional-branch pattern, e.g. + optionally wrapped in compile-time `if qd.static(...)` (the optional parallel-section pattern, e.g. qipc's ENABLE_EE). Docstrings / coverage probes / `pass` are allowed. Anything else (a bare - for-loop, assignment, etc.) is a serial task that would silently fall outside any branch, so - reject it.""" + for-loop, assignment, etc.) is a serial task that would silently fall outside any parallel section, + so reject it.""" for i, stmt in enumerate(stmts): if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): continue if isinstance(stmt, ast.Pass): continue if isinstance(stmt, ast.With) and stmt.items: - if ASTTransformer._is_branch_call(stmt.items[0].context_expr): + if ASTTransformer._is_parallel_section_call(stmt.items[0].context_expr): continue if isinstance(stmt, ast.If): ASTTransformer._validate_graph_parallel_context_body(stmt.body) @@ -1730,23 +1730,23 @@ def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: ) @staticmethod - def _build_branch_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: - """Handles ``with qd.graph_parallel():`` branch members of a ``qd.graph_parallel_context()`` region. + def _build_parallel_section_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: + """Handles a ``with qd.graph_parallel():`` parallel section of a ``qd.graph_parallel_context()`` region. - Reuses the stream-parallel tagging: begin_stream_parallel() assigns this branch a fresh + Reuses the stream-parallel tagging: begin_stream_parallel() assigns this parallel section a fresh ``stream_parallel_group_id`` that every for-loop in the body inherits, so the offloaded tasks - carry the branch id all the way to the graph builder.""" + carry the parallel-section id all the way to the graph builder.""" if not getattr(ctx, "_in_graph_parallel_context", False): raise QuadrantsSyntaxError( "qd.graph_parallel() can only be used directly inside a qd.graph_parallel_context() region" ) - ctx._in_branch = True + ctx._in_parallel_section = True ctx.ast_builder.begin_stream_parallel() try: build_stmts(ctx, node.body) finally: ctx.ast_builder.end_stream_parallel() - ctx._in_branch = False + ctx._in_parallel_section = False return None @staticmethod diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index b2de9d8132..070814c5f4 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -617,9 +617,9 @@ def _is_graph_parallel_context_with(stmt: ast.stmt) -> bool: return False @staticmethod - def _is_branch_with(stmt: ast.stmt) -> bool: - """Syntactic check matching ASTTransformer._is_branch_call: a ``with qd.graph_parallel(...):`` - branch member of a ``qd.graph_parallel_context()`` region.""" + def _is_parallel_section_with(stmt: ast.stmt) -> bool: + """Syntactic check matching ASTTransformer._is_parallel_section_call: a ``with qd.graph_parallel(...):`` + parallel section of a ``qd.graph_parallel_context()`` region.""" if not isinstance(stmt, ast.With) or len(stmt.items) != 1: return False ctx_expr = stmt.items[0].context_expr @@ -693,15 +693,15 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo continue if FunctionDefTransformer._is_graph_parallel_context_with(stmt): # A `with qd.graph_parallel_context()` region groups concurrent `with qd.graph_parallel()` - # branches; it is a legal sibling of for-loops / checkpoints. Its body must be branch blocks - # (optionally under `if qd.static(...)`); the full check is in - # ASTTransformer._build_graph_parallel_context_with. Each branch body is task territory, - # validated here with the in-loop rules. Descend through `if` members so branches inside an - # optional-branch `if qd.static(...)` are reached too. + # parallel sections; it is a legal sibling of for-loops / checkpoints. Its body must be + # parallel-section blocks (optionally under `if qd.static(...)`); the full check is in + # ASTTransformer._build_graph_parallel_context_with. Each parallel section's body is task + # territory, validated here with the in-loop rules. Descend through `if` members so parallel + # sections inside an optional `if qd.static(...)` are reached too. pending = list(stmt.body) while pending: member = pending.pop() - if FunctionDefTransformer._is_branch_with(member): + if FunctionDefTransformer._is_parallel_section_with(member): FunctionDefTransformer._validate_graph_do_while_stmt_list(member.body, is_kernel_top=False) elif isinstance(member, ast.If): pending.extend(member.body) diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index c251319417..fec45d7078 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -754,29 +754,30 @@ def graph_do_while(condition) -> bool: @contextmanager def graph_parallel_context(): - """Opens a fork/join region whose ``qd.graph_parallel()`` branches run concurrently. + """Opens a fork/join region whose ``qd.graph_parallel()`` parallel sections run concurrently. - Used as ``with qd.graph_parallel_context():`` inside a ``@qd.kernel(graph=True)`` kernel. The - region's body must contain only ``with qd.graph_parallel():`` blocks. Each branch is an independent - sequence of work; the branches have no ordering relative to each other and may execute concurrently, - while everything after the region waits for *all* branches to finish (the join). This is the graph - analogue of ``qd.stream_parallel()`` (which is for non-graph kernels): it lets independent stages -- - e.g. qipc's point-triangle and edge-edge assembly -- overlap inside a captured graph. + Used as ``with qd.graph_parallel_context():`` inside a ``@qd.kernel(graph=True)`` kernel. The region's + body must contain only ``with qd.graph_parallel():`` blocks. Each parallel section is an independent + sequence of work; the parallel sections have no ordering relative to each other and may execute + concurrently, while everything after the region waits for *all* parallel sections to finish (the join). + This is the graph analogue of ``qd.stream_parallel()`` (which is for non-graph kernels): it lets + independent stages -- e.g. qipc's point-triangle and edge-edge assembly -- overlap inside a captured + graph. - Concurrency contract (the author's responsibility): branches must be data-race free with respect to - one another (no branch reads what another writes, no two branches write the same location). Calls - *within* a branch keep their program order. + Concurrency contract (the author's responsibility): parallel sections must be data-race free with + respect to one another (no parallel section reads what another writes, no two parallel sections write + the same location). Calls *within* a parallel section keep their program order. Backend behaviour: - - CUDA SM graph path: branches become independent graph chains joined by an empty node, so the - runtime schedules them on parallel streams (real overlap). - - CPU / Vulkan / Metal / AMDGPU graph: correct results, branches run serially (the concurrency - tags are honoured only by the graph builder today). + - CUDA SM graph path: parallel sections become independent graph chains joined by an empty node, so + the runtime schedules them on parallel streams (real overlap). + - CPU / Vulkan / Metal / AMDGPU graph: correct results, parallel sections run serially (the + concurrency tags are honoured only by the graph builder today). Restrictions (enforced at kernel compile time): - Must be used inside ``@qd.kernel(graph=True)``. - The region body may contain only ``with qd.graph_parallel():`` blocks. - - Regions cannot be nested, and a branch body must be straight-line task work (no nested + - Regions cannot be nested, and a parallel section body must be straight-line task work (no nested ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). This function should not be called directly at runtime; it is recognised and transformed during AST @@ -789,11 +790,11 @@ def graph_parallel_context(): @contextmanager def graph_parallel(): - """Declares one concurrent branch of an enclosing ``qd.graph_parallel_context()`` region. + """Declares one parallel section of an enclosing ``qd.graph_parallel_context()`` region. Used as ``with qd.graph_parallel():`` directly inside a ``with qd.graph_parallel_context():`` block. - The branch's body is an independent sequence of work that may run concurrently with the region's - other branches. + The parallel section's body is an independent sequence of work that may run concurrently with the + region's other parallel sections. See ``qd.graph_parallel_context()`` for the full contract and backend behaviour. """ diff --git a/quadrants/runtime/cuda/graph_manager.cpp b/quadrants/runtime/cuda/graph_manager.cpp index 046e325115..d932abb8aa 100644 --- a/quadrants/runtime/cuda/graph_manager.cpp +++ b/quadrants/runtime/cuda/graph_manager.cpp @@ -406,18 +406,19 @@ void GraphManager::build_level(int parent_id, // --- A qd.graph_parallel_context() fork/join region: a contiguous run of this level's direct, // non-checkpoint tasks tagged with a nonzero stream_parallel_group_id (set by qd.graph_parallel()). Each - // distinct group id is one branch; branches fork from the region's entry (`prev_node`), run their - // tasks in order, and join into a single empty node so downstream work waits for all of them. CUDA's - // graph executor schedules the independent branch chains on separate streams -> real overlap. --- + // distinct group id is one parallel section; the parallel sections fork from the region's entry + // (`prev_node`), run their tasks in order, and join into a single empty node so downstream work waits for + // all of them. CUDA's graph executor schedules the independent parallel-section chains on separate + // streams -> real overlap. --- if (tasks[cursor].stream_parallel_group_id != 0 && tasks[cursor].checkpoint_id < 0) { int run_end = cursor; while (run_end < end && tasks[run_end].graph_do_while_level_id == parent_id && tasks[run_end].checkpoint_id < 0 && tasks[run_end].stream_parallel_group_id != 0) { run_end++; } - // Bucket the run's tasks by branch id, preserving first-seen (declaration) order. + // Bucket the run's tasks by parallel-section id, preserving first-seen (declaration) order. std::vector group_ids; - std::vector> branches; + std::vector> parallel_sections; for (int t = cursor; t < run_end; t++) { const int g = tasks[t].stream_parallel_group_id; int idx = -1; @@ -430,16 +431,16 @@ void GraphManager::build_level(int parent_id, if (idx < 0) { idx = (int)group_ids.size(); group_ids.push_back(g); - branches.emplace_back(); + parallel_sections.emplace_back(); } - branches[idx].push_back(t); + parallel_sections[idx].push_back(t); } void *ctx_ptr = &cached.persistent_ctx; std::vector tails; - tails.reserve(branches.size()); - for (auto &br : branches) { - void *bp = prev_node; // every branch forks from the region entry dependency - for (int t : br) { + tails.reserve(parallel_sections.size()); + for (auto &ps : parallel_sections) { + void *bp = prev_node; // every parallel section forks from the region entry dependency + for (int t : ps) { bp = add_kernel_node(target_graph, bp, cuda_module->lookup_function(tasks[t].name), (unsigned int)tasks[t].grid_dim, (unsigned int)tasks[t].block_dim, (unsigned int)tasks[t].dynamic_shared_array_bytes, &ctx_ptr); @@ -447,8 +448,9 @@ void GraphManager::build_level(int parent_id, } tails.push_back(bp); } - // Join. A single-branch region (e.g. an optional branch compiled out) has nothing to join, so just - // continue the chain from its tail; otherwise collect all tails into one empty successor node. + // Join. A single-parallel-section region (e.g. an optional parallel section compiled out) has nothing + // to join, so just continue the chain from its tail; otherwise collect all tails into one empty + // successor node. if (tails.size() == 1) { prev_node = tails[0]; } else { diff --git a/quadrants/runtime/cuda/graph_manager.h b/quadrants/runtime/cuda/graph_manager.h index d5681b37ff..bb724c6cad 100644 --- a/quadrants/runtime/cuda/graph_manager.h +++ b/quadrants/runtime/cuda/graph_manager.h @@ -202,8 +202,8 @@ class GraphManager { unsigned int shared_mem, void **kernel_params); // Add an empty (no-op) node to `graph` depending on every node in `deps`. Used as the join point of a - // qd.graph_parallel_context() region: it has no work but collects all branch tails into a single successor so - // downstream nodes wait for every branch. `deps` must be non-empty. + // qd.graph_parallel_context() region: it has no work but collects all parallel-section tails into a single + // successor so downstream nodes wait for every parallel section. `deps` must be non-empty. void *add_empty_node(void *graph, const std::vector &deps); // Recursively build the nodes for graph_do_while level `parent_id` (-1 = kernel top level) over the task range // [begin, end) into `target_graph` (the body graph of `parent_id`, or the root graph for -1). Direct tasks become From 4fdf6e457d1c53e16aae02eee8b09ae7f0a01b19 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 08:31:00 -0700 Subject: [PATCH 10/25] test(graph): use 'parallel section' in graph_parallel tests Rename the remaining "branch" wording in test_graph_parallel.py: test function names (e.g. test_graph_parallel_two_branches -> _two_sections) and all docstrings/comments now say "parallel section(s)", matching the source/docs rename. --- tests/python/test_graph_parallel.py | 68 +++++++++++++++-------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index 8fabbf4b8a..9c52ac4fb9 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -1,13 +1,14 @@ -"""Tests for qd.graph_parallel_context / qd.graph_parallel -- concurrent fork/join branches in graph kernels. +"""Tests for qd.graph_parallel_context / qd.graph_parallel -- concurrent fork/join parallel sections in +graph kernels. -`with qd.graph_parallel_context():` opens a fork/join region whose `with qd.graph_parallel():` members are independent -sequences of work. On the graph path the branches become independent graph chains joined by a single -empty node, so the runtime schedules them on parallel streams; on other backends (CPU / AMDGPU / Vulkan / -Metal) they run serially but produce identical results. +`with qd.graph_parallel_context():` opens a fork/join region whose `with qd.graph_parallel():` members are +independent sequences of work. On the graph path the parallel sections become independent graph chains +joined by a single empty node, so the runtime schedules them on parallel streams; on other backends +(CPU / AMDGPU / Vulkan / Metal) they run serially but produce identical results. The behavioural assertions (disjoint-array correctness) hold on every backend. The graph-structure -assertions (node counts: one kernel node per branch task + one empty join node) only apply where the -builder forks/joins (CUDA today), so they are guarded by `_on_cuda()`. +assertions (node counts: one kernel node per parallel-section task + one empty join node) only apply where +the builder forks/joins (CUDA today), so they are guarded by `_on_cuda()`. """ import numpy as np @@ -51,10 +52,10 @@ def test_graph_parallel_is_no_op_outside_kernels(): @test_utils.test() -def test_graph_parallel_two_branches(): - """Two branches write disjoint arrays; a serial loop after the region reads both (so it depends on - the join). Results must match the serial reference on every backend; on CUDA the graph has one node - per task plus one empty join node.""" +def test_graph_parallel_two_sections(): + """Two parallel sections write disjoint arrays; a serial loop after the region reads both (so it + depends on the join). Results must match the serial reference on every backend; on CUDA the graph has + one node per task plus one empty join node.""" n = 1024 @qd.kernel(graph=True) @@ -85,7 +86,7 @@ def k( num_tasks = _num_offloaded_tasks() if _on_cuda(): # One graph node per offloaded task (each dynamic-bound loop is a bound-compute serial + a - # range_for, both in the branch) plus exactly one empty join node for the single region. + # range_for, both in the parallel section) plus exactly one empty join node for the single region. assert _graph_num_nodes() == num_tasks + 1 np.testing.assert_allclose(x.to_numpy(), 1.0) @@ -100,8 +101,8 @@ def k( @test_utils.test() -def test_graph_parallel_three_branches(): - """Fan-out of three independent branches; one empty join node.""" +def test_graph_parallel_three_sections(): + """Fan-out of three independent parallel sections; one empty join node.""" n = 256 @qd.kernel(graph=True) @@ -130,7 +131,7 @@ def k( k(a, b, c) num_tasks = _num_offloaded_tasks() if _on_cuda(): - assert _graph_num_nodes() == num_tasks + 1 # three branches + one join + assert _graph_num_nodes() == num_tasks + 1 # three parallel sections + one join np.testing.assert_allclose(a.to_numpy(), 1.0) np.testing.assert_allclose(b.to_numpy(), 2.0) @@ -138,9 +139,10 @@ def k( @test_utils.test() -def test_graph_parallel_multi_loop_branches(): - """Each branch contains several loops; they must chain in order inside the branch while the two - branches run independently. Branch tasks = 4, plus one join node on CUDA.""" +def test_graph_parallel_multi_loop_sections(): + """Each parallel section contains several loops; they must chain in order inside the parallel section + while the two parallel sections run independently. Parallel-section tasks = 4, plus one join node on + CUDA.""" n = 128 @qd.kernel(graph=True) @@ -165,16 +167,17 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): k(x, y) num_tasks = _num_offloaded_tasks() if _on_cuda(): - assert _graph_num_nodes() == num_tasks + 1 # all branch tasks + one join + assert _graph_num_nodes() == num_tasks + 1 # all parallel-section tasks + one join np.testing.assert_allclose(x.to_numpy(), 2.0) # (0+1)*2 np.testing.assert_allclose(y.to_numpy(), 12.0) # (0+3)*4 @test_utils.test() -def test_graph_parallel_single_branch_no_join(): - """A region with a single branch (e.g. an optional branch compiled out) needs no join: it degenerates - to a plain chain, so the node count equals the number of branch tasks (no extra empty node).""" +def test_graph_parallel_single_section_no_join(): + """A region with a single parallel section (e.g. an optional parallel section compiled out) needs no + join: it degenerates to a plain chain, so the node count equals the number of parallel-section tasks + (no extra empty node).""" n = 256 @qd.kernel(graph=True) @@ -190,15 +193,16 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): k(x) num_tasks = _num_offloaded_tasks() if _on_cuda(): - assert _graph_num_nodes() == num_tasks # single branch -> plain chain, no join node + assert _graph_num_nodes() == num_tasks # single parallel section -> plain chain, no join node np.testing.assert_allclose(x.to_numpy(), 5.0) @test_utils.test() -def test_graph_parallel_optional_branch_static_if(): - """The qipc ENABLE_EE pattern: a branch wrapped in `if qd.static(...)`. When the flag is False the - branch is compiled out (region has one branch -> no join); when True both branches run.""" +def test_graph_parallel_optional_section_static_if(): + """The qipc ENABLE_EE pattern: a parallel section wrapped in `if qd.static(...)`. When the flag is + False the parallel section is compiled out (region has one parallel section -> no join); when True both + parallel sections run.""" n = 128 @qd.kernel(graph=True) @@ -229,15 +233,15 @@ def k_on(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1 y.from_numpy(np.zeros(n, dtype=np.float32)) k_off(x, y) if _on_cuda(): - assert _graph_num_nodes() == _num_offloaded_tasks() # single branch -> no join + assert _graph_num_nodes() == _num_offloaded_tasks() # single parallel section -> no join np.testing.assert_allclose(x.to_numpy(), 1.0) - np.testing.assert_allclose(y.to_numpy(), 0.0) # EE branch compiled out + np.testing.assert_allclose(y.to_numpy(), 0.0) # EE parallel section compiled out x.from_numpy(np.zeros(n, dtype=np.float32)) y.from_numpy(np.zeros(n, dtype=np.float32)) k_on(x, y) if _on_cuda(): - assert _graph_num_nodes() == _num_offloaded_tasks() + 1 # two branches + join + assert _graph_num_nodes() == _num_offloaded_tasks() + 1 # two parallel sections + join np.testing.assert_allclose(x.to_numpy(), 1.0) np.testing.assert_allclose(y.to_numpy(), 1.0) @@ -245,7 +249,7 @@ def k_on(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1 @test_utils.test() def test_graph_parallel_inside_graph_do_while(): """A fork/join region inside a qd.graph_do_while loop body must be correct across iterations: each - iteration runs both branches, then decrements the counter.""" + iteration runs both parallel sections, then decrements the counter.""" n = 64 iters = 5 @@ -281,7 +285,7 @@ def k( @test_utils.test() -def test_graph_parallel_branch_outside_region_raises(): +def test_graph_parallel_section_outside_region_raises(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1)): with qd.graph_parallel(): @@ -310,7 +314,7 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): @test_utils.test() -def test_graph_parallel_non_branch_body_raises(): +def test_graph_parallel_non_section_body_raises(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1)): with qd.graph_parallel_context(): From c334f8bdbdc62eec042b2ee5845d2cdeadbd8804 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 08:36:08 -0700 Subject: [PATCH 11/25] test(graph): clearer name for graph_parallel_context body-validation test Rename test_graph_parallel_non_section_body_raises -> test_graph_parallel_context_non_graph_parallel_raises so the name says what it checks: a qd.graph_parallel_context() whose body is not a qd.graph_parallel() must raise. --- tests/python/test_graph_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index 9c52ac4fb9..0407abefd5 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -314,7 +314,7 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): @test_utils.test() -def test_graph_parallel_non_section_body_raises(): +def test_graph_parallel_context_non_graph_parallel_raises(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1)): with qd.graph_parallel_context(): From 65922495b1c839383ee12a983ff182e801940cdc Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 08:49:58 -0700 Subject: [PATCH 12/25] test(graph): construct-based name for graph_parallel-outside-context test Rename test_graph_parallel_section_outside_region_raises -> test_graph_parallel_outside_context_raises: it checks that qd.graph_parallel() used outside a qd.graph_parallel_context() raises. --- tests/python/test_graph_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index 0407abefd5..3922b4884a 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -285,7 +285,7 @@ def k( @test_utils.test() -def test_graph_parallel_section_outside_region_raises(): +def test_graph_parallel_outside_context_raises(): @qd.kernel(graph=True) def k(x: qd.types.ndarray(qd.f32, ndim=1)): with qd.graph_parallel(): From a7988b386439400a3ac97bbaff18bde29ed891b7 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 08:54:04 -0700 Subject: [PATCH 13/25] docs/graph: drop CUDA conditional while nodes note from do-while semantics --- docs/source/user_guide/graph.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 8b400a2ba1..b74b95a445 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -123,7 +123,7 @@ def converge(x: qd.types.ndarray(qd.f32, ndim=1), ### Do-while semantics -`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behaviour of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop. +`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop. ### ndarray vs field From 509adfda0e34358bf4663d11a1736ba61fac6233 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 08:59:00 -0700 Subject: [PATCH 14/25] docs/graph: trim intro of graph_parallel_context section --- docs/source/user_guide/graph.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index b74b95a445..2556f9d806 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -470,7 +470,7 @@ In this case, our recommendation is: ## Parallel sections with `qd.graph_parallel_context` *(experimental)* -`qd.checkpoint` and `graph_do_while` change *which* kernels run and *how many times*; `qd.graph_parallel_context` changes *how* a graph's kernels are scheduled relative to each other. By default the kernels captured in a `graph=True` kernel run as a single dependency chain (each waits for the previous one), even when they are completely independent. A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them concurrently. +A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them concurrently. `qd.graph_parallel_context` is honoured by the graph builder so it composes with `graph=True` and `graph_do_while`. From 8a110e0d80ddd138c096656f19692e049da90c7c Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 09:05:15 -0700 Subject: [PATCH 15/25] docs/graph: use American spellings (behavior/honored/recognized) Flip the British spellings introduced for graph_parallel back to American to match the rest of the codebase: behaviour->behavior, honoured->honored, recognised->recognized, behavioural->behavioral across graph.md, streams.md, misc.py docstrings, and the tests. --- docs/source/user_guide/graph.md | 4 ++-- docs/source/user_guide/streams.md | 2 +- python/quadrants/lang/misc.py | 10 +++++----- tests/python/test_graph_parallel.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 2556f9d806..08f3f6efd4 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -472,7 +472,7 @@ In this case, our recommendation is: A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them concurrently. -`qd.graph_parallel_context` is honoured by the graph builder so it composes with `graph=True` and `graph_do_while`. +`qd.graph_parallel_context` is honored by the graph builder so it composes with `graph=True` and `graph_do_while`. ```python @qd.kernel(graph=True) @@ -503,7 +503,7 @@ def step(...): - `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()`. - `qd.graph_parallel_context` cannot be nested, and a parallel section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a parallel section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above). -### Backend behaviour +### Backend behavior | backend | scheduling | | --- | --- | diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index 9c1c4f28b1..d30a0bc442 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -48,7 +48,7 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. -> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#parallel-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honoured by the graph builder. +> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#parallel-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honored by the graph builder. ### Restrictions diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index fec45d7078..4216bc1da3 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -746,7 +746,7 @@ def graph_do_while(condition) -> bool: iteration). The one structural rule is that a ``qd.graph_do_while`` ``while``-loop may appear only at the kernel top level or directly inside another ``graph_do_while`` body, not inside a ``for``-loop. - This function should not be called directly at runtime; it is recognised and transformed during AST compilation. + This function should not be called directly at runtime; it is recognized and transformed during AST compilation. Requires ``@qd.kernel(graph=True)``. """ return bool(condition) @@ -768,11 +768,11 @@ def graph_parallel_context(): respect to one another (no parallel section reads what another writes, no two parallel sections write the same location). Calls *within* a parallel section keep their program order. - Backend behaviour: + Backend behavior: - CUDA SM graph path: parallel sections become independent graph chains joined by an empty node, so the runtime schedules them on parallel streams (real overlap). - CPU / Vulkan / Metal / AMDGPU graph: correct results, parallel sections run serially (the - concurrency tags are honoured only by the graph builder today). + concurrency tags are honored only by the graph builder today). Restrictions (enforced at kernel compile time): - Must be used inside ``@qd.kernel(graph=True)``. @@ -780,7 +780,7 @@ def graph_parallel_context(): - Regions cannot be nested, and a parallel section body must be straight-line task work (no nested ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). - This function should not be called directly at runtime; it is recognised and transformed during AST + This function should not be called directly at runtime; it is recognized and transformed during AST compilation. At Python runtime (outside kernels) it is a no-op context manager. See also ``docs/source/user_guide/graph.md``. @@ -796,7 +796,7 @@ def graph_parallel(): The parallel section's body is an independent sequence of work that may run concurrently with the region's other parallel sections. - See ``qd.graph_parallel_context()`` for the full contract and backend behaviour. + See ``qd.graph_parallel_context()`` for the full contract and backend behavior. """ yield diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index 3922b4884a..8b85ebe88f 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -6,7 +6,7 @@ joined by a single empty node, so the runtime schedules them on parallel streams; on other backends (CPU / AMDGPU / Vulkan / Metal) they run serially but produce identical results. -The behavioural assertions (disjoint-array correctness) hold on every backend. The graph-structure +The behavioral assertions (disjoint-array correctness) hold on every backend. The graph-structure assertions (node counts: one kernel node per parallel-section task + one empty join node) only apply where the builder forks/joins (CUDA today), so they are guarded by `_on_cuda()`. """ From 41817414a63341075b9122fa68e6447fa6d94b4f Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 09:17:01 -0700 Subject: [PATCH 16/25] docs/graph: rename term 'parallel section' -> 'qd.graph_parallel section' Use the construct-named term "qd.graph_parallel section" throughout the docs, docstrings, comments, and tests (dropping the redundant prefix where qd.graph_parallel is already adjacent). Updates the graph.md heading + the streams.md cross-reference anchor accordingly. --- docs/source/user_guide/graph.md | 24 ++++----- docs/source/user_guide/streams.md | 2 +- python/quadrants/lang/ast/ast_transformer.py | 27 +++++----- .../function_def_transformer.py | 12 ++--- python/quadrants/lang/misc.py | 41 +++++++------- quadrants/runtime/cuda/graph_manager.cpp | 18 +++---- quadrants/runtime/cuda/graph_manager.h | 4 +- tests/python/test_graph_parallel.py | 54 +++++++++---------- 8 files changed, 92 insertions(+), 90 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 08f3f6efd4..278998f83b 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -11,7 +11,7 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i | `graph=True` | hardware accelerated | hardware accelerated | hardware accelerated | runs (no acceleration) | runs (no acceleration) | runs (no acceleration) | | `qd.graph_do_while` | hardware accelerated | host fallback | host fallback | host fallback | host fallback | host fallback | | `qd.checkpoint` | GPU-side | GPU-side | GPU-side | GPU-side | GPU-side | host-side | -| `qd.graph_parallel_context` / `qd.graph_parallel` (parallel sections) | concurrent | concurrent | runs serially | runs serially | runs serially | runs serially | +| `qd.graph_parallel_context` / `qd.graph_parallel` (sections) | concurrent | concurrent | runs serially | runs serially | runs serially | runs serially | AMDGPU `graph_do_while` falls back to a host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2). @@ -468,7 +468,7 @@ In this case, our recommendation is: - if you need optimum 100% performance on unsupported platforms, then consider PRing onto quadrants an optimized graph implementation for your target platform - for example it could somehow run MAX_ITER iterations anyway, similar to the earlier hand-rolled version, but via the graph abstraction, hence allowing the code to be compact, cross-platform, and also optimally fast -## Parallel sections with `qd.graph_parallel_context` *(experimental)* +## `qd.graph_parallel` sections with `qd.graph_parallel_context` *(experimental)* A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them concurrently. @@ -478,36 +478,36 @@ A `with qd.graph_parallel_context():` region lets you declare independent stages @qd.kernel(graph=True) def step(...): while qd.graph_do_while(ncond): - assemble_shared(...) # serial: feeds both parallel sections + assemble_shared(...) # serial: feeds both `qd.graph_parallel` sections - with qd.graph_parallel_context(): # fork: parallel sections run concurrently + with qd.graph_parallel_context(): # fork: `qd.graph_parallel` sections run concurrently with qd.graph_parallel(): # point-triangle contacts pt_assemble(...) pt_hessian(...) with qd.graph_parallel(): # edge-edge contacts (independent of pt) ee_assemble(...) ee_hessian(...) - # join: everything below waits for BOTH parallel sections to finish + # join: everything below waits for BOTH `qd.graph_parallel` sections to finish merge_hessians(...) precondition(...) ``` ### Semantics -- **Fork / join.** Every parallel section in the region forks from the work that precedes the region. All parallel sections must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every parallel section's last kernel. -- **Parallel sections are independent — you guarantee it.** Calls *within* a parallel section keep their program order, but calls in *different* parallel sections have no ordering. The parallel sections must be data-race free with respect to one another: no parallel section may read what another writes, and no two parallel sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results. +- **Fork / join.** Every `qd.graph_parallel` section in the region forks from the work that precedes the region. All `qd.graph_parallel` sections must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every `qd.graph_parallel` section's last kernel. +- **`qd.graph_parallel` sections are independent — you guarantee it.** Calls *within* a `qd.graph_parallel` section keep their program order, but calls in *different* `qd.graph_parallel` sections have no ordering. The `qd.graph_parallel` sections must be data-race free with respect to one another: no `qd.graph_parallel` section may read what another writes, and no two `qd.graph_parallel` sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results. ### Restrictions (enforced at kernel compile time) -- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional parallel section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). +- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional `qd.graph_parallel` section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). - `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()`. -- `qd.graph_parallel_context` cannot be nested, and a parallel section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a parallel section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above). +- `qd.graph_parallel_context` cannot be nested, and a `qd.graph_parallel` section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a `qd.graph_parallel` section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above). ### Backend behavior | backend | scheduling | | --- | --- | -| CUDA | parallel sections run **concurrently** | -| AMDGPU / CPU / Vulkan / Metal | parallel sections run **serially** | +| CUDA | `qd.graph_parallel` sections run **concurrently** | +| AMDGPU / CPU / Vulkan / Metal | `qd.graph_parallel` sections run **serially** | -Because parallel sections are independent by construction, running them serially produces identical results — only the scheduling differs. +Because `qd.graph_parallel` sections are independent by construction, running them serially produces identical results — only the scheduling differs. diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index d30a0bc442..f9a79e3f48 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -48,7 +48,7 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns. -> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#parallel-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honored by the graph builder. +> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#qdgraph_parallel-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honored by the graph builder. ### Restrictions diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 765dba3758..247b99c684 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1389,8 +1389,8 @@ def _is_graph_parallel_context_call(node: ast.expr) -> bool: @staticmethod def _is_parallel_section_call(node: ast.expr) -> bool: - """If *node* is a ``qd.graph_parallel()`` (a parallel section) call return True, else False. The - call shape is validated here so misuse raises at the ``with`` site rather than later.""" + """If *node* is a ``qd.graph_parallel()`` (a section) call return True, else False. The call shape + is validated here so misuse raises at the ``with`` site rather than later.""" if not isinstance(node, ast.Call): return False func = node.func @@ -1680,9 +1680,10 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains only ``with qd.graph_parallel():`` blocks, then walks the body. The region emits no IR tag of its - own -- each parallel section inside lowers to a stream-parallel group (via begin/end_stream_parallel), - and the graph builder forks the distinct groups in a contiguous run and joins them. Regions are kept - apart by the serial work between them (see d3_0_graph_parallel_impl.md).""" + own -- each ``qd.graph_parallel`` section inside lowers to a stream-parallel group (via + begin/end_stream_parallel), and the graph builder forks the distinct groups in a contiguous run and + joins them. Regions are kept apart by the serial work between them (see + d3_0_graph_parallel_impl.md).""" if not ctx.is_kernel: raise QuadrantsSyntaxError( "qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func" @@ -1707,10 +1708,10 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast @staticmethod def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: """A qd.graph_parallel_context() region body may contain only `with qd.graph_parallel():` blocks, - optionally wrapped in compile-time `if qd.static(...)` (the optional parallel-section pattern, e.g. - qipc's ENABLE_EE). Docstrings / coverage probes / `pass` are allowed. Anything else (a bare - for-loop, assignment, etc.) is a serial task that would silently fall outside any parallel section, - so reject it.""" + optionally wrapped in compile-time `if qd.static(...)` (the optional ``qd.graph_parallel`` section + pattern, e.g. qipc's ENABLE_EE). Docstrings / coverage probes / `pass` are allowed. Anything else + (a bare for-loop, assignment, etc.) is a serial task that would silently fall outside any + ``qd.graph_parallel`` section, so reject it.""" for i, stmt in enumerate(stmts): if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): continue @@ -1731,11 +1732,11 @@ def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: @staticmethod def _build_parallel_section_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: - """Handles a ``with qd.graph_parallel():`` parallel section of a ``qd.graph_parallel_context()`` region. + """Handles a ``with qd.graph_parallel():`` section of a ``qd.graph_parallel_context()`` region. - Reuses the stream-parallel tagging: begin_stream_parallel() assigns this parallel section a fresh - ``stream_parallel_group_id`` that every for-loop in the body inherits, so the offloaded tasks - carry the parallel-section id all the way to the graph builder.""" + Reuses the stream-parallel tagging: begin_stream_parallel() assigns this ``qd.graph_parallel`` + section a fresh ``stream_parallel_group_id`` that every for-loop in the body inherits, so the + offloaded tasks carry the ``qd.graph_parallel`` section id all the way to the graph builder.""" if not getattr(ctx, "_in_graph_parallel_context", False): raise QuadrantsSyntaxError( "qd.graph_parallel() can only be used directly inside a qd.graph_parallel_context() region" diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 070814c5f4..665f2b5096 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -619,7 +619,7 @@ def _is_graph_parallel_context_with(stmt: ast.stmt) -> bool: @staticmethod def _is_parallel_section_with(stmt: ast.stmt) -> bool: """Syntactic check matching ASTTransformer._is_parallel_section_call: a ``with qd.graph_parallel(...):`` - parallel section of a ``qd.graph_parallel_context()`` region.""" + section of a ``qd.graph_parallel_context()`` region.""" if not isinstance(stmt, ast.With) or len(stmt.items) != 1: return False ctx_expr = stmt.items[0].context_expr @@ -693,11 +693,11 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo continue if FunctionDefTransformer._is_graph_parallel_context_with(stmt): # A `with qd.graph_parallel_context()` region groups concurrent `with qd.graph_parallel()` - # parallel sections; it is a legal sibling of for-loops / checkpoints. Its body must be - # parallel-section blocks (optionally under `if qd.static(...)`); the full check is in - # ASTTransformer._build_graph_parallel_context_with. Each parallel section's body is task - # territory, validated here with the in-loop rules. Descend through `if` members so parallel - # sections inside an optional `if qd.static(...)` are reached too. + # sections; it is a legal sibling of for-loops / checkpoints. Its body must be + # `qd.graph_parallel` section blocks (optionally under `if qd.static(...)`); the full check + # is in ASTTransformer._build_graph_parallel_context_with. Each `qd.graph_parallel` section's + # body is task territory, validated here with the in-loop rules. Descend through `if` members + # so `qd.graph_parallel` sections inside an optional `if qd.static(...)` are reached too. pending = list(stmt.body) while pending: member = pending.pop() diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index 4216bc1da3..fa36e67f86 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -754,31 +754,32 @@ def graph_do_while(condition) -> bool: @contextmanager def graph_parallel_context(): - """Opens a fork/join region whose ``qd.graph_parallel()`` parallel sections run concurrently. + """Opens a fork/join region whose ``qd.graph_parallel()`` sections run concurrently. Used as ``with qd.graph_parallel_context():`` inside a ``@qd.kernel(graph=True)`` kernel. The region's - body must contain only ``with qd.graph_parallel():`` blocks. Each parallel section is an independent - sequence of work; the parallel sections have no ordering relative to each other and may execute - concurrently, while everything after the region waits for *all* parallel sections to finish (the join). - This is the graph analogue of ``qd.stream_parallel()`` (which is for non-graph kernels): it lets - independent stages -- e.g. qipc's point-triangle and edge-edge assembly -- overlap inside a captured - graph. - - Concurrency contract (the author's responsibility): parallel sections must be data-race free with - respect to one another (no parallel section reads what another writes, no two parallel sections write - the same location). Calls *within* a parallel section keep their program order. + body must contain only ``with qd.graph_parallel():`` blocks. Each ``qd.graph_parallel`` section is an + independent sequence of work; the ``qd.graph_parallel`` sections have no ordering relative to each + other and may execute concurrently, while everything after the region waits for *all* ``qd.graph_parallel`` + sections to finish (the join). This is the graph analogue of ``qd.stream_parallel()`` (which is for + non-graph kernels): it lets independent stages -- e.g. qipc's point-triangle and edge-edge assembly + -- overlap inside a captured graph. + + Concurrency contract (the author's responsibility): ``qd.graph_parallel`` sections must be data-race + free with respect to one another (no ``qd.graph_parallel`` section reads what another writes, no two + ``qd.graph_parallel`` sections write the same location). Calls *within* a ``qd.graph_parallel`` section + keep their program order. Backend behavior: - - CUDA SM graph path: parallel sections become independent graph chains joined by an empty node, so - the runtime schedules them on parallel streams (real overlap). - - CPU / Vulkan / Metal / AMDGPU graph: correct results, parallel sections run serially (the - concurrency tags are honored only by the graph builder today). + - CUDA SM graph path: ``qd.graph_parallel`` sections become independent graph chains joined by an + empty node, so the runtime schedules them on parallel streams (real overlap). + - CPU / Vulkan / Metal / AMDGPU graph: correct results, ``qd.graph_parallel`` sections run serially + (the concurrency tags are honored only by the graph builder today). Restrictions (enforced at kernel compile time): - Must be used inside ``@qd.kernel(graph=True)``. - The region body may contain only ``with qd.graph_parallel():`` blocks. - - Regions cannot be nested, and a parallel section body must be straight-line task work (no nested - ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). + - Regions cannot be nested, and a ``qd.graph_parallel`` section body must be straight-line task work + (no nested ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). This function should not be called directly at runtime; it is recognized and transformed during AST compilation. At Python runtime (outside kernels) it is a no-op context manager. @@ -790,11 +791,11 @@ def graph_parallel_context(): @contextmanager def graph_parallel(): - """Declares one parallel section of an enclosing ``qd.graph_parallel_context()`` region. + """Declares one ``qd.graph_parallel`` section of an enclosing ``qd.graph_parallel_context()`` region. Used as ``with qd.graph_parallel():`` directly inside a ``with qd.graph_parallel_context():`` block. - The parallel section's body is an independent sequence of work that may run concurrently with the - region's other parallel sections. + The ``qd.graph_parallel`` section's body is an independent sequence of work that may run concurrently + with the region's other ``qd.graph_parallel`` sections. See ``qd.graph_parallel_context()`` for the full contract and backend behavior. """ diff --git a/quadrants/runtime/cuda/graph_manager.cpp b/quadrants/runtime/cuda/graph_manager.cpp index d932abb8aa..4b79031f40 100644 --- a/quadrants/runtime/cuda/graph_manager.cpp +++ b/quadrants/runtime/cuda/graph_manager.cpp @@ -406,17 +406,17 @@ void GraphManager::build_level(int parent_id, // --- A qd.graph_parallel_context() fork/join region: a contiguous run of this level's direct, // non-checkpoint tasks tagged with a nonzero stream_parallel_group_id (set by qd.graph_parallel()). Each - // distinct group id is one parallel section; the parallel sections fork from the region's entry - // (`prev_node`), run their tasks in order, and join into a single empty node so downstream work waits for - // all of them. CUDA's graph executor schedules the independent parallel-section chains on separate - // streams -> real overlap. --- + // distinct group id is one qd.graph_parallel section; the qd.graph_parallel sections fork from the + // region's entry (`prev_node`), run their tasks in order, and join into a single empty node so + // downstream work waits for all of them. CUDA's graph executor schedules the independent + // qd.graph_parallel section chains on separate streams -> real overlap. --- if (tasks[cursor].stream_parallel_group_id != 0 && tasks[cursor].checkpoint_id < 0) { int run_end = cursor; while (run_end < end && tasks[run_end].graph_do_while_level_id == parent_id && tasks[run_end].checkpoint_id < 0 && tasks[run_end].stream_parallel_group_id != 0) { run_end++; } - // Bucket the run's tasks by parallel-section id, preserving first-seen (declaration) order. + // Bucket the run's tasks by qd.graph_parallel section id, preserving first-seen (declaration) order. std::vector group_ids; std::vector> parallel_sections; for (int t = cursor; t < run_end; t++) { @@ -439,7 +439,7 @@ void GraphManager::build_level(int parent_id, std::vector tails; tails.reserve(parallel_sections.size()); for (auto &ps : parallel_sections) { - void *bp = prev_node; // every parallel section forks from the region entry dependency + void *bp = prev_node; // every qd.graph_parallel section forks from the region entry dependency for (int t : ps) { bp = add_kernel_node(target_graph, bp, cuda_module->lookup_function(tasks[t].name), (unsigned int)tasks[t].grid_dim, (unsigned int)tasks[t].block_dim, @@ -448,9 +448,9 @@ void GraphManager::build_level(int parent_id, } tails.push_back(bp); } - // Join. A single-parallel-section region (e.g. an optional parallel section compiled out) has nothing - // to join, so just continue the chain from its tail; otherwise collect all tails into one empty - // successor node. + // Join. A region with a single qd.graph_parallel section (e.g. an optional qd.graph_parallel section + // compiled out) has nothing to join, so just continue the chain from its tail; otherwise collect all + // tails into one empty successor node. if (tails.size() == 1) { prev_node = tails[0]; } else { diff --git a/quadrants/runtime/cuda/graph_manager.h b/quadrants/runtime/cuda/graph_manager.h index bb724c6cad..210504d2f6 100644 --- a/quadrants/runtime/cuda/graph_manager.h +++ b/quadrants/runtime/cuda/graph_manager.h @@ -202,8 +202,8 @@ class GraphManager { unsigned int shared_mem, void **kernel_params); // Add an empty (no-op) node to `graph` depending on every node in `deps`. Used as the join point of a - // qd.graph_parallel_context() region: it has no work but collects all parallel-section tails into a single - // successor so downstream nodes wait for every parallel section. `deps` must be non-empty. + // qd.graph_parallel_context() region: it has no work but collects all qd.graph_parallel section tails into + // a single successor so downstream nodes wait for every qd.graph_parallel section. `deps` must be non-empty. void *add_empty_node(void *graph, const std::vector &deps); // Recursively build the nodes for graph_do_while level `parent_id` (-1 = kernel top level) over the task range // [begin, end) into `target_graph` (the body graph of `parent_id`, or the root graph for -1). Direct tasks become diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index 8b85ebe88f..90428fcf37 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -1,14 +1,14 @@ -"""Tests for qd.graph_parallel_context / qd.graph_parallel -- concurrent fork/join parallel sections in -graph kernels. +"""Tests for qd.graph_parallel_context / qd.graph_parallel -- concurrent fork/join sections in graph +kernels. `with qd.graph_parallel_context():` opens a fork/join region whose `with qd.graph_parallel():` members are -independent sequences of work. On the graph path the parallel sections become independent graph chains -joined by a single empty node, so the runtime schedules them on parallel streams; on other backends +independent sequences of work. On the graph path the qd.graph_parallel sections become independent graph +chains joined by a single empty node, so the runtime schedules them on parallel streams; on other backends (CPU / AMDGPU / Vulkan / Metal) they run serially but produce identical results. The behavioral assertions (disjoint-array correctness) hold on every backend. The graph-structure -assertions (node counts: one kernel node per parallel-section task + one empty join node) only apply where -the builder forks/joins (CUDA today), so they are guarded by `_on_cuda()`. +assertions (node counts: one kernel node per qd.graph_parallel section task + one empty join node) only +apply where the builder forks/joins (CUDA today), so they are guarded by `_on_cuda()`. """ import numpy as np @@ -53,9 +53,9 @@ def test_graph_parallel_is_no_op_outside_kernels(): @test_utils.test() def test_graph_parallel_two_sections(): - """Two parallel sections write disjoint arrays; a serial loop after the region reads both (so it - depends on the join). Results must match the serial reference on every backend; on CUDA the graph has - one node per task plus one empty join node.""" + """Two qd.graph_parallel sections write disjoint arrays; a serial loop after the region reads both (so + it depends on the join). Results must match the serial reference on every backend; on CUDA the graph + has one node per task plus one empty join node.""" n = 1024 @qd.kernel(graph=True) @@ -86,7 +86,7 @@ def k( num_tasks = _num_offloaded_tasks() if _on_cuda(): # One graph node per offloaded task (each dynamic-bound loop is a bound-compute serial + a - # range_for, both in the parallel section) plus exactly one empty join node for the single region. + # range_for, both in the qd.graph_parallel section) plus exactly one empty join node for the region. assert _graph_num_nodes() == num_tasks + 1 np.testing.assert_allclose(x.to_numpy(), 1.0) @@ -102,7 +102,7 @@ def k( @test_utils.test() def test_graph_parallel_three_sections(): - """Fan-out of three independent parallel sections; one empty join node.""" + """Fan-out of three independent qd.graph_parallel sections; one empty join node.""" n = 256 @qd.kernel(graph=True) @@ -131,7 +131,7 @@ def k( k(a, b, c) num_tasks = _num_offloaded_tasks() if _on_cuda(): - assert _graph_num_nodes() == num_tasks + 1 # three parallel sections + one join + assert _graph_num_nodes() == num_tasks + 1 # three qd.graph_parallel sections + one join np.testing.assert_allclose(a.to_numpy(), 1.0) np.testing.assert_allclose(b.to_numpy(), 2.0) @@ -140,9 +140,9 @@ def k( @test_utils.test() def test_graph_parallel_multi_loop_sections(): - """Each parallel section contains several loops; they must chain in order inside the parallel section - while the two parallel sections run independently. Parallel-section tasks = 4, plus one join node on - CUDA.""" + """Each qd.graph_parallel section contains several loops; they must chain in order inside the + qd.graph_parallel section while the two qd.graph_parallel sections run independently. qd.graph_parallel + section tasks = 4, plus one join node on CUDA.""" n = 128 @qd.kernel(graph=True) @@ -167,7 +167,7 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): k(x, y) num_tasks = _num_offloaded_tasks() if _on_cuda(): - assert _graph_num_nodes() == num_tasks + 1 # all parallel-section tasks + one join + assert _graph_num_nodes() == num_tasks + 1 # all qd.graph_parallel section tasks + one join np.testing.assert_allclose(x.to_numpy(), 2.0) # (0+1)*2 np.testing.assert_allclose(y.to_numpy(), 12.0) # (0+3)*4 @@ -175,9 +175,9 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): @test_utils.test() def test_graph_parallel_single_section_no_join(): - """A region with a single parallel section (e.g. an optional parallel section compiled out) needs no - join: it degenerates to a plain chain, so the node count equals the number of parallel-section tasks - (no extra empty node).""" + """A region with a single qd.graph_parallel section (e.g. an optional qd.graph_parallel section compiled + out) needs no join: it degenerates to a plain chain, so the node count equals the number of + qd.graph_parallel section tasks (no extra empty node).""" n = 256 @qd.kernel(graph=True) @@ -193,16 +193,16 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): k(x) num_tasks = _num_offloaded_tasks() if _on_cuda(): - assert _graph_num_nodes() == num_tasks # single parallel section -> plain chain, no join node + assert _graph_num_nodes() == num_tasks # single qd.graph_parallel section -> plain chain, no join np.testing.assert_allclose(x.to_numpy(), 5.0) @test_utils.test() def test_graph_parallel_optional_section_static_if(): - """The qipc ENABLE_EE pattern: a parallel section wrapped in `if qd.static(...)`. When the flag is - False the parallel section is compiled out (region has one parallel section -> no join); when True both - parallel sections run.""" + """The qipc ENABLE_EE pattern: a qd.graph_parallel section wrapped in `if qd.static(...)`. When the flag + is False the qd.graph_parallel section is compiled out (region has one qd.graph_parallel section -> no + join); when True both qd.graph_parallel sections run.""" n = 128 @qd.kernel(graph=True) @@ -233,15 +233,15 @@ def k_on(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1 y.from_numpy(np.zeros(n, dtype=np.float32)) k_off(x, y) if _on_cuda(): - assert _graph_num_nodes() == _num_offloaded_tasks() # single parallel section -> no join + assert _graph_num_nodes() == _num_offloaded_tasks() # single qd.graph_parallel section -> no join np.testing.assert_allclose(x.to_numpy(), 1.0) - np.testing.assert_allclose(y.to_numpy(), 0.0) # EE parallel section compiled out + np.testing.assert_allclose(y.to_numpy(), 0.0) # EE qd.graph_parallel section compiled out x.from_numpy(np.zeros(n, dtype=np.float32)) y.from_numpy(np.zeros(n, dtype=np.float32)) k_on(x, y) if _on_cuda(): - assert _graph_num_nodes() == _num_offloaded_tasks() + 1 # two parallel sections + join + assert _graph_num_nodes() == _num_offloaded_tasks() + 1 # two qd.graph_parallel sections + join np.testing.assert_allclose(x.to_numpy(), 1.0) np.testing.assert_allclose(y.to_numpy(), 1.0) @@ -249,7 +249,7 @@ def k_on(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1 @test_utils.test() def test_graph_parallel_inside_graph_do_while(): """A fork/join region inside a qd.graph_do_while loop body must be correct across iterations: each - iteration runs both parallel sections, then decrements the counter.""" + iteration runs both qd.graph_parallel sections, then decrements the counter.""" n = 64 iters = 5 From 478ec7096685098e3f298557545580caeb288f5b Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 09:18:26 -0700 Subject: [PATCH 17/25] style: apply pre-commit (black, clang-format) formatting --- python/quadrants/lang/ast/ast_transformer.py | 8 ++------ quadrants/runtime/cuda/graph_manager.cpp | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 247b99c684..3342640f40 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1685,18 +1685,14 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast joins them. Regions are kept apart by the serial work between them (see d3_0_graph_parallel_impl.md).""" if not ctx.is_kernel: - raise QuadrantsSyntaxError( - "qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func" - ) + raise QuadrantsSyntaxError("qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func") kernel = ctx.global_context.current_kernel if kernel is None or not kernel.use_graph: raise QuadrantsSyntaxError("qd.graph_parallel_context() requires @qd.kernel(graph=True)") if getattr(ctx, "_in_graph_parallel_context", False): raise QuadrantsSyntaxError("qd.graph_parallel_context() regions cannot be nested") if getattr(ctx, "_in_parallel_section", False): - raise QuadrantsSyntaxError( - "qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body" - ) + raise QuadrantsSyntaxError("qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body") ASTTransformer._validate_graph_parallel_context_body(node.body) ctx._in_graph_parallel_context = True try: diff --git a/quadrants/runtime/cuda/graph_manager.cpp b/quadrants/runtime/cuda/graph_manager.cpp index 4b79031f40..daa17fdcc9 100644 --- a/quadrants/runtime/cuda/graph_manager.cpp +++ b/quadrants/runtime/cuda/graph_manager.cpp @@ -412,8 +412,8 @@ void GraphManager::build_level(int parent_id, // qd.graph_parallel section chains on separate streams -> real overlap. --- if (tasks[cursor].stream_parallel_group_id != 0 && tasks[cursor].checkpoint_id < 0) { int run_end = cursor; - while (run_end < end && tasks[run_end].graph_do_while_level_id == parent_id && - tasks[run_end].checkpoint_id < 0 && tasks[run_end].stream_parallel_group_id != 0) { + while (run_end < end && tasks[run_end].graph_do_while_level_id == parent_id && tasks[run_end].checkpoint_id < 0 && + tasks[run_end].stream_parallel_group_id != 0) { run_end++; } // Bucket the run's tasks by qd.graph_parallel section id, preserving first-seen (declaration) order. From 7b4a0b6af107f90bea43a0e87ed925dd7db63e4d Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 09:33:19 -0700 Subject: [PATCH 18/25] feat(graph): generate qd.graph_parallel sections from a qd.static for-loop Port the static-loop section generator from hp/graph-parallel-static-branches onto the current naming. A `for ... in qd.static(...)` loop inside a qd.graph_parallel_context unrolls at trace time into one qd.graph_parallel section per iteration (each gets a fresh stream_parallel_group_id), so sections can be forked from a compile-time sequence -- e.g. one per @qd.func member of a @qd.data_oriented list (qipc's per-contact-type assembly pattern). - ast_transformer: region-body validator recurses into `for ... in qd.static(...)` loops (staticness re-checked via get_decorator at every nesting level, so a runtime loop -- even nested under a static one -- is still rejected). - function_def_transformer: graph_do_while structure validator descends into For members so static-loop section bodies are validated like hand-written ones. - tests: static-loop sections (incl. nested, empty range, single section, over-funcs, mixed with if-static, inside graph_do_while) + runtime-loop and non-section-body rejection guards. - docs: "Generating qd.graph_parallel sections from a compile-time sequence". --- docs/source/user_guide/graph.md | 22 +- python/quadrants/lang/ast/ast_transformer.py | 29 +- .../function_def_transformer.py | 5 + tests/python/test_graph_parallel.py | 284 ++++++++++++++++++ 4 files changed, 330 insertions(+), 10 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index 278998f83b..a39a309bcb 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -497,9 +497,29 @@ def step(...): - **Fork / join.** Every `qd.graph_parallel` section in the region forks from the work that precedes the region. All `qd.graph_parallel` sections must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every `qd.graph_parallel` section's last kernel. - **`qd.graph_parallel` sections are independent — you guarantee it.** Calls *within* a `qd.graph_parallel` section keep their program order, but calls in *different* `qd.graph_parallel` sections have no ordering. The `qd.graph_parallel` sections must be data-race free with respect to one another: no `qd.graph_parallel` section may read what another writes, and no two `qd.graph_parallel` sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results. +### Generating `qd.graph_parallel` sections from a compile-time sequence + +`qd.graph_parallel` sections do not have to be written out one by one. A `for ... in qd.static(...)` loop is unrolled at compile time, so each iteration that contains a `with qd.graph_parallel():` becomes its own section — handy for forking one section per element of a static list (e.g. per contact type): + +```python +@qd.data_oriented +class Solver: + def __init__(self): + self.funcs = [self._assemble_pt, self._assemble_ee] # static list of @qd.func members + + @qd.kernel(graph=True) + def step(self): + with qd.graph_parallel_context(): + for i in qd.static(range(len(self.funcs))): # unrolls to one section per func + with qd.graph_parallel(): + self.funcs[i]() +``` + +The loop **must** be a `qd.static(...)` loop (its trip count is known at compile time). A plain runtime `for i in range(n):` is rejected — a runtime loop cannot be unrolled into independent sections. + ### Restrictions (enforced at kernel compile time) -- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional `qd.graph_parallel` section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set). +- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional `qd.graph_parallel` section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set) or `for ... in qd.static(...)` loops (generate one `qd.graph_parallel` section per element of a compile-time sequence). - `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()`. - `qd.graph_parallel_context` cannot be nested, and a `qd.graph_parallel` section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a `qd.graph_parallel` section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above). diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 3342640f40..2089bde66b 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -1693,7 +1693,7 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast raise QuadrantsSyntaxError("qd.graph_parallel_context() regions cannot be nested") if getattr(ctx, "_in_parallel_section", False): raise QuadrantsSyntaxError("qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body") - ASTTransformer._validate_graph_parallel_context_body(node.body) + ASTTransformer._validate_graph_parallel_context_body(ctx, node.body) ctx._in_graph_parallel_context = True try: build_stmts(ctx, node.body) @@ -1702,12 +1702,20 @@ def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast return None @staticmethod - def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: + def _validate_graph_parallel_context_body(ctx: ASTTransformerFuncContext, stmts: list[ast.stmt]) -> None: """A qd.graph_parallel_context() region body may contain only `with qd.graph_parallel():` blocks, optionally wrapped in compile-time `if qd.static(...)` (the optional ``qd.graph_parallel`` section - pattern, e.g. qipc's ENABLE_EE). Docstrings / coverage probes / `pass` are allowed. Anything else - (a bare for-loop, assignment, etc.) is a serial task that would silently fall outside any - ``qd.graph_parallel`` section, so reject it.""" + pattern, e.g. qipc's ENABLE_EE) or `for ... in qd.static(...)` loops (generate one + ``qd.graph_parallel`` section per element of a compile-time sequence). Docstrings / coverage probes / + `pass` are allowed. Anything else (a runtime for-loop, a bare assignment, etc.) is a serial task that + would silently fall outside any ``qd.graph_parallel`` section, so reject it. + + The `for` case is restricted to `qd.static(...)` loops on purpose: a static loop unrolls at trace + time into its repeated body, so it lowers to literal `with qd.graph_parallel():` blocks (each gets a + fresh stream_parallel_group_id). A *runtime* for-loop would instead trace a single parallel range_for + with the section tagging nested inside it -- malformed. Staticness is checked with `get_decorator` + (the same resolution `build_For` uses) at every nesting level, so a runtime loop nested under a + static one is still rejected.""" for i, stmt in enumerate(stmts): if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): continue @@ -1717,13 +1725,16 @@ def _validate_graph_parallel_context_body(stmts: list[ast.stmt]) -> None: if ASTTransformer._is_parallel_section_call(stmt.items[0].context_expr): continue if isinstance(stmt, ast.If): - ASTTransformer._validate_graph_parallel_context_body(stmt.body) - ASTTransformer._validate_graph_parallel_context_body(stmt.orelse) + ASTTransformer._validate_graph_parallel_context_body(ctx, stmt.body) + ASTTransformer._validate_graph_parallel_context_body(ctx, stmt.orelse) + continue + if isinstance(stmt, ast.For) and not stmt.orelse and get_decorator(ctx, stmt.iter) == "static": + ASTTransformer._validate_graph_parallel_context_body(ctx, stmt.body) continue raise QuadrantsSyntaxError( "A qd.graph_parallel_context() region may contain only 'with qd.graph_parallel():' blocks " - "(optionally inside 'if qd.static(...)'). Move other work outside the region. " - f"[offending stmt {i}: {type(stmt).__name__}]" + "(optionally inside 'if qd.static(...)' or 'for ... in qd.static(...)'). Move other work " + f"outside the region. [offending stmt {i}: {type(stmt).__name__}]" ) @staticmethod diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 665f2b5096..67b32a9bb1 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -706,6 +706,11 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo elif isinstance(member, ast.If): pending.extend(member.body) pending.extend(member.orelse) + elif isinstance(member, ast.For): + # `for ... in qd.static(...)` generates sections; descend so each unrolled section + # body is still validated with the in-loop rules (a runtime for here is rejected + # earlier by ASTTransformer._build_graph_parallel_context_with). + pending.extend(member.body) continue where = "the kernel body" if is_kernel_top else "a qd.graph_do_while() body" raise QuadrantsSyntaxError( diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index 90428fcf37..ad7eb33fad 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -340,3 +340,287 @@ def k(x: qd.types.ndarray(qd.f32, ndim=1)): x = qd.ndarray(qd.f32, shape=(16,)) with pytest.raises(qd.QuadrantsSyntaxError): k(x) + + +@test_utils.test() +def test_graph_parallel_static_loop_two_sections(): + """`for b in qd.static(range(NB))` unrolls into NB literal qd.graph_parallel sections, each writing a + disjoint row.""" + nb = 2 + n = 256 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2)): + with qd.graph_parallel_context(): + for b in qd.static(range(nb)): + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + (b + 1) + + x = qd.ndarray(qd.f32, shape=(nb, n)) + x.from_numpy(np.zeros((nb, n), dtype=np.float32)) + + k(x) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 # nb qd.graph_parallel sections + one join + + out = x.to_numpy() + np.testing.assert_allclose(out[0], 1.0) + np.testing.assert_allclose(out[1], 2.0) + + +@test_utils.test() +def test_graph_parallel_static_loop_over_funcs(): + """The motivating pattern: a @qd.data_oriented class iterates a static list of @qd.func members, one + qd.graph_parallel section each (mirrors qipc's per-contact-type assembly funcs).""" + n = 4 + + @qd.data_oriented + class Demo: + def __init__(self): + self.a = qd.field(qd.i32, shape=(n,)) + self.b = qd.field(qd.i32, shape=(n,)) + self.funcs = [self._fill_a, self._fill_b] + + @qd.func + def _fill_a(self): + for i in range(n): + self.a[i] += 1 + + @qd.func + def _fill_b(self): + for i in range(n): + self.b[i] += 10 + + @qd.kernel(graph=True) + def step(self): + with qd.graph_parallel_context(): + for i in qd.static(range(len(self.funcs))): + with qd.graph_parallel(): + self.funcs[i]() + + d = Demo() + d.a.from_numpy(np.zeros(n, dtype=np.int32)) + d.b.from_numpy(np.zeros(n, dtype=np.int32)) + d.step() + np.testing.assert_array_equal(d.a.to_numpy(), np.ones(n, dtype=np.int32)) + np.testing.assert_array_equal(d.b.to_numpy(), np.full(n, 10, dtype=np.int32)) + + +@test_utils.test() +def test_graph_parallel_static_loop_single_section(): + """A static loop of one iteration is a single-section region: a plain chain, no join node.""" + n = 256 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + for _b in qd.static(range(1)): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 5.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks # single section -> no join node + + np.testing.assert_allclose(x.to_numpy(), 5.0) + + +@test_utils.test() +def test_graph_parallel_static_loop_empty_range(): + """An empty static range produces zero qd.graph_parallel sections: the region is a no-op (consistent + with wrapping the only section in `if qd.static(False)`). Serial work after it still runs.""" + n = 128 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + for _b in qd.static(range(0)): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + for i in range(x.shape[0]): + x[i] = x[i] + 5.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x) + np.testing.assert_allclose(x.to_numpy(), 5.0) # region did nothing; only the serial +5 applied + + +@test_utils.test() +def test_graph_parallel_static_loop_nested(): + """Nested static loops fan out to N*M qd.graph_parallel sections, each writing a disjoint row.""" + ni, nj = 2, 2 + nrows = ni * nj + n = 64 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2)): + with qd.graph_parallel_context(): + for i in qd.static(range(ni)): + for j in qd.static(range(nj)): + with qd.graph_parallel(): + for c in range(x.shape[1]): + x[i * nj + j, c] = x[i * nj + j, c] + (i * nj + j + 1) + + x = qd.ndarray(qd.f32, shape=(nrows, n)) + x.from_numpy(np.zeros((nrows, n), dtype=np.float32)) + + k(x) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 # nrows qd.graph_parallel sections + one join + + out = x.to_numpy() + for r in range(nrows): + np.testing.assert_allclose(out[r], float(r + 1)) + + +@test_utils.test() +def test_graph_parallel_static_loop_mixed_with_static_if(): + """A static section loop and an `if qd.static(...)` optional qd.graph_parallel section coexist in one + region.""" + nb = 2 + n = 64 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2), y: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + for b in qd.static(range(nb)): + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + (b + 1) + if qd.static(True): + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] + 7.0 + + x = qd.ndarray(qd.f32, shape=(nb, n)) + y = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros((nb, n), dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x, y) + num_tasks = _num_offloaded_tasks() + if _on_cuda(): + assert _graph_num_nodes() == num_tasks + 1 # nb + 1 qd.graph_parallel sections + one join + + out = x.to_numpy() + np.testing.assert_allclose(out[0], 1.0) + np.testing.assert_allclose(out[1], 2.0) + np.testing.assert_allclose(y.to_numpy(), 7.0) + + +@test_utils.test() +def test_graph_parallel_runtime_loop_raises(): + """A *runtime* for-loop in a region body stays rejected: only `qd.static(...)` loops unroll to literal + qd.graph_parallel sections; a runtime range would nest the section tagging inside a parallel range_for + (malformed).""" + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2), nb: qd.i32): + with qd.graph_parallel_context(): + for b in range(nb): + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(2, 16)) + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.graph_parallel"): + k(x, 2) + + +@test_utils.test() +def test_graph_parallel_takes_no_arguments(): + """qd.graph_parallel() (the section) takes no arguments. Any argument raises.""" + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1)): + with qd.graph_parallel_context(): + with qd.graph_parallel(name="bx"): + for i in range(x.shape[0]): + x[i] = x[i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(16,)) + with pytest.raises(qd.QuadrantsSyntaxError, match="qd.graph_parallel.. takes no arguments"): + k(x) + + +@test_utils.test() +def test_graph_parallel_static_loop_body_non_section_raises(): + """A static loop body must still be section-only: serial work inside the loop (outside any + qd.graph_parallel section) would silently fall outside a section, so it is rejected (the validator + recurses into the loop body).""" + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2)): + with qd.graph_parallel_context(): + for b in qd.static(range(2)): + x[b, 0] = 1.0 # serial work outside any qd.graph_parallel section + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(2, 16)) + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.graph_parallel"): + k(x) + + +@test_utils.test() +def test_graph_parallel_static_loop_runtime_inner_loop_raises(): + """Staticness is re-checked at every nesting level: a *runtime* loop nested inside a static loop and + wrapping a qd.graph_parallel section is still rejected (only the static unroll yields independent + sections).""" + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=2), m: qd.i32): + with qd.graph_parallel_context(): + for b in qd.static(range(2)): + for _j in range(m): # runtime loop around a section -> rejected + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + 1.0 + + x = qd.ndarray(qd.f32, shape=(2, 16)) + with pytest.raises(qd.QuadrantsSyntaxError, match="may contain only .with qd.graph_parallel"): + k(x, 2) + + +@test_utils.test() +def test_graph_parallel_static_loop_inside_graph_do_while(): + """A static section loop composes with qd.graph_do_while: each iteration runs all unrolled + qd.graph_parallel sections, then decrements the counter.""" + nb = 2 + n = 64 + iters = 4 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.i32, ndim=2), counter: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(counter): + with qd.graph_parallel_context(): + for b in qd.static(range(nb)): + with qd.graph_parallel(): + for i in range(x.shape[1]): + x[b, i] = x[b, i] + (b + 1) + for _ in range(1): + counter[()] = counter[()] - 1 + + x = qd.ndarray(qd.i32, shape=(nb, n)) + counter = qd.ndarray(qd.i32, shape=()) + x.from_numpy(np.zeros((nb, n), dtype=np.int32)) + counter.from_numpy(np.array(iters, dtype=np.int32)) + + k(x, counter) + + assert counter.to_numpy() == 0 + out = x.to_numpy() + np.testing.assert_array_equal(out[0], np.full(n, iters, dtype=np.int32)) + np.testing.assert_array_equal(out[1], np.full(n, 2 * iters, dtype=np.int32)) From 94a9b5bdd435cde0e7f60edced96b8eb37441ea4 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 10:24:57 -0700 Subject: [PATCH 19/25] refactor(graph): extract graph_parallel transformer + doc/wrap CI fixes Address the three Cursor-agent CI gates on PR #756: - Feature factorization: move the graph_parallel detection/validation/lowering out of ast_transformer.py into ast_transformers/graph_parallel_transformer.py (GraphParallelTransformer), mirroring checkpoint_transformer.py. build_With now calls the new class directly; function_def_transformer comments updated to point at it. - Line wrapping: re-wrap the graph_parallel_context docstring to 120c and drop a dangling reference to an internal design doc. - Doc quality (graph.md): reword pre-existing graph_do_while content the gate flagged -- replace "lowering" jargon, drop the internal "genesis-world" / hypothetical qd.graph_range_for / MAX_ITER design walk-through phrasing. --- docs/source/user_guide/graph.md | 11 +- python/quadrants/lang/ast/ast_transformer.py | 125 +--------------- .../function_def_transformer.py | 8 +- .../graph_parallel_transformer.py | 139 ++++++++++++++++++ 4 files changed, 154 insertions(+), 129 deletions(-) create mode 100644 python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index a39a309bcb..a155b1c67e 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -180,7 +180,7 @@ Therefore on unsupported platforms, you might consider creating a second impleme ## Checkpoints with `qd.checkpoint` *(experimental)* -> **Experimental.** `qd.checkpoint`, `qd.GraphStatus`, and `kernel.resume(from_checkpoint=...)` are experimental APIs. The shape of the public surface (the context-manager signature, the `@qd.kernel(checkpoints=True)` flag, the `GraphStatus` fields, the host-side resume loop, the error messages, and the cross-backend lowering details) may change in any future release without a deprecation cycle. +> **Experimental.** `qd.checkpoint`, `qd.GraphStatus`, and `kernel.resume(from_checkpoint=...)` are experimental APIs. The shape of the public surface (the context-manager signature, the `@qd.kernel(checkpoints=True)` flag, the `GraphStatus` fields, the host-side resume loop, the error messages, and the cross-backend compilation details) may change in any future release without a deprecation cycle. `qd.checkpoint` lets a graph kernel break partway through, surface a reason to the host, let the host fix things up, and resume from the same location on the next launch. An example use-case is an algorithm implemented as a graph that may need to allocate additional memory partway through, where the operations in the graph are in-place, and therefore cannot be rerun without changing/corrupting the output, and therefore for which simply retrying the whole graph from the start is not an option. @@ -269,7 +269,7 @@ while status.yielded: arr[i] = arr[i] + 1 ``` -The restriction is by design: each top-level statement inside a checkpoint becomes its own GPU task / graph node, so silently wrapping bare statements would hide a sequence of N field writes ballooning into N kernel launches. Forcing the user to write the `for`-wrap themselves keeps the lowering visible and gives a single obvious place to fuse multiple writes into one task by sharing a single wrapper. +The restriction is by design: each top-level statement inside a checkpoint becomes its own GPU task / graph node, so silently wrapping bare statements would hide a sequence of N field writes ballooning into N kernel launches. Forcing the user to write the `for`-wrap themselves keeps the mapping to GPU tasks visible and gives a single obvious place to fuse multiple writes into one task by sharing a single wrapper. ## Performance @@ -394,11 +394,9 @@ k1(a, count) The recommendation is to use the graph do while here anyway, if you need it for any platform, in order to ensure the code is compact and maintainable. -If you do want fixed-size for loops to run optimally on unsupported hardware platforms, we could add a specializd `qd.graph_range_for` function. This would: -- on graph-do-while-supported hardware: handle adding the additional increment kernel -- on graph-do-while-unsupported hardware: handle running the loop entirely on the host-side, to avoid adding a gpu pipeline stall +If you need fixed-size for loops to run optimally on hardware without graph-do-while support, consider opening a PR for a dedicated helper that picks the host-side fallback automatically on those backends. -In practice, for our own kernels, i.e. in genesis-world, they largely fall under the do while formulation, see the previous section. However, also have some that used to be do while, but have been migrated to an optimized fixed-size, see next section. +In practice, most such loops fall under the do while formulation (see the previous section). Some that were originally do while have since been migrated to an optimized fixed-size form (see the next section). ### A while loop, conditional on a device-side scalar tensor, that has been optimized into a fixed-size for loop @@ -466,7 +464,6 @@ In this case, our recommendation is: - use graph do while anyway, if you need it on any platform - this will ensure your code is compact and maintainable - if you need optimum 100% performance on unsupported platforms, then consider PRing onto quadrants an optimized graph implementation for your target platform - - for example it could somehow run MAX_ITER iterations anyway, similar to the earlier hand-rolled version, but via the graph abstraction, hence allowing the code to be compact, cross-platform, and also optimally fast ## `qd.graph_parallel` sections with `qd.graph_parallel_context` *(experimental)* diff --git a/python/quadrants/lang/ast/ast_transformer.py b/python/quadrants/lang/ast/ast_transformer.py index 2089bde66b..47749650dd 100644 --- a/python/quadrants/lang/ast/ast_transformer.py +++ b/python/quadrants/lang/ast/ast_transformer.py @@ -32,6 +32,9 @@ from quadrants.lang.ast.ast_transformers.function_def_transformer import ( FunctionDefTransformer, ) +from quadrants.lang.ast.ast_transformers.graph_parallel_transformer import ( + GraphParallelTransformer, +) from quadrants.lang.exception import ( QuadrantsIndexError, QuadrantsRuntimeTypeError, @@ -1372,37 +1375,6 @@ def _is_checkpoint_call(node: ast.expr, global_vars: dict): ``CheckpointCallInfo`` or ``None``.""" return CheckpointTransformer.is_checkpoint_call(node, global_vars) - @staticmethod - def _is_graph_parallel_context_call(node: ast.expr) -> bool: - """If *node* is a ``qd.graph_parallel_context()`` call return True, else False.""" - if not isinstance(node, ast.Call): - return False - func = node.func - is_gpc = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel_context") or ( - isinstance(func, ast.Name) and func.id == "graph_parallel_context" - ) - if not is_gpc: - return False - if node.args or node.keywords: - raise QuadrantsSyntaxError("qd.graph_parallel_context() takes no arguments") - return True - - @staticmethod - def _is_parallel_section_call(node: ast.expr) -> bool: - """If *node* is a ``qd.graph_parallel()`` (a section) call return True, else False. The call shape - is validated here so misuse raises at the ``with`` site rather than later.""" - if not isinstance(node, ast.Call): - return False - func = node.func - is_parallel_section = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( - isinstance(func, ast.Name) and func.id == "graph_parallel" - ) - if not is_parallel_section: - return False - if node.args or node.keywords: - raise QuadrantsSyntaxError("qd.graph_parallel() takes no arguments") - return True - @staticmethod def build_While(ctx: ASTTransformerFuncContext, node: ast.While) -> None: if node.orelse: @@ -1646,11 +1618,11 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None: if checkpoint_info is not None: return ASTTransformer._build_checkpoint_with(ctx, node, checkpoint_info) - if ASTTransformer._is_graph_parallel_context_call(item.context_expr): - return ASTTransformer._build_graph_parallel_context_with(ctx, node) + if GraphParallelTransformer.is_graph_parallel_context_call(item.context_expr): + return GraphParallelTransformer.build_graph_parallel_context_with(ctx, node, build_stmts) - if ASTTransformer._is_parallel_section_call(item.context_expr): - return ASTTransformer._build_parallel_section_with(ctx, node) + if GraphParallelTransformer.is_parallel_section_call(item.context_expr): + return GraphParallelTransformer.build_parallel_section_with(ctx, node, build_stmts) if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars): raise QuadrantsSyntaxError( @@ -1674,89 +1646,6 @@ def _build_checkpoint_with( ``ast_transformers/checkpoint_transformer.py``.""" return CheckpointTransformer.build_checkpoint_with(ctx, node, info, build_stmts) - @staticmethod - def _build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: - """Handles ``with qd.graph_parallel_context():`` fork/join regions. - - Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains - only ``with qd.graph_parallel():`` blocks, then walks the body. The region emits no IR tag of its - own -- each ``qd.graph_parallel`` section inside lowers to a stream-parallel group (via - begin/end_stream_parallel), and the graph builder forks the distinct groups in a contiguous run and - joins them. Regions are kept apart by the serial work between them (see - d3_0_graph_parallel_impl.md).""" - if not ctx.is_kernel: - raise QuadrantsSyntaxError("qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func") - kernel = ctx.global_context.current_kernel - if kernel is None or not kernel.use_graph: - raise QuadrantsSyntaxError("qd.graph_parallel_context() requires @qd.kernel(graph=True)") - if getattr(ctx, "_in_graph_parallel_context", False): - raise QuadrantsSyntaxError("qd.graph_parallel_context() regions cannot be nested") - if getattr(ctx, "_in_parallel_section", False): - raise QuadrantsSyntaxError("qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body") - ASTTransformer._validate_graph_parallel_context_body(ctx, node.body) - ctx._in_graph_parallel_context = True - try: - build_stmts(ctx, node.body) - finally: - ctx._in_graph_parallel_context = False - return None - - @staticmethod - def _validate_graph_parallel_context_body(ctx: ASTTransformerFuncContext, stmts: list[ast.stmt]) -> None: - """A qd.graph_parallel_context() region body may contain only `with qd.graph_parallel():` blocks, - optionally wrapped in compile-time `if qd.static(...)` (the optional ``qd.graph_parallel`` section - pattern, e.g. qipc's ENABLE_EE) or `for ... in qd.static(...)` loops (generate one - ``qd.graph_parallel`` section per element of a compile-time sequence). Docstrings / coverage probes / - `pass` are allowed. Anything else (a runtime for-loop, a bare assignment, etc.) is a serial task that - would silently fall outside any ``qd.graph_parallel`` section, so reject it. - - The `for` case is restricted to `qd.static(...)` loops on purpose: a static loop unrolls at trace - time into its repeated body, so it lowers to literal `with qd.graph_parallel():` blocks (each gets a - fresh stream_parallel_group_id). A *runtime* for-loop would instead trace a single parallel range_for - with the section tagging nested inside it -- malformed. Staticness is checked with `get_decorator` - (the same resolution `build_For` uses) at every nesting level, so a runtime loop nested under a - static one is still rejected.""" - for i, stmt in enumerate(stmts): - if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): - continue - if isinstance(stmt, ast.Pass): - continue - if isinstance(stmt, ast.With) and stmt.items: - if ASTTransformer._is_parallel_section_call(stmt.items[0].context_expr): - continue - if isinstance(stmt, ast.If): - ASTTransformer._validate_graph_parallel_context_body(ctx, stmt.body) - ASTTransformer._validate_graph_parallel_context_body(ctx, stmt.orelse) - continue - if isinstance(stmt, ast.For) and not stmt.orelse and get_decorator(ctx, stmt.iter) == "static": - ASTTransformer._validate_graph_parallel_context_body(ctx, stmt.body) - continue - raise QuadrantsSyntaxError( - "A qd.graph_parallel_context() region may contain only 'with qd.graph_parallel():' blocks " - "(optionally inside 'if qd.static(...)' or 'for ... in qd.static(...)'). Move other work " - f"outside the region. [offending stmt {i}: {type(stmt).__name__}]" - ) - - @staticmethod - def _build_parallel_section_with(ctx: ASTTransformerFuncContext, node: ast.With) -> None: - """Handles a ``with qd.graph_parallel():`` section of a ``qd.graph_parallel_context()`` region. - - Reuses the stream-parallel tagging: begin_stream_parallel() assigns this ``qd.graph_parallel`` - section a fresh ``stream_parallel_group_id`` that every for-loop in the body inherits, so the - offloaded tasks carry the ``qd.graph_parallel`` section id all the way to the graph builder.""" - if not getattr(ctx, "_in_graph_parallel_context", False): - raise QuadrantsSyntaxError( - "qd.graph_parallel() can only be used directly inside a qd.graph_parallel_context() region" - ) - ctx._in_parallel_section = True - ctx.ast_builder.begin_stream_parallel() - try: - build_stmts(ctx, node.body) - finally: - ctx.ast_builder.end_stream_parallel() - ctx._in_parallel_section = False - return None - @staticmethod def build_Pass(ctx: ASTTransformerFuncContext, node: ast.Pass) -> None: return None diff --git a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py index 67b32a9bb1..69cd102472 100644 --- a/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/function_def_transformer.py @@ -602,7 +602,7 @@ def _is_loop_config_call(stmt: ast.stmt) -> bool: @staticmethod def _is_graph_parallel_context_with(stmt: ast.stmt) -> bool: - """Syntactic check matching ASTTransformer._is_graph_parallel_context_call: a + """Syntactic check matching GraphParallelTransformer.is_graph_parallel_context_call: a ``with qd.graph_parallel_context():`` fork/join region.""" if not isinstance(stmt, ast.With) or len(stmt.items) != 1: return False @@ -618,7 +618,7 @@ def _is_graph_parallel_context_with(stmt: ast.stmt) -> bool: @staticmethod def _is_parallel_section_with(stmt: ast.stmt) -> bool: - """Syntactic check matching ASTTransformer._is_parallel_section_call: a ``with qd.graph_parallel(...):`` + """Syntactic check matching GraphParallelTransformer.is_parallel_section_call: a ``with qd.graph_parallel(...):`` section of a ``qd.graph_parallel_context()`` region.""" if not isinstance(stmt, ast.With) or len(stmt.items) != 1: return False @@ -695,7 +695,7 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo # A `with qd.graph_parallel_context()` region groups concurrent `with qd.graph_parallel()` # sections; it is a legal sibling of for-loops / checkpoints. Its body must be # `qd.graph_parallel` section blocks (optionally under `if qd.static(...)`); the full check - # is in ASTTransformer._build_graph_parallel_context_with. Each `qd.graph_parallel` section's + # is in GraphParallelTransformer.build_graph_parallel_context_with. Each `qd.graph_parallel` section's # body is task territory, validated here with the in-loop rules. Descend through `if` members # so `qd.graph_parallel` sections inside an optional `if qd.static(...)` are reached too. pending = list(stmt.body) @@ -709,7 +709,7 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo elif isinstance(member, ast.For): # `for ... in qd.static(...)` generates sections; descend so each unrolled section # body is still validated with the in-loop rules (a runtime for here is rejected - # earlier by ASTTransformer._build_graph_parallel_context_with). + # earlier by GraphParallelTransformer.build_graph_parallel_context_with). pending.extend(member.body) continue where = "the kernel body" if is_kernel_top else "a qd.graph_do_while() body" diff --git a/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py b/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py new file mode 100644 index 0000000000..d3fa27d312 --- /dev/null +++ b/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py @@ -0,0 +1,139 @@ +# type: ignore +"""AST recognition, validation, and lowering for ``qd.graph_parallel_context()`` / ``qd.graph_parallel()`` blocks. + +Lives alongside ``checkpoint_transformer.py`` / ``function_def_transformer.py`` so that ``ast_transformer.py`` doesn't +have to grow per-feature. ``ASTTransformer.build_With`` forwards ``qd.graph_parallel_context()`` regions and their +``qd.graph_parallel()`` sections into the static methods here. + +A ``qd.graph_parallel_context()`` region emits no IR tag of its own: each ``qd.graph_parallel()`` section inside lowers +to a stream-parallel group (via ``begin/end_stream_parallel``), and the graph builder forks the distinct groups in a +contiguous run and joins them. Regions are kept apart by the serial work between them. See +``docs/source/user_guide/graph.md`` for the user-facing surface. +""" + +from __future__ import annotations + +import ast + +from quadrants.lang.ast.ast_transformer_utils import ( + ASTTransformerFuncContext, + get_decorator, +) +from quadrants.lang.ast.ast_transformers.function_def_transformer import ( + FunctionDefTransformer, +) +from quadrants.lang.exception import QuadrantsSyntaxError + + +class GraphParallelTransformer: + @staticmethod + def is_graph_parallel_context_call(node: ast.expr) -> bool: + """If *node* is a ``qd.graph_parallel_context()`` call return True, else False.""" + if not isinstance(node, ast.Call): + return False + func = node.func + is_gpc = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel_context") or ( + isinstance(func, ast.Name) and func.id == "graph_parallel_context" + ) + if not is_gpc: + return False + if node.args or node.keywords: + raise QuadrantsSyntaxError("qd.graph_parallel_context() takes no arguments") + return True + + @staticmethod + def is_parallel_section_call(node: ast.expr) -> bool: + """If *node* is a ``qd.graph_parallel()`` (a section) call return True, else False. The call shape is validated + here so misuse raises at the ``with`` site rather than later.""" + if not isinstance(node, ast.Call): + return False + func = node.func + is_parallel_section = (isinstance(func, ast.Attribute) and func.attr == "graph_parallel") or ( + isinstance(func, ast.Name) and func.id == "graph_parallel" + ) + if not is_parallel_section: + return False + if node.args or node.keywords: + raise QuadrantsSyntaxError("qd.graph_parallel() takes no arguments") + return True + + @staticmethod + def build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast.With, build_stmts) -> None: + """Handles ``with qd.graph_parallel_context():`` fork/join regions. + + Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains only + ``with qd.graph_parallel():`` blocks, then walks the body. The region emits no IR tag of its own -- each + ``qd.graph_parallel`` section inside lowers to a stream-parallel group (via begin/end_stream_parallel), and the + graph builder forks the distinct groups in a contiguous run and joins them. Regions are kept apart by the serial + work between them.""" + if not ctx.is_kernel: + raise QuadrantsSyntaxError("qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func") + kernel = ctx.global_context.current_kernel + if kernel is None or not kernel.use_graph: + raise QuadrantsSyntaxError("qd.graph_parallel_context() requires @qd.kernel(graph=True)") + if getattr(ctx, "_in_graph_parallel_context", False): + raise QuadrantsSyntaxError("qd.graph_parallel_context() regions cannot be nested") + if getattr(ctx, "_in_parallel_section", False): + raise QuadrantsSyntaxError("qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body") + GraphParallelTransformer._validate_graph_parallel_context_body(ctx, node.body) + ctx._in_graph_parallel_context = True + try: + build_stmts(ctx, node.body) + finally: + ctx._in_graph_parallel_context = False + return None + + @staticmethod + def _validate_graph_parallel_context_body(ctx: ASTTransformerFuncContext, stmts: list[ast.stmt]) -> None: + """A qd.graph_parallel_context() region body may contain only `with qd.graph_parallel():` blocks, optionally + wrapped in compile-time `if qd.static(...)` (the optional ``qd.graph_parallel`` section pattern, e.g. qipc's + ENABLE_EE) or `for ... in qd.static(...)` loops (generate one ``qd.graph_parallel`` section per element of a + compile-time sequence). Docstrings / coverage probes / `pass` are allowed. Anything else (a runtime for-loop, a + bare assignment, etc.) is a serial task that would silently fall outside any ``qd.graph_parallel`` section, so + reject it. + + The `for` case is restricted to `qd.static(...)` loops on purpose: a static loop unrolls at trace time into its + repeated body, so it lowers to literal `with qd.graph_parallel():` blocks (each gets a fresh + stream_parallel_group_id). A *runtime* for-loop would instead trace a single parallel range_for with the section + tagging nested inside it -- malformed. Staticness is checked with `get_decorator` (the same resolution + `build_For` uses) at every nesting level, so a runtime loop nested under a static one is still rejected.""" + for i, stmt in enumerate(stmts): + if FunctionDefTransformer._is_docstring(stmt, i) or FunctionDefTransformer._is_coverage_probe(stmt): + continue + if isinstance(stmt, ast.Pass): + continue + if isinstance(stmt, ast.With) and stmt.items: + if GraphParallelTransformer.is_parallel_section_call(stmt.items[0].context_expr): + continue + if isinstance(stmt, ast.If): + GraphParallelTransformer._validate_graph_parallel_context_body(ctx, stmt.body) + GraphParallelTransformer._validate_graph_parallel_context_body(ctx, stmt.orelse) + continue + if isinstance(stmt, ast.For) and not stmt.orelse and get_decorator(ctx, stmt.iter) == "static": + GraphParallelTransformer._validate_graph_parallel_context_body(ctx, stmt.body) + continue + raise QuadrantsSyntaxError( + "A qd.graph_parallel_context() region may contain only 'with qd.graph_parallel():' blocks " + "(optionally inside 'if qd.static(...)' or 'for ... in qd.static(...)'). Move other work " + f"outside the region. [offending stmt {i}: {type(stmt).__name__}]" + ) + + @staticmethod + def build_parallel_section_with(ctx: ASTTransformerFuncContext, node: ast.With, build_stmts) -> None: + """Handles a ``with qd.graph_parallel():`` section of a ``qd.graph_parallel_context()`` region. + + Reuses the stream-parallel tagging: begin_stream_parallel() assigns this ``qd.graph_parallel`` section a fresh + ``stream_parallel_group_id`` that every for-loop in the body inherits, so the offloaded tasks carry the + ``qd.graph_parallel`` section id all the way to the graph builder.""" + if not getattr(ctx, "_in_graph_parallel_context", False): + raise QuadrantsSyntaxError( + "qd.graph_parallel() can only be used directly inside a qd.graph_parallel_context() region" + ) + ctx._in_parallel_section = True + ctx.ast_builder.begin_stream_parallel() + try: + build_stmts(ctx, node.body) + finally: + ctx.ast_builder.end_stream_parallel() + ctx._in_parallel_section = False + return None From c5337553d76feaf56c6a55e1355128f0564c86b1 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Tue, 23 Jun 2026 12:53:27 -0700 Subject: [PATCH 20/25] fix(graph): add qd.graph_parallel{,_context} to test_api expected surface test_api.py asserts the exact set of public qd.* symbols; the renamed context managers were missing from the expected list, failing the test across every platform. Also reflow a misc.py docstring line to 120c. --- python/quadrants/lang/misc.py | 4 ++-- tests/python/test_api.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index fa36e67f86..fbc542179a 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -781,8 +781,8 @@ def graph_parallel_context(): - Regions cannot be nested, and a ``qd.graph_parallel`` section body must be straight-line task work (no nested ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). - This function should not be called directly at runtime; it is recognized and transformed during AST - compilation. At Python runtime (outside kernels) it is a no-op context manager. + This function should not be called directly at runtime; it is recognized and transformed during AST compilation. + At Python runtime (outside kernels) it is a no-op context manager. See also ``docs/source/user_guide/graph.md``. """ diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 824b6ea7a6..e6b316d911 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -157,6 +157,8 @@ def _get_expected_matrix_apis(): "global_thread_idx", "gpu", "graph_do_while", + "graph_parallel", + "graph_parallel_context", "grouped", "i", "i16", From 9f35526867aea0459f6e01aaaa14dbd83bbc0375 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 24 Jun 2026 08:52:33 -0400 Subject: [PATCH 21/25] fix up doc ci flagged issues --- docs/source/user_guide/graph.md | 5 +++-- docs/source/user_guide/streams.md | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index a155b1c67e..cb8a6ff107 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -162,7 +162,7 @@ Note that `qd.func`'s are inlined, so you can freely factorize these structures ### Restrictions -- The counter ndarray may be swapped between calls: the cached graph reads each counter through an indirection slot that is refreshed on every launch, so passing a different ndarray (or alternating between several) replays the cached graph without rebuilding it. +- The counter ndarray may be swapped between calls. Passing a different ndarray (or alternating between several) replays the cached graph without rebuilding it. ### Caveats @@ -255,7 +255,7 @@ while status.yielded: ### Restrictions -- Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time with a fix-it pointing at `checkpoints=True`. +- Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time. - `cp_id` must be an int literal or an `IntEnum` value, and must be unique across the kernel. - `yield_on=` must be a kernel parameter that is a 0-d `qd.types.ndarray(qd.i32, ndim=0)`; expressions are not supported. - Checkpoints cannot be nested inside other checkpoints. Checkpoints inside a `qd.graph_do_while` body are fine. @@ -498,6 +498,7 @@ def step(...): `qd.graph_parallel` sections do not have to be written out one by one. A `for ... in qd.static(...)` loop is unrolled at compile time, so each iteration that contains a `with qd.graph_parallel():` becomes its own section — handy for forking one section per element of a static list (e.g. per contact type): +(Note: See [compound_types.md](compound_types.md) for qd.data_oriented description) ```python @qd.data_oriented class Solver: diff --git a/docs/source/user_guide/streams.md b/docs/source/user_guide/streams.md index f9a79e3f48..a3df7cedc2 100644 --- a/docs/source/user_guide/streams.md +++ b/docs/source/user_guide/streams.md @@ -136,4 +136,4 @@ with qd.create_stream() as s: ## Limitations - **Not compatible with graphs.** Do not pass `qd_stream` to a kernel decorated with `graph=True` (if you do, a `RuntimeError` will be raised). -- **Not compatible with autodiff.** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context (if you do, a `RuntimeError` will be raised). +- **Not compatible with [autodiff.md](autodiff.md).** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context (if you do, a `RuntimeError` will be raised). From 4c80eb0a973e5e6fc0900e45840182cd12c79c14 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 24 Jun 2026 10:07:24 -0400 Subject: [PATCH 22/25] address doc CI issues --- docs/source/user_guide/graph.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index cb8a6ff107..eb308c27a5 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -4,7 +4,7 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i ## Backend support -`graph=True` and `graph_do_while` run on every backend. They are *hardware accelerated* on CUDA (via CUDA graphs) and AMDGPU (via HIP graphs); `graph_do_while` additionally requires CUDA SM 9.0+ / Hopper for its hardware-accelerated path. On other backends, `graph=True` is silently ignored and the kernel runs via the normal launch path, and `graph_do_while` falls back to a host-side do-while loop that copies the condition value GPU → host each iteration (causing a pipeline stall). `qd.checkpoint` gating runs entirely on the device on every GPU backend; only the CPU backend uses host-side gating. +`graph=True` and `graph_do_while` run on every backend. They are *hardware accelerated* on CUDA (via CUDA graphs) and AMDGPU (via HIP graphs); `graph_do_while` additionally requires [CUDA SM 9.0+](https://developer.nvidia.com/cuda/gpus) for its hardware-accelerated path. On other backends, `graph=True` is silently ignored and the kernel runs via the normal launch path, and `graph_do_while` falls back to a host-side do-while loop. `qd.checkpoint` gating runs entirely on the device on every GPU backend. | Feature | `qd.cuda` SM 9.0+ | `qd.cuda` < SM 9.0 | `qd.amdgpu` | `qd.metal` | `qd.vulkan` | `qd.cpu` | | --- | --- | --- | --- | --- | --- | --- | @@ -45,7 +45,7 @@ my_kernel(x, y) # first call: builds and caches the graph my_kernel(x, y) # subsequent calls: replays the cached graph ``` -This works the same way on CUDA and AMDGPU. The cache is keyed per (compiled-kernel-specialization, launch-id), so different template instantiations (different field bindings, etc.) get their own cached graph. +This works the same way on CUDA and AMDGPU. ### Restrictions @@ -491,7 +491,7 @@ def step(...): ### Semantics -- **Fork / join.** Every `qd.graph_parallel` section in the region forks from the work that precedes the region. All `qd.graph_parallel` sections must finish before any work *after* the region begins (the join). On CUDA the join is a single empty graph node depending on every `qd.graph_parallel` section's last kernel. +- **Fork / join.** Every `qd.graph_parallel` section in the region forks from the work that precedes the region. All `qd.graph_parallel` sections must finish before any work *after* the region begins (the join). - **`qd.graph_parallel` sections are independent — you guarantee it.** Calls *within* a `qd.graph_parallel` section keep their program order, but calls in *different* `qd.graph_parallel` sections have no ordering. The `qd.graph_parallel` sections must be data-race free with respect to one another: no `qd.graph_parallel` section may read what another writes, and no two `qd.graph_parallel` sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results. ### Generating `qd.graph_parallel` sections from a compile-time sequence From 0ba1929700c7e6ed4fa0122ae73933cde5e1214a Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 24 Jun 2026 06:44:18 -0700 Subject: [PATCH 23/25] fix(graph): keep joins between back-to-back qd.graph_parallel_context regions Two qd.graph_parallel_context() regions written back-to-back (no serial work between them) were merged into one fork/join by the CUDA graph builder: the run-extension loop keyed only on stream_parallel_group_id != 0, with no region boundary marker, so the second region's sections forked from the same entry as the first's and could race the first region's writes. Thread a per-kernel graph_parallel_region_id from the AST builder (begin/end_graph_parallel_context) through ForLoopConfig -> FrontendForStmt -> RangeForStmt/StructForStmt -> OffloadedStmt -> OffloadedTask, and require equal region id when extending a fork/join run in build_level so adjacent regions each get their own join. AMD already serializes sections, so only the CUDA builder needed the guard. Adds a regression test. --- .../graph_parallel_transformer.py | 18 +++--- quadrants/codegen/amdgpu/codegen_amdgpu.cpp | 1 + quadrants/codegen/cuda/codegen_cuda.cpp | 1 + quadrants/codegen/llvm/llvm_compiled_data.h | 5 ++ quadrants/ir/frontend_ir.cpp | 6 ++ quadrants/ir/frontend_ir.h | 25 ++++++++ quadrants/ir/ir.h | 6 ++ quadrants/ir/statements.cpp | 3 + quadrants/ir/statements.h | 13 +++++ quadrants/python/export_lang.cpp | 2 + quadrants/runtime/cuda/graph_manager.cpp | 10 +++- quadrants/transforms/lower_ast.cpp | 3 + quadrants/transforms/offload.cpp | 11 +++- tests/python/test_graph_parallel.py | 58 +++++++++++++++++++ 14 files changed, 151 insertions(+), 11 deletions(-) diff --git a/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py b/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py index d3fa27d312..5ca9a2a896 100644 --- a/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py +++ b/python/quadrants/lang/ast/ast_transformers/graph_parallel_transformer.py @@ -5,9 +5,10 @@ have to grow per-feature. ``ASTTransformer.build_With`` forwards ``qd.graph_parallel_context()`` regions and their ``qd.graph_parallel()`` sections into the static methods here. -A ``qd.graph_parallel_context()`` region emits no IR tag of its own: each ``qd.graph_parallel()`` section inside lowers -to a stream-parallel group (via ``begin/end_stream_parallel``), and the graph builder forks the distinct groups in a -contiguous run and joins them. Regions are kept apart by the serial work between them. See +A ``qd.graph_parallel_context()`` region tags its body with a per-kernel region id (via +``begin/end_graph_parallel_context``) and each ``qd.graph_parallel()`` section inside lowers to a stream-parallel group +(via ``begin/end_stream_parallel``). The graph builder forks the distinct groups of one region in a contiguous run and +joins them; the region id keeps two back-to-back regions apart (each gets its own join). See ``docs/source/user_guide/graph.md`` for the user-facing surface. """ @@ -62,10 +63,11 @@ def build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast. """Handles ``with qd.graph_parallel_context():`` fork/join regions. Validates the use-site (kernel must be graph=True, no nesting) and that the region body contains only - ``with qd.graph_parallel():`` blocks, then walks the body. The region emits no IR tag of its own -- each - ``qd.graph_parallel`` section inside lowers to a stream-parallel group (via begin/end_stream_parallel), and the - graph builder forks the distinct groups in a contiguous run and joins them. Regions are kept apart by the serial - work between them.""" + ``with qd.graph_parallel():`` blocks, then walks the body. The region is bracketed with + begin/end_graph_parallel_context() so its body carries a per-kernel region id, and each ``qd.graph_parallel`` + section inside lowers to a stream-parallel group (via begin/end_stream_parallel). The graph builder forks the + distinct groups of one region in a contiguous run and joins them; the region id keeps two back-to-back regions + apart (each gets its own join).""" if not ctx.is_kernel: raise QuadrantsSyntaxError("qd.graph_parallel_context() can only be used inside @qd.kernel, not @qd.func") kernel = ctx.global_context.current_kernel @@ -77,9 +79,11 @@ def build_graph_parallel_context_with(ctx: ASTTransformerFuncContext, node: ast. raise QuadrantsSyntaxError("qd.graph_parallel_context() cannot appear inside a qd.graph_parallel() body") GraphParallelTransformer._validate_graph_parallel_context_body(ctx, node.body) ctx._in_graph_parallel_context = True + ctx.ast_builder.begin_graph_parallel_context() try: build_stmts(ctx, node.body) finally: + ctx.ast_builder.end_graph_parallel_context() ctx._in_graph_parallel_context = False return None diff --git a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp index ccba1e701b..291bbe264b 100644 --- a/quadrants/codegen/amdgpu/codegen_amdgpu.cpp +++ b/quadrants/codegen/amdgpu/codegen_amdgpu.cpp @@ -417,6 +417,7 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM { } current_task->block_dim = stmt->block_dim; current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; + current_task->graph_parallel_region_id = stmt->graph_parallel_region_id; current_task->checkpoint_id = stmt->checkpoint_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); diff --git a/quadrants/codegen/cuda/codegen_cuda.cpp b/quadrants/codegen/cuda/codegen_cuda.cpp index 4e19864055..6961a97878 100644 --- a/quadrants/codegen/cuda/codegen_cuda.cpp +++ b/quadrants/codegen/cuda/codegen_cuda.cpp @@ -673,6 +673,7 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { current_task->block_dim = stmt->block_dim; current_task->dynamic_shared_array_bytes = dynamic_shared_array_bytes; current_task->stream_parallel_group_id = stmt->stream_parallel_group_id; + current_task->graph_parallel_region_id = stmt->graph_parallel_region_id; current_task->checkpoint_id = stmt->checkpoint_id; QD_ASSERT(current_task->grid_dim != 0); QD_ASSERT(current_task->block_dim != 0); diff --git a/quadrants/codegen/llvm/llvm_compiled_data.h b/quadrants/codegen/llvm/llvm_compiled_data.h index 7a08c77be5..8f3bed28a1 100644 --- a/quadrants/codegen/llvm/llvm_compiled_data.h +++ b/quadrants/codegen/llvm/llvm_compiled_data.h @@ -118,6 +118,10 @@ class OffloadedTask { int grid_dim{0}; int dynamic_shared_array_bytes{0}; int stream_parallel_group_id{0}; + // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Populated by the CUDA LLVM codegen from + // `OffloadedStmt::graph_parallel_region_id`. The GraphManager pairs it with `stream_parallel_group_id` so two + // back-to-back regions are built as separate fork/join groups (each with its own join) instead of one merged group. + int graph_parallel_region_id{0}; // `cp_id` of the enclosing `qd.checkpoint(...)` block for this task (`-1` outside any checkpoint). Populated by the // CUDA / AMDGPU LLVM codegen from `OffloadedStmt::checkpoint_id` (set by the offload pass from // `RangeForStmt::checkpoint_id` / `StructForStmt::checkpoint_id`). The GraphManager will consume this in slice 1c to @@ -167,6 +171,7 @@ class OffloadedTask { grid_dim, dynamic_shared_array_bytes, stream_parallel_group_id, + graph_parallel_region_id, checkpoint_id, graph_do_while_level_id, ad_stack, diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index 7b0f4fa04f..75cad9f380 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -111,6 +111,7 @@ FrontendForStmt::FrontendForStmt(const FrontendForStmt &o) mem_access_opt(o.mem_access_opt), block_dim(o.block_dim), stream_parallel_group_id(o.stream_parallel_group_id), + graph_parallel_region_id(o.graph_parallel_region_id), graph_do_while_level_id(o.graph_do_while_level_id), checkpoint_id(o.checkpoint_id), loop_name(o.loop_name) { @@ -122,6 +123,7 @@ void FrontendForStmt::init_config(Arch arch, const ForLoopConfig &config) { mem_access_opt = config.mem_access_opt; block_dim = config.block_dim; stream_parallel_group_id = config.stream_parallel_group_id; + graph_parallel_region_id = config.graph_parallel_region_id; graph_do_while_level_id = config.graph_do_while_level_id; checkpoint_id = config.checkpoint_id; loop_name = config.loop_name; @@ -1510,6 +1512,7 @@ void ASTBuilder::warn_if_named_nested_loop() { void ASTBuilder::begin_frontend_range_for(const Expr &i, const Expr &s, const Expr &e, const DebugInfo &dbg_info) { warn_if_named_nested_loop(); for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; + for_loop_dec_.config.graph_parallel_region_id = current_graph_parallel_region_id_; for_loop_dec_.config.graph_do_while_level_id = current_graph_do_while_level_id_; for_loop_dec_.config.checkpoint_id = current_checkpoint_id_; auto stmt_unique = std::make_unique(i, s, e, arch_, for_loop_dec_.config, dbg_info); @@ -1527,6 +1530,7 @@ void ASTBuilder::begin_frontend_struct_for_on_snode(const ExprGroup &loop_vars, "qd.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; + for_loop_dec_.config.graph_parallel_region_id = current_graph_parallel_region_id_; for_loop_dec_.config.graph_do_while_level_id = current_graph_do_while_level_id_; for_loop_dec_.config.checkpoint_id = current_checkpoint_id_; auto stmt_unique = std::make_unique(loop_vars, snode, arch_, for_loop_dec_.config, dbg_info); @@ -1544,6 +1548,7 @@ void ASTBuilder::begin_frontend_struct_for_on_external_tensor(const ExprGroup &l "qd.loop_config(serialize=True) does not have effect on the struct for. " "The execution order is not guaranteed."); for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; + for_loop_dec_.config.graph_parallel_region_id = current_graph_parallel_region_id_; for_loop_dec_.config.graph_do_while_level_id = current_graph_do_while_level_id_; for_loop_dec_.config.checkpoint_id = current_checkpoint_id_; auto stmt_unique = @@ -1563,6 +1568,7 @@ void ASTBuilder::begin_frontend_mesh_for(const Expr &i, "qd.loop_config(serialize=True) does not have effect on the mesh for. " "The execution order is not guaranteed."); for_loop_dec_.config.stream_parallel_group_id = current_stream_parallel_group_id_; + for_loop_dec_.config.graph_parallel_region_id = current_graph_parallel_region_id_; for_loop_dec_.config.graph_do_while_level_id = current_graph_do_while_level_id_; for_loop_dec_.config.checkpoint_id = current_checkpoint_id_; auto stmt_unique = diff --git a/quadrants/ir/frontend_ir.h b/quadrants/ir/frontend_ir.h index 39a116ea45..ff18016937 100644 --- a/quadrants/ir/frontend_ir.h +++ b/quadrants/ir/frontend_ir.h @@ -24,6 +24,11 @@ struct ForLoopConfig { int block_dim{0}; bool uniform{false}; int stream_parallel_group_id{0}; + // Per-kernel id of the enclosing `qd.graph_parallel_context()` region (0 outside any region). Assigned by the AST + // builder's `current_graph_parallel_region_id_` at `begin_frontend_*_for` time and threaded alongside + // `stream_parallel_group_id` so the CUDA graph builder can tell two back-to-back regions apart and keep their joins + // (without it, adjacent regions merge into one fork/join and the second can race the first). + int graph_parallel_region_id{0}; int graph_do_while_level_id{-1}; // `cp_id` (see design doc `perso_hugh/doc/qipc/reentrant.md` section 5.1) of the enclosing `qd.checkpoint(...)` block // when this for-loop is emitted, or `-1` when the for-loop is outside any checkpoint. Assigned by the AST builder's @@ -208,6 +213,7 @@ class FrontendForStmt : public Stmt { MemoryAccessOptions mem_access_opt; int block_dim; int stream_parallel_group_id{0}; + int graph_parallel_region_id{0}; int graph_do_while_level_id{-1}; int checkpoint_id{-1}; std::string loop_name; @@ -929,6 +935,7 @@ class ASTBuilder { config.block_dim = 0; config.strictly_serialized = false; config.stream_parallel_group_id = 0; + config.graph_parallel_region_id = 0; config.graph_do_while_level_id = -1; config.checkpoint_id = -1; config.loop_name.clear(); @@ -943,6 +950,12 @@ class ASTBuilder { int id_counter_{0}; int stream_parallel_group_counter_{0}; int current_stream_parallel_group_id_{0}; + // Per-kernel counter handed out by `begin_graph_parallel_context()` (one id per `qd.graph_parallel_context()` + // region), and the innermost active region id (0 outside any region). Reset per kernel via fresh ASTBuilder + // construction. Mirrors `stream_parallel_group_counter_` / `current_stream_parallel_group_id_`, but at region + // (not section) granularity, so for-loops created inside a region carry its id. + int graph_parallel_region_counter_{0}; + int current_graph_parallel_region_id_{0}; // Innermost active graph_do_while level id (-1 if not inside any). The Python AST transformer manages the stack and // calls set_graph_do_while_level_id() on enter/exit; for-loops created while it is >= 0 are tagged with it (mirrors // current_stream_parallel_group_id_). @@ -1094,6 +1107,18 @@ class ASTBuilder { current_stream_parallel_group_id_ = 0; } + // Enter a `qd.graph_parallel_context()` region: hand out a fresh per-kernel region id that every for-loop emitted + // inside the region (across all its `qd.graph_parallel()` sections) carries, so the CUDA graph builder can keep + // adjacent regions' fork/join groups apart. The Python AST transformer calls these on region enter/exit. + void begin_graph_parallel_context() { + QD_ERROR_IF(current_graph_parallel_region_id_ != 0, "qd.graph_parallel_context() regions cannot be nested"); + current_graph_parallel_region_id_ = ++graph_parallel_region_counter_; + } + + void end_graph_parallel_context() { + current_graph_parallel_region_id_ = 0; + } + // Set the innermost active graph_do_while level id. Pass the new level id when entering a graph_do_while loop, and // the parent level id (or -1) when leaving it. The Python AST transformer owns the level stack and the level table. void set_graph_do_while_level_id(int level_id) { diff --git a/quadrants/ir/ir.h b/quadrants/ir/ir.h index fb5b6dc5a9..997399507b 100644 --- a/quadrants/ir/ir.h +++ b/quadrants/ir/ir.h @@ -397,6 +397,12 @@ class StmtFieldManager { struct GraphRegionTag { int graph_do_while_level_id{-1}; int stream_parallel_group_id{0}; + // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Only ever set on the PURE-bucket + // `fallback_tag` in `offload.cpp` (so a section's pure bound-compute serial is grouped into the same region as the + // for-loop that consumes it). Source-level serial statements live outside any region, so their stamped `region_tag` + // leaves this at 0. Excluded from operator==/!= for the same reason `is_set` is: the serial bucket only ever + // compares region-0 side-effecting tags, so region id never participates. + int graph_parallel_region_id{0}; // `cp_id` (see `quadrants/lang/checkpoint.py`) of the enclosing `qd.checkpoint(...)` block, or `-1` outside any // checkpoint. Stamped on every frontend statement by `ASTBuilder::insert` (from `current_checkpoint_id_`) so a // SIDE-EFFECTING serial task tagged with a real checkpoint can be gated/skipped on `resume(from_checkpoint=...)` diff --git a/quadrants/ir/statements.cpp b/quadrants/ir/statements.cpp index 3e29a5a8d2..1911f71f70 100644 --- a/quadrants/ir/statements.cpp +++ b/quadrants/ir/statements.cpp @@ -223,6 +223,7 @@ std::unique_ptr RangeForStmt::clone() const { block_dim, strictly_serialized); new_stmt->reversed = reversed; new_stmt->stream_parallel_group_id = stream_parallel_group_id; + new_stmt->graph_parallel_region_id = graph_parallel_region_id; new_stmt->checkpoint_id = checkpoint_id; new_stmt->graph_do_while_level_id = graph_do_while_level_id; new_stmt->loop_name = loop_name; @@ -247,6 +248,7 @@ std::unique_ptr StructForStmt::clone() const { auto new_stmt = std::make_unique(snode, body->clone(), is_bit_vectorized, num_cpu_threads, block_dim); new_stmt->mem_access_opt = mem_access_opt; new_stmt->stream_parallel_group_id = stream_parallel_group_id; + new_stmt->graph_parallel_region_id = graph_parallel_region_id; new_stmt->checkpoint_id = checkpoint_id; new_stmt->graph_do_while_level_id = graph_do_while_level_id; new_stmt->loop_name = loop_name; @@ -410,6 +412,7 @@ std::unique_ptr OffloadedStmt::clone() const { new_stmt->bls_size = bls_size; new_stmt->mem_access_opt = mem_access_opt; new_stmt->stream_parallel_group_id = stream_parallel_group_id; + new_stmt->graph_parallel_region_id = graph_parallel_region_id; new_stmt->checkpoint_id = checkpoint_id; new_stmt->graph_do_while_level_id = graph_do_while_level_id; new_stmt->loop_name = loop_name; diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 9f722cf261..93b82414df 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -977,6 +977,10 @@ class RangeForStmt : public Stmt { bool strictly_serialized; std::string range_hint; int stream_parallel_group_id{0}; + // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Propagated from + // `FrontendForStmt::graph_parallel_region_id` by `lower_ast.cpp` and on to `OffloadedStmt` by `offload.cpp`. See + // ForLoopConfig comment in `frontend_ir.h`. + int graph_parallel_region_id{0}; // `cp_id` of the enclosing `qd.checkpoint(...)` block (`-1` outside any checkpoint). Propagated from // `FrontendForStmt::checkpoint_id` by `lower_ast.cpp`, then carried into the post-offload // `OffloadedStmt::checkpoint_id` by `offload.cpp`. See ForLoopConfig comment in `frontend_ir.h` for the full @@ -1015,6 +1019,7 @@ class RangeForStmt : public Stmt { block_dim, strictly_serialized, stream_parallel_group_id, + graph_parallel_region_id, checkpoint_id, graph_do_while_level_id); QD_DEFINE_ACCEPT @@ -1036,6 +1041,8 @@ class StructForStmt : public Stmt { int block_dim; MemoryAccessOptions mem_access_opt; int stream_parallel_group_id{0}; + // See `RangeForStmt::graph_parallel_region_id` -- same lifecycle. + int graph_parallel_region_id{0}; // See `RangeForStmt::checkpoint_id` -- same lifecycle, same `-1` sentinel. int checkpoint_id{-1}; int graph_do_while_level_id{-1}; @@ -1060,6 +1067,7 @@ class StructForStmt : public Stmt { block_dim, mem_access_opt, stream_parallel_group_id, + graph_parallel_region_id, checkpoint_id, graph_do_while_level_id); QD_DEFINE_ACCEPT @@ -1406,6 +1414,10 @@ class OffloadedStmt : public Stmt { std::size_t bls_size{0}; MemoryAccessOptions mem_access_opt; int stream_parallel_group_id{0}; + // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Set by `offload.cpp` from the source + // `RangeForStmt` / `StructForStmt`, read by the CUDA LLVM codegen to populate + // `OffloadedTask::graph_parallel_region_id` so the GraphManager keeps adjacent regions' fork/join groups apart. + int graph_parallel_region_id{0}; // `cp_id` of the enclosing `qd.checkpoint(...)` block for this offloaded task (`-1` outside any checkpoint). Set by // `offload.cpp` from the source `RangeForStmt::checkpoint_id` / `StructForStmt::checkpoint_id`. Read by the CUDA / // AMDGPU LLVM codegen to populate `OffloadedTask::checkpoint_id`, which the GraphManager will consume in slice 1c. @@ -1463,6 +1475,7 @@ class OffloadedStmt : public Stmt { index_offsets, mem_access_opt, stream_parallel_group_id, + graph_parallel_region_id, checkpoint_id, graph_do_while_level_id); QD_DEFINE_ACCEPT diff --git a/quadrants/python/export_lang.cpp b/quadrants/python/export_lang.cpp index 83e8b50f79..43bf580ce8 100644 --- a/quadrants/python/export_lang.cpp +++ b/quadrants/python/export_lang.cpp @@ -313,6 +313,8 @@ void export_lang(py::module &m) { .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag) .def("begin_stream_parallel", &ASTBuilder::begin_stream_parallel) .def("end_stream_parallel", &ASTBuilder::end_stream_parallel) + .def("begin_graph_parallel_context", &ASTBuilder::begin_graph_parallel_context) + .def("end_graph_parallel_context", &ASTBuilder::end_graph_parallel_context) .def("set_graph_do_while_level_id", &ASTBuilder::set_graph_do_while_level_id) .def("begin_checkpoint", &ASTBuilder::begin_checkpoint) .def("end_checkpoint", &ASTBuilder::end_checkpoint); diff --git a/quadrants/runtime/cuda/graph_manager.cpp b/quadrants/runtime/cuda/graph_manager.cpp index daa17fdcc9..0ac08a07fd 100644 --- a/quadrants/runtime/cuda/graph_manager.cpp +++ b/quadrants/runtime/cuda/graph_manager.cpp @@ -409,11 +409,17 @@ void GraphManager::build_level(int parent_id, // distinct group id is one qd.graph_parallel section; the qd.graph_parallel sections fork from the // region's entry (`prev_node`), run their tasks in order, and join into a single empty node so // downstream work waits for all of them. CUDA's graph executor schedules the independent - // qd.graph_parallel section chains on separate streams -> real overlap. --- + // qd.graph_parallel section chains on separate streams -> real overlap. + // + // The run is bounded to a single region by graph_parallel_region_id: two qd.graph_parallel_context() regions + // written back-to-back (no serial task between them to break the run) carry distinct region ids, so each builds + // its own fork/join with its own join node. Without this guard the second region's sections would fork from the + // same entry as the first's and could run concurrently with -- and race -- the first region's work. --- if (tasks[cursor].stream_parallel_group_id != 0 && tasks[cursor].checkpoint_id < 0) { + const int region_id = tasks[cursor].graph_parallel_region_id; int run_end = cursor; while (run_end < end && tasks[run_end].graph_do_while_level_id == parent_id && tasks[run_end].checkpoint_id < 0 && - tasks[run_end].stream_parallel_group_id != 0) { + tasks[run_end].stream_parallel_group_id != 0 && tasks[run_end].graph_parallel_region_id == region_id) { run_end++; } // Bucket the run's tasks by qd.graph_parallel section id, preserving first-seen (declaration) order. diff --git a/quadrants/transforms/lower_ast.cpp b/quadrants/transforms/lower_ast.cpp index fe8084c6bc..31a5b29b21 100644 --- a/quadrants/transforms/lower_ast.cpp +++ b/quadrants/transforms/lower_ast.cpp @@ -232,6 +232,7 @@ class LowerAST : public IRVisitor { new_for->loop_name = stmt->loop_name; new_for->index_offsets = offsets; new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; + new_for->graph_parallel_region_id = stmt->graph_parallel_region_id; new_for->checkpoint_id = stmt->checkpoint_id; new_for->graph_do_while_level_id = stmt->graph_do_while_level_id; VecStatement new_statements; @@ -269,6 +270,7 @@ class LowerAST : public IRVisitor { /*range_hint=*/fmt::format("arg ({})", fmt::join(arg_id, ", ")), /*loop_name=*/stmt->loop_name); new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; + new_for->graph_parallel_region_id = stmt->graph_parallel_region_id; new_for->checkpoint_id = stmt->checkpoint_id; new_for->graph_do_while_level_id = stmt->graph_do_while_level_id; VecStatement new_statements; @@ -306,6 +308,7 @@ class LowerAST : public IRVisitor { stmt->strictly_serialized, /*range_hint=*/"", /*loop_name=*/stmt->loop_name); new_for->stream_parallel_group_id = stmt->stream_parallel_group_id; + new_for->graph_parallel_region_id = stmt->graph_parallel_region_id; new_for->checkpoint_id = stmt->checkpoint_id; new_for->graph_do_while_level_id = stmt->graph_do_while_level_id; new_for->body->insert(std::make_unique(new_for.get(), 0), 0); diff --git a/quadrants/transforms/offload.cpp b/quadrants/transforms/offload.cpp index 8baa775f14..36ec81ead3 100644 --- a/quadrants/transforms/offload.cpp +++ b/quadrants/transforms/offload.cpp @@ -109,6 +109,7 @@ class Offloader { const GraphRegionTag tag = bucket_has_side_effect ? bucket_tag : fallback_tag; pending_serial_statements->graph_do_while_level_id = tag.graph_do_while_level_id; pending_serial_statements->stream_parallel_group_id = tag.stream_parallel_group_id; + pending_serial_statements->graph_parallel_region_id = tag.graph_parallel_region_id; pending_serial_statements->checkpoint_id = tag.checkpoint_id; root_block->insert(std::move(pending_serial_statements)); pending_serial_statements = Stmt::make_typed(OffloadedStmt::TaskType::serial, arch, kernel); @@ -156,7 +157,9 @@ class Offloader { auto &stmt = root_statements[i]; // Note that stmt->parent is root_block, which doesn't contain stmt now. if (auto s = stmt->cast(); s && !s->strictly_serialized) { - assemble_serial_statements(GraphRegionTag{s->graph_do_while_level_id, s->stream_parallel_group_id}); + GraphRegionTag pre_for_tag{s->graph_do_while_level_id, s->stream_parallel_group_id}; + pre_for_tag.graph_parallel_region_id = s->graph_parallel_region_id; + assemble_serial_statements(pre_for_tag); auto offloaded = Stmt::make_typed(OffloadedStmt::TaskType::range_for, arch, kernel); // offloaded->body is an empty block now. offloaded->grid_dim = config.saturating_grid_dim; @@ -192,12 +195,15 @@ class Offloader { } offloaded->range_hint = s->range_hint; offloaded->stream_parallel_group_id = s->stream_parallel_group_id; + offloaded->graph_parallel_region_id = s->graph_parallel_region_id; offloaded->graph_do_while_level_id = s->graph_do_while_level_id; offloaded->checkpoint_id = s->checkpoint_id; offloaded->loop_name = s->loop_name; root_block->insert(std::move(offloaded)); } else if (auto st = stmt->cast()) { - assemble_serial_statements(GraphRegionTag{st->graph_do_while_level_id, st->stream_parallel_group_id}); + GraphRegionTag pre_for_tag{st->graph_do_while_level_id, st->stream_parallel_group_id}; + pre_for_tag.graph_parallel_region_id = st->graph_parallel_region_id; + assemble_serial_statements(pre_for_tag); emit_struct_for(st, root_block, config, st->mem_access_opt); } else if (auto st = stmt->cast()) { assemble_serial_statements(GraphRegionTag{st->graph_do_while_level_id, /*group=*/0}); @@ -309,6 +315,7 @@ class Offloader { offloaded_struct_for->num_cpu_threads = std::min(for_stmt->num_cpu_threads, config.cpu_max_num_threads); offloaded_struct_for->mem_access_opt = mem_access_opt; offloaded_struct_for->stream_parallel_group_id = for_stmt->stream_parallel_group_id; + offloaded_struct_for->graph_parallel_region_id = for_stmt->graph_parallel_region_id; offloaded_struct_for->graph_do_while_level_id = for_stmt->graph_do_while_level_id; offloaded_struct_for->checkpoint_id = for_stmt->checkpoint_id; offloaded_struct_for->loop_name = for_stmt->loop_name; diff --git a/tests/python/test_graph_parallel.py b/tests/python/test_graph_parallel.py index ad7eb33fad..b6f0e780d7 100644 --- a/tests/python/test_graph_parallel.py +++ b/tests/python/test_graph_parallel.py @@ -100,6 +100,64 @@ def k( np.testing.assert_allclose(z.to_numpy(), 3.0) +@test_utils.test() +def test_graph_parallel_back_to_back_regions_keep_join(): + """Two qd.graph_parallel_context() regions written back-to-back (no serial work between them) must each get + their own fork/join. Region B reads what region A wrote, so if the two regions were merged into one fork/join + (dropping A's join) B's sections would fork alongside -- and race -- A's. On CUDA we assert the graph has one + empty join node per region (two total); on every backend the post-join values must be correct. Regression test + for the back-to-back-region merge bug. + + Each region's two sections stay mutually disjoint (one touches only x, the other only y); the only data + dependency is across the region boundary (B's x-section reads A's x-section output, B's y-section reads A's + y-section output), which is exactly the edge the join protects.""" + n = 1024 + + @qd.kernel(graph=True) + def k(x: qd.types.ndarray(qd.f32, ndim=1), y: qd.types.ndarray(qd.f32, ndim=1)): + # Region A: write x and y in two independent sections. + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = 1.0 + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = 2.0 + # Region B, immediately after A (no serial statement between the regions). Each section reads the region-A + # section that wrote the same array, so B must wait for A's join. B's sections remain disjoint from each + # other (x-only vs y-only). + with qd.graph_parallel_context(): + with qd.graph_parallel(): + for i in range(x.shape[0]): + x[i] = x[i] * 10.0 + with qd.graph_parallel(): + for i in range(y.shape[0]): + y[i] = y[i] * 10.0 + + x = qd.ndarray(qd.f32, shape=(n,)) + y = qd.ndarray(qd.f32, shape=(n,)) + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + + k(x, y) + + if _on_cuda(): + # One empty join node per region (two regions -> two joins). The merge bug this guards would build a single + # fork/join across both regions, emitting only one join, i.e. num_tasks + 1. + assert _graph_num_nodes() == _num_offloaded_tasks() + 2 + + # Correct only if region A fully joined before region B ran: x = 1 * 10, y = 2 * 10. + np.testing.assert_allclose(x.to_numpy(), 10.0) + np.testing.assert_allclose(y.to_numpy(), 20.0) + + # Relaunch: same cached graph, same result. + x.from_numpy(np.zeros(n, dtype=np.float32)) + y.from_numpy(np.zeros(n, dtype=np.float32)) + k(x, y) + np.testing.assert_allclose(x.to_numpy(), 10.0) + np.testing.assert_allclose(y.to_numpy(), 20.0) + + @test_utils.test() def test_graph_parallel_three_sections(): """Fan-out of three independent qd.graph_parallel sections; one empty join node.""" From 66efb1c4975cfac21d98e13e46801a955f2b3300 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Wed, 24 Jun 2026 07:44:10 -0700 Subject: [PATCH 24/25] refactor(graph): extract graph_parallel context managers out of lang/misc.py Moves the `graph_parallel_context()` / `graph_parallel()` no-op context-manager stubs into their own `quadrants/lang/graph_parallel.py`, mirroring `checkpoint.py`, to keep the 958-line catch-all `lang/misc.py` (imported by 33+ modules) from growing further. Re-exported via `misc.py` so the canonical `qd.graph_parallel_context` / `qd.graph_parallel` import paths are unchanged. Addresses the "Check feature factorization" CI finding on PR #756. --- python/quadrants/lang/graph_parallel.py | 63 +++++++++++++++++++++++++ python/quadrants/lang/misc.py | 52 +------------------- 2 files changed, 64 insertions(+), 51 deletions(-) create mode 100644 python/quadrants/lang/graph_parallel.py diff --git a/python/quadrants/lang/graph_parallel.py b/python/quadrants/lang/graph_parallel.py new file mode 100644 index 0000000000..58b2bf5447 --- /dev/null +++ b/python/quadrants/lang/graph_parallel.py @@ -0,0 +1,63 @@ +"""User-facing ``qd.graph_parallel_context`` / ``qd.graph_parallel`` context-managers and their no-op Python-runtime +stubs. + +Kept in its own module to keep ``lang/misc.py`` from growing further (mirrors ``checkpoint.py``) -- the AST transformer +and the C++ runtime do all the actual implementation work; this file is just the public API entry point. + +Re-exported via ``qd.lang.misc`` (and therefore as ``qd.graph_parallel_context`` / ``qd.graph_parallel``) for the +user-facing canonical import path. +""" + +from __future__ import annotations + +from contextlib import contextmanager + + +@contextmanager +def graph_parallel_context(): + """Opens a fork/join region whose ``qd.graph_parallel()`` sections run concurrently. + + Used as ``with qd.graph_parallel_context():`` inside a ``@qd.kernel(graph=True)`` kernel. The region's + body must contain only ``with qd.graph_parallel():`` blocks. Each ``qd.graph_parallel`` section is an + independent sequence of work; the ``qd.graph_parallel`` sections have no ordering relative to each + other and may execute concurrently, while everything after the region waits for *all* ``qd.graph_parallel`` + sections to finish (the join). This is the graph analogue of ``qd.stream_parallel()`` (which is for + non-graph kernels): it lets independent stages -- e.g. qipc's point-triangle and edge-edge assembly + -- overlap inside a captured graph. + + Concurrency contract (the author's responsibility): ``qd.graph_parallel`` sections must be data-race + free with respect to one another (no ``qd.graph_parallel`` section reads what another writes, no two + ``qd.graph_parallel`` sections write the same location). Calls *within* a ``qd.graph_parallel`` section + keep their program order. + + Backend behavior: + - CUDA SM graph path: ``qd.graph_parallel`` sections become independent graph chains joined by an + empty node, so the runtime schedules them on parallel streams (real overlap). + - CPU / Vulkan / Metal / AMDGPU graph: correct results, ``qd.graph_parallel`` sections run serially + (the concurrency tags are honored only by the graph builder today). + + Restrictions (enforced at kernel compile time): + - Must be used inside ``@qd.kernel(graph=True)``. + - The region body may contain only ``with qd.graph_parallel():`` blocks. + - Regions cannot be nested, and a ``qd.graph_parallel`` section body must be straight-line task work + (no nested ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). + + This function should not be called directly at runtime; it is recognized and transformed during AST compilation. + At Python runtime (outside kernels) it is a no-op context manager. + + See also ``docs/source/user_guide/graph.md``. + """ + yield + + +@contextmanager +def graph_parallel(): + """Declares one ``qd.graph_parallel`` section of an enclosing ``qd.graph_parallel_context()`` region. + + Used as ``with qd.graph_parallel():`` directly inside a ``with qd.graph_parallel_context():`` block. + The ``qd.graph_parallel`` section's body is an independent sequence of work that may run concurrently + with the region's other ``qd.graph_parallel`` sections. + + See ``qd.graph_parallel_context()`` for the full contract and backend behavior. + """ + yield diff --git a/python/quadrants/lang/misc.py b/python/quadrants/lang/misc.py index fbc542179a..1353ec70d2 100644 --- a/python/quadrants/lang/misc.py +++ b/python/quadrants/lang/misc.py @@ -3,7 +3,6 @@ import shutil import tempfile import warnings -from contextlib import contextmanager from copy import deepcopy as _deepcopy from quadrants import _logging, _snode @@ -13,6 +12,7 @@ from quadrants.lang import impl, util from quadrants.lang.checkpoint import checkpoint from quadrants.lang.expr import Expr +from quadrants.lang.graph_parallel import graph_parallel, graph_parallel_context from quadrants.lang.graph_status import GraphStatus from quadrants.lang.impl import axes, get_runtime from quadrants.profiler.kernel_profiler import get_default_kernel_profiler @@ -752,56 +752,6 @@ def graph_do_while(condition) -> bool: return bool(condition) -@contextmanager -def graph_parallel_context(): - """Opens a fork/join region whose ``qd.graph_parallel()`` sections run concurrently. - - Used as ``with qd.graph_parallel_context():`` inside a ``@qd.kernel(graph=True)`` kernel. The region's - body must contain only ``with qd.graph_parallel():`` blocks. Each ``qd.graph_parallel`` section is an - independent sequence of work; the ``qd.graph_parallel`` sections have no ordering relative to each - other and may execute concurrently, while everything after the region waits for *all* ``qd.graph_parallel`` - sections to finish (the join). This is the graph analogue of ``qd.stream_parallel()`` (which is for - non-graph kernels): it lets independent stages -- e.g. qipc's point-triangle and edge-edge assembly - -- overlap inside a captured graph. - - Concurrency contract (the author's responsibility): ``qd.graph_parallel`` sections must be data-race - free with respect to one another (no ``qd.graph_parallel`` section reads what another writes, no two - ``qd.graph_parallel`` sections write the same location). Calls *within* a ``qd.graph_parallel`` section - keep their program order. - - Backend behavior: - - CUDA SM graph path: ``qd.graph_parallel`` sections become independent graph chains joined by an - empty node, so the runtime schedules them on parallel streams (real overlap). - - CPU / Vulkan / Metal / AMDGPU graph: correct results, ``qd.graph_parallel`` sections run serially - (the concurrency tags are honored only by the graph builder today). - - Restrictions (enforced at kernel compile time): - - Must be used inside ``@qd.kernel(graph=True)``. - - The region body may contain only ``with qd.graph_parallel():`` blocks. - - Regions cannot be nested, and a ``qd.graph_parallel`` section body must be straight-line task work - (no nested ``qd.graph_do_while``, ``qd.checkpoint``, or ``qd.graph_parallel_context``). - - This function should not be called directly at runtime; it is recognized and transformed during AST compilation. - At Python runtime (outside kernels) it is a no-op context manager. - - See also ``docs/source/user_guide/graph.md``. - """ - yield - - -@contextmanager -def graph_parallel(): - """Declares one ``qd.graph_parallel`` section of an enclosing ``qd.graph_parallel_context()`` region. - - Used as ``with qd.graph_parallel():`` directly inside a ``with qd.graph_parallel_context():`` block. - The ``qd.graph_parallel`` section's body is an independent sequence of work that may run concurrently - with the region's other ``qd.graph_parallel`` sections. - - See ``qd.graph_parallel_context()`` for the full contract and backend behavior. - """ - yield - - def global_thread_idx(): """Returns the global thread id of this running thread, only available for cpu and cuda backends. From 29a2c3dc490c2ba4596af023cdf1c8e047ee02b5 Mon Sep 17 00:00:00 2001 From: Hugh Perkins Date: Thu, 25 Jun 2026 05:31:03 -0700 Subject: [PATCH 25/25] style(graph): reflow under-wrapped C++ comments to 120c Repacks the graph_parallel_context fork/join block comment in graph_manager.cpp (was wrapped at ~93-99c) and the OffloadedStmt region-id comment in statements.h so comment lines fill toward the project's 120c width instead of the ~80c default. Addresses the "Check line wrapping" CI on PR #756. --- quadrants/ir/statements.h | 4 ++-- quadrants/runtime/cuda/graph_manager.cpp | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/quadrants/ir/statements.h b/quadrants/ir/statements.h index 93b82414df..9ad6b688c3 100644 --- a/quadrants/ir/statements.h +++ b/quadrants/ir/statements.h @@ -1415,8 +1415,8 @@ class OffloadedStmt : public Stmt { MemoryAccessOptions mem_access_opt; int stream_parallel_group_id{0}; // Per-kernel `qd.graph_parallel_context()` region id (0 outside any region). Set by `offload.cpp` from the source - // `RangeForStmt` / `StructForStmt`, read by the CUDA LLVM codegen to populate - // `OffloadedTask::graph_parallel_region_id` so the GraphManager keeps adjacent regions' fork/join groups apart. + // `RangeForStmt` / `StructForStmt`. The CUDA LLVM codegen copies it into `OffloadedTask::graph_parallel_region_id` + // so the GraphManager keeps adjacent regions' fork/join groups apart. int graph_parallel_region_id{0}; // `cp_id` of the enclosing `qd.checkpoint(...)` block for this offloaded task (`-1` outside any checkpoint). Set by // `offload.cpp` from the source `RangeForStmt::checkpoint_id` / `StructForStmt::checkpoint_id`. Read by the CUDA / diff --git a/quadrants/runtime/cuda/graph_manager.cpp b/quadrants/runtime/cuda/graph_manager.cpp index 0ac08a07fd..ca3e196b4c 100644 --- a/quadrants/runtime/cuda/graph_manager.cpp +++ b/quadrants/runtime/cuda/graph_manager.cpp @@ -404,12 +404,11 @@ void GraphManager::build_level(int parent_id, continue; } - // --- A qd.graph_parallel_context() fork/join region: a contiguous run of this level's direct, - // non-checkpoint tasks tagged with a nonzero stream_parallel_group_id (set by qd.graph_parallel()). Each - // distinct group id is one qd.graph_parallel section; the qd.graph_parallel sections fork from the - // region's entry (`prev_node`), run their tasks in order, and join into a single empty node so - // downstream work waits for all of them. CUDA's graph executor schedules the independent - // qd.graph_parallel section chains on separate streams -> real overlap. + // --- A qd.graph_parallel_context() fork/join region: a contiguous run of this level's direct, non-checkpoint + // tasks tagged with a nonzero stream_parallel_group_id (set by qd.graph_parallel()). Each distinct group id is one + // qd.graph_parallel section; the qd.graph_parallel sections fork from the region's entry (`prev_node`), run their + // tasks in order, and join into a single empty node so downstream work waits for all of them. CUDA's graph + // executor schedules the independent qd.graph_parallel section chains on separate streams -> real overlap. // // The run is bounded to a single region by graph_parallel_region_id: two qd.graph_parallel_context() regions // written back-to-back (no serial task between them to break the run) carry distinct region ids, so each builds