fix(eagle): avoid mutating drafter sequence lengths#318
Conversation
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: cee1c77f54
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| draft_first_step_reduce = forward_mode.is_decode() | ||
|
|
||
| if draft_first_step_reduce and self.attn_backend.support_kv_cache_prewrite: | ||
| # Trim seq_lens by rejected-draft count so the sliced decode | ||
| # query does not attend to dead positions. | ||
| correction = (self.spec_num_tokens - draft_input.accept_lengths).to( | ||
| self.draft_seq_lens_buf.dtype | ||
| ) | ||
| self.draft_seq_lens_buf[:bs].sub_(correction) | ||
|
|
||
| ctx = ForwardContext( |
There was a problem hiding this comment.
Restore accepted-length seq_lens for draft catch-up
When speculative decode rejects any candidate (accept_lengths < spec_num_tokens) and the drafter backend uses KV prewrite, draft_seq_lens_buf is copied from input_buffers.seq_lens_buf, which still includes the full verification width (valid_cache_lengths + spec_num_tokens). Without subtracting the rejected-token tail before _run_first_step, the reduced first draft query attends to dead rejected positions in the draft KV cache; the later multi-step path already recomputes cache_start from valid_cache_lengths + accept_lengths, so this mismatch only corrupts the first drafted token after partial rejection.
Useful? React with 👍 / 👎.
|
Thank you for the fix. I overlooked this case before. The correction was originally added for Llama Eagle3's decode-kernel direct dispatch optimization (Q-slice + decode kernel on the prewrite-decode catch-up step): since the decode kernel has no causal mask, the seq_lens correction is required there to prevent the catch-up query from attending to the previous multi-step's rejected positions. Qwen3.5 NextN and DeepSeek NextN don't use this optimization, but they currently enter the same branch, which is where the bug originates. I think this PR change alone can currently fix Qwen nextN, but it may cause minor errors in Llama Eagle3. |
Yeah, i think it does make sense. Then can delete the decode-kernel direct dispatch branch in if ctx.draft_first_step_reduce:
# KV already written via fused_set_kv_buffer_arg above; slice Q
# to one query per request and route attn as decode.
q_rope = q_rope.index_select(0, ctx.gather_ids)
attn_output = ctx.attn_backend.forward(
q_rope,
None,
None,
self.attn,
out_cache_loc,
ctx.token_to_kv_pool,
ForwardMode.DECODE,
ctx.bs,
save_kv_cache=False,
)
else:
attn_output = self.attn(
q_rope,
None,
None,
save_kv_cache=False,
ctx=ctx,
out_cache_loc=out_cache_loc,
)
And move this part to outside of if-else for post attn trim. |
Investigating the root cause, let's hold this PR for now. |
|
@XucSh will do a quick fix in a separate PR. And then work on refactoring. |

Summary
Test Plan