diff --git a/third_party/xllm_ops b/third_party/xllm_ops index 79eb46cb95..409d1dbb3b 160000 --- a/third_party/xllm_ops +++ b/third_party/xllm_ops @@ -1 +1 @@ -Subproject commit 79eb46cb951346acc12bfbd8b9b9170bc8a83db9 +Subproject commit 409d1dbb3b2c5c04b280fbe9e6c95a861f2f603a diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 683617fb3e..1a72e380d5 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -420,6 +420,12 @@ struct ModelInputParams { safe_to(graph_buffer.attn_mask, device, true); params.graph_buffer.tiling_data = safe_to(graph_buffer.tiling_data, device, true); + params.graph_buffer.xfa_q_cu_seq_lens = + safe_to(graph_buffer.xfa_q_cu_seq_lens, device, true); + params.graph_buffer.xfa_extra_tiling = + safe_to(graph_buffer.xfa_extra_tiling, device, true); + params.graph_buffer.xfa_attn_mask = + safe_to(graph_buffer.xfa_attn_mask, device, true); // params for flashinfer params.paged_kv_indptr = safe_to(paged_kv_indptr, device); @@ -666,6 +672,9 @@ struct ModelInputParams { struct GraphBuffer { torch::Tensor attn_mask; torch::Tensor tiling_data; + torch::Tensor xfa_q_cu_seq_lens; + torch::Tensor xfa_extra_tiling; + torch::Tensor xfa_attn_mask; bool use_expanded_decode_for_spec_verify_attention = false; torch::Tensor expanded_kv_seq_lens; torch::Tensor expanded_block_tables; diff --git a/xllm/core/kernels/npu/xllm_ops/CMakeLists.txt b/xllm/core/kernels/npu/xllm_ops/CMakeLists.txt index f8e9cda1bb..f4d17d4cc7 100644 --- a/xllm/core/kernels/npu/xllm_ops/CMakeLists.txt +++ b/xllm/core/kernels/npu/xllm_ops/CMakeLists.txt @@ -13,6 +13,7 @@ cc_library( beam_search.cpp select_unshared_kv.cpp beam_search_rec.cpp + x_flash_attention_infer.cpp DEPS atb torch_npu diff --git a/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp b/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp index cec49c2142..2768d08339 100644 --- a/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp +++ b/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp @@ -83,6 +83,9 @@ void beam_search_rec(const torch::Tensor& logprobs, create_acltensor(&out_beam_count_prefix_sums_ids, out_beam_count_prefix_sums); create_acltensor(&out_sequence_ids, out_sequence); + CHECK_GT(top_tokens.dim(), 0) + << "beam_search_rec: top_tokens must have at least one dimension"; + const int64_t top_k = top_tokens.size(-1); uint64_t workspace_size = 0; aclOpExecutor* executor = nullptr; CHECK_ACL_SUCCESS( @@ -91,6 +94,7 @@ void beam_search_rec(const torch::Tensor& logprobs, top_logprobs_ids, sequence_group_ids, current_step, + top_k, out_token_ids_ids, out_token_index_ids, out_log_probs_ids, diff --git a/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp b/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp new file mode 100644 index 0000000000..8b9ee57fa7 --- /dev/null +++ b/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp @@ -0,0 +1,359 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "aclnn_x_flash_attention_infer.h" +#include "core/kernels/npu/aclnn/pytorch_npu_helper.hpp" +#include "core/kernels/npu/utils.h" +#include "xllm_ops_api.h" + +namespace { + +constexpr uint32_t kMaxExtraInfoNodes = 25; +constexpr uint32_t kMaxKvStackLen = 512; + +struct CoreNode { + uint32_t start_b_idx = 0; + uint32_t start_n1_idx = 0; + uint32_t start_s2_idx = 0; + uint32_t end_b_idx = 0; + uint32_t end_n1_idx = 0; + uint32_t end_s2_idx = 0; + uint64_t first_split_kv_task_lse_offset = 0; + uint64_t first_split_kv_task_o_offset = 0; +}; + +struct SplitNode { + uint32_t batch_idx = 0; + uint32_t head_start_idx = 0; + uint32_t head_end_idx = 0; + uint32_t q_start_idx = 0; + uint32_t q_end_idx = 0; + uint32_t split_num = 0; + uint64_t lse_task_offset = 0; + uint64_t o_task_offset = 0; +}; + +struct SplitKvExtraInfo { + CoreNode core_info[kMaxExtraInfoNodes]; + SplitNode split_info[kMaxExtraInfoNodes]; + uint32_t total_split_node_num = 0; +}; + +torch::Tensor make_extra_tiling(const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + int64_t block_size, + int64_t head_size) { + CHECK_EQ(actual_q_lens.dim(), 1) << "actual_q_lens must be 1-D"; + CHECK_EQ(actual_kv_lens.dim(), 1) << "actual_kv_lens must be 1-D"; + CHECK_EQ(actual_q_lens.numel(), actual_kv_lens.numel()) + << "actual_q_lens and actual_kv_lens must have the same length"; + CHECK_GT(actual_q_lens.numel(), 0) << "batch size must be positive"; + CHECK_GT(q_head, 0) << "q_head must be positive"; + CHECK_GT(kv_head, 0) << "kv_head must be positive"; + CHECK_EQ(q_head % kv_head, 0) << "q_head must be divisible by kv_head"; + CHECK_GT(block_size, 0) << "block_size must be positive"; + CHECK_GT(head_size, 0) << "head_size must be positive"; + + auto q_lens_cpu = + actual_q_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + auto kv_lens_cpu = + actual_kv_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + const auto* q_lens = q_lens_cpu.data_ptr(); + const auto* kv_lens = kv_lens_cpu.data_ptr(); + const int64_t batch = actual_q_lens.numel(); + + SplitKvExtraInfo extra_info{}; + for (uint32_t i = 0; i < kMaxExtraInfoNodes; ++i) { + extra_info.core_info[i].start_b_idx = std::numeric_limits::max(); + } + + const uint32_t block_stack_num = + std::max(1, kMaxKvStackLen / static_cast(block_size)); + const uint32_t group_size = static_cast(q_head / kv_head); + uint32_t core_idx = 0; + uint32_t split_node_idx = 0; + uint64_t lse_offset = 0; + uint64_t o_offset = 0; + + for (int64_t batch_idx = 0; batch_idx < batch; ++batch_idx) { + const uint32_t q_end = static_cast(q_lens[batch_idx]); + const uint32_t q_start = + batch_idx == 0 ? 0 : static_cast(q_lens[batch_idx - 1]); + const uint32_t q_len = q_end - q_start; + const uint32_t kv_seq_len = static_cast(kv_lens[batch_idx]); + const uint32_t kv_blocks = + (kv_seq_len + static_cast(block_size) - 1) / + static_cast(block_size); + const uint32_t s2_blocks = + (kv_blocks + block_stack_num - 1) / block_stack_num; + + if (s2_blocks <= 1) { + CHECK_LT(core_idx, kMaxExtraInfoNodes) + << "x_flash_attention_infer extra_tiling core_info overflow"; + auto& core = extra_info.core_info[core_idx++]; + core.start_b_idx = static_cast(batch_idx); + core.start_n1_idx = 0; + core.start_s2_idx = 0; + core.end_b_idx = static_cast(batch_idx); + core.end_n1_idx = static_cast(kv_head - 1); + core.end_s2_idx = s2_blocks; + continue; + } + + // For long-KV requests, keep each KV head on one core and scan all S2 + // blocks inside that core. This avoids the split-KV combine path, whose + // fixed coreInfo/splitInfo capacity is too small for long Qwen3.5 chunks. + for (uint32_t kv_head_idx = 0; kv_head_idx < static_cast(kv_head); + ++kv_head_idx) { + CHECK_LT(core_idx, kMaxExtraInfoNodes) + << "x_flash_attention_infer extra_tiling core_info overflow"; + auto& core = extra_info.core_info[core_idx++]; + core.start_b_idx = static_cast(batch_idx); + core.start_n1_idx = kv_head_idx; + core.start_s2_idx = 0; + core.end_b_idx = static_cast(batch_idx); + core.end_n1_idx = kv_head_idx; + core.end_s2_idx = s2_blocks; + } + } + extra_info.total_split_node_num = split_node_idx; + + const int64_t int32_count = + (sizeof(SplitKvExtraInfo) + sizeof(int32_t) - 1) / sizeof(int32_t); + auto extra_tiling_cpu = + torch::empty({int32_count}, torch::TensorOptions().dtype(torch::kInt32)); + std::memcpy(extra_tiling_cpu.data_ptr(), + &extra_info, + sizeof(SplitKvExtraInfo)); + return extra_tiling_cpu.to(actual_q_lens.device(), /*non_blocking=*/false); +} + +torch::Tensor make_extra_tiling_from_host_lens( + const std::vector& actual_q_lens, + const std::vector& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + int64_t block_size, + int64_t head_size) { + CHECK_EQ(actual_q_lens.size(), actual_kv_lens.size()) + << "actual_q_lens and actual_kv_lens must have the same length"; + CHECK_GT(actual_q_lens.size(), 0) << "batch size must be positive"; + CHECK_GT(q_head, 0) << "q_head must be positive"; + CHECK_GT(kv_head, 0) << "kv_head must be positive"; + CHECK_EQ(q_head % kv_head, 0) << "q_head must be divisible by kv_head"; + CHECK_GT(block_size, 0) << "block_size must be positive"; + CHECK_GT(head_size, 0) << "head_size must be positive"; + + SplitKvExtraInfo extra_info{}; + for (uint32_t i = 0; i < kMaxExtraInfoNodes; ++i) { + extra_info.core_info[i].start_b_idx = std::numeric_limits::max(); + } + + const uint32_t block_stack_num = + std::max(1, kMaxKvStackLen / static_cast(block_size)); + uint32_t core_idx = 0; + uint32_t split_node_idx = 0; + + for (int64_t batch_idx = 0; + batch_idx < static_cast(actual_q_lens.size()); + ++batch_idx) { + const uint32_t kv_seq_len = + static_cast(actual_kv_lens[batch_idx]); + const uint32_t kv_blocks = + (kv_seq_len + static_cast(block_size) - 1) / + static_cast(block_size); + const uint32_t s2_blocks = + (kv_blocks + block_stack_num - 1) / block_stack_num; + + if (s2_blocks <= 1) { + CHECK_LT(core_idx, kMaxExtraInfoNodes) + << "x_flash_attention_infer extra_tiling core_info overflow"; + auto& core = extra_info.core_info[core_idx++]; + core.start_b_idx = static_cast(batch_idx); + core.start_n1_idx = 0; + core.start_s2_idx = 0; + core.end_b_idx = static_cast(batch_idx); + core.end_n1_idx = static_cast(kv_head - 1); + core.end_s2_idx = s2_blocks; + continue; + } + + for (uint32_t kv_head_idx = 0; kv_head_idx < static_cast(kv_head); + ++kv_head_idx) { + CHECK_LT(core_idx, kMaxExtraInfoNodes) + << "x_flash_attention_infer extra_tiling core_info overflow"; + auto& core = extra_info.core_info[core_idx++]; + core.start_b_idx = static_cast(batch_idx); + core.start_n1_idx = kv_head_idx; + core.start_s2_idx = 0; + core.end_b_idx = static_cast(batch_idx); + core.end_n1_idx = kv_head_idx; + core.end_s2_idx = s2_blocks; + } + } + extra_info.total_split_node_num = split_node_idx; + + const int64_t int32_count = + (sizeof(SplitKvExtraInfo) + sizeof(int32_t) - 1) / sizeof(int32_t); + auto extra_tiling_cpu = + torch::empty({int32_count}, torch::TensorOptions().dtype(torch::kInt32)); + std::memcpy(extra_tiling_cpu.data_ptr(), + &extra_info, + sizeof(SplitKvExtraInfo)); + return extra_tiling_cpu; +} + +} // namespace + +namespace xllm::kernel::npu { + +torch::Tensor x_flash_attention_infer(const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + double scale, + const std::string& layout) { + check_tensor(query, "query", "x_flash_attention_infer"); + check_tensor(key_cache, "key_cache", "x_flash_attention_infer"); + check_tensor(value_cache, "value_cache", "x_flash_attention_infer"); + check_tensor(mask, "mask", "x_flash_attention_infer"); + check_tensor(block_table, "block_table", "x_flash_attention_infer"); + check_tensor(actual_q_lens, "actual_q_lens", "x_flash_attention_infer"); + check_tensor(actual_kv_lens, "actual_kv_lens", "x_flash_attention_infer"); + CHECK_EQ(query.dim(), 3) << "query must be [tokens, q_heads, head_size]"; + CHECK_EQ(key_cache.dim(), 4) + << "key_cache must be [blocks, block_size, kv_heads, head_size]"; + CHECK_EQ(value_cache.dim(), 4) + << "value_cache must be [blocks, block_size, kv_heads, head_size]"; + CHECK_EQ(actual_q_lens.scalar_type(), torch::kInt32); + CHECK_EQ(actual_kv_lens.scalar_type(), torch::kInt32); + + torch::Tensor extra_tiling = make_extra_tiling(actual_q_lens, + actual_kv_lens, + q_head, + kv_head, + key_cache.size(1), + key_cache.size(3)); + return x_flash_attention_infer_with_extra_tiling(query, + key_cache, + value_cache, + mask, + block_table, + actual_q_lens, + actual_kv_lens, + extra_tiling, + q_head, + kv_head, + scale, + layout); +} + +torch::Tensor x_flash_attention_infer_with_extra_tiling( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + const torch::Tensor& extra_tiling, + int64_t q_head, + int64_t kv_head, + double scale, + const std::string& layout) { + check_tensor(query, "query", "x_flash_attention_infer"); + check_tensor(key_cache, "key_cache", "x_flash_attention_infer"); + check_tensor(value_cache, "value_cache", "x_flash_attention_infer"); + check_tensor(mask, "mask", "x_flash_attention_infer"); + check_tensor(block_table, "block_table", "x_flash_attention_infer"); + check_tensor(actual_q_lens, "actual_q_lens", "x_flash_attention_infer"); + check_tensor(actual_kv_lens, "actual_kv_lens", "x_flash_attention_infer"); + check_tensor(extra_tiling, "extra_tiling", "x_flash_attention_infer"); + CHECK_EQ(query.dim(), 3) << "query must be [tokens, q_heads, head_size]"; + CHECK_EQ(key_cache.dim(), 4) + << "key_cache must be [blocks, block_size, kv_heads, head_size]"; + CHECK_EQ(value_cache.dim(), 4) + << "value_cache must be [blocks, block_size, kv_heads, head_size]"; + CHECK_EQ(actual_q_lens.scalar_type(), torch::kInt32); + CHECK_EQ(actual_kv_lens.scalar_type(), torch::kInt32); + CHECK_EQ(extra_tiling.scalar_type(), torch::kInt32); + CHECK_GE(extra_tiling.numel(), x_flash_attention_extra_tiling_int32_count()); + + std::string layout_attr = layout; + char* layout_attr_ptr = const_cast(layout_attr.c_str()); + torch::Tensor output = torch::empty(query.sizes(), query.options()); + + EXEC_NPU_CMD(aclnnXFlashAttentionInfer, + query, + key_cache, + value_cache, + mask, + block_table, + actual_q_lens, + actual_kv_lens, + extra_tiling, + layout_attr_ptr, + q_head, + kv_head, + scale, + output); + return output; +} + +int64_t x_flash_attention_extra_tiling_int32_count() { + return (sizeof(SplitKvExtraInfo) + sizeof(int32_t) - 1) / sizeof(int32_t); +} + +int64_t x_flash_attention_decode_group_size(int64_t kv_head) { + CHECK_GT(kv_head, 0); + return kMaxExtraInfoNodes / kv_head; +} + +void update_x_flash_attention_extra_tiling( + const std::vector& actual_q_lens, + const std::vector& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + int64_t block_size, + int64_t head_size, + torch::Tensor& extra_tiling) { + auto next_extra_tiling = make_extra_tiling_from_host_lens( + actual_q_lens, actual_kv_lens, q_head, kv_head, block_size, head_size); + CHECK_GE(extra_tiling.numel(), next_extra_tiling.numel()) + << "x_flash_attention_infer persistent extra_tiling is too small"; + extra_tiling.slice(/*dim=*/0, /*start=*/0, /*end=*/next_extra_tiling.numel()) + .copy_(next_extra_tiling, /*non_blocking=*/false); +} + +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h b/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h index 4bb6527b40..ae07eb61a5 100644 --- a/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h +++ b/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h @@ -53,4 +53,43 @@ void select_unshared_kv(const torch::Tensor& beam_index, int64_t decode_step, int64_t beam_size, int64_t layer_num); + +torch::Tensor x_flash_attention_infer(const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + double scale, + const std::string& layout = "TND"); + +torch::Tensor x_flash_attention_infer_with_extra_tiling( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + const torch::Tensor& extra_tiling, + int64_t q_head, + int64_t kv_head, + double scale, + const std::string& layout = "TND"); + +int64_t x_flash_attention_extra_tiling_int32_count(); + +int64_t x_flash_attention_decode_group_size(int64_t kv_head); + +void update_x_flash_attention_extra_tiling( + const std::vector& actual_q_lens, + const std::vector& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + int64_t block_size, + int64_t head_size, + torch::Tensor& extra_tiling); } // namespace xllm::kernel::npu diff --git a/xllm/core/layers/common/attention_metadata.h b/xllm/core/layers/common/attention_metadata.h index 85e1d027ab..4c8fc61b44 100644 --- a/xllm/core/layers/common/attention_metadata.h +++ b/xllm/core/layers/common/attention_metadata.h @@ -162,6 +162,11 @@ struct AttentionMetadata { // For ACL graph execution - fixed-address device tiling data for // CustomPagedAttention replay. torch::Tensor paged_attention_tiling_data; + // For ACL graph execution with x_flash_attention_infer. The tensor address is + // fixed across replay and its contents are updated before graph replay. + torch::Tensor xfa_q_cu_seq_lens; + torch::Tensor xfa_extra_tiling; + torch::Tensor xfa_attn_mask; // Pre-computed attention mask for npu_fused_infer_attention. torch::Tensor fia_attn_mask; // Host vectors for npu_fused_infer_attention (kernel requires host memory). diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index 42a89e092c..216d67430e 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -89,6 +89,9 @@ AttentionMetadata build_attention_metadata( params.graph_buffer.expanded_block_tables; attn_metadata.expanded_paged_attention_tiling_data = params.graph_buffer.expanded_tiling_data; + attn_metadata.xfa_q_cu_seq_lens = params.graph_buffer.xfa_q_cu_seq_lens; + attn_metadata.xfa_extra_tiling = params.graph_buffer.xfa_extra_tiling; + attn_metadata.xfa_attn_mask = params.graph_buffer.xfa_attn_mask; if (!params.graph_buffer.expanded_kv_seq_lens_vec.empty()) { attn_metadata.expanded_kv_seq_lens_host = torch::tensor( params.graph_buffer.expanded_kv_seq_lens_vec, torch::kInt); @@ -106,6 +109,9 @@ AttentionMetadata build_attention_metadata( if (use_acl_graph) { // ACL graph mode: use CustomPagedAttention with tiling_data on device attn_metadata.paged_attention_tiling_data = params.graph_buffer.tiling_data; + attn_metadata.xfa_q_cu_seq_lens = params.graph_buffer.xfa_q_cu_seq_lens; + attn_metadata.xfa_extra_tiling = params.graph_buffer.xfa_extra_tiling; + attn_metadata.xfa_attn_mask = params.graph_buffer.xfa_attn_mask; } // Provide capture-time host seq_lens for NPU kernels. ACL graph replay must // rely on fixed-address device inputs such as tiling_data, not mutable host @@ -167,14 +173,17 @@ AttentionMetadata build_attention_metadata( params.batch_forward_type.is_chunked_prefill(); attn_metadata.is_prefill = params.batch_forward_type.is_prefill(); - // enable_mla is for DeepSeekv32 on mlu device +#if defined(USE_NPU) + attn_metadata.block_table = params.block_tables; +#else if (!attn_metadata.is_prefill || enable_mla) { attn_metadata.block_table = params.block_tables; -#if !defined(USE_NPU) && !defined(USE_CUDA) +#if !defined(USE_CUDA) attn_metadata.kv_seq_lens = torch::diff(params.kv_seq_lens); // kv seqlens attn_metadata.q_seq_lens = torch::diff(params.q_seq_lens); // q seqlens #endif } +#endif #if defined(USE_NPU) // NPU path uses per-sequence lengths (not cumulative), so no diff. // Ensure per-sequence lengths are available for NPU kernels in all phases. diff --git a/xllm/core/layers/npu_torch/attention.cpp b/xllm/core/layers/npu_torch/attention.cpp index 89ba076ea8..1b2ed98010 100644 --- a/xllm/core/layers/npu_torch/attention.cpp +++ b/xllm/core/layers/npu_torch/attention.cpp @@ -15,12 +15,403 @@ limitations under the License. #include "attention.h" +#include +#include +#include +#include +#include +#include +#include + #include "kernels/npu/npu_ops_api.h" +#include "kernels/npu/xllm_ops/xllm_ops_api.h" #include "kernels/ops_api.h" namespace xllm { namespace layer { +namespace { + +constexpr int64_t kFiaSplitFuseMaskSize = 2048; +constexpr int64_t kXfaMaxQkRowsPerCall = 128; +constexpr int64_t kXfaMaxExtraInfoNodes = 24; +constexpr int64_t kXfaMaxKvStackLen = 512; + +struct XfaQuerySlice { + int64_t seq_idx = 0; + int64_t q_start = 0; + int64_t q_len = 0; + int64_t kv_len = 0; + int64_t core_count = 1; +}; + +torch::Tensor int32_tensor_on_device(const std::vector& values, + const torch::Device& device) { + return torch::tensor(values, torch::TensorOptions().dtype(torch::kInt32)) + .to(device); +} + +torch::Tensor int64_tensor_on_device(const std::vector& values, + const torch::Device& device) { + return torch::tensor(values, torch::TensorOptions().dtype(torch::kInt64)) + .to(device); +} + +int64_t div_up(int64_t value, int64_t divisor) { + return (value + divisor - 1) / divisor; +} + +int64_t xfa_core_count_for_slice(int64_t kv_len, + int64_t kv_head, + int64_t block_size) { + const int64_t block_stack_num = + std::max(1, kXfaMaxKvStackLen / block_size); + const int64_t kv_blocks = div_up(kv_len, block_size); + const int64_t s2_blocks = div_up(kv_blocks, block_stack_num); + return s2_blocks <= 1 ? 1 : kv_head; +} + +torch::Tensor get_fia_split_fuse_attn_mask(const torch::Tensor& query) { + static std::mutex mutex; + static std::unordered_map mask_cache; + + const std::string cache_key = query.device().str(); + std::lock_guard lock(mutex); + auto it = mask_cache.find(cache_key); + if (it != mask_cache.end() && it->second.defined()) { + return it->second; + } + + auto cpu_options = torch::TensorOptions().dtype(torch::kFloat32); + auto mask = + torch::triu(torch::ones({kFiaSplitFuseMaskSize, kFiaSplitFuseMaskSize}, + cpu_options), + 1) + .to(torch::kInt8) + .to(query.device()) + .contiguous(); + mask_cache[cache_key] = mask; + return mask; +} + +bool env_flag_enabled(const char* name) { + const char* value = std::getenv(name); + return value != nullptr && std::string(value) == "1"; +} + +bool use_x_flash_decode() { + return !env_flag_enabled("XLLM_DISABLE_XFLASH_DECODE"); +} + +bool use_x_flash_prefill() { + return !env_flag_enabled("XLLM_DISABLE_XFLASH_PREFILL"); +} + +bool use_x_flash_prefill_grouping() { + return !env_flag_enabled("XLLM_DISABLE_XFLASH_PREFILL_GROUPING"); +} + +bool allow_long_kv_x_flash_prefill_grouping() { + return env_flag_enabled("XLLM_ENABLE_XFLASH_LONG_PREFILL_GROUPING"); +} + +bool can_group_x_flash_prefill_slice(const XfaQuerySlice& slice) { + // Long-KV multi-core grouping must stay behind an explicit validation flag. + return use_x_flash_prefill_grouping() && + (slice.core_count == 1 || allow_long_kv_x_flash_prefill_grouping()); +} + +torch::Tensor x_flash_attention_infer_with_query_split( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& q_seq_lens, + const torch::Tensor& kv_seq_lens, + const torch::Tensor& q_seq_lens_host, + const torch::Tensor& kv_seq_lens_host, + int64_t q_head, + int64_t kv_head, + double scale) { + CHECK_GT(kv_head, 0); + CHECK_EQ(q_head % kv_head, 0); + const int64_t group_size = q_head / kv_head; + const int64_t max_q_len_per_call = + std::max(1, kXfaMaxQkRowsPerCall / group_size); + + auto q_seq_lens_cpu = + q_seq_lens_host.defined() + ? q_seq_lens_host + .to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous() + : q_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + auto kv_seq_lens_cpu = + kv_seq_lens_host.defined() + ? kv_seq_lens_host + .to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous() + : kv_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + CHECK_EQ(q_seq_lens_cpu.numel(), kv_seq_lens_cpu.numel()); + const auto* q_seq_lens_ptr = q_seq_lens_cpu.data_ptr(); + const auto* kv_seq_lens_ptr = kv_seq_lens_cpu.data_ptr(); + const int64_t batch = q_seq_lens_cpu.numel(); + + auto output = torch::empty_like(query); + int64_t q_offset = 0; + std::vector slices; + const auto device = query.device(); + for (int64_t seq_idx = 0; seq_idx < batch; ++seq_idx) { + const int64_t q_len = q_seq_lens_ptr[seq_idx]; + const int64_t kv_len = kv_seq_lens_ptr[seq_idx]; + const int64_t past_kv_len = kv_len - q_len; + CHECK_GE(past_kv_len, 0); + + for (int64_t q_start = 0; q_start < q_len; q_start += max_q_len_per_call) { + const int64_t sub_q_len = + std::min(max_q_len_per_call, q_len - q_start); + const int64_t sub_kv_len = past_kv_len + q_start + sub_q_len; + const int64_t core_count = + xfa_core_count_for_slice(sub_kv_len, kv_head, key_cache.size(1)); + CHECK_LE(core_count, kXfaMaxExtraInfoNodes) + << "x_flash_attention_infer query slice needs too many cores"; + slices.push_back( + {seq_idx, q_offset + q_start, sub_q_len, sub_kv_len, core_count}); + } + q_offset += q_len; + } + + std::vector consumed(slices.size(), false); + auto has_seq_in_group = [&](const std::vector& group_indices, + int64_t seq_idx) { + return std::any_of( + group_indices.begin(), group_indices.end(), [&](size_t idx) { + return slices[idx].seq_idx == seq_idx; + }); + }; + auto run_group = [&](const std::vector& group_indices) { + CHECK(!group_indices.empty()); + int64_t core_count = 0; + std::vector query_parts; + std::vector group_q_lens; + std::vector group_kv_lens; + std::vector block_indices; + query_parts.reserve(group_indices.size()); + group_q_lens.reserve(group_indices.size()); + group_kv_lens.reserve(group_indices.size()); + block_indices.reserve(group_indices.size()); + int32_t q_cu_len = 0; + for (size_t idx : group_indices) { + const auto& slice = slices[idx]; + core_count += slice.core_count; + query_parts.push_back(query.narrow(0, slice.q_start, slice.q_len)); + q_cu_len += static_cast(slice.q_len); + group_q_lens.push_back(q_cu_len); + group_kv_lens.push_back(static_cast(slice.kv_len)); + block_indices.push_back(slice.seq_idx); + } + auto group_query = torch::cat(query_parts, /*dim=*/0).contiguous(); + + auto group_block_table = + block_table + .index_select(0, int64_tensor_on_device(block_indices, device)) + .contiguous(); + auto group_q_lens_tensor = int32_tensor_on_device(group_q_lens, device); + auto group_kv_lens_tensor = int32_tensor_on_device(group_kv_lens, device); + + auto group_output = + xllm::kernel::npu::x_flash_attention_infer(group_query, + key_cache, + value_cache, + mask, + group_block_table, + group_q_lens_tensor, + group_kv_lens_tensor, + q_head, + kv_head, + scale, + "TND"); + int64_t group_output_offset = 0; + for (size_t idx : group_indices) { + const auto& slice = slices[idx]; + output.narrow(0, slice.q_start, slice.q_len) + .copy_(group_output.narrow(0, group_output_offset, slice.q_len)); + group_output_offset += slice.q_len; + } + }; + + for (size_t group_start = 0; group_start < slices.size();) { + while (group_start < slices.size() && consumed[group_start]) { + ++group_start; + } + if (group_start >= slices.size()) { + break; + } + + std::vector group_indices{group_start}; + int64_t core_count = slices[group_start].core_count; + const int64_t group_q_len = slices[group_start].q_len; + const int64_t group_kv_len = slices[group_start].kv_len; + const int64_t group_core_count = slices[group_start].core_count; + const bool groupable = can_group_x_flash_prefill_slice(slices[group_start]); + + if (groupable && group_core_count > 1) { + // Long-KV multi-core grouping is only safe across distinct requests. + for (size_t idx = group_start + 1; idx < slices.size(); ++idx) { + if (consumed[idx] || slices[idx].q_len != group_q_len || + slices[idx].kv_len != group_kv_len || + slices[idx].core_count != group_core_count || + has_seq_in_group(group_indices, slices[idx].seq_idx)) { + continue; + } + if (core_count + slices[idx].core_count > kXfaMaxExtraInfoNodes) { + break; + } + core_count += slices[idx].core_count; + group_indices.push_back(idx); + } + } else { + size_t group_end = group_start + 1; + while (groupable && group_end < slices.size() && !consumed[group_end] && + slices[group_end].q_len == group_q_len && + slices[group_end].core_count == group_core_count && + core_count + slices[group_end].core_count <= + kXfaMaxExtraInfoNodes) { + core_count += slices[group_end].core_count; + group_indices.push_back(group_end); + ++group_end; + } + } + + run_group(group_indices); + for (size_t idx : group_indices) { + consumed[idx] = true; + } + ++group_start; + } + return output; +} + +torch::Tensor x_flash_attention_infer_decode_with_batch_split( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& kv_seq_lens, + int64_t q_head, + int64_t kv_head, + double scale) { + CHECK_EQ(query.dim(), 3) + << "decode query must be [batch, q_heads, head_size]"; + auto kv_seq_lens_cpu = + kv_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + CHECK_EQ(query.size(0), kv_seq_lens_cpu.numel()); + const auto* kv_seq_lens_ptr = kv_seq_lens_cpu.data_ptr(); + const int64_t batch = query.size(0); + const auto device = query.device(); + auto output = torch::empty_like(query); + + for (int64_t group_start = 0; group_start < batch;) { + int64_t group_end = group_start; + int64_t core_count = 0; + while (group_end < batch) { + const int64_t kv_len = kv_seq_lens_ptr[group_end]; + const int64_t seq_core_count = + xfa_core_count_for_slice(kv_len, kv_head, key_cache.size(1)); + CHECK_LE(seq_core_count, kXfaMaxExtraInfoNodes) + << "x_flash_attention_infer decode request needs too many cores"; + if (group_end > group_start && + core_count + seq_core_count > kXfaMaxExtraInfoNodes) { + break; + } + core_count += seq_core_count; + ++group_end; + } + + const int64_t group_batch = group_end - group_start; + std::vector group_q_lens; + std::vector group_kv_lens; + group_q_lens.reserve(group_batch); + group_kv_lens.reserve(group_batch); + int32_t q_cu_len = 0; + for (int64_t idx = group_start; idx < group_end; ++idx) { + ++q_cu_len; + group_q_lens.push_back(q_cu_len); + group_kv_lens.push_back(kv_seq_lens_ptr[idx]); + } + + auto group_query = query.narrow(0, group_start, group_batch).contiguous(); + auto group_block_table = + block_table.narrow(0, group_start, group_batch).contiguous(); + auto group_q_lens_tensor = int32_tensor_on_device(group_q_lens, device); + auto group_kv_lens_tensor = int32_tensor_on_device(group_kv_lens, device); + + auto group_output = + xllm::kernel::npu::x_flash_attention_infer(group_query, + key_cache, + value_cache, + mask, + group_block_table, + group_q_lens_tensor, + group_kv_lens_tensor, + q_head, + kv_head, + scale, + "TND"); + output.narrow(0, group_start, group_batch).copy_(group_output); + group_start = group_end; + } + return output; +} + +torch::Tensor x_flash_attention_infer_decode_for_graph( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& q_cu_seq_lens, + const torch::Tensor& kv_seq_lens, + const torch::Tensor& extra_tiling, + int64_t q_head, + int64_t kv_head, + double scale) { + CHECK_EQ(query.dim(), 3) + << "decode query must be [batch, q_heads, head_size]"; + CHECK(q_cu_seq_lens.defined()) + << "graph decode requires device q cumulative seq lens"; + CHECK(kv_seq_lens.defined()) << "graph decode requires device kv seq lens"; + CHECK(extra_tiling.defined()) << "graph decode requires fixed extra tiling"; + auto actual_q_lens = q_cu_seq_lens; + if (actual_q_lens.numel() == query.size(0) + 1) { + actual_q_lens = + actual_q_lens.narrow(/*dim=*/0, /*start=*/1, /*length=*/query.size(0)); + } + CHECK_EQ(actual_q_lens.numel(), query.size(0)); + CHECK_EQ(kv_seq_lens.numel(), query.size(0)); + CHECK_EQ(actual_q_lens.scalar_type(), torch::kInt32); + CHECK_EQ(kv_seq_lens.scalar_type(), torch::kInt32); + return xllm::kernel::npu::x_flash_attention_infer_with_extra_tiling( + query, + key_cache, + value_cache, + mask, + block_table, + actual_q_lens, + kv_seq_lens, + extra_tiling, + q_head, + kv_head, + scale, + "TND"); +} + +} // namespace + AttentionImpl::AttentionImpl(int64_t num_heads, int64_t head_size, float scale, @@ -87,50 +478,61 @@ void AttentionImpl::prefill_forward(torch::Tensor& query, query = query.view({-1, num_heads_, head_size_}); output = output.view({-1, num_heads_, head_size_}); - if (attn_metadata.is_prefill) { - key = key.view({-1, num_kv_heads_, head_size_}); - value = value.view({-1, num_kv_heads_, head_size_}); + auto run_fia_prefill = [&]() { + if (attn_metadata.is_prefill) { + auto key_view = key.view({-1, num_kv_heads_, head_size_}); + auto value_view = value.view({-1, num_kv_heads_, head_size_}); + auto fallback_output = torch::empty_like(query); + xllm::kernel::npu::batch_prefill( + query, + key_view, + value_view, + attn_metadata.fia_attn_mask.defined() + ? attn_metadata.fia_attn_mask + : get_fia_split_fuse_attn_mask(query), + attn_metadata.q_seq_lens, + scale_, + fallback_output); + return fallback_output; + } - auto fia_result = xllm::kernel::npu::npu_fused_infer_attention( + auto fallback_output = torch::empty_like(query); + xllm::kernel::npu::batch_chunked_paged_prefill( query, - key, - value, - attn_metadata.fia_attn_mask.defined() - ? std::make_optional(attn_metadata.fia_attn_mask) - : std::nullopt, - std::nullopt, - attn_metadata.q_cu_seq_lens_host_vec, - attn_metadata.kv_cu_seq_lens_host_vec, - num_heads_, - num_kv_heads_, + k_cache, + v_cache.value(), scale_, - 0, - 3, - "TND"); - output.copy_(std::get<0>(fia_result).view_as(output)); - } else if (attn_metadata.is_chunked_prefill) { - auto k = k_cache.view({k_cache.size(0), k_cache.size(1), -1}); - auto v = v_cache.value().view( - {v_cache.value().size(0), v_cache.value().size(1), -1}); - auto fia_result = xllm::kernel::npu::npu_fused_infer_attention( - query, - k, - v, + attn_metadata.block_table, + attn_metadata.kv_seq_lens, attn_metadata.fia_attn_mask.defined() - ? std::make_optional(attn_metadata.fia_attn_mask) - : std::nullopt, - attn_metadata.block_table.defined() - ? std::make_optional(attn_metadata.block_table) - : std::nullopt, - attn_metadata.q_cu_seq_lens_host_vec, - attn_metadata.kv_seq_lens_host_vec, + ? attn_metadata.fia_attn_mask + : get_fia_split_fuse_attn_mask(query), + attn_metadata.q_seq_lens, + fallback_output); + return fallback_output; + }; + + if (attn_metadata.is_prefill || attn_metadata.is_chunked_prefill) { + if (!use_x_flash_prefill()) { + auto fia_result = run_fia_prefill(); + output.copy_(fia_result.view_as(output)); + return; + } + + auto xfa_result = x_flash_attention_infer_with_query_split( + query, + k_cache, + v_cache.value(), + get_fia_split_fuse_attn_mask(query), + attn_metadata.block_table, + attn_metadata.q_seq_lens, + attn_metadata.kv_seq_lens, + attn_metadata.q_seq_lens_host, + attn_metadata.kv_seq_lens_host, num_heads_, num_kv_heads_, - scale_, - k_cache.size(1), - 3, - "TND"); - output.copy_(std::get<0>(fia_result).view_as(output)); + scale_); + output.copy_(xfa_result.view_as(output)); } } @@ -148,11 +550,15 @@ void AttentionImpl::decoder_forward(torch::Tensor& query, if (attn_metadata.use_expanded_decode_for_spec_verify_attention) { block_table = attn_metadata.expanded_block_table; tiling_data = attn_metadata.expanded_paged_attention_tiling_data; - if (attn_metadata.expanded_kv_seq_lens_host.defined()) { + if (tiling_data.defined()) { + kv_seq_lens = attn_metadata.expanded_kv_seq_lens; + } else if (attn_metadata.expanded_kv_seq_lens_host.defined()) { kv_seq_lens = attn_metadata.expanded_kv_seq_lens_host; } else { kv_seq_lens = attn_metadata.expanded_kv_seq_lens; } + } else if (tiling_data.defined()) { + kv_seq_lens = attn_metadata.kv_seq_lens; } else if (attn_metadata.kv_seq_lens_host.defined()) { kv_seq_lens = attn_metadata.kv_seq_lens_host; } else { @@ -160,6 +566,40 @@ void AttentionImpl::decoder_forward(torch::Tensor& query, kv_seq_lens = attn_metadata.kv_seq_lens; } + if (v_cache.has_value() && use_x_flash_decode()) { + auto decode_query = query.view({-1, num_heads_, head_size_}); + torch::Tensor xfa_result; + if (tiling_data.defined()) { + xfa_result = x_flash_attention_infer_decode_for_graph( + decode_query, + k_cache, + v_cache.value(), + attn_metadata.xfa_attn_mask.defined() + ? attn_metadata.xfa_attn_mask + : get_fia_split_fuse_attn_mask(decode_query), + block_table, + attn_metadata.xfa_q_cu_seq_lens, + kv_seq_lens, + attn_metadata.xfa_extra_tiling, + num_heads_, + num_kv_heads_, + scale_); + } else { + xfa_result = x_flash_attention_infer_decode_with_batch_split( + decode_query, + k_cache, + v_cache.value(), + get_fia_split_fuse_attn_mask(decode_query), + block_table, + kv_seq_lens, + num_heads_, + num_kv_heads_, + scale_); + } + output.copy_(xfa_result.view_as(output)); + return; + } + if (tiling_data.defined()) { // Use CustomPagedAttention for ACL graph mode to avoid .to(kCPU) operations diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index 61e2a04c35..07e7057209 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -524,23 +524,64 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( if (!use_spec_verify && is_any_prefill) { torch::IntArrayRef num_accepted_tokens_opt; - std::vector linear_state_indices_vec( - input_params.linear_state_ids.begin(), - input_params.linear_state_ids.end()); torch::Tensor conv_input = reshape_qkvz_unpad(attn_metadata, mixed_qkv); - mixed_qkv = xllm::kernel::causal_conv1d( - conv_input, - conv_weight, - conv_cache, - std::optional(), // bias (no bias for qwen3) - torch::IntArrayRef(input_params.query_start_loc), - torch::IntArrayRef(linear_state_indices_vec), - torch::IntArrayRef(input_params.has_initial_state), - num_accepted_tokens_opt, - 1, // activation_mode - -1, // pad_slot_id - 0 // run mode 0:fn, 1:update - ); + if (attn_metadata.is_chunked_prefill && batch_size > 1) { + CHECK_GE(attn_metadata.q_seq_lens_vec.size(), + static_cast(batch_size)) + << "q_seq_lens_vec must be populated for Qwen3.5 chunked conv."; + CHECK_EQ(input_params.linear_state_ids.size(), + static_cast(batch_size)) + << "linear_state_ids must be sequence-scoped for Qwen3.5 conv."; + CHECK_EQ(input_params.has_initial_state.size(), + static_cast(batch_size)) + << "has_initial_state must be sequence-scoped for Qwen3.5 conv."; + std::vector conv_outputs; + conv_outputs.reserve(batch_size); + int64_t offset = 0; + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + CHECK_GT(valid_len, 0) << "Qwen3.5 conv sequence length must be > 0"; + auto conv_slice = conv_input.narrow(0, offset, valid_len).contiguous(); + offset += valid_len; + std::vector query_start_loc = {0, valid_len}; + std::vector linear_state_indices = { + input_params.linear_state_ids[batch_idx]}; + std::vector has_initial_state = { + input_params.has_initial_state[batch_idx]}; + conv_outputs.emplace_back(xllm::kernel::causal_conv1d( + conv_slice, + conv_weight, + conv_cache, + std::optional(), // bias (no bias for qwen3) + torch::IntArrayRef(query_start_loc), + torch::IntArrayRef(linear_state_indices), + torch::IntArrayRef(has_initial_state), + num_accepted_tokens_opt, + 1, // activation_mode + -1, // pad_slot_id + 0 // run mode 0:fn, 1:update + )); + } + CHECK_EQ(offset, conv_input.size(0)); + mixed_qkv = torch::cat(conv_outputs, 0).contiguous(); + } else { + std::vector linear_state_indices_vec( + input_params.linear_state_ids.begin(), + input_params.linear_state_ids.end()); + mixed_qkv = xllm::kernel::causal_conv1d( + conv_input, + conv_weight, + conv_cache, + std::optional(), // bias (no bias for qwen3) + torch::IntArrayRef(input_params.query_start_loc), + torch::IntArrayRef(linear_state_indices_vec), + torch::IntArrayRef(input_params.has_initial_state), + num_accepted_tokens_opt, + 1, // activation_mode + -1, // pad_slot_id + 0 // run mode 0:fn, 1:update + ); + } mixed_qkv = reshape_qkvz_with_pad(attn_metadata, mixed_qkv); mixed_qkv = mixed_qkv.transpose(1, 2); @@ -586,7 +627,7 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( conv_weight.transpose(0, 1).contiguous(); xllm::kernel::CausalConv1dUpdateParams conv1d_params; conv1d_params.x = mixed_qkv.reshape({-1, mixed_qkv.size(-1)}); - conv1d_params.conv_state = conv_cache; + conv1d_params.conv_state = conv_cache.transpose(1, 2); conv1d_params.weight = conv_weight_for_update; conv1d_params.conv_state_indices = logical_state_indices; conv1d_params.block_idx_last_scheduled_token = @@ -668,76 +709,205 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( CHECK_GE(attn_metadata.q_seq_lens_vec.size(), static_cast(batch_size)) << "q_seq_lens_vec must be populated for Qwen3.5 prefill."; - std::vector packed_q; - std::vector packed_k; - std::vector packed_v; - std::vector packed_g; - std::vector packed_beta; - packed_q.reserve(batch_size); - packed_k.reserve(batch_size); - packed_v.reserve(batch_size); - packed_g.reserve(batch_size); - packed_beta.reserve(batch_size); - for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; - packed_q.emplace_back(processed_q[batch_idx].narrow(0, 0, valid_len)); - packed_k.emplace_back(processed_k[batch_idx].narrow(0, 0, valid_len)); - packed_v.emplace_back(processed_v[batch_idx].narrow(0, 0, valid_len)); - packed_g.emplace_back(g[batch_idx].narrow(0, 0, valid_len)); - packed_beta.emplace_back(beta[batch_idx].narrow(0, 0, valid_len)); - } - torch::Tensor packed_processed_q = torch::cat(packed_q, 0).unsqueeze(0); - torch::Tensor packed_processed_k = torch::cat(packed_k, 0).unsqueeze(0); - torch::Tensor packed_processed_v = torch::cat(packed_v, 0).unsqueeze(0); - torch::Tensor packed_g_tensor = torch::cat(packed_g, 0).unsqueeze(0); - torch::Tensor packed_beta_tensor = torch::cat(packed_beta, 0).unsqueeze(0); - - xllm::kernel::ChunkGatedDeltaRuleParams chunk_gated_delta_params; - chunk_gated_delta_params.q = packed_processed_q; - chunk_gated_delta_params.k = packed_processed_k; - chunk_gated_delta_params.v = packed_processed_v; - chunk_gated_delta_params.g = packed_g_tensor; - chunk_gated_delta_params.beta = packed_beta_tensor; - // Get initial state from ssm_cache for sequences with previous state - // Shape: [batch_size, num_heads, head_k_dim, head_v_dim] - torch::Tensor initial_state_tensor = - torch::index_select(ssm_cache, 0, linear_state_base_indices); - CHECK_EQ(input_params.has_initial_state.size(), - input_params.linear_state_ids.size()) - << "has_initial_state must be sequence-scoped."; - for (size_t i = 0; i < input_params.has_initial_state.size(); ++i) { - if (input_params.has_initial_state[i] == 0) { - initial_state_tensor.select(0, static_cast(i)).fill_(0.0); + if (attn_metadata.is_chunked_prefill && fla_ssm_state_layout) { + double scale = 1.0 / std::sqrt(static_cast(processed_q.size(-1))); + std::vector packed_q; + std::vector packed_k; + std::vector packed_v; + std::vector packed_a; + std::vector packed_b; + packed_q.reserve(batch_size); + packed_k.reserve(batch_size); + packed_v.reserve(batch_size); + packed_a.reserve(batch_size); + packed_b.reserve(batch_size); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + CHECK_GT(valid_len, 0) + << "Qwen3.5 chunked GDN sequence length must be > 0"; + packed_q.emplace_back(processed_q[batch_idx].narrow(0, 0, valid_len)); + packed_k.emplace_back(processed_k[batch_idx].narrow(0, 0, valid_len)); + packed_v.emplace_back(processed_v[batch_idx].narrow(0, 0, valid_len)); + packed_a.emplace_back(a[batch_idx].narrow(0, 0, valid_len)); + packed_b.emplace_back(b[batch_idx].narrow(0, 0, valid_len)); } - } + torch::Tensor packed_q_tensor = torch::cat(packed_q, 0).unsqueeze(0); + torch::Tensor packed_k_tensor = torch::cat(packed_k, 0).unsqueeze(0); + torch::Tensor packed_v_tensor = torch::cat(packed_v, 0).unsqueeze(0); + torch::Tensor packed_a_tensor = torch::cat(packed_a, 0).unsqueeze(0); + torch::Tensor packed_b_tensor = torch::cat(packed_b, 0).unsqueeze(0); - if (!fla_ssm_state_layout && attn_metadata.is_chunked_prefill) { - initial_state_tensor = - initial_state_tensor.transpose(-1, -2).contiguous(); - } + CHECK_EQ(input_params.has_initial_state.size(), + input_params.linear_state_ids.size()) + << "has_initial_state must be sequence-scoped."; + std::vector zero_state_indices; + zero_state_indices.reserve(input_params.has_initial_state.size()); + for (size_t i = 0; i < input_params.has_initial_state.size(); ++i) { + if (input_params.has_initial_state[i] == 0) { + zero_state_indices.emplace_back( + linear_state_base_indices.narrow(0, static_cast(i), 1)); + } + } + if (!zero_state_indices.empty()) { + auto indices = torch::cat(zero_state_indices, 0).contiguous(); + auto zero_state = torch::zeros({indices.size(0), + ssm_cache.size(1), + ssm_cache.size(2), + ssm_cache.size(3)}, + ssm_cache.options()); + ssm_cache.index_put_({indices}, zero_state); + } - chunk_gated_delta_params.initial_state = initial_state_tensor; - chunk_gated_delta_params.output_final_state = true; - chunk_gated_delta_params.cu_seqlens = attn_metadata.q_cu_seq_lens; - chunk_gated_delta_params.head_first = false; - chunk_gated_delta_params.use_qk_l2norm_in_kernel = true; - torch::Tensor packed_core_attn_out; - std::tie(packed_core_attn_out, last_recurrent_state) = - xllm::kernel::chunk_gated_delta_rule(chunk_gated_delta_params); - core_attn_out = torch::zeros_like(processed_v); - int64_t packed_offset = 0; - for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; - core_attn_out[batch_idx] - .narrow(0, 0, valid_len) - .copy_(packed_core_attn_out[0].narrow(0, packed_offset, valid_len)); - packed_offset += valid_len; + xllm::kernel::FusedSigmoidGatingDeltaRuleUpdateParams params; + params.A_log = A_log_.contiguous(); + params.a = packed_a_tensor.contiguous(); + params.dt_bias = dt_bias_.contiguous(); + params.q = packed_q_tensor.contiguous(); + params.k = packed_k_tensor.contiguous(); + params.v = packed_v_tensor.contiguous(); + params.b = packed_b_tensor.contiguous(); + params.initial_state_source = ssm_cache; + params.initial_state_indices = linear_state_base_indices.contiguous(); + params.cu_seqlens = attn_metadata.q_cu_seq_lens.contiguous(); + params.scale = static_cast(scale); + params.use_qk_l2norm_in_kernel = true; + params.softplus_beta = 1.0f; + params.softplus_threshold = 20.0f; + torch::Tensor packed_core_attn_out = + xllm::kernel::fused_sigmoid_gating_delta_rule_update(params); + core_attn_out = torch::zeros_like(processed_v); + int64_t packed_offset = 0; + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + core_attn_out[batch_idx] + .narrow(0, 0, valid_len) + .copy_(packed_core_attn_out[0].narrow(0, packed_offset, valid_len)); + packed_offset += valid_len; + } + } else if (attn_metadata.is_chunked_prefill && batch_size > 1) { + CHECK_EQ(input_params.has_initial_state.size(), + input_params.linear_state_ids.size()) + << "has_initial_state must be sequence-scoped."; + core_attn_out = torch::zeros_like(processed_v); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + CHECK_GT(valid_len, 0) + << "Qwen3.5 chunked GDN sequence length must be > 0"; + torch::Tensor seq_state_index = + linear_state_base_indices.narrow(0, batch_idx, 1).contiguous(); + torch::Tensor initial_state_tensor = + torch::index_select(ssm_cache, 0, seq_state_index); + if (input_params.has_initial_state[batch_idx] == 0) { + initial_state_tensor.fill_(0.0); + } + if (!fla_ssm_state_layout) { + initial_state_tensor = + initial_state_tensor.transpose(-1, -2).contiguous(); + } + + xllm::kernel::ChunkGatedDeltaRuleParams chunk_gated_delta_params; + chunk_gated_delta_params.q = + processed_q[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.k = + processed_k[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.v = + processed_v[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.g = + g[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.beta = + beta[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.initial_state = initial_state_tensor; + chunk_gated_delta_params.output_final_state = true; + chunk_gated_delta_params.cu_seqlens = torch::tensor( + std::vector{0, static_cast(valid_len)}, + torch::TensorOptions().dtype(torch::kInt32).device(device)); + chunk_gated_delta_params.head_first = false; + chunk_gated_delta_params.use_qk_l2norm_in_kernel = true; + torch::Tensor seq_core_attn_out; + std::tie(seq_core_attn_out, last_recurrent_state) = + xllm::kernel::chunk_gated_delta_rule(chunk_gated_delta_params); + core_attn_out[batch_idx] + .narrow(0, 0, valid_len) + .copy_(seq_core_attn_out[0].narrow(0, 0, valid_len)); + torch::Tensor state_to_store = + fla_ssm_state_layout ? last_recurrent_state + : last_recurrent_state.transpose(-1, -2); + ssm_cache.index_put_({seq_state_index}, + state_to_store.to(ssm_cache.dtype())); + } + } else { + std::vector packed_q; + std::vector packed_k; + std::vector packed_v; + std::vector packed_g; + std::vector packed_beta; + packed_q.reserve(batch_size); + packed_k.reserve(batch_size); + packed_v.reserve(batch_size); + packed_g.reserve(batch_size); + packed_beta.reserve(batch_size); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + packed_q.emplace_back(processed_q[batch_idx].narrow(0, 0, valid_len)); + packed_k.emplace_back(processed_k[batch_idx].narrow(0, 0, valid_len)); + packed_v.emplace_back(processed_v[batch_idx].narrow(0, 0, valid_len)); + packed_g.emplace_back(g[batch_idx].narrow(0, 0, valid_len)); + packed_beta.emplace_back(beta[batch_idx].narrow(0, 0, valid_len)); + } + torch::Tensor packed_processed_q = torch::cat(packed_q, 0).unsqueeze(0); + torch::Tensor packed_processed_k = torch::cat(packed_k, 0).unsqueeze(0); + torch::Tensor packed_processed_v = torch::cat(packed_v, 0).unsqueeze(0); + torch::Tensor packed_g_tensor = torch::cat(packed_g, 0).unsqueeze(0); + torch::Tensor packed_beta_tensor = + torch::cat(packed_beta, 0).unsqueeze(0); + + xllm::kernel::ChunkGatedDeltaRuleParams chunk_gated_delta_params; + chunk_gated_delta_params.q = packed_processed_q; + chunk_gated_delta_params.k = packed_processed_k; + chunk_gated_delta_params.v = packed_processed_v; + chunk_gated_delta_params.g = packed_g_tensor; + chunk_gated_delta_params.beta = packed_beta_tensor; + // Get initial state from ssm_cache for sequences with previous state + // Shape: [batch_size, num_heads, head_k_dim, head_v_dim] + torch::Tensor initial_state_tensor = + torch::index_select(ssm_cache, 0, linear_state_base_indices); + CHECK_EQ(input_params.has_initial_state.size(), + input_params.linear_state_ids.size()) + << "has_initial_state must be sequence-scoped."; + for (size_t i = 0; i < input_params.has_initial_state.size(); ++i) { + if (input_params.has_initial_state[i] == 0) { + initial_state_tensor.select(0, static_cast(i)).fill_(0.0); + } + } + + if (!fla_ssm_state_layout && attn_metadata.is_chunked_prefill) { + initial_state_tensor = + initial_state_tensor.transpose(-1, -2).contiguous(); + } + + chunk_gated_delta_params.initial_state = initial_state_tensor; + chunk_gated_delta_params.output_final_state = true; + chunk_gated_delta_params.cu_seqlens = attn_metadata.q_cu_seq_lens; + chunk_gated_delta_params.head_first = false; + chunk_gated_delta_params.use_qk_l2norm_in_kernel = true; + torch::Tensor packed_core_attn_out; + std::tie(packed_core_attn_out, last_recurrent_state) = + xllm::kernel::chunk_gated_delta_rule(chunk_gated_delta_params); + core_attn_out = torch::zeros_like(processed_v); + int64_t packed_offset = 0; + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + core_attn_out[batch_idx] + .narrow(0, 0, valid_len) + .copy_(packed_core_attn_out[0].narrow(0, packed_offset, valid_len)); + packed_offset += valid_len; + } + torch::Tensor state_to_store = + fla_ssm_state_layout ? last_recurrent_state + : last_recurrent_state.transpose(-1, -2); + ssm_cache.index_put_({linear_state_base_indices}, + state_to_store.to(ssm_cache.dtype())); } - torch::Tensor state_to_store = fla_ssm_state_layout - ? last_recurrent_state - : last_recurrent_state.transpose(-1, -2); - ssm_cache.index_put_({linear_state_base_indices}, - state_to_store.to(ssm_cache.dtype())); } else if (checkpoint_stride > 1) { auto ssm_state = torch::index_select(ssm_cache, 0, linear_state_base_indices); diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 43b2c7aca8..04885c6e1a 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -34,6 +34,7 @@ limitations under the License. #endif #include "core/common/global_flags.h" #include "core/common/metrics.h" +#include "core/kernels/npu/xllm_ops/xllm_ops_api.h" #include "core/util/utils.h" #include "platform/npu/device_capture_lock.h" @@ -51,6 +52,7 @@ namespace xllm::npu { namespace { constexpr uint64_t kSpecVerifyGraphKeyMask = 1ull << 63; constexpr uint64_t kSpecVerifyQMaxSeqLenShift = 32; +constexpr int64_t kFiaSplitFuseMaskSize = 2048; std::pair find_attention_plan_kv_cache( const std::vector& kv_caches) { @@ -170,6 +172,18 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args, persistent_mask_ = torch::zeros({max_graph_tokens, max_seq_len}, torch::dtype(dtype).device(device)); } + xfa_extra_tiling_ = torch::zeros( + {xllm::kernel::npu::x_flash_attention_extra_tiling_int32_count()}, + torch::dtype(torch::kInt32).device(device)); + xfa_attn_mask_ = + torch::triu(torch::ones({kFiaSplitFuseMaskSize, kFiaSplitFuseMaskSize}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(torch::kCPU)), + 1) + .to(torch::kInt8) + .to(device) + .contiguous(); // Do not need to create ATB context and custom paged attention operation if (args_.head_dim() == 0) { @@ -447,9 +461,10 @@ std::optional GraphPersistentParam::update( .slice(/*dim=*/0, /*start=*/0, /*end=*/embedding_tokens) .copy_(embedding, /*non_blocking=*/true); } - // Update q_cu_seq_lens only if params.q_cu_seq_lens is defined + // Preserve the historical q_cu_seq_lens semantics for Qwen3.5 GDN/conv + // update kernels. x_flash_attention graph replay uses a separate + // query-start buffer below. if (params.q_cu_seq_lens.defined()) { - // Lazy initialization: if q_cu_seq_lens_ is not initialized, initialize it if (q_cu_seq_lens_.numel() == 0) { const int64_t max_seqs_per_batch = get_decode_graph_capacity(options_); q_cu_seq_lens_ = torch::zeros({max_seqs_per_batch + 1}, @@ -464,8 +479,6 @@ std::optional GraphPersistentParam::update( CHECK_GE(params.q_cu_seq_lens.numel(), required_q_cu_seq_lens) << "q_cu_seq_lens does not have enough entries for ACL graph execution"; if (use_qwen3_5_query_start_loc && !input_has_leading_zero) { - // Normal Qwen3.5 decode input carries cumsum without the leading zero, - // while update kernels expect query_start_loc-style [0, cumsum...]. q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/1).zero_(); q_cu_seq_lens_ .slice(/*dim=*/0, /*start=*/1, /*end=*/actual_batch_size + 1) @@ -528,6 +541,48 @@ std::optional GraphPersistentParam::update( expanded_kv_seq_lens_vec = update_expanded_spec_decode_attention( params, actual_num_tokens, padded_num_tokens, actual_batch_size); } + if (xfa_q_cu_seq_lens_.numel() == 0) { + const int64_t max_graph_tokens = get_decode_graph_capacity(options_); + xfa_q_cu_seq_lens_ = torch::zeros( + {max_graph_tokens + 1}, torch::dtype(torch::kInt).device(device_)); + } + xfa_q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/1).zero_(); + int32_t xfa_q_offset = 0; + std::vector xfa_q_cu_seq_lens_vec; + if (use_expanded_spec_decode_attention) { + xfa_q_cu_seq_lens_vec.reserve(padded_num_tokens); + for (uint32_t i = 0; i < padded_num_tokens; ++i) { + ++xfa_q_offset; + xfa_q_cu_seq_lens_vec.emplace_back(xfa_q_offset); + } + } else { + xfa_q_cu_seq_lens_vec.reserve(padded_batch_size); + for (int64_t i = 0; i < padded_batch_size; ++i) { + xfa_q_offset += padded_q_seq_lens_vec[i]; + xfa_q_cu_seq_lens_vec.emplace_back(xfa_q_offset); + } + } + xfa_q_cu_seq_lens_ + .slice(/*dim=*/0, + /*start=*/1, + /*end=*/static_cast(xfa_q_cu_seq_lens_vec.size()) + 1) + .copy_(torch::tensor(xfa_q_cu_seq_lens_vec, torch::kInt).to(device_), + /*non_blocking=*/true); + + ModelInputParams xfa_params = params; + if (use_expanded_spec_decode_attention) { + xfa_params.num_sequences = padded_num_tokens; + xfa_params.kv_seq_lens_vec = expanded_kv_seq_lens_vec; + xfa_params.q_seq_lens_vec = std::vector(padded_num_tokens, 1); + } else { + xfa_params.num_sequences = padded_batch_size; + xfa_params.kv_seq_lens_vec = padded_kv_seq_lens_vec; + xfa_params.q_seq_lens_vec = padded_q_seq_lens_vec; + } + if (k_cache.defined() && k_cache.numel() > 0) { + update_x_flash_attention_extra_tiling( + xfa_params, xfa_params.num_sequences, k_cache); + } if (tiling_data_.numel() > 0) { aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); @@ -591,6 +646,8 @@ std::optional GraphPersistentParam::update( persistent_mask(padded_num_tokens); } params_for_capture->graph_buffer.tiling_data = tiling_data(); + params_for_capture->graph_buffer.xfa_extra_tiling = xfa_extra_tiling(); + params_for_capture->graph_buffer.xfa_attn_mask = xfa_attn_mask(); // Set persistent embedding if available if (params.input_embedding.defined()) { params_for_capture->input_embedding = @@ -612,15 +669,19 @@ std::optional GraphPersistentParam::update( params_for_capture->graph_buffer.expanded_tiling_data = tiling_data(); params_for_capture->graph_buffer.expanded_kv_seq_lens_vec = expanded_kv_seq_lens_vec; + params_for_capture->graph_buffer.xfa_q_cu_seq_lens = + xfa_q_cu_seq_lens().slice(/*dim=*/0, + /*start=*/0, + /*end=*/padded_num_tokens + 1); + } else { + params_for_capture->graph_buffer.xfa_q_cu_seq_lens = + xfa_q_cu_seq_lens().slice(/*dim=*/0, + /*start=*/0, + /*end=*/padded_batch_size + 1); } - // Set q_cu_seq_lens if available - if (params.q_cu_seq_lens.defined()) { - const bool use_qwen3_5_query_start_loc = - is_qwen3_5_model_type(args_.model_type()); + if (q_cu_seq_lens_.defined() && q_cu_seq_lens_.numel() > 0) { params_for_capture->q_cu_seq_lens = q_cu_seq_lens_.slice( - /*dim=*/0, - /*start=*/0, - /*end=*/padded_batch_size + (use_qwen3_5_query_start_loc ? 1 : 0)); + /*dim=*/0, /*start=*/0, /*end=*/padded_batch_size + 1); } return params_for_capture; @@ -965,6 +1026,41 @@ void GraphPersistentParam::plan_paged_attention_tiling( CHECK_EQ(acl_status, ACL_SUCCESS) << "Failed to copy tiling buffer to device"; } +void GraphPersistentParam::update_x_flash_attention_extra_tiling( + const ModelInputParams& input_params, + uint32_t padded_batch_size, + const torch::Tensor& k_cache) { + CHECK(k_cache.defined() && k_cache.dim() == 4) + << "x_flash_attention graph extra tiling needs valid k_cache"; + const int dp_local_tp_size = options_.world_size() / options_.dp_size(); + const int64_t q_head = args_.n_heads() / dp_local_tp_size; + const int64_t kv_head = std::max( + 1, args_.n_kv_heads().value_or(args_.n_heads()) / dp_local_tp_size); + std::vector q_lens; + std::vector kv_lens; + q_lens.reserve(padded_batch_size); + kv_lens.reserve(padded_batch_size); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + CHECK_LT(i, input_params.q_seq_lens_vec.size()) + << "q_seq_lens_vec is shorter than padded_batch_size"; + CHECK_LT(i, input_params.kv_seq_lens_vec.size()) + << "kv_seq_lens_vec is shorter than padded_batch_size"; + const int32_t q_len = input_params.q_seq_lens_vec[i]; + const int32_t kv_len = input_params.kv_seq_lens_vec[i]; + CHECK_EQ(q_len, 1) + << "x_flash_attention graph decode path expects q_len == 1"; + q_lens.emplace_back(static_cast(i + 1)); + kv_lens.emplace_back(kv_len); + } + xllm::kernel::npu::update_x_flash_attention_extra_tiling(q_lens, + kv_lens, + q_head, + kv_head, + k_cache.size(1), + k_cache.size(3), + xfa_extra_tiling_); +} + void GraphPersistentParam::update_attention_mask( const ModelInputParams& input_params) { torch::Dtype dtype = util::parse_dtype(args_.dtype(), device_); diff --git a/xllm/core/runtime/acl_graph_executor_impl.h b/xllm/core/runtime/acl_graph_executor_impl.h index deaca700bc..4bb9177a3f 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.h +++ b/xllm/core/runtime/acl_graph_executor_impl.h @@ -118,6 +118,9 @@ class GraphPersistentParam { return persistent_mask_; } const torch::Tensor& tiling_data() const { return tiling_data_; } + const torch::Tensor& xfa_q_cu_seq_lens() const { return xfa_q_cu_seq_lens_; } + const torch::Tensor& xfa_extra_tiling() const { return xfa_extra_tiling_; } + const torch::Tensor& xfa_attn_mask() const { return xfa_attn_mask_; } torch::Tensor hidden_states(uint32_t actual_tokens = 0) const { if (actual_tokens > 0) { return hidden_states_.slice( @@ -201,6 +204,11 @@ class GraphPersistentParam { const ModelInputParams& input_params, aclrtStream stream); + void update_x_flash_attention_extra_tiling( + const ModelInputParams& input_params, + uint32_t padded_batch_size, + const torch::Tensor& k_cache); + std::vector update_expanded_spec_decode_attention( const ModelInputParams& input_params, uint32_t actual_num_tokens, @@ -229,6 +237,7 @@ class GraphPersistentParam { // for deepseekv3.2 torch::Tensor q_cu_seq_lens_; + torch::Tensor xfa_q_cu_seq_lens_; // for mtp model torch::Tensor persistent_embedding_; @@ -248,6 +257,8 @@ class GraphPersistentParam { // Persistent paged attention tiling tensor on device torch::Tensor tiling_data_; + torch::Tensor xfa_extra_tiling_; + torch::Tensor xfa_attn_mask_; // Cached attention parameters int32_t num_head_;