From e25c6144edd022fcc9817152eaa84962edb082f0 Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Wed, 20 May 2026 17:22:02 +0800 Subject: [PATCH 01/10] refacor:minimize changes. --- xllm/core/distributed_runtime/llm_engine.cpp | 6 +- xllm/core/distributed_runtime/master.cpp | 2 +- xllm/core/framework/kv_cache/kv_cache.cpp | 3 +- .../framework/kv_cache/kv_cache_shape.cpp | 3 +- xllm/core/layers/npu_torch/fused_moe.cpp | 3 +- xllm/core/runtime/params_utils.cpp | 4 ++ xllm/core/util/utils.h | 23 ------- xllm/models/llm/deepseek_v4_mtp.h | 69 +++++++++---------- xllm/models/llm/mtp_model_base.h | 46 +++---------- 9 files changed, 57 insertions(+), 102 deletions(-) diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 56ffc4db43..3f0ce1ae33 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -575,7 +575,8 @@ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { // all swa-related cache size from cache_size_in_bytes, then compute // c4_count / c128_count (c4_count = 32 * c128_count). // cache_size_in_bytes is already the full available device memory. - if (util::is_deepseek_v4_model_type(args_.model_type())) { + if (args_.model_type() == "deepseek_v4" || + args_.model_type() == "deepseek_v4_mtp") { const int64_t max_seqs = static_cast(std::max(options_.max_seqs_per_batch(), 1)); const int32_t block_size = options_.block_size(); @@ -792,7 +793,8 @@ bool LLMEngine::allocate_kv_cache(const KVCacheCapacity& kv_cache_cap) { // init kv cache for each worker const KVCacheShape kv_cache_shape(kv_cache_cap, args_, dp_local_tp_size_); - if (util::is_deepseek_v4_model_type(args_.model_type())) { + if (args_.model_type() == "deepseek_v4" || + args_.model_type() == "deepseek_v4_mtp") { LOG(INFO) << "Initializing DSV4 kv cache with shape: [swa_count=" << kv_cache_cap.swa_count() << ", c4_count=" << kv_cache_cap.c4_count() diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index dbb0d6775a..54f7c78682 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -146,7 +146,7 @@ Master::Master(const Options& options, EngineType type) std::filesystem::path(options_.model_path()).lexically_normal(); if (options_.enable_prefix_cache() && options_.backend() == "llm") { const std::string model_type = util::get_model_type(model_path); - if (util::is_deepseek_v4_model_type(model_type)) { + if (model_type == "deepseek_v4" || model_type == "deepseek_v4_mtp") { LOG(WARNING) << model_type << " does not support prefix cache with " "CompositeBlockManager yet, fallback to " "enable_prefix_cache=false"; diff --git a/xllm/core/framework/kv_cache/kv_cache.cpp b/xllm/core/framework/kv_cache/kv_cache.cpp index 0dd585569e..c1173ee41e 100644 --- a/xllm/core/framework/kv_cache/kv_cache.cpp +++ b/xllm/core/framework/kv_cache/kv_cache.cpp @@ -414,7 +414,8 @@ void allocate_kv_caches(std::vector& kv_caches, const int64_t num_layers = create_options.num_layers(); kv_caches.reserve(num_layers); - if (util::is_deepseek_v4_model_type(create_options.model_type())) { + if (create_options.model_type() == "deepseek_v4" || + create_options.model_type() == "deepseek_v4_mtp") { std::vector layer_compress_ratios; layer_compress_ratios.reserve(static_cast(num_layers)); std::map ratio_shape_summaries; diff --git a/xllm/core/framework/kv_cache/kv_cache_shape.cpp b/xllm/core/framework/kv_cache/kv_cache_shape.cpp index 8f7501f030..5181580cc6 100644 --- a/xllm/core/framework/kv_cache/kv_cache_shape.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_shape.cpp @@ -66,7 +66,8 @@ KVCacheShape::KVCacheShape(const KVCacheCapacity& kv_cache_cap, CHECK_GT(world_size, 0) << "world_size must be positive."; CHECK_GT(kv_cache_cap.block_size(), 0) << "block_size must be positive."; - if (util::is_deepseek_v4_model_type(model_args.model_type())) { + if (model_args.model_type() == "deepseek_v4" || + model_args.model_type() == "deepseek_v4_mtp") { key_cache_shape_ = std::vector{kv_cache_cap.swa_count(), kv_cache_cap.c4_count(), kv_cache_cap.c128_count()}; diff --git a/xllm/core/layers/npu_torch/fused_moe.cpp b/xllm/core/layers/npu_torch/fused_moe.cpp index 1658f67a37..78b2af2b24 100644 --- a/xllm/core/layers/npu_torch/fused_moe.cpp +++ b/xllm/core/layers/npu_torch/fused_moe.cpp @@ -338,7 +338,8 @@ FusedMoEImpl::FusedMoEImpl(const ModelArgs& model_args, n_shared_experts_(model_args.n_shared_experts()), is_gated_(moe_args.is_gated), skip_gate_load_(moe_args.skip_gate_load), - is_deepseek_v4_(util::is_deepseek_v4_model_type(model_args.model_type())), + is_deepseek_v4_(model_args.model_type() == "deepseek_v4" || + model_args.model_type() == "deepseek_v4_mtp"), renormalize_(model_args.norm_topk_prob() ? 1 : 0), hidden_act_(model_args.hidden_act()), scoring_func_(model_args.scoring_func()), diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 1cc0605a49..642cadeb27 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -224,6 +224,10 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.batch_forward_type = BatchForwardType(pb_forward_input->batch_forward_type()); input_params.num_sequences = pb_forward_input->num_sequences(); + CHECK_EQ(block_tables_vec.size(), + static_cast(input_params.num_sequences)) + << "block_tables_vec size (" << block_tables_vec.size() + << ") must match num_sequences (" << input_params.num_sequences << ")"; input_params.kv_max_seq_len = pb_forward_input->max_seq_len(); input_params.q_max_seq_len = pb_forward_input->q_max_seq_len(); input_params.kv_seq_lens = torch::tensor(seq_lens, tensor_options); diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 81043698b7..4e0ae72692 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -131,29 +131,6 @@ inline bool is_mla_model_type(std::string_view model_type) { return mla_model_type_set().contains(std::string(model_type)); } -inline bool has_mtp_model_type_marker(std::string_view model_type) { - return model_type.find("mtp") != std::string_view::npos; -} - -inline bool starts_with_model_type(std::string_view model_type, - std::string_view target_model_type) { - return model_type.size() >= target_model_type.size() && - model_type.compare( - 0, target_model_type.size(), target_model_type) == 0; -} - -inline bool is_target_mtp_model_type(std::string_view model_type, - std::string_view target_model_type) { - return starts_with_model_type(model_type, target_model_type) && - has_mtp_model_type_marker(model_type); -} - -inline bool is_deepseek_v4_model_type(std::string_view model_type) { - constexpr std::string_view kTargetModelType = "deepseek_v4"; - return model_type == kTargetModelType || - is_target_mtp_model_type(model_type, kTargetModelType); -} - inline std::string get_model_name( const std::filesystem::path& normalized_model_path) { std::string model_name; diff --git a/xllm/models/llm/deepseek_v4_mtp.h b/xllm/models/llm/deepseek_v4_mtp.h index e76c43a94f..dc37d52921 100644 --- a/xllm/models/llm/deepseek_v4_mtp.h +++ b/xllm/models/llm/deepseek_v4_mtp.h @@ -60,8 +60,7 @@ class DeepseekV4MultiTokenPredictorLayerImpl torch::Tensor positions, layer::AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params, - torch::Tensor tokens) { + const ModelInputParams& input_params) { ModelInputParams modified_input_params = input_params; modified_input_params.input_embedding = previous_hidden_states; std::optional residual; @@ -71,30 +70,27 @@ class DeepseekV4MultiTokenPredictorLayerImpl positions, attn_metadata, kv_cache, - modified_input_params, - tokens); + modified_input_params); } }; TORCH_MODULE(DeepseekV4MultiTokenPredictorLayer); class DeepseekV4MtpModelImpl final : public torch::nn::Module { public: - explicit DeepseekV4MtpModelImpl(const ModelContext& context) { - auto model_args = context.get_model_args(); + explicit DeepseekV4MtpModelImpl(const ModelContext& context) + : model_args_(context.get_model_args()) { auto options = context.get_tensor_options(); auto parallel_args = context.get_parallel_args(); - model_args_ = &model_args; - - CHECK_GT(model_args.n_layers(), 0) + CHECK_GT(model_args_.n_layers(), 0) << "deepseek_v4_mtp requires n_layers > 0"; - CHECK_GE(model_args.num_nextn_predict_layers(), 0) + CHECK_GE(model_args_.num_nextn_predict_layers(), 0) << "deepseek_v4_mtp requires num_nextn_predict_layers >= 0"; - const int32_t mtp_n_layers = model_args.n_layers(); + const int32_t mtp_n_layers = model_args_.n_layers(); - num_heads_ = model_args.n_heads(); - head_dim_ = model_args.o_lora_rank() + model_args.qk_rope_head_dim(); + num_heads_ = model_args_.n_heads(); + head_dim_ = model_args_.o_lora_rank() + model_args_.qk_rope_head_dim(); dp_local_tp_size_ = std::max(parallel_args.world_size() / std::max(parallel_args.dp_size(), 1), @@ -104,31 +100,31 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { "size. n_heads=" << num_heads_ << ", local_tp_size=" << dp_local_tp_size_; tp_num_heads_ = num_heads_ / dp_local_tp_size_; - window_size_ = model_args.window_size(); - index_n_heads_ = model_args.index_n_heads(); - index_head_dim_ = model_args.index_head_dim(); - index_topk_ = model_args.index_topk(); - norm_eps_ = static_cast(model_args.rms_norm_eps()); - - const int64_t rope_head_dim = model_args.rope_head_dim(); - const int64_t max_pos = model_args.max_position_embeddings(); + window_size_ = model_args_.window_size(); + index_n_heads_ = model_args_.index_n_heads(); + index_head_dim_ = model_args_.index_head_dim(); + index_topk_ = model_args_.index_topk(); + norm_eps_ = static_cast(model_args_.rms_norm_eps()); + + const int64_t rope_head_dim = model_args_.rope_head_dim(); + const int64_t max_pos = model_args_.max_position_embeddings(); if (rope_head_dim > 0 && max_pos > 0) { const int64_t original_max_pos = - model_args.rope_scaling_original_max_position_embeddings() > 0 - ? model_args.rope_scaling_original_max_position_embeddings() + model_args_.rope_scaling_original_max_position_embeddings() > 0 + ? model_args_.rope_scaling_original_max_position_embeddings() : max_pos; dsa_rotary_embedding_ = std::make_shared( /*rotary_dim=*/rope_head_dim, /*max_position_embeddings=*/max_pos, /*interleaved=*/true, - /*rope_theta=*/model_args.rope_theta(), - /*compress_rope_theta=*/model_args.compress_rope_theta(), - /*scaling_factor=*/model_args.factor(), + /*rope_theta=*/model_args_.rope_theta(), + /*compress_rope_theta=*/model_args_.compress_rope_theta(), + /*scaling_factor=*/model_args_.factor(), /*extrapolation_factor=*/1.0f, - /*beta_fast=*/model_args.beta_fast(), - /*beta_slow=*/model_args.beta_slow(), - /*attn_factor=*/model_args.rope_scaling_attn_factor(), + /*beta_fast=*/model_args_.beta_fast(), + /*beta_slow=*/model_args_.beta_slow(), + /*attn_factor=*/model_args_.rope_scaling_attn_factor(), /*mscale=*/1.0f, /*mscale_all_dim=*/1.0f, /*original_max_position_embeddings=*/original_max_pos, @@ -136,14 +132,14 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { dsa_cos_sin_ = dsa_rotary_embedding_->get_cos_sin_cache("default"); } - if (model_args.index_head_dim() > 0) { + if (model_args_.index_head_dim() > 0) { int64_t hadamard_dim_padded = - deepseek_v4_next_power_of_two(model_args.index_head_dim()); + deepseek_v4_next_power_of_two(model_args_.index_head_dim()); dsa_hadamard_ = deepseek_v4_create_hadamard_matrix( hadamard_dim_padded, options.dtype().toScalarType(), options.device()); } - deepseek_v4_build_cache_specs(model_args, caches_info_, group_infos_); + deepseek_v4_build_cache_specs(model_args_, caches_info_, group_infos_); mtp_layers_.reserve(mtp_n_layers); for (int32_t i = 0; i < mtp_n_layers; ++i) { @@ -256,8 +252,7 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { positions, *modified_input_params.attn_metadata, kv_caches[i], - modified_input_params, - tokens); + modified_input_params); } auto [output, _] = final_norm_(hidden_states, std::nullopt); @@ -648,8 +643,8 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { const int32_t layer_compress_ratio = deepseek_v4_normalize_compress_ratio( (layer_id < - static_cast(model_args_->compress_ratios().size())) - ? model_args_->compress_ratios()[static_cast(layer_id)] + static_cast(model_args_.compress_ratios().size())) + ? model_args_.compress_ratios()[static_cast(layer_id)] : 1); if (layer_compress_ratio == 4 && dsa.c4_cos.defined()) { @@ -706,7 +701,7 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { int64_t index_topk_ = 512; double norm_eps_ = 1e-6; - const ModelArgs* model_args_ = nullptr; + ModelArgs model_args_; layer::RMSNorm final_norm_{nullptr}; layer::WordEmbedding embed_tokens_{nullptr}; diff --git a/xllm/models/llm/mtp_model_base.h b/xllm/models/llm/mtp_model_base.h index 7535b26462..b97fa37d27 100644 --- a/xllm/models/llm/mtp_model_base.h +++ b/xllm/models/llm/mtp_model_base.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "core/framework/state_dict/utils.h" @@ -32,7 +31,7 @@ namespace xllm { enum class MtpProjectionType { CONCAT_EH_PROJ, ADD_EH_PROJ }; inline bool is_deepseek_v4_mtp_model(const ModelArgs& model_args) { - return util::is_target_mtp_model_type(model_args.model_type(), "deepseek_v4"); + return model_args.model_type() == "deepseek_v4_mtp"; } inline MtpProjectionType get_mtp_projection_type(const ModelArgs& model_args) { @@ -80,13 +79,8 @@ class MtpDecoderLayerImplBase : public torch::nn::Module { /*QuantArgs=*/QuantArgs(), options)); } - const int32_t decoder_layer_index = is_deepseek_v4_mtp_model(model_args_) - ? std::min( - layer_index, - model_args_.n_layers() - 1) - : layer_index; mtp_block_ = register_module("mtp_block", - DecoderLayerType(context, decoder_layer_index)); + DecoderLayerType(context, layer_index)); if (is_deepseek_v4_mtp_model(model_args_)) { const int64_t hc_mult = model_args_.hc_mult(); @@ -106,9 +100,7 @@ class MtpDecoderLayerImplBase : public torch::nn::Module { torch::Tensor positions, const layer::AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params, - const std::optional& input_ids = - std::nullopt) { + const ModelInputParams& input_params) { // Layer norm on token inputs auto enorm_out = std::get<0>(enorm_(embed)); @@ -138,29 +130,12 @@ class MtpDecoderLayerImplBase : public torch::nn::Module { } // Pass through mtp block - if constexpr (std::is_invocable_v&, - torch::Tensor&, - const layer::AttentionMetadata&, - KVCache&, - const ModelInputParams&, - const std::optional&>) { - hidden_states = mtp_block_(hidden_states, - residual, - positions, - attn_metadata, - kv_cache, - input_params, - input_ids); - } else { - hidden_states = mtp_block_(hidden_states, - residual, - positions, - attn_metadata, - kv_cache, - input_params); - } + hidden_states = mtp_block_(hidden_states, + residual, + positions, + attn_metadata, + kv_cache, + input_params); if (is_deepseek_v4_mtp_model(model_args_)) { auto x_float = hidden_states.to(torch::kFloat32); @@ -311,8 +286,7 @@ class MtpModelImplBase : public torch::nn::Module { positions, attn_metadata, kv_caches[i], - modified_input_params, - tokens); + modified_input_params); if (!modified_input_params.record_layer(static_cast(i), hidden_states.device())) { return ModelOutput(); From 6c8852988b152442895db4728f51ba88add64f73 Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Wed, 20 May 2026 17:38:47 +0800 Subject: [PATCH 02/10] refacor:minimize changes/1. --- xllm/core/runtime/params_utils.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 642cadeb27..21a0940632 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -224,10 +224,12 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.batch_forward_type = BatchForwardType(pb_forward_input->batch_forward_type()); input_params.num_sequences = pb_forward_input->num_sequences(); - CHECK_EQ(block_tables_vec.size(), - static_cast(input_params.num_sequences)) - << "block_tables_vec size (" << block_tables_vec.size() - << ") must match num_sequences (" << input_params.num_sequences << ")"; + if (!block_tables_vec.empty()) { + CHECK_EQ(block_tables_vec.size(), + static_cast(input_params.num_sequences)) + << "block_tables_vec size (" << block_tables_vec.size() + << ") must match num_sequences (" << input_params.num_sequences << ")"; + } input_params.kv_max_seq_len = pb_forward_input->max_seq_len(); input_params.q_max_seq_len = pb_forward_input->q_max_seq_len(); input_params.kv_seq_lens = torch::tensor(seq_lens, tensor_options); From 5e90d8c829b5ca480834822eaec7da87485e697c Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Thu, 21 May 2026 14:43:41 +0800 Subject: [PATCH 03/10] feat: support npu_graph for dsv4_mtp in 0521 new style. --- xllm/models/llm/deepseek_v4_mtp.h | 342 +++++++++++++++++++++++++++++- 1 file changed, 334 insertions(+), 8 deletions(-) diff --git a/xllm/models/llm/deepseek_v4_mtp.h b/xllm/models/llm/deepseek_v4_mtp.h index dc37d52921..541cf77883 100644 --- a/xllm/models/llm/deepseek_v4_mtp.h +++ b/xllm/models/llm/deepseek_v4_mtp.h @@ -36,6 +36,7 @@ limitations under the License. #include "core/kernels/ops_api.h" #include "core/layers/common/attention_metadata_builder.h" #include "core/layers/common/dsa_metadata.h" +#include "core/layers/common/dsa_metadata_builder.h" #include "core/layers/common/linear.h" #include "core/layers/common/rms_norm.h" #include "core/layers/common/word_embedding.h" @@ -81,6 +82,7 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { : model_args_(context.get_model_args()) { auto options = context.get_tensor_options(); auto parallel_args = context.get_parallel_args(); + device_ = options.device(); CHECK_GT(model_args_.n_layers(), 0) << "deepseek_v4_mtp requires n_layers > 0"; @@ -211,9 +213,11 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { const ModelInputParams& input_params) { torch::NoGradGuard no_grad; - if (tokens.numel() == 0) { - tokens = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); - positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + if (!tokens.defined() || tokens.numel() == 0) { + tokens = torch::tensor( + {0}, torch::TensorOptions().dtype(torch::kInt32).device(device_)); + positions = torch::tensor( + {0}, torch::TensorOptions().dtype(torch::kInt32).device(device_)); } const torch::Device runtime_device = tokens.device(); @@ -224,8 +228,20 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { torch::Tensor hidden_states = embed_tokens_(tokens); - tokens = maybe_to_device(tokens, runtime_device); - positions = maybe_to_device(positions, runtime_device); + const bool acl_graph_forward = deepseek_v4_uses_acl_graph(input_params); + if (acl_graph_forward) { + CHECK(tokens.defined() && tokens.device() == runtime_device) + << "[DeepseekV4Mtp] ACL graph requires tokens on the runtime device"; + CHECK(positions.defined() && positions.device() == runtime_device) + << "[DeepseekV4Mtp] ACL graph requires positions on the runtime device"; + CHECK(input_params.new_cache_slots.defined()) + << "[DeepseekV4Mtp] ACL graph requires persistent new_cache_slots"; + CHECK(input_params.block_tables.defined()) + << "[DeepseekV4Mtp] ACL graph requires persistent block_tables"; + } else { + tokens = maybe_to_device(tokens, runtime_device); + positions = maybe_to_device(positions, runtime_device); + } auto mask = (positions == 0); if (mask.any().item()) { @@ -234,11 +250,30 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { } auto modified_input_params = input_params; + if (acl_graph_forward) { + normalize_graph_metadata_input_params(modified_input_params); + } auto& dp_token_nums = modified_input_params.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); - modified_input_params.attn_metadata = - build_attention_metadata_for_forward(positions, modified_input_params); + if (!modified_input_params.attn_metadata) { + CHECK(!acl_graph_forward) + << "[DeepseekV4Mtp] ACL graph requires prebuilt attention metadata"; + modified_input_params.attn_metadata = + build_attention_metadata_for_forward(positions, modified_input_params); + } + + if (modified_input_params.attn_metadata && + modified_input_params.attn_metadata->dsa_metadata) { + auto& dsa = *(modified_input_params.attn_metadata->dsa_metadata); + const bool graph_metadata_ready = acl_graph_forward && + dsa.packed_metadata_buffer.defined() && + dsa.start_pos.defined(); + if (graph_metadata_ready) { + build_rope(dsa, runtime_device); + build_precomputed_metadata(dsa); + } + } CHECK_GE(static_cast(kv_caches.size()), static_cast(mtp_layers_.size())) @@ -259,7 +294,281 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { return ModelOutput(output, std::nullopt); } - private: + bool requires_graph_forward_metadata() { return true; } + + std::unique_ptr + create_graph_forward_metadata_state() { + return std::make_unique(); + } + + void prepare_graph_forward_metadata(ModelGraphMetadataState* state, + const torch::Tensor& positions, + ModelInputParams& input_params) { + CHECK(state != nullptr) + << "[DeepseekV4Mtp] graph metadata state must be initialized"; + auto* deepseek_v4_state = dynamic_cast(state); + CHECK(deepseek_v4_state != nullptr) + << "[DeepseekV4Mtp] received incompatible graph metadata state"; + + auto modified_input_params = input_params; + if (modified_input_params.actual_num_sequences == 0) { + fill_empty_dp_rank_input_params(modified_input_params); + } + normalize_graph_metadata_input_params(modified_input_params); + auto& dp_token_nums = modified_input_params.dp_global_token_nums; + std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); + + auto attn_metadata = std::make_shared( + layer::DSAMetadataBuilder::build(modified_input_params, + positions, + dsa_cos_sin_, + caches_info_, + group_infos_)); + if (attn_metadata->dsa_metadata) { + auto& dsa = *attn_metadata->dsa_metadata; + if (dsa_hadamard_.defined()) { + dsa.hadamard = dsa_hadamard_; + } + copy_to_graph_packed_metadata_buffer( + dsa, deepseek_v4_state->dsa_metadata_persistent, positions.device()); + if (dsa.actual_seq_lengths_kv.defined() && dsa.seq_lens_q.defined()) { + dsa.start_pos = + (dsa.actual_seq_lengths_kv - dsa.seq_lens_q).to(torch::kInt32); + } + } + input_params.attn_metadata = persist_graph_attention_metadata( + *deepseek_v4_state, std::move(attn_metadata)); + CHECK(input_params.attn_metadata) + << "[DeepseekV4Mtp] ACL graph requires DSA metadata"; + } + + private: + static bool tensor_aliases_storage(const torch::Tensor& lhs, + const torch::Tensor& rhs) { + return lhs.defined() && rhs.defined() && lhs.data_ptr() == rhs.data_ptr() && + lhs.sizes() == rhs.sizes() && lhs.strides() == rhs.strides(); + } + + static torch::Tensor copy_to_persistent_tensor(const torch::Tensor& src, + torch::Tensor& dst) { + if (!src.defined()) { + return src; + } + if (!dst.defined()) { + dst = torch::empty_like(src); + } else { + CHECK_EQ(dst.scalar_type(), src.scalar_type()) + << "[DeepseekV4Mtp] graph metadata tensor dtype changed"; + CHECK_EQ(dst.device(), src.device()) + << "[DeepseekV4Mtp] graph metadata tensor device changed"; + if (dst.sizes() != src.sizes()) { + bool can_copy_into_capacity = dst.dim() == src.dim() && src.dim() > 0 && + src.size(0) <= dst.size(0); + for (int64_t dim = 1; can_copy_into_capacity && dim < src.dim(); + ++dim) { + can_copy_into_capacity = dst.size(dim) == src.size(dim); + } + CHECK(can_copy_into_capacity) + << "[DeepseekV4Mtp] graph metadata tensor size changed from " + << dst.sizes() << " to " << src.sizes(); + dst.zero_(); + dst.slice(/*dim=*/0, /*start=*/0, /*end=*/src.size(0)) + .copy_(src, /*non_blocking=*/true); + return dst; + } + } + if (!tensor_aliases_storage(src, dst)) { + dst.copy_(src, /*non_blocking=*/true); + } + return dst; + } + + static void copy_to_graph_packed_metadata_buffer( + layer::DSAMetadata& dsa, + DeepseekV4GraphMetadataState::DSAMetadataPersistent& persistent, + const torch::Device& runtime_device) { +#if defined(USE_NPU) + if (runtime_device.is_cpu() || + runtime_device.type() != c10::DeviceType::PrivateUse1) { + deepseek_v4_move_dsa_metadata_to_device(dsa, runtime_device); + return; + } + + std::vector specs; + deepseek_v4_collect_cpu_metadata_tensors(dsa, runtime_device, specs); + const size_t total_bytes = deepseek_v4_layout_packed_tensor_specs(specs); + if (total_bytes == 0) { + return; + } + + if (!persistent.packed_metadata_host_buffer.defined() || + persistent.packed_metadata_host_buffer.scalar_type() != torch::kUInt8 || + persistent.packed_metadata_host_buffer.device() != torch::kCPU || + persistent.packed_metadata_host_buffer.numel() < + static_cast(total_bytes)) { + persistent.packed_metadata_host_buffer = + torch::empty({static_cast(total_bytes)}, + torch::TensorOptions() + .dtype(torch::kUInt8) + .device(torch::kCPU) + .pinned_memory(true)); + } + auto host_buffer = persistent.packed_metadata_host_buffer.slice( + /*dim=*/0, /*start=*/0, /*end=*/static_cast(total_bytes)); + deepseek_v4_fill_packed_host_buffer(specs, host_buffer); + auto device_options = + torch::TensorOptions().dtype(torch::kUInt8).device(runtime_device); + if (!persistent.packed_metadata_buffer.defined()) { + persistent.packed_metadata_buffer = + torch::empty({static_cast(total_bytes)}, device_options); + } else { + CHECK_EQ(persistent.packed_metadata_host_buffer.scalar_type(), + torch::kUInt8) + << "[DeepseekV4Mtp] graph host packed metadata dtype changed"; + CHECK_EQ(persistent.packed_metadata_host_buffer.device(), torch::kCPU) + << "[DeepseekV4Mtp] graph host packed metadata device changed"; + CHECK_GE(persistent.packed_metadata_host_buffer.numel(), + static_cast(total_bytes)) + << "[DeepseekV4Mtp] graph host packed metadata exceeds persistent " + "capacity: required=" + << total_bytes + << ", capacity=" << persistent.packed_metadata_host_buffer.numel(); + CHECK_EQ(persistent.packed_metadata_buffer.scalar_type(), torch::kUInt8) + << "[DeepseekV4Mtp] graph packed metadata dtype changed"; + CHECK_EQ(persistent.packed_metadata_buffer.device(), runtime_device) + << "[DeepseekV4Mtp] graph packed metadata device changed"; + CHECK_GE(persistent.packed_metadata_buffer.numel(), + static_cast(total_bytes)) + << "[DeepseekV4Mtp] graph packed metadata exceeds persistent capacity: " + << "required=" << total_bytes + << ", capacity=" << persistent.packed_metadata_buffer.numel(); + } + + persistent.packed_metadata_buffer + .slice(/*dim=*/0, + /*start=*/0, + /*end=*/static_cast(total_bytes)) + .copy_(host_buffer, /*non_blocking=*/true); + dsa.packed_metadata_buffer = persistent.packed_metadata_buffer.slice( + /*dim=*/0, /*start=*/0, /*end=*/static_cast(total_bytes)); + deepseek_v4_bind_packed_tensor_views(specs, dsa.packed_metadata_buffer); +#else + (void)persistent; + deepseek_v4_move_dsa_metadata_to_device(dsa, runtime_device); +#endif + } + + static int64_t infer_actual_batch_size(const ModelInputParams& params) { + if (params.actual_num_sequences > 0) { + return params.actual_num_sequences; + } + if (!params.kv_seq_lens_vec.empty()) { + return static_cast(params.kv_seq_lens_vec.size()); + } + if (!params.q_seq_lens_vec.empty()) { + return static_cast(params.q_seq_lens_vec.size()); + } + if (params.kv_seq_lens.defined() && params.kv_seq_lens.dim() >= 1) { + return params.kv_seq_lens.size(0); + } + if (params.q_seq_lens.defined() && params.q_seq_lens.dim() >= 1) { + return params.q_seq_lens.size(0); + } + if (params.block_tables.defined() && params.block_tables.dim() >= 2) { + return params.block_tables.size(0); + } + for (const auto& block_table : params.multi_block_tables) { + if (block_table.defined() && block_table.dim() >= 2) { + return block_table.size(0); + } + } + return 0; + } + + void normalize_graph_metadata_input_params(ModelInputParams& params) const { + const int64_t actual_batch_size = + std::max(infer_actual_batch_size(params), 0); + int64_t metadata_batch_size = params.actual_num_sequences; + if (params.enable_graph) { + metadata_batch_size = + std::max(metadata_batch_size, params.num_sequences); + } + if (metadata_batch_size <= 0) { + metadata_batch_size = 1; + } + + auto trim_lens_vec = [metadata_batch_size, + actual_batch_size](std::vector& lens) { + if (lens.empty()) { + lens.assign(static_cast(metadata_batch_size), 0); + } else if (static_cast(lens.size()) < metadata_batch_size) { + lens.resize(static_cast(metadata_batch_size), 0); + } else { + lens.resize(static_cast(metadata_batch_size)); + } + for (int64_t i = 0; i < static_cast(lens.size()); ++i) { + if (i < actual_batch_size) { + lens[static_cast(i)] = std::max(lens[i], 1); + } else { + lens[static_cast(i)] = 0; + } + } + }; + + trim_lens_vec(params.kv_seq_lens_vec); + trim_lens_vec(params.q_seq_lens_vec); + params.num_sequences = static_cast(metadata_batch_size); + params.actual_num_sequences = static_cast(metadata_batch_size); + } + + std::shared_ptr persist_graph_attention_metadata( + DeepseekV4GraphMetadataState& state, + std::shared_ptr metadata) const { + if (!metadata || !metadata->dsa_metadata) { + return metadata; + } + + auto& dsa = *metadata->dsa_metadata; + auto& persistent = state.dsa_metadata_persistent; + + dsa.attn_mask = + copy_to_persistent_tensor(dsa.attn_mask, persistent.attn_mask); + dsa.start_pos = + copy_to_persistent_tensor(dsa.start_pos, persistent.start_pos); + + return metadata; + } + + void fill_empty_dp_rank_input_params(ModelInputParams& params) const { + auto cpu_int_options = torch::TensorOptions() + .dtype(torch::kInt32) + .device(torch::kCPU) + .pinned_memory(true); + params.num_sequences = 1; + params.kv_max_seq_len = std::max(params.kv_max_seq_len, 1); + params.q_max_seq_len = std::max(params.q_max_seq_len, 1); + params.kv_seq_lens_vec = {1}; + params.q_seq_lens_vec = {1}; + params.kv_seq_lens = torch::tensor(params.kv_seq_lens_vec, cpu_int_options); + params.q_seq_lens = torch::tensor(params.q_seq_lens_vec, cpu_int_options); + params.q_cu_seq_lens = torch::tensor({1}, cpu_int_options); + params.kv_cache_tokens_nums = torch::tensor({1}, cpu_int_options); + params.kv_cache_tokens_nums_host = {1}; + params.new_cache_slots = torch::tensor({0}, cpu_int_options); + params.block_tables = torch::zeros({1, 1}, cpu_int_options); + + if (!params.multi_block_tables.empty()) { + return; + } + + const int32_t manager_num = static_cast(group_infos_.size()); + params.multi_block_tables.reserve(manager_num); + for (int32_t manager_id = 0; manager_id < manager_num; ++manager_id) { + params.multi_block_tables.push_back( + torch::zeros({1, 1}, cpu_int_options)); + } + } + std::shared_ptr build_attention_metadata_for_forward(const torch::Tensor& positions, const ModelInputParams& input_params) { @@ -702,6 +1011,7 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { double norm_eps_ = 1e-6; ModelArgs model_args_; + torch::Device device_{torch::kCPU}; layer::RMSNorm final_norm_{nullptr}; layer::WordEmbedding embed_tokens_{nullptr}; @@ -725,6 +1035,22 @@ class DeepseekV4MtpForCausalLMImpl final } model_->verify_loaded_weights(prefix); } + + bool requires_graph_forward_metadata() { + return this->model_->requires_graph_forward_metadata(); + } + + std::unique_ptr + create_graph_forward_metadata_state() { + return this->model_->create_graph_forward_metadata_state(); + } + + void prepare_graph_forward_metadata(ModelGraphMetadataState* state, + const torch::Tensor& positions, + ModelInputParams& input_params) { + this->model_->prepare_graph_forward_metadata( + state, positions, input_params); + } }; TORCH_MODULE(DeepseekV4MtpForCausalLM); From ef82fefb71860c5190d4d007c66ba55f6c9b74d1 Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Thu, 21 May 2026 17:05:22 +0800 Subject: [PATCH 04/10] bugfix: fix incorrect graph padding mask applied on mtp tokens. --- xllm/core/runtime/acl_graph_executor_impl.cpp | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 1e049cb47e..0cc13c4000 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -425,6 +425,10 @@ std::optional GraphPersistentParam::update( std::max(options_.num_decoding_tokens(), 1); actual_batch_size = actual_num_tokens / decode_tokens; } + const int64_t padded_batch_size = static_cast(padded_num_tokens); + const int64_t actual_metadata_rows = std::min( + padded_batch_size, + std::max(actual_batch_size, params.num_sequences)); // Copy data from input parameters to persistent graph tensors if (actual_num_tokens > 0) { @@ -461,10 +465,10 @@ std::optional GraphPersistentParam::update( } } int64_t q_copy_len = 0; - if (actual_batch_size > 0 && params.q_seq_lens.defined() && + if (actual_metadata_rows > 0 && params.q_seq_lens.defined() && params.q_seq_lens.dim() >= 1 && params.q_seq_lens.numel() > 0) { q_copy_len = - std::min(actual_batch_size, params.q_seq_lens.size(0)); + std::min(actual_metadata_rows, params.q_seq_lens.size(0)); if (q_copy_len > 0) { q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_copy_len) .copy_(params.q_seq_lens.slice(/*dim=*/0, @@ -474,10 +478,10 @@ std::optional GraphPersistentParam::update( } } int64_t kv_copy_len = 0; - if (actual_batch_size > 0 && params.kv_seq_lens.defined() && + if (actual_metadata_rows > 0 && params.kv_seq_lens.defined() && params.kv_seq_lens.dim() >= 1 && params.kv_seq_lens.numel() > 0) { kv_copy_len = - std::min(actual_batch_size, params.kv_seq_lens.size(0)); + std::min(actual_metadata_rows, params.kv_seq_lens.size(0)); if (kv_copy_len > 0) { kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/kv_copy_len) .copy_(params.kv_seq_lens.slice(/*dim=*/0, @@ -488,7 +492,6 @@ std::optional GraphPersistentParam::update( } // Keep padded decode slots valid for empty/local-short DP shards. // These tensors are consumed by ATB setup alongside *_seq_lens_vec. - const int64_t padded_batch_size = static_cast(padded_num_tokens); if (q_copy_len < padded_batch_size) { q_seq_lens_ .slice(/*dim=*/0, @@ -551,10 +554,10 @@ std::optional GraphPersistentParam::update( // zeroed in the active bucket slice. int64_t block_rows_to_copy = 0; int64_t actual_block_table_len = 0; - if (actual_batch_size > 0 && params.block_tables.defined() && + if (actual_metadata_rows > 0 && params.block_tables.defined() && params.block_tables.dim() >= 2 && params.block_tables.numel() > 0) { block_rows_to_copy = - std::min(actual_batch_size, params.block_tables.size(0)); + std::min(actual_metadata_rows, params.block_tables.size(0)); actual_block_table_len = params.block_tables.size(1); if (block_rows_to_copy > 0 && actual_block_table_len > 0) { auto slice_persistent_block_tables = @@ -577,9 +580,9 @@ std::optional GraphPersistentParam::update( /*end=*/persistent_block_tables_.size(1)) .zero_(); } - if (actual_batch_size < padded_batch_size) { + if (actual_metadata_rows < padded_batch_size) { zero_tensor_tail( - persistent_block_tables_, actual_batch_size, padded_batch_size); + persistent_block_tables_, actual_metadata_rows, padded_batch_size); } // Update persistent embedding from input_embedding if available @@ -611,7 +614,8 @@ std::optional GraphPersistentParam::update( const int64_t q_cu_size = (has_q_cu && params.q_cu_seq_lens.numel() > 0) ? params.q_cu_seq_lens.size(0) : 0; - const int64_t q_cu_copy_len = std::min(actual_batch_size, q_cu_size); + const int64_t q_cu_copy_len = + std::min(actual_metadata_rows, q_cu_size); if (q_cu_copy_len > 0) { q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len) .copy_(params.q_cu_seq_lens.slice(/*dim=*/0, @@ -667,13 +671,11 @@ std::optional GraphPersistentParam::update( params_for_capture->q_seq_lens = q_seq_lens(padded_num_tokens); params_for_capture->kv_seq_lens_vec.resize(padded_num_tokens); params_for_capture->q_seq_lens_vec.resize(padded_num_tokens); - // Copy actual values from original params - for (int i = 0; i < actual_batch_size; i++) { + for (int64_t i = 0; i < actual_metadata_rows; ++i) { params_for_capture->kv_seq_lens_vec[i] = params.kv_seq_lens_vec[i]; params_for_capture->q_seq_lens_vec[i] = params.q_seq_lens_vec[i]; } - // Fill padded positions with default values - for (int i = actual_batch_size; i < padded_num_tokens; i++) { + for (int64_t i = actual_metadata_rows; i < padded_batch_size; ++i) { params_for_capture->kv_seq_lens_vec[i] = 1; params_for_capture->q_seq_lens_vec[i] = 1; } From e2df3d90f9186ecab45969bcf4da6cb60fcc3de8 Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Thu, 21 May 2026 22:37:45 +0800 Subject: [PATCH 05/10] bugfix: fix incorrect graph padding mask applied on mtp tokens/1. --- xllm/core/runtime/acl_graph_executor_impl.cpp | 3 +- xllm/models/llm/deepseek_v4.h | 48 ++++--------------- xllm/models/llm/deepseek_v4_mtp.h | 48 ++++--------------- 3 files changed, 20 insertions(+), 79 deletions(-) diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 0cc13c4000..694f561840 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -425,6 +425,7 @@ std::optional GraphPersistentParam::update( std::max(options_.num_decoding_tokens(), 1); actual_batch_size = actual_num_tokens / decode_tokens; } + // MTP model support: actual_metadata_rows is not always equal to actual_batch_size const int64_t padded_batch_size = static_cast(padded_num_tokens); const int64_t actual_metadata_rows = std::min( padded_batch_size, @@ -666,7 +667,7 @@ std::optional GraphPersistentParam::update( std::make_optional(params); // Set persistent buffers in params_for_capture params_for_capture->actual_num_sequences = - static_cast(actual_batch_size); + static_cast(actual_metadata_rows); params_for_capture->kv_seq_lens = kv_seq_lens(padded_num_tokens); params_for_capture->q_seq_lens = q_seq_lens(padded_num_tokens); params_for_capture->kv_seq_lens_vec.resize(padded_num_tokens); diff --git a/xllm/models/llm/deepseek_v4.h b/xllm/models/llm/deepseek_v4.h index 712921765d..044e786e22 100644 --- a/xllm/models/llm/deepseek_v4.h +++ b/xllm/models/llm/deepseek_v4.h @@ -958,37 +958,10 @@ class DeepseekV4ModelImpl #endif } - static int64_t infer_actual_batch_size(const ModelInputParams& params) { - if (params.actual_num_sequences > 0) { - return params.actual_num_sequences; - } - if (!params.kv_seq_lens_vec.empty()) { - return static_cast(params.kv_seq_lens_vec.size()); - } - if (!params.q_seq_lens_vec.empty()) { - return static_cast(params.q_seq_lens_vec.size()); - } - if (params.kv_seq_lens.defined() && params.kv_seq_lens.dim() >= 1) { - return params.kv_seq_lens.size(0); - } - if (params.q_seq_lens.defined() && params.q_seq_lens.dim() >= 1) { - return params.q_seq_lens.size(0); - } - if (params.block_tables.defined() && params.block_tables.dim() >= 2) { - return params.block_tables.size(0); - } - for (const auto& block_table : params.multi_block_tables) { - if (block_table.defined() && block_table.dim() >= 2) { - return block_table.size(0); - } - } - return 0; - } - void normalize_graph_metadata_input_params(ModelInputParams& params) const { - const int64_t actual_batch_size = - std::max(infer_actual_batch_size(params), 0); - int64_t metadata_batch_size = params.actual_num_sequences; + int64_t actual_metadata_rows = std::max(params.actual_num_sequences, + 0); + int64_t metadata_batch_size = actual_metadata_rows; if (params.enable_graph) { metadata_batch_size = std::max(metadata_batch_size, params.num_sequences); @@ -996,9 +969,11 @@ class DeepseekV4ModelImpl if (metadata_batch_size <= 0) { metadata_batch_size = 1; } + actual_metadata_rows = std::min(actual_metadata_rows, + metadata_batch_size); auto trim_lens_vec = [metadata_batch_size, - actual_batch_size](std::vector& lens) { + actual_metadata_rows](std::vector& lens) { if (lens.empty()) { lens.assign(static_cast(metadata_batch_size), 0); } else if (static_cast(lens.size()) < metadata_batch_size) { @@ -1006,13 +981,7 @@ class DeepseekV4ModelImpl } else { lens.resize(static_cast(metadata_batch_size)); } - for (int64_t i = 0; i < static_cast(lens.size()); ++i) { - if (i < actual_batch_size) { - lens[static_cast(i)] = std::max(lens[i], 1); - } else { - lens[static_cast(i)] = 0; - } - } + std::fill(lens.begin() + actual_metadata_rows, lens.end(), 0); }; // Graph forward tensors are padded to the decode bucket. Build metadata @@ -1020,7 +989,7 @@ class DeepseekV4ModelImpl trim_lens_vec(params.kv_seq_lens_vec); trim_lens_vec(params.q_seq_lens_vec); params.num_sequences = static_cast(metadata_batch_size); - params.actual_num_sequences = static_cast(metadata_batch_size); + params.actual_num_sequences = static_cast(actual_metadata_rows); } std::shared_ptr @@ -1065,6 +1034,7 @@ class DeepseekV4ModelImpl .device(torch::kCPU) .pinned_memory(true); params.num_sequences = 1; + params.actual_num_sequences = 1; params.kv_max_seq_len = std::max(params.kv_max_seq_len, 1); params.q_max_seq_len = std::max(params.q_max_seq_len, 1); params.kv_seq_lens_vec = {1}; diff --git a/xllm/models/llm/deepseek_v4_mtp.h b/xllm/models/llm/deepseek_v4_mtp.h index 541cf77883..5eb652ee8f 100644 --- a/xllm/models/llm/deepseek_v4_mtp.h +++ b/xllm/models/llm/deepseek_v4_mtp.h @@ -458,37 +458,10 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { #endif } - static int64_t infer_actual_batch_size(const ModelInputParams& params) { - if (params.actual_num_sequences > 0) { - return params.actual_num_sequences; - } - if (!params.kv_seq_lens_vec.empty()) { - return static_cast(params.kv_seq_lens_vec.size()); - } - if (!params.q_seq_lens_vec.empty()) { - return static_cast(params.q_seq_lens_vec.size()); - } - if (params.kv_seq_lens.defined() && params.kv_seq_lens.dim() >= 1) { - return params.kv_seq_lens.size(0); - } - if (params.q_seq_lens.defined() && params.q_seq_lens.dim() >= 1) { - return params.q_seq_lens.size(0); - } - if (params.block_tables.defined() && params.block_tables.dim() >= 2) { - return params.block_tables.size(0); - } - for (const auto& block_table : params.multi_block_tables) { - if (block_table.defined() && block_table.dim() >= 2) { - return block_table.size(0); - } - } - return 0; - } - void normalize_graph_metadata_input_params(ModelInputParams& params) const { - const int64_t actual_batch_size = - std::max(infer_actual_batch_size(params), 0); - int64_t metadata_batch_size = params.actual_num_sequences; + int64_t actual_metadata_rows = std::max(params.actual_num_sequences, + 0); + int64_t metadata_batch_size = actual_metadata_rows; if (params.enable_graph) { metadata_batch_size = std::max(metadata_batch_size, params.num_sequences); @@ -496,9 +469,11 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { if (metadata_batch_size <= 0) { metadata_batch_size = 1; } + actual_metadata_rows = std::min(actual_metadata_rows, + metadata_batch_size); auto trim_lens_vec = [metadata_batch_size, - actual_batch_size](std::vector& lens) { + actual_metadata_rows](std::vector& lens) { if (lens.empty()) { lens.assign(static_cast(metadata_batch_size), 0); } else if (static_cast(lens.size()) < metadata_batch_size) { @@ -506,19 +481,13 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { } else { lens.resize(static_cast(metadata_batch_size)); } - for (int64_t i = 0; i < static_cast(lens.size()); ++i) { - if (i < actual_batch_size) { - lens[static_cast(i)] = std::max(lens[i], 1); - } else { - lens[static_cast(i)] = 0; - } - } + std::fill(lens.begin() + actual_metadata_rows, lens.end(), 0); }; trim_lens_vec(params.kv_seq_lens_vec); trim_lens_vec(params.q_seq_lens_vec); params.num_sequences = static_cast(metadata_batch_size); - params.actual_num_sequences = static_cast(metadata_batch_size); + params.actual_num_sequences = static_cast(actual_metadata_rows); } std::shared_ptr persist_graph_attention_metadata( @@ -545,6 +514,7 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { .device(torch::kCPU) .pinned_memory(true); params.num_sequences = 1; + params.actual_num_sequences = 1; params.kv_max_seq_len = std::max(params.kv_max_seq_len, 1); params.q_max_seq_len = std::max(params.q_max_seq_len, 1); params.kv_seq_lens_vec = {1}; From 3b10c3d0d1ae5d5d8a197aa3925c4a9dada5934c Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Fri, 22 May 2026 08:59:35 +0800 Subject: [PATCH 06/10] bugfix: fix incorrect graph padding mask applied on mtp tokens/2. --- xllm/core/runtime/acl_graph_executor_impl.cpp | 15 ++++++- xllm/models/llm/deepseek_v4.h | 45 +++++++++++++------ xllm/models/llm/deepseek_v4_mtp.h | 34 +++++++++----- 3 files changed, 66 insertions(+), 28 deletions(-) diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 694f561840..0aa7a32518 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -425,7 +425,10 @@ std::optional GraphPersistentParam::update( std::max(options_.num_decoding_tokens(), 1); actual_batch_size = actual_num_tokens / decode_tokens; } - // MTP model support: actual_metadata_rows is not always equal to actual_batch_size + // actual_metadata_rows: number of metadata rows carrying real token data. + // Normal decode: actual_batch_size (= num_requests). + // MTP validate: params.num_sequences (= num_requests * (1 + num_spec_tokens)), + // which is larger than actual_batch_size. const int64_t padded_batch_size = static_cast(padded_num_tokens); const int64_t actual_metadata_rows = std::min( padded_batch_size, @@ -665,11 +668,19 @@ std::optional GraphPersistentParam::update( if (return_capture_params) { std::optional params_for_capture = std::make_optional(params); - // Set persistent buffers in params_for_capture + + // actual_num_sequences tells the model's normalize function how many + // metadata rows carry real data. For MTP validate this can be larger + // than actual_batch_size because one request expands to + // (1 + num_speculative_tokens) token rows. params_for_capture->actual_num_sequences = static_cast(actual_metadata_rows); params_for_capture->kv_seq_lens = kv_seq_lens(padded_num_tokens); params_for_capture->q_seq_lens = q_seq_lens(padded_num_tokens); + + // Copy real seq_lens rows [0, actual_metadata_rows), fill padding + // rows [actual_metadata_rows, padded_batch_size) with 1 so that the + // graph-captured forward sees valid (non-zero) seq_lens everywhere. params_for_capture->kv_seq_lens_vec.resize(padded_num_tokens); params_for_capture->q_seq_lens_vec.resize(padded_num_tokens); for (int64_t i = 0; i < actual_metadata_rows; ++i) { diff --git a/xllm/models/llm/deepseek_v4.h b/xllm/models/llm/deepseek_v4.h index 044e786e22..7dc363b5ab 100644 --- a/xllm/models/llm/deepseek_v4.h +++ b/xllm/models/llm/deepseek_v4.h @@ -958,37 +958,54 @@ class DeepseekV4ModelImpl #endif } + // Normalize metadata vectors for graph mode (bucket-padded) forward. + // + // Key concepts: + // actual_metadata_rows -- rows that carry real token data. + // For normal decode: = num_requests. + // For MTP validate: = num_requests * (1 + num_spec_tokens). + // padded_metadata_rows -- total rows after bucket padding. + // >= actual_metadata_rows, rounded up to decode bucket. + // Rows in [actual_metadata_rows, padded_metadata_rows) are bucket padding + // and must be zeroed so DSAMetadataBuilder treats them as inactive. + // + // Example (normal decode, 2 requests, bucket size 4): + // actual_metadata_rows = 2, padded_metadata_rows = 4 + // kv_seq_lens_vec = [23, 24, 0, 0] + // + // Example (MTP validate, 1 request, num_speculative_tokens=1, bucket 4): + // actual_metadata_rows = 1*(1+1) = 2, padded_metadata_rows = 4 + // kv_seq_lens_vec = [23, 24, 0, 0] + // Both rows are real; num_speculative_tokens is from speculative config. void normalize_graph_metadata_input_params(ModelInputParams& params) const { int64_t actual_metadata_rows = std::max(params.actual_num_sequences, 0); - int64_t metadata_batch_size = actual_metadata_rows; + int64_t padded_metadata_rows = actual_metadata_rows; if (params.enable_graph) { - metadata_batch_size = - std::max(metadata_batch_size, params.num_sequences); + padded_metadata_rows = + std::max(padded_metadata_rows, params.num_sequences); } - if (metadata_batch_size <= 0) { - metadata_batch_size = 1; + if (padded_metadata_rows <= 0) { + padded_metadata_rows = 1; } actual_metadata_rows = std::min(actual_metadata_rows, - metadata_batch_size); + padded_metadata_rows); - auto trim_lens_vec = [metadata_batch_size, + auto trim_lens_vec = [padded_metadata_rows, actual_metadata_rows](std::vector& lens) { if (lens.empty()) { - lens.assign(static_cast(metadata_batch_size), 0); - } else if (static_cast(lens.size()) < metadata_batch_size) { - lens.resize(static_cast(metadata_batch_size), 0); + lens.assign(static_cast(padded_metadata_rows), 0); + } else if (static_cast(lens.size()) < padded_metadata_rows) { + lens.resize(static_cast(padded_metadata_rows), 0); } else { - lens.resize(static_cast(metadata_batch_size)); + lens.resize(static_cast(padded_metadata_rows)); } std::fill(lens.begin() + actual_metadata_rows, lens.end(), 0); }; - // Graph forward tensors are padded to the decode bucket. Build metadata - // for the same padded row count so compressor/attention inputs agree. trim_lens_vec(params.kv_seq_lens_vec); trim_lens_vec(params.q_seq_lens_vec); - params.num_sequences = static_cast(metadata_batch_size); + params.num_sequences = static_cast(padded_metadata_rows); params.actual_num_sequences = static_cast(actual_metadata_rows); } diff --git a/xllm/models/llm/deepseek_v4_mtp.h b/xllm/models/llm/deepseek_v4_mtp.h index 5eb652ee8f..6a66c28d03 100644 --- a/xllm/models/llm/deepseek_v4_mtp.h +++ b/xllm/models/llm/deepseek_v4_mtp.h @@ -458,35 +458,45 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { #endif } + // Normalize metadata vectors for graph mode (bucket-padded) forward. + // + // Key concepts: + // actual_metadata_rows -- rows that carry real token data. + // For normal decode: = num_requests. + // For MTP validate: = num_requests * (1 + num_spec_tokens). + // padded_metadata_rows -- total rows after bucket padding. + // >= actual_metadata_rows, rounded up to decode bucket. + // Rows in [actual_metadata_rows, padded_metadata_rows) are bucket padding + // and must be zeroed so DSAMetadataBuilder treats them as inactive. void normalize_graph_metadata_input_params(ModelInputParams& params) const { int64_t actual_metadata_rows = std::max(params.actual_num_sequences, 0); - int64_t metadata_batch_size = actual_metadata_rows; + int64_t padded_metadata_rows = actual_metadata_rows; if (params.enable_graph) { - metadata_batch_size = - std::max(metadata_batch_size, params.num_sequences); + padded_metadata_rows = + std::max(padded_metadata_rows, params.num_sequences); } - if (metadata_batch_size <= 0) { - metadata_batch_size = 1; + if (padded_metadata_rows <= 0) { + padded_metadata_rows = 1; } actual_metadata_rows = std::min(actual_metadata_rows, - metadata_batch_size); + padded_metadata_rows); - auto trim_lens_vec = [metadata_batch_size, + auto trim_lens_vec = [padded_metadata_rows, actual_metadata_rows](std::vector& lens) { if (lens.empty()) { - lens.assign(static_cast(metadata_batch_size), 0); - } else if (static_cast(lens.size()) < metadata_batch_size) { - lens.resize(static_cast(metadata_batch_size), 0); + lens.assign(static_cast(padded_metadata_rows), 0); + } else if (static_cast(lens.size()) < padded_metadata_rows) { + lens.resize(static_cast(padded_metadata_rows), 0); } else { - lens.resize(static_cast(metadata_batch_size)); + lens.resize(static_cast(padded_metadata_rows)); } std::fill(lens.begin() + actual_metadata_rows, lens.end(), 0); }; trim_lens_vec(params.kv_seq_lens_vec); trim_lens_vec(params.q_seq_lens_vec); - params.num_sequences = static_cast(metadata_batch_size); + params.num_sequences = static_cast(padded_metadata_rows); params.actual_num_sequences = static_cast(actual_metadata_rows); } From e46cebd82215567f9102601772330299f2a1af71 Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Fri, 22 May 2026 09:40:47 +0800 Subject: [PATCH 07/10] bugfix: remove remained ds_v4 detect --- xllm/core/distributed_runtime/llm_engine.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 3f0ce1ae33..f9118f1c98 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -1437,7 +1437,8 @@ std::vector LLMEngine::prepare_inputs( args_, threadpool_.get(), cp_size_))); dp_global_token_nums[dp_rank] = batched_inputs[dp_rank].flatten_tokens_vec.size(); - if (util::is_deepseek_v4_model_type(args_.model_type())) { + if (args_.model_type() == "deepseek_v4" || + args_.model_type() == "deepseek_v4_mtp") { const int64_t actual_scheduled_tokens = static_cast( batched_inputs[dp_rank].flatten_tokens_vec.size()); const int64_t max_tokens_per_batch = From 0196fed18c6b840c333ac9582cec9e8823db466e Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Fri, 22 May 2026 10:46:36 +0800 Subject: [PATCH 08/10] refactor: remove redundant declarations of graph funcs and replace metadata_rows to num_sequeces. --- xllm/core/runtime/acl_graph_executor_impl.cpp | 54 ++--- xllm/core/runtime/params_utils.cpp | 18 +- xllm/models/llm/deepseek_v4.h | 12 +- xllm/models/llm/deepseek_v4_mtp.h | 190 +----------------- 4 files changed, 57 insertions(+), 217 deletions(-) diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 0aa7a32518..e04f5019ea 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -425,13 +425,13 @@ std::optional GraphPersistentParam::update( std::max(options_.num_decoding_tokens(), 1); actual_batch_size = actual_num_tokens / decode_tokens; } - // actual_metadata_rows: number of metadata rows carrying real token data. + // actual_num_sequences: number of sequence rows carrying real token data. // Normal decode: actual_batch_size (= num_requests). // MTP validate: params.num_sequences (= num_requests * (1 + num_spec_tokens)), // which is larger than actual_batch_size. - const int64_t padded_batch_size = static_cast(padded_num_tokens); - const int64_t actual_metadata_rows = std::min( - padded_batch_size, + const int64_t padded_num_sequences = static_cast(padded_num_tokens); + const int64_t actual_num_sequences = std::min( + padded_num_sequences, std::max(actual_batch_size, params.num_sequences)); // Copy data from input parameters to persistent graph tensors @@ -469,10 +469,10 @@ std::optional GraphPersistentParam::update( } } int64_t q_copy_len = 0; - if (actual_metadata_rows > 0 && params.q_seq_lens.defined() && + if (actual_num_sequences > 0 && params.q_seq_lens.defined() && params.q_seq_lens.dim() >= 1 && params.q_seq_lens.numel() > 0) { q_copy_len = - std::min(actual_metadata_rows, params.q_seq_lens.size(0)); + std::min(actual_num_sequences, params.q_seq_lens.size(0)); if (q_copy_len > 0) { q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_copy_len) .copy_(params.q_seq_lens.slice(/*dim=*/0, @@ -482,10 +482,10 @@ std::optional GraphPersistentParam::update( } } int64_t kv_copy_len = 0; - if (actual_metadata_rows > 0 && params.kv_seq_lens.defined() && + if (actual_num_sequences > 0 && params.kv_seq_lens.defined() && params.kv_seq_lens.dim() >= 1 && params.kv_seq_lens.numel() > 0) { kv_copy_len = - std::min(actual_metadata_rows, params.kv_seq_lens.size(0)); + std::min(actual_num_sequences, params.kv_seq_lens.size(0)); if (kv_copy_len > 0) { kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/kv_copy_len) .copy_(params.kv_seq_lens.slice(/*dim=*/0, @@ -496,18 +496,18 @@ std::optional GraphPersistentParam::update( } // Keep padded decode slots valid for empty/local-short DP shards. // These tensors are consumed by ATB setup alongside *_seq_lens_vec. - if (q_copy_len < padded_batch_size) { + if (q_copy_len < padded_num_sequences) { q_seq_lens_ .slice(/*dim=*/0, /*start=*/q_copy_len, - /*end=*/padded_batch_size) + /*end=*/padded_num_sequences) .fill_(1); } - if (kv_copy_len < padded_batch_size) { + if (kv_copy_len < padded_num_sequences) { kv_seq_lens_ .slice(/*dim=*/0, /*start=*/kv_copy_len, - /*end=*/padded_batch_size) + /*end=*/padded_num_sequences) .fill_(1); } @@ -558,10 +558,10 @@ std::optional GraphPersistentParam::update( // zeroed in the active bucket slice. int64_t block_rows_to_copy = 0; int64_t actual_block_table_len = 0; - if (actual_metadata_rows > 0 && params.block_tables.defined() && + if (actual_num_sequences > 0 && params.block_tables.defined() && params.block_tables.dim() >= 2 && params.block_tables.numel() > 0) { block_rows_to_copy = - std::min(actual_metadata_rows, params.block_tables.size(0)); + std::min(actual_num_sequences, params.block_tables.size(0)); actual_block_table_len = params.block_tables.size(1); if (block_rows_to_copy > 0 && actual_block_table_len > 0) { auto slice_persistent_block_tables = @@ -584,9 +584,9 @@ std::optional GraphPersistentParam::update( /*end=*/persistent_block_tables_.size(1)) .zero_(); } - if (actual_metadata_rows < padded_batch_size) { + if (actual_num_sequences < padded_num_sequences) { zero_tensor_tail( - persistent_block_tables_, actual_metadata_rows, padded_batch_size); + persistent_block_tables_, actual_num_sequences, padded_num_sequences); } // Update persistent embedding from input_embedding if available @@ -619,7 +619,7 @@ std::optional GraphPersistentParam::update( ? params.q_cu_seq_lens.size(0) : 0; const int64_t q_cu_copy_len = - std::min(actual_metadata_rows, q_cu_size); + std::min(actual_num_sequences, q_cu_size); if (q_cu_copy_len > 0) { q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len) .copy_(params.q_cu_seq_lens.slice(/*dim=*/0, @@ -627,10 +627,10 @@ std::optional GraphPersistentParam::update( /*end=*/q_cu_copy_len), /*non_blocking=*/true); } - if (padded_batch_size > q_cu_copy_len) { + if (padded_num_sequences > q_cu_copy_len) { auto tail_q_seq_lens = q_seq_lens_.slice(/*dim=*/0, /*start=*/q_cu_copy_len, - /*end=*/padded_batch_size); + /*end=*/padded_num_sequences); auto tail_cu = torch::cumsum(tail_q_seq_lens, /*dim=*/0); if (q_cu_copy_len > 0) { auto last_prefix = q_cu_seq_lens_.slice(/*dim=*/0, @@ -639,7 +639,7 @@ std::optional GraphPersistentParam::update( tail_cu = tail_cu + last_prefix; } q_cu_seq_lens_ - .slice(/*dim=*/0, /*start=*/q_cu_copy_len, /*end=*/padded_batch_size) + .slice(/*dim=*/0, /*start=*/q_cu_copy_len, /*end=*/padded_num_sequences) .copy_(tail_cu, /*non_blocking=*/true); } @@ -670,24 +670,24 @@ std::optional GraphPersistentParam::update( std::make_optional(params); // actual_num_sequences tells the model's normalize function how many - // metadata rows carry real data. For MTP validate this can be larger + // sequence rows carry real data. For MTP validate this can be larger // than actual_batch_size because one request expands to // (1 + num_speculative_tokens) token rows. params_for_capture->actual_num_sequences = - static_cast(actual_metadata_rows); + static_cast(actual_num_sequences); params_for_capture->kv_seq_lens = kv_seq_lens(padded_num_tokens); params_for_capture->q_seq_lens = q_seq_lens(padded_num_tokens); - // Copy real seq_lens rows [0, actual_metadata_rows), fill padding - // rows [actual_metadata_rows, padded_batch_size) with 1 so that the + // Copy real seq_lens rows [0, actual_num_sequences), fill padding + // rows [actual_num_sequences, padded_num_sequences) with 1 so that the // graph-captured forward sees valid (non-zero) seq_lens everywhere. params_for_capture->kv_seq_lens_vec.resize(padded_num_tokens); params_for_capture->q_seq_lens_vec.resize(padded_num_tokens); - for (int64_t i = 0; i < actual_metadata_rows; ++i) { + for (int64_t i = 0; i < actual_num_sequences; ++i) { params_for_capture->kv_seq_lens_vec[i] = params.kv_seq_lens_vec[i]; params_for_capture->q_seq_lens_vec[i] = params.q_seq_lens_vec[i]; } - for (int64_t i = actual_metadata_rows; i < padded_batch_size; ++i) { + for (int64_t i = actual_num_sequences; i < padded_num_sequences; ++i) { params_for_capture->kv_seq_lens_vec[i] = 1; params_for_capture->q_seq_lens_vec[i] = 1; } @@ -732,7 +732,7 @@ std::optional GraphPersistentParam::update( params_for_capture->q_cu_seq_lens = q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, - /*end=*/padded_batch_size); + /*end=*/padded_num_sequences); } // Replace dp/cp ep padding with slices of persistent buffers so that diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 21a0940632..66ba6e8e36 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -229,6 +229,17 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, static_cast(input_params.num_sequences)) << "block_tables_vec size (" << block_tables_vec.size() << ") must match num_sequences (" << input_params.num_sequences << ")"; + } else { + // support for mutli_block_tables + for (int32_t m = 0; m < pb_forward_input->multi_block_tables_vec().size(); + ++m) { + const auto& mgr = pb_forward_input->multi_block_tables_vec()[m]; + CHECK_EQ(static_cast(mgr.block_tables().size()), + static_cast(input_params.num_sequences)) + << "multi_block_tables[" << m << "] size (" + << mgr.block_tables().size() << ") must match num_sequences (" + << input_params.num_sequences << ")"; + } } input_params.kv_max_seq_len = pb_forward_input->max_seq_len(); input_params.q_max_seq_len = pb_forward_input->q_max_seq_len(); @@ -258,12 +269,9 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::move(create_2d_tensor(block_tables_vec, torch::kInt)); // multi block manager support for DeepSeek V4 - for (int m = 0; m < pb_forward_input->multi_block_tables_vec().size(); ++m) { + for (int32_t m = 0; m < pb_forward_input->multi_block_tables_vec().size(); + ++m) { const auto& mgr = pb_forward_input->multi_block_tables_vec()[m]; - CHECK_EQ(static_cast(mgr.block_tables().size()), - static_cast(input_params.num_sequences)) - << "multi_block_tables[" << m << "] size (" << mgr.block_tables().size() - << ") must match num_sequences (" << input_params.num_sequences << ")"; std::vector> mgr_tables; mgr_tables.reserve(mgr.block_tables().size()); for (int s = 0; s < mgr.block_tables().size(); ++s) { diff --git a/xllm/models/llm/deepseek_v4.h b/xllm/models/llm/deepseek_v4.h index 7dc363b5ab..58d9437758 100644 --- a/xllm/models/llm/deepseek_v4.h +++ b/xllm/models/llm/deepseek_v4.h @@ -840,6 +840,8 @@ class DeepseekV4ModelImpl return tensor_max_or_zero(fallback_tensor); } + public: + // TODO: Common Funcs for both dsv4/dsv4_mtp. Suggests move shared DeepSeek V4 graph metadata utilities out of DeepseekV4ModelImpl. static bool tensor_aliases_storage(const torch::Tensor& lhs, const torch::Tensor& rhs) { return lhs.defined() && rhs.defined() && lhs.data_ptr() == rhs.data_ptr() && @@ -977,7 +979,7 @@ class DeepseekV4ModelImpl // actual_metadata_rows = 1*(1+1) = 2, padded_metadata_rows = 4 // kv_seq_lens_vec = [23, 24, 0, 0] // Both rows are real; num_speculative_tokens is from speculative config. - void normalize_graph_metadata_input_params(ModelInputParams& params) const { + static void normalize_graph_metadata_input_params(ModelInputParams& params) { int64_t actual_metadata_rows = std::max(params.actual_num_sequences, 0); int64_t padded_metadata_rows = actual_metadata_rows; @@ -1025,9 +1027,10 @@ class DeepseekV4ModelImpl return attn_metadata; } - std::shared_ptr persist_graph_attention_metadata( + static std::shared_ptr + persist_graph_attention_metadata( DeepseekV4GraphMetadataState& state, - std::shared_ptr metadata) const { + std::shared_ptr metadata) { if (!metadata || !metadata->dsa_metadata) { return metadata; } @@ -1043,6 +1046,7 @@ class DeepseekV4ModelImpl return metadata; } + private: void fill_empty_dp_rank_input_params( ModelInputParams& params, const std::vector* kv_caches = nullptr) const { @@ -1077,7 +1081,7 @@ class DeepseekV4ModelImpl kv_caches->front()); } block_num = std::max(block_num, 1); - params.multi_block_tables.push_back( + params.multi_block_tables.emplace_back( torch::zeros({1, block_num}, cpu_int_options)); } } diff --git a/xllm/models/llm/deepseek_v4_mtp.h b/xllm/models/llm/deepseek_v4_mtp.h index 6a66c28d03..43115643cb 100644 --- a/xllm/models/llm/deepseek_v4_mtp.h +++ b/xllm/models/llm/deepseek_v4_mtp.h @@ -251,7 +251,8 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { auto modified_input_params = input_params; if (acl_graph_forward) { - normalize_graph_metadata_input_params(modified_input_params); + DeepseekV4ModelImpl::normalize_graph_metadata_input_params( + modified_input_params); } auto& dp_token_nums = modified_input_params.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); @@ -314,7 +315,8 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { if (modified_input_params.actual_num_sequences == 0) { fill_empty_dp_rank_input_params(modified_input_params); } - normalize_graph_metadata_input_params(modified_input_params); + DeepseekV4ModelImpl::normalize_graph_metadata_input_params( + modified_input_params); auto& dp_token_nums = modified_input_params.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); @@ -329,195 +331,21 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { if (dsa_hadamard_.defined()) { dsa.hadamard = dsa_hadamard_; } - copy_to_graph_packed_metadata_buffer( + DeepseekV4ModelImpl::copy_to_graph_packed_metadata_buffer( dsa, deepseek_v4_state->dsa_metadata_persistent, positions.device()); if (dsa.actual_seq_lengths_kv.defined() && dsa.seq_lens_q.defined()) { dsa.start_pos = (dsa.actual_seq_lengths_kv - dsa.seq_lens_q).to(torch::kInt32); } } - input_params.attn_metadata = persist_graph_attention_metadata( - *deepseek_v4_state, std::move(attn_metadata)); + input_params.attn_metadata = + DeepseekV4ModelImpl::persist_graph_attention_metadata( + *deepseek_v4_state, std::move(attn_metadata)); CHECK(input_params.attn_metadata) << "[DeepseekV4Mtp] ACL graph requires DSA metadata"; } private: - static bool tensor_aliases_storage(const torch::Tensor& lhs, - const torch::Tensor& rhs) { - return lhs.defined() && rhs.defined() && lhs.data_ptr() == rhs.data_ptr() && - lhs.sizes() == rhs.sizes() && lhs.strides() == rhs.strides(); - } - - static torch::Tensor copy_to_persistent_tensor(const torch::Tensor& src, - torch::Tensor& dst) { - if (!src.defined()) { - return src; - } - if (!dst.defined()) { - dst = torch::empty_like(src); - } else { - CHECK_EQ(dst.scalar_type(), src.scalar_type()) - << "[DeepseekV4Mtp] graph metadata tensor dtype changed"; - CHECK_EQ(dst.device(), src.device()) - << "[DeepseekV4Mtp] graph metadata tensor device changed"; - if (dst.sizes() != src.sizes()) { - bool can_copy_into_capacity = dst.dim() == src.dim() && src.dim() > 0 && - src.size(0) <= dst.size(0); - for (int64_t dim = 1; can_copy_into_capacity && dim < src.dim(); - ++dim) { - can_copy_into_capacity = dst.size(dim) == src.size(dim); - } - CHECK(can_copy_into_capacity) - << "[DeepseekV4Mtp] graph metadata tensor size changed from " - << dst.sizes() << " to " << src.sizes(); - dst.zero_(); - dst.slice(/*dim=*/0, /*start=*/0, /*end=*/src.size(0)) - .copy_(src, /*non_blocking=*/true); - return dst; - } - } - if (!tensor_aliases_storage(src, dst)) { - dst.copy_(src, /*non_blocking=*/true); - } - return dst; - } - - static void copy_to_graph_packed_metadata_buffer( - layer::DSAMetadata& dsa, - DeepseekV4GraphMetadataState::DSAMetadataPersistent& persistent, - const torch::Device& runtime_device) { -#if defined(USE_NPU) - if (runtime_device.is_cpu() || - runtime_device.type() != c10::DeviceType::PrivateUse1) { - deepseek_v4_move_dsa_metadata_to_device(dsa, runtime_device); - return; - } - - std::vector specs; - deepseek_v4_collect_cpu_metadata_tensors(dsa, runtime_device, specs); - const size_t total_bytes = deepseek_v4_layout_packed_tensor_specs(specs); - if (total_bytes == 0) { - return; - } - - if (!persistent.packed_metadata_host_buffer.defined() || - persistent.packed_metadata_host_buffer.scalar_type() != torch::kUInt8 || - persistent.packed_metadata_host_buffer.device() != torch::kCPU || - persistent.packed_metadata_host_buffer.numel() < - static_cast(total_bytes)) { - persistent.packed_metadata_host_buffer = - torch::empty({static_cast(total_bytes)}, - torch::TensorOptions() - .dtype(torch::kUInt8) - .device(torch::kCPU) - .pinned_memory(true)); - } - auto host_buffer = persistent.packed_metadata_host_buffer.slice( - /*dim=*/0, /*start=*/0, /*end=*/static_cast(total_bytes)); - deepseek_v4_fill_packed_host_buffer(specs, host_buffer); - auto device_options = - torch::TensorOptions().dtype(torch::kUInt8).device(runtime_device); - if (!persistent.packed_metadata_buffer.defined()) { - persistent.packed_metadata_buffer = - torch::empty({static_cast(total_bytes)}, device_options); - } else { - CHECK_EQ(persistent.packed_metadata_host_buffer.scalar_type(), - torch::kUInt8) - << "[DeepseekV4Mtp] graph host packed metadata dtype changed"; - CHECK_EQ(persistent.packed_metadata_host_buffer.device(), torch::kCPU) - << "[DeepseekV4Mtp] graph host packed metadata device changed"; - CHECK_GE(persistent.packed_metadata_host_buffer.numel(), - static_cast(total_bytes)) - << "[DeepseekV4Mtp] graph host packed metadata exceeds persistent " - "capacity: required=" - << total_bytes - << ", capacity=" << persistent.packed_metadata_host_buffer.numel(); - CHECK_EQ(persistent.packed_metadata_buffer.scalar_type(), torch::kUInt8) - << "[DeepseekV4Mtp] graph packed metadata dtype changed"; - CHECK_EQ(persistent.packed_metadata_buffer.device(), runtime_device) - << "[DeepseekV4Mtp] graph packed metadata device changed"; - CHECK_GE(persistent.packed_metadata_buffer.numel(), - static_cast(total_bytes)) - << "[DeepseekV4Mtp] graph packed metadata exceeds persistent capacity: " - << "required=" << total_bytes - << ", capacity=" << persistent.packed_metadata_buffer.numel(); - } - - persistent.packed_metadata_buffer - .slice(/*dim=*/0, - /*start=*/0, - /*end=*/static_cast(total_bytes)) - .copy_(host_buffer, /*non_blocking=*/true); - dsa.packed_metadata_buffer = persistent.packed_metadata_buffer.slice( - /*dim=*/0, /*start=*/0, /*end=*/static_cast(total_bytes)); - deepseek_v4_bind_packed_tensor_views(specs, dsa.packed_metadata_buffer); -#else - (void)persistent; - deepseek_v4_move_dsa_metadata_to_device(dsa, runtime_device); -#endif - } - - // Normalize metadata vectors for graph mode (bucket-padded) forward. - // - // Key concepts: - // actual_metadata_rows -- rows that carry real token data. - // For normal decode: = num_requests. - // For MTP validate: = num_requests * (1 + num_spec_tokens). - // padded_metadata_rows -- total rows after bucket padding. - // >= actual_metadata_rows, rounded up to decode bucket. - // Rows in [actual_metadata_rows, padded_metadata_rows) are bucket padding - // and must be zeroed so DSAMetadataBuilder treats them as inactive. - void normalize_graph_metadata_input_params(ModelInputParams& params) const { - int64_t actual_metadata_rows = std::max(params.actual_num_sequences, - 0); - int64_t padded_metadata_rows = actual_metadata_rows; - if (params.enable_graph) { - padded_metadata_rows = - std::max(padded_metadata_rows, params.num_sequences); - } - if (padded_metadata_rows <= 0) { - padded_metadata_rows = 1; - } - actual_metadata_rows = std::min(actual_metadata_rows, - padded_metadata_rows); - - auto trim_lens_vec = [padded_metadata_rows, - actual_metadata_rows](std::vector& lens) { - if (lens.empty()) { - lens.assign(static_cast(padded_metadata_rows), 0); - } else if (static_cast(lens.size()) < padded_metadata_rows) { - lens.resize(static_cast(padded_metadata_rows), 0); - } else { - lens.resize(static_cast(padded_metadata_rows)); - } - std::fill(lens.begin() + actual_metadata_rows, lens.end(), 0); - }; - - trim_lens_vec(params.kv_seq_lens_vec); - trim_lens_vec(params.q_seq_lens_vec); - params.num_sequences = static_cast(padded_metadata_rows); - params.actual_num_sequences = static_cast(actual_metadata_rows); - } - - std::shared_ptr persist_graph_attention_metadata( - DeepseekV4GraphMetadataState& state, - std::shared_ptr metadata) const { - if (!metadata || !metadata->dsa_metadata) { - return metadata; - } - - auto& dsa = *metadata->dsa_metadata; - auto& persistent = state.dsa_metadata_persistent; - - dsa.attn_mask = - copy_to_persistent_tensor(dsa.attn_mask, persistent.attn_mask); - dsa.start_pos = - copy_to_persistent_tensor(dsa.start_pos, persistent.start_pos); - - return metadata; - } - void fill_empty_dp_rank_input_params(ModelInputParams& params) const { auto cpu_int_options = torch::TensorOptions() .dtype(torch::kInt32) @@ -544,7 +372,7 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { const int32_t manager_num = static_cast(group_infos_.size()); params.multi_block_tables.reserve(manager_num); for (int32_t manager_id = 0; manager_id < manager_num; ++manager_id) { - params.multi_block_tables.push_back( + params.multi_block_tables.emplace_back( torch::zeros({1, 1}, cpu_int_options)); } } From 7393b9589380f81cb1e7d06a88c19b0ac9fdbd17 Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Fri, 22 May 2026 16:50:39 +0800 Subject: [PATCH 09/10] bugfix: fix MTP DP dummy data logic --- xllm/core/runtime/acl_graph_executor_impl.cpp | 77 +++++++++---------- xllm/models/llm/deepseek_v4_mtp.h | 23 +++++- 2 files changed, 55 insertions(+), 45 deletions(-) diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index e04f5019ea..66f900d639 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -418,22 +418,13 @@ std::optional GraphPersistentParam::update( bool return_capture_params) { CHECK_GT(padded_num_tokens, 0) << "padded_num_tokens must be > 0 when return_capture_params is true"; - const uint32_t actual_num_tokens = tokens.size(0); + const int64_t actual_num_tokens = tokens.size(0); int64_t actual_batch_size = infer_actual_batch_size(params); if (params.batch_forward_type.is_decode()) { const int64_t decode_tokens = std::max(options_.num_decoding_tokens(), 1); actual_batch_size = actual_num_tokens / decode_tokens; } - // actual_num_sequences: number of sequence rows carrying real token data. - // Normal decode: actual_batch_size (= num_requests). - // MTP validate: params.num_sequences (= num_requests * (1 + num_spec_tokens)), - // which is larger than actual_batch_size. - const int64_t padded_num_sequences = static_cast(padded_num_tokens); - const int64_t actual_num_sequences = std::min( - padded_num_sequences, - std::max(actual_batch_size, params.num_sequences)); - // Copy data from input parameters to persistent graph tensors if (actual_num_tokens > 0) { persistent_tokens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) @@ -469,10 +460,10 @@ std::optional GraphPersistentParam::update( } } int64_t q_copy_len = 0; - if (actual_num_sequences > 0 && params.q_seq_lens.defined() && + if (actual_num_tokens > 0 && params.q_seq_lens.defined() && params.q_seq_lens.dim() >= 1 && params.q_seq_lens.numel() > 0) { q_copy_len = - std::min(actual_num_sequences, params.q_seq_lens.size(0)); + std::min(actual_num_tokens, params.q_seq_lens.size(0)); if (q_copy_len > 0) { q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_copy_len) .copy_(params.q_seq_lens.slice(/*dim=*/0, @@ -482,10 +473,10 @@ std::optional GraphPersistentParam::update( } } int64_t kv_copy_len = 0; - if (actual_num_sequences > 0 && params.kv_seq_lens.defined() && + if (actual_num_tokens > 0 && params.kv_seq_lens.defined() && params.kv_seq_lens.dim() >= 1 && params.kv_seq_lens.numel() > 0) { kv_copy_len = - std::min(actual_num_sequences, params.kv_seq_lens.size(0)); + std::min(actual_num_tokens, params.kv_seq_lens.size(0)); if (kv_copy_len > 0) { kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/kv_copy_len) .copy_(params.kv_seq_lens.slice(/*dim=*/0, @@ -496,18 +487,18 @@ std::optional GraphPersistentParam::update( } // Keep padded decode slots valid for empty/local-short DP shards. // These tensors are consumed by ATB setup alongside *_seq_lens_vec. - if (q_copy_len < padded_num_sequences) { + if (q_copy_len < static_cast(padded_num_tokens)) { q_seq_lens_ .slice(/*dim=*/0, /*start=*/q_copy_len, - /*end=*/padded_num_sequences) + /*end=*/static_cast(padded_num_tokens)) .fill_(1); } - if (kv_copy_len < padded_num_sequences) { + if (kv_copy_len < static_cast(padded_num_tokens)) { kv_seq_lens_ .slice(/*dim=*/0, /*start=*/kv_copy_len, - /*end=*/padded_num_sequences) + /*end=*/static_cast(padded_num_tokens)) .fill_(1); } @@ -558,10 +549,10 @@ std::optional GraphPersistentParam::update( // zeroed in the active bucket slice. int64_t block_rows_to_copy = 0; int64_t actual_block_table_len = 0; - if (actual_num_sequences > 0 && params.block_tables.defined() && + if (actual_num_tokens > 0 && params.block_tables.defined() && params.block_tables.dim() >= 2 && params.block_tables.numel() > 0) { block_rows_to_copy = - std::min(actual_num_sequences, params.block_tables.size(0)); + std::min(actual_num_tokens, params.block_tables.size(0)); actual_block_table_len = params.block_tables.size(1); if (block_rows_to_copy > 0 && actual_block_table_len > 0) { auto slice_persistent_block_tables = @@ -584,9 +575,10 @@ std::optional GraphPersistentParam::update( /*end=*/persistent_block_tables_.size(1)) .zero_(); } - if (actual_num_sequences < padded_num_sequences) { - zero_tensor_tail( - persistent_block_tables_, actual_num_sequences, padded_num_sequences); + if (actual_num_tokens < static_cast(padded_num_tokens)) { + zero_tensor_tail(persistent_block_tables_, + actual_num_tokens, + static_cast(padded_num_tokens)); } // Update persistent embedding from input_embedding if available @@ -618,8 +610,7 @@ std::optional GraphPersistentParam::update( const int64_t q_cu_size = (has_q_cu && params.q_cu_seq_lens.numel() > 0) ? params.q_cu_seq_lens.size(0) : 0; - const int64_t q_cu_copy_len = - std::min(actual_num_sequences, q_cu_size); + const int64_t q_cu_copy_len = std::min(actual_num_tokens, q_cu_size); if (q_cu_copy_len > 0) { q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len) .copy_(params.q_cu_seq_lens.slice(/*dim=*/0, @@ -627,10 +618,11 @@ std::optional GraphPersistentParam::update( /*end=*/q_cu_copy_len), /*non_blocking=*/true); } - if (padded_num_sequences > q_cu_copy_len) { - auto tail_q_seq_lens = q_seq_lens_.slice(/*dim=*/0, - /*start=*/q_cu_copy_len, - /*end=*/padded_num_sequences); + if (static_cast(padded_num_tokens) > q_cu_copy_len) { + auto tail_q_seq_lens = q_seq_lens_.slice( + /*dim=*/0, + /*start=*/q_cu_copy_len, + /*end=*/static_cast(padded_num_tokens)); auto tail_cu = torch::cumsum(tail_q_seq_lens, /*dim=*/0); if (q_cu_copy_len > 0) { auto last_prefix = q_cu_seq_lens_.slice(/*dim=*/0, @@ -639,7 +631,9 @@ std::optional GraphPersistentParam::update( tail_cu = tail_cu + last_prefix; } q_cu_seq_lens_ - .slice(/*dim=*/0, /*start=*/q_cu_copy_len, /*end=*/padded_num_sequences) + .slice(/*dim=*/0, + /*start=*/q_cu_copy_len, + /*end=*/static_cast(padded_num_tokens)) .copy_(tail_cu, /*non_blocking=*/true); } @@ -669,25 +663,23 @@ std::optional GraphPersistentParam::update( std::optional params_for_capture = std::make_optional(params); - // actual_num_sequences tells the model's normalize function how many - // sequence rows carry real data. For MTP validate this can be larger - // than actual_batch_size because one request expands to - // (1 + num_speculative_tokens) token rows. params_for_capture->actual_num_sequences = - static_cast(actual_num_sequences); + static_cast(actual_num_tokens); params_for_capture->kv_seq_lens = kv_seq_lens(padded_num_tokens); params_for_capture->q_seq_lens = q_seq_lens(padded_num_tokens); - // Copy real seq_lens rows [0, actual_num_sequences), fill padding - // rows [actual_num_sequences, padded_num_sequences) with 1 so that the + // Copy real seq_lens rows [0, actual_num_tokens), fill padding + // rows [actual_num_tokens, padded_num_tokens) with 1 so that the // graph-captured forward sees valid (non-zero) seq_lens everywhere. params_for_capture->kv_seq_lens_vec.resize(padded_num_tokens); params_for_capture->q_seq_lens_vec.resize(padded_num_tokens); - for (int64_t i = 0; i < actual_num_sequences; ++i) { + for (int64_t i = 0; i < actual_num_tokens; ++i) { params_for_capture->kv_seq_lens_vec[i] = params.kv_seq_lens_vec[i]; params_for_capture->q_seq_lens_vec[i] = params.q_seq_lens_vec[i]; } - for (int64_t i = actual_num_sequences; i < padded_num_sequences; ++i) { + for (int64_t i = actual_num_tokens; + i < static_cast(padded_num_tokens); + ++i) { params_for_capture->kv_seq_lens_vec[i] = 1; params_for_capture->q_seq_lens_vec[i] = 1; } @@ -730,9 +722,10 @@ std::optional GraphPersistentParam::update( // Set q_cu_seq_lens if available if (params.q_cu_seq_lens.defined()) { params_for_capture->q_cu_seq_lens = - q_cu_seq_lens_.slice(/*dim=*/0, - /*start=*/0, - /*end=*/padded_num_sequences); + q_cu_seq_lens_.slice( + /*dim=*/0, + /*start=*/0, + /*end=*/static_cast(padded_num_tokens)); } // Replace dp/cp ep padding with slices of persistent buffers so that diff --git a/xllm/models/llm/deepseek_v4_mtp.h b/xllm/models/llm/deepseek_v4_mtp.h index 43115643cb..c08e3ea14e 100644 --- a/xllm/models/llm/deepseek_v4_mtp.h +++ b/xllm/models/llm/deepseek_v4_mtp.h @@ -213,7 +213,9 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { const ModelInputParams& input_params) { torch::NoGradGuard no_grad; - if (!tokens.defined() || tokens.numel() == 0) { + const bool is_empty_dp_rank = !tokens.defined() || tokens.numel() == 0; + + if (is_empty_dp_rank) { tokens = torch::tensor( {0}, torch::TensorOptions().dtype(torch::kInt32).device(device_)); positions = torch::tensor( @@ -222,7 +224,12 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { const torch::Device runtime_device = tokens.device(); - torch::Tensor previous_hidden_states = input_params.input_embedding; + auto modified_input_params = input_params; + if (is_empty_dp_rank) { + fill_empty_dp_rank_input_params(modified_input_params); + } + + torch::Tensor previous_hidden_states = modified_input_params.input_embedding; CHECK(previous_hidden_states.defined()) << "input_params.input_embedding must be defined for MTP model"; @@ -249,7 +256,6 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { torch::zeros_like(hidden_states.index({mask}))); } - auto modified_input_params = input_params; if (acl_graph_forward) { DeepseekV4ModelImpl::normalize_graph_metadata_input_params( modified_input_params); @@ -315,6 +321,7 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { if (modified_input_params.actual_num_sequences == 0) { fill_empty_dp_rank_input_params(modified_input_params); } + DeepseekV4ModelImpl::normalize_graph_metadata_input_params( modified_input_params); auto& dp_token_nums = modified_input_params.dp_global_token_nums; @@ -365,6 +372,16 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { params.new_cache_slots = torch::tensor({0}, cpu_int_options); params.block_tables = torch::zeros({1, 1}, cpu_int_options); + if (!params.input_embedding.defined()) { + auto options = torch::TensorOptions() + .dtype(torch::kBFloat16) + .device(device_); + params.input_embedding = torch::zeros( + {static_cast(params.num_sequences), + model_args_.hidden_size()}, + options); + } + if (!params.multi_block_tables.empty()) { return; } From c9ceb10f8fc579e2f459f59804a5ab3188fb0cfb Mon Sep 17 00:00:00 2001 From: panxuanyu Date: Sat, 23 May 2026 14:53:04 +0800 Subject: [PATCH 10/10] refactor: add util::is_target_model_type --- xllm/core/distributed_runtime/llm_engine.cpp | 19 ++++++++++++------- xllm/core/distributed_runtime/master.cpp | 4 +++- xllm/core/framework/kv_cache/kv_cache.cpp | 5 +++-- .../framework/kv_cache/kv_cache_shape.cpp | 5 +++-- xllm/core/layers/npu_torch/fused_moe.cpp | 5 +++-- xllm/core/util/utils.h | 14 ++++++++++++++ xllm/models/llm/mtp_model_base.h | 4 +++- 7 files changed, 41 insertions(+), 15 deletions(-) diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index f9118f1c98..709a30685c 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -575,8 +575,9 @@ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { // all swa-related cache size from cache_size_in_bytes, then compute // c4_count / c128_count (c4_count = 32 * c128_count). // cache_size_in_bytes is already the full available device memory. - if (args_.model_type() == "deepseek_v4" || - args_.model_type() == "deepseek_v4_mtp") { + if (util::is_taget_model_type(args_.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { const int64_t max_seqs = static_cast(std::max(options_.max_seqs_per_batch(), 1)); const int32_t block_size = options_.block_size(); @@ -793,8 +794,9 @@ bool LLMEngine::allocate_kv_cache(const KVCacheCapacity& kv_cache_cap) { // init kv cache for each worker const KVCacheShape kv_cache_shape(kv_cache_cap, args_, dp_local_tp_size_); - if (args_.model_type() == "deepseek_v4" || - args_.model_type() == "deepseek_v4_mtp") { + if (util::is_taget_model_type(args_.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { LOG(INFO) << "Initializing DSV4 kv cache with shape: [swa_count=" << kv_cache_cap.swa_count() << ", c4_count=" << kv_cache_cap.c4_count() @@ -819,7 +821,9 @@ bool LLMEngine::allocate_kv_cache(const KVCacheCapacity& kv_cache_cap) { .slot_size(kv_cache_cap.slot_size()) .model_id(options_.model_id()) .max_seqs_per_batch(options_.max_seqs_per_batch()); - if (util::is_deepseek_v4_model_type(args_.model_type())) { + if (util::is_taget_model_type(args_.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { constexpr uint32_t kManagerTypeBlockManagerImpl = 0; constexpr uint32_t kManagerTypeSlidingWindowBlockManager = 1; @@ -1437,8 +1441,9 @@ std::vector LLMEngine::prepare_inputs( args_, threadpool_.get(), cp_size_))); dp_global_token_nums[dp_rank] = batched_inputs[dp_rank].flatten_tokens_vec.size(); - if (args_.model_type() == "deepseek_v4" || - args_.model_type() == "deepseek_v4_mtp") { + if (util::is_taget_model_type(args_.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { const int64_t actual_scheduled_tokens = static_cast( batched_inputs[dp_rank].flatten_tokens_vec.size()); const int64_t max_tokens_per_batch = diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index 54f7c78682..3a1a3ee1d0 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -146,7 +146,9 @@ Master::Master(const Options& options, EngineType type) std::filesystem::path(options_.model_path()).lexically_normal(); if (options_.enable_prefix_cache() && options_.backend() == "llm") { const std::string model_type = util::get_model_type(model_path); - if (model_type == "deepseek_v4" || model_type == "deepseek_v4_mtp") { + if (util::is_taget_model_type(model_type, + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { LOG(WARNING) << model_type << " does not support prefix cache with " "CompositeBlockManager yet, fallback to " "enable_prefix_cache=false"; diff --git a/xllm/core/framework/kv_cache/kv_cache.cpp b/xllm/core/framework/kv_cache/kv_cache.cpp index c1173ee41e..5565fddbc1 100644 --- a/xllm/core/framework/kv_cache/kv_cache.cpp +++ b/xllm/core/framework/kv_cache/kv_cache.cpp @@ -414,8 +414,9 @@ void allocate_kv_caches(std::vector& kv_caches, const int64_t num_layers = create_options.num_layers(); kv_caches.reserve(num_layers); - if (create_options.model_type() == "deepseek_v4" || - create_options.model_type() == "deepseek_v4_mtp") { + if (util::is_taget_model_type(create_options.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { std::vector layer_compress_ratios; layer_compress_ratios.reserve(static_cast(num_layers)); std::map ratio_shape_summaries; diff --git a/xllm/core/framework/kv_cache/kv_cache_shape.cpp b/xllm/core/framework/kv_cache/kv_cache_shape.cpp index 5181580cc6..9a97882611 100644 --- a/xllm/core/framework/kv_cache/kv_cache_shape.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_shape.cpp @@ -66,8 +66,9 @@ KVCacheShape::KVCacheShape(const KVCacheCapacity& kv_cache_cap, CHECK_GT(world_size, 0) << "world_size must be positive."; CHECK_GT(kv_cache_cap.block_size(), 0) << "block_size must be positive."; - if (model_args.model_type() == "deepseek_v4" || - model_args.model_type() == "deepseek_v4_mtp") { + if (util::is_taget_model_type(model_args.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { key_cache_shape_ = std::vector{kv_cache_cap.swa_count(), kv_cache_cap.c4_count(), kv_cache_cap.c128_count()}; diff --git a/xllm/core/layers/npu_torch/fused_moe.cpp b/xllm/core/layers/npu_torch/fused_moe.cpp index 78b2af2b24..850aac504b 100644 --- a/xllm/core/layers/npu_torch/fused_moe.cpp +++ b/xllm/core/layers/npu_torch/fused_moe.cpp @@ -338,8 +338,9 @@ FusedMoEImpl::FusedMoEImpl(const ModelArgs& model_args, n_shared_experts_(model_args.n_shared_experts()), is_gated_(moe_args.is_gated), skip_gate_load_(moe_args.skip_gate_load), - is_deepseek_v4_(model_args.model_type() == "deepseek_v4" || - model_args.model_type() == "deepseek_v4_mtp"), + is_deepseek_v4_(util::is_taget_model_type(model_args.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)), renormalize_(model_args.norm_topk_prob() ? 1 : 0), hidden_act_(model_args.hidden_act()), scoring_func_(model_args.scoring_func()), diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 4e0ae72692..98a62dc031 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -131,6 +131,20 @@ inline bool is_mla_model_type(std::string_view model_type) { return mla_model_type_set().contains(std::string(model_type)); } +inline bool is_taget_model_type(std::string_view model_type, + std::string_view target, + bool match_mtp) { + if (model_type == target) { + return true; + } + if (!match_mtp) { + return false; + } + std::string target_mtp(target); + target_mtp += "_mtp"; + return model_type == target_mtp; +} + inline std::string get_model_name( const std::filesystem::path& normalized_model_path) { std::string model_name; diff --git a/xllm/models/llm/mtp_model_base.h b/xllm/models/llm/mtp_model_base.h index b97fa37d27..b47f9c4865 100644 --- a/xllm/models/llm/mtp_model_base.h +++ b/xllm/models/llm/mtp_model_base.h @@ -31,7 +31,9 @@ namespace xllm { enum class MtpProjectionType { CONCAT_EH_PROJ, ADD_EH_PROJ }; inline bool is_deepseek_v4_mtp_model(const ModelArgs& model_args) { - return model_args.model_type() == "deepseek_v4_mtp"; + return util::is_taget_model_type(model_args.model_type(), + /*target=*/"deepseek_v4_mtp", + /*match_mtp=*/false); } inline MtpProjectionType get_mtp_projection_type(const ModelArgs& model_args) {