Skip to content
16 changes: 12 additions & 4 deletions xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,9 @@ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
// all swa-related cache size from cache_size_in_bytes, then compute
// c4_count / c128_count (c4_count = 32 * c128_count).
// cache_size_in_bytes is already the full available device memory.
if (util::is_deepseek_v4_model_type(args_.model_type())) {
if (util::is_taget_model_type(args_.model_type(),
/*target=*/"deepseek_v4",
/*match_mtp=*/true)) {
const int64_t max_seqs =
static_cast<int64_t>(std::max(options_.max_seqs_per_batch(), 1));
const int32_t block_size = options_.block_size();
Expand Down Expand Up @@ -792,7 +794,9 @@ bool LLMEngine::allocate_kv_cache(const KVCacheCapacity& kv_cache_cap) {

// init kv cache for each worker
const KVCacheShape kv_cache_shape(kv_cache_cap, args_, dp_local_tp_size_);
if (util::is_deepseek_v4_model_type(args_.model_type())) {
if (util::is_taget_model_type(args_.model_type(),
/*target=*/"deepseek_v4",
/*match_mtp=*/true)) {
LOG(INFO) << "Initializing DSV4 kv cache with shape: [swa_count="
<< kv_cache_cap.swa_count()
<< ", c4_count=" << kv_cache_cap.c4_count()
Expand All @@ -817,7 +821,9 @@ bool LLMEngine::allocate_kv_cache(const KVCacheCapacity& kv_cache_cap) {
.slot_size(kv_cache_cap.slot_size())
.model_id(options_.model_id())
.max_seqs_per_batch(options_.max_seqs_per_batch());
if (util::is_deepseek_v4_model_type(args_.model_type())) {
if (util::is_taget_model_type(args_.model_type(),
/*target=*/"deepseek_v4",
/*match_mtp=*/true)) {
constexpr uint32_t kManagerTypeBlockManagerImpl = 0;
constexpr uint32_t kManagerTypeSlidingWindowBlockManager = 1;

Expand Down Expand Up @@ -1435,7 +1441,9 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(
args_, threadpool_.get(), cp_size_)));
dp_global_token_nums[dp_rank] =
batched_inputs[dp_rank].flatten_tokens_vec.size();
if (util::is_deepseek_v4_model_type(args_.model_type())) {
if (util::is_taget_model_type(args_.model_type(),
/*target=*/"deepseek_v4",
/*match_mtp=*/true)) {
const int64_t actual_scheduled_tokens = static_cast<int64_t>(
batched_inputs[dp_rank].flatten_tokens_vec.size());
const int64_t max_tokens_per_batch =
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/distributed_runtime/master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ Master::Master(const Options& options, EngineType type)
std::filesystem::path(options_.model_path()).lexically_normal();
if (options_.enable_prefix_cache() && options_.backend() == "llm") {
const std::string model_type = util::get_model_type(model_path);
if (util::is_deepseek_v4_model_type(model_type)) {
if (util::is_taget_model_type(model_type,
/*target=*/"deepseek_v4",
/*match_mtp=*/true)) {
LOG(WARNING) << model_type << " does not support prefix cache with "
"CompositeBlockManager yet, fallback to "
"enable_prefix_cache=false";
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/framework/kv_cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,9 @@ void allocate_kv_caches(std::vector<KVCache>& kv_caches,
const int64_t num_layers = create_options.num_layers();
kv_caches.reserve(num_layers);

if (util::is_deepseek_v4_model_type(create_options.model_type())) {
if (util::is_taget_model_type(create_options.model_type(),
/*target=*/"deepseek_v4",
/*match_mtp=*/true)) {
std::vector<int32_t> layer_compress_ratios;
layer_compress_ratios.reserve(static_cast<size_t>(num_layers));
std::map<int32_t, std::string> ratio_shape_summaries;
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/framework/kv_cache/kv_cache_shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ KVCacheShape::KVCacheShape(const KVCacheCapacity& kv_cache_cap,
CHECK_GT(world_size, 0) << "world_size must be positive.";
CHECK_GT(kv_cache_cap.block_size(), 0) << "block_size must be positive.";

if (util::is_deepseek_v4_model_type(model_args.model_type())) {
if (util::is_taget_model_type(model_args.model_type(),
/*target=*/"deepseek_v4",
/*match_mtp=*/true)) {
key_cache_shape_ = std::vector<int64_t>{kv_cache_cap.swa_count(),
kv_cache_cap.c4_count(),
kv_cache_cap.c128_count()};
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/layers/npu_torch/fused_moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ FusedMoEImpl::FusedMoEImpl(const ModelArgs& model_args,
n_shared_experts_(model_args.n_shared_experts()),
is_gated_(moe_args.is_gated),
skip_gate_load_(moe_args.skip_gate_load),
is_deepseek_v4_(util::is_deepseek_v4_model_type(model_args.model_type())),
is_deepseek_v4_(util::is_taget_model_type(model_args.model_type(),
/*target=*/"deepseek_v4",
/*match_mtp=*/true)),
renormalize_(model_args.norm_topk_prob() ? 1 : 0),
hidden_act_(model_args.hidden_act()),
scoring_func_(model_args.scoring_func()),
Expand Down
69 changes: 38 additions & 31 deletions xllm/core/runtime/acl_graph_executor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,13 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
bool return_capture_params) {
CHECK_GT(padded_num_tokens, 0)
<< "padded_num_tokens must be > 0 when return_capture_params is true";
const uint32_t actual_num_tokens = tokens.size(0);
const int64_t actual_num_tokens = tokens.size(0);
int64_t actual_batch_size = infer_actual_batch_size(params);
if (params.batch_forward_type.is_decode()) {
const int64_t decode_tokens =
std::max<int64_t>(options_.num_decoding_tokens(), 1);
actual_batch_size = actual_num_tokens / decode_tokens;
}

// Copy data from input parameters to persistent graph tensors
if (actual_num_tokens > 0) {
persistent_tokens_.slice(/*dim=*/0, /*start=*/0, /*end=*/actual_num_tokens)
Expand Down Expand Up @@ -461,10 +460,10 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
}
}
int64_t q_copy_len = 0;
if (actual_batch_size > 0 && params.q_seq_lens.defined() &&
if (actual_num_tokens > 0 && params.q_seq_lens.defined() &&
params.q_seq_lens.dim() >= 1 && params.q_seq_lens.numel() > 0) {
q_copy_len =
std::min<int64_t>(actual_batch_size, params.q_seq_lens.size(0));
std::min<int64_t>(actual_num_tokens, params.q_seq_lens.size(0));
if (q_copy_len > 0) {
q_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_copy_len)
.copy_(params.q_seq_lens.slice(/*dim=*/0,
Expand All @@ -474,10 +473,10 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
}
}
int64_t kv_copy_len = 0;
if (actual_batch_size > 0 && params.kv_seq_lens.defined() &&
if (actual_num_tokens > 0 && params.kv_seq_lens.defined() &&
params.kv_seq_lens.dim() >= 1 && params.kv_seq_lens.numel() > 0) {
kv_copy_len =
std::min<int64_t>(actual_batch_size, params.kv_seq_lens.size(0));
std::min<int64_t>(actual_num_tokens, params.kv_seq_lens.size(0));
if (kv_copy_len > 0) {
kv_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/kv_copy_len)
.copy_(params.kv_seq_lens.slice(/*dim=*/0,
Expand All @@ -488,19 +487,18 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
}
// Keep padded decode slots valid for empty/local-short DP shards.
// These tensors are consumed by ATB setup alongside *_seq_lens_vec.
const int64_t padded_batch_size = static_cast<int64_t>(padded_num_tokens);
if (q_copy_len < padded_batch_size) {
if (q_copy_len < static_cast<int64_t>(padded_num_tokens)) {
q_seq_lens_
.slice(/*dim=*/0,
/*start=*/q_copy_len,
/*end=*/padded_batch_size)
/*end=*/static_cast<int64_t>(padded_num_tokens))
.fill_(1);
}
if (kv_copy_len < padded_batch_size) {
if (kv_copy_len < static_cast<int64_t>(padded_num_tokens)) {
kv_seq_lens_
.slice(/*dim=*/0,
/*start=*/kv_copy_len,
/*end=*/padded_batch_size)
/*end=*/static_cast<int64_t>(padded_num_tokens))
.fill_(1);
}

Expand Down Expand Up @@ -551,10 +549,10 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
// zeroed in the active bucket slice.
int64_t block_rows_to_copy = 0;
int64_t actual_block_table_len = 0;
if (actual_batch_size > 0 && params.block_tables.defined() &&
if (actual_num_tokens > 0 && params.block_tables.defined() &&
params.block_tables.dim() >= 2 && params.block_tables.numel() > 0) {
block_rows_to_copy =
std::min<int64_t>(actual_batch_size, params.block_tables.size(0));
std::min<int64_t>(actual_num_tokens, params.block_tables.size(0));
actual_block_table_len = params.block_tables.size(1);
if (block_rows_to_copy > 0 && actual_block_table_len > 0) {
auto slice_persistent_block_tables =
Expand All @@ -577,9 +575,10 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
/*end=*/persistent_block_tables_.size(1))
.zero_();
}
if (actual_batch_size < padded_batch_size) {
zero_tensor_tail(
persistent_block_tables_, actual_batch_size, padded_batch_size);
if (actual_num_tokens < static_cast<int64_t>(padded_num_tokens)) {
zero_tensor_tail(persistent_block_tables_,
actual_num_tokens,
static_cast<int64_t>(padded_num_tokens));
}

// Update persistent embedding from input_embedding if available
Expand Down Expand Up @@ -611,18 +610,19 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
const int64_t q_cu_size = (has_q_cu && params.q_cu_seq_lens.numel() > 0)
? params.q_cu_seq_lens.size(0)
: 0;
const int64_t q_cu_copy_len = std::min<int64_t>(actual_batch_size, q_cu_size);
const int64_t q_cu_copy_len = std::min<int64_t>(actual_num_tokens, q_cu_size);
if (q_cu_copy_len > 0) {
q_cu_seq_lens_.slice(/*dim=*/0, /*start=*/0, /*end=*/q_cu_copy_len)
.copy_(params.q_cu_seq_lens.slice(/*dim=*/0,
/*start=*/0,
/*end=*/q_cu_copy_len),
/*non_blocking=*/true);
}
if (padded_batch_size > q_cu_copy_len) {
auto tail_q_seq_lens = q_seq_lens_.slice(/*dim=*/0,
/*start=*/q_cu_copy_len,
/*end=*/padded_batch_size);
if (static_cast<int64_t>(padded_num_tokens) > q_cu_copy_len) {
auto tail_q_seq_lens = q_seq_lens_.slice(
/*dim=*/0,
/*start=*/q_cu_copy_len,
/*end=*/static_cast<int64_t>(padded_num_tokens));
auto tail_cu = torch::cumsum(tail_q_seq_lens, /*dim=*/0);
if (q_cu_copy_len > 0) {
auto last_prefix = q_cu_seq_lens_.slice(/*dim=*/0,
Expand All @@ -631,7 +631,9 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
tail_cu = tail_cu + last_prefix;
}
q_cu_seq_lens_
.slice(/*dim=*/0, /*start=*/q_cu_copy_len, /*end=*/padded_batch_size)
.slice(/*dim=*/0,
/*start=*/q_cu_copy_len,
/*end=*/static_cast<int64_t>(padded_num_tokens))
.copy_(tail_cu, /*non_blocking=*/true);
}

Expand Down Expand Up @@ -660,20 +662,24 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
if (return_capture_params) {
std::optional<ModelInputParams> params_for_capture =
std::make_optional<ModelInputParams>(params);
// Set persistent buffers in params_for_capture

params_for_capture->actual_num_sequences =
static_cast<int32_t>(actual_batch_size);
static_cast<int32_t>(actual_num_tokens);
params_for_capture->kv_seq_lens = kv_seq_lens(padded_num_tokens);
params_for_capture->q_seq_lens = q_seq_lens(padded_num_tokens);

// Copy real seq_lens rows [0, actual_num_tokens), fill padding
// rows [actual_num_tokens, padded_num_tokens) with 1 so that the
// graph-captured forward sees valid (non-zero) seq_lens everywhere.
params_for_capture->kv_seq_lens_vec.resize(padded_num_tokens);
params_for_capture->q_seq_lens_vec.resize(padded_num_tokens);
// Copy actual values from original params
for (int i = 0; i < actual_batch_size; i++) {
for (int64_t i = 0; i < actual_num_tokens; ++i) {
params_for_capture->kv_seq_lens_vec[i] = params.kv_seq_lens_vec[i];
params_for_capture->q_seq_lens_vec[i] = params.q_seq_lens_vec[i];
}
// Fill padded positions with default values
for (int i = actual_batch_size; i < padded_num_tokens; i++) {
for (int64_t i = actual_num_tokens;
i < static_cast<int64_t>(padded_num_tokens);
++i) {
params_for_capture->kv_seq_lens_vec[i] = 1;
params_for_capture->q_seq_lens_vec[i] = 1;
}
Expand Down Expand Up @@ -716,9 +722,10 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
// Set q_cu_seq_lens if available
if (params.q_cu_seq_lens.defined()) {
params_for_capture->q_cu_seq_lens =
q_cu_seq_lens_.slice(/*dim=*/0,
/*start=*/0,
/*end=*/padded_batch_size);
q_cu_seq_lens_.slice(
/*dim=*/0,
/*start=*/0,
/*end=*/static_cast<int64_t>(padded_num_tokens));
}

// Replace dp/cp ep padding with slices of persistent buffers so that
Expand Down
24 changes: 19 additions & 5 deletions xllm/core/runtime/params_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,23 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,
input_params.batch_forward_type =
BatchForwardType(pb_forward_input->batch_forward_type());
input_params.num_sequences = pb_forward_input->num_sequences();
if (!block_tables_vec.empty()) {
CHECK_EQ(block_tables_vec.size(),
static_cast<size_t>(input_params.num_sequences))
<< "block_tables_vec size (" << block_tables_vec.size()
<< ") must match num_sequences (" << input_params.num_sequences << ")";
} else {
// support for mutli_block_tables
for (int32_t m = 0; m < pb_forward_input->multi_block_tables_vec().size();
++m) {
const auto& mgr = pb_forward_input->multi_block_tables_vec()[m];
CHECK_EQ(static_cast<size_t>(mgr.block_tables().size()),
static_cast<size_t>(input_params.num_sequences))
<< "multi_block_tables[" << m << "] size ("
<< mgr.block_tables().size() << ") must match num_sequences ("
<< input_params.num_sequences << ")";
}
}
input_params.kv_max_seq_len = pb_forward_input->max_seq_len();
input_params.q_max_seq_len = pb_forward_input->q_max_seq_len();
input_params.kv_seq_lens = torch::tensor(seq_lens, tensor_options);
Expand Down Expand Up @@ -252,12 +269,9 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,
std::move(create_2d_tensor(block_tables_vec, torch::kInt));

// multi block manager support for DeepSeek V4
for (int m = 0; m < pb_forward_input->multi_block_tables_vec().size(); ++m) {
for (int32_t m = 0; m < pb_forward_input->multi_block_tables_vec().size();
++m) {
const auto& mgr = pb_forward_input->multi_block_tables_vec()[m];
CHECK_EQ(static_cast<size_t>(mgr.block_tables().size()),
static_cast<size_t>(input_params.num_sequences))
<< "multi_block_tables[" << m << "] size (" << mgr.block_tables().size()
<< ") must match num_sequences (" << input_params.num_sequences << ")";
std::vector<std::vector<int32_t>> mgr_tables;
mgr_tables.reserve(mgr.block_tables().size());
for (int s = 0; s < mgr.block_tables().size(); ++s) {
Expand Down
33 changes: 12 additions & 21 deletions xllm/core/util/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,27 +131,18 @@ inline bool is_mla_model_type(std::string_view model_type) {
return mla_model_type_set().contains(std::string(model_type));
}

inline bool has_mtp_model_type_marker(std::string_view model_type) {
return model_type.find("mtp") != std::string_view::npos;
}

inline bool starts_with_model_type(std::string_view model_type,
std::string_view target_model_type) {
return model_type.size() >= target_model_type.size() &&
model_type.compare(
0, target_model_type.size(), target_model_type) == 0;
}

inline bool is_target_mtp_model_type(std::string_view model_type,
std::string_view target_model_type) {
return starts_with_model_type(model_type, target_model_type) &&
has_mtp_model_type_marker(model_type);
}

inline bool is_deepseek_v4_model_type(std::string_view model_type) {
constexpr std::string_view kTargetModelType = "deepseek_v4";
return model_type == kTargetModelType ||
is_target_mtp_model_type(model_type, kTargetModelType);
inline bool is_taget_model_type(std::string_view model_type,
std::string_view target,
bool match_mtp) {
if (model_type == target) {
return true;
}
if (!match_mtp) {
return false;
}
std::string target_mtp(target);
target_mtp += "_mtp";
return model_type == target_mtp;
}

inline std::string get_model_name(
Expand Down
Loading