Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,62 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_me
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_lightning_indexer(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_LIGHTNING_INDEXER);

const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];

GGML_ASSERT(src0->ne[0] == 128); // n_embd
GGML_ASSERT(src0->ne[1] == 64); // n_head

char base[256];
char name[256];

snprintf(base, 256, "kernel_lightning_indexer_%s", ggml_type_name(src1->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_comb(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_DSV4_HC_COMB);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, "kernel_dsv4_hc_comb");
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, "kernel_dsv4_hc_comb", "kernel_dsv4_hc_comb", nullptr);
}

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_pre(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_DSV4_HC_PRE);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, "kernel_dsv4_hc_pre");
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, "kernel_dsv4_hc_pre", "kernel_dsv4_hc_pre", nullptr);
}

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_post(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->op == GGML_OP_DSV4_HC_POST);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, "kernel_dsv4_hc_post");
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, "kernel_dsv4_hc_post", "kernel_dsv4_hc_post", nullptr);
}

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_bl
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_lightning_indexer (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_comb (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_pre (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_dsv4_hc_post (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
53 changes: 53 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,59 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
}
return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_LIGHTNING_INDEXER:
{
// DeepSeek V4 lightning indexer: n_embd=128, n_head=64
const int64_t n_embd = op->src[0]->ne[0];
const int64_t n_head = op->src[0]->ne[1];

if (n_embd != 128 || n_head != 64) {
return false;
}

if (op->src[0]->type != GGML_TYPE_F32) {
return false;
}
if (op->src[2]->type != GGML_TYPE_F32) {
return false;
}

switch (op->src[1]->type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
break;
case GGML_TYPE_BF16:
if (!has_bfloat) {
return false;
}
break;
default:
return false;
}

return true;
}
case GGML_OP_DSV4_HC_COMB:
if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32 || op->src[2]->type != GGML_TYPE_F32) {
return false;
}
return true;
case GGML_OP_DSV4_HC_PRE:
if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32) {
return false;
}
return true;
case GGML_OP_DSV4_HC_POST:
if (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type != GGML_TYPE_F32 ||
op->src[2]->type != GGML_TYPE_F32 || op->src[3]->type != GGML_TYPE_F32) {
return false;
}
return true;
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
return has_simdgroup_reduction;
Expand Down
67 changes: 67 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@
#define FC_SUM_ROWS 1400
#define FC_UPSCALE 1500
#define FC_GATED_DELTA_NET 1600
#define FC_LIGHTNING_INDEXER 1700
#define FC_DSV4_HC_COMB 1800
#define FC_DSV4_HC_PRE 1900
#define FC_DSV4_HC_POST 2000

// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
Expand Down Expand Up @@ -1172,4 +1176,67 @@ typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_sgd;

typedef struct {
int32_t n_kv;
int32_t n_head;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
float scale_embd;
float scale_heads;
} ggml_metal_kargs_lightning_indexer;

typedef struct {
uint64_t nd0;
uint64_t nd1;
uint64_t nd2;
uint64_t nm0;
uint64_t nm1;
uint64_t ns0;
uint64_t nb0;
int32_t n_tokens;
int32_t n_iter;
float eps;
int32_t pad;
} ggml_metal_kargs_dsv4_hc_comb;

typedef struct {
uint64_t nbx0;
uint64_t nbx1;
uint64_t nbx2;
uint64_t nbw0;
uint64_t nbw1;
uint64_t nbd0;
uint64_t nbd1;
int32_t n_embd;
int32_t n_tokens;
} ggml_metal_kargs_dsv4_hc_pre;

typedef struct {
uint64_t nbx0;
uint64_t nbx1;
uint64_t nbr0;
uint64_t nbr1;
uint64_t nbr2;
uint64_t nbp0;
uint64_t nbp1;
uint64_t nbc0;
uint64_t nbc1;
uint64_t nbc2;
uint64_t nbd0;
uint64_t nbd1;
uint64_t nbd2;
int32_t n_embd;
int32_t n_tokens;
} ggml_metal_kargs_dsv4_hc_post;

#endif // GGML_METAL_IMPL
Loading