Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion tests/core/scheduler/fixed_steps_scheduler_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@ limitations under the License.

#include "fixed_steps_scheduler.h"

#include <absl/synchronization/notification.h>
#include <absl/time/time.h>
#include <gtest/gtest.h>

#include <algorithm>
#include <cstdint>
#include <limits>
#include <mutex>
#include <vector>

#include "common/global_flags.h"
#include "continuous_scheduler.h"
#include "distributed_runtime/engine.h"
#include "framework/request/rec_type.h"
#include "util/scope_guard.h"

namespace xllm {

Expand Down Expand Up @@ -71,7 +77,16 @@ class FakeEngine : public Engine {
fake_block_manager_ = std::make_unique<BlockManagerPool>(opt, 1);
}
ForwardOutput step(std::vector<Batch>& batch) override {
(void)batch;
{
std::lock_guard<std::mutex> lock(mutex_);
++step_calls_;
step_batch_sizes_.emplace_back(
batch.empty() ? 0 : batch[0].get_sequences().size());
if (step_calls_ >= notify_after_step_calls_ &&
!step_calls_ready_.HasBeenNotified()) {
step_calls_ready_.Notify();
}
}
return ForwardOutput();
}
void update_last_step_result(std::vector<Batch>& batch) override {
Expand All @@ -94,9 +109,33 @@ class FakeEngine : public Engine {
}
bool init() override { return true; }

int32_t step_calls() const {
std::lock_guard<std::mutex> lock(mutex_);
return step_calls_;
}

std::vector<size_t> step_batch_sizes() const {
std::lock_guard<std::mutex> lock(mutex_);
return step_batch_sizes_;
}

void notify_after_step_calls(int32_t step_calls) {
std::lock_guard<std::mutex> lock(mutex_);
notify_after_step_calls_ = step_calls;
}

bool wait_for_step_calls(absl::Duration timeout) const {
return step_calls_ready_.WaitForNotificationWithTimeout(timeout);
}

private:
std::unique_ptr<Tokenizer> fake_tokenizer_;
std::unique_ptr<BlockManagerPool> fake_block_manager_;
mutable std::mutex mutex_;
int32_t step_calls_ = 0;
int32_t notify_after_step_calls_ = std::numeric_limits<int32_t>::max();
std::vector<size_t> step_batch_sizes_;
mutable absl::Notification step_calls_ready_;
};

ContinuousScheduler::Options CreateOptions(
Expand Down Expand Up @@ -228,4 +267,92 @@ TEST(FixedStepsSchedulerTest, StepCompletesWithRequest) {
EXPECT_NO_THROW(scheduler.step(absl::Milliseconds(500)));
}

TEST(FixedStepsSchedulerTest,
OneRecXAttentionMultiConcurrencyPrepareBatchKeepsOneRequest) {
const bool old_enable_prefix_cache = FLAGS_enable_prefix_cache;
const int32_t old_max_decode_rounds = FLAGS_max_decode_rounds;
const double old_prefill_threshold =
FLAGS_prefill_scheduling_memory_usage_threshold;
SCOPE_GUARD([&] {
FLAGS_enable_prefix_cache = old_enable_prefix_cache;
FLAGS_max_decode_rounds = old_max_decode_rounds;
FLAGS_prefill_scheduling_memory_usage_threshold = old_prefill_threshold;
});

FLAGS_enable_prefix_cache = false;
FLAGS_max_decode_rounds = 2;
FLAGS_prefill_scheduling_memory_usage_threshold = 1.0;

std::unique_ptr<FakeEngine> engine =
std::make_unique<FakeEngine>(/*num_blocks=*/64, /*block_size=*/128);
ContinuousScheduler::Options opt = CreateOptions(
/*max_tokens_per_batch=*/10000,
/*max_seqs_per_batch=*/256,
/*dp_size=*/1,
/*enable_schedule_overlap=*/false,
/*rec_worker_max_concurrency=*/2);
FixedStepsScheduler scheduler(engine.get(), opt);
std::vector<std::shared_ptr<Request>> requests = GenRequests(
/*prompt_lens=*/{1, 1}, /*max_tokens=*/{1, 1}, RecType::kOneRec);
for (std::shared_ptr<Request>& req : requests) {
scheduler.add_request(req);
}

ContinuousScheduler* base = &scheduler;
std::vector<Batch> batches = base->prepare_batch_test();
ASSERT_FALSE(batches.empty());
EXPECT_EQ(batches[0].get_sequences().size(), 1u);
EXPECT_EQ(base->get_running_requests().size(), 1u);
EXPECT_EQ(base->get_waiting_requests().size(), 1u);
}

TEST(FixedStepsSchedulerTest,
OneRecXAttentionMultiConcurrencySchedulesSeparateSteps) {
const bool old_enable_prefix_cache = FLAGS_enable_prefix_cache;
const int32_t old_max_decode_rounds = FLAGS_max_decode_rounds;
const double old_prefill_threshold =
FLAGS_prefill_scheduling_memory_usage_threshold;
SCOPE_GUARD([&] {
FLAGS_enable_prefix_cache = old_enable_prefix_cache;
FLAGS_max_decode_rounds = old_max_decode_rounds;
FLAGS_prefill_scheduling_memory_usage_threshold = old_prefill_threshold;
});

FLAGS_enable_prefix_cache = false;
FLAGS_max_decode_rounds = 2;
FLAGS_prefill_scheduling_memory_usage_threshold = 1.0;

std::unique_ptr<FakeEngine> engine =
std::make_unique<FakeEngine>(/*num_blocks=*/3, /*block_size=*/128);
engine->notify_after_step_calls(/*step_calls=*/2);
{
ContinuousScheduler::Options opt = CreateOptions(
/*max_tokens_per_batch=*/10000,
/*max_seqs_per_batch=*/256,
/*dp_size=*/1,
/*enable_schedule_overlap=*/false,
/*rec_worker_max_concurrency=*/2);
FixedStepsScheduler scheduler(engine.get(), opt);

std::vector<std::shared_ptr<Request>> requests =
GenRequests(/*prompt_lens=*/{1, 1, 1},
/*max_tokens=*/{1, 1, 1},
RecType::kOneRec);
for (std::shared_ptr<Request>& req : requests) {
scheduler.add_request(req);
}

scheduler.step(absl::ZeroDuration());
scheduler.step(absl::ZeroDuration());

EXPECT_TRUE(engine->wait_for_step_calls(absl::Milliseconds(500)));
EXPECT_EQ(engine->step_calls(), 2);
const std::vector<size_t> step_batch_sizes = engine->step_batch_sizes();
ASSERT_EQ(step_batch_sizes.size(), 2u);
EXPECT_EQ(step_batch_sizes[0], 1u);
EXPECT_EQ(step_batch_sizes[1], 1u);
EXPECT_EQ(scheduler.get_waiting_requests().size(), 1u);
}
}

} // namespace xllm
2 changes: 1 addition & 1 deletion third_party/xllm_atb_layers
Submodule xllm_atb_layers updated from 49e2cf to 3555e2
41 changes: 34 additions & 7 deletions xllm/core/distributed_runtime/rec_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ namespace xllm {
namespace {

constexpr int64_t kMinimalOneRecMetadataKVBlocks = 2;
constexpr const char* kOnerecXAttentionLockTimingEnv =
"XLLM_DEBUG_ONEREC_XATTN_LOCK_TIMING";

bool enable_onerec_xattention_lock_timing() {
static const bool enable_lock_timing =
util::get_bool_env(kOnerecXAttentionLockTimingEnv, false);
return enable_lock_timing;
}

} // namespace

Expand Down Expand Up @@ -813,7 +821,11 @@ RecEngine::OneRecXAttentionEnginePipeline::OneRecXAttentionEnginePipeline(

int64_t RecEngine::OneRecXAttentionEnginePipeline::minimal_kv_cache_blocks()
const {
return kMinimalOneRecMetadataKVBlocks;
const int64_t max_concurrency =
std::max<int64_t>(1, engine_.options_.rec_worker_max_concurrency());
// BlockManagerImpl reserves block 0 for padding, so xattention needs one
// allocatable metadata block per concurrent scheduler step plus padding.
return std::max(kMinimalOneRecMetadataKVBlocks, max_concurrency + 1);
}

ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::step(
Expand All @@ -824,16 +836,25 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::step(

Timer timer;
auto forward_inputs = engine_.workers_[0]->prepare_inputs(batches[0]);
COUNTER_ADD(prepare_input_latency_microseconds, timer.elapsed_microseconds());
const double prepare_inputs_us = timer.elapsed_microseconds();
COUNTER_ADD(prepare_input_latency_microseconds, prepare_inputs_us);
if (enable_onerec_xattention_lock_timing()) {
LOG(INFO) << "OneRec xattention engine host timing, "
<< "stage=prepare_inputs, elapsed_us=" << prepare_inputs_us;
}

if (!forward_inputs.token_ids.defined()) {
return {};
}

timer.reset();
const auto& output = get_model_output(forward_inputs);
COUNTER_ADD(rec_first_token_latency_microseconds,
timer.elapsed_microseconds());
const double model_output_us = timer.elapsed_microseconds();
COUNTER_ADD(rec_first_token_latency_microseconds, model_output_us);
if (enable_onerec_xattention_lock_timing()) {
LOG(INFO) << "OneRec xattention engine host timing, "
<< "stage=get_model_output, elapsed_us=" << model_output_us;
}

timer.reset();
if (output.beam_sequence_group.defined() &&
Expand All @@ -842,7 +863,12 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::step(
} else {
batches[0].process_sample_output(output.sample_output, false);
}
COUNTER_ADD(rec_sampling_latency_microseconds, timer.elapsed_microseconds());
const double process_output_us = timer.elapsed_microseconds();
COUNTER_ADD(rec_sampling_latency_microseconds, process_output_us);
if (enable_onerec_xattention_lock_timing()) {
LOG(INFO) << "OneRec xattention engine host timing, "
<< "stage=process_output, elapsed_us=" << process_output_us;
}

batches[0].finish();
return output;
Expand All @@ -854,6 +880,7 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::get_model_output(
util::get_bool_env("XLLM_DEBUG_ONEREC_ENGINE_TRACE", false);
const bool trace_stage_timing =
util::get_bool_env("XLLM_DEBUG_ONEREC_XATTN_STAGE_TIMING", false);
const bool trace_lock_timing = enable_onerec_xattention_lock_timing();
Timer engine_timer;
auto log_engine_stage = [&](const char* stage_name,
const torch::Tensor& tensor = torch::Tensor()) {
Expand All @@ -865,10 +892,10 @@ ForwardOutput RecEngine::OneRecXAttentionEnginePipeline::get_model_output(
<< (tensor.defined() ? tensor.sizes() : c10::IntArrayRef{});
};
auto log_engine_timing = [&](const char* stage_name) {
if (!trace_stage_timing) {
if (!trace_stage_timing && !trace_lock_timing) {
return;
}
LOG(INFO) << "OneRec xattention engine timing, stage=" << stage_name
LOG(INFO) << "OneRec xattention engine host timing, stage=" << stage_name
<< ", elapsed_us=" << engine_timer.elapsed_microseconds();
engine_timer.reset();
};
Expand Down
11 changes: 5 additions & 6 deletions xllm/core/framework/batch/onerec_batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,11 @@ limitations under the License.

namespace xllm {

// Use Meyers' Singleton pattern to avoid static initialization order fiasco
// This ensures the cache is initialized on first use, after all dependencies
// (like PyTorch runtime) are properly initialized.
// Keep the cache per scheduler thread. OneRec xattention multi-stream can build
// multiple inputs concurrently, and CacheData is mutated during construction.
OneRecBatchInputBuilder::HighPerformanceCache&
OneRecBatchInputBuilder::get_perf_cache() {
static HighPerformanceCache cache;
thread_local HighPerformanceCache cache;
cache.ensure_tensors_initialized();
return cache;
}
Expand All @@ -59,15 +58,15 @@ OneRecBatchInputBuilder::OneRecBatchInputBuilder(
args_(args),
batch_forward_type_(batch_forward_type),
thread_pool_(thread_pool) {
// Get references to function-local statics (safe initialization)
// Reset only this thread's reusable scratch vectors.
auto& perf_cache = get_perf_cache();
perf_cache.memory_pool.reset();
}

ForwardInput OneRecBatchInputBuilder::build_rec_forward_input(
uint32_t num_decoding_tokens,
uint32_t min_decoding_batch_size) {
// Get reference to function-local static cache (safe initialization)
// Get this thread's reusable scratch cache.
auto& perf_cache = get_perf_cache();

// ========== Global constant cache ==========
Expand Down
22 changes: 10 additions & 12 deletions xllm/core/kernels/npu/xllm_ops/beam_search_rec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include <c10/core/Device.h>
#include <glog/logging.h>
#include <torch/torch.h>
#include <torch_npu/csrc/core/npu/NPUCachingAllocator.h>
#include <torch_npu/csrc/libs/init_npu.h>
#include <torch_npu/torch_npu.h>

Expand Down Expand Up @@ -178,7 +179,8 @@ void beam_search_rec(const torch::Tensor& logprobs,
aclTensor* out_beam_count_prefix_sums_ids = nullptr;
aclTensor* out_sequence_ids = nullptr;
int32_t device_id = logprobs.device().index();
aclrtStream stream = c10_npu::getCurrentNPUStream(device_id).stream();
c10_npu::NPUStream npu_stream = c10_npu::getCurrentNPUStream(device_id);
aclrtStream stream = npu_stream.stream();

create_acltensor(&logprobs_ids, logprobs);
create_acltensor(&top_tokens_ids, top_tokens);
Expand Down Expand Up @@ -209,17 +211,18 @@ void beam_search_rec(const torch::Tensor& logprobs,
&executor),
"beam_search_rec: failed to get workspace size for REC beam search");
void* workspace_addr = nullptr;
torch::Tensor workspace_tensor;
if (workspace_size > 0) {
CHECK_ACL_SUCCESS(
aclrtMalloc(&workspace_addr, workspace_size, ACL_MEM_MALLOC_HUGE_FIRST),
"beam_search_rec: failed to allocate workspace for REC beam search");
workspace_tensor = torch::empty(
{static_cast<int64_t>(workspace_size)},
torch::TensorOptions().dtype(torch::kUInt8).device(logprobs.device()));
workspace_addr = workspace_tensor.data_ptr();
c10_npu::NPUCachingAllocator::recordStream(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you delete below aclrtSynchronizeStream and only add recordStream?

workspace_tensor.storage().data_ptr(), npu_stream);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this means for workspace_tensor and npu_stream in recordstream?

}
CHECK_ACL_SUCCESS(
aclnnBeamSearchGroup(workspace_addr, workspace_size, executor, stream),
"beam_search_rec: failed to execute REC beam search");
CHECK_ACL_SUCCESS(
aclrtSynchronizeStream(stream),
"beam_search_rec: failed to synchronize stream for REC beam search");
aclDestroyTensor(logprobs_ids);
aclDestroyTensor(top_tokens_ids);
aclDestroyTensor(top_logprobs_ids);
Expand All @@ -229,11 +232,6 @@ void beam_search_rec(const torch::Tensor& logprobs,
aclDestroyTensor(out_log_probs_ids);
aclDestroyTensor(out_beam_count_prefix_sums_ids);
aclDestroyTensor(out_sequence_ids);
if (workspace_size > 0) {
CHECK_ACL_SUCCESS(
aclrtFree(workspace_addr),
"beam_search_rec: failed to free workspace for REC beam search");
}
}

void beam_search_rec(const torch::Tensor& logprobs,
Expand Down
Loading
Loading