Skip to content

Add A16W4 MoE GEMM stage2 kernel (BF16 activations x MXFP4 weights)#431

Draft
apicciau wants to merge 5 commits intomainfrom
apicciau/a16w4-moe-gemm2
Draft

Add A16W4 MoE GEMM stage2 kernel (BF16 activations x MXFP4 weights)#431
apicciau wants to merge 5 commits intomainfrom
apicciau/a16w4-moe-gemm2

Conversation

@apicciau
Copy link
Copy Markdown

Motivation

This PR ports the A16W4 MoE GEMM stage2 kernel from aiter into FlyDSL. The kernel targets BF16 activations with MXFP4 (FP4 E2M1, per-1x32 block scale) weights — the down-projection GEMM in a Mixture-of-Experts FFN block.

Performance on GPT-OSS shapes (model_dim=3072, inter_dim=3072, E=128, topk=4, MI355X) shows stage2 consistently outperforming CK Tile across all token counts:

tokens FlyDSL (us) CK Tile (us) fly/CK
1 17.37 18.47 0.940x
2 18.21 19.29 0.944x
4 21.15 23.63 0.895x
8 30.92 31.11 0.994x
16 53.04 55.36 0.958x
32 97.16 98.71 0.984x
64 97.57 98.63 0.989x
128 98.96 101.34 0.977x

fly/CK < 1.0 means FlyDSL is faster. Numbers from Zan Zhang's original aiter implementation, which is functionally identical to this port.

The stage2 kernel is landed first as a self-contained unit. A follow-up PR will add the stage1 kernel (compile_a16w4_moe_gemm1).

Technical Details

Three existing files are modified — no new files are added.

kernels/mfma_preshuffle_pipeline.py — new load/unpack helpers for the kpack=16 MXFP4 preshuffle format (authored by Zan Zhang):

  • load_b_raw_mxfp4: loads 4 bytes (8 FP4 nibbles) for one ku step using shuffle_weight_a16w4 addressing
  • load_b_raw_mxfp4_dwordx4: dwordx4 variant loading the full 16-byte kpack in a single buffer load
  • unpack_b_mxfp4_bf16: unpacks 8 FP4 E2M1 nibbles to two i64 values (8 bf16) for mfma_f32_16x16x32_bf16; dispatches to GFX950 hardware path (v_cvt_scalef32_pk_bf16_fp4, 4 VALU) or software fallback

kernels/mixed_moe_gemm_2stage.py — new stage2 kernel (authored by Zan Zhang):

  • _decode_e8m0_byte_to_f32: decodes an E8M0 scale byte to f32 = 2^(e-127)
  • _barrier: emits s_waitcnt + s_barrier as inline asm, bypassing LLVM's conservative barrier insertion
  • compile_a16w4_moe_gemm2: the full stage2 down-projection GEMM. Takes BF16 activations and MXFP4 weights pre-shuffled by shuffle_weight_a16w4/shuffle_scale_a16w4. Uses mfma_f32_16x16x32_bf16 with a ping-pong double-buffered pipeline. Supports atomic accumulation and reduce modes, optional per-row routing weights, and optional bias.

Ported from aiter/aiter/ops/flydsl/kernels/a16w4_moe_gemm_2stage.py with two mechanical adaptations required by FlyDSL conventions:

  • idx2crd(...)fx.idx2crd(...) with layout_get(coord, i)fx.get(coord, i) (FlyDSL returns a !fly.int_tuple MLIR value, not a Python tuple)
  • cache_modifier dropped from _buffer_load_vec (not supported in FlyDSL; the call site passed cache_modifier=0, i.e. default policy, so there is no correctness or performance impact)

Test Plan

New test test_moe_gemm2_a16w4 in tests/kernels/test_moe_gemm.py:

  • Calls compile_a16w4_moe_gemm2 directly with two shapes (small: 16 tokens, medium: 128 tokens)
  • Generates BF16 activations and MXFP4 weights quantised with aiter.get_torch_quant(QuantType.per_1x32)
  • Pre-shuffles weights and scales via shuffle_weight_a16w4 / shuffle_scale_a16w4 from aiter.ops.shuffle
  • Compares kernel output against torch_moe_gemm2 reference (FP32 dequantised computation)
  • Gated on gfx950+ (MI350/MI355X); skipped on older architectures

Test Result

Both parametrisations pass on MI355X (gfx950):

tests/kernels/test_moe_gemm.py::test_moe_gemm2_a16w4[a16w4-s2-small]  PASSED
tests/kernels/test_moe_gemm.py::test_moe_gemm2_a16w4[a16w4-s2-medium] PASSED

Submission Checklist

Zzz9990 and others added 5 commits April 23, 2026 10:54
* Add _cvt_scalef32_pk_bf16_fp4: GFX950 hardware path for converting 2 FP4 E2M1 nibbles to 2 bf16 via v_cvt_scalef32_pk_bf16_fp4 (1 VALU vs ~36 software).
* Add _fp4x4_in_i32_to_bf16x4_i64: software fallback converting 4 FP4 nibbles (packed in 4 bytes of i32) to 4 bf16 packed as i64.
* Add load_b_raw_mxfp4: loads 4 bytes (8 FP4 nibbles) from a kpack=16 preshuffle layout (shuffle_weight_a16w4 format) using ku-based k0/klane addressing.
* Add load_b_raw_mxfp4_dwordx4: dwordx4 variant loading the full 16-byte kpack for one sub-lane in a single buffer_load.
* Add unpack_b_mxfp4_bf16: dispatches to hw or sw path, returning (b0, b1) i64 pair for mfma_f32_16x16x32_bf16.
* Add _decode_e8m0_byte_to_f32: converts an E8M0 byte (i8) to f32 = 2^(e-127) via bit shift into position 23.
* Add _barrier: emits s_waitcnt + s_barrier as inline asm, bypassing LLVM SIInsertWaitcnts conservative insertion.
* Add compile_a16w4_moe_gemm2: stage2 down-projection GEMM for BF16 activations x MXFP4 (FP4 E2M1) weights with E8M0 block scales, using mfma_f32_16x16x32_bf16. Ported from aiter a16w4_moe_gemm_2stage.py with mechanical adaptations (fx.idx2crd, tuple indexing in place of layout_get).
* Add test_moe_gemm2_a16w4: exercises compile_a16w4_moe_gemm2 with BF16 activations and MXFP4 E2M1 weights pre-shuffled via shuffle_weight_a16w4/shuffle_scale_a16w4. Compares kernel output against torch_moe_gemm2 reference. Gated on gfx950+.
…ordx4

* Fix IntTuple component extraction: layout_get(coord, i) must map to fx.get(coord, i), not coord[i]. FlyDSL's fx.idx2crd returns a !fly.int_tuple MLIR value; Python indexing returns another IntTuple, not a scalar index.
* Drop cache_modifier from load_b_raw_mxfp4_dwordx4: FlyDSL's _buffer_load_vec does not support this parameter. The argument is an optional cache-policy hint with no correctness impact.
* Use inter_dim=3072 shapes so inter_dim/tile_k >= 2 (ping-pong pipeline requires at least 2 K-tile iterations).
* View float4 weight and scale tensors as uint8 before passing to kernel (DLPack does not support float4_e2m1fn_x2 or float8_e8m0fnu).
* Remove test_name kwarg from verify_output call (not part of its signature).
@apicciau apicciau self-assigned this Apr 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants