Skip to content

Fix: correct mxfp4 for K not divisible by 256#2900

Open
Wanzizhu wants to merge 1 commit intomainfrom
zw/fix_f4_oob
Open

Fix: correct mxfp4 for K not divisible by 256#2900
Wanzizhu wants to merge 1 commit intomainfrom
zw/fix_f4_oob

Conversation

@Wanzizhu
Copy link
Copy Markdown

Root cause

In _dynamic_mxfp4_quant_kernel_asm_layout (shuffle=True path), the scale buffer row-group stride was computed as bs_offs_0 * 2 * 16 * scaleN
where scaleN = scaleN_valid = ceil(K/32). However, the f4gemm kernel reads the scale buffer using the physical column width scaleN_pad =
ceil(scaleN_valid/8)*8.

When K % 256 != 0, scaleN_valid is not a multiple of 8, so scaleN_valid != scaleN_pad. This mismatch causes the quant kernel to write valid
scales at offsets based on stride scaleN_valid, while the GEMM kernel reads them at offsets based on stride scaleN_pad. Every row group beyond
the first 32 rows is misaligned, corrupting ~99% of output elements.

Additionally, out-of-bounds scale entries (for K_pad > K_actual) were filled with 127 (= 1.0 in e8m0), causing padding K tiles to contribute
garbage values to the GEMM output instead of zero.

Fix

  • Replace scaleN with scaleN_pad in the row-group stride term so the write layout matches what the GEMM kernel reads.
  • Fill out-of-bounds scale entries with 0 (= 2^-127 ≈ 0 in e8m0) so padding K tiles contribute negligibly to the output.

Affected shapes

Any shape where K % 256 != 0 and ceil(K/32) % 8 != 0, e.g. K=2880 (scaleN_valid=90, scaleN_pad=96). K=2816 and K=3072 are both divisible by 256
so were unaffected.

Alternative fix

pad B's packed K dimension to a multiple of 128 (= 256 actual fp4 elements) with zeros before shuffling, so the GEMM never reads
padding K tiles. This PR fixes it on the scale side instead.

The scale shuffle used scaleN_valid as the row-group stride, but the
GEMM kernel reads the scale buffer using the physical column width
(scaleN_pad = ceil(scaleN_valid/8)*8). When K % 256 != 0 these differ,
causing misaligned scale reads for all row groups beyond the first 32
rows (~99% of output elements wrong).

Two fixes in _dynamic_mxfp4_quant_kernel_asm_layout:
- Use scaleN_pad instead of scaleN in the row-group stride term so the
  write layout matches what the GEMM kernel expects.
- Fill out-of-bounds scale entries (K >= actual K) with 0 instead of
  127, so padding K tiles contribute ~0 to the GEMM output.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Wanzizhu Wanzizhu requested review from a team and Copilot April 24, 2026 07:12
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2900 --add-label <label>

@Wanzizhu Wanzizhu changed the title fix: correct mxfp4 scale shuffle layout for K not divisible by 256 Fix: correct mxfp4 for K not divisible by 256 Apr 24, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes incorrect MXFP4 blockscale shuffle layout when K % 256 != 0 by aligning the quant kernel’s scale write stride with the GEMM kernel’s padded scale width, and ensuring padded K tiles contribute negligibly.

Changes:

  • Use scaleN_pad (padded-to-multiple-of-8) in the shuffled blockscale row-group stride to match the GEMM read layout.
  • Fill out-of-bounds/padded scale entries with 0 (min e8m0 scale) instead of 127 (≈1.0) to avoid garbage contributions from padded K tiles.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +389 to +393
+ bs_offs_0 * 2 * 16 * scaleN_pad
)
bs_mask1 = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :]
bs_mask2 = (bs_offs_m < scaleM_pad)[:, None] & (bs_offs_n < scaleN_pad)[None, :]
bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 127)
bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 0)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shuffle=True layout is critical for correctness when scaleN_valid is not a multiple of 8 (e.g., K%256!=0). There doesn’t appear to be a test that exercises dynamic_mxfp4_quant(..., shuffle=True) for such shapes to validate the padded stride/layout and padding-fill behavior; please add a regression test (e.g., N=2880 where scaleN_valid=90, scaleN_pad=96) comparing against a reference shuffle (pad then shuffle) or round-trip unshuffle.

Copilot uses AI. Check for mistakes.
bs_mask1 = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :]
bs_mask2 = (bs_offs_m < scaleM_pad)[:, None] & (bs_offs_n < scaleN_pad)[None, :]
bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 127)
bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 0)
Copy link

Copilot AI Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tl.where(bs_mask1, bs_e8m0, 0) introduces a new magic value where 0 is intended to mean the smallest e8m0 scale (2^-127). Consider defining a named constant (or adding a short inline comment) to make it clear this is an intentional “near-zero” padding scale rather than an arbitrary zero byte.

Suggested change
bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 0)
E8M0_PAD_SCALE = 0 # e8m0 encoding for the smallest padding scale (2^-127)
bs_e8m0 = tl.where(bs_mask1, bs_e8m0, E8M0_PAD_SCALE)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants