Skip to content

[MLA] Fix qh128 ASM nullptr write; enable native qh128 fp8 on gfx950 #2907

Draft
inkcherry wants to merge 3 commits intoROCm:mainfrom
inkcherry:aiter_mla_
Draft

[MLA] Fix qh128 ASM nullptr write; enable native qh128 fp8 on gfx950 #2907
inkcherry wants to merge 3 commits intoROCm:mainfrom
inkcherry:aiter_mla_

Conversation

@inkcherry
Copy link
Copy Markdown
Contributor

@inkcherry inkcherry commented Apr 24, 2026

for OSL/ISL=8k/1k con=256 InferenceMax DSV3 decode case cc @Duyi-Wang @slippedJim
Restore the native qh128 fp8 decode path on gfx950 that PR #2204 had to
disable for this case, by fixing the underlying nullptr-write bug in the qh128 ASM kernel
call site rather than working around it. This recovers MLA decode performance
on gfx950 for nhead=128, fp8/fp8 while keeping gfx942 unchanged.

python3 op_tests/test_mla_persistent.py \
    -n 128,1 -d fp8 -kvd fp8 \
    -b  32 64 128 256 -c 8192
bs without this PR with this PR speedup
32 210 µs / 0.75 TB/s 105 µs / 1.50 TB/s 2.0×
64 413 µs / 0.76 TB/s 174 µs / 1.82 TB/s 2.4×
128 800 µs / 0.79 TB/s 300 µs / 2.10 TB/s 2.7×
256 1812 µs / 0.70 TB/s 526 µs / 2.40 TB/s 3.4×

…lback

The qh128 fp8 MLA stage1 ASM kernel
(mla_a8w8_qh128_m32x4_n16x2_msk0_ps) unconditionally writes to its
ptr_LSEP argument. Callers that requested no LSE were previously
allowed to pass nullptr/None for final_lse, which causes a GPU memory
access fault at sufficiently large batch x context (e.g. b=256
c=4096) on gfx950, crashing the host with a core dump.

This was previously masked by a gfx942-only routing guard which
forced gfx950 down a qh16 fold fallback path. The latent kernel bug
became reachable once that guard was removed and was further exposed
by the LSE-aware dispatcher refactor (ROCm#2378), which made the test
harness pass final_lse=None.

Two minimal, low-risk fixes:

1. aiter/mla.py: always allocate the final_lse output buffer in the
   two MLA decode call sites that route to mla_decode_stage1_asm_fwd,
   and pass that buffer to the kernel. The caller-visible value of
   `final_lse` (returned to the user) still tracks return_lse and is
   None when not requested - public semantics are unchanged.

2. csrc/py_itfs_cu/asm_mla.cu: make get_heuristic_kernel_mla() do a
   2-pass lookup. First match the requested `lse` flag exactly; if
   that fails AND lse=1 was requested, retry with lse=0. The qh128
   binary always writes LSE, so picking it when a non-null buffer is
   supplied is safe even when its CSV row is marked lse=0.

With both changes:
- gfx950 large-batch persistent decode no longer crashes;
- gfx942 behaviour is unchanged (exact-match path is hit on pass 0);
- a small, transient (total_s, nhead) fp32 buffer is always
  allocated whether or not the caller requested LSE.

Validated end-to-end with a 4P+1D DeepSeek-R1 fp4 PD-disagg run on
gfx950 (8k input / 1k output / 2048 conc / 20480 prompts): no GPU
faults, throughput restored to the pre-regression baseline.

Made-with: Cursor
Now that the qh128 ASM kernel can no longer be reached with a null
final_lse pointer (preceding commit), it is safe to let gfx950 use
the same native num_heads==128 fp8 routing as gfx942. This avoids
the qh16 fold fallback and recovers a measurable amount of decode
throughput on gfx950.

Three sites updated to accept either gfx942 or gfx950:
- aiter/mla.py             (decode dispatch in mla_decode_fwd)
- aiter/ops/attention.py   (get_mla_metadata_info_v1, decode_update_mla_metadata_v1)
- csrc/kernels/mla/metadata/v1_2_device.cuh (natively_supported)

The use_qseqlen_fold fallback below remains untouched - it is
gated on `!natively_supported` and so naturally turns off when
gfx950 is treated as natively supported.

Validated end-to-end on gfx950 (DeepSeek-R1 fp4 4P+1D PD-disagg,
8k/1k, 2048 conc, 20480 prompts): warmup time 03:49 (vs 03:52
pre-regression baseline), output throughput within ~1% of baseline,
no GPU faults.

Made-with: Cursor
The reference path inside torch_mla_extend_split_kv classifies which
configurations are "natively supported" (i.e. the kernel produces
tensors in the qh128 layout, not the qh16 fold layout). Before this
PR the gfx942/gfx950 split for nheads=128 fp8 lived only in
aiter/mla.py; the persistent test ref still treated gfx950 nheads=128
fp8 as fold-layout. After the preceding two commits enable the native
qh128 path on gfx950, the existing default sweep entry
(nhead=128, decode_qlen=2) + fp8/fp8 + bs=256 + ctx=8192 hits a layout
mismatch in torch_mla_extend_split_kv and crashes
(IndexError: max(): Expected reduction dim 2 to have non-zero size).

Mirror the same gfx942 -> (gfx942, gfx950) extension here so the ref
follows the kernel. After the fix, the existing (128, 2) sweep entry
runs to completion (~1040 us, 1.26 TB/s on gfx950) and the equivalent
(128, 1) configuration that motivated this PR can be reproduced via:

  python op_tests/test_mla_persistent.py \
      -n 128,1 -d fp8 -kvd fp8 -b 256 -c 8192

(~544 us, 2.33 TB/s on gfx950 with this PR; pre-fix the qh128 stage1
kernel either crashed on a null final_lse pointer or fell back to the
slower qh16 fold path.)

No other test cases or argparse defaults change.

Made-with: Cursor
@inkcherry inkcherry requested review from a team and Copilot April 24, 2026 12:05
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2907 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a gfx950 crash in the qh128 MLA ASM decode path (nullptr write to ptr_LSEP) and re-enables native qh128 fp8/fp8 decoding on gfx950 (previously folded/disabled), aiming to restore decode performance while keeping gfx942 behavior consistent.

Changes:

  • Expand “natively supported” gating to include gfx950 for nhead=128, fp8/fp8 across Python tests/metadata and device-side metadata generation.
  • Update ASM kernel selection logic to allow an lse=0 fallback when lse is requested but no exact-match entry exists.
  • Always allocate and pass a final_lse buffer to the ASM stage1 decode call to avoid passing nullptr.

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
op_tests/test_mla_persistent.py Treat gfx950 like gfx942 for nhead=128 fp8/fp8 “native support” gating in the reference/test path.
csrc/py_itfs_cu/asm_mla.cu Add 2-pass heuristic kernel lookup with lse fallback behavior.
csrc/kernels/mla/metadata/v1_2_device.cuh Mark gfx950 as natively supported for num_heads==128 with fp8 q/kv in v1.2 metadata path.
aiter/ops/attention.py Align Python metadata heuristics to allow gfx950 native 128 fp8/fp8 path.
aiter/mla.py Allocate/pass final_lse buffer unconditionally to avoid qh128 ASM nullptr writes; broaden gfx gating to include gfx950 for 128 fp8/fp8.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/mla.py
Comment on lines +248 to 255
# Plan A safety: always allocate final_lse buffer.
# qh128 ASM kernel (mla_a8w8_qh128_m32x4_n16x2_msk0_ps) writes ptr_LSEP
# unconditionally; passing nullptr crashes on gfx950 at large batch.
final_lse_buf = torch.empty(
(total_s, nhead), dtype=dtypes.fp32, device=device
)
final_lse = final_lse_buf if return_lse else None

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final_lse_buf is now always allocated and passed into mla_decode_stage1_asm_fwd, even when return_lse is false. Because the C++ side sets lse_flag = (lse != nullptr), this changes kernel selection on gfx950 for configs that have both lse=0 and lse=1 entries (e.g. *_lse_* kernels in hsa/gfx950/mla/mla_asm.csv), potentially forcing extra LSE computation/stores and hurting perf/memory without the caller needing LSE.

Consider only passing a non-null LSE pointer when return_lse is true, and using a non-null scratch buffer only for the specific qh128 fp8/fp8 gfx950 case that crashes on nullptr (so other paths preserve lse_flag==0).

Copilot uses AI. Check for mistakes.
Comment thread aiter/mla.py
Comment on lines +425 to 432
# Plan A safety: always allocate final_lse buffer.
# qh128 ASM kernel (mla_a8w8_qh128_m32x4_n16x2_msk0_ps) writes ptr_LSEP
# unconditionally; passing nullptr crashes on gfx950 at large batch.
final_lse_buf = torch.empty(
(total_s, nhead), dtype=dtypes.fp32, device=device
)
final_lse = final_lse_buf if return_lse else None

Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue in the persistent-mode branch: final_lse_buf is always allocated/passed to mla_decode_stage1_asm_fwd, which makes lse_flag true in C++ even when return_lse is false. This can unintentionally select lse=1 kernels (and do extra work) on gfx950.

Recommend gating the non-null LSE pointer to return_lse and only using a scratch buffer for the specific qh128 fp8/fp8 gfx950 call path that requires a non-null pointer for safety.

Copilot uses AI. Check for mistakes.
Comment on lines +65 to +69
// Plan A: 2-pass lookup. First require exact lse match; if not found and
// lse was requested, fall back to lse=0 entries. The qh128 binary writes
// ptr_LSEP unconditionally, so passing a non-null buffer is safe even
// when CSV row is marked lse=0.
for (int pass = 0; pass < 2; ++pass) {
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new generic 2-pass fallback (lse requested → fall back to lse=0 entries) changes the meaning of lse: callers requesting LSE output may silently get a non-LSE kernel if no lse=1 entry exists. Given the config CSVs include distinct lse=0 vs lse=1 kernels (e.g. *_lse_* on gfx950), this risks returning incorrect/undefined final_lse instead of failing fast.

Safer options: restrict this fallback to the known qh128 fp8/fp8 gfx950 decode configuration where the CSV’s lse flag is known to be inaccurate, or keep strict matching and fix the qh128 lse metadata in the config source instead.

Copilot uses AI. Check for mistakes.
@inkcherry inkcherry marked this pull request as draft April 24, 2026 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants