Skip to content

Fix 2-stage fused_allreduce_rmsnorm memory ordering#2890

Open
hubertlu-tw wants to merge 1 commit intoROCm:mainfrom
hubertlu-tw:2_stage_fused_ar_fix
Open

Fix 2-stage fused_allreduce_rmsnorm memory ordering#2890
hubertlu-tw wants to merge 1 commit intoROCm:mainfrom
hubertlu-tw:2_stage_fused_ar_fix

Conversation

@hubertlu-tw
Copy link
Copy Markdown
Contributor

Motivation

The 2-stage fused_allreduce_rmsnorm path in csrc/include/custom_all_reduce.cuh can return wrong results on ROCm once per-rank volume crosses roughly ~1.2 MB at TP=8. Output tracks the unfused all_reduce → bf16 → residual_add → RMSNorm reference for many consecutive calls, then diverges sharply (max-abs error on the order of 10² bf16 units) and stays wrong.

Downstream, SGLang with --enable-aiter-allreduce-fusion shows the same pattern on GptOssForCausalLM 120B (hidden 2880, bf16, TP=8): early GSM8K passes look normal, later back-to-back runs collapse toward 0.000 accuracy; disabling fusion removes the failure.

Root cause is cross-device memory ordering in the ROCm implementation of end_sync for the stage that writes peer-visible IPC tmps[] before a separate kernel reads them. Three issues stack: stage 1 used end_sync<ngpus, true> (relaxed atomics), the acquire load used __MEMORY_SCOPE_DEVICE while the release store used __MEMORY_SCOPE_SYSTEM, and only the polling threads performed the acquire—other threads in the block could still read stale peer writes after __syncthreads().

Modifications

  • csrc/include/custom_all_reduce.cuh
    • reduce_scatter_cross_device_store: call end_sync<ngpus, false> after cross-device writes into each rank’s tmps so peers synchronize with release/acquire instead of relaxed flags. cross_device_reduce_2stage_naive still uses final_sync=true because its consumer is on-device only.
    • end_sync (ROCm path): when final_sync=false, use __MEMORY_SCOPE_SYSTEM on the acquire load to match the system-scoped release store.
    • end_sync: when final_sync=false, issue __scoped_atomic_thread_fence(__ATOMIC_ACQUIRE, __MEMORY_SCOPE_SYSTEM) after __syncthreads() so every thread in the block sees peer-written memory, not only the ngpus threads that polled the flag.
Unified diff (reference)
diff --git a/csrc/include/custom_all_reduce.cuh b/csrc/include/custom_all_reduce.cuh
@@ -217,13 +217,27 @@ DINLINE void end_sync(const RankSignals& sg,
                                 flag,
                                 final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
                                 __MEMORY_SCOPE_SYSTEM);
-        // wait until we got true from all ranks
+        // wait until we got true from all ranks.
+        //
+        // When final_sync=false this barrier also needs to synchronize
+        // peer-GPU writes that preceded the release store. For
+        // release/acquire to actually propagate cross-device writes,
+        // the acquire must match the release's __MEMORY_SCOPE_SYSTEM.
         while(__scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x],
                                      final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
-                                     __MEMORY_SCOPE_DEVICE) < flag)
+                                     final_sync ? __MEMORY_SCOPE_DEVICE
+                                                : __MEMORY_SCOPE_SYSTEM) < flag)
             ;
     }
     __syncthreads();
+    if(!final_sync)
+    {
+        // Only the ngpus polling threads did the acquire above.
+        // __syncthreads() is a within-block barrier and does not
+        // carry system-scope acquire ordering. Upgrade with an
+        // explicit block-wide system-scope acquire fence so all
+        // threads see peer-written memory after this point.
+        __scoped_atomic_thread_fence(__ATOMIC_ACQUIRE, __MEMORY_SCOPE_SYSTEM);
+    }
@@ -1149,7 +1149,10 @@ __global__ void __launch_bounds__(512, 1) reduce_scatter_cross_device_store(
         tmps[warp_id][rank * part + idx] = rslt;
     }
-    end_sync<ngpus, true>(sg, self_sg, rank);
+    // Stage 2 (local_device_load_rmsnorm*) reads `tmps` on each rank's
+    // own memory, which contains IPC writes from peer ranks' stage-1.
+    // Use final_sync=false so the release/acquire pair synchronizes
+    // those cross-device writes.
+    end_sync<ngpus, false>(sg, self_sg, rank);
 }
  • op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py (new): multi-process regression for TP world size 2, 4, or 8 (--tp). Sweeps shapes across 1-stage (≤128 KB total) and 2-stage regimes; default 2000 iterations per shape with fresh random inputs; compares fused output to unfused all_reduce + residual_add + RMSNorm. Fails if max-abs error ever exceeds 0.5. --benchmark runs a latency microbenchmark (BENCHMARK_SHAPES, cuda.Event per call, rank-averaged p50/p90/p99) without correctness checks. Exit code 1 without the kernel fix on the TP=8 sweep (primary repro); TP=2/4 are smoke + perf coverage.

Accuracy Tests

Primary repro for the historical ordering bug remains TP=8 on the default 13-shape sweep; TP=4 and TP=2 are smoke coverage on the same shapes.

The test script I used which can reproduce the error without my fix: test_fused_ar_rms_memory_order.py.

TP=8 — regression test (primary repro)
python3 op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py --tp 8 --iters 2000

Result: exit 0; all 13 shapes PASS; worst max_abs_err per shape in {0.0625, 0.1250}; first_bad=-1 everywhere.

TP=4 — regression test (smoke)
python3 op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py --tp 4 --iters 2000

Result: exit 0; all 13 shapes PASS; worst max_abs_err per shape in {0.0625, 0.1250}.

TP=2 — regression test (smoke)
python3 op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py --tp 2 --iters 2000

Result: exit 0; all 13 shapes PASS; worst max_abs_err per shape in {0.0625, 0.1250}.

Kernel-level (pre-fix vs post-fix, same invocation): Without the C++ fix, shapes at large M×hidden fail deterministically in the sweep (example pre-fix excerpt):

Result Shape (M × hidden) max_abs_err first_bad row
PASS M=1, hidden=2880 0.0625
FAIL M=1319, hidden=2880 149.19 1209
FAIL M=2048, hidden=2880 237.75 247
PASS M=16384, hidden=2880 0.1250

M=16384 can still PASS pre-fix because the failure is timing-sensitive; the test’s tight loop is what makes smaller large-M shapes fail reliably.

Post-fix: all shapes in the sweep PASS; max-abs error stays in {0.0625, 0.1250}.

SGLang GSM8K (gpt-oss-120B, fusion on): five back-to-back client runs (--num-questions 1319 --parallel 1319 --num-shots 5).

Run Accuracy pre-fix Accuracy post-fix
1 0.854 0.850
2 0.850 0.857
3 0.835 0.848
4 0.000 0.857
5 0.000 0.839
Server launch (SGLang)
SGLANG_USE_AITER_UNIFIED_ATTN=1 SGLANG_USE_AITER=1 \
  python3 -m sglang.launch_server \
    --model-path /data/gpt-oss-120b --tp 8 \
    --attention-backend aiter --trust-remote-code \
    --enable-aiter-allreduce-fusion \
    --page-size 16 --disable-radix-cache
Client loop
for i in 1 2 3 4 5; do
  python3 benchmark/gsm8k/bench_sglang.py \
    --num-questions 1319 --parallel 1319 --num-shots 5 --port 9000
done

Post-fix runs stay in the same band as fusion-off baseline on this stack (~0.83–0.87).

Benchmarking and Profiling

Absolute latency by TP, patched vs unpatched (test_fused_ar_rms_memory_order.py --benchmark)

Same host, same ROCm stack, bench_warmup=200, bench_iters=1000, bf16, rank-averaged p50/p90/p99 (µs) from cuda.Event per call. Shapes = BENCHMARK_SHAPES (10 rows). Tables are script-native (fused op only; no unfused reference in the timed loop). Baselines were collected by checking out HEAD~1 of csrc/include/custom_all_reduce.cuh and flushing the JIT cache; tables were re-captured immediately after restoring the fix.

Improvement % uses the lower-is-better formula (baseline − patched) / baseline × 100 on p50. Positive = patched is faster; negative = patched is slower (cost of correct ordering).

TP=8 — microbenchmark (baseline vs patched)
python3 op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py --tp 8 --benchmark \
  --bench-warmup 200 --bench-iters 1000
M H bytes Baseline p50 (µs) Patched p50 (µs) Δ p50 (µs) Improvement % (p50) Baseline p90 Patched p90 Baseline p99 Patched p99
1 4096 8192 101.51 99.14 −2.37 +2.3 119.78 122.73 184.58 196.65
16 4096 131072 99.57 98.38 −1.19 +1.2 114.86 105.10 191.79 188.33
64 2880 368640 102.53 102.72 +0.19 −0.2 116.94 132.73 208.02 193.30
128 2880 737280 100.59 99.77 −0.82 +0.8 109.41 113.60 184.63 201.09
1024 2880 5898240 88.85 87.73 −1.12 +1.3 97.30 96.85 171.76 186.25
1319 2880 7597440 89.00 87.79 −1.21 +1.4 97.93 96.18 169.77 186.21
2048 2880 11796480 92.04 96.77 +4.73 −5.1 94.49 99.37 98.55 103.49
4096 2880 23592960 162.82 167.61 +4.79 −2.9 165.69 170.02 173.24 173.26
8192 2880 47185920 311.46 316.71 +5.25 −1.7 318.45 324.37 323.69 329.15
16384 2880 94371840 613.20 613.59 +0.39 −0.1 616.85 617.30 620.11 620.83
TP=4 — microbenchmark (baseline vs patched)
python3 op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py --tp 4 --benchmark \
  --bench-warmup 200 --bench-iters 1000
M H bytes Baseline p50 (µs) Patched p50 (µs) Δ p50 (µs) Improvement % (p50) Baseline p90 Patched p90 Baseline p99 Patched p99
1 4096 8192 90.51 90.80 +0.29 −0.3 101.23 103.22 173.51 179.94
16 4096 131072 89.66 89.66 0.00 0.0 108.91 95.17 187.30 168.05
64 2880 368640 92.45 93.29 +0.84 −0.9 98.35 98.52 186.07 179.82
128 2880 737280 91.59 92.55 +0.96 −1.0 101.61 101.85 191.04 183.81
1024 2880 5898240 85.36 87.69 +2.33 −2.7 93.05 92.79 173.16 172.62
1319 2880 7597440 103.62 108.24 +4.62 −4.5 106.38 110.91 110.40 114.31
2048 2880 11796480 145.96 151.57 +5.61 −3.8 148.90 154.65 152.59 158.10
4096 2880 23592960 268.63 273.43 +4.80 −1.8 270.98 275.92 274.02 278.70
8192 2880 47185920 522.40 526.85 +4.45 −0.9 529.95 531.90 535.30 538.98
16384 2880 94371840 1048.95 1047.36 −1.59 +0.2 1054.70 1053.26 1058.82 1058.09
TP=2 — microbenchmark (baseline vs patched)
python3 op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py --tp 2 --benchmark \
  --bench-warmup 200 --bench-iters 1000
M H bytes Baseline p50 (µs) Patched p50 (µs) Δ p50 (µs) Improvement % (p50) Baseline p90 Patched p90 Baseline p99 Patched p99
1 4096 8192 75.02 75.82 +0.80 −1.1 83.80 86.21 147.97 147.77
16 4096 131072 73.88 74.98 +1.10 −1.5 78.50 101.61 141.74 147.46
64 2880 368640 77.78 80.28 +2.50 −3.2 96.80 84.66 152.32 155.83
128 2880 737280 80.73 83.34 +2.61 −3.2 101.59 88.59 155.17 153.98
1024 2880 5898240 132.40 136.78 +4.38 −3.3 133.67 137.94 138.20 142.28
1319 2880 7597440 167.86 172.64 +4.78 −2.8 169.28 174.06 173.06 177.51
2048 2880 11796480 247.78 252.50 +4.72 −1.9 249.07 253.80 252.90 257.46
4096 2880 23592960 477.94 482.31 +4.37 −0.9 479.21 483.86 482.19 487.07
8192 2880 47185920 943.48 947.85 +4.37 −0.5 946.47 950.69 953.95 956.47
16384 2880 94371840 1896.39 1896.86 +0.47 −0.0 1900.68 1901.02 1905.26 1905.16

Takeaways (all TPs):

  • Small-M (1-stage, ≤128 KB) and very large-M (M=16384) rows sit inside ±1–3 µs run-to-run noise in either direction; those paths either skip the new fence entirely (1-stage final_sync=true) or are dominated by bulk transfer.
  • On the 2-stage path, the fix carries a ~+4.5 µs fixed overhead from the system-scope acquire fence. It shows up uniformly at TP=2 and TP=4 and appears at M ≥ 2048 at TP=8; as a fraction of call time it is ≤5% on decode-sized shapes and shrinks toward ~1% at multi-megabyte payloads.
  • The alternative — disabling fusion for total_bytes > ~4 MB — would lose the fused path for every decode batch above ~700 tokens at hidden=2880, a strictly worse perf trade-off.

Checklist

  • Format C++/Python per aiter code guide (pre-commit / project formatter if applicable).
  • Add or extend tests per aiter testing docs; new file op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py.
  • Update user-facing docs only if behavior flags or public API change (none here).
  • Attach accuracy evidence (regression test + optional SGLang GSM8K table above) and perf table for kernel behavior change.
  • Match existing style in touched headers (no unrelated refactors).

Review and Merge Process

  1. Request review from Aiter maintainers / code owners for csrc/ and op_tests/multigpu_tests/.
  2. Run multi-GPU CI or the in-repo test on ROCm (TP=2/4 smoke; TP=8 for the primary repro sweep).
  3. After approvals and green checks, merge per ROCm/aiter repository policy.

CC: @valarLip @TennyWang1223 @HaiShaw @kkHuang-amd

@hubertlu-tw hubertlu-tw requested review from a team and TennyWang1223 April 23, 2026 23:16
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2890 --add-label <label>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant