From f96a5b962e66569b0d1d5362c3a1a1617fb90af3 Mon Sep 17 00:00:00 2001 From: fredericmiesegaes Date: Mon, 23 Mar 2026 19:09:38 +0100 Subject: [PATCH] fix: 8-bit dequant for MLX mixed-precision gate quantization MLX 4-bit models quantize routing gates (mlp.gate, mlp.shared_expert_gate) at 8-bit precision, specified per-tensor in config.json. The inference engine treated all tensors as 4-bit, extracting 8 nibbles per uint32 from data that actually packs 4 bytes per uint32. This corrupts routing scores, selecting wrong experts and producing nonsensical output. Changes: - Add dequant_matvec_8bit Metal kernel (4 bytes/uint32, FMA-optimized) - Add cpu_dequant_matvec_8bit CPU fallback - Add BatchMatvecSpec.bits field for per-tensor bit-width dispatch - Mark gate and shared_expert_gate as 8-bit in all dispatch sites Fixes #10 Co-Authored-By: Claude Opus 4.6 (1M context) --- metal_infer/infer.m | 137 +++++++++++++++++++++++++++----------- metal_infer/shaders.metal | 68 +++++++++++++++++++ 2 files changed, 166 insertions(+), 39 deletions(-) diff --git a/metal_infer/infer.m b/metal_infer/infer.m index 5d2a946f..058e12fd 100644 --- a/metal_infer/infer.m +++ b/metal_infer/infer.m @@ -734,6 +734,44 @@ static void cpu_dequant_matvec( } } +// 8-bit dequantized matrix-vector multiply (CPU fallback) +// W is stored as packed uint32 (4 x 8-bit values per uint32) +// scales/biases are bfloat16 per group +static void cpu_dequant_matvec_8bit( + const uint32_t *W, const uint16_t *scales, const uint16_t *biases, + const float *x, float *out, + int out_dim, int in_dim, int group_size +) { + int num_groups = in_dim / group_size; + int packed_per_group = group_size / 4; + int packed_cols = in_dim / 4; + + for (int row = 0; row < out_dim; row++) { + float acc = 0.0f; + const uint32_t *w_row = W + row * packed_cols; + const uint16_t *s_row = scales + row * num_groups; + const uint16_t *b_row = biases + row * num_groups; + + for (int g = 0; g < num_groups; g++) { + float scale = bf16_to_f32(s_row[g]); + float bias = bf16_to_f32(b_row[g]); + int base_packed = g * packed_per_group; + int base_x = g * group_size; + + for (int p = 0; p < packed_per_group; p++) { + uint32_t packed = w_row[base_packed + p]; + int x_base = base_x + p * 4; + + acc += ((float)((packed >> 0) & 0xFF) * scale + bias) * x[x_base + 0]; + acc += ((float)((packed >> 8) & 0xFF) * scale + bias) * x[x_base + 1]; + acc += ((float)((packed >> 16) & 0xFF) * scale + bias) * x[x_base + 2]; + acc += ((float)((packed >> 24) & 0xFF) * scale + bias) * x[x_base + 3]; + } + } + out[row] = acc; + } +} + // RMS normalization: out = x * w / rms(x) static void cpu_rms_norm(const float *x, const uint16_t *w_bf16, float *out, int dim, float eps) { float sum_sq = 0.0f; @@ -905,6 +943,7 @@ static void cpu_conv1d_step( id matvec_v5; // LUT dequant variant id matvec_fast; // for in_dim > 4096 id matvec_2bit; // 2-bit expert dequant kernel + id matvec_8bit; // 8-bit gate dequant kernel id rms_norm_sum; id rms_norm_apply; id rms_norm_apply_bf16; @@ -1045,6 +1084,7 @@ static void cpu_conv1d_step( ctx->matvec_v5 = makePipe(@"dequant_matvec_4bit_v5"); // LUT variant (no uint→float conversions) ctx->matvec_fast = makePipe(@"dequant_matvec_4bit_fast"); ctx->matvec_2bit = makePipe(@"dequant_matvec_2bit"); + ctx->matvec_8bit = makePipe(@"dequant_matvec_8bit"); ctx->rms_norm_sum = makePipe(@"rms_norm_sum_sq"); ctx->rms_norm_apply = makePipe(@"rms_norm_apply"); ctx->rms_norm_apply_bf16 = makePipe(@"rms_norm_apply_bf16"); @@ -1349,6 +1389,7 @@ static void fast_dequant_matvec( uint32_t in_dim; uint32_t group_size; int batch_slot; // which batch_out[slot] to use for GPU output + int bits; // quantization bits: 4 (default) or 8 (gate weights) } BatchMatvecSpec; // Run N matmuls in a single command buffer. All share the same input vector. @@ -1372,8 +1413,12 @@ static void gpu_batch_matvec( id o_buf = ctx->batch_out[s->batch_slot]; id enc = [cmdbuf computeCommandEncoder]; - int use_v3 = (s->in_dim <= 4096); - [enc setComputePipelineState: use_v3 ? ctx->matvec_v3 : ctx->matvec_fast]; + if (s->bits == 8) { + [enc setComputePipelineState: ctx->matvec_8bit]; + } else { + int use_v3 = (s->in_dim <= 4096); + [enc setComputePipelineState: use_v3 ? ctx->matvec_v3 : ctx->matvec_fast]; + } [enc setBuffer:ctx->wf_buf offset:w_off atIndex:0]; [enc setBuffer:ctx->wf_buf offset:s_off atIndex:1]; [enc setBuffer:ctx->wf_buf offset:b_off atIndex:2]; @@ -1383,7 +1428,7 @@ static void gpu_batch_matvec( [enc setBytes:&s->in_dim length:4 atIndex:6]; [enc setBytes:&s->group_size length:4 atIndex:7]; - if (use_v3) { + if (s->bits == 8 || s->in_dim <= 4096) { uint32_t num_tgs = (s->out_dim + 7) / 8; [enc dispatchThreadgroups:MTLSizeMake(num_tgs, 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; @@ -1426,8 +1471,12 @@ static void gpu_encode_batch_matvec( id o_buf = ctx->batch_out[s->batch_slot]; id enc = [cmdbuf computeCommandEncoder]; - int use_v3 = (s->in_dim <= 4096); - [enc setComputePipelineState: use_v3 ? ctx->matvec_v3 : ctx->matvec_fast]; + if (s->bits == 8) { + [enc setComputePipelineState: ctx->matvec_8bit]; + } else { + int use_v3 = (s->in_dim <= 4096); + [enc setComputePipelineState: use_v3 ? ctx->matvec_v3 : ctx->matvec_fast]; + } [enc setBuffer:ctx->wf_buf offset:w_off atIndex:0]; [enc setBuffer:ctx->wf_buf offset:s_off atIndex:1]; [enc setBuffer:ctx->wf_buf offset:b_off atIndex:2]; @@ -1437,7 +1486,7 @@ static void gpu_encode_batch_matvec( [enc setBytes:&s->in_dim length:4 atIndex:6]; [enc setBytes:&s->group_size length:4 atIndex:7]; - if (use_v3) { + if (s->bits == 8 || s->in_dim <= 4096) { uint32_t num_tgs = (s->out_dim + 7) / 8; [enc dispatchThreadgroups:MTLSizeMake(num_tgs, 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; @@ -1886,8 +1935,13 @@ static void fast_batch_matvec( } else { for (int i = 0; i < num_specs; i++) { BatchMatvecSpec *s = &specs[i]; - cpu_dequant_matvec(s->W, s->scales, s->biases, x, s->out_cpu, - s->out_dim, s->in_dim, s->group_size); + if (s->bits == 8) { + cpu_dequant_matvec_8bit(s->W, s->scales, s->biases, x, s->out_cpu, + s->out_dim, s->in_dim, s->group_size); + } else { + cpu_dequant_matvec(s->W, s->scales, s->biases, x, s->out_cpu, + s->out_dim, s->in_dim, s->group_size); + } } } } @@ -2188,9 +2242,9 @@ static void full_attention_forward( // Batch Q/K/V into one command buffer (3 dispatches, 1 commit) if (qw && qs && qb && kw && ks && kb && vw && vs && vb) { BatchMatvecSpec qkv_specs[3] = { - { qw, qs, qb, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0 }, - { kw, ks, kb, k, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1 }, - { vw, vs, vb, v, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2 }, + { qw, qs, qb, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0, 4 }, + { kw, ks, kb, k, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1, 4 }, + { vw, vs, vb, v, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2, 4 }, }; fast_batch_matvec(normed, HIDDEN_DIM, qkv_specs, 3); } @@ -2446,10 +2500,10 @@ static void linear_attention_forward( if (qkv_w && qkv_s && qkv_b && z_w && z_s && z_b && b_w && b_s && b_b && a_w && a_s && a_b) { BatchMatvecSpec la_specs[4] = { - { qkv_w, qkv_s, qkv_b, qkv, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0 }, - { z_w, z_s, z_b, z, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1 }, - { b_w, b_s, b_b, beta, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2 }, - { a_w, a_s, a_b, alpha, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3 }, + { qkv_w, qkv_s, qkv_b, qkv, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0, 4 }, + { z_w, z_s, z_b, z, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1, 4 }, + { b_w, b_s, b_b, beta, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2, 4 }, + { a_w, a_s, a_b, alpha, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3, 4 }, }; fast_batch_matvec(normed, HIDDEN_DIM, la_specs, 4); } @@ -2699,10 +2753,10 @@ static void moe_forward( if (gate_w && gate_s && gate_b && sgw && sgs && sgb && suw && sus && sub && seg_w && seg_s && seg_b) { BatchMatvecSpec moe_specs[4] = { - { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 }, - { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 }, - { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 }, - { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 }, + { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0, 8 }, + { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1, 4 }, + { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2, 4 }, + { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3, 8 }, }; fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4); } @@ -4035,9 +4089,9 @@ static void fused_layer_forward( if (lc->q_w && lc->q_s && lc->q_b && lc->k_w && lc->k_s && lc->k_b && lc->v_w && lc->v_s && lc->v_b) { - attn_specs[0] = (BatchMatvecSpec){ lc->q_w, lc->q_s, lc->q_b, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0 }; - attn_specs[1] = (BatchMatvecSpec){ lc->k_w, lc->k_s, lc->k_b, k_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1 }; - attn_specs[2] = (BatchMatvecSpec){ lc->v_w, lc->v_s, lc->v_b, v_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2 }; + attn_specs[0] = (BatchMatvecSpec){ lc->q_w, lc->q_s, lc->q_b, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0, 4 }; + attn_specs[1] = (BatchMatvecSpec){ lc->k_w, lc->k_s, lc->k_b, k_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1, 4 }; + attn_specs[2] = (BatchMatvecSpec){ lc->v_w, lc->v_s, lc->v_b, v_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2, 4 }; num_attn_specs = 3; } } else { @@ -4051,10 +4105,10 @@ static void fused_layer_forward( if (lc->qkv_w && lc->qkv_s && lc->qkv_b && lc->z_w && lc->z_s && lc->z_b && lc->b_w && lc->b_s && lc->b_b && lc->a_w && lc->a_s && lc->a_b) { - attn_specs[0] = (BatchMatvecSpec){ lc->qkv_w, lc->qkv_s, lc->qkv_b, qkv_out, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0 }; - attn_specs[1] = (BatchMatvecSpec){ lc->z_w, lc->z_s, lc->z_b, z_out, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1 }; - attn_specs[2] = (BatchMatvecSpec){ lc->b_w, lc->b_s, lc->b_b, beta_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2 }; - attn_specs[3] = (BatchMatvecSpec){ lc->a_w, lc->a_s, lc->a_b, alpha_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3 }; + attn_specs[0] = (BatchMatvecSpec){ lc->qkv_w, lc->qkv_s, lc->qkv_b, qkv_out, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0, 4 }; + attn_specs[1] = (BatchMatvecSpec){ lc->z_w, lc->z_s, lc->z_b, z_out, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1, 4 }; + attn_specs[2] = (BatchMatvecSpec){ lc->b_w, lc->b_s, lc->b_b, beta_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2, 4 }; + attn_specs[3] = (BatchMatvecSpec){ lc->a_w, lc->a_s, lc->a_b, alpha_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3, 4 }; num_attn_specs = 4; } } @@ -4339,8 +4393,13 @@ static void fused_layer_forward( } else { for (int i = 0; i < num_attn_specs; i++) { BatchMatvecSpec *s = &attn_specs[i]; - cpu_dequant_matvec(s->W, s->scales, s->biases, normed, s->out_cpu, - s->out_dim, s->in_dim, s->group_size); + if (s->bits == 8) { + cpu_dequant_matvec_8bit(s->W, s->scales, s->biases, normed, s->out_cpu, + s->out_dim, s->in_dim, s->group_size); + } else { + cpu_dequant_matvec(s->W, s->scales, s->biases, normed, s->out_cpu, + s->out_dim, s->in_dim, s->group_size); + } } } if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd1_submit += t1 - t0; } @@ -4378,9 +4437,9 @@ static void fused_layer_forward( memset(spec_scores, 0, NUM_EXPERTS * sizeof(float)); // Gate projection matvec on pre-attention normed input (CPU, ~0.1ms for 512x4096) - cpu_dequant_matvec(lc->gate_w, lc->gate_s, lc->gate_b, - normed, spec_scores, - NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE); + cpu_dequant_matvec_8bit(lc->gate_w, lc->gate_s, lc->gate_b, + normed, spec_scores, + NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE); cpu_softmax(spec_scores, NUM_EXPERTS); int spec_K = (K > MAX_K) ? MAX_K : K; @@ -4968,10 +5027,10 @@ static void fused_layer_forward( // ---- Enc 5-8: routing + shared expert projections (read buf_input) ---- BatchMatvecSpec moe_specs[4] = { - { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 }, - { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 }, - { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 }, - { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 }, + { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0, 8 }, + { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1, 4 }, + { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2, 4 }, + { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3, 8 }, }; // buf_input already contains h_post from Enc 4 output -- no memcpy needed gpu_encode_batch_matvec(g_metal, cmd_fused, moe_specs, 4); @@ -5017,10 +5076,10 @@ static void fused_layer_forward( // Routing + shared expert batch if (have_moe_weights) { BatchMatvecSpec moe_specs[4] = { - { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 }, - { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 }, - { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 }, - { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 }, + { gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0, 8 }, + { sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1, 4 }, + { suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2, 4 }, + { seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3, 8 }, }; fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4); } diff --git a/metal_infer/shaders.metal b/metal_infer/shaders.metal index 80a3be6b..a15e41b2 100644 --- a/metal_infer/shaders.metal +++ b/metal_infer/shaders.metal @@ -492,6 +492,74 @@ kernel void dequant_matvec_2bit( } +// ============================================================================ +// Kernel 1e: 8-bit dequantized matrix-vector multiply (FMA-optimized) +// ============================================================================ +// Same structure as dequant_matvec_4bit_v3 but for 8-bit quantization: +// - 4 values per uint32 (vs 8 for 4-bit, 16 for 2-bit) +// - packed_cols = in_dim / 4 +// - Extract bytes: (packed >> 0) & 0xFF, >> 8, >> 16, >> 24 +// - FMA-optimized: (byte * scale + bias) * x = fma(byte, scale*x, bias*x) +// Used for gate routing weights and shared_expert_gate in Qwen3.5-397B. + +kernel void dequant_matvec_8bit( + device const uint32_t* W_packed [[buffer(0)]], // [out_dim, in_dim/4] + device const uint16_t* scales [[buffer(1)]], // [out_dim, num_groups] bf16 + device const uint16_t* biases [[buffer(2)]], // [out_dim, num_groups] bf16 + device const float* x [[buffer(3)]], // [in_dim] + device float* out [[buffer(4)]], // [out_dim] + constant uint& out_dim [[buffer(5)]], + constant uint& in_dim [[buffer(6)]], + constant uint& group_size [[buffer(7)]], + uint tgid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint simd_lane [[thread_index_in_simdgroup]], + uint simd_group [[simdgroup_index_in_threadgroup]] +) { + uint row = tgid * ROWS_PER_TG + simd_group; + uint packed_cols = in_dim / 4; // 4 values per uint32 for 8-bit + uint num_groups = in_dim / group_size; + + threadgroup float x_shared[4096]; + for (uint i = lid; i < in_dim; i += 256) { + x_shared[i] = x[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (row >= out_dim) return; + + device const uint32_t* w_row = W_packed + row * packed_cols; + device const uint16_t* s_row = scales + row * num_groups; + device const uint16_t* b_row = biases + row * num_groups; + + float acc = 0.0f; + + for (uint col = simd_lane; col < packed_cols; col += 32) { + // group_size/4 packed words per group + uint g = col / (group_size / 4); + float scale = bf16_to_f32(s_row[g]); + float bias = bf16_to_f32(b_row[g]); + + uint32_t packed = w_row[col]; + uint x_base = col * 4; + + float sx0 = scale * x_shared[x_base + 0]; float bx0 = bias * x_shared[x_base + 0]; + float sx1 = scale * x_shared[x_base + 1]; float bx1 = bias * x_shared[x_base + 1]; + float sx2 = scale * x_shared[x_base + 2]; float bx2 = bias * x_shared[x_base + 2]; + float sx3 = scale * x_shared[x_base + 3]; float bx3 = bias * x_shared[x_base + 3]; + + acc += fma(float((packed >> 0) & 0xFF), sx0, bx0); + acc += fma(float((packed >> 8) & 0xFF), sx1, bx1); + acc += fma(float((packed >> 16) & 0xFF), sx2, bx2); + acc += fma(float((packed >> 24) & 0xFF), sx3, bx3); + } + + float sum = simd_sum(acc); + if (simd_lane == 0) { + out[row] = sum; + } +} + + // ============================================================================ // Kernel 1d: FULLY OPTIMIZED with uint4 vector loads // ============================================================================