Fix 2-stage fused_allreduce_rmsnorm memory ordering#2890
Open
hubertlu-tw wants to merge 1 commit intoROCm:mainfrom
Open
Fix 2-stage fused_allreduce_rmsnorm memory ordering#2890hubertlu-tw wants to merge 1 commit intoROCm:mainfrom
hubertlu-tw wants to merge 1 commit intoROCm:mainfrom
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The 2-stage
fused_allreduce_rmsnormpath incsrc/include/custom_all_reduce.cuhcan return wrong results on ROCm once per-rank volume crosses roughly ~1.2 MB at TP=8. Output tracks the unfusedall_reduce → bf16 → residual_add → RMSNormreference for many consecutive calls, then diverges sharply (max-abs error on the order of 10² bf16 units) and stays wrong.Downstream, SGLang with
--enable-aiter-allreduce-fusionshows the same pattern onGptOssForCausalLM120B (hidden 2880, bf16, TP=8): early GSM8K passes look normal, later back-to-back runs collapse toward 0.000 accuracy; disabling fusion removes the failure.Root cause is cross-device memory ordering in the ROCm implementation of
end_syncfor the stage that writes peer-visible IPCtmps[]before a separate kernel reads them. Three issues stack: stage 1 usedend_sync<ngpus, true>(relaxed atomics), the acquire load used__MEMORY_SCOPE_DEVICEwhile the release store used__MEMORY_SCOPE_SYSTEM, and only the polling threads performed the acquire—other threads in the block could still read stale peer writes after__syncthreads().Modifications
csrc/include/custom_all_reduce.cuhreduce_scatter_cross_device_store: callend_sync<ngpus, false>after cross-device writes into each rank’stmpsso peers synchronize with release/acquire instead of relaxed flags.cross_device_reduce_2stage_naivestill usesfinal_sync=truebecause its consumer is on-device only.end_sync(ROCm path): whenfinal_sync=false, use__MEMORY_SCOPE_SYSTEMon the acquire load to match the system-scoped release store.end_sync: whenfinal_sync=false, issue__scoped_atomic_thread_fence(__ATOMIC_ACQUIRE, __MEMORY_SCOPE_SYSTEM)after__syncthreads()so every thread in the block sees peer-written memory, not only thengpusthreads that polled the flag.Unified diff (reference)
op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py(new): multi-process regression for TP world size 2, 4, or 8 (--tp). Sweeps shapes across 1-stage (≤128 KB total) and 2-stage regimes; default 2000 iterations per shape with fresh random inputs; compares fused output to unfusedall_reduce + residual_add + RMSNorm. Fails if max-abs error ever exceeds0.5.--benchmarkruns a latency microbenchmark (BENCHMARK_SHAPES,cuda.Eventper call, rank-averaged p50/p90/p99) without correctness checks. Exit code 1 without the kernel fix on the TP=8 sweep (primary repro); TP=2/4 are smoke + perf coverage.Accuracy Tests
Primary repro for the historical ordering bug remains TP=8 on the default 13-shape sweep; TP=4 and TP=2 are smoke coverage on the same shapes.
The test script I used which can reproduce the error without my fix: test_fused_ar_rms_memory_order.py.
TP=8 — regression test (primary repro)
Result: exit 0; all 13 shapes PASS; worst
max_abs_errper shape in{0.0625, 0.1250};first_bad=-1everywhere.TP=4 — regression test (smoke)
Result: exit 0; all 13 shapes PASS; worst
max_abs_errper shape in{0.0625, 0.1250}.TP=2 — regression test (smoke)
Result: exit 0; all 13 shapes PASS; worst
max_abs_errper shape in{0.0625, 0.1250}.Kernel-level (pre-fix vs post-fix, same invocation): Without the C++ fix, shapes at large M×hidden fail deterministically in the sweep (example pre-fix excerpt):
M=16384can still PASS pre-fix because the failure is timing-sensitive; the test’s tight loop is what makes smaller large-M shapes fail reliably.Post-fix: all shapes in the sweep PASS; max-abs error stays in
{0.0625, 0.1250}.SGLang GSM8K (gpt-oss-120B, fusion on): five back-to-back client runs (
--num-questions 1319 --parallel 1319 --num-shots 5).Server launch (SGLang)
SGLANG_USE_AITER_UNIFIED_ATTN=1 SGLANG_USE_AITER=1 \ python3 -m sglang.launch_server \ --model-path /data/gpt-oss-120b --tp 8 \ --attention-backend aiter --trust-remote-code \ --enable-aiter-allreduce-fusion \ --page-size 16 --disable-radix-cacheClient loop
Post-fix runs stay in the same band as fusion-off baseline on this stack (~0.83–0.87).
Benchmarking and Profiling
Absolute latency by TP, patched vs unpatched (
test_fused_ar_rms_memory_order.py --benchmark)Same host, same ROCm stack,
bench_warmup=200,bench_iters=1000, bf16, rank-averaged p50/p90/p99 (µs) fromcuda.Eventper call. Shapes =BENCHMARK_SHAPES(10 rows). Tables are script-native (fused op only; no unfused reference in the timed loop). Baselines were collected by checking outHEAD~1ofcsrc/include/custom_all_reduce.cuhand flushing the JIT cache; tables were re-captured immediately after restoring the fix.Improvement % uses the lower-is-better formula
(baseline − patched) / baseline × 100on p50. Positive = patched is faster; negative = patched is slower (cost of correct ordering).TP=8 — microbenchmark (baseline vs patched)
TP=4 — microbenchmark (baseline vs patched)
TP=2 — microbenchmark (baseline vs patched)
Takeaways (all TPs):
final_sync=true) or are dominated by bulk transfer.total_bytes > ~4 MB— would lose the fused path for every decode batch above ~700 tokens at hidden=2880, a strictly worse perf trade-off.Checklist
op_tests/multigpu_tests/test_fused_ar_rms_memory_order.py.Review and Merge Process
csrc/andop_tests/multigpu_tests/.CC: @valarLip @TennyWang1223 @HaiShaw @kkHuang-amd