refactor: remove xattention one-stage decode path.#1504
Conversation
There was a problem hiding this comment.
Code Review
This pull request removes the one-stage decode execution path for XAttention, standardizing the implementation on the two-stage approach. The changes include the removal of the enable_xattention_one_stage configuration flag from the C API and core framework, the deletion of the run_single_stage_decode method, and the simplification of the REC worker's asynchronous input preparation logic. Feedback for this PR highlights a critical issue in the updated unit tests where the max_tokens_per_batch configuration is inconsistent with the sequence lengths used in the test input, potentially leading to memory corruption or incorrect attention results during validation.
| torch::Tensor run_decode_once(DecodeTestInput& input, bool enable_two_stage) { | ||
| RecConfig::get_instance().enable_xattention_one_stage(!enable_two_stage); | ||
| torch::Tensor run_two_stage_decode_once(DecodeTestInput& input) { | ||
| SchedulerConfig::get_instance().max_tokens_per_batch(kSharedSeqLen); |
There was a problem hiding this comment.
The max_tokens_per_batch configuration is set to kSharedSeqLen (300), but the test setup in create_decode_test_input (lines 237-241) defines kv_cu_seq_lens reaching up to kBatchSize * kSharedSeqLen (1200). In the two-stage decode implementation (xattention.cpp), max_tokens_per_batch is used as the boundary between shared and unshared KV cache. This mismatch causes the shared prompts to overlap with the unshared beam-specific cache area, which will lead to incorrect attention results or memory corruption during testing.
No description provided.