From 35568f0debfdfd3abe4f34bb4555939528b03444 Mon Sep 17 00:00:00 2001 From: erweiw Date: Sun, 31 May 2026 22:45:40 -0700 Subject: [PATCH 01/10] [programming_examples] int4-AWQ GEMM kernel + standalone example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds matmul_int4_bf16_packed alongside the existing matvec entries in mv_int4_bf16.cc, plus a programming_examples/matrix_multiplication/int4_awq host builder, Makefile and lit tests. The GEMM reuses the GEMV's Q+S+Z per-tile packed BO layout (output-major), extended to m_tile activation rows; one .o file serves both decode (GEMV) and prefill (GEMM). Also adds zero_vectorized_bf16_mn — explicit aie::store_v of aie::zeros for the larger GEMM C tile. Peano auto-vectorizes a scalar `for c[i]=0` loop on AIE2P with a stride-4 store that skips every 4th element once the buffer is >= one full vector wide; manifests as repeated kernel calls reading stale c[]. The existing GEMV's DIM_M=8 stays under the vectorization threshold so it was unaffected. Tested on NPU2: - Smoke (M=32 K=128 N=64, exercises M_div>1 + N_div>1): corr 0.999997 - Llama Q-proj at prefill seq=32 (M=32 K=2048 N=2048): corr 0.999985 - GEMV regression (M=2048 K=2048): unchanged, corr 0.999997 Co-Authored-By: Claude Opus 4.7 (1M context) --- programming_examples/generate_readme.py | 6 + .../matrix_multiplication/int4_awq/Makefile | 58 +++ .../int4_awq/matmul_int4_packed.py | 370 ++++++++++++++++++ ...un_packed_npu2_llama_qproj_seq32_peano.lit | 12 + .../int4_awq/run_packed_npu2_small_peano.lit | 12 + .../int4_awq/mv_int4_bf16.cc | 92 ++++- 6 files changed, 548 insertions(+), 2 deletions(-) create mode 100644 programming_examples/matrix_multiplication/int4_awq/Makefile create mode 100644 programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py create mode 100644 programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit create mode 100644 programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit diff --git a/programming_examples/generate_readme.py b/programming_examples/generate_readme.py index 3363e5390..b05dc46d8 100644 --- a/programming_examples/generate_readme.py +++ b/programming_examples/generate_readme.py @@ -60,6 +60,12 @@ "path": "matrix_vector_multiplication/int4_awq", "datatypes": "int4 weights / bf16 activations", }, + { + "category": "Linear Algebra", + "name": "Matrix Multiplication (AWQ int4)", + "path": "matrix_multiplication/int4_awq", + "datatypes": "int4 weights / bf16 activations", + }, { "category": "Linear Algebra", "name": "AXPY", diff --git a/programming_examples/matrix_multiplication/int4_awq/Makefile b/programming_examples/matrix_multiplication/int4_awq/Makefile new file mode 100644 index 000000000..ef183034d --- /dev/null +++ b/programming_examples/matrix_multiplication/int4_awq/Makefile @@ -0,0 +1,58 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +# +# int4 AWQ GEMM example (packed Q+S+Z BO). Reuses mv_int4_bf16.cc from the +# matrix_vector_multiplication/int4_awq examples — the GEMM symbol is in the +# same .o file as the GEMV/zero/partial_plus_r helpers. + +srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) +INT4_SRCDIR := $(srcdir)/../../matrix_vector_multiplication/int4_awq + +ifdef PEANO_INSTALL_DIR + BUILD_DIR := build_peano +else + BUILD_DIR := build_chess +endif + +OUTPUT_FORMAT ?= elf +OUTPUT_FORMAT_FLAG = --output-format $(OUTPUT_FORMAT) + +# Shapes / tiling (overridable). Defaults match the smallest correctness test. +M ?= 32 +K ?= 128 +N ?= 64 +GS ?= 128 +M_TILE ?= 16 +N_TILE ?= 16 +K_CHUNK ?= 128 +N_CORES ?= 4 + +AIEOPT_DIR = $(shell realpath $(dir $(shell which aie-opt))/..) +WARNING_FLAGS = -Wno-parentheses -Wno-attributes -Wno-macro-redefined -Wno-empty-body +PEANOWRAP2P_FLAGS = -O2 -std=c++20 --target=aie2p-none-unknown-elf ${WARNING_FLAGS} -DNDEBUG -I ${AIEOPT_DIR}/include + +PY_ARGS = --m $(M) --k $(K) --n $(N) --gs $(GS) --m-tile $(M_TILE) --n-tile $(N_TILE) --k-chunk $(K_CHUNK) --n-cores $(N_CORES) + +all: run_packed + +print_packed: + ${powershell} python3 ${srcdir}/matmul_int4_packed.py $(OUTPUT_FORMAT_FLAG) -p $(PY_ARGS) + +compile-kernel: + mkdir -p $(BUILD_DIR) + @if [ -z "$(PEANO_INSTALL_DIR)" ]; then \ + echo "Error: PEANO_INSTALL_DIR not set (source utils/env_setup.sh)."; \ + exit 1; \ + fi + $(PEANO_INSTALL_DIR)/bin/clang++ ${PEANOWRAP2P_FLAGS} \ + -DDIM_M=$(M_TILE) -DDIM_N=$(N_TILE) -DDIM_K=$(K_CHUNK) -DDIM_GS=$(GS) \ + -DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16 \ + -c $(INT4_SRCDIR)/mv_int4_bf16.cc -o $(BUILD_DIR)/mv_int4_bf16.o + +run_packed: compile-kernel + mkdir -p $(BUILD_DIR) + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py $(OUTPUT_FORMAT_FLAG) $(PY_ARGS) + +clean: + rm -rf $(BUILD_DIR) __pycache__ diff --git a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py new file mode 100644 index 000000000..d7f152c61 --- /dev/null +++ b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py @@ -0,0 +1,370 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT +# +# int4-AWQ GEMM (prefill): C[M, N] = A[M, K] @ dequant(W)[K, N]. +# +# Weight is laid out as W_q[N, K/2] (output-major) so the per-tile packed +# layout matches the int4-AWQ GEMV. One packed L3 BO per tile: +# [ Q : N_TILE * K_CHUNK/2 bytes uint8 ] +# [ S : K_CHUNK/GS * N_TILE bf16 ] +# [ Z : K_CHUNK/GS * N_TILE uint8 ] +# Single multi-dim BD per shim channel keeps the 2-S2MM-per-tile budget +# (packed Q+S+Z on one S2MM, A on the other). +# +# Herd is 1D over N (sizes=[N_CORES, 1]). Each core owns N_per_core=N/N_CORES +# output columns and loops over (n_outer, m_outer, k_outer). M is handled by +# the serial M-outer loop in the core; multi-launch over M would extend this. + +import argparse + +import numpy as np +from ml_dtypes import bfloat16 + +from air.ir import ( + AffineConstantExpr, + AffineExpr, + AffineMap, + AffineSymbolExpr, + BF16Type, + IntegerAttr, + IntegerType, + MemRefType, + StringAttr, + UnitAttr, +) +from air.dialects.affine import apply as affine_apply +from air.dialects.air import ( + Channel, + ChannelGet, + ChannelPut, + MemorySpace, + T, + herd, + launch, + module_builder, + segment, +) +from air.dialects.air import channel as channel_decl +from air.dialects.func import FuncOp, CallOp +from air.dialects.memref import AllocOp, DeallocOp +from air.dialects import arith +from air.dialects.scf import for_, yield_ +from air.backend.xrt import XRTBackend +from air.backend.xrt_runner import XRTRunner + +KERNEL_OBJ_NAME = "mv_int4_bf16.o" + + +def pack_inputs(W_q, W_s, W_z, M, K, N, GS, M_TILE, N_TILE, K_CHUNK, N_CORES): + """Pack per-(n_outer, k_outer) Q+S+Z tiles into a single L3 buffer. + + Output: uint8 [N_CORES * N_div * K_div, tile_bytes] where each core gets a + contiguous slab of N_div*K_div tiles (n_outer is outermost within a core). + W_q shape: [N, K/2] uint8 (output-major, K packed 2 nibbles per byte). + W_s shape: [K/GS, N] bf16. + W_z shape: [K/GS, N] uint8. + """ + n_gpc = K_CHUNK // GS + q_bytes = N_TILE * (K_CHUNK // 2) + s_bytes = n_gpc * N_TILE * 2 + z_bytes = n_gpc * N_TILE + tile_bytes = q_bytes + s_bytes + z_bytes + + N_per_core = N // N_CORES + N_div = N_per_core // N_TILE + K_div = K // K_CHUNK + + total_tiles = N_CORES * N_div * K_div + packed = np.zeros((total_tiles, tile_bytes), dtype=np.uint8) + + tile_idx = 0 + for c in range(N_CORES): + base_col = c * N_per_core + for n_outer in range(N_div): + col_off = base_col + n_outer * N_TILE + for kc in range(K_div): + q_col_byte = kc * (K_CHUNK // 2) + g_off = kc * n_gpc + q_tile = W_q[ + col_off : col_off + N_TILE, + q_col_byte : q_col_byte + (K_CHUNK // 2), + ] + s_tile = W_s[g_off : g_off + n_gpc, col_off : col_off + N_TILE] + z_tile = W_z[g_off : g_off + n_gpc, col_off : col_off + N_TILE] + p = packed[tile_idx] + p[0:q_bytes] = np.ascontiguousarray(q_tile).view(np.uint8).reshape(-1) + p[q_bytes : q_bytes + s_bytes] = ( + np.ascontiguousarray(s_tile).view(np.uint8).reshape(-1) + ) + p[q_bytes + s_bytes :] = ( + np.ascontiguousarray(z_tile).view(np.uint8).reshape(-1) + ) + tile_idx += 1 + return packed + + +def build_module(M, K, N, GS=128, M_TILE=16, N_TILE=16, K_CHUNK=128, N_CORES=4): + assert M % M_TILE == 0 + M_div = M // M_TILE + assert N % N_CORES == 0 + N_per_core = N // N_CORES + assert N_per_core % N_TILE == 0 + N_div = N_per_core // N_TILE + assert K % K_CHUNK == 0 + assert K_CHUNK % GS == 0 + K_div = K // K_CHUNK + n_gpc = K_CHUNK // GS + + total_tiles = N_CORES * N_div * K_div + + q_bytes = N_TILE * (K_CHUNK // 2) + s_bytes = n_gpc * N_TILE * 2 + z_bytes = n_gpc * N_TILE + tile_bytes = q_bytes + s_bytes + z_bytes + + assert q_bytes % 32 == 0 + assert (q_bytes + s_bytes) % 32 == 0 + + tiles_per_core = N_div * K_div + + @module_builder + def build(): + bf16_ty = BF16Type.get() + i8_ty = IntegerType.get_signless(8) + + packed_l3 = MemRefType.get([total_tiles, tile_bytes], i8_ty) + A_l3 = MemRefType.get([M, K], bf16_ty) + C_l3 = MemRefType.get([M, N], bf16_ty) + + l1_ms = IntegerAttr.get(T.i32(), MemorySpace.L1) + + packed_l1 = MemRefType.get([tile_bytes], i8_ty, memory_space=l1_ms) + A_l1 = MemRefType.get([M_TILE, K_CHUNK], bf16_ty, memory_space=l1_ms) + C_l1 = MemRefType.get([M_TILE, N_TILE], bf16_ty, memory_space=l1_ms) + + channel_decl("inL3", size=[N_CORES]) + Channel("inA", size=[1, 1], broadcast_shape=[N_CORES, 1]) + channel_decl("outC", size=[N_CORES]) + + zero_func = FuncOp( + "zero_vectorized_bf16_mn", ([C_l1], []), visibility="private" + ) + zero_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + zero_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + + matmul_func = FuncOp( + "matmul_int4_bf16_packed", + ([packed_l1, A_l1, C_l1], []), + visibility="private", + ) + matmul_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + matmul_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + + # Launch's lj axis iterates m_outer; lj * M_TILE is the row offset + # into A and C for this launch. + lj_to_row_map = AffineMap.get( + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(M_TILE), + ) + ], + ) + + @FuncOp.from_py_func(packed_l3, A_l3, C_l3) + def matmul_int4_packed(PACKED, A, C): + @launch(sizes=[1, M_div], operands=[PACKED, A, C]) + def launch_body(li, lj, lsx, lsy, packed, a, c): + m_row_off = affine_apply(lj_to_row_map, [lj]) + for cc in range(N_CORES): + c_idx = arith.ConstantOp.create_index(cc) + c_tile_const = arith.ConstantOp.create_index(cc * tiles_per_core) + # Packed weight: each launch re-streams all N_div*K_div + # tiles for this core. No stride-0 dim — matches GEMV + # packed-put pattern. + ChannelPut( + "inL3", + packed, + indices=[c_idx], + offsets=[c_tile_const, 0], + sizes=[tiles_per_core, tile_bytes], + strides=[tile_bytes, 1], + ) + # Output C: per launch, write M_TILE rows x N_per_core + # cols starting at row m_row_off, col cc*N_per_core. + c_n_const = arith.ConstantOp.create_index(cc * N_per_core) + ChannelGet( + "outC", + c, + indices=[c_idx], + offsets=[0, m_row_off, c_n_const], + sizes=[N_div, M_TILE, N_TILE], + strides=[N_TILE, N, 1], + ) + + # A: per launch, broadcast the M_TILE-row band to all cores. + # n_outer stride 0 = replay the same A tile for each n_outer. + # Matches GEMV B-put stride-0 outer pattern. + ChannelPut( + "inA", + a, + offsets=[0, 0, m_row_off, 0], + sizes=[N_div, K_div, M_TILE, K_CHUNK], + strides=[0, K_CHUNK, K, 1], + ) + + @segment(name="seg") + def segment_body(): + @herd(name="mm_h", sizes=[N_CORES, 1]) + def herd_body(tx, ty, _sx, _sy): + for _n_outer in for_(N_div): + l1_c_op = AllocOp(C_l1, [], []) + CallOp(zero_func, [l1_c_op]) + for _ in for_(K_div): + l1_p_op = AllocOp(packed_l1, [], []) + l1_a_op = AllocOp(A_l1, [], []) + ChannelGet("inL3", l1_p_op, indices=[tx]) + ChannelGet("inA", l1_a_op, indices=[tx, ty]) + CallOp( + matmul_func, + [l1_p_op, l1_a_op, l1_c_op], + ) + DeallocOp(l1_p_op) + DeallocOp(l1_a_op) + yield_([]) + ChannelPut("outC", l1_c_op, indices=[tx]) + DeallocOp(l1_c_op) + yield_([]) + + herd_body.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + herd_body.attributes["x_loc"] = IntegerAttr.get(T.i64(), 0) + herd_body.attributes["y_loc"] = IntegerAttr.get(T.i64(), 2) + + return build() + + +def cpu_reference(W_q, W_s, W_z, A): + """W is stored as [N, K/2] uint8 (output-major). dequant(W)[k, n].""" + N_ = W_q.shape[0] + K_ = A.shape[1] + M_ = A.shape[0] + n_groups = W_s.shape[0] + gs = K_ // n_groups + Af = A.astype(np.float32) + W_s_f = W_s.astype(np.float32) + W_z_i = W_z.astype(np.int32) + + # Dequantize W into [K, N] f32. + W_dq = np.zeros((K_, N_), dtype=np.float32) + for n in range(N_): + for kk in range(K_): + byte = int(W_q[n, kk // 2]) + nib = (byte & 0x0F) if (kk % 2 == 0) else ((byte >> 4) & 0x0F) + g = kk // gs + W_dq[kk, n] = (nib - W_z_i[g, n]) * W_s_f[g, n] + C = Af @ W_dq + return C.astype(bfloat16) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="matmul_int4_packed.py", + description="int4-AWQ GEMM: C[M,N] = A[M,K] @ dequant(W)[K,N]", + ) + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("-p", "--print-module-only", action="store_true") + parser.add_argument("--m", type=int, default=32) + parser.add_argument("--k", type=int, default=128) + parser.add_argument("--n", type=int, default=64) + parser.add_argument("--gs", type=int, default=128) + parser.add_argument("--m-tile", type=int, default=16, dest="m_tile") + parser.add_argument("--n-tile", type=int, default=16, dest="n_tile") + parser.add_argument("--k-chunk", type=int, default=128, dest="k_chunk") + parser.add_argument("--n-cores", type=int, default=4, dest="n_cores") + parser.add_argument( + "--output-format", + type=str, + choices=["xclbin", "elf"], + default="elf", + ) + parser.add_argument( + "--compile-mode", + type=str, + choices=["compile-and-run", "compile-only"], + default="compile-and-run", + dest="compile_mode", + ) + args = parser.parse_args() + + module = build_module( + args.m, + args.k, + args.n, + GS=args.gs, + M_TILE=args.m_tile, + N_TILE=args.n_tile, + K_CHUNK=args.k_chunk, + N_CORES=args.n_cores, + ) + if args.print_module_only: + print(module) + exit(0) + + if args.compile_mode == "compile-only": + backend = XRTBackend( + verbose=args.verbose, + omit_while_true_loop=False, + omit_pingpong=True, + output_format=args.output_format, + instance_name="matmul_int4_packed", + use_lock_race_condition_fix=True, + stack_size=16384, + ) + backend.compile(module) + backend.unload() + exit(0) + + np.random.seed(42) + W_q_unp = np.random.randint(0, 16, size=(args.n, args.k), dtype=np.uint8) + W_q = (W_q_unp[:, 0::2] | (W_q_unp[:, 1::2] << 4)).astype(np.uint8) + n_groups = args.k // args.gs + W_s = np.random.uniform(0.005, 0.02, size=(n_groups, args.n)).astype(bfloat16) + W_z = np.random.randint(7, 9, size=(n_groups, args.n), dtype=np.uint8) + A = np.random.randn(args.m, args.k).astype(bfloat16) + + C_ref = cpu_reference(W_q, W_s, W_z, A) + PACKED = pack_inputs( + W_q, + W_s, + W_z, + args.m, + args.k, + args.n, + args.gs, + args.m_tile, + args.n_tile, + args.k_chunk, + args.n_cores, + ) + + runner = XRTRunner( + verbose=args.verbose, + omit_while_true_loop=False, + omit_pingpong=True, + output_format=args.output_format, + instance_name="matmul_int4_packed", + use_lock_race_condition_fix=True, + stack_size=16384, + ) + exit( + runner.run_test( + module, + inputs=[PACKED, A], + expected_outputs=[C_ref], + rtol=0.1, + atol=0.05, + min_correlation=0.999, + ) + ) diff --git a/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit new file mode 100644 index 000000000..f27f6eb7a --- /dev/null +++ b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit @@ -0,0 +1,12 @@ +// (c) Copyright 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT +// +// REQUIRES: ryzen_ai_npu2, peano +// +// RUN: mkdir -p test_int4_gemm_npu2_llama_qproj_seq32_peano +// RUN: cd test_int4_gemm_npu2_llama_qproj_seq32_peano +// RUN: make -f %S/Makefile clean +// +// Llama-3.2-1B Q-projection at prefill seq=32: A[32,2048] @ dequant(W)[2048,2048]. +// RUN: make -f %S/Makefile run_packed M=32 K=2048 N=2048 N_CORES=8 OUTPUT_FORMAT=elf PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit new file mode 100644 index 000000000..c775068c1 --- /dev/null +++ b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit @@ -0,0 +1,12 @@ +// (c) Copyright 2026 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: MIT +// +// REQUIRES: ryzen_ai_npu2, peano +// +// RUN: mkdir -p test_int4_gemm_npu2_small_peano +// RUN: cd test_int4_gemm_npu2_small_peano +// RUN: make -f %S/Makefile clean +// +// int4-AWQ GEMM smoke: M=32, K=128, N=64 (exercises M_div>1 + N_div>1). +// RUN: make -f %S/Makefile run_packed M=32 K=128 N=64 OUTPUT_FORMAT=elf PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// CHECK: PASS! diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 3f4776a72..8f6d00585 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -1,9 +1,9 @@ -//===- mv_int4_bf16.cc - AWQ uint4 weight x bf16 activation matvec --------===// +//===- mv_int4_bf16.cc - AWQ uint4 weight x bf16 activation matvec/matmul -===// // // Copyright (C) 2026, Advanced Micro Devices, Inc. // SPDX-License-Identifier: MIT // -// Per-tile micro-kernels for the int4-AWQ GEMV examples: +// Per-tile micro-kernels for the int4-AWQ GEMV and GEMM examples: // - matvec_int4_bf16_packed(packed, b, c): // c[0..m] += dequant(A)[m, k] @ b[k] // where dequant(A)[r, k] = (q[r, k] - z[r, g(k)]) * s_a[r, g(k)], @@ -13,7 +13,12 @@ // [ S : k/gs * m bf16 ] // [ Z : k/gs * m uint8 ] // Offsets are 32-byte aligned when m and gs are powers of two ≥ 16. +// - matmul_int4_bf16_packed(packed, a, c): +// c[0..m, 0..n] += a[0..m, 0..k] @ dequant(W)[0..k, 0..n] +// W laid out as [n_tile, k/2] (output-major) — same packed layout +// as the GEMV with M renamed to N. Used by the int4-AWQ GEMM prefill. // - zero_vectorized_bf16(c): c[0..m] = 0 +// - zero_vectorized_bf16_mn(c): c[0..m*n] = 0 (vectorized for GEMM) // - partial_plus_r_bf16(p, r, off, d): d[0..m] = p[0..m] + r[off..off+m] // //===----------------------------------------------------------------------===// @@ -24,6 +29,9 @@ #ifndef DIM_M #define DIM_M 8 #endif +#ifndef DIM_N +#define DIM_N 16 +#endif #ifndef DIM_K #define DIM_K 2048 #endif @@ -87,6 +95,73 @@ static void zero_impl(bfloat16 *__restrict c) { c[i] = (bfloat16)0.0f; } +// Vectorized zero for the GEMM C tile. Peano auto-vectorizes the scalar +// `for c[i] = 0` loop with a stride-4 store that skips every 4th element +// once the buffer is >= one full vector wide — the bug only manifests +// across repeated kernel calls that read-modify c[]. Explicit aie::store_v +// of aie::zeros avoids it. +template +static void zero_mn_impl(bfloat16 *__restrict c) { + constexpr unsigned VW = 32; + constexpr unsigned NTOT = m_tile * n_tile; + static_assert(NTOT % VW == 0, + "m_tile*n_tile must be a multiple of vector width"); + aie::vector zv = aie::zeros(); + for (unsigned i = 0; i < NTOT; i += VW) + aie::store_v(c + i, zv); +} + +// int4-AWQ GEMM inner kernel. Same inline dequant+MAC chain as the GEMV, +// extended to m_tile activation rows. Each (mi, n) accumulates a full +// K-chunk's contribution into c[mi, n]. +template +void mm_int4_bf16_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, + uint8_t *__restrict a_z, bfloat16 *__restrict a, + bfloat16 *__restrict c) { + ::aie::set_rounding(aie::rounding_mode::conv_even); + static_assert(gs % r == 0, "group size must be multiple of inner vector r"); + constexpr unsigned NSUB = gs / r; + constexpr unsigned NG = k_chunk / gs; + + for (unsigned mi = 0; mi < m_tile; mi++) { + for (unsigned n = 0; n < n_tile; n++) { + aie::accum acc; + acc.from_vector(aie::zeros()); + const uint8_t *__restrict aq_n = a_q + n * (k_chunk / 2); + + for (unsigned g = 0; g < NG; g++) { + aie::vector zv = + aie::broadcast((int8_t)a_z[g * n_tile + n]); + bfloat16 sa = a_s[g * n_tile + n]; + + aie::accum g_acc; + g_acc.from_vector(aie::zeros()); + +#pragma clang loop unroll(full) + for (unsigned i = 0; i < NSUB; i++) { + const unsigned off = (g * gs + i * r) / 2; + aie::vector packed = aie::load_v(aq_n + off); + aie::vector w_int8 = + packed.template cast_to().template unpack_sign( + false); + w_int8 = aie::sub(w_int8, zv); + aie::vector w_bf16 = aie::to_float(w_int8, 0); + aie::vector a_vec = + aie::load_v(a + mi * k_chunk + g * gs + i * r); + g_acc = aie::mac(g_acc, w_bf16, a_vec); + } + + aie::vector g_bf16 = g_acc.template to_vector(); + acc = aie::mac(acc, g_bf16, sa); + } + + float s = aie::reduce_add(acc.template to_vector()); + c[mi * n_tile + n] = (bfloat16)((float)c[mi * n_tile + n] + s); + } + } +} + template static void partial_plus_r_impl(const bfloat16 *__restrict partial, const bfloat16 *__restrict r_full, int offset, @@ -110,6 +185,19 @@ void matvec_int4_bf16_packed(uint8_t *packed, bfloat16 *b, bfloat16 *c) { void zero_vectorized_bf16(bfloat16 *c) { zero_impl(c); } +void zero_vectorized_bf16_mn(bfloat16 *c) { zero_mn_impl(c); } + +// Packed-BO GEMM entry. Same Q+S+Z packing as the GEMV (output-major W), +// driven by an m_tile-row activation tile a[]. +void matmul_int4_bf16_packed(uint8_t *packed, bfloat16 *a, bfloat16 *c) { + constexpr unsigned Q_BYTES = DIM_N * (DIM_K / 2); + constexpr unsigned S_BYTES = (DIM_K / DIM_GS) * DIM_N * 2; + uint8_t *a_q = packed; + bfloat16 *a_s = reinterpret_cast(packed + Q_BYTES); + uint8_t *a_z = packed + Q_BYTES + S_BYTES; + mm_int4_bf16_impl(a_q, a_s, a_z, a, c); +} + void partial_plus_r_bf16(bfloat16 *partial, bfloat16 *r_full, int offset, bfloat16 *d) { partial_plus_r_impl(partial, r_full, offset, d); From 45225fb4025ad63bd734bb19dce76bcd4923276d Mon Sep 17 00:00:00 2001 From: erweiw Date: Sun, 31 May 2026 22:56:48 -0700 Subject: [PATCH 02/10] [programming_examples] address Copilot review on #1639 - matmul_int4_packed.py: assert kernel-side static_assert constraints (GS % 32 == 0, M_TILE*N_TILE % 32 == 0) at module-build time so unsupported tilings fail with a Python message instead of a C++ template/static_assert error from the kernel build. - mv_int4_bf16.cc: restructure mm_int4_bf16_impl loop nest to (mi, g, i, n) so each activation load a_vec(mi, g, i) is reused across all n_tile output columns instead of being reloaded per n. Per-(g, n) zero-point broadcasts and per-(mi, n) accumulators are hoisted out of the i loop too. Inner hot path stays load-packed + unpack + sub + cvt + mac. Acceptance unchanged on NPU2: - Smoke M=32 K=128 N=64: corr 0.999997 - Llama Q-proj seq=32 M=32 K=2048 N=2048: corr 0.999985 - GEMV regression M=2048 K=2048: corr 0.999997 Co-Authored-By: Claude Opus 4.7 (1M context) --- .../int4_awq/matmul_int4_packed.py | 11 ++++ .../int4_awq/mv_int4_bf16.cc | 65 ++++++++++++------- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py index d7f152c61..206d9d8bc 100644 --- a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py +++ b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py @@ -104,6 +104,17 @@ def pack_inputs(W_q, W_s, W_z, M, K, N, GS, M_TILE, N_TILE, K_CHUNK, N_CORES): def build_module(M, K, N, GS=128, M_TILE=16, N_TILE=16, K_CHUNK=128, N_CORES=4): + # Kernel-side static_assert constraints (mm_int4_bf16_impl uses r=32 inner + # vector and zero_mn_impl uses VW=32). Surface them at module-build time + # so unsupported tilings fail with a Python message, not a C++ template + # error during compile-kernel. + _R = 32 + assert GS % _R == 0, f"GS ({GS}) must be a multiple of inner vector width {_R}" + assert (M_TILE * N_TILE) % _R == 0, ( + f"M_TILE*N_TILE ({M_TILE}*{N_TILE}={M_TILE * N_TILE}) must be a " + f"multiple of vector width {_R} for zero_vectorized_bf16_mn" + ) + assert M % M_TILE == 0 M_div = M // M_TILE assert N % N_CORES == 0 diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 8f6d00585..e52bc60de 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -111,9 +111,12 @@ static void zero_mn_impl(bfloat16 *__restrict c) { aie::store_v(c + i, zv); } -// int4-AWQ GEMM inner kernel. Same inline dequant+MAC chain as the GEMV, -// extended to m_tile activation rows. Each (mi, n) accumulates a full -// K-chunk's contribution into c[mi, n]. +// int4-AWQ GEMM inner kernel. Loop nest is (mi, g, i, n) so each +// activation load a_vec(mi, g, i) is reused across all n_tile output +// columns instead of being reloaded per n. Per-group zero-point broadcast +// zv(g, n) and per-(mi, n) accumulator acc[n] are likewise hoisted out of +// the inner i loop, leaving just load-packed + unpack + sub + cvt + mac +// in the hot path. template void mm_int4_bf16_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, @@ -125,38 +128,52 @@ void mm_int4_bf16_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, constexpr unsigned NG = k_chunk / gs; for (unsigned mi = 0; mi < m_tile; mi++) { - for (unsigned n = 0; n < n_tile; n++) { - aie::accum acc; - acc.from_vector(aie::zeros()); - const uint8_t *__restrict aq_n = a_q + n * (k_chunk / 2); - - for (unsigned g = 0; g < NG; g++) { - aie::vector zv = - aie::broadcast((int8_t)a_z[g * n_tile + n]); - bfloat16 sa = a_s[g * n_tile + n]; - - aie::accum g_acc; - g_acc.from_vector(aie::zeros()); + // Per-(mi, n) accumulator spans all K-groups; reduce_add at the end. + aie::accum acc[n_tile]; + for (unsigned n = 0; n < n_tile; n++) + acc[n].from_vector(aie::zeros()); + + for (unsigned g = 0; g < NG; g++) { + // Hoist per-(g, n) zero-point broadcasts out of the i loop. + aie::vector zv[n_tile]; + for (unsigned n = 0; n < n_tile; n++) + zv[n] = aie::broadcast((int8_t)a_z[g * n_tile + n]); + + // Per-(mi, g, n) intra-group accumulator over the NSUB sub-blocks. + aie::accum g_acc[n_tile]; + for (unsigned n = 0; n < n_tile; n++) + g_acc[n].from_vector(aie::zeros()); #pragma clang loop unroll(full) - for (unsigned i = 0; i < NSUB; i++) { - const unsigned off = (g * gs + i * r) / 2; + for (unsigned i = 0; i < NSUB; i++) { + // Single a_vec load per (mi, g, i) reused across all n_tile cols. + aie::vector a_vec = + aie::load_v(a + mi * k_chunk + g * gs + i * r); + const unsigned off = (g * gs + i * r) / 2; + + for (unsigned n = 0; n < n_tile; n++) { + const uint8_t *__restrict aq_n = a_q + n * (k_chunk / 2); aie::vector packed = aie::load_v(aq_n + off); aie::vector w_int8 = packed.template cast_to().template unpack_sign( false); - w_int8 = aie::sub(w_int8, zv); + w_int8 = aie::sub(w_int8, zv[n]); aie::vector w_bf16 = aie::to_float(w_int8, 0); - aie::vector a_vec = - aie::load_v(a + mi * k_chunk + g * gs + i * r); - g_acc = aie::mac(g_acc, w_bf16, a_vec); + g_acc[n] = aie::mac(g_acc[n], w_bf16, a_vec); } + } - aie::vector g_bf16 = g_acc.template to_vector(); - acc = aie::mac(acc, g_bf16, sa); + // Fold per-group bf16 scale into the per-(mi, n) running accumulator. + for (unsigned n = 0; n < n_tile; n++) { + bfloat16 sa = a_s[g * n_tile + n]; + aie::vector g_bf16 = + g_acc[n].template to_vector(); + acc[n] = aie::mac(acc[n], g_bf16, sa); } + } - float s = aie::reduce_add(acc.template to_vector()); + for (unsigned n = 0; n < n_tile; n++) { + float s = aie::reduce_add(acc[n].template to_vector()); c[mi * n_tile + n] = (bfloat16)((float)c[mi * n_tile + n] + s); } } From 48ca85e4e5dc5660be3cad687d9a2925131b084d Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 1 Jun 2026 10:43:49 -0700 Subject: [PATCH 03/10] [programming_examples] int4-AWQ GEMM: 2D herd with per-PE K accumulation Replaces the 1D-herd-over-N design with a 2D herd over (M, N) where each PE accumulates K serially into a per-PE row-major bf16 L1 C. Drains via a 4D L2 C [herd_m, herd_n, tile_m, tile_n] with per-PE dst_offsets. Verified PASS at herd 1x1, 2x4, 8x4 (smoke shapes) and at Llama-3.2-1B Q-projection shape M=N=K=2048 herd 8x4 (correlation 0.999986). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../matrix_multiplication/int4_awq/Makefile | 76 ++- .../int4_awq/matmul_int4_packed.py | 559 +++++++++--------- ...un_packed_npu2_llama_qproj_seq32_peano.lit | 2 +- .../int4_awq/run_packed_npu2_small_peano.lit | 4 +- 4 files changed, 336 insertions(+), 305 deletions(-) diff --git a/programming_examples/matrix_multiplication/int4_awq/Makefile b/programming_examples/matrix_multiplication/int4_awq/Makefile index ef183034d..bfa846d4c 100644 --- a/programming_examples/matrix_multiplication/int4_awq/Makefile +++ b/programming_examples/matrix_multiplication/int4_awq/Makefile @@ -1,42 +1,44 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT # -# int4 AWQ GEMM example (packed Q+S+Z BO). Reuses mv_int4_bf16.cc from the -# matrix_vector_multiplication/int4_awq examples — the GEMM symbol is in the -# same .o file as the GEMV/zero/partial_plus_r helpers. +# int4-AWQ GEMM example. The kernel .cc lives in the int4-AWQ GEMV sibling +# directory and is shared (matmul_int4_bf16_packed + zero_vectorized_bf16_mn). srcdir := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) INT4_SRCDIR := $(srcdir)/../../matrix_vector_multiplication/int4_awq +# TILE_K_L2 must equal K: the segment-level K loop has 1 iter so the per-PE +# L1 C accumulator survives across all K_CHUNK iterations. +M ?= 64 +K ?= 128 +N ?= 128 +GS ?= 128 +TILE_M ?= 16 +TILE_N ?= 16 +TILE_K_L1 ?= 128 +TILE_K_L2 ?= $(K) +HERD_M ?= 2 +HERD_N ?= 4 + ifdef PEANO_INSTALL_DIR BUILD_DIR := build_peano else BUILD_DIR := build_chess endif -OUTPUT_FORMAT ?= elf -OUTPUT_FORMAT_FLAG = --output-format $(OUTPUT_FORMAT) - -# Shapes / tiling (overridable). Defaults match the smallest correctness test. -M ?= 32 -K ?= 128 -N ?= 64 -GS ?= 128 -M_TILE ?= 16 -N_TILE ?= 16 -K_CHUNK ?= 128 -N_CORES ?= 4 - AIEOPT_DIR = $(shell realpath $(dir $(shell which aie-opt))/..) WARNING_FLAGS = -Wno-parentheses -Wno-attributes -Wno-macro-redefined -Wno-empty-body PEANOWRAP2P_FLAGS = -O2 -std=c++20 --target=aie2p-none-unknown-elf ${WARNING_FLAGS} -DNDEBUG -I ${AIEOPT_DIR}/include -PY_ARGS = --m $(M) --k $(K) --n $(N) --gs $(GS) --m-tile $(M_TILE) --n-tile $(N_TILE) --k-chunk $(K_CHUNK) --n-cores $(N_CORES) +PY_ARGS = --m $(M) --k $(K) --n $(N) --gs $(GS) \ + --tile-m $(TILE_M) --tile-n $(TILE_N) \ + --tile-k-l1 $(TILE_K_L1) --tile-k-l2 $(TILE_K_L2) \ + --herd-m $(HERD_M) --herd-n $(HERD_N) all: run_packed -print_packed: - ${powershell} python3 ${srcdir}/matmul_int4_packed.py $(OUTPUT_FORMAT_FLAG) -p $(PY_ARGS) +print: + ${powershell} python3 ${srcdir}/matmul_int4_packed.py $(PY_ARGS) -p compile-kernel: mkdir -p $(BUILD_DIR) @@ -45,14 +47,44 @@ compile-kernel: exit 1; \ fi $(PEANO_INSTALL_DIR)/bin/clang++ ${PEANOWRAP2P_FLAGS} \ - -DDIM_M=$(M_TILE) -DDIM_N=$(N_TILE) -DDIM_K=$(K_CHUNK) -DDIM_GS=$(GS) \ + -DDIM_M=$(TILE_M) -DDIM_N=$(TILE_N) -DDIM_K=$(TILE_K_L1) -DDIM_GS=$(GS) \ -DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16 \ -c $(INT4_SRCDIR)/mv_int4_bf16.cc -o $(BUILD_DIR)/mv_int4_bf16.o run_packed: compile-kernel - mkdir -p $(BUILD_DIR) PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ - ${powershell} python3 ${srcdir}/matmul_int4_packed.py $(OUTPUT_FORMAT_FLAG) $(PY_ARGS) + ${powershell} python3 ${srcdir}/matmul_int4_packed.py $(PY_ARGS) + +run1x1: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py \ + --m 16 --k 128 --n 16 --gs $(GS) \ + --tile-m 16 --tile-n 16 --tile-k-l1 128 --tile-k-l2 128 \ + --herd-m 1 --herd-n 1 + +run2x4: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py \ + --m 64 --k 128 --n 128 --gs $(GS) \ + --tile-m $(TILE_M) --tile-n $(TILE_N) \ + --tile-k-l1 $(TILE_K_L1) --tile-k-l2 128 \ + --herd-m 2 --herd-n 4 + +run8x4: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py \ + --m 256 --k 128 --n 128 --gs $(GS) \ + --tile-m $(TILE_M) --tile-n $(TILE_N) \ + --tile-k-l1 $(TILE_K_L1) --tile-k-l2 128 \ + --herd-m 8 --herd-n 4 + +# Llama Q-proj scale: M=N=K=2048, herd 8x4, TILE_K_L2=K (no re-zero). +run_llama_qproj: compile-kernel + PEANO_INSTALL_DIR=$(PEANO_INSTALL_DIR) cd $(BUILD_DIR) && \ + ${powershell} python3 ${srcdir}/matmul_int4_packed.py \ + --m 2048 --k 2048 --n 2048 --gs $(GS) \ + --tile-m 16 --tile-n 16 --tile-k-l1 128 --tile-k-l2 2048 \ + --herd-m 8 --herd-n 4 clean: rm -rf $(BUILD_DIR) __pycache__ diff --git a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py index 206d9d8bc..e83011c9a 100644 --- a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py +++ b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py @@ -1,21 +1,13 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT # -# int4-AWQ GEMM (prefill): C[M, N] = A[M, K] @ dequant(W)[K, N]. -# -# Weight is laid out as W_q[N, K/2] (output-major) so the per-tile packed -# layout matches the int4-AWQ GEMV. One packed L3 BO per tile: -# [ Q : N_TILE * K_CHUNK/2 bytes uint8 ] -# [ S : K_CHUNK/GS * N_TILE bf16 ] -# [ Z : K_CHUNK/GS * N_TILE uint8 ] -# Single multi-dim BD per shim channel keeps the 2-S2MM-per-tile budget -# (packed Q+S+Z on one S2MM, A on the other). -# -# Herd is 1D over N (sizes=[N_CORES, 1]). Each core owns N_per_core=N/N_CORES -# output columns and loops over (n_outer, m_outer, k_outer). M is handled by -# the serial M-outer loop in the core; multi-launch over M would extend this. +# int4-AWQ GEMM (prefill). 2D herd over (M, N), with K accumulated per-PE +# inside the herd; per-PE drain into a 4D L2 C [herd_m, herd_n, tile_m, +# tile_n]. Uses matmul_int4_bf16_packed from mv_int4_bf16.cc (packed +# Q+S+Z weight BO; bf16 activation and output). import argparse +import sys import numpy as np from ml_dtypes import bfloat16 @@ -29,25 +21,23 @@ IntegerAttr, IntegerType, MemRefType, + ShapedType, + StridedLayoutAttr, StringAttr, UnitAttr, ) from air.dialects.affine import apply as affine_apply from air.dialects.air import ( - Channel, - ChannelGet, - ChannelPut, MemorySpace, T, + dma_memcpy_nd, herd, launch, module_builder, segment, ) -from air.dialects.air import channel as channel_decl -from air.dialects.func import FuncOp, CallOp +from air.dialects.func import CallOp, FuncOp from air.dialects.memref import AllocOp, DeallocOp -from air.dialects import arith from air.dialects.scf import for_, yield_ from air.backend.xrt import XRTBackend from air.backend.xrt_runner import XRTRunner @@ -55,219 +45,54 @@ KERNEL_OBJ_NAME = "mv_int4_bf16.o" -def pack_inputs(W_q, W_s, W_z, M, K, N, GS, M_TILE, N_TILE, K_CHUNK, N_CORES): - """Pack per-(n_outer, k_outer) Q+S+Z tiles into a single L3 buffer. +def packed_tile_bytes(n_tile, k_chunk, gs): + n_gpc = k_chunk // gs + q_bytes = n_tile * (k_chunk // 2) + s_bytes = n_gpc * n_tile * 2 + z_bytes = n_gpc * n_tile + return q_bytes, s_bytes, z_bytes, q_bytes + s_bytes + z_bytes + - Output: uint8 [N_CORES * N_div * K_div, tile_bytes] where each core gets a - contiguous slab of N_div*K_div tiles (n_outer is outermost within a core). - W_q shape: [N, K/2] uint8 (output-major, K packed 2 nibbles per byte). - W_s shape: [K/GS, N] bf16. - W_z shape: [K/GS, N] uint8. +def pack_inputs(W_q, W_s, W_z, M, K, N, GS, N_TILE, K_CHUNK): + """Pack per-(n_outer, k_outer) Q+S+Z tiles into [N_div, K_div, tile_bytes]. + + W_q [N, K/2] u8 (output-major), W_s [K/GS, N] bf16, W_z [K/GS, N] u8. """ n_gpc = K_CHUNK // GS - q_bytes = N_TILE * (K_CHUNK // 2) - s_bytes = n_gpc * N_TILE * 2 - z_bytes = n_gpc * N_TILE - tile_bytes = q_bytes + s_bytes + z_bytes - - N_per_core = N // N_CORES - N_div = N_per_core // N_TILE + q_bytes, s_bytes, _, tile_bytes = packed_tile_bytes(N_TILE, K_CHUNK, GS) + N_div = N // N_TILE K_div = K // K_CHUNK - - total_tiles = N_CORES * N_div * K_div - packed = np.zeros((total_tiles, tile_bytes), dtype=np.uint8) - - tile_idx = 0 - for c in range(N_CORES): - base_col = c * N_per_core - for n_outer in range(N_div): - col_off = base_col + n_outer * N_TILE - for kc in range(K_div): - q_col_byte = kc * (K_CHUNK // 2) - g_off = kc * n_gpc - q_tile = W_q[ - col_off : col_off + N_TILE, - q_col_byte : q_col_byte + (K_CHUNK // 2), - ] - s_tile = W_s[g_off : g_off + n_gpc, col_off : col_off + N_TILE] - z_tile = W_z[g_off : g_off + n_gpc, col_off : col_off + N_TILE] - p = packed[tile_idx] - p[0:q_bytes] = np.ascontiguousarray(q_tile).view(np.uint8).reshape(-1) - p[q_bytes : q_bytes + s_bytes] = ( - np.ascontiguousarray(s_tile).view(np.uint8).reshape(-1) - ) - p[q_bytes + s_bytes :] = ( - np.ascontiguousarray(z_tile).view(np.uint8).reshape(-1) - ) - tile_idx += 1 + packed = np.zeros((N_div, K_div, tile_bytes), dtype=np.uint8) + for n_outer in range(N_div): + col_off = n_outer * N_TILE + for k_outer in range(K_div): + q_col_byte = k_outer * (K_CHUNK // 2) + g_off = k_outer * n_gpc + q_tile = W_q[ + col_off : col_off + N_TILE, + q_col_byte : q_col_byte + (K_CHUNK // 2), + ] + s_tile = W_s[g_off : g_off + n_gpc, col_off : col_off + N_TILE] + z_tile = W_z[g_off : g_off + n_gpc, col_off : col_off + N_TILE] + p = packed[n_outer, k_outer] + p[0:q_bytes] = np.ascontiguousarray(q_tile).view(np.uint8).reshape(-1) + p[q_bytes : q_bytes + s_bytes] = ( + np.ascontiguousarray(s_tile).view(np.uint8).reshape(-1) + ) + p[q_bytes + s_bytes :] = ( + np.ascontiguousarray(z_tile).view(np.uint8).reshape(-1) + ) return packed -def build_module(M, K, N, GS=128, M_TILE=16, N_TILE=16, K_CHUNK=128, N_CORES=4): - # Kernel-side static_assert constraints (mm_int4_bf16_impl uses r=32 inner - # vector and zero_mn_impl uses VW=32). Surface them at module-build time - # so unsupported tilings fail with a Python message, not a C++ template - # error during compile-kernel. - _R = 32 - assert GS % _R == 0, f"GS ({GS}) must be a multiple of inner vector width {_R}" - assert (M_TILE * N_TILE) % _R == 0, ( - f"M_TILE*N_TILE ({M_TILE}*{N_TILE}={M_TILE * N_TILE}) must be a " - f"multiple of vector width {_R} for zero_vectorized_bf16_mn" - ) - - assert M % M_TILE == 0 - M_div = M // M_TILE - assert N % N_CORES == 0 - N_per_core = N // N_CORES - assert N_per_core % N_TILE == 0 - N_div = N_per_core // N_TILE - assert K % K_CHUNK == 0 - assert K_CHUNK % GS == 0 - K_div = K // K_CHUNK - n_gpc = K_CHUNK // GS - - total_tiles = N_CORES * N_div * K_div - - q_bytes = N_TILE * (K_CHUNK // 2) - s_bytes = n_gpc * N_TILE * 2 - z_bytes = n_gpc * N_TILE - tile_bytes = q_bytes + s_bytes + z_bytes - - assert q_bytes % 32 == 0 - assert (q_bytes + s_bytes) % 32 == 0 - - tiles_per_core = N_div * K_div - - @module_builder - def build(): - bf16_ty = BF16Type.get() - i8_ty = IntegerType.get_signless(8) - - packed_l3 = MemRefType.get([total_tiles, tile_bytes], i8_ty) - A_l3 = MemRefType.get([M, K], bf16_ty) - C_l3 = MemRefType.get([M, N], bf16_ty) - - l1_ms = IntegerAttr.get(T.i32(), MemorySpace.L1) - - packed_l1 = MemRefType.get([tile_bytes], i8_ty, memory_space=l1_ms) - A_l1 = MemRefType.get([M_TILE, K_CHUNK], bf16_ty, memory_space=l1_ms) - C_l1 = MemRefType.get([M_TILE, N_TILE], bf16_ty, memory_space=l1_ms) - - channel_decl("inL3", size=[N_CORES]) - Channel("inA", size=[1, 1], broadcast_shape=[N_CORES, 1]) - channel_decl("outC", size=[N_CORES]) - - zero_func = FuncOp( - "zero_vectorized_bf16_mn", ([C_l1], []), visibility="private" - ) - zero_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) - zero_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() - - matmul_func = FuncOp( - "matmul_int4_bf16_packed", - ([packed_l1, A_l1, C_l1], []), - visibility="private", - ) - matmul_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) - matmul_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() - - # Launch's lj axis iterates m_outer; lj * M_TILE is the row offset - # into A and C for this launch. - lj_to_row_map = AffineMap.get( - 0, - 1, - [ - AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(M_TILE), - ) - ], - ) - - @FuncOp.from_py_func(packed_l3, A_l3, C_l3) - def matmul_int4_packed(PACKED, A, C): - @launch(sizes=[1, M_div], operands=[PACKED, A, C]) - def launch_body(li, lj, lsx, lsy, packed, a, c): - m_row_off = affine_apply(lj_to_row_map, [lj]) - for cc in range(N_CORES): - c_idx = arith.ConstantOp.create_index(cc) - c_tile_const = arith.ConstantOp.create_index(cc * tiles_per_core) - # Packed weight: each launch re-streams all N_div*K_div - # tiles for this core. No stride-0 dim — matches GEMV - # packed-put pattern. - ChannelPut( - "inL3", - packed, - indices=[c_idx], - offsets=[c_tile_const, 0], - sizes=[tiles_per_core, tile_bytes], - strides=[tile_bytes, 1], - ) - # Output C: per launch, write M_TILE rows x N_per_core - # cols starting at row m_row_off, col cc*N_per_core. - c_n_const = arith.ConstantOp.create_index(cc * N_per_core) - ChannelGet( - "outC", - c, - indices=[c_idx], - offsets=[0, m_row_off, c_n_const], - sizes=[N_div, M_TILE, N_TILE], - strides=[N_TILE, N, 1], - ) - - # A: per launch, broadcast the M_TILE-row band to all cores. - # n_outer stride 0 = replay the same A tile for each n_outer. - # Matches GEMV B-put stride-0 outer pattern. - ChannelPut( - "inA", - a, - offsets=[0, 0, m_row_off, 0], - sizes=[N_div, K_div, M_TILE, K_CHUNK], - strides=[0, K_CHUNK, K, 1], - ) - - @segment(name="seg") - def segment_body(): - @herd(name="mm_h", sizes=[N_CORES, 1]) - def herd_body(tx, ty, _sx, _sy): - for _n_outer in for_(N_div): - l1_c_op = AllocOp(C_l1, [], []) - CallOp(zero_func, [l1_c_op]) - for _ in for_(K_div): - l1_p_op = AllocOp(packed_l1, [], []) - l1_a_op = AllocOp(A_l1, [], []) - ChannelGet("inL3", l1_p_op, indices=[tx]) - ChannelGet("inA", l1_a_op, indices=[tx, ty]) - CallOp( - matmul_func, - [l1_p_op, l1_a_op, l1_c_op], - ) - DeallocOp(l1_p_op) - DeallocOp(l1_a_op) - yield_([]) - ChannelPut("outC", l1_c_op, indices=[tx]) - DeallocOp(l1_c_op) - yield_([]) - - herd_body.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) - herd_body.attributes["x_loc"] = IntegerAttr.get(T.i64(), 0) - herd_body.attributes["y_loc"] = IntegerAttr.get(T.i64(), 2) - - return build() - - def cpu_reference(W_q, W_s, W_z, A): - """W is stored as [N, K/2] uint8 (output-major). dequant(W)[k, n].""" N_ = W_q.shape[0] K_ = A.shape[1] - M_ = A.shape[0] n_groups = W_s.shape[0] gs = K_ // n_groups Af = A.astype(np.float32) W_s_f = W_s.astype(np.float32) W_z_i = W_z.astype(np.int32) - - # Dequantize W into [K, N] f32. W_dq = np.zeros((K_, N_), dtype=np.float32) for n in range(N_): for kk in range(K_): @@ -279,30 +104,225 @@ def cpu_reference(W_q, W_s, W_z, A): return C.astype(bfloat16) -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog="matmul_int4_packed.py", - description="int4-AWQ GEMM: C[M,N] = A[M,K] @ dequant(W)[K,N]", +@module_builder +def build_module( + m, k, n, gs, tile_m, tile_k_l2, tile_k_l1, tile_n, herd_m, herd_n +): + assert m % (tile_m * herd_m) == 0 + assert n % (tile_n * herd_n) == 0 + assert k % tile_k_l2 == 0 + assert tile_k_l2 % tile_k_l1 == 0 + assert tile_k_l1 % gs == 0 + + _, _, _, tile_bytes = packed_tile_bytes(tile_n, tile_k_l1, gs) + k_per_l2 = tile_k_l2 // tile_k_l1 + N_div = n // tile_n + K_div = k // tile_k_l1 + + bf16_ty = BF16Type.get() + u8_ty = IntegerType.get_signless(8) + + A_l3_ty = MemRefType.get([m, k], bf16_ty) + B_l3_ty = MemRefType.get([N_div, K_div, tile_bytes], u8_ty) + C_l3_ty = MemRefType.get([m, n], bf16_ty) + + l1_ms = IntegerAttr.get(T.i32(), MemorySpace.L1) + l2_ms = IntegerAttr.get(T.i32(), MemorySpace.L2) + + A_l2_ty = MemRefType.get( + [herd_m, 1, tile_m, tile_k_l2], bf16_ty, memory_space=l2_ms ) - parser.add_argument("-v", "--verbose", action="store_true") - parser.add_argument("-p", "--print-module-only", action="store_true") - parser.add_argument("--m", type=int, default=32) + B_l2_ty = MemRefType.get( + [1, herd_n, k_per_l2, tile_bytes], u8_ty, memory_space=l2_ms + ) + C_l2_ty = MemRefType.get( + [herd_m, herd_n, tile_m, tile_n], bf16_ty, memory_space=l2_ms + ) + + A_l1_ty = MemRefType.get([tile_m, tile_k_l1], bf16_ty, memory_space=l1_ms) + B_l1_ty = MemRefType.get([tile_bytes], u8_ty, memory_space=l1_ms) + C_l1_ty = MemRefType.get([tile_m, tile_n], bf16_ty, memory_space=l1_ms) + + zero_func = FuncOp( + "zero_vectorized_bf16_mn", + ([C_l1_ty], []), + visibility="private", + ) + zero_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + zero_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + + matmul_func = FuncOp( + "matmul_int4_bf16_packed", + ([B_l1_ty, A_l1_ty, C_l1_ty], []), + visibility="private", + ) + matmul_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + matmul_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + + @FuncOp.from_py_func(A_l3_ty, B_l3_ty, C_l3_ty) + def matmul_int4_packed(arg_a, arg_b, arg_c): + launch_size = [m // tile_m // herd_m, n // tile_n // herd_n] + + @launch(operands=[arg_a, arg_b, arg_c], sizes=launch_size) + def launch_body(li, lj, lsx, lsy, l3_a, l3_b, l3_c): + @segment(name="seg", operands=[li, lj, l3_a, l3_b, l3_c]) + def segment_body(li_s, lj_s, l3_a_s, l3_b_s, l3_c_s): + l2_a = AllocOp(A_l2_ty, [], []) + l2_b = AllocOp(B_l2_ty, [], []) + l2_c = AllocOp(C_l2_ty, [], []) + + ix_to_row = AffineMap.get( + 0, 1, + [AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_m * herd_m))], + ) + iy_to_n_outer = AffineMap.get( + 0, 1, + [AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(herd_n))], + ) + row_off = affine_apply(ix_to_row, [li_s]) + n_outer_off = affine_apply(iy_to_n_outer, [lj_s]) + + k_l2_to_k = AffineMap.get( + 0, 1, + [AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_k_l2))], + ) + k_l2_to_chunk = AffineMap.get( + 0, 1, + [AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(k_per_l2))], + ) + k_chunk_off_l1_map = AffineMap.get( + 0, 1, + [AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_k_l1))], + ) + + for i in for_(0, k // tile_k_l2): + k_l2_off = affine_apply(k_l2_to_k, [i]) + k_chunk_off = affine_apply(k_l2_to_chunk, [i]) + + dma_memcpy_nd( + l2_a, l3_a_s, + src_offsets=[0, 0, row_off, k_l2_off], + src_sizes=[herd_m, 1, tile_m, tile_k_l2], + src_strides=[k * tile_m, tile_k_l2, k, 1], + ) + dma_memcpy_nd( + l2_b, l3_b_s, + src_offsets=[0, n_outer_off, k_chunk_off, 0], + src_sizes=[1, herd_n, k_per_l2, tile_bytes], + src_strides=[ + K_div * tile_bytes, + K_div * tile_bytes, + tile_bytes, + 1, + ], + ) + + @herd( + name="herd_0", + sizes=[herd_m, herd_n], + operands=[l2_a, l2_b, l2_c], + ) + def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): + _l1_a = AllocOp(A_l1_ty, [], []) + _l1_b = AllocOp(B_l1_ty, [], []) + _l1_c = AllocOp(C_l1_ty, [], []) + CallOp(zero_func, [_l1_c]) + for j in for_(0, k_per_l2): + k1_off = affine_apply(k_chunk_off_l1_map, [j]) + dma_memcpy_nd( + _l1_a, _l2a, + src_offsets=[_tx, 0, 0, k1_off], + src_sizes=[1, 1, tile_m, tile_k_l1], + src_strides=[ + tile_m * tile_k_l2, + tile_m * tile_k_l2, + tile_k_l2, + 1, + ], + ) + dma_memcpy_nd( + _l1_b, _l2b, + src_offsets=[0, _ty, j, 0], + src_sizes=[1, 1, 1, tile_bytes], + src_strides=[ + herd_n * k_per_l2 * tile_bytes, + k_per_l2 * tile_bytes, + tile_bytes, + 1, + ], + ) + CallOp(matmul_func, [_l1_b, _l1_a, _l1_c]) + yield_([]) + dma_memcpy_nd( + _l2c, _l1_c, + dst_offsets=[_tx, _ty, 0, 0], + dst_sizes=[1, 1, tile_m, tile_n], + dst_strides=[ + herd_n * tile_m * tile_n, + tile_m * tile_n, + tile_n, + 1, + ], + ) + DeallocOp(_l1_a) + DeallocOp(_l1_b) + DeallocOp(_l1_c) + + yield_([]) + + col_off_map = AffineMap.get( + 0, 1, + [AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_n * herd_n))], + ) + col_off = affine_apply(col_off_map, [lj_s]) + dma_memcpy_nd( + l3_c_s, l2_c, + dst_offsets=[row_off, col_off], + dst_sizes=[herd_m * tile_m, herd_n * tile_n], + dst_strides=[n, 1], + src_offsets=[0, 0, 0, 0], + src_sizes=[herd_m, tile_m, herd_n, tile_n], + src_strides=[ + herd_n * tile_m * tile_n, + tile_n, + tile_m * tile_n, + 1, + ], + ) + + DeallocOp(l2_a) + DeallocOp(l2_b) + DeallocOp(l2_c) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=64) parser.add_argument("--k", type=int, default=128) - parser.add_argument("--n", type=int, default=64) + parser.add_argument("--n", type=int, default=128) parser.add_argument("--gs", type=int, default=128) - parser.add_argument("--m-tile", type=int, default=16, dest="m_tile") - parser.add_argument("--n-tile", type=int, default=16, dest="n_tile") - parser.add_argument("--k-chunk", type=int, default=128, dest="k_chunk") - parser.add_argument("--n-cores", type=int, default=4, dest="n_cores") - parser.add_argument( - "--output-format", - type=str, - choices=["xclbin", "elf"], - default="elf", - ) + parser.add_argument("--tile-m", type=int, default=16, dest="tile_m") + parser.add_argument("--tile-k-l2", type=int, default=128, dest="tile_k_l2") + parser.add_argument("--tile-k-l1", type=int, default=128, dest="tile_k_l1") + parser.add_argument("--tile-n", type=int, default=16, dest="tile_n") + parser.add_argument("--herd-m", type=int, default=2, dest="herd_m") + parser.add_argument("--herd-n", type=int, default=4, dest="herd_n") + parser.add_argument("-v", "--verbose", action="store_true") + parser.add_argument("-p", "--print-module-only", action="store_true") parser.add_argument( "--compile-mode", - type=str, choices=["compile-and-run", "compile-only"], default="compile-and-run", dest="compile_mode", @@ -310,32 +330,13 @@ def cpu_reference(W_q, W_s, W_z, A): args = parser.parse_args() module = build_module( - args.m, - args.k, - args.n, - GS=args.gs, - M_TILE=args.m_tile, - N_TILE=args.n_tile, - K_CHUNK=args.k_chunk, - N_CORES=args.n_cores, + args.m, args.k, args.n, args.gs, + args.tile_m, args.tile_k_l2, args.tile_k_l1, args.tile_n, + args.herd_m, args.herd_n, ) if args.print_module_only: print(module) - exit(0) - - if args.compile_mode == "compile-only": - backend = XRTBackend( - verbose=args.verbose, - omit_while_true_loop=False, - omit_pingpong=True, - output_format=args.output_format, - instance_name="matmul_int4_packed", - use_lock_race_condition_fix=True, - stack_size=16384, - ) - backend.compile(module) - backend.unload() - exit(0) + sys.exit(0) np.random.seed(42) W_q_unp = np.random.randint(0, 16, size=(args.n, args.k), dtype=np.uint8) @@ -345,37 +346,35 @@ def cpu_reference(W_q, W_s, W_z, A): W_z = np.random.randint(7, 9, size=(n_groups, args.n), dtype=np.uint8) A = np.random.randn(args.m, args.k).astype(bfloat16) - C_ref = cpu_reference(W_q, W_s, W_z, A) PACKED = pack_inputs( - W_q, - W_s, - W_z, - args.m, - args.k, - args.n, - args.gs, - args.m_tile, - args.n_tile, - args.k_chunk, - args.n_cores, + W_q, W_s, W_z, args.m, args.k, args.n, args.gs, args.tile_n, args.tile_k_l1 ) + C_ref = cpu_reference(W_q, W_s, W_z, A) + + if args.compile_mode == "compile-only": + backend = XRTBackend( + verbose=args.verbose, + omit_while_true_loop=False, + output_format="xclbin", + runtime_loop_tiling_sizes=[2, 2], + stack_size=16384, + ) + backend.compile(module) + backend.unload() + sys.exit(0) runner = XRTRunner( verbose=args.verbose, omit_while_true_loop=False, - omit_pingpong=True, - output_format=args.output_format, - instance_name="matmul_int4_packed", - use_lock_race_condition_fix=True, + output_format="xclbin", + runtime_loop_tiling_sizes=[2, 2], stack_size=16384, ) - exit( + sys.exit( runner.run_test( module, - inputs=[PACKED, A], + inputs=[A, PACKED], expected_outputs=[C_ref], - rtol=0.1, - atol=0.05, - min_correlation=0.999, + rtol=0.1, atol=0.05, min_correlation=0.999, ) ) diff --git a/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit index f27f6eb7a..6c5a3d1af 100644 --- a/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit +++ b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_llama_qproj_seq32_peano.lit @@ -8,5 +8,5 @@ // RUN: make -f %S/Makefile clean // // Llama-3.2-1B Q-projection at prefill seq=32: A[32,2048] @ dequant(W)[2048,2048]. -// RUN: make -f %S/Makefile run_packed M=32 K=2048 N=2048 N_CORES=8 OUTPUT_FORMAT=elf PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// RUN: make -f %S/Makefile run_packed M=32 K=2048 N=2048 HERD_M=2 HERD_N=4 PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s // CHECK: PASS! diff --git a/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit index c775068c1..d79a54428 100644 --- a/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit +++ b/programming_examples/matrix_multiplication/int4_awq/run_packed_npu2_small_peano.lit @@ -7,6 +7,6 @@ // RUN: cd test_int4_gemm_npu2_small_peano // RUN: make -f %S/Makefile clean // -// int4-AWQ GEMM smoke: M=32, K=128, N=64 (exercises M_div>1 + N_div>1). -// RUN: make -f %S/Makefile run_packed M=32 K=128 N=64 OUTPUT_FORMAT=elf PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s +// int4-AWQ GEMM smoke: M=64 K=128 N=128 on a 2x4 herd. +// RUN: make -f %S/Makefile run_packed M=64 K=128 N=128 HERD_M=2 HERD_N=4 PEANO_INSTALL_DIR=%PEANO_INSTALL_DIR | FileCheck %s // CHECK: PASS! From dbf24d66d0dae59b9aa0f85f89df74889ca47b69 Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 1 Jun 2026 11:37:10 -0700 Subject: [PATCH 04/10] [programming_examples] int4-AWQ GEMM: dequant-to-L1 + bf16 MMUL (1.7x) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the per-(mi, g, i, n) aie::mac inner loop with a two-phase kernel: (1) pack A row-major → mmul-packed [KB][MB][r][s] and dequant W (scale folded in) into mmul-packed [NB][KB][s][t]; (2) aie::mmul<8,8,8, bf16,bf16,accfloat> loop. Same kernel symbol and ABI, drop-in. Measured at M=N=K=2048 herd 8x4: 198 ms → 117 ms (1.7x). bf16 MMUL ceiling at the same shape is 51 ms; remaining gap is dominated by the scatter store in the dequant phase (32 scalar stores per (g, n, i) across a strided [s][t] inner layout). Correlation 0.999955+ at all tested shapes (1x1, 2x4, 8x4, Llama Q-proj). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../int4_awq/mv_int4_bf16.cc | 141 +++++++++++------- 1 file changed, 87 insertions(+), 54 deletions(-) diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index e52bc60de..78060e02a 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -111,70 +111,103 @@ static void zero_mn_impl(bfloat16 *__restrict c) { aie::store_v(c + i, zv); } -// int4-AWQ GEMM inner kernel. Loop nest is (mi, g, i, n) so each -// activation load a_vec(mi, g, i) is reused across all n_tile output -// columns instead of being reloaded per n. Per-group zero-point broadcast -// zv(g, n) and per-(mi, n) accumulator acc[n] are likewise hoisted out of -// the inner i loop, leaving just load-packed + unpack + sub + cvt + mac -// in the hot path. +// int4-AWQ GEMM inner kernel using aie::mmul<8,8,8,bf16,bf16,accfloat>. +// Phase 1: pack row-major A into a_pack and dequant W (with per-group bf16 +// scale folded in) into b_pack — both in the operand layouts that the bf16 +// GEMM kernel's MMUL loop expects (K-major A, N-major B). +// Phase 2: MMUL loop yields f32 accumulators; convert to bf16 and add to +// the existing row-major c. template -void mm_int4_bf16_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, - uint8_t *__restrict a_z, bfloat16 *__restrict a, - bfloat16 *__restrict c) { - ::aie::set_rounding(aie::rounding_mode::conv_even); - static_assert(gs % r == 0, "group size must be multiple of inner vector r"); - constexpr unsigned NSUB = gs / r; + unsigned R = 32> +void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, + uint8_t *__restrict a_z, bfloat16 *__restrict a, + bfloat16 *__restrict c) { + constexpr unsigned r = 8, s = 8, t = 8; + constexpr unsigned MB = m_tile / r; + constexpr unsigned NB = n_tile / t; + constexpr unsigned KB = k_chunk / s; constexpr unsigned NG = k_chunk / gs; + static_assert(m_tile % r == 0, "m_tile must be multiple of 8"); + static_assert(n_tile % t == 0, "n_tile must be multiple of 8"); + static_assert(k_chunk % s == 0, "k_chunk must be multiple of 8"); + static_assert(gs % R == 0, "gs must be multiple of R"); - for (unsigned mi = 0; mi < m_tile; mi++) { - // Per-(mi, n) accumulator spans all K-groups; reduce_add at the end. - aie::accum acc[n_tile]; - for (unsigned n = 0; n < n_tile; n++) - acc[n].from_vector(aie::zeros()); + ::aie::set_rounding(aie::rounding_mode::conv_even); - for (unsigned g = 0; g < NG; g++) { - // Hoist per-(g, n) zero-point broadcasts out of the i loop. - aie::vector zv[n_tile]; - for (unsigned n = 0; n < n_tile; n++) - zv[n] = aie::broadcast((int8_t)a_z[g * n_tile + n]); + using MMUL = aie::mmul; - // Per-(mi, g, n) intra-group accumulator over the NSUB sub-blocks. - aie::accum g_acc[n_tile]; - for (unsigned n = 0; n < n_tile; n++) - g_acc[n].from_vector(aie::zeros()); + alignas(32) bfloat16 a_pack[KB * MB * r * s]; + alignas(32) bfloat16 b_pack[NB * KB * s * t]; -#pragma clang loop unroll(full) - for (unsigned i = 0; i < NSUB; i++) { - // Single a_vec load per (mi, g, i) reused across all n_tile cols. - aie::vector a_vec = - aie::load_v(a + mi * k_chunk + g * gs + i * r); - const unsigned off = (g * gs + i * r) / 2; - - for (unsigned n = 0; n < n_tile; n++) { - const uint8_t *__restrict aq_n = a_q + n * (k_chunk / 2); - aie::vector packed = aie::load_v(aq_n + off); - aie::vector w_int8 = - packed.template cast_to().template unpack_sign( - false); - w_int8 = aie::sub(w_int8, zv[n]); - aie::vector w_bf16 = aie::to_float(w_int8, 0); - g_acc[n] = aie::mac(g_acc[n], w_bf16, a_vec); - } + // Pack A row-major [m_tile][k_chunk] → [KB][MB][r][s]. + for (unsigned k_b = 0; k_b < KB; k_b++) { + for (unsigned m_b = 0; m_b < MB; m_b++) { + bfloat16 *dst = a_pack + (k_b * MB + m_b) * (r * s); + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::vector v = + aie::load_v(a + (m_b * r + m_i) * k_chunk + k_b * s); + aie::store_v(dst + m_i * s, v); } + } + } - // Fold per-group bf16 scale into the per-(mi, n) running accumulator. - for (unsigned n = 0; n < n_tile; n++) { - bfloat16 sa = a_s[g * n_tile + n]; - aie::vector g_bf16 = - g_acc[n].template to_vector(); - acc[n] = aie::mac(acc[n], g_bf16, sa); + // Dequant W (with scale fold) → [NB][KB][s][t]. One (g, n, i) iteration + // produces R K-values for fixed n that scatter across R/s k-blocks at + // stride t within each k-block. + for (unsigned g = 0; g < NG; g++) { + for (unsigned n = 0; n < n_tile; n++) { + bfloat16 sc = a_s[g * n_tile + n]; + aie::vector zv = aie::broadcast((int8_t)a_z[g * n_tile + n]); + aie::vector sv = aie::broadcast(sc); + unsigned n_b = n / t; + unsigned n_i = n % t; + const uint8_t *__restrict aq_n = a_q + n * (k_chunk / 2); + + for (unsigned i = 0; i < gs / R; i++) { + unsigned k_base = g * gs + i * R; + aie::vector pk = aie::load_v(aq_n + k_base / 2); + aie::vector w_i8 = + pk.template cast_to().template unpack_sign(false); + w_i8 = aie::sub(w_i8, zv); + aie::vector w_bf16 = aie::to_float(w_i8, 0); + aie::vector w_scaled = + aie::mul(w_bf16, sv).template to_vector(); + + unsigned k_b_base = k_base / s; + bfloat16 *base = b_pack + n_b * (KB * s * t) + n_i; +#pragma clang loop unroll(full) + for (unsigned j = 0; j < R; j++) { + unsigned k_b = k_b_base + j / s; + unsigned k_i = j % s; + base[k_b * (s * t) + k_i * t] = w_scaled[j]; + } } } + } - for (unsigned n = 0; n < n_tile; n++) { - float s = aie::reduce_add(acc[n].template to_vector()); - c[mi * n_tile + n] = (bfloat16)((float)c[mi * n_tile + n] + s); + // MMUL: one accumulator per (m_b, n_b) tile, reduced across KB k-blocks. + for (unsigned m_b = 0; m_b < MB; m_b++) { + for (unsigned n_b = 0; n_b < NB; n_b++) { + MMUL C; +#pragma clang loop unroll(full) + for (unsigned k_b = 0; k_b < KB; k_b++) { + aie::vector A = + aie::load_v(a_pack + (k_b * MB + m_b) * (r * s)); + aie::vector B = + aie::load_v(b_pack + (n_b * KB + k_b) * (s * t)); + if (k_b == 0) + C.mul(A, B); + else + C.mac(A, B); + } + aie::vector ctile = C.template to_vector(); + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::vector row = ctile.template extract(m_i); + bfloat16 *cdst = c + (m_b * r + m_i) * n_tile + n_b * t; + aie::vector c_old = aie::load_v(cdst); + aie::vector c_new = aie::add(row, c_old); + aie::store_v(cdst, c_new); + } } } } @@ -212,7 +245,7 @@ void matmul_int4_bf16_packed(uint8_t *packed, bfloat16 *a, bfloat16 *c) { uint8_t *a_q = packed; bfloat16 *a_s = reinterpret_cast(packed + Q_BYTES); uint8_t *a_z = packed + Q_BYTES + S_BYTES; - mm_int4_bf16_impl(a_q, a_s, a_z, a, c); + mm_int4_bf16_mmul_impl(a_q, a_s, a_z, a, c); } void partial_plus_r_bf16(bfloat16 *partial, bfloat16 *r_full, int offset, From 1da3d3d43a80f61a84d945c9b9a495ba90823e2a Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 1 Jun 2026 11:56:28 -0700 Subject: [PATCH 05/10] fixup: format + kernel-constraint asserts in host builder clang-format-17 on mv_int4_bf16.cc and black on matmul_int4_packed.py to fix the format CI check. Also surfaces the int4 GEMM kernel-side static_asserts (tile_m/n/k_l1 % 8 for mmul, gs % 32 for dequant, tile_m*tile_n % 32 for the zero kernel) as Python asserts at module- build time so unsupported tilings fail with a clear message rather than a C++ template/static_assert during compile-kernel. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../int4_awq/matmul_int4_packed.py | 120 ++++++++++++------ .../int4_awq/mv_int4_bf16.cc | 3 +- 2 files changed, 84 insertions(+), 39 deletions(-) diff --git a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py index e83011c9a..6750034e7 100644 --- a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py +++ b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py @@ -105,14 +105,23 @@ def cpu_reference(W_q, W_s, W_z, A): @module_builder -def build_module( - m, k, n, gs, tile_m, tile_k_l2, tile_k_l1, tile_n, herd_m, herd_n -): +def build_module(m, k, n, gs, tile_m, tile_k_l2, tile_k_l1, tile_n, herd_m, herd_n): assert m % (tile_m * herd_m) == 0 assert n % (tile_n * herd_n) == 0 assert k % tile_k_l2 == 0 assert tile_k_l2 % tile_k_l1 == 0 assert tile_k_l1 % gs == 0 + # Kernel-side static_assert constraints from mv_int4_bf16.cc: + # mm_int4_bf16_mmul_impl: tile_m/n/k_chunk % 8 (mmul dims), gs % R=32 + # zero_vectorized_bf16_mn: (tile_m * tile_n) % VW=32 + assert ( + tile_m % 8 == 0 and tile_n % 8 == 0 and tile_k_l1 % 8 == 0 + ), "tile_m, tile_n, tile_k_l1 must each be multiples of 8 (mmul tile size)" + assert gs % 32 == 0, "gs must be a multiple of dequant inner-vector width 32" + assert (tile_m * tile_n) % 32 == 0, ( + f"tile_m*tile_n ({tile_m}*{tile_n}={tile_m * tile_n}) must be a multiple " + f"of vector width 32 for zero_vectorized_bf16_mn" + ) _, _, _, tile_bytes = packed_tile_bytes(tile_n, tile_k_l1, gs) k_per_l2 = tile_k_l2 // tile_k_l1 @@ -172,37 +181,53 @@ def segment_body(li_s, lj_s, l3_a_s, l3_b_s, l3_c_s): l2_c = AllocOp(C_l2_ty, [], []) ix_to_row = AffineMap.get( - 0, 1, - [AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_m * herd_m))], + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_m * herd_m), + ) + ], ) iy_to_n_outer = AffineMap.get( - 0, 1, - [AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(herd_n))], + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(herd_n) + ) + ], ) row_off = affine_apply(ix_to_row, [li_s]) n_outer_off = affine_apply(iy_to_n_outer, [lj_s]) k_l2_to_k = AffineMap.get( - 0, 1, - [AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_k_l2))], + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(tile_k_l2) + ) + ], ) k_l2_to_chunk = AffineMap.get( - 0, 1, - [AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(k_per_l2))], + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(k_per_l2) + ) + ], ) k_chunk_off_l1_map = AffineMap.get( - 0, 1, - [AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_k_l1))], + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), AffineConstantExpr.get(tile_k_l1) + ) + ], ) for i in for_(0, k // tile_k_l2): @@ -210,13 +235,15 @@ def segment_body(li_s, lj_s, l3_a_s, l3_b_s, l3_c_s): k_chunk_off = affine_apply(k_l2_to_chunk, [i]) dma_memcpy_nd( - l2_a, l3_a_s, + l2_a, + l3_a_s, src_offsets=[0, 0, row_off, k_l2_off], src_sizes=[herd_m, 1, tile_m, tile_k_l2], src_strides=[k * tile_m, tile_k_l2, k, 1], ) dma_memcpy_nd( - l2_b, l3_b_s, + l2_b, + l3_b_s, src_offsets=[0, n_outer_off, k_chunk_off, 0], src_sizes=[1, herd_n, k_per_l2, tile_bytes], src_strides=[ @@ -240,7 +267,8 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): for j in for_(0, k_per_l2): k1_off = affine_apply(k_chunk_off_l1_map, [j]) dma_memcpy_nd( - _l1_a, _l2a, + _l1_a, + _l2a, src_offsets=[_tx, 0, 0, k1_off], src_sizes=[1, 1, tile_m, tile_k_l1], src_strides=[ @@ -251,7 +279,8 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): ], ) dma_memcpy_nd( - _l1_b, _l2b, + _l1_b, + _l2b, src_offsets=[0, _ty, j, 0], src_sizes=[1, 1, 1, tile_bytes], src_strides=[ @@ -264,7 +293,8 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): CallOp(matmul_func, [_l1_b, _l1_a, _l1_c]) yield_([]) dma_memcpy_nd( - _l2c, _l1_c, + _l2c, + _l1_c, dst_offsets=[_tx, _ty, 0, 0], dst_sizes=[1, 1, tile_m, tile_n], dst_strides=[ @@ -281,14 +311,19 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): yield_([]) col_off_map = AffineMap.get( - 0, 1, - [AffineExpr.get_mul( - AffineSymbolExpr.get(0), - AffineConstantExpr.get(tile_n * herd_n))], + 0, + 1, + [ + AffineExpr.get_mul( + AffineSymbolExpr.get(0), + AffineConstantExpr.get(tile_n * herd_n), + ) + ], ) col_off = affine_apply(col_off_map, [lj_s]) dma_memcpy_nd( - l3_c_s, l2_c, + l3_c_s, + l2_c, dst_offsets=[row_off, col_off], dst_sizes=[herd_m * tile_m, herd_n * tile_n], dst_strides=[n, 1], @@ -330,9 +365,16 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): args = parser.parse_args() module = build_module( - args.m, args.k, args.n, args.gs, - args.tile_m, args.tile_k_l2, args.tile_k_l1, args.tile_n, - args.herd_m, args.herd_n, + args.m, + args.k, + args.n, + args.gs, + args.tile_m, + args.tile_k_l2, + args.tile_k_l1, + args.tile_n, + args.herd_m, + args.herd_n, ) if args.print_module_only: print(module) @@ -357,7 +399,7 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): omit_while_true_loop=False, output_format="xclbin", runtime_loop_tiling_sizes=[2, 2], - stack_size=16384, + stack_size=16384, ) backend.compile(module) backend.unload() @@ -375,6 +417,8 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): module, inputs=[A, PACKED], expected_outputs=[C_ref], - rtol=0.1, atol=0.05, min_correlation=0.999, + rtol=0.1, + atol=0.05, + min_correlation=0.999, ) ) diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 78060e02a..3e98589e9 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -157,7 +157,8 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, for (unsigned g = 0; g < NG; g++) { for (unsigned n = 0; n < n_tile; n++) { bfloat16 sc = a_s[g * n_tile + n]; - aie::vector zv = aie::broadcast((int8_t)a_z[g * n_tile + n]); + aie::vector zv = + aie::broadcast((int8_t)a_z[g * n_tile + n]); aie::vector sv = aie::broadcast(sc); unsigned n_b = n / t; unsigned n_i = n % t; From 4a51b7abaa23078e0901317ecbebb90cf87a0a82 Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 1 Jun 2026 14:48:00 -0700 Subject: [PATCH 06/10] fixup: f32 c accumulator + post-MMUL scale fold + DIM_K_CHUNK macro MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three issues addressed: 1. CI compile failure on amdhx370: the matmul template was instantiating with DIM_K=2048 when the host build only cared about DIM_K for the matvec — overflowing the AIE2P assembly printer's immediate range on the matmul's scratch buffer addressing. Added a separate DIM_K_CHUNK macro (default 128) used by mm_int4_bf16_mmul_impl, decoupled from matvec's DIM_K. matvec callers that only set DIM_K still build cleanly. 2. f32 accumulator across host K-chunk calls. matmul_int4_bf16_packed renamed to matmul_int4_bf16_packed_f32 and now takes float* c. New helpers zero_vectorized_f32_mn and f32_to_bf16_mn handle the L1 C init and the final bf16 narrowing once per launch. Host builder adds an f32 L1 C accumulator + bf16 L1 C drain buffer + convert kernel call between the K loop and the drain. 3. Post-MMUL scale fold. Dequant now produces UNSCALED bf16 W; per-group MMUL accumulates in f32; convert to f32 vec via row-by-row extract (the 64-element store_v gives a bad layout — using the same extract(m_i) pattern as the working pre-MMUL kernel fixes it); scalar multiply by per-n bf16 scale (lifted to f32) and accumulate into the c tile. One bf16 truncate per output element per group instead of per W element — matches mac kernel's precision pattern. Correlation at Llama Q-proj seq=32 (M=32 K=N=2048): 0.999945 → 0.999975. Mismatch count (atol=0.05) dropped from ~6800 to ~11 / 65536 (0.017%). max_mismatch_percentage=0.05 in the host script bounds this at 32 elements with margin — correlation > 0.999 remains the primary check. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../matrix_multiplication/int4_awq/Makefile | 2 +- .../int4_awq/matmul_int4_packed.py | 40 +++-- .../int4_awq/mv_int4_bf16.cc | 144 +++++++++++++----- 3 files changed, 138 insertions(+), 48 deletions(-) diff --git a/programming_examples/matrix_multiplication/int4_awq/Makefile b/programming_examples/matrix_multiplication/int4_awq/Makefile index bfa846d4c..662b7d564 100644 --- a/programming_examples/matrix_multiplication/int4_awq/Makefile +++ b/programming_examples/matrix_multiplication/int4_awq/Makefile @@ -47,7 +47,7 @@ compile-kernel: exit 1; \ fi $(PEANO_INSTALL_DIR)/bin/clang++ ${PEANOWRAP2P_FLAGS} \ - -DDIM_M=$(TILE_M) -DDIM_N=$(TILE_N) -DDIM_K=$(TILE_K_L1) -DDIM_GS=$(GS) \ + -DDIM_M=$(TILE_M) -DDIM_N=$(TILE_N) -DDIM_K_CHUNK=$(TILE_K_L1) -DDIM_GS=$(GS) \ -DAIE_API_EMULATE_BFLOAT16_MMUL_WITH_BFP16 \ -c $(INT4_SRCDIR)/mv_int4_bf16.cc -o $(BUILD_DIR)/mv_int4_bf16.o diff --git a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py index 6750034e7..344149307 100644 --- a/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py +++ b/programming_examples/matrix_multiplication/int4_awq/matmul_int4_packed.py @@ -18,6 +18,7 @@ AffineMap, AffineSymbolExpr, BF16Type, + F32Type, IntegerAttr, IntegerType, MemRefType, @@ -129,6 +130,7 @@ def build_module(m, k, n, gs, tile_m, tile_k_l2, tile_k_l1, tile_n, herd_m, herd K_div = k // tile_k_l1 bf16_ty = BF16Type.get() + f32_ty = F32Type.get() u8_ty = IntegerType.get_signless(8) A_l3_ty = MemRefType.get([m, k], bf16_ty) @@ -150,24 +152,35 @@ def build_module(m, k, n, gs, tile_m, tile_k_l2, tile_k_l1, tile_n, herd_m, herd A_l1_ty = MemRefType.get([tile_m, tile_k_l1], bf16_ty, memory_space=l1_ms) B_l1_ty = MemRefType.get([tile_bytes], u8_ty, memory_space=l1_ms) - C_l1_ty = MemRefType.get([tile_m, tile_n], bf16_ty, memory_space=l1_ms) + # L1 C accumulator: f32. Kept across the host K-chunk loop so partial sums + # don't bf16-truncate between calls. Converted to bf16 once at the end. + C_l1_acc_ty = MemRefType.get([tile_m, tile_n], f32_ty, memory_space=l1_ms) + C_l1_drain_ty = MemRefType.get([tile_m, tile_n], bf16_ty, memory_space=l1_ms) zero_func = FuncOp( - "zero_vectorized_bf16_mn", - ([C_l1_ty], []), + "zero_vectorized_f32_mn", + ([C_l1_acc_ty], []), visibility="private", ) zero_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) zero_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() matmul_func = FuncOp( - "matmul_int4_bf16_packed", - ([B_l1_ty, A_l1_ty, C_l1_ty], []), + "matmul_int4_bf16_packed_f32", + ([B_l1_ty, A_l1_ty, C_l1_acc_ty], []), visibility="private", ) matmul_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) matmul_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + f32_to_bf16_func = FuncOp( + "f32_to_bf16_mn", + ([C_l1_acc_ty, C_l1_drain_ty], []), + visibility="private", + ) + f32_to_bf16_func.attributes["link_with"] = StringAttr.get(KERNEL_OBJ_NAME) + f32_to_bf16_func.attributes["llvm.emit_c_interface"] = UnitAttr.get() + @FuncOp.from_py_func(A_l3_ty, B_l3_ty, C_l3_ty) def matmul_int4_packed(arg_a, arg_b, arg_c): launch_size = [m // tile_m // herd_m, n // tile_n // herd_n] @@ -262,8 +275,9 @@ def segment_body(li_s, lj_s, l3_a_s, l3_b_s, l3_c_s): def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): _l1_a = AllocOp(A_l1_ty, [], []) _l1_b = AllocOp(B_l1_ty, [], []) - _l1_c = AllocOp(C_l1_ty, [], []) - CallOp(zero_func, [_l1_c]) + _l1_c_acc = AllocOp(C_l1_acc_ty, [], []) + _l1_c_drain = AllocOp(C_l1_drain_ty, [], []) + CallOp(zero_func, [_l1_c_acc]) for j in for_(0, k_per_l2): k1_off = affine_apply(k_chunk_off_l1_map, [j]) dma_memcpy_nd( @@ -290,11 +304,13 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): 1, ], ) - CallOp(matmul_func, [_l1_b, _l1_a, _l1_c]) + CallOp(matmul_func, [_l1_b, _l1_a, _l1_c_acc]) yield_([]) + # Convert f32 accumulator → bf16 once per launch. + CallOp(f32_to_bf16_func, [_l1_c_acc, _l1_c_drain]) dma_memcpy_nd( _l2c, - _l1_c, + _l1_c_drain, dst_offsets=[_tx, _ty, 0, 0], dst_sizes=[1, 1, tile_m, tile_n], dst_strides=[ @@ -306,7 +322,8 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): ) DeallocOp(_l1_a) DeallocOp(_l1_b) - DeallocOp(_l1_c) + DeallocOp(_l1_c_acc) + DeallocOp(_l1_c_drain) yield_([]) @@ -419,6 +436,9 @@ def compute_body(_tx, _ty, _sx, _sy, _l2a, _l2b, _l2c): expected_outputs=[C_ref], rtol=0.1, atol=0.05, + # bf16 floor: at large K and tight atol a small fraction of + # elements land just outside atol while correlation stays > 0.9999. + max_mismatch_percentage=0.05, min_correlation=0.999, ) ) diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 3e98589e9..7e3521854 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -39,7 +39,17 @@ #define DIM_GS 128 #endif +// matmul_int4_bf16_packed's K is the per-call k_chunk (matches the host +// builder's TILE_K_L1), not the matvec's full K. Default to 128 so a build +// that only consumes the matvec (e.g. DIM_K=2048) doesn't instantiate the +// matmul template at a k_chunk that overflows L1 scratch. +#ifndef DIM_K_CHUNK +#define DIM_K_CHUNK 128 +#endif + static_assert(DIM_K % DIM_GS == 0, "DIM_K must be a multiple of DIM_GS"); +static_assert(DIM_K_CHUNK % DIM_GS == 0, + "DIM_K_CHUNK must be a multiple of DIM_GS"); // int4-AWQ matvec inner kernel. One accfloat accumulator per output row spans // the full K. Within each group we accumulate raw (w_bf16 × b) into a per- @@ -112,24 +122,27 @@ static void zero_mn_impl(bfloat16 *__restrict c) { } // int4-AWQ GEMM inner kernel using aie::mmul<8,8,8,bf16,bf16,accfloat>. -// Phase 1: pack row-major A into a_pack and dequant W (with per-group bf16 -// scale folded in) into b_pack — both in the operand layouts that the bf16 -// GEMM kernel's MMUL loop expects (K-major A, N-major B). -// Phase 2: MMUL loop yields f32 accumulators; convert to bf16 and add to -// the existing row-major c. +// Dequant produces UNSCALED bf16 W (just nibble unpack + zero subtract). +// Per (m_b, n_b): preload the f32 c tile, then for each group do an unscaled +// MMUL, convert to f32 vec, multiply by f32 scale (scalar, no extra bf16 +// truncate), and accumulate into the c tile. Across host K-chunk calls the +// c buffer stays f32 (bf16-GEMM accum pattern). Matches mac-kernel's +// "truncate at group sum level, not per W element" precision pattern. template void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, uint8_t *__restrict a_z, bfloat16 *__restrict a, - bfloat16 *__restrict c) { + float *__restrict c) { constexpr unsigned r = 8, s = 8, t = 8; constexpr unsigned MB = m_tile / r; constexpr unsigned NB = n_tile / t; constexpr unsigned KB = k_chunk / s; + constexpr unsigned KB_PER_G = gs / s; constexpr unsigned NG = k_chunk / gs; static_assert(m_tile % r == 0, "m_tile must be multiple of 8"); static_assert(n_tile % t == 0, "n_tile must be multiple of 8"); static_assert(k_chunk % s == 0, "k_chunk must be multiple of 8"); + static_assert(gs % s == 0, "gs must be multiple of mmul k-tile (8)"); static_assert(gs % R == 0, "gs must be multiple of R"); ::aie::set_rounding(aie::rounding_mode::conv_even); @@ -151,15 +164,11 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, } } - // Dequant W (with scale fold) → [NB][KB][s][t]. One (g, n, i) iteration - // produces R K-values for fixed n that scatter across R/s k-blocks at - // stride t within each k-block. + // Dequant W (NO scale fold) → [NB][KB][s][t]. for (unsigned g = 0; g < NG; g++) { for (unsigned n = 0; n < n_tile; n++) { - bfloat16 sc = a_s[g * n_tile + n]; aie::vector zv = aie::broadcast((int8_t)a_z[g * n_tile + n]); - aie::vector sv = aie::broadcast(sc); unsigned n_b = n / t; unsigned n_i = n % t; const uint8_t *__restrict aq_n = a_q + n * (k_chunk / 2); @@ -171,8 +180,6 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, pk.template cast_to().template unpack_sign(false); w_i8 = aie::sub(w_i8, zv); aie::vector w_bf16 = aie::to_float(w_i8, 0); - aie::vector w_scaled = - aie::mul(w_bf16, sv).template to_vector(); unsigned k_b_base = k_base / s; bfloat16 *base = b_pack + n_b * (KB * s * t) + n_i; @@ -180,39 +187,93 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, for (unsigned j = 0; j < R; j++) { unsigned k_b = k_b_base + j / s; unsigned k_i = j % s; - base[k_b * (s * t) + k_i * t] = w_scaled[j]; + base[k_b * (s * t) + k_i * t] = w_bf16[j]; } } } } - // MMUL: one accumulator per (m_b, n_b) tile, reduced across KB k-blocks. + // Per (m_b, n_b): preload c into scalar f32 scratch, per group MMUL + + // scale fold + accumulate into scratch, store scratch back. for (unsigned m_b = 0; m_b < MB; m_b++) { for (unsigned n_b = 0; n_b < NB; n_b++) { - MMUL C; + float *cdst_base = c + (m_b * r) * n_tile + n_b * t; + alignas(32) float c_acc_buf[r * t]; + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::vector row = aie::load_v(cdst_base + m_i * n_tile); + aie::store_v(c_acc_buf + m_i * t, row); + } + + for (unsigned g = 0; g < NG; g++) { + // Per-group unscaled MMUL. Init accumulator from zeros (bf16-GEMM + // pattern, avoids the runtime `first` flag). + aie::vector zero_init = aie::zeros(); + MMUL C_g(zero_init); #pragma clang loop unroll(full) - for (unsigned k_b = 0; k_b < KB; k_b++) { - aie::vector A = - aie::load_v(a_pack + (k_b * MB + m_b) * (r * s)); - aie::vector B = - aie::load_v(b_pack + (n_b * KB + k_b) * (s * t)); - if (k_b == 0) - C.mul(A, B); - else - C.mac(A, B); + for (unsigned kg = 0; kg < KB_PER_G; kg++) { + unsigned k_b = g * KB_PER_G + kg; + aie::vector A = + aie::load_v(a_pack + (k_b * MB + m_b) * (r * s)); + aie::vector B = + aie::load_v(b_pack + (n_b * KB + k_b) * (s * t)); + C_g.mac(A, B); + } + // Spill C_g row-by-row via extract (proven pattern from + // the working pre-MMUL kernel). Scale fold in scalar f32. + aie::vector c_g_vec = C_g.template to_vector(); + alignas(32) float c_g_buf[r * t]; + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::vector row = c_g_vec.template extract(m_i); + aie::store_v(c_g_buf + m_i * t, row); + } + for (unsigned m_i = 0; m_i < r; m_i++) { + for (unsigned n_i = 0; n_i < t; n_i++) { + float scale_f = (float)a_s[g * n_tile + n_b * t + n_i]; + c_acc_buf[m_i * t + n_i] += c_g_buf[m_i * t + n_i] * scale_f; + } + } } - aie::vector ctile = C.template to_vector(); + + // Store accumulator back to c row-by-row. for (unsigned m_i = 0; m_i < r; m_i++) { - aie::vector row = ctile.template extract(m_i); - bfloat16 *cdst = c + (m_b * r + m_i) * n_tile + n_b * t; - aie::vector c_old = aie::load_v(cdst); - aie::vector c_new = aie::add(row, c_old); - aie::store_v(cdst, c_new); + aie::vector row = aie::load_v(c_acc_buf + m_i * t); + aie::store_v(cdst_base + m_i * n_tile, row); } } } } +// f32 zero for the GEMM C accumulator (kept in f32 across host K-chunk +// iterations to avoid bf16 truncation of partial sums). +template +static void zero_mn_f32_impl(float *__restrict c) { + constexpr unsigned VW = 16; + constexpr unsigned NTOT = m_tile * n_tile; + static_assert(NTOT % VW == 0, + "m_tile*n_tile must be a multiple of f32 vector width"); + aie::vector zv = aie::zeros(); + for (unsigned i = 0; i < NTOT; i += VW) + aie::store_v(c + i, zv); +} + +// f32 → bf16 narrowing for the final L1 C drain. Run after the host's K-loop +// completes; one call per (m_tile × n_tile) tile. +template +static void f32_to_bf16_mn_impl(const float *__restrict src, + bfloat16 *__restrict dst) { + ::aie::set_rounding(aie::rounding_mode::conv_even); + constexpr unsigned VW = 16; + constexpr unsigned NTOT = m_tile * n_tile; + static_assert(NTOT % VW == 0, "m_tile*n_tile must be a multiple of VW"); + for (unsigned i = 0; i < NTOT; i += VW) { + aie::vector v = aie::load_v(src + i); + aie::vector vb; + for (unsigned j = 0; j < VW; j++) + vb[j] = (bfloat16)v[j]; + aie::store_v(dst + i, vb); + } +} + template static void partial_plus_r_impl(const bfloat16 *__restrict partial, const bfloat16 *__restrict r_full, int offset, @@ -238,15 +299,24 @@ void zero_vectorized_bf16(bfloat16 *c) { zero_impl(c); } void zero_vectorized_bf16_mn(bfloat16 *c) { zero_mn_impl(c); } -// Packed-BO GEMM entry. Same Q+S+Z packing as the GEMV (output-major W), -// driven by an m_tile-row activation tile a[]. -void matmul_int4_bf16_packed(uint8_t *packed, bfloat16 *a, bfloat16 *c) { - constexpr unsigned Q_BYTES = DIM_N * (DIM_K / 2); - constexpr unsigned S_BYTES = (DIM_K / DIM_GS) * DIM_N * 2; +void zero_vectorized_f32_mn(float *c) { zero_mn_f32_impl(c); } + +void f32_to_bf16_mn(float *src, bfloat16 *dst) { + f32_to_bf16_mn_impl(src, dst); +} + +// Packed-BO GEMM entry, f32 output. Same Q+S+Z packing as the GEMV +// (output-major W); c is kept in f32 across host-side K-chunk iterations so +// the per-K-chunk partial sums don't bf16-truncate. Convert to bf16 once at +// the end with f32_to_bf16_mn. +void matmul_int4_bf16_packed_f32(uint8_t *packed, bfloat16 *a, float *c) { + constexpr unsigned Q_BYTES = DIM_N * (DIM_K_CHUNK / 2); + constexpr unsigned S_BYTES = (DIM_K_CHUNK / DIM_GS) * DIM_N * 2; uint8_t *a_q = packed; bfloat16 *a_s = reinterpret_cast(packed + Q_BYTES); uint8_t *a_z = packed + Q_BYTES + S_BYTES; - mm_int4_bf16_mmul_impl(a_q, a_s, a_z, a, c); + mm_int4_bf16_mmul_impl(a_q, a_s, a_z, a, + c); } void partial_plus_r_bf16(bfloat16 *partial, bfloat16 *r_full, int offset, From e87d1b35e54ec661ef2c0899e7866930fc874467 Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 1 Jun 2026 15:19:19 -0700 Subject: [PATCH 07/10] fixup: vectorize post-MMUL scale fold MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the 64-iter scalar mul+add per (m_b, n_b) per group with a vector mul+add chain: bf16 scale broadcast → aie::mul(c_g_bf16, scale_tile) → f32 accum → row-by-row add into c_acc_buf. The bf16 mul is supported on aie2p (f32 vector mul isn't) and produces an f32 accumulator with no extra truncate. Restores throughput to 117 ms / 147 GOPS at M=N=K=2048 herd 8x4 — matches the prior (precision-broken) MMUL kernel's speed AND keeps the post-MMUL fold + f32 c accumulator precision pattern. Net result vs the rejected mac kernel: 1.7x faster (198 → 117 ms) AND more precise (correlation 0.999974 vs ~0.99995, mismatches stay tiny). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../int4_awq/mv_int4_bf16.cc | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 7e3521854..57cbbb8ad 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -218,19 +218,31 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, aie::load_v(b_pack + (n_b * KB + k_b) * (s * t)); C_g.mac(A, B); } - // Spill C_g row-by-row via extract (proven pattern from - // the working pre-MMUL kernel). Scale fold in scalar f32. - aie::vector c_g_vec = C_g.template to_vector(); - alignas(32) float c_g_buf[r * t]; + // Vectorized post-MMUL scale fold: convert C_g to bf16, multiply + // by bf16 scale broadcast (mul → f32 accum, no extra truncate), + // lift to f32 vec, accumulate row-by-row into c_acc_buf. + aie::vector c_g_bf16 = + C_g.template to_vector(); + // Build bf16 scale broadcast: (m_i, n_i) row-major, scale per n_i. + // Load 8 scales into a vec and tile across r rows. + alignas(32) bfloat16 scale_row_buf[t]; + for (unsigned n_i = 0; n_i < t; n_i++) + scale_row_buf[n_i] = a_s[g * n_tile + n_b * t + n_i]; + aie::vector scale_row = aie::load_v(scale_row_buf); + alignas(32) bfloat16 scale_tile_buf[r * t]; + for (unsigned m_i = 0; m_i < r; m_i++) + aie::store_v(scale_tile_buf + m_i * t, scale_row); + aie::vector scale_tile = + aie::load_v(scale_tile_buf); + aie::accum scaled_acc = aie::mul(c_g_bf16, scale_tile); + aie::vector scaled_f32 = + scaled_acc.template to_vector(); + // Accumulate into c_acc_buf row by row (vector load + add + store). for (unsigned m_i = 0; m_i < r; m_i++) { - aie::vector row = c_g_vec.template extract(m_i); - aie::store_v(c_g_buf + m_i * t, row); - } - for (unsigned m_i = 0; m_i < r; m_i++) { - for (unsigned n_i = 0; n_i < t; n_i++) { - float scale_f = (float)a_s[g * n_tile + n_b * t + n_i]; - c_acc_buf[m_i * t + n_i] += c_g_buf[m_i * t + n_i] * scale_f; - } + aie::vector inc = scaled_f32.template extract(m_i); + aie::vector old = aie::load_v(c_acc_buf + m_i * t); + aie::vector sum = aie::add(old, inc); + aie::store_v(c_acc_buf + m_i * t, sum); } } From 5391abacbc4e935fffff138ad07d6cd5fa07c939 Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 1 Jun 2026 17:41:03 -0700 Subject: [PATCH 08/10] fixup: 2x2 mmul expansion + vectorized dequant scatter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The int4 GEMM micro-kernel previously used 1x1 mmul (single accumulator chain) with a 32-scalar-store dequant scatter — both vs the 2x2 mmul expansion + contiguous-store layout that the bf16 baseline uses. At Q-proj shape (M=N=K=2048, herd 8x4) the kernel ran 117 ms while bf16 ran 61 ms despite having less weight bytes to move. Changes: - 2x2 mmul: 4 independent accumulators C00/C01/C10/C11, A and B vector reuse 2x per inner kg iter, chess_prepare_for_pipelining hint - Dequant b_pack layout swapped to [NB][KB][t][s] (n_i outer, k_i inner) so 8 k_i values land contiguously per (n_i, k_b) — replaces 32 scalar stores per inner iter with 4 vector stores - aie::transpose(B, t, s) at mmul load flips back to the mmul-expected [s][t] order in-register, avoiding any host-side per-tile Q repack (keeps pack_inputs consistent with the canonical AWQ tile layout) Result at Q-proj M=N=K=2048, herd 8x4: 117 ms → 39.5 ms (2.96x). Now 1.55x faster than bf16 baseline (61 ms) at the same shape. Smoke (M=64 K=128 N=128, herd 2x4) still PASS at corr 0.999974. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../int4_awq/mv_int4_bf16.cc | 164 ++++++++++++------ 1 file changed, 110 insertions(+), 54 deletions(-) diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 57cbbb8ad..6f68345e5 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -139,8 +139,10 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, constexpr unsigned KB = k_chunk / s; constexpr unsigned KB_PER_G = gs / s; constexpr unsigned NG = k_chunk / gs; - static_assert(m_tile % r == 0, "m_tile must be multiple of 8"); - static_assert(n_tile % t == 0, "n_tile must be multiple of 8"); + static_assert(m_tile % (2 * r) == 0, + "m_tile must be multiple of 16 (2x mmul m for 2x2 expansion)"); + static_assert(n_tile % (2 * t) == 0, + "n_tile must be multiple of 16 (2x mmul n for 2x2 expansion)"); static_assert(k_chunk % s == 0, "k_chunk must be multiple of 8"); static_assert(gs % s == 0, "gs must be multiple of mmul k-tile (8)"); static_assert(gs % R == 0, "gs must be multiple of R"); @@ -164,7 +166,10 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, } } - // Dequant W (NO scale fold) → [NB][KB][s][t]. + // Dequant W → b_pack [NB][KB][t][s] = [NB][KB][n_i][k_i]. For fixed + // (n_b, n_i, k_b) the 8 k_i positions are contiguous → 8-wide vector + // store. mmul reads this with aie::transpose at load (mmul B wants the + // [s][t] order which is the transpose of what dequant produces). for (unsigned g = 0; g < NG; g++) { for (unsigned n = 0; n < n_tile; n++) { aie::vector zv = @@ -181,75 +186,126 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, w_i8 = aie::sub(w_i8, zv); aie::vector w_bf16 = aie::to_float(w_i8, 0); + // R=32 spans R/s=4 mmul k blocks. For each block, store 8 contiguous + // bf16 to b_pack[n_b][k_b][n_i][k_i = 0..7]. unsigned k_b_base = k_base / s; - bfloat16 *base = b_pack + n_b * (KB * s * t) + n_i; #pragma clang loop unroll(full) - for (unsigned j = 0; j < R; j++) { - unsigned k_b = k_b_base + j / s; - unsigned k_i = j % s; - base[k_b * (s * t) + k_i * t] = w_bf16[j]; + for (unsigned j = 0; j < R / s; j++) { + unsigned k_b = k_b_base + j; + bfloat16 *dst = b_pack + n_b * (KB * t * s) + k_b * (t * s) + n_i * s; + aie::vector chunk = w_bf16.template extract(j); + aie::store_v(dst, chunk); } } } } - // Per (m_b, n_b): preload c into scalar f32 scratch, per group MMUL + - // scale fold + accumulate into scratch, store scratch back. - for (unsigned m_b = 0; m_b < MB; m_b++) { - for (unsigned n_b = 0; n_b < NB; n_b++) { - float *cdst_base = c + (m_b * r) * n_tile + n_b * t; - alignas(32) float c_acc_buf[r * t]; + // 2x2 MMUL expansion: each iteration of the (m_b, n_b) super-tile holds + // 4 independent accumulators (C00/C01/C10/C11) so the inner kg loop + // issues 4 MACs/iter against 4 different register-file destinations — + // no serial dep chain on a single accumulator. Each loaded A and B + // vector is reused by 2 MACs. + for (unsigned m_b2 = 0; m_b2 < MB; m_b2 += 2) { + for (unsigned n_b2 = 0; n_b2 < NB; n_b2 += 2) { + float *cdst00 = c + (m_b2 * r) * n_tile + n_b2 * t; + float *cdst01 = c + (m_b2 * r) * n_tile + (n_b2 + 1) * t; + float *cdst10 = c + ((m_b2 + 1) * r) * n_tile + n_b2 * t; + float *cdst11 = c + ((m_b2 + 1) * r) * n_tile + (n_b2 + 1) * t; + + alignas(32) float c_acc_00[r * t]; + alignas(32) float c_acc_01[r * t]; + alignas(32) float c_acc_10[r * t]; + alignas(32) float c_acc_11[r * t]; for (unsigned m_i = 0; m_i < r; m_i++) { - aie::vector row = aie::load_v(cdst_base + m_i * n_tile); - aie::store_v(c_acc_buf + m_i * t, row); + aie::store_v(c_acc_00 + m_i * t, aie::load_v(cdst00 + m_i * n_tile)); + aie::store_v(c_acc_01 + m_i * t, aie::load_v(cdst01 + m_i * n_tile)); + aie::store_v(c_acc_10 + m_i * t, aie::load_v(cdst10 + m_i * n_tile)); + aie::store_v(c_acc_11 + m_i * t, aie::load_v(cdst11 + m_i * n_tile)); } for (unsigned g = 0; g < NG; g++) { - // Per-group unscaled MMUL. Init accumulator from zeros (bf16-GEMM - // pattern, avoids the runtime `first` flag). aie::vector zero_init = aie::zeros(); - MMUL C_g(zero_init); -#pragma clang loop unroll(full) - for (unsigned kg = 0; kg < KB_PER_G; kg++) { - unsigned k_b = g * KB_PER_G + kg; - aie::vector A = - aie::load_v(a_pack + (k_b * MB + m_b) * (r * s)); - aie::vector B = - aie::load_v(b_pack + (n_b * KB + k_b) * (s * t)); - C_g.mac(A, B); + MMUL C00(zero_init), C01(zero_init), C10(zero_init), C11(zero_init); + + const bfloat16 *__restrict pA0 = + a_pack + (g * KB_PER_G * MB + m_b2) * (r * s); + const bfloat16 *__restrict pA1 = + a_pack + (g * KB_PER_G * MB + m_b2 + 1) * (r * s); + const bfloat16 *__restrict pB0 = + b_pack + (n_b2 * KB + g * KB_PER_G) * (s * t); + const bfloat16 *__restrict pB1 = + b_pack + ((n_b2 + 1) * KB + g * KB_PER_G) * (s * t); + + chess_prepare_for_pipelining chess_loop_range(1, ) for (unsigned kg = 0; + kg < KB_PER_G; + kg++) { + aie::vector A0 = + aie::load_v(pA0); + pA0 += MB * (r * s); + aie::vector A1 = + aie::load_v(pA1); + pA1 += MB * (r * s); + // b_pack stores tiles in [t=n_i][s=k_i] order (dequant friendly); + // mmul wants [s=k_i][t=n_i], so transpose per load. + aie::vector B0 = + aie::transpose(aie::load_v(pB0), t, s); + pB0 += s * t; + aie::vector B1 = + aie::transpose(aie::load_v(pB1), t, s); + pB1 += s * t; + C00.mac(A0, B0); + C01.mac(A0, B1); + C10.mac(A1, B0); + C11.mac(A1, B1); + } + + // Per-group scale fold (cold path — runs NG=1 times for the + // production gs=k_chunk config). Two scale broadcasts (one per + // n-block); each applies to both m-block rows. + alignas(32) bfloat16 scale0_buf[t], scale1_buf[t]; + for (unsigned n_i = 0; n_i < t; n_i++) { + scale0_buf[n_i] = a_s[g * n_tile + n_b2 * t + n_i]; + scale1_buf[n_i] = a_s[g * n_tile + (n_b2 + 1) * t + n_i]; } - // Vectorized post-MMUL scale fold: convert C_g to bf16, multiply - // by bf16 scale broadcast (mul → f32 accum, no extra truncate), - // lift to f32 vec, accumulate row-by-row into c_acc_buf. - aie::vector c_g_bf16 = - C_g.template to_vector(); - // Build bf16 scale broadcast: (m_i, n_i) row-major, scale per n_i. - // Load 8 scales into a vec and tile across r rows. - alignas(32) bfloat16 scale_row_buf[t]; - for (unsigned n_i = 0; n_i < t; n_i++) - scale_row_buf[n_i] = a_s[g * n_tile + n_b * t + n_i]; - aie::vector scale_row = aie::load_v(scale_row_buf); - alignas(32) bfloat16 scale_tile_buf[r * t]; - for (unsigned m_i = 0; m_i < r; m_i++) - aie::store_v(scale_tile_buf + m_i * t, scale_row); - aie::vector scale_tile = - aie::load_v(scale_tile_buf); - aie::accum scaled_acc = aie::mul(c_g_bf16, scale_tile); - aie::vector scaled_f32 = - scaled_acc.template to_vector(); - // Accumulate into c_acc_buf row by row (vector load + add + store). + aie::vector s0 = aie::load_v(scale0_buf); + aie::vector s1 = aie::load_v(scale1_buf); + alignas(32) bfloat16 s0_tile_buf[r * t]; + alignas(32) bfloat16 s1_tile_buf[r * t]; for (unsigned m_i = 0; m_i < r; m_i++) { - aie::vector inc = scaled_f32.template extract(m_i); - aie::vector old = aie::load_v(c_acc_buf + m_i * t); - aie::vector sum = aie::add(old, inc); - aie::store_v(c_acc_buf + m_i * t, sum); + aie::store_v(s0_tile_buf + m_i * t, s0); + aie::store_v(s1_tile_buf + m_i * t, s1); } + aie::vector s0_tile = + aie::load_v(s0_tile_buf); + aie::vector s1_tile = + aie::load_v(s1_tile_buf); + + auto fold = [&](MMUL &C, aie::vector &scale_tile, + float *c_acc) { + aie::vector c_bf16 = + C.template to_vector(); + aie::accum scaled_acc = + aie::mul(c_bf16, scale_tile); + aie::vector scaled_f32 = + scaled_acc.template to_vector(); + for (unsigned m_i = 0; m_i < r; m_i++) { + aie::vector inc = scaled_f32.template extract(m_i); + aie::vector old = aie::load_v(c_acc + m_i * t); + aie::vector sum = aie::add(old, inc); + aie::store_v(c_acc + m_i * t, sum); + } + }; + fold(C00, s0_tile, c_acc_00); + fold(C01, s1_tile, c_acc_01); + fold(C10, s0_tile, c_acc_10); + fold(C11, s1_tile, c_acc_11); } - // Store accumulator back to c row-by-row. for (unsigned m_i = 0; m_i < r; m_i++) { - aie::vector row = aie::load_v(c_acc_buf + m_i * t); - aie::store_v(cdst_base + m_i * n_tile, row); + aie::store_v(cdst00 + m_i * n_tile, aie::load_v(c_acc_00 + m_i * t)); + aie::store_v(cdst01 + m_i * n_tile, aie::load_v(c_acc_01 + m_i * t)); + aie::store_v(cdst10 + m_i * n_tile, aie::load_v(c_acc_10 + m_i * t)); + aie::store_v(cdst11 + m_i * n_tile, aie::load_v(c_acc_11 + m_i * t)); } } } From 32769dd7329169a5cc153e761ea59fd53ef73f30 Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 1 Jun 2026 17:42:59 -0700 Subject: [PATCH 09/10] fixup: clang-format Co-Authored-By: Claude Opus 4.7 (1M context) --- .../int4_awq/mv_int4_bf16.cc | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 6f68345e5..93e2e7788 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -236,9 +236,8 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, const bfloat16 *__restrict pB1 = b_pack + ((n_b2 + 1) * KB + g * KB_PER_G) * (s * t); - chess_prepare_for_pipelining chess_loop_range(1, ) for (unsigned kg = 0; - kg < KB_PER_G; - kg++) { + chess_prepare_for_pipelining chess_loop_range( + 1, ) for (unsigned kg = 0; kg < KB_PER_G; kg++) { aie::vector A0 = aie::load_v(pA0); pA0 += MB * (r * s); @@ -275,17 +274,14 @@ void mm_int4_bf16_mmul_impl(uint8_t *__restrict a_q, bfloat16 *__restrict a_s, aie::store_v(s0_tile_buf + m_i * t, s0); aie::store_v(s1_tile_buf + m_i * t, s1); } - aie::vector s0_tile = - aie::load_v(s0_tile_buf); - aie::vector s1_tile = - aie::load_v(s1_tile_buf); + aie::vector s0_tile = aie::load_v(s0_tile_buf); + aie::vector s1_tile = aie::load_v(s1_tile_buf); auto fold = [&](MMUL &C, aie::vector &scale_tile, float *c_acc) { aie::vector c_bf16 = C.template to_vector(); - aie::accum scaled_acc = - aie::mul(c_bf16, scale_tile); + aie::accum scaled_acc = aie::mul(c_bf16, scale_tile); aie::vector scaled_f32 = scaled_acc.template to_vector(); for (unsigned m_i = 0; m_i < r; m_i++) { From 1e6e2e718ab546ccefc9e22994b401f978ebd531 Mon Sep 17 00:00:00 2001 From: erweiw Date: Mon, 1 Jun 2026 18:22:22 -0700 Subject: [PATCH 10/10] fixup: guard matmul entry from matvec builds The 2x2-expansion static_assert in mm_int4_bf16_mmul_impl requires m_tile and n_tile to be multiples of 16. Matvec lit tests build the same .cc with DIM_M=8, which doesn't link any matmul symbol but still instantiates the matmul template via matmul_int4_bf16_packed_f32 and trips the assert. Guard the matmul entry + helpers with #if DIM_M >= 16 && DIM_N >= 16 so matvec builds skip the matmul template instantiation. Confirmed locally: matvec.o (DIM_M=8) builds cleanly with no matmul symbols; matmul.o (DIM_M=16) builds cleanly with full symbol set. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc index 93e2e7788..63543a825 100644 --- a/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc +++ b/programming_examples/matrix_vector_multiplication/int4_awq/mv_int4_bf16.cc @@ -361,6 +361,10 @@ void matvec_int4_bf16_packed(uint8_t *packed, bfloat16 *b, bfloat16 *c) { void zero_vectorized_bf16(bfloat16 *c) { zero_impl(c); } +// Matmul-only entries are guarded so matvec builds (DIM_M=8) don't +// instantiate mm_int4_bf16_mmul_impl, which static-asserts m_tile and +// n_tile are multiples of 16 (2x mmul m/n for the 2x2 expansion). +#if DIM_M >= 16 && DIM_N >= 16 void zero_vectorized_bf16_mn(bfloat16 *c) { zero_mn_impl(c); } void zero_vectorized_f32_mn(float *c) { zero_mn_f32_impl(c); } @@ -382,6 +386,7 @@ void matmul_int4_bf16_packed_f32(uint8_t *packed, bfloat16 *a, float *c) { mm_int4_bf16_mmul_impl(a_q, a_s, a_z, a, c); } +#endif void partial_plus_r_bf16(bfloat16 *partial, bfloat16 *r_full, int offset, bfloat16 *d) {