Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 108 additions & 35 deletions tests/core/layers/cuda/xattention_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.
#include <memory>
#include <vector>

#include "core/framework/config/rec_config.h"
#include "core/framework/config/scheduler_config.h"
#include "framework/kv_cache/kv_cache.h"
#include "layers/cuda/flashinfer_workspace.h"
Expand All @@ -40,16 +39,94 @@ struct DecodeTestInput {
torch::Tensor value;
};

constexpr int32_t kTestBatchSize = 4;
constexpr int32_t kTestBeamWidth = 128;
constexpr int32_t kTestNumHeads = 16;
constexpr int32_t kTestNumKvHeads = 8;
constexpr int32_t kTestHeadDim = 128;
constexpr int32_t kTestSharedSeqLen = 300;
constexpr int32_t kTestMaxDecodeStep = 2;
constexpr int32_t kTestCurrentStep = 1;

torch::Tensor build_reference_output(const DecodeTestInput& input) {
const int64_t total_beam = input.query.size(0);
const int64_t shared_len = kTestSharedSeqLen;
const int64_t seq_len = shared_len + kTestCurrentStep + 1;
const int64_t group_size = kTestNumHeads / kTestNumKvHeads;

auto float_opts = torch::TensorOptions()
.dtype(torch::kFloat32)
.device(input.query.device());
torch::Tensor query =
input.query.view({total_beam, kTestNumHeads, kTestHeadDim})
.to(torch::kFloat32);
torch::Tensor key = input.attn_metadata.full_k_cache.clone();
torch::Tensor value = input.attn_metadata.full_v_cache.clone();
torch::Tensor current_key =
input.key.view({total_beam, kTestNumKvHeads, kTestHeadDim});
torch::Tensor current_value =
input.value.view({total_beam, kTestNumKvHeads, kTestHeadDim});
torch::Tensor output =
torch::empty({total_beam, kTestNumHeads * kTestHeadDim}, float_opts);

const float scale = 1.0f / std::sqrt(static_cast<float>(kTestHeadDim));
for (int64_t beam_idx = 0; beam_idx < total_beam; ++beam_idx) {
const int64_t slot_id =
shared_len + beam_idx * kTestMaxDecodeStep + kTestCurrentStep;
key.select(0, slot_id).copy_(current_key.select(0, beam_idx));
value.select(0, slot_id).copy_(current_value.select(0, beam_idx));
}
torch::Tensor shared_key = key.slice(0, 0, shared_len);
torch::Tensor shared_value = value.slice(0, 0, shared_len);

for (int64_t beam_idx = 0; beam_idx < total_beam; ++beam_idx) {
torch::Tensor beam_key =
torch::empty({seq_len, kTestNumKvHeads, kTestHeadDim}, float_opts);
torch::Tensor beam_value =
torch::empty({seq_len, kTestNumKvHeads, kTestHeadDim}, float_opts);

beam_key.slice(/*dim=*/0, /*start=*/0, /*end=*/shared_len)
.copy_(shared_key);
beam_value.slice(/*dim=*/0, /*start=*/0, /*end=*/shared_len)
.copy_(shared_value);

const int64_t beam_base = shared_len + beam_idx * kTestMaxDecodeStep;
beam_key.slice(/*dim=*/0, /*start=*/shared_len, /*end=*/seq_len)
.copy_(key.slice(0, beam_base, beam_base + kTestCurrentStep + 1));
beam_value.slice(/*dim=*/0, /*start=*/shared_len, /*end=*/seq_len)
.copy_(value.slice(0, beam_base, beam_base + kTestCurrentStep + 1));

torch::Tensor key_rep =
beam_key.unsqueeze(2)
.expand({seq_len, kTestNumKvHeads, group_size, kTestHeadDim})
.reshape({seq_len, kTestNumHeads, kTestHeadDim});
torch::Tensor value_rep =
beam_value.unsqueeze(2)
.expand({seq_len, kTestNumKvHeads, group_size, kTestHeadDim})
.reshape({seq_len, kTestNumHeads, kTestHeadDim});

torch::Tensor scores =
torch::einsum("hd,shd->hs", {query.select(0, beam_idx), key_rep}) *
scale;
torch::Tensor attn = torch::softmax(scores, /*dim=*/-1);
torch::Tensor beam_output = torch::einsum("hs,shd->hd", {attn, value_rep})
.reshape({kTestNumHeads * kTestHeadDim});
output.select(0, beam_idx).copy_(beam_output);
}

return output.to(input.query.scalar_type());
}

class XAttentionDecodeCompareTest : public ::testing::Test {
protected:
static constexpr int32_t kBatchSize = 4;
static constexpr int32_t kBeamWidth = 128;
static constexpr int32_t kNumHeads = 16;
static constexpr int32_t kNumKvHeads = 8;
static constexpr int32_t kHeadDim = 128;
static constexpr int32_t kSharedSeqLen = 300;
static constexpr int32_t kMaxDecodeStep = 2;
static constexpr int32_t kCurrentStep = 1;
static constexpr int32_t kBatchSize = kTestBatchSize;
static constexpr int32_t kBeamWidth = kTestBeamWidth;
static constexpr int32_t kNumHeads = kTestNumHeads;
static constexpr int32_t kNumKvHeads = kTestNumKvHeads;
static constexpr int32_t kHeadDim = kTestHeadDim;
static constexpr int32_t kSharedSeqLen = kTestSharedSeqLen;
static constexpr int32_t kMaxDecodeStep = kTestMaxDecodeStep;
static constexpr int32_t kCurrentStep = kTestCurrentStep;

void SetUp() override {
if (!torch::cuda::is_available()) {
Expand Down Expand Up @@ -199,8 +276,7 @@ class XAttentionDecodeCompareTest : public ::testing::Test {
return input;
}

torch::Tensor run_decode_once(DecodeTestInput& input, bool enable_two_stage) {
RecConfig::get_instance().enable_xattention_one_stage(!enable_two_stage);
torch::Tensor run_two_stage_decode_once(DecodeTestInput& input) {
SchedulerConfig::get_instance().max_tokens_per_batch(kSharedSeqLen);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The max_tokens_per_batch configuration is set to kSharedSeqLen (300), but the test setup in create_decode_test_input (lines 237-241) defines kv_cu_seq_lens reaching up to kBatchSize * kSharedSeqLen (1200). In the two-stage decode implementation (xattention.cpp), max_tokens_per_batch is used as the boundary between shared and unshared KV cache. This mismatch causes the shared prompts to overlap with the unshared beam-specific cache area, which will lead to incorrect attention results or memory corruption during testing.


XAttentionImpl attention(
Expand All @@ -224,33 +300,30 @@ class XAttentionDecodeCompareTest : public ::testing::Test {
return std::get<0>(result).clone();
}

void compare_single_and_two_stage(torch::ScalarType dtype,
double atol,
double rtol) {
void compare_two_stage_with_reference(torch::ScalarType dtype,
double atol,
double rtol) {
constexpr int64_t kSeed = 20260303;
torch::manual_seed(kSeed);
torch::cuda::manual_seed_all(kSeed);
auto single_input = create_decode_test_input(dtype);
auto input = create_decode_test_input(dtype);
torch::manual_seed(kSeed);
torch::cuda::manual_seed_all(kSeed);
auto two_stage_input = create_decode_test_input(dtype);

two_stage_input.query.copy_(single_input.query);
two_stage_input.key.copy_(single_input.key);
two_stage_input.value.copy_(single_input.value);
auto reference_input = create_decode_test_input(dtype);
reference_input.query.copy_(input.query);
reference_input.key.copy_(input.key);
reference_input.value.copy_(input.value);

auto single_output =
run_decode_once(single_input, /*enable_two_stage=*/false);
auto two_stage_output =
run_decode_once(two_stage_input, /*enable_two_stage=*/true);
auto reference_output = build_reference_output(reference_input);
auto two_stage_output = run_two_stage_decode_once(input);

auto abs_diff =
(single_output - two_stage_output).abs().to(torch::kFloat32);
(reference_output - two_stage_output).abs().to(torch::kFloat32);
const double max_abs_diff = abs_diff.max().item<double>();
const double mean_abs_diff = abs_diff.mean().item<double>();

EXPECT_TRUE(torch::allclose(single_output, two_stage_output, rtol, atol))
<< "single-stage and two-stage decode outputs mismatch: "
EXPECT_TRUE(torch::allclose(reference_output, two_stage_output, rtol, atol))
<< "reference and two-stage decode outputs mismatch: "
<< "max_abs_diff=" << max_abs_diff
<< ", mean_abs_diff=" << mean_abs_diff << ", atol=" << atol
<< ", rtol=" << rtol;
Expand All @@ -263,16 +336,16 @@ class XAttentionDecodeCompareTest : public ::testing::Test {
torch::Device device_{torch::kCPU};
};

TEST_F(XAttentionDecodeCompareTest, SingleVsTwoStageFp16) {
compare_single_and_two_stage(torch::kFloat16,
/*atol=*/2e-3,
/*rtol=*/2e-3);
TEST_F(XAttentionDecodeCompareTest, TwoStageFp16) {
compare_two_stage_with_reference(torch::kFloat16,
/*atol=*/2e-3,
/*rtol=*/2e-3);
}

TEST_F(XAttentionDecodeCompareTest, SingleVsTwoStageBf16) {
compare_single_and_two_stage(torch::kBFloat16,
/*atol=*/2e-2,
/*rtol=*/2e-2);
TEST_F(XAttentionDecodeCompareTest, TwoStageBf16) {
compare_two_stage_with_reference(torch::kBFloat16,
/*atol=*/2e-2,
/*rtol=*/2e-2);
}

} // namespace
Expand Down
1 change: 0 additions & 1 deletion xllm/c_api/default.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ const XLLM_InitOptions XLLM_INIT_REC_OPTIONS_DEFAULT = {
.enable_graph = true,
.enable_rec_fast_sampler = true,
.enable_prefill_piecewise_graph = true,
.enable_xattention_one_stage = false,
.enable_graph_mode_decode_no_padding = true,
.enable_block_copy_kernel = false,
.enable_topk_sorted = false,
Expand Down
4 changes: 1 addition & 3 deletions xllm/c_api/internal/rec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ const char* get_rec_pipeline_name(xllm::RecPipelineType pipeline_type) {
void reset_pipeline_runtime_toggles() {
xllm::RecConfig::get_instance()
.enable_rec_fast_sampler(false)
.enable_xattention_one_stage(false)
.enable_rec_prefill_only(false)
.enable_constrained_decoding(false);
xllm::ExecutionConfig::get_instance()
Expand All @@ -70,8 +69,7 @@ void reset_pipeline_runtime_toggles() {

void apply_multi_round_pipeline_toggles() {
xllm::RecConfig::get_instance()
.enable_rec_fast_sampler(true)
.enable_xattention_one_stage(false);
.enable_rec_fast_sampler(true);
xllm::ExecutionConfig::get_instance()
.enable_prefill_piecewise_graph(true)
.enable_graph_mode_decode_no_padding(true);
Expand Down
3 changes: 0 additions & 3 deletions xllm/c_api/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ typedef struct XLLM_CAPI_EXPORT XLLM_InitOptions {
/** Whether to enable prefill piecewise graph for REC */
bool enable_prefill_piecewise_graph;

/** Whether to enable xattention one-stage execution for REC */
bool enable_xattention_one_stage;

/** Whether to enable graph-mode decode without padding for REC */
bool enable_graph_mode_decode_no_padding;

Expand Down
2 changes: 0 additions & 2 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,6 @@ DECLARE_bool(enable_rec_prefill_only);

DECLARE_bool(output_rec_logprobs);

DECLARE_bool(enable_xattention_one_stage);

DECLARE_int32(max_decode_rounds);

DECLARE_bool(enable_constrained_decoding);
Expand Down
8 changes: 0 additions & 8 deletions xllm/core/framework/config/rec_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ DEFINE_bool(enable_rec_prefill_only,
"Enable rec prefill-only mode (no decoder self-attention blocks "
"allocation).");

DEFINE_bool(enable_xattention_one_stage,
false,
"Whether to force xattention one-stage decode for rec "
"multi-round mode.");

DEFINE_int32(max_decode_rounds,
0,
"Maximum number of decode rounds for multi-step decoding. "
Expand Down Expand Up @@ -85,7 +80,6 @@ namespace xllm {
void RecConfig::from_flags() {
enable_rec_fast_sampler(FLAGS_enable_rec_fast_sampler)
.enable_rec_prefill_only(FLAGS_enable_rec_prefill_only)
.enable_xattention_one_stage(FLAGS_enable_xattention_one_stage)
.max_decode_rounds(FLAGS_max_decode_rounds)
.enable_constrained_decoding(FLAGS_enable_constrained_decoding)
.output_rec_logprobs(FLAGS_output_rec_logprobs)
Expand All @@ -103,8 +97,6 @@ void RecConfig::from_json(const JsonReader& json) {
json.value_or<bool>("enable_rec_fast_sampler", enable_rec_fast_sampler()))
.enable_rec_prefill_only(json.value_or<bool>("enable_rec_prefill_only",
enable_rec_prefill_only()))
.enable_xattention_one_stage(json.value_or<bool>(
"enable_xattention_one_stage", enable_xattention_one_stage()))
.max_decode_rounds(
json.value_or<int32_t>("max_decode_rounds", max_decode_rounds()))
.enable_constrained_decoding(json.value_or<bool>(
Expand Down
3 changes: 0 additions & 3 deletions xllm/core/framework/config/rec_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class RecConfig final {
"REC OPTIONS",
{"enable_rec_fast_sampler",
"enable_rec_prefill_only",
"enable_xattention_one_stage",
"max_decode_rounds",
"enable_constrained_decoding",
"output_rec_logprobs",
Expand All @@ -58,8 +57,6 @@ class RecConfig final {

PROPERTY(bool, enable_rec_prefill_only) = false;

PROPERTY(bool, enable_xattention_one_stage) = false;

PROPERTY(int32_t, max_decode_rounds) = 0;

PROPERTY(bool, enable_constrained_decoding) = false;
Expand Down
71 changes: 34 additions & 37 deletions xllm/core/layers/common/attention_metadata_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License.
#include "attention_metadata.h"
#include "core/common/global_flags.h"
#include "core/framework/config/execution_config.h"
#include "core/framework/config/rec_config.h"
#include "framework/model/model_args.h"
#include "framework/model/model_input_params.h"

Expand Down Expand Up @@ -166,44 +165,42 @@ AttentionMetadata build_attention_metadata(
attn_metadata.step_tensor = llmrec_params.current_round_tensor;
}

if (!::xllm::RecConfig::get_instance().enable_xattention_one_stage()) {
#if defined(USE_CUDA) || defined(USE_MUSA)
attn_metadata.xattention_two_stage_decode_cache.emplace(
XAttentionTwoStageDecodeCache{});
auto& cache = attn_metadata.xattention_two_stage_decode_cache.value();

cache.shared_lse = llmrec_params.two_stage_shared_lse;
cache.shared_o = llmrec_params.two_stage_shared_o;
cache.unshared_lse = llmrec_params.two_stage_unshared_lse;
cache.unshared_o = llmrec_params.two_stage_unshared_o;
cache.q_cu_seq_lens_shared = llmrec_params.two_stage_q_cu_seq_lens_shared;
cache.qo_indptr_expanded = llmrec_params.two_stage_qo_indptr_expanded;
cache.paged_kv_indptr_expanded =
llmrec_params.two_stage_paged_kv_indptr_expanded;
cache.paged_kv_indices_expanded =
llmrec_params.two_stage_paged_kv_indices_expanded;
cache.paged_kv_last_page_len_expanded =
llmrec_params.two_stage_paged_kv_last_page_len_expanded;

if (cache.q_cu_seq_lens_shared.defined()) {
cache.cached_batch_size =
static_cast<int32_t>(cache.q_cu_seq_lens_shared.numel()) - 1;
}
cache.cached_beam_size = llmrec_params.beam_width;
if (!llmrec_params.unshared_k_caches.empty()) {
cache.cached_max_decode_step =
static_cast<int32_t>(llmrec_params.unshared_k_caches[0].size(2));
}
if (cache.shared_o.defined() && cache.shared_o.dim() == 3) {
cache.cached_num_heads = static_cast<int32_t>(cache.shared_o.size(1));
cache.cached_head_size = static_cast<int32_t>(cache.shared_o.size(2));
}
if (llmrec_params.current_round_tensor.defined() &&
llmrec_params.current_round_tensor.numel() > 0) {
cache.cached_step = llmrec_params.current_round_tensor.item<int32_t>();
}
#endif
attn_metadata.xattention_two_stage_decode_cache.emplace(
XAttentionTwoStageDecodeCache{});
auto& cache = attn_metadata.xattention_two_stage_decode_cache.value();

cache.shared_lse = llmrec_params.two_stage_shared_lse;
cache.shared_o = llmrec_params.two_stage_shared_o;
cache.unshared_lse = llmrec_params.two_stage_unshared_lse;
cache.unshared_o = llmrec_params.two_stage_unshared_o;
cache.q_cu_seq_lens_shared = llmrec_params.two_stage_q_cu_seq_lens_shared;
cache.qo_indptr_expanded = llmrec_params.two_stage_qo_indptr_expanded;
cache.paged_kv_indptr_expanded =
llmrec_params.two_stage_paged_kv_indptr_expanded;
cache.paged_kv_indices_expanded =
llmrec_params.two_stage_paged_kv_indices_expanded;
cache.paged_kv_last_page_len_expanded =
llmrec_params.two_stage_paged_kv_last_page_len_expanded;

if (cache.q_cu_seq_lens_shared.defined()) {
cache.cached_batch_size =
static_cast<int32_t>(cache.q_cu_seq_lens_shared.numel()) - 1;
}
cache.cached_beam_size = llmrec_params.beam_width;
if (!llmrec_params.unshared_k_caches.empty()) {
cache.cached_max_decode_step =
static_cast<int32_t>(llmrec_params.unshared_k_caches[0].size(2));
}
if (cache.shared_o.defined() && cache.shared_o.dim() == 3) {
cache.cached_num_heads = static_cast<int32_t>(cache.shared_o.size(1));
cache.cached_head_size = static_cast<int32_t>(cache.shared_o.size(2));
}
if (llmrec_params.current_round_tensor.defined() &&
llmrec_params.current_round_tensor.numel() > 0) {
cache.cached_step = llmrec_params.current_round_tensor.item<int32_t>();
}
#endif
}

return attn_metadata;
Expand Down
Loading
Loading