diff --git a/tests/core/scheduler/fixed_steps_scheduler_test.cpp b/tests/core/scheduler/fixed_steps_scheduler_test.cpp index 7acb9f2f15..829b6251f1 100644 --- a/tests/core/scheduler/fixed_steps_scheduler_test.cpp +++ b/tests/core/scheduler/fixed_steps_scheduler_test.cpp @@ -15,15 +15,21 @@ limitations under the License. #include "fixed_steps_scheduler.h" +#include #include #include #include +#include +#include +#include +#include #include "common/global_flags.h" #include "continuous_scheduler.h" #include "distributed_runtime/engine.h" #include "framework/request/rec_type.h" +#include "util/scope_guard.h" namespace xllm { @@ -71,7 +77,16 @@ class FakeEngine : public Engine { fake_block_manager_ = std::make_unique(opt, 1); } ForwardOutput step(std::vector& batch) override { - (void)batch; + { + std::lock_guard lock(mutex_); + ++step_calls_; + step_batch_sizes_.emplace_back( + batch.empty() ? 0 : batch[0].get_sequences().size()); + if (step_calls_ >= notify_after_step_calls_ && + !step_calls_ready_.HasBeenNotified()) { + step_calls_ready_.Notify(); + } + } return ForwardOutput(); } void update_last_step_result(std::vector& batch) override { @@ -94,9 +109,33 @@ class FakeEngine : public Engine { } bool init() override { return true; } + int32_t step_calls() const { + std::lock_guard lock(mutex_); + return step_calls_; + } + + std::vector step_batch_sizes() const { + std::lock_guard lock(mutex_); + return step_batch_sizes_; + } + + void notify_after_step_calls(int32_t step_calls) { + std::lock_guard lock(mutex_); + notify_after_step_calls_ = step_calls; + } + + bool wait_for_step_calls(absl::Duration timeout) const { + return step_calls_ready_.WaitForNotificationWithTimeout(timeout); + } + private: std::unique_ptr fake_tokenizer_; std::unique_ptr fake_block_manager_; + mutable std::mutex mutex_; + int32_t step_calls_ = 0; + int32_t notify_after_step_calls_ = std::numeric_limits::max(); + std::vector step_batch_sizes_; + mutable absl::Notification step_calls_ready_; }; ContinuousScheduler::Options CreateOptions( @@ -228,4 +267,92 @@ TEST(FixedStepsSchedulerTest, StepCompletesWithRequest) { EXPECT_NO_THROW(scheduler.step(absl::Milliseconds(500))); } +TEST(FixedStepsSchedulerTest, + OneRecXAttentionMultiConcurrencyPrepareBatchKeepsOneRequest) { + const bool old_enable_prefix_cache = FLAGS_enable_prefix_cache; + const int32_t old_max_decode_rounds = FLAGS_max_decode_rounds; + const double old_prefill_threshold = + FLAGS_prefill_scheduling_memory_usage_threshold; + SCOPE_GUARD([&] { + FLAGS_enable_prefix_cache = old_enable_prefix_cache; + FLAGS_max_decode_rounds = old_max_decode_rounds; + FLAGS_prefill_scheduling_memory_usage_threshold = old_prefill_threshold; + }); + + FLAGS_enable_prefix_cache = false; + FLAGS_max_decode_rounds = 2; + FLAGS_prefill_scheduling_memory_usage_threshold = 1.0; + + std::unique_ptr engine = + std::make_unique(/*num_blocks=*/64, /*block_size=*/128); + ContinuousScheduler::Options opt = CreateOptions( + /*max_tokens_per_batch=*/10000, + /*max_seqs_per_batch=*/256, + /*dp_size=*/1, + /*enable_schedule_overlap=*/false, + /*rec_worker_max_concurrency=*/2); + FixedStepsScheduler scheduler(engine.get(), opt); + std::vector> requests = GenRequests( + /*prompt_lens=*/{1, 1}, /*max_tokens=*/{1, 1}, RecType::kOneRec); + for (std::shared_ptr& req : requests) { + scheduler.add_request(req); + } + + ContinuousScheduler* base = &scheduler; + std::vector batches = base->prepare_batch_test(); + ASSERT_FALSE(batches.empty()); + EXPECT_EQ(batches[0].get_sequences().size(), 1u); + EXPECT_EQ(base->get_running_requests().size(), 1u); + EXPECT_EQ(base->get_waiting_requests().size(), 1u); +} + +TEST(FixedStepsSchedulerTest, + OneRecXAttentionMultiConcurrencySchedulesSeparateSteps) { + const bool old_enable_prefix_cache = FLAGS_enable_prefix_cache; + const int32_t old_max_decode_rounds = FLAGS_max_decode_rounds; + const double old_prefill_threshold = + FLAGS_prefill_scheduling_memory_usage_threshold; + SCOPE_GUARD([&] { + FLAGS_enable_prefix_cache = old_enable_prefix_cache; + FLAGS_max_decode_rounds = old_max_decode_rounds; + FLAGS_prefill_scheduling_memory_usage_threshold = old_prefill_threshold; + }); + + FLAGS_enable_prefix_cache = false; + FLAGS_max_decode_rounds = 2; + FLAGS_prefill_scheduling_memory_usage_threshold = 1.0; + + std::unique_ptr engine = + std::make_unique(/*num_blocks=*/3, /*block_size=*/128); + engine->notify_after_step_calls(/*step_calls=*/2); + { + ContinuousScheduler::Options opt = CreateOptions( + /*max_tokens_per_batch=*/10000, + /*max_seqs_per_batch=*/256, + /*dp_size=*/1, + /*enable_schedule_overlap=*/false, + /*rec_worker_max_concurrency=*/2); + FixedStepsScheduler scheduler(engine.get(), opt); + + std::vector> requests = + GenRequests(/*prompt_lens=*/{1, 1, 1}, + /*max_tokens=*/{1, 1, 1}, + RecType::kOneRec); + for (std::shared_ptr& req : requests) { + scheduler.add_request(req); + } + + scheduler.step(absl::ZeroDuration()); + scheduler.step(absl::ZeroDuration()); + + EXPECT_TRUE(engine->wait_for_step_calls(absl::Milliseconds(500))); + EXPECT_EQ(engine->step_calls(), 2); + const std::vector step_batch_sizes = engine->step_batch_sizes(); + ASSERT_EQ(step_batch_sizes.size(), 2u); + EXPECT_EQ(step_batch_sizes[0], 1u); + EXPECT_EQ(step_batch_sizes[1], 1u); + EXPECT_EQ(scheduler.get_waiting_requests().size(), 1u); + } +} + } // namespace xllm diff --git a/third_party/xllm_atb_layers b/third_party/xllm_atb_layers index 49e2cf8923..3555e21b9f 160000 --- a/third_party/xllm_atb_layers +++ b/third_party/xllm_atb_layers @@ -1 +1 @@ -Subproject commit 49e2cf892339d58a09dce5120cd0cc94f5f9d0b0 +Subproject commit 3555e21b9f213e412f8fb0d106e1c3f3083490e0 diff --git a/xllm/core/distributed_runtime/rec_engine.cpp b/xllm/core/distributed_runtime/rec_engine.cpp index 2dc36cad76..6b07648204 100644 --- a/xllm/core/distributed_runtime/rec_engine.cpp +++ b/xllm/core/distributed_runtime/rec_engine.cpp @@ -42,6 +42,14 @@ namespace xllm { namespace { constexpr int64_t kMinimalOneRecMetadataKVBlocks = 2; +constexpr const char* kOnerecXAttentionLockTimingEnv = + "XLLM_DEBUG_ONEREC_XATTN_LOCK_TIMING"; + +bool enable_onerec_xattention_lock_timing() { + static const bool enable_lock_timing = + util::get_bool_env(kOnerecXAttentionLockTimingEnv, false); + return enable_lock_timing; +} } // namespace @@ -813,7 +821,11 @@ RecEngine::OneRecXAttentionEnginePipeline::OneRecXAttentionEnginePipeline( int64_t RecEngine::OneRecXAttentionEnginePipeline::minimal_kv_cache_blocks() const { - return kMinimalOneRecMetadataKVBlocks; + const int64_t max_concurrency = + std::max(1, engine_.options_.rec_worker_max_concurrency()); + // BlockManagerImpl reserves block 0 for padding, so xattention needs one + // allocatable metadata block per concurrent scheduler step plus padding. + return std::max(kMinimalOneRecMetadataKVBlocks, max_concurrency + 1); } ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::step( @@ -824,7 +836,12 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::step( Timer timer; auto forward_inputs = engine_.workers_[0]->prepare_inputs(batches[0]); - COUNTER_ADD(prepare_input_latency_microseconds, timer.elapsed_microseconds()); + const double prepare_inputs_us = timer.elapsed_microseconds(); + COUNTER_ADD(prepare_input_latency_microseconds, prepare_inputs_us); + if (enable_onerec_xattention_lock_timing()) { + LOG(INFO) << "OneRec xattention engine host timing, " + << "stage=prepare_inputs, elapsed_us=" << prepare_inputs_us; + } if (!forward_inputs.token_ids.defined()) { return {}; @@ -832,8 +849,12 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::step( timer.reset(); const auto& output = get_model_output(forward_inputs); - COUNTER_ADD(rec_first_token_latency_microseconds, - timer.elapsed_microseconds()); + const double model_output_us = timer.elapsed_microseconds(); + COUNTER_ADD(rec_first_token_latency_microseconds, model_output_us); + if (enable_onerec_xattention_lock_timing()) { + LOG(INFO) << "OneRec xattention engine host timing, " + << "stage=get_model_output, elapsed_us=" << model_output_us; + } timer.reset(); if (output.beam_sequence_group.defined() && @@ -842,7 +863,12 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::step( } else { batches[0].process_sample_output(output.sample_output, false); } - COUNTER_ADD(rec_sampling_latency_microseconds, timer.elapsed_microseconds()); + const double process_output_us = timer.elapsed_microseconds(); + COUNTER_ADD(rec_sampling_latency_microseconds, process_output_us); + if (enable_onerec_xattention_lock_timing()) { + LOG(INFO) << "OneRec xattention engine host timing, " + << "stage=process_output, elapsed_us=" << process_output_us; + } batches[0].finish(); return output; @@ -854,6 +880,7 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::get_model_output( util::get_bool_env("XLLM_DEBUG_ONEREC_ENGINE_TRACE", false); const bool trace_stage_timing = util::get_bool_env("XLLM_DEBUG_ONEREC_XATTN_STAGE_TIMING", false); + const bool trace_lock_timing = enable_onerec_xattention_lock_timing(); Timer engine_timer; auto log_engine_stage = [&](const char* stage_name, const torch::Tensor& tensor = torch::Tensor()) { @@ -865,10 +892,10 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::get_model_output( << (tensor.defined() ? tensor.sizes() : c10::IntArrayRef{}); }; auto log_engine_timing = [&](const char* stage_name) { - if (!trace_stage_timing) { + if (!trace_stage_timing && !trace_lock_timing) { return; } - LOG(INFO) << "OneRec xattention engine timing, stage=" << stage_name + LOG(INFO) << "OneRec xattention engine host timing, stage=" << stage_name << ", elapsed_us=" << engine_timer.elapsed_microseconds(); engine_timer.reset(); }; diff --git a/xllm/core/framework/batch/onerec_batch_input_builder.cpp b/xllm/core/framework/batch/onerec_batch_input_builder.cpp index a96de9706e..6a3256e7fe 100644 --- a/xllm/core/framework/batch/onerec_batch_input_builder.cpp +++ b/xllm/core/framework/batch/onerec_batch_input_builder.cpp @@ -30,12 +30,11 @@ limitations under the License. namespace xllm { -// Use Meyers' Singleton pattern to avoid static initialization order fiasco -// This ensures the cache is initialized on first use, after all dependencies -// (like PyTorch runtime) are properly initialized. +// Keep the cache per scheduler thread. OneRec xattention multi-stream can build +// multiple inputs concurrently, and CacheData is mutated during construction. OneRecBatchInputBuilder::HighPerformanceCache& OneRecBatchInputBuilder::get_perf_cache() { - static HighPerformanceCache cache; + thread_local HighPerformanceCache cache; cache.ensure_tensors_initialized(); return cache; } @@ -59,7 +58,7 @@ OneRecBatchInputBuilder::OneRecBatchInputBuilder( args_(args), batch_forward_type_(batch_forward_type), thread_pool_(thread_pool) { - // Get references to function-local statics (safe initialization) + // Reset only this thread's reusable scratch vectors. auto& perf_cache = get_perf_cache(); perf_cache.memory_pool.reset(); } @@ -67,7 +66,7 @@ OneRecBatchInputBuilder::OneRecBatchInputBuilder( ForwardInput OneRecBatchInputBuilder::build_rec_forward_input( uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size) { - // Get reference to function-local static cache (safe initialization) + // Get this thread's reusable scratch cache. auto& perf_cache = get_perf_cache(); // ========== Global constant cache ========== diff --git a/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp b/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp index 62622044bc..de15678fd1 100644 --- a/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp +++ b/xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -178,7 +179,8 @@ void beam_search_rec(const torch::Tensor& logprobs, aclTensor* out_beam_count_prefix_sums_ids = nullptr; aclTensor* out_sequence_ids = nullptr; int32_t device_id = logprobs.device().index(); - aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream(); + c10_npu::NPUStream npu_stream = c10_npu::getCurrentNPUStream(device_id); + aclrtStream stream = npu_stream.stream(); create_acltensor(&logprobs_ids, logprobs); create_acltensor(&top_tokens_ids, top_tokens); @@ -209,17 +211,18 @@ void beam_search_rec(const torch::Tensor& logprobs, &executor), "beam_search_rec: failed to get workspace size for REC beam search"); void* workspace_addr = nullptr; + torch::Tensor workspace_tensor; if (workspace_size > 0) { - CHECK_ACL_SUCCESS( - aclrtMalloc(&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST), - "beam_search_rec: failed to allocate workspace for REC beam search"); + workspace_tensor = torch::empty( + {static_cast(workspace_size)}, + torch::TensorOptions().dtype(torch::kUInt8).device(logprobs.device())); + workspace_addr = workspace_tensor.data_ptr(); + c10_npu::NPUCachingAllocator::recordStream( + workspace_tensor.storage().data_ptr(), npu_stream); } CHECK_ACL_SUCCESS( aclnnBeamSearchGroup(workspace_addr, workspace_size, executor, stream), "beam_search_rec: failed to execute REC beam search"); - CHECK_ACL_SUCCESS( - aclrtSynchronizeStream(stream), - "beam_search_rec: failed to synchronize stream for REC beam search"); aclDestroyTensor(logprobs_ids); aclDestroyTensor(top_tokens_ids); aclDestroyTensor(top_logprobs_ids); @@ -229,11 +232,6 @@ void beam_search_rec(const torch::Tensor& logprobs, aclDestroyTensor(out_log_probs_ids); aclDestroyTensor(out_beam_count_prefix_sums_ids); aclDestroyTensor(out_sequence_ids); - if (workspace_size > 0) { - CHECK_ACL_SUCCESS( - aclrtFree(workspace_addr), - "beam_search_rec: failed to free workspace for REC beam search"); - } } void beam_search_rec(const torch::Tensor& logprobs, diff --git a/xllm/core/kernels/npu/xllm_ops/select_unshared_kv.cpp b/xllm/core/kernels/npu/xllm_ops/select_unshared_kv.cpp index d2ebdf52fc..b73faac421 100644 --- a/xllm/core/kernels/npu/xllm_ops/select_unshared_kv.cpp +++ b/xllm/core/kernels/npu/xllm_ops/select_unshared_kv.cpp @@ -16,6 +16,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -90,7 +91,8 @@ void select_unshared_kv(const torch::Tensor& beam_index, create_acltensor(&block_table_ids, block_table); int32_t device_id = beam_index.device().index(); - aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream(); + c10_npu::NPUStream npu_stream = c10_npu::getCurrentNPUStream(device_id); + aclrtStream stream = npu_stream.stream(); uint64_t workspace_size = 0; aclOpExecutor* executor = nullptr; @@ -109,24 +111,23 @@ void select_unshared_kv(const torch::Tensor& beam_index, &executor), "select_unshared_kv: failed to get workspace size"); void* workspace_addr = nullptr; + torch::Tensor workspace_tensor; if (workspace_size > 0) { - CHECK_ACL_SUCCESS( - aclrtMalloc(&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST), - "select_unshared_kv: failed to allocate workspace"); + workspace_tensor = torch::empty({static_cast(workspace_size)}, + torch::TensorOptions() + .dtype(torch::kUInt8) + .device(beam_index.device())); + workspace_addr = workspace_tensor.data_ptr(); + c10_npu::NPUCachingAllocator::recordStream( + workspace_tensor.storage().data_ptr(), npu_stream); } CHECK_ACL_SUCCESS( aclnnSelectUnsharedKV(workspace_addr, workspace_size, executor, stream), "select_unshared_kv: failed to reorder caches"); - CHECK_ACL_SUCCESS(aclrtSynchronizeStream(stream), - "select_unshared_kv: failed to synchronize stream"); aclDestroyTensor(beam_index_ids); aclDestroyTensor(group_offset_ids); aclDestroyTensor(block_table_ids); aclDestroyTensorList(x_key_block_list_ids); aclDestroyTensorList(x_value_block_list_ids); - if (workspace_size > 0) { - CHECK_ACL_SUCCESS(aclrtFree(workspace_addr), - "select_unshared_kv: failed to free workspace"); - } } } // namespace xllm::kernel::npu diff --git a/xllm/core/runtime/rec_worker_impl.cpp b/xllm/core/runtime/rec_worker_impl.cpp index 62bd6126d7..2928c7122c 100644 --- a/xllm/core/runtime/rec_worker_impl.cpp +++ b/xllm/core/runtime/rec_worker_impl.cpp @@ -50,12 +50,22 @@ limitations under the License. #include "framework/state_dict/rec_vocab_dict.h" #include "models/model_registry.h" #include "util/env_var.h" +#include "util/scope_guard.h" #include "util/timer.h" namespace xllm { namespace { +constexpr const char* kOnerecXAttentionSerializeHostStagesEnv = + "XLLM_ONEREC_XATTN_SERIALIZE_HOST_STAGES"; +constexpr const char* kOnerecXAttentionSerializePrepareEnv = + "XLLM_ONEREC_XATTN_SERIALIZE_PREPARE"; +constexpr const char* kOnerecXAttentionSerializeModelForwardEnv = + "XLLM_ONEREC_XATTN_SERIALIZE_MODEL_FORWARD"; +constexpr const char* kOnerecXAttentionLockTimingEnv = + "XLLM_DEBUG_ONEREC_XATTN_LOCK_TIMING"; + RecVocabDict* get_onerec_vocab_dict(const std::string& model_weights_path) { if (model_weights_path.empty()) { return nullptr; @@ -85,6 +95,58 @@ bool enable_onerec_xattention_stage_timing() { return util::get_bool_env("XLLM_DEBUG_ONEREC_XATTN_STAGE_TIMING", false); } +bool enable_onerec_xattention_lock_timing() { + static const bool enable_lock_timing = + util::get_bool_env(kOnerecXAttentionLockTimingEnv, false); + return enable_lock_timing; +} + +bool enable_rec_pipeline_concurrency_debug() { + return util::get_bool_env("XLLM_DEBUG_REC_PIPELINE_CONCURRENCY", false); +} + +bool serialize_onerec_xattention_host_stages() { + static const bool serialize_host_stages = + util::get_bool_env(kOnerecXAttentionSerializeHostStagesEnv, true); + return serialize_host_stages; +} + +bool serialize_onerec_xattention_prepare() { + // Keep serialization enabled by default. The host-stages switch remains a + // bulk default, while the per-stage switches isolate lock-off experiments. + static const bool serialize_prepare = + util::get_bool_env(kOnerecXAttentionSerializePrepareEnv, + serialize_onerec_xattention_host_stages()); + return serialize_prepare; +} + +bool serialize_onerec_xattention_model_forward() { + static const bool serialize_model_forward = + util::get_bool_env(kOnerecXAttentionSerializeModelForwardEnv, + serialize_onerec_xattention_host_stages()); + return serialize_model_forward; +} + +#if defined(USE_NPU) +bool should_serialize_onerec_xattention_prepare(int64_t max_concurrency) { + return max_concurrency > 1 && serialize_onerec_xattention_prepare(); +} + +bool should_serialize_onerec_xattention_model_forward(int64_t max_concurrency) { + return max_concurrency > 1 && serialize_onerec_xattention_model_forward(); +} +#else +bool should_serialize_onerec_xattention_prepare(int64_t max_concurrency) { + UNUSED_PARAMETER(max_concurrency); + return false; +} + +bool should_serialize_onerec_xattention_model_forward(int64_t max_concurrency) { + UNUSED_PARAMETER(max_concurrency); + return false; +} +#endif + #if defined(USE_NPU) torch::Tensor int32_vector_to_device_tensor(const std::vector& values, const torch::Device& device) { @@ -1068,6 +1130,40 @@ ForwardInput RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_inputs( void RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_work_before_execute( const ForwardInput& inputs, ForwardInput& processed_inputs) { + std::unique_lock prepare_lock( + runtime_.worker.onerec_xattention_prepare_mutex_, std::defer_lock); + const bool serialize_prepare = should_serialize_onerec_xattention_prepare( + runtime_.worker.options_.rec_worker_max_concurrency()); + const bool trace_lock_timing = enable_onerec_xattention_lock_timing(); + double prepare_lock_wait_us = 0.0; + std::optional prepare_lock_wait_timer; + if (trace_lock_timing) { + prepare_lock_wait_timer.emplace(); + } + if (serialize_prepare) { + prepare_lock.lock(); + if (prepare_lock_wait_timer.has_value()) { + prepare_lock_wait_us = prepare_lock_wait_timer->elapsed_microseconds(); + } + } + std::optional prepare_lock_hold_timer; + if (trace_lock_timing) { + prepare_lock_hold_timer.emplace(); + } + SCOPE_GUARD([&] { + if (!trace_lock_timing) { + return; + } + const double prepare_elapsed_us = + prepare_lock_hold_timer->elapsed_microseconds(); + LOG(INFO) << "OneRec xattention lock timing, stage=prepare, serialize=" + << serialize_prepare << ", wait_us=" << prepare_lock_wait_us + << ", elapsed_us=" << prepare_elapsed_us << ", lock_hold_us=" + << (serialize_prepare ? prepare_elapsed_us : 0.0) + << ", max_concurrency=" + << runtime_.worker.options_.rec_worker_max_concurrency(); + }); + const bool trace_stage_timing = enable_onerec_xattention_stage_timing(); Timer prepare_timer; auto log_prepare_timing = [&](const char* stage_name) { @@ -1079,7 +1175,20 @@ void RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_work_before_execute( << ", elapsed_us=" << prepare_timer.elapsed_microseconds(); prepare_timer.reset(); }; + std::optional prepare_host_timer; + if (trace_lock_timing) { + prepare_host_timer.emplace(); + } + auto log_prepare_host_timing = [&](const char* stage_name) { + if (!trace_lock_timing) { + return; + } + LOG(INFO) << "OneRec xattention prepare host timing, stage=" << stage_name + << ", elapsed_us=" << prepare_host_timer->elapsed_microseconds(); + prepare_host_timer->reset(); + }; RecWorkPipeline::prepare_work_before_execute(inputs, processed_inputs); + log_prepare_host_timing("base_to_device"); log_prepare_timing("base_to_device"); #if defined(USE_NPU) @@ -1164,6 +1273,7 @@ void RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_work_before_execute( torch::TensorOptions() .dtype(torch::kInt32) .device(runtime_.worker.device())); + log_prepare_host_timing("cache_prepare"); log_prepare_timing("cache_prepare"); const int32_t beam_width = @@ -1192,18 +1302,21 @@ void RecWorkerImpl::OneRecXAttentionWorkPipeline::prepare_work_before_execute( } if (!onerec_params.decoder_context_embedding.defined()) { + log_prepare_host_timing("metadata_prepare"); log_prepare_timing("metadata_prepare"); return; } if (onerec_params.decoder_context_embedding.scalar_type() == runtime_.worker.dtype()) { + log_prepare_host_timing("metadata_prepare"); log_prepare_timing("metadata_prepare"); return; } onerec_params.decoder_context_embedding = onerec_params.decoder_context_embedding.to(runtime_.worker.dtype()); + log_prepare_host_timing("metadata_prepare"); log_prepare_timing("metadata_prepare"); } @@ -1387,6 +1500,51 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( }; #endif + auto forward_model = [&](const torch::Tensor& tokens, + const torch::Tensor& positions, + const ModelInputParams& params) { + std::unique_lock forward_lock( + runtime_.worker.onerec_xattention_model_forward_mutex_, + std::defer_lock); + const bool serialize_model_forward = + should_serialize_onerec_xattention_model_forward( + runtime_.worker.options_.rec_worker_max_concurrency()); + const bool trace_lock_timing = enable_onerec_xattention_lock_timing(); + double forward_lock_wait_us = 0.0; + std::optional forward_lock_wait_timer; + if (trace_lock_timing) { + forward_lock_wait_timer.emplace(); + } + if (serialize_model_forward) { + forward_lock.lock(); + if (forward_lock_wait_timer.has_value()) { + forward_lock_wait_us = + forward_lock_wait_timer->elapsed_microseconds(); + } + } + std::optional forward_lock_hold_timer; + if (trace_lock_timing) { + forward_lock_hold_timer.emplace(); + } + auto model_output = runtime_.executor->forward( + tokens, positions, runtime_.worker.kv_caches_, params); + if (trace_lock_timing) { + const double forward_elapsed_us = + forward_lock_hold_timer->elapsed_microseconds(); + LOG(INFO) << "OneRec xattention lock timing, stage=model_forward, " + << "serialize=" << serialize_model_forward + << ", wait_us=" << forward_lock_wait_us + << ", elapsed_us=" << forward_elapsed_us << ", lock_hold_us=" + << (serialize_model_forward ? forward_elapsed_us : 0.0) + << ", max_concurrency=" + << runtime_.worker.options_.rec_worker_max_concurrency() + << ", round=" << get_onerec_decode_round(*round_params) + << ", rec_stage=" + << static_cast(round_params->rec_stage); + } + return model_output; + }; + torch::Tensor hidden_states; if (round_params->rec_stage == OneRecModelInputParams::RecStage::PREFILL) { if (!round_params->is_first_prefill) { @@ -1400,11 +1558,8 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( decoder_onerec_params.is_encoder_forward = false; decoder_onerec_params.has_encoder_output = round_params->has_encoder_output; - auto model_output = - runtime_.executor->forward(mutable_input.token_ids, - mutable_input.positions, - runtime_.worker.kv_caches_, - decoder_params); + auto model_output = forward_model( + mutable_input.token_ids, mutable_input.positions, decoder_params); #if defined(USE_NPU) validate_selected_token_idxes_stage("decoder_forward"); #endif @@ -1436,11 +1591,8 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( encoder_tokens = round_params->encoder_token_ids; } - auto encoder_output = - runtime_.executor->forward(encoder_tokens, - round_params->encoder_positions, - runtime_.worker.kv_caches_, - encoder_params); + auto encoder_output = forward_model( + encoder_tokens, round_params->encoder_positions, encoder_params); #if defined(USE_NPU) validate_selected_token_idxes_stage("encoder_forward"); #endif @@ -1451,11 +1603,8 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( decoder_onerec_params.is_encoder_forward = false; decoder_onerec_params.has_encoder_output = encoder_output.hidden_states.defined(); - auto model_output = - runtime_.executor->forward(mutable_input.token_ids, - mutable_input.positions, - runtime_.worker.kv_caches_, - decoder_params); + auto model_output = forward_model( + mutable_input.token_ids, mutable_input.positions, decoder_params); #if defined(USE_NPU) validate_selected_token_idxes_stage("decoder_forward"); #endif @@ -1472,10 +1621,8 @@ std::optional RecWorkerImpl::OneRecXAttentionWorkPipeline::step( decoder_onerec_params.is_encoder_forward = false; decoder_onerec_params.has_encoder_output = round_params->has_encoder_output; - auto model_output = runtime_.executor->forward(mutable_input.token_ids, - mutable_input.positions, - runtime_.worker.kv_caches_, - decoder_params); + auto model_output = forward_model( + mutable_input.token_ids, mutable_input.positions, decoder_params); #if defined(USE_NPU) validate_selected_token_idxes_stage("decode_forward"); #endif @@ -2851,7 +2998,14 @@ bool RecWorkerImpl::init_model(ModelContext& context) { << "Unsupported rec model_type: " << model_type; // Create concurrent pipeline (not base class pipeline) - auto pipeline_type = get_rec_pipeline_type(rec_model_kind_); + pipeline_type_ = get_rec_pipeline_type(rec_model_kind_); +#if defined(USE_NPU) + CHECK(!(pipeline_type_ == RecPipelineType::kOneRecXAttentionPipeline && + (FLAGS_enable_graph || options_.enable_graph()) && + options_.rec_worker_max_concurrency() > 1)) + << "NPU OneRec xattention multi-stream does not support graph mode yet. " + << "Disable graph mode or set rec_worker_max_concurrency=1."; +#endif // Reserve space for model instances work_pipelines_.reserve(options_.rec_worker_max_concurrency()); @@ -2885,7 +3039,7 @@ bool RecWorkerImpl::init_model(ModelContext& context) { runtime.model.get(), runtime.worker.device()); } - work_pipelines_.emplace_back(create_pipeline(pipeline_type, runtime)); + work_pipelines_.emplace_back(create_pipeline(pipeline_type_, runtime)); index_queue_.enqueue(i); } @@ -2902,7 +3056,8 @@ bool RecWorkerImpl::init_model(ModelContext& context) { } LOG(INFO) << "Created " << work_pipelines_.size() - << " pipelines for concurrent execution"; + << " pipelines for concurrent execution, pipeline_type=" + << static_cast(pipeline_type_); return true; } @@ -3007,14 +3162,34 @@ folly::SemiFuture> RecWorkerImpl::step_async( size_t index; index_queue_.wait_dequeue(index); auto future = promise.getSemiFuture(); + if (enable_rec_pipeline_concurrency_debug()) { + LOG(INFO) << "RecWorkerImpl leased pipeline, index=" << index + << ", max_concurrency=" << options_.rec_worker_max_concurrency() + << ", pipeline_type=" << static_cast(pipeline_type_); + } // Use schedule() to assign tasks, letting ThreadPool automatically select // idle threads The logic for allocating instance_id happens when the task // executes (see lambda below) step_threadpool_->schedule_with_tid( - [this, &input, index, promise = std::move(promise)]() mutable { + [this, input, index, promise = std::move(promise)]() mutable { auto stream_guard = work_pipelines_[index]->runtime().stream->set_stream_guard(); + SCOPE_GUARD([this, index] { index_queue_.enqueue(index); }); + if (enable_rec_pipeline_concurrency_debug()) { + LOG(INFO) << "RecWorkerImpl running pipeline, index=" << index + << ", stream=" << *work_pipelines_[index]->runtime().stream; + } + +#if defined(USE_NPU) + aclrtStream current_stream = + c10_npu::getCurrentNPUStream(device_.index()).stream(); + atb::Context* atb_context = const_cast( + work_pipelines_[index]->runtime().context->get_atb_context()); + CHECK(atb_context != nullptr) + << "ATB context is null for pipeline " << index; + atb_context->SetExecuteStream(current_stream); +#endif ForwardInput input_on_device; work_pipelines_[index]->prepare_work_before_execute(input, @@ -3027,8 +3202,6 @@ folly::SemiFuture> RecWorkerImpl::step_async( const auto output = work_pipelines_[index]->step(input_on_device); promise.setValue(output); - - index_queue_.enqueue(index); }, index); diff --git a/xllm/core/runtime/rec_worker_impl.h b/xllm/core/runtime/rec_worker_impl.h index 247f2fec41..d38264807f 100644 --- a/xllm/core/runtime/rec_worker_impl.h +++ b/xllm/core/runtime/rec_worker_impl.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include #include #include @@ -373,9 +374,13 @@ class RecWorkerImpl : public LLMWorkerImpl { std::vector> work_pipelines_; RecModelKind rec_model_kind_ = RecModelKind::kNone; + RecPipelineType pipeline_type_ = RecPipelineType::kLlmRecDefault; std::unique_ptr step_threadpool_; + std::mutex onerec_xattention_prepare_mutex_; + std::mutex onerec_xattention_model_forward_mutex_; + moodycamel::BlockingConcurrentQueue index_queue_; }; diff --git a/xllm/core/scheduler/fixed_steps_scheduler.cpp b/xllm/core/scheduler/fixed_steps_scheduler.cpp index f58b472755..16b8dadcb2 100644 --- a/xllm/core/scheduler/fixed_steps_scheduler.cpp +++ b/xllm/core/scheduler/fixed_steps_scheduler.cpp @@ -22,10 +22,10 @@ limitations under the License. #include #include -#include #include #include #include +#include #include "common/metrics.h" #include "common/types.h" @@ -36,10 +36,27 @@ limitations under the License. #include "framework/request/rec_type.h" #include "framework/request/request.h" #include "framework/request/sequence.h" +#include "util/env_var.h" #include "util/rec_model_utils.h" +#include "util/scope_guard.h" +#include "util/timer.h" namespace xllm { +namespace { + +bool enable_rec_pipeline_concurrency_debug() { + return util::get_bool_env("XLLM_DEBUG_REC_PIPELINE_CONCURRENCY", false); +} + +bool enable_onerec_xattention_lock_timing() { + static const bool enable_lock_timing = + util::get_bool_env("XLLM_DEBUG_ONEREC_XATTN_LOCK_TIMING", false); + return enable_lock_timing; +} + +} // namespace + FixedStepsScheduler::FixedStepsScheduler(Engine* engine, const Options& options) : ContinuousScheduler(engine, options), step_semaphore_( @@ -79,10 +96,23 @@ void FixedStepsScheduler::handle_prefill_requests( bool blocks_exhausted = false; const bool requires_kv_cache = scheduler_pipeline_ && scheduler_pipeline_->requires_kv_cache(); + const size_t max_prefill_requests_per_step = + scheduler_pipeline_ != nullptr + ? std::max( + 1, scheduler_pipeline_->max_prefill_requests_per_step(*this)) + : std::numeric_limits::max(); + const auto can_schedule_under_kv_threshold = [this]() { + if (scheduler_pipeline_ == nullptr || + !scheduler_pipeline_->respects_prefill_memory_threshold(*this)) { + return true; + } + return kv_cache_manager_->kv_cache_utilization() < + FLAGS_prefill_scheduling_memory_usage_threshold; + }; while (!waiting_priority_queue_.empty() && remaining_seq_budget > 0 && remaining_token_budget > 0 && - kv_cache_manager_->kv_cache_utilization() < - FLAGS_prefill_scheduling_memory_usage_threshold) { + running_requests_.size() < max_prefill_requests_per_step && + can_schedule_under_kv_threshold()) { std::shared_ptr request(waiting_priority_queue_.top()); if (request->finished() || request->cancelled()) { if (requires_kv_cache) { @@ -168,8 +198,13 @@ void FixedStepsScheduler::handle_prefill_requests( prefill_sequences_budget.end()); } + const bool resource_temporarily_exhausted = + blocks_exhausted || !can_schedule_under_kv_threshold(); + const bool should_defer_resource_exhausted = + resource_temporarily_exhausted && scheduler_pipeline_ != nullptr && + scheduler_pipeline_->should_defer_resource_exhausted(*this); if (running_sequences_.empty() && !waiting_priority_queue_.empty() && - running_queue_->empty()) { + !should_defer_resource_exhausted && running_queue_->empty()) { LOG(ERROR) << "Request prompt is too long, no enough budget/memory to schedule " "a single sequence."; @@ -338,7 +373,18 @@ ScheduleResult FixedStepsScheduler::schedule_request( void FixedStepsScheduler::step(const absl::Duration& timeout) { if (!options_.enable_schedule_overlap()) { // get a new batch of requests + const bool trace_lock_timing = enable_onerec_xattention_lock_timing(); + const bool trace_multistream_timing = + trace_lock_timing && options_.rec_worker_max_concurrency() > 1; + std::optional schedule_timer; + if (trace_multistream_timing) { + schedule_timer.emplace(); + } ScheduleResult result = schedule_request(timeout); + double schedule_request_us = 0.0; + if (schedule_timer.has_value()) { + schedule_request_us = schedule_timer->elapsed_microseconds(); + } bool all_empty = std::all_of(result.batches.begin(), result.batches.end(), @@ -346,16 +392,48 @@ void FixedStepsScheduler::step(const absl::Duration& timeout) { if (all_empty) { return; } + const size_t batch_request_count = + result.batches.empty() ? 0 : result.batches[0].get_sequences().size(); + if (trace_multistream_timing) { + LOG(INFO) << "OneRec xattention scheduler host timing, " + << "stage=schedule_request, elapsed_us=" << schedule_request_us + << ", batch_request_count=" << batch_request_count + << ", waiting_requests=" << waiting_priority_queue_.size() + << ", in_flight_steps=" + << in_flight_steps_.load(std::memory_order_acquire); + } // Submit task to thread pool for asynchronous execution // After engine_->step() completes, process finished/cancelled requests auto function = [this, + trace_lock_timing, + batch_request_count, batches = std::move(result.batches), requests = std::move(result.requests), sequences = std::move(result.sequences)]() mutable { + SCOPE_GUARD([this] { + if (options_.rec_worker_max_concurrency() > 1) { + in_flight_steps_.fetch_sub(1, std::memory_order_acq_rel); + step_semaphore_.release(); + } + }); + std::optional engine_step_timer; + if (trace_lock_timing && options_.rec_worker_max_concurrency() > 1) { + engine_step_timer.emplace(); + } engine_->step(batches); + if (trace_lock_timing && options_.rec_worker_max_concurrency() > 1) { + LOG(INFO) << "OneRec xattention scheduler host timing, " + << "stage=engine_step, elapsed_us=" + << engine_step_timer->elapsed_microseconds() + << ", batch_request_count=" << batch_request_count; + } // After step completes, check and process finished/cancelled requests + std::optional response_timer; + if (trace_lock_timing && options_.rec_worker_max_concurrency() > 1) { + response_timer.emplace(); + } std::vector> finished_requests; for (auto& request : requests) { if (request) { @@ -371,14 +449,41 @@ void FixedStepsScheduler::step(const absl::Duration& timeout) { if (!finished_requests.empty()) { response_processor_->process_completed_requests(finished_requests); } - - if (options_.rec_worker_max_concurrency() > 1) { - step_semaphore_.release(); + if (trace_lock_timing && options_.rec_worker_max_concurrency() > 1) { + LOG(INFO) << "OneRec xattention scheduler host timing, " + << "stage=process_finished_requests, elapsed_us=" + << response_timer->elapsed_microseconds() + << ", finished_requests=" << finished_requests.size(); } }; if (options_.rec_worker_max_concurrency() > 1) { + std::optional semaphore_timer; + if (trace_lock_timing) { + semaphore_timer.emplace(); + } step_semaphore_.acquire(); + double semaphore_wait_us = 0.0; + if (semaphore_timer.has_value()) { + semaphore_wait_us = semaphore_timer->elapsed_microseconds(); + } + const size_t previous_in_flight = + in_flight_steps_.fetch_add(1, std::memory_order_acq_rel); + if (trace_lock_timing) { + LOG(INFO) << "OneRec xattention scheduler host timing, " + << "stage=semaphore_acquire, elapsed_us=" << semaphore_wait_us + << ", previous_in_flight=" << previous_in_flight + << ", max_concurrency=" + << options_.rec_worker_max_concurrency(); + } + if (enable_rec_pipeline_concurrency_debug()) { + LOG(INFO) << "FixedStepsScheduler submit async step, " + << "scheduler_concurrency=" + << options_.rec_worker_max_concurrency() + << ", previous_in_flight=" << previous_in_flight + << ", batch_request_count=" << batch_request_count + << ", waiting_requests=" << waiting_priority_queue_.size(); + } step_threadpool_->schedule(function); } else { function(); @@ -438,6 +543,14 @@ FixedStepsScheduler::OneRecXAttentionSchedulerPipeline::create_batches( scheduler.kv_cache_manager_->get_swap_block_transfer_infos()); } +size_t FixedStepsScheduler::OneRecXAttentionSchedulerPipeline:: + max_prefill_requests_per_step(const FixedStepsScheduler& scheduler) const { + if (scheduler.options_.rec_worker_max_concurrency() > 1) { + return 1; + } + return std::numeric_limits::max(); +} + bool FixedStepsScheduler::OneRecXAttentionSchedulerPipeline::allocate_kv_cache( KVCacheManager* kv_cache_manager, Sequence* sequence) { @@ -458,6 +571,19 @@ bool FixedStepsScheduler::OneRecXAttentionSchedulerPipeline::allocate_kv_cache( num_tokens + max_generated_tokens); } +bool FixedStepsScheduler::OneRecXAttentionSchedulerPipeline:: + respects_prefill_memory_threshold( + const FixedStepsScheduler& scheduler) const { + return scheduler.options_.rec_worker_max_concurrency() <= 1; +} + +bool FixedStepsScheduler::OneRecXAttentionSchedulerPipeline:: + should_defer_resource_exhausted( + const FixedStepsScheduler& scheduler) const { + return scheduler.options_.rec_worker_max_concurrency() > 1 && + scheduler.in_flight_steps_.load(std::memory_order_acquire) > 0; +} + std::vector FixedStepsScheduler::RecMultiRoundSchedulerPipeline::create_batches( FixedStepsScheduler& scheduler, diff --git a/xllm/core/scheduler/fixed_steps_scheduler.h b/xllm/core/scheduler/fixed_steps_scheduler.h index 2af25d51f1..a153c7b6d6 100644 --- a/xllm/core/scheduler/fixed_steps_scheduler.h +++ b/xllm/core/scheduler/fixed_steps_scheduler.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -65,6 +66,18 @@ class FixedStepsScheduler final : public ContinuousScheduler { virtual std::vector create_batches(FixedStepsScheduler& scheduler, BatchFactory* batch_factory) = 0; virtual bool requires_kv_cache() const = 0; + virtual size_t max_prefill_requests_per_step( + const FixedStepsScheduler& /*scheduler*/) const { + return std::numeric_limits::max(); + } + virtual bool respects_prefill_memory_threshold( + const FixedStepsScheduler& /*scheduler*/) const { + return true; + } + virtual bool should_defer_resource_exhausted( + const FixedStepsScheduler& /*scheduler*/) const { + return false; + } // Allocate KV cache for sequence, implemented by each pipeline virtual bool allocate_kv_cache(KVCacheManager* kv_cache_manager, Sequence* sequence) = 0; @@ -95,6 +108,12 @@ class FixedStepsScheduler final : public ContinuousScheduler { std::vector create_batches(FixedStepsScheduler& scheduler, BatchFactory* batch_factory) override; bool requires_kv_cache() const override { return true; } + size_t max_prefill_requests_per_step( + const FixedStepsScheduler& scheduler) const override; + bool respects_prefill_memory_threshold( + const FixedStepsScheduler& scheduler) const override; + bool should_defer_resource_exhausted( + const FixedStepsScheduler& scheduler) const override; bool allocate_kv_cache(KVCacheManager* kv_cache_manager, Sequence* sequence) override; }; @@ -133,6 +152,8 @@ class FixedStepsScheduler final : public ContinuousScheduler { // Semaphore to control concurrent execution of step() std::counting_semaphore<10000> step_semaphore_; + + std::atomic in_flight_steps_{0}; }; } // namespace xllm