Add A16W4 MoE GEMM stage2 kernel (BF16 activations x MXFP4 weights)#431
Draft
Add A16W4 MoE GEMM stage2 kernel (BF16 activations x MXFP4 weights)#431
Conversation
* Add _cvt_scalef32_pk_bf16_fp4: GFX950 hardware path for converting 2 FP4 E2M1 nibbles to 2 bf16 via v_cvt_scalef32_pk_bf16_fp4 (1 VALU vs ~36 software). * Add _fp4x4_in_i32_to_bf16x4_i64: software fallback converting 4 FP4 nibbles (packed in 4 bytes of i32) to 4 bf16 packed as i64. * Add load_b_raw_mxfp4: loads 4 bytes (8 FP4 nibbles) from a kpack=16 preshuffle layout (shuffle_weight_a16w4 format) using ku-based k0/klane addressing. * Add load_b_raw_mxfp4_dwordx4: dwordx4 variant loading the full 16-byte kpack for one sub-lane in a single buffer_load. * Add unpack_b_mxfp4_bf16: dispatches to hw or sw path, returning (b0, b1) i64 pair for mfma_f32_16x16x32_bf16.
* Add _decode_e8m0_byte_to_f32: converts an E8M0 byte (i8) to f32 = 2^(e-127) via bit shift into position 23. * Add _barrier: emits s_waitcnt + s_barrier as inline asm, bypassing LLVM SIInsertWaitcnts conservative insertion. * Add compile_a16w4_moe_gemm2: stage2 down-projection GEMM for BF16 activations x MXFP4 (FP4 E2M1) weights with E8M0 block scales, using mfma_f32_16x16x32_bf16. Ported from aiter a16w4_moe_gemm_2stage.py with mechanical adaptations (fx.idx2crd, tuple indexing in place of layout_get).
* Add test_moe_gemm2_a16w4: exercises compile_a16w4_moe_gemm2 with BF16 activations and MXFP4 E2M1 weights pre-shuffled via shuffle_weight_a16w4/shuffle_scale_a16w4. Compares kernel output against torch_moe_gemm2 reference. Gated on gfx950+.
…ordx4 * Fix IntTuple component extraction: layout_get(coord, i) must map to fx.get(coord, i), not coord[i]. FlyDSL's fx.idx2crd returns a !fly.int_tuple MLIR value; Python indexing returns another IntTuple, not a scalar index. * Drop cache_modifier from load_b_raw_mxfp4_dwordx4: FlyDSL's _buffer_load_vec does not support this parameter. The argument is an optional cache-policy hint with no correctness impact.
* Use inter_dim=3072 shapes so inter_dim/tile_k >= 2 (ping-pong pipeline requires at least 2 K-tile iterations). * View float4 weight and scale tensors as uint8 before passing to kernel (DLPack does not support float4_e2m1fn_x2 or float8_e8m0fnu). * Remove test_name kwarg from verify_output call (not part of its signature).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
This PR ports the A16W4 MoE GEMM stage2 kernel from aiter into FlyDSL. The kernel targets BF16 activations with MXFP4 (FP4 E2M1, per-1x32 block scale) weights — the down-projection GEMM in a Mixture-of-Experts FFN block.
Performance on GPT-OSS shapes (model_dim=3072, inter_dim=3072, E=128, topk=4, MI355X) shows stage2 consistently outperforming CK Tile across all token counts:
fly/CK < 1.0 means FlyDSL is faster. Numbers from Zan Zhang's original aiter implementation, which is functionally identical to this port.
The stage2 kernel is landed first as a self-contained unit. A follow-up PR will add the stage1 kernel (
compile_a16w4_moe_gemm1).Technical Details
Three existing files are modified — no new files are added.
kernels/mfma_preshuffle_pipeline.py— new load/unpack helpers for the kpack=16 MXFP4 preshuffle format (authored by Zan Zhang):load_b_raw_mxfp4: loads 4 bytes (8 FP4 nibbles) for one ku step usingshuffle_weight_a16w4addressingload_b_raw_mxfp4_dwordx4: dwordx4 variant loading the full 16-byte kpack in a single buffer loadunpack_b_mxfp4_bf16: unpacks 8 FP4 E2M1 nibbles to two i64 values (8 bf16) formfma_f32_16x16x32_bf16; dispatches to GFX950 hardware path (v_cvt_scalef32_pk_bf16_fp4, 4 VALU) or software fallbackkernels/mixed_moe_gemm_2stage.py— new stage2 kernel (authored by Zan Zhang):_decode_e8m0_byte_to_f32: decodes an E8M0 scale byte to f32 = 2^(e-127)_barrier: emitss_waitcnt + s_barrieras inline asm, bypassing LLVM's conservative barrier insertioncompile_a16w4_moe_gemm2: the full stage2 down-projection GEMM. Takes BF16 activations and MXFP4 weights pre-shuffled byshuffle_weight_a16w4/shuffle_scale_a16w4. Usesmfma_f32_16x16x32_bf16with a ping-pong double-buffered pipeline. Supports atomic accumulation and reduce modes, optional per-row routing weights, and optional bias.Ported from
aiter/aiter/ops/flydsl/kernels/a16w4_moe_gemm_2stage.pywith two mechanical adaptations required by FlyDSL conventions:idx2crd(...)→fx.idx2crd(...)withlayout_get(coord, i)→fx.get(coord, i)(FlyDSL returns a!fly.int_tupleMLIR value, not a Python tuple)cache_modifierdropped from_buffer_load_vec(not supported in FlyDSL; the call site passedcache_modifier=0, i.e. default policy, so there is no correctness or performance impact)Test Plan
New test
test_moe_gemm2_a16w4intests/kernels/test_moe_gemm.py:compile_a16w4_moe_gemm2directly with two shapes (small: 16 tokens, medium: 128 tokens)aiter.get_torch_quant(QuantType.per_1x32)shuffle_weight_a16w4/shuffle_scale_a16w4fromaiter.ops.shuffletorch_moe_gemm2reference (FP32 dequantised computation)gfx950+(MI350/MI355X); skipped on older architecturesTest Result
Both parametrisations pass on MI355X (gfx950):
Submission Checklist