diff --git a/programming_examples/generate_readme.py b/programming_examples/generate_readme.py index 3363e5390..b05dc46d8 100644 --- a/programming_examples/generate_readme.py +++ b/programming_examples/generate_readme.py @@ -60,6 +60,12 @@ "path": "matrix_vector_multiplication/int4_awq", "datatypes": "int4 weights / bf16 activations", }, + { + "category": "Linear Algebra", + "name": "Matrix Multiplication (AWQ int4)", + "path": "matrix_multiplication/int4_awq", + "datatypes": "int4 weights / bf16 activations", + }, { "category": "Linear Algebra", "name": "AXPY", diff --git a/programming_examples/matrix_multiplication/int4_awq/Makefile b/programming_examples/matrix_multiplication/int4_awq/Makefile new file mode 100644 index 000000000..662b7d564 --- /dev/null +++ b/programming_examples/matrix_multiplication/int4_awq/Makefile @@ -0,0 +1,90 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +# +# int4-AWQ GEMM example. The kernel .cc lives in the int4-AWQ GEMV sibling +# directory and is shared (matmul_int4_bf16_packed + zero_vectorized_bf16_mn). + +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) +INT4_SRCDIR := $(srcdir)/../../matrix_vector_multiplication/int4_awq + +# TILE_K_L2 must equal K: the segment-level K loop has 1 iter so the per-PE +# L1 C accumulator survives across all K_CHUNK iterations. +M ?= 64 +K ?= 128 +N ?= 128 +GS ?= 128 +TILE_M ?= 16 +TILE_N ?= 16 +TILE_K_L1 ?= 128 +TILE_K_L2 ?= $(K) +HERD_M ?= 2 +HERD_N ?= 4 + +ifdef PEANO_INSTALL_DIR + BUILD_DIR := build_peano +else + BUILD_DIR := build_chess +endif + +AIEOPT_DIR = $(shell realpath $(dir $(shell which aie-opt))/..) +WARNING_FLAGS = -Wno-parentheses -Wno-attributes -Wno-macro-redefined -Wno-empty-body +PEANOWRAP2P_FLAGS = -O2 -std=c++20 --target=aie2p-none-unknown-elf ${WARNING_FLAGS} -DNDEBUG -I ${AIEOPT_DIR}/include + +PY_ARGS = --m $(M) --k $(K) --n $(N) --gs $(GS) \ + --tile-m $(TILE_M) --tile-n $(TILE_N) \ + --tile-k-l1 $(TILE_K_L1) --tile-k-l2 $(TILE_K_L2) \ + --herd-m $(HERD_M) --herd-n $(HERD_N) + +all: run_packed + +print: + ${powershell} python3 ${srcdir}/matmul_int4_packed.py $(PY_ARGS) -p + +compile-kernel: + mkdir -p $(BUILD_DIR) + @if [ -z "$(PEANO_INSTALL_DIR)" ]; then \ + echo "Error: PEANO_INSTALL_DIR not set (source utils/env_setup.sh)."; \ + exit 1; \ + fi + $(PEANO_INSTALL_DIR)/bin/clang++ ${PEANOWRAP2P_FLAGS} \ + -DDIM_M=$(TILE_M) -DDIM_N=$(TILE_N) -DDIM_K_CHUNK=$(TILE_K_L1) -DDIM_GS=$(GS) \ + -DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16 \ + -c $(INT4_SRCDIR)/mv_int4_bf16.cc -o $(BUILD_DIR)/mv_int4_bf16.o + +run_packed: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py $(PY_ARGS) + +run1x1: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py \ + --m 16 --k 128 --n 16 --gs $(GS) \ + --tile-m 16 --tile-n 16 --tile-k-l1 128 --tile-k-l2 128 \ + --herd-m 1 --herd-n 1 + +run2x4: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py \ + --m 64 --k 128 --n 128 --gs $(GS) \ + --tile-m $(TILE_M) --tile-n $(TILE_N) \ + --tile-k-l1 $(TILE_K_L1) --tile-k-l2 128 \ + --herd-m 2 --herd-n 4 + +run8x4: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py \ + --m 256 --k 128 --n 128 --gs $(GS) \ + --tile-m $(TILE_M) --tile-n $(TILE_N) \ + --tile-k-l1 $(TILE_K_L1) --tile-k-l2 128 \ + --herd-m 8 --herd-n 4 + +# Llama Q-proj scale: M=N=K=2048, herd 8x4, TILE_K_L2=K (no re-zero). +run_llama_qproj: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py \ + --m 2048 --k 2048 --n 2048 --gs $(GS) \ + --tile-m 16 --tile-n 16 --tile-k-l1 128 --tile-k-l2 2048 \ + --herd-m 8 --herd-n 4 + +clean: + rm -rf $(BUILD_DIR) __pycache__ diff --git a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py new file mode 100644 index 000000000..344149307 --- /dev/null +++ b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py @@ -0,0 +1,444 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +# +# int4-AWQ GEMM (prefill). 2D herd over (M, N), with K accumulated per-PE +# inside the herd; per-PE drain into a 4D L2 C [herd_m, herd_n, tile_m, +# tile_n]. Uses matmul_int4_bf16_packed from mv_int4_bf16.cc (packed +# Q+S+Z weight BO; bf16 activation and output). + +import argparse +import sys + +import numpy as np +from ml_dtypes import bfloat16 + +from air.ir import ( + AffineConstantExpr, + AffineExpr, + AffineMap, + AffineSymbolExpr, + BF16Type, + F32Type, + IntegerAttr, + IntegerType, + MemRefType, + ShapedType, + StridedLayoutAttr, + StringAttr, + UnitAttr, +) +from air.dialects.affine import apply as affine_apply +from air.dialects.air import ( + MemorySpace, + T, + dma_memcpy_nd, + herd, + launch, + module_builder, + segment, +) +from air.dialects.func import CallOp, FuncOp +from air.dialects.memref import AllocOp, DeallocOp +from air.dialects.scf import for_, yield_ +from air.backend.xrt import XRTBackend +from air.backend.xrt_runner import XRTRunner + +KERNEL_OBJ_NAME = "mv_int4_bf16.o" + + +def packed_tile_bytes(n_tile, k_chunk, gs): + n_gpc = k_chunk // gs + q_bytes = n_tile * (k_chunk // 2) + s_bytes = n_gpc * n_tile * 2 + z_bytes = n_gpc * n_tile + return q_bytes, s_bytes, z_bytes, q_bytes + s_bytes + z_bytes + + +def pack_inputs(W_q, W_s, W_z, M, K, N, GS, N_TILE, K_CHUNK): + """Pack per-(n_outer, k_outer) Q+S+Z tiles into [N_div, K_div, tile_bytes]. + + W_q [N, K/2] u8 (output-major), W_s [K/GS, N] bf16, W_z [K/GS, N] u8. + """ + n_gpc = K_CHUNK // GS + q_bytes, s_bytes, _, tile_bytes = packed_tile_bytes(N_TILE, K_CHUNK, GS) + N_div = N // N_TILE + K_div = K // K_CHUNK + packed = np.zeros((N_div, K_div, tile_bytes), dtype=np.uint8) + for n_outer in range(N_div): + col_off = n_outer * N_TILE + for k_outer in range(K_div): + q_col_byte = k_outer * (K_CHUNK // 2) + g_off = k_outer * n_gpc + q_tile = W_q[ + col_off : col_off + N_TILE, + q_col_byte : q_col_byte + (K_CHUNK // 2), + ] + s_tile = W_s[g_off : g_off + n_gpc, col_off : col_off + N_TILE] + z_tile = W_z[g_off : g_off + n_gpc, col_off : col_off + N_TILE] + p = packed[n_outer, k_outer] + p[0:q_bytes] = np.ascontiguousarray(q_tile).view(np.uint8).reshape(-1) + p[q_bytes : q_bytes + s_bytes] = ( + np.ascontiguousarray(s_tile).view(np.uint8).reshape(-1) + ) + p[q_bytes + s_bytes :] = ( + np.ascontiguousarray(z_tile).view(np.uint8).reshape(-1) + ) + return packed + + +def cpu_reference(W_q, W_s, W_z, A): + N_ = W_q.shape[0] + K_ = A.shape[1] + n_groups = W_s.shape[0] + gs = K_ // n_groups + Af = A.astype(np.float32) + W_s_f = W_s.astype(np.float32) + W_z_i = W_z.astype(np.int32) + W_dq = np.zeros((K_, N_), dtype=np.float32) + for n in range(N_): + for kk in range(K_): + byte = int(W_q[n, kk // 2]) + nib = (byte & 0x0F) if (kk % 2 == 0) else ((byte >> 4) & 0x0F) + g = kk // gs + W_dq[kk, n] = (nib - W_z_i[g, n]) * W_s_f[g, n] + C = Af @ W_dq + return C.astype(bfloat16) + + +@module_builder +def build_module(m, k, n, gs, tile_m, tile_k_l2, tile_k_l1, tile_n, herd_m, herd_n): + assert m % (tile_m * herd_m) == 0 + assert n % (tile_n * herd_n) == 0 + assert k % tile_k_l2 == 0 + assert tile_k_l2 % tile_k_l1 == 0 + assert tile_k_l1 % gs == 0 + # Kernel-side static_assert constraints from mv_int4_bf16.cc: + # mm_int4_bf16_mmul_impl: tile_m/n/k_chunk % 8 (mmul dims), gs % R=32 + # zero_vectorized_bf16_mn: (tile_m * tile_n) % VW=32 + assert ( + tile_m % 8 == 0 and tile_n % 8 == 0 and tile_k_l1 % 8 == 0 + ), "tile_m, tile_n, tile_k_l1 must each be multiples of 8 (mmul tile size)" + assert gs % 32 == 0, "gs must be a multiple of dequant inner-vector width 32" + assert (tile_m * tile_n) % 32 == 0, ( + f"tile_m*tile_n ({tile_m}*{tile_n}={tile_m * tile_n}) must be a multiple " + f"of vector width 32 for zero_vectorized_bf16_mn" + ) + + _, _, _, tile_bytes = packed_tile_bytes(tile_n, tile_k_l1, gs) + k_per_l2 = tile_k_l2 // tile_k_l1 + N_div = n // tile_n + K_div = k // tile_k_l1 + + bf16_ty = BF16Type.get() + f32_ty = F32Type.get() + u8_ty = IntegerType.get_signless(8) + + A_l3_ty = MemRefType.get([m, k], bf16_ty) + B_l3_ty = MemRefType.get([N_div, K_div, tile_bytes], u8_ty) + C_l3_ty = MemRefType.get([m, n], bf16_ty) + + l1_ms = IntegerAttr.get(T.i32(), MemorySpace.L1) + l2_ms = IntegerAttr.get(T.i32(), MemorySpace.L2) + + A_l2_ty = MemRefType.get( + [herd_m, 1, tile_m, tile_k_l2], bf16_ty, memory_space=l2_ms + ) + B_l2_ty = MemRefType.get( + [1, herd_n, k_per_l2, tile_bytes], u8_ty, memory_space=l2_ms + ) + C_l2_ty = MemRefType.get( + [herd_m, herd_n, tile_m, tile_n], bf16_ty, memory_space=l2_ms + ) + + A_l1_ty = MemRefType.get([tile_m, tile_k_l1], bf16_ty, memory_space=l1_ms) + B_l1_ty = MemRefType.get([tile_bytes], u8_ty, memory_space=l1_ms) + # L1 C accumulator: f32. Kept across the host K-chunk loop so partial sums + # don't bf16-truncate between calls. Converted to bf16 once at the end. + C_l1_acc_ty = MemRefType.get([tile_m, tile_n], f32_ty, memory_space=l1_ms) + C_l1_drain_ty = MemRefType.get([tile_m, tile_n], bf16_ty, memory_space=l1_ms) + + zero_func = FuncOp( + "zero_vectorized_f32_mn", + ([C_l1_acc_ty], []), + visibility="private", + ) + zero_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + zero_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + + matmul_func = FuncOp( + "matmul_int4_bf16_packed_f32", + ([B_l1_ty, A_l1_ty, C_l1_acc_ty], []), + visibility="private", + ) + matmul_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + matmul_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + + f32_to_bf16_func = FuncOp( + "f32_to_bf16_mn", + ([C_l1_acc_ty, C_l1_drain_ty], []), + visibility="private", + ) + f32_to_bf16_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + f32_to_bf16_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + + @FuncOp.from_py_func(A_l3_ty, B_l3_ty, C_l3_ty) + def matmul_int4_packed(arg_a, arg_b, arg_c): + launch_size = [m // tile_m // herd_m, n // tile_n // herd_n] + + @launch(operands=[arg_a, arg_b, arg_c], sizes=launch_size) + def launch_body(li, lj, lsx, lsy, l3_a, l3_b, l3_c): + @segment(name="seg", operands=[li, lj, l3_a, l3_b, l3_c]) + def segment_body(li_s, lj_s, l3_a_s, l3_b_s, l3_c_s): + l2_a = AllocOp(A_l2_ty, [], []) + l2_b = AllocOp(B_l2_ty, [], []) + l2_c = AllocOp(C_l2_ty, [], []) + + ix_to_row = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_m * herd_m), + ) + ], + ) + iy_to_n_outer = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(herd_n) + ) + ], + ) + row_off = affine_apply(ix_to_row, [li_s]) + n_outer_off = affine_apply(iy_to_n_outer, [lj_s]) + + k_l2_to_k = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(tile_k_l2) + ) + ], + ) + k_l2_to_chunk = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(k_per_l2) + ) + ], + ) + k_chunk_off_l1_map = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(tile_k_l1) + ) + ], + ) + + for i in for_(0, k // tile_k_l2): + k_l2_off = affine_apply(k_l2_to_k, [i]) + k_chunk_off = affine_apply(k_l2_to_chunk, [i]) + + dma_memcpy_nd( + l2_a, + l3_a_s, + src_offsets=[0, 0, row_off, k_l2_off], + src_sizes=[herd_m, 1, tile_m, tile_k_l2], + src_strides=[k * tile_m, tile_k_l2, k, 1], + ) + dma_memcpy_nd( + l2_b, + l3_b_s, + src_offsets=[0, n_outer_off, k_chunk_off, 0], + src_sizes=[1, herd_n, k_per_l2, tile_bytes], + src_strides=[ + K_div * tile_bytes, + K_div * tile_bytes, + tile_bytes, + 1, + ], + ) + + @herd( + name="herd_0", + sizes=[herd_m, herd_n], + operands=[l2_a, l2_b, l2_c], + ) + def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): + _l1_a = AllocOp(A_l1_ty, [], []) + _l1_b = AllocOp(B_l1_ty, [], []) + _l1_c_acc = AllocOp(C_l1_acc_ty, [], []) + _l1_c_drain = AllocOp(C_l1_drain_ty, [], []) + CallOp(zero_func, [_l1_c_acc]) + for j in for_(0, k_per_l2): + k1_off = affine_apply(k_chunk_off_l1_map, [j]) + dma_memcpy_nd( + _l1_a, + _l2a, + src_offsets=[_tx, 0, 0, k1_off], + src_sizes=[1, 1, tile_m, tile_k_l1], + src_strides=[ + tile_m * tile_k_l2, + tile_m * tile_k_l2, + tile_k_l2, + 1, + ], + ) + dma_memcpy_nd( + _l1_b, + _l2b, + src_offsets=[0, _ty, j, 0], + src_sizes=[1, 1, 1, tile_bytes], + src_strides=[ + herd_n * k_per_l2 * tile_bytes, + k_per_l2 * tile_bytes, + tile_bytes, + 1, + ], + ) + CallOp(matmul_func, [_l1_b, _l1_a, _l1_c_acc]) + yield_([]) + # Convert f32 accumulator → bf16 once per launch. + CallOp(f32_to_bf16_func, [_l1_c_acc, _l1_c_drain]) + dma_memcpy_nd( + _l2c, + _l1_c_drain, + dst_offsets=[_tx, _ty, 0, 0], + dst_sizes=[1, 1, tile_m, tile_n], + dst_strides=[ + herd_n * tile_m * tile_n, + tile_m * tile_n, + tile_n, + 1, + ], + ) + DeallocOp(_l1_a) + DeallocOp(_l1_b) + DeallocOp(_l1_c_acc) + DeallocOp(_l1_c_drain) + + yield_([]) + + col_off_map = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_n * herd_n), + ) + ], + ) + col_off = affine_apply(col_off_map, [lj_s]) + dma_memcpy_nd( + l3_c_s, + l2_c, + dst_offsets=[row_off, col_off], + dst_sizes=[herd_m * tile_m, herd_n * tile_n], + dst_strides=[n, 1], + src_offsets=[0, 0, 0, 0], + src_sizes=[herd_m, tile_m, herd_n, tile_n], + src_strides=[ + herd_n * tile_m * tile_n, + tile_n, + tile_m * tile_n, + 1, + ], + ) + + DeallocOp(l2_a) + DeallocOp(l2_b) + DeallocOp(l2_c) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=64) + parser.add_argument("--k", type=int, default=128) + parser.add_argument("--n", type=int, default=128) + parser.add_argument("--gs", type=int, default=128) + parser.add_argument("--tile-m", type=int, default=16, dest="tile_m") + parser.add_argument("--tile-k-l2", type=int, default=128, dest="tile_k_l2") + parser.add_argument("--tile-k-l1", type=int, default=128, dest="tile_k_l1") + parser.add_argument("--tile-n", type=int, default=16, dest="tile_n") + parser.add_argument("--herd-m", type=int, default=2, dest="herd_m") + parser.add_argument("--herd-n", type=int, default=4, dest="herd_n") + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("-p", "--print-module-only", action="store_true") + parser.add_argument( + "--compile-mode", + choices=["compile-and-run", "compile-only"], + default="compile-and-run", + dest="compile_mode", + ) + args = parser.parse_args() + + module = build_module( + args.m, + args.k, + args.n, + args.gs, + args.tile_m, + args.tile_k_l2, + args.tile_k_l1, + args.tile_n, + args.herd_m, + args.herd_n, + ) + if args.print_module_only: + print(module) + sys.exit(0) + + np.random.seed(42) + W_q_unp = np.random.randint(0, 16, size=(args.n, args.k), dtype=np.uint8) + W_q = (W_q_unp[:, 0::2] | (W_q_unp[:, 1::2] << 4)).astype(np.uint8) + n_groups = args.k // args.gs + W_s = np.random.uniform(0.005, 0.02, size=(n_groups, args.n)).astype(bfloat16) + W_z = np.random.randint(7, 9, size=(n_groups, args.n), dtype=np.uint8) + A = np.random.randn(args.m, args.k).astype(bfloat16) + + PACKED = pack_inputs( + W_q, W_s, W_z, args.m, args.k, args.n, args.gs, args.tile_n, args.tile_k_l1 + ) + C_ref = cpu_reference(W_q, W_s, W_z, A) + + if args.compile_mode == "compile-only": + backend = XRTBackend( + verbose=args.verbose, + omit_while_true_loop=False, + output_format="xclbin", + runtime_loop_tiling_sizes=[2, 2], + stack_size=16384, + ) + backend.compile(module) + backend.unload() + sys.exit(0) + + runner = XRTRunner( + verbose=args.verbose, + omit_while_true_loop=False, + output_format="xclbin", + runtime_loop_tiling_sizes=[2, 2], + stack_size=16384, + ) + sys.exit( + runner.run_test( + module, + inputs=[A, PACKED], + expected_outputs=[C_ref], + rtol=0.1, + atol=0.05, + # bf16 floor: at large K and tight atol a small fraction of + # elements land just outside atol while correlation stays > 0.9999. + max_mismatch_percentage=0.05, + min_correlation=0.999, + ) + ) diff --git a/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit new file mode 100644 index 000000000..6c5a3d1af --- /dev/null +++ b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit @@ -0,0 +1,12 @@ +// (c) Copyright 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT +// +// REQUIRES: ryzen_ai_npu2, peano +// +// RUN: mkdir -p test_int4_gemm_npu2_llama_qproj_seq32_peano +// RUN: cd test_int4_gemm_npu2_llama_qproj_seq32_peano +// RUN: make -f %S/Makefile clean +// +// Llama-3.2-1B Q-projection at prefill seq=32: A[32,2048] @ dequant(W)[2048,2048]. +// RUN: make -f %S/Makefile run_packed M=32 K=2048 N=2048 HERD_M=2 HERD_N=4 PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit new file mode 100644 index 000000000..d79a54428 --- /dev/null +++ b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit @@ -0,0 +1,12 @@ +// (c) Copyright 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT +// +// REQUIRES: ryzen_ai_npu2, peano +// +// RUN: mkdir -p test_int4_gemm_npu2_small_peano +// RUN: cd test_int4_gemm_npu2_small_peano +// RUN: make -f %S/Makefile clean +// +// int4-AWQ GEMM smoke: M=64 K=128 N=128 on a 2x4 herd. +// RUN: make -f %S/Makefile run_packed M=64 K=128 N=128 HERD_M=2 HERD_N=4 PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 3f4776a72..63543a825 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -1,9 +1,9 @@ -//===- mv_int4_bf16.cc - AWQ uint4 weight x bf16 activation matvec --------===// +//===- mv_int4_bf16.cc - AWQ uint4 weight x bf16 activation matvec/matmul -===// // // Copyright (C) 2026, Advanced Micro Devices, Inc. // SPDX-License-Identifier: MIT // -// Per-tile micro-kernels for the int4-AWQ GEMV examples: +// Per-tile micro-kernels for the int4-AWQ GEMV and GEMM examples: // - matvec_int4_bf16_packed(packed, b, c): // c[0..m] += dequant(A)[m, k] @ b[k] // where dequant(A)[r, k] = (q[r, k] - z[r, g(k)]) * s_a[r, g(k)], @@ -13,7 +13,12 @@ // [ S : k/gs * m bf16 ] // [ Z : k/gs * m uint8 ] // Offsets are 32-byte aligned when m and gs are powers of two ≥ 16. +// - matmul_int4_bf16_packed(packed, a, c): +// c[0..m, 0..n] += a[0..m, 0..k] @ dequant(W)[0..k, 0..n] +// W laid out as [n_tile, k/2] (output-major) — same packed layout +// as the GEMV with M renamed to N. Used by the int4-AWQ GEMM prefill. // - zero_vectorized_bf16(c): c[0..m] = 0 +// - zero_vectorized_bf16_mn(c): c[0..m*n] = 0 (vectorized for GEMM) // - partial_plus_r_bf16(p, r, off, d): d[0..m] = p[0..m] + r[off..off+m] // //===----------------------------------------------------------------------===// @@ -24,6 +29,9 @@ #ifndef DIM_M #define DIM_M 8 #endif +#ifndef DIM_N +#define DIM_N 16 +#endif #ifndef DIM_K #define DIM_K 2048 #endif @@ -31,7 +39,17 @@ #define DIM_GS 128 #endif +// matmul_int4_bf16_packed's K is the per-call k_chunk (matches the host +// builder's TILE_K_L1), not the matvec's full K. Default to 128 so a build +// that only consumes the matvec (e.g. DIM_K=2048) doesn't instantiate the +// matmul template at a k_chunk that overflows L1 scratch. +#ifndef DIM_K_CHUNK +#define DIM_K_CHUNK 128 +#endif + static_assert(DIM_K % DIM_GS == 0, "DIM_K must be a multiple of DIM_GS"); +static_assert(DIM_K_CHUNK % DIM_GS == 0, + "DIM_K_CHUNK must be a multiple of DIM_GS"); // int4-AWQ matvec inner kernel. One accfloat accumulator per output row spans // the full K. Within each group we accumulate raw (w_bf16 × b) into a per- @@ -87,6 +105,239 @@ static void zero_impl(bfloat16 *__restrict c) { c[i] = (bfloat16)0.0f; } +// Vectorized zero for the GEMM C tile. Peano auto-vectorizes the scalar +// `for c[i] = 0` loop with a stride-4 store that skips every 4th element +// once the buffer is >= one full vector wide — the bug only manifests +// across repeated kernel calls that read-modify c[]. Explicit aie::store_v +// of aie::zeros avoids it. +template +static void zero_mn_impl(bfloat16 *__restrict c) { + constexpr unsigned VW = 32; + constexpr unsigned NTOT = m_tile * n_tile; + static_assert(NTOT % VW == 0, + "m_tile*n_tile must be a multiple of vector width"); + aie::vector zv = aie::zeros(); + for (unsigned i = 0; i < NTOT; i += VW) + aie::store_v(c + i, zv); +} + +// int4-AWQ GEMM inner kernel using aie::mmul<8,8,8,bf16,bf16,accfloat>. +// Dequant produces UNSCALED bf16 W (just nibble unpack + zero subtract). +// Per (m_b, n_b): preload the f32 c tile, then for each group do an unscaled +// MMUL, convert to f32 vec, multiply by f32 scale (scalar, no extra bf16 +// truncate), and accumulate into the c tile. Across host K-chunk calls the +// c buffer stays f32 (bf16-GEMM accum pattern). Matches mac-kernel's +// "truncate at group sum level, not per W element" precision pattern. +template +void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, + uint8_t *__restrict a_z, bfloat16 *__restrict a, + float *__restrict c) { + constexpr unsigned r = 8, s = 8, t = 8; + constexpr unsigned MB = m_tile / r; + constexpr unsigned NB = n_tile / t; + constexpr unsigned KB = k_chunk / s; + constexpr unsigned KB_PER_G = gs / s; + constexpr unsigned NG = k_chunk / gs; + static_assert(m_tile % (2 * r) == 0, + "m_tile must be multiple of 16 (2x mmul m for 2x2 expansion)"); + static_assert(n_tile % (2 * t) == 0, + "n_tile must be multiple of 16 (2x mmul n for 2x2 expansion)"); + static_assert(k_chunk % s == 0, "k_chunk must be multiple of 8"); + static_assert(gs % s == 0, "gs must be multiple of mmul k-tile (8)"); + static_assert(gs % R == 0, "gs must be multiple of R"); + + ::aie::set_rounding(aie::rounding_mode::conv_even); + + using MMUL = aie::mmul; + + alignas(32) bfloat16 a_pack[KB * MB * r * s]; + alignas(32) bfloat16 b_pack[NB * KB * s * t]; + + // Pack A row-major [m_tile][k_chunk] → [KB][MB][r][s]. + for (unsigned k_b = 0; k_b < KB; k_b++) { + for (unsigned m_b = 0; m_b < MB; m_b++) { + bfloat16 *dst = a_pack + (k_b * MB + m_b) * (r * s); + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::vector v = + aie::load_v(a + (m_b * r + m_i) * k_chunk + k_b * s); + aie::store_v(dst + m_i * s, v); + } + } + } + + // Dequant W → b_pack [NB][KB][t][s] = [NB][KB][n_i][k_i]. For fixed + // (n_b, n_i, k_b) the 8 k_i positions are contiguous → 8-wide vector + // store. mmul reads this with aie::transpose at load (mmul B wants the + // [s][t] order which is the transpose of what dequant produces). + for (unsigned g = 0; g < NG; g++) { + for (unsigned n = 0; n < n_tile; n++) { + aie::vector zv = + aie::broadcast((int8_t)a_z[g * n_tile + n]); + unsigned n_b = n / t; + unsigned n_i = n % t; + const uint8_t *__restrict aq_n = a_q + n * (k_chunk / 2); + + for (unsigned i = 0; i < gs / R; i++) { + unsigned k_base = g * gs + i * R; + aie::vector pk = aie::load_v(aq_n + k_base / 2); + aie::vector w_i8 = + pk.template cast_to().template unpack_sign(false); + w_i8 = aie::sub(w_i8, zv); + aie::vector w_bf16 = aie::to_float(w_i8, 0); + + // R=32 spans R/s=4 mmul k blocks. For each block, store 8 contiguous + // bf16 to b_pack[n_b][k_b][n_i][k_i = 0..7]. + unsigned k_b_base = k_base / s; +#pragma clang loop unroll(full) + for (unsigned j = 0; j < R / s; j++) { + unsigned k_b = k_b_base + j; + bfloat16 *dst = b_pack + n_b * (KB * t * s) + k_b * (t * s) + n_i * s; + aie::vector chunk = w_bf16.template extract(j); + aie::store_v(dst, chunk); + } + } + } + } + + // 2x2 MMUL expansion: each iteration of the (m_b, n_b) super-tile holds + // 4 independent accumulators (C00/C01/C10/C11) so the inner kg loop + // issues 4 MACs/iter against 4 different register-file destinations — + // no serial dep chain on a single accumulator. Each loaded A and B + // vector is reused by 2 MACs. + for (unsigned m_b2 = 0; m_b2 < MB; m_b2 += 2) { + for (unsigned n_b2 = 0; n_b2 < NB; n_b2 += 2) { + float *cdst00 = c + (m_b2 * r) * n_tile + n_b2 * t; + float *cdst01 = c + (m_b2 * r) * n_tile + (n_b2 + 1) * t; + float *cdst10 = c + ((m_b2 + 1) * r) * n_tile + n_b2 * t; + float *cdst11 = c + ((m_b2 + 1) * r) * n_tile + (n_b2 + 1) * t; + + alignas(32) float c_acc_00[r * t]; + alignas(32) float c_acc_01[r * t]; + alignas(32) float c_acc_10[r * t]; + alignas(32) float c_acc_11[r * t]; + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::store_v(c_acc_00 + m_i * t, aie::load_v(cdst00 + m_i * n_tile)); + aie::store_v(c_acc_01 + m_i * t, aie::load_v(cdst01 + m_i * n_tile)); + aie::store_v(c_acc_10 + m_i * t, aie::load_v(cdst10 + m_i * n_tile)); + aie::store_v(c_acc_11 + m_i * t, aie::load_v(cdst11 + m_i * n_tile)); + } + + for (unsigned g = 0; g < NG; g++) { + aie::vector zero_init = aie::zeros(); + MMUL C00(zero_init), C01(zero_init), C10(zero_init), C11(zero_init); + + const bfloat16 *__restrict pA0 = + a_pack + (g * KB_PER_G * MB + m_b2) * (r * s); + const bfloat16 *__restrict pA1 = + a_pack + (g * KB_PER_G * MB + m_b2 + 1) * (r * s); + const bfloat16 *__restrict pB0 = + b_pack + (n_b2 * KB + g * KB_PER_G) * (s * t); + const bfloat16 *__restrict pB1 = + b_pack + ((n_b2 + 1) * KB + g * KB_PER_G) * (s * t); + + chess_prepare_for_pipelining chess_loop_range( + 1, ) for (unsigned kg = 0; kg < KB_PER_G; kg++) { + aie::vector A0 = + aie::load_v(pA0); + pA0 += MB * (r * s); + aie::vector A1 = + aie::load_v(pA1); + pA1 += MB * (r * s); + // b_pack stores tiles in [t=n_i][s=k_i] order (dequant friendly); + // mmul wants [s=k_i][t=n_i], so transpose per load. + aie::vector B0 = + aie::transpose(aie::load_v(pB0), t, s); + pB0 += s * t; + aie::vector B1 = + aie::transpose(aie::load_v(pB1), t, s); + pB1 += s * t; + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + } + + // Per-group scale fold (cold path — runs NG=1 times for the + // production gs=k_chunk config). Two scale broadcasts (one per + // n-block); each applies to both m-block rows. + alignas(32) bfloat16 scale0_buf[t], scale1_buf[t]; + for (unsigned n_i = 0; n_i < t; n_i++) { + scale0_buf[n_i] = a_s[g * n_tile + n_b2 * t + n_i]; + scale1_buf[n_i] = a_s[g * n_tile + (n_b2 + 1) * t + n_i]; + } + aie::vector s0 = aie::load_v(scale0_buf); + aie::vector s1 = aie::load_v(scale1_buf); + alignas(32) bfloat16 s0_tile_buf[r * t]; + alignas(32) bfloat16 s1_tile_buf[r * t]; + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::store_v(s0_tile_buf + m_i * t, s0); + aie::store_v(s1_tile_buf + m_i * t, s1); + } + aie::vector s0_tile = aie::load_v(s0_tile_buf); + aie::vector s1_tile = aie::load_v(s1_tile_buf); + + auto fold = [&](MMUL &C, aie::vector &scale_tile, + float *c_acc) { + aie::vector c_bf16 = + C.template to_vector(); + aie::accum scaled_acc = aie::mul(c_bf16, scale_tile); + aie::vector scaled_f32 = + scaled_acc.template to_vector(); + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::vector inc = scaled_f32.template extract(m_i); + aie::vector old = aie::load_v(c_acc + m_i * t); + aie::vector sum = aie::add(old, inc); + aie::store_v(c_acc + m_i * t, sum); + } + }; + fold(C00, s0_tile, c_acc_00); + fold(C01, s1_tile, c_acc_01); + fold(C10, s0_tile, c_acc_10); + fold(C11, s1_tile, c_acc_11); + } + + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::store_v(cdst00 + m_i * n_tile, aie::load_v(c_acc_00 + m_i * t)); + aie::store_v(cdst01 + m_i * n_tile, aie::load_v(c_acc_01 + m_i * t)); + aie::store_v(cdst10 + m_i * n_tile, aie::load_v(c_acc_10 + m_i * t)); + aie::store_v(cdst11 + m_i * n_tile, aie::load_v(c_acc_11 + m_i * t)); + } + } + } +} + +// f32 zero for the GEMM C accumulator (kept in f32 across host K-chunk +// iterations to avoid bf16 truncation of partial sums). +template +static void zero_mn_f32_impl(float *__restrict c) { + constexpr unsigned VW = 16; + constexpr unsigned NTOT = m_tile * n_tile; + static_assert(NTOT % VW == 0, + "m_tile*n_tile must be a multiple of f32 vector width"); + aie::vector zv = aie::zeros(); + for (unsigned i = 0; i < NTOT; i += VW) + aie::store_v(c + i, zv); +} + +// f32 → bf16 narrowing for the final L1 C drain. Run after the host's K-loop +// completes; one call per (m_tile × n_tile) tile. +template +static void f32_to_bf16_mn_impl(const float *__restrict src, + bfloat16 *__restrict dst) { + ::aie::set_rounding(aie::rounding_mode::conv_even); + constexpr unsigned VW = 16; + constexpr unsigned NTOT = m_tile * n_tile; + static_assert(NTOT % VW == 0, "m_tile*n_tile must be a multiple of VW"); + for (unsigned i = 0; i < NTOT; i += VW) { + aie::vector v = aie::load_v(src + i); + aie::vector vb; + for (unsigned j = 0; j < VW; j++) + vb[j] = (bfloat16)v[j]; + aie::store_v(dst + i, vb); + } +} + template static void partial_plus_r_impl(const bfloat16 *__restrict partial, const bfloat16 *__restrict r_full, int offset, @@ -110,6 +361,33 @@ void matvec_int4_bf16_packed(uint8_t *packed, bfloat16 *b, bfloat16 *c) { void zero_vectorized_bf16(bfloat16 *c) { zero_impl(c); } +// Matmul-only entries are guarded so matvec builds (DIM_M=8) don't +// instantiate mm_int4_bf16_mmul_impl, which static-asserts m_tile and +// n_tile are multiples of 16 (2x mmul m/n for the 2x2 expansion). +#if DIM_M >= 16 && DIM_N >= 16 +void zero_vectorized_bf16_mn(bfloat16 *c) { zero_mn_impl(c); } + +void zero_vectorized_f32_mn(float *c) { zero_mn_f32_impl(c); } + +void f32_to_bf16_mn(float *src, bfloat16 *dst) { + f32_to_bf16_mn_impl(src, dst); +} + +// Packed-BO GEMM entry, f32 output. Same Q+S+Z packing as the GEMV +// (output-major W); c is kept in f32 across host-side K-chunk iterations so +// the per-K-chunk partial sums don't bf16-truncate. Convert to bf16 once at +// the end with f32_to_bf16_mn. +void matmul_int4_bf16_packed_f32(uint8_t *packed, bfloat16 *a, float *c) { + constexpr unsigned Q_BYTES = DIM_N * (DIM_K_CHUNK / 2); + constexpr unsigned S_BYTES = (DIM_K_CHUNK / DIM_GS) * DIM_N * 2; + uint8_t *a_q = packed; + bfloat16 *a_s = reinterpret_cast(packed + Q_BYTES); + uint8_t *a_z = packed + Q_BYTES + S_BYTES; + mm_int4_bf16_mmul_impl(a_q, a_s, a_z, a, + c); +} +#endif + void partial_plus_r_bf16(bfloat16 *partial, bfloat16 *r_full, int offset, bfloat16 *d) { partial_plus_r_impl(partial, r_full, offset, d);