diff --git a/docs/api/kernels.rst b/docs/api/kernels.rst index 6cc0d41b0..a490099fa 100644 --- a/docs/api/kernels.rst +++ b/docs/api/kernels.rst @@ -16,7 +16,11 @@ GEMM Kernels MoE (Mixture-of-Experts) Kernels ---------------------------------- -- ``kernels.moe_gemm_2stage`` -- MoE GEMM with 2-stage pipeline (stage1 + stage2) +- ``kernels.moe_gemm_2stage`` -- CDNA / MFMA MoE GEMM with 2-stage pipeline +- ``kernels.rdna_moe_gemm_2stage`` -- RDNA4 (``gfx120x`` / ``gfx1201``) MoE + GEMM 2-stage, fp16/bf16 WMMA +- ``kernels.moe_gemm_2stage_wmma_gfx1250`` -- gfx1250 (MI450) MoE GEMM + 2-stage, fp16/bf16 WMMA with TDM - ``kernels.mixed_moe_gemm_2stage`` -- Mixed-precision MoE GEMM - ``kernels.moe_blockscale_2stage`` -- MoE with block-scale quantization (MXFP4) - ``kernels.moe_reduce`` -- MoE reduction kernel: sums over the topk dimension diff --git a/docs/architecture_guide.md b/docs/architecture_guide.md index 6b0eb9e21..f475564ca 100644 --- a/docs/architecture_guide.md +++ b/docs/architecture_guide.md @@ -82,6 +82,7 @@ FlyDSL/ │ ├── blockscale_preshuffle_gemm.py # Blockscale GEMM │ ├── hgemm_splitk.py # FP16 GEMM split-K │ ├── moe_gemm_2stage.py # MoE GEMM (2-stage gate/up + reduce) +│ ├── rdna_moe_gemm_2stage.py # RDNA4 (gfx120x) MoE GEMM (fp16/bf16 WMMA) │ ├── moe_blockscale_2stage.py # MoE Blockscale GEMM │ ├── mixed_moe_gemm_2stage.py # Mixed-precision MoE GEMM │ ├── pa_decode_fp8.py # Paged attention decode (FP8) diff --git a/docs/prebuilt_kernels_guide.md b/docs/prebuilt_kernels_guide.md index 018b122f1..83f02d8b3 100644 --- a/docs/prebuilt_kernels_guide.md +++ b/docs/prebuilt_kernels_guide.md @@ -283,8 +283,13 @@ What operation do you need? ├── MoE (Mixture of Experts) │ ├── Blockscale MoE (gate+up+reduce) │ │ └── → kernels/moe_blockscale_2stage.py -│ └── Standard MoE (fp8/f16/bf16/int8/int4) -│ └── → kernels/moe_gemm_2stage.py +│ ├── Standard MoE (CDNA / MFMA, fp8/f16/bf16/int8/int4) +│ │ └── → kernels/moe_gemm_2stage.py +│ ├── RDNA4 MoE (gfx120x / gfx1201, fp16/bf16 WMMA) +│ │ └── → kernels/rdna_moe_gemm_2stage.py +│ └── GFX1250 MoE (MI450, WMMA fp16/bf16 + MXScale fp4/fp8/a8w4) +│ ├── → kernels/moe_gemm_2stage_wmma_gfx1250.py +│ └── → kernels/moe_gemm_2stage_mxscale_gfx1250.py │ └── Building blocks ├── Warp/block reduction → kernels_common.py @@ -301,7 +306,10 @@ What operation do you need? | `kernels/preshuffle_gemm.py` | GEMM (preshuffle layout) | | `kernels/blockscale_preshuffle_gemm.py` | Blockscale GEMM | | `kernels/hgemm_splitk.py` | FP16 GEMM split-K | -| `kernels/moe_gemm_2stage.py` | MoE GEMM 2-stage (gate/up + reduce) | +| `kernels/moe_gemm_2stage.py` | MoE GEMM 2-stage (gate/up + reduce), CDNA / MFMA | +| `kernels/rdna_moe_gemm_2stage.py` | RDNA4 (gfx120x) MoE GEMM 2-stage, fp16/bf16 WMMA | +| `kernels/moe_gemm_2stage_wmma_gfx1250.py` | gfx1250 MoE GEMM 2-stage, fp16/bf16 WMMA | +| `kernels/moe_gemm_2stage_mxscale_gfx1250.py` | gfx1250 MoE GEMM 2-stage, fp4/fp8/a8w4 MXScale | | `kernels/moe_blockscale_2stage.py` | MoE Blockscale 2-stage | | `kernels/mixed_moe_gemm_2stage.py` | Mixed-precision MoE GEMM | | `kernels/pa_decode_fp8.py` | Paged attention decode (FP8) | @@ -330,6 +338,7 @@ What operation do you need? | `tests/kernels/test_blockscale_preshuffle_gemm.py` | Blockscale GEMM | | `tests/kernels/test_hgemm_splitk.py` | FP16 GEMM split-K | | `tests/kernels/test_moe_gemm.py` | MoE GEMM | +| `tests/kernels/test_moe_gemm_rdna4.py` | RDNA4 MoE GEMM | | `tests/kernels/test_moe_blockscale.py` | MoE Blockscale GEMM | | `tests/kernels/test_moe_reduce.py` | MoE reduce kernel | | `tests/kernels/test_pa.py` | Paged attention decode | @@ -345,3 +354,18 @@ What operation do you need? | `tests/kernels/test_vec_add.py` | Vector addition | | `tests/kernels/test_quant.py` | Quantization utilities | | `tests/kernels/benchmark_common.py` | Shared benchmark infrastructure | + +## 9. RDNA4 MoE Notes + +`kernels/rdna_moe_gemm_2stage.py` targets `gfx120x` only (Radeon RDNA4, +including `gfx1201`). It uses ``wmma_f32_16x16x16_{f16,bf16}`` with a simple +LDS pipeline and reuses the public `compile_moe_gemm1` / `compile_moe_gemm2` +/ `compile_moe_gemm2_ex` contract via the `make_moe_public_api` factory in +`kernels/moe_gemm_2stage.py`. + +Measured starting points on `gfx1201`: + +- Stage1: `tile_k=128`, `tile_n=64` for `tile_m` 16/32, and `tile_n=128` for `tile_m=64` +- Stage2: `tile_k=128`, `tile_n=64` +- `waves_per_eu=2` often helps stage1, while stage2 remains workload-dependent +- Reduce mode can outperform atomic mode for medium and large routed workloads, so both modes should be benchmarked on target shapes diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index 06eddb248..cf77f3a18 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -3392,6 +3392,103 @@ def mode(self) -> str: return MoeGemm2Mode.REDUCE +# --------------------------------------------------------------------------- +# Arch-agnostic MoE public API factory +# --------------------------------------------------------------------------- +# +# Arch-specific MoE kernel modules (CDNA MFMA here, RDNA4 in +# ``rdna_moe_gemm_2stage.py``, gfx1250 in ``moe_gemm_2stage_wmma_gfx1250.py``) +# share the same public builder shape. ``make_moe_public_api`` generates +# ``compile_moe_gemm1`` / ``compile_moe_gemm2`` / ``compile_moe_gemm2_ex`` +# bound to a given arch-specific ``compile_impl`` so each arch file does not +# have to hand-roll the same wrappers. + + +# Extra kwargs accepted at the public API layer so callers can stay uniform +# across CDNA / gfx1250 / RDNA4 even if some options are arch-specific; we +# strip the ones the target ``compile_impl`` does not actually use. +_MOE_PUBLIC_EXTRA_KWARGS = ( + "group_size", + "use_cshuffle_epilog", + "num_buffers", + "use_tdm_gather", + "use_tdm_store", + "inst_prefetch", + "wave_specialized_tdm", + "cluster_m", + "cluster_n", +) + + +def _moe_strip_extras(kw: dict, allowed_extras: tuple = ()) -> dict: + result = dict(kw) + for key in _MOE_PUBLIC_EXTRA_KWARGS: + if key in allowed_extras: + continue + result.pop(key, None) + return result + + +def make_moe_public_api(compile_impl, *, pass_through_kwargs: tuple = ()): + """Create ``compile_moe_gemm1`` / ``compile_moe_gemm2`` / ``compile_moe_gemm2_ex``. + + ``compile_impl`` must accept ``stage``, ``doweight``, ``accumulate`` and + the usual MoE kwargs (``model_dim``, ``inter_dim``, ``experts``, ``topk``, + ``tile_m``, ``tile_n``, ``tile_k``, ``in_dtype``, ``out_dtype``, + ``waves_per_eu``, ``expert_sched_mode``). ``pass_through_kwargs`` lets + arch-specific builders opt in to receiving extra public kwargs (e.g. + gfx1250 TDM / cluster knobs) that would otherwise be stripped. + """ + + def compile_moe_gemm1(*, doweight_stage1, **kw): + kw = _moe_strip_extras(kw, pass_through_kwargs) + return compile_impl(stage=1, doweight=doweight_stage1, **kw) + + def compile_moe_gemm2(*, doweight_stage2, accumulate=True, **kw): + kw = _moe_strip_extras(kw, pass_through_kwargs) + return compile_impl( + stage=2, + doweight=doweight_stage2, + accumulate=accumulate, + **kw, + ) + + def compile_moe_gemm2_ex( + *, + mode=MoeGemm2Mode.ATOMIC, + valid_mask=None, + zero_intermediate=True, + **kw, + ): + if mode == MoeGemm2Mode.REDUCE: + gemm2_exe = compile_moe_gemm2(accumulate=False, **kw) + out_s = str(kw.get("out_dtype", "f16")).strip().lower() + if out_s in ("f16", "fp16", "half"): + dtype_str = "f16" + elif out_s in ("bf16", "bfloat16"): + dtype_str = "bf16" + else: + dtype_str = "f32" + reduce_exe = compile_moe_reduction( + topk=kw["topk"], + model_dim=kw["model_dim"], + dtype_str=dtype_str, + use_mask=(valid_mask is not None), + ) + return _MoeGemm2ReduceWrapper( + gemm2_exe=gemm2_exe, + reduce_exe=reduce_exe, + topk=kw["topk"], + model_dim=kw["model_dim"], + out_dtype_str=dtype_str, + use_mask=(valid_mask is not None), + zero_intermediate=zero_intermediate, + ) + return compile_moe_gemm2(accumulate=True, **kw) + + return compile_moe_gemm1, compile_moe_gemm2, compile_moe_gemm2_ex + + def compile_moe_gemm2_ex( *, model_dim: int, diff --git a/kernels/moe_gemm_2stage_wmma_gfx1250.py b/kernels/moe_gemm_2stage_wmma_gfx1250.py index 0dd7d9355..b17cb6260 100644 --- a/kernels/moe_gemm_2stage_wmma_gfx1250.py +++ b/kernels/moe_gemm_2stage_wmma_gfx1250.py @@ -2,11 +2,11 @@ # Copyright (c) 2025 FlyDSL Project Contributors -"""gfx1250 MoE 2-stage fp16 WMMA kernels. +"""gfx1250 (MI450 / GFX12) MoE 2-stage fp16/bf16 WMMA kernels. -Implements stage1/stage2 single-kernel inline paths using the -``wmma_f32_16x16x32_f16`` instruction for fp16 (and bf16 via host -conversion) inputs. +Implements the single-kernel stage1/stage2 paths for gfx1250 using the +``wmma_f32_16x16x32_f16`` instruction together with TDM helpers. RDNA4 +(gfx120x) has a different WMMA ISA and lives in ``rdna_moe_gemm_2stage.py``. """ from __future__ import annotations @@ -15,10 +15,7 @@ from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from kernels.moe_gemm_2stage import ( - MoeGemm2Mode, - compile_moe_reduction, -) +from kernels.moe_gemm_2stage import MoeGemm2Mode, make_moe_public_api from kernels.moe_gemm_2stage_common_gfx1250 import ( _bf16_to_f16_wrapper, _emit_stage1_gate_up_epilogue, @@ -813,6 +810,7 @@ def launch_fp16_stage2_single( # Public API entry points for fp16/bf16 # --------------------------------------------------------------------------- + @functools.lru_cache(maxsize=1024) def _compile_moe_wmma_gemm( *, @@ -872,41 +870,6 @@ def _compile_moe_wmma_gemm( return exe -def compile_moe_gemm1(*, doweight_stage1, group_size=-1, use_cshuffle_epilog=None, - num_buffers=1, use_tdm_gather=True, use_tdm_store=False, - inst_prefetch=False, wave_specialized_tdm=False, - cluster_m=1, cluster_n=1, **kw): - return _compile_moe_wmma_gemm(stage=1, doweight=doweight_stage1, **kw) - - -def compile_moe_gemm2(*, doweight_stage2, accumulate=True, group_size=-1, - use_cshuffle_epilog=None, - num_buffers=1, use_tdm_gather=True, use_tdm_store=False, - inst_prefetch=False, wave_specialized_tdm=False, - cluster_m=1, cluster_n=1, **kw): - return _compile_moe_wmma_gemm(stage=2, doweight=doweight_stage2, accumulate=accumulate, **kw) - - -def compile_moe_gemm2_ex(*, mode=MoeGemm2Mode.ATOMIC, valid_mask=None, zero_intermediate=True, **kw): - if mode == MoeGemm2Mode.REDUCE: - gemm2_exe = compile_moe_gemm2(accumulate=False, **kw) - out_s = str(kw.get("out_dtype", "f16")).strip().lower() - if out_s in ("f16", "fp16", "half"): - dtype_str = "f16" - elif out_s in ("bf16", "bfloat16"): - dtype_str = "bf16" - else: - dtype_str = "f32" - reduce_exe = compile_moe_reduction( - topk=kw["topk"], model_dim=kw["model_dim"], - dtype_str=dtype_str, use_mask=(valid_mask is not None), - ) - from kernels.moe_gemm_2stage import _MoeGemm2ReduceWrapper - return _MoeGemm2ReduceWrapper( - gemm2_exe=gemm2_exe, reduce_exe=reduce_exe, - topk=kw["topk"], model_dim=kw["model_dim"], - out_dtype_str=dtype_str, - use_mask=(valid_mask is not None), - zero_intermediate=zero_intermediate, - ) - return compile_moe_gemm2(accumulate=True, **kw) +compile_moe_gemm1, compile_moe_gemm2, compile_moe_gemm2_ex = make_moe_public_api( + _compile_moe_wmma_gemm +) diff --git a/kernels/rdna_moe_gemm_2stage.py b/kernels/rdna_moe_gemm_2stage.py new file mode 100644 index 000000000..a18f7198d --- /dev/null +++ b/kernels/rdna_moe_gemm_2stage.py @@ -0,0 +1,1069 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""RDNA4 (gfx120x) MoE 2-stage fp16/bf16 WMMA kernels. + +This path targets the Radeon RDNA4 WMMA ISA (``gfx120x``, including +``gfx1201``) using ``wmma_f32_16x16x16_{f16,bf16}`` and a simple LDS pipeline. +It is intentionally separate from the gfx1250 (MI450 / GFX12) TDM-based WMMA +path in ``moe_gemm_2stage_wmma_gfx1250.py``. + +Measured starting points on ``gfx1201``: +- stage1: ``tile_k=128``, ``tile_n=64`` for ``tile_m`` 16/32, ``tile_n=128`` for ``tile_m=64`` +- stage2: ``tile_k=128``, ``tile_n=64`` +- ``waves_per_eu=2`` often helps stage1, while stage2 remains workload-dependent +""" + +from __future__ import annotations + +import functools + +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + +from kernels.moe_gemm_2stage import MoeGemm2Mode, make_moe_public_api +from kernels.rdna_moe_gemm_2stage_common import ( + _emit_stage1_gate_up_epilogue, + _emit_stage2_store_epilogue, + _finalize_alloc_and_launch_2d, + _make_moe_wave_layout, + _make_wmma_sub_tiles, + _moe_out_elem_ty, + _pick_fp16_launch_shape, + _require_gfx120x, +) + + +def _validate_vectorized_tile( + tile_rows: int, tile_k: int, block_threads: int, tile_name: str +) -> int: + load_vec = 8 # 8 bf16/f16 elements = 16 bytes + total = int(tile_rows) * int(tile_k) + denom = int(block_threads) * load_vec + if total % denom != 0: + raise ValueError( + f"{tile_name} tile ({tile_rows}x{tile_k}) must be divisible by " + f"block_threads*{load_vec} ({denom}) for RDNA4 vectorized loads" + ) + return total // denom + + +def _set_expert_sched_hint(jit_fn, enabled: bool) -> None: + if enabled: + jit_fn.compile_hints["llvm_options"] = { + "amdgpu-expert-scheduling-mode": True, + } + + +class _MoeGemm2AtomicCastWrapper: + """Run an internal f32 atomic stage2 kernel and cast back. + + RDNA4 cannot lower scalar f16/bf16 buffer atomics cleanly, so for those + output dtypes we run an f32 atomic kernel into a temporary buffer and cast + back to the requested dtype on the host. + """ + + def __init__(self, gemm2_exe, model_dim: int, out_dtype_str: str): + self._gemm2_exe = gemm2_exe + self._model_dim = int(model_dim) + self._out_dtype_str = str(out_dtype_str).strip().lower() + self._cache = {} + for attr in ("compile_hints",): + if hasattr(gemm2_exe, attr): + setattr(self, attr, getattr(gemm2_exe, attr)) + + def _resolve_torch_stream(self, *, arg_out, stream): + import torch + + device_index = arg_out.device.index + if device_index is None: + device_index = torch.cuda.current_device() + if stream is None: + return torch.cuda.current_stream(device=device_index) + if isinstance(stream, int): + return torch.cuda.ExternalStream(stream, device=device_index) + return stream + + def _get_tmp(self, arg_out, tokens_in: int, torch_stream): + import torch + + device_index = arg_out.device.index + if device_index is None: + device_index = torch.cuda.current_device() + key = (int(device_index), int(tokens_in), int(torch_stream.cuda_stream)) + cached = self._cache.get(key) + if cached is not None: + return cached + with torch.cuda.device(device_index), torch.cuda.stream(torch_stream): + tmp = torch.empty( + int(tokens_in), + self._model_dim, + device=arg_out.device, + dtype=torch.float32, + ) + tmp.record_stream(torch_stream) + self._cache[key] = tmp + return tmp + + def __call__( + self, + arg_out, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_num_valid_ids, + tokens_in, + n_in, + k_in, + size_expert_ids_in, + stream=None, + ): + import torch + + torch_stream = self._resolve_torch_stream(arg_out=arg_out, stream=stream) + device_index = arg_out.device.index + if device_index is None: + device_index = torch.cuda.current_device() + tmp = self._get_tmp(arg_out, tokens_in, torch_stream) + with torch.cuda.device(device_index), torch.cuda.stream(torch_stream): + tmp.zero_() + self._gemm2_exe( + tmp, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_num_valid_ids, + tokens_in, + n_in, + k_in, + size_expert_ids_in, + torch_stream, + ) + arg_out_view = arg_out.view(int(tokens_in), self._model_dim) + arg_out_view.copy_(tmp) + tmp.record_stream(torch_stream) + + @property + def mode(self) -> str: + return MoeGemm2Mode.ATOMIC + + +@functools.lru_cache(maxsize=64) +def _compile_stage1_kernel_impl( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + route_tile_m: int, + tile_m: int, + tile_n: int, + tile_k: int, + m_warp: int, + n_warp: int, + doweight_stage1: bool, + in_dtype: str, + out_dtype: str, + waves_per_eu: int | None, + expert_sched_mode: bool = True, +): + """Compile RDNA4 stage1 single-kernel WMMA MoE path.""" + + import flydsl.compiler as flyc + import flydsl.expr as fx + from flydsl._mlir import ir + from flydsl._mlir.dialects import scf + from flydsl.compiler.kernel_function import CompilationContext + from flydsl.expr import arith, buffer_ops, gpu, idx2crd, range_constexpr, rocdl, vector + from flydsl.expr.typing import T + from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + + WMMA_M, WMMA_N, WMMA_K = 16, 16, 16 + WAVE_SIZE = 32 + LDS_PAD_A = 8 + LDS_PAD_B = 8 + LOAD_VEC = 8 + ELEM_BYTES = 2 + + if in_dtype not in ("fp16", "bf16"): + raise ValueError(f"RDNA4 stage1 only supports fp16/bf16, got {in_dtype!r}") + if out_dtype not in ("f16", "bf16"): + raise ValueError(f"RDNA4 stage1 only supports f16/bf16 outputs, got {out_dtype!r}") + if int(model_dim) % int(tile_k) != 0: + raise ValueError(f"model_dim={model_dim} must be divisible by tile_k={tile_k}") + if int(tile_k) % WMMA_K != 0: + raise ValueError(f"tile_k={tile_k} must be divisible by {WMMA_K}") + if int(tile_m) % WMMA_M != 0 or int(tile_n) % WMMA_N != 0: + raise ValueError(f"tile_m/tile_n must be multiples of 16, got ({tile_m},{tile_n})") + + block_threads = int(m_warp) * int(n_warp) * WAVE_SIZE + warp_tile_m = int(tile_m) // int(m_warp) + warp_tile_n = int(tile_n) // int(n_warp) + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + if wmma_m_rep <= 0 or wmma_n_rep <= 0: + raise ValueError( + "Invalid RDNA4 stage1 warp tiling: " + f"wmma_m_rep={wmma_m_rep}, wmma_n_rep={wmma_n_rep}" + ) + + num_k_tiles = int(model_dim) // int(tile_k) + k_wmma_steps = int(tile_k) // WMMA_K + n_total = int(2 * inter_dim) + num_a_loads = _validate_vectorized_tile(tile_m, tile_k, block_threads, "stage1 A") + num_b_loads = _validate_vectorized_tile(tile_n, tile_k, block_threads, "stage1 B") + sub_tiles = _make_wmma_sub_tiles( + wmma_m_rep=wmma_m_rep, + wmma_n_rep=wmma_n_rep, + WMMA_M=WMMA_M, + is_fp4=False, + ) + + lds_a_stride = int(tile_k) + LDS_PAD_A + lds_b_stride = int(tile_k) + LDS_PAD_B + lds_a_elems = int(tile_m) * lds_a_stride + LDS_PAD_A + lds_b_elems = int(tile_n) * lds_b_stride + LDS_PAD_B + + gpu_arch = str(get_hip_arch()) + alloc = SmemAllocator(None, arch=gpu_arch, global_sym_name="moe_rdna4_s1") + off_a = alloc._align(alloc.ptr, 16) + alloc.ptr = off_a + lds_a_elems * ELEM_BYTES + off_b = alloc._align(alloc.ptr, 16) + alloc.ptr = off_b + lds_b_elems * ELEM_BYTES + + @flyc.kernel(known_block_size=[block_threads, 1, 1]) + def moe_rdna4_stage1( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_inter_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): + _ = (arg_scale_x, arg_scale_w, arg_num_valid_ids, i32_k_in) + + in_ir_ty = T.bf16 if in_dtype == "bf16" else T.f16 + v8_in_ty = T.vec(8, in_ir_ty) + v4f32_ty = T.f32x4 + v8f32_ty = T.vec(8, T.f32) + v8i16_ty = T.vec(8, T.i16) + zero_raw = arith.constant_vector(0.0, v4f32_ty) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") # inter tile + by = gpu.block_id("y") # expert block + + tokens_idx = arith.index_cast(T.index, i32_tokens_in) + inter_idx = arith.index_cast(T.index, i32_inter_in) + size_expert_ids = arith.index_cast(T.index, i32_size_expert_ids_in) + + sorted_num = size_expert_ids * arith.index(int(route_tile_m)) + sorted_nbytes = sorted_num * arith.index(4) + eid_nbytes = size_expert_ids * arith.index(4) + x_nbytes = tokens_idx * arith.index(int(model_dim)) * arith.index(2) + w_nbytes = arith.index(int(experts * n_total * int(model_dim) * ELEM_BYTES)) + + sorted_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes + ) + eid_rsrc = buffer_ops.create_buffer_resource( + arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes + ) + x_rsrc = buffer_ops.create_buffer_resource( + arg_x, max_size=False, num_records_bytes=x_nbytes + ) + w_rsrc = buffer_ops.create_buffer_resource( + arg_w, max_size=False, num_records_bytes=w_nbytes + ) + out_rsrc = buffer_ops.create_buffer_resource(arg_out, max_size=True) + sw_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=True) + + eid_i32 = buffer_ops.buffer_load( + eid_rsrc, arith.index_cast(T.i32, by), vec_width=1, dtype=T.i32 + ) + eid_ok0 = arith.cmpi( + arith.CmpIPredicate.sge, eid_i32, arith.constant(0, type=T.i32) + ) + eid_ok1 = arith.cmpi( + arith.CmpIPredicate.slt, eid_i32, arith.constant(int(experts), type=T.i32) + ) + eid_ok = arith.andi(eid_ok0, eid_ok1) + + layout_thr = _make_moe_wave_layout( + m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx + ) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + fx.get(thr_coord, 0), + fx.get(thr_coord, 1), + fx.get(thr_coord, 2), + fx.get(thr_coord, 3), + ) + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + blk_n = bx * arith.index(int(tile_n)) + model_dim_i32 = arith.constant(int(model_dim), type=T.i32) + n_total_i32 = arith.constant(int(n_total), type=T.i32) + c2_i32 = arith.constant(2, type=T.i32) + base8 = lane_kgrp * fx.Index(8) + + base_ptr = alloc.get_base() + smem_a = SmemPtr(base_ptr, off_a, in_ir_ty, shape=(lds_a_elems,)) + smem_b = SmemPtr(base_ptr, off_b, in_ir_ty, shape=(lds_b_elems,)) + lds_a = get_op_result_or_value(smem_a.get()) + lds_b = get_op_result_or_value(smem_b.get()) + + def silu(x): + t = x * (-1.4426950408889634) + emu = rocdl.exp2(T.f32, t) + den = 1.0 + emu + sig = rocdl.rcp(T.f32, den) + return x * sig + + def _wmma_op(result_type, a_vec, b_vec, acc): + if in_dtype == "bf16": + a_i16 = vector.bitcast(v8i16_ty, a_vec) + b_i16 = vector.bitcast(v8i16_ty, b_vec) + return rocdl.wmma_f32_16x16x16_bf16( + result_type, a_i16, b_i16, arith.unwrap(acc) + ).result + return rocdl.wmma_f32_16x16x16_f16( + result_type, arith.unwrap(a_vec), arith.unwrap(b_vec), arith.unwrap(acc) + ).result + + a_lds_info = [] + for al in range_constexpr(num_a_loads): + a_lin = tx * fx.Index(LOAD_VEC) + fx.Index(al * block_threads * LOAD_VEC) + a_load_row = a_lin // fx.Index(tile_k) + a_load_col = a_lin % fx.Index(tile_k) + lds_rel = a_load_row * fx.Index(lds_a_stride) + a_load_col + a_lds_info.append((a_load_row, a_load_col, lds_rel)) + + b_lds_info = [] + for bl in range_constexpr(num_b_loads): + b_lin = tx * fx.Index(LOAD_VEC) + fx.Index(bl * block_threads * LOAD_VEC) + b_load_row = b_lin // fx.Index(tile_k) + b_load_col = b_lin % fx.Index(tile_k) + lds_rel = b_load_row * fx.Index(lds_b_stride) + b_load_col + b_lds_info.append((b_load_row, b_load_col, lds_rel)) + + def _load_a_tile(k_base): + raw_data = [] + for al in range_constexpr(num_a_loads): + a_load_row, a_load_col, _ = a_lds_info[al] + sorted_row = by * arith.index(int(tile_m)) + a_load_row + row_i32 = arith.index_cast(T.i32, a_load_row) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + row_i32, + arith.constant(int(route_tile_m), type=T.i32), + ) + sorted_safe = arith.select( + row_in_route, + sorted_i32, + arith.index_cast(T.i32, by * arith.index(int(route_tile_m))), + ) + fused = buffer_ops.buffer_load( + sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32 + ) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + load_ok = arith.andi(row_in_route, tok_ok) + elem_off = tok * model_dim_i32 + arith.index_cast(T.i32, k_base + a_load_col) + f32_off = elem_off // c2_i32 + raw_if = scf.IfOp(load_ok, results_=[v4f32_ty], has_else=True) + with ir.InsertionPoint(raw_if.then_block): + scf.YieldOp( + [buffer_ops.buffer_load(x_rsrc, f32_off, vec_width=4, dtype=T.f32)] + ) + with ir.InsertionPoint(raw_if.else_block): + scf.YieldOp([zero_raw]) + raw_data.append(raw_if.results[0]) + return raw_data + + def _load_b_tile(k_base, row_shift: int): + raw_data = [] + base_row = eid_i32 * n_total_i32 + arith.index_cast(T.i32, blk_n) + arith.constant( + int(row_shift), type=T.i32 + ) + for bl in range_constexpr(num_b_loads): + b_load_row, b_load_col, _ = b_lds_info[bl] + row_i32 = base_row + arith.index_cast(T.i32, b_load_row) + elem_off = row_i32 * model_dim_i32 + arith.index_cast(T.i32, k_base + b_load_col) + f32_off = elem_off // c2_i32 + raw_data.append( + buffer_ops.buffer_load(w_rsrc, f32_off, vec_width=4, dtype=T.f32) + ) + return raw_data + + def _store_a_tile(raw_data): + for al in range_constexpr(num_a_loads): + _, _, lds_rel = a_lds_info[al] + a_vec = vector.bitcast(v8_in_ty, raw_data[al]) + vector.store(a_vec, lds_a, [lds_rel]) + + def _store_b_tile(raw_data): + for bl in range_constexpr(num_b_loads): + _, _, lds_rel = b_lds_info[bl] + b_vec = vector.bitcast(v8_in_ty, raw_data[bl]) + vector.store(b_vec, lds_b, [lds_rel]) + + def _load_a_single_from_lds(rk, rm_val): + col_base = fx.Index(rk * WMMA_K) + base8 + row = warp_m_base + fx.Index(rm_val * WMMA_M) + lane16 + lds_idx = row * fx.Index(lds_a_stride) + col_base + return vector.load_op(v8_in_ty, lds_a, [lds_idx]) + + def _load_b_from_lds(rk): + vecs = [] + col_base = fx.Index(rk * WMMA_K) + base8 + for rn in range_constexpr(wmma_n_rep): + row = warp_n_base + fx.Index(rn * WMMA_N) + lane16 + lds_idx = row * fx.Index(lds_b_stride) + col_base + vecs.append(vector.load_op(v8_in_ty, lds_b, [lds_idx])) + return vecs + + def _do_compute_rk(accs_in, rk): + new_accs = list(accs_in) + b_vecs = _load_b_from_lds(rk) + for wm in range_constexpr(wmma_m_rep): + a_vec = _load_a_single_from_lds(rk, wm) + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + new_accs[idx] = _wmma_op(v8f32_ty, a_vec, b_vecs[wn], new_accs[idx]) + return new_accs + + acc_zero = arith.constant_vector(0.0, v8f32_ty) + acc_gate = [acc_zero] * (wmma_m_rep * wmma_n_rep) + acc_up = [acc_zero] * (wmma_m_rep * wmma_n_rep) + + _if_eid = scf.IfOp(eid_ok) + with ir.InsertionPoint(_if_eid.then_block): + for kt in range_constexpr(num_k_tiles): + k_base = fx.Index(kt * int(tile_k)) + a_data = _load_a_tile(k_base) + gate_data = _load_b_tile(k_base, 0) + _store_a_tile(a_data) + _store_b_tile(gate_data) + gpu.barrier() + + for ks in range_constexpr(k_wmma_steps): + acc_gate = _do_compute_rk(acc_gate, ks) + gpu.barrier() + + up_data = _load_b_tile(k_base, int(inter_dim)) + _store_b_tile(up_data) + gpu.barrier() + + for ks in range_constexpr(k_wmma_steps): + acc_up = _do_compute_rk(acc_up, ks) + gpu.barrier() + + out_elem_ty = _moe_out_elem_ty(out_dtype, T) + + def _load_gate_up_sub8(acc_idx, _vec_base): + return acc_gate[acc_idx], acc_up[acc_idx] + + _emit_stage1_gate_up_epilogue( + sub_tiles=sub_tiles, + by=by, + tile_m=int(tile_m), + route_tile_m=int(route_tile_m), + warp_m_base=warp_m_base, + warp_n_base=warp_n_base, + blk_n=blk_n, + lane16=lane16, + lane_kgrp=lane_kgrp, + WMMA_N=WMMA_N, + i32_tokens_in=i32_tokens_in, + i32_inter_in=i32_inter_in, + topk=int(topk), + sorted_rsrc=sorted_rsrc, + tw_rsrc=sw_rsrc, + out_rsrc=out_rsrc, + doweight_stage1=bool(doweight_stage1), + out_elem_ty=out_elem_ty, + load_gate_up_sub8=_load_gate_up_sub8, + silu_fn=silu, + ir=ir, + fx=fx, + arith=arith, + buffer_ops=buffer_ops, + scf=scf, + vector=vector, + range_constexpr=range_constexpr, + T=T, + ) + scf.YieldOp([]) + + @flyc.jit + def launch_stage1( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_inter_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + stream: fx.Stream, + ): + _ = (arg_num_valid_ids, i32_k_in) + ctx = CompilationContext.get_current() + inter_in = arith.index_cast(T.index, i32_inter_in) + size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) + gx = (inter_in + fx.Index(int(tile_n) - 1)) // fx.Index(int(tile_n)) + gy = size_expert_ids_in + launcher = moe_rdna4_stage1( + arg_out, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_num_valid_ids, + i32_tokens_in, + i32_inter_in, + i32_k_in, + i32_size_expert_ids_in, + ) + _finalize_alloc_and_launch_2d( + ctx=ctx, + alloc=alloc, + launcher=launcher, + gx=gx, + gy=gy, + block_threads=block_threads, + stream=stream, + waves_per_eu=waves_per_eu, + ir=ir, + ) + + _set_expert_sched_hint(launch_stage1, expert_sched_mode) + return launch_stage1 + + +@functools.lru_cache(maxsize=64) +def _compile_stage2_kernel_impl( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + route_tile_m: int, + tile_m: int, + tile_n: int, + tile_k: int, + m_warp: int, + n_warp: int, + doweight_stage2: bool, + in_dtype: str, + out_dtype: str, + accumulate: bool, + waves_per_eu: int | None, + expert_sched_mode: bool = True, +): + """Compile RDNA4 stage2 single-kernel WMMA MoE path.""" + + import flydsl.compiler as flyc + import flydsl.expr as fx + from flydsl._mlir import ir + from flydsl._mlir.dialects import scf + from flydsl.compiler.kernel_function import CompilationContext + from flydsl.expr import arith, buffer_ops, gpu, idx2crd, range_constexpr, rocdl, vector + from flydsl.expr.typing import T + from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + + WMMA_M, WMMA_N, WMMA_K = 16, 16, 16 + WAVE_SIZE = 32 + LDS_PAD_A = 8 + LDS_PAD_B = 8 + LOAD_VEC = 8 + ELEM_BYTES = 2 + + out_s = str(out_dtype).strip().lower() + out_is_f32 = out_s in ("f32", "fp32", "float") + if in_dtype not in ("fp16", "bf16"): + raise ValueError(f"RDNA4 stage2 only supports fp16/bf16, got {in_dtype!r}") + if out_s not in ("f16", "fp16", "half", "bf16", "bfloat16", "f32", "fp32", "float"): + raise ValueError(f"RDNA4 stage2 only supports f16/bf16/f32 outputs, got {out_dtype!r}") + if (not bool(accumulate)) and out_is_f32: + raise ValueError( + "RDNA4 compile_moe_gemm2(accumulate=False) only supports out_dtype in {'f16','bf16'}" + ) + if int(inter_dim) % int(tile_k) != 0: + raise ValueError(f"inter_dim={inter_dim} must be divisible by tile_k={tile_k}") + if int(tile_k) % WMMA_K != 0: + raise ValueError(f"tile_k={tile_k} must be divisible by {WMMA_K}") + if int(tile_m) % WMMA_M != 0 or int(tile_n) % WMMA_N != 0: + raise ValueError(f"tile_m/tile_n must be multiples of 16, got ({tile_m},{tile_n})") + + block_threads = int(m_warp) * int(n_warp) * WAVE_SIZE + warp_tile_m = int(tile_m) // int(m_warp) + warp_tile_n = int(tile_n) // int(n_warp) + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + if wmma_m_rep <= 0 or wmma_n_rep <= 0: + raise ValueError( + "Invalid RDNA4 stage2 warp tiling: " + f"wmma_m_rep={wmma_m_rep}, wmma_n_rep={wmma_n_rep}" + ) + + num_k_tiles = int(inter_dim) // int(tile_k) + k_wmma_steps = int(tile_k) // WMMA_K + num_a_loads = _validate_vectorized_tile(tile_m, tile_k, block_threads, "stage2 A") + num_b_loads = _validate_vectorized_tile(tile_n, tile_k, block_threads, "stage2 B") + sub_tiles = _make_wmma_sub_tiles( + wmma_m_rep=wmma_m_rep, + wmma_n_rep=wmma_n_rep, + WMMA_M=WMMA_M, + is_fp4=False, + ) + + lds_a_stride = int(tile_k) + LDS_PAD_A + lds_b_stride = int(tile_k) + LDS_PAD_B + lds_a_elems = int(tile_m) * lds_a_stride + LDS_PAD_A + lds_b_elems = int(tile_n) * lds_b_stride + LDS_PAD_B + + gpu_arch = str(get_hip_arch()) + alloc = SmemAllocator(None, arch=gpu_arch, global_sym_name="moe_rdna4_s2") + off_a = alloc._align(alloc.ptr, 16) + alloc.ptr = off_a + lds_a_elems * ELEM_BYTES + off_b = alloc._align(alloc.ptr, 16) + alloc.ptr = off_b + lds_b_elems * ELEM_BYTES + + @flyc.kernel(known_block_size=[block_threads, 1, 1]) + def moe_rdna4_stage2( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): + _ = (arg_scale_x, arg_scale_w, i32_k_in) + + in_ir_ty = T.bf16 if in_dtype == "bf16" else T.f16 + v8_in_ty = T.vec(8, in_ir_ty) + v4f32_ty = T.f32x4 + v8f32_ty = T.vec(8, T.f32) + v8i16_ty = T.vec(8, T.i16) + zero_raw = arith.constant_vector(0.0, v4f32_ty) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") # N tile + by = gpu.block_id("y") # expert block + + tokens_idx = arith.index_cast(T.index, i32_tokens_in) + n_idx = arith.index_cast(T.index, i32_n_in) + size_expert_ids = arith.index_cast(T.index, i32_size_expert_ids_in) + num_valid_i32 = buffer_ops.buffer_load( + buffer_ops.create_buffer_resource(arg_num_valid_ids, max_size=True), + arith.constant(0, type=T.i32), + vec_width=1, + dtype=T.i32, + ) + + sorted_num = size_expert_ids * arith.index(int(route_tile_m)) + sorted_nbytes = sorted_num * arith.index(4) + eid_nbytes = size_expert_ids * arith.index(4) + x_rows = tokens_idx * arith.index(int(topk)) + x_nbytes = x_rows * arith.index(int(inter_dim)) * arith.index(2) + out_elem_bytes = 4 if out_is_f32 else 2 + out_nbytes = tokens_idx * n_idx * arith.index(out_elem_bytes) + if not bool(accumulate): + out_nbytes = x_rows * n_idx * arith.index(out_elem_bytes) + + sorted_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes + ) + eid_rsrc = buffer_ops.create_buffer_resource( + arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes + ) + x_rsrc = buffer_ops.create_buffer_resource( + arg_x, max_size=False, num_records_bytes=x_nbytes + ) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=True) + out_rsrc = buffer_ops.create_buffer_resource( + arg_out, max_size=False, num_records_bytes=out_nbytes + ) + sw_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=True) + + eid_i32 = buffer_ops.buffer_load( + eid_rsrc, arith.index_cast(T.i32, by), vec_width=1, dtype=T.i32 + ) + eid_ok0 = arith.cmpi( + arith.CmpIPredicate.sge, eid_i32, arith.constant(0, type=T.i32) + ) + eid_ok1 = arith.cmpi( + arith.CmpIPredicate.slt, eid_i32, arith.constant(int(experts), type=T.i32) + ) + block_row_start = arith.index_cast(T.i32, by * arith.index(int(route_tile_m))) + block_in_valid = arith.cmpi( + arith.CmpIPredicate.slt, block_row_start, num_valid_i32 + ) + block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) + + layout_thr = _make_moe_wave_layout( + m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx + ) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + fx.get(thr_coord, 0), + fx.get(thr_coord, 1), + fx.get(thr_coord, 2), + fx.get(thr_coord, 3), + ) + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + blk_n = bx * arith.index(int(tile_n)) + model_dim_i32 = arith.constant(int(model_dim), type=T.i32) + inter_dim_i32 = arith.constant(int(inter_dim), type=T.i32) + topk_i32 = arith.constant(int(topk), type=T.i32) + c2_i32 = arith.constant(2, type=T.i32) + base8 = lane_kgrp * fx.Index(8) + + base_ptr = alloc.get_base() + smem_a = SmemPtr(base_ptr, off_a, in_ir_ty, shape=(lds_a_elems,)) + smem_b = SmemPtr(base_ptr, off_b, in_ir_ty, shape=(lds_b_elems,)) + lds_a = get_op_result_or_value(smem_a.get()) + lds_b = get_op_result_or_value(smem_b.get()) + + def _wmma_op(result_type, a_vec, b_vec, acc): + if in_dtype == "bf16": + a_i16 = vector.bitcast(v8i16_ty, a_vec) + b_i16 = vector.bitcast(v8i16_ty, b_vec) + return rocdl.wmma_f32_16x16x16_bf16( + result_type, a_i16, b_i16, arith.unwrap(acc) + ).result + return rocdl.wmma_f32_16x16x16_f16( + result_type, arith.unwrap(a_vec), arith.unwrap(b_vec), arith.unwrap(acc) + ).result + + a_lds_info = [] + for al in range_constexpr(num_a_loads): + a_lin = tx * fx.Index(LOAD_VEC) + fx.Index(al * block_threads * LOAD_VEC) + a_load_row = a_lin // fx.Index(tile_k) + a_load_col = a_lin % fx.Index(tile_k) + lds_rel = a_load_row * fx.Index(lds_a_stride) + a_load_col + a_lds_info.append((a_load_row, a_load_col, lds_rel)) + + b_lds_info = [] + for bl in range_constexpr(num_b_loads): + b_lin = tx * fx.Index(LOAD_VEC) + fx.Index(bl * block_threads * LOAD_VEC) + b_load_row = b_lin // fx.Index(tile_k) + b_load_col = b_lin % fx.Index(tile_k) + lds_rel = b_load_row * fx.Index(lds_b_stride) + b_load_col + b_lds_info.append((b_load_row, b_load_col, lds_rel)) + + def _load_a_tile(k_base): + raw_data = [] + for al in range_constexpr(num_a_loads): + a_load_row, a_load_col, _ = a_lds_info[al] + sorted_row = by * arith.index(int(tile_m)) + a_load_row + row_i32 = arith.index_cast(T.i32, a_load_row) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + row_i32, + arith.constant(int(route_tile_m), type=T.i32), + ) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load( + sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32 + ) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi( + arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32) + ) + slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, slot, topk_i32) + ts_ok = arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1)) + load_ok = arith.andi(row_ok, ts_ok) + ts = tok * topk_i32 + slot + elem_off = ts * inter_dim_i32 + arith.index_cast(T.i32, k_base + a_load_col) + f32_off = elem_off // c2_i32 + raw_if = scf.IfOp(load_ok, results_=[v4f32_ty], has_else=True) + with ir.InsertionPoint(raw_if.then_block): + scf.YieldOp( + [buffer_ops.buffer_load(x_rsrc, f32_off, vec_width=4, dtype=T.f32)] + ) + with ir.InsertionPoint(raw_if.else_block): + scf.YieldOp([zero_raw]) + raw_data.append(raw_if.results[0]) + return raw_data + + def _load_b_tile(k_base): + raw_data = [] + base_row = eid_i32 * model_dim_i32 + arith.index_cast(T.i32, blk_n) + for bl in range_constexpr(num_b_loads): + b_load_row, b_load_col, _ = b_lds_info[bl] + row_i32 = base_row + arith.index_cast(T.i32, b_load_row) + elem_off = row_i32 * inter_dim_i32 + arith.index_cast(T.i32, k_base + b_load_col) + f32_off = elem_off // c2_i32 + raw_data.append( + buffer_ops.buffer_load(w_rsrc, f32_off, vec_width=4, dtype=T.f32) + ) + return raw_data + + def _store_a_tile(raw_data): + for al in range_constexpr(num_a_loads): + _, _, lds_rel = a_lds_info[al] + a_vec = vector.bitcast(v8_in_ty, raw_data[al]) + vector.store(a_vec, lds_a, [lds_rel]) + + def _store_b_tile(raw_data): + for bl in range_constexpr(num_b_loads): + _, _, lds_rel = b_lds_info[bl] + b_vec = vector.bitcast(v8_in_ty, raw_data[bl]) + vector.store(b_vec, lds_b, [lds_rel]) + + def _load_a_single_from_lds(rk, rm_val): + col_base = fx.Index(rk * WMMA_K) + base8 + row = warp_m_base + fx.Index(rm_val * WMMA_M) + lane16 + lds_idx = row * fx.Index(lds_a_stride) + col_base + return vector.load_op(v8_in_ty, lds_a, [lds_idx]) + + def _load_b_from_lds(rk): + vecs = [] + col_base = fx.Index(rk * WMMA_K) + base8 + for rn in range_constexpr(wmma_n_rep): + row = warp_n_base + fx.Index(rn * WMMA_N) + lane16 + lds_idx = row * fx.Index(lds_b_stride) + col_base + vecs.append(vector.load_op(v8_in_ty, lds_b, [lds_idx])) + return vecs + + def _do_compute_rk(accs_in, rk): + new_accs = list(accs_in) + b_vecs = _load_b_from_lds(rk) + for wm in range_constexpr(wmma_m_rep): + a_vec = _load_a_single_from_lds(rk, wm) + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + new_accs[idx] = _wmma_op(v8f32_ty, a_vec, b_vecs[wn], new_accs[idx]) + return new_accs + + acc_zero = arith.constant_vector(0.0, v8f32_ty) + acc = [acc_zero] * (wmma_m_rep * wmma_n_rep) + + _if_blk = scf.IfOp(block_ok) + with ir.InsertionPoint(_if_blk.then_block): + for kt in range_constexpr(num_k_tiles): + k_base = fx.Index(kt * int(tile_k)) + a_data = _load_a_tile(k_base) + b_data = _load_b_tile(k_base) + _store_a_tile(a_data) + _store_b_tile(b_data) + gpu.barrier() + for ks in range_constexpr(k_wmma_steps): + acc = _do_compute_rk(acc, ks) + gpu.barrier() + + out_elem_ty = _moe_out_elem_ty(out_dtype, T) + + def _load_sub8(acc_idx, _vec_base): + return acc[acc_idx] + + _emit_stage2_store_epilogue( + sub_tiles=sub_tiles, + by=by, + tile_m=int(tile_m), + route_tile_m=int(route_tile_m), + warp_m_base=warp_m_base, + warp_n_base=warp_n_base, + blk_n=blk_n, + lane16=lane16, + lane_kgrp=lane_kgrp, + WMMA_N=WMMA_N, + i32_tokens_in=i32_tokens_in, + i32_n_in=i32_n_in, + topk=int(topk), + num_valid_i32=num_valid_i32, + block_row_start=block_row_start, + sorted_rsrc=sorted_rsrc, + tw_rsrc=sw_rsrc, + out_rsrc=out_rsrc, + doweight_stage2=bool(doweight_stage2), + accumulate=bool(accumulate), + out_elem_ty=out_elem_ty, + out_is_f32=bool(out_is_f32), + load_sub8=_load_sub8, + ir=ir, + fx=fx, + arith=arith, + buffer_ops=buffer_ops, + scf=scf, + vector=vector, + range_constexpr=range_constexpr, + rocdl=rocdl, + T=T, + ) + scf.YieldOp([]) + + @flyc.jit + def launch_stage2( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + stream: fx.Stream, + ): + _ = i32_k_in + ctx = CompilationContext.get_current() + n_in = arith.index_cast(T.index, i32_n_in) + size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) + gx = (n_in + fx.Index(int(tile_n) - 1)) // fx.Index(int(tile_n)) + gy = size_expert_ids_in + launcher = moe_rdna4_stage2( + arg_out, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_num_valid_ids, + i32_tokens_in, + i32_n_in, + i32_k_in, + i32_size_expert_ids_in, + ) + _finalize_alloc_and_launch_2d( + ctx=ctx, + alloc=alloc, + launcher=launcher, + gx=gx, + gy=gy, + block_threads=block_threads, + stream=stream, + waves_per_eu=waves_per_eu, + ir=ir, + ) + + _set_expert_sched_hint(launch_stage2, expert_sched_mode) + return launch_stage2 + + +@functools.lru_cache(maxsize=1024) +def _compile_moe_gemm( + *, + stage: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight: bool, + in_dtype: str = "fp16", + out_dtype: str = "f16", + accumulate: bool = True, + waves_per_eu: int | None = None, + expert_sched_mode: bool = True, +): + _require_gfx120x() + if waves_per_eu is not None and int(waves_per_eu) < 1: + raise ValueError(f"waves_per_eu must be >= 1, got {waves_per_eu!r}") + if in_dtype not in ("fp16", "bf16"): + raise ValueError( + f"Unsupported in_dtype for RDNA4 MoE stage{stage}: {in_dtype!r}; " + "expected 'fp16' or 'bf16'" + ) + + single_tile_m, single_tile_n, single_m_warp, single_n_warp = _pick_fp16_launch_shape( + int(tile_m), + int(tile_n), + int(tile_k), + max_total_warps=4, + ) + common = dict( + model_dim=int(model_dim), + inter_dim=int(inter_dim), + experts=int(experts), + topk=int(topk), + route_tile_m=int(tile_m), + tile_m=int(single_tile_m), + tile_n=int(single_tile_n), + tile_k=int(tile_k), + m_warp=int(single_m_warp), + n_warp=int(single_n_warp), + in_dtype=in_dtype, + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + expert_sched_mode=expert_sched_mode, + ) + + if stage == 1: + return _compile_stage1_kernel_impl( + doweight_stage1=bool(doweight), + **common, + ) + # RDNA4 stage2 with f16/bf16 atomic accumulate is implemented by running an + # internal f32-output kernel and casting back on the host; the LLVM backend + # cannot lower scalar half-atomics cleanly for this path. + out_s = str(out_dtype).strip().lower() + if bool(accumulate) and out_s in ("f16", "fp16", "half", "bf16", "bfloat16"): + f32_exe = _compile_stage2_kernel_impl( + doweight_stage2=bool(doweight), + accumulate=True, + out_dtype="f32", + **{k: v for k, v in common.items() if k != "out_dtype"}, + ) + return _MoeGemm2AtomicCastWrapper( + f32_exe, + model_dim=int(model_dim), + out_dtype_str=str(out_dtype), + ) + return _compile_stage2_kernel_impl( + doweight_stage2=bool(doweight), + accumulate=bool(accumulate), + **common, + ) + + +compile_moe_gemm1, compile_moe_gemm2, compile_moe_gemm2_ex = make_moe_public_api( + _compile_moe_gemm +) diff --git a/kernels/rdna_moe_gemm_2stage_common.py b/kernels/rdna_moe_gemm_2stage_common.py new file mode 100644 index 000000000..38e5d21b4 --- /dev/null +++ b/kernels/rdna_moe_gemm_2stage_common.py @@ -0,0 +1,370 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Shared utilities for RDNA4 (gfx120x) MoE 2-stage WMMA kernels. + +Helpers that the RDNA4 WMMA MoE path (``rdna_moe_gemm_2stage.py``) pulls in. +""" + +from __future__ import annotations + +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + + +def _require_gfx120x() -> None: + arch = str(get_hip_arch()) + if not arch.startswith("gfx120"): + raise RuntimeError(f"Expected gfx120x (RDNA4) architecture, got {arch!r}") + + +def _align_up(v: int, a: int) -> int: + return ((int(v) + int(a) - 1) // int(a)) * int(a) + + +def _moe_out_elem_ty(out_dtype: str, T): + """RDNA4 MoE output element type mapping (f16, bf16, or f32).""" + + out_s = str(out_dtype).strip().lower() + if out_s in ("f16", "fp16", "half"): + return T.f16 + if out_s in ("bf16", "bfloat16"): + return T.bf16 + if out_s in ("f32", "fp32", "float"): + return T.f32 + raise ValueError(f"Unsupported out_dtype {out_dtype!r}") + + +def _make_moe_wave_layout(*, m_warp: int, n_warp: int, WAVE_SIZE: int, fx): + return fx.make_layout( + (int(m_warp), int(n_warp), 2, 16), + (int(n_warp) * WAVE_SIZE, WAVE_SIZE, 16, 1), + ) + + +def _make_wmma_sub_tiles( + *, wmma_m_rep: int, wmma_n_rep: int, WMMA_M: int, is_fp4: bool = False +) -> list: + sub_tiles = [] + for wm in range(wmma_m_rep): + for wn in range(wmma_n_rep): + if is_fp4: + for half in range(2): + sub_tiles.append( + (wm * wmma_n_rep + wn, half * 8, wm * WMMA_M, wn * 2 + half) + ) + else: + sub_tiles.append((wm * wmma_n_rep + wn, 0, wm * WMMA_M, wn)) + return sub_tiles + + +def _finalize_alloc_and_launch_2d( + *, + ctx, + alloc, + launcher, + gx, + gy, + block_threads: int, + stream, + waves_per_eu, + ir, +): + with ir.InsertionPoint(ctx.gpu_module_body): + alloc.finalized = False + alloc.finalize() + for op in ctx.gpu_module_body.operations: + if hasattr(op, "attributes") and op.OPERATION_NAME == "gpu.func": + if waves_per_eu is not None and int(waves_per_eu) >= 1: + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), int(waves_per_eu) + ) + launcher.launch( + grid=(gx, gy, 1), + block=(block_threads, 1, 1), + stream=stream, + ) + + +def _fp16_tile_lds_bytes( + tile_m: int, + tile_n: int, + tile_k: int, + *, + num_b_tiles: int = 1, + lds_pad_a: int = 8, + lds_pad_b: int = 8, + elem_bytes: int = 2, +) -> int: + """Estimate LDS bytes for the RDNA4 stage1/stage2 WMMA tile layout.""" + + lds_a_stride = int(tile_k) + int(lds_pad_a) + lds_b_stride = int(tile_k) + int(lds_pad_b) + lds_a_elems = int(tile_m) * lds_a_stride + int(lds_pad_a) + lds_b_elems = int(tile_n) * lds_b_stride + int(lds_pad_b) + return (lds_a_elems + int(num_b_tiles) * lds_b_elems) * int(elem_bytes) + + +def _pick_fp16_launch_shape( + route_tile_m: int, + route_tile_n: int, + tile_k: int, + *, + max_total_warps: int = 4, + lds_budget_bytes: int = 60 * 1024, +) -> tuple[int, int, int, int]: + """Pick a legal launch shape for the RDNA4 fp16/bf16 WMMA MoE path. + + The returned tuple is ``(tile_m, tile_n, m_warp, n_warp)``. + """ + + tile_m = _align_up(int(route_tile_m), 16) + tile_n = _align_up(int(route_tile_n), 16) + + lds_bytes = _fp16_tile_lds_bytes(tile_m, tile_n, int(tile_k)) + if lds_bytes > int(lds_budget_bytes): + raise ValueError( + f"RDNA4 MoE LDS budget exceeded for tile=({tile_m},{tile_n},{tile_k}): " + f"{lds_bytes} bytes > {lds_budget_bytes} bytes" + ) + + preferred = ( + (2, 2), + (1, 4), + (4, 1), + (2, 1), + (1, 2), + (1, 1), + ) + for mw, nw in preferred: + if mw * nw > int(max_total_warps): + continue + if tile_m % mw != 0 or tile_n % nw != 0: + continue + if (tile_m // mw) % 16 != 0 or (tile_n // nw) % 16 != 0: + continue + return tile_m, tile_n, mw, nw + + raise ValueError( + "Cannot find legal RDNA4 WMMA launch shape for " + f"tile_m={route_tile_m}, tile_n={route_tile_n}, tile_k={tile_k}" + ) + + +def _emit_stage1_gate_up_epilogue( + *, + sub_tiles, + by, + tile_m: int, + route_tile_m: int, + warp_m_base, + warp_n_base, + blk_n, + lane16, + lane_kgrp, + WMMA_N: int, + i32_tokens_in, + i32_inter_in, + topk: int, + num_valid_i32=None, + block_row_start=None, + sorted_rsrc, + tw_rsrc, + out_rsrc, + doweight_stage1: bool, + out_elem_ty, + load_gate_up_sub8, + silu_fn, + ir, + fx, + arith, + buffer_ops, + scf, + vector, + range_constexpr, + T, +): + """RDNA4 stage1 gate/up epilogue (WMMA_K=16 accumulator lane layout).""" + + c_topk_i32 = arith.constant(int(topk), type=T.i32) + default_block_row_start = arith.index_cast(T.i32, by * arith.index(int(route_tile_m))) + row_base_i32 = block_row_start if block_row_start is not None else default_block_row_start + for acc_idx, vec_base, m_off, wn in sub_tiles: + sub8g, sub8u = load_gate_up_sub8(acc_idx, vec_base) + col = blk_n + warp_n_base + fx.Index(wn * WMMA_N) + lane16 + col_i32 = arith.index_cast(T.i32, col) + col_ok = arith.cmpi(arith.CmpIPredicate.ult, col_i32, i32_inter_in) + for vi in range_constexpr(8): + row_local = warp_m_base + fx.Index(m_off) + lane_kgrp * fx.Index(8) + fx.Index(vi) + sorted_row = by * arith.index(int(tile_m)) + row_local + row_i32 = arith.index_cast(T.i32, row_local) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + row_i32, + arith.constant(int(route_tile_m), type=T.i32), + ) + if num_valid_i32 is None: + row_ok_meta = row_in_route + else: + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok_meta = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok_meta, sorted_i32, row_base_i32) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32)) + slot_ok1 = arith.cmpi( + arith.CmpIPredicate.slt, slot, arith.constant(int(topk), type=T.i32) + ) + row_ok = arith.andi( + row_ok_meta, arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1)) + ) + tw = ( + buffer_ops.buffer_load(tw_rsrc, sorted_safe, vec_width=1, dtype=T.f32) + if bool(doweight_stage1) + else arith.constant(1.0, type=T.f32) + ) + out_ok = arith.andi(row_ok, col_ok) + _if_out = scf.IfOp(out_ok) + with ir.InsertionPoint(_if_out.then_block): + vg = vector.extract(sub8g, static_position=[vi], dynamic_position=[]) + vu = vector.extract(sub8u, static_position=[vi], dynamic_position=[]) + y = silu_fn(vg) * vu + if bool(doweight_stage1): + y = y * tw + out_v = arith.trunc_f(out_elem_ty, y) + out_idx = (tok * c_topk_i32 + slot) * i32_inter_in + col_i32 + buffer_ops.buffer_store(out_v, out_rsrc, out_idx) + scf.YieldOp([]) + + +def _emit_stage2_store_epilogue( + *, + sub_tiles, + by, + tile_m: int, + route_tile_m: int, + warp_m_base, + warp_n_base, + blk_n, + lane16, + lane_kgrp, + WMMA_N: int, + i32_tokens_in, + i32_n_in, + topk: int, + num_valid_i32, + block_row_start, + sorted_rsrc, + tw_rsrc, + out_rsrc, + doweight_stage2: bool, + accumulate: bool, + out_elem_ty, + out_is_f32: bool, + load_sub8, + ir, + fx, + arith, + buffer_ops, + scf, + vector, + range_constexpr, + rocdl, + T, +): + """RDNA4 stage2 store epilogue (WMMA_K=16 accumulator lane layout).""" + + c_topk_i32 = arith.constant(int(topk), type=T.i32) + c4_i32 = arith.constant(4, type=T.i32) + zero_i32 = arith.constant(0, type=T.i32) + + for acc_idx, vec_base, m_off, wn in sub_tiles: + sub8 = load_sub8(acc_idx, vec_base) + col = blk_n + warp_n_base + fx.Index(wn * WMMA_N) + lane16 + col_i32 = arith.index_cast(T.i32, col) + col_ok = arith.cmpi(arith.CmpIPredicate.ult, col_i32, i32_n_in) + if bool(accumulate): + for vi in range_constexpr(8): + row_local = warp_m_base + fx.Index(m_off) + lane_kgrp * fx.Index(8) + fx.Index(vi) + sorted_row = by * arith.index(int(tile_m)) + row_local + row_i32 = arith.index_cast(T.i32, row_local) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + row_i32, + arith.constant(int(route_tile_m), type=T.i32), + ) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32)) + slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, slot, c_topk_i32) + row_store_ok = arith.andi( + row_ok, arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1)) + ) + tw = ( + buffer_ops.buffer_load(tw_rsrc, sorted_safe, vec_width=1, dtype=T.f32) + if bool(doweight_stage2) + else arith.constant(1.0, type=T.f32) + ) + out_ok = arith.andi(row_store_ok, col_ok) + _if_out = scf.IfOp(out_ok) + with ir.InsertionPoint(_if_out.then_block): + v = vector.extract(sub8, static_position=[vi], dynamic_position=[]) + if bool(doweight_stage2): + v = v * tw + out_idx = tok * i32_n_in + col_i32 + byte_off = out_idx * (c4_i32 if bool(out_is_f32) else arith.constant(2, type=T.i32)) + if bool(out_is_f32): + rocdl.raw_ptr_buffer_atomic_fadd(v, out_rsrc, byte_off, zero_i32, zero_i32) + else: + out_v = arith.trunc_f(out_elem_ty, v) + rocdl.raw_ptr_buffer_atomic_fadd( + out_v, out_rsrc, byte_off, zero_i32, zero_i32 + ) + scf.YieldOp([]) + else: + for vi in range_constexpr(8): + row_local = warp_m_base + fx.Index(m_off) + lane_kgrp * fx.Index(8) + fx.Index(vi) + sorted_row = by * arith.index(int(tile_m)) + row_local + row_i32 = arith.index_cast(T.i32, row_local) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + row_i32, + arith.constant(int(route_tile_m), type=T.i32), + ) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32)) + slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, slot, c_topk_i32) + row_store_ok = arith.andi( + row_ok, arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1)) + ) + ts = tok * c_topk_i32 + slot + tw = ( + buffer_ops.buffer_load(tw_rsrc, sorted_safe, vec_width=1, dtype=T.f32) + if bool(doweight_stage2) + else arith.constant(1.0, type=T.f32) + ) + out_ok = arith.andi(row_store_ok, col_ok) + _if_out = scf.IfOp(out_ok) + with ir.InsertionPoint(_if_out.then_block): + v = vector.extract(sub8, static_position=[vi], dynamic_position=[]) + if bool(doweight_stage2): + v = v * tw + out_idx = ts * i32_n_in + col_i32 + out_v = arith.trunc_f(out_elem_ty, v) + buffer_ops.buffer_store(out_v, out_rsrc, out_idx) + scf.YieldOp([]) diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index c8da01507..b15808b22 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -202,6 +202,47 @@ _emit_row() { printf "%-14.14s %-34.34s %-10.10s %10s %10s\n" "${op}" "${shape}" "${dtype}" "${tbps}" "${tflops}" } +_emit_rdna4_moe_rows_from_log() { + log="$1" + [ -f "${log}" ] || return 0 + awk ' + function trim(s) { + sub(/^[ \t]+/, "", s) + sub(/[ \t]+$/, "", s) + return s + } + /\|/ && $0 ~ /dim=/ && $0 ~ /inter=/ { + n = split($0, parts, /\|/) + if (n >= 2) { + cfg = trim(parts[1]) + dtype = trim(parts[2]) + } + next + } + /Stage 1/ { stage = "moe_rdna4_s1"; next } + /Stage 2 atomic/ { stage = "moe_rdna4_s2a"; next } + /Stage 2 reduce/ { stage = "moe_rdna4_s2r"; next } + /^[[:space:]]*[0-9]+[[:space:]]+/ { + if (stage == "" || cfg == "" || dtype == "") { + next + } + tokens = $1 + m_eff = $2 + tflops = $4 + tbps = $5 + status = $7 + if (status == "" || status == "FAIL") { + next + } + shape = cfg "_t" tokens "_m" m_eff + printf "%s\t%s\t%s\t%s\t%s\n", stage, shape, dtype, tbps, tflops + } + ' "${log}" | while IFS="$(printf '\t')" read -r op shape dtype tbps tflops; do + [ -n "${op}" ] || continue + _emit_row "${op}" "${shape}" "${dtype}" "${tbps}" "${tflops}" + done +} + _normalize_op() { # Normalize aliases to canonical op names. op="${1:-}" @@ -789,6 +830,30 @@ if [ "${RUN_MOE}" -eq 1 ] && [ "${IS_CDNA}" = "true" ]; then done fi +# RDNA4 MoE (gfx120x only — uses WMMA fp16/bf16) +if [ "${RUN_MOE}" -eq 1 ] && [ "${IS_RDNA4}" = "true" ]; then + echo "" + echo "========================================================================" + echo "RDNA4 MoE Benchmarks" + echo "========================================================================" + log="${BENCH_LOG_DIR}/moe_rdna4.log" + if python3 tests/kernels/test_moe_gemm_rdna4.py \ + --bench \ + --bench-dtype fp16,bf16 \ + --bench-warmup 5 \ + --bench-iters 20 \ + --bench-no-ref \ + >"${log}" 2>&1; then + SUCCESS_COUNT=$((SUCCESS_COUNT + 1)) + cat "${log}" + _emit_rdna4_moe_rows_from_log "${log}" + else + FAIL_COUNT=$((FAIL_COUNT + 1)) + echo "moe rdna4 failed. Log: ${log}" >&2 + _show_fail_log "${log}" "moe_rdna4" + fi +fi + # RDNA4 WMMA GEMM benchmarks (via benchmark_common.py) if [ "${IS_RDNA4}" = "true" ]; then echo "" diff --git a/tests/kernels/benchmark_common.py b/tests/kernels/benchmark_common.py index e30cee712..ff50feae5 100644 --- a/tests/kernels/benchmark_common.py +++ b/tests/kernels/benchmark_common.py @@ -800,8 +800,13 @@ def moe_bench_main( ) check_ref = not args.bench_no_ref + try: + from flydsl.runtime.device import get_rocm_arch as _get_rocm_arch + _arch_label = str(_get_rocm_arch()) + except Exception: + _arch_label = "rocm" print("=" * 110) - print(" AMD gfx1250 MOE GEMM Kernel Performance Benchmark") + print(f" AMD {_arch_label} MoE GEMM Kernel Performance Benchmark") print(f" PyTorch {torch.__version__}") print(f" Device: {torch.cuda.get_device_name(0)}") props = torch.cuda.get_device_properties(0) diff --git a/tests/kernels/moe_test_utils.py b/tests/kernels/moe_test_utils.py new file mode 100644 index 000000000..8339b3d77 --- /dev/null +++ b/tests/kernels/moe_test_utils.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Shared test helpers for MoE kernel harnesses.""" + +from __future__ import annotations + +import os +from typing import Optional, Tuple + +import torch + +# Optional: use aiter's exact routing/sorting implementation (matches +# `aiter/op_tests/test_moe_2stage.py`). Some environments ship aiter python but +# miss required JIT .so dependencies; we fall back gracefully. +try: + import aiter # noqa: F401 + from aiter.fused_moe import moe_sorting as aiter_moe_sorting + + HAS_AITER = True +except Exception: + HAS_AITER = False + aiter_moe_sorting = None + + +RoutingBuffers = Tuple[ + torch.Tensor, # sorted_token_ids + torch.Tensor, # sorted_weights + torch.Tensor, # sorted_expert_ids + torch.Tensor, # num_valid_ids (shape [1], i32) + int, # sorted_size + int, # blocks +] + + +def moe_sorting_torch_native( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + *, + num_experts: int, + block_size: int, + expert_mask: Optional[torch.Tensor] = None, + num_local_tokens: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Torch reference for aiter's moe_sorting.""" + + assert topk_ids.is_cuda and topk_weights.is_cuda + device = topk_ids.device + tokens, topk = topk_ids.shape + topk = topk_ids.shape[1] + + max_num_tokens_padded = int( + topk_ids.numel() + int(num_experts) * int(block_size) - int(topk) + ) + max_num_m_blocks = int( + (max_num_tokens_padded + int(block_size) - 1) // int(block_size) + ) + + init_val = (int(topk) << 24) | int(tokens) + sorted_ids = torch.full( + (max_num_tokens_padded,), init_val, dtype=torch.int32, device=device + ) + sorted_weights = torch.empty( + (max_num_tokens_padded,), dtype=torch.float32, device=device + ) + sorted_expert_ids = torch.full( + (max_num_m_blocks,), -1, dtype=torch.int32, device=device + ) + num_tokens_post_pad = torch.empty((2,), dtype=torch.int32, device=device) + + if num_local_tokens is not None: + topk_ids = topk_ids[: num_local_tokens.item()] + + sorted_ids_begin = 0 + sorted_expert_ids_begin = 0 + skip_expert_num = 0 + for expert_id in range(int(num_experts)): + if expert_mask is not None and int(expert_mask[expert_id].item()) == 0: + skip_expert_num += 1 + continue + token_id, topk_id = torch.where(topk_ids == expert_id) + tokens_num = int(token_id.numel()) + sorted_expert_ids_num = int( + (tokens_num + int(block_size) - 1) // int(block_size) + ) + tokens_num_pad = int(sorted_expert_ids_num * int(block_size)) + sorted_ids[sorted_ids_begin : sorted_ids_begin + tokens_num] = ( + (topk_id.to(torch.int32) << 24) | token_id.to(torch.int32) + ) + sorted_weights[sorted_ids_begin : sorted_ids_begin + tokens_num] = ( + topk_weights[token_id, topk_id].to(torch.float32) + ) + sorted_ids_begin = int(sorted_ids_begin + tokens_num_pad) + sorted_expert_ids[ + sorted_expert_ids_begin : sorted_expert_ids_begin + sorted_expert_ids_num + ] = int(expert_id - skip_expert_num) + sorted_expert_ids_begin = int( + sorted_expert_ids_begin + sorted_expert_ids_num + ) + + num_tokens_post_pad[0] = int(sorted_ids_begin) + num_tokens_post_pad[1] = int(topk_ids.shape[0]) + return sorted_ids, sorted_weights, sorted_expert_ids, num_tokens_post_pad + + +def _maybe_aiter_moe_sorting( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + *, + num_experts: int, + model_dim: int, + block_m: int, +) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + """Return (sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids) or None.""" + + if not HAS_AITER or aiter_moe_sorting is None: + return None + try: + topk_ids_i32 = topk_ids.to(torch.int32) + topk_w_f32 = topk_weights.to(torch.float32) + sorted_ids, sorted_w, sorted_expert_ids, num_valid_ids, _moe_buf = ( + aiter_moe_sorting( + topk_ids_i32, + topk_w_f32, + num_experts, + model_dim, + torch.float16, + block_m, + ) + ) + if num_valid_ids.numel() > 1: + num_valid_ids = num_valid_ids[:1].contiguous() + return sorted_ids, sorted_w, sorted_expert_ids, num_valid_ids + except Exception: + return None + + +def build_routing_buffers( + *, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + experts: int, + tile_m: int, + model_dim: Optional[int] = None, + moe_sort_mode: Optional[str] = None, + expert_mask: Optional[torch.Tensor] = None, + num_local_tokens: Optional[torch.Tensor] = None, +) -> RoutingBuffers: + """Build routing buffers once, reusable across stage1 + stage2.""" + + default_mode = "aiter" if HAS_AITER else "torch" + sort_mode = str( + moe_sort_mode or os.environ.get("flydsl_MOE_SORT_MODE", default_mode) + ).lower().strip() + if sort_mode not in ("aiter", "torch"): + raise ValueError( + f"invalid moe_sort_mode={sort_mode!r} (expected 'aiter' or 'torch')" + ) + + if sort_mode == "torch": + ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_tokens_post_pad, + ) = moe_sorting_torch_native( + topk_ids=topk_ids.to(torch.int32), + topk_weights=topk_weights.to(torch.float32), + num_experts=int(experts), + block_size=int(tile_m), + expert_mask=expert_mask, + num_local_tokens=num_local_tokens, + ) + num_valid_ids = num_tokens_post_pad[:1].contiguous() + sorted_size = int(sorted_token_ids.numel()) + blocks = int(sorted_expert_ids.numel()) + return ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + sorted_size, + blocks, + ) + + if not HAS_AITER: + raise RuntimeError( + "aiter is not available; cannot build routing buffers (moe_sort_mode='aiter')." + ) + if model_dim is None: + raise ValueError("model_dim is required when moe_sort_mode='aiter'") + + res = _maybe_aiter_moe_sorting( + topk_ids, + topk_weights, + num_experts=experts, + model_dim=model_dim, + block_m=tile_m, + ) + if res is None: + raise RuntimeError("aiter moe_sorting failed/unavailable; cannot build routing buffers.") + sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids = res + sorted_token_ids = sorted_token_ids.contiguous() + sorted_weights = sorted_weights.contiguous() + sorted_expert_ids = sorted_expert_ids.contiguous() + sorted_size = int(sorted_token_ids.numel()) + blocks = int(sorted_expert_ids.numel()) + return ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + sorted_size, + blocks, + ) + + +def get_topk_valid_mask( + topk_ids: torch.Tensor, + expert_mask: Optional[torch.Tensor] = None, + *, + dtype: torch.dtype = torch.int8, +) -> torch.Tensor: + """Build valid_mask [tokens, topk] for optional EP-style masking.""" + + if expert_mask is None: + return torch.ones(topk_ids.shape, dtype=dtype, device=topk_ids.device) + return expert_mask[topk_ids].to(dtype) diff --git a/tests/kernels/test_moe_gemm_rdna4.py b/tests/kernels/test_moe_gemm_rdna4.py new file mode 100644 index 000000000..8f2b9376e --- /dev/null +++ b/tests/kernels/test_moe_gemm_rdna4.py @@ -0,0 +1,933 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""MoE GEMM tests for RDNA4 (gfx120x) WMMA fp16/bf16 kernels.""" + +from __future__ import annotations + +import math +import os +import sys +from typing import Optional + +import pytest +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYTHON_CANDIDATES = [ + os.path.join(_REPO_ROOT, "build", "python_packages"), + _REPO_ROOT, +] +for _p in reversed(_PYTHON_CANDIDATES): + if os.path.isdir(_p) and _p not in sys.path: + sys.path.insert(0, _p) + +from tests.kernels.test_ref import torch_moe_gemm1, torch_moe_gemm2 +from tests.kernels.moe_test_utils import ( + RoutingBuffers, + build_routing_buffers, + get_topk_valid_mask, +) +from tests.test_common import verify_output, run_perftest +from flydsl.runtime.device import get_rocm_arch +from kernels.moe_gemm_2stage import MoeGemm2Mode +from kernels.rdna_moe_gemm_2stage import ( + compile_moe_gemm1, + compile_moe_gemm2, + compile_moe_gemm2_ex, +) + +ARCH = str(get_rocm_arch()) + +pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +if not ARCH.startswith("gfx120"): + pytest.skip(f"RDNA4 MoE tests require gfx120x, got {ARCH}", allow_module_level=True) + + +def _make_reduce_mode_compile_fn(use_valid_mask: bool = False): + def _compile( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage2: bool, + in_dtype: str = "fp16", + group_size: int = -1, + out_dtype: str = "f16", + waves_per_eu: Optional[int] = None, + expert_sched_mode: bool = True, + ): + _ = group_size + return compile_moe_gemm2_ex( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=doweight_stage2, + in_dtype=in_dtype, + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + valid_mask=(True if bool(use_valid_mask) else None), + mode=MoeGemm2Mode.REDUCE, + zero_intermediate=True, + expert_sched_mode=bool(expert_sched_mode), + ) + + return _compile + + +def _make_inputs( + *, + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + even_dispatch: bool = False, + seed: int = 0, + init_scale: float = 0.2, +): + device = torch.device("cuda") + torch.manual_seed(int(seed)) + + s = float(init_scale) + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s + w1_fp32 = torch.randn( + (experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32 + ) * s + w2_fp32 = torch.randn( + (experts, model_dim, inter_dim), device=device, dtype=torch.float32 + ) * (s / math.sqrt(inter_dim)) + + if even_dispatch: + topk_ids = torch.stack( + [ + torch.arange(topk, device=device, dtype=torch.int32) + ((t * topk) % experts) + for t in range(tokens) + ] + ) % experts + topk_weights = torch.full( + (tokens, topk), 1.0 / topk, device=device, dtype=torch.float32 + ) + else: + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + + return x_fp32, w1_fp32, w2_fp32, topk_ids, topk_weights + + +def _make_partial_valid_mask(topk_ids: torch.Tensor, experts: int) -> torch.Tensor: + expert_mask = torch.ones((experts,), device=topk_ids.device, dtype=torch.uint8) + expert_mask[1::2] = 0 + valid_mask = get_topk_valid_mask( + topk_ids, + expert_mask=expert_mask, + dtype=torch.uint8, + ).contiguous() + num_valid = int(valid_mask.sum().item()) + if num_valid <= 0 or num_valid >= valid_mask.numel(): + raise ValueError("expected partial valid_mask with both zero and non-zero entries") + return valid_mask + + +def run_moe_stage1( + *, + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage1: bool, + in_dtype: str = "fp16", + seed: int = 0, + num_iters: int = 3, + num_warmup: int = 1, + test_graph: bool = False, + waves_per_eu: Optional[int] = None, + expert_sched_mode: bool = True, + even_dispatch: bool = False, + x_fp32_in: Optional[torch.Tensor] = None, + w1_fp32_in: Optional[torch.Tensor] = None, + w2_fp32_in: Optional[torch.Tensor] = None, + topk_ids_in: Optional[torch.Tensor] = None, + topk_weights_in: Optional[torch.Tensor] = None, + routing_in: Optional[RoutingBuffers] = None, + return_outputs: bool = False, + skip_ref: bool = False, + # Accepted for moe_bench_main compatibility; RDNA4 has no TDM hardware. + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, +): + _ = (w2_fp32_in, use_tdm_store, inst_prefetch, wave_specialized_tdm) + assert model_dim % tile_k == 0 + assert inter_dim % tile_n == 0 + + x_fp32, w1_fp32, _, topk_ids, topk_weights = _make_inputs( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + even_dispatch=even_dispatch, + seed=seed, + ) + if x_fp32_in is not None: + x_fp32 = x_fp32_in + if w1_fp32_in is not None: + w1_fp32 = w1_fp32_in + if topk_ids_in is not None: + topk_ids = topk_ids_in + if topk_weights_in is not None: + topk_weights = topk_weights_in + + routing = routing_in or build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + tile_m=tile_m, + model_dim=model_dim, + moe_sort_mode="torch", + ) + ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + _sorted_size, + blocks, + ) = routing + + cast = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + stage1_out_dtype = "f16" if in_dtype == "fp16" else "bf16" + out_torch_dtype = cast + + x_q = x_fp32.to(cast).contiguous() + w1_q = w1_fp32.to(cast) + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim).contiguous() + + out = torch.zeros((tokens, topk, inter_dim), device=x_q.device, dtype=out_torch_dtype) + scale_x_1d = torch.empty((0,), device=x_q.device, dtype=torch.float32) + scale_w1_1d = torch.empty((0,), device=x_q.device, dtype=torch.float32) + sorted_weights_1d = sorted_weights.contiguous().view(-1) + + exe = compile_moe_gemm1( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + in_dtype=in_dtype, + out_dtype=stage1_out_dtype, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=bool(doweight_stage1), + waves_per_eu=waves_per_eu, + expert_sched_mode=bool(expert_sched_mode), + ) + + def launch(o, x, w, sx, sw, st, eids, sw_sorted): + stream = torch.cuda.current_stream() + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + inter_dim, + model_dim, + int(blocks), + stream, + ) + + _, us = run_perftest( + launch, + out, + x_q, + w1_q_flat, + scale_x_1d, + scale_w1_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + num_iters=int(num_iters), + num_warmup=int(num_warmup), + testGraph=test_graph, + ) + torch.cuda.synchronize() + + if not bool(skip_ref): + ref = torch_moe_gemm1( + x_q, + w1_q_flat, + None, + None, + topk_ids.to(torch.int64), + topk_weights, + inter_dim=inter_dim, + doweight_stage1=bool(doweight_stage1), + ) + assert verify_output(out.to(torch.float32), ref, rtol=0.25, atol=0.25) + + if return_outputs: + return out, us + return None + + +def run_moe_stage2( + *, + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage1: bool, + in_dtype: str = "fp16", + out_dtype: str = "f16", + seed: int = 0, + num_iters: int = 3, + num_warmup: int = 1, + test_graph: bool = False, + waves_per_eu: Optional[int] = None, + expert_sched_mode: bool = True, + even_dispatch: bool = False, + use_reduce: bool = False, + use_valid_mask: bool = False, + valid_mask_in: Optional[torch.Tensor] = None, + x_fp32_in: Optional[torch.Tensor] = None, + w1_fp32_in: Optional[torch.Tensor] = None, + w2_fp32_in: Optional[torch.Tensor] = None, + topk_ids_in: Optional[torch.Tensor] = None, + topk_weights_in: Optional[torch.Tensor] = None, + routing_in: Optional[RoutingBuffers] = None, + a2_in: Optional[torch.Tensor] = None, + # ``moe_bench_main`` uses the gfx1250 API naming; alias to ``a2_in`` and + # drop ``a2_scale_in`` because RDNA4 fp16/bf16 MoE has no A2 scale. + a2_fp8_in: Optional[torch.Tensor] = None, + a2_scale_in: Optional[torch.Tensor] = None, + return_outputs: bool = False, + skip_ref: bool = False, + # Accepted for moe_bench_main compatibility; RDNA4 has no TDM hardware. + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, +): + _ = (a2_scale_in, use_tdm_store, inst_prefetch, wave_specialized_tdm) + if valid_mask_in is not None and (not bool(use_reduce) or not bool(use_valid_mask)): + raise ValueError("valid_mask_in requires use_reduce=True and use_valid_mask=True") + if a2_in is None and a2_fp8_in is not None: + a2_in = a2_fp8_in + if model_dim % tile_n != 0: + raise ValueError( + f"Invalid stage2 tiling: model_dim ({model_dim}) must be divisible by tile_n ({tile_n})." + ) + if inter_dim % tile_k != 0: + raise ValueError( + f"Invalid stage2 tiling: inter_dim ({inter_dim}) must be divisible by tile_k ({tile_k})." + ) + + x_fp32, w1_fp32, w2_fp32, topk_ids, topk_weights = _make_inputs( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + even_dispatch=even_dispatch, + seed=seed, + ) + if x_fp32_in is not None: + x_fp32 = x_fp32_in + if w1_fp32_in is not None: + w1_fp32 = w1_fp32_in + if w2_fp32_in is not None: + w2_fp32 = w2_fp32_in + if topk_ids_in is not None: + topk_ids = topk_ids_in + if topk_weights_in is not None: + topk_weights = topk_weights_in + + routing = routing_in or build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + tile_m=tile_m, + model_dim=model_dim, + moe_sort_mode="torch", + ) + ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + _sorted_size, + blocks, + ) = routing + + cast = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + w2_q = w2_fp32.to(cast).contiguous() + w2_kernel = w2_q.view(experts * model_dim, inter_dim).contiguous().view(-1) + + if a2_in is not None: + a2_q = a2_in.contiguous() + else: + out1_ref = torch_moe_gemm1( + x_fp32.to(cast), + w1_fp32.to(cast).view(experts * (2 * inter_dim), model_dim).contiguous(), + None, + None, + topk_ids.to(torch.int64), + topk_weights, + inter_dim=inter_dim, + doweight_stage1=bool(doweight_stage1), + ) + a2_q = out1_ref.to(cast).contiguous() + + a2_scale_1d = torch.empty((0,), device=a2_q.device, dtype=torch.float32) + w2_scale_1d = torch.empty((0,), device=a2_q.device, dtype=torch.float32) + sorted_weights_1d = sorted_weights.contiguous().view(-1) + + out_s = str(out_dtype).strip().lower() + if out_s in ("f16", "fp16", "half"): + out_torch_dtype = torch.float16 + elif out_s in ("f32", "fp32", "float"): + out_torch_dtype = torch.float32 + else: + raise ValueError(f"out_dtype must be 'f16' or 'f32', got {out_dtype!r}") + + if bool(use_reduce) and out_torch_dtype == torch.float32: + pytest.skip("reduce mode does not support out_dtype='f32'") + + out = torch.zeros((tokens, model_dim), device=a2_q.device, dtype=out_torch_dtype) + out_perf = torch.zeros_like(out) + + compile_fn = ( + _make_reduce_mode_compile_fn(use_valid_mask=bool(use_valid_mask)) + if bool(use_reduce) + else compile_moe_gemm2 + ) + exe = compile_fn( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=not bool(doweight_stage1), + in_dtype=in_dtype, + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + expert_sched_mode=bool(expert_sched_mode), + ) + is_reduce_exe = getattr(exe, "mode", None) == MoeGemm2Mode.REDUCE + + def launch(o, x, w, sx, sw, st, eids, sw_sorted): + stream = torch.cuda.current_stream() + valid_mask = None + if is_reduce_exe and bool(use_valid_mask): + if valid_mask_in is not None: + valid_mask = valid_mask_in.to( + device=topk_ids.device, + dtype=torch.uint8, + ).contiguous() + else: + valid_mask = get_topk_valid_mask( + topk_ids, + expert_mask=None, + dtype=torch.uint8, + ).contiguous() + if is_reduce_exe: + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + model_dim, + inter_dim, + int(blocks), + valid_mask, + stream, + ) + else: + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + model_dim, + inter_dim, + int(blocks), + stream, + ) + + _, us = run_perftest( + launch, + out_perf, + a2_q.view(-1), + w2_kernel.view(-1), + a2_scale_1d, + w2_scale_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + num_iters=int(num_iters), + num_warmup=int(num_warmup), + testGraph=test_graph, + ) + torch.cuda.synchronize() + + out.zero_() + launch( + out, + a2_q.view(-1), + w2_kernel.view(-1), + a2_scale_1d, + w2_scale_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + ) + torch.cuda.synchronize() + + if not bool(skip_ref): + a2_ref = a2_q + if bool(use_reduce) and bool(use_valid_mask): + if valid_mask_in is not None: + valid_mask_ref = valid_mask_in.to(device=a2_q.device, dtype=a2_q.dtype) + else: + valid_mask_ref = get_topk_valid_mask( + topk_ids, + expert_mask=None, + dtype=a2_q.dtype, + ) + a2_ref = a2_q * valid_mask_ref.view(tokens, topk, 1) + ref2 = torch_moe_gemm2( + a2_ref, + w2_q, + None, + None, + topk_ids.to(torch.int64), + topk_weights, + model_dim=model_dim, + doweight_stage2=not bool(doweight_stage1), + ) + assert verify_output(out.to(torch.float32), ref2, rtol=0.5, atol=0.5) + + if return_outputs: + return out, us + return None + + +def run_moe_gemm_2stage( + *, + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n1: int, + tile_k1: int, + tile_n2: int, + tile_k2: int, + doweight_stage1: bool, + in_dtype: str, + out_dtype: str, + use_reduce: bool, + use_valid_mask: bool, + test_graph: bool, + waves_per_eu: Optional[int] = None, + even_dispatch: bool = False, + expert_sched_mode: bool = True, + seed: int = 0, + skip_ref: bool = False, +): + x_fp32, w1_fp32, w2_fp32, topk_ids, topk_weights = _make_inputs( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + even_dispatch=even_dispatch, + seed=seed, + ) + routing = build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + tile_m=tile_m, + model_dim=model_dim, + moe_sort_mode="torch", + ) + + stage1_out, _ = run_moe_stage1( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n1, + tile_k=tile_k1, + doweight_stage1=doweight_stage1, + in_dtype=in_dtype, + waves_per_eu=waves_per_eu, + test_graph=test_graph, + expert_sched_mode=expert_sched_mode, + x_fp32_in=x_fp32, + w1_fp32_in=w1_fp32, + topk_ids_in=topk_ids, + topk_weights_in=topk_weights, + routing_in=routing, + return_outputs=True, + skip_ref=bool(skip_ref), + ) + + run_moe_stage2( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n2, + tile_k=tile_k2, + doweight_stage1=doweight_stage1, + in_dtype=in_dtype, + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + test_graph=test_graph, + expert_sched_mode=expert_sched_mode, + x_fp32_in=x_fp32, + w1_fp32_in=w1_fp32, + w2_fp32_in=w2_fp32, + topk_ids_in=topk_ids, + topk_weights_in=topk_weights, + routing_in=routing, + a2_in=stage1_out, + use_reduce=use_reduce, + use_valid_mask=use_valid_mask, + skip_ref=bool(skip_ref), + ) + + +@pytest.mark.parametrize("waves_per_eu", [1, 2], ids=["wpe1", "wpe2"]) +def test_moe_2stage_waves_per_eu_smoke(waves_per_eu: int): + shape = dict(tokens=32, model_dim=256, inter_dim=128, experts=4, topk=2, tile_m=16) + stage1_out, _ = run_moe_stage1( + **shape, + tile_n=64, + tile_k=128, + doweight_stage1=False, + in_dtype="fp16", + waves_per_eu=waves_per_eu, + return_outputs=True, + skip_ref=True, + ) + stage2_out, _ = run_moe_stage2( + **shape, + tile_n=128, + tile_k=128, + doweight_stage1=False, + in_dtype="fp16", + out_dtype="f16", + waves_per_eu=waves_per_eu, + a2_in=stage1_out.to(torch.float16), + return_outputs=True, + skip_ref=True, + ) + assert torch.isfinite(stage1_out).all() + assert torch.isfinite(stage2_out).all() + + +def test_moe_reduce_valid_mask_masks_invalid_routes(): + shape = dict(tokens=48, model_dim=256, inter_dim=128, experts=6, topk=3, tile_m=16) + x_fp32, w1_fp32, w2_fp32, topk_ids, topk_weights = _make_inputs( + tokens=shape["tokens"], + model_dim=shape["model_dim"], + inter_dim=shape["inter_dim"], + experts=shape["experts"], + topk=shape["topk"], + even_dispatch=True, + seed=7, + ) + routing = build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=shape["experts"], + tile_m=shape["tile_m"], + model_dim=shape["model_dim"], + moe_sort_mode="torch", + ) + valid_mask = _make_partial_valid_mask(topk_ids, experts=shape["experts"]) + + stage1_out, _ = run_moe_stage1( + **shape, + tile_n=64, + tile_k=128, + doweight_stage1=False, + in_dtype="fp16", + x_fp32_in=x_fp32, + w1_fp32_in=w1_fp32, + topk_ids_in=topk_ids, + topk_weights_in=topk_weights, + routing_in=routing, + return_outputs=True, + skip_ref=False, + ) + stage2_out, _ = run_moe_stage2( + **shape, + tile_n=64, + tile_k=128, + doweight_stage1=False, + in_dtype="fp16", + out_dtype="f16", + use_reduce=True, + use_valid_mask=True, + valid_mask_in=valid_mask, + x_fp32_in=x_fp32, + w1_fp32_in=w1_fp32, + w2_fp32_in=w2_fp32, + topk_ids_in=topk_ids, + topk_weights_in=topk_weights, + routing_in=routing, + a2_in=stage1_out, + return_outputs=True, + skip_ref=False, + ) + + w2_q = w2_fp32.to(torch.float16).contiguous() + unmasked_ref = torch_moe_gemm2( + stage1_out, + w2_q, + None, + None, + topk_ids.to(torch.int64), + topk_weights, + model_dim=shape["model_dim"], + doweight_stage2=True, + ) + assert not torch.allclose( + stage2_out.to(torch.float32), + unmasked_ref, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize( + "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n1, tile_k1, tile_n2, tile_k2, doweight_stage1", + [ + pytest.param(64, 256, 128, 4, 2, 16, 64, 128, 64, 128, False, id="S"), + ], +) +@pytest.mark.parametrize("in_dtype", ["fp16", "bf16"]) +@pytest.mark.parametrize("out_dtype", ["f16", "f32"], ids=["out_f16", "out_f32"]) +@pytest.mark.parametrize("use_reduce", [False, True], ids=["atomic", "reduce"]) +@pytest.mark.parametrize("use_valid_mask", [False, True], ids=["nomask", "mask"]) +@pytest.mark.parametrize("test_graph", [False, True], ids=["eager", "graph"]) +def test_moe_gemm_2stage_smoke( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n1: int, + tile_k1: int, + tile_n2: int, + tile_k2: int, + doweight_stage1: bool, + in_dtype: str, + out_dtype: str, + use_reduce: bool, + use_valid_mask: bool, + test_graph: bool, +): + if (not bool(use_reduce)) and bool(use_valid_mask): + pytest.skip("valid_mask is only used in reduce mode.") + run_moe_gemm_2stage( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n1=tile_n1, + tile_k1=tile_k1, + tile_n2=tile_n2, + tile_k2=tile_k2, + doweight_stage1=doweight_stage1, + in_dtype=in_dtype, + out_dtype=out_dtype, + use_reduce=use_reduce, + use_valid_mask=use_valid_mask, + test_graph=test_graph, + ) + + +@pytest.mark.parametrize( + "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n1, tile_k1, tile_n2, tile_k2", + [ + pytest.param(129, 1024, 256, 8, 2, 32, 64, 128, 64, 128, id="M"), + pytest.param( + 333, + 4096, + 2048, + 17, + 9, + 64, + 128, + 128, + 64, + 128, + id="L", + marks=pytest.mark.large_shape, + ), + ], +) +def test_moe_gemm_2stage_perf_smoke( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n1: int, + tile_k1: int, + tile_n2: int, + tile_k2: int, +): + run_moe_gemm_2stage( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n1=tile_n1, + tile_k1=tile_k1, + tile_n2=tile_n2, + tile_k2=tile_k2, + doweight_stage1=False, + in_dtype="fp16", + out_dtype="f16", + use_reduce=False, + use_valid_mask=False, + test_graph=False, + skip_ref=(tokens >= 129), + ) + + +# --------------------------------------------------------------------------- +# Benchmark entry point (mirrors tests/kernels/test_moe_gemm_wmma_gfx1250.py) +# --------------------------------------------------------------------------- + + +def _bench_setup_data(tokens, model_dim, inter_dim, experts, topk, tile_m, seed=42): + """Build random MoE data + routing buffers for RDNA4 bench sweeps.""" + device = torch.device("cuda") + torch.manual_seed(seed) + s = 0.2 + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * s + w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * ( + s / math.sqrt(inter_dim) + ) + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + routing = build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + tile_m=tile_m, + model_dim=model_dim, + moe_sort_mode="torch", + ) + return x_fp32, w1_fp32, w2_fp32, topk_ids, topk_weights, routing + + +def _bench_prepare_a2(out1_fp16, _tokens, _topk, _inter_dim, in_dtype): + """Convert stage1 fp16 output to stage2 activation input (fp16 or bf16).""" + if in_dtype == "fp16": + return out1_fp16, None + if in_dtype == "bf16": + return out1_fp16.to(torch.bfloat16), None + raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") + + +if __name__ == "__main__": + import argparse + import sys + from tests.kernels.benchmark_common import add_moe_bench_args, moe_bench_main + + torch.set_default_device("cuda") + + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + description="MoE 2-stage (FlyDSL RDNA4 / gfx120x WMMA fp16/bf16) benchmark", + ) + parser.add_argument( + "--in_dtype", + type=str, + default="fp16", + choices=["fp16", "bf16"], + help="Kernel input dtype (default: fp16).", + ) + # RDNA4 has no TDM hardware; accept the same flags as the gfx1250 harness so + # scripts/run_benchmark.sh can invoke both paths uniformly. They are ignored. + parser.add_argument("--use_tdm_store", action="store_true", default=False) + parser.add_argument("--inst_prefetch", action="store_true", default=False) + parser.add_argument("--wave_specialized_tdm", action="store_true", default=False) + add_moe_bench_args(parser) + args = parser.parse_args() + + if not args.bench: + print("Use --bench to run the RDNA4 MoE benchmark sweep.", file=sys.stderr) + sys.exit(2) + + moe_bench_main( + args, + stage1_fn=run_moe_stage1, + stage2_fn=run_moe_stage2, + setup_data_fn=_bench_setup_data, + prepare_a2_fn=_bench_prepare_a2, + )