From c62a89dbf3960e989893c7ce9232d954596660be Mon Sep 17 00:00:00 2001 From: Yingxu Deng Date: Sat, 9 May 2026 20:24:45 +0800 Subject: [PATCH 1/2] bugfix: change single block manager blocks to max concurrent requests. (#1413) Co-authored-by: kangmeng3 --- xllm/core/framework/block/block_manager_pool.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xllm/core/framework/block/block_manager_pool.cpp b/xllm/core/framework/block/block_manager_pool.cpp index cf1734a05..4451f77f1 100644 --- a/xllm/core/framework/block/block_manager_pool.cpp +++ b/xllm/core/framework/block/block_manager_pool.cpp @@ -83,9 +83,7 @@ BlockManagerPool::BlockManagerPool(const Options& options, int32_t dp_size) // pool. Worker-side embedding and linear-state caches remain physically // separate and are addressed via transport fields. single_block_managers_.emplace_back(std::make_unique( - /*num_blocks=*/::xllm::SchedulerConfig::get_instance() - .max_seqs_per_batch() + - 2, + /*num_blocks=*/FLAGS_max_concurrent_requests + 2, /*resource_name=*/"single block", /*exhaustion_message=*/"No more single-block ids available")); } From 446c0574dc0a811d11437b2fa845c1117f87837d Mon Sep 17 00:00:00 2001 From: Joey Gao <1783198484@qq.com> Date: Mon, 11 May 2026 18:16:26 +0800 Subject: [PATCH 2/2] bugfix: fix high concurrency linear state overflow (#1422) Co-authored-by: pjgao --- tests/core/framework/kv_cache/kv_cache_estimation_test.cpp | 1 + xllm/core/distributed_runtime/llm_engine.cpp | 3 +++ xllm/core/framework/block/block_manager_pool.cpp | 4 +++- xllm/core/framework/config/scheduler_config.cpp | 7 +++++++ xllm/core/framework/config/scheduler_config.h | 3 +++ xllm/core/framework/kv_cache/kv_cache_estimation.cpp | 6 +++--- xllm/core/framework/kv_cache/kv_cache_estimation.h | 1 + 7 files changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/core/framework/kv_cache/kv_cache_estimation_test.cpp b/tests/core/framework/kv_cache/kv_cache_estimation_test.cpp index 65032fab6..fdb87600e 100644 --- a/tests/core/framework/kv_cache/kv_cache_estimation_test.cpp +++ b/tests/core/framework/kv_cache/kv_cache_estimation_test.cpp @@ -37,6 +37,7 @@ KVCacheEstimateOptions make_estimate_options() { options.world_size = 1; options.n_local_kv_heads = 2; options.max_seqs_per_batch = 8; + options.max_concurrent_requests = 8; return options; } diff --git a/xllm/core/distributed_runtime/llm_engine.cpp b/xllm/core/distributed_runtime/llm_engine.cpp index 010c11bf3..e32d8b3e8 100644 --- a/xllm/core/distributed_runtime/llm_engine.cpp +++ b/xllm/core/distributed_runtime/llm_engine.cpp @@ -437,6 +437,8 @@ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { estimate_options.n_local_linear_v_heads = n_local_linear_v_heads_; estimate_options.max_seqs_per_batch = static_cast(options_.max_seqs_per_batch()); + estimate_options.max_concurrent_requests = + static_cast(::xllm::SchedulerConfig::get_instance().max_concurrent_requests()); estimate_options.is_draft_engine = options_.is_draft_engine(); estimate_options.enable_prefix_cache = ::xllm::KVCacheConfig::get_instance().enable_prefix_cache(); @@ -452,6 +454,7 @@ KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() { DeviceMonitor::get_instance().set_total_activation_memory(device.index()); } + return kv_cache_cap; } diff --git a/xllm/core/framework/block/block_manager_pool.cpp b/xllm/core/framework/block/block_manager_pool.cpp index 4451f77f1..0ff2329e9 100644 --- a/xllm/core/framework/block/block_manager_pool.cpp +++ b/xllm/core/framework/block/block_manager_pool.cpp @@ -83,7 +83,9 @@ BlockManagerPool::BlockManagerPool(const Options& options, int32_t dp_size) // pool. Worker-side embedding and linear-state caches remain physically // separate and are addressed via transport fields. single_block_managers_.emplace_back(std::make_unique( - /*num_blocks=*/FLAGS_max_concurrent_requests + 2, + /*num_blocks=*/::xllm::SchedulerConfig::get_instance() + .max_concurrent_requests() + + 2, /*resource_name=*/"single block", /*exhaustion_message=*/"No more single-block ids available")); } diff --git a/xllm/core/framework/config/scheduler_config.cpp b/xllm/core/framework/config/scheduler_config.cpp index 85c7dc9cd..43152b455 100644 --- a/xllm/core/framework/config/scheduler_config.cpp +++ b/xllm/core/framework/config/scheduler_config.cpp @@ -22,6 +22,10 @@ DEFINE_int32(max_tokens_per_batch, 10240, "Max number of tokens per batch."); DEFINE_int32(max_seqs_per_batch, 1024, "Max number of sequences per batch."); +DEFINE_int32(max_concurrent_requests, + 200, + "Max number of concurrent requests."); + DEFINE_bool(enable_schedule_overlap, false, "Whether to enable schedule overlap."); @@ -78,6 +82,7 @@ namespace xllm { void SchedulerConfig::from_flags() { max_tokens_per_batch(FLAGS_max_tokens_per_batch) .max_seqs_per_batch(FLAGS_max_seqs_per_batch) + .max_concurrent_requests(FLAGS_max_concurrent_requests) .enable_schedule_overlap(FLAGS_enable_schedule_overlap) .prefill_scheduling_memory_usage_threshold( FLAGS_prefill_scheduling_memory_usage_threshold) @@ -99,6 +104,8 @@ void SchedulerConfig::from_json(const JsonReader& json) { json.value_or("max_tokens_per_batch", max_tokens_per_batch())) .max_seqs_per_batch( json.value_or("max_seqs_per_batch", max_seqs_per_batch())) + .max_concurrent_requests(json.value_or( + "max_concurrent_requests", max_concurrent_requests())) .enable_schedule_overlap(json.value_or("enable_schedule_overlap", enable_schedule_overlap())) .prefill_scheduling_memory_usage_threshold( diff --git a/xllm/core/framework/config/scheduler_config.h b/xllm/core/framework/config/scheduler_config.h index cace4457f..e20af7495 100644 --- a/xllm/core/framework/config/scheduler_config.h +++ b/xllm/core/framework/config/scheduler_config.h @@ -41,6 +41,7 @@ class SchedulerConfig final { "SCHEDULER OPTIONS", {"max_tokens_per_batch", "max_seqs_per_batch", + "max_concurrent_requests", "enable_schedule_overlap", "prefill_scheduling_memory_usage_threshold", "enable_chunked_prefill", @@ -61,6 +62,8 @@ class SchedulerConfig final { PROPERTY(int32_t, max_seqs_per_batch) = 1024; + PROPERTY(int32_t, max_concurrent_requests) = 200; + PROPERTY(bool, enable_schedule_overlap) = false; PROPERTY(double, prefill_scheduling_memory_usage_threshold) = 0.95; diff --git a/xllm/core/framework/kv_cache/kv_cache_estimation.cpp b/xllm/core/framework/kv_cache/kv_cache_estimation.cpp index 362c2c4fd..507f0ca52 100644 --- a/xllm/core/framework/kv_cache/kv_cache_estimation.cpp +++ b/xllm/core/framework/kv_cache/kv_cache_estimation.cpp @@ -215,7 +215,7 @@ void init_dsv4_counts(const ModelArgs& model_args, void init_standard_counts(const ModelArgs& model_args, const KVCacheEstimateOptions& options, KVCacheCapacity* kv_cache_cap) { - kv_cache_cap->num_linear_state_blocks(options.max_seqs_per_batch + 2); + kv_cache_cap->num_linear_state_blocks(options.max_concurrent_requests + 2); for (int64_t layer_id = 0; layer_id < kv_cache_cap->n_layers(); ++layer_id) { if (is_full_attention_layer(model_args, layer_id)) { ++kv_cache_cap->num_full_attention_layers(); @@ -241,8 +241,8 @@ void init_standard_counts(const ModelArgs& model_args, kv_cache_cap->linear_cache_size_in_bytes()) << "failed to reserve linear state cache for linear-attention " "layers: " - << "max_seqs_per_batch (" << options.max_seqs_per_batch - << ") is too large. Please reduce max_seqs_per_batch to less than " + << "max_concurrent_requests (" << options.max_concurrent_requests + << ") is too large. Please reduce max_concurrent_requests to less than " << kv_cache_cap->cache_size_in_bytes() / (kv_cache_cap->num_linear_attention_layers() * kv_cache_cap->linear_slot_size()) - diff --git a/xllm/core/framework/kv_cache/kv_cache_estimation.h b/xllm/core/framework/kv_cache/kv_cache_estimation.h index 01ba7374b..18a12b9d9 100644 --- a/xllm/core/framework/kv_cache/kv_cache_estimation.h +++ b/xllm/core/framework/kv_cache/kv_cache_estimation.h @@ -36,6 +36,7 @@ struct KVCacheEstimateOptions { int64_t n_local_linear_k_heads = 0; int64_t n_local_linear_v_heads = 0; int64_t max_seqs_per_batch = 0; + int64_t max_concurrent_requests = 0; bool is_draft_engine = false; bool enable_prefix_cache = false; };