【WIP】feat: onerec xattn npu multistream.#1453
Conversation
| OneRecBatchInputBuilder::HighPerformanceCache& | ||
| OneRecBatchInputBuilder::get_perf_cache() { | ||
| static HighPerformanceCache cache; | ||
| thread_local HighPerformanceCache cache; |
There was a problem hiding this comment.
here no need to change, because we won't use multi stream in prefill only, and this cache only works in prefill only mode.
| {static_cast<int64_t>(workspace_size)}, | ||
| torch::TensorOptions().dtype(torch::kUInt8).device(logprobs.device())); | ||
| workspace_addr = workspace_tensor.data_ptr(); | ||
| c10_npu::NPUCachingAllocator::recordStream( |
There was a problem hiding this comment.
why you delete below aclrtSynchronizeStream and only add recordStream?
| torch::TensorOptions().dtype(torch::kUInt8).device(logprobs.device())); | ||
| workspace_addr = workspace_tensor.data_ptr(); | ||
| c10_npu::NPUCachingAllocator::recordStream( | ||
| workspace_tensor.storage().data_ptr(), npu_stream); |
There was a problem hiding this comment.
what's this means for workspace_tensor and npu_stream in recordstream?
There was a problem hiding this comment.
Code Review
This pull request introduces support for multi-concurrency in the OneRec xattention pipeline, enabling concurrent batch preparation and execution. Key changes include transitioning the performance cache to thread-local storage, implementing serialization for NPU-specific preparation and forward passes to ensure thread safety, and updating the scheduler to manage in-flight steps and resource exhaustion more effectively. Additionally, NPU kernel workspace management was improved by using the NPUCachingAllocator instead of manual allocation. Feedback includes several requests to annotate constant arguments with parameter names to comply with the repository style guide, a critical recommendation to capture asynchronous lambda parameters by value to avoid dangling references, and a suggestion to add a null check for the ATB context before use.
| FLAGS_max_decode_rounds = 2; | ||
| FLAGS_prefill_scheduling_memory_usage_threshold = 1.0; | ||
|
|
||
| std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(64, 128); |
There was a problem hiding this comment.
Constant arguments in function or constructor calls should be annotated with a comment indicating the parameter name to improve readability, as per the repository style guide.
| std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(64, 128); | |
| std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(/*num_blocks=*/64, /*block_size=*/128); |
References
- Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)
| /*rec_worker_max_concurrency=*/2); | ||
| FixedStepsScheduler scheduler(engine.get(), opt); | ||
| std::vector<std::shared_ptr<Request>> requests = | ||
| GenRequests({1, 1}, {1, 1}, RecType::kOneRec); |
There was a problem hiding this comment.
Constant arguments in function calls should be annotated with a comment indicating the parameter name to improve readability.
| GenRequests({1, 1}, {1, 1}, RecType::kOneRec); | |
| GenRequests(/*prompt_lens=*/{1, 1}, /*max_tokens=*/{1, 1}, RecType::kOneRec); |
References
- Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)
| FLAGS_max_decode_rounds = 2; | ||
| FLAGS_prefill_scheduling_memory_usage_threshold = 1.0; | ||
|
|
||
| std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(3, 128); |
There was a problem hiding this comment.
Constant arguments in function or constructor calls should be annotated with a comment indicating the parameter name.
| std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(3, 128); | |
| std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(/*num_blocks=*/3, /*block_size=*/128); |
References
- Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)
| FLAGS_prefill_scheduling_memory_usage_threshold = 1.0; | ||
|
|
||
| std::unique_ptr<FakeEngine> engine = std::make_unique<FakeEngine>(3, 128); | ||
| engine->notify_after_step_calls(2); |
There was a problem hiding this comment.
Constant arguments in function calls should be annotated with a comment indicating the parameter name.
| engine->notify_after_step_calls(2); | |
| engine->notify_after_step_calls(/*step_calls=*/2); |
References
- Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)
| FixedStepsScheduler scheduler(engine.get(), opt); | ||
|
|
||
| std::vector<std::shared_ptr<Request>> requests = | ||
| GenRequests({1, 1, 1}, {1, 1, 1}, RecType::kOneRec); |
There was a problem hiding this comment.
Constant arguments in function calls should be annotated with a comment indicating the parameter name.
| GenRequests({1, 1, 1}, {1, 1, 1}, RecType::kOneRec); | |
| GenRequests(/*prompt_lens=*/{1, 1, 1}, /*max_tokens=*/{1, 1, 1}, RecType::kOneRec); |
References
- Annotate constant arguments with a comment indicating the parameter name when calling functions or constructors. (link)
| @@ -3015,6 +3063,19 @@ folly::SemiFuture<std::optional<ForwardOutput>> RecWorkerImpl::step_async( | |||
| [this, &input, index, promise = std::move(promise)]() mutable { | |||
There was a problem hiding this comment.
The input parameter is captured by reference in the asynchronous lambda. This is dangerous because step_async returns a SemiFuture, and if the caller does not block on the result immediately, the input object might be destroyed before the lambda executes, leading to a dangling reference. Capturing input by value is safer for asynchronous tasks.
| [this, &input, index, promise = std::move(promise)]() mutable { | |
| [this, input, index, promise = std::move(promise)]() mutable { |
| atb::Context* atb_context = const_cast<atb::Context*>( | ||
| work_pipelines_[index]->runtime().context->get_atb_context()); | ||
| atb_context->SetExecuteStream(current_stream); |
There was a problem hiding this comment.
The atb_context obtained from the model context should be checked for nullptr before calling SetExecuteStream. Accessing a null pointer will cause a crash. Using a CHECK assertion is recommended here to ensure the context is valid for NPU execution.
atb::Context* atb_context = const_cast<atb::Context*>(\n work_pipelines_[index]->runtime().context->get_atb_context());\n CHECK(atb_context != nullptr) << "ATB context is null for pipeline " << index;\n atb_context->SetExecuteStream(current_stream);References
- Use CHECK (glog) instead of TORCH_CHECK for assertions.
| torch::TensorOptions() | ||
| .dtype(torch::kUInt8) | ||
| .device(beam_index.device())); | ||
| workspace_addr = workspace_tensor.data_ptr(); |
There was a problem hiding this comment.
it looks like you use torch allocator for workspace to avoid malloc and free, but in multistream, it is ok for no sync but recordstream?
| } | ||
| #else | ||
| bool should_serialize_onerec_xattention_prepare(int64_t max_concurrency) { | ||
| UNUSED_PARAMETER(max_concurrency); |
There was a problem hiding this comment.
what UNUSED_PARAMETER means?
| runtime_.worker.onerec_xattention_prepare_mutex_, std::defer_lock); | ||
| if (should_serialize_onerec_xattention_prepare( | ||
| runtime_.worker.options_.rec_worker_max_concurrency())) { | ||
| prepare_lock.lock(); |
There was a problem hiding this comment.
do you use lock in prepare input for multi stream? may be this result in low performance?
|
|
||
| bool should_serialize_onerec_xattention_model_forward(int64_t max_concurrency) { | ||
| return max_concurrency > 1; | ||
| } |
There was a problem hiding this comment.
why max_concurrency > 1 should serialize forward and prepare input, does it influence the performance?
| std::defer_lock); | ||
| if (should_serialize_onerec_xattention_model_forward( | ||
| runtime_.worker.options_.rec_worker_max_concurrency())) { | ||
| forward_lock.lock(); |
There was a problem hiding this comment.
why you use lock, it looks like multistream never run in same time, does gpu multistream use this too?
| aclrtStream current_stream = | ||
| c10_npu::getCurrentNPUStream(device_.index()).stream(); | ||
| atb::Context* atb_context = const_cast<atb::Context*>( | ||
| work_pipelines_[index]->runtime().context->get_atb_context()); |
There was a problem hiding this comment.
it looks like one device only one stream? so where is multi stream?
There was a problem hiding this comment.
why you remove it, it looks like it is used in many pipeline, will this influence other pipeline?
No description provided.