diff --git a/docs/source/user_guide/graph.md b/docs/source/user_guide/graph.md index cfdcd3bb21..d725cbefea 100644 --- a/docs/source/user_guide/graph.md +++ b/docs/source/user_guide/graph.md @@ -125,6 +125,30 @@ def converge(x: qd.types.ndarray(qd.f32, ndim=1), keep_going[()] = 0 ``` +### Each top-level for-loop in the body is its own offloaded launch (grid-wide barriers survive) + +A `graph=True` kernel compiles **each top-level `for`-loop into its own offloaded GPU launch** (its own graph node), and consecutive top-level loops are separated by an **implicit grid-wide barrier** — loop *k* finishes across the entire grid before loop *k+1* starts. This is what makes the "many top-level loops" form in [Basic usage](#basic-usage) work. + +**This behavior is fully preserved inside a `graph_do_while` body.** The loop body is not a single fused launch — every top-level `for`-loop *inside* the `while qd.graph_do_while(...):` block is still its own offloaded node, with grid-wide barriers between consecutive loops, on every iteration. So you can place a **multi-phase, multi-launch algorithm that depends on grid-wide synchronization between phases directly inside `graph_do_while`** — for example a device-wide radix sort (per-digit *histogram → scan → scatter*, where each phase must complete across the whole grid before the next reads its output), or any solver phase chain (assemble → precondition → SpMV → reduce). Each phase sees the previous phase's writes from every block. + +```python +@qd.kernel(graph=True) +def solve(..., cond: qd.types.ndarray(qd.i32, ndim=0)): + while qd.graph_do_while(cond): + for i in range(n): # offload A (whole grid finishes) ... + histogram[...] = ... + for i in range(n): # ... then offload B sees all of A's writes + scanned[...] = ... + for i in range(n): # ... then offload C sees all of B's writes + out[...] = ... + for _ in range(1): # update the loop condition + cond[()] = ... +``` + +**What does break the grid-wide barrier:** nesting a `for`-loop inside *ordinary* runtime control flow — another `for`, an `if`, or a plain Python `while` — **demotes it from top-level position**, so it no longer becomes its own offloaded launch. Instead it runs as device code *within the enclosing launch*, and the grid-wide barrier between it and its siblings is lost (other blocks may not have produced their data yet). `graph_do_while` is **not** "ordinary runtime control flow" in this sense — it is precisely the construct designed to host a sequence of top-level offloaded loops, so loops directly in its body keep their barriers. Compile-time `qd.static(range(...))` loops are also fine: they unroll flat at compile time and keep their bodies at top-level position. + +> Rule of thumb: a `for`-loop must be **directly** at the top level of the kernel body — or **directly** inside a `graph_do_while` body — to become its own offloaded launch with grid-wide synchronization. Wrapping it in a runtime `for`/`if`/`while` collapses it into the enclosing launch. + ### Do-while semantics `graph_do_while` has **do-while** semantics: the kernel body always executes at least once before the condition is checked. This matches the behavior of CUDA conditional while nodes. The flag value must be >= 1 at launch time. Passing 0 with a kernel that decrements the counter will cause an infinite loop.