Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/api/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ GEMM Kernels
MoE (Mixture-of-Experts) Kernels
----------------------------------

- ``kernels.moe_gemm_2stage`` -- MoE GEMM with 2-stage pipeline (stage1 + stage2)
- ``kernels.moe_gemm_2stage`` -- CDNA / MFMA MoE GEMM with 2-stage pipeline
- ``kernels.rdna_moe_gemm_2stage`` -- RDNA4 (``gfx120x`` / ``gfx1201``) MoE
GEMM 2-stage, fp16/bf16 WMMA
- ``kernels.moe_gemm_2stage_wmma_gfx1250`` -- gfx1250 (MI450) MoE GEMM
2-stage, fp16/bf16 WMMA with TDM
- ``kernels.mixed_moe_gemm_2stage`` -- Mixed-precision MoE GEMM
- ``kernels.moe_blockscale_2stage`` -- MoE with block-scale quantization (MXFP4)
- ``kernels.moe_reduce`` -- MoE reduction kernel: sums over the topk dimension
Expand Down
1 change: 1 addition & 0 deletions docs/architecture_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ FlyDSL/
│ ├── blockscale_preshuffle_gemm.py # Blockscale GEMM
│ ├── hgemm_splitk.py # FP16 GEMM split-K
│ ├── moe_gemm_2stage.py # MoE GEMM (2-stage gate/up + reduce)
│ ├── rdna_moe_gemm_2stage.py # RDNA4 (gfx120x) MoE GEMM (fp16/bf16 WMMA)
│ ├── moe_blockscale_2stage.py # MoE Blockscale GEMM
│ ├── mixed_moe_gemm_2stage.py # Mixed-precision MoE GEMM
│ ├── pa_decode_fp8.py # Paged attention decode (FP8)
Expand Down
30 changes: 27 additions & 3 deletions docs/prebuilt_kernels_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,13 @@ What operation do you need?
├── MoE (Mixture of Experts)
│ ├── Blockscale MoE (gate+up+reduce)
│ │ └── → kernels/moe_blockscale_2stage.py
│ └── Standard MoE (fp8/f16/bf16/int8/int4)
│ └── → kernels/moe_gemm_2stage.py
│ ├── Standard MoE (CDNA / MFMA, fp8/f16/bf16/int8/int4)
│ │ └── → kernels/moe_gemm_2stage.py
│ ├── RDNA4 MoE (gfx120x / gfx1201, fp16/bf16 WMMA)
│ │ └── → kernels/rdna_moe_gemm_2stage.py
│ └── GFX1250 MoE (MI450, WMMA fp16/bf16 + MXScale fp4/fp8/a8w4)
│ ├── → kernels/moe_gemm_2stage_wmma_gfx1250.py
│ └── → kernels/moe_gemm_2stage_mxscale_gfx1250.py
└── Building blocks
├── Warp/block reduction → kernels_common.py
Expand All @@ -301,7 +306,10 @@ What operation do you need?
| `kernels/preshuffle_gemm.py` | GEMM (preshuffle layout) |
| `kernels/blockscale_preshuffle_gemm.py` | Blockscale GEMM |
| `kernels/hgemm_splitk.py` | FP16 GEMM split-K |
| `kernels/moe_gemm_2stage.py` | MoE GEMM 2-stage (gate/up + reduce) |
| `kernels/moe_gemm_2stage.py` | MoE GEMM 2-stage (gate/up + reduce), CDNA / MFMA |
| `kernels/rdna_moe_gemm_2stage.py` | RDNA4 (gfx120x) MoE GEMM 2-stage, fp16/bf16 WMMA |
| `kernels/moe_gemm_2stage_wmma_gfx1250.py` | gfx1250 MoE GEMM 2-stage, fp16/bf16 WMMA |
| `kernels/moe_gemm_2stage_mxscale_gfx1250.py` | gfx1250 MoE GEMM 2-stage, fp4/fp8/a8w4 MXScale |
| `kernels/moe_blockscale_2stage.py` | MoE Blockscale 2-stage |
| `kernels/mixed_moe_gemm_2stage.py` | Mixed-precision MoE GEMM |
| `kernels/pa_decode_fp8.py` | Paged attention decode (FP8) |
Expand Down Expand Up @@ -330,6 +338,7 @@ What operation do you need?
| `tests/kernels/test_blockscale_preshuffle_gemm.py` | Blockscale GEMM |
| `tests/kernels/test_hgemm_splitk.py` | FP16 GEMM split-K |
| `tests/kernels/test_moe_gemm.py` | MoE GEMM |
| `tests/kernels/test_moe_gemm_rdna4.py` | RDNA4 MoE GEMM |
| `tests/kernels/test_moe_blockscale.py` | MoE Blockscale GEMM |
| `tests/kernels/test_moe_reduce.py` | MoE reduce kernel |
| `tests/kernels/test_pa.py` | Paged attention decode |
Expand All @@ -345,3 +354,18 @@ What operation do you need?
| `tests/kernels/test_vec_add.py` | Vector addition |
| `tests/kernels/test_quant.py` | Quantization utilities |
| `tests/kernels/benchmark_common.py` | Shared benchmark infrastructure |

## 9. RDNA4 MoE Notes

`kernels/rdna_moe_gemm_2stage.py` targets `gfx120x` only (Radeon RDNA4,
including `gfx1201`). It uses ``wmma_f32_16x16x16_{f16,bf16}`` with a simple
LDS pipeline and reuses the public `compile_moe_gemm1` / `compile_moe_gemm2`
/ `compile_moe_gemm2_ex` contract via the `make_moe_public_api` factory in
`kernels/moe_gemm_2stage.py`.

Measured starting points on `gfx1201`:

- Stage1: `tile_k=128`, `tile_n=64` for `tile_m` 16/32, and `tile_n=128` for `tile_m=64`
- Stage2: `tile_k=128`, `tile_n=64`
- `waves_per_eu=2` often helps stage1, while stage2 remains workload-dependent
- Reduce mode can outperform atomic mode for medium and large routed workloads, so both modes should be benchmarked on target shapes
97 changes: 97 additions & 0 deletions kernels/moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3392,6 +3392,103 @@ def mode(self) -> str:
return MoeGemm2Mode.REDUCE


# ---------------------------------------------------------------------------
# Arch-agnostic MoE public API factory
# ---------------------------------------------------------------------------
#
# Arch-specific MoE kernel modules (CDNA MFMA here, RDNA4 in
# ``rdna_moe_gemm_2stage.py``, gfx1250 in ``moe_gemm_2stage_wmma_gfx1250.py``)
# share the same public builder shape. ``make_moe_public_api`` generates
# ``compile_moe_gemm1`` / ``compile_moe_gemm2`` / ``compile_moe_gemm2_ex``
# bound to a given arch-specific ``compile_impl`` so each arch file does not
# have to hand-roll the same wrappers.


# Extra kwargs accepted at the public API layer so callers can stay uniform
# across CDNA / gfx1250 / RDNA4 even if some options are arch-specific; we
# strip the ones the target ``compile_impl`` does not actually use.
_MOE_PUBLIC_EXTRA_KWARGS = (
"group_size",
"use_cshuffle_epilog",
"num_buffers",
"use_tdm_gather",
"use_tdm_store",
"inst_prefetch",
"wave_specialized_tdm",
"cluster_m",
"cluster_n",
)


def _moe_strip_extras(kw: dict, allowed_extras: tuple = ()) -> dict:
result = dict(kw)
for key in _MOE_PUBLIC_EXTRA_KWARGS:
if key in allowed_extras:
continue
result.pop(key, None)
return result


def make_moe_public_api(compile_impl, *, pass_through_kwargs: tuple = ()):
"""Create ``compile_moe_gemm1`` / ``compile_moe_gemm2`` / ``compile_moe_gemm2_ex``.

``compile_impl`` must accept ``stage``, ``doweight``, ``accumulate`` and
the usual MoE kwargs (``model_dim``, ``inter_dim``, ``experts``, ``topk``,
``tile_m``, ``tile_n``, ``tile_k``, ``in_dtype``, ``out_dtype``,
``waves_per_eu``, ``expert_sched_mode``). ``pass_through_kwargs`` lets
arch-specific builders opt in to receiving extra public kwargs (e.g.
gfx1250 TDM / cluster knobs) that would otherwise be stripped.
"""

def compile_moe_gemm1(*, doweight_stage1, **kw):
kw = _moe_strip_extras(kw, pass_through_kwargs)
return compile_impl(stage=1, doweight=doweight_stage1, **kw)

def compile_moe_gemm2(*, doweight_stage2, accumulate=True, **kw):
kw = _moe_strip_extras(kw, pass_through_kwargs)
return compile_impl(
stage=2,
doweight=doweight_stage2,
accumulate=accumulate,
**kw,
)

def compile_moe_gemm2_ex(
*,
mode=MoeGemm2Mode.ATOMIC,
valid_mask=None,
zero_intermediate=True,
**kw,
):
if mode == MoeGemm2Mode.REDUCE:
gemm2_exe = compile_moe_gemm2(accumulate=False, **kw)
out_s = str(kw.get("out_dtype", "f16")).strip().lower()
if out_s in ("f16", "fp16", "half"):
dtype_str = "f16"
elif out_s in ("bf16", "bfloat16"):
dtype_str = "bf16"
else:
dtype_str = "f32"
reduce_exe = compile_moe_reduction(
topk=kw["topk"],
model_dim=kw["model_dim"],
dtype_str=dtype_str,
use_mask=(valid_mask is not None),
)
return _MoeGemm2ReduceWrapper(
gemm2_exe=gemm2_exe,
reduce_exe=reduce_exe,
topk=kw["topk"],
model_dim=kw["model_dim"],
out_dtype_str=dtype_str,
use_mask=(valid_mask is not None),
zero_intermediate=zero_intermediate,
)
return compile_moe_gemm2(accumulate=True, **kw)

return compile_moe_gemm1, compile_moe_gemm2, compile_moe_gemm2_ex


def compile_moe_gemm2_ex(
*,
model_dim: int,
Expand Down
55 changes: 9 additions & 46 deletions kernels/moe_gemm_2stage_wmma_gfx1250.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# Copyright (c) 2025 FlyDSL Project Contributors


"""gfx1250 MoE 2-stage fp16 WMMA kernels.
"""gfx1250 (MI450 / GFX12) MoE 2-stage fp16/bf16 WMMA kernels.

Implements stage1/stage2 single-kernel inline paths using the
``wmma_f32_16x16x32_f16`` instruction for fp16 (and bf16 via host
conversion) inputs.
Implements the single-kernel stage1/stage2 paths for gfx1250 using the
``wmma_f32_16x16x32_f16`` instruction together with TDM helpers. RDNA4
(gfx120x) has a different WMMA ISA and lives in ``rdna_moe_gemm_2stage.py``.
"""

from __future__ import annotations
Expand All @@ -15,10 +15,7 @@

from flydsl.runtime.device import get_rocm_arch as get_hip_arch

from kernels.moe_gemm_2stage import (
MoeGemm2Mode,
compile_moe_reduction,
)
from kernels.moe_gemm_2stage import MoeGemm2Mode, make_moe_public_api
from kernels.moe_gemm_2stage_common_gfx1250 import (
_bf16_to_f16_wrapper,
_emit_stage1_gate_up_epilogue,
Expand Down Expand Up @@ -813,6 +810,7 @@ def launch_fp16_stage2_single(
# Public API entry points for fp16/bf16
# ---------------------------------------------------------------------------


@functools.lru_cache(maxsize=1024)
def _compile_moe_wmma_gemm(
*,
Expand Down Expand Up @@ -872,41 +870,6 @@ def _compile_moe_wmma_gemm(
return exe


def compile_moe_gemm1(*, doweight_stage1, group_size=-1, use_cshuffle_epilog=None,
num_buffers=1, use_tdm_gather=True, use_tdm_store=False,
inst_prefetch=False, wave_specialized_tdm=False,
cluster_m=1, cluster_n=1, **kw):
return _compile_moe_wmma_gemm(stage=1, doweight=doweight_stage1, **kw)


def compile_moe_gemm2(*, doweight_stage2, accumulate=True, group_size=-1,
use_cshuffle_epilog=None,
num_buffers=1, use_tdm_gather=True, use_tdm_store=False,
inst_prefetch=False, wave_specialized_tdm=False,
cluster_m=1, cluster_n=1, **kw):
return _compile_moe_wmma_gemm(stage=2, doweight=doweight_stage2, accumulate=accumulate, **kw)


def compile_moe_gemm2_ex(*, mode=MoeGemm2Mode.ATOMIC, valid_mask=None, zero_intermediate=True, **kw):
if mode == MoeGemm2Mode.REDUCE:
gemm2_exe = compile_moe_gemm2(accumulate=False, **kw)
out_s = str(kw.get("out_dtype", "f16")).strip().lower()
if out_s in ("f16", "fp16", "half"):
dtype_str = "f16"
elif out_s in ("bf16", "bfloat16"):
dtype_str = "bf16"
else:
dtype_str = "f32"
reduce_exe = compile_moe_reduction(
topk=kw["topk"], model_dim=kw["model_dim"],
dtype_str=dtype_str, use_mask=(valid_mask is not None),
)
from kernels.moe_gemm_2stage import _MoeGemm2ReduceWrapper
return _MoeGemm2ReduceWrapper(
gemm2_exe=gemm2_exe, reduce_exe=reduce_exe,
topk=kw["topk"], model_dim=kw["model_dim"],
out_dtype_str=dtype_str,
use_mask=(valid_mask is not None),
zero_intermediate=zero_intermediate,
)
return compile_moe_gemm2(accumulate=True, **kw)
compile_moe_gemm1, compile_moe_gemm2, compile_moe_gemm2_ex = make_moe_public_api(
_compile_moe_wmma_gemm
)
Loading
Loading