Skip to content

Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API#433

Open
aryaman-gupta wants to merge 44 commits intomainfrom
aryaman/group-gemm
Open

Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API#433
aryaman-gupta wants to merge 44 commits intomainfrom
aryaman/group-gemm

Conversation

@aryaman-gupta
Copy link
Copy Markdown

@aryaman-gupta aryaman-gupta commented Apr 23, 2026

Summary

Adds two new grouped and batched FP8 GEMM kernels with blockscaling that mirror the DeepGEMM API on AMD CDNA GPUs:

FlyDSL kernel DeepGEMM op
kernels/grouped_gemm_blockscale_contiguous.pycompile_grouped_gemm_blockscale_contiguous m_grouped_fp8_gemm_nt_contiguous
kernels/grouped_gemm_blockscale_masked.pycompile_masked_grouped_gemm_blockscale_masked m_grouped_fp8_gemm_nt_masked

These ops are core to MoE inference workloads where the gate/up/down projections of multiple experts are batched into a single grouped GEMM. DeepSeek-V3 is the most prominent example — its expert MLPs use FP8 with per-token activation scaling and per-block weight scaling, exactly the configuration these kernels accept. Adding FlyDSL implementations unblocks running such models on AMD hardware via the same call sites that already use DeepGEMM on NVIDIA.

The Python signatures (tensor shapes, dtypes, scale_a / scale_b layouts including the transposed [scale_k, M] activation-scale layout, grouped_layout semantics with -1 for padding, masked_m semantics for the masked variant, (1, 128) × (128, 128) block-scale granularity) are designed to match the DeepGEMM ops byte-for-byte so call-site code can switch backend by import alone.

Test plan

  • Unit / correctness: tests/kernels/test_grouped_gemm_blockscale_contiguous.py and ..._masked.py. Coverage spans 1–8 groups, m_per_group from 100 (unaligned) to 1024, plus DeepSeek-V3 shapes (N=2048 K=7168 and N=7168 K=2304) at both bf16 and f16 outputs.
  • 30 / 30 tests pass on MI350 (gfx950) at logits_diff_threshold=1e-3.
  • Tests registered in tests/arch_compat.py:CDNA_ONLY_TESTS so non-CDNA CI auto-skips.
  • pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] for correct CI bucketing.

Run locally:

PYTHONPATH=./ FLYDSL_RUNTIME_ENABLE_CACHE=0 pytest \
  tests/kernels/test_grouped_gemm_blockscale_contiguous.py \
  tests/kernels/test_grouped_gemm_blockscale_masked.py

DeepGEMM conformance details

A few things were chosen explicitly to match DeepGEMM's behavior so any divergence in numerical results vs the NVIDIA path would surface as a real bug rather than a tolerance / convention mismatch:

  • Tolerance. The correctness threshold uses the same calc_diff formula as DeepGEMM (cosine-similarity-style normalized error) at logits_diff_threshold=1e-3 — matching DeepGEMM's own FP8 GEMM tests (tests/test_legacy.py:35, tests/test_fp8_fp4.py:194).
  • E8M0 quantization on the host. Scale tensors are rounded with ceil_to_ue8m0 (matching DeepGEMM's deep_gemm/utils/math.py:13). Truncation would shrink the scale and cause FP8 saturation on every block; the ceiling is what keeps x / scale_e8m0 ≤ fp8_max.
  • Reference convention. Tests compute the reference from pre-quantization FP32 inputs (a_f32 @ b_f32.T) — same as DeepGEMM's tests. The reference contains zero quantization error, so the diff measured against it is the actual end-to-end FP8 → ground-truth error budget.
  • Hardware-vs-software scaling, matching DeepGEMM's CUDA / SM strategy. On MI350 (gfx950) the kernel uses the hardware E8M0 path via mfma_scale_f32_16x16x128_f8f6f4, with E8M0 bytes pre-extracted on the host and loaded as uint8 (analogous to DeepGEMM's pack_ue8m0_to_int for the SM100 path). On MI300 (gfx942) where the MFMA-scale instruction is unavailable, scaling is applied in software.

Notes for reviewers

  • Deliberate test-style divergence from test_blockscale_preshuffle_gemm.py. Reference uses pre-quantization FP32 instead of dequant-then-matmul, and tolerance is 1e-3 instead of the repo default 2e-3 — both deliberately chosen to match DeepGEMM's convention as described above.

aryaman-gupta and others added 30 commits March 27, 2026 18:58
Replace hardcoded test calls with argparse-based __main__ matching
the pattern used by other kernel tests (blockscale, moe, preshuffle).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace inline correctness checks and manual CUDA event benchmarking
with shared utilities from tests.test_common, matching the pattern
used by blockscale_preshuffle_gemm and other kernel tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
tile_m/n/k, out_dtype, num_iters, and num_warmup were parsed but
never passed to the test functions which hardcoded their own values.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Explicitly create the CPU reference output tensor with device='cpu'
to avoid conflict when torch.set_default_device('cuda') is active.
The reference matmul stays on CPU due to hipBLAS issues on this ROCm.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Split load_a_tile into prefetch_a_tile (Global→VGPR) and
store_a_tile_to_lds (VGPR→LDS). Moves ds_write after compute_tile
to match the MoE blockscale 2-stage pipeline, enabling future
instruction scheduling to interleave ds_write with trailing MFMAs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds hot_loop_scheduler() with coarse-grained sched_group_barrier
hints matching the moe_blockscale_2stage pattern. Placed after
store_a_tile_to_lds and before gpu.barrier(), only emitted when
a next tile actually exists (avoids LLVM assertion from mismatched
instruction counts on tail iterations).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds optional waves_per_eu parameter to compile functions. When set,
applies rocdl.waves_per_eu attribute to gpu.func for occupancy
tuning. Matches the pattern from blockscale_preshuffle_gemm and
moe_blockscale_2stage kernels.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Contiguous test:
- Generate unaligned M group sizes with -1 padding rows (DeepGEMM convention)
- Add unaligned M test cases (2g-100m, 4g-200m)
- Add DeepSeek-V3 shapes (2112x7168, 7168x2304)
- Add out_dtype parametrization (bf16 + f16)
- Zero out padding rows before comparison
- Add --waves_per_eu CLI arg

Masked test:
- Fix output buffer dtype bug (was hardcoded bf16, now respects out_dtype)
- Add sparse masking test (4g-512max-50m)
- Add DeepSeek-V3 shapes
- Add out_dtype parametrization (bf16 + f16)
- Wire out_dtype through generate_masked_grouped_gemm_inputs
- Add --waves_per_eu CLI arg

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
aryaman-gupta and others added 12 commits April 17, 2026 10:27
…divisibility

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace unity scaling (0x7F7F7F7F) with actual E8M0 hardware scale
application in the mfma_scale_f32_16x16x128_f8f6f4 instruction.
Converts FP32 block scales to E8M0 (exponent extraction via >> 23),
loads per-lane scaleA (varies by lane_mod_16) and uniform scaleB,
and accumulates directly into the running accumulator. Eliminates
software multiply + FMA scale application on gfx950.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace raw arith.shrui/arith.andi calls with ArithValue operator
overloading (>> and &) for the FP32-to-E8M0 conversion. The raw
dialect ops require ir.Value operands, but fx.Int32() creates DSL
wrappers. ArithValue operators handle the unwrapping automatically.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use DeepGEMM's ceil_to_ue8m0 (round scale UP, not truncate) so that
x / scale_e8m0 ≤ fp8_max — truncation caused FP8 saturation and a
systematic per-element bias on every block, failing every config.

Switch reference to FP32-input GEMM (DeepGEMM convention) so the test
measures total kernel + quantization error against ground truth, and
tighten logits_diff threshold to 1e-3 to match DeepGEMM.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The scale_b address is wave-uniform (no lane dependence), but every lane
was issuing the same buffer_load. Promote the value via rocdl.readfirstlane
so downstream consumers can use it from an SGPR-style broadcast instead of
a per-lane VGPR.

Modest but consistent gain on memory-leaning DS-V3 shapes (N=2048 K=7168
m=256: +15% TFLOPS); ~neutral elsewhere within run-to-run noise. Verified
via ISA: +56 v_readfirstlane instructions, no VGPR change, no occupancy
hit. Correctness 30/30 at 1e-3 logits_diff.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Tests extract the E8M0 byte after fp32_to_e8m0 and return uint8 scale
tensors. Kernel buffer resources sized at 1 byte/scale on gfx950; HW
path does buffer_load(T.i8) + extui-to-i32, dropping the in-kernel
bitcast/shrui/andi extraction.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The s_a_vecs f32 loads are unused on the gfx950 HW path and would index
out-of-bounds against the int8-sized scale buffer if MLIR DCE ever
failed to eliminate them. Gating makes the gfx950/942 split explicit.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Add `prefetch_scales(k_tile_idx_py)` helper that loads the E8M0 byte for
each (mi, ni) of the next K-tile into VGPRs ahead of `compute_tile`.
Issued before `load_b_tile` in the ping-pong loop so scale-VMEM latency
overlaps the prior tile's MFMAs and the next B-tile load.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Register both tests in tests/arch_compat.py:CDNA_ONLY_TESTS so RDNA CI
  auto-skips them.
- Add `pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower]`
  so CI buckets them correctly.
- Add `torch.cuda.is_available()` module-level skip guard.
- Drop dead `--waves_per_eu` argparse arg (was accepted but never
  forwarded).
- Merge per-file `*_correctness` and `*_performance` into a single
  `test_grouped_fp8_gemm` / `test_masked_grouped_fp8_gemm` matching the
  test_blockscale_preshuffle_gemm convention.
- Move the per-group reference matmul from CPU to GPU (hipBLASLt). Test
  suite runtime drops ~70s → ~34s.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Remove unused `import os` from both kernel files.
- Remove orphan `# Helper: compute one K-tile from LDS + B tile` banner
  from both kernels (the function it labeled was renamed/refactored).
- Remove duplicate `c_scale_k = fx.Index(scale_k)` reassignment in the
  masked kernel (already in scope from the earlier definition).
- Drop the drift-prone "Optimizations applied:" lists from kernel
  module docstrings; correct the now-stale `scale_a` / `scale_b` dtype
  to reflect uint8 on gfx950 / FP32 on gfx942.
- Simplify the "Per-group matmul" comment in both test files; drop the
  specific backend (hipBLASLt) claim.
- Add missing `device`, `scale_block_k`, `scale_block_n` entries to the
  masked test's `generate_*_inputs` Args docstring.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- compile_grouped_fp8_gemm           -> compile_grouped_gemm_blockscale_contiguous
- compile_masked_grouped_fp8_gemm    -> compile_grouped_gemm_blockscale_masked

The new names mirror the file names exactly (drop "fp8_gemm",
incorporate "blockscale" + the contiguous/masked variant), making
the call site self-documenting. Internal kernel/launcher symbols and
the JIT cache-key strings are renamed in lockstep. Test imports and
call sites updated. DeepGEMM op references in the docstrings
(`m_grouped_fp8_gemm_nt_contiguous` / `..._masked`) are unchanged —
those are DeepGEMM's actual symbol names.

Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
@aryaman-gupta aryaman-gupta changed the title Adds Grouped GEMM kernels matching DeepGEMM API Adds Grouped and Batched GEMM kernels matching DeepGEMM API Apr 24, 2026
@aryaman-gupta aryaman-gupta changed the title Adds Grouped and Batched GEMM kernels matching DeepGEMM API Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant