From e4b8fe1cc001cbc5cbc49cf9b3b885660ffd17f2 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Tue, 6 Jan 2026 19:26:35 +0800 Subject: [PATCH 01/14] issue/889 - added interface definitions --- include/infinicore/ops.hpp | 4 + include/infinicore/ops/embedding.hpp | 7 + include/infinicore/ops/flash_attention.hpp | 16 +++ include/infinicore/ops/kv_caching.hpp | 28 ++++ .../infinicore/ops/random_sample_batched.hpp | 20 +++ include/infiniop.h | 4 + include/infiniop/ops/embedding.h | 25 ++++ include/infiniop/ops/flash_attention.h | 34 +++++ include/infiniop/ops/kv_caching.h | 31 +++++ include/infiniop/ops/random_sample.h | 6 - include/infiniop/ops/random_sample_batched.h | 34 +++++ src/infinicore/ops/embedding/embedding.cc | 84 +++--------- .../ops/embedding/embedding_infiniop.cc | 49 +++++++ .../ops/flash_attention/flash_attention.cc | 29 ++++ .../flash_attention_infiniop.cc | 51 +++++++ src/infinicore/ops/kv_caching/kv_caching.cc | 47 +++++++ .../ops/kv_caching/kv_caching_infiniop.cc | 59 ++++++++ .../random_sample_batched.cc | 54 ++++++++ .../random_sample_batched_infiniop.cc | 63 +++++++++ src/infiniop/ops/embedding/operator.cc | 89 ++++++++++++ src/infiniop/ops/flash_attention/operator.cc | 121 +++++++++++++++++ src/infiniop/ops/kv_caching/operator.cc | 121 +++++++++++++++++ .../ops/random_sample_batched/operator.cc | 128 ++++++++++++++++++ 23 files changed, 1035 insertions(+), 69 deletions(-) create mode 100644 include/infinicore/ops/flash_attention.hpp create mode 100644 include/infinicore/ops/kv_caching.hpp create mode 100644 include/infinicore/ops/random_sample_batched.hpp create mode 100644 include/infiniop/ops/embedding.h create mode 100644 include/infiniop/ops/flash_attention.h create mode 100644 include/infiniop/ops/kv_caching.h create mode 100644 include/infiniop/ops/random_sample_batched.h create mode 100644 src/infinicore/ops/embedding/embedding_infiniop.cc create mode 100644 src/infinicore/ops/flash_attention/flash_attention.cc create mode 100644 src/infinicore/ops/flash_attention/flash_attention_infiniop.cc create mode 100644 src/infinicore/ops/kv_caching/kv_caching.cc create mode 100644 src/infinicore/ops/kv_caching/kv_caching_infiniop.cc create mode 100644 src/infinicore/ops/random_sample_batched/random_sample_batched.cc create mode 100644 src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc create mode 100644 src/infiniop/ops/embedding/operator.cc create mode 100644 src/infiniop/ops/flash_attention/operator.cc create mode 100644 src/infiniop/ops/kv_caching/operator.cc create mode 100644 src/infiniop/ops/random_sample_batched/operator.cc diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index a7249ec9d..3fb47d383 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -4,12 +4,16 @@ #include "ops/add_rms_norm.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" +#include "ops/embedding.hpp" +#include "ops/flash_attention.hpp" +#include "ops/kv_caching.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" #include "ops/paged_attention_prefill.hpp" #include "ops/paged_caching.hpp" #include "ops/random_sample.hpp" +#include "ops/random_sample_batched.hpp" #include "ops/rearrange.hpp" #include "ops/rms_norm.hpp" #include "ops/rope.hpp" diff --git a/include/infinicore/ops/embedding.hpp b/include/infinicore/ops/embedding.hpp index 4fd9991c4..6be997134 100644 --- a/include/infinicore/ops/embedding.hpp +++ b/include/infinicore/ops/embedding.hpp @@ -4,6 +4,13 @@ namespace infinicore::op { +class Embedding { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor out, Tensor input, Tensor weight); + static common::OpDispatcher &dispatcher(); +}; + Tensor embedding(Tensor input, Tensor weight); void embedding_(Tensor out, Tensor input, Tensor weight); } // namespace infinicore::op diff --git a/include/infinicore/ops/flash_attention.hpp b/include/infinicore/ops/flash_attention.hpp new file mode 100644 index 000000000..957255192 --- /dev/null +++ b/include/infinicore/ops/flash_attention.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class FlashAttention { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, bool); + static void execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal); + static common::OpDispatcher &dispatcher(); +}; + +Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal); +void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal); +} // namespace infinicore::op diff --git a/include/infinicore/ops/kv_caching.hpp b/include/infinicore/ops/kv_caching.hpp new file mode 100644 index 000000000..e4b6f514c --- /dev/null +++ b/include/infinicore/ops/kv_caching.hpp @@ -0,0 +1,28 @@ +#pragma + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class KVCaching { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); + static void execute(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths); + static common::OpDispatcher &dispatcher(); +}; + +Tensor kv_caching(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths); +void kv_caching_(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths); +} // namespace infinicore::op diff --git a/include/infinicore/ops/random_sample_batched.hpp b/include/infinicore/ops/random_sample_batched.hpp new file mode 100644 index 000000000..8906bc12b --- /dev/null +++ b/include/infinicore/ops/random_sample_batched.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class RandomSampleBatched { +public: + using schema = void (*)(Tensor, Tensor, const float *, const float *, const int *, const float *, int); + static void execute(Tensor result, Tensor probs, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); + static common::OpDispatcher &dispatcher(); +}; + +// Out-of-place API +Tensor random_sample_batched(Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); +// In-place API +void random_sample_batched_(Tensor indices, Tensor logits, const float *random_val, const float *topp, const int *topk, const float *temperature, int batch_size); + +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index c0a09fcb4..ca42e1509 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -9,8 +9,11 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/embedding.h" +#include "infiniop/ops/flash_attention.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" +#include "infiniop/ops/kv_caching.h" #include "infiniop/ops/layer_norm.h" #include "infiniop/ops/logsoftmax.h" #include "infiniop/ops/lp_norm.h" @@ -20,6 +23,7 @@ #include "infiniop/ops/paged_attention_prefill.h" #include "infiniop/ops/paged_caching.h" #include "infiniop/ops/random_sample.h" +#include "infiniop/ops/random_sample_batched.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" diff --git a/include/infiniop/ops/embedding.h b/include/infiniop/ops/embedding.h new file mode 100644 index 000000000..cd1df3a73 --- /dev/null +++ b/include/infiniop/ops/embedding.h @@ -0,0 +1,25 @@ +#ifndef __INFINIOP_EMBEDDING_API_H__ +#define __INFINIOP_EMBEDDING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t; + +__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor( + infiniopHandle_t handle, + infiniopEmbeddingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc); + +__C __export infiniStatus_t infiniopEmbedding( + infiniopEmbeddingDescriptor_t desc, + void *output, + const void *input, + const void *weight, + void *stream); + +__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor( + infiniopEmbeddingDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/flash_attention.h b/include/infiniop/ops/flash_attention.h new file mode 100644 index 000000000..06c3ff47c --- /dev/null +++ b/include/infiniop/ops/flash_attention.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_FLASH_ATTENTION_API_H__ +#define __INFINIOP_FLASH_ATTENTION_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t; + +__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + float scale, + char is_causal); + +__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + void *stream); + +__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor( + infiniopFlashAttentionDescriptor_t desc); +#endif diff --git a/include/infiniop/ops/kv_caching.h b/include/infiniop/ops/kv_caching.h new file mode 100644 index 000000000..e6efa48b3 --- /dev/null +++ b/include/infiniop/ops/kv_caching.h @@ -0,0 +1,31 @@ +#ifndef __INFINIOP_KV_CACHING_API_H__ +#define __INFINIOP_KV_CACHING_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopKVCachingDescriptor_t; + +__C __export infiniStatus_t infiniopCreateKVCachingDescriptor( + infiniopHandle_t handle, + infiniopKVCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths); + +__C __export infiniStatus_t infiniopGetKVCachingWorkspaceSize(infiniopKVCachingDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopKVCaching(infiniopKVCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream); + +__C __export infiniStatus_t infiniopDestroyKVCachingDescriptor(infiniopKVCachingDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/random_sample.h b/include/infiniop/ops/random_sample.h index 1c242d7ba..bb2b15959 100644 --- a/include/infiniop/ops/random_sample.h +++ b/include/infiniop/ops/random_sample.h @@ -15,12 +15,6 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize( infiniopRandomSampleDescriptor_t desc, size_t *size); -__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor( - infiniopHandle_t handle, - infiniopRandomSampleDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t result, - infiniopTensorDescriptor_t probs); - __C __export infiniStatus_t infiniopRandomSample( infiniopRandomSampleDescriptor_t desc, void *workspace, diff --git a/include/infiniop/ops/random_sample_batched.h b/include/infiniop/ops/random_sample_batched.h new file mode 100644 index 000000000..4512e7dcb --- /dev/null +++ b/include/infiniop/ops/random_sample_batched.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__ +#define __INFINIOP_RANDOM_SAMPLE_BATCHED_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopRandomSampleBatchedDescriptor_t; + +__C __export infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor( + infiniopHandle_t handle, + infiniopRandomSampleBatchedDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs); + +__C __export infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize( + infiniopRandomSampleBatchedDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopRandomSampleBatched( + infiniopRandomSampleBatchedDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size, + void *stream); + +__C __export infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor( + infiniopRandomSampleBatchedDescriptor_t desc); + +#endif diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index f1add0c97..96f19803c 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -1,15 +1,34 @@ #include "infinicore/ops/embedding.hpp" +#include "../../utils.hpp" #include "infinicore/context/context.hpp" #include +#include namespace infinicore::op { +common::OpDispatcher &Embedding::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void Embedding::execute(Tensor out, Tensor input, Tensor weight) { + // Check that all tensors are on the same device + // This is critical: if input is on CPU while out/weight are on GPU, + // passing CPU pointer to CUDA kernel will cause memory access errors + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight); + + // Set device context + infinicore::context::setDevice(out->device()); + + // Use dispatcher to lookup kernel (infiniop implementation) + dispatcher().lookup(out->device().getType())(out, input, weight); +} + Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 ) { auto input_shape = input->shape(); auto weight_shape = weight->shape(); - // auto vocab_size = weight_shape[0]; auto embedding_dim = weight_shape[1]; // Assign memory to out variables @@ -22,68 +41,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i } void embedding_(Tensor out, Tensor input, Tensor weight) { - assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); - assert(infinicore::Device::Type::CPU == input->device().getType()); - - auto input_shape = input->shape(); - auto weight_shape = weight->shape(); - auto embedding_dim = weight_shape[1]; - - // Calculate the number of token - Size counts = 1; - for (auto &v : input_shape) { - counts *= v; - } - - // the bytes of one token - const Size bytes = dsize(weight->dtype()) * embedding_dim; - auto *weight_ptr = weight->data(); - auto *out_ptr = out->data(); - - // copies - if (weight->device().getType() == Device::Type::CPU) { - if (infinicore::DataType::I64 == input->dtype()) { - const int64_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - std::memcpy(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } else if (infinicore::DataType::I32 == input->dtype()) { - const int32_t *input_arr = reinterpret_cast(input->data()); - - for (Size i = 0; i < counts; ++i) { - int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - std::memcpy(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } - - } else { - if (infinicore::DataType::I64 == input->dtype()) { - const int64_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int64_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - context::memcpyD2D(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } else if (infinicore::DataType::I32 == input->dtype()) { - const int32_t *input_arr = reinterpret_cast(input->data()); - for (Size i = 0; i < counts; ++i) { - int32_t idx = input_arr[i]; - assert((idx >= 0) && (idx < weight_shape[0])); - context::memcpyD2D(out_ptr + i * bytes, - weight_ptr + idx * bytes, - bytes); - } - } - } + Embedding::execute(out, input, weight); } } // namespace infinicore::op diff --git a/src/infinicore/ops/embedding/embedding_infiniop.cc b/src/infinicore/ops/embedding/embedding_infiniop.cc new file mode 100644 index 000000000..dfbbb2f71 --- /dev/null +++ b/src/infinicore/ops/embedding/embedding_infiniop.cc @@ -0,0 +1,49 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/embedding.hpp" +#include + +namespace infinicore::op::embedding_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopEmbeddingDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor input, Tensor weight) { + size_t seed = hash_combine(out, input, weight); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopEmbeddingDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor( + context::getInfiniopHandle(device), &desc, + out->desc(), input->desc(), weight->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + INFINICORE_CHECK_ERROR(infiniopEmbedding( + desc, + out->data(), + input->data(), + weight->data(), + context::getStream())); +} + +static bool registered = []() { + Embedding::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::embedding_impl::infiniop diff --git a/src/infinicore/ops/flash_attention/flash_attention.cc b/src/infinicore/ops/flash_attention/flash_attention.cc new file mode 100644 index 000000000..97db6de79 --- /dev/null +++ b/src/infinicore/ops/flash_attention/flash_attention.cc @@ -0,0 +1,29 @@ +#include "infinicore/ops/flash_attention.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &FlashAttention::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v); + infinicore::context::setDevice(out->device()); + dispatcher().lookup(out->device().getType())( + out, q, k, v, scale, is_causal); +} + +Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { + Shape shape = q->shape(); + auto out = Tensor::empty(shape, q->dtype(), q->device()); + flash_attention_(out, q, k, v, scale, is_causal); + return out; +} + +void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { + FlashAttention::execute(out, q, k, v, scale, is_causal); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc new file mode 100644 index 000000000..e0a91e681 --- /dev/null +++ b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc @@ -0,0 +1,51 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/flash_attention.hpp" +#include + +namespace infinicore::op::flash_attention_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopFlashAttentionDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyFlashAttentionDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { + size_t seed = hash_combine(out, q, k, v, scale, is_causal); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopFlashAttentionDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateFlashAttentionDescriptor( + context::getInfiniopHandle(device), &desc, + out->desc(), q->desc(), k->desc(), v->desc(), + scale, static_cast(is_causal))); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetFlashAttentionWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopFlashAttention( + desc, workspace->data(), workspace_size, + out->data(), q->data(), k->data(), v->data(), context::getStream())); +} + +static bool registered = []() { + FlashAttention::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::flash_attention_impl::infiniop diff --git a/src/infinicore/ops/kv_caching/kv_caching.cc b/src/infinicore/ops/kv_caching/kv_caching.cc new file mode 100644 index 000000000..bed3a4566 --- /dev/null +++ b/src/infinicore/ops/kv_caching/kv_caching.cc @@ -0,0 +1,47 @@ +#include "infinicore/ops/kv_caching.hpp" + +#include "../../utils.hpp" + +#include + +namespace infinicore::op { + +common::OpDispatcher &KVCaching::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void KVCaching::execute(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, past_kv_lengths); + infinicore::context::setDevice(k_cache->device()); + auto device_type = k_cache->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No KVCaching implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(k_cache, v_cache, k, v, past_kv_lengths); +} + +Tensor kv_caching(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths) { + KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths); + return k_cache; // or v_cache, depending on the intended use +} + +void kv_caching_(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths) { + KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc new file mode 100644 index 000000000..37d5e1fa3 --- /dev/null +++ b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc @@ -0,0 +1,59 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/kv_caching.hpp" +#include + +namespace infinicore::op::kv_caching_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopKVCachingDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyKVCachingDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor k_cache, + Tensor v_cache, + Tensor k, + Tensor v, + Tensor past_kv_lengths) { + size_t seed = hash_combine(k_cache, v_cache, k, v, past_kv_lengths); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopKVCachingDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateKVCachingDescriptor( + context::getInfiniopHandle(device), &desc, + k_cache->desc(), v_cache->desc(), + k->desc(), v->desc(), + past_kv_lengths->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetKVCachingWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopKVCaching( + desc, workspace->data(), workspace_size, + k_cache->data(), v_cache->data(), + k->data(), v->data(), + past_kv_lengths->data(), + context::getStream())); +} + +static bool registered = []() { + KVCaching::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::kv_caching_impl::infiniop diff --git a/src/infinicore/ops/random_sample_batched/random_sample_batched.cc b/src/infinicore/ops/random_sample_batched/random_sample_batched.cc new file mode 100644 index 000000000..a02635f66 --- /dev/null +++ b/src/infinicore/ops/random_sample_batched/random_sample_batched.cc @@ -0,0 +1,54 @@ +#include "infinicore/ops/random_sample_batched.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &RandomSampleBatched::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void RandomSampleBatched::execute( + Tensor result, + Tensor probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(result, probs); + infinicore::context::setDevice(result->device()); + auto device_type = result->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No RandomSampleBatched implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(result, probs, random_val, topp, topk, temperature, batch_size); +} + +Tensor random_sample_batched( + Tensor logits, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + Shape shape = logits->shape(); + auto result = Tensor::empty(shape, DataType::I32, logits->device()); + random_sample_batched_(result, logits, random_val, topp, topk, temperature, batch_size); + return result; +} +void random_sample_batched_( + Tensor result, + Tensor logits, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + RandomSampleBatched::execute(result, logits, random_val, topp, topk, temperature, batch_size); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc b/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc new file mode 100644 index 000000000..2916c0b2a --- /dev/null +++ b/src/infinicore/ops/random_sample_batched/random_sample_batched_infiniop.cc @@ -0,0 +1,63 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/random_sample_batched.hpp" +#include + +namespace infinicore::op::random_sample_batched_impl::infiniop_backend { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopRandomSampleBatchedDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyRandomSampleBatchedDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate( + Tensor result, + Tensor probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size) { + size_t seed = hash_combine(result, probs, batch_size); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopRandomSampleBatchedDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleBatchedDescriptor( + context::getInfiniopHandle(device), &desc, + result->desc(), probs->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetRandomSampleBatchedWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopRandomSampleBatched( + desc, + workspace->data(), workspace_size, + result->data(), probs->data(), + random_val, topp, topk, temperature, + batch_size, + context::getStream())); +} + +} // namespace infinicore::op::random_sample_batched_impl::infiniop_backend + +namespace infinicore::op { +static bool registered = []() { + RandomSampleBatched::dispatcher().registerAll(&random_sample_batched_impl::infiniop_backend::calculate, false); + return true; +}(); +} // namespace infinicore::op diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc new file mode 100644 index 000000000..0bf7864c9 --- /dev/null +++ b/src/infiniop/ops/embedding/operator.cc @@ -0,0 +1,89 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/embedding.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/embedding_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +// #include "nvidia/embedding_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateEmbeddingDescriptor( + infiniopHandle_t handle, + infiniopEmbeddingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::embedding::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + output_desc, \ + input_desc, \ + weight_desc) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopEmbedding( + infiniopEmbeddingDescriptor_t desc, + void *output, + const void *input, + const void *weight, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(output, input, weight, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + } + +#undef DELETE + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc new file mode 100644 index 000000000..f8699f15b --- /dev/null +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -0,0 +1,121 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/flash_attention.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/flash_attention_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +// #include "nvidia/flash_attention_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + float scale, + char is_causal) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::flash_attention::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + q_desc, \ + k_desc, \ + v_desc, \ + scale, \ + is_causal); + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->get_workspace_size(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // GET_SIZE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, q, k, v, stream); + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyFlashAttentionDescriptor( + infiniopFlashAttentionDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DESTROY(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} diff --git a/src/infiniop/ops/kv_caching/operator.cc b/src/infiniop/ops/kv_caching/operator.cc new file mode 100644 index 000000000..65b27a414 --- /dev/null +++ b/src/infiniop/ops/kv_caching/operator.cc @@ -0,0 +1,121 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/kv_caching.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/kv_caching_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +// #include "nvidia/kv_caching_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateKVCachingDescriptor( + infiniopHandle_t handle, + infiniopKVCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::kv_caching::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + k_cache, \ + v_cache, \ + k, \ + v, \ + past_kv_lengths) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetKVCachingWorkspaceSize( + infiniopKVCachingDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->get_workspace_size(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // GET_SIZE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopKVCaching( + infiniopKVCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, k_cache, v_cache, k, v, past_kv_lengths, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyKVCachingDescriptor( + infiniopKVCachingDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} diff --git a/src/infiniop/ops/random_sample_batched/operator.cc b/src/infiniop/ops/random_sample_batched/operator.cc new file mode 100644 index 000000000..d0047ad53 --- /dev/null +++ b/src/infiniop/ops/random_sample_batched/operator.cc @@ -0,0 +1,128 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/random_sample_batched.h" + +#ifdef ENABLE_CPU_API +// #include "cpu/random_sample_batched_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +// #include "nvidia/random_sample_batched_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateRandomSampleBatchedDescriptor( + infiniopHandle_t handle, + infiniopRandomSampleBatchedDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t result, + infiniopTensorDescriptor_t probs) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::random_sample::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + result, \ + probs) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetRandomSampleBatchedWorkspaceSize( + infiniopRandomSampleBatchedDescriptor_t desc, + size_t *size) { + +#define GET_SIZE(CASE, NAMESPACE) \ + case CASE: \ + using Ptr = const op::random_sample::NAMESPACE::Descriptor *; \ + *size = reinterpret_cast(desc)->minWorkspaceSize(); \ + } \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // GET_SIZE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET_SIZE +} + +__C infiniStatus_t infiniopRandomSampleBatched( + infiniopRandomSampleBatchedDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + const float *random_val, + const float *topp, + const int *topk, + const float *temperature, + int batch_size, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, \ + result, probs, \ + random_val, \ + topp, topk, temperature, \ + batch_size, \ + stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyRandomSampleBatchedDescriptor( + infiniopRandomSampleBatchedDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) + // DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} From f935452d98bb7a007846123dbdd5c092fdb13ab9 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 25 Aug 2025 20:04:27 +0800 Subject: [PATCH 02/14] issue/402 - convenient ninetoothed util --- src/infiniop/ninetoothed/utils.h | 75 +++++++++++++++++++++ src/infiniop/ops/relu/metax/relu_metax.maca | 21 ++---- src/infiniop/ops/relu/nvidia/relu_nvidia.cu | 21 ++---- 3 files changed, 85 insertions(+), 32 deletions(-) create mode 100644 src/infiniop/ninetoothed/utils.h diff --git a/src/infiniop/ninetoothed/utils.h b/src/infiniop/ninetoothed/utils.h new file mode 100644 index 000000000..1b7d1fe3a --- /dev/null +++ b/src/infiniop/ninetoothed/utils.h @@ -0,0 +1,75 @@ +#ifndef __NINETOOTHED_UTILS__ +#define __NINETOOTHED_UTILS__ + +#include +#include +#include +#include + +namespace ninetoothed { + +template +class Tensor { +public: + using Data = decltype(NineToothedTensor::data); + + using Size = std::remove_pointer_t; + + using Stride = std::remove_pointer_t; + + template + Tensor(const void *data, Shape shape, Strides strides) : data_{data}, shape_{shape}, strides_{strides}, ndim_{shape_.size()} {} + + Tensor(const void *data, std::initializer_list shape, std::initializer_list strides) : Tensor{data, decltype(shape_){shape}, decltype(strides_){strides}} {} + + Tensor(const void *data, const Size *shape, const Stride *strides, Size ndim) : data_{data}, shape_{shape, shape + ndim}, strides_{strides, strides + ndim}, ndim_{shape_.size()} {} + + Tensor(const T value) : value_{value}, data_{&value_}, ndim_{0} {} + + operator NineToothedTensor() { return {const_cast(data_), shape_.data(), strides_.data()}; } + + template + Tensor expand(const Shape &sizes) const { + auto new_ndim{sizes.size()}; + + decltype(shape_) shape(new_ndim, 1); + decltype(strides_) strides(new_ndim, 0); + + auto num_new_dims{new_ndim - ndim_}; + + for (auto dim{decltype(ndim_){0}}; dim < ndim_; ++dim) { + shape[dim + num_new_dims] = shape_[dim]; + strides[dim + num_new_dims] = strides_[dim]; + } + + for (auto dim{decltype(new_ndim){0}}; dim < new_ndim; ++dim) { + if (sizes[dim] == std::numeric_limits>::max() || shape[dim] != 1) { + continue; + } + + shape[dim] = sizes[dim]; + strides[dim] = 0; + } + + return {data_, shape, strides}; + } + + Tensor expand_as(const Tensor &other) const { + return expand(other.shape_); + } + +private: + const void *data_{nullptr}; + + std::vector shape_; + + std::vector strides_; + + Size ndim_{0}; + + T value_{0}; +}; + +} // namespace ninetoothed + +#endif diff --git a/src/infiniop/ops/relu/metax/relu_metax.maca b/src/infiniop/ops/relu/metax/relu_metax.maca index 900fce9e0..2c5104bdd 100644 --- a/src/infiniop/ops/relu/metax/relu_metax.maca +++ b/src/infiniop/ops/relu/metax/relu_metax.maca @@ -2,6 +2,7 @@ #include "../../../../../build/ninetoothed/relu.h" #include "../../../devices/metax/metax_common.h" +#include "../../../ninetoothed/utils.h" #include "relu_metax.h" namespace op::relu::metax { @@ -42,22 +43,10 @@ infiniStatus_t Descriptor::calculate( } const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { diff --git a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu index 22b85a401..a3c79fb52 100644 --- a/src/infiniop/ops/relu/nvidia/relu_nvidia.cu +++ b/src/infiniop/ops/relu/nvidia/relu_nvidia.cu @@ -1,5 +1,6 @@ #ifdef ENABLE_NINETOOTHED #include "../../../../../build/ninetoothed/relu.h" +#include "../../../ninetoothed/utils.h" #endif #include "../../../devices/nvidia/nvidia_common.cuh" #include "../../../elementwise/nvidia/elementwise_nvidia.cuh" @@ -45,22 +46,10 @@ infiniStatus_t Descriptor::calculate( } #ifdef ENABLE_NINETOOTHED const auto &ndim{_info.getNdim()}; - const auto &x_shape_{_info.getInputShape(0)}; - const auto &x_strides_{_info.getInputStrides(0)}; - std::vector x_shape_vec{x_shape_, x_shape_ + ndim}; - std::vector x_strides_vec{x_strides_, x_strides_ + ndim}; - auto x_data{const_cast(inputs[0])}; - auto x_shape{x_shape_vec.data()}; - auto x_strides{x_strides_vec.data()}; - const NineToothedTensor x{x_data, x_shape, x_strides}; - const auto &y_shape_{_info.getOutputShape()}; - const auto &y_strides_{_info.getOutputStrides()}; - std::vector y_shape_vec{y_shape_, y_shape_ + ndim}; - std::vector y_strides_vec{y_strides_, y_strides_ + ndim}; - auto y_data{output}; - auto y_shape{y_shape_vec.data()}; - auto y_strides{y_strides_vec.data()}; - const NineToothedTensor y{y_data, y_shape, y_strides}; + + auto x{ninetoothed::Tensor{inputs[0], _info.getInputShape(0), _info.getInputStrides(0), ndim}}; + auto y{ninetoothed::Tensor{output, _info.getOutputShape(), _info.getOutputStrides(), ndim}}; + constexpr auto block_size{1024}; switch (_dtype) { From 3bd75ca22e179772b64d6e6bc0e6d1937471bf98 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 13 Jan 2026 12:39:00 +0000 Subject: [PATCH 03/14] issue/919 - Add a NineToothed implementation of `scaled_dot_product_attention` --- python/infinicore/nn/functional/__init__.py | 2 + .../scaled_dot_product_attention.py | 28 ++++ src/infinicore/pybind11/ops.hpp | 2 + .../pybind11/ops/flash_attention.hpp | 21 +++ .../ops/flash_attention/ninetoothed/build.py | 35 +++++ .../flash_attention/ninetoothed/descriptor.h | 133 ++++++++++++++++++ src/infiniop/ops/flash_attention/operator.cc | 20 +++ .../ops/scaled_dot_product_attention.py | 18 ++- 8 files changed, 249 insertions(+), 10 deletions(-) create mode 100644 python/infinicore/nn/functional/scaled_dot_product_attention.py create mode 100644 src/infinicore/pybind11/ops/flash_attention.hpp create mode 100644 src/infiniop/ops/flash_attention/ninetoothed/build.py create mode 100644 src/infiniop/ops/flash_attention/ninetoothed/descriptor.h diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..f8c7d6ef0 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -4,6 +4,7 @@ from .random_sample import random_sample from .rms_norm import rms_norm from .rope import RopeAlgo, rope +from .scaled_dot_product_attention import scaled_dot_product_attention from .silu import silu from .swiglu import swiglu @@ -11,6 +12,7 @@ "causal_softmax", "random_sample", "rms_norm", + "scaled_dot_product_attention", "silu", "swiglu", "linear", diff --git a/python/infinicore/nn/functional/scaled_dot_product_attention.py b/python/infinicore/nn/functional/scaled_dot_product_attention.py new file mode 100644 index 000000000..d89f484fe --- /dev/null +++ b/python/infinicore/nn/functional/scaled_dot_product_attention.py @@ -0,0 +1,28 @@ +import math + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def scaled_dot_product_attention( + query, + key, + value, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, +): + assert attn_mask is None and dropout_p == 0 and not enable_gqa + + emb_dim = query.shape[-1] + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + return Tensor( + _infinicore.flash_attention( + query._underlying, key._underlying, value._underlying, scale, is_causal + ) + ) diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 3d6ebe79a..e2d5aa00b 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -7,6 +7,7 @@ #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" +#include "ops/flash_attention.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" @@ -29,6 +30,7 @@ inline void bind(py::module &m) { bind_add_rms_norm(m); bind_attention(m); bind_causal_softmax(m); + bind_flash_attention(m); bind_random_sample(m); bind_linear(m); bind_matmul(m); diff --git a/src/infinicore/pybind11/ops/flash_attention.hpp b/src/infinicore/pybind11/ops/flash_attention.hpp new file mode 100644 index 000000000..09ec91980 --- /dev/null +++ b/src/infinicore/pybind11/ops/flash_attention.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "infinicore/ops/flash_attention.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_flash_attention(py::module &m) { + m.def("flash_attention", + &op::flash_attention, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("scale"), + py::arg("is_causal")); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/flash_attention/ninetoothed/build.py b/src/infiniop/ops/flash_attention/ninetoothed/build.py new file mode 100644 index 000000000..dfcce6910 --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/build.py @@ -0,0 +1,35 @@ +import ninetoothed +from ntops.kernels import scaled_dot_product_attention +from ntops.kernels.scaled_dot_product_attention import CausalVariant + +import infiniop.ninetoothed.build + + +def build(): + with_kv_cache_values = (0,) + emb_dim_values = (16, 32, 64, 128, 256) + is_causal_values = (0, 1) + with_attn_mask_values = (0,) + causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT) + dtype_values = (ninetoothed.float16, ninetoothed.float32) + block_size_m_values = (64,) + block_size_n_values = (64,) + + constexpr_param_grid = { + "with_kv_cache": with_kv_cache_values, + "emb_dim": emb_dim_values, + "is_causal": is_causal_values, + "with_attn_mask": with_attn_mask_values, + "causal_variant": causal_variant_values, + "dtype": dtype_values, + "block_size_m": block_size_m_values, + "block_size_n": block_size_n_values, + } + + infiniop.ninetoothed.build.build( + scaled_dot_product_attention.premake, + constexpr_param_grid, + caller="cuda", + op_name="flash_attention", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h new file mode 100644 index 000000000..697891d3d --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -0,0 +1,133 @@ +#ifndef __FLASH_ATTENTION_DESCRIPTOR_H__ +#define __FLASH_ATTENTION_DESCRIPTOR_H__ + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/flash_attention.h" +#include "../../../ninetoothed/utils.h" + +namespace op::flash_attention::ninetoothed { + +class Descriptor final : public InfiniopDescriptor { +public: + Descriptor(infiniopHandle_t handle, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + double scale, + char is_causal) : InfiniopDescriptor{handle->device, handle->device_id}, + _query_shape{q_desc->shape()}, + _query_strides{q_desc->strides()}, + _key_shape{k_desc->shape()}, + _key_strides{k_desc->strides()}, + _value_shape{v_desc->shape()}, + _value_strides{v_desc->strides()}, + _output_strides{out_desc->strides()}, + _dtype{q_desc->dtype()}, + _scale{scale}, + _is_causal{is_causal} {} + + ~Descriptor() = default; + + size_t get_workspace_size() const { + return 0; + } + + infiniStatus_t calculate(void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k, + const void *v, + void *stream) const { + uint64_t empty_shape[4]; + int64_t empty_strides[4]; + + auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}}; + auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}}; + auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}}; + + NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides}; + NineToothedTensor is_causal; + NineToothedTensor scale{const_cast(&_scale), nullptr, nullptr}; + auto output{::ninetoothed::Tensor{out, _query_shape, _output_strides}}; + NineToothedTensor with_attn_mask; + NineToothedTensor causal_variant; + + const auto with_kv_cache_{0}; + const auto emb_dim_{_query_shape[3]}; + const auto is_causal_{_is_causal}; + const auto with_attn_mask_{0}; + const auto causal_variant_{1}; + const auto dtype_{_dtype}; + + constexpr auto block_size_m_{64}; + constexpr auto block_size_n_{64}; + + launch_flash_attention(stream, + query, + key, + value, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache_, + emb_dim_, + is_causal_, + with_attn_mask_, + causal_variant_, + dtype_, + block_size_m_, + block_size_n_); + + return INFINI_STATUS_SUCCESS; + } + + static infiniStatus_t create(infiniopHandle_t handle, + Descriptor **desc, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + double scale, + char is_causal) { + *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, scale, is_causal}; + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector _query_shape; + + std::vector _query_strides; + + std::vector _key_shape; + + std::vector _key_strides; + + std::vector _value_shape; + + std::vector _value_strides; + + std::vector _output_strides; + + infiniDtype_t _dtype; + + double _scale; + + char _is_causal; +}; + +} // namespace op::flash_attention::ninetoothed + +#endif // __FLASH_ATTENTION_DESCRIPTOR_H__ diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc index f8699f15b..e907d3c41 100644 --- a/src/infiniop/ops/flash_attention/operator.cc +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -6,8 +6,12 @@ // #include "cpu/flash_attention_cpu.h" #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) +#include "ninetoothed/descriptor.h" +#else // #include "nvidia/flash_attention_nvidia.cuh" #endif +#endif __C infiniStatus_t infiniopCreateFlashAttentionDescriptor( infiniopHandle_t handle, @@ -37,7 +41,11 @@ __C infiniStatus_t infiniopCreateFlashAttentionDescriptor( // CREATE(INFINI_DEVICE_CPU, cpu); #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -60,7 +68,11 @@ __C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( // GET_SIZE(INFINI_DEVICE_CPU, cpu); #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); +#endif #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -90,7 +102,11 @@ __C infiniStatus_t infiniopFlashAttention( // CALCULATE(INFINI_DEVICE_CPU, cpu); #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -112,7 +128,11 @@ __C infiniStatus_t infiniopDestroyFlashAttentionDescriptor( // DESTROY(INFINI_DEVICE_CPU, cpu); #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed); +#else // DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/test/infinicore/ops/scaled_dot_product_attention.py b/test/infinicore/ops/scaled_dot_product_attention.py index 218420d72..644fb6f99 100644 --- a/test/infinicore/ops/scaled_dot_product_attention.py +++ b/test/infinicore/ops/scaled_dot_product_attention.py @@ -11,17 +11,16 @@ # q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim) _TEST_CASES_DATA = [ - ((2, 8, 16), (2, 8, 16), (2, 8, 16), None, 0.0, False), - ((1, 4, 32), (1, 4, 32), (1, 4, 32), None, 0.0, False), - ((2, 6, 12), (2, 6, 12), (2, 6, 12), None, 0.0, True), - ((3, 8, 8), (3, 8, 8), (3, 8, 8), None, 0.0, False), - ((2, 4, 16), (2, 4, 16), (2, 4, 16), None, 0.0, True), - ((1, 2, 64), (1, 2, 64), (1, 2, 64), None, 0.0, False), + ((1, 1, 2, 16), (1, 1, 2, 16), (1, 1, 2, 16), None, 0.0, False), + ((1, 2, 8, 16), (1, 2, 8, 16), (1, 2, 8, 16), None, 0.0, False), + ((1, 1, 4, 32), (1, 1, 4, 32), (1, 1, 4, 32), None, 0.0, False), + ((1, 2, 4, 16), (1, 2, 4, 16), (1, 2, 4, 16), None, 0.0, True), + ((1, 1, 2, 64), (1, 1, 2, 64), (1, 1, 2, 64), None, 0.0, False), ] _TOLERANCE_MAP = { infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, - infinicore.float32: {"atol": 1e-4, "rtol": 1e-4}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, } _TENSOR_DTYPES = [infinicore.float16, infinicore.float32] @@ -68,9 +67,8 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.scaled_dot_product_attention(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + return infinicore.nn.functional.scaled_dot_product_attention(*args, **kwargs) def main(): From 588c04c6aeb429f4d28c3722768a672b5ff90848 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang <45955067+voltjia@users.noreply.github.com> Date: Wed, 14 Jan 2026 18:54:30 +0800 Subject: [PATCH 04/14] issue/925: Speed up `scripts/build_ntops.py` and `src/infiniop/ninetoothed/build.py` with `concurrent.futures` (#926) --- scripts/build_ntops.py | 28 +++++++---- src/infiniop/ninetoothed/build.py | 77 +++++++++++++++++++------------ 2 files changed, 68 insertions(+), 37 deletions(-) diff --git a/scripts/build_ntops.py b/scripts/build_ntops.py index 1499b6bf8..e1397e56d 100644 --- a/scripts/build_ntops.py +++ b/scripts/build_ntops.py @@ -1,3 +1,4 @@ +import concurrent.futures import importlib import pathlib @@ -11,16 +12,27 @@ def _find_and_build_ops(): ops_path = SRC_DIR_PATH / "infiniop" / "ops" - for op_dir in ops_path.iterdir(): - ninetoothed_path = op_dir / "ninetoothed" + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] - if ninetoothed_path.is_dir(): - module_path = ninetoothed_path / "build" - relative_path = module_path.relative_to(SRC_DIR_PATH) - import_name = ".".join(relative_path.parts) - module = importlib.import_module(import_name) + for op_dir in ops_path.iterdir(): + ninetoothed_path = op_dir / "ninetoothed" - module.build() + if not ninetoothed_path.is_dir(): + continue + + futures.append(executor.submit(_build, ninetoothed_path)) + + concurrent.futures.as_completed(futures) + + +def _build(ninetoothed_path): + module_path = ninetoothed_path / "build" + relative_path = module_path.relative_to(SRC_DIR_PATH) + import_name = ".".join(relative_path.parts) + module = importlib.import_module(import_name) + + module.build() if __name__ == "__main__": diff --git a/src/infiniop/ninetoothed/build.py b/src/infiniop/ninetoothed/build.py index aea421b7f..153e6b9f5 100644 --- a/src/infiniop/ninetoothed/build.py +++ b/src/infiniop/ninetoothed/build.py @@ -1,3 +1,4 @@ +import concurrent.futures import functools import inspect import itertools @@ -16,40 +17,28 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): headers = [] all_param_names = [] + combinations = [] launches = [] - for combination in _generate_param_value_combinations(constexpr_param_grid): - arrangement, application, tensors = premake(**combination) + with concurrent.futures.ProcessPoolExecutor() as executor: + futures = [] - for param_name, param_value in combination.items(): - if isinstance(param_value, str): - combination[param_name] = ( - f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}" - ) + for combination in tuple( + _generate_param_value_combinations(constexpr_param_grid) + ): + future = executor.submit( + _make, premake, combination, caller, op_name, output_dir + ) - combination = {f"{name}_": value for name, value in combination.items()} + futures.append(future) - kernel_name = f"{op_name}_{_generate_suffix(combination.values())}" + for future in concurrent.futures.as_completed(futures): + header, param_names, combination, launch = future.result() - ninetoothed.make( - arrangement, - application, - tensors, - caller=caller, - kernel_name=kernel_name, - output_dir=output_dir, - ) - - header = output_dir / f"{kernel_name}.h" - param_names = ("stream",) + tuple( - inspect.signature(application).parameters.keys() - ) - launch = f""" if ({_generate_condition(combination)}) - return launch_{kernel_name}({", ".join(param_names)});""" - - headers.append(header) - all_param_names.append(param_names) - launches.append(launch) + headers.append(header) + all_param_names.append(param_names) + combinations.append(combination) + launches.append(launch) includes = "\n".join(f'#include "{header}"' for header in headers) @@ -64,7 +53,7 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): "NineToothedStream", ] + ["NineToothedTensor" for _ in range(len(param_names) - 1)] - for param_name in combination: + for param_name in functools.reduce(lambda x, y: x | y, combinations, {}): param_names.append(param_name) param_types.append("int") @@ -97,6 +86,36 @@ def build(premake, constexpr_param_grid, caller, op_name, output_dir): (BUILD_DIRECTORY_PATH / header_file_name).write_text(header_content) +def _make(premake, combination, caller, op_name, output_dir): + arrangement, application, tensors = premake(**combination) + + for param_name, param_value in combination.items(): + if isinstance(param_value, str): + combination[param_name] = ( + f"INFINI_DTYPE_{combination[param_name].replace('fp', 'F').upper()}" + ) + + combination = {f"{name}_": value for name, value in combination.items()} + + kernel_name = f"{op_name}_{_generate_suffix(combination.values())}" + + ninetoothed.make( + arrangement, + application, + tensors, + caller=caller, + kernel_name=kernel_name, + output_dir=output_dir, + ) + + header = output_dir / f"{kernel_name}.h" + param_names = ("stream",) + tuple(inspect.signature(application).parameters.keys()) + launch = f""" if ({_generate_condition(combination)}) + return launch_{kernel_name}({", ".join(param_names)});""" + + return header, param_names, combination, launch + + def _generate_condition(combination): return " && ".join(f"{param} == {value}" for param, value in combination.items()) From bb1329ae56d3149b87acb4056d9d980cf460d6d2 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 16 Jan 2026 14:09:19 +0800 Subject: [PATCH 05/14] issue/889 - revert embedding modifications --- include/infinicore/ops.hpp | 1 - include/infinicore/ops/embedding.hpp | 7 -- include/infiniop.h | 1 - include/infiniop/ops/embedding.h | 25 ------ src/infinicore/ops/embedding/embedding.cc | 84 ++++++++++++----- .../ops/embedding/embedding_infiniop.cc | 49 ---------- src/infiniop/ops/embedding/operator.cc | 89 ------------------- 7 files changed, 63 insertions(+), 193 deletions(-) delete mode 100644 include/infiniop/ops/embedding.h delete mode 100644 src/infinicore/ops/embedding/embedding_infiniop.cc delete mode 100644 src/infiniop/ops/embedding/operator.cc diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 3fb47d383..a156a8176 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -4,7 +4,6 @@ #include "ops/add_rms_norm.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" -#include "ops/embedding.hpp" #include "ops/flash_attention.hpp" #include "ops/kv_caching.hpp" #include "ops/matmul.hpp" diff --git a/include/infinicore/ops/embedding.hpp b/include/infinicore/ops/embedding.hpp index 6be997134..4fd9991c4 100644 --- a/include/infinicore/ops/embedding.hpp +++ b/include/infinicore/ops/embedding.hpp @@ -4,13 +4,6 @@ namespace infinicore::op { -class Embedding { -public: - using schema = void (*)(Tensor, Tensor, Tensor); - static void execute(Tensor out, Tensor input, Tensor weight); - static common::OpDispatcher &dispatcher(); -}; - Tensor embedding(Tensor input, Tensor weight); void embedding_(Tensor out, Tensor input, Tensor weight); } // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index ca42e1509..246180e65 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -9,7 +9,6 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" -#include "infiniop/ops/embedding.h" #include "infiniop/ops/flash_attention.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" diff --git a/include/infiniop/ops/embedding.h b/include/infiniop/ops/embedding.h deleted file mode 100644 index cd1df3a73..000000000 --- a/include/infiniop/ops/embedding.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef __INFINIOP_EMBEDDING_API_H__ -#define __INFINIOP_EMBEDDING_API_H__ - -#include "../operator_descriptor.h" - -typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t; - -__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor( - infiniopHandle_t handle, - infiniopEmbeddingDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t output_desc, - infiniopTensorDescriptor_t input_desc, - infiniopTensorDescriptor_t weight_desc); - -__C __export infiniStatus_t infiniopEmbedding( - infiniopEmbeddingDescriptor_t desc, - void *output, - const void *input, - const void *weight, - void *stream); - -__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor( - infiniopEmbeddingDescriptor_t desc); - -#endif diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc index 96f19803c..f1add0c97 100644 --- a/src/infinicore/ops/embedding/embedding.cc +++ b/src/infinicore/ops/embedding/embedding.cc @@ -1,34 +1,15 @@ #include "infinicore/ops/embedding.hpp" -#include "../../utils.hpp" #include "infinicore/context/context.hpp" #include -#include namespace infinicore::op { -common::OpDispatcher &Embedding::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -} - -void Embedding::execute(Tensor out, Tensor input, Tensor weight) { - // Check that all tensors are on the same device - // This is critical: if input is on CPU while out/weight are on GPU, - // passing CPU pointer to CUDA kernel will cause memory access errors - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, input, weight); - - // Set device context - infinicore::context::setDevice(out->device()); - - // Use dispatcher to lookup kernel (infiniop implementation) - dispatcher().lookup(out->device().getType())(out, input, weight); -} - Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 ) { auto input_shape = input->shape(); auto weight_shape = weight->shape(); + // auto vocab_size = weight_shape[0]; auto embedding_dim = weight_shape[1]; // Assign memory to out variables @@ -41,7 +22,68 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i } void embedding_(Tensor out, Tensor input, Tensor weight) { - Embedding::execute(out, input, weight); + assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); + assert(infinicore::Device::Type::CPU == input->device().getType()); + + auto input_shape = input->shape(); + auto weight_shape = weight->shape(); + auto embedding_dim = weight_shape[1]; + + // Calculate the number of token + Size counts = 1; + for (auto &v : input_shape) { + counts *= v; + } + + // the bytes of one token + const Size bytes = dsize(weight->dtype()) * embedding_dim; + auto *weight_ptr = weight->data(); + auto *out_ptr = out->data(); + + // copies + if (weight->device().getType() == Device::Type::CPU) { + if (infinicore::DataType::I64 == input->dtype()) { + const int64_t *input_arr = reinterpret_cast(input->data()); + for (Size i = 0; i < counts; ++i) { + int64_t idx = input_arr[i]; + assert((idx >= 0) && (idx < weight_shape[0])); + std::memcpy(out_ptr + i * bytes, + weight_ptr + idx * bytes, + bytes); + } + } else if (infinicore::DataType::I32 == input->dtype()) { + const int32_t *input_arr = reinterpret_cast(input->data()); + + for (Size i = 0; i < counts; ++i) { + int32_t idx = input_arr[i]; + assert((idx >= 0) && (idx < weight_shape[0])); + std::memcpy(out_ptr + i * bytes, + weight_ptr + idx * bytes, + bytes); + } + } + + } else { + if (infinicore::DataType::I64 == input->dtype()) { + const int64_t *input_arr = reinterpret_cast(input->data()); + for (Size i = 0; i < counts; ++i) { + int64_t idx = input_arr[i]; + assert((idx >= 0) && (idx < weight_shape[0])); + context::memcpyD2D(out_ptr + i * bytes, + weight_ptr + idx * bytes, + bytes); + } + } else if (infinicore::DataType::I32 == input->dtype()) { + const int32_t *input_arr = reinterpret_cast(input->data()); + for (Size i = 0; i < counts; ++i) { + int32_t idx = input_arr[i]; + assert((idx >= 0) && (idx < weight_shape[0])); + context::memcpyD2D(out_ptr + i * bytes, + weight_ptr + idx * bytes, + bytes); + } + } + } } } // namespace infinicore::op diff --git a/src/infinicore/ops/embedding/embedding_infiniop.cc b/src/infinicore/ops/embedding/embedding_infiniop.cc deleted file mode 100644 index dfbbb2f71..000000000 --- a/src/infinicore/ops/embedding/embedding_infiniop.cc +++ /dev/null @@ -1,49 +0,0 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" -#include "infinicore/ops/embedding.hpp" -#include - -namespace infinicore::op::embedding_impl::infiniop { - -thread_local common::OpCache caches( - 100, // capacity - [](infiniopEmbeddingDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor out, Tensor input, Tensor weight) { - size_t seed = hash_combine(out, input, weight); - - auto device = context::getDevice(); - auto &cache = caches.getCache(device); - - auto desc_opt = cache.get(seed); - infiniopEmbeddingDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor( - context::getInfiniopHandle(device), &desc, - out->desc(), input->desc(), weight->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - INFINICORE_CHECK_ERROR(infiniopEmbedding( - desc, - out->data(), - input->data(), - weight->data(), - context::getStream())); -} - -static bool registered = []() { - Embedding::dispatcher().registerAll(&calculate, false); - return true; -}(); - -} // namespace infinicore::op::embedding_impl::infiniop diff --git a/src/infiniop/ops/embedding/operator.cc b/src/infiniop/ops/embedding/operator.cc deleted file mode 100644 index 0bf7864c9..000000000 --- a/src/infiniop/ops/embedding/operator.cc +++ /dev/null @@ -1,89 +0,0 @@ -#include "../../operator.h" -#include "../../handle.h" -#include "infiniop/ops/embedding.h" - -#ifdef ENABLE_CPU_API -// #include "cpu/embedding_cpu.h" -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -// #include "nvidia/embedding_nvidia.cuh" -#endif - -__C infiniStatus_t infiniopCreateEmbeddingDescriptor( - infiniopHandle_t handle, - infiniopEmbeddingDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t output_desc, - infiniopTensorDescriptor_t input_desc, - infiniopTensorDescriptor_t weight_desc) { - -#define CREATE(CASE, NAMESPACE) \ - case CASE: \ - return op::embedding::NAMESPACE::Descriptor::create( \ - handle, \ - reinterpret_cast(desc_ptr), \ - output_desc, \ - input_desc, \ - weight_desc) - - switch (handle->device) { - -#ifdef ENABLE_CPU_API - // CREATE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) - // CREATE(INFINI_DEVICE_NVIDIA, nvidia); -#endif - default: - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; - } - -#undef CREATE -} - -__C infiniStatus_t infiniopEmbedding( - infiniopEmbeddingDescriptor_t desc, - void *output, - const void *input, - const void *weight, - void *stream) { - -#define CALCULATE(CASE, NAMESPACE) \ - case CASE: \ - return reinterpret_cast(desc) \ - ->calculate(output, input, weight, stream) - - switch (desc->device_type) { - -#ifdef ENABLE_CPU_API - // CALCULATE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) - // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); -#endif - default: - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; - } - -#undef CALCULATE -} - -__C infiniStatus_t infiniopDestroyEmbeddingDescriptor(infiniopEmbeddingDescriptor_t desc) { - -#define DELETE(CASE, NAMESPACE) \ - case CASE: \ - delete reinterpret_cast(desc); \ - return INFINI_STATUS_SUCCESS; - - switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // DELETE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) - // DELETE(INFINI_DEVICE_NVIDIA, nvidia); -#endif - } - -#undef DELETE - - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; -} From c1fb26b5bb798b0278ae59549f9ebff718a59762 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 16 Jan 2026 15:53:57 +0800 Subject: [PATCH 06/14] issue/940 - check build result and implicitly require build.py --- scripts/build_ntops.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/build_ntops.py b/scripts/build_ntops.py index e1397e56d..601249615 100644 --- a/scripts/build_ntops.py +++ b/scripts/build_ntops.py @@ -21,9 +21,14 @@ def _find_and_build_ops(): if not ninetoothed_path.is_dir(): continue + build_file = ninetoothed_path / "build.py" + if not build_file.exists(): + continue + futures.append(executor.submit(_build, ninetoothed_path)) - concurrent.futures.as_completed(futures) + for future in concurrent.futures.as_completed(futures): + future.result() def _build(ninetoothed_path): From e7dcc3cf786965ea7a0203c7ce32e4b95a74597a Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 15 Jan 2026 15:38:53 +0800 Subject: [PATCH 07/14] issue/931 - ninetoothed swiglu --- src/infiniop/ops/swiglu/ninetoothed/build.py | 29 +++++++ src/infiniop/ops/swiglu/ninetoothed/swiglu.h | 82 +++++++++++++++++++ src/infiniop/ops/swiglu/ninetoothed/swiglu.py | 22 +++++ src/infiniop/ops/swiglu/operator.cc | 56 +++++++++++++ 4 files changed, 189 insertions(+) create mode 100644 src/infiniop/ops/swiglu/ninetoothed/build.py create mode 100644 src/infiniop/ops/swiglu/ninetoothed/swiglu.h create mode 100644 src/infiniop/ops/swiglu/ninetoothed/swiglu.py diff --git a/src/infiniop/ops/swiglu/ninetoothed/build.py b/src/infiniop/ops/swiglu/ninetoothed/build.py new file mode 100644 index 000000000..fa4af6db2 --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/build.py @@ -0,0 +1,29 @@ +import ninetoothed +from . import swiglu + +import infiniop.ninetoothed.build + + +def build(): + MAX_NDIM = 5 + + ndim_values = range(1, MAX_NDIM + 1) + dtype_values = ( + ninetoothed.float16, + ninetoothed.bfloat16, + ninetoothed.float32, + ) + + constexpr_param_grid = { + "ndim": ndim_values, + "dtype": dtype_values, + "block_size": (1024,), + } + + infiniop.ninetoothed.build.build( + swiglu.premake, + constexpr_param_grid, + caller="cuda", + op_name="swiglu", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/swiglu/ninetoothed/swiglu.h b/src/infiniop/ops/swiglu/ninetoothed/swiglu.h new file mode 100644 index 000000000..4aa2fa70e --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/swiglu.h @@ -0,0 +1,82 @@ +#ifndef SWIGLU_H +#define SWIGLU_H + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/swiglu.h" +#include "../../../ninetoothed/utils.h" + +namespace op::swiglu::ninetoothed { +class Descriptor final : public InfiniopDescriptor { + +public: + Descriptor( + infiniopHandle_t handle, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) : InfiniopDescriptor{handle->device, handle->device_id}, + out_shape_{out_desc->shape()}, + out_strides_{out_desc->strides()}, + up_shape_{input_desc_vec[0]->shape()}, + up_strides_{input_desc_vec[0]->strides()}, + gate_shape_{input_desc_vec[1]->shape()}, + gate_strides_{input_desc_vec[1]->strides()}, + dtype_{out_desc->dtype()} {} + + ~Descriptor() = default; + + size_t workspaceSize() const { + return 0; + } + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + *desc_ptr = new Descriptor(handle, out_desc, input_desc_vec); + return INFINI_STATUS_SUCCESS; + } + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + auto out_nt{::ninetoothed::Tensor(output, out_shape_, out_strides_)}; + auto up_nt{::ninetoothed::Tensor(inputs[0], up_shape_, up_strides_)}; + auto gate_nt{::ninetoothed::Tensor(inputs[1], gate_shape_, gate_strides_)}; + + if (launch_swiglu(stream, + out_nt, + up_nt, + gate_nt, + out_shape_.size(), + dtype_, + 1024)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector out_shape_; + std::vector out_strides_; + + std::vector up_shape_; + std::vector up_strides_; + + std::vector gate_shape_; + std::vector gate_strides_; + + infiniDtype_t dtype_; +}; +} // namespace op::swiglu::ninetoothed + +#endif // SWIGLU_H diff --git a/src/infiniop/ops/swiglu/ninetoothed/swiglu.py b/src/infiniop/ops/swiglu/ninetoothed/swiglu.py new file mode 100644 index 000000000..62074a84b --- /dev/null +++ b/src/infiniop/ops/swiglu/ninetoothed/swiglu.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(output, up, gate): + output = ntl.sigmoid(ntl.cast(gate, ntl.float32)) * gate * up # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/swiglu/operator.cc b/src/infiniop/ops/swiglu/operator.cc index 9d8e6406a..b3fabba32 100644 --- a/src/infiniop/ops/swiglu/operator.cc +++ b/src/infiniop/ops/swiglu/operator.cc @@ -6,14 +6,22 @@ #include "cpu/swiglu_cpu.h" #endif #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#if defined(ENABLE_NINETOOTHED) +#include "ninetoothed/swiglu.h" +#else #include "nvidia/swiglu_nvidia.cuh" #endif +#endif #ifdef ENABLE_KUNLUN_API #include "kunlun/swiglu_kunlun.h" #endif #ifdef ENABLE_METAX_API +#if defined(ENABLE_NINETOOTHED) +#include "ninetoothed/swiglu.h" +#else #include "metax/swiglu_metax.h" #endif +#endif #ifdef ENABLE_CAMBRICON_API #include "bang/swiglu_bang.h" #endif @@ -46,11 +54,19 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( CREATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API CREATE(INFINI_DEVICE_QY, nvidia); #endif @@ -61,8 +77,12 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + CREATE(INFINI_DEVICE_METAX, ninetoothed); +#else CREATE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API CREATE(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -92,11 +112,19 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des GET(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_NVIDIA, ninetoothed); +#else GET(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API GET(INFINI_DEVICE_QY, nvidia); #endif @@ -107,8 +135,12 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des GET(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + GET(INFINI_DEVICE_METAX, ninetoothed); +#else GET(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API GET(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -145,11 +177,19 @@ __C infiniStatus_t infiniopSwiGLU( CALCULATE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API CALCULATE(INFINI_DEVICE_QY, nvidia); #endif @@ -160,8 +200,12 @@ __C infiniStatus_t infiniopSwiGLU( CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + CALCULATE(INFINI_DEVICE_METAX, ninetoothed); +#else CALCULATE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API CALCULATE(INFINI_DEVICE_CAMBRICON, bang); #endif @@ -193,11 +237,19 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { DELETE(INFINI_DEVICE_CPU, cpu); #endif #ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_NVIDIA, ninetoothed); +#else DELETE(INFINI_DEVICE_NVIDIA, nvidia); #endif +#endif #ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#else DELETE(INFINI_DEVICE_ILUVATAR, nvidia); #endif +#endif #ifdef ENABLE_QY_API DELETE(INFINI_DEVICE_QY, nvidia); #endif @@ -208,8 +260,12 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { DELETE(INFINI_DEVICE_KUNLUN, kunlun); #endif #ifdef ENABLE_METAX_API +#ifdef ENABLE_NINETOOTHED + DELETE(INFINI_DEVICE_METAX, ninetoothed); +#else DELETE(INFINI_DEVICE_METAX, metax); #endif +#endif #ifdef ENABLE_CAMBRICON_API DELETE(INFINI_DEVICE_CAMBRICON, bang); #endif From fbb4f3c0f7bf187c73c237d4d8fedc83fba1a5c8 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 14 Jan 2026 13:01:26 +0800 Subject: [PATCH 08/14] issue/923 - ninetoothed kv_caching --- include/infinicore/ops/kv_caching.hpp | 26 +--- python/infinicore/__init__.py | 2 + python/infinicore/ops/kv_caching.py | 13 ++ src/infinicore/ops/kv_caching/kv_caching.cc | 57 ++++---- .../ops/kv_caching/kv_caching_infiniop.cc | 87 ++++++------ src/infinicore/pybind11/ops.hpp | 8 +- src/infinicore/pybind11/ops/kv_caching.hpp | 32 +++++ .../flash_attention/ninetoothed/descriptor.h | 38 ++--- .../ops/kv_caching/ninetoothed/build.py | 27 ++++ .../ops/kv_caching/ninetoothed/kv_caching.h | 101 +++++++++++++ .../ops/kv_caching/ninetoothed/kv_caching.py | 66 +++++++++ src/infiniop/ops/kv_caching/operator.cc | 62 +++++--- test/infinicore/framework/base.py | 13 +- test/infinicore/ops/kv_caching.py | 134 ++++++++++++++++++ 14 files changed, 529 insertions(+), 137 deletions(-) create mode 100644 python/infinicore/ops/kv_caching.py create mode 100644 src/infinicore/pybind11/ops/kv_caching.hpp create mode 100644 src/infiniop/ops/kv_caching/ninetoothed/build.py create mode 100644 src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h create mode 100644 src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py create mode 100644 test/infinicore/ops/kv_caching.py diff --git a/include/infinicore/ops/kv_caching.hpp b/include/infinicore/ops/kv_caching.hpp index e4b6f514c..3a70c2824 100644 --- a/include/infinicore/ops/kv_caching.hpp +++ b/include/infinicore/ops/kv_caching.hpp @@ -1,28 +1,16 @@ -#pragma +#pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" namespace infinicore::op { -class KVCaching { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); - static void execute(Tensor k_cache, - Tensor v_cache, - Tensor k, - Tensor v, - Tensor past_kv_lengths); - static common::OpDispatcher &dispatcher(); -}; -Tensor kv_caching(Tensor k_cache, - Tensor v_cache, - Tensor k, - Tensor v, - Tensor past_kv_lengths); +INFINICORE_GRAPH_OP_CLASS(KVCaching, Tensor, Tensor, const Tensor &, const Tensor &, const Tensor &); + void kv_caching_(Tensor k_cache, Tensor v_cache, - Tensor k, - Tensor v, - Tensor past_kv_lengths); + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths); } // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index c6b01d5aa..845bbcc0a 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -45,6 +45,7 @@ from infinicore.ops.add import add from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_ from infinicore.ops.attention import attention +from infinicore.ops.kv_caching import kv_caching from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow @@ -115,6 +116,7 @@ "add_rms_norm", "add_rms_norm_", "attention", + "kv_caching", "matmul", "mul", "narrow", diff --git a/python/infinicore/ops/kv_caching.py b/python/infinicore/ops/kv_caching.py new file mode 100644 index 000000000..b34f2346e --- /dev/null +++ b/python/infinicore/ops/kv_caching.py @@ -0,0 +1,13 @@ +from infinicore.lib import _infinicore + + +def kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + _infinicore.kv_caching_( + k_cache._underlying, + v_cache._underlying, + k._underlying, + v._underlying, + past_kv_lengths._underlying, + ) + + return k_cache, v_cache diff --git a/src/infinicore/ops/kv_caching/kv_caching.cc b/src/infinicore/ops/kv_caching/kv_caching.cc index bed3a4566..0110f7973 100644 --- a/src/infinicore/ops/kv_caching/kv_caching.cc +++ b/src/infinicore/ops/kv_caching/kv_caching.cc @@ -2,46 +2,41 @@ #include "../../utils.hpp" -#include - namespace infinicore::op { +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(KVCaching); -common::OpDispatcher &KVCaching::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void KVCaching::execute(Tensor k_cache, - Tensor v_cache, - Tensor k, - Tensor v, - Tensor past_kv_lengths) { +KVCaching::KVCaching(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, past_kv_lengths); - infinicore::context::setDevice(k_cache->device()); - auto device_type = k_cache->device().getType(); - auto func = dispatcher().lookup(device_type); - - if (func == nullptr) { - throw std::runtime_error("No KVCaching implementation found for device type: " + std::to_string(static_cast(device_type))); - } - - func(k_cache, v_cache, k, v, past_kv_lengths); + INFINICORE_GRAPH_OP_DISPATCH(k_cache->device().getType(), + k_cache, + v_cache, + k, + v, + past_kv_lengths); } -Tensor kv_caching(Tensor k_cache, - Tensor v_cache, - Tensor k, - Tensor v, - Tensor past_kv_lengths) { - KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths); - return k_cache; // or v_cache, depending on the intended use +void KVCaching::execute(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(KVCaching, + k_cache, + v_cache, + k, + v, + past_kv_lengths); } void kv_caching_(Tensor k_cache, Tensor v_cache, - Tensor k, - Tensor v, - Tensor past_kv_lengths) { + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { KVCaching::execute(k_cache, v_cache, k, v, past_kv_lengths); } } // namespace infinicore::op diff --git a/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc index 37d5e1fa3..53ea5f0ae 100644 --- a/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc +++ b/src/infinicore/ops/kv_caching/kv_caching_infiniop.cc @@ -1,59 +1,60 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" +#include "../infiniop_impl.hpp" #include "infinicore/ops/kv_caching.hpp" -#include namespace infinicore::op::kv_caching_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopKVCachingDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyKVCachingDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor k_cache, - Tensor v_cache, - Tensor k, - Tensor v, - Tensor past_kv_lengths) { +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, KVCaching, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, k_cache, v_cache, k, v, past_kv_lengths; +}; + +void *plan(Tensor k_cache, + Tensor v_cache, + const Tensor &k, + const Tensor &v, + const Tensor &past_kv_lengths) { size_t seed = hash_combine(k_cache, v_cache, k, v, past_kv_lengths); - auto device = context::getDevice(); - auto &cache = caches.getCache(device); + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, KVCaching, + seed, k_cache->desc(), v_cache->desc(), + k->desc(), v->desc(), past_kv_lengths->desc()); + + INFINIOP_WORKSPACE_TENSOR(workspace, KVCaching, descriptor); - auto desc_opt = cache.get(seed); - infiniopKVCachingDescriptor_t desc = nullptr; + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(k), + graph::GraphTensor(v), + graph::GraphTensor(past_kv_lengths)}; - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateKVCachingDescriptor( - context::getInfiniopHandle(device), &desc, - k_cache->desc(), v_cache->desc(), - k->desc(), v->desc(), - past_kv_lengths->desc())); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return planned; +} - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetKVCachingWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopKVCaching( - desc, workspace->data(), workspace_size, - k_cache->data(), v_cache->data(), - k->data(), v->data(), - past_kv_lengths->data(), + planned->descriptor->desc, + nullptr, 0, + planned->k_cache->data(), + planned->v_cache->data(), + planned->k->data(), + planned->v->data(), + planned->past_kv_lengths->data(), context::getStream())); } -static bool registered = []() { - KVCaching::dispatcher().registerAll(&calculate, false); - return true; -}(); +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(KVCaching, &plan, &run, cleanup); } // namespace infinicore::op::kv_caching_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index e2d5aa00b..c3b781050 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -8,6 +8,7 @@ #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" #include "ops/flash_attention.hpp" +#include "ops/kv_caching.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" @@ -30,20 +31,21 @@ inline void bind(py::module &m) { bind_add_rms_norm(m); bind_attention(m); bind_causal_softmax(m); + bind_embedding(m); bind_flash_attention(m); - bind_random_sample(m); + bind_kv_caching(m); bind_linear(m); bind_matmul(m); bind_mul(m); bind_paged_attention(m); bind_paged_attention_prefill(m); bind_paged_caching(m); + bind_random_sample(m); bind_rearrange(m); bind_rms_norm(m); + bind_rope(m); bind_silu(m); bind_swiglu(m); - bind_rope(m); - bind_embedding(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/kv_caching.hpp b/src/infinicore/pybind11/ops/kv_caching.hpp new file mode 100644 index 000000000..2864312b2 --- /dev/null +++ b/src/infinicore/pybind11/ops/kv_caching.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include + +#include "infinicore/ops/kv_caching.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_kv_caching(py::module &m) { + m.def("kv_caching_", + &op::kv_caching_, + py::arg("k_cache"), + py::arg("v_cache"), + py::arg("k"), + py::arg("v"), + py::arg("past_kv_lengths"), + R"doc(In-place Key-Value Caching. + +Updates the KV cache in-place with new key and value tensors. + +Args: + k_cache: Key cache tensor to update in-place + v_cache: Value cache tensor to update in-place + k: New key tensor to append + v: New value tensor to append + past_kv_lengths: Tensor containing current sequence lengths for each batch +)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h index 697891d3d..09257d16d 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -67,24 +67,26 @@ class Descriptor final : public InfiniopDescriptor { constexpr auto block_size_m_{64}; constexpr auto block_size_n_{64}; - launch_flash_attention(stream, - query, - key, - value, - attn_mask, - is_causal, - scale, - output, - with_attn_mask, - causal_variant, - with_kv_cache_, - emb_dim_, - is_causal_, - with_attn_mask_, - causal_variant_, - dtype_, - block_size_m_, - block_size_n_); + if (launch_flash_attention(stream, + query, + key, + value, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache_, + emb_dim_, + is_causal_, + with_attn_mask_, + causal_variant_, + dtype_, + block_size_m_, + block_size_n_)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/kv_caching/ninetoothed/build.py b/src/infiniop/ops/kv_caching/ninetoothed/build.py new file mode 100644 index 000000000..03481c86b --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/build.py @@ -0,0 +1,27 @@ +import ninetoothed +from . import kv_caching + +import infiniop.ninetoothed.build + + +def build(): + dtype_values = ( + ninetoothed.float16, + ninetoothed.bfloat16, + ninetoothed.float32, + ) + + constexpr_param_grid = { + "emb_dim": (1, 16, 32, 64, 128, 256), + "dtype": dtype_values, + "block_size_m": (64,), + "block_size_n": (64,), + } + + infiniop.ninetoothed.build.build( + kv_caching.premake, + constexpr_param_grid, + caller="cuda", + op_name="kv_caching", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h new file mode 100644 index 000000000..43388f58d --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.h @@ -0,0 +1,101 @@ +#ifndef KV_CACHING_H +#define KV_CACHING_H + +#include "../../../handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +#include "../../../../../build/ninetoothed/kv_caching.h" +#include "../../../ninetoothed/utils.h" + +namespace op::kv_caching::ninetoothed { +class Descriptor final : public InfiniopDescriptor { + +public: + Descriptor( + infiniopHandle_t handle, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t past_kv_lengths_desc) : InfiniopDescriptor{handle->device, handle->device_id}, + k_cache_shape_{k_cache_desc->shape()}, + k_cache_strides_{k_cache_desc->strides()}, + v_cache_shape_{v_cache_desc->shape()}, + v_cache_strides_{v_cache_desc->strides()}, + k_shape_{k_desc->shape()}, + k_strides_{k_desc->strides()}, + v_shape_{v_desc->shape()}, + v_strides_{v_desc->strides()}, + past_kv_lengths_shape_{past_kv_lengths_desc->shape()}, + past_kv_lengths_strides_{past_kv_lengths_desc->strides()}, + dtype_{k_desc->dtype()} {} + + ~Descriptor() = default; + + size_t get_workspace_size() const { return 0; }; + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_cache, + infiniopTensorDescriptor_t v_cache, + infiniopTensorDescriptor_t k, + infiniopTensorDescriptor_t v, + infiniopTensorDescriptor_t past_kv_lengths) { + *desc_ptr = new Descriptor{handle, k_cache, v_cache, k, v, past_kv_lengths}; + return INFINI_STATUS_SUCCESS; + } + + infiniStatus_t calculate( + void *workspace, size_t workspace_size, + void *k_cache, + void *v_cache, + const void *k, + const void *v, + const void *past_kv_lengths, + void *stream) const { + auto k_cache_nt{::ninetoothed::Tensor{k_cache, k_cache_shape_, k_cache_strides_}}; + auto v_cache_nt{::ninetoothed::Tensor{v_cache, v_cache_shape_, v_cache_strides_}}; + auto k_nt{::ninetoothed::Tensor{k, k_shape_, k_strides_}}; + auto v_nt{::ninetoothed::Tensor{v, v_shape_, v_strides_}}; + auto past_kv_lengths_nt{::ninetoothed::Tensor{past_kv_lengths, past_kv_lengths_shape_, past_kv_lengths_strides_}}; + + if (launch_kv_caching(stream, + k_cache_nt, + v_cache_nt, + k_nt, + v_nt, + past_kv_lengths_nt, + k_shape_[3], + dtype_, + 64, 64)) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + return INFINI_STATUS_SUCCESS; + } + +private: + using Size = ::ninetoothed::Tensor<>::Size; + using Stride = ::ninetoothed::Tensor<>::Stride; + + std::vector k_cache_shape_; + std::vector k_cache_strides_; + + std::vector v_cache_shape_; + std::vector v_cache_strides_; + + std::vector k_shape_; + std::vector k_strides_; + std::vector v_shape_; + std::vector v_strides_; + + std::vector past_kv_lengths_shape_; + std::vector past_kv_lengths_strides_; + + infiniDtype_t dtype_; +}; +} // namespace op::kv_caching::ninetoothed + +#endif // KV_CACHING_H diff --git a/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py new file mode 100644 index 000000000..dfc5088e9 --- /dev/null +++ b/src/infiniop/ops/kv_caching/ninetoothed/kv_caching.py @@ -0,0 +1,66 @@ +import functools +import ninetoothed +from ninetoothed import Tensor + + +def arrangement( + k_cache, + v_cache, + k, + v, + past_lengths, + block_size_m=ninetoothed.block_size(), + block_size_n=ninetoothed.block_size(), +): + k_cache_arranged = k_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + v_cache_arranged = v_cache.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + + k_arranged = k.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + v_arranged = v.tile((1, block_size_m, 1, -1)).tile((1, 1, -1, 1)) + + past_lengths_arranged = ( + past_lengths.tile((1,)) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .unsqueeze(4) + .expand((-1, *k_arranged.shape)) + ) + + return ( + k_cache_arranged, + v_cache_arranged, + k_arranged, + v_arranged, + past_lengths_arranged, + ) + + +def application(k_cache, v_cache, k, v, past_lengths): + pos = past_lengths + + for i in range(k.shape[-2]): + k_cache[0, 0, pos + i, 0] = k[0, 0, i, 0] + v_cache[0, 0, pos + i, 0] = v[0, 0, i, 0] + + +def premake(emb_dim=None, dtype=None, block_size_m=None, block_size_n=None): + arrangement_ = functools.partial( + arrangement, block_size_m=block_size_m, block_size_n=block_size_n + ) + + shape_options = (None, None, None, {"constexpr": True, "upper_bound": 256}) + + tensors = ( + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(4, dtype=dtype, shape_options=shape_options), + Tensor(1, dtype=ninetoothed.int64), + ) + + if emb_dim is not None: + for tensor in tensors: + tensor.shape = tensor.shape[:-1] + (emb_dim,) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/kv_caching/operator.cc b/src/infiniop/ops/kv_caching/operator.cc index 65b27a414..34bdf9a99 100644 --- a/src/infiniop/ops/kv_caching/operator.cc +++ b/src/infiniop/ops/kv_caching/operator.cc @@ -2,11 +2,10 @@ #include "../../handle.h" #include "infiniop/ops/kv_caching.h" -#ifdef ENABLE_CPU_API -// #include "cpu/kv_caching_cpu.h" +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) || defined(ENABLE_MOORE_API) +#include "ninetoothed/kv_caching.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -// #include "nvidia/kv_caching_nvidia.cuh" #endif __C infiniStatus_t infiniopCreateKVCachingDescriptor( @@ -31,12 +30,18 @@ __C infiniStatus_t infiniopCreateKVCachingDescriptor( switch (handle->device) { -#ifdef ENABLE_CPU_API - // CREATE(INFINI_DEVICE_CPU, cpu); +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) - // CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#if defined(ENABLE_ILUVATAR_API) + CREATE(INFINI_DEVICE_ILUVATAR, ninetoothed); #endif +#if defined(ENABLE_METAX_API) + CREATE(INFINI_DEVICE_METAX, ninetoothed); +#endif +#endif + default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } @@ -55,11 +60,17 @@ __C infiniStatus_t infiniopGetKVCachingWorkspaceSize( return INFINI_STATUS_SUCCESS; switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // GET_SIZE(INFINI_DEVICE_CPU, cpu); + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + GET_SIZE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + GET_SIZE(INFINI_DEVICE_METAX, ninetoothed); #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) - // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -86,11 +97,16 @@ __C infiniStatus_t infiniopKVCaching( switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // CALCULATE(INFINI_DEVICE_CPU, cpu); +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + CALCULATE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + CALCULATE(INFINI_DEVICE_METAX, ninetoothed); #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) - // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -108,11 +124,17 @@ __C infiniStatus_t infiniopDestroyKVCachingDescriptor( return INFINI_STATUS_SUCCESS; switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // DELETE(INFINI_DEVICE_CPU, cpu); + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) + DELETE(INFINI_DEVICE_NVIDIA, ninetoothed); +#endif +#if defined(ENABLE_ILUVATAR_API) + DELETE(INFINI_DEVICE_ILUVATAR, ninetoothed); +#endif +#if defined(ENABLE_METAX_API) + DELETE(INFINI_DEVICE_METAX, ninetoothed); #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) - // DELETE(INFINI_DEVICE_NVIDIA, nvidia); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/test/infinicore/framework/base.py b/test/infinicore/framework/base.py index 87222b299..80dcb3eb1 100644 --- a/test/infinicore/framework/base.py +++ b/test/infinicore/framework/base.py @@ -342,7 +342,10 @@ def prepare_infinicore_inputs_and_kwargs(self, inputs, kwargs, comparison_target for i, inp in enumerate(inputs): if isinstance(inp, torch.Tensor): # Clone only if this input will be used for comparison - if comparison_target == i: + if comparison_target == i or ( + isinstance(comparison_target, (list, tuple)) + and i in comparison_target + ): cloned_inp = clone_torch_tensor(inp) infini_tensor = infinicore_tensor_from_torch(cloned_inp) cloned_tensors.append(cloned_inp) @@ -508,7 +511,9 @@ def run_test(self, device, test_case, config): # Handle multiple outputs comparison # Determine what to compare based on comparison_target - if comparison_target is None: + if comparison_target is None or isinstance( + comparison_target, (list, tuple) + ): # Compare return values (out-of-place multiple outputs) torch_comparison = torch_result infini_comparison = infini_result @@ -573,7 +578,9 @@ def run_test(self, device, test_case, config): # ========================================================================== else: # Determine comparison targets for single output - if comparison_target is None: + if comparison_target is None or isinstance( + comparison_target, (list, tuple) + ): # Compare return values (out-of-place) torch_comparison = torch_result infini_comparison = infini_result diff --git a/test/infinicore/ops/kv_caching.py b/test/infinicore/ops/kv_caching.py new file mode 100644 index 000000000..4ca857586 --- /dev/null +++ b/test/infinicore/ops/kv_caching.py @@ -0,0 +1,134 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TensorInitializer, + TestCase, + GenericTestRunner, + is_broadcast, +) + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (shape (bs, nkvh, seq_len, hd), strides) +_TEST_CASES_DATA = [ + ((1, 1, 8, 1), None), + ((1, 8, 32, 32), None), + ((8, 8, 64, 32), None), + ((1, 32, 8, 64), (32768, 1024, 64, 1)), + ((4, 8, 32, 16), (65536, 8192, 256, 16)), + ((8, 16, 64, 128), (8388608, 524288, 8192, 1)), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 0, "rtol": 0}, + infinicore.bfloat16: {"atol": 0, "rtol": 0}, + infinicore.float32: {"atol": 0, "rtol": 0}, +} + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + test_cases = [] + + for data in _TEST_CASES_DATA: + import random + + cache_shape = data[0] + kv_shape = ( + cache_shape[0], + cache_shape[1], + random.randint(1, cache_shape[2]), + cache_shape[3], + ) + past_shape = (cache_shape[0],) + + strides = data[1] + + past_length = random.randint(0, cache_shape[2] - kv_shape[2]) + + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0}) + + cache_spec = TensorSpec.from_tensor(cache_shape, strides, dtype) + kv_spec = TensorSpec.from_tensor(kv_shape, None, dtype) + + past_kv_lengths_spec = TensorSpec.from_tensor( + past_shape, + None, + infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=past_length, + high=past_length + 1, + ) + + test_cases.append( + TestCase( + inputs=[ + cache_spec, + cache_spec, + kv_spec, + kv_spec, + past_kv_lengths_spec, + ], + kwargs={}, + output_spec=None, + comparison_target=[0, 1], + tolerance=tolerance, + description=f"KV Caching", + ) + ) + + return test_cases + + +def torch_kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + batch_size, num_kv_heads, _, head_dim = k_cache.shape + seq_len = k.shape[2] + + for b in range(batch_size): + past_len = past_kv_lengths[b].item() + for h in range(num_kv_heads): + k_cache[b, h, past_len : past_len + seq_len, :] = k[b, h, :, :] + v_cache[b, h, past_len : past_len + seq_len, :] = v[b, h, :, :] + + return k_cache, v_cache + + +def infinicore_kv_caching(k_cache, v_cache, k, v, past_kv_lengths): + infinicore.kv_caching(k_cache, v_cache, k, v, past_kv_lengths) + return k_cache, v_cache + + +class OpTest(BaseOperatorTest): + def __init__(self): + super().__init__("KV Caching") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_kv_caching(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infinicore_kv_caching(*args, **kwargs) + + +def main(): + test_runner = GenericTestRunner(OpTest) + test_runner.run_and_exit() + + +if __name__ == "__main__": + main() From 22c3aeb301a6492979ca30978449b81404e1609b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 20 Jan 2026 02:34:40 +0000 Subject: [PATCH 09/14] issue/948 - modify flash attn interface --- include/infinicore/ops/flash_attention.hpp | 12 +-- include/infiniop/ops/flash_attention.h | 1 + .../scaled_dot_product_attention.py | 7 +- .../ops/flash_attention/flash_attention.cc | 24 +++--- .../flash_attention_infiniop.cc | 79 ++++++++++--------- .../pybind11/ops/flash_attention.hpp | 1 + .../flash_attention/ninetoothed/descriptor.h | 9 ++- src/infiniop/ops/flash_attention/operator.cc | 2 + 8 files changed, 75 insertions(+), 60 deletions(-) diff --git a/include/infinicore/ops/flash_attention.hpp b/include/infinicore/ops/flash_attention.hpp index 957255192..33084aa45 100644 --- a/include/infinicore/ops/flash_attention.hpp +++ b/include/infinicore/ops/flash_attention.hpp @@ -4,13 +4,9 @@ #include "common/op.hpp" namespace infinicore::op { -class FlashAttention { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, float, bool); - static void execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal); - static common::OpDispatcher &dispatcher(); -}; -Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal); -void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal); +INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, Tensor, Tensor, Tensor, std::size_t, float, bool); + +Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal); +void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal); } // namespace infinicore::op diff --git a/include/infiniop/ops/flash_attention.h b/include/infiniop/ops/flash_attention.h index 06c3ff47c..2bcb9fe77 100644 --- a/include/infiniop/ops/flash_attention.h +++ b/include/infiniop/ops/flash_attention.h @@ -12,6 +12,7 @@ __C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor( infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, + std::size_t total_kv_len, float scale, char is_causal); diff --git a/python/infinicore/nn/functional/scaled_dot_product_attention.py b/python/infinicore/nn/functional/scaled_dot_product_attention.py index d89f484fe..cc43e890f 100644 --- a/python/infinicore/nn/functional/scaled_dot_product_attention.py +++ b/python/infinicore/nn/functional/scaled_dot_product_attention.py @@ -23,6 +23,11 @@ def scaled_dot_product_attention( return Tensor( _infinicore.flash_attention( - query._underlying, key._underlying, value._underlying, scale, is_causal + query._underlying, + key._underlying, + value._underlying, + key.shape[-2], + scale, + is_causal, ) ) diff --git a/src/infinicore/ops/flash_attention/flash_attention.cc b/src/infinicore/ops/flash_attention/flash_attention.cc index 97db6de79..92a854710 100644 --- a/src/infinicore/ops/flash_attention/flash_attention.cc +++ b/src/infinicore/ops/flash_attention/flash_attention.cc @@ -4,26 +4,26 @@ namespace infinicore::op { -common::OpDispatcher &FlashAttention::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention); -void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { +FlashAttention::FlashAttention(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v); - infinicore::context::setDevice(out->device()); - dispatcher().lookup(out->device().getType())( - out, q, k, v, scale, is_causal); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k, v, total_kv_len, scale, is_causal); } -Tensor flash_attention(Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { +void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal); +} + +Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { Shape shape = q->shape(); auto out = Tensor::empty(shape, q->dtype(), q->device()); - flash_attention_(out, q, k, v, scale, is_causal); + flash_attention_(out, q, k, v, total_kv_len, scale, is_causal); return out; } -void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { - FlashAttention::execute(out, q, k, v, scale, is_causal); +void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { + FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal); } } // namespace infinicore::op diff --git a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc index e0a91e681..b714744f0 100644 --- a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc +++ b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc @@ -1,4 +1,5 @@ #include "../../utils.hpp" +#include "../infiniop_impl.hpp" #include "infinicore/common/hash.hpp" #include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/flash_attention.hpp" @@ -6,46 +7,50 @@ namespace infinicore::op::flash_attention_impl::infiniop { -thread_local common::OpCache caches( - 100, // capacity - [](infiniopFlashAttentionDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyFlashAttentionDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor out, Tensor q, Tensor k, Tensor v, float scale, bool is_causal) { - size_t seed = hash_combine(out, q, k, v, scale, is_causal); - - auto device = context::getDevice(); - auto &cache = caches.getCache(device); - - auto desc_opt = cache.get(seed); - infiniopFlashAttentionDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateFlashAttentionDescriptor( - context::getInfiniopHandle(device), &desc, - out->desc(), q->desc(), k->desc(), v->desc(), - scale, static_cast(is_causal))); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetFlashAttentionWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, out, q, k, v; + std::size_t total_kv_len; + float scale; + bool is_causal; +}; + +void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { + size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal); + + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, FlashAttention, + seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len, scale, is_causal); + + INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor); + + auto planned = new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k), + graph::GraphTensor(v), + total_kv_len, scale, is_causal}; + + return planned; +} + +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopFlashAttention( - desc, workspace->data(), workspace_size, - out->data(), q->data(), k->data(), v->data(), context::getStream())); + planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), + planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), context::getStream())); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; } -static bool registered = []() { - FlashAttention::dispatcher().registerAll(&calculate, false); - return true; -}(); +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(FlashAttention, &plan, &run, &cleanup); } // namespace infinicore::op::flash_attention_impl::infiniop diff --git a/src/infinicore/pybind11/ops/flash_attention.hpp b/src/infinicore/pybind11/ops/flash_attention.hpp index 09ec91980..6e3766796 100644 --- a/src/infinicore/pybind11/ops/flash_attention.hpp +++ b/src/infinicore/pybind11/ops/flash_attention.hpp @@ -14,6 +14,7 @@ inline void bind_flash_attention(py::module &m) { py::arg("q"), py::arg("k"), py::arg("v"), + py::arg("total_kv_len"), py::arg("scale"), py::arg("is_causal")); } diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h index 09257d16d..f39d9d045 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -17,6 +17,7 @@ class Descriptor final : public InfiniopDescriptor { infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, + std::size_t total_kv_len, double scale, char is_causal) : InfiniopDescriptor{handle->device, handle->device_id}, _query_shape{q_desc->shape()}, @@ -28,7 +29,10 @@ class Descriptor final : public InfiniopDescriptor { _output_strides{out_desc->strides()}, _dtype{q_desc->dtype()}, _scale{scale}, - _is_causal{is_causal} {} + _is_causal{is_causal} { + _key_shape[_key_shape.size() - 2] = total_kv_len; + _value_shape[_key_shape.size() - 2] = total_kv_len; + } ~Descriptor() = default; @@ -97,9 +101,10 @@ class Descriptor final : public InfiniopDescriptor { infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, + std::size_t total_kv_len, double scale, char is_causal) { - *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, scale, is_causal}; + *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal}; return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc index e907d3c41..6ce530fd3 100644 --- a/src/infiniop/ops/flash_attention/operator.cc +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -20,6 +20,7 @@ __C infiniStatus_t infiniopCreateFlashAttentionDescriptor( infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, + std::size_t total_kv_len, float scale, char is_causal) { @@ -32,6 +33,7 @@ __C infiniStatus_t infiniopCreateFlashAttentionDescriptor( q_desc, \ k_desc, \ v_desc, \ + total_kv_len, \ scale, \ is_causal); From 4ba239d97a8c3fc576e225f8c303c5a0ea04f686 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 21 Jan 2026 19:55:55 +0800 Subject: [PATCH 10/14] issue/889 - changed size_t to tensor in flash attn interface --- include/infinicore/ops/flash_attention.hpp | 6 +- include/infiniop/ops/flash_attention.h | 3 +- python/infinicore/nn/functional/__init__.py | 8 +- .../nn/functional/flash_attention.py | 34 +++ .../scaled_dot_product_attention.py | 2 + .../ops/flash_attention/flash_attention.cc | 10 +- .../flash_attention_infiniop.cc | 11 +- .../ops/flash_attention/ninetoothed/build.py | 8 +- .../flash_attention/ninetoothed/descriptor.h | 15 +- .../ninetoothed/flash_attention.py | 281 ++++++++++++++++++ src/infiniop/ops/flash_attention/operator.cc | 49 +-- test/infinicore/ops/flash_attention.py | 115 +++++++ 12 files changed, 483 insertions(+), 59 deletions(-) create mode 100644 python/infinicore/nn/functional/flash_attention.py create mode 100644 src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py create mode 100644 test/infinicore/ops/flash_attention.py diff --git a/include/infinicore/ops/flash_attention.hpp b/include/infinicore/ops/flash_attention.hpp index 33084aa45..24e33cfb6 100644 --- a/include/infinicore/ops/flash_attention.hpp +++ b/include/infinicore/ops/flash_attention.hpp @@ -5,8 +5,8 @@ namespace infinicore::op { -INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, Tensor, Tensor, Tensor, std::size_t, float, bool); +INFINICORE_GRAPH_OP_CLASS(FlashAttention, Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, float, bool); -Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal); -void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal); +Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal); +void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal); } // namespace infinicore::op diff --git a/include/infiniop/ops/flash_attention.h b/include/infiniop/ops/flash_attention.h index 2bcb9fe77..5ea71335b 100644 --- a/include/infiniop/ops/flash_attention.h +++ b/include/infiniop/ops/flash_attention.h @@ -12,7 +12,7 @@ __C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor( infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, - std::size_t total_kv_len, + infiniopTensorDescriptor_t total_kv_len, float scale, char is_causal); @@ -28,6 +28,7 @@ __C __export infiniStatus_t infiniopFlashAttention( const void *q, const void *k, const void *v, + const void *total_kv_len, void *stream); __C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor( diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index f8c7d6ef0..d34490365 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,5 +1,6 @@ from .causal_softmax import causal_softmax from .embedding import embedding +from .flash_attention import flash_attention from .linear import linear from .random_sample import random_sample from .rms_norm import rms_norm @@ -10,13 +11,14 @@ __all__ = [ "causal_softmax", + "embedding", + "flash_attention", + "linear", "random_sample", "rms_norm", + "rope", "scaled_dot_product_attention", "silu", "swiglu", - "linear", - "embedding", - "rope", "RopeAlgo", ] diff --git a/python/infinicore/nn/functional/flash_attention.py b/python/infinicore/nn/functional/flash_attention.py new file mode 100644 index 000000000..8f42e865f --- /dev/null +++ b/python/infinicore/nn/functional/flash_attention.py @@ -0,0 +1,34 @@ +import math + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def flash_attention( + query, + key, + value, + total_kv_len, + attn_mask=None, + dropout_p=0, + is_causal=False, + scale=None, + enable_gqa=False, +): + assert attn_mask is None and dropout_p == 0 and not enable_gqa + + emb_dim = query.shape[-1] + + if scale is None: + scale = 1 / math.sqrt(emb_dim) + + return Tensor( + _infinicore.flash_attention( + query._underlying, + key._underlying, + value._underlying, + total_kv_len._underlying, + scale, + is_causal, + ) + ) diff --git a/python/infinicore/nn/functional/scaled_dot_product_attention.py b/python/infinicore/nn/functional/scaled_dot_product_attention.py index cc43e890f..0b780e562 100644 --- a/python/infinicore/nn/functional/scaled_dot_product_attention.py +++ b/python/infinicore/nn/functional/scaled_dot_product_attention.py @@ -14,6 +14,8 @@ def scaled_dot_product_attention( scale=None, enable_gqa=False, ): + raise NotImplementedError("Scaled Dot Product Attention is not yet supported.") + assert attn_mask is None and dropout_p == 0 and not enable_gqa emb_dim = query.shape[-1] diff --git a/src/infinicore/ops/flash_attention/flash_attention.cc b/src/infinicore/ops/flash_attention/flash_attention.cc index 92a854710..21cd56010 100644 --- a/src/infinicore/ops/flash_attention/flash_attention.cc +++ b/src/infinicore/ops/flash_attention/flash_attention.cc @@ -6,24 +6,26 @@ namespace infinicore::op { INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(FlashAttention); -FlashAttention::FlashAttention(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +FlashAttention::FlashAttention(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k, v); INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, q, k, v, total_kv_len, scale, is_causal); } -void FlashAttention::execute(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +void FlashAttention::execute(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { INFINICORE_GRAPH_OP_RECORD_OR_RUN(FlashAttention, out, q, k, v, total_kv_len, scale, is_causal); } -Tensor flash_attention(Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +Tensor flash_attention(const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { Shape shape = q->shape(); + int idx = shape.size() - 1; + shape[idx] = v->shape()[idx]; auto out = Tensor::empty(shape, q->dtype(), q->device()); flash_attention_(out, q, k, v, total_kv_len, scale, is_causal); return out; } -void flash_attention_(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +void flash_attention_(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { FlashAttention::execute(out, q, k, v, total_kv_len, scale, is_causal); } } // namespace infinicore::op diff --git a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc index b714744f0..f5207f0ee 100644 --- a/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc +++ b/src/infinicore/ops/flash_attention/flash_attention_infiniop.cc @@ -11,18 +11,17 @@ INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, FlashAttention, 100); struct PlannedMeta { std::shared_ptr descriptor; - graph::GraphTensor workspace, out, q, k, v; - std::size_t total_kv_len; + graph::GraphTensor workspace, out, q, k, v, total_kv_len; float scale; bool is_causal; }; -void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, float scale, bool is_causal) { +void *plan(Tensor out, const Tensor &q, const Tensor &k, const Tensor &v, const Tensor &total_kv_len, float scale, bool is_causal) { size_t seed = hash_combine(out, q, k, v, total_kv_len, scale, is_causal); INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( Descriptor, descriptor, FlashAttention, - seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len, scale, is_causal); + seed, out->desc(), q->desc(), k->desc(), v->desc(), total_kv_len->desc(), scale, is_causal); INFINIOP_WORKSPACE_TENSOR(workspace, FlashAttention, descriptor); @@ -33,7 +32,7 @@ void *plan(Tensor out, Tensor q, Tensor k, Tensor v, std::size_t total_kv_len, f graph::GraphTensor(q), graph::GraphTensor(k), graph::GraphTensor(v), - total_kv_len, scale, is_causal}; + graph::GraphTensor(total_kv_len), scale, is_causal}; return planned; } @@ -43,7 +42,7 @@ void run(void *planned_meta) { INFINICORE_CHECK_ERROR(infiniopFlashAttention( planned->descriptor->desc, planned->workspace->data(), planned->workspace->numel(), - planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), context::getStream())); + planned->out->data(), planned->q->data(), planned->k->data(), planned->v->data(), planned->total_kv_len->data(), context::getStream())); } void cleanup(void **planned_meta_ptr) { diff --git a/src/infiniop/ops/flash_attention/ninetoothed/build.py b/src/infiniop/ops/flash_attention/ninetoothed/build.py index dfcce6910..fc36eec6a 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/build.py +++ b/src/infiniop/ops/flash_attention/ninetoothed/build.py @@ -1,6 +1,6 @@ import ninetoothed -from ntops.kernels import scaled_dot_product_attention -from ntops.kernels.scaled_dot_product_attention import CausalVariant +from . import flash_attention +from .flash_attention import CausalVariant import infiniop.ninetoothed.build @@ -11,7 +11,7 @@ def build(): is_causal_values = (0, 1) with_attn_mask_values = (0,) causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT) - dtype_values = (ninetoothed.float16, ninetoothed.float32) + dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32) block_size_m_values = (64,) block_size_n_values = (64,) @@ -27,7 +27,7 @@ def build(): } infiniop.ninetoothed.build.build( - scaled_dot_product_attention.premake, + flash_attention.premake, constexpr_param_grid, caller="cuda", op_name="flash_attention", diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h index f39d9d045..d47a347e1 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -17,7 +17,7 @@ class Descriptor final : public InfiniopDescriptor { infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, - std::size_t total_kv_len, + infiniopTensorDescriptor_t total_kv_len, double scale, char is_causal) : InfiniopDescriptor{handle->device, handle->device_id}, _query_shape{q_desc->shape()}, @@ -26,12 +26,12 @@ class Descriptor final : public InfiniopDescriptor { _key_strides{k_desc->strides()}, _value_shape{v_desc->shape()}, _value_strides{v_desc->strides()}, + _total_kv_shape{total_kv_len->shape()}, + _total_kv_strides{total_kv_len->strides()}, _output_strides{out_desc->strides()}, _dtype{q_desc->dtype()}, _scale{scale}, _is_causal{is_causal} { - _key_shape[_key_shape.size() - 2] = total_kv_len; - _value_shape[_key_shape.size() - 2] = total_kv_len; } ~Descriptor() = default; @@ -46,6 +46,7 @@ class Descriptor final : public InfiniopDescriptor { const void *q, const void *k, const void *v, + const void *total_kv_len, void *stream) const { uint64_t empty_shape[4]; int64_t empty_strides[4]; @@ -53,6 +54,7 @@ class Descriptor final : public InfiniopDescriptor { auto query{::ninetoothed::Tensor{q, _query_shape, _query_strides}}; auto key{::ninetoothed::Tensor{k, _key_shape, _key_strides}}; auto value{::ninetoothed::Tensor{v, _value_shape, _value_strides}}; + auto total_kv_length{::ninetoothed::Tensor{total_kv_len, _total_kv_shape, _total_kv_strides}}; NineToothedTensor attn_mask{nullptr, empty_shape, empty_strides}; NineToothedTensor is_causal; @@ -75,6 +77,7 @@ class Descriptor final : public InfiniopDescriptor { query, key, value, + total_kv_length, attn_mask, is_causal, scale, @@ -101,7 +104,7 @@ class Descriptor final : public InfiniopDescriptor { infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, - std::size_t total_kv_len, + infiniopTensorDescriptor_t total_kv_len, double scale, char is_causal) { *desc = new Descriptor{handle, out_desc, q_desc, k_desc, v_desc, total_kv_len, scale, is_causal}; @@ -126,6 +129,10 @@ class Descriptor final : public InfiniopDescriptor { std::vector _value_strides; + std::vector _total_kv_shape; + + std::vector _total_kv_strides; + std::vector _output_strides; infiniDtype_t _dtype; diff --git a/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py new file mode 100644 index 000000000..965408408 --- /dev/null +++ b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py @@ -0,0 +1,281 @@ +import enum +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +BLOCK_SIZE_M = ninetoothed.block_size() +BLOCK_SIZE_N = ninetoothed.block_size() + + +class CausalVariant(enum.IntEnum): + """Please refer to ``_.""" + + UPPER_LEFT = enum.auto() + + LOWER_RIGHT = enum.auto() + + +def arrangement( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + with_kv_cache, + block_size_m=None, + block_size_n=None, +): + def arrange_query_or_output(input): + arranged = input.tile((1, 1, block_size_m, -1)).tile( + (1, query.shape[-3] // key.shape[-3], 1, 1) + ) + arranged.dtype = arranged.dtype.squeeze((0, 2, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_key_or_value(input): + arranged = ( + input.tile((1, 1, block_size_n, -1)) + .tile((1, 1, -1, -1)) + .expand((-1, -1, query_arranged.shape[-2], -1)) + ) + arranged.dtype = arranged.dtype.squeeze((0, 1, 3)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + def arrange_total_kv_len(input, shape): + arranged = input.tile((1,)) + arranged = arranged.unsqueeze(1).unsqueeze(2).unsqueeze(3).expand(shape) + return arranged + + def arrange_present_key_or_present_value(input): + arranged = input.tile((1, 1, block_size_m, block_size_n)) + arranged.dtype = arranged.dtype.squeeze((0, 1)) + + return arranged + + def arrange_attn_mask(input): + arranged = input.tile((1, 1, block_size_m, block_size_n)).tile((1, 1, 1, -1)) + arranged.dtype = arranged.dtype.squeeze((0, 1, 2)) + arranged.dtype.dtype = arranged.dtype.dtype.squeeze((0, 1)) + + return arranged + + if block_size_m is None: + block_size_m = BLOCK_SIZE_M + + if block_size_n is None: + block_size_n = BLOCK_SIZE_N + + query_arranged = arrange_query_or_output(query) + key_arranged = arrange_key_or_value(key) + value_arranged = arrange_key_or_value(value) + total_kv_len_arranged = arrange_total_kv_len(total_kv_len, query_arranged.shape) + present_key_arranged = arrange_present_key_or_present_value(present_key) + present_value_arranged = arrange_present_key_or_present_value(present_value) + present_key_slot_arranged = arrange_present_key_or_present_value(present_key_slot) + present_value_slot_arranged = arrange_present_key_or_present_value( + present_value_slot + ) + attn_mask_arranged = arrange_attn_mask(attn_mask) + is_causal_arranged = is_causal + scale_arranged = scale + output_arranged = arrange_query_or_output(output) + with_attn_mask_arranged = with_attn_mask + causal_variant_arranged = causal_variant + + if with_kv_cache: + return ( + query_arranged, + key_arranged, + value_arranged, + total_kv_len_arranged, + present_key_arranged, + present_value_arranged, + present_key_slot_arranged, + present_value_slot_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + causal_variant_arranged, + ) + + return ( + query_arranged, + key_arranged, + value_arranged, + total_kv_len_arranged, + attn_mask_arranged, + is_causal_arranged, + scale_arranged, + output_arranged, + with_attn_mask_arranged, + causal_variant_arranged, + ) + + +def application_with_kv_cache( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, +): + present_key_slot = present_key # noqa: F841 + present_value_slot = present_value # noqa: F841 + + application_without_kv_cache( + query, + key, + value, + total_kv_len, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + +def application_without_kv_cache( + query, + key, + value, + total_kv_len, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, +): + actual_kv_len = total_kv_len[0] + + for i in range(query.shape[0]): + query_i = (1.4426950408889634 * scale * query[i]).to(query[i].dtype) + + acc = ntl.zeros((query_i.shape[-2], query_i.shape[-1]), dtype=ntl.float32) + lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32) + max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32) + + for j in range(min(key.shape[0], actual_kv_len)): + + qk = ntl.dot(query_i, ntl.trans(key[j])) + + key_pos = key[j].offsets(-2) + qk = ntl.where(key_pos < actual_kv_len, qk, float("-inf")) + + if with_attn_mask: + qk += attn_mask[j] + + if is_causal: + query_pos = query[i].offsets(-2) + + if causal_variant == 2: + mask = ( + query_pos[:, None] + actual_kv_len - query.source.shape[-2] + >= key_pos[None, :] + ) + else: + mask = query_pos[:, None] >= key_pos[None, :] + + qk = ntl.where(mask, qk, float("-inf")) + + next_max = ntl.maximum(max, ntl.max(qk, 1)) + stable_qk = ntl.exp2(qk - next_max[:, None]) + + alpha = ntl.exp2(max - next_max) + acc = acc * alpha[:, None] + ntl.dot(stable_qk.to(value[i].dtype), value[j]) + max = next_max + lse = lse * alpha + ntl.sum(stable_qk, 1) + + acc /= lse[:, None] + output[i] = acc # noqa: F841 + + +def premake( + with_kv_cache, + emb_dim=None, + is_causal=None, + with_attn_mask=None, + causal_variant=None, + dtype=None, + block_size_m=None, + block_size_n=None, +): + arrangement_ = functools.partial( + arrangement, + with_kv_cache=with_kv_cache, + block_size_m=block_size_m, + block_size_n=block_size_n, + ) + + query, key, value, attn_mask, output = ( + Tensor( + 4, + dtype=dtype, + shape_options=(None, None, None, {"constexpr": True, "upper_bound": 128}), + ) + for _ in range(5) + ) + total_kv_len = Tensor(1, dtype=ninetoothed.int32) + present_key, present_value, present_key_slot, present_value_slot = ( + Tensor(4, dtype=dtype) for _ in range(4) + ) + scale = Tensor(0, dtype=ninetoothed.float64) + is_causal = Tensor(0, constexpr=True, value=is_causal) + with_attn_mask = Tensor(0, constexpr=True, value=with_attn_mask) + causal_variant = Tensor(0, constexpr=True, value=causal_variant) + + if emb_dim is not None: + for tensor in (query, key, value, attn_mask, output): + tensor.shape = tensor.shape[:-1] + (emb_dim,) + + if with_kv_cache: + application = application_with_kv_cache + else: + application = application_without_kv_cache + + tensors = ( + query, + key, + value, + total_kv_len, + present_key, + present_value, + present_key_slot, + present_value_slot, + attn_mask, + is_causal, + scale, + output, + with_attn_mask, + causal_variant, + ) + + return arrangement_, application, tensors diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc index 6ce530fd3..ddccf9836 100644 --- a/src/infiniop/ops/flash_attention/operator.cc +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -5,11 +5,9 @@ #ifdef ENABLE_CPU_API // #include "cpu/flash_attention_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_METAX_API) #include "ninetoothed/descriptor.h" -#else -// #include "nvidia/flash_attention_nvidia.cuh" #endif #endif @@ -20,7 +18,7 @@ __C infiniStatus_t infiniopCreateFlashAttentionDescriptor( infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_desc, infiniopTensorDescriptor_t v_desc, - std::size_t total_kv_len, + infiniopTensorDescriptor_t total_kv_len, float scale, char is_causal) { @@ -39,14 +37,9 @@ __C infiniStatus_t infiniopCreateFlashAttentionDescriptor( switch (handle->device) { -#ifdef ENABLE_CPU_API - // CREATE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) CREATE(INFINI_DEVICE_NVIDIA, ninetoothed); -#else - // CREATE(INFINI_DEVICE_NVIDIA, nvidia); #endif #endif default: @@ -66,14 +59,10 @@ __C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( return INFINI_STATUS_SUCCESS; switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // GET_SIZE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) GET_SIZE(INFINI_DEVICE_NVIDIA, ninetoothed); -#else - // GET_SIZE(INFINI_DEVICE_NVIDIA, nvidia); #endif #endif default: @@ -91,23 +80,19 @@ __C infiniStatus_t infiniopFlashAttention( const void *q, const void *k, const void *v, + const void *total_kv_len, void *stream) { #define CALCULATE(CASE, NAMESPACE) \ case CASE: \ return reinterpret_cast(desc) \ - ->calculate(workspace, workspace_size, out, q, k, v, stream); + ->calculate(workspace, workspace_size, out, q, k, v, total_kv_len, stream); switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // CALCULATE(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) CALCULATE(INFINI_DEVICE_NVIDIA, ninetoothed); -#else - // CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif #endif default: @@ -126,14 +111,10 @@ __C infiniStatus_t infiniopDestroyFlashAttentionDescriptor( return INFINI_STATUS_SUCCESS; switch (desc->device_type) { -#ifdef ENABLE_CPU_API - // DESTROY(INFINI_DEVICE_CPU, cpu); -#endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) -#if defined(ENABLE_NINETOOTHED) && defined(ENABLE_NVIDIA_API) + +#if defined(ENABLE_NINETOOTHED) +#if defined(ENABLE_NVIDIA_API) DESTROY(INFINI_DEVICE_NVIDIA, ninetoothed); -#else - // DESTROY(INFINI_DEVICE_NVIDIA, nvidia); #endif #endif default: diff --git a/test/infinicore/ops/flash_attention.py b/test/infinicore/ops/flash_attention.py new file mode 100644 index 000000000..2d4b09599 --- /dev/null +++ b/test/infinicore/ops/flash_attention.py @@ -0,0 +1,115 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TensorInitializer, + TestCase, + GenericTestRunner, +) + +# Test cases format: (q_shape, k_shape, v_shape, attn_mask_or_None, dropout_p, is_causal) +# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim) + +_TEST_CASES_DATA = [ + ((1, 1, 2, 16), (1, 1, 8, 16), (1, 1, 8, 16), None, 0.0, False), + ((1, 2, 128, 16), (1, 2, 256, 16), (1, 2, 256, 16), None, 0.0, False), + ((1, 1, 4, 32), (1, 1, 32, 32), (1, 1, 32, 32), None, 0.0, True), + ((1, 8, 256, 16), (1, 8, 512, 16), (1, 8, 512, 16), None, 0.0, True), + ((1, 8, 4, 16), (1, 8, 64, 16), (1, 8, 64, 16), None, 0.0, False), + ((8, 28, 256, 128), (8, 28, 512, 128), (8, 28, 512, 128), None, 0.0, True), +] + +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-3, "rtol": 1e-3}, +} +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + import random + + cases = [] + for q_shape, k_shape, v_shape, attn_mask, dropout_p, is_causal in _TEST_CASES_DATA: + for dtype in _TENSOR_DTYPES: + tol = _TOLERANCE_MAP[dtype] + q_spec = TensorSpec.from_tensor(q_shape, None, dtype) + k_spec = TensorSpec.from_tensor(k_shape, None, dtype) + v_spec = TensorSpec.from_tensor(v_shape, None, dtype) + + len_shape = (q_shape[0],) + total_len = random.randint(1, k_shape[2]) + total_kv_len_spec = TensorSpec.from_tensor( + len_shape, + None, + infinicore.int64, + init_mode=TensorInitializer.RANDINT, + low=total_len, + high=total_len + 1, + ) + + kwargs = { + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + } + # remove None keys + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + cases.append( + TestCase( + inputs=[q_spec, k_spec, v_spec, total_kv_len_spec, total_len], + kwargs=kwargs, + output_spec=None, + comparison_target=None, + tolerance=tol, + description="Flash Attention", + ) + ) + + return cases + + +def torch_flash_attn(q, k, v, total_kv_len, cheat, **kwargs): + k_slice = k[:, :, :cheat, :] + v_slice = v[:, :, :cheat, :] + return torch.nn.functional.scaled_dot_product_attention( + q, k_slice, v_slice, **kwargs + ) + + +def infini_flash_attn(q, k, v, total_kv_len, cheat, **kwargs): + return infinicore.nn.functional.flash_attention(q, k, v, total_kv_len, **kwargs) + + +class OpTest(BaseOperatorTest): + """ScaledDotProductAttention operator test with simplified implementation""" + + def __init__(self): + super().__init__("ScaledDotProductAttention") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + return torch_flash_attn(*args, **kwargs) + + def infinicore_operator(self, *args, **kwargs): + return infini_flash_attn(*args, **kwargs) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() From 79d142f8290ddc760f2763691153a020e7ff7be9 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 22 Jan 2026 09:51:30 +0800 Subject: [PATCH 11/14] issue/889 - cxflags to cxxflags in compile scripts --- xmake.lua | 12 ++++++------ xmake/ascend.lua | 6 +++--- xmake/bang.lua | 6 +++--- xmake/cpu.lua | 10 +++++----- xmake/hygon.lua | 8 ++++---- xmake/iluvatar.lua | 8 ++++---- xmake/kunlun.lua | 6 +++--- xmake/metax.lua | 8 ++++---- xmake/moore.lua | 6 +++--- xmake/nvidia.lua | 4 ++-- xmake/qy.lua | 4 ++-- xmake/test.lua | 2 +- 12 files changed, 40 insertions(+), 40 deletions(-) diff --git a/xmake.lua b/xmake.lua index d5a4ba7f7..a493b6a46 100644 --- a/xmake.lua +++ b/xmake.lua @@ -19,7 +19,7 @@ end if is_plat("windows") then set_runtimes("MD") add_ldflags("/utf-8", {force = true}) - add_cxflags("/utf-8", {force = true}) + add_cxxflags("/utf-8", {force = true}) end -- CPU @@ -218,14 +218,14 @@ target("infini-utils") set_warnings("all", "error") if is_plat("windows") then - add_cxflags("/wd4068") + add_cxxflags("/wd4068") if has_config("omp") then - add_cxflags("/openmp") + add_cxxflags("/openmp") end else - add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp", {force = true}) end end @@ -269,7 +269,7 @@ target("infinirt") end set_languages("cxx17") if not is_plat("windows") then - add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")) add_files("src/infinirt/*.cc") diff --git a/xmake/ascend.lua b/xmake/ascend.lua index 6a28979b4..5d9e12b67 100644 --- a/xmake/ascend.lua +++ b/xmake/ascend.lua @@ -43,7 +43,7 @@ target("infiniop-ascend") add_deps("infini-utils") on_install(function (target) end) - add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") set_warnings("all", "error") set_languages("cxx17") @@ -61,7 +61,7 @@ target("infinirt-ascend") add_deps("infini-utils") -- Add files add_files("$(projectdir)/src/infinirt/ascend/*.cc") - add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-ascend") @@ -75,6 +75,6 @@ target("infiniccl-ascend") add_includedirs(ASCEND_HOME .. "/include/hccl") add_links("libhccl.so") add_files("../src/infiniccl/ascend/*.cc") - add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") end target_end() diff --git a/xmake/bang.lua b/xmake/bang.lua index d2195acd5..e841874ec 100644 --- a/xmake/bang.lua +++ b/xmake/bang.lua @@ -40,7 +40,7 @@ target("infiniop-cambricon") add_deps("infini-utils") on_install(function (target) end) - add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") set_warnings("all", "error") set_languages("cxx17") @@ -58,7 +58,7 @@ target("infinirt-cambricon") on_install(function (target) end) -- Add include dirs add_files("../src/infinirt/bang/*.cc") - add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-cambricon") @@ -88,7 +88,7 @@ target("infiniccl-cambricon") add_runenvs("LD_LIBRARY_PATH", NEUWARE_HOME .. "/lib64") add_files("../src/infiniccl/cambricon/*.cc") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") add_ldflags("-fPIC") else print("[Warning] CNCL is currently only supported on Linux") diff --git a/xmake/cpu.lua b/xmake/cpu.lua index 22dc8f8e7..a042243e0 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -6,14 +6,14 @@ target("infiniop-cpu") set_warnings("all", "error") if is_plat("windows") then - add_cxflags("/wd4068") + add_cxxflags("/wd4068") if has_config("omp") then - add_cxflags("/openmp") + add_cxxflags("/openmp") end else - add_cxflags("-fPIC", "-Wno-unknown-pragmas") + add_cxxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp") end end @@ -31,7 +31,7 @@ target("infinirt-cpu") set_warnings("all", "error") if not is_plat("windows") then - add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") diff --git a/xmake/hygon.lua b/xmake/hygon.lua index ed4b91f0e..22a226b0f 100644 --- a/xmake/hygon.lua +++ b/xmake/hygon.lua @@ -59,7 +59,7 @@ target("infiniop-hygon") add_cuflags("-Wno-return-type", {force = true}) -- 抑制return语句警告 add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") @@ -76,7 +76,7 @@ target("infiniop-hygon") add_files("../src/infiniop/ops/swiglu/nvidia/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", {cxxflags = {"-Wno-return-type"}}) end target_end() @@ -104,7 +104,7 @@ target("infinirt-hygon") add_cuflags("-Wno-return-type", {force = true}) -- 抑制return语句警告 add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") @@ -137,7 +137,7 @@ target("infiniccl-hygon") add_cuflags("-Wno-return-type", {force = true}) -- 抑制return语句警告 add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 add_cuflags("-arch=gfx906", "-arch=gfx926", "-arch=gfx928", "-arch=gfx936") diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua index 35ccf2154..90e39983e 100644 --- a/xmake/iluvatar.lua +++ b/xmake/iluvatar.lua @@ -45,7 +45,7 @@ target("infiniop-iluvatar") add_cuflags("-Wno-error=unused-private-field") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- set_languages("cxx17") 天数似乎不能用这个配置 add_files("../src/infiniop/devices/nvidia/*.cu", "../src/infiniop/ops/*/nvidia/*.cu") @@ -54,7 +54,7 @@ target("infiniop-iluvatar") add_files("../src/infiniop/ops/dequantize_awq/iluvatar/*.cu") if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", {cxxflags = {"-Wno-return-type"}}) end target_end() @@ -72,7 +72,7 @@ target("infinirt-iluvatar") set_warnings("all", "error") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") -- set_languages("cxx17") 天数似乎不能用这个配置 add_files("../src/infinirt/cuda/*.cu") @@ -93,7 +93,7 @@ target("infiniccl-iluvatar") set_warnings("all", "error") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/kunlun.lua b/xmake/kunlun.lua index 185082b3c..ed82ac808 100644 --- a/xmake/kunlun.lua +++ b/xmake/kunlun.lua @@ -74,7 +74,7 @@ target("infiniop-kunlun") add_deps("infini-utils") on_install(function (target) end) - add_cxflags("-lstdc++ -fPIC -Wno-error=unused-function") + add_cxxflags("-lstdc++ -fPIC -Wno-error=unused-function") set_warnings("all", "error") set_languages("cxx17") @@ -101,7 +101,7 @@ target("infinirt-kunlun") on_install(function (target) end) -- Add include dirs add_files("$(projectdir)/src/infinirt/kunlun/*.cc") - add_cxflags("-lstdc++ -Wall -Werror -fPIC") + add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() target("infiniccl-kunlun") @@ -116,6 +116,6 @@ target("infiniccl-kunlun") add_linkdirs(path.join(XCCL_DIR, "so")) add_links("bkcl") add_files("$(projectdir)/src/infiniccl/kunlun/*.cc") - add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") end target_end() diff --git a/xmake/metax.lua b/xmake/metax.lua index 5561b45db..121a21e78 100644 --- a/xmake/metax.lua +++ b/xmake/metax.lua @@ -47,12 +47,12 @@ target("infiniop-metax") on_install(function (target) end) set_languages("cxx17") set_warnings("all", "error") - add_cxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true}) + add_cxxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true}) add_files("../src/infiniop/devices/metax/*.cc", "../src/infiniop/ops/*/metax/*.cc") add_files("../src/infiniop/ops/*/metax/*.maca", {rule = "maca"}) if has_config("ninetoothed") then - add_files("../build/ninetoothed/*.c", {cxflags = {"-include stdlib.h", "-Wno-return-type"}}) + add_files("../build/ninetoothed/*.c", {cxxflags = {"-include stdlib.h", "-Wno-return-type"}}) end target_end() @@ -62,7 +62,7 @@ target("infinirt-metax") on_install(function (target) end) add_deps("infini-utils") set_warnings("all", "error") - add_cxflags("-lstdc++ -fPIC") + add_cxxflags("-lstdc++ -fPIC") add_files("../src/infinirt/metax/*.cc") target_end() @@ -72,7 +72,7 @@ target("infiniccl-metax") on_install(function (target) end) set_warnings("all", "error") if not is_plat("windows") then - add_cxflags("-fPIC") + add_cxxflags("-fPIC") end if has_config("ccl") then if has_config("use-mc") then diff --git a/xmake/moore.lua b/xmake/moore.lua index 25eddf522..5ad5387b6 100644 --- a/xmake/moore.lua +++ b/xmake/moore.lua @@ -41,7 +41,7 @@ target("infiniop-moore") on_install(function (target) end) set_languages("cxx17") set_warnings("all", "error") - add_cxflags("-lstdc++", "-fPIC", "-Wno-comment") + add_cxxflags("-lstdc++", "-fPIC", "-Wno-comment") add_files("../src/infiniop/devices/moore/*.cc") add_files("../src/infiniop/ops/*/moore/*.mu", {rule = "mu"}) @@ -55,7 +55,7 @@ target("infinirt-moore") on_install(function (target) end) add_deps("infini-utils") set_warnings("all", "error") - add_cxflags("-lstdc++", "-fPIC") + add_cxxflags("-lstdc++", "-fPIC") add_files("../src/infinirt/moore/*.cc") target_end() @@ -65,7 +65,7 @@ target("infiniccl-moore") on_install(function (target) end) set_warnings("all", "error") if not is_plat("windows") then - add_cxflags("-fPIC") + add_cxxflags("-fPIC") end if has_config("ccl") then add_links("libmccl.so") diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 75086b8a1..934ac967c 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -92,7 +92,7 @@ target("infinirt-nvidia") else add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") @@ -111,7 +111,7 @@ target("infiniccl-nvidia") if not is_plat("windows") then add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/qy.lua b/xmake/qy.lua index ecef359a8..896135d13 100644 --- a/xmake/qy.lua +++ b/xmake/qy.lua @@ -116,7 +116,7 @@ target("infinirt-qy") else add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") end set_languages("cxx17") @@ -132,7 +132,7 @@ target("infiniccl-qy") if not is_plat("windows") then add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") - add_cxflags("-fPIC") + add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") if nccl_root then diff --git a/xmake/test.lua b/xmake/test.lua index 002083e1d..56dca6e5f 100644 --- a/xmake/test.lua +++ b/xmake/test.lua @@ -24,7 +24,7 @@ target("infiniop-test") add_links("infiniop", "infinirt") if has_config("omp") then - add_cxflags("-fopenmp") + add_cxxflags("-fopenmp") add_ldflags("-fopenmp") end From 4201ea720bd7ae9885ff5a798fbaa700d3c54a2b Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 22 Jan 2026 09:49:41 +0000 Subject: [PATCH 12/14] issue/889 - optimize flash attention performance from default setup --- src/infiniop/ops/flash_attention/ninetoothed/build.py | 2 +- src/infiniop/ops/flash_attention/ninetoothed/descriptor.h | 4 ++-- .../ops/flash_attention/ninetoothed/flash_attention.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/infiniop/ops/flash_attention/ninetoothed/build.py b/src/infiniop/ops/flash_attention/ninetoothed/build.py index fc36eec6a..4455e1ea6 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/build.py +++ b/src/infiniop/ops/flash_attention/ninetoothed/build.py @@ -12,7 +12,7 @@ def build(): with_attn_mask_values = (0,) causal_variant_values = (CausalVariant.UPPER_LEFT, CausalVariant.LOWER_RIGHT) dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32) - block_size_m_values = (64,) + block_size_m_values = (256,) block_size_n_values = (64,) constexpr_param_grid = { diff --git a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h index d47a347e1..0a6e9c1f8 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h +++ b/src/infiniop/ops/flash_attention/ninetoothed/descriptor.h @@ -67,10 +67,10 @@ class Descriptor final : public InfiniopDescriptor { const auto emb_dim_{_query_shape[3]}; const auto is_causal_{_is_causal}; const auto with_attn_mask_{0}; - const auto causal_variant_{1}; + const auto causal_variant_{2}; const auto dtype_{_dtype}; - constexpr auto block_size_m_{64}; + constexpr auto block_size_m_{256}; constexpr auto block_size_n_{64}; if (launch_flash_attention(stream, diff --git a/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py index 965408408..66598fe60 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py +++ b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py @@ -183,7 +183,7 @@ def application_without_kv_cache( lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32) max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32) - for j in range(min(key.shape[0], actual_kv_len)): + for j in range(key.shape[0]): qk = ntl.dot(query_i, ntl.trans(key[j])) @@ -196,7 +196,7 @@ def application_without_kv_cache( if is_causal: query_pos = query[i].offsets(-2) - if causal_variant == 2: + if causal_variant == 2: # CausalVariant.LOWER_RIGHT: mask = ( query_pos[:, None] + actual_kv_len - query.source.shape[-2] >= key_pos[None, :] From aeb0e1f41bd54e3a300134481bbad204f09aba03 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 22 Jan 2026 20:28:03 +0800 Subject: [PATCH 13/14] issue/889 - optimize flash attention performance from kernel --- src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py index 66598fe60..22d63ae4a 100644 --- a/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py +++ b/src/infiniop/ops/flash_attention/ninetoothed/flash_attention.py @@ -183,7 +183,7 @@ def application_without_kv_cache( lse = ntl.full((query_i.shape[-2],), 1, dtype=ntl.float32) max = ntl.full((query_i.shape[-2],), float("-inf"), dtype=ntl.float32) - for j in range(key.shape[0]): + for j in range(-(-actual_kv_len // key.dtype.shape[0])): qk = ntl.dot(query_i, ntl.trans(key[j])) From ed7d8430a50cfe0776b182c24c0f5ef9d626030b Mon Sep 17 00:00:00 2001 From: wooway777 Date: Fri, 23 Jan 2026 21:11:59 +0800 Subject: [PATCH 14/14] issue/889 - add cxflags for fpic --- xmake.lua | 2 ++ xmake/ascend.lua | 3 +++ xmake/bang.lua | 3 +++ xmake/cpu.lua | 2 ++ xmake/hygon.lua | 3 +++ xmake/iluvatar.lua | 3 +++ xmake/kunlun.lua | 3 +++ xmake/metax.lua | 3 +++ xmake/moore.lua | 3 +++ xmake/nvidia.lua | 3 +++ xmake/qy.lua | 3 +++ 11 files changed, 31 insertions(+) diff --git a/xmake.lua b/xmake.lua index a493b6a46..6070711a6 100644 --- a/xmake.lua +++ b/xmake.lua @@ -223,6 +223,7 @@ target("infini-utils") add_cxxflags("/openmp") end else + add_cxflags("-fPIC", "-Wno-unknown-pragmas") add_cxxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then add_cxxflags("-fopenmp") @@ -269,6 +270,7 @@ target("infinirt") end set_languages("cxx17") if not is_plat("windows") then + add_cxflags("-fPIC") add_cxxflags("-fPIC") end set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")) diff --git a/xmake/ascend.lua b/xmake/ascend.lua index 5d9e12b67..e51626d1d 100644 --- a/xmake/ascend.lua +++ b/xmake/ascend.lua @@ -43,6 +43,7 @@ target("infiniop-ascend") add_deps("infini-utils") on_install(function (target) end) + add_cxflags("-lstdc++ -fPIC") add_cxxflags("-lstdc++ -fPIC") set_warnings("all", "error") @@ -61,6 +62,7 @@ target("infinirt-ascend") add_deps("infini-utils") -- Add files add_files("$(projectdir)/src/infinirt/ascend/*.cc") + add_cxflags("-lstdc++ -Wall -Werror -fPIC") add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() @@ -75,6 +77,7 @@ target("infiniccl-ascend") add_includedirs(ASCEND_HOME .. "/include/hccl") add_links("libhccl.so") add_files("../src/infiniccl/ascend/*.cc") + add_cxflags("-lstdc++ -fPIC") add_cxxflags("-lstdc++ -fPIC") end target_end() diff --git a/xmake/bang.lua b/xmake/bang.lua index e841874ec..ffa85ef6d 100644 --- a/xmake/bang.lua +++ b/xmake/bang.lua @@ -40,6 +40,7 @@ target("infiniop-cambricon") add_deps("infini-utils") on_install(function (target) end) + add_cxflags("-lstdc++ -fPIC") add_cxxflags("-lstdc++ -fPIC") set_warnings("all", "error") @@ -58,6 +59,7 @@ target("infinirt-cambricon") on_install(function (target) end) -- Add include dirs add_files("../src/infinirt/bang/*.cc") + add_cxflags("-lstdc++ -Wall -Werror -fPIC") add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() @@ -88,6 +90,7 @@ target("infiniccl-cambricon") add_runenvs("LD_LIBRARY_PATH", NEUWARE_HOME .. "/lib64") add_files("../src/infiniccl/cambricon/*.cc") + add_cxflags("-fPIC") add_cxxflags("-fPIC") add_ldflags("-fPIC") else diff --git a/xmake/cpu.lua b/xmake/cpu.lua index a042243e0..e192fbbbd 100644 --- a/xmake/cpu.lua +++ b/xmake/cpu.lua @@ -11,6 +11,7 @@ target("infiniop-cpu") add_cxxflags("/openmp") end else + add_cxflags("-fPIC", "-Wno-unknown-pragmas") add_cxxflags("-fPIC", "-Wno-unknown-pragmas") if has_config("omp") then add_cxxflags("-fopenmp") @@ -31,6 +32,7 @@ target("infinirt-cpu") set_warnings("all", "error") if not is_plat("windows") then + add_cxflags("-fPIC") add_cxxflags("-fPIC") end diff --git a/xmake/hygon.lua b/xmake/hygon.lua index 22a226b0f..05d3e8356 100644 --- a/xmake/hygon.lua +++ b/xmake/hygon.lua @@ -59,6 +59,7 @@ target("infiniop-hygon") add_cuflags("-Wno-return-type", {force = true}) -- 抑制return语句警告 add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 @@ -104,6 +105,7 @@ target("infinirt-hygon") add_cuflags("-Wno-return-type", {force = true}) -- 抑制return语句警告 add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 @@ -137,6 +139,7 @@ target("infiniccl-hygon") add_cuflags("-Wno-return-type", {force = true}) -- 抑制return语句警告 add_cuflags("-fPIC", "-std=c++17", {force = true}) add_culdflags("-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") -- 添加海光DCU特定的编译标志 diff --git a/xmake/iluvatar.lua b/xmake/iluvatar.lua index 90e39983e..b4ba792fa 100644 --- a/xmake/iluvatar.lua +++ b/xmake/iluvatar.lua @@ -45,6 +45,7 @@ target("infiniop-iluvatar") add_cuflags("-Wno-error=unused-private-field") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") -- set_languages("cxx17") 天数似乎不能用这个配置 @@ -72,6 +73,7 @@ target("infinirt-iluvatar") set_warnings("all", "error") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") -- set_languages("cxx17") 天数似乎不能用这个配置 @@ -93,6 +95,7 @@ target("infiniccl-iluvatar") set_warnings("all", "error") add_cuflags("-fPIC", "-x", "ivcore", "-std=c++17", {force = true}) add_culdflags("-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") diff --git a/xmake/kunlun.lua b/xmake/kunlun.lua index ed82ac808..84ba14082 100644 --- a/xmake/kunlun.lua +++ b/xmake/kunlun.lua @@ -74,6 +74,7 @@ target("infiniop-kunlun") add_deps("infini-utils") on_install(function (target) end) + add_cxflags("-lstdc++ -fPIC -Wno-error=unused-function") add_cxxflags("-lstdc++ -fPIC -Wno-error=unused-function") set_warnings("all", "error") @@ -101,6 +102,7 @@ target("infinirt-kunlun") on_install(function (target) end) -- Add include dirs add_files("$(projectdir)/src/infinirt/kunlun/*.cc") + add_cxflags("-lstdc++ -Wall -Werror -fPIC") add_cxxflags("-lstdc++ -Wall -Werror -fPIC") target_end() @@ -116,6 +118,7 @@ target("infiniccl-kunlun") add_linkdirs(path.join(XCCL_DIR, "so")) add_links("bkcl") add_files("$(projectdir)/src/infiniccl/kunlun/*.cc") + add_cxflags("-lstdc++ -fPIC") add_cxxflags("-lstdc++ -fPIC") end target_end() diff --git a/xmake/metax.lua b/xmake/metax.lua index 121a21e78..91672abe0 100644 --- a/xmake/metax.lua +++ b/xmake/metax.lua @@ -47,6 +47,7 @@ target("infiniop-metax") on_install(function (target) end) set_languages("cxx17") set_warnings("all", "error") + add_cxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true}) add_cxxflags("-lstdc++", "-fPIC", "-Wno-defaulted-function-deleted", "-Wno-strict-aliasing", {force = true}) add_files("../src/infiniop/devices/metax/*.cc", "../src/infiniop/ops/*/metax/*.cc") add_files("../src/infiniop/ops/*/metax/*.maca", {rule = "maca"}) @@ -62,6 +63,7 @@ target("infinirt-metax") on_install(function (target) end) add_deps("infini-utils") set_warnings("all", "error") + add_cxflags("-lstdc++ -fPIC") add_cxxflags("-lstdc++ -fPIC") add_files("../src/infinirt/metax/*.cc") target_end() @@ -72,6 +74,7 @@ target("infiniccl-metax") on_install(function (target) end) set_warnings("all", "error") if not is_plat("windows") then + add_cxflags("-fPIC") add_cxxflags("-fPIC") end if has_config("ccl") then diff --git a/xmake/moore.lua b/xmake/moore.lua index 5ad5387b6..fdcad9564 100644 --- a/xmake/moore.lua +++ b/xmake/moore.lua @@ -41,6 +41,7 @@ target("infiniop-moore") on_install(function (target) end) set_languages("cxx17") set_warnings("all", "error") + add_cxflags("-lstdc++", "-fPIC", "-Wno-comment") add_cxxflags("-lstdc++", "-fPIC", "-Wno-comment") add_files("../src/infiniop/devices/moore/*.cc") add_files("../src/infiniop/ops/*/moore/*.mu", {rule = "mu"}) @@ -55,6 +56,7 @@ target("infinirt-moore") on_install(function (target) end) add_deps("infini-utils") set_warnings("all", "error") + add_cxflags("-lstdc++", "-fPIC") add_cxxflags("-lstdc++", "-fPIC") add_files("../src/infinirt/moore/*.cc") target_end() @@ -65,6 +67,7 @@ target("infiniccl-moore") on_install(function (target) end) set_warnings("all", "error") if not is_plat("windows") then + add_cxflags("-fPIC") add_cxxflags("-fPIC") end if has_config("ccl") then diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 934ac967c..5752dfefe 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -48,6 +48,7 @@ target("infiniop-nvidia") add_cuflags("-Xcompiler=-fPIC") add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") add_cflags("-fPIC") add_cuflags("--expt-relaxed-constexpr") @@ -92,6 +93,7 @@ target("infinirt-nvidia") else add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") end @@ -111,6 +113,7 @@ target("infiniccl-nvidia") if not is_plat("windows") then add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT") diff --git a/xmake/qy.lua b/xmake/qy.lua index 896135d13..1defe8763 100644 --- a/xmake/qy.lua +++ b/xmake/qy.lua @@ -88,6 +88,7 @@ target("infiniop-qy") add_cuflags("-Xcompiler=-fPIC") add_cuflags("--extended-lambda") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") add_cuflags("--expt-relaxed-constexpr") if CUDNN_ROOT ~= nil then @@ -116,6 +117,7 @@ target("infinirt-qy") else add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") end @@ -132,6 +134,7 @@ target("infiniccl-qy") if not is_plat("windows") then add_cuflags("-Xcompiler=-fPIC") add_culdflags("-Xcompiler=-fPIC") + add_cxflags("-fPIC") add_cxxflags("-fPIC") local nccl_root = os.getenv("NCCL_ROOT")