diff --git a/kernels/grouped_gemm_blockscale_contiguous.py b/kernels/grouped_gemm_blockscale_contiguous.py new file mode 100644 index 000000000..b91642a97 --- /dev/null +++ b/kernels/grouped_gemm_blockscale_contiguous.py @@ -0,0 +1,691 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Contiguous Grouped FP8 GEMM kernel with block scaling. + +API matching DeepGEMM's m_grouped_fp8_gemm_nt_contiguous: + - A: [M_total, K] FP8 - concatenated rows from all groups + - scale_a: [scale_k, M_total] - per-token, per-128K scales (transposed). + uint8 (E8M0) on gfx950 (HW scaling); FP32 on gfx942 (SW scaling). + - B: [num_groups, N, K] FP8 - one weight matrix per group + - scale_b: [num_groups, scale_n, scale_k] - per-block scales. + uint8 (E8M0) on gfx950; FP32 on gfx942. + - D: [M_total, N] BF16 - output + - grouped_layout: [M_total] INT32 - maps each row to group ID (-1 for padding) + +Block scaling granularity (matching DeepGEMM): + - A: (1, 128) - per-token, per-128-K-elements + - B: (128, 128) - per-128-N, per-128-K block +""" + +import functools + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import arith, gpu, buffer_ops, vector, rocdl +from flydsl.expr import range_constexpr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from flydsl._mlir import ir +from flydsl._mlir.dialects import scf +from flydsl._mlir.dialects import math as math_dialect +from flydsl.expr.typing import T +from flydsl.expr.arith import ArithValue + +from kernels.mfma_epilogues import mfma_epilog +from kernels.mfma_preshuffle_pipeline import ( + crd2idx, + lds_store_16b_xor16, + load_b_pack_k32, + make_preshuffle_b_layout, + swizzle_xor16, + tile_chunk_coord_i32, +) + + +@functools.lru_cache(maxsize=128) +def compile_grouped_gemm_blockscale_contiguous( + *, + n: int, + k: int, + num_groups: int, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + scale_block_k: int = 128, + scale_block_n: int = 128, + out_dtype: str = "bf16", + waves_per_eu: int | None = None, +): + """Compile grouped FP8 GEMM kernel and return the JIT launcher. + + Args: + n: N dimension (output columns per group) + k: K dimension (reduction dimension) + num_groups: Number of groups (experts) + tile_m: M tile size (default 128) + tile_n: N tile size (default 128) + tile_k: K tile size (default 128) + scale_block_k: K-dimension scale block size (default 128) + scale_block_n: N-dimension scale block size (default 128) + out_dtype: Output data type ("bf16" or "f16") + + Returns: + JIT launcher function. + """ + gpu_arch = get_hip_arch() + _is_gfx950 = str(gpu_arch).startswith("gfx95") + + allocator = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_grouped_gemm") + + # Validate parameters + if k % tile_k != 0: + raise ValueError(f"k ({k}) must be divisible by tile_k ({tile_k})") + if n % tile_n != 0: + raise ValueError(f"n ({n}) must be divisible by tile_n ({tile_n})") + if tile_k % scale_block_k != 0: + raise ValueError(f"tile_k ({tile_k}) must be divisible by scale_block_k ({scale_block_k})") + if tile_n % scale_block_n != 0: + raise ValueError(f"tile_n ({tile_n}) must be divisible by scale_block_n ({scale_block_n})") + + # Output type + if out_dtype not in ("bf16", "f16"): + raise ValueError(f"out_dtype must be 'bf16' or 'f16', got {out_dtype!r}") + out_mlir = lambda: T.bf16 if out_dtype == "bf16" else T.f16 + + # Compile-time constants + total_threads = 256 + elem_bytes = 1 # FP8 + num_k_tiles = k // tile_k + scale_k = k // scale_block_k + scale_n = n // scale_block_n + sb_per_tile = tile_k // scale_block_k # scale blocks per K-tile + k_unroll = tile_k // 64 # K64-byte micro-steps (for K32 MFMA pairs) + kpack_bytes = 16 # 16-byte packs for FP8 + + # LDS allocation: max of ping-pong A tiles and CShuffle epilogue output + lds_a_bytes = tile_m * tile_k * elem_bytes + lds_pingpong_bytes = 2 * lds_a_bytes + lds_out_bytes = tile_m * tile_n * 2 # bf16/f16 = 2 bytes per element + lds_total_bytes = max(lds_pingpong_bytes, lds_out_bytes) + lds_alloc_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_alloc_offset + lds_total_bytes + lds_tile_elems = tile_m * tile_k # element offset between ping and pong + + # Module name for caching + module_name = ( + f"grouped_gemm_blockscale_contiguous_{out_dtype}" + f"_n{n}_k{k}_g{num_groups}" + f"_t{tile_m}x{tile_n}x{tile_k}" + f"_pingpong" + ).replace("-", "_") + + # Thread -> tile element mapping for A loads + tile_k_bytes = tile_k * elem_bytes + tile_k_dwords = tile_k_bytes // 4 + bytes_a_per_tile = tile_m * tile_k * elem_bytes + bytes_per_thread_a = bytes_a_per_tile // total_threads + a_load_bytes = 16 # 16-byte loads (dwordx4) + chunk_i32_a = a_load_bytes // 4 # 4 dwords per load + num_a_loads = bytes_per_thread_a // a_load_bytes + + @flyc.kernel(name=module_name) + def grouped_gemm_blockscale_contiguous_kernel( + arg_d: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_scale_a: fx.Tensor, + arg_scale_b: fx.Tensor, + arg_grouped_layout: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + i32_k: fx.Int32, + i32_num_groups: fx.Int32, + ): + # Convert runtime parameters to index type + m_in = arith.index_cast(T.index, i32_m) + n_in = arith.index_cast(T.index, i32_n) + k_in = arith.index_cast(T.index, i32_k) + num_groups_in = arith.index_cast(T.index, i32_num_groups) + + # Thread and block IDs + tx = gpu.thread_id("x") + by = gpu.block_id("x") # N-block index + bx = gpu.block_id("y") # M-block index + + # Block positions + bx_m = bx * fx.Index(tile_m) + by_n = by * fx.Index(tile_n) + + # Wave/lane decomposition (256 threads = 4 waves x 64 lanes) + layout_wave_lane = fx.make_layout((4, 64), stride=(64, 1)) + coord_wave_lane = fx.idx2crd(tx, layout_wave_lane) + wave_id = fx.get(coord_wave_lane, 0) + lane_id = fx.get(coord_wave_lane, 1) + + # Lane decomposition for MFMA (lane_id -> lane_div_16, lane_mod_16) + layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) + coord_lane16 = fx.idx2crd(lane_id, layout_lane16) + lane_div_16 = fx.get(coord_lane16, 0) + lane_mod_16 = fx.get(coord_lane16, 1) + + # LDS setup: single memref for both ping-pong buffers + base_ptr = allocator.get_base() + lds_a = SmemPtr(base_ptr, lds_alloc_offset, T.f8, shape=(2 * tile_m * tile_k,)).get() + lds_stride = tile_k + layout_lds = fx.make_layout((tile_m, tile_k), stride=(lds_stride, 1)) + lds_base_pong = fx.Index(0) + lds_base_ping = fx.Index(lds_tile_elems) + + # CShuffle epilogue LDS (aliased from same base, bf16 element type) + lds_out = SmemPtr(base_ptr, lds_alloc_offset, out_mlir(), shape=(tile_m * tile_n,)).get() + + # Buffer resources + a_nbytes = m_in * k_in + a_rsrc = buffer_ops.create_buffer_resource( + arg_a, max_size=False, num_records_bytes=arith.index_cast(T.i64, a_nbytes) + ) + + b_nbytes = num_groups_in * n_in * k_in + b_rsrc = buffer_ops.create_buffer_resource( + arg_b, max_size=False, num_records_bytes=arith.index_cast(T.i64, b_nbytes) + ) + + d_nbytes = m_in * n_in * fx.Index(2) # bf16/f16 = 2 bytes + d_rsrc = buffer_ops.create_buffer_resource( + arg_d, max_size=False, num_records_bytes=arith.index_cast(T.i64, d_nbytes) + ) + + # Scale buffers — gfx950 HW E8M0 path consumes int8 (one byte/scale, + # pre-packed on host); gfx942 SW path consumes f32. + scale_byte_size = 1 if _is_gfx950 else 4 + + # scale_a: [scale_k, M] - transposed layout + sa_nbytes = fx.Index(scale_k) * m_in * fx.Index(scale_byte_size) + sa_rsrc = buffer_ops.create_buffer_resource( + arg_scale_a, max_size=False, num_records_bytes=arith.index_cast(T.i64, sa_nbytes) + ) + + # scale_b: [num_groups, scale_n, scale_k] + sb_nbytes = num_groups_in * fx.Index(scale_n * scale_k * scale_byte_size) + sb_rsrc = buffer_ops.create_buffer_resource( + arg_scale_b, max_size=False, num_records_bytes=arith.index_cast(T.i64, sb_nbytes) + ) + + # grouped_layout: [M] + gl_nbytes = m_in * fx.Index(4) + gl_rsrc = buffer_ops.create_buffer_resource( + arg_grouped_layout, max_size=False, num_records_bytes=arith.index_cast(T.i64, gl_nbytes) + ) + + # Load group ID for this M-block (use first row of tile) + group_id_i32 = buffer_ops.buffer_load(gl_rsrc, bx_m, vec_width=1, dtype=T.i32) + is_valid = arith.cmpi(arith.CmpIPredicate.sge, group_id_i32, fx.Int32(0)) + + # Early exit for invalid blocks (padding rows) + _if_valid = scf.IfOp(is_valid) + with ir.InsertionPoint(_if_valid.then_block): + group_idx = arith.index_cast(T.index, group_id_i32) + + # MFMA tiling constants + m_repeat = tile_m // 16 # 8 for tile_m=128 + num_waves = 4 + n_per_wave = tile_n // num_waves # 32 for tile_n=128 + num_acc_n = n_per_wave // 16 # 2 for n_per_wave=32 + + # Initialize accumulators (FP32) + acc_init = arith.constant_vector(0.0, T.f32x4) + num_accs = m_repeat * num_acc_n + accs = [acc_init] * num_accs + + # Wave's N-tile base + wave_mod_4 = wave_id % fx.Index(4) + n_tile_base = wave_mod_4 * fx.Index(n_per_wave) + + # Precompute N-block indices for scale_b + c_scale_block_n = fx.Index(scale_block_n) + c_scale_k = fx.Index(scale_k) + n_block_for_scale = [] + for ni in range_constexpr(num_acc_n): + col_base = by_n + n_tile_base + arith.index(ni * 16) + n_blk = col_base // c_scale_block_n + n_block_for_scale.append(n_blk) + + # B preshuffle layout: total N = num_groups * N (all groups concatenated) + c_n_total = num_groups_in * n_in + b_layout = make_preshuffle_b_layout( + arith, c_n=c_n_total, c_k=k_in, + kpack_bytes=kpack_bytes, elem_bytes=elem_bytes, + ) + layout_b = b_layout.layout_b + + # Decompose global N column into (n_blk, n_intra) for preshuffle layout + c_n0 = c_n_total // fx.Index(16) + c_n0_i32 = arith.index_cast(T.i32, c_n0) + layout_n_blk_intra = fx.make_layout((c_n0_i32, 16), stride=(16, 1)) + n_blk_list = [] + n_intra_list = [] + group_n_off = group_idx * n_in # N-offset for this group in concatenated B + for ni in range_constexpr(num_acc_n): + col_global = group_n_off + by_n + n_tile_base + arith.index(ni * 16) + lane_mod_16 + coord_ni = fx.idx2crd(col_global, layout_n_blk_intra) + n_blk_list.append(fx.get(coord_ni, 0)) + n_intra_list.append(fx.get(coord_ni, 1)) + + # A load mapping: thread -> (row, col_i32) in tile (dword-indexed K) + layout_a_tile_div4 = fx.make_layout( + (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) + ) + c_chunk_a = fx.Index(chunk_i32_a) + tx_i32_base = tx * c_chunk_a + _k_div4_factor = k_in // fx.Index(4) + k_blocks16 = arith.index(tile_k_bytes // 16) + c4_bytes = fx.Index(4) + + # Precompute per-load tile coordinates (row_local, col_local_i32) + a_row_local = [] + a_col_local_i32 = [] + for i in range_constexpr(num_a_loads): + row_local, col_local_i32 = tile_chunk_coord_i32( + arith, tx_i32_base=tx_i32_base, i=i, + total_threads=total_threads, + layout_tile_div4=layout_a_tile_div4, + chunk_i32=chunk_i32_a, + ) + a_row_local.append(row_local) + a_col_local_i32.append(col_local_i32) + + # ── Helper: prefetch A tile (gmem -> regs) ───────────────── + def prefetch_a_tile(k_tile_idx_py): + """Load A tile from global memory into VGPRs.""" + base_k_div4 = fx.Index(k_tile_idx_py * tile_k_dwords) + parts = [] + for i in range_constexpr(num_a_loads): + row_global = bx_m + a_row_local[i] + idx_i32 = row_global * _k_div4_factor + base_k_div4 + a_col_local_i32[i] + a_vec = buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=4, dtype=T.i32) + parts.append(vector.bitcast(T.i32x4, a_vec)) + return parts + + # ── Helper: store A regs to LDS with XOR swizzle ─────────── + def store_a_tile_to_lds(a_parts, lds_base): + """Write prefetched A tile from VGPRs into LDS with XOR16 swizzle.""" + for i in range_constexpr(num_a_loads): + lds_store_16b_xor16( + arith, vector, + lds_memref=lds_a, vec16_ty=T.f8x16, + layout_lds=layout_lds, + row_local=a_row_local[i], + col_local_i32=a_col_local_i32[i], + tx_c4=c4_bytes, k_blocks16=k_blocks16, + lds_base=lds_base, + vec_part_i32x4=a_parts[i], elem_bytes=elem_bytes, + ) + + # ── Helper: load A K64 pack from LDS with XOR swizzle ──────── + def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): + """Load 16B from LDS with XOR16 swizzle, return two i64 halves.""" + col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base_bytes, k_blocks16) + idx_a16 = crd2idx((curr_row_a_lds, col_base_swz_bytes), layout_lds) + lds_base + loaded_a16 = vector.load_op(T.vec(16, T.f8), lds_a, [idx_a16]) + a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) + a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) + a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) + return a0, a1 + + # ── Helper: load one B pack (K32 micro-step) ──────────────── + def load_b_pack(base_k, ki_step, ni): + return load_b_pack_k32( + buffer_ops, arith, vector, + arg_b=arg_b, b_rsrc=b_rsrc, + layout_b=layout_b, + base_k=base_k, ki_step=ki_step, + n_blk=n_blk_list[ni], + n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, + elem_type=T.f8, + kpack_bytes=kpack_bytes, + elem_bytes=elem_bytes, + ) + + # ── Helper: prefetch entire B tile (gmem -> regs) ─────────── + def load_b_tile(base_k): + """Load all B packs for one K-tile. + + Returns list of length k_unroll, each entry is + (packs_half0[ni], packs_half1[ni]) for one K64 micro-step. + """ + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + ki0 = (ku * 2) + 0 + ki1 = (ku * 2) + 1 + b0 = load_b_pack(base_k, ki0, ni) + b1 = load_b_pack(base_k, ki1, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + + # Base coordinates for A0 prefetch (mi=0, ku=0) + row_a_lds_base = lane_mod_16 # mi=0 + col_offset_base_bytes = lane_div_16 * fx.Index(16) # ku=0 + + # ── Pack helper for gfx950 K=128 MFMA ────────────────────── + mfma_res_ty = T.f32x4 + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(T.vec(4, T.i64), [x0, x1, x2, x3]) + return vector.bitcast(T.vec(8, T.i32), v4) + + # ── Scheduling hints (sched_group_barrier, MoE stage 2 pattern) ── + ku_per_sb = scale_block_k // 64 + rocdl.sched_barrier(0) + + def hot_loop_scheduler(): + mfma_group = num_acc_n + if _is_gfx950: + total_mfma = sb_per_tile * m_repeat * mfma_group + else: + total_mfma = k_unroll * m_repeat * mfma_group * 2 + rocdl.sched_group_barrier(rocdl.mask_dsrd, ku_per_sb * m_repeat, 0) + rocdl.sched_group_barrier(rocdl.mask_mfma, total_mfma, 1) + rocdl.sched_group_barrier(rocdl.mask_vmem_rd, num_a_loads, 2) + rocdl.sched_group_barrier(rocdl.mask_dswr, num_a_loads, 3) + rocdl.sched_barrier(0) + + # ── Helper: prefetch E8M0 scales for one K-tile (gfx950 HW path) ── + # Returns (sa_e8m0_pf, sb_e8m0_pf) — outer index = sb (sb_per_tile), + # inner = m_repeat / num_acc_n. Issued ahead of compute_tile so + # scale-load latency overlaps with prior tile's MFMA work + B-tile load. + def prefetch_scales(k_tile_idx_py): + if not _is_gfx950: + return None + sa_pf = [] + sb_pf = [] + sb_group_offset = group_idx * fx.Index(scale_n * scale_k) + for sb in range_constexpr(sb_per_tile): + kb = fx.Index(k_tile_idx_py * sb_per_tile + sb) + sa_base_pf = kb * m_in + + sa_sb = [] + for mi in range_constexpr(m_repeat): + sa_row = bx_m + arith.index(mi * 16) + lane_mod_16 + sa_idx = sa_base_pf + sa_row + sa_i8 = buffer_ops.buffer_load(sa_rsrc, sa_idx, vec_width=1, dtype=T.i8) + sa_e8m0 = ArithValue(sa_i8).extui(T.i32) + sa_sb.append(sa_e8m0) + sa_pf.append(sa_sb) + + sb_sb = [] + for ni in range_constexpr(num_acc_n): + sb_idx = sb_group_offset + n_block_for_scale[ni] * c_scale_k + kb + sb_i8 = buffer_ops.buffer_load(sb_rsrc, sb_idx, vec_width=1, dtype=T.i8) + sb_i32 = ArithValue(sb_i8).extui(T.i32) + sb_e8m0 = rocdl.readfirstlane(T.i32, sb_i32) + sb_sb.append(sb_e8m0) + sb_pf.append(sb_sb) + return (sa_pf, sb_pf) + + def compute_tile(accs_in, k_tile_idx_py, lds_base, b_tile_in, scales_pf, *, a0_prefetch=None): + """Compute MFMA tiles for one K-tile, return updated accumulators. + + scales_pf: result of prefetch_scales(k_tile_idx_py) for the gfx950 + HW path; None for the gfx942 SW path (which loads scales locally). + """ + current_accs = list(accs_in) + + for sb in range_constexpr(sb_per_tile): + kb = fx.Index(k_tile_idx_py * sb_per_tile + sb) + + # gfx942 SW path needs s_a_vecs (4 loads/mi, lane_div_16 layout) + # and s_b_vals as f32. gfx950 HW path uses prefetched E8M0 bytes. + s_a_vecs = [] + s_b_vals = [] + if not _is_gfx950: + sa_base = kb * m_in + row_off_base = lane_div_16 * fx.Index(4) + for mi in range_constexpr(m_repeat): + s_a_row = [] + for ii in range_constexpr(4): + row_in_tile = arith.index(mi * 16) + row_off_base + fx.Index(ii) + row_global = bx_m + row_in_tile + sa_idx = sa_base + row_global + s_a_val = buffer_ops.buffer_load(sa_rsrc, sa_idx, vec_width=1, dtype=T.f32) + s_a_row.append(s_a_val) + s_a_vec4 = vector.from_elements(T.f32x4, s_a_row) + s_a_vecs.append(s_a_vec4) + + sb_group_offset = group_idx * fx.Index(scale_n * scale_k) + for ni in range_constexpr(num_acc_n): + sb_idx = sb_group_offset + n_block_for_scale[ni] * c_scale_k + kb + s_b_val = buffer_ops.buffer_load(sb_rsrc, sb_idx, vec_width=1, dtype=T.f32) + s_b_val = rocdl.readfirstlane(T.f32, s_b_val) + s_b_vals.append(s_b_val) + + # MFMA computation for this scale block + if _is_gfx950: + # gfx950: single K=128 MFMA with HW E8M0 scales prefetched + # ahead of compute_tile (scales_pf passed by caller). + sa_pf, sb_pf = scales_pf + sa_e8m0_list = sa_pf[sb] + sb_e8m0_list = sb_pf[sb] + + ku0 = sb * ku_per_sb + ku1 = ku0 + 1 + b0_packs0, b0_packs1 = b_tile_in[ku0] + b1_packs0, b1_packs1 = b_tile_in[ku1] + col_base0 = col_offset_base_bytes + fx.Index(ku0 * 64) + col_base1 = col_offset_base_bytes + fx.Index(ku1 * 64) + + for mi in range_constexpr(m_repeat): + curr_row_a_lds = lane_mod_16 + arith.index(mi * 16) + if a0_prefetch is not None and sb == 0 and mi == 0: + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + + for ni in range_constexpr(num_acc_n): + b128 = pack_i64x4_to_i32x8( + b0_packs0[ni], b0_packs1[ni], + b1_packs0[ni], b1_packs1[ni], + ) + acc_idx = mi * num_acc_n + ni + # Hardware scaling + direct accumulation + current_accs[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [a128, b128, current_accs[acc_idx], + 0, 0, 0, sa_e8m0_list[mi], 0, sb_e8m0_list[ni]], + ) + else: + # gfx942: two K=32 MFMAs per K64 micro-step + for ku_local in range_constexpr(ku_per_sb): + ku = sb * ku_per_sb + ku_local + k_offset_bytes = ku * 64 + b_packs0, b_packs1 = b_tile_in[ku] + + for mi in range_constexpr(m_repeat): + if a0_prefetch is not None and sb == 0 and ku_local == 0 and mi == 0: + a0, a1 = a0_prefetch + else: + row_a_lds = lane_mod_16 + arith.index(mi * 16) + col_a_base_bytes = lane_div_16 * fx.Index(16) + fx.Index(k_offset_bytes) + a0, a1 = lds_load_packs_k64(row_a_lds, col_a_base_bytes, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + + mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 + mfma_mid = mfma_fn(T.f32x4, [a0, b_packs0[ni], acc_init, 0, 0, 0]) + mfma_result = mfma_fn(T.f32x4, [a1, b_packs1[ni], mfma_mid, 0, 0, 0]) + + # Apply scales: accum += mfma_result * scale_a * scale_b + s_a_v4 = s_a_vecs[mi] + s_b_bc = vector.broadcast(T.f32x4, s_b_vals[ni]) + scaled = ArithValue(mfma_result) * ArithValue(s_a_v4) + current_accs[acc_idx] = math_dialect.fma(scaled, s_b_bc, current_accs[acc_idx]) + + return current_accs + + # ===== Ping-pong K-loop ===== + # Prologue: prefetch first A tile into VGPRs, store to LDS, load B + scales + a_regs0 = prefetch_a_tile(0) + store_a_tile_to_lds(a_regs0, lds_base_pong) + b_tile_pong = load_b_tile(fx.Index(0)) + scales_pong_pf = prefetch_scales(0) + gpu.barrier() + + # Prefetch first A pack from pong (hides LDS latency behind upcoming VMEM) + a0_prefetch_pong = lds_load_packs_k64(row_a_lds_base, col_offset_base_bytes, lds_base_pong) + + for k_pair in range_constexpr(0, num_k_tiles, 2): + # Prefetch next scales BEFORE B-tile VMEM (per moe-2stage pattern: + # scale-load latency hides behind heavy B VMEM); then A+B regs. + if k_pair + 1 < num_k_tiles: + scales_ping_pf = prefetch_scales(k_pair + 1) + a_regs_ping = prefetch_a_tile(k_pair + 1) + b_tile_ping = load_b_tile(fx.Index((k_pair + 1) * tile_k)) + + # Compute current tile from pong LDS + accs = compute_tile(accs, k_pair, lds_base_pong, b_tile_pong, scales_pong_pf, + a0_prefetch=a0_prefetch_pong) + a0_prefetch_pong = None + + # Store next A to LDS (ds_write after compute, overlaps with trailing MFMAs) + if k_pair + 1 < num_k_tiles: + store_a_tile_to_lds(a_regs_ping, lds_base_ping) + hot_loop_scheduler() + gpu.barrier() + + if k_pair + 1 < num_k_tiles: + # Prefetch first A pack from ping + a0_prefetch_ping = lds_load_packs_k64( + row_a_lds_base, col_offset_base_bytes, lds_base_ping) + + # Prefetch next scales + A+B + if k_pair + 2 < num_k_tiles: + scales_pong_pf = prefetch_scales(k_pair + 2) + a_regs_pong = prefetch_a_tile(k_pair + 2) + b_tile_pong = load_b_tile(fx.Index((k_pair + 2) * tile_k)) + + # Compute current tile from ping LDS + accs = compute_tile(accs, k_pair + 1, lds_base_ping, b_tile_ping, scales_ping_pf, + a0_prefetch=a0_prefetch_ping) + a0_prefetch_ping = None + + # Store next A to LDS + if k_pair + 2 < num_k_tiles: + store_a_tile_to_lds(a_regs_pong, lds_base_pong) + hot_loop_scheduler() + gpu.barrier() + + # Prefetch first A pack from pong for next iteration + if k_pair + 2 < num_k_tiles: + a0_prefetch_pong = lds_load_packs_k64( + row_a_lds_base, col_offset_base_bytes, lds_base_pong) + + # ===== Epilogue: CShuffle vectorized stores ===== + c_n = n_in + vec1_out = T.vec(1, out_mlir()) + e_vec = 4 if (tile_n % (32 * 4)) == 0 else 2 + + def write_row_to_lds( + *, mi, ii, row_in_tile, row, + row_base_lds, col_base_local, num_acc_n, lds_out, + ): + for ni in range_constexpr(num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * num_acc_n + ni + acc = accs[acc_idx] + val = vector.extract(acc, static_position=[ii], dynamic_position=[]) + v_out = arith.trunc_f(out_mlir(), val) + lds_idx = row_base_lds + col_local + v1 = vector.from_elements(vec1_out, [v_out]) + vector.store(v1, lds_out, [lds_idx], alignment=2) + + def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + idx_out = row * c_n + col_g0 + byte_off = idx_out * 2 + if e_vec == 4: + frag_i32x2 = vector.bitcast(T.vec(2, T.i32), frag) + buffer_ops.buffer_store( + frag_i32x2, d_rsrc, byte_off, offset_is_bytes=True + ) + else: + frag_i32x1 = vector.bitcast(T.vec(1, T.i32), frag) + frag_i32 = vector.extract( + frag_i32x1, static_position=[0], dynamic_position=[] + ) + buffer_ops.buffer_store( + frag_i32, d_rsrc, byte_off, offset_is_bytes=True + ) + + mfma_epilog( + use_cshuffle=True, + arith=arith, vector=vector, gpu=gpu, + range_constexpr=range_constexpr, + tile_m=tile_m, tile_n=tile_n, e_vec=e_vec, + m_repeat=m_repeat, num_acc_n=num_acc_n, + tx=tx, lane_div_16=lane_div_16, lane_mod_16=lane_mod_16, + bx_m=bx_m, by_n=by_n, n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=out_mlir(), + write_row_to_lds=write_row_to_lds, + store_pair=store_pair, + ) + + scf.YieldOp([]) + + # ===== JIT Launcher ===== + @flyc.jit + def launch_grouped_gemm_blockscale_contiguous( + arg_d: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_scale_a: fx.Tensor, + arg_scale_b: fx.Tensor, + arg_grouped_layout: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + i32_k: fx.Int32, + i32_num_groups: fx.Int32, + stream: fx.Stream, + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + # Grid dimensions + m_in = arith.index_cast(T.index, i32_m) + n_in = arith.index_cast(T.index, i32_n) + gx = n_in // fx.Index(tile_n) # N-blocks + gy = (m_in + fx.Index(tile_m - 1)) // fx.Index(tile_m) # M-blocks (ceil) + + launcher = grouped_gemm_blockscale_contiguous_kernel( + arg_d, + arg_a, + arg_b, + arg_scale_a, + arg_scale_b, + arg_grouped_layout, + i32_m, + i32_n, + i32_k, + i32_num_groups, + ) + if waves_per_eu is not None: + _wpe = int(waves_per_eu) + if _wpe >= 1: + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get(T.i32, _wpe) + launcher.launch(grid=(gx, gy, 1), block=(total_threads, 1, 1), stream=stream) + + return launch_grouped_gemm_blockscale_contiguous diff --git a/kernels/grouped_gemm_blockscale_masked.py b/kernels/grouped_gemm_blockscale_masked.py new file mode 100644 index 000000000..ec759d040 --- /dev/null +++ b/kernels/grouped_gemm_blockscale_masked.py @@ -0,0 +1,697 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Masked Grouped FP8 GEMM kernel (M-grouped masked layout). + +API matching DeepGEMM's m_grouped_fp8_gemm_nt_masked: + - A: [G, expected_m, K] FP8 - padded activation tensor per group + - scale_a: [G, scale_k, expected_m] - per-token, per-128K scales (transposed). + uint8 (E8M0) on gfx950 (HW scaling); FP32 on gfx942 (SW scaling). + - B: [G, N, K] FP8 - one weight matrix per group + - scale_b: [G, scale_n, scale_k] - per-block scales. + uint8 (E8M0) on gfx950; FP32 on gfx942. + - D: [G, expected_m, N] BF16 - padded output tensor per group + - masked_m: [G] INT32 - tracks the actual number of valid tokens per group + - expected_m: INT32 - the padded capacity (max_m) for the M dimension + +Block scaling granularity (matching DeepGEMM's 1D2D configuration): + - A: (1, 128) - per-token, per-128-K-elements + - B: (128, 128) - per-128-N, per-128-K block +""" + +import functools + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import arith, gpu, buffer_ops, vector, rocdl +from flydsl.expr import range_constexpr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from flydsl._mlir import ir +from flydsl._mlir.dialects import scf +from flydsl._mlir.dialects import math as math_dialect +from flydsl.expr.typing import T +from flydsl.expr.arith import ArithValue + +from kernels.mfma_epilogues import mfma_epilog +from kernels.mfma_preshuffle_pipeline import ( + crd2idx, + lds_store_16b_xor16, + load_b_pack_k32, + make_preshuffle_b_layout, + swizzle_xor16, + tile_chunk_coord_i32, +) + + +@functools.lru_cache(maxsize=128) +def compile_grouped_gemm_blockscale_masked( + *, + n: int, + k: int, + num_groups: int, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + scale_block_k: int = 128, + scale_block_n: int = 128, + out_dtype: str = "bf16", + waves_per_eu: int | None = None, +): + """Compile masked grouped FP8 GEMM kernel and return the JIT launcher. + + Args: + n: N dimension (output columns per group) + k: K dimension (reduction dimension) + num_groups: Number of groups (experts) + tile_m: M tile size (default 128) + tile_n: N tile size (default 128) + tile_k: K tile size (default 128) + scale_block_k: K-dimension scale block size (default 128) + scale_block_n: N-dimension scale block size (default 128) + out_dtype: Output data type ("bf16" or "f16") + + Returns: + JIT launcher function. + """ + gpu_arch = get_hip_arch() + _is_gfx950 = str(gpu_arch).startswith("gfx95") + + allocator = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_masked_grouped_gemm") + + # Validate parameters + if k % tile_k != 0: + raise ValueError(f"k ({k}) must be divisible by tile_k ({tile_k})") + if n % tile_n != 0: + raise ValueError(f"n ({n}) must be divisible by tile_n ({tile_n})") + if tile_k % scale_block_k != 0: + raise ValueError(f"tile_k ({tile_k}) must be divisible by scale_block_k ({scale_block_k})") + if tile_n % scale_block_n != 0: + raise ValueError(f"tile_n ({tile_n}) must be divisible by scale_block_n ({scale_block_n})") + + # Output type + if out_dtype not in ("bf16", "f16"): + raise ValueError(f"out_dtype must be 'bf16' or 'f16', got {out_dtype!r}") + out_mlir = lambda: T.bf16 if out_dtype == "bf16" else T.f16 + + # Compile-time constants + total_threads = 256 + elem_bytes = 1 # FP8 + num_k_tiles = k // tile_k + scale_k = k // scale_block_k + scale_n = n // scale_block_n + sb_per_tile = tile_k // scale_block_k # scale blocks per K-tile + k_unroll = tile_k // 64 # K64-byte micro-steps (for K32 MFMA pairs) + kpack_bytes = 16 # 16-byte packs for FP8 + + # LDS allocation: max of ping-pong A tiles and CShuffle epilogue output + lds_a_bytes = tile_m * tile_k * elem_bytes + lds_pingpong_bytes = 2 * lds_a_bytes + lds_out_bytes = tile_m * tile_n * 2 # bf16/f16 = 2 bytes per element + lds_total_bytes = max(lds_pingpong_bytes, lds_out_bytes) + lds_alloc_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = lds_alloc_offset + lds_total_bytes + lds_tile_elems = tile_m * tile_k # element offset between ping and pong + + # Module name for caching + module_name = ( + f"grouped_gemm_blockscale_masked_{out_dtype}" + f"_n{n}_k{k}_g{num_groups}" + f"_t{tile_m}x{tile_n}x{tile_k}" + f"_pingpong" + ).replace("-", "_") + + # Thread -> tile element mapping for A loads + tile_k_bytes = tile_k * elem_bytes + tile_k_dwords = tile_k_bytes // 4 + bytes_a_per_tile = tile_m * tile_k * elem_bytes + bytes_per_thread_a = bytes_a_per_tile // total_threads + a_load_bytes = 16 # 16-byte loads (dwordx4) + chunk_i32_a = a_load_bytes // 4 # 4 dwords per load + num_a_loads = bytes_per_thread_a // a_load_bytes + + @flyc.kernel(name=module_name) + def grouped_gemm_blockscale_masked_kernel( + arg_d: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_scale_a: fx.Tensor, + arg_scale_b: fx.Tensor, + arg_masked_m: fx.Tensor, + i32_expected_m: fx.Int32, + i32_n: fx.Int32, + i32_k: fx.Int32, + i32_num_groups: fx.Int32, + ): + # Convert runtime parameters to index type + # In the masked kernel, expected_m acts as our padded max capacity per group. + m_in = arith.index_cast(T.index, i32_expected_m) + n_in = arith.index_cast(T.index, i32_n) + k_in = arith.index_cast(T.index, i32_k) + num_groups_in = arith.index_cast(T.index, i32_num_groups) + + # Thread and 3D block IDs + tx = gpu.thread_id("x") + by = gpu.block_id("x") # N-block index + bx = gpu.block_id("y") # M-block index + bz = gpu.block_id("z") # Group ID index + group_idx = bz + + # Block positions + bx_m = bx * fx.Index(tile_m) + by_n = by * fx.Index(tile_n) + + # Wave/lane decomposition (256 threads = 4 waves x 64 lanes) + layout_wave_lane = fx.make_layout((4, 64), stride=(64, 1)) + coord_wave_lane = fx.idx2crd(tx, layout_wave_lane) + wave_id = fx.get(coord_wave_lane, 0) + lane_id = fx.get(coord_wave_lane, 1) + + # Lane decomposition for MFMA (lane_id -> lane_div_16, lane_mod_16) + layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) + coord_lane16 = fx.idx2crd(lane_id, layout_lane16) + lane_div_16 = fx.get(coord_lane16, 0) + lane_mod_16 = fx.get(coord_lane16, 1) + + # LDS setup: single memref for both ping-pong buffers + base_ptr = allocator.get_base() + lds_a = SmemPtr(base_ptr, lds_alloc_offset, T.f8, shape=(2 * tile_m * tile_k,)).get() + lds_stride = tile_k + layout_lds = fx.make_layout((tile_m, tile_k), stride=(lds_stride, 1)) + lds_base_pong = fx.Index(0) + lds_base_ping = fx.Index(lds_tile_elems) + + # CShuffle epilogue LDS (aliased from same base, bf16 element type) + lds_out = SmemPtr(base_ptr, lds_alloc_offset, out_mlir(), shape=(tile_m * tile_n,)).get() + + # Buffer resources + a_nbytes = num_groups_in * m_in * k_in + a_rsrc = buffer_ops.create_buffer_resource( + arg_a, max_size=False, num_records_bytes=arith.index_cast(T.i64, a_nbytes) + ) + + b_nbytes = num_groups_in * n_in * k_in + b_rsrc = buffer_ops.create_buffer_resource( + arg_b, max_size=False, num_records_bytes=arith.index_cast(T.i64, b_nbytes) + ) + + d_nbytes = num_groups_in * m_in * n_in * fx.Index(2) # bf16/f16 = 2 bytes + d_rsrc = buffer_ops.create_buffer_resource( + arg_d, max_size=False, num_records_bytes=arith.index_cast(T.i64, d_nbytes) + ) + + # Scale buffers — gfx950 HW E8M0 path consumes int8 (one byte/scale, + # pre-packed on host); gfx942 SW path consumes f32. + scale_byte_size = 1 if _is_gfx950 else 4 + + # scale_a: [G, scale_k, max_m] + sa_nbytes = num_groups_in * fx.Index(scale_k) * m_in * fx.Index(scale_byte_size) + sa_rsrc = buffer_ops.create_buffer_resource( + arg_scale_a, max_size=False, num_records_bytes=arith.index_cast(T.i64, sa_nbytes) + ) + + # scale_b: [G, scale_n, scale_k] + sb_nbytes = num_groups_in * fx.Index(scale_n * scale_k * scale_byte_size) + sb_rsrc = buffer_ops.create_buffer_resource( + arg_scale_b, max_size=False, num_records_bytes=arith.index_cast(T.i64, sb_nbytes) + ) + + # masked_m: [G] + mask_nbytes = num_groups_in * fx.Index(4) + mask_rsrc = buffer_ops.create_buffer_resource( + arg_masked_m, max_size=False, num_records_bytes=arith.index_cast(T.i64, mask_nbytes) + ) + + # Early exit for invalid blocks that fall entirely within the padded garbage + bx_m_i32 = arith.index_cast(T.i32, bx_m) + valid_m_i32 = buffer_ops.buffer_load(mask_rsrc, group_idx, vec_width=1, dtype=T.i32) + is_valid = arith.cmpi(arith.CmpIPredicate.ult, bx_m_i32, valid_m_i32) + + _if_valid = scf.IfOp(is_valid) + with ir.InsertionPoint(_if_valid.then_block): + + # MFMA tiling constants + m_repeat = tile_m // 16 # 8 for tile_m=128 + num_waves = 4 + n_per_wave = tile_n // num_waves # 32 for tile_n=128 + num_acc_n = n_per_wave // 16 # 2 for n_per_wave=32 + + # Initialize accumulators (FP32) + acc_init = arith.constant_vector(0.0, T.f32x4) + num_accs = m_repeat * num_acc_n + accs = [acc_init] * num_accs + + # Wave's N-tile base + wave_mod_4 = wave_id % fx.Index(4) + n_tile_base = wave_mod_4 * fx.Index(n_per_wave) + + # Precompute N-block indices for scale_b + c_scale_block_n = fx.Index(scale_block_n) + c_scale_k = fx.Index(scale_k) + n_block_for_scale = [] + for ni in range_constexpr(num_acc_n): + col_base = by_n + n_tile_base + arith.index(ni * 16) + n_blk = col_base // c_scale_block_n + n_block_for_scale.append(n_blk) + + # B preshuffle layout: total N = num_groups * N (all groups concatenated) + c_n_total = num_groups_in * n_in + b_layout = make_preshuffle_b_layout( + arith, c_n=c_n_total, c_k=k_in, + kpack_bytes=kpack_bytes, elem_bytes=elem_bytes, + ) + layout_b = b_layout.layout_b + + # Decompose global N column into (n_blk, n_intra) for preshuffle layout + c_n0 = c_n_total // fx.Index(16) + c_n0_i32 = arith.index_cast(T.i32, c_n0) + layout_n_blk_intra = fx.make_layout((c_n0_i32, 16), stride=(16, 1)) + n_blk_list = [] + n_intra_list = [] + group_n_off = group_idx * n_in # N-offset for this group in concatenated B + for ni in range_constexpr(num_acc_n): + col_global = group_n_off + by_n + n_tile_base + arith.index(ni * 16) + lane_mod_16 + coord_ni = fx.idx2crd(col_global, layout_n_blk_intra) + n_blk_list.append(fx.get(coord_ni, 0)) + n_intra_list.append(fx.get(coord_ni, 1)) + + # A load mapping: thread -> (row, col_i32) in tile (dword-indexed K) + layout_a_tile_div4 = fx.make_layout( + (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) + ) + c_chunk_a = fx.Index(chunk_i32_a) + tx_i32_base = tx * c_chunk_a + _k_div4_factor = k_in // fx.Index(4) + group_a_off_div4 = group_idx * m_in * _k_div4_factor # 3D A Offset + k_blocks16 = arith.index(tile_k_bytes // 16) + c4_bytes = fx.Index(4) + + # Precompute per-load tile coordinates (row_local, col_local_i32) + a_row_local = [] + a_col_local_i32 = [] + for i in range_constexpr(num_a_loads): + row_local, col_local_i32 = tile_chunk_coord_i32( + arith, tx_i32_base=tx_i32_base, i=i, + total_threads=total_threads, + layout_tile_div4=layout_a_tile_div4, + chunk_i32=chunk_i32_a, + ) + a_row_local.append(row_local) + a_col_local_i32.append(col_local_i32) + + # ── Helper: prefetch A tile (gmem -> regs) ───────────────── + def prefetch_a_tile(k_tile_idx_py): + """Load A tile from global memory into VGPRs.""" + base_k_div4 = fx.Index(k_tile_idx_py * tile_k_dwords) + parts = [] + for i in range_constexpr(num_a_loads): + row_global = bx_m + a_row_local[i] + idx_i32 = group_a_off_div4 + row_global * _k_div4_factor + base_k_div4 + a_col_local_i32[i] + a_vec = buffer_ops.buffer_load(a_rsrc, idx_i32, vec_width=4, dtype=T.i32) + parts.append(vector.bitcast(T.i32x4, a_vec)) + return parts + + # ── Helper: store A regs to LDS with XOR swizzle ─────────── + def store_a_tile_to_lds(a_parts, lds_base): + """Write prefetched A tile from VGPRs into LDS with XOR16 swizzle.""" + for i in range_constexpr(num_a_loads): + lds_store_16b_xor16( + arith, vector, + lds_memref=lds_a, vec16_ty=T.f8x16, + layout_lds=layout_lds, + row_local=a_row_local[i], + col_local_i32=a_col_local_i32[i], + tx_c4=c4_bytes, k_blocks16=k_blocks16, + lds_base=lds_base, + vec_part_i32x4=a_parts[i], elem_bytes=elem_bytes, + ) + + # ── Helper: load A K64 pack from LDS with XOR swizzle ──────── + def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): + """Load 16B from LDS with XOR16 swizzle, return two i64 halves.""" + col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base_bytes, k_blocks16) + idx_a16 = crd2idx((curr_row_a_lds, col_base_swz_bytes), layout_lds) + lds_base + loaded_a16 = vector.load_op(T.vec(16, T.f8), lds_a, [idx_a16]) + a_i64x2 = vector.bitcast(T.i64x2, loaded_a16) + a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) + a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) + return a0, a1 + + # ── Helper: load one B pack (K32 micro-step) ──────────────── + def load_b_pack(base_k, ki_step, ni): + return load_b_pack_k32( + buffer_ops, arith, vector, + arg_b=arg_b, b_rsrc=b_rsrc, + layout_b=layout_b, + base_k=base_k, ki_step=ki_step, + n_blk=n_blk_list[ni], + n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, + elem_type=T.f8, + kpack_bytes=kpack_bytes, + elem_bytes=elem_bytes, + ) + + # ── Helper: prefetch entire B tile (gmem -> regs) ─────────── + def load_b_tile(base_k): + """Load all B packs for one K-tile. + + Returns list of length k_unroll, each entry is + (packs_half0[ni], packs_half1[ni]) for one K64 micro-step. + """ + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + ki0 = (ku * 2) + 0 + ki1 = (ku * 2) + 1 + b0 = load_b_pack(base_k, ki0, ni) + b1 = load_b_pack(base_k, ki1, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + + # Base coordinates for A0 prefetch (mi=0, ku=0) + row_a_lds_base = lane_mod_16 # mi=0 + col_offset_base_bytes = lane_div_16 * fx.Index(16) # ku=0 + + # ── Pack helper for gfx950 K=128 MFMA ────────────────────── + mfma_res_ty = T.f32x4 + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(T.vec(4, T.i64), [x0, x1, x2, x3]) + return vector.bitcast(T.vec(8, T.i32), v4) + + # ── Scheduling hints (sched_group_barrier, MoE stage 2 pattern) ── + ku_per_sb = scale_block_k // 64 + rocdl.sched_barrier(0) + + def hot_loop_scheduler(): + mfma_group = num_acc_n + if _is_gfx950: + total_mfma = sb_per_tile * m_repeat * mfma_group + else: + total_mfma = k_unroll * m_repeat * mfma_group * 2 + rocdl.sched_group_barrier(rocdl.mask_dsrd, ku_per_sb * m_repeat, 0) + rocdl.sched_group_barrier(rocdl.mask_mfma, total_mfma, 1) + rocdl.sched_group_barrier(rocdl.mask_vmem_rd, num_a_loads, 2) + rocdl.sched_group_barrier(rocdl.mask_dswr, num_a_loads, 3) + rocdl.sched_barrier(0) + + sa_group_off = group_idx * c_scale_k * m_in # 3D scale_a Offset + + # ── Helper: prefetch E8M0 scales for one K-tile (gfx950 HW path) ── + def prefetch_scales(k_tile_idx_py): + if not _is_gfx950: + return None + sa_pf = [] + sb_pf = [] + sb_group_offset = group_idx * fx.Index(scale_n * scale_k) + for sb in range_constexpr(sb_per_tile): + kb = fx.Index(k_tile_idx_py * sb_per_tile + sb) + sa_base_pf = sa_group_off + kb * m_in + + sa_sb = [] + for mi in range_constexpr(m_repeat): + sa_row = bx_m + arith.index(mi * 16) + lane_mod_16 + sa_idx = sa_base_pf + sa_row + sa_i8 = buffer_ops.buffer_load(sa_rsrc, sa_idx, vec_width=1, dtype=T.i8) + sa_e8m0 = ArithValue(sa_i8).extui(T.i32) + sa_sb.append(sa_e8m0) + sa_pf.append(sa_sb) + + sb_sb = [] + for ni in range_constexpr(num_acc_n): + sb_idx = sb_group_offset + n_block_for_scale[ni] * c_scale_k + kb + sb_i8 = buffer_ops.buffer_load(sb_rsrc, sb_idx, vec_width=1, dtype=T.i8) + sb_i32 = ArithValue(sb_i8).extui(T.i32) + sb_e8m0 = rocdl.readfirstlane(T.i32, sb_i32) + sb_sb.append(sb_e8m0) + sb_pf.append(sb_sb) + return (sa_pf, sb_pf) + + def compute_tile(accs_in, k_tile_idx_py, lds_base, b_tile_in, scales_pf, *, a0_prefetch=None): + """Compute MFMA tiles for one K-tile, return updated accumulators. + + scales_pf: result of prefetch_scales(k_tile_idx_py) for the gfx950 + HW path; None for the gfx942 SW path (which loads scales locally). + """ + current_accs = list(accs_in) + + for sb in range_constexpr(sb_per_tile): + kb = fx.Index(k_tile_idx_py * sb_per_tile + sb) + + # gfx942 SW path needs s_a_vecs (4 loads/mi, lane_div_16 layout) + # and s_b_vals as f32. gfx950 HW path uses prefetched E8M0 bytes. + s_a_vecs = [] + s_b_vals = [] + if not _is_gfx950: + sa_base = sa_group_off + kb * m_in + row_off_base = lane_div_16 * fx.Index(4) + for mi in range_constexpr(m_repeat): + s_a_row = [] + for ii in range_constexpr(4): + row_in_tile = arith.index(mi * 16) + row_off_base + fx.Index(ii) + row_global = bx_m + row_in_tile + sa_idx = sa_base + row_global + s_a_val = buffer_ops.buffer_load(sa_rsrc, sa_idx, vec_width=1, dtype=T.f32) + s_a_row.append(s_a_val) + s_a_vec4 = vector.from_elements(T.f32x4, s_a_row) + s_a_vecs.append(s_a_vec4) + + sb_group_offset = group_idx * fx.Index(scale_n * scale_k) + for ni in range_constexpr(num_acc_n): + sb_idx = sb_group_offset + n_block_for_scale[ni] * c_scale_k + kb + s_b_val = buffer_ops.buffer_load(sb_rsrc, sb_idx, vec_width=1, dtype=T.f32) + s_b_val = rocdl.readfirstlane(T.f32, s_b_val) + s_b_vals.append(s_b_val) + + # MFMA computation for this scale block + if _is_gfx950: + # gfx950: single K=128 MFMA with HW E8M0 scales prefetched. + sa_pf, sb_pf = scales_pf + sa_e8m0_list = sa_pf[sb] + sb_e8m0_list = sb_pf[sb] + + ku0 = sb * ku_per_sb + ku1 = ku0 + 1 + b0_packs0, b0_packs1 = b_tile_in[ku0] + b1_packs0, b1_packs1 = b_tile_in[ku1] + col_base0 = col_offset_base_bytes + fx.Index(ku0 * 64) + col_base1 = col_offset_base_bytes + fx.Index(ku1 * 64) + + for mi in range_constexpr(m_repeat): + curr_row_a_lds = lane_mod_16 + arith.index(mi * 16) + if a0_prefetch is not None and sb == 0 and mi == 0: + a0, a1 = a0_prefetch + else: + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base0, lds_base) + a2, a3 = lds_load_packs_k64(curr_row_a_lds, col_base1, lds_base) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + + for ni in range_constexpr(num_acc_n): + b128 = pack_i64x4_to_i32x8( + b0_packs0[ni], b0_packs1[ni], + b1_packs0[ni], b1_packs1[ni], + ) + acc_idx = mi * num_acc_n + ni + # Hardware scaling + direct accumulation + current_accs[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [a128, b128, current_accs[acc_idx], + 0, 0, 0, sa_e8m0_list[mi], 0, sb_e8m0_list[ni]], + ) + else: + # gfx942: two K=32 MFMAs per K64 micro-step + for ku_local in range_constexpr(ku_per_sb): + ku = sb * ku_per_sb + ku_local + k_offset_bytes = ku * 64 + b_packs0, b_packs1 = b_tile_in[ku] + + for mi in range_constexpr(m_repeat): + if a0_prefetch is not None and sb == 0 and ku_local == 0 and mi == 0: + a0, a1 = a0_prefetch + else: + row_a_lds = lane_mod_16 + arith.index(mi * 16) + col_a_base_bytes = lane_div_16 * fx.Index(16) + fx.Index(k_offset_bytes) + a0, a1 = lds_load_packs_k64(row_a_lds, col_a_base_bytes, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + + mfma_fn = rocdl.mfma_f32_16x16x32_fp8_fp8 + mfma_mid = mfma_fn(T.f32x4, [a0, b_packs0[ni], acc_init, 0, 0, 0]) + mfma_result = mfma_fn(T.f32x4, [a1, b_packs1[ni], mfma_mid, 0, 0, 0]) + + # Apply scales: accum += mfma_result * scale_a * scale_b + s_a_v4 = s_a_vecs[mi] + s_b_bc = vector.broadcast(T.f32x4, s_b_vals[ni]) + scaled = ArithValue(mfma_result) * ArithValue(s_a_v4) + current_accs[acc_idx] = math_dialect.fma(scaled, s_b_bc, current_accs[acc_idx]) + + return current_accs + + # ===== Ping-pong K-loop ===== + # Prologue: prefetch first A tile into VGPRs, store to LDS, load B + scales + a_regs0 = prefetch_a_tile(0) + store_a_tile_to_lds(a_regs0, lds_base_pong) + b_tile_pong = load_b_tile(fx.Index(0)) + scales_pong_pf = prefetch_scales(0) + gpu.barrier() + + # Prefetch first A pack from pong (hides LDS latency behind upcoming VMEM) + a0_prefetch_pong = lds_load_packs_k64(row_a_lds_base, col_offset_base_bytes, lds_base_pong) + + for k_pair in range_constexpr(0, num_k_tiles, 2): + # Prefetch next scales BEFORE B-tile VMEM (per moe-2stage pattern: + # scale-load latency hides behind heavy B VMEM); then A+B regs. + if k_pair + 1 < num_k_tiles: + scales_ping_pf = prefetch_scales(k_pair + 1) + a_regs_ping = prefetch_a_tile(k_pair + 1) + b_tile_ping = load_b_tile(fx.Index((k_pair + 1) * tile_k)) + + # Compute current tile from pong LDS + accs = compute_tile(accs, k_pair, lds_base_pong, b_tile_pong, scales_pong_pf, + a0_prefetch=a0_prefetch_pong) + a0_prefetch_pong = None + + # Store next A to LDS (ds_write after compute) + if k_pair + 1 < num_k_tiles: + store_a_tile_to_lds(a_regs_ping, lds_base_ping) + hot_loop_scheduler() + gpu.barrier() + + if k_pair + 1 < num_k_tiles: + # Prefetch first A pack from ping + a0_prefetch_ping = lds_load_packs_k64( + row_a_lds_base, col_offset_base_bytes, lds_base_ping) + + # Prefetch next scales + A+B + if k_pair + 2 < num_k_tiles: + scales_pong_pf = prefetch_scales(k_pair + 2) + a_regs_pong = prefetch_a_tile(k_pair + 2) + b_tile_pong = load_b_tile(fx.Index((k_pair + 2) * tile_k)) + + # Compute current tile from ping LDS + accs = compute_tile(accs, k_pair + 1, lds_base_ping, b_tile_ping, scales_ping_pf, + a0_prefetch=a0_prefetch_ping) + a0_prefetch_ping = None + + # Store next A to LDS + if k_pair + 2 < num_k_tiles: + store_a_tile_to_lds(a_regs_pong, lds_base_pong) + hot_loop_scheduler() + gpu.barrier() + + # Prefetch first A pack from pong for next iteration + if k_pair + 2 < num_k_tiles: + a0_prefetch_pong = lds_load_packs_k64( + row_a_lds_base, col_offset_base_bytes, lds_base_pong) + + # ===== Epilogue: CShuffle vectorized stores ===== + c_n = n_in + d_group_off = group_idx * m_in * n_in # 3D D Offset + vec1_out = T.vec(1, out_mlir()) + e_vec = 4 if (tile_n % (32 * 4)) == 0 else 2 + + def write_row_to_lds( + *, mi, ii, row_in_tile, row, + row_base_lds, col_base_local, num_acc_n, lds_out, + ): + for ni in range_constexpr(num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * num_acc_n + ni + acc = accs[acc_idx] + val = vector.extract(acc, static_position=[ii], dynamic_position=[]) + v_out = arith.trunc_f(out_mlir(), val) + lds_idx = row_base_lds + col_local + v1 = vector.from_elements(vec1_out, [v_out]) + vector.store(v1, lds_out, [lds_idx], alignment=2) + + def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + idx_out = d_group_off + row * c_n + col_g0 + byte_off = idx_out * 2 + if e_vec == 4: + frag_i32x2 = vector.bitcast(T.vec(2, T.i32), frag) + buffer_ops.buffer_store( + frag_i32x2, d_rsrc, byte_off, offset_is_bytes=True + ) + else: + frag_i32x1 = vector.bitcast(T.vec(1, T.i32), frag) + frag_i32 = vector.extract( + frag_i32x1, static_position=[0], dynamic_position=[] + ) + buffer_ops.buffer_store( + frag_i32, d_rsrc, byte_off, offset_is_bytes=True + ) + + mfma_epilog( + use_cshuffle=True, + arith=arith, vector=vector, gpu=gpu, + range_constexpr=range_constexpr, + tile_m=tile_m, tile_n=tile_n, e_vec=e_vec, + m_repeat=m_repeat, num_acc_n=num_acc_n, + tx=tx, lane_div_16=lane_div_16, lane_mod_16=lane_mod_16, + bx_m=bx_m, by_n=by_n, n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=out_mlir(), + write_row_to_lds=write_row_to_lds, + store_pair=store_pair, + ) + + scf.YieldOp([]) + + # ===== JIT Launcher ===== + @flyc.jit + def launch_grouped_gemm_blockscale_masked( + arg_d: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_scale_a: fx.Tensor, + arg_scale_b: fx.Tensor, + arg_masked_m: fx.Tensor, + i32_expected_m: fx.Int32, + i32_n: fx.Int32, + i32_k: fx.Int32, + i32_num_groups: fx.Int32, + stream: fx.Stream, + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + # Grid dimensions + max_m_in = arith.index_cast(T.index, i32_expected_m) + n_in = arith.index_cast(T.index, i32_n) + num_groups_in = arith.index_cast(T.index, i32_num_groups) + + gx = n_in // fx.Index(tile_n) # N-blocks + gy = (max_m_in + fx.Index(tile_m - 1)) // fx.Index(tile_m) # M-blocks (ceil) + gz = num_groups_in + + launcher = grouped_gemm_blockscale_masked_kernel( + arg_d, + arg_a, + arg_b, + arg_scale_a, + arg_scale_b, + arg_masked_m, + i32_expected_m, + i32_n, + i32_k, + i32_num_groups, + ) + if waves_per_eu is not None: + _wpe = int(waves_per_eu) + if _wpe >= 1: + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get(T.i32, _wpe) + launcher.launch(grid=(gx, gy, gz), block=(total_threads, 1, 1), stream=stream) + + return launch_grouped_gemm_blockscale_masked \ No newline at end of file diff --git a/tests/arch_compat.py b/tests/arch_compat.py index fb6bab523..6d3792a1d 100644 --- a/tests/arch_compat.py +++ b/tests/arch_compat.py @@ -11,6 +11,8 @@ CDNA_ONLY_TESTS = frozenset({ "test_preshuffle_gemm.py", "test_blockscale_preshuffle_gemm.py", + "test_grouped_gemm_blockscale_contiguous.py", + "test_grouped_gemm_blockscale_masked.py", "test_moe_gemm.py", "test_moe_blockscale.py", "test_moe_reduce.py", diff --git a/tests/kernels/test_grouped_gemm_blockscale_contiguous.py b/tests/kernels/test_grouped_gemm_blockscale_contiguous.py new file mode 100644 index 000000000..060698563 --- /dev/null +++ b/tests/kernels/test_grouped_gemm_blockscale_contiguous.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Tests for Contiguous Grouped FP8 GEMM kernel (blockscale). + +Tests the contiguous grouped FP8 GEMM with block scaling, matching DeepGEMM's +m_grouped_fp8_gemm_nt_contiguous API. +""" + +import os +import sys +import logging + +import torch +import pytest + +pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYTHON_CANDIDATES = [ + os.path.join(_REPO_ROOT, "build", "python_packages"), + _REPO_ROOT, +] +for _p in reversed(_PYTHON_CANDIDATES): + if os.path.isdir(_p) and _p not in sys.path: + sys.path.insert(0, _p) + +from kernels.grouped_gemm_blockscale_contiguous import compile_grouped_gemm_blockscale_contiguous +from flydsl.runtime.device import get_rocm_arch +from tests.test_common import run_perftest, verify_output +from tests.utils import shuffle_weight + +logging.basicConfig(level=logging.INFO) + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +ARCH = get_rocm_arch() +# Use appropriate FP8 dtype for the architecture +DTYPE_FP8 = torch.float8_e4m3fn if "gfx95" in ARCH else torch.float8_e4m3fnuz +# gfx950 uses hardware E8M0 block scaling — quantization must use E8M0-truncated scales +USE_UE8M0 = "gfx95" in ARCH + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def fp32_to_e8m0(scale: torch.Tensor) -> torch.Tensor: + """Round FP32 scale UP to E8M0 precision (ceiling on exponent). + + Rounding up is required so that x / scale_e8m0 <= fp8_max — truncation + would shrink the scale, causing FP8 saturation and a systematic bias on + every block. + """ + bits = scale.abs().float().view(torch.int32) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int() + return (exp.clamp(1, 254) << 23).view(torch.float32) + + +def fp32_e8m0_to_byte(scale_e8m0_f32: torch.Tensor) -> torch.Tensor: + """Extract the E8M0 byte from a float that was previously rounded by + fp32_to_e8m0. Returns uint8. Use this when handing scales to the kernel + so dequant uses bit-exact the same scale the kernel will apply (HW E8M0 + path uses byte 0 of the i32 scale operand).""" + bits = scale_e8m0_f32.view(torch.int32) + return ((bits >> 23) & 0xFF).to(torch.uint8) + + +def quantize_to_fp8(x: torch.Tensor, scale_block_k: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize tensor to FP8 with per-row, per-block scaling. + + Args: + x: Input tensor [M, K] + scale_block_k: K-dimension block size for scaling + + Returns: + (x_fp8, scale): FP8 tensor and scale factors [scale_k, M] + """ + M, K = x.shape + nblk_k = K // scale_block_k + + # Reshape to [M, nblk_k, scale_block_k] + x_blocks = x.view(M, nblk_k, scale_block_k) + + # Compute per-block max (for scale) + x_amax = x_blocks.abs().amax(dim=2).clamp(min=1e-12) + + fp8_max = torch.finfo(DTYPE_FP8).max + scale = x_amax / fp8_max + + # Round to E8M0 when hardware scaling is used (gfx950); keep as FP32 for the + # quantization divide, then pre-pack to uint8 below for kernel consumption. + if USE_UE8M0: + scale = fp32_to_e8m0(scale) + + # Quantize + x_scaled = x_blocks / scale.unsqueeze(2) + x_fp8 = x_scaled.to(DTYPE_FP8).view(M, K) + + # Pre-pack scale as uint8 (one E8M0 byte per scale) so the kernel can load + # 1 byte/scale and skip the in-kernel bitwise extract. + if USE_UE8M0: + scale = fp32_e8m0_to_byte(scale) + + # Transpose scale to [scale_k, M] to match the kernel's expected layout + scale = scale.T.contiguous() + + return x_fp8, scale + + +def quantize_b_to_fp8( + b: torch.Tensor, scale_block_n: int = 128, scale_block_k: int = 128 +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize B tensor to FP8 with per-block scaling. + + Args: + b: Input tensor [num_groups, N, K] + scale_block_n: N-dimension block size + scale_block_k: K-dimension block size + + Returns: + (b_fp8, scale_b): FP8 tensor and scale factors [num_groups, scale_n, scale_k] + """ + num_groups, N, K = b.shape + nblk_n = N // scale_block_n + nblk_k = K // scale_block_k + + # Reshape to [num_groups, nblk_n, scale_block_n, nblk_k, scale_block_k] + b_blocks = b.view(num_groups, nblk_n, scale_block_n, nblk_k, scale_block_k) + + # Compute per-block max + b_amax = b_blocks.abs().amax(dim=(2, 4)).clamp(min=1e-12) + + fp8_max = torch.finfo(DTYPE_FP8).max + scale = b_amax / fp8_max + + # Round to E8M0 when hardware scaling is used (gfx950); keep as FP32 for the + # quantization divide, then pre-pack to uint8 below for kernel consumption. + if USE_UE8M0: + scale = fp32_to_e8m0(scale) + + # Quantize + b_scaled = b_blocks / scale.view(num_groups, nblk_n, 1, nblk_k, 1) + b_fp8 = b_scaled.to(DTYPE_FP8).view(num_groups, N, K) + + if USE_UE8M0: + scale = fp32_e8m0_to_byte(scale) + + return b_fp8, scale + + +def generate_grouped_gemm_inputs( + num_groups: int, + m_per_group: int, + n: int, + k: int, + scale_block_k: int = 128, + scale_block_n: int = 128, + out_dtype: str = "bf16", + device: str = "cuda", +): + """Generate test inputs for grouped GEMM. + + Generates variable actual group sizes (unaligned), pads each group to + 128-row alignment, and marks padding rows with -1 in grouped_layout. + + Args: + num_groups: Number of groups + m_per_group: Approximate actual M rows per group (before alignment) + n: N dimension + k: K dimension + scale_block_k: K-dimension scale block size + scale_block_n: N-dimension scale block size + out_dtype: Output data type ("bf16" or "f16") + device: Device to create tensors on + + Returns: + Tuple of (a_fp8, scale_a, b_shuffled, scale_b, grouped_layout, d, ref_d, M) + """ + alignment = 128 # M-row alignment for contiguous layout + torch_out_dtype = torch.bfloat16 if out_dtype == "bf16" else torch.float16 + + # Generate variable actual group sizes, then align + actual_ms = [] + aligned_ms = [] + for _ in range(num_groups): + m_actual = int(m_per_group * (0.8 + 0.4 * torch.rand(1).item())) + m_actual = max(1, m_actual) # at least 1 row + m_aligned = align(m_actual, alignment) + actual_ms.append(m_actual) + aligned_ms.append(m_aligned) + M = sum(aligned_ms) + + # Create grouped_layout with -1 padding + grouped_layout = torch.empty(M, dtype=torch.int32, device=device) + start = 0 + for g, (m_actual, m_aligned) in enumerate(zip(actual_ms, aligned_ms)): + grouped_layout[start : start + m_actual] = g + grouped_layout[start + m_actual : start + m_aligned] = -1 + start += m_aligned + + # Generate random data + a_f32 = torch.randn(M, k, device=device, dtype=torch.float32) + b_f32 = torch.randn(num_groups, n, k, device=device, dtype=torch.float32) + + # Zero out padding rows in A + start = 0 + for m_actual, m_aligned in zip(actual_ms, aligned_ms): + a_f32[start + m_actual : start + m_aligned] = 0 + start += m_aligned + + # Reference output from original FP32 data BEFORE quantization + # (ref absorbs all quantization + scale errors). + # Per-group matmul. + ref_d = torch.zeros(M, n, dtype=torch.float32, device=device) + for g in range(num_groups): + mask = grouped_layout == g + if mask.any(): + ref_d[mask] = a_f32[mask] @ b_f32[g].T + ref_d = ref_d.to(torch_out_dtype) + + # Quantize to FP8 + a_fp8, scale_a = quantize_to_fp8(a_f32, scale_block_k) + b_fp8, scale_b = quantize_b_to_fp8(b_f32, scale_block_n, scale_block_k) + + # Output buffer + d = torch.zeros(M, n, dtype=torch_out_dtype, device=device) + + # Preshuffle B for kernel (applied per-group, batch dim folded automatically) + b_shuffled = shuffle_weight(b_fp8, layout=(16, 16)) + + return a_fp8, scale_a, b_shuffled, scale_b, grouped_layout, d, ref_d, M + + +def _as_i8(t: torch.Tensor) -> torch.Tensor: + """View FP8 tensor as int8 for kernel interface.""" + return t.view(torch.int8) + + +@pytest.mark.parametrize( + "num_groups,m_per_group,n,k", + [ + # Basic shapes + pytest.param(1, 128, 128, 128, id="1g-128m-128n-128k"), + pytest.param(2, 128, 128, 128, id="2g-128m-128n-128k"), + pytest.param(4, 128, 256, 256, id="4g-128m-256n-256k"), + # Unaligned M (produces -1 padding rows in grouped_layout) + pytest.param(2, 100, 128, 128, id="2g-100m-unaligned"), + pytest.param(4, 200, 256, 256, id="4g-200m-unaligned"), + # Larger shapes + pytest.param(8, 256, 512, 512, id="8g-256m-512n-512k", marks=pytest.mark.large_shape), + pytest.param(8, 256, 2048, 7168, id="8g-2048x7168-large", marks=pytest.mark.large_shape), + pytest.param(8, 256, 7168, 2304, id="8g-7168x2304-large", marks=pytest.mark.large_shape), + ], +) +@pytest.mark.parametrize("out_dtype", [ + pytest.param("bf16", id="bf16"), + pytest.param("f16", id="f16"), +]) +def test_grouped_fp8_gemm(num_groups, m_per_group, n, k, out_dtype, + *, tile_m=128, tile_n=128, tile_k=128, + bench_iters=20, bench_warmup=3): + """Verify grouped FP8 GEMM correctness against a PyTorch reference and + report throughput. Single function combining correctness + bench, matching + the convention of test_blockscale_preshuffle_gemm.py.""" + scale_block_k = 128 + scale_block_n = 128 + + # Generate inputs + a_fp8, scale_a, b_fp8, scale_b, grouped_layout, d, ref_d, M = generate_grouped_gemm_inputs( + num_groups, m_per_group, n, k, scale_block_k, scale_block_n, out_dtype=out_dtype, + ) + + # Compile kernel + launch_fn = compile_grouped_gemm_blockscale_contiguous( + n=n, + k=k, + num_groups=num_groups, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + scale_block_k=scale_block_k, + scale_block_n=scale_block_n, + out_dtype=out_dtype, + ) + + # Launch wrapper + stream = torch.cuda.current_stream() + + def launch_kernel(d, a, b, sa, sb, gl): + launch_fn(d, a, b, sa, sb, gl, M, n, k, num_groups, stream) + + bench_iters = max(2, int(bench_iters)) + bench_warmup = int(bench_warmup) + _, us = run_perftest( + launch_kernel, + d.contiguous().view(-1), + _as_i8(a_fp8.contiguous().view(-1)), + _as_i8(b_fp8.contiguous().view(-1)), + scale_a.contiguous().view(-1), + scale_b.contiguous().view(-1), + grouped_layout.contiguous(), + num_iters=bench_iters, + num_warmup=bench_warmup, + ) + torch.cuda.synchronize() + + # Zero out padding rows (group_id == -1) before comparison + padding_mask = grouped_layout.cpu() == -1 + c_out_f32 = d.to(torch.float32) + c_ref = ref_d.to(torch.float32) + c_out_f32[padding_mask] = 0.0 + c_ref[padding_mask] = 0.0 + + msg = f"num_groups={num_groups}, M={M}, N={n}, K={k}, out={out_dtype}" + passed = verify_output(c_out_f32, c_ref, rtol=1e-2, atol=1e-2, msg=msg, + logits_diff_threshold=1e-3) + assert passed, f"Correctness check failed for {msg}" + + flops = 2 * M * n * k + tflops = flops / (us / 1e6) / 1e12 + bytes_a = M * k # FP8 + bytes_b = num_groups * n * k # FP8 + bytes_d = M * n * 2 # BF16 + bytes_scales = (k // scale_block_k) * M * 4 + num_groups * (n // scale_block_n) * (k // scale_block_k) * 4 + total_bytes = bytes_a + bytes_b + bytes_d + bytes_scales + bandwidth_tbs = total_bytes / (us / 1e6) / 1e12 + + print(f"\n [{num_groups} groups, M={M}, N={n}, K={k}]") + print(f" Throughput: {us:.1f} us, {tflops:.2f} TFLOPS, BW: {bandwidth_tbs:.3f} TB/s") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Grouped FP8 GEMM benchmark", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--num_groups", type=int, default=4) + parser.add_argument("--m_per_group", type=int, default=0, + help="Approx M rows per group (0 = sweep [128, 256, 512, 1024])") + parser.add_argument("-N", type=int, default=512) + parser.add_argument("-K", type=int, default=512) + parser.add_argument("--tile_m", type=int, default=128) + parser.add_argument("--tile_n", type=int, default=128) + parser.add_argument("--tile_k", type=int, default=128) + parser.add_argument("--out_dtype", type=str, default="bf16", choices=["bf16", "f16"]) + parser.add_argument("--num_iters", type=int, default=100) + parser.add_argument("--num_warmup", type=int, default=5) + args = parser.parse_args() + + torch.set_default_device("cuda") + + m_list = [args.m_per_group] if args.m_per_group > 0 else [128, 256, 512, 1024] + + for m_per_group in m_list: + test_grouped_fp8_gemm(args.num_groups, m_per_group, args.N, args.K, + out_dtype=args.out_dtype, + tile_m=args.tile_m, tile_n=args.tile_n, + tile_k=args.tile_k, + bench_iters=args.num_iters, bench_warmup=args.num_warmup) diff --git a/tests/kernels/test_grouped_gemm_blockscale_masked.py b/tests/kernels/test_grouped_gemm_blockscale_masked.py new file mode 100644 index 000000000..a01db194b --- /dev/null +++ b/tests/kernels/test_grouped_gemm_blockscale_masked.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Tests for Masked Grouped FP8 GEMM kernel. + +Tests the masked grouped FP8 GEMM with block scaling, matching DeepGEMM's +m_grouped_fp8_gemm_nt_masked API. +""" + +import os +import sys +import logging + +import torch +import pytest + +pytestmark = [pytest.mark.l2_device, pytest.mark.rocm_lower] + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYTHON_CANDIDATES = [ + os.path.join(_REPO_ROOT, "build", "python_packages"), + _REPO_ROOT, +] +for _p in reversed(_PYTHON_CANDIDATES): + if os.path.isdir(_p) and _p not in sys.path: + sys.path.insert(0, _p) + +from kernels.grouped_gemm_blockscale_masked import compile_grouped_gemm_blockscale_masked +from flydsl.runtime.device import get_rocm_arch +from tests.test_common import run_perftest, verify_output +from tests.utils import shuffle_weight + +logging.basicConfig(level=logging.INFO) + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +ARCH = get_rocm_arch() +# Use appropriate FP8 dtype for the architecture +DTYPE_FP8 = torch.float8_e4m3fn if "gfx95" in ARCH else torch.float8_e4m3fnuz +# gfx950 uses hardware E8M0 block scaling — quantization must use E8M0-rounded scales +USE_UE8M0 = "gfx95" in ARCH + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def fp32_to_e8m0(scale: torch.Tensor) -> torch.Tensor: + """Round FP32 scale UP to E8M0 precision (ceiling on exponent). + + Rounding up is required so that x / scale_e8m0 <= fp8_max — truncation + would shrink the scale, causing FP8 saturation and a systematic bias on + every block. + """ + bits = scale.abs().float().view(torch.int32) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int() + return (exp.clamp(1, 254) << 23).view(torch.float32) + + +def fp32_e8m0_to_byte(scale_e8m0_f32: torch.Tensor) -> torch.Tensor: + """Extract the E8M0 byte from a float that was previously rounded by + fp32_to_e8m0. Returns uint8. Use this when handing scales to the kernel + so dequant uses bit-exact the same scale the kernel will apply (HW E8M0 + path uses byte 0 of the i32 scale operand).""" + bits = scale_e8m0_f32.view(torch.int32) + return ((bits >> 23) & 0xFF).to(torch.uint8) + + +def quantize_a_masked_to_fp8(x: torch.Tensor, scale_block_k: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize padded 3D A tensor to FP8 with per-row, per-block scaling. + + Args: + x: Input tensor [G, max_m, K] + scale_block_k: K-dimension block size for scaling + + Returns: + (x_fp8, scale): FP8 tensor and scale factors [G, scale_k, max_m] + """ + G, max_m, K = x.shape + nblk_k = K // scale_block_k + + # Reshape to [G, max_m, nblk_k, scale_block_k] + x_blocks = x.view(G, max_m, nblk_k, scale_block_k) + + # Compute per-block max (for scale) + x_amax = x_blocks.abs().amax(dim=-1).clamp(min=1e-12) + + fp8_max = torch.finfo(DTYPE_FP8).max + scale = x_amax / fp8_max + + # Round to E8M0 when hardware scaling is used (gfx950); keep as FP32 for the + # quantization divide, then pre-pack to uint8 below for kernel consumption. + if USE_UE8M0: + scale = fp32_to_e8m0(scale) + + # Quantize + x_scaled = x_blocks / scale.unsqueeze(-1) + x_fp8 = x_scaled.to(DTYPE_FP8).view(G, max_m, K) + + # Pre-pack scale as uint8 (one E8M0 byte per scale) so the kernel can load + # 1 byte/scale and skip the in-kernel bitwise extract. + if USE_UE8M0: + scale = fp32_e8m0_to_byte(scale) + + # Transpose scale to [G, scale_k, max_m] to match kernel layout + scale = scale.transpose(1, 2).contiguous() + + return x_fp8, scale + + +def quantize_b_to_fp8( + b: torch.Tensor, scale_block_n: int = 128, scale_block_k: int = 128 +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize B tensor to FP8 with per-block scaling. + + Args: + b: Input tensor [num_groups, N, K] + scale_block_n: N-dimension block size + scale_block_k: K-dimension block size + + Returns: + (b_fp8, scale_b): FP8 tensor and scale factors [num_groups, scale_n, scale_k] + """ + num_groups, N, K = b.shape + nblk_n = N // scale_block_n + nblk_k = K // scale_block_k + + # Reshape to [num_groups, nblk_n, scale_block_n, nblk_k, scale_block_k] + b_blocks = b.view(num_groups, nblk_n, scale_block_n, nblk_k, scale_block_k) + + # Compute per-block max + b_amax = b_blocks.abs().amax(dim=(2, 4)).clamp(min=1e-12) + + fp8_max = torch.finfo(DTYPE_FP8).max + scale = b_amax / fp8_max + + # Round to E8M0 when hardware scaling is used (gfx950); keep as FP32 for the + # quantization divide, then pre-pack to uint8 below for kernel consumption. + if USE_UE8M0: + scale = fp32_to_e8m0(scale) + + # Quantize + b_scaled = b_blocks / scale.view(num_groups, nblk_n, 1, nblk_k, 1) + b_fp8 = b_scaled.to(DTYPE_FP8).view(num_groups, N, K) + + if USE_UE8M0: + scale = fp32_e8m0_to_byte(scale) + + return b_fp8, scale + + +def generate_masked_grouped_gemm_inputs( + num_groups: int, + max_m: int, + expected_m_per_group: int, + n: int, + k: int, + scale_block_k: int = 128, + scale_block_n: int = 128, + out_dtype: str = "bf16", + device: str = "cuda", +): + """Generate test inputs for masked grouped GEMM. + + Args: + num_groups: Number of groups + max_m: Capacity padding dimension for M + expected_m_per_group: Average actual M rows per group + n: N dimension + k: K dimension + scale_block_k: K-dimension scale block size + scale_block_n: N-dimension scale block size + out_dtype: Output data type ("bf16" or "f16") + device: Device to create tensors on + + Returns: + Tuple of (a_fp8, scale_a, b_shuffled, scale_b, masked_m, d, ref_d) + """ + torch_out_dtype = torch.bfloat16 if out_dtype == "bf16" else torch.float16 + + # Generate valid length array + masked_m = torch.empty(num_groups, dtype=torch.int32, device=device) + for g in range(num_groups): + m_val = int(expected_m_per_group * (0.8 + 0.4 * torch.rand(1).item())) + m_val = min(m_val, max_m) # cap at max_m + masked_m[g] = m_val + + # Generate random padded data + a_f32 = torch.randn(num_groups, max_m, k, device=device, dtype=torch.float32) + b_f32 = torch.randn(num_groups, n, k, device=device, dtype=torch.float32) + + # Reference output from original FP32 data BEFORE quantization + # (ref absorbs all quantization + scale errors). + # Per-group matmul. + ref_d = torch.zeros(num_groups, max_m, n, dtype=torch.float32, device=device) + for g in range(num_groups): + m_actual = masked_m[g].item() + if m_actual > 0: + ref_d[g, :m_actual, :] = a_f32[g, :m_actual, :] @ b_f32[g].T + ref_d = ref_d.to(torch_out_dtype) + + # Quantize to FP8 + a_fp8, scale_a = quantize_a_masked_to_fp8(a_f32, scale_block_k) + b_fp8, scale_b = quantize_b_to_fp8(b_f32, scale_block_n, scale_block_k) + + # Output buffer + d = torch.zeros(num_groups, max_m, n, dtype=torch_out_dtype, device=device) + + # Preshuffle B for kernel (applied per-group, batch dim folded automatically) + b_shuffled = shuffle_weight(b_fp8, layout=(16, 16)) + + return a_fp8, scale_a, b_shuffled, scale_b, masked_m, d, ref_d + + +def _as_i8(t: torch.Tensor) -> torch.Tensor: + """View FP8 tensor as int8 for kernel interface.""" + return t.view(torch.int8) + + +@pytest.mark.parametrize( + "num_groups,max_m,expected_m,n,k", + [ + # Basic shapes + pytest.param(1, 256, 100, 128, 128, id="1g-256max-100m"), + pytest.param(4, 256, 150, 128, 128, id="4g-256max-150m"), + pytest.param(8, 512, 300, 256, 256, id="8g-512max-300m"), + # Small expected_m (many padding rows) + pytest.param(4, 512, 50, 128, 128, id="4g-512max-50m-sparse"), + # Larger shapes + pytest.param(8, 1024, 800, 512, 512, id="8g-1024max-800m", marks=pytest.mark.large_shape), + pytest.param(8, 512, 300, 2048, 7168, id="8g-2048x7168-large", marks=pytest.mark.large_shape), + pytest.param(8, 512, 300, 7168, 2304, id="8g-7168x2304-large", marks=pytest.mark.large_shape), + ], +) +@pytest.mark.parametrize("out_dtype", [ + pytest.param("bf16", id="bf16"), + pytest.param("f16", id="f16"), +]) +def test_masked_grouped_fp8_gemm(num_groups, max_m, expected_m, n, k, out_dtype, + *, tile_m=128, tile_n=128, tile_k=128, + bench_iters=20, bench_warmup=3): + """Verify masked grouped FP8 GEMM correctness against a PyTorch reference and + report throughput. Single function combining correctness + bench, matching + the convention of test_blockscale_preshuffle_gemm.py.""" + scale_block_k = 128 + scale_block_n = 128 + + # Generate inputs + a_fp8, scale_a, b_fp8, scale_b, masked_m, d, ref_d = generate_masked_grouped_gemm_inputs( + num_groups, max_m, expected_m, n, k, scale_block_k, scale_block_n, out_dtype=out_dtype, + ) + + # Compile kernel + launch_fn = compile_grouped_gemm_blockscale_masked( + n=n, + k=k, + num_groups=num_groups, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + scale_block_k=scale_block_k, + scale_block_n=scale_block_n, + out_dtype=out_dtype, + ) + + # Launch wrapper + stream = torch.cuda.current_stream() + + def launch_kernel(d, a, b, sa, sb, mask): + launch_fn(d, a, b, sa, sb, mask, max_m, n, k, num_groups, stream) + + bench_iters = max(2, int(bench_iters)) + bench_warmup = int(bench_warmup) + _, us = run_perftest( + launch_kernel, + d.contiguous().view(-1), + _as_i8(a_fp8.contiguous().view(-1)), + _as_i8(b_fp8.contiguous().view(-1)), + scale_a.contiguous().view(-1), + scale_b.contiguous().view(-1), + masked_m.contiguous(), + num_iters=bench_iters, + num_warmup=bench_warmup, + ) + torch.cuda.synchronize() + + # Verify correctness — zero out padding rows before comparison + # (kernel does not mask individual stores in the epilogue) + c_out_f32 = d.to(torch.float32) + c_ref = ref_d.to(torch.float32) + for g in range(num_groups): + m_val = masked_m[g].item() + c_out_f32[g, m_val:, :] = 0.0 + c_ref[g, m_val:, :] = 0.0 + + msg = f"num_groups={num_groups}, max_m={max_m}, N={n}, K={k}, out={out_dtype}" + passed = verify_output(c_out_f32, c_ref, rtol=1e-2, atol=1e-2, msg=msg, + logits_diff_threshold=1e-3) + assert passed, f"Correctness check failed for {msg}" + + # Effective FLOPs/BW based on ACTUAL valid tokens (padding mostly skipped by kernel). + valid_m_sum = masked_m.sum().item() + flops = 2 * valid_m_sum * n * k + tflops = flops / (us / 1e6) / 1e12 + bytes_a = valid_m_sum * k # FP8 + bytes_b = num_groups * n * k # FP8 + bytes_d = valid_m_sum * n * 2 # BF16 + bytes_scales = (k // scale_block_k) * valid_m_sum * 4 + num_groups * (n // scale_block_n) * (k // scale_block_k) * 4 + total_bytes = bytes_a + bytes_b + bytes_d + bytes_scales + bandwidth_tbs = total_bytes / (us / 1e6) / 1e12 + + print(f"\n [{num_groups} groups, max_m={max_m}, expected_m={expected_m}, N={n}, K={k}]") + print(f" Throughput: {us:.1f} us, {tflops:.2f} TFLOPS, BW: {bandwidth_tbs:.3f} TB/s (Effective)") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Masked Grouped FP8 GEMM benchmark", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("--num_groups", type=int, default=4) + parser.add_argument("--max_m", type=int, default=512) + parser.add_argument("--expected_m", type=int, default=0, + help="Approx valid M rows per group (0 = sweep [128, 256, 384])") + parser.add_argument("-N", type=int, default=512) + parser.add_argument("-K", type=int, default=512) + parser.add_argument("--tile_m", type=int, default=128) + parser.add_argument("--tile_n", type=int, default=128) + parser.add_argument("--tile_k", type=int, default=128) + parser.add_argument("--out_dtype", type=str, default="bf16", choices=["bf16", "f16"]) + parser.add_argument("--num_iters", type=int, default=100) + parser.add_argument("--num_warmup", type=int, default=5) + args = parser.parse_args() + + torch.set_default_device("cuda") + + m_list = [args.expected_m] if args.expected_m > 0 else [128, 256, 384] + + for expected_m in m_list: + test_masked_grouped_fp8_gemm(args.num_groups, args.max_m, expected_m, args.N, args.K, + out_dtype=args.out_dtype, + tile_m=args.tile_m, tile_n=args.tile_n, + tile_k=args.tile_k, + bench_iters=args.num_iters, bench_warmup=args.num_warmup) \ No newline at end of file