Draft
Conversation
Implements compile_mixed_moe_gemm1 for BF16 activations x FP4 E2M1 weights
using mfma_f32_16x16x32_bf16 on gfx950. Key additions:
- compute_bf16xfp4_tile: software dequant via v_cvt_scalef32_pk_bf16_fp4,
E8M0 scale loading with correct k_mid/m1 byte-shift extraction, and
k1_override for the two K32 sub-steps per K64 outer tile.
- col_n_valid guard in _stage1_store_row: prevents OOB CTAs (by=1 when
tile_n==inter_dim) from writing zero rows into neighbouring token-slot
output cells via a race condition on the strided output layout.
- swiglu MLIR fix: convert Python float literals to arith.constant values
before passing to arith.minimumf/maximumf.
- test_moe_gemm1_bf16xfp4: 6 parametrizations (tile_m in {16,32,64} x
act in {silu,swiglu}), all passing with logits_diff < 0.002.
Stage 2 x_lds_elem() referenced is_bf16_a which was never declared, causing a NameError on any fp4 path. Also fixes a_elem_bytes to account for bf16 (2 bytes) alongside fp16.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds BF16 x FP4 (MXFP4 E2M1) support to the MoE GEMM stage1 kernel on gfx950 (MI350/MI355X). In this configuration, activations stay in full BF16 precision while weights are stored in FP4 and dequantized to BF16 in software before each MFMA call. This avoids the accuracy loss of quantizing both operands and matches the W4A16 deployment pattern used in production MoE inference.
The implementation uses
mfma_f32_16x16x32_bf16with two K32 sub-steps per K64 outer tile, software dequant viav_cvt_scalef32_pk_bf16_fp4, and MXFP4 E8M0 block scale loading with correct byte-shift extraction from the packed i32 scale layout. Two bugs affecting all dtype paths are also fixed as part of this work.compute_bf16xfp4_tile: software dequant viav_cvt_scalef32_pk_bf16_fp4, E8M0 scale loading with correctk_mid/m1byte-shift extraction, andk1_overridefor the two K32 sub-steps per K64 outer tile.col_n_validguard in_stage1_store_row: prevents OOB CTAs (by=1whentile_n==inter_dim) from writing zero rows into neighbouring token-slot output cells via a race condition on the strided output layout.swigluMLIR fix: convert Python float literals toarith.constantvalues before passing toarith.minimumf/maximumf.test_moe_gemm1_bf16xfp4: 6 parametrizations (tile_min {16,32,64} xactin {silu,swiglu}), all passing withlogits_diff < 0.002.Motivation
BF16xFP4 is the target dtype configuration for GPT-OSS on MI350/MI355X. Keeping activations in BF16 avoids the dual-quantization accuracy penalty while FP4 weights reduce memory bandwidth and model size. This PR makes the stage1 kernel usable for that deployment scenario.
Technical Details
The core challenge is that
mfma_f32_16x16x32_bf16requires BF16 operands, so FP4 weights cannot feed the MFMA directly. Each 32-bit FP4 kpack (8 nibbles per lane) is decoded tovec<8, bf16>via four calls tov_cvt_scalef32_pk_bf16_fp4, one per byte, each converting 2 nibbles to 2 BF16 values scaled by the MXFP4 E8M0 block scale. The E8M0 byte is placed into the FP32 exponent field via a left-shift by 23, producing the exact2^(e-127)scale value the instruction expects.The
col_n_validfix addresses a pre-existing correctness issue: whentile_n == inter_dim, the grid launches 2 CTAs along N (2*inter_dim // tile_n = 2), but one CTA already covers both the gate and up halves of W1. The second CTA has all output columns out of bounds, loads zero scales, produces zero accumulation, and without the guard writes those zeros into adjacent token-slot rows — corrupting roughly half the output.Test Plan
Test Result
All 6 parametrizations pass (
logits_diff < 0.002,rtol=0.15,atol=0.15):Performance numbers against GPT-OSS shapes on MI350 to be added before merging.
Submission Checklist