From 3346bc91319136376eb090eb3bb1e171c94f4cd9 Mon Sep 17 00:00:00 2001 From: Felix Li Date: Fri, 10 Apr 2026 18:04:23 +0800 Subject: [PATCH 01/29] implement Vector type, remove vector dialect in kernels and use soffset for quant kernel (#377) --- kernels/fused_rope_cache_kernel.py | 302 +++++------ kernels/layernorm_kernel.py | 134 ++--- kernels/rmsnorm_kernel.py | 95 ++-- kernels/softmax_kernel.py | 47 +- python/flydsl/expr/numeric.py | 65 ++- python/flydsl/expr/typing.py | 212 +++++++- python/flydsl/expr/utils/arith.py | 13 +- python/flydsl/expr/vector.py | 30 +- tests/kernels/test_quant.py | 103 ++-- tests/unit/test_vector.py | 799 +++++++++++++++++++++++++++++ 10 files changed, 1375 insertions(+), 425 deletions(-) create mode 100644 tests/unit/test_vector.py diff --git a/kernels/fused_rope_cache_kernel.py b/kernels/fused_rope_cache_kernel.py index 07cbbeee..d9258487 100644 --- a/kernels/fused_rope_cache_kernel.py +++ b/kernels/fused_rope_cache_kernel.py @@ -26,84 +26,17 @@ import flydsl.compiler as flyc import flydsl.expr as fx -from flydsl.expr import vector, range_constexpr -from flydsl.expr.arith import ArithValue +from flydsl.expr import range_constexpr from flydsl.expr.typing import T -from flydsl.expr import buffer_ops +from flydsl.expr.numeric import Numeric +from flydsl.expr.vector import full from kernels.kernels_common import dtype_to_elem_type -from kernels.mfma_preshuffle_pipeline import crd2idx WARP_SIZE = 64 VEC_WIDTH = 8 -def _layout_to_dword_off(coord, layout, elem_bytes): - """Coordinate → dword offset for buffer_load/buffer_store. - - crd2idx(coord, layout) → element offset (index) → byte offset (i32) → dword offset (i32). - """ - elem_off = ArithValue(crd2idx(coord, layout)).index_cast(T.i32) - return (ArithValue(elem_off) * elem_bytes) >> fx.Int32(2) - - -def _make_rope_copy_helpers(elem_type, elem_bits): - """Build copy atom and register types for RoPE vector loads/stores.""" - copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - vec_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register - ) - vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) - return copy_atom, vec_reg_ty, vec_reg_lay - - -def _load_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, div_tensor, idx): - """Vector load via layout API: div_tensor[:, idx] → register vec.""" - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return ArithValue(fx.memref_load_vec(r)) - - -def _store_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, val, div_tensor, idx): - """Vector store via layout API: register vec → div_tensor[:, idx].""" - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) - - -def _apply_neox_rope(qk_div, cos_div, sin_div, pair_div, - qk_tid, cos_tid, pair_tid, is_first_half, - copy_atom, vec_reg_ty, vec_reg_lay): - """Load, rotate (NeoX), and return the rotated vector. - - Performs: - out[first_half] = qk * cos - pair * sin - out[second_half] = qk * cos + pair * sin - - Uses buffer-backed tensor layout API for vector loads. - - Args: - qk_tid: index into qk_div for current thread's vector - cos_tid: index into cos_div/sin_div (tid % vecs_per_half) - pair_tid: index into pair_div for partner vector - - Returns: - rot_e: rotated vector in element type - """ - qk_e = _load_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, qk_div, qk_tid) - cos_e = _load_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, cos_div, cos_tid) - sin_e = _load_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, sin_div, cos_tid) - pair_e = _load_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, pair_div, pair_tid) - - # NeoX sign: first half uses -sin, second half uses +sin - qk_cos = ArithValue(qk_e) * ArithValue(cos_e) - pair_sin = ArithValue(pair_e) * ArithValue(sin_e) - sin_term = ArithValue(is_first_half).select(-pair_sin, pair_sin) - rot_e = ArithValue(qk_cos) + ArithValue(sin_term) - - return rot_e - - def build_fused_rope_cache_module( head_dim: int = 64, rotary_dim: int = -1, @@ -142,8 +75,6 @@ def build_fused_rope_cache_module( f"(f32 is not supported: kernel uses 2-byte elem_bytes and vec8 vectorization)" ) half_dim = rotary_dim // 2 - elem_bytes = 2 # bf16 and f16 are both 2 bytes - vec_dwords = (VEC_WIDTH * elem_bytes) // 4 # 4 dwords for vec8 of 2-byte elements vecs_per_half = half_dim // VEC_WIDTH # number of VEC_WIDTH-wide vectors covering half_dim vecs_per_head = head_dim // VEC_WIDTH # number of VEC_WIDTH-wide vectors covering head_dim x_size = 16 # x-packing factor for non-flash key_cache @@ -179,17 +110,6 @@ def build_fused_rope_cache_module( ) BLOCK_THREADS = WARP_SIZE - # Layout shape/stride tuples (plain Python ints) — materialized as - # fx.make_layout inside each kernel where an MLIR context is active. - # None is used for dynamic/unknown extents (token count, position range, - # block count) so the layout shape matches the actual indexing domain. - _q_shape = (None, num_q_heads, vecs_per_head) - _q_stride = (num_q_heads * head_dim, head_dim, VEC_WIDTH) - _kv_shape = (None, num_kv_heads, vecs_per_head) - _kv_stride = (num_kv_heads * head_dim, head_dim, VEC_WIDTH) - _cos_shape = (None, vecs_per_half) - _cos_stride = (half_dim, VEC_WIDTH) - # ----- Kernel 1: Q RoPE ----- # Grid: (T * QH, 1, 1), one program per (token, q_head) # Each program: vecs_per_head threads process head_dim elements @@ -212,16 +132,42 @@ def q_rope_kernel( Qo_buf = fx.rocdl.make_buffer_tensor(Q_out) Cos_buf = fx.rocdl.make_buffer_tensor(CosCache) Sin_buf = fx.rocdl.make_buffer_tensor(SinCache) - pos_rsrc = buffer_ops.create_buffer_resource(Positions, max_size=True) + Pos_buf = fx.rocdl.make_buffer_tensor(Positions) - copy_atom, vec_reg_ty, vec_reg_lay = _make_rope_copy_helpers(elem_type, elem_bits) + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + vec_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + + copy_atom_i32 = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + i32_reg_ty = fx.MemRefType.get(T.i32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + i32_reg_lay = fx.make_layout(1, 1) + + def load_scalar_i32(buf_tensor, elem_offset): + """Scalar i32 load using soffset for dynamic indexing.""" + div = fx.logical_divide(buf_tensor, fx.make_layout(1, 1)) + base_view = fx.slice(div, (None, fx.Int32(0))) + atom = copy_atom_i32.set_value("soffset", elem_offset) + r = fx.memref_alloca(i32_reg_ty, i32_reg_lay) + fx.copy_atom_call(atom, base_view, r) + return fx.memref_load_vec(r)[0] + + def load_vec(div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + def store_vec(val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) if tid < fx.Int32(vecs_per_head): pid_t = pid // num_q_heads pid_hq = pid % num_q_heads - # Load position - pos_val = buffer_ops.buffer_load(pos_rsrc, pid_t, vec_width=1, dtype=T.i32) + pos_val = load_scalar_i32(Pos_buf, pid_t) # Q[pid_t, pid_hq, :] tiled by VEC_WIDTH q_row = fx.slice(Q_buf, (pid_t, fx.Int32(pid_hq), None)) @@ -239,16 +185,20 @@ def q_rope_kernel( # NeoX rotation: pair with opposite half is_first_half = tid < fx.Int32(vecs_per_half) - pair_tid = ArithValue(is_first_half).select(tid + vecs_per_half, tid - vecs_per_half) - # tid % vecs_per_half wraps into cos/sin range + pair_tid = is_first_half.select(tid + vecs_per_half, tid - vecs_per_half) cos_vec_idx = tid % vecs_per_half - rot_e = _apply_neox_rope( - q_div, cos_div, sin_div, q_div, - tid, cos_vec_idx, pair_tid, is_first_half, - copy_atom, vec_reg_ty, vec_reg_lay, - ) - _store_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, rot_e, qo_div, tid) + qk_e = load_vec(q_div, tid) + cos_e = load_vec(cos_div, cos_vec_idx) + sin_e = load_vec(sin_div, cos_vec_idx) + pair_e = load_vec(q_div, pair_tid) + + qk_cos = qk_e * cos_e + pair_sin = pair_e * sin_e + sin_term = is_first_half.select(-pair_sin, pair_sin) + rot_e = qk_cos + sin_term + + store_vec(rot_e, qo_div, tid) # ----- Kernel 2: K RoPE + KV cache write ----- # Grid: (T * KH, 1, 1), one program per (token, kv_head) @@ -269,8 +219,7 @@ def k_cache_kernel( tid = fx.thread_idx.x # 0..63 elem_type = dtype_to_elem_type(dtype_str) - vec_type_e = T.vec(VEC_WIDTH, elem_type) - i32_vec_ty = T.vec(vec_dwords, T.i32) + elem_dtype = Numeric.from_ir_type(elem_type) elem_bits = 16 # bf16/f16 only # Buffer-backed tensors via layout API @@ -279,23 +228,58 @@ def k_cache_kernel( Ko_buf = fx.rocdl.make_buffer_tensor(K_out) Cos_buf = fx.rocdl.make_buffer_tensor(CosCache) Sin_buf = fx.rocdl.make_buffer_tensor(SinCache) - pos_rsrc = buffer_ops.create_buffer_resource(Positions, max_size=True) - slot_rsrc = buffer_ops.create_buffer_resource(SlotMapping, max_size=True) - # KV cache: keep buffer_ops for complex scattered writes - kc_rsrc = buffer_ops.create_buffer_resource(KeyCache, max_size=True) - vc_rsrc = buffer_ops.create_buffer_resource(ValueCache, max_size=True) + Pos_buf = fx.rocdl.make_buffer_tensor(Positions) + Slot_buf = fx.rocdl.make_buffer_tensor(SlotMapping) + KC_buf = fx.rocdl.make_buffer_tensor(KeyCache) + VC_buf = fx.rocdl.make_buffer_tensor(ValueCache) + + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + vec_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) - copy_atom, vec_reg_ty, vec_reg_lay = _make_rope_copy_helpers(elem_type, elem_bits) + copy_atom_i32 = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + i32_reg_ty = fx.MemRefType.get(T.i32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + i32_reg_lay = fx.make_layout(1, 1) - # Layouts for KV cache (used in non-layout-API scatter paths) - kv_layout = fx.make_layout(_kv_shape, _kv_stride) + if not flash_layout: + copy_atom_elem = fx.make_copy_atom(fx.rocdl.BufferCopy16b(), elem_bits) + elem_reg_ty = fx.MemRefType.get( + elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + elem_reg_lay = fx.make_layout(1, 1) + + def load_scalar_i32(buf_tensor, elem_offset): + """Scalar i32 load using soffset for dynamic indexing.""" + div = fx.logical_divide(buf_tensor, fx.make_layout(1, 1)) + base_view = fx.slice(div, (None, fx.Int32(0))) + atom = copy_atom_i32.set_value("soffset", elem_offset) + r = fx.memref_alloca(i32_reg_ty, i32_reg_lay) + fx.copy_atom_call(atom, base_view, r) + return fx.memref_load_vec(r)[0] + + def load_vec(div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + def store_vec(val, div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(val, r) + fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + + def store_scalar(val, div_tensor, idx): + r = fx.memref_alloca(elem_reg_ty, elem_reg_lay) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) + fx.copy_atom_call(copy_atom_elem, r, fx.slice(div_tensor, (None, idx))) if tid < fx.Int32(vecs_per_head): pid_t = pid // num_kv_heads pid_hk = pid % num_kv_heads - # Load position - pos_val = buffer_ops.buffer_load(pos_rsrc, pid_t, vec_width=1, dtype=T.i32) + pos_val = load_scalar_i32(Pos_buf, pid_t) # K[pid_t, pid_hk, :] tiled by VEC_WIDTH k_row = fx.slice(K_buf, (pid_t, fx.Int32(pid_hk), None)) @@ -313,82 +297,58 @@ def k_cache_kernel( # NeoX rotation is_first_half = tid < fx.Int32(vecs_per_half) - pair_tid = ArithValue(is_first_half).select(tid + vecs_per_half, tid - vecs_per_half) + pair_tid = is_first_half.select(tid + vecs_per_half, tid - vecs_per_half) cos_vec_idx = tid % vecs_per_half - k_rot_e = _apply_neox_rope( - k_div, cos_div, sin_div, k_div, - tid, cos_vec_idx, pair_tid, is_first_half, - copy_atom, vec_reg_ty, vec_reg_lay, - ) - _store_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, k_rot_e, ko_div, tid) + qk_e = load_vec(k_div, tid) + cos_e = load_vec(cos_div, cos_vec_idx) + sin_e = load_vec(sin_div, cos_vec_idx) + pair_e = load_vec(k_div, pair_tid) + + qk_cos = qk_e * cos_e + pair_sin = pair_e * sin_e + sin_term = is_first_half.select(-pair_sin, pair_sin) + k_rot_e = qk_cos + sin_term + + store_vec(k_rot_e, ko_div, tid) # --- KV Cache write --- - slot_val = buffer_ops.buffer_load(slot_rsrc, pid_t, vec_width=1, dtype=T.i32) + slot_val = load_scalar_i32(Slot_buf, pid_t) if slot_val >= fx.Int32(0): - pid_t_slot = ArithValue(slot_val) // block_size - pid_b = ArithValue(slot_val) % block_size + pid_t_slot = slot_val // block_size + pid_b = slot_val % block_size - # Load V via layout API + # Load V v_row = fx.slice(V_buf, (pid_t, fx.Int32(pid_hk), None)) v_div = fx.logical_divide(v_row, fx.make_layout(VEC_WIDTH, 1)) - v_e = _load_vec_buf(copy_atom, vec_reg_ty, vec_reg_lay, v_div, tid) - - # Bitcast for KV cache stores (buffer_ops needs i32 vecs) - k_rot_i32 = vector.bitcast(i32_vec_ty, k_rot_e) - v_raw = vector.bitcast(i32_vec_ty, v_e) - - # KV cache dword offset for stores (scattered, keep buffer_ops) - kv_coord = (pid_t, fx.Int32(pid_hk), tid) - k_dw = _layout_to_dword_off(kv_coord, kv_layout, elem_bytes) + v_e = load_vec(v_div, tid) if flash_layout: - kc_flash_layout = fx.make_layout( - (None, block_size, num_kv_heads, vecs_per_head), - (block_size * num_kv_heads * head_dim, - num_kv_heads * head_dim, - head_dim, - VEC_WIDTH), - ) - kc_coord = (pid_t_slot, pid_b, pid_hk, tid) - kc_dw = _layout_to_dword_off(kc_coord, kc_flash_layout, elem_bytes) - - buffer_ops.buffer_store(k_rot_i32, kc_rsrc, kc_dw) - buffer_ops.buffer_store(v_raw, vc_rsrc, kc_dw) + # Flash: [num_blocks, block_size, KH, D] → 1D, tile by VEC_WIDTH + kc_row = fx.slice(KC_buf, (pid_t_slot, pid_b, fx.Int32(pid_hk), None)) + kc_div = fx.logical_divide(kc_row, fx.make_layout(VEC_WIDTH, 1)) + vc_row = fx.slice(VC_buf, (pid_t_slot, pid_b, fx.Int32(pid_hk), None)) + vc_div = fx.logical_divide(vc_row, fx.make_layout(VEC_WIDTH, 1)) + + store_vec(k_rot_e, kc_div, tid) + store_vec(v_e, vc_div, tid) else: # Non-flash key_cache: [num_blocks, KH, D//x, BS, x] - d_start = ArithValue(tid) * VEC_WIDTH - dim_group = d_start // x_size - dim_within = d_start % x_size - - kc_nf_layout = fx.make_layout( - (None, num_kv_heads, head_dim // x_size, block_size, x_size), - (num_kv_heads * (head_dim // x_size) * block_size * x_size, - (head_dim // x_size) * block_size * x_size, - block_size * x_size, - x_size, - 1), - ) - kc_coord_nf = (pid_t_slot, pid_hk, dim_group, pid_b, dim_within) - kc_dw_nf = _layout_to_dword_off(kc_coord_nf, kc_nf_layout, elem_bytes) - - buffer_ops.buffer_store(k_rot_i32, kc_rsrc, kc_dw_nf) - - # Non-flash value_cache: scalar stores (non-contiguous layout) - vc_nf_layout = fx.make_layout( - (None, num_kv_heads, head_dim, block_size), - (num_kv_heads * head_dim * block_size, - head_dim * block_size, - block_size, - 1), - ) + dim_group = (tid * VEC_WIDTH) // x_size + sub_tile = tid % (x_size // VEC_WIDTH) + + kc_nf_row = fx.slice(KC_buf, (pid_t_slot, fx.Int32(pid_hk), dim_group, pid_b, None)) + kc_nf_div = fx.logical_divide(kc_nf_row, fx.make_layout(VEC_WIDTH, 1)) + store_vec(k_rot_e, kc_nf_div, sub_tile) + + # Non-flash value_cache: [num_blocks, KH, D, block_size] for vi in range_constexpr(VEC_WIDTH): - v_scalar = vector.extract(v_e, static_position=[vi]) - d_idx = ArithValue(tid) * VEC_WIDTH + vi - vc_coord = (pid_t_slot, pid_hk, d_idx, pid_b) - vc_elem_off = ArithValue(crd2idx(vc_coord, vc_nf_layout)).index_cast(T.i32) - buffer_ops.buffer_store(v_scalar, vc_rsrc, vc_elem_off) + v_scalar = v_e[vi] + d_idx = tid * VEC_WIDTH + vi + vc_row = fx.slice(VC_buf, (pid_t_slot, fx.Int32(pid_hk), d_idx, None)) + vc_div = fx.logical_divide(vc_row, fx.make_layout(1, 1)) + store_scalar(v_scalar, vc_div, pid_b) @flyc.jit def launch_fused_rope_cache( @@ -407,7 +367,7 @@ def launch_fused_rope_cache( stream: fx.Stream = fx.Stream(None), ): # Kernel 1: Q RoPE - n_q = ArithValue(num_tokens) * num_q_heads + n_q = num_tokens * num_q_heads q_launcher = q_rope_kernel(Q, Positions, CosCache, SinCache, Q_out) q_launcher.launch( grid=(n_q, 1, 1), @@ -416,7 +376,7 @@ def launch_fused_rope_cache( ) # Kernel 2: K RoPE + KV cache write - n_k = ArithValue(num_tokens) * num_kv_heads + n_k = num_tokens * num_kv_heads k_launcher = k_cache_kernel( K, V, Positions, CosCache, SinCache, SlotMapping, KeyCache, ValueCache, K_out, diff --git a/kernels/layernorm_kernel.py b/kernels/layernorm_kernel.py index d8ceaafd..6f1441ea 100644 --- a/kernels/layernorm_kernel.py +++ b/kernels/layernorm_kernel.py @@ -15,9 +15,11 @@ import flydsl.expr as fx from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, vector, gpu, range_constexpr +from flydsl.expr import arith, gpu, range_constexpr from flydsl.expr.arith import ArithValue from flydsl.expr.typing import T, Int32 +from flydsl.expr.vector import ReductionOp, full +from flydsl.expr.numeric import Numeric, Float32, Uint32 from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from flydsl.runtime.device import get_rocm_arch as get_hip_arch @@ -128,8 +130,6 @@ def block_reduce_add2(val0, val1): return s_sum.load([c0_idx]), s_sumsq.load([c0_idx]) def compute_mean_rstd(sum_val, sumsq_val): - from flydsl.expr.arith import ArithValue - inv_n = arith.constant(1.0 / float(N), type=compute_type) s = ArithValue(sum_val) ss = ArithValue(sumsq_val) @@ -150,19 +150,14 @@ def compute_mean_rstd(sum_val, sumsq_val): # memory access (same approach as preshuffle_gemm). # ================================================================== if N == (BLOCK_THREADS * VEC_WIDTH * 4) and elem_bits <= 16: - from flydsl.expr.arith import ArithValue - num_tiles_py = 4 + elem_dtype = Numeric.from_ir_type(elem_type) c_zero_f = arith.constant(0.0, type=compute_type) thread_sum = c_zero_f thread_sumsq = c_zero_f - cache_as_elem = (dtype_str != "f32") in_local = [] - vec_type_c = T.vec(VEC_WIDTH, compute_type) - vec_type_e = T.vec(VEC_WIDTH, elem_type) - # ── Layout API: buffer-backed tensors + tiled access ───── Input_buf = fx.rocdl.make_buffer_tensor(Input) Output_buf = fx.rocdl.make_buffer_tensor(Output) @@ -186,7 +181,7 @@ def compute_mean_rstd(sum_val, sumsq_val): def _load_vec(div_tensor, idx): r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return ArithValue(fx.memref_load_vec(r)) + return fx.memref_load_vec(r) def _store_vec(val, div_tensor, idx): r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) @@ -196,106 +191,55 @@ def _store_vec(val, div_tensor, idx): # ── Pass 1: load input, accumulate sum / sumsq ─────────────── for tile_i in range_constexpr(num_tiles_py): idx = tid + tile_i * BLOCK_THREADS - vec_e = _load_vec(in_div, idx) - - if cache_as_elem: - in_local.append(vec_e) - x = vec_e.extf(vec_type_c) - else: - x = vec_e - in_local.append(x) + vec = _load_vec(in_div, idx) + in_local.append(vec) + x = vec.to(Float32) - x_av = ArithValue(x) - x2 = x_av * x_av - red = vector.reduction( - compute_type, vector.CombiningKind.ADD, - x, fastmath=fm_fast, - ) - red2 = vector.reduction( - compute_type, vector.CombiningKind.ADD, - x2, fastmath=fm_fast, - ) + x2 = x * x + red = x.reduce(ReductionOp.ADD, fastmath=fm_fast) + red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sum = ArithValue(thread_sum) + red thread_sumsq = ArithValue(thread_sumsq) + red2 sum_val, sumsq_val = block_reduce_add2(thread_sum, thread_sumsq) mean, rstd = compute_mean_rstd(sum_val, sumsq_val) - mean_splat = vector.broadcast(vec_type_c, mean) - rstd_splat = vector.broadcast(vec_type_c, rstd) - mean_splat_av = ArithValue(mean_splat) - rstd_splat_av = ArithValue(rstd_splat) - - g_e_cur = _load_vec(gamma_div, tid) - b_e_cur = _load_vec(beta_div, tid) - g_cur = ( - g_e_cur - if dtype_str == "f32" - else g_e_cur.extf(vec_type_c) - ) - b_cur = ( - b_e_cur - if dtype_str == "f32" - else b_e_cur.extf(vec_type_c) - ) + g_cur = _load_vec(gamma_div, tid).to(Float32) + b_cur = _load_vec(beta_div, tid).to(Float32) # ── Pass 2: normalize + affine + store ─────────────────────── for tile_i in range_constexpr(num_tiles_py): if tile_i + 1 < num_tiles_py: next_idx = tid + (tile_i + 1) * BLOCK_THREADS - g_e_next = _load_vec(gamma_div, next_idx) - b_e_next = _load_vec(beta_div, next_idx) - g_next = ( - g_e_next - if dtype_str == "f32" - else g_e_next.extf(vec_type_c) - ) - b_next = ( - b_e_next - if dtype_str == "f32" - else b_e_next.extf(vec_type_c) - ) + g_next = _load_vec(gamma_div, next_idx).to(Float32) + b_next = _load_vec(beta_div, next_idx).to(Float32) else: g_next = g_cur b_next = b_cur - x = in_local[tile_i] - if cache_as_elem: - x = x.extf(vec_type_c) - - x_av = ArithValue(x) - g_av = ArithValue(g_cur) - b_av = ArithValue(b_cur) - y = (x_av - mean_splat_av) * rstd_splat_av - y = (y * g_av) + b_av - y_val = y + x = in_local[tile_i].to(Float32) + y = (x - mean) * rstd + y = y * g_cur + b_cur if dtype_str == "bf16": if USE_HW_CVT_PK_BF16_F32: - out_e = y_val.truncf(vec_type_e) + out_e = y.to(elem_dtype) else: - vec_i32_ty = T.vec(VEC_WIDTH, T.i32) - vec4_i32_ty = T.vec(VEC_WIDTH // 2, T.i32) - vec_bf16_ty = T.vec(VEC_WIDTH, elem_type) - c16_i32 = arith.constant(16, type=T.i32) - c16_v = vector.broadcast(vec_i32_ty, c16_i32) - u = y_val.bitcast(vec_i32_ty) - upper = u.shrui(c16_v) - c1_v = vector.broadcast(vec_i32_ty, arith.constant(1, type=T.i32)) - lsb = upper & c1_v - c7fff_v = vector.broadcast(vec_i32_ty, arith.constant(0x7FFF, type=T.i32)) - bias = ArithValue(c7fff_v) + lsb - u_round = u + bias - bf16_bits = u_round.shrui(c16_v) - even = vector.shuffle(bf16_bits, bf16_bits, [0, 2, 4, 6]) - odd = vector.shuffle(bf16_bits, bf16_bits, [1, 3, 5, 7]) - odd_sh = odd << vector.broadcast(vec4_i32_ty, c16_i32) + u = y.bitcast(Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 packed = even | odd_sh - out_e = vector.bitcast(vec_bf16_ty, packed) + out_e = packed.bitcast(elem_dtype) elif dtype_str == "f32": - out_e = y_val + out_e = y else: - out_e = y_val.truncf(vec_type_e) + out_e = y.to(elem_dtype) out_idx = tid + tile_i * BLOCK_THREADS _store_vec(out_e, out_div, out_idx) @@ -307,7 +251,7 @@ def _store_vec(val, div_tensor, idx): # ============================================================== # Generic path: 2-pass scalar implementation for arbitrary N # ============================================================== - from flydsl.expr.arith import ArithValue + elem_dtype = Numeric.from_ir_type(elem_type) Input_buf = fx.rocdl.make_buffer_tensor(Input) Output_buf = fx.rocdl.make_buffer_tensor(Output) @@ -339,14 +283,12 @@ def _load_scalar(divided_tensor, index): view = fx.slice(divided_tensor, (None, index)) r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) fx.copy_atom_call(copy_atom_s, view, r) - v = fx.memref_load_vec(r) - return vector.extract(v, static_position=[0]) + return fx.memref_load_vec(r)[0].ir_value() def _store_scalar(divided_tensor, index, val): r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - vec_ty = T.vec(1, elem_type) - v = vector.from_elements(vec_ty, [val]) - fx.memref_store_vec(v, r) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) view = fx.slice(divided_tensor, (None, index)) fx.copy_atom_call(copy_atom_s, r, view) @@ -397,9 +339,9 @@ def _store_scalar(divided_tensor, index, val): else b_e.extf(compute_type) ) diff = ArithValue(x) - mean - norm = diff * ArithValue(rstd) - scaled = norm * ArithValue(g) - y = scaled + ArithValue(b) + norm = diff * rstd + scaled = norm * g + y = scaled + b if dtype_str == "bf16": y_e = y.truncf(elem_type) elif dtype_str == "f32": diff --git a/kernels/rmsnorm_kernel.py b/kernels/rmsnorm_kernel.py index b22681e3..09f0c88e 100644 --- a/kernels/rmsnorm_kernel.py +++ b/kernels/rmsnorm_kernel.py @@ -14,9 +14,11 @@ import flydsl.expr as fx from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, vector, gpu, range_constexpr +from flydsl.expr import arith, gpu, range_constexpr from flydsl.expr.arith import ArithValue from flydsl.expr.typing import T, Int32 +from flydsl.expr.vector import ReductionOp, full +from flydsl.expr.numeric import Numeric, Float32, Uint32 from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from flydsl.runtime.device import get_rocm_arch as get_hip_arch @@ -128,12 +130,8 @@ def block_reduce_add2(val0, val1): # Fast path: N is a multiple of tile_cols # ================================================================== if N >= tile_cols and N % tile_cols == 0 and elem_bits <= 16: - from flydsl.expr.arith import ArithValue - num_tiles = N // tile_cols - - vec_type_c = T.vec(VEC_WIDTH, compute_type) - vec_type_e = T.vec(VEC_WIDTH, elem_type) + elem_dtype = Numeric.from_ir_type(elem_type) # ── Layout API: buffer-backed tensors + tiled access ───── Input_buf = fx.rocdl.make_buffer_tensor(Input) @@ -156,7 +154,7 @@ def block_reduce_add2(val0, val1): def _load_vec(div_tensor, idx): r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return ArithValue(fx.memref_load_vec(r)) + return fx.memref_load_vec(r) def _store_vec(val, div_tensor, idx): r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) @@ -166,75 +164,52 @@ def _store_vec(val, div_tensor, idx): c_zero_f = arith.constant(0.0, type=compute_type) thread_sumsq = c_zero_f thread_dummy = c_zero_f - cache_as_elem = (dtype_str != "f32") in_local = [] # Pass 1: load + cache + sumsq for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - vec_e = _load_vec(in_div, idx) - - if cache_as_elem: - in_local.append(vec_e) - x = vec_e.extf(vec_type_c) - else: - x = vec_e - in_local.append(x) + vec = _load_vec(in_div, idx) + in_local.append(vec) + x = vec.to(Float32) - x_av = ArithValue(x) - x2 = x_av * x_av - red2 = vector.reduction(compute_type, vector.CombiningKind.ADD, x2, fastmath=fm_fast) + x2 = x * x + red2 = x2.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sumsq = ArithValue(thread_sumsq) + red2 _, sum_sq = block_reduce_add2(thread_dummy, thread_sumsq) mean_sq = ArithValue(sum_sq) / n_float - ms_eps = ArithValue(mean_sq) + eps_c + ms_eps = mean_sq + eps_c rrms = ms_eps.rsqrt(fastmath=fm_fast) - rrms_splat = vector.broadcast(vec_type_c, rrms) - rrms_splat_av = ArithValue(rrms_splat) # Pass 2: normalize + gamma + store (reuse cached input) for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - g_e = _load_vec(gamma_div, idx) - g = g_e if dtype_str == "f32" else g_e.extf(vec_type_c) + g = _load_vec(gamma_div, idx).to(Float32) + x = in_local[tile_i].to(Float32) - x = in_local[tile_i] - if cache_as_elem: - x = x.extf(vec_type_c) - - x_av = ArithValue(x) - g_av = ArithValue(g) - y = (x_av * rrms_splat_av) * g_av - y_val = y + y = (x * rrms) * g if dtype_str == "bf16": if USE_HW_CVT_PK_BF16_F32: - out_e = y_val.truncf(vec_type_e) + out_e = y.to(elem_dtype) else: - vec_i32_ty = T.vec(VEC_WIDTH, T.i32) - vec4_i32_ty = T.vec(VEC_WIDTH // 2, T.i32) - vec_bf16_ty = T.vec(VEC_WIDTH, elem_type) - c16_i32 = arith.constant(16, type=T.i32) - c16_v = vector.broadcast(vec_i32_ty, c16_i32) - u = y_val.bitcast(vec_i32_ty) - upper = u.shrui(c16_v) - c1_v = vector.broadcast(vec_i32_ty, arith.constant(1, type=T.i32)) - lsb = upper & c1_v - c7fff_v = vector.broadcast(vec_i32_ty, arith.constant(0x7FFF, type=T.i32)) - bias = ArithValue(c7fff_v) + ArithValue(lsb) - u_round = ArithValue(u) + bias - bf16_bits = u_round.shrui(c16_v) - even = vector.shuffle(bf16_bits, bf16_bits, [0, 2, 4, 6]) - odd = vector.shuffle(bf16_bits, bf16_bits, [1, 3, 5, 7]) - odd_sh = odd << vector.broadcast(vec4_i32_ty, c16_i32) + u = y.bitcast(Uint32) + upper = u >> 16 + lsb = upper & 1 + bias = lsb + 0x7FFF + u_round = y.bitcast(Uint32) + bias + bf16_bits = u_round >> 16 + even = bf16_bits.shuffle(bf16_bits, [0, 2, 4, 6]) + odd = bf16_bits.shuffle(bf16_bits, [1, 3, 5, 7]) + odd_sh = odd << 16 packed = even | odd_sh - out_e = vector.bitcast(vec_bf16_ty, packed) + out_e = packed.bitcast(elem_dtype) elif dtype_str == "f32": - out_e = y_val + out_e = y else: - out_e = y_val.truncf(vec_type_e) + out_e = y.to(elem_dtype) out_idx = tid + tile_i * BLOCK_THREADS _store_vec(out_e, out_div, out_idx) @@ -243,7 +218,7 @@ def _store_vec(val, div_tensor, idx): # ============================================================== # Generic path: scalar 2-pass for arbitrary N # ============================================================== - from flydsl.expr.arith import ArithValue + elem_dtype = Numeric.from_ir_type(elem_type) Input_buf = fx.rocdl.make_buffer_tensor(Input) Output_buf = fx.rocdl.make_buffer_tensor(Output) @@ -267,14 +242,12 @@ def _load_scalar(divided_tensor, index): view = fx.slice(divided_tensor, (None, index)) r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) fx.copy_atom_call(copy_atom_s, view, r) - v = fx.memref_load_vec(r) - return vector.extract(v, static_position=[0]) + return fx.memref_load_vec(r)[0].ir_value() def _store_scalar(divided_tensor, index, val): r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - vec_ty = T.vec(1, elem_type) - v = vector.from_elements(vec_ty, [val]) - fx.memref_store_vec(v, r) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) view = fx.slice(divided_tensor, (None, index)) fx.copy_atom_call(copy_atom_s, r, view) @@ -296,7 +269,7 @@ def _store_scalar(divided_tensor, index, val): sum_sq = block_reduce_add(thread_sumsq) mean_sq = ArithValue(sum_sq) / n_float - ms_eps = ArithValue(mean_sq) + eps_c + ms_eps = mean_sq + eps_c rrms = ms_eps.rsqrt(fastmath=fm_fast) for base_idx_int in range_constexpr(0, N, BLOCK_THREADS): @@ -307,8 +280,8 @@ def _store_scalar(divided_tensor, index, val): g_e = _load_scalar(gamma_div, idx) x = x_e if dtype_str == "f32" else x_e.extf(compute_type) g = g_e if dtype_str == "f32" else g_e.extf(compute_type) - norm = ArithValue(x) * ArithValue(rrms) - y = norm * ArithValue(g) + norm = ArithValue(x) * rrms + y = norm * g if dtype_str == "f32": y_e = y elif dtype_str == "bf16": diff --git a/kernels/softmax_kernel.py b/kernels/softmax_kernel.py index 870add5b..8f5d4d32 100644 --- a/kernels/softmax_kernel.py +++ b/kernels/softmax_kernel.py @@ -17,9 +17,11 @@ import flydsl.expr as fx from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, vector, gpu, range_constexpr +from flydsl.expr import arith, gpu, range_constexpr from flydsl.expr.arith import ArithValue from flydsl.expr.typing import T, Int32 +from flydsl.expr.vector import ReductionOp, full +from flydsl.expr.numeric import Numeric, Float32 from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from flydsl.runtime.device import get_rocm_arch as get_hip_arch @@ -122,12 +124,10 @@ def block_reduce(val, mode): # Fast path: N is a multiple of tile_cols # ================================================================== if False and N >= tile_cols and N % tile_cols == 0: - from flydsl.expr.arith import ArithValue + from flydsl.expr import math as fmath num_tiles = N // tile_cols - - vec_type_c = T.vec(VEC_WIDTH, compute_type) - vec_type_e = T.vec(VEC_WIDTH, elem_type) + elem_dtype = Numeric.from_ir_type(elem_type) # ── Layout API: buffer-backed tensors + tiled access ───── A_buf = fx.rocdl.make_buffer_tensor(A) @@ -148,7 +148,7 @@ def block_reduce(val, mode): def _load_vec(div_tensor, idx): r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return ArithValue(fx.memref_load_vec(r)) + return fx.memref_load_vec(r) def _store_vec(val, div_tensor, idx): r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) @@ -161,26 +161,23 @@ def _store_vec(val, div_tensor, idx): for tile_i in range_constexpr(num_tiles): idx = tid + tile_i * BLOCK_THREADS - vec_e = _load_vec(a_div, idx) - x = vec_e if dtype_str == "f32" else vec_e.extf(vec_type_c) + vec = _load_vec(a_div, idx) + x = vec.to(Float32) row_buffer.append(x) - red_max = vector.reduction(compute_type, vector.CombiningKind.MAXNUMF, x) + red_max = x.reduce(ReductionOp.MAX) thread_max = thread_max.maximumf(red_max) global_max = block_reduce(thread_max, "max") # 2. Exp + local sum - g_max_splat = vector.broadcast(vec_type_c, global_max) - log2e_splat = vector.broadcast(vec_type_c, c_log2e) thread_sum = c_zero_f for i in range_constexpr(num_tiles): x = row_buffer[i] - sub = ArithValue(x) - ArithValue(g_max_splat) - scaled = sub * ArithValue(log2e_splat) - exp_val = scaled.exp2(fastmath=fm_fast) + scaled = (x - global_max) * c_log2e + exp_val = fmath.exp2(scaled, fastmath=True) row_buffer[i] = exp_val - red_sum = vector.reduction(compute_type, vector.CombiningKind.ADD, exp_val, fastmath=fm_fast) + red_sum = exp_val.reduce(ReductionOp.ADD, fastmath=fm_fast) thread_sum = thread_sum + red_sum global_sum = block_reduce(thread_sum, "sum") @@ -188,16 +185,10 @@ def _store_vec(val, div_tensor, idx): # 3. Normalize + store c_one = arith.constant(1.0, type=compute_type) inv_sum = c_one / ArithValue(global_sum) - inv_sum_splat = vector.broadcast(vec_type_c, inv_sum) for tile_i in range_constexpr(num_tiles): - exp_vec = row_buffer[tile_i] - norm_vec = ArithValue(exp_vec) * ArithValue(inv_sum_splat) - - if dtype_str == "f32": - out_e = norm_vec - else: - out_e = norm_vec.truncf(vec_type_e) + norm_vec = row_buffer[tile_i] * inv_sum + out_e = norm_vec if dtype_str == "f32" else norm_vec.to(elem_dtype) out_idx = tid + tile_i * BLOCK_THREADS _store_vec(out_e, c_div, out_idx) @@ -206,7 +197,7 @@ def _store_vec(val, div_tensor, idx): # ============================================================== # Generic path: scalar for arbitrary N # ============================================================== - from flydsl.expr.arith import ArithValue + elem_dtype = Numeric.from_ir_type(elem_type) A_buf = fx.rocdl.make_buffer_tensor(A) C_buf = fx.rocdl.make_buffer_tensor(C) @@ -228,14 +219,12 @@ def _load_scalar(divided, index): view = fx.slice(divided, (None, index)) r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) fx.copy_atom_call(copy_atom_s, view, r) - v = fx.memref_load_vec(r) - return vector.extract(v, static_position=[0]) + return fx.memref_load_vec(r)[0].ir_value() def _store_scalar(divided, index, val): r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) - vec_ty = T.vec(1, elem_type) - v = vector.from_elements(vec_ty, [val]) - fx.memref_store_vec(v, r) + ts = full(1, elem_dtype(val), elem_dtype) + fx.memref_store_vec(ts, r) view = fx.slice(divided, (None, index)) fx.copy_atom_call(copy_atom_s, r, view) diff --git a/python/flydsl/expr/numeric.py b/python/flydsl/expr/numeric.py index 36591d0f..315357be 100644 --- a/python/flydsl/expr/numeric.py +++ b/python/flydsl/expr/numeric.py @@ -145,7 +145,7 @@ def _widen_narrow_int(x, widen_bool=False): def _resolve_float_type(ta, tb): """Pick the wider float type, or the one with higher rank at equal width.""" - _FLOAT_RANK = {Float64: 3, Float32: 2, Float16: 1, BFloat16: 1} + # Use module-level _FLOAT_RANK (defined after all classes) if ta.is_float and not tb.is_float: return ta if tb.is_float and not ta.is_float: @@ -256,7 +256,6 @@ def __hash__(self): def select(self, true_value, false_value, *, loc=None): """Ternary select (for Boolean conditions from Int32 comparisons).""" - from .utils.arith import ArithValue return ArithValue(self).select(true_value, false_value, loc=loc) @property @@ -370,6 +369,7 @@ def from_ir_type(ir_type): T.f8E5M2(): Float8E5M2, T.f8E4M3(): Float8E4M3, T.f8E4M3FN(): Float8E4M3FN, + Float8E4M3FNUZ.ir_type: Float8E4M3FNUZ, # not in upstream MLIR extras T T.f8E4M3B11FNUZ(): Float8E4M3B11FNUZ, T.f8E8M0FNU(): Float8E8M0FNU, T.f6E2M3FN(): Float6E2M3FN, @@ -683,7 +683,7 @@ class Float8E5M2(Float, metaclass=NumericMeta, width=8, ir_type=T.f8E5M2): ... class Float8E4M3FN(Float, metaclass=NumericMeta, width=8, ir_type=T.f8E4M3FN): ... -class Float8E4M3FNUZ(Float, metaclass=NumericMeta, width=8, ir_type=lambda: ir.Float8E4M3FNUZType.get()): ... +class Float8E4M3FNUZ(Float, metaclass=NumericMeta, width=8, ir_type=lambda: ir.Float8E4M3FNUZType.get()): ... # not in upstream MLIR extras T class Float8E4M3B11FNUZ(Float, metaclass=NumericMeta, width=8, ir_type=T.f8E4M3B11FNUZ): ... @@ -705,6 +705,65 @@ class Float4E2M1FN(Float, metaclass=NumericMeta, width=4, ir_type=T.f4E2M1FN): . +# Float type rank for promotion (must be after class definitions) +_FLOAT_RANK = {Float64: 3, Float32: 2, Float16: 1, BFloat16: 1} + +# ── Type promotion (added to Numeric after all subclasses exist) ────── + +_FLOAT_BY_MIN_WIDTH = {16: Float16, 32: Float32, 64: Float64} + + +def _widen_float(float_type, min_width): + """Return the narrowest standard float type with width >= *min_width*.""" + if float_type.width >= min_width: + return float_type + for w in (32, 64): + if w >= min_width: + return _FLOAT_BY_MIN_WIDTH[w] + return Float64 + + +@classmethod +def _promote(cls, a_type, b_type): + """Resolve the promoted result type for two Numeric types. + + :param a_type: Left Numeric class (e.g. Float16) + :param b_type: Right Numeric class (e.g. Float32) + :return: The common Numeric class both can be safely promoted to + """ + if a_type is b_type: + return a_type + + a_float = a_type.is_float + b_float = b_type.is_float + + if a_float and not b_float: + return _widen_float(a_type, b_type.width) + if b_float and not a_float: + return _widen_float(b_type, a_type.width) + + if a_float and b_float: + aw, bw = a_type.width, b_type.width + if aw > bw and aw >= 16: + return a_type + if bw > aw and bw >= 16: + return b_type + if aw == bw: + ra = _FLOAT_RANK.get(a_type, 0) + rb = _FLOAT_RANK.get(b_type, 0) + return a_type if ra >= rb else b_type + raise ValueError(f"cannot promote {a_type} and {b_type}; cast explicitly") + + # Both integers + if a_type.signed == b_type.signed: + return a_type if a_type.width >= b_type.width else b_type + u, s = (a_type, b_type) if not a_type.signed else (b_type, a_type) + return u if u.width >= s.width else s + + +Numeric.promote = _promote + + class Index(Integer, metaclass=NumericMeta, width=64, signed=False, ir_type=lambda: ir.IndexType.get()): """DSL Numeric for MLIR index type. Replaces arith.index(N). diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index c0d37cd7..852d2071 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -2,13 +2,25 @@ # Copyright (c) 2025 FlyDSL Project Contributors import ctypes -from typing import Generic, TypeVar +import enum +from inspect import isclass +from typing import Generic, Type, TypeVar from flydsl.runtime.device import get_rocm_arch from .._mlir import ir from .._mlir.dialects import gpu +from .._mlir.dialects import vector as _vector from .meta import traced_op +from .utils.arith import ( + ArithValue, + element_type, + fp_to_fp, + fp_to_int, + int_to_fp, + int_to_int, + _to_raw, +) from .numeric import ( BFloat16, Boolean, @@ -31,6 +43,7 @@ Int16, Int32, Int64, + Integer, Numeric, Uint8, Uint16, @@ -258,6 +271,12 @@ def vec(self, n: int, elem: ir.Type) -> ir.Type: "TiledMma", "Stream", "Tuple3D", + # Vector types + "Vector", + "ReductionOp", + "full", + "full_like", + "zeros_like", ] @@ -796,3 +815,194 @@ def __getattr__(self, name): def __iter__(self): return iter((self.x, self.y, self.z)) + + +# ═══════════════════════════════════════════════════════════════════════ +# Vector — register vector with value semantics +# ═══════════════════════════════════════════════════════════════════════ + +class ReductionOp(enum.Enum): + ADD = "add" + MUL = "mul" + MAX = "max" + MIN = "min" + + +_REDUCE_KINDS = { + "add": (_vector.CombiningKind.ADD, _vector.CombiningKind.ADD, _vector.CombiningKind.ADD), + "mul": (_vector.CombiningKind.MUL, _vector.CombiningKind.MUL, _vector.CombiningKind.MUL), + "max": (_vector.CombiningKind.MAXNUMF, _vector.CombiningKind.MAXSI, _vector.CombiningKind.MAXUI), + "min": (_vector.CombiningKind.MINIMUMF, _vector.CombiningKind.MINSI, _vector.CombiningKind.MINUI), +} + + +def _resolve_combining_kind(op, is_float, signed): + if isinstance(op, _vector.CombiningKind): + return op + if isinstance(op, ReductionOp): + key = op.value + elif isinstance(op, str): + key = op.lower() + else: + raise TypeError(f"reduce op must be str, ReductionOp, or CombiningKind, got {type(op)}") + triple = _REDUCE_KINDS.get(key) + if triple is None: + raise ValueError(f"unknown reduction kind {op!r}; expected one of {list(_REDUCE_KINDS)}") + return triple[0] if is_float else (triple[1] if signed else triple[2]) + + +@ir.register_value_caster(ir.VectorType.static_typeid, replace=True) +class Vector(ArithValue): + """Thread-local register vector with value semantics. + + Wraps a flat ``vector`` ir.Value with shape and dtype metadata. + Arithmetic operators are inherited from ArithValue; scalar operands + are auto-broadcast via ``_coerce_other``. + """ + + def __init__(self, value, shape=None, dtype=None): + if not isinstance(value, ir.Value) and hasattr(value, "ir_value"): + value = value.ir_value() + if shape is None: + vty = ir.VectorType(value.type) + shape = tuple(vty.shape) + dtype = Numeric.from_ir_type(vty.element_type) + signed = dtype.signed if isclass(dtype) and issubclass(dtype, Integer) else False + super().__init__(value, signed) + self._shape = (shape,) if isinstance(shape, int) else tuple(shape) + self._dtype = dtype + + @property + def dtype(self) -> Type[Numeric]: + return self._dtype + + @property + def element_type(self) -> Type[Numeric]: + return self._dtype + + @property + def shape(self): + return self._shape + + @property + def numel(self) -> int: + r = 1 + for s in self._shape: + r *= s + return r + + def __str__(self): + return f"Vector({self.type} o {self._shape}, {self._dtype.__name__})" + + def __repr__(self): + return self.__str__() + + def __fly_values__(self): + return [self] + + @classmethod + def __fly_construct__(cls, values): + return values[0] + + def to(self, dtype: Type[Numeric], *, loc=None, ip=None) -> "Vector": + if dtype is ir.Value: + return self + if not isclass(dtype) or not issubclass(dtype, Numeric): + raise TypeError(f"dtype must be a Numeric type, got {type(dtype)}") + src_dtype = self._dtype + if src_dtype is dtype: + return self + src_float = getattr(src_dtype, 'is_float', False) + dst_float = getattr(dtype, 'is_float', False) + if src_float and dst_float: + res = fp_to_fp(self, dtype.ir_type, loc=loc, ip=ip) + elif src_float: + res = fp_to_int(self, dtype.signed, dtype.ir_type, loc=loc, ip=ip) + elif dst_float: + res = int_to_fp(self, src_dtype.signed, dtype.ir_type, loc=loc, ip=ip) + else: + res = int_to_int(self, dtype, loc=loc, ip=ip) + return Vector(res, self._shape, dtype) + + def ir_value(self, *, loc=None, ip=None): + return self + + def reduce(self, op, init_val=None, reduction_profile=None, + *, fastmath=None, loc=None, ip=None): + is_fp = self._dtype.is_float + signed = getattr(self._dtype, 'signed', True) + kind = _resolve_combining_kind(op, is_fp, signed) + et = element_type(self.type) + kwargs = {} + if fastmath is not None: + kwargs["fastmath"] = fastmath + if init_val is not None: + if isinstance(init_val, Numeric): + init_val = init_val.ir_value(loc=loc, ip=ip) + kwargs["acc"] = _to_raw(init_val) + res = _vector.reduction(et, kind, self, loc=loc, ip=ip, **kwargs) + return self._dtype(res) + + def __getitem__(self, idx): + if isinstance(idx, int): + res = _vector.ExtractOp( + self, static_position=[idx], dynamic_position=[] + ).result + return self._dtype(res) + raise TypeError(f"unsupported index type: {type(idx)}") + + def bitcast(self, dtype: Type[Numeric], *, loc=None, ip=None) -> "Vector": + src_bits = self.numel * self._dtype.width + dst_count = src_bits // dtype.width + dst_vec_ty = ir.VectorType.get([dst_count], dtype.ir_type) + res = _vector.BitCastOp(dst_vec_ty, self, loc=loc, ip=ip).result + return Vector(res, (dst_count,), dtype) + + def shuffle(self, other, mask, *, loc=None, ip=None) -> "Vector": + other_val = other if not isinstance(other, Vector) else ir.Value(other) + res = _vector.shuffle(self, other_val, mask, loc=loc, ip=ip) + return Vector(res, (len(mask),), self._dtype) + + @classmethod + def filled(cls, shape, fill_value, dtype: Type[Numeric], + *, loc=None, ip=None) -> "Vector": + shape = (shape,) if isinstance(shape, int) else tuple(shape) + n = 1 + for s in shape: + n *= s + if isinstance(fill_value, (int, float, bool)): + fill_value = dtype(fill_value) + elif isinstance(fill_value, Numeric): + fill_value = fill_value.to(dtype, loc=loc, ip=ip) + else: + raise ValueError(f"expected numeric fill_value, got {type(fill_value)}") + vec_ty = ir.VectorType.get([n], dtype.ir_type) + val = _vector.broadcast(vec_ty, fill_value.ir_value(loc=loc, ip=ip), + loc=loc, ip=ip) + return cls(val, shape, dtype) + + @classmethod + def filled_like(cls, template: "Vector", fill_value, dtype=None, + *, loc=None, ip=None) -> "Vector": + if dtype is None: + dtype = template.dtype + return cls.filled(template.shape, fill_value, dtype, loc=loc, ip=ip) + + @classmethod + def zeros_like(cls, template: "Vector", dtype=None, + *, loc=None, ip=None) -> "Vector": + if dtype is None: + dtype = template.dtype + return cls.filled(template.shape, 0.0 if dtype.is_float else 0, dtype, loc=loc, ip=ip) + + +def full(shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> Vector: + return Vector.filled(shape, fill_value, dtype, loc=loc, ip=ip) + + +def full_like(a: Vector, fill_value, dtype=None, *, loc=None, ip=None) -> Vector: + return Vector.filled_like(a, fill_value, dtype, loc=loc, ip=ip) + + +def zeros_like(a: Vector, dtype=None, *, loc=None, ip=None) -> Vector: + return Vector.zeros_like(a, dtype, loc=loc, ip=ip) diff --git a/python/flydsl/expr/utils/arith.py b/python/flydsl/expr/utils/arith.py index c255d310..64b01c03 100644 --- a/python/flydsl/expr/utils/arith.py +++ b/python/flydsl/expr/utils/arith.py @@ -120,8 +120,13 @@ def _coerce_other(self, other, *, loc=None, ip=None): if not isinstance(other, ArithValue): # Accept DSL Numeric types (Int32, Float32, etc.) by unwrapping via ir_value() if hasattr(other, "ir_value"): - return ArithValue(other.ir_value()) - return NotImplemented + other = ArithValue(other.ir_value()) + else: + return NotImplemented + # Broadcast scalar to vector when self is a vector and other is scalar + if isinstance(self.type, ir.VectorType) and not isinstance(other.type, ir.VectorType): + from ..._mlir.dialects import vector as _vector + return _vector.broadcast(self.type, _to_raw(other), loc=loc, ip=ip) return other @@ -275,12 +280,12 @@ def _neg_op(self, *, loc=None, ip=None): raise TypeError("negation is not supported for boolean type") if self.is_float: return arith.negf(self, loc=loc, ip=ip) - c0 = arith.constant(self.type, 0, loc=loc, ip=ip) + c0 = arith_const(0, self.type, loc=loc, ip=ip) return arith.subi(c0, self, loc=loc, ip=ip) def _invert_op(self, *, loc=None, ip=None): - return arith.xori(self, arith.constant(self.type, -1)) + return arith.xori(self, arith_const(-1, self.type, loc=loc, ip=ip)) @ir.register_value_caster(ir.Float4E2M1FNType.static_typeid) diff --git a/python/flydsl/expr/vector.py b/python/flydsl/expr/vector.py index c4af4c0e..002adac9 100644 --- a/python/flydsl/expr/vector.py +++ b/python/flydsl/expr/vector.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -"""Vector dialect helpers. +"""Vector dialect helpers and re-exports. -This module exists so tests can import vector ops through `flydsl._mlir_helpers` -instead of directly importing from `_mlir.dialects.*`. +The ``Vector`` class itself lives in ``typing.py`` alongside other builtin +DSL types. This module re-exports it for convenience and provides thin +wrappers around upstream ``_mlir.dialects.vector`` ops. """ from __future__ import annotations @@ -12,18 +13,21 @@ from .._mlir.dialects import vector as _vector from .meta import traced_op - -# Re-export everything from the upstream dialect module for convenience. +# Re-export upstream dialect for ``from flydsl.expr import vector; vector.broadcast(...)`` from .._mlir.dialects.vector import * # noqa: F401,F403,E402 +# Re-export Vector and friends so ``from flydsl.expr.vector import Vector`` works +from .typing import Vector, ReductionOp, full, full_like, zeros_like # noqa: F401 + + +# ═══════════════════════════════════════════════════════════════════════ +# Dialect helper wrappers (legacy, will be deprecated) +# Prefer using Vector methods or _mlir.dialects.vector directly. +# ═══════════════════════════════════════════════════════════════════════ @traced_op def from_elements(*args, loc=None, ip=None, **kwargs): - """Construct a vector from scalar elements, auto-unwrapping ArithValue wrappers. - - Accepts the same arguments as ``vector.from_elements`` but transparently - handles ``ArithValue`` scalars so callers don't need explicit unwrapping. - """ + """Construct a vector from scalar elements, auto-unwrapping ArithValue wrappers.""" from . import arith as _arith_ext if len(args) >= 2: @@ -51,11 +55,6 @@ def store(value, memref, indices, *, loc=None, ip=None, **kwargs): ) -# ----------------------------------------------------------------------------- -# Thin wrappers for common op classes that otherwise require `.result` access. -# ----------------------------------------------------------------------------- - - @traced_op def extract(vector, static_position=None, dynamic_position=None, *, loc=None, ip=None): """Wrapper around `vector.ExtractOp(...).result`.""" @@ -100,4 +99,3 @@ def bitcast(result_type, source, *, loc=None, ip=None): loc=loc, ip=ip, ).result - diff --git a/tests/kernels/test_quant.py b/tests/kernels/test_quant.py index 64d4b982..e0c842c5 100644 --- a/tests/kernels/test_quant.py +++ b/tests/kernels/test_quant.py @@ -50,13 +50,14 @@ import flydsl.compiler as flyc import flydsl.expr as fx from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, vector, gpu, range_constexpr +from flydsl.expr import arith, gpu, range_constexpr from flydsl.expr.typing import T, Int32 from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from flydsl.runtime.device import get_rocm_arch from flydsl._mlir import ir -from flydsl.expr import buffer_ops from flydsl.expr.arith import ArithValue +from flydsl.expr.vector import Vector, full, ReductionOp +from flydsl.expr.numeric import Float16, Float32, Int8 from tests.test_common import run_perftest BLOCK_THREADS = 256 @@ -148,37 +149,49 @@ def block_reduce_max(val): # ── Layout API: buffer-backed tensors ──────────────────────────── Input_buf = fx.rocdl.make_buffer_tensor(Input) - scales_rsrc = buffer_ops.create_buffer_resource(Scales, max_size=True) - # i8 output: keep buffer_ops (BufferCopy64b unsupported by LLVM backend) - out_rsrc = buffer_ops.create_buffer_resource(Output, max_size=True) + Out_buf = fx.rocdl.make_buffer_tensor(Output) + Scales_buf = fx.rocdl.make_buffer_tensor(Scales) - # Slice row for this block, tile by VEC_WIDTH - row_in = fx.slice(Input_buf, (bid, None)) + # Slice at row 0; actual row offset via soffset (SGPR) + bid_row_offset = ArithValue(bid) * fx.Int32(N) + + row_in = fx.slice(Input_buf, (0, None)) in_div = fx.logical_divide(row_in, fx.make_layout(VEC_WIDTH, 1)) - # Copy atom for f16 loads: 8 x f16 = 128b - copy_atom_in = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 16) + # Copy atom for f16 loads: 8 x f16 = 128b, with soffset for row + copy_atom_in_base = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), 16) + copy_atom_in = copy_atom_in_base.set_value("soffset", bid_row_offset) vec_reg_ty_f16 = fx.MemRefType.get( T.f16, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register ) vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) - elem_bytes_i8 = 1 - vec_dwords_i8 = (VEC_WIDTH * elem_bytes_i8) // 4 # 2 - row_soffset_out = ArithValue(bid) * (N * elem_bytes_i8) - thr_col_bytes_i8 = ArithValue(tid) * (VEC_WIDTH * elem_bytes_i8) + # Copy atom for i8 output: 8 x i8 = 64b, with soffset for row + copy_atom_out_base = fx.make_copy_atom(fx.rocdl.BufferCopy64b(), 8) + copy_atom_out = copy_atom_out_base.set_value("soffset", bid_row_offset) + out_row = fx.slice(Out_buf, (0, None)) + out_div = fx.logical_divide(out_row, fx.make_layout(VEC_WIDTH, 1)) + i8_vec_reg_ty = fx.MemRefType.get( + T.i8, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register + ) + + # Copy atom for f32 scales: scalar store with soffset for bid + copy_atom_f32_base = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + copy_atom_f32 = copy_atom_f32_base.set_value("soffset", bid) + scales_div = fx.logical_divide(Scales_buf, fx.make_layout(1, 1)) + scales_base = fx.slice(scales_div, (None, fx.Int32(0))) + f32_reg_ty = fx.MemRefType.get( + T.f32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register + ) + f32_reg_lay = fx.make_layout(1, 1) def _load_vec_f16(div_tensor, idx): r = fx.memref_alloca(vec_reg_ty_f16, vec_reg_lay) fx.copy_atom_call(copy_atom_in, fx.slice(div_tensor, (None, idx)), r) - return ArithValue(fx.memref_load_vec(r)) + return Vector(fx.memref_load_vec(r), VEC_WIDTH, Float16) - vec_type_f32 = T.vec(VEC_WIDTH, T.f32) - - # abs via sign-bit clearing (|x| = bitcast(bitcast(x, i32) & 0x7FFFFFFF, f32)) - i32_vec_ty = T.vec(VEC_WIDTH, T.i32) - abs_mask = arith.constant_vector(0x7FFFFFFF, i32_vec_ty) + abs_mask = full(VEC_WIDTH, Int32(0x7FFFFFFF), Int32) c_zero_f = arith.constant(0.0, type=T.f32) local_max = c_zero_f @@ -191,15 +204,15 @@ def _load_vec_f16(div_tensor, idx): is_valid = col_end <= N vec_f16 = _load_vec_f16(in_div, idx) - vec_f32 = vec_f16.extf(vec_type_f32) + vec_f32 = vec_f16.to(Float32) cached_vecs.append(vec_f32) - vec_i32 = vec_f32.bitcast(i32_vec_ty) - vec_abs_i32 = arith.andi(vec_i32, abs_mask) - vec_abs = vector.bitcast(vec_type_f32, vec_abs_i32) + vec_i32 = vec_f32.bitcast(Int32) + vec_abs_i32 = vec_i32 & abs_mask + vec_abs = vec_abs_i32.bitcast(Float32) - chunk_max = vector.reduction(T.f32, "maxnumf", vec_abs) - chunk_max_safe = arith.select(is_valid, chunk_max, c_zero_f) + chunk_max = vec_abs.reduce(ReductionOp.MAX) + chunk_max_safe = arith.select(is_valid, chunk_max.ir_value(), c_zero_f) local_max = local_max.maximumf(chunk_max_safe) reduced_max = block_reduce_max(local_max) @@ -213,31 +226,33 @@ def _load_vec_f16(div_tensor, idx): # thread 0 stores Scales[bid] if arith.cmpi(arith.CmpIPredicate.eq, tid, Int32(0)): - buffer_ops.buffer_store(final_scale, scales_rsrc, bid) + r_sc = fx.memref_alloca(f32_reg_ty, f32_reg_lay) + ts_sc = full(1, Float32(final_scale), Float32) + fx.memref_store_vec(ts_sc, r_sc) + fx.copy_atom_call(copy_atom_f32, r_sc, scales_base) # ── Pass 2: quantize f32 → i8, store ───────────────────────────── inv_scale = ArithValue(c_1) / ArithValue(final_scale) - inv_scale_splat = vector.broadcast(vec_type_f32, inv_scale) - - vec_type_i8 = T.vec(VEC_WIDTH, T.i8) - i32_vec_ty_out = T.vec(vec_dwords_i8, T.i32) for tile_i in range_constexpr(num_tiles): - col_end = ArithValue(tid) * VEC_WIDTH + (tile_i * tile_cols + VEC_WIDTH) - is_valid = col_end <= N - vec_f32 = cached_vecs[tile_i] - vec_scaled = ArithValue(vec_f32) * ArithValue(inv_scale_splat) - - vec_i8 = arith.FPToSIOp(vec_type_i8, arith.unwrap(vec_scaled)).result - out_packed = vector.bitcast(i32_vec_ty_out, vec_i8) - - col_bytes_out = ArithValue(thr_col_bytes_i8) + (tile_i * tile_cols * elem_bytes_i8) - dw_out = col_bytes_out.shrui(arith.constant(2, type=T.i32)) - buffer_ops.buffer_store( - out_packed, out_rsrc, dw_out, - soffset_bytes=row_soffset_out, mask=is_valid, - ) + vec_scaled = vec_f32 * inv_scale + vec_i8 = vec_scaled.to(Int8) + idx_out = tid + tile_i * BLOCK_THREADS + + # Python-level check: only the last tile of non-aligned N needs a guard + last_partial = (N % tile_cols != 0) and (tile_i == num_tiles - 1) + if last_partial: + col_end = ArithValue(tid) * VEC_WIDTH + (tile_i * tile_cols + VEC_WIDTH) + is_valid = col_end <= N + if is_valid: + r_out = fx.memref_alloca(i8_vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(vec_i8, r_out) + fx.copy_atom_call(copy_atom_out, r_out, fx.slice(out_div, (None, idx_out))) + else: + r_out = fx.memref_alloca(i8_vec_reg_ty, vec_reg_lay) + fx.memref_store_vec(vec_i8, r_out) + fx.copy_atom_call(copy_atom_out, r_out, fx.slice(out_div, (None, idx_out))) @flyc.jit def launch_quant( diff --git a/tests/unit/test_vector.py b/tests/unit/test_vector.py new file mode 100644 index 00000000..758449c8 --- /dev/null +++ b/tests/unit/test_vector.py @@ -0,0 +1,799 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Unit tests for Vector, ReductionOp, math (vector support), and factory functions. + +All tests are IR-level (no GPU required). They build MLIR modules using +Vector operations and verify the generated IR text. +""" + +import pytest + +from flydsl._mlir import ir +from flydsl._mlir.dialects import arith, func + +from flydsl.expr.vector import ( + Vector, + ReductionOp, + full, + full_like, + zeros_like, +) +from flydsl.expr.numeric import ( + Float32, + Float16, + BFloat16, + Float64, + Int32, + Int16, + Uint32, + Boolean, + Numeric, +) +from flydsl.expr import math as fmath + +pytestmark = pytest.mark.l0_backend_agnostic + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _build_module(build_fn, arg_types=None): + """Build an MLIR module, call *build_fn* with block arguments, return IR text.""" + with ir.Context() as ctx: + ctx.allow_unregistered_dialects = True + with ir.Location.unknown(ctx): + if arg_types is None: + types = [ir.VectorType.get([8], ir.F32Type.get())] + else: + types = [t() if callable(t) else t for t in arg_types] + module = ir.Module.create() + with ir.InsertionPoint(module.body): + ftype = ir.FunctionType.get(types, []) + f = func.FuncOp("test", ftype) + with ir.InsertionPoint(f.add_entry_block()): + build_fn(*f.entry_block.arguments) + func.ReturnOp([]) + module.operation.verify() + return str(module) + + +def _vec_f32(): + return ir.VectorType.get([8], ir.F32Type.get()) + +def _vec_f16(): + return ir.VectorType.get([8], ir.F16Type.get()) + +def _vec_bf16(): + return ir.VectorType.get([8], ir.BF16Type.get()) + +def _vec_i32(): + return ir.VectorType.get([8], ir.IntegerType.get_signless(32)) + +def _vec_i16(): + return ir.VectorType.get([8], ir.IntegerType.get_signless(16)) + + +# =========================================================================== +# A. Construction & properties +# =========================================================================== + +class TestConstruction: + + def test_init_from_vector(self): + def build(raw): + t = Vector(raw, 8, Float32) + assert t.shape == (8,) + assert t.dtype is Float32 + assert t.element_type is Float32 + assert t.numel == 8 + _build_module(build) + + def test_init_shape_int_vs_tuple(self): + def build(raw): + t1 = Vector(raw, 8, Float32) + t2 = Vector(raw, (8,), Float32) + assert t1.shape == t2.shape == (8,) + _build_module(build) + + def test_signed_false_for_float(self): + def build(raw): + t = Vector(raw, 8, Float32) + assert t.signed is False + _build_module(build) + + def test_signed_true_for_int32(self): + def build(raw): + t = Vector(raw, 8, Int32) + assert t.signed is True + _build_module(build, [_vec_i32]) + + def test_signed_false_for_uint32(self): + def build(raw): + t = Vector(raw, 8, Uint32) + assert t.signed is False + _build_module(build, [_vec_i32]) + + def test_str_repr(self): + def build(raw): + t = Vector(raw, 8, Float32) + s = str(t) + assert "Vector" in s + assert "Float32" in s + _build_module(build) + + +# =========================================================================== +# B. Operators +# =========================================================================== + +class TestOperators: + + def test_add_two_tensors(self): + def build(a, b): + ta = Vector(a, 8, Float32) + tb = Vector(b, 8, Float32) + _ = ta + tb + ir_text = _build_module(build, [_vec_f32, _vec_f32]) + assert "arith.addf" in ir_text + + def test_mul_scalar_broadcast(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = ta * 2.0 + ir_text = _build_module(build) + # Scalar 2.0 is splatted into a vector constant via arith_const + assert "arith.mulf" in ir_text + + def test_sub_reverse(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = 1.0 - ta + ir_text = _build_module(build) + # Scalar 1.0 is splatted into a vector constant via arith_const + assert "arith.subf" in ir_text + + def test_int_add(self): + def build(a, b): + ta = Vector(a, 8, Int32) + tb = Vector(b, 8, Int32) + _ = ta + tb + ir_text = _build_module(build, [_vec_i32, _vec_i32]) + assert "arith.addi" in ir_text + + def test_comparison_returns_boolean(self): + def build(a, b): + ta = Vector(a, 8, Float32) + tb = Vector(b, 8, Float32) + result = ta < tb + assert isinstance(result, Vector) + assert result.dtype is Boolean + _build_module(build, [_vec_f32, _vec_f32]) + + def test_bitwise_and_or_xor(self): + def build(a, b): + ta = Vector(a, 8, Uint32) + tb = Vector(b, 8, Uint32) + _ = ta & tb + _ = ta | tb + _ = ta ^ tb + ir_text = _build_module(build, [_vec_i32, _vec_i32]) + assert "arith.andi" in ir_text + assert "arith.ori" in ir_text + assert "arith.xori" in ir_text + + def test_shift_ops(self): + def build(a): + ta = Vector(a, 8, Uint32) + _ = ta >> 16 + _ = ta << 8 + ir_text = _build_module(build, [_vec_i32]) + assert "arith.shrui" in ir_text + assert "arith.shli" in ir_text + + def test_unsigned_shift_uses_shrui(self): + """Uint32 Vector >> must use shrui, not shrsi.""" + def build(a): + ta = Vector(a, 8, Uint32) + _ = ta >> 16 + ir_text = _build_module(build, [_vec_i32]) + assert "arith.shrui" in ir_text + assert "arith.shrsi" not in ir_text + + def test_signed_shift_uses_shrsi(self): + """Int32 Vector >> must use shrsi.""" + def build(a): + ta = Vector(a, 8, Int32) + _ = ta >> 16 + ir_text = _build_module(build, [_vec_i32]) + assert "arith.shrsi" in ir_text + + def test_neg(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = -ta + ir_text = _build_module(build) + assert "arith.negf" in ir_text + + def test_truediv(self): + def build(a, b): + ta = Vector(a, 8, Float32) + tb = Vector(b, 8, Float32) + _ = ta / tb + ir_text = _build_module(build, [_vec_f32, _vec_f32]) + assert "arith.divf" in ir_text + + def test_pow(self): + def build(a, b): + ta = Vector(a, 8, Float32) + tb = Vector(b, 8, Float32) + _ = ta ** tb + ir_text = _build_module(build, [_vec_f32, _vec_f32]) + assert "math.powf" in ir_text + + def test_floordiv_int(self): + def build(a, b): + ta = Vector(a, 8, Int32) + tb = Vector(b, 8, Int32) + _ = ta // tb + ir_text = _build_module(build, [_vec_i32, _vec_i32]) + assert "arith.floordivsi" in ir_text + + def test_mod_int(self): + def build(a, b): + ta = Vector(a, 8, Int32) + tb = Vector(b, 8, Int32) + _ = ta % tb + ir_text = _build_module(build, [_vec_i32, _vec_i32]) + assert "arith.remsi" in ir_text + + def test_neg_int(self): + """Negating an integer Vector should produce arith.subi (0 - x).""" + def build(a): + ta = Vector(a, 8, Int32) + result = -ta + assert isinstance(result, Vector) + ir_text = _build_module(build, [_vec_i32]) + assert "arith.subi" in ir_text + + def test_reverse_bitwise(self): + def build(a, b): + ta = Vector(a, 8, Uint32) + tb = Vector(b, 8, Uint32) + # reverse ops: rhs.__rand__ etc. + r1 = tb & ta + r2 = tb | ta + r3 = tb ^ ta + assert isinstance(r1, Vector) + assert isinstance(r2, Vector) + assert isinstance(r3, Vector) + ir_text = _build_module(build, [_vec_i32, _vec_i32]) + assert "arith.andi" in ir_text + assert "arith.ori" in ir_text + assert "arith.xori" in ir_text + + +# =========================================================================== +# C. Type promotion +# =========================================================================== + +class TestTypePromotion: + + def test_same_type(self): + assert Numeric.promote(Float32, Float32) is Float32 + + def test_f16_f32(self): + assert Numeric.promote(Float16, Float32) is Float32 + + def test_bf16_f32(self): + assert Numeric.promote(BFloat16, Float32) is Float32 + + def test_int_float(self): + """Int32 + Float32 → Float32.""" + assert Numeric.promote(Int32, Float32) is Float32 + + def test_int_wider_than_float(self): + """Float16 + Int32 → Float32 (int width 32 > float width 16).""" + assert Numeric.promote(Float16, Int32) is Float32 + + def test_int_same_width_as_float(self): + """Float32 + Int32 → Float32 (same width, float wins).""" + assert Numeric.promote(Float32, Int32) is Float32 + + def test_int_narrower_than_float(self): + """Float32 + Int16 → Float32 (int is narrower).""" + assert Numeric.promote(Float32, Int16) is Float32 + + def test_int64_with_float32(self): + """Float32 + Int64 → Float64 (int width 64 > float width 32).""" + from flydsl.expr.numeric import Int64 + assert Numeric.promote(Float32, Int64) is Float64 + + def test_f16_f64(self): + assert Numeric.promote(Float16, Float64) is Float64 + + def test_promote_in_operator(self): + """Mixed-type vector ops require explicit .to() conversion (no auto-promote).""" + def build(a, b): + ta = Vector(a, 8, Float16) + tb = Vector(b, 8, Float32) + ta_f32 = ta.to(Float32) + result = ta_f32 + tb + assert result.dtype is Float32 + ir_text = _build_module(build, [_vec_f16, _vec_f32]) + assert "arith.extf" in ir_text + assert "arith.addf" in ir_text + + def test_mixed_signed_unsigned_int(self): + """Int32 + Uint32 → Uint32 (unsigned wins at same width).""" + assert Numeric.promote(Int32, Uint32) is Uint32 + assert Numeric.promote(Uint32, Int32) is Uint32 + + def test_promote_bf16_scalar(self): + """BFloat16 tensor + scalar → explicit .to() needed for mixed-type ops.""" + def build(a): + ta = Vector(a, 8, BFloat16) + ta_f32 = ta.to(Float32) + result = ta_f32 + 1.0 + assert result.dtype is Float32 + ir_text = _build_module(build, [_vec_bf16]) + assert "arith.extf" in ir_text + assert "arith.addf" in ir_text + + +# =========================================================================== +# D. Type conversion (.to()) +# =========================================================================== + +class TestToConversion: + + def test_same_type_noop(self): + def build(a): + ta = Vector(a, 8, Float32) + result = ta.to(Float32) + assert result is ta + _build_module(build) + + def test_float_to_float_truncf(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = ta.to(BFloat16) + ir_text = _build_module(build) + assert "arith.truncf" in ir_text + + def test_float_to_float_extf(self): + def build(a): + ta = Vector(a, 8, Float16) + _ = ta.to(Float32) + ir_text = _build_module(build, [_vec_f16]) + assert "arith.extf" in ir_text + + def test_float_to_int(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = ta.to(Int32) + ir_text = _build_module(build) + assert "arith.fptosi" in ir_text + + def test_int_to_float(self): + def build(a): + ta = Vector(a, 8, Int32) + _ = ta.to(Float32) + ir_text = _build_module(build, [_vec_i32]) + assert "arith.sitofp" in ir_text + + def test_uint_to_float(self): + """Uint32 → Float32 should use uitofp, not sitofp.""" + def build(a): + ta = Vector(a, 8, Uint32) + result = ta.to(Float32) + assert result.dtype is Float32 + ir_text = _build_module(build, [_vec_i32]) + assert "arith.uitofp" in ir_text + assert "arith.sitofp" not in ir_text + + def test_float_to_uint(self): + """Float32 → Uint32 should use fptoui, not fptosi.""" + def build(a): + ta = Vector(a, 8, Float32) + result = ta.to(Uint32) + assert result.dtype is Uint32 + ir_text = _build_module(build) + assert "arith.fptoui" in ir_text + assert "arith.fptosi" not in ir_text + + def test_int16_to_int32(self): + """Int16 → Int32 should use extsi.""" + def build(a): + ta = Vector(a, 8, Int16) + result = ta.to(Int32) + assert result.dtype is Int32 + assert result.shape == (8,) + ir_text = _build_module(build, [_vec_i16]) + assert "arith.extsi" in ir_text + + def test_to_ir_value_returns_self(self): + """to(ir.Value) should return self unchanged.""" + def build(a): + ta = Vector(a, 8, Float32) + result = ta.to(ir.Value) + assert result is ta + _build_module(build) + + def test_to_preserves_shape(self): + def build(a): + ta = Vector(a, 8, Float32) + result = ta.to(BFloat16) + assert result.shape == (8,) + assert result.dtype is BFloat16 + _build_module(build) + + +# =========================================================================== +# E. Reduction +# =========================================================================== + +class TestReduction: + + def test_reduce_add(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = ta.reduce(ReductionOp.ADD) + ir_text = _build_module(build) + assert "vector.reduction " in ir_text + + def test_reduce_max(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = ta.reduce(ReductionOp.MAX) + ir_text = _build_module(build) + assert "vector.reduction " in ir_text + + def test_reduce_min(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = ta.reduce(ReductionOp.MIN) + ir_text = _build_module(build) + assert "vector.reduction " in ir_text + + def test_reduce_with_fastmath(self): + def build(a): + ta = Vector(a, 8, Float32) + fm = arith.FastMathFlags.fast + _ = ta.reduce(ReductionOp.ADD, fastmath=fm) + ir_text = _build_module(build) + assert "fastmath" in ir_text.lower() or "fast" in ir_text + + def test_reduce_returns_numeric(self): + """reduce() should return Numeric, not raw ir.Value.""" + def build(a): + ta = Vector(a, 8, Float32) + result = ta.reduce(ReductionOp.ADD) + assert isinstance(result, Float32) + _build_module(build) + + def test_int_reduce_add(self): + def build(a): + ta = Vector(a, 8, Int32) + result = ta.reduce(ReductionOp.ADD) + assert isinstance(result, Int32) + ir_text = _build_module(build, [_vec_i32]) + assert "vector.reduction " in ir_text + + def test_int_reduce_max_signed(self): + """Int32 MAX should use maxsi, not maxnumf.""" + def build(a): + ta = Vector(a, 8, Int32) + result = ta.reduce(ReductionOp.MAX) + assert isinstance(result, Int32) + ir_text = _build_module(build, [_vec_i32]) + assert "vector.reduction " in ir_text + + def test_int_reduce_max_unsigned(self): + """Uint32 MAX should use maxui.""" + def build(a): + ta = Vector(a, 8, Uint32) + result = ta.reduce(ReductionOp.MAX) + assert isinstance(result, Uint32) + ir_text = _build_module(build, [_vec_i32]) + assert "vector.reduction " in ir_text + + def test_int_reduce_min_signed(self): + """Int32 MIN should use minsi.""" + def build(a): + ta = Vector(a, 8, Int32) + result = ta.reduce(ReductionOp.MIN) + assert isinstance(result, Int32) + ir_text = _build_module(build, [_vec_i32]) + assert "vector.reduction " in ir_text + + def test_int_reduce_min_unsigned(self): + """Uint32 MIN should use minui.""" + def build(a): + ta = Vector(a, 8, Uint32) + result = ta.reduce(ReductionOp.MIN) + assert isinstance(result, Uint32) + ir_text = _build_module(build, [_vec_i32]) + assert "vector.reduction " in ir_text + + def test_reduce_with_init_val(self): + """reduce() with init_val should pass acc to vector.reduction.""" + def build(a): + ta = Vector(a, 8, Float32) + init = Float32(0.0) + result = ta.reduce(ReductionOp.ADD, init_val=init) + assert isinstance(result, Float32) + ir_text = _build_module(build) + assert "vector.reduction " in ir_text + + def test_reduce_string_add(self): + """reduce() accepts plain string 'add'.""" + def build(a): + ta = Vector(a, 8, Float32) + result = ta.reduce("add") + assert isinstance(result, Float32) + ir_text = _build_module(build) + assert "vector.reduction " in ir_text + + def test_reduce_string_max(self): + """reduce() accepts plain string 'max'.""" + def build(a): + ta = Vector(a, 8, Float32) + _ = ta.reduce("max") + ir_text = _build_module(build) + assert "vector.reduction " in ir_text + + def test_reduce_combining_kind_direct(self): + """reduce() accepts raw CombiningKind.""" + from flydsl._mlir.dialects.vector import CombiningKind + def build(a): + ta = Vector(a, 8, Float32) + result = ta.reduce(CombiningKind.ADD) + assert isinstance(result, Float32) + ir_text = _build_module(build) + assert "vector.reduction " in ir_text + + def test_reduce_bad_op_raises(self): + """reduce() raises on invalid op type.""" + def build(a): + ta = Vector(a, 8, Float32) + with pytest.raises(TypeError): + ta.reduce(42) + _build_module(build) + + def test_reduce_bad_string_raises(self): + """reduce() raises on unknown string.""" + def build(a): + ta = Vector(a, 8, Float32) + with pytest.raises(ValueError, match="unknown"): + ta.reduce("foobar") + _build_module(build) + + +# =========================================================================== +# F. Element access +# =========================================================================== + +class TestElementAccess: + + def test_getitem_int(self): + def build(a): + ta = Vector(a, 8, Float32) + elem = ta[0] + assert isinstance(elem, Float32) + ir_text = _build_module(build) + assert "vector.extract" in ir_text + + def test_getitem_invalid_type(self): + def build(a): + ta = Vector(a, 8, Float32) + with pytest.raises(TypeError): + ta["bad"] + _build_module(build) + + +# =========================================================================== +# G. Vector ops +# =========================================================================== + +class TestVectorOps: + + def test_bitcast(self): + def build(a): + ta = Vector(a, 8, Float32) + result = ta.bitcast(Uint32) + assert result.shape == (8,) + assert result.dtype is Uint32 + ir_text = _build_module(build) + assert "vector.bitcast" in ir_text + + def test_bitcast_width_change(self): + """f32 → f16 bitcast: 8 elements * 32 bits = 256 bits → 16 elements * 16 bits.""" + def build(a): + ta = Vector(a, 8, Float32) + result = ta.bitcast(Float16) + assert result.shape == (16,) + assert result.dtype is Float16 + ir_text = _build_module(build) + assert "vector.bitcast" in ir_text + + def test_shuffle(self): + def build(a, b): + ta = Vector(a, 8, Float32) + tb = Vector(b, 8, Float32) + result = ta.shuffle(tb, [0, 2, 4, 6]) + assert result.shape == (4,) + assert result.dtype is Float32 + ir_text = _build_module(build, [_vec_f32, _vec_f32]) + assert "vector.shuffle" in ir_text + + +# =========================================================================== +# H. Factory functions +# =========================================================================== + +class TestFactories: + + def test_full(self): + def build(a): + t = full(8, 1.0, Float32) + assert t.shape == (8,) + assert t.dtype is Float32 + ir_text = _build_module(build) + assert "vector.broadcast" in ir_text + + def test_full_like(self): + def build(a): + ta = Vector(a, 8, Float32) + t = full_like(ta, 0.0) + assert t.shape == ta.shape + assert t.dtype == ta.dtype + ir_text = _build_module(build) + assert "vector.broadcast" in ir_text + + def test_zeros_like(self): + def build(a): + ta = Vector(a, 8, Float32) + t = zeros_like(ta) + assert t.shape == ta.shape + assert t.dtype == ta.dtype + ir_text = _build_module(build) + assert "vector.broadcast" in ir_text + + def test_full_with_numeric_fill(self): + """full() with Numeric fill_value should work.""" + def build(a): + t = full(8, Float32(2.5), Float32) + assert t.shape == (8,) + assert t.dtype is Float32 + ir_text = _build_module(build) + assert "vector.broadcast" in ir_text + + def test_classmethod_filled(self): + def build(a): + t = Vector.filled(8, 1.0, Float32) + assert t.shape == (8,) + assert t.dtype is Float32 + ir_text = _build_module(build) + assert "vector.broadcast" in ir_text + + def test_classmethod_filled_like(self): + def build(a): + ta = Vector(a, 8, Float32) + t = Vector.filled_like(ta, 2.0) + assert t.shape == ta.shape + assert t.dtype is Float32 + ir_text = _build_module(build) + assert "vector.broadcast" in ir_text + + def test_classmethod_zeros_like(self): + def build(a): + ta = Vector(a, 8, Float32) + t = Vector.zeros_like(ta) + assert t.shape == ta.shape + assert t.dtype is Float32 + ir_text = _build_module(build) + assert "vector.broadcast" in ir_text + + +# =========================================================================== +# I. fmath +# =========================================================================== + +class TestFmath: + + def test_exp2_tensor(self): + def build(a): + ta = Vector(a, 8, Float32) + result = fmath.exp2(ta) + assert isinstance(result, Vector) + assert result.dtype is Float32 + assert result.shape == (8,) + ir_text = _build_module(build) + assert "math.exp2" in ir_text + + def test_rsqrt_tensor(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = fmath.rsqrt(ta) + ir_text = _build_module(build) + assert "math.rsqrt" in ir_text + + def test_fastmath_flag(self): + from flydsl.expr.arith import FastMathFlags + def build(a): + ta = Vector(a, 8, Float32) + _ = fmath.exp2(ta, fastmath=FastMathFlags.fast) + ir_text = _build_module(build) + assert "fast" in ir_text + + def test_scalar_float(self): + """math on scalar Float32 returns Float32 Numeric.""" + def build(raw): + x = Float32(raw) + result = fmath.sqrt(x) + assert not isinstance(result, Vector) + assert isinstance(result, Float32) + _build_module(build, [ir.F32Type.get]) + + def test_int_scalar_math(self): + """math.exp2 on raw integer ir.Value (not through Numeric) is allowed by MLIR.""" + # math.py passes through to MLIR ops which accept any float-like; + # integer scalars wrapped in Int32 Numeric are handled by _traced_math_op + pass + + def test_vector_scalar_atan2(self): + """atan2 with Vector and scalar broadcasts the scalar.""" + def build(a, raw_scalar): + ta = Vector(a, 8, Float32) + # scalar is broadcast to match vector type via _coerce_other + _ = fmath.atan2(ta, ta) + _build_module(build, [_vec_f32, ir.F32Type.get]) + + def test_new_functions_exist(self): + """Verify all newly added functions are accessible.""" + for name in ["erf", "acos", "asin", "atan", "atan2", "tan", "log10"]: + assert hasattr(fmath, name), f"fmath.{name} missing" + + def test_erf_tensor(self): + def build(a): + ta = Vector(a, 8, Float32) + _ = fmath.erf(ta) + ir_text = _build_module(build) + assert "math.erf" in ir_text + + def test_atan2_tensor(self): + def build(a, b): + ta = Vector(a, 8, Float32) + tb = Vector(b, 8, Float32) + result = fmath.atan2(ta, tb) + assert isinstance(result, Vector) + ir_text = _build_module(build, [_vec_f32, _vec_f32]) + assert "math.atan2" in ir_text + + +# =========================================================================== +# J. scf.for integration +# =========================================================================== + +class TestProtocol: + + def test_fly_values_roundtrip(self): + def build(a): + ta = Vector(a, 8, Float32) + values = ta.__fly_values__() + assert len(values) == 1 + reconstructed = Vector.__fly_construct__(values) + assert isinstance(reconstructed, ir.Value) + _build_module(build) + + def test_hash(self): + """Vector must be hashable since __eq__ is overridden.""" + def build(a): + ta = Vector(a, 8, Float32) + h = hash(ta) + assert isinstance(h, int) + _build_module(build) From 9a84cc9c0b04374c3695f0462ec5b4d09e1c2293 Mon Sep 17 00:00:00 2001 From: Felix Li Date: Sun, 12 Apr 2026 08:14:19 +0800 Subject: [PATCH 02/29] [FIX] Support AOT cross-compilation with COMPILE_ONLY cache save (#383) * [FIX] Support AOT cross-compilation with COMPILE_ONLY cache save * [FIX] Simplify aot_example.py precompile path --------- Co-authored-by: Claude Opus 4.6 (1M context) --- lib/Bindings/Python/DLTensorAdaptor.h | 3 --- python/flydsl/_version.py | 1 + python/flydsl/compiler/jit_function.py | 8 +++--- tests/python/examples/aot_example.py | 37 ++++++++++++++++++++++---- 4 files changed, 37 insertions(+), 12 deletions(-) create mode 100644 python/flydsl/_version.py diff --git a/lib/Bindings/Python/DLTensorAdaptor.h b/lib/Bindings/Python/DLTensorAdaptor.h index f7b8b99e..298a2e34 100644 --- a/lib/Bindings/Python/DLTensorAdaptor.h +++ b/lib/Bindings/Python/DLTensorAdaptor.h @@ -242,9 +242,6 @@ class DLTensorAdaptor { LayoutAttr layoutAttr = LayoutAttr::get(ctx, shapeAttr, strideAttr); - if (getAddressSpace() != 1) { - throw std::runtime_error("Only device address space is supported"); - } AddressSpaceAttr addrSpaceAttr = AddressSpaceAttr::get(ctx, AddressSpace::Global); assert(alignment_ > 0 && "alignment must be positive"); diff --git a/python/flydsl/_version.py b/python/flydsl/_version.py new file mode 100644 index 00000000..b1aa5330 --- /dev/null +++ b/python/flydsl/_version.py @@ -0,0 +1 @@ +__version__ = "0.1.3.1" diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index 94e2c10a..f8be93aa 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -829,10 +829,6 @@ def __call__(self, *args, **kwargs): compiled_module = MlirCompiler.compile(module, arch=backend.target.arch, func_name=self.func.__name__) - if env.compile.compile_only: - print(f"[flydsl] COMPILE_ONLY=1, compilation succeeded (arch={backend.target.arch})") - return None - compiled_func = CompiledArtifact( compiled_module, self.func.__name__, @@ -851,6 +847,10 @@ def __call__(self, *args, **kwargs): str_key = self._cache_key_to_str(cache_key) self.cache_manager.set(str_key, compiled_func) + if env.compile.compile_only: + print(f"[flydsl] COMPILE_ONLY=1, compilation succeeded (arch={backend.target.arch})") + return None + result = compiled_func(*jit_args) # Build CallState so subsequent calls skip DLPack. The in-process diff --git a/tests/python/examples/aot_example.py b/tests/python/examples/aot_example.py index b9462b40..f61038a9 100644 --- a/tests/python/examples/aot_example.py +++ b/tests/python/examples/aot_example.py @@ -27,7 +27,6 @@ import os import sys import time - _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) @@ -35,7 +34,33 @@ from kernels.preshuffle_gemm import compile_preshuffle_gemm_a8 -def _run_kernel( +def precompile_to_cache(launch_fn, M: int, N: int, K: int, in_dtype: str): + """Trigger JIT compilation with CPU dummy tensors and COMPILE_ONLY=1.""" + import torch + + is_low_prec = in_dtype not in ("fp16", "bf16") + a_dtype = torch.int8 if is_low_prec else (torch.float16 if in_dtype == "fp16" else torch.bfloat16) + b_elems = (N * K) // 2 if in_dtype == "int4" else N * K + + prev = os.environ.get("COMPILE_ONLY") + os.environ["COMPILE_ONLY"] = "1" + try: + launch_fn( + torch.zeros(M * N, dtype=torch.float16), + torch.zeros(M * K, dtype=a_dtype), + torch.zeros(b_elems, dtype=torch.int8 if is_low_prec else a_dtype), + torch.zeros(M, dtype=torch.float32) if is_low_prec else torch.empty(0, dtype=torch.float32), + torch.zeros(N, dtype=torch.float32) if is_low_prec else torch.empty(0, dtype=torch.float32), + M, N, 0, + ) + finally: + if prev is None: + os.environ.pop("COMPILE_ONLY", None) + else: + os.environ["COMPILE_ONLY"] = prev + + +def run_and_verify( launch_fn, M: int, N: int, @@ -211,12 +236,14 @@ def compile_one_config( in_dtype=in_dtype, lds_stage=lds_stage, ) + if run_kernel: + run_and_verify(launch_fn, M, N, K, in_dtype) + else: + precompile_to_cache(launch_fn, M, N, K, in_dtype) + elapsed = time.time() - t0 result["compile_time"] = elapsed print(f" [OK] compile {elapsed:6.1f}s {shape_str}") - - if run_kernel and launch_fn is not None: - _run_kernel(launch_fn, M, N, K, in_dtype) except Exception as e: print(f" [FAIL] compile {shape_str}: {e}") From ab2c1cb545a4e60b0888e303301bf7f6e3d7c660 Mon Sep 17 00:00:00 2001 From: Felix Li Date: Sun, 12 Apr 2026 09:14:01 +0800 Subject: [PATCH 03/29] [Docs] Add MI355X/gfx1201 and MI450/gfx1250 to platform docs (#384) --- CLAUDE.md | 9 +++++---- README.md | 6 +++--- docs/architecture_guide.md | 4 +++- docs/kernel_authoring_guide.md | 4 +++- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 489de90f..15c787e5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,6 +1,6 @@ # FlyDSL Project Guide -FlyDSL (Flexible Layout Python DSL) — a Python DSL and MLIR-based compiler stack for authoring high-performance GPU kernels with explicit layouts and tiling on AMD GPUs (MI300X/MI350/MI450). +FlyDSL (Flexible Layout Python DSL) — a Python DSL and MLIR-based compiler stack for authoring high-performance GPU kernels with explicit layouts and tiling on AMD GPUs (MI300X/MI350/MI355X/MI450). ## Repository Layout @@ -59,9 +59,10 @@ FLYDSL_DUMP_IR=1 PYTHONPATH=./ python tests/kernels/test_pa.py # Dump MLIR IR at | Arch | Chips | Wave size | MMA | Key features | |---|---|---|---|---| -| **CDNA3** | gfx942/gfx950 (MI300X) | 64 | MFMA | BufferCopy, preshuffle GEMM | -| **RDNA** | gfx10xx/gfx11xx/gfx12xx | 32 | WMMA | RDNA-specific GEMM | -| **gfx1250** | MI400 | 32 | WMMA | TDM ops, FP8/FP4 GEMM, multi-stage pipeline | +| **CDNA3** | gfx942 (MI300X) | 64 | MFMA | BufferCopy, preshuffle GEMM | +| **CDNA4** | gfx950 (MI350/MI355X) | 64 | MFMA | MFMA_SCALE, FP4, 160KB LDS | +| **RDNA4** | gfx1201 (Radeon AI PRO R9700) | 32 | WMMA | RDNA-specific GEMM | +| **gfx1250** | MI450 | 32 | WMMA | TDM ops, FP8/FP4 GEMM, multi-stage pipeline | ## Key Conventions & Pitfalls diff --git a/README.md b/README.md index bc471b47..d040c81e 100644 --- a/README.md +++ b/README.md @@ -357,8 +357,8 @@ See `examples/` for more examples including tiled copy (`02-tiledCopy.py`), tile | **MoE GEMM** | `test_moe_gemm.py` | MoE 2-stage (gate/up + reduce) | | **MoE Blockscale** | `test_moe_blockscale.py` | MoE blockscale 2-stage | | **MoE Reduce** | `test_moe_reduce.py` | MoE reduce kernel | -| **PagedAttention** | `test_pa.py` | Paged attention decode (FP8) | -| **FlashAttention** | `test_flash_attn_func.py` | Flash attention | +| **PagedAttention** | `test_pa.py` | Paged attention decode (FP8) — *WIP perf tuning* | +| **FlashAttention** | `test_flash_attn_func.py` | Flash attention — *WIP perf tuning* | | **LayerNorm** | `test_layernorm.py` | LayerNorm (layout API) | | **RMSNorm** | `test_rmsnorm.py` | RMSNorm (layout API) | | **Softmax** | `test_softmax.py` | Softmax (layout API) | @@ -371,7 +371,7 @@ See `examples/` for more examples including tiled copy (`02-tiledCopy.py`), tile | **Quantization** | `test_quant.py` | Quantization utilities | **Verified Platforms**: -* AMD MI300X/MI308X (gfx942), AMD MI350 (gfx950), AMD MI450 (gfx1250) +* AMD MI300X/MI308X (gfx942), AMD MI350/MI355X (gfx950), AMD MI450 (gfx1250), Radeon AI PRO R9700 (gfx1201) * Linux / ROCm 6.x, 7.x ## 🙏 Acknowledgements diff --git a/docs/architecture_guide.md b/docs/architecture_guide.md index 696c9c0a..6b0eb9e2 100644 --- a/docs/architecture_guide.md +++ b/docs/architecture_guide.md @@ -358,7 +358,9 @@ Transforms Python control flow to MLIR ops at the AST level: | Architecture | GPU | LDS per CU | Notes | |---|---|---|---| | `gfx942` | MI300A / MI300X | 64 KB | CDNA 3, primary development target | -| `gfx950` | MI350 | 160 KB | CDNA 4, larger LDS | +| `gfx950` | MI350 / MI355X | 160 KB | CDNA 4, larger LDS | +| `gfx1201` | Radeon AI PRO R9700 | 64 KB | RDNA 4 | +| `gfx1250` | MI450 | 320 KB | GFX12, wave32, WMMA, TDM ops | | `gfx90a` | MI250X | 64 KB | CDNA 2 (verified platform) | --- diff --git a/docs/kernel_authoring_guide.md b/docs/kernel_authoring_guide.md index 046e9d7b..031a956e 100644 --- a/docs/kernel_authoring_guide.md +++ b/docs/kernel_authoring_guide.md @@ -415,7 +415,9 @@ with ir.InsertionPoint(comp_ctx.gpu_module_body): | Architecture | LDS per CU | |---|---| | `gfx942` (MI300X) | 64 KB | -| `gfx950` (MI350) | 160 KB | +| `gfx950` (MI350/MI355X) | 160 KB | +| `gfx1201` (Radeon AI PRO R9700) | 64 KB | +| `gfx1250` (MI450) | 320 KB | --- From b1688aa538a94ee7093ba2daf24fd1dddab776ec Mon Sep 17 00:00:00 2001 From: Felix Li Date: Sun, 12 Apr 2026 09:38:51 +0800 Subject: [PATCH 04/29] update to v0.1.3 (#385) --- python/flydsl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/flydsl/__init__.py b/python/flydsl/__init__.py index b215ad92..99582a11 100644 --- a/python/flydsl/__init__.py +++ b/python/flydsl/__init__.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -_BASE_VERSION = "0.1.2" +_BASE_VERSION = "0.1.3.1" # FFM simulator compatibility shim (no-op outside simulator sessions). from ._compat import _maybe_preload_system_comgr # noqa: E402 From bf6a8d074fba3a6887b69034b533dc5890b460e5 Mon Sep 17 00:00:00 2001 From: yadaish Date: Mon, 13 Apr 2026 13:09:33 +0800 Subject: [PATCH 05/29] Pr/a16wi4 group (#370) --------- Co-authored-by: root Co-authored-by: Claude Opus 4 Co-authored-by: Felix Li --- kernels/hgemm_splitk.py | 8 +- kernels/mfma_preshuffle_pipeline.py | 262 +++++++- kernels/moe_gemm_2stage.py | 810 ++++++++++++++++++------- python/flydsl/expr/rocdl/__init__.py | 24 +- python/flydsl/expr/rocdl/inline_asm.py | 65 ++ scripts/run_benchmark.sh | 56 ++ tests/kernels/test_moe_gemm.py | 170 +++++- tests/kernels/test_ref.py | 4 +- tests/utils.py | 23 +- 9 files changed, 1136 insertions(+), 286 deletions(-) create mode 100644 python/flydsl/expr/rocdl/inline_asm.py diff --git a/kernels/hgemm_splitk.py b/kernels/hgemm_splitk.py index be8c80d0..723b4ab5 100644 --- a/kernels/hgemm_splitk.py +++ b/kernels/hgemm_splitk.py @@ -78,12 +78,12 @@ def __init__(self, dtype: str): self.dtype = dtype def __call__(self, a_frag, b_frag, c_frag): + res_ty = T.vec(self.WMMA_C_FRAG_VALUES, T.f32) + operands = [a_frag, b_frag, c_frag, 0, 0, 0] if self.dtype == 'bf16': - c_frag_new = rocdl.mfma_f32_16x16x32_bf16(T.vec(self.WMMA_C_FRAG_VALUES, T.f32), a_frag, b_frag, c_frag, 0, 0, 0).res - return c_frag_new + return rocdl.mfma_f32_16x16x32_bf16(res_ty, operands) else: - c_frag_new = rocdl.mfma_f32_16x16x32_f16(T.vec(self.WMMA_C_FRAG_VALUES, T.f32), a_frag, b_frag, c_frag, 0, 0, 0).res - return c_frag_new + return rocdl.mfma_f32_16x16x32_f16(res_ty, operands) class OnlineScheduler: diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 1de69d06..41e67047 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -187,6 +187,30 @@ def make_preshuffle_scale_layout( ) +def _unpack_int4_to_int8_pair(packed32): + """Split packed int4 dword into two int8 dwords (even/odd nibbles). + + 7-op bit manipulation shared by all int4 unpack paths (W4A8, W4A16, W4A_FP8). + """ + c_08 = fx.Int32(0x08080808) + c_0f = fx.Int32(0x0F0F0F0F) + c_1e = fx.Int32(0x1E) + c_4 = fx.Int32(4) + s0 = (packed32 & c_08) * c_1e + even = (packed32 & c_0f) | s0 + t = packed32 >> c_4 + s1 = (t & c_08) * c_1e + odd = (t & c_0f) | s1 + return even, odd + + +def _pack_i32_pair_to_i64(lo, hi, vector): + """Pack two i32 values into one i64 via vector bitcast.""" + v2 = vector.from_elements(T.vec(2, T.i32), [lo, hi]) + v64 = vector.bitcast(T.vec(1, T.i64), v2) + return vector.extract(v64, static_position=[0], dynamic_position=[]) + + def _i8x4_in_i32_to_bf16x4_i64(val_i32, arith, vector, scale_val=None): """Convert one i32 (4 signed int8 bytes) to 4 bf16 packed as i64. @@ -254,6 +278,7 @@ def load_b_raw_w4a16( c4_idx = fx.Index(4) k0_base = base_k // c64 + k1_layout_offset = ku * 2 lane_div_32 = lane_div_16 // c2_idx total_k1 = fx.Index(k1_layout_offset) + lane_div_32 @@ -278,24 +303,80 @@ def load_b_raw_w4a16( return packed32 -def unpack_b_w4a16(packed32, arith, vector, scale_val=None): +def _int4_to_bf16x4_i64_gfx950(packed32, nibble_offsets, arith, vector, scale_val=None, defer_scale16=False): + """Convert 4 int4 nibbles to 4 bf16 packed as i64 using gfx950 instructions. + + Uses v_cvt_off_f32_i4_sdwa with byte_sel to avoid per-nibble shifts. + Even nibbles (0,2,4,6) → SDWA BYTE_0/1/2/3 on original src. + Odd nibbles (1,3,5,7) → SDWA BYTE_0/1/2/3 on (src >> 4). + Only 1 shift total instead of 7. + + When defer_scale16=True, the ×16 correction factor for v_cvt_off_f32_i4 is + omitted and must be applied later (e.g. in the epilogue). This saves VALU + in the hot loop and uses v_cvt_pk_bf16_f32 for proper f32→bf16 conversion. + """ + from flydsl.expr import rocdl + from flydsl._mlir.dialects._arith_ops_gen import MulFOp as _MulFOp + + _uw = _arith._to_raw + _av = _arith.ArithValue + + src_even = packed32 + src_odd = packed32 >> fx.Int32(4) + + f32_vals = [] + for nib in nibble_offsets: + byte_idx = nib // 2 + src = src_odd if (nib % 2) else src_even + v = rocdl.cvt_off_f32_i4(src, byte_sel=byte_idx) + f32_vals.append(v) + + if defer_scale16: + # Skip ×16; multiply by scale_val only if groupwise. + if scale_val is not None: + raw_scale = _uw(scale_val) + f32_vals = [_MulFOp(v, raw_scale).result for v in f32_vals] + # Use v_cvt_pk_bf16_f32 for proper f32→bf16 (no bit-shift trick needed). + i32_lo = rocdl.cvt_pk_bf16_f32(f32_vals[0], f32_vals[1]) + i32_hi = rocdl.cvt_pk_bf16_f32(f32_vals[2], f32_vals[3]) + else: + c16 = fx.Float32(16.0) + if scale_val is not None: + effective_scale = scale_val * c16 + else: + effective_scale = c16 + raw_scale = _uw(effective_scale) + f32_vals = [_MulFOp(v, raw_scale).result for v in f32_vals] + # Truncate f32→bf16 via bit-shift (exact for scaled int values). + c16_shift = fx.Int32(16) + c_ffff0000 = fx.Int32(0xFFFF0000) + bf16_vals = [arith.bitcast(T.i32, _av(v)) for v in f32_vals] + i32_lo = (bf16_vals[0] >> c16_shift) | (bf16_vals[1] & c_ffff0000) + i32_hi = (bf16_vals[2] >> c16_shift) | (bf16_vals[3] & c_ffff0000) + + v2 = vector.from_elements(T.vec(2, T.i32), [i32_lo, i32_hi]) + v64 = vector.bitcast(T.vec(1, T.i64), v2) + return vector.extract(v64, static_position=[0], dynamic_position=[]) + + +def unpack_b_w4a16(packed32, arith, vector, scale_val=None, use_gfx950_cvt=False, defer_scale16=False): """Phase 2 of W4A16 B load: unpack int4->int8 + convert int8->bf16. Takes raw packed32 from load_b_raw_w4a16 and produces (b0, b1) -- two i64 values each containing 4 bf16 for one MFMA. - """ - c_08080808 = fx.Int32(0x08080808) - c_0f0f0f0f = fx.Int32(0x0F0F0F0F) - c_1e = fx.Int32(0x1E) - c_4_i32 = fx.Int32(4) - - s0 = (packed32 & c_08080808) * c_1e - even = (packed32 & c_0f0f0f0f) | s0 - t = packed32 >> c_4_i32 - s1 = (t & c_08080808) * c_1e - odd = (t & c_0f0f0f0f) | s1 + When use_gfx950_cvt=True, uses v_cvt_off_f32_i4 + v_cvt_pk_bf16_f32 + for ~2x fewer VALU instructions. + When defer_scale16=True (requires use_gfx950_cvt=True), the ×16 + correction for v_cvt_off_f32_i4 is omitted; caller must apply it + in the epilogue. + """ + if use_gfx950_cvt: + b0 = _int4_to_bf16x4_i64_gfx950(packed32, [0, 2, 4, 6], arith, vector, scale_val, defer_scale16=defer_scale16) + b1 = _int4_to_bf16x4_i64_gfx950(packed32, [1, 3, 5, 7], arith, vector, scale_val, defer_scale16=defer_scale16) + return (b0, b1) + even, odd = _unpack_int4_to_int8_pair(packed32) b0 = _i8x4_in_i32_to_bf16x4_i64(even, arith, vector, scale_val=scale_val) b1 = _i8x4_in_i32_to_bf16x4_i64(odd, arith, vector, scale_val=scale_val) return (b0, b1) @@ -352,22 +433,8 @@ def load_b_pack_k32( static_position=[0], dynamic_position=[], ) - - c_08080808 = fx.Int32(0x08080808) - c_0f0f0f0f = fx.Int32(0x0F0F0F0F) - c_1e = fx.Int32(0x1E) - c_4_i32 = fx.Int32(4) - - s0 = (packed32 & c_08080808) * c_1e - even = (packed32 & c_0f0f0f0f) | s0 - - t = packed32 >> c_4_i32 - s1 = (t & c_08080808) * c_1e - odd = (t & c_0f0f0f0f) | s1 - - v2 = vector.from_elements(T.vec(2, T.i32), [even, odd]) - v64 = vector.bitcast(T.vec(1, T.i64), v2) - return vector.extract(v64, static_position=[0], dynamic_position=[]) + even, odd = _unpack_int4_to_int8_pair(packed32) + return _pack_i32_pair_to_i64(even, odd, vector) vec_elems = kpack_bytes // int(elem_bytes) b16 = _buffer_load_vec( @@ -548,15 +615,152 @@ def lds_load_pack_k32( __all__ = [ "PreshuffleBLayout", + "PreshuffleScaleLayout", "buffer_copy_gmem16_dwordx4", - "lds_row_major_idx", "lds_load_pack_k32", + "lds_row_major_idx", "lds_store_4b_xor16", "lds_store_8b_xor16", "lds_store_16b_xor16", "make_preshuffle_b_layout", + "make_preshuffle_scale_layout", "load_b_pack_k32", + "load_b_raw_w4a16", + "unpack_b_w4a16", + "load_b_raw_w4a16_groupwise", + "unpack_b_w4a16_groupwise", + "extract_bf16_scale", "split_row_major_2d", "swizzle_xor16", "tile_chunk_coord_i32", ] + + +# --------------------------------------------------------------------------- +# Groupwise scale load helper (shared by W4A16 and W4A8 groupwise paths) +# --------------------------------------------------------------------------- + +def _load_groupwise_scale( + buffer_ops, + arith, + *, + scale_rsrc, + expert_offset, + n_blk, + n_intra, + k_pos, + num_groups: int, + group_size: int, + n_per_expert: int, + scale_dtype=None, +): + """Load one per-group scale value from the scale buffer. + + Computes the linear index into the scale tensor from expert offset, + N position, and group index derived from ``k_pos``. + + For bf16 scales the tensor uses ``(E, G//2, N, 2)`` layout — two + adjacent groups for the same N position are packed into one dword. + We load the raw i32 dword (no extraction) so it can be carried as + loop state without register copies. Use :func:`extract_bf16_scale` + in the compute phase to obtain the f32 value. + """ + c16 = fx.Index(16) + n_global = n_blk * c16 + n_intra + c_group_size = fx.Index(group_size) + c_npe = fx.Index(n_per_expert) + group_idx = k_pos // c_group_size + if scale_dtype is None: + scale_dtype = T.f32 + + if scale_dtype == T.bf16: + # (E, G//2, N, 2) layout: dword at [e, pair, n] holds bf16 scales + # for groups 2*pair and 2*pair+1. + pair_idx = group_idx >> fx.Index(1) # group_idx // 2 + # Dword index: same flat formula but with G//2 groups + num_pairs = num_groups // 2 + c_npm1 = fx.Index(num_pairs - 1) + dword_base = expert_offset * c_npm1 + n_global + dword_elem = dword_base + pair_idx * c_npe + dword_idx = arith.index_cast(T.i32, dword_elem) + # Return raw i32 dword — extraction deferred to compute phase. + scale_val = buffer_ops.buffer_load(scale_rsrc, dword_idx, vec_width=1, dtype=T.i32) + else: + # (E, G, N) layout with f32 dtype + c_gm1 = fx.Index(num_groups - 1) + base_scale = expert_offset * c_gm1 + n_global + elem_idx = base_scale + group_idx * c_npe + scale_idx_i32 = arith.index_cast(T.i32, elem_idx) + scale_val = buffer_ops.buffer_load(scale_rsrc, scale_idx_i32, vec_width=1, dtype=T.f32) + return scale_val + + +def extract_bf16_scale(arith, scale_raw_i32, ku: int): + """Extract f32 scale from raw i32 dword loaded by bf16 groupwise path. + + In the ``(E, G//2, N, 2)`` layout two adjacent groups share one dword. + ``ku`` determines which half: even ku → low bf16, odd ku → high bf16. + """ + if ku % 2 == 0: + # Low bf16: shift left by 16 to place in upper 16 bits → f32 + return arith.bitcast(T.f32, scale_raw_i32 << fx.Int32(16)) + else: + # High bf16: mask upper 16 bits → f32 + return arith.bitcast(T.f32, scale_raw_i32 & fx.Int32(0xFFFF0000)) + + +# --------------------------------------------------------------------------- +# W4A16 groupwise load / unpack helpers +# --------------------------------------------------------------------------- + +def load_b_raw_w4a16_groupwise( + buffer_ops, + arith, + vector, + *, + arg_b, + b_rsrc, + layout_b, + base_k, + ku: int, + n_blk, + n_intra, + lane_div_16, + elem_type, + scale_rsrc, + expert_offset, + num_groups: int, + group_size: int, + n_per_expert: int, + kpack_bytes: int = 8, + scale_dtype=None, +): + """Phase 1 of W4A16 groupwise B load: buffer_loads for weight + scale. + + Reuses :func:`load_b_raw_w4a16` for the weight load, then issues an + additional ``buffer_load_dword`` for the per-group scale. + + Returns ``(packed32, scale_val)``. + """ + packed32 = load_b_raw_w4a16( + buffer_ops, arith, vector, + arg_b=arg_b, b_rsrc=b_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk, n_intra=n_intra, + lane_div_16=lane_div_16, elem_type=elem_type, + kpack_bytes=kpack_bytes, + ) + k_pos = base_k + fx.Index(ku * 32) + scale_val = _load_groupwise_scale( + buffer_ops, arith, + scale_rsrc=scale_rsrc, expert_offset=expert_offset, + n_blk=n_blk, n_intra=n_intra, k_pos=k_pos, + num_groups=num_groups, group_size=group_size, n_per_expert=n_per_expert, + scale_dtype=scale_dtype, + ) + return (packed32, scale_val) + + +def unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector, use_gfx950_cvt=False): + """Phase 2 of W4A16 groupwise: unpack + scale + convert to bf16.""" + return unpack_b_w4a16(packed32, arith, vector, scale_val=scale_val, use_gfx950_cvt=use_gfx950_cvt) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index b9333570..bce43ece 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -55,6 +55,9 @@ def bf16_global_atomics_arch_description() -> str: load_b_pack_k32, load_b_raw_w4a16, unpack_b_w4a16, + load_b_raw_w4a16_groupwise, + unpack_b_w4a16_groupwise, + extract_bf16_scale, tile_chunk_coord_i32, swizzle_xor16, ) @@ -105,6 +108,7 @@ def compile_moe_gemm1( group_size: int = -1, out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, + scale_is_bf16: bool = False, ): """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. @@ -118,20 +122,23 @@ def compile_moe_gemm1( have a distinct input scaling before quantization). - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel - "int4_bf16": W4A16 path: X is bf16, W is packed int4 unpacked to bf16 in-kernel + scale_is_bf16: When True, groupwise scales are bf16 (halves scale bandwidth). """ gpu_arch = get_hip_arch() allocator = SmemAllocator(None, arch=gpu_arch) _state = {} # legacy; kept until stage2/reduction are migrated - if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16"): + _valid_dtypes = ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16") + if in_dtype not in _valid_dtypes: raise ValueError( - f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16'), got {in_dtype!r}" + f"in_dtype must be one of {_valid_dtypes}, got {in_dtype!r}" ) - is_int4_bf16 = in_dtype == "int4_bf16" + is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights is_f16 = in_dtype == "fp16" is_bf16 = is_int4_bf16 or in_dtype == "bf16" is_f16_or_bf16 = is_f16 or is_bf16 + needs_scale_x = not is_f16_or_bf16 needs_scale_w = (not is_f16_or_bf16) or is_int4_bf16 elem_bytes = 2 if is_f16_or_bf16 else 1 if out_dtype not in ("f16", "bf16"): @@ -156,6 +163,28 @@ def compile_moe_gemm1( # "int8smooth" still uses int8 MFMA, but X/scale_x are provided per (token,slot). is_int8 = is_int8 or x_is_token_slot + # w_is_int4: True for any variant where weights are packed int4. + w_is_int4 = is_int4 or is_int4_bf16 + + # Group-wise scale support for W4A16 + # NOTE: Only group_size=32 is supported due to int4 preshuffle layout constraints. + use_groupwise_scale = w_is_int4 and group_size > 0 + if use_groupwise_scale and group_size != 32: + raise ValueError( + f"FlyDSL groupwise scale only supports group_size=32, got {group_size}. " + f"This is due to int4 preshuffle layout constraints. " + f"Please use Triton kernel for other group sizes." + ) + is_int4_bf16_groupwise = is_int4_bf16 and use_groupwise_scale + num_groups = model_dim // group_size if use_groupwise_scale else 1 + _scale_is_bf16 = scale_is_bf16 and use_groupwise_scale + scale_w_size_stage1 = experts * (2 * inter_dim) * num_groups + # For groupwise scale, weight scale is applied per-group in the K loop, + # so epilogue can skip weight scale multiplication (uses 1.0 for sw). + + _is_gfx950 = "gfx95" in get_hip_arch() + use_gfx950_cvt = is_int4_bf16 and _is_gfx950 + mfma_i32_k32 = None if is_int8: mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( @@ -178,11 +207,14 @@ def compile_moe_gemm1( "(or `rocdl.mfma_f32_16x16x16_bf16_1k`)." ) + # gfx950: use 16x16x32 MFMA for f16/bf16 (K=32 per MFMA, vs K=16 on gfx942). + _use_mfma_k32 = _is_gfx950 and (is_f16 or is_bf16) + DYN = ir.ShapedType.get_dynamic_size() size_out = DYN size_x = DYN - # W is packed int4 for W4A8: 2 values per byte. - size_w = (experts * (2 * inter_dim) * model_dim) // 2 if (is_int4 or is_int4_bf16) else (experts * (2 * inter_dim) * model_dim) + # W is packed int4 for W4A8/W4A16/W4A_FP8: 2 values per byte. + size_w = (experts * (2 * inter_dim) * model_dim) // 2 if w_is_int4 else (experts * (2 * inter_dim) * model_dim) size_sorted = DYN size_expert_ids = DYN @@ -213,10 +245,13 @@ def compile_moe_gemm1( epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" # IMPORTANT: module name participates in FlyDSL's compile cache key. # Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary. + _gs_tag = f"_g{group_size}" if use_groupwise_scale else "" + scale_tag = "_sbf16" if _scale_is_bf16 else "" module_name = ( f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults + f"{_gs_tag}{scale_tag}" + f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults ).replace("-", "_") # ── LDS sizing (pure Python; no MLIR Context needed) ───────────────────── @@ -261,7 +296,8 @@ def moe_gemm1( k_i32_v = i32_k_in x_elem = T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8)) # For int4/int4_bf16, weights are stored as packed bytes (i8) and unpacked in-kernel. - w_elem = T.i8 if (is_int4 or is_int4_bf16) else (T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8))) + w_elem = T.i8 if w_is_int4 else (T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8))) + scale_dtype = T.bf16 if _scale_is_bf16 else T.f32 vec16_elems = 16 if elem_bytes == 1 else 8 vec8_elems = 8 if elem_bytes == 1 else 4 vec4_elems = 4 if elem_bytes == 1 else 2 @@ -287,15 +323,16 @@ def silu(x): if is_int8 else arith.constant_vector(0.0, T.f32x4) ) + zero_f32_acc = arith.constant_vector(0.0, T.f32x4) if is_int4_bf16_groupwise else None # Layouts (use i32 values; fly.make_shape requires i32/i64, not index) layout_x = fx.make_layout((tokens_i32_v, k_i32_v), stride=(k_i32_v, 1)) # B preshuffle layout: match GEMM test helper exactly. c_n_total = arith.index(experts * (2 * inter_dim)) - # For packed int4 (W4A8/W4A16), kpack_bytes=8. - kpack_bytes = 8 if (is_int4 or is_int4_bf16) else 16 - w_elem_bytes = 1 if (is_int4 or is_int4_bf16) else elem_bytes + # For packed int4 (W4A8/W4A16/W4A_FP8), kpack_bytes=8. + kpack_bytes = 8 if w_is_int4 else 16 + w_elem_bytes = 1 if w_is_int4 else elem_bytes b_layout = make_preshuffle_b_layout( arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) @@ -600,13 +637,37 @@ def load_b_pack(base_k, ki_step, ni, blk_list, intra_list): def load_b_tile(base_k, blk_list, intra_list): """Prefetch the entire per-thread B tile (gmem -> regs) for a given K base. - + Returns a list of length `k_unroll`, where each entry is a tuple: (packs_half0[ni], packs_half1[ni]) for the K64 micro-step. + For groupwise variants, each entry also includes per-group scales: + (packs0[ni], packs1[ni], scales0[ni], scales1[ni]) """ - if is_int4_bf16: - # W4A16: 2-phase load+unpack for VMEM latency hiding - # Phase 1: Issue ALL buffer_loads first. + if is_int4_bf16_groupwise: + # W4A16 groupwise: load raw packed32 + scale; defer dequant to compute_tile. + raw_data = [] + for ku in range_constexpr(k_unroll): + raw_ku = [] + for ni in range_constexpr(num_acc_n): + packed32, scale_val = load_b_raw_w4a16_groupwise( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=blk_list[ni], n_intra=intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + scale_rsrc=sw_rsrc, + expert_offset=expert_off_idx, + num_groups=num_groups, + group_size=group_size, + n_per_expert=2*inter_dim, + kpack_bytes=kpack_bytes, + scale_dtype=scale_dtype, + ) + raw_ku.append((packed32, scale_val)) + raw_data.append(raw_ku) + return raw_data + elif is_int4_bf16: + # W4A16 per-row: load raw packed32; defer dequant to compute_tile. raw_data = [] for ku in range_constexpr(k_unroll): raw_ku = [] @@ -621,30 +682,22 @@ def load_b_tile(base_k, blk_list, intra_list): ) raw_ku.append(raw) raw_data.append(raw_ku) - # Phase 2: Unpack ALL (by now early loads have completed). + return raw_data + else: + # fp8/int8/bf16/fp16: original code path b_tile = [] for ku in range_constexpr(k_unroll): packs0 = [] packs1 = [] for ni in range_constexpr(num_acc_n): - b0, b1 = unpack_b_w4a16(raw_data[ku][ni], arith, vector) + ki0 = (ku * 2) + 0 + ki1 = (ku * 2) + 1 + b0 = load_b_pack(base_k, ki0, ni, blk_list, intra_list) + b1 = load_b_pack(base_k, ki1, ni, blk_list, intra_list) packs0.append(b0) packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile - b_tile = [] - for ku in range_constexpr(k_unroll): - packs0 = [] - packs1 = [] - for ni in range_constexpr(num_acc_n): - ki0 = (ku * 2) + 0 - ki1 = (ku * 2) + 1 - b0 = load_b_pack(base_k, ki0, ni, blk_list, intra_list) - b1 = load_b_pack(base_k, ki1, ni, blk_list, intra_list) - packs0.append(b0) - packs1.append(b1) - b_tile.append((packs0, packs1)) - return b_tile acc_gate = [acc_init] * (num_acc_n * m_repeat) acc_up = [acc_init] * (num_acc_n * m_repeat) @@ -727,20 +780,23 @@ def compute_tile( gate_list = list(acc_gate_in) up_list = list(acc_up_in) mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 - mfma_fn = ( - mfma_i32_k32 - if is_int8 - else ( - mfma_f32_bf16_k16 - if is_bf16 - else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) + if _use_mfma_k32: + mfma_fn = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 + else: + mfma_fn = ( + mfma_i32_k32 + if is_int8 + else ( + mfma_f32_bf16_k16 + if is_bf16 + else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) + ) ) - ) - + # Optional: prefetch epilogue scales while we are about to run the last MFMA tile, # matching the preshuffle GEMM pattern of overlapping scale loads with MFMA. epilogue_pf = None - if prefetch_epilogue: + if prefetch_epilogue and not use_groupwise_scale: expert_off_pf = expert_off_idx sw_gate_pf = [] sw_up_pf = [] @@ -768,7 +824,24 @@ def _i64_to_v4i16(x_i64): v1 = vector.from_elements(T.vec(1, T.i64), [x_i64]) return vector.bitcast(T.i16x4, v1) + def _i64x2_to_v8f16(lo, hi): + v2 = vector.from_elements(T.i64x2, [lo, hi]) + return vector.bitcast(T.f16x8, v2) + + def _i64x2_to_v8bf16(lo, hi): + v2 = vector.from_elements(T.i64x2, [lo, hi]) + return vector.bitcast(T.bf16x8, v2) + def mfma_k64(acc_in, a0, a1, b0, b1): + if _use_mfma_k32: + # gfx950: single 16x16x32 MFMA consuming all 128 bits (K=32 f16/bf16) + if is_f16: + av = _i64x2_to_v8f16(a0, a1) + bv = _i64x2_to_v8f16(b0, b1) + else: + av = _i64x2_to_v8bf16(a0, a1) + bv = _i64x2_to_v8bf16(b0, b1) + return mfma_fn(mfma_res_ty, [av, bv, acc_in, 0, 0, 0]) if is_f16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) @@ -785,38 +858,99 @@ def mfma_k64(acc_in, a0, a1, b0, b1): return mfma_fn(mfma_res_ty, [a1v, b1v, acc_mid, 0, 0, 0]) acc_mid = mfma_fn(mfma_res_ty, [a0, b0, acc_in, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1, b1, acc_mid, 0, 0, 0]) - - for ku in range_constexpr(k_unroll): - b_gate_packs0, b_gate_packs1 = b_gate_tile_in[ku] - b_up_packs0, b_up_packs1 = b_up_tile_in[ku] - ki64 = arith.index(ku * 64) - col_base = col_offset_base_bytes + ki64 - - for mi in range_constexpr(m_repeat): - mi_val = arith.index(mi * 16) - curr_row_a_lds = row_a_lds + mi_val - - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) - - for ni in range_constexpr(num_acc_n): - acc_idx = mi * num_acc_n + ni - gate_list[acc_idx] = mfma_k64( - gate_list[acc_idx], - a0, - a1, - b_gate_packs0[ni], - b_gate_packs1[ni], - ) - up_list[acc_idx] = mfma_k64( - up_list[acc_idx], - a0, - a1, - b_up_packs0[ni], - b_up_packs1[ni], - ) + + def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): + """MFMA f32 partial -> scale -> add to f32 accumulator via math.fma on vector.""" + from flydsl._mlir.dialects._math_ops_gen import fma as _math_fma + _uw = arith._to_raw + scale_vec = _uw(vector.broadcast(T.f32x4, scale_val)) + return arith.ArithValue(_math_fma(scale_vec, _uw(f32_partial_vec), _uw(f32_acc_vec))) + + if is_int4_bf16 or is_int4_bf16_groupwise: + # W4A16: deferred dequant — unpack int4->bf16 right before MFMA + # to minimize VGPR lifetime of dequantized bf16 values. + _pending_gate_up = None + for ku in range_constexpr(k_unroll): + b_gate_raw = b_gate_tile_in[ku] + b_up_raw = b_up_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + if is_int4_bf16_groupwise: + packed_g, sc_g = b_gate_raw[ni] + packed_u, sc_u = b_up_raw[ni] + if _scale_is_bf16: + sc_g = extract_bf16_scale(arith, sc_g, ku) + sc_u = extract_bf16_scale(arith, sc_u, ku) + else: + packed_g, sc_g = b_gate_raw[ni], None + packed_u, sc_u = b_up_raw[ni], None + if is_int4_bf16_groupwise and use_gfx950_cvt: + # Defer group scale to post-MFMA FMA with pipeline: + # Issue current MFMA, then apply FMA for previous iteration's result. + bg0, bg1 = unpack_b_w4a16(packed_g, arith, vector, scale_val=None, use_gfx950_cvt=True, defer_scale16=True) + tmp_g = mfma_k64(zero_f32_acc, a0, a1, bg0, bg1) + bu0, bu1 = unpack_b_w4a16(packed_u, arith, vector, scale_val=None, use_gfx950_cvt=True, defer_scale16=True) + tmp_u = mfma_k64(zero_f32_acc, a0, a1, bu0, bu1) + # Apply FMA for previous pending result (MFMA already completed). + if _pending_gate_up is not None: + p_idx, p_g, p_u, p_sc_g, p_sc_u = _pending_gate_up + gate_list[p_idx] = _acc_scaled_f32(gate_list[p_idx], p_g, p_sc_g) + up_list[p_idx] = _acc_scaled_f32(up_list[p_idx], p_u, p_sc_u) + _pending_gate_up = (acc_idx, tmp_g, tmp_u, sc_g, sc_u) + else: + bg0, bg1 = unpack_b_w4a16(packed_g, arith, vector, scale_val=sc_g, use_gfx950_cvt=use_gfx950_cvt, defer_scale16=use_gfx950_cvt) + gate_list[acc_idx] = mfma_k64(gate_list[acc_idx], a0, a1, bg0, bg1) + bu0, bu1 = unpack_b_w4a16(packed_u, arith, vector, scale_val=sc_u, use_gfx950_cvt=use_gfx950_cvt, defer_scale16=use_gfx950_cvt) + up_list[acc_idx] = mfma_k64(up_list[acc_idx], a0, a1, bu0, bu1) + # Drain last pending FMA. + if _pending_gate_up is not None: + p_idx, p_g, p_u, p_sc_g, p_sc_u = _pending_gate_up + gate_list[p_idx] = _acc_scaled_f32(gate_list[p_idx], p_g, p_sc_g) + up_list[p_idx] = _acc_scaled_f32(up_list[p_idx], p_u, p_sc_u) + else: + for ku in range_constexpr(k_unroll): + b_gate_packs0, b_gate_packs1 = b_gate_tile_in[ku] + b_up_packs0, b_up_packs1 = b_up_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + gate_list[acc_idx] = mfma_k64( + gate_list[acc_idx], + a0, + a1, + b_gate_packs0[ni], + b_gate_packs1[ni], + ) + up_list[acc_idx] = mfma_k64( + up_list[acc_idx], + a0, + a1, + b_up_packs0[ni], + b_up_packs1[ni], + ) return gate_list, up_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- @@ -828,6 +962,8 @@ def mfma_k64(acc_in, a0, a1, b0, b1): rocdl.sched_barrier(0) def hot_loop_scheduler(): + rocdl.sched_barrier(0) + return mfma_group = num_acc_n * 2 # K64 micro-step: 2x K32 MFMA per gemm. mfma_total = (k_unroll * 2) * m_repeat * mfma_group @@ -871,70 +1007,132 @@ def hot_loop_scheduler(): # tile we are about to compute from LDS, to overlap with upcoming VMEM. a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - # Unrolled ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles. - # Keep this as constexpr expansion to avoid SCF child-region dominance issues - # when carrying MFMA accumulators/prefetch values into the tail section. + # Ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles. + # Uses scf.for with loop-carried accumulators, B-tile prefetch, and A0 LDS prefetch. c2_tile_k = arith.index(tile_k * 2) + c_tile_k = arith.index(tile_k) total_tiles = int(model_dim) // int(tile_k) pair_iters = max((total_tiles - 2) // 2, 0) - for pair_i in range_constexpr(pair_iters): - k_iv = arith.index(pair_i * (tile_k * 2)) + # B-tile data layout per k_unroll entry (3 variants): + # + # 1) int4 + groupwise scale (is_int4_bf16_groupwise): + # [(packed_w4, scale), (packed_w4, scale), ...] per ni + # Each ni has a (packed_weights, groupwise_scale) pair. + # Flattened as: [packed_0..N, scale_0..N] → 2 * num_acc_n values + # + # 2) int4_bf16 without groupwise scale (int4_bf16_single_field): + # [raw_i64, raw_i64, ...] per ni + # Single packed i64 per ni, already contains both weight halves. + # Flattened as: [raw_0..N] → 1 * num_acc_n values + # + # 3) fp8/int8/bf16/fp16 (default — two register packs per ku): + # (packs_even_list, packs_odd_list) + # Two lists of num_acc_n regs for even/odd MFMA operands. + # Flattened as: [even_0..N, odd_0..N] → 2 * num_acc_n values + # + int4_bf16_single_field = is_int4_bf16 and not is_int4_bf16_groupwise + _fields_per_ku = 1 if int4_bf16_single_field else 2 + _vals_per_b_tile = k_unroll * _fields_per_ku * num_acc_n + + def _flatten_b_tile(b_tile): + """Flatten B tile to a 1-D list for scf.for loop-carried state.""" + flat = [] + for ku_entry in b_tile: + if is_int4_bf16_groupwise: + # [(packed, scale), ...] → [packed_0..N, scale_0..N] + flat.extend(t[0] for t in ku_entry) + flat.extend(t[1] for t in ku_entry) + elif int4_bf16_single_field: + # [raw_i64, ...] → [raw_0..N] + flat.extend(ku_entry) + else: + # (packs_even, packs_odd) → [even_0..N, odd_0..N] + flat.extend(ku_entry[0]) + flat.extend(ku_entry[1]) + return flat + + def _unflatten_b_tile(vals): + """Reconstruct B tile from flattened scf.for loop-carried state.""" + b_tile, idx = [], 0 + for _ in range_constexpr(k_unroll): + if is_int4_bf16_groupwise: + packed = list(vals[idx:idx + num_acc_n]) + idx += num_acc_n + scales = list(vals[idx:idx + num_acc_n]) + idx += num_acc_n + b_tile.append([(packed[ni], scales[ni]) for ni in range_constexpr(num_acc_n)]) + elif int4_bf16_single_field: + b_tile.append(list(vals[idx:idx + num_acc_n])) + idx += num_acc_n + else: + packs_even = list(vals[idx:idx + num_acc_n]) + idx += num_acc_n + packs_odd = list(vals[idx:idx + num_acc_n]) + idx += num_acc_n + b_tile.append((packs_even, packs_odd)) + return b_tile + + init_state = ( + list(acc_gate) + list(acc_up) + + _flatten_b_tile(b_gate_cur) + _flatten_b_tile(b_up_cur) + + list(a0_prefetch_pong) + ) + + _n_acc = m_repeat * num_acc_n + _p_bg = 2 * _n_acc + _p_bu = _p_bg + _vals_per_b_tile + _p_a0 = _p_bu + _vals_per_b_tile + + for pair_iv, state in range(0, pair_iters, 1, init=init_state): + _ag = list(state[:_n_acc]) + _au = list(state[_n_acc:_p_bg]) + _bg = _unflatten_b_tile(list(state[_p_bg:_p_bu])) + _bu = _unflatten_b_tile(list(state[_p_bu:_p_a0])) + _a0pf = (state[_p_a0], state[_p_a0 + 1]) + + k_iv = pair_iv * (c_tile_k + c_tile_k) + # ---- stage 0: prefetch+store ping, compute pong ---- - next_k1 = k_iv + tile_k + next_k1 = k_iv + c_tile_k x_regs_ping = load_x_tile(next_k1) - b_gate_ping = load_b_tile(next_k1, n_blk_gate, n_intra_gate) - b_up_ping = load_b_tile(next_k1, n_blk_up, n_intra_up) - - acc_gate, acc_up, _ = compute_tile( - acc_gate, - acc_up, - b_gate_cur, - b_up_cur, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - ) - a0_prefetch_pong = None + _bg_ping = load_b_tile(next_k1, n_blk_gate, n_intra_gate) + _bu_ping = load_b_tile(next_k1, n_blk_up, n_intra_up) + + _ag, _au, _ = compute_tile(_ag, _au, _bg, _bu, lds_base_pong, a0_prefetch=_a0pf) store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() - - # Cross-tile prefetch for the ping tile we are about to compute. - a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) - + + _a0pf_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) + # ---- stage 1: prefetch+store pong, compute ping ---- - next_k2 = k_iv + c2_tile_k + next_k2 = k_iv + c_tile_k + c_tile_k x_regs_pong = load_x_tile(next_k2) - b_gate_next = load_b_tile(next_k2, n_blk_gate, n_intra_gate) - b_up_next = load_b_tile(next_k2, n_blk_up, n_intra_up) - - acc_gate, acc_up, _ = compute_tile( - acc_gate, - acc_up, - b_gate_ping, - b_up_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - ) - a0_prefetch_ping = None + _bg_next = load_b_tile(next_k2, n_blk_gate, n_intra_gate) + _bu_next = load_b_tile(next_k2, n_blk_up, n_intra_up) + + _ag, _au, _ = compute_tile(_ag, _au, _bg_ping, _bu_ping, lds_base_ping, a0_prefetch=_a0pf_ping) store_x_tile_to_lds(x_regs_pong, lds_base_pong) hot_loop_scheduler() gpu.barrier() - - # Cross-tile prefetch for the next pong tile. - a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - - # Advance pong state to next_k2 for next iteration. - b_gate_cur = b_gate_next - b_up_cur = b_up_next - - # Tail: 2 remaining tiles at (k_in - 2*tile_k) and (k_in - tile_k). - # Rebuild prefetch in the current block: values produced inside the `range(...)` - # loop body may live in a child region and cannot be used here. - k_tail0 = k_in - c2_tile_k - b_gate_cur = load_b_tile(k_tail0, n_blk_gate, n_intra_gate) - b_up_cur = load_b_tile(k_tail0, n_blk_up, n_intra_up) - a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) + + _a0pf_new = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) + + loop_results = yield ( + list(_ag) + list(_au) + + _flatten_b_tile(_bg_next) + _flatten_b_tile(_bu_next) + + list(_a0pf_new) + ) + + # After scf.for: extract final state from yielded results. + SmemPtr._view_cache = None + if pair_iters > 0: + acc_gate = list(loop_results[:_n_acc]) + acc_up = list(loop_results[_n_acc:_p_bg]) + b_gate_cur = _unflatten_b_tile(list(loop_results[_p_bg:_p_bu])) + b_up_cur = _unflatten_b_tile(list(loop_results[_p_bu:_p_a0])) + a0_prefetch_pong = (loop_results[_p_a0], loop_results[_p_a0 + 1]) k_tail1 = k_in - tile_k x_regs_ping = load_x_tile(k_tail1) b_gate_ping = load_b_tile(k_tail1, n_blk_gate, n_intra_gate) @@ -975,7 +1173,10 @@ def hot_loop_scheduler(): inter_i32_v = fx.Int32(inter_dim) mask24_i32 = fx.Int32(0xFFFFFF) - if epilogue_pf is not None: + if use_groupwise_scale: + sw_gate_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n + sw_up_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n + elif epilogue_pf is not None: sw_gate_vals, sw_up_vals = epilogue_pf else: sw_gate_vals = [] @@ -995,6 +1196,13 @@ def hot_loop_scheduler(): else buffer_ops.buffer_load(sw_rsrc, row_up_idx, vec_width=1, dtype=T.f32) ) + # When defer_scale16 was used, the x16 correction for v_cvt_off_f32_i4 + # was omitted from the hot loop. Fold it into the epilogue scale. + if use_gfx950_cvt: + _c16 = fx.Float32(16.0) + sw_gate_vals = [v * _c16 for v in sw_gate_vals] + sw_up_vals = [v * _c16 for v in sw_up_vals] + # Epilogue hoists to keep IR + Python build time small: col_i32_list = [] for ni in range_constexpr(num_acc_n): @@ -1279,6 +1487,7 @@ def compile_moe_gemm2( out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, accumulate: bool = True, + scale_is_bf16: bool = False, ): """Compile stage2 kernel (`moe_gemm2`) and return the compiled executable. @@ -1289,6 +1498,7 @@ def compile_moe_gemm2( - "int8": A2/W are int8 - "int4": W4A8 path: A2 is int8, W is packed int4 unpacked to int8 in-kernel - "int4_bf16": W4A16 path: A2 is bf16, W is packed int4 unpacked to bf16 in-kernel + scale_is_bf16: When True, groupwise scales are bf16 (halves scale bandwidth). Stage2 output supports: - out_dtype="f16": fp16 half2 atomics (fast, can overflow to +/-inf for bf16 workloads) @@ -1301,14 +1511,16 @@ def compile_moe_gemm2( allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16"): + _valid_dtypes = ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16") + if in_dtype not in _valid_dtypes: raise ValueError( - f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16'), got {in_dtype!r}" + f"in_dtype must be one of {_valid_dtypes}, got {in_dtype!r}" ) - is_int4_bf16 = in_dtype == "int4_bf16" + is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights is_f16 = in_dtype == "fp16" is_bf16 = is_int4_bf16 or in_dtype == "bf16" is_f16_or_bf16 = is_f16 or is_bf16 + needs_scale_x = not is_f16_or_bf16 needs_scale_w = (not is_f16_or_bf16) or is_int4_bf16 elem_bytes = 2 if is_f16_or_bf16 else 1 out_s = str(out_dtype).strip().lower() @@ -1319,9 +1531,28 @@ def compile_moe_gemm2( if (not bool(accumulate)) and out_is_f32: raise ValueError("compile_moe_gemm2(accumulate=False) only supports out_dtype in {'f16','bf16'}") is_int4 = in_dtype == "int4" + # w_is_int4: True for any variant where weights are packed int4. + w_is_int4 = is_int4 or is_int4_bf16 # INT4 here means W4A8: A2 is int8, W is packed int4 and unpacked to int8 in-kernel. is_int8 = (in_dtype in ("int8", "int8smooth")) or is_int4 + # Group-wise scale support for W4A16 + use_groupwise_scale = w_is_int4 and group_size > 0 + if use_groupwise_scale and group_size != 32: + raise ValueError( + f"FlyDSL groupwise scale only supports group_size=32, got {group_size}. " + f"This is due to int4 preshuffle layout constraints. " + f"Please use Triton kernel for other group sizes." + ) + is_int4_bf16_groupwise = is_int4_bf16 and use_groupwise_scale + # Stage2 K dimension is inter_dim (weight shape: [E, model_dim, inter_dim]) + num_groups = inter_dim // group_size if use_groupwise_scale else 1 + _scale_is_bf16 = scale_is_bf16 and use_groupwise_scale + scale_w_size_stage2 = experts * model_dim * num_groups + + _is_gfx950 = "gfx95" in get_hip_arch() + use_gfx950_cvt = is_int4_bf16 and _is_gfx950 + mfma_i32_k32 = None if is_int8: mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( @@ -1344,14 +1575,17 @@ def compile_moe_gemm2( "(or `rocdl.mfma_f32_16x16x16_bf16_1k`)." ) + # gfx950: use 16x16x32 MFMA for f16/bf16 (K=32 per MFMA, vs K=16 on gfx942). + _use_mfma_k32 = _is_gfx950 and (is_f16 or is_bf16) + DYN = ir.ShapedType.get_dynamic_size() size_out = DYN size_x = DYN size_sorted = DYN size_expert_ids_shape = DYN size_scale_x = DYN - # W is packed int4 for W4A8/W4A16: 2 values per byte. - size_w = (experts * model_dim * inter_dim) // 2 if (is_int4 or is_int4_bf16) else (experts * model_dim * inter_dim) + # W is packed int4 for W4A8/W4A16/W4A_FP8: 2 values per byte. + size_w = (experts * model_dim * inter_dim) // 2 if w_is_int4 else (experts * model_dim * inter_dim) total_threads = 256 tile_k_bytes = int(tile_k) * int(elem_bytes) @@ -1413,9 +1647,12 @@ def out_elem(): # IMPORTANT: module name participates in FlyDSL's compile cache key. # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. + _gs_tag = f"_g{group_size}" if use_groupwise_scale else "" + scale_tag = "_sbf16" if _scale_is_bf16 else "" module_name = ( f"mfma_moe2_{in_dtype}_{out_s}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" + f"{_gs_tag}{scale_tag}" f"_abi2" # mask sentinel token ids on loads/stores to avoid illegal address faults ).replace("-", "_") @@ -1471,7 +1708,8 @@ def moe_gemm2( k_i32_v = i32_k_in x_elem = T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8)) # For int4/int4_bf16, weights are stored as packed bytes (i8) and unpacked in-kernel. - w_elem = T.i8 if (is_int4 or is_int4_bf16) else (T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8))) + w_elem = T.i8 if w_is_int4 else (T.bf16 if is_bf16 else (T.f16 if is_f16 else (T.i8 if is_int8 else T.f8))) + scale_dtype = T.bf16 if _scale_is_bf16 else T.f32 vec16_elems = 16 if elem_bytes == 1 else 8 vec8_elems = 8 if elem_bytes == 1 else 4 vec4_elems = 4 if elem_bytes == 1 else 2 @@ -1483,6 +1721,7 @@ def moe_gemm2( if is_int8 else arith.constant_vector(0.0, T.f32x4) ) + zero_f32_acc = arith.constant_vector(0.0, T.f32x4) if is_int4_bf16_groupwise else None # A2 layout (flatten token-slot -> M; use i32 for fly.make_shape). topk_idx = fx.Index(topk) @@ -1492,8 +1731,9 @@ def moe_gemm2( # B preshuffle layout: [experts*model_dim, inter_dim] c_n_total = arith.index(experts * model_dim) - kpack_bytes = 8 if (is_int4 or is_int4_bf16) else 16 - w_elem_bytes = 1 if (is_int4 or is_int4_bf16) else elem_bytes + # For packed int4 (W4A8/W4A16/W4A_FP8), kpack_bytes=8. + kpack_bytes = 8 if w_is_int4 else 16 + w_elem_bytes = 1 if w_is_int4 else elem_bytes b_layout = make_preshuffle_b_layout( arith, c_n=c_n_total, c_k=k_in, kpack_bytes=kpack_bytes, elem_bytes=w_elem_bytes ) @@ -1791,12 +2031,37 @@ def load_b_pack(base_k, ki_step, ni): def load_b_tile(base_k): """Prefetch the entire per-thread B tile (gmem -> regs) for a given K base. - + Returns a list of length `k_unroll`, where each entry is a tuple: (packs_half0[ni], packs_half1[ni]) for the K64 micro-step. + For groupwise variants, each entry also includes per-group scales: + (packs0[ni], packs1[ni], scales0[ni], scales1[ni]) """ - if is_int4_bf16: - # W4A16: 2-phase load+unpack for VMEM latency hiding + if is_int4_bf16_groupwise: + # W4A16 groupwise: load raw packed32 + scale; defer dequant to compute_tile. + raw_data = [] + for ku in range_constexpr(k_unroll): + raw_ku = [] + for ni in range_constexpr(num_acc_n): + packed32, scale_val = load_b_raw_w4a16_groupwise( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ku=ku, + n_blk=n_blk_list[ni], n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, elem_type=w_elem, + scale_rsrc=sw_rsrc, + expert_offset=expert_off_idx, + num_groups=num_groups, + group_size=group_size, + n_per_expert=model_dim, + kpack_bytes=kpack_bytes, + scale_dtype=scale_dtype, + ) + raw_ku.append((packed32, scale_val)) + raw_data.append(raw_ku) + return raw_data + elif is_int4_bf16: + # W4A16 per-row: load raw packed32; defer dequant to compute_tile. raw_data = [] for ku in range_constexpr(k_unroll): raw_ku = [] @@ -1811,29 +2076,22 @@ def load_b_tile(base_k): ) raw_ku.append(raw) raw_data.append(raw_ku) + return raw_data + else: + # fp8/int8/bf16/fp16: original code path b_tile = [] for ku in range_constexpr(k_unroll): packs0 = [] packs1 = [] for ni in range_constexpr(num_acc_n): - b0, b1 = unpack_b_w4a16(raw_data[ku][ni], arith, vector) + ki0 = (ku * 2) + 0 + ki1 = (ku * 2) + 1 + b0 = load_b_pack(base_k, ki0, ni) + b1 = load_b_pack(base_k, ki1, ni) packs0.append(b0) packs1.append(b1) b_tile.append((packs0, packs1)) return b_tile - b_tile = [] - for ku in range_constexpr(k_unroll): - packs0 = [] - packs1 = [] - for ni in range_constexpr(num_acc_n): - ki0 = (ku * 2) + 0 - ki1 = (ku * 2) + 1 - b0 = load_b_pack(base_k, ki0, ni) - b1 = load_b_pack(base_k, ki1, ni) - packs0.append(b0) - packs1.append(b1) - b_tile.append((packs0, packs1)) - return b_tile # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- def store_x_tile_to_lds(vec_x_in_parts, lds_base): @@ -1903,18 +2161,21 @@ def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): def compute_tile(acc_in, b_tile_in, lds_base, *, prefetch_epilogue: bool = False, a0_prefetch=None): acc_list = list(acc_in) mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 - mfma_fn = ( - mfma_i32_k32 - if is_int8 - else ( - mfma_f32_bf16_k16 - if is_bf16 - else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) + if _use_mfma_k32: + mfma_fn = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 + else: + mfma_fn = ( + mfma_i32_k32 + if is_int8 + else ( + mfma_f32_bf16_k16 + if is_bf16 + else (rocdl.mfma_f32_16x16x16f16 if is_f16 else rocdl.mfma_f32_16x16x32_fp8_fp8) + ) ) - ) - + epilogue_pf = None - if prefetch_epilogue: + if prefetch_epilogue and not use_groupwise_scale: expert_off_pf = expert_off_idx sw_pf = [] for ni in range_constexpr(num_acc_n): @@ -1952,7 +2213,24 @@ def _i64_to_v4i16(x_i64): v1 = vector.from_elements(T.vec(1, T.i64), [x_i64]) return vector.bitcast(T.i16x4, v1) + def _i64x2_to_v8f16(lo, hi): + v2 = vector.from_elements(T.i64x2, [lo, hi]) + return vector.bitcast(T.f16x8, v2) + + def _i64x2_to_v8bf16(lo, hi): + v2 = vector.from_elements(T.i64x2, [lo, hi]) + return vector.bitcast(T.bf16x8, v2) + def mfma_k64(acc0, a0, a1, b0, b1): + if _use_mfma_k32: + # gfx950: single 16x16x32 MFMA consuming all 128 bits (K=32 f16/bf16) + if is_f16: + av = _i64x2_to_v8f16(a0, a1) + bv = _i64x2_to_v8f16(b0, b1) + else: + av = _i64x2_to_v8bf16(a0, a1) + bv = _i64x2_to_v8bf16(b0, b1) + return mfma_fn(mfma_res_ty, [av, bv, acc0, 0, 0, 0]) if is_f16: a0v = _i64_to_v4f16(a0) a1v = _i64_to_v4f16(a1) @@ -1969,30 +2247,78 @@ def mfma_k64(acc0, a0, a1, b0, b1): return mfma_fn(mfma_res_ty, [a1v, b1v, acc1, 0, 0, 0]) acc1 = mfma_fn(mfma_res_ty, [a0, b0, acc0, 0, 0, 0]) return mfma_fn(mfma_res_ty, [a1, b1, acc1, 0, 0, 0]) - - for ku in range_constexpr(k_unroll): - b_packs0, b_packs1 = b_tile_in[ku] - ki64 = arith.index(ku * 64) - col_base = col_offset_base_bytes + ki64 - - for mi in range_constexpr(m_repeat): - mi_val = arith.index(mi * 16) - curr_row_a_lds = row_a_lds + mi_val - - if (a0_prefetch is not None) and (ku == 0) and (mi == 0): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) - - for ni in range_constexpr(num_acc_n): - acc_idx = mi * num_acc_n + ni - acc_list[acc_idx] = mfma_k64( - acc_list[acc_idx], - a0, - a1, - b_packs0[ni], - b_packs1[ni], - ) + + def _acc_scaled_f32(f32_acc_vec, f32_partial_vec, scale_val): + """MFMA f32 partial -> scale -> add to f32 accumulator via math.fma on vector.""" + from flydsl._mlir.dialects._math_ops_gen import fma as _math_fma + _uw = arith._to_raw + scale_vec = _uw(vector.broadcast(T.f32x4, scale_val)) + return arith.ArithValue(_math_fma(scale_vec, _uw(f32_partial_vec), _uw(f32_acc_vec))) + + if is_int4_bf16 or is_int4_bf16_groupwise: + # W4A16: deferred dequant -- unpack int4->bf16 right before MFMA + # to minimize VGPR lifetime of dequantized bf16 values. + _pending_acc = None + for ku in range_constexpr(k_unroll): + b_raw = b_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + if is_int4_bf16_groupwise: + packed, sc = b_raw[ni] + if _scale_is_bf16: + sc = extract_bf16_scale(arith, sc, ku) + else: + packed, sc = b_raw[ni], None + if is_int4_bf16_groupwise and use_gfx950_cvt: + b0, b1 = unpack_b_w4a16(packed, arith, vector, scale_val=None, use_gfx950_cvt=True, defer_scale16=True) + tmp = mfma_k64(zero_f32_acc, a0, a1, b0, b1) + if _pending_acc is not None: + p_idx, p_tmp, p_sc = _pending_acc + acc_list[p_idx] = _acc_scaled_f32(acc_list[p_idx], p_tmp, p_sc) + _pending_acc = (acc_idx, tmp, sc) + else: + b0, b1 = unpack_b_w4a16(packed, arith, vector, scale_val=sc, use_gfx950_cvt=use_gfx950_cvt, defer_scale16=use_gfx950_cvt) + acc_list[acc_idx] = mfma_k64(acc_list[acc_idx], a0, a1, b0, b1) + # Drain last pending FMA. + if _pending_acc is not None: + p_idx, p_tmp, p_sc = _pending_acc + acc_list[p_idx] = _acc_scaled_f32(acc_list[p_idx], p_tmp, p_sc) + else: + for ku in range_constexpr(k_unroll): + b_packs0, b_packs1 = b_tile_in[ku] + ki64 = arith.index(ku * 64) + col_base = col_offset_base_bytes + ki64 + + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + acc_list[acc_idx] = mfma_k64( + acc_list[acc_idx], + a0, + a1, + b_packs0[ni], + b_packs1[ni], + ) return acc_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- @@ -2039,6 +2365,8 @@ def mfma_k64(acc0, a0, a1, b0, b1): # rocdl.sched_barrier(0) def hot_loop_scheduler(): + rocdl.sched_barrier(0) + return # - MFMA group size per "slot": num_acc_n # - Total MFMA per tile: (2*K32 per K64) * k_unroll * m_repeat * num_acc_n # - We emit (mfma_group + dsrd + mfma_group) per scheduler iteration. @@ -2108,39 +2436,93 @@ def hot_loop_scheduler(): k_main2_py = 0 c2_tile_k = arith.index(tile_k * 2) + c_tile_k_s2 = arith.index(tile_k) pair_iters = k_main2_py // (int(tile_k) * 2) - for pair_i in range_constexpr(pair_iters): - k_iv = arith.index(pair_i * (tile_k * 2)) - next_k1 = k_iv + tile_k + + # B-tile data layout per k_unroll entry (3 variants): + # See gemm1 _flatten_b_tile for full layout documentation. + int4_bf16_single_field = is_int4_bf16 and not is_int4_bf16_groupwise + _fields_per_ku = 1 if int4_bf16_single_field else 2 + _vals_per_b_tile = k_unroll * _fields_per_ku * num_acc_n + _n_acc = m_repeat * num_acc_n + _p_b = _n_acc + _p_a0 = _p_b + _vals_per_b_tile + + def _flatten_b_tile(b_tile): + """Flatten B tile to a 1-D list for scf.for loop-carried state.""" + flat = [] + for ku_entry in b_tile: + if is_int4_bf16_groupwise: + flat.extend(t[0] for t in ku_entry) + flat.extend(t[1] for t in ku_entry) + elif int4_bf16_single_field: + flat.extend(ku_entry) + else: + flat.extend(ku_entry[0]) + flat.extend(ku_entry[1]) + return flat + + def _unflatten_b_tile(vals): + """Reconstruct B tile from flattened scf.for loop-carried state.""" + b_tile, idx = [], 0 + for _ in range_constexpr(k_unroll): + if is_int4_bf16_groupwise: + packed = list(vals[idx:idx + num_acc_n]) + idx += num_acc_n + scales = list(vals[idx:idx + num_acc_n]) + idx += num_acc_n + b_tile.append([(packed[ni], scales[ni]) for ni in range_constexpr(num_acc_n)]) + elif int4_bf16_single_field: + b_tile.append(list(vals[idx:idx + num_acc_n])) + idx += num_acc_n + else: + packs_even = list(vals[idx:idx + num_acc_n]) + idx += num_acc_n + packs_odd = list(vals[idx:idx + num_acc_n]) + idx += num_acc_n + b_tile.append((packs_even, packs_odd)) + return b_tile + + init_state = list(acc) + _flatten_b_tile(b_cur) + list(a0_prefetch_pong) + + for pair_iv, state in range(0, pair_iters, 1, init=init_state): + _ac = list(state[:_n_acc]) + _bc = _unflatten_b_tile(list(state[_p_b:_p_a0])) + _a0 = (state[_p_a0], state[_p_a0 + 1]) + + k_iv = pair_iv * (c_tile_k_s2 + c_tile_k_s2) + + next_k1 = k_iv + c_tile_k_s2 x_regs_ping = load_x_tile(next_k1) - b_ping = load_b_tile(next_k1) - - acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) - a0_prefetch_pong = None + _bp = load_b_tile(next_k1) + + _ac, _ = compute_tile(_ac, _bc, lds_base_pong, a0_prefetch=_a0) store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() - - # Cross-tile prefetch for the ping tile we are about to compute. - a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) - - next_k2 = k_iv + c2_tile_k + + _a0p = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) + + next_k2 = k_iv + c_tile_k_s2 + c_tile_k_s2 x_regs_pong = load_x_tile(next_k2) - b_next = load_b_tile(next_k2) - - acc, _ = compute_tile(acc, b_ping, lds_base_ping, a0_prefetch=a0_prefetch_ping) - a0_prefetch_ping = None + _bn = load_b_tile(next_k2) + + _ac, _ = compute_tile(_ac, _bp, lds_base_ping, a0_prefetch=_a0p) store_x_tile_to_lds(x_regs_pong, lds_base_pong) hot_loop_scheduler() gpu.barrier() - - # Cross-tile prefetch for the next pong tile. - a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - - b_cur = b_next - + + _a0n = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) + + loop_results = yield list(_ac) + _flatten_b_tile(_bn) + list(_a0n) + + SmemPtr._view_cache = None + if pair_iters > 0: + acc = list(loop_results[:_n_acc]) + b_cur = _unflatten_b_tile(list(loop_results[_p_b:_p_a0])) + a0_prefetch_pong = (loop_results[_p_a0], loop_results[_p_a0 + 1]) + if odd_k_tiles: - # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). acc, epilogue_pf = compute_tile( acc, b_cur, @@ -2149,18 +2531,15 @@ def hot_loop_scheduler(): a0_prefetch=a0_prefetch_pong, ) else: - # Tail: 2 remaining tiles. k_tail1 = k_in - tile_k x_regs_ping = load_x_tile(k_tail1) b_ping = load_b_tile(k_tail1) - + acc, _ = compute_tile(acc, b_cur, lds_base_pong, a0_prefetch=a0_prefetch_pong) - a0_prefetch_pong = None store_x_tile_to_lds(x_regs_ping, lds_base_ping) hot_loop_scheduler() gpu.barrier() - - # Epilogue tile with sw prefetch. + a0_prefetch_ping = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_ping) acc, epilogue_pf = compute_tile( acc, b_ping, lds_base_ping, prefetch_epilogue=True, a0_prefetch=a0_prefetch_ping @@ -2194,7 +2573,10 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): sw_pf, tw_pf = epilogue_pf # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). - if sw_pf is not None: + if use_groupwise_scale: + # Groupwise: weight scale already applied per-group in K-loop. + sw_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n + elif sw_pf is not None: sw_vals = sw_pf else: sw_vals = [] @@ -2207,6 +2589,12 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): else buffer_ops.buffer_load(sw_rsrc, row_w_idx, vec_width=1, dtype=T.f32) ) + # When defer_scale16 was used, the x16 correction for v_cvt_off_f32_i4 + # was omitted from the hot loop. Fold it into the epilogue scale. + if use_gfx950_cvt: + _c16 = fx.Float32(16.0) + sw_vals = [v * _c16 for v in sw_vals] + if out_is_f32: # origin/dev_a16w4: f32 output uses scalar f32 atomics and skips CShuffle/LDS. c4_i32 = fx.Int32(4) @@ -2824,6 +3212,7 @@ def compile_moe_gemm2_ex( mode: str = MoeGemm2Mode.ATOMIC, valid_mask = None, zero_intermediate: bool = True, + scale_is_bf16: bool = False, ): """Compile MoE GEMM2 kernel with optional reduction. @@ -2860,6 +3249,7 @@ def compile_moe_gemm2_ex( out_dtype=out_dtype, use_cshuffle_epilog=use_cshuffle_epilog, accumulate=False, + scale_is_bf16=scale_is_bf16, ) # Compile reduction kernel with masking support out_s = str(out_dtype).strip().lower() diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index 255495fe..9112b8a2 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -34,8 +34,11 @@ _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 _ods_mfma_i32_16x16x32_i8 = mfma_i32_16x16x32_i8 -_ods_mfma_scale_f32_16x16x128_f8f6f4 = globals().get("mfma_scale_f32_16x16x128_f8f6f4", None) or globals().get( - "mfma_scale_f32_16x16x128_f8f6f4_", None +_ods_mfma_f32_16x16x32_f16 = globals().get("mfma_f32_16x16x32_f16", None) +_ods_mfma_f32_16x16x32_bf16 = globals().get("mfma_f32_16x16x32_bf16", None) +_ods_mfma_scale_f32_16x16x128_f8f6f4 = ( + globals().get("mfma_scale_f32_16x16x128_f8f6f4", None) + or globals().get("mfma_scale_f32_16x16x128_f8f6f4_", None) ) mask_mfma = 0x008 mask_vmem_rd = 0x020 @@ -110,6 +113,22 @@ def mfma_i32_16x16x32_i8(result_type, operands, *, loc=None, ip=None): return _ods_mfma_i32_16x16x32_i8(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip).result +@traced_op +def mfma_f32_16x16x32_f16(result_type, operands, *, loc=None, ip=None): + if _ods_mfma_f32_16x16x32_f16 is None: + raise AttributeError("ROCDL op not found: mfma_f32_16x16x32_f16 (gfx950+)") + a, b, c, cbsz, abid, blgp = _split_mfma_operands(operands, loc=loc) + return _ods_mfma_f32_16x16x32_f16(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip).result + + +@traced_op +def mfma_f32_16x16x32_bf16(result_type, operands, *, loc=None, ip=None): + if _ods_mfma_f32_16x16x32_bf16 is None: + raise AttributeError("ROCDL op not found: mfma_f32_16x16x32_bf16 (gfx950+)") + a, b, c, cbsz, abid, blgp = _split_mfma_operands(operands, loc=loc) + return _ods_mfma_f32_16x16x32_bf16(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip).result + + @traced_op def mfma_scale_f32_16x16x128_f8f6f4(result_type, operands, *, loc=None, ip=None): if _ods_mfma_scale_f32_16x16x128_f8f6f4 is None: @@ -378,6 +397,7 @@ def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): # ── New high-level helpers from universal.py ────────────────────────── from .universal import * # noqa: F401,F403 +from .inline_asm import * # noqa: F401,F403 # ── Wrappers: accept DSL Numeric args (fx.Int32, fx.Float32, etc.) ───────── # ODS-generated ops require raw ir.Value. These wrappers auto-convert. diff --git a/python/flydsl/expr/rocdl/inline_asm.py b/python/flydsl/expr/rocdl/inline_asm.py new file mode 100644 index 00000000..78e4facf --- /dev/null +++ b/python/flydsl/expr/rocdl/inline_asm.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""GCN/CDNA inline assembly wrappers for ROCm GPU instructions. + +These emit LLVM inline asm ops for instructions that have no corresponding +MLIR ROCDL dialect op yet. The underlying ISA instructions are defined in +LLVM's AMDGPU backend (VOP1Instructions.td / VOP3Instructions.td) but the +MLIR ROCDLOps.td tablegen does not surface them. + +TODO: Remove these inline asm wrappers once upstream MLIR adds proper ROCDL +dialect ops for v_cvt_off_f32_i4 and v_cvt_pk_bf16_f32. +""" + + +def _to_ir(v): + """Coerce DSL Numeric to ir.Value if needed.""" + from ..._mlir import ir as _ir + if not isinstance(v, _ir.Value) and hasattr(v, 'ir_value'): + return v.ir_value() + return v + + +def cvt_off_f32_i4(src_i32, byte_sel=None): + """gfx9xx: v_cvt_off_f32_i4 — convert low nibble (bits[3:0]) to f32. + + With byte_sel=0..3, uses SDWA to select the byte before conversion, + avoiding an explicit shift. byte_sel=None uses the plain VOP1 form. + """ + from ..._mlir.dialects import llvm as _llvm + from ..._mlir import ir + + if byte_sel is not None: + sel = ["BYTE_0", "BYTE_1", "BYTE_2", "BYTE_3"][int(byte_sel)] + return _llvm.inline_asm( + ir.F32Type.get(), + [_to_ir(src_i32)], + f"v_cvt_off_f32_i4_sdwa $0, $1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:{sel}", + "=v,v", + has_side_effects=False, + ) + return _llvm.inline_asm( + ir.F32Type.get(), + [_to_ir(src_i32)], + "v_cvt_off_f32_i4 $0, $1", + "=v,v", + has_side_effects=False, + ) + + +def cvt_pk_bf16_f32(src_a_f32, src_b_f32): + """gfx950: v_cvt_pk_bf16_f32 vdst, vsrc0, vsrc1. + + Pack two f32 values into 2xbf16 in i32. + dst[15:0] = bf16(src_a), dst[31:16] = bf16(src_b). + """ + from ..._mlir.dialects import llvm as _llvm + from ..._mlir import ir + return _llvm.inline_asm( + ir.IntegerType.get_signless(32), + [_to_ir(src_a_f32), _to_ir(src_b_f32)], + "v_cvt_pk_bf16_f32 $0, $1, $2", + "=v,v,v", + has_side_effects=False, + ) diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index 8373fd4e..c8da0150 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -128,6 +128,14 @@ MOE_FP4_SHAPES=' 32768,7168,2048,32,8,64,256,256,256,256 ' +# MoE W4A16 groupwise shapes (int4_bf16, group_size=32): same format as MOE_SHAPES +# Kimi 2.5 TP8: model_dim=7168, inter_dim=256, E=384, topk=8 +MOE_W4A16_SHAPES=' +128,7168,256,384,8,16,128,128,128,256 +256,7168,256,384,8,16,128,128,128,256 +512,7168,256,384,8,16,128,128,128,256 +' + # Memory bound threshold (M or tokens <= threshold => memory bound) MEMORY_BOUND_THRESHOLD=512 @@ -731,6 +739,54 @@ if [ "${RUN_MOE}" -eq 1 ] && [ "${IS_CDNA}" = "true" ]; then fi fi done + + # MoE W4A16 groupwise (int4_bf16, group_size=32) + for shape in $MOE_W4A16_SHAPES; do + [ -z "$shape" ] && continue + oldIFS=$IFS + IFS=, + set -- $shape + IFS=$oldIFS + tokens=$1; model_dim=$2; inter_dim=$3; experts=$4; topk=$5; tile_m=$6; tile_n=$7; tile_k=$8; tile_n2=$9; tile_k2=${10} + log="${BENCH_LOG_DIR}/moe_w4a16_t${tokens}_md${model_dim}_id${inter_dim}_e${experts}_k${topk}.log" + if python3 tests/kernels/test_moe_gemm.py \ + --in_dtype int4_bf16 \ + --group_size 32 \ + -dim "$model_dim,$inter_dim" \ + -t "$tokens" \ + -e "$experts" \ + -k "$topk" \ + --num_warmup 10 \ + --num_iters 100 \ + --tile_m "$tile_m" \ + --tile_n "$tile_n" \ + --tile_k "$tile_k" \ + --tile_n2 "$tile_n2" \ + --tile_k2 "$tile_k2" \ + --skip_ref false \ + --compare_aiter_ck false >"${log}" 2>&1; then + SUCCESS_COUNT=$((SUCCESS_COUNT + 1)) + else + FAIL_COUNT=$((FAIL_COUNT + 1)) + echo "moe w4a16 failed. Log: ${log}" >&2 + _show_fail_log "${log}" "moe_w4a16" + fi + shape_moe="t${tokens}-d${model_dim}x${inter_dim}-e${experts}k${topk}" + + dt_s1="$(grep -Eo 'FlyDSL MoE stage1\[[^]]+\]:' "${log}" | tail -1 | cut -d'[' -f2 | cut -d']' -f1 || true)" + tf_s1="$(grep -Eo 'FlyDSL MoE stage1\[[^]]+\]:.* ([0-9.]+) TFLOPS' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" + tb_s1="$(grep -Eo 'FlyDSL MoE stage1\[[^]]+\]:.* ([0-9.]+) TB/s' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" + if [ -n "${dt_s1}" ] && [ -n "${tf_s1}" ] && [ -n "${tb_s1}" ]; then + _emit_row "moe_w4a16_s1" "${shape_moe}" "${dt_s1}" "${tb_s1}" "${tf_s1}" + fi + + dt_s2="$(grep -Eo 'FlyDSL MoE stage2 \[[^]]+\] [^ ]+' "${log}" | tail -1 | awk '{print $NF}' || true)" + tf_s2="$(grep -Eo 'FlyDSL MoE stage2 .* ([0-9.]+) TFLOPS' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" + tb_s2="$(grep -Eo 'FlyDSL MoE stage2 .* ([0-9.]+) TB/s' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" + if [ -n "${dt_s2}" ] && [ -n "${tf_s2}" ] && [ -n "${tb_s2}" ]; then + _emit_row "moe_w4a16_s2" "${shape_moe}" "${dt_s2}" "${tb_s2}" "${tf_s2}" + fi + done fi # RDNA4 WMMA GEMM benchmarks (via benchmark_common.py) diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index c802092a..ca819bba 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -333,6 +333,9 @@ def run_moe_stage1( test_graph: bool = False, # Optional override for pre-built groupwise scale tensor [E, K//group_size, 2*inter_dim] (Opt 0 layout). scale_w1_groups_in: Optional[torch.Tensor] = None, + scale_dtype: str = "f32", + even_dispatch: bool = False, + out_dtype: str = "f16", ): assert model_dim % 64 == 0 assert model_dim % tile_k == 0 @@ -362,9 +365,17 @@ def run_moe_stage1( # Routing: aiter uses fused_topk; we use torch topk+softmax for portability/determinism. if topk_ids_in is None or topk_weights_in is None: - score = torch.randn((tokens, experts), device=device, dtype=torch.float32) - topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) - topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + if even_dispatch: + # Evenly distribute tokens across experts: each token picks topk consecutive experts + # cycling through all E experts. Guarantees uniform load. + topk_ids = torch.stack( + [torch.arange(topk, device=device, dtype=torch.int32) + ((t * topk) % experts) for t in range(tokens)] + ) % experts + topk_weights = torch.full((tokens, topk), 1.0 / topk, device=device, dtype=torch.float32) + else: + score = torch.randn((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) else: topk_ids = topk_ids_in topk_weights = topk_weights_in @@ -390,12 +401,23 @@ def run_moe_stage1( blocks, ) = routing + # # Print token distribution across experts. + # tokens_per_expert = torch.bincount(topk_ids.reshape(-1).to(torch.int64), minlength=experts) + # active = int((tokens_per_expert > 0).sum()) + # print( + # f" Token dispatch: {tokens} tokens x topk={topk} -> {active}/{experts} active experts, " + # f"tokens/expert min={int(tokens_per_expert.min())} max={int(tokens_per_expert.max())} " + # f"mean={tokens_per_expert.float().mean():.1f}" + # f"{' (even)' if even_dispatch else ' (random)'}" + # ) + if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "fp4"): raise ValueError( f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16','fp4'), got {in_dtype!r}" ) is_int4 = in_dtype == "int4" is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights + w_is_int4 = is_int4 or is_int4_bf16 is_int8 = in_dtype in ("int8", "int8smooth", "int4") is_int8smooth = in_dtype == "int8smooth" is_fp4 = in_dtype == "fp4" @@ -467,7 +489,7 @@ def run_moe_stage1( w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, _scale_w2_unused = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) - # --- Groupwise scale for W4A16 --- + # --- Groupwise scale for int4 variants --- use_groupwise_scale = is_int4_bf16 and group_size > 0 scale_w1_groups = None # [E, K//group_size, 2*inter_dim] for kernel (Opt 0 layout) if use_groupwise_scale: @@ -477,11 +499,12 @@ def run_moe_stage1( scale_w1_groups = scale_w1_groups_in else: # Generate random groupwise scale [E, num_groups, N] (Opt 0: cache-friendly). + _scale_torch_dtype = torch.bfloat16 if scale_dtype == "bf16" else torch.float32 scale_w1_groups = ( torch.rand(experts, num_groups_s1, N_total, device=device, dtype=torch.float32) * 0.05 + 0.005 - ) - # Prepare scale for kernel (no-op shuffle, returns contiguous). + ).to(_scale_torch_dtype) + # Prepare scale for kernel (handles both f32 and bf16 layouts). scale_w1_prepared = shuffle_scale_for_int4(scale_w1_groups, group_size=group_size) # Override per-row scale (kernel uses groupwise scale instead). scale_w1 = None @@ -539,8 +562,9 @@ def run_moe_stage1( scale_w1_1d = scale_w1_flat.view(-1).contiguous() sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] - # Output: [tokens, topk, inter_dim] fp16 - out = torch.empty((tokens, topk, inter_dim), device=device, dtype=torch.float16) + # Output: [tokens, topk, inter_dim] + _out_torch_dtype = torch.bfloat16 if out_dtype == "bf16" else torch.float16 + out = torch.empty((tokens, topk, inter_dim), device=device, dtype=_out_torch_dtype) if is_fp4: exe = compile_mixed_moe_gemm1( @@ -582,6 +606,8 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): tile_k=tile_k, doweight_stage1=bool(doweight_stage1), use_cshuffle_epilog=False, + scale_is_bf16=(scale_dtype == "bf16"), + out_dtype=out_dtype, ) def _s1_args(o, x, w, sx, sw, st, eids, sw_sorted): @@ -648,13 +674,23 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): tflops = flops / (us / 1e6) / 1e12 # Rough bytes-moved accounting (same spirit as GEMM tests: count each tensor once). - aE = min(tokens * topk, experts) - x_elem_bytes = 2 if (is_int4_bf16 or in_dtype in ("bf16", "fp16")) else (0.5 if is_fp4 else 1) # bf16/fp16 activations + # Only activated experts load weights/scales: E_active = min(E, tokens * topk). + active_experts = min(experts, tokens * topk) bytes_moved = 0 - bytes_moved += tokens * model_dim * x_elem_bytes - bytes_moved += aE * (inter_dim * 2) * model_dim * 0.5 if (use_packed_int4 or is_fp4) else 1 - bytes_moved += aE * (inter_dim * 2) * math.ceil(model_dim / group_size if group_size > 0 else (32 if is_fp4 else model_dim)) - bytes_moved += tokens * topk * inter_dim * 2 + is_f16_or_bf16_s1 = is_int4_bf16 or in_dtype in ("bf16", "fp16") + x_elem_bytes = 2 if is_f16_or_bf16_s1 else 1 + bytes_moved += (tokens * topk if is_int8smooth else tokens) * model_dim * x_elem_bytes # x (bf16 for W4A16, else fp8/int8) + bytes_moved += (active_experts * (2 * inter_dim) * model_dim) // (2 if use_packed_int4 else 1) # w (packed for int4) + bytes_moved += tokens * topk * inter_dim * 2 # out fp16 (logical) + bytes_moved += ((tokens * topk if is_int8smooth else tokens) * 4) if not is_f16_or_bf16_s1 else 0 # scale_x f32 + if use_groupwise_scale: + num_groups_s1 = model_dim // group_size + _scale_bytes = 2 if scale_dtype == "bf16" else 4 + bytes_moved += active_experts * num_groups_s1 * (2 * inter_dim) * _scale_bytes # groupwise scale + elif not is_f16_or_bf16_s1: + bytes_moved += active_experts * (2 * inter_dim) * 4 # per-row scale_w f32 + # Note: routing metadata (sorted_weights, sorted_token_ids, sorted_expert_ids) excluded + # from bytes_moved — they are negligible vs weight/activation/scale tensors. tbps = bytes_moved / 1e12 / (us / 1e6) print( @@ -784,6 +820,8 @@ def run_moe_stage2( test_graph: bool = False, # Optional override for pre-built groupwise scale tensor [E, inter_dim//group_size, model_dim] (Opt 0 layout). scale_w2_groups_in: Optional[torch.Tensor] = None, + scale_dtype: str = "f32", + even_dispatch: bool = False, ): """MoE stage2 (gemm2): out2[t] = sum_{slot} ( out1[t,slot] @ W2[expert]^T ) with optional routed weight.""" @@ -818,7 +856,7 @@ def run_moe_stage2( # Default compile function. if compile_fn is None: if use_reduce: - compile_fn = _make_reduce_mode_compile_fn(use_flydsl_reduce=True, use_valid_mask=bool(use_valid_mask)) + compile_fn = _make_reduce_mode_compile_fn(use_flydsl_reduce=True, use_valid_mask=bool(use_valid_mask), scale_dtype=scale_dtype) else: compile_fn = compile_moe_gemm2 @@ -846,9 +884,15 @@ def run_moe_stage2( # Routing: deterministic torch topk + softmax. if topk_ids_in is None or topk_weights_in is None: - score = torch.rand((tokens, experts), device=device, dtype=torch.float32) - topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) - topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + if even_dispatch: + topk_ids = torch.stack( + [torch.arange(topk, device=device, dtype=torch.int32) + ((t * topk) % experts) for t in range(tokens)] + ) % experts + topk_weights = torch.full((tokens, topk), 1.0 / topk, device=device, dtype=torch.float32) + else: + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) else: topk_ids = topk_ids_in topk_weights = topk_weights_in @@ -873,8 +917,16 @@ def run_moe_stage2( sorted_size, blocks, ) = routing - # NOTE: routing uses `moe_sorting` output directly (no host trim/pad). Extra launched blocks - # are gated by `num_valid_ids` inside the kernels. + + # # Print token distribution across experts. + # tokens_per_expert = torch.bincount(topk_ids.reshape(-1).to(torch.int64), minlength=experts) + # active = int((tokens_per_expert > 0).sum()) + # print( + # f" Token dispatch: {tokens} tokens x topk={topk} -> {active}/{experts} active experts, " + # f"tokens/expert min={int(tokens_per_expert.min())} max={int(tokens_per_expert.max())} " + # f"mean={tokens_per_expert.float().mean():.1f}" + # f"{' (even)' if even_dispatch else ' (random)'}" + # ) if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "fp4"): raise ValueError( @@ -882,6 +934,7 @@ def run_moe_stage2( ) is_int4 = in_dtype == "int4" is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights + w_is_int4 = is_int4 or is_int4_bf16 is_int8 = in_dtype in ("int8", "int8smooth", "int4") is_int8smooth = in_dtype == "int8smooth" is_fp4 = in_dtype == "fp4" @@ -944,7 +997,7 @@ def run_moe_stage2( w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) - # --- Groupwise scale for W4A16 (stage 2) --- + # --- Groupwise scale for int4 variants (stage 2) --- use_groupwise_scale = is_int4_bf16 and group_size > 0 scale_w2_groups = None # [E, inter_dim//group_size, model_dim] Opt 0 layout if use_groupwise_scale: @@ -953,10 +1006,12 @@ def run_moe_stage2( scale_w2_groups = scale_w2_groups_in else: # Generate random groupwise scale [E, num_groups, N] (Opt 0: cache-friendly). + _scale_torch_dtype = torch.bfloat16 if scale_dtype == "bf16" else torch.float32 scale_w2_groups = ( torch.rand(experts, num_groups_s2, model_dim, device=device, dtype=torch.float32) * 0.05 + 0.005 - ) + ).to(_scale_torch_dtype) + # Prepare scale for kernel (handles both f32 and bf16 layouts). scale_w2_prepared = shuffle_scale_for_int4(scale_w2_groups, group_size=group_size) # Override per-row scale (kernel uses groupwise scale instead). scale_w2 = None @@ -1144,6 +1199,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): tile_n=tile_n, tile_k=tile_k, doweight_stage2=bool(doweight_stage2), + scale_is_bf16=(scale_dtype == "bf16"), ) is_reduce_exe = (getattr(exe, "mode", None) == MoeGemm2Mode.REDUCE) or bool(use_reduce) @@ -1233,13 +1289,23 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): flops = 2 * tokens * topk * model_dim * inter_dim tflops = flops / (us / 1e6) / 1e12 - aE = min(tokens * topk, experts) - a2_elem_bytes = 2 if in_dtype in ("int4_bf16", "bf16", "fp16") else (0.5 if is_fp4 else 1) # bf16/fp16 activations + # Only activated experts load weights/scales: E_active = min(E, tokens * topk). + active_experts = min(experts, tokens * topk) bytes_moved = 0 - bytes_moved += tokens * topk * inter_dim * a2_elem_bytes - bytes_moved += aE * model_dim * inter_dim * 0.5 if (use_packed_int4 or is_fp4) else 1 - bytes_moved += aE * model_dim * math.ceil(inter_dim / group_size if group_size > 0 else (32 if is_fp4 else inter_dim)) - bytes_moved += tokens * topk * model_dim * 2 + a2_elem_bytes = 2 if in_dtype in ("int4_bf16", "bf16", "fp16") else 1 # bf16/fp16 activations + bytes_moved += tokens * topk * inter_dim * a2_elem_bytes # a2 (logical) + bytes_moved += (active_experts * model_dim * inter_dim) // (2 if w_is_int4 else 1) # w2 (packed for int4) + bytes_moved += tokens * model_dim * (2 if out_torch_dtype in (torch.float16, torch.bfloat16) else 4) # out + is_f16_or_bf16_s2 = is_int4_bf16 or in_dtype in ("bf16", "fp16") + bytes_moved += (tokens * topk * 4) if not is_f16_or_bf16_s2 else 0 # a2_scale f32 (None for bf16) + if use_groupwise_scale: + num_groups_s2 = inter_dim // group_size + _scale_bytes = 2 if scale_dtype == "bf16" else 4 + bytes_moved += active_experts * num_groups_s2 * model_dim * _scale_bytes # groupwise scale + elif not is_f16_or_bf16_s2: + bytes_moved += active_experts * model_dim * 4 # per-row scale_w f32 + # Note: routing metadata (sorted_weights, sorted_token_ids, sorted_expert_ids) excluded + # from bytes_moved — they are negligible vs weight/activation/scale tensors. tbps = bytes_moved / 1e12 / (us / 1e6) print( f"FlyDSL MoE stage2 [{kernel_name}] {in_dtype} {'reduce' if use_reduce else 'atomic'} | " @@ -1358,7 +1424,7 @@ def launch_ck(o, a2_, w1_, w2_, sorted_ids_, sorted_eids_, num_valid_, w2_scale_ "fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", pytest.param("fp4", marks=pytest.mark.skipif("gfx95" not in ARCH, reason="FP4 requires gfx950+")), ]) -@pytest.mark.parametrize("out_dtype", ["f16", "f32"], ids=["out_f16", "out_f32"]) +@pytest.mark.parametrize("out_dtype", ["f16", "bf16", "f32"], ids=["out_f16", "out_bf16", "out_f32"]) @pytest.mark.parametrize("use_reduce", [False, True], ids=["atomic", "reduce"]) @pytest.mark.parametrize("use_valid_mask", [False, True], ids=["nomask", "mask"]) @pytest.mark.parametrize("test_graph", [ @@ -1404,7 +1470,7 @@ def test_moe_gemm_2stage( if bool(use_reduce) and out_s in ("f32", "fp32", "float"): pytest.skip("reduce mode does not support out_dtype='f32' (compile_moe_gemm2(accumulate=False) forbids it).") if group_size > 0 and in_dtype != "int4_bf16": - pytest.skip("groupwise scale only applies to int4_bf16") + pytest.skip("groupwise scale only applies to int4_bf16 (W4A16)") if in_dtype == "fp4": if bool(use_valid_mask): pytest.skip("FP4 does not support valid_mask") @@ -1584,7 +1650,7 @@ def _per_1x32_fp4_quant(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # Test Helpers for MoE GEMM2 Mode Comparison -def _make_reduce_mode_compile_fn(use_flydsl_reduce: bool = True, use_valid_mask: bool = False): +def _make_reduce_mode_compile_fn(use_flydsl_reduce: bool = True, use_valid_mask: bool = False, scale_dtype: str = "f32"): """Create a compile function that forces reduce mode. Args: @@ -1604,6 +1670,7 @@ def _compile( in_dtype: str = "fp8", group_size: int = -1, out_dtype: str = "f16", + scale_is_bf16: bool = False, ): if use_flydsl_reduce: return compile_moe_gemm2_ex( @@ -1621,6 +1688,7 @@ def _compile( valid_mask=(True if bool(use_valid_mask) else None), mode=MoeGemm2Mode.REDUCE, zero_intermediate=False, # test non-zeroed performance + scale_is_bf16=(scale_dtype == "bf16"), ) else: gemm2_exe = compile_moe_gemm2( @@ -1636,6 +1704,7 @@ def _compile( group_size=group_size, out_dtype=out_dtype, accumulate=False, + scale_is_bf16=(scale_dtype == "bf16"), ) return _TorchReduceWrapper(gemm2_exe, topk, model_dim) return _compile @@ -1710,6 +1779,45 @@ def test_moe_gemm_2stage_bf16_out(use_reduce): ) +@pytest.mark.parametrize("scale_dtype", ["f32", "bf16"], ids=["scale_f32", "scale_bf16"]) +def test_moe_gemm_w4a16_groupwise_scale(scale_dtype): + """Test W4A16 groupwise scale with f32 and bf16 (packed) scale dtypes.""" + tokens, model_dim, inter_dim, experts, topk = 64, 256, 128, 4, 2 + tile_m, tile_n, tile_k = 16, 64, 128 + device = torch.device("cuda") + s = 0.2 + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * s + w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * s + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + routing = build_routing_buffers( + topk_ids=topk_ids, topk_weights=topk_weights, + experts=experts, model_dim=model_dim, tile_m=tile_m, + ) + out1, _ = run_moe_stage1( + tokens=tokens, model_dim=model_dim, inter_dim=inter_dim, + experts=experts, topk=topk, in_dtype="int4_bf16", group_size=32, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + doweight_stage1=False, num_iters=2, num_warmup=1, + x_fp32_in=x_fp32, w1_fp32_in=w1_fp32, w2_fp32_in=w2_fp32, + topk_ids_in=topk_ids, topk_weights_in=topk_weights, routing_in=routing, + return_outputs=True, skip_ref=False, scale_dtype=scale_dtype, + ) + a2 = out1.to(torch.bfloat16) + run_moe_stage2( + tokens=tokens, model_dim=model_dim, inter_dim=inter_dim, + experts=experts, topk=topk, in_dtype="int4_bf16", group_size=32, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + doweight_stage1=False, num_iters=2, num_warmup=1, + x_fp32_in=x_fp32, w1_fp32_in=w1_fp32, w2_fp32_in=w2_fp32, + topk_ids_in=topk_ids, topk_weights_in=topk_weights, routing_in=routing, + a2_fp8_in=a2, a2_scale_in=None, + return_outputs=True, skip_ref=False, scale_dtype=scale_dtype, + ) + + @pytest.mark.parametrize( "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n, tile_k", [ diff --git a/tests/kernels/test_ref.py b/tests/kernels/test_ref.py index b0288b67..36dbe4c3 100644 --- a/tests/kernels/test_ref.py +++ b/tests/kernels/test_ref.py @@ -100,11 +100,11 @@ def torch_moe_gemm1( s_idx = idx[:, 1] x_in = x[t_idx, :] if x.dim() == 2 else x[t_idx, s_idx, :] y2 = F.linear(x_in, w1[e, :, :]) # [num, 2*inter_dim] - if doweight_stage1: - y2 = y2 * topk_weights[t_idx, s_idx].unsqueeze(-1) gate = y2[:, :inter_dim] up = y2[:, inter_dim:] y = F.silu(gate) * up + if doweight_stage1: + y = y * topk_weights[t_idx, s_idx].unsqueeze(-1) out[t_idx, s_idx, :] = y return out diff --git a/tests/utils.py b/tests/utils.py index 6b714460..e0fbd498 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -99,19 +99,21 @@ def shuffle_weight(x: torch.Tensor, layout=(16, 16), use_int4=False) -> torch.Te def shuffle_scale_for_int4(scale: torch.Tensor, group_size: int = 32, layout=(16, 16)) -> torch.Tensor: """Prepare scale tensor for W4A16 groupwise scale kernel. - NOTE: Despite the name, this function does NOT shuffle the scale tensor. - The kernel uses the [E, num_groups, N] layout (Opt 0: cache-friendly) where - adjacent threads read adjacent N elements (stride-1 access). + Input: scale tensor of shape ``[E, num_groups, N]``. - Scale indexing uses: scale_idx = expert_offset*(G-1) + n_global + group_idx*N_pe + For **f32** scales the kernel uses ``(E, G, N)`` layout directly, so this + is a contiguous no-op. + + For **bf16** scales the kernel uses ``(E, G//2, N, 2)`` layout — two + adjacent groups for the same N position are packed into one dword. Args: - scale: Scale tensor of shape [E, num_groups, N] where num_groups = K_dim // group_size - group_size: Group size for quantization (must be 32 for FlyDSL) - layout: Tile layout (unused, kept for API compatibility) + scale: Scale tensor of shape [E, num_groups, N]. + group_size: Group size for quantization (must be 32 for FlyDSL). + layout: Tile layout (unused, kept for API compatibility). Returns: - Scale tensor in [E, num_groups, N] layout, ready for kernel consumption. + Scale tensor ready for kernel consumption. """ if group_size != 32: raise ValueError( @@ -119,4 +121,9 @@ def shuffle_scale_for_int4(scale: torch.Tensor, group_size: int = 32, layout=(16 f"This is due to int4 preshuffle layout constraints." ) + if scale.dtype == torch.bfloat16: + # (E, G, N) bf16 → (E, G//2, N, 2) bf16 packed layout. + E, G, N = scale.shape + return scale.view(E, G // 2, 2, N).permute(0, 1, 3, 2).contiguous() + return scale.contiguous() \ No newline at end of file From c5d83dfd89e940405e6bfbfe0d15cf15b954076b Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 13 Apr 2026 20:18:48 +0800 Subject: [PATCH 06/29] [FEAT] Add get_leaves op and support convert dyn_tuple to py_tuple (#391) * Add get_leaves op and support convert dyn_tuple to py_tuple * fix comments --- examples/02-tiledCopy.py | 6 +- include/flydsl/Dialect/Fly/IR/FlyOps.td | 5 +- lib/Dialect/Fly/IR/FlyOps.cpp | 15 +++- lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 49 ++++++++++++- python/flydsl/expr/derived.py | 11 ++- python/flydsl/expr/primitive.py | 9 ++- python/flydsl/expr/typing.py | 71 ++++++++++--------- tests/mlir/Transforms/layout_lowering.mlir | 32 +++++++++ 8 files changed, 149 insertions(+), 49 deletions(-) diff --git a/examples/02-tiledCopy.py b/examples/02-tiledCopy.py index 0e781f44..d5a3e183 100644 --- a/examples/02-tiledCopy.py +++ b/examples/02-tiledCopy.py @@ -30,11 +30,9 @@ def copy_kernel( thr_layout = fx.make_layout((4, 1), (1, 1)) val_layout = fx.make_layout((1, 8), (1, 1)) copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32) - layout_thr_val = fx.raked_product(thr_layout, val_layout) + tile_mn, tv_layout = fx.make_layout_tv(thr_layout, val_layout) - tile_mn = fx.make_tile(4, 8) - - tiled_copy = fx.make_tiled_copy(copy_atom, layout_thr_val, tile_mn) + tiled_copy = fx.make_tiled_copy(copy_atom, tv_layout, tile_mn) thr_copy = tiled_copy.get_slice(tid) partition_src = thr_copy.partition_S(bA) diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index 2665df92..8c4d5f4e 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -111,10 +111,9 @@ def Fly_GetScalarOp : Fly_Op<"get_scalar", [Pure, DeclareOpInterfaceMethods]> { - let arguments = (ins Fly_IntTuple:$input); - let results = (outs Fly_IntTuple:$result); + let arguments = (ins Fly_IntTuple:$input, DefaultValuedAttr:$dynamicOnly); + let results = (outs Variadic>:$results); } def Fly_GetShapeOp : Fly_Op<"get_shape", [Pure, DeclareOpInterfaceMethods]> { diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index 7f1d78e2..e289e869 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -300,8 +300,19 @@ FLY_INFER_RETURN_TYPES(GetLeavesOp) { if (!intTupleType) return emitOptionalError(location, "GetLeavesOp: expected IntTupleType, got ", operands[0].getType()); - // transform_leaf on intTupleAttr, - return failure(); + bool dynamicOnly = false; + if (properties) + dynamicOnly = properties.as()->dynamicOnly.getValue(); + IntTupleBuilder builder(context); + SmallVector flatLeaves; + intTupleFlattenToVector(builder, intTupleType.getAttr(), flatLeaves); + for (auto leaf : flatLeaves) { + auto intAttr = leaf.extractIntFromLeaf(); + if (dynamicOnly && intAttr.isStatic()) + continue; + inferredReturnTypes.push_back(IntegerType::get(context, std::max(32, intAttr.getWidth()))); + } + return success(); } FLY_INFER_RETURN_TYPES(GetShapeOp) { diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index e1a717ab..e456beb4 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -417,6 +417,49 @@ class GetScalarLowering : public OpRewritePattern { } }; +class GetLeavesLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetLeavesOp op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value intTuple = op.getInput(); + + auto intTupleTy = dyn_cast(intTuple.getType()); + if (!intTupleTy) + return failure(); + if (!isNormalForm(cast>(intTuple))) + return failure(); + + auto defOp = intTuple.getDefiningOp(); + bool dynamicOnly = op.getDynamicOnly(); + + if (dynamicOnly) { + rewriter.replaceOp(op, defOp.getDyncElems()); + return success(); + } + + IntTupleBuilder builder(rewriter.getContext()); + SmallVector flatLeaves; + intTupleFlattenToVector(builder, intTupleTy.getAttr(), flatLeaves); + + SmallVector results; + auto dyncIter = defOp.getDyncElems().begin(); + for (auto leaf : flatLeaves) { + auto intAttr = leaf.extractIntFromLeaf(); + if (intAttr.isStatic()) { + results.push_back(arith::ConstantIntOp::create(rewriter, loc, intAttr.getValue(), + std::max(32, intAttr.getWidth()))); + } else { + results.push_back(*dyncIter++); + } + } + + rewriter.replaceOp(op, results); + return success(); + } +}; + class GetShapeLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2633,9 +2676,9 @@ class FlyLayoutLoweringPass MakeLayoutLikeOpLowering, MakeFragmentLikeOpLowering>(context); // Extractors - patterns.add(context); + patterns.add(context); // IntTuple operations patterns.add ir.Type: @@ -330,12 +330,25 @@ def get_static_leaf_int(self) -> int: raise ValueError("IntTuple is not a static leaf") return self.type.get_static_leaf_int - def to_py_value(self): - if not self.is_static: - raise ValueError("IntTuple is not static") + @traced_op + def to_py_value(self, loc=None, ip=None): + if self.is_static: + if self.is_leaf: + return self.get_static_leaf_int + return tuple(get_(self, i).to_py_value() for i in range(self.rank)) + leaves = get_leaves(self, dynamic_only=True, loc=loc, ip=ip) + leaf_iter = iter(leaves) + return self._rebuild_py_value(leaf_iter) + + def _rebuild_py_value(self, leaf_iter): if self.is_leaf: - return self.get_static_leaf_int - return tuple(get_(self, i).to_py_value() for i in range(self.rank)) + if self.is_static: + return self.get_static_leaf_int + val = next(leaf_iter) + width = ir.IntegerType(val.type).width + wrapper = Int64 if width == 64 else Int32 + return wrapper(val) + return tuple(IntTuple(get_(self, i))._rebuild_py_value(leaf_iter) for i in range(self.rank)) @traced_op def __getitem__(self, mode, loc=None, ip=None): @@ -821,6 +834,7 @@ def __iter__(self): # Vector — register vector with value semantics # ═══════════════════════════════════════════════════════════════════════ + class ReductionOp(enum.Enum): ADD = "add" MUL = "mul" @@ -829,8 +843,8 @@ class ReductionOp(enum.Enum): _REDUCE_KINDS = { - "add": (_vector.CombiningKind.ADD, _vector.CombiningKind.ADD, _vector.CombiningKind.ADD), - "mul": (_vector.CombiningKind.MUL, _vector.CombiningKind.MUL, _vector.CombiningKind.MUL), + "add": (_vector.CombiningKind.ADD, _vector.CombiningKind.ADD, _vector.CombiningKind.ADD), + "mul": (_vector.CombiningKind.MUL, _vector.CombiningKind.MUL, _vector.CombiningKind.MUL), "max": (_vector.CombiningKind.MAXNUMF, _vector.CombiningKind.MAXSI, _vector.CombiningKind.MAXUI), "min": (_vector.CombiningKind.MINIMUMF, _vector.CombiningKind.MINSI, _vector.CombiningKind.MINUI), } @@ -912,8 +926,8 @@ def to(self, dtype: Type[Numeric], *, loc=None, ip=None) -> "Vector": src_dtype = self._dtype if src_dtype is dtype: return self - src_float = getattr(src_dtype, 'is_float', False) - dst_float = getattr(dtype, 'is_float', False) + src_float = getattr(src_dtype, "is_float", False) + dst_float = getattr(dtype, "is_float", False) if src_float and dst_float: res = fp_to_fp(self, dtype.ir_type, loc=loc, ip=ip) elif src_float: @@ -927,10 +941,9 @@ def to(self, dtype: Type[Numeric], *, loc=None, ip=None) -> "Vector": def ir_value(self, *, loc=None, ip=None): return self - def reduce(self, op, init_val=None, reduction_profile=None, - *, fastmath=None, loc=None, ip=None): + def reduce(self, op, init_val=None, reduction_profile=None, *, fastmath=None, loc=None, ip=None): is_fp = self._dtype.is_float - signed = getattr(self._dtype, 'signed', True) + signed = getattr(self._dtype, "signed", True) kind = _resolve_combining_kind(op, is_fp, signed) et = element_type(self.type) kwargs = {} @@ -945,9 +958,7 @@ def reduce(self, op, init_val=None, reduction_profile=None, def __getitem__(self, idx): if isinstance(idx, int): - res = _vector.ExtractOp( - self, static_position=[idx], dynamic_position=[] - ).result + res = _vector.ExtractOp(self, static_position=[idx], dynamic_position=[]).result return self._dtype(res) raise TypeError(f"unsupported index type: {type(idx)}") @@ -964,8 +975,7 @@ def shuffle(self, other, mask, *, loc=None, ip=None) -> "Vector": return Vector(res, (len(mask),), self._dtype) @classmethod - def filled(cls, shape, fill_value, dtype: Type[Numeric], - *, loc=None, ip=None) -> "Vector": + def filled(cls, shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> "Vector": shape = (shape,) if isinstance(shape, int) else tuple(shape) n = 1 for s in shape: @@ -977,20 +987,17 @@ def filled(cls, shape, fill_value, dtype: Type[Numeric], else: raise ValueError(f"expected numeric fill_value, got {type(fill_value)}") vec_ty = ir.VectorType.get([n], dtype.ir_type) - val = _vector.broadcast(vec_ty, fill_value.ir_value(loc=loc, ip=ip), - loc=loc, ip=ip) + val = _vector.broadcast(vec_ty, fill_value.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) return cls(val, shape, dtype) @classmethod - def filled_like(cls, template: "Vector", fill_value, dtype=None, - *, loc=None, ip=None) -> "Vector": + def filled_like(cls, template: "Vector", fill_value, dtype=None, *, loc=None, ip=None) -> "Vector": if dtype is None: dtype = template.dtype return cls.filled(template.shape, fill_value, dtype, loc=loc, ip=ip) @classmethod - def zeros_like(cls, template: "Vector", dtype=None, - *, loc=None, ip=None) -> "Vector": + def zeros_like(cls, template: "Vector", dtype=None, *, loc=None, ip=None) -> "Vector": if dtype is None: dtype = template.dtype return cls.filled(template.shape, 0.0 if dtype.is_float else 0, dtype, loc=loc, ip=ip) diff --git a/tests/mlir/Transforms/layout_lowering.mlir b/tests/mlir/Transforms/layout_lowering.mlir index cd39ebe4..35df9986 100644 --- a/tests/mlir/Transforms/layout_lowering.mlir +++ b/tests/mlir/Transforms/layout_lowering.mlir @@ -253,3 +253,35 @@ func.func @test_right_inverse() -> !fly.layout<(4,2):(2,1)> { %result = fly.right_inverse(%layout) : (!fly.layout<(2,4):(4,1)>) -> !fly.layout<(4,2):(2,1)> return %result : !fly.layout<(4,2):(2,1)> } + +// ----- + +// === GetLeavesOp Lowering === + +// dynamicOnly=false: all leaves returned, static as arith.constant, dynamic forwarded. +// Mixed i32/i64 dynamic leaves. +// CHECK-LABEL: @test_get_leaves_all +// CHECK-SAME: (%[[X:.*]]: i32, %[[Y:.*]]: i64) +func.func @test_get_leaves_all(%x: i32, %y: i64) -> (i32, i32, i64) { + %t = fly.make_int_tuple(%x, %y) : (i32, i64) -> !fly.int_tuple<(4, ?, ?{i64})> + // CHECK-NOT: fly.get_leaves + // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : i32 + // CHECK: return %[[C4]], %[[X]], %[[Y]] + %0:3 = fly.get_leaves(%t) : (!fly.int_tuple<(4, ?, ?{i64})>) -> (i32, i32, i64) + return %0#0, %0#1, %0#2 : i32, i32, i64 +} + +// ----- + +// dynamicOnly=true: only dynamic leaves returned, static skipped. +// Mixed i32/i64 dynamic leaves. +// CHECK-LABEL: @test_get_leaves_dynamic_only +// CHECK-SAME: (%[[X:.*]]: i32, %[[Y:.*]]: i64) +func.func @test_get_leaves_dynamic_only(%x: i32, %y: i64) -> (i32, i64) { + %t = fly.make_int_tuple(%x, %y) : (i32, i64) -> !fly.int_tuple<(4, ?, ?{i64})> + // CHECK-NOT: fly.get_leaves + // CHECK-NOT: arith.constant + // CHECK: return %[[X]], %[[Y]] + %0:2 = fly.get_leaves(%t) {dynamicOnly = true} : (!fly.int_tuple<(4, ?, ?{i64})>) -> (i32, i64) + return %0#0, %0#1 : i32, i64 +} From c5c4f6ffbfc7e28fa9919445dc72e33f17a607f0 Mon Sep 17 00:00:00 2001 From: Felix Li Date: Mon, 13 Apr 2026 21:00:57 +0800 Subject: [PATCH 07/29] remove ci massive llvm log (#394) --- tests/unit/test_compile_hints.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit/test_compile_hints.py b/tests/unit/test_compile_hints.py index bc11bb69..f303f76d 100644 --- a/tests/unit/test_compile_hints.py +++ b/tests/unit/test_compile_hints.py @@ -67,6 +67,7 @@ def test_bool_round_trip(self): restored = _fly.set_llvm_option_bool("enable-post-misched", True) assert restored is False + @pytest.mark.skip(reason="Temporarily disabled: opt-bisect-limit leaks LLVM BISECT logs across pytest.") def test_int_round_trip(self): _fly = self._get_fly() # Use a large limit so it doesn't affect compilation @@ -110,6 +111,7 @@ def test_bool_scoping(self): val = _fly.set_llvm_option_bool("enable-post-misched", baseline) assert val == baseline + @pytest.mark.skip(reason="Temporarily disabled: opt-bisect-limit leaks LLVM BISECT logs across pytest.") def test_mixed_types(self): from flydsl.compiler.llvm_options import llvm_options @@ -170,6 +172,8 @@ def test_fp_math_reaches_pipeline(self, monkeypatch): """Verify fast_fp_math/unsafe_fp_math appear in rocdl-attach-target.""" from flydsl.compiler.backends import rocm captured = {} + monkeypatch.setenv("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + _reset_jit_caches(_noop_launch) orig = rocm.RocmBackend.pipeline_fragments @@ -187,6 +191,7 @@ def patched(self, *, compile_hints): def test_llvm_options_in_compile_hints(self): """Verify llvm_options key is accepted and doesn't crash.""" + _reset_jit_caches(_noop_launch) exe = flyc.compile[{ "llvm_options": {"enable-post-misched": False}, }](_noop_launch) From e4ab0a90e579d0a1e14b4fe00cf465afcb417e8d Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 13 Apr 2026 21:18:08 +0800 Subject: [PATCH 08/29] Change result type of elem_less/equal to i1 (#392) --- include/flydsl/Dialect/Fly/IR/FlyOps.td | 11 ++- .../flydsl/Dialect/Fly/Utils/IntTupleUtils.h | 4 + lib/Dialect/Fly/IR/FlyOps.cpp | 28 ------ lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 86 ++++++++++++++++--- lib/Dialect/Fly/Utils/IntTupleUtils.cpp | 2 +- 5 files changed, 86 insertions(+), 45 deletions(-) diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index 8c4d5f4e..a64e0384 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -180,8 +180,15 @@ def Fly_IntTupleProductLikeOp : Fly_IntTupleBinaryOp<"int_tuple_product_like">; def Fly_ShapeDivOp : Fly_IntTupleBinaryOp<"shape_div">; def Fly_CeilDivOp : Fly_IntTupleBinaryOp<"ceil_div">; -def Fly_ElemLessOp : Fly_IntTupleBinaryOp<"elem_less">; -def Fly_EqualOp : Fly_IntTupleBinaryOp<"equal">; + +def Fly_ElemLessOp : Fly_Op<"elem_less", [Pure]> { + let arguments = (ins Fly_IntTuple:$lhs, Fly_IntTuple:$rhs); + let results = (outs I1:$result); +} +def Fly_EqualOp : Fly_Op<"equal", [Pure]> { + let arguments = (ins Fly_IntTuple:$lhs, Fly_IntTuple:$rhs); + let results = (outs I1:$result); +} //===----------------------------------------------------------------------===// // IntTupleLike operations diff --git a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h index 17fb7d6f..4681e2a7 100644 --- a/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/IntTupleUtils.h @@ -68,6 +68,7 @@ class IntTupleValueAdaptor { static IntTupleValueAdaptor create(Builder &builder, Value value, IntTupleAttr attr) { // Rebuild the adaptor from a normal-form IntTuple value while re-establishing // the leaf/non-leaf invariant above. + assert(isa>(value) && "Value must be a IntTuple"); assert(isNormalForm(cast>(value)) && "Value must be in normal form"); if (attr.isLeaf()) { if (attr.isStatic()) { @@ -96,6 +97,9 @@ class IntTupleValueAdaptor { IntTupleAttr getAttr() const { return attr; } friend class IntTupleBuilder; + friend IntTupleValueAdaptor + intTupleBasis2Tuple(const IntTupleBuilder &builder, + IntTupleValueAdaptor basis); }; template diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index e289e869..b07d29eb 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -560,35 +560,7 @@ FLY_INFER_RETURN_TYPES(CeilDivOp) { return success(); } -FLY_INFER_RETURN_TYPES(ElemLessOp) { - auto lhsTy = dyn_cast(operands[0].getType()); - auto rhsTy = dyn_cast(operands[1].getType()); - if (!lhsTy) - return emitOptionalError(location, "ElemLessOp: expected IntTupleType for lhs, got ", - operands[0].getType()); - if (!rhsTy) - return emitOptionalError(location, "ElemLessOp: expected IntTupleType for rhs, got ", - operands[1].getType()); - IntTupleBuilder builder(context); - inferredReturnTypes.assign( - {IntTupleType::get(intTupleElemLess(builder, lhsTy.getAttr(), rhsTy.getAttr()))}); - return success(); -} -FLY_INFER_RETURN_TYPES(EqualOp) { - auto lhsTy = dyn_cast(operands[0].getType()); - auto rhsTy = dyn_cast(operands[1].getType()); - if (!lhsTy) - return emitOptionalError(location, "EqualOp: expected IntTupleType for lhs, got ", - operands[0].getType()); - if (!rhsTy) - return emitOptionalError(location, "EqualOp: expected IntTupleType for rhs, got ", - operands[1].getType()); - IntTupleBuilder builder(context); - inferredReturnTypes.assign( - {IntTupleType::get(intTupleEqual(builder, lhsTy.getAttr(), rhsTy.getAttr()))}); - return success(); -} //===----------------------------------------------------------------------===// // IntTupleLike operations diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index e456beb4..f40d38f8 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -693,18 +693,6 @@ struct IntTupleCeilDivFn { return intTupleCeilDiv(builder, lhs, rhs); } }; -struct IntTupleElemLessFn { - IntTupleValueAdaptor operator()(IntTupleBuilder &builder, - IntTupleValueAdaptor lhs, IntTupleValueAdaptor rhs) const { - return intTupleElemLess(builder, lhs, rhs); - } -}; -struct IntTupleEqualFn { - IntTupleValueAdaptor operator()(IntTupleBuilder &builder, - IntTupleValueAdaptor lhs, IntTupleValueAdaptor rhs) const { - return intTupleEqual(builder, lhs, rhs); - } -}; using IntTupleAddOpLowering = IntTupleBinaryOpLowering; using IntTupleSubOpLowering = IntTupleBinaryOpLowering; @@ -720,8 +708,78 @@ using IntTupleProductLikeOpLowering = using ShapeDivOpLowering = IntTupleBinaryOpLowering; using CeilDivOpLowering = IntTupleBinaryOpLowering; -using ElemLessOpLowering = IntTupleBinaryOpLowering; -using EqualOpLowering = IntTupleBinaryOpLowering; + +class ElemLessOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ElemLessOp op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + if (!lhsTy || !rhsTy) + return failure(); + + if (!isNormalForm(cast>(lhs)) || + !isNormalForm(cast>(rhs))) + return failure(); + + IntTupleBuilder builder(rewriter, loc); + IntTupleValueAdaptor lhsAdaptor = IntTupleValueAdaptor::create(builder, lhs, lhsTy.getAttr()); + IntTupleValueAdaptor rhsAdaptor = IntTupleValueAdaptor::create(builder, rhs, rhsTy.getAttr()); + + auto result = intTupleElemLess(builder, lhsAdaptor, rhsAdaptor); + auto i1Ty = rewriter.getI1Type(); + Value i1Val; + if (result.isStatic()) { + int32_t staticVal = result.getAttr().extractIntFromLeaf().getValue(); + i1Val = arith::ConstantIntOp::create(rewriter, loc, i1Ty, staticVal != 0).getResult(); + } else { + i1Val = arith::TruncIOp::create(rewriter, loc, i1Ty, result.getValue()).getResult(); + } + rewriter.replaceOp(op, i1Val); + return success(); + } +}; + +class EqualOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(EqualOp op, PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = op.getLhs(); + auto rhs = op.getRhs(); + + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + if (!lhsTy || !rhsTy) + return failure(); + + if (!isNormalForm(cast>(lhs)) || + !isNormalForm(cast>(rhs))) + return failure(); + + IntTupleBuilder builder(rewriter, loc); + IntTupleValueAdaptor lhsAdaptor = IntTupleValueAdaptor::create(builder, lhs, lhsTy.getAttr()); + IntTupleValueAdaptor rhsAdaptor = IntTupleValueAdaptor::create(builder, rhs, rhsTy.getAttr()); + + auto result = intTupleEqual(builder, lhsAdaptor, rhsAdaptor); + auto i1Ty = rewriter.getI1Type(); + Value i1Val; + if (result.isStatic()) { + int32_t staticVal = result.getAttr().extractIntFromLeaf().getValue(); + i1Val = arith::ConstantIntOp::create(rewriter, loc, i1Ty, staticVal != 0).getResult(); + } else { + i1Val = arith::TruncIOp::create(rewriter, loc, i1Ty, result.getValue()).getResult(); + } + rewriter.replaceOp(op, i1Val); + return success(); + } +}; //===----------------------------------------------------------------------===// // IntTupleLike operations diff --git a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp index ab2ac88b..c8c93fa9 100644 --- a/lib/Dialect/Fly/Utils/IntTupleUtils.cpp +++ b/lib/Dialect/Fly/Utils/IntTupleUtils.cpp @@ -682,7 +682,7 @@ IntTupleValueAdaptor intTupleBasis2Tuple(const IntTupleBuilder Date: Tue, 14 Apr 2026 13:21:55 +0800 Subject: [PATCH 09/29] [MLIR] enhance get_scalar op, only requires dyn_leaf_cnt = 1 (#398) * enhance get_scalar op, only requires dyn_leaf_cnt = 1 * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- lib/Dialect/Fly/IR/FlyOps.cpp | 10 +++++--- lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 11 +++++---- tests/mlir/Transforms/layout_lowering.mlir | 23 +++++++++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index b07d29eb..438b4ded 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -288,9 +288,13 @@ FLY_INFER_RETURN_TYPES(GetScalarOp) { if (!intTupleType) return emitOptionalError(location, "GetScalarOp: expected IntTupleType, got ", operands[0].getType()); - if (!intTupleType.getAttr().isLeaf()) - return emitOptionalError(location, "GetScalarOp: expected a leaf IntTuple, got ", intTupleType); - auto intAttr = intTupleType.getAttr().extractIntFromLeaf(); + IntTupleAttr scalarAttr = intTupleType.getAttr(); + while (!scalarAttr.isLeaf() && scalarAttr.rank() == 1) + scalarAttr = scalarAttr.at(0); + if (!scalarAttr.isLeaf()) + return emitOptionalError(location, "GetScalarOp: expected a scalar IntTuple, got ", + intTupleType); + auto intAttr = scalarAttr.extractIntFromLeaf(); inferredReturnTypes.assign({IntegerType::get(context, intAttr.getWidth())}); return success(); } diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index f40d38f8..4ab3f5a1 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -398,12 +398,15 @@ class GetScalarLowering : public OpRewritePattern { if (!isNormalForm(cast>(intTuple))) return failure(); - IntTupleAttr attr = intTupleTy.getAttr(); - assert(attr.isLeaf() && "IntTuple must be a leaf"); + IntTupleAttr scalarAttr = intTupleTy.getAttr(); - Type resultTy = op.getResult().getType(); - auto intAttr = attr.extractIntFromLeaf(); + while (!scalarAttr.isLeaf() && scalarAttr.rank() == 1) + scalarAttr = scalarAttr.at(0); + if (!scalarAttr.isLeaf()) + return rewriter.notifyMatchFailure(op, "expected leaf IntTupleAttr after unwrapping rank-1 chain"); + auto intAttr = scalarAttr.extractIntFromLeaf(); if (intAttr.isStatic()) { + Type resultTy = op.getResult().getType(); rewriter.replaceOp(op, arith::ConstantIntOp::create(rewriter, loc, resultTy, intAttr.getValue())); return success(); diff --git a/tests/mlir/Transforms/layout_lowering.mlir b/tests/mlir/Transforms/layout_lowering.mlir index 35df9986..8b3e2a1d 100644 --- a/tests/mlir/Transforms/layout_lowering.mlir +++ b/tests/mlir/Transforms/layout_lowering.mlir @@ -99,6 +99,29 @@ func.func @test_get_scalar_dynamic(%x: i32) -> i32 { return %s : i32 } +// Static get_scalar unwraps nested singleton tuples. +// CHECK-LABEL: @test_get_scalar_nested_static +func.func @test_get_scalar_nested_static() -> i32 { + %t = fly.make_int_tuple() : () -> !fly.int_tuple<((((42))))> + // CHECK-NOT: fly.get_scalar + // CHECK-NOT: fly.make_int_tuple + // CHECK: %[[C:.*]] = arith.constant 42 : i32 + // CHECK: return %[[C]] + %s = fly.get_scalar(%t) : (!fly.int_tuple<((((42))))>) -> i32 + return %s : i32 +} + +// Dynamic get_scalar unwraps nested singleton tuples to the leaf SSA value. +// CHECK-LABEL: @test_get_scalar_nested_dynamic +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @test_get_scalar_nested_dynamic(%x: i32) -> i32 { + %t = fly.make_int_tuple(%x) : (i32) -> !fly.int_tuple<((((?))))> + // CHECK-NOT: fly.get_scalar + // CHECK: return %[[ARG]] + %s = fly.get_scalar(%t) : (!fly.int_tuple<((((?))))>) -> i32 + return %s : i32 +} + // ----- // === SizeOp Lowering === From 74449c4bfd3c51b79c7d4ab14cb4d98f1ebfdc59 Mon Sep 17 00:00:00 2001 From: Ao Li Date: Tue, 14 Apr 2026 17:08:11 +0800 Subject: [PATCH 10/29] support gptoss gemm shape by padding (#397) --- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 211 +++++++++++++++++----- tests/kernels/test_wmma_gemm_gfx1250.py | 91 +++++++--- 2 files changed, 236 insertions(+), 66 deletions(-) diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index e6ec6c70..e7ccde68 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -122,6 +122,76 @@ def _a8w4_tolerances(a_scale: torch.Tensor, b_scale: torch.Tensor, return rtol, atol, diag +def _align_up(value: int, align: int) -> int: + return ((value + align - 1) // align) * align + + +def _mxscale_pack_factors(data_format: str) -> tuple[int, int]: + if data_format == "fp4": + return 2, 2 + if data_format == "a8w4": + return 1, 2 + if data_format == "fp8": + return 1, 1 + raise ValueError(f"unsupported data_format={data_format!r}") + + +def _get_padded_problem_shape( + data_format: str, + M: int, + N: int, + K: int, + tile_m: int, + tile_n: int, + tile_k: int, + split_k: int, +) -> dict[str, int]: + """Pad runtime problem to tile-aligned kernel dimensions.""" + if K % SCALE_BLOCK != 0: + raise ValueError(f"K={K} must be divisible by SCALE_BLOCK={SCALE_BLOCK}") + + pack_a, pack_b = _mxscale_pack_factors(data_format) + padded_k = _align_up(K, tile_k * split_k) + return { + "M": _align_up(M, tile_m), + "N": _align_up(N, tile_n), + "K": padded_k, + "K_scale": padded_k // SCALE_BLOCK, + "pack_a": pack_a, + "pack_b": pack_b, + } + + +def _pad_2d_tensor(tensor: torch.Tensor, rows: int, cols: int, fill_value: int) -> torch.Tensor: + if tensor.shape == (rows, cols): + return tensor + padded = torch.full((rows, cols), fill_value, dtype=tensor.dtype, device=tensor.device) + padded[:tensor.shape[0], :tensor.shape[1]] = tensor + return padded + + +def _pad_mxscale_inputs( + a: torch.Tensor, + b: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + padded_shape: dict[str, int], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Pad data/scale tensors so the kernel can run full tiles safely.""" + a = _pad_2d_tensor(a, padded_shape["M"], padded_shape["K"] // padded_shape["pack_a"], fill_value=0) + b = _pad_2d_tensor(b, padded_shape["N"], padded_shape["K"] // padded_shape["pack_b"], fill_value=0) + a_scale = _pad_2d_tensor(a_scale, padded_shape["M"], padded_shape["K_scale"], fill_value=127) + b_scale = _pad_2d_tensor(b_scale, padded_shape["N"], padded_shape["K_scale"], fill_value=127) + return a, b, a_scale, b_scale + + +def _format_kernel_pad(M: int, N: int, K: int, padded_shape: dict[str, int]) -> str: + padded_dims = (padded_shape["M"], padded_shape["N"], padded_shape["K"]) + if padded_dims == (M, N, K): + return "" + return f", kernel_pad={padded_dims}" + + def _run_mxscale_gemm_test( data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, use_tdm_store, out_dtype, @@ -142,11 +212,15 @@ def _run_mxscale_gemm_test( if use_scale_opsel and is_fp4: pytest.skip("FP4 32x16 WMMA scaleBType op_sel ignored by AM simulator") - if K % split_k != 0: - pytest.skip(f"K={K} must be divisible by split_k={split_k}") - local_k = K // split_k - if local_k % tile_k != 0: - pytest.skip(f"K/split_k={local_k} must be divisible by tile_k={tile_k}") + if K % SCALE_BLOCK != 0: + pytest.skip(f"K={K} must be divisible by SCALE_BLOCK={SCALE_BLOCK}") + + padded_shape = _get_padded_problem_shape( + data_format, M, N, K, tile_m, tile_n, tile_k, split_k) + padded_m = padded_shape["M"] + padded_n = padded_shape["N"] + padded_k = padded_shape["K"] + local_k = padded_k // split_k num_k_tiles = local_k // tile_k if num_buffers > 1 and num_k_tiles < num_buffers: @@ -166,7 +240,8 @@ def _run_mxscale_gemm_test( mcast_str = f", cluster=({cluster_m},{cluster_n})" \ if cluster_m > 1 or cluster_n > 1 else "" tdm_str = ", tdm_store" if use_tdm_store else ", buffer_store" - print(f"\nRunning {fmt_name} GEMM: M={M}, N={N}, K={K}, " + pad_str = _format_kernel_pad(M, N, K, padded_shape) + print(f"\nRunning {fmt_name} GEMM: M={M}, N={N}, K={K}{pad_str}, " f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}" f"{mcast_str}{tdm_str}, preshuffle, out={out_dtype}") @@ -196,6 +271,9 @@ def _run_mxscale_gemm_test( print(f"Ref stats: min={ref.min():.2f}, max={ref.max():.2f}, " f"mean={ref.mean():.2f}, std={ref.std():.2f}") + a, b, a_scale, b_scale = _pad_mxscale_inputs( + a, b, a_scale, b_scale, padded_shape) + # Preshuffle scales skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // m_warp @@ -204,19 +282,19 @@ def _run_mxscale_gemm_test( b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) # Preshuffle B data - K_packed = K // 2 if (is_fp4 or is_a8w4) else K - b = fp4_utils.preshuffle_b_16x16(b, N, K_packed) + K_packed = padded_k // padded_shape["pack_b"] + b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) # Upload & launch a_gpu = a.cuda() b_gpu = b.cuda() as_gpu = a_scale.cuda() bs_gpu = b_scale.cuda() - c_gpu = torch.zeros(M, N, dtype=torch_out_dtype, device="cpu").cuda() + c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_out_dtype, device="cuda") launch_fn = compile_mxscale_gemm( data_format=data_format, - M=M, N=N, K=K, + M=padded_m, N=padded_n, K=padded_k, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=m_warp, n_warp=n_warp, num_buffers=num_buffers, @@ -237,11 +315,11 @@ def _run_mxscale_gemm_test( b_gpu.contiguous().view(-1), as_gpu.contiguous().view(-1), bs_gpu.contiguous().view(-1), - M, N, torch.cuda.current_stream(), + padded_m, padded_n, torch.cuda.current_stream(), ) torch.cuda.synchronize() - c_out = c_gpu.cpu() + c_out = c_gpu[:M, :N].cpu() print(f"Out stats: min={c_out.float().min():.2f}, max={c_out.float().max():.2f}, " f"mean={c_out.float().mean():.2f}, std={c_out.float().std():.2f}") @@ -312,9 +390,10 @@ def _extract_i64_metadata(compiled_ir: str, key: str) -> int: @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ - (128, 128, 256, 128, 128, 128, 2, 2), - (128, 128, 512, 128, 128, 128, 2, 2), - (128, 128, 1024, 128, 128, 128, 2, 2), + (128, 512, 7168, 128, 128, 256, 2, 2), + (128, 7168, 256, 128, 256, 128, 2, 2), + (128, 4096, 7168, 128, 256, 256, 2, 2), + (128, 7168, 2048, 128, 256, 256, 2, 2), (1024, 1024, 1024, 256, 256, 256, 2, 2), ], ) @@ -380,9 +459,8 @@ def test_mxfp8_gemm(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ - (128, 128, 256, 128, 128, 128, 2, 2), - (128, 128, 512, 128, 128, 128, 2, 2), - (128, 128, 1024, 128, 128, 128, 2, 2), + (128, 5760, 2880, 128, 256, 256, 2, 2), + (128, 2880, 2880, 128, 256, 256, 2, 2), (1024, 1024, 1024, 128, 256, 128, 2, 4), ], ) @@ -399,6 +477,28 @@ def test_a8w4_gemm(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, use_scale_opsel=use_scale_opsel) +@pytest.mark.parametrize( + "M, N, K, use_tdm_store", + [ + (13, 2880, 2880, True), + (33, 5760, 2880, False), + ], +) +def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): + # Small-M path: pad M to 16 and dedicate one wave to the M dimension. + _run_mxscale_gemm_test( + "a8w4", + M, N, K, + 16, 256, 256, + 1, 4, + num_buffers=2, + use_tdm_store=use_tdm_store, + out_dtype="bf16", + l2_prefetch_distance=2, + use_scale_opsel=False, + ) + + @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", [ @@ -487,22 +587,32 @@ def _run_benchmark(args): data_format = args.data_format M, N, K = args.M, args.N, args.K + tile_m, tile_n, tile_k = args.tile_m, args.tile_n, args.tile_k + if K % SCALE_BLOCK != 0: + raise ValueError(f"K={K} must be divisible by SCALE_BLOCK={SCALE_BLOCK}") + + padded_shape = _get_padded_problem_shape( + data_format, M, N, K, tile_m, tile_n, tile_k, args.split_k) + padded_m = padded_shape["M"] + padded_n = padded_shape["N"] + padded_k = padded_shape["K"] + PACK_A = padded_shape["pack_a"] + PACK_B = padded_shape["pack_b"] + is_fp4 = data_format == "fp4" is_a8w4 = data_format == "a8w4" - PACK_A = 1 if (not is_fp4 and not is_a8w4) else (1 if data_format == "fp8" else 2 if is_fp4 else 1) - PACK_B = 2 if (is_fp4 or is_a8w4) else 1 - _dtype_map = {"f32": torch.float32, "bf16": torch.bfloat16, "f16": torch.float16} torch_out_dtype = _dtype_map[args.out_dtype] elem_bytes_d = 2 if args.out_dtype in ("bf16", "f16") else 4 - - tile_m, tile_n, tile_k = args.tile_m, args.tile_n, args.tile_k fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") print("=" * 72) print(f" {fmt_name} GEMM Benchmark on gfx1250") print(f" PyTorch {torch.__version__}, Device: {torch.cuda.get_device_name(0)}") + needs_pad = (padded_m, padded_n, padded_k) != (M, N, K) print(f" Shape: M={M}, N={N}, K={K}") + if needs_pad: + print(f" Kernel pad: M={padded_m}, N={padded_n}, K={padded_k}") print(f" Tile: ({tile_m}, {tile_n}, {tile_k}), warps=({args.m_warp}x{args.n_warp})") print(f" Buffers={args.num_buffers}, out={args.out_dtype}, " f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}") @@ -528,20 +638,23 @@ def _run_benchmark(args): a_scale = fp4_utils.random_e8m0(M, K // SCALE_BLOCK) b_scale = fp4_utils.random_e8m0(N, K // SCALE_BLOCK) + a, b, a_scale, b_scale = _pad_mxscale_inputs( + a, b, a_scale, b_scale, padded_shape) + skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // args.m_warp warp_tile_n = tile_n // args.n_warp a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) - K_packed = K // 2 if (is_fp4 or is_a8w4) else K - b = fp4_utils.preshuffle_b_16x16(b, N, K_packed) + K_packed = padded_k // PACK_B + b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) a_gpu = a.cuda() b_gpu = b.cuda() as_gpu = a_scale.cuda() bs_gpu = b_scale.cuda() - c_gpu = torch.zeros(M, N, dtype=torch_out_dtype, device="cuda") + c_gpu = torch.zeros(padded_m, padded_n, dtype=torch_out_dtype, device="cuda") print(f"\n[1/3] Compiling kernel...") t0 = time.perf_counter() @@ -551,7 +664,7 @@ def _run_benchmark(args): use_tdm_store = False launch_fn = compile_mxscale_gemm( data_format=data_format, - M=M, N=N, K=K, + M=padded_m, N=padded_n, K=padded_k, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=args.m_warp, n_warp=args.n_warp, num_buffers=args.num_buffers, @@ -569,18 +682,20 @@ def _run_benchmark(args): ) stream = torch.cuda.current_stream() + c_flat = c_gpu.view(-1) + a_flat = a_gpu.view(-1) + b_flat = b_gpu.view(-1) + as_flat = as_gpu.view(-1) + bs_flat = bs_gpu.view(-1) def prep_kernel(): c_gpu.zero_() def run_kernel(): launch_fn( - c_gpu.contiguous().view(-1), - a_gpu.contiguous().view(-1), - b_gpu.contiguous().view(-1), - as_gpu.contiguous().view(-1), - bs_gpu.contiguous().view(-1), - M, N, stream, + c_flat, a_flat, b_flat, + as_flat, bs_flat, + padded_m, padded_n, stream, ) prep_kernel() @@ -593,14 +708,16 @@ def run_kernel(): us = _bench_kernel_us(run_kernel, warmup=args.warmup, iters=args.iters, flush_l2=not args.no_flush_l2, prep_fn=prep_kernel) - flops = 2.0 * M * N * K + logical_flops = 2.0 * M * N * K + kernel_flops = 2.0 * padded_m * padded_n * padded_k time_s = us / 1e6 - tflops = flops / time_s / 1e12 if time_s > 0 else 0.0 + logical_tflops = logical_flops / time_s / 1e12 if time_s > 0 else 0.0 + kernel_tflops = kernel_flops / time_s / 1e12 if time_s > 0 else 0.0 - bytes_a = M * K // PACK_A - bytes_b = N * K // PACK_B - bytes_scale = (M + N) * (K // SCALE_BLOCK) - bytes_d = M * N * elem_bytes_d + bytes_a = padded_m * padded_k // PACK_A + bytes_b = padded_n * padded_k // PACK_B + bytes_scale = (padded_m + padded_n) * padded_shape["K_scale"] + bytes_d = padded_m * padded_n * elem_bytes_d read_bytes = bytes_a + bytes_b + bytes_scale write_bytes = bytes_d bytes_moved = read_bytes + write_bytes @@ -614,10 +731,10 @@ def run_kernel(): wmma_n_rep = warp_tile_n // WMMA_N_EFF k_wmma_steps = tile_k // WMMA_K wmma_per_tile = wmma_m_rep * wmma_n_rep * k_wmma_steps - m_tiles = (M + tile_m - 1) // tile_m - n_tiles = (N + tile_n - 1) // tile_n - k_tiles = K // tile_k - k_tiles_local = (K // args.split_k) // tile_k + m_tiles = padded_m // tile_m + n_tiles = padded_n // tile_n + k_tiles = padded_k // tile_k + k_tiles_local = (padded_k // args.split_k) // tile_k total_wmma = m_tiles * n_tiles * k_tiles * wmma_per_tile # Sequential WMMAs per workgroup (all k_tiles execute sequentially) seq_wmma = k_tiles_local * wmma_per_tile @@ -625,7 +742,10 @@ def run_kernel(): print(f"\n[3/3] Results:") print(f" Kernel time: {us:.1f} us ({us / 1e3:.4f} ms)") - print(f" TFLOPS: {tflops:.4f}") + if not needs_pad: + print(f" TFLOPS: {kernel_tflops:.4f}") + else: + print(f" TFLOPS: {logical_tflops:.4f} (logical), {kernel_tflops:.4f} (kernel)") print(f" Bandwidth: {bw_gbs:.1f} GB/s " f"(read: {read_bw_gbs:.1f} + write: {write_bw_gbs:.1f})") print(f" Bytes moved: {bytes_moved / 1e6:.1f} MB " @@ -646,7 +766,8 @@ def run_kernel(): f"WMMA_SCALE trap-handler emulation") print("=" * 72) - return us, tflops, bw_gbs + reported_tflops = kernel_tflops if not needs_pad else logical_tflops + return us, reported_tflops, bw_gbs if __name__ == "__main__": diff --git a/tests/kernels/test_wmma_gemm_gfx1250.py b/tests/kernels/test_wmma_gemm_gfx1250.py index a88b51d3..1bb3014f 100644 --- a/tests/kernels/test_wmma_gemm_gfx1250.py +++ b/tests/kernels/test_wmma_gemm_gfx1250.py @@ -41,21 +41,43 @@ def _validate_pipeline_depth(*, K, tile_k, num_buffers): ) +def _align_up(value: int, align: int) -> int: + return ((value + align - 1) // align) * align + + +def _get_padded_problem_shape(M: int, N: int, K: int, + tile_m: int, tile_n: int, tile_k: int) -> tuple[int, int, int]: + return ( + _align_up(M, tile_m), + _align_up(N, tile_n), + _align_up(K, tile_k), + ) + + +def _pad_2d_tensor(tensor: torch.Tensor, rows: int, cols: int, fill_value: float = 0.0) -> torch.Tensor: + if tensor.shape == (rows, cols): + return tensor + padded = torch.full((rows, cols), fill_value, dtype=tensor.dtype, device=tensor.device) + padded[:tensor.shape[0], :tensor.shape[1]] = tensor + return padded + + +def _format_kernel_pad(M: int, N: int, K: int, mpad: int, npad: int, kpad: int) -> str: + padded_dims = (mpad, npad, kpad) + if padded_dims == (M, N, K): + return "" + return f", kernel_pad={padded_dims}" + + @pytest.mark.parametrize("in_dtype", ["fp16", "bf16"]) @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k", [ (128, 128, 64, 64, 128, 32), - (128, 128, 256, 64, 128, 128), (256, 256, 256, 64, 256, 128), - (256, 256, 192, 64, 256, 64), - (256, 512, 256, 64, 256, 128), (512, 512, 512, 64, 256, 128), - (201, 179, 128, 64, 128, 64), (300, 399, 256, 64, 256, 128), - (256, 256, 256, 256, 256, 128), (1024, 1024, 1024, 256, 256, 128), - (512, 512, 512, 256, 256, 128), ], ) @pytest.mark.parametrize("num_buffers", [2, 3]) @@ -70,7 +92,8 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, if arch != "gfx1250": pytest.skip(f"WMMA requires gfx1250, got {arch}") - _validate_pipeline_depth(K=K, tile_k=tile_k, num_buffers=num_buffers) + mpad, npad, kpad = _get_padded_problem_shape(M, N, K, tile_m, tile_n, tile_k) + _validate_pipeline_depth(K=kpad, tile_k=tile_k, num_buffers=num_buffers) lds_pad = 8 elem_bytes = 2 @@ -85,8 +108,6 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 torch.manual_seed(0) - mpad = (M + tile_m - 1) // tile_m * tile_m - npad = (N + tile_n - 1) // tile_n * tile_n wg_m = mpad // tile_m wg_n = npad // tile_n @@ -104,25 +125,28 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, f"wg_grid=({wg_m},{wg_n}), cluster=({cluster_m},{cluster_n})" ) + pad_str = _format_kernel_pad(M, N, K, mpad, npad, kpad) + print( + f"Running WMMA GEMM TDM: M={M}, N={N}, K={K}{pad_str}, ", + end="" + ) print( - f"Running WMMA GEMM TDM: M={M}, N={N}, K={K}, " f"dtype={in_dtype}, out={_eff_out}, bufs={num_buffers}, " f"tdm_store={use_tdm_store}, cluster=({cluster_m},{cluster_n}), " f"wave_spec_tdm={wave_specialized_tdm}, inst_prefetch={inst_prefetch}" ) - a = torch.randn((M, K), dtype=torch_dtype, device='cpu').cuda() - b = torch.randn((K, N), dtype=torch_dtype, device='cpu').cuda() + a = torch.randn((M, K), dtype=torch_dtype, device='cpu') + b = torch.randn((K, N), dtype=torch_dtype, device='cpu') + ref = torch.mm(a.to(torch.float32), b.to(torch.float32)) - a_pad = torch.zeros((mpad, K), dtype=torch_dtype, device='cpu').cuda() - b_pad = torch.zeros((K, npad), dtype=torch_dtype, device='cpu').cuda() - a_pad[:M, :] = a - b_pad[:, :N] = b + a_pad = _pad_2d_tensor(a, mpad, kpad).cuda() + b_pad = _pad_2d_tensor(b, kpad, npad).cuda() - c_pad = torch.zeros((mpad, npad), dtype=_out_torch, device='cpu').cuda() + c_pad = torch.zeros((mpad, npad), dtype=_out_torch, device='cuda') launch_fn = compile_wmma_gemm_tdm( - M=mpad, N=npad, K=K, + M=mpad, N=npad, K=kpad, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=m_warp, n_warp=n_warp, in_dtype=in_dtype, out_dtype=out_dtype, @@ -142,7 +166,6 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, ) torch.cuda.synchronize() - ref = torch.mm(a.cpu().to(torch.float32), b.cpu().to(torch.float32)) rtol = 3e-2 atol = 3e-2 assert verify_output(c_pad[:M, :N].cpu().to(torch.float32), ref, rtol=rtol, atol=atol) @@ -154,15 +177,14 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, "M, N, K, tile_m, tile_n, tile_k", [ (1024, 1024, 1024, 128, 256, 128), - (2048, 2048, 1024, 128, 256, 128), (2048, 2048, 2048, 128, 256, 128), - (4096, 4096, 1024, 128, 256, 128), ], ) @pytest.mark.parametrize("cluster_m, cluster_n", [(2, 2), (4, 4)]) def test_wmma_gemm_tdm_mcast(in_dtype, M, N, K, tile_m, tile_n, tile_k, cluster_m, cluster_n): """Cluster multicast GEMM correctness test (large shapes only).""" + pytest.skip("Temporarily skip fp16 GEMM mcast tests.") test_wmma_gemm_tdm( in_dtype, M, N, K, tile_m, tile_n, tile_k, num_buffers=2, m_warp=2, n_warp=4, @@ -207,8 +229,35 @@ def test_wmma_gemm_tdm_tdm_store_tail_regression(in_dtype): ) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, use_tdm_store", + [ + pytest.param(13, 512, 7168, 16, 128, 128, 1, 4, True, id="DS-TP-stage1"), + pytest.param(6, 7168, 256, 16, 256, 128, 1, 4, False, id="DS-TP-stage2"), + pytest.param(29, 3072, 5120, 32, 256, 128, 1, 4, True, id="DS-EP-stage1"), + pytest.param(32, 5120, 1536, 32, 256, 128, 1, 4, False, id="DS-EP-stage2"), + pytest.param(22, 5760, 2880, 32, 256, 128, 1, 4, True, id="GPTOSS-stage1"), + pytest.param(23, 2880, 2880, 32, 256, 128, 1, 4, False, id="GPTOSS-stage2"), + ], +) +def test_wmma_gemm_tdm_moe_shapes( + M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, use_tdm_store, +): + test_wmma_gemm_tdm( + "fp16", + M, N, K, + tile_m, tile_n, tile_k, + num_buffers=2, + m_warp=m_warp, + n_warp=n_warp, + l2_prefetch_distance=2, + use_tdm_store=use_tdm_store, + ) + + def test_wmma_gemm_tdm_mcast_tail(): """Exercise cluster mode with an even number of K tiles (tail includes a load).""" + pytest.skip("Temporarily skip fp16 GEMM mcast tests.") test_wmma_gemm_tdm( "fp16", 512, 512, 512, From 6cb26c42fa0a5ce49b81323112600b2f772ad9f4 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 14 Apr 2026 17:10:01 +0800 Subject: [PATCH 11/29] Add MakeFragmentLayoutLikeOp, improve robustness of prim funcs (#399) --- docs/api/dsl.rst | 1 + examples/02-tiledCopy.py | 5 ++--- examples/03-tiledMma.py | 11 ++++------- examples/04-preshuffle_gemm.py | 15 +++++++++------ include/flydsl/Dialect/Fly/IR/FlyOps.td | 5 +++++ lib/Dialect/Fly/IR/FlyOps.cpp | 17 +++++++++++++++++ lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 19 ++++++++++++++++++- python/flydsl/expr/primitive.py | 16 ++++++++++++++++ python/flydsl/expr/typing.py | 16 ++++++++++++---- 9 files changed, 84 insertions(+), 21 deletions(-) diff --git a/docs/api/dsl.rst b/docs/api/dsl.rst index 7289826b..e94d6fa5 100644 --- a/docs/api/dsl.rst +++ b/docs/api/dsl.rst @@ -80,6 +80,7 @@ Memory Operations - **fx.memref_store(value, memref, indices)** -- scalar store to memref - **fx.memref_load_vec(memref)** -- load entire register as a vector - **fx.memref_store_vec(vec, memref)** -- store vector to register memref +- **fx.make_fragment_layout_like(layout_like)** -- compute the corresponding fragment layout - **fx.make_fragment_like(tensor)** -- allocate register fragment with same layout Copy & GEMM diff --git a/examples/02-tiledCopy.py b/examples/02-tiledCopy.py index d5a3e183..e6fc55da 100644 --- a/examples/02-tiledCopy.py +++ b/examples/02-tiledCopy.py @@ -17,13 +17,12 @@ def copy_kernel( block_m = 8 block_n = 24 - tile = fx.make_tile(fx.make_layout(block_m, 1), fx.make_layout(block_n, 1)) A = fx.rocdl.make_buffer_tensor(A) B = fx.rocdl.make_buffer_tensor(B) - bA = fx.zipped_divide(A, tile) - bB = fx.zipped_divide(B, tile) + bA = fx.zipped_divide(A, (block_m, block_n)) + bB = fx.zipped_divide(B, (block_m, block_n)) bA = fx.slice(bA, (None, bid)) bB = fx.slice(bB, (None, bid)) diff --git a/examples/03-tiledMma.py b/examples/03-tiledMma.py index 4cdaa034..eb948931 100644 --- a/examples/03-tiledMma.py +++ b/examples/03-tiledMma.py @@ -20,17 +20,13 @@ def gemm_kernel( tid = fx.thread_idx.x bid = fx.block_idx.x - tileA = fx.make_tile(block_m, block_k) - tileB = fx.make_tile(block_n, block_k) - tileC = fx.make_tile(block_m, block_n) - A = fx.rocdl.make_buffer_tensor(A) B = fx.rocdl.make_buffer_tensor(B) C = fx.rocdl.make_buffer_tensor(C) - bA = fx.zipped_divide(A, tileA) - bB = fx.zipped_divide(B, tileB) - bC = fx.zipped_divide(C, tileC) + bA = fx.zipped_divide(A, (block_m, block_k)) + bB = fx.zipped_divide(B, (block_n, block_k)) + bC = fx.zipped_divide(C, (block_m, block_n)) bA = fx.slice(bA, (None, bid)) bB = fx.slice(bB, (None, bid)) @@ -64,6 +60,7 @@ def gemm_kernel( fx.copy(copy_atom, copy_src_A, copy_frag_A, pred=None) fx.copy(copy_atom, copy_src_B, copy_frag_B, pred=None) + frag_C.fill(0) fx.gemm(mma_atom, frag_C, frag_A, frag_B, frag_C) fx.copy(copy_atom, copy_frag_C, copy_dst_C, pred=None) diff --git a/examples/04-preshuffle_gemm.py b/examples/04-preshuffle_gemm.py index 9cb3335e..d47b6956 100644 --- a/examples/04-preshuffle_gemm.py +++ b/examples/04-preshuffle_gemm.py @@ -28,9 +28,9 @@ def gemm_kernel( B = fx.rocdl.make_buffer_tensor(B) C = fx.rocdl.make_buffer_tensor(C) - gA_k = fx.flat_divide(A, fx.make_tile(BLOCK_M, BLOCK_K))[None, None, bid_x, None] # (BM, BK, k) - gB_k = fx.flat_divide(B, fx.make_tile(BLOCK_N, BLOCK_K))[None, None, bid_y, None] # (BN, BK, k) - gC = fx.flat_divide(C, fx.make_tile(BLOCK_M, BLOCK_N))[None, None, bid_x, bid_y] # (BM, BN) + gA_k = fx.flat_divide(A, (BLOCK_M, BLOCK_K))[None, None, bid_x, None] # (BM, BK, k) + gB_k = fx.flat_divide(B, (BLOCK_N, BLOCK_K))[None, None, bid_y, None] # (BN, BK, k) + gC = fx.flat_divide(C, (BLOCK_M, BLOCK_N))[None, None, bid_x, bid_y] # (BM, BN) thr_mma = tiled_mma.thr_slice(tid) thr_copy_g2s_A = tiled_copy_g2s_A.get_slice(tid) @@ -63,8 +63,10 @@ def gemm_kernel( mma_frag_A = thr_mma.make_fragment_A(sA[None, None, 0]) # (VA, VM, VN) mma_frag_B = fx.make_fragment_like( - fx.flat_product(thr_mma.partition_B(gB_k).layout(None, None, None, 0), fx.make_layout(2, 1)), - fx.Float16.ir_type, + fx.flat_product( + fx.make_fragment_layout_like(thr_mma.partition_B(gB_k).layout(None, None, None, 0)), fx.make_layout(2, 1) + ), + fx.Float16, ) # (VB, VM, VK, 2) mma_frag_C = thr_mma.make_fragment_C(gC) # (VC, VM, VN) @@ -101,6 +103,7 @@ def run_pipeline_stage(read_stage, next_k, read_next=True): mma_frag_A[None, None, (None, block_k_iter)], mma_frag_B[None, None, (None, block_k_iter), read_stage], mma_frag_C, + traversal_order=fx.GemmTraversalOrder.KNM, ) fx.copy(uni_copy_128b, copy_frag_A, thr_sA[None, None, None, write_stage]) @@ -136,7 +139,7 @@ def sched_main_iter(with_vmem=False, with_dswr=False): fx.copy(buffer_copy_128b, thr_gA_k[None, None, None, 0], copy_frag_A) fx.copy(buffer_copy_128b, thr_gB_k[None, None, None, 0], mma_frag_B_retile[None, None, None, 0]) - mma_frag_C.store(fx.arith.constant_vector(0.0, fx.T.VectorType.get([64], fx.T.f32()))) + mma_frag_C.fill(0) fx.copy(uni_copy_128b, copy_frag_A, thr_sA[None, None, None, 0]) fx.gpu.barrier() diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index a64e0384..a9fbb837 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -102,6 +102,11 @@ def Fly_MakeFragmentLikeOp : Fly_Op<"make_fragment_like", [Pure, DeclareOpInterf let assemblyFormat = "`(` $src (`,` $dtype^)? `)` attr-dict `:` functional-type(operands, results)"; } +def Fly_MakeFragmentLayoutLikeOp : Fly_Op<"make_fragment_layout_like", [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyTypeOf<[Fly_Layout, Fly_MemRef]>:$src); + let results = (outs Fly_Layout:$result); +} + //===----------------------------------------------------------------------===// // Extractors //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Fly/IR/FlyOps.cpp b/lib/Dialect/Fly/IR/FlyOps.cpp index 438b4ded..a278fea3 100644 --- a/lib/Dialect/Fly/IR/FlyOps.cpp +++ b/lib/Dialect/Fly/IR/FlyOps.cpp @@ -248,6 +248,23 @@ FLY_INFER_RETURN_TYPES(MakeViewOp) { } } +FLY_INFER_RETURN_TYPES(MakeFragmentLayoutLikeOp) { + auto srcLayout = GetLayoutAttrFromLayoutLikeType(operands[0].getType()); + if (!srcLayout) + return emitOptionalError(location, + "MakeFragmentLayoutLikeOp: expected LayoutType or MemRefType, got ", + operands[0].getType()); + + if (!srcLayout.getShape().isStatic()) + return emitOptionalError( + location, "MakeFragmentLayoutLikeOp: expected static shape layout, got ", srcLayout); + + LayoutBuilder layoutBuilder(context); + LayoutAttr fragmentLayout = layoutMakeFragmentLayout(layoutBuilder, srcLayout); + inferredReturnTypes.assign({LayoutType::get(context, fragmentLayout)}); + return success(); +} + FLY_INFER_RETURN_TYPES(MakeFragmentLikeOp) { TypeAttr dtypeAttr; if (properties) diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index 4ab3f5a1..7a15dd95 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -364,6 +364,22 @@ class MakeIdentityLayoutOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MakeFragmentLayoutLikeOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resultTy = cast(op.getType()); + + LayoutBuilder layoutBuilder(rewriter, loc); + Value fragmentLayout = layoutBuilder.materializeConstantLayout(resultTy.getAttr()).getValue(); + rewriter.replaceOp(op, fragmentLayout); + return success(); + } +}; + class MakeFragmentLikeOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2734,7 +2750,8 @@ class FlyLayoutLoweringPass // Constructors patterns.add(context); + MakeLayoutLikeOpLowering, MakeFragmentLayoutLikeOpLowering, + MakeFragmentLikeOpLowering>(context); // Extractors patterns.add "Vector": @classmethod def filled(cls, shape, fill_value, dtype: Type[Numeric], *, loc=None, ip=None) -> "Vector": + def _shape_numel(dims): + n = 1 + for dim in dims: + if isinstance(dim, (tuple, list)): + n *= _shape_numel(dim) + else: + n *= dim + return n + shape = (shape,) if isinstance(shape, int) else tuple(shape) - n = 1 - for s in shape: - n *= s + n = _shape_numel(shape) if isinstance(fill_value, (int, float, bool)): fill_value = dtype(fill_value) elif isinstance(fill_value, Numeric): From fd9ca4fed12095099920eca19efa5f9940b93577 Mon Sep 17 00:00:00 2001 From: yadaish Date: Tue, 14 Apr 2026 20:06:00 +0800 Subject: [PATCH 12/29] support split-k algo for moe_gemm_2stage (#390) --- kernels/moe_gemm_2stage.py | 230 ++++++++++++++++++++++++++++++--- tests/kernels/test_moe_gemm.py | 22 +++- 2 files changed, 233 insertions(+), 19 deletions(-) diff --git a/kernels/moe_gemm_2stage.py b/kernels/moe_gemm_2stage.py index bce43ece..06eddb24 100644 --- a/kernels/moe_gemm_2stage.py +++ b/kernels/moe_gemm_2stage.py @@ -109,6 +109,7 @@ def compile_moe_gemm1( out_dtype: str = "f16", use_cshuffle_epilog: bool | None = None, scale_is_bf16: bool = False, + k_batch: int = 1, ): """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. @@ -123,6 +124,8 @@ def compile_moe_gemm1( - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel - "int4_bf16": W4A16 path: X is bf16, W is packed int4 unpacked to bf16 in-kernel scale_is_bf16: When True, groupwise scales are bf16 (halves scale bandwidth). + k_batch: Split-K factor. When >1, K is partitioned across k_batch CTAs that + atomically accumulate gate/up partials. Caller must pre-zero output. """ gpu_arch = get_hip_arch() @@ -185,6 +188,21 @@ def compile_moe_gemm1( _is_gfx950 = "gfx95" in get_hip_arch() use_gfx950_cvt = is_int4_bf16 and _is_gfx950 + # Split-K validation + _is_splitk = k_batch > 1 + if _is_splitk: + _k_per_batch = model_dim // k_batch + assert model_dim % k_batch == 0, f"model_dim={model_dim} not divisible by k_batch={k_batch}" + assert _k_per_batch % tile_k == 0, f"K_per_batch={_k_per_batch} not divisible by tile_k={tile_k}" + # The ping-pong K-loop requires an even number of K tiles (>=4). + _k_tiles = _k_per_batch // tile_k + assert _k_tiles >= 4 and _k_tiles % 2 == 0, ( + f"K_per_batch/tile_k={_k_tiles} must be even and >=4 for the ping-pong pipeline. " + f"Try a different k_batch (model_dim={model_dim}, tile_k={tile_k})." + ) + else: + _k_per_batch = model_dim + mfma_i32_k32 = None if is_int8: mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( @@ -239,7 +257,8 @@ def compile_moe_gemm1( if use_cshuffle_epilog is None: use_cshuffle_epilog = os.environ.get("FLYDSL_MOE_STAGE1_CSHUFFLE", "1") in ("1", "true", "True", "YES", "yes") use_cshuffle_epilog = bool(use_cshuffle_epilog) - if out_dtype != "f16" and use_cshuffle_epilog: + # Split-K uses f32 atomic CShuffle regardless of out_dtype, so skip this check. + if out_dtype != "f16" and use_cshuffle_epilog and not _is_splitk: raise ValueError("stage1 cshuffle epilog currently supports only f16 output (out_dtype='f16')") epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" @@ -247,10 +266,11 @@ def compile_moe_gemm1( # Keep an explicit ABI tag so signature changes can't accidentally reuse an old binary. _gs_tag = f"_g{group_size}" if use_groupwise_scale else "" scale_tag = "_sbf16" if _scale_is_bf16 else "" + _split_k_tag = f"_splitk{k_batch}" if _is_splitk else "" module_name = ( f"mfma_moe1_{in_dtype}_{out_dtype}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"{_gs_tag}{scale_tag}" + f"{_gs_tag}{scale_tag}{_split_k_tag}" f"_abi3" # also mask sentinel token ids on loads (X/scale_x) to avoid illegal address faults ).replace("-", "_") @@ -259,8 +279,12 @@ def compile_moe_gemm1( # - ping-pong X tiles (2 * tile_m * lds_stride bytes) # - optional epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) _use_cshuffle_epilog = bool(use_cshuffle_epilog) + # Split-K requires CShuffle epilogue (f32 atomic adds via store_pair callback) + if _is_splitk: + _use_cshuffle_epilog = True + _cshuffle_elem_bytes = 4 if _is_splitk else 2 # f32 for split-K, f16 otherwise lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(elem_bytes) - lds_out_bytes = 2 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 + lds_out_bytes = _cshuffle_elem_bytes * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 lds_total_bytes = max(lds_x_bytes, lds_out_bytes) lds_total_elems = lds_total_bytes if elem_bytes == 1 else (lds_total_bytes // 2) @@ -350,6 +374,12 @@ def silu(x): by = gpu.block_id("x") # tile along inter_dim bx = gpu.block_id("y") # tile along sorted M + if _is_splitk: + bz = gpu.block_id("z") # K-batch id + k_base_idx = bz * arith.index(_k_per_batch) + else: + k_base_idx = arith.index(0) + # Block validity: compute as early as possible so invalid blocks skip all buffer-resource # setup, LDS pointer math, and gmem prefetch work. bx_m = bx * fx.Index(tile_m) @@ -381,9 +411,11 @@ def silu(x): shape=(lds_total_elems,), ) lds_x = lds_x_ptr.get() - # Alias LDS bytes as fp16 for optional CShuffle epilogue. + # Alias LDS bytes for optional CShuffle epilogue. + # Split-K uses f32 (4B) per element for atomic accumulation; normal uses f16 (2B). + _lds_out_elem_type = T.f32 if _is_splitk else T.f16 lds_out = ( - SmemPtr(base_ptr, lds_x_ptr.byte_offset, T.f16, shape=(tile_m * tile_n,)).get() + SmemPtr(base_ptr, lds_x_ptr.byte_offset, _lds_out_elem_type, shape=(tile_m * tile_n,)).get() if _use_cshuffle_epilog else None ) @@ -401,9 +433,12 @@ def silu(x): w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) - # OUT: [tokens, topk, inter] f16/bf16 -> bytes = tokens*topk*inter*out_elem_bytes - out_elem_bytes = 2 # f16/bf16 - out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(out_elem_bytes) + # OUT: normal=[tokens, topk, inter] f16/bf16, split-K=[tokens*topk, 2*inter] f32 + out_elem_bytes = 4 if _is_splitk else 2 + if _is_splitk: + out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(2 * out_elem_bytes) + else: + out_nbytes_idx = tokens_in * c_topk * inter_in * fx.Index(out_elem_bytes) out_rsrc = buffer_ops.create_buffer_resource( arg_out, max_size=False, num_records_bytes=out_nbytes_idx ) @@ -992,26 +1027,26 @@ def hot_loop_scheduler(): rocdl.sched_barrier(0) # Prologue: prefetch tile0, store to LDS(cur), sync. - k0 = fx.Index(0) + k0 = k_base_idx x_regs0 = load_x_tile(k0) b_gate_cur = load_b_tile(k0, n_blk_gate, n_intra_gate) b_up_cur = load_b_tile(k0, n_blk_up, n_intra_up) store_x_tile_to_lds(x_regs0, lds_base_cur) gpu.barrier() - + # Loop-carried ping/pong state. lds_base_pong = lds_base_cur # current/compute lds_base_ping = lds_base_nxt # next/load+store - + # Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the # tile we are about to compute from LDS, to overlap with upcoming VMEM. a0_prefetch_pong = lds_load_packs_k64(row_a_lds, col_offset_base_bytes, lds_base_pong) - + # Ping-pong main loop (2 tiles per iteration), leaving 2 tail tiles. # Uses scf.for with loop-carried accumulators, B-tile prefetch, and A0 LDS prefetch. c2_tile_k = arith.index(tile_k * 2) c_tile_k = arith.index(tile_k) - total_tiles = int(model_dim) // int(tile_k) + total_tiles = int(_k_per_batch) // int(tile_k) pair_iters = max((total_tiles - 2) // 2, 0) # B-tile data layout per k_unroll entry (3 variants): @@ -1091,7 +1126,7 @@ def _unflatten_b_tile(vals): _bu = _unflatten_b_tile(list(state[_p_bu:_p_a0])) _a0pf = (state[_p_a0], state[_p_a0 + 1]) - k_iv = pair_iv * (c_tile_k + c_tile_k) + k_iv = k_base_idx + pair_iv * (c_tile_k + c_tile_k) # ---- stage 0: prefetch+store ping, compute pong ---- next_k1 = k_iv + c_tile_k @@ -1133,7 +1168,7 @@ def _unflatten_b_tile(vals): b_gate_cur = _unflatten_b_tile(list(loop_results[_p_bg:_p_bu])) b_up_cur = _unflatten_b_tile(list(loop_results[_p_bu:_p_a0])) a0_prefetch_pong = (loop_results[_p_a0], loop_results[_p_a0 + 1]) - k_tail1 = k_in - tile_k + k_tail1 = k_base_idx + arith.index(_k_per_batch - tile_k) x_regs_ping = load_x_tile(k_tail1) b_gate_ping = load_b_tile(k_tail1, n_blk_gate, n_intra_gate) b_up_ping = load_b_tile(k_tail1, n_blk_up, n_intra_up) @@ -1214,6 +1249,169 @@ def _unflatten_b_tile(vals): # Uses EVec=4 (buffer store "x4" of fp16 elements). use_cshuffle_epilog_flag = _use_cshuffle_epilog + # ─── Split-K epilogue: two-pass gate/up with f32 atomic fadd ─── + if _is_splitk: + if lds_out is None: + raise RuntimeError("Split-K epilogue requires lds_out (CShuffle)") + + out_base_idx = buffer_ops.extract_base_index(arg_out) + _split_k_out_row_stride = inter_dim * 2 * out_elem_bytes # bytes per row + _split_k_e_vec = 2 # f32 vec2 for atomic fadd + + # Mutable slot: 0 for gate pass, inter_dim for up pass + _split_k_n_offset = [0] + + # Mutable slots for two-pass gate/up selection + _split_k_acc = [acc_gate] + _split_k_sw_vals = [sw_gate_vals] + + def write_row_to_lds_splitk( + *, + mi: int, + ii: int, + row_in_tile, + row, + row_base_lds, + col_base_local, + num_acc_n: int, + lds_out, + ): + """Write scaled f32 partial sums to LDS (no silu, no doweight).""" + _acc = _split_k_acc[0] + _sw = _split_k_sw_vals[0] + # Load per-row scale_x (sx) — same logic as normal epilogue. + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=T.i32) + t2 = fused2 & mask24_i32 + t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) + if x_is_token_slot: + s2 = fused2 >> 24 + ts2 = s2 * tokens_i32_v + t2 + sx = ( + fx.Float32(1.0) + if is_f16_or_bf16 + else arith.select( + t_valid, + buffer_ops.buffer_load(sx_rsrc, ts2, vec_width=1, dtype=T.f32), + fx.Float32(0.0), + ) + ) + else: + sx = ( + fx.Float32(1.0) + if is_f16_or_bf16 + else arith.select( + t_valid, + buffer_ops.buffer_load(sx_rsrc, t2, vec_width=1, dtype=T.f32), + fx.Float32(0.0), + ) + ) + for ni in range_constexpr(num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * num_acc_n + ni + v = vector.extract( + _acc[acc_idx], static_position=[ii], dynamic_position=[] + ) + if is_int8: + v = arith.sitofp(T.f32, v) + v = v * sx * _sw[ni] + lds_idx = row_base_lds + col_local + v1 = vector.from_elements(T.vec(1, T.f32), [v]) + vector.store(v1, lds_out, [lds_idx], alignment=4) + + def precompute_row_splitk(*, row_local, row): + fused2 = buffer_ops.buffer_load(sorted_rsrc, row, vec_width=1, dtype=T.i32) + t2 = fused2 & mask24_i32 + s2 = fused2 >> 24 + t_ok = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) + t_idx = arith.index_cast(T.index, t2) + s_idx = arith.index_cast(T.index, s2) + ts_idx = t_idx * arith.index(topk) + s_idx + row_byte_base = out_base_idx + ts_idx * arith.index(_split_k_out_row_stride) + return (row_byte_base, t_ok) + + def store_pair_splitk(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + row_byte_base = row_ctx + col_idx = col_g0 + arith.index(_split_k_n_offset[0]) + byte_off_col = col_idx * arith.index(out_elem_bytes) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1) + out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + out_ptr_v, + frag_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=_split_k_e_vec * out_elem_bytes, + ) + + _cshuffle_nlane_splitk = min(32, tile_n // _split_k_e_vec) + _splitk_frag_elem = ir.F32Type.get() + + # Pass 1: gate (offset=0) + _split_k_acc[0] = acc_gate + _split_k_sw_vals[0] = sw_gate_vals + _split_k_n_offset[0] = 0 + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_split_k_e_vec, + cshuffle_nlane=_cshuffle_nlane_splitk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_splitk_frag_elem, + write_row_to_lds=write_row_to_lds_splitk, + precompute_row=precompute_row_splitk, + store_pair=store_pair_splitk, + ) + + gpu.barrier() + + # Pass 2: up (offset=inter_dim) + _split_k_acc[0] = acc_up + _split_k_sw_vals[0] = sw_up_vals + _split_k_n_offset[0] = inter_dim + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_split_k_e_vec, + cshuffle_nlane=_cshuffle_nlane_splitk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_splitk_frag_elem, + write_row_to_lds=write_row_to_lds_splitk, + precompute_row=precompute_row_splitk, + store_pair=store_pair_splitk, + ) + return + if use_cshuffle_epilog_flag: if lds_out is None: raise RuntimeError("CShuffle epilogue enabled but lds_out is not allocated/aliased.") @@ -1463,7 +1661,7 @@ def launch_moe_gemm1( i32_k_in, i32_size_expert_ids_in, ).launch( - grid=(gx, gy, 1), + grid=(gx, gy, k_batch), block=(256, 1, 1), stream=stream, ) diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index ca819bba..b979b3fe 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -336,6 +336,7 @@ def run_moe_stage1( scale_dtype: str = "f32", even_dispatch: bool = False, out_dtype: str = "f16", + k_batch: int = 1, ): assert model_dim % 64 == 0 assert model_dim % tile_k == 0 @@ -562,9 +563,13 @@ def run_moe_stage1( scale_w1_1d = scale_w1_flat.view(-1).contiguous() sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] - # Output: [tokens, topk, inter_dim] + # Output: normal=[tokens, topk, inter_dim] f16/bf16, split-K=[tokens*topk, 2*inter_dim] f32 _out_torch_dtype = torch.bfloat16 if out_dtype == "bf16" else torch.float16 - out = torch.empty((tokens, topk, inter_dim), device=device, dtype=_out_torch_dtype) + _is_splitk = k_batch > 1 + if _is_splitk: + out = torch.zeros((tokens * topk, 2 * inter_dim), device=device, dtype=torch.float32) + else: + out = torch.empty((tokens, topk, inter_dim), device=device, dtype=_out_torch_dtype) if is_fp4: exe = compile_mixed_moe_gemm1( @@ -605,9 +610,10 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): tile_n=tile_n, tile_k=tile_k, doweight_stage1=bool(doweight_stage1), - use_cshuffle_epilog=False, + use_cshuffle_epilog=None if _is_splitk else False, scale_is_bf16=(scale_dtype == "bf16"), out_dtype=out_dtype, + k_batch=k_batch, ) def _s1_args(o, x, w, sx, sw, st, eids, sw_sorted): @@ -618,6 +624,8 @@ def _s1_args(o, x, w, sx, sw, st, eids, sw_sorted): compiled_exe = flyc.compile(exe, *_s1_args(out, x_q, w_kernel, scale_x_1d, scale_w1_1d, sorted_token_ids, sorted_expert_ids, sorted_weights_1d)) def launch(o, x, w, sx, sw, st, eids, sw_sorted): + if _is_splitk: + o.zero_() compiled_exe(*_s1_args(o, x, w, sx, sw, st, eids, sw_sorted)) _, us = run_perftest( @@ -636,6 +644,14 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): ) torch.cuda.synchronize() + # Split-K post-processing: apply silu(gate)*up on host, reshape to [tokens, topk, inter_dim] + # Note: the gfx950 v_cvt_off_f32_i4 x16 correction is already applied per-CTA in the kernel + # epilogue (linear factor commutes with summation: sum(x_i*16) = 16*sum(x_i)). + if _is_splitk: + gate = out[:, :inter_dim] # [tokens*topk, inter_dim] f32 + up = out[:, inter_dim:] # [tokens*topk, inter_dim] f32 + out = (torch.nn.functional.silu(gate) * up).to(_out_torch_dtype).view(tokens, topk, inter_dim) + if not bool(skip_ref): if is_int8smooth: # x_q is slot-major [topk, tokens, K]; convert to [tokens, topk, K] for ref. From f65e930c551bff84a07e84c94d727e83ae80eecd Mon Sep 17 00:00:00 2001 From: XingerZhu Date: Thu, 16 Apr 2026 11:18:47 +0800 Subject: [PATCH 13/29] Gfx1250 moe (#402) --- kernels/gemm_common_gfx1250.py | 32 +- kernels/moe_gemm_2stage_common_gfx1250.py | 1089 +++++++ kernels/moe_gemm_2stage_mxscale_gfx1250.py | 2889 +++++++++++++++++ kernels/moe_gemm_2stage_wmma_gfx1250.py | 912 ++++++ python/flydsl/expr/rocdl/tdm_ops.py | 324 ++ tests/kernels/benchmark_common.py | 443 +++ .../kernels/test_moe_gemm_mxscale_gfx1250.py | 1756 ++++++++++ tests/kernels/test_moe_gemm_wmma_gfx1250.py | 1419 ++++++++ 8 files changed, 8856 insertions(+), 8 deletions(-) create mode 100644 kernels/moe_gemm_2stage_common_gfx1250.py create mode 100644 kernels/moe_gemm_2stage_mxscale_gfx1250.py create mode 100644 kernels/moe_gemm_2stage_wmma_gfx1250.py create mode 100644 tests/kernels/test_moe_gemm_mxscale_gfx1250.py create mode 100644 tests/kernels/test_moe_gemm_wmma_gfx1250.py diff --git a/kernels/gemm_common_gfx1250.py b/kernels/gemm_common_gfx1250.py index 07719f34..0fe466ff 100644 --- a/kernels/gemm_common_gfx1250.py +++ b/kernels/gemm_common_gfx1250.py @@ -1,7 +1,7 @@ """Shared utilities for gfx1250 GEMM kernels (fp16 / mxfp4 / mxfp8). """ from flydsl._mlir import ir -from flydsl._mlir.dialects import llvm as llvm_dialect +from flydsl._mlir.dialects import llvm as llvm_dialect, scf from flydsl.expr import arith, buffer_ops, gpu, rocdl, tdm_ops, vector from flydsl.expr.arith import _to_raw as _raw from flydsl.expr.typing import T @@ -97,18 +97,28 @@ def lds_transpose_load_raw(result_type, lds_base_idx, byte_offset): return _rocdl.ds_load_tr16_b128(result_type, ptr_val) -def pipeline_fence(outstanding=0, use_cluster=False): - """Fused READY+REUSE fence for gfx1250 multi-buffer pipeline. +def workgroup_barrier(use_cluster=False): + """Issue the appropriate barrier for LDS visibility. - Issues ``s_wait_tensorcnt`` followed by the appropriate barrier. + Cluster mode layers an inter-workgroup barrier on top of the regular + workgroup barrier protocol, so call sites can treat it as a single + "LDS is now readable" fence. """ - tdm_ops.tensor_wait(outstanding) if use_cluster: gpu.cluster_barrier() else: gpu.barrier() +def pipeline_fence(outstanding=0, use_cluster=False): + """Fused READY+REUSE fence for gfx1250 multi-buffer pipeline. + + Issues ``s_wait_tensorcnt`` followed by the appropriate barrier. + """ + tdm_ops.tensor_wait(outstanding) + workgroup_barrier(use_cluster=use_cluster) + + WGP_BARRIER_ID = -1 @@ -146,10 +156,15 @@ def issue_tdm_loads(*descs, wave_specialized=False, wave_id=None): if wave_id is None: wave_id = rocdl.wave_id() for idx, desc in enumerate(descs): - if arith.cmpi( - arith.CmpIPredicate.eq, wave_id, arith.constant(idx, type=T.i32) - ): + is_loader_wave = arith.cmpi( + arith.CmpIPredicate.eq, + wave_id, + arith.constant(idx, type=T.i32), + ) + if_op = scf.IfOp(is_loader_wave) + with ir.InsertionPoint(if_op.then_block): tdm_ops.tensor_load_2d(desc) + scf.YieldOp([]) return for desc in descs: @@ -234,6 +249,7 @@ def store_acc_vec8_to_buffer(acc_vec8, c_rsrc, addr, "lds_load_b128_raw", "lds_transpose_load_raw", # Pipeline + "workgroup_barrier", "pipeline_fence", "pipeline_fence_signal", "pipeline_fence_wait", diff --git a/kernels/moe_gemm_2stage_common_gfx1250.py b/kernels/moe_gemm_2stage_common_gfx1250.py new file mode 100644 index 00000000..e6cb4a59 --- /dev/null +++ b/kernels/moe_gemm_2stage_common_gfx1250.py @@ -0,0 +1,1089 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + + +"""Shared utilities for gfx1250 MoE 2-stage kernels. + +Common helpers used by both the fp16 WMMA kernels and the mxscale +(fp4/fp8/a8w4) kernels. +""" + +from __future__ import annotations + +import inspect +from typing import Any + +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + + +def _require_gfx1250() -> None: + arch = str(get_hip_arch()) + if not arch.startswith("gfx1250"): + raise RuntimeError(f"Expected gfx1250 architecture, got {arch!r}") + + +def _align_up(v: int, a: int) -> int: + return ((int(v) + int(a) - 1) // int(a)) * int(a) + + +def _pick_fp4_warp_shape(tile_m: int, tile_n: int) -> tuple[int, int]: + """Pick a legal (m_warp, n_warp) for compile_mxfp4_gemm constraints.""" + for m_warp in (4, 2, 1): + if tile_m % m_warp != 0: + continue + warp_tile_m = tile_m // m_warp + if (warp_tile_m % 16) != 0: + continue + for n_warp in (4, 2, 1): + if tile_n % n_warp != 0: + continue + warp_tile_n = tile_n // n_warp + if (warp_tile_n % 32) == 0: + return m_warp, n_warp + raise ValueError( + f"Cannot find legal (m_warp,n_warp) for FP4 GEMM with tile_m={tile_m}, tile_n={tile_n}. " + "Need warp_tile_m multiple of 16 and warp_tile_n multiple of 32." + ) + + +def _pick_fp16_single_launch_shape(route_tile_m: int, route_tile_n: int, + max_total_warps: int = 0) -> tuple[int, int, int, int]: + """Pick launch shape for fp16 stage1 single-kernel path. + + Single-kernel path should follow route tile size (not backend-expanded 128x*) + while keeping legal WMMA tile decomposition. + """ + tile_m = _align_up(int(route_tile_m), 16) + tile_n = _align_up(int(route_tile_n), 16) + for mw in (4, 2, 1): + if tile_m % mw != 0: + continue + if (tile_m // mw) % 16 != 0: + continue + for nw in (8, 4, 2, 1): + if max_total_warps > 0 and mw * nw > max_total_warps: + continue + if tile_n % nw != 0: + continue + if (tile_n // nw) % 16 != 0: + continue + return tile_m, tile_n, mw, nw + raise ValueError( + f"Cannot find legal single-kernel fp16 shape for tile_m={route_tile_m}, tile_n={route_tile_n}" + ) + + +def _compile_with_optional_wpe(fn, kwargs: dict[str, Any]): + sig = inspect.signature(fn) + if "waves_per_eu" not in sig.parameters: + kwargs = {k: v for k, v in kwargs.items() if k != "waves_per_eu"} + return fn(**kwargs) + + +def _bf16_to_f16_wrapper(fp16_exe, x_arg: int, w_arg: int): + """Wrap a compiled fp16 kernel to accept bf16 inputs by converting them to fp16 on the host.""" + import torch + + def wrapper(*args, **kwargs): + args = list(args) + for idx in (x_arg, w_arg): + if idx < len(args) and hasattr(args[idx], 'dtype') and args[idx].dtype == torch.bfloat16: + args[idx] = args[idx].to(torch.float16) + return fp16_exe(*args, **kwargs) + + for attr in ('mode',): + if hasattr(fp16_exe, attr): + setattr(wrapper, attr, getattr(fp16_exe, attr)) + return wrapper + + +def _pick_mxscale_launch_shape(data_format: str, route_tile_m: int, tile_n: int) -> tuple[int, int, int, int]: + if data_format not in ("fp4", "fp8", "a8w4"): + raise ValueError(f"data_format must be 'fp4', 'fp8', or 'a8w4', got {data_format!r}") + if data_format == "fp4": + single_tile_m = _align_up(int(route_tile_m), 16) + single_tile_n = _align_up(int(tile_n), 32) + single_m_warp, single_n_warp = _pick_fp4_warp_shape(single_tile_m, single_tile_n) + return single_tile_m, single_tile_n, single_m_warp, single_n_warp + return _pick_fp16_single_launch_shape(int(route_tile_m), int(tile_n), max_total_warps=8) + + +def _make_moe_wave_layout(*, m_warp: int, n_warp: int, WAVE_SIZE: int, fx): + return fx.make_layout( + (int(m_warp), int(n_warp), 2, 16), + (int(n_warp) * WAVE_SIZE, WAVE_SIZE, 16, 1), + ) + + +def _make_wmma_sub_tiles( + *, wmma_m_rep: int, wmma_n_rep: int, WMMA_M: int, is_fp4: bool +) -> list[tuple[int, int, int, int]]: + sub_tiles = [] + for wm in range(wmma_m_rep): + for wn in range(wmma_n_rep): + if is_fp4: + for half in range(2): + sub_tiles.append((wm * wmma_n_rep + wn, half * 8, wm * WMMA_M, wn * 2 + half)) + else: + sub_tiles.append((wm * wmma_n_rep + wn, 0, wm * WMMA_M, wn)) + return sub_tiles + + +def _moe_out_elem_ty(out_dtype: str, T): + return T.f16 if out_dtype == "f16" else T.bf16 + + +def _extract_sub8(acc, vec_base: int, *, vector, range_constexpr, ACC_VEC_SIZE: int): + if ACC_VEC_SIZE == 8: + return acc + return vector.shuffle(acc, acc, [vec_base + i for i in range_constexpr(8)]) + + +def _finalize_alloc_and_launch_2d(*, ctx, alloc, launcher, gx, gy, block_threads: int, stream, waves_per_eu, ir, + cluster=None): + with ir.InsertionPoint(ctx.gpu_module_body): + alloc.finalized = False + alloc.finalize() + for op in ctx.gpu_module_body.operations: + if hasattr(op, "attributes") and op.OPERATION_NAME == "gpu.func": + if waves_per_eu is not None and int(waves_per_eu) >= 1: + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), int(waves_per_eu) + ) + if cluster is not None: + op.attributes["rocdl.cluster_dims"] = ir.StringAttr.get( + f"{cluster[0]},{cluster[1]},{cluster[2]}") + launcher.launch( + grid=(gx, gy, 1), + block=(block_threads, 1, 1), + stream=stream, + cluster=cluster, + ) + + +def _emit_stage1_gate_up_epilogue( + *, + sub_tiles, + by, + tile_m: int, + route_tile_m: int, + warp_m_base, + warp_n_base, + blk_n, + lane16, + lane_kgrp, + WMMA_N: int, + i32_tokens_in, + i32_inter_in, + topk: int, + num_valid_i32=None, + block_row_start=None, + sorted_rsrc, + tw_rsrc, + out_rsrc, + doweight_stage1: bool, + out_elem_ty, + load_gate_up_sub8, + silu_fn, + ir, + fx, + arith, + buffer_ops, + scf, + vector, + range_constexpr, + T, +): + c_topk_i32 = arith.constant(int(topk), type=T.i32) + default_block_row_start = arith.index_cast(T.i32, by * arith.index(int(route_tile_m))) + row_base_i32 = block_row_start if block_row_start is not None else default_block_row_start + for acc_idx, vec_base, m_off, wn in sub_tiles: + row_local = warp_m_base + fx.Index(m_off) + lane16 + sorted_row = by * arith.index(int(tile_m)) + row_local + row_i32 = arith.index_cast(T.i32, row_local) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + row_i32, + arith.constant(int(route_tile_m), type=T.i32), + ) + if num_valid_i32 is None: + row_ok_meta = row_in_route + else: + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok_meta = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select( + row_ok_meta, + sorted_i32, + row_base_i32, + ) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32)) + slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, slot, arith.constant(int(topk), type=T.i32)) + row_ok = arith.andi(row_ok_meta, arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1))) + sub8g, sub8u = load_gate_up_sub8(acc_idx, vec_base) + tw = buffer_ops.buffer_load(tw_rsrc, sorted_safe, vec_width=1, dtype=T.f32) if bool(doweight_stage1) else arith.constant(1.0, type=T.f32) + col_base = blk_n + warp_n_base + fx.Index(wn * WMMA_N) + lane_kgrp * fx.Index(8) + for vi in range_constexpr(8): + col = col_base + fx.Index(vi) + col_ok = arith.cmpi(arith.CmpIPredicate.ult, arith.index_cast(T.i32, col), i32_inter_in) + out_ok = arith.andi(row_ok, col_ok) + _if_out = scf.IfOp(out_ok) + with ir.InsertionPoint(_if_out.then_block): + vg = vector.extract(sub8g, static_position=[vi], dynamic_position=[]) + vu = vector.extract(sub8u, static_position=[vi], dynamic_position=[]) + y = silu_fn(vg) * vu + if bool(doweight_stage1): + y = y * tw + out_v = arith.trunc_f(out_elem_ty, y) + out_idx = ((tok * c_topk_i32 + slot) * i32_inter_in + + arith.index_cast(T.i32, col)) + buffer_ops.buffer_store(out_v, out_rsrc, out_idx) + scf.YieldOp([]) + + +def _emit_stage2_store_epilogue( + *, + sub_tiles, + by, + tile_m: int, + route_tile_m: int, + warp_m_base, + warp_n_base, + blk_n, + lane16, + lane_kgrp, + WMMA_N: int, + i32_tokens_in, + i32_n_in, + topk: int, + num_valid_i32, + block_row_start, + sorted_rsrc, + tw_rsrc, + out_rsrc, + doweight_stage2: bool, + accumulate: bool, + out_elem_ty, + load_sub8, + ir, + fx, + arith, + buffer_ops, + scf, + vector, + range_constexpr, + rocdl, + T, +): + c_topk_i32 = arith.constant(int(topk), type=T.i32) + c2_i32 = arith.constant(2, type=T.i32) + zero_i32 = arith.constant(0, type=T.i32) + mask_even_i32 = arith.constant(0xFFFFFFFE, type=T.i32) + + def atomic_add_x2(val_x2, byte_off_i32): + rocdl.raw_ptr_buffer_atomic_fadd(val_x2, out_rsrc, byte_off_i32, zero_i32, zero_i32) + + for acc_idx, vec_base, m_off, wn in sub_tiles: + row_local = warp_m_base + fx.Index(m_off) + lane16 + sorted_row = by * arith.index(int(tile_m)) + row_local + row_i32 = arith.index_cast(T.i32, row_local) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi(arith.CmpIPredicate.ult, row_i32, arith.constant(int(route_tile_m), type=T.i32)) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32)) + slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, slot, c_topk_i32) + row_store_ok = arith.andi(row_ok, arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1))) + ts = tok * c_topk_i32 + slot + sub8 = load_sub8(acc_idx, vec_base) + tw = buffer_ops.buffer_load(tw_rsrc, sorted_safe, vec_width=1, dtype=T.f32) if bool(doweight_stage2) else arith.constant(1.0, type=T.f32) + col_base = blk_n + warp_n_base + fx.Index(wn * WMMA_N) + lane_kgrp * fx.Index(8) + if bool(accumulate): + for vpair in range_constexpr(4): + vi0 = vpair * 2 + vi1 = vi0 + 1 + col0 = col_base + fx.Index(vi0) + col1 = col_base + fx.Index(vi1) + col0_i32 = arith.index_cast(T.i32, col0) + col1_i32 = arith.index_cast(T.i32, col1) + col0_ok = arith.cmpi(arith.CmpIPredicate.ult, col0_i32, i32_n_in) + col1_ok = arith.cmpi(arith.CmpIPredicate.ult, col1_i32, i32_n_in) + out_ok = arith.andi(row_store_ok, col0_ok) + _if_out = scf.IfOp(out_ok) + with ir.InsertionPoint(_if_out.then_block): + v0 = vector.extract(sub8, static_position=[vi0], dynamic_position=[]) + v1 = vector.extract(sub8, static_position=[vi1], dynamic_position=[]) + if bool(doweight_stage2): + v0 = v0 * tw + v1 = v1 * tw + v1 = arith.select(col1_ok, v1, arith.constant(0.0, type=T.f32)) + out0 = arith.trunc_f(out_elem_ty, v0) + out1 = arith.trunc_f(out_elem_ty, v1) + frag = vector.from_elements(T.vec(2, out_elem_ty), [out0, out1]) + idx0 = tok * i32_n_in + col0_i32 + idx_even = idx0 & mask_even_i32 + byte_off = idx_even * c2_i32 + atomic_add_x2(frag, byte_off) + scf.YieldOp([]) + else: + for vi in range_constexpr(8): + col = col_base + fx.Index(vi) + col_ok = arith.cmpi(arith.CmpIPredicate.ult, arith.index_cast(T.i32, col), i32_n_in) + out_ok = arith.andi(row_store_ok, col_ok) + _if_out = scf.IfOp(out_ok) + with ir.InsertionPoint(_if_out.then_block): + v = vector.extract(sub8, static_position=[vi], dynamic_position=[]) + if bool(doweight_stage2): + v = v * tw + col_i32 = arith.index_cast(T.i32, col) + out_idx = ts * i32_n_in + col_i32 + out_v = arith.trunc_f(out_elem_ty, v) + buffer_ops.buffer_store(out_v, out_rsrc, out_idx) + scf.YieldOp([]) + + +def _pack_stage1_gate_up_tiles(tensor, *, experts: int, inter_dim: int, tile_n: int, cols: int): + """Pack stage1 gate/up rows into [gate_tile0, up_tile0, gate_tile1, up_tile1, ...].""" + import torch + + if tensor is None: + return None + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Expected torch.Tensor for stage1 gate/up packing, got {type(tensor)!r}") + if tensor.numel() == 0: + return tensor + elems_per_expert = int(2 * inter_dim) * int(cols) + if tensor.numel() != int(experts) * elems_per_expert: + if tensor.numel() % elems_per_expert != 0: + raise ValueError( + "Unexpected stage1 tensor size for gate/up packing: " + f"numel={tensor.numel()} expected={int(experts) * elems_per_expert} " + f"(experts={experts}, inter_dim={inter_dim}, cols={cols})" + ) + experts = tensor.numel() // elems_per_expert + expected_rows = int(experts) * int(2 * inter_dim) + if int(inter_dim) % int(tile_n) != 0: + raise ValueError( + f"Stage1 gate/up packed layout requires inter_dim divisible by tile_n, got {inter_dim} and {tile_n}" + ) + + tensor_3d = tensor.contiguous().view(int(experts), int(2 * inter_dim), int(cols)) + gate = tensor_3d[:, :int(inter_dim), :] + up = tensor_3d[:, int(inter_dim):, :] + gate_tiles = gate.view(int(experts), int(inter_dim // tile_n), int(tile_n), int(cols)) + up_tiles = up.view(int(experts), int(inter_dim // tile_n), int(tile_n), int(cols)) + packed = torch.cat((gate_tiles, up_tiles), dim=2) + return packed.view(expected_rows, int(cols)) + + +class _Stage1GateUpPackedWrapper: + """Host-side wrapper that repacks stage1 W1 rows to match the merged gate/up TDM layout.""" + + def __init__( + self, + stage1_exe, + *, + experts: int, + inter_dim: int, + tile_n: int, + packed_cols_w: int, + packed_cols_scale: int, + ): + self._stage1_exe = stage1_exe + self._experts = int(experts) + self._inter_dim = int(inter_dim) + self._tile_n = int(tile_n) + self._packed_cols_w = int(packed_cols_w) + self._packed_cols_scale = int(packed_cols_scale) + self._cache = {} + + for attr in ("mode", "compile_hints"): + if hasattr(stage1_exe, attr): + setattr(self, attr, getattr(stage1_exe, attr)) + + def _get_packed_operands(self, arg_w, arg_scale_w): + key = (id(arg_w), id(arg_scale_w)) + cached = self._cache.get(key) + if cached is not None: + return cached[0] + + packed_w = _pack_stage1_gate_up_tiles( + arg_w, + experts=self._experts, + inter_dim=self._inter_dim, + tile_n=self._tile_n, + cols=self._packed_cols_w, + ) + if hasattr(arg_scale_w, "numel") and int(arg_scale_w.numel()) > 0: + packed_scale_w = _pack_stage1_gate_up_tiles( + arg_scale_w, + experts=self._experts, + inter_dim=self._inter_dim, + tile_n=self._tile_n, + cols=self._packed_cols_scale, + ) + else: + packed_scale_w = arg_scale_w + + # Store (result, original_refs) — the strong refs to originals + # prevent id() reuse while the entry is alive. + self._cache[key] = ((packed_w, packed_scale_w), (arg_w, arg_scale_w)) + return packed_w, packed_scale_w + + def __call__(self, *args, **kwargs): + args = list(args) + if len(args) > 4: + args[2], args[4] = self._get_packed_operands(args[2], args[4]) + return self._stage1_exe(*args, **kwargs) + + +# --------------------------------------------------------------------------- +# MXScale format infrastructure helpers +# --------------------------------------------------------------------------- + + +def _mxscale_format_config(data_format: str) -> dict[str, int | bool]: + if data_format not in ("fp4", "fp8", "a8w4"): + raise ValueError(f"data_format must be 'fp4', 'fp8', or 'a8w4', got {data_format!r}") + is_fp4 = data_format == "fp4" + is_a8w4 = data_format == "a8w4" + pack_factor_a = 1 if not is_fp4 else 2 + pack_factor_b = 2 if (is_fp4 or is_a8w4) else 1 + wmma_n_eff = 32 if is_fp4 else 16 + acc_vec_size = 16 if is_fp4 else 8 + ds_loads_per_a_frag = 2 if is_fp4 else 4 + return { + "is_fp4": is_fp4, + "is_a8w4": is_a8w4, + "PACK_FACTOR_A": pack_factor_a, + "PACK_FACTOR_B": pack_factor_b, + "WMMA_N_EFF": wmma_n_eff, + "ACC_VEC_SIZE": acc_vec_size, + "DS_LOADS_PER_A_FRAG": ds_loads_per_a_frag, + } + + +def _mxscale_precompute_preshuffled_b_data_bases( + *, + packed_tile_k_b: int, + warp_tile_n, + wave_n_idx, + lane16, + lane_kgrp, + wmma_n_rep: int, + arith, + range_constexpr, +): + ngroup_stride = packed_tile_k_b * 16 + n_group_base = arith.index(warp_tile_n // 16) * wave_n_idx + row_off = lane16 * arith.index(16) + k_tile_off = lane_kgrp * arith.index(256) + bases = [] + for wn in range_constexpr(wmma_n_rep): + ngroup_off = n_group_base * arith.index(ngroup_stride) + arith.index(wn * ngroup_stride) + bases.append(ngroup_off + row_off + k_tile_off) + return bases + + +def _mxscale_precompute_a_scale_lane_bases( + *, + warp_m_base, + lane16, + wmma_m_rep: int, + interleaved_scale_cols_a: int, + arith, +): + warp_lds_row = warp_m_base / arith.index(wmma_m_rep) + lane16 + base = warp_lds_row * arith.index(interleaved_scale_cols_a) + return [base] + + +def _mxscale_load_scale_b128( + *, + lds_buffer, + scale_base, + reps: int, + ks, + SCALES_PER_WMMA: int, + _lds_load_b128, + arith, + vector, + range_constexpr, +): + ks_byte_off = ks * reps * SCALES_PER_WMMA + eff_base = scale_base if ks_byte_off == 0 else scale_base + arith.index(ks_byte_off) + num_loads = (reps + 3) // 4 + vecs = [] + for ld in range_constexpr(num_loads): + off = eff_base if ld == 0 else eff_base + arith.index(ld * 16) + vecs.append(_lds_load_b128(lds_buffer, off)) + results = [] + for i in range_constexpr(reps): + vi = vector.extract(vecs[i // 4], static_position=[i % 4], dynamic_position=[]) + results.append(vi) + return results + + +def _mxscale_load_preshuffled_b_frag( + *, + lds_buffer, + b_lane_bases, + wn: int, + ks, + is_fp4: bool, + is_a8w4: bool, + PACK_FACTOR_B: int, + WMMA_K: int, + _lds_load_b128, + arith, + vector, +): + num_tiles = WMMA_K // PACK_FACTOR_B // 16 + k_subtile_off = arith.index(ks * num_tiles * 256) + if is_fp4: + base0 = b_lane_bases[wn * 2] + k_subtile_off + base1 = b_lane_bases[wn * 2 + 1] + k_subtile_off + v0 = _lds_load_b128(lds_buffer, base0) + v1 = _lds_load_b128(lds_buffer, base0 + arith.index(512)) + v2 = _lds_load_b128(lds_buffer, base1) + v3 = _lds_load_b128(lds_buffer, base1 + arith.index(512)) + v01 = vector.shuffle(v0, v1, list(range(8))) + v23 = vector.shuffle(v2, v3, list(range(8))) + return vector.shuffle(v01, v23, list(range(16))) + base0 = b_lane_bases[wn] + k_subtile_off + v0 = _lds_load_b128(lds_buffer, base0) + v1 = _lds_load_b128(lds_buffer, base0 + arith.index(512)) + if is_a8w4: + return vector.shuffle(v0, v1, list(range(8))) + v2 = _lds_load_b128(lds_buffer, base0 + arith.index(1024)) + v3 = _lds_load_b128(lds_buffer, base0 + arith.index(1536)) + v01 = vector.shuffle(v0, v1, list(range(8))) + v23 = vector.shuffle(v2, v3, list(range(8))) + return vector.shuffle(v01, v23, list(range(16))) + + +def _mxscale_load_scale_i32( + *, + lds_buffer, + scale_base, + ks, + SCALES_PER_WMMA: int, + llvm_dialect, + ir, + arith, + T, +): + byte_off = scale_base + arith.index(ks * SCALES_PER_WMMA) + ptr_val = _mxscale_lds_ptr(lds_buffer, byte_off, ir=ir, arith=arith, T=T) + return llvm_dialect.load(ir.IntegerType.get_signless(32), ptr_val) + + +def _mxscale_precompute_a_data_bases( + *, + warp_m_base, + lane16, + lane_kgrp, + lds_a_stride_bytes: int, + wmma_m_rep: int, + WMMA_M: int, + is_fp4: bool, + arith, + range_constexpr, +): + row_base = (warp_m_base + lane16) * arith.index(lds_a_stride_bytes) + k_half_off = lane_kgrp * arith.index(32 if is_fp4 else 16) + return [ + row_base + arith.index(wm * WMMA_M * lds_a_stride_bytes) + k_half_off + for wm in range_constexpr(wmma_m_rep) + ] + + +def _mxscale_precompute_rowmajor_b_data_bases( + *, + warp_n_base, + lane16, + lane_kgrp, + lds_b_stride_bytes: int, + wmma_n_rep: int, + WMMA_N: int, + arith, + range_constexpr, +): + return [ + (warp_n_base + lane16) * arith.index(lds_b_stride_bytes) + + lane_kgrp * arith.index(32) + + arith.index(wnh * WMMA_N * lds_b_stride_bytes) + for wnh in range_constexpr(wmma_n_rep * 2) + ] + + +def _mxscale_precompute_rowmajor_scale_lane_bases( + *, + warp_base, + lane16, + scale_k_per_tile: int, + reps: int, + WMMA_DIM: int, + arith, + range_constexpr, +): + return [ + (warp_base + lane16) * arith.index(int(scale_k_per_tile)) + + arith.index(r * WMMA_DIM * int(scale_k_per_tile)) + for r in range_constexpr(reps) + ] + + +def _mxscale_lds_ptr(lds_buffer, byte_offset, *, ir, arith, T): + """Compute an ``!llvm.ptr<3>`` into LDS at *byte_offset*.""" + from flydsl._mlir.dialects import llvm as _llvm, memref as _memref + from flydsl.expr.arith import _to_raw as _raw + from flydsl.expr.arith import ArithValue as _AV + lds_ptr_ty = ir.Type.parse("!llvm.ptr<3>") + raw_memref = arith.unwrap(lds_buffer) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + total_byte = _AV(lds_base) + byte_offset + addr_i32 = _raw(arith.index_cast(T.i32, total_byte)) + return _llvm.inttoptr(lds_ptr_ty, addr_i32) + + +def _mxscale_lds_load_b128(lds_buffer, byte_offset, *, ir, arith, T, llvm_dialect): + """Load a vec4 (16 bytes) from LDS at the given byte offset.""" + ptr_val = _mxscale_lds_ptr(lds_buffer, byte_offset, ir=ir, arith=arith, T=T) + return llvm_dialect.load( + ir.VectorType.get([4], ir.IntegerType.get_signless(32)), ptr_val, + ) + + +def _mxscale_load_data_frag( + *, + lds_buffer, + lane_base, + ks, + PACK_FACTOR_A: int, + WMMA_K: int, + is_fp4: bool, + _lds_load_b128, + arith, + vector, +): + byte_off = lane_base + arith.index(ks * WMMA_K // PACK_FACTOR_A) + v0 = _lds_load_b128(lds_buffer, byte_off) + if is_fp4: + v1 = _lds_load_b128(lds_buffer, byte_off + arith.index(16)) + return vector.shuffle(v0, v1, list(range(8))) + v1 = _lds_load_b128(lds_buffer, byte_off + arith.index(32)) + v2 = _lds_load_b128(lds_buffer, byte_off + arith.index(64)) + v3 = _lds_load_b128(lds_buffer, byte_off + arith.index(96)) + v01 = vector.shuffle(v0, v1, list(range(8))) + v23 = vector.shuffle(v2, v3, list(range(8))) + return vector.shuffle(v01, v23, list(range(16))) + + +def _mxscale_load_rowmajor_b_frag( + *, + lds_buffer, + b_lane_bases, + wn: int, + ks, + PACK_FACTOR_B: int, + WMMA_K: int, + _lds_load_b128, + arith, + vector, +): + k_byte_off = arith.index(ks * WMMA_K // PACK_FACTOR_B) + base0 = b_lane_bases[wn * 2] + k_byte_off + base1 = b_lane_bases[wn * 2 + 1] + k_byte_off + v0 = _lds_load_b128(lds_buffer, base0) + v1 = _lds_load_b128(lds_buffer, base0 + arith.index(16)) + v2 = _lds_load_b128(lds_buffer, base1) + v3 = _lds_load_b128(lds_buffer, base1 + arith.index(16)) + v01 = vector.shuffle(v0, v1, list(range(8))) + v23 = vector.shuffle(v2, v3, list(range(8))) + return vector.shuffle(v01, v23, list(range(16))) + + +def _mxscale_emit_wmma( + *, + accs, + wm: int, + wn: int, + a_frag, + b_frags, + a_scales, + b_scales, + is_fp4: bool, + is_a8w4: bool, + use_scale_opsel: bool, + rocdl, + T, +): + idx = wm * len(b_frags) + wn + if use_scale_opsel: + a_scale_idx = wm // 2 + a_opsel = wm % 2 + else: + a_scale_idx = wm + a_opsel = 0 + + if is_fp4: + accs[idx] = rocdl.wmma_scale_f32_32x16x128_f4( + T.vec(16, T.f32), + b_frags[wn], a_frag, accs[idx], + b_scales[wn * 2], a_scales[a_scale_idx], + scaleAType=0, + scaleBType=a_opsel, + ) + return + + if use_scale_opsel: + b_scale_idx = wn // 2 + b_opsel = wn % 2 + else: + b_scale_idx = wn + b_opsel = 0 + accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( + T.vec(8, T.f32), + b_frags[wn], a_frag, accs[idx], + b_scales[b_scale_idx], a_scales[a_scale_idx], + fmtA=4 if is_a8w4 else 0, + fmtB=0, + scaleAType=b_opsel, + scaleBType=a_opsel, + ) + + +# --------------------------------------------------------------------------- +# Shared tiling / pipeline / loader helpers for mxscale stage1 & stage2 +# --------------------------------------------------------------------------- + + +def _compute_mxscale_tiling( + *, + data_format: str, + K: int, + tile_m: int, + tile_n: int, + tile_k: int, + m_warp: int, + n_warp: int, + out_dtype: str, + num_buffers: int, + cluster_m: int = 1, + cluster_n: int = 1, + stage_name: str = "", +) -> dict: + """Derive all shared tiling / format constants for an mxscale stage kernel.""" + fmt_cfg = _mxscale_format_config(data_format) + is_fp4 = bool(fmt_cfg["is_fp4"]) + is_a8w4 = bool(fmt_cfg["is_a8w4"]) + PACK_FACTOR_A = int(fmt_cfg["PACK_FACTOR_A"]) + PACK_FACTOR_B = int(fmt_cfg["PACK_FACTOR_B"]) + ACC_VEC_SIZE = int(fmt_cfg["ACC_VEC_SIZE"]) + WMMA_N_EFF = int(fmt_cfg["WMMA_N_EFF"]) + DS_LOADS_PER_A_FRAG = int(fmt_cfg["DS_LOADS_PER_A_FRAG"]) + + WMMA_M, WMMA_N, WMMA_K = 16, 16, 128 + SCALE_BLOCK = 32 + SCALES_PER_WMMA = WMMA_K // SCALE_BLOCK + WAVE_SIZE = 32 + LDS_PAD_A_BYTES = 16 + LDS_PAD_B_BYTES = 16 if is_fp4 else 0 + + if out_dtype not in ("f16", "bf16"): + raise ValueError( + f"mxscale {stage_name} single kernel supports out_dtype " + f"in ('f16','bf16'), got {out_dtype!r}" + ) + if (K % int(tile_k)) != 0: + raise ValueError(f"K={K} must be divisible by tile_k={tile_k}") + if (int(tile_k) % WMMA_K) != 0: + raise ValueError(f"tile_k={tile_k} must be divisible by {WMMA_K}") + if (int(tile_k) % SCALE_BLOCK) != 0: + raise ValueError(f"tile_k={tile_k} must be divisible by {SCALE_BLOCK}") + if int(num_buffers) not in (1, 2, 3, 4): + raise ValueError(f"num_buffers must be 1, 2, 3, or 4, got {num_buffers}") + use_cluster = int(cluster_m) > 1 or int(cluster_n) > 1 + if use_cluster and int(cluster_m) * int(cluster_n) > 16: + raise ValueError( + f"cluster_m * cluster_n must be <= 16, got {cluster_m}*{cluster_n}" + ) + + K_packed_a = K // PACK_FACTOR_A + K_packed_b = K // PACK_FACTOR_B + packed_tile_k_a = int(tile_k) // PACK_FACTOR_A + packed_tile_k_b = int(tile_k) // PACK_FACTOR_B + K_scale = K // SCALE_BLOCK + scale_k_per_tile = int(tile_k) // SCALE_BLOCK + block_threads = int(m_warp) * int(n_warp) * WAVE_SIZE + warp_tile_m = int(tile_m) // int(m_warp) + warp_tile_n = int(tile_n) // int(n_warp) + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N_EFF + k_wmma_steps = int(tile_k) // WMMA_K + n_accs = wmma_m_rep * wmma_n_rep + num_k_tiles = K // int(tile_k) + b_scale_load_rep = (wmma_n_rep * 2) if is_fp4 else wmma_n_rep + interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile + + if wmma_m_rep <= 0 or wmma_n_rep <= 0: + raise ValueError( + f"Invalid warp tiling for mxscale {stage_name} single kernel: " + f"wmma_m_rep={wmma_m_rep}, wmma_n_rep={wmma_n_rep}" + ) + + lds_a_stride_bytes = packed_tile_k_a + LDS_PAD_A_BYTES + lds_b_stride_bytes = packed_tile_k_b + LDS_PAD_B_BYTES + lds_a_data_bytes = int(tile_m) * lds_a_stride_bytes + lds_b_data_bytes = int(tile_n) * lds_b_stride_bytes + lds_a_scale_bytes = int(tile_m) * scale_k_per_tile + lds_b_scale_bytes = int(tile_n) * scale_k_per_tile + interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile + + return dict( + is_fp4=is_fp4, is_a8w4=is_a8w4, + PACK_FACTOR_A=PACK_FACTOR_A, PACK_FACTOR_B=PACK_FACTOR_B, + ACC_VEC_SIZE=ACC_VEC_SIZE, WMMA_N_EFF=WMMA_N_EFF, + DS_LOADS_PER_A_FRAG=DS_LOADS_PER_A_FRAG, + WMMA_M=WMMA_M, WMMA_N=WMMA_N, WMMA_K=WMMA_K, + SCALE_BLOCK=SCALE_BLOCK, SCALES_PER_WMMA=SCALES_PER_WMMA, + WAVE_SIZE=WAVE_SIZE, + LDS_PAD_A_BYTES=LDS_PAD_A_BYTES, LDS_PAD_B_BYTES=LDS_PAD_B_BYTES, + use_cluster=use_cluster, + K=K, K_packed_a=K_packed_a, K_packed_b=K_packed_b, + packed_tile_k_a=packed_tile_k_a, packed_tile_k_b=packed_tile_k_b, + K_scale=K_scale, scale_k_per_tile=scale_k_per_tile, + block_threads=block_threads, + warp_tile_m=warp_tile_m, warp_tile_n=warp_tile_n, + wmma_m_rep=wmma_m_rep, wmma_n_rep=wmma_n_rep, + k_wmma_steps=k_wmma_steps, n_accs=n_accs, + num_k_tiles=num_k_tiles, + b_scale_load_rep=b_scale_load_rep, + interleaved_scale_cols_b=interleaved_scale_cols_b, + lds_a_stride_bytes=lds_a_stride_bytes, + lds_b_stride_bytes=lds_b_stride_bytes, + lds_a_data_bytes=lds_a_data_bytes, + lds_b_data_bytes=lds_b_data_bytes, + lds_a_scale_bytes=lds_a_scale_bytes, + lds_b_scale_bytes=lds_b_scale_bytes, + interleaved_scale_cols_a=interleaved_scale_cols_a, + ) + + +def _compute_pipeline_plan( + *, + num_k_tiles: int, + num_buffers: int, + B_TDM_PER_STEP: int, + tile_m: int, + use_tdm_gather: bool, + wave_specialized_tdm: bool, + tdm_loader_waves: int, +) -> dict: + """Compute pipeline pre-load / tail plan shared by mxscale stages.""" + from kernels.pipeline_utils import make_tail_plan + + pre_loaded = int(num_buffers) - 1 + loop_iters = (num_k_tiles - pre_loaded) // int(num_buffers) + tail_start = loop_iters * int(num_buffers) + extra = num_k_tiles - tail_start - pre_loaded + A_GATHER_GROUPS = (int(tile_m) + 7) // 8 if bool(use_tdm_gather) else 0 + if bool(use_tdm_gather) and bool(wave_specialized_tdm): + A_GATHER_TDM_PER_STEP = ( + (A_GATHER_GROUPS + tdm_loader_waves - 1) // tdm_loader_waves + ) + else: + A_GATHER_TDM_PER_STEP = A_GATHER_GROUPS + TDM_PER_STEP = B_TDM_PER_STEP + A_GATHER_TDM_PER_STEP + fence_outstanding = TDM_PER_STEP * (int(num_buffers) - 2) + base_tail_plan = make_tail_plan(int(num_buffers), pre_loaded, extra) + tail_plan = [ + (ls, cs, o * TDM_PER_STEP // 2 if o > 0 else o) + for ls, cs, o in base_tail_plan + ] + if num_k_tiles < int(num_buffers): + raise ValueError( + f"{num_buffers}-stage buffering requires num_k_tiles >= {num_buffers}, " + f"got {num_k_tiles}" + ) + return dict( + pre_loaded=pre_loaded, + loop_iters=loop_iters, + tail_start=tail_start, + extra=extra, + A_GATHER_GROUPS=A_GATHER_GROUPS, + TDM_PER_STEP=TDM_PER_STEP, + fence_outstanding=fence_outstanding, + tail_plan=tail_plan, + ) + + +def _compute_tdm_store_layout( + *, + warp_tile_m: int, + warp_tile_n: int, + num_warps: int, + WMMA_N: int, + use_pipeline: bool, +) -> dict: + """Compute TDM-store D output LDS layout, shared by mxscale stages.""" + LDS_PAD_D_BYTES = 16 + elem_bytes_d = 2 # f16/bf16 + lds_d_row_stride = warp_tile_n * elem_bytes_d + LDS_PAD_D_BYTES + warp_d_bytes = warp_tile_m * lds_d_row_stride + total_d_bytes = num_warps * warp_d_bytes + return dict( + lds_d_row_stride=lds_d_row_stride, + warp_d_bytes=warp_d_bytes, + total_d_bytes=total_d_bytes, + d_output_off=0, + lds_d_stride_elems=lds_d_row_stride // 2, + warp_d_elems=warp_d_bytes // 2, + n_col_d_elems=WMMA_N * elem_bytes_d // 2, + d_need_epilogue_fence=use_pipeline, + ) + + +def _make_mxscale_data_loaders( + *, + tiling: dict, + warp_m_base, + warp_n_base, + wave_n_idx, + lane16, + lane_kgrp, + ir, + arith, + vector, + llvm_dialect, + T, + range_constexpr, +) -> dict: + """Create the 9 LDS data-loading adapter closures shared by mxscale stages. + + Returns a dict whose keys match the local names used inside the + ``moe_mxscale_stage*_single`` kernel functions. + """ + is_fp4 = tiling["is_fp4"] + is_a8w4 = tiling["is_a8w4"] + PACK_FACTOR_A = tiling["PACK_FACTOR_A"] + PACK_FACTOR_B = tiling["PACK_FACTOR_B"] + WMMA_K = tiling["WMMA_K"] + WMMA_M = tiling["WMMA_M"] + WMMA_N = tiling["WMMA_N"] + SCALES_PER_WMMA = tiling["SCALES_PER_WMMA"] + lds_a_stride_bytes = tiling["lds_a_stride_bytes"] + lds_b_stride_bytes = tiling["lds_b_stride_bytes"] + packed_tile_k_b = tiling["packed_tile_k_b"] + warp_tile_n = tiling["warp_tile_n"] + wmma_m_rep = tiling["wmma_m_rep"] + wmma_n_rep = tiling["wmma_n_rep"] + scale_k_per_tile = tiling["scale_k_per_tile"] + interleaved_scale_cols_a = tiling["interleaved_scale_cols_a"] + + def _lds_load_b128(lds_buffer, byte_offset): + return _mxscale_lds_load_b128( + lds_buffer, byte_offset, + ir=ir, arith=arith, T=T, llvm_dialect=llvm_dialect, + ) + + def load_data_frag(lds_buffer, lane_base, ks): + return _mxscale_load_data_frag( + lds_buffer=lds_buffer, lane_base=lane_base, ks=ks, + PACK_FACTOR_A=PACK_FACTOR_A, WMMA_K=WMMA_K, is_fp4=is_fp4, + _lds_load_b128=_lds_load_b128, arith=arith, vector=vector, + ) + + def load_b_frag(lds_buffer, b_lane_bases, wn, ks): + if is_fp4: + return _mxscale_load_rowmajor_b_frag( + lds_buffer=lds_buffer, b_lane_bases=b_lane_bases, + wn=wn, ks=ks, + PACK_FACTOR_B=PACK_FACTOR_B, WMMA_K=WMMA_K, + _lds_load_b128=_lds_load_b128, arith=arith, vector=vector, + ) + return _mxscale_load_preshuffled_b_frag( + lds_buffer=lds_buffer, b_lane_bases=b_lane_bases, + wn=wn, ks=ks, + is_fp4=is_fp4, is_a8w4=is_a8w4, + PACK_FACTOR_B=PACK_FACTOR_B, WMMA_K=WMMA_K, + _lds_load_b128=_lds_load_b128, arith=arith, vector=vector, + ) + + def load_scale_i32(lds_buffer, scale_base, ks): + return _mxscale_load_scale_i32( + lds_buffer=lds_buffer, scale_base=scale_base, ks=ks, + SCALES_PER_WMMA=SCALES_PER_WMMA, + llvm_dialect=llvm_dialect, ir=ir, arith=arith, T=T, + ) + + def _precompute_a_data_bases(): + return _mxscale_precompute_a_data_bases( + warp_m_base=warp_m_base, lane16=lane16, lane_kgrp=lane_kgrp, + lds_a_stride_bytes=lds_a_stride_bytes, wmma_m_rep=wmma_m_rep, + WMMA_M=WMMA_M, is_fp4=is_fp4, + arith=arith, range_constexpr=range_constexpr, + ) + + def _precompute_b_data_bases(): + if is_fp4: + return _mxscale_precompute_rowmajor_b_data_bases( + warp_n_base=warp_n_base, lane16=lane16, lane_kgrp=lane_kgrp, + lds_b_stride_bytes=lds_b_stride_bytes, wmma_n_rep=wmma_n_rep, + WMMA_N=WMMA_N, arith=arith, range_constexpr=range_constexpr, + ) + return _mxscale_precompute_preshuffled_b_data_bases( + packed_tile_k_b=packed_tile_k_b, warp_tile_n=warp_tile_n, + wave_n_idx=wave_n_idx, lane16=lane16, lane_kgrp=lane_kgrp, + wmma_n_rep=wmma_n_rep, arith=arith, range_constexpr=range_constexpr, + ) + + def _precompute_a_scale_lane_bases(): + if is_fp4: + return _mxscale_precompute_rowmajor_scale_lane_bases( + warp_base=warp_m_base, lane16=lane16, + scale_k_per_tile=scale_k_per_tile, reps=wmma_m_rep, + WMMA_DIM=WMMA_M, arith=arith, range_constexpr=range_constexpr, + ) + return _mxscale_precompute_a_scale_lane_bases( + warp_m_base=warp_m_base, lane16=lane16, + wmma_m_rep=wmma_m_rep, + interleaved_scale_cols_a=interleaved_scale_cols_a, arith=arith, + ) + + def _precompute_b_scale_lane_bases(): + return _mxscale_precompute_rowmajor_scale_lane_bases( + warp_base=warp_n_base, lane16=lane16, + scale_k_per_tile=scale_k_per_tile, reps=wmma_n_rep * 2, + WMMA_DIM=WMMA_N, arith=arith, range_constexpr=range_constexpr, + ) + + def load_scale_b128(lds_buffer, scale_base, reps, ks=0): + return _mxscale_load_scale_b128( + lds_buffer=lds_buffer, scale_base=scale_base, + reps=reps, ks=ks, SCALES_PER_WMMA=SCALES_PER_WMMA, + _lds_load_b128=_lds_load_b128, + arith=arith, vector=vector, range_constexpr=range_constexpr, + ) + + return dict( + _lds_load_b128=_lds_load_b128, + load_data_frag=load_data_frag, + load_b_frag=load_b_frag, + load_scale_i32=load_scale_i32, + _precompute_a_data_bases=_precompute_a_data_bases, + _precompute_b_data_bases=_precompute_b_data_bases, + _precompute_a_scale_lane_bases=_precompute_a_scale_lane_bases, + _precompute_b_scale_lane_bases=_precompute_b_scale_lane_bases, + load_scale_b128=load_scale_b128, + ) diff --git a/kernels/moe_gemm_2stage_mxscale_gfx1250.py b/kernels/moe_gemm_2stage_mxscale_gfx1250.py new file mode 100644 index 00000000..f3483ab8 --- /dev/null +++ b/kernels/moe_gemm_2stage_mxscale_gfx1250.py @@ -0,0 +1,2889 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + + +"""gfx1250 MoE 2-stage mxscale kernels (fp4/fp8/a8w4). + +Implements stage1/stage2 single-kernel inline paths using the +``wmma_scale_f32_16x16x128_f8f6f4`` and ``wmma_scale_f32_32x16x128_f4`` +instructions for microscaling block formats with E8M0 scales. +""" + +from __future__ import annotations + +import functools + +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + +from kernels.moe_gemm_2stage import ( + MoeGemm2Mode, + compile_moe_reduction, +) +from kernels.moe_gemm_2stage_common_gfx1250 import ( + _Stage1GateUpPackedWrapper, + _compute_mxscale_tiling, + _compute_pipeline_plan, + _compute_tdm_store_layout, + _emit_stage1_gate_up_epilogue, + _emit_stage2_store_epilogue, + _extract_sub8, + _finalize_alloc_and_launch_2d, + _make_moe_wave_layout, + _make_mxscale_data_loaders, + _make_wmma_sub_tiles, + _moe_out_elem_ty, + _mxscale_emit_wmma, + _pick_mxscale_launch_shape, + _require_gfx1250, +) + + +@functools.lru_cache(maxsize=64) +def _compile_stage1_mxscale_kernel_impl( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + route_tile_m: int, + tile_m: int, + tile_n: int, + tile_k: int, + m_warp: int, + n_warp: int, + doweight_stage1: bool, + out_dtype: str, + waves_per_eu: int | None, + data_format: str = "fp8", + expert_sched_mode: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, +): + """Compile mxscale stage1 single kernel (route-pack + TDM + WMMA_SCALE + epilog).""" + import flydsl.compiler as flyc + import flydsl.expr as fx + from flydsl._mlir import ir + from flydsl._mlir.dialects import llvm as llvm_dialect + from flydsl._mlir.dialects import scf + from flydsl.compiler.kernel_function import CompilationContext + from flydsl.expr import arith, buffer_ops, gpu, idx2crd, range_constexpr, rocdl, tdm_ops, vector + from flydsl.expr.typing import T + from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + + tp = _compute_mxscale_tiling( + data_format=data_format, K=int(model_dim), + tile_m=int(tile_m), tile_n=int(tile_n), tile_k=int(tile_k), + m_warp=int(m_warp), n_warp=int(n_warp), out_dtype=out_dtype, + num_buffers=int(num_buffers), cluster_m=int(cluster_m), + cluster_n=int(cluster_n), stage_name="stage1", + ) + is_fp4, is_a8w4 = tp["is_fp4"], tp["is_a8w4"] + PACK_FACTOR_A, PACK_FACTOR_B = tp["PACK_FACTOR_A"], tp["PACK_FACTOR_B"] + ACC_VEC_SIZE = tp["ACC_VEC_SIZE"] + DS_LOADS_PER_A_FRAG = tp["DS_LOADS_PER_A_FRAG"] + WMMA_M, WMMA_N, WMMA_K = tp["WMMA_M"], tp["WMMA_N"], tp["WMMA_K"] + SCALE_BLOCK, SCALES_PER_WMMA = tp["SCALE_BLOCK"], tp["SCALES_PER_WMMA"] + WAVE_SIZE = tp["WAVE_SIZE"] + LDS_PAD_A_BYTES, LDS_PAD_B_BYTES = tp["LDS_PAD_A_BYTES"], tp["LDS_PAD_B_BYTES"] + use_cluster = tp["use_cluster"] + K = tp["K"] + K_packed_a, K_packed_b = tp["K_packed_a"], tp["K_packed_b"] + packed_tile_k_a, packed_tile_k_b = tp["packed_tile_k_a"], tp["packed_tile_k_b"] + K_scale, scale_k_per_tile = tp["K_scale"], tp["scale_k_per_tile"] + block_threads = tp["block_threads"] + warp_tile_m, warp_tile_n = tp["warp_tile_m"], tp["warp_tile_n"] + wmma_m_rep, wmma_n_rep = tp["wmma_m_rep"], tp["wmma_n_rep"] + k_wmma_steps, n_accs = tp["k_wmma_steps"], tp["n_accs"] + num_k_tiles = tp["num_k_tiles"] + b_scale_load_rep = tp["b_scale_load_rep"] + interleaved_scale_cols_b = tp["interleaved_scale_cols_b"] + lds_a_stride_bytes = tp["lds_a_stride_bytes"] + lds_b_stride_bytes = tp["lds_b_stride_bytes"] + lds_a_data_bytes, lds_b_data_bytes = tp["lds_a_data_bytes"], tp["lds_b_data_bytes"] + lds_a_scale_bytes, lds_b_scale_bytes = tp["lds_a_scale_bytes"], tp["lds_b_scale_bytes"] + interleaved_scale_cols_a = tp["interleaved_scale_cols_a"] + + N = int(inter_dim) + _merge_gate_up_tdm = bool((data_format in ("fp8", "a8w4")) and (N % int(tile_n) == 0)) + num_warps_s1 = int(m_warp) * int(n_warp) + _tdm_loader_waves = 2 if _merge_gate_up_tdm else 4 + if bool(wave_specialized_tdm): + if num_warps_s1 < _tdm_loader_waves: + raise ValueError( + f"wave_specialized_tdm requires at least {_tdm_loader_waves} waves, got {num_warps_s1}") + tdm_desc_num_warps = 1 if bool(wave_specialized_tdm) else num_warps_s1 + effective_waves_per_eu = waves_per_eu + if use_cluster and effective_waves_per_eu is None: + effective_waves_per_eu = 2 + + _sub_tiles = _make_wmma_sub_tiles( + wmma_m_rep=wmma_m_rep, wmma_n_rep=wmma_n_rep, WMMA_M=WMMA_M, is_fp4=is_fp4 + ) + + # Pipeline calculations for multi-buffer + _use_pipeline = int(num_buffers) >= 2 + if _use_pipeline: + from kernels.gemm_common_gfx1250 import ( + pipeline_fence, pipeline_fence_signal, pipeline_fence_wait, + ) + if _merge_gate_up_tdm: + _B_TDM_PER_STEP = 1 if bool(wave_specialized_tdm) else 2 + else: + _B_TDM_PER_STEP = 1 if bool(wave_specialized_tdm) else 4 + _pp = _compute_pipeline_plan( + num_k_tiles=num_k_tiles, num_buffers=int(num_buffers), + B_TDM_PER_STEP=_B_TDM_PER_STEP, tile_m=int(tile_m), + use_tdm_gather=use_tdm_gather, + wave_specialized_tdm=wave_specialized_tdm, + tdm_loader_waves=_tdm_loader_waves, + ) + pre_loaded = _pp["pre_loaded"] + loop_iters = _pp["loop_iters"] + _tail_start = _pp["tail_start"] + extra = _pp["extra"] + _A_GATHER_GROUPS = _pp["A_GATHER_GROUPS"] + TDM_PER_STEP = _pp["TDM_PER_STEP"] + _fence_outstanding = _pp["fence_outstanding"] + _tail_plan = _pp["tail_plan"] + from kernels.gemm_common_gfx1250 import workgroup_barrier + + alloc = SmemAllocator( + None, + arch=str(get_hip_arch()), + global_sym_name=f"moe_mxscale_{data_format}_s1_single_g{int(bool(use_tdm_gather))}", + ) + _nb = int(num_buffers) + off_ag_list, off_as_list = [], [] + off_bg_list, off_bs_list = [], [] + off_bu_list, off_bsu_list = [], [] + off_bg_pair_list, off_bs_pair_list = [], [] + for _buf_i in range(_nb): + _o = alloc._align(alloc.ptr, 16); alloc.ptr = _o + lds_a_data_bytes; off_ag_list.append(_o) + if _merge_gate_up_tdm: + _o = alloc._align(alloc.ptr, 16); alloc.ptr = _o + 2 * lds_b_data_bytes; off_bg_pair_list.append(_o) + else: + _o = alloc._align(alloc.ptr, 16); alloc.ptr = _o + lds_b_data_bytes; off_bg_list.append(_o) + _o = alloc._align(alloc.ptr, 16); alloc.ptr = _o + lds_a_scale_bytes; off_as_list.append(_o) + if _merge_gate_up_tdm: + _o = alloc._align(alloc.ptr, 16); alloc.ptr = _o + 2 * lds_b_scale_bytes; off_bs_pair_list.append(_o) + else: + _o = alloc._align(alloc.ptr, 16); alloc.ptr = _o + lds_b_scale_bytes; off_bs_list.append(_o) + _o = alloc._align(alloc.ptr, 16); alloc.ptr = _o + lds_b_data_bytes; off_bu_list.append(_o) + _o = alloc._align(alloc.ptr, 16); alloc.ptr = _o + lds_b_scale_bytes; off_bsu_list.append(_o) + + if bool(use_tdm_store): + from kernels.gemm_common_gfx1250 import store_acc_vec8_to_lds + _ds1 = _compute_tdm_store_layout( + warp_tile_m=warp_tile_m, warp_tile_n=warp_tile_n, + num_warps=num_warps_s1, WMMA_N=WMMA_N, use_pipeline=_use_pipeline, + ) + lds_d_row_stride_s1 = _ds1["lds_d_row_stride"] + d_output_off_s1 = _ds1["d_output_off"] + _lds_d_stride_elems_s1 = _ds1["lds_d_stride_elems"] + _warp_d_elems_s1 = _ds1["warp_d_elems"] + _n_col_d_elems_s1 = _ds1["n_col_d_elems"] + d_need_epilogue_fence_s1 = _ds1["d_need_epilogue_fence"] + if _ds1["total_d_bytes"] > alloc.ptr: + alloc.ptr = _ds1["total_d_bytes"] + + @flyc.kernel(known_block_size=[block_threads, 1, 1]) + def moe_mxscale_stage1_single( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_inter_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): + _ = i32_k_in + if inst_prefetch: + if arith.cmpi(arith.CmpIPredicate.eq, rocdl.wave_id(), + arith.constant(0, type=T.i32)): + _prefetch_lines = ["s_setreg_imm32_b32 hwreg(HW_REG_WAVE_MODE, 8, 1), 1"] + for _pg in range_constexpr(10): + _prefetch_lines.append( + f"s_prefetch_inst_pc_rel {_pg * 4096}, s0, 31") + llvm_dialect.inline_asm( + None, [], + "\n".join(_prefetch_lines), + "", has_side_effects=True, + ) + llvm_dialect.inline_asm( + None, [], + "s_setreg_imm32_b32 hwreg(26, 4, 1), 1", + "", + has_side_effects=True, + ) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + tokens_idx = arith.index_cast(T.index, i32_tokens_in) + size_expert_ids = arith.index_cast(T.index, i32_size_expert_ids_in) + c_topk_i32 = arith.constant(int(topk), type=T.i32) + num_valid_i32 = buffer_ops.buffer_load( + buffer_ops.create_buffer_resource(arg_num_valid_ids, max_size=True), + arith.constant(0, type=T.i32), + vec_width=1, + dtype=T.i32, + ) + sorted_num = size_expert_ids * arith.index(int(route_tile_m)) + sorted_nbytes = sorted_num * arith.index(4) + eid_nbytes = size_expert_ids * arith.index(4) + x_nbytes = tokens_idx * arith.index(K_packed_a) + sx_nbytes = tokens_idx * arith.index(K_scale) + w_rows = arith.index(int(experts * (2 * N))) + w_nbytes = w_rows * arith.index(K_packed_b) + sw_nbytes = w_rows * arith.index(K_scale) + + sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes) + eid_rsrc = buffer_ops.create_buffer_resource(arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes) + x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes) + sx_rsrc = buffer_ops.create_buffer_resource(arg_scale_x, max_size=False, num_records_bytes=sx_nbytes) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes) + sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False, num_records_bytes=sw_nbytes) + out_rsrc = buffer_ops.create_buffer_resource(arg_out, max_size=True) + tw_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=True) + + eid_i32 = buffer_ops.buffer_load(eid_rsrc, arith.index_cast(T.i32, by), vec_width=1, dtype=T.i32) + eid_ok0 = arith.cmpi(arith.CmpIPredicate.sge, eid_i32, arith.constant(0, type=T.i32)) + eid_ok1 = arith.cmpi(arith.CmpIPredicate.slt, eid_i32, arith.constant(int(experts), type=T.i32)) + block_row_start = arith.index_cast(T.i32, by * arith.index(int(route_tile_m))) + block_in_valid = arith.cmpi(arith.CmpIPredicate.slt, block_row_start, num_valid_i32) + block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) + + layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3) + ) + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + blk_n = bx * arith.index(int(tile_n)) + + if use_cluster: + _local_x, _local_y = gpu.compute_cluster_position() + _a_mcast_mask, b_mcast_mask = gpu.compute_mcast_masks( + _local_x, _local_y, int(cluster_m), int(cluster_n)) + else: + b_mcast_mask = 0 + + base_ptr = alloc.get_base() + lds_ag_bufs, lds_as_bufs = [], [] + lds_bg_bufs, lds_bs_bufs = [], [] + lds_bu_bufs, lds_bsu_bufs = [], [] + lds_bg_pair_bufs, lds_bs_pair_bufs = [], [] + for _bi in range_constexpr(_nb): + lds_ag_bufs.append(get_op_result_or_value( + SmemPtr(base_ptr, off_ag_list[_bi], T.i8, shape=(lds_a_data_bytes,)).get())) + lds_as_bufs.append(get_op_result_or_value( + SmemPtr(base_ptr, off_as_list[_bi], T.i8, shape=(lds_a_scale_bytes,)).get())) + if _merge_gate_up_tdm: + lds_bg_pair_bufs.append(get_op_result_or_value( + SmemPtr(base_ptr, off_bg_pair_list[_bi], T.i8, shape=(2 * lds_b_data_bytes,)).get())) + lds_bs_pair_bufs.append(get_op_result_or_value( + SmemPtr(base_ptr, off_bs_pair_list[_bi], T.i8, shape=(2 * lds_b_scale_bytes,)).get())) + else: + lds_bg_bufs.append(get_op_result_or_value( + SmemPtr(base_ptr, off_bg_list[_bi], T.i8, shape=(lds_b_data_bytes,)).get())) + lds_bs_bufs.append(get_op_result_or_value( + SmemPtr(base_ptr, off_bs_list[_bi], T.i8, shape=(lds_b_scale_bytes,)).get())) + lds_bu_bufs.append(get_op_result_or_value( + SmemPtr(base_ptr, off_bu_list[_bi], T.i8, shape=(lds_b_data_bytes,)).get())) + lds_bsu_bufs.append(get_op_result_or_value( + SmemPtr(base_ptr, off_bsu_list[_bi], T.i8, shape=(lds_b_scale_bytes,)).get())) + + if bool(use_tdm_store): + from kernels.gemm_common_gfx1250 import get_lds_memref + d_lds_f16_count_s1 = total_d_bytes_s1 // 2 + d_smem_s1 = SmemPtr(base_ptr, d_output_off_s1, T.f16, + shape=(d_lds_f16_count_s1,)) + d_lds_buffer_s1 = get_lds_memref(d_smem_s1) + warp_lds_off_s1 = ( + (wave_m_idx * arith.index(int(n_warp)) + wave_n_idx) + * arith.index(_warp_d_elems_s1) + ) + d_lane_base_s1 = ( + warp_lds_off_s1 + + lane16 * arith.index(_lds_d_stride_elems_s1) + + lane_kgrp * arith.index(4 * elem_bytes_d_s1) + ) + wave_id_idx_s1 = arith.index_cast(T.index, rocdl.wave_id()) + d_warp_off_sgpr_s1 = ( + wave_id_idx_s1 * arith.index(warp_d_bytes_s1) + + arith.index(d_output_off_s1) + ) + warp_m_off_sgpr_s1 = ( + (wave_id_idx_s1 / arith.index(int(n_warp))) + * arith.index(warp_tile_m) + ) + warp_n_off_sgpr_s1 = ( + (wave_id_idx_s1 % arith.index(int(n_warp))) + * arith.index(warp_tile_n) + ) + # TDM store for MoE stage1 uses gather-store mode because the + # output rows are not contiguous — each sorted row maps to + # out[tok * topk + slot, :] which is a scattered layout. + # d_desc_s1 is built lazily in the epilogue after sorted_ids + # are decoded (see _emit_tdm_gather_store_s1 below). + + def silu(x): + t = x * (-1.4426950408889634) + emu = rocdl.exp2(T.f32, t) + den = 1.0 + emu + sig = rocdl.rcp(T.f32, den) + return x * sig + + def make_desc_a(k_base): + return k_base / arith.index(PACK_FACTOR_A) + + # TDM gather for A data + _use_tdm_gather_a = bool(use_tdm_gather) + + def issue_a_load(k_packed_base, target_lds): + total = int(tile_m * packed_tile_k_a) + rounds = (total + block_threads - 1) // block_threads + for it in range(rounds): + elem = tx + fx.Index(it * block_threads) + in_range = arith.cmpi(arith.CmpIPredicate.ult, arith.index_cast(T.i32, elem), arith.constant(total, type=T.i32)) + _if_elem = scf.IfOp(in_range) + with ir.InsertionPoint(_if_elem.then_block): + row = elem // arith.index(int(packed_tile_k_a)) + col = elem % arith.index(int(packed_tile_k_a)) + sorted_row = by * arith.index(int(tile_m)) + row + row_i32 = arith.index_cast(T.i32, row) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi(arith.CmpIPredicate.ult, row_i32, arith.constant(int(route_tile_m), type=T.i32)) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + load_ok = arith.andi(row_ok, tok_ok) + x_idx = tok * arith.constant(K_packed_a, type=T.i32) + arith.index_cast(T.i32, k_packed_base + col) + x_idx_safe = arith.select(load_ok, x_idx, arith.constant(0, type=T.i32)) + x_val = arith.select(load_ok, buffer_ops.buffer_load(x_rsrc, x_idx_safe, vec_width=1, dtype=T.i8), arith.constant(0, type=T.i8)) + lds_idx = row * arith.index(lds_a_stride_bytes) + col + v1 = vector.from_elements(T.vec(1, T.i8), [x_val]) + vector.store(v1, target_lds, [lds_idx], alignment=1) + scf.YieldOp([]) + + # Pre-compute token row indices for ALL tile_m rows (once, outside K-loop). + # _a_tok_ids[i] = token_id for TDM gather A load + # _a_out_row_ids[i] = tok * topk + slot for TDM gather store output + _a_tok_ids = [] + _a_out_row_ids = [] + _a_load_valids = [] + _a_store_valids = [] + + def _sum_i32_values(_vals): + _acc = arith.constant(0, type=T.i32) + for _vi in range_constexpr(len(_vals)): + _acc = _acc + _vals[_vi] + return _acc + + def _precompute_a_row_indices(): + """Load sorted_ids for all tile_m rows and decode token_ids + output row indices.""" + _safe_row = arith.constant(0, type=T.i32) + _one_i32 = arith.constant(1, type=T.i32) + _zero_i32 = arith.constant(0, type=T.i32) + for _ri in range_constexpr(int(tile_m)): + _sorted_row = by * fx.Index(int(tile_m)) + fx.Index(_ri) + _sorted_i32 = arith.index_cast(T.i32, _sorted_row) + _row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Int32(_ri), + fx.Int32(int(route_tile_m)), + ) + _row_in_valid = arith.cmpi( + arith.CmpIPredicate.slt, + _sorted_i32, + num_valid_i32, + ) + _row_ok = arith.andi(_row_in_route, _row_in_valid) + _sorted_safe = arith.select( + _row_ok, _sorted_i32, + block_row_start, + ) + _fused = buffer_ops.buffer_load(sorted_rsrc, _sorted_safe, vec_width=1, dtype=T.i32) + _tok = _fused & fx.Int32((1 << 24) - 1) + _slot = _fused >> fx.Int32(24) + _tok_ok = arith.cmpi(arith.CmpIPredicate.ult, _tok, i32_tokens_in) + _slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, _slot, fx.Int32(0)) + _slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, _slot, c_topk_i32) + _slot_ok = arith.andi(_slot_ok0, _slot_ok1) + _row_tok_ok = arith.andi(_row_ok, _tok_ok) + _load_valid_i32 = arith.select(_row_tok_ok, _one_i32, _zero_i32) + _a_load_valids.append(rocdl.readfirstlane(T.i32, _load_valid_i32)) + _tok_safe = arith.select(_row_tok_ok, _tok, _safe_row) + _tok_sgpr = rocdl.readfirstlane(T.i32, _tok_safe) + _a_tok_ids.append(_tok_sgpr) + _out_row = _tok * c_topk_i32 + _slot + _row_fully_ok = arith.andi(_row_tok_ok, _slot_ok) + _store_valid_i32 = arith.select(_row_fully_ok, _one_i32, _zero_i32) + _a_store_valids.append(rocdl.readfirstlane(T.i32, _store_valid_i32)) + _out_row_safe = arith.select( + _row_fully_ok, _out_row, + _safe_row, + ) + _out_row_sgpr = rocdl.readfirstlane(T.i32, _out_row_safe) + _a_out_row_ids.append(_out_row_sgpr) + + _TDM_GATHER_CHUNK = 8 + _TDM_GATHER_GROUPS = (int(tile_m) + _TDM_GATHER_CHUNK - 1) // _TDM_GATHER_CHUNK + + _a_tokens_sgpr = None + _a_tokens_topk_sgpr = None + + def _get_tokens_sgpr(): + nonlocal _a_tokens_sgpr + if _a_tokens_sgpr is None: + _tok_i32 = arith.index_cast(T.i32, arith.index_cast(T.index, i32_tokens_in)) + _a_tokens_sgpr = rocdl.readfirstlane(T.i32, _tok_i32) + return _a_tokens_sgpr + + def _get_tokens_topk_sgpr(): + nonlocal _a_tokens_topk_sgpr + if _a_tokens_topk_sgpr is None: + _m_i32 = _get_tokens_sgpr() * c_topk_i32 + _a_tokens_topk_sgpr = rocdl.readfirstlane(T.i32, _m_i32) + return _a_tokens_topk_sgpr + + def issue_a_load_tdm_gather(k_base, target_lds): + """Load A data using TDM gather mode — one TDM instruction per 8 rows.""" + k_packed_base = k_base if PACK_FACTOR_A == 1 else k_base // fx.Index(PACK_FACTOR_A) + _tokens_dim1 = _get_tokens_sgpr() + _zero_i32 = arith.constant(0, type=T.i32) + for _gi in range_constexpr(_TDM_GATHER_GROUPS): + _start = _gi * _TDM_GATHER_CHUNK + _cnt = min(_TDM_GATHER_CHUNK, int(tile_m) - _start) + _row_indices = _a_tok_ids[_start:_start + _cnt] + _valid_count = _sum_i32_values(_a_load_valids[_start:_start + _cnt]) + _lds_off = fx.Index(_start * lds_a_stride_bytes) + _has_valid = arith.cmpi(arith.CmpIPredicate.sgt, _valid_count, _zero_i32) + _issue_pred = _has_valid + if wave_specialized_tdm: + _gather_owner = _gi % _tdm_loader_waves + _is_gather_loader = arith.cmpi( + arith.CmpIPredicate.eq, + _tdm_wave_id, + arith.constant(_gather_owner, type=T.i32), + ) + _issue_pred = arith.andi(_issue_pred, _is_gather_loader) + _if_issue = scf.IfOp(_issue_pred) + with ir.InsertionPoint(_if_issue.then_block): + desc = tdm_ops.make_tensor_gather_descriptor( + global_ptr=arg_x, + lds_memref=target_lds, + row_indices=_row_indices, + row_width=int(packed_tile_k_a), + tensor_dim0=K_packed_a, + tensor_dim1=_tokens_dim1, + stride=K_packed_a, + elem_bytes=1, + pad_interval=int(packed_tile_k_a) if LDS_PAD_A_BYTES > 0 else 0, + pad_amount=LDS_PAD_A_BYTES if LDS_PAD_A_BYTES > 0 else 0, + index_size=32, + gather_tile_dim1=_valid_count, + lds_byte_offset=_lds_off, + global_byte_offset=k_packed_base, + ) + tdm_ops.tensor_load_gather(desc) + scf.YieldOp([]) + + def make_desc_as(k_base): + return k_base / arith.index(SCALE_BLOCK) + + def issue_as_load(k_scale_base, target_lds): + total = int(tile_m * scale_k_per_tile) + rounds = (total + block_threads - 1) // block_threads + for it in range(rounds): + elem = tx + fx.Index(it * block_threads) + in_range = arith.cmpi(arith.CmpIPredicate.ult, arith.index_cast(T.i32, elem), arith.constant(total, type=T.i32)) + _if_elem = scf.IfOp(in_range) + with ir.InsertionPoint(_if_elem.then_block): + row = elem // arith.index(int(scale_k_per_tile)) + ksc = elem % arith.index(int(scale_k_per_tile)) + sorted_row = by * arith.index(int(tile_m)) + row + row_i32 = arith.index_cast(T.i32, row) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi(arith.CmpIPredicate.ult, row_i32, arith.constant(int(route_tile_m), type=T.i32)) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + load_ok = arith.andi(row_ok, tok_ok) + ksc_off = k_scale_base + ksc + sx_idx = tok * arith.constant(K_scale, type=T.i32) + arith.index_cast(T.i32, ksc_off) + sx_idx_safe = arith.select(load_ok, sx_idx, arith.constant(0, type=T.i32)) + sx_val = arith.select(load_ok, buffer_ops.buffer_load(sx_rsrc, sx_idx_safe, vec_width=1, dtype=T.i8), arith.constant(127, type=T.i8)) + if is_fp4: + lds_idx = row * arith.index(int(scale_k_per_tile)) + ksc + else: + warp_row_idx = row / arith.index(warp_tile_m) + local_row = row % arith.index(warp_tile_m) + lane_row = local_row % arith.index(WMMA_M) + local_wm_idx = local_row / arith.index(WMMA_M) + global_lds_row = warp_row_idx * arith.index(WMMA_M) + lane_row + ksc_blk = ksc / arith.index(SCALES_PER_WMMA) + ksc_sub = ksc % arith.index(SCALES_PER_WMMA) + lds_idx = ( + global_lds_row * arith.index(interleaved_scale_cols_a) + + ksc_blk * arith.index(wmma_m_rep * SCALES_PER_WMMA) + + local_wm_idx * arith.index(SCALES_PER_WMMA) + + ksc_sub + ) + v1 = vector.from_elements(T.vec(1, T.i8), [sx_val]) + vector.store(v1, target_lds, [lds_idx], alignment=1) + scf.YieldOp([]) + + def make_desc_b(lds_b_mem, n_off, k_base): + if is_fp4: + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_w, lds_memref=lds_b_mem, + global_offset=(n_off, k_base / arith.index(PACK_FACTOR_B)), + tensor_shape=(int(tile_n), int(packed_tile_k_b)), + strides=(K_packed_b, 1), + tile_shape=(int(tile_n), int(packed_tile_k_b)), + elem_bytes=1, pad_interval=int(packed_tile_k_b), pad_amount=LDS_PAD_B_BYTES, + num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask) + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_w, lds_memref=lds_b_mem, + global_offset=(n_off / arith.index(16), (k_base / arith.index(PACK_FACTOR_B)) * arith.index(16)), + tensor_shape=(int(experts * (2 * N) // 16), int(K_packed_b * 16)), + strides=(K_packed_b * 16, 1), + tile_shape=(int(tile_n // 16), int(packed_tile_k_b * 16)), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=tdm_desc_num_warps, + workgroup_mask=b_mcast_mask) + + def make_desc_b_pair(lds_b_mem, n_off, k_base): + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_w, lds_memref=lds_b_mem, + global_offset=(n_off / arith.index(16), (k_base / arith.index(PACK_FACTOR_B)) * arith.index(16)), + tensor_shape=(int(experts * (2 * N) // 16), int(K_packed_b * 16)), + strides=(K_packed_b * 16, 1), + tile_shape=(int((2 * tile_n) // 16), int(packed_tile_k_b * 16)), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=tdm_desc_num_warps, + workgroup_mask=b_mcast_mask) + + def make_desc_bs(lds_bs_mem, n_off, k_base): + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_scale_w, lds_memref=lds_bs_mem, + global_offset=(n_off, k_base / arith.index(SCALE_BLOCK)), + tensor_shape=(int(tile_n), int(scale_k_per_tile)), + strides=(K_scale, 1), + tile_shape=(int(tile_n), int(scale_k_per_tile)), + elem_bytes=1, pad_interval=0, pad_amount=0, + num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask) + + def make_desc_bs_pair(lds_bs_mem, n_off, k_base): + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_scale_w, lds_memref=lds_bs_mem, + global_offset=(n_off, k_base / arith.index(SCALE_BLOCK)), + tensor_shape=(int(2 * tile_n), int(scale_k_per_tile)), + strides=(K_scale, 1), + tile_shape=(int(2 * tile_n), int(scale_k_per_tile)), + elem_bytes=1, pad_interval=0, pad_amount=0, + num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask) + + def _stage1_pair_row_base(): + _eid_row = arith.index_cast(T.index, eid_i32) * arith.index(int(2 * N)) + _tile_idx = blk_n / arith.index(int(tile_n)) + return _eid_row + _tile_idx * arith.index(int(2 * tile_n)) + + _ldrs = _make_mxscale_data_loaders( + tiling=tp, warp_m_base=warp_m_base, warp_n_base=warp_n_base, + wave_n_idx=wave_n_idx, lane16=lane16, lane_kgrp=lane_kgrp, + ir=ir, arith=arith, vector=vector, llvm_dialect=llvm_dialect, + T=T, range_constexpr=range_constexpr, + ) + _lds_load_b128 = _ldrs["_lds_load_b128"] + load_data_frag = _ldrs["load_data_frag"] + load_b_frag = _ldrs["load_b_frag"] + load_scale_i32 = _ldrs["load_scale_i32"] + _precompute_a_data_bases = _ldrs["_precompute_a_data_bases"] + _precompute_b_data_bases = _ldrs["_precompute_b_data_bases"] + _precompute_a_scale_lane_bases = _ldrs["_precompute_a_scale_lane_bases"] + _precompute_b_scale_lane_bases = _ldrs["_precompute_b_scale_lane_bases"] + load_scale_b128 = _ldrs["load_scale_b128"] + + acc_zero = arith.constant_vector(0.0, T.vec(ACC_VEC_SIZE, T.f32)) + acc_g = [acc_zero] * n_accs + acc_u = [acc_zero] * n_accs + + _if_blk = scf.IfOp(block_ok) + with ir.InsertionPoint(_if_blk.then_block): + if _use_tdm_gather_a or bool(use_tdm_store): + _precompute_a_row_indices() + a_data_bases = _precompute_a_data_bases() + b_data_bases = _precompute_b_data_bases() + if _merge_gate_up_tdm: + b_u_data_bases = [ + _base + arith.index(lds_b_data_bytes) + for _base in b_data_bases + ] + else: + b_u_data_bases = b_data_bases + as_bases = _precompute_a_scale_lane_bases() + bs_bases = _precompute_b_scale_lane_bases() + if _merge_gate_up_tdm: + bsu_bases = [ + _base + arith.index(lds_b_scale_bytes) + for _base in bs_bases + ] + else: + bsu_bases = bs_bases + _use_scheduled_compute = _use_pipeline and not is_fp4 + _front_wm = (wmma_m_rep + 1) // 2 + _back_wm = wmma_m_rep - _front_wm + _front_wmma = 2 * _front_wm * wmma_n_rep + _back_wmma = 2 * _back_wm * wmma_n_rep + _b_frag_ds_loads_per_wn = 2 if is_a8w4 else 4 + _a_scale_ds_loads = wmma_m_rep if is_fp4 else (wmma_m_rep + 3) // 4 + _b_scale_ds_loads = b_scale_load_rep if is_fp4 else wmma_n_rep + _gate_up_ds_loads = ( + 2 * (wmma_n_rep * _b_frag_ds_loads_per_wn + _b_scale_ds_loads) + + _a_scale_ds_loads + ) + + # ── compute-tile helper (gate + up) ────────────────────── + def _load_gate_up_b_and_scales(buf_idx, ks): + if _merge_gate_up_tdm: + _gate_b_buf = lds_bg_pair_bufs[buf_idx] + _up_b_buf = lds_bg_pair_bufs[buf_idx] + _gate_bs_buf = lds_bs_pair_bufs[buf_idx] + _up_bs_buf = lds_bs_pair_bufs[buf_idx] + else: + _gate_b_buf = lds_bg_bufs[buf_idx] + _up_b_buf = lds_bu_bufs[buf_idx] + _gate_bs_buf = lds_bs_bufs[buf_idx] + _up_bs_buf = lds_bsu_bufs[buf_idx] + + b_g = [load_b_frag(_gate_b_buf, b_data_bases, wn, ks) + for wn in range_constexpr(wmma_n_rep)] + b_u = [load_b_frag(_up_b_buf, b_u_data_bases, wn, ks) + for wn in range_constexpr(wmma_n_rep)] + if is_fp4: + as_v = [load_scale_i32(lds_as_bufs[buf_idx], as_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] + bs_gv = [load_scale_i32(_gate_bs_buf, bs_bases[bi], ks) + for bi in range_constexpr(b_scale_load_rep)] + bs_uv = [load_scale_i32(_up_bs_buf, bsu_bases[bi], ks) + for bi in range_constexpr(b_scale_load_rep)] + else: + as_v = load_scale_b128(lds_as_bufs[buf_idx], as_bases[0], + wmma_m_rep, ks) + bs_gv = [load_scale_i32(_gate_bs_buf, bs_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + bs_uv = [load_scale_i32(_up_bs_buf, bsu_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + return b_g, bs_gv, b_u, bs_uv, as_v + + def emit_wmma(accs, wm, wn, a_frag, b_frags, a_scales, b_scales): + _mxscale_emit_wmma( + accs=accs, wm=wm, wn=wn, + a_frag=a_frag, b_frags=b_frags, + a_scales=a_scales, b_scales=b_scales, + is_fp4=is_fp4, is_a8w4=is_a8w4, + use_scale_opsel=False, + rocdl=rocdl, T=T, + ) + + def _emit_rows(acg_in, acu_in, start_wm, a_frags, b_g, b_u, a_scales, bs_g, bs_u): + for frag_i in range_constexpr(len(a_frags)): + wm = start_wm + frag_i + for wn_raw in range_constexpr(wmma_n_rep): + wn = (wmma_n_rep - 1 - wn_raw) if (wm % 2 == 1) else wn_raw + emit_wmma(acg_in, wm, wn, a_frags[frag_i], b_g, a_scales, bs_g) + emit_wmma(acu_in, wm, wn, a_frags[frag_i], b_u, a_scales, bs_u) + + def _compute_k_tile(acg, acu, buf_idx, mid_compute_callback=None): + _mid_emit_ks = 0 + if k_wmma_steps > 1: + _mid_emit_wm = wmma_m_rep - 1 + _mid_emit_wn = wmma_n_rep - 1 + else: + _front_wn = (wmma_n_rep + 1) // 2 + if wmma_m_rep > 1: + _mid_emit_wm = _front_wm - 1 + _mid_emit_wn = wmma_n_rep - 1 + else: + _mid_emit_wm = 0 + _mid_emit_wn = _front_wn - 1 + _did_mid = False + for ks in range_constexpr(k_wmma_steps): + b_g, bs_gv, b_u, bs_uv, as_v = _load_gate_up_b_and_scales(buf_idx, ks) + for wm in range_constexpr(wmma_m_rep): + a_frag = load_data_frag(lds_ag_bufs[buf_idx], + a_data_bases[wm], ks) + for wn_raw in range_constexpr(wmma_n_rep): + wn = (wmma_n_rep - 1 - wn_raw) if (wm % 2 == 1) else wn_raw + emit_wmma(acg, wm, wn, a_frag, b_g, as_v, bs_gv) + emit_wmma(acu, wm, wn, a_frag, b_u, as_v, bs_uv) + if ( + not _did_mid + and mid_compute_callback is not None + and ks == _mid_emit_ks + and wm == _mid_emit_wm + and wn == _mid_emit_wn + ): + mid_compute_callback() + _did_mid = True + return acg, acu + + def _a_streaming_compute( + acg, + acu, + buf_idx, + b_g, + bs_gv, + b_u, + bs_uv, + as_v, + ks, + next_bs_info=None, + mid_compute_callback=None, + ): + next_result = None + a_frags_front = [ + load_data_frag(lds_ag_bufs[buf_idx], a_data_bases[wm], ks) + for wm in range_constexpr(_front_wm) + ] + _use_partial_drain = ( + next_bs_info is not None + and _front_wm * wmma_n_rep >= 4 + ) + + if _use_partial_drain: + _next_buf_idx, _next_ks = next_bs_info + next_result = _load_gate_up_b_and_scales(_next_buf_idx, _next_ks) + rocdl.s_wait_dscnt(_gate_up_ds_loads) + else: + rocdl.s_wait_dscnt(0) + + _emit_rows(acg, acu, 0, a_frags_front, b_g, b_u, as_v, bs_gv, bs_uv) + + if mid_compute_callback is not None: + rocdl.sched_barrier(0) + mid_compute_callback() + + if _back_wm > 0: + a_frags_back = [ + load_data_frag( + lds_ag_bufs[buf_idx], + a_data_bases[_front_wm + h], + ks, + ) + for h in range_constexpr(_back_wm) + ] + _back_drain = _gate_up_ds_loads if _use_partial_drain else 0 + rocdl.s_wait_dscnt(_back_drain) + _emit_rows( + acg, + acu, + _front_wm, + a_frags_back, + b_g, + b_u, + as_v, + bs_gv, + bs_uv, + ) + + if not _use_partial_drain and next_bs_info is not None: + _next_buf_idx, _next_ks = next_bs_info + next_result = _load_gate_up_b_and_scales(_next_buf_idx, _next_ks) + return acg, acu, next_result + + def _compute_k_tile_scheduled(acg, acu, buf_idx, mid_compute_callback=None): + current_g = list(acg) + current_u = list(acu) + if k_wmma_steps == 1: + b_g, bs_gv, b_u, bs_uv, as_v = _load_gate_up_b_and_scales(buf_idx, 0) + current_g, current_u, _ = _a_streaming_compute( + current_g, current_u, buf_idx, + b_g, bs_gv, b_u, bs_uv, as_v, 0, + mid_compute_callback=mid_compute_callback, + ) + else: + b_g, bs_gv, b_u, bs_uv, as_v = _load_gate_up_b_and_scales(buf_idx, 0) + for ks in range_constexpr(k_wmma_steps - 1): + _mid_cb = mid_compute_callback if ks == 0 else None + current_g, current_u, _next = _a_streaming_compute( + current_g, current_u, buf_idx, + b_g, bs_gv, b_u, bs_uv, as_v, ks, + next_bs_info=(buf_idx, ks + 1), + mid_compute_callback=_mid_cb, + ) + b_g, bs_gv, b_u, bs_uv, as_v = _next + current_g, current_u, _ = _a_streaming_compute( + current_g, current_u, buf_idx, + b_g, bs_gv, b_u, bs_uv, as_v, + k_wmma_steps - 1, + ) + return current_g, current_u + + def _hot_loop_scheduler_scheduled(): + if not _use_scheduled_compute: + return + _front_a_loads = _front_wm * DS_LOADS_PER_A_FRAG + _back_a_loads = _back_wm * DS_LOADS_PER_A_FRAG + for _ks in range_constexpr(k_wmma_steps): + if _ks == 0: + rocdl.sched_dsrd(_gate_up_ds_loads + _front_a_loads) + else: + rocdl.sched_dsrd(_front_a_loads) + rocdl.sched_mfma(_front_wmma) + if _back_wmma > 0: + rocdl.sched_dsrd(_back_a_loads) + rocdl.sched_mfma(_back_wmma) + if _ks < k_wmma_steps - 1: + rocdl.sched_dsrd(_gate_up_ds_loads) + rocdl.sched_barrier(0) + + if wave_specialized_tdm: + _tdm_wave_id = rocdl.wave_id() + _loader_waves = _tdm_loader_waves + _is_loader_wave = arith.cmpi( + arith.CmpIPredicate.ult, + _tdm_wave_id, + arith.constant(_loader_waves, type=T.i32), + ) + _tdm_pred = arith.constant(1, type=T.i32) + + def _select_wave_tdm_value(*values): + if len(values) != _loader_waves: + raise ValueError( + f"expected {_loader_waves} wave-specialized TDM values, got {len(values)}" + ) + _selected = values[-1] + for _sel_idx in range_constexpr(_loader_waves - 1): + _value_idx = _loader_waves - 2 - _sel_idx + _is_wave = arith.cmpi( + arith.CmpIPredicate.eq, + _tdm_wave_id, + arith.constant(_value_idx, type=T.i32), + ) + _selected = arith.select(_is_wave, values[_value_idx], _selected) + return _selected + + def _tdm_desc_lds_addr(desc): + return vector.extract( + desc.dgroup0, + static_position=[1], + dynamic_position=[], + ) + + def _tdm_desc_addr_lo(desc): + return vector.extract( + desc.dgroup0, + static_position=[2], + dynamic_position=[], + ) + + def _tdm_desc_addr_hi(desc): + return vector.extract( + desc.dgroup0, + static_position=[3], + dynamic_position=[], + ) + + _zero_k_base = arith.index(0) + _scale_adv_i32 = arith.constant(scale_k_per_tile, type=T.i32) + if _merge_gate_up_tdm: + _n_pair_init = _stage1_pair_row_base() + _data_adv_i32 = arith.constant(packed_tile_k_b * 16, type=T.i32) + + _stages_b_lds_addr = [ + _tdm_desc_lds_addr( + make_desc_b_pair( + lds_bg_pair_bufs[i], + _n_pair_init, + _zero_k_base, + ) + ) + for i in range_constexpr(_nb) + ] + _stages_bs_lds_addr = [ + _tdm_desc_lds_addr( + make_desc_bs_pair( + lds_bs_pair_bufs[i], + _n_pair_init, + _zero_k_base, + ) + ) + for i in range_constexpr(_nb) + ] + + _desc_b_init = make_desc_b_pair( + lds_bg_pair_bufs[0], + _n_pair_init, + _zero_k_base, + ) + _desc_bs_init = make_desc_bs_pair( + lds_bs_pair_bufs[0], + _n_pair_init, + _zero_k_base, + ) + + _active_stage_lds_addr = [ + _select_wave_tdm_value( + _stages_b_lds_addr[i], + _stages_bs_lds_addr[i], + ) + for i in range_constexpr(_nb) + ] + _active_addr_lo = _select_wave_tdm_value( + _tdm_desc_addr_lo(_desc_b_init), + _tdm_desc_addr_lo(_desc_bs_init), + ) + _active_addr_hi = _select_wave_tdm_value( + _tdm_desc_addr_hi(_desc_b_init), + _tdm_desc_addr_hi(_desc_bs_init), + ) + _active_dgroup1 = _select_wave_tdm_value( + _desc_b_init.dgroup1, + _desc_bs_init.dgroup1, + ) + _active_adv_i32 = _select_wave_tdm_value( + _data_adv_i32, + _scale_adv_i32, + ) + else: + _eid_row = ( + arith.index_cast(T.index, eid_i32) + * arith.index(int(2 * N)) + ) + _n_gate_init = _eid_row + blk_n + _n_up_init = _eid_row + blk_n + arith.index(int(N)) + _data_adv_i32 = arith.constant( + packed_tile_k_b if is_fp4 else packed_tile_k_b * 16, + type=T.i32, + ) + + _stages_bg_lds_addr = [ + _tdm_desc_lds_addr( + make_desc_b( + lds_bg_bufs[i], + _n_gate_init, + _zero_k_base, + ) + ) + for i in range_constexpr(_nb) + ] + _stages_bu_lds_addr = [ + _tdm_desc_lds_addr( + make_desc_b( + lds_bu_bufs[i], + _n_up_init, + _zero_k_base, + ) + ) + for i in range_constexpr(_nb) + ] + _stages_bs_lds_addr = [ + _tdm_desc_lds_addr( + make_desc_bs( + lds_bs_bufs[i], + _n_gate_init, + _zero_k_base, + ) + ) + for i in range_constexpr(_nb) + ] + _stages_bsu_lds_addr = [ + _tdm_desc_lds_addr( + make_desc_bs( + lds_bsu_bufs[i], + _n_up_init, + _zero_k_base, + ) + ) + for i in range_constexpr(_nb) + ] + + _desc_bg_init = make_desc_b( + lds_bg_bufs[0], + _n_gate_init, + _zero_k_base, + ) + _desc_bu_init = make_desc_b( + lds_bu_bufs[0], + _n_up_init, + _zero_k_base, + ) + _desc_bs_init = make_desc_bs( + lds_bs_bufs[0], + _n_gate_init, + _zero_k_base, + ) + _desc_bsu_init = make_desc_bs( + lds_bsu_bufs[0], + _n_up_init, + _zero_k_base, + ) + + _active_stage_lds_addr = [ + _select_wave_tdm_value( + _stages_bg_lds_addr[i], + _stages_bu_lds_addr[i], + _stages_bs_lds_addr[i], + _stages_bsu_lds_addr[i], + ) + for i in range_constexpr(_nb) + ] + _active_addr_lo = _select_wave_tdm_value( + _tdm_desc_addr_lo(_desc_bg_init), + _tdm_desc_addr_lo(_desc_bu_init), + _tdm_desc_addr_lo(_desc_bs_init), + _tdm_desc_addr_lo(_desc_bsu_init), + ) + _active_addr_hi = _select_wave_tdm_value( + _tdm_desc_addr_hi(_desc_bg_init), + _tdm_desc_addr_hi(_desc_bu_init), + _tdm_desc_addr_hi(_desc_bs_init), + _tdm_desc_addr_hi(_desc_bsu_init), + ) + _active_dgroup1 = _select_wave_tdm_value( + _desc_bg_init.dgroup1, + _desc_bu_init.dgroup1, + _desc_bs_init.dgroup1, + _desc_bsu_init.dgroup1, + ) + _active_adv_i32 = _select_wave_tdm_value( + _data_adv_i32, + _data_adv_i32, + _scale_adv_i32, + _scale_adv_i32, + ) + + def _issue_active_b_tdm_only(stage_idx, curr_addr_lo): + _if_loader = scf.IfOp(_is_loader_wave) + with ir.InsertionPoint(_if_loader.then_block): + _dg0 = vector.from_elements(T.vec(4, T.i32), [ + _tdm_pred, + _active_stage_lds_addr[stage_idx], + curr_addr_lo, + _active_addr_hi, + ]) + tdm_ops.tensor_load_2d( + tdm_ops.TDMDescriptor2D(_dg0, _active_dgroup1) + ) + scf.YieldOp([]) + _next_addr_lo = arith.addi(curr_addr_lo, _active_adv_i32) + return arith.select( + _is_loader_wave, + _next_addr_lo, + curr_addr_lo, + ) + + # ── pipeline load helpers ───────────────────────────────── + def _issue_b_tdm_only(k_base, buf_idx): + if _merge_gate_up_tdm: + _n_pair = _stage1_pair_row_base() + tdm_ops.tensor_load_2d( + make_desc_b_pair(lds_bg_pair_bufs[buf_idx], _n_pair, k_base)) + tdm_ops.tensor_load_2d( + make_desc_bs_pair(lds_bs_pair_bufs[buf_idx], _n_pair, k_base)) + else: + _eid_row = (arith.index_cast(T.index, eid_i32) + * arith.index(int(2 * N))) + _n_gate = _eid_row + blk_n + _n_up = _eid_row + blk_n + arith.index(int(N)) + tdm_ops.tensor_load_2d( + make_desc_b(lds_bg_bufs[buf_idx], _n_gate, k_base)) + tdm_ops.tensor_load_2d( + make_desc_b(lds_bu_bufs[buf_idx], _n_up, k_base)) + tdm_ops.tensor_load_2d( + make_desc_bs(lds_bs_bufs[buf_idx], _n_gate, k_base)) + tdm_ops.tensor_load_2d( + make_desc_bs(lds_bsu_bufs[buf_idx], _n_up, k_base)) + + def _issue_scalar_loads(k_base, buf_idx): + if _use_tdm_gather_a: + issue_a_load_tdm_gather(k_base, lds_ag_bufs[buf_idx]) + else: + issue_a_load(make_desc_a(k_base), lds_ag_bufs[buf_idx]) + issue_as_load(make_desc_as(k_base), lds_as_bufs[buf_idx]) + + def _issue_all_loads(k_base, buf_idx): + _issue_b_tdm_only(k_base, buf_idx) + _issue_scalar_loads(k_base, buf_idx) + + def _compute_with_mid_loads(acg, acu, buf_idx, mid_load_callback=None): + if _use_scheduled_compute: + return _compute_k_tile_scheduled( + acg, acu, buf_idx, + mid_compute_callback=mid_load_callback, + ) + return _compute_k_tile( + acg, acu, buf_idx, + mid_compute_callback=mid_load_callback, + ) + + # ── main K-dimension reduction ──────────────────────────── + if not _use_pipeline: + if wave_specialized_tdm: + active_b_addr_lo = _active_addr_lo + for kt in range_constexpr(num_k_tiles): + k_base = fx.Index(kt * int(tile_k)) + active_b_addr_lo = _issue_active_b_tdm_only( + 0, active_b_addr_lo) + _issue_scalar_loads(k_base, 0) + tdm_ops.tensor_wait(0) + workgroup_barrier(use_cluster=use_cluster) + acc_g, acc_u = _compute_k_tile(acc_g, acc_u, 0) + workgroup_barrier(use_cluster=use_cluster) + else: + for kt in range_constexpr(num_k_tiles): + k_base = fx.Index(kt * int(tile_k)) + _issue_all_loads(k_base, 0) + tdm_ops.tensor_wait(0) + workgroup_barrier(use_cluster=use_cluster) + acc_g, acc_u = _compute_k_tile(acc_g, acc_u, 0) + workgroup_barrier(use_cluster=use_cluster) + else: + # ── prologue ── + if wave_specialized_tdm: + active_b_addr_lo = _active_addr_lo + for _pi in range_constexpr(pre_loaded): + active_b_addr_lo = _issue_active_b_tdm_only( + _pi, active_b_addr_lo) + _issue_scalar_loads(fx.Index(_pi * int(tile_k)), _pi) + else: + for _pi in range_constexpr(pre_loaded): + _issue_all_loads(fx.Index(_pi * int(tile_k)), _pi) + pipeline_fence(outstanding=0, use_cluster=use_cluster) + + # ── main pipelined loop ── + if loop_iters > 0: + if wave_specialized_tdm: + _init = list(acc_g) + list(acc_u) + [active_b_addr_lo] + for _li, _st in fx.range(0, loop_iters, 1, init=_init): + _ag = list(_st[:n_accs]) + _au = list(_st[n_accs:2 * n_accs]) + _cur_b_addr_lo = _st[2 * n_accs] + for _bi in range_constexpr(_nb): + _lb = (_bi + _nb - 1) % _nb + _kt = (_li * fx.Index(_nb) + + fx.Index(pre_loaded + _bi)) + _kb = _kt * fx.Index(int(tile_k)) + pipeline_fence_signal( + outstanding=_fence_outstanding, + use_cluster=use_cluster) + pipeline_fence_wait(use_cluster=use_cluster) + _cur_b_addr_lo = _issue_active_b_tdm_only( + _lb, _cur_b_addr_lo) + + def _mid_issue_scalar(_mid_kb=_kb, _mid_lb=_lb): + _issue_scalar_loads(_mid_kb, _mid_lb) + + if _use_scheduled_compute: + rocdl.sched_barrier(0) + _ag, _au = _compute_with_mid_loads( + _ag, + _au, + _bi, + _mid_issue_scalar, + ) + if _use_scheduled_compute: + _hot_loop_scheduler_scheduled() + _res = yield list(_ag) + list(_au) + [_cur_b_addr_lo] + acc_g = list(_res[:n_accs]) + acc_u = list(_res[n_accs:2 * n_accs]) + active_b_addr_lo = _res[2 * n_accs] + else: + _init = list(acc_g) + list(acc_u) + for _li, _st in fx.range(0, loop_iters, 1, init=_init): + _ag = list(_st[:n_accs]) + _au = list(_st[n_accs:2 * n_accs]) + for _bi in range_constexpr(_nb): + _lb = (_bi + _nb - 1) % _nb + _kt = (_li * fx.Index(_nb) + + fx.Index(pre_loaded + _bi)) + _kb = _kt * fx.Index(int(tile_k)) + pipeline_fence_signal( + outstanding=_fence_outstanding, + use_cluster=use_cluster) + pipeline_fence_wait(use_cluster=use_cluster) + _issue_b_tdm_only(_kb, _lb) + + def _mid_issue_scalar(_mid_kb=_kb, _mid_lb=_lb): + _issue_scalar_loads(_mid_kb, _mid_lb) + + if _use_scheduled_compute: + rocdl.sched_barrier(0) + _ag, _au = _compute_with_mid_loads( + _ag, + _au, + _bi, + _mid_issue_scalar, + ) + if _use_scheduled_compute: + _hot_loop_scheduler_scheduled() + _res = yield list(_ag) + list(_au) + acc_g = list(_res[:n_accs]) + acc_u = list(_res[n_accs:2 * n_accs]) + + # ── post-loop fence ── + if loop_iters > 0: + pipeline_fence(outstanding=0, use_cluster=use_cluster) + elif use_cluster: + gpu.cluster_barrier() + + # ── tail ── + _tail_li = 0 + _tail_had_load = False + for _ls, _cs, _out in _tail_plan: + if _out == -1: + if _tail_had_load: + pipeline_fence(outstanding=0, + use_cluster=use_cluster) + if _use_scheduled_compute: + rocdl.sched_barrier(0) + acc_g, acc_u = _compute_k_tile_scheduled( + acc_g, acc_u, _cs) + _hot_loop_scheduler_scheduled() + else: + acc_g, acc_u = _compute_k_tile( + acc_g, acc_u, _cs) + else: + pipeline_fence_signal(outstanding=_out, + use_cluster=use_cluster) + pipeline_fence_wait(use_cluster=use_cluster) + if _ls is not None: + _tail_had_load = True + _tkb = fx.Index( + (_tail_start + pre_loaded + _tail_li) + * int(tile_k)) + _tail_li += 1 + if wave_specialized_tdm: + active_b_addr_lo = _issue_active_b_tdm_only( + _ls, active_b_addr_lo) + else: + _issue_b_tdm_only(_tkb, _ls) + + def _tail_mid_issue_scalar(_mid_kb=_tkb, _mid_ls=_ls): + _issue_scalar_loads(_mid_kb, _mid_ls) + + if _use_scheduled_compute: + rocdl.sched_barrier(0) + acc_g, acc_u = _compute_with_mid_loads( + acc_g, + acc_u, + _cs, + _tail_mid_issue_scalar, + ) + if _use_scheduled_compute: + _hot_loop_scheduler_scheduled() + else: + if _use_scheduled_compute: + rocdl.sched_barrier(0) + acc_g, acc_u = _compute_k_tile_scheduled( + acc_g, acc_u, _cs) + _hot_loop_scheduler_scheduled() + else: + acc_g, acc_u = _compute_k_tile( + acc_g, acc_u, _cs) + + out_elem_ty = _moe_out_elem_ty(out_dtype, T) + + if bool(use_tdm_store): + # ── TDM store epilogue: silu(gate)*up → LDS → global (contiguous sorted output) ── + _scale_per_wm_s1 = [] + for _wm in range_constexpr(wmma_m_rep): + _m_off_val = _wm * WMMA_M + _row_local = warp_m_base + arith.index(_m_off_val) + lane16 + _sorted_row = by * arith.index(int(tile_m)) + _row_local + _sorted_i32 = arith.index_cast(T.i32, _sorted_row) + _row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + arith.index_cast(T.i32, _row_local), + arith.constant(int(route_tile_m), type=T.i32)) + if bool(doweight_stage1): + _sorted_safe = arith.select( + _row_in_route, _sorted_i32, + arith.index_cast(T.i32, + by * arith.index(int(route_tile_m)))) + _tw = buffer_ops.buffer_load( + tw_rsrc, _sorted_safe, vec_width=1, dtype=T.f32) + _sc = arith.select( + _row_in_route, _tw, + arith.constant(0.0, type=T.f32)) + else: + _sc = arith.select( + _row_in_route, + arith.constant(1.0, type=T.f32), + arith.constant(0.0, type=T.f32)) + _scale_per_wm_s1.append(_sc) + + if d_need_epilogue_fence_s1: + pipeline_fence(outstanding=0, use_cluster=use_cluster) + rocdl.sched_barrier(0) + + for _acc_idx, _vec_base, _m_off, _wn in _sub_tiles: + _wm_idx = _m_off // WMMA_M + _sc = _scale_per_wm_s1[_wm_idx] + _sub8g = _extract_sub8( + acc_g[_acc_idx], _vec_base, + vector=vector, + range_constexpr=range_constexpr, + ACC_VEC_SIZE=ACC_VEC_SIZE) + _sub8u = _extract_sub8( + acc_u[_acc_idx], _vec_base, + vector=vector, + range_constexpr=range_constexpr, + ACC_VEC_SIZE=ACC_VEC_SIZE) + _fused = [] + for _vi in range_constexpr(8): + _vg = vector.extract( + _sub8g, + static_position=[_vi], + dynamic_position=[]) + _vu = vector.extract( + _sub8u, + static_position=[_vi], + dynamic_position=[]) + _y = silu(_vg) * _vu * _sc + _fused.append(_y) + _fused_sub8 = vector.from_elements( + T.vec(8, T.f32), _fused) + _imm = (_m_off * _lds_d_stride_elems_s1 + + _wn * _n_col_d_elems_s1) + store_acc_vec8_to_lds( + d_lds_buffer_s1, d_lane_base_s1, _imm, + _fused_sub8, out_elem=out_elem_ty) + + rocdl.s_wait_dscnt(0) + # TDM gather store: each warp stores its warp_tile_m rows + # to scattered output positions tok*topk+slot. + _warp_row_start = arith.index_cast(T.i32, warp_m_base) + _warp_row_start_py = rocdl.readfirstlane(T.i32, _warp_row_start) + _d_store_chunk = 8 # 32-bit gather mode + _d_store_groups = (warp_tile_m + _d_store_chunk - 1) // _d_store_chunk + _tokens_topk_dim1 = _get_tokens_topk_sgpr() + for _dsi in range_constexpr(_d_store_groups): + _ds_start = _dsi * _d_store_chunk + _ds_cnt = min(_d_store_chunk, warp_tile_m - _ds_start) + # Global output row indices for this group + _ds_start_in_tile = _dsi * _d_store_chunk + rocdl.readfirstlane( + T.i32, arith.index_cast(T.i32, warp_m_base)) + # Can't do runtime add on SGPR easily; use compile-time + # warp offset from wave_id. But warp_m_base is runtime. + # Instead, index _a_out_row_ids which is tile-global. + # warp_m_base = wave_m_idx * warp_tile_m (runtime index) + # We need _a_out_row_ids[warp_m_base + _ds_start + i] + # Since warp_m_base depends on wave_id, we use scf.if + # per warp to select the correct slice. + # Simpler: for num_warps_m = m_warp, unroll per warp: + _ds_indices = [] + _ds_valids = [] + for _wi in range_constexpr(int(m_warp)): + _tile_row = _wi * warp_tile_m + _ds_start + _warp_indices = _a_out_row_ids[_tile_row:_tile_row + _ds_cnt] + _warp_valids = _a_store_valids[_tile_row:_tile_row + _ds_cnt] + if _wi == 0: + _ds_indices = list(_warp_indices) + _ds_valids = list(_warp_valids) + else: + _is_this_warp = arith.cmpi( + arith.CmpIPredicate.eq, + rocdl.wave_id() % fx.Int32(int(n_warp * m_warp) // int(n_warp)), + fx.Int32(_wi)) + # Actually wave_m_idx is the M warp index + _is_this_warp = arith.cmpi( + arith.CmpIPredicate.eq, + arith.index_cast(T.i32, wave_m_idx), + fx.Int32(_wi)) + for _ii in range_constexpr(len(_ds_indices)): + _ds_indices[_ii] = arith.select( + _is_this_warp, + _warp_indices[_ii], + _ds_indices[_ii]) + _ds_valids[_ii] = arith.select( + _is_this_warp, + _warp_valids[_ii], + _ds_valids[_ii]) + # LDS offset within D buffer for this group + _ds_lds_off = arith.index( + _ds_start * lds_d_row_stride_s1) + d_warp_off_sgpr_s1 + # Column offset in output + _col_byte_off = (blk_n + warp_n_off_sgpr_s1) * arith.index(elem_bytes_d_s1) + # For store direction: TDM ignores pad_enable, so we + # expand tile_dim0 to include padding so LDS read + # addresses align. tensor_dim0 stays at warp_tile_n so + # the extra pad elements hit OOB and are dropped. + _pad_elems = LDS_PAD_D_BYTES_s1 // elem_bytes_d_s1 + _store_tile_w = warp_tile_n + _pad_elems + _ds_valid_count = _sum_i32_values(_ds_valids) + _zero_i32 = arith.constant(0, type=T.i32) + _has_store = arith.cmpi(arith.CmpIPredicate.sgt, _ds_valid_count, _zero_i32) + _if_store = scf.IfOp(_has_store) + with ir.InsertionPoint(_if_store.then_block): + _d_store_desc = tdm_ops.make_tensor_gather_descriptor( + global_ptr=arg_out, + lds_memref=base_ptr, + row_indices=_ds_indices, + row_width=_store_tile_w, + tensor_dim0=warp_tile_n, + tensor_dim1=_tokens_topk_dim1, + stride=N, + elem_bytes=elem_bytes_d_s1, + pad_interval=0, + pad_amount=0, + index_size=32, + gather_tile_dim1=_ds_valid_count, + lds_byte_offset=_ds_lds_off, + global_byte_offset=_col_byte_off, + ) + tdm_ops.tensor_store_gather(_d_store_desc) + scf.YieldOp([]) + tdm_ops.tensor_wait(0) + else: + def _load_gate_up_sub8(acc_idx, vec_base): + return ( + _extract_sub8( + acc_g[acc_idx], vec_base, vector=vector, range_constexpr=range_constexpr, ACC_VEC_SIZE=ACC_VEC_SIZE + ), + _extract_sub8( + acc_u[acc_idx], vec_base, vector=vector, range_constexpr=range_constexpr, ACC_VEC_SIZE=ACC_VEC_SIZE + ), + ) + + _emit_stage1_gate_up_epilogue( + sub_tiles=_sub_tiles, + by=by, + tile_m=int(tile_m), + route_tile_m=int(route_tile_m), + warp_m_base=warp_m_base, + warp_n_base=warp_n_base, + blk_n=blk_n, + lane16=lane16, + lane_kgrp=lane_kgrp, + WMMA_N=WMMA_N, + i32_tokens_in=i32_tokens_in, + i32_inter_in=i32_inter_in, + topk=int(topk), + num_valid_i32=num_valid_i32, + block_row_start=block_row_start, + sorted_rsrc=sorted_rsrc, + tw_rsrc=tw_rsrc, + out_rsrc=out_rsrc, + doweight_stage1=bool(doweight_stage1), + out_elem_ty=out_elem_ty, + load_gate_up_sub8=_load_gate_up_sub8, + silu_fn=silu, + ir=ir, + fx=fx, + arith=arith, + buffer_ops=buffer_ops, + scf=scf, + vector=vector, + range_constexpr=range_constexpr, + T=T, + ) + scf.YieldOp([]) + + @flyc.jit + def launch_mxscale_stage1_single( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_inter_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + stream: fx.Stream, + ): + _ = i32_k_in + ctx = CompilationContext.get_current() + inter_in = arith.index_cast(T.index, i32_inter_in) + size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) + gx = (inter_in + fx.Index(int(tile_n) - 1)) // fx.Index(int(tile_n)) + gy = size_expert_ids_in + launcher = moe_mxscale_stage1_single( + arg_out, arg_x, arg_w, arg_scale_x, arg_scale_w, + arg_sorted_token_ids, arg_expert_ids, arg_sorted_weights, arg_num_valid_ids, + i32_tokens_in, i32_inter_in, i32_k_in, i32_size_expert_ids_in, + ) + _cluster_arg = (int(cluster_m), int(cluster_n), 1) if use_cluster else None + _finalize_alloc_and_launch_2d( + ctx=ctx, + alloc=alloc, + launcher=launcher, + gx=gx, + gy=gy, + block_threads=block_threads, + stream=stream, + waves_per_eu=effective_waves_per_eu, + ir=ir, + cluster=_cluster_arg, + ) + + if expert_sched_mode: + launch_mxscale_stage1_single.compile_hints["llvm_options"] = { + "amdgpu-expert-scheduling-mode": True, + } + + return launch_mxscale_stage1_single + + +@functools.lru_cache(maxsize=64) +def _compile_stage2_mxscale_kernel_impl( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + route_tile_m: int, + tile_m: int, + tile_n: int, + tile_k: int, + m_warp: int, + n_warp: int, + doweight_stage2: bool, + out_dtype: str, + accumulate: bool, + waves_per_eu: int | None, + data_format: str = "fp8", + expert_sched_mode: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, +): + """Compile mxscale stage2 single kernel (route-pack + TDM + WMMA_SCALE + epilog).""" + import flydsl.compiler as flyc + import flydsl.expr as fx + from flydsl._mlir import ir + from flydsl._mlir.dialects import llvm as llvm_dialect + from flydsl._mlir.dialects import scf + from flydsl.compiler.kernel_function import CompilationContext + from flydsl.expr import arith, buffer_ops, gpu, idx2crd, range_constexpr, rocdl, tdm_ops, vector + from flydsl.expr.typing import T + from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + + if bool(use_tdm_store) and bool(accumulate): + raise ValueError("use_tdm_store is not compatible with accumulate=True in moe mxscale stage2") + + tp = _compute_mxscale_tiling( + data_format=data_format, K=int(inter_dim), + tile_m=int(tile_m), tile_n=int(tile_n), tile_k=int(tile_k), + m_warp=int(m_warp), n_warp=int(n_warp), out_dtype=out_dtype, + num_buffers=int(num_buffers), cluster_m=int(cluster_m), + cluster_n=int(cluster_n), stage_name="stage2", + ) + is_fp4, is_a8w4 = tp["is_fp4"], tp["is_a8w4"] + PACK_FACTOR_A, PACK_FACTOR_B = tp["PACK_FACTOR_A"], tp["PACK_FACTOR_B"] + ACC_VEC_SIZE = tp["ACC_VEC_SIZE"] + DS_LOADS_PER_A_FRAG = tp["DS_LOADS_PER_A_FRAG"] + WMMA_M, WMMA_N, WMMA_K = tp["WMMA_M"], tp["WMMA_N"], tp["WMMA_K"] + SCALE_BLOCK, SCALES_PER_WMMA = tp["SCALE_BLOCK"], tp["SCALES_PER_WMMA"] + WAVE_SIZE = tp["WAVE_SIZE"] + LDS_PAD_A_BYTES, LDS_PAD_B_BYTES = tp["LDS_PAD_A_BYTES"], tp["LDS_PAD_B_BYTES"] + use_cluster = tp["use_cluster"] + K = tp["K"] + K_packed_a, K_packed_b = tp["K_packed_a"], tp["K_packed_b"] + packed_tile_k_a, packed_tile_k_b = tp["packed_tile_k_a"], tp["packed_tile_k_b"] + K_scale, scale_k_per_tile = tp["K_scale"], tp["scale_k_per_tile"] + block_threads = tp["block_threads"] + warp_tile_m, warp_tile_n = tp["warp_tile_m"], tp["warp_tile_n"] + wmma_m_rep, wmma_n_rep = tp["wmma_m_rep"], tp["wmma_n_rep"] + k_wmma_steps, n_accs = tp["k_wmma_steps"], tp["n_accs"] + num_k_tiles = tp["num_k_tiles"] + b_scale_load_rep = tp["b_scale_load_rep"] + interleaved_scale_cols_b = tp["interleaved_scale_cols_b"] + lds_a_stride_bytes = tp["lds_a_stride_bytes"] + lds_b_stride_bytes = tp["lds_b_stride_bytes"] + lds_a_data_bytes, lds_b_data_bytes = tp["lds_a_data_bytes"], tp["lds_b_data_bytes"] + lds_a_scale_bytes, lds_b_scale_bytes = tp["lds_a_scale_bytes"], tp["lds_b_scale_bytes"] + interleaved_scale_cols_a = tp["interleaved_scale_cols_a"] + + N_total = int(model_dim) + num_warps = int(m_warp) * int(n_warp) + if bool(wave_specialized_tdm): + if num_warps < 2: + raise ValueError( + f"wave_specialized_tdm requires at least 2 waves (B + B_scale), got {num_warps}") + _tdm_loader_waves = 2 + tdm_desc_num_warps = 1 if bool(wave_specialized_tdm) else num_warps + effective_waves_per_eu = waves_per_eu + if use_cluster and effective_waves_per_eu is None: + effective_waves_per_eu = 2 + + _use_pipeline = int(num_buffers) >= 2 + if _use_pipeline: + from kernels.gemm_common_gfx1250 import ( + pipeline_fence, pipeline_fence_signal, pipeline_fence_wait, + ) + _B_TDM_PER_STEP = 1 if bool(wave_specialized_tdm) else 2 + _pp = _compute_pipeline_plan( + num_k_tiles=num_k_tiles, num_buffers=int(num_buffers), + B_TDM_PER_STEP=_B_TDM_PER_STEP, tile_m=int(tile_m), + use_tdm_gather=use_tdm_gather, + wave_specialized_tdm=wave_specialized_tdm, + tdm_loader_waves=_tdm_loader_waves, + ) + pre_loaded = _pp["pre_loaded"] + loop_iters = _pp["loop_iters"] + _tail_start = _pp["tail_start"] + extra = _pp["extra"] + _A_GATHER_GROUPS = _pp["A_GATHER_GROUPS"] + TDM_PER_STEP = _pp["TDM_PER_STEP"] + _fence_outstanding = _pp["fence_outstanding"] + _tail_plan = _pp["tail_plan"] + from kernels.gemm_common_gfx1250 import workgroup_barrier + + alloc = SmemAllocator( + None, + arch=str(get_hip_arch()), + global_sym_name=f"moe_mxscale_{data_format}_s2_single_g{int(bool(use_tdm_gather))}", + ) + _nb = int(num_buffers) + off_a_list, off_b_list, off_as_list, off_bs_list = [], [], [], [] + for _buf_i in range(_nb): + _oa = alloc._align(alloc.ptr, 16) + alloc.ptr = _oa + lds_a_data_bytes + off_a_list.append(_oa) + _ob = alloc._align(alloc.ptr, 16) + alloc.ptr = _ob + lds_b_data_bytes + off_b_list.append(_ob) + _oas = alloc._align(alloc.ptr, 16) + alloc.ptr = _oas + lds_a_scale_bytes + off_as_list.append(_oas) + _obs = alloc._align(alloc.ptr, 16) + alloc.ptr = _obs + lds_b_scale_bytes + off_bs_list.append(_obs) + + if bool(use_tdm_store): + from kernels.gemm_common_gfx1250 import store_acc_vec8_to_lds + _ds2 = _compute_tdm_store_layout( + warp_tile_m=warp_tile_m, warp_tile_n=warp_tile_n, + num_warps=num_warps, WMMA_N=WMMA_N, use_pipeline=_use_pipeline, + ) + lds_d_row_stride = _ds2["lds_d_row_stride"] + d_output_off = _ds2["d_output_off"] + _lds_d_stride_elems = _ds2["lds_d_stride_elems"] + _warp_d_elems = _ds2["warp_d_elems"] + _n_col_d_elems = _ds2["n_col_d_elems"] + d_need_epilogue_fence = _ds2["d_need_epilogue_fence"] + if _ds2["total_d_bytes"] > alloc.ptr: + alloc.ptr = _ds2["total_d_bytes"] + + _sub_tiles = _make_wmma_sub_tiles( + wmma_m_rep=wmma_m_rep, wmma_n_rep=wmma_n_rep, WMMA_M=WMMA_M, is_fp4=is_fp4 + ) + + @flyc.kernel(known_block_size=[block_threads, 1, 1]) + def moe_mxscale_stage2_single( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): + _ = i32_k_in + if inst_prefetch: + if arith.cmpi(arith.CmpIPredicate.eq, rocdl.wave_id(), + arith.constant(0, type=T.i32)): + _prefetch_lines = ["s_setreg_imm32_b32 hwreg(HW_REG_WAVE_MODE, 8, 1), 1"] + for _pg in range_constexpr(10): + _prefetch_lines.append( + f"s_prefetch_inst_pc_rel {_pg * 4096}, s0, 31") + llvm_dialect.inline_asm( + None, [], + "\n".join(_prefetch_lines), + "", has_side_effects=True, + ) + llvm_dialect.inline_asm( + None, [], + "s_setreg_imm32_b32 hwreg(26, 4, 1), 1", + "", + has_side_effects=True, + ) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + tokens_idx = arith.index_cast(T.index, i32_tokens_in) + n_idx = arith.index_cast(T.index, i32_n_in) + size_expert_ids = arith.index_cast(T.index, i32_size_expert_ids_in) + c_topk_i32 = arith.constant(int(topk), type=T.i32) + num_valid_i32 = buffer_ops.buffer_load( + buffer_ops.create_buffer_resource(arg_num_valid_ids, max_size=True), + arith.constant(0, type=T.i32), + vec_width=1, + dtype=T.i32, + ) + + sorted_num = size_expert_ids * arith.index(int(route_tile_m)) + sorted_nbytes = sorted_num * arith.index(4) + eid_nbytes = size_expert_ids * arith.index(4) + x_rows = tokens_idx * arith.index(int(topk)) + x_nbytes = x_rows * arith.index(K_packed_a) + sx_nbytes = x_rows * arith.index(K_scale) + w_rows = arith.index(int(experts)) * n_idx + w_nbytes = w_rows * arith.index(K_packed_b) + sw_nbytes = w_rows * arith.index(K_scale) + out_nbytes = tokens_idx * n_idx * arith.index(2) + if not bool(accumulate): + out_nbytes = x_rows * n_idx * arith.index(2) + + sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes) + eid_rsrc = buffer_ops.create_buffer_resource(arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes) + x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes) + sx_rsrc = buffer_ops.create_buffer_resource(arg_scale_x, max_size=False, num_records_bytes=sx_nbytes) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes) + sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False, num_records_bytes=sw_nbytes) + out_rsrc = buffer_ops.create_buffer_resource(arg_out, max_size=False, num_records_bytes=out_nbytes) + tw_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=True) + + eid_i32 = buffer_ops.buffer_load(eid_rsrc, arith.index_cast(T.i32, by), vec_width=1, dtype=T.i32) + eid_ok0 = arith.cmpi(arith.CmpIPredicate.sge, eid_i32, arith.constant(0, type=T.i32)) + eid_ok1 = arith.cmpi(arith.CmpIPredicate.slt, eid_i32, arith.constant(int(experts), type=T.i32)) + block_row_start = arith.index_cast(T.i32, by * arith.index(int(route_tile_m))) + block_in_valid = arith.cmpi(arith.CmpIPredicate.slt, block_row_start, num_valid_i32) + block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) + + layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3) + ) + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + blk_n = bx * arith.index(int(tile_n)) + + if use_cluster: + _local_x, _local_y = gpu.compute_cluster_position() + _a_mcast_mask, b_mcast_mask = gpu.compute_mcast_masks( + _local_x, _local_y, int(cluster_m), int(cluster_n)) + else: + b_mcast_mask = 0 + + base_ptr = alloc.get_base() + lds_a_bufs = [] + lds_b_bufs = [] + lds_as_bufs = [] + lds_bs_bufs = [] + for _bi in range_constexpr(_nb): + _sa = SmemPtr(base_ptr, off_a_list[_bi], T.i8, shape=(lds_a_data_bytes,)) + _sb = SmemPtr(base_ptr, off_b_list[_bi], T.i8, shape=(lds_b_data_bytes,)) + _sas = SmemPtr(base_ptr, off_as_list[_bi], T.i8, shape=(lds_a_scale_bytes,)) + _sbs = SmemPtr(base_ptr, off_bs_list[_bi], T.i8, shape=(lds_b_scale_bytes,)) + lds_a_bufs.append(get_op_result_or_value(_sa.get())) + lds_b_bufs.append(get_op_result_or_value(_sb.get())) + lds_as_bufs.append(get_op_result_or_value(_sas.get())) + lds_bs_bufs.append(get_op_result_or_value(_sbs.get())) + + if bool(use_tdm_store): + from kernels.gemm_common_gfx1250 import get_lds_memref + d_lds_f16_count = total_d_bytes // 2 + d_smem = SmemPtr(base_ptr, d_output_off, T.f16, + shape=(d_lds_f16_count,)) + d_lds_buffer = get_lds_memref(d_smem) + warp_lds_off = ( + (wave_m_idx * arith.index(int(n_warp)) + wave_n_idx) + * arith.index(_warp_d_elems) + ) + d_lane_base = ( + warp_lds_off + + lane16 * arith.index(_lds_d_stride_elems) + + lane_kgrp * arith.index(4 * elem_bytes_d) + ) + wave_id_idx = arith.index_cast(T.index, rocdl.wave_id()) + d_warp_off_sgpr = ( + wave_id_idx * arith.index(warp_d_bytes) + + arith.index(d_output_off) + ) + warp_m_off_sgpr = ( + (wave_id_idx / arith.index(int(n_warp))) + * arith.index(warp_tile_m) + ) + warp_n_off_sgpr = ( + (wave_id_idx % arith.index(int(n_warp))) + * arith.index(warp_tile_n) + ) + d_desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_out, + lds_memref=base_ptr, + global_offset=( + by * arith.index(int(tile_m)) + warp_m_off_sgpr, + blk_n + warp_n_off_sgpr, + ), + tensor_shape=(warp_tile_m, warp_tile_n), + strides=(N_total, 1), + tile_shape=(warp_tile_m, warp_tile_n), + elem_bytes=elem_bytes_d, + pad_interval=warp_tile_n, + pad_amount=LDS_PAD_D_BYTES // elem_bytes_d, + num_warps=1, + lds_byte_offset=d_warp_off_sgpr, + for_store=True, + ) + + _use_tdm_gather_a = bool(use_tdm_gather) + _a_row_ids = [] + _a_row_valids = [] + _TDM_GATHER_CHUNK = 8 + _TDM_GATHER_GROUPS = (int(tile_m) + _TDM_GATHER_CHUNK - 1) // _TDM_GATHER_CHUNK + _tokens_topk_sgpr = None + + def _sum_i32_values(_vals): + _acc = arith.constant(0, type=T.i32) + for _vi in range_constexpr(len(_vals)): + _acc = _acc + _vals[_vi] + return _acc + + def _get_tokens_topk_sgpr(): + nonlocal _tokens_topk_sgpr + if _tokens_topk_sgpr is None: + _m_i32 = arith.index_cast( + T.i32, + tokens_idx * arith.index(int(topk)), + ) + _tokens_topk_sgpr = rocdl.readfirstlane(T.i32, _m_i32) + return _tokens_topk_sgpr + + def _precompute_a_row_indices(): + _safe_row = arith.constant(0, type=T.i32) + _one_i32 = arith.constant(1, type=T.i32) + _zero_i32 = arith.constant(0, type=T.i32) + for _ri in range_constexpr(int(tile_m)): + _sorted_row = by * fx.Index(int(tile_m)) + fx.Index(_ri) + _sorted_i32 = arith.index_cast(T.i32, _sorted_row) + _row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + fx.Int32(_ri), + fx.Int32(int(route_tile_m)), + ) + _row_in_valid = arith.cmpi( + arith.CmpIPredicate.slt, + _sorted_i32, + num_valid_i32, + ) + _row_ok = arith.andi(_row_in_route, _row_in_valid) + _sorted_safe = arith.select(_row_ok, _sorted_i32, block_row_start) + _fused = buffer_ops.buffer_load(sorted_rsrc, _sorted_safe, vec_width=1, dtype=T.i32) + _tok = _fused & fx.Int32((1 << 24) - 1) + _slot = _fused >> fx.Int32(24) + _tok_ok = arith.cmpi(arith.CmpIPredicate.ult, _tok, i32_tokens_in) + _slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, _slot, fx.Int32(0)) + _slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, _slot, c_topk_i32) + _ts = _tok * c_topk_i32 + _slot + _ts_ok = arith.andi(_tok_ok, arith.andi(_slot_ok0, _slot_ok1)) + _row_fully_ok = arith.andi(_row_ok, _ts_ok) + _row_valid_i32 = arith.select(_row_fully_ok, _one_i32, _zero_i32) + _a_row_valids.append(rocdl.readfirstlane(T.i32, _row_valid_i32)) + _ts_safe = arith.select(_row_fully_ok, _ts, _safe_row) + _a_row_ids.append(rocdl.readfirstlane(T.i32, _ts_safe)) + + def make_desc_a(k_base): + return k_base / arith.index(PACK_FACTOR_A) + + def issue_a_load(k_packed_base, target_lds): + total = int(tile_m * packed_tile_k_a) + rounds = (total + block_threads - 1) // block_threads + for it in range(rounds): + elem = tx + fx.Index(it * block_threads) + in_range = arith.cmpi(arith.CmpIPredicate.ult, arith.index_cast(T.i32, elem), arith.constant(total, type=T.i32)) + _if_elem = scf.IfOp(in_range) + with ir.InsertionPoint(_if_elem.then_block): + row = elem // arith.index(int(packed_tile_k_a)) + col = elem % arith.index(int(packed_tile_k_a)) + sorted_row = by * arith.index(int(tile_m)) + row + row_i32 = arith.index_cast(T.i32, row) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi(arith.CmpIPredicate.ult, row_i32, arith.constant(int(route_tile_m), type=T.i32)) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32)) + slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, slot, c_topk_i32) + ts = tok * c_topk_i32 + slot + ts_ok = arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1)) + load_ok = arith.andi(row_ok, ts_ok) + x_idx = ts * arith.constant(K_packed_a, type=T.i32) + arith.index_cast(T.i32, k_packed_base + col) + x_idx_safe = arith.select(load_ok, x_idx, arith.constant(0, type=T.i32)) + x_val = arith.select(load_ok, buffer_ops.buffer_load(x_rsrc, x_idx_safe, vec_width=1, dtype=T.i8), arith.constant(0, type=T.i8)) + lds_idx = row * arith.index(lds_a_stride_bytes) + col + v1 = vector.from_elements(T.vec(1, T.i8), [x_val]) + vector.store(v1, target_lds, [lds_idx], alignment=1) + scf.YieldOp([]) + + def issue_a_load_tdm_gather(k_base, target_lds): + """Load stage2 A rows via TDM gather using token-slot row ids.""" + k_packed_base = k_base if PACK_FACTOR_A == 1 else k_base // fx.Index(PACK_FACTOR_A) + _tokens_topk = _get_tokens_topk_sgpr() + _zero_i32 = arith.constant(0, type=T.i32) + for _gi in range_constexpr(_TDM_GATHER_GROUPS): + _start = _gi * _TDM_GATHER_CHUNK + _cnt = min(_TDM_GATHER_CHUNK, int(tile_m) - _start) + _row_indices = _a_row_ids[_start:_start + _cnt] + _valid_count = _sum_i32_values(_a_row_valids[_start:_start + _cnt]) + _lds_off = fx.Index(_start * lds_a_stride_bytes) + _has_valid = arith.cmpi(arith.CmpIPredicate.sgt, _valid_count, _zero_i32) + _issue_pred = _has_valid + if wave_specialized_tdm: + _gather_owner = _gi % _tdm_loader_waves + _is_gather_loader = arith.cmpi( + arith.CmpIPredicate.eq, + _tdm_wave_id, + arith.constant(_gather_owner, type=T.i32), + ) + _issue_pred = arith.andi(_issue_pred, _is_gather_loader) + _if_issue = scf.IfOp(_issue_pred) + with ir.InsertionPoint(_if_issue.then_block): + desc = tdm_ops.make_tensor_gather_descriptor( + global_ptr=arg_x, + lds_memref=target_lds, + row_indices=_row_indices, + row_width=int(packed_tile_k_a), + tensor_dim0=K_packed_a, + tensor_dim1=_tokens_topk, + stride=K_packed_a, + elem_bytes=1, + pad_interval=int(packed_tile_k_a) if LDS_PAD_A_BYTES > 0 else 0, + pad_amount=LDS_PAD_A_BYTES if LDS_PAD_A_BYTES > 0 else 0, + index_size=32, + gather_tile_dim1=_valid_count, + lds_byte_offset=_lds_off, + global_byte_offset=k_packed_base, + ) + tdm_ops.tensor_load_gather(desc) + scf.YieldOp([]) + + def make_desc_as(k_base): + return k_base / arith.index(SCALE_BLOCK) + + def issue_as_load(k_scale_base, target_lds): + total = int(tile_m * scale_k_per_tile) + rounds = (total + block_threads - 1) // block_threads + for it in range(rounds): + elem = tx + fx.Index(it * block_threads) + in_range = arith.cmpi(arith.CmpIPredicate.ult, arith.index_cast(T.i32, elem), arith.constant(total, type=T.i32)) + _if_elem = scf.IfOp(in_range) + with ir.InsertionPoint(_if_elem.then_block): + row = elem // arith.index(int(scale_k_per_tile)) + ksc = elem % arith.index(int(scale_k_per_tile)) + sorted_row = by * arith.index(int(tile_m)) + row + row_i32 = arith.index_cast(T.i32, row) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi(arith.CmpIPredicate.ult, row_i32, arith.constant(int(route_tile_m), type=T.i32)) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32)) + slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, slot, c_topk_i32) + ts = tok * c_topk_i32 + slot + ts_ok = arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1)) + load_ok = arith.andi(row_ok, ts_ok) + ksc_off = k_scale_base + ksc + sx_idx = ts * arith.constant(K_scale, type=T.i32) + arith.index_cast(T.i32, ksc_off) + sx_idx_safe = arith.select(load_ok, sx_idx, arith.constant(0, type=T.i32)) + sx_val = arith.select(load_ok, buffer_ops.buffer_load(sx_rsrc, sx_idx_safe, vec_width=1, dtype=T.i8), arith.constant(127, type=T.i8)) + if is_fp4: + lds_idx = row * arith.index(int(scale_k_per_tile)) + ksc + else: + warp_row_idx = row / arith.index(warp_tile_m) + local_row = row % arith.index(warp_tile_m) + lane_row = local_row % arith.index(WMMA_M) + local_wm_idx = local_row / arith.index(WMMA_M) + global_lds_row = warp_row_idx * arith.index(WMMA_M) + lane_row + ksc_blk = ksc / arith.index(SCALES_PER_WMMA) + ksc_sub = ksc % arith.index(SCALES_PER_WMMA) + lds_idx = ( + global_lds_row * arith.index(interleaved_scale_cols_a) + + ksc_blk * arith.index(wmma_m_rep * SCALES_PER_WMMA) + + local_wm_idx * arith.index(SCALES_PER_WMMA) + + ksc_sub + ) + v1 = vector.from_elements(T.vec(1, T.i8), [sx_val]) + vector.store(v1, target_lds, [lds_idx], alignment=1) + scf.YieldOp([]) + + def make_desc_b(n_off, k_base, target_lds): + if is_fp4: + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_w, lds_memref=target_lds, + global_offset=(n_off, k_base / arith.index(PACK_FACTOR_B)), + tensor_shape=(int(tile_n), int(packed_tile_k_b)), + strides=(K_packed_b, 1), + tile_shape=(int(tile_n), int(packed_tile_k_b)), + elem_bytes=1, pad_interval=int(packed_tile_k_b), pad_amount=LDS_PAD_B_BYTES, + num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask) + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_w, lds_memref=target_lds, + global_offset=(n_off / arith.index(16), (k_base / arith.index(PACK_FACTOR_B)) * arith.index(16)), + tensor_shape=(int(N_total // 16), int(K_packed_b * 16)), + strides=(int(K_packed_b * 16), 1), + tile_shape=(int(tile_n // 16), int(packed_tile_k_b * 16)), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=tdm_desc_num_warps, + workgroup_mask=b_mcast_mask) + + def make_desc_bs(n_off, k_base, target_lds): + return tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_scale_w, lds_memref=target_lds, + global_offset=(n_off, k_base / arith.index(SCALE_BLOCK)), + tensor_shape=(int(tile_n), int(scale_k_per_tile)), + strides=(K_scale, 1), + tile_shape=(int(tile_n), int(scale_k_per_tile)), + elem_bytes=1, pad_interval=0, pad_amount=0, + num_warps=tdm_desc_num_warps, workgroup_mask=b_mcast_mask) + + def issue_b_load(k_base, target_lds_b, target_lds_bs): + eid_idx = arith.index_cast(T.index, eid_i32) + n_off = eid_idx * n_idx + blk_n + tdm_ops.tensor_load_2d(make_desc_b(n_off, k_base, target_lds_b)) + tdm_ops.tensor_load_2d(make_desc_bs(n_off, k_base, target_lds_bs)) + + _ldrs = _make_mxscale_data_loaders( + tiling=tp, warp_m_base=warp_m_base, warp_n_base=warp_n_base, + wave_n_idx=wave_n_idx, lane16=lane16, lane_kgrp=lane_kgrp, + ir=ir, arith=arith, vector=vector, llvm_dialect=llvm_dialect, + T=T, range_constexpr=range_constexpr, + ) + _lds_load_b128 = _ldrs["_lds_load_b128"] + load_data_frag = _ldrs["load_data_frag"] + load_b_frag = _ldrs["load_b_frag"] + load_scale_i32 = _ldrs["load_scale_i32"] + _precompute_a_data_bases = _ldrs["_precompute_a_data_bases"] + _precompute_b_data_bases = _ldrs["_precompute_b_data_bases"] + _precompute_a_scale_lane_bases = _ldrs["_precompute_a_scale_lane_bases"] + _precompute_b_scale_lane_bases = _ldrs["_precompute_b_scale_lane_bases"] + load_scale_b128 = _ldrs["load_scale_b128"] + + acc_zero = arith.constant_vector(0.0, T.vec(ACC_VEC_SIZE, T.f32)) + acc = [acc_zero] * n_accs + + _if_blk = scf.IfOp(block_ok) + with ir.InsertionPoint(_if_blk.then_block): + if _use_tdm_gather_a: + _precompute_a_row_indices() + a_data_bases = _precompute_a_data_bases() + b_data_bases = _precompute_b_data_bases() + as_bases = _precompute_a_scale_lane_bases() + bs_bases = _precompute_b_scale_lane_bases() + _use_scheduled_compute = _use_pipeline and not is_fp4 + _front_wm = (wmma_m_rep + 1) // 2 + _back_wm = wmma_m_rep - _front_wm + _front_wmma = _front_wm * wmma_n_rep + _back_wmma = _back_wm * wmma_n_rep + _b_frag_ds_loads_per_wn = 2 if is_a8w4 else 4 + _a_scale_ds_loads = wmma_m_rep if is_fp4 else (wmma_m_rep + 3) // 4 + _b_scale_ds_loads = b_scale_load_rep if is_fp4 else wmma_n_rep + _bs_ds_loads = ( + wmma_n_rep * _b_frag_ds_loads_per_wn + + _b_scale_ds_loads + + _a_scale_ds_loads + ) + + # ── compute-tile helper ────────────────────────────────── + def emit_wmma(accs, wm, wn, a_frag, b_frags, a_scales, b_scales): + _mxscale_emit_wmma( + accs=accs, wm=wm, wn=wn, + a_frag=a_frag, b_frags=b_frags, + a_scales=a_scales, b_scales=b_scales, + is_fp4=is_fp4, is_a8w4=is_a8w4, + use_scale_opsel=False, + rocdl=rocdl, T=T, + ) + + def _compute_k_tile(accs_in, buf_idx, mid_compute_callback=None): + _mid_emit_ks = 0 + if k_wmma_steps > 1: + _mid_emit_wm = wmma_m_rep - 1 + _mid_emit_wn = wmma_n_rep - 1 + else: + _front_wm = (wmma_m_rep + 1) // 2 + _front_wn = (wmma_n_rep + 1) // 2 + if wmma_m_rep > 1: + _mid_emit_wm = _front_wm - 1 + _mid_emit_wn = wmma_n_rep - 1 + else: + _mid_emit_wm = 0 + _mid_emit_wn = _front_wn - 1 + _did_mid = False + for ks in range_constexpr(k_wmma_steps): + b_v = [load_b_frag(lds_b_bufs[buf_idx], b_data_bases, wn, ks) + for wn in range_constexpr(wmma_n_rep)] + if is_fp4: + as_v = [load_scale_i32(lds_as_bufs[buf_idx], as_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] + bs_v = [load_scale_i32(lds_bs_bufs[buf_idx], bs_bases[bi], ks) + for bi in range_constexpr(b_scale_load_rep)] + else: + as_v = load_scale_b128(lds_as_bufs[buf_idx], as_bases[0], + wmma_m_rep, ks) + bs_v = [load_scale_i32(lds_bs_bufs[buf_idx], bs_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + for wm in range_constexpr(wmma_m_rep): + a_frag = load_data_frag(lds_a_bufs[buf_idx], + a_data_bases[wm], ks) + for wn in range_constexpr(wmma_n_rep): + emit_wmma(accs_in, wm, wn, a_frag, b_v, as_v, bs_v) + if ( + not _did_mid + and mid_compute_callback is not None + and ks == _mid_emit_ks + and wm == _mid_emit_wm + and wn == _mid_emit_wn + ): + mid_compute_callback() + _did_mid = True + return accs_in + + def _load_b_and_scales(buf_idx, ks): + b_v = [load_b_frag(lds_b_bufs[buf_idx], b_data_bases, wn, ks) + for wn in range_constexpr(wmma_n_rep)] + if is_fp4: + as_v = [load_scale_i32(lds_as_bufs[buf_idx], as_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] + bs_v = [load_scale_i32(lds_bs_bufs[buf_idx], bs_bases[bi], ks) + for bi in range_constexpr(b_scale_load_rep)] + else: + as_v = load_scale_b128(lds_as_bufs[buf_idx], as_bases[0], + wmma_m_rep, ks) + bs_v = [load_scale_i32(lds_bs_bufs[buf_idx], bs_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + return b_v, bs_v, as_v + + def _emit_rows(accs_in, start_wm, a_frags, b_frags, a_scales, b_scales): + for frag_i in range_constexpr(len(a_frags)): + wm = start_wm + frag_i + for wn_raw in range_constexpr(wmma_n_rep): + wn = (wmma_n_rep - 1 - wn_raw) if (wm % 2 == 1) else wn_raw + emit_wmma(accs_in, wm, wn, a_frags[frag_i], b_frags, a_scales, b_scales) + + def _a_streaming_compute( + accs_in, + buf_idx, + b_frags, + b_scales, + a_scales, + ks, + next_bs_info=None, + mid_compute_callback=None, + ): + current_accs = accs_in + next_result = None + a_frags_front = [ + load_data_frag(lds_a_bufs[buf_idx], a_data_bases[wm], ks) + for wm in range_constexpr(_front_wm) + ] + _use_partial_drain = ( + next_bs_info is not None + and _front_wm * wmma_n_rep >= 4 + ) + + if _use_partial_drain: + _next_buf_idx, _next_ks = next_bs_info + next_result = _load_b_and_scales(_next_buf_idx, _next_ks) + rocdl.s_wait_dscnt(_bs_ds_loads) + else: + rocdl.s_wait_dscnt(0) + + _emit_rows(current_accs, 0, a_frags_front, b_frags, a_scales, b_scales) + + if mid_compute_callback is not None: + rocdl.sched_barrier(0) + mid_compute_callback() + + if _back_wm > 0: + a_frags_back = [ + load_data_frag( + lds_a_bufs[buf_idx], + a_data_bases[_front_wm + h], + ks, + ) + for h in range_constexpr(_back_wm) + ] + _back_drain = _bs_ds_loads if _use_partial_drain else 0 + rocdl.s_wait_dscnt(_back_drain) + _emit_rows( + current_accs, + _front_wm, + a_frags_back, + b_frags, + a_scales, + b_scales, + ) + + if _use_partial_drain: + return current_accs, next_result + if next_bs_info is not None: + _next_buf_idx, _next_ks = next_bs_info + next_result = _load_b_and_scales(_next_buf_idx, _next_ks) + return current_accs, next_result + return current_accs + + def _compute_k_tile_scheduled(accs_in, buf_idx, mid_compute_callback=None): + current_accs = list(accs_in) + if k_wmma_steps == 1: + b_v, bs_v, as_v = _load_b_and_scales(buf_idx, 0) + current_accs = _a_streaming_compute( + current_accs, + buf_idx, + b_v, + bs_v, + as_v, + 0, + mid_compute_callback=mid_compute_callback, + ) + else: + prev_b, prev_bs, prev_as = _load_b_and_scales(buf_idx, 0) + for ks in range_constexpr(k_wmma_steps - 1): + _mid_cb = mid_compute_callback if ks == 0 else None + current_accs, (prev_b, prev_bs, prev_as) = _a_streaming_compute( + current_accs, + buf_idx, + prev_b, + prev_bs, + prev_as, + ks, + next_bs_info=(buf_idx, ks + 1), + mid_compute_callback=_mid_cb, + ) + current_accs = _a_streaming_compute( + current_accs, + buf_idx, + prev_b, + prev_bs, + prev_as, + k_wmma_steps - 1, + ) + return current_accs + + def _hot_loop_scheduler_scheduled(): + if not _use_scheduled_compute: + return + _front_a_loads = _front_wm * DS_LOADS_PER_A_FRAG + _back_a_loads = _back_wm * DS_LOADS_PER_A_FRAG + for _ks in range_constexpr(k_wmma_steps): + if _ks == 0: + rocdl.sched_dsrd(_bs_ds_loads + _front_a_loads) + else: + rocdl.sched_dsrd(_front_a_loads) + rocdl.sched_mfma(_front_wmma) + if _back_wmma > 0: + rocdl.sched_dsrd(_back_a_loads) + rocdl.sched_mfma(_back_wmma) + if _ks < k_wmma_steps - 1: + rocdl.sched_dsrd(_bs_ds_loads) + rocdl.sched_barrier(0) + + if wave_specialized_tdm: + _tdm_wave_id = rocdl.wave_id() + _is_loader_wave = arith.cmpi( + arith.CmpIPredicate.ult, + _tdm_wave_id, + arith.constant(_tdm_loader_waves, type=T.i32), + ) + _tdm_pred = arith.constant(1, type=T.i32) + + def _select_wave_tdm_value(b_value, bs_value): + _wave_is_b = arith.cmpi( + arith.CmpIPredicate.eq, + _tdm_wave_id, + arith.constant(0, type=T.i32), + ) + return arith.select(_wave_is_b, b_value, bs_value) + + def _tdm_desc_lds_addr(desc): + return vector.extract( + desc.dgroup0, + static_position=[1], + dynamic_position=[], + ) + + def _tdm_desc_addr_lo(desc): + return vector.extract( + desc.dgroup0, + static_position=[2], + dynamic_position=[], + ) + + def _tdm_desc_addr_hi(desc): + return vector.extract( + desc.dgroup0, + static_position=[3], + dynamic_position=[], + ) + + _eid = arith.index_cast(T.index, eid_i32) + _n_init = _eid * n_idx + blk_n + _zero_k_base = arith.index(0) + _data_adv_i32 = arith.constant( + packed_tile_k_b if is_fp4 else packed_tile_k_b * 16, + type=T.i32, + ) + _scale_adv_i32 = arith.constant(scale_k_per_tile, type=T.i32) + + _stages_b_lds_addr = [ + _tdm_desc_lds_addr( + make_desc_b( + _n_init, + _zero_k_base, + lds_b_bufs[i], + ) + ) + for i in range_constexpr(_nb) + ] + _stages_bs_lds_addr = [ + _tdm_desc_lds_addr( + make_desc_bs( + _n_init, + _zero_k_base, + lds_bs_bufs[i], + ) + ) + for i in range_constexpr(_nb) + ] + + _desc_b_init = make_desc_b( + _n_init, + _zero_k_base, + lds_b_bufs[0], + ) + _desc_bs_init = make_desc_bs( + _n_init, + _zero_k_base, + lds_bs_bufs[0], + ) + + _active_stage_lds_addr = [ + _select_wave_tdm_value( + _stages_b_lds_addr[i], + _stages_bs_lds_addr[i], + ) + for i in range_constexpr(_nb) + ] + _active_addr_lo = _select_wave_tdm_value( + _tdm_desc_addr_lo(_desc_b_init), + _tdm_desc_addr_lo(_desc_bs_init), + ) + _active_addr_hi = _select_wave_tdm_value( + _tdm_desc_addr_hi(_desc_b_init), + _tdm_desc_addr_hi(_desc_bs_init), + ) + _active_dgroup1 = _select_wave_tdm_value( + _desc_b_init.dgroup1, + _desc_bs_init.dgroup1, + ) + _active_adv_i32 = _select_wave_tdm_value( + _data_adv_i32, + _scale_adv_i32, + ) + def _issue_active_b_tdm_only(stage_idx, curr_addr_lo): + _if_loader = scf.IfOp(_is_loader_wave) + with ir.InsertionPoint(_if_loader.then_block): + _dg0 = vector.from_elements(T.vec(4, T.i32), [ + _tdm_pred, + _active_stage_lds_addr[stage_idx], + curr_addr_lo, + _active_addr_hi, + ]) + tdm_ops.tensor_load_2d( + tdm_ops.TDMDescriptor2D(_dg0, _active_dgroup1) + ) + scf.YieldOp([]) + _next_addr_lo = arith.addi(curr_addr_lo, _active_adv_i32) + return arith.select( + _is_loader_wave, + _next_addr_lo, + curr_addr_lo, + ) + + # ── pipeline load helpers ───────────────────────────────── + def _issue_b_tdm_only(k_base, buf_idx): + _eid = arith.index_cast(T.index, eid_i32) + _n = _eid * n_idx + blk_n + tdm_ops.tensor_load_2d( + make_desc_b(_n, k_base, lds_b_bufs[buf_idx])) + tdm_ops.tensor_load_2d( + make_desc_bs(_n, k_base, lds_bs_bufs[buf_idx])) + + def _issue_scalar_loads(k_base, buf_idx): + if _use_tdm_gather_a: + issue_a_load_tdm_gather(k_base, lds_a_bufs[buf_idx]) + else: + issue_a_load(make_desc_a(k_base), lds_a_bufs[buf_idx]) + issue_as_load(make_desc_as(k_base), lds_as_bufs[buf_idx]) + + def _issue_all_loads(k_base, buf_idx): + _issue_b_tdm_only(k_base, buf_idx) + _issue_scalar_loads(k_base, buf_idx) + + def _compute_with_mid_loads(accs_in, buf_idx, mid_load_callback=None): + if _use_scheduled_compute: + return _compute_k_tile_scheduled( + accs_in, buf_idx, + mid_compute_callback=mid_load_callback, + ) + return _compute_k_tile( + accs_in, buf_idx, + mid_compute_callback=mid_load_callback, + ) + + # ── main K-dimension reduction ──────────────────────────── + if not _use_pipeline: + # Single-buffer path (num_buffers=1) + if wave_specialized_tdm: + active_b_addr_lo = _active_addr_lo + for kt in range_constexpr(num_k_tiles): + k_base = fx.Index(kt * int(tile_k)) + active_b_addr_lo = _issue_active_b_tdm_only( + 0, active_b_addr_lo) + _issue_scalar_loads(k_base, 0) + tdm_ops.tensor_wait(0) + workgroup_barrier(use_cluster=use_cluster) + acc = _compute_k_tile(acc, 0) + workgroup_barrier(use_cluster=use_cluster) + else: + for kt in range_constexpr(num_k_tiles): + k_base = fx.Index(kt * int(tile_k)) + _issue_all_loads(k_base, 0) + tdm_ops.tensor_wait(0) + workgroup_barrier(use_cluster=use_cluster) + acc = _compute_k_tile(acc, 0) + workgroup_barrier(use_cluster=use_cluster) + else: + # Multi-buffer pipeline + # ── prologue: pre-load first `pre_loaded` stages ── + if wave_specialized_tdm: + active_b_addr_lo = _active_addr_lo + for _pi in range_constexpr(pre_loaded): + active_b_addr_lo = _issue_active_b_tdm_only( + _pi, active_b_addr_lo) + _issue_scalar_loads(fx.Index(_pi * int(tile_k)), _pi) + else: + for _pi in range_constexpr(pre_loaded): + _issue_all_loads(fx.Index(_pi * int(tile_k)), _pi) + pipeline_fence(outstanding=0, use_cluster=use_cluster) + + # ── main pipelined loop ── + if loop_iters > 0: + if wave_specialized_tdm: + _init = list(acc) + [active_b_addr_lo] + for _li, _st in fx.range(0, loop_iters, 1, init=_init): + _acc = list(_st[:n_accs]) + _cur_b_addr_lo = _st[n_accs] + for _bi in range_constexpr(_nb): + _lb = (_bi + _nb - 1) % _nb + _kt = (_li * fx.Index(_nb) + + fx.Index(pre_loaded + _bi)) + _kb = _kt * fx.Index(int(tile_k)) + pipeline_fence_signal( + outstanding=_fence_outstanding, + use_cluster=use_cluster) + pipeline_fence_wait(use_cluster=use_cluster) + + _cur_b_addr_lo = _issue_active_b_tdm_only( + _lb, _cur_b_addr_lo) + + def _mid_issue_scalar(_mid_kb=_kb, _mid_lb=_lb): + _issue_scalar_loads(_mid_kb, _mid_lb) + + if _use_scheduled_compute: + rocdl.sched_barrier(0) + _acc = _compute_with_mid_loads( + _acc, + _bi, + _mid_issue_scalar, + ) + if _use_scheduled_compute: + _hot_loop_scheduler_scheduled() + _res = yield list(_acc) + [_cur_b_addr_lo] + acc = list(_res[:n_accs]) + active_b_addr_lo = _res[n_accs] + else: + _init = list(acc) + for _li, _st in fx.range(0, loop_iters, 1, init=_init): + _acc = list(_st[:n_accs]) if isinstance(_st, (list, tuple)) else [_st] + for _bi in range_constexpr(_nb): + _lb = (_bi + _nb - 1) % _nb + _kt = (_li * fx.Index(_nb) + + fx.Index(pre_loaded + _bi)) + _kb = _kt * fx.Index(int(tile_k)) + pipeline_fence_signal( + outstanding=_fence_outstanding, + use_cluster=use_cluster) + pipeline_fence_wait(use_cluster=use_cluster) + + _issue_b_tdm_only(_kb, _lb) + + def _mid_issue_scalar(_mid_kb=_kb, _mid_lb=_lb): + _issue_scalar_loads(_mid_kb, _mid_lb) + + if _use_scheduled_compute: + rocdl.sched_barrier(0) + _acc = _compute_with_mid_loads( + _acc, + _bi, + _mid_issue_scalar, + ) + if _use_scheduled_compute: + _hot_loop_scheduler_scheduled() + _res = yield list(_acc) + acc = list(_res[:n_accs]) if isinstance(_res, (list, tuple)) else [_res] + + # ── post-loop fence ── + if loop_iters > 0: + pipeline_fence(outstanding=0, use_cluster=use_cluster) + elif use_cluster: + gpu.cluster_barrier() + + # ── tail ── + _tail_li = 0 + _tail_had_load = False + for _ls, _cs, _out in _tail_plan: + if _out == -1: + if _tail_had_load: + pipeline_fence(outstanding=0, + use_cluster=use_cluster) + if _use_scheduled_compute: + rocdl.sched_barrier(0) + acc = _compute_k_tile_scheduled(acc, _cs) + _hot_loop_scheduler_scheduled() + else: + acc = _compute_k_tile(acc, _cs) + else: + pipeline_fence_signal(outstanding=_out, + use_cluster=use_cluster) + pipeline_fence_wait(use_cluster=use_cluster) + if _ls is not None: + _tail_had_load = True + _tkb = fx.Index( + (_tail_start + pre_loaded + _tail_li) + * int(tile_k)) + _tail_li += 1 + + if wave_specialized_tdm: + active_b_addr_lo = _issue_active_b_tdm_only( + _ls, active_b_addr_lo) + else: + _issue_b_tdm_only(_tkb, _ls) + + def _tail_mid_issue_scalar(_mid_kb=_tkb, _mid_ls=_ls): + _issue_scalar_loads(_mid_kb, _mid_ls) + + if _use_scheduled_compute: + rocdl.sched_barrier(0) + acc = _compute_with_mid_loads( + acc, + _cs, + _tail_mid_issue_scalar, + ) + if _use_scheduled_compute: + _hot_loop_scheduler_scheduled() + else: + if _use_scheduled_compute: + rocdl.sched_barrier(0) + acc = _compute_k_tile_scheduled(acc, _cs) + _hot_loop_scheduler_scheduled() + else: + acc = _compute_k_tile(acc, _cs) + + out_elem_ty = _moe_out_elem_ty(out_dtype, T) + + if bool(use_tdm_store): + # ── TDM store epilogue: acc → LDS → global (contiguous sorted output) ── + # Pre-compute per-wm row scale (weight × validity mask) + _scale_per_wm = [] + for _wm in range_constexpr(wmma_m_rep): + _m_off_val = _wm * WMMA_M + _row_local = warp_m_base + arith.index(_m_off_val) + lane16 + _sorted_row = by * arith.index(int(tile_m)) + _row_local + _sorted_i32 = arith.index_cast(T.i32, _sorted_row) + _row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + arith.index_cast(T.i32, _row_local), + arith.constant(int(route_tile_m), type=T.i32)) + _row_in_valid = arith.cmpi( + arith.CmpIPredicate.slt, _sorted_i32, num_valid_i32) + _row_ok = arith.andi(_row_in_route, _row_in_valid) + if bool(doweight_stage2): + _sorted_safe = arith.select( + _row_ok, _sorted_i32, block_row_start) + _tw = buffer_ops.buffer_load( + tw_rsrc, _sorted_safe, vec_width=1, dtype=T.f32) + _sc = arith.select( + _row_ok, _tw, + arith.constant(0.0, type=T.f32)) + else: + _sc = arith.select( + _row_ok, + arith.constant(1.0, type=T.f32), + arith.constant(0.0, type=T.f32)) + _scale_per_wm.append(_sc) + + if d_need_epilogue_fence: + pipeline_fence(outstanding=0, use_cluster=use_cluster) + rocdl.sched_barrier(0) + + for _acc_idx, _vec_base, _m_off, _wn in _sub_tiles: + _wm_idx = _m_off // WMMA_M + _sc = _scale_per_wm[_wm_idx] + _sub8 = _extract_sub8( + acc[_acc_idx], _vec_base, + vector=vector, + range_constexpr=range_constexpr, + ACC_VEC_SIZE=ACC_VEC_SIZE) + _scaled = [] + for _vi in range_constexpr(8): + _v = vector.extract( + _sub8, + static_position=[_vi], + dynamic_position=[]) + _scaled.append(_v * _sc) + _scaled_sub8 = vector.from_elements( + T.vec(8, T.f32), _scaled) + _imm = _m_off * _lds_d_stride_elems + _wn * _n_col_d_elems + store_acc_vec8_to_lds( + d_lds_buffer, d_lane_base, _imm, _scaled_sub8, + out_elem=out_elem_ty) + + rocdl.s_wait_dscnt(0) + tdm_ops.tensor_store_2d(d_desc) + tdm_ops.tensor_wait(0) + else: + def _load_sub8(acc_idx, vec_base): + return _extract_sub8( + acc[acc_idx], vec_base, vector=vector, range_constexpr=range_constexpr, ACC_VEC_SIZE=ACC_VEC_SIZE + ) + + _emit_stage2_store_epilogue( + sub_tiles=_sub_tiles, + by=by, + tile_m=int(tile_m), + route_tile_m=int(route_tile_m), + warp_m_base=warp_m_base, + warp_n_base=warp_n_base, + blk_n=blk_n, + lane16=lane16, + lane_kgrp=lane_kgrp, + WMMA_N=WMMA_N, + i32_tokens_in=i32_tokens_in, + i32_n_in=i32_n_in, + topk=int(topk), + num_valid_i32=num_valid_i32, + block_row_start=block_row_start, + sorted_rsrc=sorted_rsrc, + tw_rsrc=tw_rsrc, + out_rsrc=out_rsrc, + doweight_stage2=bool(doweight_stage2), + accumulate=bool(accumulate), + out_elem_ty=out_elem_ty, + load_sub8=_load_sub8, + ir=ir, + fx=fx, + arith=arith, + buffer_ops=buffer_ops, + scf=scf, + vector=vector, + range_constexpr=range_constexpr, + rocdl=rocdl, + T=T, + ) + scf.YieldOp([]) + + @flyc.jit + def launch_mxscale_stage2_single( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + stream: fx.Stream, + ): + _ = i32_k_in + ctx = CompilationContext.get_current() + n_in = arith.index_cast(T.index, i32_n_in) + size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) + gx = (n_in + fx.Index(int(tile_n) - 1)) // fx.Index(int(tile_n)) + gy = size_expert_ids_in + launcher = moe_mxscale_stage2_single( + arg_out, arg_x, arg_w, arg_scale_x, arg_scale_w, + arg_sorted_token_ids, arg_expert_ids, arg_sorted_weights, arg_num_valid_ids, + i32_tokens_in, i32_n_in, i32_k_in, i32_size_expert_ids_in, + ) + _cluster_arg = (int(cluster_m), int(cluster_n), 1) if use_cluster else None + _finalize_alloc_and_launch_2d( + ctx=ctx, + alloc=alloc, + launcher=launcher, + gx=gx, + gy=gy, + block_threads=block_threads, + stream=stream, + waves_per_eu=effective_waves_per_eu, + ir=ir, + cluster=_cluster_arg, + ) + + if expert_sched_mode: + launch_mxscale_stage2_single.compile_hints["llvm_options"] = { + "amdgpu-expert-scheduling-mode": True, + } + + return launch_mxscale_stage2_single + + +# --------------------------------------------------------------------------- +# Public API entry points for fp4/fp8/a8w4 +# --------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=1024) +def _compile_moe_mxscale_gemm( + *, + stage: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight: bool, + in_dtype: str = "fp4", + out_dtype: str = "f16", + accumulate: bool = True, + waves_per_eu: int | None = None, + expert_sched_mode: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, +): + _require_gfx1250() + if waves_per_eu is not None and int(waves_per_eu) < 1: + raise ValueError(f"waves_per_eu must be >= 1, got {waves_per_eu!r}") + if in_dtype not in ("fp4", "fp8", "a8w4"): + raise ValueError( + f"Unsupported in_dtype for MXScale stage{stage}: {in_dtype!r}, " + "expected 'fp4', 'fp8', or 'a8w4'" + ) + + single_tile_m, single_tile_n, single_m_warp, single_n_warp = _pick_mxscale_launch_shape( + in_dtype, int(tile_m), int(tile_n), + ) + common = dict( + model_dim=int(model_dim), inter_dim=int(inter_dim), + experts=int(experts), topk=int(topk), + route_tile_m=int(tile_m), + tile_m=int(single_tile_m), tile_n=int(single_tile_n), tile_k=int(tile_k), + m_warp=int(single_m_warp), n_warp=int(single_n_warp), + out_dtype=out_dtype, waves_per_eu=waves_per_eu, data_format=in_dtype, + expert_sched_mode=expert_sched_mode, num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), cluster_n=int(cluster_n), + ) + + if stage == 1: + exe = _compile_stage1_mxscale_kernel_impl(doweight_stage1=bool(doweight), **common) + if in_dtype in ("fp8", "a8w4") and (int(inter_dim) % int(single_tile_n) == 0): + return _Stage1GateUpPackedWrapper( + exe, + experts=int(experts), inter_dim=int(inter_dim), + tile_n=int(single_tile_n), + packed_cols_w=(int(model_dim) // 2) if in_dtype == "a8w4" else int(model_dim), + packed_cols_scale=int(model_dim) // 32, + ) + return exe + + return _compile_stage2_mxscale_kernel_impl( + doweight_stage2=bool(doweight), accumulate=bool(accumulate), **common, + ) + + +def compile_moe_gemm1(*, doweight_stage1, group_size=-1, use_cshuffle_epilog=None, **kw): + return _compile_moe_mxscale_gemm(stage=1, doweight=doweight_stage1, **kw) + + +def compile_moe_gemm2(*, doweight_stage2, accumulate=True, group_size=-1, use_cshuffle_epilog=None, **kw): + return _compile_moe_mxscale_gemm(stage=2, doweight=doweight_stage2, accumulate=accumulate, **kw) + + +def compile_moe_gemm2_ex(*, mode=MoeGemm2Mode.ATOMIC, valid_mask=None, zero_intermediate=True, **kw): + if mode == MoeGemm2Mode.REDUCE: + gemm2_exe = compile_moe_gemm2(accumulate=False, **kw) + out_s = str(kw.get("out_dtype", "f16")).strip().lower() + if out_s in ("f16", "fp16", "half"): + dtype_str = "f16" + elif out_s in ("bf16", "bfloat16"): + dtype_str = "bf16" + else: + dtype_str = "f32" + reduce_exe = compile_moe_reduction( + topk=kw["topk"], model_dim=kw["model_dim"], + dtype_str=dtype_str, use_mask=(valid_mask is not None), + ) + from kernels.moe_gemm_2stage import _MoeGemm2ReduceWrapper + return _MoeGemm2ReduceWrapper( + gemm2_exe=gemm2_exe, reduce_exe=reduce_exe, + topk=kw["topk"], model_dim=kw["model_dim"], + out_dtype_str=dtype_str, + use_mask=(valid_mask is not None), + zero_intermediate=zero_intermediate, + ) + return compile_moe_gemm2(accumulate=True, **kw) diff --git a/kernels/moe_gemm_2stage_wmma_gfx1250.py b/kernels/moe_gemm_2stage_wmma_gfx1250.py new file mode 100644 index 00000000..0dd7d935 --- /dev/null +++ b/kernels/moe_gemm_2stage_wmma_gfx1250.py @@ -0,0 +1,912 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + + +"""gfx1250 MoE 2-stage fp16 WMMA kernels. + +Implements stage1/stage2 single-kernel inline paths using the +``wmma_f32_16x16x32_f16`` instruction for fp16 (and bf16 via host +conversion) inputs. +""" + +from __future__ import annotations + +import functools + +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + +from kernels.moe_gemm_2stage import ( + MoeGemm2Mode, + compile_moe_reduction, +) +from kernels.moe_gemm_2stage_common_gfx1250 import ( + _bf16_to_f16_wrapper, + _emit_stage1_gate_up_epilogue, + _emit_stage2_store_epilogue, + _finalize_alloc_and_launch_2d, + _make_moe_wave_layout, + _make_wmma_sub_tiles, + _moe_out_elem_ty, + _pick_fp16_single_launch_shape, + _require_gfx1250, +) + +@functools.lru_cache(maxsize=64) +def _compile_stage1_wmma_kernel_impl( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + route_tile_m: int, + tile_m: int, + tile_n: int, + tile_k: int, + m_warp: int, + n_warp: int, + doweight_stage1: bool, + out_dtype: str, + waves_per_eu: int | None, + expert_sched_mode: bool = True, +): + """Compile dense stage1 single kernel: route-pack + TDM + WMMA + epilog.""" + import flydsl.compiler as flyc + import flydsl.expr as fx + from flydsl._mlir import ir + from flydsl._mlir.dialects import llvm as llvm_dialect + from flydsl._mlir.dialects import scf + from flydsl.compiler.kernel_function import CompilationContext + from flydsl.expr import arith, buffer_ops, gpu, idx2crd, range_constexpr, rocdl, tdm_ops, vector + from flydsl.expr.typing import T + from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + + WMMA_M, WMMA_N, WMMA_K = 16, 16, 32 + WAVE_SIZE = 32 + LDS_PAD_A = 8 + LDS_PAD_B = 8 + elem_bytes = 2 + + if out_dtype not in ("f16", "bf16"): + raise ValueError(f"fp16 stage1 single kernel supports out_dtype in ('f16','bf16'), got {out_dtype!r}") + if (int(model_dim) % int(tile_k)) != 0: + raise ValueError(f"model_dim={model_dim} must be divisible by tile_k={tile_k}") + if (int(tile_k) % WMMA_K) != 0: + raise ValueError(f"tile_k={tile_k} must be divisible by {WMMA_K}") + if (int(tile_m) % WMMA_M) != 0 or (int(tile_n) % WMMA_N) != 0: + raise ValueError(f"tile_m/tile_n must be multiples of 16, got ({tile_m},{tile_n})") + + block_threads = int(m_warp) * int(n_warp) * WAVE_SIZE + warp_tile_m = int(tile_m) // int(m_warp) + warp_tile_n = int(tile_n) // int(n_warp) + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + if wmma_m_rep <= 0 or wmma_n_rep <= 0: + raise ValueError(f"Invalid warp tiling for fp16 single kernel: wmma_m_rep={wmma_m_rep}, wmma_n_rep={wmma_n_rep}") + + n_accs = wmma_m_rep * wmma_n_rep + num_k_tiles = int(model_dim) // int(tile_k) + k_wmma_steps = int(tile_k) // WMMA_K + n_total = int(2 * inter_dim) + _sub_tiles = _make_wmma_sub_tiles( + wmma_m_rep=wmma_m_rep, wmma_n_rep=wmma_n_rep, WMMA_M=WMMA_M, is_fp4=False + ) + + lds_a_stride = int(tile_k) + LDS_PAD_A + lds_b_stride = int(tile_n) + LDS_PAD_B + lds_a_elems = int(tile_m) * lds_a_stride + LDS_PAD_A + lds_b_elems = int(tile_k) * lds_b_stride + LDS_PAD_B + + alloc = SmemAllocator(None, arch=str(get_hip_arch()), global_sym_name="moe_fp16_s1_single") + off_bg = alloc._align(alloc.ptr, 16) + alloc.ptr = off_bg + lds_b_elems * elem_bytes + off_bu = alloc._align(alloc.ptr, 16) + alloc.ptr = off_bu + lds_b_elems * elem_bytes + off_a = alloc._align(alloc.ptr, 16) + alloc.ptr = off_a + lds_a_elems * elem_bytes + + @flyc.kernel(known_block_size=[block_threads, 1, 1]) + def moe_fp16_stage1_single( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_max_token_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_inter_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): + _ = (arg_scale_x, arg_scale_w, arg_max_token_ids, i32_k_in) + llvm_dialect.inline_asm( + None, [], # void result, no operands + "s_setreg_imm32_b32 hwreg(26, 4, 1), 1", + "", + has_side_effects=True, + ) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") # inter tile + by = gpu.block_id("y") # expert block + + tokens_idx = arith.index_cast(T.index, i32_tokens_in) + inter_idx = arith.index_cast(T.index, i32_inter_in) + size_expert_ids = arith.index_cast(T.index, i32_size_expert_ids_in) + + sorted_num = size_expert_ids * arith.index(int(route_tile_m)) + sorted_nbytes = sorted_num * arith.index(4) + eid_nbytes = size_expert_ids * arith.index(4) + x_nbytes = tokens_idx * arith.index(int(model_dim)) * arith.index(2) + w_nbytes = arith.index(int(experts * n_total * int(model_dim) * 2)) + + sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes) + eid_rsrc = buffer_ops.create_buffer_resource(arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes) + x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False, num_records_bytes=w_nbytes) + out_rsrc = buffer_ops.create_buffer_resource(arg_out, max_size=True) + sw_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=True) + + eid_i32 = buffer_ops.buffer_load(eid_rsrc, arith.index_cast(T.i32, by), vec_width=1, dtype=T.i32) + eid_ok0 = arith.cmpi(arith.CmpIPredicate.sge, eid_i32, arith.constant(0, type=T.i32)) + eid_ok1 = arith.cmpi(arith.CmpIPredicate.slt, eid_i32, arith.constant(int(experts), type=T.i32)) + eid_ok = arith.andi(eid_ok0, eid_ok1) + + layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3) + ) + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + blk_n = bx * arith.index(int(tile_n)) + + base_ptr = alloc.get_base() + smem_bg = SmemPtr(base_ptr, off_bg, T.f16, shape=(lds_b_elems,)) + smem_bu = SmemPtr(base_ptr, off_bu, T.f16, shape=(lds_b_elems,)) + smem_a = SmemPtr(base_ptr, off_a, T.f16, shape=(lds_a_elems,)) + lds_bg = get_op_result_or_value(smem_bg.get()) + lds_bu = get_op_result_or_value(smem_bu.get()) + lds_a = get_op_result_or_value(smem_a.get()) + + def silu(x): + t = x * (-1.4426950408889634) + emu = rocdl.exp2(T.f32, t) + den = 1.0 + emu + sig = rocdl.rcp(T.f32, den) + return x * sig + + def pack_a_to_lds(k_base): + total = int(tile_m * tile_k) + rounds = (total + block_threads - 1) // block_threads + for it in range(rounds): + elem = tx + fx.Index(it * block_threads) + in_range = arith.cmpi( + arith.CmpIPredicate.ult, + arith.index_cast(T.i32, elem), + arith.constant(total, type=T.i32), + ) + _if_elem = scf.IfOp(in_range) + with ir.InsertionPoint(_if_elem.then_block): + row = elem // arith.index(int(tile_k)) + col = elem % arith.index(int(tile_k)) + sorted_row = by * arith.index(int(tile_m)) + row + row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + arith.index_cast(T.i32, row), + arith.constant(int(route_tile_m), type=T.i32), + ) + sorted_row_safe = arith.select( + row_in_route, + arith.index_cast(T.i32, sorted_row), + arith.index_cast(T.i32, by * arith.index(int(route_tile_m))), + ) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_row_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + tok_ok0 = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + tok_ok = arith.andi(row_in_route, tok_ok0) + x_idx = tok * arith.constant(int(model_dim), type=T.i32) + arith.index_cast(T.i32, k_base + col) + x_idx_safe = arith.select(tok_ok, x_idx, arith.constant(0, type=T.i32)) + x_val = arith.select( + tok_ok, + buffer_ops.buffer_load(x_rsrc, x_idx_safe, vec_width=1, dtype=T.f16), + arith.constant(0.0, type=T.f16), + ) + lds_idx = row * arith.index(lds_a_stride) + col + v1 = vector.from_elements(T.vec(1, T.f16), [x_val]) + vector.store(v1, lds_a, [lds_idx], alignment=2) + scf.YieldOp([]) + + def copy_b_to_lds(k_base, lds_memref, up_shift): + eid_idx = arith.index_cast(T.index, eid_i32) + n_base = eid_idx * arith.index(n_total) + blk_n + arith.index(up_shift) + total = int(tile_k) * int(tile_n) + rounds = (total + block_threads - 1) // block_threads + for it in range(rounds): + elem = tx + fx.Index(it * block_threads) + in_range = arith.cmpi( + arith.CmpIPredicate.ult, + arith.index_cast(T.i32, elem), + arith.constant(total, type=T.i32), + ) + _if_elem = scf.IfOp(in_range) + with ir.InsertionPoint(_if_elem.then_block): + k_local = elem // arith.index(int(tile_n)) + n_local = elem % arith.index(int(tile_n)) + w_idx = (n_base + n_local) * arith.index(int(model_dim)) + k_base + k_local + w_val = buffer_ops.buffer_load( + w_rsrc, arith.index_cast(T.i32, w_idx), + vec_width=1, dtype=T.f16, + ) + lds_idx = k_local * arith.index(lds_b_stride) + n_local + v1 = vector.from_elements(T.vec(1, T.f16), [w_val]) + vector.store(v1, lds_memref, [lds_idx], alignment=2) + scf.YieldOp([]) + + def _precompute_a_lane_bases(): + row_stride_off = (warp_m_base + lane16) * arith.index(lds_a_stride) + k_lane_off = lane_kgrp * arith.index(8) + bases = [] + for wm in range_constexpr(wmma_m_rep): + a_base = row_stride_off + arith.index(wm * WMMA_M * lds_a_stride) + k_lane_off + bases.append(a_base) + return bases + + def _precompute_b_lane_bases(): + lane8 = lane16 % arith.index(8) + lane_ngrp = lane16 / arith.index(8) + k_lane_off = (lane_kgrp * arith.index(8) + lane8) * arith.index(lds_b_stride) + n_lane_off = lane_ngrp * arith.index(8) + bases = [] + for wn in range_constexpr(wmma_n_rep): + n_col = warp_n_base + arith.index(wn * WMMA_N) + n_lane_off + bases.append(k_lane_off + n_col) + return bases + + def load_a_frag(a_base, ks): + vec8_ty = ir.VectorType.get([8], T.f16) + off0 = a_base + arith.index(ks * WMMA_K) + off1 = a_base + arith.index(ks * WMMA_K + 16) + v0 = vector.load_op(vec8_ty, lds_a, [off0]) + v1 = vector.load_op(vec8_ty, lds_a, [off1]) + return vector.shuffle(v0, v1, list(range(16))) + + def load_b_frag(lds_buf, b_base, ks): + vec8_ty = ir.VectorType.get([8], T.f16) + results = [] + for k_half in range_constexpr(2): + k_row_off = (ks * WMMA_K + k_half * 16) * lds_b_stride + elem_off = b_base + arith.index(k_row_off) + v = rocdl.lds_transpose_load(vec8_ty, lds_buf, elem_off, elem_bytes) + results.append(v) + return vector.shuffle(results[0], results[1], list(range(16))) + + acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) + acc_gate = [acc_zero] * n_accs + acc_up = [acc_zero] * n_accs + + _if_eid = scf.IfOp(eid_ok) + with ir.InsertionPoint(_if_eid.then_block): + a_bases = _precompute_a_lane_bases() + b_bases = _precompute_b_lane_bases() + for kt in range_constexpr(num_k_tiles): + k_base = fx.Index(kt * int(tile_k)) + pack_a_to_lds(k_base) + copy_b_to_lds(k_base, lds_bg, 0) + copy_b_to_lds(k_base, lds_bu, int(inter_dim)) + gpu.barrier() + + for ks in range_constexpr(k_wmma_steps): + b_gate_frags = [load_b_frag(lds_bg, b_bases[wn], ks) for wn in range_constexpr(wmma_n_rep)] + b_up_frags = [load_b_frag(lds_bu, b_bases[wn], ks) for wn in range_constexpr(wmma_n_rep)] + for wm in range_constexpr(wmma_m_rep): + a_frag = load_a_frag(a_bases[wm], ks) + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + acc_gate[idx] = rocdl.wmma_f32_16x16x32_f16( + T.vec(8, T.f32), + b_gate_frags[wn], + a_frag, + acc_gate[idx], + signA=False, + signB=False, + modC=0, + reuseA=False, + reuseB=False, + ).result + acc_up[idx] = rocdl.wmma_f32_16x16x32_f16( + T.vec(8, T.f32), + b_up_frags[wn], + a_frag, + acc_up[idx], + signA=False, + signB=False, + modC=0, + reuseA=False, + reuseB=False, + ).result + gpu.barrier() + + out_elem_ty = _moe_out_elem_ty(out_dtype, T) + + def _load_gate_up_sub8(acc_idx, _vec_base): + return acc_gate[acc_idx], acc_up[acc_idx] + + _emit_stage1_gate_up_epilogue( + sub_tiles=_sub_tiles, + by=by, + tile_m=int(tile_m), + route_tile_m=int(route_tile_m), + warp_m_base=warp_m_base, + warp_n_base=warp_n_base, + blk_n=blk_n, + lane16=lane16, + lane_kgrp=lane_kgrp, + WMMA_N=WMMA_N, + i32_tokens_in=i32_tokens_in, + i32_inter_in=i32_inter_in, + topk=int(topk), + sorted_rsrc=sorted_rsrc, + tw_rsrc=sw_rsrc, + out_rsrc=out_rsrc, + doweight_stage1=bool(doweight_stage1), + out_elem_ty=out_elem_ty, + load_gate_up_sub8=_load_gate_up_sub8, + silu_fn=silu, + ir=ir, + fx=fx, + arith=arith, + buffer_ops=buffer_ops, + scf=scf, + vector=vector, + range_constexpr=range_constexpr, + T=T, + ) + scf.YieldOp([]) + + @flyc.jit + def launch_fp16_stage1_single( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_max_token_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_inter_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + stream: fx.Stream, + ): + _ = i32_k_in + ctx = CompilationContext.get_current() + inter_in = arith.index_cast(T.index, i32_inter_in) + size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) + gx = (inter_in + fx.Index(int(tile_n) - 1)) // fx.Index(int(tile_n)) + gy = size_expert_ids_in + launcher = moe_fp16_stage1_single( + arg_out, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_max_token_ids, + i32_tokens_in, + i32_inter_in, + i32_k_in, + i32_size_expert_ids_in, + ) + _finalize_alloc_and_launch_2d( + ctx=ctx, + alloc=alloc, + launcher=launcher, + gx=gx, + gy=gy, + block_threads=block_threads, + stream=stream, + waves_per_eu=waves_per_eu, + ir=ir, + ) + + if expert_sched_mode: + launch_fp16_stage1_single.compile_hints["llvm_options"] = { + "amdgpu-expert-scheduling-mode": True, + } + + return launch_fp16_stage1_single + + +@functools.lru_cache(maxsize=64) +def _compile_stage2_wmma_kernel_impl( + *, + inter_dim: int, + experts: int, + topk: int, + route_tile_m: int, + tile_m: int, + tile_n: int, + tile_k: int, + m_warp: int, + n_warp: int, + doweight_stage2: bool, + out_dtype: str, + accumulate: bool, + waves_per_eu: int | None, + expert_sched_mode: bool = True, +): + """Compile fp16 stage2 single kernel: route-pack + TDM + WMMA + epilog.""" + import flydsl.compiler as flyc + import flydsl.expr as fx + from flydsl._mlir import ir + from flydsl._mlir.dialects import llvm as llvm_dialect + from flydsl._mlir.dialects import scf + from flydsl.compiler.kernel_function import CompilationContext + from flydsl.expr import arith, buffer_ops, gpu, idx2crd, range_constexpr, rocdl, tdm_ops, vector + from flydsl.expr.typing import T + from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + + WMMA_M, WMMA_N, WMMA_K = 16, 16, 32 + WAVE_SIZE = 32 + LDS_PAD_A = 8 + LDS_PAD_B = 8 + elem_bytes = 2 + + if out_dtype not in ("f16", "bf16"): + raise ValueError(f"fp16 stage2 single kernel supports out_dtype in ('f16','bf16'), got {out_dtype!r}") + if (int(inter_dim) % int(tile_k)) != 0: + raise ValueError(f"inter_dim={inter_dim} must be divisible by tile_k={tile_k}") + if (int(tile_k) % WMMA_K) != 0: + raise ValueError(f"tile_k={tile_k} must be divisible by {WMMA_K}") + + block_threads = int(m_warp) * int(n_warp) * WAVE_SIZE + warp_tile_m = int(tile_m) // int(m_warp) + warp_tile_n = int(tile_n) // int(n_warp) + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + if wmma_m_rep <= 0 or wmma_n_rep <= 0: + raise ValueError(f"Invalid warp tiling for fp16 stage2 single kernel: wmma_m_rep={wmma_m_rep}, wmma_n_rep={wmma_n_rep}") + + n_accs = wmma_m_rep * wmma_n_rep + num_k_tiles = int(inter_dim) // int(tile_k) + k_wmma_steps = int(tile_k) // WMMA_K + _sub_tiles = _make_wmma_sub_tiles( + wmma_m_rep=wmma_m_rep, wmma_n_rep=wmma_n_rep, WMMA_M=WMMA_M, is_fp4=False + ) + + lds_a_stride = int(tile_k) + LDS_PAD_A + lds_b_stride = int(tile_n) + LDS_PAD_B + lds_a_elems = int(tile_m) * lds_a_stride + LDS_PAD_A + lds_b_elems = int(tile_k) * lds_b_stride + LDS_PAD_B + + alloc = SmemAllocator(None, arch=str(get_hip_arch()), global_sym_name="moe_fp16_s2_single") + off_b = alloc._align(alloc.ptr, 16) + alloc.ptr = off_b + lds_b_elems * elem_bytes + off_a = alloc._align(alloc.ptr, 16) + alloc.ptr = off_a + lds_a_elems * elem_bytes + + @flyc.kernel(known_block_size=[block_threads, 1, 1]) + def moe_fp16_stage2_single( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): + _ = (arg_scale_x, arg_scale_w, i32_k_in) + llvm_dialect.inline_asm( + None, [], # void result, no operands + "s_setreg_imm32_b32 hwreg(26, 4, 1), 1", + "", + has_side_effects=True, + ) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") # n tile + by = gpu.block_id("y") # expert block + + tokens_idx = arith.index_cast(T.index, i32_tokens_in) + n_idx = arith.index_cast(T.index, i32_n_in) + size_expert_ids = arith.index_cast(T.index, i32_size_expert_ids_in) + num_valid_i32 = buffer_ops.buffer_load( + buffer_ops.create_buffer_resource(arg_num_valid_ids, max_size=True), + arith.constant(0, type=T.i32), + vec_width=1, + dtype=T.i32, + ) + + sorted_num = size_expert_ids * arith.index(int(route_tile_m)) + sorted_nbytes = sorted_num * arith.index(4) + eid_nbytes = size_expert_ids * arith.index(4) + x_rows = tokens_idx * arith.index(int(topk)) + x_nbytes = x_rows * arith.index(int(inter_dim)) * arith.index(2) + out_nbytes = tokens_idx * n_idx * arith.index(2) + if not bool(accumulate): + out_nbytes = x_rows * n_idx * arith.index(2) + + sorted_rsrc = buffer_ops.create_buffer_resource(arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes) + eid_rsrc = buffer_ops.create_buffer_resource(arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes) + x_rsrc = buffer_ops.create_buffer_resource(arg_x, max_size=False, num_records_bytes=x_nbytes) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=True) + out_rsrc = buffer_ops.create_buffer_resource(arg_out, max_size=False, num_records_bytes=out_nbytes) + sw_rsrc = buffer_ops.create_buffer_resource(arg_sorted_weights, max_size=True) + + eid_i32 = buffer_ops.buffer_load(eid_rsrc, arith.index_cast(T.i32, by), vec_width=1, dtype=T.i32) + eid_ok0 = arith.cmpi(arith.CmpIPredicate.sge, eid_i32, arith.constant(0, type=T.i32)) + eid_ok1 = arith.cmpi(arith.CmpIPredicate.slt, eid_i32, arith.constant(int(experts), type=T.i32)) + block_row_start = arith.index_cast(T.i32, by * arith.index(int(route_tile_m))) + block_in_valid = arith.cmpi(arith.CmpIPredicate.slt, block_row_start, num_valid_i32) + block_ok = arith.andi(block_in_valid, arith.andi(eid_ok0, eid_ok1)) + + layout_thr = _make_moe_wave_layout(m_warp=m_warp, n_warp=n_warp, WAVE_SIZE=WAVE_SIZE, fx=fx) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + fx.get(thr_coord, 0), fx.get(thr_coord, 1), fx.get(thr_coord, 2), fx.get(thr_coord, 3) + ) + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + blk_n = bx * arith.index(int(tile_n)) + + base_ptr = alloc.get_base() + smem_b = SmemPtr(base_ptr, off_b, T.f16, shape=(lds_b_elems,)) + smem_a = SmemPtr(base_ptr, off_a, T.f16, shape=(lds_a_elems,)) + lds_b = get_op_result_or_value(smem_b.get()) + lds_a = get_op_result_or_value(smem_a.get()) + + def pack_a_to_lds(k_base): + total = int(tile_m * tile_k) + rounds = (total + block_threads - 1) // block_threads + for it in range(rounds): + elem = tx + fx.Index(it * block_threads) + in_range = arith.cmpi( + arith.CmpIPredicate.ult, + arith.index_cast(T.i32, elem), + arith.constant(total, type=T.i32), + ) + _if_elem = scf.IfOp(in_range) + with ir.InsertionPoint(_if_elem.then_block): + row = elem // arith.index(int(tile_k)) + col = elem % arith.index(int(tile_k)) + sorted_row = by * arith.index(int(tile_m)) + row + row_i32 = arith.index_cast(T.i32, row) + sorted_i32 = arith.index_cast(T.i32, sorted_row) + row_in_route = arith.cmpi( + arith.CmpIPredicate.ult, + row_i32, + arith.constant(int(route_tile_m), type=T.i32), + ) + row_in_valid = arith.cmpi(arith.CmpIPredicate.slt, sorted_i32, num_valid_i32) + row_ok = arith.andi(row_in_route, row_in_valid) + sorted_safe = arith.select(row_ok, sorted_i32, block_row_start) + fused = buffer_ops.buffer_load(sorted_rsrc, sorted_safe, vec_width=1, dtype=T.i32) + tok = fused & arith.constant((1 << 24) - 1, type=T.i32) + slot = fused >> arith.constant(24, type=T.i32) + tok_ok = arith.cmpi(arith.CmpIPredicate.ult, tok, i32_tokens_in) + slot_ok0 = arith.cmpi(arith.CmpIPredicate.sge, slot, arith.constant(0, type=T.i32)) + slot_ok1 = arith.cmpi(arith.CmpIPredicate.slt, slot, arith.constant(int(topk), type=T.i32)) + ts = tok * arith.constant(int(topk), type=T.i32) + slot + ts_ok = arith.andi(tok_ok, arith.andi(slot_ok0, slot_ok1)) + load_ok = arith.andi(row_ok, ts_ok) + x_idx = ts * arith.constant(int(inter_dim), type=T.i32) + arith.index_cast(T.i32, k_base + col) + x_idx_safe = arith.select(load_ok, x_idx, arith.constant(0, type=T.i32)) + x_val = arith.select( + load_ok, + buffer_ops.buffer_load(x_rsrc, x_idx_safe, vec_width=1, dtype=T.f16), + arith.constant(0.0, type=T.f16), + ) + lds_idx = row * arith.index(lds_a_stride) + col + v1 = vector.from_elements(T.vec(1, T.f16), [x_val]) + vector.store(v1, lds_a, [lds_idx], alignment=2) + scf.YieldOp([]) + + def copy_b_to_lds(k_base): + eid_idx = arith.index_cast(T.index, eid_i32) + n_base = eid_idx * n_idx + blk_n + total = int(tile_k) * int(tile_n) + rounds = (total + block_threads - 1) // block_threads + for it in range(rounds): + elem = tx + fx.Index(it * block_threads) + in_range = arith.cmpi( + arith.CmpIPredicate.ult, + arith.index_cast(T.i32, elem), + arith.constant(total, type=T.i32), + ) + _if_elem = scf.IfOp(in_range) + with ir.InsertionPoint(_if_elem.then_block): + k_local = elem // arith.index(int(tile_n)) + n_local = elem % arith.index(int(tile_n)) + w_idx = (n_base + n_local) * arith.index(int(inter_dim)) + k_base + k_local + w_val = buffer_ops.buffer_load( + w_rsrc, arith.index_cast(T.i32, w_idx), + vec_width=1, dtype=T.f16, + ) + lds_idx = k_local * arith.index(lds_b_stride) + n_local + v1 = vector.from_elements(T.vec(1, T.f16), [w_val]) + vector.store(v1, lds_b, [lds_idx], alignment=2) + scf.YieldOp([]) + + def _precompute_a_lane_bases(): + row_stride_off = (warp_m_base + lane16) * arith.index(lds_a_stride) + k_lane_off = lane_kgrp * arith.index(8) + bases = [] + for wm in range_constexpr(wmma_m_rep): + a_base = row_stride_off + arith.index(wm * WMMA_M * lds_a_stride) + k_lane_off + bases.append(a_base) + return bases + + def _precompute_b_lane_bases(): + lane8 = lane16 % arith.index(8) + lane_ngrp = lane16 / arith.index(8) + k_lane_off = (lane_kgrp * arith.index(8) + lane8) * arith.index(lds_b_stride) + n_lane_off = lane_ngrp * arith.index(8) + bases = [] + for wn in range_constexpr(wmma_n_rep): + n_col = warp_n_base + arith.index(wn * WMMA_N) + n_lane_off + bases.append(k_lane_off + n_col) + return bases + + def load_a_frag(a_base, ks): + vec8_ty = ir.VectorType.get([8], T.f16) + off0 = a_base + arith.index(ks * WMMA_K) + off1 = a_base + arith.index(ks * WMMA_K + 16) + v0 = vector.load_op(vec8_ty, lds_a, [off0]) + v1 = vector.load_op(vec8_ty, lds_a, [off1]) + return vector.shuffle(v0, v1, list(range(16))) + + def load_b_frag(b_base, ks): + vec8_ty = ir.VectorType.get([8], T.f16) + results = [] + for k_half in range_constexpr(2): + k_row_off = (ks * WMMA_K + k_half * 16) * lds_b_stride + elem_off = b_base + arith.index(k_row_off) + v = rocdl.lds_transpose_load(vec8_ty, lds_b, elem_off, elem_bytes) + results.append(v) + return vector.shuffle(results[0], results[1], list(range(16))) + + acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) + acc = [acc_zero] * n_accs + + _if_blk = scf.IfOp(block_ok) + with ir.InsertionPoint(_if_blk.then_block): + a_bases = _precompute_a_lane_bases() + b_bases = _precompute_b_lane_bases() + + for kt in range_constexpr(num_k_tiles): + k_base = fx.Index(kt * int(tile_k)) + pack_a_to_lds(k_base) + copy_b_to_lds(k_base) + gpu.barrier() + + for ks in range_constexpr(k_wmma_steps): + b_frags = [load_b_frag(b_bases[wn], ks) for wn in range_constexpr(wmma_n_rep)] + for wm in range_constexpr(wmma_m_rep): + a_frag = load_a_frag(a_bases[wm], ks) + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + acc[idx] = rocdl.wmma_f32_16x16x32_f16( + T.vec(8, T.f32), + b_frags[wn], + a_frag, + acc[idx], + signA=False, + signB=False, + modC=0, + reuseA=False, + reuseB=False, + ).result + gpu.barrier() + + out_elem_ty = _moe_out_elem_ty(out_dtype, T) + + def _load_sub8(acc_idx, _vec_base): + return acc[acc_idx] + + _emit_stage2_store_epilogue( + sub_tiles=_sub_tiles, + by=by, + tile_m=int(tile_m), + route_tile_m=int(route_tile_m), + warp_m_base=warp_m_base, + warp_n_base=warp_n_base, + blk_n=blk_n, + lane16=lane16, + lane_kgrp=lane_kgrp, + WMMA_N=WMMA_N, + i32_tokens_in=i32_tokens_in, + i32_n_in=i32_n_in, + topk=int(topk), + num_valid_i32=num_valid_i32, + block_row_start=block_row_start, + sorted_rsrc=sorted_rsrc, + tw_rsrc=sw_rsrc, + out_rsrc=out_rsrc, + doweight_stage2=bool(doweight_stage2), + accumulate=bool(accumulate), + out_elem_ty=out_elem_ty, + load_sub8=_load_sub8, + ir=ir, + fx=fx, + arith=arith, + buffer_ops=buffer_ops, + scf=scf, + vector=vector, + range_constexpr=range_constexpr, + rocdl=rocdl, + T=T, + ) + scf.YieldOp([]) + + @flyc.jit + def launch_fp16_stage2_single( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + stream: fx.Stream, + ): + _ = i32_k_in + ctx = CompilationContext.get_current() + n_in = arith.index_cast(T.index, i32_n_in) + size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) + gx = (n_in + fx.Index(int(tile_n) - 1)) // fx.Index(int(tile_n)) + gy = size_expert_ids_in + launcher = moe_fp16_stage2_single( + arg_out, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_num_valid_ids, + i32_tokens_in, + i32_n_in, + i32_k_in, + i32_size_expert_ids_in, + ) + _finalize_alloc_and_launch_2d( + ctx=ctx, + alloc=alloc, + launcher=launcher, + gx=gx, + gy=gy, + block_threads=block_threads, + stream=stream, + waves_per_eu=waves_per_eu, + ir=ir, + ) + + if expert_sched_mode: + launch_fp16_stage2_single.compile_hints["llvm_options"] = { + "amdgpu-expert-scheduling-mode": True, + } + + return launch_fp16_stage2_single + + +# --------------------------------------------------------------------------- +# Public API entry points for fp16/bf16 +# --------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=1024) +def _compile_moe_wmma_gemm( + *, + stage: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight: bool, + in_dtype: str = "fp16", + out_dtype: str = "f16", + accumulate: bool = True, + waves_per_eu: int | None = None, + expert_sched_mode: bool = True, +): + _require_gfx1250() + if waves_per_eu is not None and int(waves_per_eu) < 1: + raise ValueError(f"waves_per_eu must be >= 1, got {waves_per_eu!r}") + if in_dtype not in ("fp16", "bf16"): + raise ValueError( + f"Unsupported in_dtype for WMMA stage{stage}: {in_dtype!r}, " + "expected 'fp16' or 'bf16'" + ) + + single_tile_m, single_tile_n, single_m_warp, single_n_warp = _pick_fp16_single_launch_shape( + int(tile_m), int(tile_n), max_total_warps=8, + ) + common = dict( + inter_dim=int(inter_dim), + experts=int(experts), + topk=int(topk), + route_tile_m=int(tile_m), + tile_m=int(single_tile_m), + tile_n=int(single_tile_n), + tile_k=int(tile_k), + m_warp=int(single_m_warp), + n_warp=int(single_n_warp), + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + expert_sched_mode=expert_sched_mode, + ) + + if stage == 1: + exe = _compile_stage1_wmma_kernel_impl( + model_dim=int(model_dim), doweight_stage1=bool(doweight), **common, + ) + else: + exe = _compile_stage2_wmma_kernel_impl( + doweight_stage2=bool(doweight), accumulate=bool(accumulate), **common, + ) + + if in_dtype == "bf16": + return _bf16_to_f16_wrapper(exe, x_arg=1, w_arg=2) + return exe + + +def compile_moe_gemm1(*, doweight_stage1, group_size=-1, use_cshuffle_epilog=None, + num_buffers=1, use_tdm_gather=True, use_tdm_store=False, + inst_prefetch=False, wave_specialized_tdm=False, + cluster_m=1, cluster_n=1, **kw): + return _compile_moe_wmma_gemm(stage=1, doweight=doweight_stage1, **kw) + + +def compile_moe_gemm2(*, doweight_stage2, accumulate=True, group_size=-1, + use_cshuffle_epilog=None, + num_buffers=1, use_tdm_gather=True, use_tdm_store=False, + inst_prefetch=False, wave_specialized_tdm=False, + cluster_m=1, cluster_n=1, **kw): + return _compile_moe_wmma_gemm(stage=2, doweight=doweight_stage2, accumulate=accumulate, **kw) + + +def compile_moe_gemm2_ex(*, mode=MoeGemm2Mode.ATOMIC, valid_mask=None, zero_intermediate=True, **kw): + if mode == MoeGemm2Mode.REDUCE: + gemm2_exe = compile_moe_gemm2(accumulate=False, **kw) + out_s = str(kw.get("out_dtype", "f16")).strip().lower() + if out_s in ("f16", "fp16", "half"): + dtype_str = "f16" + elif out_s in ("bf16", "bfloat16"): + dtype_str = "bf16" + else: + dtype_str = "f32" + reduce_exe = compile_moe_reduction( + topk=kw["topk"], model_dim=kw["model_dim"], + dtype_str=dtype_str, use_mask=(valid_mask is not None), + ) + from kernels.moe_gemm_2stage import _MoeGemm2ReduceWrapper + return _MoeGemm2ReduceWrapper( + gemm2_exe=gemm2_exe, reduce_exe=reduce_exe, + topk=kw["topk"], model_dim=kw["model_dim"], + out_dtype_str=dtype_str, + use_mask=(valid_mask is not None), + zero_intermediate=zero_intermediate, + ) + return compile_moe_gemm2(accumulate=True, **kw) diff --git a/python/flydsl/expr/rocdl/tdm_ops.py b/python/flydsl/expr/rocdl/tdm_ops.py index 4ade03cb..b1e7ead7 100644 --- a/python/flydsl/expr/rocdl/tdm_ops.py +++ b/python/flydsl/expr/rocdl/tdm_ops.py @@ -40,8 +40,13 @@ __all__ = [ "TDMDescriptor2D", + "TDMGatherDescriptor", "make_tensor_descriptor_2d", + "make_tensor_gather_dgroup0", + "make_tensor_gather_descriptor", "tensor_load_2d", + "tensor_load_gather", + "tensor_store_gather", "tensor_store_2d", "tensor_wait", "compute_padding_encoding", @@ -132,6 +137,22 @@ class TDMDescriptor2D: dgroup1: object # vector<8xi32> MLIR Value +@dataclass +class TDMGatherDescriptor: + """Holds GROUP0, GROUP1, GROUP2, GROUP3 for TDM gather mode. + + In gather mode, groups 2 and 3 carry row indices instead of + higher-dimension tensor metadata. + + - 32-bit index mode: up to 8 row indices (4 per group) + - 16-bit index mode: up to 16 row indices (8 per group) + """ + dgroup0: object # vector<4xi32> MLIR Value + dgroup1: object # vector<8xi32> MLIR Value + dgroup2: object # vector<4xi32> MLIR Value — row indices [0..3] or [0..7] + dgroup3: object # vector<4xi32> MLIR Value — row indices [4..7] or [8..15] + + # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- @@ -390,6 +411,309 @@ def make_tensor_descriptor_2d( return TDMDescriptor2D(dgroup0=dgroup0, dgroup1=dgroup1) +def make_tensor_gather_descriptor( + global_ptr, + lds_memref, + row_indices, + row_width: int, + tensor_dim0: int, + tensor_dim1, + stride: int, + elem_bytes: int = 1, + pad_interval: int = 0, + pad_amount: int = 0, + index_size: int = 32, + gather_tile_dim1=None, + lds_byte_offset=None, + global_byte_offset=None, + workgroup_mask: Union[int, "ir.Value"] = 0, +) -> TDMGatherDescriptor: + """Build a TDM gather descriptor for loading arbitrary rows from global to LDS. + + In gather mode the TDM fetches rows specified by explicit indices in + descriptor groups 2 and 3, rather than iterating over contiguous dim1. + + Args: + global_ptr: The global tensor pointer (fx.Tensor). + lds_memref: The LDS memref base (SmemAllocator base). + row_indices: List of row index MLIR i32 Values. Max 8 for 32-bit + mode, max 16 for 16-bit mode. + row_width: Width of each row in data_size elements (= tile_dim0). + Must be a multiple of 4 bytes. + tensor_dim0: Full tensor dimension 0 (row width) for OOB check. + tensor_dim1: Full tensor dimension 1 (num rows) for OOB check. + Accepts a Python int (compile-time) or an MLIR i32 + Value / SGPR (runtime). Per ISA spec §4.10.3.2, + row indices >= tensor_dim1 are treated as OOB, so + this MUST be >= the actual number of rows (tokens). + stride: Stride of dim0 in elements (row stride of the global + matrix). + elem_bytes: Element size in bytes (1, 2, 4, or 8). + pad_interval: Padding interval in elements (0 to disable). + pad_amount: Padding amount in elements (0 to disable). + index_size: Row index width in bits (16 or 32). + gather_tile_dim1: + Optional override for gather-mode tile_dim1 (the number + of valid indices to consume from groups 2/3). Accepts a + Python int or runtime MLIR i32 Value / SGPR. Defaults to + len(row_indices), preserving the historical behavior. + lds_byte_offset: Additional LDS byte offset. + global_byte_offset: Additional global memory byte offset (MLIR index). + Used for K-tile column offsets. + workgroup_mask: Multicast mask. + + Returns: + TDMGatherDescriptor with groups 0-3 ready for tensor_load_gather. + """ + assert index_size in (16, 32), f"index_size must be 16 or 32, got {index_size}" + max_indices = 8 if index_size == 32 else 16 + num_indices = len(row_indices) + assert 0 < num_indices <= max_indices, ( + f"row_indices length {num_indices} exceeds max {max_indices} for {index_size}-bit mode" + ) + assert row_width * elem_bytes % 4 == 0, ( + f"row_width * elem_bytes must be multiple of 4, got {row_width * elem_bytes}" + ) + + dgroup0 = make_tensor_gather_dgroup0( + global_ptr=global_ptr, + lds_memref=lds_memref, + index_size=index_size, + lds_byte_offset=lds_byte_offset, + global_byte_offset=global_byte_offset, + ) + + # ================================================================ + # GROUP 1: config + tensor dims + tile + stride + # ================================================================ + data_size_code = int(math.log2(elem_bytes)) + + if pad_interval > 0 and pad_amount > 0: + elem_bits = elem_bytes * 8 + enc_interval, enc_amount = compute_padding_encoding( + pad_interval, pad_amount, elem_bits + ) + pad_enable = 1 + else: + enc_interval, enc_amount = 0, 0 + pad_enable = 0 + + if isinstance(workgroup_mask, int): + g1_s0_val = ( + (workgroup_mask & 0xFFFF) + | (data_size_code << 16) + | (0 << 18) # atomic_barrier_enable + | (0 << 19) # iterate_enable (ignored in gather) + | (pad_enable << 20) + | (0 << 21) # early_timeout + | (enc_interval << 22) + | (enc_amount << 25) + ) + g1_s0 = arith.constant(g1_s0_val, type=T.i32) + else: + upper = ( + (data_size_code << 16) + | (pad_enable << 20) + | (enc_interval << 22) + | (enc_amount << 25) + ) + g1_s0 = arith.ori( + arith.constant(upper, type=T.i32), + arith.andi(workgroup_mask, arith.constant(0xFFFF, type=T.i32)), + ) + + # tensor_dim0 (32 bits) packed into sgpr1[31:16] and sgpr2[15:0] + # tensor_dim1 (32 bits) packed into sgpr2[31:16] and sgpr3[15:0] + # + # tensor_dim1 may be a runtime MLIR i32 value (e.g. num_tokens) — + # the TDM hardware uses it for OOB checking on gather row indices. + _td1_is_runtime = not isinstance(tensor_dim1, int) + + g1_s1 = arith.constant((tensor_dim0 & 0xFFFF) << 16, type=T.i32) + + if _td1_is_runtime: + _td0_hi = arith.constant((tensor_dim0 >> 16) & 0xFFFF, type=T.i32) + _td1_lo = arith.andi(tensor_dim1, arith.constant(0xFFFF, type=T.i32)) + _td1_lo_shifted = arith.shli(_td1_lo, arith.constant(16, type=T.i32)) + g1_s2 = arith.ori(_td0_hi, _td1_lo_shifted) + + _td1_hi = arith.andi( + arith.shrui(tensor_dim1, arith.constant(16, type=T.i32)), + arith.constant(0xFFFF, type=T.i32), + ) + g1_s3 = arith.ori(_td1_hi, arith.constant(row_width << 16, type=T.i32)) + else: + g1_s2 = arith.constant( + ((tensor_dim0 >> 16) & 0xFFFF) | ((tensor_dim1 & 0xFFFF) << 16), + type=T.i32, + ) + g1_s3 = arith.constant( + ((tensor_dim1 >> 16) & 0xFFFF) | (row_width << 16), + type=T.i32, + ) + + # sgpr4: tile_dim1[15:0] — in gather mode, this is the number of valid + # indices consumed from descriptor groups 2/3. Allow kernels to override it + # at runtime so they can keep a fixed index vector while shrinking the valid + # prefix for padded MoE tiles. + if gather_tile_dim1 is None: + g1_s4 = arith.constant(num_indices & 0xFFFF, type=T.i32) + elif isinstance(gather_tile_dim1, int): + g1_s4 = arith.constant(gather_tile_dim1 & 0xFFFF, type=T.i32) + else: + g1_s4 = arith.andi(gather_tile_dim1, arith.constant(0xFFFF, type=T.i32)) + + # sgpr5: tensor_dim0_stride (dim0 stride = row stride in elements) + g1_s5 = arith.constant(stride & 0xFFFFFFFF, type=T.i32) + + # sgpr6-7: tensor_dim1_stride (ignored in gather mode) + g1_s6 = arith.constant(0, type=T.i32) + g1_s7 = arith.constant(0, type=T.i32) + + dgroup1 = vector.from_elements( + T.vec(8, T.i32), + [g1_s0, g1_s1, g1_s2, g1_s3, g1_s4, g1_s5, g1_s6, g1_s7], + ) + + # ================================================================ + # GROUP 2 & 3: row indices + # ================================================================ + zero = arith.constant(0, type=T.i32) + + if index_size == 32: + # 32-bit mode: group2 has indices [0..3], group3 has [4..7] + g2_vals = [row_indices[i] if i < num_indices else zero for i in range(4)] + g3_vals = [row_indices[i + 4] if (i + 4) < num_indices else zero for i in range(4)] + else: + # 16-bit mode: pack 2 x 16-bit indices per 32-bit word + # Group 2: indices [0..7] packed into 4 x i32 + g2_vals = [] + for w in range(4): + lo_idx = w * 2 + hi_idx = w * 2 + 1 + lo = row_indices[lo_idx] if lo_idx < num_indices else zero + hi = row_indices[hi_idx] if hi_idx < num_indices else zero + lo_masked = arith.andi(lo, arith.constant(0xFFFF, type=T.i32)) + hi_shifted = arith.shli(arith.andi(hi, arith.constant(0xFFFF, type=T.i32)), + arith.constant(16, type=T.i32)) + g2_vals.append(arith.ori(lo_masked, hi_shifted)) + # Group 3: indices [8..15] packed into 4 x i32 + g3_vals = [] + for w in range(4): + lo_idx = 8 + w * 2 + hi_idx = 8 + w * 2 + 1 + lo = row_indices[lo_idx] if lo_idx < num_indices else zero + hi = row_indices[hi_idx] if hi_idx < num_indices else zero + lo_masked = arith.andi(lo, arith.constant(0xFFFF, type=T.i32)) + hi_shifted = arith.shli(arith.andi(hi, arith.constant(0xFFFF, type=T.i32)), + arith.constant(16, type=T.i32)) + g3_vals.append(arith.ori(lo_masked, hi_shifted)) + + dgroup2 = vector.from_elements(T.vec(4, T.i32), g2_vals) + dgroup3 = vector.from_elements(T.vec(4, T.i32), g3_vals) + + return TDMGatherDescriptor( + dgroup0=dgroup0, dgroup1=dgroup1, + dgroup2=dgroup2, dgroup3=dgroup3, + ) + + + +def make_tensor_gather_dgroup0( + global_ptr, + lds_memref, + *, + index_size: int = 32, + lds_byte_offset=None, + global_byte_offset=None, +): + """Build gather descriptor GROUP0 only. + + This is the dynamic address-bearing portion of a TDM gather descriptor. + Separating it lets kernels hoist static GROUP1/GROUP2/GROUP3 state and + only rebuild the per-issue address group close to the TDM instruction. + """ + from ..._mlir.dialects import fly as _fly_d + + assert index_size in (16, 32), f"index_size must be 16 or 32, got {index_size}" + + glb_ptr_type = ir.Type.parse("!llvm.ptr<1>") + i64 = ir.IntegerType.get_signless(64) + a_raw = global_ptr.__fly_values__()[0] + glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) + glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) + if global_byte_offset is not None: + glb_byte_off_i64 = arith.index_cast(T.i64, global_byte_offset) + glb_base_i64 = glb_base_i64 + glb_byte_off_i64 + + lds_base_idx = _ArithValue(memref_dialect.extract_aligned_pointer_as_index(lds_memref)) + lds_total_off = lds_base_idx + if lds_byte_offset is not None: + lds_total_off = lds_total_off + lds_byte_offset + lds_addr_i32 = arith.index_cast(T.i32, lds_total_off) + + gather_index_bit = 1 if index_size == 32 else 0 + g0_pred = ( + 1 + | (gather_index_bit << 30) + | (1 << 31) + ) + g0_s0 = arith.constant(g0_pred, type=T.i32) + g0_s1 = lds_addr_i32 + + i32 = ir.IntegerType.get_signless(32) + g0_s2 = _ArithValue(std_arith.TruncIOp(i32, _raw(glb_base_i64)).result) + hi_raw = _ArithValue(_raw(glb_base_i64)).shrui(arith.constant(32, type=T.i64)) + g0_s3 = ( + _ArithValue(std_arith.TruncIOp(i32, _raw(hi_raw)).result) + | arith.constant(1 << 31, type=T.i32) + ) + return vector.from_elements( + T.vec(4, T.i32), [g0_s0, g0_s1, g0_s2, g0_s3] + ) + +def tensor_load_gather( + desc: TDMGatherDescriptor, + cache_policy: int = 0, +) -> None: + """Issue a TDM gather load (Global -> LDS) using row indices. + + Uses the 5-group tensor_load_to_lds intrinsic with groups 2 and 3 + carrying the gather row indices. + + Args: + desc: TDMGatherDescriptor from make_tensor_gather_descriptor. + cache_policy: Cache policy (0 = default). + """ + dg4 = _raw(_zero_dgroup_v8i32()) + rocdl.tensor_load_to_lds( + _raw(desc.dgroup0), _raw(desc.dgroup1), + _raw(desc.dgroup2), _raw(desc.dgroup3), + dg4, cache_policy, + ) + + +def tensor_store_gather( + desc: TDMGatherDescriptor, + cache_policy: int = 0, +) -> None: + """Issue a TDM gather store (LDS -> Global) using row indices. + + Uses the 5-group tensor_store_from_lds intrinsic with groups 2 and 3 + carrying the gather row indices. + + Args: + desc: TDMGatherDescriptor from make_tensor_gather_descriptor. + cache_policy: Cache policy (0 = default). + """ + dg4 = _raw(_zero_dgroup_v8i32()) + rocdl.tensor_store_from_lds( + _raw(desc.dgroup0), _raw(desc.dgroup1), + _raw(desc.dgroup2), _raw(desc.dgroup3), + dg4, cache_policy, + ) + + def _zero_dgroup_v4i32(): """Create a zero vector<4xi32> for unused descriptor groups.""" z = arith.constant(0, type=T.i32) diff --git a/tests/kernels/benchmark_common.py b/tests/kernels/benchmark_common.py index 36926f8c..e30cee71 100644 --- a/tests/kernels/benchmark_common.py +++ b/tests/kernels/benchmark_common.py @@ -398,6 +398,449 @@ def run_wmma_sweep( return rows +# ── MOE bench common helpers ────────────────────────────────────────────── + +BENCH_WARMUP = 5 +BENCH_ITERS = 20 + +BENCH_MODEL_CONFIGS = [ + # name, model_dim, inter_dim, experts, topk + ("DeepSeek-TP", 7168, 256, 257, 9), + ("DeepSeek-EP", 7168, 2048, 32, 8), + ("GPToss", 2880, 2880, 128, 4), +] + +BENCH_DTYPE_TARGET_TILES = { + # dtype: (tile_m, target_n, target_k, wmma_k) + "fp4": (16, 256, 512, 128), + "fp8": (16, 256, 512, 128), + "a8w4": (16, 256, 512, 128), + "fp16": (32, 64, 64, 32), + "bf16": (32, 64, 64, 32), +} + +BENCH_DEFAULT_TOKEN_SWEEP = [1, 4, 8, 32, 64, 128, 256] +_BENCH_SCALE_GROUP = 32 + + +def bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): + """Per-iteration CUDA events timer with optional L2 flush and median latency.""" + import torch + + flush_buf = None + if flush_l2: + l2_bytes = getattr( + torch.cuda.get_device_properties(torch.cuda.current_device()), + "L2_cache_size", 4 * 1024 * 1024) + alloc_bytes = max(l2_bytes * 2, 8 * 1024 * 1024) + flush_buf = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") + + for _ in range(warmup): + if flush_buf is not None: + flush_buf.zero_() + if prep_fn is not None: + prep_fn() + run_fn() + torch.cuda.synchronize() + + start_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + + for i in range(iters): + if flush_buf is not None: + flush_buf.zero_() + if prep_fn is not None: + prep_fn() + start_ev[i].record() + run_fn() + end_ev[i].record() + + torch.cuda.synchronize() + latencies = sorted(start_ev[i].elapsed_time(end_ev[i]) * 1e3 for i in range(iters)) + + n = len(latencies) + if n >= 8: + q1, q3 = latencies[n // 4], latencies[3 * n // 4] + iqr = q3 - q1 + lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr + filtered = [x for x in latencies if lo <= x <= hi] + if filtered: + latencies = filtered + + del flush_buf + return latencies[len(latencies) // 2] + + +def bench_best_tile(target, dim, align): + """Largest value <= target that divides dim and is a multiple of align.""" + for v in range(target, 0, -align): + if dim % v == 0: + return v + return None + + +def bench_resolve_tiles(in_dtype, model_dim, inter_dim): + """Compute the largest valid (tile_m, tile_n1, tile_k1, tile_n2, tile_k2) + for a given dtype and model shape, falling back from the target when + dimensions don't divide evenly.""" + tile_m, target_n, target_k, wmma_k = BENCH_DTYPE_TARGET_TILES[in_dtype] + + tile_n1 = bench_best_tile(target_n, inter_dim, 16) + tile_k1 = bench_best_tile(target_k, model_dim, wmma_k) + tile_n2 = bench_best_tile(target_n, model_dim, 16) + + tile_k2 = None + for k in range(target_k, 0, -wmma_k): + if inter_dim % k != 0: + continue + total = tile_m * k + if total % 256 == 0 and (total // 256) % 4 == 0: + tile_k2 = k + break + + if any(v is None for v in (tile_n1, tile_k1, tile_n2, tile_k2)): + return None + return (tile_m, tile_n1, tile_k1, tile_n2, tile_k2) + + +def bench_dtype_bpe(in_dtype): + """Return (a_bpe, w_bpe, w_scale_bpg) for bandwidth accounting.""" + if in_dtype == "fp4": + return 0.5, 0.5, 1 + if in_dtype == "a8w4": + return 1, 0.5, 1 + if in_dtype == "fp8": + return 1, 1, 1 + if in_dtype in ("fp16", "bf16"): + return 2, 2, 0 + return 1, 1, 1 + + +def bench_bytes_moved_stage1(tokens, topk, model_dim, inter_dim, experts, in_dtype): + import math + a_bpe, w_bpe, w_scale_bpg = bench_dtype_bpe(in_dtype) + aE = min(tokens * topk, experts) + b = 0 + b += tokens * model_dim * a_bpe + b += aE * (2 * inter_dim) * model_dim * w_bpe + b += aE * (2 * inter_dim) * math.ceil(model_dim / _BENCH_SCALE_GROUP) * w_scale_bpg + b += tokens * topk * inter_dim * 2 + return int(b) + + +def bench_bytes_moved_stage2(tokens, topk, model_dim, inter_dim, experts, in_dtype): + import math + a_bpe, w_bpe, w_scale_bpg = bench_dtype_bpe(in_dtype) + aE = min(tokens * topk, experts) + b = 0 + b += tokens * topk * inter_dim * a_bpe + b += aE * model_dim * inter_dim * w_bpe + b += aE * model_dim * math.ceil(inter_dim / _BENCH_SCALE_GROUP) * w_scale_bpg + b += tokens * topk * model_dim * 2 + return int(b) + + +def bench_print_banner(text): + print(f"\n{'=' * 110}") + print(f" {text}") + print(f"{'=' * 110}") + + +def bench_print_stage_header(): + print(f"{'Tokens':>7} {'M_eff':>7} {'Latency(us)':>12} {'TFLOPS':>9} " + f"{'BW(TB/s)':>10} {'Util%':>7} {'Status':>8}") + print("-" * 110) + + +def bench_print_stage_row(tokens, m_eff, us, tflops, tbps, util_pct, status): + print(f"{tokens:>7} {m_eff:>7} {us:>10.1f} {tflops:>8.2f} " + f"{tbps:>9.3f} {util_pct:>6.1f}% {status:>8}") + + +# ── MOE bench sweep system ───────────────────────────────────────────────── +# Generic benchmark sweep for MoE 2-stage kernels: sweeps model configs × +# dtypes × token counts. Callers provide the stage1/stage2 runner functions +# and data-setup helpers so this module stays kernel-agnostic. + + +def add_moe_bench_args(parser) -> None: + """Register the ``--bench`` argument group on *parser*. + + Call this from the test script's ``if __name__ == '__main__':`` block so + the user can ``python test_xxx.py --bench ...``. + """ + import argparse # noqa: F811 – local import to avoid top-level dep + + bench_group = parser.add_argument_group( + "benchmark sweep", + "Options for --bench mode (sweep model configs × dtypes × token counts)", + ) + bench_group.add_argument("--bench", action="store_true", default=False, + help="Run benchmark sweep mode instead of normal test mode.") + bench_group.add_argument("--bench-dtype", type=str, default=None, + help="Comma-separated dtypes for bench (default: all keys in BENCH_DTYPE_TARGET_TILES).") + bench_group.add_argument("--bench-tokens", type=str, default=None, + help="Comma-separated token counts for bench (default: 1,4,8,32,64,128,256).") + bench_group.add_argument("--bench-config", type=str, default=None, + help="Config name filter for bench (DeepSeek-TP, DeepSeek-EP, GPToss).") + bench_group.add_argument("--bench-no-ref", action="store_true", default=False, + help="Skip correctness reference check in bench mode (pure perf).") + bench_group.add_argument("--bench-warmup", type=int, default=BENCH_WARMUP, + help=f"Warmup iterations for bench (default: {BENCH_WARMUP}).") + bench_group.add_argument("--bench-iters", type=int, default=BENCH_ITERS, + help=f"Measurement iterations for bench (default: {BENCH_ITERS}).") + bench_group.add_argument("--bench-peak-tflops", type=float, default=0, + help="Peak TFLOPS for utilization calculation in bench mode.") + + +def moe_bench_config( + name: str, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + in_dtype: str, + token_list: List[int], + check_ref: bool, + peak_tflops: float, + *, + stage1_fn: Callable, + stage2_fn: Callable, + setup_data_fn: Callable, + prepare_a2_fn: Callable, + warmup: int = BENCH_WARMUP, + iters: int = BENCH_ITERS, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, +) -> None: + """Benchmark a single (model, dtype) configuration across all token counts. + + Parameters + ---------- + stage1_fn : callable + ``run_moe_stage1(...)`` from the test harness. + stage2_fn : callable + ``run_moe_stage2(...)`` from the test harness. + setup_data_fn : callable + ``(tokens, model_dim, inter_dim, experts, topk, tile_m) -> (x, w1, w2, ids, wts, routing)`` + prepare_a2_fn : callable + ``(out1_fp16, tokens, topk, inter_dim, in_dtype) -> (a2_q, a2_scale)`` + """ + import torch + + tiles = bench_resolve_tiles(in_dtype, model_dim, inter_dim) + if tiles is None: + bench_print_banner(f"{name} | {in_dtype} | dim={model_dim} inter={inter_dim}") + print(f" SKIP: no valid tile for this shape (WMMA_K alignment)") + return + tile_m, tile_n1, tile_k1, tile_n2, tile_k2 = tiles + + bench_print_banner(f"{name} | {in_dtype} | dim={model_dim} inter={inter_dim} E={experts} K={topk}") + print(f" Tiles: stage1=({tile_m},{tile_n1},{tile_k1}) stage2=({tile_m},{tile_n2},{tile_k2})") + print(f" Warmup={warmup} Iters={iters} RefCheck={'ON' if check_ref else 'OFF'}") + print( + f" Knobs: use_tdm_store={bool(use_tdm_store)} " + f"inst_prefetch={bool(inst_prefetch)} " + f"wave_specialized_tdm={bool(wave_specialized_tdm)}" + ) + if peak_tflops > 0: + print(f" Peak compute reference: {peak_tflops:.0f} TFLOPS") + + # ── Stage 1 ── + print(f"\n ── Stage 1 (gate+up: [{model_dim}] -> [{2*inter_dim}]) ──") + bench_print_stage_header() + + s1_results = [] + for tok in token_list: + torch.cuda.empty_cache() + x, w1, w2, ids, wts, routing = setup_data_fn(tok, model_dim, inter_dim, experts, topk, tile_m) + try: + out1, us1 = stage1_fn( + tokens=tok, model_dim=model_dim, inter_dim=inter_dim, + experts=experts, topk=topk, in_dtype=in_dtype, + tile_m=tile_m, tile_n=tile_n1, tile_k=tile_k1, + doweight_stage1=False, seed=0, + num_iters=iters, num_warmup=warmup, + x_fp32_in=x, w1_fp32_in=w1, w2_fp32_in=w2, + topk_ids_in=ids, topk_weights_in=wts, routing_in=routing, + return_outputs=True, skip_ref=(not check_ref), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + ) + status = "PASS" if check_ref else "OK" + except Exception as e: + status = "FAIL" + us1 = 0.0 + out1 = torch.zeros((tok, topk, inter_dim), device="cuda", dtype=torch.float16) + print(f" [{type(e).__name__}] tokens={tok}: {e}") + + m_eff = tok * topk + flops = 2 * tok * topk * (2 * inter_dim) * model_dim + tflops = flops / (us1 / 1e6) / 1e12 if us1 > 0 else 0 + bm = bench_bytes_moved_stage1(tok, topk, model_dim, inter_dim, experts, in_dtype) + tbps = bm / 1e12 / (us1 / 1e6) if us1 > 0 else 0 + util = (tflops / peak_tflops * 100) if (peak_tflops > 0 and tflops > 0) else 0 + + bench_print_stage_row(tok, m_eff, us1, tflops, tbps, util, status) + s1_results.append((tok, m_eff, us1, tflops, tbps, status, out1)) + + # ── Stage 2 atomic ── + print(f"\n ── Stage 2 atomic (down: [{inter_dim}] -> [{model_dim}]) ──") + bench_print_stage_header() + + for tok, m_eff, _, _, _, s1_status, out1 in s1_results: + torch.cuda.empty_cache() + x, w1, w2, ids, wts, routing = setup_data_fn(tok, model_dim, inter_dim, experts, topk, tile_m) + a2_q, a2_scale = prepare_a2_fn(out1, tok, topk, inter_dim, in_dtype) + try: + _, us2 = stage2_fn( + tokens=tok, model_dim=model_dim, inter_dim=inter_dim, + experts=experts, topk=topk, in_dtype=in_dtype, out_dtype="f16", + tile_m=tile_m, tile_n=tile_n2, tile_k=tile_k2, + doweight_stage1=False, seed=0, + num_iters=iters, num_warmup=warmup, + x_fp32_in=x, w1_fp32_in=w1, w2_fp32_in=w2, + topk_ids_in=ids, topk_weights_in=wts, routing_in=routing, + a2_fp8_in=a2_q, a2_scale_in=a2_scale, + return_outputs=True, skip_ref=(not check_ref), + use_reduce=False, + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + ) + status = "PASS" if check_ref else "OK" + except Exception as e: + status = "FAIL" + us2 = 0.0 + print(f" [{type(e).__name__}] tokens={tok}: {e}") + + flops = 2 * tok * topk * model_dim * inter_dim + tflops = flops / (us2 / 1e6) / 1e12 if us2 > 0 else 0 + bm = bench_bytes_moved_stage2(tok, topk, model_dim, inter_dim, experts, in_dtype) + tbps = bm / 1e12 / (us2 / 1e6) if us2 > 0 else 0 + util = (tflops / peak_tflops * 100) if (peak_tflops > 0 and tflops > 0) else 0 + + bench_print_stage_row(tok, m_eff, us2, tflops, tbps, util, status) + + # ── Stage 2 reduce ── + print(f"\n ── Stage 2 reduce (down: [{inter_dim}] -> [{model_dim}]) ──") + bench_print_stage_header() + + for tok, m_eff, _, _, _, s1_status, out1 in s1_results: + torch.cuda.empty_cache() + x, w1, w2, ids, wts, routing = setup_data_fn(tok, model_dim, inter_dim, experts, topk, tile_m) + a2_q, a2_scale = prepare_a2_fn(out1, tok, topk, inter_dim, in_dtype) + try: + _, us2r = stage2_fn( + tokens=tok, model_dim=model_dim, inter_dim=inter_dim, + experts=experts, topk=topk, in_dtype=in_dtype, out_dtype="f16", + tile_m=tile_m, tile_n=tile_n2, tile_k=tile_k2, + doweight_stage1=False, seed=0, + num_iters=iters, num_warmup=warmup, + x_fp32_in=x, w1_fp32_in=w1, w2_fp32_in=w2, + topk_ids_in=ids, topk_weights_in=wts, routing_in=routing, + a2_fp8_in=a2_q, a2_scale_in=a2_scale, + return_outputs=True, skip_ref=(not check_ref), + use_reduce=True, + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + ) + status = "PASS" if check_ref else "OK" + except Exception as e: + status = "FAIL" + us2r = 0.0 + print(f" [{type(e).__name__}] tokens={tok}: {e}") + + flops = 2 * tok * topk * model_dim * inter_dim + tflops = flops / (us2r / 1e6) / 1e12 if us2r > 0 else 0 + bm = bench_bytes_moved_stage2(tok, topk, model_dim, inter_dim, experts, in_dtype) + tbps = bm / 1e12 / (us2r / 1e6) if us2r > 0 else 0 + util = (tflops / peak_tflops * 100) if (peak_tflops > 0 and tflops > 0) else 0 + + bench_print_stage_row(tok, m_eff, us2r, tflops, tbps, util, status) + + del s1_results + torch.cuda.empty_cache() + + +def moe_bench_main( + args, + *, + stage1_fn: Callable, + stage2_fn: Callable, + setup_data_fn: Callable, + prepare_a2_fn: Callable, +) -> None: + """Entry point for ``--bench`` mode: sweep model configs × dtypes × token counts. + + Parameters + ---------- + args : argparse.Namespace + Parsed CLI args (must include the ``--bench-*`` group from ``add_moe_bench_args`` + and ``use_tdm_store``, ``inst_prefetch``, ``wave_specialized_tdm``). + stage1_fn, stage2_fn, setup_data_fn, prepare_a2_fn : + Kernel-specific callables (see ``moe_bench_config`` for signatures). + """ + import time + import torch + + os.environ["FLYDSL_RUNTIME_ENABLE_CACHE"] = "1" + + warmup = args.bench_warmup + iters = args.bench_iters + + dtypes = args.bench_dtype.split(",") if args.bench_dtype else list(BENCH_DTYPE_TARGET_TILES.keys()) + token_list = ( + [int(t) for t in args.bench_tokens.split(",")] + if args.bench_tokens + else BENCH_DEFAULT_TOKEN_SWEEP + ) + check_ref = not args.bench_no_ref + + print("=" * 110) + print(" AMD gfx1250 MOE GEMM Kernel Performance Benchmark") + print(f" PyTorch {torch.__version__}") + print(f" Device: {torch.cuda.get_device_name(0)}") + props = torch.cuda.get_device_properties(0) + print(f" CUs: {props.multi_processor_count}") + print(f" Memory: {props.total_memory / 1024**3:.0f} GB") + print(f" Warmup={warmup} Iters={iters} RefCheck={'ON' if check_ref else 'OFF'}") + print(f" Dtypes: {dtypes}") + print(f" Tokens: {token_list}") + print("=" * 110) + + t_start = time.time() + for cfg_name, mdim, idim, exp, topk in BENCH_MODEL_CONFIGS: + if args.bench_config and args.bench_config not in cfg_name: + continue + for dt in dtypes: + if dt not in BENCH_DTYPE_TARGET_TILES: + print(f"\n [SKIP] Unknown dtype: {dt}") + continue + try: + moe_bench_config( + cfg_name, mdim, idim, exp, topk, + dt, token_list, check_ref, args.bench_peak_tflops, + stage1_fn=stage1_fn, + stage2_fn=stage2_fn, + setup_data_fn=setup_data_fn, + prepare_a2_fn=prepare_a2_fn, + warmup=warmup, iters=iters, + use_tdm_store=bool(args.use_tdm_store), + inst_prefetch=bool(args.inst_prefetch), + wave_specialized_tdm=bool(args.wave_specialized_tdm), + ) + except Exception as e: + print(f"\n [ERROR] {cfg_name}/{dt}: {e}") + import traceback; traceback.print_exc() + + elapsed = time.time() - t_start + bench_print_banner(f"Done in {elapsed:.1f}s") + + def main() -> None: # CLI entrypoint: # BENCH_CONFIGS="M,N,dtype;..." AITER_IMPL=triton BENCH_WARMUP=10 BENCH_ITERS=50 python -m tests.kernels.benchmark_common diff --git a/tests/kernels/test_moe_gemm_mxscale_gfx1250.py b/tests/kernels/test_moe_gemm_mxscale_gfx1250.py new file mode 100644 index 00000000..a6c477d5 --- /dev/null +++ b/tests/kernels/test_moe_gemm_mxscale_gfx1250.py @@ -0,0 +1,1756 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""MoE GEMM tests for MXScale data types (fp4/fp8/a8w4) on gfx1250.""" +import math +import os +import sys +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F + +# ----------------------------------------------------------------------------- +# Ensure we use the repo-local `flydsl` when running this file directly. +# +# Some environments have another `flydsl` (e.g. from a sibling checkout) earlier +# on `sys.path`, which can miss newer ROCDL wrappers (notably atomic fadd / MFMA). +# ----------------------------------------------------------------------------- +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYTHON_CANDIDATES = [ + os.path.join(_REPO_ROOT, "build", "python_packages"), + _REPO_ROOT, +] +for _p in reversed(_PYTHON_CANDIDATES): + if os.path.isdir(_p) and _p not in sys.path: + sys.path.insert(0, _p) + +from tests.kernels.test_ref import torch_moe_gemm1, torch_moe_gemm2 +from tests.utils import get_dtype_max +from tests.test_common import verify_output, run_perftest +from flydsl.runtime.device import get_rocm_arch +from tests.kernels.utils import fp4_utils +from tests.kernels.benchmark_common import ( + bench_kernel_us as _bench_kernel_us, +) + +ARCH = get_rocm_arch() +# GFX950 (MI350) and newer typically use OCP standard float8_e4m3fn +# GFX940/941/942 (MI300) use float8_e4m3fnuz +if "gfx95" in ARCH: + DTYPE_FP8 = torch.float8_e4m3fn +else: + DTYPE_FP8 = torch.float8_e4m3fnuz + +SCALE_BLOCK = 32 + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +if not str(ARCH).startswith("gfx1250"): + pytest.skip(f"MoE 2stage gfx1250 tests require gfx1250, got {ARCH}", allow_module_level=True) + + + +def _per_1x32_fp8_quant(x: torch.Tensor): + """Quantize fp32 tensor to raw FP8/E4M3 bytes with one E8M0 scale per 32-wide K block.""" + if x.shape[-1] % SCALE_BLOCK != 0: + raise ValueError(f"Last dim must be divisible by {SCALE_BLOCK}, got {x.shape[-1]}") + shape_original = x.shape + x2d = x.reshape(-1, shape_original[-1]).to(torch.float32) + m, n = x2d.shape + x_blk = x2d.view(-1, SCALE_BLOCK) + x_blk = torch.nan_to_num(x_blk, nan=0.0, posinf=0.0, neginf=0.0) + max_abs = torch.amax(torch.abs(x_blk), dim=1) + dtype_max = float(get_dtype_max(DTYPE_FP8)) + scale_e8m0 = fp4_utils.f32_to_e8m0(max_abs / dtype_max) + scale_f32 = fp4_utils.e8m0_to_f32(scale_e8m0) + scale_f32 = torch.nan_to_num(scale_f32, nan=1.0, posinf=1.0, neginf=1.0) + scale_f32[scale_f32 == 0] = 1.0 + y_f32 = x_blk / scale_f32.view(-1, 1) + # Clamp before casting to float8 to avoid generating NaN payloads. + y_f32 = torch.clamp(y_f32, min=-dtype_max, max=dtype_max) + y = fp4_utils._f32_to_floatx_unpacked(y_f32.contiguous().view(-1), 4, 3).view(torch.uint8) + y = y.view(*shape_original) + scale = scale_e8m0.view(m, n // SCALE_BLOCK).view(torch.uint8) + return y, scale + + +def _dequant_blockscale_fp8(x_q: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + if scale.dim() == x_q.dim() - 1: + scale = scale.view(*x_q.shape[:-1], scale.shape[-1]) + scale_f32 = fp4_utils.e8m0_to_f32(scale.view(torch.uint8)) + scale_expanded = scale_f32.repeat_interleave(SCALE_BLOCK, dim=-1)[..., : x_q.shape[-1]] + return fp4_utils.fp8_e4m3_to_f32(x_q.view(torch.uint8)) * scale_expanded + + +def _dequant_blockscale_fp4(x_q: torch.Tensor, scale: torch.Tensor, k_dim: int) -> torch.Tensor: + if scale.dim() == x_q.dim() - 1: + scale = scale.view(*x_q.shape[:-1], scale.shape[-1]) + scale_f32 = fp4_utils.e8m0_to_f32(scale.view(torch.uint8)) + scale_expanded = scale_f32.repeat_interleave(SCALE_BLOCK, dim=-1)[..., :k_dim] + return fp4_utils.mxfp4_to_f32(x_q.view(torch.uint8))[..., :k_dim] * scale_expanded + + +def _torch_moe_gemm1_a8w4( + x_fp8: torch.Tensor, + w1_fp4_flat: torch.Tensor, + scale_x: torch.Tensor, + scale_w1_flat: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + inter_dim: int, + doweight_stage1: bool, +) -> torch.Tensor: + topk = topk_ids.shape[1] + tokens, model_dim = x_fp8.shape + experts = int(w1_fp4_flat.shape[0] // (2 * inter_dim)) + x = _dequant_blockscale_fp8(x_fp8, scale_x) + w1 = _dequant_blockscale_fp4(w1_fp4_flat, scale_w1_flat, model_dim).view(experts, 2 * inter_dim, model_dim) + out = torch.zeros((tokens, topk, inter_dim), device=x.device, dtype=torch.float32) + for e in range(experts): + mask = topk_ids == e + idx = mask.nonzero(as_tuple=False) + if idx.numel() == 0: + continue + t_idx = idx[:, 0] + s_idx = idx[:, 1] + y2 = F.linear(x[t_idx, :], w1[e, :, :]) + gate = y2[:, :inter_dim] + up = y2[:, inter_dim:] + y = F.silu(gate) * up + if doweight_stage1: + y = y * topk_weights[t_idx, s_idx].unsqueeze(-1) + out[t_idx, s_idx, :] = y + return out + + +def _torch_moe_gemm2_a8w4( + a2_fp8: torch.Tensor, + w2_fp4: torch.Tensor, + scale_a2: torch.Tensor, + scale_w2: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + model_dim: int, + doweight_stage2: bool, +) -> torch.Tensor: + tokens, topk, inter_dim = a2_fp8.shape + experts = int(w2_fp4.shape[0]) if w2_fp4.dim() == 3 else int(w2_fp4.shape[0] // model_dim) + a2 = _dequant_blockscale_fp8(a2_fp8, scale_a2) + w2 = _dequant_blockscale_fp4(w2_fp4, scale_w2, inter_dim).view(experts, model_dim, inter_dim) + out = torch.zeros((tokens, model_dim), device=a2.device, dtype=torch.float32) + for e in range(experts): + mask = topk_ids == e + idx = mask.nonzero(as_tuple=False) + if idx.numel() == 0: + continue + t_idx = idx[:, 0] + s_idx = idx[:, 1] + y = F.linear(a2[t_idx, s_idx, :], w2[e, :, :]) + if doweight_stage2: + y = y * topk_weights[t_idx, s_idx].unsqueeze(-1) + out.index_add_(0, t_idx, y) + return out + + +# Reuse routing utilities from the base MoE GEMM test harness. +from tests.kernels.test_moe_gemm import ( + build_routing_buffers, + get_topk_valid_mask, + RoutingBuffers, +) + +# Kernel implementations live under `kernels/`; this test file is the harness. +from kernels.moe_gemm_2stage_mxscale_gfx1250 import ( + compile_moe_gemm1, + compile_moe_gemm2, + compile_moe_gemm2_ex, + MoeGemm2Mode, +) + + +# ---- Stage1/Stage2 runners (helpers; NOT pytest tests) ---- +def run_moe_stage1( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage1: bool, + *, + in_dtype: str = "fp8", + seed: int = 0, + num_iters: int = 5, + num_warmup: int = 2, + moe_sort_mode: Optional[str] = None, + # Optional overrides (used by the 2-stage runner to avoid duplicated setup/sorting). + x_fp32_in: Optional[torch.Tensor] = None, + w1_fp32_in: Optional[torch.Tensor] = None, + topk_ids_in: Optional[torch.Tensor] = None, + topk_weights_in: Optional[torch.Tensor] = None, + routing_in: Optional[RoutingBuffers] = None, + return_outputs: bool = False, + skip_ref: bool = False, + test_graph: bool = False, + waves_per_eu: Optional[int] = None, + benchmark_mode: bool = False, + flush_l2: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, + expert_sched_mode: bool = True, +): + assert model_dim % 64 == 0 + assert inter_dim % tile_n == 0 + + device = torch.device("cuda") + torch.manual_seed(int(seed)) + + x_fp32 = ( + x_fp32_in + if x_fp32_in is not None + else torch.randn((tokens, model_dim), device=device, dtype=torch.float32) + ) + w1_fp32 = ( + w1_fp32_in + if w1_fp32_in is not None + else torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) + ) + + # Routing: aiter uses fused_topk; we use torch topk+softmax for portability/determinism. + if topk_ids_in is None or topk_weights_in is None: + score = torch.randn((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + else: + topk_ids = topk_ids_in + topk_weights = topk_weights_in + + routing = ( + routing_in + if routing_in is not None + else build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + moe_sort_mode=moe_sort_mode, + ) + ) + ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + sorted_size, + blocks, + ) = routing + + if in_dtype not in ("fp4", "fp8", "a8w4"): + raise ValueError(f"in_dtype must be one of ('fp4','fp8','a8w4'), got {in_dtype!r}") + is_fp4 = in_dtype == "fp4" + is_a8w4 = in_dtype == "a8w4" + + # Quantize inputs / weights (stage1 does not use W2). + if in_dtype == "fp4": + x_fp4, scale_x_raw, _ = fp4_utils.per_1x32_f4_quant(x_fp32) + x_q = x_fp4.view(torch.uint8) + scale_x = scale_x_raw.view(torch.uint8).view(tokens, model_dim // 32) + w1_fp4, scale_w1_raw, _ = fp4_utils.per_1x32_f4_quant(w1_fp32.view(-1, model_dim)) + w1_q = w1_fp4.view(torch.uint8).view(experts, 2 * inter_dim, model_dim // 2) + scale_w1 = scale_w1_raw.view(torch.uint8).view(experts, 2 * inter_dim, model_dim // 32) + elif in_dtype == "fp8": + x_q, scale_x = _per_1x32_fp8_quant(x_fp32) + w1_q, scale_w1 = _per_1x32_fp8_quant(w1_fp32) + else: # a8w4 + x_q, scale_x = _per_1x32_fp8_quant(x_fp32) + w1_q, scale_w1, _ = fp4_utils.per_1x32_f4_quant(w1_fp32) + w1_q = w1_q.view(torch.uint8) + scale_w1 = scale_w1.view(torch.uint8) + + # --- K-dimension padding for non-aligned model_dim --- + _orig_model_dim = model_dim + if model_dim % tile_k != 0: + _pad_k = ((model_dim + tile_k - 1) // tile_k) * tile_k - model_dim + if is_fp4: + x_q = F.pad(x_q, (0, _pad_k // 2)) + else: + x_q = F.pad(x_q, (0, _pad_k)) + scale_x = F.pad(scale_x, (0, _pad_k // 32)) + if is_fp4 or is_a8w4: + w1_q = F.pad(w1_q, (0, _pad_k // 2)) + else: + w1_q = F.pad(w1_q, (0, _pad_k)) + if scale_w1 is not None: + scale_w1 = F.pad(scale_w1, (0, _pad_k // 32)) + model_dim = model_dim + _pad_k + + # Preshuffle weights — gfx1250 native kernels handle layout internally. + uses_fp4_weight_layout = is_fp4 or is_a8w4 + w1_shuffled = w1_q + if in_dtype in ("fp8", "a8w4"): + w1_rows = experts * (2 * inter_dim) + w1_cols = model_dim // 2 if is_a8w4 else model_dim + w1_shuffled = fp4_utils.preshuffle_b_16x16( + w1_q.contiguous().view(w1_rows, w1_cols), + w1_rows, + w1_cols, + ).view_as(w1_q) + + # Flatten W1 for our FlyDSL kernel (treat expert dim as part of N). + if uses_fp4_weight_layout: + w1_shuffled_flat = w1_shuffled.view(experts * (2 * inter_dim), model_dim // 2) + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim // 2) + scale_w1_ref_flat = None if scale_w1 is None else scale_w1.view(experts * (2 * inter_dim), model_dim // 32) + else: + w1_shuffled_flat = w1_shuffled.view(experts * (2 * inter_dim), model_dim) + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim) + scale_w1_ref_flat = None if scale_w1 is None else scale_w1.view(experts * (2 * inter_dim), model_dim // 32) + + x_q = ( + x_q.contiguous().view(tokens, model_dim // 2) + if is_fp4 + else x_q.contiguous().view(tokens, model_dim) + ) + w_kernel = w1_shuffled_flat.contiguous() + if uses_fp4_weight_layout: + w_kernel = w_kernel.view(experts * (2 * inter_dim), model_dim // 2) + else: + w_kernel = w_kernel.view(experts * (2 * inter_dim), model_dim) + + scale_x_1d = scale_x.view(-1).contiguous() + if scale_w1 is None: + scale_w1_1d = torch.empty((0,), device=device, dtype=torch.float32) + else: + scale_w1_1d = scale_w1.view(-1).contiguous() + sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] + + # Output: [tokens, topk, inter_dim] fp16 + out = torch.zeros((tokens, topk, inter_dim), device=device, dtype=torch.float16) + + exe = compile_moe_gemm1( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + in_dtype=in_dtype, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=bool(doweight_stage1), + use_cshuffle_epilog=False, + waves_per_eu=waves_per_eu, + num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + expert_sched_mode=bool(expert_sched_mode), + ) + + def launch(o, x, w, sx, sw, st, eids, sw_sorted): + stream = torch.cuda.current_stream() + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + inter_dim, + model_dim, + int(blocks), + stream, + ) + + if bool(benchmark_mode) and not bool(test_graph): + def _prep_stage1(): + out.zero_() + + def _run_stage1(): + launch(out, x_q, w_kernel, scale_x_1d, scale_w1_1d, sorted_token_ids, sorted_expert_ids, sorted_weights_1d) + + us = _bench_kernel_us( + _run_stage1, + warmup=int(num_warmup), + iters=int(max(1, num_iters)), + flush_l2=bool(flush_l2), + prep_fn=_prep_stage1, + ) + torch.cuda.synchronize() + else: + _, us = run_perftest( + launch, + out, + x_q, + w_kernel, + scale_x_1d, + scale_w1_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + num_iters=int(num_iters), + num_warmup=int(num_warmup), + testGraph=test_graph, + ) + torch.cuda.synchronize() + + if not bool(skip_ref): + if in_dtype == "fp8": + x_ref = _dequant_blockscale_fp8(x_q.view(tokens, model_dim), scale_x.view(tokens, model_dim // 32)) + sx_ref = None + elif is_fp4: + x_ref = _dequant_blockscale_fp4(x_q.view(tokens, model_dim // 2), scale_x.view(tokens, model_dim // 32), model_dim) + sx_ref = None + else: # a8w4 + x_ref = x_q + sx_ref = scale_x + if is_a8w4: + ref = _torch_moe_gemm1_a8w4( + x_ref, + w1_q_flat, + sx_ref, + scale_w1_ref_flat, + topk_ids.to(torch.int64), + topk_weights, + inter_dim=inter_dim, + doweight_stage1=doweight_stage1, + ) + elif in_dtype == "fp8": + w_ref_f32 = _dequant_blockscale_fp8(w1_q_flat, scale_w1_ref_flat) + ref = torch_moe_gemm1( + x_ref, + w_ref_f32, + sx_ref, + None, + topk_ids.to(torch.int64), + topk_weights, + inter_dim=inter_dim, + doweight_stage1=doweight_stage1, + ) + else: # fp4 + w_ref_f32 = _dequant_blockscale_fp4(w1_q_flat, scale_w1_ref_flat, model_dim) + ref = torch_moe_gemm1( + x_ref, + w_ref_f32, + sx_ref, + None, + topk_ids.to(torch.int64), + topk_weights, + inter_dim=inter_dim, + doweight_stage1=doweight_stage1, + ) + + rtol = 0.5 if (is_a8w4 or is_fp4) else 0.25 + atol = 0.5 if is_a8w4 else 0.25 + assert verify_output(out.to(torch.float32), ref, rtol=rtol, atol=atol) + + # Note: kernel launches full expert-block range; effective work is gated by num_valid_ids. + flops = 2 * tokens * topk * (2 * inter_dim) * _orig_model_dim + tflops = flops / (us / 1e6) / 1e12 + + # Rough bytes-moved accounting (same spirit as GEMM tests: count each tensor once). + bytes_moved = 0 + bytes_x = tokens * _orig_model_dim # 1B elements (fp4/fp8/a8w4) + bytes_w = (experts * (2 * inter_dim) * _orig_model_dim) // (2 if is_a8w4 else 1) + bytes_out = tokens * topk * inter_dim * 2 + bytes_scale_x = tokens * 4 + bytes_scale_w = experts * (2 * inter_dim) * 4 + bytes_route = ( + int(sorted_weights.numel()) * 4 + + int(sorted_token_ids.numel()) * 4 + + int(sorted_expert_ids.numel()) * 4 + ) + bytes_moved += bytes_x + bytes_w + bytes_out + bytes_scale_x + bytes_scale_w + bytes_route + tbps = bytes_moved / 1e12 / (us / 1e6) + read_bytes = bytes_x + bytes_w + bytes_scale_x + bytes_scale_w + bytes_route + write_bytes = bytes_out + read_bw_gbs = read_bytes / 1e9 / (us / 1e6) + write_bw_gbs = write_bytes / 1e9 / (us / 1e6) + + if bool(benchmark_mode): + print( + f"FlyDSL MoE stage1[{in_dtype}] benchmark | " + f"shape=({tokens},{model_dim},{inter_dim}), E={experts}, K={topk}, " + f"tile=({tile_m},{tile_n},{tile_k})" + ) + print( + f" kernel: {us:.1f} us ({us / 1e3:.4f} ms) | " + f"{tflops:.2f} TFLOPS(logical, M={tokens*topk}) | {tbps:.3f} TB/s" + ) + print( + f" bandwidth: read {read_bw_gbs:.1f} GB/s + write {write_bw_gbs:.1f} GB/s | " + f"bytes: x={bytes_x/1e6:.1f}MB w={bytes_w/1e6:.1f}MB " + f"sx={bytes_scale_x/1e6:.1f}MB sw={bytes_scale_w/1e6:.1f}MB " + f"route={bytes_route/1e6:.1f}MB out={bytes_out/1e6:.1f}MB" + ) + else: + print( + f"FlyDSL MoE stage1[{in_dtype}]: " + f"{us:.1f} us, " + f"{tflops:.2f} TFLOPS(logical, M={tokens*topk}), " + f"{tbps:.3f} TB/s (doweight_stage1={doweight_stage1})" + ) + if return_outputs: + return out, us + return None + + +def run_moe_stage2( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage1: bool, + *, + in_dtype: str = "fp8", + out_dtype: str = "f16", + seed: int = 0, + num_iters: int = 5, + num_warmup: int = 2, + moe_sort_mode: Optional[str] = None, + x_fp32_in: Optional[torch.Tensor] = None, + w1_fp32_in: Optional[torch.Tensor] = None, + w2_fp32_in: Optional[torch.Tensor] = None, + topk_ids_in: Optional[torch.Tensor] = None, + topk_weights_in: Optional[torch.Tensor] = None, + routing_in: Optional[RoutingBuffers] = None, + a2_fp8_in: Optional[torch.Tensor] = None, + a2_scale_in: Optional[torch.Tensor] = None, + return_outputs: bool = False, + skip_ref: bool = False, + init_scale: float = 0.2, + compile_fn=None, + kernel_name: str = "moe_gemm2", + use_reduce: bool = False, + use_valid_mask: bool = False, + test_graph: bool = False, + waves_per_eu: Optional[int] = None, + benchmark_mode: bool = False, + flush_l2: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, + expert_sched_mode: bool = True, +): + """MoE stage2 (gemm2): out2[t] = sum_{slot} ( out1[t,slot] @ W2[expert]^T ) with optional routed weight.""" + + # Parameter sanity checks with actionable hints (avoid bare AssertionError). + if model_dim % tile_n != 0: + raise ValueError( + f"Invalid stage2 tiling: model_dim ({model_dim}) must be divisible by tile_n2 ({tile_n})." + ) + # Enforce the kernel's stage2 gmem->reg load mapping constraints. + # See: kernels/moe_gemm_2stage.py::compile_moe_gemm2 (x_load_bytes selection). + if (tile_m * tile_k) % 256 != 0: + raise ValueError( + f"Invalid stage2 tiling: tile_m*tile_k2 must be divisible by 256 (total_threads=256). " + f"Got tile_m={tile_m}, tile_k2={tile_k} -> tile_m*tile_k2={tile_m * tile_k}." + ) + bytes_per_thread_x = (tile_m * tile_k) // 256 # 1B elements + if bytes_per_thread_x % 4 != 0: + raise ValueError( + f"Invalid stage2 tiling for gmem loads: bytes_per_thread_x ((tile_m*tile_k2)/256) must be divisible by 4. " + f"Got tile_m={tile_m}, tile_k2={tile_k} -> bytes_per_thread_x={bytes_per_thread_x}. " + ) + + # Default compile function. + if compile_fn is None: + if use_reduce: + compile_fn = _make_reduce_mode_compile_fn(use_flydsl_reduce=True, use_valid_mask=bool(use_valid_mask)) + else: + compile_fn = compile_moe_gemm2 + + device = torch.device("cuda") + torch.manual_seed(int(seed)) + + s = float(init_scale) + + # Data: input and weights (aiter shapes) + x_fp32 = ( + x_fp32_in + if x_fp32_in is not None + else torch.rand((tokens, model_dim), device=device, dtype=torch.float32) * s + ) + w1_fp32 = ( + w1_fp32_in + if w1_fp32_in is not None + else torch.rand((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) + ) + w2_fp32 = ( + w2_fp32_in + if w2_fp32_in is not None + else torch.rand((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) + ) + + # Routing: deterministic torch topk + softmax. + if topk_ids_in is None or topk_weights_in is None: + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + else: + topk_ids = topk_ids_in + topk_weights = topk_weights_in + + routing = ( + routing_in + if routing_in is not None + else build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + moe_sort_mode=moe_sort_mode, + ) + ) + ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + sorted_size, + blocks, + ) = routing + # NOTE: routing uses `moe_sorting` output directly (no host trim/pad). Extra launched blocks + # are gated by `num_valid_ids` inside the kernels. + + if in_dtype not in ("fp4", "fp8", "a8w4"): + raise ValueError(f"in_dtype must be one of ('fp4','fp8','a8w4'), got {in_dtype!r}") + is_fp4 = in_dtype == "fp4" + is_a8w4 = in_dtype == "a8w4" + + # Quantize inputs / weights. + if in_dtype == "fp4": + x_fp4, scale_x_raw, _ = fp4_utils.per_1x32_f4_quant(x_fp32) + x_q = x_fp4.view(torch.uint8) + scale_x = scale_x_raw.view(torch.uint8).view(tokens, model_dim // 32) + w1_fp4, scale_w1_raw, _ = fp4_utils.per_1x32_f4_quant(w1_fp32.view(-1, model_dim)) + w1_q = w1_fp4.view(torch.uint8).view(experts, 2 * inter_dim, model_dim // 2) + scale_w1 = scale_w1_raw.view(torch.uint8).view(experts, 2 * inter_dim, model_dim // 32) + w2_fp4, scale_w2_raw, _ = fp4_utils.per_1x32_f4_quant(w2_fp32.view(-1, inter_dim)) + w2_q = w2_fp4.view(torch.uint8).view(experts, model_dim, inter_dim // 2) + scale_w2 = scale_w2_raw.view(torch.uint8).view(experts, model_dim, inter_dim // 32) + elif in_dtype == "fp8": + x_q, scale_x = _per_1x32_fp8_quant(x_fp32) + w1_q, scale_w1 = _per_1x32_fp8_quant(w1_fp32) + w2_q, scale_w2 = _per_1x32_fp8_quant(w2_fp32) + else: # a8w4 + x_q, scale_x = _per_1x32_fp8_quant(x_fp32) + w1_q, scale_w1, _ = fp4_utils.per_1x32_f4_quant(w1_fp32) + w2_q, scale_w2, _ = fp4_utils.per_1x32_f4_quant(w2_fp32) + w1_q = w1_q.view(torch.uint8) + scale_w1 = scale_w1.view(torch.uint8) + w2_q = w2_q.view(torch.uint8) + scale_w2 = scale_w2.view(torch.uint8) + + # --- K-dimension padding for non-aligned inter_dim (stage2 K=inter_dim) --- + _orig_inter_dim = inter_dim + if inter_dim % tile_k != 0: + _pad_k2 = ((inter_dim + tile_k - 1) // tile_k) * tile_k - inter_dim + if is_fp4 or is_a8w4: + w2_q = F.pad(w2_q, (0, _pad_k2 // 2)) + else: + w2_q = F.pad(w2_q, (0, _pad_k2)) + if scale_w2 is not None: + scale_w2 = F.pad(scale_w2, (0, _pad_k2 // 32)) + inter_dim = inter_dim + _pad_k2 + + # Preshuffle W2 — gfx1250 native kernels handle layout internally. + uses_fp4_weight_layout = is_fp4 or is_a8w4 + w2_shuffled = w2_q + if in_dtype in ("fp8", "a8w4"): + w2_rows = experts * model_dim + w2_cols = inter_dim // 2 if is_a8w4 else inter_dim + w2_shuffled = fp4_utils.preshuffle_b_16x16( + w2_q.contiguous().view(w2_rows, w2_cols), + w2_rows, + w2_cols, + ).view_as(w2_q) + + # Stage2 input (A2): either provided (gemm1->quantize chaining) or built from stage1 reference. + if a2_fp8_in is not None and a2_scale_in is not None: + a2_q = a2_fp8_in + a2_scale = a2_scale_in + else: + if is_fp4 or is_a8w4: + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim // 2) + scale_w1_flat = None if scale_w1 is None else scale_w1.view(experts * (2 * inter_dim), model_dim // 32) + else: # fp8 + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim) + scale_w1_flat = None if scale_w1 is None else scale_w1.view(experts * (2 * inter_dim), model_dim // 32) + if bool(skip_ref): + raise RuntimeError( + "run_moe_stage2(skip_ref=True) requires providing a2_fp8_in and a2_scale_in " + "(so we don't have to run the huge torch reference stage1)." + ) + if in_dtype == "fp8": + x_dequant = _dequant_blockscale_fp8(x_q.view(-1, model_dim), scale_x.reshape(-1, model_dim // 32)) + w1_dequant = _dequant_blockscale_fp8(w1_q_flat, scale_w1_flat) + out1_ref = torch_moe_gemm1( + x_dequant, w1_dequant, None, None, + topk_ids.to(torch.int64), topk_weights, + inter_dim=inter_dim, doweight_stage1=bool(doweight_stage1), + ) + else: + out1_ref = torch_moe_gemm1( + x_q, + w1_q_flat, + scale_x, + scale_w1_flat, + topk_ids.to(torch.int64), + topk_weights, + inter_dim=inter_dim, + doweight_stage1=bool(doweight_stage1), + ) + if in_dtype in ("fp8", "a8w4"): + a2_q, a2_scale = _per_1x32_fp8_quant(out1_ref) + else: # fp4 + a2_q = fp4_utils.random_fp4_packed(tokens * topk, inter_dim, device=device) + a2_scale = fp4_utils.random_e8m0(tokens * topk, inter_dim // 32, device=device) + + # Pad A2 activation for non-aligned inter_dim (stage2 K-padding). + if _orig_inter_dim != inter_dim: + _pad_k2 = inter_dim - _orig_inter_dim + if is_fp4: + a2_q = F.pad(a2_q, (0, _pad_k2 // 2)) + else: + a2_q = F.pad(a2_q, (0, _pad_k2)) + if a2_scale is not None: + a2_scale = F.pad(a2_scale, (0, _pad_k2 // 32)) + + # Flatten weights/scales for the kernel. + if uses_fp4_weight_layout: + w2_shuffled_flat = w2_shuffled.view(experts * model_dim, inter_dim // 2) + scale_w2_ref_flat = None if scale_w2 is None else scale_w2.view(experts * model_dim, inter_dim // 32) + else: + w2_shuffled_flat = w2_shuffled.view(experts * model_dim, inter_dim) + scale_w2_ref_flat = None if scale_w2 is None else scale_w2.view(experts * model_dim, inter_dim // 32) + + w2_flat = w2_shuffled_flat.contiguous().view(-1) + w2_kernel = w2_flat + if uses_fp4_weight_layout: + w2_kernel = w2_kernel.view(experts * model_dim, inter_dim // 2) + else: + w2_kernel = w2_kernel.view(experts * model_dim, inter_dim) + + a2_scale_1d = a2_scale.view(-1).contiguous() + if scale_w2 is None: + w2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) + else: + w2_scale_1d = scale_w2.view(-1).contiguous() + sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] + + out_s = str(out_dtype).strip().lower() + if out_s in ("f16", "fp16", "half"): + out_torch_dtype = torch.float16 + elif out_s in ("f32", "fp32", "float"): + out_torch_dtype = torch.float32 + else: + raise ValueError(f"out_dtype must be 'f16' or 'f32', got {out_dtype!r}") + + out = torch.zeros((tokens, model_dim), device=device, dtype=out_torch_dtype) + out_perf = torch.zeros_like(out) + + doweight_stage2 = not bool(doweight_stage1) + compile_kwargs = dict( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + in_dtype=in_dtype, + out_dtype=out_dtype, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=bool(doweight_stage2), + num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + expert_sched_mode=bool(expert_sched_mode), + ) + if waves_per_eu is not None: + compile_kwargs["waves_per_eu"] = waves_per_eu + try: + exe = compile_fn(**compile_kwargs) + except TypeError: + # Some wrapper compile_fns (e.g. local reduce wrappers) may not expose + # waves_per_eu; retry without it to keep compatibility. + compile_kwargs.pop("waves_per_eu", None) + exe = compile_fn(**compile_kwargs) + is_reduce_exe = (getattr(exe, "mode", None) == MoeGemm2Mode.REDUCE) or bool(use_reduce) + + def launch(o, x, w, sx, sw, st, eids, sw_sorted): + stream = torch.cuda.current_stream() + valid_mask = None + if is_reduce_exe and bool(use_valid_mask): + # Default: non-EP (all ones). EP mode can be emulated by passing expert_mask. + valid_mask = get_topk_valid_mask(topk_ids, expert_mask=None).contiguous() + if is_reduce_exe: + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + model_dim, + inter_dim, + int(blocks), + valid_mask, + stream, + ) + else: + # Atomic mode does not take valid_mask. + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + model_dim, + inter_dim, + int(blocks), + stream, + ) + + if bool(benchmark_mode) and not bool(test_graph): + def _prep_stage2(): + out_perf.zero_() + + def _run_stage2(): + launch( + out_perf, + a2_q.view(-1), + w2_kernel.view(-1), + a2_scale_1d, + w2_scale_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + ) + + us = _bench_kernel_us( + _run_stage2, + warmup=int(num_warmup), + iters=int(max(1, num_iters)), + flush_l2=bool(flush_l2), + prep_fn=_prep_stage2, + ) + torch.cuda.synchronize() + else: + _, us = run_perftest( + launch, + out_perf, + a2_q.view(-1), + w2_kernel.view(-1), + a2_scale_1d, + w2_scale_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + num_iters=int(num_iters), + num_warmup=int(num_warmup), + testGraph=test_graph, + ) + torch.cuda.synchronize() + + # Correctness run (single launch into a clean zeroed output). + out.zero_() + launch( + out, + a2_q.view(-1), + w2_kernel.view(-1), + a2_scale_1d, + w2_scale_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + ) + torch.cuda.synchronize() + + if not bool(skip_ref): + if is_a8w4: + ref2 = _torch_moe_gemm2_a8w4( + a2_q, + w2_q, + a2_scale, + scale_w2, + topk_ids.to(torch.int64), + topk_weights, + model_dim=model_dim, + doweight_stage2=doweight_stage2, + ) + elif in_dtype == "fp8": + a2_dequant = _dequant_blockscale_fp8(a2_q.view(-1, inter_dim), a2_scale.reshape(-1, inter_dim // 32)) + a2_dequant = a2_dequant.view(a2_q.shape[0], a2_q.shape[1], inter_dim) + w2_dequant = _dequant_blockscale_fp8( + w2_q.view(experts * model_dim, inter_dim), + scale_w2.view(experts * model_dim, inter_dim // 32), + ) + ref2 = torch_moe_gemm2( + a2_dequant, w2_dequant.view_as(w2_q), None, None, + topk_ids.to(torch.int64), topk_weights, + model_dim=model_dim, doweight_stage2=doweight_stage2, + ) + else: # fp4 + a2_dequant = _dequant_blockscale_fp4( + a2_q.view(-1, inter_dim // 2), a2_scale.reshape(-1, inter_dim // 32), inter_dim + ) + a2_dequant = a2_dequant.view(a2_q.shape[0], a2_q.shape[1], inter_dim) + w2_dequant = _dequant_blockscale_fp4( + w2_q.view(experts * model_dim, inter_dim // 2), + scale_w2.view(experts * model_dim, inter_dim // 32), + inter_dim, + ) + ref2 = torch_moe_gemm2( + a2_dequant, w2_dequant.view(experts, model_dim, inter_dim), None, None, + topk_ids.to(torch.int64), topk_weights, + model_dim=model_dim, doweight_stage2=doweight_stage2, + ) + assert verify_output(out.to(torch.float32), ref2, rtol=0.5, atol=0.5) + + # Launches full expert-block range; effective work is gated by num_valid_ids. + flops = 2 * tokens * topk * model_dim * _orig_inter_dim + tflops = flops / (us / 1e6) / 1e12 + + bytes_moved = 0 + bytes_a2 = tokens * topk * _orig_inter_dim # 1B elements (fp4/fp8/a8w4) + bytes_w2 = (experts * model_dim * _orig_inter_dim) // (2 if is_a8w4 else 1) + bytes_out = tokens * model_dim * (2 if out_torch_dtype == torch.float16 else 4) + bytes_scale_a2 = tokens * topk * 4 + bytes_scale_w2 = experts * model_dim * 4 + bytes_route = ( + int(sorted_weights.numel()) * 4 + + int(sorted_token_ids.numel()) * 4 + + int(sorted_expert_ids.numel()) * 4 + ) + bytes_moved += bytes_a2 + bytes_w2 + bytes_out + bytes_scale_a2 + bytes_scale_w2 + bytes_route + tbps = bytes_moved / 1e12 / (us / 1e6) + read_bytes = bytes_a2 + bytes_w2 + bytes_scale_a2 + bytes_scale_w2 + bytes_route + write_bytes = bytes_out + read_bw_gbs = read_bytes / 1e9 / (us / 1e6) + write_bw_gbs = write_bytes / 1e9 / (us / 1e6) + if bool(benchmark_mode): + print( + f"FlyDSL MoE stage2[{kernel_name}] {in_dtype} {'reduce' if use_reduce else 'atomic'} benchmark | " + f"shape=({tokens},{model_dim},{inter_dim}), E={experts}, K={topk}, " + f"tile=({tile_m},{tile_n},{tile_k})" + ) + print( + f" kernel: {us:.1f} us ({us / 1e3:.4f} ms) | " + f"{tflops:.2f} TFLOPS(logical, M={tokens*topk}) | {tbps:.3f} TB/s" + ) + print( + f" bandwidth: read {read_bw_gbs:.1f} GB/s + write {write_bw_gbs:.1f} GB/s | " + f"bytes: a2={bytes_a2/1e6:.1f}MB w2={bytes_w2/1e6:.1f}MB " + f"sa2={bytes_scale_a2/1e6:.1f}MB sw2={bytes_scale_w2/1e6:.1f}MB " + f"route={bytes_route/1e6:.1f}MB out={bytes_out/1e6:.1f}MB" + ) + else: + print( + f"FlyDSL MoE stage2 [{kernel_name}] {in_dtype} {'reduce' if use_reduce else 'atomic'} | " + f"{model_dim}x{inter_dim}, E={experts}, K={topk}, M_eff={tokens*topk} | " + f"{us:.1f} us, {tflops:.2f} TFLOPS, {tbps:.3f} TB/s" + ) + # Print profile breakdown if the executor supports it + if hasattr(exe, 'print_profile_stats'): + exe.print_profile_stats() + + if return_outputs: + return out, us + return None + + +def run_moe_gemm_2stage( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n1: int, + tile_k1: int, + tile_n2: int, + tile_k2: int, + doweight_stage1: bool, + in_dtype: str, + out_dtype: str, + use_reduce: bool, + use_valid_mask: bool, + test_graph: bool, + *, + seed: int = 0, + num_iters: int = 5, + num_warmup: int = 2, + moe_sort_mode: Optional[str] = None, + init_scale: float = 1.0, + skip_ref: bool = False, + w_fp4_kernel: bool = False, + benchmark_mode: bool = False, + flush_l2: bool = True, + num_buffers: int = 1, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, +): + """Single 2-stage test: gemm1 -> quantize -> gemm2, with routing built once.""" + if (not bool(use_reduce)) and bool(use_valid_mask): + pytest.skip("valid_mask is only used in reduce mode (atomic mode ignores it).") + if out_dtype in ("f32", "fp32", "float"): + pytest.skip(f"gfx1250 {in_dtype} kernels only support out_dtype f16/bf16, not f32.") + if in_dtype in ("fp4", "a8w4") and os.environ.get("FLYDSL_SKIP_SHAPE_GUARD", "0") != "1": + is_small_shape = ( + tokens == 64 + and model_dim == 256 + and inter_dim == 128 + and experts == 4 + and topk == 2 + ) + if not is_small_shape: + pytest.skip(f"{in_dtype} in main matrix is enabled only for the small shape.") + out_s = str(out_dtype).strip().lower() + if bool(use_reduce) and out_s in ("f32", "fp32", "float"): + pytest.skip("reduce mode does not support out_dtype='f32' (compile_moe_gemm2(accumulate=False) forbids it).") + device = torch.device("cuda") + + if init_scale == 1.0: + init_scale = 0.2 + s = float(init_scale) + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * s + w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) + + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + + if moe_sort_mode is None: + moe_sort_mode = "torch" + + routing = build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + moe_sort_mode=moe_sort_mode, + ) + + out1_fp16, _us1 = run_moe_stage1( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + in_dtype=in_dtype, + tile_m=tile_m, + tile_n=tile_n1, + tile_k=tile_k1, + doweight_stage1=bool(doweight_stage1), + seed=seed, + num_iters=num_iters, + num_warmup=num_warmup, + moe_sort_mode=moe_sort_mode, + x_fp32_in=x_fp32, + w1_fp32_in=w1_fp32, + topk_ids_in=topk_ids, + topk_weights_in=topk_weights, + routing_in=routing, + return_outputs=True, + skip_ref=bool(skip_ref), + test_graph=test_graph, + benchmark_mode=bool(benchmark_mode), + flush_l2=bool(flush_l2), + num_buffers=int(num_buffers), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + ) + + a2_q, a2_scale = _prepare_a2_from_stage1( + out1_fp16, in_dtype, tokens, topk, inter_dim, + w_fp4_kernel=w_fp4_kernel, skip_ref=bool(skip_ref), + topk_ids=topk_ids, experts=experts, device=device, + ) + + _out2_fp32, _us2 = run_moe_stage2( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + in_dtype=in_dtype, + out_dtype=out_dtype, + tile_m=tile_m, + tile_n=tile_n2, + tile_k=tile_k2, + doweight_stage1=bool(doweight_stage1), + seed=seed, + num_iters=num_iters, + num_warmup=num_warmup, + moe_sort_mode=moe_sort_mode, + x_fp32_in=x_fp32, + w1_fp32_in=w1_fp32, + w2_fp32_in=w2_fp32, + topk_ids_in=topk_ids, + topk_weights_in=topk_weights, + routing_in=routing, + a2_fp8_in=a2_q, + a2_scale_in=a2_scale, + return_outputs=True, + skip_ref=bool(skip_ref), + use_reduce=bool(use_reduce), + use_valid_mask=use_valid_mask, + test_graph=test_graph, + benchmark_mode=bool(benchmark_mode), + flush_l2=bool(flush_l2), + num_buffers=int(num_buffers), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + ) + + +# Test Helpers for MoE GEMM2 Mode Comparison +def _make_reduce_mode_compile_fn(use_flydsl_reduce: bool = True, use_valid_mask: bool = False): + """Create a compile function that forces reduce mode. + + Args: + use_flydsl_reduce: If True, use FlyDSL reduce kernel. + If False, use torch.sum (for baseline comparison). + """ + def _compile( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage2: bool, + in_dtype: str = "fp8", + out_dtype: str = "f16", + waves_per_eu: Optional[int] = None, + expert_sched_mode: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, + ): + if use_flydsl_reduce: + return compile_moe_gemm2_ex( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=doweight_stage2, + in_dtype=in_dtype, + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + valid_mask=(True if bool(use_valid_mask) else None), + mode=MoeGemm2Mode.REDUCE, + zero_intermediate=False, # test non-zeroed performance + expert_sched_mode=bool(expert_sched_mode), + num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + ) + else: + gemm2_exe = compile_moe_gemm2( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=doweight_stage2, + in_dtype=in_dtype, + out_dtype=out_dtype, + accumulate=False, + waves_per_eu=waves_per_eu, + expert_sched_mode=bool(expert_sched_mode), + num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + ) + return _TorchReduceWrapper(gemm2_exe, topk, model_dim) + return _compile + + +class _TorchReduceWrapper: + """Wrapper for GEMM2 (accumulate=False) with torch.sum reduction. + + For baseline comparison only. Production code should use compile_moe_gemm2_ex. + """ + + def __init__(self, gemm2_exe, topk: int, model_dim: int): + self._exe = gemm2_exe + self._topk = topk + self._model_dim = model_dim + self._intermediate = None + self._mode = MoeGemm2Mode.REDUCE + + def __call__( + self, + arg_out, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_num_valid_ids, + tokens_in, + n_in, + k_in, + size_expert_ids_in, + valid_mask, + stream, + ): + # Lazy allocate intermediate buffer + needed = tokens_in * self._topk * self._model_dim + if self._intermediate is None or self._intermediate.numel() < needed: + self._intermediate = torch.empty( + tokens_in * self._topk, self._model_dim, + device=arg_out.device, dtype=arg_out.dtype + ) + + intermediate = self._intermediate[:tokens_in * self._topk, :] + self._exe( + intermediate.view(-1), + arg_x, arg_w, arg_scale_x, arg_scale_w, + arg_sorted_token_ids, arg_expert_ids, arg_sorted_weights, + arg_num_valid_ids, tokens_in, n_in, k_in, size_expert_ids_in, + stream, + ) + X = intermediate.view(tokens_in, self._topk, self._model_dim) + if valid_mask is not None: + X = X * valid_mask.view(tokens_in, self._topk, 1).to(dtype=X.dtype) + torch.sum(X, dim=1, out=arg_out) + + @property + def mode(self) -> str: + return self._mode + + + +def _prepare_a2_from_stage1(out1_fp16, in_dtype, tokens, topk, inter_dim, + *, w_fp4_kernel=False, skip_ref=False, + topk_ids=None, experts=None, device=None): + """Convert stage1 fp16 output to appropriate stage2 activation input.""" + if w_fp4_kernel: + if in_dtype == "fp4": + dev = device or out1_fp16.device + if skip_ref: + a2_q = fp4_utils.random_fp4_packed(tokens * topk, inter_dim, device=dev).view( + tokens, topk, inter_dim // 2) + a2_scale = fp4_utils.random_e8m0(tokens * topk, inter_dim // 32, device=dev).view( + tokens, topk, inter_dim // 32) + else: + f32 = out1_fp16.to(torch.float32) + a2_fp4, a2_scale_raw, _ = fp4_utils.per_1x32_f4_quant(f32.view(-1, inter_dim)) + a2_q = a2_fp4.view(torch.uint8).view(tokens, topk, inter_dim // 2) + a2_scale = a2_scale_raw.view(torch.uint8).view(tokens, topk, inter_dim // 32) + else: + a2_q = out1_fp16.to(torch.float32) + a2_scale = None + return a2_q, a2_scale + + if in_dtype == "fp4": + f32 = out1_fp16.to(torch.float32) + a2_fp4, a2_scale_raw, _ = fp4_utils.per_1x32_f4_quant(f32.view(-1, inter_dim)) + a2_q = a2_fp4.view(torch.uint8).view(tokens, topk, inter_dim // 2) + a2_scale = a2_scale_raw.view(torch.uint8).view(tokens, topk, inter_dim // 32) + elif in_dtype in ("fp8", "a8w4"): + a2_q, a2_scale = _per_1x32_fp8_quant(out1_fp16.to(torch.float32)) + else: + raise ValueError(f"in_dtype must be one of ('fp4','fp8','a8w4'), got {in_dtype!r}") + return a2_q, a2_scale + + + +def test_moe_stage1_use_tdm_gather_smoke(): + """Stage1 should support toggling A TDM gather without changing numerics on padded routes.""" + tokens = 5 + model_dim = 256 + inter_dim = 128 + experts = 1 + topk = 1 + tile_m = 16 + tile_n = 64 + tile_k = 128 + + device = torch.device("cuda") + torch.manual_seed(0) + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) + topk_ids = torch.zeros((tokens, topk), device=device, dtype=torch.int64) + topk_weights = torch.ones((tokens, topk), device=device, dtype=torch.float32) + routing = build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + moe_sort_mode="torch", + ) + + common_args = dict( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=False, + in_dtype="fp8", + seed=123, + num_iters=1, + num_warmup=1, + moe_sort_mode="torch", + x_fp32_in=x_fp32, + w1_fp32_in=w1_fp32, + topk_ids_in=topk_ids, + topk_weights_in=topk_weights, + routing_in=routing, + return_outputs=True, + skip_ref=True, + ) + + out_scalar, _ = run_moe_stage1( + **common_args, + use_tdm_gather=False, + ) + out_gather, _ = run_moe_stage1( + **common_args, + use_tdm_gather=True, + ) + + assert torch.isfinite(out_scalar).all() + assert torch.isfinite(out_gather).all() + assert verify_output(out_gather.to(torch.float32), out_scalar.to(torch.float32), rtol=0.5, atol=0.5) + + +def test_moe_stage2_use_tdm_gather_smoke(): + """Stage2 should support toggling A TDM gather without changing numerics.""" + tokens = 32 + model_dim = 256 + inter_dim = 128 + experts = 4 + topk = 2 + tile_m = 16 + tile_n = 64 + tile_k = 128 + + torch.manual_seed(0) + a2_fp32 = torch.randn((tokens, topk, inter_dim), device="cuda", dtype=torch.float32) + a2_q, a2_scale = _per_1x32_fp8_quant(a2_fp32) + + common_args = dict( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=False, + in_dtype="fp8", + out_dtype="f16", + seed=123, + num_iters=1, + num_warmup=1, + return_outputs=True, + skip_ref=True, + a2_fp8_in=a2_q, + a2_scale_in=a2_scale, + ) + + out_scalar, _ = run_moe_stage2( + **common_args, + use_tdm_gather=False, + ) + out_gather, _ = run_moe_stage2( + **common_args, + use_tdm_gather=True, + ) + + assert torch.isfinite(out_scalar).all() + assert torch.isfinite(out_gather).all() + assert verify_output(out_gather.to(torch.float32), out_scalar.to(torch.float32), rtol=0.5, atol=0.5) + + +def test_moe_stage2_use_tdm_gather_padding_smoke(): + """Stage2 gather should match scalar load on deterministic padding-heavy routing.""" + tokens = 5 + model_dim = 256 + inter_dim = 128 + experts = 1 + topk = 2 + tile_m = 16 + tile_n = 64 + tile_k = 128 + + device = torch.device("cuda") + torch.manual_seed(1) + a2_fp32 = torch.randn((tokens, topk, inter_dim), device=device, dtype=torch.float32) + a2_q, a2_scale = _per_1x32_fp8_quant(a2_fp32) + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) + w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) + topk_ids = torch.zeros((tokens, topk), device=device, dtype=torch.int64) + topk_weights = torch.tensor([[0.75, 0.25]], device=device, dtype=torch.float32).expand(tokens, topk).contiguous() + routing = build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + moe_sort_mode="torch", + ) + + common_args = dict( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=False, + in_dtype="fp8", + out_dtype="f16", + seed=321, + num_iters=1, + num_warmup=1, + moe_sort_mode="torch", + x_fp32_in=x_fp32, + w1_fp32_in=w1_fp32, + w2_fp32_in=w2_fp32, + topk_ids_in=topk_ids, + topk_weights_in=topk_weights, + routing_in=routing, + a2_fp8_in=a2_q, + a2_scale_in=a2_scale, + return_outputs=True, + skip_ref=True, + ) + + out_scalar, _ = run_moe_stage2( + **common_args, + use_tdm_gather=False, + ) + out_gather, _ = run_moe_stage2( + **common_args, + use_tdm_gather=True, + ) + + assert torch.isfinite(out_scalar).all() + assert torch.isfinite(out_gather).all() + assert verify_output(out_gather.to(torch.float32), out_scalar.to(torch.float32), rtol=0.5, atol=0.5) + + +# --------------------------------------------------------------------------- +# FP4 smoke tests +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("use_reduce", [False, True], ids=["atomic", "reduce"]) +def test_moe_2stage_fp4_smoke(use_reduce: bool): + """Smoke test for gfx1250 fp4 stage1/stage2 path.""" + tokens = 32 + model_dim = 256 + inter_dim = 128 + experts = 4 + topk = 2 + tile_m = 16 + tile_n1 = 64 + tile_k1 = 128 + tile_n2 = 64 + tile_k2 = 128 + + stage1_out, _ = run_moe_stage1( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n1, + tile_k=tile_k1, + doweight_stage1=False, + in_dtype="fp4", + num_iters=1, + num_warmup=1, + return_outputs=True, + skip_ref=True, + ) + + a2_fp4 = fp4_utils.random_fp4_packed(tokens * topk, inter_dim, device=stage1_out.device) + a2_scale = fp4_utils.random_e8m0(tokens * topk, inter_dim // 32, device=stage1_out.device) + stage2_out, _ = run_moe_stage2( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n2, + tile_k=tile_k2, + doweight_stage1=False, + in_dtype="fp4", + out_dtype="f16", + num_iters=1, + num_warmup=1, + return_outputs=True, + skip_ref=True, + a2_fp8_in=a2_fp4, + a2_scale_in=a2_scale, + use_reduce=bool(use_reduce), + ) + + assert torch.isfinite(stage1_out).all() + assert torch.isfinite(stage2_out).all() + + +def test_moe_2stage_fp4_wfp4_reduce_reference(): + """wfp4 correctness path should still use stage1-derived fp4 A2 input.""" + tokens = 64 + model_dim = 256 + inter_dim = 128 + experts = 4 + topk = 2 + tile_m = 16 + tile_n = 64 + tile_k = 128 + seed = 0 + run_moe_gemm_2stage( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n1=tile_n, + tile_k1=tile_k, + tile_n2=tile_n, + tile_k2=tile_k, + doweight_stage1=False, + in_dtype="fp4", + out_dtype="f16", + seed=seed, + num_iters=1, + num_warmup=1, + moe_sort_mode="torch", + skip_ref=False, + use_reduce=True, + use_tdm_store=False, + w_fp4_kernel=True, + test_graph=False, + use_valid_mask=False, + ) + + +# --------------------------------------------------------------------------- +# Main parametrized 2-stage test — MXScale dtypes only +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n1, tile_k1, tile_n2, tile_k2, doweight_stage1", + [ + pytest.param(64, 256, 128, 4, 2, 16, 64, 128, 64, 128, False, id="S"), + pytest.param(129, 1024, 256, 8, 2, 32, 128, 128, 128, 128, False, id="M"), + pytest.param(333, 4096, 2048, 17, 9, 64, 128, 128, 256, 128, False, id="L", marks=pytest.mark.large_shape), + ], +) +@pytest.mark.parametrize("in_dtype", ["fp4", "fp8", "a8w4"]) +@pytest.mark.parametrize("out_dtype", ["f16"], ids=["out_f16"]) +@pytest.mark.parametrize("use_reduce", [False, True], ids=["atomic", "reduce"]) +@pytest.mark.parametrize("use_valid_mask", [False, True], ids=["nomask", "mask"]) +@pytest.mark.parametrize("test_graph", [ + pytest.param(False, id="eager"), + pytest.param(True, id="graph"), +]) +def test_moe_gemm_2stage( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n1: int, + tile_k1: int, + tile_n2: int, + tile_k2: int, + doweight_stage1: bool, + in_dtype: str, + out_dtype: str, + use_reduce: bool, + use_valid_mask: bool, + test_graph: bool, +): + run_moe_gemm_2stage( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n1=tile_n1, + tile_k1=tile_k1, + tile_n2=tile_n2, + tile_k2=tile_k2, + doweight_stage1=doweight_stage1, + in_dtype=in_dtype, + out_dtype=out_dtype, + use_reduce=use_reduce, + use_valid_mask=use_valid_mask, + test_graph=test_graph, + ) + + +# --------------------------------------------------------------------------- +# Standalone stage2 test (atomic vs reduce comparison) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n, tile_k", + [ + pytest.param(8192, 7168, 256, 128, 8, 64, 256, 128, id="DS-TP8-prefill-S", marks=pytest.mark.large_shape), + pytest.param(16384, 7168, 256, 256, 8, 64, 256, 128, id="DS-TP8-prefill-M", marks=pytest.mark.large_shape), + pytest.param(32768, 7168, 256, 256, 8, 64, 256, 128, id="DS-TP8-prefill-L", marks=pytest.mark.large_shape), + pytest.param(1, 7168, 256, 256, 8, 16, 256, 128, id="DS-TP8-decode-bs1"), + pytest.param(8, 7168, 256, 256, 8, 32, 256, 128, id="DS-TP8-decode-bs8"), + pytest.param(1666, 5120, 1536, 64, 6, 64, 256, 128, id="EP-K6-prefill", marks=pytest.mark.large_shape), + pytest.param(32768, 5120, 1536, 64, 6, 64, 256, 128, id="EP-K6-prefill-L", marks=pytest.mark.large_shape), + pytest.param(1, 5120, 1536, 16, 6, 16, 128, 256, id="EP-K6-decode-bs1"), + pytest.param(8, 5120, 1536, 16, 6, 64, 128, 128, id="EP-K6-decode-bs8"), + ], +) +@pytest.mark.parametrize("in_dtype", ["fp8"]) +def test_moe_stage2_standalone( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + in_dtype: str, + *, + seed: int = 0, + num_iters: int = 10, + num_warmup: int = 3, +): + """Standalone stage2 test comparing atomic vs reduce modes. + + Tests: + 1. Atomic mode: direct accumulation with atomics + 2. Reduce mode (torch): GEMM2 + torch.sum reduction + 3. Reduce mode (FlyDSL): GEMM2 + FlyDSL reduce kernel + """ + common_args = dict( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=False, + in_dtype=in_dtype, + seed=seed, + num_iters=num_iters, + num_warmup=num_warmup, + moe_sort_mode="torch", + skip_ref=False, + ) + + run_moe_stage2(**common_args, kernel_name="moe_gemm2_atomic") + + run_moe_stage2( + **common_args, + compile_fn=_make_reduce_mode_compile_fn(use_flydsl_reduce=False), + kernel_name="moe_gemm2_reduce_torch", + ) + + run_moe_stage2( + **common_args, + use_reduce=True, + kernel_name="moe_gemm2_reduce_flydsl", + ) + + run_moe_stage2( + **common_args, + use_reduce=True, + use_valid_mask=True, + kernel_name="moe_gemm2_reduce_flydsl_valid_mask", + ) diff --git a/tests/kernels/test_moe_gemm_wmma_gfx1250.py b/tests/kernels/test_moe_gemm_wmma_gfx1250.py new file mode 100644 index 00000000..737ed3e3 --- /dev/null +++ b/tests/kernels/test_moe_gemm_wmma_gfx1250.py @@ -0,0 +1,1419 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""MoE GEMM tests for WMMA data types (fp16/bf16) on gfx1250. +""" +import argparse +import logging +import math +import os +import sys +from typing import Tuple, Optional + +import pytest +import torch + +# ----------------------------------------------------------------------------- +# Ensure we use the repo-local `flydsl` when running this file directly. +# +# Some environments have another `flydsl` (e.g. from a sibling checkout) earlier +# on `sys.path`, which can miss newer ROCDL wrappers (notably atomic fadd / MFMA). +# ----------------------------------------------------------------------------- +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYTHON_CANDIDATES = [ + os.path.join(_REPO_ROOT, "build", "python_packages"), + _REPO_ROOT, +] +for _p in reversed(_PYTHON_CANDIDATES): + if os.path.isdir(_p) and _p not in sys.path: + sys.path.insert(0, _p) + +from tests.kernels.test_ref import torch_moe_gemm1, torch_moe_gemm2 +from tests.test_common import verify_output, run_perftest +from flydsl.runtime.device import get_rocm_arch +from tests.kernels.benchmark_common import ( + bench_kernel_us as _bench_kernel_us, + add_moe_bench_args, + moe_bench_main, +) + +ARCH = get_rocm_arch() + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +if not str(ARCH).startswith("gfx1250"): + pytest.skip(f"MoE 2stage gfx1250 tests require gfx1250, got {ARCH}", allow_module_level=True) + + +# Optional: use aiter's exact routing/sorting implementation (matches `aiter/op_tests/test_moe_2stage.py`). +# Some environments ship aiter python but miss required JIT .so dependencies; we fall back gracefully. +try: + import aiter + from aiter.fused_moe import moe_sorting as aiter_moe_sorting + + HAS_AITER = True +except Exception: + HAS_AITER = False + +# Kernel implementations live under `kernels/`; this test file is the harness. +from kernels.moe_gemm_2stage_wmma_gfx1250 import ( + compile_moe_gemm1, + compile_moe_gemm2, + compile_moe_gemm2_ex, + MoeGemm2Mode, +) + +logging.basicConfig(level=logging.INFO) + +# Reduce noisy aiter log spam (e.g. "type hints mismatch, override to --> ...") so test output +# stays readable. You can override via env: FLYDSL_AITER_LOG_LEVEL=INFO/WARNING/ERROR. +_aiter_level = os.environ.get("FLYDSL_AITER_LOG_LEVEL", "ERROR").upper().strip() +try: + logging.getLogger("aiter").setLevel(getattr(logging, _aiter_level, logging.ERROR)) +except Exception: + # Best-effort only; never break tests due to logging configuration. + pass + + +def moe_sorting_torch_native( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + *, + num_experts: int, + block_size: int, + expert_mask: Optional[torch.Tensor] = None, + num_local_tokens: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Torch reference for aiter's moe_sorting. + + Returns: + - sorted_ids[int32]: fused (topk_slot<<24 | token_id) + - sorted_weights[fp32]: aligned with sorted_ids + - sorted_expert_ids[int32]: one expert id per M-block (size = num_blocks) + - num_tokens_post_pad[int32]: [0]=total padded tokens, [1]=num_tokens (logical) + + Notes: + - This function intentionally mirrors `aiter/op_tests/test_moe_sorting.py::moe_sorting_native`. + """ + assert topk_ids.is_cuda and topk_weights.is_cuda + device = topk_ids.device + M, topk = topk_ids.shape + topk = topk_ids.shape[1] + + # Upper bound allocation (matches aiter op_tests; not strictly required but keeps shapes predictable). + max_num_tokens_padded = int(topk_ids.numel() + int(num_experts) * int(block_size) - int(topk)) + max_num_m_blocks = int((max_num_tokens_padded + int(block_size) - 1) // int(block_size)) + + init_val = (int(topk) << 24) | int(M) + sorted_ids = torch.full((max_num_tokens_padded,), init_val, dtype=torch.int32, device=device) + sorted_weights = torch.empty((max_num_tokens_padded,), dtype=torch.float32, device=device) + sorted_expert_ids = torch.full((max_num_m_blocks,), -1, dtype=torch.int32, device=device) + num_tokens_post_pad = torch.empty((2,), dtype=torch.int32, device=device) + + if num_local_tokens is not None: + topk_ids = topk_ids[: num_local_tokens.item()] + + sorted_ids_begin = 0 + sorted_expert_ids_begin = 0 + skip_expert_num = 0 + for expertId in range(int(num_experts)): + if expert_mask is not None and int(expert_mask[expertId].item()) == 0: + skip_expert_num += 1 + continue + token_id, topk_id = torch.where(topk_ids == expertId) + tokensNum = int(token_id.numel()) + sorted_expert_ids_num = int((tokensNum + int(block_size) - 1) // int(block_size)) + tokensNumPad = int(sorted_expert_ids_num * int(block_size)) + sorted_ids[sorted_ids_begin : sorted_ids_begin + tokensNum] = ( + (topk_id.to(torch.int32) << 24) | token_id.to(torch.int32) + ) + sorted_weights[sorted_ids_begin : sorted_ids_begin + tokensNum] = topk_weights[ + token_id, topk_id + ].to(torch.float32) + sorted_ids_begin = int(sorted_ids_begin + tokensNumPad) + sorted_expert_ids[ + sorted_expert_ids_begin : sorted_expert_ids_begin + sorted_expert_ids_num + ] = int(expertId - skip_expert_num) + sorted_expert_ids_begin = int(sorted_expert_ids_begin + sorted_expert_ids_num) + + num_tokens_post_pad[0] = int(sorted_ids_begin) + num_tokens_post_pad[1] = int(topk_ids.shape[0]) + + return sorted_ids, sorted_weights, sorted_expert_ids, num_tokens_post_pad + + +def _maybe_aiter_moe_sorting( + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + *, + num_experts: int, + model_dim: int, + block_m: int, +) -> Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: + """Return (sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids) or None.""" + if not HAS_AITER: + return None + try: + # aiter expects i32 ids and fp32 weights + topk_ids_i32 = topk_ids.to(torch.int32) + topk_w_f32 = topk_weights.to(torch.float32) + sorted_ids, sorted_w, sorted_expert_ids, num_valid_ids, _moe_buf = aiter_moe_sorting( + topk_ids_i32, + topk_w_f32, + num_experts, + model_dim, + torch.float16, + block_m, + ) + # `num_valid_ids` is documented as [1]; some builds allocate [2]. Keep the first element. + if num_valid_ids.numel() > 1: + num_valid_ids = num_valid_ids[:1].contiguous() + return sorted_ids, sorted_w, sorted_expert_ids, num_valid_ids + except Exception: + return None + + +RoutingBuffers = Tuple[ + torch.Tensor, # sorted_token_ids + torch.Tensor, # sorted_weights + torch.Tensor, # sorted_expert_ids + torch.Tensor, # num_valid_ids (shape [1], i32) + int, # sorted_size + int, # blocks +] + + +def get_topk_valid_mask(topk_ids: torch.Tensor, expert_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + """Build valid_mask [tokens, topk] for (optional) EP-style masking. + + Mirrors `aiter.fused_moe.get_topk_valid_mask` semantics: + - If expert_mask is None: all slots are valid (all ones) + - Else: valid_mask[t, k] = expert_mask[topk_ids[t, k]] (cast to int8) + """ + if expert_mask is None: + return torch.ones(topk_ids.shape, dtype=torch.int8, device=topk_ids.device) + return expert_mask[topk_ids].to(torch.int8) + + +def build_routing_buffers( + *, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + experts: int, + model_dim: int, + tile_m: int, + moe_sort_mode: Optional[str] = None, +) -> RoutingBuffers: + """Build routing buffers once, reusable across stage1 + stage2. + + NOTE: + - `moe_sort_mode="aiter"` aligns with `aiter/aiter/test_moe_flydsl.py` (swap path): + - Use aiter's `moe_sorting` output directly (no host trim/pad of sorted buffers) + - Launch full expert-block range; kernels use `num_valid_ids` to early-exit extra blocks + - `moe_sort_mode="torch"` is a portable fallback when aiter isn't available: + - Mirrors `aiter/op_tests/test_moe_sorting.py::moe_sorting_native` for consistent semantics + """ + device = topk_ids.device + default_mode = "aiter" if HAS_AITER else "torch" + sort_mode = str(moe_sort_mode or os.environ.get("flydsl_MOE_SORT_MODE", default_mode)).lower().strip() + if sort_mode not in ("aiter", "torch"): + raise ValueError(f"invalid moe_sort_mode={sort_mode!r} (expected 'aiter' or 'torch')") + + if sort_mode == "torch": + sorted_token_ids, sorted_weights, sorted_expert_ids, num_tokens_post_pad = moe_sorting_torch_native( + topk_ids=topk_ids.to(torch.int32), + topk_weights=topk_weights.to(torch.float32), + num_experts=int(experts), + block_size=int(tile_m), + ) + # num_valid_ids[0] == total padded rows; kernels use this for early-exit. + num_valid_ids = num_tokens_post_pad[:1].contiguous() + sorted_size = int(sorted_token_ids.numel()) + blocks = int(sorted_expert_ids.numel()) + return ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + sorted_size, + blocks, + ) + + # aiter mode + if not HAS_AITER: + raise RuntimeError("aiter is not available; cannot build routing buffers (moe_sort_mode='aiter').") + + res = _maybe_aiter_moe_sorting( + topk_ids, + topk_weights, + num_experts=experts, + model_dim=model_dim, + block_m=tile_m, + ) + if res is None: + raise RuntimeError("aiter moe_sorting failed/unavailable; cannot build routing buffers.") + sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids = res + + # Keep moe_sorting outputs as-is (no host trim/pad). Launch full expert-block range. + sorted_token_ids = sorted_token_ids.contiguous() + sorted_weights = sorted_weights.contiguous() + sorted_expert_ids = sorted_expert_ids.contiguous() + sorted_size = int(sorted_token_ids.numel()) + blocks = int(sorted_expert_ids.numel()) + return ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + sorted_size, + blocks, + ) + + +# ---- Stage1/Stage2 runners (helpers; NOT pytest tests) ---- +def run_moe_stage1( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage1: bool, + *, + in_dtype: str = "fp16", + seed: int = 0, + num_iters: int = 5, + num_warmup: int = 2, + moe_sort_mode: Optional[str] = None, + x_fp32_in: Optional[torch.Tensor] = None, + w1_fp32_in: Optional[torch.Tensor] = None, + topk_ids_in: Optional[torch.Tensor] = None, + topk_weights_in: Optional[torch.Tensor] = None, + routing_in: Optional[RoutingBuffers] = None, + return_outputs: bool = False, + skip_ref: bool = False, + test_graph: bool = False, + waves_per_eu: Optional[int] = None, + benchmark_mode: bool = False, + flush_l2: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, + expert_sched_mode: bool = True, +): + assert model_dim % 64 == 0 + assert model_dim % tile_k == 0, f"model_dim={model_dim} must be divisible by tile_k={tile_k}" + assert inter_dim % tile_n == 0 + + if in_dtype not in ("fp16", "bf16"): + raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") + + device = torch.device("cuda") + torch.manual_seed(int(seed)) + + # Data: input and weights (aiter shapes) + x_fp32 = ( + x_fp32_in + if x_fp32_in is not None + else torch.randn((tokens, model_dim), device=device, dtype=torch.float32) + ) + w1_fp32 = ( + w1_fp32_in + if w1_fp32_in is not None + else torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) + ) + + # Routing: aiter uses fused_topk; we use torch topk+softmax for portability/determinism. + if topk_ids_in is None or topk_weights_in is None: + score = torch.randn((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + else: + topk_ids = topk_ids_in + topk_weights = topk_weights_in + + routing = ( + routing_in + if routing_in is not None + else build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + moe_sort_mode=moe_sort_mode, + ) + ) + ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + sorted_size, + blocks, + ) = routing + + cast = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + x_q = x_fp32.to(cast) + w1_q = w1_fp32.to(cast) + + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim) + x_q = x_q.contiguous().view(tokens, model_dim) + w_kernel = w1_q_flat.contiguous() + + scale_x_1d = torch.empty((0,), device=device, dtype=torch.float32) + scale_w1_1d = torch.empty((0,), device=device, dtype=torch.float32) + sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] + + # Output: [tokens, topk, inter_dim] fp16 + out = torch.zeros((tokens, topk, inter_dim), device=device, dtype=torch.float16) + + from kernels.moe_gemm_2stage_wmma_gfx1250 import compile_moe_gemm1 + exe = compile_moe_gemm1( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + in_dtype=in_dtype, + group_size=-1, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=bool(doweight_stage1), + use_cshuffle_epilog=False, + waves_per_eu=waves_per_eu, + num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + expert_sched_mode=bool(expert_sched_mode), + ) + + def launch(o, x, w, sx, sw, st, eids, sw_sorted): + stream = torch.cuda.current_stream() + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + inter_dim, + model_dim, + int(blocks), + stream, + ) + + if bool(benchmark_mode) and not bool(test_graph): + def _prep_stage1(): + out.zero_() + + def _run_stage1(): + launch(out, x_q, w_kernel, scale_x_1d, scale_w1_1d, sorted_token_ids, sorted_expert_ids, sorted_weights_1d) + + us = _bench_kernel_us( + _run_stage1, + warmup=int(num_warmup), + iters=int(max(1, num_iters)), + flush_l2=bool(flush_l2), + prep_fn=_prep_stage1, + ) + torch.cuda.synchronize() + else: + _, us = run_perftest( + launch, + out, + x_q, + w_kernel, + scale_x_1d, + scale_w1_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + num_iters=int(num_iters), + num_warmup=int(num_warmup), + testGraph=test_graph, + ) + torch.cuda.synchronize() + + if not bool(skip_ref): + ref = torch_moe_gemm1( + x_q, + w1_q_flat, + None, + None, + topk_ids.to(torch.int64), + topk_weights, + inter_dim=inter_dim, + doweight_stage1=doweight_stage1, + ) + assert verify_output(out.to(torch.float32), ref, rtol=0.25, atol=0.25) + + # Note: kernel launches full expert-block range; effective work is gated by num_valid_ids. + flops = 2 * tokens * topk * (2 * inter_dim) * model_dim + tflops = flops / (us / 1e6) / 1e12 + + # Rough bytes-moved accounting (same spirit as GEMM tests: count each tensor once). + x_elem_bytes = 2 + bytes_x = tokens * model_dim * x_elem_bytes + bytes_w = experts * (2 * inter_dim) * model_dim + bytes_out = tokens * topk * inter_dim * 2 + bytes_scale_x = tokens * 4 + bytes_scale_w = experts * (2 * inter_dim) * 4 + bytes_route = ( + int(sorted_weights.numel()) * 4 + + int(sorted_token_ids.numel()) * 4 + + int(sorted_expert_ids.numel()) * 4 + ) + bytes_moved = bytes_x + bytes_w + bytes_out + bytes_scale_x + bytes_scale_w + bytes_route + tbps = bytes_moved / 1e12 / (us / 1e6) + read_bytes = bytes_x + bytes_w + bytes_scale_x + bytes_scale_w + bytes_route + write_bytes = bytes_out + read_bw_gbs = read_bytes / 1e9 / (us / 1e6) + write_bw_gbs = write_bytes / 1e9 / (us / 1e6) + + if bool(benchmark_mode): + print( + f"FlyDSL MoE stage1[{in_dtype}] benchmark | " + f"shape=({tokens},{model_dim},{inter_dim}), E={experts}, K={topk}, " + f"tile=({tile_m},{tile_n},{tile_k})" + ) + print( + f" kernel: {us:.1f} us ({us / 1e3:.4f} ms) | " + f"{tflops:.2f} TFLOPS(logical, M={tokens*topk}) | {tbps:.3f} TB/s" + ) + print( + f" bandwidth: read {read_bw_gbs:.1f} GB/s + write {write_bw_gbs:.1f} GB/s | " + f"bytes: x={bytes_x/1e6:.1f}MB w={bytes_w/1e6:.1f}MB " + f"sx={bytes_scale_x/1e6:.1f}MB sw={bytes_scale_w/1e6:.1f}MB " + f"route={bytes_route/1e6:.1f}MB out={bytes_out/1e6:.1f}MB" + ) + else: + print( + f"FlyDSL MoE stage1[{in_dtype}]: " + f"{us:.1f} us, " + f"{tflops:.2f} TFLOPS(logical, M={tokens*topk}), " + f"{tbps:.3f} TB/s (doweight_stage1={doweight_stage1})" + ) + if return_outputs: + return out, us + return None + + +def run_moe_stage2( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage1: bool, + *, + in_dtype: str = "fp16", + out_dtype: str = "f16", + seed: int = 0, + num_iters: int = 5, + num_warmup: int = 2, + moe_sort_mode: Optional[str] = None, + x_fp32_in: Optional[torch.Tensor] = None, + w1_fp32_in: Optional[torch.Tensor] = None, + w2_fp32_in: Optional[torch.Tensor] = None, + topk_ids_in: Optional[torch.Tensor] = None, + topk_weights_in: Optional[torch.Tensor] = None, + routing_in: Optional[RoutingBuffers] = None, + a2_fp8_in: Optional[torch.Tensor] = None, + a2_scale_in: Optional[torch.Tensor] = None, + return_outputs: bool = False, + skip_ref: bool = False, + init_scale: float = 0.2, + compile_fn=None, + kernel_name: str = "moe_gemm2", + use_reduce: bool = False, + use_valid_mask: bool = False, + test_graph: bool = False, + waves_per_eu: Optional[int] = None, + benchmark_mode: bool = False, + flush_l2: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, + expert_sched_mode: bool = True, +): + """MoE stage2 (gemm2): out2[t] = sum_{slot} ( out1[t,slot] @ W2[expert]^T ) with optional routed weight.""" + if in_dtype not in ("fp16", "bf16"): + raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") + + # Parameter sanity checks with actionable hints (avoid bare AssertionError). + if model_dim % tile_n != 0: + raise ValueError( + f"Invalid stage2 tiling: model_dim ({model_dim}) must be divisible by tile_n2 ({tile_n})." + ) + if inter_dim % tile_k != 0: + raise ValueError( + "Invalid stage2 tiling: inter_dim ({inter_dim}) must be divisible by tile_k2 ({tile_k}). " + "Try setting `--tile_k2` to a divisor of inter_dim. " + "Tip: stage2 splits A2 loads across 256 threads; if you want smaller tile_k2, you may need a larger tile_m so (tile_m*tile_k2) stays divisible by 1024." + .format(inter_dim=inter_dim, tile_k=tile_k) + ) + # Enforce the kernel's stage2 gmem->reg load mapping constraints. + # See: kernels/moe_gemm_2stage.py::compile_moe_gemm2 (x_load_bytes selection). + if (tile_m * tile_k) % 256 != 0: + raise ValueError( + f"Invalid stage2 tiling: tile_m*tile_k2 must be divisible by 256 (total_threads=256). " + f"Got tile_m={tile_m}, tile_k2={tile_k} -> tile_m*tile_k2={tile_m * tile_k}." + ) + bytes_per_thread_x = (tile_m * tile_k) // 256 # 1B elements + if bytes_per_thread_x % 4 != 0: + raise ValueError( + f"Invalid stage2 tiling for gmem loads: bytes_per_thread_x ((tile_m*tile_k2)/256) must be divisible by 4. " + f"Got tile_m={tile_m}, tile_k2={tile_k} -> bytes_per_thread_x={bytes_per_thread_x}. " + ) + + # Default compile function. + if compile_fn is None: + if use_reduce: + compile_fn = _make_reduce_mode_compile_fn(use_flydsl_reduce=True, use_valid_mask=bool(use_valid_mask)) + else: + compile_fn = compile_moe_gemm2 + + device = torch.device("cuda") + torch.manual_seed(int(seed)) + + s = float(init_scale) + + # Data: input and weights (aiter shapes) + x_fp32 = ( + x_fp32_in + if x_fp32_in is not None + else torch.rand((tokens, model_dim), device=device, dtype=torch.float32) * s + ) + w1_fp32 = ( + w1_fp32_in + if w1_fp32_in is not None + else torch.rand((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * (s / math.sqrt(model_dim)) + ) + w2_fp32 = ( + w2_fp32_in + if w2_fp32_in is not None + else torch.rand((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) + ) + + # Routing: deterministic torch topk + softmax. + if topk_ids_in is None or topk_weights_in is None: + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + else: + topk_ids = topk_ids_in + topk_weights = topk_weights_in + + routing = ( + routing_in + if routing_in is not None + else build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + moe_sort_mode=moe_sort_mode, + ) + ) + ( + sorted_token_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + sorted_size, + blocks, + ) = routing + # NOTE: routing uses `moe_sorting` output directly (no host trim/pad). Extra launched blocks + # are gated by `num_valid_ids` inside the kernels. + + cast = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + x_q = x_fp32.to(cast) + w1_q = w1_fp32.to(cast) + w2_q = w2_fp32.to(cast) + + if a2_fp8_in is not None: + a2_q = a2_fp8_in + a2_scale = a2_scale_in + else: + w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim) + if bool(skip_ref): + raise RuntimeError( + "run_moe_stage2(skip_ref=True) requires providing a2_fp8_in " + "(so we don't have to run the huge torch reference stage1)." + ) + out1_ref = torch_moe_gemm1( + x_q, + w1_q_flat, + None, + None, + topk_ids.to(torch.int64), + topk_weights, + inter_dim=inter_dim, + doweight_stage1=bool(doweight_stage1), + ) + if in_dtype == "fp16": + a2_q = out1_ref.to(torch.float16) + a2_scale = None + else: + a2_q = out1_ref.to(torch.bfloat16) + a2_scale = None + + w2_shuffled_flat = w2_q.view(experts * model_dim, inter_dim) + w2_kernel = w2_shuffled_flat.contiguous().view(-1) + + a2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) + w2_scale_1d = torch.empty((0,), device=device, dtype=torch.float32) + sorted_weights_1d = sorted_weights.contiguous().view(-1) # [sorted_size] + + out_s = str(out_dtype).strip().lower() + if out_s in ("f16", "fp16", "half"): + out_torch_dtype = torch.float16 + elif out_s in ("f32", "fp32", "float"): + out_torch_dtype = torch.float32 + else: + raise ValueError(f"out_dtype must be 'f16' or 'f32', got {out_dtype!r}") + + out = torch.zeros((tokens, model_dim), device=device, dtype=out_torch_dtype) + out_perf = torch.zeros_like(out) + + doweight_stage2 = not bool(doweight_stage1) + compile_kwargs = dict( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + in_dtype=in_dtype, + out_dtype=out_dtype, + group_size=-1, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=bool(doweight_stage2), + num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + expert_sched_mode=bool(expert_sched_mode), + ) + if waves_per_eu is not None: + compile_kwargs["waves_per_eu"] = waves_per_eu + try: + exe = compile_fn(**compile_kwargs) + except TypeError: + # Some wrapper compile_fns (e.g. local reduce wrappers) may not expose + # waves_per_eu; retry without it to keep compatibility. + compile_kwargs.pop("waves_per_eu", None) + exe = compile_fn(**compile_kwargs) + is_reduce_exe = (getattr(exe, "mode", None) == MoeGemm2Mode.REDUCE) or bool(use_reduce) + + def launch(o, x, w, sx, sw, st, eids, sw_sorted): + stream = torch.cuda.current_stream() + valid_mask = None + if is_reduce_exe and bool(use_valid_mask): + # Default: non-EP (all ones). EP mode can be emulated by passing expert_mask. + valid_mask = get_topk_valid_mask(topk_ids, expert_mask=None).contiguous() + if is_reduce_exe: + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + model_dim, + inter_dim, + int(blocks), + valid_mask, + stream, + ) + else: + # Atomic mode does not take valid_mask. + exe( + o, + x, + w, + sx, + sw, + st, + eids, + sw_sorted, + num_valid_ids, + tokens, + model_dim, + inter_dim, + int(blocks), + stream, + ) + + # NOTE: stage2 uses atomic-add into `out`, so we cannot reuse the same output buffer + # across perf iterations for correctness. Time into a dedicated buffer, then run + # a single clean launch for correctness verification below. + if bool(benchmark_mode) and not bool(test_graph): + def _prep_stage2(): + out_perf.zero_() + + def _run_stage2(): + launch( + out_perf, + a2_q.view(-1), + w2_kernel.view(-1), + a2_scale_1d, + w2_scale_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + ) + + us = _bench_kernel_us( + _run_stage2, + warmup=int(num_warmup), + iters=int(max(1, num_iters)), + flush_l2=bool(flush_l2), + prep_fn=_prep_stage2, + ) + torch.cuda.synchronize() + else: + _, us = run_perftest( + launch, + out_perf, + a2_q.view(-1), + w2_kernel.view(-1), + a2_scale_1d, + w2_scale_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + num_iters=int(num_iters), + num_warmup=int(num_warmup), + testGraph=test_graph, + ) + torch.cuda.synchronize() + + # Correctness run (single launch into a clean zeroed output). + out.zero_() + launch( + out, + a2_q.view(-1), + w2_kernel.view(-1), + a2_scale_1d, + w2_scale_1d, + sorted_token_ids, + sorted_expert_ids, + sorted_weights_1d, + ) + torch.cuda.synchronize() + + if not bool(skip_ref): + ref2 = torch_moe_gemm2( + a2_q, + w2_q, + None, + None, + topk_ids.to(torch.int64), + topk_weights, + model_dim=model_dim, + doweight_stage2=doweight_stage2, + ) + assert verify_output(out.to(torch.float32), ref2, rtol=0.5, atol=0.5) + + # Launches full expert-block range; effective work is gated by num_valid_ids. + flops = 2 * tokens * topk * model_dim * inter_dim + tflops = flops / (us / 1e6) / 1e12 + + a2_elem_bytes = 2 + bytes_a2 = tokens * topk * inter_dim * a2_elem_bytes + bytes_w2 = experts * model_dim * inter_dim + bytes_out = tokens * model_dim * (2 if out_torch_dtype == torch.float16 else 4) + bytes_scale_a2 = tokens * topk * 4 + bytes_scale_w2 = experts * model_dim * 4 + bytes_route = ( + int(sorted_weights.numel()) * 4 + + int(sorted_token_ids.numel()) * 4 + + int(sorted_expert_ids.numel()) * 4 + ) + bytes_moved = bytes_a2 + bytes_w2 + bytes_out + bytes_scale_a2 + bytes_scale_w2 + bytes_route + tbps = bytes_moved / 1e12 / (us / 1e6) + read_bytes = bytes_a2 + bytes_w2 + bytes_scale_a2 + bytes_scale_w2 + bytes_route + write_bytes = bytes_out + read_bw_gbs = read_bytes / 1e9 / (us / 1e6) + write_bw_gbs = write_bytes / 1e9 / (us / 1e6) + if bool(benchmark_mode): + print( + f"FlyDSL MoE stage2[{kernel_name}] {in_dtype} {'reduce' if use_reduce else 'atomic'} benchmark | " + f"shape=({tokens},{model_dim},{inter_dim}), E={experts}, K={topk}, " + f"tile=({tile_m},{tile_n},{tile_k})" + ) + print( + f" kernel: {us:.1f} us ({us / 1e3:.4f} ms) | " + f"{tflops:.2f} TFLOPS(logical, M={tokens*topk}) | {tbps:.3f} TB/s" + ) + print( + f" bandwidth: read {read_bw_gbs:.1f} GB/s + write {write_bw_gbs:.1f} GB/s | " + f"bytes: a2={bytes_a2/1e6:.1f}MB w2={bytes_w2/1e6:.1f}MB " + f"sa2={bytes_scale_a2/1e6:.1f}MB sw2={bytes_scale_w2/1e6:.1f}MB " + f"route={bytes_route/1e6:.1f}MB out={bytes_out/1e6:.1f}MB" + ) + else: + print( + f"FlyDSL MoE stage2 [{kernel_name}] {in_dtype} {'reduce' if use_reduce else 'atomic'} | " + f"{model_dim}x{inter_dim}, E={experts}, K={topk}, M_eff={tokens*topk} | " + f"{us:.1f} us, {tflops:.2f} TFLOPS, {tbps:.3f} TB/s" + ) + + # Print profile breakdown if the executor supports it + if hasattr(exe, 'print_profile_stats'): + exe.print_profile_stats() + + if return_outputs: + return out, us + return None + + +def run_moe_gemm_2stage( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n1: int, + tile_k1: int, + tile_n2: int, + tile_k2: int, + doweight_stage1: bool, + in_dtype: str, + out_dtype: str, + use_reduce: bool, + use_valid_mask: bool, + test_graph: bool, + *, + seed: int = 0, + num_iters: int = 5, + num_warmup: int = 2, + moe_sort_mode: Optional[str] = None, + init_scale: float = 0.2, + skip_ref: bool = False, + benchmark_mode: bool = False, + flush_l2: bool = True, + num_buffers: int = 1, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, +): + """Single 2-stage test: gemm1 -> quantize -> gemm2, with routing built once.""" + if (not bool(use_reduce)) and bool(use_valid_mask): + pytest.skip("valid_mask is only used in reduce mode (atomic mode ignores it).") + out_s = str(out_dtype).strip().lower() + if bool(use_reduce) and out_s in ("f32", "fp32", "float"): + pytest.skip("reduce mode does not support out_dtype='f32' (compile_moe_gemm2(accumulate=False) forbids it).") + device = torch.device("cuda") + + s = float(init_scale) + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * s + w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) + + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + + routing = build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + moe_sort_mode=moe_sort_mode, + ) + + _shared = dict( + seed=seed, num_iters=num_iters, num_warmup=num_warmup, + moe_sort_mode=moe_sort_mode, + x_fp32_in=x_fp32, w1_fp32_in=w1_fp32, + topk_ids_in=topk_ids, topk_weights_in=topk_weights, + routing_in=routing, return_outputs=True, + skip_ref=bool(skip_ref), test_graph=test_graph, + benchmark_mode=bool(benchmark_mode), flush_l2=bool(flush_l2), + num_buffers=int(num_buffers), use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), cluster_n=int(cluster_n), + ) + + out1_fp16, _us1 = run_moe_stage1( + tokens=tokens, model_dim=model_dim, inter_dim=inter_dim, + experts=experts, topk=topk, in_dtype=in_dtype, + tile_m=tile_m, tile_n=tile_n1, tile_k=tile_k1, + doweight_stage1=bool(doweight_stage1), + **_shared, + ) + + a2_q, a2_scale = _prepare_a2_from_stage1(out1_fp16, in_dtype) + + _out2_fp32, _us2 = run_moe_stage2( + tokens=tokens, model_dim=model_dim, inter_dim=inter_dim, + experts=experts, topk=topk, in_dtype=in_dtype, out_dtype=out_dtype, + tile_m=tile_m, tile_n=tile_n2, tile_k=tile_k2, + doweight_stage1=bool(doweight_stage1), + w2_fp32_in=w2_fp32, + a2_fp8_in=a2_q, a2_scale_in=a2_scale, + use_reduce=bool(use_reduce), use_valid_mask=use_valid_mask, + **_shared, + ) + + +# Test Helpers for MoE GEMM2 Mode Comparison +def _make_reduce_mode_compile_fn(use_flydsl_reduce: bool = True, use_valid_mask: bool = False): + """Create a compile function that forces reduce mode. + + Args: + use_flydsl_reduce: If True, use FlyDSL reduce kernel. + If False, use torch.sum (for baseline comparison). + """ + def _compile( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage2: bool, + in_dtype: str = "fp16", + group_size: int = -1, + out_dtype: str = "f16", + waves_per_eu: Optional[int] = None, + expert_sched_mode: bool = True, + num_buffers: int = 1, + use_tdm_gather: bool = True, + use_tdm_store: bool = False, + inst_prefetch: bool = False, + wave_specialized_tdm: bool = False, + cluster_m: int = 1, + cluster_n: int = 1, + ): + if use_flydsl_reduce: + return compile_moe_gemm2_ex( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=doweight_stage2, + in_dtype=in_dtype, + group_size=group_size, + out_dtype=out_dtype, + waves_per_eu=waves_per_eu, + valid_mask=(True if bool(use_valid_mask) else None), + mode=MoeGemm2Mode.REDUCE, + zero_intermediate=False, # test non-zeroed performance + expert_sched_mode=bool(expert_sched_mode), + num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + ) + else: + gemm2_exe = compile_moe_gemm2( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=doweight_stage2, + in_dtype=in_dtype, + group_size=group_size, + out_dtype=out_dtype, + accumulate=False, + waves_per_eu=waves_per_eu, + expert_sched_mode=bool(expert_sched_mode), + num_buffers=int(num_buffers), + use_tdm_gather=bool(use_tdm_gather), + use_tdm_store=bool(use_tdm_store), + inst_prefetch=bool(inst_prefetch), + wave_specialized_tdm=bool(wave_specialized_tdm), + cluster_m=int(cluster_m), + cluster_n=int(cluster_n), + ) + return _TorchReduceWrapper(gemm2_exe, topk, model_dim) + return _compile + + +class _TorchReduceWrapper: + """Wrapper for GEMM2 (accumulate=False) with torch.sum reduction. + + For baseline comparison only. Production code should use compile_moe_gemm2_ex. + """ + + def __init__(self, gemm2_exe, topk: int, model_dim: int): + self._exe = gemm2_exe + self._topk = topk + self._model_dim = model_dim + self._intermediate = None + self._mode = MoeGemm2Mode.REDUCE + + def __call__( + self, + arg_out, + arg_x, + arg_w, + arg_scale_x, + arg_scale_w, + arg_sorted_token_ids, + arg_expert_ids, + arg_sorted_weights, + arg_num_valid_ids, + tokens_in, + n_in, + k_in, + size_expert_ids_in, + valid_mask, + stream, + ): + # Lazy allocate intermediate buffer + needed = tokens_in * self._topk * self._model_dim + if self._intermediate is None or self._intermediate.numel() < needed: + self._intermediate = torch.empty( + tokens_in * self._topk, self._model_dim, + device=arg_out.device, dtype=arg_out.dtype + ) + + intermediate = self._intermediate[:tokens_in * self._topk, :] + self._exe( + intermediate.view(-1), + arg_x, arg_w, arg_scale_x, arg_scale_w, + arg_sorted_token_ids, arg_expert_ids, arg_sorted_weights, + arg_num_valid_ids, tokens_in, n_in, k_in, size_expert_ids_in, + stream, + ) + X = intermediate.view(tokens_in, self._topk, self._model_dim) + if valid_mask is not None: + X = X * valid_mask.view(tokens_in, self._topk, 1).to(dtype=X.dtype) + torch.sum(X, dim=1, out=arg_out) + + @property + def mode(self) -> str: + return self._mode + + +def _bench_setup_data(tokens, model_dim, inter_dim, experts, topk, tile_m, seed=42): + """Build random MoE data + routing buffers for bench sweeps.""" + device = torch.device("cuda") + torch.manual_seed(seed) + s = 0.2 + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * s + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * s + w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) * (s / math.sqrt(inter_dim)) + score = torch.rand((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + routing = build_routing_buffers( + topk_ids=topk_ids, topk_weights=topk_weights, + experts=experts, model_dim=model_dim, tile_m=tile_m, + ) + return x_fp32, w1_fp32, w2_fp32, topk_ids, topk_weights, routing + + +def _prepare_a2_from_stage1(out1_fp16: torch.Tensor, in_dtype: str): + """Convert stage1 fp16 output to stage2 activation input (fp16 or bf16).""" + if in_dtype == "fp16": + return out1_fp16, None + if in_dtype == "bf16": + return out1_fp16.to(torch.bfloat16), None + raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") + + +def _bench_prepare_a2(out1_fp16, _tokens, _topk, _inter_dim, in_dtype): + return _prepare_a2_from_stage1(out1_fp16, in_dtype) + + +if __name__ == "__main__": + torch.set_default_device("cuda") + # CLI (mirrors key knobs from aiter/op_tests/test_moe_2stage.py, stage1 subset) + def _str2bool(v): + if v is None: + return None + if isinstance(v, bool): + return v + s = str(v).strip().lower() + if s in {"1", "true", "t", "yes", "y", "on"}: + return True + if s in {"0", "false", "f", "no", "n", "off"}: + return False + raise argparse.ArgumentTypeError(f"invalid bool: {v} (use t/f, true/false, 1/0)") + + def _str2tuple_dim(v: str) -> Tuple[int, int]: + # aiter uses "-dim 6144,4096" meaning (model_dim, inter_dim) + s = str(v).strip() + parts = [p.strip() for p in s.split(",") if p.strip()] + if len(parts) != 2: + raise argparse.ArgumentTypeError(f"invalid -dim {v!r}; expected 'model_dim,inter_dim' e.g. 6144,4096") + return int(parts[0]), int(parts[1]) + + parser = argparse.ArgumentParser( + formatter_class=argparse.RawTextHelpFormatter, + description="MoE 2-stage (FlyDSL WMMA fp16/bf16) test/benchmark on gfx1250.", + ) + parser.add_argument( + "--in_dtype", + type=str, + default="fp16", + choices=["fp16", "bf16", "all"], + help="Kernel input dtype: fp16, bf16, or all (default: fp16).", + ) + parser.add_argument("-dim", type=_str2tuple_dim, default=(6144, 4096), help="Model dimension: model_dim,inter_dim (e.g. -dim 6144,4096)") + parser.add_argument("-t", "--tokenNum", type=int, default=32, help="Number of tokens (e.g. -t 1024)") + parser.add_argument("-e", "--expert", type=int, default=8, help="Number of experts (e.g. -e 8)") + parser.add_argument("-k", "--topk", type=int, default=2, help="Top-k (e.g. -k 2)") + parser.add_argument("-s", "--doweight_stage1", type=_str2bool, nargs="?", const=True, default=False, help="Whether to multiply routed weight in stage1 (t/f).") + + # Stage1-specific kernel tiling knobs + parser.add_argument("--tile_m", type=int, default=16, help="Tile M / block_m (routing block size).") + parser.add_argument("--tile_n", type=int, default=256, help="Tile N (inter dim tile).") + parser.add_argument("--tile_k", type=int, default=512, help="Tile K (model dim tile).") + parser.add_argument("--tile_n2", type=int, default=None, help="Stage2 tile N (model dim tile). Default: 2*tile_n.") + parser.add_argument("--tile_k2", type=int, default=None, help="Stage2 tile K (inter dim tile). Default: tile_k.") + + parser.add_argument("--moe_sort_mode", type=str, default=None, choices=["aiter", "torch"], help="Routing buffer build mode (aiter moe_sorting vs torch fallback).") + parser.add_argument("--skip_ref", type=_str2bool, nargs="?", const=True, default=False, help="Skip torch reference correctness checks (benchmark-only).") + parser.add_argument( + "--gemm2_mode", + type=str, + default="both", + choices=["both", "atomic", "reduce"], + help="Stage2 accumulation mode: 'atomic', 'reduce', or 'both' (default: both).", + ) + parser.add_argument( + "--out_dtype", + type=str, + default="f16", + choices=["f16", "f32"], + help="Stage2 output dtype: f16 (half2 atomics) or f32 (scalar fp32 atomics).", + ) + parser.add_argument("--use_valid_mask", type=_str2bool, nargs="?", const=True, default=False, help="Use valid mask for optimization when reduce or not.") + + # Benchmark knobs + parser.add_argument("--benchmark", action="store_true", default=False, + help="Use GEMM-style per-iteration event timing with optional L2 flush.") + parser.add_argument("--no_flush_l2", action="store_true", default=False, + help="Disable L2 flush in benchmark mode.") + parser.add_argument("--num_buffers", type=int, default=1, choices=[1, 2, 3, 4], + help="Requested MXScale pipeline buffers for gfx1250 MoE kernels.") + parser.add_argument("--use_tdm_store", type=_str2bool, nargs="?", const=True, default=False, + help="Requested TDM store epilogue for gfx1250 MoE kernels.") + parser.add_argument("--inst_prefetch", type=_str2bool, nargs="?", const=True, default=False, + help="Enable instruction prefetch for gfx1250 MoE kernels.") + parser.add_argument("--wave_specialized_tdm", type=_str2bool, nargs="?", const=True, default=False, + help="Enable wave-specialized TDM loading for gfx1250 MoE kernels.") + parser.add_argument("--cluster_m", type=int, default=1, + help="Requested cluster_m for gfx1250 MoE kernels.") + parser.add_argument("--cluster_n", type=int, default=1, + help="Requested cluster_n for gfx1250 MoE kernels.") + parser.add_argument("--seed", type=int, default=0, help="torch.manual_seed(seed)") + parser.add_argument("--num_iters", type=int, default=2, help="Benchmark iters") + parser.add_argument("--num_warmup", type=int, default=1, help="Benchmark warmup iters") + + # graph mode test + parser.add_argument( + "--test_graph", + "-tg", + action="store_true", + default=False, + help="test with graph mode.", + ) + + # ── Benchmark sweep mode (--bench) ── + add_moe_bench_args(parser) + + args = parser.parse_args() + + # ── Bench sweep mode: run and exit ── + if args.bench: + moe_bench_main( + args, + stage1_fn=run_moe_stage1, + stage2_fn=run_moe_stage2, + setup_data_fn=_bench_setup_data, + prepare_a2_fn=_bench_prepare_a2, + ) + sys.exit(0) + + model_dim, inter_dim = args.dim + + tile_n2 = int(args.tile_n2) if args.tile_n2 is not None else int(args.tile_n) * 2 + tile_k2 = int(args.tile_k2) if args.tile_k2 is not None else args.tile_k + + # Determine which gemm2 modes to run. + if args.gemm2_mode == "both": + reduce_flags = [False, True] + elif args.gemm2_mode == "reduce": + reduce_flags = [True] + else: # "atomic" + reduce_flags = [False] + + # Common CLI arguments shared across stage1/stage2/2stage calls. + _common = dict( + tokens=int(args.tokenNum), model_dim=int(model_dim), inter_dim=int(inter_dim), + experts=int(args.expert), topk=int(args.topk), + doweight_stage1=bool(args.doweight_stage1), + tile_m=int(args.tile_m), + seed=int(args.seed), num_iters=int(args.num_iters), num_warmup=int(args.num_warmup), + moe_sort_mode=args.moe_sort_mode, + skip_ref=bool(args.skip_ref), + test_graph=bool(args.test_graph), + benchmark_mode=bool(args.benchmark), + flush_l2=not bool(args.no_flush_l2), + num_buffers=int(args.num_buffers), + use_tdm_store=bool(args.use_tdm_store), + inst_prefetch=bool(args.inst_prefetch), + wave_specialized_tdm=bool(args.wave_specialized_tdm), + cluster_m=int(args.cluster_m), cluster_n=int(args.cluster_n), + ) + + def run_one(dt: str, use_reduce: bool): + out_s = str(args.out_dtype).strip().lower() + if bool(use_reduce) and out_s in ("f32", "fp32", "float"): + print("[skip] reduce mode does not support out_dtype='f32'") + return + if (not bool(use_reduce)) and bool(args.use_valid_mask): + print("[skip] valid_mask is only used in reduce mode (atomic ignores it)") + return + run_moe_gemm_2stage( + **_common, in_dtype=dt, + out_dtype=str(args.out_dtype), + tile_n1=int(args.tile_n), tile_k1=int(args.tile_k), + tile_n2=tile_n2, tile_k2=tile_k2, + use_reduce=use_reduce, + use_valid_mask=bool(args.use_valid_mask), + ) + print(f"PASSED: dtype={dt} reduce={use_reduce}") + + # Run 2-stage (gemm1 -> quantize -> gemm2) aiter-style test/benchmark. + # Expand "all" to all supported dtypes. + in_dtypes = args.in_dtype.split(",") + if "all" in in_dtypes: + in_dtypes = ["fp16", "bf16"] + for dt in in_dtypes: + for use_reduce in reduce_flags: + run_one(dt, use_reduce) + + + +# --------------------------------------------------------------------------- +# Smoke tests +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("waves_per_eu", [1, 2], ids=["wpe1", "wpe2"]) +def test_moe_2stage_waves_per_eu_smoke(waves_per_eu: int): + """Smoke test for stage1/stage2 waves_per_eu plumbing on gfx1250.""" + _shape = dict(tokens=32, model_dim=256, inter_dim=128, experts=4, topk=2, tile_m=16) + _fast = dict(num_iters=1, num_warmup=1, return_outputs=True, skip_ref=True, + in_dtype="fp16", doweight_stage1=False, waves_per_eu=waves_per_eu) + + stage1_out, _ = run_moe_stage1(**_shape, tile_n=64, tile_k=128, **_fast) + stage2_out, _ = run_moe_stage2( + **_shape, tile_n=64, tile_k=128, out_dtype="f16", + a2_fp8_in=stage1_out.to(torch.float16), a2_scale_in=None, **_fast, + ) + assert torch.isfinite(stage1_out).all() + assert torch.isfinite(stage2_out).all() + + +# --------------------------------------------------------------------------- +# Main parametrized 2-stage test — WMMA dtypes (fp16 / bf16) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize( + "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n1, tile_k1, tile_n2, tile_k2, doweight_stage1", + [ + pytest.param(64, 256, 128, 4, 2, 16, 64, 128, 64, 128, False, id="S"), + pytest.param(129, 1024, 256, 8, 2, 32, 128, 128, 128, 128, False, id="M"), + pytest.param(333, 4096, 2048, 17, 9, 64, 128, 128, 256, 128, False, id="L", marks=pytest.mark.large_shape), + ], +) +@pytest.mark.parametrize("in_dtype", ["fp16", "bf16"]) +@pytest.mark.parametrize("out_dtype", ["f16", "f32"], ids=["out_f16", "out_f32"]) +@pytest.mark.parametrize("use_reduce", [False, True], ids=["atomic", "reduce"]) +@pytest.mark.parametrize("use_valid_mask", [False, True], ids=["nomask", "mask"]) +@pytest.mark.parametrize("test_graph", [ + pytest.param(False, id="eager"), + pytest.param(True, id="graph"), +]) +def test_moe_gemm_2stage( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n1: int, + tile_k1: int, + tile_n2: int, + tile_k2: int, + doweight_stage1: bool, + in_dtype: str, + out_dtype: str, + use_reduce: bool, + use_valid_mask: bool, + test_graph: bool, +): + run_moe_gemm_2stage( + tokens=tokens, + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n1=tile_n1, + tile_k1=tile_k1, + tile_n2=tile_n2, + tile_k2=tile_k2, + doweight_stage1=doweight_stage1, + in_dtype=in_dtype, + out_dtype=out_dtype, + use_reduce=use_reduce, + use_valid_mask=use_valid_mask, + test_graph=test_graph, + ) From 6e635c6e4c47b135cfe33080f4eb2d45d67dcda4 Mon Sep 17 00:00:00 2001 From: yanboshao Date: Thu, 16 Apr 2026 11:20:03 +0800 Subject: [PATCH 14/29] Add CI testcases and benchmark for allreduce (#387) --------- Co-authored-by: yashao@amd.com Co-authored-by: Felix Li Co-authored-by: Claude Opus 4.6 (1M context) --- .github/workflows/flydsl.yaml | 179 +++++++- README.md | 2 +- docs/prebuilt_kernels_guide.md | 2 +- kernels/custom_all_reduce.py | 12 +- kernels/custom_all_reduce_kernel.py | 423 ++++++++++++------ python/flydsl/_version.py | 2 +- python/flydsl/expr/__init__.py | 2 +- python/flydsl/expr/arith.py | 20 - python/flydsl/expr/buffer_ops.py | 33 +- python/flydsl/expr/mem_ops.py | 199 -------- python/flydsl/expr/vector.py | 22 +- tests/arch_compat.py | 1 + tests/kernels/compare_allreduce_benchmark.py | 90 ++++ ..._flydsl_allreduce.py => test_allreduce.py} | 199 +++++++- tests/pytest.ini | 2 + 15 files changed, 805 insertions(+), 383 deletions(-) delete mode 100644 python/flydsl/expr/mem_ops.py create mode 100644 tests/kernels/compare_allreduce_benchmark.py rename tests/kernels/{test_flydsl_allreduce.py => test_allreduce.py} (79%) diff --git a/.github/workflows/flydsl.yaml b/.github/workflows/flydsl.yaml index 6b54a304..2896e9f3 100644 --- a/.github/workflows/flydsl.yaml +++ b/.github/workflows/flydsl.yaml @@ -9,6 +9,11 @@ on: - main workflow_dispatch: +permissions: + contents: read + actions: read + pull-requests: read + concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true @@ -19,10 +24,18 @@ env: GITHUB_COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.event.head_commit.id }} jobs: + # --------------------------------------------------------------------------- + # Single-GPU tests: kernels, unit, examples, MLIR FileCheck, benchmarks. + # Runs on 1-GPU and Navi runners only. + # --------------------------------------------------------------------------- test: strategy: matrix: - runners: [ 'linux-flydsl-mi325-1', 'linux-flydsl-mi355-1', 'linux-flydsl-navi-2' ] + runners: [ + 'linux-flydsl-mi325-1', + 'linux-flydsl-mi355-1', + 'linux-flydsl-navi-2', + ] fail-fast: false runs-on: ${{ matrix.runners }} steps: @@ -169,3 +182,167 @@ jobs: run: | docker stop flydsl_test docker rm flydsl_test + + # --------------------------------------------------------------------------- + # Multi-GPU allreduce tests: ONLY for 8-GPU runners. + # Runs on BOTH linux-flydsl-mi325-8 AND linux-flydsl-mi355-8 independently. + # fail-fast: false ensures both runners always complete even if one fails. + # --------------------------------------------------------------------------- + multi-gpu: + needs: test + name: Multi-GPU AllReduce Tests (${{ matrix.runners }}) + timeout-minutes: 120 + strategy: + matrix: + runners: [ + 'linux-flydsl-mi325-8', + 'linux-flydsl-mi355-8', + ] + fail-fast: false + runs-on: ${{ matrix.runners }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + repository: ${{ env.GITHUB_REPO_NAME }} + ref: ${{ env.GITHUB_COMMIT_SHA }} + path: flydsl-test + + - name: Start CI container + run: | + echo "Clean up containers..." + docker ps -aq -f name=flydsl_test | xargs -r docker stop | xargs -r docker rm || true + + echo "Start CI container..." + if [ -f "/etc/podinfo/gha-render-devices" ]; then + DEVICE_FLAG=$(cat /etc/podinfo/gha-render-devices) + else + DEVICE_FLAG="--device /dev/dri" + fi + + docker run -dt --network=host --user root --device=/dev/kfd $DEVICE_FLAG \ + -v "${GITHUB_WORKSPACE:-$PWD}/flydsl-test:/flydsl-test" \ + --ipc=host --group-add video \ + --shm-size 16g \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + -w /flydsl-test \ + --name flydsl_test \ + ${{ env.DOCKER_IMAGE }} + env: + GITHUB_WORKSPACE: ${{ github.workspace }} + + - name: Install dependencies + run: | + docker exec flydsl_test bash -c "apt-get update && apt-get install -y cmake build-essential patchelf" + docker exec flydsl_test bash -c "python3 -m pip install -U pip setuptools wheel" + docker exec flydsl_test bash -c "python3 -m pip install ninja>=1.11.1" + docker exec flydsl_test bash -c "python3 -m pip install -U 'hypothesis>=6.82.0'" + docker exec flydsl_test bash -c "git config --global --add safe.directory /flydsl-test && cd /flydsl-test && git log" + + - name: Restore cached MLIR install tarball (if available) + id: mlir-cache + uses: actions/cache@v4 + with: + path: mlir_install.tgz + key: mlir-install-${{ matrix.runners }}-${{ hashFiles('flydsl-test/thirdparty/llvm-hash.txt', 'flydsl-test/scripts/build_llvm.sh', 'flydsl-test/CMakeLists.txt', 'flydsl-test/.github/workflows/flydsl.yaml') }} + + - name: Use cached MLIR install tarball (skip LLVM build) + if: steps.mlir-cache.outputs.cache-hit == 'true' + run: | + ls -lh mlir_install.tgz + docker cp mlir_install.tgz flydsl_test:/tmp/mlir_install.tgz + docker exec flydsl_test bash -c "rm -rf /llvm-project/mlir_install && mkdir -p /llvm-project && tar -xzf /tmp/mlir_install.tgz -C /llvm-project" + docker exec flydsl_test bash -c "ls -la /llvm-project/mlir_install/lib/cmake/mlir" + + - name: Build LLVM + if: steps.mlir-cache.outputs.cache-hit != 'true' + run: | + set -ex + docker exec flydsl_test bash -c "cd /flydsl-test && bash scripts/build_llvm.sh" + docker exec flydsl_test bash -c "ls -la /llvm-project/mlir_install/lib/cmake/mlir" + docker cp flydsl_test:/llvm-project/mlir_install.tgz ./mlir_install.tgz || true + + - name: Build FlyDSL (uses MLIR install prefix) + run: | + docker exec flydsl_test bash -c "export MLIR_PATH=/llvm-project/mlir_install && cd /flydsl-test && python3 -m pip install -e . --use-pep517" + + - name: Run multi-GPU allreduce tests + timeout-minutes: 30 + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test + python3 -m pytest tests/kernels/test_allreduce.py \ + -m multi_gpu -v --no-header --tb=short + " + + - name: Run allreduce benchmark (PR) + timeout-minutes: 30 + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test + python3 tests/kernels/test_allreduce.py \ + --world_size 8 --iters 51 --warmup 5 \ + --allreduce_impl flydsl --mode cudagraph \ + --shapes '2,7168,fp16;32,8192,fp32;128,8192,fp16;1024,7168,bf16;4096,8192,bf16' \ + --output_csv /tmp/bench_pr.csv + " + + - name: Build main branch baseline + id: build-main + timeout-minutes: 20 + continue-on-error: true + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test + git fetch origin main --depth=1 + git worktree add /tmp/flydsl-main origin/main + cd /tmp/flydsl-main + export MLIR_PATH=/llvm-project/mlir_install + python3 -m pip install -e . --use-pep517 2>&1 | tail -5 + " + + - name: Run allreduce benchmark (main) + id: bench-main + if: steps.build-main.outcome == 'success' + timeout-minutes: 30 + continue-on-error: true + run: | + docker exec flydsl_test bash -c " + cp /flydsl-test/tests/kernels/test_allreduce.py \ + /tmp/flydsl-main/tests/kernels/test_allreduce.py + cd /tmp/flydsl-main + python3 tests/kernels/test_allreduce.py \ + --world_size 8 --iters 51 --warmup 5 \ + --allreduce_impl flydsl --mode cudagraph \ + --shapes '2,7168,fp16;32,8192,fp32;128,8192,fp16;1024,7168,bf16;4096,8192,bf16' \ + --output_csv /tmp/bench_main.csv + " + + - name: Check performance regression (PR vs main) + if: steps.bench-main.outcome != 'skipped' + timeout-minutes: 5 + run: | + docker exec flydsl_test bash -c " + cd /flydsl-test + python3 tests/kernels/compare_allreduce_benchmark.py \ + /tmp/bench_main.csv /tmp/bench_pr.csv + " + + - name: Show test logs + if: failure() + run: | + docker exec flydsl_test bash -c 'cd /tmp && tar czf /tmp/logs.tgz *.log 2>/dev/null || echo "no logs"' + docker cp flydsl_test:/tmp/logs.tgz . || true + if [ -f logs.tgz ]; then + tar -xzf logs.tgz || true + cat *.log || true + else + echo "logs.tgz not found; skipping log extraction" + fi + + - name: Clean up + if: always() + run: | + docker stop flydsl_test + docker rm flydsl_test diff --git a/README.md b/README.md index d040c81e..b6ddda36 100644 --- a/README.md +++ b/README.md @@ -363,7 +363,7 @@ See `examples/` for more examples including tiled copy (`02-tiledCopy.py`), tile | **RMSNorm** | `test_rmsnorm.py` | RMSNorm (layout API) | | **Softmax** | `test_softmax.py` | Softmax (layout API) | | **Fused RoPE** | `test_fused_rope_cache.py` | Fused RoPE + KV cache | -| **AllReduce** | `test_flydsl_allreduce.py` | Multi-GPU all-reduce | +| **AllReduce** | `test_allreduce.py` | Multi-GPU all-reduce | | **RDNA GEMM** | `test_rdna_gemm.py` | RDNA FP16/FP8 GEMM | | **GFX1250 GEMM** | `test_gemm_fp8fp4_gfx1250.py` | GFX1250 FP8/FP4 GEMM | | **WMMA GEMM** | `test_wmma_gemm_gfx1250.py` | GFX1250 WMMA GEMM | diff --git a/docs/prebuilt_kernels_guide.md b/docs/prebuilt_kernels_guide.md index 4d3745b5..018b122f 100644 --- a/docs/prebuilt_kernels_guide.md +++ b/docs/prebuilt_kernels_guide.md @@ -338,7 +338,7 @@ What operation do you need? | `tests/kernels/test_rmsnorm.py` | RMSNorm | | `tests/kernels/test_softmax.py` | Softmax | | `tests/kernels/test_fused_rope_cache.py` | Fused RoPE + KV cache | -| `tests/kernels/test_flydsl_allreduce.py` | Multi-GPU all-reduce | +| `tests/kernels/test_allreduce.py` | Multi-GPU all-reduce | | `tests/kernels/test_rdna_gemm.py` | RDNA GEMM | | `tests/kernels/test_gemm_fp8fp4_gfx1250.py` | GFX1250 FP8/FP4 GEMM | | `tests/kernels/test_wmma_gemm_gfx1250.py` | GFX1250 WMMA GEMM | diff --git a/kernels/custom_all_reduce.py b/kernels/custom_all_reduce.py index 5f6b7631..a00c0a1b 100644 --- a/kernels/custom_all_reduce.py +++ b/kernels/custom_all_reduce.py @@ -268,6 +268,14 @@ def __init__(self, *, group, device, max_size: int, world_size: int, rank: int, if self.world_size not in {2, 4, 8}: raise ValueError(f"world_size must be one of {{2, 4, 8}}, got {self.world_size}") + # Pre-initialize resource attributes so close() is safe on partial init failure. + self._meta_ptr = None + self._meta_bases = [None] * self.world_size + self._input_buffer_bases = [None] * self.world_size + self._output_buffer_bases = [None] * self.world_size + self._graph_ipc_reg_list = [] + self._out_ptrs_cache = None + alloc_size = self._SIGNAL_SIZE + int(self.max_size) self._meta_ptr = self._alloc_uncached(alloc_size) @@ -373,7 +381,9 @@ def __init__(self, *, group, device, max_size: int, world_size: int, rank: int, def close(self): """Release IPC memory handles for peer GPU buffers.""" - for bases in [self._meta_bases, self._input_buffer_bases, self._output_buffer_bases]: + for bases in [getattr(self, '_meta_bases', []), + getattr(self, '_input_buffer_bases', []), + getattr(self, '_output_buffer_bases', [])]: for b in bases: if b is not None: self._close_mem_handle(int(b)) diff --git a/kernels/custom_all_reduce_kernel.py b/kernels/custom_all_reduce_kernel.py index 06bbaa47..cb753819 100644 --- a/kernels/custom_all_reduce_kernel.py +++ b/kernels/custom_all_reduce_kernel.py @@ -10,15 +10,135 @@ from __future__ import annotations +import math + import flydsl.compiler as flyc -from flydsl.expr import arith as ea, gpu, range_constexpr, mem_ops, vector as ev +from flydsl.expr import arith as ea, gpu, range_constexpr, vector as ev, buffer_ops from flydsl.expr.typing import T, Int32, Int64, Stream from flydsl._mlir import ir -from flydsl._mlir.dialects import scf +from flydsl._mlir.dialects import scf, llvm, rocdl from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from kernels.custom_all_reduce import _KMAXBLOCKS as _MAX_BLOCKS + +# --------------------------------------------------------------------------- +# Low-level memory helpers — all operate on raw i64 device addresses. +# +# Cache modifier bits for buffer_load / buffer_store (AMD GFX942 aux field): +# bit 0 = SC0 — bypass L1/TCP cache +# bit 1 = SC1 — bypass L2/TCC cache +# bit 2 = NT — nontemporal (bypass hardware prefetcher) +# --------------------------------------------------------------------------- +_CM_CACHED = 0 # normal cached access +_CM_SC1 = 2 # bypass L2 only (reads from signal bufs across GPUs) +_CM_SC0_SC1 = 3 # bypass L1+L2 (writes to signal bufs: fully uncached) +_CM_NT = 4 # nontemporal (bulk data writes, bypasses L2 prefetch) + + +# ---- buffer resource descriptor helper ------------------------------------ + +def _make_rsrc(addr_i64): + """Create buffer resource descriptor from a wave-uniform i64 base address.""" + return buffer_ops.create_buffer_resource_from_addr(addr_i64) + + +# ---- bulk data: 16-byte (128-bit) load / store ---------------------------- +# These accept a pre-built rsrc descriptor and a per-lane element offset (i32). + +def _load_v4i32(rsrc, elem_off_i32): + """Buffer-load vector<4xi32> (16 bytes) with pre-built descriptor.""" + return buffer_ops.buffer_load(rsrc, elem_off_i32, + vec_width=4, dtype=T.i32) + + +def _store_v4i32(rsrc, elem_off_i32, data): + """Buffer-store vector<4xi32> (16 bytes), cached.""" + buffer_ops.buffer_store(data, rsrc, elem_off_i32, + cache_modifier=_CM_CACHED) + + +def _store_v4i32_nt(rsrc, elem_off_i32, v4i32_val): + """Buffer-store vector<4xi32> nontemporal — bypasses L2 prefetcher.""" + buffer_ops.buffer_store(v4i32_val, rsrc, elem_off_i32, + cache_modifier=_CM_NT) + rocdl.s_waitcnt(0) + + +# ---- signal buffer: i32 load / store -------------------------------------- + +def _store_i32(rsrc, val_i32): + """Store i32 with default caching via pre-built rsrc descriptor.""" + buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), + cache_modifier=_CM_CACHED) + + +def _load_i32_uncached(rsrc): + """Load i32 bypassing L2 (sc1) via pre-built rsrc descriptor.""" + val = buffer_ops.buffer_load(rsrc, ea.constant(0, type=T.i32), + vec_width=1, dtype=T.i32, + cache_modifier=_CM_SC1) + rocdl.s_waitcnt(0) + return val + + +def _store_i32_uncached(rsrc, val_i32): + """Store i32 bypassing L1+L2 (sc0+sc1) via pre-built rsrc descriptor.""" + buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), + cache_modifier=_CM_SC0_SC1) + rocdl.s_waitcnt(0) + + +def _invalidate_l1(): + """Invalidate L1 scalar cache (buffer_inv sc1). + + Call inside a polling loop after an uncached load to discard stale L1 + lines so the next iteration sees fresh data from L2/HBM. + """ + llvm.InlineAsmOp(None, [], "buffer_inv sc1", "", has_side_effects=True) + + +def _store_i32_uncached_flush(rsrc, val_i32): + """Store i32 with L2 writeback then sc0+sc1 store via pre-built rsrc. + + buffer_wbl2 flushes dirty L2 lines to HBM before the signal store. + """ + llvm.InlineAsmOp(None, [], "buffer_wbl2 sc0 sc1", "", has_side_effects=True) + buffer_ops.buffer_store(val_i32, rsrc, ea.constant(0, type=T.i32), + cache_modifier=_CM_SC0_SC1) + rocdl.s_waitcnt(0) + + +# ---- pointer array helpers ----------------------------------------------- + +def _pack_i64_vec(values): + """Pack preloaded i64 values into vector for contiguous VGPR storage. + + On AMDGPU the subsequent ``ev.extract`` with a dynamic index lowers + through ``ConvertVectorToLLVM`` to ``llvm.extractelement`` which the + backend emits as ``v_movrels_b32`` (VGPR-relative addressing, ~3 insns) + instead of a chained ``arith.select`` costing 2*(N-1) insns. + """ + vec_type = T.vec(len(values), T.i64) + return ev.from_elements(vec_type, values) + + +def _extract_i64(vec, index): + """Extract i64 from a packed vector by dynamic index (VGPR-relative).""" + idx = ea.index_cast(T.index, index) + return ev.extract(vec, dynamic_position=[idx]) + + +def _load_device_ptr(array_base_i64, index): + """Load i64 pointer from a device-side pointer array at *index*. + + Uses buffer_load(dtype=i64): offset is in elements so buffer_load + automatically scales by 8 bytes internally. + """ + rsrc = buffer_ops.create_buffer_resource_from_addr(array_base_i64) + return buffer_ops.buffer_load(rsrc, index, vec_width=1, dtype=T.i64) + + # Signal buffer layout offsets (bytes), derived from _MAX_BLOCKS. # start[_MAX_BLOCKS][8] of uint32 | end[_MAX_BLOCKS][8] of uint32 | flag[_MAX_BLOCKS] of uint32 _SG_START_OFF_B = 0 @@ -30,6 +150,20 @@ # Element type helpers # --------------------------------------------------------------------------- +_BYTES_PER_PACK = 16 # sizeof(vector<4xi32>), the atomic load/store unit +_ELEMS_PER_PACK = _BYTES_PER_PACK // 4 # i32 elements per pack + + +def _elem_bytes(dtype_str: str) -> int: + """Return byte width of one scalar element for the given dtype.""" + d = (dtype_str or "").strip().lower() + if d in {"f32", "fp32"}: + return 4 + if d in {"f16", "fp16", "bf16"}: + return 2 + raise ValueError(f"unsupported dtype_str: {dtype_str!r}") + + def _elem_type(dtype_str: str) -> ir.Type: d = (dtype_str or "").strip().lower() if d in {"f16", "fp16"}: @@ -42,12 +176,8 @@ def _elem_type(dtype_str: str) -> ir.Type: def _pack_elems(dtype_str: str) -> int: - d = (dtype_str or "").strip().lower() - if d in {"f32", "fp32"}: - return 4 - if d in {"f16", "fp16", "bf16"}: - return 8 - raise ValueError(f"unsupported dtype_str: {dtype_str!r}") + """Number of elements per pack, derived from _BYTES_PER_PACK.""" + return _BYTES_PER_PACK // _elem_bytes(dtype_str) def _u(v): @@ -61,18 +191,18 @@ def _u(v): def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngpus: int): """Start-sync: write start flag to all peers, wait for all to arrive.""" - - i32, i64 = T.i32, T.i64 flag_addr = (self_sg_i64 + ea.constant(_SG_FLAG_OFF_B, type=i64) + bid_i32.extui(i64) * ea.constant(4, type=i64)) - flag = mem_ops.load_i32_uncached(flag_addr) + ea.constant(1, type=i32) + flag_rsrc = _make_rsrc(flag_addr) + flag = _load_i32_uncached(flag_rsrc) + ea.constant(1, type=i32) bid8 = bid_i32 * ea.constant(8, type=i32) lin_lane = bid8 + lane_i32 start_wait_addr = (self_sg_i64 + ea.constant(_SG_START_OFF_B, type=i64) + lin_lane.extui(i64) * ea.constant(4, type=i64)) + wait_rsrc = _make_rsrc(start_wait_addr) lin_rank = bid8 + rank_i32 start_rank_off = (ea.constant(_SG_START_OFF_B, type=i64) + lin_rank.extui(i64) * ea.constant(4, type=i64)) @@ -80,9 +210,10 @@ def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngp is_lane = _u(lane_i32) < ea.constant(ngpus, type=i32) if_op = scf.IfOp(is_lane, results_=[], has_else=False) with ir.InsertionPoint(if_op.then_block): - peer_sg = ea.select_by_index(lane_i32, sgs_i64) - mem_ops.store_i32_uncached_flush(peer_sg + start_rank_off, flag) - init_cur = mem_ops.load_i32_uncached(start_wait_addr) + peer_sg = _extract_i64(_pack_i64_vec(sgs_i64), lane_i32) + peer_rsrc = _make_rsrc(peer_sg + start_rank_off) + _store_i32_uncached(peer_rsrc, flag) + init_cur = _load_i32_uncached(wait_rsrc) w = scf.WhileOp([i32], [init_cur]) wb = ir.Block.create_at_start(w.before, [i32]) wa = ir.Block.create_at_start(w.after, [i32]) @@ -91,43 +222,34 @@ def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngp need_wait = _u(cur) < flag scf.ConditionOp(need_wait, [cur]) with ir.InsertionPoint(wa): - scf.YieldOp([mem_ops.load_i32_uncached(start_wait_addr)]) + scf.YieldOp([_load_i32_uncached(wait_rsrc)]) scf.YieldOp([]) gpu.barrier() is_t0 = lane_i32 == ea.constant(0, type=i32) if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) with ir.InsertionPoint(if_t0.then_block): - mem_ops.store_i32(flag_addr, flag) + _store_i32(flag_rsrc, flag) scf.YieldOp([]) return flag_addr def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, - ngpus: int, need_wbl2: bool = False): - """End-sync: write end flag to all peers, wait for all to finish. - - Args: - need_wbl2: True → use st_xgmi_u32 (buffer_wbl2 + signal store). - Required after cached stores (st_global_16b) so - that L2 dirty lines reach HBM before the signal. - False → use st_signal_u32 (signal store only, no wbl2). - For nt data stores (st_nt_16b) which already bypass - L2; uses ATOMIC_RELAXED + MEMORY_SCOPE_SYSTEM. - """ - + ngpus: int): + """End-sync: write end flag to all peers, wait for all to finish.""" i32, i64 = T.i32, T.i64 - gpu.barrier() flag_addr = (self_sg_i64 + ea.constant(_SG_FLAG_OFF_B, type=i64) + bid_i32.extui(i64) * ea.constant(4, type=i64)) - flag = mem_ops.load_i32_uncached(flag_addr) + ea.constant(1, type=i32) + flag_rsrc = _make_rsrc(flag_addr) + flag = _load_i32_uncached(flag_rsrc) + ea.constant(1, type=i32) bid8 = bid_i32 * ea.constant(8, type=i32) lin_lane = bid8 + lane_i32 end_wait_addr = (self_sg_i64 + ea.constant(_SG_END_OFF_B, type=i64) + lin_lane.extui(i64) * ea.constant(4, type=i64)) + wait_rsrc = _make_rsrc(end_wait_addr) lin_rank = bid8 + rank_i32 end_rank_off = (ea.constant(_SG_END_OFF_B, type=i64) + lin_rank.extui(i64) * ea.constant(4, type=i64)) @@ -135,12 +257,10 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, is_lane = _u(lane_i32) < ea.constant(ngpus, type=i32) if_op = scf.IfOp(is_lane, results_=[], has_else=False) with ir.InsertionPoint(if_op.then_block): - peer_sg = ea.select_by_index(lane_i32, sgs_i64) - if need_wbl2: - mem_ops.store_i32_uncached_flush(peer_sg + end_rank_off, flag) - else: - mem_ops.store_i32_uncached(peer_sg + end_rank_off, flag) - init_cur = mem_ops.load_i32_uncached(end_wait_addr) + peer_sg = _extract_i64(_pack_i64_vec(sgs_i64), lane_i32) + peer_rsrc = _make_rsrc(peer_sg + end_rank_off) + _store_i32_uncached(peer_rsrc, flag) + init_cur = _load_i32_uncached(wait_rsrc) w = scf.WhileOp([i32], [init_cur]) wb = ir.Block.create_at_start(w.before, [i32]) wa = ir.Block.create_at_start(w.after, [i32]) @@ -149,8 +269,8 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, need_wait = _u(cur) < flag scf.ConditionOp(need_wait, [cur]) with ir.InsertionPoint(wa): - nxt = mem_ops.load_i32_uncached(end_wait_addr) - mem_ops.invalidate_l1() + nxt = _load_i32_uncached(wait_rsrc) + _invalidate_l1() scf.YieldOp([nxt]) scf.YieldOp([]) @@ -158,7 +278,7 @@ def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, is_t0 = lane_i32 == ea.constant(0, type=i32) if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) with ir.InsertionPoint(if_t0.then_block): - mem_ops.store_i32(flag_addr, flag) + _store_i32(flag_rsrc, flag) scf.YieldOp([]) @@ -251,10 +371,7 @@ def allreduce_1stage_arr( Each warp loads data from one rank into shared memory, then warp 0 reduces across all warps and writes the result to global memory. """ - - i32, i64 = T.i32, T.i64 - idx = ir.IndexType.get() v4i32 = T.i32x4 if is_f32: v4f32 = T.f32x4 @@ -272,14 +389,15 @@ def allreduce_1stage_arr( in_ptrs_i64 = in_ptrs.ir_value() out_ptr_i64 = out_ptr.ir_value() - sgs = [mem_ops.load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - in_ptrs_arr = [mem_ops.load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + in_ptrs_vec = _pack_i64_vec(in_ptrs_arr) smem_sym = f"allreduce_1s_smem_ws{world_size}_t{threads}" n_smem = 2 * threads allocator = SmemAllocator(None, global_sym_name=smem_sym) smem_off = allocator._align(allocator.ptr, 16) - allocator.ptr = smem_off + n_smem * 16 + allocator.ptr = smem_off + n_smem * _BYTES_PER_PACK with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): allocator.finalize() smem_ptr = SmemPtr(allocator.get_base(), smem_off, v4i32, shape=(n_smem,)) @@ -297,6 +415,9 @@ def allreduce_1stage_arr( tid_pack = bid_i32 * tnum_gpu_i32 + lane_id stride_pack = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 + out_rsrc = _make_rsrc(out_ptr_i64) + in_rsrc = _make_rsrc(_extract_i64(in_ptrs_vec, warp_id)) + loop = scf.WhileOp([i32, i32], [tid_pack, ea.constant(0, type=i32)]) bfor = ir.Block.create_at_start(loop.before, [i32, i32]) afor = ir.Block.create_at_start(loop.after, [i32, i32]) @@ -308,12 +429,10 @@ def allreduce_1stage_arr( p = afor.arguments[0] parity = afor.arguments[1] - # Each warp loads data from its rank into shared memory - in_base = ea.select_by_index(warp_id, in_ptrs_arr) - off16 = p.extui(i64) * ea.constant(16, type=i64) - raw = mem_ops.load_v4i32(in_base + off16) + off_i32 = p * ea.constant(_ELEMS_PER_PACK, type=i32) + raw = _load_v4i32(in_rsrc, off_i32) sm_base = parity * ea.constant(threads, type=i32) - sm_idx = ea.index_cast(idx, sm_base + lane_i32) + sm_idx = ea.index_cast(T.index, sm_base + lane_i32) smem_ptr.store(raw, [sm_idx]) gpu.barrier() @@ -324,21 +443,21 @@ def allreduce_1stage_arr( acc = None for wi in range_constexpr(world_size): sm_i_idx = ea.index_cast( - idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + sm_base) + T.index, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + sm_base) raw_i = smem_ptr.load([sm_i_idx]) if is_f32: - vf = raw_i.bitcast(v4f32) + vf = ev.bitcast(v4f32, raw_i) acc = vf if acc is None else acc + vf else: v16 = ev.bitcast(v8half, raw_i) v32 = v16.extf(v8f32) acc = v32 if acc is None else acc + v32 if is_f32: - out_bits = acc.bitcast(v4i32) + out_bits = ev.bitcast(v4i32, acc) else: out_bits = ev.bitcast(v4i32, acc.truncf(v8half)) - dst_off = p.extui(i64) * ea.constant(16, type=i64) - mem_ops.store_v4i32(out_ptr_i64 + dst_off, out_bits) + dst_off_i32 = p * ea.constant(_ELEMS_PER_PACK, type=i32) + _store_v4i32(out_rsrc, dst_off_i32, out_bits) scf.YieldOp([]) # No barrier 2 needed: parity double-buffer ensures next iteration @@ -362,10 +481,7 @@ def allreduce_2stage_arr( tmp_ptrs: Int64, out_ptr: Int64, ): - - i32, i64 = T.i32, T.i64 - idx = ir.IndexType.get() v4i32 = T.i32x4 if is_f32: v4f32 = T.f32x4 @@ -384,9 +500,10 @@ def allreduce_2stage_arr( tmp_ptrs_i64 = tmp_ptrs.ir_value() out_ptr_i64 = out_ptr.ir_value() - sgs = [mem_ops.load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - in_ptrs_arr = [mem_ops.load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - tmp_ptrs_arr = [mem_ops.load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + in_ptrs_arr = [_load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + tmp_ptrs_arr = [_load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + in_ptrs_vec = _pack_i64_vec(in_ptrs_arr) # Compute pack range for this rank's reduce-scatter partition start_p = rank_i32 * ea.constant(part_p, type=i32) @@ -408,26 +525,27 @@ def allreduce_2stage_arr( smem_slots = threads if _use_single_buf_2stage else 2 * threads allocator = SmemAllocator(None, global_sym_name=smem_sym) smem_off = allocator._align(allocator.ptr, 16) - allocator.ptr = smem_off + smem_slots * 16 + allocator.ptr = smem_off + smem_slots * _BYTES_PER_PACK with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): allocator.finalize() smem_ptr = SmemPtr(allocator.get_base(), smem_off, v4i32, shape=(smem_slots,)) smem_ptr.get() - tmp_out_i64 = tmp_ptrs_arr[0] + tmp_out_rsrc = _make_rsrc(tmp_ptrs_arr[0]) # ---- Stage 1: reduce-scatter ---- # Two implementations selected at compile time via _use_single_buf_2stage: # Single-buffer (large tensor): 8KB LDS, 2 barriers/iter, higher occupancy. # Double-buffer (small tensor): 16KB LDS, 1 barrier/iter (parity trick). + in_rsrc = _make_rsrc(_extract_i64(in_ptrs_vec, warp_id)) def _build_reduce_body(cur, smem_base_expr=None): """Emit reduce body: load → smem → barrier1 → warp0 reduce → [barrier2].""" - in_base = ea.select_by_index(warp_id, in_ptrs_arr) - raw = mem_ops.load_v4i32(in_base + cur.extui(i64) * ea.constant(16, type=i64)) + off_i32 = cur * ea.constant(_ELEMS_PER_PACK, type=i32) + raw = _load_v4i32(in_rsrc, off_i32) if smem_base_expr is None: - sm_idx = ea.index_cast(idx, lane_i32) + sm_idx = ea.index_cast(T.index, lane_i32) else: - sm_idx = ea.index_cast(idx, smem_base_expr + lane_i32) + sm_idx = ea.index_cast(T.index, smem_base_expr + lane_i32) smem_ptr.store(raw, [sm_idx]) gpu.barrier() # barrier 1: all warps have written smem @@ -437,24 +555,24 @@ def _build_reduce_body(cur, smem_base_expr=None): acc = None for wi in range_constexpr(world_size): if smem_base_expr is None: - sm_r_idx = ea.index_cast(idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id) + sm_r_idx = ea.index_cast(T.index, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id) else: - sm_r_idx = ea.index_cast(idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + smem_base_expr) + sm_r_idx = ea.index_cast(T.index, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + smem_base_expr) raw_i = smem_ptr.load([sm_r_idx]) if is_f32: - vf = raw_i.bitcast(v4f32) + vf = ev.bitcast(v4f32, raw_i) acc = vf if acc is None else acc + vf else: v16 = ev.bitcast(v8half, raw_i) v32 = v16.extf(v8f32) acc = v32 if acc is None else acc + v32 if is_f32: - out_raw = acc.bitcast(v4i32) + out_raw = ev.bitcast(v4i32, acc) else: out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) rel_p = cur - start_p - mem_ops.store_v4i32(tmp_out_i64 + rel_p.extui(i64) * ea.constant(16, type=i64), - out_raw) + rel_off_i32 = rel_p * ea.constant(_ELEMS_PER_PACK, type=i32) + _store_v4i32(tmp_out_rsrc, rel_off_i32, out_raw) scf.YieldOp([]) idx_p = start_p + tid_pack @@ -492,14 +610,18 @@ def _build_reduce_body(cur, smem_base_expr=None): # smem half, so warp-0 reads and all-warp writes are disjoint. scf.YieldOp([cur + stride_pack, ea.constant(1, type=i32) - parity]) + gpu.barrier() _signal_end_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) # ---- Stage 2: all-gather ---- + out_rsrc = _make_rsrc(out_ptr_i64) + if vec_ok: + tmp_ptrs_vec = _pack_i64_vec(tmp_ptrs_arr) tid_pack2 = bid_i32 * tnum_gpu_i32 + lane_id stride_pack2 = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 - + tmp_rsrc = _make_rsrc(_extract_i64(tmp_ptrs_vec, warp_id)) loop2 = scf.WhileOp([i32], [tid_pack2]) b2 = ir.Block.create_at_start(loop2.before, [i32]) a2 = ir.Block.create_at_start(loop2.after, [i32]) @@ -514,14 +636,14 @@ def _build_reduce_body(cur, smem_base_expr=None): dst_rank = sum_rw & ea.constant(world_size - 1, type=i32) else: dst_rank = _u(sum_rw) % ea.constant(world_size, type=i32) - tmp_base = ea.select_by_index(warp_id, tmp_ptrs_arr) - raw = mem_ops.load_v4i32(tmp_base + cur.extui(i64) * ea.constant(16, type=i64)) + src_off_i32 = cur * ea.constant(_ELEMS_PER_PACK, type=i32) + raw = _load_v4i32(tmp_rsrc, src_off_i32) dst_pack = dst_rank * ea.constant(part_p, type=i32) + cur - mem_ops.store_v4i32(out_ptr_i64 + dst_pack.extui(i64) * ea.constant(16, type=i64), - raw) + dst_off_i32 = dst_pack * ea.constant(_ELEMS_PER_PACK, type=i32) + _store_v4i32(out_rsrc, dst_off_i32, raw) scf.YieldOp([cur + stride_pack2]) else: - # Non-vectorized fallback (world_size=6 or num_packs % world_size != 0) + tmp_rsrcs = [_make_rsrc(tmp_ptrs_arr[i]) for i in range(world_size)] tid_i32 = bid_i32 * ea.constant(threads, type=i32) + lane_i32 stride_i32 = gpu.grid_dim.x.ir_value() * ea.constant(threads, type=i32) @@ -541,11 +663,11 @@ def _build_reduce_body(cur, smem_base_expr=None): ok = _u(cur) < ea.constant(part_p, type=i32) ifp = scf.IfOp(ok, results_=[], has_else=False) with ir.InsertionPoint(ifp.then_block): - src_off = cur.extui(i64) * ea.constant(16, type=i64) - raw = mem_ops.load_v4i32(tmp_ptrs_arr[p] + src_off) + src_off_i32 = cur * ea.constant(_ELEMS_PER_PACK, type=i32) + raw = _load_v4i32(tmp_rsrcs[p], src_off_i32) dst_pack_idx = ea.constant(p * part_p, type=i32) + cur - dst_off = dst_pack_idx.extui(i64) * ea.constant(16, type=i64) - mem_ops.store_v4i32(out_ptr_i64 + dst_off, raw) + dst_off_i32 = dst_pack_idx * ea.constant(_ELEMS_PER_PACK, type=i32) + _store_v4i32(out_rsrc, dst_off_i32, raw) scf.YieldOp([]) scf.YieldOp([cur + stride_i32]) @@ -562,11 +684,7 @@ def allreduce_2stage_write_mode( out_ptrs: Int64, tmp_ptrs: Int64, ): - import math - - i32, i64 = T.i32, T.i64 - idx = ir.IndexType.get() v4i32 = T.i32x4 if is_f32: v4f32 = T.f32x4 @@ -585,8 +703,11 @@ def allreduce_2stage_write_mode( out_ptrs_i64 = out_ptrs.ir_value() tmp_ptrs_i64 = tmp_ptrs.ir_value() - sgs = [mem_ops.load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] - out_ptrs_arr = [mem_ops.load_device_ptr(out_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + sgs = [_load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + out_ptrs_arr = [_load_device_ptr(out_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + tmp_ptrs_arr = [_load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(world_size)] + tmp_ptrs_vec = _pack_i64_vec(tmp_ptrs_arr) + out_ptrs_vec = _pack_i64_vec(out_ptrs_arr) tnum_gpu_i32 = ea.constant(tnum_gpu, type=i32) log2_tnum = int(math.log2(tnum_gpu)) @@ -600,14 +721,16 @@ def allreduce_2stage_write_mode( n_smem_wm = 2 * threads allocator_wm = SmemAllocator(None, global_sym_name=smem_sym_wm) smem_wm_off = allocator_wm._align(allocator_wm.ptr, 16) - allocator_wm.ptr = smem_wm_off + n_smem_wm * 16 + allocator_wm.ptr = smem_wm_off + n_smem_wm * _BYTES_PER_PACK with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): allocator_wm.finalize() smem_ptr = SmemPtr(allocator_wm.get_base(), smem_wm_off, v4i32, shape=(n_smem_wm,)) smem_ptr.get() - tmp_out_i64 = mem_ops.load_device_ptr(tmp_ptrs_i64, rank_i32) + tmp_out_i64 = _extract_i64(tmp_ptrs_vec, rank_i32) # ---- Stage 1: scatter local input to REMOTE tmp buffers ---- + inp_rsrc = _make_rsrc(inp_ptr_i64) + start_w = warp_id * ea.constant(part_p, type=i32) is_last_w = warp_id == ea.constant(world_size - 1, type=i32) end_w_if = scf.IfOp(is_last_w, results_=[i32], has_else=True) @@ -617,6 +740,13 @@ def allreduce_2stage_write_mode( scf.YieldOp([start_w + ea.constant(part_p, type=i32)]) end_w = end_w_if.results[0] + dst_tmp = _extract_i64(tmp_ptrs_vec, warp_id) + is_tmp_null = dst_tmp == ea.constant(0, type=i64) + dst_tmp_low4 = dst_tmp & ea.constant(0xF, type=i64) + is_tmp_misaligned = dst_tmp_low4 != ea.constant(0, type=i64) + bad_tmp_addr = is_tmp_null | is_tmp_misaligned + dst_tmp_rsrc = _make_rsrc(dst_tmp) + idx_s1 = start_w + tid_pack loop_s1 = scf.WhileOp([i32, i32], [idx_s1, stride_pack]) bs1 = ir.Block.create_at_start(loop_s1.before, [i32, i32]) @@ -628,20 +758,16 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(as1): cur = as1.arguments[0] stride_s1 = as1.arguments[1] - raw = mem_ops.load_v4i32(inp_ptr_i64 + cur.extui(i64) * ea.constant(16, type=i64)) + cur_off_i32 = cur * ea.constant(_ELEMS_PER_PACK, type=i32) + raw = _load_v4i32(inp_rsrc, cur_off_i32) rel_idx = cur - start_w dst_off = rank_i32 * ea.constant(part_p, type=i32) + rel_idx - dst_tmp = mem_ops.load_device_ptr(tmp_ptrs_i64, warp_id) - tmp_addr = dst_tmp + dst_off.extui(i64) * ea.constant(16, type=i64) - is_tmp_null = dst_tmp == ea.constant(0, type=i64) - tmp_low4 = tmp_addr & ea.constant(0xF, type=i64) - is_tmp_misaligned = tmp_low4 != ea.constant(0, type=i64) - bad_tmp_addr = is_tmp_null | is_tmp_misaligned if_tmp_ok = scf.IfOp(bad_tmp_addr, results_=[], has_else=True) with ir.InsertionPoint(if_tmp_ok.then_block): scf.YieldOp([]) with ir.InsertionPoint(if_tmp_ok.else_block): - mem_ops.store_v4i32(tmp_addr, raw) + dst_off_i32 = dst_off * ea.constant(_ELEMS_PER_PACK, type=i32) + _store_v4i32(dst_tmp_rsrc, dst_off_i32, raw) scf.YieldOp([]) scf.YieldOp([cur + stride_s1, stride_s1]) @@ -650,10 +776,8 @@ def allreduce_2stage_write_mode( self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) # ---- Stage 2: reduce local tmp and write to REMOTE outputs ---- + tmp_out_rsrc = _make_rsrc(tmp_out_i64) part_p_i32 = ea.constant(part_p, type=i32) - # The last rank's output partition has largest_part_p elements - # (= part_p + num_packs % world_size). Use a runtime branch so that - # when num_packs is evenly divisible the overhead is minimal (same value). is_last_rank_s2 = rank_i32 == ea.constant(world_size - 1, type=i32) end_s2_if = scf.IfOp(is_last_rank_s2, results_=[i32], has_else=True) with ir.InsertionPoint(end_s2_if.then_block): @@ -661,6 +785,19 @@ def allreduce_2stage_write_mode( with ir.InsertionPoint(end_s2_if.else_block): scf.YieldOp([part_p_i32]) end_s2 = end_s2_if.results[0] + + is_tmpout_null = tmp_out_i64 == ea.constant(0, type=i64) + tmpout_low4 = tmp_out_i64 & ea.constant(0xF, type=i64) + is_load_misaligned = tmpout_low4 != ea.constant(0, type=i64) + bad_load_addr = is_tmpout_null | is_load_misaligned + + dst_ptr = _extract_i64(out_ptrs_vec, warp_id) + dst_out_rsrc = _make_rsrc(dst_ptr) + is_out_null = dst_ptr == ea.constant(0, type=i64) + dst_ptr_low4 = dst_ptr & ea.constant(0xF, type=i64) + is_out_misaligned = dst_ptr_low4 != ea.constant(0, type=i64) + bad_out_addr = is_out_null | is_out_misaligned + loop_s2 = scf.WhileOp([i32, i32], [tid_pack, stride_pack]) bs2 = ir.Block.create_at_start(loop_s2.before, [i32, i32]) as2 = ir.Block.create_at_start(loop_s2.after, [i32, i32]) @@ -672,70 +809,68 @@ def allreduce_2stage_write_mode( cur = as2.arguments[0] stride_s2 = as2.arguments[1] + # All warps load their chunk from tmp into smem src_off = warp_id * ea.constant(part_p, type=i32) + cur - load_addr = tmp_out_i64 + src_off.extui(i64) * ea.constant(16, type=i64) - is_tmpout_null = tmp_out_i64 == ea.constant(0, type=i64) - load_low4 = load_addr & ea.constant(0xF, type=i64) - is_load_misaligned = load_low4 != ea.constant(0, type=i64) - bad_load_addr = is_tmpout_null | is_load_misaligned + src_off_i32 = src_off * ea.constant(_ELEMS_PER_PACK, type=i32) raw_if = scf.IfOp(bad_load_addr, results_=[v4i32], has_else=True) with ir.InsertionPoint(raw_if.then_block): scf.YieldOp([ea.constant_vector(0, v4i32)]) with ir.InsertionPoint(raw_if.else_block): - scf.YieldOp([mem_ops.load_v4i32(load_addr)]) + scf.YieldOp([_load_v4i32(tmp_out_rsrc, src_off_i32)]) raw = raw_if.results[0] - sm_idx = ea.index_cast(idx, lane_i32) + sm_idx = ea.index_cast(T.index, lane_i32) smem_ptr.store(raw, [sm_idx]) gpu.barrier() - warp_id_local = _u(lane_i32) >> ea.constant(log2_tnum, type=i32) - lane_id_local = lane_i32 - warp_id_local * ea.constant(tnum_gpu, type=i32) - - raw_vals = [] - for wi in range_constexpr(world_size): - sm_i_idx = ea.index_cast(idx, ea.constant(wi * tnum_gpu, type=i32) + lane_id_local) - raw_vals.append(smem_ptr.load([sm_i_idx])) - - acc = None - for wi in range_constexpr(world_size): - raw_i = raw_vals[wi] + # Warp 0 reduces across all warps, writes result to res area + # (smem[threads .. threads+tnum_gpu-1]). Two-barrier pattern + # matching aiter: barrier1 guards tmp_smem, barrier2 guards + # res_smem; between iterations tmp and res are disjoint so no + # WAR hazard exists. + is_w0 = warp_id == ea.constant(0, type=i32) + ifw0 = scf.IfOp(is_w0, results_=[], has_else=False) + with ir.InsertionPoint(ifw0.then_block): + acc = None + for wi in range_constexpr(world_size): + sm_i_idx = ea.index_cast( + T.index, ea.constant(wi * tnum_gpu, type=i32) + lane_id) + raw_i = smem_ptr.load([sm_i_idx]) + if is_f32: + vf = ev.bitcast(v4f32, raw_i) + acc = vf if acc is None else acc + vf + else: + v16 = ev.bitcast(v8half, raw_i) + v32 = v16.extf(v8f32) + acc = v32 if acc is None else acc + v32 if is_f32: - vf = raw_i.bitcast(v4f32) - acc = vf if acc is None else acc + vf + out_raw = ev.bitcast(v4i32, acc) else: - v16 = ev.bitcast(v8half, raw_i) - v32 = v16.extf(v8f32) - acc = v32 if acc is None else acc + v32 - if is_f32: - out_raw = acc.bitcast(v4i32) - else: - out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) + out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) + res_idx = ea.index_cast(T.index, ea.constant(threads, type=i32) + lane_id) + smem_ptr.store(out_raw, [res_idx]) + scf.YieldOp([]) + + gpu.barrier() + + # All warps read the same reduced result from res area and + # nontemporal-write to their respective remote output buffers. + res_read_idx = ea.index_cast(T.index, ea.constant(threads, type=i32) + lane_id) + reduced_val = smem_ptr.load([res_read_idx]) dst_out_off = rank_i32 * ea.constant(part_p, type=i32) + cur - dst_byte_off = dst_out_off.extui(i64) * ea.constant(16, type=i64) - - # Each warp writes its reduced partition directly to the target - # output via flat_store_dwordx4 nt. The nt hint bypasses L1/L2 - # and works for all memory types (including IPC-mapped addresses). - dst_ptr = out_ptrs_arr[0] - for w in range_constexpr(1, world_size): - is_warp_w = warp_id_local == ea.constant(w, type=i32) - dst_ptr = ea.select(is_warp_w, out_ptrs_arr[w], dst_ptr) - out_addr = dst_ptr + dst_byte_off - is_out_null = dst_ptr == ea.constant(0, type=i64) - out_low4 = out_addr & ea.constant(0xF, type=i64) - is_out_misaligned = out_low4 != ea.constant(0, type=i64) - bad_out_addr = is_out_null | is_out_misaligned + dst_off_i32 = dst_out_off * ea.constant(_ELEMS_PER_PACK, type=i32) + if_out_ok = scf.IfOp(bad_out_addr, results_=[], has_else=True) with ir.InsertionPoint(if_out_ok.then_block): scf.YieldOp([]) with ir.InsertionPoint(if_out_ok.else_block): - mem_ops.store_v4i32_nt(out_addr, out_raw) + _store_v4i32_nt(dst_out_rsrc, dst_off_i32, reduced_val) scf.YieldOp([]) scf.YieldOp([cur + stride_s2, stride_s2]) + gpu.barrier() _signal_end_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) diff --git a/python/flydsl/_version.py b/python/flydsl/_version.py index b1aa5330..9556dd64 100644 --- a/python/flydsl/_version.py +++ b/python/flydsl/_version.py @@ -1 +1 @@ -__version__ = "0.1.3.1" +__version__ = "0.1.3.2" diff --git a/python/flydsl/expr/__init__.py b/python/flydsl/expr/__init__.py index 3fef52c7..7c2bd331 100644 --- a/python/flydsl/expr/__init__.py +++ b/python/flydsl/expr/__init__.py @@ -9,5 +9,5 @@ from . import utils -from . import arith, vector, gpu, buffer_ops, rocdl, math, mem_ops +from . import arith, vector, gpu, buffer_ops, rocdl, math from .rocdl import tdm_ops diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index 459ccbdd..38403f6c 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -64,24 +64,4 @@ def cmpf(predicate, lhs, rhs, **kwargs): return _mlir_arith.cmpf(predicate, _to_raw(lhs), _to_raw(rhs), **kwargs) -@traced_op -def select_by_index(index_val, values): - """Select one of *values* by integer *index_val* via chained ``arith.select``. - - Equivalent to a compile-time switch: returns ``values[index_val]``. - - Args: - index_val: Integer index (i32 ``ir.Value``). - values: List of ``ir.Value`` to select from. - - Returns: - The selected ``ir.Value``. - """ - out = values[0] - for i in range(1, len(values)): - pred = _mlir_arith.CmpIOp( - _mlir_arith.CmpIPredicate.eq, index_val, constant(i, type=index_val.type) - ).result - out = _mlir_arith.SelectOp(pred, values[i], out).result - return out diff --git a/python/flydsl/expr/buffer_ops.py b/python/flydsl/expr/buffer_ops.py index 6c41c311..9cebd709 100644 --- a/python/flydsl/expr/buffer_ops.py +++ b/python/flydsl/expr/buffer_ops.py @@ -72,6 +72,7 @@ def _get_buffer_flags(arch=None): 'create_llvm_ptr', 'get_element_ptr', 'create_buffer_resource', + 'create_buffer_resource_from_addr', 'buffer_load', 'buffer_store', 'BufferResourceDescriptor', @@ -323,6 +324,34 @@ def _num_records_from_memref_type() -> Optional[int]: return BufferResourceDescriptor(rsrc) +def create_buffer_resource_from_addr(addr_i64: ir.Value) -> ir.Value: + """Create AMD buffer resource descriptor from a raw i64 device address. + + Useful when working with runtime pointer arrays (e.g. IPC-mapped addresses + or device-side pointer tables) where no fly.memref is available. + The full address is encoded as the buffer base; callers should pass + byte offset 0 to buffer_load / buffer_store. + + Args: + addr_i64: Raw 64-bit device address (i64 MLIR value). + + Returns: + ROCDL buffer resource descriptor (!llvm.ptr<8>). + + Example: + >>> rsrc = create_buffer_resource_from_addr(raw_addr_i64) + >>> data = buffer_load(rsrc, i32_zero, vec_width=4, dtype=T.i32) + """ + addr_i64 = _unwrap_value(addr_i64) + ptr_type = ir.Type.parse('!llvm.ptr') + base_ptr = llvm.IntToPtrOp(ptr_type, addr_i64).result + flags = _create_i32_constant(_get_buffer_flags()) + stride = _create_i16_constant(0) + num_records = _create_i64_constant(0xFFFFFFFF) + rsrc_type = ir.Type.parse('!llvm.ptr<8>') + return rocdl.MakeBufferRsrcOp(rsrc_type, base_ptr, stride, num_records, flags).result + + @traced_op def create_buffer_resource(memref_val: ir.Value, stride: int = 0, @@ -361,10 +390,10 @@ def buffer_load(rsrc: ir.Value, cache_modifier: int = 0, soffset_bytes: Optional[Union[int, ir.Value]] = None) -> ir.Value: """AMD buffer load operation. - + Load data from global memory using buffer descriptor and offset. Uses hardware-level bounds checking and vectorization. - + Args: rsrc: Buffer resource descriptor (!llvm.ptr<8>) offset: Offset in elements (i32 type) diff --git a/python/flydsl/expr/mem_ops.py b/python/flydsl/expr/mem_ops.py deleted file mode 100644 index 4e74ea9e..00000000 --- a/python/flydsl/expr/mem_ops.py +++ /dev/null @@ -1,199 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 FlyDSL Project Contributors - -"""Low-level memory operations via inline assembly for multi-GPU kernels. - -Provides wrappers around GFX942 inline assembly instructions for: - -- **Uncached** loads/stores (``sc0 sc1`` — system-scope coherent, - for cross-GPU signal buffers allocated with ``hipDeviceMallocUncached``) -- **Nontemporal** stores (``nt`` — bypasses L1/L2 cache, works on any - memory type including regular ``hipMalloc`` / IPC-mapped addresses) -- **Cached** vector loads/stores (16-byte / ``v4i32``) -- Device-side pointer access - -All functions operate on raw ``ir.Value`` (i32/i64/vector<4xi32>). - -Example:: - - from flydsl.expr import mem_ops - - val = mem_ops.load_i32_uncached(addr) - mem_ops.store_i32_uncached_flush(peer_addr, flag) - data = mem_ops.load_v4i32(data_addr) -""" - -from __future__ import annotations - -from .._mlir import ir -from .._mlir.dialects import arith as _arith, llvm, rocdl -from .meta import traced_op -from .typing import T - - -# --------------------------------------------------------------------------- -# Uncached i32 operations (system-scope coherent, for signal buffers) -# --------------------------------------------------------------------------- - -@traced_op -def load_i32_uncached(addr_i64): - """Load i32 from global address, bypassing L1 cache (system-scope). - - Emits ``global_load_dword ... sc1`` on GFX942. - Typically used to poll cross-GPU signal buffers. - """ - v = llvm.InlineAsmOp( - T.i32, [addr_i64], - "global_load_dword $0, $1, off sc1", "=v,v", - has_side_effects=True, - ).result - rocdl.s_waitcnt(0) - return v - - -@traced_op -def store_i32_uncached_flush(addr_i64, val_i32): - """Store i32 with L2 flush + system-scope coherence for XGMI visibility. - - Emits ``buffer_wbl2 sc0 sc1`` followed by ``global_store_dword ... sc0 sc1``. - Use after cached data stores (``store_v4i32``) to ensure L2 dirty lines - reach HBM before the signal becomes visible to peer GPUs. - """ - llvm.InlineAsmOp(None, [], "buffer_wbl2 sc0 sc1", "", has_side_effects=True) - llvm.InlineAsmOp( - None, [addr_i64, val_i32], - "global_store_dword $0, $1, off sc0 sc1", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -@traced_op -def store_i32_uncached(addr_i64, val_i32): - """Store i32 with system-scope coherence (no L2 flush). - - Emits ``global_store_dword ... sc0 sc1``. - Use after nontemporal data stores (``store_v4i32_nt``) which already - bypass L2 — no ``buffer_wbl2`` is needed. - """ - llvm.InlineAsmOp( - None, [addr_i64, val_i32], - "global_store_dword $0, $1, off sc0 sc1", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -@traced_op -def store_i32(addr_i64, val_i32): - """Store i32 to global address (normal cached store). - - Emits ``global_store_dword ... off`` with no cache coherence flags. - Use for writes visible only to the local GPU (e.g. updating own signal). - """ - llvm.InlineAsmOp( - None, [addr_i64, val_i32], - "global_store_dword $0, $1, off", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -# --------------------------------------------------------------------------- -# v4i32 (16-byte) vector operations -# --------------------------------------------------------------------------- - -@traced_op -def load_v4i32(addr_i64): - """Load 16 bytes (``vector<4xi32>``) from global address. - - Emits ``flat_load_dwordx4``. - """ - v = llvm.InlineAsmOp( - T.i32x4, [addr_i64], - "flat_load_dwordx4 $0, $1", "=v,v", - has_side_effects=True, - ).result - rocdl.s_waitcnt(0) - return v - - -@traced_op -def store_v4i32(addr_i64, v4i32_val): - """Store 16 bytes (``vector<4xi32>``) to global address. - - Emits ``global_store_dwordx4 ... off``. - """ - llvm.InlineAsmOp( - None, [addr_i64, v4i32_val], - "global_store_dwordx4 $0, $1, off", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -@traced_op -def store_v4i32_nt(addr_i64, v4i32_val): - """Store 16 bytes with nontemporal hint, bypassing L1/L2 cache. - - Emits ``flat_store_dwordx4 ... nt``. - Suitable for large data writes across XGMI — works on any memory type - (regular ``hipMalloc``, IPC-mapped coarse-grained memory). - """ - llvm.InlineAsmOp( - None, [addr_i64, v4i32_val], - "flat_store_dwordx4 $0, $1 nt", "v,v", - has_side_effects=True, - ) - rocdl.s_waitcnt(0) - - -# --------------------------------------------------------------------------- -# Pointer helpers -# --------------------------------------------------------------------------- - -@traced_op -def load_device_ptr(array_base_i64, index): - """Load an i64 pointer from a device-side pointer array. - - Computes ``base + index * 8``, casts to ``!llvm.ptr``, and loads i64. - - Args: - array_base_i64: Base address of the pointer array (i64). - index: Array index (i32 or i64). - """ - from . import arith as ea - - i64 = T.i64 - if hasattr(index, 'type') and isinstance(index.type, ir.IntegerType) and index.type.width == 32: - index = _arith.ExtUIOp(i64, index).result - elem_addr = array_base_i64 + index * ea.constant(8, type=i64) - ptr = llvm.IntToPtrOp(ir.Type.parse("!llvm.ptr"), elem_addr).result - return llvm.LoadOp(i64, ptr).result - - -@traced_op -def invalidate_l1(): - """Invalidate L1 scalar cache (``buffer_inv sc1``). - - Use inside a polling loop after a remote-visible load to discard stale - L1 cache lines so the next iteration sees fresh data from L2/HBM. - """ - llvm.InlineAsmOp(None, [], "buffer_inv sc1", "", has_side_effects=True) - - -__all__ = [ - # Uncached i32 (system-scope coherent) - "load_i32_uncached", - "store_i32_uncached_flush", - "store_i32_uncached", - "store_i32", - # v4i32 (16-byte vector) - "load_v4i32", - "store_v4i32", - "store_v4i32_nt", - # Cache control - "invalidate_l1", - # Pointer helpers - "load_device_ptr", -] diff --git a/python/flydsl/expr/vector.py b/python/flydsl/expr/vector.py index 002adac9..182ab5bf 100644 --- a/python/flydsl/expr/vector.py +++ b/python/flydsl/expr/vector.py @@ -10,6 +10,7 @@ from __future__ import annotations +from .._mlir import ir from .._mlir.dialects import vector as _vector from .meta import traced_op @@ -55,9 +56,21 @@ def store(value, memref, indices, *, loc=None, ip=None, **kwargs): ) +# ----------------------------------------------------------------------------- +# Thin wrappers for common op classes that otherwise require `.result` access. +# ----------------------------------------------------------------------------- + + @traced_op def extract(vector, static_position=None, dynamic_position=None, *, loc=None, ip=None): - """Wrapper around `vector.ExtractOp(...).result`.""" + """Wrapper around `vector.ExtractOp(...).result`. + + When only ``dynamic_position`` is supplied (without explicit + ``static_position``), each dynamic index needs a corresponding + ``kDynamic`` sentinel in the static attribute so the ODS builder + pairs them correctly. This wrapper fills in the sentinels + automatically. + """ from . import arith as _arith_ext if static_position is None: @@ -65,6 +78,13 @@ def extract(vector, static_position=None, dynamic_position=None, *, loc=None, ip if dynamic_position is None: dynamic_position = [] dynamic_position = [_arith_ext.unwrap(i, index=True, loc=loc) for i in dynamic_position] + + n_static = len(static_position) + n_dynamic = len(dynamic_position) + if n_dynamic > 0 and n_static < n_dynamic: + kDynamic = ir.ShapedType.get_dynamic_size() + static_position = list(static_position) + [kDynamic] * (n_dynamic - n_static) + return _vector.ExtractOp( _arith_ext.unwrap(vector, loc=loc), static_position=static_position, diff --git a/tests/arch_compat.py b/tests/arch_compat.py index c79b7be6..fb6bab52 100644 --- a/tests/arch_compat.py +++ b/tests/arch_compat.py @@ -16,6 +16,7 @@ "test_moe_reduce.py", "test_pa.py", "test_quant.py", + "test_allreduce.py", # custom_all_reduce requires CDNA (gfx9xx) }) # Example scripts verified to work on RDNA (non-CDNA) GPUs. diff --git a/tests/kernels/compare_allreduce_benchmark.py b/tests/kernels/compare_allreduce_benchmark.py new file mode 100644 index 00000000..2b0942cc --- /dev/null +++ b/tests/kernels/compare_allreduce_benchmark.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python3 +"""Compare two allreduce benchmark CSVs (main vs PR) and flag regressions. + +Usage: + python3 compare_benchmark.py + +Exit code 1 if any case regresses more than BOTH thresholds: + - relative increase > MAX_REGRESSION_PCT (default 10%) + - absolute increase > MIN_ABS_REGRESSION_US (default 5 us) +""" +import sys +import pandas as pd + +MAX_REGRESSION_PCT = 10.0 +MIN_ABS_REGRESSION_US = 5.0 + + +def main(): + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ") + sys.exit(2) + + main_csv, pr_csv = sys.argv[1], sys.argv[2] + + main_df = pd.read_csv(main_csv) + pr_df = pd.read_csv(pr_csv) + + main_agg = main_df[main_df["rank"] == "aggregate"].copy() + pr_agg = pr_df[pr_df["rank"] == "aggregate"].copy() + + # Detect cases that failed/skipped in PR but succeeded on main + pr_agg_indexed = pr_agg.set_index(["shape", "dtype"]) + main_agg_indexed = main_agg.set_index(["shape", "dtype"]) + + pr_broken = pr_agg_indexed[ + (pr_agg_indexed["avg_time_us"] <= 0) | pr_agg_indexed["avg_time_us"].isna() + ] + main_ok = main_agg_indexed[ + (main_agg_indexed["avg_time_us"] > 0) & main_agg_indexed["avg_time_us"].notna() + ] + newly_broken = pr_broken.index.intersection(main_ok.index) + + # Performance comparison for cases that both sides ran successfully + pr_valid = pr_agg_indexed[["avg_time_us"]].loc[ + (pr_agg_indexed["avg_time_us"] > 0) & pr_agg_indexed["avg_time_us"].notna() + ] + main_valid = main_agg_indexed[["avg_time_us"]].loc[ + (main_agg_indexed["avg_time_us"] > 0) & main_agg_indexed["avg_time_us"].notna() + ] + merged = pr_valid.join(main_valid, lsuffix="_pr", rsuffix="_main").dropna() + + fail_count = 0 + + if not merged.empty: + merged["delta_us"] = merged["avg_time_us_pr"] - merged["avg_time_us_main"] + merged["delta_pct"] = (merged["delta_us"] / merged["avg_time_us_main"]) * 100.0 + + print("=== Allreduce Benchmark: PR vs main ===") + for (shape, dtype), row in merged.iterrows(): + regressed = ( + row["delta_pct"] > MAX_REGRESSION_PCT + and row["delta_us"] > MIN_ABS_REGRESSION_US + ) + tag = "REGRESSION" if regressed else "OK" + if regressed: + fail_count += 1 + print( + f" {shape:>20s} {dtype:>4s} " + f"main={row['avg_time_us_main']:8.2f} us " + f"PR={row['avg_time_us_pr']:8.2f} us " + f"delta={row['delta_us']:+8.2f} us ({row['delta_pct']:+5.1f}%) " + f"[{tag}]" + ) + + if len(newly_broken) > 0: + print("\n=== Cases BROKEN in PR (work on main but fail on PR) ===") + for shape, dtype in newly_broken: + fail_count += 1 + err = pr_agg_indexed.loc[(shape, dtype)].get("error", "unknown") + print(f" {shape:>20s} {dtype:>4s} [BROKEN] error: {err}") + + if fail_count > 0: + print(f"\nFAILED: {fail_count} issue(s) detected.") + sys.exit(1) + else: + print("\nPASSED: No regression or breakage detected.") + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/test_flydsl_allreduce.py b/tests/kernels/test_allreduce.py similarity index 79% rename from tests/kernels/test_flydsl_allreduce.py rename to tests/kernels/test_allreduce.py index 45d2d3fd..eb3d9269 100644 --- a/tests/kernels/test_flydsl_allreduce.py +++ b/tests/kernels/test_allreduce.py @@ -257,6 +257,13 @@ def _dist_worker( mode: "eager" or "cudagraph" - which path to run (separate flows) result_dict: Shared dictionary to collect results from all ranks """ + import warnings + warnings.filterwarnings("ignore") + _devnull_fd = os.open(os.devnull, os.O_WRONLY) + os.dup2(_devnull_fd, 1) + os.dup2(_devnull_fd, 2) + os.close(_devnull_fd) + torch.cuda.set_device(rank) device = torch.device(f"cuda:{rank}") @@ -305,15 +312,6 @@ def _dist_worker( out = torch.empty_like(x_flat) fa = init_custom_ar(meta, rank_data, handles, offsets, rank=rank, full_nvlink=True, out=out) - if rank == 0: - fa_mod = getattr(getattr(fa, "__class__", None), "__module__", None) - fa_name = getattr(getattr(fa, "__class__", None), "__name__", None) - print( - f"[custom_all_reduce] backend=aiter " - f"allreduce_impl={allreduce_impl!r} fa={fa_mod}.{fa_name}", - flush=True, - ) - # Warmup: align all ranks dist.all_reduce(torch.zeros(1, device=device), group=group) torch.cuda.synchronize() @@ -405,7 +403,7 @@ def _run_eager(): elif mode == "cudagraph": if not hasattr(fa, "capture"): if rank == 0: - print("[test_flydsl_allreduce] WARN: fa has no capture(); skipping cudagraph.", flush=True) + print("[test_allreduce] WARN: fa has no capture(); skipping cudagraph.", flush=True) result_dict[rank] = { "rank": rank, "shape": shape, "dtype": dtype_str, "world_size": world_size, "mode": "cudagraph", "max_error": float("nan"), "avg_time_us": 0.0, @@ -638,6 +636,7 @@ def run_all_tests( print(f" Avg time: mean={mean_avg_time:.3f} us/iter, min={min_avg_time:.3f}, max={max_avg_time:.3f}") # Add aggregate row + rank0 = rank_results[0] if rank_results else {} aggregate_result = { "rank": "aggregate", "shape": str(shape), @@ -649,9 +648,10 @@ def run_all_tests( "min_avg_time_us": min_avg_time, "max_avg_time_us": max_avg_time, "device_time_sum_us": sum(r["device_time_sum_us"] for r in rank_results), - "kernel_name": rank_results[0]["kernel_name"] if rank_results else "unknown", + "kernel_name": rank0.get("kernel_name", "unknown"), "num_iters": num_iters, "num_warmup": num_warmup, + "error": rank0.get("error"), } all_results.append(aggregate_result) @@ -678,12 +678,189 @@ def run_all_tests( if not aggregate_df.empty: print(aggregate_df.to_string(index=False)) print("=" * 80) + + failed = [ + r for r in all_results + if r.get("rank") == "aggregate" + and (r.get("kernel_name") in ("skip", "error") or r.get("error")) + ] + if failed: + print("\n✗ FAILED cases:") + for r in failed: + reason = r.get("error") or r.get("kernel_name", "unknown") + print(f" {r['shape']} {r['dtype']} {r['mode']} → {reason}") + sys.exit(1) + return df else: print("\nNo results to save.") return pd.DataFrame() +# ============================================================================ +# Pytest test functions for 8-GPU allreduce CI testing +# ============================================================================ + +def _count_physical_gpus() -> int: + """Return number of physically available GPUs via a fresh subprocess. + + Using a subprocess bypasses both HIP_VISIBLE_DEVICES restrictions and + PyTorch's internal device-count cache in the parent pytest process. + """ + import subprocess as _sp + env = {k: v for k, v in os.environ.items() if k != "HIP_VISIBLE_DEVICES"} + try: + r = _sp.run( + [sys.executable, "-c", "import torch; print(torch.cuda.device_count())"], + capture_output=True, text=True, timeout=30, env=env, + ) + return int(r.stdout.strip()) if r.returncode == 0 else 0 + except Exception: + return 0 + + +# All 8-GPU test configurations (always run, no large_shape distinction). +_8GPU_PARAMS = [ + # (shape, dtype_str, mode) + # --- small shapes (edge-case coverage, aligned with aiter) --- + ((2, 7168), "bf16", "cudagraph"), # 14 K elements · BF16 · cudagraph (aiter shape) + ((16, 4096), "fp16", "eager"), # 64 K elements · FP16 · eager + # --- medium shapes --- + ((128, 8192), "bf16", "cudagraph"), # 1 M elements · BF16 · cudagraph + ((96, 4096), "fp16", "eager"), # 384 K elements · FP16 · eager + # --- eager + cudagraph cross-dtype --- + ((512, 8192), "bf16", "eager"), # 4 M elements · BF16 · eager + ((1024, 8192), "fp16", "cudagraph"), # 8 M elements · FP16 · cudagraph + # --- fp32 coverage --- + ((64, 4096), "fp32", "eager"), # 256 K elements · FP32 · eager +] + +# 4-GPU test configurations (fp32 + smaller world_size coverage). +_4GPU_PARAMS = [ + # (shape, dtype_str, mode) + ((64, 4096), "fp32", "eager"), # 256 K elements · FP32 · eager + ((128, 8192), "fp16", "eager"), # 1 M elements · FP16 · eager + ((64, 8192), "bf16", "cudagraph"), # 512 K elements · BF16 · cudagraph +] + + +# 8-GPU benchmark configurations: cover all 3 kernel paths × 3 dtypes, cudagraph mode. +# small (2×7168) → 1-stage kernel +# medium (128×8192) → 2-stage kernel +# large (1024×8192) → write-mode kernel +_BENCHMARK_PARAMS = [ + # (shape, dtype_str, mode) + ((2, 7168), "fp16", "cudagraph"), + ((32, 8192), "fp32", "cudagraph"), + ((128, 8192), "fp16", "cudagraph"), + ((1024, 7168), "bf16", "cudagraph"), + ((4096, 8192), "bf16", "cudagraph") +] + + +def _run_subprocess(*, world_size, shape, dtype_str, mode, iters=10, warmup=2, + output_csv=None, timeout=600): + """Launch the allreduce harness in a subprocess and assert success.""" + import subprocess as _sp + + env = {k: v for k, v in os.environ.items() if k != "HIP_VISIBLE_DEVICES"} + shape_str = ",".join(str(d) for d in shape) + f",{dtype_str}" + + cmd = [ + sys.executable, __file__, + "--world_size", str(world_size), + "--iters", str(iters), + "--warmup", str(warmup), + "--shapes", shape_str, + "--mode", mode, + "--allreduce_impl", "flydsl", + ] + if output_csv: + cmd += ["--output_csv", output_csv] + result = _sp.run(cmd, env=env, timeout=timeout, capture_output=True, text=True) + assert result.returncode == 0, ( + f"{world_size}-GPU allreduce FAILED: shape={shape}, dtype={dtype_str}, " + f"mode={mode} (exit code {result.returncode})\n" + f"stdout (last 2000 chars):\n{result.stdout[-2000:]}\n" + f"stderr (last 2000 chars):\n{result.stderr[-2000:]}" + ) + return result + + +def _run_subprocess_test(*, world_size, shape, dtype_str, mode): + """Launch the allreduce accuracy test in a subprocess.""" + _run_subprocess(world_size=world_size, shape=shape, dtype_str=dtype_str, mode=mode) + + +def _run_subprocess_benchmark(*, world_size, shape, dtype_str, mode): + """Launch the allreduce benchmark in a subprocess with more iterations. + + Returns the CSV output path for downstream baseline comparison. + """ + shape_tag = "x".join(str(d) for d in shape) + csv_path = f"/tmp/allreduce_bench_{shape_tag}_{dtype_str}_{mode}.csv" + result = _run_subprocess( + world_size=world_size, shape=shape, dtype_str=dtype_str, mode=mode, + iters=51, warmup=5, output_csv=csv_path, timeout=900, + ) + if result.stdout: + for line in result.stdout.splitlines(): + if "avg_time" in line.lower() or "max_error" in line.lower() or "aggregate" in line.lower(): + print(line) + return csv_path + + +def _param_id(shape, dtype_str, mode): + s = "x".join(str(d) for d in shape) + return f"{s}-{dtype_str}-{mode}" + + +@pytest.mark.multi_gpu +@pytest.mark.parametrize("shape,dtype_str,mode", _8GPU_PARAMS, + ids=[_param_id(*p) for p in _8GPU_PARAMS]) +def test_allreduce_8gpu_accuracy(shape, dtype_str, mode): + """8-GPU allreduce accuracy test. + + Runs the allreduce harness in a child subprocess so that + HIP_VISIBLE_DEVICES (auto-set by run_tests.sh to one GPU index) + does not limit device visibility inside the distributed workers. + + Skipped automatically on machines with fewer than 8 physical GPUs. + """ + phys_ng = _count_physical_gpus() + if phys_ng < 8: + pytest.skip(f"Requires >= 8 physical GPUs, found {phys_ng}.") + _run_subprocess_test(world_size=8, shape=shape, dtype_str=dtype_str, mode=mode) + + +@pytest.mark.multi_gpu +@pytest.mark.benchmark +@pytest.mark.parametrize("shape,dtype_str,mode", _BENCHMARK_PARAMS, + ids=[_param_id(*p) for p in _BENCHMARK_PARAMS]) +def test_allreduce_8gpu_benchmark(shape, dtype_str, mode): + """8-GPU allreduce benchmark test. + + Uses 51 iters / 5 warmup to get stable timing data. + Performance regression is checked at the CI workflow level by comparing + this PR's results against the main branch (run separately). + """ + phys_ng = _count_physical_gpus() + if phys_ng < 8: + pytest.skip(f"Requires >= 8 physical GPUs, found {phys_ng}.") + _run_subprocess_benchmark(world_size=8, shape=shape, dtype_str=dtype_str, mode=mode) + + +@pytest.mark.multi_gpu +@pytest.mark.parametrize("shape,dtype_str,mode", _4GPU_PARAMS, + ids=[_param_id(*p) for p in _4GPU_PARAMS]) +def test_allreduce_4gpu_accuracy(shape, dtype_str, mode): + """4-GPU allreduce accuracy test (covers fp32 and world_size=4).""" + phys_ng = _count_physical_gpus() + if phys_ng < 4: + pytest.skip(f"Requires >= 4 physical GPUs, found {phys_ng}.") + _run_subprocess_test(world_size=4, shape=shape, dtype_str=dtype_str, mode=mode) + + if __name__ == "__main__": freeze_support() # Align with AIter harness: use spawn to avoid fork+CUDA issues. diff --git a/tests/pytest.ini b/tests/pytest.ini index 765a874a..af4f0e41 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -11,3 +11,5 @@ markers = l1b_target_dialect: requires a target lowering stack; pair with rocm_lower (or future backend markers) l2_device: requires GPU and full runtime stack for execution/correctness rocm_lower: L1b/L2 tests that assume the ROCDL lowering path + multi_gpu: marks tests that require multi-GPU (>=4 physical GPUs; skipped automatically when unavailable) + benchmark: marks performance benchmark tests (longer runtime, more iterations) From a28e1508c91d3085c6c2c494c3772345f2c3baef Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Thu, 16 Apr 2026 21:56:33 +0800 Subject: [PATCH 15/29] [OPT] Add pass: convert-atom-call-to-ssa-form (#407) * [OPT] Add pass: convert-atom-call-to-ssa-form * update mlir tests --- .../flydsl/Dialect/Fly/IR/FlyInterfaces.td | 42 +++ include/flydsl/Dialect/Fly/IR/FlyOps.td | 17 +- include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td | 13 +- .../flydsl/Dialect/Fly/Transforms/Passes.td | 17 +- .../flydsl/Dialect/Fly/Utils/PointerUtils.h | 10 +- lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 99 +++++-- lib/Dialect/Fly/CMakeLists.txt | 1 + lib/Dialect/Fly/IR/FlyTypeDefs.cpp | 23 ++ lib/Dialect/Fly/IR/FlyUniversalOps.cpp | 134 +++++++-- lib/Dialect/Fly/Transforms/Canonicalize.cpp | 8 - .../Transforms/ConvertAtomCallToSSAForm.cpp | 153 ++++++++++ lib/Dialect/Fly/Transforms/LayoutLowering.cpp | 17 +- lib/Dialect/Fly/Utils/PointerUtils.cpp | 23 +- lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp | 226 +++++++++++---- lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp | 123 +++++---- lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp | 68 +++-- lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp | 125 +++++---- tests/mlir/Conversion/copy_atom_stateful.mlir | 9 +- tests/mlir/Conversion/pointer_ops.mlir | 63 ----- .../convert-atom-call-to-ssa-form.mlir | 261 ++++++++++++++++++ 20 files changed, 1122 insertions(+), 310 deletions(-) create mode 100644 lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp create mode 100644 tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir diff --git a/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td b/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td index da5fc892..414c8dab 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td +++ b/include/flydsl/Dialect/Fly/IR/FlyInterfaces.td @@ -104,6 +104,34 @@ def Fly_CopyOpTypeInterface : TypeInterface<"CopyOpTypeInterface"> { "::mlir::Value":$atomVal, "::mlir::Value":$src, "::mlir::Value":$dst, + "::mlir::Value":$pred)>, + InterfaceMethod< + "Emit SSA-form copy: reads src, returns vector result.", + "::mlir::FailureOr<::mlir::Value>", + "emitAtomCallSSA", + (ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc, + "::mlir::Type":$resultTy, + "::mlir::Type":$copyAtomTy, + "::mlir::Type":$srcTy, + "::mlir::Type":$dstTy, + "::mlir::Value":$atomVal, + "::mlir::Value":$src, + "::mlir::Value":$dst)>, + InterfaceMethod< + "Emit SSA-form predicated copy: reads src under predicate, returns vector result.", + "::mlir::FailureOr<::mlir::Value>", + "emitAtomCallSSA", + (ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc, + "::mlir::Type":$resultTy, + "::mlir::Type":$copyAtomTy, + "::mlir::Type":$srcTy, + "::mlir::Type":$dstTy, + "::mlir::Type":$predTy, + "::mlir::Value":$atomVal, + "::mlir::Value":$src, + "::mlir::Value":$dst, "::mlir::Value":$pred)> ]; } @@ -135,6 +163,20 @@ def Fly_MmaOpTypeInterface : TypeInterface<"MmaOpTypeInterface"> { "::mlir::Value":$d, "::mlir::Value":$a, "::mlir::Value":$b, + "::mlir::Value":$c)>, + InterfaceMethod<"", "::mlir::FailureOr<::mlir::Value>", "emitAtomCallSSA", + (ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc, + "::mlir::Type":$resultTy, + "::mlir::Type":$mmaAtomTy, + "::mlir::Type":$dTy, + "::mlir::Type":$aTy, + "::mlir::Type":$bTy, + "::mlir::Type":$cTy, + "::mlir::Value":$atomVal, + "::mlir::Value":$d, + "::mlir::Value":$a, + "::mlir::Value":$b, "::mlir::Value":$c)> ]; } diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index a9fbb837..f0adac84 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -341,8 +341,8 @@ def Fly_TileToShapeOp : Fly_Op<"tile_to_shape", [Pure, DeclareOpInterfaceMethods def Fly_MakeMmaAtomOp : Fly_Op<"make_mma_atom", [Pure]> { let arguments = (ins); - let results = (outs AnyType:$result); - let assemblyFormat = "attr-dict `:` type($result)"; + let results = (outs Fly_MmaAtom:$result); + let assemblyFormat = "attr-dict `:` qualified(type($result))"; } def Fly_MakeCopyAtomOp : Fly_Op<"make_copy_atom", [Pure]> { @@ -361,7 +361,18 @@ def Fly_CopyAtomCall : Fly_Op<"copy_atom_call"> { let arguments = (ins Fly_CopyAtom:$copyAtom, Fly_MemRef:$src, Fly_MemRef:$dst, Optional:$pred); } def Fly_MmaAtomCall : Fly_Op<"mma_atom_call"> { - let arguments = (ins AnyType:$mmaAtom, Fly_MemRef:$d, Fly_MemRef:$a, Fly_MemRef:$b, Fly_MemRef:$c); + let arguments = (ins Fly_MmaAtom:$mmaAtom, Fly_MemRef:$d, Fly_MemRef:$a, Fly_MemRef:$b, Fly_MemRef:$c); +} + +def Fly_CopyAtomCallSSA : Fly_Op<"copy_atom_call_ssa", [AttrSizedOperandSegments]> { + let arguments = (ins Fly_CopyAtom:$copyAtom, AnyType:$src, + Optional:$dst, Optional:$pred); + let results = (outs Variadic:$results); +} +def Fly_MmaAtomCallSSA : Fly_Op<"mma_atom_call_ssa"> { + let arguments = (ins Fly_MmaAtom:$mmaAtom, Optional:$d, + AnyType:$a, AnyType:$b, AnyType:$c); + let results = (outs Variadic:$results); } def Fly_MakeTiledCopyOp : Fly_Op<"make_tiled_copy", [Pure, DeclareOpInterfaceMethods]> { diff --git a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td index 91a6459b..cc548970 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td +++ b/include/flydsl/Dialect/Fly/IR/FlyTypeDefs.td @@ -254,12 +254,20 @@ def Fly_CopyAtom : Fly_Type<"CopyAtom", "copy_atom", [ ::mlir::Attribute getThrValLayoutSrc(); ::mlir::Attribute getThrValLayoutDst(); ::mlir::Attribute getThrValLayoutRef(); + ::mlir::LogicalResult emitAtomCall(OpBuilder &builder, Location loc, Type copyAtomTy, Type srcMemTy, Type dstMemTy, Value atomVal, Value src, Value dst) const; ::mlir::LogicalResult emitAtomCall(OpBuilder &builder, Location loc, Type copyAtomTy, Type srcMemTy, Type dstMemTy, Type predMemTy, Value atomVal, Value src, Value dst, Value pred) const; - + + ::mlir::FailureOr<::mlir::Value> emitAtomCallSSA(OpBuilder &builder, Location loc, Type resultTy, + Type copyAtomTy, Type srcTy, Type dstTy, + Value atomVal, Value src, Value dst) const; + ::mlir::FailureOr<::mlir::Value> emitAtomCallSSA(OpBuilder &builder, Location loc, Type resultTy, + Type copyAtomTy, Type srcTy, Type dstTy, Type predTy, + Value atomVal, Value src, Value dst, Value pred) const; + bool isStateful() const; ::mlir::Type getConvertedType(::mlir::MLIRContext *ctx) const; ::mlir::Value setAtomState(OpBuilder &builder, Location loc, Value atomStruct, @@ -294,6 +302,9 @@ def Fly_MmaAtom : Fly_Type<"MmaAtom", "mma_atom", [ ::mlir::LogicalResult emitAtomCall(OpBuilder &builder, Location loc, Type mmaAtomTy, Type dMemTy, Type aMemTy, Type bMemTy, Type cMemTy, Value atomVal, Value d, Value a, Value b, Value c) const; + ::mlir::FailureOr<::mlir::Value> emitAtomCallSSA(OpBuilder &builder, Location loc, Type resultTy, + Type mmaAtomTy, Type dTy, Type aTy, Type bTy, Type cTy, + Value atomVal, Value d, Value a, Value b, Value c) const; bool isStateful() const; ::mlir::Type getConvertedType(::mlir::MLIRContext *ctx) const; diff --git a/include/flydsl/Dialect/Fly/Transforms/Passes.td b/include/flydsl/Dialect/Fly/Transforms/Passes.td index 2f0c8d80..6277bd71 100644 --- a/include/flydsl/Dialect/Fly/Transforms/Passes.td +++ b/include/flydsl/Dialect/Fly/Transforms/Passes.td @@ -38,10 +38,10 @@ def FlyCanonicalizePass : Pass<"fly-canonicalize"> { def FlyLayoutLoweringPass : Pass<"fly-layout-lowering"> { let summary = "Lower layout algebra operations"; - let description = [{ + let description = [{ Lowers layout algebra operations to simpler forms. }]; - + let dependentDialects = [ "arith::ArithDialect", "gpu::GPUDialect", @@ -49,4 +49,17 @@ def FlyLayoutLoweringPass : Pass<"fly-layout-lowering"> { ]; } +def FlyConvertAtomCallToSSAFormPass : Pass<"fly-convert-atom-call-to-ssa-form"> { + let summary = "Convert copy_atom_call/mma_atom_call to SSA form"; + let description = [{ + Try to convert copy_atom_call and mma_atom_call to their SSA counterparts + (copy_atom_call_ssa / mma_atom_call_ssa) if possible. + Promote register tensor to vector ssa value. Keep others memref unchanged. + }]; + + let dependentDialects = [ + "vector::VectorDialect" + ]; +} + #endif // FLY_PASSES diff --git a/include/flydsl/Dialect/Fly/Utils/PointerUtils.h b/include/flydsl/Dialect/Fly/Utils/PointerUtils.h index a7fbb90f..44b62526 100644 --- a/include/flydsl/Dialect/Fly/Utils/PointerUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/PointerUtils.h @@ -4,15 +4,19 @@ #ifndef FLYDSL_DIALECT_FLY_UTILS_POINTERUTILS_H #define FLYDSL_DIALECT_FLY_UTILS_POINTERUTILS_H +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Value.h" -#include "flydsl/Dialect/Fly/IR/FlyDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" namespace mlir::fly { -TypedValue applySwizzleOnPtr(OpBuilder &b, Location loc, TypedValue ptr, SwizzleAttr swizzle); +TypedValue applySwizzleOnPtr(OpBuilder &b, Location loc, + TypedValue ptr, + SwizzleAttr swizzle); + +Type RegMem2SSAType(fly::MemRefType memRefTy); } // namespace mlir::fly diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index c98770ac..6582a58d 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -200,18 +200,6 @@ class PtrToIntOpLowering : public OpConversionPattern { } }; -class GetIterOpLowering : public OpConversionPattern { -public: - GetIterOpLowering(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} - - LogicalResult matchAndRewrite(GetIterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOp(op, adaptor.getMemref()); - return success(); - } -}; - class ApplySwizzleOpLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -537,6 +525,45 @@ class CopyAtomCallLowering : public OpConversionPattern { } }; +class CopyAtomCallSSALowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(CopyAtomCallSSA op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type copyAtomType = op.getCopyAtom().getType(); + auto copyAtom = dyn_cast(copyAtomType); + if (!copyAtom) + return rewriter.notifyMatchFailure(op, "copyAtom is not CopyAtomType"); + + Location loc = op.getLoc(); + bool hasResult = op.getResults().size() > 0; + Type srcTy = op.getSrc().getType(); + Value pred = adaptor.getPred(); + + Type resultTy = hasResult ? op.getResult(0).getType() : Type{}; + Type dstTy = op.getDst() ? op.getDst().getType() : Type{}; + + FailureOr result; + if (pred) { + result = copyAtom.emitAtomCallSSA(rewriter, loc, resultTy, copyAtomType, srcTy, dstTy, + op.getPred().getType(), adaptor.getCopyAtom(), + adaptor.getSrc(), adaptor.getDst(), pred); + } else { + result = copyAtom.emitAtomCallSSA(rewriter, loc, resultTy, copyAtomType, srcTy, dstTy, + adaptor.getCopyAtom(), adaptor.getSrc(), adaptor.getDst()); + } + if (failed(result)) + return failure(); + + if (hasResult) + rewriter.replaceOp(op, *result); + else + rewriter.eraseOp(op); + return success(); + } +}; + class MmaAtomCallLowering : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -575,6 +602,38 @@ class MmaAtomCallLowering : public OpConversionPattern { } }; +class MmaAtomCallSSALowering : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(MmaAtomCallSSA op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto mmaAtomTy = dyn_cast(op.getMmaAtom().getType()); + if (!mmaAtomTy) + return rewriter.notifyMatchFailure(op, "expected MmaAtomType for mmaAtom operand"); + + Location loc = op.getLoc(); + bool hasResult = op.getResults().size() > 0; + + Type resultTy = hasResult ? op.getResult(0).getType() : Type{}; + Type dTy = op.getD() ? op.getD().getType() : Type{}; + Value dPtr = hasResult ? Value{} : adaptor.getD(); + + auto result = + mmaAtomTy.emitAtomCallSSA(rewriter, loc, resultTy, mmaAtomTy, dTy, op.getA().getType(), + op.getB().getType(), op.getC().getType(), adaptor.getMmaAtom(), + dPtr, adaptor.getA(), adaptor.getB(), adaptor.getC()); + if (failed(result)) + return failure(); + + if (hasResult) + rewriter.replaceOp(op, *result); + else + rewriter.eraseOp(op); + return success(); + } +}; + /// Lower `gpu.launch_func` kernel operands so that any `!fly.memref` values are /// replaced by their type-converted builtin `memref` values. This prevents /// `unrealized_conversion_cast` materializations from remaining live after @@ -739,8 +798,7 @@ class FlyToROCDLConversionPass patterns.add(typeConverter, context); patterns.add(typeConverter, context); - patterns.add(typeConverter, - context); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); @@ -748,6 +806,7 @@ class FlyToROCDLConversionPass patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); patterns.add(typeConverter, context); // TODO: deprecated in the future @@ -802,15 +861,3 @@ class FlyROCDLClusterAttrPass }; } // namespace - -namespace impl { - -std::unique_ptr<::mlir::Pass> createFlyToROCDLConversionPass() { - return std::make_unique(); -} - -std::unique_ptr<::mlir::Pass> createFlyROCDLClusterAttrPass() { - return std::make_unique(); -} - -} // namespace impl diff --git a/lib/Dialect/Fly/CMakeLists.txt b/lib/Dialect/Fly/CMakeLists.txt index a4350259..3e0b6f14 100644 --- a/lib/Dialect/Fly/CMakeLists.txt +++ b/lib/Dialect/Fly/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRFlyDialect Transforms/LayoutLowering.cpp Transforms/Canonicalize.cpp Transforms/RewriteFuncSignature.cpp + Transforms/ConvertAtomCallToSSAForm.cpp DEPENDS MLIRFlyIncGen diff --git a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp index 44b5277e..3571d173 100644 --- a/lib/Dialect/Fly/IR/FlyTypeDefs.cpp +++ b/lib/Dialect/Fly/IR/FlyTypeDefs.cpp @@ -422,6 +422,22 @@ LogicalResult CopyAtomType::emitAtomCall(OpBuilder &builder, Location loc, Type pred); } +FailureOr CopyAtomType::emitAtomCallSSA(OpBuilder &builder, Location loc, Type resultTy, + Type copyAtomTy, Type srcTy, Type dstTy, + Value atomVal, Value src, Value dst) const { + return cast(getCopyOp()) + .emitAtomCallSSA(builder, loc, resultTy, copyAtomTy, srcTy, dstTy, atomVal, src, dst); +} + +FailureOr CopyAtomType::emitAtomCallSSA(OpBuilder &builder, Location loc, Type resultTy, + Type copyAtomTy, Type srcTy, Type dstTy, Type predTy, + Value atomVal, Value src, Value dst, + Value pred) const { + return cast(getCopyOp()) + .emitAtomCallSSA(builder, loc, resultTy, copyAtomTy, srcTy, dstTy, predTy, atomVal, src, dst, + pred); +} + bool CopyAtomType::isStateful() const { return isa(getCopyOp()); } Type CopyAtomType::getConvertedType(MLIRContext *ctx) const { @@ -476,6 +492,13 @@ LogicalResult MmaAtomType::emitAtomCall(OpBuilder &builder, Location loc, Type m return cast(getMmaOp()) .emitAtomCall(builder, loc, mmaAtomTy, dMemTy, aMemTy, bMemTy, cMemTy, atomVal, d, a, b, c); } +FailureOr MmaAtomType::emitAtomCallSSA(OpBuilder &builder, Location loc, Type resultTy, + Type mmaAtomTy, Type dTy, Type aTy, Type bTy, + Type cTy, Value atomVal, Value d, Value a, Value b, + Value c) const { + return cast(getMmaOp()) + .emitAtomCallSSA(builder, loc, resultTy, mmaAtomTy, dTy, aTy, bTy, cTy, atomVal, d, a, b, c); +} bool MmaAtomType::isStateful() const { return isa(getMmaOp()); } diff --git a/lib/Dialect/Fly/IR/FlyUniversalOps.cpp b/lib/Dialect/Fly/IR/FlyUniversalOps.cpp index 50baa5d2..f0b9eaf2 100644 --- a/lib/Dialect/Fly/IR/FlyUniversalOps.cpp +++ b/lib/Dialect/Fly/IR/FlyUniversalOps.cpp @@ -122,6 +122,62 @@ void MmaOpUniversalFMAType::print(AsmPrinter &printer) const { printer << ", (" << getElemTy() << ", " << getElemTy() << ") -> " << getElemTy() << ">"; } +FailureOr CopyOpUniversalCopyType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type copyAtomTyArg, + Type srcTyArg, Type dstTyArg, + Value atomVal, Value src, + Value dst) const { + Value result; + if (isa(srcTyArg)) { + // src is memory + auto srcMemTy = cast(srcTyArg); + Type loadTy = resultTy ? resultTy : builder.getIntegerType(getBitSize()); + Value srcPtr = applySwizzleOnPtr(builder, loc, cast>(src), + srcMemTy.getSwizzle()); + result = LLVM::LoadOp::create(builder, loc, loadTy, srcPtr); + } else { + // src is register + result = src; + } + + if (!resultTy) { + // dst is memory + auto dstMemTy = cast(dstTyArg); + Value dstPtr = applySwizzleOnPtr(builder, loc, cast>(dst), + dstMemTy.getSwizzle()); + LLVM::StoreOp::create(builder, loc, result, dstPtr); + } + return result; +} + +FailureOr CopyOpUniversalCopyType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type copyAtomTyArg, + Type srcTyArg, Type dstTyArg, + Type predTyArg, Value atomVal, Value src, + Value dst, Value pred) const { + OpBuilder::InsertionGuard guard(builder); + if (resultTy) { + auto ifOp = scf::IfOp::create(builder, loc, resultTy, pred, /*withElseRegion=*/true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, + atomVal, src, dst); + if (failed(result)) + return failure(); + scf::YieldOp::create(builder, loc, *result); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + scf::YieldOp::create(builder, loc, dst); + return ifOp.getResult(0); + } + + auto ifOp = scf::IfOp::create(builder, loc, TypeRange{}, pred, /*withElse=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = + emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, atomVal, src, dst); + if (failed(result)) + return failure(); + return Value(); +} + LogicalResult CopyOpUniversalCopyType::emitAtomCall(OpBuilder &builder, Location loc, Type copyAtomTyArg, Type srcMemTyArg, Type dstMemTyArg, Value atomVal, Value src, @@ -172,22 +228,17 @@ static std::optional convertAtomicOp(AtomicOp binOp, bool isF return isFloat ? std::nullopt : std::optional(LLVM::AtomicBinOp::uinc_wrap); case AtomicOp::Dec: return isFloat ? std::nullopt : std::optional(LLVM::AtomicBinOp::udec_wrap); - default: - return std::nullopt; } + return std::nullopt; } -LogicalResult CopyOpUniversalAtomicType::emitAtomCall(OpBuilder &builder, Location loc, - Type copyAtomTyArg, Type srcMemTyArg, - Type dstMemTyArg, Value atomVal, Value src, - Value dst) const { - if (!isa(src.getType()) || !isa(dst.getType())) - return failure(); - - auto srcMemTy = cast(srcMemTyArg); - auto dstMemTy = cast(dstMemTyArg); - - if (srcMemTy.getAddressSpace().getValue() != AddressSpace::Register) +FailureOr CopyOpUniversalAtomicType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type copyAtomTyArg, + Type srcTyArg, Type dstTyArg, + Value atomVal, Value src, + Value dst) const { + auto dstMemTy = dstTyArg ? dyn_cast(dstTyArg) : fly::MemRefType(); + if (!dstMemTy) return failure(); Type elemTy = getValType(); @@ -196,12 +247,50 @@ LogicalResult CopyOpUniversalAtomicType::emitAtomCall(OpBuilder &builder, Locati Value dstPtr = applySwizzleOnPtr(builder, loc, cast>(dst), dstMemTy.getSwizzle()); - Value loaded = LLVM::LoadOp::create(builder, loc, elemTy, src); - auto binOp = convertAtomicOp(getAtomicOp().getValue(), isFloat); if (!binOp) return failure(); - LLVM::AtomicRMWOp::create(builder, loc, *binOp, dstPtr, loaded, LLVM::AtomicOrdering::monotonic); + LLVM::AtomicRMWOp::create(builder, loc, *binOp, dstPtr, src, LLVM::AtomicOrdering::monotonic); + return src; +} + +FailureOr CopyOpUniversalAtomicType::emitAtomCallSSA( + OpBuilder &builder, Location loc, Type resultTy, Type copyAtomTyArg, Type srcTyArg, + Type dstTyArg, Type predTyArg, Value atomVal, Value src, Value dst, Value pred) const { + OpBuilder::InsertionGuard guard(builder); + if (resultTy) { + auto ifOp = scf::IfOp::create(builder, loc, resultTy, pred, /*withElseRegion=*/true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, + atomVal, src, dst); + if (failed(result)) + return failure(); + scf::YieldOp::create(builder, loc, *result); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + scf::YieldOp::create(builder, loc, dst); + return ifOp.getResult(0); + } + + auto ifOp = scf::IfOp::create(builder, loc, TypeRange{}, pred, /*withElse=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = + emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, atomVal, src, dst); + if (failed(result)) + return failure(); + return Value(); +} + +LogicalResult CopyOpUniversalAtomicType::emitAtomCall(OpBuilder &builder, Location loc, + Type copyAtomTyArg, Type srcMemTyArg, + Type dstMemTyArg, Value atomVal, Value src, + Value dst) const { + auto srcMemTy = cast(srcMemTyArg); + auto srcSSATy = fly::RegMem2SSAType(srcMemTy); + Value srcVal = LLVM::LoadOp::create(builder, loc, srcSSATy, src); + auto res = emitAtomCallSSA(builder, loc, Type{}, copyAtomTyArg, srcSSATy, dstMemTyArg, atomVal, + srcVal, dst); + if (failed(res)) + return failure(); return success(); } @@ -218,6 +307,19 @@ LogicalResult CopyOpUniversalAtomicType::emitAtomCall(OpBuilder &builder, Locati return emitAtomCall(builder, loc, copyAtomTyArg, srcMemTyArg, dstMemTyArg, atomVal, src, dst); } +FailureOr MmaOpUniversalFMAType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type mmaAtomTyArg, + Type dTyArg, Type aTyArg, Type bTyArg, + Type cTyArg, Value atomVal, Value d, + Value a, Value b, Value c) const { + Type elemTy = getElemTy(); + Value mul = LLVM::FMulOp::create(builder, loc, elemTy, a, b); + Value res = LLVM::FAddOp::create(builder, loc, elemTy, mul, c); + if (d) + LLVM::StoreOp::create(builder, loc, res, d); + return res; +} + LogicalResult MmaOpUniversalFMAType::emitAtomCall(OpBuilder &builder, Location loc, Type mmaAtomTy, Type dMemTy, Type aMemTy, Type bMemTy, Type cMemTy, Value atomVal, Value dPtr, diff --git a/lib/Dialect/Fly/Transforms/Canonicalize.cpp b/lib/Dialect/Fly/Transforms/Canonicalize.cpp index 79a4c453..85ccb22b 100644 --- a/lib/Dialect/Fly/Transforms/Canonicalize.cpp +++ b/lib/Dialect/Fly/Transforms/Canonicalize.cpp @@ -77,11 +77,3 @@ class FlyCanonicalizePass : public mlir::fly::impl::FlyCanonicalizePassBase createFlyCanonicalizePass() { - return std::make_unique(); -} - -} // namespace impl diff --git a/lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp b/lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp new file mode 100644 index 00000000..47cb5e63 --- /dev/null +++ b/lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Transforms/Passes.h" +#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" +#include "flydsl/Dialect/Fly/Utils/PointerUtils.h" + +using namespace mlir; +using namespace mlir::fly; + +namespace mlir { +namespace fly { +#define GEN_PASS_DEF_FLYCONVERTATOMCALLTOSSAFORMPASS +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" +} // namespace fly +} // namespace mlir + +namespace { + +bool isEligibleToPromote(fly::MemRefType memRefTy) { + if (memRefTy.getAddressSpace().getValue() != AddressSpace::Register) + return false; + auto layoutAttr = dyn_cast(memRefTy.getLayout()); + if (!layoutAttr) + return false; + LayoutBuilder builder(memRefTy.getContext()); + auto coalesced = layoutCoalesce(builder, layoutAttr); + if (!coalesced.isLeaf()) + return false; + return coalesced.getStride().isLeafStaticValue(1) || coalesced.getShape().isLeafStaticValue(1); +} + +class FlyConvertAtomCallToSSAFormPass + : public mlir::fly::impl::FlyConvertAtomCallToSSAFormPassBase { +public: + using mlir::fly::impl::FlyConvertAtomCallToSSAFormPassBase< + FlyConvertAtomCallToSSAFormPass>::FlyConvertAtomCallToSSAFormPassBase; + + void runOnOperation() override { + auto moduleOp = getOperation(); + + SmallVector copyOpsToConvert; + SmallVector mmaOpsToConvert; + + moduleOp->walk([&](CopyAtomCall op) { + auto srcTy = cast(op.getSrc().getType()); + auto dstTy = cast(op.getDst().getType()); + if (isEligibleToPromote(srcTy) || isEligibleToPromote(dstTy)) + copyOpsToConvert.push_back(op); + }); + + moduleOp->walk([&](MmaAtomCall op) { + auto dTy = cast(op.getD().getType()); + auto aTy = cast(op.getA().getType()); + auto bTy = cast(op.getB().getType()); + auto cTy = cast(op.getC().getType()); + if (isEligibleToPromote(dTy) || isEligibleToPromote(aTy) || isEligibleToPromote(bTy) || + isEligibleToPromote(cTy)) + mmaOpsToConvert.push_back(op); + }); + + OpBuilder builder(moduleOp->getContext()); + + for (CopyAtomCall copyOp : copyOpsToConvert) { + auto srcTy = cast(copyOp.getSrc().getType()); + auto dstTy = cast(copyOp.getDst().getType()); + bool srcEligible = isEligibleToPromote(srcTy); + bool dstEligible = isEligibleToPromote(dstTy); + + builder.setInsertionPoint(copyOp); + Location loc = copyOp.getLoc(); + + Value srcVal = copyOp.getSrc(); + if (srcEligible) { + Value srcIter = srcVal.getDefiningOp().getIter(); + srcVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(srcTy), srcIter); + } + + Value pred = copyOp.getPred(); + if (pred) { + auto predMemTy = cast(pred.getType()); + if (isEligibleToPromote(predMemTy)) { + Value predIter = pred.getDefiningOp().getIter(); + pred = PtrLoadOp::create(builder, loc, RegMem2SSAType(predMemTy), predIter); + } + } + + if (dstEligible) { + auto ssaTy = RegMem2SSAType(dstTy); + Value dstIter = copyOp.getDst().getDefiningOp().getIter(); + Value oldDst = pred ? PtrLoadOp::create(builder, loc, ssaTy, dstIter) : Value{}; + auto ssaOp = + CopyAtomCallSSA::create(builder, loc, TypeRange{ssaTy}, copyOp.getCopyAtom(), srcVal, + /*dst=*/oldDst, pred); + PtrStoreOp::create(builder, loc, ssaOp.getResult(0), dstIter); + } else { + CopyAtomCallSSA::create(builder, loc, TypeRange{}, copyOp.getCopyAtom(), srcVal, + /*dst=*/copyOp.getDst(), pred); + } + copyOp->erase(); + } + + for (MmaAtomCall mmaOp : mmaOpsToConvert) { + auto dTy = cast(mmaOp.getD().getType()); + auto aTy = cast(mmaOp.getA().getType()); + auto bTy = cast(mmaOp.getB().getType()); + auto cTy = cast(mmaOp.getC().getType()); + bool dEligible = isEligibleToPromote(dTy); + bool aEligible = isEligibleToPromote(aTy); + bool bEligible = isEligibleToPromote(bTy); + bool cEligible = isEligibleToPromote(cTy); + + builder.setInsertionPoint(mmaOp); + Location loc = mmaOp.getLoc(); + + Value aVal = mmaOp.getA(); + Value bVal = mmaOp.getB(); + Value cVal = mmaOp.getC(); + + if (aEligible) { + Value aIter = aVal.getDefiningOp().getIter(); + aVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(aTy), aIter).getResult(); + } + if (bEligible) { + Value bIter = bVal.getDefiningOp().getIter(); + bVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(bTy), bIter).getResult(); + } + if (cEligible) { + Value cIter = cVal.getDefiningOp().getIter(); + cVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(cTy), cIter).getResult(); + } + + if (dEligible) { + auto ssaOp = MmaAtomCallSSA::create(builder, loc, TypeRange{RegMem2SSAType(dTy)}, + mmaOp.getMmaAtom(), /*d=*/nullptr, aVal, bVal, cVal); + Value dIter = mmaOp.getD().getDefiningOp().getIter(); + PtrStoreOp::create(builder, loc, ssaOp.getResult(0), dIter); + } else { + MmaAtomCallSSA::create(builder, loc, TypeRange{}, mmaOp.getMmaAtom(), /*d=*/mmaOp.getD(), + aVal, bVal, cVal); + } + mmaOp->erase(); + } + } +}; + +} // namespace diff --git a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp index 7a15dd95..9ecee02f 100644 --- a/lib/Dialect/Fly/Transforms/LayoutLowering.cpp +++ b/lib/Dialect/Fly/Transforms/LayoutLowering.cpp @@ -419,7 +419,8 @@ class GetScalarLowering : public OpRewritePattern { while (!scalarAttr.isLeaf() && scalarAttr.rank() == 1) scalarAttr = scalarAttr.at(0); if (!scalarAttr.isLeaf()) - return rewriter.notifyMatchFailure(op, "expected leaf IntTupleAttr after unwrapping rank-1 chain"); + return rewriter.notifyMatchFailure( + op, "expected leaf IntTupleAttr after unwrapping rank-1 chain"); auto intAttr = scalarAttr.extractIntFromLeaf(); if (intAttr.isStatic()) { Type resultTy = op.getResult().getType(); @@ -2749,9 +2750,9 @@ class FlyLayoutLoweringPass RewritePatternSet patterns(context); // Constructors - patterns.add(context); + patterns + .add(context); // Extractors patterns.add createFlyLayoutLoweringPass() { - return std::make_unique(); -} - -} // namespace impl diff --git a/lib/Dialect/Fly/Utils/PointerUtils.cpp b/lib/Dialect/Fly/Utils/PointerUtils.cpp index 4650c009..80210a01 100644 --- a/lib/Dialect/Fly/Utils/PointerUtils.cpp +++ b/lib/Dialect/Fly/Utils/PointerUtils.cpp @@ -1,13 +1,17 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2025 FlyDSL Project Contributors -#include "flydsl/Dialect/Fly/Utils/PointerUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" +#include "flydsl/Dialect/Fly/Utils/PointerUtils.h" + namespace mlir::fly { -TypedValue applySwizzleOnPtr(OpBuilder &b, Location loc, TypedValue ptr, SwizzleAttr swizzle) { +TypedValue applySwizzleOnPtr(OpBuilder &b, Location loc, + TypedValue ptr, + SwizzleAttr swizzle) { if (swizzle.isTrivialSwizzle()) return ptr; auto ptrTy = ptr.getType(); @@ -20,7 +24,20 @@ TypedValue applySwizzleOnPtr(OpBuilder &b, Location loc, Value masked = arith::AndIOp::create(b, loc, ptrInt, bitMask); Value shifted = arith::ShRUIOp::create(b, loc, masked, shiftAmt); Value swizzled = arith::XOrIOp::create(b, loc, ptrInt, shifted); - return cast>(LLVM::IntToPtrOp::create(b, loc, ptrTy, swizzled).getResult()); + return cast>( + LLVM::IntToPtrOp::create(b, loc, ptrTy, swizzled).getResult()); +} + +Type RegMem2SSAType(fly::MemRefType memRefTy) { + if (memRefTy.getAddressSpace().getValue() != AddressSpace::Register) + return Type(); + LayoutBuilder builder(memRefTy.getContext()); + auto layoutAttr = cast(memRefTy.getLayout()); + int32_t cosize = layoutCosize(builder, layoutAttr).getLeafAsInt().getValue(); + Type elemTy = memRefTy.getElemTy(); + if (cosize == 1) + return elemTy; + return VectorType::get({cosize}, elemTy); } } // namespace mlir::fly diff --git a/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp index f4cf7a41..a182dc72 100644 --- a/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/PointerUtils.h" #include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" #include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" #include "flydsl/Dialect/FlyROCDL/Utils/BufferFatPtr.h" @@ -64,62 +65,122 @@ Attribute CopyOpCDNA3BufferCopyType::getThrBitLayoutRef() const { return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); } -LogicalResult CopyOpCDNA3BufferCopyType::emitAtomCall(OpBuilder &builder, Location loc, - Type copyAtomTyArg, Type srcMemTyArg, - Type dstMemTyArg, Value atomVal, Value src, - Value dst) const { - auto srcMemTy = cast(srcMemTyArg); - auto dstMemTy = cast(dstMemTyArg); - +FailureOr CopyOpCDNA3BufferCopyType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type copyAtomTyArg, + Type srcTyArg, Type dstTyArg, + Value atomVal, Value src, + Value dst) const { IntegerType copyTy = builder.getIntegerType(getBitSize()); - AddressSpace srcAS = srcMemTy.getAddressSpace().getValue(); - AddressSpace dstAS = dstMemTy.getAddressSpace().getValue(); - - bool srcIsBuffer = (srcAS == AddressSpace::BufferDesc); - bool dstIsBuffer = (dstAS == AddressSpace::BufferDesc); - - if (srcIsBuffer == dstIsBuffer) - return failure(); - Value soffsetRaw = LLVM::ExtractValueOp::create( builder, loc, atomVal, ArrayRef{*getFieldIndex(AtomStateField::Soffset)}); - fly::MemRefType bufferMemTy = srcIsBuffer ? srcMemTy : dstMemTy; - int64_t elemBits = bufferMemTy.getElemTy().getIntOrFloatBitWidth(); - Value soffset; - if (elemBits == 8) { - soffset = soffsetRaw; - } else if (elemBits > 8 && elemBits % 8 == 0) { - Value scale = arith::ConstantIntOp::create(builder, loc, elemBits / 8, 32); - soffset = arith::MulIOp::create(builder, loc, soffsetRaw, scale); - } else { + auto computeSoffset = [&](int64_t elemBits) -> Value { + if (elemBits == 8) + return soffsetRaw; + if (elemBits > 8 && elemBits % 8 == 0) { + Value scale = arith::ConstantIntOp::create(builder, loc, elemBits / 8, 32); + return arith::MulIOp::create(builder, loc, soffsetRaw, scale); + } Value scale = arith::ConstantIntOp::create(builder, loc, elemBits, 32); Value bits = arith::MulIOp::create(builder, loc, soffsetRaw, scale); Value eight = arith::ConstantIntOp::create(builder, loc, 8, 32); - soffset = arith::DivUIOp::create(builder, loc, bits, eight); - } + return arith::DivUIOp::create(builder, loc, bits, eight); + }; Value zero = arith::ConstantIntOp::create(builder, loc, 0, 32); ArrayAttr noAttrs; - auto unpackBuffer = [&](Value val, fly::MemRefType flyTy) -> std::pair { - BufferFatPtr bp(flyTy.getPointerType(), val); - return {bp.bufferRsrc(builder, loc), bp.swizzleByteOffset(builder, loc)}; - }; + auto srcMemTy = srcTyArg ? dyn_cast(srcTyArg) : fly::MemRefType(); + auto dstMemTy = dstTyArg ? dyn_cast(dstTyArg) : fly::MemRefType(); + + if (srcMemTy && srcMemTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) { + // buffer -> reg + Value soffset = computeSoffset(srcMemTy.getElemTy().getIntOrFloatBitWidth()); + BufferFatPtr bp(srcMemTy.getPointerType(), src); + Value srcRsrc = bp.bufferRsrc(builder, loc); + Value srcOff = bp.swizzleByteOffset(builder, loc); - if (srcIsBuffer && !dstIsBuffer) { - auto [srcRsrc, srcOff] = unpackBuffer(src, srcMemTy); Value loaded = ROCDL::RawPtrBufferLoadOp::create(builder, loc, copyTy, srcRsrc, srcOff, soffset, zero, noAttrs, noAttrs, noAttrs); - LLVM::StoreOp::create(builder, loc, loaded, dst); - } else if (!srcIsBuffer && dstIsBuffer) { - auto [dstRsrc, dstOff] = unpackBuffer(dst, dstMemTy); - Value loaded = LLVM::LoadOp::create(builder, loc, copyTy, src); - ROCDL::RawPtrBufferStoreOp::create(builder, loc, loaded, dstRsrc, dstOff, soffset, zero, + if (resultTy && loaded.getType() != resultTy) + loaded = LLVM::BitcastOp::create(builder, loc, resultTy, loaded); + return loaded; + } + + if (dstMemTy && dstMemTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) { + // reg -> buffer + Value soffset = computeSoffset(dstMemTy.getElemTy().getIntOrFloatBitWidth()); + BufferFatPtr bp(dstMemTy.getPointerType(), dst); + Value dstRsrc = bp.bufferRsrc(builder, loc); + Value dstOff = bp.swizzleByteOffset(builder, loc); + + Value stored = src; + if (stored.getType() != copyTy) + stored = LLVM::BitcastOp::create(builder, loc, copyTy, stored); + ROCDL::RawPtrBufferStoreOp::create(builder, loc, stored, dstRsrc, dstOff, soffset, zero, noAttrs, noAttrs, noAttrs); + return stored; + } + + return failure(); +} + +FailureOr CopyOpCDNA3BufferCopyType::emitAtomCallSSA( + OpBuilder &builder, Location loc, Type resultTy, Type copyAtomTyArg, Type srcTyArg, + Type dstTyArg, Type predTyArg, Value atomVal, Value src, Value dst, Value pred) const { + OpBuilder::InsertionGuard guard(builder); + if (resultTy) { + // buffer -> reg + auto ifOp = scf::IfOp::create(builder, loc, resultTy, pred, /*withElseRegion=*/true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, + atomVal, src, dst); + if (failed(result)) + return failure(); + scf::YieldOp::create(builder, loc, *result); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + scf::YieldOp::create(builder, loc, dst); + return ifOp.getResult(0); } else { + // reg -> buffer + auto ifOp = scf::IfOp::create(builder, loc, TypeRange{}, pred, /*withElse=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, + atomVal, src, dst); + if (failed(result)) + return failure(); + return Value(); + } +} + +LogicalResult CopyOpCDNA3BufferCopyType::emitAtomCall(OpBuilder &builder, Location loc, + Type copyAtomTyArg, Type srcMemTyArg, + Type dstMemTyArg, Value atomVal, Value src, + Value dst) const { + auto srcMemTy = cast(srcMemTyArg); + auto dstMemTy = cast(dstMemTyArg); + + bool srcIsBuffer = (srcMemTy.getAddressSpace().getValue() == AddressSpace::BufferDesc); + bool dstIsBuffer = (dstMemTy.getAddressSpace().getValue() == AddressSpace::BufferDesc); + + if (srcIsBuffer == dstIsBuffer) return failure(); + + if (srcIsBuffer) { + auto dstSSATy = fly::RegMem2SSAType(dstMemTy); + auto res = emitAtomCallSSA(builder, loc, dstSSATy, copyAtomTyArg, srcMemTyArg, Type{}, atomVal, + src, Value{}); + if (failed(res)) + return failure(); + LLVM::StoreOp::create(builder, loc, *res, dst); + } else { + auto srcSSATy = fly::RegMem2SSAType(srcMemTy); + Value srcVal = LLVM::LoadOp::create(builder, loc, srcSSATy, src); + auto res = emitAtomCallSSA(builder, loc, Type{}, copyAtomTyArg, srcSSATy, dstMemTyArg, atomVal, + srcVal, dst); + if (failed(res)) + return failure(); } return success(); } @@ -129,6 +190,7 @@ LogicalResult CopyOpCDNA3BufferCopyType::emitAtomCall(OpBuilder &builder, Locati Type dstMemTyArg, Type predMemTyArg, Value atomVal, Value src, Value dst, Value pred) const { + OpBuilder::InsertionGuard guard(builder); auto predMemTy = cast(predMemTyArg); Value predVal = LLVM::LoadOp::create(builder, loc, predMemTy.getElemTy(), pred); auto ifOp = scf::IfOp::create(builder, loc, TypeRange{}, predVal, /*withElse=*/false); @@ -198,6 +260,25 @@ Attribute CopyOpCDNA3BufferCopyLDSType::getThrBitLayoutRef() const { return FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1))); } +FailureOr CopyOpCDNA3BufferCopyLDSType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type copyAtomTyArg, + Type srcTyArg, Type dstTyArg, + Value atomVal, Value src, + Value dst) const { + if (failed(emitAtomCall(builder, loc, copyAtomTyArg, srcTyArg, dstTyArg, atomVal, src, dst))) + return failure(); + return Value{}; +} + +FailureOr CopyOpCDNA3BufferCopyLDSType::emitAtomCallSSA( + OpBuilder &builder, Location loc, Type resultTy, Type copyAtomTyArg, Type srcTyArg, + Type dstTyArg, Type predTyArg, Value atomVal, Value src, Value dst, Value pred) const { + if (failed(emitAtomCall(builder, loc, copyAtomTyArg, srcTyArg, dstTyArg, predTyArg, atomVal, src, + dst, pred))) + return failure(); + return Value{}; +} + LogicalResult CopyOpCDNA3BufferCopyLDSType::emitAtomCall(OpBuilder &builder, Location loc, Type copyAtomTyArg, Type srcMemTyArg, Type dstMemTyArg, Value atomVal, Value src, @@ -250,6 +331,7 @@ LogicalResult CopyOpCDNA3BufferCopyLDSType::emitAtomCall(OpBuilder &builder, Loc Type dstMemTyArg, Type predMemTyArg, Value atomVal, Value src, Value dst, Value pred) const { + OpBuilder::InsertionGuard guard(builder); auto predMemTy = cast(predMemTyArg); Value predVal = LLVM::LoadOp::create(builder, loc, predMemTy.getElemTy(), pred); auto ifOp = scf::IfOp::create(builder, loc, TypeRange{}, predVal, /*withElse=*/false); @@ -310,17 +392,13 @@ Attribute CopyOpCDNA3BufferAtomicType::getThrBitLayoutRef() const { return FxLayout(FxShape(FxC(1), FxC(bits)), FxStride(FxC(1), FxC(1))); } -LogicalResult CopyOpCDNA3BufferAtomicType::emitAtomCall(OpBuilder &builder, Location loc, - Type copyAtomTyArg, Type srcMemTyArg, - Type dstMemTyArg, Value atomVal, Value src, - Value dst) const { - auto srcMemTy = cast(srcMemTyArg); - auto dstMemTy = cast(dstMemTyArg); - - AddressSpace srcAS = srcMemTy.getAddressSpace().getValue(); - AddressSpace dstAS = dstMemTy.getAddressSpace().getValue(); - - if (srcAS != AddressSpace::Register || dstAS != AddressSpace::BufferDesc) +FailureOr CopyOpCDNA3BufferAtomicType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type copyAtomTyArg, + Type srcTyArg, Type dstTyArg, + Value atomVal, Value src, + Value dst) const { + auto dstMemTy = cast(dstTyArg); + if (dstMemTy.getAddressSpace().getValue() != AddressSpace::BufferDesc) return failure(); Type valTy = getValType(); @@ -329,8 +407,6 @@ LogicalResult CopyOpCDNA3BufferAtomicType::emitAtomCall(OpBuilder &builder, Loca scalarTy = vecTy.getElementType(); bool isFloat = isa(scalarTy); - Value loaded = LLVM::LoadOp::create(builder, loc, valTy, src); - BufferFatPtr bp(dstMemTy.getPointerType(), dst); Value dstRsrc = bp.bufferRsrc(builder, loc); Value dstOff = bp.swizzleByteOffset(builder, loc); @@ -361,27 +437,67 @@ LogicalResult CopyOpCDNA3BufferAtomicType::emitAtomCall(OpBuilder &builder, Loca case AtomicOp::Add: if (!isFloat) return failure(); - ROCDL::RawPtrBufferAtomicFaddOp::create(builder, loc, loaded, dstRsrc, dstOff, soffset, zero, + ROCDL::RawPtrBufferAtomicFaddOp::create(builder, loc, src, dstRsrc, dstOff, soffset, zero, noAttrs, noAttrs, noAttrs); break; case AtomicOp::Max: if (isFloat) - ROCDL::RawPtrBufferAtomicFmaxOp::create(builder, loc, loaded, dstRsrc, dstOff, soffset, zero, + ROCDL::RawPtrBufferAtomicFmaxOp::create(builder, loc, src, dstRsrc, dstOff, soffset, zero, noAttrs, noAttrs, noAttrs); else - ROCDL::RawPtrBufferAtomicSmaxOp::create(builder, loc, loaded, dstRsrc, dstOff, soffset, zero, + ROCDL::RawPtrBufferAtomicSmaxOp::create(builder, loc, src, dstRsrc, dstOff, soffset, zero, noAttrs, noAttrs, noAttrs); break; case AtomicOp::Min: if (isFloat) return failure(); - ROCDL::RawPtrBufferAtomicUminOp::create(builder, loc, loaded, dstRsrc, dstOff, soffset, zero, + ROCDL::RawPtrBufferAtomicUminOp::create(builder, loc, src, dstRsrc, dstOff, soffset, zero, noAttrs, noAttrs, noAttrs); break; default: return failure(); } + return src; +} + +FailureOr CopyOpCDNA3BufferAtomicType::emitAtomCallSSA( + OpBuilder &builder, Location loc, Type resultTy, Type copyAtomTyArg, Type srcTyArg, + Type dstTyArg, Type predTyArg, Value atomVal, Value src, Value dst, Value pred) const { + OpBuilder::InsertionGuard guard(builder); + if (resultTy) { + auto ifOp = scf::IfOp::create(builder, loc, resultTy, pred, /*withElseRegion=*/true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, + atomVal, src, dst); + if (failed(result)) + return failure(); + scf::YieldOp::create(builder, loc, *result); + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + scf::YieldOp::create(builder, loc, dst); + return ifOp.getResult(0); + } + + auto ifOp = scf::IfOp::create(builder, loc, TypeRange{}, pred, /*withElse=*/false); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = + emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, atomVal, src, dst); + if (failed(result)) + return failure(); + return Value(); +} + +LogicalResult CopyOpCDNA3BufferAtomicType::emitAtomCall(OpBuilder &builder, Location loc, + Type copyAtomTyArg, Type srcMemTyArg, + Type dstMemTyArg, Value atomVal, Value src, + Value dst) const { + auto srcMemTy = cast(srcMemTyArg); + auto srcSSATy = fly::RegMem2SSAType(srcMemTy); + Value srcVal = LLVM::LoadOp::create(builder, loc, srcSSATy, src); + auto res = emitAtomCallSSA(builder, loc, Type{}, copyAtomTyArg, srcSSATy, dstMemTyArg, atomVal, + srcVal, dst); + if (failed(res)) + return failure(); return success(); } diff --git a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp index 40637c90..f8d898e0 100644 --- a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp @@ -1,13 +1,15 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2025 FlyDSL Project Contributors +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" #include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" #include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" -#include "mlir/IR/BuiltinTypes.h" using namespace mlir; using namespace mlir::fly; @@ -108,7 +110,7 @@ static Type getMfmaABType(MLIRContext *ctx, Type elemTy, int32_t k = 0) { return VectorType::get({4}, Float16Type::get(ctx)); if (elemTy.isBF16()) return VectorType::get({(k >= 16) ? 4 : 2}, IntegerType::get(ctx, 16)); - if (isF8(elemTy)) + if (elemTy.getIntOrFloatBitWidth() == 8) return IntegerType::get(ctx, 64); return nullptr; } @@ -161,16 +163,68 @@ static int64_t getMfmaAccVecSize(int32_t m, int32_t k, Type elemTyA) { return 0; } -template -static LogicalResult emitMfma(OpBuilder &builder, Location loc, Type abTyA, Type abTyB, - VectorType accTy, Value aPtr, Value bPtr, Value cPtr, Value dPtr) { - Value a = LLVM::LoadOp::create(builder, loc, abTyA, aPtr); - Value b = LLVM::LoadOp::create(builder, loc, abTyB, bPtr); - Value c = LLVM::LoadOp::create(builder, loc, accTy, cPtr); - auto zeroAttr = builder.getI32IntegerAttr(0); - Value res = MfmaOp::create(builder, loc, accTy, a, b, c, zeroAttr, zeroAttr, zeroAttr); - LLVM::StoreOp::create(builder, loc, res, dPtr); - return success(); +FailureOr MmaOpCDNA3_MFMAType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type mmaAtomTyArg, Type dTyArg, + Type aTyArg, Type bTyArg, Type cTyArg, + Value atomVal, Value d, Value a, Value b, + Value c) const { + int32_t m = getM(); + int32_t n = getN(); + int32_t k = getK(); + Type elemTyA = getElemTyA(); + Type elemTyB = getElemTyB(); + MLIRContext *ctx = builder.getContext(); + + Type abTyA = getMfmaABType(ctx, elemTyA, k); + Type abTyB = getMfmaABType(ctx, elemTyB, k); + if (!abTyA || !abTyB) + return failure(); + + int64_t accVecSize = getMfmaAccVecSize(m, k, elemTyA); + if (accVecSize == 0) + return failure(); + + Type accElemTy = getElemTyAcc(); + VectorType accTy = VectorType::get({accVecSize}, accElemTy); + +#define DISPATCH_MFMA_SSA(M_, K_, PRED, OP) \ + if (m == M_ && n == M_ && k == K_ && (PRED)) { \ + auto zeroAttr = builder.getI32IntegerAttr(0); \ + return ROCDL::OP::create(builder, loc, accTy, a, b, c, zeroAttr, zeroAttr, zeroAttr) \ + .getResult(); \ + } + + DISPATCH_MFMA_SSA(32, 1, elemTyA.isF32(), mfma_f32_32x32x1f32) + DISPATCH_MFMA_SSA(16, 1, elemTyA.isF32(), mfma_f32_16x16x1f32) + DISPATCH_MFMA_SSA(4, 1, elemTyA.isF32(), mfma_f32_4x4x1f32) + DISPATCH_MFMA_SSA(32, 2, elemTyA.isF32(), mfma_f32_32x32x2f32) + DISPATCH_MFMA_SSA(16, 4, elemTyA.isF32(), mfma_f32_16x16x4f32) + + DISPATCH_MFMA_SSA(32, 4, elemTyA.isF16(), mfma_f32_32x32x4f16) + DISPATCH_MFMA_SSA(16, 4, elemTyA.isF16(), mfma_f32_16x16x4f16) + DISPATCH_MFMA_SSA(4, 4, elemTyA.isF16(), mfma_f32_4x4x4f16) + DISPATCH_MFMA_SSA(32, 8, elemTyA.isF16(), mfma_f32_32x32x8f16) + DISPATCH_MFMA_SSA(16, 16, elemTyA.isF16(), mfma_f32_16x16x16f16) + + DISPATCH_MFMA_SSA(32, 2, elemTyA.isBF16(), mfma_f32_32x32x2bf16) + DISPATCH_MFMA_SSA(16, 2, elemTyA.isBF16(), mfma_f32_16x16x2bf16) + DISPATCH_MFMA_SSA(4, 2, elemTyA.isBF16(), mfma_f32_4x4x2bf16) + DISPATCH_MFMA_SSA(32, 4, elemTyA.isBF16(), mfma_f32_32x32x4bf16) + DISPATCH_MFMA_SSA(16, 8, elemTyA.isBF16(), mfma_f32_16x16x8bf16) + DISPATCH_MFMA_SSA(16, 16, elemTyA.isBF16(), mfma_f32_16x16x16bf16_1k) + + DISPATCH_MFMA_SSA(16, 32, isFP8(elemTyA) && isFP8(elemTyB), mfma_f32_16x16x32_fp8_fp8) + DISPATCH_MFMA_SSA(16, 32, isFP8(elemTyA) && isBF8(elemTyB), mfma_f32_16x16x32_fp8_bf8) + DISPATCH_MFMA_SSA(16, 32, isBF8(elemTyA) && isFP8(elemTyB), mfma_f32_16x16x32_bf8_fp8) + DISPATCH_MFMA_SSA(16, 32, isBF8(elemTyA) && isBF8(elemTyB), mfma_f32_16x16x32_bf8_bf8) + DISPATCH_MFMA_SSA(32, 16, isFP8(elemTyA) && isFP8(elemTyB), mfma_f32_32x32x16_fp8_fp8) + DISPATCH_MFMA_SSA(32, 16, isFP8(elemTyA) && isBF8(elemTyB), mfma_f32_32x32x16_fp8_bf8) + DISPATCH_MFMA_SSA(32, 16, isBF8(elemTyA) && isFP8(elemTyB), mfma_f32_32x32x16_bf8_fp8) + DISPATCH_MFMA_SSA(32, 16, isBF8(elemTyA) && isBF8(elemTyB), mfma_f32_32x32x16_bf8_bf8) + +#undef DISPATCH_MFMA_SSA + + return failure(); } LogicalResult MmaOpCDNA3_MFMAType::emitAtomCall(OpBuilder &builder, Location loc, Type mmaAtomTy, @@ -178,7 +232,6 @@ LogicalResult MmaOpCDNA3_MFMAType::emitAtomCall(OpBuilder &builder, Location loc Value atomVal, Value dPtr, Value aPtr, Value bPtr, Value cPtr) const { int32_t m = getM(); - int32_t n = getN(); int32_t k = getK(); Type elemTyA = getElemTyA(); Type elemTyB = getElemTyB(); @@ -196,41 +249,15 @@ LogicalResult MmaOpCDNA3_MFMAType::emitAtomCall(OpBuilder &builder, Location loc Type accElemTy = getElemTyAcc(); VectorType accTy = VectorType::get({accVecSize}, accElemTy); -#define DISPATCH_MFMA(M_, K_, PRED, OP) \ - if (m == M_ && n == M_ && k == K_ && (PRED)) \ - return emitMfma(builder, loc, abTyA, abTyB, accTy, aPtr, bPtr, cPtr, dPtr); - - DISPATCH_MFMA(32, 1, elemTyA.isF32(), mfma_f32_32x32x1f32) - DISPATCH_MFMA(16, 1, elemTyA.isF32(), mfma_f32_16x16x1f32) - DISPATCH_MFMA(4, 1, elemTyA.isF32(), mfma_f32_4x4x1f32) - DISPATCH_MFMA(32, 2, elemTyA.isF32(), mfma_f32_32x32x2f32) - DISPATCH_MFMA(16, 4, elemTyA.isF32(), mfma_f32_16x16x4f32) - - DISPATCH_MFMA(32, 4, elemTyA.isF16(), mfma_f32_32x32x4f16) - DISPATCH_MFMA(16, 4, elemTyA.isF16(), mfma_f32_16x16x4f16) - DISPATCH_MFMA(4, 4, elemTyA.isF16(), mfma_f32_4x4x4f16) - DISPATCH_MFMA(32, 8, elemTyA.isF16(), mfma_f32_32x32x8f16) - DISPATCH_MFMA(16, 16, elemTyA.isF16(), mfma_f32_16x16x16f16) - - DISPATCH_MFMA(32, 2, elemTyA.isBF16(), mfma_f32_32x32x2bf16) - DISPATCH_MFMA(16, 2, elemTyA.isBF16(), mfma_f32_16x16x2bf16) - DISPATCH_MFMA(4, 2, elemTyA.isBF16(), mfma_f32_4x4x2bf16) - DISPATCH_MFMA(32, 4, elemTyA.isBF16(), mfma_f32_32x32x4bf16) - DISPATCH_MFMA(16, 8, elemTyA.isBF16(), mfma_f32_16x16x8bf16) - DISPATCH_MFMA(16, 16, elemTyA.isBF16(), mfma_f32_16x16x16bf16_1k) - - DISPATCH_MFMA(16, 32, isFP8(elemTyA) && isFP8(elemTyB), mfma_f32_16x16x32_fp8_fp8) - DISPATCH_MFMA(16, 32, isFP8(elemTyA) && isBF8(elemTyB), mfma_f32_16x16x32_fp8_bf8) - DISPATCH_MFMA(16, 32, isBF8(elemTyA) && isFP8(elemTyB), mfma_f32_16x16x32_bf8_fp8) - DISPATCH_MFMA(16, 32, isBF8(elemTyA) && isBF8(elemTyB), mfma_f32_16x16x32_bf8_bf8) - DISPATCH_MFMA(32, 16, isFP8(elemTyA) && isFP8(elemTyB), mfma_f32_32x32x16_fp8_fp8) - DISPATCH_MFMA(32, 16, isFP8(elemTyA) && isBF8(elemTyB), mfma_f32_32x32x16_fp8_bf8) - DISPATCH_MFMA(32, 16, isBF8(elemTyA) && isFP8(elemTyB), mfma_f32_32x32x16_bf8_fp8) - DISPATCH_MFMA(32, 16, isBF8(elemTyA) && isBF8(elemTyB), mfma_f32_32x32x16_bf8_bf8) - -#undef DISPATCH_MFMA - - return failure(); + Value a = LLVM::LoadOp::create(builder, loc, abTyA, aPtr); + Value b = LLVM::LoadOp::create(builder, loc, abTyB, bPtr); + Value c = LLVM::LoadOp::create(builder, loc, accTy, cPtr); + auto res = emitAtomCallSSA(builder, loc, accTy, mmaAtomTy, Type{}, abTyA, abTyB, accTy, atomVal, + Value{}, a, b, c); + if (failed(res)) + return failure(); + LLVM::StoreOp::create(builder, loc, *res, dPtr); + return success(); } } // namespace mlir::fly_rocdl diff --git a/lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp b/lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp index ebca6e49..808b6aaf 100644 --- a/lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/PointerUtils.h" #include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" #include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" @@ -65,37 +66,65 @@ Attribute CopyOpCDNA4LdsReadTransposeType::getThrBitLayoutRef() const { return getThrBitLayoutDst(); } -LogicalResult CopyOpCDNA4LdsReadTransposeType::emitAtomCall(OpBuilder &builder, Location loc, - Type copyAtomTyArg, Type srcMemTyArg, - Type dstMemTyArg, Value atomVal, - Value src, Value dst) const { - auto srcMemTy = cast(srcMemTyArg); - auto dstMemTy = cast(dstMemTyArg); - - AddressSpace srcAS = srcMemTy.getAddressSpace().getValue(); - AddressSpace dstAS = dstMemTy.getAddressSpace().getValue(); - - if (srcAS != AddressSpace::Shared || dstAS != AddressSpace::Register) - return failure(); - +FailureOr CopyOpCDNA4LdsReadTransposeType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type copyAtomTyArg, + Type srcTyArg, Type dstTyArg, + Value atomVal, Value src, + Value dst) const { int32_t bitSize = getBitSize(); int32_t transGranularity = getTransGranularity(); - IntegerType copyTy = builder.getIntegerType(getBitSize()); Value loaded; if (bitSize == 64 && transGranularity == 4) { - loaded = ROCDL::ds_read_tr4_b64::create(builder, loc, copyTy, src).getResult(); + auto intrTy = VectorType::get({2}, builder.getI32Type()); + loaded = ROCDL::ds_read_tr4_b64::create(builder, loc, intrTy, src); } else if (bitSize == 64 && transGranularity == 8) { - loaded = ROCDL::ds_read_tr8_b64::create(builder, loc, copyTy, src).getResult(); + auto intrTy = VectorType::get({2}, builder.getI32Type()); + loaded = ROCDL::ds_read_tr8_b64::create(builder, loc, intrTy, src); } else if (bitSize == 96 && transGranularity == 6) { - loaded = ROCDL::ds_read_tr6_b96::create(builder, loc, copyTy, src).getResult(); + auto intrTy = VectorType::get({3}, builder.getI32Type()); + loaded = ROCDL::ds_read_tr6_b96::create(builder, loc, intrTy, src); } else if (bitSize == 64 && transGranularity == 16) { - loaded = ROCDL::ds_read_tr16_b64::create(builder, loc, copyTy, src).getResult(); + auto intrTy = VectorType::get({4}, builder.getI16Type()); + loaded = ROCDL::ds_read_tr16_b64::create(builder, loc, intrTy, src); } else { return failure(); } - LLVM::StoreOp::create(builder, loc, loaded, dst); + if (resultTy && loaded.getType() != resultTy) + loaded = LLVM::BitcastOp::create(builder, loc, resultTy, loaded); + + return loaded; +} + +FailureOr CopyOpCDNA4LdsReadTransposeType::emitAtomCallSSA( + OpBuilder &builder, Location loc, Type resultTy, Type copyAtomTyArg, Type srcTyArg, + Type dstTyArg, Type predTyArg, Value atomVal, Value src, Value dst, Value pred) const { + assert(resultTy && "resultTy must be SSA Type"); + OpBuilder::InsertionGuard guard(builder); + auto ifOp = scf::IfOp::create(builder, loc, resultTy, pred, /*withElseRegion=*/true); + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + auto result = + emitAtomCallSSA(builder, loc, resultTy, copyAtomTyArg, srcTyArg, dstTyArg, atomVal, src, dst); + if (failed(result)) + return failure(); + scf::YieldOp::create(builder, loc, *result); + + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + scf::YieldOp::create(builder, loc, dst); + return ifOp.getResult(0); +} + +LogicalResult CopyOpCDNA4LdsReadTransposeType::emitAtomCall(OpBuilder &builder, Location loc, + Type copyAtomTyArg, Type srcMemTyArg, + Type dstMemTyArg, Value atomVal, + Value src, Value dst) const { + auto dstSSATy = fly::RegMem2SSAType(cast(dstMemTyArg)); + auto res = emitAtomCallSSA(builder, loc, dstSSATy, copyAtomTyArg, srcMemTyArg, Type{}, atomVal, + src, Value{}); + if (failed(res)) + return failure(); + LLVM::StoreOp::create(builder, loc, *res, dst); return success(); } @@ -104,6 +133,7 @@ LogicalResult CopyOpCDNA4LdsReadTransposeType::emitAtomCall(OpBuilder &builder, Type dstMemTyArg, Type predMemTyArg, Value atomVal, Value src, Value dst, Value pred) const { + OpBuilder::InsertionGuard guard(builder); auto predMemTy = cast(predMemTyArg); Value predVal = LLVM::LoadOp::create(builder, loc, predMemTy.getElemTy(), pred); auto ifOp = scf::IfOp::create(builder, loc, TypeRange{}, predVal, /*withElse=*/false); diff --git a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp index 7a372006..0e9dbb41 100644 --- a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp @@ -212,11 +212,8 @@ static int64_t getWmmaAccVecSize(int32_t m, int32_t k, Type elemTyA, Type elemTy enum class WmmaVariant { ModsAllReuse, ModsC, ModsABClamp }; template -static LogicalResult emitWmma(OpBuilder &builder, Location loc, Type abTyA, Type abTyB, - VectorType accTy, Value aPtr, Value bPtr, Value cPtr, Value dPtr) { - Value a = LLVM::LoadOp::create(builder, loc, abTyA, aPtr); - Value b = LLVM::LoadOp::create(builder, loc, abTyB, bPtr); - Value c = LLVM::LoadOp::create(builder, loc, accTy, cPtr); +static FailureOr emitWmmaSSA(OpBuilder &builder, Location loc, VectorType accTy, Value a, + Value b, Value c) { Value res; if constexpr (Variant == WmmaVariant::ModsAllReuse) { res = WmmaOp::create(builder, loc, accTy, @@ -235,8 +232,72 @@ static LogicalResult emitWmma(OpBuilder &builder, Location loc, Type abTyA, Type /*reuseA=*/false, /*reuseB=*/false, /*clamp=*/false) .getResult(); } - LLVM::StoreOp::create(builder, loc, res, dPtr); - return success(); + return res; +} + +FailureOr MmaOpGFX1250_WMMAType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type mmaAtomTyArg, + Type dTyArg, Type aTyArg, Type bTyArg, + Type cTyArg, Value atomVal, Value d, + Value a, Value b, Value c) const { + int32_t m = getM(); + int32_t n = getN(); + int32_t k = getK(); + Type elemTyA = getElemTyA(); + Type elemTyB = getElemTyB(); + Type elemTyAcc = getElemTyAcc(); + MLIRContext *ctx = builder.getContext(); + + Type abTyA = getWmmaABType(ctx, m, k, elemTyA); + Type abTyB = getWmmaABType(ctx, m, k, elemTyB); + if (!abTyA || !abTyB) + return failure(); + + int64_t accVecSize = getWmmaAccVecSize(m, k, elemTyA, elemTyB, elemTyAcc); + if (accVecSize == 0) + return failure(); + + VectorType accTy = VectorType::get({accVecSize}, elemTyAcc); + +#define DISPATCH_WMMA_SSA(M_, K_, PRED, OP, VARIANT) \ + if (m == M_ && n == M_ && k == K_ && (PRED)) { \ + return emitWmmaSSA(builder, loc, accTy, a, b, c); \ + } + +#define DISPATCH_WMMA_SSA_FP8(K_, ACC_PRED, ACC_PREFIX) \ + DISPATCH_WMMA_SSA(16, K_, isFP8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_fp8_fp8, ModsC) \ + DISPATCH_WMMA_SSA(16, K_, isFP8(elemTyA) && isBF8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_fp8_bf8, ModsC) \ + DISPATCH_WMMA_SSA(16, K_, isBF8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_bf8_fp8, ModsC) \ + DISPATCH_WMMA_SSA(16, K_, isBF8(elemTyA) && isBF8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_bf8_bf8, ModsC) + + DISPATCH_WMMA_SSA(16, 4, elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32(), + wmma_f32_16x16x4_f32, ModsAllReuse) + + DISPATCH_WMMA_SSA(16, 32, elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32(), + wmma_f32_16x16x32_f16, ModsAllReuse) + DISPATCH_WMMA_SSA(16, 32, elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32(), + wmma_f32_16x16x32_bf16, ModsAllReuse) + DISPATCH_WMMA_SSA(16, 32, elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF16(), + wmma_f16_16x16x32_f16, ModsAllReuse) + DISPATCH_WMMA_SSA(16, 32, elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isBF16(), + wmma_bf16_16x16x32_bf16, ModsAllReuse) + + DISPATCH_WMMA_SSA_FP8(64, elemTyAcc.isF32(), f32) + DISPATCH_WMMA_SSA_FP8(64, elemTyAcc.isF16(), f16) + DISPATCH_WMMA_SSA_FP8(128, elemTyAcc.isF32(), f32) + DISPATCH_WMMA_SSA_FP8(128, elemTyAcc.isF16(), f16) + + DISPATCH_WMMA_SSA(16, 64, elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32), + wmma_i32_16x16x64_iu8, ModsABClamp) + +#undef DISPATCH_WMMA_SSA_FP8 +#undef DISPATCH_WMMA_SSA + + return failure(); } LogicalResult MmaOpGFX1250_WMMAType::emitAtomCall(OpBuilder &builder, Location loc, Type mmaAtomTy, @@ -252,7 +313,7 @@ LogicalResult MmaOpGFX1250_WMMAType::emitAtomCall(OpBuilder &builder, Location l MLIRContext *ctx = builder.getContext(); Type abTyA = getWmmaABType(ctx, m, k, elemTyA); - Type abTyB = getWmmaABType(ctx, m, k, elemTyB); + Type abTyB = getWmmaABType(ctx, n, k, elemTyB); if (!abTyA || !abTyB) return failure(); @@ -262,45 +323,15 @@ LogicalResult MmaOpGFX1250_WMMAType::emitAtomCall(OpBuilder &builder, Location l VectorType accTy = VectorType::get({accVecSize}, elemTyAcc); -#define DISPATCH_WMMA(M_, K_, PRED, OP, VARIANT) \ - if (m == M_ && n == M_ && k == K_ && (PRED)) \ - return emitWmma(builder, loc, abTyA, abTyB, accTy, aPtr, \ - bPtr, cPtr, dPtr); - -#define DISPATCH_WMMA_FP8(K_, ACC_PRED, ACC_PREFIX) \ - DISPATCH_WMMA(16, K_, isFP8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ - wmma_##ACC_PREFIX##_16x16x##K_##_fp8_fp8, ModsC) \ - DISPATCH_WMMA(16, K_, isFP8(elemTyA) && isBF8(elemTyB) && ACC_PRED, \ - wmma_##ACC_PREFIX##_16x16x##K_##_fp8_bf8, ModsC) \ - DISPATCH_WMMA(16, K_, isBF8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ - wmma_##ACC_PREFIX##_16x16x##K_##_bf8_fp8, ModsC) \ - DISPATCH_WMMA(16, K_, isBF8(elemTyA) && isBF8(elemTyB) && ACC_PRED, \ - wmma_##ACC_PREFIX##_16x16x##K_##_bf8_bf8, ModsC) - - DISPATCH_WMMA(16, 4, elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32(), - wmma_f32_16x16x4_f32, ModsAllReuse) - - DISPATCH_WMMA(16, 32, elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32(), - wmma_f32_16x16x32_f16, ModsAllReuse) - DISPATCH_WMMA(16, 32, elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32(), - wmma_f32_16x16x32_bf16, ModsAllReuse) - DISPATCH_WMMA(16, 32, elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF16(), - wmma_f16_16x16x32_f16, ModsAllReuse) - DISPATCH_WMMA(16, 32, elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isBF16(), - wmma_bf16_16x16x32_bf16, ModsAllReuse) - - DISPATCH_WMMA_FP8(64, elemTyAcc.isF32(), f32) - DISPATCH_WMMA_FP8(64, elemTyAcc.isF16(), f16) - DISPATCH_WMMA_FP8(128, elemTyAcc.isF32(), f32) - DISPATCH_WMMA_FP8(128, elemTyAcc.isF16(), f16) - - DISPATCH_WMMA(16, 64, elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32), - wmma_i32_16x16x64_iu8, ModsABClamp) - -#undef DISPATCH_WMMA_FP8 -#undef DISPATCH_WMMA - - return failure(); + Value a = LLVM::LoadOp::create(builder, loc, abTyA, aPtr); + Value b = LLVM::LoadOp::create(builder, loc, abTyB, bPtr); + Value c = LLVM::LoadOp::create(builder, loc, accTy, cPtr); + auto res = emitAtomCallSSA(builder, loc, Type{}, mmaAtomTy, accTy, abTyA, abTyB, accTy, atomVal, + Value{}, a, b, c); + if (failed(res)) + return failure(); + LLVM::StoreOp::create(builder, loc, *res, dPtr); + return success(); } } // namespace mlir::fly_rocdl diff --git a/tests/mlir/Conversion/copy_atom_stateful.mlir b/tests/mlir/Conversion/copy_atom_stateful.mlir index a1d4ed3a..103e6848 100644 --- a/tests/mlir/Conversion/copy_atom_stateful.mlir +++ b/tests/mlir/Conversion/copy_atom_stateful.mlir @@ -74,10 +74,11 @@ func.func @test_copy_atom_call_store_with_soffset( %atom: !fly.copy_atom, 32>, %src: !fly.memref, %dst: !fly.memref) { - // CHECK: %[[SOFF_RAW:.*]] = llvm.extractvalue %[[ATOM]][0] - // CHECK: %[[SOFF:.*]] = arith.muli %[[SOFF_RAW]], %{{.*}} - // CHECK: llvm.load %[[SRC]] - // CHECK: rocdl.raw.ptr.buffer.store %{{.*}}, %{{.*}}, %{{.*}}, %[[SOFF]] + // CHECK-DAG: %[[SOFF_RAW:.*]] = llvm.extractvalue %[[ATOM]][0] + // CHECK-DAG: %[[SOFF:.*]] = arith.muli %[[SOFF_RAW]], %{{.*}} + // CHECK-DAG: %[[VAL:.*]] = llvm.load %[[SRC]] + // CHECK-DAG: %[[CAST:.*]] = llvm.bitcast %[[VAL]] + // CHECK: rocdl.raw.ptr.buffer.store %[[CAST]], %{{.*}}, %{{.*}}, %[[SOFF]] fly.copy_atom_call(%atom, %src, %dst) : (!fly.copy_atom, 32>, !fly.memref, !fly.memref) -> () return } diff --git a/tests/mlir/Conversion/pointer_ops.mlir b/tests/mlir/Conversion/pointer_ops.mlir index 57809570..c3827c3b 100644 --- a/tests/mlir/Conversion/pointer_ops.mlir +++ b/tests/mlir/Conversion/pointer_ops.mlir @@ -3,52 +3,11 @@ // RUN: %fly-opt %s --convert-fly-to-rocdl | FileCheck %s // Pointer operation lowering tests: -// - fly.get_iter -> identity (memref ptr passthrough) // - fly.add_offset -> llvm.getelementptr // - fly.make_view -> identity/bitcast // ----- -// === GetIter (identity) === - -// get_iter extracts a raw pointer from a memref, then add_offset advances it. -// After lowering, get_iter becomes identity and add_offset becomes GEP. - -// CHECK-LABEL: @test_get_iter_global -// CHECK-SAME: (%[[MEM:.*]]: !llvm.ptr<1>) -func.func @test_get_iter_global(%mem: !fly.memref) { - // CHECK-NOT: fly.get_iter - %iter = fly.get_iter(%mem) : (!fly.memref) -> !fly.ptr - %offset = fly.make_int_tuple() : () -> !fly.int_tuple<8> - // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 - %result = fly.add_offset(%iter, %offset) : (!fly.ptr, !fly.int_tuple<8>) -> !fly.ptr - return -} - -// CHECK-LABEL: @test_get_iter_shared -// CHECK-SAME: (%[[MEM:.*]]: !llvm.ptr<3>) -func.func @test_get_iter_shared(%mem: !fly.memref) { - // CHECK-NOT: fly.get_iter - %iter = fly.get_iter(%mem) : (!fly.memref) -> !fly.ptr - %offset = fly.make_int_tuple() : () -> !fly.int_tuple<4> - // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32 - %result = fly.add_offset(%iter, %offset) : (!fly.ptr, !fly.int_tuple<4>) -> !fly.ptr - return -} - -// CHECK-LABEL: @test_get_iter_register -// CHECK-SAME: (%[[MEM:.*]]: !llvm.ptr<5>) -func.func @test_get_iter_register(%mem: !fly.memref) { - // CHECK-NOT: fly.get_iter - %iter = fly.get_iter(%mem) : (!fly.memref) -> !fly.ptr - %offset = fly.make_int_tuple() : () -> !fly.int_tuple<2> - // CHECK: llvm.getelementptr %[[MEM]][{{.*}}] : (!llvm.ptr<5>, i32) -> !llvm.ptr<5>, f32 - %result = fly.add_offset(%iter, %offset) : (!fly.ptr, !fly.int_tuple<2>) -> !fly.ptr - return -} - -// ----- - // === AddOffset === // CHECK-LABEL: @test_add_offset_static @@ -89,25 +48,3 @@ gpu.module @dyn_shared_module { gpu.return } } - -// ----- - -// === MakeView (identity when address spaces match) === - -// CHECK-LABEL: @test_make_view -// CHECK-SAME: (%[[PTR:.*]]: !llvm.ptr<1>) -func.func @test_make_view(%ptr: !fly.ptr) -> f32 { - %s = fly.make_int_tuple() : () -> !fly.int_tuple<(4, 8)> - %d = fly.make_int_tuple() : () -> !fly.int_tuple<(1, 4)> - %layout = fly.make_layout(%s, %d) : (!fly.int_tuple<(4, 8)>, !fly.int_tuple<(1, 4)>) -> !fly.layout<(4, 8) : (1, 4)> - // CHECK-NOT: fly.make_view - %view = fly.make_view(%ptr, %layout) : (!fly.ptr, !fly.layout<(4, 8) : (1, 4)>) -> !fly.memref - %iter = fly.get_iter(%view) : (!fly.memref) -> !fly.ptr - %offset = fly.make_int_tuple() : () -> !fly.int_tuple<7> - // CHECK: llvm.getelementptr %[[PTR]][{{.*}}] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32 - %gep = fly.add_offset(%iter, %offset) : (!fly.ptr, !fly.int_tuple<7>) -> !fly.ptr - // CHECK: %[[VAL:.*]] = llvm.load - %val = fly.ptr.load(%gep) : (!fly.ptr) -> f32 - // CHECK: return %[[VAL]] - return %val : f32 -} diff --git a/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir b/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir new file mode 100644 index 00000000..7db68b08 --- /dev/null +++ b/tests/mlir/Transforms/convert-atom-call-to-ssa-form.mlir @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s --fly-convert-atom-call-to-ssa-form | FileCheck %s + +gpu.module @convert_atom_call_to_ssa_form { + + // Test 1: copy_atom_call with register dst (rank=1, stride=1) should be promoted + // CHECK-LABEL: gpu.func @copy_dst_register + // CHECK-NOT: fly.copy_atom_call( + // CHECK: %[[REG_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK: %[[SSA:.*]] = fly.copy_atom_call_ssa(%{{.*}}, %{{.*}}) {operandSegmentSizes = array} + // CHECK-SAME: : (!fly.copy_atom, 16>, !fly.memref) -> vector<4xf16> + // CHECK: fly.ptr.store(%[[SSA]], %[[REG_PTR]]) : (vector<4xf16>, !fly.ptr) -> () + gpu.func @copy_dst_register(%src: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %src_view = fly.make_view(%src, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + fly.copy_atom_call(%copy, %src_view, %reg_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + gpu.return + } + + // Test 1b: copy_atom_call with register src (rank=1, stride=1) should be promoted + // src is pre-loaded via ptr.load, then passed to copy_atom_call_ssa as vector operand + // CHECK-LABEL: gpu.func @copy_src_register + // CHECK-NOT: fly.copy_atom_call( + // CHECK: %[[SRC_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK: %[[VEC:.*]] = fly.ptr.load(%[[SRC_PTR]]) : (!fly.ptr) -> vector<4xf16> + // CHECK: fly.copy_atom_call_ssa(%{{.*}}, %[[VEC]], %{{.*}}) {operandSegmentSizes = array} + // CHECK-SAME: : (!fly.copy_atom, 16>, vector<4xf16>, !fly.memref) -> () + gpu.func @copy_src_register(%dst: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %dst_view = fly.make_view(%dst, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + fly.copy_atom_call(%copy, %reg_view, %dst_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + gpu.return + } + + // Test 1c: copy_atom_call with both src and dst register should be promoted + // src is pre-loaded via ptr.load, result stored back to dst register + // CHECK-LABEL: gpu.func @copy_both_register + // CHECK-NOT: fly.copy_atom_call( + // CHECK: %[[SRC_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK: %[[DST_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK: %[[VEC:.*]] = fly.ptr.load(%[[SRC_PTR]]) : (!fly.ptr) -> vector<4xf16> + // CHECK: %[[SSA:.*]] = fly.copy_atom_call_ssa(%{{.*}}, %[[VEC]]) {operandSegmentSizes = array} + // CHECK-SAME: : (!fly.copy_atom, 16>, vector<4xf16>) -> vector<4xf16> + // CHECK: fly.ptr.store(%[[SSA]], %[[DST_PTR]]) : (vector<4xf16>, !fly.ptr) -> () + gpu.func @copy_both_register() kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + + %src_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %src_view = fly.make_view(%src_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %dst_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %dst_view = fly.make_view(%dst_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + fly.copy_atom_call(%copy, %src_view, %dst_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + gpu.return + } + + // Test 2: copy_atom_call with non-register src and dst should NOT be promoted + // CHECK-LABEL: gpu.func @copy_global_unchanged + // CHECK: fly.copy_atom_call( + // CHECK-NOT: fly.copy_atom_call_ssa + gpu.func @copy_global_unchanged(%src: !fly.ptr, %dst: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %src_view = fly.make_view(%src, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %dst_view = fly.make_view(%dst, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + + fly.copy_atom_call(%copy, %src_view, %dst_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + gpu.return + } + + // Test 3: copy_atom_call with register dst, non-leaf layout that coalesces to rank=1 stride=1 + // (4,1):(1,0) coalesces to 4:1, so should be promoted + // CHECK-LABEL: gpu.func @copy_dst_register_coalescable + // CHECK-NOT: fly.copy_atom_call( + // CHECK: %[[REG_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK: %[[SSA:.*]] = fly.copy_atom_call_ssa(%{{.*}}, %{{.*}}) {operandSegmentSizes = array} + // CHECK-SAME: : (!fly.copy_atom, 16>, !fly.memref) -> vector<4xf16> + // CHECK: fly.ptr.store(%[[SSA]], %[[REG_PTR]]) : (vector<4xf16>, !fly.ptr) -> () + gpu.func @copy_dst_register_coalescable(%src: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %src_view = fly.make_view(%src, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + %acc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,1)> + %acc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,0)> + %acc_layout = fly.make_layout(%acc_shape, %acc_stride) : (!fly.int_tuple<(4,1)>, !fly.int_tuple<(1,0)>) -> !fly.layout<(4,1):(1,0)> + + %copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %acc_layout) : (!fly.ptr, !fly.layout<(4,1):(1,0)>) -> !fly.memref + + fly.copy_atom_call(%copy, %src_view, %reg_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + gpu.return + } + + // Test 3b: copy_atom_call with register dst, non-coalescable layout should NOT be promoted + // (4,2):(1,8) cannot coalesce to rank=1 stride=1 + // CHECK-LABEL: gpu.func @copy_dst_register_non_coalescable + // CHECK: fly.copy_atom_call( + // CHECK-NOT: fly.copy_atom_call_ssa + gpu.func @copy_dst_register_non_coalescable(%src: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %src_view = fly.make_view(%src, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + %nc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,2)> + %nc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,8)> + %nc_layout = fly.make_layout(%nc_shape, %nc_stride) : (!fly.int_tuple<(4,2)>, !fly.int_tuple<(1,8)>) -> !fly.layout<(4,2):(1,8)> + + %copy = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %nc_layout) : (!fly.ptr, !fly.layout<(4,2):(1,8)>) -> !fly.memref + + fly.copy_atom_call(%copy, %src_view, %reg_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + gpu.return + } + + // Test 4: mma_atom_call with register d (rank=1, stride=1) should be promoted + // a, b, c are also register eligible, so they get pre-loaded as vectors + // CHECK-LABEL: gpu.func @mma_d_register + // CHECK-NOT: fly.mma_atom_call( + // CHECK-DAG: %[[A_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK-DAG: %[[B_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK-DAG: %[[D_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK-DAG: %[[C_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK: %[[A:.*]] = fly.ptr.load(%[[A_PTR]]) : (!fly.ptr) -> vector<4xf16> + // CHECK: %[[B:.*]] = fly.ptr.load(%[[B_PTR]]) : (!fly.ptr) -> vector<4xf16> + // CHECK: %[[C:.*]] = fly.ptr.load(%[[C_PTR]]) : (!fly.ptr) -> vector<4xf32> + // CHECK: %[[SSA:.*]] = fly.mma_atom_call_ssa(%{{.*}}, %[[A]], %[[B]], %[[C]]) + // CHECK-SAME: -> vector<4xf32> + // CHECK: fly.ptr.store(%[[SSA]], %[[D_PTR]]) : (vector<4xf32>, !fly.ptr) -> () + gpu.func @mma_d_register(%out: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4_f16 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + %vec4_f32 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %a_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %b_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %d_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %c_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + + %a_view = fly.make_view(%a_ptr, %vec4_f16) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %b_view = fly.make_view(%b_ptr, %vec4_f16) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %d_view = fly.make_view(%d_ptr, %vec4_f32) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %c_view = fly.make_view(%c_ptr, %vec4_f32) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + %atom = fly.make_mma_atom : !fly.mma_atom f32>> + + fly.mma_atom_call(%atom, %d_view, %a_view, %b_view, %c_view) : (!fly.mma_atom f32>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + gpu.return + } + + // Test 5: mma_atom_call with non-leaf register d layout that coalesces to rank=1 stride=1 + // (4,1):(1,0) coalesces to 4:1, so should be promoted. a, b, c also pre-loaded + // CHECK-LABEL: gpu.func @mma_d_register_coalescable + // CHECK-NOT: fly.mma_atom_call( + // CHECK-DAG: %[[A_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK-DAG: %[[B_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK-DAG: %[[D_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + // CHECK-DAG: %[[C_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + // CHECK: %[[A:.*]] = fly.ptr.load(%[[A_PTR]]) : (!fly.ptr) -> vector<4xf16> + // CHECK: %[[B:.*]] = fly.ptr.load(%[[B_PTR]]) : (!fly.ptr) -> vector<4xf16> + // CHECK: %[[C:.*]] = fly.ptr.load(%[[C_PTR]]) : (!fly.ptr) -> vector<4xf32> + // CHECK: %[[SSA:.*]] = fly.mma_atom_call_ssa(%{{.*}}, %[[A]], %[[B]], %[[C]]) + // CHECK-SAME: -> vector<4xf32> + // CHECK: fly.ptr.store(%[[SSA]], %[[D_PTR]]) : (vector<4xf32>, !fly.ptr) -> () + gpu.func @mma_d_register_coalescable(%out: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4_f16 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %acc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,1)> + %acc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,0)> + %acc_layout = fly.make_layout(%acc_shape, %acc_stride) : (!fly.int_tuple<(4,1)>, !fly.int_tuple<(1,0)>) -> !fly.layout<(4,1):(1,0)> + + %a_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %b_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %d_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + %c_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + + %a_view = fly.make_view(%a_ptr, %vec4_f16) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %b_view = fly.make_view(%b_ptr, %vec4_f16) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %d_view = fly.make_view(%d_ptr, %acc_layout) : (!fly.ptr, !fly.layout<(4,1):(1,0)>) -> !fly.memref + %c_view = fly.make_view(%c_ptr, %acc_layout) : (!fly.ptr, !fly.layout<(4,1):(1,0)>) -> !fly.memref + + %atom = fly.make_mma_atom : !fly.mma_atom f32>> + + fly.mma_atom_call(%atom, %d_view, %a_view, %b_view, %c_view) : (!fly.mma_atom f32>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + gpu.return + } + + // Test 5b: mma_atom_call with register d non-coalescable, but a/b are register eligible + // d and c have (4,2):(1,8) which cannot coalesce, but a/b have 4:1 which is eligible + // a and b should be pre-loaded as vectors, mma_atom_call_ssa is used + // d and c remain as memref (not promoted to SSA) + // CHECK-LABEL: gpu.func @mma_d_register_non_coalescable + // CHECK-DAG: %[[A_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK-DAG: %[[B_PTR:.*]] = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + // CHECK: %[[A:.*]] = fly.ptr.load(%[[A_PTR]]) : (!fly.ptr) -> vector<4xf16> + // CHECK: %[[B:.*]] = fly.ptr.load(%[[B_PTR]]) : (!fly.ptr) -> vector<4xf16> + // CHECK: fly.mma_atom_call_ssa(%{{.*}}, %{{.*}}, %[[A]], %[[B]], %{{.*}}) : + // CHECK-SAME: (!fly.mma_atom f32>>, + // CHECK-SAME: !fly.memref, vector<4xf16>, vector<4xf16>, + // CHECK-SAME: !fly.memref) -> () + // CHECK-NOT: fly.ptr.store + gpu.func @mma_d_register_non_coalescable(%out: !fly.ptr) kernel { + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4_f16 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %nc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,2)> + %nc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,8)> + %nc_layout = fly.make_layout(%nc_shape, %nc_stride) : (!fly.int_tuple<(4,2)>, !fly.int_tuple<(1,8)>) -> !fly.layout<(4,2):(1,8)> + + %a_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %b_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %d_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + %c_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + + %a_view = fly.make_view(%a_ptr, %vec4_f16) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %b_view = fly.make_view(%b_ptr, %vec4_f16) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %d_view = fly.make_view(%d_ptr, %nc_layout) : (!fly.ptr, !fly.layout<(4,2):(1,8)>) -> !fly.memref + %c_view = fly.make_view(%c_ptr, %nc_layout) : (!fly.ptr, !fly.layout<(4,2):(1,8)>) -> !fly.memref + + %atom = fly.make_mma_atom : !fly.mma_atom f32>> + + fly.mma_atom_call(%atom, %d_view, %a_view, %b_view, %c_view) : (!fly.mma_atom f32>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + gpu.return + } +} From 2e71f65b494ca5b180b5d7c0424331fdefe9b76e Mon Sep 17 00:00:00 2001 From: Elton Date: Thu, 16 Apr 2026 21:57:36 +0800 Subject: [PATCH 16/29] [docs] add frontend semantic restrictions for MLIR kernel authoring (#406) --- .../skills/flydsl-kernel-authoring/SKILL.md | 33 +++++++++++++++++++ CLAUDE.md | 3 ++ 2 files changed, 36 insertions(+) diff --git a/.claude/skills/flydsl-kernel-authoring/SKILL.md b/.claude/skills/flydsl-kernel-authoring/SKILL.md index 03af59f2..1cbb2498 100644 --- a/.claude/skills/flydsl-kernel-authoring/SKILL.md +++ b/.claude/skills/flydsl-kernel-authoring/SKILL.md @@ -224,6 +224,39 @@ for i in range(runtime_value): ... ``` +### Frontend Semantic Restrictions +When writing or reviewing `@flyc.kernel` / `@flyc.jit` code, proactively avoid these patterns because they can conflict with MLIR construction even if they look valid in plain Python. + +1. **Do not define values inside `if/else` and use them later outside the branch.** Keep a single explicit definition path. + ```python + if cond: + dst = a + else: + dst = b + use(dst) # avoid this pattern + ``` + +2. **Do not mutate captured outer variables inside nested helper functions.** Read-only closure capture is acceptable, but writes should go through explicit parameters and return values. + ```python + def kernel(): + acc = fx.Float32(0.0) + + def helper(acc): + acc = acc + fx.Float32(1.0) + return acc + + acc = helper(acc) + ``` + +3. **Avoid early `return`, and do not place `return` / `yield` inside `if/else` branches.** Prefer a single explicit exit so the frontend can determine result types. + ```python + if cond: + out = v0 + else: + out = v1 + return out + ``` + ### scf.for with Loop-Carried Values (Software Pipelining) Use `init=` on `range()` to create an `scf.for` with explicit SSA phi nodes for loop-carried state. This is required for software pipelining (prefetch patterns) where data must flow across iterations. diff --git a/CLAUDE.md b/CLAUDE.md index 15c787e5..9a7f077a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -74,3 +74,6 @@ FLYDSL_DUMP_IR=1 PYTHONPATH=./ python tests/kernels/test_pa.py # Dump MLIR IR at - **Layout API vs buffer_ops**: New kernels should use `fx.rocdl.make_buffer_tensor()` + `copy_atom_call` (layout API). Raw `buffer_ops.create_buffer_resource()` is legacy - **Arch detection**: Use `from flydsl.runtime.device import get_rocm_arch` - **`range` vs `range_constexpr`**: Use `range_constexpr` for compile-time unrolled loops; `range(start, stop, step, init=[...])` for `scf.for` with loop-carried values +- **Branch-local defs**: Do not define a value inside `if/else` and then use it after the branch. Hoist the variable or rewrite the logic so later uses see a single explicit definition path. +- **Nested helper captures**: Inside `@flyc.kernel` / `@flyc.jit`, nested helper functions must not mutate captured outer variables. Read-only capture is acceptable, but writes should go through explicit parameters / returns. +- **Single-exit control flow**: Avoid early `return`. Do not place `return` or `yield` inside `if/else` branches; keep a single explicit exit path so MLIR result types stay well-defined. From f9120a8f1045f792ba72806ba8fecdbe18c89fbf Mon Sep 17 00:00:00 2001 From: Elton Date: Thu, 16 Apr 2026 21:59:47 +0800 Subject: [PATCH 17/29] Detach requires_grad tensors before FlyDSL DLPack export (#409) --- python/flydsl/compiler/jit_argument.py | 11 ++++++----- tests/unit/test_jit_stream_param.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/flydsl/compiler/jit_argument.py b/python/flydsl/compiler/jit_argument.py index 58a05b25..7abda3bb 100644 --- a/python/flydsl/compiler/jit_argument.py +++ b/python/flydsl/compiler/jit_argument.py @@ -147,11 +147,12 @@ def __init__( assumed_align: Optional[int] = None, use_32bit_stride: bool = False, ): - self._tensor_keepalive = tensor - dlpack_tensor = tensor - if _FLOAT8_DTYPES and tensor.dtype in _FLOAT8_DTYPES: - dlpack_tensor = tensor.view(torch.uint8) - self._tensor_keepalive = dlpack_tensor + # Forward-only interop: DLPack export from torch rejects tensors that + # still participate in autograd, so detach before crossing into FlyDSL. + dlpack_tensor = tensor.detach() if tensor.requires_grad else tensor + if _FLOAT8_DTYPES and dlpack_tensor.dtype in _FLOAT8_DTYPES: + dlpack_tensor = dlpack_tensor.view(torch.uint8) + self._tensor_keepalive = dlpack_tensor self.tensor_adaptor = DLTensorAdaptor(dlpack_tensor.__dlpack__(stream=-1), assumed_align, use_32bit_stride) self.assumed_align = assumed_align diff --git a/tests/unit/test_jit_stream_param.py b/tests/unit/test_jit_stream_param.py index b29a3e6d..0d35f7c4 100644 --- a/tests/unit/test_jit_stream_param.py +++ b/tests/unit/test_jit_stream_param.py @@ -188,3 +188,15 @@ def test_multiple_parameters(self): _vecadd(a, b, c, SIZE, BLOCK_DIM, VEC_WIDTH) torch.cuda.synchronize() assert torch.allclose(c, a.data + b.data, atol=1e-5) + + def test_parameter_requires_grad_forward_only(self): + """requires_grad=True Parameter should be accepted for forward-only use.""" + w = torch.nn.Parameter( + torch.randn(SIZE, device="cuda", dtype=torch.float32), + requires_grad=True, + ) + b = torch.randn(SIZE, device="cuda", dtype=torch.float32) + c = torch.empty_like(b) + _vecadd(w, b, c, SIZE, BLOCK_DIM, VEC_WIDTH) + torch.cuda.synchronize() + assert torch.allclose(c, w.detach() + b, atol=1e-5) From 68f57256c78a6d2ffc54bf874ce326489360857c Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Fri, 17 Apr 2026 09:23:03 +0800 Subject: [PATCH 18/29] [OPT] Add pass promote-regmem-to-vectorssa (#410) * [OPT] Add pass promote-regmem-to-vectorssa * fix comments * fix empty yield case --- examples/04-preshuffle_gemm.py | 15 +- .../flydsl/Dialect/Fly/Transforms/Passes.td | 16 + lib/Dialect/Fly/CMakeLists.txt | 1 + .../Transforms/PromoteRegMemToVectorSSA.cpp | 664 ++++++++++++++++++ python/flydsl/compiler/backends/rocm.py | 4 +- python/flydsl/expr/primitive.py | 4 + .../Transforms/promote_regmem_to_ssa.mlir | 380 ++++++++++ .../promote_regmem_to_ssa_copy_pred.mlir | 141 ++++ .../promote_regmem_to_ssa_invalid.mlir | 43 ++ 9 files changed, 1261 insertions(+), 7 deletions(-) create mode 100644 lib/Dialect/Fly/Transforms/PromoteRegMemToVectorSSA.cpp create mode 100644 tests/mlir/Transforms/promote_regmem_to_ssa.mlir create mode 100644 tests/mlir/Transforms/promote_regmem_to_ssa_copy_pred.mlir create mode 100644 tests/mlir/Transforms/promote_regmem_to_ssa_invalid.mlir diff --git a/examples/04-preshuffle_gemm.py b/examples/04-preshuffle_gemm.py index d47b6956..55a34516 100644 --- a/examples/04-preshuffle_gemm.py +++ b/examples/04-preshuffle_gemm.py @@ -72,8 +72,9 @@ def gemm_kernel( mma_frag_A_retile = thr_copy_s2r_A.retile(mma_frag_A) mma_frag_B_retile = thr_copy_g2r_B.retile(mma_frag_B) - mma_frag_C_f16 = fx.make_fragment_like(mma_frag_C, fx.Float16.ir_type) - mma_frag_C_retile = thr_copy_r2g_C.retile(mma_frag_C_f16) + + gA_k_stride = fx.get_scalar(gA_k.stride[2]) + gB_k_stride = fx.get_scalar(gB_k.stride[2]) def run_pipeline_stage(read_stage, next_k, read_next=True): write_stage = read_stage ^ 1 @@ -81,13 +82,13 @@ def run_pipeline_stage(read_stage, next_k, read_next=True): if read_next: next_k = fx.Int32(next_k) fx.copy( - buffer_copy_128b.set_value("soffset", next_k * BLOCK_K), - thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom + buffer_copy_128b.set_value("soffset", next_k * gA_k_stride), + thr_gA_k[None, None, None, 0], # global offset is added on the soffset of buffer_copy_atom copy_frag_A, ) fx.copy( - buffer_copy_128b, - thr_gB_k[None, None, None, next_k], + buffer_copy_128b.set_value("soffset", next_k * gB_k_stride), + thr_gB_k[None, None, None, 0], mma_frag_B_retile[None, None, None, write_stage], ) @@ -151,6 +152,8 @@ def sched_main_iter(with_vmem=False, with_dswr=False): run_pipeline_stage(read_stage=0, next_k=K // BLOCK_K - 1) run_pipeline_stage(read_stage=1, next_k=None, read_next=False) + mma_frag_C_f16 = fx.make_fragment_like(mma_frag_C, fx.Float16.ir_type) + mma_frag_C_retile = thr_copy_r2g_C.retile(mma_frag_C_f16) mma_frag_C_f16.store(fx.arith.trunc_f(fx.T.VectorType.get([64], fx.T.f16()), mma_frag_C.load())) fx.copy(buffer_copy_16b, mma_frag_C_retile, thr_gC) diff --git a/include/flydsl/Dialect/Fly/Transforms/Passes.td b/include/flydsl/Dialect/Fly/Transforms/Passes.td index 6277bd71..c3b60aca 100644 --- a/include/flydsl/Dialect/Fly/Transforms/Passes.td +++ b/include/flydsl/Dialect/Fly/Transforms/Passes.td @@ -62,4 +62,20 @@ def FlyConvertAtomCallToSSAFormPass : Pass<"fly-convert-atom-call-to-ssa-form"> ]; } +def FlyPromoteRegMemToVectorSSAPass : Pass<"fly-promote-regmem-to-vectorssa"> { + let summary = "Promote register memory to vector SSA values"; + let description = [{ + Promotes fly.make_ptr(register) memory semantics to vector SSA values. + Requires fly-convert-atom-call-to-ssa-form to have run first. + }]; + + let dependentDialects = [ + "arith::ArithDialect", + "gpu::GPUDialect", + "scf::SCFDialect", + "vector::VectorDialect", + "ub::UBDialect" + ]; +} + #endif // FLY_PASSES diff --git a/lib/Dialect/Fly/CMakeLists.txt b/lib/Dialect/Fly/CMakeLists.txt index 3e0b6f14..a1d2b89b 100644 --- a/lib/Dialect/Fly/CMakeLists.txt +++ b/lib/Dialect/Fly/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRFlyDialect Transforms/Canonicalize.cpp Transforms/RewriteFuncSignature.cpp Transforms/ConvertAtomCallToSSAForm.cpp + Transforms/PromoteRegMemToVectorSSA.cpp DEPENDS MLIRFlyIncGen diff --git a/lib/Dialect/Fly/Transforms/PromoteRegMemToVectorSSA.cpp b/lib/Dialect/Fly/Transforms/PromoteRegMemToVectorSSA.cpp new file mode 100644 index 00000000..9b636fe9 --- /dev/null +++ b/lib/Dialect/Fly/Transforms/PromoteRegMemToVectorSSA.cpp @@ -0,0 +1,664 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/CSE.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Transforms/Passes.h" +#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" + +#include +#include + +using namespace mlir; +using namespace mlir::fly; + +namespace llvm { + +template <> struct DenseMapInfo : DenseMapInfo { + using Base = DenseMapInfo; + + static mlir::fly::MakePtrOp getEmptyKey() { return mlir::fly::MakePtrOp(Base::getEmptyKey()); } + + static mlir::fly::MakePtrOp getTombstoneKey() { + return mlir::fly::MakePtrOp(Base::getTombstoneKey()); + } +}; + +} // namespace llvm + +namespace mlir { +namespace fly { +#define GEN_PASS_DEF_FLYPROMOTEREGMEMTOVECTORSSAPASS +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" +} // namespace fly +} // namespace mlir + +namespace { + +using VectorValue = TypedValue; +using RegPtrValue = TypedValue; +using RegMem2VectorSSAMap = DenseMap; + +struct RegAccessInfo { + MakePtrOp makePtrOp; + int32_t offset; + int32_t width; + Type elemTy; +}; + +struct RegAllocaInfo { + MakePtrOp makePtrOp; + int32_t allocaSize; + Type elemTy; + VectorType vectorSSATy; +}; + +bool isRegValue(Value value) { + if (auto ptrTy = dyn_cast(value.getType())) + return ptrTy.getAddressSpace().getValue() == AddressSpace::Register; + if (auto memRefTy = dyn_cast(value.getType())) + return memRefTy.getAddressSpace().getValue() == AddressSpace::Register; + return false; +} + +bool isRegOperandOrResult(Operation *op) { + for (Value result : op->getOpResults()) { + if (isRegValue(result)) + return true; + } + for (Value operand : op->getOperands()) { + if (isRegValue(operand)) + return true; + } + return false; +} + +std::optional> resolveRegOffset(Value ptr) { + assert(isa(ptr) && "expected register pointer"); + + if (auto makePtrOp = ptr.getDefiningOp()) { + return std::pair{makePtrOp, 0}; + } else if (auto addOffsetOp = ptr.getDefiningOp()) { + auto base = resolveRegOffset(addOffsetOp.getPtr()); + if (!base) + return std::nullopt; + IntAttr intAttr = addOffsetOp.getOffset().getType().getAttr().getLeafAsInt(); + return std::pair{base->first, base->second + intAttr.getValue()}; + } else { + return std::nullopt; + } +} + +class FlyPromoteRegMemToVectorSSAPass + : public mlir::fly::impl::FlyPromoteRegMemToVectorSSAPassBase { +public: + using mlir::fly::impl::FlyPromoteRegMemToVectorSSAPassBase< + FlyPromoteRegMemToVectorSSAPass>::FlyPromoteRegMemToVectorSSAPassBase; + + void runOnOperation() override { + auto moduleOp = getOperation(); + moduleOp->walk([&](gpu::GPUFuncOp funcOp) { + if (failed(processFunction(funcOp))) + signalPassFailure(); + }); + + auto context = moduleOp->getContext(); + IRRewriter rewriter(context); + DominanceInfo domInfo(getOperation()); + eliminateCommonSubExpressions(rewriter, domInfo, getOperation()); + } + +private: + DenseMap regAllocaInfos; + SmallVector allocaOrder; + + LogicalResult processFunction(gpu::GPUFuncOp funcOp) { + OpBuilder opBuilder(funcOp.getContext()); + + regAllocaInfos.clear(); + allocaOrder.clear(); + + funcOp.walk([&](MakePtrOp makePtrOp) { + if (!isRegValue(makePtrOp)) + return; + auto allocaSizeAttr = makePtrOp.getDictAttrs()->getAs("allocaSize"); + if (!allocaSizeAttr || allocaSizeAttr.getInt() <= 0) + return; + + PointerType ptrTy = cast(makePtrOp.getType()); + regAllocaInfos.try_emplace( + makePtrOp, + RegAllocaInfo{makePtrOp, static_cast(allocaSizeAttr.getInt()), ptrTy.getElemTy(), + VectorType::get({allocaSizeAttr.getInt()}, ptrTy.getElemTy())}); + allocaOrder.push_back(makePtrOp); + }); + + funcOp.walk([&](RecastIterOp recastOp) { + if (recastOp->use_empty()) + return; + for (Value v : {recastOp.getSrc(), recastOp.getResult()}) { + if (!isRegValue(v)) + continue; + auto root = resolveRegOffset(v); + if (root) + regAllocaInfos.erase(root->first); + } + }); + bool hasExcludedRoots = allocaOrder.size() != regAllocaInfos.size(); + llvm::erase_if(allocaOrder, [&](MakePtrOp r) { return !regAllocaInfos.count(r); }); + + if (allocaOrder.empty()) + return success(); + + Block *oldEntry = &funcOp.getBody().front(); + Block *newEntry = opBuilder.createBlock(&funcOp.getBody()); + for (BlockArgument arg : oldEntry->getArguments()) + newEntry->addArgument(arg.getType(), funcOp.getLoc()); + + IRMapping mapping; + for (auto it : llvm::zip(oldEntry->getArguments(), newEntry->getArguments())) + mapping.map(std::get<0>(it), std::get<1>(it)); + + RegMem2VectorSSAMap state; + if (failed(rewriteBlock(oldEntry, newEntry, mapping, state))) { + newEntry->erase(); + return failure(); + } + + oldEntry->erase(); + cleanupDeadOps(funcOp); + + if (!hasExcludedRoots) { + bool invalid = false; + funcOp.walk([&](Operation *op) { + if (invalid) + return; + if (isRegOperandOrResult(op)) { + op->emitOpError("register operand/result remain after rmem SSA promotion"); + invalid = true; + return; + } + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (isRegValue(arg)) { + op->emitOpError("register block arguments remain after rmem SSA promotion"); + invalid = true; + return; + } + } + } + } + }); + if (invalid) + return failure(); + } + return success(); + } + + RegAllocaInfo *getAllocaInfo(MakePtrOp makePtrOp) { + auto it = regAllocaInfos.find(makePtrOp); + if (it == regAllocaInfos.end()) + return nullptr; + return &it->second; + } + + std::optional getRegAccessInfo(Value ptr, Type valueType) { + auto rootAndOffset = resolveRegOffset(ptr); + if (!rootAndOffset) + return std::nullopt; + + int32_t width = 0; + if (auto vecTy = dyn_cast(valueType)) + width = vecTy.getNumElements(); + else if (isa(valueType)) + width = 1; + else + return std::nullopt; + + const RegAllocaInfo *info = getAllocaInfo(rootAndOffset->first); + if (!info) + return std::nullopt; + + return RegAccessInfo{rootAndOffset->first, rootAndOffset->second, width, info->elemTy}; + } + + LogicalResult collectTouchedRegAllocaInRegion(Region ®ion, Operation *boundaryOp, + DenseSet &touchedRoots) { + auto tryRecord = [&](Value ptr) { + if (!isa(ptr)) + return; + auto root = resolveRegOffset(ptr); + if (root && getAllocaInfo(root->first) && !boundaryOp->isProperAncestor(root->first)) + touchedRoots.insert(root->first); + }; + for (Block &block : region) { + for (Operation &op : block) { + if (auto loadOp = dyn_cast(&op)) + tryRecord(loadOp.getPtr()); + else if (auto storeOp = dyn_cast(&op)) + tryRecord(storeOp.getPtr()); + + for (Region &nested : op.getRegions()) { + if (failed(collectTouchedRegAllocaInRegion(nested, boundaryOp, touchedRoots))) + return failure(); + } + } + } + return success(); + } + + LogicalResult collectTouchedRegAlloca(Operation *boundaryOp, SmallVectorImpl &roots) { + DenseSet touchedRoots; + for (Region ®ion : boundaryOp->getRegions()) { + if (failed(collectTouchedRegAllocaInRegion(region, boundaryOp, touchedRoots))) + return failure(); + } + for (MakePtrOp makePtrOp : allocaOrder) { + if (touchedRoots.contains(makePtrOp)) + roots.push_back(makePtrOp); + } + return success(); + } + + void appendVectorSSATypes(ArrayRef roots, SmallVectorImpl &types) { + for (MakePtrOp makePtrOp : roots) { + const RegAllocaInfo *info = getAllocaInfo(makePtrOp); + assert(info && "missing register makePtrOp info"); + types.push_back(info->vectorSSATy); + } + } + + void appendVectorSSAValues(RegMem2VectorSSAMap &state, ArrayRef roots, + SmallVectorImpl &values) { + for (MakePtrOp makePtrOp : roots) { + auto it = state.find(makePtrOp); + assert(it != state.end() && "missing state for register makePtrOp"); + values.push_back(it->second); + } + } + + LogicalResult rewritePtrStore(PtrStoreOp storeOp, OpBuilder &builder, IRMapping &mapping, + RegMem2VectorSSAMap &state) { + auto access = getRegAccessInfo(storeOp.getPtr(), storeOp.getValue().getType()); + if (!access) + return failure(); + + auto stateIt = state.find(access->makePtrOp); + assert(stateIt != state.end() && "missing state for register ptr.store"); + VectorValue currentVec = stateIt->second; + VectorValue updatedVec; + + auto loc = storeOp.getLoc(); + Value storedValue = mapping.lookupOrDefault(storeOp.getValue()); + + if (isa(storedValue.getType())) { + assert(access->width == 1 && "expected scalar type with width 1"); + updatedVec = vector::InsertOp::create(builder, loc, storedValue, currentVec, access->offset); + } else { + assert(isa(storedValue.getType()) && "expected vector type"); + assert(cast(storedValue.getType()).getNumElements() == access->width && + "expected vector type with same width as access"); + + updatedVec = vector::InsertStridedSliceOp::create(builder, loc, storedValue, currentVec, + ArrayRef{access->offset}, + ArrayRef{1}); + } + state[access->makePtrOp] = updatedVec; + return success(); + } + + LogicalResult rewritePtrLoad(PtrLoadOp loadOp, OpBuilder &builder, IRMapping &mapping, + RegMem2VectorSSAMap &state) { + auto access = getRegAccessInfo(loadOp.getPtr(), loadOp.getResult().getType()); + if (!access) + return failure(); + + auto stateIt = state.find(access->makePtrOp); + assert(stateIt != state.end() && "missing state for register ptr.load"); + VectorValue currentVec = stateIt->second; + + auto loc = loadOp.getLoc(); + Type resultType = loadOp.getResult().getType(); + Value extracted; + + if (isa(resultType)) { + assert(access->width == 1 && "expected scalar type with width 1"); + extracted = vector::ExtractOp::create(builder, loc, currentVec, access->offset); + } else { + assert(isa(resultType) && "expected vector type"); + assert(cast(resultType).getNumElements() == access->width && + "expected vector type with same width as access"); + + extracted = vector::ExtractStridedSliceOp::create( + builder, loc, currentVec, ArrayRef{access->offset}, + ArrayRef{access->width}, ArrayRef{1}); + } + mapping.map(loadOp, extracted); + return success(); + } + + LogicalResult rewriteIfOp(scf::IfOp oldIf, OpBuilder &builder, IRMapping &mapping, + RegMem2VectorSSAMap &state) { + SmallVector touchedRoots; + if (failed(collectTouchedRegAlloca(oldIf, touchedRoots))) + return failure(); + + SmallVector newResultTypes(oldIf.getResultTypes()); + appendVectorSSATypes(touchedRoots, newResultTypes); + + bool hasElse = !oldIf.getElseRegion().empty(); + bool withElse = hasElse || !touchedRoots.empty(); + auto newIf = scf::IfOp::create(builder, oldIf.getLoc(), TypeRange(newResultTypes), + mapping.lookupOrDefault(oldIf.getCondition()), withElse); + + { + // process then block + IRMapping thenMapping = mapping; + RegMem2VectorSSAMap thenState = state; + Block *oldThen = &oldIf.getThenRegion().front(); + Block *newThen = &newIf.getThenRegion().front(); + if (!newThen->empty()) + newThen->back().erase(); + if (failed(rewriteBlock(oldThen, newThen, thenMapping, thenState))) + return failure(); + + auto oldYield = cast(oldThen->getTerminator()); + SmallVector newYieldOperands; + for (Value yielded : oldYield.getOperands()) + newYieldOperands.push_back(thenMapping.lookupOrDefault(yielded)); + appendVectorSSAValues(thenState, touchedRoots, newYieldOperands); + + OpBuilder thenYieldBuilder = OpBuilder::atBlockEnd(newThen); + scf::YieldOp::create(thenYieldBuilder, oldYield.getLoc(), newYieldOperands); + } + + if (hasElse) { + IRMapping elseMapping = mapping; + RegMem2VectorSSAMap elseState = state; + Block *oldElse = &oldIf.getElseRegion().front(); + Block *newElse = &newIf.getElseRegion().front(); + if (!newElse->empty()) + newElse->back().erase(); + if (failed(rewriteBlock(oldElse, newElse, elseMapping, elseState))) + return failure(); + + auto oldYield = cast(oldElse->getTerminator()); + SmallVector newYieldOperands; + for (Value yielded : oldYield.getOperands()) + newYieldOperands.push_back(elseMapping.lookupOrDefault(yielded)); + appendVectorSSAValues(elseState, touchedRoots, newYieldOperands); + + OpBuilder elseYieldBuilder = OpBuilder::atBlockEnd(newElse); + scf::YieldOp::create(elseYieldBuilder, oldYield.getLoc(), newYieldOperands); + } else if (withElse) { + Block *newElse = &newIf.getElseRegion().front(); + if (!newElse->empty()) + newElse->back().erase(); + SmallVector elseYieldOperands; + appendVectorSSAValues(state, touchedRoots, elseYieldOperands); + OpBuilder elseYieldBuilder = OpBuilder::atBlockEnd(newElse); + scf::YieldOp::create(elseYieldBuilder, oldIf.getLoc(), elseYieldOperands); + } + + for (unsigned i = 0; i < oldIf.getNumResults(); ++i) + mapping.map(oldIf.getResult(i), newIf.getResult(i)); + for (auto it : llvm::enumerate(touchedRoots)) + state[it.value()] = cast(newIf.getResult(oldIf.getNumResults() + it.index())); + + return success(); + } + + LogicalResult rewriteForOp(scf::ForOp oldFor, OpBuilder &builder, IRMapping &mapping, + RegMem2VectorSSAMap &state) { + SmallVector touchedRoots; + if (failed(collectTouchedRegAlloca(oldFor, touchedRoots))) + return failure(); + + SmallVector newInitArgs; + for (Value initArg : oldFor.getInitArgs()) + newInitArgs.push_back(mapping.lookupOrDefault(initArg)); + appendVectorSSAValues(state, touchedRoots, newInitArgs); + + auto newFor = scf::ForOp::create( + builder, oldFor.getLoc(), mapping.lookupOrDefault(oldFor.getLowerBound()), + mapping.lookupOrDefault(oldFor.getUpperBound()), mapping.lookupOrDefault(oldFor.getStep()), + newInitArgs, [](OpBuilder &, Location, Value, ValueRange) {}, oldFor.getUnsignedCmp()); + + IRMapping bodyMapping = mapping; + Block *oldBody = oldFor.getBody(); + Block *newBody = newFor.getBody(); + for (unsigned i = 0; i < oldBody->getNumArguments(); ++i) + bodyMapping.map(oldBody->getArgument(i), newBody->getArgument(i)); + + RegMem2VectorSSAMap bodyState = state; + unsigned carriedArgBase = 1 + oldFor.getRegionIterArgs().size(); + for (auto it : llvm::enumerate(touchedRoots)) + bodyState[it.value()] = cast(newBody->getArgument(carriedArgBase + it.index())); + + if (failed(rewriteBlock(oldBody, newBody, bodyMapping, bodyState))) + return failure(); + + auto oldYield = cast(oldBody->getTerminator()); + SmallVector newYieldOperands; + for (Value yielded : oldYield.getOperands()) + newYieldOperands.push_back(bodyMapping.lookupOrDefault(yielded)); + for (MakePtrOp makePtrOp : touchedRoots) + newYieldOperands.push_back(bodyState[makePtrOp]); + + OpBuilder yieldBuilder = OpBuilder::atBlockEnd(newBody); + scf::YieldOp::create(yieldBuilder, oldYield.getLoc(), newYieldOperands); + + for (unsigned i = 0; i < oldFor.getNumResults(); ++i) + mapping.map(oldFor.getResult(i), newFor.getResult(i)); + for (auto it : llvm::enumerate(touchedRoots)) + state[it.value()] = cast(newFor.getResult(oldFor.getNumResults() + it.index())); + + return success(); + } + + LogicalResult rewriteWhileOp(scf::WhileOp oldWhile, OpBuilder &builder, IRMapping &mapping, + RegMem2VectorSSAMap &state) { + SmallVector touchedRoots; + if (failed(collectTouchedRegAlloca(oldWhile, touchedRoots))) + return failure(); + + SmallVector newInitArgs; + for (Value initArg : oldWhile.getInits()) + newInitArgs.push_back(mapping.lookupOrDefault(initArg)); + appendVectorSSAValues(state, touchedRoots, newInitArgs); + + SmallVector newResultTypes(oldWhile.getResultTypes().begin(), + oldWhile.getResultTypes().end()); + appendVectorSSATypes(touchedRoots, newResultTypes); + + auto newWhile = + scf::WhileOp::create(builder, oldWhile.getLoc(), TypeRange(newResultTypes), newInitArgs); + + SmallVector beforeArgTypes; + beforeArgTypes.reserve(newInitArgs.size()); + for (Value initArg : newInitArgs) + beforeArgTypes.push_back(initArg.getType()); + SmallVector beforeArgLocs(beforeArgTypes.size(), oldWhile.getLoc()); + Block *newBefore = + builder.createBlock(&newWhile.getBefore(), {}, beforeArgTypes, beforeArgLocs); + + SmallVector afterArgTypes(oldWhile.getResultTypes().begin(), + oldWhile.getResultTypes().end()); + appendVectorSSATypes(touchedRoots, afterArgTypes); + SmallVector afterArgLocs(afterArgTypes.size(), oldWhile.getLoc()); + Block *newAfter = builder.createBlock(&newWhile.getAfter(), {}, afterArgTypes, afterArgLocs); + + { + // process before block + IRMapping beforeMapping = mapping; + Block *oldBefore = oldWhile.getBeforeBody(); + for (unsigned i = 0; i < oldBefore->getNumArguments(); ++i) + beforeMapping.map(oldBefore->getArgument(i), newBefore->getArgument(i)); + + RegMem2VectorSSAMap beforeState = state; + unsigned rootArgBase = oldBefore->getNumArguments(); + for (auto it : llvm::enumerate(touchedRoots)) + beforeState[it.value()] = + cast(newBefore->getArgument(rootArgBase + it.index())); + + if (failed(rewriteBlock(oldBefore, newBefore, beforeMapping, beforeState))) + return failure(); + + auto oldCondition = cast(oldBefore->getTerminator()); + SmallVector newConditionArgs; + for (Value arg : oldCondition.getArgs()) + newConditionArgs.push_back(beforeMapping.lookupOrDefault(arg)); + for (MakePtrOp makePtrOp : touchedRoots) + newConditionArgs.push_back(beforeState[makePtrOp]); + + OpBuilder condBuilder = OpBuilder::atBlockEnd(newBefore); + scf::ConditionOp::create(condBuilder, oldCondition.getLoc(), + beforeMapping.lookupOrDefault(oldCondition.getCondition()), + newConditionArgs); + } + + { + // process after block + IRMapping afterMapping = mapping; + Block *oldAfter = oldWhile.getAfterBody(); + for (unsigned i = 0; i < oldAfter->getNumArguments(); ++i) + afterMapping.map(oldAfter->getArgument(i), newAfter->getArgument(i)); + + RegMem2VectorSSAMap afterState = state; + unsigned rootArgBase = oldAfter->getNumArguments(); + for (auto it : llvm::enumerate(touchedRoots)) + afterState[it.value()] = cast(newAfter->getArgument(rootArgBase + it.index())); + + if (failed(rewriteBlock(oldAfter, newAfter, afterMapping, afterState))) + return failure(); + + auto oldYield = cast(oldAfter->getTerminator()); + SmallVector newYieldOperands; + for (Value yielded : oldYield.getOperands()) + newYieldOperands.push_back(afterMapping.lookupOrDefault(yielded)); + for (MakePtrOp makePtrOp : touchedRoots) + newYieldOperands.push_back(afterState[makePtrOp]); + + OpBuilder yieldBuilder = OpBuilder::atBlockEnd(newAfter); + scf::YieldOp::create(yieldBuilder, oldYield.getLoc(), newYieldOperands); + } + + for (unsigned i = 0; i < oldWhile.getNumResults(); ++i) + mapping.map(oldWhile.getResult(i), newWhile.getResult(i)); + for (auto it : llvm::enumerate(touchedRoots)) + state[it.value()] = + cast(newWhile.getResult(oldWhile.getNumResults() + it.index())); + + return success(); + } + + LogicalResult rewriteBlock(Block *oldBlock, Block *newBlock, IRMapping &mapping, + RegMem2VectorSSAMap &state) { + OpBuilder builder(oldBlock->getParentOp()->getContext()); + + for (Operation &op : *oldBlock) { + builder.setInsertionPointToEnd(newBlock); + + if (isa(op)) + continue; + + if (auto makePtrOp = dyn_cast(&op); makePtrOp && isRegValue(makePtrOp)) { + if (auto *info = getAllocaInfo(makePtrOp)) { + mapping.map( + makePtrOp.getResult(), + ub::PoisonOp::create(builder, makePtrOp.getLoc(), makePtrOp.getType()).getResult()); + state[makePtrOp] = cast( + ub::PoisonOp::create(builder, makePtrOp.getLoc(), info->vectorSSATy).getResult()); + } else { + builder.clone(op, mapping); + } + continue; + } + + if (auto storeOp = dyn_cast(&op); storeOp && isRegValue(storeOp.getPtr())) { + if (getRegAccessInfo(storeOp.getPtr(), storeOp.getValue().getType())) { + if (failed(rewritePtrStore(storeOp, builder, mapping, state))) + return failure(); + } else { + builder.clone(op, mapping); + } + continue; + } + + if (auto loadOp = dyn_cast(&op); loadOp && isRegValue(loadOp.getPtr())) { + if (getRegAccessInfo(loadOp.getPtr(), loadOp.getResult().getType())) { + if (failed(rewritePtrLoad(loadOp, builder, mapping, state))) + return failure(); + } else { + builder.clone(op, mapping); + } + continue; + } + + if (auto forOp = dyn_cast(&op)) { + if (failed(rewriteForOp(forOp, builder, mapping, state))) + return failure(); + continue; + } + if (auto ifOp = dyn_cast(&op)) { + if (failed(rewriteIfOp(ifOp, builder, mapping, state))) + return failure(); + continue; + } + if (auto whileOp = dyn_cast(&op)) { + if (failed(rewriteWhileOp(whileOp, builder, mapping, state))) + return failure(); + continue; + } + + if (op.getNumRegions() != 0) { + bool hasNestedRegOperandOrResult = false; + op.walk([&](Operation *nestedOp) { + if (nestedOp == &op) + return; + if (isRegOperandOrResult(nestedOp)) + hasNestedRegOperandOrResult = true; + }); + if (hasNestedRegOperandOrResult) + return op.emitOpError("unsupported region op with register values during rmem SSA"); + } + + builder.clone(op, mapping); + } + + return success(); + } + + void cleanupDeadOps(gpu::GPUFuncOp funcOp) { + bool changed = true; + while (changed) { + changed = false; + funcOp.walk([&](Block *block) { + for (Operation &op : llvm::make_early_inc_range(llvm::reverse(*block))) { + if (isOpTriviallyDead(&op)) { + op.erase(); + changed = true; + } + } + }); + } + } +}; + +} // namespace diff --git a/python/flydsl/compiler/backends/rocm.py b/python/flydsl/compiler/backends/rocm.py index 5188d152..052e0314 100644 --- a/python/flydsl/compiler/backends/rocm.py +++ b/python/flydsl/compiler/backends/rocm.py @@ -65,6 +65,8 @@ def pipeline_fragments(self, *, compile_hints: dict) -> List[str]: "fly-rewrite-func-signature", "fly-canonicalize", "fly-layout-lowering", + "fly-convert-atom-call-to-ssa-form", + "fly-promote-regmem-to-vectorssa", "convert-fly-to-rocdl", "canonicalize", f"gpu.module(convert-scf-to-cf,cse," @@ -74,8 +76,8 @@ def pipeline_fragments(self, *, compile_hints: dict) -> List[str]: "convert-scf-to-cf", "convert-cf-to-llvm", "gpu-to-llvm{use-bare-pointers-for-host=true use-bare-pointers-for-kernels=true}", - "convert-arith-to-llvm", "convert-vector-to-llvm", + "convert-arith-to-llvm", "convert-func-to-llvm", "reconcile-unrealized-casts", *( diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 4451bbd2..f4f29f9a 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -799,11 +799,15 @@ def get_dyn_shared(loc=None, ip=None): @traced_op def inttoptr(result_type, src, loc=None, ip=None): + if result_type.address_space == AddressSpace.Register: + raise ValueError("inttoptr is not supported for register address space") return fly.inttoptr(result_type, src, loc=loc, ip=ip) @traced_op def ptrtoint(ptr, loc=None, ip=None): + if ptr.address_space == AddressSpace.Register: + raise ValueError("ptrtoint is not supported for register address space") return fly.ptrtoint(ptr, loc=loc, ip=ip) diff --git a/tests/mlir/Transforms/promote_regmem_to_ssa.mlir b/tests/mlir/Transforms/promote_regmem_to_ssa.mlir new file mode 100644 index 00000000..add21060 --- /dev/null +++ b/tests/mlir/Transforms/promote_regmem_to_ssa.mlir @@ -0,0 +1,380 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s --fly-convert-atom-call-to-ssa-form --fly-promote-regmem-to-vectorssa | FileCheck %s --check-prefix=CHECK + +// Tests for fly-promote-regmem-to-vectorssa pass: +// - register state is threaded across scf.for / scf.if / scf.while +// - tail mma_atom_call after the loop is rewritten to SSA form +// - final register ptr.load is replaced by vector extract ops + +// CHECK-LABEL: gpu.func @promote_accumulator_to_vector_ssa +// CHECK-SAME: (%[[OUT:.*]]: !fly.ptr) +// CHECK-NOT: register +// CHECK-NOT: fly.mma_atom_call( +// CHECK: %[[A_STATE:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [0], strides = [1]} : vector<4xf16> into vector<4xf16> +// CHECK: %[[B_STATE:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [0], strides = [1]} : vector<4xf16> into vector<4xf16> +// CHECK: %[[ACC_INIT:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> +gpu.module @promote_rmem_to_vector_ssa { + gpu.func @promote_accumulator_to_vector_ssa(%out: !fly.ptr) kernel { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %a_init = arith.constant dense<1.000000e+00> : vector<4xf16> + %b_init = arith.constant dense<2.000000e+00> : vector<4xf16> + %zero = arith.constant dense<0.000000e+00> : vector<4xf32> + + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + %acc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,1)> + %acc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,0)> + %acc_layout = fly.make_layout(%acc_shape, %acc_stride) : (!fly.int_tuple<(4,1)>, !fly.int_tuple<(1,0)>) -> !fly.layout<(4,1):(1,0)> + + %a_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %b_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %acc_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + + fly.ptr.store(%a_init, %a_ptr) : (vector<4xf16>, !fly.ptr) -> () + fly.ptr.store(%b_init, %b_ptr) : (vector<4xf16>, !fly.ptr) -> () + + %acc_off = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %acc_slot = fly.add_offset(%acc_ptr, %acc_off) : (!fly.ptr, !fly.int_tuple<4>) -> !fly.ptr + fly.ptr.store(%zero, %acc_slot) : (vector<4xf32>, !fly.ptr) -> () + + %a_view = fly.make_view(%a_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %b_view = fly.make_view(%b_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %acc_view = fly.make_view(%acc_slot, %acc_layout) : (!fly.ptr, !fly.layout<(4,1):(1,0)>) -> !fly.memref + %atom = fly.make_mma_atom : !fly.mma_atom f32>> + + // CHECK: %{{.*}}:3 = scf.for {{.*}} iter_args(%[[A_ITER:.*]] = %[[A_STATE]], %[[B_ITER:.*]] = %[[B_STATE]], %[[ACC:.*]] = %[[ACC_INIT]]) -> (vector<4xf16>, vector<4xf16>, vector<8xf32>) { + // CHECK: %[[LOOP_A:.*]] = vector.extract_strided_slice %[[A_ITER]] {offsets = [0], sizes = [4], strides = [1]} : vector<4xf16> to vector<4xf16> + // CHECK: %[[LOOP_B:.*]] = vector.extract_strided_slice %[[B_ITER]] {offsets = [0], sizes = [4], strides = [1]} : vector<4xf16> to vector<4xf16> + // CHECK: %[[LOOP_C:.*]] = vector.extract_strided_slice %[[ACC]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> + // CHECK: %[[LOOP_RES:.*]] = fly.mma_atom_call_ssa + // CHECK: %[[LOOP_ACC_NEXT:.*]] = vector.insert_strided_slice %[[LOOP_RES]], %[[ACC]] {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> + // CHECK: scf.yield %[[A_ITER]], %[[B_ITER]], %[[LOOP_ACC_NEXT]] : vector<4xf16>, vector<4xf16>, vector<8xf32> + scf.for %iv = %c0 to %c2 step %c1 { + fly.mma_atom_call(%atom, %acc_view, %a_view, %b_view, %acc_view) : (!fly.mma_atom f32>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + } + + // CHECK: %[[TAIL_A:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [4], strides = [1]} : vector<4xf16> to vector<4xf16> + // CHECK: %[[TAIL_B:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [4], strides = [1]} : vector<4xf16> to vector<4xf16> + // CHECK: %[[TAIL_C:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> + // CHECK: %[[TAIL_RES:.*]] = fly.mma_atom_call_ssa + // CHECK: %[[TAIL_ACC:.*]] = vector.insert_strided_slice %[[TAIL_RES]], %{{.*}} {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> + fly.mma_atom_call(%atom, %acc_view, %a_view, %b_view, %acc_view) : (!fly.mma_atom f32>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + + // CHECK: %[[FINAL:.*]] = vector.extract_strided_slice %[[TAIL_ACC]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> + // CHECK: %[[ELEM:.*]] = vector.extract %[[FINAL]][%{{.*}}] : f32 from vector<4xf32> + // CHECK: fly.ptr.store(%[[ELEM]], %[[OUT]]) : (f32, !fly.ptr) -> () + %final = fly.ptr.load(%acc_slot) : (!fly.ptr) -> vector<4xf32> + %elem0 = vector.extract %final[%c0] : f32 from vector<4xf32> + fly.ptr.store(%elem0, %out) : (f32, !fly.ptr) -> () + gpu.return + } + + // CHECK-LABEL: gpu.func @promote_fp8_mma_to_vector_ssa + // CHECK-NOT: register + // CHECK-NOT: fly.mma_atom_call( + // CHECK: %[[A_STATE:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [0], strides = [1]} : vector<8xf8E4M3FNUZ> into vector<8xf8E4M3FNUZ> + // CHECK: %[[B_STATE:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [0], strides = [1]} : vector<8xf8E4M3FNUZ> into vector<8xf8E4M3FNUZ> + // CHECK: %[[ACC_INIT:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> + // CHECK: %[[A:.*]] = vector.extract_strided_slice %[[A_STATE]] {offsets = [0], sizes = [8], strides = [1]} : vector<8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ> + // CHECK: %[[B:.*]] = vector.extract_strided_slice %[[B_STATE]] {offsets = [0], sizes = [8], strides = [1]} : vector<8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ> + // CHECK: %[[C:.*]] = vector.extract_strided_slice %[[ACC_INIT]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> + // CHECK: %[[RES:.*]] = fly.mma_atom_call_ssa(%{{.*}}, %[[A]], %[[B]], %[[C]]) + // CHECK-SAME: -> vector<4xf32> + gpu.func @promote_fp8_mma_to_vector_ssa(%out: !fly.ptr) kernel { + %c0 = arith.constant 0 : index + %a_init = arith.constant dense<1.000000e+00> : vector<8xf8E4M3FNUZ> + %b_init = arith.constant dense<2.000000e+00> : vector<8xf8E4M3FNUZ> + %zero = arith.constant dense<0.000000e+00> : vector<4xf32> + + %shape8 = fly.make_int_tuple() : () -> !fly.int_tuple<8> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec8 = fly.make_layout(%shape8, %stride1) : (!fly.int_tuple<8>, !fly.int_tuple<1>) -> !fly.layout<8:1> + %acc_shape = fly.make_int_tuple() : () -> !fly.int_tuple<(4,1)> + %acc_stride = fly.make_int_tuple() : () -> !fly.int_tuple<(1,0)> + %acc_layout = fly.make_layout(%acc_shape, %acc_stride) : (!fly.int_tuple<(4,1)>, !fly.int_tuple<(1,0)>) -> !fly.layout<(4,1):(1,0)> + + %a_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + %b_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + %acc_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 8 : i64}} : () -> !fly.ptr + + fly.ptr.store(%a_init, %a_ptr) : (vector<8xf8E4M3FNUZ>, !fly.ptr) -> () + fly.ptr.store(%b_init, %b_ptr) : (vector<8xf8E4M3FNUZ>, !fly.ptr) -> () + + %acc_off = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %acc_slot = fly.add_offset(%acc_ptr, %acc_off) : (!fly.ptr, !fly.int_tuple<4>) -> !fly.ptr + fly.ptr.store(%zero, %acc_slot) : (vector<4xf32>, !fly.ptr) -> () + + %a_view = fly.make_view(%a_ptr, %vec8) : (!fly.ptr, !fly.layout<8:1>) -> !fly.memref + %b_view = fly.make_view(%b_ptr, %vec8) : (!fly.ptr, !fly.layout<8:1>) -> !fly.memref + %acc_view = fly.make_view(%acc_slot, %acc_layout) : (!fly.ptr, !fly.layout<(4,1):(1,0)>) -> !fly.memref + %atom = fly.make_mma_atom : !fly.mma_atom f32>> + + fly.mma_atom_call(%atom, %acc_view, %a_view, %b_view, %acc_view) : (!fly.mma_atom f32>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + + %final = fly.ptr.load(%acc_slot) : (!fly.ptr) -> vector<4xf32> + %elem0 = vector.extract %final[%c0] : f32 from vector<4xf32> + fly.ptr.store(%elem0, %out) : (f32, !fly.ptr) -> () + gpu.return + } + + // CHECK-LABEL: gpu.func @promote_if_register_state_to_vector_ssa + // CHECK-NOT: register + // CHECK: %[[INIT:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [0], strides = [1]} : vector<4xf32> into vector<4xf32> + // CHECK: %[[IF_STATE:.*]] = scf.if %arg1 -> (vector<4xf32>) { + // CHECK: %[[THEN_STATE:.*]] = vector.insert_strided_slice %{{.*}}, %[[INIT]] {offsets = [0], strides = [1]} : vector<4xf32> into vector<4xf32> + // CHECK: scf.yield %[[THEN_STATE]] : vector<4xf32> + // CHECK: } else { + // CHECK: scf.yield %[[INIT]] : vector<4xf32> + // CHECK: } + // CHECK: %[[FINAL:.*]] = vector.extract_strided_slice %[[IF_STATE]] {offsets = [0], sizes = [4], strides = [1]} : vector<4xf32> to vector<4xf32> + gpu.func @promote_if_register_state_to_vector_ssa(%out: !fly.ptr, %pred: i1) kernel { + %c0 = arith.constant 0 : index + %zero = arith.constant dense<0.000000e+00> : vector<4xf32> + %one = arith.constant dense<1.000000e+00> : vector<4xf32> + %reg = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + + fly.ptr.store(%zero, %reg) : (vector<4xf32>, !fly.ptr) -> () + scf.if %pred { + fly.ptr.store(%one, %reg) : (vector<4xf32>, !fly.ptr) -> () + } + + %final = fly.ptr.load(%reg) : (!fly.ptr) -> vector<4xf32> + %elem0 = vector.extract %final[%c0] : f32 from vector<4xf32> + fly.ptr.store(%elem0, %out) : (f32, !fly.ptr) -> () + gpu.return + } + + // CHECK-LABEL: gpu.func @promote_while_register_state_to_vector_ssa + // CHECK-NOT: register + // CHECK: %[[INIT:.*]] = vector.insert %{{.*}}, %{{.*}}[0] : i32 into vector<1xi32> + // CHECK: %[[WHILE:.*]] = scf.while (%[[STATE:.*]] = %[[INIT]]) : (vector<1xi32>) -> vector<1xi32> { + // CHECK: %[[CUR:.*]] = vector.extract %[[STATE]][0] : i32 from vector<1xi32> + // CHECK: %[[COND:.*]] = arith.cmpi slt, %[[CUR]], %{{.*}} : i32 + // CHECK: scf.condition(%[[COND]]) %[[STATE]] : vector<1xi32> + // CHECK: } do { + // CHECK: ^bb0(%[[LOOP_STATE:.*]]: vector<1xi32>): + // CHECK: %[[CUR2:.*]] = vector.extract %[[LOOP_STATE]][0] : i32 from vector<1xi32> + // CHECK: %[[NEXT:.*]] = arith.addi %[[CUR2]], %{{.*}} : i32 + // CHECK: %[[NEXT_STATE:.*]] = vector.insert %[[NEXT]], %[[LOOP_STATE]] [0] : i32 into vector<1xi32> + // CHECK: scf.yield %[[NEXT_STATE]] : vector<1xi32> + // CHECK: } + // CHECK: %[[FINAL:.*]] = vector.extract %[[WHILE]][0] : i32 from vector<1xi32> + gpu.func @promote_while_register_state_to_vector_ssa(%out: !fly.ptr) kernel { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %reg = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + + fly.ptr.store(%c0_i32, %reg) : (i32, !fly.ptr) -> () + scf.while : () -> () { + %cur = fly.ptr.load(%reg) : (!fly.ptr) -> i32 + %cond = arith.cmpi slt, %cur, %c2_i32 : i32 + scf.condition(%cond) + } do { + %cur = fly.ptr.load(%reg) : (!fly.ptr) -> i32 + %next = arith.addi %cur, %c1_i32 : i32 + fly.ptr.store(%next, %reg) : (i32, !fly.ptr) -> () + scf.yield + } + + %final = fly.ptr.load(%reg) : (!fly.ptr) -> i32 + fly.ptr.store(%final, %out) : (i32, !fly.ptr) -> () + gpu.return + } + + // CHECK-LABEL: gpu.func @promote_if_with_nested_while_preserves_results_and_state + // CHECK-NOT: register + // CHECK: %[[INIT:.*]] = vector.insert %{{.*}}, %{{.*}} [0] : i32 into vector<1xi32> + // CHECK: %[[IF:.*]]:2 = scf.if %arg1 -> (i32, vector<1xi32>) { + // CHECK: %[[WHILE:.*]]:2 = scf.while (%[[ITER:.*]] = %c0_i32, %[[STATE:.*]] = %[[INIT]]) : (i32, vector<1xi32>) -> (i32, vector<1xi32>) { + // CHECK: %[[COND:.*]] = arith.cmpi slt, %[[ITER]], %{{.*}} : i32 + // CHECK: scf.condition(%[[COND]]) %[[ITER]], %[[STATE]] : i32, vector<1xi32> + // CHECK: } do { + // CHECK: ^bb0(%[[ITER_IN:.*]]: i32, %[[STATE_IN:.*]]: vector<1xi32>): + // CHECK: %[[CUR:.*]] = vector.extract %[[STATE_IN]][0] : i32 from vector<1xi32> + // CHECK: %[[NEXT_VAL:.*]] = arith.addi %[[CUR]], %{{.*}} : i32 + // CHECK: %[[STATE_NEXT:.*]] = vector.insert %[[NEXT_VAL]], %[[STATE_IN]] [0] : i32 into vector<1xi32> + // CHECK: %[[ITER_NEXT:.*]] = arith.addi %[[ITER_IN]], %{{.*}} : i32 + // CHECK: scf.yield %[[ITER_NEXT]], %[[STATE_NEXT]] : i32, vector<1xi32> + // CHECK: } + // CHECK: scf.yield %[[WHILE]]#0, %[[WHILE]]#1 : i32, vector<1xi32> + // CHECK: } else { + // CHECK: %[[ELSE_STATE:.*]] = vector.insert %{{.*}}, %[[INIT]] [0] : i32 into vector<1xi32> + // CHECK: scf.yield %{{.*}}, %[[ELSE_STATE]] : i32, vector<1xi32> + // CHECK: } + // CHECK: %[[FINAL:.*]] = vector.extract %[[IF]]#1[0] : i32 from vector<1xi32> + // CHECK: %[[SUM:.*]] = arith.addi %[[IF]]#0, %[[FINAL]] : i32 + // CHECK: fly.ptr.store(%[[SUM]], %arg0) : (i32, !fly.ptr) -> () + gpu.func @promote_if_with_nested_while_preserves_results_and_state(%out: !fly.ptr, %pred: i1) kernel { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c2_i32 = arith.constant 2 : i32 + %reg = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + + fly.ptr.store(%c0_i32, %reg) : (i32, !fly.ptr) -> () + %if_res = scf.if %pred -> (i32) { + %while_res = scf.while (%iter = %c0_i32) : (i32) -> (i32) { + %cur = fly.ptr.load(%reg) : (!fly.ptr) -> i32 + %cond = arith.cmpi slt, %iter, %c2_i32 : i32 + scf.condition(%cond) %iter : i32 + } do { + ^bb0(%iter_in: i32): + %cur = fly.ptr.load(%reg) : (!fly.ptr) -> i32 + %next = arith.addi %cur, %c1_i32 : i32 + fly.ptr.store(%next, %reg) : (i32, !fly.ptr) -> () + %iter_next = arith.addi %iter_in, %c1_i32 : i32 + scf.yield %iter_next : i32 + } + scf.yield %while_res : i32 + } else { + fly.ptr.store(%c2_i32, %reg) : (i32, !fly.ptr) -> () + scf.yield %c2_i32 : i32 + } + + %final = fly.ptr.load(%reg) : (!fly.ptr) -> i32 + %sum = arith.addi %if_res, %final : i32 + fly.ptr.store(%sum, %out) : (i32, !fly.ptr) -> () + gpu.return + } + + // CHECK-LABEL: gpu.func @promote_for_with_nested_if_while_preserves_results_and_state + // CHECK-NOT: register + // CHECK: %[[INIT:.*]] = vector.insert %{{.*}}, %{{.*}} [0] : i32 into vector<1xi32> + // CHECK: %[[FOR:.*]]:2 = scf.for {{.*}} iter_args(%[[SUM:.*]] = %c0_i32, %[[STATE:.*]] = %[[INIT]]) -> (i32, vector<1xi32>) { + // CHECK: %[[IF:.*]] = scf.if %arg1 -> (vector<1xi32>) { + // CHECK: %[[WHILE:.*]]:2 = scf.while (%[[INNER:.*]] = %c0_i32, %[[WHILE_STATE:.*]] = %[[STATE]]) : (i32, vector<1xi32>) -> (i32, vector<1xi32>) { + // CHECK: %[[INNER_COND:.*]] = arith.cmpi slt, %[[INNER]], %{{.*}} : i32 + // CHECK: scf.condition(%[[INNER_COND]]) %[[INNER]], %[[WHILE_STATE]] : i32, vector<1xi32> + // CHECK: } do { + // CHECK: ^bb0(%[[INNER_IN:.*]]: i32, %[[BODY_STATE:.*]]: vector<1xi32>): + // CHECK: %[[CUR:.*]] = vector.extract %[[BODY_STATE]][0] : i32 from vector<1xi32> + // CHECK: %[[NEXT_VAL:.*]] = arith.addi %[[CUR]], %{{.*}} : i32 + // CHECK: %[[NEXT_STATE:.*]] = vector.insert %[[NEXT_VAL]], %[[BODY_STATE]] [0] : i32 into vector<1xi32> + // CHECK: %[[INNER_NEXT:.*]] = arith.addi %[[INNER_IN]], %{{.*}} : i32 + // CHECK: scf.yield %[[INNER_NEXT]], %[[NEXT_STATE]] : i32, vector<1xi32> + // CHECK: } + // CHECK: scf.yield %[[WHILE]]#1 : vector<1xi32> + // CHECK: } else { + // CHECK: scf.yield %[[STATE]] : vector<1xi32> + // CHECK: } + // CHECK: %[[SUM_NEXT:.*]] = arith.addi %[[SUM]], %{{.*}} : i32 + // CHECK: scf.yield %[[SUM_NEXT]], %[[IF]] : i32, vector<1xi32> + // CHECK: } + // CHECK: %[[FINAL:.*]] = vector.extract %[[FOR]]#1[0] : i32 from vector<1xi32> + // CHECK: %[[OUTVAL:.*]] = arith.addi %[[FOR]]#0, %[[FINAL]] : i32 + // CHECK: fly.ptr.store(%[[OUTVAL]], %arg0) : (i32, !fly.ptr) -> () + gpu.func @promote_for_with_nested_if_while_preserves_results_and_state(%out: !fly.ptr, %pred: i1) kernel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %reg = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + + fly.ptr.store(%c0_i32, %reg) : (i32, !fly.ptr) -> () + %sum = scf.for %iv = %c0 to %c2 step %c1 iter_args(%sum_iter = %c0_i32) -> (i32) { + scf.if %pred { + %while_res = scf.while (%inner = %c0_i32) : (i32) -> (i32) { + %cond = arith.cmpi slt, %inner, %c1_i32 : i32 + scf.condition(%cond) %inner : i32 + } do { + ^bb0(%inner_in: i32): + %cur = fly.ptr.load(%reg) : (!fly.ptr) -> i32 + %next = arith.addi %cur, %c1_i32 : i32 + fly.ptr.store(%next, %reg) : (i32, !fly.ptr) -> () + %inner_next = arith.addi %inner_in, %c1_i32 : i32 + scf.yield %inner_next : i32 + } + scf.yield + } + %sum_next = arith.addi %sum_iter, %c1_i32 : i32 + scf.yield %sum_next : i32 + } + + %final = fly.ptr.load(%reg) : (!fly.ptr) -> i32 + %out_val = arith.addi %sum, %final : i32 + fly.ptr.store(%out_val, %out) : (i32, !fly.ptr) -> () + gpu.return + } + + // CHECK-LABEL: gpu.func @promote_region_local_nested_register_state_to_vector_ssa + // CHECK-NOT: register + // CHECK: %[[IFRES:.*]] = scf.if %arg1 -> (i32) { + // CHECK: %[[LOCAL_INIT:.*]] = vector.insert %{{.*}}, %{{.*}} [0] : i32 into vector<1xi32> + // CHECK: %[[FOR:.*]] = scf.for {{.*}} iter_args(%[[STATE:.*]] = %[[LOCAL_INIT]]) -> (vector<1xi32>) { + // CHECK: %[[WHILE:.*]] = scf.while (%[[WHILE_STATE:.*]] = %[[STATE]]) : (vector<1xi32>) -> vector<1xi32> { + // CHECK: %[[CUR:.*]] = vector.extract %[[WHILE_STATE]][0] : i32 from vector<1xi32> + // CHECK: %[[COND:.*]] = arith.cmpi slt, %[[CUR]], %{{.*}} : i32 + // CHECK: scf.condition(%[[COND]]) %[[WHILE_STATE]] : vector<1xi32> + // CHECK: } do { + // CHECK: ^bb0(%[[BODY_STATE:.*]]: vector<1xi32>): + // CHECK: %[[CUR2:.*]] = vector.extract %[[BODY_STATE]][0] : i32 from vector<1xi32> + // CHECK: %[[NEXT_VAL:.*]] = arith.addi %[[CUR2]], %{{.*}} : i32 + // CHECK: %[[NEXT_STATE:.*]] = vector.insert %[[NEXT_VAL]], %[[BODY_STATE]] [0] : i32 into vector<1xi32> + // CHECK: scf.yield %[[NEXT_STATE]] : vector<1xi32> + // CHECK: } + // CHECK: scf.yield %[[WHILE]] : vector<1xi32> + // CHECK: } + // CHECK: %[[FINAL:.*]] = vector.extract %[[FOR]][0] : i32 from vector<1xi32> + // CHECK: fly.ptr.store(%[[FINAL]], %arg0) : (i32, !fly.ptr) -> () + // CHECK: scf.yield %{{.*}} : i32 + // CHECK: } else { + // CHECK: scf.yield %{{.*}} : i32 + // CHECK: } + gpu.func @promote_region_local_nested_register_state_to_vector_ssa(%out: !fly.ptr, %pred: i1) kernel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + %if_res = scf.if %pred -> (i32) { + %local = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + fly.ptr.store(%c0_i32, %local) : (i32, !fly.ptr) -> () + + scf.for %iv = %c0 to %c2 step %c1 { + scf.while : () -> () { + %cur = fly.ptr.load(%local) : (!fly.ptr) -> i32 + %cond = arith.cmpi slt, %cur, %c1_i32 : i32 + scf.condition(%cond) + } do { + %cur = fly.ptr.load(%local) : (!fly.ptr) -> i32 + %next = arith.addi %cur, %c1_i32 : i32 + fly.ptr.store(%next, %local) : (i32, !fly.ptr) -> () + scf.yield + } + scf.yield + } + + %final = fly.ptr.load(%local) : (!fly.ptr) -> i32 + fly.ptr.store(%final, %out) : (i32, !fly.ptr) -> () + scf.yield %c0_i32 : i32 + } else { + scf.yield %c1_i32 : i32 + } + %use_if_res = arith.addi %if_res, %c0_i32 : i32 + gpu.return + } + + // CHECK-LABEL: gpu.func @promote_void_if_with_region_local_register + // CHECK-NOT: register + // CHECK: scf.if %arg1 + // CHECK: %[[LOCAL_INIT:.*]] = ub.poison : vector<1xi32> + // CHECK: %[[AFTER_STORE:.*]] = vector.insert %{{.*}}, %[[LOCAL_INIT]] + // CHECK: %[[LOADED:.*]] = vector.extract %[[AFTER_STORE]] + // CHECK: fly.ptr.store(%[[LOADED]], %arg0) + gpu.func @promote_void_if_with_region_local_register(%out: !fly.ptr, %pred: i1) kernel { + %c42_i32 = arith.constant 42 : i32 + scf.if %pred { + %local = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + fly.ptr.store(%c42_i32, %local) : (i32, !fly.ptr) -> () + %val = fly.ptr.load(%local) : (!fly.ptr) -> i32 + fly.ptr.store(%val, %out) : (i32, !fly.ptr) -> () + } + gpu.return + } +} diff --git a/tests/mlir/Transforms/promote_regmem_to_ssa_copy_pred.mlir b/tests/mlir/Transforms/promote_regmem_to_ssa_copy_pred.mlir new file mode 100644 index 00000000..66d65d66 --- /dev/null +++ b/tests/mlir/Transforms/promote_regmem_to_ssa_copy_pred.mlir @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s --fly-convert-atom-call-to-ssa-form --fly-promote-regmem-to-vectorssa | FileCheck %s + + +// CHECK-LABEL: gpu.func @promote_copy_atoms_to_ssa +// CHECK-NOT: register +// CHECK: %[[POISON:.*]] = ub.poison : vector<4xf16> +// CHECK: %[[LOAD:.*]] = fly.copy_atom_call_ssa(%{{.*}}, %{{.*}}) {operandSegmentSizes = array} +// CHECK-SAME: : (!fly.copy_atom, 16>, !fly.memref) -> vector<4xf16> +// CHECK: %[[STATE:.*]] = vector.insert_strided_slice %[[LOAD]], %[[POISON]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<4xf16> +// CHECK: %[[READ:.*]] = vector.extract_strided_slice %[[STATE]] {offsets = [0], sizes = [4], strides = [1]} : vector<4xf16> to vector<4xf16> +// CHECK: fly.copy_atom_call_ssa(%{{.*}}, %[[READ]], %{{.*}}) {operandSegmentSizes = array} +// CHECK-SAME: : (!fly.copy_atom, 16>, vector<4xf16>, !fly.memref) -> () + +gpu.module @promote_rmem_to_vector_ssa_copy { + gpu.func @promote_copy_atoms_to_ssa(%src: !fly.ptr, %dst: !fly.ptr) kernel { + %c0_i16 = arith.constant 0 : i16 + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c4294967295_i64 = arith.constant 4294967295 : i64 + + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %src_desc = fly.make_ptr(%src, %c0_i16, %c4294967295_i64, %c1024_i32) : (!fly.ptr, i16, i64, i32) -> !fly.ptr + %src_view = fly.make_view(%src_desc, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %dst_view = fly.make_view(%dst, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + %copy_in = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + %copy_in_soff = fly.atom.set_value(%copy_in, "soffset", %c4_i32) : (!fly.copy_atom, 16>, i32) -> !fly.copy_atom, 16> + %copy_out = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + fly.copy_atom_call(%copy_in_soff, %src_view, %reg_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + fly.copy_atom_call(%copy_out, %reg_view, %dst_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + gpu.return + } + +// CHECK-LABEL: gpu.func @promote_loop_local_copy_to_ssa +// CHECK-NOT: register +// CHECK: %[[LOOP_POISON:.*]] = ub.poison : vector<4xf16> +// CHECK: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITER_STATE:.*]] = %[[LOOP_POISON]]) -> (vector<4xf16>) { +// CHECK: %[[ITER_SSA:.*]] = fly.copy_atom_call_ssa(%{{.*}}, %{{.*}}) {operandSegmentSizes = array} +// CHECK-SAME: : (!fly.copy_atom, 16>, !fly.memref) -> vector<4xf16> +// CHECK: %[[NEW_STATE:.*]] = vector.insert_strided_slice %[[ITER_SSA]], %[[ITER_STATE]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<4xf16> +// CHECK: %[[SLICE:.*]] = vector.extract_strided_slice %[[NEW_STATE]] {offsets = [0], sizes = [4], strides = [1]} : vector<4xf16> to vector<4xf16> +// CHECK: %[[ELEM:.*]] = vector.extract %[[SLICE]][%{{.*}}] : f16 from vector<4xf16> + gpu.func @promote_loop_local_copy_to_ssa(%src: !fly.ptr, %dst: !fly.ptr) kernel { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c0_i16 = arith.constant 0 : i16 + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c4294967295_i64 = arith.constant 4294967295 : i64 + + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %src_desc = fly.make_ptr(%src, %c0_i16, %c4294967295_i64, %c1024_i32) : (!fly.ptr, i16, i64, i32) -> !fly.ptr + %src_view = fly.make_view(%src_desc, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + %copy_in = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + scf.for %iv = %c0 to %c2 step %c1 { + %iv_i32 = arith.index_cast %iv : index to i32 + %soff = arith.muli %iv_i32, %c4_i32 : i32 + %copy_iter = fly.atom.set_value(%copy_in, "soffset", %soff) : (!fly.copy_atom, 16>, i32) -> !fly.copy_atom, 16> + fly.copy_atom_call(%copy_iter, %src_view, %reg_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + %vec = fly.ptr.load(%reg_ptr) : (!fly.ptr) -> vector<4xf16> + %elem = vector.extract %vec[%c0] : f16 from vector<4xf16> + fly.ptr.store(%elem, %dst) : (f16, !fly.ptr) -> () + } + gpu.return + } +} + + +// Verify that when a predicated copy_atom_call has a register-memref pred +// and a register-memref dst, after both passes: +// 1. pred is promoted to i1 SSA via vector<1xi1> state +// 2. old dst value is extracted from vector state before the SSA call +// 3. copy_atom_call_ssa receives old dst and pred as SSA operands +// 4. result is inserted back into vector state +// 5. all register types are eliminated + +// CHECK-LABEL: gpu.func @promote_pred_copy_to_ssa +// CHECK-NOT: register +// CHECK: %[[PRED_POISON:.*]] = ub.poison : vector<1xi1> +// CHECK: %[[PRED_STATE:.*]] = vector.insert %arg2, %[[PRED_POISON]] [0] : i1 into vector<1xi1> +// CHECK: %[[DST_POISON:.*]] = ub.poison : vector<4xf16> +// CHECK: %[[PRED_VAL:.*]] = vector.extract %[[PRED_STATE]][0] : i1 from vector<1xi1> +// CHECK: %[[OLD_DST:.*]] = vector.extract_strided_slice %[[DST_POISON]] {offsets = [0], sizes = [4], strides = [1]} : vector<4xf16> to vector<4xf16> +// CHECK: %[[SSA:.*]] = fly.copy_atom_call_ssa(%{{.*}}, %{{.*}}, %[[OLD_DST]], %[[PRED_VAL]]) {operandSegmentSizes = array} +// CHECK-SAME: : (!fly.copy_atom, 16>, !fly.memref, vector<4xf16>, i1) -> vector<4xf16> +// CHECK: %[[UPDATED:.*]] = vector.insert_strided_slice %[[SSA]], %[[DST_POISON]] {offsets = [0], strides = [1]} : vector<4xf16> into vector<4xf16> +// CHECK: %[[OUT_VEC:.*]] = vector.extract_strided_slice %[[UPDATED]] {offsets = [0], sizes = [4], strides = [1]} : vector<4xf16> to vector<4xf16> +// CHECK: fly.copy_atom_call_ssa(%{{.*}}, %[[OUT_VEC]], %{{.*}}) {operandSegmentSizes = array} +// CHECK-SAME: : (!fly.copy_atom, 16>, vector<4xf16>, !fly.memref) -> () +gpu.module @promote_rmem_to_vector_ssa_copy_pred { + gpu.func @promote_pred_copy_to_ssa(%src: !fly.ptr, %dst: !fly.ptr, %pred_val: i1) kernel { + %c0_i16 = arith.constant 0 : i16 + %c4_i32 = arith.constant 4 : i32 + %c1024_i32 = arith.constant 1024 : i32 + %c4294967295_i64 = arith.constant 4294967295 : i64 + + %shape4 = fly.make_int_tuple() : () -> !fly.int_tuple<4> + %stride1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %vec4 = fly.make_layout(%shape4, %stride1) : (!fly.int_tuple<4>, !fly.int_tuple<1>) -> !fly.layout<4:1> + + %shape1 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %pred_layout = fly.make_layout(%shape1, %stride1) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + + %src_desc = fly.make_ptr(%src, %c0_i16, %c4294967295_i64, %c1024_i32) : (!fly.ptr, i16, i64, i32) -> !fly.ptr + %src_view = fly.make_view(%src_desc, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + %dst_view = fly.make_view(%dst, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + %copy_in = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + %copy_in_soff = fly.atom.set_value(%copy_in, "soffset", %c4_i32) : (!fly.copy_atom, 16>, i32) -> !fly.copy_atom, 16> + + %pred_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + fly.ptr.store(%pred_val, %pred_ptr) : (i1, !fly.ptr) -> () + %pred_view = fly.make_view(%pred_ptr, %pred_layout) : (!fly.ptr, !fly.layout<1:1>) -> !fly.memref + + %reg_ptr = fly.make_ptr() {dictAttrs = {allocaSize = 4 : i64}} : () -> !fly.ptr + %reg_view = fly.make_view(%reg_ptr, %vec4) : (!fly.ptr, !fly.layout<4:1>) -> !fly.memref + + fly.copy_atom_call(%copy_in_soff, %src_view, %reg_view, %pred_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref, !fly.memref) -> () + + %copy_out = fly.make_copy_atom {valBits = 16 : i32} : !fly.copy_atom, 16> + fly.copy_atom_call(%copy_out, %reg_view, %dst_view) : (!fly.copy_atom, 16>, !fly.memref, !fly.memref) -> () + gpu.return + } +} diff --git a/tests/mlir/Transforms/promote_regmem_to_ssa_invalid.mlir b/tests/mlir/Transforms/promote_regmem_to_ssa_invalid.mlir new file mode 100644 index 00000000..5c521f38 --- /dev/null +++ b/tests/mlir/Transforms/promote_regmem_to_ssa_invalid.mlir @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s --fly-promote-regmem-to-vectorssa | FileCheck %s + +// CHECK-LABEL: gpu.func @skip_register_recast_iter +// CHECK: fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr +// CHECK: fly.recast_iter +// CHECK: fly.ptr.load +// CHECK: fly.ptr.store +gpu.module @promote_regmem_to_ssa_recast_skip { + gpu.func @skip_register_recast_iter(%out: !fly.ptr) kernel { + %reg = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + %recast = "fly.recast_iter"(%reg) : (!fly.ptr) -> !fly.ptr + %val = fly.ptr.load(%recast) : (!fly.ptr) -> i32 + fly.ptr.store(%val, %out) : (i32, !fly.ptr) -> () + gpu.return + } + +// CHECK-LABEL: gpu.func @mixed_recast_and_normal +// CHECK: fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr +// CHECK: fly.recast_iter +// CHECK: fly.ptr.load(%{{.*}}) : (!fly.ptr) -> i32 +// CHECK-NOT: !fly.ptr +// CHECK: ub.poison : vector<1xi32> +// CHECK: %[[STATE:.*]] = vector.insert %{{.*}}, %{{.*}} [0] : i32 into vector<1xi32> +// CHECK: %[[FINAL:.*]] = vector.extract %[[STATE]][0] : i32 from vector<1xi32> + gpu.func @mixed_recast_and_normal(%out: !fly.ptr) kernel { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + + %recast_reg = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + %recast = "fly.recast_iter"(%recast_reg) : (!fly.ptr) -> !fly.ptr + %recast_val = fly.ptr.load(%recast) : (!fly.ptr) -> i32 + + %normal_reg = fly.make_ptr() {dictAttrs = {allocaSize = 1 : i64}} : () -> !fly.ptr + fly.ptr.store(%c1_i32, %normal_reg) : (i32, !fly.ptr) -> () + %normal_val = fly.ptr.load(%normal_reg) : (!fly.ptr) -> i32 + + %sum = arith.addi %recast_val, %normal_val : i32 + fly.ptr.store(%sum, %out) : (i32, !fly.ptr) -> () + gpu.return + } +} From 8d3456cfc0f316a07b7b1a09230e5e18899fd072 Mon Sep 17 00:00:00 2001 From: Felix Li Date: Sun, 19 Apr 2026 08:12:25 +0800 Subject: [PATCH 19/29] fix version err after uninstall (#413) --- .github/workflows/publish-pypi.yaml | 4 ++-- python/flydsl/__init__.py | 7 +------ python/flydsl/_version.py | 1 - python/flydsl/compiler/jit_function.py | 10 ++++------ setup.py | 18 ++++-------------- 5 files changed, 11 insertions(+), 29 deletions(-) delete mode 100644 python/flydsl/_version.py diff --git a/.github/workflows/publish-pypi.yaml b/.github/workflows/publish-pypi.yaml index 92d119c1..9588656d 100644 --- a/.github/workflows/publish-pypi.yaml +++ b/.github/workflows/publish-pypi.yaml @@ -25,10 +25,10 @@ jobs: id: version run: | TAG_VERSION="${GITHUB_REF_NAME#v}" - PACKAGE_VERSION="$(awk -F'"' '/^_BASE_VERSION = / {print $2; exit}' python/flydsl/__init__.py)" + PACKAGE_VERSION="$(awk -F'"' '/^__version__ = / {print $2; exit}' python/flydsl/__init__.py)" if [ -z "${PACKAGE_VERSION}" ]; then - echo "Failed to find _BASE_VERSION in python/flydsl/__init__.py" >&2 + echo "Failed to find __version__ in python/flydsl/__init__.py" >&2 exit 1 fi diff --git a/python/flydsl/__init__.py b/python/flydsl/__init__.py index 99582a11..872aa58f 100644 --- a/python/flydsl/__init__.py +++ b/python/flydsl/__init__.py @@ -1,16 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -_BASE_VERSION = "0.1.3.1" +__version__ = "0.1.4" # FFM simulator compatibility shim (no-op outside simulator sessions). from ._compat import _maybe_preload_system_comgr # noqa: E402 _maybe_preload_system_comgr() -try: - from ._version import __version__ -except ImportError: - __version__ = _BASE_VERSION - from .autotune import autotune, Config diff --git a/python/flydsl/_version.py b/python/flydsl/_version.py deleted file mode 100644 index 9556dd64..00000000 --- a/python/flydsl/_version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "0.1.3.2" diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index f8be93aa..f7edd6ab 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -72,12 +72,10 @@ def _flydsl_key() -> str: except Exception: pass - # Also hash flydsl/__init__.py and _version.py. - for name in ("__init__.py", "_version.py"): - p = flydsl_root / name - if p.is_file(): - with open(p, "rb") as f: - contents.append(hashlib.sha256(f.read()).hexdigest()) + p = flydsl_root / "__init__.py" + if p.is_file(): + with open(p, "rb") as f: + contents.append(hashlib.sha256(f.read()).hexdigest()) # 2) Hash native shared libraries (C++ passes, runtime wrappers, bindings). backend = get_backend() diff --git a/setup.py b/setup.py index 7ff2a0ea..a332216e 100644 --- a/setup.py +++ b/setup.py @@ -124,12 +124,11 @@ def _read_version() -> str: release -> {base} (e.g. 0.1.0) -> {base}.dev{commit_count} (legacy local dev builds) """ + import re + init_py = (PY_SRC / "flydsl" / "__init__.py").read_text(encoding="utf-8") - base_version = "0.0.0" - for line in init_py.splitlines(): - if line.startswith("_BASE_VERSION"): - base_version = line.split("=", 1)[1].strip().strip('"').strip("'") - break + m = re.search(r'^__version__\s*=\s*["\']([^"\']+)["\']', init_py, re.MULTILINE) + base_version = m.group(1) if m else "0.0.0" if "+" in base_version: return base_version @@ -156,14 +155,6 @@ def _read_version() -> str: return f"{base_version}.dev{commit_count}" -def _write_version_file(version: str) -> None: - """Generate _version.py so that runtime __version__ matches the build version.""" - version_file = PY_SRC / "flydsl" / "_version.py" - version_file.write_text( - f'__version__ = "{version}"\n', - encoding="utf-8", - ) - def _load_requirements() -> list[str]: req = REPO_ROOT / "requirements.txt" @@ -399,7 +390,6 @@ def _ensure_python_embedded_mlir_package() -> None: } _version = _read_version() -_write_version_file(_version) setup( name="flydsl", From 23f59ab2e765b7ce4aa34a7a590a94919fa435e7 Mon Sep 17 00:00:00 2001 From: andyluo7 <43718156+andyluo7@users.noreply.github.com> Date: Sun, 19 Apr 2026 19:43:44 -0700 Subject: [PATCH 20/29] Add fused epilogue support to preshuffle GEMM: bias + ReLU/SiLU/GeLU (#404) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add fused epilogue support to preshuffle GEMM: bias + ReLU/SiLU/GeLU Modified body_row to apply bias + activation in registers before the output store, eliminating separate epilogue kernel launches. MI300X results: 2.73x avg speedup vs hipBLAS+bias+SiLU - O-proj: 3.25x, MoE-dn: 3.33x, QKV: 2.92x - Zero epilogue overhead (fused ops hidden by store latency) New parameter: epilogue='none'|'bias'|'bias_relu'|'bias_silu'|'bias_gelu' New kernel arg: arg_bias (N-element bias tensor) * Fix test: pass dummy bias tensor for fused epilogue kernel signature The fused epilogue commit added arg_bias to kernel_gemm and launch_gemm unconditionally (needed for epilogue='none' to maintain a single kernel definition). The test's _gemm_args and _w4_args functions need to pass a dummy bias tensor to match the updated launch_gemm signature. Without this fix, args shift: M goes to arg_bias slot, N to i32_m, stream to i32_n, causing 'missing a required argument: stream' error. * Add CUDAGraph capture test for preshuffle GEMM Verifies that FlyDSL kernels are correctly captured by torch.cuda.CUDAGraph when torch.cuda.current_stream() is passed as the stream argument. Test flow: 1. Regular execution → reference result 2. CUDAGraph capture on a dedicated stream 3. Graph replay → verify result matches reference 4. Assert non-zero output (kernel was captured) Tests both BF16 and FP8 paths. * Fix CI test failures: torch.dtype kwarg + pertoken_quant kwarg name Two test-only fixes in tests/kernels/test_preshuffle_gemm.py: 1. test_mfma_a8_flyc_preshuffle (line 206): torch.empty was passed the string out_dtype ("bf16"/"fp16") instead of a torch.dtype. Use torch_out_dtype (defined on line 116) which is the actual torch.dtype, matching c_out_raw on line 190. Fixes 192 failing parametrizations across mi325-1/mi355-1 runners. 2. test_cudagraph_capture_preshuffle[fp8] (lines 500-501): wrong kwarg name for pertoken_quant -- uses 'dtype=' but the function signature in tests/utils.py expects 'quant_dtype='. Switch to the correct kwarg. No production code touched. CI for both mi325 and mi355 should return to green. * Address review: epilogue+cshuffle guard, bias max_size, SiLU rcp, fused tests Four fixes for review comments on PR #404: 1. Reject epilogue!='none' + use_cshuffle_epilog=True (correctness bug). The cshuffle path returns from write_row_to_lds before body_row, so the bias/activation fusion would silently be dropped. Now raises ValueError at compile time. 2. Drop hardcoded 'c_n * 2' for the bias buffer size; use max_size=True like the other resources. The previous code assumed 2-byte output and would break on any future fp32 path. 3. Rewrite SiLU as 'val * (1/denom)' instead of 'val / denom' so the compiler lowers to v_rcp_f32 + v_mul_f32 (~4x faster than v_div_* on AMD GPUs). 4. Add fused-epilogue correctness tests: - test_fused_epilogue_correctness parametrized over bias / bias_relu / bias_silu / bias_gelu, comparing against a torch reference. - test_fused_epilogue_rejects_cshuffle covering the new guard from #1. Previously every test ran with epilogue='none' + dummy bias, so the actual fusion path had no coverage. * Fix latent bias_relu/bias_gelu codegen bugs caught by new epilogue tests When running the new test_fused_epilogue_correctness tests on MI300X (gfx942) two pre-existing latent bugs in body_row's epilogue path showed up. Neither was reachable before because no test ever exercised epilogue != 'none'. 1. bias_relu: arith.cmpf was called with the string predicate 'ogt', which the underlying MLIR binding rejects (it expects an integer CmpFPredicate enum value). Replaced cmpf+select with arith.maximumf, which is both correct and more concise. 2. bias_gelu: math.tanh has no AMD libcall ('no libcall available for ftanh'), so any kernel using the tanh-approx GeLU failed to lower. Rewrote tanh in terms of math.exp using a numerically stable form that only ever evaluates exp(-2|y|) (in [0, 1]), so we never overflow fp32 even for large activations. The (1 + tanh(y)) factor used by GeLU is then formed branchlessly via cmpf+select on the sign of y. Verified on MI300X: 4/4 fused epilogue correctness tests pass, the guard test passes, and all 103 pre-existing tests still pass. * Tighten epilogue tests + remove dead code in GeLU rewrite - Remove unused is_pos/is_neg lines and stale comments in the GeLU branchless tanh expansion (the cmpf+select on the sign of y is the only thing actually used). - Tighten test_fused_epilogue_correctness: explicit NaN/Inf assertions before the value comparison, and document the bf16 tolerance choice (atol=2.0, rtol=0.05) based on K=8192 reduction error. Verified: 103 passed, 20 skipped on gfx942. --------- Co-authored-by: Andy Luo Co-authored-by: Felix Li --- kernels/preshuffle_gemm.py | 110 ++++++++++++- tests/kernels/test_preshuffle_gemm.py | 223 ++++++++++++++++++++++++++ 2 files changed, 331 insertions(+), 2 deletions(-) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 1e6d38ed..5c88b9c3 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -16,6 +16,7 @@ from flydsl.expr import arith, vector from flydsl.expr import gpu from flydsl.expr import buffer_ops, rocdl +from flydsl.expr import math from flydsl.expr.typing import T @@ -142,7 +143,8 @@ def compile_preshuffle_gemm_a8( waves_per_eu: Optional[int] = None, use_async_copy: bool = False, dsrd_preload: int = -1, - dvmem_preload: int = -1 + dvmem_preload: int = -1, + epilogue: str = "none", # "none", "bias", "bias_relu", "bias_silu", "bias_gelu" ): """Compile the preshuffle GEMM kernel using the @flyc.kernel API. @@ -306,6 +308,23 @@ def _out_elem(): allocator_pong.ptr = lds_alloc_offset + lds_total_elems * elem_bytes # ── Kernel function ──────────────────────────────────────────────────── + _has_epilogue = epilogue != "none" + _has_bias = epilogue in ("bias", "bias_relu", "bias_silu", "bias_gelu") + _has_relu = epilogue == "bias_relu" + _has_silu = epilogue == "bias_silu" + _has_gelu = epilogue == "bias_gelu" + + # Fused epilogue is implemented inside body_row (the direct store path). + # When use_cshuffle_epilog=True, the epilogue path goes through + # write_row_to_lds -> store_pair and returns before body_row, which would + # silently drop the bias + activation. Reject the unsupported combination. + if _has_epilogue and use_cshuffle_epilog: + raise ValueError( + "Fused epilogue (epilogue != 'none') is not supported with " + "use_cshuffle_epilog=True; the cshuffle path bypasses body_row " + "where the bias/activation fusion lives." + ) + @flyc.kernel def kernel_gemm( arg_c: fx.Tensor, @@ -313,6 +332,7 @@ def kernel_gemm( arg_b: fx.Tensor, arg_scale_a: fx.Tensor, arg_scale_b: fx.Tensor, + arg_bias: fx.Tensor, i32_m: fx.Int32, i32_n: fx.Int32, ): @@ -395,6 +415,14 @@ def kernel_gemm( _needs_per_token_scale = not is_f16_or_bf16 and not is_fp4 scale_a_rsrc = None if (is_f16_or_bf16) else buffer_ops.create_buffer_resource( arg_scale_a, max_size=False) + + # ---- Bias buffer resource (for fused epilogue) ---- + # Use max_size=True so the buffer descriptor's size is taken from the + # actual arg_bias tensor; this avoids hardcoding the output element + # size (was c_n * 2, which broke if out_dtype became fp32 etc.). + bias_rsrc = None + if _has_bias: + bias_rsrc = buffer_ops.create_buffer_resource(arg_bias, max_size=True) b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=True) scale_b_rsrc = None if (is_f16_or_bf16) else buffer_ops.create_buffer_resource( arg_scale_b, max_size=True) @@ -985,6 +1013,83 @@ def body_row(*, mi, ii, row_in_tile, row): val_s = (val * s_a) * s_b_vals[ni] else: val_s = val + + # ── Fused epilogue: bias + activation ── + if _has_bias and bias_rsrc is not None: + col_idx = col_base + (ni * 16) + bias_val_f16 = buffer_ops.buffer_load( + bias_rsrc, col_idx, vec_width=1, + dtype=_out_elem()) + bias_val_f32 = arith.extf(T.f32, bias_val_f16) + val_s = val_s + bias_val_f32 + + if _has_relu: + # ReLU(x) = max(x, 0). Use maximumf rather than + # cmpf+select: the lower-level cmpf wrapper requires + # an integer CmpFPredicate enum value, not the string + # "ogt", so the previous form failed at compile time + # the moment the bias_relu epilogue was actually + # exercised (test coverage gap). + zero_f32 = arith.constant(0.0, type=T.f32) + val_s = arith.maximumf(val_s, zero_f32) + elif _has_silu: + # SiLU(x) = x * sigmoid(x). Compute as + # sigmoid_x = 1 / (1 + exp(-x)) # one rcp instead of fdiv + # val_s = val_s * sigmoid_x + # to lower to v_rcp_f32 + v_mul_f32 instead of v_div_* + # (~4x faster than fdiv on AMD GPUs). + neg_one = arith.constant(-1.0, type=T.f32) + neg_val = val_s * neg_one + exp_neg = math.exp(neg_val) + one_f32 = arith.constant(1.0, type=T.f32) + denom = one_f32 + exp_neg + sigmoid_x = arith.divf(one_f32, denom) + val_s = val_s * sigmoid_x + elif _has_gelu: + # GeLU approx: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + # math.tanh has no AMD libcall, so expand it via exp. + # Numerically stable form using only non-positive + # exponent (avoids fp32 overflow for large |x|): + # a = -2 * |y| (a <= 0, exp(a) in [0,1]) + # tanh(y) = sign(y) * (1 - exp(a)) / (1 + exp(a)) + # 1 + tanh(y) = 1 + sign(y) * (1 - exp(a))/(1+exp(a)) + # We compute (1 + tanh(y)) directly from y because we + # need the GeLU output, which is half * x * (1 + tanh). + half_f32 = arith.constant(0.5, type=T.f32) + coeff_f32 = arith.constant(0.044715, type=T.f32) + sqrt2pi_f32 = arith.constant(0.7978845608, type=T.f32) + neg_two_f32 = arith.constant(-2.0, type=T.f32) + one_f32 = arith.constant(1.0, type=T.f32) + zero_f32 = arith.constant(0.0, type=T.f32) + x3 = val_s * val_s * val_s + y = sqrt2pi_f32 * (val_s + coeff_f32 * x3) + # |y| via max(y, -y) — avoids math.absf dependency + neg_y = zero_f32 - y + abs_y = arith.maximumf(y, neg_y) + # exp(-2|y|) is in [0, 1], no overflow. + e_neg2abs = math.exp(neg_two_f32 * abs_y) + denom = one_f32 + e_neg2abs + # tanh(|y|) = (1 - e_neg2abs) / denom + # tanh(y) = sign(y) * tanh(|y|) + # 1 + tanh(y): + # y >= 0: 1 + tanh(|y|) = (denom + (1 - e)) / denom + # = (2) / denom + # (because denom = 1 + e and + # denom + 1 - e = 2) + # y < 0: 1 - tanh(|y|) = (denom - (1 - e)) / denom + # = (2 * e) / denom + two_f32 = arith.constant(2.0, type=T.f32) + # numerator = 2 when y >= 0 + # = 2 * e_neg2abs when y < 0 + # OGT predicate id = 2 (CmpFPredicate.OGT) + sign_pred = arith.cmpf(2, y, zero_f32) + num_pos = two_f32 + num_neg = two_f32 * e_neg2abs + numerator = arith.select(sign_pred, num_pos, num_neg) + recip = arith.divf(one_f32, denom) + one_plus_tanh = numerator * recip + val_s = half_f32 * val_s * one_plus_tanh + val_f16 = arith.trunc_f(_out_elem(), val_s) idx_out = idx_base + (ni * 16) buffer_ops.buffer_store(val_f16, c_rsrc, idx_out) @@ -1384,6 +1489,7 @@ def launch_gemm( arg_b: fx.Tensor, arg_scale_a: fx.Tensor, arg_scale_b: fx.Tensor, + arg_bias: fx.Tensor, i32_m: fx.Int32, i32_n: fx.Int32, stream: fx.Stream, @@ -1398,7 +1504,7 @@ def launch_gemm( gx = (i32_m + (tile_m - 1)) // tile_m gy = i32_n // tile_n - launcher = kernel_gemm(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, i32_m, i32_n) + launcher = kernel_gemm(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, arg_bias, i32_m, i32_n) if waves_per_eu is not None: _wpe = int(waves_per_eu) if _wpe >= 1: diff --git a/tests/kernels/test_preshuffle_gemm.py b/tests/kernels/test_preshuffle_gemm.py index 1e444739..682a5328 100644 --- a/tests/kernels/test_preshuffle_gemm.py +++ b/tests/kernels/test_preshuffle_gemm.py @@ -202,12 +202,16 @@ def _pack_shuffled_int8_to_packed_int4_no_perm(x_shuf_i8): def _as_i8(t): return t.view(torch.int8) if "float8" in str(t.dtype) else t + # Create a dummy bias tensor (unused when epilogue="none") + _dummy_bias = torch.empty(0, dtype=torch_out_dtype, device=a_q.device) + def _gemm_args(c, a, b, sa, sb): return (c.contiguous().view(-1), _as_i8(a.contiguous().view(-1)), _as_i8(b.contiguous().view(-1)), sa.contiguous().view(-1) if sa.numel() > 0 else sa, sb.contiguous().view(-1) if sb.numel() > 0 else sb, + _dummy_bias, M, N, torch.cuda.current_stream()) compiled_fn = flyc.compile(launch_fn, *_gemm_args(c_out_raw, a_q, b_input, sa_flat, sb_flat)) @@ -362,12 +366,16 @@ def _to_bytes(t): return t return t.view(torch.uint8) + # Create a dummy bias tensor (unused when epilogue="none") + _dummy_bias_w4 = torch.empty(0, dtype=torch.bfloat16, device=a_q.device) + def _w4_args(c, a, b, sa, sb): return (c.contiguous().view(-1), _to_bytes(a).contiguous().view(-1), _to_bytes(b).contiguous().view(-1), _to_bytes(sa).contiguous().view(-1), _to_bytes(sb).contiguous().view(-1), + _dummy_bias_w4, M, N, torch.cuda.current_stream()) compiled_fn = flyc.compile(launch_fn, *_w4_args(c_out, a_q, b_shuffled, scale_a, scale_b_shuffled)) @@ -464,3 +472,218 @@ def launch_kernel(c, a, b, sa, sb): ) except pytest.skip.Exception as e: print(f"Skipped: {e}") + + +# ── CUDAGraph Capture Test ──────────────────────────────────────────────── + +@pytest.mark.parametrize("in_dtype", ["bf16", "fp8"]) +def test_cudagraph_capture_preshuffle(in_dtype): + """Verify FlyDSL preshuffle GEMM kernels are captured by CUDAGraph. + + This test ensures that passing torch.cuda.current_stream() correctly + routes the kernel launch to the capture stream during graph recording. + Without proper stream handling, CUDAGraph replay produces all-zeros. + """ + device = "cuda:0" + M, N, K = 1, 8192, 8192 + tile_m, tile_n, tile_k = 16, 64, 256 + + arch = str(get_rocm_arch()) + if not arch.startswith("gfx94") and not arch.startswith("gfx95"): + pytest.skip(f"Unsupported arch: {arch}") + + # Prepare data + a_raw = torch.randn(M, K, dtype=torch.bfloat16, device=device) + b_raw = torch.randn(N, K, dtype=torch.bfloat16, device=device) + + if in_dtype == "fp8": + a_q, scale_a = pertoken_quant(a_raw, quant_dtype=torch.float8_e4m3fnuz) + b_q, scale_b = pertoken_quant(b_raw, quant_dtype=torch.float8_e4m3fnuz) + a_q = a_q.view(torch.int8) + b_input = shuffle_weight(b_q.view(torch.int8), layout=(16, 16)).contiguous().view(-1) + sa_flat = scale_a.contiguous().view(-1) + sb_flat = scale_b.contiguous().view(-1) + else: + a_q = a_raw + b_input = shuffle_weight(b_raw.contiguous(), layout=(16, 16)).contiguous().view(-1) + sa_flat = torch.empty(0, dtype=torch.float32, device=device) + sb_flat = torch.empty(0, dtype=torch.float32, device=device) + + c_out = torch.empty(M, N, dtype=torch.bfloat16, device=device) + _dummy_bias = torch.empty(0, dtype=torch.bfloat16, device=device) + + # Compile kernel + launch_fn = compile_preshuffle_gemm_a8( + M=M, N=N, K=K, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + in_dtype=in_dtype, + epilogue="none", + ) + + def _args(c, a, b, sa, sb): + return (c.contiguous().view(-1), + a.contiguous().view(-1) if "int" not in str(a.dtype) else a.contiguous().view(-1), + b, + sa.contiguous().view(-1) if sa.numel() > 0 else sa, + sb.contiguous().view(-1) if sb.numel() > 0 else sb, + _dummy_bias, + M, N, torch.cuda.current_stream()) + + compiled_fn = flyc.compile(launch_fn, *_args(c_out, a_q, b_input, sa_flat, sb_flat)) + + # Warmup + compiled_fn(*_args(c_out, a_q, b_input, sa_flat, sb_flat)) + torch.cuda.synchronize() + + # ── Regular execution (reference) ── + c_out.zero_() + compiled_fn(*_args(c_out, a_q, b_input, sa_flat, sb_flat)) + torch.cuda.synchronize() + ref = c_out.clone() + assert ref.abs().max().item() > 0, "Regular execution produced all zeros" + + # ── CUDAGraph capture ── + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + + # Warmup on capture stream + with torch.cuda.stream(s): + compiled_fn(*_args(c_out, a_q, b_input, sa_flat, sb_flat)) + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + + # Record + c_out.zero_() + with torch.cuda.graph(g, stream=s): + compiled_fn(*_args(c_out, a_q, b_input, sa_flat, sb_flat)) + torch.cuda.synchronize() + + # ── Replay ── + c_out.zero_() + g.replay() + torch.cuda.synchronize() + graph_result = c_out.clone() + + # ── Verify ── + max_diff = (ref - graph_result).abs().max().item() + assert graph_result.abs().max().item() > 0, ( + f"CUDAGraph replay produced all zeros — kernel was NOT captured! " + f"ref max={ref.abs().max().item():.4f}" + ) + assert torch.allclose(ref, graph_result, atol=1e-2), ( + f"CUDAGraph result mismatch: max_diff={max_diff:.6f}, " + f"ref max={ref.abs().max().item():.4f}, graph max={graph_result.abs().max().item():.4f}" + ) + print(f"✓ CUDAGraph capture verified ({in_dtype}): max_diff={max_diff:.6f}") + + +# ── Fused epilogue correctness test ───────────────────────────────────────── + +@pytest.mark.parametrize("epilogue", ["bias", "bias_relu", "bias_silu", "bias_gelu"]) +def test_fused_epilogue_correctness(epilogue): + """Verify fused epilogue (bias + activation) matches a torch reference. + + The previous test suite only exercised epilogue='none' with a dummy bias + tensor, so a regression in body_row's fused bias/activation path would + not have been caught. This test runs each of the four epilogue modes + end-to-end and compares against a torch reference. + """ + import torch.nn.functional as F + + arch = str(get_rocm_arch()) + if not arch.startswith("gfx94") and not arch.startswith("gfx95"): + pytest.skip(f"Unsupported arch: {arch}") + + device = "cuda:0" + M, N, K = 16, 5120, 8192 + tile_m, tile_n, tile_k = 16, 64, 512 + in_dtype = "bf16" + out_dtype = "bf16" + torch_out_dtype = torch.bfloat16 + + torch.manual_seed(0) + a_raw = torch.randn(M, K, dtype=torch_out_dtype, device=device) + b_raw = torch.randn(N, K, dtype=torch_out_dtype, device=device) + bias = torch.randn(N, dtype=torch_out_dtype, device=device) + + # Torch reference: GEMM + bias + activation + a_f32 = a_raw.to(torch.float32) + b_f32 = b_raw.to(torch.float32) + ref_f32 = a_f32 @ b_f32.T + bias.to(torch.float32) + if epilogue == "bias_relu": + ref_f32 = F.relu(ref_f32) + elif epilogue == "bias_silu": + ref_f32 = F.silu(ref_f32) + elif epilogue == "bias_gelu": + ref_f32 = F.gelu(ref_f32, approximate="tanh") + ref = ref_f32.to(torch_out_dtype) + + # FlyDSL kernel + b_input = shuffle_weight(b_raw.contiguous(), layout=(16, 16)).contiguous().view(-1) + sa_flat = torch.empty(0, dtype=torch.float32, device=device) + sb_flat = torch.empty(0, dtype=torch.float32, device=device) + c_out = torch.zeros(M, N, dtype=torch_out_dtype, device=device) + + launch_fn = compile_preshuffle_gemm_a8( + M=M, N=N, K=K, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + in_dtype=in_dtype, + out_dtype=out_dtype, + epilogue=epilogue, + ) + + def _args(c, a, b, sa, sb, bs): + return ( + c.contiguous().view(-1), + a.contiguous().view(-1), + b, + sa.contiguous().view(-1) if sa.numel() > 0 else sa, + sb.contiguous().view(-1) if sb.numel() > 0 else sb, + bs, + M, N, torch.cuda.current_stream(), + ) + + compiled_fn = flyc.compile(launch_fn, *_args(c_out, a_raw, b_input, sa_flat, sb_flat, bias)) + compiled_fn(*_args(c_out, a_raw, b_input, sa_flat, sb_flat, bias)) + torch.cuda.synchronize() + + # bf16 has ~7 bits mantissa; for K=8192 reduction the per-element + # error is bounded by ~K * eps_bf16 ~ 8192 * 2^-7 ~= 64 ULP. We use + # rtol=0.05 (5%) and atol=2.0 (covers small-magnitude outputs). + assert not torch.isnan(c_out).any(), ( + f"Epilogue {epilogue}: kernel produced NaN(s) " + f"(count={int(torch.isnan(c_out).sum().item())})" + ) + assert not torch.isinf(c_out).any(), ( + f"Epilogue {epilogue}: kernel produced Inf(s)" + ) + atol = 2.0 + rtol = 0.05 + diff = (c_out.to(torch.float32) - ref.to(torch.float32)).abs() + max_diff = diff.max().item() + rel = (diff / (ref.to(torch.float32).abs() + 1e-3)).max().item() + assert torch.allclose(c_out, ref, atol=atol, rtol=rtol), ( + f"Epilogue {epilogue} mismatch: max_abs_diff={max_diff:.4f} max_rel={rel:.4f}, " + f"ref max={ref.abs().max().item():.4f}, out max={c_out.abs().max().item():.4f}" + ) + print(f"✓ Fused epilogue {epilogue} correctness verified: " + f"max_abs_diff={max_diff:.4f}, max_rel={rel:.4f}, ref_max={ref.abs().max().item():.2f}") + + +def test_fused_epilogue_rejects_cshuffle(): + """Compile-time guard: epilogue != 'none' with use_cshuffle_epilog=True + must raise rather than silently produce wrong output.""" + arch = str(get_rocm_arch()) + if not arch.startswith("gfx94") and not arch.startswith("gfx95"): + pytest.skip(f"Unsupported arch: {arch}") + + with pytest.raises(ValueError, match="cshuffle"): + compile_preshuffle_gemm_a8( + M=16, N=64, K=512, + tile_m=16, tile_n=64, tile_k=512, + in_dtype="bf16", + out_dtype="bf16", + epilogue="bias_silu", + use_cshuffle_epilog=True, + ) From bdcfc1eafeb54afa9bfbdf7bf3886becbbf02847 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Mon, 20 Apr 2026 21:51:32 +0800 Subject: [PATCH 21/29] [Fix] eliminate llvm unsupported type (Float8/4) before llvm conversion (#415) --- .../flydsl/Dialect/Fly/Utils/PointerUtils.h | 10 +++- lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 16 +++++- lib/Dialect/Fly/IR/FlyUniversalOps.cpp | 2 +- .../Transforms/ConvertAtomCallToSSAForm.cpp | 14 +++--- .../Transforms/PromoteRegMemToVectorSSA.cpp | 49 +++++++++++++++++-- lib/Dialect/Fly/Utils/PointerUtils.cpp | 13 ++++- lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp | 9 ++-- lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp | 11 +++-- lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp | 2 +- lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp | 7 +++ tests/mlir/Conversion/mma_atom.mlir | 14 ++++++ .../Transforms/promote_regmem_to_ssa.mlir | 10 ++-- 12 files changed, 127 insertions(+), 30 deletions(-) diff --git a/include/flydsl/Dialect/Fly/Utils/PointerUtils.h b/include/flydsl/Dialect/Fly/Utils/PointerUtils.h index 44b62526..b2bd6265 100644 --- a/include/flydsl/Dialect/Fly/Utils/PointerUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/PointerUtils.h @@ -16,7 +16,15 @@ TypedValue applySwizzleOnPtr(OpBuilder &b, Location loc, TypedValue ptr, SwizzleAttr swizzle); -Type RegMem2SSAType(fly::MemRefType memRefTy); +/// Project llvm unsupported small-float element types (Float8/Float6/Float4) to integer types of +/// the same bit-width. Non-small-float types are returned unchanged. +Type projectToLLVMCompatibleElemTy(Type elemTy); + +/// Compute the SSA-value type corresponding to a register memref. +/// +/// \p llvmCompatibleType controls whether llvm unsupported small-float element types (e.g. +/// f8E4M3FNUZ/f8E5M2FNUZ) are projected to their same-width integer counterpart. +Type RegMem2SSAType(fly::MemRefType memRefTy, bool llvmCompatibleType = false); } // namespace mlir::fly diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 6582a58d..e5ad20a4 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -75,7 +75,8 @@ class MakePtrOpLowering : public OpConversionPattern { unsigned llvmAS = mapToLLVMAddressSpace(AddressSpace::Register); auto llvmPtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), llvmAS); Value nElems = arith::ConstantIntOp::create(rewriter, loc, allocSize.getInt(), 64); - Value ptr = LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, flyPtrTy.getElemTy(), nElems, 0); + Type elemTy = projectToLLVMCompatibleElemTy(flyPtrTy.getElemTy()); + Value ptr = LLVM::AllocaOp::create(rewriter, loc, llvmPtrTy, elemTy, nElems, 0); rewriter.replaceOp(op, ptr); return success(); } else if (addrSpace == AddressSpace::BufferDesc) { @@ -261,7 +262,7 @@ class AddOffsetOpLowering : public OpConversionPattern { if (!ptrTy) return failure(); - Type elemTy = flyPtrTy.getElemTy(); + Type elemTy = projectToLLVMCompatibleElemTy(flyPtrTy.getElemTy()); Value gep = LLVM::GEPOp::create(rewriter, loc, ptrTy, elemTy, base, ValueRange{offsetVal}); rewriter.replaceOp(op, gep); return success(); @@ -690,6 +691,17 @@ class FlyTypeConverter : public TypeConverter { FlyTypeConverter() { addConversion([](Type type) { return type; }); + addConversion([&](FloatType floatTy) -> std::optional { + if (floatTy.getWidth() < 16) + return IntegerType::get(floatTy.getContext(), floatTy.getWidth()); + return std::nullopt; + }); + addConversion([&](VectorType vecTy) -> std::optional { + Type convertedElem = convertType(vecTy.getElementType()); + if (!convertedElem || convertedElem == vecTy.getElementType()) + return std::nullopt; + return VectorType::get(vecTy.getShape(), convertedElem, vecTy.getScalableDims()); + }); addConversion([&](fly::MemRefType flyMemRefTy) -> Type { if (flyMemRefTy.getAddressSpace().getValue() == AddressSpace::BufferDesc) return BufferFatPtr::getType(flyMemRefTy.getContext()); diff --git a/lib/Dialect/Fly/IR/FlyUniversalOps.cpp b/lib/Dialect/Fly/IR/FlyUniversalOps.cpp index f0b9eaf2..8864031a 100644 --- a/lib/Dialect/Fly/IR/FlyUniversalOps.cpp +++ b/lib/Dialect/Fly/IR/FlyUniversalOps.cpp @@ -285,7 +285,7 @@ LogicalResult CopyOpUniversalAtomicType::emitAtomCall(OpBuilder &builder, Locati Type dstMemTyArg, Value atomVal, Value src, Value dst) const { auto srcMemTy = cast(srcMemTyArg); - auto srcSSATy = fly::RegMem2SSAType(srcMemTy); + auto srcSSATy = fly::RegMem2SSAType(srcMemTy, /*llvmCompatibleType=*/true); Value srcVal = LLVM::LoadOp::create(builder, loc, srcSSATy, src); auto res = emitAtomCallSSA(builder, loc, Type{}, copyAtomTyArg, srcSSATy, dstMemTyArg, atomVal, srcVal, dst); diff --git a/lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp b/lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp index 47cb5e63..3e80d161 100644 --- a/lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp +++ b/lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp @@ -79,7 +79,7 @@ class FlyConvertAtomCallToSSAFormPass Value srcVal = copyOp.getSrc(); if (srcEligible) { Value srcIter = srcVal.getDefiningOp().getIter(); - srcVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(srcTy), srcIter); + srcVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(srcTy, true), srcIter); } Value pred = copyOp.getPred(); @@ -87,12 +87,12 @@ class FlyConvertAtomCallToSSAFormPass auto predMemTy = cast(pred.getType()); if (isEligibleToPromote(predMemTy)) { Value predIter = pred.getDefiningOp().getIter(); - pred = PtrLoadOp::create(builder, loc, RegMem2SSAType(predMemTy), predIter); + pred = PtrLoadOp::create(builder, loc, RegMem2SSAType(predMemTy, true), predIter); } } if (dstEligible) { - auto ssaTy = RegMem2SSAType(dstTy); + auto ssaTy = RegMem2SSAType(dstTy, true); Value dstIter = copyOp.getDst().getDefiningOp().getIter(); Value oldDst = pred ? PtrLoadOp::create(builder, loc, ssaTy, dstIter) : Value{}; auto ssaOp = @@ -125,19 +125,19 @@ class FlyConvertAtomCallToSSAFormPass if (aEligible) { Value aIter = aVal.getDefiningOp().getIter(); - aVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(aTy), aIter).getResult(); + aVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(aTy, true), aIter).getResult(); } if (bEligible) { Value bIter = bVal.getDefiningOp().getIter(); - bVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(bTy), bIter).getResult(); + bVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(bTy, true), bIter).getResult(); } if (cEligible) { Value cIter = cVal.getDefiningOp().getIter(); - cVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(cTy), cIter).getResult(); + cVal = PtrLoadOp::create(builder, loc, RegMem2SSAType(cTy, true), cIter).getResult(); } if (dEligible) { - auto ssaOp = MmaAtomCallSSA::create(builder, loc, TypeRange{RegMem2SSAType(dTy)}, + auto ssaOp = MmaAtomCallSSA::create(builder, loc, TypeRange{RegMem2SSAType(dTy, true)}, mmaOp.getMmaAtom(), /*d=*/nullptr, aVal, bVal, cVal); Value dIter = mmaOp.getD().getDefiningOp().getIter(); PtrStoreOp::create(builder, loc, ssaOp.getResult(0), dIter); diff --git a/lib/Dialect/Fly/Transforms/PromoteRegMemToVectorSSA.cpp b/lib/Dialect/Fly/Transforms/PromoteRegMemToVectorSSA.cpp index 9b636fe9..f4f676f3 100644 --- a/lib/Dialect/Fly/Transforms/PromoteRegMemToVectorSSA.cpp +++ b/lib/Dialect/Fly/Transforms/PromoteRegMemToVectorSSA.cpp @@ -11,15 +11,13 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/CSE.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" #include "flydsl/Dialect/Fly/Transforms/Passes.h" -#include "flydsl/Dialect/Fly/Utils/LayoutUtils.h" +#include "flydsl/Dialect/Fly/Utils/PointerUtils.h" #include #include @@ -141,10 +139,12 @@ class FlyPromoteRegMemToVectorSSAPass return; PointerType ptrTy = cast(makePtrOp.getType()); + Type originElemTy = ptrTy.getElemTy(); + Type ssaElemTy = projectToLLVMCompatibleElemTy(ptrTy.getElemTy()); regAllocaInfos.try_emplace( makePtrOp, - RegAllocaInfo{makePtrOp, static_cast(allocaSizeAttr.getInt()), ptrTy.getElemTy(), - VectorType::get({allocaSizeAttr.getInt()}, ptrTy.getElemTy())}); + RegAllocaInfo{makePtrOp, static_cast(allocaSizeAttr.getInt()), originElemTy, + VectorType::get({allocaSizeAttr.getInt()}, ssaElemTy)}); allocaOrder.push_back(makePtrOp); }); @@ -293,6 +293,40 @@ class FlyPromoteRegMemToVectorSSAPass } } + Value bitcastScalarViaVector(OpBuilder &builder, Location loc, Value value, Type targetTy) { + auto srcVecTy = VectorType::get({1}, value.getType()); + auto dstVecTy = VectorType::get({1}, targetTy); + Value srcVec = vector::FromElementsOp::create(builder, loc, srcVecTy, value); + Value dstVec = vector::BitCastOp::create(builder, loc, dstVecTy, srcVec); + return vector::ExtractOp::create(builder, loc, dstVec, ArrayRef{0}); + } + + Value bitcastToSSAElem(OpBuilder &builder, Location loc, Value value, Type ssaElemTy) { + Type valueTy = value.getType(); + if (auto vecTy = dyn_cast(valueTy)) { + if (vecTy.getElementType() == ssaElemTy) + return value; + auto targetTy = VectorType::get(vecTy.getShape(), ssaElemTy); + return vector::BitCastOp::create(builder, loc, targetTy, value); + } + if (valueTy == ssaElemTy) + return value; + return bitcastScalarViaVector(builder, loc, value, ssaElemTy); + } + + Value bitcastFromSSAElem(OpBuilder &builder, Location loc, Value value, Type originalTy) { + Type valueTy = value.getType(); + if (auto vecTy = dyn_cast(valueTy)) { + auto originalVecTy = cast(originalTy); + if (vecTy.getElementType() == originalVecTy.getElementType()) + return value; + return vector::BitCastOp::create(builder, loc, originalVecTy, value); + } + if (valueTy == originalTy) + return value; + return bitcastScalarViaVector(builder, loc, value, originalTy); + } + LogicalResult rewritePtrStore(PtrStoreOp storeOp, OpBuilder &builder, IRMapping &mapping, RegMem2VectorSSAMap &state) { auto access = getRegAccessInfo(storeOp.getPtr(), storeOp.getValue().getType()); @@ -307,6 +341,10 @@ class FlyPromoteRegMemToVectorSSAPass auto loc = storeOp.getLoc(); Value storedValue = mapping.lookupOrDefault(storeOp.getValue()); + const RegAllocaInfo *info = getAllocaInfo(access->makePtrOp); + assert(info && "missing alloca info for register ptr.store"); + storedValue = bitcastToSSAElem(builder, loc, storedValue, info->vectorSSATy.getElementType()); + if (isa(storedValue.getType())) { assert(access->width == 1 && "expected scalar type with width 1"); updatedVec = vector::InsertOp::create(builder, loc, storedValue, currentVec, access->offset); @@ -349,6 +387,7 @@ class FlyPromoteRegMemToVectorSSAPass builder, loc, currentVec, ArrayRef{access->offset}, ArrayRef{access->width}, ArrayRef{1}); } + extracted = bitcastFromSSAElem(builder, loc, extracted, resultType); mapping.map(loadOp, extracted); return success(); } diff --git a/lib/Dialect/Fly/Utils/PointerUtils.cpp b/lib/Dialect/Fly/Utils/PointerUtils.cpp index 80210a01..9858ecc6 100644 --- a/lib/Dialect/Fly/Utils/PointerUtils.cpp +++ b/lib/Dialect/Fly/Utils/PointerUtils.cpp @@ -28,13 +28,24 @@ TypedValue applySwizzleOnPtr(OpBuilder &b, Location loc, LLVM::IntToPtrOp::create(b, loc, ptrTy, swizzled).getResult()); } -Type RegMem2SSAType(fly::MemRefType memRefTy) { +Type projectToLLVMCompatibleElemTy(Type elemTy) { + if (auto floatTy = dyn_cast(elemTy)) { + unsigned width = floatTy.getWidth(); + if (width < 16) + return IntegerType::get(elemTy.getContext(), width); + } + return elemTy; +} + +Type RegMem2SSAType(fly::MemRefType memRefTy, bool llvmCompatibleType) { if (memRefTy.getAddressSpace().getValue() != AddressSpace::Register) return Type(); LayoutBuilder builder(memRefTy.getContext()); auto layoutAttr = cast(memRefTy.getLayout()); int32_t cosize = layoutCosize(builder, layoutAttr).getLeafAsInt().getValue(); Type elemTy = memRefTy.getElemTy(); + if (llvmCompatibleType) + elemTy = projectToLLVMCompatibleElemTy(elemTy); if (cosize == 1) return elemTy; return VectorType::get({cosize}, elemTy); diff --git a/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp index a182dc72..abf0e973 100644 --- a/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp @@ -168,14 +168,14 @@ LogicalResult CopyOpCDNA3BufferCopyType::emitAtomCall(OpBuilder &builder, Locati return failure(); if (srcIsBuffer) { - auto dstSSATy = fly::RegMem2SSAType(dstMemTy); + auto dstSSATy = fly::RegMem2SSAType(dstMemTy, true); auto res = emitAtomCallSSA(builder, loc, dstSSATy, copyAtomTyArg, srcMemTyArg, Type{}, atomVal, src, Value{}); if (failed(res)) return failure(); LLVM::StoreOp::create(builder, loc, *res, dst); } else { - auto srcSSATy = fly::RegMem2SSAType(srcMemTy); + auto srcSSATy = fly::RegMem2SSAType(srcMemTy, true); Value srcVal = LLVM::LoadOp::create(builder, loc, srcSSATy, src); auto res = emitAtomCallSSA(builder, loc, Type{}, copyAtomTyArg, srcSSATy, dstMemTyArg, atomVal, srcVal, dst); @@ -214,8 +214,9 @@ std::optional CopyOpCDNA3BufferCopyLDSType::getFieldIndex(AtomStateFie return 0; case AtomStateField::ImmOffset: return 1; + default: + return std::nullopt; } - return std::nullopt; } Type CopyOpCDNA3BufferCopyLDSType::getConvertedType(MLIRContext *ctx) const { @@ -492,7 +493,7 @@ LogicalResult CopyOpCDNA3BufferAtomicType::emitAtomCall(OpBuilder &builder, Loca Type dstMemTyArg, Value atomVal, Value src, Value dst) const { auto srcMemTy = cast(srcMemTyArg); - auto srcSSATy = fly::RegMem2SSAType(srcMemTy); + auto srcSSATy = fly::RegMem2SSAType(srcMemTy, /*llvmCompatibleType=*/true); Value srcVal = LLVM::LoadOp::create(builder, loc, srcSSATy, src); auto res = emitAtomCallSSA(builder, loc, Type{}, copyAtomTyArg, srcSSATy, dstMemTyArg, atomVal, srcVal, dst); diff --git a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp index f8d898e0..9180e0b2 100644 --- a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp @@ -1,10 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright (c) 2025 FlyDSL Project Contributors -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "flydsl/Dialect/Fly/IR/FlyDialect.h" @@ -32,8 +30,6 @@ LayoutAttr getThrValLayoutAB(MLIRContext *ctx, int32_t M, int32_t N, int32_t K, } // namespace cdna3 -namespace cdna4 {} - namespace mlir::fly_rocdl { bool MmaOpCDNA3_MFMAType::isStatic() const { return true; } @@ -187,6 +183,13 @@ FailureOr MmaOpCDNA3_MFMAType::emitAtomCallSSA(OpBuilder &builder, Locati Type accElemTy = getElemTyAcc(); VectorType accTy = VectorType::get({accVecSize}, accElemTy); + if (a.getType() != abTyA) + a = LLVM::BitcastOp::create(builder, loc, abTyA, a); + if (b.getType() != abTyB) + b = LLVM::BitcastOp::create(builder, loc, abTyB, b); + if (c.getType() != accTy) + c = LLVM::BitcastOp::create(builder, loc, accTy, c); + #define DISPATCH_MFMA_SSA(M_, K_, PRED, OP) \ if (m == M_ && n == M_ && k == K_ && (PRED)) { \ auto zeroAttr = builder.getI32IntegerAttr(0); \ diff --git a/lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp b/lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp index 808b6aaf..cdf88650 100644 --- a/lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA4/CopyAtom.cpp @@ -119,7 +119,7 @@ LogicalResult CopyOpCDNA4LdsReadTransposeType::emitAtomCall(OpBuilder &builder, Type copyAtomTyArg, Type srcMemTyArg, Type dstMemTyArg, Value atomVal, Value src, Value dst) const { - auto dstSSATy = fly::RegMem2SSAType(cast(dstMemTyArg)); + auto dstSSATy = fly::RegMem2SSAType(cast(dstMemTyArg), true); auto res = emitAtomCallSSA(builder, loc, dstSSATy, copyAtomTyArg, srcMemTyArg, Type{}, atomVal, src, Value{}); if (failed(res)) diff --git a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp index 0e9dbb41..14553f3d 100644 --- a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp @@ -259,6 +259,13 @@ FailureOr MmaOpGFX1250_WMMAType::emitAtomCallSSA(OpBuilder &builder, Loca VectorType accTy = VectorType::get({accVecSize}, elemTyAcc); + if (a.getType() != abTyA) + a = LLVM::BitcastOp::create(builder, loc, abTyA, a); + if (b.getType() != abTyB) + b = LLVM::BitcastOp::create(builder, loc, abTyB, b); + if (c.getType() != accTy) + c = LLVM::BitcastOp::create(builder, loc, accTy, c); + #define DISPATCH_WMMA_SSA(M_, K_, PRED, OP, VARIANT) \ if (m == M_ && n == M_ && k == K_ && (PRED)) { \ return emitWmmaSSA(builder, loc, accTy, a, b, c); \ diff --git a/tests/mlir/Conversion/mma_atom.mlir b/tests/mlir/Conversion/mma_atom.mlir index fec2a782..35819497 100644 --- a/tests/mlir/Conversion/mma_atom.mlir +++ b/tests/mlir/Conversion/mma_atom.mlir @@ -35,3 +35,17 @@ func.func @test_gemm_from_tiled_mma_arg( fly.gemm(%tiled_mma, %d, %a, %b, %c) : (!fly.tiled_mma f32>>, <(1,4,1):(0,1,0)>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () return } + +// CHECK-LABEL: @test_mma_atom_call_ssa_fp8 +// CHECK-SAME: (%[[A:.*]]: vector<8xi8>, %[[B:.*]]: vector<8xi8>, %[[C:.*]]: vector<4xf32>) +func.func @test_mma_atom_call_ssa_fp8( + %a: vector<8xi8>, + %b: vector<8xi8>, + %c: vector<4xf32>) -> vector<4xf32> { + %atom = fly.make_mma_atom : !fly.mma_atom f32>> + // CHECK: %[[A_CAST:.*]] = llvm.bitcast %[[A]] : vector<8xi8> to i64 + // CHECK: %[[B_CAST:.*]] = llvm.bitcast %[[B]] : vector<8xi8> to i64 + // CHECK: %[[RES:.*]] = rocdl.mfma.f32.16x16x32.fp8.fp8 %[[A_CAST]], %[[B_CAST]], %[[C]] + %res = fly.mma_atom_call_ssa(%atom, %a, %b, %c) : (!fly.mma_atom f32>>, vector<8xi8>, vector<8xi8>, vector<4xf32>) -> vector<4xf32> + return %res : vector<4xf32> +} diff --git a/tests/mlir/Transforms/promote_regmem_to_ssa.mlir b/tests/mlir/Transforms/promote_regmem_to_ssa.mlir index add21060..602b067b 100644 --- a/tests/mlir/Transforms/promote_regmem_to_ssa.mlir +++ b/tests/mlir/Transforms/promote_regmem_to_ssa.mlir @@ -76,11 +76,13 @@ gpu.module @promote_rmem_to_vector_ssa { // CHECK-LABEL: gpu.func @promote_fp8_mma_to_vector_ssa // CHECK-NOT: register // CHECK-NOT: fly.mma_atom_call( - // CHECK: %[[A_STATE:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [0], strides = [1]} : vector<8xf8E4M3FNUZ> into vector<8xf8E4M3FNUZ> - // CHECK: %[[B_STATE:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [0], strides = [1]} : vector<8xf8E4M3FNUZ> into vector<8xf8E4M3FNUZ> + // CHECK: %[[A_BC:.*]] = vector.bitcast %{{.*}} : vector<8xf8E4M3FNUZ> to vector<8xi8> + // CHECK: %[[A_STATE:.*]] = vector.insert_strided_slice %[[A_BC]], %{{.*}} {offsets = [0], strides = [1]} : vector<8xi8> into vector<8xi8> + // CHECK: %[[B_BC:.*]] = vector.bitcast %{{.*}} : vector<8xf8E4M3FNUZ> to vector<8xi8> + // CHECK: %[[B_STATE:.*]] = vector.insert_strided_slice %[[B_BC]], %{{.*}} {offsets = [0], strides = [1]} : vector<8xi8> into vector<8xi8> // CHECK: %[[ACC_INIT:.*]] = vector.insert_strided_slice %{{.*}}, %{{.*}} {offsets = [4], strides = [1]} : vector<4xf32> into vector<8xf32> - // CHECK: %[[A:.*]] = vector.extract_strided_slice %[[A_STATE]] {offsets = [0], sizes = [8], strides = [1]} : vector<8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ> - // CHECK: %[[B:.*]] = vector.extract_strided_slice %[[B_STATE]] {offsets = [0], sizes = [8], strides = [1]} : vector<8xf8E4M3FNUZ> to vector<8xf8E4M3FNUZ> + // CHECK: %[[A:.*]] = vector.extract_strided_slice %[[A_STATE]] {offsets = [0], sizes = [8], strides = [1]} : vector<8xi8> to vector<8xi8> + // CHECK: %[[B:.*]] = vector.extract_strided_slice %[[B_STATE]] {offsets = [0], sizes = [8], strides = [1]} : vector<8xi8> to vector<8xi8> // CHECK: %[[C:.*]] = vector.extract_strided_slice %[[ACC_INIT]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf32> to vector<4xf32> // CHECK: %[[RES:.*]] = fly.mma_atom_call_ssa(%{{.*}}, %[[A]], %[[B]], %[[C]]) // CHECK-SAME: -> vector<4xf32> From c1ea6985c9648ed6d308524bf85c3254296f16de Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 21 Apr 2026 09:14:20 +0800 Subject: [PATCH 22/29] [ROCDL] Add CDNA4_MFMAScaleType (#417) * [ROCDL] Add CDNA4_MFMAScaleType --- include/flydsl/Dialect/FlyROCDL/IR/Atom.td | 6 +- include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td | 32 ++ lib/Bindings/Python/FlyROCDLExtension.cpp | 28 +- lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp | 80 ++--- lib/Dialect/FlyROCDL/CDNA4/MmaAtom.cpp | 295 ++++++++++++++++++ lib/Dialect/FlyROCDL/CMakeLists.txt | 1 + python/flydsl/expr/rocdl/cdna4.py | 22 +- python/flydsl/expr/rocdl/universal.py | 8 +- tests/mlir/Conversion/mma_atom_stateful.mlir | 128 ++++++++ 9 files changed, 535 insertions(+), 65 deletions(-) create mode 100644 lib/Dialect/FlyROCDL/CDNA4/MmaAtom.cpp create mode 100644 tests/mlir/Conversion/mma_atom_stateful.mlir diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Atom.td b/include/flydsl/Dialect/FlyROCDL/IR/Atom.td index be9ecb38..77b8d347 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/Atom.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/Atom.td @@ -7,8 +7,10 @@ include "flydsl/Dialect/FlyROCDL/IR/Dialect.td" def FlyROCDL_AtomStateField : I32EnumAttr<"AtomStateField", "", [ - I32EnumAttrCase<"Soffset", 0, "soffset">, - I32EnumAttrCase<"ImmOffset", 1, "imm_offset"> + I32EnumAttrCase<"Soffset", 0, "soffset">, + I32EnumAttrCase<"ImmOffset", 1, "imm_offset">, + I32EnumAttrCase<"ScaleA", 2, "scale_a">, + I32EnumAttrCase<"ScaleB", 3, "scale_b"> ]> { let genSpecializedAttr = 0; let cppNamespace = FlyROCDL_Dialect.cppNamespace; diff --git a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td index 1113b9d2..29732d2e 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td @@ -33,7 +33,39 @@ def FlyROCDL_MmaOpCDNA3_MFMA : FlyROCL_MmaOp<"MmaOpCDNA3_MFMA", "cdna3.mfma", [] // MmaOp CDNA4 //===----------------------------------------------------------------------===// +def FlyROCDL_MmaOpCDNA4_MFMAScale : FlyROCL_StatefulMmaOp<"MmaOpCDNA4_MFMAScale", "cdna4.mfma_scale", []> { + let parameters = (ins + "int32_t":$m, + "int32_t":$n, + "int32_t":$k, + "Type":$elemTyA, + "Type":$elemTyB, + "Type":$elemTyAcc, + "int32_t":$opselA, + "int32_t":$opselB + ); + let assemblyFormat = [{ + `<` custom($m, $n, $k) `,` `(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc + `,` `opselA` `=` $opselA `,` `opselB` `=` $opselB `>` + }]; + let builders = [ + TypeBuilderWithInferredContext<(ins + "int32_t":$m, "int32_t":$n, "int32_t":$k, + "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc), [{ + return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc, + /*opselA=*/0, /*opselB=*/0); + }]>, + TypeBuilderWithInferredContext<(ins + "int32_t":$m, "int32_t":$n, "int32_t":$k, + "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc, + "int32_t":$opselA, "int32_t":$opselB), [{ + return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc, + opselA, opselB); + }]> + ]; + let genVerifyDecl = 1; +} //===----------------------------------------------------------------------===// // MmaOp GFX1250 — WMMA wave32 diff --git a/lib/Bindings/Python/FlyROCDLExtension.cpp b/lib/Bindings/Python/FlyROCDLExtension.cpp index b29391a5..fb4e886b 100644 --- a/lib/Bindings/Python/FlyROCDLExtension.cpp +++ b/lib/Bindings/Python/FlyROCDLExtension.cpp @@ -38,6 +38,27 @@ struct PyMmaOpCDNA3_MFMAType : PyConcreteType { } }; +struct PyMmaOpCDNA4_MFMAScaleType : PyConcreteType { + FLYDSL_REGISTER_TYPE_BINDING(MmaOpCDNA4_MFMAScaleType, "MmaOpCDNA4_MFMAScaleType"); + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int32_t m, int32_t n, int32_t k, PyType &elemTyA, PyType &elemTyB, PyType &elemTyAcc, + int32_t opselA, int32_t opselB, DefaultingPyMlirContext context) { + return PyMmaOpCDNA4_MFMAScaleType( + context->getRef(), + wrap(MmaOpCDNA4_MFMAScaleType::get(m, n, k, unwrap(elemTyA), unwrap(elemTyB), + unwrap(elemTyAcc), opselA, opselB))); + }, + "m"_a, "n"_a, "k"_a, "elem_ty_a"_a, "elem_ty_b"_a, "elem_ty_acc"_a, "opsel_a"_a = 0, + "opsel_b"_a = 0, nb::kw_only(), "context"_a = nb::none(), + "Create a MmaOpCDNA4_MFMAScaleType with m, n, k dimensions, element types, " + "and optional opsel_a / opsel_b (compile-time lane index into the scale " + "vector, default 0)"); + } +}; + struct PyMmaOpGFX1250_WMMAType : PyConcreteType { FLYDSL_REGISTER_TYPE_BINDING(MmaOpGFX1250_WMMAType, "MmaOpGFX1250_WMMAType"); @@ -97,8 +118,8 @@ struct PyCopyOpCDNA3BufferAtomicType : PyConcreteTypeget()); - auto atomicOpAttr = ::mlir::fly::AtomicOpAttr::get( - ctx, static_cast<::mlir::fly::AtomicOp>(atomicOp)); + auto atomicOpAttr = + ::mlir::fly::AtomicOpAttr::get(ctx, static_cast<::mlir::fly::AtomicOp>(atomicOp)); return PyCopyOpCDNA3BufferAtomicType( context->getRef(), wrap(CopyOpCDNA3BufferAtomicType::get(atomicOpAttr, unwrap(valTypeObj)))); @@ -133,10 +154,13 @@ struct PyCopyOpCDNA4LdsReadTransposeType : PyConcreteType emi static bool isFP8(Type ty) { return isa(ty); } static bool isBF8(Type ty) { return isa(ty); } -static bool isF8(Type ty) { return isFP8(ty) || isBF8(ty); } -static Type getMfmaABType(MLIRContext *ctx, Type elemTy, int32_t k = 0) { +static Type getMfmaABType(MLIRContext *ctx, Type elemTy, int32_t mn, int32_t k = 0) { if (elemTy.isF32()) return Float32Type::get(ctx); if (elemTy.isF16()) - return VectorType::get({4}, Float16Type::get(ctx)); - if (elemTy.isBF16()) - return VectorType::get({(k >= 16) ? 4 : 2}, IntegerType::get(ctx, 16)); + return VectorType::get({mn * k / 64}, Float16Type::get(ctx)); + if (elemTy.isBF16()) { + int vecSize = mn * k / 64; + Type elemTy; + if (vecSize == 8) { + elemTy = BFloat16Type::get(ctx); // CDNA4 version + } else { + elemTy = IntegerType::get(ctx, 16); + } + return VectorType::get({vecSize}, elemTy); + } if (elemTy.getIntOrFloatBitWidth() == 8) return IntegerType::get(ctx, 64); return nullptr; } -static int64_t getMfmaAccVecSize(int32_t m, int32_t k, Type elemTyA) { - if (elemTyA.isF32()) { - if (m == 32 && k == 1) - return 32; - if (m == 32 && k == 2) - return 16; - if (m == 16 && k == 1) - return 16; - if (m == 16 && k == 4) - return 4; - if (m == 4 && k == 1) - return 4; - } - if (elemTyA.isF16()) { - if (m == 32 && k == 4) - return 32; - if (m == 32 && k == 8) - return 16; - if (m == 16 && k == 4) - return 16; - if (m == 16 && k == 16) - return 4; - if (m == 4 && k == 4) - return 4; - } - if (elemTyA.isBF16()) { - if (m == 32 && k == 2) - return 32; - if (m == 32 && k == 4) - return 16; - if (m == 16 && k == 2) - return 16; - if (m == 16 && k == 8) - return 4; - if (m == 16 && k == 16) - return 4; - if (m == 4 && k == 2) - return 4; - } - if (isF8(elemTyA)) { - if (m == 16 && k == 32) - return 4; - if (m == 32 && k == 16) - return 16; - } +static int64_t getMfmaAccVecSize(int32_t m, int32_t n, Type elemTyA) { + if (m == 16 && n == 16) + return 4; + if (m == 32 && n == 32) + return 16; return 0; } @@ -171,12 +138,12 @@ FailureOr MmaOpCDNA3_MFMAType::emitAtomCallSSA(OpBuilder &builder, Locati Type elemTyB = getElemTyB(); MLIRContext *ctx = builder.getContext(); - Type abTyA = getMfmaABType(ctx, elemTyA, k); - Type abTyB = getMfmaABType(ctx, elemTyB, k); + Type abTyA = getMfmaABType(ctx, elemTyA, m, k); + Type abTyB = getMfmaABType(ctx, elemTyB, n, k); if (!abTyA || !abTyB) return failure(); - int64_t accVecSize = getMfmaAccVecSize(m, k, elemTyA); + int64_t accVecSize = getMfmaAccVecSize(m, n, elemTyA); if (accVecSize == 0) return failure(); @@ -235,17 +202,18 @@ LogicalResult MmaOpCDNA3_MFMAType::emitAtomCall(OpBuilder &builder, Location loc Value atomVal, Value dPtr, Value aPtr, Value bPtr, Value cPtr) const { int32_t m = getM(); + int32_t n = getN(); int32_t k = getK(); Type elemTyA = getElemTyA(); Type elemTyB = getElemTyB(); MLIRContext *ctx = builder.getContext(); - Type abTyA = getMfmaABType(ctx, elemTyA, k); - Type abTyB = getMfmaABType(ctx, elemTyB, k); + Type abTyA = getMfmaABType(ctx, elemTyA, m, k); + Type abTyB = getMfmaABType(ctx, elemTyB, n, k); if (!abTyA || !abTyB) return failure(); - int64_t accVecSize = getMfmaAccVecSize(m, k, elemTyA); + int64_t accVecSize = getMfmaAccVecSize(m, n, elemTyA); if (accVecSize == 0) return failure(); diff --git a/lib/Dialect/FlyROCDL/CDNA4/MmaAtom.cpp b/lib/Dialect/FlyROCDL/CDNA4/MmaAtom.cpp new file mode 100644 index 00000000..c96cca66 --- /dev/null +++ b/lib/Dialect/FlyROCDL/CDNA4/MmaAtom.cpp @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::fly; + +namespace cdna4 { + +LayoutAttr getThrValLayoutAB(MLIRContext *ctx, int32_t M, int32_t N, int32_t K, Type elemTy) { + auto getContext = [&]() { return ctx; }; + + int MN = M; + assert(M == N && "M and N must be equal"); + + int GroupK = 64 / MN; + int KPerThread = K / GroupK; + + return FxLayout(FxShape(FxThr(MN, GroupK), FxVal(KPerThread)), + FxStride(FxThr(1, MN * KPerThread), FxVal(MN))); +} + +LayoutAttr getThrValLayoutC(MLIRContext *ctx, int32_t M, int32_t N) { + auto getContext = [&]() { return ctx; }; + + int GroupM = 64 / N; + int ValM0 = 4; + int ValM1 = M / 4 / GroupM; + + return FxLayout(FxShape(FxThr(N, GroupM), FxVal(ValM0, ValM1)), + FxStride(FxThr(M, ValM0), FxVal(1, ValM0 * GroupM))); +} + +} // namespace cdna4 + +namespace mlir::fly_rocdl { + +//===----------------------------------------------------------------------===// +// MmaOpCDNA4_MFMAScaleType +//===----------------------------------------------------------------------===// + +std::optional MmaOpCDNA4_MFMAScaleType::getFieldIndex(AtomStateField field) { + switch (field) { + case AtomStateField::ScaleA: + return 0; + case AtomStateField::ScaleB: + return 1; + default: + return std::nullopt; + } +} + +Type MmaOpCDNA4_MFMAScaleType::getConvertedType(MLIRContext *ctx) const { + auto i32Ty = IntegerType::get(ctx, 32); + return LLVM::LLVMStructType::getLiteral(ctx, {i32Ty, i32Ty}); +} + +Value MmaOpCDNA4_MFMAScaleType::getDefaultState(OpBuilder &builder, Location loc) const { + auto structTy = cast(getConvertedType(builder.getContext())); + Value state = LLVM::UndefOp::create(builder, loc, structTy); + Value zero = arith::ConstantIntOp::create(builder, loc, 0, 32); + state = LLVM::InsertValueOp::create(builder, loc, state, zero, + ArrayRef{*getFieldIndex(AtomStateField::ScaleA)}); + state = LLVM::InsertValueOp::create(builder, loc, state, zero, + ArrayRef{*getFieldIndex(AtomStateField::ScaleB)}); + return state; +} + +Value MmaOpCDNA4_MFMAScaleType::setAtomState(OpBuilder &builder, Location loc, Value atomStruct, + Attribute fieldAttr, Value fieldValue) const { + auto fieldStr = dyn_cast(fieldAttr); + if (!fieldStr) + return nullptr; + auto field = symbolizeAtomStateField(fieldStr.getValue()); + if (!field) + return nullptr; + auto idx = getFieldIndex(*field); + if (!idx) + return nullptr; + Value scaleVal = fieldValue; + Type srcTy = scaleVal.getType(); + Type i32Ty = IntegerType::get(builder.getContext(), 32); + if (srcTy != i32Ty) { + auto bitWidthOf = [](Type t) -> unsigned { + if (auto vec = dyn_cast(t)) { + Type elt = vec.getElementType(); + if (!elt.isIntOrFloat()) + return 0; + return elt.getIntOrFloatBitWidth() * vec.getNumElements(); + } + if (auto intTy = dyn_cast(t)) + return intTy.getWidth(); + return 0; + }; + if (bitWidthOf(srcTy) != 32) + return nullptr; + scaleVal = LLVM::BitcastOp::create(builder, loc, i32Ty, scaleVal); + } + return LLVM::InsertValueOp::create(builder, loc, atomStruct, scaleVal, ArrayRef{*idx}); +} + +Attribute MmaOpCDNA4_MFMAScaleType::getThrLayout() const { return FxLayout(FxC(64), FxC(1)); } + +Attribute MmaOpCDNA4_MFMAScaleType::getShapeMNK() const { + return IntTupleAttr::get(ArrayAttr::get(getContext(), {FxC(getM()), FxC(getN()), FxC(getK())})); +} + +Type MmaOpCDNA4_MFMAScaleType::getValTypeA() const { return getElemTyA(); } +Type MmaOpCDNA4_MFMAScaleType::getValTypeB() const { return getElemTyB(); } +Type MmaOpCDNA4_MFMAScaleType::getValTypeC() const { return getElemTyAcc(); } +Type MmaOpCDNA4_MFMAScaleType::getValTypeD() const { return getElemTyAcc(); } + +Attribute MmaOpCDNA4_MFMAScaleType::getThrValLayoutA() const { + return cdna4::getThrValLayoutAB(getContext(), getM(), getN(), getK(), getElemTyA()); +} + +Attribute MmaOpCDNA4_MFMAScaleType::getThrValLayoutB() const { + return cdna4::getThrValLayoutAB(getContext(), getM(), getN(), getK(), getElemTyB()); +} + +Attribute MmaOpCDNA4_MFMAScaleType::getThrValLayoutC() const { + return cdna4::getThrValLayoutC(getContext(), getM(), getN()); +} + +static std::optional mfmaFloatTypeEncode(Type elemTy) { + if (isa(elemTy)) + return 0u; + if (isa(elemTy)) + return 1u; + if (isa(elemTy)) + return 2u; + if (isa(elemTy)) + return 3u; + if (isa(elemTy)) + return 4u; + return std::nullopt; +} + +static bool isSupportedScaledElemTy(Type ty) { + return isa(ty); +} + +LogicalResult MmaOpCDNA4_MFMAScaleType::verify(function_ref emitError, + int32_t m, int32_t n, int32_t k, Type elemTyA, + Type elemTyB, Type elemTyAcc, int32_t opselA, + int32_t opselB) { + if (!((m == 16 && n == 16 && k == 128) || (m == 32 && n == 32 && k == 64))) { + return emitError() << "unsupported MNK for CDNA4 MFMA_Scale: " << m << "x" << n << "x" << k + << " (expected 16x16x128 or 32x32x64)"; + } + if (!elemTyAcc.isF32()) + return emitError() << "elemTyAcc must be f32, got " << elemTyAcc; + if (!isSupportedScaledElemTy(elemTyA)) { + return emitError() << "elemTyA must be one of f8E4M3FN, f8E5M2, f6E2M3FN, " + "f6E3M2FN, f4E2M1FN, got " + << elemTyA; + } + if (!isSupportedScaledElemTy(elemTyB)) { + return emitError() << "elemTyB must be one of f8E4M3FN, f8E5M2, f6E2M3FN, " + "f6E3M2FN, f4E2M1FN, got " + << elemTyB; + } + if (opselA < 0 || opselA > 3) + return emitError() << "opselA must be in [0, 3], got " << opselA; + if (opselB < 0 || opselB > 3) + return emitError() << "opselB must be in [0, 3], got " << opselB; + return success(); +} + +static Type getScaledMfmaABType(MLIRContext *ctx, Type elemTy) { + Type i32Ty = IntegerType::get(ctx, 32); + if (isa(elemTy)) + return VectorType::get({8}, i32Ty); + if (isa(elemTy)) + return VectorType::get({6}, i32Ty); + if (isa(elemTy)) + return VectorType::get({4}, i32Ty); + return nullptr; +} + +static int64_t getScaledMfmaAccVecSize(int32_t m, int32_t n) { + if (m == 16 && n == 16) + return 4; + if (m == 32 && n == 32) + return 16; + return 0; +} + +FailureOr MmaOpCDNA4_MFMAScaleType::emitAtomCallSSA(OpBuilder &builder, Location loc, + Type resultTy, Type mmaAtomTyArg, + Type dTyArg, Type aTyArg, Type bTyArg, + Type cTyArg, Value atomVal, Value d, + Value a, Value b, Value c) const { + int32_t m = getM(); + int32_t n = getN(); + int32_t k = getK(); + Type elemTyA = getElemTyA(); + Type elemTyB = getElemTyB(); + MLIRContext *ctx = builder.getContext(); + + Type abTyA = getScaledMfmaABType(ctx, elemTyA); + Type abTyB = getScaledMfmaABType(ctx, elemTyB); + if (!abTyA || !abTyB) + return failure(); + + int64_t accVecSize = getScaledMfmaAccVecSize(m, n); + if (accVecSize == 0) + return failure(); + + std::optional aTypeCode = mfmaFloatTypeEncode(elemTyA); + std::optional bTypeCode = mfmaFloatTypeEncode(elemTyB); + if (!aTypeCode || !bTypeCode) + return failure(); + + Type accElemTy = getElemTyAcc(); + VectorType accTy = VectorType::get({accVecSize}, accElemTy); + + if (a.getType() != abTyA) + a = LLVM::BitcastOp::create(builder, loc, abTyA, a); + if (b.getType() != abTyB) + b = LLVM::BitcastOp::create(builder, loc, abTyB, b); + if (c.getType() != accTy) + c = LLVM::BitcastOp::create(builder, loc, accTy, c); + + Value scaleA = LLVM::ExtractValueOp::create( + builder, loc, atomVal, ArrayRef{*getFieldIndex(AtomStateField::ScaleA)}); + Value scaleB = LLVM::ExtractValueOp::create( + builder, loc, atomVal, ArrayRef{*getFieldIndex(AtomStateField::ScaleB)}); + + auto cbszAttr = builder.getI32IntegerAttr(*aTypeCode); + auto blgpAttr = builder.getI32IntegerAttr(*bTypeCode); + auto opselAAttr = builder.getI32IntegerAttr(getOpselA()); + auto opselBAttr = builder.getI32IntegerAttr(getOpselB()); + + if (m == 16 && n == 16 && k == 128) { + return ROCDL::mfma_scale_f32_16x16x128_f8f6f4::create(builder, loc, accTy, a, b, c, cbszAttr, + blgpAttr, opselAAttr, scaleA, opselBAttr, + scaleB) + .getResult(); + } + if (m == 32 && n == 32 && k == 64) { + return ROCDL::mfma_scale_f32_32x32x64_f8f6f4::create(builder, loc, accTy, a, b, c, cbszAttr, + blgpAttr, opselAAttr, scaleA, opselBAttr, + scaleB) + .getResult(); + } + + return failure(); +} + +LogicalResult MmaOpCDNA4_MFMAScaleType::emitAtomCall(OpBuilder &builder, Location loc, + Type mmaAtomTy, Type dMemTy, Type aMemTy, + Type bMemTy, Type cMemTy, Value atomVal, + Value dPtr, Value aPtr, Value bPtr, + Value cPtr) const { + int32_t m = getM(); + int32_t n = getN(); + Type elemTyA = getElemTyA(); + Type elemTyB = getElemTyB(); + MLIRContext *ctx = builder.getContext(); + + Type abTyA = getScaledMfmaABType(ctx, elemTyA); + Type abTyB = getScaledMfmaABType(ctx, elemTyB); + if (!abTyA || !abTyB) + return failure(); + + int64_t accVecSize = getScaledMfmaAccVecSize(m, n); + if (accVecSize == 0) + return failure(); + + Type accElemTy = getElemTyAcc(); + VectorType accTy = VectorType::get({accVecSize}, accElemTy); + + Value a = LLVM::LoadOp::create(builder, loc, abTyA, aPtr); + Value b = LLVM::LoadOp::create(builder, loc, abTyB, bPtr); + Value c = LLVM::LoadOp::create(builder, loc, accTy, cPtr); + auto res = emitAtomCallSSA(builder, loc, accTy, mmaAtomTy, Type{}, abTyA, abTyB, accTy, atomVal, + Value{}, a, b, c); + if (failed(res)) + return failure(); + LLVM::StoreOp::create(builder, loc, *res, dPtr); + return success(); +} + +} // namespace mlir::fly_rocdl diff --git a/lib/Dialect/FlyROCDL/CMakeLists.txt b/lib/Dialect/FlyROCDL/CMakeLists.txt index 9309b6b6..9a7be1b4 100644 --- a/lib/Dialect/FlyROCDL/CMakeLists.txt +++ b/lib/Dialect/FlyROCDL/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRFlyROCDLDialect Dialect.cpp CDNA3/MmaAtom.cpp CDNA3/CopyAtom.cpp + CDNA4/MmaAtom.cpp CDNA4/CopyAtom.cpp GFX1250/MmaAtom.cpp diff --git a/python/flydsl/expr/rocdl/cdna4.py b/python/flydsl/expr/rocdl/cdna4.py index dc50d91f..9b691157 100644 --- a/python/flydsl/expr/rocdl/cdna4.py +++ b/python/flydsl/expr/rocdl/cdna4.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -from ..._mlir.dialects.fly_rocdl import CopyOpCDNA4LdsReadTransposeType +from ..._mlir.dialects.fly_rocdl import CopyOpCDNA4LdsReadTransposeType, MmaOpCDNA4_MFMAScaleType +from ..._mlir.extras import types as T def LDSReadTrans(trans_granularity, bit_size): @@ -13,3 +14,22 @@ def LDSReadTrans(trans_granularity, bit_size): LDSReadTrans8_64b = lambda: CopyOpCDNA4LdsReadTransposeType.get(8, 64) LDSReadTrans6_96b = lambda: CopyOpCDNA4LdsReadTransposeType.get(6, 96) LDSReadTrans16_64b = lambda: CopyOpCDNA4LdsReadTransposeType.get(16, 64) + + +def MFMA_Scale(m, n, k, elem_ty_a, elem_ty_b=None, elem_ty_acc=None, *, opsel_a=0, opsel_b=0): + """Create a CDNA4 scaled MFMA atom (mfma.scale.f32.*.f8f6f4). + + Current atom state: + - `scale_a` (`i32`), default zero + - `scale_b` (`i32`), default zero + """ + ty_a = elem_ty_a.ir_type if hasattr(elem_ty_a, "ir_type") else elem_ty_a + if elem_ty_b is None: + ty_b = ty_a + else: + ty_b = elem_ty_b.ir_type if hasattr(elem_ty_b, "ir_type") else elem_ty_b + if elem_ty_acc is None: + ty_acc = T.f32() + else: + ty_acc = elem_ty_acc.ir_type if hasattr(elem_ty_acc, "ir_type") else elem_ty_acc + return MmaOpCDNA4_MFMAScaleType.get(m, n, k, ty_a, ty_b, ty_acc, opsel_a=opsel_a, opsel_b=opsel_b) diff --git a/python/flydsl/expr/rocdl/universal.py b/python/flydsl/expr/rocdl/universal.py index c6f3ca86..b2ea260f 100644 --- a/python/flydsl/expr/rocdl/universal.py +++ b/python/flydsl/expr/rocdl/universal.py @@ -26,7 +26,7 @@ def BufferCopy(bit_size): """Create a CDNA3 buffer copy atom. Current atom state: - - `soffset` (`i32`) + - `soffset` (`i32`), default zero """ return CopyOpCDNA3BufferCopyType.get(bit_size) @@ -45,8 +45,8 @@ def BufferCopyLDS(bit_size): Only supports BufferDesc -> Shared address space direction. Current atom state: - - `soffset` (`i32`) - - `imm_offset` (`i32`) + - `soffset` (`i32`), default zero + - `imm_offset` (`i32`), default zero """ return CopyOpCDNA3BufferCopyLDSType.get(bit_size) @@ -60,7 +60,7 @@ def BufferAtomic(atomic_op, val_type): """Create a CDNA3 buffer atomic copy atom. Current atom state: - - `soffset` (`i32`) + - `soffset` (`i32`), default zero """ ty = val_type.ir_type if hasattr(val_type, "ir_type") else val_type return CopyOpCDNA3BufferAtomicType.get(int(atomic_op), ty) diff --git a/tests/mlir/Conversion/mma_atom_stateful.mlir b/tests/mlir/Conversion/mma_atom_stateful.mlir new file mode 100644 index 00000000..14c64c2f --- /dev/null +++ b/tests/mlir/Conversion/mma_atom_stateful.mlir @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s --fly-rewrite-func-signature --fly-canonicalize --fly-layout-lowering --convert-fly-to-rocdl | FileCheck %s + +// Stateful MmaAtom lowering tests (CDNA4 MFMA_Scale). +// fly.make_mma_atom -> default (0, 0) state +// fly.atom.set_value -> llvm.insertvalue at scale_a / scale_b field +// fly.mma_atom_call -> rocdl.mfma.scale.f32.*.f8f6f4 intrinsic +// Register memrefs are materialised inside the function body via +// fly.memref.alloca (register memrefs are not valid function arguments). + +// ----- + +// Stateful MMA atom type converts to !llvm.struct<(i32, i32)> +// CHECK-LABEL: @test_stateful_mma_scale_type +// CHECK-SAME: (%[[ATOM:.*]]: !llvm.struct<(i32, i32)>) +func.func @test_stateful_mma_scale_type( + %atom: !fly.mma_atom f32, opselA = 0, opselB = 0>>) { + return +} + +// ----- + +// make_mma_atom produces default state via getDefaultState (scaleA/scaleB = 0) +// and the atom feeds into mma_atom_call so it is not DCE'd. + +// CHECK-LABEL: @test_make_mma_atom_default_scales +func.func @test_make_mma_atom_default_scales() { + %lay_ab = fly.static : !fly.layout<32:1> + %lay_cd = fly.static : !fly.layout<4:1> + + // CHECK: llvm.alloca %{{.*}} x f32 : (i64) -> !llvm.ptr<5> + // CHECK: llvm.alloca %{{.*}} x i8 : (i64) -> !llvm.ptr<5> + // CHECK: llvm.alloca %{{.*}} x i8 : (i64) -> !llvm.ptr<5> + // CHECK: llvm.alloca %{{.*}} x f32 : (i64) -> !llvm.ptr<5> + %d = fly.memref.alloca(%lay_cd) : (!fly.layout<4:1>) -> !fly.memref + %a = fly.memref.alloca(%lay_ab) : (!fly.layout<32:1>) -> !fly.memref + %b = fly.memref.alloca(%lay_ab) : (!fly.layout<32:1>) -> !fly.memref + %c = fly.memref.alloca(%lay_cd) : (!fly.layout<4:1>) -> !fly.memref + + // CHECK-DAG: %[[UNDEF:.*]] = llvm.mlir.undef : !llvm.struct<(i32, i32)> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[S1:.*]] = llvm.insertvalue %[[C0]], %[[UNDEF]][0] + // CHECK: llvm.insertvalue %[[C0]], %[[S1]][1] + %atom = fly.make_mma_atom : !fly.mma_atom f32, opselA = 0, opselB = 0>> + fly.mma_atom_call(%atom, %d, %a, %b, %c) : (!fly.mma_atom f32, opselA = 0, opselB = 0>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + return +} + +// ----- + +// End-to-end: set scales then mma_atom_call lowers to rocdl.mfma.scale.f32.16x16x128.f8f6f4 + +// CHECK-LABEL: @test_mma_scale_atom_call_16x16x128 +// CHECK-SAME: (%[[ATOM:.*]]: !llvm.struct<(i32, i32)>, %[[SA:.*]]: i32, %[[SB:.*]]: i32) +func.func @test_mma_scale_atom_call_16x16x128( + %atom: !fly.mma_atom f32, opselA = 0, opselB = 0>>, + %scale_a: i32, + %scale_b: i32) { + %lay_ab = fly.static : !fly.layout<32:1> + %lay_cd = fly.static : !fly.layout<4:1> + + // CHECK: %[[D:.*]] = llvm.alloca %{{.*}} x f32 : (i64) -> !llvm.ptr<5> + // CHECK: %[[A:.*]] = llvm.alloca %{{.*}} x i8 : (i64) -> !llvm.ptr<5> + // CHECK: %[[B:.*]] = llvm.alloca %{{.*}} x i8 : (i64) -> !llvm.ptr<5> + // CHECK: %[[C:.*]] = llvm.alloca %{{.*}} x f32 : (i64) -> !llvm.ptr<5> + %d = fly.memref.alloca(%lay_cd) : (!fly.layout<4:1>) -> !fly.memref + %a = fly.memref.alloca(%lay_ab) : (!fly.layout<32:1>) -> !fly.memref + %b = fly.memref.alloca(%lay_ab) : (!fly.layout<32:1>) -> !fly.memref + %c = fly.memref.alloca(%lay_cd) : (!fly.layout<4:1>) -> !fly.memref + + // CHECK: %[[A1:.*]] = llvm.insertvalue %[[SA]], %[[ATOM]][0] + %atom_a = fly.atom.set_value(%atom, "scale_a", %scale_a) : (!fly.mma_atom f32, opselA = 0, opselB = 0>>, i32) -> !fly.mma_atom f32, opselA = 0, opselB = 0>> + // CHECK: %[[A2:.*]] = llvm.insertvalue %[[SB]], %[[A1]][1] + %atom_ab = fly.atom.set_value(%atom_a, "scale_b", %scale_b) : (!fly.mma_atom f32, opselA = 0, opselB = 0>>, i32) -> !fly.mma_atom f32, opselA = 0, opselB = 0>> + + // CHECK-DAG: %[[A_VAL:.*]] = llvm.load %[[A]] : !llvm.ptr<5> -> vector<8xi32> + // CHECK-DAG: %[[B_VAL:.*]] = llvm.load %[[B]] : !llvm.ptr<5> -> vector<8xi32> + // CHECK-DAG: %[[C_VAL:.*]] = llvm.load %[[C]] : !llvm.ptr<5> -> vector<4xf32> + // CHECK-DAG: %[[SA_VAL:.*]] = llvm.extractvalue %[[A2]][0] + // CHECK-DAG: %[[SB_VAL:.*]] = llvm.extractvalue %[[A2]][1] + // CHECK: %[[RES:.*]] = rocdl.mfma.scale.f32.16x16x128.f8f6f4 %[[A_VAL]], %[[B_VAL]], %[[C_VAL]], 0, 0, 0, %[[SA_VAL]], 0, %[[SB_VAL]] + // CHECK: llvm.store %[[RES]], %[[D]] : vector<4xf32>, !llvm.ptr<5> + fly.mma_atom_call(%atom_ab, %d, %a, %b, %c) : (!fly.mma_atom f32, opselA = 0, opselB = 0>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + return +} + +// ----- + +// 32x32x64 shape with f4E2M1FN / f8E5M2 mixed operands -> cbsz=4 (fp4), blgp=1 (bf8/E5M2) + +// CHECK-LABEL: @test_mma_scale_atom_call_32x32x64_mixed +func.func @test_mma_scale_atom_call_32x32x64_mixed( + %atom: !fly.mma_atom f32, opselA = 0, opselB = 0>>) { + %lay_ab = fly.static : !fly.layout<32:1> + %lay_cd = fly.static : !fly.layout<16:1> + + %d = fly.memref.alloca(%lay_cd) : (!fly.layout<16:1>) -> !fly.memref + %a = fly.memref.alloca(%lay_ab) : (!fly.layout<32:1>) -> !fly.memref + %b = fly.memref.alloca(%lay_ab) : (!fly.layout<32:1>) -> !fly.memref + %c = fly.memref.alloca(%lay_cd) : (!fly.layout<16:1>) -> !fly.memref + + // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 4, 1, 0, %{{.*}}, 0, %{{.*}} + fly.mma_atom_call(%atom, %d, %a, %b, %c) : (!fly.mma_atom f32, opselA = 0, opselB = 0>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + return +} + +// ----- + +// User-specified opselA / opselB are forwarded to the intrinsic as the two +// I32 attrs attached to scaleA / scaleB. opsel is part of the atom type — to +// change it at runtime, construct a new mma_atom with the desired opsel. + +// CHECK-LABEL: @test_mma_scale_atom_call_with_opsel +func.func @test_mma_scale_atom_call_with_opsel( + %atom: !fly.mma_atom f32, opselA = 2, opselB = 3>>) { + %lay_ab = fly.static : !fly.layout<32:1> + %lay_cd = fly.static : !fly.layout<4:1> + + %d = fly.memref.alloca(%lay_cd) : (!fly.layout<4:1>) -> !fly.memref + %a = fly.memref.alloca(%lay_ab) : (!fly.layout<32:1>) -> !fly.memref + %b = fly.memref.alloca(%lay_ab) : (!fly.layout<32:1>) -> !fly.memref + %c = fly.memref.alloca(%lay_cd) : (!fly.layout<4:1>) -> !fly.memref + + // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4 %{{.*}}, %{{.*}}, %{{.*}}, 0, 0, 2, %{{.*}}, 3, %{{.*}} + fly.mma_atom_call(%atom, %d, %a, %b, %c) : (!fly.mma_atom f32, opselA = 2, opselB = 3>>, !fly.memref, !fly.memref, !fly.memref, !fly.memref) -> () + return +} From 898c8d362310cd6ac4e51b7399f9d8f3d0e684c2 Mon Sep 17 00:00:00 2001 From: yanboshao Date: Tue, 21 Apr 2026 11:50:11 +0800 Subject: [PATCH 23/29] Adjust allreduce CI benchmark thresholds and replace error column with acc_res (#419) - Raise regression thresholds to 10us absolute / 15% relative - Replace aggregate 'error' field (always NaN) with 'acc_res' showing pass/failed Co-authored-by: root --- tests/kernels/compare_allreduce_benchmark.py | 12 ++++++------ tests/kernels/test_allreduce.py | 12 +++++++----- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/kernels/compare_allreduce_benchmark.py b/tests/kernels/compare_allreduce_benchmark.py index 2b0942cc..6f159b0c 100644 --- a/tests/kernels/compare_allreduce_benchmark.py +++ b/tests/kernels/compare_allreduce_benchmark.py @@ -5,14 +5,14 @@ python3 compare_benchmark.py Exit code 1 if any case regresses more than BOTH thresholds: - - relative increase > MAX_REGRESSION_PCT (default 10%) - - absolute increase > MIN_ABS_REGRESSION_US (default 5 us) + - relative increase > MAX_REGRESSION_PCT (default 15%) + - absolute increase > MIN_ABS_REGRESSION_US (default 10 us) """ import sys import pandas as pd -MAX_REGRESSION_PCT = 10.0 -MIN_ABS_REGRESSION_US = 5.0 +MAX_REGRESSION_PCT = 15.0 +MIN_ABS_REGRESSION_US = 10.0 def main(): @@ -76,8 +76,8 @@ def main(): print("\n=== Cases BROKEN in PR (work on main but fail on PR) ===") for shape, dtype in newly_broken: fail_count += 1 - err = pr_agg_indexed.loc[(shape, dtype)].get("error", "unknown") - print(f" {shape:>20s} {dtype:>4s} [BROKEN] error: {err}") + acc = pr_agg_indexed.loc[(shape, dtype)].get("acc_res", "unknown") + print(f" {shape:>20s} {dtype:>4s} [BROKEN] acc_res: {acc}") if fail_count > 0: print(f"\nFAILED: {fail_count} issue(s) detected.") diff --git a/tests/kernels/test_allreduce.py b/tests/kernels/test_allreduce.py index eb3d9269..80e44082 100644 --- a/tests/kernels/test_allreduce.py +++ b/tests/kernels/test_allreduce.py @@ -637,6 +637,10 @@ def run_all_tests( # Add aggregate row rank0 = rank_results[0] if rank_results else {} + has_failure = any( + r.get("error") or r.get("kernel_name") in ("skip", "error") + for r in rank_results + ) aggregate_result = { "rank": "aggregate", "shape": str(shape), @@ -651,7 +655,7 @@ def run_all_tests( "kernel_name": rank0.get("kernel_name", "unknown"), "num_iters": num_iters, "num_warmup": num_warmup, - "error": rank0.get("error"), + "acc_res": "failed" if has_failure else "pass", } all_results.append(aggregate_result) @@ -681,14 +685,12 @@ def run_all_tests( failed = [ r for r in all_results - if r.get("rank") == "aggregate" - and (r.get("kernel_name") in ("skip", "error") or r.get("error")) + if r.get("rank") == "aggregate" and r.get("acc_res") == "failed" ] if failed: print("\n✗ FAILED cases:") for r in failed: - reason = r.get("error") or r.get("kernel_name", "unknown") - print(f" {r['shape']} {r['dtype']} {r['mode']} → {reason}") + print(f" {r['shape']} {r['dtype']} {r['mode']} → {r.get('kernel_name', 'unknown')}") sys.exit(1) return df From 01399408a531942e55faa0eeaf93fe9012cca39e Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Tue, 21 Apr 2026 20:50:54 +0800 Subject: [PATCH 24/29] [Agent] New skill: add-target-atom-op (#423) * [Agent] New skill: add-target-atom-op * fix comments --- .claude/skills/add-target-atom-op/SKILL.md | 472 +++++++++++++++++++++ 1 file changed, 472 insertions(+) create mode 100644 .claude/skills/add-target-atom-op/SKILL.md diff --git a/.claude/skills/add-target-atom-op/SKILL.md b/.claude/skills/add-target-atom-op/SKILL.md new file mode 100644 index 00000000..915bac58 --- /dev/null +++ b/.claude/skills/add-target-atom-op/SKILL.md @@ -0,0 +1,472 @@ +--- +name: add-target-atom-op +description: > + Add a new target-specific Mma / Copy Op type to any FlyDSL backend + dialect (`lib/Dialect/Fly//` + + `include/flydsl/Dialect/Fly/IR/`). Explains the `MmaOp`-type / + `CopyOp`-type design (each type plugs into the generic + `!fly.mma_atom<...>` / `!fly.copy_atom<...>` wrapper through + `Fly_MmaOpTypeInterface` / `Fly_CopyOpTypeInterface`), the + stateful-vs-stateless variants (`Fly_StatefulOpTypeInterface`), and the + required `emitAtomCall` / `emitAtomCallSSA` lowering contract to the + backend dialect (LLVM/ROCDL/NVVM/SPIR-V/...). Use when adding a new + tensor-core / matrix instruction (MFMA, WMMA, HMMA, WGMMA, ...), a new + buffer / shared-memory / global copy atom, a new stateful copy (e.g. + per-atom offset or descriptor), or bringing up a new backend dialect + (`FlyPTX`, `FlyCPU`, etc.). The current reference implementation is + `FlyROCDL` with `CDNA3` MFMA, `CDNA3` BufferCopy, `CDNA4` + LDS-read-transpose, treat these as templates, not + prerequisites. Usage: /add-target-atom-op +allowed-tools: Read Edit Bash Grep Glob Agent +--- + +# Add a Target-Specific Mma / Copy Op to a FlyDSL Backend Dialect + +Step-by-step recipe for authoring a new `MmaOp*Type` or `CopyOp*Type` in a backend dialect +(`fly_rocdl`, or a future `fly_ptx` / ...), plus the **inherent design contract** every Op author +must understand before writing a single line of code. + +The examples throughout this skill draw from the `fly_rocdl` dialect (AMD ROCDL backend). The design +is deliberately backend-agnostic: the generic `!fly.mma_atom` / `!fly.copy_atom` wrappers and the +three type interfaces (`Fly_MmaOpTypeInterface`, `Fly_CopyOpTypeInterface`, +`Fly_StatefulOpTypeInterface`) live in the target-neutral `fly` dialect and know nothing about AMD, +NVIDIA, or other specifics. A new backend follows the exact same recipe — only the payload types, +the final intrinsic emission, and the directory prefix change. + +--- + +## 1. Inherent Design: How FlyDSL Atoms Work + +Internalize these five facts before adding any `Op`. They explain *why* the reference `CDNA3`, +`CDNA4`, `GFX1250` implementations look the way they do — and the same structure applies verbatim to +any new backend. + +### 1.1 Two-level type design: generic wrapper + target-specific payload + +There are **two kinds** of related types, and they live in different dialects: + +| Level | Dialect | Type (example) | Role | +|-------|---------|----------------|------| +| generic wrapper | `fly` | `!fly.mma_atom<...>`, `!fly.copy_atom<..., bits>` | Target-agnostic. Appears everywhere in kernel IR. | +| target payload | backend dialect (e.g. `fly_rocdl`) | `!fly_rocdl.cdna3.mfma<...>`, `!fly_rocdl.cdna3.buffer_copy<32>` | Knows *which* concrete instruction/intrinsic to emit. | + +The generic wrapper always holds a **payload** type as its first parameter. + +```mlir +// Using the ROCDL backend (the current reference): +!fly.mma_atom f32>> +!fly.copy_atom, 32> + +// A hypothetical NVIDIA backend would look like: +!fly.mma_atom f32>> +!fly.copy_atom, 128> +``` + +Every method you see on `MmaAtomType` / `CopyAtomType` is a trampoline: + +```cpp +// lib/Dialect/Fly/IR/FlyTypeDefs.cpp +Attribute MmaAtomType::getShapeMNK() const { + return cast(getMmaOp()).getShapeMNK(); +} +LogicalResult MmaAtomType::emitAtomCall(...) const { + return cast(getMmaOp()).emitAtomCall(...); +} +``` + +**Your job** when adding a new Op is to define the *payload* type and implement the interface +methods — the wrapper and the kernel-level ops (`fly.mma_atom_call`, `fly.copy_atom_call`, +`fly.make_mma_atom`, ...) work automatically. + +### 1.2 Three interfaces an Op type may implement + +`include/flydsl/Dialect/Fly/IR/FlyInterfaces.td`: + +| Interface | Required for... | Methods — **mandatory** / *optional* (see §1.3) | +|-----------------------------------|-----------------|----------------------| +| `Fly_MayStaticTypeInterface` | Stateless atoms (CopyOp with *no* mutable state; all MmaOps today) | **`isStatic`**, **`rebuildStaticValue`** | +| `Fly_CopyOpTypeInterface` | All CopyOps | **`getThrLayout`**, **`getThrBitLayoutSrc/Dst/Ref`**, **`emitAtomCall`** (mem + pred), *`emitAtomCallSSA`* (mem + pred — only if `fly-convert-atom-call-to-ssa-form` is in the pipeline) | +| `Fly_MmaOpTypeInterface` | All MmaOps | **`getThrLayout`**, **`getShapeMNK`**, **`getValTypeA/B/C/D`**, **`getThrValLayoutA/B/C`**, **`emitAtomCall`**, *`emitAtomCallSSA`* (only if SSA-promotion pass is active) | +| `Fly_StatefulOpTypeInterface` | Atoms that carry mutable per-call state (e.g. `soffset`, `imm_offset`) | **`getConvertedType`**, **`getDefaultState`**, **`setAtomState`** | + +Backend dialect could provide four convenience base classes that pre-declare the right interface +combinations. In the reference ROCDL backend (`include/flydsl/Dialect/FlyROCDL/IR/Dialect.td`) they +are: + +```tablegen +class FlyROCL_CopyOp // stateless CopyOp : MayStatic + CopyOp +class FlyROCL_StatefulCopyOp // stateful CopyOp : CopyOp + Stateful +class FlyROCL_MmaOp // stateless MmaOp : MayStatic + MmaOp +class FlyROCL_StatefulMmaOp // stateful MmaOp : MmaOp + Stateful +``` + +Mnemonic: **stateful => no `MayStaticTypeInterface`**; the mutable state *is* the dynamic component, +so the type is never "fully static" in the canonical-rebuild sense. + +### 1.3 `emitAtomCall` vs `emitAtomCallSSA` — only `emitAtomCall` is mandatory + +Two kernel-IR ops carry the atom invocation, and they correspond to the two interface methods: + +| Kernel Op | Operand form | Lowered via | Implementation status | +|--------------------------|-----------------------------------------------|-----------------------|-----------------------| +| `fly.copy_atom_call` | `src/dst : !fly.memref<...>` | `emitAtomCall` | **Required** | +| `fly.mma_atom_call` | `a/b/c/d : !fly.memref<...>` | `emitAtomCall` | **Required** | +| `fly.copy_atom_call_ssa` | `src/dst : SSA value or !fly.memref<..., addressSpace != Register>` | `emitAtomCallSSA` | **Optional** — only needed if `fly-convert-atom-call-to-ssa-form` appears in the pipeline | +| `fly.mma_atom_call_ssa` | `a/b/c : SSA value or !fly.memref<..., addressSpace != Register>` | `emitAtomCallSSA` | **Optional** (same condition) | + +**Default path (memref / `emitAtomCall`).** Every `fly.copy_atom_call` / `fly.mma_atom_call` in the +IR lowers through `emitAtomCall`. The Op receives the operand *pointers* into register memory +(`!fly.memref<..., register, layout>`), is expected to issue `llvm.load` / `llvm.store` itself to +read/write threads' registers, and emit the backend intrinsic in between. This is sufficient for the +full compile-to-binary pipeline — no SSA version required. + +**Optional path (SSA / `emitAtomCallSSA`).** A pipeline may insert the +`fly-convert-atom-call-to-ssa-form` pass (see +`lib/Dialect/Fly/Transforms/ConvertAtomCallToSSAForm.cpp`). That pass inspects every `AtomCall` and, +for operands whose `register`-address-space memref has a **coalescable** layout +(`isEligibleToPromote`: stride-1 or shape-1 after coalesce), rewrites them: + +1. `PtrLoadOp` pulls the whole register memref into a single SSA value of type + `RegMem2SSAType(memref)` — which is `elemTy` when the layout has cosize 1, or + `vector` otherwise (see `RegMem2SSAType` in `Fly/Utils/PointerUtils.cpp`). +2. The `AtomCall` is replaced with `AtomCallSSA`, taking those SSA values in place of pointers. +3. For output-producing cases, a `PtrStoreOp` writes the SSA result back to the original register + memref. + +At lowering time, `AtomCallSSA` dispatches to `emitAtomCallSSA` instead of `emitAtomCall`. The Op's +job there is **just the intrinsic + any required `LLVM::BitcastOp` between the SSA `vector<...>` and +the intrinsic's expected packed type** — no loads or stores because the SSA values already live in +registers. + +**Concrete differences between the two methods:** + +| | `emitAtomCall` | `emitAtomCallSSA` | +|--------------------------|-------------------------------------------------|-------------------------------------------------| +| Operand kinds | `Value`s of type `!fly.memref<..., register>` (lowered to `!llvm.ptr`) | `Value`s of scalar / `vector` type | +| What the method does | `LLVM::LoadOp` to fetch operands → intrinsic → `LLVM::StoreOp` to write result | (optional bitcast to intrinsic's packed type) → intrinsic → return `Value` / `failure` | +| Return type | `LogicalResult` | `FailureOr` (the result SSA value, or `failure`) | +| Needs layout/cosize info | No — operand type already carries it | No — caller already packed operands into `vector` | +| Bitcast dance | Typically unnecessary (load yields the right type) | Often necessary (SSA vector width may not match intrinsic's expected operand width) | +| Backend intrinsic emitted | Same | Same | + +In practice every reference Op implements `emitAtomCall` as a thin shim over `emitAtomCallSSA` — +load operands, call `emitAtomCallSSA`, store the result. See `MmaOpCDNA3_MFMAType::emitAtomCall` in +`CDNA3/MmaAtom.cpp` for the canonical shim and `CopyOpCDNA3BufferAtomicType::emitAtomCall` in +`CDNA3/CopyAtom.cpp` for a CopyOp instance. **If your downstream pipeline never runs +`fly-convert-atom-call-to-ssa-form`, you may skip `emitAtomCallSSA` entirely and write a +self-contained `emitAtomCall`** — but the shim pattern is strictly better because it keeps the two +paths in sync for free. + +### 1.4 ThrVal layouts describe the per-thread register footprint + +Every MmaOp / CopyOp must publish layouts that describe *which thread holds which element* of the +tile. This is consumed by `TiledCopy` / `TiledMma` in the layout-lowering pass. + +| Method (MmaOp) | What it describes | +|------------------------|-------------------| +| `getThrLayout` | thread-count layout inside one thread group that issues the instruction (e.g. `(64):(1)` for an AMD wave64 MFMA, `(32):(1)` for AMD wave32 WMMA, `(1):(1)` for a single thread, `(128):(1)` for NVIDIA WGMMA issued by a warpgroup) | +| `getShapeMNK` | tuple `(M, N, K)` of the instruction tile | +| `getValTypeA/B/C/D` | per-operand element type | +| `getThrValLayoutA/B/C` | layout mapping `(thr, val)` → element coordinate in the **reference tile** (column-major `(M,K)` for A, `(N,K)` for B, `(M,N)` for C) | + +| Method (CopyOp) | What it describes | +|----------------------------|-------------------| +| `getThrLayout` | thread count participating in one atom call | +| `getThrBitLayoutSrc/Dst/Ref` | layout in **bit-granularity** — shape is `(num_threads, num_bits)` — one bit per leaf | + +The base `CopyAtomType::getThrValLayoutSrc()` then "recasts" the bit layout into a +`valBits`-granularity layout (see `CopyAtomType::getThrValLayout{Src,Dst,Ref}` in +`FlyTypeDefs.cpp`). This is why CopyOp types publish a **bit-layout** and MmaOp types publish a +**value-layout**: copies carry an extra `valBits` parameter on the wrapper, and one CopyOp type can +serve multiple element widths. + +Use the `FxLayout / FxShape / FxStride / FxThr / FxVal / FxC` macros from +`flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc` — they're the auxiliary way to build these +`LayoutAttr`s. + +### 1.5 Critical checks for ThrVal / ThrBit layouts — read before writing any + +A wrong ThrVal/ThrBit layout is the #1 source of silent-wrong-result bugs in FlyDSL: the compiler +accepts it, the kernel runs, and the output is garbage. There are no good runtime diagnostics for +this. Before you commit any new `getThrValLayout*` / `getThrBitLayout*`, verify **every** rule below +on paper or in a scratch test. + +#### 1.5.1 Shape must be a top-level 2-tuple `((thr...), (val...))` + +Look at any existing example: `FxLayout(FxShape(FxThr(...), FxVal(...)), FxStride(FxThr(...), +FxVal(...)))`. The top-level shape has exactly **rank 2**: outer mode 0 is the thread axes, outer +mode 1 is the value axes. Each of these modes may itself be a nested tuple. + +This is not a style convention — it's load-bearing: `TiledOpUtils.h` unconditionally does +`shape.at(0)` / `shape.at(1)` / `stride.at(0)` / `stride.at(1)` to slice "thr" from "val". For +CopyOps, the same shape is passed to `layoutZippedDivide(tiledLayoutThrVal, atomTile)` where +`atomTile = (atomNumThr, atomNumVal)` is computed as the product of mode-0 and mode-1 (see +`detail::layoutTiledCopyThrValView` in `TiledOpUtils.h`). + +If you nest an extra level or flatten it to rank 1, the code compiles but silently reads +`thrShape=firstLeaf`, `valShape=secondLeaf`, and produces wrong tile divisions. + +#### 1.5.2 Shape-product invariants + +Let `|·|` denote "total number of elements". Then: + +| Op kind | Method | Must satisfy | +|----------|---------------------------|--------------| +| MmaOp | `getThrLayout` | `\|thr\|` == number of threads that cooperate on one instruction (e.g. 64 for AMD wave64 MFMA, 32 for AMD wave32 WMMA, 128 for NVIDIA SM90 WGMMA warpgroup) | +| MmaOp | `getThrValLayoutA` | `\|thr\| * \|val\|` == `M * K` | +| MmaOp | `getThrValLayoutB` | `\|thr\| * \|val\|` == `N * K` | +| MmaOp | `getThrValLayoutC` (and D) | `\|thr\| * \|val\|` == `M * N` | +| MmaOp | `\|thr\|` of ThrValLayout{A,B,C} | matches `\|thr\|` of `getThrLayout` | +| MmaOp | `\|val\|` of ThrValLayout | matches the thread's register vector width used in `emitAtomCallSSA` (e.g. `accVecSize` for C; `vecSize` of `abTyA` for A) | +| CopyOp | `getThrLayout` | `\|thr\|` == number of threads participating in one atom call (e.g. 1 for a per-thread load, 16 for AMD `ds_read_tr16_b64`) | +| CopyOp | `getThrBitLayoutSrc/Dst/Ref` | `\|val\|` == `bitSize` (the Op's `bitSize` parameter or per-atom constant). Shape is always `(|thr|, bitSize)`. | +| CopyOp | `\|thr\|` of ThrBitLayout{Src,Dst,Ref} | all three equal and equal to `\|thr\|` of `getThrLayout` | + +Violating any of these still compiles but yields undefined behavior. Thread-count mismatch is +especially insidious: a wave64 MFMA registered with `FxC(32)` (or a 32-thread NVIDIA warp MMA +registered with `FxC(16)`) will happily emit the intrinsic, but half the threads will compute on stale +registers. + +#### 1.5.3 Reference coordinate system is *column-major*, not row-major + +| Op | Operand | Reference tile | Column-major interpretation | +|----------|---------|----------------|-----------------------------| +| MmaOp | A | `(M, K)` | stride `(1, M)` is baseline | +| MmaOp | B | `(N, K)` | stride `(1, N)` is baseline | +| MmaOp | C, D | `(M, N)` | stride `(1, M)` is baseline | +| CopyOp | src/dst | `(M, N)` | stride `(1, M)` is baseline | + +#### 1.5.4 CopyOp bit-layout vs. value-layout — do not confuse them + +The interface publishes `getThrBitLayout*` (bit granularity); the `CopyAtomType` wrapper computes +`getThrValLayout*` by calling `layoutRecast(bitLayout, /*oldBits=*/1, /*newBits=*/valBits)` (see +`CopyAtomType::getThrValLayout{Src,Dst,Ref}` in `FlyTypeDefs.cpp`). + +Consequences: +- A 32b buffer copy writes `FxShape(FxC(1), FxC(32))` for an f32 → the recast at `valBits=32` trivially + keeps it as `FxShape(FxC(1), FxC(1))` (1 f32 per thread). The same Op reused for a 16b copy of an + f16 pair gives `FxShape(FxC(1), FxC(2))` (2 f16 per thread), all automatically. +- If `bitSize` does not divide evenly by the downstream `valBits` (e.g. 96b / 32b f32 = 3 — fine; + 96b / 64b = 1.5 — broken), the `layoutRecast` path silently produces nonsense. It is programmer's + duty to find this mismatch behavior. + +### 1.6 Stateful atoms are lowered to an LLVM struct + +A stateful Op's state lives as an `!llvm.struct<(i32, i32, ...)>` at runtime. The three methods you +implement for `StatefulOpTypeInterface` wire this up: + +1. `getConvertedType(ctx)` — the concrete `!llvm.struct<...>` layout. +2. `getDefaultState(builder, loc)` — build an initial value (typically zero-initialized). +3. `setAtomState(builder, loc, struct, fieldAttr, fieldValue)` — field write. The `fieldAttr` is a + `StringAttr` that must be one of the `AtomStateField` enum mnemonics (`"soffset"`, + `"imm_offset"`). + +Then at lowering time the field is read back via `LLVM::ExtractValueOp` with the index returned by +your static `getFieldIndex(AtomStateField)` helper. See the `CopyOpCDNA3BufferCopyType` stateful +methods (`getFieldIndex` / `getConvertedType` / `getDefaultState` / `setAtomState`) in +`CDNA3/CopyAtom.cpp` for the full template. + +If your Op needs a new field kind (something other than `Soffset` or `ImmOffset`), extend the +`AtomStateField` enum in your backend's `Atom.td` — for the ROCDL reference backend that lives at +`include/flydsl/Dialect/FlyROCDL/IR/Atom.td` (this regenerates `AtomStateEnums.{h,cpp}.inc`). A new +backend should declare its own `AtomStateField` enum in its IR directory. + +--- + +## 2. The Files You Will Touch + +Below, `` stands for the backend dialect's name (e.g. `FlyROCDL` today; would be `FlyPTX`, +`FlySPIRV`, etc. for a new backend) and `` is the per-chip subdirectory (e.g. +`CDNA3`, `CDNA4`, `GFX1250` for ROCDL; would be `SM80`, `SM90`, `SM100` for a PTX port). + +| File | Purpose | +|------|---------| +| `include/flydsl/Dialect//IR/MmaAtom.td` or `CopyAtom.td` | TableGen declaration of the new type (`def _MmaOp_` / `def _CopyOp`) | +| `lib/Dialect///MmaAtom.cpp` or `CopyAtom.cpp` | Interface method implementations | +| `lib/Dialect//CMakeLists.txt` | Add the new `.cpp` to `MLIRDialect` | +| `lib/Bindings/Python/Extension.cpp` | Expose the type to Python so kernels can construct it | +| `python/flydsl/expr//.py` | DSL-level constructor wrappers (e.g. `MFMA(...)`, `WMMA(...)`, `BufferCopy(...)` for ROCDL; would be `WGMMA(...)`, `CpAsyncBulk(...)` etc. for PTX) | +| `include/flydsl/Dialect//IR/Atom.td` | (Only if adding a new `AtomStateField`) extend the enum | +| `tests/mlir/Conversion/.mlir` | A FileCheck test exercising the new lowering | + +Concrete instantiation for the reference ROCDL backend: +`include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td`, `lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp`, +`MLIRFlyROCDLDialect`, etc. + +--- + +## 3. Recipe: Add a Stateless MmaOp + +Examples use the ROCDL backend with a hypothetical `CDNA5_MFMA`. For other backends substitute +`FlyROCL_MmaOp` → your backend's base class, `CDNA` → your chip family (`SM80`, `SM90`, +`X86AVX512`, ...), `ROCDL::mfma_*` → your intrinsic ops (`NVVM::WgmmaMmaAsyncOp`, `vector.contract`, +...), `--convert-fly-to-rocdl` → your conversion pass. Everything else is identical. + +**Step 1 — TableGen declaration** in `include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td`: + +```tablegen +def FlyROCDL_MmaOpCDNA5_MFMA : FlyROCL_MmaOp<"MmaOpCDNA5_MFMA", "cdna5.mfma", []> { + let parameters = (ins "int32_t":$m, "int32_t":$n, "int32_t":$k, + "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc); + let assemblyFormat = "`<` custom($m, $n, $k) `,` " + "`(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc `>`"; + let builders = [TypeBuilderWithInferredContext<(ins ...), [{ + return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc); }]>]; + let genVerifyDecl = 1; +} +``` + +The `FlyROCL_MmaOp` base auto-adds `DeclareTypeInterfaceMethods` + +``. `MNKDimensionList` is the shared parser/printer pair +(`parseMNKDimensionList` / `printMNKDimensionList` in `lib/Dialect/Fly/IR/FlyDialect.cpp`) — reuse +it. + +**Step 2 — Interface methods** in `lib/Dialect/FlyROCDL/CDNA5/MmaAtom.cpp`. Clone +`lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp` as the template; the only non-boilerplate methods are: + +- `getThrLayout()` → `FxLayout(FxC(), FxC(1))` +- `getShapeMNK()` → `IntTupleAttr` of `(m, n, k)` +- `getValTypeA/B/C/D()` → operand element type (or packed vector type if the intrinsic expects it) +- `getThrValLayoutA/B/C()` → the layout that maps `(thr, val)` into the reference tile (A=(M,K), + B=(N,K), C=(M,N), all column-major). Satisfy every invariant in §1.5 before trusting it. + +**Step 3 — `verify`** (static, from `genVerifyDecl = 1`). Reject any `(m, n, k, elemTy)` tuple you +don't support, with a clear `emitError()` message; otherwise an invalid config silently hits +`return failure()` in `emitAtomCallSSA` with no diagnostic. + +**Step 4 — `emitAtomCallSSA`** (optional, only needed when `fly-convert-atom-call-to-ssa-form` is in +the pipeline; see §1.3). The only place you touch backend intrinsics. Pattern from +`MmaOpCDNA3_MFMAType::emitAtomCallSSA` in `CDNA3/MmaAtom.cpp`: derive the intrinsic's exact operand +types, `LLVM::BitcastOp` each SSA operand to match, then dispatch to lowered dialect ops. Find +intrinsic names in `llvm/include/llvm/IR/IntrinsicsAMDGPU.td` (ROCDL), `NVVMOps.td` (NVVM), or +`SPIRVOps.td` (SPIR-V). + +**Step 5 — `emitAtomCall`** (mandatory entry point; see §1.3). If Step 4 exists, this is a ~15-line +shim: + +```cpp +LogicalResult MmaOpCDNA5_MFMAType::emitAtomCall(OpBuilder &builder, Location loc, Type mmaAtomTy, + Type dMemTy, Type aMemTy, Type bMemTy, Type cMemTy, + Value atomVal, Value dPtr, Value aPtr, Value bPtr, Value cPtr) const { + // Derive abTyA, abTyB, accTy exactly as in emitAtomCallSSA. + Value a = LLVM::LoadOp::create(builder, loc, abTyA, aPtr); + Value b = LLVM::LoadOp::create(builder, loc, abTyB, bPtr); + Value c = LLVM::LoadOp::create(builder, loc, accTy, cPtr); + auto res = emitAtomCallSSA(builder, loc, accTy, mmaAtomTy, Type{}, + abTyA, abTyB, accTy, atomVal, Value{}, a, b, c); + if (failed(res)) return failure(); + LLVM::StoreOp::create(builder, loc, *res, dPtr); + return success(); +} +``` + +If you skipped Step 4, emit the intrinsic directly here. + +**Step 6 — CMake.** Add `CDNA5/MmaAtom.cpp` to the source list in +`lib/Dialect/FlyROCDL/CMakeLists.txt`. + +**Step 7 — Python bindings**. In `lib/Bindings/Python/FlyROCDLExtension.cpp`, add a +`PyMmaOpCDNA5_MFMAType : PyConcreteType<...>` following the existing `PyMmaOpCDNA3_MFMAType` +template, and register it in the `NB_MODULE(_mlirDialectsFlyROCDL, m)` block. Then add a thin +wrapper like `MFMA_CDNA5(m, n, k, elem, ...)` in `python/flydsl/expr/rocdl.py`. + +**Step 8 — FileCheck test.** Clone `tests/mlir/Conversion/mma_atom.mlir` and swap the payload type: + +```mlir +// RUN: %fly-opt %s --fly-rewrite-func-signature --fly-canonicalize \ +// RUN: --fly-layout-lowering --convert-fly-to-rocdl | FileCheck %s +// CHECK: rocdl.mfma.f32.16x16x32f16 +``` + +**Step 9 — Build and verify.** `bash scripts/build.sh` rebuilds C++, bindings, and stubs. Run +FileCheck, then a 1-wave end-to-end Python kernels before trusting the layout. + +--- + +## 4. Recipe: Add a Stateful CopyOp + +Same skeleton as §3 but replacing `MayStaticTypeInterface` with `StatefulOpTypeInterface`. Clone +`CDNA3/CopyAtom.cpp` (`CopyOpCDNA3BufferCopy`) as the template. Stateful CopyOps model backend +concepts with per-call mutable state (AMD buffer descriptors, NVIDIA TMA descriptors, per-atom SM +offsets, ...). State always lowers to `!llvm.struct<...>` in rocdl backend, so only the field set +differs. + +**Step 1 — TableGen.** Pick `FlyROCL_StatefulCopyOp` as the base: + +```tablegen +def FlyROCDL_CopyOpCDNA5GlobalCopy + : FlyROCL_StatefulCopyOp<"CopyOpCDNA5GlobalCopy", "cdna5.global_copy", []> { + let parameters = (ins "int32_t":$bitSize); + let assemblyFormat = "`<` $bitSize `>`"; +} +``` + +**Step 2 — Stateful methods** (`getFieldIndex`, `getConvertedType`, `getDefaultState`, +`setAtomState`). Template directly from the stateful methods of `CopyOpCDNA3BufferCopyType` in +`CDNA3/CopyAtom.cpp`. Key points: + +- `getFieldIndex(AtomStateField)` is a static `switch` returning the struct field index. +- `getConvertedType(ctx)` returns the `LLVM::LLVMStructType::getLiteral(ctx, {i32, i32, ...})` + matching your state. +- `getDefaultState` builds an `UndefOp` then `InsertValueOp`s zero into each field. +- `setAtomState` must return `nullptr` on unrecognized fields (not fail-silently to `success`). + +**Step 3 — `getThrLayout` + `getThrBitLayoutSrc/Dst/Ref`** (all in **bit** granularity). For a +simple per-thread copy: `FxLayout(FxShape(FxC(1), FxC(getBitSize())), FxStride(FxC(1), FxC(1)))` for +Src/Dst/Ref. If Src ≠ Dst (e.g. LDS-read-transpose), Ref usually mirrors the register side — see +`CDNA4/CopyAtom.cpp`. All three layouts must satisfy the invariants in §1.5. + +**Step 4 — `emitAtomCallSSA`** (optional, see §1.3). Extract state fields with +`LLVM::ExtractValueOp`, then dispatch to the backend intrinsic. Pattern: the unpredicated +`CopyOpCDNA3BufferCopyType::emitAtomCallSSA` overload in `CDNA3/CopyAtom.cpp`. + +**Step 5 — Predicated SSA variant.** Wrap the unpredicated form in `scf::IfOp` — load side yields +`result` in `then` / old dst in `else`; store side uses a single-branch `scf.if`. Template: the +predicated `CopyOpCDNA3BufferCopyType::emitAtomCallSSA` overload (the one taking `Value pred`) in +`CDNA3/CopyAtom.cpp`. + +**Step 6 — memref-form `emitAtomCall` (mandatory + predicated).** `LLVM::LoadOp`/`StoreOp` shim +around the SSA form (or intrinsic dispatch directly if you skipped Steps 4-5). Template: the two +`CopyOpCDNA3BufferCopyType::emitAtomCall` overloads in `CDNA3/CopyAtom.cpp`. + +**Steps 7-9 — CMake / Python / test.** Identical to §3 Steps 6-8. Python wrapper follows +`CopyOpCDNA3BufferCopyType.get(bit_size)`. + +--- + +## 5. Adding a New `AtomStateField` (rare) + +If your stateful Op needs a field kind no existing Op uses (`"voffset"`, `"cpol"`, `"tensor_map"` +…), extend the enum in your backend's `Atom.td` (`include/flydsl/Dialect/FlyROCDL/IR/Atom.td` for +ROCDL): + +```tablegen +def FlyROCDL_AtomStateField : I32EnumAttr<"AtomStateField", "", [ + I32EnumAttrCase<"Soffset", 0, "soffset">, + I32EnumAttrCase<"ImmOffset", 1, "imm_offset">, + I32EnumAttrCase<"Voffset", 2, "voffset"> // <-- NEW +]> { let genSpecializedAttr = 0; let cppNamespace = FlyROCDL_Dialect.cppNamespace; } +``` + +Use it in your `getFieldIndex` switch. `fly.atom.set_value(%atom, "voffset", %val)` then works +automatically via `AtomSetValueOp`. + +--- + +### Recommended reading order + +Files (4) and (8) are target-neutral; the rest are ROCDL templates a new backend mirrors in its own +tree. + +1. `include/flydsl/Dialect/FlyROCDL/IR/Dialect.td` — base classes +2. `include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td` — MmaOp type decls +3. `include/flydsl/Dialect/FlyROCDL/IR/CopyAtom.td` — CopyOp type decls +4. `include/flydsl/Dialect/Fly/IR/FlyInterfaces.td` — **target-neutral** interface contracts +5. `lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp` — simplest stateless MmaOp +6. `lib/Dialect/FlyROCDL/CDNA3/CopyAtom.cpp` — all three CopyOp patterns +7. `lib/Dialect/Fly/IR/FlyTypeDefs.cpp` — **target-neutral** wrapper trampolines (see the + `CopyAtomType::*` and `MmaAtomType::*` method definitions) +8. `lib/Conversion/FlyToROCDL/FlyToROCDL.cpp` — `MakeCopyAtomOpLowering` / `MakeMmaAtomOpLowering` / + `AtomSetValueOpLowering` / `CopyAtomCallLowering` / `CopyAtomCallSSALowering` / + `MmaAtomCallLowering` / `MmaAtomCallSSALowering` (callers of your interface methods) + From 3f7b6b573f0a9fd6701c30fcf6c79999e54cf585 Mon Sep 17 00:00:00 2001 From: Felix Li Date: Wed, 22 Apr 2026 11:42:55 +0800 Subject: [PATCH 25/29] Port v2 gemm to main (#422) --- .../skills/flydsl-kernel-authoring/SKILL.md | 49 +- .claude/skills/port-to-layout-api/SKILL.md | 33 +- kernels/preshuffle_gemm.py | 48 +- kernels/preshuffle_gemm_v2.py | 507 ++++++++++++++++++ lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp | 8 +- tests/kernels/bench_preshuffle_gemm_v2.py | 411 ++++++++++++++ 6 files changed, 1021 insertions(+), 35 deletions(-) create mode 100644 kernels/preshuffle_gemm_v2.py create mode 100644 tests/kernels/bench_preshuffle_gemm_v2.py diff --git a/.claude/skills/flydsl-kernel-authoring/SKILL.md b/.claude/skills/flydsl-kernel-authoring/SKILL.md index 1cbb2498..de6be2a4 100644 --- a/.claude/skills/flydsl-kernel-authoring/SKILL.md +++ b/.claude/skills/flydsl-kernel-authoring/SKILL.md @@ -333,27 +333,52 @@ result = arith.select(cond, true_val, false_val) is_less = arith.cmpf(a, b, predicate="olt") # ordered less-than ``` -### Vector Arithmetic (IMPORTANT) -All arith ops (`addf`, `mulf`, `negf`, `maximumf`, `cmpf`, `select`) work on **both scalars and vectors**. -To broadcast a scalar to a vector, use `arith.constant_vector`: +### Internal Types: Vector and Numeric (PREFERRED) + +Use FlyDSL's internal typed system instead of raw MLIR ops. The `Vector` class wraps `vector` with operator overloading and type-safe methods. ```python -from flydsl._mlir.ir import VectorType +from flydsl.expr.typing import Vector as Vec, Float32, Float16, BFloat16 + +# Wrap raw vector values +acc = Vec(frag_C.load()) # vector → Vector with * / + operators + +# Indexing (replaces vector.extract) +val = acc[idx] # returns Float32 scalar + +# Bitcast (replaces vector.bitcast) +v_f32 = Vec(raw_vec).bitcast(Float32) # vector → vector -# Create a splat constant vector (e.g., all 2.0) -vec_type = VectorType.get([vec_width], fx.T.f32()) -scale_vec = arith.constant_vector(2.0, vec_type) +# Type conversion (replaces arith.trunc_f / arith.ext_f) +bf16_val = f32_val.to(BFloat16) # f32 → bf16 -# Now use it with vector ops -vA = fx.memref_load_vec(rA) # load vec from register -vC = arith.mulf(vA, scale_vec) # element-wise scale +# Arithmetic — use Python operators, not arith.mulf/addf +result = (val * scale_a) * scale_b # auto-dispatches to mulf + +# Splat constant vector +zeros = Vec.filled(N, 0.0, Float32) + +# Index cast — use fx.Int32 instead of arith.index_cast +idx = fx.Int32(gpu.block_id("x") * tile_m) ``` +**Prefer internal types over raw ops:** +| Raw MLIR op | Internal type equivalent | +|-------------|------------------------| +| `vector.extract(v, static_position=[i], ...)` | `Vec(v)[i]` | +| `vector.bitcast(target_ty, v)` | `Vec(v).bitcast(Float32)` | +| `arith.trunc_f(ty, v)` | `v.to(BFloat16)` | +| `arith.mulf(a, b)` | `a * b` | +| `arith.addf(a, b)` | `a + b` | +| `arith.index_cast(T.i32, v)` | `fx.Int32(v)` | + +Still use `arith.constant_vector` for splat and `vector.from_elements` for building vectors from scalars (no Vector equivalent yet). + ### Arith Ops Availability Table | Operation | Function | Works on Vectors | Notes | |-----------|----------|-----------------|-------| -| Add | `arith.addf(a, b)` | Yes | | -| Multiply | `arith.mulf(a, b)` | Yes | | +| Add | `a + b` or `arith.addf(a, b)` | Yes | | +| Multiply | `a * b` or `arith.mulf(a, b)` | Yes | | | Negate | `arith.negf(a)` | Yes | | | Max | `arith.maximumf(a, b)` | Yes | Good for ReLU | | Compare | `arith.cmpf(a, b, pred)` | Yes | Returns i1/vec | diff --git a/.claude/skills/port-to-layout-api/SKILL.md b/.claude/skills/port-to-layout-api/SKILL.md index c9a4955d..9038e95d 100644 --- a/.claude/skills/port-to-layout-api/SKILL.md +++ b/.claude/skills/port-to-layout-api/SKILL.md @@ -31,7 +31,7 @@ Read the kernel and classify each buffer_load/buffer_store: | Pattern | Layout API Port | Example | |---------|----------------|---------| | Contiguous vec load along innermost dim | `make_buffer_tensor` + `BufferCopy128b` | Load 8xf16 from row | -| Scalar load (vec_width=1) of metadata | Keep as `buffer_ops.buffer_load` | Position/slot/mask loads | +| Scalar load (vec_width=1) | `make_buffer_tensor` + `BufferCopy32b`/`BufferCopy16b` | Scale/metadata loads | | Scattered store (non-contiguous layout) | Keep as `buffer_ops.buffer_store` | Non-flash value_cache | | Contiguous vec store along innermost dim | `make_buffer_tensor` + `BufferCopy` | Store 8xf16 to output | @@ -138,12 +138,33 @@ if is_valid: _store_vec(val, out_div, idx) ``` -### Step 6: Keep Scalar Accesses as buffer_ops +### Step 6: Scalar Loads via Layout API -Not everything should use the layout API. Keep `buffer_ops` for: -- Scalar metadata loads: `buffer_ops.buffer_load(rsrc, idx, vec_width=1, dtype=T.i32)` -- Scattered stores where elements are non-contiguous in memory -- Single-element stores (e.g., writing one scale value per block) +Scalar loads (vec_width=1) also work through the layout API: + +```python +buf = fx.rocdl.make_buffer_tensor(tensor, max_size=True) +copy_atom_s = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) # f32 scalar +scalar_reg_ty = fx.MemRefType.get(T.f32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) +scalar_reg_lay = fx.make_layout(1, 1) +div = fx.logical_divide(buf, fx.make_layout(1, 1)) + +def load_scalar(index): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.copy_atom_call(copy_atom_s, fx.slice(div, (None, fx.Int32(index))), r) + return Vec(fx.memref_load_vec(r))[0] # extract scalar from vector<1xf32> +``` + +Scalar stores work the same way (reverse src/dst): +```python +def store_scalar(index, val): + r = fx.memref_alloca(scalar_reg_ty, scalar_reg_lay) + fx.memref_store_vec(Vec.filled(1, val, Float32), r) + fx.copy_atom_call(copy_atom_s, r, fx.slice(div, (None, fx.Int32(index)))) +``` + +Keep `buffer_ops` only for: +- Scattered stores where elements are truly non-contiguous in memory ### Step 7: Remove Dead Code diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 5c88b9c3..557509f9 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -252,6 +252,7 @@ def compile_preshuffle_gemm_a8( _is_gfx950 = str(gpu_arch).startswith("gfx950") _is_gfx942 = str(gpu_arch).startswith("gfx942") + use_mfma_k32 = _is_gfx950 and is_f16_or_bf16 lds_stride_bytes = tile_k_bytes @@ -504,7 +505,7 @@ def _extract_b_packs(b16): b_i64x2 = vector.bitcast(T.i64x2, b16) b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) - if not is_f16_or_bf16: + if not is_f16_or_bf16 or use_mfma_k32: return b0_i64, b1_i64 b0_v1 = vector.from_elements(T.vec(1, T.i64), [b0_i64]) b1_v1 = vector.from_elements(T.vec(1, T.i64), [b1_i64]) @@ -586,7 +587,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): a0_i64 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) a1_i64 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) - if not is_f16_or_bf16: + if not is_f16_or_bf16 or use_mfma_k32: return a0_i64, a1_i64 a0_v1 = vector.from_elements(T.vec(1, T.i64), [a0_i64]) @@ -897,21 +898,33 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): return current_accs_list, scales_pf mfma_res_ty = T.i32x4 if is_int8 else T.f32x4 - if is_int8: - mfma_fn = mfma_i32_k32 - elif is_f16: - mfma_fn = rocdl.mfma_f32_16x16x16f16 - elif is_bf16: - mfma_fn = rocdl.mfma_f32_16x16x16bf16_1k + if use_mfma_k32: + mfma_fn_k32 = rocdl.mfma_f32_16x16x32_f16 if is_f16 else rocdl.mfma_f32_16x16x32_bf16 + + def i64x2_to_v8(lo, hi): + v2 = vector.from_elements(T.i64x2, [lo, hi]) + return vector.bitcast(T.f16x8 if is_f16 else T.bf16x8, v2) + + def mfma_k64_bytes(acc_in, a0, a1, b0, b1): + av = i64x2_to_v8(a0, a1) + bv = i64x2_to_v8(b0, b1) + return mfma_fn_k32(mfma_res_ty, [av, bv, acc_in, 0, 0, 0]) else: - mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 + if is_int8: + mfma_fn = mfma_i32_k32 + elif is_f16: + mfma_fn = rocdl.mfma_f32_16x16x16f16 + elif is_bf16: + mfma_fn = rocdl.mfma_f32_16x16x16bf16_1k + else: + mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 - def mfma_step(acc_in, a, b): - return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) + def mfma_step(acc_in, a, b): + return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) - def mfma_k64_bytes(acc_in, a0, a1, b0, b1): - acc_mid = mfma_step(acc_in, a0, b0) - return mfma_step(acc_mid, a1, b1) + def mfma_k64_bytes(acc_in, a0, a1, b0, b1): + acc_mid = mfma_step(acc_in, a0, b0) + return mfma_step(acc_mid, a1, b1) for ku in range_constexpr(k_unroll): b_packs0, b_packs1 = b_tile_in[ku] @@ -1154,7 +1167,12 @@ def _build_scheduler(numer: int, denom: int): rocdl.sched_dswr(1) else: mfma_group = num_acc_n - element_k_per_mfma = 128 if _is_gfx950 else 32 + if use_mfma_k32: + element_k_per_mfma = 32 + elif _is_gfx950: + element_k_per_mfma = 128 + else: + element_k_per_mfma = 32 num_mfma_per_tile_k = tile_k // element_k_per_mfma mfma_total = num_mfma_per_tile_k * m_repeat * mfma_group num_ds_load = num_a_lds_load diff --git a/kernels/preshuffle_gemm_v2.py b/kernels/preshuffle_gemm_v2.py new file mode 100644 index 00000000..14e2cbec --- /dev/null +++ b/kernels/preshuffle_gemm_v2.py @@ -0,0 +1,507 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Preshuffle GEMM kernel — Layout API version. + +Supports f16, bf16, fp8 via layout API (fx.copy + fx.gemm). +Uses scf.for tile loop with ping-pong double buffer (2-stage B). +Includes hot_loop_scheduler from the old pipeline for instruction scheduling. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.expr import vector, gpu, rocdl, range_constexpr +from flydsl.expr.typing import T, Float16, Float32, BFloat16, Float8E4M3FNUZ, Float8E4M3FN, Int8, Vector as Vec +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.runtime.device import get_rocm_arch +from flydsl._mlir import ir +from kernels.preshuffle_gemm import _get_preload +from typing import Optional + + +def compile_preshuffle_gemm_v2( + *, + N: int, + K: int, + tile_m: int, + tile_n: int, + tile_k: int, + in_dtype: str = "fp8", + out_dtype: str = "bf16", + waves_per_eu: Optional[int] = None, + enable_scheduler: bool = True, +): + """Compile preshuffle GEMM using the layout API. + + Supports in_dtype: fp8, fp16, bf16. + Returns a JitFunction: fn(C, A, B, scale_a, scale_b, M, N, stream). + """ + if in_dtype not in ("fp8", "fp16", "bf16"): + raise ValueError(f"in_dtype must be fp8/fp16/bf16, got {in_dtype!r}") + + is_fp8 = in_dtype == "fp8" + is_f16 = in_dtype == "fp16" + is_bf16 = in_dtype == "bf16" + is_f16_or_bf16 = is_f16 or is_bf16 + out_is_bf16 = out_dtype == "bf16" + elem_bytes = 1 if is_fp8 else 2 + + gpu_arch = get_rocm_arch() + is_gfx942 = str(gpu_arch).startswith("gfx942") + is_gfx950 = str(gpu_arch).startswith("gfx950") + # TODO: enable when CDNA4 MFMA_Scale works through layout API (fly.mma_atom_call) + use_mfma_scale_128 = False # is_fp8 and is_gfx950 + use_mfma_k32 = is_f16_or_bf16 and is_gfx950 + if use_mfma_scale_128: + if tile_k % 128 != 0: + raise ValueError(f"tile_k must be divisible by 128 for gfx950 fp8, got {tile_k}") + + if is_f16: + layout_elem = Float16 + elif is_bf16: + layout_elem = BFloat16 + elif is_gfx950: + layout_elem = Float8E4M3FN + else: + layout_elem = Float8E4M3FNUZ + + out_elem_cls = BFloat16 if out_is_bf16 else Float16 + + # Tile geometry + # k_perm groups atoms: 32 for f16/bf16 K=16 (2 atoms), 32 for K=32 (1 atom), + # 128 for gfx950 fp8 (1×K=128), 64 for gfx942 fp8 (2×K=32) + tile_K_perm = 128 if use_mfma_scale_128 else (64 if is_fp8 else 32) + k_iters = tile_k // tile_K_perm + num_tiles = K // tile_k + m_repeat = tile_m // 16 + num_waves = 4 + n_per_wave = tile_n // num_waves + num_acc_n = n_per_wave // 16 + n_accs = m_repeat * num_acc_n + acc_size = n_accs * 4 + + # LDS: ping + pong + smem_bytes = tile_m * tile_k * elem_bytes * 2 + + total_threads = 256 + a_load_bytes = 16 + bytes_per_thread_a = (tile_m * tile_k * elem_bytes) // total_threads + num_a_loads = bytes_per_thread_a // a_load_bytes + num_b_loads = (tile_n * tile_k * elem_bytes) // total_threads // 16 + num_ds_load = (tile_m * tile_k * elem_bytes) // 64 // 16 # A LDS reads per wave + num_gmem_loads = num_a_loads + num_b_loads + if is_fp8 and is_gfx950: + dsrd_preload, dvmem_preload = _get_preload(tile_m, tile_n, tile_k) + else: + dsrd_preload, dvmem_preload = (0, 0) + + # ── Kernel ──────────────────────────────────────────────────────── + @flyc.kernel + def kernel_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_scale_a: fx.Tensor, + arg_scale_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + tiled_mma: fx.TiledMma, + tiled_copy_g2s: fx.TiledCopy, + ): + tid = fx.thread_idx.x + bid_x, bid_y, _ = fx.block_idx + + gA = fx.rocdl.make_buffer_tensor(arg_a) + gB = fx.rocdl.make_buffer_tensor(arg_b) + gC = fx.rocdl.make_buffer_tensor(arg_c) + + tA = fx.flat_divide(gA, fx.make_tile(tile_m, tile_k))[None, None, bid_x, None] + tB = fx.flat_divide(gB, fx.make_tile(tile_n, tile_k))[None, None, bid_y, None] + tC = fx.flat_divide(gC, fx.make_tile(tile_m, tile_n))[None, None, bid_x, bid_y] + + # Copy atoms: 128b for all dtypes (matches old path's buffer_load_dwordx4 / ds_read_b128) + mma_copy = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), layout_elem) + mma_uni = fx.make_copy_atom(fx.UniversalCopy128b(), layout_elem) + buf_copy_g2s = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), layout_elem) + uni_copy_g2s = fx.make_copy_atom(fx.UniversalCopy128b(), layout_elem) + + # Per-thread slices + thr_mma = tiled_mma.thr_slice(tid) + thr_g2s = tiled_copy_g2s.get_slice(tid) + thr_s2r = fx.make_tiled_copy_A(mma_copy, tiled_mma).get_slice(tid) + thr_g2r_B = fx.make_tiled_copy_B(mma_copy, tiled_mma).get_slice(tid) + + # LDS: XOR swizzle for f16/bf16 to avoid bank conflicts, identity for fp8 + smem_ptr = fx.recast_iter( + fx.PointerType.get(layout_elem.ir_type, fx.AddressSpace.Shared, 512), + fx.get_dyn_shared(), + ) + if is_fp8: + sA = fx.make_view(smem_ptr, + fx.make_ordered_layout((tile_m, tile_k, 2), (1, 0, 2))) + else: + swz = fx.SwizzleType.get(3, 3, 3) + sA = fx.make_view(smem_ptr, fx.make_composed_layout( + fx.static(swz), + fx.make_ordered_layout((tile_m, tile_k, 2), (1, 0, 2)), + )) + + # Partitions + pA_g = thr_g2s.partition_S(tA) + pA_s = thr_g2s.partition_D(sA) + pA_s2r = thr_s2r.partition_S(sA) + pB_g = thr_g2r_B.partition_S(tB) + + # Fragments — 2 separate B fragments (split double buffer for VGPR lifetime) + frag_copy_A = fx.make_fragment_like(pA_s[None, None, None, 0]) + frag_A = thr_mma.make_fragment_A(sA[None, None, 0]) + frag_B_single_layout = thr_mma.partition_B(tB).layout(None, None, None, 0) + frag_B_0 = fx.make_fragment_like(frag_B_single_layout, layout_elem.ir_type) + frag_B_1 = fx.make_fragment_like(frag_B_single_layout, layout_elem.ir_type) + frag_B_stages = [frag_B_0, frag_B_1] + frag_C = thr_mma.make_fragment_C(tC) + frag_A_retile = thr_s2r.retile(frag_A) + frag_B_0_retile = thr_g2r_B.retile(frag_B_0) + frag_B_1_retile = thr_g2r_B.retile(frag_B_1) + frag_B_retile_stages = [frag_B_0_retile, frag_B_1_retile] + buf_copy_out = fx.make_copy_atom(fx.rocdl.BufferCopy16b(), out_elem_cls) + thr_r2g_C = fx.make_tiled_copy_C(buf_copy_out, tiled_mma).get_slice(tid) + pC_g = thr_r2g_C.partition_S(tC) + frag_C_out = fx.make_fragment_like(frag_C, out_elem_cls.ir_type) + frag_C_retile = thr_r2g_C.retile(frag_C_out) + + # ── Scheduling hints (ported from old pipeline) ─────────── + def build_scheduler(numer: int, denom: int): + if denom <= 0: + return [] + if numer <= 0: + return [0] * denom + out = [] + prev = 0 + for i in range_constexpr(denom): + cur = ((i + 1) * numer + (denom - 1)) // denom + out.append(cur - prev) + prev = cur + return out + + def hot_loop_scheduler(): + mfma_group = num_acc_n + + if is_gfx942: + mfma_total = (k_iters * 2) * m_repeat * mfma_group + mfma_per_iter = 2 * mfma_group + sche_iters = 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) + + rocdl.sched_dsrd(2) + rocdl.sched_mfma(1) + if tile_m == 16: + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + if tile_m == 16: + rocdl.sched_vmem(1) + + if num_acc_n < 4: + rocdl.sched_dsrd(1) + rocdl.sched_mfma(1) + if tile_m == 16: + rocdl.sched_vmem(1) + rocdl.sched_dsrd(1) + rocdl.sched_mfma(1) + if tile_m == 16: + rocdl.sched_vmem(1) + rocdl.sched_mfma(1) + + dswr_tail = num_a_loads + dstr_advance = 2 + if dswr_tail > sche_iters: + dswr_tail = sche_iters + dswr_start = max(sche_iters - dswr_tail - dstr_advance, 0) + + for sche_i in range_constexpr(sche_iters): + rocdl.sched_vmem(1) + rocdl.sched_mfma(mfma_group) + rocdl.sched_dsrd(1) + rocdl.sched_mfma(mfma_group) + if sche_i >= dswr_start - 1: + rocdl.sched_dswr(1) + else: + # gfx950 path: distribute vmem/dsrd across MFMA slots + if use_mfma_k32: + element_k_per_mfma = 32 + elif is_fp8: + element_k_per_mfma = 128 # mfma_scale_f32_16x16x128 + else: + element_k_per_mfma = 16 + num_mfma_per_tile_k = tile_k // element_k_per_mfma + mfma_total = num_mfma_per_tile_k * m_repeat * mfma_group + dswr_tail = num_a_loads + dstr_advance = 2 + if dswr_tail > mfma_total: + dswr_tail = mfma_total + dsrd_preload_eff = min(int(dsrd_preload), num_ds_load) + dvmem_preload_eff = min(int(dvmem_preload), num_gmem_loads) + vmem_remaining = num_gmem_loads - dvmem_preload_eff + dsrd_remaining = num_ds_load - dsrd_preload_eff + if vmem_remaining > 0 and vmem_remaining < mfma_total: + vmem_schedule = (build_scheduler(vmem_remaining, vmem_remaining) + + [0] * (mfma_total - vmem_remaining)) + else: + vmem_schedule = build_scheduler(vmem_remaining, mfma_total) + dsrd_schedule = build_scheduler(dsrd_remaining, mfma_total) + dswr_start = max(mfma_total - dswr_tail - dstr_advance, 0) + last_dsrd_mfma_idx = -1 + for sched_idx in range_constexpr(mfma_total): + if dsrd_schedule[sched_idx]: + last_dsrd_mfma_idx = sched_idx + dswr_start = max(dswr_start, last_dsrd_mfma_idx + 1) + idx_ds_read = dsrd_preload_eff + idx_gmem_load = dvmem_preload_eff + idx_ds_write = 0 + if dvmem_preload_eff: + rocdl.sched_vmem(dvmem_preload_eff) + if dsrd_preload_eff: + rocdl.sched_dsrd(dsrd_preload_eff) + for mfma_idx in range_constexpr(mfma_total): + rocdl.sched_mfma(1) + n_dsrd = dsrd_schedule[mfma_idx] + if n_dsrd and (idx_ds_read < num_ds_load): + if idx_ds_read + n_dsrd > num_ds_load: + n_dsrd = num_ds_load - idx_ds_read + if n_dsrd: + rocdl.sched_dsrd(n_dsrd) + idx_ds_read += n_dsrd + n_vmem = vmem_schedule[mfma_idx] + if n_vmem and (idx_gmem_load < num_gmem_loads): + if idx_gmem_load + n_vmem > num_gmem_loads: + n_vmem = num_gmem_loads - idx_gmem_load + if n_vmem: + rocdl.sched_vmem(n_vmem) + idx_gmem_load += n_vmem + if (idx_ds_write < dswr_tail) and (mfma_idx >= dswr_start): + rocdl.sched_dswr(1) + idx_ds_write += 1 + if idx_ds_write < num_a_loads: + rocdl.sched_dswr(num_a_loads - idx_ds_write) + + rocdl.sched_barrier(0) + + # ── Pipeline stage (double-buffered B via split fragments) ─ + def pipeline_stage(read_stage, next_k_val=None, read_next=True): + write_stage = read_stage ^ 1 + cur_frag_B = frag_B_stages[read_stage] + # 1. Prefetch next A tile (global → register) + if read_next and next_k_val is not None: + fx.copy(buf_copy_g2s, pA_g[None, None, None, next_k_val], frag_copy_A) + # 2. Load next B tile (before compute — matches v1 pipeline order, + # all vmem available for scheduler interleaving with MFMAs) + if read_next and next_k_val is not None: + fx.copy(mma_copy, pB_g[None, None, None, next_k_val], + frag_B_retile_stages[write_stage]) + # 3. Compute: A from LDS + MFMA with current B + for ki in range_constexpr(k_iters): + fx.copy(mma_uni, pA_s2r[None, None, ki, read_stage], + frag_A_retile[None, None, ki]) + # K=128 or K=32 (1 atom): frag K dim is flat k_iters → coord = ki + # K=16 gfx942 (2 atoms): frag K dim is (atoms, k_iters) → coord = (None, ki) + k_coord = ki if (use_mfma_scale_128 or use_mfma_k32) else (None, ki) + fx.gemm(tiled_mma, frag_C, + frag_A[None, None, k_coord], + cur_frag_B[None, None, k_coord], + frag_C) + # 4. Write A tile to LDS + barrier + fx.copy(uni_copy_g2s, frag_copy_A, pA_s[None, None, None, write_stage]) + if enable_scheduler: + hot_loop_scheduler() + gpu.barrier() + + # ── Prologue ────────────────────────────────────────────── + fx.copy(buf_copy_g2s, pA_g[None, None, None, 0], frag_copy_A) + fx.copy(mma_copy, pB_g[None, None, None, 0], frag_B_retile_stages[0]) + frag_C.store(Vec.filled(acc_size, 0.0, Float32)) + fx.copy(uni_copy_g2s, frag_copy_A, pA_s[None, None, None, 0]) + gpu.barrier() + rocdl.sched_barrier(0) + + # ── Main tile loop (scf.for with ping-pong) ────────────── + if num_tiles == 1: + pipeline_stage(read_stage=0, read_next=False) + elif num_tiles == 2: + pipeline_stage(read_stage=0, next_k_val=fx.Int32(1)) + pipeline_stage(read_stage=1, read_next=False) + else: + loop_start = fx.Index(0) + loop_end = fx.Index((num_tiles - 2) // 2) + loop_step = fx.Index(1) + # Loop-carried values: + # bf16/f16: acc + B stage 0 (B alloca types don't match for SROA) + # fp8: acc only (B alloca has uniform i64 types → SROA promotes it) + acc_init = frag_C.load() + if is_fp8: + for iv, state in range(loop_start, loop_end, loop_step, + init=[acc_init]): + frag_C.store(state[0]) + k_base = fx.Int32(iv * 2) + pipeline_stage(read_stage=0, next_k_val=k_base + fx.Int32(1)) + pipeline_stage(read_stage=1, next_k_val=k_base + fx.Int32(2)) + results = yield [frag_C.load()] + frag_C.store(results) + else: + b0_init = frag_B_stages[0].load() + for iv, state in range(loop_start, loop_end, loop_step, + init=[acc_init, b0_init]): + frag_C.store(state[0]) + frag_B_stages[0].store(state[1]) + k_base = fx.Int32(iv * 2) + pipeline_stage(read_stage=0, next_k_val=k_base + fx.Int32(1)) + pipeline_stage(read_stage=1, next_k_val=k_base + fx.Int32(2)) + results = yield [frag_C.load(), frag_B_stages[0].load()] + frag_C.store(results[0]) + frag_B_stages[0].store(results[1]) + pipeline_stage(read_stage=0, next_k_val=fx.Int32(num_tiles - 1)) + pipeline_stage(read_stage=1, read_next=False) + + # ── Epilogue ───────────────────────────────────────────── + if is_fp8: + # FP8: inline scale multiply via layout API buffer loads + # Accumulator layout: [mi*num_acc_n*4 + ni*4 + ii] + # scale_a depends on row (mi, ii), scale_b depends on col (ni) + bx_m = gpu.block_id("x") * tile_m + by_n = gpu.block_id("y") * tile_n + wave_id = gpu.thread_id("x") // 64 + lane_id = gpu.thread_id("x") % 64 + lane_div_16 = lane_id // 16 + lane_mod_16 = lane_id % 16 + n_tile_base = wave_id * n_per_wave + + # Scale buffer tensors + scalar copy atom + scale_a_buf = fx.rocdl.make_buffer_tensor(arg_scale_a, max_size=True) + scale_b_buf = fx.rocdl.make_buffer_tensor(arg_scale_b, max_size=True) + scale_copy = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + scale_reg_ty = fx.MemRefType.get(T.f32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + scale_reg_lay = fx.make_layout(1, 1) + scale_a_div = fx.logical_divide(scale_a_buf, fx.make_layout(1, 1)) + scale_b_div = fx.logical_divide(scale_b_buf, fx.make_layout(1, 1)) + + def load_scale(div_tensor, index): + r = fx.memref_alloca(scale_reg_ty, scale_reg_lay) + fx.copy_atom_call(scale_copy, fx.slice(div_tensor, (None, fx.Int32(index))), r) + return Vec(fx.memref_load_vec(r))[0] + + # Load per-column scales: 1 scalar per N-block + s_b_vals = [load_scale(scale_b_div, by_n + n_tile_base + ni * 16 + lane_mod_16) + for ni in range_constexpr(num_acc_n)] + # Load per-row scales: 1 scalar per row per thread + s_a_vals = [[load_scale(scale_a_div, bx_m + mi * 16 + lane_div_16 * 4 + ii) + for ii in range_constexpr(4)] + for mi in range_constexpr(m_repeat)] + + # Build scaled accumulator inline + acc_vec = Vec(frag_C.load()) + scaled_elems = [] + for mi in range_constexpr(m_repeat): + for ni in range_constexpr(num_acc_n): + for ii in range_constexpr(4): + idx = mi * num_acc_n * 4 + ni * 4 + ii + val = acc_vec[idx] + s_a = s_a_vals[mi][ii] + scaled_val = (val * s_a) * s_b_vals[ni] + scaled_elems.append(scaled_val.to(out_elem_cls)) + + out_vec = vector.from_elements( + T.vec(acc_size, out_elem_cls.ir_type), scaled_elems) + frag_C_out.store(out_vec) + fx.copy(buf_copy_out, frag_C_retile, pC_g) + else: + # f16/bf16: truncate + vectorized fx.copy + frag_C_out.store(Vec(frag_C.load()).to(out_elem_cls)) + fx.copy(buf_copy_out, frag_C_retile, pC_g) + + # ── Host launcher ───────────────────────────────────────────── + @flyc.jit + def launch_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_scale_a: fx.Tensor, + arg_scale_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + stream: fx.Stream, + ): + ctx = CompilationContext.get_current() + + # MMA atom — layout_elem carries the dtype (Float16/BFloat16/Float8E4M3FN/etc) + if use_mfma_k32: + mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 32, layout_elem)) + k_perm = fx.make_layout((8, 4), (1, 8)) + elif is_f16_or_bf16: + mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 16, layout_elem)) + k_perm = fx.make_layout((4, 4, 2), (1, 8, 4)) + elif use_mfma_scale_128: + mma_atom = fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16, 16, 128, layout_elem)) + k_perm = fx.make_layout((32, 4), (1, 32)) + else: + mma_atom = fx.make_mma_atom(fx.rocdl.MFMA(16, 16, 32, layout_elem)) + k_perm = fx.make_layout((8, 4, 2), (1, 16, 8)) + + tiled_mma = fx.make_tiled_mma( + mma_atom, fx.make_layout((1, 4, 1), (0, 1, 0)), + fx.make_tile(None, None, k_perm), + ) + + # G2S tiled copy + val_per_thr = a_load_bytes // elem_bytes + thrs_k = tile_k // val_per_thr + thrs_m = total_threads // thrs_k + tiled_copy_g2s = fx.make_tiled_copy( + fx.make_copy_atom(fx.UniversalCopy128b(), layout_elem), + fx.make_layout( + ((thrs_k, thrs_m), (1, val_per_thr)), + ((thrs_m * val_per_thr, 1), (1, thrs_m)), + ), + fx.make_tile(thrs_m, tile_k), + ) + + # Preshuffle B layout (2D hierarchical) + kp_bytes = 16 + kp_elems = kp_bytes if elem_bytes == 1 else kp_bytes // elem_bytes + k_bytes_b = K * elem_bytes + n0 = N // 16 + k0 = k_bytes_b // 64 + s_nlane = kp_elems + s_klane = 16 * s_nlane + s_k0 = 4 * s_klane + s_n0 = k0 * s_k0 + preshuffle_B = fx.Tensor(fx.make_view( + fx.get_iter(arg_b), + fx.make_layout(((16, n0), (kp_elems, 4, k0)), + ((s_nlane, s_n0), (1, s_klane, s_k0))), + )) + + # Reshape A and C to 2D + M_max = 65536 + arg_a_2d = fx.Tensor(fx.make_view( + fx.get_iter(arg_a), fx.make_layout((M_max, K), (K, 1)), + )) + arg_c_2d = fx.Tensor(fx.make_view( + fx.get_iter(arg_c), fx.make_layout((M_max, N), (N, 1)), + )) + + gx = (i32_m + (tile_m - 1)) // tile_m + gy = i32_n // tile_n + + launcher = kernel_gemm( + arg_c_2d, arg_a_2d, preshuffle_B, + arg_scale_a, arg_scale_b, i32_m, i32_n, + tiled_mma, tiled_copy_g2s, + ) + if waves_per_eu is not None and int(waves_per_eu) >= 1: + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + T.i32, int(waves_per_eu)) + launcher.launch( + grid=(gx, gy, 1), block=(256, 1, 1), smem=smem_bytes, stream=stream, + ) + + return launch_gemm diff --git a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp index 13d2a9de..baca1f29 100644 --- a/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/CDNA3/MmaAtom.cpp @@ -84,7 +84,7 @@ LogicalResult MmaOpCDNA3_MFMAType::verify(function_ref emi auto isValidElemType = [](Type ty) { return ty.isF16() || ty.isBF16() || ty.isF32() || isa(ty) || - isa(ty); + isa(ty) || isa(ty); }; if (!isValidElemType(elemTyA)) { return emitError() << "elemTyA must be f16, bf16, f32, f8E4M3FNUZ, f8E5M2FNUZ, got " << elemTyA; @@ -95,7 +95,7 @@ LogicalResult MmaOpCDNA3_MFMAType::verify(function_ref emi return success(); } -static bool isFP8(Type ty) { return isa(ty); } +static bool isFP8(Type ty) { return isa(ty) || isa(ty); } static bool isBF8(Type ty) { return isa(ty); } static Type getMfmaABType(MLIRContext *ctx, Type elemTy, int32_t mn, int32_t k = 0) { @@ -175,6 +175,8 @@ FailureOr MmaOpCDNA3_MFMAType::emitAtomCallSSA(OpBuilder &builder, Locati DISPATCH_MFMA_SSA(4, 4, elemTyA.isF16(), mfma_f32_4x4x4f16) DISPATCH_MFMA_SSA(32, 8, elemTyA.isF16(), mfma_f32_32x32x8f16) DISPATCH_MFMA_SSA(16, 16, elemTyA.isF16(), mfma_f32_16x16x16f16) + DISPATCH_MFMA_SSA(16, 32, elemTyA.isF16(), mfma_f32_16x16x32_f16) + DISPATCH_MFMA_SSA(32, 16, elemTyA.isF16(), mfma_f32_32x32x16_f16) DISPATCH_MFMA_SSA(32, 2, elemTyA.isBF16(), mfma_f32_32x32x2bf16) DISPATCH_MFMA_SSA(16, 2, elemTyA.isBF16(), mfma_f32_16x16x2bf16) @@ -182,6 +184,8 @@ FailureOr MmaOpCDNA3_MFMAType::emitAtomCallSSA(OpBuilder &builder, Locati DISPATCH_MFMA_SSA(32, 4, elemTyA.isBF16(), mfma_f32_32x32x4bf16) DISPATCH_MFMA_SSA(16, 8, elemTyA.isBF16(), mfma_f32_16x16x8bf16) DISPATCH_MFMA_SSA(16, 16, elemTyA.isBF16(), mfma_f32_16x16x16bf16_1k) + DISPATCH_MFMA_SSA(16, 32, elemTyA.isBF16(), mfma_f32_16x16x32_bf16) + DISPATCH_MFMA_SSA(32, 16, elemTyA.isBF16(), mfma_f32_32x32x16_bf16) DISPATCH_MFMA_SSA(16, 32, isFP8(elemTyA) && isFP8(elemTyB), mfma_f32_16x16x32_fp8_fp8) DISPATCH_MFMA_SSA(16, 32, isFP8(elemTyA) && isBF8(elemTyB), mfma_f32_16x16x32_fp8_bf8) diff --git a/tests/kernels/bench_preshuffle_gemm_v2.py b/tests/kernels/bench_preshuffle_gemm_v2.py new file mode 100644 index 00000000..8ace8cdc --- /dev/null +++ b/tests/kernels/bench_preshuffle_gemm_v2.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Benchmark: preshuffle_gemm_v2 (layout API) vs old preshuffle_gemm. + +Usage: + # Run all default configs + PYTHONPATH=./ python tests/kernels/bench_preshuffle_gemm_v2.py + + # Specific dtype + PYTHONPATH=./ python tests/kernels/bench_preshuffle_gemm_v2.py --dtype fp16 + PYTHONPATH=./ python tests/kernels/bench_preshuffle_gemm_v2.py --dtype bf16 + PYTHONPATH=./ python tests/kernels/bench_preshuffle_gemm_v2.py --dtype fp8 + + # Custom shape + PYTHONPATH=./ python tests/kernels/bench_preshuffle_gemm_v2.py --dtype fp16 -M 5120 -N 5120 -K 8192 --tile_m 64 --tile_n 128 --tile_k 64 + + # All tiles sweep for a given shape + PYTHONPATH=./ python tests/kernels/bench_preshuffle_gemm_v2.py --dtype fp16 -M 128 -N 5120 -K 8192 --sweep +""" + +import os +import sys +import argparse + +os.environ.setdefault("FLYDSL_RUNTIME_ENABLE_CACHE", "0") + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import torch +import flydsl.compiler as flyc +from kernels.preshuffle_gemm_v2 import compile_preshuffle_gemm_v2 +from kernels.preshuffle_gemm import compile_preshuffle_gemm_a8 +from tests.utils import pertoken_quant, shuffle_weight +from flydsl.runtime.device import get_rocm_arch + +ARCH = str(get_rocm_arch()) +DTYPE_FP8 = torch.float8_e4m3fn if "gfx95" in ARCH else torch.float8_e4m3fnuz +DEVICE = torch.device("cuda") + + +def _bench_kernel(compiled_fn, args, warmup=5, iters=20): + for _ in range(warmup): + compiled_fn(*args) + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + for _ in range(iters): + compiled_fn(*args) + e.record() + torch.cuda.synchronize() + return s.elapsed_time(e) * 1000 / iters # us + + +def _make_data(M, N, K, in_dtype): + is_fp = in_dtype in ("fp16", "bf16") + if is_fp: + torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + a = torch.rand(M, K, device=DEVICE, dtype=torch_dtype) + b_raw = torch.rand(N, K, device=DEVICE, dtype=torch_dtype) + sa = sb = torch.empty(0, device=DEVICE, dtype=torch.float32) + ref = a.float() @ b_raw.float().T + else: + a_f = torch.rand(M, K, device=DEVICE, dtype=torch.float32) + b_f = torch.rand(N, K, device=DEVICE, dtype=torch.float32) + a, sa = pertoken_quant(a_f, quant_dtype=DTYPE_FP8) + b_raw, sb = pertoken_quant(b_f, quant_dtype=DTYPE_FP8) + ref = (a.float() * sa.view(-1, 1)) @ (b_raw.float() * sb.view(-1, 1)).T + b_shuf = shuffle_weight(b_raw, layout=(16, 16)) + return a, b_raw, b_shuf, sa, sb, ref + + +def _as_i8(t): + return t.view(torch.int8) if "float8" in str(t.dtype) else t + + +def _make_args(c, a, b_shuf, sa, sb, M, N, *, include_bias=False): + args = [ + c.view(-1), + _as_i8(a.view(-1)), + _as_i8(b_shuf.view(-1)), + sa.view(-1) if sa.numel() > 0 else sa, + sb.view(-1) if sb.numel() > 0 else sb, + ] + if include_bias: + args.append(torch.empty(0, device=c.device, dtype=c.dtype)) + args.extend([M, N, torch.cuda.current_stream()]) + return tuple(args) + + +def compile_one(M, N, K, tile_m, tile_n, tile_k, in_dtype, out_dtype="bf16", + waves_per_eu=None, enable_scheduler=True, maxnreg=None, + opt_level=None): + """Compile v2 and old kernels, return compilation status.""" + elem_bytes = 1 if in_dtype in ("fp8",) else 2 + smem = tile_m * tile_k * elem_bytes * 2 + if smem > 65536: + return None + + torch_out_dtype = torch.bfloat16 if out_dtype == "bf16" else torch.float16 + a, b_raw, b_shuf, sa, sb, ref = _make_data(M, N, K, in_dtype) + c = torch.zeros(M, N, device=DEVICE, dtype=torch_out_dtype) + args = _make_args(c, a, b_shuf, sa, sb, M, N) + + hints = {} + if maxnreg: + hints["maxnreg"] = maxnreg + if opt_level is not None: + hints["opt_level"] = opt_level + + import time + t0 = time.time() + fn_v2 = compile_preshuffle_gemm_v2( + N=N, K=K, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + in_dtype=in_dtype, out_dtype=out_dtype, + waves_per_eu=waves_per_eu, enable_scheduler=enable_scheduler, + ) + compiled_v2 = flyc.compile[hints](fn_v2, *args) if hints else flyc.compile(fn_v2, *args) + t_v2 = time.time() - t0 + + t0 = time.time() + fn_old = compile_preshuffle_gemm_a8( + N=N, K=K, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + in_dtype=in_dtype, + out_dtype=out_dtype, + ) + args_old = _make_args(c, a, b_shuf, sa, sb, M, N, include_bias=True) + compiled_old = flyc.compile(fn_old, *args_old) + t_old = time.time() - t0 + + tile_str = f"{tile_m}x{tile_n}x{tile_k}" + print(f" {tile_str:>14s} v2: {t_v2:.1f}s old: {t_old:.1f}s [OK]") + return {"tile": tile_str} + + +def bench_one(M, N, K, tile_m, tile_n, tile_k, in_dtype, out_dtype="bf16", + warmup=5, iters=20, check_correctness=True, waves_per_eu=None, + enable_scheduler=True, maxnreg=None, opt_level=None, + llvm_opts=None): + elem_bytes = 1 if in_dtype in ("fp8",) else 2 + smem = tile_m * tile_k * elem_bytes * 2 + if smem > 65536: + return None # LDS overflow + + torch_out_dtype = torch.bfloat16 if out_dtype == "bf16" else torch.float16 + a, b_raw, b_shuf, sa, sb, ref = _make_data(M, N, K, in_dtype) + + # ── v2 (layout API) ────────────────────────────────────────── + hints = {} + if maxnreg: + hints["maxnreg"] = maxnreg + if opt_level is not None: + hints["opt_level"] = opt_level + if llvm_opts: + hints["llvm_options"] = llvm_opts + fn_v2 = compile_preshuffle_gemm_v2( + N=N, K=K, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + in_dtype=in_dtype, out_dtype=out_dtype, + waves_per_eu=waves_per_eu, enable_scheduler=enable_scheduler, + ) + c_v2 = torch.zeros(M, N, device=DEVICE, dtype=torch_out_dtype) + args_v2 = _make_args(c_v2, a, b_shuf, sa, sb, M, N) + compiled_v2 = flyc.compile[hints](fn_v2, *args_v2) if hints else flyc.compile(fn_v2, *args_v2) + us_v2 = _bench_kernel(compiled_v2, args_v2, warmup=warmup, iters=iters) + + # ── old path ────────────────────────────────────────────────── + us_old = 0.0 + compiled_old = None + try: + fn_old = compile_preshuffle_gemm_a8( + N=N, K=K, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + in_dtype=in_dtype, + out_dtype=out_dtype, + ) + c_old = torch.zeros(M, N, device=DEVICE, dtype=torch_out_dtype) + args_old = _make_args(c_old, a, b_shuf, sa, sb, M, N, include_bias=True) + compiled_old = flyc.compile(fn_old, *args_old) + us_old = _bench_kernel(compiled_old, args_old, warmup=warmup, iters=iters) + except (ValueError, RuntimeError) as e: + print(f" (old kernel unsupported: {e})") + c_old = torch.zeros(M, N, device=DEVICE, dtype=torch_out_dtype) + args_old = _make_args(c_old, a, b_shuf, sa, sb, M, N, include_bias=True) + + flops = 2 * M * N * K + tflops_v2 = flops / (us_v2 / 1e6) / 1e12 + tflops_old = flops / (us_old / 1e6) / 1e12 if us_old > 0 else 0.0 + ratio = tflops_v2 / tflops_old * 100 if tflops_old > 0 else float("inf") + + # correctness + err_v2 = err_old = None + if check_correctness: + compiled_v2(*args_v2) + if compiled_old: + compiled_old(*args_old) + torch.cuda.synchronize() + err_v2 = ((c_v2.float() - ref).abs() / (ref.abs() + 1e-6)).mean().item() + if compiled_old: + err_old = ((c_old.float() - ref).abs() / (ref.abs() + 1e-6)).mean().item() + + return { + "M": M, "N": N, "K": K, + "tile": f"{tile_m}x{tile_n}x{tile_k}", + "k_iters": tile_k // 32, + "us_v2": us_v2, "us_old": us_old, + "tflops_v2": tflops_v2, "tflops_old": tflops_old, + "ratio": ratio, + "err_v2": err_v2, "err_old": err_old, + } + + +# ── Default benchmark configs ───────────────────────────────────── + +DEFAULT_CONFIGS = { + "fp16": [ + # (M, N, K, tile_m, tile_n, tile_k) + (128, 5120, 8192, 64, 128, 64), # k=2, best tile + (128, 5120, 8192, 128, 128, 64), # k=2 + (128, 5120, 8192, 64, 128, 128), # k=4 + (32, 5120, 8192, 32, 64, 512), # k=16, stress test + (5120, 5120, 8192, 64, 128, 64), # large, compute-bound + (5120, 5120, 8192, 64, 128, 128), # large, k=4 + ], + "bf16": [ + (128, 5120, 8192, 64, 128, 64), + (128, 5120, 8192, 64, 128, 128), + (5120, 5120, 8192, 64, 128, 64), + ], + "fp8": [ + (16, 5120, 8192, 16, 64, 256), + (128, 5120, 8192, 64, 128, 128), + (128, 5120, 8192, 64, 256, 256), + (5120, 5120, 8320, 64, 256, 128), + (5120, 5120, 8320, 64, 256, 256), + ], +} + +SWEEP_TILES = [ + (32, 64, 64), + (32, 64, 128), + (32, 64, 256), + (32, 64, 512), + (64, 128, 64), + (64, 128, 128), + (64, 128, 256), + (64, 256, 64), + (64, 256, 128), + (128, 128, 64), + (128, 256, 64), +] + + +def print_results(results, in_dtype): + print() + hdr = f"{'tile':>14s} {'k':>2s} | {'v2 us':>8s} {'v2 TF':>7s} | {'old us':>8s} {'old TF':>7s} | {'ratio':>6s}" + if results and results[0].get("err_v2") is not None: + hdr += f" | {'err_v2':>7s} {'err_old':>7s}" + print(f" {in_dtype.upper()} (M={results[0]['M']}, N={results[0]['N']}, K={results[0]['K']})") + print(f" {hdr}") + print(f" {'-' * len(hdr)}") + for r in results: + old_us_str = f"{r['us_old']:>8.1f}" if r['us_old'] > 0 else f"{'n/a':>8s}" + old_tf_str = f"{r['tflops_old']:>7.1f}" if r['tflops_old'] > 0 else f"{'n/a':>7s}" + ratio_str = f"{r['ratio']:>5.1f}%" if r['ratio'] != float("inf") else f"{'v2only':>6s}" + line = (f" {r['tile']:>14s} {r['k_iters']:>2d} | " + f"{r['us_v2']:>8.1f} {r['tflops_v2']:>7.1f} | " + f"{old_us_str} {old_tf_str} | " + f"{ratio_str}") + if r.get("err_v2") is not None: + ev2 = f"{r['err_v2']:>7.4f}" if r['err_v2'] is not None else f"{'n/a':>7s}" + eo = f"{r['err_old']:>7.4f}" if r['err_old'] is not None else f"{'n/a':>7s}" + line += f" | {ev2} {eo}" + print(line) + + +def main(): + parser = argparse.ArgumentParser(description="Benchmark preshuffle_gemm v2 vs old") + parser.add_argument("--dtype", type=str, default=None, + choices=["fp16", "bf16", "fp8"], + help="Data type to benchmark (default: all)") + parser.add_argument("-M", type=int, default=None) + parser.add_argument("-N", type=int, default=None) + parser.add_argument("-K", type=int, default=None) + parser.add_argument("--tile_m", type=int, default=None) + parser.add_argument("--tile_n", type=int, default=None) + parser.add_argument("--tile_k", type=int, default=None) + parser.add_argument("--sweep", action="store_true", + help="Sweep all tile configs for given M/N/K") + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--no-check", action="store_true", + help="Skip correctness check") + parser.add_argument("--waves_per_eu", type=int, default=None, + help="Set waves_per_eu hint for v2 kernel (e.g. 3)") + parser.add_argument("--compile_only", action="store_true", + help="Compile only (no benchmark). Use with FLYDSL_DUMP_IR=1 for VGPR analysis") + parser.add_argument("--no_scheduler", action="store_true", + help="Disable hot_loop_scheduler in v2 kernel") + parser.add_argument("--maxnreg", type=int, default=None, + help="Set max VGPR count for v2 kernel (e.g. 168)") + parser.add_argument("--opt_level", type=int, default=None, + help="LLVM optimization level for v2 kernel (default: 2)") + parser.add_argument("--no_post_misched", action="store_true", + help="Disable LLVM post-RA machine scheduling") + parser.add_argument("--lsr_drop", action="store_true", + help="Set lsr-drop-solution=True") + args = parser.parse_args() + + llvm_opts = {} + if args.no_post_misched: + llvm_opts["enable-post-misched"] = False + if args.lsr_drop: + llvm_opts["lsr-drop-solution"] = True + if not llvm_opts: + llvm_opts = None + + dtypes = [args.dtype] if args.dtype else ["fp16", "bf16", "fp8"] + + print(f"GPU: {ARCH}") + wpe_str = f", waves_per_eu={args.waves_per_eu}" if args.waves_per_eu else "" + print(f"Pipeline comparison: v2 (layout API{wpe_str}) vs old (manual)") + print("=" * 78) + + for dt in dtypes: + # Compile-only mode + if args.compile_only: + if args.M and args.tile_m: + M, N, K = args.M, args.N or 5120, args.K or 8192 + compile_one(M, N, K, args.tile_m, args.tile_n or 128, args.tile_k or 64, + dt, waves_per_eu=args.waves_per_eu, + enable_scheduler=not args.no_scheduler, + maxnreg=args.maxnreg, opt_level=args.opt_level) + else: + configs = DEFAULT_CONFIGS.get(dt, []) + for M, N, K, tm, tn, tk in configs: + compile_one(M, N, K, tm, tn, tk, dt, waves_per_eu=args.waves_per_eu, + enable_scheduler=not args.no_scheduler, + maxnreg=args.maxnreg, opt_level=args.opt_level) + continue + + # Custom single config + if args.M and args.tile_m and not args.sweep: + M, N, K = args.M, args.N or 5120, args.K or 8192 + r = bench_one(M, N, K, args.tile_m, args.tile_n or 128, args.tile_k or 64, + dt, warmup=args.warmup, iters=args.iters, + check_correctness=not args.no_check, + waves_per_eu=args.waves_per_eu, + enable_scheduler=not args.no_scheduler, + maxnreg=args.maxnreg, opt_level=args.opt_level, + llvm_opts=llvm_opts) + if r: + print_results([r], dt) + continue + + # Sweep mode + if args.sweep: + M = args.M or 128 + N = args.N or 5120 + K = args.K or 8192 + results = [] + for tm, tn, tk in SWEEP_TILES: + if tm > M: + continue + r = bench_one(M, N, K, tm, tn, tk, dt, + warmup=args.warmup, iters=args.iters, + check_correctness=not args.no_check, + waves_per_eu=args.waves_per_eu, + enable_scheduler=not args.no_scheduler, + maxnreg=args.maxnreg, opt_level=args.opt_level, + llvm_opts=llvm_opts) + if r: + results.append(r) + if results: + print_results(results, dt) + continue + + # Default configs + configs = DEFAULT_CONFIGS.get(dt, []) + if not configs: + continue + # Group by (M, N, K) + groups = {} + for M, N, K, tm, tn, tk in configs: + key = (M, N, K) + groups.setdefault(key, []).append((tm, tn, tk)) + + for (M, N, K), tiles in groups.items(): + results = [] + for tm, tn, tk in tiles: + r = bench_one(M, N, K, tm, tn, tk, dt, + warmup=args.warmup, iters=args.iters, + check_correctness=not args.no_check, + waves_per_eu=args.waves_per_eu, + enable_scheduler=not args.no_scheduler, + maxnreg=args.maxnreg, opt_level=args.opt_level, + llvm_opts=llvm_opts) + if r: + results.append(r) + if results: + print_results(results, dt) + + print() + print("=" * 78) + print("Done.") + + +if __name__ == "__main__": + main() From ff6df13a91219d6cc97b322073a3cf64eab32600 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Thu, 23 Apr 2026 15:03:59 +0800 Subject: [PATCH 26/29] [OPT] Add fly-int-swizzle-simplify pass (#427) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [OPT] Add fly-int-swizzle-simplify pass Add an algebraic simplification pass that recognizes the canonical three-instruction swizzle sequence (andi/shrui/xori) emitted by `applySwizzle` and peels period-aligned addends out of the swizzle: swizzle(base + d) → swizzle(base) + d when d % period == 0 --------- Co-authored-by: Claude Opus 4 --- include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td | 2 + .../flydsl/Dialect/Fly/Transforms/Passes.td | 25 ++ include/flydsl/Dialect/Fly/Utils/IntUtils.h | 2 + lib/Dialect/Fly/CMakeLists.txt | 3 +- .../Fly/Transforms/IntSwizzleSimplify.cpp | 247 ++++++++++++++++++ lib/Dialect/Fly/Utils/IntUtils.cpp | 6 + python/flydsl/compiler/backends/rocm.py | 2 + .../mlir/Transforms/int_swizzle_simplify.mlir | 122 +++++++++ 8 files changed, 408 insertions(+), 1 deletion(-) create mode 100644 lib/Dialect/Fly/Transforms/IntSwizzleSimplify.cpp create mode 100644 tests/mlir/Transforms/int_swizzle_simplify.mlir diff --git a/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td b/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td index 2f71ad5e..521f4cf9 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td +++ b/include/flydsl/Dialect/Fly/IR/FlyAttrDefs.td @@ -156,6 +156,8 @@ def Fly_SwizzleAttr : Fly_Attr<"Swizzle", "swizzle", []> { let extraClassDeclaration = [{ bool isTrivialSwizzle() const; static SwizzleAttr getTrivialSwizzle(MLIRContext *context); + + int32_t period() const { return 1 << (getMask() + getBase() + getShift()); } }]; } diff --git a/include/flydsl/Dialect/Fly/Transforms/Passes.td b/include/flydsl/Dialect/Fly/Transforms/Passes.td index c3b60aca..44a34f56 100644 --- a/include/flydsl/Dialect/Fly/Transforms/Passes.td +++ b/include/flydsl/Dialect/Fly/Transforms/Passes.td @@ -62,6 +62,31 @@ def FlyConvertAtomCallToSSAFormPass : Pass<"fly-convert-atom-call-to-ssa-form"> ]; } +def FlyIntSwizzleSimplifyPass : Pass<"fly-int-swizzle-simplify"> { + let summary = "Algebraically simplify swizzle-shaped arith sequences"; + let description = [{ + Recognize the canonical swizzle sequence emitted by `applySwizzle`, + + %m = arith.constant ((1< { let summary = "Promote register memory to vector SSA values"; let description = [{ diff --git a/include/flydsl/Dialect/Fly/Utils/IntUtils.h b/include/flydsl/Dialect/Fly/Utils/IntUtils.h index fcda7518..f6597a90 100644 --- a/include/flydsl/Dialect/Fly/Utils/IntUtils.h +++ b/include/flydsl/Dialect/Fly/Utils/IntUtils.h @@ -56,6 +56,8 @@ IntAttr intCeilDiv(IntAttr lhs, IntAttr rhs); IntAttr intShapeDiv(IntAttr lhs, IntAttr rhs); IntAttr intApplySwizzle(IntAttr v, SwizzleAttr swizzle); +bool isDivisibleBy(IntAttr attr, int32_t divisor); + //===----------------------------------------------------------------------===// // BasisAttr operations //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Fly/CMakeLists.txt b/lib/Dialect/Fly/CMakeLists.txt index a1d2b89b..ace175fb 100644 --- a/lib/Dialect/Fly/CMakeLists.txt +++ b/lib/Dialect/Fly/CMakeLists.txt @@ -13,7 +13,8 @@ add_mlir_dialect_library(MLIRFlyDialect Transforms/RewriteFuncSignature.cpp Transforms/ConvertAtomCallToSSAForm.cpp Transforms/PromoteRegMemToVectorSSA.cpp - + Transforms/IntSwizzleSimplify.cpp + DEPENDS MLIRFlyIncGen FlyTransformPassIncGen diff --git a/lib/Dialect/Fly/Transforms/IntSwizzleSimplify.cpp b/lib/Dialect/Fly/Transforms/IntSwizzleSimplify.cpp new file mode 100644 index 00000000..0bbf6f8f --- /dev/null +++ b/lib/Dialect/Fly/Transforms/IntSwizzleSimplify.cpp @@ -0,0 +1,247 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/TypeSwitch.h" + +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" + +#include +#include + +using namespace mlir; +using namespace mlir::fly; + +namespace mlir { +namespace fly { +#define GEN_PASS_DEF_FLYINTSWIZZLESIMPLIFYPASS +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" +} // namespace fly +} // namespace mlir + +namespace { + +//===----------------------------------------------------------------------===// +// Divisibility lattice over SSA values. +// +// For each value we take max() of two evidences: +// * forward, from the defining arith op (gcd / sat-mul / shl); +// * backward, from any Make*Op user that carries a static +// IntAttr.divisibility annotation on its result IntTupleAttr. +//===----------------------------------------------------------------------===// + +class Divisibility { +public: + int get(Value v); + +private: + static int satMul(int a, int b) { + if (a == 0 || b == 0) + return 0; + return (a > INT32_MAX / b) ? INT32_MAX : a * b; + } + + int fromDef(Value v); + int fromUses(Value v); + static int fromMakeOpUser(Operation *user, unsigned operandIdx); + + DenseMap cache; + llvm::SmallPtrSet visiting; +}; + +int Divisibility::get(Value v) { + if (auto c = getConstantIntValue(v)) { + int64_t abs_c = *c == INT64_MIN ? INT64_MAX : std::abs(*c); + return abs_c == 0 ? INT32_MAX : (int)std::min(abs_c, (int64_t)INT32_MAX); + } + if (auto it = cache.find(v); it != cache.end()) + return it->second; + if (!visiting.insert(v).second) + return 1; // conservative on cycles + int d = std::max(fromDef(v), fromUses(v)); + visiting.erase(v); + return cache[v] = d; +} + +int Divisibility::fromDef(Value v) { + Operation *def = v.getDefiningOp(); + if (!def) + return 1; + return llvm::TypeSwitch(def) + .Case([&](auto op) { return std::gcd(get(op.getLhs()), get(op.getRhs())); }) + .Case([&](auto op) { return satMul(get(op.getLhs()), get(op.getRhs())); }) + .Case([&](auto op) { + int lhs = get(op.getLhs()); + auto k = getConstantIntValue(op.getRhs()); + return (k && *k >= 0 && *k < 31) ? satMul(lhs, 1 << *k) : lhs; + }) + .Case( + [&](auto op) { return std::gcd(get(op.getTrueValue()), get(op.getFalseValue())); }) + .Default(1); +} + +int Divisibility::fromUses(Value v) { + int best = 1; + for (OpOperand &use : v.getUses()) + best = std::max(best, fromMakeOpUser(use.getOwner(), use.getOperandNumber())); + return best; +} + +int Divisibility::fromMakeOpUser(Operation *user, unsigned operandIdx) { + auto resTy = llvm::TypeSwitch(user) + .Case( + [](auto op) { return cast(op.getType()); }) + .Default(IntTupleType()); + if (!resTy) + return 1; + + // DFS the IntTupleAttr to find the operandIdx-th dynamic leaf. + unsigned cursor = 0; + int result = 1; + std::function visit = [&](IntTupleAttr a) -> bool { + if (a.isLeaf()) { + auto leaf = a.extractIntFromLeaf(); + if (!leaf || leaf.getStaticFlag()) + return false; + if (cursor++ == operandIdx) { + int32_t d = leaf.getDivisibility(); + result = d > 0 ? (int)d : 1; + return true; + } + return false; + } + auto arr = dyn_cast(a.getValue()); + if (!arr) + return false; + for (int i = 0, e = arr.size(); i < e; ++i) + if (visit(a.at(i))) + return true; + return false; + }; + visit(resTy.getAttr()); + return result; +} + +//===------------------------------------------------------------------------------------------===// +// Peel period-aligned summands out of the swizzle's input: +// +// swizzle(base + Σ aᵢ) → swizzle(base) + Σ aᵢ +// +// when each aᵢ ≡ 0 (mod period). When `base` drops out entirely the swizzle of 0 folds to 0. +//===------------------------------------------------------------------------------------------===// + +struct PeelResult { + Value base; + SmallVector offsets; +}; + +PeelResult peelByPeriod(Value v, int period, Divisibility &div) { + if (period <= 0) + return {v, {}}; + auto period_multiple = [&](Value x) { return div.get(x) % period == 0; }; + + if (auto add = v.getDefiningOp()) { + if (period_multiple(add.getRhs())) { + auto rec = peelByPeriod(add.getLhs(), period, div); + rec.offsets.push_back(add.getRhs()); + return rec; + } + if (period_multiple(add.getLhs())) { + auto rec = peelByPeriod(add.getRhs(), period, div); + rec.offsets.push_back(add.getLhs()); + return rec; + } + return {v, {}}; + } + if (period_multiple(v)) + return {/*base=*/Value(), {v}}; + return {v, {}}; +} + +//===----------------------------------------------------------------------===// +// Pass. +//===----------------------------------------------------------------------===// + +class FlyIntSwizzleSimplifyPass + : public mlir::fly::impl::FlyIntSwizzleSimplifyPassBase { +public: + using mlir::fly::impl::FlyIntSwizzleSimplifyPassBase< + FlyIntSwizzleSimplifyPass>::FlyIntSwizzleSimplifyPassBase; + + void runOnOperation() override { + auto module = getOperation(); + module->walk([&](gpu::GPUFuncOp fn) { simplifyInFunc(fn); }); + } + + void simplifyInFunc(gpu::GPUFuncOp fn) { + Divisibility div; + + // Collect candidate xori ops (with their period) up-front so we can rewrite + // without invalidating the walk. + SmallVector> candidates; + + fn.walk([&](arith::XOrIOp xori) { + auto shrui = xori.getRhs().getDefiningOp(); + if (!shrui) + return; + auto andi = shrui.getLhs().getDefiningOp(); + if (!andi) + return; + Value x = xori.getLhs(); + if (andi.getLhs() != x) + return; + auto maskC = getConstantIntValue(andi.getRhs()); + auto shiftC = getConstantIntValue(shrui.getRhs()); + if (!maskC || !shiftC) + return; + uint64_t mask = (uint64_t)*maskC; + if (mask == 0) + return; + uint64_t low = mask & -mask; + uint64_t plus = mask + low; + if ((mask & plus) != 0) + return; + int K = llvm::countr_zero(mask); + int M = llvm::popcount(mask); + if ((int)*shiftC > K) + return; + candidates.push_back({xori, 1 << (M + K)}); + }); + + for (auto [xori, period] : candidates) { + Value input = xori.getLhs(); + auto peeled = peelByPeriod(input, period, div); + if (peeled.offsets.empty()) + continue; + + OpBuilder b(xori); + Location loc = xori.getLoc(); + Type ty = xori.getType(); + + auto shrui = cast(xori.getRhs().getDefiningOp()); + auto andi = cast(shrui.getLhs().getDefiningOp()); + Value mask = andi.getRhs(); + Value shiftC = shrui.getRhs(); + + Value cur; + if (peeled.base) { + Value masked = arith::AndIOp::create(b, loc, peeled.base, mask); + Value shifted = arith::ShRUIOp::create(b, loc, masked, shiftC); + cur = arith::XOrIOp::create(b, loc, peeled.base, shifted); + } else { + cur = arith::ConstantIntOp::create(b, loc, ty, 0); + } + for (Value v : peeled.offsets) + cur = arith::AddIOp::create(b, loc, cur, v); + + xori.replaceAllUsesWith(cur); + xori.erase(); + } + } +}; + +} // namespace diff --git a/lib/Dialect/Fly/Utils/IntUtils.cpp b/lib/Dialect/Fly/Utils/IntUtils.cpp index 96dc452d..1a7ae023 100644 --- a/lib/Dialect/Fly/Utils/IntUtils.cpp +++ b/lib/Dialect/Fly/Utils/IntUtils.cpp @@ -250,6 +250,12 @@ IntAttr intApplySwizzle(IntAttr v, SwizzleAttr swizzle) { utils::divisibilityApplySwizzle(v.getDivisibility(), swizzle)); } +bool isDivisibleBy(IntAttr attr, int32_t divisor) { + if (attr.isStatic()) + return attr.getValue() % divisor == 0; + return attr.getDivisibility() % divisor == 0; +} + BasisAttr operator*(BasisAttr lhs, IntAttr rhs) { return BasisAttr::get(lhs.getValue() * rhs, lhs.getModes()); } diff --git a/python/flydsl/compiler/backends/rocm.py b/python/flydsl/compiler/backends/rocm.py index 052e0314..b76b41ea 100644 --- a/python/flydsl/compiler/backends/rocm.py +++ b/python/flydsl/compiler/backends/rocm.py @@ -65,6 +65,8 @@ def pipeline_fragments(self, *, compile_hints: dict) -> List[str]: "fly-rewrite-func-signature", "fly-canonicalize", "fly-layout-lowering", + "fly-int-swizzle-simplify", + "canonicalize", "fly-convert-atom-call-to-ssa-form", "fly-promote-regmem-to-vectorssa", "convert-fly-to-rocdl", diff --git a/tests/mlir/Transforms/int_swizzle_simplify.mlir b/tests/mlir/Transforms/int_swizzle_simplify.mlir new file mode 100644 index 00000000..60f13d6d --- /dev/null +++ b/tests/mlir/Transforms/int_swizzle_simplify.mlir @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2025 FlyDSL Project Contributors +// RUN: %fly-opt %s --fly-int-swizzle-simplify | FileCheck %s + +// SwizzleType(M=3, B=3, S=3): mask = 0b111000000 = 448, shift = 3, +// period P = 2^(3+3+3) = 512. + +// ----------------------------------------------------------------------------- +// (1) Plain swizzle sequence whose input has no peelable structure: untouched. +// ----------------------------------------------------------------------------- +// CHECK-LABEL: gpu.func @no_peel +// CHECK: %[[M:.+]] = arith.constant 448 : i32 +// CHECK: %[[S:.+]] = arith.constant 3 : i32 +// CHECK: %[[A:.+]] = arith.andi %arg0, %[[M]] +// CHECK: %[[H:.+]] = arith.shrui %[[A]], %[[S]] +// CHECK: arith.xori %arg0, %[[H]] +// CHECK-NOT: arith.addi +gpu.module @t1 { + gpu.func @no_peel(%x: i32) { + %m = arith.constant 448 : i32 + %s = arith.constant 3 : i32 + %a = arith.andi %x, %m : i32 + %h = arith.shrui %a, %s : i32 + %r = arith.xori %x, %h : i32 + %t = fly.make_int_tuple(%r) : (i32) -> !fly.int_tuple + gpu.return + } +} + +// ----------------------------------------------------------------------------- +// (2) Period-aligned constant addend is peeled out: +// swizzle(x + 512) → swizzle(x) + 512. +// ----------------------------------------------------------------------------- +// CHECK-LABEL: gpu.func @peel_const_addend +// CHECK: %[[D:.+]] = arith.constant 512 : i32 +// CHECK: %[[M:.+]] = arith.constant 448 : i32 +// CHECK: %[[S:.+]] = arith.constant 3 : i32 +// CHECK: %[[AND:.+]] = arith.andi %arg0, %[[M]] +// CHECK: %[[SHR:.+]] = arith.shrui %[[AND]], %[[S]] +// CHECK: %[[SW:.+]] = arith.xori %arg0, %[[SHR]] +// CHECK: arith.addi %[[SW]], %[[D]] +gpu.module @t2 { + gpu.func @peel_const_addend(%x: i32) { + %d = arith.constant 512 : i32 + %y = arith.addi %x, %d : i32 + %m = arith.constant 448 : i32 + %s = arith.constant 3 : i32 + %a = arith.andi %y, %m : i32 + %h = arith.shrui %a, %s : i32 + %r = arith.xori %y, %h : i32 + %t = fly.make_int_tuple(%r) : (i32) -> !fly.int_tuple + gpu.return + } +} + +// ----------------------------------------------------------------------------- +// (3) Whole input divisible by period: swizzle(d) → 0; result is 0 + d = d. +// ----------------------------------------------------------------------------- +// CHECK-LABEL: gpu.func @peel_to_zero +// CHECK: %[[D:.+]] = arith.constant 1024 : i32 +// CHECK: %[[ZERO:.+]] = arith.constant 0 : i32 +// CHECK: arith.addi %[[ZERO]], %[[D]] +// CHECK-NOT: arith.xori +gpu.module @t3 { + gpu.func @peel_to_zero(%x: i32) { + %d = arith.constant 1024 : i32 + %m = arith.constant 448 : i32 + %s = arith.constant 3 : i32 + %a = arith.andi %d, %m : i32 + %h = arith.shrui %a, %s : i32 + %r = arith.xori %d, %h : i32 + %t = fly.make_int_tuple(%r) : (i32) -> !fly.int_tuple + gpu.return + } +} + +// ----------------------------------------------------------------------------- +// (4) Non-period-aligned addend stays in place (no rewrite). +// ----------------------------------------------------------------------------- +// CHECK-LABEL: gpu.func @no_peel_misaligned +// CHECK: %[[D:.+]] = arith.constant 7 : i32 +// CHECK: arith.addi %arg0, %[[D]] +// CHECK: arith.xori +// CHECK-NOT: arith.addi +gpu.module @t4 { + gpu.func @no_peel_misaligned(%x: i32) { + %d = arith.constant 7 : i32 + %y = arith.addi %x, %d : i32 + %m = arith.constant 448 : i32 + %s = arith.constant 3 : i32 + %a = arith.andi %y, %m : i32 + %h = arith.shrui %a, %s : i32 + %r = arith.xori %y, %h : i32 + %t = fly.make_int_tuple(%r) : (i32) -> !fly.int_tuple + gpu.return + } +} + +// ----------------------------------------------------------------------------- +// (5) i64 type works the same. +// ----------------------------------------------------------------------------- +// CHECK-LABEL: gpu.func @peel_i64 +// CHECK: %[[D:.+]] = arith.constant 512 : i64 +// CHECK: %[[M:.+]] = arith.constant 448 : i64 +// CHECK: %[[S:.+]] = arith.constant 3 : i64 +// CHECK: %[[AND:.+]] = arith.andi %arg0, %[[M]] +// CHECK: %[[SHR:.+]] = arith.shrui %[[AND]], %[[S]] +// CHECK: %[[SW:.+]] = arith.xori %arg0, %[[SHR]] +// CHECK: arith.addi %[[SW]], %[[D]] +gpu.module @t5 { + gpu.func @peel_i64(%x: i64) { + %d = arith.constant 512 : i64 + %y = arith.addi %x, %d : i64 + %m = arith.constant 448 : i64 + %s = arith.constant 3 : i64 + %a = arith.andi %y, %m : i64 + %h = arith.shrui %a, %s : i64 + %r = arith.xori %y, %h : i64 + %t = fly.make_int_tuple(%r) : (i64) -> !fly.int_tuple + gpu.return + } +} From 25510d1b513802542f6d51f80cb71d585be85f11 Mon Sep 17 00:00:00 2001 From: ruanjm Date: Thu, 23 Apr 2026 16:51:05 +0800 Subject: [PATCH 27/29] Implement MLA decode fwd nh128 a8w8 kernel with FlyDSL. (#403) --- kernels/mla_fwd_decode.py | 175 ++ kernels/mla_fwd_decode_m16x8_fp8_fp8.py | 2292 +++++++++++++++++++++++ tests/kernels/test_mla_decode.py | 312 +++ 3 files changed, 2779 insertions(+) create mode 100644 kernels/mla_fwd_decode.py create mode 100644 kernels/mla_fwd_decode_m16x8_fp8_fp8.py create mode 100644 tests/kernels/test_mla_decode.py diff --git a/kernels/mla_fwd_decode.py b/kernels/mla_fwd_decode.py new file mode 100644 index 00000000..1e01ab5a --- /dev/null +++ b/kernels/mla_fwd_decode.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""FlyDSL MLA decode launcher. Uses aiter for device queries.""" + +import functools +import re +import shutil +import subprocess + +import torch + + +def _gcn_arch_base(arch_name: str) -> str: + """Strip target features (':sramecc+:xnack-') from a gcnArchName.""" + return arch_name.split(":", 1)[0] + + +@functools.lru_cache(maxsize=None) +def _get_lds_size_per_cu(arch: str) -> int: + """Return the LDS (shared memory) size per CU in bytes for ``arch``. + + Cached per arch so a mixed-GPU process (or one that switches devices) + gets the right LDS budget for the active device — not whichever GPU + rocminfo happens to list first. Caller must pass the current device's + base gcnArchName (e.g. ``"gfx942"``). + + Parses the GROUP segment pool size from ``rocminfo`` output, picking + the first GPU agent whose name matches ``arch``. + """ + rocminfo = shutil.which("rocminfo") + if rocminfo is None: + raise RuntimeError("rocminfo not found on PATH") + result = subprocess.run( + [rocminfo], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) + agents = re.split(r"Agent\s*\d+", result.stdout) + for agent in agents: + if "Device Type" not in agent or agent.find("GPU") == -1: + continue + # Match this agent's Name (e.g. "gfx942") against the requested arch. + name_m = re.search(r"^\s*Name:\s*(\S+)", agent, re.MULTILINE) + if not name_m or name_m.group(1) != arch: + continue + lines = agent.split("\n") + for i, line in enumerate(lines): + if re.search(r"Segment\s*:\s*GROUP", line) and i + 1 < len(lines): + m = re.search(r"Size\s*:\s*(\d+)", lines[i + 1]) + if m: + return int(m.group(1)) * 1024 # KB -> bytes + raise RuntimeError( + f"No GPU GROUP segment found in rocminfo output for arch {arch!r}" + ) + + +def _is_fp8(dtype: torch.dtype) -> bool: + return dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) + + +def flydsl_mla_fwd_decode( + query: torch.Tensor, # [num_seqs, num_heads, head_size] + kv_buffer: torch.Tensor, # [num_page, page_size, num_kv_heads, head_size] + kv_page_indices: torch.Tensor, + work_indptr: torch.Tensor, + work_info_set: torch.Tensor, + final_output: torch.Tensor, # [num_seqs, num_heads, v_head_dim] + split_output: torch.Tensor, # [num_partial_slots, 1, num_heads, v_head_dim] + split_lse: torch.Tensor, # [num_partial_slots, 1, num_heads, 1] + softmax_scale: float, +) -> None: + """Launch the FlyDSL MLA decode forward kernel.""" + num_heads = query.size(1) + q_dtype = query.dtype + kv_dtype = kv_buffer.dtype + + if num_heads == 128 and _is_fp8(q_dtype) and _is_fp8(kv_dtype): + from .mla_fwd_decode_m16x8_fp8_fp8 import ( + OCCUPANCY, + QK_HEAD_DIM, + V_HEAD_DIM, + launch_mla_fwd_decode_m16x8_fp8_fp8, + ) + + # ── shape validation ── + assert query.ndim == 3, ( + f"query: expected 3D [num_seqs, num_heads, qk_head_dim], got shape {list(query.shape)}" + ) + assert query.size(2) == QK_HEAD_DIM, ( + f"query: head_dim={query.size(2)}, expected {QK_HEAD_DIM}" + ) + assert kv_buffer.ndim == 4, ( + f"kv_buffer: expected 4D [num_page, page_size, num_kv_heads, qk_head_dim], " + f"got shape {list(kv_buffer.shape)}" + ) + assert kv_buffer.size(1) * kv_buffer.size(2) == 1, ( + f"kv_buffer: page_size*num_kv_heads must be 1, " + f"got page_size={kv_buffer.size(1)}, num_kv_heads={kv_buffer.size(2)}" + ) + assert kv_buffer.size(3) == QK_HEAD_DIM, ( + f"kv_buffer: head_dim={kv_buffer.size(3)}, expected {QK_HEAD_DIM}" + ) + num_seqs = query.size(0) + assert final_output.shape == (num_seqs, num_heads, V_HEAD_DIM), ( + f"final_output: expected shape [{num_seqs}, {num_heads}, {V_HEAD_DIM}], " + f"got {list(final_output.shape)}" + ) + num_partial = split_output.size(0) + assert split_output.ndim == 4 and split_output.shape[1:] == (1, num_heads, V_HEAD_DIM), ( + f"split_output: expected [N, 1, {num_heads}, {V_HEAD_DIM}], " + f"got {list(split_output.shape)}" + ) + assert split_lse.ndim == 4 and split_lse.shape[1:] == (1, num_heads, 1), ( + f"split_lse: expected [N, 1, {num_heads}, 1], got {list(split_lse.shape)}" + ) + assert split_lse.size(0) == num_partial, ( + f"split_lse batch dim ({split_lse.size(0)}) != split_output batch dim ({num_partial})" + ) + dev = query.device + for name, t in [("kv_buffer", kv_buffer), ("kv_page_indices", kv_page_indices), + ("work_indptr", work_indptr), ("work_info_set", work_info_set), + ("final_output", final_output), ("split_output", split_output), + ("split_lse", split_lse)]: + assert t.device == dev, f"{name}: expected device {dev}, got {t.device}" + + # Output tensors must be contiguous: reshape() on a non-contiguous + # output would silently materialize a copy, the kernel would write + # into the copy, and the caller's original tensor would never be + # updated. Use view() after asserting contiguity so any layout + # mismatch fails loudly here instead. + for name, t in [("final_output", final_output), + ("split_output", split_output), + ("split_lse", split_lse)]: + assert t.is_contiguous(), ( + f"{name}: must be contiguous (stride={list(t.stride())}, " + f"shape={list(t.shape)}); reshape() would silently copy and " + f"the kernel's writes would not be visible to the caller" + ) + + num_pages = kv_buffer.size(0) + + query_flat = query.reshape(num_seqs * num_heads, QK_HEAD_DIM) + kv_flat = kv_buffer.reshape(num_pages, QK_HEAD_DIM) + final_flat = final_output.view(num_seqs * num_heads, V_HEAD_DIM) + split_o_flat = split_output.view(num_partial * num_heads, V_HEAD_DIM) + split_lse_flat = split_lse.view(num_partial * num_heads) + + work_indptr_flat = work_indptr.contiguous() + work_info_flat = work_info_set.contiguous().view(-1) + kv_idx_flat = kv_page_indices.contiguous() + + from aiter.jit.utils.chip_info import get_cu_num + + num_cus = get_cu_num() + arch = _gcn_arch_base(torch.cuda.get_device_properties(dev).gcnArchName) + lds_size = _get_lds_size_per_cu(arch) // OCCUPANCY + + launch_mla_fwd_decode_m16x8_fp8_fp8( + query_flat, + kv_flat, + kv_idx_flat, + work_indptr_flat, + work_info_flat, + final_flat, + split_o_flat, + split_lse_flat, + softmax_scale, + num_cus, + lds_size, + stream=torch.cuda.current_stream(), + ) + else: + raise NotImplementedError( + f"flydsl_mla_fwd_decode: unsupported num_heads={num_heads}, " + f"q_dtype={q_dtype}, kv_dtype={kv_dtype}" + ) diff --git a/kernels/mla_fwd_decode_m16x8_fp8_fp8.py b/kernels/mla_fwd_decode_m16x8_fp8_fp8.py new file mode 100644 index 00000000..dbec8746 --- /dev/null +++ b/kernels/mla_fwd_decode_m16x8_fp8_fp8.py @@ -0,0 +1,2292 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""FlyDSL MLA decode kernel (nhead=128, fp8 Q, fp8 KV, bf16 output). + +Transplanted from csrc/kernels/mla/hk/mi3xx_v32_fwd_decode_h128_fp8_fp8.cuh. + +Architecture: 8 warps / 512 threads, persistent-thread dispatch. +Per work item: load Q -> iterate KV tiles (BLOCK_N=32) -> QK GEMM (nope+rope) +-> online softmax -> PV GEMM -> output (final bf16 or split f32 + LSE). + +NOTE: Do NOT use ``from __future__ import annotations`` here -- it breaks +``fx.Constexpr`` detection in the FlyDSL AST rewriter. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl._mlir import ir +from flydsl._mlir.dialects import llvm, scf +from flydsl._mlir.dialects import math as _math +from flydsl._mlir.dialects import memref as _memref +from flydsl._mlir.dialects import gpu as _mlir_gpu +from flydsl._mlir.dialects._arith_enum_gen import CmpIPredicate + +from flydsl.expr import arith, vector, gpu, buffer_ops, rocdl +from flydsl.expr import range_constexpr +from flydsl.expr.arith import _to_raw as _raw +from flydsl.expr.typing import T + +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.utils.smem_allocator import SmemAllocator +from flydsl.runtime.device import get_rocm_arch as get_hip_arch + + +# --------------------------------------------------------------------------- +# Compile-time constants (mirroring HkMlaDecodeFwdTraits) +# --------------------------------------------------------------------------- +NUM_QO_HEADS: int = 128 +NUM_KV_HEADS: int = 1 +KV_LORA_RANK: int = 512 +QK_NOPE_HEAD_DIM: int = KV_LORA_RANK # 512 +QK_ROPE_HEAD_DIM: int = 64 +QK_HEAD_DIM: int = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM # 576 +V_HEAD_DIM: int = KV_LORA_RANK # 512 +PAGE_SIZE: int = 1 +NUM_WARPS: int = 8 +WARP_SIZE: int = 64 +NUM_THREADS: int = NUM_WARPS * WARP_SIZE # 512 +BLOCK_M: int = 128 # == NUM_QO_HEADS +BLOCK_N: int = 32 +BLOCK_K: int = 32 +TILE_M: int = BLOCK_M // NUM_WARPS # 16 +OCCUPANCY: int = 1 + +SIZE_MLA_WORK_INFO_IN_DW: int = 8 +LOG2E: float = 1.4426950408889634 + +# --------------------------------------------------------------------------- +# KvManagerV2 LDS layout constants +# --------------------------------------------------------------------------- +# KV tile: 32 rows x 576 cols (fp8), split into 9 blocks of 64 cols each. +# Each block: 8 sub-blocks (one per warp) of 4 rows x 64 cols + 2 DW padding. +KV_NUM_COLS: int = 64 +KV_NUM_BLOCKS: int = QK_HEAD_DIM // KV_NUM_COLS # 576 / 64 = 9 +KV_ROWS_PER_SUB: int = BLOCK_N // NUM_WARPS # 32 / 8 = 4 +KV_BYTES_PER_ROW: int = KV_NUM_COLS # 64 * 1 (fp8) +KV_PAD_DW: int = 2 +KV_SUB_BYTES: int = KV_ROWS_PER_SUB * KV_BYTES_PER_ROW + KV_PAD_DW * 4 # 264 +KV_NUM_SUBS: int = BLOCK_N // KV_ROWS_PER_SUB # 8 +KV_BLOCK_BYTES: int = KV_SUB_BYTES * KV_NUM_SUBS # 2112 +SZ_LDS_KV: int = KV_BLOCK_BYTES * KV_NUM_BLOCKS # 2112 * 9 = 19008 + +# --------------------------------------------------------------------------- +# VtManagerV1 LDS layout constants +# --------------------------------------------------------------------------- +VT_ROWS_PER_THR: int = 4 +VT_COLS_PER_THR: int = 8 +VT_ELEMS_PER_BLK: int = VT_ROWS_PER_THR * VT_COLS_PER_THR # 32 +VT_BLKS_PER_ROW: int = V_HEAD_DIM // VT_COLS_PER_THR # 64 +VT_BLKS_PER_ROW_PAD: int = VT_BLKS_PER_ROW + 2 # 66 +VT_NUM_SUB_BLKS: int = 8 +SZ_LDS_VT: int = VT_NUM_SUB_BLKS * ( + (BLOCK_N // VT_NUM_SUB_BLKS) * V_HEAD_DIM + 16 * 4 +) # 8 * (4*512 + 64) = 16896 + +# --------------------------------------------------------------------------- +# QManagerV3 LDS layout constants (per-warp staging for VRAM->LDS->GPR) +# --------------------------------------------------------------------------- +Q_ELEM_PER_ROW: int = 64 +Q_ELEM_PER_COL: int = 16 +Q_PAD_BYTES_PER_2ROWS: int = 8 # 2 DW +Q_BYTES_PER_2ROWS: int = Q_ELEM_PER_ROW * 2 + Q_PAD_BYTES_PER_2ROWS # 136 +SZ_LDS_Q_PER_WARP: int = Q_ELEM_PER_COL // 2 * Q_BYTES_PER_2ROWS # 1088 +SZ_LDS_Q: int = NUM_WARPS * SZ_LDS_Q_PER_WARP # 8704 + +# --------------------------------------------------------------------------- +# OManager16bitsV2 (bf16 output via LDS reshape) +# --------------------------------------------------------------------------- +O16_NUM_ROWS: int = 16 +O16_NUM_COLS: int = 32 +O16_PAD_ELEM_PER_2ROWS: int = 4 # padded 2-row stride in bf16 elements +O16_ELEM_PER_PAD_2ROWS: int = 2 * O16_NUM_COLS + O16_PAD_ELEM_PER_2ROWS # 68 +O16_LDS_PER_WARP: int = (O16_NUM_ROWS // 2) * O16_ELEM_PER_PAD_2ROWS * 2 # 1088 +SZ_LDS_O16: int = NUM_WARPS * O16_LDS_PER_WARP # 8704 (reuses p_lds_kv region) + +# --------------------------------------------------------------------------- +# OManager32bitsV2 (f32 split output via LDS reshape) +# --------------------------------------------------------------------------- +O32_NUM_ROWS: int = 16 +O32_NUM_COLS: int = 32 +O32_PAD_ELEM_PER_ROW: int = 4 +O32_ELEM_PER_PAD_ROW: int = O32_NUM_COLS + O32_PAD_ELEM_PER_ROW # 36 +O32_LDS_PER_WARP: int = O32_NUM_ROWS * O32_ELEM_PER_PAD_ROW * 4 # 2304 +SZ_LDS_O32: int = NUM_WARPS * O32_LDS_PER_WARP # 18432 + +# Overall LDS layout (byte offsets): +# [0, SZ_LDS_VT) = Vt staging buffer +# [SZ_LDS_VT, SZ_LDS_VT + SZ_LDS_Q) = Q staging buffer +# [SZ_LDS_VT + SZ_LDS_Q, +SZ_LDS_KV) = KV double-buffer 0 +# [SZ_LDS_VT + SZ_LDS_Q + SZ_LDS_KV, +SZ_LDS_KV) = KV double-buffer 1 +# Output reuses the KV double-buffer 0 region. +P_LDS_VT: int = 0 +P_LDS_Q: int = SZ_LDS_VT # 16896 +P_LDS_KV_0: int = P_LDS_Q + SZ_LDS_Q # 25600 +P_LDS_KV_1: int = P_LDS_KV_0 + SZ_LDS_KV # 44608 +TOTAL_LDS_BYTES: int = P_LDS_KV_1 + SZ_LDS_KV # 63616 + +assert ( + max(SZ_LDS_O16, SZ_LDS_O32) <= SZ_LDS_KV +), "Output LDS must fit in one KV buffer region" + +# --------------------------------------------------------------------------- +# MFMA tile constants +# --------------------------------------------------------------------------- +MFMA_M: int = 16 +MFMA_N: int = 16 +MFMA_K: int = 32 # mfma_f32_16x16x32_fp8_fp8 +MFMA_ELEM_PER_THR: int = MFMA_M * MFMA_K // WARP_SIZE # 8 + +# Number of QK sub-tile iterations +NUM_NOPE_ITERS: int = QK_NOPE_HEAD_DIM // (MFMA_K * 2) # 512/64 = 8 +NUM_ROPE_ITERS: int = QK_ROPE_HEAD_DIM // (MFMA_K * 2) # 64/64 = 1 +NUM_PV_ITERS: int = V_HEAD_DIM // (MFMA_N * 2) # 512/32 = 16 + + +# --------------------------------------------------------------------------- +# Utility helpers (ported from FlyDSL/kernels/mla_decode_fp8.py) +# --------------------------------------------------------------------------- + + +def _encode_waitcnt(vmcnt=63, expcnt=7, lgkmcnt=63): + """Encode s_waitcnt bitfield for CDNA3 (gfx94x).""" + vm_lo = vmcnt & 0xF + vm_hi = (vmcnt >> 4) & 0x3 + return vm_lo | (expcnt << 4) | (lgkmcnt << 8) | (vm_hi << 14) + + +def _barrier(vmcnt=63, lgkmcnt=63): + """Emit s_waitcnt + s_barrier via inline asm.""" + parts = [] + needs_waitcnt = vmcnt < 63 or lgkmcnt < 63 + if needs_waitcnt: + wc = [] + if vmcnt < 63: + wc.append(f"vmcnt({vmcnt})") + if lgkmcnt < 63: + wc.append(f"lgkmcnt({lgkmcnt})") + parts.append("s_waitcnt " + " ".join(wc)) + parts.append("s_barrier") + llvm.InlineAsmOp( + res=None, + operands_=[], + asm_string="\n".join(parts), + constraints="", + has_side_effects=True, + is_align_stack=False, + ) + + +_LDS_PTR_TYPE = None + + +def _inttoptr_lds(i64_val): + """Convert i64 scalar to !llvm.ptr<3> (LDS pointer).""" + global _LDS_PTR_TYPE + if _LDS_PTR_TYPE is None: + _LDS_PTR_TYPE = ir.Type.parse("!llvm.ptr<3>") + return llvm.inttoptr(_LDS_PTR_TYPE, i64_val) + + +def _get_element_ptr(base_ptr, byte_offset=None, static_byte_offset=0, elem_type=None): + """GEP-based pointer arithmetic.""" + _GEP_DYN = -(2**31) + raw_ptr = _raw(base_ptr) if not isinstance(base_ptr, ir.Value) else base_ptr + if elem_type is None: + elem_type = T.i8 + + if byte_offset is None: + return llvm.GEPOp( + raw_ptr.type, + raw_ptr, + [], + [int(static_byte_offset)], + elem_type, + None, + ).result + elif isinstance(byte_offset, int): + return llvm.GEPOp( + raw_ptr.type, + raw_ptr, + [], + [int(byte_offset) + int(static_byte_offset)], + elem_type, + None, + ).result + else: + offset_val = ( + _raw(byte_offset) if not isinstance(byte_offset, ir.Value) else byte_offset + ) + if isinstance(offset_val.type, ir.IndexType): + offset_val = arith.index_cast(T.i64, offset_val) + if static_byte_offset != 0: + static_attr = ir.IntegerAttr.get(offset_val.type, int(static_byte_offset)) + static_const = arith.ConstantOp(offset_val.type, static_attr).result + offset_val = _raw(arith.ArithValue(offset_val) + arith.ArithValue(static_const)) + return llvm.GEPOp( + raw_ptr.type, + raw_ptr, + [offset_val], + [_GEP_DYN], + elem_type, + None, + ).result + + +def _lds_load(byte_addr_index, vec_type, static_byte_offset=0): + """LDS load via raw llvm.LoadOp on an LDS pointer (addr space 3).""" + raw_addr = ( + _raw(byte_addr_index) + if not isinstance(byte_addr_index, ir.Value) + else byte_addr_index + ) + addr_i64 = arith.index_cast(T.i64, raw_addr) + lds_ptr = _inttoptr_lds(addr_i64) + if static_byte_offset != 0: + lds_ptr = _get_element_ptr(lds_ptr, static_byte_offset=static_byte_offset) + return llvm.LoadOp(vec_type, lds_ptr, alignment=16, nontemporal=True).result + + +def _lds_load_volatile(base_i32, vec_type, byte_offset=0): + """Volatile LDS load forcing ds_read_b64/b32 with immediate offset. + + Unlike _lds_load, uses volatile to prevent LLVM from merging adjacent + loads into ds_read2 variants (which have limited 8-bit offsets). + LLVM still tracks these as LDS loads for lgkmcnt. + Input: base_i32 must be an i32 ir.Value (LDS byte address). + """ + addr_i64 = _raw(arith.ArithValue(base_i32).extui(T.i64)) + lds_ptr = _inttoptr_lds(addr_i64) + if byte_offset != 0: + lds_ptr = _get_element_ptr(lds_ptr, static_byte_offset=byte_offset) + return llvm.LoadOp(vec_type, lds_ptr, alignment=8, volatile_=True).result + + +def _index_cast_to_i32(value): + """Cast index/ArithValue to i32. No-op if already i32.""" + raw = _raw(value) if not isinstance(value, ir.Value) else value + if raw.type == T.i32: + return raw + return arith.index_cast(T.i32, raw) + + +def _fast_exp2(val): + """Bare v_exp_f32 via rocdl.exp2 -- no range reduction.""" + return rocdl.exp2(T.f32, _raw(val)) + + +def _to_mlir(val, index=False): + """Convert Python int/float, ArithValue, or ir.Value to raw MLIR Value.""" + if isinstance(val, int): + return _raw(arith.constant(val, index=index)) + if isinstance(val, float): + return _raw(arith.constant(val)) + if isinstance(val, ir.Value): + return val + return _raw(val) + + +# --------------------------------------------------------------------------- +# Kernel +# --------------------------------------------------------------------------- +@flyc.kernel(known_block_size=[NUM_THREADS, 1, 1]) +def kn_mla_fwd_decode_m16x8_fp8_fp8( + # --- inputs --- + query: fx.Tensor, # [num_seqs * num_heads, qk_head_dim] (fp8) + kv_buffer: fx.Tensor, # [num_pages, qk_head_dim] (fp8) + kv_page_indices: fx.Tensor, # [num_page_used] (i32) + # --- metadata --- + work_indptr: fx.Tensor, # [num_workers + 1] (i32) + work_info_set: fx.Tensor, # [num_work_items * 8] (i32) + # --- outputs --- + final_output: fx.Tensor, # [num_seqs * num_heads, v_head_dim] (bf16) + split_output: fx.Tensor, # [num_partial_slots * num_heads, v_head_dim] (f32) + split_lse: fx.Tensor, # [num_partial_slots * num_heads] (f32) + # --- parameters --- + softmax_scale: fx.Float32, +): + """MLA decode forward kernel (nhead=128, fp8/fp8 -> bf16). + + Persistent-thread kernel: each workgroup picks up work items + from ``work_indptr`` / ``work_info_set`` and processes them sequentially. + """ + _STUB_EARLY_RETURN = False # Set True to skip all kernel body for testing launch + if _STUB_EARLY_RETURN: + return + + # ---- Types ---- + fm_fast = arith.FastMathFlags.fast + # fastmath without ninf: safe for operations that may encounter -inf + # (boundary masking sets OOB attention scores to -inf) + fm_no_inf = ( + arith.FastMathFlags.nnan + | arith.FastMathFlags.nsz + | arith.FastMathFlags.arcp + | arith.FastMathFlags.contract + | arith.FastMathFlags.afn + | arith.FastMathFlags.reassoc + ) + + def _mfma_fp8(result_type, operands, **kw): + return rocdl.mfma_f32_16x16x32_fp8_fp8(result_type, operands, **kw) + + # ---- LDS setup ---- + arch = get_hip_arch() + lds_allocator = SmemAllocator(None, arch=arch) + lds_allocator.ptr = TOTAL_LDS_BYTES # reserve LDS bytes + + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + lds_allocator.finalize() + + lds_buffer = lds_allocator.get_base() + lds_base_idx = _memref.ExtractAlignedPointerAsIndexOp(lds_buffer).result + + # ---- V^T transpose perm constants ---- + c_perm0 = arith.constant(0x05010400, type=T.i32) + c_perm1 = arith.constant(0x07030602, type=T.i32) + c_perm2 = arith.constant(0x05040100, type=T.i32) + c_perm3 = arith.constant(0x07060302, type=T.i32) + + def _vt_perm(src_hi, src_lo, sel): + return llvm.call_intrinsic( + T.i32, + "llvm.amdgcn.perm", + [src_hi, src_lo, sel], + [], + [], + ) + + # ---- Constants ---- + c_neg_inf_f32 = arith.constant(float("-inf"), type=T.f32) + c_zero_f32 = arith.constant(0.0, type=T.f32) + c_one_f32 = arith.constant(1.0, type=T.f32) + c_zero_i32 = arith.constant(0, type=T.i32) + c_zero_v4f32 = arith.constant_vector(0.0, T.f32x4) + c_log2e = arith.constant(LOG2E, type=T.f32) + c_inv_log2e = arith.constant(1.0 / LOG2E, type=T.f32) + c_dword_sz = arith.constant(4, type=T.i32) + c_aux_zero = arith.constant(0, type=T.i32) + + # ---- Buffer resources ---- + query_rsrc = buffer_ops.create_buffer_resource(query) + kv_rsrc = buffer_ops.create_buffer_resource(kv_buffer) + kv_page_indices_rsrc = buffer_ops.create_buffer_resource(kv_page_indices) + work_indptr_rsrc = buffer_ops.create_buffer_resource(work_indptr) + work_info_set_rsrc = buffer_ops.create_buffer_resource(work_info_set) + final_output_rsrc = buffer_ops.create_buffer_resource(final_output) + split_output_rsrc = buffer_ops.create_buffer_resource(split_output) + split_lse_rsrc = buffer_ops.create_buffer_resource(split_lse) + + # ---- Thread indices ---- + worker_idx = gpu.block_idx.x + tid = gpu.thread_id("x") + warp_idx = tid / arith.index(WARP_SIZE) + lane_idx = tid % arith.index(WARP_SIZE) + warp_idx_i32 = rocdl.readfirstlane(T.i32, _raw(_index_cast_to_i32(warp_idx))) + lane_idx_i32 = _index_cast_to_i32(lane_idx) + + # ---- Work range ---- + worker_idx_i32 = _index_cast_to_i32(worker_idx) + work_range = buffer_ops.buffer_load( + work_indptr_rsrc, worker_idx_i32, vec_width=2, dtype=T.i32 + ) + work_start_i32 = rocdl.readfirstlane(T.i32, _raw(vector.extract(work_range, [0]))) + work_end_i32 = rocdl.readfirstlane(T.i32, _raw(vector.extract(work_range, [1]))) + work_start_idx = arith.index_cast(T.index, work_start_i32) + work_end_idx = arith.index_cast(T.index, work_end_i32) + + # ---- KvManagerV2 thread-to-data mapping ---- + # Each warp takes 4 rows: warp w -> rows {w*2, w*2+1, w*2+16, w*2+17} + # lane mapping: (lane/32)*16 + (lane/16)%2 + warp*2 + kv_ld_row_base = ( + lane_idx / arith.index(32) * arith.index(16) + + (lane_idx / arith.index(16)) % arith.index(2) + + warp_idx * arith.index(2) + ) + kv_ld_row_base_i32 = _index_cast_to_i32(kv_ld_row_base) + kv_ld_col_base_i32 = _index_cast_to_i32( + (lane_idx % arith.index(16)) * arith.index(4) + ) + + # ---- Helper: resolve KV page index -> physical row ---- + def _get_kv_ld_row(kv_tile_start_i32, kv_tile_end_i32, check_boundary): + """Resolve physical KV row for this thread's assigned row. + + For OOB rows (row >= kv_end), returns -1 WITHOUT issuing a + buffer_load -- avoids reading garbage from kv_page_indices. + """ + row_idx_i32 = _raw(kv_ld_row_base_i32 + kv_tile_start_i32) + if check_boundary: + neg_one = _raw(arith.constant(-1, type=T.i32)) + if_op = scf.IfOp( + arith.cmpi( + CmpIPredicate.slt, row_idx_i32, _raw(kv_tile_end_i32) + ), + [T.i32], + has_else=True, + ) + with ir.InsertionPoint(if_op.regions[0].blocks[0]): + # In-bounds: do the buffer_load + phys_row_ib = buffer_ops.buffer_load( + kv_page_indices_rsrc, row_idx_i32, vec_width=1, dtype=T.i32 + ) + scf.YieldOp([_raw(phys_row_ib)]) + with ir.InsertionPoint(if_op.regions[1].blocks[0]): + # OOB: return -1 + scf.YieldOp([neg_one]) + return if_op.results[0] + else: + phys_row = buffer_ops.buffer_load( + kv_page_indices_rsrc, row_idx_i32, vec_width=1, dtype=T.i32 + ) + return _raw(phys_row) + + # ---- Helper: async_load_k_tile (VRAM->LDS via buffer_load_dword_lds) ---- + def _async_load_k_tile( + p_lds_kv_warp_i32, row_i32, col_base_i32, block_idx_const, check_boundary=False + ): + """Load one 32x64 block of KV data from VRAM to LDS. + + block_idx_const: Python int [0..8], which 64-col block. + """ + lds_warp_offset = block_idx_const * KV_BLOCK_BYTES + # p_lds_kv_warp points to warp's sub-block start. + # Actual LDS target: p_lds_kv_warp + block*KV_BLOCK_BYTES - block*64 + _lds_adj = arith.constant( + lds_warp_offset - block_idx_const * KV_NUM_COLS, type=T.i32 + ) + lds_base_i32 = _raw(arith.ArithValue(p_lds_kv_warp_i32) + _lds_adj) + + if check_boundary: + neg_one = _raw(arith.constant(-1, type=T.i32)) + is_oob = arith.cmpi(CmpIPredicate.eq, _raw(row_i32), neg_one) + # For OOB: write zero to LDS + if_op = scf.IfOp(is_oob, [], has_else=True) + with ir.InsertionPoint(if_op.regions[0].blocks[0]): + # Write zero via ds_write_b32 at lane's position + zero_u32 = _raw(arith.constant(0, type=T.i32)) + lane_offset = _raw(lane_idx_i32 * arith.constant(4, type=T.i32)) + lds_addr_zero = _raw( + arith.ArithValue(lds_base_i32) + + arith.ArithValue( + _raw(arith.constant(block_idx_const * KV_NUM_COLS, type=T.i32)) + + arith.ArithValue(lane_offset) + ) + ) + lds_addr_i64 = _raw(arith.ArithValue(lds_addr_zero).extui(T.i64)) + lds_ptr = _inttoptr_lds(lds_addr_i64) + llvm.StoreOp(zero_u32, lds_ptr, alignment=4) + scf.YieldOp([]) + with ir.InsertionPoint(if_op.regions[1].blocks[0]): + # Normal load + voff = _raw( + arith.ArithValue(_raw(row_i32)) + * arith.constant(QK_HEAD_DIM, type=T.i32) + + col_base_i32 + ) + col_off = arith.constant(block_idx_const * KV_NUM_COLS, type=T.i32) + lds_ptr_i64 = _raw(arith.ArithValue(lds_base_i32).extui(T.i64)) + lds_ptr = _inttoptr_lds(lds_ptr_i64) + rocdl.raw_ptr_buffer_load_lds( + kv_rsrc, + lds_ptr, + _raw(c_dword_sz), + voff, + _raw(c_aux_zero), + _raw(col_off), + _raw(c_aux_zero), + ) + scf.YieldOp([]) + else: + voff = _raw( + arith.ArithValue(_raw(row_i32)) + * arith.constant(QK_HEAD_DIM, type=T.i32) + + col_base_i32 + ) + col_off = arith.constant(block_idx_const * KV_NUM_COLS, type=T.i32) + lds_ptr_i64 = _raw(arith.ArithValue(lds_base_i32).extui(T.i64)) + lds_ptr = _inttoptr_lds(lds_ptr_i64) + rocdl.raw_ptr_buffer_load_lds( + kv_rsrc, + lds_ptr, + _raw(c_dword_sz), + voff, + _raw(c_aux_zero), + _raw(col_off), + _raw(c_aux_zero), + ) + + def _async_load_kv_all( + p_lds_kv_warp_i32, row_i32, col_base_i32, check_boundary=False + ): + """Load all 9 blocks of a KV tile.""" + for blk in range_constexpr(KV_NUM_BLOCKS): + _async_load_k_tile( + p_lds_kv_warp_i32, + row_i32, + col_base_i32, + blk, + check_boundary=check_boundary, + ) + + # ---- Inline-asm prefetch: fully opaque to LLVM waitcnt analysis ---- + def _prefetch_k_tile_asm( + p_lds_kv_warp_i32, + row_i32, + col_base_i32, + block_idx_const, + check_boundary=True, + ): + """Prefetch one KV block via inline asm buffer_load_dword lds. + + Uses inline asm for BOTH the normal load AND the OOB zero-write + so LLVM sees no LDS operations and won't insert spurious + s_waitcnt vmcnt(0) before subsequent ds_read ops. + + check_boundary: controls OOB row==-1 check. + - False (Python): skips check entirely -- caller guarantees valid row. + - True (Python): always emits scf.IfOp(row==-1). + - ir.Value (i1): emits scf.IfOp(check_boundary AND row==-1), + allowing runtime bypass. + """ + lds_warp_offset = block_idx_const * KV_BLOCK_BYTES + _lds_adj2 = arith.constant( + lds_warp_offset - block_idx_const * KV_NUM_COLS, type=T.i32 + ) + lds_base_i32 = _raw(arith.ArithValue(p_lds_kv_warp_i32) + _lds_adj2) + + def _emit_normal_load(): + voff = _raw( + arith.ArithValue(_raw(row_i32)) + * arith.constant(QK_HEAD_DIM, type=T.i32) + + col_base_i32 + ) + col_off_imm = block_idx_const * KV_NUM_COLS + asm_str = ( + "s_mov_b32 m0, $0\n" + "s_nop 0\n" + f"buffer_load_dword $1, $2, 0 offen offset:{col_off_imm} lds" + ) + llvm.InlineAsmOp( + res=None, + operands_=[lds_base_i32, voff, _raw(kv_rsrc)], + asm_string=asm_str, + constraints="s,v,s", + has_side_effects=True, + is_align_stack=False, + ) + + if check_boundary is False: + _emit_normal_load() + else: + # Build OOB condition: row == -1 + neg_one = _raw(arith.constant(-1, type=T.i32)) + is_oob = arith.cmpi(CmpIPredicate.eq, _raw(row_i32), neg_one) + # If check_boundary is a runtime i1, AND it in + if check_boundary is not True: + is_oob = _raw(arith.ArithValue(check_boundary) & arith.ArithValue(is_oob)) + + if_op = scf.IfOp(is_oob, [], has_else=True) + with ir.InsertionPoint(if_op.regions[0].blocks[0]): + # OOB: write zero to LDS via inline asm ds_write_b32 + lane_offset = _raw(lane_idx_i32 * arith.constant(4, type=T.i32)) + lds_zero_addr = _raw( + arith.ArithValue(lds_base_i32) + + arith.constant(block_idx_const * KV_NUM_COLS, type=T.i32) + + arith.ArithValue(lane_offset) + ) + llvm.InlineAsmOp( + res=None, + operands_=[lds_zero_addr, _raw(arith.constant(0, type=T.i32))], + asm_string="ds_write_b32 $0, $1", + constraints="v,v", + has_side_effects=True, + is_align_stack=False, + ) + scf.YieldOp([]) + with ir.InsertionPoint(if_op.regions[1].blocks[0]): + _emit_normal_load() + scf.YieldOp([]) + + # ---- K LDS lane base pointer (computed once, shared across all K loads) ---- + # Per-lane dynamic part of the K LDS address, stored as an LDS pointer. + # All K loads use this as base + GEP(fixed_offset), so LLVM can fold + # the fixed_offset into ds_read's 16-bit immediate offset field. + _k_row_in_mfma = lane_idx % arith.index(MFMA_M) + _k_row_phy = (_k_row_in_mfma / arith.index(2)) * arith.index( + 4 + ) + _k_row_in_mfma % arith.index(2) + _k_col_in_lane = (lane_idx / arith.index(MFMA_M)) * arith.index(MFMA_ELEM_PER_THR) + _k_lds_lane_offset = ( + (_k_row_phy / arith.index(4)) * arith.index(KV_SUB_BYTES) + + (_k_row_phy % arith.index(4)) * arith.index(KV_BYTES_PER_ROW) + + (_k_col_in_lane % arith.index(KV_NUM_COLS)) + ) + + # ---- Vt LDS lane base offset (computed once, shared across all Vt loads) ---- + _vt_row_blk = lane_idx / arith.index(16) + _vt_col_blk = (lane_idx % arith.index(16)) / arith.index(VT_COLS_PER_THR) + _vt_row_inblk = lane_idx % arith.index(VT_ROWS_PER_THR) + _vt_col_inblk = ( + (lane_idx % arith.index(8)) / arith.index(VT_ROWS_PER_THR) + ) * arith.index(VT_ROWS_PER_THR) + _vt_block_offset = ( + _vt_row_blk * arith.index(VT_BLKS_PER_ROW_PAD) + _vt_col_blk + ) * arith.index(VT_ELEMS_PER_BLK) + _vt_inblock_offset = _vt_row_inblk * arith.index(VT_COLS_PER_THR) + _vt_col_inblk + _vt_lds_lane_offset = _vt_block_offset + _vt_inblock_offset + + # ---- Helper: load K sub-tile from LDS (16x32 for MFMA) ---- + def _load_k_from_lds(k_base_i32, row_offset, col_offset): + """Read 16x32 K sub-tile from LDS -> i64 for MFMA. + + row_offset: 0 or 16 (which half of BLOCK_N=32) + col_offset: column offset in elements (multiple of 32) + + KvManagerV2 LDS address formula: + row_phy = (row/2)*4 + (row%2) where row = lane_idx % 16 + p = p_lds_kv + (row_phy/4)*KV_SUB_BYTES + (row_phy%4)*KV_BYTES_PER_ROW + + (col%64)*sizeof(kv_t) + (col/64)*KV_BLOCK_BYTES + fixed_offset = (row_offset/16)*2*KV_BYTES_PER_ROW + + (col_offset%64)*sizeof(kv_t) + + (col_offset/64)*KV_BLOCK_BYTES + + NOTE: The fixed_offset is passed via static_byte_offset so LLVM + can potentially fold it into ds_read's immediate. Currently LLVM + lowers this to ds_read2_b64 due to inttoptr; a proper fix needs + FlyDSL infrastructure changes to emit ds_read_b64 with large offsets. + """ + # Fixed part: compile-time constant byte offset + fixed_offset = ( + (row_offset // 16) * 2 * KV_BYTES_PER_ROW + + (col_offset % KV_NUM_COLS) + + (col_offset // KV_NUM_COLS) * KV_BLOCK_BYTES + ) + + # ds_read_b64 with immediate offset (volatile prevents ds_read2 merge) + data = _lds_load_volatile(k_base_i32, T.i64, byte_offset=fixed_offset) + return data + + # ---- Helper: load V from KV LDS (un-transposed) ---- + def _load_v_from_lds(p_lds_kv_base_idx, warp_idx_val, lane_idx_val): + """Load un-transposed V: each warp reads 16x128 region. + + KvManagerV2::load_v_to_gpr pattern: + row = (warp%2)*16 + lane/16*4 + row_phy = ((row%16)/2)*4 + 2*(row/16) + (row%2) + col = (lane%16)*8 + (warp/2)*128 + Returns 8 i32 values. + """ + row = (warp_idx_val % arith.index(2)) * arith.index(16) + ( + lane_idx_val / arith.index(16) + ) * arith.index(4) + row_mod16 = row % arith.index(16) + row_phy = ( + (row_mod16 / arith.index(2)) * arith.index(4) + + arith.index(2) * (row / arith.index(16)) + + row % arith.index(2) + ) + col = (lane_idx_val % arith.index(16)) * arith.index(8) + ( + warp_idx_val / arith.index(2) + ) * arith.index(128) + + lds_v_offset = ( + (row_phy / arith.index(4)) * arith.index(KV_SUB_BYTES) + + (row_phy % arith.index(4)) * arith.index(KV_BYTES_PER_ROW) + + (col / arith.index(KV_NUM_COLS)) * arith.index(KV_BLOCK_BYTES) + + (col % arith.index(KV_NUM_COLS)) + ) + + lds_addr = p_lds_kv_base_idx + lds_v_offset + + # 4 x ds_read_b64: load 8 dwords at strides matching KvManagerV2 + v_vals = [] + for pass_idx in range_constexpr(4): + if pass_idx == 0: + off = 0 + elif pass_idx == 1: + off = KV_BYTES_PER_ROW + elif pass_idx == 2: + off = KV_SUB_BYTES + else: + off = KV_SUB_BYTES + KV_BYTES_PER_ROW + data = _lds_load( + lds_addr, + T.i32x2, + static_byte_offset=off, + ) + v_vals.append( + vector.extract(data, static_position=[0], dynamic_position=[]) + ) + v_vals.append( + vector.extract(data, static_position=[1], dynamic_position=[]) + ) + return v_vals # 8 i32 values + + # ---- Helper: transpose V in-register ---- + def _transpose_v(v8): + """12x v_perm_b32 to transpose 4x8 fp8 block. + + Ported from VtManagerV1::transpose_v. + Input: v8[0..7] in row-major 4x8 layout + Output: v8[0..7] in transposed layout for Vt storage + """ + # Phase 1: perm_0 (c_perm0=0x05010400) and perm_3 (c_perm1=0x07030602) + t0_0 = _vt_perm(v8[2], v8[0], c_perm0) + t2_0 = _vt_perm(v8[2], v8[0], c_perm1) + t0_1 = _vt_perm(v8[3], v8[1], c_perm0) + t2_1 = _vt_perm(v8[3], v8[1], c_perm1) + + t1_0 = _vt_perm(v8[6], v8[4], c_perm0) + t3_0 = _vt_perm(v8[6], v8[4], c_perm1) + t1_1 = _vt_perm(v8[7], v8[5], c_perm0) + t3_1 = _vt_perm(v8[7], v8[5], c_perm1) + + # Phase 2: perm_1 (c_perm2=0x05040100) and perm_2 (c_perm3=0x07060302) + # Output order: r0_0, r0_1, r1_0, r1_1, r2_0, r2_1, r3_0, r3_1 + r = [None] * 8 + r[0] = _vt_perm(t1_0, t0_0, c_perm2) # r0_0 + r[1] = _vt_perm(t1_1, t0_1, c_perm2) # r0_1 + r[2] = _vt_perm(t1_0, t0_0, c_perm3) # r1_0 + r[3] = _vt_perm(t1_1, t0_1, c_perm3) # r1_1 + r[4] = _vt_perm(t3_0, t2_0, c_perm2) # r2_0 + r[5] = _vt_perm(t3_1, t2_1, c_perm2) # r2_1 + r[6] = _vt_perm(t3_0, t2_0, c_perm3) # r3_0 + r[7] = _vt_perm(t3_1, t2_1, c_perm3) # r3_1 + return r + + # ---- Helper: store transposed V to Vt LDS ---- + def _store_vt_to_lds(vt_lds_base_idx, warp_idx_val, lane_idx_val, vt8): + """VtManagerV1::store_transposed_v_to_lds. + + 4x8 block-wise row-major layout, no padding between rows/cols. + row_blk = (warp%2)*4 + lane/16 + col_blk = (lane%16) + (warp/2)*16 + block_offset = (row_blk * VT_BLKS_PER_ROW_PAD + col_blk) * VT_ELEMS_PER_BLK + """ + row_blk = (warp_idx_val % arith.index(2)) * arith.index( + 4 + ) + lane_idx_val / arith.index(16) + col_blk = (lane_idx_val % arith.index(16)) + ( + warp_idx_val / arith.index(2) + ) * arith.index(16) + block_offset = ( + row_blk * arith.index(VT_BLKS_PER_ROW_PAD) + col_blk + ) * arith.index(VT_ELEMS_PER_BLK) + lds_vt_addr = vt_lds_base_idx + block_offset + + # ds_write_b128 x 2 (4 dwords each = 32 fp8) + lo_packed = vector.from_elements(T.i32x4, vt8[0:4]) + lo_i8 = vector.bitcast(T.i8x16, lo_packed) + vector.store(lo_i8, lds_buffer, [_raw(lds_vt_addr)]) + + hi_packed = vector.from_elements(T.i32x4, vt8[4:8]) + hi_i8 = vector.bitcast(T.i8x16, hi_packed) + vector.store(hi_i8, lds_buffer, [_raw(lds_vt_addr + arith.index(16))]) + + # ---- Helper: load transposed V from Vt LDS ---- + def _load_vt_from_lds(vt_base_i32, col_offset): + """VtManagerV1::load_transposed_v_to_gpr. + + Each warp reads 32x16 block from Vt LDS. Returns 2 i32 via ds_read_b32. + vt_base_i32: i32 LDS byte address with lane offset pre-baked. + col_offset: Python int, multiple of 16, in [0, 512). + + Lane offset pre-computed in _vt_lds_lane_offset (top level). + Only col_offset contributes a fixed immediate offset here. + offset_tl_bl = 4 * VT_BLKS_PER_ROW_PAD * VT_ELEMS_PER_BLK = 8448 + """ + fixed_col_blk = col_offset // VT_COLS_PER_THR + fixed_block_offset = fixed_col_blk * VT_ELEMS_PER_BLK + offset_tl_bl = 4 * VT_BLKS_PER_ROW_PAD * VT_ELEMS_PER_BLK # 8448 + + # ds_read_b32 x 2 with immediate offsets (volatile prevents ds_read2 merge) + v0 = _lds_load_volatile(vt_base_i32, T.i32, byte_offset=fixed_block_offset) + v1 = _lds_load_volatile( + vt_base_i32, T.i32, byte_offset=fixed_block_offset + offset_tl_bl + ) + return v0, v1 + + # ---- Helper: warp reduce (butterfly XOR) ---- + def _shfl_xor_f32(val_f32, offset_i32, width_i32): + """XOR shuffle for f32 via bitcast to i32 and back.""" + # Bitcast f32 -> i32 + val_i32 = _raw(arith.ArithValue(val_f32).bitcast(T.i32)) + # Shuffle as i32 + peer_i32 = _mlir_gpu.ShuffleOp( + val_i32, offset_i32, width_i32, _mlir_gpu.ShuffleMode.XOR + ).shuffleResult + # Bitcast i32 -> f32 + return _raw(arith.ArithValue(peer_i32).bitcast(T.f32)) + + def _warp_reduce_max_16(val): + """Butterfly max reduce across MFMA column groups. + + HK: reduce_range=64, stop_stride=15 -> strides [32, 16]. + This reduces across the 4 column groups (each owning 4 K positions) + while keeping each row (Q head) independent. + """ + w = _to_mlir(val) + width = _raw(arith.constant(WARP_SIZE, type=T.i32)) + for sh in [32, 16]: + offset = _raw(arith.constant(sh, type=T.i32)) + peer = _shfl_xor_f32(w, offset, width) + w = arith.MaximumFOp(w, peer, fastmath=fm_no_inf).result + return w + + def _warp_reduce_add_16(val): + """Butterfly sum reduce across MFMA column groups.""" + w = _to_mlir(val) + width = _raw(arith.constant(WARP_SIZE, type=T.i32)) + for sh in [32, 16]: + offset = _raw(arith.constant(sh, type=T.i32)) + peer = _shfl_xor_f32(w, offset, width) + w = arith.ArithValue(w).addf(peer, fastmath=fm_fast) + return w + + # ---- Helper: Q loading (QManagerV3) ---- + def _load_q_to_regs(qo_start_i32): + """Load Q from VRAM to registers via LDS staging. + + QManagerV3: each warp loads 16x64 per pass, 9 passes total. + VRAM -> LDS (ds_write_b128), then LDS -> register (ds_read_b64). + Returns (q_nope_regs, q_rope_regs): + q_nope_regs: list of 16 v2i64 (16 sub-tiles x 32 cols each) + q_rope_regs: list of 2 v2i64 (2 sub-tiles x 32 cols each) + """ + p_lds_q_warp = ( + lds_base_idx + + arith.index(P_LDS_Q) + + warp_idx * arith.index(SZ_LDS_Q_PER_WARP) + ) + + # VRAM addressing: row = lane/4, col = (lane%4)*16 + # s_offset = warp * 16 * QK_HEAD_DIM * sizeof(fp8) + # v_offset = (row * QK_HEAD_DIM + col) * sizeof(fp8) + s_offset_i32 = _raw( + warp_idx_i32 * arith.constant(16 * QK_HEAD_DIM, type=T.i32) + ) + # Add qo_start offset: qo_start * NUM_QO_HEADS * QK_HEAD_DIM + q_base_offset = _raw( + arith.ArithValue(_raw(qo_start_i32)) + * arith.constant(NUM_QO_HEADS * QK_HEAD_DIM, type=T.i32) + ) + s_offset_i32 = _raw(arith.ArithValue(s_offset_i32) + arith.ArithValue(q_base_offset)) + + row = lane_idx / arith.index(4) + col = (lane_idx % arith.index(4)) * arith.index(16) + v_offset_i32 = _index_cast_to_i32(row * arith.index(QK_HEAD_DIM) + col) + + # LDS store layout (QManagerV3): + # row_st = lane/4, col_st = (lane%4)*16 + # v_offset_st = (row_st/2)*Q_BYTES_PER_2ROWS + ((row_st%2)*64 + col_st) + row_st = lane_idx / arith.index(4) + col_st = (lane_idx % arith.index(4)) * arith.index(16) + lds_st_offset = ( + (row_st / arith.index(2)) * arith.index(Q_BYTES_PER_2ROWS) + + (row_st % arith.index(2)) * arith.index(Q_ELEM_PER_ROW) + + col_st + ) + + # LDS read layout (MFMA-compatible): + # row_ld = lane%16, col_ld = (lane/16)*8 + # v_offset_ld = (row_ld/2)*Q_BYTES_PER_2ROWS + ((row_ld%2)*64 + col_ld) + row_ld = lane_idx % arith.index(16) + col_ld = (lane_idx / arith.index(16)) * arith.index(8) + lds_ld_offset = ( + (row_ld / arith.index(2)) * arith.index(Q_BYTES_PER_2ROWS) + + (row_ld % arith.index(2)) * arith.index(Q_ELEM_PER_ROW) + + col_ld + ) + + q_regs = [] # Will hold 18 v2i64 = 16 nope + 2 rope + + # Fold s_offset and per-pass ioffset into voffset so that soffset=0. + # LLVM ISel only extracts immediate offsets when soffset is literal 0. + # v_offset_i32 is in bytes; buffer_load auto-scales by element_bytes + # (i32 = 4), so divide by 4. s_offset_i32 is also in bytes. + voff_dw = _raw( + (arith.ArithValue(_raw(v_offset_i32)) + arith.ArithValue(s_offset_i32)) + // arith.constant(4, type=T.i32) + ) + + # Pre-compute LDS pointers (constant across passes) + lds_st_addr = p_lds_q_warp + lds_st_offset + lds_st_i64 = arith.index_cast(T.i64, lds_st_addr) + lds_st_ptr = _inttoptr_lds(_raw(lds_st_i64)) + lds_rd_addr = p_lds_q_warp + lds_ld_offset + + def _q_buf_load(pass_idx): + voff_pass = _raw( + arith.ArithValue(voff_dw) + + arith.constant(pass_idx * Q_ELEM_PER_ROW // 4, type=T.i32) + ) + return buffer_ops.buffer_load( + query_rsrc, + voff_pass, + vec_width=4, + dtype=T.i32, + ) + + def _shuffle_q_through_lds(q_vram_data): + """LDS write (ds_write_b128) + barrier + LDS read (2x ds_read_b64).""" + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + llvm.StoreOp(_raw(q_vram_data), lds_st_ptr, alignment=16) + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + q0 = _lds_load(lds_rd_addr, T.i64, static_byte_offset=0) + q1 = _lds_load(lds_rd_addr, T.i64, static_byte_offset=MFMA_K) + return (q0, q1) + + # 3-deep pipeline: keep 2 buffer_loads in flight while shuffling + # the completed one through LDS (matches HK QManagerV3). + # Before loop: issue passes 0, 1 + # Iteration i: wait(1), issue pass i+2, shuffle pass i + # Last 2 iters: wait(0), shuffle (no new issue) + loads = [None, None, None] + loads[0] = _q_buf_load(0) + loads[1] = _q_buf_load(1) + + for i in range_constexpr(9): + slot = i % 3 + issue_pass = i + 2 + + if issue_pass < 9: + rocdl.s_waitcnt(_encode_waitcnt(vmcnt=1)) + loads[issue_pass % 3] = _q_buf_load(issue_pass) + else: + rocdl.s_waitcnt(_encode_waitcnt(vmcnt=0)) + + q_regs.append(_shuffle_q_through_lds(loads[slot])) + + # Split into nope (passes 0-7 -> 16 sub-tiles) and rope (pass 8 -> 2 sub-tiles) + q_nope_packs = [] + for i in range_constexpr(8): + q_nope_packs.append(q_regs[i][0]) # sub-tile 0 + q_nope_packs.append(q_regs[i][1]) # sub-tile 1 + q_rope_packs = [q_regs[8][0], q_regs[8][1]] + return q_nope_packs, q_rope_packs + + # ---- Helper: softmax scale + boundary masking ---- + def _softmax_scale_p(p_vals, col_0_start_i32, kv_end_i32, check_boundary): + """Scale p_vals by softmax_scale, mask OOB to -inf. + + check_boundary: False (skip), True (always mask), or ir.Value i1 + (runtime: mask only when True at runtime). + """ + result = [None] * 8 + for i in range_constexpr(8): + result[i] = arith.MulFOp( + _raw(p_vals[i]), _raw(softmax_scale), fastmath=fm_fast + ).result + + if check_boundary is not False: + for i in range_constexpr(8): + # Position of this element: col_0_start + (i//4)*16 + (i%4) + sub_offset = (i // 4) * 16 + (i % 4) + pos_i32 = _raw( + arith.ArithValue(_raw(col_0_start_i32)) + + arith.constant(sub_offset, type=T.i32) + ) + is_oob = arith.cmpi( + CmpIPredicate.sge, pos_i32, _raw(kv_end_i32) + ) + # If check_boundary is a runtime i1, AND it in + if check_boundary is not True: + is_oob = _raw(arith.ArithValue(check_boundary) & arith.ArithValue(is_oob)) + result[i] = _raw(arith.select(is_oob, _raw(c_neg_inf_f32), result[i])) + return result + + # ---- Helper: online softmax ---- + def _softmax( + p_vals, + row_max_old, + row_sum_e_old, + is_first_iter, + kv_tile_start_i32, + kv_end_i32, + check_boundary, + ): + """Online softmax: scale -> max -> exp2 -> sum -> rescale. + + p_vals: 8 f32 attention scores for this thread + Returns: (p_exp_vals, row_max_new, row_sum_e_new, rescale) + """ + # Column index for this thread's first element + col_0_idx = lane_idx / arith.index(16) + col_0_start_i32 = _raw( + arith.ArithValue(_index_cast_to_i32(col_0_idx * arith.index(4))) + + kv_tile_start_i32 + ) + + # Scale and mask + scaled = _softmax_scale_p(p_vals, col_0_start_i32, kv_end_i32, check_boundary) + + # Local max of 8 values + local_max = scaled[0] + for i in range_constexpr(1, 8): + local_max = arith.MaximumFOp( + local_max, _raw(scaled[i]), fastmath=fm_no_inf + ).result + + # Warp reduce max (within 16-lane groups) + local_max = _warp_reduce_max_16(local_max) + + # New row max + if is_first_iter: + new_row_max = local_max + rescale = _raw(c_one_f32) + else: + new_row_max = arith.MaximumFOp( + local_max, _raw(row_max_old), fastmath=fm_no_inf + ).result + # rescale = exp2((old_max - new_max) * log2e) + diff = arith.SubFOp( + _raw(row_max_old), new_row_max, fastmath=fm_no_inf + ).result + rescale_arg = arith.MulFOp( + diff, _raw(c_log2e), fastmath=fm_no_inf + ).result + rescale = _fast_exp2(rescale_arg) + + # exp(p - max) for each value, and sum + p_exp_vals = [None] * 8 + local_sum = _raw(c_zero_f32) + for i in range_constexpr(8): + # exp2((p[i] - new_max) * log2e) + diff = arith.SubFOp( + _raw(scaled[i]), new_row_max, fastmath=fm_no_inf + ).result + exp_arg = arith.MulFOp(diff, _raw(c_log2e), fastmath=fm_no_inf).result + p_exp_vals[i] = _fast_exp2(exp_arg) + local_sum = arith.AddFOp( + local_sum, p_exp_vals[i], fastmath=fm_fast + ).result + + # Warp reduce sum + local_sum = _warp_reduce_add_16(local_sum) + + # Update row_sum_e + if is_first_iter: + row_sum_e_new = local_sum + else: + row_sum_e_new = arith.AddFOp( + arith.MulFOp( + rescale, _raw(row_sum_e_old), fastmath=fm_fast + ).result, + local_sum, + fastmath=fm_fast, + ).result + + return p_exp_vals, new_row_max, row_sum_e_new, rescale + + # ---- Helper: pack P from f32 to fp8 ---- + def _pack_p_to_fp8(p_exp_vals): + """Pack 8 f32 -> 2 i32 (4x cvt_pk_fp8_f32) -> 1 i64 for MFMA.""" + w0 = rocdl.cvt_pk_fp8_f32( + T.i32, _raw(p_exp_vals[0]), _raw(p_exp_vals[1]), c_zero_i32, 0 + ) + w0 = rocdl.cvt_pk_fp8_f32( + T.i32, _raw(p_exp_vals[2]), _raw(p_exp_vals[3]), w0, 1 + ) + w1 = rocdl.cvt_pk_fp8_f32( + T.i32, _raw(p_exp_vals[4]), _raw(p_exp_vals[5]), c_zero_i32, 0 + ) + w1 = rocdl.cvt_pk_fp8_f32( + T.i32, _raw(p_exp_vals[6]), _raw(p_exp_vals[7]), w1, 1 + ) + w0_i64 = arith.ArithValue(w0).extui(T.i64) + w1_i64 = arith.ArithValue(w1).extui(T.i64) + c32_i64 = arith.constant(32, type=T.i64) + w1_shifted = w1_i64 << c32_i64 + p_pack = _raw(w0_i64 | w1_shifted) + return p_pack + + # ---- Helper: rescale oaccu ---- + def _rescale_oaccu(oaccu, rescale): + """Multiply all oaccu accumulators by rescale factor. + Descending s_setprio 3->0 across 4 groups of 8 muls.""" + rescale_vec = vector.broadcast(T.f32x4, rescale) + result = [None] * len(oaccu) + for group in range_constexpr(4): + rocdl.s_setprio(3 - group) + for j in range_constexpr(8): + i = group * 8 + j + result[i] = arith.MulFOp( + _raw(oaccu[i]), _raw(rescale_vec), fastmath=fm_fast + ).result + return result + + # ---- Helper: process one KV tile (GEMM1 + softmax + V + GEMM2) ---- + # Interleaves async prefetch of the NEXT tile's KV data + # into the GEMM1 NoPE loop (1 block per iteration, 9 total). + def _process_tile_gemm1( + p_lds_kv_base, + kv_tile_start_i32, + kv_end_i32, + q_nope, + q_rope, + row_max_in, + row_sum_e_in, + is_first_iter, + check_boundary, + p_lds_kv_next_warp_i32=None, + row_kv_ld_next=None, + kv_ld_col_base_i32_arg=None, + check_boundary_next=True, + # 2-ahead row resolution (match HK's row_kv_ld_next_next pattern) + nn_resolve_start=None, + nn_resolve_end=None, + do_resolve_nn=None, + ): + """Process one KV tile: QK GEMM -> softmax -> V transpose -> pack P. + + GEMM2 (PV accumulation) is NOT included -- call _gemm2_with_rescale + after the branch merge to keep oaccu out of phi nodes. + + Returns (row_max, row_sum_e, p_pack, rescale). + """ + # ---- K base VGPR (baked-in lane offset) ---- + k_base_i32 = _raw( + arith.ArithValue(_index_cast_to_i32(p_lds_kv_base)) + + arith.ArithValue(_index_cast_to_i32(_k_lds_lane_offset)) + ) + + do_prefetch = p_lds_kv_next_warp_i32 is not None + + def _maybe_prefetch(block_idx): + """Issue prefetch (OOB check controlled by check_boundary_next).""" + if not do_prefetch: + return + _prefetch_k_tile_asm( + p_lds_kv_next_warp_i32, + row_kv_ld_next, + kv_ld_col_base_i32_arg, + block_idx, + check_boundary=check_boundary_next, + ) + + # ---- Prefetch block 0 of next tile (inline asm, opaque to LLVM) ---- + _maybe_prefetch(0) + + # ---- GEMM1: QK attention scores ---- + p_comp = [_raw(c_zero_v4f32), _raw(c_zero_v4f32)] + + for nope_pair in range_constexpr(NUM_NOPE_ITERS): + tile_0 = nope_pair * 2 + tile_1 = nope_pair * 2 + 1 + + k0_lo = _load_k_from_lds(k_base_i32, 0, tile_0 * BLOCK_K) + k0_hi = _load_k_from_lds(k_base_i32, 16, tile_0 * BLOCK_K) + k1_lo = _load_k_from_lds(k_base_i32, 0, tile_1 * BLOCK_K) + k1_hi = _load_k_from_lds(k_base_i32, 16, tile_1 * BLOCK_K) + + # Prefetch block nope_pair+1 of next tile (inline asm) + _maybe_prefetch(nope_pair + 1) + + rocdl.sched_barrier(0) + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=2)) + + q_0 = q_nope[tile_0] + q_1 = q_nope[tile_1] + + if nope_pair == 0: + p_comp[0] = _mfma_fp8( + T.f32x4, [k0_lo, q_0, _raw(c_zero_v4f32), 0, 0, 0] + ) + p_comp[1] = _mfma_fp8( + T.f32x4, [k0_hi, q_0, _raw(c_zero_v4f32), 0, 0, 0] + ) + rocdl.s_setprio(15) + else: + p_comp[0] = _mfma_fp8(T.f32x4, [k0_lo, q_0, p_comp[0], 0, 0, 0]) + p_comp[1] = _mfma_fp8(T.f32x4, [k0_hi, q_0, p_comp[1], 0, 0, 0]) + + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + + p_comp[0] = _mfma_fp8(T.f32x4, [k1_lo, q_1, p_comp[0], 0, 0, 0]) + p_comp[1] = _mfma_fp8(T.f32x4, [k1_hi, q_1, p_comp[1], 0, 0, 0]) + + for rope_pair in range_constexpr(NUM_ROPE_ITERS): + tile_0 = rope_pair * 2 + tile_1 = rope_pair * 2 + 1 + + k0_lo = _load_k_from_lds(k_base_i32, 0, (tile_0 + 16) * BLOCK_K) + k0_hi = _load_k_from_lds(k_base_i32, 16, (tile_0 + 16) * BLOCK_K) + k1_lo = _load_k_from_lds(k_base_i32, 0, (tile_1 + 16) * BLOCK_K) + k1_hi = _load_k_from_lds(k_base_i32, 16, (tile_1 + 16) * BLOCK_K) + + rocdl.sched_barrier(0) + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=2)) + + p_comp[0] = _mfma_fp8(T.f32x4, [k0_lo, q_rope[tile_0], p_comp[0], 0, 0, 0]) + p_comp[1] = _mfma_fp8(T.f32x4, [k0_hi, q_rope[tile_0], p_comp[1], 0, 0, 0]) + + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + + p_comp[0] = _mfma_fp8(T.f32x4, [k1_lo, q_rope[tile_1], p_comp[0], 0, 0, 0]) + p_comp[1] = _mfma_fp8(T.f32x4, [k1_hi, q_rope[tile_1], p_comp[1], 0, 0, 0]) + + rocdl.s_setprio(14) + + # ---- Extract p_comp values for softmax ---- + p_vals = [] + for sub in range_constexpr(2): + for ii in range_constexpr(4): + p_vals.append( + vector.extract( + p_comp[sub], static_position=[ii], dynamic_position=[] + ) + ) + + # ---- Load V from KV LDS ---- + v8_raw = _load_v_from_lds(p_lds_kv_base, warp_idx, lane_idx) + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + rocdl.sched_barrier(0) + + # ---- Resolve row for tile+2 (2-ahead, matches HK line 407-426) ---- + # The buffer_load has softmax+V-transpose+GEMM2+barrier to complete. + if do_resolve_nn is not None: + neg_one_nn = _raw(arith.constant(-1, type=T.i32)) + if_nn = scf.IfOp(do_resolve_nn, [T.i32], has_else=True) + with ir.InsertionPoint(if_nn.regions[0].blocks[0]): + row_nn_resolved = _get_kv_ld_row(nn_resolve_start, nn_resolve_end, True) + scf.YieldOp([row_nn_resolved]) + with ir.InsertionPoint(if_nn.regions[1].blocks[0]): + scf.YieldOp([neg_one_nn]) + row_kv_ld_nn = if_nn.results[0] + else: + row_kv_ld_nn = _raw(arith.constant(-1, type=T.i32)) + + # ---- Softmax ---- + p_exp_vals, row_max_new, row_sum_e_new, rescale = _softmax( + p_vals, + row_max_in, + row_sum_e_in, + is_first_iter, + kv_tile_start_i32, + kv_end_i32, + check_boundary, + ) + + # ---- Transpose V and store to Vt LDS ---- + vt8 = _transpose_v(v8_raw) + vt_lds_base = lds_base_idx + arith.index(P_LDS_VT) + _store_vt_to_lds(vt_lds_base, warp_idx, lane_idx, vt8) + + # ---- Pack P to fp8 ---- + p_pack = _pack_p_to_fp8(p_exp_vals) + + return row_max_new, row_sum_e_new, p_pack, rescale, row_kv_ld_nn + + def _gemm2_core(p_pack, oaccu, vt_base_i32): + """GEMM2 PV accumulation loop (shared by first-iter and rescale paths). + + Matches HK interleaving: 8x ds_read_b32 burst (2 PV iters), + lgkmcnt(4) -> 2 MFMA, lgkmcnt(0) -> 2 MFMA. + """ + c32_i64_pv = _raw(arith.constant(32, type=T.i64)) + rocdl.s_setprio(15) + for pv_pair in range_constexpr(NUM_PV_ITERS // 2): + # Load 8 values: vt for 2 consecutive PV iterations + iter_a = pv_pair * 2 + iter_b = pv_pair * 2 + 1 + col_a0 = iter_a * MFMA_N * 2 + col_a1 = col_a0 + MFMA_N + col_b0 = iter_b * MFMA_N * 2 + col_b1 = col_b0 + MFMA_N + + # 8x ds_read_b32 burst + vta0_lo, vta0_hi = _load_vt_from_lds(vt_base_i32, col_a0) + vta1_lo, vta1_hi = _load_vt_from_lds(vt_base_i32, col_a1) + vtb0_lo, vtb0_hi = _load_vt_from_lds(vt_base_i32, col_b0) + vtb1_lo, vtb1_hi = _load_vt_from_lds(vt_base_i32, col_b1) + + rocdl.sched_barrier(0) + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=4)) + + # MFMA pair A (first PV iter) + kv_mfma_a0 = _raw( + arith.ArithValue(vta0_lo).extui(T.i64) + | (arith.ArithValue(vta0_hi).extui(T.i64) << c32_i64_pv) + ) + oaccu[iter_a * 2] = _mfma_fp8( + T.f32x4, [kv_mfma_a0, p_pack, oaccu[iter_a * 2], 0, 0, 0] + ) + + kv_mfma_a1 = _raw( + arith.ArithValue(vta1_lo).extui(T.i64) + | (arith.ArithValue(vta1_hi).extui(T.i64) << c32_i64_pv) + ) + oaccu[iter_a * 2 + 1] = _mfma_fp8( + T.f32x4, [kv_mfma_a1, p_pack, oaccu[iter_a * 2 + 1], 0, 0, 0] + ) + rocdl.sched_barrier(0) + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + + # MFMA pair B (second PV iter) + kv_mfma_b0 = _raw( + arith.ArithValue(vtb0_lo).extui(T.i64) + | (arith.ArithValue(vtb0_hi).extui(T.i64) << c32_i64_pv) + ) + oaccu[iter_b * 2] = _mfma_fp8( + T.f32x4, [kv_mfma_b0, p_pack, oaccu[iter_b * 2], 0, 0, 0] + ) + + kv_mfma_b1 = _raw( + arith.ArithValue(vtb1_lo).extui(T.i64) + | (arith.ArithValue(vtb1_hi).extui(T.i64) << c32_i64_pv) + ) + oaccu[iter_b * 2 + 1] = _mfma_fp8( + T.f32x4, [kv_mfma_b1, p_pack, oaccu[iter_b * 2 + 1], 0, 0, 0] + ) + rocdl.sched_barrier(0) + + if pv_pair < NUM_PV_ITERS // 2 - 1: + rocdl.s_nop(1) + + rocdl.s_setprio(0) + return oaccu + + def _gemm2_first_iter(p_pack, vt_base_i32): + """GEMM2 for first iteration: C=0 (hardcoded), no rescale. + + The MFMA C input is literal c_zero_v4f32, so LLVM doesn't need + oaccu registers live -- results go to fresh registers. + """ + _barrier(lgkmcnt=0) + rocdl.sched_barrier(0) + oaccu = [_raw(c_zero_v4f32)] * (NUM_PV_ITERS * 2) + return _gemm2_core(p_pack, oaccu, vt_base_i32) + + def _gemm2_with_rescale(p_pack, rescale, oaccu_in, vt_base_i32): + """Rescale oaccu, barrier, then GEMM2 PV accumulation. + + This runs after the branch merge so oaccu never enters phi nodes. + """ + oaccu = _rescale_oaccu(oaccu_in, rescale) + _barrier(lgkmcnt=0) + rocdl.sched_barrier(0) + return _gemm2_core(p_pack, oaccu, vt_base_i32) + + def _pack_f32x4_to_bf16_2dw(acc_val): + """Convert f32x4 accumulator to 2 packed bf16 dwords.""" + bf16_vals = arith.trunc_f(T.bf16x4, acc_val) + i16_vals = _raw(vector.bitcast(T.i16x4, bf16_vals)) + i16_0 = vector.extract(i16_vals, static_position=[0], dynamic_position=[]) + i16_1 = vector.extract(i16_vals, static_position=[1], dynamic_position=[]) + i16_2 = vector.extract(i16_vals, static_position=[2], dynamic_position=[]) + i16_3 = vector.extract(i16_vals, static_position=[3], dynamic_position=[]) + c16 = arith.constant(16, type=T.i32) + lo_0 = arith.ArithValue(i16_0).extui(T.i32) + hi_0 = arith.ArithValue(i16_1).extui(T.i32) + dw0 = _raw(lo_0 | (hi_0 << c16)) + lo_1 = arith.ArithValue(i16_2).extui(T.i32) + hi_1 = arith.ArithValue(i16_3).extui(T.i32) + dw1 = _raw(lo_1 | (hi_1 << c16)) + return dw0, dw1 + + # ---- Pre-compute LDS reshape addresses (computed once, reused per store) ---- + # bf16 LDS write address (MFMA layout): row_st = lane%16, col_st = (lane/16)*4 + _li = arith.ArithValue(_raw(lane_idx_i32)) + _c2 = arith.constant(2, type=T.i32) + _c4 = arith.constant(4, type=T.i32) + _c8 = arith.constant(8, type=T.i32) + _c16_i = arith.constant(16, type=T.i32) + + _o16_row_st = arith.RemUIOp(_raw(_li), _raw(_c16_i)).result + _o16_col_st = _raw( + arith.ArithValue(arith.DivUIOp(_raw(_li), _raw(_c16_i)).result) * _c4 + ) + # get_v_offset_lds(r,c) = ((r/2)*68 + (r%2)*32 + c) * 2 [bytes] + _o16_st_r = arith.ArithValue(_o16_row_st) + _o16_st_offset = _raw( + ( + arith.ArithValue(arith.DivUIOp(_o16_row_st, _raw(_c2)).result) + * arith.constant(O16_ELEM_PER_PAD_2ROWS, type=T.i32) + + arith.ArithValue(arith.RemUIOp(_o16_row_st, _raw(_c2)).result) + * arith.constant(O16_NUM_COLS, type=T.i32) + + arith.ArithValue(_o16_col_st) + ) + * _c2 # sizeof(bf16) + ) + + # bf16 LDS read address (coalesced layout): row_ld = lane/4, col_ld = (lane%4)*8 + _o16_row_ld = arith.DivUIOp(_raw(_li), _raw(_c4)).result + _o16_col_ld = _raw( + arith.ArithValue(arith.RemUIOp(_raw(_li), _raw(_c4)).result) * _c8 + ) + _o16_rd_offset = _raw( + ( + arith.ArithValue(arith.DivUIOp(_o16_row_ld, _raw(_c2)).result) + * arith.constant(O16_ELEM_PER_PAD_2ROWS, type=T.i32) + + arith.ArithValue(arith.RemUIOp(_o16_row_ld, _raw(_c2)).result) + * arith.constant(O16_NUM_COLS, type=T.i32) + + arith.ArithValue(_o16_col_ld) + ) + * _c2 # sizeof(bf16) + ) + + # f32 LDS write address: same row_st/col_st but different padding + # get_v_offset_lds(r,c) = (r * 36 + c) * 4 [bytes] + _o32_st_offset = _raw( + ( + arith.ArithValue(_o16_row_st) # reuse: lane%16 + * arith.constant(O32_ELEM_PER_PAD_ROW, type=T.i32) + + arith.ArithValue(_o16_col_st) # reuse: (lane/16)*4 + ) + * _c4 # sizeof(f32) + ) + + # f32 LDS read address: row_ld = lane/8, col_ld = (lane%8)*4 + _o32_row_ld = arith.DivUIOp(_raw(_li), _raw(_c8)).result + _o32_col_ld = _raw( + arith.ArithValue(arith.RemUIOp(_raw(_li), _raw(_c8)).result) * _c4 + ) + _o32_rd_offset = _raw( + ( + arith.ArithValue(_o32_row_ld) + * arith.constant(O32_ELEM_PER_PAD_ROW, type=T.i32) + + arith.ArithValue(_o32_col_ld) + ) + * _c4 # sizeof(f32) + ) + + def _store_oaccu_pair_bf16(oaccu_a, oaccu_b, tile_idx, p_lds_o_i32, row_base_i32): + """Store 2 oaccu groups (1 PV iter) as bf16 via LDS reshape. + + Matches HK OManager16bitsV2: writes MFMA-layout data to LDS, + reads back in row-major coalesced layout, then buffer_store_dwordx4. + """ + # Per-warp LDS base + lds_warp = _raw( + arith.ArithValue(p_lds_o_i32) + + warp_idx_i32 * arith.constant(O16_LDS_PER_WARP, type=T.i32) + ) + lds_st_addr = _raw(arith.ArithValue(lds_warp) + arith.ArithValue(_o16_st_offset)) + + # LDS write: 2 sub-blocks -> 2x ds_write_b64 + for sub, acc_val in enumerate([oaccu_a, oaccu_b]): + dw0, dw1 = _pack_f32x4_to_bf16_2dw(acc_val) + vec_2dw = vector.from_elements(T.i32x2, [dw0, dw1]) + sub_offset = sub * O16_NUM_COLS # 0 or 32 bytes (16 bf16 cols x 2 bytes) + st_addr_sub = _raw( + arith.ArithValue(lds_st_addr) + arith.constant(sub_offset, type=T.i32) + ) + st_i64 = _raw(arith.ArithValue(st_addr_sub).extui(T.i64)) + st_ptr = _inttoptr_lds(st_i64) + llvm.StoreOp( + vec_2dw, + st_ptr, + alignment=8, + volatile_=True, + ) # ds_write_b64 + + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + + # LDS read: ds_read_b128 (4 dwords = 8 bf16 in coalesced layout) + lds_rd_addr = _raw(arith.ArithValue(lds_warp) + arith.ArithValue(_o16_rd_offset)) + rd_i64 = _raw(arith.ArithValue(lds_rd_addr).extui(T.i64)) + rd_ptr = _inttoptr_lds(rd_i64) + data = llvm.LoadOp(T.i32x4, rd_ptr, alignment=16).result # ds_read_b128 + + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + + # Coalesced VRAM store: buffer_store_dwordx4 + # row = row_ld + row_base, col = col_ld + tile_idx * MFMA_N * 2 + col_offset_i32 = arith.constant(tile_idx * MFMA_N * 2, type=T.i32) + row_vram = arith.ArithValue(row_base_i32) + arith.ArithValue(_o16_row_ld) + col_vram = arith.ArithValue(_o16_col_ld) + col_offset_i32 + vram_offset = _raw( + (row_vram * arith.constant(V_HEAD_DIM, type=T.i32) + col_vram) + * arith.constant(2, type=T.i32) # sizeof(bf16) + ) + buffer_ops.buffer_store( + data, + final_output_rsrc, + vram_offset, + offset_is_bytes=True, + ) + + def _store_oaccu_pair_split(oaccu_a, oaccu_b, tile_idx, p_lds_o_i32, row_base_i32): + """Store 2 oaccu groups (1 PV iter) as f32 via LDS reshape. + + Matches HK OManager32bitsV2: writes MFMA-layout f32 data to LDS, + reads back in row-major coalesced layout, then buffer_store_dwordx4. + 16 rows need 2 rounds (8 rows each) because 64 lanes / 8 lanes-per-row = 8. + """ + # Per-warp LDS base + lds_warp = _raw( + arith.ArithValue(p_lds_o_i32) + + warp_idx_i32 * arith.constant(O32_LDS_PER_WARP, type=T.i32) + ) + lds_st_addr = _raw(arith.ArithValue(lds_warp) + arith.ArithValue(_o32_st_offset)) + + col_offset_i32 = _raw(arith.constant(tile_idx * MFMA_N * 2, type=T.i32)) + O32_LD_DELTA = 8 * O32_ELEM_PER_PAD_ROW * 4 # 1152 bytes between round 0/1 + + # LDS write: 2 sub-blocks -> 2x ds_write_b128 + rocdl.s_waitcnt(_encode_waitcnt(vmcnt=0)) # HK pattern: drain prior stores + for sub, acc_val in enumerate([oaccu_a, oaccu_b]): + sub_offset = sub * O32_NUM_COLS // 2 * 4 # 0 or 64 bytes + st_addr_sub = _raw( + arith.ArithValue(lds_st_addr) + arith.constant(sub_offset, type=T.i32) + ) + st_i64 = _raw(arith.ArithValue(st_addr_sub).extui(T.i64)) + st_ptr = _inttoptr_lds(st_i64) + llvm.StoreOp(acc_val, st_ptr, alignment=16) # ds_write_b128 + + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + + # LDS read: 2x ds_read_b128 (round 0 = rows 0-7, round 1 = rows 8-15) + lds_rd_addr = _raw(arith.ArithValue(lds_warp) + arith.ArithValue(_o32_rd_offset)) + rd_i64 = _raw(arith.ArithValue(lds_rd_addr).extui(T.i64)) + rd_ptr = _inttoptr_lds(rd_i64) + data_0 = llvm.LoadOp(T.f32x4, rd_ptr, alignment=16).result # rows 0-7 + rd_ptr_1 = _get_element_ptr(rd_ptr, static_byte_offset=O32_LD_DELTA) + data_1 = llvm.LoadOp(T.f32x4, rd_ptr_1, alignment=16).result # rows 8-15 + + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + + # 2x coalesced VRAM store + # Round 0: row = row_ld_base(0..7) + row_base + row_vram_0 = arith.ArithValue(row_base_i32) + arith.ArithValue(_o32_row_ld) + col_vram = arith.ArithValue(_o32_col_ld) + arith.ArithValue(col_offset_i32) + vram_off_0 = _raw( + (row_vram_0 * arith.constant(V_HEAD_DIM, type=T.i32) + col_vram) + * arith.constant(4, type=T.i32) # sizeof(f32) + ) + # Bitcast f32x4 -> i32x4 for buffer_store + data_0_i32 = _raw(vector.bitcast(T.i32x4, data_0)) + buffer_ops.buffer_store( + data_0_i32, + split_output_rsrc, + vram_off_0, + offset_is_bytes=True, + ) + + # Round 1: row = row_ld_base + 8 + row_base + row_vram_1 = row_vram_0 + arith.constant(8, type=T.i32) + vram_off_1 = _raw( + (row_vram_1 * arith.constant(V_HEAD_DIM, type=T.i32) + col_vram) + * arith.constant(4, type=T.i32) + ) + data_1_i32 = _raw(vector.bitcast(T.i32x4, data_1)) + buffer_ops.buffer_store( + data_1_i32, + split_output_rsrc, + vram_off_1, + offset_is_bytes=True, + ) + + def _gemm2_last_with_store( + p_pack, + rescale, + oaccu_in, + vt_base_i32, + reci_sum, + is_split, + p_lds_o_i32, + row_base_i32, + is_first_iter_flag, + ): + """Last-tile GEMM2: interleave rescale + MFMA + normalize + store. + + Matches HK's kIsLastIter pattern. For each of 8 PV pairs: + 1. Rescale 2 oaccu groups (skip if first iter) + 2. Load Vt from LDS (4x ds_read) + 3. 2 MFMAs (accumulate or init) + 4. Multiply by reci_sum + 5. Store immediately (bf16 or f32 split) + """ + rescale_vec = vector.broadcast(T.f32x4, rescale) + reci_vec = vector.broadcast(T.f32x4, reci_sum) + c32_i64_pv = _raw(arith.constant(32, type=T.i64)) + + _barrier(lgkmcnt=0) + rocdl.sched_barrier(0) + rocdl.s_setprio(15) + for pv_pair in range_constexpr(NUM_PV_ITERS // 2): + iter_a = pv_pair * 2 + iter_b = pv_pair * 2 + 1 + col_a0 = iter_a * MFMA_N * 2 + col_a1 = col_a0 + MFMA_N + col_b0 = iter_b * MFMA_N * 2 + col_b1 = col_b0 + MFMA_N + + # Rescale 4 oaccu groups for this pair (skip if first iter) + if not is_first_iter_flag: + oaccu_in[iter_a * 2] = arith.MulFOp( + _raw(oaccu_in[iter_a * 2]), + _raw(rescale_vec), + fastmath=fm_fast, + ).result + oaccu_in[iter_a * 2 + 1] = arith.MulFOp( + _raw(oaccu_in[iter_a * 2 + 1]), + _raw(rescale_vec), + fastmath=fm_fast, + ).result + oaccu_in[iter_b * 2] = arith.MulFOp( + _raw(oaccu_in[iter_b * 2]), + _raw(rescale_vec), + fastmath=fm_fast, + ).result + oaccu_in[iter_b * 2 + 1] = arith.MulFOp( + _raw(oaccu_in[iter_b * 2 + 1]), + _raw(rescale_vec), + fastmath=fm_fast, + ).result + + # 8x ds_read_b32 burst + vta0_lo, vta0_hi = _load_vt_from_lds(vt_base_i32, col_a0) + vta1_lo, vta1_hi = _load_vt_from_lds(vt_base_i32, col_a1) + vtb0_lo, vtb0_hi = _load_vt_from_lds(vt_base_i32, col_b0) + vtb1_lo, vtb1_hi = _load_vt_from_lds(vt_base_i32, col_b1) + + rocdl.sched_barrier(0) + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=4)) + + # MFMA pair A + kv_mfma_a0 = _raw( + arith.ArithValue(vta0_lo).extui(T.i64) + | (arith.ArithValue(vta0_hi).extui(T.i64) << c32_i64_pv) + ) + oaccu_in[iter_a * 2] = _mfma_fp8( + T.f32x4, [kv_mfma_a0, p_pack, oaccu_in[iter_a * 2], 0, 0, 0] + ) + + kv_mfma_a1 = _raw( + arith.ArithValue(vta1_lo).extui(T.i64) + | (arith.ArithValue(vta1_hi).extui(T.i64) << c32_i64_pv) + ) + oaccu_in[iter_a * 2 + 1] = _mfma_fp8( + T.f32x4, [kv_mfma_a1, p_pack, oaccu_in[iter_a * 2 + 1], 0, 0, 0] + ) + rocdl.sched_barrier(0) + rocdl.s_waitcnt(_encode_waitcnt(lgkmcnt=0)) + + # MFMA pair B + kv_mfma_b0 = _raw( + arith.ArithValue(vtb0_lo).extui(T.i64) + | (arith.ArithValue(vtb0_hi).extui(T.i64) << c32_i64_pv) + ) + oaccu_in[iter_b * 2] = _mfma_fp8( + T.f32x4, [kv_mfma_b0, p_pack, oaccu_in[iter_b * 2], 0, 0, 0] + ) + + kv_mfma_b1 = _raw( + arith.ArithValue(vtb1_lo).extui(T.i64) + | (arith.ArithValue(vtb1_hi).extui(T.i64) << c32_i64_pv) + ) + oaccu_in[iter_b * 2 + 1] = _mfma_fp8( + T.f32x4, [kv_mfma_b1, p_pack, oaccu_in[iter_b * 2 + 1], 0, 0, 0] + ) + rocdl.sched_barrier(0) + + # Normalize by reci_sum + oaccu_in[iter_a * 2] = arith.MulFOp( + oaccu_in[iter_a * 2], _raw(reci_vec), fastmath=fm_fast + ).result + oaccu_in[iter_a * 2 + 1] = arith.MulFOp( + oaccu_in[iter_a * 2 + 1], _raw(reci_vec), fastmath=fm_fast + ).result + oaccu_in[iter_b * 2] = arith.MulFOp( + oaccu_in[iter_b * 2], _raw(reci_vec), fastmath=fm_fast + ).result + oaccu_in[iter_b * 2 + 1] = arith.MulFOp( + oaccu_in[iter_b * 2 + 1], _raw(reci_vec), fastmath=fm_fast + ).result + + # Store immediately via LDS reshape (coalesced) + if is_split: + _store_oaccu_pair_split( + oaccu_in[iter_a * 2], + oaccu_in[iter_a * 2 + 1], + iter_a, + p_lds_o_i32, + row_base_i32, + ) + _store_oaccu_pair_split( + oaccu_in[iter_b * 2], + oaccu_in[iter_b * 2 + 1], + iter_b, + p_lds_o_i32, + row_base_i32, + ) + else: + _store_oaccu_pair_bf16( + oaccu_in[iter_a * 2], + oaccu_in[iter_a * 2 + 1], + iter_a, + p_lds_o_i32, + row_base_i32, + ) + _store_oaccu_pair_bf16( + oaccu_in[iter_b * 2], + oaccu_in[iter_b * 2 + 1], + iter_b, + p_lds_o_i32, + row_base_i32, + ) + + rocdl.s_setprio(0) + + # ================================================================== + # KV LDS buffer pointers -- computed once, persist across work items + # ================================================================== + p_lds_kv_0_base = lds_base_idx + arith.index(P_LDS_KV_0) + p_lds_kv_1_base = lds_base_idx + arith.index(P_LDS_KV_1) + + kv_warp_offset_i32 = _raw( + warp_idx_i32 * arith.constant(KV_SUB_BYTES, type=T.i32) + ) + + p_lds_kv_0_warp_i32 = _raw( + arith.ArithValue(_index_cast_to_i32(p_lds_kv_0_base)) + + arith.ArithValue(kv_warp_offset_i32) + ) + p_lds_kv_1_warp_i32 = _raw( + arith.ArithValue(_index_cast_to_i32(p_lds_kv_1_base)) + + arith.ArithValue(kv_warp_offset_i32) + ) + + # Vt base pointer (invariant across all tiles -- only depends on + # lds_base_idx + P_LDS_VT + lane offset). Computed once here so + # _gemm2_with_rescale can use it outside branches. + vt_base_i32 = _raw( + arith.ArithValue(_index_cast_to_i32(lds_base_idx + arith.index(P_LDS_VT))) + + arith.ArithValue(_index_cast_to_i32(_vt_lds_lane_offset)) + ) + + # ================================================================== + # Main kernel body: persistent-thread work loop + # ================================================================== + for work_idx in range(work_start_idx, work_end_idx): + # Load MlaWorkInfo + wi_base_i32 = _index_cast_to_i32(work_idx * SIZE_MLA_WORK_INFO_IN_DW) + wi_dw1_4 = buffer_ops.buffer_load( + work_info_set_rsrc, + arith.addi(wi_base_i32, arith.constant(1, type=T.i32)), + vec_width=4, + dtype=T.i32, + ) + wi_dw5 = buffer_ops.buffer_load( + work_info_set_rsrc, + arith.addi(wi_base_i32, arith.constant(5, type=T.i32)), + vec_width=1, + dtype=T.i32, + ) + partial_qo_loc = rocdl.readfirstlane(T.i32, _raw(vector.extract(wi_dw1_4, [0]))) + qo_start = rocdl.readfirstlane(T.i32, _raw(vector.extract(wi_dw1_4, [1]))) + qo_end = rocdl.readfirstlane(T.i32, _raw(vector.extract(wi_dw1_4, [2]))) + kv_start = rocdl.readfirstlane(T.i32, _raw(vector.extract(wi_dw1_4, [3]))) + kv_end = rocdl.readfirstlane(T.i32, _raw(wi_dw5)) + kv_len = arith.subi(kv_end, kv_start) + + # ---- KV tile iteration ---- + # Initialize softmax state + row_max = _raw(c_neg_inf_f32) + row_sum_e = _raw(c_zero_f32) + oaccu = [_raw(c_zero_v4f32)] * (NUM_PV_ITERS * 2) + + # Compute number of tiles + c_block_n = arith.constant(BLOCK_N, type=T.i32) + c_block_n_m1 = arith.constant(BLOCK_N - 1, type=T.i32) + num_tiles = arith.divui(arith.addi(kv_len, c_block_n_m1), c_block_n) + num_tiles_idx = arith.index_cast(T.index, num_tiles) + + # --- Pre-compute boundary flags --- + first_tile_needs_boundary = arith.cmpi( + CmpIPredicate.slt, _raw(kv_len), _raw(c_block_n), + ) + has_multi_tiles = arith.cmpi( + CmpIPredicate.sgt, _raw(kv_len), _raw(c_block_n), + ) + # last_tile_partial: (kv_len & (BLOCK_N-1)) != 0 + last_tile_partial = arith.cmpi( + CmpIPredicate.ne, + _raw(arith.ArithValue(_raw(kv_len)) & c_block_n_m1), + _raw(arith.constant(0, type=T.i32)), + ) + + # --- First tile: resolve KV row (branched on boundary) --- + if_setup = scf.IfOp(first_tile_needs_boundary, [T.i32], has_else=True) + with ir.InsertionPoint(if_setup.regions[0].blocks[0]): + row_cb = _get_kv_ld_row(kv_start, kv_end, True) + scf.YieldOp([row_cb]) + with ir.InsertionPoint(if_setup.regions[1].blocks[0]): + kv_first_end = _raw(kv_start + c_block_n) + row_nc = _get_kv_ld_row(kv_start, kv_first_end, False) + scf.YieldOp([row_nc]) + row_kv_ld_first = if_setup.results[0] + + # Load Q to GPR (independent of boundary check) + q_nope_packs, q_rope_packs = _load_q_to_regs(qo_start) + + # Async load first tile KV to LDS (branched) + if_load = scf.IfOp(first_tile_needs_boundary, [], has_else=True) + with ir.InsertionPoint(if_load.regions[0].blocks[0]): + _async_load_kv_all( + p_lds_kv_0_warp_i32, + row_kv_ld_first, + kv_ld_col_base_i32, + check_boundary=True, + ) + scf.YieldOp([]) + with ir.InsertionPoint(if_load.regions[1].blocks[0]): + _async_load_kv_all( + p_lds_kv_0_warp_i32, + row_kv_ld_first, + kv_ld_col_base_i32, + check_boundary=False, + ) + scf.YieldOp([]) + + # --- Tile-1 row resolution (only meaningful for multi-tile) --- + # tile1_is_full: kv_start + 2*BN <= kv_end (equiv to num_tiles >= 3) + c_2bn = arith.constant(2 * BLOCK_N, type=T.i32) + kv_start_plus_bn = _raw(kv_start + c_block_n) + kv_start_plus_2bn = _raw(kv_start + c_2bn) + tile1_is_full = arith.cmpi( + CmpIPredicate.sle, kv_start_plus_2bn, _raw(kv_end), + ) + if_tile1_row = scf.IfOp(tile1_is_full, [T.i32], has_else=True) + with ir.InsertionPoint(if_tile1_row.regions[0].blocks[0]): + # Tile 1 is full -> no boundary check, tile_end = start+2*BN + row_t1_full = _get_kv_ld_row(kv_start_plus_bn, kv_start_plus_2bn, False) + scf.YieldOp([row_t1_full]) + with ir.InsertionPoint(if_tile1_row.regions[1].blocks[0]): + # Tile 1 may be partial -> boundary check, tile_end = kv_end + row_t1_partial = _get_kv_ld_row(kv_start_plus_bn, _raw(kv_end), True) + scf.YieldOp([row_t1_partial]) + row_kv_ld_tile1 = if_tile1_row.results[0] + + # check_boundary_next for first tile: True only when + # num_tiles==2 AND last_tile_partial (next tile is partial last) + # Equiv: !tile1_is_full AND last_tile_partial + # But simpler: cbn = !tile1_is_full (when num_tiles>=2, !tile1_is_full + # means num_tiles==2, and if num_tiles==2 and tile1 not full then + # last_tile_partial must be true). Actually just use: !tile1_is_full AND has_multi_tiles AND last_tile_partial. + # Simplest correct: HK uses (kv_1st_end + BN - 1) < kv_end -> !(kv_start+2*BN <= kv_end) -> !tile1_is_full + # Wait: HK condition for cbn=False is (kv_1st_end + BN - 1) < kv_end i.e. kv_start+2*BN-1 < kv_end + # i.e. kv_start+2*BN <= kv_end i.e. tile1_is_full. So cbn=False when tile1_is_full. + # cbn=True when !tile1_is_full. This is correct regardless of last_tile_partial because + # when num_tiles==2 and !tile1_is_full, the next tile IS the last and IS partial. + # !tile1_is_full: kv_start + 2*BN > kv_end (num_tiles == 2, next tile partial) + first_tile_cbn = arith.cmpi( + CmpIPredicate.sgt, kv_start_plus_2bn, _raw(kv_end), + ) + + # --- Process first tile --- + # 5 values through phi: rm, rse, p_pack, rescale, row_kv_ld_nn + first_result_types = [T.f32, T.f32, T.i64, T.f32, T.i32] + + # 2-ahead resolve params for first tile (tile 0 -> resolve tile 2) + # nn_start = kv_start + 2*BN, nn_end = kv_end (always boundary check) + # do_resolve: tile 2 exists iff kv_start + 2*BN < kv_end + do_resolve_nn_first = arith.cmpi( + CmpIPredicate.slt, kv_start_plus_2bn, _raw(kv_end), + ) + + # Branch on has_multi_tiles: multi-tile gets prefetch, single doesn't + if_first = scf.IfOp(has_multi_tiles, first_result_types, has_else=True) + with ir.InsertionPoint(if_first.regions[0].blocks[0]): + # Multi-tile: first tile is always full, prefetch tile 1. + # Sub-branch on first_tile_cbn for compile-time check_boundary_next. + if_first_cbn = scf.IfOp(first_tile_cbn, first_result_types, has_else=True) + with ir.InsertionPoint(if_first_cbn.regions[0].blocks[0]): + # cbn=True: next tile needs boundary check (num_tiles==2, partial) + _barrier(vmcnt=0, lgkmcnt=0) + rocdl.sched_barrier(0) + rm1a, rse1a, pp1a, rs1a, nn1a = _process_tile_gemm1( + p_lds_kv_0_base, + kv_start, + kv_end, + q_nope_packs, + q_rope_packs, + row_max, + row_sum_e, + is_first_iter=True, + check_boundary=False, + p_lds_kv_next_warp_i32=p_lds_kv_1_warp_i32, + row_kv_ld_next=row_kv_ld_tile1, + kv_ld_col_base_i32_arg=kv_ld_col_base_i32, + check_boundary_next=True, + nn_resolve_start=kv_start_plus_2bn, + nn_resolve_end=_raw(kv_end), + do_resolve_nn=do_resolve_nn_first, + ) + y1a = [ + _raw(v) if not isinstance(v, ir.Value) else v + for v in [rm1a, rse1a, pp1a, rs1a, nn1a] + ] + scf.YieldOp(y1a) + with ir.InsertionPoint(if_first_cbn.regions[1].blocks[0]): + # cbn=False: next tile is full, no boundary check + _barrier(vmcnt=0, lgkmcnt=0) + rocdl.sched_barrier(0) + rm1b, rse1b, pp1b, rs1b, nn1b = _process_tile_gemm1( + p_lds_kv_0_base, + kv_start, + kv_end, + q_nope_packs, + q_rope_packs, + row_max, + row_sum_e, + is_first_iter=True, + check_boundary=False, + p_lds_kv_next_warp_i32=p_lds_kv_1_warp_i32, + row_kv_ld_next=row_kv_ld_tile1, + kv_ld_col_base_i32_arg=kv_ld_col_base_i32, + check_boundary_next=False, + nn_resolve_start=kv_start_plus_2bn, + nn_resolve_end=_raw(kv_end), + do_resolve_nn=do_resolve_nn_first, + ) + y1b = [ + _raw(v) if not isinstance(v, ir.Value) else v + for v in [rm1b, rse1b, pp1b, rs1b, nn1b] + ] + scf.YieldOp(y1b) + y1 = list(if_first_cbn.results) + scf.YieldOp(y1) + with ir.InsertionPoint(if_first.regions[1].blocks[0]): + # Single tile: no prefetch, no 2-ahead resolve + _barrier(vmcnt=0, lgkmcnt=0) + rocdl.sched_barrier(0) + rm2, rse2, pp2, rs2, nn2 = _process_tile_gemm1( + p_lds_kv_0_base, + kv_start, + kv_end, + q_nope_packs, + q_rope_packs, + row_max, + row_sum_e, + is_first_iter=True, + check_boundary=first_tile_needs_boundary, + ) + y2 = [ + _raw(v) if not isinstance(v, ir.Value) else v + for v in [rm2, rse2, pp2, rs2, nn2] + ] + scf.YieldOp(y2) + + row_max = if_first.results[0] + row_sum_e = if_first.results[1] + p_pack_first = if_first.results[2] + rescale_first = if_first.results[3] + row_kv_ld_nn_first = if_first.results[4] + + def _write_lse(pqo_loc_i32, rm, rse): + """Write LSE for split output (first 16 lanes per warp).""" + if arith.cmpi( + CmpIPredicate.ult, lane_idx_i32, arith.constant(16, type=T.i32) + ): + log2_sum = _math.log2(_raw(rse)) + ln_sum = arith.MulFOp( + log2_sum, _raw(c_inv_log2e), fastmath=fm_fast + ).result + lse = arith.AddFOp(_raw(rm), ln_sum, fastmath=fm_fast).result + row_idx_i32 = _raw( + lane_idx_i32 + + warp_idx_i32 * arith.constant(16, type=T.i32) + + arith.ArithValue(pqo_loc_i32) * arith.constant(NUM_QO_HEADS, type=T.i32) + ) + buffer_ops.buffer_store(lse, split_lse_rsrc, row_idx_i32) + + # LDS base for output reshape (reuse KV buffer 0 region) + p_lds_o_i32 = _index_cast_to_i32(p_lds_kv_0_base) + + def _do_last_gemm2_and_store( + pp, + rs, + oaccu_list, + rm, + rse, + is_first_iter_flag, + ): + """GEMM2 last tile with interleaved store + LSE write. + + Branches on partial_qo_loc to select bf16 vs f32 split output. + """ + reci = arith.DivFOp(_raw(c_one_f32), rse, fastmath=fm_fast).result + is_not_split = arith.cmpi( + CmpIPredicate.slt, + _raw(partial_qo_loc), + _raw(arith.constant(0, type=T.i32)), + ) + if_out = scf.IfOp(is_not_split, [], has_else=True) + with ir.InsertionPoint(if_out.regions[0].blocks[0]): + # bf16 final output: row_base = qo_start * NUM_QO_HEADS + warp*16 + rb_bf16 = _raw( + arith.ArithValue(_raw(qo_start)) * arith.constant(NUM_QO_HEADS, type=T.i32) + + warp_idx_i32 * arith.constant(16, type=T.i32) + ) + _gemm2_last_with_store( + pp, + rs, + list(oaccu_list), + vt_base_i32, + reci, + False, + p_lds_o_i32, + rb_bf16, + is_first_iter_flag, + ) + scf.YieldOp([]) + with ir.InsertionPoint(if_out.regions[1].blocks[0]): + # f32 split output: row_base = pqo_loc * NUM_QO_HEADS + warp*16 + rb_split = _raw( + arith.ArithValue(_raw(partial_qo_loc)) * arith.constant(NUM_QO_HEADS, type=T.i32) + + warp_idx_i32 * arith.constant(16, type=T.i32) + ) + _gemm2_last_with_store( + pp, + rs, + list(oaccu_list), + vt_base_i32, + reci, + True, + p_lds_o_i32, + rb_split, + is_first_iter_flag, + ) + _write_lse(_raw(partial_qo_loc), rm, rse) + scf.YieldOp([]) + + # ---- Multi-tile vs single-tile dispatch ---- + if_multi = scf.IfOp(has_multi_tiles, [], has_else=True) + with ir.InsertionPoint(if_multi.regions[0].blocks[0]): + # === Multi-tile path === + + # GEMM2 for first tile: C=0 hardcoded, no rescale needed + oaccu_mt = _gemm2_first_iter(p_pack_first, vt_base_i32) + + # --- Middle tiles [1, num_tiles-1) via scf.ForOp --- + c_one_idx = arith.index(1) + num_tiles_m1 = arith.subi(num_tiles, arith.constant(1, type=T.i32)) + num_tiles_m1_idx = arith.index_cast(T.index, num_tiles_m1) + num_tiles_m2 = arith.subi(num_tiles, arith.constant(2, type=T.i32)) + + init_args = [row_max, row_sum_e] + oaccu_mt + [row_kv_ld_nn_first] + init_args = [ + _raw(v) if not isinstance(v, ir.Value) else v for v in init_args + ] + + for_op = scf.ForOp( + _raw(c_one_idx), + _raw(num_tiles_m1_idx), + _raw(c_one_idx), + init_args, + ) + with ir.InsertionPoint(for_op.body): + tile_iv = for_op.induction_variable # index type + tile_iv_i32 = _index_cast_to_i32(tile_iv) + kv_tile_start_i32 = _raw( + kv_start + arith.ArithValue(tile_iv_i32) * c_block_n + ) + + # Unpack carried state + rm_carried = for_op.inner_iter_args[0] + rse_carried = for_op.inner_iter_args[1] + oaccu_carried = [ + for_op.inner_iter_args[2 + i] for i in range(NUM_PV_ITERS * 2) + ] + # 2-ahead: row resolved by previous iteration's _process_tile_gemm1 + row_kv_ld_next = for_op.inner_iter_args[2 + NUM_PV_ITERS * 2] + + # Buffer parity + tile_parity = _raw( + arith.ArithValue(tile_iv_i32) & arith.constant(1, type=T.i32) + ) + is_odd = arith.cmpi( + CmpIPredicate.ne, tile_parity, _raw(arith.constant(0, type=T.i32)), + ) + curr_base_idx = _raw(arith.select( + is_odd, _raw(p_lds_kv_1_base), _raw(p_lds_kv_0_base), + )) + next_warp = _raw(arith.select( + is_odd, p_lds_kv_0_warp_i32, p_lds_kv_1_warp_i32, + )) + + # check_boundary_next: True when tile_idx == num_tiles-2 AND last_tile_partial + is_second_to_last = arith.cmpi( + CmpIPredicate.eq, tile_iv_i32, _raw(num_tiles_m2), + ) + mid_cbn = _raw( + arith.ArithValue(is_second_to_last) & arith.ArithValue(last_tile_partial) + ) + + # 2-ahead resolve params for this iteration: + nn_start_mid = _raw( + arith.ArithValue(kv_tile_start_i32) + c_2bn + ) + do_resolve_nn_mid = arith.cmpi( + CmpIPredicate.slt, nn_start_mid, _raw(kv_end), + ) + + # Process tile: cb=False always, cbn is compile-time via sub-branch + mid_gemm1_types = [T.f32, T.f32, T.i64, T.f32, T.i32] + if_mid_tile = scf.IfOp(mid_cbn, mid_gemm1_types, has_else=True) + with ir.InsertionPoint(if_mid_tile.regions[0].blocks[0]): + # cbn=True: next tile needs boundary check + _barrier(vmcnt=0, lgkmcnt=0) + rocdl.sched_barrier(0) + rm_ma, rse_ma, pp_ma, rs_ma, nn_ma = _process_tile_gemm1( + curr_base_idx, + kv_tile_start_i32, + _raw(kv_end), + q_nope_packs, + q_rope_packs, + rm_carried, + rse_carried, + is_first_iter=False, + check_boundary=False, + p_lds_kv_next_warp_i32=next_warp, + row_kv_ld_next=row_kv_ld_next, + kv_ld_col_base_i32_arg=kv_ld_col_base_i32, + check_boundary_next=True, + nn_resolve_start=nn_start_mid, + nn_resolve_end=_raw(kv_end), + do_resolve_nn=do_resolve_nn_mid, + ) + y_ma = [ + _raw(v) if not isinstance(v, ir.Value) else v + for v in [rm_ma, rse_ma, pp_ma, rs_ma, nn_ma] + ] + scf.YieldOp(y_ma) + with ir.InsertionPoint(if_mid_tile.regions[1].blocks[0]): + # cbn=False: next tile is full, no boundary check + _barrier(vmcnt=0, lgkmcnt=0) + rocdl.sched_barrier(0) + rm_mb, rse_mb, pp_mb, rs_mb, nn_mb = _process_tile_gemm1( + curr_base_idx, + kv_tile_start_i32, + _raw(kv_end), + q_nope_packs, + q_rope_packs, + rm_carried, + rse_carried, + is_first_iter=False, + check_boundary=False, + p_lds_kv_next_warp_i32=next_warp, + row_kv_ld_next=row_kv_ld_next, + kv_ld_col_base_i32_arg=kv_ld_col_base_i32, + check_boundary_next=False, + nn_resolve_start=nn_start_mid, + nn_resolve_end=_raw(kv_end), + do_resolve_nn=do_resolve_nn_mid, + ) + y_mb = [ + _raw(v) if not isinstance(v, ir.Value) else v + for v in [rm_mb, rse_mb, pp_mb, rs_mb, nn_mb] + ] + scf.YieldOp(y_mb) + rm_m = if_mid_tile.results[0] + rse_m = if_mid_tile.results[1] + pp_m = if_mid_tile.results[2] + rs_m = if_mid_tile.results[3] + nn_m = if_mid_tile.results[4] + oa_m = _gemm2_with_rescale(pp_m, rs_m, oaccu_carried, vt_base_i32) + yield_vals = [rm_m, rse_m] + oa_m + [nn_m] + yield_vals = [ + _raw(v) if not isinstance(v, ir.Value) else v for v in yield_vals + ] + scf.YieldOp(yield_vals) + + # Unpack results from middle tiles loop + row_max_mt = for_op.results[0] + row_sum_e_mt = for_op.results[1] + oaccu_mt = [for_op.results[2 + i] for i in range(NUM_PV_ITERS * 2)] + + # --- Last tile: GEMM1 + interleaved GEMM2 store --- + last_tile_iv_i32 = _raw(num_tiles_m1) + kv_last_start = _raw( + kv_start + arith.ArithValue(last_tile_iv_i32) * c_block_n + ) + last_parity = _raw( + arith.ArithValue(last_tile_iv_i32) & arith.constant(1, type=T.i32) + ) + last_is_odd = arith.cmpi( + CmpIPredicate.ne, last_parity, _raw(arith.constant(0, type=T.i32)), + ) + last_curr_base = _raw(arith.select( + last_is_odd, _raw(p_lds_kv_1_base), _raw(p_lds_kv_0_base), + )) + + _barrier(vmcnt=0, lgkmcnt=0) + rocdl.sched_barrier(0) + rm_l, rse_l, pp_l, rs_l, _nn_l = _process_tile_gemm1( + last_curr_base, + kv_last_start, + _raw(kv_end), + q_nope_packs, + q_rope_packs, + row_max_mt, + row_sum_e_mt, + is_first_iter=False, + check_boundary=last_tile_partial, + ) + _do_last_gemm2_and_store( + pp_l, + rs_l, + oaccu_mt, + rm_l, + rse_l, + is_first_iter_flag=False, + ) + scf.YieldOp([]) + + with ir.InsertionPoint(if_multi.regions[1].blocks[0]): + # === Single tile path: GEMM2 with interleaved store === + oaccu_st = [_raw(c_zero_v4f32)] * (NUM_PV_ITERS * 2) + _do_last_gemm2_and_store( + p_pack_first, + rescale_first, + oaccu_st, + row_max, + row_sum_e, + is_first_iter_flag=True, + ) + scf.YieldOp([]) + + +# --------------------------------------------------------------------------- +# JIT launcher +# --------------------------------------------------------------------------- +@flyc.jit +def launch_mla_fwd_decode_m16x8_fp8_fp8( + query: fx.Tensor, + kv_buffer: fx.Tensor, + kv_page_indices: fx.Tensor, + work_indptr: fx.Tensor, + work_info_set: fx.Tensor, + final_output: fx.Tensor, + split_output: fx.Tensor, + split_lse: fx.Tensor, + softmax_scale: fx.Float32, + num_cus: fx.Constexpr, + lds_size: fx.Constexpr, + stream: fx.Stream = fx.Stream(None), +): + """JIT host function: configures grid/block and launches the kernel.""" + assert TOTAL_LDS_BYTES <= lds_size, ( + f"Kernel requires {TOTAL_LDS_BYTES} bytes LDS but CU budget is {lds_size}" + ) + kn_mla_fwd_decode_m16x8_fp8_fp8( + query, + kv_buffer, + kv_page_indices, + work_indptr, + work_info_set, + final_output, + split_output, + split_lse, + softmax_scale, + ).launch( + grid=(num_cus, 1, 1), + block=(NUM_THREADS, 1, 1), + smem=0, # LDS is statically allocated via SmemAllocator + stream=stream, + ) diff --git a/tests/kernels/test_mla_decode.py b/tests/kernels/test_mla_decode.py new file mode 100644 index 00000000..dd20fe6e --- /dev/null +++ b/tests/kernels/test_mla_decode.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + +""" +Simplified MLA decode test for FlyDSL kernel. + +Tests the FlyDSL MLA decode kernel (fp8 Q, fp8 KV, nhead=128, page_size=1) +using aiter for metadata generation and reduce. + +Usage: + cd /jruan/ws/FlyDSL + python tests/kernels/test_mla_decode.py -b 1 -c 128 + python tests/kernels/test_mla_decode.py -b 32 -c 8192 +""" + +import sys +import os +import argparse +import logging + +import torch +import pytest + +sys.path.insert(0, "build-fly/python_packages") +sys.path.insert(1, ".") +os.environ["FLYDSL_RUNTIME_ENABLE_CACHE"] = "1" +logging.basicConfig(level=logging.INFO, format="%(message)s") + +pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] + +aiter = pytest.importorskip("aiter", reason="aiter is not installed, skipping MLA tests") +from aiter import dtypes +from aiter.ops.attention import ( + get_mla_metadata_info_v1, + get_mla_metadata_v1, + mla_reduce_v1, +) + +from flydsl.runtime.device import get_rocm_arch + +from tests.test_common import run_perftest, checkAllclose +from kernels.mla_fwd_decode import flydsl_mla_fwd_decode + +torch.set_default_device("cuda") + +logger = logging.getLogger("mla_decode_test") + +_GPU_ARCH = str(get_rocm_arch()) + +# ── Model constants (DeepSeek-V3 / R1) ────────────────────────── +KV_LORA_RANK = 512 +QK_NOPE_HEAD_DIM = 512 +QK_ROPE_HEAD_DIM = 64 +QK_HEAD_DIM = KV_LORA_RANK + QK_ROPE_HEAD_DIM # 576 +V_HEAD_DIM = 512 +NHEAD = 128 +NHEAD_KV = 1 +PAGE_SIZE = 1 + + +# ── Pure-PyTorch reference ────────────────────────────────────── + +def ref_masked_attention(query, key, value, scale, dtype, q_scale=None, kv_scale=None): + """Single-sequence MLA attention (no causal mask needed for decode_qlen=1).""" + s = scale + if q_scale is not None: + s *= q_scale + if kv_scale is not None: + s *= kv_scale + + attn_weights = torch.einsum("qhd,khd->hqk", query.float(), key.float()) * s + lse = attn_weights.logsumexp(dim=-1) + m = attn_weights.max(-1).values + attn_weights_exp = torch.exp(attn_weights - m.unsqueeze(-1)) + l = attn_weights_exp.sum(-1) + out = torch.einsum("hqk,khd->qhd", attn_weights_exp.float(), value.float()) + out = out / l.transpose(0, 1).unsqueeze(-1) + if kv_scale is not None: + out *= kv_scale + return out.to(dtype), lse + + +def torch_mla_extend(q, kvc_cache, qo_indptr, kv_indptr, kv_indices, + kv_last_page_lens, sm_scale, dtype): + """Pure-PyTorch paged MLA attention reference.""" + is_fp8_q = q.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) + is_fp8_kvc = kvc_cache.dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz) + + if is_fp8_q: + q = q.to(torch.float) + if is_fp8_kvc: + kvc_cache = kvc_cache.to(torch.float) + + num_page, page_size, nhead_kv, _ = kvc_cache.shape + qs = torch.tensor_split(q, qo_indptr.tolist()[1:]) + kvc = torch.index_select(kvc_cache, 0, kv_indices) + kvs = torch.tensor_split(kvc, kv_indptr.tolist()[1:]) + bs = qo_indptr.shape[0] - 1 + + os_list = [] + lses = [] + for i in range(bs): + cur_num_page = kvs[i].shape[0] + real_kv_seq_len = (cur_num_page - 1) * page_size + kv_last_page_lens.tolist()[i] + kvc_i = kvs[i].flatten(0, 1)[:real_kv_seq_len] + q_i = qs[i] + k_i = kvc_i + v_i = kvc_i[:, :, :KV_LORA_RANK] + o_i, lse_i = ref_masked_attention(q_i, k_i, v_i, sm_scale, dtype) + os_list.append(o_i) + lses.append(lse_i) + + o = torch.concat(os_list) + lse = torch.concat(lses, dim=1).transpose(0, 1) + return o, lse + + +# ── Test driver ───────────────────────────────────────────────── + +def run_single(batch_size, ctx_len, decode_qlen=1, max_split_per_batch=32): + nhead = NHEAD + nhead_kv = NHEAD_KV + page_size = PAGE_SIZE + fp8 = dtypes.fp8 + out_dtype = torch.bfloat16 + + kv_max_sz = 65536 * 32 + num_page = (kv_max_sz + page_size - 1) // page_size + + # ── Sequence metadata ── + seq_lens_kv = torch.full((batch_size,), ctx_len, dtype=torch.int) + kv_block_nums = torch.full((batch_size,), (ctx_len + page_size - 1) // page_size, + dtype=torch.int) + kv_last_page_lens = torch.ones(batch_size, dtype=torch.int) + if ctx_len % page_size != 0: + kv_last_page_lens.fill_(ctx_len % page_size) + + kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int) + kv_indptr[1:] = torch.cumsum(kv_block_nums, dim=0) + num_page = kv_indptr[-1].item() + kv_indices = torch.randperm(num_page, dtype=torch.int) + + seq_lens_qo = torch.full((batch_size,), decode_qlen, dtype=torch.int) + qo_indptr = torch.zeros(batch_size + 1, dtype=torch.int) + qo_indptr[1:] = torch.cumsum(seq_lens_qo, dim=0) + total_q = qo_indptr[-1].item() + max_seqlen_qo = decode_qlen + + # ── KV buffer and Q ── + kv_buffer = torch.randn((num_page, page_size, nhead_kv, QK_HEAD_DIM), + dtype=torch.bfloat16) + kv_buffer_fp8 = kv_buffer.to(fp8) + + q = torch.randn((total_q, nhead, QK_HEAD_DIM), dtype=torch.bfloat16) + q_fp8 = q.to(fp8) + + sm_scale = 1.0 / (QK_HEAD_DIM ** 0.5) + + # ── PyTorch reference (using fp8 data, cast to float internally) ── + out_ref, lse_ref = torch_mla_extend( + q_fp8, kv_buffer_fp8.view(num_page, page_size, nhead_kv, QK_HEAD_DIM), + qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, + sm_scale, out_dtype, + ) + + # ── Limit splits for large nhead ── + gpu = torch.cuda.current_device() + cu_num = torch.cuda.get_device_properties(gpu).multi_processor_count + max_split_per_batch = min( + (cu_num + batch_size - 1) // batch_size, max_split_per_batch + ) + + # ── Metadata via aiter ── + ( + (work_meta_data_size, work_meta_data_type), + (work_indptr_size, work_indptr_type), + (work_info_set_size, work_info_set_type), + (reduce_indptr_size, reduce_indptr_type), + (reduce_final_map_size, reduce_final_map_type), + (reduce_partial_map_size, reduce_partial_map_type), + ) = get_mla_metadata_info_v1( + batch_size, max_seqlen_qo, nhead, fp8, fp8, + is_sparse=False, fast_mode=True, + num_kv_splits=max_split_per_batch, + intra_batch_mode=False, + ) + + work_meta_data = torch.empty(work_meta_data_size, dtype=work_meta_data_type, device="cuda") + work_indptr = torch.empty(work_indptr_size, dtype=work_indptr_type, device="cuda") + work_info_set = torch.empty(work_info_set_size, dtype=work_info_set_type, device="cuda") + reduce_indptr = torch.empty(reduce_indptr_size, dtype=reduce_indptr_type, device="cuda") + reduce_final_map = torch.empty(reduce_final_map_size, dtype=reduce_final_map_type, device="cuda") + reduce_partial_map = torch.empty(reduce_partial_map_size, dtype=reduce_partial_map_type, device="cuda") + + get_mla_metadata_v1( + qo_indptr, kv_indptr, kv_last_page_lens, + nhead // nhead_kv, nhead_kv, False, + work_meta_data, work_info_set, work_indptr, + reduce_indptr, reduce_final_map, reduce_partial_map, + kv_granularity=max(page_size, 16), + max_seqlen_qo=int(max_seqlen_qo), + uni_seqlen_qo=decode_qlen, + fast_mode=True, + max_split_per_batch=max_split_per_batch, + intra_batch_mode=False, + dtype_q=fp8, + dtype_kv=fp8, + ) + + # ── Allocate output / partial buffers ── + out_asm = torch.empty((total_q, nhead, V_HEAD_DIM), dtype=out_dtype).fill_(-1) + + logits = torch.empty( + (reduce_partial_map.size(0) * max_seqlen_qo, 1, nhead, V_HEAD_DIM), + dtype=torch.float32, device="cuda", + ) + attn_lse = torch.empty( + (reduce_partial_map.size(0) * max_seqlen_qo, 1, nhead, 1), + dtype=torch.float32, device="cuda", + ) + + # ── Launch FlyDSL kernel ── + def launch_kernel(): + flydsl_mla_fwd_decode( + q_fp8, + kv_buffer_fp8.view(num_page, page_size, nhead_kv, QK_HEAD_DIM), + kv_indices, + work_indptr, + work_info_set, + out_asm, + logits, + attn_lse, + sm_scale, + ) + mla_reduce_v1( + logits, attn_lse, + reduce_indptr, reduce_final_map, reduce_partial_map, + max_seqlen_qo, out_asm, None, + ) + + _, us = run_perftest(launch_kernel, num_iters=10, num_warmup=3) + torch.cuda.synchronize() + + # ── Verify ── + total_kv = seq_lens_kv.sum().item() + err = checkAllclose( + out_ref, out_asm, + msg=f"[b={batch_size} c={ctx_len}] golden vs flydsl: {us:>8.2f} us ... ", + ) + + # Cosine similarity check + x, y = out_ref.double(), out_asm.double() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + + flops = decode_qlen * total_kv * nhead * (QK_HEAD_DIM + V_HEAD_DIM) * 2 + bw = ( + total_kv * nhead_kv * QK_HEAD_DIM * 1 # fp8 = 1 byte + + total_q * nhead * QK_HEAD_DIM * 1 + + total_q * nhead * V_HEAD_DIM * 2 # bf16 = 2 bytes + ) + + logger.info( + f" cos_diff={cos_diff:.2e} TFLOPS={flops / us / 1e6:.2f} " + f"TB/s={bw / us / 1e6:.2f} err_ratio={err:.2%}" + ) + assert cos_diff < 3e-2, f"cos_diff={cos_diff} exceeds threshold" + return err, us + + +# ── pytest ────────────────────────────────────────────────────── + +# On gfx950, AITER folds nh=128 + fp8/fp8 to a nh=16 work-info layout +# instead of generating the native nh=128 layout. This FlyDSL kernel only +# decodes the native nh=128 layout, so it cannot run against AITER's +# gfx950 metadata. +@pytest.mark.skipif( + _GPU_ARCH == "gfx950", + reason=( + "AITER metadata for nh=128 + fp8/fp8 on gfx950 uses the folded " + "nh=16 layout, which this FlyDSL MLA kernel does not support." + ), +) +@pytest.mark.parametrize("batch_size,ctx_len", [ + (1, 128), + (4, 2048), + (32, 8192), +]) +def test_mla_decode(batch_size, ctx_len): + run_single(batch_size, ctx_len) + + +# ── CLI (local benchmarking) ──────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser(description="FlyDSL MLA decode test") + parser.add_argument("-b", "--batch", type=int, nargs="*", default=[1, 32]) + parser.add_argument("-c", "--ctx_len", type=int, nargs="*", default=[128, 8192]) + parser.add_argument("-ms", "--max_splits", type=int, default=32) + args = parser.parse_args() + + for b in args.batch: + for c in args.ctx_len: + logger.info(f"\n{'='*60}") + logger.info(f"batch={b} ctx_len={c}") + logger.info(f"{'='*60}") + run_single(b, c, max_split_per_batch=args.max_splits) + + logger.info("\nAll tests passed.") + + +if __name__ == "__main__": + main() From 51917465b94cd474ce2780bf94a886a975ca2ac5 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Thu, 23 Apr 2026 22:46:52 +0800 Subject: [PATCH 28/29] [Perf] Port mixed_moe kernel optimizations for stage1/stage2 (#388) * [Perf] Port aiter mixed_moe kernel optimizations for stage1/stage2 Port performance-critical optimizations from aiter's mixed_moe_gemm_2stage kernel body (both stage1 and stage2) into FlyDSL, along with supporting infrastructure changes. Key changes: - mixed_moe_gemm_2stage.py: Full kernel body replacement with aiter version featuring dual SmemAllocator (ping-pong), unified MFMA pipeline schedule, _barrier() for fine-grained waitcnt control, and new parameters (persist_m, fuse_fp4_quant, fuse_sort_scale, use_async_copy, sort_block_m, etc.) - layout_utils.py: New file ported from aiter for layout index arithmetic (crd2idx, idx2crd, _div_pow2, _mod_pow2) - silu_and_mul_fq.py: New file ported from aiter for split-K + fp4 quant after silu fusion - mfma_preshuffle_pipeline.py: Added k_major support, cache_modifier param, bitwise-AND optimization in swizzle_xor16, PreshuffleScaleLayout additions - kernels_common.py: Extracted shared _if_then context manager and validate_moe_dtypes helper - mfma_epilogues.py: Replaced local _if_then with shared import Performance (DeepSeek TP8 FP4, 7168x256, E=256, K=8): - Stage1 Decode t=1: 37.3 -> 26.2 us (-29.8%) - Stage1 Decode t=8: 45.0 -> 31.0 us (-31.1%) - Stage1 Prefill 8K: 561.8 -> 348.8 us (-37.9%) - Stage2 Prefill 8K reduce: 569.1 -> 534.8 us (-6.0%) - FP8 stage2 unchanged (within noise) * fix ci * merge a8w4 moe * add a8w4 bench --- kernels/kernels_common.py | 37 +- kernels/layout_utils.py | 181 ++ kernels/mfma_epilogues.py | 202 +- kernels/mfma_preshuffle_pipeline.py | 227 +- kernels/mixed_moe_gemm_2stage.py | 4705 ++++++++++++++++++--------- kernels/silu_and_mul_fq.py | 419 +++ scripts/run_benchmark.sh | 145 +- tests/kernels/test_moe_gemm.py | 301 +- tests/kernels/test_ref.py | 138 +- 9 files changed, 4582 insertions(+), 1773 deletions(-) create mode 100644 kernels/layout_utils.py create mode 100644 kernels/silu_and_mul_fq.py diff --git a/kernels/kernels_common.py b/kernels/kernels_common.py index 3af725a6..3d73bdae 100644 --- a/kernels/kernels_common.py +++ b/kernels/kernels_common.py @@ -7,13 +7,48 @@ but this module is intentionally small and MLIR-dialect facing. """ +from contextlib import contextmanager + from flydsl._mlir import ir from flydsl.expr.typing import T -from flydsl._mlir.dialects import arith as _std_arith, builtin, gpu as _gpu, llvm as _llvm +from flydsl._mlir.dialects import arith as _std_arith, builtin, gpu as _gpu, llvm as _llvm, scf as _scf from flydsl.expr import buffer_ops from flydsl.runtime.device import get_rocm_arch, is_rdna_arch +@contextmanager +def _if_then(if_op, scf=None): + """Context manager for SCF IfOp then-region across old/new Python APIs. + + Ensures the then block always ends with a YieldOp. + The optional *scf* parameter is accepted for backward compatibility + but ignored — the module-level import is used. + """ + with ir.InsertionPoint(if_op.then_block): + try: + yield if_op.then_block + finally: + blk = if_op.then_block + if (not blk.operations) or not isinstance(blk.operations[-1], _scf.YieldOp): + _scf.YieldOp([]) + + +_VALID_A_DTYPES = frozenset(("fp8", "fp16", "int8", "fp4")) +_VALID_B_DTYPES = frozenset(("fp8", "fp16", "int8", "int4", "fp4")) + + +def validate_moe_dtypes(a_dtype: str, b_dtype: str) -> None: + """Validate a_dtype/b_dtype strings for mixed MoE kernels.""" + if a_dtype not in _VALID_A_DTYPES: + raise ValueError( + f"a_dtype must be one of {tuple(sorted(_VALID_A_DTYPES))}, got {a_dtype!r}" + ) + if b_dtype not in _VALID_B_DTYPES: + raise ValueError( + f"b_dtype must be one of {tuple(sorted(_VALID_B_DTYPES))}, got {b_dtype!r}" + ) + + def dtype_to_elem_type(dtype_str: str): """Map a dtype string to its MLIR scalar type. diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py new file mode 100644 index 00000000..350c9c48 --- /dev/null +++ b/kernels/layout_utils.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Layout helpers for GEMM kernels. + +Parses fly layout type strings (e.g. '(4,64):(64,1)') and computes +idx2crd / crd2idx with plain arith ops for static layouts. +Falls back to fly dialect ops for dynamic layouts. + +Optimisation: power-of-2 strides/shapes emit ``shrui`` / ``andi`` instead of +``divui`` / ``remui``, avoiding 10-15-cycle V_DIV sequences on CDNA GPUs. +""" + +import math as _math +import re +import builtins as _builtins + +import flydsl.expr as fx +from flydsl._mlir import ir +from flydsl.expr import arith +from flydsl.expr.arith import ArithValue +from flydsl.expr.typing import T + + +def _wrap(v): + """Wrap raw ir.Value in ArithValue for operator overloading compatibility.""" + if isinstance(v, ArithValue): + return v + if isinstance(v, ir.Value): + return ArithValue(v) + return v + + +def _is_pow2(n): + """Return True when *n* is a positive power of two.""" + return n > 0 and (n & (n - 1)) == 0 + + +def _div_pow2(val, divisor): + """Unsigned divide index *val* by a **compile-time** power-of-2 *divisor*. + + Emits ``arith.shrui`` (1 VALU cycle) instead of ``arith.divui`` + (10-15 VALU cycles on CDNA). + """ + shift = _math.log2(divisor) + assert shift == int(shift), f"{divisor} is not a power of 2" + return arith.shrui(val, arith.index(int(shift))) + + +def _mod_pow2(val, modulus): + """Unsigned remainder of index *val* by a **compile-time** power-of-2 *modulus*. + + Emits ``arith.andi`` (1 VALU cycle) instead of ``arith.remui``. + """ + return arith.andi(val, arith.index(modulus - 1)) + + +def _parse_dim(tok): + """Parse a single dimension token: '?' -> None, otherwise int.""" + tok = tok.strip() + return None if tok == "?" else int(tok) + + +def _parse_layout(ly): + """Parse '(s0,s1,...):(d0,d1,...)' -> (shapes, strides) as lists (None for '?').""" + ly_str = str(ly.type) if hasattr(ly, "type") else str(ly) + m = re.search(r"\(([^)]+)\):\(([^)]+)\)", ly_str) + if not m: + return None + shapes = [_parse_dim(s) for s in m.group(1).split(",")] + strides = [_parse_dim(s) for s in m.group(2).split(",")] + return shapes, strides + + +def _has_dynamic_strides(strides): + """Check if any stride is dynamic (None).""" + return any(s is None for s in strides) + + +def idx2crd(idx, layout): + """Decompose flat index into a list of coordinate values. + + For static layouts, computes coordinates with plain arith ops. + Power-of-2 strides/shapes use shift/mask instead of div/rem. + For dynamic layouts, falls back to fx.idx2crd + fx.get. + """ + parsed = _parse_layout(layout) + + if parsed is None or _has_dynamic_strides(parsed[1]): + result = fx.idx2crd(idx, layout) + ndims = len(parsed[1]) if parsed else 1 + return [_wrap(fx.get(result, i)) for i in range(ndims)] + + if hasattr(idx, "type") and str(idx.type) != "index": + idx = arith.index_cast(T.index, idx) + shapes, strides = parsed + ndims = len(strides) + + ordered = sorted( + [ + (i, s, sz) + for i, s, sz in _builtins.zip(range(ndims), strides, shapes) + if s != 0 + ], + key=lambda x: x[1], + reverse=True, + ) + coords = [None] * ndims + remaining = idx + for i, stride_val, size_val in ordered: + if stride_val == 1: + c = remaining + elif _is_pow2(stride_val): + c = _div_pow2(remaining, stride_val) + else: + c = remaining / arith.index(stride_val) + if size_val is not None: + if _is_pow2(size_val): + c = _mod_pow2(c, size_val) + else: + c = c % arith.index(size_val) + coords[i] = c + for i in range(ndims): + if coords[i] is None: + coords[i] = remaining + return coords + + +def crd2idx(crd, layout): + """Compute flat index from a coordinate tuple/list. + + For static layouts, computes with plain arith ops. + For dynamic layouts, falls back to fx.crd2idx with fx.make_coord. + """ + if not isinstance(crd, (list, tuple)): + crd = [crd] + parsed = _parse_layout(layout) + + if parsed is None or _has_dynamic_strides(parsed[1]): + crd_i32 = [] + for c in crd: + cv = c + if isinstance(cv, int): + cv = arith.constant(cv, T.i32) + crd_i32.append(cv) + continue + if isinstance(cv, ArithValue): + raw = cv.ir_value() if hasattr(cv, "ir_value") else cv + if isinstance(raw, ir.Value) and isinstance(raw.type, ir.IndexType): + cv = arith.index_cast(T.i32, raw) + else: + cv = raw + elif isinstance(cv, ir.Value) and isinstance(cv.type, ir.IndexType): + cv = arith.index_cast(T.i32, cv) + elif hasattr(cv, "ir_value"): + raw = cv.ir_value() + if isinstance(raw, ir.Value) and isinstance(raw.type, ir.IndexType): + cv = arith.index_cast(T.i32, raw) + else: + cv = raw + crd_i32.append(cv) + coord_val = fx.make_coord(*crd_i32) + result = fx.crd2idx(coord_val, layout) + scalar = fx.get_scalar(result) + if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): + scalar = arith.index_cast(T.index, scalar) + return _wrap(scalar) + + _, strides = parsed + result = None + for coord_v, stride_v in _builtins.zip(crd, strides): + if stride_v == 0: + continue + term = coord_v if stride_v == 1 else coord_v * arith.index(stride_v) + result = term if result is None else result + term + return result if result is not None else arith.index(0) + + +def get(int_tuple, mode): + """Extract element at `mode` from a Python list/tuple.""" + return int_tuple[mode] diff --git a/kernels/mfma_epilogues.py b/kernels/mfma_epilogues.py index 89174fd0..cee72352 100644 --- a/kernels/mfma_epilogues.py +++ b/kernels/mfma_epilogues.py @@ -23,30 +23,24 @@ 3) remap threads into (MLane, NLane) = (8,32) and read half2 from LDS, then call `store_pair(...)` to emit the final global store/atomic. + When ``lds_out_split`` is provided, the epilogue runs in split-LDS mode: + waves are partitioned into two groups (group A uses ``lds_out``, group B + uses ``lds_out_split``), each handling half of the N dimension. + These helpers are intentionally *dialect-agnostic*: callers pass the dialect modules (`arith`, `vector`, `gpu`) and the `range_constexpr` iterator. """ from __future__ import annotations -from contextlib import contextmanager from typing import Callable from flydsl._mlir import ir import flydsl.expr as fx +from flydsl._mlir.dialects.arith import CmpIPredicate from flydsl.expr.typing import T - -@contextmanager -def _if_then(if_op, scf): - """Compat helper for SCF IfOp then-region across old/new Python APIs.""" - with ir.InsertionPoint(if_op.then_block): - try: - yield if_op.then_block - finally: - blk = if_op.then_block - if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): - scf.YieldOp([]) +from kernels.kernels_common import _if_then def default_epilog( @@ -114,6 +108,12 @@ def c_shuffle_epilog( write_row_to_lds: Callable, precompute_row: Callable | None = None, store_pair: Callable, + # When LDS overflows, split lds_out across two buffers by wave-group. + # Pass the second buffer here; first buffer is `lds_out`. + lds_out_split=None, + # Row offset in lds_out for 8-wave mode (MLIR index value). + # Shifts both write and read LDS indices by lds_row_offset * tile_n elements. + lds_row_offset=None, ): """LDS CShuffle epilogue skeleton. @@ -140,14 +140,173 @@ def c_shuffle_epilog( f"tile_n must be divisible by (CShuffleNLane*EVec) = {cshuffle_nlane*e_vec}, got tile_n={tile_n}" ) + # ===================== Split-LDS mode (early return) ===================== + # When lds_out_split is provided, waves are divided into two groups: + # Group A (waves 0..N/2-1) uses lds_out, columns [0, tile_n/2) + # Group B (waves N/2..N-1) uses lds_out_split, columns [tile_n/2, tile_n) + # Each group writes/reads independently; same barriers synchronise all waves. + if lds_out_split is not None: + if scf is None: + raise ValueError("scf module is required for split-LDS cshuffle") + + _half_n = int(tile_n) // 2 + _half_threads = int(block_size) // 2 + EVec = int(e_vec) + + CShuffleNLane_s = min(int(cshuffle_nlane), _half_n // EVec) + if _half_threads % CShuffleNLane_s != 0: + raise ValueError( + f"half_threads={_half_threads} not divisible by CShuffleNLane_split={CShuffleNLane_s}" + ) + CShuffleMLane_s = _half_threads // CShuffleNLane_s + if int(tile_m) % CShuffleMLane_s != 0: + raise ValueError( + f"tile_m={tile_m} not divisible by CShuffleMLane_split={CShuffleMLane_s}" + ) + m_reps_s = int(tile_m) // CShuffleMLane_s + n_reps_s = _half_n // (CShuffleNLane_s * EVec) + + _half_n_idx = arith.constant(_half_n, index=True) + _half_thr_idx = arith.constant(_half_threads, index=True) + _zero_idx = arith.constant(0, index=True) + + _is_group_b = arith.cmpi(CmpIPredicate.uge, tx, _half_thr_idx) + + # -- write phase (all waves, each to its group's LDS buffer) -- + n_tile_base_v = n_tile_base + col_base_local_a = n_tile_base_v + lane_mod_16 + col_base_local_b = col_base_local_a - _half_n_idx + + def _write_row_split(mi: int, ii: int, row_in_tile, row): + row_base_lds = row_in_tile * _half_n_idx + _if_g = scf.IfOp(_is_group_b) + with ir.InsertionPoint(_if_g.then_block): + write_row_to_lds( + mi=mi, + ii=ii, + row_in_tile=row_in_tile, + row=row, + row_base_lds=row_base_lds, + col_base_local=col_base_local_b, + num_acc_n=num_acc_n, + lds_out=lds_out_split, + ) + scf.YieldOp([]) + with ir.InsertionPoint(_if_g.else_block): + write_row_to_lds( + mi=mi, + ii=ii, + row_in_tile=row_in_tile, + row=row, + row_base_lds=row_base_lds, + col_base_local=col_base_local_a, + num_acc_n=num_acc_n, + lds_out=lds_out, + ) + scf.YieldOp([]) + + gpu.barrier() + default_epilog( + arith=arith, + range_constexpr=range_constexpr, + m_repeat=m_repeat, + lane_div_16=lane_div_16, + bx_m=bx_m, + body_row=_write_row_split, + ) + gpu.barrier() + + # -- read phase (each group reads from its own LDS buffer) -- + tx_local = tx - arith.select(_is_group_b, _half_thr_idx, _zero_idx) + c_nlane_s = arith.constant(CShuffleNLane_s, index=True) + m_lane_s = tx_local / c_nlane_s + n_lane_s = tx_local % c_nlane_s + c_evec = arith.constant(EVec, index=True) + + if frag_elem_type is None: + frag_elem_type = T.f16 + vec_frag = T.vec(EVec, frag_elem_type) + bx_m_v = bx_m + by_n_v = by_n + + _precomputed_rows_s = [] + for mr in range_constexpr(m_reps_s): + row_base_m = arith.constant(mr * CShuffleMLane_s, index=True) + row_local = row_base_m + m_lane_s + row = bx_m_v + row_local + row_ctx_raw = ( + precompute_row(row_local=row_local, row=row) + if precompute_row is not None + else None + ) + row_ctx = row_ctx_raw + row_pred = None + if ( + scf is not None + and row_ctx_raw is not None + and isinstance(row_ctx_raw, tuple) + and len(row_ctx_raw) == 2 + ): + row_ctx, row_pred = row_ctx_raw + _precomputed_rows_s.append((row_local, row, row_ctx, row_pred)) + + for mr in range_constexpr(m_reps_s): + row_local, row, row_ctx, row_pred = _precomputed_rows_s[mr] + + def _do_store_row_split(): + row_base_lds = row_local * _half_n_idx + for nr in range_constexpr(n_reps_s): + col_base_nr = arith.constant( + nr * (CShuffleNLane_s * EVec), index=True + ) + col_pair0_local = col_base_nr + (n_lane_s * c_evec) + lds_idx = row_base_lds + col_pair0_local + + _if_ld = scf.IfOp(_is_group_b, [vec_frag]) + with ir.InsertionPoint(_if_ld.then_block): + fb = vector.load_op(vec_frag, lds_out_split, [lds_idx]) + scf.YieldOp([fb]) + with ir.InsertionPoint(_if_ld.else_block): + fa = vector.load_op(vec_frag, lds_out, [lds_idx]) + scf.YieldOp([fa]) + frag = _if_ld.results[0] + + col_pair0 = col_pair0_local + arith.select( + _is_group_b, _half_n_idx, _zero_idx + ) + store_pair( + row_local=row_local, + row=row, + row_ctx=row_ctx, + col_pair0=col_pair0, + col_g0=by_n_v + col_pair0, + frag=frag, + ) + + if row_pred is not None: + _if_row = scf.IfOp(row_pred) + with _if_then(_if_row, scf): + _do_store_row_split() + else: + _do_store_row_split() + + return # split path complete + + # ===================== Standard (non-split) path below ===================== + # ---------------- Step 1: write C tile to LDS (row-major, fp16) ---------------- tile_n_idx = arith.constant(int(tile_n), index=True) n_tile_base_v = n_tile_base col_base_local = n_tile_base_v + lane_mod_16 # index within [0,tile_n) + _lds_row_base_offset = ( + lds_row_offset * tile_n_idx if lds_row_offset is not None else None + ) + def _write_row(mi: int, ii: int, row_in_tile, row): - # row_base_lds = row_in_tile * tile_n row_base_lds = row_in_tile * tile_n_idx + if _lds_row_base_offset is not None: + row_base_lds = row_base_lds + _lds_row_base_offset write_row_to_lds( mi=mi, ii=ii, @@ -192,13 +351,19 @@ def _write_row(mi: int, ii: int, row_in_tile, row): bx_m_v = bx_m by_n_v = by_n + # Batch-precompute all row contexts (sorted_idx loads) before the store loop. + # This issues all buffer_load instructions upfront so the compiler can pipeline + # them instead of serializing each load with s_waitcnt vmcnt(0). + _precomputed_rows = [] for mr in range_constexpr(m_reps_shuffle): row_base_m = arith.constant(mr * CShuffleMLane, index=True) row_local = row_base_m + m_lane row = bx_m_v + row_local row_ctx_raw = ( - precompute_row(row_local=row_local, row=row) if precompute_row is not None else None + precompute_row(row_local=row_local, row=row) + if precompute_row is not None + else None ) # Optional row-level predicate: if `precompute_row` returns `(ctx, pred_i1)` and `scf` @@ -213,8 +378,16 @@ def _write_row(mi: int, ii: int, row_in_tile, row): ): row_ctx, row_pred = row_ctx_raw + _precomputed_rows.append((row_local, row, row_ctx, row_pred)) + + # Now perform LDS reads and stores using the pre-fetched row contexts. + for mr in range_constexpr(m_reps_shuffle): + row_local, row, row_ctx, row_pred = _precomputed_rows[mr] + def _do_store_row(): row_base_lds = row_local * tile_n_idx + if _lds_row_base_offset is not None: + row_base_lds = row_base_lds + _lds_row_base_offset for nr in range_constexpr(n_reps_shuffle): col_base_nr = arith.constant(nr * (CShuffleNLane * EVec), index=True) col_pair0 = col_base_nr + (n_lane * c_evec) # even col within tile @@ -307,4 +480,3 @@ def mfma_epilog( precompute_row=precompute_row, store_pair=store_pair, ) - diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 41e67047..1ec312e5 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -28,9 +28,15 @@ def crd2idx(crd, layout): def swizzle_xor16(row, col, k_blocks16): """XOR-with-row swizzle on the K dimension at 16B granularity. - Computes: col XOR ((row % k_blocks16) * 16) + Computes: col XOR ((row & (k_blocks16 - 1)) * 16) + + k_blocks16 is always a power of 2 (tile_k_bytes / 16), so use + bitwise AND instead of remui to save ~10 VALU cycles on CDNA. """ - rem = row % k_blocks16 + from flydsl.expr import arith as _swz_arith + + mask = k_blocks16 - _swz_arith.index(1) + rem = _swz_arith.andi(row, mask) return col ^ (rem * 16) @@ -45,20 +51,39 @@ def split_row_major_2d(index, minor_extent): return index // minor_extent, index % minor_extent -def _buffer_load_vec(buffer_ops, vector, rsrc, idx, *, elem_type, vec_elems, elem_bytes, offset_in_bytes): +def _buffer_load_vec( + buffer_ops, + vector, + rsrc, + idx, + *, + elem_type, + vec_elems, + elem_bytes, + offset_in_bytes, + cache_modifier=0, +): """Load vec_elems elements via buffer_load dwordx[1,2,4] + bitcast.""" + from flydsl.expr import arith as _ld_arith + elem_size = int(elem_bytes) load_bytes = int(vec_elems) * elem_size vec_width = load_bytes // 4 if offset_in_bytes: - idx_i32 = idx // 4 + idx_i32 = _ld_arith.shrui(idx, _ld_arith.index(2)) elif elem_bytes == 2: - idx_i32 = (idx * 2) // 4 + idx_i32 = _ld_arith.shrui(idx, _ld_arith.index(1)) else: idx_i32 = idx - i32_val = buffer_ops.buffer_load(rsrc, idx_i32, vec_width=vec_width, dtype=T.i32) + i32_val = buffer_ops.buffer_load( + rsrc, + idx_i32, + vec_width=vec_width, + dtype=T.i32, + cache_modifier=cache_modifier, + ) if vec_width == 1: i32_vec = vector.from_elements(T.vec(1, T.i32), [i32_val]) else: @@ -66,59 +91,6 @@ def _buffer_load_vec(buffer_ops, vector, rsrc, idx, *, elem_type, vec_elems, ele return vector.bitcast(T.vec(int(vec_elems), elem_type), i32_vec) -@dataclass(frozen=True) -class PreshuffleBLayout: - """Container returned by `make_preshuffle_b_layout`.""" - - layout_b: object - kpack_bytes: int - - -def make_preshuffle_b_layout( - arith, - *, - c_n: ir.Value, - c_k: ir.Value, - kpack_bytes: int = 16, - elem_bytes: int = 1, -) -> PreshuffleBLayout: - """Build B layout matching aiter/CK preshuffle for A8 MFMA kernels.""" - if kpack_bytes not in (8, 16): - raise ValueError(f"kpack_bytes must be 8 or 16, got {kpack_bytes!r}") - - c16 = fx.Index(16) - c64 = fx.Index(64) - c4 = fx.Index(4) - c_kpack = fx.Index(kpack_bytes) - - if elem_bytes not in (1, 2): - raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}") - c_k_bytes = c_k * arith.constant(int(elem_bytes), index=True) - c_k0 = c_k_bytes // c64 - n0 = c_n // c16 - - c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // arith.constant(int(elem_bytes), index=True)) - - stride_nlane = c_kpack_elems - stride_klane = c16 * stride_nlane - stride_k0 = c4 * stride_klane - stride_n0 = c_k0 * stride_k0 - - # fly.make_shape requires i32/i64 for dynamic operands (not index). - # Convert dynamic index values to i32; use Python ints for static constants. - kpack_elems_static = kpack_bytes if elem_bytes == 1 else kpack_bytes // elem_bytes - n0_i32 = arith.index_cast(T.i32, n0) - c_k0_i32 = arith.index_cast(T.i32, c_k0) - stride_n0_i32 = arith.index_cast(T.i32, stride_n0) - stride_k0_i32 = arith.index_cast(T.i32, stride_k0) - stride_klane_i32 = arith.index_cast(T.i32, stride_klane) - stride_nlane_i32 = arith.index_cast(T.i32, stride_nlane) - - stride_b = (stride_n0_i32, stride_k0_i32, stride_klane_i32, stride_nlane_i32, 1) - layout_b = fx.make_layout((n0_i32, c_k0_i32, 4, 16, kpack_elems_static), stride_b) - return PreshuffleBLayout(layout_b=layout_b, kpack_bytes=kpack_bytes) - - @dataclass(frozen=True) class PreshuffleScaleLayout: """Container returned by `make_preshuffle_scale_layout`. @@ -129,10 +101,10 @@ class PreshuffleScaleLayout: idx = mni * stride_n0 + ku * stride_k0 + k_lane * stride_klane + n_lane """ - layout_scale: object # fly layout value (same as PreshuffleBLayout.layout_b) - stride_n0: object # index-typed MLIR value (dynamic) - stride_k0: object # index-typed MLIR value (= 64) - stride_klane: object # index-typed MLIR value (= 16) + layout_scale: object + stride_n0: object + stride_k0: object + stride_klane: object def make_preshuffle_scale_layout( @@ -150,14 +122,12 @@ def make_preshuffle_scale_layout( Layout shape: ``(c_mn1, c_k1, 4, 16)`` where ``c_mn1 = c_mn / 16 / mn_pack`` and ``c_k1 = (c_k / scale_block_size) / 4 / k_pack``. """ - c16 = arith.constant(16, index=True) - c4 = arith.constant(4, index=True) - c_mn_pack = arith.constant(mn_pack, index=True) - c_k_pack = arith.constant(k_pack, index=True) - c_k_scale = c_k / scale_block_size - - c_mn1 = c_mn / c16 / c_mn_pack - c_k1 = c_k_scale / c4 / c_k_pack + c16 = fx.Index(16) + c4 = fx.Index(4) + c_k_scale = c_k // fx.Index(scale_block_size) + + c_mn1 = (c_mn // c16) // fx.Index(mn_pack) + c_k1 = (c_k_scale // c4) // fx.Index(k_pack) if elem_bytes != mn_pack * k_pack: raise ValueError( f"elem_bytes of scale must be {mn_pack} * {k_pack}, got {elem_bytes!r}" @@ -167,7 +137,6 @@ def make_preshuffle_scale_layout( stride_k0 = c4 * stride_klane stride_n0 = c_k1 * stride_k0 - # Build fly layout (i32 strides for fx.make_layout). c_mn1_i32 = arith.index_cast(T.i32, c_mn1) c_k1_i32 = arith.index_cast(T.i32, c_k1) stride_n0_i32 = arith.index_cast(T.i32, stride_n0) @@ -187,6 +156,76 @@ def make_preshuffle_scale_layout( ) +@dataclass(frozen=True) +class PreshuffleBLayout: + """Container returned by `make_preshuffle_b_layout`.""" + + layout_b: object + kpack_bytes: int + + +def make_preshuffle_b_layout( + arith, + *, + c_n: ir.Value, + c_k: ir.Value, + kpack_bytes: int = 16, + elem_bytes: int = 1, + k_major: bool = False, +) -> PreshuffleBLayout: + """Build B layout matching aiter/CK preshuffle for A8 MFMA kernels. + + When *k_major* is True the block-level order is K-major (``k_blk`` outermost), + matching the ``(0,3,1,4,2,5)`` shuffle permutation. The default N-major + order (``k_major=False``) matches the legacy ``(0,1,3,4,2,5)`` permutation. + """ + if kpack_bytes not in (8, 16): + raise ValueError(f"kpack_bytes must be 8 or 16, got {kpack_bytes!r}") + + c16 = fx.Index(16) + c_kpack = fx.Index(kpack_bytes) + + if elem_bytes not in (1, 2): + raise ValueError(f"elem_bytes must be 1 or 2, got {elem_bytes!r}") + c_k_bytes = c_k * arith.constant(int(elem_bytes), index=True) + n0 = c_n // c16 + + c_kpack_elems = c_kpack if elem_bytes == 1 else (c_kpack // arith.constant(int(elem_bytes), index=True)) + + stride_nlane = c_kpack_elems + + if k_major: + c32 = fx.Index(32) + c2 = fx.Index(2) + c_k0 = c_k_bytes // c32 + klane_dim = 2 + stride_klane = c16 * stride_nlane + stride_n0 = c2 * stride_klane + stride_k0 = n0 * stride_n0 + else: + c64 = fx.Index(64) + c4 = fx.Index(4) + c_k0 = c_k_bytes // c64 + klane_dim = 4 + stride_klane = c16 * stride_nlane + stride_k0 = c4 * stride_klane + stride_n0 = c_k0 * stride_k0 + + kpack_elems_static = kpack_bytes if elem_bytes == 1 else kpack_bytes // elem_bytes + n0_i32 = arith.index_cast(T.i32, n0) + c_k0_i32 = arith.index_cast(T.i32, c_k0) + stride_n0_i32 = arith.index_cast(T.i32, stride_n0) + stride_k0_i32 = arith.index_cast(T.i32, stride_k0) + stride_klane_i32 = arith.index_cast(T.i32, stride_klane) + stride_nlane_i32 = arith.index_cast(T.i32, stride_nlane) + + stride_b = (stride_n0_i32, stride_k0_i32, stride_klane_i32, stride_nlane_i32, 1) + layout_b = fx.make_layout( + (n0_i32, c_k0_i32, klane_dim, 16, kpack_elems_static), stride_b + ) + return PreshuffleBLayout(layout_b=layout_b, kpack_bytes=kpack_bytes) + + def _unpack_int4_to_int8_pair(packed32): """Split packed int4 dword into two int8 dwords (even/odd nibbles). @@ -292,8 +331,14 @@ def load_b_raw_w4a16( idx_bytes = idx_pack + k2_base b4 = _buffer_load_vec( - buffer_ops, vector, b_rsrc, idx_bytes, - elem_type=elem_type, vec_elems=4, elem_bytes=1, offset_in_bytes=True, + buffer_ops, + vector, + b_rsrc, + idx_bytes, + elem_type=elem_type, + vec_elems=4, + elem_bytes=1, + offset_in_bytes=True, ) packed32 = vector.extract( vector.bitcast(T.vec(1, T.i32), b4), @@ -425,8 +470,14 @@ def load_b_pack_k32( if unpack_int4: idx_bytes = idx_pack + k2_base b4 = _buffer_load_vec( - buffer_ops, vector, b_rsrc, idx_bytes, - elem_type=elem_type, vec_elems=4, elem_bytes=1, offset_in_bytes=True, + buffer_ops, + vector, + b_rsrc, + idx_bytes, + elem_type=elem_type, + vec_elems=4, + elem_bytes=1, + offset_in_bytes=True, ) packed32 = vector.extract( vector.bitcast(T.vec(1, T.i32), b4), @@ -438,8 +489,13 @@ def load_b_pack_k32( vec_elems = kpack_bytes // int(elem_bytes) b16 = _buffer_load_vec( - buffer_ops, vector, b_rsrc, idx_pack, - elem_type=elem_type, vec_elems=vec_elems, elem_bytes=elem_bytes, + buffer_ops, + vector, + b_rsrc, + idx_pack, + elem_type=elem_type, + vec_elems=vec_elems, + elem_bytes=elem_bytes, offset_in_bytes=(elem_bytes == 1), ) @@ -492,8 +548,13 @@ def buffer_copy_gmem16_dwordx4( if int(vec_elems) <= 0: raise ValueError(f"vec_elems must be > 0, got {vec_elems!r}") return _buffer_load_vec( - buffer_ops, vector, rsrc, idx_i32, - elem_type=elem_type, vec_elems=vec_elems, elem_bytes=elem_bytes, + buffer_ops, + vector, + rsrc, + idx_i32, + elem_type=elem_type, + vec_elems=vec_elems, + elem_bytes=elem_bytes, offset_in_bytes=False, ) @@ -625,14 +686,10 @@ def lds_load_pack_k32( "make_preshuffle_b_layout", "make_preshuffle_scale_layout", "load_b_pack_k32", - "load_b_raw_w4a16", - "unpack_b_w4a16", - "load_b_raw_w4a16_groupwise", - "unpack_b_w4a16_groupwise", - "extract_bf16_scale", "split_row_major_2d", "swizzle_xor16", "tile_chunk_coord_i32", + "unpack_b_w4a16", ] diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index eb23631a..039d31bc 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -3,69 +3,103 @@ """MoE GEMM stage1/stage2 kernel implementations (FlyDSL MFMA FP8/FP16/FP4). +This module contains the **kernel builder code** for: +- `moe_gemm1` (stage1, with silu/swiglu activation) +- `moe_gemm2` (stage2) + +It is extracted from `tests/kernels/test_moe_gemm.py` so that: +- `kernels/` holds the implementation +- `tests/` holds correctness/perf harnesses + +Mixed-precision support (a_dtype x b_dtype): +- fp8 x fp8, fp8 x fp4 (A8W4 on gfx950), fp4 x fp4, + fp16 x fp16, int8 x int4, ... + +A8W4 path is selected by `a_dtype='fp8', b_dtype='fp4'` plus +`gate_mode=GateMode.INTERLEAVE` + `a_scale_one=True` in stage1. """ -import functools import os -from contextlib import contextmanager +import functools +from enum import Enum import flydsl.compiler as flyc import flydsl.expr as fx from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, gpu, buffer_ops, vector, rocdl + from flydsl.expr import range_constexpr -from flydsl.expr.typing import T from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr -try: - from flydsl.runtime.device import supports_bf16_global_atomics -except ImportError: - # Backward compatibility for runtime.device versions that only expose get_rocm_arch. - def supports_bf16_global_atomics(arch: str) -> bool: - return str(arch).startswith(("gfx94", "gfx95", "gfx12")) +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr from flydsl._mlir import ir -from flydsl._mlir.dialects import scf, memref, llvm +from flydsl.expr.typing import T + +from flydsl.expr import arith, gpu, buffer_ops, vector, rocdl, const_expr +from flydsl._mlir.dialects import llvm, scf, memref +from flydsl._mlir.dialects.arith import CmpIPredicate from kernels.mfma_preshuffle_pipeline import ( _buffer_load_vec, buffer_copy_gmem16_dwordx4, - crd2idx, - lds_row_major_idx, lds_store_16b_xor16, lds_store_8b_xor16, lds_store_4b_xor16, make_preshuffle_b_layout, make_preshuffle_scale_layout, - load_b_pack_k32, - split_row_major_2d, tile_chunk_coord_i32, swizzle_xor16, ) -from kernels.mfma_epilogues import c_shuffle_epilog, mfma_epilog - - -@contextmanager -def _if_then(if_op): - """Compat helper for SCF IfOp then-region across old/new Python APIs.""" - with ir.InsertionPoint(if_op.then_block): - try: - yield if_op.then_block - finally: - blk = if_op.then_block - if (not blk.operations) or not isinstance(blk.operations[-1], scf.YieldOp): - scf.YieldOp([]) +from kernels.mfma_epilogues import c_shuffle_epilog +from kernels.layout_utils import crd2idx, idx2crd, get as layout_get +from kernels.kernels_common import _if_then, validate_moe_dtypes + + +class GateMode(str, Enum): + """Gate/Up computation strategy for stage1 GEMM. + + SEPARATED: Two separate B-tile streams (gate + up), default mode. + MOCK_GATE_ONLY: Single B-tile stream over full [0, 2*inter_dim), simulates + gate-only by doubling grid X on top of SEPARATED layout. + Requires split-K (k_batch>1). NOT true gate-only. + GATE_ONLY: Reserved for future true gate-only implementation. + INTERLEAVE: Weight rows interleave gate/up (gate[0], up[0], gate[1], ...). + pack_N=2 routes even/odd N subtiles. NOT tied to split-K. + """ + SEPARATED = "separated" + MOCK_GATE_ONLY = "mock_gate_only" + GATE_ONLY = "gate_only" + INTERLEAVE = "interleave" -def _w_elem_type(*, is_f4_b: bool, is_f16_b: bool): - """Return the packed weight element type used by preshuffled B tiles.""" - if is_f4_b: - return T.i8 - return T.f16 if is_f16_b else T.f8 +def _barrier(vmcnt=63, lgkmcnt=63): + """Emit s_waitcnt + s_barrier via inline asm. -@functools.lru_cache(maxsize=1024) + Bypasses LLVM SIInsertWaitcnts which would insert a conservative + s_waitcnt vmcnt(0) lgkmcnt(0) before every S_BARRIER MI. + """ + parts = [] + needs_waitcnt = vmcnt < 63 or lgkmcnt < 63 + if needs_waitcnt: + wc = [] + if vmcnt < 63: + wc.append(f"vmcnt({vmcnt})") + if lgkmcnt < 63: + wc.append(f"lgkmcnt({lgkmcnt})") + parts.append("s_waitcnt " + " ".join(wc)) + parts.append("s_barrier") + llvm.InlineAsmOp( + res=None, + operands_=[], + asm_string="\n".join(parts), + constraints="", + has_side_effects=True, + is_align_stack=False, + ) + + +@functools.lru_cache(maxsize=None) def compile_mixed_moe_gemm1( *, model_dim: int, @@ -75,7 +109,6 @@ def compile_mixed_moe_gemm1( tile_m: int, tile_n: int, tile_k: int, - # NOTE: aiter swap passes these for API symmetry; stage1 uses dynamic memrefs so they are ignored. doweight_stage1: bool, a_dtype: str = "fp8", b_dtype: str = "fp4", @@ -85,388 +118,610 @@ def compile_mixed_moe_gemm1( enable_bias: bool = False, model_dim_pad: int = 0, inter_dim_pad: int = 0, + persist_m: int = 1, + use_async_copy: bool = False, + waves_per_eu: int = 4, + k_batch: int = 1, + b_nt: int = 0, + gate_mode: GateMode = GateMode.SEPARATED, + a_scale_one: bool = False, + xcd_swizzle: int = 0, ): - """Compile stage1 kernel (`moe_gemm1`) and return the compiled executable. + """Compile stage1 kernel (gate+up with silu/swiglu). - a_dtype: - - "fp8": X is fp8 - - "fp16": X is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) - - "int8": X is int8 - - "fp4": X is fp4 + GEMM: act(X @ W_gate.T, X @ W_up.T) -> [tokens*topk, inter_dim] + Direct store (no atomic). When k_batch>1 (split-K), each CTA + computes a K-slice and atomically adds gate/up partials. + Note: persist_m=1 (no persistence) is optimal for stage1 because K=model_dim + is large, so each CTA is already compute-heavy. persist_m>1 serializes M blocks + that the GPU can process in parallel. - b_dtype: - - "fp8": W is fp8 - - "fp16": W is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) - - "int8": W is int8 - - "int4": W4A8 path: X is int8, W is packed int4 (2 values per byte) unpacked to int8 in-kernel - - "fp4": W is fp4 + gate_mode controls the gate/up computation strategy — see GateMode enum. """ gpu_arch = get_hip_arch() - allocator = SmemAllocator(None, arch=gpu_arch) + allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") + allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") + _state = {} - if a_dtype not in ("fp8", "fp16", "int8", "fp4"): - raise ValueError( - f"a_dtype must be one of ('fp8','fp16','int8','fp4'), got {a_dtype!r}" - ) - if b_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"): - raise ValueError( - f"b_dtype must be one of ('fp8','fp16','int8','int4','fp4'), got {b_dtype!r}" - ) + validate_moe_dtypes(a_dtype, b_dtype) is_f16_a = a_dtype == "fp16" is_f16_b = b_dtype == "fp16" - is_f16 = is_f16_a or is_f16_b - is_f8_a = a_dtype == "fp8" is_f4_a = a_dtype == "fp4" is_f4_b = b_dtype == "fp4" - pack_M = 2 - pack_N = 2 + sort_block_m = max(32, tile_m) + num_waves = min(4, tile_n // 32) + total_threads = num_waves * 64 + pack_M = 1 if tile_m < 32 else 2 + n_per_wave = tile_n // num_waves + pack_N = min(2, n_per_wave // 16) pack_K = 2 - + scale_mn_pack = 2 elem_bytes = 1 - a_elem_bytes = 2 if is_f16_a else 1 b_elem_bytes = 1 tile_k_bytes = int(tile_k) * int(a_elem_bytes) - a_elem_vec_pack = 2 if is_f4_a else 1 cbsz = 0 if is_f8_a else 4 blgp = 4 - # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). if (tile_k_bytes % 64) != 0: - raise ValueError( - f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " - f"(tile_k={tile_k}, elem_bytes={a_elem_bytes})" - ) - is_int4 = b_dtype == "int4" + raise ValueError(f"tile_k_bytes must be divisible by 64, got {tile_k_bytes}") + out_s = str(out_dtype).strip().lower() + out_is_f32 = out_s in ("f32", "fp32", "float") + out_is_bf16 = out_s in ("bf16", "bfloat16") + is_int4 = b_dtype == "int4" + is_int8 = False - def _x_lds_elem_type(): - return T.f16 if is_f16_a else T.f8 + def _x_elem_type(): + if is_f4_b: + return T.f8 if is_f8_a else T.i8 + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) - def _out_elem_type(): - return T.bf16 if out_dtype == "bf16" else T.f16 + def _w_elem_type(): + if is_f4_b: + return T.i8 + return T.f16 if is_f16_b else (T.i8 if is_int8 else T.f8) - def _out_lds_elem_type(): - return T.f32 + def out_elem(): + return T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16) + mock_gate_only = gate_mode is GateMode.MOCK_GATE_ONLY + gate_up_interleave = gate_mode is GateMode.INTERLEAVE + + # Padding semantics: model_dim and inter_dim INCLUDE padding. + # model_dim = model_dim_true + model_dim_pad (K direction) + # inter_dim = inter_dim_true + inter_dim_pad (N direction) + # Tensor sizes use the padded dimensions (inter_dim, model_dim). + # Padding only affects kernel internal logic and grid computation. + _inter_dim_valid = inter_dim - inter_dim_pad + + # Split-K validation + _is_splitk = k_batch > 1 + if mock_gate_only and not _is_splitk: + raise ValueError("mock_gate_only requires k_batch > 1 (split-K)") + if _is_splitk: + _k_per_batch = model_dim // k_batch + assert ( + model_dim % k_batch == 0 + ), f"model_dim={model_dim} not divisible by k_batch={k_batch}" + assert ( + _k_per_batch % tile_k == 0 + ), f"K_per_batch={_k_per_batch} not divisible by tile_k={tile_k}" + + out_dtype = "bf16" + else: + _k_per_batch = model_dim + _k_dim = _k_per_batch - total_threads = 256 bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) if bytes_x_per_tile % total_threads != 0: raise ValueError( - "tile_m*tile_k*elem_bytes must be divisible by " - f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}, elem_bytes={a_elem_bytes}" + f"tile_m*tile_k*elem_bytes must be divisible by {total_threads}" ) bytes_per_thread_x = bytes_x_per_tile // total_threads - pad_k = 0 + + _use_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ( + "1", + "true", + "True", + "YES", + "yes", + ) + pad_k = 0 if _use_lds128 else 8 lds_stride = tile_k + pad_k + if use_cshuffle_epilog is None: - use_cshuffle_epilog = os.environ.get("FLYDSL_MOE_STAGE1_CSHUFFLE", "0") in ( + _use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE1_CSHUFFLE", "1") in ( "1", "true", "True", "YES", "yes", ) - use_cshuffle_epilog = bool(use_cshuffle_epilog) - - epilog_tag = "cshuffle" if use_cshuffle_epilog else "direct" + else: + _use_cshuffle_epilog = bool(use_cshuffle_epilog) + + _need_fp4 = out_dtype == "fp4" + _need_fp8 = out_dtype == "fp8" + _need_quant = _need_fp4 or _need_fp8 + _need_sort = _need_quant + + if _need_quant: + _use_cshuffle_epilog = True + + _fp4q_tag = "_fp4q" if _need_fp4 else "" + _fp8q_tag = "_fp8q" if _need_fp8 else "" + _sort_tag = "_sort" if _need_sort else "" + _async_tag = "_async" if use_async_copy else "" + _sk_tag = f"_sk{k_batch}" if _is_splitk else "" + _go_tag = "_go" if mock_gate_only else "" + _gui_tag = "_gui" if gate_up_interleave else "" + _as1_tag = "_as1" if a_scale_one else "" + _xcd_tag = f"_xcd{xcd_swizzle}" if xcd_swizzle > 0 else "" module_name = ( - f"mfma_moe1_a{a_dtype}_w{b_dtype}_{epilog_tag}" - f"_t{tile_m}x{tile_n}x{tile_k}" - f"_abi35_ckstage1" + f"mfma_moe1_silu_mul_a{a_dtype}_w{b_dtype}_{out_s}" + f"_t{tile_m}x{tile_n}x{tile_k}_pm{persist_m}{_fp4q_tag}{_fp8q_tag}{_sort_tag}{_async_tag}{_sk_tag}{_go_tag}{_gui_tag}{_as1_tag}{_xcd_tag}_v32" ).replace("-", "_") - # -- LDS sizing (pure Python; no MLIR Context needed) --------------------- - # Reuse the same LDS bytes for both: - # - ping-pong X tiles (2 * tile_m * lds_stride * elem_bytes bytes) - # - optional CShuffle tile (stage1 uses 2xf16 vector store, sized in 4B pairs) - _use_cshuffle_epilog = bool(use_cshuffle_epilog) - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(a_elem_bytes) - lds_out_bytes = 4 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 - lds_tid_bytes = int(tile_m) * 4 - lds_total_bytes = max(lds_x_bytes, lds_out_bytes) + lds_tid_bytes - lds_total_elems = lds_total_bytes if a_elem_bytes == 1 else (lds_total_bytes // 2) - lds_alloc_bytes = int(lds_total_elems) * int(a_elem_bytes) - lds_alloc_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = lds_alloc_offset + lds_alloc_bytes + # -- LDS sizing -- + _cshuffle_elem_bytes = 4 if _need_quant else (4 if out_is_f32 else 2) + _single_x_bytes = int(tile_m) * int(lds_stride) * int(a_elem_bytes) + lds_out_bytes = ( + _cshuffle_elem_bytes * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 + ) + lds_tid_bytes = int(tile_m) * 4 + _input_elems = _single_x_bytes if a_elem_bytes == 1 else (_single_x_bytes // 2) + + # Determine whether we need wave-group split for lds_out. + # Standard layout: pong = max(input, lds_out) + tid, ping = input. + # When this overflows, split lds_out into two halves across pong & ping. + _GLOBAL_ALIGN = 1024 + _std_pong = max(_single_x_bytes, lds_out_bytes) + lds_tid_bytes + _std_ping = _single_x_bytes + _std_pong_aligned = allocator_pong._align(_std_pong, 128) + _std_total = allocator_pong._align( + _std_pong_aligned, _GLOBAL_ALIGN + ) + allocator_pong._align(_std_ping, 128) + _lds_limit = {"gfx950": 163840, "gfx942": 65536}.get(gpu_arch, 0) + + _split_lds_out = ( + _lds_limit > 0 + and lds_out_bytes > 0 + and _std_total > _lds_limit + and num_waves >= 2 + ) - @flyc.kernel - def moe_gemm1( - arg_out: fx.Tensor, - arg_x: fx.Tensor, - arg_w: fx.Tensor, - arg_scale_x: fx.Tensor, - arg_scale_w: fx.Tensor, - arg_sorted_token_ids: fx.Tensor, - arg_expert_ids: fx.Tensor, - arg_sorted_weights: fx.Tensor, - arg_max_token_ids: fx.Tensor, - arg_bias: fx.Tensor, - i32_tokens_in: fx.Int32, - i32_inter_in: fx.Int32, - i32_k_in: fx.Int32, - i32_size_expert_ids_in: fx.Int32, - ): - tokens_in = arith.index_cast(T.index, i32_tokens_in) - k_in = arith.index_cast(T.index, i32_k_in) - size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) - tokens_i32_v = i32_tokens_in - k_i32_v = i32_k_in - x_elem = T.f16 if is_f16_a else T.f8 - vec4_f32 = T.vec(4, T.f32) - vec4_i32 = T.vec(4, T.i32) - vec1_f32 = T.vec(1, T.f32) - vec16_elems = 16 if a_elem_bytes == 1 else 8 - vec16_x = T.vec(vec16_elems, x_elem) - vec2_i64 = T.vec(2, T.i64) - - def silu(x): - # Align with CK's device fast path: - # emu = exp(-x) ~= exp2(log2e * (-x)) -> v_exp_f32 - # sig = rcp(1 + emu) -> v_rcp_f32 - # y = x * sig - t = x * (-1.4426950408889634) # -log2(e) - emu = rocdl.exp2(T.f32, t) - den = 1.0 + emu - sig = rocdl.rcp(T.f32, den) - return x * sig - - _arith_min = getattr(arith, "minimum", None) or getattr(arith, "minimumf") - _arith_max = getattr(arith, "maximum", None) or getattr(arith, "maximumf") - - def swiglu(gate, up, alpha=1.702, limit=7.0): - gate = _arith_min(gate, limit) - up = _arith_min(up, limit) - up = _arith_max(up, -limit) - - t = gate * alpha * (-1.4426950408889634) # -log2(e) - emu = rocdl.exp2(T.f32, t) - den = 1.0 + emu - sig = rocdl.rcp(T.f32, den) - return gate * sig * (up + 1.0) - - acc_init = arith.constant_vector(0.0, vec4_f32) - - # B preshuffle layout: match GEMM test helper exactly. - c_n_total = fx.Index(experts * (2 * inter_dim)) - kpack_bytes = 8 if is_int4 else 16 - b_layout = make_preshuffle_b_layout( - arith, - c_n=c_n_total, - c_k=k_in // pack_K, - kpack_bytes=kpack_bytes, - elem_bytes=b_elem_bytes, - ) - layout_b = b_layout.layout_b + if _split_lds_out: + _half_out_bytes = _cshuffle_elem_bytes * int(tile_m) * (int(tile_n) // 2) + _pong_buffer_bytes = max(_single_x_bytes, _half_out_bytes) + _ping_buffer_bytes = max(_single_x_bytes, _half_out_bytes) + else: + _pong_buffer_bytes = max(_single_x_bytes, lds_out_bytes) + _ping_buffer_bytes = _single_x_bytes - m_repeat = tile_m // 16 - k_unroll = tile_k_bytes // 128 # K64-byte micro-step + def x_lds_elem(): + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + + lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) + allocator_pong.ptr = lds_pong_offset + _pong_buffer_bytes + _lds_tid_offset_pong = allocator_pong._align(allocator_pong.ptr, 4) + allocator_pong.ptr = _lds_tid_offset_pong + lds_tid_bytes + + lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) + allocator_ping.ptr = lds_ping_offset + _ping_buffer_bytes + + if waves_per_eu is not None and waves_per_eu >= 1: + _total_cu_lds = 160 * 1024 + _min_lds = _total_cu_lds // (waves_per_eu + 1) + 1 + _pong_sz = allocator_pong._align(allocator_pong.ptr, 128) + _ping_sz = allocator_ping._align(allocator_ping.ptr, 128) + _cur_lds = _pong_sz + _ping_sz + if _cur_lds < _min_lds: + allocator_ping.ptr += _min_lds - _cur_lds + + kpack_bytes = 8 if is_int4 else 16 + out_elem_bytes = 4 if out_is_f32 else 2 + + _e_vec_s1 = min(tile_n // 32, 8) + if _need_quant: + _e_vec_s1 = max(2, _e_vec_s1) + _num_threads_per_quant_blk_s1 = 32 // _e_vec_s1 + _shuffle_dists_s1 = [] + _sh_val = 1 + while _sh_val < _num_threads_per_quant_blk_s1: + _shuffle_dists_s1.append(_sh_val) + _sh_val *= 2 + _num_shuffle_steps_s1 = len(_shuffle_dists_s1) + + # ---- Unified pipeline schedule (outside @flyc.kernel) ---- + # Each scheduling phase is a dict: + # mfma: [(k_idx, mi_idx, ikxdl, imxdl, asv_idx), ...] + # a_reads: [(k, mi), ...] # A ds_read subtiles + # b_loads: [('gate'/'up', ku, ni), ...] # B VMEM loads + # has_scale: bool # A/B scale VMEM loads + _pipe_m_repeat = tile_m // 16 + _pipe_k_unroll = tile_k_bytes // 128 + _pipe_k_unroll_packed = _pipe_k_unroll // pack_K + _pipe_m_repeat_packed = _pipe_m_repeat // pack_M + _pipe_num_acc_n = n_per_wave // 16 + + # A ds_read groups: group by mi (same mi, all k values together) + _pipe_a_groups = [] + for _mi in range(_pipe_m_repeat): + _grp = [] + for _k in range(_pipe_k_unroll): + _grp.append((_k, _mi)) + if len(_grp) == 2: + _pipe_a_groups.append(_grp) + _grp = [] + if _grp: + _pipe_a_groups.append(_grp) + + # B VMEM loads: individual gate/up loads + _pipe_b_loads = [] + for ku in range(_pipe_k_unroll): + for ni in range(_pipe_num_acc_n): + _pipe_b_loads.append(("gate", ku, ni)) + if not mock_gate_only and not gate_up_interleave: + _pipe_b_loads.append(("up", ku, ni)) + + # MFMA order: B-major (fix B, cycle all A tiles before next B) + # Each entry: one (k, ni) pair; the compute function loops over all mi. + # This keeps B operands (from VMEM) fixed while cycling A (from LDS, no wait). + _pipe_num_acc_n_packed = _pipe_num_acc_n // pack_N + _pipe_all_mfma = [] + for _ku128 in range(_pipe_k_unroll_packed): + for _ni_packed in range(_pipe_num_acc_n_packed): + for _ikxdl in range(pack_K): + for _inxdl in range(pack_N): + _k_idx = _ku128 * pack_K + _ikxdl + _ni_idx = _ni_packed * pack_N + _inxdl + _pipe_all_mfma.append((_k_idx, _ni_idx, _ikxdl, _inxdl, _ku128)) + + # Group MFMAs per scheduling phase (wider M -> more MFMAs per phase) + _pipe_mfma_per_phase = max(1, len(_pipe_all_mfma) // 4) + _pipe_n_phases = len(_pipe_all_mfma) // _pipe_mfma_per_phase + + # Build unified phase descriptors + _a_groups_per_phase = (len(_pipe_a_groups) + _pipe_n_phases - 1) // _pipe_n_phases + _pipe_phases = [] + _mfma_i = 0 + _a_i = 0 + for _p in range(_pipe_n_phases): + _a_reads = [] + for _ in range(_a_groups_per_phase): + if _a_i < len(_pipe_a_groups): + _a_reads.extend(_pipe_a_groups[_a_i]) + _a_i += 1 + _phase = { + "mfma": _pipe_all_mfma[_mfma_i : _mfma_i + _pipe_mfma_per_phase], + "a_reads": _a_reads, + "b_loads": [], + "has_scale": (_p == 0), + } + _mfma_i += _pipe_mfma_per_phase + _pipe_phases.append(_phase) + + # Distribute B loads evenly across phases 1..n-1 (phase 0 has scales) + _bi = 0 + for _p in range(1, _pipe_n_phases): + _rem_b = len(_pipe_b_loads) - _bi + _rem_p = _pipe_n_phases - _p + _n_b = (_rem_b + _rem_p - 1) // _rem_p if _rem_p > 0 else 0 + for _ in range(_n_b): + if _bi < len(_pipe_b_loads): + _pipe_phases[_p]["b_loads"].append(_pipe_b_loads[_bi]) + _bi += 1 + + # Extract flat lists for kernel access (avoids dict access in AST rewriter) + _pp_mfma = [p["mfma"] for p in _pipe_phases] + _pp_a_reads = [p["a_reads"] for p in _pipe_phases] + _pp_b_loads = [p["b_loads"] for p in _pipe_phases] + _pp_has_scale = [p["has_scale"] for p in _pipe_phases] + + fp4_ratio = 2 if a_dtype == "fp4" else 1 + gui_ratio = 1 if gate_up_interleave else 2 + _vmcnt_before_barrier = tile_m // 32 // fp4_ratio + tile_n // 32 * gui_ratio - # A scale is sorted/padded by MoE routing, so its M dimension follows - # the sorted row buffer (`blocks * tile_m`), not the raw token count. - sorted_rows = size_expert_ids_in * fx.Index(tile_m) - layout_a_scale = make_preshuffle_scale_layout( - arith, c_mn=sorted_rows, c_k=k_in - ) - layout_b_scale = make_preshuffle_scale_layout( - arith, c_mn=c_n_total, c_k=k_in - ) + if True: - shape_lds = fx.make_shape(tile_m, tile_k) - stride_lds = fx.make_stride(lds_stride, 1) - layout_lds = fx.make_layout(shape_lds, stride_lds) + @flyc.kernel + def moe_gemm1( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + arg_bias: fx.Tensor, + arg_out_scale_sorted: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): - tx = gpu.thread_id("x") - # Align with Aiter launch mapping (NSwizzle==false): - # - blockIdx.x -> N dimension (tile along inter_dim) - # - blockIdx.y -> expert-block id / M dimension (tile along sorted M) - by = gpu.block_id("x") # tile along inter_dim - bx = gpu.block_id("y") # tile along sorted M + tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) + n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) + k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) + size_expert_ids_in = arith.index_cast( + ir.IndexType.get(), i32_size_expert_ids_in.ir_value() + ) - # Block validity: compute as early as possible so invalid blocks skip all buffer-resource - # setup, LDS pointer math, and gmem prefetch work. - bx_m = bx * fx.Index(tile_m) - by_n = by * fx.Index(tile_n) + x_elem = T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + f32 = T.f32 + i32 = T.i32 + i64 = T.i64 + vec4_f32 = T.vec(4, f32) + vec16_elems = 16 if a_elem_bytes == 1 else 8 + vec16_x = T.vec(vec16_elems, x_elem) + vec2_i64 = T.vec(2, i64) - maxids_rsrc = buffer_ops.create_buffer_resource( - arg_max_token_ids, max_size=False, num_records_bytes=fx.Int32(4) - ) - max_token_id_i32 = buffer_ops.buffer_load( - maxids_rsrc, fx.Index(0), vec_width=1, dtype=T.i32 - ) + acc_init = arith.constant_vector(0.0, vec4_f32) - bias_rsrc = ( - buffer_ops.create_buffer_resource(arg_bias, max_size=False) - if enable_bias - else None - ) + # --- Stage1 dimension mapping --- + # X: [tokens, model_dim] -- M = sorted tokens, K = model_dim + # W: [E*2*inter_dim, model_dim] gate portion -- N = inter_dim + # Out: [tokens*topk, inter_dim] - bx_m_i32 = arith.index_cast(T.i32, bx_m) - blk_valid = arith.cmpi(arith.CmpIPredicate.ult, bx_m_i32, max_token_id_i32) - # Common constants/atoms (hoisted): keep IR small like GEMM. - # CK-style XOR16 swizzle parameter (constant, power-of-two in our configs). - k_blocks16 = fx.Index(tile_k_bytes // 16) - _if_blk = scf.IfOp(blk_valid) - with _if_then(_if_blk): - x_lds_elem = _x_lds_elem_type() - base_ptr = allocator.get_base() - lds_x_ptr = SmemPtr( - base_ptr, - lds_alloc_offset, - x_lds_elem, - shape=(lds_total_elems,), + # B preshuffle layout: [E*2*inter_dim, model_dim] + # Gate rows for expert e: [e*2*inter_dim, e*2*inter_dim + inter_dim) + c_n_total = arith.constant(experts * (2 * inter_dim), index=True) + b_layout = make_preshuffle_b_layout( + arith, + c_n=c_n_total, + c_k=k_in // pack_K, + kpack_bytes=kpack_bytes, + elem_bytes=b_elem_bytes, + # k_major=True, ) - lds_x = lds_x_ptr.get() - # Alias LDS bytes as fp16 for optional CShuffle epilogue. - _use_cshuffle_epilog = bool(use_cshuffle_epilog) + layout_b = b_layout.layout_b - _lds_out_elems = tile_m * tile_n if _use_cshuffle_epilog else 0 - lds_out = ( - SmemPtr( - base_ptr, - lds_x_ptr.byte_offset, - _out_lds_elem_type(), - shape=(_lds_out_elems,), - ).get() - if _use_cshuffle_epilog - else None + # A-scale: [sorted_size, K/32] -- pre-scattered by caller into sorted layout + # Same as stage2: indexed by sorted_row position, not by token_id. + sorted_m = size_expert_ids_in * arith.constant(sort_block_m, index=True) + layout_a_scale = make_preshuffle_scale_layout( + arith, c_mn=sorted_m, c_k=arith.constant(model_dim, index=True) ) + # B-scale: [E*2*inter_dim, K/32] + layout_b_scale = make_preshuffle_scale_layout( + arith, c_mn=c_n_total, c_k=arith.constant(model_dim, index=True) + ) + + _eff_lds_stride = lds_stride + _eff_tile_k_bytes = tile_k_bytes + if const_expr(use_async_copy and a_elem_vec_pack > 1): + _eff_lds_stride = lds_stride // a_elem_vec_pack + _eff_tile_k_bytes = tile_k_bytes // a_elem_vec_pack + + shape_lds = fx.make_shape(tile_m, _eff_lds_stride) + stride_lds = fx.make_stride(_eff_lds_stride, 1) + layout_lds = fx.make_layout(shape_lds, stride_lds) - # Use logical buffer sizes (descriptor num_records) so hardware OOB checking can be - # used directly (CK-style). This allows us to avoid `select`-based masking for - # invalid lanes and rely on the buffer instruction's built-in bounds behavior. - x_nbytes = ( - tokens_in - * (k_in // fx.Index(int(a_elem_vec_pack))) - * fx.Index(int(elem_bytes)) + tx = gpu.thread_id("x") + by = gpu.block_id("x") # tile along inter_dim (N) + bx_persist = gpu.block_id("y") # persistent WG index + + if xcd_swizzle > 0: + _NUM_XCDS_S1 = 8 + _c1_sw = arith.constant(1, index=True) + _c_tn_sw = arith.constant(tile_n, index=True) + _c_idp_sw = arith.constant(2 * inter_dim_pad, index=True) + if const_expr(mock_gate_only or gate_up_interleave): + _gx = (n_in - _c_idp_sw + _c_tn_sw - _c1_sw) / _c_tn_sw + else: + _c2_sw = arith.constant(2, index=True) + _gx = ( + (n_in - _c_idp_sw + _c2_sw * _c_tn_sw - _c1_sw) + / _c_tn_sw + / _c2_sw + ) + _c_pm_sw = arith.constant(persist_m, index=True) + _gy = (size_expert_ids_in + _c_pm_sw - _c1_sw) / _c_pm_sw + + _linear_id = bx_persist * _gx + by + _num_wgs = _gx * _gy + + _c_xcds = arith.constant(_NUM_XCDS_S1, index=True) + _wgs_per_xcd = _num_wgs / _c_xcds + _wgid = (_linear_id % _c_xcds) * _wgs_per_xcd + (_linear_id / _c_xcds) + + _WGM_S1 = xcd_swizzle + _c_wgm = arith.constant(_WGM_S1, index=True) + _num_wgid_in_group = _c_wgm * _gx + _group_id = _wgid / _num_wgid_in_group + _first_pid_m = _group_id * _c_wgm + _remaining_m = _gy - _first_pid_m + _cmp_m = arith.cmpi(CmpIPredicate.ult, _remaining_m, _c_wgm) + _group_size_m = arith.select(_cmp_m, _remaining_m, _c_wgm) + + _wgid_in_group = _wgid % _num_wgid_in_group + bx_persist = _first_pid_m + (_wgid_in_group % _group_size_m) + by = _wgid_in_group / _group_size_m + + by_n = by * arith.constant(tile_n, index=True) + + k_base_idx = arith.index(0) + if const_expr(_is_splitk): + bz = gpu.block_id("z") # K-batch id + k_base_idx = bz * arith.constant(_k_dim, index=True) + + k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) + layout_tx_wave_lane = fx.make_layout((num_waves, 64), stride=(64, 1)) + layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) + + base_ptr_pong = allocator_pong.get_base() + base_ptr_ping = allocator_ping.get_base() + lds_x_pong = SmemPtr( + base_ptr_pong, lds_pong_offset, x_lds_elem(), shape=(_input_elems,) + ).get() + lds_x_ping = SmemPtr( + base_ptr_ping, lds_ping_offset, x_lds_elem(), shape=(_input_elems,) + ).get() + _lds_out_elem_type = ( + T.f32 if _need_quant else (T.bf16 if out_is_bf16 else T.f16) ) + if _split_lds_out and _use_cshuffle_epilog: + _half_out_elems = int(tile_m) * (int(tile_n) // 2) + lds_out = SmemPtr( + base_ptr_pong, + lds_pong_offset, + _lds_out_elem_type, + shape=(_half_out_elems,), + ).get() + lds_out_B = SmemPtr( + base_ptr_ping, + lds_ping_offset, + _lds_out_elem_type, + shape=(_half_out_elems,), + ).get() + else: + lds_out = ( + SmemPtr( + base_ptr_pong, + lds_pong_offset, + _lds_out_elem_type, + shape=(tile_m * tile_n,), + ).get() + if _use_cshuffle_epilog + else None + ) + lds_out_B = None + lds_tid = SmemPtr( + base_ptr_pong, _lds_tid_offset_pong, T.i32, shape=(tile_m,) + ).get() + + # Buffer resources + c_a_pack = arith.constant(int(a_elem_vec_pack), index=True) + c_elem_bytes = arith.constant(int(a_elem_bytes), index=True) + + # X: [tokens, model_dim] + x_nbytes_idx = (tokens_in * k_in * c_elem_bytes) / c_a_pack + x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource( - arg_x, max_size=False, num_records_bytes=x_nbytes - ) - _w_n = fx.Index(experts * (2 * inter_dim)) - _w_nbytes = ( - _w_n - * (k_in // fx.Index(int(a_elem_vec_pack))) - * fx.Index(int(elem_bytes)) - ) - _w_nbytes_i32 = arith.index_cast(T.i32, _w_nbytes) - w_rsrc = buffer_ops.create_buffer_resource( - arg_w, max_size=False, num_records_bytes=_w_nbytes_i32 + arg_x, max_size=False, num_records_bytes=x_nbytes_i32 ) - # OUT: [tokens * topk * inter_dim] in f16/bf16 (2B each) or fp8 (1B each). - _out_elem_bytes = 1 if out_dtype == "fp8" else 2 - out_nbytes = tokens_in * arith.constant( - topk * inter_dim * _out_elem_bytes, index=True + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + + # Out: [tokens*topk, inter_dim] + numids_rsrc = buffer_ops.create_buffer_resource( + arg_num_valid_ids, + max_size=False, + num_records_bytes=arith.constant(4, type=T.i32), ) - out_nbytes_i32 = arith.index_cast(T.i32, out_nbytes) - out_rsrc = buffer_ops.create_buffer_resource( - arg_out, max_size=False, num_records_bytes=out_nbytes_i32 + num_valid_i32 = buffer_ops.buffer_load( + numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32 ) - if is_f16_a: - sx_rsrc = None - else: - # A1 microscale: [sorted_rows, K/32] e8m0 bytes, packed as i32. - _c32 = fx.Index(32) - _kblk = k_in / _c32 - _sorted_rows = size_expert_ids_in * arith.constant( - tile_m, index=True - ) - sx_nbytes = _sorted_rows * _kblk - sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes) + sx_rsrc = 1 + sw_rsrc = 1 + if const_expr(not (is_f16_a or a_scale_one)): + # A scale: [sorted_size, model_dim/32] pre-scattered by caller + c32 = arith.constant(32, index=True) + kblk = k_in / c32 + sx_nbytes_idx = sorted_m * kblk + sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) - if is_f16_b: - sw_rsrc = None - else: - # W1 microscale: [experts * 2 * inter_dim, K/32] e8m0 bytes. - _c32_w = fx.Index(32) - _kblk_w = k_in / _c32_w - _mn_w = fx.Index(experts * (2 * inter_dim)) - sw_nbytes = _mn_w * _kblk_w - sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes) + if const_expr(not is_f16_b): + c32 = arith.constant(32, index=True) + kblk_w = k_in / c32 + mn_w = arith.constant(experts * (2 * inter_dim), index=True) + sw_nbytes_idx = mn_w * kblk_w + sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) sw_rsrc = buffer_ops.create_buffer_resource( arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 ) - # sorted_token_ids / sorted_weights: [blocks*tile_m] (CK-style padded length) - sorted_nbytes = size_expert_ids_in * arith.constant( - tile_m * 4, index=True + sorted_nbytes_idx = size_expert_ids_in * arith.constant( + sort_block_m * 4, index=True ) - sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes) + sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes_idx) sorted_rsrc = buffer_ops.create_buffer_resource( arg_sorted_token_ids, max_size=False, num_records_bytes=sorted_nbytes_i32, ) - sorted_w_rsrc = ( - buffer_ops.create_buffer_resource( - arg_sorted_weights, - max_size=False, - num_records_bytes=sorted_nbytes_i32, - ) - if doweight_stage1 - else None + sorted_w_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 ) - # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 - eid_nbytes_i32 = arith.index_cast( - T.i32, size_expert_ids_in * fx.Index(4) - ) + eid_nbytes_idx = size_expert_ids_in * arith.constant(4, index=True) + eid_nbytes_i32 = arith.index_cast(T.i32, eid_nbytes_idx) expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 ) + bias_rsrc = ( + buffer_ops.create_buffer_resource(arg_bias, max_size=False) + if enable_bias + else None + ) + + # Sorted-scale buffer resource for fused mxfp4 quantization + _sorted_scale_cols = inter_dim // 32 + _sorted_scale_cols_i32 = arith.constant(_sorted_scale_cols, type=T.i32) + sorted_scale_rsrc = None + if const_expr(_need_sort): + sorted_scale_rsrc = buffer_ops.create_buffer_resource( + arg_out_scale_sorted, max_size=False + ) - # Expert id for this M tile (keep address math in `index`) + # ---- persist_m loop (same pattern as stage2) ---- + _PERSIST_M = persist_m + _c0_p = arith.constant(0, index=True) + _c1_p = arith.constant(1, index=True) + _c_pm = arith.constant(_PERSIST_M, index=True) + _for_persist = scf.ForOp(_c0_p, _c_pm, _c1_p) + _for_ip = ir.InsertionPoint(_for_persist.body) + _for_ip.__enter__() + _mi_p = _for_persist.induction_variable + bx = bx_persist * _c_pm + _mi_p + bx_m = bx * arith.constant(sort_block_m, index=True) + + # Block validity + bx_m_i32 = arith.index_cast(T.i32, bx_m) + blk_valid = arith.cmpi(CmpIPredicate.ult, bx_m_i32, num_valid_i32) expert_i32 = buffer_ops.buffer_load( expert_rsrc, bx, vec_width=1, dtype=T.i32 ) + expert_idx = arith.index_cast(ir.IndexType.get(), expert_i32) exp_valid = arith.cmpi( - arith.CmpIPredicate.ult, expert_i32, fx.Int32(experts) - ) # todo fix - _ifexpert_of = scf.IfOp(exp_valid) - with _if_then(_ifexpert_of): - expert_idx = arith.index_cast(T.index, expert_i32) - inter2_idx = fx.Index(2 * inter_dim) - expert_off_idx = expert_idx * inter2_idx # index + CmpIPredicate.ult, expert_i32, arith.constant(experts, type=T.i32) + ) - bx_m = bx * fx.Index(tile_m) + def _moe_gemm1_body(): + # Gate expert offset: first inter_dim rows of each expert's 2*inter_dim block + expert_off_idx = expert_idx * arith.constant(2 * inter_dim, index=True) - # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- - # Keep a fixed 16B gmem->reg schedule (dwordx4) to match preshuffle_gemm_flyc.py. - if bytes_per_thread_x % 16 != 0: - raise ValueError( - f"bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" - ) + # X loading -- KEY DIFFERENCE from stage2: X row = token_id only x_load_bytes = 16 num_x_loads = bytes_per_thread_x // x_load_bytes - chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) + chunk_i32 = x_load_bytes // 4 - # Work in dword units along K: K_dwords = (K_packed_bytes)/4. - # For fp4, 2 elements per byte, so divide by a_elem_vec_pack. - c_a_pack = fx.Index(int(a_elem_vec_pack)) c_k_div4 = ( - (k_in // c_a_pack) * fx.Index(int(elem_bytes)) - ) // arith.index(4) - c_k_div4_i32 = arith.index_cast(T.i32, c_k_div4) - tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + (k_in / c_a_pack) * arith.constant(int(a_elem_bytes), index=True) + ) / arith.index(4) + tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // ( + 4 * int(a_elem_vec_pack) + ) layout_x_tile_div4 = fx.make_layout( (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) ) - c_chunk_i32 = fx.Index(chunk_i32) + c_chunk_i32 = arith.constant(chunk_i32, index=True) tx_i32_base = tx * c_chunk_i32 - mask24 = fx.Int32(0xFFFFFF) - # Keep i32 constants available for epilogue index math. - topk_i32 = fx.Int32(topk) + topk_i32 = arith.constant(topk) + mask24 = arith.constant(0xFFFFFF) tokens_i32 = arith.index_cast(T.i32, tokens_in) def x_tile_chunk_coord_i32(i: int): @@ -479,46 +734,9 @@ def x_tile_chunk_coord_i32(i: int): chunk_i32=chunk_i32, ) - # CK-aligned: decode token once (per thread's M-slice) and build a base row offset. - x_row_base_div4 = [] - x_row_valid = [] - x_col_local_i32 = [] - x_row_local = [] - for i in range_constexpr(num_x_loads): - row_local, col_local_i32 = x_tile_chunk_coord_i32(i) - x_row_local.append(row_local) - x_col_local_i32.append(col_local_i32) - - sorted_row_i = bx_m + row_local - sorted_row_i32 = arith.index_cast(T.i32, sorted_row_i) - row_valid = arith.cmpi( - arith.CmpIPredicate.ult, sorted_row_i32, max_token_id_i32 - ) - sorted_row_safe = arith.select( - row_valid, sorted_row_i, fx.Index(0) - ) - fused_i = buffer_ops.buffer_load( - sorted_rsrc, sorted_row_safe, vec_width=1, dtype=T.i32 - ) - t_i32 = fused_i & mask24 - s_i32 = fused_i >> fx.Int32(24) - t_valid = arith.cmpi(arith.CmpIPredicate.ult, t_i32, tokens_i32) - s_valid = arith.cmpi(arith.CmpIPredicate.ult, s_i32, topk_i32) - ts_valid = row_valid & (t_valid & s_valid) - x_row_valid.append(ts_valid) - t_safe = arith.select(ts_valid, t_i32, fx.Int32(0)) - t_idx = arith.index_cast(T.index, t_safe) - x_row_base_div4.append(t_idx * c_k_div4) - - vec4_i32 = T.vec(4, T.i32) - def load_x(idx_i32): - """Load `x_load_bytes` bytes from X (gmem) into regs. - - For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. - """ idx_elem = ( - idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) ) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -529,277 +747,414 @@ def load_x(idx_i32): vec_elems=vec16_elems, ) - _zero_row_idx = fx.Index(0) + # Decode sorted token ids -- stage1: X row = token_id (not t*topk+s) + x_row_base_div4 = [] + x_col_local_i32 = [] + x_row_local = [] + # Also store token_id and slot_id for output indexing + + for i in range_constexpr(num_x_loads): + row_local, col_local_i32 = x_tile_chunk_coord_i32(i) + x_row_local.append(row_local) + x_col_local_i32.append(col_local_i32) + + sorted_row_i = bx_m + row_local + fused_i = buffer_ops.buffer_load( + sorted_rsrc, sorted_row_i, vec_width=1, dtype=T.i32 + ) + t_i32 = arith.andi(fused_i, mask24) + s_i32 = arith.shrui(fused_i, arith.constant(24)) + t_valid = arith.cmpi(CmpIPredicate.ult, t_i32, tokens_i32) + s_valid = arith.cmpi(CmpIPredicate.ult, s_i32, topk_i32) + ts_valid = arith.andi(t_valid, s_valid) + t_safe = arith.select(ts_valid, t_i32, arith.constant(0)) + + # KEY: X row base uses token_id only (not t*topk+s) + t_idx = arith.index_cast(ir.IndexType.get(), t_safe) + x_row_base_div4.append(t_idx * c_k_div4) def load_x_tile(base_k): - """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" base_k_div4 = ( - (base_k // c_a_pack) - * fx.Index(int(elem_bytes)) - ) // arith.index(4) - zero_x_i32 = arith.constant_vector(0, vec4_i32) + (base_k / c_a_pack) + * arith.constant(int(a_elem_bytes), index=True) + ) / arith.index(4) parts = [] for i in range_constexpr(num_x_loads): - safe_base = arith.select( - x_row_valid[i], x_row_base_div4[i], _zero_row_idx - ) - idx_i32 = safe_base + base_k_div4 + x_col_local_i32[i] + idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32) - x_i32 = vector.bitcast(vec4_i32, x_vec) - x_i32 = arith.select(x_row_valid[i], x_i32, zero_x_i32) - parts.append(x_i32) + parts.append(vector.bitcast(T.vec(4, i32), x_vec)) return parts - # tx -> wave/lane (GEMM-style decomposition). - wave_id, lane_id = split_row_major_2d(tx, fx.Index(64)) - lane_div_16, lane_mod_16 = split_row_major_2d( - lane_id, fx.Index(16) - ) - - # Match GEMM naming/pattern: row in LDS is lane_mod_16, and col base is lane_div_16*16B (KPackBytes=16). + # Wave/lane decomposition (identical to stage2) + coord_wl = idx2crd(tx, layout_tx_wave_lane) + wave_id = layout_get(coord_wl, 0) + lane_id = layout_get(coord_wl, 1) + coord_l16 = idx2crd(lane_id, layout_lane16) + lane_div_16 = layout_get(coord_l16, 0) + lane_mod_16 = layout_get(coord_l16, 1) row_a_lds = lane_mod_16 - col_offset_base = lane_div_16 * fx.Index(16) + col_offset_base = lane_div_16 * arith.constant(16, index=True) - # Dynamic N tiling within block (same as existing kernels) - num_waves = 4 - n_per_wave = tile_n // num_waves num_acc_n = n_per_wave // 16 - c_n_per_wave = fx.Index(n_per_wave) - wave_mod_4 = wave_id % arith.index(4) - n_tile_base = wave_mod_4 * c_n_per_wave + c_n_per_wave = arith.constant(n_per_wave, index=True) + wave_n_id = wave_id % arith.constant(num_waves, index=True) + n_tile_base = wave_n_id * c_n_per_wave - # fp4 pack - k_unroll_packed = k_unroll // pack_K - m_repeat_packed = m_repeat // pack_M - num_acc_n_packed = num_acc_n // pack_N - - # Precompute gate/up B coordinates and output columns for CK-style stage1: - # weights/scales are laid out as [gate rows][up rows], so each output column - # pairs one gate row with the matching up row at +inter_dim. - col_g_list = [] - inter_idx = fx.Index(inter_dim) - out_block_base = by_n - out_wave_base = n_tile_base + # N-tile precompute for gate AND up weights gate_n_intra_list = [] gate_n_blk_list = [] up_n_intra_list = [] up_n_blk_list = [] + col_g_list = [] + c_n0_static = experts * (2 * inter_dim) // 16 + layout_n_blk_intra = fx.make_layout((c_n0_static, 16), stride=(16, 1)) + inter_idx = arith.constant(inter_dim, index=True) + for i in range_constexpr(num_acc_n): offset = i * 16 - c_offset = fx.Index(offset) - out_col = ( - out_block_base + out_wave_base + c_offset + lane_mod_16 - ) - col_g_list.append(out_col) + c_offset = arith.constant(offset, index=True) + if const_expr(not gate_up_interleave): + col_g = by_n + n_tile_base + c_offset + lane_mod_16 + col_g_list.append(col_g) - gate_row_w = expert_off_idx + out_col - gate_n_blk, gate_n_intra = split_row_major_2d( - gate_row_w, fx.Index(16) - ) - gate_n_blk_list.append(gate_n_blk) - gate_n_intra_list.append(gate_n_intra) + global_n = by_n + n_tile_base + c_offset + lane_mod_16 + # Gate/interleave: rows [expert_off, expert_off + 2*inter_dim) + gate_row_w = expert_off_idx + global_n + gate_coord = idx2crd(gate_row_w, layout_n_blk_intra) + gate_n_blk_list.append(layout_get(gate_coord, 0)) + gate_n_intra_list.append(layout_get(gate_coord, 1)) + if const_expr(not mock_gate_only and not gate_up_interleave): + up_row_w = gate_row_w + inter_idx + up_coord = idx2crd(up_row_w, layout_n_blk_intra) + up_n_blk_list.append(layout_get(up_coord, 0)) + up_n_intra_list.append(layout_get(up_coord, 1)) + + if const_expr(gate_up_interleave): + _gui_num_acc_n_out = num_acc_n // pack_N + for _gui_i in range_constexpr(_gui_num_acc_n_out): + _gui_offset = _gui_i * 16 + _gui_c_offset = arith.constant(_gui_offset, index=True) + _gui_col_g = ( + (by_n + n_tile_base) // arith.constant(2, index=True) + + _gui_c_offset + + lane_mod_16 + ) + col_g_list.append(_gui_col_g) - up_row_w = gate_row_w + inter_idx - up_n_blk, up_n_intra = split_row_major_2d( - up_row_w, fx.Index(16) - ) - up_n_blk_list.append(up_n_blk) - up_n_intra_list.append(up_n_intra) + m_repeat = tile_m // 16 + k_unroll = tile_k_bytes // 128 + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N - # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- + _K_per_ku = tile_k // k_unroll + _pad_k_elems = ( + (model_dim_pad % tile_k) + if (not _is_splitk and model_dim_pad > 0) + else 0 + ) + _pad_ku_skip = _pad_k_elems // _K_per_ku + _tail_ku = k_unroll - _pad_ku_skip + _tail_ku_packed = ( + (_tail_ku + pack_K - 1) // pack_K if _pad_ku_skip > 0 else None + ) + + # B load for gate and up separately def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): - # K64 micro-step = 2x K32 MFMA steps. Reuse the shared helper. - b0 = load_b_pack_k32( - buffer_ops, - arith, - vector, - arg_b=arg_w, - b_rsrc=w_rsrc, - layout_b=layout_b, - base_k=base_k, - ki_step=ku * 2, - n_blk=n_blk, - n_intra=n_intra, - lane_div_16=lane_div_16, - elem_type=_w_elem_type(is_f4_b=is_f4_b, is_f16_b=is_f16), - kpack_bytes=kpack_bytes, - elem_bytes=b_elem_bytes, - unpack_int4=bool(is_int4), + c64 = arith.constant(64, index=True) + base_k_bytes = base_k * arith.constant( + int(b_elem_bytes), index=True ) - b1 = load_b_pack_k32( + k0 = base_k_bytes // c64 + arith.constant(ku, index=True) + k1 = lane_div_16 + coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True)) + idx_pack = crd2idx(coord_pack, layout_b) + vec_elems = kpack_bytes // int(b_elem_bytes) + b16 = _buffer_load_vec( buffer_ops, - arith, vector, - arg_b=arg_w, - b_rsrc=w_rsrc, - layout_b=layout_b, - base_k=base_k, - ki_step=ku * 2 + 1, - n_blk=n_blk, - n_intra=n_intra, - lane_div_16=lane_div_16, - elem_type=_w_elem_type(is_f4_b=is_f4_b, is_f16_b=is_f16), - kpack_bytes=kpack_bytes, + w_rsrc, + idx_pack, + elem_type=_w_elem_type(), + vec_elems=vec_elems, elem_bytes=b_elem_bytes, - unpack_int4=bool(is_int4), + offset_in_bytes=(b_elem_bytes == 1), + cache_modifier=b_nt, + ) + b_i64x2 = vector.bitcast(vec2_i64, b16) + b0 = vector.extract( + b_i64x2, static_position=[0], dynamic_position=[] + ) + b1 = vector.extract( + b_i64x2, static_position=[1], dynamic_position=[] ) return b0, b1 - def load_b_tile(base_k): + def load_b_tile(base_k, ku_limit=k_unroll): + """Load B tiles. Returns (gate_b_tile, up_b_tile). + When mock_gate_only or gate_up_interleave, up_b_tile is None.""" gate_b_tile = [] - up_b_tile = [] - for ku in range_constexpr(k_unroll): - gate_packs0 = [] - gate_packs1 = [] - up_packs0 = [] - up_packs1 = [] + up_b_tile = ( + [] if (not mock_gate_only and not gate_up_interleave) else None + ) + for ku in range_constexpr(ku_limit): + g_packs0, g_packs1 = [], [] + u_packs0, u_packs1 = [], [] for ni in range_constexpr(num_acc_n): - gate_b0, gate_b1 = load_b_packs_k64( - base_k, - ku, - gate_n_blk_list[ni], - gate_n_intra_list[ni], - ) - up_b0, up_b1 = load_b_packs_k64( - base_k, - ku, - up_n_blk_list[ni], - up_n_intra_list[ni], + gb0, gb1 = load_b_packs_k64( + base_k, ku, gate_n_blk_list[ni], gate_n_intra_list[ni] ) - gate_packs0.append(gate_b0) - gate_packs1.append(gate_b1) - up_packs0.append(up_b0) - up_packs1.append(up_b1) - gate_b_tile.append((gate_packs0, gate_packs1)) - up_b_tile.append((up_packs0, up_packs1)) + g_packs0.append(gb0) + g_packs1.append(gb1) + if const_expr( + not mock_gate_only and not gate_up_interleave + ): + ub0, ub1 = load_b_packs_k64( + base_k, ku, up_n_blk_list[ni], up_n_intra_list[ni] + ) + u_packs0.append(ub0) + u_packs1.append(ub1) + gate_b_tile.append((g_packs0, g_packs1)) + if const_expr(not mock_gate_only and not gate_up_interleave): + up_b_tile.append((u_packs0, u_packs1)) return gate_b_tile, up_b_tile - def load_scale(arg_scale, rsrc, scale_info, ku, mni): - k_lane = lane_div_16 - n_lane = lane_mod_16 - # Direct arith crd2idx: idx = mni*stride_n0 + ku*stride_k0 + k_lane*stride_klane + n_lane - idx_pack = ( - mni * scale_info.stride_n0 - + ku * scale_info.stride_k0 - + k_lane * scale_info.stride_klane - + n_lane + # Pre-compute scale base element indices (K-loop invariant). + # idx = mni * stride_n0 + ku * stride_k0 + k_lane * stride_klane + n_lane + # Split into: base_elem = mni * stride_n0 + lane_elem (invariant) + # k_elem = ku * stride_k0 (per-iteration) + _scale_lane_elem = ( + lane_div_16 * layout_b_scale.stride_klane + lane_mod_16 + ) + + _gate_scale_bases = [] + _up_scale_bases = [] + for _ni in range_constexpr(num_acc_n_packed): + _col_base = ( + by_n + + n_tile_base + + arith.constant(_ni * 16 * pack_N, index=True) ) - s = buffer_ops.buffer_load( - rsrc, idx_pack, vec_width=1, dtype=T.i32 + _gate_mni = (expert_off_idx + _col_base) // arith.constant( + 32, index=True ) - return vector.from_elements(T.vec(1, T.i32), [s]) + _gate_scale_bases.append( + _gate_mni * layout_b_scale.stride_n0 + _scale_lane_elem + ) + if const_expr(not mock_gate_only and not gate_up_interleave): + _up_mni = ( + expert_off_idx + inter_idx + _col_base + ) // arith.constant(32, index=True) + _up_scale_bases.append( + _up_mni * layout_b_scale.stride_n0 + _scale_lane_elem + ) - def load_scale_masked(arg_scale, rsrc, scale_info, ku, mni, valid): - safe_mni = arith.select( - valid, mni, fx.Index(0) + if const_expr(not a_scale_one): + _a_scale_bases = [] + for _mi in range_constexpr(m_repeat_packed): + _a_mni = _mi + bx_m // scale_mn_pack // 16 + _a_scale_bases.append( + _a_mni * layout_a_scale.stride_n0 + _scale_lane_elem + ) + + _c16_idx = arith.constant(16, index=True) + _c2_idx = arith.constant(2, index=True) + _scale_mask_lo = arith.constant(0xFF, type=T.i32) + + _m_half_idx = arith.constant(0, type=T.i32) + _m_half_i32 = arith.constant(0, type=T.i32) + _scale_shift = arith.constant(0, type=T.i32) + _scale_shift_hi = arith.constant(0, type=T.i32) + _n_half_idx = arith.constant(0, type=T.i32) + _n_half_i32 = arith.constant(0, type=T.i32) + _bscale_shift = arith.constant(0, type=T.i32) + _bscale_shift_hi = arith.constant(0, type=T.i32) + if const_expr(pack_M < scale_mn_pack): + _m_half_idx = (bx_m // _c16_idx) % _c2_idx + _m_half_i32 = arith.index_cast(T.i32, _m_half_idx) + _scale_shift = _m_half_i32 * arith.constant(8, type=T.i32) + _scale_shift_hi = _scale_shift + arith.constant(16, type=T.i32) + + if const_expr(pack_N < scale_mn_pack): + _n_half_idx = (n_tile_base // _c16_idx) % _c2_idx + _n_half_i32 = arith.index_cast(T.i32, _n_half_idx) + _bscale_shift = _n_half_i32 * arith.constant(8, type=T.i32) + _bscale_shift_hi = _bscale_shift + arith.constant(16, type=T.i32) + + def _rearrange_a_scale(raw_i32): + """Rearrange scale bytes for pack_M=1: extract m_half's k0,k1 bytes.""" + if const_expr(pack_M >= scale_mn_pack): + return raw_i32 + b_k0 = arith.andi( + arith.shrui(raw_i32, _scale_shift), _scale_mask_lo ) - scale_i32 = vector.extract( - load_scale(arg_scale, rsrc, scale_info, ku, safe_mni), - static_position=[0], - dynamic_position=[], + b_k1 = arith.andi( + arith.shrui(raw_i32, _scale_shift_hi), _scale_mask_lo + ) + return arith.ori( + b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) ) - scale_i32 = arith.select(valid, scale_i32, fx.Int32(0)) - return vector.from_elements(T.vec(1, T.i32), [scale_i32]) - def load_b_scale_tile(base_k): - gate_b_scale_tile = [] - up_b_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): - for ni in range_constexpr(num_acc_n_packed): - col_offset = ni * 16 * pack_N - col_offset_idx = fx.Index(col_offset) - col_base = ( - out_block_base + out_wave_base + col_offset_idx - ) - col_valid = arith.cmpi( - arith.CmpIPredicate.ult, col_base, inter_idx - ) - gate_mni = ( - expert_off_idx + col_base - ) // fx.Index(32) - up_mni = ( - expert_off_idx + inter_idx + col_base - ) // fx.Index(32) - gate_scale_i32 = load_scale_masked( - arg_scale_w, - sw_rsrc, - layout_b_scale, - ku + base_k, - gate_mni, - col_valid, - ) - up_scale_i32 = load_scale_masked( - arg_scale_w, - sw_rsrc, - layout_b_scale, - ku + base_k, - up_mni, - col_valid, - ) - gate_b_scale_tile.append(gate_scale_i32) - up_b_scale_tile.append(up_scale_i32) - return gate_b_scale_tile, up_b_scale_tile + def _rearrange_b_scale(raw_i32): + """Rearrange scale bytes for pack_N=1: extract n_half's k0,k1 bytes.""" + if const_expr(pack_N >= scale_mn_pack): + return raw_i32 + b_k0 = arith.andi( + arith.shrui(raw_i32, _bscale_shift), _scale_mask_lo + ) + b_k1 = arith.andi( + arith.shrui(raw_i32, _bscale_shift_hi), _scale_mask_lo + ) + return arith.ori( + b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) + ) + + if const_expr(a_scale_one): + _as1_const = arith.constant(0x7F7F7F7F, type=T.i32) + _as1_vec = vector.from_elements(T.vec(1, T.i32), [_as1_const]) - def load_a_scale_tile(base_k): + def prefetch_ab_scale_tile(base_k, ku_packed_limit=k_unroll_packed): a_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): + gate_b_scale = [] + up_b_scale = ( + [] if (not mock_gate_only and not gate_up_interleave) else None + ) + for ku in range_constexpr(ku_packed_limit): + k_off = (ku + base_k) * layout_b_scale.stride_k0 for mi in range_constexpr(m_repeat_packed): - scale = load_scale( - arg_scale_x, - sx_rsrc, - layout_a_scale, - ku + base_k, - mi + bx_m // pack_M // 16, + if const_expr(a_scale_one): + a_scale_tile.append(_as1_vec) + else: + s = buffer_ops.buffer_load( + sx_rsrc, + _a_scale_bases[mi] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + s = _rearrange_a_scale(s) + a_scale_tile.append( + vector.from_elements(T.vec(1, T.i32), [s]) + ) + for ni in range_constexpr(num_acc_n_packed): + gs = buffer_ops.buffer_load( + sw_rsrc, + _gate_scale_bases[ni] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, ) - a_scale_tile.append(scale) - return a_scale_tile - - def prefetch_ab_scale_tile(base_k): - gate_bs, up_bs = load_b_scale_tile(base_k) - return [load_a_scale_tile(base_k), gate_bs, up_bs] + gs = _rearrange_b_scale(gs) + gate_b_scale.append( + vector.from_elements(T.vec(1, T.i32), [gs]) + ) + if const_expr( + not mock_gate_only and not gate_up_interleave + ): + us = buffer_ops.buffer_load( + sw_rsrc, + _up_scale_bases[ni] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + us = _rearrange_b_scale(us) + up_b_scale.append( + vector.from_elements(T.vec(1, T.i32), [us]) + ) + return [a_scale_tile, gate_b_scale, up_b_scale] - acc_gate = [acc_init] * (num_acc_n * m_repeat) - acc_up = [acc_init] * (num_acc_n * m_repeat) + _lds_base_zero = arith.index(0) - # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): lds_store_16b_xor16( arith, vector, - lds_memref=lds_x, + lds_memref=lds_buffer, vec16_ty=vec16_x, layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, tx_c4=arith.index(4), k_blocks16=k_blocks16, - lds_base=lds_base, + lds_base=_lds_base_zero, vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - # --- A LDS load helper for K64 (load 16B once, extract 2x i64 halves) --- - def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): - # Swizzle in bytes, then convert to element offset for memref indexing. + if const_expr(use_async_copy): + _dma_bytes = 16 + _wave_size = 64 + _eff_bytes_per_buffer = ( + int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) + ) + _num_dma_loads = max( + 1, _eff_bytes_per_buffer // (total_threads * _dma_bytes) + ) + + def dma_x_tile_to_lds(base_k, lds_buffer): + c4_idx = arith.index(4) + base_k_div4 = ( + (base_k / c_a_pack) + * arith.constant(int(elem_bytes), index=True) + ) / arith.index(4) + + lds_ptr_i64 = None + for i in range_constexpr(_num_dma_loads): + row_local_i = x_row_local[i] + col_local_i32_i = x_col_local_i32[i] + col_local_sw = swizzle_xor16( + row_local_i, col_local_i32_i * c4_idx, k_blocks16 + ) + row_k_dw = x_row_base_div4[i] + base_k_div4 + global_byte_idx = row_k_dw * c4_idx + col_local_sw + global_offset = arith.index_cast(T.i32, global_byte_idx) + + if const_expr(i == 0): + lds_addr = memref.extract_aligned_pointer_as_index( + lds_buffer + ) + wave_id * arith.constant( + _wave_size * _dma_bytes, index=True + ) + lds_ptr_i64 = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + else: + lds_ptr_i64 = lds_ptr_i64 + arith.constant( + total_threads * _dma_bytes, type=T.i64 + ) + + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_ptr_i64) + + rocdl.raw_ptr_buffer_load_lds( + x_rsrc, + lds_ptr, + arith.constant(_dma_bytes, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + ) + + def prefetch_x_to_lds(base_k, lds_buffer): + dma_x_tile_to_lds(base_k, lds_buffer) + + def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16( curr_row_a_lds, col_base, k_blocks16 ) col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 - else (col_base_swz_bytes // arith.index(2)) - ) - idx_a16 = lds_row_major_idx( - curr_row_a_lds, - col_base_swz, - fx.Index(lds_stride), - lds_base, + else (col_base_swz_bytes / arith.index(2)) ) - loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) + idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract( a_i64x2, static_position=[0], dynamic_position=[] @@ -809,522 +1164,1315 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): ) return a0, a1 - def compute_f8f6f4_tile( + def prefetch_full_a_from_lds(lds_buffer, ku_limit=k_unroll): + """Load entire A tile from LDS into registers before compute.""" + a_regs = [] + for k_idx in range_constexpr(ku_limit): + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + for mi_idx in range_constexpr(m_repeat): + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row = row_a_lds + mi_val + a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) + if const_expr(is_f8_a): + a2, a3 = lds_load_packs_k64( + curr_row, col_base + 64, lds_buffer + ) + a_regs.append((a0, a1, a2, a3)) + else: + a_regs.append((a0, a1)) + return a_regs + + # Compute tile: gate + up MFMA interleaved, same A data, different B data. + # Two accumulator sets; after all K tiles, acc = acc_gate + acc_up (f32 add). + def compute_tile( acc_gate_in, acc_up_in, gate_b_tile_in, up_b_tile_in, - lds_base, - *, - a0_prefetch=None, + a_tile_regs, a_scale=None, gate_b_scale=None, up_b_scale=None, - prefetch_epilogue: bool = False, + *, + prefetch_epilogue=False, + ku_count=k_unroll, ): gate_list = list(acc_gate_in) - up_list = list(acc_up_in) - - # Re-sync all threads before consuming the current LDS tile. - gpu.barrier() - rocdl.sched_barrier(0) - + _single_b = mock_gate_only or gate_up_interleave + up_list = None if _single_b else list(acc_up_in) + mfma_res_ty = vec4_f32 epilogue_pf = None - if enable_bias and prefetch_epilogue: - gate_bias = [] - up_bias = [] - for ni in range_constexpr(num_acc_n): - global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 - gate_offset = expert_off_idx + global_n - up_offset = expert_off_idx + global_n + inter_dim - gate_bias.append( - buffer_ops.buffer_load( - bias_rsrc, gate_offset, vec_width=1, dtype=T.f32 - ) - ) - up_bias.append( - buffer_ops.buffer_load( - bias_rsrc, up_offset, vec_width=1, dtype=T.f32 + bias_pf = None + if const_expr(prefetch_epilogue): + if const_expr(enable_bias): + bias_pf = [] + for ni in range_constexpr(num_acc_n): + if const_expr(gate_up_interleave): + _logical_col = ( + (by_n + n_tile_base) + // arith.constant(2, index=True) + + arith.constant((ni // 2) * 16, index=True) + + lane_mod_16 + ) + _up_off = ( + inter_idx + if (ni % 2 == 1) + else arith.constant(0, index=True) + ) + bias_offset = ( + expert_off_idx + _up_off + _logical_col + ) + else: + global_n = ( + by_n + + n_tile_base + + arith.constant(ni * 16, index=True) + + lane_mod_16 + ) + bias_offset = expert_off_idx + global_n + bias_pf.append( + buffer_ops.buffer_load( + bias_rsrc, bias_offset, vec_width=1, dtype=f32 + ) ) - ) - epilogue_pf = (gate_bias, up_bias) - - if (int(tile_k) % 128) != 0: - raise ValueError( - f"tile_k must be divisible by 128 for mfma_scale_x128, got tile_k={tile_k}" - ) + tw_pf = None + if const_expr(doweight_stage1): + tw_pf = [] + lane_div_16_mul4_pf = lane_div_16 * arith.index(4) + ii_idx_list_pf = [ + arith.constant(ii, index=True) for ii in range(4) + ] + for mi in range_constexpr(m_repeat): + mi_base_pf = arith.constant(mi * 16, index=True) + for ii in range_constexpr(4): + row_off_pf = ( + lane_div_16_mul4_pf + ii_idx_list_pf[ii] + ) + sorted_row_pf = bx_m + mi_base_pf + row_off_pf + tw_pf.append( + buffer_ops.buffer_load( + sorted_w_rsrc, + sorted_row_pf, + vec_width=1, + dtype=f32, + ) + ) + epilogue_pf = (None, tw_pf, bias_pf) - mfma_res_ty = T.f32x4 + c0_i64 = arith.constant(0, type=T.i64) vec4_i64 = T.vec(4, T.i64) vec8_i32 = T.vec(8, T.i32) - c0_i64 = arith.constant(0, type=T.i64) def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) return vector.bitcast(vec8_i32, v4) - # Gate and Up MFMA interleaved in the same inner loop. - for ku128 in range_constexpr(k_unroll_packed): - for mi in range_constexpr(m_repeat_packed): - a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] - a_scale_val = vector.extract( - a_scale_i32, + _eff_packed = (ku_count + pack_K - 1) // pack_K + # B-major: fix B (ni), cycle A (mi) -- B from VMEM stays + # in registers while A from LDS is repacked per mi. + for ku128 in range_constexpr(_eff_packed): + for ni in range_constexpr(num_acc_n_packed): + gate_bs_i32 = gate_b_scale[ku128 * num_acc_n_packed + ni] + gate_bs_val = vector.extract( + gate_bs_i32, static_position=[0], dynamic_position=[], ) - for ni in range_constexpr(num_acc_n_packed): - gate_bs_i32 = gate_b_scale[ - ku128 * num_acc_n_packed + ni - ] - gate_bs_val = vector.extract( - gate_bs_i32, - static_position=[0], - dynamic_position=[], - ) - up_bs_i32 = up_b_scale[ - ku128 * num_acc_n_packed + ni - ] + if const_expr(not _single_b): + up_bs_i32 = up_b_scale[ku128 * num_acc_n_packed + ni] up_bs_val = vector.extract( - up_bs_i32, - static_position=[0], - dynamic_position=[], + up_bs_i32, static_position=[0], dynamic_position=[] ) - for ikxdl in range_constexpr(pack_K): - k_idx = ku128 * pack_K + ikxdl + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + if k_idx < ku_count: gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] - up_bp0, up_bp1 = up_b_tile_in[k_idx] - col_base = ( - col_offset_base - + (k_idx * 128) // a_elem_vec_pack - ) - for imxdl in range_constexpr(pack_M): - col_base0 = col_base - mi_idx = mi * pack_M + imxdl - mi_val = arith.constant( - mi_idx * 16, index=True - ) - curr_row_a_lds = row_a_lds + mi_val - - a0, a1 = lds_load_packs_k64( - curr_row_a_lds, col_base0, lds_base + if const_expr(not _single_b): + up_bp0, up_bp1 = up_b_tile_in[k_idx] + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl + gb0 = gate_bp0[ni_idx] + gb1 = gate_bp1[ni_idx] + gb128 = pack_i64x4_to_i32x8( + gb0, gb1, c0_i64, c0_i64 ) - - if is_f8_a: - col_base1 = col_base + 64 - a2, a3 = lds_load_packs_k64( - curr_row_a_lds, col_base1, lds_base - ) - a128 = pack_i64x4_to_i32x8( - a0, a1, a2, a3 - ) - else: - a128 = pack_i64x4_to_i32x8( - a0, a1, c0_i64, c0_i64 - ) - - for inxdl in range_constexpr(pack_N): - ni_idx = ni * pack_N + inxdl - acc_idx = mi_idx * num_acc_n + ni_idx - - gb0 = gate_bp0[ni_idx] - gb1 = gate_bp1[ni_idx] - gb128 = pack_i64x4_to_i32x8( - gb0, gb1, c0_i64, c0_i64 - ) - + if const_expr(not _single_b): ub0 = up_bp0[ni_idx] ub1 = up_bp1[ni_idx] ub128 = pack_i64x4_to_i32x8( ub0, ub1, c0_i64, c0_i64 ) - - rocdl.sched_barrier(0) - gate_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - gb128, - gate_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - gate_bs_val, - ], - ) + for mi in range_constexpr(m_repeat_packed): + a_scale_i32 = a_scale[ + ku128 * m_repeat_packed + mi + ] + a_scale_val = vector.extract( + a_scale_i32, + static_position=[0], + dynamic_position=[], ) - rocdl.sched_barrier(0) - up_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - ub128, - up_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - up_bs_val, - ], + for imxdl in range_constexpr(pack_M): + mi_idx = mi * pack_M + imxdl + _a_reg_idx = k_idx * m_repeat + mi_idx + if const_expr(is_f8_a): + a0, a1, a2, a3 = a_tile_regs[ + _a_reg_idx + ] + a128 = pack_i64x4_to_i32x8( + a0, a1, a2, a3 + ) + else: + a0, a1 = a_tile_regs[_a_reg_idx] + a128 = pack_i64x4_to_i32x8( + a0, a1, c0_i64, c0_i64 + ) + acc_idx = mi_idx * num_acc_n + ni_idx + gate_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + gb128, + gate_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + gate_bs_val, + ], + ) ) - ) + if const_expr(not _single_b): + up_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + ub128, + up_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + up_bs_val, + ], + ) + ) return gate_list, up_list, epilogue_pf - # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- - lds_tile_elems = fx.Index(tile_m * lds_stride) - lds_base_cur = arith.index(0) - lds_base_nxt = lds_tile_elems + def load_a_subtile(k_idx, mi_idx, lds_buffer): + """Load a single A sub-tile from LDS (one ds_read).""" + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row = row_a_lds + mi_val + a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) + if const_expr(is_f8_a): + a2, a3 = lds_load_packs_k64(curr_row, col_base + 64, lds_buffer) + return (a0, a1, a2, a3) + else: + return (a0, a1) + + _single_b_pipe = mock_gate_only or gate_up_interleave + + def compute_bmajor_mfma_phase( + all_a_tiles, + gate_b_single, + up_b_single, + a_scale_vals, + gate_bs_val, + up_bs_val, + gate_list, + up_list, + k_idx, + ni_idx, + ikxdl, + inxdl, + ): + """B-major MFMA: fix one B (ni), cycle all A tiles (mi). - # Optional scheduler hints (copied from tuned GEMM); can be disabled via env. - rocdl.sched_barrier(0) + Packs B once and reuses across all mi iterations. + A tiles come from LDS (already available, no VMEM wait). - def hot_loop_scheduler(): - mfma_group = num_acc_n * 2 - # K64 micro-step: 2x K32 MFMA per gemm. - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = ( - 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - ) - - # DS-read preload (CK default is 2); clamp to non-negative. - rocdl.sched_dsrd(2) - rocdl.sched_mfma(2) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - - # DS-write hints near the end: match total X LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if dswr_tail > sche_iters: - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: - rocdl.sched_dswr(1) + all_a_tiles: flat list indexed by [k*m_repeat + mi]. + gate_b_single/up_b_single: (b0, b1) for one specific ni. + When _single_b_pipe (mock_gate_only or interleave), up_b_single is None. + a_scale_vals: list of A scale scalars indexed by mi_packed. + """ + c0_i64 = arith.constant(0, type=T.i64) + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + + def _pack(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + mfma_res_ty = vec4_f32 + gb128 = _pack(gate_b_single[0], gate_b_single[1], c0_i64, c0_i64) + if const_expr(not _single_b_pipe): + ub128 = _pack(up_b_single[0], up_b_single[1], c0_i64, c0_i64) + + for mi_p in range_constexpr(m_repeat_packed): + a_scale_val = a_scale_vals[mi_p] + for imxdl in range_constexpr(pack_M): + mi_idx = mi_p * pack_M + imxdl + a_reg = all_a_tiles[k_idx * m_repeat + mi_idx] + + if const_expr(is_f8_a): + a128 = _pack(a_reg[0], a_reg[1], a_reg[2], a_reg[3]) + else: + a128 = _pack(a_reg[0], a_reg[1], c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n + ni_idx + gate_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + gb128, + gate_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + gate_bs_val, + ], + ) + if const_expr(not _single_b_pipe): + up_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + ub128, + up_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + up_bs_val, + ], + ) + ) + + def _interleaved_half( + lds_read, + lds_write, + next_k_dma_py, + next_k_load, + prev_a_tile, + prev_gate_w, + prev_up_w, + prev_a_scale, + prev_gate_bs, + prev_up_bs, + acc_gate, + acc_up, + ): + """One flatmm-style interleaved half-iteration (deep pipeline). + + Generalized for arbitrary m_repeat (block_m=32, 64, ...). + DMA targets lds_write (OTHER buffer) while ds_read uses + lds_read (already DMA'd in previous half). + + Interleaving schedule (per half): + Phase 0: scale VMEM + 2 ds_read(A) -> 4 MFMA(prev) + Phase 1..N: B VMEM(distributed) + 2 ds_read(A, if avail) -> 4 MFMA(prev) + Phase N+1..: remaining B VMEM -> 4 MFMA(prev) + """ + _abs_k = k_base_idx + arith.constant(next_k_load, index=True) + _bk = _abs_k // arith.constant(2, index=True) + _sk = _abs_k // arith.constant(pack_K * 128, index=True) + _k_off = _sk * layout_b_scale.stride_k0 + + rocdl.sched_barrier(0) + rocdl.s_waitcnt(_vmcnt_before_barrier) + _barrier() rocdl.sched_barrier(0) - # Prologue: prefetch tile0, store to LDS(cur), sync. - k0 = arith.index(0) - x_regs0 = load_x_tile(k0) - gate_w0, up_w0 = load_b_tile(k0) + # DMA A to OTHER buffer (for next half), non-blocking + _abs_k_dma = k_base_idx + arith.constant(next_k_dma_py, index=True) + if const_expr(use_async_copy and next_k_dma_py < int(_k_dim)): + prefetch_x_to_lds(_abs_k_dma, lds_write) + if const_expr(not use_async_copy): + _x_regs = load_x_tile(_abs_k_dma) + + # ---- Extract previous scale values ---- + _prev_asvs = [] + for _mi_p in range_constexpr(m_repeat_packed): + _prev_asvs.append( + vector.extract( + prev_a_scale[_mi_p], + static_position=[0], + dynamic_position=[], + ) + ) + _prev_gsv_list = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + _prev_gsv_list.append( + vector.extract( + prev_gate_bs[_gs_ni], + static_position=[0], + dynamic_position=[], + ) + ) + if const_expr(not _single_b_pipe): + _prev_usv_list = [] + for _us_ni in range_constexpr(num_acc_n_packed): + _prev_usv_list.append( + vector.extract( + prev_up_bs[_us_ni], + static_position=[0], + dynamic_position=[], + ) + ) + # ---- Execute phases from unified schedule ---- + _a_all = {} + _b_gate_all = {} + _b_up_all = {} + + for _p in range_constexpr(_pipe_n_phases): + # Scale VMEM loads (phase 0 only) + if const_expr(_pp_has_scale[_p]): + _new_as_list = [] + for _mi_p in range_constexpr(m_repeat_packed): + if const_expr(a_scale_one): + _new_as_list.append(_as1_const) + else: + _raw_as = buffer_ops.buffer_load( + sx_rsrc, + _a_scale_bases[_mi_p] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_as_list.append(_rearrange_a_scale(_raw_as)) + _new_gs_list = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + _gs_raw = buffer_ops.buffer_load( + sw_rsrc, + _gate_scale_bases[_gs_ni] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_gs_list.append(_rearrange_b_scale(_gs_raw)) + if const_expr(not _single_b_pipe): + _new_us_list = [] + for _us_ni in range_constexpr(num_acc_n_packed): + _us_raw = buffer_ops.buffer_load( + sw_rsrc, + _up_scale_bases[_us_ni] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_us_list.append(_rearrange_b_scale(_us_raw)) + + # B VMEM loads + for _b_j in range_constexpr(len(_pp_b_loads[_p])): + _b_type, _b_ku, _b_ni = _pp_b_loads[_p][_b_j] + if const_expr(_b_type == "gate"): + _b_gate_all[(_b_ku, _b_ni)] = load_b_packs_k64( + _bk, + _b_ku, + gate_n_blk_list[_b_ni], + gate_n_intra_list[_b_ni], + ) + else: + _b_up_all[(_b_ku, _b_ni)] = load_b_packs_k64( + _bk, + _b_ku, + up_n_blk_list[_b_ni], + up_n_intra_list[_b_ni], + ) + + # A ds_reads + rocdl.sched_barrier(0) + for _a_j in range_constexpr(len(_pp_a_reads[_p])): + _ak, _ami = _pp_a_reads[_p][_a_j] + _a_all[(_ak, _ami)] = load_a_subtile( + _ak, + _ami, + lds_read, + ) + rocdl.sched_barrier(0) + + # MFMAs on prev data + rocdl.s_setprio(1) + for _m_j in range_constexpr(len(_pp_mfma[_p])): + _k_idx, _ni_idx, _ikxdl, _inxdl, _ku128 = _pp_mfma[_p][_m_j] + _ni_packed_idx = _ni_idx // pack_N + _up_b_single = ( + ( + prev_up_w[_k_idx][0][_ni_idx], + prev_up_w[_k_idx][1][_ni_idx], + ) + if not _single_b_pipe + else None + ) + compute_bmajor_mfma_phase( + prev_a_tile, + ( + prev_gate_w[_k_idx][0][_ni_idx], + prev_gate_w[_k_idx][1][_ni_idx], + ), + _up_b_single, + _prev_asvs, + _prev_gsv_list[_ni_packed_idx], + ( + _prev_usv_list[_ni_packed_idx] + if not _single_b_pipe + else None + ), + acc_gate, + acc_up, + _k_idx, + _ni_idx, + _ikxdl, + _inxdl, + ) + rocdl.s_setprio(0) + rocdl.sched_barrier(0) + + # ---- Assemble loaded data for next half-iteration ---- + cur_a_tile = [] + for _k in range_constexpr(k_unroll): + for _mi in range_constexpr(m_repeat): + cur_a_tile.append(_a_all[(_k, _mi)]) + + cur_gate_w = [] + cur_up_w = None if _single_b_pipe else [] + for ku in range_constexpr(k_unroll): + g_packs0, g_packs1 = [], [] + u_packs0, u_packs1 = [], [] + for ni in range_constexpr(num_acc_n): + g = _b_gate_all[(ku, ni)] + g_packs0.append(g[0]) + g_packs1.append(g[1]) + if const_expr(not _single_b_pipe): + u = _b_up_all[(ku, ni)] + u_packs0.append(u[0]) + u_packs1.append(u[1]) + cur_gate_w.append((g_packs0, g_packs1)) + if const_expr(not _single_b_pipe): + cur_up_w.append((u_packs0, u_packs1)) + + cur_a_scale = [] + for _mi_p in range_constexpr(m_repeat_packed): + cur_a_scale.append( + vector.from_elements( + T.vec(1, T.i32), + [_new_as_list[_mi_p]], + ) + ) + cur_gate_bs = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + cur_gate_bs.append( + vector.from_elements( + T.vec(1, T.i32), [_new_gs_list[_gs_ni]] + ) + ) + if const_expr(not _single_b_pipe): + cur_up_bs = [] + for _us_ni in range_constexpr(num_acc_n_packed): + cur_up_bs.append( + vector.from_elements( + T.vec(1, T.i32), [_new_us_list[_us_ni]] + ) + ) + else: + cur_up_bs = None + + if const_expr(not use_async_copy): + store_x_tile_to_lds(_x_regs, lds_write) + + return ( + cur_a_tile, + cur_gate_w, + cur_up_w, + cur_a_scale, + cur_gate_bs, + cur_up_bs, + acc_gate, + acc_up, + ) + + # Pipeline (split ping/pong allocators) + rocdl.sched_barrier(0) + + k0 = k_base_idx + if const_expr(use_async_copy): + prefetch_x_to_lds(k0, lds_x_pong) + else: + x_regs0 = load_x_tile(k0) + store_x_tile_to_lds(x_regs0, lds_x_pong) + rocdl.sched_barrier(0) + _k0_scale = k_base_idx // arith.constant(pack_K * 128, index=True) a_scale_pong, gate_bs_pong, up_bs_pong = prefetch_ab_scale_tile( - k0 // 2 + _k0_scale ) - store_x_tile_to_lds(x_regs0, lds_base_cur) - gpu.barrier() + _c_tile_m_idx = arith.constant(tile_m, index=True) + _tid_in_range = arith.cmpi(CmpIPredicate.ult, tx, _c_tile_m_idx) + _if_tid = scf.IfOp(_tid_in_range) + with ir.InsertionPoint(_if_tid.then_block): + _tid_row = bx_m + tx + _tid_val = buffer_ops.buffer_load( + sorted_rsrc, _tid_row, vec_width=1, dtype=T.i32 + ) + _tid_vec1 = vector.from_elements(T.vec(1, T.i32), [_tid_val]) + vector.store(_tid_vec1, lds_tid, [tx]) + scf.YieldOp([]) - # Loop-carried ping/pong state. - lds_base_pong = lds_base_cur - lds_base_ping = lds_base_nxt - gate_w_pong = gate_w0 - up_w_pong = up_w0 + acc_gate = [acc_init] * num_acc_n * m_repeat + acc_up = ( + [acc_init] * num_acc_n * m_repeat if not _single_b_pipe else None + ) - a0_prefetch_pong = None + _k1 = k_base_idx + arith.constant(tile_k, index=True) + rocdl.sched_barrier(0) + if const_expr(use_async_copy): + prefetch_x_to_lds(_k1, lds_x_ping) + else: + _x_regs_prime = load_x_tile(_k1) + store_x_tile_to_lds(_x_regs_prime, lds_x_ping) + + _k0_b = k_base_idx // arith.constant(2, index=True) + gate_w0, up_w0 = load_b_tile(_k0_b) + # Prime the deep pipeline: DMA K=tile_k -> ping (1 tile ahead) + if const_expr(use_async_copy): + rocdl.s_waitcnt(0) + gpu.barrier() + rocdl.sched_barrier(0) + a_tile_pong = prefetch_full_a_from_lds(lds_x_pong) - if os.environ.get("FLYDSL_STAGE1_EARLY_RETURN", "0") == "1": - return + rocdl.sched_barrier(0) + rocdl.s_waitcnt(6) - num_k_tiles_py = int(model_dim) // int(tile_k) + num_k_tiles_py = int(_k_dim) // int(tile_k) odd_k_tiles = (num_k_tiles_py % 2) == 1 tail_tiles = 1 if odd_k_tiles else 2 k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) - if k_main2_py < 0: + if const_expr(k_main2_py < 0): k_main2_py = 0 - _skip_compute = ( - os.environ.get("FLYDSL_STAGE1_SKIP_COMPUTE", "0") - == "1" - ) - if k_main2_py > 0: - for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): - k_iv = k_iv_py - next_k1 = k_iv + tile_k - x_regs_ping = load_x_tile(next_k1) - gate_w_ping, up_w_ping = load_b_tile(next_k1 // 2) - a_scale_ping, gate_bs_ping, up_bs_ping = ( - prefetch_ab_scale_tile(next_k1 // pack_K // 128) - ) + gate_w_pong = gate_w0 + up_w_pong = up_w0 - if _skip_compute: - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - gpu.barrier() - a0_prefetch_ping = None - next_k2 = k_iv + (tile_k * 2) - x_regs_pong = load_x_tile(next_k2) - gate_w_pong, up_w_pong = load_b_tile(next_k2 // 2) - a_scale_pong, gate_bs_pong, up_bs_pong = ( - prefetch_ab_scale_tile(next_k2 // pack_K // 128) - ) - store_x_tile_to_lds(x_regs_pong, lds_base_pong) - gpu.barrier() - a0_prefetch_pong = None - continue + rocdl.sched_barrier(0) - acc_gate, acc_up, _ = compute_f8f6f4_tile( + if const_expr(k_main2_py > 0): + for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): + next_k_load_1 = k_iv_py + tile_k + next_k_load_2 = k_iv_py + tile_k * 2 + next_k_dma_1 = k_iv_py + tile_k * 2 + next_k_dma_2 = k_iv_py + tile_k * 3 + + # Half 1: read ping (DMA'd prev half), DMA->pong, MFMA(pong) + ( + a_tile_ping, + gate_w_ping, + up_w_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, acc_gate, acc_up, + ) = _interleaved_half( + lds_x_ping, + lds_x_pong, + next_k_dma_1, + next_k_load_1, + a_tile_pong, gate_w_pong, up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, - ) - a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - gpu.barrier() - - a0_prefetch_ping = None - - next_k2 = k_iv + (tile_k * 2) - x_regs_pong = load_x_tile(next_k2) - gate_w_pong, up_w_pong = load_b_tile(next_k2 // 2) - a_scale_pong, gate_bs_pong, up_bs_pong = ( - prefetch_ab_scale_tile(next_k2 // pack_K // 128) + a_scale_pong, + gate_bs_pong, + up_bs_pong, + acc_gate, + acc_up, ) - acc_gate, acc_up, _ = compute_f8f6f4_tile( + # Half 2: read pong (DMA'd Half 1), DMA->ping, MFMA(ping) + ( + a_tile_pong, + gate_w_pong, + up_w_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, acc_gate, acc_up, + ) = _interleaved_half( + lds_x_pong, + lds_x_ping, + next_k_dma_2, + next_k_load_2, + a_tile_ping, gate_w_ping, up_w_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - gate_b_scale=gate_bs_ping, - up_b_scale=up_bs_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, + acc_gate, + acc_up, ) - a0_prefetch_ping = None - store_x_tile_to_lds(x_regs_pong, lds_base_pong) - gpu.barrier() - - a0_prefetch_pong = None - if odd_k_tiles: - acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( + # _wave_mod2_b = wave_id % arith.constant(2, index=True) + # _wave_odd = arith.cmpi( + # CmpIPredicate.eq, _wave_mod2_b, arith.constant(1, index=True) + # ) + # _if_wave_odd = scf.IfOp(_wave_odd) + # with ir.InsertionPoint(_if_wave_odd.then_block): + # # gpu.barrier() + # _barrier() + # scf.YieldOp([]) + + if const_expr(odd_k_tiles): + acc_gate, acc_up, epilogue_pf = compute_tile( acc_gate, acc_up, gate_w_pong, up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, + a_tile_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, prefetch_epilogue=True, + ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, ) else: - k_tail1 = k_in - tile_k - x_regs_ping = load_x_tile(k_tail1) - gate_w_ping, up_w_ping = load_b_tile(k_tail1 // 2) - a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( - k_tail1 // pack_K // 128 - ) - - acc_gate, acc_up, _ = compute_f8f6f4_tile( + _k_tail_rel = arith.constant(_k_dim - tile_k, index=True) + k_tail1 = k_base_idx + _k_tail_rel + x_regs_ping = [] + if const_expr(use_async_copy): + prefetch_x_to_lds(k_tail1, lds_x_ping) + else: + x_regs_ping = load_x_tile(k_tail1) + if _pad_ku_skip > 0: + gate_w_ping, up_w_ping = load_b_tile( + k_tail1 // arith.constant(2, index=True), + ku_limit=_tail_ku, + ) + a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( + k_tail1 // arith.constant(pack_K * 128, index=True), + ku_packed_limit=_tail_ku_packed, + ) + else: + gate_w_ping, up_w_ping = load_b_tile( + k_tail1 // arith.constant(2, index=True) + ) + a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( + k_tail1 // arith.constant(pack_K * 128, index=True) + ) + acc_gate, acc_up, _ = compute_tile( acc_gate, acc_up, gate_w_pong, up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, + a_tile_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, ) - a0_prefetch_pong = None - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - gpu.barrier() - - a0_prefetch_ping = None - - acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( + if const_expr(not use_async_copy): + store_x_tile_to_lds(x_regs_ping, lds_x_ping) + rocdl.s_waitcnt(0) + _barrier() + if _pad_ku_skip > 0: + a_tile_ping = prefetch_full_a_from_lds( + lds_x_ping, ku_limit=_tail_ku + ) + else: + a_tile_ping = prefetch_full_a_from_lds(lds_x_ping) + acc_gate, acc_up, epilogue_pf = compute_tile( acc_gate, acc_up, gate_w_ping, up_w_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - gate_b_scale=gate_bs_ping, - up_b_scale=up_bs_ping, + a_tile_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, prefetch_epilogue=True, + ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, ) - # Store epilogue to out[t, slot, inter] - topk_i32_v = topk_i32 - inter_i32_v = fx.Int32(inter_dim) - mask24_i32 = fx.Int32(0xFFFFFF) - - # Epilogue hoists to keep IR + Python build time small: - col_i32_list = [] - for ni in range_constexpr(num_acc_n): - col_i32_list.append(arith.index_cast(T.i32, col_g_list[ni])) - - _lane_div_16_mul4 = lane_div_16 * arith.index(4) - inter_i32_local = inter_i32_v - - # Optional: CK-style CShuffle epilogue for better global store coalescing. - # Uses EVec=4 (buffer store "x4" of fp16 elements). - _use_cshuffle_epilog = (out_dtype == "fp8") or bool( - use_cshuffle_epilog - ) - - _mask_even_i32 = fx.Int32(0xFFFFFFFE) - - if _use_cshuffle_epilog: - if lds_out is None: - raise RuntimeError( - "CShuffle epilogue enabled but lds_out is not allocated/aliased." + bias_pf = None + if const_expr(epilogue_pf is not None): + _, _, bias_pf = epilogue_pf + + # Activation helpers (f32 element-wise on vec4_f32) + def _silu_elem(g): + """silu(x) = x * sigmoid(x); HW fast path: exp2, rcp""" + neg_log2e = arith.constant(-1.4426950408889634, type=f32) + t = g * neg_log2e + emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) + one = arith.constant(1.0, type=f32) + den = one + emu + sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) + return g * sig + + def _silu_mul_vec4(gate_v4, up_v4): + """Element-wise silu(gate) * up on vec4_f32.""" + result_elems = [] + for ei in range_constexpr(4): + g = vector.extract( + gate_v4, static_position=[ei], dynamic_position=[] ) + u = vector.extract( + up_v4, static_position=[ei], dynamic_position=[] + ) + result_elems.append(_silu_elem(g) * u) + return vector.from_elements(vec4_f32, result_elems) - def write_row_to_lds( - *, - mi: int, - ii: int, - row_in_tile, - row, - row_base_lds, - col_base_local, - num_acc_n: int, - lds_out, - ): - # `row` is the sorted-row index (bx_m + row_in_tile). - row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi( - arith.CmpIPredicate.ult, row_i32, max_token_id_i32 + def _swiglu_mul_vec4(gate_v4, up_v4): + """Element-wise swiglu(gate, up) on vec4_f32. + swiglu(g, u) = g * sigmoid(alpha * g) * (u + 1) + with clamping: gate <= limit, -limit <= up <= limit. + """ + result_elems = [] + _alpha = arith.constant(1.702, type=f32) + _limit = arith.constant(7.0, type=f32) + _neg_limit = arith.constant(-7.0, type=f32) + _one = arith.constant(1.0, type=f32) + _neg_log2e = arith.constant(-1.4426950408889634, type=f32) + for ei in range_constexpr(4): + g = vector.extract( + gate_v4, static_position=[ei], dynamic_position=[] ) - row_safe = arith.select( - row_valid0, row, fx.Index(0) + u = vector.extract( + up_v4, static_position=[ei], dynamic_position=[] ) - fused2 = buffer_ops.buffer_load( - sorted_rsrc, row_safe, vec_width=1, dtype=T.i32 + g = arith.minimumf(g, _limit) + u = arith.minimumf(u, _limit) + u = arith.maximumf(u, _neg_limit) + t = g * _alpha * _neg_log2e + emu = llvm.call_intrinsic( + f32, "llvm.amdgcn.exp2.f32", [t], [], [] ) - _t2 = fused2 & mask24_i32 + den = _one + emu + sig = llvm.call_intrinsic( + f32, "llvm.amdgcn.rcp.f32", [den], [], [] + ) + result_elems.append(g * sig * (u + _one)) + return vector.from_elements(vec4_f32, result_elems) - # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: - tw = buffer_ops.buffer_load( - sorted_w_rsrc, row_safe, vec_width=1, dtype=T.f32 + def _act_vec4(gate_v4, up_v4): + """Dispatch activation based on `act` parameter.""" + if act == "swiglu": + return _swiglu_mul_vec4(gate_v4, up_v4) + else: + return _silu_mul_vec4(gate_v4, up_v4) + + # Add bias to raw GEMM accumulators before activation. + # bias layout: [E, 2*inter_dim] flat f32 (non-interleaved: gate then up). + # For gate_up_interleave, map physical column to logical bias offset. + if const_expr(enable_bias and not _is_splitk): + if const_expr(bias_pf is not None): + _bias_gate_vals = bias_pf + else: + _bias_gate_vals = [] + for _ni in range_constexpr(num_acc_n): + if const_expr(gate_up_interleave): + _logical_col = ( + (by_n + n_tile_base) + // arith.constant(2, index=True) + + arith.constant((_ni // 2) * 16, index=True) + + lane_mod_16 + ) + _up_off = ( + inter_idx + if (_ni % 2 == 1) + else arith.constant(0, index=True) + ) + _bias_off = expert_off_idx + _up_off + _logical_col + else: + _bn = ( + by_n + + n_tile_base + + arith.constant(_ni * 16, index=True) + + lane_mod_16 + ) + _bias_off = expert_off_idx + _bn + _bias_gate_vals.append( + buffer_ops.buffer_load( + bias_rsrc, _bias_off, vec_width=1, dtype=f32 + ) ) - - for ni in range_constexpr(num_acc_n): - col_local = col_base_local + (ni * 16) - - acc_idx = mi * num_acc_n + ni - vg = vector.extract( - acc_gate[acc_idx], - static_position=[ii], - dynamic_position=[], + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + _bsplat = vector.from_elements( + vec4_f32, [_bias_gate_vals[_ni]] * 4 ) - vu = vector.extract( - acc_up[acc_idx], - static_position=[ii], - dynamic_position=[], + acc_gate[_aidx] = arith.addf(acc_gate[_aidx], _bsplat) + + if const_expr(not (mock_gate_only or gate_up_interleave)): + _bias_up_vals = [] + for _ni in range_constexpr(num_acc_n): + _bn = ( + by_n + + n_tile_base + + arith.constant(_ni * 16, index=True) + + lane_mod_16 + ) + _bias_up_vals.append( + buffer_ops.buffer_load( + bias_rsrc, + expert_off_idx + inter_idx + _bn, + vec_width=1, + dtype=f32, + ) + ) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + _bsplat = vector.from_elements( + vec4_f32, [_bias_up_vals[_ni]] * 4 + ) + acc_up[_aidx] = arith.addf(acc_up[_aidx], _bsplat) + + if const_expr(gate_up_interleave and not _is_splitk): + _gui_out_n = num_acc_n // pack_N + acc = [None] * (_gui_out_n * m_repeat) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(_gui_out_n): + _g_idx = _mi * num_acc_n + _ni * pack_N + _u_idx = _g_idx + 1 + _out_idx = _mi * _gui_out_n + _ni + acc[_out_idx] = _act_vec4( + acc_gate[_g_idx], acc_gate[_u_idx] ) + elif const_expr(not _is_splitk): + acc = [None] * (int(num_acc_n) * int(m_repeat)) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + acc[_aidx] = _silu_mul_vec4(acc_gate[_aidx], acc_up[_aidx]) + + # ---- Epilogue: CShuffle + direct store (accumulate=False) ---- + # Output: out[(t*topk+s) * inter_dim + col] = silu(gate) * up + # For split-K: skip silu, output gate/up separately with atomic add + tw_pf = None + bias_pf = None + if const_expr(epilogue_pf is not None): + _, tw_pf, bias_pf = epilogue_pf - if enable_bias: - gate_bias_list, up_bias_list = epilogue_pf - vg = vg + gate_bias_list[ni] - vu = vu + up_bias_list[ni] + mask24_i32 = arith.constant(0xFFFFFF) + topk_i32_v = topk_i32 + tokens_i32_v = tokens_i32 - if act == "swiglu": - y = swiglu(vg, vu) - else: - y = silu(vg) * vu + from flydsl._mlir.dialects import fly as _fly - if doweight_stage1: - y = y * tw + _llvm_ptr_ty = ir.Type.parse("!llvm.ptr") + out_base_ptr = _fly.extract_aligned_pointer_as_index( + _llvm_ptr_ty, arg_out + ) + out_base_i64 = llvm.ptrtoint(T.i64, out_base_ptr) + out_base_idx = arith.index_cast(ir.IndexType.get(), out_base_i64) - lds_idx = row_base_lds + col_local - v1 = vector.from_elements(vec1_f32, [y]) - vector.store(v1, lds_out, [lds_idx], alignment=1) + if const_expr(lds_out is None): + raise RuntimeError("CShuffle epilogue requires lds_out") - def precompute_row(*, row_local, row): - row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi( - arith.CmpIPredicate.ult, row_i32, max_token_id_i32 - ) - row_safe = arith.select( - row_valid0, row, fx.Index(0) - ) - fused2 = buffer_ops.buffer_load( - sorted_rsrc, row_safe, vec_width=1, dtype=T.i32 - ) - t2 = fused2 & mask24_i32 - s2 = fused2 >> 24 - t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32) - s_valid = arith.cmpi(arith.CmpIPredicate.ult, s2, topk_i32_v) - ts_valid = row_valid0 & (t_valid & s_valid) - t2_safe = arith.select(ts_valid, t2, fx.Int32(0)) - s2_safe = arith.select(ts_valid, s2, fx.Int32(0)) - idx0 = (t2_safe * topk_i32_v + s2_safe) * inter_i32_local - return idx0, ts_valid - - def store_pair( - *, row_local, row, row_ctx, col_pair0, col_g0, frag - ): - idx0 = row_ctx - col_i32 = arith.index_cast(T.i32, col_g0) - idx_out = idx0 + col_i32 - if out_dtype == "fp8": - frag = vector.bitcast(vec4_f32, frag) - frag0 = vector.extract( - frag, static_position=[0], dynamic_position=[] - ) - frag1 = vector.extract( - frag, static_position=[1], dynamic_position=[] + _apply_weight = doweight_stage1 and not _is_splitk + + def write_row_to_lds( + *, + mi: int, + ii: int, + row_in_tile, + row, + row_base_lds, + col_base_local, + num_acc_n: int, + lds_out, + ): + if const_expr(_apply_weight): + tw_idx = (mi * 4) + ii + if const_expr(tw_pf is not None): + tw = tw_pf[tw_idx] + else: + tw = buffer_ops.buffer_load( + sorted_w_rsrc, row, vec_width=1, dtype=f32 ) - frag2 = vector.extract( - frag, static_position=[2], dynamic_position=[] + for ni in range_constexpr(num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * num_acc_n + ni + v = vector.extract( + acc[acc_idx], static_position=[ii], dynamic_position=[] + ) + if const_expr(_apply_weight): + v = v * tw + if const_expr(_need_quant): + lds_idx = row_base_lds + col_local + vec1_f32 = T.vec(1, f32) + v1 = vector.from_elements(vec1_f32, [v]) + vector.store(v1, lds_out, [lds_idx], alignment=4) + else: + v_out = arith.trunc_f(out_elem(), v) + lds_idx = row_base_lds + col_local + vec1_out = T.vec(1, out_elem()) + v1 = vector.from_elements(vec1_out, [v_out]) + vector.store(v1, lds_out, [lds_idx], alignment=2) + + _out_row_stride = ( + inter_dim * 2 * out_elem_bytes + if _is_splitk + else ( + inter_dim // 2 + if _need_fp4 + else (inter_dim if _need_fp8 else inter_dim * out_elem_bytes) + ) + ) + + def precompute_row(*, row_local, row): + fused2 = memref.load(lds_tid, [row_local]) + row_i32 = arith.index_cast(T.i32, row) + row_valid0 = arith.cmpi(CmpIPredicate.ult, row_i32, num_valid_i32) + t = fused2 & mask24_i32 + s = fused2 >> 24 + t_ok = arith.cmpi(CmpIPredicate.ult, t, tokens_i32_v) + s_ok = arith.cmpi(CmpIPredicate.ult, s, topk_i32_v) + row_valid = arith.andi(row_valid0, arith.andi(t_ok, s_ok)) + t_idx = arith.index_cast(ir.IndexType.get(), t) + s_idx = arith.index_cast(ir.IndexType.get(), s) + ts_idx = t_idx * arith.constant(topk, index=True) + s_idx + row_byte_base = out_base_idx + ts_idx * arith.constant( + _out_row_stride, index=True + ) + return ((fused2, row_byte_base), row_valid) + + def _idx_to_llvm_ptr(idx_val, addr_space=1): + idx_v = idx_val._value if hasattr(idx_val, "_value") else idx_val + i64_v = arith.index_cast(T.i64, idx_v) + i64_raw = i64_v._value if hasattr(i64_v, "_value") else i64_v + ptr_ty = ir.Type.parse(f"!llvm.ptr<{addr_space}>") + return llvm.inttoptr(ptr_ty, i64_raw) + + _e_vec = _e_vec_s1 + _e_vec_sk = 2 + _cshuffle_nlane = min(32, tile_n // _e_vec) + _cshuffle_nlane_sk = min(32, tile_n // _e_vec_sk) + _num_threads_per_quant_blk = _num_threads_per_quant_blk_s1 + + _c0_i32 = arith.constant(0, type=T.i32) + _c1_i32 = arith.constant(1, type=T.i32) + _c2_i32 = arith.constant(2, type=T.i32) + _c3_i32 = arith.constant(3, type=T.i32) + _c4_i32 = arith.constant(4, type=T.i32) + _c5_i32 = arith.constant(5, type=T.i32) + _c7_i32 = arith.constant(7, type=T.i32) + _c15_i32 = arith.constant(15, type=T.i32) + _c21_i32 = arith.constant(21, type=T.i32) + _c23_i32 = arith.constant(23, type=T.i32) + _c28_i32 = arith.constant(28, type=T.i32) + _c31_i32 = arith.constant(31, type=T.i32) + _c32_i32 = arith.constant(32, type=T.i32) + _c64_i32 = arith.constant(64, type=T.i32) + _c126_i32 = arith.constant(126, type=T.i32) + _c127_i32 = arith.constant(127, type=T.i32) + _c254_i32 = arith.constant(254, type=T.i32) + _c256_i32 = arith.constant(256, type=T.i32) + _c0xFF_i32 = arith.constant(0xFF, type=T.i32) + _c0x200000_i32 = arith.constant(0x200000, type=T.i32) + _c0xFF800000_i32 = arith.constant(0xFF800000, type=T.i32) + _c0x400000_i32 = arith.constant(0x400000, type=T.i32) + _c0x7FFFFF_i32 = arith.constant(0x7FFFFF, type=T.i32) + _c0x80000000_i32 = arith.constant(0x80000000, type=T.i32) + _c0_f32 = arith.constant(0.0, type=T.f32) + + _c8_i32 = arith.constant(8, type=T.i32) + _fp_headroom = 2 if _need_fp4 else (8 if _need_fp8 else 0) + _c_headroom_i32 = arith.constant(_fp_headroom, type=T.i32) + + def _f32_to_e2m1(qx_f32): + """Convert a scaled f32 value to fp4 (e2m1) 4-bit integer.""" + qx = qx_f32.bitcast(T.i32) + s = qx & _c0x80000000_i32 + e = (qx >> _c23_i32) & _c0xFF_i32 + m = qx & _c0x7FFFFF_i32 + adj_exp = arith.maxsi(_c126_i32 - e, _c0_i32) + m_denorm = (_c0x400000_i32 | (m >> _c1_i32)) >> adj_exp + is_denorm = arith.cmpi(CmpIPredicate.ult, e, _c127_i32) + m = arith.select(is_denorm, m_denorm, m) + e = arith.maxsi(e - _c126_i32, _c0_i32) + combined = (e << _c2_i32) | (m >> _c21_i32) + rounded = (combined + _c1_i32) >> _c1_i32 + e2m1 = arith.minui(rounded, _c7_i32) + return (s >> _c28_i32) | e2m1 + + if const_expr(_need_sort): + _n32_sort = _sorted_scale_cols_i32 * _c32_i32 + + # Mutable slot for split-K N-offset (gate=0, up=inter_dim) + _sk_n_offset = [0] + + def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + fused, row_byte_base = row_ctx + if const_expr(_need_quant and not _is_splitk): + frag_vals = [] + for i in range_constexpr(_e_vec): + frag_vals.append( + vector.extract( + frag, static_position=[i], dynamic_position=[] + ) ) - frag3 = vector.extract( - frag, static_position=[3], dynamic_position=[] + + local_max = _c0_f32 + for i in range_constexpr(_e_vec): + abs_v = llvm.call_intrinsic( + f32, "llvm.fabs.f32", [frag_vals[i]], [], [] ) + local_max = arith.maximumf(local_max, abs_v) + + for _si in range_constexpr(_num_shuffle_steps_s1): + off = arith.constant(_shuffle_dists_s1[_si], type=T.i32) + peer = local_max.shuffle_xor(off, _c64_i32) + local_max = arith.maximumf(local_max, peer) + + max_i32 = local_max.bitcast(T.i32) + max_rounded = (max_i32 + _c0x200000_i32) & _c0xFF800000_i32 + exp_field = max_rounded >> _c23_i32 + e8m0_biased = arith.maxsi(exp_field - _c_headroom_i32, _c0_i32) + + quant_exp = _c254_i32 - e8m0_biased + quant_scale = (quant_exp << _c23_i32).bitcast(T.f32) + + if const_expr(_need_fp4): + fp4_vals = [] + for i in range_constexpr(_e_vec): + scaled_v = frag_vals[i] * quant_scale + fp4_vals.append(_f32_to_e2m1(scaled_v)) + + packed_i32 = fp4_vals[0] | (fp4_vals[1] << _c4_i32) + for k in range_constexpr(1, _e_vec // 2): + byte_k = fp4_vals[2 * k] | ( + fp4_vals[2 * k + 1] << _c4_i32 + ) + packed_i32 = packed_i32 | ( + byte_k << arith.constant(k * 8, type=T.i32) + ) - out_fp8 = fx.Int32(0) - out_fp8 = rocdl.cvt_pk_fp8_f32( - src_a=arith.unwrap(frag0), - src_b=arith.unwrap(frag1), - old=arith.unwrap(out_fp8), - word_sel=0, - res=T.i32, + ptr_addr_idx = row_byte_base + col_g0 / arith.constant( + 2, index=True ) - out_fp8 = rocdl.cvt_pk_fp8_f32( - src_a=arith.unwrap(frag2), - src_b=arith.unwrap(frag3), - old=arith.unwrap(out_fp8), - word_sel=1, - res=T.i32, + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + _pack_bytes = _e_vec // 2 + if const_expr(_pack_bytes == 1): + store_val = arith.TruncIOp(T.i8, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, out_ptr_v, alignment=1, nontemporal=True + ) + elif const_expr(_pack_bytes == 2): + store_val = arith.TruncIOp(T.i16, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, out_ptr_v, alignment=2, nontemporal=True + ) + else: + packed_raw = ( + packed_i32._value + if hasattr(packed_i32, "_value") + else packed_i32 + ) + llvm.StoreOp( + packed_raw, out_ptr_v, alignment=4, nontemporal=True + ) + + elif const_expr(_need_fp8): + scaled_vals = [] + for i in range_constexpr(_e_vec): + scaled_vals.append(frag_vals[i] * quant_scale) + + ptr_addr_idx = row_byte_base + col_g0 + if const_expr(_e_vec <= 4): + packed_i32 = _c0_i32 + for _w in range_constexpr(_e_vec // 2): + packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[2 * _w], + scaled_vals[2 * _w + 1], + packed_i32, + _w, + ) + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + if _e_vec == 2: + store_val = arith.TruncIOp(T.i16, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, + out_ptr_v, + alignment=2, + nontemporal=True, + ) + else: + packed_raw = ( + packed_i32._value + if hasattr(packed_i32, "_value") + else packed_i32 + ) + llvm.StoreOp( + packed_raw, + out_ptr_v, + alignment=4, + nontemporal=True, + ) + else: + for _wg in range_constexpr(_e_vec // 4): + _b = _wg * 4 + packed_w = _c0_i32 + packed_w = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[_b], + scaled_vals[_b + 1], + packed_w, + 0, + ) + packed_w = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[_b + 2], + scaled_vals[_b + 3], + packed_w, + 1, + ) + word_ptr = ptr_addr_idx + arith.constant( + _wg * 4, index=True + ) + out_ptr_v = _idx_to_llvm_ptr(word_ptr) + packed_raw = ( + packed_w._value + if hasattr(packed_w, "_value") + else packed_w + ) + llvm.StoreOp( + packed_raw, + out_ptr_v, + alignment=4, + nontemporal=True, + ) + + if const_expr(_need_sort): + col_g0_i32 = arith.index_cast(T.i32, col_g0) + is_scale_writer = arith.cmpi( + CmpIPredicate.eq, col_g0_i32 & _c31_i32, _c0_i32 ) - buffer_ops.buffer_store(out_fp8, out_rsrc, idx_out // 4) - else: - out_vec_ty = T.vec(4, _out_elem_type()) - out_vals = [] - for fi in range_constexpr(4): - frag_i = vector.extract( - frag, static_position=[fi], dynamic_position=[] + _if_scale = scf.IfOp(is_scale_writer) + with ir.InsertionPoint(_if_scale.then_block): + row_i32_s = arith.index_cast(T.i32, row) + col_s_i32 = col_g0_i32 >> _c5_i32 + d0 = row_i32_s >> _c5_i32 + d1 = (row_i32_s >> _c4_i32) & _c1_i32 + d2 = row_i32_s & _c15_i32 + d3 = col_s_i32 >> _c3_i32 + d4 = (col_s_i32 >> _c2_i32) & _c1_i32 + d5 = col_s_i32 & _c3_i32 + byte_off = ( + d0 * _n32_sort + + d3 * _c256_i32 + + d5 * _c64_i32 + + d2 * _c4_i32 + + d4 * _c2_i32 + + d1 ) - out_vals.append( - arith.trunc_f(_out_elem_type(), frag_i) + e8m0_i8 = arith.TruncIOp(T.i8, e8m0_biased) + buffer_ops.buffer_store( + e8m0_i8, + sorted_scale_rsrc, + byte_off, + offset_is_bytes=True, ) - out_vec = vector.from_elements(out_vec_ty, out_vals) - buffer_ops.buffer_store(out_vec, out_rsrc, idx_out) + scf.YieldOp([]) + elif const_expr(_is_splitk): + col_idx = col_g0 + arith.constant(_sk_n_offset[0], index=True) + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + out_ptr_v, + frag_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=_e_vec_sk * out_elem_bytes, + ) + else: + col_idx = col_g0 + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.StoreOp( + frag_v, + out_ptr_v, + alignment=_e_vec * out_elem_bytes, + nontemporal=True, + ) - mfma_epilog( - use_cshuffle=True, + _frag_elem = ( + ir.F32Type.get() + if _need_quant + else (ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get()) + ) + + if const_expr(gate_up_interleave and not _is_splitk): + # gui without splitk: acc has activation applied, halved N + _gui_eff_n = _gui_out_n + _gui_tile_n = tile_n // 2 + _gui_cshuffle_nlane = min(32, _gui_tile_n // _e_vec) + _gui_by_n = by_n / arith.constant(2, index=True) + _gui_n_tile_base = n_tile_base / arith.constant(2, index=True) + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=_gui_tile_n, + e_vec=_e_vec, + cshuffle_nlane=_gui_cshuffle_nlane, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=_gui_eff_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=_gui_by_n, + n_tile_base=_gui_n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + ) + elif const_expr(mock_gate_only or (gate_up_interleave and _is_splitk)): + # mock_gate_only: single pass, by_n covers full [0, 2*inter_dim) + _eff_e_vec = _e_vec_sk + acc = acc_gate + c_shuffle_epilog( arith=arith, vector=vector, gpu=gpu, @@ -1332,7 +2480,9 @@ def store_pair( range_constexpr=range_constexpr, tile_m=tile_m, tile_n=tile_n, - e_vec=4, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, m_repeat=m_repeat, num_acc_n=num_acc_n, tx=tx, @@ -1342,85 +2492,118 @@ def store_pair( by_n=by_n, n_tile_base=n_tile_base, lds_out=lds_out, - frag_elem_type=T.f32, + frag_elem_type=_frag_elem, write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, + lds_out_split=lds_out_B, ) - return - - def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): - # `row` is the sorted-row index (bx_m + row_in_tile). - row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi( - arith.CmpIPredicate.ult, row_i32, max_token_id_i32 - ) - row_safe = arith.select( - row_valid0, row, fx.Index(0) - ) - fused2 = buffer_ops.buffer_load( - sorted_rsrc, row_safe, vec_width=1, dtype=T.i32 + elif const_expr(_is_splitk): + # Two-pass epilogue: gate then up, each with atomic add + _eff_e_vec = _e_vec_sk + + # Pass 1: gate + acc = acc_gate + _sk_n_offset[0] = 0 + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + lds_out_split=lds_out_B, ) - t2 = fused2 & mask24_i32 - s2 = fused2 >> 24 - t_valid = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32) - s_valid = arith.cmpi(arith.CmpIPredicate.ult, s2, topk_i32_v) - ts_valid = row_valid0 & (t_valid & s_valid) - t2_safe = arith.select(ts_valid, t2, fx.Int32(0)) - s2_safe = arith.select(ts_valid, s2, fx.Int32(0)) - - # out linear index base = ((t*topk + s)*inter_dim) (invariant across ni) - idx0 = (t2_safe * topk_i32_v + s2_safe) * inter_i32_local - - # Sorted weight aligned with `row` (matches aiter moe_sorting output). - if doweight_stage1: - tw = buffer_ops.buffer_load( - sorted_w_rsrc, row_safe, vec_width=1, dtype=T.f32 - ) - _if_valid = scf.IfOp(ts_valid) - with _if_then(_if_valid): - for ni in range_constexpr(num_acc_n): - col_i32 = col_i32_list[ni] - acc_idx = mi * num_acc_n + ni - vg = vector.extract( - acc_gate[acc_idx], - static_position=[ii], - dynamic_position=[], - ) - vu = vector.extract( - acc_up[acc_idx], - static_position=[ii], - dynamic_position=[], - ) - if enable_bias: - gate_bias_list, up_bias_list = epilogue_pf - vg = vg + gate_bias_list[ni] - vu = vu + up_bias_list[ni] - - if act == "swiglu": - y = swiglu(vg, vu) - else: - y = silu(vg) * vu + gpu.barrier() - if doweight_stage1: - y = y * tw + # Pass 2: up + acc = acc_up + _sk_n_offset[0] = inter_dim + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + lds_out_split=lds_out_B, + ) + else: + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_e_vec, + cshuffle_nlane=_cshuffle_nlane, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + lds_out_split=lds_out_B, + ) - y = arith.trunc_f(_out_elem_type(), y) - idx_out = idx0 + col_i32 - buffer_ops.buffer_store(y, out_rsrc, idx_out) + _if_blk = scf.IfOp(blk_valid) + with ir.InsertionPoint(_if_blk.then_block): + _ifexpert_of = scf.IfOp(exp_valid) + with ir.InsertionPoint(_ifexpert_of.then_block): + _moe_gemm1_body() + scf.YieldOp([]) + scf.YieldOp([]) - mfma_epilog( - use_cshuffle=False, - arith=arith, - range_constexpr=range_constexpr, - m_repeat=m_repeat, - lane_div_16=lane_div_16, - bx_m=bx_m, - body_row=_stage1_store_row, - ) + gpu.barrier() + scf.YieldOp([]) + _for_ip.__exit__(None, None, None) - # -- Host launcher (flyc.jit + .launch) -------------------------------- + # -- Host launcher -- _cache_tag = ( module_name, a_dtype, @@ -1435,6 +2618,13 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): model_dim_pad, inter_dim_pad, use_cshuffle_epilog, + persist_m, + use_async_copy, + waves_per_eu, + k_batch, + gate_mode, + a_scale_one, + xcd_swizzle, ) @flyc.jit @@ -1449,6 +2639,7 @@ def launch_mixed_moe_gemm1( arg_sorted_weights: fx.Tensor, arg_max_token_ids: fx.Tensor, arg_bias: fx.Tensor, + arg_out_scale_sorted: fx.Tensor, i32_tokens_in: fx.Int32, i32_inter_in: fx.Int32, i32_k_in: fx.Int32, @@ -1456,14 +2647,30 @@ def launch_mixed_moe_gemm1( stream: fx.Stream, ): _ = _cache_tag - allocator.finalized = False + allocator_pong.finalized = False + allocator_ping.finalized = False ctx = CompilationContext.get_current() with ir.InsertionPoint(ctx.gpu_module_body): - allocator.finalize() - - inter_in = arith.index_cast(T.index, i32_inter_in) - gx = inter_in // fx.Index(tile_n) - gy = arith.index_cast(T.index, i32_size_expert_ids_in) + allocator_pong.finalize() + allocator_ping.finalize() + + inter_in = arith.index_cast(ir.IndexType.get(), i32_inter_in.ir_value()) + tile_n_index = arith.constant(tile_n, index=True) + inter_dim_pad_total = arith.constant(2 * inter_dim_pad, index=True) + if const_expr(mock_gate_only or gate_up_interleave): + gx = (inter_in - inter_dim_pad_total + tile_n_index - 1) / tile_n_index + else: + gx = ( + (inter_in - inter_dim_pad_total + 2 * tile_n_index - 1) + / tile_n_index + / arith.constant(2, index=True) + ) + _c_pm_l = arith.constant(persist_m, index=True) + gy = ( + arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) + + _c_pm_l + - arith.constant(1, index=True) + ) / _c_pm_l moe_gemm1( arg_out, @@ -1476,20 +2683,17 @@ def launch_mixed_moe_gemm1( arg_sorted_weights, arg_max_token_ids, arg_bias, + arg_out_scale_sorted, i32_tokens_in, i32_inter_in, i32_k_in, i32_size_expert_ids_in, - ).launch( - grid=(gx, gy, 1), - block=(256, 1, 1), - stream=stream, - ) + ).launch(grid=(gx, gy, k_batch), block=(total_threads, 1, 1), stream=stream) return launch_mixed_moe_gemm1 -@functools.lru_cache(maxsize=1024) +@functools.lru_cache(maxsize=None) def compile_mixed_moe_gemm2( *, model_dim: int, @@ -1512,9 +2716,17 @@ def compile_mixed_moe_gemm2( model_dim_pad: int = 0, inter_dim_pad: int = 0, persist_m: int = 4, + sort_block_m: int = 0, + b_nt: int = 2, + xcd_swizzle: int = 0, ): """Compile stage2 kernel (`moe_gemm2`) and return the compiled executable. + persist_m: + - > 0: legacy mode -- each CTA processes exactly persist_m consecutive M tiles. + - <= 0: **persistent mode** -- grid_y = cu_num (auto-detected), each CTA + round-robins over M tiles with stride cu_num. + a_dtype: - "fp8": A2 is fp8 - "fp16": A2 is fp16 (caller uses tile_k halved vs fp8 to match MFMA K halving) @@ -1534,19 +2746,24 @@ def compile_mixed_moe_gemm2( `use_cshuffle_epilog` controls whether we use the LDS CShuffle epilogue before global atomics (recommended for performance). - """ - gpu_arch = get_hip_arch() - allocator = SmemAllocator(None, arch=gpu_arch) - if a_dtype not in ("fp8", "fp16", "int8", "fp4"): - raise ValueError( - f"a_dtype must be one of ('fp8','fp16','int8','fp4'), got {a_dtype!r}" - ) - if b_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"): + `sort_block_m` is the block_size used by moe_sorting / stage1. When 0 (default), + assumed equal to `tile_m`. When set, stage2 can use a different tile_m from + sorting/stage1. Requires sort_block_m % tile_m == 0. + """ + _sort_block_m = tile_m if sort_block_m <= 0 else sort_block_m + if _sort_block_m != tile_m and _sort_block_m % tile_m != 0: raise ValueError( - f"b_dtype must be one of ('fp8','fp16','int8','int4','fp4'), got {b_dtype!r}" + f"sort_block_m ({_sort_block_m}) must be a multiple of tile_m ({tile_m})" ) + gpu_arch = get_hip_arch() + allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") + allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") + _state = {} + + validate_moe_dtypes(a_dtype, b_dtype) + is_f16_a = a_dtype == "fp16" is_f16_b = b_dtype == "fp16" @@ -1554,9 +2771,13 @@ def compile_mixed_moe_gemm2( is_f4_a = a_dtype == "fp4" is_f4_b = b_dtype == "fp4" - pack_M = 2 - pack_N = 2 - pack_K = 2 + _scale_pack_m = 2 # physical mn_pack in preshuffle microscale layout + _scale_pack_n = 2 + _scale_pack_k = 2 # physical k_pack in preshuffle scale layout + pack_M = min(_scale_pack_m, tile_m // 16) + pack_N = min(_scale_pack_n, tile_n // 64) + _k_unroll_raw = (int(tile_k) * (2 if a_dtype == "fp16" else 1)) // 128 + pack_K = min(_scale_pack_k, _k_unroll_raw) elem_bytes = 1 @@ -1568,6 +2789,22 @@ def compile_mixed_moe_gemm2( cbsz = 0 if is_f8_a else 4 blgp = 4 + # ---- Static B preshuffle strides (compile-time) ---- + # All values below are Python ints computable at kernel-compile time. + # Using them in an explicit multiply-add replaces the fly dialect's + # dynamic ``crd2idx`` path which emits Barrett reduction for the + # non-power-of-2 ``n0 = experts*model_dim//16`` shape. + _b_kpack_bytes_s = 8 if (b_dtype == "int4") else 16 + _b_kpack_elems_s = _b_kpack_bytes_s // b_elem_bytes + _b_c_k_s = inter_dim // _scale_pack_k + _b_c_k0_s = (_b_c_k_s * b_elem_bytes) // 64 + _b_stride_nlane = _b_kpack_elems_s # 16 + _b_stride_klane = 16 * _b_stride_nlane # 256 + _b_stride_k0 = 4 * _b_stride_klane # 1024 + _b_stride_n0 = _b_c_k0_s * _b_stride_k0 # c_k0 * 1024 + assert model_dim % 16 == 0, "model_dim must be divisible by 16" + _expert_b_stride = (model_dim // 16) * _b_stride_n0 + # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). if (tile_k_bytes % 64) != 0: raise ValueError( @@ -1587,6 +2824,32 @@ def compile_mixed_moe_gemm2( "compile_moe_gemm2(accumulate=False) only supports out_dtype in {'f16','bf16'}" ) is_int4 = b_dtype == "int4" + # INT4 here means W4A8: A2 is int8, W is packed int4 and unpacked to int8 in-kernel. + is_int8 = False + + mfma_i32_k32 = None + if is_int8: + mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( + rocdl, "mfma_i32_16x16x32_i8", None + ) + if mfma_i32_k32 is None: + raise AttributeError( + "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " + "(or `rocdl.mfma_i32_16x16x32_i8`)." + ) + + def _x_elem_type(): + if is_f4_b: + return T.f8 if is_f8_a else T.i8 + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + + def _w_elem_type(): + if is_f4_b: + return T.i8 + return T.f16 if is_f16_b else (T.i8 if is_int8 else T.f8) + + def _scale_elem_type(): + return T.i32 total_threads = 256 bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) @@ -1597,13 +2860,22 @@ def compile_mixed_moe_gemm2( ) bytes_per_thread_x = bytes_x_per_tile // total_threads - pad_k = 0 + _use_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ( + "1", + "true", + "True", + "YES", + "yes", + ) + pad_k = 0 if _use_lds128 else 8 lds_stride = tile_k + pad_k - # gfx950+ has buffer_atomic_pk_add_bf16; gfx942 uses global atomics via raw pointer. - _has_buffer_atomic_bf16 = str(gpu_arch).startswith(("gfx95", "gfx12")) - _needs_global_atomic_bf16 = out_is_bf16 and not _has_buffer_atomic_bf16 - if out_is_bf16 and not supports_bf16_global_atomics(gpu_arch): - raise ValueError(f"out_dtype='bf16' requires bf16 global atomics, got arch={gpu_arch!r}") + + if a_elem_vec_pack > 1: + _eff_lds_stride = lds_stride // a_elem_vec_pack + _eff_tile_k_bytes = tile_k_bytes // a_elem_vec_pack + else: + _eff_lds_stride = lds_stride + _eff_tile_k_bytes = tile_k_bytes if out_is_f32: # Match origin/dev_a16w4: f32 output uses scalar atomics and does NOT use the CShuffle epilogue. @@ -1616,7 +2888,7 @@ def compile_mixed_moe_gemm2( ) else: if use_cshuffle_epilog is None: - _use_cshuffle_epilog = os.environ.get("FLYDSL_MOE_STAGE2_CSHUFFLE", "1") in ( + _use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE2_CSHUFFLE", "1") in ( "1", "true", "True", @@ -1627,7 +2899,7 @@ def compile_mixed_moe_gemm2( _use_cshuffle_epilog = bool(use_cshuffle_epilog) if not _use_cshuffle_epilog: raise ValueError( - "stage2 f16 output currently requires CShuffle epilogue (FLYDSL_MOE_STAGE2_CSHUFFLE=1)." + "stage2 f16 output currently requires CShuffle epilogue (FLIR_MOE_STAGE2_CSHUFFLE=1)." ) # NOTE: Keep this as a callable so we don't require an MLIR Context at Python-time. @@ -1638,32 +2910,49 @@ def out_elem(): # IMPORTANT: include tiling in the module name to avoid accidentally reusing a compiled # binary for a different (tile_m, tile_n, tile_k) configuration. # See stage1 note: include ABI tag to prevent binary reuse across signature changes. - # IMPORTANT: module name participates in FlyDSL's compile cache key. + # IMPORTANT: module name participates in the compiler cache key. # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. + _persistent = persist_m <= 0 + if _persistent: + from aiter.jit.utils.chip_info import get_cu_num + + _cu_num = get_cu_num() + else: + _cu_num = 0 + _sbm_tag = "" if _sort_block_m == tile_m else f"_sbm{_sort_block_m}" + _pm_tag = f"_persist_cu{_cu_num}" if _persistent else f"_pm{persist_m}" + _xcd_tag = f"_xcd{xcd_swizzle}" if xcd_swizzle > 0 else "" module_name = ( f"mfma_moe2_a{a_dtype}_w{b_dtype}_{out_s}_{epilog_tag}" f"_t{tile_m}x{tile_n}x{tile_k}" - f"_vscale_fix3" + f"_vscale_fix3{_pm_tag}{_sbm_tag}{_xcd_tag}" ).replace("-", "_") # -- LDS sizing (pure Python; no MLIR Context needed) --------------------- - # Reuse a single allocation for both: - # - ping-pong A2 tiles (2 * tile_m * lds_stride * elem_bytes bytes) - # - epilogue CShuffle tile (tile_m * tile_n f16 -> 2 * tile_m * tile_n bytes) - lds_x_bytes = 2 * int(tile_m) * int(lds_stride) * int(a_elem_bytes) + # Ping-pong A2 tiles via separate allocators (like stage1). + _single_x_bytes = int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) + _cshuffle_elem_bytes_s2 = 2 # f16/bf16 = 2 bytes lds_out_bytes = ( - 2 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 - ) # f16 bytes + _cshuffle_elem_bytes_s2 * int(tile_m) * int(tile_n) + if _use_cshuffle_epilog + else 0 + ) lds_tid_bytes = int(tile_m) * 4 - lds_total_bytes = max(lds_x_bytes, lds_out_bytes) + lds_tid_bytes - lds_total_elems = lds_total_bytes if a_elem_bytes == 1 else (lds_total_bytes // 2) + _input_elems = _single_x_bytes if a_elem_bytes == 1 else (_single_x_bytes // 2) + + _pong_buffer_bytes = max(_single_x_bytes, lds_out_bytes) + _ping_buffer_bytes = _single_x_bytes def x_lds_elem(): - return T.f16 if is_f16_a else T.f8 + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + + lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) + allocator_pong.ptr = lds_pong_offset + _pong_buffer_bytes + _lds_tid_offset_pong = allocator_pong._align(allocator_pong.ptr, 4) + allocator_pong.ptr = _lds_tid_offset_pong + lds_tid_bytes - lds_alloc_bytes = int(lds_total_elems) * int(a_elem_bytes) - lds_alloc_offset = allocator._align(allocator.ptr, 16) - allocator.ptr = lds_alloc_offset + lds_alloc_bytes + lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) + allocator_ping.ptr = lds_ping_offset + _ping_buffer_bytes if True: @@ -1684,43 +2973,49 @@ def moe_gemm2( i32_k_in: fx.Int32, i32_size_expert_ids_in: fx.Int32, ): - tokens_in = arith.index_cast(T.index, i32_tokens_in) - n_in = arith.index_cast(T.index, i32_n_in) - k_in = arith.index_cast(T.index, i32_k_in) - size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) - x_elem = T.f16 if is_f16_a else T.f8 - vec4_f32 = T.vec(4, T.f32) - vec4_i32 = T.vec(4, T.i32) + + tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) + n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) + k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) + size_expert_ids_in = arith.index_cast( + ir.IndexType.get(), i32_size_expert_ids_in.ir_value() + ) + x_elem = T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + f32 = T.f32 + i32 = T.i32 + i64 = T.i64 + vec4_f32 = T.vec(4, f32) + vec4_i32 = T.vec(4, i32) vec16_elems = 16 if a_elem_bytes == 1 else 8 vec8_elems = 8 if a_elem_bytes == 1 else 4 vec4_elems = 4 if a_elem_bytes == 1 else 2 vec16_x = T.vec(vec16_elems, x_elem) - vec2_i64 = T.vec(2, T.i64) + vec2_i64 = T.vec(2, i64) - acc_init = arith.constant_vector(0.0, vec4_f32) + acc_init = ( + arith.constant_vector(0, vec4_i32) + if is_int8 + else arith.constant_vector(0.0, vec4_f32) + ) # A2 layout (flatten token-slot -> M; use i32 for fly.make_shape). - topk_idx = fx.Index(topk) + topk_idx = arith.constant(topk, index=True) m_in = tokens_in * topk_idx - # fly.make_shape requires i32/i64, not index - m_i32_v = arith.index_cast(T.i32, m_in) - k_i32_v = i32_k_in # B preshuffle layout: [experts*model_dim, inter_dim] - c_n_total = fx.Index(experts * model_dim) + c_n_total = arith.constant(experts * model_dim, index=True) kpack_bytes = 8 if is_int4 else 16 - b_layout = make_preshuffle_b_layout( - arith, - c_n=c_n_total, - c_k=k_in // pack_K, - kpack_bytes=kpack_bytes, - elem_bytes=b_elem_bytes, - ) - layout_b = b_layout.layout_b + from .layout_utils import _div_pow2, _mod_pow2 + + def check_c_n_valid_gate(base_n): + return arith.cmpi(CmpIPredicate.ult, base_n, model_dim - model_dim_pad) + + def check_c_k_valid_gate(base_k): + return arith.cmpi(CmpIPredicate.ult, base_k, inter_dim - inter_dim_pad) # A&B's scale preshuffle layout # For fp4, k_in is already packed (inter_dim // a_elem_vec_pack), so we need original inter_dim - c_k_orig = fx.Index(inter_dim) + c_k_orig = arith.constant(inter_dim, index=True) layout_a_scale = make_preshuffle_scale_layout( arith, c_mn=m_in, c_k=c_k_orig ) @@ -1728,58 +3023,85 @@ def moe_gemm2( arith, c_mn=c_n_total, c_k=c_k_orig ) - shape_lds = fx.make_shape(tile_m, tile_k) - stride_lds = fx.make_stride(lds_stride, 1) + shape_lds = fx.make_shape(tile_m, _eff_lds_stride) + stride_lds = fx.make_stride(_eff_lds_stride, 1) layout_lds = fx.make_layout(shape_lds, stride_lds) tx = gpu.thread_id("x") - # Align with Aiter launch mapping: - # - blockIdx.x -> N dimension (tile along model_dim) - # - blockIdx.y -> expert-block id / M dimension (tile along sorted M) - by = gpu.block_id("x") # tile along model_dim - bx_persist = gpu.block_id("y") # tile along sorted M + by = gpu.block_id("x") # tile along model_dim (N-dim) + bx_persist = gpu.block_id("y") # persistent WG index (M-dim) + + if const_expr(xcd_swizzle > 0): + _NUM_XCDS_S = 8 + _c1_sw = arith.constant(1, index=True) + _c_tn_sw = arith.constant(tile_n, index=True) + _c_mdp_sw = arith.constant(model_dim_pad, index=True) + _gx = (n_in - _c_mdp_sw + _c_tn_sw - _c1_sw) / _c_tn_sw + if const_expr(_persistent): + _gy = arith.constant(_cu_num, index=True) + else: + _c_pm_sw = arith.constant(persist_m, index=True) + _gy = (size_expert_ids_in + _c_pm_sw - _c1_sw) / _c_pm_sw + + _linear_id = bx_persist * _gx + by + _num_wgs = _gx * _gy + + _c_xcds = arith.constant(_NUM_XCDS_S, index=True) + _wgs_per_xcd = _num_wgs / _c_xcds + _wgid = (_linear_id % _c_xcds) * _wgs_per_xcd + (_linear_id / _c_xcds) + + _WGM_S = xcd_swizzle + _c_wgm = arith.constant(_WGM_S, index=True) + _num_wgid_in_group = _c_wgm * _gx + _group_id = _wgid / _num_wgid_in_group + _first_pid_m = _group_id * _c_wgm + _remaining_m = _gy - _first_pid_m + _cmp_m = arith.cmpi(CmpIPredicate.ult, _remaining_m, _c_wgm) + _group_size_m = arith.select(_cmp_m, _remaining_m, _c_wgm) + + _wgid_in_group = _wgid % _num_wgid_in_group + bx_persist = _first_pid_m + (_wgid_in_group % _group_size_m) + by = _wgid_in_group / _group_size_m # XOR16 swizzle parameter (in bytes; constant, power-of-two in our configs). - k_blocks16 = fx.Index(tile_k_bytes // 16) - base_ptr = allocator.get_base() - lds_x_ptr = SmemPtr( - base_ptr, - lds_alloc_offset, - x_lds_elem(), - shape=(lds_total_elems,), - ) - lds_x = lds_x_ptr.get() - # Alias the same underlying LDS bytes as f16/bf16 for epilogue shuffle. + k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) + layout_tx_wave_lane = fx.make_layout((4, 64), stride=(64, 1)) + layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) + + base_ptr_pong = allocator_pong.get_base() + base_ptr_ping = allocator_ping.get_base() + lds_x_pong = SmemPtr( + base_ptr_pong, lds_pong_offset, x_lds_elem(), shape=(_input_elems,) + ).get() + lds_x_ping = SmemPtr( + base_ptr_ping, lds_ping_offset, x_lds_elem(), shape=(_input_elems,) + ).get() lds_out = ( SmemPtr( - base_ptr, - lds_x_ptr.byte_offset, + base_ptr_pong, + lds_pong_offset, (T.bf16 if out_is_bf16 else T.f16), shape=(tile_m * tile_n,), ).get() if _use_cshuffle_epilog else None ) - - # lds_tid: alias LDS after max(x, out) for sorted_idx preload - _lds_x_b = 2 * int(tile_m) * int(lds_stride) * int(a_elem_bytes) - _lds_out_b = 2 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 - _lds_tid_off = max(_lds_x_b, _lds_out_b) lds_tid = SmemPtr( - base_ptr, lds_x_ptr.byte_offset + _lds_tid_off, T.i32, shape=(tile_m,) + base_ptr_pong, _lds_tid_offset_pong, T.i32, shape=(tile_m,) ).get() # Buffer resources. # For dynamic memrefs, `max_size=False` cannot infer the logical size from the memref *type*, # so we should pass `num_records_bytes` explicitly for stable hardware OOB behavior. - c_topk = fx.Index(topk) + c_topk = arith.constant(topk, index=True) # X(A2): buffer size in bytes, accounting for FP4 packing (2 elements per byte). # fp8/int8: 1 byte per element -> bytes = tokens*topk * K # fp4: 2 elements per byte -> bytes = tokens*topk * K / 2 - c_a_pack = fx.Index(int(a_elem_vec_pack)) - c_elem_bytes = fx.Index(int(a_elem_bytes)) - x_nbytes_idx = ((tokens_in * c_topk) * k_in * c_elem_bytes) // c_a_pack + c_elem_bytes = arith.constant(int(a_elem_bytes), index=True) + x_nbytes_idx = _div_pow2( + (tokens_in * c_topk) * k_in * c_elem_bytes, int(a_elem_vec_pack) + ) x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) x_rsrc = buffer_ops.create_buffer_resource( arg_x, max_size=False, num_records_bytes=x_nbytes_i32 @@ -1790,14 +3112,14 @@ def moe_gemm2( # OUT: [tokens, model_dim] -> clamp to descriptor max (i32 bytes) to avoid overflow on huge tokens. out_elem_bytes = 4 if out_is_f32 else 2 out_nbytes_idx = ( - tokens_in * n_in * fx.Index(out_elem_bytes) + tokens_in * n_in * arith.constant(out_elem_bytes, index=True) ) - if not bool(accumulate): + if const_expr(not bool(accumulate)): out_nbytes_idx = ( tokens_in * arith.index(topk) * n_in - * fx.Index(out_elem_bytes) + * arith.constant(out_elem_bytes, index=True) ) out_nbytes_i32 = arith.index_cast(T.i32, out_nbytes_idx) out_rsrc = buffer_ops.create_buffer_resource( @@ -1808,22 +3130,26 @@ def moe_gemm2( numids_rsrc = buffer_ops.create_buffer_resource( arg_num_valid_ids, max_size=False, - num_records_bytes=fx.Int32(4), + num_records_bytes=arith.constant(4, type=T.i32), ) num_valid_i32 = buffer_ops.buffer_load( - numids_rsrc, fx.Index(0), vec_width=1, dtype=T.i32 + numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32 ) - num_valid_idx = arith.index_cast(T.index, num_valid_i32) + # num_valid_ids is a scalar (same value for all lanes) loaded into + # VGPR. Promote to SGPR so downstream buffer resource descriptors + # that use it for num_records stay in SGPRs, eliminating the + # expensive waterfall loop the compiler would otherwise emit. + num_valid_i32 = rocdl.ReadfirstlaneOp(T.i32, num_valid_i32).res + num_valid_idx = arith.index_cast(ir.IndexType.get(), num_valid_i32) # fp16 path ignores scales completely (implicit scale=1.0). - if is_f16_a: - sx_rsrc = None - else: - if is_f4_a: - # A2 microscale: packed i32 holding e8m0 bytes for [sorted_size, K/32]. - c32 = fx.Index(32) - kblk = k_in // c32 - # Total bytes = num_valid_ids * kblk. + sx_rsrc = 1 + sw_rsrc = 1 + if const_expr(not is_f16_a): + if const_expr(is_f4_a or is_f8_a): + # A2 microscale: e8m0 in sorted layout [sorted_size, K/32]. + # Caller must pre-scatter a2_scale via moe_mxfp4_sort. + kblk = _div_pow2(k_in, 32) sx_nbytes_idx = num_valid_idx * kblk sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( @@ -1831,31 +3157,28 @@ def moe_gemm2( ) else: # scale_x (A2 scale): [tokens*topk] f32 -> bytes = tokens*topk*4 - sx_nbytes_idx = (tokens_in * c_topk) * fx.Index(4) + sx_nbytes_idx = (tokens_in * c_topk) * arith.constant(4, index=True) sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) sx_rsrc = buffer_ops.create_buffer_resource( arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) - if is_f16_b: - sw_rsrc = None - else: + if const_expr(not is_f16_b): # Weight microscale buffer (packed i32 holding e8m0 bytes). # Use an exact descriptor size so hardware OOB checking works. - c32 = fx.Index(32) - kblk_w = k_in // c32 # K/32 - mn_w = fx.Index(experts * model_dim) + kblk_w = _div_pow2(k_in, 32) # K/32 + mn_w = arith.constant(experts * model_dim, index=True) sw_nbytes_idx = mn_w * kblk_w # bytes (e8m0) sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) sw_rsrc = buffer_ops.create_buffer_resource( arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 ) - # sorted_token_ids / sorted_weights: [blocks*tile_m] (CK-style padded length) + # sorted_token_ids / sorted_weights: [blocks*tile_m] (padded length) sorted_nbytes_idx = ( size_expert_ids_in - * fx.Index(tile_m) - * fx.Index(4) + * arith.constant(tile_m, index=True) + * arith.constant(4, index=True) ) sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes_idx) sorted_rsrc = buffer_ops.create_buffer_resource( @@ -1867,8 +3190,14 @@ def moe_gemm2( arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 ) - # expert ids: [blocks] i32 -> bytes = size_expert_ids_in*4 - eid_nbytes_idx = size_expert_ids_in * fx.Index(4) + # expert ids: [sort_blocks] i32. + _c_sbm = arith.constant(_sort_block_m, index=True) + _c_tm = arith.constant(tile_m, index=True) + _c1 = arith.constant(1, index=True) + _sort_blocks_ub = _div_pow2( + size_expert_ids_in * _c_tm + _c_sbm - _c1, _sort_block_m + ) + eid_nbytes_idx = _sort_blocks_ub * arith.constant(4, index=True) eid_nbytes_i32 = arith.index_cast(T.i32, eid_nbytes_idx) expert_rsrc = buffer_ops.create_buffer_resource( arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 @@ -1879,50 +3208,117 @@ def moe_gemm2( else None ) - # ---- persist_m loop ---- - _PERSIST_M = persist_m - _c0_p = arith.index(0) - _c1_p = arith.index(1) - _c_pm = arith.index(_PERSIST_M) - _for_persist = scf.ForOp(_c0_p, _c_pm, _c1_p) + # ---- persist loop ---- + _c0_p = arith.constant(0, index=True) + _c1_p = arith.constant(1, index=True) + + if const_expr(_persistent): + # Expert-phase scheduling: contiguous M-tile dispatch. + # grid_y = cu_num, each CTA handles a contiguous chunk of M-tiles: + # [bx_persist * tiles_per_block, ..., (bx_persist+1) * tiles_per_block - 1] + # Adjacent blocks process adjacent M-tiles -> same expert -> B weight L2 reuse. + _c_cu = arith.constant(_cu_num, index=True) + _c_tm_p = arith.constant(tile_m, index=True) + _num_valid_idx = arith.index_cast(ir.IndexType.get(), num_valid_i32) + _total_m_tiles = (_num_valid_idx + _c_tm_p - _c1_p) / _c_tm_p + _tiles_per_block = (_total_m_tiles + _c_cu - _c1_p) / _c_cu + _i1 = ir.IntegerType.get_signless(1) + _init_active = arith.constant(1, type=_i1) + _for_persist = scf.ForOp(_c0_p, _tiles_per_block, _c1_p, [_init_active]) + else: + # Legacy mode: fixed persist_m consecutive tiles. + _c_pm = arith.constant(persist_m, index=True) + _init_prev_expert = arith.constant(0, type=T.i32) + _init_prev_b_base = arith.constant(0, index=True) + _for_persist = scf.ForOp( + _c0_p, + _c_pm, + _c1_p, + [_init_prev_expert, _init_prev_b_base], + ) + _for_ip = ir.InsertionPoint(_for_persist.body) _for_ip.__enter__() _mi_p = _for_persist.induction_variable - bx = bx_persist * _c_pm + _mi_p - bx_m = bx * fx.Index(tile_m) + + if const_expr(_persistent): + _still_active = _for_persist.inner_iter_args[0] + bx = bx_persist * _tiles_per_block + _mi_p + else: + _prev_expert_i32 = _for_persist.inner_iter_args[0] + _prev_expert_b_base = _for_persist.inner_iter_args[1] + bx = bx_persist * arith.constant(persist_m, index=True) + _mi_p + + bx_m = bx * arith.constant(tile_m, index=True) # Early-exit guard: skip garbage expert blocks beyond `num_valid_ids`. bx_m_i32 = arith.index_cast(T.i32, bx_m) - blk_valid = arith.cmpi(arith.CmpIPredicate.ult, bx_m_i32, num_valid_i32) + blk_valid = arith.cmpi(CmpIPredicate.ult, bx_m_i32, num_valid_i32) + sort_blk = _div_pow2(bx_m, _sort_block_m) expert_i32 = buffer_ops.buffer_load( - expert_rsrc, bx, vec_width=1, dtype=T.i32 + expert_rsrc, sort_blk, vec_width=1, dtype=T.i32 ) - expert_idx = arith.index_cast(T.index, expert_i32) + expert_idx = arith.index_cast(ir.IndexType.get(), expert_i32) exp_valid = arith.cmpi( - arith.CmpIPredicate.ult, expert_i32, fx.Int32(experts) + CmpIPredicate.ult, expert_i32, arith.constant(experts, type=T.i32) + ) + + if const_expr(_persistent): + # Absolute B-base: no cross-iteration state needed. + _expert_b_base = expert_idx * arith.constant( + _expert_b_stride, index=True + ) + else: + # Legacy incremental B-base: delta = (cur - prev) * stride + _delta_expert = arith.subi(expert_i32, _prev_expert_i32) + _delta_expert_idx = arith.index_cast(ir.IndexType.get(), _delta_expert) + _delta_b = _delta_expert_idx * arith.constant( + _expert_b_stride, index=True + ) + _expert_b_base = _prev_expert_b_base + _delta_b + + # Early-exit: if the first row of this tile is a sentinel (all-padding tile), + # skip the entire GEMM. + _first_tok = buffer_ops.buffer_load( + sorted_rsrc, bx_m, vec_width=1, dtype=T.i32 + ) + _first_tid = arith.andi(_first_tok, arith.constant(0xFFFFFF, type=T.i32)) + _tokens_i32_guard = arith.index_cast(T.i32, tokens_in) + tile_has_tokens = arith.cmpi( + CmpIPredicate.ult, _first_tid, _tokens_i32_guard ) + # For tile_m < 32 (pack_M < _scale_pack_m): shift a_scale i32 so the + # correct bytes land at the op_sel positions we use. + if const_expr(pack_M < _scale_pack_m): + _m_off = _mod_pow2(_div_pow2(bx_m, 16), _scale_pack_m) + _m_scale_shift_i32 = arith.index_cast( + T.i32, _m_off * arith.constant(8, index=True) + ) + else: + _m_scale_shift_i32 = None + def _moe_gemm2_then_body(): # Expert id for this M tile. - n_idx = fx.Index(model_dim) + n_idx = arith.constant(model_dim, index=True) expert_off_idx = expert_idx * n_idx # index # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. - if is_f16_a: - if bytes_per_thread_x % 16 != 0: + if const_expr(is_f16_a): + if const_expr(bytes_per_thread_x % 16 != 0): raise ValueError( f"[fp16] bytes_per_thread_x ({bytes_per_thread_x}) must be divisible by 16" ) x_load_bytes = 16 else: - if bytes_per_thread_x % 16 == 0: + if const_expr(bytes_per_thread_x % 16 == 0): x_load_bytes = 16 - elif bytes_per_thread_x % 8 == 0: + elif const_expr(bytes_per_thread_x % 8 == 0): x_load_bytes = 8 - elif bytes_per_thread_x % 4 == 0: + elif const_expr(bytes_per_thread_x % 4 == 0): x_load_bytes = 4 else: raise ValueError( @@ -1930,25 +3326,24 @@ def _moe_gemm2_then_body(): ) num_x_loads = bytes_per_thread_x // x_load_bytes chunk_i32 = x_load_bytes // 4 # dwords per chunk (1/2/4) - vec4_i32 = T.vec(4, T.i32) - vec2_i32 = T.vec(2, T.i32) - vec1_i32 = T.vec(1, T.i32) + vec4_i32 = T.vec(4, i32) - c_k_div4 = ( - (k_in // c_a_pack) * fx.Index(int(a_elem_bytes)) - ) // arith.index(4) - c_k_div4_i32 = arith.index_cast(T.i32, c_k_div4) + c_k_div4 = _div_pow2( + _div_pow2(k_in, int(a_elem_vec_pack)) + * arith.constant(int(a_elem_bytes), index=True), + 4, + ) tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // ( 4 * int(a_elem_vec_pack) ) layout_x_tile_div4 = fx.make_layout( (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) ) - c_chunk_i32 = fx.Index(chunk_i32) + c_chunk_i32 = arith.constant(chunk_i32, index=True) tx_i32_base = tx * c_chunk_i32 - topk_i32 = fx.Int32(topk) - mask24 = fx.Int32(0xFFFFFF) + topk_i32 = arith.constant(topk) + mask24 = arith.constant(0xFFFFFF) # Sentinel clamp uses `tokens` as the upper bound: t_valid = (t < tokens). tokens_i32 = arith.index_cast(T.i32, tokens_in) @@ -1962,6 +3357,8 @@ def x_tile_chunk_coord_i32(i: int): chunk_i32=chunk_i32, ) + vec1_i32 = T.vec(1, i32) + vec2_i32 = T.vec(2, i32) x_load_vec_elems = ( x_load_bytes if a_elem_bytes == 1 else x_load_bytes // a_elem_bytes ) @@ -1971,7 +3368,7 @@ def load_x(idx_i32): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): idx_elem = ( idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) ) @@ -2009,69 +3406,81 @@ def load_x(idx_i32): fused_i = buffer_ops.buffer_load( sorted_rsrc, sorted_row_i, vec_width=1, dtype=T.i32 ) - t_i32 = fused_i & mask24 - s_i32 = fused_i >> fx.Int32(24) - # Keep `blk_valid` only; remove per-row token validity checks. - - t_valid = arith.cmpi(arith.CmpIPredicate.ult, t_i32, tokens_i32) - s_valid = arith.cmpi(arith.CmpIPredicate.ult, s_i32, topk_i32) - ts_valid = t_valid & s_valid - t_safe = arith.select(ts_valid, t_i32, fx.Int32(0)) - s_safe = arith.select(ts_valid, s_i32, fx.Int32(0)) + t_i32 = arith.andi(fused_i, mask24) + s_i32 = arith.shrui(fused_i, arith.constant(24)) + + t_valid = arith.cmpi(CmpIPredicate.ult, t_i32, tokens_i32) + s_valid = arith.cmpi(CmpIPredicate.ult, s_i32, topk_i32) + ts_valid = arith.andi(t_valid, s_valid) + t_safe = arith.select(ts_valid, t_i32, arith.constant(0)) + s_safe = arith.select(ts_valid, s_i32, arith.constant(0)) row_ts_i32 = t_safe * topk_i32 + s_safe - row_ts_idx = arith.index_cast(T.index, row_ts_i32) + row_ts_idx = arith.index_cast(ir.IndexType.get(), row_ts_i32) - # Base row offset in dword units: row_ts_idx * (k_in/4) x_row_base_div4.append(row_ts_idx * c_k_div4) def load_x_tile(base_k): - base_k_div4 = ( - (base_k // c_a_pack) - * fx.Index(int(a_elem_bytes)) - ) // arith.index(4) + base_k_div4 = _div_pow2( + _div_pow2(base_k, int(a_elem_vec_pack)) + * arith.constant(int(a_elem_bytes), index=True), + 4, + ) parts = [] for i in range_constexpr(num_x_loads): idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] x_vec = load_x(idx_i32) - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): parts.append(vector.bitcast(vec4_i32, x_vec)) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes == 8): parts.append(vector.bitcast(vec2_i32, x_vec)) else: parts.append(vector.bitcast(vec1_i32, x_vec)) return parts # tx -> wave/lane (GEMM-style decomposition). - wave_id, lane_id = split_row_major_2d(tx, fx.Index(64)) - lane_div_16, lane_mod_16 = split_row_major_2d( - lane_id, fx.Index(16) - ) + coord_wl = idx2crd(tx, layout_tx_wave_lane) + wave_id = layout_get(coord_wl, 0) + lane_id = layout_get(coord_wl, 1) + coord_l16 = idx2crd(lane_id, layout_lane16) + lane_div_16 = layout_get(coord_l16, 0) + lane_mod_16 = layout_get(coord_l16, 1) row_a_lds = lane_mod_16 - col_offset_base = lane_div_16 * fx.Index(16) + col_offset_base = lane_div_16 * arith.constant(16, index=True) # Dynamic N tiling within block. - by_n = by * fx.Index(tile_n) num_waves = 4 n_per_wave = tile_n // num_waves num_acc_n = n_per_wave // 16 - c_n_per_wave = fx.Index(n_per_wave) - wave_mod_4 = wave_id % fx.Index(4) + c_n_per_wave = arith.constant(n_per_wave, index=True) + wave_mod_4 = _mod_pow2(wave_id, 4) n_tile_base = wave_mod_4 * c_n_per_wave - # Precompute (n_blk, n_intra) for B, and col indices for output. - n_intra_list = [] - n_blk_list = [] + by_n = by * arith.constant(tile_n, index=True) + + if const_expr(pack_N < _scale_pack_n): + _global_n_base = expert_off_idx + by_n + n_tile_base + _n_off = _mod_pow2(_div_pow2(_global_n_base, 16), _scale_pack_n) + _n_scale_shift_i32 = arith.index_cast( + T.i32, _n_off * arith.constant(8, index=True) + ) + else: + _n_scale_shift_i32 = None + n_intra_list = [None] * num_acc_n + n_blk_list = [None] * num_acc_n + col_g_list = [None] * num_acc_n for i in range_constexpr(num_acc_n): offset = i * 16 - c_offset = fx.Index(offset) + col_g = by_n + n_tile_base + col_g = _div_pow2(col_g, 2) + offset + col_g = col_g + lane_mod_16 + col_g_list[i] = col_g + c_offset = arith.constant(offset, index=True) global_n = by_n + n_tile_base + c_offset + lane_mod_16 - row_w = expert_off_idx + global_n - n_blk, n_intra = split_row_major_2d(row_w, fx.Index(16)) - n_blk_list.append(n_blk) - n_intra_list.append(n_intra) + n_blk_list[i] = _div_pow2(global_n, 16) + n_intra_list[i] = _mod_pow2(global_n, 16) m_repeat = tile_m // 16 k_unroll = tile_k_bytes // 128 # K64-byte micro-step (2x MFMA) @@ -2081,24 +3490,37 @@ def load_x_tile(base_k): m_repeat_packed = m_repeat // pack_M num_acc_n_packed = num_acc_n // pack_N + _K_per_ku_s2 = tile_k // k_unroll + _pad_k_elems_s2 = (inter_dim_pad % tile_k) if inter_dim_pad > 0 else 0 + _pad_ku_skip_s2 = _pad_k_elems_s2 // _K_per_ku_s2 + _tail_ku_s2 = k_unroll - _pad_ku_skip_s2 + _tail_ku_packed_s2 = ( + (_tail_ku_s2 + pack_K - 1) // pack_K + if _pad_ku_skip_s2 > 0 + else None + ) + # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- def load_b_packs_k64(base_k, ku: int, ni: int): """Load one K64-byte B micro-step: single 16B load, split into 2x i64.""" - c64 = fx.Index(64) base_k_bytes = base_k * arith.constant( int(b_elem_bytes), index=True ) - k0_base = base_k_bytes // c64 - k0 = k0_base + fx.Index(ku) + k0_base = _div_pow2(base_k_bytes, 64) + k0 = k0_base + arith.constant(ku, index=True) k1 = lane_div_16 - coord_pack = ( - n_blk_list[ni], - k0, - k1, - n_intra_list[ni], - fx.Index(0), + # Incremental B addressing: _expert_b_base carries the + # expert's preshuffle offset (updated via delta each + # persist_m iteration); local n_blk/n_intra contribute + # the per-lane within-tile offset. All strides are + # compile-time constants -> shift/mul, no Barrett. + idx_pack = ( + _expert_b_base + + n_blk_list[ni] * arith.constant(_b_stride_n0, index=True) + + k0 * arith.constant(_b_stride_k0, index=True) + + k1 * arith.constant(_b_stride_klane, index=True) + + n_intra_list[ni] * arith.constant(_b_stride_nlane, index=True) ) - idx_pack = crd2idx(coord_pack, layout_b) vec_elems = kpack_bytes // int(b_elem_bytes) b16 = _buffer_load_vec( @@ -2106,10 +3528,11 @@ def load_b_packs_k64(base_k, ku: int, ni: int): vector, w_rsrc, idx_pack, - elem_type=_w_elem_type(is_f4_b=is_f4_b, is_f16_b=is_f16_b), + elem_type=_w_elem_type(), vec_elems=vec_elems, elem_bytes=b_elem_bytes, offset_in_bytes=(b_elem_bytes == 1), + cache_modifier=b_nt, ) b_i64x2 = vector.bitcast(vec2_i64, b16) b0 = vector.extract( @@ -2120,9 +3543,38 @@ def load_b_packs_k64(base_k, ku: int, ni: int): ) return b0, b1 - def load_b_tile(base_k): + def load_b_tile(base_k, ku_limit=k_unroll): b_tile = [] - for ku in range_constexpr(k_unroll): + for ku in range_constexpr(ku_limit): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + b0, b1 = load_b_packs_k64(base_k, ku, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + + _b_split_enabled = k_unroll >= 2 + _b_split_ku = k_unroll // 2 if _b_split_enabled else k_unroll + + def load_b_tile_lo(base_k): + """Load first half of B tile (ku < _b_split_ku).""" + b_tile = [] + for ku in range_constexpr(_b_split_ku): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + b0, b1 = load_b_packs_k64(base_k, ku, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + + def load_b_tile_hi(base_k): + """Load second half of B tile (ku >= _b_split_ku).""" + b_tile = [] + for ku in range_constexpr(_b_split_ku, k_unroll): packs0 = [] packs1 = [] for ni in range_constexpr(num_acc_n): @@ -2145,9 +3597,20 @@ def load_scale(arg_scale, rsrc, scale_info, ku, mni): s = buffer_ops.buffer_load(rsrc, idx_pack, vec_width=1, dtype=T.i32) return vector.from_elements(T.vec(1, T.i32), [s]) - def load_b_scale_tile(base_k): + def _apply_k_shift(scale_vec, k_shift_bits): + if const_expr(k_shift_bits > 0): + val = vector.extract( + scale_vec, static_position=[0], dynamic_position=[] + ) + val = arith.shrui(val, arith.constant(k_shift_bits, type=T.i32)) + return vector.from_elements(T.vec(1, T.i32), [val]) + return scale_vec + + def load_b_scale_tile( + base_k, k_shift_bits=0, ku_packed_limit=k_unroll_packed + ): b_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): + for ku in range_constexpr(ku_packed_limit): for ni in range_constexpr(num_acc_n_packed): scale = load_scale( arg_scale_w, @@ -2155,63 +3618,84 @@ def load_b_scale_tile(base_k): layout_b_scale, ku + base_k, ni - + (expert_off_idx + by_n + n_tile_base) // pack_N // 16, + + _div_pow2( + _div_pow2( + expert_off_idx + by_n + n_tile_base, + _scale_pack_n, + ), + 16, + ), ) + scale = _apply_k_shift(scale, k_shift_bits) b_scale_tile.append(scale) return b_scale_tile - def load_a_scale_tile(base_k): + def load_a_scale_tile( + base_k, k_shift_bits=0, ku_packed_limit=k_unroll_packed + ): a_scale_tile = [] - for ku in range_constexpr(k_unroll_packed): + for ku in range_constexpr(ku_packed_limit): for mi in range_constexpr(m_repeat_packed): scale = load_scale( arg_scale_x, sx_rsrc, layout_a_scale, ku + base_k, - mi + bx_m // pack_M // 16, + mi + _div_pow2(_div_pow2(bx_m, _scale_pack_m), 16), ) + scale = _apply_k_shift(scale, k_shift_bits) a_scale_tile.append(scale) return a_scale_tile - def prefetch_ab_scale_tile(base_k): - return [load_a_scale_tile(base_k), load_b_scale_tile(base_k)] + def prefetch_ab_scale_tile( + base_k, k_shift_bits=0, ku_packed_limit=k_unroll_packed + ): + return [ + load_a_scale_tile( + base_k, k_shift_bits, ku_packed_limit=ku_packed_limit + ), + load_b_scale_tile( + base_k, k_shift_bits, ku_packed_limit=ku_packed_limit + ), + ] vec8_x = T.vec(vec8_elems, x_elem) vec4_x_lds = T.vec(vec4_elems, x_elem) - # ---- Pipeline helpers: store X tile to LDS with ping-pong base ---- - def store_x_tile_to_lds(vec_x_in_parts, lds_base): + # ---- Pipeline helpers: store X tile to LDS (unused in DMA path) ---- + _lds_base_zero = arith.index(0) + + def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if x_load_bytes == 16: + if const_expr(x_load_bytes == 16): lds_store_16b_xor16( arith, vector, - lds_memref=lds_x, + lds_memref=lds_buffer, vec16_ty=vec16_x, layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, tx_c4=arith.index(4), k_blocks16=k_blocks16, - lds_base=lds_base, + lds_base=_lds_base_zero, vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) - elif x_load_bytes == 8: + elif const_expr(x_load_bytes == 8): lds_store_8b_xor16( arith, vector, - lds_memref=lds_x, + lds_memref=lds_buffer, vec8_ty=vec8_x, layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, tx_c4=arith.index(4), k_blocks16=k_blocks16, - lds_base=lds_base, + lds_base=_lds_base_zero, vec_part_i32x2=vec_x_in_parts[i], elem_bytes=elem_bytes, ) @@ -2219,36 +3703,30 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): lds_store_4b_xor16( arith, vector, - lds_memref=lds_x, + lds_memref=lds_buffer, vec4_ty=vec4_x_lds, layout_lds=layout_lds, row_local=row_local, col_local_i32=col_local_i32, tx_c4=arith.index(4), k_blocks16=k_blocks16, - lds_base=lds_base, + lds_base=_lds_base_zero, vec_part_i32x1=vec_x_in_parts[i], elem_bytes=elem_bytes, ) # --- A LDS load helper for K64 (load 16B once, extract 2x i64 halves) --- - def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): - # Swizzle in bytes, then convert to element offset for memref indexing. + def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): col_base_swz_bytes = swizzle_xor16( curr_row_a_lds, col_base, k_blocks16 ) col_base_swz = ( col_base_swz_bytes if elem_bytes == 1 - else (col_base_swz_bytes // arith.index(2)) + else (col_base_swz_bytes / arith.index(2)) ) - idx_a16 = lds_row_major_idx( - curr_row_a_lds, - col_base_swz, - fx.Index(lds_stride), - lds_base, - ) - loaded_a16 = vector.load_op(vec16_x, lds_x, [idx_a16]) + idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) a0 = vector.extract( a_i64x2, static_position=[0], dynamic_position=[] @@ -2261,38 +3739,47 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): def compute_tile( acc_in, b_tile_in, - lds_base, + lds_buffer, a_scale=None, b_scale=None, *, prefetch_epilogue: bool = False, a0_prefetch=None, + a1_prefetch=None, + b_hi_loader=None, + ku_count=k_unroll, ): + if const_expr(b_hi_loader is not None): + b_tile_full = [None] * k_unroll + for i in range_constexpr(_b_split_ku): + b_tile_full[i] = b_tile_in[i] + else: + b_tile_full = b_tile_in acc_list = list(acc_in) - mfma_res_ty = vec4_f32 + mfma_res_ty = vec4_i32 if is_int8 else vec4_f32 epilogue_pf = None bias = None - if prefetch_epilogue: - if enable_bias: + if const_expr(prefetch_epilogue): + if const_expr(enable_bias): bias = [] for ni in range_constexpr(num_acc_n): global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 bias_offset = expert_off_idx + global_n bias.append( buffer_ops.buffer_load( - bias_rsrc, bias_offset, vec_width=1, dtype=T.f32 + bias_rsrc, bias_offset, vec_width=1, dtype=f32 ) ) tw_pf = None - if doweight_stage2: + if const_expr(doweight_stage2): tw_pf = [] lane_div_16_mul4_pf = lane_div_16 * arith.index(4) ii_idx_list_pf = [ - fx.Index(ii) for ii in range(4) + arith.constant(ii, index=True) for ii in range(4) ] for mi in range_constexpr(m_repeat): - mi_base_pf = fx.Index(mi * 16) + mi_base_pf = arith.constant(mi * 16, index=True) for ii in range_constexpr(4): row_off_pf = ( lane_div_16_mul4_pf + ii_idx_list_pf[ii] @@ -2304,7 +3791,7 @@ def compute_tile( sorted_w_rsrc, sorted_row_pf, vec_width=1, - dtype=T.f32, + dtype=f32, ) ) epilogue_pf = (None, tw_pf, bias) @@ -2317,13 +3804,34 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) return vector.bitcast(vec8_i32, v4) - # fp4 path - for ku128 in range_constexpr(k_unroll_packed): + # fp4 path -- single k_idx loop [0, k_unroll). + # b_hi load is issued at the very start so all k_unroll + # MFMAs can overlap the VMEM latency. + _pack_K_shift = (pack_K - 1).bit_length() + _pack_K_mask = pack_K - 1 + + if const_expr(b_hi_loader is not None): + _b_hi = b_hi_loader() + for _bhi_i in range_constexpr(len(_b_hi)): + b_tile_full[_b_split_ku + _bhi_i] = _b_hi[_bhi_i] + + for k_idx in range_constexpr(ku_count): + ku128 = k_idx >> _pack_K_shift + ikxdl = k_idx & _pack_K_mask + + b_packs0, b_packs1 = b_tile_full[k_idx] + + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + for mi in range_constexpr(m_repeat_packed): a_scale_i32 = a_scale[ku128 * m_repeat_packed + mi] a_scale_val = vector.extract( a_scale_i32, static_position=[0], dynamic_position=[] ) + if const_expr(_m_scale_shift_i32 is not None): + a_scale_val = arith.shrui( + a_scale_val, _m_scale_shift_i32 + ) for ni in range_constexpr(num_acc_n_packed): b_scale_i32 = b_scale[ku128 * num_acc_n_packed + ni] b_scale_val = vector.extract( @@ -2331,135 +3839,151 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): static_position=[0], dynamic_position=[], ) - for ikxdl in range_constexpr(pack_K): - k_idx = ku128 * pack_K + ikxdl - - b_packs0, b_packs1 = b_tile_in[k_idx] - - col_base = ( - col_offset_base - + (k_idx * 128) // a_elem_vec_pack + if const_expr(_n_scale_shift_i32 is not None): + b_scale_val = arith.shrui( + b_scale_val, _n_scale_shift_i32 ) - for imxdl in range_constexpr(pack_M): - col_base0 = col_base - mi_idx = mi * pack_M + imxdl - mi_val = fx.Index(mi_idx * 16) - curr_row_a_lds = row_a_lds + mi_val - - if ( - (a0_prefetch is not None) - and (k_idx == 0) - and (mi_idx == 0) - ): - a0, a1 = a0_prefetch - else: - a0, a1 = lds_load_packs_k64( - curr_row_a_lds, col_base0, lds_base - ) + for imxdl in range_constexpr(pack_M): + col_base0 = col_base + mi_idx = mi * pack_M + imxdl + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + if const_expr( + (a0_prefetch is not None) + and (k_idx == 0) + and (mi_idx == 0) + ): + a0, a1 = a0_prefetch + elif const_expr( + (a1_prefetch is not None) + and (k_idx == 1) + and (mi_idx == 0) + ): + a0, a1 = a1_prefetch + else: + a0, a1 = lds_load_packs_k64( + curr_row_a_lds, col_base0, lds_buffer + ) - if is_f8_a: - col_base1 = col_base + 64 - a2, a3 = lds_load_packs_k64( - curr_row_a_lds, col_base1, lds_base - ) - a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - else: - a128 = pack_i64x4_to_i32x8( - a0, a1, c0_i64, c0_i64 - ) + if const_expr(is_f8_a): + col_base1 = col_base + 64 + a2, a3 = lds_load_packs_k64( + curr_row_a_lds, col_base1, lds_buffer + ) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + else: + a128 = pack_i64x4_to_i32x8( + a0, a1, c0_i64, c0_i64 + ) - for inxdl in range_constexpr(pack_N): - ni_idx = ni * pack_N + inxdl + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl - b0 = b_packs0[ni_idx] - b1 = b_packs1[ni_idx] - b128 = pack_i64x4_to_i32x8( - b0, b1, c0_i64, c0_i64 - ) + b0 = b_packs0[ni_idx] + b1 = b_packs1[ni_idx] + b128 = pack_i64x4_to_i32x8( + b0, b1, c0_i64, c0_i64 + ) - acc_idx = mi_idx * num_acc_n + ni_idx - rocdl.sched_barrier(0) - acc_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - b128, - acc_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - b_scale_val, - ], - ) + acc_idx = mi_idx * num_acc_n + ni_idx + acc_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + acc_list[acc_idx], + cbsz, + blgp, + ikxdl * _scale_pack_m + imxdl, + a_scale_val, + ikxdl * _scale_pack_n + inxdl, + b_scale_val, + ], ) + ) return acc_list, epilogue_pf # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- - lds_tile_elems = fx.Index(tile_m * lds_stride) - lds_base_cur = arith.index(0) - lds_base_nxt = lds_tile_elems + # ---- Async DMA: GMEM -> LDS (bypasses VGPR, like stage1) ---- + _dma_bytes = 16 + _wave_size = 64 + _eff_bytes_per_buffer = ( + int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) + ) + _num_dma_loads = max( + 1, _eff_bytes_per_buffer // (total_threads * _dma_bytes) + ) + + def dma_x_tile_to_lds(base_k, lds_buffer): + c4_idx = arith.index(4) + base_k_div4 = _div_pow2( + _div_pow2(base_k, int(a_elem_vec_pack)) + * arith.constant(int(a_elem_bytes), index=True), + 4, + ) + + lds_ptr_i64 = None + for i in range_constexpr(_num_dma_loads): + row_local_i = x_row_local[i] + col_local_i32_i = x_col_local_i32[i] + col_local_sw = swizzle_xor16( + row_local_i, col_local_i32_i * c4_idx, k_blocks16 + ) + row_k_dw = x_row_base_div4[i] + base_k_div4 + global_byte_idx = row_k_dw * c4_idx + col_local_sw + global_offset = arith.index_cast(T.i32, global_byte_idx) + + if const_expr(i == 0): + lds_addr = memref.extract_aligned_pointer_as_index( + lds_buffer + ) + wave_id * arith.constant( + _wave_size * _dma_bytes, index=True + ) + lds_ptr_i64 = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + else: + lds_ptr_i64 = lds_ptr_i64 + arith.constant( + total_threads * _dma_bytes, type=T.i64 + ) + + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_ptr_i64) + + rocdl.raw_ptr_buffer_load_lds( + x_rsrc, + lds_ptr, + arith.constant(_dma_bytes, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + ) + + def prefetch_x_to_lds(base_k, lds_buffer): + dma_x_tile_to_lds(base_k, lds_buffer) rocdl.sched_barrier(0) def hot_loop_scheduler(): - # - MFMA group size per "slot": num_acc_n - # - Total MFMA per tile: (2*K32 per K64) * k_unroll * m_repeat * num_acc_n - # - We emit (mfma_group + dsrd + mfma_group) per scheduler iteration. - mfma_group = num_acc_n - mfma_total = (k_unroll * 2) * m_repeat * mfma_group - mfma_per_iter = 2 * mfma_group - sche_iters = ( - 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) - ) - - rocdl.sched_dsrd(2) - rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - if num_acc_n < 4: - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(1) - if tile_m == 16: - rocdl.sched_vmem(1) - rocdl.sched_mfma(1) - - # DS-write hints near the end: match total A LDS-store micro-ops per thread. - dswr_tail = num_x_loads - if dswr_tail > sche_iters: - dswr_tail = sche_iters - dswr_start = sche_iters - dswr_tail - - for sche_i in range_constexpr(sche_iters): - rocdl.sched_vmem(1) - rocdl.sched_mfma(mfma_group) - rocdl.sched_dsrd(1) - rocdl.sched_mfma(mfma_group) - if sche_i >= dswr_start - 1: - rocdl.sched_dswr(1) - rocdl.sched_barrier(0) - # Prologue. - k0 = arith.index(0) - x_regs0 = load_x_tile(k0) - b_cur = load_b_tile(k0) - a_scale_pong, b_scale_pong = prefetch_ab_scale_tile(k0 // pack_K // 128) - store_x_tile_to_lds(x_regs0, lds_base_cur) + def _k_shift_bits(k_py): + if const_expr(pack_K >= _scale_pack_k): + return 0 + return ((k_py // 128) % _scale_pack_k) * _scale_pack_m * 8 + + def _k_base(k_py): + return k_py // _scale_pack_k // 128 + # Preload sorted_idx into lds_tid for epilogue precompute_row - _c_tile_m_idx = fx.Index(tile_m) - _tid_in_range = arith.cmpi(arith.CmpIPredicate.ult, tx, _c_tile_m_idx) + # (N-independent; placed before N-tile loop so it's done once per M-tile.) + _c_tile_m_idx = arith.constant(tile_m, index=True) + _tid_in_range = arith.cmpi(CmpIPredicate.ult, tx, _c_tile_m_idx) _if_tid = scf.IfOp(_tid_in_range) with ir.InsertionPoint(_if_tid.then_block): _tid_row = bx_m + tx @@ -2469,17 +3993,36 @@ def hot_loop_scheduler(): _tid_vec1 = vector.from_elements(T.vec(1, T.i32), [_tid_val]) vector.store(_tid_vec1, lds_tid, [tx]) scf.YieldOp([]) + + gpu.barrier() + + # Prologue -- B-first + async DMA X(0) -> pong. + k0 = arith.index(0) + if const_expr(_b_split_enabled): + b_cur = load_b_tile_lo(k0) + else: + b_cur = load_b_tile(k0) + a_scale_pong, b_scale_pong = prefetch_ab_scale_tile( + _k_base(0), _k_shift_bits(0) + ) + rocdl.sched_barrier(0) + prefetch_x_to_lds(k0, lds_x_pong) + rocdl.s_waitcnt(0) gpu.barrier() acc = [acc_init] * num_acc_n * m_repeat - lds_base_pong = lds_base_cur - lds_base_ping = lds_base_nxt - # Cross-tile A0 LDS prefetch (default-on): prefetch the first A-pack (K64) for the - # tile we are about to compute from LDS, to overlap with upcoming VMEM. + # Cross-tile A0+A1 LDS prefetch from pong buffer. a0_prefetch_pong = lds_load_packs_k64( - row_a_lds, col_offset_base, lds_base_pong + row_a_lds, col_offset_base, lds_x_pong ) + _a1_col_base = col_offset_base + 128 // a_elem_vec_pack + a1_prefetch_pong = ( + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_x_pong) + if pack_K >= 2 + else None + ) + # Main loop: process K tiles in 2-tile ping-pong steps. # # IMPORTANT: for odd number of K tiles, leave **1** tail tile; for even, leave **2**. @@ -2489,121 +4032,211 @@ def hot_loop_scheduler(): odd_k_tiles = (num_k_tiles_py % 2) == 1 tail_tiles = 1 if odd_k_tiles else 2 k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) - if k_main2_py < 0: + if const_expr(k_main2_py < 0): k_main2_py = 0 - c2_tile_k = fx.Index(tile_k * 2) + c2_tile_k = arith.constant(tile_k * 2, index=True) b_pong = b_cur + k0_pong_bk = k0 + # Only emit the scf.for when there are actually iterations to run. # When k_main2_py == 0 the loop body is empty; emitting an scf.for # would create a region whose internal SSA values cannot be used # by the post-loop tail code. - if k_main2_py > 0: + def _make_b_hi_loader(base_k): + """Create a b_hi_loader callable for a given base_k.""" + return lambda _bk=base_k: load_b_tile_hi(_bk) + + if const_expr(k_main2_py > 0): for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): - k_iv = k_iv_py + rocdl.sched_barrier(0) + k_iv = arith.index(k_iv_py) next_k1 = k_iv + tile_k - x_regs_ping = load_x_tile(next_k1) - b_ping = load_b_tile(next_k1 // 2) + next_k1_bk = next_k1 // 2 + # DMA X(next_k1) -> ping (non-blocking, overlaps with compute) + prefetch_x_to_lds(next_k1, lds_x_ping) + b_ping_lo = ( + load_b_tile_lo(next_k1_bk) + if _b_split_enabled + else load_b_tile(next_k1_bk) + ) a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( - next_k1 // pack_K // 128 + _k_base(next_k1), _k_shift_bits(next_k1) ) acc, _ = compute_tile( acc, b_pong, - lds_base_pong, + lds_x_pong, a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, + a1_prefetch=a1_prefetch_pong, + b_hi_loader=( + _make_b_hi_loader(k0_pong_bk) + if _b_split_enabled + else None + ), ) - store_x_tile_to_lds(x_regs_ping, lds_base_ping) - # hot_loop_scheduler() + hot_loop_scheduler() + rocdl.s_waitcnt(0) gpu.barrier() # Cross-tile prefetch for the ping tile we are about to compute. a0_prefetch_ping = lds_load_packs_k64( - row_a_lds, col_offset_base, lds_base_ping + row_a_lds, col_offset_base, lds_x_ping + ) + a1_prefetch_ping = ( + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_x_ping) + if pack_K >= 2 + else None ) next_k2 = k_iv + c2_tile_k - x_regs_pong = load_x_tile(next_k2) - b_pong = load_b_tile(next_k2 // 2) + next_k2_py = k_iv_py + tile_k * 2 + next_k2_bk = next_k2 // 2 + # DMA X(next_k2) -> pong (non-blocking, overlaps with compute) + prefetch_x_to_lds(next_k2, lds_x_pong) + b_pong = ( + load_b_tile_lo(next_k2_bk) + if _b_split_enabled + else load_b_tile(next_k2_bk) + ) a_scale_pong, b_scale_pong = prefetch_ab_scale_tile( - next_k2 // pack_K // 128 + _k_base(next_k2_py), _k_shift_bits(next_k2_py) ) acc, _ = compute_tile( acc, - b_ping, - lds_base_ping, + b_ping_lo, + lds_x_ping, a_scale_ping, b_scale_ping, a0_prefetch=a0_prefetch_ping, + a1_prefetch=a1_prefetch_ping, + b_hi_loader=( + _make_b_hi_loader(next_k1_bk) + if _b_split_enabled + else None + ), ) - store_x_tile_to_lds(x_regs_pong, lds_base_pong) - # hot_loop_scheduler() + k0_pong_bk = next_k2_bk + hot_loop_scheduler() gpu.barrier() # Cross-tile prefetch for the next pong tile. a0_prefetch_pong = lds_load_packs_k64( - row_a_lds, col_offset_base, lds_base_pong + row_a_lds, col_offset_base, lds_x_pong + ) + a1_prefetch_pong = ( + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_x_pong) + if pack_K >= 2 + else None ) - if odd_k_tiles: - # Tail: single remaining tile (already in `b_cur` / `lds_base_pong`). + if const_expr(odd_k_tiles): + # Tail: single remaining tile (already in pong buffer). acc, epilogue_pf = compute_tile( acc, b_pong, - lds_base_pong, + lds_x_pong, a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, + a1_prefetch=a1_prefetch_pong, prefetch_epilogue=True, + b_hi_loader=( + _make_b_hi_loader(k0_pong_bk) if _b_split_enabled else None + ), + ku_count=_tail_ku_s2 if _pad_ku_skip_s2 > 0 else k_unroll, ) else: # Tail: 2 remaining tiles. k_tail1 = (k_in + tile_k - 1) // tile_k * tile_k - tile_k - x_regs_ping = load_x_tile(k_tail1) - b_ping = load_b_tile(k_tail1 // 2) - a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( - k_tail1 // pack_K // 128 - ) + k_tail1_py = ( + int(inter_dim) + tile_k - 1 + ) // tile_k * tile_k - tile_k + k_tail1_bk = k_tail1 // 2 + # DMA tail X -> ping + prefetch_x_to_lds(k_tail1, lds_x_ping) + if const_expr(_pad_ku_skip_s2 > 0): + b_ping_lo = load_b_tile(k_tail1_bk, ku_limit=_tail_ku_s2) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( + _k_base(k_tail1_py), + _k_shift_bits(k_tail1_py), + ku_packed_limit=_tail_ku_packed_s2, + ) + else: + b_ping_lo = ( + load_b_tile_lo(k_tail1_bk) + if _b_split_enabled + else load_b_tile(k_tail1_bk) + ) + a_scale_ping, b_scale_ping = prefetch_ab_scale_tile( + _k_base(k_tail1_py), _k_shift_bits(k_tail1_py) + ) acc, _ = compute_tile( acc, b_pong, - lds_base_pong, + lds_x_pong, a_scale_pong, b_scale_pong, a0_prefetch=a0_prefetch_pong, + a1_prefetch=a1_prefetch_pong, + b_hi_loader=( + _make_b_hi_loader(k0_pong_bk) if _b_split_enabled else None + ), ) - store_x_tile_to_lds(x_regs_ping, lds_base_ping) # hot_loop_scheduler() + rocdl.s_waitcnt(0) gpu.barrier() # Epilogue tile with sw prefetch. a0_prefetch_ping = lds_load_packs_k64( - row_a_lds, col_offset_base, lds_base_ping + row_a_lds, col_offset_base, lds_x_ping + ) + a1_prefetch_ping = ( + lds_load_packs_k64(row_a_lds, _a1_col_base, lds_x_ping) + if pack_K >= 2 and (_pad_ku_skip_s2 == 0 or _tail_ku_s2 >= 2) + else None ) acc, epilogue_pf = compute_tile( acc, - b_ping, - lds_base_ping, + b_ping_lo, + lds_x_ping, a_scale_ping, b_scale_ping, a0_prefetch=a0_prefetch_ping, + a1_prefetch=a1_prefetch_ping, prefetch_epilogue=True, + b_hi_loader=( + None + if _pad_ku_skip_s2 > 0 + else ( + _make_b_hi_loader(k_tail1_bk) + if _b_split_enabled + else None + ) + ), + ku_count=_tail_ku_s2 if _pad_ku_skip_s2 > 0 else k_unroll, ) # ---------------- Epilogue: LDS CShuffle + atomic half2 (x2) ---------------- # Reuse the shared helper so GEMM / MoE kernels share the exact same CShuffle skeleton. - model_i32 = fx.Int32(model_dim) - zero_i32 = fx.Int32(0) - c2_i32 = fx.Int32(2) - mask_even_i32 = fx.Int32(0xFFFFFFFE) + sw_pf = None + tw_pf = None + bias_pf = None + if const_expr(epilogue_pf is not None): + sw_pf, tw_pf, bias_pf = epilogue_pf + + mask24_i32 = arith.constant(0xFFFFFF) + topk_i32_v = topk_i32 + + zero_i32 = arith.constant(0) def atomic_add_f16x2(val_f16x2, byte_off_i32): rocdl.raw_ptr_buffer_atomic_fadd( @@ -2614,24 +4247,24 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): zero_i32, ) - sw_pf = None - tw_pf = None - bias_pf = None - if epilogue_pf is not None: - sw_pf, tw_pf, bias_pf = epilogue_pf - - mask24_i32 = fx.Int32(0xFFFFFF) - topk_i32_v = topk_i32 - # Weight scales for the N tile (col_g depends on lane/wave/by but not on (t,s)). - if lds_out is None: + if const_expr(lds_out is None): raise RuntimeError( - "FLYDSL_MOE_STAGE2_CSHUFFLE=1 but lds_out is not allocated/aliased." + "FLIR_MOE_STAGE2_CSHUFFLE=1 but lds_out is not allocated/aliased." ) - out_base_idx = None - if _needs_global_atomic_bf16: - out_base_idx = buffer_ops.extract_base_index(arg_out) + # Precompute the output base address (i64 index) for ALL paths. + # Both accumulate=True (global atomic) and accumulate=False (global store) + # need 64-bit addressing to avoid i32 offset overflow when + # tokens * model_dim * elem_bytes > INT32_MAX (~150K tokens for model_dim=7168). + from flydsl._mlir.dialects import fly as _fly + + _llvm_ptr_ty = ir.Type.parse("!llvm.ptr") + out_base_ptr = _fly.extract_aligned_pointer_as_index( + _llvm_ptr_ty, arg_out + ) + out_base_i64 = llvm.ptrtoint(T.i64, out_base_ptr) + out_base_idx = arith.index_cast(ir.IndexType.get(), out_base_i64) def write_row_to_lds( *, @@ -2644,13 +4277,27 @@ def write_row_to_lds( num_acc_n: int, lds_out, ): - if doweight_stage2: + # Match origin/dev_a16w4: rely on sentinel padded rows + hardware OOB behavior. + fused2 = buffer_ops.buffer_load( + sorted_rsrc, row, vec_width=1, dtype=T.i32 + ) + t2 = fused2 & mask24_i32 + s2 = fused2 >> 24 + + t_ok = arith.cmpi(CmpIPredicate.ult, t2, tokens_i32) + s_ok = arith.cmpi(CmpIPredicate.ult, s2, topk_i32_v) + ts_ok = arith.andi(t_ok, s_ok) + t2_safe = arith.select(ts_ok, t2, arith.constant(0)) + s2_safe = arith.select(ts_ok, s2, arith.constant(0)) + t2_safe * topk_i32_v + s2_safe + + if const_expr(doweight_stage2): tw_idx = (mi * 4) + ii - if tw_pf is not None: + if const_expr(tw_pf is not None): tw = tw_pf[tw_idx] else: tw = buffer_ops.buffer_load( - sorted_w_rsrc, row, vec_width=1, dtype=T.f32 + sorted_w_rsrc, row, vec_width=1, dtype=f32 ) for ni in range_constexpr(num_acc_n): @@ -2659,10 +4306,12 @@ def write_row_to_lds( v = vector.extract( acc[acc_idx], static_position=[ii], dynamic_position=[] ) - if enable_bias: + if const_expr(is_int8): + v = arith.sitofp(f32, v) + if const_expr(enable_bias): v = v + bias_pf[ni] - if doweight_stage2: + if const_expr(doweight_stage2): v = v * tw v_out = arith.trunc_f(out_elem(), v) @@ -2677,52 +4326,69 @@ def precompute_row(*, row_local, row): # to avoid extra VMEM round-trips in the epilogue. fused2 = memref.load(lds_tid, [row_local]) row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi(arith.CmpIPredicate.ult, row_i32, num_valid_i32) + row_valid0 = arith.cmpi(CmpIPredicate.ult, row_i32, num_valid_i32) t = fused2 & mask24_i32 s = fused2 >> 24 - t_ok = arith.cmpi(arith.CmpIPredicate.ult, t, tokens_i32) - s_ok = arith.cmpi(arith.CmpIPredicate.ult, s, topk_i32_v) - row_valid = row_valid0 & (t_ok & s_ok) + t_ok = arith.cmpi(CmpIPredicate.ult, t, tokens_i32) + s_ok = arith.cmpi(CmpIPredicate.ult, s, topk_i32_v) + row_valid = arith.andi(row_valid0, arith.andi(t_ok, s_ok)) + t_idx = arith.index_cast(ir.IndexType.get(), t) + s_idx = arith.index_cast(ir.IndexType.get(), s) + ts_idx = t_idx * arith.constant(topk, index=True) + s_idx + if const_expr(accumulate): + row_byte_base = out_base_idx + t_idx * arith.constant( + model_dim * out_elem_bytes, index=True + ) + else: + row_byte_base = out_base_idx + ts_idx * arith.constant( + model_dim * out_elem_bytes, index=True + ) + return ((fused2, row_byte_base), row_valid) - return (fused2, row_valid) + def _idx_to_llvm_ptr(idx_val, addr_space=1): + """Convert an index-typed byte address to !llvm.ptr.""" + idx_v = idx_val._value if hasattr(idx_val, "_value") else idx_val + i64_v = arith.index_cast(T.i64, idx_v) + i64_raw = i64_v._value if hasattr(i64_v, "_value") else i64_v + ptr_ty = ir.Type.parse(f"!llvm.ptr<{addr_space}>") + return llvm.inttoptr(ptr_ty, i64_raw) def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): - fused = row_ctx - t = fused & mask24_i32 - s = fused >> 24 - idx0 = t * model_i32 - if not bool(accumulate): - ts = t * topk_i32_v + s - idx0 = ts * model_i32 - col_i32 = arith.index_cast(T.i32, col_g0) - idx_elem = idx0 + col_i32 - idx_elem_even = idx_elem & mask_even_i32 - if _needs_global_atomic_bf16: - if bool(accumulate): - byte_off = idx_elem_even * c2_i32 - byte_off_idx = arith.index_cast(T.index, byte_off) - ptr_addr_idx = out_base_idx + byte_off_idx - out_ptr = buffer_ops.create_llvm_ptr(ptr_addr_idx, address_space=1) - out_ptr_v = out_ptr._value if hasattr(out_ptr, "_value") else out_ptr - frag_v = frag._value if hasattr(frag, "_value") else frag - - llvm.AtomicRMWOp( - llvm.AtomicBinOp.fadd, - out_ptr_v, - frag_v, - llvm.AtomicOrdering.monotonic, - syncscope="agent", - alignment=4, - ) - else: - buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) + fused, row_byte_base = row_ctx + if const_expr(not bool(accumulate)): + # ---- 64-bit global store path (avoids i32 offset overflow) ---- + col_idx = col_g0 + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.StoreOp( + frag_v, + out_ptr_v, + alignment=_e_vec * out_elem_bytes, + nontemporal=True, + ) else: - byte_off = idx_elem_even * c2_i32 - if bool(accumulate): - atomic_add_f16x2(frag, byte_off) - else: - buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) + # ---- accumulate=True: 64-bit global atomic path ---- + col_idx = col_g0 + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + out_ptr_v, + frag_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=_e_vec * out_elem_bytes, + ) + _e_vec = 2 if accumulate else min(tile_n // 32, 8) c_shuffle_epilog( arith=arith, vector=vector, @@ -2731,7 +4397,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): range_constexpr=range_constexpr, tile_m=tile_m, tile_n=tile_n, - e_vec=2, + e_vec=_e_vec, m_repeat=m_repeat, num_acc_n=num_acc_n, tx=tx, @@ -2741,22 +4407,38 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): by_n=by_n, n_tile_base=n_tile_base, lds_out=lds_out, - frag_elem_type=(T.bf16 if out_is_bf16 else T.f16), + frag_elem_type=( + ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get() + ), write_row_to_lds=write_row_to_lds, precompute_row=precompute_row, store_pair=store_pair, ) - _if_blk = scf.IfOp(blk_valid) - with ir.InsertionPoint(_if_blk.then_block): - _ifexpert_of = scf.IfOp(exp_valid) - with ir.InsertionPoint(_ifexpert_of.then_block): + _all_valid = arith.andi(blk_valid, arith.andi(exp_valid, tile_has_tokens)) + + if const_expr(_persistent): + # Short-circuit: contiguous tiles are monotonically increasing, + # so once bx_m >= num_valid_ids all remaining tiles are invalid. + _cur_active = arith.andi(_still_active, blk_valid) + _do_gemm = arith.andi( + _cur_active, arith.andi(exp_valid, tile_has_tokens) + ) + _if_valid = scf.IfOp(_do_gemm) + with ir.InsertionPoint(_if_valid.then_block): _moe_gemm2_then_body() scf.YieldOp([]) - scf.YieldOp([]) - gpu.barrier() - scf.YieldOp([]) + gpu.barrier() + scf.YieldOp([_cur_active]) + else: + _if_valid = scf.IfOp(_all_valid) + with ir.InsertionPoint(_if_valid.then_block): + _moe_gemm2_then_body() + scf.YieldOp([]) + + gpu.barrier() + scf.YieldOp([expert_i32, _expert_b_base]) _for_ip.__exit__(None, None, None) # -- Host launcher (flyc.jit + .launch) -------------------------------- @@ -2775,6 +4457,9 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): inter_dim_pad, use_cshuffle_epilog, persist_m, + _sort_block_m, + _cu_num if _persistent else 0, + xcd_swizzle, ) @flyc.jit @@ -2795,19 +4480,29 @@ def launch_mixed_moe_gemm2( i32_size_expert_ids_in: fx.Int32, stream: fx.Stream, ): - allocator.finalized = False + _ = _cache_tag + allocator_pong.finalized = False + allocator_ping.finalized = False ctx = CompilationContext.get_current() with ir.InsertionPoint(ctx.gpu_module_body): - allocator.finalize() - - n_in = arith.index_cast(T.index, i32_n_in) - gx = n_in // fx.Index(tile_n) - _c_pm_l = fx.Index(persist_m) - gy = ( - arith.index_cast(T.index, i32_size_expert_ids_in) - + _c_pm_l - - fx.Index(1) - ) / _c_pm_l + allocator_pong.finalize() + allocator_ping.finalize() + + n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) + _tile_n_idx = arith.constant(tile_n, index=True) + _model_dim_pad_idx = arith.constant(model_dim_pad, index=True) + gx = ( + n_in - _model_dim_pad_idx + _tile_n_idx - arith.constant(1, index=True) + ) / _tile_n_idx + if _persistent: + gy = arith.constant(_cu_num, index=True) + else: + _c_pm_l = arith.constant(persist_m, index=True) + gy = ( + arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) + + _c_pm_l + - arith.constant(1, index=True) + ) / _c_pm_l moe_gemm2( arg_out, diff --git a/kernels/silu_and_mul_fq.py b/kernels/silu_and_mul_fq.py new file mode 100644 index 00000000..a802bf92 --- /dev/null +++ b/kernels/silu_and_mul_fq.py @@ -0,0 +1,419 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Fused silu_and_mul + optional quantization + sorted-scale write kernel (FlyDSL). + +Designed for split-K MOE stage1 post-processing: + + input : tmp_out (token_num * topk, inter_dim * 2) bf16 + sorted : sorted_token_ids (sorted_len,) i32 -- packed (token<<0 | slot<<24) + num_valid_ids (1,) i32 + output : out raw byte buffer + * quant_mode="fp4" -> FP4x2 packed, row stride = inter_dim//2 + * quant_mode="fp8" -> MXFP8 (e4m3fn) bytes, row stride = inter_dim + * quant_mode="none" -> bf16, row stride = inter_dim * 2 + out_scale_sorted raw byte buffer -- tiled E8M0 scale + (only written when quant_mode in {"fp4","fp8"}; same tiled layout + as ``moe_mxfp4_sort``) + +Grid: (num_sorted_rows, 1, 1) -- one workgroup per sorted row (including blockM padding). +Block: (BLOCK_THREADS, 1, 1) + +Each workgroup: + 1. Loads sorted_token_ids[bid] -> (token_id, slot_id) -> row = token_id * topk + slot_id + 2. If bid < num_valid_ids (valid row): + a. Reads gate/up depending on gui_layout: + gui_layout=False -> gate at col [0:inter_dim], up at col [inter_dim:2*inter_dim] + gui_layout=True -> block-interleaved per-16 (gate[0:16], up[0:16], gate[16:32], ...) + b. Computes silu(gate) * up in f32 + c. Per-1x32 MXFP4/MXFP8 quant -> writes packed data + E8M0 scale in tiled layout, + or (quant_mode="none") writes bf16 directly + 3. If bid >= num_valid_ids (blockM padding row): + a. quant_mode in {fp4,fp8}: writes zero E8M0 scale (keeps tiled layout consistent) + b. quant_mode=="none": no-op + +All arithmetic uses FlyDSL high-level APIs: +``fx_math.absf`` / ``rocdl.exp2`` / ``rocdl.rcp`` / ``ArithValue.shuffle_xor`` / +``rocdl.cvt_pk_fp8_f32`` / ``vector.bitcast`` / ``vector.truncf``. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.expr import arith, vector, rocdl, range_constexpr +from flydsl.expr import buffer_ops, math as fx_math +from flydsl.expr.typing import T +from flydsl.expr.arith import ArithValue +from flydsl.compiler.kernel_function import CompilationContext + +from flydsl._mlir import ir + +from kernels.kernels_common import get_warp_size + +BLOCK_THREADS = 256 +WARP_SIZE = get_warp_size() + + +def _make_scale_tiled_layout(scale_cols_val): + """Build hierarchical 2-D layout for sorted E8M0 scale bytes. + + Uses flydsl's hierarchical shape to express the tiled decomposition:: + + row -> (row%16, (row//16)%2, row//32) + col -> (col%4, (col//4)%2, col//8) + + flydsl decomposes hierarchical shapes innermost-first (left-to-right), + so shape ``(16, 2, N)`` yields ``(idx%16, (idx//16)%2, idx//32)``. + + Strides: ``(4, 1, n32_sort, 64, 2, 256)`` where + ``n32_sort = scale_cols * 32``. + """ + n32_sort = scale_cols_val * 32 + return fx.make_layout( + ((16, 2, 32), (4, 2, 8)), + stride=((4, 1, n32_sort), (64, 2, arith.constant(256, type=T.i32))), + ) + + +def _scale_byte_offset(layout_scale, row, col32): + """Compute byte offset for one E8M0 scale element via layout algebra.""" + result = fx.crd2idx(fx.make_coord(row, col32), layout_scale) + scalar = fx.get_scalar(result) + if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): + scalar = arith.index_cast(T.index, scalar) + return ArithValue(scalar) + + +def build_silu_and_mul_fq_module( + inter_dim: int, + topk: int, + quant_mode: str = "fp4", + gui_layout: bool = False, +): + """Return a JIT launcher for fused silu_and_mul + optional quant + scale sort. + + Parameters + ---------- + inter_dim : int + Output columns of stage1 (after activation). Input has ``inter_dim*2`` cols. + Must be divisible by 32 (MXFP4/MXFP8 block size). + topk : int + Number of expert slots per token. + quant_mode : str + One of ``"fp4"`` (default, MXFP4 e2m1 + e8m0 scale), + ``"fp8"`` (MXFP8 e4m3fn + e8m0 scale), + or ``"none"`` (bf16 passthrough, no scale written). + The old 2-argument call ``build_silu_and_mul_fq_module(inter_dim, topk)`` + keeps the original MXFP4 semantics. + gui_layout : bool + ``False`` (default): input is gate-up separated [gate_0:N | up_0:N]. + ``True``: input is block-interleaved per-16 + [gate_0:16, up_0:16, gate_16:32, ...]; requires ``VEC <= 16``. + """ + assert inter_dim % 32 == 0, f"inter_dim={inter_dim} must be divisible by 32" + if quant_mode not in ("fp4", "fp8", "none"): + raise ValueError( + f"quant_mode must be one of ('fp4','fp8','none'), got {quant_mode!r}" + ) + _need_fp4 = quant_mode == "fp4" + _need_fp8 = quant_mode == "fp8" + _need_quant = _need_fp4 or _need_fp8 + _fp_headroom = 2 if _need_fp4 else 8 # only used when _need_quant + + scale_cols = inter_dim // 32 + ELEMS_PER_THREAD = (inter_dim + BLOCK_THREADS - 1) // BLOCK_THREADS + VEC = max(ELEMS_PER_THREAD, 2) + if VEC % 2 != 0: + VEC += 1 + assert 32 % VEC == 0, f"VEC={VEC} must divide 32 evenly" + if gui_layout: + assert VEC <= 16, f"VEC={VEC} must be <=16 for gui (block-interleave) layout" + THREADS_PER_QUANT_BLK = 32 // VEC + SHUFFLE_DISTS = [] + d = 1 + while d < THREADS_PER_QUANT_BLK: + SHUFFLE_DISTS.append(d) + d *= 2 + + fp4_row_bytes = inter_dim // 2 + _fp4_pack_bytes = VEC // 2 + + @flyc.kernel + def silu_and_mul_fq_kernel( + x: fx.Tensor, + out_buf: fx.Tensor, + out_scale_sorted: fx.Tensor, + sorted_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + token_num: fx.Int32, + ): + bid = fx.block_idx.x + tid = fx.thread_idx.x + + vec_f32_ty = T.vec(VEC, T.f32) + + # ── Layout API: buffer-backed tensor for structured input ──── + X_buf = fx.rocdl.make_buffer_tensor(x) + copy_atom_vec = fx.make_copy_atom( + fx.rocdl.BufferCopy(VEC * 16), 16 # bf16 = 16 bits + ) + vec_reg_ty = fx.MemRefType.get( + T.bf16, fx.LayoutType.get(VEC, 1), fx.AddressSpace.Register + ) + vec_reg_lay = fx.make_layout(VEC, 1) + + def load_vec(div_tensor, idx): + r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) + fx.copy_atom_call(copy_atom_vec, fx.slice(div_tensor, (None, idx)), r) + return fx.memref_load_vec(r) + + # ── Buffer resources for flat byte buffers and scalar loads ─── + out_rsrc = buffer_ops.create_buffer_resource(out_buf, max_size=True) + scale_rsrc = buffer_ops.create_buffer_resource(out_scale_sorted, max_size=True) + tid_rsrc = buffer_ops.create_buffer_resource(sorted_ids, max_size=True) + nv_rsrc = buffer_ops.create_buffer_resource(num_valid_ids, max_size=True) + + num_valid = buffer_ops.buffer_load(nv_rsrc, fx.Int32(0), vec_width=1, dtype=T.i32) + bid_i32 = ArithValue(bid) + + fused_tid_val = buffer_ops.buffer_load(tid_rsrc, bid_i32, vec_width=1, dtype=T.i32) + token_id = fused_tid_val & 0xFFFFFF + slot_id = fused_tid_val >> 24 + is_valid = (bid_i32 < num_valid) & (token_id < ArithValue(token_num)) & (slot_id < topk) + + layout_scale = _make_scale_tiled_layout( + ArithValue(arith.constant(scale_cols, type=T.i32)) + ) + + def _store_scale(scale_rsrc, layout_scale, bid_i32, col0, val_i8): + if (col0 & 31) == fx.Int32(0): + s_off = _scale_byte_offset(layout_scale, bid_i32, col0 >> 5) + buffer_ops.buffer_store( + val_i8, scale_rsrc, s_off, offset_is_bytes=True, + ) + + def _f32_to_e2m1(qx_f32): + """Convert a scaled f32 value to fp4 (e2m1) 4-bit integer.""" + qx = qx_f32.bitcast(T.i32) + s = qx & 0x80000000 + e = (qx >> 23) & 0xFF + m = qx & 0x7FFFFF + c0_i32 = arith.constant(0, type=T.i32) + c126 = arith.constant(126, type=T.i32) + adj_exp = arith.maxsi(c126 - e, c0_i32) + m_denorm = (0x400000 | (m >> 1)) >> adj_exp + m = (e < arith.constant(127, type=T.i32)).select(m_denorm, m) + e = arith.maxsi(e - c126, c0_i32) + rounded = ((e << 2) | (m >> 21)) + 1 >> 1 + e2m1 = arith.minui(rounded, arith.constant(7, type=T.i32)) + return (s >> 28) | e2m1 + + thread_id = ArithValue(tid) + COLS_PER_ITER = BLOCK_THREADS * VEC + c0_f32 = arith.constant(0.0, type=T.f32) + + for iter_idx in range_constexpr( + (inter_dim + COLS_PER_ITER - 1) // COLS_PER_ITER + ): + col0 = thread_id * VEC + iter_idx * COLS_PER_ITER + + if col0 < inter_dim: + if is_valid: + in_row = token_id * topk + slot_id + + row_x = fx.slice(X_buf, (in_row, None)) + row_div = fx.logical_divide(row_x, fx.make_layout(VEC, 1)) + tile_idx = tid + iter_idx * BLOCK_THREADS + + if gui_layout: + # Block-interleaved per 16: + # [gate_0:16, up_0:16, gate_16:32, up_16:32, ...] + # VEC <= 16 guarantees one VEC-wide load is entirely + # within a single gate-16 (or up-16) chunk. + # gate_col = (col0 // 16) * 32 + (col0 % 16) + # up_col = gate_col + 16 + block16 = col0 >> 4 + off16 = col0 & 15 + gate_col = block16 * 32 + off16 + up_col = gate_col + 16 + gate_tile_idx = gate_col // VEC + up_tile_idx = up_col // VEC + else: + gate_tile_idx = tile_idx + up_tile_idx = tile_idx + inter_dim // VEC + + gate_f32 = load_vec(row_div, gate_tile_idx).extf(vec_f32_ty) + up_f32 = load_vec(row_div, up_tile_idx).extf(vec_f32_ty) + + # ── SiLU(gate) * up ────────────────────────────── + neg_log2e = arith.constant(-1.4426950408889634, type=T.f32) + c1_f32 = arith.constant(1.0, type=T.f32) + act_vals = [] + for vi in range_constexpr(VEC): + g = vector.extract(gate_f32, static_position=[vi], dynamic_position=[]) + u = vector.extract(up_f32, static_position=[vi], dynamic_position=[]) + emu = ArithValue(rocdl.exp2(T.f32, g * neg_log2e)) + sig = ArithValue(rocdl.rcp(T.f32, c1_f32 + emu)) + act_vals.append(g * sig * u) + + if _need_quant: + # ── Per-32-block max for E8M0 scale ────────── + local_max = c0_f32 + for vi in range_constexpr(VEC): + local_max = local_max.maximumf(fx_math.absf(act_vals[vi])) + + for sh_dist in SHUFFLE_DISTS: + local_max = local_max.maximumf( + local_max.shuffle_xor(fx.Int32(sh_dist), fx.Int32(WARP_SIZE)) + ) + + # ── Compute e8m0 bias + quant_scale (fp4: h=2, fp8: h=8) ── + exp_field = ((local_max.bitcast(T.i32) + 0x200000) & 0xFF800000) >> 23 + e8m0_biased = arith.maxsi( + exp_field - arith.constant(_fp_headroom, type=T.i32), + arith.constant(0, type=T.i32), + ) + quant_scale = ( + (arith.constant(254, type=T.i32) - e8m0_biased) << 23 + ).bitcast(T.f32) + + if _need_fp4: + fp4_vals = [] + for vi in range_constexpr(VEC): + fp4_vals.append(_f32_to_e2m1(act_vals[vi] * quant_scale)) + + packed_i32 = fp4_vals[0] | (fp4_vals[1] << 4) + for k in range_constexpr(1, VEC // 2): + byte_k = fp4_vals[2 * k] | (fp4_vals[2 * k + 1] << 4) + packed_i32 = packed_i32 | (byte_k << (k * 8)) + + fp4_byte_off = in_row * fp4_row_bytes + (col0 >> 1) + _pack_type = {1: T.i8, 2: T.i16}.get(_fp4_pack_bytes, T.i32) + packed = ( + arith.trunci(_pack_type, packed_i32) + if _fp4_pack_bytes < 4 + else packed_i32 + ) + buffer_ops.buffer_store( + packed, out_rsrc, fp4_byte_off, offset_is_bytes=True, + ) + else: + # MXFP8 (e4m3fn) output, row stride = inter_dim bytes. + # Use rocdl.cvt_pk_fp8_f32 to pack 2 f32 -> 2 fp8 bytes + # per dword (word_sel selects low/high pair). + scaled = [act_vals[vi] * quant_scale for vi in range_constexpr(VEC)] + fp8_byte_off = in_row * inter_dim + col0 + if VEC <= 4: + packed_i32 = arith.constant(0, type=T.i32) + for w in range_constexpr(VEC // 2): + packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled[2 * w], + scaled[2 * w + 1], + packed_i32, + w, + ) + if VEC == 2: + buffer_ops.buffer_store( + arith.trunci(T.i16, packed_i32), + out_rsrc, + fp8_byte_off, + offset_is_bytes=True, + ) + else: + buffer_ops.buffer_store( + packed_i32, + out_rsrc, + fp8_byte_off, + offset_is_bytes=True, + ) + else: + # VEC > 4: pack 4 f32 -> 1 dword per group + for wg in range_constexpr(VEC // 4): + base = wg * 4 + packed_w = arith.constant(0, type=T.i32) + packed_w = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled[base], + scaled[base + 1], + packed_w, + 0, + ) + packed_w = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled[base + 2], + scaled[base + 3], + packed_w, + 1, + ) + word_off = fp8_byte_off + wg * 4 + buffer_ops.buffer_store( + packed_w, + out_rsrc, + word_off, + offset_is_bytes=True, + ) + + _store_scale(scale_rsrc, layout_scale, bid_i32, col0, + arith.trunci(T.i8, e8m0_biased)) + else: + # quant_mode == "none": write bf16 out directly. + # out row stride = inter_dim * 2 bytes. + act_f32_vec = vector.from_elements(vec_f32_ty, act_vals) + act_f32_av = ArithValue(act_f32_vec) + act_bf16_vec = act_f32_av.truncf(T.vec(VEC, T.bf16)) + # Write as packed i32 (VEC/2 dwords). + vec_dw = VEC // 2 # each dword = 2 bf16 elems + if vec_dw >= 1: + act_i32 = vector.bitcast(T.vec(vec_dw, T.i32), act_bf16_vec) + bf16_byte_off = in_row * (inter_dim * 2) + col0 * 2 + if vec_dw == 1: + store_val = vector.extract( + act_i32, static_position=[0], dynamic_position=[] + ) + buffer_ops.buffer_store( + store_val, + out_rsrc, + bf16_byte_off, + offset_is_bytes=True, + ) + else: + buffer_ops.buffer_store( + act_i32, + out_rsrc, + bf16_byte_off, + offset_is_bytes=True, + ) + else: + # Invalid (padding) row. Only zero-fill the e8m0 scale + # when a scale tile exists (fp4/fp8); quant_mode="none" + # has no scale buffer to maintain. + if _need_quant: + _store_scale(scale_rsrc, layout_scale, bid_i32, col0, + arith.constant(0, type=T.i8)) + + @flyc.jit + def launch_silu_and_mul_fq( + x: fx.Tensor, + out_buf: fx.Tensor, + out_scale_sorted: fx.Tensor, + sorted_ids: fx.Tensor, + num_valid_ids: fx.Tensor, + token_num: fx.Int32, + num_sorted_rows: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + pass + + idx_rows = ArithValue(num_sorted_rows).index_cast(T.index) + launcher = silu_and_mul_fq_kernel( + x, out_buf, out_scale_sorted, sorted_ids, num_valid_ids, token_num + ) + launcher.launch( + grid=(idx_rows, 1, 1), + block=(BLOCK_THREADS, 1, 1), + stream=stream, + ) + + return launch_silu_and_mul_fq diff --git a/scripts/run_benchmark.sh b/scripts/run_benchmark.sh index c8da0150..10325084 100755 --- a/scripts/run_benchmark.sh +++ b/scripts/run_benchmark.sh @@ -136,6 +136,17 @@ MOE_W4A16_SHAPES=' 512,7168,256,384,8,16,128,128,128,256 ' +# MoE A8W4 shapes (FP8 activation + MX-FP4 weight, gfx950 only): same format as MOE_SHAPES. +# GPT-OSS inspired: model_dim=3072, inter_dim=3072, E=128, topk=4; sweep tokens from 512 to +# bracket memory- and compute-bound regimes. tile_m>=32 / tile_k>=256 are MX-FP4 layout requirements. +MOE_A8W4_SHAPES=' +512,3072,3072,128,4,32,128,256,256,256 +1024,3072,3072,128,4,32,128,256,256,256 +2048,3072,3072,128,4,32,128,256,256,256 +4096,3072,3072,128,4,32,128,256,256,256 +8192,3072,3072,128,4,32,128,256,256,256 +' + # Memory bound threshold (M or tokens <= threshold => memory bound) MEMORY_BOUND_THRESHOLD=512 @@ -192,14 +203,14 @@ print_bound_info() { # Print one-line perf row (like run_tests.sh style). _fmt_table_header() { # Use fixed widths and truncate long strings to keep columns aligned. - # (bash printf supports precision on %s: %-W.Ps) - printf "\n%-14.14s %-34.34s %-10.10s %10s %10s\n" "op" "shape" "dtype" "TB/s" "TFLOPS" - printf "%-14.14s %-34.34s %-10.10s %10s %10s\n" "--------------" "----------------------------------" "----------" "----------" "----------" + # op column is wide enough to host "moe__s2_atomic" / "_reduce" suffixes. + printf "\n%-22.22s %-34.34s %-10.10s %10s %10s\n" "op" "shape" "dtype" "TB/s" "TFLOPS" + printf "%-22.22s %-34.34s %-10.10s %10s %10s\n" "----------------------" "----------------------------------" "----------" "----------" "----------" } _emit_row() { op="$1"; shape="$2"; dtype="$3"; tbps="$4"; tflops="$5" - printf "%-14.14s %-34.34s %-10.10s %10s %10s\n" "${op}" "${shape}" "${dtype}" "${tbps}" "${tflops}" + printf "%-22.22s %-34.34s %-10.10s %10s %10s\n" "${op}" "${shape}" "${dtype}" "${tbps}" "${tflops}" } _normalize_op() { @@ -364,6 +375,51 @@ print(f"{op}\t{shape}\t{dtype}\t{fmt(tbps)}\t{fmt(tflops)}") PY } +_emit_moe_s2_rows() { + # Args: op_prefix shape log_path + # Extract separate atomic/reduce rows from MoE stage2 log lines. A line looks like: + # FlyDSL MoE stage2 [moe_gemm2] fp4 atomic | 7168x2048, ... | 1163.2 us, 1654.24 TFLOPS, 0.377 TB/s + # Emit two table rows (op_prefix_atomic, op_prefix_reduce). Falls back to single row + # tagged "mixed" if the log only has one mode (e.g., --gemm2_mode was overridden). + op_prefix="$1"; shape="$2"; log_path="$3" + python3 - "$op_prefix" "$shape" "$log_path" <<'PY' +import re, sys + +op_prefix, shape, path = sys.argv[1], sys.argv[2], sys.argv[3] +try: + with open(path, "r", errors="ignore") as f: + txt = f.read() +except Exception: + txt = "" + +pat = re.compile( + r"FlyDSL MoE stage2 \[[^]]+\]\s+(\S+)\s+(atomic|reduce)\b.*?" + r"([0-9.]+)\s*TFLOPS.*?([0-9.]+)\s*TB/s" +) +# keep last occurrence per mode +found = {} +for m in pat.finditer(txt): + dtype, mode = m.group(1), m.group(2) + found[mode] = (dtype, float(m.group(3)), float(m.group(4))) + +def fmt(x): + return "-" if x is None else f"{x:.3f}" + +# Always emit atomic row first (if any), then reduce row. +emitted = False +for mode in ("atomic", "reduce"): + if mode not in found: + continue + dtype, tflops, tbps = found[mode] + print(f"{op_prefix}_{mode}\t{shape}\t{dtype}\t{fmt(tbps)}\t{fmt(tflops)}") + emitted = True + +if not emitted: + # Nothing parsed — emit empty row so caller knows. + print(f"{op_prefix}_atomic\t{shape}\t-\t-\t-") +PY +} + # ============================================================================ # Run Benchmarks # ============================================================================ @@ -671,12 +727,9 @@ if [ "${RUN_MOE}" -eq 1 ] && [ "${IS_CDNA}" = "true" ]; then _emit_row "moe_gemm1" "${shape_moe}" "${dt_s1}" "${tb_s1}" "${tf_s1}" fi - dt_s2="$(grep -Eo 'FlyDSL MoE stage2 \[[^]]+\] [^ ]+' "${log}" | tail -1 | awk '{print $NF}' || true)" - tf_s2="$(grep -Eo 'FlyDSL MoE stage2 .* ([0-9.]+) TFLOPS' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" - tb_s2="$(grep -Eo 'FlyDSL MoE stage2 .* ([0-9.]+) TB/s' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" - if [ -n "${dt_s2}" ] && [ -n "${tf_s2}" ] && [ -n "${tb_s2}" ]; then - _emit_row "moe_gemm2" "${shape_moe}" "${dt_s2}" "${tb_s2}" "${tf_s2}" - fi + _emit_moe_s2_rows "moe_gemm2" "${shape_moe}" "${log}" | while IFS="$(printf '\t')" read -r _op _sh _dt _tb _tf; do + _emit_row "${_op}" "${_sh}" "${_dt}" "${_tb}" "${_tf}" + done done # MoE FP4 (gfx950 only) @@ -720,12 +773,9 @@ if [ "${RUN_MOE}" -eq 1 ] && [ "${IS_CDNA}" = "true" ]; then _emit_row "moe_fp4_s1" "${shape_moe}" "${dt_s1}" "${tb_s1}" "${tf_s1}" fi - dt_s2="$(grep -Eo 'FlyDSL MoE stage2 \[[^]]+\] [^ ]+' "${log}" | tail -1 | awk '{print $NF}' || true)" - tf_s2="$(grep -Eo 'FlyDSL MoE stage2 .* ([0-9.]+) TFLOPS' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" - tb_s2="$(grep -Eo 'FlyDSL MoE stage2 .* ([0-9.]+) TB/s' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" - if [ -n "${dt_s2}" ] && [ -n "${tf_s2}" ] && [ -n "${tb_s2}" ]; then - _emit_row "moe_fp4_s2" "${shape_moe}" "${dt_s2}" "${tb_s2}" "${tf_s2}" - fi + _emit_moe_s2_rows "moe_fp4_s2" "${shape_moe}" "${log}" | while IFS="$(printf '\t')" read -r _op _sh _dt _tb _tf; do + _emit_row "${_op}" "${_sh}" "${_dt}" "${_tb}" "${_tf}" + done fi else # Skip gracefully on unsupported architectures @@ -780,11 +830,64 @@ if [ "${RUN_MOE}" -eq 1 ] && [ "${IS_CDNA}" = "true" ]; then _emit_row "moe_w4a16_s1" "${shape_moe}" "${dt_s1}" "${tb_s1}" "${tf_s1}" fi - dt_s2="$(grep -Eo 'FlyDSL MoE stage2 \[[^]]+\] [^ ]+' "${log}" | tail -1 | awk '{print $NF}' || true)" - tf_s2="$(grep -Eo 'FlyDSL MoE stage2 .* ([0-9.]+) TFLOPS' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" - tb_s2="$(grep -Eo 'FlyDSL MoE stage2 .* ([0-9.]+) TB/s' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" - if [ -n "${dt_s2}" ] && [ -n "${tf_s2}" ] && [ -n "${tb_s2}" ]; then - _emit_row "moe_w4a16_s2" "${shape_moe}" "${dt_s2}" "${tb_s2}" "${tf_s2}" + _emit_moe_s2_rows "moe_w4a16_s2" "${shape_moe}" "${log}" | while IFS="$(printf '\t')" read -r _op _sh _dt _tb _tf; do + _emit_row "${_op}" "${_sh}" "${_dt}" "${_tb}" "${_tf}" + done + done + + # MoE A8W4 — FP8 activation + MX-FP4 weight (gfx950 only). End-to-end 2-stage: + # stage1 a_dtype=fp8,b_dtype=fp4 -> silu(gate)*up fp16 -> MX-FP8 re-quant -> stage2. + for shape in $MOE_A8W4_SHAPES; do + [ -z "$shape" ] && continue + oldIFS=$IFS + IFS=, + # shellcheck disable=SC2086 # intentional word-splitting on IFS=, + set -- $shape + IFS=$oldIFS + tokens=$1; model_dim=$2; inter_dim=$3; experts=$4; topk=$5; tile_m=$6; tile_n=$7; tile_k=$8; tile_n2=$9; tile_k2=${10} + dtype="a8w4" + shape_moe="t${tokens}-d${model_dim}x${inter_dim}-e${experts}k${topk}" + log="${BENCH_LOG_DIR}/moe_a8w4_t${tokens}_md${model_dim}_id${inter_dim}_e${experts}_k${topk}.log" + if python3 tests/kernels/test_moe_gemm.py \ + --in_dtype a8w4 \ + -dim "$model_dim,$inter_dim" \ + -t "$tokens" \ + -e "$experts" \ + -k "$topk" \ + --num_warmup 10 \ + --num_iters 100 \ + --tile_m "$tile_m" \ + --tile_n "$tile_n" \ + --tile_k "$tile_k" \ + --tile_n2 "$tile_n2" \ + --tile_k2 "$tile_k2" \ + --skip_ref false \ + --compare_aiter_ck false >"${log}" 2>&1; then + # CLI prints "Skipping a8w4: requires gfx950+" on unsupported archs. + if grep -q "requires gfx950\|Skipping a8w4" "${log}"; then + _emit_row "moe_a8w4" "${shape_moe}" "${dtype}" "skip" "skip" + else + SUCCESS_COUNT=$((SUCCESS_COUNT + 1)) + + dt_s1="$(grep -Eo 'FlyDSL MoE stage1\[[^]]+\]:' "${log}" | tail -1 | cut -d'[' -f2 | cut -d']' -f1 || true)" + tf_s1="$(grep -Eo 'FlyDSL MoE stage1\[[^]]+\]:.* ([0-9.]+) TFLOPS' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" + tb_s1="$(grep -Eo 'FlyDSL MoE stage1\[[^]]+\]:.* ([0-9.]+) TB/s' "${log}" | tail -1 | awk '{print $(NF-1)}' || true)" + if [ -n "${dt_s1}" ] && [ -n "${tf_s1}" ] && [ -n "${tb_s1}" ]; then + _emit_row "moe_a8w4_s1" "${shape_moe}" "${dt_s1}" "${tb_s1}" "${tf_s1}" + fi + + _emit_moe_s2_rows "moe_a8w4_s2" "${shape_moe}" "${log}" | while IFS="$(printf '\t')" read -r _op _sh _dt _tb _tf; do + _emit_row "${_op}" "${_sh}" "${_dt}" "${_tb}" "${_tf}" + done + fi + else + if grep -q "requires gfx950\|Skipping a8w4\|not supported" "${log}" 2>/dev/null; then + _emit_row "moe_a8w4" "${shape_moe}" "${dtype}" "skip" "skip" + else + FAIL_COUNT=$((FAIL_COUNT + 1)) + echo "moe a8w4 failed. Log: ${log}" >&2 + _show_fail_log "${log}" "moe_a8w4" + fi fi done fi diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index b979b3fe..5796ab75 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -412,9 +412,9 @@ def run_moe_stage1( # f"{' (even)' if even_dispatch else ' (random)'}" # ) - if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "fp4"): + if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "fp4", "a8w4"): raise ValueError( - f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16','fp4'), got {in_dtype!r}" + f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16','fp4','a8w4'), got {in_dtype!r}" ) is_int4 = in_dtype == "int4" is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights @@ -422,18 +422,23 @@ def run_moe_stage1( is_int8 = in_dtype in ("int8", "int8smooth", "int4") is_int8smooth = in_dtype == "int8smooth" is_fp4 = in_dtype == "fp4" + is_a8w4 = in_dtype == "a8w4" # MX-FP8 activation + MX-FP4 weight + is_fp4_path = is_fp4 or is_a8w4 # shared weight/shuffle pipeline use_packed_int4 = is_int4 or is_int4_bf16 # Quantize inputs / weights. - if in_dtype == "fp4": + if is_fp4_path: from tests.kernels.utils import fp4_utils - x_fp4, x_scale_raw = _per_1x32_fp4_quant(x_fp32) + # x: MX-FP8 (e4m3fn, 1 B/elem) for a8w4; MX-FP4 (packed, 0.5 B/elem) for fp4. + if is_a8w4: + x_q, x_scale_raw = _per_1x32_mxfp8_quant(x_fp32) + else: + x_q, x_scale_raw = _per_1x32_fp4_quant(x_fp32) w1_flat_fp32 = w1_fp32.view(experts * (2 * inter_dim), model_dim) w1_fp4, w1_scale_raw = _per_1x32_fp4_quant(w1_flat_fp32) - x_q = x_fp4 # will be converted to uint8 below w1_q = w1_fp4 w2_q = None # not needed for stage1 - scale_x = x_scale_raw # raw e8m0 [tokens, K//32] + scale_x = x_scale_raw # raw e8m0 scale (K dim divided by 32) scale_w1 = w1_scale_raw # raw e8m0 [E*2N, K//32] elif in_dtype == "fp8": x_q, scale_x = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) # [tokens,K], [tokens,1] @@ -511,8 +516,8 @@ def run_moe_stage1( scale_w1 = None # Preshuffle weights and prepare scale tensors. - if is_fp4: - # FP4: preshuffle via float4_e2m1fn_x2 view, scale as uint8. + if is_fp4_path: + # FP4 path (fp4 or a8w4): W1 is always MX-FP4; x is MX-FP4 (fp4) or MX-FP8 (a8w4). w1_shuffled = shuffle_weight(w1_q.view(torch.float4_e2m1fn_x2)) w_kernel = w1_shuffled.view(torch.uint8).contiguous() w1_q_flat = w1_q.view(experts * (2 * inter_dim), model_dim // 2).contiguous() @@ -527,6 +532,11 @@ def run_moe_stage1( token_num=tokens, block_size=tile_m, ).view(torch.uint8).contiguous() + # Kernel consumes raw bytes: fp4 packs 2 elems/byte (K//2 cols); a8w4 is 1 B/elem (K cols). + # Preserve the original dtype view for the torch reference path: mxfp4 detection in + # ``_detect_scale_kind`` needs the packed fp4 layout (shape[-1]*2 == scale*32), while + # mxfp8 detection needs dtype=float8_e4m3fn. The kernel itself always reads uint8. + x_q_ref = x_q x_q = x_q.view(torch.uint8).contiguous().view(tokens, -1) else: w1_shuffled = shuffle_weight(w1_q) @@ -571,7 +581,7 @@ def run_moe_stage1( else: out = torch.empty((tokens, topk, inter_dim), device=device, dtype=_out_torch_dtype) - if is_fp4: + if is_fp4_path: exe = compile_mixed_moe_gemm1( model_dim=model_dim, inter_dim=inter_dim, @@ -581,16 +591,20 @@ def run_moe_stage1( tile_n=tile_n, tile_k=tile_k, doweight_stage1=bool(doweight_stage1), - a_dtype="fp4", + a_dtype="fp8" if is_a8w4 else "fp4", b_dtype="fp4", out_dtype="f16", act="silu", ) bias_dummy = torch.empty((0,), device=device, dtype=torch.float32) + # Empty placeholder: stage1 writes a sorted E8M0 scale buffer only + # when gate_mode=INTERLEAVE + fp8/fp4 fused quant. For fp4-fp4 we + # still have to pass the arg slot (launcher signature is fixed). + out_scale_sorted_dummy = torch.empty((0,), device=device, dtype=torch.uint8) def _s1_args_fp4(o, x, w, sx, sw, st, eids, sw_sorted): return (o, x, w, sx, sw, st, eids, sw_sorted, - num_valid_ids, bias_dummy, + num_valid_ids, bias_dummy, out_scale_sorted_dummy, tokens, inter_dim * 2, model_dim, int(blocks), torch.cuda.current_stream()) @@ -667,7 +681,8 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): atol = 0.5 if (is_int4 or is_int4_bf16) else 0.25 assert verify_output(out.to(torch.float32), ref, rtol=rtol, atol=atol) else: - x_ref = x_q + # Use original-dtype view for fp4/a8w4 so `_detect_scale_kind` picks the right branch. + x_ref = x_q_ref if is_fp4_path else x_q sx_ref = scale_x ref = torch_moe_gemm1( x_ref, w1_q_flat, sx_ref, scale_w1_flat, @@ -675,14 +690,14 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): inter_dim=inter_dim, doweight_stage1=doweight_stage1, group_size=group_size, scale_w1_groups=scale_w1_groups, ) - rtol = 0.5 if (is_int4 or is_int4_bf16 or is_fp4) else 0.25 - atol = 0.5 if (is_int4 or is_int4_bf16 or is_fp4) else 0.25 + rtol = 0.5 if (is_int4 or is_int4_bf16 or is_fp4_path) else 0.25 + atol = 0.5 if (is_int4 or is_int4_bf16 or is_fp4_path) else 0.25 assert verify_output( out.to(torch.float32), ref, rtol=rtol, atol=atol, - logits_diff_threshold=1 if is_fp4 else 2e-3, + logits_diff_threshold=1 if is_fp4_path else 2e-3, ) # Note: kernel launches full expert-block range; effective work is gated by num_valid_ids. @@ -692,19 +707,35 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): # Rough bytes-moved accounting (same spirit as GEMM tests: count each tensor once). # Only activated experts load weights/scales: E_active = min(E, tokens * topk). active_experts = min(experts, tokens * topk) - bytes_moved = 0 is_f16_or_bf16_s1 = is_int4_bf16 or in_dtype in ("bf16", "fp16") - x_elem_bytes = 2 if is_f16_or_bf16_s1 else 1 - bytes_moved += (tokens * topk if is_int8smooth else tokens) * model_dim * x_elem_bytes # x (bf16 for W4A16, else fp8/int8) - bytes_moved += (active_experts * (2 * inter_dim) * model_dim) // (2 if use_packed_int4 else 1) # w (packed for int4) - bytes_moved += tokens * topk * inter_dim * 2 # out fp16 (logical) - bytes_moved += ((tokens * topk if is_int8smooth else tokens) * 4) if not is_f16_or_bf16_s1 else 0 # scale_x f32 - if use_groupwise_scale: + # Per-element bits for X and W (W1 total cols = 2*inter_dim). + # fp4: x=4b, w=4b (both MX-FP4) + # a8w4: x=8b, w=4b (MX-FP8 act + MX-FP4 weight) + # f16/bf16/int4_bf16: x=16b + # int4/int8/fp8 etc.: x=8b, w packed to 4b when W4A*. + x_bits = 4 if is_fp4 else (16 if is_f16_or_bf16_s1 else 8) + w_bits = 4 if (is_fp4_path or use_packed_int4) else (16 if is_f16_or_bf16_s1 else 8) + x_rows = tokens * topk if is_int8smooth else tokens + x_elems = x_rows * model_dim + w_elems = active_experts * (2 * inter_dim) * model_dim + + bytes_moved = 0 + bytes_moved += (x_elems * x_bits) // 8 # x + bytes_moved += (w_elems * w_bits) // 8 # w1 + bytes_moved += tokens * topk * inter_dim * 2 # out fp16 (logical, post-silu) + + # scale bytes + if is_fp4_path: + # Per-1x32 E8M0 scale: 1 byte per 32 logical elements for both x and w1. + bytes_moved += x_elems // 32 + bytes_moved += w_elems // 32 + elif use_groupwise_scale: num_groups_s1 = model_dim // group_size _scale_bytes = 2 if scale_dtype == "bf16" else 4 bytes_moved += active_experts * num_groups_s1 * (2 * inter_dim) * _scale_bytes # groupwise scale elif not is_f16_or_bf16_s1: - bytes_moved += active_experts * (2 * inter_dim) * 4 # per-row scale_w f32 + bytes_moved += x_rows * 4 # scale_x f32 + bytes_moved += active_experts * (2 * inter_dim) * 4 # per-row scale_w f32 # Note: routing metadata (sorted_weights, sorted_token_ids, sorted_expert_ids) excluded # from bytes_moved — they are negligible vs weight/activation/scale tensors. tbps = bytes_moved / 1e12 / (us / 1e6) @@ -944,9 +975,9 @@ def run_moe_stage2( # f"{' (even)' if even_dispatch else ' (random)'}" # ) - if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "fp4"): + if in_dtype not in ("fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "fp4", "a8w4"): raise ValueError( - f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16','fp4'), got {in_dtype!r}" + f"in_dtype must be one of ('fp8','fp16','bf16','int8','int8smooth','int4','int4_bf16','fp4','a8w4'), got {in_dtype!r}" ) is_int4 = in_dtype == "int4" is_int4_bf16 = in_dtype == "int4_bf16" # W4A16: bf16 activations, packed int4 weights @@ -954,6 +985,9 @@ def run_moe_stage2( is_int8 = in_dtype in ("int8", "int8smooth", "int4") is_int8smooth = in_dtype == "int8smooth" is_fp4 = in_dtype == "fp4" + is_a8w4 = in_dtype == "a8w4" # MX-FP8 activation + MX-FP4 weight + # Share the FP4 stage2 path (W2 shuffle / scale sort / mixed kernel). + is_fp4_path = is_fp4 or is_a8w4 use_packed_int4 = is_int4 or is_int4_bf16 # Quantize inputs / weights. @@ -991,13 +1025,14 @@ def run_moe_stage2( w1_q, scale_w1 = pertoken_quant(w1_fp32, quant_dtype=torch.int8, dtypeMax=7) w2_q, scale_w2 = pertoken_quant(w2_fp32, quant_dtype=torch.int8, dtypeMax=7) scale_x = None - elif in_dtype == "fp4": + elif in_dtype in ("fp4", "a8w4"): from tests.kernels.utils import fp4_utils if fp4_utils is None: pytest.skip("fp4_utils not available (triton not installed)") if "gfx95" not in ARCH: pytest.skip(f"FP4 MFMA requires gfx950+, got {ARCH}") - # FP4: quantize W2 only here; A2 is provided via a2_fp8_in from stage1 output + # FP4 / A8W4 share the MXFP4 W2 path; quantize W2 only here. A2 comes + # from `a2_fp8_in` (FP4 packed bytes for 'fp4', MX-FP8 e4m3fn for 'a8w4'). w2_flat_fp32 = w2_fp32.view(experts * model_dim, inter_dim) w2_fp4, w2_scale_raw = _per_1x32_fp4_quant(w2_flat_fp32) w2_q = w2_fp4 @@ -1032,23 +1067,26 @@ def run_moe_stage2( # Override per-row scale (kernel uses groupwise scale instead). scale_w2 = None - if is_fp4: - # FP4: preshuffle W2 and prepare scales + if is_fp4_path: + # FP4 / A8W4: preshuffle W2 and prepare MXFP4 weight scales w2_shuffled = shuffle_weight(w2_q.view(torch.float4_e2m1fn_x2)) w2_kernel = w2_shuffled.view(torch.uint8).contiguous() w2_scale_1d = fp4_utils.e8m0_shuffle(scale_w2).view(torch.uint8).contiguous() - # A2 input: from a2_fp8_in (stage1 output, already FP4 quantized) + # A2 input: provided by the caller. For 'fp4' it is packed MXFP4 + # [tokens*topk, inter_dim//2]; for 'a8w4' it is MX-FP8 e4m3fn bytes + # [tokens*topk, inter_dim]. The companion scale is per-1x32 E8M0 + # [tokens*topk, inter_dim//32] for both paths. if a2_fp8_in is not None and a2_scale_in is not None: - a2_q = a2_fp8_in # already FP4 [tokens*topk, inter_dim//2] - a2_scale_raw = a2_scale_in # raw e8m0 [tokens*topk, inter_dim//32] + a2_q = a2_fp8_in + a2_scale_raw = a2_scale_in a2_scale = a2_scale_raw else: raise RuntimeError( - "run_moe_stage2(in_dtype='fp4') requires a2_fp8_in and a2_scale_in " - "(FP4 A2 must be quantized from stage1 output)." + f"run_moe_stage2(in_dtype={in_dtype!r}) requires a2_fp8_in and a2_scale_in " + "(A2 must be quantized externally for the FP4/A8W4 path)." ) - # Sort A2 scale by MoE routing order + # Sort A2 scale by MoE routing order (dtype-agnostic on the E8M0 bytes). a2_scale_1d = fp4_utils.moe_mxfp4_sort( a2_scale_raw.view(tokens, topk, -1), sorted_ids=sorted_token_ids, @@ -1056,7 +1094,10 @@ def run_moe_stage2( token_num=tokens, block_size=tile_m, ).view(torch.uint8).contiguous() - a2_q = a2_q.view(torch.uint8).contiguous() + # The kernel consumes A2 as a flat byte stream. For 'fp4' that's the + # packed fp4x2 bytes; for 'a8w4' the e4m3fn bytes. Keep a2_q as its + # logical dtype for the reference path below. + a2_q_kernel = a2_q.view(torch.uint8).contiguous() w1_shuffled = None w2_shuffled_orig = None @@ -1152,8 +1193,9 @@ def run_moe_stage2( doweight_stage2 = not bool(doweight_stage1) - if is_fp4: + if is_fp4_path: fp4_accumulate = not bool(use_reduce) + a_dtype_kernel = "fp8" if is_a8w4 else "fp4" exe = compile_mixed_moe_gemm2( model_dim=model_dim, inter_dim=inter_dim, @@ -1163,7 +1205,7 @@ def run_moe_stage2( tile_n=tile_n, tile_k=tile_k, doweight_stage2=bool(doweight_stage2), - a_dtype="fp4", + a_dtype=a_dtype_kernel, b_dtype="fp4", out_dtype="f16", accumulate=fp4_accumulate, @@ -1178,7 +1220,7 @@ def _s2_args_fp4_interm(interm, x, w, sx, sw, st, eids, sw_sorted): _dummy_interm = torch.empty(tokens * topk, model_dim, device=device, dtype=torch.float16) compiled_exe = flyc.compile(exe, *_s2_args_fp4_interm( - _dummy_interm, a2_q.view(-1), w2_kernel.view(-1), + _dummy_interm, a2_q_kernel.view(-1), w2_kernel.view(-1), a2_scale_1d, w2_scale_1d, sorted_token_ids, sorted_expert_ids, sorted_weights_1d)) @@ -1196,7 +1238,7 @@ def _s2_args_fp4(o, x, w, sx, sw, st, eids, sw_sorted): torch.cuda.current_stream()) compiled_exe = flyc.compile(exe, *_s2_args_fp4( - out_perf, a2_q.view(-1), w2_kernel.view(-1), + out_perf, a2_q_kernel.view(-1), w2_kernel.view(-1), a2_scale_1d, w2_scale_1d, sorted_token_ids, sorted_expert_ids, sorted_weights_1d)) @@ -1219,7 +1261,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): ) is_reduce_exe = (getattr(exe, "mode", None) == MoeGemm2Mode.REDUCE) or bool(use_reduce) - if not is_fp4: + if not is_fp4_path: def _s2_args_atomic(o, x, w, sx, sw, st, eids, sw_sorted): return (o, x, w, sx, sw, st, eids, sw_sorted, num_valid_ids, tokens, model_dim, inter_dim, int(blocks), @@ -1253,13 +1295,18 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): else: exe(*_s2_args_atomic(o, x, w, sx, sw, st, eids, sw_sorted)) + # Flat byte view of A2 consumed by the kernel. For the FP4/A8W4 path we + # already built a dedicated uint8 `a2_q_kernel`; for the other paths the + # kernel accepts the original a2_q dtype directly. + a2_launch_buf = (a2_q_kernel if is_fp4_path else a2_q).view(-1) + # NOTE: stage2 uses atomic-add into `out`, so we cannot reuse the same output buffer # across perf iterations for correctness. Time into a dedicated buffer, then run # a single clean launch for correctness verification below. _, us = run_perftest( launch, out_perf, - a2_q.view(-1), + a2_launch_buf, w2_kernel.view(-1), a2_scale_1d, w2_scale_1d, @@ -1276,7 +1323,7 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): out.zero_() launch( out, - a2_q.view(-1), + a2_launch_buf, w2_kernel.view(-1), a2_scale_1d, w2_scale_1d, @@ -1303,26 +1350,62 @@ def launch(o, x, w, sx, sw, st, eids, sw_sorted): # Launches full expert-block range; effective work is gated by num_valid_ids. flops = 2 * tokens * topk * model_dim * inter_dim - tflops = flops / (us / 1e6) / 1e12 + # Guard us=0: graph-captured replay at tiny shapes occasionally reports 0us + # which would propagate NaN/Inf into TFLOPS and TB/s; fall back to NaN so + # numeric columns render as "nan" instead of raising RuntimeWarning. + tflops = float("nan") if us <= 0 else flops / (us / 1e6) / 1e12 # Only activated experts load weights/scales: E_active = min(E, tokens * topk). active_experts = min(experts, tokens * topk) - bytes_moved = 0 - a2_elem_bytes = 2 if in_dtype in ("int4_bf16", "bf16", "fp16") else 1 # bf16/fp16 activations - bytes_moved += tokens * topk * inter_dim * a2_elem_bytes # a2 (logical) - bytes_moved += (active_experts * model_dim * inter_dim) // (2 if w_is_int4 else 1) # w2 (packed for int4) - bytes_moved += tokens * model_dim * (2 if out_torch_dtype in (torch.float16, torch.bfloat16) else 4) # out - is_f16_or_bf16_s2 = is_int4_bf16 or in_dtype in ("bf16", "fp16") - bytes_moved += (tokens * topk * 4) if not is_f16_or_bf16_s2 else 0 # a2_scale f32 (None for bf16) + + # One row per in_dtype: (a_bits, w_bits, a_scale_mode, w_scale_mode). + # a/w *_bits*: element width in bits for a2 / w2 bytes moved. + # a_scale_mode: "none" | "per_token_f32" | "per_block_e8m0_32" + # w_scale_mode: "none" | "per_row_f32" | "per_block_e8m0_32" | "groupwise" + # 'groupwise' is the int4_bf16 + group_size>0 path; it is handled as a + # one-off override below because its W2 scale element bytes is configurable. + _BYTES_SPEC = { + "fp8": ( 8, 8, "per_token_f32", "per_row_f32"), + "fp16": (16, 16, "none", "none"), + "bf16": (16, 16, "none", "none"), + "int8": ( 8, 8, "per_token_f32", "per_row_f32"), + "int8smooth": ( 8, 8, "per_token_f32", "per_row_f32"), + "int4": ( 8, 4, "per_token_f32", "per_row_f32"), + "int4_bf16": (16, 4, "none", "per_row_f32"), + "fp4": ( 4, 4, "per_block_e8m0_32", "per_block_e8m0_32"), + "a8w4": ( 8, 4, "per_block_e8m0_32", "per_block_e8m0_32"), + } + a_bits, w_bits, a_scale_mode, w_scale_mode = _BYTES_SPEC[in_dtype] if use_groupwise_scale: + w_scale_mode = "groupwise" + + a2_elems = tokens * topk * inter_dim + w2_elems = active_experts * model_dim * inter_dim + + bytes_moved = 0 + bytes_moved += (a2_elems * a_bits) // 8 # a2 + bytes_moved += (w2_elems * w_bits) // 8 # w2 + bytes_moved += tokens * model_dim * out_torch_dtype.itemsize # out + + # a2 scale + if a_scale_mode == "per_token_f32": + bytes_moved += tokens * topk * 4 + elif a_scale_mode == "per_block_e8m0_32": + bytes_moved += a2_elems // 32 # 1 uint8 per 32 elems + + # w2 scale + if w_scale_mode == "per_row_f32": + bytes_moved += active_experts * model_dim * 4 + elif w_scale_mode == "per_block_e8m0_32": + bytes_moved += w2_elems // 32 + elif w_scale_mode == "groupwise": num_groups_s2 = inter_dim // group_size - _scale_bytes = 2 if scale_dtype == "bf16" else 4 - bytes_moved += active_experts * num_groups_s2 * model_dim * _scale_bytes # groupwise scale - elif not is_f16_or_bf16_s2: - bytes_moved += active_experts * model_dim * 4 # per-row scale_w f32 + scale_elem_bytes = 2 if scale_dtype == "bf16" else 4 + bytes_moved += active_experts * num_groups_s2 * model_dim * scale_elem_bytes + # Note: routing metadata (sorted_weights, sorted_token_ids, sorted_expert_ids) excluded # from bytes_moved — they are negligible vs weight/activation/scale tensors. - tbps = bytes_moved / 1e12 / (us / 1e6) + tbps = float("nan") if us <= 0 else bytes_moved / 1e12 / (us / 1e6) print( f"FlyDSL MoE stage2 [{kernel_name}] {in_dtype} {'reduce' if use_reduce else 'atomic'} | " f"{model_dim}x{inter_dim}, E={experts}, K={topk}, M_eff={tokens*topk} | " @@ -1487,20 +1570,20 @@ def test_moe_gemm_2stage( pytest.skip("reduce mode does not support out_dtype='f32' (compile_moe_gemm2(accumulate=False) forbids it).") if group_size > 0 and in_dtype != "int4_bf16": pytest.skip("groupwise scale only applies to int4_bf16 (W4A16)") - if in_dtype == "fp4": + if in_dtype in ("fp4", "a8w4"): if bool(use_valid_mask): - pytest.skip("FP4 does not support valid_mask") + pytest.skip(f"{in_dtype} does not support valid_mask") if out_s not in ("f16", "fp16", "half"): - pytest.skip("FP4 only supports f16 output") + pytest.skip(f"{in_dtype} only supports f16 output") if group_size > 0: - pytest.skip("FP4 does not support groupwise scale") - # FP4 requires K >= 256 and tile_k >= 256 (scale layout constraint) + pytest.skip(f"{in_dtype} does not support groupwise scale") + # Per-1x32 scale layout requires K >= 256 and tile_k >= 256 on both stages. if model_dim < 256 or tile_k1 < 256: - pytest.skip(f"FP4 requires model_dim >= 256 and tile_k >= 256, got {model_dim}, {tile_k1}") + pytest.skip(f"{in_dtype} requires model_dim >= 256 and tile_k >= 256, got {model_dim}, {tile_k1}") if inter_dim < 256 or tile_k2 < 256: - pytest.skip(f"FP4 stage2 requires inter_dim >= 256 and tile_k2 >= 256, got {inter_dim}, {tile_k2}") + pytest.skip(f"{in_dtype} stage2 requires inter_dim >= 256 and tile_k2 >= 256, got {inter_dim}, {tile_k2}") if tile_m < 32 or tile_m % 32 != 0: - pytest.skip(f"FP4 requires tile_m % 32 == 0 and tile_m >= 32, got {tile_m}") + pytest.skip(f"{in_dtype} requires tile_m % 32 == 0 and tile_m >= 32, got {tile_m}") device = torch.device("cuda") # torch.manual_seed(int(seed)) @@ -1568,12 +1651,14 @@ def test_moe_gemm_2stage( test_graph=test_graph, ) - if in_dtype == "fp4": - # Quantize stage1 output to FP4 for stage2 input - out1_fp32 = out1_fp16.to(torch.float32) - a2_fp4, a2_scale_raw = _per_1x32_fp4_quant(out1_fp32.view(tokens * topk, inter_dim)) - a2_q = a2_fp4 # [tokens*topk, inter_dim//2] as float4_e2m1fn_x2 - a2_scale = a2_scale_raw # raw e8m0 [tokens*topk, inter_dim//32], will be sorted in run_moe_stage2 + if in_dtype in ("fp4", "a8w4"): + # Re-quantize stage1 output for stage2 input: + # fp4 -> MX-FP4 (0.5 B/elem, packed) + # a8w4 -> MX-FP8 e4m3fn (1 B/elem) + # run_moe_stage2 sorts the raw E8M0 scale [tokens*topk, inter_dim//32] internally. + out1_fp32 = out1_fp16.to(torch.float32).view(tokens * topk, inter_dim) + quantize_a2 = _per_1x32_mxfp8_quant if in_dtype == "a8w4" else _per_1x32_fp4_quant + a2_q, a2_scale = quantize_a2(out1_fp32) elif w_fp4_kernel: a2_q = out1_fp16.to(torch.float32) a2_scale = None @@ -1664,6 +1749,28 @@ def _per_1x32_fp4_quant(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return y_fp4, scale +def _per_1x32_mxfp8_quant(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize a tensor to MX-FP8 (e4m3fn) with per-1x32 E8M0 block scaling. + + Mirrors `_per_1x32_fp4_quant` for the A8W4 path: the activation is kept at + 1 byte/element (no packing), and each 32-element K block gets its own + E8M0 scale stored as uint8. Returns + (x_q [..., K] fp8_e4m3fn, scale_e8m0 [..., K//32] uint8). + """ + from tests.kernels.utils import fp4_utils + + fp8_max = float(torch.finfo(torch.float8_e4m3fn).max) + shape_orig = x.shape + x_flat = x.contiguous().view(-1, 32).float() + amax = torch.amax(torch.abs(x_flat), dim=-1).clamp_min(1e-30) + scale_e8m0 = fp4_utils.f32_to_e8m0(amax / fp8_max) + scale_f32 = fp4_utils.e8m0_to_f32(scale_e8m0).clamp_min(1e-30) + x_q = (x_flat / scale_f32.view(-1, 1)).clamp(-fp8_max, fp8_max).to(torch.float8_e4m3fn) + x_q = x_q.view(shape_orig).contiguous() + scale_bytes = scale_e8m0.view(*shape_orig[:-1], shape_orig[-1] // 32).view(torch.uint8).contiguous() + return x_q, scale_bytes + + # Test Helpers for MoE GEMM2 Mode Comparison def _make_reduce_mode_compile_fn(use_flydsl_reduce: bool = True, use_valid_mask: bool = False, scale_dtype: str = "f32"): @@ -1855,6 +1962,7 @@ def test_moe_gemm_w4a16_groupwise_scale(scale_dtype): @pytest.mark.parametrize("in_dtype", [ "fp8", pytest.param("fp4", marks=pytest.mark.skipif("gfx95" not in ARCH, reason="FP4 requires gfx950+")), + pytest.param("a8w4", marks=pytest.mark.skipif("gfx95" not in ARCH, reason="A8W4 requires gfx950+")), ]) def test_moe_stage2_standalone( tokens: int, @@ -1877,19 +1985,19 @@ def test_moe_stage2_standalone( 1. Atomic mode: direct accumulation with atomics 2. Reduce mode (torch): GEMM2 + torch.sum reduction 3. Reduce mode (FlyDSL): GEMM2 + FlyDSL reduce kernel - For FP4: atomic mode + torch reduce mode. + For FP4 / A8W4: atomic mode + torch reduce mode (MXFP4 weight path). """ - is_fp4 = in_dtype == "fp4" - if is_fp4: + is_fp4_path = in_dtype in ("fp4", "a8w4") + if is_fp4_path: from tests.kernels.utils import fp4_utils if fp4_utils is None: pytest.skip("FP4 dependencies not available (triton/mixed_moe_gemm not installed)") if "gfx95" not in ARCH: - pytest.skip(f"FP4 requires gfx950+, got {ARCH}") + pytest.skip(f"{in_dtype} requires gfx950+, got {ARCH}") if inter_dim < 256 or tile_k < 256: - pytest.skip(f"FP4 requires inter_dim >= 256 and tile_k >= 256, got {inter_dim}, {tile_k}") + pytest.skip(f"{in_dtype} requires inter_dim >= 256 and tile_k >= 256, got {inter_dim}, {tile_k}") if tile_m < 32 or tile_m % 32 != 0: - pytest.skip(f"FP4 requires tile_m % 32 == 0 and tile_m >= 32, got {tile_m}") + pytest.skip(f"{in_dtype} requires tile_m % 32 == 0 and tile_m >= 32, got {tile_m}") # Common args common_args = dict( @@ -1911,26 +2019,20 @@ def test_moe_stage2_standalone( skip_ref=False, ) - if is_fp4: - # FP4 requires a2_fp8_in / a2_scale_in (can't build from torch reference) + if is_fp4_path: + # FP4 / A8W4 require pre-quantized a2 (stage1 output); torch reference + # cannot synthesize FP4/FP8 packed activations, so we build them here. + _A2_QUANTIZERS = { + "fp4": _per_1x32_fp4_quant, + "a8w4": _per_1x32_mxfp8_quant, + } device = torch.device("cuda") torch.manual_seed(seed) a2_fp32 = torch.randn((tokens * topk, inter_dim), device=device, dtype=torch.float32) * 0.2 - a2_fp4, a2_scale = _per_1x32_fp4_quant(a2_fp32) - # FP4 supports atomic mode and torch-based reduce mode. - run_moe_stage2( - **common_args, - a2_fp8_in=a2_fp4, - a2_scale_in=a2_scale, - kernel_name="moe_gemm2_atomic_fp4", - ) - run_moe_stage2( - **common_args, - a2_fp8_in=a2_fp4, - a2_scale_in=a2_scale, - use_reduce=True, - kernel_name="moe_gemm2_reduce_torch_fp4", - ) + a2_q, a2_scale = _A2_QUANTIZERS[in_dtype](a2_fp32) + fp4_args = dict(common_args, a2_fp8_in=a2_q, a2_scale_in=a2_scale) + run_moe_stage2(**fp4_args, kernel_name=f"moe_gemm2_atomic_{in_dtype}") + run_moe_stage2(**fp4_args, use_reduce=True, kernel_name=f"moe_gemm2_reduce_torch_{in_dtype}") return # Run baseline stage2 (atomic accumulation) @@ -1990,12 +2092,13 @@ def _str2tuple_dim(v: str) -> Tuple[int, int]: "--in_dtype", type=str, default="fp8", - choices=["fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "fp4", "all"], - help="Kernel input dtype: fp8 / fp16 / int8 / int8smooth / int4 / int4_bf16 / fp4 / all (default: all). " + choices=["fp8", "fp16", "bf16", "int8", "int8smooth", "int4", "int4_bf16", "fp4", "a8w4", "all"], + help="Kernel input dtype: fp8 / fp16 / int8 / int8smooth / int4 / int4_bf16 / fp4 / a8w4 / all (default: all). " "int8smooth expands X to [tokens*topk, K] with per-(token,slot) scales. " "int4 means W4A8: A int8, W packed int4. " "int4_bf16 means W4A16: A bf16, W packed int4. " - "fp4 means A4W4: both activation and weight are FP4 (uses mixed_moe_gemm kernel).", + "fp4 means A4W4: both activation and weight are FP4 (uses mixed_moe_gemm kernel). " + "a8w4 means FP8 activation + MX-FP4 weight (per-1x32 E8M0 block scales on both sides; gfx950+).", ) parser.add_argument("-d", "--dtype", type=str, default="fp32", choices=["fp32", "fp16", "bf16"], help="Input init dtype (currently data is quantized to FP8 per-token; init dtype mainly affects RNG range).") parser.add_argument("-dim", type=_str2tuple_dim, default=(6144, 4096), help="Model dimension: model_dim,inter_dim (e.g. -dim 6144,4096)") @@ -2115,10 +2218,10 @@ def run_one(dt: str, use_reduce: bool): # Expand "all" to all supported dtypes. in_dtypes = args.in_dtype.split(",") if "all" in in_dtypes: - in_dtypes = ["fp8", "fp16", "bf16", "int8", "int4", "int4_bf16", "fp4"] + in_dtypes = ["fp8", "fp16", "bf16", "int8", "int4", "int4_bf16", "fp4", "a8w4"] for dt in in_dtypes: - if dt == "fp4" and "gfx95" not in ARCH: - print(f"Skipping FP4: requires gfx950+, got {ARCH}") + if dt in ("fp4", "a8w4") and "gfx95" not in ARCH: + print(f"Skipping {dt}: requires gfx950+, got {ARCH}") continue for use_reduce in reduce_flags: run_one(dt, use_reduce) diff --git a/tests/kernels/test_ref.py b/tests/kernels/test_ref.py index 36dbe4c3..1b4384c2 100644 --- a/tests/kernels/test_ref.py +++ b/tests/kernels/test_ref.py @@ -5,6 +5,41 @@ import torch.nn.functional as F +# Scale-kind "enum" used by the ref helpers below. Each kind determines +# how we dequantize (x, scale) to fp32 and how to recover the logical K +# dimension from x.shape[-1]. +# "mxfp4": packed fp4 bytes (1 byte = 2 fp4 values); logical_K == shape[-1] * 2 +# scale is uint8 E8M0 block scale, one byte per 32 logical elements. +# "mxfp8": raw fp8_e4m3fn bytes; logical_K == shape[-1] +# scale is uint8 E8M0 block scale, one byte per 32 logical elements. +# "scalar": plain fp/int tensor + per-token / per-row fp scale (may be None); +# logical_K == shape[-1]. 'scale is None' is the no-quant case. +_FP4_DTYPES = (torch.uint8, torch.float4_e2m1fn_x2) + + +def _detect_scale_kind(x: torch.Tensor, scale: torch.Tensor | None) -> str: + """Classify the (x, scale) pair into one of {"mxfp4", "mxfp8", "scalar"}. + + The classifier is intentionally dtype-driven so callers can pass *either* + packed FP4 bytes or real torch.float4 tensors for MX-FP4, while still + distinguishing MX-FP8 (one fp8 byte per element) unambiguously. + """ + if scale is None or scale.dtype != torch.uint8: + return "scalar" + # Both block-scale kinds use K_blocks == logical_K // 32. + if x.dtype in _FP4_DTYPES and x.shape[-1] * 2 == scale.shape[-1] * 32: + return "mxfp4" + if x.dtype == torch.float8_e4m3fn and x.shape[-1] == scale.shape[-1] * 32: + return "mxfp8" + return "scalar" + + +def _logical_k(x: torch.Tensor, kind: str) -> int: + """Return the logical K dim (unpacked) given the detected scale kind.""" + k = int(x.shape[-1]) + return k * 2 if kind == "mxfp4" else k + + def _dequant_mxfp4_per_1x32(x_fp4: torch.Tensor, scale_e8m0: torch.Tensor) -> torch.Tensor: """Dequantize packed MXFP4 with per-1x32 e8m0 block scales to fp32.""" from tests.kernels.utils import fp4_utils @@ -20,6 +55,42 @@ def _dequant_mxfp4_per_1x32(x_fp4: torch.Tensor, scale_e8m0: torch.Tensor) -> to return x_f32.view(*x_u8.shape[:-1], logical_k) +def _dequant_mxfp8_per_1x32(x_fp8: torch.Tensor, scale_e8m0: torch.Tensor) -> torch.Tensor: + """Dequantize MX-FP8 (e4m3fn) activations with per-1x32 e8m0 block scales to fp32. + + Mirrors the MFMA-time semantics of the A8W4 stage1/stage2 kernels: the kernel + reads the raw fp8 byte, casts to fp32, and multiplies by the 8-bit exponent + scale (E8M0 = 2^(byte-127)) for every 32-element K block. + """ + from tests.kernels.utils import fp4_utils + + logical_k = int(x_fp8.shape[-1]) + if logical_k % 32 != 0: + raise ValueError(f"MX-FP8 logical K must be divisible by 32, got {logical_k}") + + x_f32 = x_fp8.reshape(-1, logical_k).to(torch.float32) + scales_f32 = fp4_utils.e8m0_to_f32( + scale_e8m0.view(torch.uint8).reshape(-1, logical_k // 32) + ) + x_f32 = x_f32 * scales_f32.repeat_interleave(32, dim=1) + return x_f32.view(*x_fp8.shape[:-1], logical_k) + + +def _dequant(x: torch.Tensor, scale: torch.Tensor | None, kind: str) -> torch.Tensor: + """Unified fp32 dequantization driven by the detected ``kind``. + + - "mxfp4" / "mxfp8" delegate to the per-1x32 block-scale helpers above. + - "scalar" handles both the no-quant (scale is None) and broadcast-scale + paths (per-token / per-row fp scale). + """ + if kind == "mxfp4": + return _dequant_mxfp4_per_1x32(x, scale) + if kind == "mxfp8": + return _dequant_mxfp8_per_1x32(x, scale) + x_f32 = x.to(torch.float32) + return x_f32 if scale is None else x_f32 * scale + + def torch_moe_gemm1( x_q: torch.Tensor, w1_q_flat: torch.Tensor, @@ -40,34 +111,28 @@ def torch_moe_gemm1( Required when group_size > 0; ignored otherwise. """ topk = topk_ids.shape[1] - is_fp4 = ( - scale_x is not None - and scale_w1_flat is not None - and x_q.shape[-1] * 2 == scale_x.shape[-1] * 32 - and w1_q_flat.shape[-1] * 2 == scale_w1_flat.shape[-1] * 32 - ) + # Independent per-1x32 block-scale detection for x and w, so that mixed + # precisions such as A8W4 (fp8 activation + mxfp4 weight) can use the correct + # dequant per side. See ``_detect_scale_kind`` for the classification rules. + x_kind = _detect_scale_kind(x_q, scale_x) + w_kind = _detect_scale_kind(w1_q_flat, scale_w1_flat) if x_q.dim() == 2: tokens = int(x_q.shape[0]) - model_dim = int(x_q.shape[1]) * 2 if is_fp4 else int(x_q.shape[1]) elif x_q.dim() == 3: - tokens, topk_x, model_dim_raw = x_q.shape - assert ( - int(topk_x) == int(topk) - ), f"x_q topk mismatch: x_q.shape={tuple(x_q.shape)}, topk={topk}" - model_dim = int(model_dim_raw) * 2 if is_fp4 else int(model_dim_raw) + tokens, topk_x, _ = x_q.shape + assert int(topk_x) == int(topk), \ + f"x_q topk mismatch: x_q.shape={tuple(x_q.shape)}, topk={topk}" else: raise ValueError(f"Unsupported x_q shape: {tuple(x_q.shape)}") + model_dim = _logical_k(x_q, x_kind) # Derive experts from weight shapes (topk_ids may not cover all experts when tokens are tiny). if w1_q_flat.dim() == 2: experts = int(w1_q_flat.shape[0] // (2 * inter_dim)) else: experts = int(w1_q_flat.shape[0]) - if is_fp4: - x = _dequant_mxfp4_per_1x32(x_q, scale_x) - else: - x = x_q.to(torch.float32) if scale_x is None else (x_q.to(torch.float32) * scale_x) + x = _dequant(x_q, scale_x, x_kind) if group_size > 0 and scale_w1_groups is not None: # Group-wise dequantization: w_dequant[e,n,k] = w_int[e,n,k] * scale[e, k//group_size, n] @@ -78,16 +143,7 @@ def torch_moe_gemm1( k_s, k_e = g * group_size, (g + 1) * group_size w1[:, :, k_s:k_e] *= scale_w1_groups[:, g, :].unsqueeze(-1) else: - # Per-row dequantization. - if is_fp4: - w1 = _dequant_mxfp4_per_1x32(w1_q_flat, scale_w1_flat).view(experts, 2 * inter_dim, model_dim) - else: - w1 = ( - w1_q_flat.to(torch.float32) - if scale_w1_flat is None - else (w1_q_flat.to(torch.float32) * scale_w1_flat) - ) - w1 = w1.view(experts, 2 * inter_dim, model_dim) + w1 = _dequant(w1_q_flat, scale_w1_flat, w_kind).view(experts, 2 * inter_dim, model_dim) out = torch.zeros((tokens, topk, inter_dim), device="cuda", dtype=torch.float32) for e in range(experts): @@ -137,24 +193,21 @@ def torch_moe_gemm2( assert a2_q.is_cuda and w2_q.is_cuda tokens, topk = topk_ids.shape - is_fp4 = ( - scale_a2 is not None - and scale_w2 is not None - and a2_q.shape[-1] * 2 == scale_a2.shape[-1] * 32 - and w2_q.shape[-1] * 2 == scale_w2.shape[-1] * 32 - ) + # Independent per-1x32 block-scale detection for a2 and w2; see + # ``_detect_scale_kind`` for the classification rules. + a_kind = _detect_scale_kind(a2_q, scale_a2) + w_kind = _detect_scale_kind(w2_q, scale_w2) - if is_fp4: - inter_dim = int(a2_q.shape[-1]) * 2 + inter_dim = _logical_k(a2_q, a_kind) + if a_kind in ("mxfp4", "mxfp8"): if a2_q.dim() == 2: a2_q = a2_q.view(tokens, topk, -1) scale_a2 = scale_a2.view(tokens, topk, -1) elif a2_q.dim() != 3: - raise ValueError(f"Unsupported FP4 a2 shape: {tuple(a2_q.shape)}") + raise ValueError(f"Unsupported {a_kind} a2 shape: {tuple(a2_q.shape)}") else: if a2_q.dim() != 3: raise ValueError(f"Unsupported a2_q shape: {tuple(a2_q.shape)}") - _, _, inter_dim = a2_q.shape # Derive experts from weight shapes (topk_ids may not cover all experts when tokens are tiny). if w2_q.dim() == 3: @@ -162,11 +215,7 @@ def torch_moe_gemm2( else: experts = int(w2_q.shape[0] // model_dim) - # Dequantize inputs. - if is_fp4: - a2 = _dequant_mxfp4_per_1x32(a2_q, scale_a2) - else: - a2 = a2_q.to(torch.float32) if scale_a2 is None else (a2_q.to(torch.float32) * scale_a2) + a2 = _dequant(a2_q, scale_a2, a_kind) if group_size > 0 and scale_w2_groups is not None: # Group-wise dequantization: w_dequant[e,n,k] = w_int[e,n,k] * scale[e, k//group_size, n] @@ -177,12 +226,7 @@ def torch_moe_gemm2( k_s, k_e = g * group_size, (g + 1) * group_size w2[:, :, k_s:k_e] *= scale_w2_groups[:, g, :].unsqueeze(-1) else: - # Per-row dequantization. - if is_fp4: - w2 = _dequant_mxfp4_per_1x32(w2_q, scale_w2).view(experts, model_dim, inter_dim) - else: - w2 = w2_q.to(torch.float32) if scale_w2 is None else (w2_q.to(torch.float32) * scale_w2) - w2 = w2.view(experts, model_dim, inter_dim) + w2 = _dequant(w2_q, scale_w2, w_kind).view(experts, model_dim, inter_dim) out = torch.zeros((tokens, model_dim), device="cuda", dtype=torch.float32) for e in range(experts): From 19d6f6e13394008090a06c023097ee090b62d710 Mon Sep 17 00:00:00 2001 From: amd-weisun Date: Thu, 23 Apr 2026 15:50:00 +0100 Subject: [PATCH 29/29] improve fused_rope kernel (#416) * improve fused rope kernel * address comments --- kernels/fused_rope_cache_kernel.py | 658 ++++++++++--------- tests/kernels/test_fused_rope_cache.py | 844 +++++++++++++++++-------- 2 files changed, 965 insertions(+), 537 deletions(-) diff --git a/kernels/fused_rope_cache_kernel.py b/kernels/fused_rope_cache_kernel.py index d9258487..903f7be6 100644 --- a/kernels/fused_rope_cache_kernel.py +++ b/kernels/fused_rope_cache_kernel.py @@ -3,14 +3,28 @@ """Fused RoPE + KV Cache kernel builder using the @flyc.kernel API. -Fuses 3 operations into two kernel launches: - Kernel 1 (Q RoPE): Q → rotate → Q_out - Kernel 2 (K+V cache): K → rotate → K_out + key_cache; V → value_cache +Fuses 3 operations into a **single kernel launch**: + Q -> RoPE rotation -> Q_out + K -> RoPE rotation -> K_out + key_cache + V -> value_cache + +Grid: (max(QH, KH), T, 1) -- shared blocks for Q and K + block_idx.x = head_idx in [0, max(QH, KH)) + block_idx.y = token_idx + + Each block conditionally does Q work (if head_idx < QH) and/or K work + (if head_idx < KH). For GQA (QH >> KH) blocks beyond KH only do Q; + for MQA-like configs where KH <= QH every block does both. + + Cos/sin are loaded ONCE per block (before branching) and shared by both + the Q and K paths, saving buffer descriptor SGPRs. Input shapes: Q: [T, QH, D], K: [T, KH, D], V: [T, KH, D] - CosCache/SinCache: [max_pos, D//2] (must be 2-D contiguous) - Positions: [T] int32, SlotMapping: [T] int32 + CosCache/SinCache: [max_pos, D//2] if reuse_freqs_front_part else [max_pos, D] + Positions/SlotMapping: + - pos_dtype="i32": [T] int32 + - pos_dtype="i64": [T] int64, accessed via stride-2 int32 indexing (.view(int32)) KV cache layouts: flash_layout=True: @@ -20,21 +34,20 @@ KeyCache: [num_blocks, KH, D//x, block_size, x] (x=16, x-packed) ValueCache: [num_blocks, KH, D, block_size] (dim-major) - """ import flydsl.compiler as flyc import flydsl.expr as fx -from flydsl.expr import range_constexpr +from flydsl.expr import arith, vector, buffer_ops, range_constexpr +from flydsl.expr.arith import ArithValue from flydsl.expr.typing import T -from flydsl.expr.numeric import Numeric -from flydsl.expr.vector import full -from kernels.kernels_common import dtype_to_elem_type +from kernels.kernels_common import get_warp_size -WARP_SIZE = 64 -VEC_WIDTH = 8 +# WARP_SIZE is 32 on RDNA (wave32: gfx10xx/gfx11xx/gfx12xx) and 64 on CDNA (wave64: gfx9xx). +# All derived values (VEC_WIDTH, vecs_per_half, BLOCK_THREADS) flow from this automatically. +WARP_SIZE = get_warp_size() def build_fused_rope_cache_module( @@ -46,23 +59,10 @@ def build_fused_rope_cache_module( is_neox: bool = True, flash_layout: bool = True, dtype_str: str = "bf16", + apply_scale: bool = False, + reuse_freqs_front_part: bool = True, + pos_dtype: str = "i32", ): - """Build fused RoPE + KV cache kernel. - - Args: - head_dim: dimension per attention head - rotary_dim: dimensions to rotate (== head_dim for full rotation) - num_q_heads: query heads per rank - num_kv_heads: KV heads per rank - block_size: paged attention block size - is_neox: True for NeoX-style rotation - flash_layout: True for [num_blocks, block_size, KH, D] cache layout - dtype_str: element dtype ("bf16" or "f16") - - Returns: - launch_fn(Q, K, V, Positions, CosCache, SinCache, SlotMapping, - KeyCache, ValueCache, Q_out, K_out, num_tokens, stream) - """ if rotary_dim == -1: rotary_dim = head_dim if not is_neox: @@ -71,284 +71,373 @@ def build_fused_rope_cache_module( raise NotImplementedError("Partial rotation not yet supported") if dtype_str not in ("bf16", "f16"): raise ValueError( - f"dtype_str must be 'bf16' or 'f16', got {dtype_str!r} " - f"(f32 is not supported: kernel uses 2-byte elem_bytes and vec8 vectorization)" + f"dtype_str must be 'bf16' or 'f16', got {dtype_str!r}" ) half_dim = rotary_dim // 2 - vecs_per_half = half_dim // VEC_WIDTH # number of VEC_WIDTH-wide vectors covering half_dim - vecs_per_head = head_dim // VEC_WIDTH # number of VEC_WIDTH-wide vectors covering head_dim - x_size = 16 # x-packing factor for non-flash key_cache - # Validate vectorization and layout assumptions to avoid silent truncation. + # VEC_WIDTH: elements per thread. Use ceil division so vecs_per_head never + # exceeds WARP_SIZE for the fixed one-thread-per-vector mapping below. + # For D=64: VEC_WIDTH=1 -> vecs_per_head=64 (full wavefront, 16-bit loads). + # For D=96: VEC_WIDTH=2 -> vecs_per_head=48 (fits within one wavefront). + # For D=128: VEC_WIDTH=2 -> vecs_per_head=64 (32-bit loads, unchanged). + VEC_WIDTH = max(1, (head_dim + WARP_SIZE - 1) // WARP_SIZE) + + vecs_per_half = half_dim // VEC_WIDTH + vecs_per_head = head_dim // VEC_WIDTH + x_size = 16 + + # elem_bits for copy atom (bf16/f16 = 16 bits) + elem_bits = 16 + # Copy atom bits: VEC_WIDTH * elem_bits + copy_bits = VEC_WIDTH * elem_bits # e.g. 2*16=32 for VEC_WIDTH=2 + if head_dim % VEC_WIDTH != 0: - raise ValueError( - f"head_dim must be a multiple of VEC_WIDTH ({VEC_WIDTH}), " - f"got head_dim={head_dim}" - ) + raise ValueError(f"head_dim must be a multiple of VEC_WIDTH ({VEC_WIDTH}), got {head_dim}") if rotary_dim % 2 != 0: - raise ValueError( - f"rotary_dim must be even so that half_dim=rotary_dim//2 is integral, " - f"got rotary_dim={rotary_dim}" - ) + raise ValueError(f"rotary_dim must be even, got {rotary_dim}") if half_dim % VEC_WIDTH != 0: - raise ValueError( - f"half_dim (rotary_dim//2) must be a multiple of VEC_WIDTH " - f"({VEC_WIDTH}), got half_dim={half_dim} (rotary_dim={rotary_dim})" - ) + raise ValueError(f"half_dim must be a multiple of VEC_WIDTH ({VEC_WIDTH}), got {half_dim}") if not flash_layout and head_dim % x_size != 0: - raise ValueError( - f"With flash_layout=False, head_dim must be a multiple of the " - f"key_cache packing factor x_size ({x_size}), got head_dim={head_dim}" - ) - if vecs_per_head > WARP_SIZE: - max_head_dim = WARP_SIZE * VEC_WIDTH - raise ValueError( - f"Unsupported head_dim={head_dim}: with WARP_SIZE={WARP_SIZE} and " - f"VEC_WIDTH={VEC_WIDTH}, head_dim must satisfy " - f"head_dim <= {max_head_dim} to avoid incomplete coverage " - f"(got vecs_per_head={vecs_per_head} > WARP_SIZE)" - ) + raise ValueError(f"head_dim must be a multiple of x_size ({x_size}), got {head_dim}") + BLOCK_THREADS = WARP_SIZE + num_q_heads_val = num_q_heads + num_kv_heads_val = num_kv_heads + max_heads = max(num_q_heads, num_kv_heads) - # ----- Kernel 1: Q RoPE ----- - # Grid: (T * QH, 1, 1), one program per (token, q_head) - # Each program: vecs_per_head threads process head_dim elements @flyc.kernel - def q_rope_kernel( - Q: fx.Tensor, # [T, QH, D] - Positions: fx.Tensor, # [T] int32 - CosCache: fx.Tensor, # [max_pos, half_dim] - SinCache: fx.Tensor, # [max_pos, half_dim] - Q_out: fx.Tensor, # [T, QH, D] + def fused_qk_rope_reshape_and_cache( + Q: fx.Tensor, + K: fx.Tensor, + V: fx.Tensor, + Positions: fx.Tensor, + CosCache: fx.Tensor, + SinCache: fx.Tensor, + SlotMapping: fx.Tensor, + KeyCache: fx.Tensor, + ValueCache: fx.Tensor, + Q_out: fx.Tensor, + K_out: fx.Tensor, + KScale: fx.Tensor, + VScale: fx.Tensor, ): - pid = fx.block_idx.x # program id: 0..T*QH-1 - tid = fx.thread_idx.x # 0..63 - - elem_type = dtype_to_elem_type(dtype_str) - elem_bits = 16 # bf16/f16 only + head_idx = fx.block_idx.x + pid_t = fx.block_idx.y + tid = fx.thread_idx.x - # Buffer-backed tensors via layout API - Q_buf = fx.rocdl.make_buffer_tensor(Q) - Qo_buf = fx.rocdl.make_buffer_tensor(Q_out) - Cos_buf = fx.rocdl.make_buffer_tensor(CosCache) - Sin_buf = fx.rocdl.make_buffer_tensor(SinCache) - Pos_buf = fx.rocdl.make_buffer_tensor(Positions) + elem_type = T.bf16 if dtype_str == "bf16" else T.f16 - copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) + # --- Layout API setup --- + copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy(copy_bits), elem_bits) vec_reg_ty = fx.MemRefType.get( elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register ) + # Single layout used for both register alloca and logical_divide (same shape). vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) + vec_div_lay = vec_reg_lay - copy_atom_i32 = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) - i32_reg_ty = fx.MemRefType.get(T.i32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) - i32_reg_lay = fx.make_layout(1, 1) + # f32 scalar copy atom for KScale/VScale loads (1 x f32 = 32 bits). + f32_copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) + f32_reg_ty = fx.MemRefType.get(T.f32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) + f32_reg_lay = fx.make_layout(1, 1) - def load_scalar_i32(buf_tensor, elem_offset): - """Scalar i32 load using soffset for dynamic indexing.""" - div = fx.logical_divide(buf_tensor, fx.make_layout(1, 1)) - base_view = fx.slice(div, (None, fx.Int32(0))) - atom = copy_atom_i32.set_value("soffset", elem_offset) - r = fx.memref_alloca(i32_reg_ty, i32_reg_lay) - fx.copy_atom_call(atom, base_view, r) - return fx.memref_load_vec(r)[0] - - def load_vec(div_tensor, idx): + # Helper: load a VEC_WIDTH vector from a divided 1D tensor at given index + def load_vec(div_tensor, idx, atom=None): r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) + fx.copy_atom_call(atom or copy_atom, fx.slice(div_tensor, (None, idx)), r) return fx.memref_load_vec(r) - def store_vec(val, div_tensor, idx): + # Helper: store a VEC_WIDTH vector to a divided 1D tensor at given index + def store_vec(val, div_tensor, idx, atom=None): r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) + fx.copy_atom_call(atom or copy_atom, r, fx.slice(div_tensor, (None, idx))) + + # Helper: get the rotary-pair element via ds_bpermute (LDS cross-lane shuffle). + # For NeoX RoPE, the pair of thread tid is tid XOR vecs_per_half. + # ds_bpermute: thread tid reads the VGPR value held by thread (pair_byte_addr/4). + # pair_byte_addr = (tid XOR vecs_per_half) * 4. + # Handles VEC_WIDTH=1 (vector<1xbf16/f16>, 16-bit) and VEC_WIDTH=2 (vector<2xbf16/f16>, 32-bit). + def ds_bpermute_pair(vec_val, pair_byte_addr): + """Return the copy of vec_val held by the rotary-pair thread, via ds_bpermute.""" + if VEC_WIDTH == 1: + # vector<1xf16/bf16> → extract scalar → bitcast to i16 → zero-extend i32 + elem_val = vector.extract(vec_val, static_position=[0], dynamic_position=[]) + i16_val = ArithValue(elem_val).bitcast(T.i16) + i32_val = ArithValue(i16_val).extui(T.i32) + # Cross-lane shuffle: get pair thread's 32-bit VGPR (pair elem in low 16 bits) + peer_i32 = fx.rocdl.ds_bpermute(T.i32, pair_byte_addr, i32_val) + # Truncate back to i16, bitcast to elem_type, reconstruct vector<1xelem_type> + peer_i16 = ArithValue(peer_i32).trunci(T.i16) + peer_elem = ArithValue(peer_i16).bitcast(elem_type) + return vector.from_elements(T.vec(1, elem_type), [peer_elem]) + else: + # VEC_WIDTH>=2: VEC_WIDTH bf16/f16 elements → n_i32 x i32, one ds_bpermute per chunk. + # VEC_WIDTH=2 → n_i32=1 (32 bits); VEC_WIDTH=4 → n_i32=2 (64 bits), etc. + n_i32 = VEC_WIDTH // 2 + v_i32 = vector.bitcast(T.vec(n_i32, T.i32), vec_val) + peer_chunks = [] + for ci in range_constexpr(n_i32): + chunk = vector.extract(v_i32, static_position=[ci], dynamic_position=[]) + peer_chunks.append(fx.rocdl.ds_bpermute(T.i32, pair_byte_addr, chunk)) + peer_v_i32 = vector.from_elements(T.vec(n_i32, T.i32), peer_chunks) + return vector.bitcast(T.vec(VEC_WIDTH, elem_type), peer_v_i32) if tid < fx.Int32(vecs_per_head): - pid_t = pid // num_q_heads - pid_hq = pid % num_q_heads - - pos_val = load_scalar_i32(Pos_buf, pid_t) + # --- Load position (scalar i32) --- + pos_rsrc = buffer_ops.create_buffer_resource(Positions, max_size=True) + if pos_dtype == "i64": + pos_elem_off = ArithValue(pid_t) * 2 + else: + pos_elem_off = pid_t + pos_val = buffer_ops.buffer_load(pos_rsrc, pos_elem_off, vec_width=1, dtype=T.i32) - # Q[pid_t, pid_hq, :] tiled by VEC_WIDTH - q_row = fx.slice(Q_buf, (pid_t, fx.Int32(pid_hq), None)) - q_div = fx.logical_divide(q_row, fx.make_layout(VEC_WIDTH, 1)) - - # Q_out[pid_t, pid_hq, :] tiled by VEC_WIDTH - qo_row = fx.slice(Qo_buf, (pid_t, fx.Int32(pid_hq), None)) - qo_div = fx.logical_divide(qo_row, fx.make_layout(VEC_WIDTH, 1)) - - # cos/sin[pos_val, :] tiled by VEC_WIDTH - cos_row = fx.slice(Cos_buf, (pos_val, None)) - cos_div = fx.logical_divide(cos_row, fx.make_layout(VEC_WIDTH, 1)) - sin_row = fx.slice(Sin_buf, (pos_val, None)) - sin_div = fx.logical_divide(sin_row, fx.make_layout(VEC_WIDTH, 1)) - - # NeoX rotation: pair with opposite half is_first_half = tid < fx.Int32(vecs_per_half) - pair_tid = is_first_half.select(tid + vecs_per_half, tid - vecs_per_half) - cos_vec_idx = tid % vecs_per_half - - qk_e = load_vec(q_div, tid) - cos_e = load_vec(cos_div, cos_vec_idx) - sin_e = load_vec(sin_div, cos_vec_idx) - pair_e = load_vec(q_div, pair_tid) - - qk_cos = qk_e * cos_e - pair_sin = pair_e * sin_e - sin_term = is_first_half.select(-pair_sin, pair_sin) - rot_e = qk_cos + sin_term - - store_vec(rot_e, qo_div, tid) - - # ----- Kernel 2: K RoPE + KV cache write ----- - # Grid: (T * KH, 1, 1), one program per (token, kv_head) - # Each program: vecs_per_head threads process head_dim elements - @flyc.kernel - def k_cache_kernel( - K: fx.Tensor, # [T, KH, D] - V: fx.Tensor, # [T, KH, D] - Positions: fx.Tensor, # [T] int32 - CosCache: fx.Tensor, # [max_pos, half_dim] - SinCache: fx.Tensor, # [max_pos, half_dim] - SlotMapping: fx.Tensor, # [T] int32 - KeyCache: fx.Tensor, # flash: [T_cache, BS, KH, D] - ValueCache: fx.Tensor, # flash: [T_cache, BS, KH, D] - K_out: fx.Tensor, # [T, KH, D] - ): - pid = fx.block_idx.x # program id: 0..T*KH-1 - tid = fx.thread_idx.x # 0..63 - - elem_type = dtype_to_elem_type(dtype_str) - elem_dtype = Numeric.from_ir_type(elem_type) - elem_bits = 16 # bf16/f16 only - - # Buffer-backed tensors via layout API - K_buf = fx.rocdl.make_buffer_tensor(K) - V_buf = fx.rocdl.make_buffer_tensor(V) - Ko_buf = fx.rocdl.make_buffer_tensor(K_out) - Cos_buf = fx.rocdl.make_buffer_tensor(CosCache) - Sin_buf = fx.rocdl.make_buffer_tensor(SinCache) - Pos_buf = fx.rocdl.make_buffer_tensor(Positions) - Slot_buf = fx.rocdl.make_buffer_tensor(SlotMapping) - KC_buf = fx.rocdl.make_buffer_tensor(KeyCache) - VC_buf = fx.rocdl.make_buffer_tensor(ValueCache) - - copy_atom = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), elem_bits) - vec_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(VEC_WIDTH, 1), fx.AddressSpace.Register - ) - vec_reg_lay = fx.make_layout(VEC_WIDTH, 1) - - copy_atom_i32 = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), 32) - i32_reg_ty = fx.MemRefType.get(T.i32, fx.LayoutType.get(1, 1), fx.AddressSpace.Register) - i32_reg_lay = fx.make_layout(1, 1) - - if not flash_layout: - copy_atom_elem = fx.make_copy_atom(fx.rocdl.BufferCopy16b(), elem_bits) - elem_reg_ty = fx.MemRefType.get( - elem_type, fx.LayoutType.get(1, 1), fx.AddressSpace.Register - ) - elem_reg_lay = fx.make_layout(1, 1) - - def load_scalar_i32(buf_tensor, elem_offset): - """Scalar i32 load using soffset for dynamic indexing.""" - div = fx.logical_divide(buf_tensor, fx.make_layout(1, 1)) - base_view = fx.slice(div, (None, fx.Int32(0))) - atom = copy_atom_i32.set_value("soffset", elem_offset) - r = fx.memref_alloca(i32_reg_ty, i32_reg_lay) - fx.copy_atom_call(atom, base_view, r) - return fx.memref_load_vec(r)[0] - - def load_vec(div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.copy_atom_call(copy_atom, fx.slice(div_tensor, (None, idx)), r) - return fx.memref_load_vec(r) - - def store_vec(val, div_tensor, idx): - r = fx.memref_alloca(vec_reg_ty, vec_reg_lay) - fx.memref_store_vec(val, r) - fx.copy_atom_call(copy_atom, r, fx.slice(div_tensor, (None, idx))) - - def store_scalar(val, div_tensor, idx): - r = fx.memref_alloca(elem_reg_ty, elem_reg_lay) - ts = full(1, elem_dtype(val), elem_dtype) - fx.memref_store_vec(ts, r) - fx.copy_atom_call(copy_atom_elem, r, fx.slice(div_tensor, (None, idx))) - - if tid < fx.Int32(vecs_per_head): - pid_t = pid // num_kv_heads - pid_hk = pid % num_kv_heads + cos_vec_idx = tid % vecs_per_half if reuse_freqs_front_part else tid - pos_val = load_scalar_i32(Pos_buf, pid_t) + # Pair lane for ds_bpermute: tid XOR vecs_per_half (symmetric, works for both halves). + # pair_byte_addr = pair_lane * 4 (ds_bpermute address unit is bytes, VGPR = 4 bytes). + pair_lane = ArithValue(tid) ^ fx.Int32(vecs_per_half) + pair_byte_addr = pair_lane * fx.Int32(4) - # K[pid_t, pid_hk, :] tiled by VEC_WIDTH - k_row = fx.slice(K_buf, (pid_t, fx.Int32(pid_hk), None)) - k_div = fx.logical_divide(k_row, fx.make_layout(VEC_WIDTH, 1)) - - # K_out[pid_t, pid_hk, :] tiled by VEC_WIDTH - ko_row = fx.slice(Ko_buf, (pid_t, fx.Int32(pid_hk), None)) - ko_div = fx.logical_divide(ko_row, fx.make_layout(VEC_WIDTH, 1)) - - # cos/sin[pos_val, :] tiled by VEC_WIDTH + # --- Shared cos/sin (loaded once, used by both Q and K) --- + Cos_buf = fx.rocdl.make_buffer_tensor(CosCache) + Sin_buf = fx.rocdl.make_buffer_tensor(SinCache) cos_row = fx.slice(Cos_buf, (pos_val, None)) - cos_div = fx.logical_divide(cos_row, fx.make_layout(VEC_WIDTH, 1)) sin_row = fx.slice(Sin_buf, (pos_val, None)) - sin_div = fx.logical_divide(sin_row, fx.make_layout(VEC_WIDTH, 1)) - - # NeoX rotation - is_first_half = tid < fx.Int32(vecs_per_half) - pair_tid = is_first_half.select(tid + vecs_per_half, tid - vecs_per_half) - cos_vec_idx = tid % vecs_per_half - - qk_e = load_vec(k_div, tid) - cos_e = load_vec(cos_div, cos_vec_idx) - sin_e = load_vec(sin_div, cos_vec_idx) - pair_e = load_vec(k_div, pair_tid) - - qk_cos = qk_e * cos_e - pair_sin = pair_e * sin_e - sin_term = is_first_half.select(-pair_sin, pair_sin) - k_rot_e = qk_cos + sin_term - - store_vec(k_rot_e, ko_div, tid) - - # --- KV Cache write --- - slot_val = load_scalar_i32(Slot_buf, pid_t) - - if slot_val >= fx.Int32(0): - pid_t_slot = slot_val // block_size - pid_b = slot_val % block_size - - # Load V - v_row = fx.slice(V_buf, (pid_t, fx.Int32(pid_hk), None)) - v_div = fx.logical_divide(v_row, fx.make_layout(VEC_WIDTH, 1)) - v_e = load_vec(v_div, tid) - - if flash_layout: - # Flash: [num_blocks, block_size, KH, D] → 1D, tile by VEC_WIDTH - kc_row = fx.slice(KC_buf, (pid_t_slot, pid_b, fx.Int32(pid_hk), None)) - kc_div = fx.logical_divide(kc_row, fx.make_layout(VEC_WIDTH, 1)) - vc_row = fx.slice(VC_buf, (pid_t_slot, pid_b, fx.Int32(pid_hk), None)) - vc_div = fx.logical_divide(vc_row, fx.make_layout(VEC_WIDTH, 1)) - - store_vec(k_rot_e, kc_div, tid) - store_vec(v_e, vc_div, tid) + cos_div = fx.logical_divide(cos_row, vec_div_lay) + sin_div = fx.logical_divide(sin_row, vec_div_lay) + cos_e = load_vec(cos_div, cos_vec_idx) + sin_e = load_vec(sin_div, cos_vec_idx) + + # --- Q RoPE (head_idx < num_q_heads) --- + if head_idx < fx.Int32(num_q_heads_val): + Q_buf = fx.rocdl.make_buffer_tensor(Q) + Q_out_buf = fx.rocdl.make_buffer_tensor(Q_out) + + q_row = fx.slice(Q_buf, (pid_t, head_idx, None)) + q_div = fx.logical_divide(q_row, vec_div_lay) + qo_row = fx.slice(Q_out_buf, (pid_t, head_idx, None)) + qo_div = fx.logical_divide(qo_row, vec_div_lay) + + q_e_vec = load_vec(q_div, tid) + q_e = ArithValue(q_e_vec) + # Use ds_bpermute to get pair element via LDS cross-lane shuffle (no VMEM). + q_pair_e = ArithValue(ds_bpermute_pair(q_e_vec, pair_byte_addr)) + + q_cos = q_e * ArithValue(cos_e) + q_pair_sin = q_pair_e * ArithValue(sin_e) + q_sin_term = is_first_half.select(-q_pair_sin, q_pair_sin) + q_rot_e = q_cos + q_sin_term + + store_vec(q_rot_e.ir_value(), qo_div, tid) + + # --- K RoPE + KV cache (head_idx < num_kv_heads) --- + if head_idx < fx.Int32(num_kv_heads_val): + K_buf = fx.rocdl.make_buffer_tensor(K) + K_out_buf = fx.rocdl.make_buffer_tensor(K_out) + + k_row = fx.slice(K_buf, (pid_t, head_idx, None)) + k_div = fx.logical_divide(k_row, vec_div_lay) + ko_row = fx.slice(K_out_buf, (pid_t, head_idx, None)) + ko_div = fx.logical_divide(ko_row, vec_div_lay) + + k_e_vec = load_vec(k_div, tid) + k_e = ArithValue(k_e_vec) + # Use ds_bpermute to get pair element via LDS cross-lane shuffle (no VMEM). + k_pair_e = ArithValue(ds_bpermute_pair(k_e_vec, pair_byte_addr)) + + k_cos = k_e * ArithValue(cos_e) + k_pair_sin = k_pair_e * ArithValue(sin_e) + k_sin_term = is_first_half.select(-k_pair_sin, k_pair_sin) + k_rot_e = k_cos + k_sin_term + + store_vec(k_rot_e.ir_value(), ko_div, tid) + # K_buf, K_out_buf now dead — 8 SGPRs freed + + # --- KV Cache write --- + slot_rsrc = buffer_ops.create_buffer_resource(SlotMapping, max_size=True) + if pos_dtype == "i64": + slot_elem_off = ArithValue(pid_t) * 2 else: - # Non-flash key_cache: [num_blocks, KH, D//x, BS, x] - dim_group = (tid * VEC_WIDTH) // x_size - sub_tile = tid % (x_size // VEC_WIDTH) - - kc_nf_row = fx.slice(KC_buf, (pid_t_slot, fx.Int32(pid_hk), dim_group, pid_b, None)) - kc_nf_div = fx.logical_divide(kc_nf_row, fx.make_layout(VEC_WIDTH, 1)) - store_vec(k_rot_e, kc_nf_div, sub_tile) - - # Non-flash value_cache: [num_blocks, KH, D, block_size] - for vi in range_constexpr(VEC_WIDTH): - v_scalar = v_e[vi] - d_idx = tid * VEC_WIDTH + vi - vc_row = fx.slice(VC_buf, (pid_t_slot, fx.Int32(pid_hk), d_idx, None)) - vc_div = fx.logical_divide(vc_row, fx.make_layout(1, 1)) - store_scalar(v_scalar, vc_div, pid_b) + slot_elem_off = pid_t + slot_val = buffer_ops.buffer_load(slot_rsrc, slot_elem_off, vec_width=1, dtype=T.i32) + + if slot_val >= fx.Int32(0): + pid_t_slot = ArithValue(slot_val) // block_size + pid_b = ArithValue(slot_val) % block_size + + # Load V via layout API (deferred here to minimize SGPR liveness) + V_buf = fx.rocdl.make_buffer_tensor(V) + v_row = fx.slice(V_buf, (pid_t, head_idx, None)) + v_div = fx.logical_divide(v_row, vec_div_lay) + v_e = load_vec(v_div, tid) + + if apply_scale: + # --- fp8 KV cache path (raw buffer_ops for fp8 intrinsics) --- + ks_buf = fx.rocdl.make_buffer_tensor(KScale) + vs_buf = fx.rocdl.make_buffer_tensor(VScale) + ks_div = fx.logical_divide(fx.slice(ks_buf, (None,)), f32_reg_lay) + vs_div = fx.logical_divide(fx.slice(vs_buf, (None,)), f32_reg_lay) + r_ks = fx.memref_alloca(f32_reg_ty, f32_reg_lay) + r_vs = fx.memref_alloca(f32_reg_ty, f32_reg_lay) + fx.copy_atom_call(f32_copy_atom, fx.slice(ks_div, (None, fx.Int32(0))), r_ks) + fx.copy_atom_call(f32_copy_atom, fx.slice(vs_div, (None, fx.Int32(0))), r_vs) + k_scale_val = vector.extract(fx.memref_load_vec(r_ks), static_position=[0], dynamic_position=[]) + v_scale_val = vector.extract(fx.memref_load_vec(r_vs), static_position=[0], dynamic_position=[]) + k_rcp = fx.rocdl.rcp(T.f32, k_scale_val) + v_rcp = fx.rocdl.rcp(T.f32, v_scale_val) + + k_scaled = [] + v_scaled = [] + for i in range_constexpr(VEC_WIDTH): + # Always use vector.extract; works for VEC_WIDTH=1 (vector<1xbf16>) + # and VEC_WIDTH>1 equally. + ke = ArithValue(vector.extract(k_rot_e.ir_value(), static_position=[i], dynamic_position=[])).extf(T.f32) * k_rcp + ve = ArithValue(vector.extract(v_e, static_position=[i], dynamic_position=[])).extf(T.f32) * v_rcp + k_scaled.append(ke) + v_scaled.append(ve) + + # fp8 packing and store + kc_fp8_rsrc = buffer_ops.create_buffer_resource(KeyCache, max_size=True) + vc_fp8_rsrc = buffer_ops.create_buffer_resource(ValueCache, max_size=True) + + if VEC_WIDTH >= 4: + def pack_fp8(vals): + i32s = [] + for i in range_constexpr(VEC_WIDTH // 4): + lo = fx.rocdl.cvt_pk_fp8_f32( + T.i32, vals[i * 4], vals[i * 4 + 1], fx.Int32(0), False + ) + wd = fx.rocdl.cvt_pk_fp8_f32( + T.i32, vals[i * 4 + 2], vals[i * 4 + 3], lo, True + ) + i32s.append(wd) + return i32s + + k_fp8 = pack_fp8(k_scaled) + v_fp8 = pack_fp8(v_scaled) + + if flash_layout: + kc_byte_off = ( + pid_t_slot * (block_size * num_kv_heads * head_dim) + + pid_b * (num_kv_heads * head_dim) + + ArithValue(head_idx) * head_dim + + ArithValue(tid) * VEC_WIDTH + ) + kc_dw = kc_byte_off // fx.Int32(4) + for wi in range_constexpr(VEC_WIDTH // 4): + buffer_ops.buffer_store(k_fp8[wi], kc_fp8_rsrc, kc_dw + fx.Int32(wi)) + buffer_ops.buffer_store(v_fp8[wi], vc_fp8_rsrc, kc_dw + fx.Int32(wi)) + else: + dim_group = ArithValue(tid) * VEC_WIDTH // x_size + sub_off = ArithValue(tid) * VEC_WIDTH % x_size + kc_byte_off = ( + pid_t_slot * (num_kv_heads * (head_dim // x_size) * block_size * x_size) + + ArithValue(head_idx) * ((head_dim // x_size) * block_size * x_size) + + dim_group * (block_size * x_size) + + pid_b * x_size + + sub_off + ) + kc_dw = kc_byte_off // fx.Int32(4) + for wi in range_constexpr(VEC_WIDTH // 4): + buffer_ops.buffer_store(k_fp8[wi], kc_fp8_rsrc, kc_dw + fx.Int32(wi)) + + for vi in range_constexpr(VEC_WIDTH): + d_idx = ArithValue(tid) * VEC_WIDTH + vi + vc_byte_off = ( + pid_t_slot * (num_kv_heads * head_dim * block_size) + + ArithValue(head_idx) * (head_dim * block_size) + + d_idx * block_size + + pid_b + ) + i32_idx = vi // 4 + byte_in_i32 = vi % 4 + shifted = ArithValue(v_fp8[i32_idx]) >> (byte_in_i32 * 8) + fp8_byte = arith.trunci(T.i8, shifted) + buffer_ops.buffer_store(fp8_byte, vc_fp8_rsrc, vc_byte_off) + else: + # VEC_WIDTH < 4: store individual fp8 bytes + for vi in range_constexpr(VEC_WIDTH): + k_pk = fx.rocdl.cvt_pk_fp8_f32( + T.i32, k_scaled[vi], fx.Float32(0.0), fx.Int32(0), False + ) + v_pk = fx.rocdl.cvt_pk_fp8_f32( + T.i32, v_scaled[vi], fx.Float32(0.0), fx.Int32(0), False + ) + k_byte = arith.trunci(T.i8, k_pk) + v_byte = arith.trunci(T.i8, v_pk) + + d_idx = ArithValue(tid) * VEC_WIDTH + vi + + if flash_layout: + byte_off = ( + pid_t_slot * (block_size * num_kv_heads * head_dim) + + pid_b * (num_kv_heads * head_dim) + + ArithValue(head_idx) * head_dim + + d_idx + ) + buffer_ops.buffer_store(k_byte, kc_fp8_rsrc, byte_off) + buffer_ops.buffer_store(v_byte, vc_fp8_rsrc, byte_off) + else: + dim_grp = d_idx // x_size + sub_o = d_idx % x_size + kc_byte_off = ( + pid_t_slot * (num_kv_heads * (head_dim // x_size) * block_size * x_size) + + ArithValue(head_idx) * ((head_dim // x_size) * block_size * x_size) + + dim_grp * (block_size * x_size) + + pid_b * x_size + + sub_o + ) + buffer_ops.buffer_store(k_byte, kc_fp8_rsrc, kc_byte_off) + + vc_byte_off = ( + pid_t_slot * (num_kv_heads * head_dim * block_size) + + ArithValue(head_idx) * (head_dim * block_size) + + d_idx * block_size + + pid_b + ) + buffer_ops.buffer_store(v_byte, vc_fp8_rsrc, vc_byte_off) + else: + # --- bf16/f16 KV cache path --- + if flash_layout: + # Flash layout: contiguous [num_blocks, block_size, KH, D] + KC_buf = fx.rocdl.make_buffer_tensor(KeyCache) + VC_buf = fx.rocdl.make_buffer_tensor(ValueCache) + kc_row = fx.slice(KC_buf, (pid_t_slot, pid_b, head_idx, None)) + vc_row = fx.slice(VC_buf, (pid_t_slot, pid_b, head_idx, None)) + kc_div = fx.logical_divide(kc_row, vec_div_lay) + vc_div = fx.logical_divide(vc_row, vec_div_lay) + store_vec(k_rot_e.ir_value(), kc_div, tid) + store_vec(v_e, vc_div, tid) + else: + # Non-flash layout: scattered stores, keep raw buffer_ops + kc_rsrc = buffer_ops.create_buffer_resource(KeyCache, max_size=True) + vc_rsrc = buffer_ops.create_buffer_resource(ValueCache, max_size=True) + for vi in range_constexpr(VEC_WIDTH): + d_idx = ArithValue(tid) * VEC_WIDTH + vi + dim_grp = d_idx // x_size + sub_o = d_idx % x_size + kc_nf_off = ( + pid_t_slot * (num_kv_heads * (head_dim // x_size) * block_size * x_size) + + ArithValue(head_idx) * ((head_dim // x_size) * block_size * x_size) + + dim_grp * (block_size * x_size) + + pid_b * x_size + + sub_o + ) + k_elem = vector.extract(k_rot_e.ir_value(), static_position=[vi], dynamic_position=[]) + buffer_ops.buffer_store(k_elem, kc_rsrc, kc_nf_off) + + for vi in range_constexpr(VEC_WIDTH): + d_idx = ArithValue(tid) * VEC_WIDTH + vi + vc_nf_off = ( + pid_t_slot * (num_kv_heads * head_dim * block_size) + + ArithValue(head_idx) * (head_dim * block_size) + + d_idx * block_size + + pid_b + ) + v_elem = vector.extract(v_e, static_position=[vi], dynamic_position=[]) + buffer_ops.buffer_store(v_elem, vc_rsrc, vc_nf_off) @flyc.jit def launch_fused_rope_cache( @@ -364,25 +453,16 @@ def launch_fused_rope_cache( Q_out: fx.Tensor, K_out: fx.Tensor, num_tokens: fx.Int32, + KScale: fx.Tensor, + VScale: fx.Tensor, stream: fx.Stream = fx.Stream(None), ): - # Kernel 1: Q RoPE - n_q = num_tokens * num_q_heads - q_launcher = q_rope_kernel(Q, Positions, CosCache, SinCache, Q_out) - q_launcher.launch( - grid=(n_q, 1, 1), - block=(BLOCK_THREADS, 1, 1), - stream=stream, - ) - - # Kernel 2: K RoPE + KV cache write - n_k = num_tokens * num_kv_heads - k_launcher = k_cache_kernel( - K, V, Positions, CosCache, SinCache, SlotMapping, - KeyCache, ValueCache, K_out, + launcher = fused_qk_rope_reshape_and_cache( + Q, K, V, Positions, CosCache, SinCache, SlotMapping, + KeyCache, ValueCache, Q_out, K_out, KScale, VScale, ) - k_launcher.launch( - grid=(n_k, 1, 1), + launcher.launch( + grid=(max_heads, num_tokens, 1), block=(BLOCK_THREADS, 1, 1), stream=stream, ) diff --git a/tests/kernels/test_fused_rope_cache.py b/tests/kernels/test_fused_rope_cache.py index 5e0cee71..62b1ef91 100644 --- a/tests/kernels/test_fused_rope_cache.py +++ b/tests/kernels/test_fused_rope_cache.py @@ -2,12 +2,36 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 FlyDSL Project Contributors -"""Fused RoPE + KV Cache kernel test. -Tests correctness of the fused kernel against PyTorch reference. -Supports both flash and non-flash KV cache layouts. - -Usage: +"""Fused RoPE + KV Cache kernel correctness tests. + +Calls ``build_fused_rope_cache_module`` directly (no aiter wrapper) and +validates Q/K rotation outputs and KV cache writes against a pure-PyTorch +reference. When AITER is available (installed or reachable via AITER_REPO), +results are also cross-checked against the Triton reference implementation. + +Test dimensions covered +----------------------- +Model configs (QH, KH, D): + - Llama-8B TP1: QH=32, KH=8, D=128 + - Llama-8B TP8: QH=4, KH=1, D=128 + - Llama-70B TP1: QH=64, KH=8, D=128 + - Llama-70B TP8: QH=8, KH=1, D=128 + - Llama-405B TP1: QH=128,KH=8, D=128 + - Llama-405B TP8: QH=16, KH=1, D=128 + - Qwen3-72B TP1: QH=64, KH=4, D=128 + - Qwen3-72B TP8: QH=8, KH=1, D=128 + - GPT-OSS TP1: QH=64, KH=8, D=64 + - GPT-OSS TP8: QH=8, KH=1, D=64 + +Token counts: T=1 (decode), T=32, T=128 (prefill) +KV cache layouts: flash_layout=True / False +Scale: apply_scale=True (fp8 cache) / False (bf16/f16 cache) +Position dtype: i32 / i64 (i64 uses .view(i32) stride-2 indexing) +Cos/sin dim: reuse_freqs_front_part=True (half-dim) / False (full-dim) + +Usage +----- # Fast CI — correctness only (GPT-OSS 120B TP=8, 10 tests): PYTHONPATH=./ pytest tests/kernels/test_fused_rope_cache.py -v -s @@ -25,339 +49,663 @@ """ import os -import logging +import sys -import torch import pytest +import torch +from flydsl.runtime.device import get_rocm_arch as _get_rocm_arch from kernels.fused_rope_cache_kernel import build_fused_rope_cache_module -logging.basicConfig(level=logging.INFO) - -# Cache compiled kernels to avoid redundant JIT compilation across parametrized tests. -# Key: (head_dim, num_q_heads, num_kv_heads, block_size, flash_layout, dtype_str) -_launch_fn_cache: dict = {} - +# --------------------------------------------------------------------------- +# Skip if no GPU +# --------------------------------------------------------------------------- +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available.", allow_module_level=True) -def _get_launch_fn(head_dim, num_q_heads, num_kv_heads, block_size, flash_layout, dtype_str="bf16"): - key = (head_dim, num_q_heads, num_kv_heads, block_size, flash_layout, dtype_str) - if key not in _launch_fn_cache: - _launch_fn_cache[key] = build_fused_rope_cache_module( - head_dim=head_dim, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, - block_size=block_size, is_neox=True, flash_layout=flash_layout, dtype_str=dtype_str, - ) - return _launch_fn_cache[key] +# --------------------------------------------------------------------------- +# Optional AITER Triton cross-check +# AITER_REPO env var: path to aiter repo root (added to sys.path if set). +# Falls back to installed aiter package. +# --------------------------------------------------------------------------- +_AITER_REPO = os.environ.get("AITER_REPO", "") +if _AITER_REPO and _AITER_REPO not in sys.path: + sys.path.insert(0, _AITER_REPO) try: - from tests.kernels.benchmark_common import bench_gpu_us_torch, maybe_enable_aiter - HAS_BENCH = True + from aiter.ops.triton.fusions.fused_kv_cache import fused_qk_rope_reshape_and_cache as _aiter_rope + + HAS_AITER = True except ImportError: - try: - from benchmark_common import bench_gpu_us_torch, maybe_enable_aiter - HAS_BENCH = True - except ImportError: - HAS_BENCH = False + HAS_AITER = False -if not torch.cuda.is_available(): - pytest.skip("CUDA/ROCm not available.", allow_module_level=True) +def _bench_gpu_us(fn, warmup: int = 20, iters: int = 200) -> float: + """Measure GPU kernel time via CUDA events (true device time, no Python-loop overhead).""" + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) * 1e3 / iters # ms → µs + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- BLOCK_SIZE = 16 + +# fp8 dtype: gfx95x (MI350/MI355X) uses e4m3fn; gfx94x (MI300X) uses e4m3fnuz. +# RDNA (gfx10xx/gfx11xx/gfx12xx): fp8 KV cache is a CDNA production feature; +# cvt_pk_fp8_f32 produces a different bit encoding on RDNA, so fp8 cache tests +# are skipped there to avoid false failures from dtype mismatches. +_ARCH = str(_get_rocm_arch()) +_IS_RDNA = not _ARCH.startswith("gfx9") +FP8_DTYPE = torch.float8_e4m3fn if "gfx95" in _ARCH else torch.float8_e4m3fnuz MAX_POS = 8192 +X_SIZE = 16 # x-pack factor in non-flash key cache layout + +# Default atol per dtype +_ATOL = {"bf16": 1e-2, "f16": 5e-3} + +# --------------------------------------------------------------------------- +# Kernel compilation cache +# Keyed by all build-time parameters so each unique config compiles once. +# --------------------------------------------------------------------------- +_kernel_cache: dict = {} + + +def _get_launch_fn( + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + block_size: int, + flash_layout: bool, + dtype_str: str, + apply_scale: bool, + reuse_freqs_front_part: bool, + pos_dtype: str, +): + key = (head_dim, num_q_heads, num_kv_heads, block_size, + flash_layout, dtype_str, apply_scale, reuse_freqs_front_part, pos_dtype) + if key not in _kernel_cache: + _kernel_cache[key] = build_fused_rope_cache_module( + head_dim=head_dim, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + block_size=block_size, + is_neox=True, + flash_layout=flash_layout, + dtype_str=dtype_str, + apply_scale=apply_scale, + reuse_freqs_front_part=reuse_freqs_front_part, + pos_dtype=pos_dtype, + ) + return _kernel_cache[key] + + +# --------------------------------------------------------------------------- +# Reference implementation +# --------------------------------------------------------------------------- -# Model configs: (head_dim, total_q_heads, total_kv_heads) -MODEL_CONFIGS = { - "GPT-OSS-120B": (64, 64, 8), - "Qwen3-235B-MoE": (64, 64, 4), - "Llama-3.1-8B": (128, 32, 8), - "Llama-3.1-70B": (128, 64, 8), - "Qwen3-72B": (128, 64, 8), - "Llama-3.1-405B": (128, 128, 8), -} - -# Default: GPT-OSS 120B TP=8 (fast CI) -HEAD_DIM = 64 -NUM_Q_HEADS = 8 -NUM_KV_HEADS = 1 - - -def fused_rope_cache_ref(q, k, v, cos_cache, sin_cache, positions, slot_mapping, - key_cache, value_cache, block_size, flash_layout=True): - """PyTorch reference for fused RoPE + KV cache. - - Computes rotation in native dtype (bf16/f16) to match AITER/Triton - and FlyDSL precision. Each multiply truncates to native dtype before - the subsequent add/subtract, matching GPU hardware behavior. +def _rope_ref(q, k, v, cos_cache, sin_cache, positions, slot_mapping, + key_cache, value_cache, block_size, flash_layout, + reuse_freqs_front_part): + """Pure-PyTorch NeoX RoPE + KV cache reference. + + Operates in native dtype (bf16/f16) to match GPU hardware rounding. + Half-dim cos/sin are broadcast over the full head as [cos, cos] / [sin, sin]. """ - half_dim = cos_cache.shape[-1] + half_dim = cos_cache.shape[-1] # D//2 when reuse_freqs=True, else D dtype = q.dtype - cos = cos_cache[positions.long()].unsqueeze(1).to(dtype) + + # Index into cos/sin cache by position + cos = cos_cache[positions.long()].unsqueeze(1).to(dtype) # [T, 1, cols] sin = sin_cache[positions.long()].unsqueeze(1).to(dtype) - q1, q2 = q[..., :half_dim], q[..., half_dim:] - q_out = torch.cat([q1 * cos - q2 * sin, q2 * cos + q1 * sin], dim=-1) + # Expand half-dim to full-dim if reuse_freqs_front_part=True + if reuse_freqs_front_part: + # cos/sin shape: [T, 1, D//2] → replicate to [T, 1, D] + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) + + # NeoX rotation: q_out = [q1*cos - q2*sin, q2*cos + q1*sin] + head_dim = q.shape[-1] + q1, q2 = q[..., :head_dim // 2], q[..., head_dim // 2:] + k1, k2 = k[..., :head_dim // 2], k[..., head_dim // 2:] - k1, k2 = k[..., :half_dim], k[..., half_dim:] - k_out = torch.cat([k1 * cos - k2 * sin, k2 * cos + k1 * sin], dim=-1) + q_out = torch.cat([q1 * cos[..., :head_dim // 2] - q2 * sin[..., :head_dim // 2], + q2 * cos[..., head_dim // 2:] + q1 * sin[..., head_dim // 2:]], dim=-1) + k_out = torch.cat([k1 * cos[..., :head_dim // 2] - k2 * sin[..., :head_dim // 2], + k2 * cos[..., head_dim // 2:] + k1 * sin[..., head_dim // 2:]], dim=-1) key_cache_out = key_cache.clone() value_cache_out = value_cache.clone() - slots_cpu = slot_mapping.cpu().tolist() - for i, slot in enumerate(slots_cpu): - if slot >= 0: - bi = slot // block_size - bp = slot % block_size - if flash_layout: - key_cache_out[bi, bp] = k_out[i] - value_cache_out[bi, bp] = v[i] - else: - # key_cache: [num_blocks, KH, D//x, block_size, x] - x = 16 - k_row = k_out[i] # [KH, D] - key_cache_out[bi, :, :, bp, :] = k_row.view( - k_row.shape[0], k_row.shape[1] // x, x) - # value_cache: [num_blocks, KH, D, block_size] - value_cache_out[bi, :, :, bp] = v[i] - return q_out, k_out, key_cache_out, value_cache_out + for i, slot in enumerate(slot_mapping.cpu().tolist()): + if slot < 0: + continue + bi = slot // block_size + bp = slot % block_size + if flash_layout: + key_cache_out[bi, bp] = k_out[i] + value_cache_out[bi, bp] = v[i] + else: + # key_cache: [num_blocks, KH, D//x, block_size, x] + k_row = k_out[i] # [KH, D] + key_cache_out[bi, :, :, bp, :] = k_row.view(k_row.shape[0], k_row.shape[1] // X_SIZE, X_SIZE) + # value_cache: [num_blocks, KH, D, block_size] + value_cache_out[bi, :, :, bp] = v[i] + return q_out, k_out, key_cache_out, value_cache_out -def run_fused_test(num_tokens, head_dim=HEAD_DIM, num_q_heads=NUM_Q_HEADS, - num_kv_heads=NUM_KV_HEADS, block_size=BLOCK_SIZE, - max_pos=MAX_POS, flash_layout=True, negative_slots=False, - dtype_str="bf16"): - """Run fused RoPE + KV cache kernel test. - Args: - negative_slots: If True, set odd-indexed slots to -1 to exercise - the slot < 0 (skip KV cache write) path. - dtype_str: Element dtype ("bf16" or "f16"). +# --------------------------------------------------------------------------- +# Core test runner +# --------------------------------------------------------------------------- + +def run_test( + num_tokens: int, + head_dim: int, + num_q_heads: int, + num_kv_heads: int, + flash_layout: bool = True, + dtype_str: str = "bf16", + apply_scale: bool = False, + reuse_freqs_front_part: bool = True, + pos_dtype: str = "i32", + negative_slots: bool = False, + block_size: int = BLOCK_SIZE, + max_pos: int = MAX_POS, + bench: bool = False, +): + """Build kernel, run it, and compare against reference (and AITER if available). + + Returns (passed, max_errors_dict). + When bench=True or FLYDSL_BENCH=1, also prints FlyDSL vs AITER timing. """ device = torch.device("cuda") torch_dtype = torch.bfloat16 if dtype_str == "bf16" else torch.float16 - num_blocks = max(32, (num_tokens + block_size - 1) // block_size + 1) - rotary_dim = head_dim # full rotation - - layout_name = "flash" if flash_layout else "non-flash" - print(f"[fused_rope_cache] M={num_tokens}, BS={block_size}, " - f"QH={num_q_heads}, KH={num_kv_heads}, D={head_dim}, layout={layout_name}, dtype={dtype_str}") - - launch_fn = _get_launch_fn(head_dim, num_q_heads, num_kv_heads, block_size, flash_layout, dtype_str) + num_blocks = max(32, (num_tokens + block_size - 1) // block_size + 4) + half_dim = head_dim // 2 + cos_sin_cols = half_dim if reuse_freqs_front_part else head_dim + + launch_fn = _get_launch_fn( + head_dim=head_dim, + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + block_size=block_size, + flash_layout=flash_layout, + dtype_str=dtype_str, + apply_scale=apply_scale, + reuse_freqs_front_part=reuse_freqs_front_part, + pos_dtype=pos_dtype, + ) torch.manual_seed(42) q = torch.randn(num_tokens, num_q_heads, head_dim, device=device, dtype=torch_dtype) k = torch.randn(num_tokens, num_kv_heads, head_dim, device=device, dtype=torch_dtype) v = torch.randn(num_tokens, num_kv_heads, head_dim, device=device, dtype=torch_dtype) - cos_cache = torch.randn(max_pos, rotary_dim // 2, device=device, dtype=torch_dtype) - sin_cache = torch.randn(max_pos, rotary_dim // 2, device=device, dtype=torch_dtype) - positions = torch.randint(0, max_pos, (num_tokens,), device=device, dtype=torch.int32) + cos_cache = torch.randn(max_pos, cos_sin_cols, device=device, dtype=torch_dtype) + sin_cache = torch.randn(max_pos, cos_sin_cols, device=device, dtype=torch_dtype) + + # Positions: i32 or i64 (i64 stored as int64 but kernel reads via stride-2 i32 view) + positions_i32 = torch.randint(0, max_pos, (num_tokens,), device=device, dtype=torch.int32) + if pos_dtype == "i64": + # The kernel expects positions as int64 tensor but reads each element + # as two consecutive i32 words, taking only the low word (little-endian). + positions_tensor = positions_i32.to(torch.int64) + else: + positions_tensor = positions_i32 + slot_mapping = torch.arange(num_tokens, device=device, dtype=torch.int32) if negative_slots: - # Set odd-indexed slots to -1 so their KV cache writes are skipped slot_mapping[1::2] = -1 - x_size = 16 + if pos_dtype == "i64": + slot_mapping_tensor = slot_mapping.to(torch.int64) + else: + slot_mapping_tensor = slot_mapping + if flash_layout: key_cache = torch.zeros(num_blocks, block_size, num_kv_heads, head_dim, device=device, dtype=torch_dtype) value_cache = torch.zeros(num_blocks, block_size, num_kv_heads, head_dim, device=device, dtype=torch_dtype) else: - key_cache = torch.zeros(num_blocks, num_kv_heads, head_dim // x_size, block_size, x_size, + key_cache = torch.zeros(num_blocks, num_kv_heads, head_dim // X_SIZE, block_size, X_SIZE, device=device, dtype=torch_dtype) value_cache = torch.zeros(num_blocks, num_kv_heads, head_dim, block_size, device=device, dtype=torch_dtype) + if apply_scale: + # fp8 cache: allocate as fp8 type for storage, but kernel uses raw buffer_ops. + # Scales must be 1-D tensors (FlyDSL requires at least one dimension). + kc_fp8 = torch.zeros_like(key_cache).to(FP8_DTYPE) + vc_fp8 = torch.zeros_like(value_cache).to(FP8_DTYPE) + kv_scale = 0.1 # round-trip friendly: maps bf16 range into fp8 range + k_scale = torch.tensor([kv_scale], dtype=torch.float32, device=device) + v_scale = torch.tensor([kv_scale], dtype=torch.float32, device=device) + else: + kc_fp8 = key_cache + vc_fp8 = value_cache + k_scale = torch.ones(1, dtype=torch.float32, device=device) + v_scale = torch.ones(1, dtype=torch.float32, device=device) + q_out = torch.empty_like(q) k_out = torch.empty_like(k) - # Reference - q_ref, k_ref, kc_ref, vc_ref = fused_rope_cache_ref( - q, k, v, cos_cache, sin_cache, positions, slot_mapping, - key_cache.clone(), value_cache.clone(), block_size, flash_layout=flash_layout, - ) - - # Launch FlyDSL kernel — correctness run stream = torch.cuda.current_stream() - launch_fn(q, k, v, positions, cos_cache, sin_cache, slot_mapping, - key_cache, value_cache, q_out, k_out, num_tokens, stream=stream) + launch_fn( + q, k, v, + positions_tensor, cos_cache, sin_cache, + slot_mapping_tensor, + kc_fp8, vc_fp8, + q_out, k_out, + num_tokens, + k_scale, v_scale, + stream=stream, + ) torch.cuda.synchronize() - # Perf measurement — opt-in via FLYDSL_BENCH=1 to avoid slowing CI - run_bench = HAS_BENCH and os.environ.get("FLYDSL_BENCH", "0") == "1" - if run_bench: - def run_flydsl(): - launch_fn(q, k, v, positions, cos_cache, sin_cache, slot_mapping, - key_cache, value_cache, q_out, k_out, num_tokens, stream=stream) - us = bench_gpu_us_torch(run_flydsl, warmup=10, iters=100) - - # Compute bandwidth - total_bytes = (q.nelement() + k.nelement() + v.nelement()) * 2 * 2 # read+write bf16 - total_bytes += cos_cache[0:1].nelement() * 2 * 2 * num_tokens # cos+sin per token - bw_gbs = total_bytes / (us * 1e-6) / 1e9 if us > 0 else 0 - print(f" [flyc] {us:.1f} us, BW: {bw_gbs:.2f} GB/s") - else: - us = 0.0 + # Reference (bf16/f16 path only — fp8 correctness checked separately) + q_ref, k_ref, kc_ref, vc_ref = _rope_ref( + q, k, v, + cos_cache, sin_cache, + positions_i32, # always i32 for reference indexing + slot_mapping, # always i32 + key_cache.clone(), value_cache.clone(), + block_size, + flash_layout=flash_layout, + reuse_freqs_front_part=reuse_freqs_front_part, + ) - # Verify — dtype-specific tolerance (bf16 eps ~0.0078, f16 eps ~0.001) - atol = 1e-2 if dtype_str == "bf16" else 5e-3 + atol = _ATOL[dtype_str] q_err = (q_out.float() - q_ref.float()).abs().max().item() k_err = (k_out.float() - k_ref.float()).abs().max().item() - # Compare full KV cache tensors (same layout for ref and kernel) - kc_err = (key_cache.float() - kc_ref.float()).abs().max().item() - vc_err = (value_cache.float() - vc_ref.float()).abs().max().item() - - print(f" q_err={q_err:.6f}, k_err={k_err:.6f}, kc_err={kc_err:.6f}, vc_err={vc_err:.6f}") - - # Optional AITER comparison (requires FLYDSL_BENCH=1) - # Skip when negative_slots: AITER may leave k_out uninitialized for skipped - # tokens (output_zeros=False), making the cross-check meaningless. - if run_bench and not negative_slots and maybe_enable_aiter(): - try: - from aiter.ops.triton.fusions.fused_kv_cache import fused_qk_rope_reshape_and_cache - except ImportError: - try: - from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache - except ImportError: - fused_qk_rope_reshape_and_cache = None - - if fused_qk_rope_reshape_and_cache is not None: - cos_4d = cos_cache.unsqueeze(1).unsqueeze(1) - sin_4d = sin_cache.unsqueeze(1).unsqueeze(1) - pos_i64 = positions.to(torch.int64) - slots_i64 = slot_mapping.to(torch.int64) - kc_aiter = torch.zeros_like(key_cache) - vc_aiter = torch.zeros_like(value_cache) - qo_aiter = torch.empty_like(q) - ko_aiter = torch.empty_like(k) - # Pre-clone inputs so clone overhead is NOT in timed region - q_aiter = q.clone() - k_aiter = k.clone() - v_aiter = v.clone() - ks = torch.tensor([1.0], device=device, dtype=torch.float32) - vs = torch.tensor([1.0], device=device, dtype=torch.float32) - - def launch_aiter(): - fused_qk_rope_reshape_and_cache( - q_aiter, k_aiter, v_aiter, kc_aiter, vc_aiter, - slots_i64, pos_i64, cos_4d, sin_4d, ks, vs, + if not apply_scale: + kc_err = (kc_fp8.float() - kc_ref.float()).abs().max().item() + vc_err = (vc_fp8.float() - vc_ref.float()).abs().max().item() + passed = q_err < atol and k_err < atol and kc_err < atol and vc_err < atol + errs = {"q": q_err, "k": k_err, "kc": kc_err, "vc": vc_err} + else: + # fp8: dequantize the written cache with per-tensor scales and compare + # against the bf16 reference. This catches packing/indexing/scaling bugs + # and validates that negative-slot entries are left unchanged (the reference + # preserves zeros for skipped slots, so kc_ref/vc_ref already encodes that). + kc_deq = kc_fp8.to(torch.float32) * k_scale.float() + vc_deq = vc_fp8.to(torch.float32) * v_scale.float() + kc_err = (kc_deq - kc_ref.float()).abs().max().item() + vc_err = (vc_deq - vc_ref.float()).abs().max().item() + # fp8 e4m3 quantization error bound: 0.5 * (binade step at the max stored value). + # For a value v stored with scale s, the fp8 input is v/s. The binade step at + # x is x * 2^(1-mbits) = x/4 for e4m3 (3 mantissa bits). Dequant error ≤ + # 0.5 * (v/s)/4 * s = v/8. So max error ≤ max(|kc_ref|) / 8. + kc_max = kc_ref.float().abs().max().item() + vc_max = vc_ref.float().abs().max().item() + kc_atol = max(1e-3, kc_max / 8.0) + vc_atol = max(1e-3, vc_max / 8.0) + passed = q_err < atol and k_err < atol and kc_err < kc_atol and vc_err < vc_atol + errs = {"q": q_err, "k": k_err, "kc": kc_err, "vc": vc_err, + "kc_atol": kc_atol, "vc_atol": vc_atol} + + do_bench = bench or os.environ.get("FLYDSL_BENCH", "0") == "1" + + # AITER cross-check (and optional benchmark) + if HAS_AITER and not negative_slots and not apply_scale: + # AITER Triton wrapper expects int64 slots/positions and 4D cos/sin + slots_i64 = slot_mapping.to(torch.int64) + pos_i64 = positions_i32.to(torch.int64) + cos_4d = cos_cache.unsqueeze(1).unsqueeze(1) # [max_pos, 1, 1, cols] + sin_4d = sin_cache.unsqueeze(1).unsqueeze(1) + + kc_aiter = key_cache.clone().zero_() + vc_aiter = value_cache.clone().zero_() + q_aiter = torch.empty_like(q) + k_aiter = torch.empty_like(k) + + _aiter_rope( + q, k, v, kc_aiter, vc_aiter, + slots_i64, pos_i64, cos_4d, sin_4d, + k_scale, v_scale, + is_neox=True, flash_layout=flash_layout, + apply_scale=False, offs=None, + q_out=q_aiter, k_out=k_aiter, output_zeros=False, + ) + torch.cuda.synchronize() + + q_vs_aiter = (q_out.float() - q_aiter.float()).abs().max().item() + k_vs_aiter = (k_out.float() - k_aiter.float()).abs().max().item() + kc_vs_aiter = (kc_fp8.float() - kc_aiter.float()).abs().max().item() + vc_vs_aiter = (vc_fp8.float() - vc_aiter.float()).abs().max().item() + errs["aiter_q"] = q_vs_aiter + errs["aiter_k"] = k_vs_aiter + errs["aiter_kc"] = kc_vs_aiter + errs["aiter_vc"] = vc_vs_aiter + + if do_bench: + def _run_fly(): + launch_fn( + q, k, v, positions_tensor, cos_cache, sin_cache, + slot_mapping_tensor, kc_fp8, vc_fp8, q_out, k_out, + num_tokens, k_scale, v_scale, + stream=torch.cuda.current_stream(), + ) + + def _run_aiter(): + _aiter_rope( + q, k, v, kc_aiter, vc_aiter, + slots_i64, pos_i64, cos_4d, sin_4d, + k_scale, v_scale, is_neox=True, flash_layout=flash_layout, - apply_scale=False, q_out=qo_aiter, k_out=ko_aiter, - output_zeros=False, + apply_scale=False, offs=None, + q_out=q_aiter, k_out=k_aiter, output_zeros=False, ) - aiter_us = bench_gpu_us_torch(launch_aiter, warmup=10, iters=100) - speedup = aiter_us / us if us > 0 else 0 + fly_us = _bench_gpu_us(_run_fly) + aiter_us = _bench_gpu_us(_run_aiter) + speedup = aiter_us / fly_us if fly_us > 0 else 0.0 + errs["fly_us"] = fly_us + errs["aiter_us"] = aiter_us + errs["speedup"] = speedup + + return passed, errs + + +# =========================================================================== +# Category 1: Core decode configs (T=1) — fast CI gate +# =========================================================================== + +@pytest.mark.parametrize("num_q_heads,num_kv_heads,head_dim", [ + (32, 8, 128), # Llama-8B TP1 + (4, 1, 128), # Llama-8B TP8 + (64, 8, 128), # Llama-70B TP1 + (8, 1, 128), # Llama-70B TP8 + (64, 8, 64), # GPT-OSS TP1 + (8, 1, 64), # GPT-OSS TP8 +], ids=[ + "Llama8B-TP1", "Llama8B-TP8", + "Llama70B-TP1", "Llama70B-TP8", + "GPTOSS-TP1", "GPTOSS-TP8", +]) +def test_decode_flash(num_q_heads, num_kv_heads, head_dim): + """T=1 decode, flash layout, bf16 — core correctness gate.""" + passed, errs = run_test( + num_tokens=1, head_dim=head_dim, + num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, + flash_layout=True, dtype_str="bf16", + ) + assert passed, f"FAILED: {errs}" + - # Cross-validate: AITER vs FlyDSL (looser tolerance — two independent - # GPU implementations may differ in operation ordering/rounding) - cross_atol = 1e-2 - torch.cuda.synchronize() - q_cross_err = (qo_aiter.float() - q_out.float()).abs().max().item() - k_cross_err = (ko_aiter.float() - k_out.float()).abs().max().item() - cross_ok = q_cross_err < cross_atol and k_cross_err < cross_atol - cross_status = "MATCH" if cross_ok else "MISMATCH" - print(f" [aiter] {aiter_us:.1f} us → FlyDSL/AITER: {speedup:.2f}x " - f"(cross-check: {cross_status}, Q={q_cross_err:.2e}, K={k_cross_err:.2e})") +# =========================================================================== +# Category 2: Flash layout, all token sizes, bf16 +# =========================================================================== - ok = q_err < atol and k_err < atol and kc_err < atol and vc_err < atol - return ok, q_err, k_err, kc_err, vc_err +@pytest.mark.parametrize("num_tokens", [1, 32, 128]) +@pytest.mark.parametrize("num_q_heads,num_kv_heads,head_dim", [ + (32, 8, 128), # Llama-8B TP1 + (4, 1, 128), # Llama-8B TP8 + (64, 8, 128), # Llama-70B TP1 + (8, 1, 128), # Llama-70B TP8 + (128, 8, 128), # Llama-405B TP1 + (16, 1, 128), # Llama-405B TP8 + (64, 4, 128), # Qwen3-72B TP1 + (8, 1, 128), # Qwen3-72B TP8 (same shape as Llama-70B TP8) + (64, 8, 64), # GPT-OSS TP1 + (8, 1, 64), # GPT-OSS TP8 +], ids=[ + "Llama8B-TP1", "Llama8B-TP8", + "Llama70B-TP1", "Llama70B-TP8", + "Llama405B-TP1", "Llama405B-TP8", + "Qwen72B-TP1", "Qwen72B-TP8", + "GPTOSS-TP1", "GPTOSS-TP8", +]) +def test_flash_bf16(num_tokens, num_q_heads, num_kv_heads, head_dim): + """Flash layout, bf16, all supported model configs and token sizes.""" + passed, errs = run_test( + num_tokens=num_tokens, head_dim=head_dim, + num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, + flash_layout=True, dtype_str="bf16", + ) + assert passed, f"FAILED (T={num_tokens}): {errs}" -# --- Default tests: GPT-OSS 120B TP=8 (fast CI) --- +# =========================================================================== +# Category 3: Non-flash layout +# =========================================================================== -@pytest.mark.parametrize("num_tokens", [1, 4, 16, 32, 128]) -def test_fused_rope_cache_flash(num_tokens): - ok, q_err, k_err, kc_err, vc_err = run_fused_test(num_tokens, flash_layout=True) - assert ok, f"FAILED: q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" +@pytest.mark.parametrize("num_tokens", [1, 32, 128]) +@pytest.mark.parametrize("num_q_heads,num_kv_heads,head_dim", [ + (32, 8, 128), # Llama-8B TP1 + (4, 1, 128), # Llama-8B TP8 + (8, 1, 128), # Llama-70B TP8 + (8, 1, 64), # GPT-OSS TP8 +], ids=["Llama8B-TP1", "Llama8B-TP8", "Llama70B-TP8", "GPTOSS-TP8"]) +def test_nonflash_bf16(num_tokens, num_q_heads, num_kv_heads, head_dim): + """Non-flash (ATOM-default) layout, bf16.""" + passed, errs = run_test( + num_tokens=num_tokens, head_dim=head_dim, + num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, + flash_layout=False, dtype_str="bf16", + ) + assert passed, f"FAILED (T={num_tokens}): {errs}" -@pytest.mark.parametrize("num_tokens", [1, 4, 16, 32, 128]) -def test_fused_rope_cache_nonflash(num_tokens): - ok, q_err, k_err, kc_err, vc_err = run_fused_test(num_tokens, flash_layout=False) - assert ok, f"FAILED: q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" +# =========================================================================== +# Category 4: f16 dtype +# =========================================================================== +@pytest.mark.parametrize("num_tokens", [1, 32]) +@pytest.mark.parametrize("flash_layout", [True, False], ids=["flash", "nonflash"]) +def test_f16(num_tokens, flash_layout): + """f16 dtype — Llama-8B TP8 representative config.""" + passed, errs = run_test( + num_tokens=num_tokens, head_dim=128, + num_q_heads=4, num_kv_heads=1, + flash_layout=flash_layout, dtype_str="f16", + ) + assert passed, f"FAILED (T={num_tokens} flash={flash_layout}): {errs}" -# --- f16 tests --- -@pytest.mark.parametrize("num_tokens", [1, 4, 32]) +# =========================================================================== +# Category 5: pos_dtype — i32 vs i64 (stride-2 indexing) +# =========================================================================== + +@pytest.mark.parametrize("pos_dtype", ["i32", "i64"]) +@pytest.mark.parametrize("num_tokens", [1, 32, 128]) +def test_pos_dtype(pos_dtype, num_tokens): + """Position tensor dtype: i32 direct vs i64 stride-2 view. + + The i64 path reads each int64 as two consecutive i32 words (low word only), + which is the same physical value on little-endian AMD GPUs. + """ + passed, errs = run_test( + num_tokens=num_tokens, head_dim=128, + num_q_heads=8, num_kv_heads=1, + flash_layout=True, dtype_str="bf16", + pos_dtype=pos_dtype, + ) + assert passed, f"FAILED (pos_dtype={pos_dtype} T={num_tokens}): {errs}" + + +# =========================================================================== +# Category 6: reuse_freqs_front_part — half-dim vs full-dim cos/sin +# =========================================================================== + +@pytest.mark.parametrize("reuse_freqs_front_part", [True, False], + ids=["half_dim", "full_dim"]) +@pytest.mark.parametrize("num_tokens", [1, 32]) @pytest.mark.parametrize("flash_layout", [True, False], ids=["flash", "nonflash"]) -def test_fused_rope_cache_f16(num_tokens, flash_layout): - ok, q_err, k_err, kc_err, vc_err = run_fused_test( - num_tokens, flash_layout=flash_layout, dtype_str="f16", +def test_reuse_freqs(reuse_freqs_front_part, num_tokens, flash_layout): + """Cos/sin shape: half-dim [max_pos, D//2] vs full-dim [max_pos, D].""" + passed, errs = run_test( + num_tokens=num_tokens, head_dim=128, + num_q_heads=8, num_kv_heads=1, + flash_layout=flash_layout, dtype_str="bf16", + reuse_freqs_front_part=reuse_freqs_front_part, ) - assert ok, f"FAILED: q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" + assert passed, f"FAILED (reuse={reuse_freqs_front_part} T={num_tokens}): {errs}" -# --- Negative slot tests: ensure slot < 0 skips KV cache write --- +# =========================================================================== +# Category 7: Negative slots (slot < 0 skips KV cache write) +# =========================================================================== @pytest.mark.parametrize("num_tokens", [4, 32]) @pytest.mark.parametrize("flash_layout", [True, False], ids=["flash", "nonflash"]) -def test_fused_rope_cache_negative_slots(num_tokens, flash_layout): - ok, q_err, k_err, kc_err, vc_err = run_fused_test( - num_tokens, flash_layout=flash_layout, negative_slots=True, +def test_negative_slots(num_tokens, flash_layout): + """Odd-indexed slots set to -1; those KV cache positions must remain zero.""" + passed, errs = run_test( + num_tokens=num_tokens, head_dim=128, + num_q_heads=8, num_kv_heads=1, + flash_layout=flash_layout, dtype_str="bf16", + negative_slots=True, ) - assert ok, f"FAILED: q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" + assert passed, f"FAILED (T={num_tokens} flash={flash_layout}): {errs}" -# --- Multi-model tests (opt-in via FLYDSL_ALL_MODELS=1) --- +# =========================================================================== +# Category 8: fp8 KV cache (apply_scale=True) — finite-value sanity check +# =========================================================================== -_MULTI_MODEL_CASES = [] -for _model, (_hd, _total_qh, _total_kh) in MODEL_CONFIGS.items(): - for _tp in [1, 8]: - _qh = _total_qh // _tp - _kh = max(1, _total_kh // _tp) - if _qh >= 1: - _MULTI_MODEL_CASES.append( - pytest.param(_model, _hd, _qh, _kh, id=f"{_model}-TP{_tp}") - ) +@pytest.mark.skipif(_IS_RDNA, reason="fp8 KV cache is a CDNA production feature; " + "cvt_pk_fp8_f32 bit encoding differs on RDNA (gfx10xx/gfx11xx/gfx12xx)") +@pytest.mark.parametrize("num_tokens", [1, 32]) +@pytest.mark.parametrize("flash_layout", [True, False], ids=["flash", "nonflash"]) +@pytest.mark.parametrize("num_q_heads,num_kv_heads,head_dim", [ + (8, 1, 64), # GPT-OSS TP8 + (8, 1, 128), # Llama TP8 +], ids=["GPTOSS-TP8", "Llama-TP8"]) +def test_fp8_cache(num_tokens, flash_layout, num_q_heads, num_kv_heads, head_dim): + """fp8 KV cache path: Q/K rotation correct, cache values finite.""" + passed, errs = run_test( + num_tokens=num_tokens, head_dim=head_dim, + num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, + flash_layout=flash_layout, dtype_str="bf16", + apply_scale=True, + ) + assert passed, f"FAILED (T={num_tokens} flash={flash_layout}): {errs}" + + +# =========================================================================== +# Category 9: Cross-parameter sweep (opt-in via FLYDSL_ALL_MODELS=1) +# =========================================================================== +_ALL_CONFIGS = [ + ("Llama8B-TP1", 32, 8, 128), + ("Llama8B-TP8", 4, 1, 128), + ("Llama70B-TP1", 64, 8, 128), + ("Llama70B-TP8", 8, 1, 128), + ("Llama405B-TP1", 128, 8, 128), + ("Llama405B-TP8", 16, 1, 128), + ("Qwen72B-TP1", 64, 4, 128), + ("Qwen72B-TP8", 8, 1, 128), + ("GPTOSS-TP1", 64, 8, 64), + ("GPTOSS-TP8", 8, 1, 64), +] -@pytest.mark.parametrize("model,head_dim,num_q_heads,num_kv_heads", _MULTI_MODEL_CASES) + +@pytest.mark.parametrize("model,num_q_heads,num_kv_heads,head_dim", + _ALL_CONFIGS, ids=[c[0] for c in _ALL_CONFIGS]) @pytest.mark.parametrize("num_tokens", [1, 32, 128]) @pytest.mark.parametrize("flash_layout", [True, False], ids=["flash", "nonflash"]) +@pytest.mark.parametrize("reuse_freqs_front_part", [True, False], ids=["half_cos", "full_cos"]) +@pytest.mark.parametrize("pos_dtype", ["i32", "i64"]) @pytest.mark.skipif(os.environ.get("FLYDSL_ALL_MODELS", "0") != "1", - reason="Multi-model sweep skipped; set FLYDSL_ALL_MODELS=1 to run") -def test_fused_rope_cache_multi_model(model, head_dim, num_q_heads, num_kv_heads, - num_tokens, flash_layout): - ok, q_err, k_err, kc_err, vc_err = run_fused_test( - num_tokens, head_dim=head_dim, + reason="Full sweep skipped; set FLYDSL_ALL_MODELS=1 to run") +def test_full_sweep(model, num_q_heads, num_kv_heads, head_dim, + num_tokens, flash_layout, reuse_freqs_front_part, pos_dtype): + """Cross-parameter correctness sweep over all models × layouts × dtypes × pos_dtype.""" + passed, errs = run_test( + num_tokens=num_tokens, head_dim=head_dim, num_q_heads=num_q_heads, num_kv_heads=num_kv_heads, - flash_layout=flash_layout, + flash_layout=flash_layout, dtype_str="bf16", + reuse_freqs_front_part=reuse_freqs_front_part, + pos_dtype=pos_dtype, ) - assert ok, f"FAILED ({model}): q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}" + assert passed, (f"FAILED ({model} T={num_tokens} flash={flash_layout} " + f"reuse={reuse_freqs_front_part} pos={pos_dtype}): {errs}") +# =========================================================================== +# CLI entry point +# =========================================================================== + if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser() - parser.add_argument("--all-models", action="store_true", - help="Test all model configs (default: GPT-OSS-120B TP=8 only)") - args = parser.parse_args() - - configs = [] - if args.all_models: - for model, (hd, total_qh, total_kh) in MODEL_CONFIGS.items(): - for tp in [1, 8]: - qh = total_qh // tp - kh = max(1, total_kh // tp) - if qh >= 1: - configs.append((model, tp, hd, qh, kh)) + all_models = "--all-models" in sys.argv + do_bench = os.environ.get("FLYDSL_BENCH", "0") == "1" + + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"AITER cross-check: {'enabled' if HAS_AITER else 'disabled (set AITER_REPO or install aiter)'}") + print(f"Benchmark: {'enabled' if do_bench else 'disabled (set FLYDSL_BENCH=1)'}") + failures = 0 + + def _run(label, **kwargs): + global failures + passed, errs = run_test(**kwargs, bench=do_bench) + status = "PASS" if passed else "FAIL" + bench_str = "" + if "fly_us" in errs: + bench_str = (f" FlyDSL={errs['fly_us']:.1f}us " + f"AITER={errs['aiter_us']:.1f}us " + f"speedup={errs['speedup']:.2f}x") + kc_str = f" kc={errs['kc']:.4f}" if "kc" in errs else "" + vc_str = f" vc={errs['vc']:.4f}" if "vc" in errs else "" + print(f" [{status}] {label}: q={errs['q']:.4f} k={errs['k']:.4f}{kc_str}{vc_str}{bench_str}") + if not passed: + failures += 1 + + configs = _ALL_CONFIGS if all_models else [ + ("GPTOSS-TP8", 8, 1, 64), + ] + + print("\n=== Category 1: decode (T=1) flash bf16 ===") + for name, qh, kh, hd in configs: + _run(f"{name} T=1", num_tokens=1, head_dim=hd, + num_q_heads=qh, num_kv_heads=kh, flash_layout=True) + + print("\n=== Category 2: prefill (T=32, T=128) flash bf16 ===") + for name, qh, kh, hd in configs: + for T in [32, 128]: + _run(f"{name} T={T}", num_tokens=T, head_dim=hd, + num_q_heads=qh, num_kv_heads=kh, flash_layout=True) + + print("\n=== Category 3: non-flash bf16 ===") + nonflash = _ALL_CONFIGS if all_models else [("Llama8B-TP8", 4, 1, 128), ("GPTOSS-TP8", 8, 1, 64)] + for name, qh, kh, hd in nonflash: + for T in [1, 32]: + _run(f"{name} T={T}", num_tokens=T, head_dim=hd, + num_q_heads=qh, num_kv_heads=kh, flash_layout=False) + + print("\n=== Category 5: pos_dtype i32 vs i64 ===") + for pos_dtype in ["i32", "i64"]: + for T in [1, 32, 128]: + _run(f"pos_dtype={pos_dtype} T={T}", num_tokens=T, head_dim=128, + num_q_heads=8, num_kv_heads=1, flash_layout=True, pos_dtype=pos_dtype) + + print("\n=== Category 6: reuse_freqs_front_part ===") + for reuse in [True, False]: + for flash in [True, False]: + _run(f"reuse={reuse} flash={flash}", num_tokens=32, head_dim=128, + num_q_heads=8, num_kv_heads=1, flash_layout=flash, + reuse_freqs_front_part=reuse) + + print("\n=== Category 8: fp8 cache ===") + for flash in [True, False]: + for T in [1, 32]: + _run(f"fp8 flash={flash} T={T}", num_tokens=T, head_dim=128, + num_q_heads=8, num_kv_heads=1, flash_layout=flash, apply_scale=True) + + print(f"\n{'='*60}") + if failures == 0: + print("ALL TESTS PASSED") + sys.exit(0) else: - configs = [("GPT-OSS-120B", 8, HEAD_DIM, NUM_Q_HEADS, NUM_KV_HEADS)] - - for model, tp, hd, qh, kh in configs: - print(f"\n{'='*60}") - print(f"{model} TP={tp}: QH={qh}, KH={kh}, D={hd}") - print(f"{'='*60}") - for flash_layout in [True, False]: - layout = "flash" if flash_layout else "non-flash" - for m in [1, 4, 32, 128]: - ok, q_err, k_err, kc_err, vc_err = run_fused_test( - m, head_dim=hd, num_q_heads=qh, num_kv_heads=kh, - flash_layout=flash_layout, - ) - status = "PASS" if ok else "FAIL" - print(f" [{status}] {layout:>9s} M={m:>4d} " - f"q={q_err:.2e} k={k_err:.2e} kc={kc_err:.2e} vc={vc_err:.2e}") - print("\nDone.") + print(f"{failures} TESTS FAILED") + sys.exit(1)