Conversation
The scale shuffle used scaleN_valid as the row-group stride, but the GEMM kernel reads the scale buffer using the physical column width (scaleN_pad = ceil(scaleN_valid/8)*8). When K % 256 != 0 these differ, causing misaligned scale reads for all row groups beyond the first 32 rows (~99% of output elements wrong). Two fixes in _dynamic_mxfp4_quant_kernel_asm_layout: - Use scaleN_pad instead of scaleN in the row-group stride term so the write layout matches what the GEMM kernel expects. - Fill out-of-bounds scale entries (K >= actual K) with 0 instead of 127, so padding K tiles contribute ~0 to the GEMM output. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Fixes incorrect MXFP4 blockscale shuffle layout when K % 256 != 0 by aligning the quant kernel’s scale write stride with the GEMM kernel’s padded scale width, and ensuring padded K tiles contribute negligibly.
Changes:
- Use
scaleN_pad(padded-to-multiple-of-8) in the shuffled blockscale row-group stride to match the GEMM read layout. - Fill out-of-bounds/padded scale entries with
0(min e8m0 scale) instead of127(≈1.0) to avoid garbage contributions from padded K tiles.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| + bs_offs_0 * 2 * 16 * scaleN_pad | ||
| ) | ||
| bs_mask1 = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] | ||
| bs_mask2 = (bs_offs_m < scaleM_pad)[:, None] & (bs_offs_n < scaleN_pad)[None, :] | ||
| bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 127) | ||
| bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 0) |
There was a problem hiding this comment.
The shuffle=True layout is critical for correctness when scaleN_valid is not a multiple of 8 (e.g., K%256!=0). There doesn’t appear to be a test that exercises dynamic_mxfp4_quant(..., shuffle=True) for such shapes to validate the padded stride/layout and padding-fill behavior; please add a regression test (e.g., N=2880 where scaleN_valid=90, scaleN_pad=96) comparing against a reference shuffle (pad then shuffle) or round-trip unshuffle.
| bs_mask1 = (bs_offs_m < M)[:, None] & (bs_offs_n < scaleN)[None, :] | ||
| bs_mask2 = (bs_offs_m < scaleM_pad)[:, None] & (bs_offs_n < scaleN_pad)[None, :] | ||
| bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 127) | ||
| bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 0) |
There was a problem hiding this comment.
tl.where(bs_mask1, bs_e8m0, 0) introduces a new magic value where 0 is intended to mean the smallest e8m0 scale (2^-127). Consider defining a named constant (or adding a short inline comment) to make it clear this is an intentional “near-zero” padding scale rather than an arbitrary zero byte.
| bs_e8m0 = tl.where(bs_mask1, bs_e8m0, 0) | |
| E8M0_PAD_SCALE = 0 # e8m0 encoding for the smallest padding scale (2^-127) | |
| bs_e8m0 = tl.where(bs_mask1, bs_e8m0, E8M0_PAD_SCALE) |
Root cause
In _dynamic_mxfp4_quant_kernel_asm_layout (shuffle=True path), the scale buffer row-group stride was computed as bs_offs_0 * 2 * 16 * scaleN
where scaleN = scaleN_valid = ceil(K/32). However, the f4gemm kernel reads the scale buffer using the physical column width scaleN_pad =
ceil(scaleN_valid/8)*8.
When K % 256 != 0, scaleN_valid is not a multiple of 8, so scaleN_valid != scaleN_pad. This mismatch causes the quant kernel to write valid
scales at offsets based on stride scaleN_valid, while the GEMM kernel reads them at offsets based on stride scaleN_pad. Every row group beyond
the first 32 rows is misaligned, corrupting ~99% of output elements.
Additionally, out-of-bounds scale entries (for K_pad > K_actual) were filled with 127 (= 1.0 in e8m0), causing padding K tiles to contribute
garbage values to the GEMM output instead of zero.
Fix
Affected shapes
Any shape where K % 256 != 0 and ceil(K/32) % 8 != 0, e.g. K=2880 (scaleN_valid=90, scaleN_pad=96). K=2816 and K=3072 are both divisible by 256
so were unaffected.
Alternative fix
pad B's packed K dimension to a multiple of 128 (= 256 actual fp4 elements) with zeros before shuffling, so the GEMM never reads
padding K tiles. This PR fixes it on the scale side instead.