Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API#433
Open
aryaman-gupta wants to merge 44 commits intomainfrom
Open
Adds Grouped and Batched GEMM kernels with blockscaling matching DeepGEMM API#433aryaman-gupta wants to merge 44 commits intomainfrom
aryaman-gupta wants to merge 44 commits intomainfrom
Conversation
Replace hardcoded test calls with argparse-based __main__ matching the pattern used by other kernel tests (blockscale, moe, preshuffle). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace inline correctness checks and manual CUDA event benchmarking with shared utilities from tests.test_common, matching the pattern used by blockscale_preshuffle_gemm and other kernel tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
tile_m/n/k, out_dtype, num_iters, and num_warmup were parsed but never passed to the test functions which hardcoded their own values. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Explicitly create the CPU reference output tensor with device='cpu'
to avoid conflict when torch.set_default_device('cuda') is active.
The reference matmul stays on CPU due to hipBLAS issues on this ROCm.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Split load_a_tile into prefetch_a_tile (Global→VGPR) and store_a_tile_to_lds (VGPR→LDS). Moves ds_write after compute_tile to match the MoE blockscale 2-stage pipeline, enabling future instruction scheduling to interleave ds_write with trailing MFMAs. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds hot_loop_scheduler() with coarse-grained sched_group_barrier hints matching the moe_blockscale_2stage pattern. Placed after store_a_tile_to_lds and before gpu.barrier(), only emitted when a next tile actually exists (avoids LLVM assertion from mismatched instruction counts on tail iterations). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds optional waves_per_eu parameter to compile functions. When set, applies rocdl.waves_per_eu attribute to gpu.func for occupancy tuning. Matches the pattern from blockscale_preshuffle_gemm and moe_blockscale_2stage kernels. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Contiguous test: - Generate unaligned M group sizes with -1 padding rows (DeepGEMM convention) - Add unaligned M test cases (2g-100m, 4g-200m) - Add DeepSeek-V3 shapes (2112x7168, 7168x2304) - Add out_dtype parametrization (bf16 + f16) - Zero out padding rows before comparison - Add --waves_per_eu CLI arg Masked test: - Fix output buffer dtype bug (was hardcoded bf16, now respects out_dtype) - Add sparse masking test (4g-512max-50m) - Add DeepSeek-V3 shapes - Add out_dtype parametrization (bf16 + f16) - Wire out_dtype through generate_masked_grouped_gemm_inputs - Add --waves_per_eu CLI arg Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…divisibility Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace unity scaling (0x7F7F7F7F) with actual E8M0 hardware scale application in the mfma_scale_f32_16x16x128_f8f6f4 instruction. Converts FP32 block scales to E8M0 (exponent extraction via >> 23), loads per-lane scaleA (varies by lane_mod_16) and uniform scaleB, and accumulates directly into the running accumulator. Eliminates software multiply + FMA scale application on gfx950. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace raw arith.shrui/arith.andi calls with ArithValue operator overloading (>> and &) for the FP32-to-E8M0 conversion. The raw dialect ops require ir.Value operands, but fx.Int32() creates DSL wrappers. ArithValue operators handle the unwrapping automatically. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Use DeepGEMM's ceil_to_ue8m0 (round scale UP, not truncate) so that x / scale_e8m0 ≤ fp8_max — truncation caused FP8 saturation and a systematic per-element bias on every block, failing every config. Switch reference to FP32-input GEMM (DeepGEMM convention) so the test measures total kernel + quantization error against ground truth, and tighten logits_diff threshold to 1e-3 to match DeepGEMM. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The scale_b address is wave-uniform (no lane dependence), but every lane was issuing the same buffer_load. Promote the value via rocdl.readfirstlane so downstream consumers can use it from an SGPR-style broadcast instead of a per-lane VGPR. Modest but consistent gain on memory-leaning DS-V3 shapes (N=2048 K=7168 m=256: +15% TFLOPS); ~neutral elsewhere within run-to-run noise. Verified via ISA: +56 v_readfirstlane instructions, no VGPR change, no occupancy hit. Correctness 30/30 at 1e-3 logits_diff. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Tests extract the E8M0 byte after fp32_to_e8m0 and return uint8 scale tensors. Kernel buffer resources sized at 1 byte/scale on gfx950; HW path does buffer_load(T.i8) + extui-to-i32, dropping the in-kernel bitcast/shrui/andi extraction. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
The s_a_vecs f32 loads are unused on the gfx950 HW path and would index out-of-bounds against the int8-sized scale buffer if MLIR DCE ever failed to eliminate them. Gating makes the gfx950/942 split explicit. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
Add `prefetch_scales(k_tile_idx_py)` helper that loads the E8M0 byte for each (mi, ni) of the next K-tile into VGPRs ahead of `compute_tile`. Issued before `load_b_tile` in the ping-pong loop so scale-VMEM latency overlaps the prior tile's MFMAs and the next B-tile load. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Register both tests in tests/arch_compat.py:CDNA_ONLY_TESTS so RDNA CI auto-skips them. - Add `pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower]` so CI buckets them correctly. - Add `torch.cuda.is_available()` module-level skip guard. - Drop dead `--waves_per_eu` argparse arg (was accepted but never forwarded). - Merge per-file `*_correctness` and `*_performance` into a single `test_grouped_fp8_gemm` / `test_masked_grouped_fp8_gemm` matching the test_blockscale_preshuffle_gemm convention. - Move the per-group reference matmul from CPU to GPU (hipBLASLt). Test suite runtime drops ~70s → ~34s. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- Remove unused `import os` from both kernel files. - Remove orphan `# Helper: compute one K-tile from LDS + B tile` banner from both kernels (the function it labeled was renamed/refactored). - Remove duplicate `c_scale_k = fx.Index(scale_k)` reassignment in the masked kernel (already in scope from the earlier definition). - Drop the drift-prone "Optimizations applied:" lists from kernel module docstrings; correct the now-stale `scale_a` / `scale_b` dtype to reflect uint8 on gfx950 / FP32 on gfx942. - Simplify the "Per-group matmul" comment in both test files; drop the specific backend (hipBLASLt) claim. - Add missing `device`, `scale_block_k`, `scale_block_n` entries to the masked test's `generate_*_inputs` Args docstring. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
- compile_grouped_fp8_gemm -> compile_grouped_gemm_blockscale_contiguous - compile_masked_grouped_fp8_gemm -> compile_grouped_gemm_blockscale_masked The new names mirror the file names exactly (drop "fp8_gemm", incorporate "blockscale" + the contiguous/masked variant), making the call site self-documenting. Internal kernel/launcher symbols and the JIT cache-key strings are renamed in lockstep. Test imports and call sites updated. DeepGEMM op references in the docstrings (`m_grouped_fp8_gemm_nt_contiguous` / `..._masked`) are unchanged — those are DeepGEMM's actual symbol names. Co-Authored-By: Claude Opus 4 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds two new grouped and batched FP8 GEMM kernels with blockscaling that mirror the DeepGEMM API on AMD CDNA GPUs:
kernels/grouped_gemm_blockscale_contiguous.py→compile_grouped_gemm_blockscale_contiguousm_grouped_fp8_gemm_nt_contiguouskernels/grouped_gemm_blockscale_masked.py→compile_masked_grouped_gemm_blockscale_maskedm_grouped_fp8_gemm_nt_maskedThese ops are core to MoE inference workloads where the gate/up/down projections of multiple experts are batched into a single grouped GEMM. DeepSeek-V3 is the most prominent example — its expert MLPs use FP8 with per-token activation scaling and per-block weight scaling, exactly the configuration these kernels accept. Adding FlyDSL implementations unblocks running such models on AMD hardware via the same call sites that already use DeepGEMM on NVIDIA.
The Python signatures (tensor shapes, dtypes,
scale_a/scale_blayouts including the transposed[scale_k, M]activation-scale layout,grouped_layoutsemantics with-1for padding,masked_msemantics for the masked variant,(1, 128)×(128, 128)block-scale granularity) are designed to match the DeepGEMM ops byte-for-byte so call-site code can switch backend by import alone.Test plan
tests/kernels/test_grouped_gemm_blockscale_contiguous.pyand..._masked.py. Coverage spans 1–8 groups,m_per_groupfrom 100 (unaligned) to 1024, plus DeepSeek-V3 shapes (N=2048 K=7168 and N=7168 K=2304) at both bf16 and f16 outputs.logits_diff_threshold=1e-3.tests/arch_compat.py:CDNA_ONLY_TESTSso non-CDNA CI auto-skips.pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower]for correct CI bucketing.Run locally:
DeepGEMM conformance details
A few things were chosen explicitly to match DeepGEMM's behavior so any divergence in numerical results vs the NVIDIA path would surface as a real bug rather than a tolerance / convention mismatch:
calc_diffformula as DeepGEMM (cosine-similarity-style normalized error) atlogits_diff_threshold=1e-3— matching DeepGEMM's own FP8 GEMM tests (tests/test_legacy.py:35,tests/test_fp8_fp4.py:194).ceil_to_ue8m0(matching DeepGEMM'sdeep_gemm/utils/math.py:13). Truncation would shrink the scale and cause FP8 saturation on every block; the ceiling is what keepsx / scale_e8m0 ≤ fp8_max.a_f32 @ b_f32.T) — same as DeepGEMM's tests. The reference contains zero quantization error, so the diff measured against it is the actual end-to-end FP8 → ground-truth error budget.mfma_scale_f32_16x16x128_f8f6f4, with E8M0 bytes pre-extracted on the host and loaded asuint8(analogous to DeepGEMM'spack_ue8m0_to_intfor the SM100 path). On MI300 (gfx942) where the MFMA-scale instruction is unavailable, scaling is applied in software.Notes for reviewers
test_blockscale_preshuffle_gemm.py. Reference uses pre-quantization FP32 instead of dequant-then-matmul, and tolerance is1e-3instead of the repo default2e-3— both deliberately chosen to match DeepGEMM's convention as described above.