From 51641a73604631bc87ab5c5e02a0681d71ef96b6 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Fri, 3 Jul 2026 12:00:21 -0300 Subject: [PATCH 1/3] metal: add GGML_OP_LIGHTNING_INDEXER Metal kernel Implements the DeepSeek V4 lightning indexer on Metal GPU. Follows the CUDA vec kernel approach with 8 SIMD groups per threadgroup, each processing 8 KV vectors using simd_sum for per-head dot product reduction. Supports F32, F16, BF16 and quantized K types (Q4_0, Q4_1, Q5_0, Q5_1, Q8_0). Assisted-by: DeepSeek V4 Pro --- ggml/src/ggml-metal/ggml-metal-device.cpp | 23 ++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 37 +++ ggml/src/ggml-metal/ggml-metal-impl.h | 20 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 84 +++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 262 ++++++++++++++++++++++ 7 files changed, 428 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0e1f1de4577d..f8fcb0331dfd 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -473,6 +473,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_lightning_indexer(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_LIGHTNING_INDEXER); + + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + + GGML_ASSERT(src0->ne[0] == 128); // n_embd + GGML_ASSERT(src0->ne[1] == 64); // n_head + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_lightning_indexer_%s", ggml_type_name(src1->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index d465f31c083b..4db4ce812efa 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -124,6 +124,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_lightning_indexer (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index a7cbc60ebe41..55b284d7c2e6 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1255,6 +1255,43 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; } return has_simdgroup_mm; // TODO: over-restricted for vec-kernels + case GGML_OP_LIGHTNING_INDEXER: + { + // DeepSeek V4 lightning indexer: n_embd=128, n_head=64 + const int64_t n_embd = op->src[0]->ne[0]; + const int64_t n_head = op->src[0]->ne[1]; + + if (n_embd != 128 || n_head != 64) { + return false; + } + + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + if (op->src[2]->type != GGML_TYPE_F32) { + return false; + } + + switch (op->src[1]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + break; + case GGML_TYPE_BF16: + if (!has_bfloat) { + return false; + } + break; + default: + return false; + } + + return true; + } case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: return has_simdgroup_reduction; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ff74cafb5b79..e94efbf9d519 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -101,6 +101,7 @@ #define FC_SUM_ROWS 1400 #define FC_UPSCALE 1500 #define FC_GATED_DELTA_NET 1600 +#define FC_LIGHTNING_INDEXER 1700 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -1172,4 +1173,23 @@ typedef struct { int64_t np; } ggml_metal_kargs_opt_step_sgd; +typedef struct { + int32_t n_kv; + int32_t n_head; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + float scale_embd; + float scale_heads; +} ggml_metal_kargs_lightning_indexer; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 18656b346f21..e0fea2490706 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -316,6 +316,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_cumsum(ctx, idx); } break; + case GGML_OP_LIGHTNING_INDEXER: + { + n_fuse = ggml_metal_op_lightning_indexer(ctx, idx); + } break; case GGML_OP_SOFT_MAX: { n_fuse = ggml_metal_op_soft_max(ctx, idx); @@ -1289,6 +1293,86 @@ int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_lightning_indexer(ggml_metal_op_t ctx, int idx) { + ggml_tensor * dst = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(dst->op == GGML_OP_LIGHTNING_INDEXER); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_TERNARY_OP_LOCALS + + // input tensor rows must be contiguous + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + GGML_ASSERT(nb20 == ggml_type_size(src2->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + const int n_embd = (int) src0->ne[0]; + const int n_head = (int) src0->ne[1]; + const int n_batch = (int) src0->ne[2]; + const int n_stream = (int) src0->ne[3]; + const int n_kv = (int) src1->ne[2]; + + const float scale_embd = ggml_get_op_params_f32(dst, 0); + const float scale_heads = ggml_get_op_params_f32(dst, 1); + + GGML_ASSERT(n_embd == 128); + GGML_ASSERT(n_head == 64); + + ggml_metal_kargs_lightning_indexer args = { + /*.n_kv =*/ n_kv, + /*.n_head =*/ n_head, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.scale_embd =*/ scale_embd, + /*.scale_heads=*/ scale_heads, + }; + + auto pipeline = ggml_metal_library_get_pipeline_lightning_indexer(lib, dst); + + constexpr int K_VECS_PER_SG = 8; + constexpr int N_SG_PER_TG = 8; + constexpr int K_VECS_PER_TG = K_VECS_PER_SG * N_SG_PER_TG; + + int num_kv_blocks = (n_kv + K_VECS_PER_TG - 1) / K_VECS_PER_TG; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src0), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src1), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src2), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(dst), 4); + + ggml_metal_encoder_dispatch_threadgroups(enc, num_kv_blocks, n_batch, n_stream, 32, N_SG_PER_TG, 1); + + return 1; +} + int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 36c61071b4fa..77ab97eebd13 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -54,6 +54,7 @@ int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_lightning_indexer (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 25e78e100898..39bd35a9ed5a 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -10752,3 +10752,265 @@ kernel void kernel_count_equal( typedef decltype(kernel_count_equal) kernel_count_equal_t; template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal; + +// Lightning indexer kernel +// Follows the CUDA vec kernel approach: +// - Each threadgroup processes K_VECS_PER_TG K vectors +// - Each SIMD group processes K_VECS_PER_SG K vectors +// - Each thread (lane) handles one float4 (4 elements) of the 128-element vectors +// - n_embd=128, n_head=64 are hardcoded (matching DeepSeek V4) + +constexpr constant int LI_N_EMBD = 128; +constexpr constant int LI_N_HEAD = 64; +constexpr constant int LI_N_EMBD_4 = LI_N_EMBD / 4; // 32 float4s per vector +constexpr constant int LI_K_VECS_PER_SG = 8; +constexpr constant int LI_N_SG_PER_TG = 8; +constexpr constant int LI_K_VECS_PER_TG = LI_K_VECS_PER_SG * LI_N_SG_PER_TG; // 64 +constexpr constant int LI_N_HEAD_INNER = LI_N_HEAD / 4; // 16 + +// shared compute logic, after K has been loaded to float4 registers +// threadgroup memory pointers are passed in from the kernel entry point +void kernel_lightning_indexer_compute( + constant ggml_metal_kargs_lightning_indexer & args, + device const float * q_base, + device const float * w_base, + device float * dst, + thread float4 * k_reg, + threadgroup float * w_shared, + threadgroup float4 * q_shared, + threadgroup float * dst_shared, + int start_kv_block, + int i_batch, int i_stream, + ushort tiisg, ushort sgitg) { + + float score_k[LI_K_VECS_PER_SG] = { 0.0f }; + + for (int i_head_0 = 0; i_head_0 < LI_N_HEAD; i_head_0 += LI_N_HEAD_INNER) { + const int tid_tg = (int) (tiisg + sgitg * N_SIMDWIDTH); + if (tid_tg < LI_N_HEAD_INNER) { + w_shared[tid_tg] = w_base[i_head_0 + tid_tg]; + } + + const int n_q = LI_N_HEAD_INNER * LI_N_EMBD_4; + const int n_tg = LI_N_SG_PER_TG * N_SIMDWIDTH; + + for (int i_q = tid_tg; i_q < n_q; i_q += n_tg) { + const int i_head_inner = i_q / LI_N_EMBD_4; + const int i_head = i_head_0 + i_head_inner; + const int i_embd = i_q % LI_N_EMBD_4; + q_shared[i_head_inner * LI_N_EMBD_4 + i_embd] = q_base[i_head*LI_N_EMBD_4 + i_embd]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int i_head_inner = 0; i_head_inner < LI_N_HEAD_INNER; ++i_head_inner) { + const float w_val = w_shared[i_head_inner]; + float qk[LI_K_VECS_PER_SG] = { 0.0f }; + + const float4 q_vec = q_shared[i_head_inner * LI_N_EMBD_4 + tiisg]; + + for (int k = 0; k < LI_K_VECS_PER_SG; ++k) { + qk[k] += q_vec.x * k_reg[k].x; + qk[k] += q_vec.y * k_reg[k].y; + qk[k] += q_vec.z * k_reg[k].z; + qk[k] += q_vec.w * k_reg[k].w; + } + + for (int k = 0; k < LI_K_VECS_PER_SG; ++k) { + float sum = simd_sum(qk[k]); + + if (tiisg == 0) { + sum *= args.scale_embd; + sum = (sum > 0.0f) ? sum : 0.0f; + score_k[k] += sum * w_val; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (tiisg == 0) { + for (int k = 0; k < LI_K_VECS_PER_SG; ++k) { + dst_shared[sgitg * LI_K_VECS_PER_SG + k] = score_k[k] * args.scale_heads; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const int tid_tg = (int) (tiisg + sgitg * N_SIMDWIDTH); + if (tid_tg < LI_K_VECS_PER_TG) { + int i_kv = start_kv_block + tid_tg; + if (i_kv < args.n_kv) { + device float * dst_base = (device float *) ((device char *) dst + i_batch*args.nb1 + i_stream*args.nb3); + dst_base[i_kv] = dst_shared[tid_tg]; + } + } +} + +// kernel entry point for F32 K type +kernel void kernel_lightning_indexer_f32( + constant ggml_metal_kargs_lightning_indexer & args, + device const float * src0, + device const char * src1, + device const float * src2, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup float w_shared[LI_N_HEAD_INNER]; + threadgroup float4 q_shared[LI_N_HEAD_INNER * LI_N_EMBD_4]; + threadgroup float dst_shared[LI_K_VECS_PER_TG]; + + const int i_batch = (int) tgpig.y; + const int i_stream = (int) tgpig.z; + const int start_kv_block = (int) tgpig.x * LI_K_VECS_PER_TG; + const int start_kv = start_kv_block + (int) sgitg * LI_K_VECS_PER_SG; + + device const float * q_base = (device const float *) ((device const char *) src0 + i_batch*args.nb02 + i_stream*args.nb03); + device const float * w_base = (device const float *) ((device const char *) src2 + i_batch*args.nb21 + i_stream*args.nb23); + + float4 k_reg[LI_K_VECS_PER_SG]; + for (int k = 0; k < LI_K_VECS_PER_SG; ++k) { + int i_kv = start_kv + k; + if (i_kv < args.n_kv) { + device const float4 * k_base = (device const float4 *) ((device const char *) src1 + i_kv*args.nb12 + i_stream*args.nb13); + k_reg[k] = k_base[tiisg]; + } else { + k_reg[k] = float4(0); + } + } + kernel_lightning_indexer_compute(args, q_base, w_base, dst, k_reg, + w_shared, q_shared, dst_shared, start_kv_block, i_batch, i_stream, tiisg, sgitg); +} + +// kernel entry point for F16 K type +kernel void kernel_lightning_indexer_f16( + constant ggml_metal_kargs_lightning_indexer & args, + device const float * src0, + device const char * src1, + device const float * src2, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup float w_shared[LI_N_HEAD_INNER]; + threadgroup float4 q_shared[LI_N_HEAD_INNER * LI_N_EMBD_4]; + threadgroup float dst_shared[LI_K_VECS_PER_TG]; + + const int i_batch = (int) tgpig.y; + const int i_stream = (int) tgpig.z; + const int start_kv_block = (int) tgpig.x * LI_K_VECS_PER_TG; + const int start_kv = start_kv_block + (int) sgitg * LI_K_VECS_PER_SG; + + device const float * q_base = (device const float *) ((device const char *) src0 + i_batch*args.nb02 + i_stream*args.nb03); + device const float * w_base = (device const float *) ((device const char *) src2 + i_batch*args.nb21 + i_stream*args.nb23); + + float4 k_reg[LI_K_VECS_PER_SG]; + for (int k = 0; k < LI_K_VECS_PER_SG; ++k) { + int i_kv = start_kv + k; + if (i_kv < args.n_kv) { + device const half4 * k_base = (device const half4 *) ((device const char *) src1 + i_kv*args.nb12 + i_stream*args.nb13); + k_reg[k] = float4(k_base[tiisg]); + } else { + k_reg[k] = float4(0); + } + } + kernel_lightning_indexer_compute(args, q_base, w_base, dst, k_reg, + w_shared, q_shared, dst_shared, start_kv_block, i_batch, i_stream, tiisg, sgitg); +} + +#if defined(GGML_METAL_HAS_BF16) +// kernel entry point for BF16 K type +kernel void kernel_lightning_indexer_bf16( + constant ggml_metal_kargs_lightning_indexer & args, + device const float * src0, + device const char * src1, + device const float * src2, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup float w_shared[LI_N_HEAD_INNER]; + threadgroup float4 q_shared[LI_N_HEAD_INNER * LI_N_EMBD_4]; + threadgroup float dst_shared[LI_K_VECS_PER_TG]; + + const int i_batch = (int) tgpig.y; + const int i_stream = (int) tgpig.z; + const int start_kv_block = (int) tgpig.x * LI_K_VECS_PER_TG; + const int start_kv = start_kv_block + (int) sgitg * LI_K_VECS_PER_SG; + + device const float * q_base = (device const float *) ((device const char *) src0 + i_batch*args.nb02 + i_stream*args.nb03); + device const float * w_base = (device const float *) ((device const char *) src2 + i_batch*args.nb21 + i_stream*args.nb23); + + float4 k_reg[LI_K_VECS_PER_SG]; + for (int k = 0; k < LI_K_VECS_PER_SG; ++k) { + int i_kv = start_kv + k; + if (i_kv < args.n_kv) { + device const bfloat4 * k_base = (device const bfloat4 *) ((device const char *) src1 + i_kv*args.nb12 + i_stream*args.nb13); + k_reg[k] = float4(k_base[tiisg]); + } else { + k_reg[k] = float4(0); + } + } + kernel_lightning_indexer_compute(args, q_base, w_base, dst, k_reg, + w_shared, q_shared, dst_shared, start_kv_block, i_batch, i_stream, tiisg, sgitg); +} +#endif + +// quantized type kernels: template with function pointer for dequantize + +template +kernel void kernel_lightning_indexer_quantized( + constant ggml_metal_kargs_lightning_indexer & args, + device const float * src0, + device const char * src1, + device const float * src2, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup float w_shared[LI_N_HEAD_INNER]; + threadgroup float4 q_shared[LI_N_HEAD_INNER * LI_N_EMBD_4]; + threadgroup float dst_shared[LI_K_VECS_PER_TG]; + + const int i_batch = (int) tgpig.y; + const int i_stream = (int) tgpig.z; + const int start_kv_block = (int) tgpig.x * LI_K_VECS_PER_TG; + const int start_kv = start_kv_block + (int) sgitg * LI_K_VECS_PER_SG; + + device const float * q_base = (device const float *) ((device const char *) src0 + i_batch*args.nb02 + i_stream*args.nb03); + device const float * w_base = (device const float *) ((device const char *) src2 + i_batch*args.nb21 + i_stream*args.nb23); + + // dequantize K to float4 registers + // n_embd=128, block_size=32 -> 4 blocks per K vector, 8 positions per block + constexpr int positions_per_block = 32 / 4; // 8 + const int il = (int) tiisg % positions_per_block; + const int block_idx = (int) tiisg / positions_per_block; + + float4 k_reg[LI_K_VECS_PER_SG]; + for (int k = 0; k < LI_K_VECS_PER_SG; ++k) { + int i_kv = start_kv + k; + if (i_kv < args.n_kv) { + device const block_t * k_block = (device const block_t *) ((device const char *) src1 + i_kv*args.nb12 + i_stream*args.nb13); + deq_t4(k_block + block_idx, (short) il, k_reg[k]); + } else { + k_reg[k] = float4(0); + } + } + + kernel_lightning_indexer_compute(args, q_base, w_base, dst, k_reg, + w_shared, q_shared, dst_shared, start_kv_block, i_batch, i_stream, tiisg, sgitg); +} + +typedef decltype(kernel_lightning_indexer_quantized) kernel_lightning_indexer_quantized_t; + +template [[host_name("kernel_lightning_indexer_q4_0")]] kernel kernel_lightning_indexer_quantized_t kernel_lightning_indexer_quantized; +template [[host_name("kernel_lightning_indexer_q4_1")]] kernel kernel_lightning_indexer_quantized_t kernel_lightning_indexer_quantized; +template [[host_name("kernel_lightning_indexer_q5_0")]] kernel kernel_lightning_indexer_quantized_t kernel_lightning_indexer_quantized; +template [[host_name("kernel_lightning_indexer_q5_1")]] kernel kernel_lightning_indexer_quantized_t kernel_lightning_indexer_quantized; +template [[host_name("kernel_lightning_indexer_q8_0")]] kernel kernel_lightning_indexer_quantized_t kernel_lightning_indexer_quantized; From 8b2483326042130773552509dc6ae04823791998 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Fri, 3 Jul 2026 12:36:41 -0300 Subject: [PATCH 2/3] metal: add DeepSeek V4 hierarchical connection Metal kernels Implements GGML_OP_DSV4_HC_COMB, GGML_OP_DSV4_HC_PRE, and GGML_OP_DSV4_HC_POST on Metal GPU. HC_PRE performs weighted sum over hc slices, HC_COMB computes Sinkhorn-normalized combination matrices, and HC_POST blends input with residuals using the combination weights. All kernels operate on F32 tensors. Assisted-by: DeepSeek V4 Pro --- ggml/src/ggml-metal/ggml-metal-device.cpp | 33 ++++ ggml/src/ggml-metal/ggml-metal-device.h | 3 + ggml/src/ggml-metal/ggml-metal-device.m | 16 ++ ggml/src/ggml-metal/ggml-metal-impl.h | 47 ++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 186 ++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 3 + ggml/src/ggml-metal/ggml-metal.metal | 171 ++++++++++++++++++++ 7 files changed, 459 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f8fcb0331dfd..8751d0f26e3a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -496,6 +496,39 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_lightning_indexe return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_comb(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_DSV4_HC_COMB); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, "kernel_dsv4_hc_comb"); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, "kernel_dsv4_hc_comb", "kernel_dsv4_hc_comb", nullptr); + } + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_pre(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_DSV4_HC_PRE); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, "kernel_dsv4_hc_pre"); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, "kernel_dsv4_hc_pre", "kernel_dsv4_hc_pre", nullptr); + } + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_post(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_DSV4_HC_POST); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, "kernel_dsv4_hc_post"); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, "kernel_dsv4_hc_post", "kernel_dsv4_hc_post", nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4db4ce812efa..dfe6fb241239 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -125,6 +125,9 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_ad struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_lightning_indexer (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_comb (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_pre (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_post (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 55b284d7c2e6..ddf6156e7c01 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1292,6 +1292,22 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return true; } + case GGML_OP_DSV4_HC_COMB: + if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32 || op->src[2]->type != GGML_TYPE_F32) { + return false; + } + return true; + case GGML_OP_DSV4_HC_PRE: + if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) { + return false; + } + return true; + case GGML_OP_DSV4_HC_POST: + if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32 || + op->src[2]->type != GGML_TYPE_F32 || op->src[3]->type != GGML_TYPE_F32) { + return false; + } + return true; case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: return has_simdgroup_reduction; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index e94efbf9d519..11c5c4f3c7cc 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -102,6 +102,9 @@ #define FC_UPSCALE 1500 #define FC_GATED_DELTA_NET 1600 #define FC_LIGHTNING_INDEXER 1700 +#define FC_DSV4_HC_COMB 1800 +#define FC_DSV4_HC_PRE 1900 +#define FC_DSV4_HC_POST 2000 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -1192,4 +1195,48 @@ typedef struct { float scale_heads; } ggml_metal_kargs_lightning_indexer; +typedef struct { + uint64_t nd0; + uint64_t nd1; + uint64_t nd2; + uint64_t nm0; + uint64_t nm1; + uint64_t ns0; + uint64_t nb0; + int32_t n_tokens; + int32_t n_iter; + float eps; + int32_t pad; +} ggml_metal_kargs_dsv4_hc_comb; + +typedef struct { + uint64_t nbx0; + uint64_t nbx1; + uint64_t nbx2; + uint64_t nbw0; + uint64_t nbw1; + uint64_t nbd0; + uint64_t nbd1; + int32_t n_embd; + int32_t n_tokens; +} ggml_metal_kargs_dsv4_hc_pre; + +typedef struct { + uint64_t nbx0; + uint64_t nbx1; + uint64_t nbr0; + uint64_t nbr1; + uint64_t nbr2; + uint64_t nbp0; + uint64_t nbp1; + uint64_t nbc0; + uint64_t nbc1; + uint64_t nbc2; + uint64_t nbd0; + uint64_t nbd1; + uint64_t nbd2; + int32_t n_embd; + int32_t n_tokens; +} ggml_metal_kargs_dsv4_hc_post; + #endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e0fea2490706..41413e5be5f1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -320,6 +320,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_lightning_indexer(ctx, idx); } break; + case GGML_OP_DSV4_HC_COMB: + { + n_fuse = ggml_metal_op_dsv4_hc_comb(ctx, idx); + } break; + case GGML_OP_DSV4_HC_PRE: + { + n_fuse = ggml_metal_op_dsv4_hc_pre(ctx, idx); + } break; + case GGML_OP_DSV4_HC_POST: + { + n_fuse = ggml_metal_op_dsv4_hc_post(ctx, idx); + } break; case GGML_OP_SOFT_MAX: { n_fuse = ggml_metal_op_soft_max(ctx, idx); @@ -1293,6 +1305,180 @@ int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_dsv4_hc_comb(ggml_metal_op_t ctx, int idx) { + ggml_tensor * dst = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(dst->op == GGML_OP_DSV4_HC_COMB); + + const ggml_tensor * mixes = dst->src[0]; + const ggml_tensor * scale = dst->src[1]; + const ggml_tensor * base = dst->src[2]; + + GGML_ASSERT(mixes->type == GGML_TYPE_F32); + GGML_ASSERT(scale->type == GGML_TYPE_F32); + GGML_ASSERT(base->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_LOCALS(size_t, nbm, mixes, nb); + GGML_TENSOR_LOCALS(size_t, nbs, scale, nb); + GGML_TENSOR_LOCALS(size_t, nbb, base, nb); + GGML_TENSOR_LOCALS(size_t, nbd, dst, nb); + + const int32_t n_tokens = (int32_t) mixes->ne[1]; + const float eps = ggml_get_op_params_f32(dst, 0); + const int32_t n_iter = ggml_get_op_params_i32(dst, 1); + + ggml_metal_kargs_dsv4_hc_comb args = { + /*.nd0 =*/ nbd0, + /*.nd1 =*/ nbd1, + /*.nd2 =*/ nbd2, + /*.nm0 =*/ nbm0, + /*.nm1 =*/ nbm1, + /*.ns0 =*/ nbs0, + /*.nb0 =*/ nbb0, + /*.n_tokens=*/ n_tokens, + /*.n_iter =*/ n_iter, + /*.eps =*/ eps, + /*.pad =*/ 0, + }; + + auto pipeline = ggml_metal_library_get_pipeline_dsv4_hc_comb(lib, dst); + + const int block_size = 256; + const int grid_size = (n_tokens + block_size - 1) / block_size; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(mixes), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(scale), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(base), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(dst), 4); + + ggml_metal_encoder_dispatch_threadgroups(enc, grid_size, 1, 1, block_size, 1, 1); + + return 1; +} + +int ggml_metal_op_dsv4_hc_pre(ggml_metal_op_t ctx, int idx) { + ggml_tensor * dst = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(dst->op == GGML_OP_DSV4_HC_PRE); + + const ggml_tensor * x = dst->src[0]; + const ggml_tensor * weights = dst->src[1]; + + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_LOCALS(size_t, nbx, x, nb); + GGML_TENSOR_LOCALS(size_t, nbw, weights, nb); + GGML_TENSOR_LOCALS(size_t, nbd, dst, nb); + + const int32_t n_embd = (int32_t) x->ne[0]; + const int32_t n_tokens = (int32_t) x->ne[2]; + + ggml_metal_kargs_dsv4_hc_pre args = { + /*.nbx0 =*/ nbx0, + /*.nbx1 =*/ nbx1, + /*.nbx2 =*/ nbx2, + /*.nbw0 =*/ nbw0, + /*.nbw1 =*/ nbw1, + /*.nbd0 =*/ nbd0, + /*.nbd1 =*/ nbd1, + /*.n_embd =*/ n_embd, + /*.n_tokens=*/ n_tokens, + }; + + auto pipeline = ggml_metal_library_get_pipeline_dsv4_hc_pre(lib, dst); + + const int block_size = 256; + const int nr = n_embd * n_tokens; + const int grid_size = (nr + block_size - 1) / block_size; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(x), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(weights), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(dst), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, grid_size, 1, 1, block_size, 1, 1); + + return 1; +} + +int ggml_metal_op_dsv4_hc_post(ggml_metal_op_t ctx, int idx) { + ggml_tensor * dst = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(dst->op == GGML_OP_DSV4_HC_POST); + + const ggml_tensor * x = dst->src[0]; + const ggml_tensor * residual = dst->src[1]; + const ggml_tensor * post = dst->src[2]; + const ggml_tensor * comb = dst->src[3]; + + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(residual->type == GGML_TYPE_F32); + GGML_ASSERT(post->type == GGML_TYPE_F32); + GGML_ASSERT(comb->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_LOCALS(size_t, nbx, x, nb); + GGML_TENSOR_LOCALS(size_t, nbr, residual, nb); + GGML_TENSOR_LOCALS(size_t, nbp, post, nb); + GGML_TENSOR_LOCALS(size_t, nbc, comb, nb); + GGML_TENSOR_LOCALS(size_t, nbd, dst, nb); + + const int32_t n_embd = (int32_t) x->ne[0]; + const int32_t n_tokens = (int32_t) x->ne[1]; + + ggml_metal_kargs_dsv4_hc_post args = { + /*.nbx0 =*/ nbx0, + /*.nbx1 =*/ nbx1, + /*.nbr0 =*/ nbr0, + /*.nbr1 =*/ nbr1, + /*.nbr2 =*/ nbr2, + /*.nbp0 =*/ nbp0, + /*.nbp1 =*/ nbp1, + /*.nbc0 =*/ nbc0, + /*.nbc1 =*/ nbc1, + /*.nbc2 =*/ nbc2, + /*.nbd0 =*/ nbd0, + /*.nbd1 =*/ nbd1, + /*.nbd2 =*/ nbd2, + /*.n_embd =*/ n_embd, + /*.n_tokens=*/ n_tokens, + }; + + auto pipeline = ggml_metal_library_get_pipeline_dsv4_hc_post(lib, dst); + + constexpr int hc = 4; + const int block_size = 256; + const int nr = n_embd * hc * n_tokens; + const int grid_size = (nr + block_size - 1) / block_size; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(x), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(residual), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(post), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(comb), 4); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(dst), 5); + + ggml_metal_encoder_dispatch_threadgroups(enc, grid_size, 1, 1, block_size, 1, 1); + + return 1; +} + int ggml_metal_op_lightning_indexer(ggml_metal_op_t ctx, int idx) { ggml_tensor * dst = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 77ab97eebd13..c69be0057a64 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -55,6 +55,9 @@ int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx); int ggml_metal_op_lightning_indexer (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_dsv4_hc_comb (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_dsv4_hc_pre (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_dsv4_hc_post (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 39bd35a9ed5a..579a051805fa 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -11014,3 +11014,174 @@ template [[host_name("kernel_lightning_indexer_q4_1")]] kernel kernel_lightning_ template [[host_name("kernel_lightning_indexer_q5_0")]] kernel kernel_lightning_indexer_quantized_t kernel_lightning_indexer_quantized; template [[host_name("kernel_lightning_indexer_q5_1")]] kernel kernel_lightning_indexer_quantized_t kernel_lightning_indexer_quantized; template [[host_name("kernel_lightning_indexer_q8_0")]] kernel kernel_lightning_indexer_quantized_t kernel_lightning_indexer_quantized; + +// DeepSeek V4 hierarchical connection kernels + +// HC_PRE: weighted sum over hc slices +// x[n_embd, hc, n_tokens] * weights[hc, n_tokens] -> dst[n_embd, n_tokens] +kernel void kernel_dsv4_hc_pre( + constant ggml_metal_kargs_dsv4_hc_pre & args, + device const float * x, + device const float * weights, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]]) { + + const int ir = (int) tgpig.x * 256 + (int) tiitg; + const int nr = args.n_embd * args.n_tokens; + + if (ir >= nr) { + return; + } + + const int i0 = ir % args.n_embd; + const int it = ir / args.n_embd; + + constexpr int hc = 4; + + device const float * x_row = (device const float *) ((device const char *) x + it*args.nbx2); + device const float * w_row = (device const float *) ((device const char *) weights + it*args.nbw1); + + float sum = x_row[i0 + 0*args.nbx1/sizeof(float)] * w_row[0]; + for (int ih = 1; ih < hc; ++ih) { + sum += x_row[i0 + ih*args.nbx1/sizeof(float)] * w_row[ih*args.nbw0/sizeof(float)]; + } + + device float * dst_row = (device float *) ((device char *) dst + it*args.nbd1); + dst_row[i0] = sum; +} + +// HC_POST: residual blend +// dst[i_embd, idst, it] = x[i_embd, it] * post[idst, it] +// + sum_{isrc} residual[i_embd, isrc, it] * comb[idst, isrc, it] +kernel void kernel_dsv4_hc_post( + constant ggml_metal_kargs_dsv4_hc_post & args, + device const float * x, + device const float * residual, + device const float * post, + device const float * comb, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]]) { + + const int ir = (int) tgpig.x * 256 + (int) tiitg; + + constexpr int hc = 4; + const int nr = args.n_embd * hc * args.n_tokens; + + if (ir >= nr) { + return; + } + + const int i0 = ir % args.n_embd; + const int idst = (ir / args.n_embd) % hc; + const int it = ir / (args.n_embd * hc); + + const float xv = *(device const float *) ((device const char *) x + i0*args.nbx0 + it*args.nbx1); + const float pv = *(device const float *) ((device const char *) post + idst*args.nbp0 + it*args.nbp1); + + float sum = xv * pv; + for (int isrc = 0; isrc < hc; ++isrc) { + const float rv = *(device const float *) ((device const char *) residual + i0*args.nbr0 + isrc*args.nbr1 + it*args.nbr2); + const float cv = *(device const float *) ((device const char *) comb + idst*args.nbc0 + isrc*args.nbc1 + it*args.nbc2); + sum += rv * cv; + } + + *(device float *) ((device char *) dst + i0*args.nbd0 + idst*args.nbd1 + it*args.nbd2) = sum; +} + +// HC_COMB: Sinkhorn normalization of combination matrix +// mixes[24, n_tokens], scale[3], base[24] -> comb[4, 4, n_tokens] +kernel void kernel_dsv4_hc_comb( + constant ggml_metal_kargs_dsv4_hc_comb & args, + device const float * mixes, + device const float * scale, + device const float * base, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]]) { + + const int it = (int) tgpig.x * 256 + (int) tiitg; + + if (it >= args.n_tokens) { + return; + } + + constexpr int hc = 4; + constexpr int comb_offset = 2 * hc; // 8 + + device const float * mixes_row = (device const float *) ((device const char *) mixes + it*args.nm1); + device const float * base_data = (device const float *) base; + + const float scale_comb = scale[2*args.ns0/sizeof(float)]; + + float comb[hc * hc]; + + // row softmax with scale + base affine + for (int isrc = 0; isrc < hc; ++isrc) { + float max = -INFINITY; + for (int idst = 0; idst < hc; ++idst) { + const int idx = idst + hc*isrc; + const int mix_idx = comb_offset + idx; + const float v = mixes_row[mix_idx*args.nm0/sizeof(float)] * scale_comb + base_data[mix_idx*args.nb0/sizeof(float)]; + comb[idx] = v; + max = fmax(max, v); + } + + float sum = 0.0f; + for (int idst = 0; idst < hc; ++idst) { + const int idx = idst + hc*isrc; + const float v = exp(comb[idx] - max); + comb[idx] = v; + sum += v; + } + + const float inv_sum = 1.0f / sum; + for (int idst = 0; idst < hc; ++idst) { + const int idx = idst + hc*isrc; + comb[idx] = comb[idx] * inv_sum + args.eps; + } + } + + // Sinkhorn iterations: normalize columns, then rows alternately + auto norm_cols = [&]() { + for (int idst = 0; idst < hc; ++idst) { + float sum = args.eps; + for (int isrc = 0; isrc < hc; ++isrc) { + sum += comb[idst + hc*isrc]; + } + const float inv_sum = 1.0f / sum; + for (int isrc = 0; isrc < hc; ++isrc) { + comb[idst + hc*isrc] *= inv_sum; + } + } + }; + + auto norm_rows = [&]() { + for (int isrc = 0; isrc < hc; ++isrc) { + float sum = args.eps; + for (int idst = 0; idst < hc; ++idst) { + sum += comb[idst + hc*isrc]; + } + const float inv_sum = 1.0f / sum; + for (int idst = 0; idst < hc; ++idst) { + comb[idst + hc*isrc] *= inv_sum; + } + } + }; + + norm_cols(); + for (int i = 1; i < args.n_iter; ++i) { + norm_rows(); + norm_cols(); + } + + // store output + device float * dst_row = (device float *) ((device char *) dst + it*args.nd2); + for (int isrc = 0; isrc < hc; ++isrc) { + for (int idst = 0; idst < hc; ++idst) { + const int idx = idst + hc*isrc; + *(device float *) ((device char *) dst + idst*args.nd0 + isrc*args.nd1 + it*args.nd2) = comb[idx]; + } + } +} From 32c763323db40e35f585beda52625f291c71b6cc Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Fri, 3 Jul 2026 14:46:56 -0300 Subject: [PATCH 3/3] metal: fix DeepSeek V4 lightning indexer q loads The Metal lightning indexer assigned a scalar float expression to the float4 threadgroup q tile and derived the address from the packed embedding width. That ignored the q head stride and loaded the wrong q vector data, corrupting indexer scores for DeepSeek V4 on Metal. Load q tiles as strided float4 values, matching the CUDA path, and use the provided source and destination strides in HC_PRE instead of assuming contiguous row layout. Add a LIGHTNING_INDEXER backend-op case so the DSV4-shaped F32 path is checked against CPU. Assisted-by: Codex --- ggml/src/ggml-metal/ggml-metal.metal | 12 ++++---- tests/test-backend-ops.cpp | 42 ++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 579a051805fa..f2752a5fd988 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -10798,7 +10798,8 @@ void kernel_lightning_indexer_compute( const int i_head_inner = i_q / LI_N_EMBD_4; const int i_head = i_head_0 + i_head_inner; const int i_embd = i_q % LI_N_EMBD_4; - q_shared[i_head_inner * LI_N_EMBD_4 + i_embd] = q_base[i_head*LI_N_EMBD_4 + i_embd]; + q_shared[i_head_inner * LI_N_EMBD_4 + i_embd] = + *(device const float4 *) ((device const char *) q_base + i_head*args.nb01 + i_embd*sizeof(float4)); } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -11039,16 +11040,15 @@ kernel void kernel_dsv4_hc_pre( constexpr int hc = 4; - device const float * x_row = (device const float *) ((device const char *) x + it*args.nbx2); device const float * w_row = (device const float *) ((device const char *) weights + it*args.nbw1); - float sum = x_row[i0 + 0*args.nbx1/sizeof(float)] * w_row[0]; + float sum = *(device const float *) ((device const char *) x + i0*args.nbx0 + it*args.nbx2) * w_row[0]; for (int ih = 1; ih < hc; ++ih) { - sum += x_row[i0 + ih*args.nbx1/sizeof(float)] * w_row[ih*args.nbw0/sizeof(float)]; + sum += *(device const float *) ((device const char *) x + i0*args.nbx0 + ih*args.nbx1 + it*args.nbx2) * + w_row[ih*args.nbw0/sizeof(float)]; } - device float * dst_row = (device float *) ((device char *) dst + it*args.nbd1); - dst_row[i0] = sum; + *(device float *) ((device char *) dst + i0*args.nbd0 + it*args.nbd1) = sum; } // HC_POST: residual blend diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 42123c6fecf1..f8aa4287e38e 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4204,6 +4204,46 @@ struct test_dsv4_hc_post : public test_dsv4_hc { } }; +struct test_lightning_indexer : public test_case { + const int64_t n_batch; + const int64_t n_kv; + const int64_t n_stream; + + std::string vars() override { + return VARS_TO_STR3(n_batch, n_kv, n_stream); + } + + double err(const float * a, const float * b, size_t n) override { + double max_abs = 0.0; + for (size_t i = 0; i < n; ++i) { + max_abs = std::max(max_abs, fabsf(a[i] - b[i])); + } + return max_abs; + } + + double max_err() override { + return 1e-4; + } + + test_lightning_indexer(int64_t n_batch = 3, int64_t n_kv = 65, int64_t n_stream = 2) + : n_batch(n_batch), n_kv(n_kv), n_stream(n_stream) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 128, 64, n_batch, n_stream); + ggml_set_name(q, "q"); + + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 128, 1, n_kv, n_stream); + ggml_set_name(k, "k"); + + ggml_tensor * weights = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 64, n_batch, 1, n_stream); + ggml_set_name(weights, "weights"); + + ggml_tensor * out = ggml_lightning_indexer(ctx, q, k, weights, 1.0f/sqrtf(128.0f), 1.0f/sqrtf(64.0f)); + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_SSM_CONV struct test_ssm_conv : public test_case { @@ -8220,6 +8260,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_dsv4_hc_post(31, 17)); test_cases.emplace_back(new test_dsv4_hc_post(128, 257)); + test_cases.emplace_back(new test_lightning_indexer()); + // glu ops for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { for (int v : {0, 1}) {