From 3e05202bd55ff875fc215ddc19cf7b4d9c5bd34a Mon Sep 17 00:00:00 2001 From: liangzhiwei20 Date: Thu, 21 May 2026 17:47:42 +0800 Subject: [PATCH] bugfix: reduce acl graph memory overhead. --- xllm/core/runtime/acl_graph_executor_impl.cpp | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 2b976a862..920c6dfde 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -24,8 +24,8 @@ limitations under the License. #include -#include "core/common/global_flags.h" #include "core/framework/config/execution_config.h" +#include "core/framework/config/scheduler_config.h" #include "core/framework/config/speculative_config.h" #ifdef TORCH_HIGHER_THAN_PTA6 #include @@ -33,7 +33,6 @@ limitations under the License. #include #include #endif -#include "core/common/global_flags.h" #include "core/common/metrics.h" #include "core/util/utils.h" #include "platform/npu/device_capture_lock.h" @@ -69,10 +68,13 @@ std::pair find_attention_plan_kv_cache( int64_t get_decode_graph_capacity(const runtime::Options& options) { CHECK_GT(options.num_decoding_tokens(), 0) << "num_decoding_tokens must be > 0 for graph capacity"; - if (::xllm::SpeculativeConfig::get_instance().enable_atb_spec_kernel()) { - return options.max_seqs_per_batch(); + const bool use_atb_spec_kernel = + ::xllm::SpeculativeConfig::get_instance().enable_atb_spec_kernel(); + if (options.enable_speculative_decode() && !options.is_draft_engine() && + !use_atb_spec_kernel) { + return options.max_seqs_per_batch() * options.num_decoding_tokens(); } - return options.max_seqs_per_batch() * options.num_decoding_tokens(); + return options.max_seqs_per_batch(); } } // namespace @@ -96,9 +98,15 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args, // Check if mRoPE is used (for VLM models like qwen2-vl) use_mrope_ = !args.rope_scaling_mrope_section().empty(); - // Use max_tokens_per_batch for first dimension size - // num_decode_tokens - const int64_t max_tokens_per_batch = options.max_tokens_per_batch(); + // Use max_tokens_per_batch for general token-shaped persistent buffers. + // These buffers may participate in both decode replay and the Qwen3.5 + // spec-verify graph path, so they keep the broader scheduler upper bound. + const int64_t max_tokens_per_batch = + ::xllm::SchedulerConfig::get_instance().max_tokens_per_batch(); + // Graph-mode token capacity is narrower than max_tokens_per_batch: ACL graph + // only serves decode / spec-verify batches, so the relevant row upper bound + // comes from decode graph capacity instead. + const int64_t max_graph_tokens = get_decode_graph_capacity(options); // num_sequences const int64_t max_seqs_per_batch = get_decode_graph_capacity(options); auto tensor_options = torch::TensorOptions().device(device); @@ -146,9 +154,13 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args, hidden_states_ = torch::zeros({max_tokens_per_batch, args.hidden_size()}, torch::dtype(dtype).device(device)); - // Initialize persistent_mask_ if need_update_attn_mask is true + // Initialize persistent_mask_ only for model types that need to update an + // explicit attention mask in graph mode. Unlike generic token buffers, the + // mask is only consumed by decode / spec-verify graphs, so size it by graph + // token capacity instead of the much larger max_tokens_per_batch prefill + // budget. if (need_update_attn_mask_) { - persistent_mask_ = torch::zeros({max_tokens_per_batch, max_seq_len}, + persistent_mask_ = torch::zeros({max_graph_tokens, max_seq_len}, torch::dtype(dtype).device(device)); } @@ -776,8 +788,9 @@ void GraphPersistentParam::update_attention_mask( // Check if num_tokens is within bounds CHECK_LE(num_tokens, persistent_mask_.size(0)) - << "num_tokens (" << num_tokens << ") exceeds max_tokens_per_batch (" - << persistent_mask_.size(0) << ")"; + << "num_tokens (" << num_tokens + << ") exceeds graph attention-mask capacity (" << persistent_mask_.size(0) + << ")"; // Get slice for actual num_tokens (compatible with both chunked and // non-chunked)