From 300dd2811a70625f66c2d512be976effb45bb856 Mon Sep 17 00:00:00 2001 From: limenxin Date: Fri, 8 May 2026 14:18:02 +0800 Subject: [PATCH] refactor: remove xattention one-stage decode path. --- tests/core/layers/cuda/xattention_test.cpp | 143 +++++++++---- xllm/c_api/default.h | 1 - xllm/c_api/internal/rec.cpp | 4 +- xllm/c_api/types.h | 3 - xllm/core/common/global_flags.h | 2 - xllm/core/framework/config/rec_config.cpp | 8 - xllm/core/framework/config/rec_config.h | 3 - .../common/attention_metadata_builder.cpp | 71 ++++--- xllm/core/layers/cuda/xattention.cpp | 76 +------ xllm/core/layers/cuda/xattention.h | 4 - .../core/runtime/cuda_graph_executor_impl.cpp | 4 +- xllm/core/runtime/rec_worker_impl.cpp | 192 ++---------------- xllm/core/runtime/rec_worker_impl.h | 34 +--- 13 files changed, 164 insertions(+), 381 deletions(-) diff --git a/tests/core/layers/cuda/xattention_test.cpp b/tests/core/layers/cuda/xattention_test.cpp index ea92e1cc73..e10bfcd2cd 100644 --- a/tests/core/layers/cuda/xattention_test.cpp +++ b/tests/core/layers/cuda/xattention_test.cpp @@ -24,7 +24,6 @@ limitations under the License. #include #include -#include "core/framework/config/rec_config.h" #include "core/framework/config/scheduler_config.h" #include "framework/kv_cache/kv_cache.h" #include "layers/cuda/flashinfer_workspace.h" @@ -40,16 +39,94 @@ struct DecodeTestInput { torch::Tensor value; }; +constexpr int32_t kTestBatchSize = 4; +constexpr int32_t kTestBeamWidth = 128; +constexpr int32_t kTestNumHeads = 16; +constexpr int32_t kTestNumKvHeads = 8; +constexpr int32_t kTestHeadDim = 128; +constexpr int32_t kTestSharedSeqLen = 300; +constexpr int32_t kTestMaxDecodeStep = 2; +constexpr int32_t kTestCurrentStep = 1; + +torch::Tensor build_reference_output(const DecodeTestInput& input) { + const int64_t total_beam = input.query.size(0); + const int64_t shared_len = kTestSharedSeqLen; + const int64_t seq_len = shared_len + kTestCurrentStep + 1; + const int64_t group_size = kTestNumHeads / kTestNumKvHeads; + + auto float_opts = torch::TensorOptions() + .dtype(torch::kFloat32) + .device(input.query.device()); + torch::Tensor query = + input.query.view({total_beam, kTestNumHeads, kTestHeadDim}) + .to(torch::kFloat32); + torch::Tensor key = input.attn_metadata.full_k_cache.clone(); + torch::Tensor value = input.attn_metadata.full_v_cache.clone(); + torch::Tensor current_key = + input.key.view({total_beam, kTestNumKvHeads, kTestHeadDim}); + torch::Tensor current_value = + input.value.view({total_beam, kTestNumKvHeads, kTestHeadDim}); + torch::Tensor output = + torch::empty({total_beam, kTestNumHeads * kTestHeadDim}, float_opts); + + const float scale = 1.0f / std::sqrt(static_cast(kTestHeadDim)); + for (int64_t beam_idx = 0; beam_idx < total_beam; ++beam_idx) { + const int64_t slot_id = + shared_len + beam_idx * kTestMaxDecodeStep + kTestCurrentStep; + key.select(0, slot_id).copy_(current_key.select(0, beam_idx)); + value.select(0, slot_id).copy_(current_value.select(0, beam_idx)); + } + torch::Tensor shared_key = key.slice(0, 0, shared_len); + torch::Tensor shared_value = value.slice(0, 0, shared_len); + + for (int64_t beam_idx = 0; beam_idx < total_beam; ++beam_idx) { + torch::Tensor beam_key = + torch::empty({seq_len, kTestNumKvHeads, kTestHeadDim}, float_opts); + torch::Tensor beam_value = + torch::empty({seq_len, kTestNumKvHeads, kTestHeadDim}, float_opts); + + beam_key.slice(/*dim=*/0, /*start=*/0, /*end=*/shared_len) + .copy_(shared_key); + beam_value.slice(/*dim=*/0, /*start=*/0, /*end=*/shared_len) + .copy_(shared_value); + + const int64_t beam_base = shared_len + beam_idx * kTestMaxDecodeStep; + beam_key.slice(/*dim=*/0, /*start=*/shared_len, /*end=*/seq_len) + .copy_(key.slice(0, beam_base, beam_base + kTestCurrentStep + 1)); + beam_value.slice(/*dim=*/0, /*start=*/shared_len, /*end=*/seq_len) + .copy_(value.slice(0, beam_base, beam_base + kTestCurrentStep + 1)); + + torch::Tensor key_rep = + beam_key.unsqueeze(2) + .expand({seq_len, kTestNumKvHeads, group_size, kTestHeadDim}) + .reshape({seq_len, kTestNumHeads, kTestHeadDim}); + torch::Tensor value_rep = + beam_value.unsqueeze(2) + .expand({seq_len, kTestNumKvHeads, group_size, kTestHeadDim}) + .reshape({seq_len, kTestNumHeads, kTestHeadDim}); + + torch::Tensor scores = + torch::einsum("hd,shd->hs", {query.select(0, beam_idx), key_rep}) * + scale; + torch::Tensor attn = torch::softmax(scores, /*dim=*/-1); + torch::Tensor beam_output = torch::einsum("hs,shd->hd", {attn, value_rep}) + .reshape({kTestNumHeads * kTestHeadDim}); + output.select(0, beam_idx).copy_(beam_output); + } + + return output.to(input.query.scalar_type()); +} + class XAttentionDecodeCompareTest : public ::testing::Test { protected: - static constexpr int32_t kBatchSize = 4; - static constexpr int32_t kBeamWidth = 128; - static constexpr int32_t kNumHeads = 16; - static constexpr int32_t kNumKvHeads = 8; - static constexpr int32_t kHeadDim = 128; - static constexpr int32_t kSharedSeqLen = 300; - static constexpr int32_t kMaxDecodeStep = 2; - static constexpr int32_t kCurrentStep = 1; + static constexpr int32_t kBatchSize = kTestBatchSize; + static constexpr int32_t kBeamWidth = kTestBeamWidth; + static constexpr int32_t kNumHeads = kTestNumHeads; + static constexpr int32_t kNumKvHeads = kTestNumKvHeads; + static constexpr int32_t kHeadDim = kTestHeadDim; + static constexpr int32_t kSharedSeqLen = kTestSharedSeqLen; + static constexpr int32_t kMaxDecodeStep = kTestMaxDecodeStep; + static constexpr int32_t kCurrentStep = kTestCurrentStep; void SetUp() override { if (!torch::cuda::is_available()) { @@ -199,8 +276,7 @@ class XAttentionDecodeCompareTest : public ::testing::Test { return input; } - torch::Tensor run_decode_once(DecodeTestInput& input, bool enable_two_stage) { - RecConfig::get_instance().enable_xattention_one_stage(!enable_two_stage); + torch::Tensor run_two_stage_decode_once(DecodeTestInput& input) { SchedulerConfig::get_instance().max_tokens_per_batch(kSharedSeqLen); XAttentionImpl attention( @@ -224,33 +300,30 @@ class XAttentionDecodeCompareTest : public ::testing::Test { return std::get<0>(result).clone(); } - void compare_single_and_two_stage(torch::ScalarType dtype, - double atol, - double rtol) { + void compare_two_stage_with_reference(torch::ScalarType dtype, + double atol, + double rtol) { constexpr int64_t kSeed = 20260303; torch::manual_seed(kSeed); torch::cuda::manual_seed_all(kSeed); - auto single_input = create_decode_test_input(dtype); + auto input = create_decode_test_input(dtype); torch::manual_seed(kSeed); torch::cuda::manual_seed_all(kSeed); - auto two_stage_input = create_decode_test_input(dtype); - - two_stage_input.query.copy_(single_input.query); - two_stage_input.key.copy_(single_input.key); - two_stage_input.value.copy_(single_input.value); + auto reference_input = create_decode_test_input(dtype); + reference_input.query.copy_(input.query); + reference_input.key.copy_(input.key); + reference_input.value.copy_(input.value); - auto single_output = - run_decode_once(single_input, /*enable_two_stage=*/false); - auto two_stage_output = - run_decode_once(two_stage_input, /*enable_two_stage=*/true); + auto reference_output = build_reference_output(reference_input); + auto two_stage_output = run_two_stage_decode_once(input); auto abs_diff = - (single_output - two_stage_output).abs().to(torch::kFloat32); + (reference_output - two_stage_output).abs().to(torch::kFloat32); const double max_abs_diff = abs_diff.max().item(); const double mean_abs_diff = abs_diff.mean().item(); - EXPECT_TRUE(torch::allclose(single_output, two_stage_output, rtol, atol)) - << "single-stage and two-stage decode outputs mismatch: " + EXPECT_TRUE(torch::allclose(reference_output, two_stage_output, rtol, atol)) + << "reference and two-stage decode outputs mismatch: " << "max_abs_diff=" << max_abs_diff << ", mean_abs_diff=" << mean_abs_diff << ", atol=" << atol << ", rtol=" << rtol; @@ -263,16 +336,16 @@ class XAttentionDecodeCompareTest : public ::testing::Test { torch::Device device_{torch::kCPU}; }; -TEST_F(XAttentionDecodeCompareTest, SingleVsTwoStageFp16) { - compare_single_and_two_stage(torch::kFloat16, - /*atol=*/2e-3, - /*rtol=*/2e-3); +TEST_F(XAttentionDecodeCompareTest, TwoStageFp16) { + compare_two_stage_with_reference(torch::kFloat16, + /*atol=*/2e-3, + /*rtol=*/2e-3); } -TEST_F(XAttentionDecodeCompareTest, SingleVsTwoStageBf16) { - compare_single_and_two_stage(torch::kBFloat16, - /*atol=*/2e-2, - /*rtol=*/2e-2); +TEST_F(XAttentionDecodeCompareTest, TwoStageBf16) { + compare_two_stage_with_reference(torch::kBFloat16, + /*atol=*/2e-2, + /*rtol=*/2e-2); } } // namespace diff --git a/xllm/c_api/default.h b/xllm/c_api/default.h index c2b6b05e27..4ea91df856 100644 --- a/xllm/c_api/default.h +++ b/xllm/c_api/default.h @@ -94,7 +94,6 @@ const XLLM_InitOptions XLLM_INIT_REC_OPTIONS_DEFAULT = { .enable_graph = true, .enable_rec_fast_sampler = true, .enable_prefill_piecewise_graph = true, - .enable_xattention_one_stage = false, .enable_graph_mode_decode_no_padding = true, .enable_block_copy_kernel = false, .enable_topk_sorted = false, diff --git a/xllm/c_api/internal/rec.cpp b/xllm/c_api/internal/rec.cpp index c880fcbc1a..7d56061532 100644 --- a/xllm/c_api/internal/rec.cpp +++ b/xllm/c_api/internal/rec.cpp @@ -59,7 +59,6 @@ const char* get_rec_pipeline_name(xllm::RecPipelineType pipeline_type) { void reset_pipeline_runtime_toggles() { xllm::RecConfig::get_instance() .enable_rec_fast_sampler(false) - .enable_xattention_one_stage(false) .enable_rec_prefill_only(false) .enable_constrained_decoding(false); xllm::ExecutionConfig::get_instance() @@ -70,8 +69,7 @@ void reset_pipeline_runtime_toggles() { void apply_multi_round_pipeline_toggles() { xllm::RecConfig::get_instance() - .enable_rec_fast_sampler(true) - .enable_xattention_one_stage(false); + .enable_rec_fast_sampler(true); xllm::ExecutionConfig::get_instance() .enable_prefill_piecewise_graph(true) .enable_graph_mode_decode_no_padding(true); diff --git a/xllm/c_api/types.h b/xllm/c_api/types.h index cdb609e8e4..67ec4194a0 100644 --- a/xllm/c_api/types.h +++ b/xllm/c_api/types.h @@ -70,9 +70,6 @@ typedef struct XLLM_CAPI_EXPORT XLLM_InitOptions { /** Whether to enable prefill piecewise graph for REC */ bool enable_prefill_piecewise_graph; - /** Whether to enable xattention one-stage execution for REC */ - bool enable_xattention_one_stage; - /** Whether to enable graph-mode decode without padding for REC */ bool enable_graph_mode_decode_no_padding; diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 0a2e23e47a..b5619f103b 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -273,8 +273,6 @@ DECLARE_bool(enable_rec_prefill_only); DECLARE_bool(output_rec_logprobs); -DECLARE_bool(enable_xattention_one_stage); - DECLARE_int32(max_decode_rounds); DECLARE_bool(enable_constrained_decoding); diff --git a/xllm/core/framework/config/rec_config.cpp b/xllm/core/framework/config/rec_config.cpp index 6a30e4a769..d749bfb34c 100644 --- a/xllm/core/framework/config/rec_config.cpp +++ b/xllm/core/framework/config/rec_config.cpp @@ -28,11 +28,6 @@ DEFINE_bool(enable_rec_prefill_only, "Enable rec prefill-only mode (no decoder self-attention blocks " "allocation)."); -DEFINE_bool(enable_xattention_one_stage, - false, - "Whether to force xattention one-stage decode for rec " - "multi-round mode."); - DEFINE_int32(max_decode_rounds, 0, "Maximum number of decode rounds for multi-step decoding. " @@ -85,7 +80,6 @@ namespace xllm { void RecConfig::from_flags() { enable_rec_fast_sampler(FLAGS_enable_rec_fast_sampler) .enable_rec_prefill_only(FLAGS_enable_rec_prefill_only) - .enable_xattention_one_stage(FLAGS_enable_xattention_one_stage) .max_decode_rounds(FLAGS_max_decode_rounds) .enable_constrained_decoding(FLAGS_enable_constrained_decoding) .output_rec_logprobs(FLAGS_output_rec_logprobs) @@ -103,8 +97,6 @@ void RecConfig::from_json(const JsonReader& json) { json.value_or("enable_rec_fast_sampler", enable_rec_fast_sampler())) .enable_rec_prefill_only(json.value_or("enable_rec_prefill_only", enable_rec_prefill_only())) - .enable_xattention_one_stage(json.value_or( - "enable_xattention_one_stage", enable_xattention_one_stage())) .max_decode_rounds( json.value_or("max_decode_rounds", max_decode_rounds())) .enable_constrained_decoding(json.value_or( diff --git a/xllm/core/framework/config/rec_config.h b/xllm/core/framework/config/rec_config.h index 0ec01d749f..4c2ac3cbec 100644 --- a/xllm/core/framework/config/rec_config.h +++ b/xllm/core/framework/config/rec_config.h @@ -40,7 +40,6 @@ class RecConfig final { "REC OPTIONS", {"enable_rec_fast_sampler", "enable_rec_prefill_only", - "enable_xattention_one_stage", "max_decode_rounds", "enable_constrained_decoding", "output_rec_logprobs", @@ -58,8 +57,6 @@ class RecConfig final { PROPERTY(bool, enable_rec_prefill_only) = false; - PROPERTY(bool, enable_xattention_one_stage) = false; - PROPERTY(int32_t, max_decode_rounds) = 0; PROPERTY(bool, enable_constrained_decoding) = false; diff --git a/xllm/core/layers/common/attention_metadata_builder.cpp b/xllm/core/layers/common/attention_metadata_builder.cpp index 19d7c52e5e..1beb0756db 100644 --- a/xllm/core/layers/common/attention_metadata_builder.cpp +++ b/xllm/core/layers/common/attention_metadata_builder.cpp @@ -22,7 +22,6 @@ limitations under the License. #include "attention_metadata.h" #include "core/common/global_flags.h" #include "core/framework/config/execution_config.h" -#include "core/framework/config/rec_config.h" #include "framework/model/model_args.h" #include "framework/model/model_input_params.h" @@ -166,44 +165,42 @@ AttentionMetadata build_attention_metadata( attn_metadata.step_tensor = llmrec_params.current_round_tensor; } - if (!::xllm::RecConfig::get_instance().enable_xattention_one_stage()) { #if defined(USE_CUDA) || defined(USE_MUSA) - attn_metadata.xattention_two_stage_decode_cache.emplace( - XAttentionTwoStageDecodeCache{}); - auto& cache = attn_metadata.xattention_two_stage_decode_cache.value(); - - cache.shared_lse = llmrec_params.two_stage_shared_lse; - cache.shared_o = llmrec_params.two_stage_shared_o; - cache.unshared_lse = llmrec_params.two_stage_unshared_lse; - cache.unshared_o = llmrec_params.two_stage_unshared_o; - cache.q_cu_seq_lens_shared = llmrec_params.two_stage_q_cu_seq_lens_shared; - cache.qo_indptr_expanded = llmrec_params.two_stage_qo_indptr_expanded; - cache.paged_kv_indptr_expanded = - llmrec_params.two_stage_paged_kv_indptr_expanded; - cache.paged_kv_indices_expanded = - llmrec_params.two_stage_paged_kv_indices_expanded; - cache.paged_kv_last_page_len_expanded = - llmrec_params.two_stage_paged_kv_last_page_len_expanded; - - if (cache.q_cu_seq_lens_shared.defined()) { - cache.cached_batch_size = - static_cast(cache.q_cu_seq_lens_shared.numel()) - 1; - } - cache.cached_beam_size = llmrec_params.beam_width; - if (!llmrec_params.unshared_k_caches.empty()) { - cache.cached_max_decode_step = - static_cast(llmrec_params.unshared_k_caches[0].size(2)); - } - if (cache.shared_o.defined() && cache.shared_o.dim() == 3) { - cache.cached_num_heads = static_cast(cache.shared_o.size(1)); - cache.cached_head_size = static_cast(cache.shared_o.size(2)); - } - if (llmrec_params.current_round_tensor.defined() && - llmrec_params.current_round_tensor.numel() > 0) { - cache.cached_step = llmrec_params.current_round_tensor.item(); - } -#endif + attn_metadata.xattention_two_stage_decode_cache.emplace( + XAttentionTwoStageDecodeCache{}); + auto& cache = attn_metadata.xattention_two_stage_decode_cache.value(); + + cache.shared_lse = llmrec_params.two_stage_shared_lse; + cache.shared_o = llmrec_params.two_stage_shared_o; + cache.unshared_lse = llmrec_params.two_stage_unshared_lse; + cache.unshared_o = llmrec_params.two_stage_unshared_o; + cache.q_cu_seq_lens_shared = llmrec_params.two_stage_q_cu_seq_lens_shared; + cache.qo_indptr_expanded = llmrec_params.two_stage_qo_indptr_expanded; + cache.paged_kv_indptr_expanded = + llmrec_params.two_stage_paged_kv_indptr_expanded; + cache.paged_kv_indices_expanded = + llmrec_params.two_stage_paged_kv_indices_expanded; + cache.paged_kv_last_page_len_expanded = + llmrec_params.two_stage_paged_kv_last_page_len_expanded; + + if (cache.q_cu_seq_lens_shared.defined()) { + cache.cached_batch_size = + static_cast(cache.q_cu_seq_lens_shared.numel()) - 1; + } + cache.cached_beam_size = llmrec_params.beam_width; + if (!llmrec_params.unshared_k_caches.empty()) { + cache.cached_max_decode_step = + static_cast(llmrec_params.unshared_k_caches[0].size(2)); } + if (cache.shared_o.defined() && cache.shared_o.dim() == 3) { + cache.cached_num_heads = static_cast(cache.shared_o.size(1)); + cache.cached_head_size = static_cast(cache.shared_o.size(2)); + } + if (llmrec_params.current_round_tensor.defined() && + llmrec_params.current_round_tensor.numel() > 0) { + cache.cached_step = llmrec_params.current_round_tensor.item(); + } +#endif } return attn_metadata; diff --git a/xllm/core/layers/cuda/xattention.cpp b/xllm/core/layers/cuda/xattention.cpp index 767a999285..d902d85a4c 100644 --- a/xllm/core/layers/cuda/xattention.cpp +++ b/xllm/core/layers/cuda/xattention.cpp @@ -18,7 +18,6 @@ limitations under the License. #include #include "core/common/global_flags.h" -#include "core/framework/config/rec_config.h" #include "core/framework/config/scheduler_config.h" #include "core/platform/device.h" #include "flashinfer_planinfo.h" @@ -44,73 +43,6 @@ XAttentionImpl::XAttentionImpl(int64_t num_heads, num_kv_heads, sliding_window) {} -void XAttentionImpl::run_single_stage_decode( - const AttentionMetadata& attn_metadata, - const torch::Tensor& key, - torch::Tensor& query, - torch::Tensor& output) { - torch::Tensor full_k_cache = attn_metadata.full_k_cache.unsqueeze(1); - torch::Tensor full_v_cache = attn_metadata.full_v_cache.unsqueeze(1); - - if (attn_metadata.enable_cuda_graph) { - CHECK(attn_metadata.plan_info->plan_info.defined()) - << "plan_info plan_info should not be null when enable_cuda_graph is " - "true"; - VLOG(kGraphExecutorLogVerboseLevel) - << "no need to update plan_info for CUDA graph"; - } else { - std::string backend = xllm::kernel::cuda::determine_attention_backend( - /*pos_encoding_mode=*/0, - /*use_fp16_qk_reduction=*/false, - /*use_custom_mask=*/false); - flashinfer::update_decode_plan_info( - attn_metadata.plan_info, - backend, - attn_metadata, - query.scalar_type(), - key.scalar_type(), - output.scalar_type(), - head_size_, - head_size_, - num_heads_, - num_kv_heads_, - /*block_size=*/full_k_cache.size(1), - /*window_size_left=*/sliding_window_, - /*enable_cuda_graph=*/false, - /*use_tensor_core=*/decode_use_tensor_core_); - } - - std::optional unshared_lse = std::nullopt; - - torch::Tensor float_workspace_buffer = - flashinfer::FlashinferWorkspace::get_instance() - .get_float_workspace_buffer(); - torch::Tensor int_workspace_buffer = - flashinfer::FlashinferWorkspace::get_instance() - .get_int_workspace_buffer(); - torch::Tensor page_locked_int_workspace_buffer = - flashinfer::FlashinferWorkspace::get_instance() - .get_page_locked_int_workspace_buffer(); - - xllm::kernel::cuda::batch_decode(attn_metadata.plan_info->uri, - attn_metadata.plan_info->plan_info, - float_workspace_buffer, - int_workspace_buffer, - page_locked_int_workspace_buffer, - query, - full_k_cache, - full_v_cache, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len, - sliding_window_, - scale_, - output, - unshared_lse, - decode_use_tensor_core_, - attn_metadata.qo_indptr); -} - void XAttentionImpl::run_two_stage_decode( const AttentionMetadata& attn_metadata, torch::Tensor& query, @@ -127,6 +59,8 @@ void XAttentionImpl::run_two_stage_decode( CHECK_EQ(total_beam % batch_size, 0) << "total_beam must be divisible by batch_size"; + CHECK(attn_metadata.xattention_two_stage_decode_cache.has_value()) + << "two-stage decode cache must be initialized."; const auto& cache = attn_metadata.xattention_two_stage_decode_cache.value(); CHECK(cache.shared_lse.defined() && cache.shared_o.defined() && cache.unshared_lse.defined() && cache.unshared_o.defined()) @@ -427,11 +361,7 @@ void XAttentionImpl::decoder_forward(const AttentionMetadata& attn_metadata, attn_metadata.unshared_k_cache, attn_metadata.unshared_v_cache, attn_metadata.step_tensor); - if (::xllm::RecConfig::get_instance().enable_xattention_one_stage()) { - run_single_stage_decode(attn_metadata, key, query, output); - } else { - run_two_stage_decode(attn_metadata, query, output); - } + run_two_stage_decode(attn_metadata, query, output); } } // namespace layer diff --git a/xllm/core/layers/cuda/xattention.h b/xllm/core/layers/cuda/xattention.h index 40e78c8693..fafc15c52b 100644 --- a/xllm/core/layers/cuda/xattention.h +++ b/xllm/core/layers/cuda/xattention.h @@ -57,10 +57,6 @@ class XAttentionImpl : public BaseAttentionImpl { torch::Tensor& key, torch::Tensor& value, torch::Tensor& output); - void run_single_stage_decode(const AttentionMetadata& attn_metadata, - const torch::Tensor& key, - torch::Tensor& query, - torch::Tensor& output); void run_two_stage_decode(const AttentionMetadata& attn_metadata, torch::Tensor& query, diff --git a/xllm/core/runtime/cuda_graph_executor_impl.cpp b/xllm/core/runtime/cuda_graph_executor_impl.cpp index 3d05c6146f..bafcfa0807 100644 --- a/xllm/core/runtime/cuda_graph_executor_impl.cpp +++ b/xllm/core/runtime/cuda_graph_executor_impl.cpp @@ -509,9 +509,7 @@ std::optional CudaGraphPersistentParam::update( const bool is_decode_with_llmrec = params.meta.batch_forward_type.is_decode() && params.has_llmrec_params(); - const bool use_two_stage_decode = - !::xllm::RecConfig::get_instance().enable_xattention_one_stage() && - is_decode_with_llmrec; + const bool use_two_stage_decode = is_decode_with_llmrec; const int32_t head_dim = args_.head_dim(); const int64_t tp_size = options_.world_size() / std::max(options_.dp_size(), 1); diff --git a/xllm/core/runtime/rec_worker_impl.cpp b/xllm/core/runtime/rec_worker_impl.cpp index 71b33c1555..817d0e9e2f 100644 --- a/xllm/core/runtime/rec_worker_impl.cpp +++ b/xllm/core/runtime/rec_worker_impl.cpp @@ -40,7 +40,6 @@ limitations under the License. #include "kernels/cuda/xattention/xattention_ops_api.h" #include "layers/cuda/flashinfer_workspace.h" #include "layers/cuda/xattention_workspace.h" -#include "platform/cuda/device_capture_lock.h" #endif #if defined(USE_NPU) #include "kernels/npu/npu_ops_api.h" @@ -1916,7 +1915,6 @@ RecWorkerImpl::LlmRecMultiRoundPipeline::LlmRecMultiRoundPipeline( : 0; beam_width_ = runtime_.worker.options_.beam_width(); - full_kv_cache_offsets_ = std::make_unique(this); allocate_kv_caches_related(); } @@ -1993,10 +1991,6 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::allocate_kv_caches_related() { cached_current_round_tensor_ = torch::zeros({1}, int_options); cached_beam_width_tensor_ = torch::zeros({1}, int_options); - if (::xllm::RecConfig::get_instance().enable_xattention_one_stage()) { - return; - } - const int64_t num_heads = runtime_.context->get_model_args().n_heads(); const int64_t max_total_beam = static_cast(max_seqs_per_batch_) * beam_width_; @@ -2417,10 +2411,6 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_two_stage_round_input( // TODO: implement prepare_two_stage_round_input for NPU #elif defined(USE_CUDA) auto& llm_rec_params = input.input_params.mutable_llmrec_params(); - CHECK_EQ(::xllm::RecConfig::get_instance().enable_xattention_one_stage(), - false) - << "prepare_two_stage_round_input should only be called when " - "two-stage decode is enabled"; input.input_params.attention.device.paged_kv_indices = torch::Tensor(); input.input_params.attention.device.paged_kv_indptr = torch::Tensor(); @@ -2566,19 +2556,9 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline::prepare_input_for_current_round( const torch::Tensor& top_tokens, const BeamSearchTensors& beam_tensors) { #if defined(USE_CUDA) - if (::xllm::RecConfig::get_instance().enable_xattention_one_stage()) { - input.input_params.attention.device.paged_kv_indices = - results.paged_kv_indices; - input.input_params.attention.device.paged_kv_indptr = - results.paged_kv_indptr; - input.input_params.attention.device.paged_kv_last_page_len = - results.paged_kv_last_page_len; - input.input_params.meta.num_sequences = - input.input_params.attention.device.paged_kv_last_page_len.numel(); - } else { - prepare_two_stage_round_input(input, round, top_tokens, beam_tensors); - return; - } + static_cast(results); + prepare_two_stage_round_input(input, round, top_tokens, beam_tensors); + return; #endif // previous_step corresponds to the decode step that produced tokens for // this round. @@ -2620,106 +2600,14 @@ RecWorkerImpl::LlmRecMultiRoundPipeline::compute_next_round_input_async( folly::Promise promise; auto future = promise.getSemiFuture(); -#if defined(USE_CUDA) - if (::xllm::RecConfig::get_instance().enable_xattention_one_stage()) { - // Capture necessary data for async computation - auto full_kv_offsets = full_kv_cache_offsets_->full_kv_offsets; - auto full_kv_mask = full_kv_cache_offsets_->full_kv_mask; - auto full_kv_indices = full_kv_cache_offsets_->full_kv_indices; - auto unshared_full_kv_offsets = full_kv_cache_offsets_->unshared_offsets; - auto real_max_decode_step_ids = full_kv_cache_offsets_->max_decode_step_ids; - uint32_t unshared_kv_begin_offset = max_tokens_per_batch_; - - // Launch async computation in thread pool (can overlap with GPU execution) - threadpool_.schedule([=, this, promise = std::move(promise)]() mutable { - auto device = runtime_.worker.device(); - auto int32_device_options = - torch::TensorOptions().dtype(torch::kInt32).device(device); - // Protect CUDA graph capture from conflicting GPU work submitted on - // prepare_stream_ while capture is in progress. Use shared lock to allow - // multiple prepare operations to run concurrently, but prevent conflicts - // with capture operations. This mirrors the NPU DeviceCaptureLock usage - // in WorkerImpl::prepare_work_before_execute. - std::optional> lock_guard; - if (runtime_.worker.options_.enable_graph()) { - auto& replay_lock = - ::xllm::cuda::DeviceCaptureLock::get_instance().get_read_lock( - runtime_.worker.device_.index()); - lock_guard.emplace(replay_lock); - } - - c10::StreamGuard streamGuard = - runtime_.worker.prepare_stream_->set_stream_guard(); - auto shared_kv_offsets = full_kv_offsets.slice(2, 0, max_token_per_req_) - .slice(0, 0, batch_size); - - auto shared_kv_lens_each_batch = torch::diff(kv_seq_lens); - - auto shared_kv_lens_each_batch_broadcast = - shared_kv_lens_each_batch.unsqueeze(1).unsqueeze(1); - - auto shared_mask = - full_kv_mask.slice(2, 0, max_token_per_req_).slice(0, 0, batch_size); - - shared_mask.copy_(shared_kv_offsets < - shared_kv_lens_each_batch_broadcast); - - auto kv_lens_batch_offsets = kv_seq_lens.slice(0, 0, -1); - - auto kv_lens_batch_offsets_broadcast = - kv_lens_batch_offsets.unsqueeze(1).unsqueeze(1); - - auto shared_kv_indices = full_kv_indices.slice(2, 0, max_token_per_req_) - .slice(0, 0, batch_size); - - shared_kv_indices.copy_(kv_lens_batch_offsets_broadcast + - shared_kv_offsets); - - auto unshared_kv_offsets = - unshared_full_kv_offsets.slice(0, 0, batch_size); - int32_t unshared_kv_len = beam_width * max_decode_step; - auto unshared_kv_indices = - full_kv_indices - .slice( - 2, max_token_per_req_, max_token_per_req_ + unshared_kv_len) - .slice(0, 0, batch_size); - unshared_kv_indices.copy_(unshared_kv_offsets + unshared_kv_begin_offset); - - auto unshared_mask = - full_kv_mask - .slice( - 2, max_token_per_req_, max_token_per_req_ + unshared_kv_len) - .slice(0, 0, batch_size); - auto real_max_decode_step_ids_slice = - real_max_decode_step_ids.slice(0, 0, batch_size); - unshared_mask.copy_(real_max_decode_step_ids_slice <= current_step); - - unshared_kv_len = current_step + 1; - - auto batch_beam_shared_kv_lens = - (shared_kv_lens_each_batch.unsqueeze(1).expand({-1, beam_width}) + - unshared_kv_len) - .flatten(); - auto cumsum_result = torch::cumsum(batch_beam_shared_kv_lens, 0); - auto paged_kv_indptr = - torch::cat({torch::zeros({1}, int32_device_options), - cumsum_result.to(int32_device_options)}, - 0); - auto paged_kv_indices = full_kv_indices.masked_select(full_kv_mask); - auto paged_kv_last_page_len = - torch::ones({batch_size * beam_width}, int32_device_options); - runtime_.worker.prepare_stream_->synchronize(); - - NextRoundInputResults results; - results.paged_kv_indices = paged_kv_indices; - results.paged_kv_indptr = paged_kv_indptr; - results.paged_kv_last_page_len = paged_kv_last_page_len; - promise.setValue(results); - }); - } else { - promise.setValue(NextRoundInputResults{}); - } -#endif + static_cast(kv_seq_lens); + static_cast(current_step); + static_cast(batch_size); + static_cast(beam_width); + static_cast(max_decode_step); + // CUDA now uses the two-stage path only, so the future is just a ready + // placeholder that keeps the round-preparation flow unchanged. + promise.setValue(NextRoundInputResults{}); return future; } @@ -2746,7 +2634,8 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline:: next_round_async_result.reset(); } - // Phase B: schedule async computation for the next round, if any. + // Phase B: prepare the next-round placeholder so the call flow stays + // consistent across backends. if (round < total_rounds - 1) { next_round_async_result = compute_next_round_input_async( input.input_params.attention.device.kv_seq_lens, @@ -2757,67 +2646,12 @@ void RecWorkerImpl::LlmRecMultiRoundPipeline:: } } -RecWorkerImpl::LlmRecMultiRoundPipeline::FullKvCacheOffsets::FullKvCacheOffsets( - LlmRecMultiRoundPipeline* multi_round_pipeline) { -#if defined(USE_NPU) -// TODO: implement FullKvCacheOffsets for NPU -#elif defined(USE_CUDA) - auto device = multi_round_pipeline->runtime().worker.device(); - auto int32_device_options = - torch::TensorOptions().dtype(torch::kInt32).device(device); - int32_t max_decode_step = get_rec_multi_round_decode_rounds() - 1; - full_kv_offsets = - torch::arange(0, - multi_round_pipeline->max_token_per_req_ + max_decode_step, - int32_device_options) - .unsqueeze(0) - .expand({multi_round_pipeline->max_seqs_per_batch_, -1}) - .unsqueeze(1) - .expand({-1, multi_round_pipeline->beam_width_, -1}); - full_kv_mask = - torch::zeros({multi_round_pipeline->max_seqs_per_batch_, - multi_round_pipeline->beam_width_, - multi_round_pipeline->max_token_per_req_ + max_decode_step}, - int32_device_options) - .to(torch::kBool); - full_kv_indices = torch::zeros_like(full_kv_offsets); - - auto batch_ids = - torch::arange( - 0, multi_round_pipeline->max_seqs_per_batch_, int32_device_options) - .unsqueeze(1) - .unsqueeze(2) - .expand({-1, multi_round_pipeline->beam_width_, max_decode_step}) * - (multi_round_pipeline->beam_width_ * max_decode_step); - - auto beams_ids = - torch::arange(0, multi_round_pipeline->beam_width_, int32_device_options) - .unsqueeze(0) - .unsqueeze(2) - .expand({multi_round_pipeline->max_seqs_per_batch_, - -1, - max_decode_step}) * - max_decode_step; - - max_decode_step_ids = torch::arange(0, max_decode_step, int32_device_options) - .unsqueeze(0) - .unsqueeze(1) - .expand({multi_round_pipeline->max_seqs_per_batch_, - multi_round_pipeline->beam_width_, - -1}); - unshared_offsets = batch_ids + beams_ids + max_decode_step_ids; -#endif -} - // ============================================================ // RecWorkerImpl Implementation // ============================================================ void RecWorkerImpl::initialize_xattention_workspace() { #if defined(USE_CUDA) - if (::xllm::RecConfig::get_instance().enable_xattention_one_stage()) { - return; - } ::xllm::layer::xattention::XAttentionWorkspace::get_instance().initialize( device_); #endif diff --git a/xllm/core/runtime/rec_worker_impl.h b/xllm/core/runtime/rec_worker_impl.h index 247f2fec41..6948eab8f2 100644 --- a/xllm/core/runtime/rec_worker_impl.h +++ b/xllm/core/runtime/rec_worker_impl.h @@ -261,18 +261,10 @@ class RecWorkerImpl : public LLMWorkerImpl { const BeamSearchTensors& beam_tensors, ForwardOutput& output); - // Structure to hold async computation results for next round input - struct NextRoundInputResults { -#if defined(USE_NPU) -// TODO: implement NextRoundInputResults for NPU -#elif defined(USE_CUDA) - torch::Tensor paged_kv_indices; - torch::Tensor paged_kv_indptr; - torch::Tensor paged_kv_last_page_len; -#endif - }; + // Placeholder for backend-specific round handoff state. + struct NextRoundInputResults {}; - // Compute next round input asynchronously (can overlap with GPU execution) + // Keep the round-transition flow uniform across backends. folly::SemiFuture compute_next_round_input_async( const torch::Tensor& kv_seq_lens, int32_t current_step, @@ -280,7 +272,7 @@ class RecWorkerImpl : public LLMWorkerImpl { int32_t beam_width, int32_t max_decode_step); - // Apply async result to prepare decode input for current round + // Prepare decode input for the current round. void prepare_input_for_current_round(ForwardInput& input, const NextRoundInputResults& results, int32_t round, @@ -315,21 +307,6 @@ class RecWorkerImpl : public LLMWorkerImpl { void prepare_kv_caches_related_for_input(const ForwardInput& inputs, ForwardInput& processed_inputs); - struct FullKvCacheOffsets { - explicit FullKvCacheOffsets( - LlmRecMultiRoundPipeline* multi_round_pipeline); -#if defined(USE_NPU) -// TODO: implement FullKvCacheOffsets for NPU -#elif defined(USE_CUDA) - torch::Tensor full_kv_offsets; - torch::Tensor full_kv_mask; - torch::Tensor full_kv_indices; - torch::Tensor unshared_offsets; - torch::Tensor max_decode_step_ids; -#endif - }; - std::unique_ptr full_kv_cache_offsets_; - std::vector cached_full_k_caches_; std::vector cached_full_v_caches_; torch::Tensor cached_naive_block_table_; @@ -347,9 +324,6 @@ class RecWorkerImpl : public LLMWorkerImpl { std::unique_ptr rec_sampler_; - // for async scheduler - ThreadPool threadpool_; - int32_t max_seqs_per_batch_; int32_t max_tokens_per_batch_; int32_t max_token_per_req_;