Skip to content

【WIP】feat: onerec xattn npu multistream.#1453

Draft
DragonFive wants to merge 6 commits into
jd-opensource:mainfrom
DragonFive:feat/onerec-xattn-npu-multistream
Draft

【WIP】feat: onerec xattn npu multistream.#1453
DragonFive wants to merge 6 commits into
jd-opensource:mainfrom
DragonFive:feat/onerec-xattn-npu-multistream

Conversation

@DragonFive
Copy link
Copy Markdown
Collaborator

No description provided.

OneRecBatchInputBuilder::HighPerformanceCache&
OneRecBatchInputBuilder::get_perf_cache() {
static HighPerformanceCache cache;
thread_local HighPerformanceCache cache;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here no need to change, because we won't use multi stream in prefill only, and this cache only works in prefill only mode.

{static_cast<int64_t>(workspace_size)},
torch::TensorOptions().dtype(torch::kUInt8).device(logprobs.device()));
workspace_addr = workspace_tensor.data_ptr();
c10_npu::NPUCachingAllocator::recordStream(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you delete below aclrtSynchronizeStream and only add recordStream?

torch::TensorOptions().dtype(torch::kUInt8).device(logprobs.device()));
workspace_addr = workspace_tensor.data_ptr();
c10_npu::NPUCachingAllocator::recordStream(
workspace_tensor.storage().data_ptr(), npu_stream);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this means for workspace_tensor and npu_stream in recordstream?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for multi-concurrency in the OneRec xattention pipeline, enabling concurrent batch preparation and execution. Key changes include transitioning the performance cache to thread-local storage, implementing serialization for NPU-specific preparation and forward passes to ensure thread safety, and updating the scheduler to manage in-flight steps and resource exhaustion more effectively. Additionally, NPU kernel workspace management was improved by using the NPUCachingAllocator instead of manual allocation. Feedback includes several requests to annotate constant arguments with parameter names to comply with the repository style guide, a critical recommendation to capture asynchronous lambda parameters by value to avoid dangling references, and a suggestion to add a null check for the ATB context before use.

FLAGS_max_decode_rounds = 2;
FLAGS_prefill_scheduling_memory_usage_threshold = 1.0;

std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(64, 128);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Constant arguments in function or constructor calls should be annotated with a comment indicating the parameter name to improve readability, as per the repository style guide.

Suggested change
std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(64, 128);
std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(/*num_blocks=*/64, /*block_size=*/128);
References
  1. Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)

/*rec_worker_max_concurrency=*/2);
FixedStepsScheduler scheduler(engine.get(), opt);
std::vector<std::shared_ptr<Request>> requests =
GenRequests({1, 1}, {1, 1}, RecType::kOneRec);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Constant arguments in function calls should be annotated with a comment indicating the parameter name to improve readability.

Suggested change
GenRequests({1, 1}, {1, 1}, RecType::kOneRec);
GenRequests(/*prompt_lens=*/{1, 1}, /*max_tokens=*/{1, 1}, RecType::kOneRec);
References
  1. Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)

FLAGS_max_decode_rounds = 2;
FLAGS_prefill_scheduling_memory_usage_threshold = 1.0;

std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(3, 128);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Constant arguments in function or constructor calls should be annotated with a comment indicating the parameter name.

Suggested change
std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(3, 128);
std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(/*num_blocks=*/3, /*block_size=*/128);
References
  1. Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)

FLAGS_prefill_scheduling_memory_usage_threshold = 1.0;

std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(3, 128);
engine->notify_after_step_calls(2);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Constant arguments in function calls should be annotated with a comment indicating the parameter name.

Suggested change
engine->notify_after_step_calls(2);
engine->notify_after_step_calls(/*step_calls=*/2);
References
  1. Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)

FixedStepsScheduler scheduler(engine.get(), opt);

std::vector<std::shared_ptr<Request>> requests =
GenRequests({1, 1, 1}, {1, 1, 1}, RecType::kOneRec);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Constant arguments in function calls should be annotated with a comment indicating the parameter name.

Suggested change
GenRequests({1, 1, 1}, {1, 1, 1}, RecType::kOneRec);
GenRequests(/*prompt_lens=*/{1, 1, 1}, /*max_tokens=*/{1, 1, 1}, RecType::kOneRec);
References
  1. Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)

Comment thread xllm/core/runtime/rec_worker_impl.cpp Outdated
@@ -3015,6 +3063,19 @@ folly::SemiFuture<std::optional<ForwardOutput>> RecWorkerImpl::step_async(
[this, &input, index, promise = std::move(promise)]() mutable {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input parameter is captured by reference in the asynchronous lambda. This is dangerous because step_async returns a SemiFuture, and if the caller does not block on the result immediately, the input object might be destroyed before the lambda executes, leading to a dangling reference. Capturing input by value is safer for asynchronous tasks.

Suggested change
[this, &input, index, promise = std::move(promise)]() mutable {
[this, input, index, promise = std::move(promise)]() mutable {

Comment on lines +3075 to +3077
atb::Context* atb_context = const_cast<atb::Context*>(
work_pipelines_[index]->runtime().context->get_atb_context());
atb_context->SetExecuteStream(current_stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The atb_context obtained from the model context should be checked for nullptr before calling SetExecuteStream. Accessing a null pointer will cause a crash. Using a CHECK assertion is recommended here to ensure the context is valid for NPU execution.

        atb::Context* atb_context = const_cast<atb::Context*>(\n            work_pipelines_[index]->runtime().context->get_atb_context());\n        CHECK(atb_context != nullptr) << "ATB context is null for pipeline " << index;\n        atb_context->SetExecuteStream(current_stream);
References
  1. Use CHECK (glog) instead of TORCH_CHECK for assertions.

torch::TensorOptions()
.dtype(torch::kUInt8)
.device(beam_index.device()));
workspace_addr = workspace_tensor.data_ptr();
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like you use torch allocator for workspace to avoid malloc and free, but in multistream, it is ok for no sync but recordstream?

}
#else
bool should_serialize_onerec_xattention_prepare(int64_t max_concurrency) {
UNUSED_PARAMETER(max_concurrency);
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what UNUSED_PARAMETER means?

runtime_.worker.onerec_xattention_prepare_mutex_, std::defer_lock);
if (should_serialize_onerec_xattention_prepare(
runtime_.worker.options_.rec_worker_max_concurrency())) {
prepare_lock.lock();
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you use lock in prepare input for multi stream? may be this result in low performance?


bool should_serialize_onerec_xattention_model_forward(int64_t max_concurrency) {
return max_concurrency > 1;
}
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why max_concurrency > 1 should serialize forward and prepare input, does it influence the performance?

std::defer_lock);
if (should_serialize_onerec_xattention_model_forward(
runtime_.worker.options_.rec_worker_max_concurrency())) {
forward_lock.lock();
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you use lock, it looks like multistream never run in same time, does gpu multistream use this too?

aclrtStream current_stream =
c10_npu::getCurrentNPUStream(device_.index()).stream();
atb::Context* atb_context = const_cast<atb::Context*>(
work_pipelines_[index]->runtime().context->get_atb_context());
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like one device only one stream? so where is multi stream?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you remove it, it looks like it is used in many pipeline, will this influence other pipeline?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant