From dc7c932c0833d48b936b69da3e82aab720013343 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 10 Jun 2026 11:56:17 -0700 Subject: [PATCH 1/2] Add Q2_0 quantization: type definition and CPU backend --- ggml/include/ggml.h | 4 +- ggml/src/ggml-common.h | 10 ++++ ggml/src/ggml-cpu/arch-fallback.h | 7 +++ ggml/src/ggml-cpu/arch/arm/quants.c | 74 ++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ggml-cpu.c | 6 +++ ggml/src/ggml-cpu/ops.cpp | 7 +++ ggml/src/ggml-cpu/quants.c | 51 +++++++++++++++++++ ggml/src/ggml-cpu/quants.h | 3 ++ ggml/src/ggml-quants.c | 76 +++++++++++++++++++++++++++++ ggml/src/ggml-quants.h | 3 ++ ggml/src/ggml.c | 10 ++++ gguf-py/gguf/constants.py | 3 ++ include/llama.h | 1 + src/llama-model-loader.cpp | 2 + src/llama-quant.cpp | 4 +- tests/test-quantize-fns.cpp | 3 +- tools/quantize/quantize.cpp | 1 + 17 files changed, 262 insertions(+), 3 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d6807b6dd47a..ac133665d978 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -429,7 +429,8 @@ extern "C" { GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_Q1_0 = 41, - GGML_TYPE_COUNT = 42, + GGML_TYPE_Q2_0 = 42, + GGML_TYPE_COUNT = 43, }; // precision @@ -473,6 +474,7 @@ extern "C" { GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors + GGML_FTYPE_MOSTLY_Q2_0 = 28, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f05683b44cd9..29028b32a2fd 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -96,6 +96,9 @@ typedef sycl::half2 ggml_half2; #define QI1_0 (QK1_0 / 32) #define QR1_0 1 +#define QI2_0 (QK2_0 / 32) +#define QR2_0 1 + #define QI4_0 (QK4_0 / (4 * QR4_0)) #define QR4_0 2 @@ -181,6 +184,13 @@ typedef struct { } block_q1_0; static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding"); +#define QK2_0 64 +typedef struct { + ggml_half d; // delta (scale) + uint8_t qs[QK2_0 / 4]; // 2 bits per element +} block_q2_0; +static_assert(sizeof(block_q2_0) == sizeof(ggml_half) + QK2_0 / 4, "wrong q2_0 block size/padding"); + #define QK4_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index b0391a67c88d..7aeacfdd5b28 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -17,6 +17,7 @@ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 +#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -83,6 +84,7 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -114,6 +116,7 @@ #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 +#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -163,6 +166,7 @@ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 +#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -203,6 +207,7 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -244,6 +249,7 @@ #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 +#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -308,6 +314,7 @@ #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 +#define ggml_vec_dot_q2_0_q8_0_generic ggml_vec_dot_q2_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index fe6213329708..9faa4a014193 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -219,6 +219,80 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +void ggml_vec_dot_q2_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK2_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0f; + +#if defined(__ARM_NEON) + // Replicate pattern: each byte repeated 4 times + static const uint8_t tbl_idx_lo[16] = {0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3}; + static const uint8_t tbl_idx_hi[16] = {4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7}; + // Right-shift amounts: 0,2,4,6 repeated for each group of 4 + static const int8_t shift_vals[16] = {0,-2,-4,-6, 0,-2,-4,-6, 0,-2,-4,-6, 0,-2,-4,-6}; + + const uint8x16_t idx_lo = vld1q_u8(tbl_idx_lo); + const uint8x16_t idx_hi = vld1q_u8(tbl_idx_hi); + const int8x16_t shifts = vld1q_s8(shift_vals); + const uint8x16_t mask2 = vdupq_n_u8(0x03); + const int8x16_t one = vdupq_n_s8(1); + + float32x4_t sumv = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + // group 64: one Q2_0 block (64 weights) maps to two Q8_0 blocks (2 * 32 = 64) + for (int k = 0; k < 2; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 2 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + + // Load 8 bytes of packed 2-bit values + const uint8x8_t raw = vld1_u8(&x[i].qs[k * 8]); + const uint8x16_t raw16 = vcombine_u8(raw, raw); + + // First 16 elements: replicate bytes 0-3, shift, mask, subtract 1 + uint8x16_t bytes0 = vqtbl1q_u8(raw16, idx_lo); + int8x16_t qv0 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshlq_u8(bytes0, shifts), mask2)), + one); + + // Second 16 elements: replicate bytes 4-7, shift, mask, subtract 1 + uint8x16_t bytes1 = vqtbl1q_u8(raw16, idx_hi); + int8x16_t qv1 = vsubq_s8( + vreinterpretq_s8_u8(vandq_u8(vshlq_u8(bytes1, shifts), mask2)), + one); + + // Load Q8_0 values and dot product + const int8x16_t y0 = vld1q_s8(yb->qs); + const int8x16_t y1 = vld1q_s8(yb->qs + 16); + + int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), qv0, y0); + int32x4_t p1 = ggml_vdotq_s32(p0, qv1, y1); + + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1); + } + } + + sumf = vaddvq_f32(sumv); +#else + ggml_vec_dot_q2_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + return; +#endif + + *s = sumf; +} void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index eb8341c9aecc..15c31fa01ec3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -227,6 +227,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, + [GGML_TYPE_Q2_0] = { + .from_float = quantize_row_q2_0, + .vec_dot = ggml_vec_dot_q2_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q4_0] = { .from_float = quantize_row_q4_0, .vec_dot = ggml_vec_dot_q4_0_q8_0, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 74611dce7f1a..6ab3fd24c3e4 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -665,6 +665,7 @@ void ggml_compute_forward_add( ggml_compute_forward_add_non_quantized(params, dst); } break; case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1115,6 +1116,7 @@ void ggml_compute_forward_add1( } } break; case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1245,6 +1247,7 @@ void ggml_compute_forward_acc( case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4415,6 +4418,7 @@ void ggml_compute_forward_out_prod( switch (src0->type) { case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4691,6 +4695,7 @@ void ggml_compute_forward_set( case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4915,6 +4920,7 @@ void ggml_compute_forward_get_rows( switch (src0->type) { case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5641,6 +5647,7 @@ void ggml_compute_forward_clamp( } break; case GGML_TYPE_BF16: case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index e5f9a4083f9c..5e36459f8cbc 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -26,6 +26,10 @@ void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in quantize_row_q1_0_ref(x, y, k); } +void quantize_row_q2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q2_0_ref(x, y, k); +} + void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { quantize_row_q4_0_ref(x, y, k); } @@ -170,6 +174,53 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c *s = sumf; } +void ggml_vec_dot_q2_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK2_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0f; + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + float sumi = 0.0f; + + // group 64: one Q2_0 block (64 weights) maps to two Q8_0 blocks (2 * 32 = 64) + for (int k = 0; k < 2; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 2 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + int sumi_block = 0; + + const uint8_t * GGML_RESTRICT qs = &x[i].qs[k * 8]; + const int8_t * GGML_RESTRICT qy = yb->qs; + + for (int b = 0; b < 8; ++b) { + const uint8_t byte = qs[b]; + // Extract 4 two-bit values, map {0,1,2,3} -> {-1,0,1,2} + sumi_block += ((int)((byte >> 0) & 3) - 1) * qy[b*4 + 0]; + sumi_block += ((int)((byte >> 2) & 3) - 1) * qy[b*4 + 1]; + sumi_block += ((int)((byte >> 4) & 3) - 1) * qy[b*4 + 2]; + sumi_block += ((int)((byte >> 6) & 3) - 1) * qy[b*4 + 3]; + } + + sumi += d1 * sumi_block; + } + + sumf += d0 * sumi; + } + + *s = sumf; +} void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index d4bc87a1c052..93ea7eeffe5b 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -13,6 +13,7 @@ extern "C" { // Quantization void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -38,6 +39,7 @@ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, // Dot product void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q2_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -71,6 +73,7 @@ void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q2_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 15d231f70c0d..1ebc50a763f1 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -71,6 +71,44 @@ void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_REST } } +void quantize_row_q2_0_ref(const float * GGML_RESTRICT x, block_q2_0 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK2_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + // Compute scale as max absolute value in the block + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float a = fabsf(x[i*qk + j]); + if (a > amax) amax = a; + } + const float d = amax; + const float id = d > 0.0f ? 1.0f / d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + // Clear quant bytes + for (int j = 0; j < qk / 4; ++j) { + y[i].qs[j] = 0; + } + + // Encode 2-bit values: round(w/d) clamped to [-1, 2], then add 1 + // 00 (-1) = -scale, 01 (0) = 0, 10 (+1) = +scale, 11 (+2) = 2*scale + for (int j = 0; j < qk; ++j) { + const float w = x[i*qk + j]; + int q = (int)roundf(w * id) + 1; + if (q < 0) q = 0; + if (q > 3) q = 3; + const int byte_index = j / 4; + const int bit_offset = (j % 4) * 2; + y[i].qs[byte_index] |= ((uint8_t)q << bit_offset); + } + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -398,6 +436,26 @@ void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRI } } +void dequantize_row_q2_0(const block_q2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK2_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk; ++j) { + const int byte_index = j / 4; + const int bit_offset = (j % 4) * 2; + const uint8_t q = (x[i].qs[byte_index] >> bit_offset) & 0x03; + // 00=-1, 01=0, 10=+1, 11=+2 + y[i*qk + j] = ((int)q - 1) * d; + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -2052,6 +2110,20 @@ size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * row_size; } +size_t quantize_q2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q2_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q2_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q2_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q2_0_ref(src, (block_q2_0*)qrow, n_per_row); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { @@ -5461,6 +5533,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_q1_0, data, nb); } break; + case GGML_TYPE_Q2_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q2_0, data, nb); + } break; case GGML_TYPE_Q4_0: { VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index d56c86da8909..75188f1af180 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -15,6 +15,7 @@ extern "C" { // Quantization GGML_API void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_q2_0_ref(const float * GGML_RESTRICT x, block_q2_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -43,6 +44,7 @@ GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_ // Dequantization GGML_API void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_q2_0(const block_q2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -93,6 +95,7 @@ GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b43016c87d21..3d682dcb2af1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -674,6 +674,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_q1_0, .from_float_ref = (ggml_from_float_t) quantize_row_q1_0_ref, }, + [GGML_TYPE_Q2_0] = { + .type_name = "q2_0", + .blck_size = QK2_0, + .type_size = sizeof(block_q2_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q2_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q2_0_ref, + }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -1410,6 +1418,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; case GGML_FTYPE_MOSTLY_Q1_0: wtype = GGML_TYPE_Q1_0; break; + case GGML_FTYPE_MOSTLY_Q2_0: wtype = GGML_TYPE_Q2_0; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; @@ -7732,6 +7741,7 @@ size_t ggml_quantize_chunk( switch (type) { case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q2_0: result = quantize_q2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 584594097346..2ebdd52bd12e 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -4356,6 +4356,7 @@ class GGMLQuantizationType(IntEnum): MXFP4 = 39 NVFP4 = 40 Q1_0 = 41 + Q2_0 = 42 class ExpertGatingFuncType(IntEnum): @@ -4410,6 +4411,7 @@ class LlamaFileType(IntEnum): MOSTLY_MXFP4_MOE = 38 # except 1d tensors MOSTLY_NVFP4 = 39 # except 1d tensors MOSTLY_Q1_0 = 40 # except 1d tensors + MOSTLY_Q2_0 = 41 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -4535,6 +4537,7 @@ class VisionProjectorType: GGMLQuantizationType.MXFP4: (32, 1 + 16), GGMLQuantizationType.NVFP4: (64, 4 + 32), GGMLQuantizationType.Q1_0: (128, 2 + 16), + GGMLQuantizationType.Q2_0: (64, 2 + 16), } diff --git a/include/llama.h b/include/llama.h index 27e480674282..4ea072e8d11b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -155,6 +155,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q1_0 = 40, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_0 = 41, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 0d1cf3cc33bb..b211950740d8 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -37,6 +37,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0"; + case LLAMA_FTYPE_MOSTLY_Q2_0: return "Q2_0"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; @@ -761,6 +762,7 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break; case GGML_TYPE_Q1_0: ftype = LLAMA_FTYPE_MOSTLY_Q1_0; break; + case GGML_TYPE_Q2_0: ftype = LLAMA_FTYPE_MOSTLY_Q2_0; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index cf92ce4bb8b7..140974dc36ac 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -380,6 +380,7 @@ static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tenso case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_S: // types on the right: block size 32 case GGML_TYPE_IQ4_XS: return_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q2_0: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_TQ1_0: @@ -480,7 +481,7 @@ static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ3_S; } - else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) { + else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0 || ftype == LLAMA_FTYPE_MOSTLY_Q2_0) { new_type = GGML_TYPE_Q4_K; } } @@ -800,6 +801,7 @@ ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; case LLAMA_FTYPE_MOSTLY_Q1_0: return GGML_TYPE_Q1_0; + case LLAMA_FTYPE_MOSTLY_Q2_0: return GGML_TYPE_Q2_0; case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index a05fab50421f..b79e3d193b7a 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -150,6 +150,7 @@ int main(int argc, char * argv[]) { type == GGML_TYPE_Q1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_BINARY : type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : + type == GGML_TYPE_Q2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : @@ -175,7 +176,7 @@ int main(int argc, char * argv[]) { ? MAX_DOT_PRODUCT_ERROR_LOWBIT : type == GGML_TYPE_Q1_0 ? MAX_DOT_PRODUCT_ERROR_BINARY - : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0 + : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0 || type == GGML_TYPE_Q2_0 ? MAX_DOT_PRODUCT_ERROR_TERNARY : type == GGML_TYPE_NVFP4 ? MAX_DOT_PRODUCT_ERROR_FP4 diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 840eefc2f5ac..15ef64c4b0ed 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -33,6 +33,7 @@ struct quant_option { static const std::vector QUANT_OPTIONS = { { "Q1_0", LLAMA_FTYPE_MOSTLY_Q1_0, " 1.125 bpw quantization", }, + { "Q2_0", LLAMA_FTYPE_MOSTLY_Q2_0, " 2.25 bpw quantization (group 64)", }, { "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", }, { "MXFP4_MOE",LLAMA_FTYPE_MOSTLY_MXFP4_MOE," MXFP4 MoE", }, From 7a4a89a72e1a3685deacc8d856fdcfb383b2252f Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 10 Jun 2026 00:55:59 -0700 Subject: [PATCH 2/2] Q2_0 group 64: Metal backend --- ggml/src/ggml-metal/ggml-metal-device.cpp | 10 ++ ggml/src/ggml-metal/ggml-metal-device.m | 2 + ggml/src/ggml-metal/ggml-metal-impl.h | 3 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 202 ++++++++++++++++++++++ 5 files changed, 218 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 4f4f073cb614..7a2dcbae2054 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -787,6 +787,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta nsg = N_SG_Q1_0; nr0 = N_R0_Q1_0; } break; + case GGML_TYPE_Q2_0: + { + nsg = N_SG_Q2_0; + nr0 = N_R0_Q2_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -1011,6 +1016,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m nsg = N_SG_Q1_0; nr0 = N_R0_Q1_0; } break; + case GGML_TYPE_Q2_0: + { + nsg = N_SG_Q2_0; + nr0 = N_R0_Q2_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 05d7f43051ba..9065bfb8c384 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1259,6 +1259,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1286,6 +1287,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; } case GGML_TYPE_Q1_0: + case GGML_TYPE_Q2_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ff74cafb5b79..89188fef29b1 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -24,6 +24,9 @@ #define N_R0_Q1_0 8 #define N_SG_Q1_0 2 +#define N_R0_Q2_0 8 +#define N_SG_Q2_0 2 + #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e2ce56e9e28b..b860f164b8c2 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2068,6 +2068,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16 || op->src[0]->type == GGML_TYPE_Q1_0 || + op->src[0]->type == GGML_TYPE_Q2_0 || op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0aea68455fba..d809f014fc25 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -168,6 +168,39 @@ void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & r reg = (type4) reg_f; } +template +void dequantize_q2_0(device const block_q2_0 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs; + const float d = xb->d; + + const int byte_offset = il * 4; // il*16 elements = il*4 bytes (4 elements per byte) + float4x4 reg_f; + + for (int i = 0; i < 4; i++) { + const uint8_t b = qs[byte_offset + i]; + reg_f[i][0] = ((float)((b >> 0) & 3) - 1.0f) * d; + reg_f[i][1] = ((float)((b >> 2) & 3) - 1.0f) * d; + reg_f[i][2] = ((float)((b >> 4) & 3) - 1.0f) * d; + reg_f[i][3] = ((float)((b >> 6) & 3) - 1.0f) * d; + } + + reg = (type4x4) reg_f; +} + +template +void dequantize_q2_0_t4(device const block_q2_0 * xb, short il, thread type4 & reg) { + const float d = xb->d; + const uint8_t b = xb->qs[il]; + + float4 reg_f; + reg_f[0] = ((float)((b >> 0) & 3) - 1.0f) * d; + reg_f[1] = ((float)((b >> 2) & 3) - 1.0f) * d; + reg_f[2] = ((float)((b >> 4) & 3) - 1.0f) * d; + reg_f[3] = ((float)((b >> 6) & 3) - 1.0f) * d; + + reg = (type4) reg_f; +} + template void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); @@ -219,6 +252,27 @@ void quantize_q1_0(device const float * src, device block_q1_0 & dst) { } } +void quantize_q2_0(device const float * src, device block_q2_0 & dst) { + float amax = 0.0f; + for (int j = 0; j < QK2_0; j++) { + float a = fabs(src[j]); + if (a > amax) amax = a; + } + const float d = amax; + dst.d = d; + + const float id = d > 0.0f ? 1.0f / d : 0.0f; + + for (int j = 0; j < QK2_0 / 4; j++) { + dst.qs[j] = 0; + } + for (int j = 0; j < QK2_0; j++) { + int q = (int)round(src[j] * id) + 1; + q = max(0, min(3, q)); + dst.qs[j / 4] |= (q << (2 * (j % 4))); + } +} + void quantize_q4_0(device const float * src, device block_q4_0 & dst) { #pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max @@ -3284,6 +3338,60 @@ inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thre return qb_curr->d * (2.0f * acc - sumy); } +// Q2_0 dot product: dot = d * (Σ(q_raw[i] * yl[i]) - sumy) +// q_raw are unsigned 2-bit values {0,1,2,3}, mapping: value = (q_raw - 1) * d +// Q2_0 dot product using bit-decomposition: +// value = (low_bit + 2*high_bit - 1) +// sum(value * y) = sum_lo(y) + 2*sum_hi(y) - sumy +// where sum_lo/sum_hi use Q1_0-style conditional adds (no multiplies) +inline float block_q_n_dot_y(device const block_q2_0 * qb_curr, float sumy, thread float * yl, int il) { + device const uint8_t * qs = qb_curr->qs + (il / 4); + const uint8_t b0 = qs[0]; + const uint8_t b1 = qs[1]; + const uint8_t b2 = qs[2]; + const uint8_t b3 = qs[3]; + + // Accumulate where low bit is set (bits 0,2,4,6 of each byte) + float acc_lo = 0.0f; + acc_lo += select(0.0f, yl[ 0], bool(b0 & 0x01)); + acc_lo += select(0.0f, yl[ 1], bool(b0 & 0x04)); + acc_lo += select(0.0f, yl[ 2], bool(b0 & 0x10)); + acc_lo += select(0.0f, yl[ 3], bool(b0 & 0x40)); + acc_lo += select(0.0f, yl[ 4], bool(b1 & 0x01)); + acc_lo += select(0.0f, yl[ 5], bool(b1 & 0x04)); + acc_lo += select(0.0f, yl[ 6], bool(b1 & 0x10)); + acc_lo += select(0.0f, yl[ 7], bool(b1 & 0x40)); + acc_lo += select(0.0f, yl[ 8], bool(b2 & 0x01)); + acc_lo += select(0.0f, yl[ 9], bool(b2 & 0x04)); + acc_lo += select(0.0f, yl[10], bool(b2 & 0x10)); + acc_lo += select(0.0f, yl[11], bool(b2 & 0x40)); + acc_lo += select(0.0f, yl[12], bool(b3 & 0x01)); + acc_lo += select(0.0f, yl[13], bool(b3 & 0x04)); + acc_lo += select(0.0f, yl[14], bool(b3 & 0x10)); + acc_lo += select(0.0f, yl[15], bool(b3 & 0x40)); + + // Accumulate where high bit is set (bits 1,3,5,7 of each byte) + float acc_hi = 0.0f; + acc_hi += select(0.0f, yl[ 0], bool(b0 & 0x02)); + acc_hi += select(0.0f, yl[ 1], bool(b0 & 0x08)); + acc_hi += select(0.0f, yl[ 2], bool(b0 & 0x20)); + acc_hi += select(0.0f, yl[ 3], bool(b0 & 0x80)); + acc_hi += select(0.0f, yl[ 4], bool(b1 & 0x02)); + acc_hi += select(0.0f, yl[ 5], bool(b1 & 0x08)); + acc_hi += select(0.0f, yl[ 6], bool(b1 & 0x20)); + acc_hi += select(0.0f, yl[ 7], bool(b1 & 0x80)); + acc_hi += select(0.0f, yl[ 8], bool(b2 & 0x02)); + acc_hi += select(0.0f, yl[ 9], bool(b2 & 0x08)); + acc_hi += select(0.0f, yl[10], bool(b2 & 0x20)); + acc_hi += select(0.0f, yl[11], bool(b2 & 0x80)); + acc_hi += select(0.0f, yl[12], bool(b3 & 0x02)); + acc_hi += select(0.0f, yl[13], bool(b3 & 0x08)); + acc_hi += select(0.0f, yl[14], bool(b3 & 0x20)); + acc_hi += select(0.0f, yl[15], bool(b3 & 0x80)); + + return qb_curr->d * (acc_lo + 2.0f * acc_hi - sumy); +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -3587,6 +3695,86 @@ kernel void kernel_mul_mv_q1_0_f32( kernel_mul_mv_q1_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template +void kernel_mul_mv_q2_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + const int nb = args.ne00/QK2_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + device const block_q2_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + ax[row] = (device const block_q2_0 *) ((device char *) src0 + offset0); + } + + float yl[16]; + float sumf[nr0] = {0.f}; + + // group 64: 4 sub-blocks of 16 weights per Q2_0 block (was 8 for group 128) + const short ix = (tiisg/4); + const short il = (tiisg%4)*16; + + device const float * yb = y + ix*QK2_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/4) { + float sumy = 0.f; + + FOR_UNROLL (short i = 0; i < 16; i++) { + yl[i] = yb[i]; + sumy += yb[i]; + } + + FOR_UNROLL (short row = 0; row < nr0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); + } + + yb += QK2_0 * (N_SIMDWIDTH/4); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q2_0_f32")]] +kernel void kernel_mul_mv_q2_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q2_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3984,6 +4172,11 @@ template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>; template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q2_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q2_0, 64, dequantize_q2_0_t4>; +template [[host_name("kernel_mul_mv_ext_q2_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q2_0, 64, dequantize_q2_0_t4>; +template [[host_name("kernel_mul_mv_ext_q2_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q2_0, 64, dequantize_q2_0_t4>; +template [[host_name("kernel_mul_mv_ext_q2_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q2_0, 64, dequantize_q2_0_t4>; + template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; @@ -7452,6 +7645,7 @@ typedef decltype(kernel_cpy_f32_q) cpy_f_q_ template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q2_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; @@ -7497,6 +7691,7 @@ kernel void kernel_cpy_q_f32( typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q2_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; @@ -7504,6 +7699,7 @@ template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32< template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q2_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; @@ -10145,6 +10341,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro typedef decltype(kernel_get_rows_q) get_rows_q_t; template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_q2_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; @@ -10208,6 +10405,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -10232,6 +10430,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q2_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -10265,6 +10464,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id; #endif template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; @@ -10289,6 +10489,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -10444,6 +10645,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4 template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q2_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;