From ee7a139fc637019598a1de919c76ac99356828ce Mon Sep 17 00:00:00 2001 From: default Date: Mon, 11 May 2026 19:20:31 +0800 Subject: [PATCH 1/7] feat(npu): use FIA for qwen3.5 prefill attention --- xllm/core/kernels/npu/CMakeLists.txt | 1 + .../kernels/npu/npu_fused_infer_attention.cpp | 199 ++++++++++++++++++ xllm/core/kernels/npu/npu_ops_api.h | 18 ++ .../common/attention_metadata_builder.cpp | 1 + xllm/core/layers/npu_torch/attention.cpp | 106 ++++++++-- 5 files changed, 311 insertions(+), 14 deletions(-) create mode 100644 xllm/core/kernels/npu/npu_fused_infer_attention.cpp diff --git a/xllm/core/kernels/npu/CMakeLists.txt b/xllm/core/kernels/npu/CMakeLists.txt index 1d2f41f974..3054fdf314 100644 --- a/xllm/core/kernels/npu/CMakeLists.txt +++ b/xllm/core/kernels/npu/CMakeLists.txt @@ -16,6 +16,7 @@ cc_library( fused_layernorm.cpp matmul.cpp npu_causal_conv1d.cpp + npu_fused_infer_attention.cpp npu_gemma_rms_norm.cpp npu_grouped_matmul.cpp npu_moe_gating_topk_softmax.cpp diff --git a/xllm/core/kernels/npu/npu_fused_infer_attention.cpp b/xllm/core/kernels/npu/npu_fused_infer_attention.cpp new file mode 100644 index 0000000000..69b701bf8b --- /dev/null +++ b/xllm/core/kernels/npu/npu_fused_infer_attention.cpp @@ -0,0 +1,199 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "core/kernels/npu/aclnn/pytorch_npu_helper.hpp" +#include "core/kernels/npu/npu_ops_api.h" +#include "core/kernels/npu/utils.h" + +namespace { + +constexpr int64_t kSwaIntMax = 2147483647; + +torch::Tensor infer_attention_output( + const torch::Tensor& query, + const torch::Tensor& value, + const std::optional& block_table, + int64_t num_heads, + const std::string& input_layout) { + if (input_layout == "TND" || input_layout == "NTD") { + int64_t value_dim = query.size(-1); + if (!block_table.has_value() && value.dim() >= 3) { + value_dim = value.size(-1); + } + return torch::empty({query.size(0), num_heads, value_dim}, query.options()); + } + + if (input_layout == "BSH") { + return torch::empty_like(query); + } + + if (input_layout == "BNSD") { + int64_t value_dim = query.size(-1); + if (!block_table.has_value() && value.dim() >= 4) { + value_dim = value.size(-1); + } + return torch::empty( + {query.size(0), query.size(1), query.size(2), value_dim}, + query.options()); + } + + LOG(FATAL) << "Unsupported FIA input_layout: " << input_layout; + return torch::Tensor(); +} + +torch::Tensor infer_softmax_lse(const torch::Tensor& query, + int64_t num_heads, + const std::string& input_layout, + bool softmax_lse_flag) { + auto options = query.options().dtype(torch::kFloat32); + if (!softmax_lse_flag) { + return torch::empty({0}, options); + } + + if (input_layout == "TND" || input_layout == "NTD") { + return torch::empty({query.size(0), num_heads, 1}, options); + } + + if (input_layout == "BSH") { + return torch::empty({query.size(0), num_heads, query.size(1), 1}, options); + } + + if (input_layout == "BNSD") { + return torch::empty({query.size(0), query.size(1), query.size(2), 1}, + options); + } + + LOG(FATAL) << "Unsupported FIA input_layout: " << input_layout; + return torch::Tensor(); +} + +c10::optional to_c10_optional_tensor( + const std::optional& tensor_opt) { + if (tensor_opt.has_value() && tensor_opt.value().defined()) { + return tensor_opt.value(); + } + return c10::nullopt; +} + +} // namespace + +namespace xllm::kernel::npu { + +std::tuple npu_fused_infer_attention( + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + const std::optional& atten_mask, + const std::optional& block_table, + const std::vector& actual_seq_lengths, + const std::vector& actual_seq_lengths_kv, + int64_t num_heads, + int64_t num_key_value_heads, + double scale, + int64_t block_size, + int64_t sparse_mode, + const std::string& input_layout, + bool softmax_lse_flag) { + check_tensor(query, "query", "npu_fused_infer_attention"); + check_tensor(key, "key", "npu_fused_infer_attention"); + check_tensor(value, "value", "npu_fused_infer_attention"); + CHECK_GT(num_heads, 0) << "num_heads must be positive"; + CHECK(!actual_seq_lengths.empty()) << "actual_seq_lengths must not be empty"; + CHECK(!actual_seq_lengths_kv.empty()) + << "actual_seq_lengths_kv must not be empty"; + + torch::Tensor output = infer_attention_output( + query, value, block_table, num_heads, input_layout); + torch::Tensor softmax_lse = + infer_softmax_lse(query, num_heads, input_layout, softmax_lse_flag); + + std::vector key_tensors_vec{key}; + std::vector value_tensors_vec{value}; + at::TensorList key_tensors(key_tensors_vec); + at::TensorList value_tensors(value_tensors_vec); + + c10::optional none_tensor = c10::nullopt; + c10::optional atten_mask_tensor = + to_c10_optional_tensor(atten_mask); + c10::optional block_table_tensor = + to_c10_optional_tensor(block_table); + + at::IntArrayRef actual_seq_lengths_ref(actual_seq_lengths); + at::IntArrayRef actual_seq_lengths_kv_ref(actual_seq_lengths_kv); + c10::optional actual_seq_lengths_opt = + actual_seq_lengths_ref; + c10::optional actual_seq_lengths_kv_opt = + actual_seq_lengths_kv_ref; + c10::optional none_int_array = c10::nullopt; + + std::string layout = input_layout; + char* input_layout_ptr = const_cast(layout.c_str()); + int64_t pre_tokens = kSwaIntMax; + int64_t next_tokens = 0; + int64_t inner_precise = 0; + int64_t antiquant_mode = 0; + int64_t key_antiquant_mode = 0; + int64_t value_antiquant_mode = 0; + + EXEC_NPU_CMD(aclnnFusedInferAttentionScoreV3, + query, + key_tensors, + value_tensors, + none_tensor, // pse_shift + atten_mask_tensor, + actual_seq_lengths_opt, + actual_seq_lengths_kv_opt, + none_tensor, // dequant_scale1 + none_tensor, // quant_scale1 + none_tensor, // dequant_scale2 + none_tensor, // quant_scale2 + none_tensor, // quant_offset2 + none_tensor, // antiquant_scale + none_tensor, // antiquant_offset + block_table_tensor, + none_tensor, // query_padding_size + none_tensor, // kv_padding_size + none_tensor, // key_antiquant_scale + none_tensor, // key_antiquant_offset + none_tensor, // value_antiquant_scale + none_tensor, // value_antiquant_offset + none_tensor, // key_shared_prefix + none_tensor, // value_shared_prefix + none_int_array, // actual_shared_prefix_len + none_tensor, // query_rope + none_tensor, // key_rope + none_tensor, // key_rope_antiquant_scale + num_heads, + scale, + pre_tokens, + next_tokens, + input_layout_ptr, + num_key_value_heads, + sparse_mode, + inner_precise, + block_size, + antiquant_mode, + softmax_lse_flag, + key_antiquant_mode, + value_antiquant_mode, + output, + softmax_lse); + + return {output, softmax_lse}; +} + +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/npu_ops_api.h b/xllm/core/kernels/npu/npu_ops_api.h index f29e44ac04..1de0fe25ad 100644 --- a/xllm/core/kernels/npu/npu_ops_api.h +++ b/xllm/core/kernels/npu/npu_ops_api.h @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include #include +#include #include "custom_functions_npu/atb_common.h" @@ -45,6 +47,22 @@ void batch_decode(const torch::Tensor& query, const torch::Tensor& seq_lens, torch::Tensor& output); +std::tuple npu_fused_infer_attention( + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + const std::optional& atten_mask, + const std::optional& block_table, + const std::vector& actual_seq_lengths, + const std::vector& actual_seq_lengths_kv, + int64_t num_heads, + int64_t num_key_value_heads, + double scale, + int64_t block_size, + int64_t sparse_mode, + const std::string& input_layout, + bool softmax_lse_flag = false); + // Custom batch decode for ACL graph execution // This variant uses CustomPagedAttention to avoid .to(kCPU) operations // that break ACL graph capture diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index fc74993ab3..8e65c0b81c 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -99,6 +99,7 @@ AttentionMetadata build_attention_metadata( params.batch_forward_type.is_mixed() || params.batch_forward_type.is_chunked_prefill(); attn_metadata.is_prefill = params.batch_forward_type.is_prefill(); + if (!attn_metadata.is_prefill || enable_mla) { attn_metadata.block_table = params.block_tables; #if !defined(USE_NPU) && !defined(USE_CUDA) diff --git a/xllm/core/layers/npu_torch/attention.cpp b/xllm/core/layers/npu_torch/attention.cpp index eb2b7c6c2b..95c9ff5a54 100644 --- a/xllm/core/layers/npu_torch/attention.cpp +++ b/xllm/core/layers/npu_torch/attention.cpp @@ -15,6 +15,12 @@ limitations under the License. #include "attention.h" +#include +#include +#include +#include +#include + #include "kernels/npu/npu_ops_api.h" #include "kernels/ops_api.h" @@ -22,6 +28,55 @@ DECLARE_bool(enable_chunked_prefill); namespace xllm { namespace layer { +namespace { + +constexpr int64_t kFiaSplitFuseMaskSize = 2048; + +std::vector cumulative_lengths(const std::vector& seq_lens) { + std::vector cu_lens; + cu_lens.reserve(seq_lens.size()); + int64_t total = 0; + for (int32_t seq_len : seq_lens) { + total += seq_len; + cu_lens.emplace_back(total); + } + return cu_lens; +} + +std::vector to_i64_vector(const std::vector& values) { + std::vector out; + out.reserve(values.size()); + for (int32_t value : values) { + out.emplace_back(value); + } + return out; +} + +torch::Tensor get_fia_split_fuse_attn_mask(const torch::Tensor& query) { + static std::mutex mutex; + static std::unordered_map mask_cache; + + const std::string cache_key = query.device().str(); + std::lock_guard lock(mutex); + auto it = mask_cache.find(cache_key); + if (it != mask_cache.end() && it->second.defined()) { + return it->second; + } + + auto cpu_options = torch::TensorOptions().dtype(torch::kFloat32); + auto mask = + torch::triu(torch::ones({kFiaSplitFuseMaskSize, kFiaSplitFuseMaskSize}, + cpu_options), + 1) + .to(torch::kInt8) + .to(query.device()) + .contiguous(); + mask_cache[cache_key] = mask; + return mask; +} + +} // namespace + AttentionImpl::AttentionImpl(int64_t num_heads, int64_t head_size, float scale, @@ -90,21 +145,44 @@ void AttentionImpl::prefill_forward(torch::Tensor& query, key = key.view({-1, num_kv_heads_, head_size_}); value = value.view({-1, num_kv_heads_, head_size_}); - xllm::kernel::npu::batch_prefill(query, - key, - value, - attn_metadata.attn_mask, - attn_metadata.kv_seq_lens_host, - scale_, - output); + auto fia_result = xllm::kernel::npu::npu_fused_infer_attention( + query, + key, + value, + get_fia_split_fuse_attn_mask(query), + std::nullopt, + cumulative_lengths(attn_metadata.q_seq_lens_vec), + cumulative_lengths(attn_metadata.kv_seq_lens_vec), + num_heads_, + num_kv_heads_, + scale_, + 0, + 3, + "TND"); + output.copy_(std::get<0>(fia_result).view_as(output)); } else if (attn_metadata.is_chunked_prefill) { - xllm::kernel::npu::batch_prefill(query, - k_cache, - v_cache.value(), - attn_metadata.attn_mask, - attn_metadata.kv_seq_lens_host, - scale_, - output); + auto q_seq_lens = cumulative_lengths(attn_metadata.q_seq_lens_vec); + auto kv_seq_lens = to_i64_vector(attn_metadata.kv_seq_lens_vec); + auto k = k_cache.view({k_cache.size(0), k_cache.size(1), -1}); + auto v = v_cache.value().view( + {v_cache.value().size(0), v_cache.value().size(1), -1}); + auto fia_result = xllm::kernel::npu::npu_fused_infer_attention( + query, + k, + v, + get_fia_split_fuse_attn_mask(query), + attn_metadata.block_table.defined() + ? std::make_optional(attn_metadata.block_table) + : std::nullopt, + q_seq_lens, + kv_seq_lens, + num_heads_, + num_kv_heads_, + scale_, + k_cache.size(1), + 3, + "TND"); + output.copy_(std::get<0>(fia_result).view_as(output)); } } From b448a699ff2da654e63ea500a9912fa08a85f93a Mon Sep 17 00:00:00 2001 From: default Date: Wed, 13 May 2026 14:51:56 +0800 Subject: [PATCH 2/7] Support x flash attention prefill on NPU --- third_party/xllm_ops | 2 +- xllm/core/kernels/npu/xllm_ops/CMakeLists.txt | 1 + .../npu/xllm_ops/x_flash_attention_infer.cpp | 214 ++++++++++++++++++ xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h | 12 + .../common/attention_metadata_builder.cpp | 6 +- xllm/core/layers/npu_torch/attention.cpp | 144 ++++++++---- .../npu_torch/qwen3_gated_delta_net_base.cpp | 17 +- 7 files changed, 345 insertions(+), 51 deletions(-) create mode 100644 xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp diff --git a/third_party/xllm_ops b/third_party/xllm_ops index 79eb46cb95..0b26ff7f7a 160000 --- a/third_party/xllm_ops +++ b/third_party/xllm_ops @@ -1 +1 @@ -Subproject commit 79eb46cb951346acc12bfbd8b9b9170bc8a83db9 +Subproject commit 0b26ff7f7aaed1593f8e4cfbbacd37be74c18df2 diff --git a/xllm/core/kernels/npu/xllm_ops/CMakeLists.txt b/xllm/core/kernels/npu/xllm_ops/CMakeLists.txt index f8e9cda1bb..f4d17d4cc7 100644 --- a/xllm/core/kernels/npu/xllm_ops/CMakeLists.txt +++ b/xllm/core/kernels/npu/xllm_ops/CMakeLists.txt @@ -13,6 +13,7 @@ cc_library( beam_search.cpp select_unshared_kv.cpp beam_search_rec.cpp + x_flash_attention_infer.cpp DEPS atb torch_npu diff --git a/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp b/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp new file mode 100644 index 0000000000..b7e3d105aa --- /dev/null +++ b/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp @@ -0,0 +1,214 @@ +/* Copyright 2026 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "aclnn_x_flash_attention_infer.h" +#include "core/kernels/npu/aclnn/pytorch_npu_helper.hpp" +#include "core/kernels/npu/utils.h" +#include "xllm_ops_api.h" + +namespace { + +constexpr uint32_t kMaxExtraInfoNodes = 25; +constexpr uint32_t kMaxKvStackLen = 512; + +struct CoreNode { + uint32_t start_b_idx = 0; + uint32_t start_n1_idx = 0; + uint32_t start_s2_idx = 0; + uint32_t end_b_idx = 0; + uint32_t end_n1_idx = 0; + uint32_t end_s2_idx = 0; + uint64_t first_split_kv_task_lse_offset = 0; + uint64_t first_split_kv_task_o_offset = 0; +}; + +struct SplitNode { + uint32_t batch_idx = 0; + uint32_t head_start_idx = 0; + uint32_t head_end_idx = 0; + uint32_t q_start_idx = 0; + uint32_t q_end_idx = 0; + uint32_t split_num = 0; + uint64_t lse_task_offset = 0; + uint64_t o_task_offset = 0; +}; + +struct SplitKvExtraInfo { + CoreNode core_info[kMaxExtraInfoNodes]; + SplitNode split_info[kMaxExtraInfoNodes]; + uint32_t total_split_node_num = 0; +}; + +torch::Tensor make_extra_tiling(const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + int64_t block_size, + int64_t head_size) { + CHECK_EQ(actual_q_lens.dim(), 1) << "actual_q_lens must be 1-D"; + CHECK_EQ(actual_kv_lens.dim(), 1) << "actual_kv_lens must be 1-D"; + CHECK_EQ(actual_q_lens.numel(), actual_kv_lens.numel()) + << "actual_q_lens and actual_kv_lens must have the same length"; + CHECK_GT(actual_q_lens.numel(), 0) << "batch size must be positive"; + CHECK_GT(q_head, 0) << "q_head must be positive"; + CHECK_GT(kv_head, 0) << "kv_head must be positive"; + CHECK_EQ(q_head % kv_head, 0) << "q_head must be divisible by kv_head"; + CHECK_GT(block_size, 0) << "block_size must be positive"; + CHECK_GT(head_size, 0) << "head_size must be positive"; + + auto q_lens_cpu = + actual_q_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + auto kv_lens_cpu = + actual_kv_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + const auto* q_lens = q_lens_cpu.data_ptr(); + const auto* kv_lens = kv_lens_cpu.data_ptr(); + const int64_t batch = actual_q_lens.numel(); + + SplitKvExtraInfo extra_info; + for (uint32_t i = 0; i < kMaxExtraInfoNodes; ++i) { + extra_info.core_info[i].start_b_idx = std::numeric_limits::max(); + } + + const uint32_t block_stack_num = + std::max(1, kMaxKvStackLen / static_cast(block_size)); + const uint32_t group_size = static_cast(q_head / kv_head); + uint32_t core_idx = 0; + uint32_t split_node_idx = 0; + uint64_t lse_offset = 0; + uint64_t o_offset = 0; + + for (int64_t batch_idx = 0; batch_idx < batch; ++batch_idx) { + const uint32_t q_end = static_cast(q_lens[batch_idx]); + const uint32_t q_start = + batch_idx == 0 ? 0 : static_cast(q_lens[batch_idx - 1]); + const uint32_t q_len = q_end - q_start; + const uint32_t kv_seq_len = static_cast(kv_lens[batch_idx]); + const uint32_t kv_blocks = + (kv_seq_len + static_cast(block_size) - 1) / + static_cast(block_size); + const uint32_t s2_blocks = + (kv_blocks + block_stack_num - 1) / block_stack_num; + + if (s2_blocks <= 1) { + CHECK_LT(core_idx, kMaxExtraInfoNodes) + << "x_flash_attention_infer extra_tiling core_info overflow"; + auto& core = extra_info.core_info[core_idx++]; + core.start_b_idx = static_cast(batch_idx); + core.start_n1_idx = 0; + core.start_s2_idx = 0; + core.end_b_idx = static_cast(batch_idx); + core.end_n1_idx = static_cast(kv_head - 1); + core.end_s2_idx = s2_blocks; + continue; + } + + // For long-KV requests, keep each KV head on one core and scan all S2 + // blocks inside that core. This avoids the split-KV combine path, whose + // fixed coreInfo/splitInfo capacity is too small for long Qwen3.5 chunks. + for (uint32_t kv_head_idx = 0; kv_head_idx < static_cast(kv_head); + ++kv_head_idx) { + CHECK_LT(core_idx, kMaxExtraInfoNodes) + << "x_flash_attention_infer extra_tiling core_info overflow"; + auto& core = extra_info.core_info[core_idx++]; + core.start_b_idx = static_cast(batch_idx); + core.start_n1_idx = kv_head_idx; + core.start_s2_idx = 0; + core.end_b_idx = static_cast(batch_idx); + core.end_n1_idx = kv_head_idx; + core.end_s2_idx = s2_blocks; + } + } + extra_info.total_split_node_num = split_node_idx; + + const int64_t int32_count = + (sizeof(SplitKvExtraInfo) + sizeof(int32_t) - 1) / sizeof(int32_t); + auto extra_tiling_cpu = + torch::empty({int32_count}, torch::TensorOptions().dtype(torch::kInt32)); + std::memcpy(extra_tiling_cpu.data_ptr(), + &extra_info, + sizeof(SplitKvExtraInfo)); + return extra_tiling_cpu.to(actual_q_lens.device(), /*non_blocking=*/false); +} + +} // namespace + +namespace xllm::kernel::npu { + +torch::Tensor x_flash_attention_infer(const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + double scale, + const std::string& layout) { + check_tensor(query, "query", "x_flash_attention_infer"); + check_tensor(key_cache, "key_cache", "x_flash_attention_infer"); + check_tensor(value_cache, "value_cache", "x_flash_attention_infer"); + check_tensor(mask, "mask", "x_flash_attention_infer"); + check_tensor(block_table, "block_table", "x_flash_attention_infer"); + check_tensor(actual_q_lens, "actual_q_lens", "x_flash_attention_infer"); + check_tensor(actual_kv_lens, "actual_kv_lens", "x_flash_attention_infer"); + CHECK_EQ(query.dim(), 3) << "query must be [tokens, q_heads, head_size]"; + CHECK_EQ(key_cache.dim(), 4) + << "key_cache must be [blocks, block_size, kv_heads, head_size]"; + CHECK_EQ(value_cache.dim(), 4) + << "value_cache must be [blocks, block_size, kv_heads, head_size]"; + CHECK_EQ(actual_q_lens.scalar_type(), torch::kInt32); + CHECK_EQ(actual_kv_lens.scalar_type(), torch::kInt32); + + torch::Tensor output = torch::empty_like(query); + torch::Tensor extra_tiling = make_extra_tiling(actual_q_lens, + actual_kv_lens, + q_head, + kv_head, + key_cache.size(1), + key_cache.size(3)); + std::string layout_attr = layout; + char* layout_attr_ptr = const_cast(layout_attr.c_str()); + + EXEC_NPU_CMD(aclnnXFlashAttentionInfer, + query, + key_cache, + value_cache, + mask, + block_table, + actual_q_lens, + actual_kv_lens, + extra_tiling, + layout_attr_ptr, + q_head, + kv_head, + scale, + output); + return output; +} + +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h b/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h index 4bb6527b40..00ab5a9f4c 100644 --- a/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h +++ b/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h @@ -53,4 +53,16 @@ void select_unshared_kv(const torch::Tensor& beam_index, int64_t decode_step, int64_t beam_size, int64_t layer_num); + +torch::Tensor x_flash_attention_infer(const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + double scale, + const std::string& layout = "TND"); } // namespace xllm::kernel::npu diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index 8e65c0b81c..3e30f38675 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -100,13 +100,17 @@ AttentionMetadata build_attention_metadata( params.batch_forward_type.is_chunked_prefill(); attn_metadata.is_prefill = params.batch_forward_type.is_prefill(); +#if defined(USE_NPU) + attn_metadata.block_table = params.block_tables; +#else if (!attn_metadata.is_prefill || enable_mla) { attn_metadata.block_table = params.block_tables; -#if !defined(USE_NPU) && !defined(USE_CUDA) +#if !defined(USE_CUDA) attn_metadata.kv_seq_lens = torch::diff(params.kv_seq_lens); // kv seqlens attn_metadata.q_seq_lens = torch::diff(params.q_seq_lens); // q seqlens #endif } +#endif #if defined(USE_NPU) // NPU path uses per-sequence lengths (not cumulative), so no diff. // Ensure per-sequence lengths are available for NPU kernels in all phases. diff --git a/xllm/core/layers/npu_torch/attention.cpp b/xllm/core/layers/npu_torch/attention.cpp index 95c9ff5a54..0a10cfa925 100644 --- a/xllm/core/layers/npu_torch/attention.cpp +++ b/xllm/core/layers/npu_torch/attention.cpp @@ -22,6 +22,7 @@ limitations under the License. #include #include "kernels/npu/npu_ops_api.h" +#include "kernels/npu/xllm_ops/xllm_ops_api.h" #include "kernels/ops_api.h" DECLARE_bool(enable_chunked_prefill); @@ -31,25 +32,31 @@ namespace layer { namespace { constexpr int64_t kFiaSplitFuseMaskSize = 2048; +constexpr int64_t kXfaMaxQkRowsPerCall = 128; -std::vector cumulative_lengths(const std::vector& seq_lens) { +torch::Tensor cumulative_lengths_tensor(const torch::Tensor& seq_lens) { + return torch::cumsum(seq_lens, 0).to(torch::kInt32); +} + +std::vector cumulative_lengths(const torch::Tensor& seq_lens) { + auto seq_lens_cpu = + seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + const auto* seq_lens_ptr = seq_lens_cpu.data_ptr(); std::vector cu_lens; - cu_lens.reserve(seq_lens.size()); + cu_lens.reserve(seq_lens_cpu.numel()); int64_t total = 0; - for (int32_t seq_len : seq_lens) { - total += seq_len; + for (int64_t i = 0; i < seq_lens_cpu.numel(); ++i) { + total += seq_lens_ptr[i]; cu_lens.emplace_back(total); } return cu_lens; } -std::vector to_i64_vector(const std::vector& values) { - std::vector out; - out.reserve(values.size()); - for (int32_t value : values) { - out.emplace_back(value); - } - return out; +torch::Tensor int32_tensor_on_device(const std::vector& values, + const torch::Device& device) { + return torch::tensor(values, torch::TensorOptions().dtype(torch::kInt32)) + .to(device); } torch::Tensor get_fia_split_fuse_attn_mask(const torch::Tensor& query) { @@ -75,6 +82,73 @@ torch::Tensor get_fia_split_fuse_attn_mask(const torch::Tensor& query) { return mask; } +torch::Tensor x_flash_attention_infer_with_query_split( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& q_seq_lens, + const torch::Tensor& kv_seq_lens, + int64_t q_head, + int64_t kv_head, + double scale) { + CHECK_GT(kv_head, 0); + CHECK_EQ(q_head % kv_head, 0); + const int64_t group_size = q_head / kv_head; + const int64_t max_q_len_per_call = + std::max(1, kXfaMaxQkRowsPerCall / group_size); + + auto q_seq_lens_cpu = + q_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + auto kv_seq_lens_cpu = + kv_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); + CHECK_EQ(q_seq_lens_cpu.numel(), kv_seq_lens_cpu.numel()); + const auto* q_seq_lens_ptr = q_seq_lens_cpu.data_ptr(); + const auto* kv_seq_lens_ptr = kv_seq_lens_cpu.data_ptr(); + const int64_t batch = q_seq_lens_cpu.numel(); + + auto output = torch::empty_like(query); + int64_t q_offset = 0; + const auto device = query.device(); + for (int64_t seq_idx = 0; seq_idx < batch; ++seq_idx) { + const int64_t q_len = q_seq_lens_ptr[seq_idx]; + const int64_t kv_len = kv_seq_lens_ptr[seq_idx]; + const int64_t past_kv_len = kv_len - q_len; + CHECK_GE(past_kv_len, 0); + + for (int64_t q_start = 0; q_start < q_len; q_start += max_q_len_per_call) { + const int64_t sub_q_len = + std::min(max_q_len_per_call, q_len - q_start); + auto sub_query = + query.narrow(0, q_offset + q_start, sub_q_len).contiguous(); + auto sub_block_table = block_table.narrow(0, seq_idx, 1).contiguous(); + auto sub_q_lens = + int32_tensor_on_device({static_cast(sub_q_len)}, device); + auto sub_kv_lens = int32_tensor_on_device( + {static_cast(past_kv_len + q_start + sub_q_len)}, device); + + auto sub_output = + xllm::kernel::npu::x_flash_attention_infer(sub_query, + key_cache, + value_cache, + mask, + sub_block_table, + sub_q_lens, + sub_kv_lens, + q_head, + kv_head, + scale, + "TND"); + output.narrow(0, q_offset + q_start, sub_q_len).copy_(sub_output); + } + q_offset += q_len; + } + return output; +} + } // namespace AttentionImpl::AttentionImpl(int64_t num_heads, @@ -142,47 +216,31 @@ void AttentionImpl::prefill_forward(torch::Tensor& query, output = output.view({-1, num_heads_, head_size_}); if (attn_metadata.is_prefill) { - key = key.view({-1, num_kv_heads_, head_size_}); - value = value.view({-1, num_kv_heads_, head_size_}); - - auto fia_result = xllm::kernel::npu::npu_fused_infer_attention( + auto xfa_result = x_flash_attention_infer_with_query_split( query, - key, - value, + k_cache, + v_cache.value(), get_fia_split_fuse_attn_mask(query), - std::nullopt, - cumulative_lengths(attn_metadata.q_seq_lens_vec), - cumulative_lengths(attn_metadata.kv_seq_lens_vec), + attn_metadata.block_table, + attn_metadata.q_seq_lens, + attn_metadata.kv_seq_lens, num_heads_, num_kv_heads_, - scale_, - 0, - 3, - "TND"); - output.copy_(std::get<0>(fia_result).view_as(output)); + scale_); + output.copy_(xfa_result.view_as(output)); } else if (attn_metadata.is_chunked_prefill) { - auto q_seq_lens = cumulative_lengths(attn_metadata.q_seq_lens_vec); - auto kv_seq_lens = to_i64_vector(attn_metadata.kv_seq_lens_vec); - auto k = k_cache.view({k_cache.size(0), k_cache.size(1), -1}); - auto v = v_cache.value().view( - {v_cache.value().size(0), v_cache.value().size(1), -1}); - auto fia_result = xllm::kernel::npu::npu_fused_infer_attention( + auto xfa_result = x_flash_attention_infer_with_query_split( query, - k, - v, + k_cache, + v_cache.value(), get_fia_split_fuse_attn_mask(query), - attn_metadata.block_table.defined() - ? std::make_optional(attn_metadata.block_table) - : std::nullopt, - q_seq_lens, - kv_seq_lens, + attn_metadata.block_table, + attn_metadata.q_seq_lens, + attn_metadata.kv_seq_lens, num_heads_, num_kv_heads_, - scale_, - k_cache.size(1), - 3, - "TND"); - output.copy_(std::get<0>(fia_result).view_as(output)); + scale_); + output.copy_(xfa_result.view_as(output)); } } diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index ddfcc40eb2..855f6f788f 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -328,11 +328,14 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( const AttentionMetadata& attn_metadata, KVCache& kv_cache, const ModelInputParams& input_params) { + const bool is_prefill_like = + attn_metadata.is_prefill || attn_metadata.is_chunked_prefill; + torch::Tensor qkvz_flat; torch::Tensor ba_flat; int64_t batch_size = 0; int64_t seq_len = 0; - if (attn_metadata.is_prefill) { + if (is_prefill_like) { std::tie(qkvz_flat, ba_flat) = project_flat_inputs(hidden_states); batch_size = 1; seq_len = qkvz_flat.size(0); @@ -355,7 +358,7 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( std::tie(mixed_qkv, z, b, a) = xllm::kernel::fused_qkvzba_split_reshape_cat(fused_params); - if (!attn_metadata.is_prefill) { + if (!is_prefill_like) { mixed_qkv = mixed_qkv.view({batch_size, seq_len, mixed_qkv.size(-1)}); } z = z.view({batch_size, seq_len, num_v_heads_ / tp_size_, head_v_dim_}); @@ -369,7 +372,7 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( torch::Tensor linear_state_indices = get_linear_state_indices(input_params, mixed_qkv.device()); - if (attn_metadata.is_prefill) { + if (is_prefill_like) { std::vector linear_state_indices_vec( input_params.linear_state_ids.begin(), input_params.linear_state_ids.end()); @@ -403,7 +406,7 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( mixed_qkv = mixed_qkv.view({batch_size, -1, mixed_qkv.size(-1)}).contiguous(); mixed_qkv = mixed_qkv.transpose(1, 2); // Compute gated delta net decay and beta terms. - if (attn_metadata.is_prefill) { + if (is_prefill_like) { xllm::kernel::FusedGdnGatingParams gdn_params; gdn_params.A_log = A_log_; gdn_params.a = a.contiguous().view({-1, a.size(-1)}); @@ -426,7 +429,7 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( } auto [processed_q, processed_k, processed_v] = process_mixed_qkv(mixed_qkv); // Apply chunked or recurrent gated-delta attention and update caches. - if (attn_metadata.is_prefill) { + if (is_prefill_like) { xllm::kernel::ChunkGatedDeltaRuleParams chunk_gated_delta_params; chunk_gated_delta_params.q = processed_q; chunk_gated_delta_params.k = processed_k; @@ -436,7 +439,9 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( // Get initial state from ssm_cache for sequences with previous state // Shape: [batch_size, num_heads, head_k_dim, head_v_dim] torch::Tensor initial_state_tensor = - torch::index_select(ssm_cache, 0, linear_state_indices); + torch::index_select(ssm_cache, 0, linear_state_indices) + .transpose(-1, -2) + .contiguous(); CHECK_EQ(input_params.has_initial_state.size(), input_params.linear_state_ids.size()) << "has_initial_state must be sequence-scoped."; From 39f5b84c412deba12c054ea57de01f134a984d6f Mon Sep 17 00:00:00 2001 From: sinle4cat Date: Thu, 14 May 2026 10:56:15 +0800 Subject: [PATCH 3/7] bugfix: x flash attention long kv x mte excepiton --- third_party/xllm_ops | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xllm_ops b/third_party/xllm_ops index 0b26ff7f7a..c8b7e7aba1 160000 --- a/third_party/xllm_ops +++ b/third_party/xllm_ops @@ -1 +1 @@ -Subproject commit 0b26ff7f7aaed1593f8e4cfbbacd37be74c18df2 +Subproject commit c8b7e7aba16dff0e1884cc60df118d24e56df5b2 From b02e42f71e046c0a23c4f8d6c4f693bd2c1a6705 Mon Sep 17 00:00:00 2001 From: Sinle4Cat Date: Fri, 15 May 2026 14:39:32 +0800 Subject: [PATCH 4/7] bugfix: route chunked prefill to x flash attention --- third_party/xllm_ops | 2 +- xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp | 4 ++++ xllm/core/layers/npu_torch/attention.cpp | 13 +------------ 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/third_party/xllm_ops b/third_party/xllm_ops index c8b7e7aba1..bae5abc680 160000 --- a/third_party/xllm_ops +++ b/third_party/xllm_ops @@ -1 +1 @@ -Subproject commit c8b7e7aba16dff0e1884cc60df118d24e56df5b2 +Subproject commit bae5abc680989199f723a774f5843534a81fb7e4 diff --git a/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp b/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp index cec49c2142..2768d08339 100644 --- a/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp +++ b/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp @@ -83,6 +83,9 @@ void beam_search_rec(const torch::Tensor& logprobs, create_acltensor(&out_beam_count_prefix_sums_ids, out_beam_count_prefix_sums); create_acltensor(&out_sequence_ids, out_sequence); + CHECK_GT(top_tokens.dim(), 0) + << "beam_search_rec: top_tokens must have at least one dimension"; + const int64_t top_k = top_tokens.size(-1); uint64_t workspace_size = 0; aclOpExecutor* executor = nullptr; CHECK_ACL_SUCCESS( @@ -91,6 +94,7 @@ void beam_search_rec(const torch::Tensor& logprobs, top_logprobs_ids, sequence_group_ids, current_step, + top_k, out_token_ids_ids, out_token_index_ids, out_log_probs_ids, diff --git a/xllm/core/layers/npu_torch/attention.cpp b/xllm/core/layers/npu_torch/attention.cpp index 51d529628e..b3d7930b4d 100644 --- a/xllm/core/layers/npu_torch/attention.cpp +++ b/xllm/core/layers/npu_torch/attention.cpp @@ -198,7 +198,7 @@ void AttentionImpl::prefill_forward(torch::Tensor& query, query = query.view({-1, num_heads_, head_size_}); output = output.view({-1, num_heads_, head_size_}); - if (attn_metadata.is_prefill) { + if (attn_metadata.is_prefill || attn_metadata.is_chunked_prefill) { auto xfa_result = x_flash_attention_infer_with_query_split( query, k_cache, @@ -211,17 +211,6 @@ void AttentionImpl::prefill_forward(torch::Tensor& query, num_kv_heads_, scale_); output.copy_(xfa_result.view_as(output)); - } else if (attn_metadata.is_chunked_prefill) { - xllm::kernel::npu::batch_chunked_paged_prefill( - query, - k_cache, - v_cache.value(), - scale_, - attn_metadata.block_table, - attn_metadata.kv_seq_lens_host, - attn_metadata.attn_mask, - attn_metadata.q_seq_lens_host, - output); } } From fa77f62c8a59a27f1ccad096e6bf2c7bdba483ec Mon Sep 17 00:00:00 2001 From: Sinle4Cat Date: Sat, 16 May 2026 09:59:07 +0800 Subject: [PATCH 5/7] bugfix: improve x flash attention long kv stability --- xllm/core/layers/npu_torch/attention.cpp | 130 ++++++++++++++---- .../npu_torch/qwen3_gated_delta_net_base.cpp | 2 +- 2 files changed, 106 insertions(+), 26 deletions(-) diff --git a/xllm/core/layers/npu_torch/attention.cpp b/xllm/core/layers/npu_torch/attention.cpp index b3d7930b4d..1b81d4ae2a 100644 --- a/xllm/core/layers/npu_torch/attention.cpp +++ b/xllm/core/layers/npu_torch/attention.cpp @@ -33,6 +33,16 @@ namespace { constexpr int64_t kFiaSplitFuseMaskSize = 2048; constexpr int64_t kXfaMaxQkRowsPerCall = 128; +constexpr int64_t kXfaMaxExtraInfoNodes = 24; +constexpr int64_t kXfaMaxKvStackLen = 512; + +struct XfaQuerySlice { + int64_t seq_idx = 0; + int64_t q_start = 0; + int64_t q_len = 0; + int64_t kv_len = 0; + int64_t core_count = 1; +}; torch::Tensor int32_tensor_on_device(const std::vector& values, const torch::Device& device) { @@ -40,6 +50,26 @@ torch::Tensor int32_tensor_on_device(const std::vector& values, .to(device); } +torch::Tensor int64_tensor_on_device(const std::vector& values, + const torch::Device& device) { + return torch::tensor(values, torch::TensorOptions().dtype(torch::kInt64)) + .to(device); +} + +int64_t div_up(int64_t value, int64_t divisor) { + return (value + divisor - 1) / divisor; +} + +int64_t xfa_core_count_for_slice(int64_t kv_len, + int64_t kv_head, + int64_t block_size) { + const int64_t block_stack_num = + std::max(1, kXfaMaxKvStackLen / block_size); + const int64_t kv_blocks = div_up(kv_len, block_size); + const int64_t s2_blocks = div_up(kv_blocks, block_stack_num); + return s2_blocks <= 1 ? 1 : kv_head; +} + torch::Tensor get_fia_split_fuse_attn_mask(const torch::Tensor& query) { static std::mutex mutex; static std::unordered_map mask_cache; @@ -71,6 +101,8 @@ torch::Tensor x_flash_attention_infer_with_query_split( const torch::Tensor& block_table, const torch::Tensor& q_seq_lens, const torch::Tensor& kv_seq_lens, + const torch::Tensor& q_seq_lens_host, + const torch::Tensor& kv_seq_lens_host, int64_t q_head, int64_t kv_head, double scale) { @@ -81,11 +113,19 @@ torch::Tensor x_flash_attention_infer_with_query_split( std::max(1, kXfaMaxQkRowsPerCall / group_size); auto q_seq_lens_cpu = - q_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) - .contiguous(); + q_seq_lens_host.defined() + ? q_seq_lens_host + .to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous() + : q_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); auto kv_seq_lens_cpu = - kv_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) - .contiguous(); + kv_seq_lens_host.defined() + ? kv_seq_lens_host + .to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous() + : kv_seq_lens.to(torch::kCPU, torch::kInt32, /*non_blocking=*/false) + .contiguous(); CHECK_EQ(q_seq_lens_cpu.numel(), kv_seq_lens_cpu.numel()); const auto* q_seq_lens_ptr = q_seq_lens_cpu.data_ptr(); const auto* kv_seq_lens_ptr = kv_seq_lens_cpu.data_ptr(); @@ -93,6 +133,7 @@ torch::Tensor x_flash_attention_infer_with_query_split( auto output = torch::empty_like(query); int64_t q_offset = 0; + std::vector slices; const auto device = query.device(); for (int64_t seq_idx = 0; seq_idx < batch; ++seq_idx) { const int64_t q_len = q_seq_lens_ptr[seq_idx]; @@ -103,30 +144,67 @@ torch::Tensor x_flash_attention_infer_with_query_split( for (int64_t q_start = 0; q_start < q_len; q_start += max_q_len_per_call) { const int64_t sub_q_len = std::min(max_q_len_per_call, q_len - q_start); - auto sub_query = - query.narrow(0, q_offset + q_start, sub_q_len).contiguous(); - auto sub_block_table = block_table.narrow(0, seq_idx, 1).contiguous(); - auto sub_q_lens = - int32_tensor_on_device({static_cast(sub_q_len)}, device); - auto sub_kv_lens = int32_tensor_on_device( - {static_cast(past_kv_len + q_start + sub_q_len)}, device); - - auto sub_output = - xllm::kernel::npu::x_flash_attention_infer(sub_query, - key_cache, - value_cache, - mask, - sub_block_table, - sub_q_lens, - sub_kv_lens, - q_head, - kv_head, - scale, - "TND"); - output.narrow(0, q_offset + q_start, sub_q_len).copy_(sub_output); + const int64_t sub_kv_len = past_kv_len + q_start + sub_q_len; + const int64_t core_count = + xfa_core_count_for_slice(sub_kv_len, kv_head, key_cache.size(1)); + CHECK_LE(core_count, kXfaMaxExtraInfoNodes) + << "x_flash_attention_infer query slice needs too many cores"; + slices.push_back( + {seq_idx, q_offset + q_start, sub_q_len, sub_kv_len, core_count}); } q_offset += q_len; } + + for (size_t group_start = 0; group_start < slices.size();) { + size_t group_end = group_start; + int64_t core_count = 0; + while (group_end < slices.size() && + core_count + slices[group_end].core_count <= kXfaMaxExtraInfoNodes) { + core_count += slices[group_end].core_count; + ++group_end; + } + + const int64_t q_start = slices[group_start].q_start; + const int64_t q_end = + slices[group_end - 1].q_start + slices[group_end - 1].q_len; + auto group_query = query.narrow(0, q_start, q_end - q_start).contiguous(); + + std::vector group_q_lens; + std::vector group_kv_lens; + std::vector block_indices; + group_q_lens.reserve(group_end - group_start); + group_kv_lens.reserve(group_end - group_start); + block_indices.reserve(group_end - group_start); + int32_t q_cu_len = 0; + for (size_t idx = group_start; idx < group_end; ++idx) { + q_cu_len += static_cast(slices[idx].q_len); + group_q_lens.push_back(q_cu_len); + group_kv_lens.push_back(static_cast(slices[idx].kv_len)); + block_indices.push_back(slices[idx].seq_idx); + } + + auto group_block_table = + block_table + .index_select(0, int64_tensor_on_device(block_indices, device)) + .contiguous(); + auto group_q_lens_tensor = int32_tensor_on_device(group_q_lens, device); + auto group_kv_lens_tensor = int32_tensor_on_device(group_kv_lens, device); + + auto group_output = + xllm::kernel::npu::x_flash_attention_infer(group_query, + key_cache, + value_cache, + mask, + group_block_table, + group_q_lens_tensor, + group_kv_lens_tensor, + q_head, + kv_head, + scale, + "TND"); + output.narrow(0, q_start, q_end - q_start).copy_(group_output); + group_start = group_end; + } return output; } @@ -207,6 +285,8 @@ void AttentionImpl::prefill_forward(torch::Tensor& query, attn_metadata.block_table, attn_metadata.q_seq_lens, attn_metadata.kv_seq_lens, + attn_metadata.q_seq_lens_host, + attn_metadata.kv_seq_lens_host, num_heads_, num_kv_heads_, scale_); diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index a7a7b1ab9d..d2b0cc0517 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -586,7 +586,7 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( conv_weight.transpose(0, 1).contiguous(); xllm::kernel::CausalConv1dUpdateParams conv1d_params; conv1d_params.x = mixed_qkv.reshape({-1, mixed_qkv.size(-1)}); - conv1d_params.conv_state = conv_cache; + conv1d_params.conv_state = conv_cache.transpose(1, 2); conv1d_params.weight = conv_weight_for_update; conv1d_params.conv_state_indices = logical_state_indices; conv1d_params.block_idx_last_scheduled_token = From 6d7cdfb5e88e529896ef8c85a7da19b2c217695f Mon Sep 17 00:00:00 2001 From: Sinle4Cat Date: Wed, 20 May 2026 18:30:16 +0800 Subject: [PATCH 6/7] bugfix: fix long kv xflash prefill grouping --- .../core/framework/model/model_input_params.h | 9 + .../npu/xllm_ops/x_flash_attention_infer.cpp | 149 ++++++- xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h | 27 ++ xllm/core/layers/common/attention_metadata.h | 5 + .../common/attention_metadata_builder.cpp | 6 + xllm/core/layers/npu_torch/attention.cpp | 248 ++++++++++-- .../npu_torch/qwen3_gated_delta_net_base.cpp | 382 ++++++++++++++---- xllm/core/runtime/acl_graph_executor_impl.cpp | 118 +++++- xllm/core/runtime/acl_graph_executor_impl.h | 11 + 9 files changed, 826 insertions(+), 129 deletions(-) diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 683617fb3e..1a72e380d5 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -420,6 +420,12 @@ struct ModelInputParams { safe_to(graph_buffer.attn_mask, device, true); params.graph_buffer.tiling_data = safe_to(graph_buffer.tiling_data, device, true); + params.graph_buffer.xfa_q_cu_seq_lens = + safe_to(graph_buffer.xfa_q_cu_seq_lens, device, true); + params.graph_buffer.xfa_extra_tiling = + safe_to(graph_buffer.xfa_extra_tiling, device, true); + params.graph_buffer.xfa_attn_mask = + safe_to(graph_buffer.xfa_attn_mask, device, true); // params for flashinfer params.paged_kv_indptr = safe_to(paged_kv_indptr, device); @@ -666,6 +672,9 @@ struct ModelInputParams { struct GraphBuffer { torch::Tensor attn_mask; torch::Tensor tiling_data; + torch::Tensor xfa_q_cu_seq_lens; + torch::Tensor xfa_extra_tiling; + torch::Tensor xfa_attn_mask; bool use_expanded_decode_for_spec_verify_attention = false; torch::Tensor expanded_kv_seq_lens; torch::Tensor expanded_block_tables; diff --git a/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp b/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp index b7e3d105aa..8b9ee57fa7 100644 --- a/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp +++ b/xllm/core/kernels/npu/xllm_ops/x_flash_attention_infer.cpp @@ -88,7 +88,7 @@ torch::Tensor make_extra_tiling(const torch::Tensor& actual_q_lens, const auto* kv_lens = kv_lens_cpu.data_ptr(); const int64_t batch = actual_q_lens.numel(); - SplitKvExtraInfo extra_info; + SplitKvExtraInfo extra_info{}; for (uint32_t i = 0; i < kMaxExtraInfoNodes; ++i) { extra_info.core_info[i].start_b_idx = std::numeric_limits::max(); } @@ -154,6 +154,81 @@ torch::Tensor make_extra_tiling(const torch::Tensor& actual_q_lens, return extra_tiling_cpu.to(actual_q_lens.device(), /*non_blocking=*/false); } +torch::Tensor make_extra_tiling_from_host_lens( + const std::vector& actual_q_lens, + const std::vector& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + int64_t block_size, + int64_t head_size) { + CHECK_EQ(actual_q_lens.size(), actual_kv_lens.size()) + << "actual_q_lens and actual_kv_lens must have the same length"; + CHECK_GT(actual_q_lens.size(), 0) << "batch size must be positive"; + CHECK_GT(q_head, 0) << "q_head must be positive"; + CHECK_GT(kv_head, 0) << "kv_head must be positive"; + CHECK_EQ(q_head % kv_head, 0) << "q_head must be divisible by kv_head"; + CHECK_GT(block_size, 0) << "block_size must be positive"; + CHECK_GT(head_size, 0) << "head_size must be positive"; + + SplitKvExtraInfo extra_info{}; + for (uint32_t i = 0; i < kMaxExtraInfoNodes; ++i) { + extra_info.core_info[i].start_b_idx = std::numeric_limits::max(); + } + + const uint32_t block_stack_num = + std::max(1, kMaxKvStackLen / static_cast(block_size)); + uint32_t core_idx = 0; + uint32_t split_node_idx = 0; + + for (int64_t batch_idx = 0; + batch_idx < static_cast(actual_q_lens.size()); + ++batch_idx) { + const uint32_t kv_seq_len = + static_cast(actual_kv_lens[batch_idx]); + const uint32_t kv_blocks = + (kv_seq_len + static_cast(block_size) - 1) / + static_cast(block_size); + const uint32_t s2_blocks = + (kv_blocks + block_stack_num - 1) / block_stack_num; + + if (s2_blocks <= 1) { + CHECK_LT(core_idx, kMaxExtraInfoNodes) + << "x_flash_attention_infer extra_tiling core_info overflow"; + auto& core = extra_info.core_info[core_idx++]; + core.start_b_idx = static_cast(batch_idx); + core.start_n1_idx = 0; + core.start_s2_idx = 0; + core.end_b_idx = static_cast(batch_idx); + core.end_n1_idx = static_cast(kv_head - 1); + core.end_s2_idx = s2_blocks; + continue; + } + + for (uint32_t kv_head_idx = 0; kv_head_idx < static_cast(kv_head); + ++kv_head_idx) { + CHECK_LT(core_idx, kMaxExtraInfoNodes) + << "x_flash_attention_infer extra_tiling core_info overflow"; + auto& core = extra_info.core_info[core_idx++]; + core.start_b_idx = static_cast(batch_idx); + core.start_n1_idx = kv_head_idx; + core.start_s2_idx = 0; + core.end_b_idx = static_cast(batch_idx); + core.end_n1_idx = kv_head_idx; + core.end_s2_idx = s2_blocks; + } + } + extra_info.total_split_node_num = split_node_idx; + + const int64_t int32_count = + (sizeof(SplitKvExtraInfo) + sizeof(int32_t) - 1) / sizeof(int32_t); + auto extra_tiling_cpu = + torch::empty({int32_count}, torch::TensorOptions().dtype(torch::kInt32)); + std::memcpy(extra_tiling_cpu.data_ptr(), + &extra_info, + sizeof(SplitKvExtraInfo)); + return extra_tiling_cpu; +} + } // namespace namespace xllm::kernel::npu { @@ -184,15 +259,60 @@ torch::Tensor x_flash_attention_infer(const torch::Tensor& query, CHECK_EQ(actual_q_lens.scalar_type(), torch::kInt32); CHECK_EQ(actual_kv_lens.scalar_type(), torch::kInt32); - torch::Tensor output = torch::empty_like(query); torch::Tensor extra_tiling = make_extra_tiling(actual_q_lens, actual_kv_lens, q_head, kv_head, key_cache.size(1), key_cache.size(3)); + return x_flash_attention_infer_with_extra_tiling(query, + key_cache, + value_cache, + mask, + block_table, + actual_q_lens, + actual_kv_lens, + extra_tiling, + q_head, + kv_head, + scale, + layout); +} + +torch::Tensor x_flash_attention_infer_with_extra_tiling( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + const torch::Tensor& extra_tiling, + int64_t q_head, + int64_t kv_head, + double scale, + const std::string& layout) { + check_tensor(query, "query", "x_flash_attention_infer"); + check_tensor(key_cache, "key_cache", "x_flash_attention_infer"); + check_tensor(value_cache, "value_cache", "x_flash_attention_infer"); + check_tensor(mask, "mask", "x_flash_attention_infer"); + check_tensor(block_table, "block_table", "x_flash_attention_infer"); + check_tensor(actual_q_lens, "actual_q_lens", "x_flash_attention_infer"); + check_tensor(actual_kv_lens, "actual_kv_lens", "x_flash_attention_infer"); + check_tensor(extra_tiling, "extra_tiling", "x_flash_attention_infer"); + CHECK_EQ(query.dim(), 3) << "query must be [tokens, q_heads, head_size]"; + CHECK_EQ(key_cache.dim(), 4) + << "key_cache must be [blocks, block_size, kv_heads, head_size]"; + CHECK_EQ(value_cache.dim(), 4) + << "value_cache must be [blocks, block_size, kv_heads, head_size]"; + CHECK_EQ(actual_q_lens.scalar_type(), torch::kInt32); + CHECK_EQ(actual_kv_lens.scalar_type(), torch::kInt32); + CHECK_EQ(extra_tiling.scalar_type(), torch::kInt32); + CHECK_GE(extra_tiling.numel(), x_flash_attention_extra_tiling_int32_count()); + std::string layout_attr = layout; char* layout_attr_ptr = const_cast(layout_attr.c_str()); + torch::Tensor output = torch::empty(query.sizes(), query.options()); EXEC_NPU_CMD(aclnnXFlashAttentionInfer, query, @@ -211,4 +331,29 @@ torch::Tensor x_flash_attention_infer(const torch::Tensor& query, return output; } +int64_t x_flash_attention_extra_tiling_int32_count() { + return (sizeof(SplitKvExtraInfo) + sizeof(int32_t) - 1) / sizeof(int32_t); +} + +int64_t x_flash_attention_decode_group_size(int64_t kv_head) { + CHECK_GT(kv_head, 0); + return kMaxExtraInfoNodes / kv_head; +} + +void update_x_flash_attention_extra_tiling( + const std::vector& actual_q_lens, + const std::vector& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + int64_t block_size, + int64_t head_size, + torch::Tensor& extra_tiling) { + auto next_extra_tiling = make_extra_tiling_from_host_lens( + actual_q_lens, actual_kv_lens, q_head, kv_head, block_size, head_size); + CHECK_GE(extra_tiling.numel(), next_extra_tiling.numel()) + << "x_flash_attention_infer persistent extra_tiling is too small"; + extra_tiling.slice(/*dim=*/0, /*start=*/0, /*end=*/next_extra_tiling.numel()) + .copy_(next_extra_tiling, /*non_blocking=*/false); +} + } // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h b/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h index 00ab5a9f4c..ae07eb61a5 100644 --- a/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h +++ b/xllm/core/kernels/npu/xllm_ops/xllm_ops_api.h @@ -65,4 +65,31 @@ torch::Tensor x_flash_attention_infer(const torch::Tensor& query, int64_t kv_head, double scale, const std::string& layout = "TND"); + +torch::Tensor x_flash_attention_infer_with_extra_tiling( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& actual_q_lens, + const torch::Tensor& actual_kv_lens, + const torch::Tensor& extra_tiling, + int64_t q_head, + int64_t kv_head, + double scale, + const std::string& layout = "TND"); + +int64_t x_flash_attention_extra_tiling_int32_count(); + +int64_t x_flash_attention_decode_group_size(int64_t kv_head); + +void update_x_flash_attention_extra_tiling( + const std::vector& actual_q_lens, + const std::vector& actual_kv_lens, + int64_t q_head, + int64_t kv_head, + int64_t block_size, + int64_t head_size, + torch::Tensor& extra_tiling); } // namespace xllm::kernel::npu diff --git a/xllm/core/layers/common/attention_metadata.h b/xllm/core/layers/common/attention_metadata.h index 85e1d027ab..4c8fc61b44 100644 --- a/xllm/core/layers/common/attention_metadata.h +++ b/xllm/core/layers/common/attention_metadata.h @@ -162,6 +162,11 @@ struct AttentionMetadata { // For ACL graph execution - fixed-address device tiling data for // CustomPagedAttention replay. torch::Tensor paged_attention_tiling_data; + // For ACL graph execution with x_flash_attention_infer. The tensor address is + // fixed across replay and its contents are updated before graph replay. + torch::Tensor xfa_q_cu_seq_lens; + torch::Tensor xfa_extra_tiling; + torch::Tensor xfa_attn_mask; // Pre-computed attention mask for npu_fused_infer_attention. torch::Tensor fia_attn_mask; // Host vectors for npu_fused_infer_attention (kernel requires host memory). diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index 7b1b9aabb4..216d67430e 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -89,6 +89,9 @@ AttentionMetadata build_attention_metadata( params.graph_buffer.expanded_block_tables; attn_metadata.expanded_paged_attention_tiling_data = params.graph_buffer.expanded_tiling_data; + attn_metadata.xfa_q_cu_seq_lens = params.graph_buffer.xfa_q_cu_seq_lens; + attn_metadata.xfa_extra_tiling = params.graph_buffer.xfa_extra_tiling; + attn_metadata.xfa_attn_mask = params.graph_buffer.xfa_attn_mask; if (!params.graph_buffer.expanded_kv_seq_lens_vec.empty()) { attn_metadata.expanded_kv_seq_lens_host = torch::tensor( params.graph_buffer.expanded_kv_seq_lens_vec, torch::kInt); @@ -106,6 +109,9 @@ AttentionMetadata build_attention_metadata( if (use_acl_graph) { // ACL graph mode: use CustomPagedAttention with tiling_data on device attn_metadata.paged_attention_tiling_data = params.graph_buffer.tiling_data; + attn_metadata.xfa_q_cu_seq_lens = params.graph_buffer.xfa_q_cu_seq_lens; + attn_metadata.xfa_extra_tiling = params.graph_buffer.xfa_extra_tiling; + attn_metadata.xfa_attn_mask = params.graph_buffer.xfa_attn_mask; } // Provide capture-time host seq_lens for NPU kernels. ACL graph replay must // rely on fixed-address device inputs such as tiling_data, not mutable host diff --git a/xllm/core/layers/npu_torch/attention.cpp b/xllm/core/layers/npu_torch/attention.cpp index 32eabe6190..1b2ed98010 100644 --- a/xllm/core/layers/npu_torch/attention.cpp +++ b/xllm/core/layers/npu_torch/attention.cpp @@ -103,6 +103,24 @@ bool use_x_flash_decode() { return !env_flag_enabled("XLLM_DISABLE_XFLASH_DECODE"); } +bool use_x_flash_prefill() { + return !env_flag_enabled("XLLM_DISABLE_XFLASH_PREFILL"); +} + +bool use_x_flash_prefill_grouping() { + return !env_flag_enabled("XLLM_DISABLE_XFLASH_PREFILL_GROUPING"); +} + +bool allow_long_kv_x_flash_prefill_grouping() { + return env_flag_enabled("XLLM_ENABLE_XFLASH_LONG_PREFILL_GROUPING"); +} + +bool can_group_x_flash_prefill_slice(const XfaQuerySlice& slice) { + // Long-KV multi-core grouping must stay behind an explicit validation flag. + return use_x_flash_prefill_grouping() && + (slice.core_count == 1 || allow_long_kv_x_flash_prefill_grouping()); +} + torch::Tensor x_flash_attention_infer_with_query_split( const torch::Tensor& query, const torch::Tensor& key_cache, @@ -165,33 +183,36 @@ torch::Tensor x_flash_attention_infer_with_query_split( q_offset += q_len; } - for (size_t group_start = 0; group_start < slices.size();) { - size_t group_end = group_start; + std::vector consumed(slices.size(), false); + auto has_seq_in_group = [&](const std::vector& group_indices, + int64_t seq_idx) { + return std::any_of( + group_indices.begin(), group_indices.end(), [&](size_t idx) { + return slices[idx].seq_idx == seq_idx; + }); + }; + auto run_group = [&](const std::vector& group_indices) { + CHECK(!group_indices.empty()); int64_t core_count = 0; - while (group_end < slices.size() && - core_count + slices[group_end].core_count <= kXfaMaxExtraInfoNodes) { - core_count += slices[group_end].core_count; - ++group_end; - } - - const int64_t q_start = slices[group_start].q_start; - const int64_t q_end = - slices[group_end - 1].q_start + slices[group_end - 1].q_len; - auto group_query = query.narrow(0, q_start, q_end - q_start).contiguous(); - + std::vector query_parts; std::vector group_q_lens; std::vector group_kv_lens; std::vector block_indices; - group_q_lens.reserve(group_end - group_start); - group_kv_lens.reserve(group_end - group_start); - block_indices.reserve(group_end - group_start); + query_parts.reserve(group_indices.size()); + group_q_lens.reserve(group_indices.size()); + group_kv_lens.reserve(group_indices.size()); + block_indices.reserve(group_indices.size()); int32_t q_cu_len = 0; - for (size_t idx = group_start; idx < group_end; ++idx) { - q_cu_len += static_cast(slices[idx].q_len); + for (size_t idx : group_indices) { + const auto& slice = slices[idx]; + core_count += slice.core_count; + query_parts.push_back(query.narrow(0, slice.q_start, slice.q_len)); + q_cu_len += static_cast(slice.q_len); group_q_lens.push_back(q_cu_len); - group_kv_lens.push_back(static_cast(slices[idx].kv_len)); - block_indices.push_back(slices[idx].seq_idx); + group_kv_lens.push_back(static_cast(slice.kv_len)); + block_indices.push_back(slice.seq_idx); } + auto group_query = torch::cat(query_parts, /*dim=*/0).contiguous(); auto group_block_table = block_table @@ -212,8 +233,63 @@ torch::Tensor x_flash_attention_infer_with_query_split( kv_head, scale, "TND"); - output.narrow(0, q_start, q_end - q_start).copy_(group_output); - group_start = group_end; + int64_t group_output_offset = 0; + for (size_t idx : group_indices) { + const auto& slice = slices[idx]; + output.narrow(0, slice.q_start, slice.q_len) + .copy_(group_output.narrow(0, group_output_offset, slice.q_len)); + group_output_offset += slice.q_len; + } + }; + + for (size_t group_start = 0; group_start < slices.size();) { + while (group_start < slices.size() && consumed[group_start]) { + ++group_start; + } + if (group_start >= slices.size()) { + break; + } + + std::vector group_indices{group_start}; + int64_t core_count = slices[group_start].core_count; + const int64_t group_q_len = slices[group_start].q_len; + const int64_t group_kv_len = slices[group_start].kv_len; + const int64_t group_core_count = slices[group_start].core_count; + const bool groupable = can_group_x_flash_prefill_slice(slices[group_start]); + + if (groupable && group_core_count > 1) { + // Long-KV multi-core grouping is only safe across distinct requests. + for (size_t idx = group_start + 1; idx < slices.size(); ++idx) { + if (consumed[idx] || slices[idx].q_len != group_q_len || + slices[idx].kv_len != group_kv_len || + slices[idx].core_count != group_core_count || + has_seq_in_group(group_indices, slices[idx].seq_idx)) { + continue; + } + if (core_count + slices[idx].core_count > kXfaMaxExtraInfoNodes) { + break; + } + core_count += slices[idx].core_count; + group_indices.push_back(idx); + } + } else { + size_t group_end = group_start + 1; + while (groupable && group_end < slices.size() && !consumed[group_end] && + slices[group_end].q_len == group_q_len && + slices[group_end].core_count == group_core_count && + core_count + slices[group_end].core_count <= + kXfaMaxExtraInfoNodes) { + core_count += slices[group_end].core_count; + group_indices.push_back(group_end); + ++group_end; + } + } + + run_group(group_indices); + for (size_t idx : group_indices) { + consumed[idx] = true; + } + ++group_start; } return output; } @@ -292,6 +368,48 @@ torch::Tensor x_flash_attention_infer_decode_with_batch_split( return output; } +torch::Tensor x_flash_attention_infer_decode_for_graph( + const torch::Tensor& query, + const torch::Tensor& key_cache, + const torch::Tensor& value_cache, + const torch::Tensor& mask, + const torch::Tensor& block_table, + const torch::Tensor& q_cu_seq_lens, + const torch::Tensor& kv_seq_lens, + const torch::Tensor& extra_tiling, + int64_t q_head, + int64_t kv_head, + double scale) { + CHECK_EQ(query.dim(), 3) + << "decode query must be [batch, q_heads, head_size]"; + CHECK(q_cu_seq_lens.defined()) + << "graph decode requires device q cumulative seq lens"; + CHECK(kv_seq_lens.defined()) << "graph decode requires device kv seq lens"; + CHECK(extra_tiling.defined()) << "graph decode requires fixed extra tiling"; + auto actual_q_lens = q_cu_seq_lens; + if (actual_q_lens.numel() == query.size(0) + 1) { + actual_q_lens = + actual_q_lens.narrow(/*dim=*/0, /*start=*/1, /*length=*/query.size(0)); + } + CHECK_EQ(actual_q_lens.numel(), query.size(0)); + CHECK_EQ(kv_seq_lens.numel(), query.size(0)); + CHECK_EQ(actual_q_lens.scalar_type(), torch::kInt32); + CHECK_EQ(kv_seq_lens.scalar_type(), torch::kInt32); + return xllm::kernel::npu::x_flash_attention_infer_with_extra_tiling( + query, + key_cache, + value_cache, + mask, + block_table, + actual_q_lens, + kv_seq_lens, + extra_tiling, + q_head, + kv_head, + scale, + "TND"); +} + } // namespace AttentionImpl::AttentionImpl(int64_t num_heads, @@ -360,7 +478,47 @@ void AttentionImpl::prefill_forward(torch::Tensor& query, query = query.view({-1, num_heads_, head_size_}); output = output.view({-1, num_heads_, head_size_}); + auto run_fia_prefill = [&]() { + if (attn_metadata.is_prefill) { + auto key_view = key.view({-1, num_kv_heads_, head_size_}); + auto value_view = value.view({-1, num_kv_heads_, head_size_}); + auto fallback_output = torch::empty_like(query); + xllm::kernel::npu::batch_prefill( + query, + key_view, + value_view, + attn_metadata.fia_attn_mask.defined() + ? attn_metadata.fia_attn_mask + : get_fia_split_fuse_attn_mask(query), + attn_metadata.q_seq_lens, + scale_, + fallback_output); + return fallback_output; + } + + auto fallback_output = torch::empty_like(query); + xllm::kernel::npu::batch_chunked_paged_prefill( + query, + k_cache, + v_cache.value(), + scale_, + attn_metadata.block_table, + attn_metadata.kv_seq_lens, + attn_metadata.fia_attn_mask.defined() + ? attn_metadata.fia_attn_mask + : get_fia_split_fuse_attn_mask(query), + attn_metadata.q_seq_lens, + fallback_output); + return fallback_output; + }; + if (attn_metadata.is_prefill || attn_metadata.is_chunked_prefill) { + if (!use_x_flash_prefill()) { + auto fia_result = run_fia_prefill(); + output.copy_(fia_result.view_as(output)); + return; + } + auto xfa_result = x_flash_attention_infer_with_query_split( query, k_cache, @@ -392,11 +550,15 @@ void AttentionImpl::decoder_forward(torch::Tensor& query, if (attn_metadata.use_expanded_decode_for_spec_verify_attention) { block_table = attn_metadata.expanded_block_table; tiling_data = attn_metadata.expanded_paged_attention_tiling_data; - if (attn_metadata.expanded_kv_seq_lens_host.defined()) { + if (tiling_data.defined()) { + kv_seq_lens = attn_metadata.expanded_kv_seq_lens; + } else if (attn_metadata.expanded_kv_seq_lens_host.defined()) { kv_seq_lens = attn_metadata.expanded_kv_seq_lens_host; } else { kv_seq_lens = attn_metadata.expanded_kv_seq_lens; } + } else if (tiling_data.defined()) { + kv_seq_lens = attn_metadata.kv_seq_lens; } else if (attn_metadata.kv_seq_lens_host.defined()) { kv_seq_lens = attn_metadata.kv_seq_lens_host; } else { @@ -404,18 +566,36 @@ void AttentionImpl::decoder_forward(torch::Tensor& query, kv_seq_lens = attn_metadata.kv_seq_lens; } - if (!tiling_data.defined() && v_cache.has_value() && use_x_flash_decode()) { + if (v_cache.has_value() && use_x_flash_decode()) { auto decode_query = query.view({-1, num_heads_, head_size_}); - auto xfa_result = x_flash_attention_infer_decode_with_batch_split( - decode_query, - k_cache, - v_cache.value(), - get_fia_split_fuse_attn_mask(decode_query), - block_table, - kv_seq_lens, - num_heads_, - num_kv_heads_, - scale_); + torch::Tensor xfa_result; + if (tiling_data.defined()) { + xfa_result = x_flash_attention_infer_decode_for_graph( + decode_query, + k_cache, + v_cache.value(), + attn_metadata.xfa_attn_mask.defined() + ? attn_metadata.xfa_attn_mask + : get_fia_split_fuse_attn_mask(decode_query), + block_table, + attn_metadata.xfa_q_cu_seq_lens, + kv_seq_lens, + attn_metadata.xfa_extra_tiling, + num_heads_, + num_kv_heads_, + scale_); + } else { + xfa_result = x_flash_attention_infer_decode_with_batch_split( + decode_query, + k_cache, + v_cache.value(), + get_fia_split_fuse_attn_mask(decode_query), + block_table, + kv_seq_lens, + num_heads_, + num_kv_heads_, + scale_); + } output.copy_(xfa_result.view_as(output)); return; } diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index 11992062f1..9bff8fe78d 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -15,6 +15,9 @@ limitations under the License. #include #include +#include +#include +#include #include #include "xllm/core/kernels/ops_api.h" @@ -23,6 +26,30 @@ namespace xllm { namespace layer { namespace { +bool qwen35_gdn_debug_enabled() { + const char* value = std::getenv("XLLM_DEBUG_QWEN35_GDN"); + if (value == nullptr) { + return false; + } + std::string flag(value); + return flag == "1" || flag == "true" || flag == "TRUE" || flag == "on" || + flag == "ON"; +} + +template +void append_vector(std::ostringstream& oss, + const char* name, + const std::vector& values) { + oss << " " << name << "=["; + for (size_t i = 0; i < values.size(); ++i) { + if (i != 0) { + oss << ","; + } + oss << values[i]; + } + oss << "]"; +} + torch::Tensor l2norm(const torch::Tensor& x, int64_t dim, double eps = 1e-6) { auto norm = torch::sqrt(torch::sum(torch::square(x), dim, true) + eps); return x / norm; @@ -521,26 +548,88 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( const bool use_spec_verify = input_params.is_spec_verify; bool is_any_prefill = attn_metadata.is_prefill || attn_metadata.is_chunked_prefill; + if (qwen35_gdn_debug_enabled() && + (attn_metadata.is_chunked_prefill || batch_size > 1)) { + static std::atomic log_count{0}; + const int current = log_count.fetch_add(1); + if (current < 160) { + std::ostringstream oss; + oss << "qwen35_gdn meta idx=" << current + << " is_prefill=" << attn_metadata.is_prefill + << " is_chunked_prefill=" << attn_metadata.is_chunked_prefill + << " is_spec_verify=" << use_spec_verify + << " batch_size=" << batch_size << " seq_len=" << seq_len + << " max_query_len=" << attn_metadata.max_query_len + << " checkpoint_stride=" << checkpoint_stride; + append_vector(oss, "q_lens", attn_metadata.q_seq_lens_vec); + append_vector(oss, "kv_lens", attn_metadata.kv_seq_lens_vec); + append_vector(oss, "query_start_loc", input_params.query_start_loc); + append_vector(oss, "linear_state_ids", input_params.linear_state_ids); + append_vector(oss, "has_initial_state", input_params.has_initial_state); + LOG(INFO) << oss.str(); + } + } if (!use_spec_verify && is_any_prefill) { torch::IntArrayRef num_accepted_tokens_opt; - std::vector linear_state_indices_vec( - input_params.linear_state_ids.begin(), - input_params.linear_state_ids.end()); torch::Tensor conv_input = reshape_qkvz_unpad(attn_metadata, mixed_qkv); - mixed_qkv = xllm::kernel::causal_conv1d( - conv_input, - conv_weight, - conv_cache, - std::optional(), // bias (no bias for qwen3) - torch::IntArrayRef(input_params.query_start_loc), - torch::IntArrayRef(linear_state_indices_vec), - torch::IntArrayRef(input_params.has_initial_state), - num_accepted_tokens_opt, - 1, // activation_mode - -1, // pad_slot_id - 0 // run mode 0:fn, 1:update - ); + if (attn_metadata.is_chunked_prefill && batch_size > 1) { + CHECK_GE(attn_metadata.q_seq_lens_vec.size(), + static_cast(batch_size)) + << "q_seq_lens_vec must be populated for Qwen3.5 chunked conv."; + CHECK_EQ(input_params.linear_state_ids.size(), + static_cast(batch_size)) + << "linear_state_ids must be sequence-scoped for Qwen3.5 conv."; + CHECK_EQ(input_params.has_initial_state.size(), + static_cast(batch_size)) + << "has_initial_state must be sequence-scoped for Qwen3.5 conv."; + std::vector conv_outputs; + conv_outputs.reserve(batch_size); + int64_t offset = 0; + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + CHECK_GT(valid_len, 0) << "Qwen3.5 conv sequence length must be > 0"; + auto conv_slice = conv_input.narrow(0, offset, valid_len).contiguous(); + offset += valid_len; + std::vector query_start_loc = {0, valid_len}; + std::vector linear_state_indices = { + input_params.linear_state_ids[batch_idx]}; + std::vector has_initial_state = { + input_params.has_initial_state[batch_idx]}; + conv_outputs.emplace_back(xllm::kernel::causal_conv1d( + conv_slice, + conv_weight, + conv_cache, + std::optional(), // bias (no bias for qwen3) + torch::IntArrayRef(query_start_loc), + torch::IntArrayRef(linear_state_indices), + torch::IntArrayRef(has_initial_state), + num_accepted_tokens_opt, + 1, // activation_mode + -1, // pad_slot_id + 0 // run mode 0:fn, 1:update + )); + } + CHECK_EQ(offset, conv_input.size(0)); + mixed_qkv = torch::cat(conv_outputs, 0).contiguous(); + } else { + std::vector linear_state_indices_vec( + input_params.linear_state_ids.begin(), + input_params.linear_state_ids.end()); + mixed_qkv = xllm::kernel::causal_conv1d( + conv_input, + conv_weight, + conv_cache, + std::optional(), // bias (no bias for qwen3) + torch::IntArrayRef(input_params.query_start_loc), + torch::IntArrayRef(linear_state_indices_vec), + torch::IntArrayRef(input_params.has_initial_state), + num_accepted_tokens_opt, + 1, // activation_mode + -1, // pad_slot_id + 0 // run mode 0:fn, 1:update + ); + } mixed_qkv = reshape_qkvz_with_pad(attn_metadata, mixed_qkv); mixed_qkv = mixed_qkv.transpose(1, 2); @@ -668,76 +757,205 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( CHECK_GE(attn_metadata.q_seq_lens_vec.size(), static_cast(batch_size)) << "q_seq_lens_vec must be populated for Qwen3.5 prefill."; - std::vector packed_q; - std::vector packed_k; - std::vector packed_v; - std::vector packed_g; - std::vector packed_beta; - packed_q.reserve(batch_size); - packed_k.reserve(batch_size); - packed_v.reserve(batch_size); - packed_g.reserve(batch_size); - packed_beta.reserve(batch_size); - for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; - packed_q.emplace_back(processed_q[batch_idx].narrow(0, 0, valid_len)); - packed_k.emplace_back(processed_k[batch_idx].narrow(0, 0, valid_len)); - packed_v.emplace_back(processed_v[batch_idx].narrow(0, 0, valid_len)); - packed_g.emplace_back(g[batch_idx].narrow(0, 0, valid_len)); - packed_beta.emplace_back(beta[batch_idx].narrow(0, 0, valid_len)); - } - torch::Tensor packed_processed_q = torch::cat(packed_q, 0).unsqueeze(0); - torch::Tensor packed_processed_k = torch::cat(packed_k, 0).unsqueeze(0); - torch::Tensor packed_processed_v = torch::cat(packed_v, 0).unsqueeze(0); - torch::Tensor packed_g_tensor = torch::cat(packed_g, 0).unsqueeze(0); - torch::Tensor packed_beta_tensor = torch::cat(packed_beta, 0).unsqueeze(0); - - xllm::kernel::ChunkGatedDeltaRuleParams chunk_gated_delta_params; - chunk_gated_delta_params.q = packed_processed_q; - chunk_gated_delta_params.k = packed_processed_k; - chunk_gated_delta_params.v = packed_processed_v; - chunk_gated_delta_params.g = packed_g_tensor; - chunk_gated_delta_params.beta = packed_beta_tensor; - // Get initial state from ssm_cache for sequences with previous state - // Shape: [batch_size, num_heads, head_k_dim, head_v_dim] - torch::Tensor initial_state_tensor = - torch::index_select(ssm_cache, 0, linear_state_base_indices); - CHECK_EQ(input_params.has_initial_state.size(), - input_params.linear_state_ids.size()) - << "has_initial_state must be sequence-scoped."; - for (size_t i = 0; i < input_params.has_initial_state.size(); ++i) { - if (input_params.has_initial_state[i] == 0) { - initial_state_tensor.select(0, static_cast(i)).fill_(0.0); + if (attn_metadata.is_chunked_prefill && fla_ssm_state_layout) { + double scale = 1.0 / std::sqrt(static_cast(processed_q.size(-1))); + std::vector packed_q; + std::vector packed_k; + std::vector packed_v; + std::vector packed_a; + std::vector packed_b; + packed_q.reserve(batch_size); + packed_k.reserve(batch_size); + packed_v.reserve(batch_size); + packed_a.reserve(batch_size); + packed_b.reserve(batch_size); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + CHECK_GT(valid_len, 0) + << "Qwen3.5 chunked GDN sequence length must be > 0"; + packed_q.emplace_back(processed_q[batch_idx].narrow(0, 0, valid_len)); + packed_k.emplace_back(processed_k[batch_idx].narrow(0, 0, valid_len)); + packed_v.emplace_back(processed_v[batch_idx].narrow(0, 0, valid_len)); + packed_a.emplace_back(a[batch_idx].narrow(0, 0, valid_len)); + packed_b.emplace_back(b[batch_idx].narrow(0, 0, valid_len)); } - } + torch::Tensor packed_q_tensor = torch::cat(packed_q, 0).unsqueeze(0); + torch::Tensor packed_k_tensor = torch::cat(packed_k, 0).unsqueeze(0); + torch::Tensor packed_v_tensor = torch::cat(packed_v, 0).unsqueeze(0); + torch::Tensor packed_a_tensor = torch::cat(packed_a, 0).unsqueeze(0); + torch::Tensor packed_b_tensor = torch::cat(packed_b, 0).unsqueeze(0); - if (!fla_ssm_state_layout && attn_metadata.is_chunked_prefill) { - initial_state_tensor = - initial_state_tensor.transpose(-1, -2).contiguous(); - } + CHECK_EQ(input_params.has_initial_state.size(), + input_params.linear_state_ids.size()) + << "has_initial_state must be sequence-scoped."; + std::vector zero_state_indices; + zero_state_indices.reserve(input_params.has_initial_state.size()); + for (size_t i = 0; i < input_params.has_initial_state.size(); ++i) { + if (input_params.has_initial_state[i] == 0) { + zero_state_indices.emplace_back( + linear_state_base_indices.narrow(0, static_cast(i), 1)); + } + } + if (!zero_state_indices.empty()) { + auto indices = torch::cat(zero_state_indices, 0).contiguous(); + auto zero_state = torch::zeros({indices.size(0), + ssm_cache.size(1), + ssm_cache.size(2), + ssm_cache.size(3)}, + ssm_cache.options()); + ssm_cache.index_put_({indices}, zero_state); + } + + xllm::kernel::FusedSigmoidGatingDeltaRuleUpdateParams params; + params.A_log = A_log_.contiguous(); + params.a = packed_a_tensor.contiguous(); + params.dt_bias = dt_bias_.contiguous(); + params.q = packed_q_tensor.contiguous(); + params.k = packed_k_tensor.contiguous(); + params.v = packed_v_tensor.contiguous(); + params.b = packed_b_tensor.contiguous(); + params.initial_state_source = ssm_cache; + params.initial_state_indices = linear_state_base_indices.contiguous(); + params.cu_seqlens = attn_metadata.q_cu_seq_lens.contiguous(); + params.scale = static_cast(scale); + params.use_qk_l2norm_in_kernel = true; + params.softplus_beta = 1.0f; + params.softplus_threshold = 20.0f; + torch::Tensor packed_core_attn_out = + xllm::kernel::fused_sigmoid_gating_delta_rule_update(params); + core_attn_out = torch::zeros_like(processed_v); + int64_t packed_offset = 0; + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + core_attn_out[batch_idx] + .narrow(0, 0, valid_len) + .copy_(packed_core_attn_out[0].narrow(0, packed_offset, valid_len)); + packed_offset += valid_len; + } + } else if (attn_metadata.is_chunked_prefill && batch_size > 1) { + CHECK_EQ(input_params.has_initial_state.size(), + input_params.linear_state_ids.size()) + << "has_initial_state must be sequence-scoped."; + core_attn_out = torch::zeros_like(processed_v); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + CHECK_GT(valid_len, 0) + << "Qwen3.5 chunked GDN sequence length must be > 0"; + torch::Tensor seq_state_index = + linear_state_base_indices.narrow(0, batch_idx, 1).contiguous(); + torch::Tensor initial_state_tensor = + torch::index_select(ssm_cache, 0, seq_state_index); + if (input_params.has_initial_state[batch_idx] == 0) { + initial_state_tensor.fill_(0.0); + } + if (!fla_ssm_state_layout) { + initial_state_tensor = + initial_state_tensor.transpose(-1, -2).contiguous(); + } - chunk_gated_delta_params.initial_state = initial_state_tensor; - chunk_gated_delta_params.output_final_state = true; - chunk_gated_delta_params.cu_seqlens = attn_metadata.q_cu_seq_lens; - chunk_gated_delta_params.head_first = false; - chunk_gated_delta_params.use_qk_l2norm_in_kernel = true; - torch::Tensor packed_core_attn_out; - std::tie(packed_core_attn_out, last_recurrent_state) = - xllm::kernel::chunk_gated_delta_rule(chunk_gated_delta_params); - core_attn_out = torch::zeros_like(processed_v); - int64_t packed_offset = 0; - for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; - core_attn_out[batch_idx] - .narrow(0, 0, valid_len) - .copy_(packed_core_attn_out[0].narrow(0, packed_offset, valid_len)); - packed_offset += valid_len; + xllm::kernel::ChunkGatedDeltaRuleParams chunk_gated_delta_params; + chunk_gated_delta_params.q = + processed_q[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.k = + processed_k[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.v = + processed_v[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.g = + g[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.beta = + beta[batch_idx].narrow(0, 0, valid_len).unsqueeze(0); + chunk_gated_delta_params.initial_state = initial_state_tensor; + chunk_gated_delta_params.output_final_state = true; + chunk_gated_delta_params.cu_seqlens = torch::tensor( + std::vector{0, static_cast(valid_len)}, + torch::TensorOptions().dtype(torch::kInt32).device(device)); + chunk_gated_delta_params.head_first = false; + chunk_gated_delta_params.use_qk_l2norm_in_kernel = true; + torch::Tensor seq_core_attn_out; + std::tie(seq_core_attn_out, last_recurrent_state) = + xllm::kernel::chunk_gated_delta_rule(chunk_gated_delta_params); + core_attn_out[batch_idx] + .narrow(0, 0, valid_len) + .copy_(seq_core_attn_out[0].narrow(0, 0, valid_len)); + torch::Tensor state_to_store = + fla_ssm_state_layout ? last_recurrent_state + : last_recurrent_state.transpose(-1, -2); + ssm_cache.index_put_({seq_state_index}, + state_to_store.to(ssm_cache.dtype())); + } + } else { + std::vector packed_q; + std::vector packed_k; + std::vector packed_v; + std::vector packed_g; + std::vector packed_beta; + packed_q.reserve(batch_size); + packed_k.reserve(batch_size); + packed_v.reserve(batch_size); + packed_g.reserve(batch_size); + packed_beta.reserve(batch_size); + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + packed_q.emplace_back(processed_q[batch_idx].narrow(0, 0, valid_len)); + packed_k.emplace_back(processed_k[batch_idx].narrow(0, 0, valid_len)); + packed_v.emplace_back(processed_v[batch_idx].narrow(0, 0, valid_len)); + packed_g.emplace_back(g[batch_idx].narrow(0, 0, valid_len)); + packed_beta.emplace_back(beta[batch_idx].narrow(0, 0, valid_len)); + } + torch::Tensor packed_processed_q = torch::cat(packed_q, 0).unsqueeze(0); + torch::Tensor packed_processed_k = torch::cat(packed_k, 0).unsqueeze(0); + torch::Tensor packed_processed_v = torch::cat(packed_v, 0).unsqueeze(0); + torch::Tensor packed_g_tensor = torch::cat(packed_g, 0).unsqueeze(0); + torch::Tensor packed_beta_tensor = + torch::cat(packed_beta, 0).unsqueeze(0); + + xllm::kernel::ChunkGatedDeltaRuleParams chunk_gated_delta_params; + chunk_gated_delta_params.q = packed_processed_q; + chunk_gated_delta_params.k = packed_processed_k; + chunk_gated_delta_params.v = packed_processed_v; + chunk_gated_delta_params.g = packed_g_tensor; + chunk_gated_delta_params.beta = packed_beta_tensor; + // Get initial state from ssm_cache for sequences with previous state + // Shape: [batch_size, num_heads, head_k_dim, head_v_dim] + torch::Tensor initial_state_tensor = + torch::index_select(ssm_cache, 0, linear_state_base_indices); + CHECK_EQ(input_params.has_initial_state.size(), + input_params.linear_state_ids.size()) + << "has_initial_state must be sequence-scoped."; + for (size_t i = 0; i < input_params.has_initial_state.size(); ++i) { + if (input_params.has_initial_state[i] == 0) { + initial_state_tensor.select(0, static_cast(i)).fill_(0.0); + } + } + + if (!fla_ssm_state_layout && attn_metadata.is_chunked_prefill) { + initial_state_tensor = + initial_state_tensor.transpose(-1, -2).contiguous(); + } + + chunk_gated_delta_params.initial_state = initial_state_tensor; + chunk_gated_delta_params.output_final_state = true; + chunk_gated_delta_params.cu_seqlens = attn_metadata.q_cu_seq_lens; + chunk_gated_delta_params.head_first = false; + chunk_gated_delta_params.use_qk_l2norm_in_kernel = true; + torch::Tensor packed_core_attn_out; + std::tie(packed_core_attn_out, last_recurrent_state) = + xllm::kernel::chunk_gated_delta_rule(chunk_gated_delta_params); + core_attn_out = torch::zeros_like(processed_v); + int64_t packed_offset = 0; + for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + const int64_t valid_len = attn_metadata.q_seq_lens_vec[batch_idx]; + core_attn_out[batch_idx] + .narrow(0, 0, valid_len) + .copy_(packed_core_attn_out[0].narrow(0, packed_offset, valid_len)); + packed_offset += valid_len; + } + torch::Tensor state_to_store = + fla_ssm_state_layout ? last_recurrent_state + : last_recurrent_state.transpose(-1, -2); + ssm_cache.index_put_({linear_state_base_indices}, + state_to_store.to(ssm_cache.dtype())); } - torch::Tensor state_to_store = fla_ssm_state_layout - ? last_recurrent_state - : last_recurrent_state.transpose(-1, -2); - ssm_cache.index_put_({linear_state_base_indices}, - state_to_store.to(ssm_cache.dtype())); } else if (checkpoint_stride > 1) { auto ssm_state = torch::index_select(ssm_cache, 0, linear_state_base_indices); diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 43b2c7aca8..04885c6e1a 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -34,6 +34,7 @@ limitations under the License. #endif #include "core/common/global_flags.h" #include "core/common/metrics.h" +#include "core/kernels/npu/xllm_ops/xllm_ops_api.h" #include "core/util/utils.h" #include "platform/npu/device_capture_lock.h" @@ -51,6 +52,7 @@ namespace xllm::npu { namespace { constexpr uint64_t kSpecVerifyGraphKeyMask = 1ull << 63; constexpr uint64_t kSpecVerifyQMaxSeqLenShift = 32; +constexpr int64_t kFiaSplitFuseMaskSize = 2048; std::pair find_attention_plan_kv_cache( const std::vector& kv_caches) { @@ -170,6 +172,18 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args, persistent_mask_ = torch::zeros({max_graph_tokens, max_seq_len}, torch::dtype(dtype).device(device)); } + xfa_extra_tiling_ = torch::zeros( + {xllm::kernel::npu::x_flash_attention_extra_tiling_int32_count()}, + torch::dtype(torch::kInt32).device(device)); + xfa_attn_mask_ = + torch::triu(torch::ones({kFiaSplitFuseMaskSize, kFiaSplitFuseMaskSize}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(torch::kCPU)), + 1) + .to(torch::kInt8) + .to(device) + .contiguous(); // Do not need to create ATB context and custom paged attention operation if (args_.head_dim() == 0) { @@ -447,9 +461,10 @@ std::optional GraphPersistentParam::update( .slice(/*dim=*/0, /*start=*/0, /*end=*/embedding_tokens) .copy_(embedding, /*non_blocking=*/true); } - // Update q_cu_seq_lens only if params.q_cu_seq_lens is defined + // Preserve the historical q_cu_seq_lens semantics for Qwen3.5 GDN/conv + // update kernels. x_flash_attention graph replay uses a separate + // query-start buffer below. if (params.q_cu_seq_lens.defined()) { - // Lazy initialization: if q_cu_seq_lens_ is not initialized, initialize it if (q_cu_seq_lens_.numel() == 0) { const int64_t max_seqs_per_batch = get_decode_graph_capacity(options_); q_cu_seq_lens_ = torch::zeros({max_seqs_per_batch + 1}, @@ -464,8 +479,6 @@ std::optional GraphPersistentParam::update( CHECK_GE(params.q_cu_seq_lens.numel(), required_q_cu_seq_lens) << "q_cu_seq_lens does not have enough entries for ACL graph execution"; if (use_qwen3_5_query_start_loc && !input_has_leading_zero) { - // Normal Qwen3.5 decode input carries cumsum without the leading zero, - // while update kernels expect query_start_loc-style [0, cumsum...]. q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/1).zero_(); q_cu_seq_lens_ .slice(/*dim=*/0, /*start=*/1, /*end=*/actual_batch_size + 1) @@ -528,6 +541,48 @@ std::optional GraphPersistentParam::update( expanded_kv_seq_lens_vec = update_expanded_spec_decode_attention( params, actual_num_tokens, padded_num_tokens, actual_batch_size); } + if (xfa_q_cu_seq_lens_.numel() == 0) { + const int64_t max_graph_tokens = get_decode_graph_capacity(options_); + xfa_q_cu_seq_lens_ = torch::zeros( + {max_graph_tokens + 1}, torch::dtype(torch::kInt).device(device_)); + } + xfa_q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/1).zero_(); + int32_t xfa_q_offset = 0; + std::vector xfa_q_cu_seq_lens_vec; + if (use_expanded_spec_decode_attention) { + xfa_q_cu_seq_lens_vec.reserve(padded_num_tokens); + for (uint32_t i = 0; i < padded_num_tokens; ++i) { + ++xfa_q_offset; + xfa_q_cu_seq_lens_vec.emplace_back(xfa_q_offset); + } + } else { + xfa_q_cu_seq_lens_vec.reserve(padded_batch_size); + for (int64_t i = 0; i < padded_batch_size; ++i) { + xfa_q_offset += padded_q_seq_lens_vec[i]; + xfa_q_cu_seq_lens_vec.emplace_back(xfa_q_offset); + } + } + xfa_q_cu_seq_lens_ + .slice(/*dim=*/0, + /*start=*/1, + /*end=*/static_cast(xfa_q_cu_seq_lens_vec.size()) + 1) + .copy_(torch::tensor(xfa_q_cu_seq_lens_vec, torch::kInt).to(device_), + /*non_blocking=*/true); + + ModelInputParams xfa_params = params; + if (use_expanded_spec_decode_attention) { + xfa_params.num_sequences = padded_num_tokens; + xfa_params.kv_seq_lens_vec = expanded_kv_seq_lens_vec; + xfa_params.q_seq_lens_vec = std::vector(padded_num_tokens, 1); + } else { + xfa_params.num_sequences = padded_batch_size; + xfa_params.kv_seq_lens_vec = padded_kv_seq_lens_vec; + xfa_params.q_seq_lens_vec = padded_q_seq_lens_vec; + } + if (k_cache.defined() && k_cache.numel() > 0) { + update_x_flash_attention_extra_tiling( + xfa_params, xfa_params.num_sequences, k_cache); + } if (tiling_data_.numel() > 0) { aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); @@ -591,6 +646,8 @@ std::optional GraphPersistentParam::update( persistent_mask(padded_num_tokens); } params_for_capture->graph_buffer.tiling_data = tiling_data(); + params_for_capture->graph_buffer.xfa_extra_tiling = xfa_extra_tiling(); + params_for_capture->graph_buffer.xfa_attn_mask = xfa_attn_mask(); // Set persistent embedding if available if (params.input_embedding.defined()) { params_for_capture->input_embedding = @@ -612,15 +669,19 @@ std::optional GraphPersistentParam::update( params_for_capture->graph_buffer.expanded_tiling_data = tiling_data(); params_for_capture->graph_buffer.expanded_kv_seq_lens_vec = expanded_kv_seq_lens_vec; + params_for_capture->graph_buffer.xfa_q_cu_seq_lens = + xfa_q_cu_seq_lens().slice(/*dim=*/0, + /*start=*/0, + /*end=*/padded_num_tokens + 1); + } else { + params_for_capture->graph_buffer.xfa_q_cu_seq_lens = + xfa_q_cu_seq_lens().slice(/*dim=*/0, + /*start=*/0, + /*end=*/padded_batch_size + 1); } - // Set q_cu_seq_lens if available - if (params.q_cu_seq_lens.defined()) { - const bool use_qwen3_5_query_start_loc = - is_qwen3_5_model_type(args_.model_type()); + if (q_cu_seq_lens_.defined() && q_cu_seq_lens_.numel() > 0) { params_for_capture->q_cu_seq_lens = q_cu_seq_lens_.slice( - /*dim=*/0, - /*start=*/0, - /*end=*/padded_batch_size + (use_qwen3_5_query_start_loc ? 1 : 0)); + /*dim=*/0, /*start=*/0, /*end=*/padded_batch_size + 1); } return params_for_capture; @@ -965,6 +1026,41 @@ void GraphPersistentParam::plan_paged_attention_tiling( CHECK_EQ(acl_status, ACL_SUCCESS) << "Failed to copy tiling buffer to device"; } +void GraphPersistentParam::update_x_flash_attention_extra_tiling( + const ModelInputParams& input_params, + uint32_t padded_batch_size, + const torch::Tensor& k_cache) { + CHECK(k_cache.defined() && k_cache.dim() == 4) + << "x_flash_attention graph extra tiling needs valid k_cache"; + const int dp_local_tp_size = options_.world_size() / options_.dp_size(); + const int64_t q_head = args_.n_heads() / dp_local_tp_size; + const int64_t kv_head = std::max( + 1, args_.n_kv_heads().value_or(args_.n_heads()) / dp_local_tp_size); + std::vector q_lens; + std::vector kv_lens; + q_lens.reserve(padded_batch_size); + kv_lens.reserve(padded_batch_size); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + CHECK_LT(i, input_params.q_seq_lens_vec.size()) + << "q_seq_lens_vec is shorter than padded_batch_size"; + CHECK_LT(i, input_params.kv_seq_lens_vec.size()) + << "kv_seq_lens_vec is shorter than padded_batch_size"; + const int32_t q_len = input_params.q_seq_lens_vec[i]; + const int32_t kv_len = input_params.kv_seq_lens_vec[i]; + CHECK_EQ(q_len, 1) + << "x_flash_attention graph decode path expects q_len == 1"; + q_lens.emplace_back(static_cast(i + 1)); + kv_lens.emplace_back(kv_len); + } + xllm::kernel::npu::update_x_flash_attention_extra_tiling(q_lens, + kv_lens, + q_head, + kv_head, + k_cache.size(1), + k_cache.size(3), + xfa_extra_tiling_); +} + void GraphPersistentParam::update_attention_mask( const ModelInputParams& input_params) { torch::Dtype dtype = util::parse_dtype(args_.dtype(), device_); diff --git a/xllm/core/runtime/acl_graph_executor_impl.h b/xllm/core/runtime/acl_graph_executor_impl.h index deaca700bc..4bb9177a3f 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.h +++ b/xllm/core/runtime/acl_graph_executor_impl.h @@ -118,6 +118,9 @@ class GraphPersistentParam { return persistent_mask_; } const torch::Tensor& tiling_data() const { return tiling_data_; } + const torch::Tensor& xfa_q_cu_seq_lens() const { return xfa_q_cu_seq_lens_; } + const torch::Tensor& xfa_extra_tiling() const { return xfa_extra_tiling_; } + const torch::Tensor& xfa_attn_mask() const { return xfa_attn_mask_; } torch::Tensor hidden_states(uint32_t actual_tokens = 0) const { if (actual_tokens > 0) { return hidden_states_.slice( @@ -201,6 +204,11 @@ class GraphPersistentParam { const ModelInputParams& input_params, aclrtStream stream); + void update_x_flash_attention_extra_tiling( + const ModelInputParams& input_params, + uint32_t padded_batch_size, + const torch::Tensor& k_cache); + std::vector update_expanded_spec_decode_attention( const ModelInputParams& input_params, uint32_t actual_num_tokens, @@ -229,6 +237,7 @@ class GraphPersistentParam { // for deepseekv3.2 torch::Tensor q_cu_seq_lens_; + torch::Tensor xfa_q_cu_seq_lens_; // for mtp model torch::Tensor persistent_embedding_; @@ -248,6 +257,8 @@ class GraphPersistentParam { // Persistent paged attention tiling tensor on device torch::Tensor tiling_data_; + torch::Tensor xfa_extra_tiling_; + torch::Tensor xfa_attn_mask_; // Cached attention parameters int32_t num_head_; From 71925307cba9e74f52a1fb7c0e6f1d8a07dda508 Mon Sep 17 00:00:00 2001 From: Sinle4Cat Date: Thu, 21 May 2026 09:26:48 +0800 Subject: [PATCH 7/7] chore: remove qwen35 gdn debug logging --- .../npu_torch/qwen3_gated_delta_net_base.cpp | 48 ------------------- 1 file changed, 48 deletions(-) diff --git a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp index 9bff8fe78d..07e7057209 100644 --- a/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp +++ b/xllm/core/layers/npu_torch/qwen3_gated_delta_net_base.cpp @@ -15,9 +15,6 @@ limitations under the License. #include #include -#include -#include -#include #include #include "xllm/core/kernels/ops_api.h" @@ -26,30 +23,6 @@ namespace xllm { namespace layer { namespace { -bool qwen35_gdn_debug_enabled() { - const char* value = std::getenv("XLLM_DEBUG_QWEN35_GDN"); - if (value == nullptr) { - return false; - } - std::string flag(value); - return flag == "1" || flag == "true" || flag == "TRUE" || flag == "on" || - flag == "ON"; -} - -template -void append_vector(std::ostringstream& oss, - const char* name, - const std::vector& values) { - oss << " " << name << "=["; - for (size_t i = 0; i < values.size(); ++i) { - if (i != 0) { - oss << ","; - } - oss << values[i]; - } - oss << "]"; -} - torch::Tensor l2norm(const torch::Tensor& x, int64_t dim, double eps = 1e-6) { auto norm = torch::sqrt(torch::sum(torch::square(x), dim, true) + eps); return x / norm; @@ -548,27 +521,6 @@ torch::Tensor Qwen3GatedDeltaNetBaseImpl::forward( const bool use_spec_verify = input_params.is_spec_verify; bool is_any_prefill = attn_metadata.is_prefill || attn_metadata.is_chunked_prefill; - if (qwen35_gdn_debug_enabled() && - (attn_metadata.is_chunked_prefill || batch_size > 1)) { - static std::atomic log_count{0}; - const int current = log_count.fetch_add(1); - if (current < 160) { - std::ostringstream oss; - oss << "qwen35_gdn meta idx=" << current - << " is_prefill=" << attn_metadata.is_prefill - << " is_chunked_prefill=" << attn_metadata.is_chunked_prefill - << " is_spec_verify=" << use_spec_verify - << " batch_size=" << batch_size << " seq_len=" << seq_len - << " max_query_len=" << attn_metadata.max_query_len - << " checkpoint_stride=" << checkpoint_stride; - append_vector(oss, "q_lens", attn_metadata.q_seq_lens_vec); - append_vector(oss, "kv_lens", attn_metadata.kv_seq_lens_vec); - append_vector(oss, "query_start_loc", input_params.query_start_loc); - append_vector(oss, "linear_state_ids", input_params.linear_state_ids); - append_vector(oss, "has_initial_state", input_params.has_initial_state); - LOG(INFO) << oss.str(); - } - } if (!use_spec_verify && is_any_prefill) { torch::IntArrayRef num_accepted_tokens_opt;