Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 98 additions & 39 deletions metal_infer/infer.m
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,44 @@ static void cpu_dequant_matvec(
}
}

// 8-bit dequantized matrix-vector multiply (CPU fallback)
// W is stored as packed uint32 (4 x 8-bit values per uint32)
// scales/biases are bfloat16 per group
static void cpu_dequant_matvec_8bit(
const uint32_t *W, const uint16_t *scales, const uint16_t *biases,
const float *x, float *out,
int out_dim, int in_dim, int group_size
) {
int num_groups = in_dim / group_size;
int packed_per_group = group_size / 4;
int packed_cols = in_dim / 4;

for (int row = 0; row < out_dim; row++) {
float acc = 0.0f;
const uint32_t *w_row = W + row * packed_cols;
const uint16_t *s_row = scales + row * num_groups;
const uint16_t *b_row = biases + row * num_groups;

for (int g = 0; g < num_groups; g++) {
float scale = bf16_to_f32(s_row[g]);
float bias = bf16_to_f32(b_row[g]);
int base_packed = g * packed_per_group;
int base_x = g * group_size;

for (int p = 0; p < packed_per_group; p++) {
uint32_t packed = w_row[base_packed + p];
int x_base = base_x + p * 4;

acc += ((float)((packed >> 0) & 0xFF) * scale + bias) * x[x_base + 0];
acc += ((float)((packed >> 8) & 0xFF) * scale + bias) * x[x_base + 1];
acc += ((float)((packed >> 16) & 0xFF) * scale + bias) * x[x_base + 2];
acc += ((float)((packed >> 24) & 0xFF) * scale + bias) * x[x_base + 3];
}
}
out[row] = acc;
}
}

// RMS normalization: out = x * w / rms(x)
static void cpu_rms_norm(const float *x, const uint16_t *w_bf16, float *out, int dim, float eps) {
float sum_sq = 0.0f;
Expand Down Expand Up @@ -905,6 +943,7 @@ static void cpu_conv1d_step(
id<MTLComputePipelineState> matvec_v5; // LUT dequant variant
id<MTLComputePipelineState> matvec_fast; // for in_dim > 4096
id<MTLComputePipelineState> matvec_2bit; // 2-bit expert dequant kernel
id<MTLComputePipelineState> matvec_8bit; // 8-bit gate dequant kernel
id<MTLComputePipelineState> rms_norm_sum;
id<MTLComputePipelineState> rms_norm_apply;
id<MTLComputePipelineState> rms_norm_apply_bf16;
Expand Down Expand Up @@ -1045,6 +1084,7 @@ static void cpu_conv1d_step(
ctx->matvec_v5 = makePipe(@"dequant_matvec_4bit_v5"); // LUT variant (no uint→float conversions)
ctx->matvec_fast = makePipe(@"dequant_matvec_4bit_fast");
ctx->matvec_2bit = makePipe(@"dequant_matvec_2bit");
ctx->matvec_8bit = makePipe(@"dequant_matvec_8bit");
ctx->rms_norm_sum = makePipe(@"rms_norm_sum_sq");
ctx->rms_norm_apply = makePipe(@"rms_norm_apply");
ctx->rms_norm_apply_bf16 = makePipe(@"rms_norm_apply_bf16");
Expand Down Expand Up @@ -1349,6 +1389,7 @@ static void fast_dequant_matvec(
uint32_t in_dim;
uint32_t group_size;
int batch_slot; // which batch_out[slot] to use for GPU output
int bits; // quantization bits: 4 (default) or 8 (gate weights)
} BatchMatvecSpec;

// Run N matmuls in a single command buffer. All share the same input vector.
Expand All @@ -1372,8 +1413,12 @@ static void gpu_batch_matvec(
id<MTLBuffer> o_buf = ctx->batch_out[s->batch_slot];

id<MTLComputeCommandEncoder> enc = [cmdbuf computeCommandEncoder];
int use_v3 = (s->in_dim <= 4096);
[enc setComputePipelineState: use_v3 ? ctx->matvec_v3 : ctx->matvec_fast];
if (s->bits == 8) {
[enc setComputePipelineState: ctx->matvec_8bit];
} else {
int use_v3 = (s->in_dim <= 4096);
[enc setComputePipelineState: use_v3 ? ctx->matvec_v3 : ctx->matvec_fast];
}
[enc setBuffer:ctx->wf_buf offset:w_off atIndex:0];
[enc setBuffer:ctx->wf_buf offset:s_off atIndex:1];
[enc setBuffer:ctx->wf_buf offset:b_off atIndex:2];
Expand All @@ -1383,7 +1428,7 @@ static void gpu_batch_matvec(
[enc setBytes:&s->in_dim length:4 atIndex:6];
[enc setBytes:&s->group_size length:4 atIndex:7];

if (use_v3) {
if (s->bits == 8 || s->in_dim <= 4096) {
uint32_t num_tgs = (s->out_dim + 7) / 8;
[enc dispatchThreadgroups:MTLSizeMake(num_tgs, 1, 1)
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
Expand Down Expand Up @@ -1426,8 +1471,12 @@ static void gpu_encode_batch_matvec(
id<MTLBuffer> o_buf = ctx->batch_out[s->batch_slot];

id<MTLComputeCommandEncoder> enc = [cmdbuf computeCommandEncoder];
int use_v3 = (s->in_dim <= 4096);
[enc setComputePipelineState: use_v3 ? ctx->matvec_v3 : ctx->matvec_fast];
if (s->bits == 8) {
[enc setComputePipelineState: ctx->matvec_8bit];
} else {
int use_v3 = (s->in_dim <= 4096);
[enc setComputePipelineState: use_v3 ? ctx->matvec_v3 : ctx->matvec_fast];
}
[enc setBuffer:ctx->wf_buf offset:w_off atIndex:0];
[enc setBuffer:ctx->wf_buf offset:s_off atIndex:1];
[enc setBuffer:ctx->wf_buf offset:b_off atIndex:2];
Expand All @@ -1437,7 +1486,7 @@ static void gpu_encode_batch_matvec(
[enc setBytes:&s->in_dim length:4 atIndex:6];
[enc setBytes:&s->group_size length:4 atIndex:7];

if (use_v3) {
if (s->bits == 8 || s->in_dim <= 4096) {
uint32_t num_tgs = (s->out_dim + 7) / 8;
[enc dispatchThreadgroups:MTLSizeMake(num_tgs, 1, 1)
threadsPerThreadgroup:MTLSizeMake(256, 1, 1)];
Expand Down Expand Up @@ -1886,8 +1935,13 @@ static void fast_batch_matvec(
} else {
for (int i = 0; i < num_specs; i++) {
BatchMatvecSpec *s = &specs[i];
cpu_dequant_matvec(s->W, s->scales, s->biases, x, s->out_cpu,
s->out_dim, s->in_dim, s->group_size);
if (s->bits == 8) {
cpu_dequant_matvec_8bit(s->W, s->scales, s->biases, x, s->out_cpu,
s->out_dim, s->in_dim, s->group_size);
} else {
cpu_dequant_matvec(s->W, s->scales, s->biases, x, s->out_cpu,
s->out_dim, s->in_dim, s->group_size);
}
}
}
}
Expand Down Expand Up @@ -2188,9 +2242,9 @@ static void full_attention_forward(
// Batch Q/K/V into one command buffer (3 dispatches, 1 commit)
if (qw && qs && qb && kw && ks && kb && vw && vs && vb) {
BatchMatvecSpec qkv_specs[3] = {
{ qw, qs, qb, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0 },
{ kw, ks, kb, k, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1 },
{ vw, vs, vb, v, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2 },
{ qw, qs, qb, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0, 4 },
{ kw, ks, kb, k, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1, 4 },
{ vw, vs, vb, v, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2, 4 },
};
fast_batch_matvec(normed, HIDDEN_DIM, qkv_specs, 3);
}
Expand Down Expand Up @@ -2446,10 +2500,10 @@ static void linear_attention_forward(
if (qkv_w && qkv_s && qkv_b && z_w && z_s && z_b &&
b_w && b_s && b_b && a_w && a_s && a_b) {
BatchMatvecSpec la_specs[4] = {
{ qkv_w, qkv_s, qkv_b, qkv, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0 },
{ z_w, z_s, z_b, z, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1 },
{ b_w, b_s, b_b, beta, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2 },
{ a_w, a_s, a_b, alpha, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3 },
{ qkv_w, qkv_s, qkv_b, qkv, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0, 4 },
{ z_w, z_s, z_b, z, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1, 4 },
{ b_w, b_s, b_b, beta, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2, 4 },
{ a_w, a_s, a_b, alpha, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3, 4 },
};
fast_batch_matvec(normed, HIDDEN_DIM, la_specs, 4);
}
Expand Down Expand Up @@ -2699,10 +2753,10 @@ static void moe_forward(
if (gate_w && gate_s && gate_b && sgw && sgs && sgb &&
suw && sus && sub && seg_w && seg_s && seg_b) {
BatchMatvecSpec moe_specs[4] = {
{ gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 },
{ sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 },
{ suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 },
{ seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 },
{ gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0, 8 },
{ sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1, 4 },
{ suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2, 4 },
{ seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3, 8 },
};
fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4);
}
Expand Down Expand Up @@ -4035,9 +4089,9 @@ static void fused_layer_forward(

if (lc->q_w && lc->q_s && lc->q_b && lc->k_w && lc->k_s && lc->k_b &&
lc->v_w && lc->v_s && lc->v_b) {
attn_specs[0] = (BatchMatvecSpec){ lc->q_w, lc->q_s, lc->q_b, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0 };
attn_specs[1] = (BatchMatvecSpec){ lc->k_w, lc->k_s, lc->k_b, k_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1 };
attn_specs[2] = (BatchMatvecSpec){ lc->v_w, lc->v_s, lc->v_b, v_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2 };
attn_specs[0] = (BatchMatvecSpec){ lc->q_w, lc->q_s, lc->q_b, q_proj_out, (uint32_t)q_proj_dim, HIDDEN_DIM, GROUP_SIZE, 0, 4 };
attn_specs[1] = (BatchMatvecSpec){ lc->k_w, lc->k_s, lc->k_b, k_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 1, 4 };
attn_specs[2] = (BatchMatvecSpec){ lc->v_w, lc->v_s, lc->v_b, v_out, (uint32_t)kv_dim, HIDDEN_DIM, GROUP_SIZE, 2, 4 };
num_attn_specs = 3;
}
} else {
Expand All @@ -4051,10 +4105,10 @@ static void fused_layer_forward(

if (lc->qkv_w && lc->qkv_s && lc->qkv_b && lc->z_w && lc->z_s && lc->z_b &&
lc->b_w && lc->b_s && lc->b_b && lc->a_w && lc->a_s && lc->a_b) {
attn_specs[0] = (BatchMatvecSpec){ lc->qkv_w, lc->qkv_s, lc->qkv_b, qkv_out, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0 };
attn_specs[1] = (BatchMatvecSpec){ lc->z_w, lc->z_s, lc->z_b, z_out, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1 };
attn_specs[2] = (BatchMatvecSpec){ lc->b_w, lc->b_s, lc->b_b, beta_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2 };
attn_specs[3] = (BatchMatvecSpec){ lc->a_w, lc->a_s, lc->a_b, alpha_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3 };
attn_specs[0] = (BatchMatvecSpec){ lc->qkv_w, lc->qkv_s, lc->qkv_b, qkv_out, (uint32_t)qkv_dim, HIDDEN_DIM, GROUP_SIZE, 0, 4 };
attn_specs[1] = (BatchMatvecSpec){ lc->z_w, lc->z_s, lc->z_b, z_out, (uint32_t)z_dim, HIDDEN_DIM, GROUP_SIZE, 1, 4 };
attn_specs[2] = (BatchMatvecSpec){ lc->b_w, lc->b_s, lc->b_b, beta_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 2, 4 };
attn_specs[3] = (BatchMatvecSpec){ lc->a_w, lc->a_s, lc->a_b, alpha_out, (uint32_t)LINEAR_NUM_V_HEADS, HIDDEN_DIM, GROUP_SIZE, 3, 4 };
num_attn_specs = 4;
}
}
Expand Down Expand Up @@ -4339,8 +4393,13 @@ static void fused_layer_forward(
} else {
for (int i = 0; i < num_attn_specs; i++) {
BatchMatvecSpec *s = &attn_specs[i];
cpu_dequant_matvec(s->W, s->scales, s->biases, normed, s->out_cpu,
s->out_dim, s->in_dim, s->group_size);
if (s->bits == 8) {
cpu_dequant_matvec_8bit(s->W, s->scales, s->biases, normed, s->out_cpu,
s->out_dim, s->in_dim, s->group_size);
} else {
cpu_dequant_matvec(s->W, s->scales, s->biases, normed, s->out_cpu,
s->out_dim, s->in_dim, s->group_size);
}
}
}
if (g_timing_enabled) { t1 = now_ms(); g_timing.cmd1_submit += t1 - t0; }
Expand Down Expand Up @@ -4378,9 +4437,9 @@ static void fused_layer_forward(
memset(spec_scores, 0, NUM_EXPERTS * sizeof(float));

// Gate projection matvec on pre-attention normed input (CPU, ~0.1ms for 512x4096)
cpu_dequant_matvec(lc->gate_w, lc->gate_s, lc->gate_b,
normed, spec_scores,
NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE);
cpu_dequant_matvec_8bit(lc->gate_w, lc->gate_s, lc->gate_b,
normed, spec_scores,
NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE);
cpu_softmax(spec_scores, NUM_EXPERTS);

int spec_K = (K > MAX_K) ? MAX_K : K;
Expand Down Expand Up @@ -4968,10 +5027,10 @@ static void fused_layer_forward(

// ---- Enc 5-8: routing + shared expert projections (read buf_input) ----
BatchMatvecSpec moe_specs[4] = {
{ gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 },
{ sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 },
{ suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 },
{ seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 },
{ gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0, 8 },
{ sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1, 4 },
{ suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2, 4 },
{ seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3, 8 },
};
// buf_input already contains h_post from Enc 4 output -- no memcpy needed
gpu_encode_batch_matvec(g_metal, cmd_fused, moe_specs, 4);
Expand Down Expand Up @@ -5017,10 +5076,10 @@ static void fused_layer_forward(
// Routing + shared expert batch
if (have_moe_weights) {
BatchMatvecSpec moe_specs[4] = {
{ gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0 },
{ sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1 },
{ suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2 },
{ seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3 },
{ gate_w, gate_s, gate_b, gate_scores, (uint32_t)NUM_EXPERTS, HIDDEN_DIM, GROUP_SIZE, 0, 8 },
{ sgw, sgs, sgb, shared_gate, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 1, 4 },
{ suw, sus, sub, shared_up, (uint32_t)SHARED_INTERMEDIATE, HIDDEN_DIM, GROUP_SIZE, 2, 4 },
{ seg_w, seg_s, seg_b, &shared_gate_score, 1, HIDDEN_DIM, GROUP_SIZE, 3, 8 },
};
fast_batch_matvec(h_post, HIDDEN_DIM, moe_specs, 4);
}
Expand Down
68 changes: 68 additions & 0 deletions metal_infer/shaders.metal
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,74 @@ kernel void dequant_matvec_2bit(
}


// ============================================================================
// Kernel 1e: 8-bit dequantized matrix-vector multiply (FMA-optimized)
// ============================================================================
// Same structure as dequant_matvec_4bit_v3 but for 8-bit quantization:
// - 4 values per uint32 (vs 8 for 4-bit, 16 for 2-bit)
// - packed_cols = in_dim / 4
// - Extract bytes: (packed >> 0) & 0xFF, >> 8, >> 16, >> 24
// - FMA-optimized: (byte * scale + bias) * x = fma(byte, scale*x, bias*x)
// Used for gate routing weights and shared_expert_gate in Qwen3.5-397B.

kernel void dequant_matvec_8bit(
device const uint32_t* W_packed [[buffer(0)]], // [out_dim, in_dim/4]
device const uint16_t* scales [[buffer(1)]], // [out_dim, num_groups] bf16
device const uint16_t* biases [[buffer(2)]], // [out_dim, num_groups] bf16
device const float* x [[buffer(3)]], // [in_dim]
device float* out [[buffer(4)]], // [out_dim]
constant uint& out_dim [[buffer(5)]],
constant uint& in_dim [[buffer(6)]],
constant uint& group_size [[buffer(7)]],
uint tgid [[threadgroup_position_in_grid]],
uint lid [[thread_position_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_group [[simdgroup_index_in_threadgroup]]
) {
uint row = tgid * ROWS_PER_TG + simd_group;
uint packed_cols = in_dim / 4; // 4 values per uint32 for 8-bit
uint num_groups = in_dim / group_size;

threadgroup float x_shared[4096];
for (uint i = lid; i < in_dim; i += 256) {
x_shared[i] = x[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (row >= out_dim) return;

device const uint32_t* w_row = W_packed + row * packed_cols;
device const uint16_t* s_row = scales + row * num_groups;
device const uint16_t* b_row = biases + row * num_groups;

float acc = 0.0f;

for (uint col = simd_lane; col < packed_cols; col += 32) {
// group_size/4 packed words per group
uint g = col / (group_size / 4);
float scale = bf16_to_f32(s_row[g]);
float bias = bf16_to_f32(b_row[g]);

uint32_t packed = w_row[col];
uint x_base = col * 4;

float sx0 = scale * x_shared[x_base + 0]; float bx0 = bias * x_shared[x_base + 0];
float sx1 = scale * x_shared[x_base + 1]; float bx1 = bias * x_shared[x_base + 1];
float sx2 = scale * x_shared[x_base + 2]; float bx2 = bias * x_shared[x_base + 2];
float sx3 = scale * x_shared[x_base + 3]; float bx3 = bias * x_shared[x_base + 3];

acc += fma(float((packed >> 0) & 0xFF), sx0, bx0);
acc += fma(float((packed >> 8) & 0xFF), sx1, bx1);
acc += fma(float((packed >> 16) & 0xFF), sx2, bx2);
acc += fma(float((packed >> 24) & 0xFF), sx3, bx3);
}

float sum = simd_sum(acc);
if (simd_lane == 0) {
out[row] = sum;
}
}


// ============================================================================
// Kernel 1d: FULLY OPTIMIZED with uint4 vector loads
// ============================================================================
Expand Down