Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions xllm/core/runtime/acl_graph_executor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@ limitations under the License.

#include <numeric>

#include "core/common/global_flags.h"
#include "core/framework/config/execution_config.h"
#include "core/framework/config/scheduler_config.h"
#include "core/framework/config/speculative_config.h"
#ifdef TORCH_HIGHER_THAN_PTA6
#include <torch_npu/csrc/framework/OpCommand.h>
#else
#include <torch_npu/csrc/aten/NPUNativeFunctions.h>
#include <torch_npu/csrc/framework/utils/OpPreparation.h>
#endif
#include "core/common/global_flags.h"
#include "core/common/metrics.h"
#include "core/util/utils.h"
#include "platform/npu/device_capture_lock.h"
Expand Down Expand Up @@ -69,10 +68,13 @@ std::pair<torch::Tensor, torch::Tensor> find_attention_plan_kv_cache(
int64_t get_decode_graph_capacity(const runtime::Options& options) {
CHECK_GT(options.num_decoding_tokens(), 0)
<< "num_decoding_tokens must be > 0 for graph capacity";
if (::xllm::SpeculativeConfig::get_instance().enable_atb_spec_kernel()) {
return options.max_seqs_per_batch();
const bool use_atb_spec_kernel =
::xllm::SpeculativeConfig::get_instance().enable_atb_spec_kernel();
if (options.enable_speculative_decode() && !options.is_draft_engine() &&
!use_atb_spec_kernel) {
return options.max_seqs_per_batch() * options.num_decoding_tokens();
}
return options.max_seqs_per_batch() * options.num_decoding_tokens();
return options.max_seqs_per_batch();
}
} // namespace

Expand All @@ -96,9 +98,15 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args,
// Check if mRoPE is used (for VLM models like qwen2-vl)
use_mrope_ = !args.rope_scaling_mrope_section().empty();

// Use max_tokens_per_batch for first dimension size
// num_decode_tokens
const int64_t max_tokens_per_batch = options.max_tokens_per_batch();
// Use max_tokens_per_batch for general token-shaped persistent buffers.
// These buffers may participate in both decode replay and the Qwen3.5
// spec-verify graph path, so they keep the broader scheduler upper bound.
const int64_t max_tokens_per_batch =
::xllm::SchedulerConfig::get_instance().max_tokens_per_batch();
// Graph-mode token capacity is narrower than max_tokens_per_batch: ACL graph
// only serves decode / spec-verify batches, so the relevant row upper bound
// comes from decode graph capacity instead.
const int64_t max_graph_tokens = get_decode_graph_capacity(options);
Comment thread
RobbieLeung marked this conversation as resolved.
// num_sequences
const int64_t max_seqs_per_batch = get_decode_graph_capacity(options);
auto tensor_options = torch::TensorOptions().device(device);
Expand Down Expand Up @@ -146,9 +154,13 @@ GraphPersistentParam::GraphPersistentParam(const ModelArgs& args,
hidden_states_ = torch::zeros({max_tokens_per_batch, args.hidden_size()},
torch::dtype(dtype).device(device));

// Initialize persistent_mask_ if need_update_attn_mask is true
// Initialize persistent_mask_ only for model types that need to update an
// explicit attention mask in graph mode. Unlike generic token buffers, the
// mask is only consumed by decode / spec-verify graphs, so size it by graph
// token capacity instead of the much larger max_tokens_per_batch prefill
// budget.
if (need_update_attn_mask_) {
persistent_mask_ = torch::zeros({max_tokens_per_batch, max_seq_len},
persistent_mask_ = torch::zeros({max_graph_tokens, max_seq_len},
torch::dtype(dtype).device(device));
}

Expand Down Expand Up @@ -776,8 +788,9 @@ void GraphPersistentParam::update_attention_mask(

// Check if num_tokens is within bounds
CHECK_LE(num_tokens, persistent_mask_.size(0))
<< "num_tokens (" << num_tokens << ") exceeds max_tokens_per_batch ("
<< persistent_mask_.size(0) << ")";
<< "num_tokens (" << num_tokens
<< ") exceeds graph attention-mask capacity (" << persistent_mask_.size(0)
<< ")";

// Get slice for actual num_tokens (compatible with both chunked and
// non-chunked)
Expand Down
Loading