Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1954d03
feat(graph): add qd.graph_parallel/qd.branch concurrent branches in g…
hughperkins Jun 9, 2026
616658c
test(graph): add qd.graph_parallel tests + docs; allow if-static opti…
hughperkins Jun 9, 2026
8d2f374
fix(graph): tag a loop's bound-compute serial task with its branch gr…
hughperkins Jun 9, 2026
d00e736
fix(graph): import contextmanager in misc.py
hughperkins Jun 23, 2026
a95d2b7
docs(graph): drop stream_parallel lowering note from graph_parallel b…
hughperkins Jun 23, 2026
70cb697
refactor(graph): rename qd.graph_parallel -> qd.graph_parallel_contex…
hughperkins Jun 23, 2026
706fd60
docs/graph: drop name= param from qd.graph_parallel; tighten graph_pa…
hughperkins Jun 23, 2026
89f8f26
docs/graph: simplify graph_parallel section (sections not branches, t…
hughperkins Jun 23, 2026
c57e452
refactor(graph): use 'parallel section' for graph_parallel internals …
hughperkins Jun 23, 2026
4fdf6e4
test(graph): use 'parallel section' in graph_parallel tests
hughperkins Jun 23, 2026
c334f8b
test(graph): clearer name for graph_parallel_context body-validation …
hughperkins Jun 23, 2026
6592249
test(graph): construct-based name for graph_parallel-outside-context …
hughperkins Jun 23, 2026
a7988b3
docs/graph: drop CUDA conditional while nodes note from do-while sema…
hughperkins Jun 23, 2026
509adfd
docs/graph: trim intro of graph_parallel_context section
hughperkins Jun 23, 2026
8a110e0
docs/graph: use American spellings (behavior/honored/recognized)
hughperkins Jun 23, 2026
4181741
docs/graph: rename term 'parallel section' -> 'qd.graph_parallel sect…
hughperkins Jun 23, 2026
478ec70
style: apply pre-commit (black, clang-format) formatting
hughperkins Jun 23, 2026
16dfba6
Merge branch 'main' into hp/graph-parallel-main
hughperkins Jun 23, 2026
7b4a0b6
feat(graph): generate qd.graph_parallel sections from a qd.static for…
hughperkins Jun 23, 2026
94a9b5b
refactor(graph): extract graph_parallel transformer + doc/wrap CI fixes
hughperkins Jun 23, 2026
c533755
fix(graph): add qd.graph_parallel{,_context} to test_api expected sur…
hughperkins Jun 23, 2026
9f35526
fix up doc ci flagged issues
hughperkins Jun 24, 2026
0a7cb98
Merge branch 'main' into hp/graph-parallel-main
hughperkins Jun 24, 2026
4c80eb0
address doc CI issues
hughperkins Jun 24, 2026
0ba1929
fix(graph): keep joins between back-to-back qd.graph_parallel_context…
hughperkins Jun 24, 2026
66efb1c
refactor(graph): extract graph_parallel context managers out of lang/…
hughperkins Jun 24, 2026
29a2c3d
style(graph): reflow under-wrapped C++ comments to 120c
hughperkins Jun 25, 2026
b304f57
Merge branch 'main' into hp/graph-parallel-main
hughperkins Jun 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 76 additions & 13 deletions docs/source/user_guide/graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ Graphs reduce kernel launch overhead by capturing a sequence of GPU operations i

## Backend support

`graph=True` and `graph_do_while` run on every backend. They are *hardware accelerated* on CUDA (via CUDA graphs) and AMDGPU (via HIP graphs); `graph_do_while` additionally requires CUDA SM 9.0+ / Hopper for its hardware-accelerated path. On other backends, `graph=True` is silently ignored and the kernel runs via the normal launch path, and `graph_do_while` falls back to a host-side do-while loop that copies the condition value GPU → host each iteration (causing a pipeline stall). `qd.checkpoint` gating runs entirely on the device on every GPU backend; only the CPU backend uses host-side gating.
`graph=True` and `graph_do_while` run on every backend. They are *hardware accelerated* on CUDA (via CUDA graphs) and AMDGPU (via HIP graphs); `graph_do_while` additionally requires [CUDA SM 9.0+](https://developer.nvidia.com/cuda/gpus) for its hardware-accelerated path. On other backends, `graph=True` is silently ignored and the kernel runs via the normal launch path, and `graph_do_while` falls back to a host-side do-while loop. `qd.checkpoint` gating runs entirely on the device on every GPU backend.

| Feature | `qd.cuda` SM 9.0+ | `qd.cuda` < SM 9.0 | `qd.amdgpu` | `qd.metal` | `qd.vulkan` | `qd.cpu` |
| --- | --- | --- | --- | --- | --- | --- |
| `graph=True` | hardware accelerated | hardware accelerated | hardware accelerated | runs (no acceleration) | runs (no acceleration) | runs (no acceleration) |
| `qd.graph_do_while` | hardware accelerated | host fallback | host fallback | host fallback | host fallback | host fallback |
| `qd.checkpoint` | GPU-side | GPU-side | GPU-side | GPU-side | GPU-side | host-side |
| `qd.graph_parallel_context` / `qd.graph_parallel` (sections) | concurrent | concurrent | runs serially | runs serially | runs serially | runs serially |

AMDGPU `graph_do_while` falls back to the host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2).
AMDGPU `graph_do_while` falls back to a host-side loop because HIP does not currently expose conditional / while graph nodes (as of ROCm 7.2).

Nested and sibling `graph_do_while` loops (and mixing `graph_do_while` with top-level `for`-loops) are **experimental** for now — see [Nested loops and mixing with for-loops](#nested-loops-and-mixing-with-for-loops).

Expand Down Expand Up @@ -44,7 +45,7 @@ my_kernel(x, y) # first call: builds and caches the graph
my_kernel(x, y) # subsequent calls: replays the cached graph
```

This works the same way on CUDA and AMDGPU. The cache is keyed per (compiled-kernel-specialization, launch-id), so different template instantiations (different field bindings, etc.) get their own cached graph.
This works the same way on CUDA and AMDGPU.

### Restrictions

Expand Down Expand Up @@ -122,7 +123,7 @@ def converge(x: qd.types.ndarray(qd.f32, ndim=1),

### Do-while semantics

`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behavior of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop.
`graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop.

### ndarray vs field

Expand Down Expand Up @@ -161,7 +162,7 @@ Note that `qd.func`'s are inlined, so you can freely factorize these structures

### Restrictions

- The counter ndarray may be swapped between calls: the cached graph reads each counter through an indirection slot that is refreshed on every launch, so passing a different ndarray (or alternating between several) replays the cached graph without rebuilding it.
- The counter ndarray may be swapped between calls. Passing a different ndarray (or alternating between several) replays the cached graph without rebuilding it.

### Caveats

Expand All @@ -179,7 +180,7 @@ Therefore on unsupported platforms, you might consider creating a second impleme

## Checkpoints with `qd.checkpoint` *(experimental)*

> **Experimental.** `qd.checkpoint`, `qd.GraphStatus`, and `kernel.resume(from_checkpoint=...)` are experimental APIs. The shape of the public surface (the context-manager signature, the `@qd.kernel(checkpoints=True)` flag, the `GraphStatus` fields, the host-side resume loop, the error messages, and the cross-backend lowering details) may change in any future release without a deprecation cycle.
> **Experimental.** `qd.checkpoint`, `qd.GraphStatus`, and `kernel.resume(from_checkpoint=...)` are experimental APIs. The shape of the public surface (the context-manager signature, the `@qd.kernel(checkpoints=True)` flag, the `GraphStatus` fields, the host-side resume loop, the error messages, and the cross-backend compilation details) may change in any future release without a deprecation cycle.

`qd.checkpoint` lets a graph kernel break partway through, surface a reason to the host, let the host fix things up, and resume from the same location on the next launch. An example use-case is an algorithm implemented as a graph that may need to allocate additional memory partway through, where the operations in the graph are in-place, and therefore cannot be rerun without changing/corrupting the output, and therefore for which simply retrying the whole graph from the start is not an option.

Expand Down Expand Up @@ -254,7 +255,7 @@ while status.yielded:

### Restrictions

- Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time with a fix-it pointing at `checkpoints=True`.
- Must be used inside `@qd.kernel(graph=True, checkpoints=True)`. Without the flag, `qd.checkpoint(...)` raises `QuadrantsSyntaxError` at compile time.
- `cp_id` must be an int literal or an `IntEnum` value, and must be unique across the kernel.
- `yield_on=` must be a kernel parameter that is a 0-d `qd.types.ndarray(qd.i32, ndim=0)`; expressions are not supported.
- Checkpoints cannot be nested inside other checkpoints. Checkpoints inside a `qd.graph_do_while` body are fine.
Expand All @@ -268,7 +269,7 @@ while status.yielded:
arr[i] = arr[i] + 1
```

The restriction is by design: each top-level statement inside a checkpoint becomes its own GPU task / graph node, so silently wrapping bare statements would hide a sequence of N field writes ballooning into N kernel launches. Forcing the user to write the `for`-wrap themselves keeps the lowering visible and gives a single obvious place to fuse multiple writes into one task by sharing a single wrapper.
The restriction is by design: each top-level statement inside a checkpoint becomes its own GPU task / graph node, so silently wrapping bare statements would hide a sequence of N field writes ballooning into N kernel launches. Forcing the user to write the `for`-wrap themselves keeps the mapping to GPU tasks visible and gives a single obvious place to fuse multiple writes into one task by sharing a single wrapper.

## Performance

Expand Down Expand Up @@ -393,11 +394,9 @@ k1(a, count)

The recommendation is to use the graph do while here anyway, if you need it for any platform, in order to ensure the code is compact and maintainable.

If you do want fixed-size for loops to run optimally on unsupported hardware platforms, we could add a specializd `qd.graph_range_for` function. This would:
- on graph-do-while-supported hardware: handle adding the additional increment kernel
- on graph-do-while-unsupported hardware: handle running the loop entirely on the host-side, to avoid adding a gpu pipeline stall
If you need fixed-size for loops to run optimally on hardware without graph-do-while support, consider opening a PR for a dedicated helper that picks the host-side fallback automatically on those backends.

In practice, for our own kernels, i.e. in genesis-world, they largely fall under the do while formulation, see the previous section. However, also have some that used to be do while, but have been migrated to an optimized fixed-size, see next section.
In practice, most such loops fall under the do while formulation (see the previous section). Some that were originally do while have since been migrated to an optimized fixed-size form (see the next section).

### A while loop, conditional on a device-side scalar tensor, that has been optimized into a fixed-size for loop

Expand Down Expand Up @@ -465,4 +464,68 @@ In this case, our recommendation is:
- use graph do while anyway, if you need it on any platform
- this will ensure your code is compact and maintainable
- if you need optimum 100% performance on unsupported platforms, then consider PRing onto quadrants an optimized graph implementation for your target platform
- for example it could somehow run MAX_ITER iterations anyway, similar to the earlier hand-rolled version, but via the graph abstraction, hence allowing the code to be compact, cross-platform, and also optimally fast

## `qd.graph_parallel` sections with `qd.graph_parallel_context` *(experimental)*

A `with qd.graph_parallel_context():` region lets you declare independent stages so the graph runs them concurrently.

`qd.graph_parallel_context` is honored by the graph builder so it composes with `graph=True` and `graph_do_while`.

```python
@qd.kernel(graph=True)
def step(...):
while qd.graph_do_while(ncond):
assemble_shared(...) # serial: feeds both `qd.graph_parallel` sections

with qd.graph_parallel_context(): # fork: `qd.graph_parallel` sections run concurrently
with qd.graph_parallel(): # point-triangle contacts
pt_assemble(...)
pt_hessian(...)
with qd.graph_parallel(): # edge-edge contacts (independent of pt)
ee_assemble(...)
ee_hessian(...)
# join: everything below waits for BOTH `qd.graph_parallel` sections to finish
merge_hessians(...)
precondition(...)
```

### Semantics

- **Fork / join.** Every `qd.graph_parallel` section in the region forks from the work that precedes the region. All `qd.graph_parallel` sections must finish before any work *after* the region begins (the join).
- **`qd.graph_parallel` sections are independent — you guarantee it.** Calls *within* a `qd.graph_parallel` section keep their program order, but calls in *different* `qd.graph_parallel` sections have no ordering. The `qd.graph_parallel` sections must be data-race free with respect to one another: no `qd.graph_parallel` section may read what another writes, and no two `qd.graph_parallel` sections may write the same memory. Quadrants does not check this; getting it wrong gives nondeterministic results.

### Generating `qd.graph_parallel` sections from a compile-time sequence

`qd.graph_parallel` sections do not have to be written out one by one. A `for ... in qd.static(...)` loop is unrolled at compile time, so each iteration that contains a `with qd.graph_parallel():` becomes its own section — handy for forking one section per element of a static list (e.g. per contact type):

(Note: See [compound_types.md](compound_types.md) for qd.data_oriented description)
```python
@qd.data_oriented
class Solver:
def __init__(self):
self.funcs = [self._assemble_pt, self._assemble_ee] # static list of @qd.func members

@qd.kernel(graph=True)
def step(self):
with qd.graph_parallel_context():
for i in qd.static(range(len(self.funcs))): # unrolls to one section per func
with qd.graph_parallel():
self.funcs[i]()
```

The loop **must** be a `qd.static(...)` loop (its trip count is known at compile time). A plain runtime `for i in range(n):` is rejected — a runtime loop cannot be unrolled into independent sections.

### Restrictions (enforced at kernel compile time)

- `qd.graph_parallel_context` may contain only `with qd.graph_parallel():` blocks, optionally wrapped in `if qd.static(...)` (so an optional `qd.graph_parallel` section can be compiled in or out — e.g. enabling edge-edge contacts only when a feature flag is set) or `for ... in qd.static(...)` loops (generate one `qd.graph_parallel` section per element of a compile-time sequence).
- `qd.graph_parallel()` may appear only directly inside a `qd.graph_parallel_context()`.
- `qd.graph_parallel_context` cannot be nested, and a `qd.graph_parallel` section body must be straight-line task work — no `qd.graph_do_while`, `qd.checkpoint`, or nested `qd.graph_parallel_context` inside a `qd.graph_parallel` section (a `qd.graph_parallel_context` may, however, sit inside a `qd.graph_do_while` body, as shown above).

### Backend behavior

| backend | scheduling |
| --- | --- |
| CUDA | `qd.graph_parallel` sections run **concurrently** |
| AMDGPU / CPU / Vulkan / Metal | `qd.graph_parallel` sections run **serially** |

Because `qd.graph_parallel` sections are independent by construction, running them serially produces identical results — only the scheduling differs.
4 changes: 3 additions & 1 deletion docs/source/user_guide/streams.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ combine() # runs after compute_ab() returns — a[] and b[] are ready

Consecutive `with qd.stream_parallel():` blocks run concurrently. Multiple for loops within a single block share a stream and run serially on it. All streams are synchronized before the kernel returns.

> **For `graph=True` kernels**, use [`qd.graph_parallel_context` / `qd.graph_parallel`](graph.md#qdgraph_parallel-sections-with-qdgraph_parallel_context-experimental) instead — `stream_parallel` is not compatible with graphs (see [Limitations](#limitations)). `qd.graph_parallel_context` expresses the same "run these independent sequences concurrently" idea but is honored by the graph builder.

### Restrictions

- All top-level statements in a kernel must be either all `stream_parallel` blocks or all regular statements. Mixing the two at the top level is a compile-time error.
Expand Down Expand Up @@ -134,4 +136,4 @@ with qd.create_stream() as s:
## Limitations

- **Not compatible with graphs.** Do not pass `qd_stream` to a kernel decorated with `graph=True` (if you do, a `RuntimeError` will be raised).
- **Not compatible with autodiff.** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context (if you do, a `RuntimeError` will be raised).
- **Not compatible with [autodiff.md](autodiff.md).** Do not pass `qd_stream` to a kernel that uses reverse-mode or forward-mode differentiation, or inside a `qd.ad.Tape` context (if you do, a `RuntimeError` will be raised).
12 changes: 11 additions & 1 deletion python/quadrants/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from quadrants.lang.ast.ast_transformers.function_def_transformer import (
FunctionDefTransformer,
)
from quadrants.lang.ast.ast_transformers.graph_parallel_transformer import (
GraphParallelTransformer,
)
from quadrants.lang.exception import (
QuadrantsIndexError,
QuadrantsRuntimeTypeError,
Expand Down Expand Up @@ -1615,9 +1618,16 @@ def build_With(ctx: ASTTransformerFuncContext, node: ast.With) -> None:
if checkpoint_info is not None:
return ASTTransformer._build_checkpoint_with(ctx, node, checkpoint_info)

if GraphParallelTransformer.is_graph_parallel_context_call(item.context_expr):
return GraphParallelTransformer.build_graph_parallel_context_with(ctx, node, build_stmts)

if GraphParallelTransformer.is_parallel_section_call(item.context_expr):
return GraphParallelTransformer.build_parallel_section_with(ctx, node, build_stmts)

if not FunctionDefTransformer._is_stream_parallel_with(node, ctx.global_vars):
raise QuadrantsSyntaxError(
"'with' in Quadrants kernels only supports qd.stream_parallel() or qd.checkpoint()"
"'with' in Quadrants kernels only supports qd.stream_parallel(), qd.checkpoint(), "
"qd.graph_parallel_context(), or qd.graph_parallel()"
)
if not ctx.is_kernel:
raise QuadrantsSyntaxError("qd.stream_parallel() can only be used inside @qd.kernel, not @qd.func")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,38 @@ def _is_loop_config_call(stmt: ast.stmt) -> bool:
return True
return False

@staticmethod
def _is_graph_parallel_context_with(stmt: ast.stmt) -> bool:
"""Syntactic check matching GraphParallelTransformer.is_graph_parallel_context_call: a
``with qd.graph_parallel_context():`` fork/join region."""
if not isinstance(stmt, ast.With) or len(stmt.items) != 1:
return False
ctx_expr = stmt.items[0].context_expr
if not isinstance(ctx_expr, ast.Call):
return False
func = ctx_expr.func
if isinstance(func, ast.Attribute) and func.attr == "graph_parallel_context":
return True
if isinstance(func, ast.Name) and func.id == "graph_parallel_context":
return True
return False

@staticmethod
def _is_parallel_section_with(stmt: ast.stmt) -> bool:
"""Syntactic check matching GraphParallelTransformer.is_parallel_section_call: a ``with qd.graph_parallel(...):``
section of a ``qd.graph_parallel_context()`` region."""
if not isinstance(stmt, ast.With) or len(stmt.items) != 1:
return False
ctx_expr = stmt.items[0].context_expr
if not isinstance(ctx_expr, ast.Call):
return False
func = ctx_expr.func
if isinstance(func, ast.Attribute) and func.attr == "graph_parallel":
return True
if isinstance(func, ast.Name) and func.id == "graph_parallel":
return True
return False

@staticmethod
def _validate_graph_do_while_structure(body: list[ast.stmt]) -> None:
"""If a kernel uses qd.graph_do_while() anywhere, enforce the structural rules that remain after
Expand Down Expand Up @@ -659,6 +691,27 @@ def _validate_graph_do_while_stmt_list(stmts: list[ast.stmt], is_kernel_top: boo
# `CheckpointTransformer.build_checkpoint_with`.
FunctionDefTransformer._validate_graph_do_while_stmt_list(stmt.body, is_kernel_top=is_kernel_top)
continue
if FunctionDefTransformer._is_graph_parallel_context_with(stmt):
# A `with qd.graph_parallel_context()` region groups concurrent `with qd.graph_parallel()`
# sections; it is a legal sibling of for-loops / checkpoints. Its body must be
# `qd.graph_parallel` section blocks (optionally under `if qd.static(...)`); the full check
# is in GraphParallelTransformer.build_graph_parallel_context_with. Each `qd.graph_parallel` section's
# body is task territory, validated here with the in-loop rules. Descend through `if` members
# so `qd.graph_parallel` sections inside an optional `if qd.static(...)` are reached too.
pending = list(stmt.body)
while pending:
member = pending.pop()
if FunctionDefTransformer._is_parallel_section_with(member):
FunctionDefTransformer._validate_graph_do_while_stmt_list(member.body, is_kernel_top=False)
elif isinstance(member, ast.If):
pending.extend(member.body)
pending.extend(member.orelse)
elif isinstance(member, ast.For):
# `for ... in qd.static(...)` generates sections; descend so each unrolled section
# body is still validated with the in-loop rules (a runtime for here is rejected
# earlier by GraphParallelTransformer.build_graph_parallel_context_with).
pending.extend(member.body)
continue
where = "the kernel body" if is_kernel_top else "a qd.graph_do_while() body"
raise QuadrantsSyntaxError(
f"When a kernel uses qd.graph_do_while(), {where} may not contain a {type(stmt).__name__} "
Expand Down
Loading
Loading