Skip to content

Add BF16xFP4 MoE GEMM stage1 kernel#424

Draft
apicciau wants to merge 2 commits intomainfrom
apicciau/bf16xfp4-moe-gemm-v2
Draft

Add BF16xFP4 MoE GEMM stage1 kernel#424
apicciau wants to merge 2 commits intomainfrom
apicciau/bf16xfp4-moe-gemm-v2

Conversation

@apicciau
Copy link
Copy Markdown

Adds BF16 x FP4 (MXFP4 E2M1) support to the MoE GEMM stage1 kernel on gfx950 (MI350/MI355X). In this configuration, activations stay in full BF16 precision while weights are stored in FP4 and dequantized to BF16 in software before each MFMA call. This avoids the accuracy loss of quantizing both operands and matches the W4A16 deployment pattern used in production MoE inference.

The implementation uses mfma_f32_16x16x32_bf16 with two K32 sub-steps per K64 outer tile, software dequant via v_cvt_scalef32_pk_bf16_fp4, and MXFP4 E8M0 block scale loading with correct byte-shift extraction from the packed i32 scale layout. Two bugs affecting all dtype paths are also fixed as part of this work.

  • compute_bf16xfp4_tile: software dequant via v_cvt_scalef32_pk_bf16_fp4, E8M0 scale loading with correct k_mid/m1 byte-shift extraction, and k1_override for the two K32 sub-steps per K64 outer tile.
  • col_n_valid guard in _stage1_store_row: prevents OOB CTAs (by=1 when tile_n==inter_dim) from writing zero rows into neighbouring token-slot output cells via a race condition on the strided output layout.
  • swiglu MLIR fix: convert Python float literals to arith.constant values before passing to arith.minimumf/maximumf.
  • test_moe_gemm1_bf16xfp4: 6 parametrizations (tile_m in {16,32,64} x act in {silu,swiglu}), all passing with logits_diff < 0.002.

Motivation

BF16xFP4 is the target dtype configuration for GPT-OSS on MI350/MI355X. Keeping activations in BF16 avoids the dual-quantization accuracy penalty while FP4 weights reduce memory bandwidth and model size. This PR makes the stage1 kernel usable for that deployment scenario.

Technical Details

The core challenge is that mfma_f32_16x16x32_bf16 requires BF16 operands, so FP4 weights cannot feed the MFMA directly. Each 32-bit FP4 kpack (8 nibbles per lane) is decoded to vec<8, bf16> via four calls to v_cvt_scalef32_pk_bf16_fp4, one per byte, each converting 2 nibbles to 2 BF16 values scaled by the MXFP4 E8M0 block scale. The E8M0 byte is placed into the FP32 exponent field via a left-shift by 23, producing the exact 2^(e-127) scale value the instruction expects.

The col_n_valid fix addresses a pre-existing correctness issue: when tile_n == inter_dim, the grid launches 2 CTAs along N (2*inter_dim // tile_n = 2), but one CTA already covers both the gate and up halves of W1. The second CTA has all output columns out of bounds, loads zero scales, produces zero accumulation, and without the guard writes those zeros into adjacent token-slot rows — corrupting roughly half the output.

Test Plan

PYTHONPATH=./ pytest tests/kernels/test_moe_gemm.py::test_moe_gemm1_bf16xfp4 -v

Test Result

All 6 parametrizations pass (logits_diff < 0.002, rtol=0.15, atol=0.15):

test_moe_gemm1_bf16xfp4[silu-tile_m16]   PASSED
test_moe_gemm1_bf16xfp4[silu-tile_m32]   PASSED
test_moe_gemm1_bf16xfp4[silu-tile_m64]   PASSED
test_moe_gemm1_bf16xfp4[swiglu-tile_m16] PASSED
test_moe_gemm1_bf16xfp4[swiglu-tile_m32] PASSED
test_moe_gemm1_bf16xfp4[swiglu-tile_m64] PASSED

Performance numbers against GPT-OSS shapes on MI350 to be added before merging.

Submission Checklist

Implements compile_mixed_moe_gemm1 for BF16 activations x FP4 E2M1 weights
using mfma_f32_16x16x32_bf16 on gfx950. Key additions:

- compute_bf16xfp4_tile: software dequant via v_cvt_scalef32_pk_bf16_fp4,
  E8M0 scale loading with correct k_mid/m1 byte-shift extraction, and
  k1_override for the two K32 sub-steps per K64 outer tile.
- col_n_valid guard in _stage1_store_row: prevents OOB CTAs (by=1 when
  tile_n==inter_dim) from writing zero rows into neighbouring token-slot
  output cells via a race condition on the strided output layout.
- swiglu MLIR fix: convert Python float literals to arith.constant values
  before passing to arith.minimumf/maximumf.
- test_moe_gemm1_bf16xfp4: 6 parametrizations (tile_m in {16,32,64} x
  act in {silu,swiglu}), all passing with logits_diff < 0.002.
Stage 2 x_lds_elem() referenced is_bf16_a which was never declared,
causing a NameError on any fp4 path. Also fixes a_elem_bytes to
account for bf16 (2 bytes) alongside fp16.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant