diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 56ffc4db4..709a30685 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -575,7 +575,9 @@ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { // all swa-related cache size from cache_size_in_bytes, then compute // c4_count / c128_count (c4_count = 32 * c128_count). // cache_size_in_bytes is already the full available device memory. - if (util::is_deepseek_v4_model_type(args_.model_type())) { + if (util::is_taget_model_type(args_.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { const int64_t max_seqs = static_cast(std::max(options_.max_seqs_per_batch(), 1)); const int32_t block_size = options_.block_size(); @@ -792,7 +794,9 @@ bool LLMEngine::allocate_kv_cache(const KVCacheCapacity& kv_cache_cap) { // init kv cache for each worker const KVCacheShape kv_cache_shape(kv_cache_cap, args_, dp_local_tp_size_); - if (util::is_deepseek_v4_model_type(args_.model_type())) { + if (util::is_taget_model_type(args_.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { LOG(INFO) << "Initializing DSV4 kv cache with shape: [swa_count=" << kv_cache_cap.swa_count() << ", c4_count=" << kv_cache_cap.c4_count() @@ -817,7 +821,9 @@ bool LLMEngine::allocate_kv_cache(const KVCacheCapacity& kv_cache_cap) { .slot_size(kv_cache_cap.slot_size()) .model_id(options_.model_id()) .max_seqs_per_batch(options_.max_seqs_per_batch()); - if (util::is_deepseek_v4_model_type(args_.model_type())) { + if (util::is_taget_model_type(args_.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { constexpr uint32_t kManagerTypeBlockManagerImpl = 0; constexpr uint32_t kManagerTypeSlidingWindowBlockManager = 1; @@ -1435,7 +1441,9 @@ std::vector LLMEngine::prepare_inputs( args_, threadpool_.get(), cp_size_))); dp_global_token_nums[dp_rank] = batched_inputs[dp_rank].flatten_tokens_vec.size(); - if (util::is_deepseek_v4_model_type(args_.model_type())) { + if (util::is_taget_model_type(args_.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { const int64_t actual_scheduled_tokens = static_cast( batched_inputs[dp_rank].flatten_tokens_vec.size()); const int64_t max_tokens_per_batch = diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index dbb0d6775..3a1a3ee1d 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -146,7 +146,9 @@ Master::Master(const Options& options, EngineType type) std::filesystem::path(options_.model_path()).lexically_normal(); if (options_.enable_prefix_cache() && options_.backend() == "llm") { const std::string model_type = util::get_model_type(model_path); - if (util::is_deepseek_v4_model_type(model_type)) { + if (util::is_taget_model_type(model_type, + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { LOG(WARNING) << model_type << " does not support prefix cache with " "CompositeBlockManager yet, fallback to " "enable_prefix_cache=false"; diff --git a/xllm/core/framework/kv_cache/kv_cache.cpp b/xllm/core/framework/kv_cache/kv_cache.cpp index 0dd585569..5565fddbc 100644 --- a/xllm/core/framework/kv_cache/kv_cache.cpp +++ b/xllm/core/framework/kv_cache/kv_cache.cpp @@ -414,7 +414,9 @@ void allocate_kv_caches(std::vector& kv_caches, const int64_t num_layers = create_options.num_layers(); kv_caches.reserve(num_layers); - if (util::is_deepseek_v4_model_type(create_options.model_type())) { + if (util::is_taget_model_type(create_options.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { std::vector layer_compress_ratios; layer_compress_ratios.reserve(static_cast(num_layers)); std::map ratio_shape_summaries; diff --git a/xllm/core/framework/kv_cache/kv_cache_shape.cpp b/xllm/core/framework/kv_cache/kv_cache_shape.cpp index 8f7501f03..9a9788261 100644 --- a/xllm/core/framework/kv_cache/kv_cache_shape.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_shape.cpp @@ -66,7 +66,9 @@ KVCacheShape::KVCacheShape(const KVCacheCapacity& kv_cache_cap, CHECK_GT(world_size, 0) << "world_size must be positive."; CHECK_GT(kv_cache_cap.block_size(), 0) << "block_size must be positive."; - if (util::is_deepseek_v4_model_type(model_args.model_type())) { + if (util::is_taget_model_type(model_args.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)) { key_cache_shape_ = std::vector{kv_cache_cap.swa_count(), kv_cache_cap.c4_count(), kv_cache_cap.c128_count()}; diff --git a/xllm/core/layers/npu_torch/fused_moe.cpp b/xllm/core/layers/npu_torch/fused_moe.cpp index 1658f67a3..850aac504 100644 --- a/xllm/core/layers/npu_torch/fused_moe.cpp +++ b/xllm/core/layers/npu_torch/fused_moe.cpp @@ -338,7 +338,9 @@ FusedMoEImpl::FusedMoEImpl(const ModelArgs& model_args, n_shared_experts_(model_args.n_shared_experts()), is_gated_(moe_args.is_gated), skip_gate_load_(moe_args.skip_gate_load), - is_deepseek_v4_(util::is_deepseek_v4_model_type(model_args.model_type())), + is_deepseek_v4_(util::is_taget_model_type(model_args.model_type(), + /*target=*/"deepseek_v4", + /*match_mtp=*/true)), renormalize_(model_args.norm_topk_prob() ? 1 : 0), hidden_act_(model_args.hidden_act()), scoring_func_(model_args.scoring_func()), diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 1e049cb47..66f900d63 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -418,14 +418,13 @@ std::optional GraphPersistentParam::update( bool return_capture_params) { CHECK_GT(padded_num_tokens, 0) << "padded_num_tokens must be > 0 when return_capture_params is true"; - const uint32_t actual_num_tokens = tokens.size(0); + const int64_t actual_num_tokens = tokens.size(0); int64_t actual_batch_size = infer_actual_batch_size(params); if (params.batch_forward_type.is_decode()) { const int64_t decode_tokens = std::max(options_.num_decoding_tokens(), 1); actual_batch_size = actual_num_tokens / decode_tokens; } - // Copy data from input parameters to persistent graph tensors if (actual_num_tokens > 0) { persistent_tokens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens) @@ -461,10 +460,10 @@ std::optional GraphPersistentParam::update( } } int64_t q_copy_len = 0; - if (actual_batch_size > 0 && params.q_seq_lens.defined() && + if (actual_num_tokens > 0 && params.q_seq_lens.defined() && params.q_seq_lens.dim() >= 1 && params.q_seq_lens.numel() > 0) { q_copy_len = - std::min(actual_batch_size, params.q_seq_lens.size(0)); + std::min(actual_num_tokens, params.q_seq_lens.size(0)); if (q_copy_len > 0) { q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_copy_len) .copy_(params.q_seq_lens.slice(/*dim=*/0, @@ -474,10 +473,10 @@ std::optional GraphPersistentParam::update( } } int64_t kv_copy_len = 0; - if (actual_batch_size > 0 && params.kv_seq_lens.defined() && + if (actual_num_tokens > 0 && params.kv_seq_lens.defined() && params.kv_seq_lens.dim() >= 1 && params.kv_seq_lens.numel() > 0) { kv_copy_len = - std::min(actual_batch_size, params.kv_seq_lens.size(0)); + std::min(actual_num_tokens, params.kv_seq_lens.size(0)); if (kv_copy_len > 0) { kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/kv_copy_len) .copy_(params.kv_seq_lens.slice(/*dim=*/0, @@ -488,19 +487,18 @@ std::optional GraphPersistentParam::update( } // Keep padded decode slots valid for empty/local-short DP shards. // These tensors are consumed by ATB setup alongside *_seq_lens_vec. - const int64_t padded_batch_size = static_cast(padded_num_tokens); - if (q_copy_len < padded_batch_size) { + if (q_copy_len < static_cast(padded_num_tokens)) { q_seq_lens_ .slice(/*dim=*/0, /*start=*/q_copy_len, - /*end=*/padded_batch_size) + /*end=*/static_cast(padded_num_tokens)) .fill_(1); } - if (kv_copy_len < padded_batch_size) { + if (kv_copy_len < static_cast(padded_num_tokens)) { kv_seq_lens_ .slice(/*dim=*/0, /*start=*/kv_copy_len, - /*end=*/padded_batch_size) + /*end=*/static_cast(padded_num_tokens)) .fill_(1); } @@ -551,10 +549,10 @@ std::optional GraphPersistentParam::update( // zeroed in the active bucket slice. int64_t block_rows_to_copy = 0; int64_t actual_block_table_len = 0; - if (actual_batch_size > 0 && params.block_tables.defined() && + if (actual_num_tokens > 0 && params.block_tables.defined() && params.block_tables.dim() >= 2 && params.block_tables.numel() > 0) { block_rows_to_copy = - std::min(actual_batch_size, params.block_tables.size(0)); + std::min(actual_num_tokens, params.block_tables.size(0)); actual_block_table_len = params.block_tables.size(1); if (block_rows_to_copy > 0 && actual_block_table_len > 0) { auto slice_persistent_block_tables = @@ -577,9 +575,10 @@ std::optional GraphPersistentParam::update( /*end=*/persistent_block_tables_.size(1)) .zero_(); } - if (actual_batch_size < padded_batch_size) { - zero_tensor_tail( - persistent_block_tables_, actual_batch_size, padded_batch_size); + if (actual_num_tokens < static_cast(padded_num_tokens)) { + zero_tensor_tail(persistent_block_tables_, + actual_num_tokens, + static_cast(padded_num_tokens)); } // Update persistent embedding from input_embedding if available @@ -611,7 +610,7 @@ std::optional GraphPersistentParam::update( const int64_t q_cu_size = (has_q_cu && params.q_cu_seq_lens.numel() > 0) ? params.q_cu_seq_lens.size(0) : 0; - const int64_t q_cu_copy_len = std::min(actual_batch_size, q_cu_size); + const int64_t q_cu_copy_len = std::min(actual_num_tokens, q_cu_size); if (q_cu_copy_len > 0) { q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len) .copy_(params.q_cu_seq_lens.slice(/*dim=*/0, @@ -619,10 +618,11 @@ std::optional GraphPersistentParam::update( /*end=*/q_cu_copy_len), /*non_blocking=*/true); } - if (padded_batch_size > q_cu_copy_len) { - auto tail_q_seq_lens = q_seq_lens_.slice(/*dim=*/0, - /*start=*/q_cu_copy_len, - /*end=*/padded_batch_size); + if (static_cast(padded_num_tokens) > q_cu_copy_len) { + auto tail_q_seq_lens = q_seq_lens_.slice( + /*dim=*/0, + /*start=*/q_cu_copy_len, + /*end=*/static_cast(padded_num_tokens)); auto tail_cu = torch::cumsum(tail_q_seq_lens, /*dim=*/0); if (q_cu_copy_len > 0) { auto last_prefix = q_cu_seq_lens_.slice(/*dim=*/0, @@ -631,7 +631,9 @@ std::optional GraphPersistentParam::update( tail_cu = tail_cu + last_prefix; } q_cu_seq_lens_ - .slice(/*dim=*/0, /*start=*/q_cu_copy_len, /*end=*/padded_batch_size) + .slice(/*dim=*/0, + /*start=*/q_cu_copy_len, + /*end=*/static_cast(padded_num_tokens)) .copy_(tail_cu, /*non_blocking=*/true); } @@ -660,20 +662,24 @@ std::optional GraphPersistentParam::update( if (return_capture_params) { std::optional params_for_capture = std::make_optional(params); - // Set persistent buffers in params_for_capture + params_for_capture->actual_num_sequences = - static_cast(actual_batch_size); + static_cast(actual_num_tokens); params_for_capture->kv_seq_lens = kv_seq_lens(padded_num_tokens); params_for_capture->q_seq_lens = q_seq_lens(padded_num_tokens); + + // Copy real seq_lens rows [0, actual_num_tokens), fill padding + // rows [actual_num_tokens, padded_num_tokens) with 1 so that the + // graph-captured forward sees valid (non-zero) seq_lens everywhere. params_for_capture->kv_seq_lens_vec.resize(padded_num_tokens); params_for_capture->q_seq_lens_vec.resize(padded_num_tokens); - // Copy actual values from original params - for (int i = 0; i < actual_batch_size; i++) { + for (int64_t i = 0; i < actual_num_tokens; ++i) { params_for_capture->kv_seq_lens_vec[i] = params.kv_seq_lens_vec[i]; params_for_capture->q_seq_lens_vec[i] = params.q_seq_lens_vec[i]; } - // Fill padded positions with default values - for (int i = actual_batch_size; i < padded_num_tokens; i++) { + for (int64_t i = actual_num_tokens; + i < static_cast(padded_num_tokens); + ++i) { params_for_capture->kv_seq_lens_vec[i] = 1; params_for_capture->q_seq_lens_vec[i] = 1; } @@ -716,9 +722,10 @@ std::optional GraphPersistentParam::update( // Set q_cu_seq_lens if available if (params.q_cu_seq_lens.defined()) { params_for_capture->q_cu_seq_lens = - q_cu_seq_lens_.slice(/*dim=*/0, - /*start=*/0, - /*end=*/padded_batch_size); + q_cu_seq_lens_.slice( + /*dim=*/0, + /*start=*/0, + /*end=*/static_cast(padded_num_tokens)); } // Replace dp/cp ep padding with slices of persistent buffers so that diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 1cc0605a4..66ba6e8e3 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -224,6 +224,23 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.batch_forward_type = BatchForwardType(pb_forward_input->batch_forward_type()); input_params.num_sequences = pb_forward_input->num_sequences(); + if (!block_tables_vec.empty()) { + CHECK_EQ(block_tables_vec.size(), + static_cast(input_params.num_sequences)) + << "block_tables_vec size (" << block_tables_vec.size() + << ") must match num_sequences (" << input_params.num_sequences << ")"; + } else { + // support for mutli_block_tables + for (int32_t m = 0; m < pb_forward_input->multi_block_tables_vec().size(); + ++m) { + const auto& mgr = pb_forward_input->multi_block_tables_vec()[m]; + CHECK_EQ(static_cast(mgr.block_tables().size()), + static_cast(input_params.num_sequences)) + << "multi_block_tables[" << m << "] size (" + << mgr.block_tables().size() << ") must match num_sequences (" + << input_params.num_sequences << ")"; + } + } input_params.kv_max_seq_len = pb_forward_input->max_seq_len(); input_params.q_max_seq_len = pb_forward_input->q_max_seq_len(); input_params.kv_seq_lens = torch::tensor(seq_lens, tensor_options); @@ -252,12 +269,9 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::move(create_2d_tensor(block_tables_vec, torch::kInt)); // multi block manager support for DeepSeek V4 - for (int m = 0; m < pb_forward_input->multi_block_tables_vec().size(); ++m) { + for (int32_t m = 0; m < pb_forward_input->multi_block_tables_vec().size(); + ++m) { const auto& mgr = pb_forward_input->multi_block_tables_vec()[m]; - CHECK_EQ(static_cast(mgr.block_tables().size()), - static_cast(input_params.num_sequences)) - << "multi_block_tables[" << m << "] size (" << mgr.block_tables().size() - << ") must match num_sequences (" << input_params.num_sequences << ")"; std::vector> mgr_tables; mgr_tables.reserve(mgr.block_tables().size()); for (int s = 0; s < mgr.block_tables().size(); ++s) { diff --git a/xllm/core/util/utils.h b/xllm/core/util/utils.h index 81043698b..98a62dc03 100644 --- a/xllm/core/util/utils.h +++ b/xllm/core/util/utils.h @@ -131,27 +131,18 @@ inline bool is_mla_model_type(std::string_view model_type) { return mla_model_type_set().contains(std::string(model_type)); } -inline bool has_mtp_model_type_marker(std::string_view model_type) { - return model_type.find("mtp") != std::string_view::npos; -} - -inline bool starts_with_model_type(std::string_view model_type, - std::string_view target_model_type) { - return model_type.size() >= target_model_type.size() && - model_type.compare( - 0, target_model_type.size(), target_model_type) == 0; -} - -inline bool is_target_mtp_model_type(std::string_view model_type, - std::string_view target_model_type) { - return starts_with_model_type(model_type, target_model_type) && - has_mtp_model_type_marker(model_type); -} - -inline bool is_deepseek_v4_model_type(std::string_view model_type) { - constexpr std::string_view kTargetModelType = "deepseek_v4"; - return model_type == kTargetModelType || - is_target_mtp_model_type(model_type, kTargetModelType); +inline bool is_taget_model_type(std::string_view model_type, + std::string_view target, + bool match_mtp) { + if (model_type == target) { + return true; + } + if (!match_mtp) { + return false; + } + std::string target_mtp(target); + target_mtp += "_mtp"; + return model_type == target_mtp; } inline std::string get_model_name( diff --git a/xllm/models/llm/deepseek_v4.h b/xllm/models/llm/deepseek_v4.h index 712921765..58d943775 100644 --- a/xllm/models/llm/deepseek_v4.h +++ b/xllm/models/llm/deepseek_v4.h @@ -840,6 +840,8 @@ class DeepseekV4ModelImpl return tensor_max_or_zero(fallback_tensor); } + public: + // TODO: Common Funcs for both dsv4/dsv4_mtp. Suggests move shared DeepSeek V4 graph metadata utilities out of DeepseekV4ModelImpl. static bool tensor_aliases_storage(const torch::Tensor& lhs, const torch::Tensor& rhs) { return lhs.defined() && rhs.defined() && lhs.data_ptr() == rhs.data_ptr() && @@ -958,69 +960,55 @@ class DeepseekV4ModelImpl #endif } - static int64_t infer_actual_batch_size(const ModelInputParams& params) { - if (params.actual_num_sequences > 0) { - return params.actual_num_sequences; - } - if (!params.kv_seq_lens_vec.empty()) { - return static_cast(params.kv_seq_lens_vec.size()); - } - if (!params.q_seq_lens_vec.empty()) { - return static_cast(params.q_seq_lens_vec.size()); - } - if (params.kv_seq_lens.defined() && params.kv_seq_lens.dim() >= 1) { - return params.kv_seq_lens.size(0); - } - if (params.q_seq_lens.defined() && params.q_seq_lens.dim() >= 1) { - return params.q_seq_lens.size(0); - } - if (params.block_tables.defined() && params.block_tables.dim() >= 2) { - return params.block_tables.size(0); - } - for (const auto& block_table : params.multi_block_tables) { - if (block_table.defined() && block_table.dim() >= 2) { - return block_table.size(0); - } - } - return 0; - } - - void normalize_graph_metadata_input_params(ModelInputParams& params) const { - const int64_t actual_batch_size = - std::max(infer_actual_batch_size(params), 0); - int64_t metadata_batch_size = params.actual_num_sequences; + // Normalize metadata vectors for graph mode (bucket-padded) forward. + // + // Key concepts: + // actual_metadata_rows -- rows that carry real token data. + // For normal decode: = num_requests. + // For MTP validate: = num_requests * (1 + num_spec_tokens). + // padded_metadata_rows -- total rows after bucket padding. + // >= actual_metadata_rows, rounded up to decode bucket. + // Rows in [actual_metadata_rows, padded_metadata_rows) are bucket padding + // and must be zeroed so DSAMetadataBuilder treats them as inactive. + // + // Example (normal decode, 2 requests, bucket size 4): + // actual_metadata_rows = 2, padded_metadata_rows = 4 + // kv_seq_lens_vec = [23, 24, 0, 0] + // + // Example (MTP validate, 1 request, num_speculative_tokens=1, bucket 4): + // actual_metadata_rows = 1*(1+1) = 2, padded_metadata_rows = 4 + // kv_seq_lens_vec = [23, 24, 0, 0] + // Both rows are real; num_speculative_tokens is from speculative config. + static void normalize_graph_metadata_input_params(ModelInputParams& params) { + int64_t actual_metadata_rows = std::max(params.actual_num_sequences, + 0); + int64_t padded_metadata_rows = actual_metadata_rows; if (params.enable_graph) { - metadata_batch_size = - std::max(metadata_batch_size, params.num_sequences); + padded_metadata_rows = + std::max(padded_metadata_rows, params.num_sequences); } - if (metadata_batch_size <= 0) { - metadata_batch_size = 1; + if (padded_metadata_rows <= 0) { + padded_metadata_rows = 1; } + actual_metadata_rows = std::min(actual_metadata_rows, + padded_metadata_rows); - auto trim_lens_vec = [metadata_batch_size, - actual_batch_size](std::vector& lens) { + auto trim_lens_vec = [padded_metadata_rows, + actual_metadata_rows](std::vector& lens) { if (lens.empty()) { - lens.assign(static_cast(metadata_batch_size), 0); - } else if (static_cast(lens.size()) < metadata_batch_size) { - lens.resize(static_cast(metadata_batch_size), 0); + lens.assign(static_cast(padded_metadata_rows), 0); + } else if (static_cast(lens.size()) < padded_metadata_rows) { + lens.resize(static_cast(padded_metadata_rows), 0); } else { - lens.resize(static_cast(metadata_batch_size)); - } - for (int64_t i = 0; i < static_cast(lens.size()); ++i) { - if (i < actual_batch_size) { - lens[static_cast(i)] = std::max(lens[i], 1); - } else { - lens[static_cast(i)] = 0; - } + lens.resize(static_cast(padded_metadata_rows)); } + std::fill(lens.begin() + actual_metadata_rows, lens.end(), 0); }; - // Graph forward tensors are padded to the decode bucket. Build metadata - // for the same padded row count so compressor/attention inputs agree. trim_lens_vec(params.kv_seq_lens_vec); trim_lens_vec(params.q_seq_lens_vec); - params.num_sequences = static_cast(metadata_batch_size); - params.actual_num_sequences = static_cast(metadata_batch_size); + params.num_sequences = static_cast(padded_metadata_rows); + params.actual_num_sequences = static_cast(actual_metadata_rows); } std::shared_ptr @@ -1039,9 +1027,10 @@ class DeepseekV4ModelImpl return attn_metadata; } - std::shared_ptr persist_graph_attention_metadata( + static std::shared_ptr + persist_graph_attention_metadata( DeepseekV4GraphMetadataState& state, - std::shared_ptr metadata) const { + std::shared_ptr metadata) { if (!metadata || !metadata->dsa_metadata) { return metadata; } @@ -1057,6 +1046,7 @@ class DeepseekV4ModelImpl return metadata; } + private: void fill_empty_dp_rank_input_params( ModelInputParams& params, const std::vector* kv_caches = nullptr) const { @@ -1065,6 +1055,7 @@ class DeepseekV4ModelImpl .device(torch::kCPU) .pinned_memory(true); params.num_sequences = 1; + params.actual_num_sequences = 1; params.kv_max_seq_len = std::max(params.kv_max_seq_len, 1); params.q_max_seq_len = std::max(params.q_max_seq_len, 1); params.kv_seq_lens_vec = {1}; @@ -1090,7 +1081,7 @@ class DeepseekV4ModelImpl kv_caches->front()); } block_num = std::max(block_num, 1); - params.multi_block_tables.push_back( + params.multi_block_tables.emplace_back( torch::zeros({1, block_num}, cpu_int_options)); } } diff --git a/xllm/models/llm/deepseek_v4_mtp.h b/xllm/models/llm/deepseek_v4_mtp.h index e76c43a94..c08e3ea14 100644 --- a/xllm/models/llm/deepseek_v4_mtp.h +++ b/xllm/models/llm/deepseek_v4_mtp.h @@ -36,6 +36,7 @@ limitations under the License. #include "core/kernels/ops_api.h" #include "core/layers/common/attention_metadata_builder.h" #include "core/layers/common/dsa_metadata.h" +#include "core/layers/common/dsa_metadata_builder.h" #include "core/layers/common/linear.h" #include "core/layers/common/rms_norm.h" #include "core/layers/common/word_embedding.h" @@ -60,8 +61,7 @@ class DeepseekV4MultiTokenPredictorLayerImpl torch::Tensor positions, layer::AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params, - torch::Tensor tokens) { + const ModelInputParams& input_params) { ModelInputParams modified_input_params = input_params; modified_input_params.input_embedding = previous_hidden_states; std::optional residual; @@ -71,30 +71,28 @@ class DeepseekV4MultiTokenPredictorLayerImpl positions, attn_metadata, kv_cache, - modified_input_params, - tokens); + modified_input_params); } }; TORCH_MODULE(DeepseekV4MultiTokenPredictorLayer); class DeepseekV4MtpModelImpl final : public torch::nn::Module { public: - explicit DeepseekV4MtpModelImpl(const ModelContext& context) { - auto model_args = context.get_model_args(); + explicit DeepseekV4MtpModelImpl(const ModelContext& context) + : model_args_(context.get_model_args()) { auto options = context.get_tensor_options(); auto parallel_args = context.get_parallel_args(); + device_ = options.device(); - model_args_ = &model_args; - - CHECK_GT(model_args.n_layers(), 0) + CHECK_GT(model_args_.n_layers(), 0) << "deepseek_v4_mtp requires n_layers > 0"; - CHECK_GE(model_args.num_nextn_predict_layers(), 0) + CHECK_GE(model_args_.num_nextn_predict_layers(), 0) << "deepseek_v4_mtp requires num_nextn_predict_layers >= 0"; - const int32_t mtp_n_layers = model_args.n_layers(); + const int32_t mtp_n_layers = model_args_.n_layers(); - num_heads_ = model_args.n_heads(); - head_dim_ = model_args.o_lora_rank() + model_args.qk_rope_head_dim(); + num_heads_ = model_args_.n_heads(); + head_dim_ = model_args_.o_lora_rank() + model_args_.qk_rope_head_dim(); dp_local_tp_size_ = std::max(parallel_args.world_size() / std::max(parallel_args.dp_size(), 1), @@ -104,31 +102,31 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { "size. n_heads=" << num_heads_ << ", local_tp_size=" << dp_local_tp_size_; tp_num_heads_ = num_heads_ / dp_local_tp_size_; - window_size_ = model_args.window_size(); - index_n_heads_ = model_args.index_n_heads(); - index_head_dim_ = model_args.index_head_dim(); - index_topk_ = model_args.index_topk(); - norm_eps_ = static_cast(model_args.rms_norm_eps()); - - const int64_t rope_head_dim = model_args.rope_head_dim(); - const int64_t max_pos = model_args.max_position_embeddings(); + window_size_ = model_args_.window_size(); + index_n_heads_ = model_args_.index_n_heads(); + index_head_dim_ = model_args_.index_head_dim(); + index_topk_ = model_args_.index_topk(); + norm_eps_ = static_cast(model_args_.rms_norm_eps()); + + const int64_t rope_head_dim = model_args_.rope_head_dim(); + const int64_t max_pos = model_args_.max_position_embeddings(); if (rope_head_dim > 0 && max_pos > 0) { const int64_t original_max_pos = - model_args.rope_scaling_original_max_position_embeddings() > 0 - ? model_args.rope_scaling_original_max_position_embeddings() + model_args_.rope_scaling_original_max_position_embeddings() > 0 + ? model_args_.rope_scaling_original_max_position_embeddings() : max_pos; dsa_rotary_embedding_ = std::make_shared( /*rotary_dim=*/rope_head_dim, /*max_position_embeddings=*/max_pos, /*interleaved=*/true, - /*rope_theta=*/model_args.rope_theta(), - /*compress_rope_theta=*/model_args.compress_rope_theta(), - /*scaling_factor=*/model_args.factor(), + /*rope_theta=*/model_args_.rope_theta(), + /*compress_rope_theta=*/model_args_.compress_rope_theta(), + /*scaling_factor=*/model_args_.factor(), /*extrapolation_factor=*/1.0f, - /*beta_fast=*/model_args.beta_fast(), - /*beta_slow=*/model_args.beta_slow(), - /*attn_factor=*/model_args.rope_scaling_attn_factor(), + /*beta_fast=*/model_args_.beta_fast(), + /*beta_slow=*/model_args_.beta_slow(), + /*attn_factor=*/model_args_.rope_scaling_attn_factor(), /*mscale=*/1.0f, /*mscale_all_dim=*/1.0f, /*original_max_position_embeddings=*/original_max_pos, @@ -136,14 +134,14 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { dsa_cos_sin_ = dsa_rotary_embedding_->get_cos_sin_cache("default"); } - if (model_args.index_head_dim() > 0) { + if (model_args_.index_head_dim() > 0) { int64_t hadamard_dim_padded = - deepseek_v4_next_power_of_two(model_args.index_head_dim()); + deepseek_v4_next_power_of_two(model_args_.index_head_dim()); dsa_hadamard_ = deepseek_v4_create_hadamard_matrix( hadamard_dim_padded, options.dtype().toScalarType(), options.device()); } - deepseek_v4_build_cache_specs(model_args, caches_info_, group_infos_); + deepseek_v4_build_cache_specs(model_args_, caches_info_, group_infos_); mtp_layers_.reserve(mtp_n_layers); for (int32_t i = 0; i < mtp_n_layers; ++i) { @@ -215,21 +213,42 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { const ModelInputParams& input_params) { torch::NoGradGuard no_grad; - if (tokens.numel() == 0) { - tokens = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); - positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + const bool is_empty_dp_rank = !tokens.defined() || tokens.numel() == 0; + + if (is_empty_dp_rank) { + tokens = torch::tensor( + {0}, torch::TensorOptions().dtype(torch::kInt32).device(device_)); + positions = torch::tensor( + {0}, torch::TensorOptions().dtype(torch::kInt32).device(device_)); } const torch::Device runtime_device = tokens.device(); - torch::Tensor previous_hidden_states = input_params.input_embedding; + auto modified_input_params = input_params; + if (is_empty_dp_rank) { + fill_empty_dp_rank_input_params(modified_input_params); + } + + torch::Tensor previous_hidden_states = modified_input_params.input_embedding; CHECK(previous_hidden_states.defined()) << "input_params.input_embedding must be defined for MTP model"; torch::Tensor hidden_states = embed_tokens_(tokens); - tokens = maybe_to_device(tokens, runtime_device); - positions = maybe_to_device(positions, runtime_device); + const bool acl_graph_forward = deepseek_v4_uses_acl_graph(input_params); + if (acl_graph_forward) { + CHECK(tokens.defined() && tokens.device() == runtime_device) + << "[DeepseekV4Mtp] ACL graph requires tokens on the runtime device"; + CHECK(positions.defined() && positions.device() == runtime_device) + << "[DeepseekV4Mtp] ACL graph requires positions on the runtime device"; + CHECK(input_params.new_cache_slots.defined()) + << "[DeepseekV4Mtp] ACL graph requires persistent new_cache_slots"; + CHECK(input_params.block_tables.defined()) + << "[DeepseekV4Mtp] ACL graph requires persistent block_tables"; + } else { + tokens = maybe_to_device(tokens, runtime_device); + positions = maybe_to_device(positions, runtime_device); + } auto mask = (positions == 0); if (mask.any().item()) { @@ -237,12 +256,31 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { torch::zeros_like(hidden_states.index({mask}))); } - auto modified_input_params = input_params; + if (acl_graph_forward) { + DeepseekV4ModelImpl::normalize_graph_metadata_input_params( + modified_input_params); + } auto& dp_token_nums = modified_input_params.dp_global_token_nums; std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); - modified_input_params.attn_metadata = - build_attention_metadata_for_forward(positions, modified_input_params); + if (!modified_input_params.attn_metadata) { + CHECK(!acl_graph_forward) + << "[DeepseekV4Mtp] ACL graph requires prebuilt attention metadata"; + modified_input_params.attn_metadata = + build_attention_metadata_for_forward(positions, modified_input_params); + } + + if (modified_input_params.attn_metadata && + modified_input_params.attn_metadata->dsa_metadata) { + auto& dsa = *(modified_input_params.attn_metadata->dsa_metadata); + const bool graph_metadata_ready = acl_graph_forward && + dsa.packed_metadata_buffer.defined() && + dsa.start_pos.defined(); + if (graph_metadata_ready) { + build_rope(dsa, runtime_device); + build_precomputed_metadata(dsa); + } + } CHECK_GE(static_cast(kv_caches.size()), static_cast(mtp_layers_.size())) @@ -256,15 +294,106 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { positions, *modified_input_params.attn_metadata, kv_caches[i], - modified_input_params, - tokens); + modified_input_params); } auto [output, _] = final_norm_(hidden_states, std::nullopt); return ModelOutput(output, std::nullopt); } - private: + bool requires_graph_forward_metadata() { return true; } + + std::unique_ptr + create_graph_forward_metadata_state() { + return std::make_unique(); + } + + void prepare_graph_forward_metadata(ModelGraphMetadataState* state, + const torch::Tensor& positions, + ModelInputParams& input_params) { + CHECK(state != nullptr) + << "[DeepseekV4Mtp] graph metadata state must be initialized"; + auto* deepseek_v4_state = dynamic_cast(state); + CHECK(deepseek_v4_state != nullptr) + << "[DeepseekV4Mtp] received incompatible graph metadata state"; + + auto modified_input_params = input_params; + if (modified_input_params.actual_num_sequences == 0) { + fill_empty_dp_rank_input_params(modified_input_params); + } + + DeepseekV4ModelImpl::normalize_graph_metadata_input_params( + modified_input_params); + auto& dp_token_nums = modified_input_params.dp_global_token_nums; + std::replace(dp_token_nums.begin(), dp_token_nums.end(), 0, 1); + + auto attn_metadata = std::make_shared( + layer::DSAMetadataBuilder::build(modified_input_params, + positions, + dsa_cos_sin_, + caches_info_, + group_infos_)); + if (attn_metadata->dsa_metadata) { + auto& dsa = *attn_metadata->dsa_metadata; + if (dsa_hadamard_.defined()) { + dsa.hadamard = dsa_hadamard_; + } + DeepseekV4ModelImpl::copy_to_graph_packed_metadata_buffer( + dsa, deepseek_v4_state->dsa_metadata_persistent, positions.device()); + if (dsa.actual_seq_lengths_kv.defined() && dsa.seq_lens_q.defined()) { + dsa.start_pos = + (dsa.actual_seq_lengths_kv - dsa.seq_lens_q).to(torch::kInt32); + } + } + input_params.attn_metadata = + DeepseekV4ModelImpl::persist_graph_attention_metadata( + *deepseek_v4_state, std::move(attn_metadata)); + CHECK(input_params.attn_metadata) + << "[DeepseekV4Mtp] ACL graph requires DSA metadata"; + } + + private: + void fill_empty_dp_rank_input_params(ModelInputParams& params) const { + auto cpu_int_options = torch::TensorOptions() + .dtype(torch::kInt32) + .device(torch::kCPU) + .pinned_memory(true); + params.num_sequences = 1; + params.actual_num_sequences = 1; + params.kv_max_seq_len = std::max(params.kv_max_seq_len, 1); + params.q_max_seq_len = std::max(params.q_max_seq_len, 1); + params.kv_seq_lens_vec = {1}; + params.q_seq_lens_vec = {1}; + params.kv_seq_lens = torch::tensor(params.kv_seq_lens_vec, cpu_int_options); + params.q_seq_lens = torch::tensor(params.q_seq_lens_vec, cpu_int_options); + params.q_cu_seq_lens = torch::tensor({1}, cpu_int_options); + params.kv_cache_tokens_nums = torch::tensor({1}, cpu_int_options); + params.kv_cache_tokens_nums_host = {1}; + params.new_cache_slots = torch::tensor({0}, cpu_int_options); + params.block_tables = torch::zeros({1, 1}, cpu_int_options); + + if (!params.input_embedding.defined()) { + auto options = torch::TensorOptions() + .dtype(torch::kBFloat16) + .device(device_); + params.input_embedding = torch::zeros( + {static_cast(params.num_sequences), + model_args_.hidden_size()}, + options); + } + + if (!params.multi_block_tables.empty()) { + return; + } + + const int32_t manager_num = static_cast(group_infos_.size()); + params.multi_block_tables.reserve(manager_num); + for (int32_t manager_id = 0; manager_id < manager_num; ++manager_id) { + params.multi_block_tables.emplace_back( + torch::zeros({1, 1}, cpu_int_options)); + } + } + std::shared_ptr build_attention_metadata_for_forward(const torch::Tensor& positions, const ModelInputParams& input_params) { @@ -648,8 +777,8 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { const int32_t layer_compress_ratio = deepseek_v4_normalize_compress_ratio( (layer_id < - static_cast(model_args_->compress_ratios().size())) - ? model_args_->compress_ratios()[static_cast(layer_id)] + static_cast(model_args_.compress_ratios().size())) + ? model_args_.compress_ratios()[static_cast(layer_id)] : 1); if (layer_compress_ratio == 4 && dsa.c4_cos.defined()) { @@ -706,7 +835,8 @@ class DeepseekV4MtpModelImpl final : public torch::nn::Module { int64_t index_topk_ = 512; double norm_eps_ = 1e-6; - const ModelArgs* model_args_ = nullptr; + ModelArgs model_args_; + torch::Device device_{torch::kCPU}; layer::RMSNorm final_norm_{nullptr}; layer::WordEmbedding embed_tokens_{nullptr}; @@ -730,6 +860,22 @@ class DeepseekV4MtpForCausalLMImpl final } model_->verify_loaded_weights(prefix); } + + bool requires_graph_forward_metadata() { + return this->model_->requires_graph_forward_metadata(); + } + + std::unique_ptr + create_graph_forward_metadata_state() { + return this->model_->create_graph_forward_metadata_state(); + } + + void prepare_graph_forward_metadata(ModelGraphMetadataState* state, + const torch::Tensor& positions, + ModelInputParams& input_params) { + this->model_->prepare_graph_forward_metadata( + state, positions, input_params); + } }; TORCH_MODULE(DeepseekV4MtpForCausalLM); diff --git a/xllm/models/llm/mtp_model_base.h b/xllm/models/llm/mtp_model_base.h index 7535b2646..b47f9c486 100644 --- a/xllm/models/llm/mtp_model_base.h +++ b/xllm/models/llm/mtp_model_base.h @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "core/framework/state_dict/utils.h" @@ -32,7 +31,9 @@ namespace xllm { enum class MtpProjectionType { CONCAT_EH_PROJ, ADD_EH_PROJ }; inline bool is_deepseek_v4_mtp_model(const ModelArgs& model_args) { - return util::is_target_mtp_model_type(model_args.model_type(), "deepseek_v4"); + return util::is_taget_model_type(model_args.model_type(), + /*target=*/"deepseek_v4_mtp", + /*match_mtp=*/false); } inline MtpProjectionType get_mtp_projection_type(const ModelArgs& model_args) { @@ -80,13 +81,8 @@ class MtpDecoderLayerImplBase : public torch::nn::Module { /*QuantArgs=*/QuantArgs(), options)); } - const int32_t decoder_layer_index = is_deepseek_v4_mtp_model(model_args_) - ? std::min( - layer_index, - model_args_.n_layers() - 1) - : layer_index; mtp_block_ = register_module("mtp_block", - DecoderLayerType(context, decoder_layer_index)); + DecoderLayerType(context, layer_index)); if (is_deepseek_v4_mtp_model(model_args_)) { const int64_t hc_mult = model_args_.hc_mult(); @@ -106,9 +102,7 @@ class MtpDecoderLayerImplBase : public torch::nn::Module { torch::Tensor positions, const layer::AttentionMetadata& attn_metadata, KVCache& kv_cache, - const ModelInputParams& input_params, - const std::optional& input_ids = - std::nullopt) { + const ModelInputParams& input_params) { // Layer norm on token inputs auto enorm_out = std::get<0>(enorm_(embed)); @@ -138,29 +132,12 @@ class MtpDecoderLayerImplBase : public torch::nn::Module { } // Pass through mtp block - if constexpr (std::is_invocable_v&, - torch::Tensor&, - const layer::AttentionMetadata&, - KVCache&, - const ModelInputParams&, - const std::optional&>) { - hidden_states = mtp_block_(hidden_states, - residual, - positions, - attn_metadata, - kv_cache, - input_params, - input_ids); - } else { - hidden_states = mtp_block_(hidden_states, - residual, - positions, - attn_metadata, - kv_cache, - input_params); - } + hidden_states = mtp_block_(hidden_states, + residual, + positions, + attn_metadata, + kv_cache, + input_params); if (is_deepseek_v4_mtp_model(model_args_)) { auto x_float = hidden_states.to(torch::kFloat32); @@ -311,8 +288,7 @@ class MtpModelImplBase : public torch::nn::Module { positions, attn_metadata, kv_caches[i], - modified_input_params, - tokens); + modified_input_params); if (!modified_input_params.record_layer(static_cast(i), hidden_states.device())) { return ModelOutput();