Skip to content

Adds Grouped GEMM kernels matching DeepGEMM API#412

Merged
aryaman-gupta merged 29 commits intoaryaman/group-gemmfrom
main
Apr 23, 2026
Merged

Adds Grouped GEMM kernels matching DeepGEMM API#412
aryaman-gupta merged 29 commits intoaryaman/group-gemmfrom
main

Conversation

@aryaman-gupta
Copy link
Copy Markdown

No description provided.

coderfeli and others added 19 commits April 10, 2026 18:04
* [FIX] Support AOT cross-compilation with COMPILE_ONLY cache save
* [FIX] Simplify aot_example.py precompile path

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
---------

Co-authored-by: root <root@smci355-ccs-aus-n11-05.cs-aus.dcgpu>
Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
Co-authored-by: Felix Li <felix.li@amd.com>
)

* Add get_leaves op and support convert dyn_tuple to py_tuple

* fix comments
* enhance get_scalar op, only requires dyn_leaf_cnt = 1

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
---------

Co-authored-by: yashao@amd.com <yashao@amd.com@tus1-p3-g24.cluster.local>
Co-authored-by: Felix Li <felix.li@amd.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* [OPT] Add pass: convert-atom-call-to-ssa-form

* update mlir tests
* [OPT] Add pass promote-regmem-to-vectorssa

* fix comments

* fix empty yield case
…404)

* Add fused epilogue support to preshuffle GEMM: bias + ReLU/SiLU/GeLU

Modified body_row to apply bias + activation in registers before
the output store, eliminating separate epilogue kernel launches.

MI300X results: 2.73x avg speedup vs hipBLAS+bias+SiLU
- O-proj: 3.25x, MoE-dn: 3.33x, QKV: 2.92x
- Zero epilogue overhead (fused ops hidden by store latency)

New parameter: epilogue='none'|'bias'|'bias_relu'|'bias_silu'|'bias_gelu'
New kernel arg: arg_bias (N-element bias tensor)

* Fix test: pass dummy bias tensor for fused epilogue kernel signature

The fused epilogue commit added arg_bias to kernel_gemm and launch_gemm
unconditionally (needed for epilogue='none' to maintain a single kernel
definition). The test's _gemm_args and _w4_args functions need to pass
a dummy bias tensor to match the updated launch_gemm signature.

Without this fix, args shift: M goes to arg_bias slot, N to i32_m,
stream to i32_n, causing 'missing a required argument: stream' error.

* Add CUDAGraph capture test for preshuffle GEMM

Verifies that FlyDSL kernels are correctly captured by torch.cuda.CUDAGraph
when torch.cuda.current_stream() is passed as the stream argument.

Test flow:
1. Regular execution → reference result
2. CUDAGraph capture on a dedicated stream
3. Graph replay → verify result matches reference
4. Assert non-zero output (kernel was captured)

Tests both BF16 and FP8 paths.

* Fix CI test failures: torch.dtype kwarg + pertoken_quant kwarg name

Two test-only fixes in tests/kernels/test_preshuffle_gemm.py:

1. test_mfma_a8_flyc_preshuffle (line 206): torch.empty was passed
   the string out_dtype ("bf16"/"fp16") instead of a torch.dtype.
   Use torch_out_dtype (defined on line 116) which is the actual
   torch.dtype, matching c_out_raw on line 190.
   Fixes 192 failing parametrizations across mi325-1/mi355-1 runners.

2. test_cudagraph_capture_preshuffle[fp8] (lines 500-501): wrong
   kwarg name for pertoken_quant -- uses 'dtype=' but the function
   signature in tests/utils.py expects 'quant_dtype='. Switch to
   the correct kwarg.

No production code touched. CI for both mi325 and mi355 should
return to green.

* Address review: epilogue+cshuffle guard, bias max_size, SiLU rcp, fused tests

Four fixes for review comments on PR #404:

1. Reject epilogue!='none' + use_cshuffle_epilog=True (correctness bug).
   The cshuffle path returns from write_row_to_lds before body_row, so the
   bias/activation fusion would silently be dropped. Now raises ValueError
   at compile time.

2. Drop hardcoded 'c_n * 2' for the bias buffer size; use max_size=True
   like the other resources. The previous code assumed 2-byte output and
   would break on any future fp32 path.

3. Rewrite SiLU as 'val * (1/denom)' instead of 'val / denom' so the
   compiler lowers to v_rcp_f32 + v_mul_f32 (~4x faster than v_div_* on
   AMD GPUs).

4. Add fused-epilogue correctness tests:
   - test_fused_epilogue_correctness parametrized over bias / bias_relu /
     bias_silu / bias_gelu, comparing against a torch reference.
   - test_fused_epilogue_rejects_cshuffle covering the new guard from #1.
   Previously every test ran with epilogue='none' + dummy bias, so the
   actual fusion path had no coverage.

* Fix latent bias_relu/bias_gelu codegen bugs caught by new epilogue tests

When running the new test_fused_epilogue_correctness tests on MI300X
(gfx942) two pre-existing latent bugs in body_row's epilogue path
showed up. Neither was reachable before because no test ever exercised
epilogue != 'none'.

1. bias_relu: arith.cmpf was called with the string predicate 'ogt',
   which the underlying MLIR binding rejects (it expects an integer
   CmpFPredicate enum value). Replaced cmpf+select with arith.maximumf,
   which is both correct and more concise.

2. bias_gelu: math.tanh has no AMD libcall ('no libcall available for
   ftanh'), so any kernel using the tanh-approx GeLU failed to lower.
   Rewrote tanh in terms of math.exp using a numerically stable form
   that only ever evaluates exp(-2|y|) (in [0, 1]), so we never overflow
   fp32 even for large activations. The (1 + tanh(y)) factor used by
   GeLU is then formed branchlessly via cmpf+select on the sign of y.

Verified on MI300X: 4/4 fused epilogue correctness tests pass, the
guard test passes, and all 103 pre-existing tests still pass.

* Tighten epilogue tests + remove dead code in GeLU rewrite

- Remove unused is_pos/is_neg lines and stale comments in the GeLU
  branchless tanh expansion (the cmpf+select on the sign of y is the
  only thing actually used).
- Tighten test_fused_epilogue_correctness: explicit NaN/Inf assertions
  before the value comparison, and document the bf16 tolerance choice
  (atol=2.0, rtol=0.05) based on K=8192 reduction error.

Verified: 103 passed, 20 skipped on gfx942.

---------

Co-authored-by: Andy Luo <andyluo7@users.noreply.github.com>
Co-authored-by: Felix Li <felix.li@amd.com>
* [ROCDL] Add CDNA4_MFMAScaleType
…h acc_res (#419)

- Raise regression thresholds to 10us absolute / 15% relative
- Replace aggregate 'error' field (always NaN) with 'acc_res' showing pass/failed

Co-authored-by: root <root@tus1-p3-g24.cluster.local>
* [Agent] New skill: add-target-atom-op
* fix comments
* [OPT] Add fly-int-swizzle-simplify pass

Add an algebraic simplification pass that recognizes the canonical
three-instruction swizzle sequence (andi/shrui/xori) emitted by
`applySwizzle` and peels period-aligned addends out of the swizzle:

    swizzle(base + d)  →  swizzle(base) + d   when d % period == 0
---------

Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
lalala-sh and others added 2 commits April 23, 2026 22:46
* [Perf] Port aiter mixed_moe kernel optimizations for stage1/stage2

Port performance-critical optimizations from aiter's mixed_moe_gemm_2stage
kernel body (both stage1 and stage2) into FlyDSL, along with supporting
infrastructure changes.

Key changes:
- mixed_moe_gemm_2stage.py: Full kernel body replacement with aiter version
  featuring dual SmemAllocator (ping-pong), unified MFMA pipeline schedule,
  _barrier() for fine-grained waitcnt control, and new parameters (persist_m,
  fuse_fp4_quant, fuse_sort_scale, use_async_copy, sort_block_m, etc.)
- layout_utils.py: New file ported from aiter for layout index arithmetic
  (crd2idx, idx2crd, _div_pow2, _mod_pow2)
- silu_and_mul_fq.py: New file ported from aiter for split-K + fp4 quant
  after silu fusion
- mfma_preshuffle_pipeline.py: Added k_major support, cache_modifier param,
  bitwise-AND optimization in swizzle_xor16, PreshuffleScaleLayout additions
- kernels_common.py: Extracted shared _if_then context manager and
  validate_moe_dtypes helper
- mfma_epilogues.py: Replaced local _if_then with shared import

Performance (DeepSeek TP8 FP4, 7168x256, E=256, K=8):
- Stage1 Decode t=1: 37.3 -> 26.2 us (-29.8%)
- Stage1 Decode t=8: 45.0 -> 31.0 us (-31.1%)
- Stage1 Prefill 8K: 561.8 -> 348.8 us (-37.9%)
- Stage2 Prefill 8K reduce: 569.1 -> 534.8 us (-6.0%)
- FP8 stage2 unchanged (within noise)
* fix ci

* merge a8w4 moe

* add a8w4 bench
* improve fused rope kernel

* address comments
@aryaman-gupta aryaman-gupta merged commit 98c9d10 into aryaman/group-gemm Apr 23, 2026
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.