Skip to content

fix(eagle): avoid mutating drafter sequence lengths#318

Draft
XucSh wants to merge 1 commit into
mainfrom
Xuchun/fix
Draft

fix(eagle): avoid mutating drafter sequence lengths#318
XucSh wants to merge 1 commit into
mainfrom
Xuchun/fix

Conversation

@XucSh
Copy link
Copy Markdown
Contributor

@XucSh XucSh commented May 30, 2026

Summary

Test Plan

Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
@XucSh XucSh requested a review from a team as a code owner May 30, 2026 03:07
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: cee1c77f54

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 223 to 225
draft_first_step_reduce = forward_mode.is_decode()

if draft_first_step_reduce and self.attn_backend.support_kv_cache_prewrite:
# Trim seq_lens by rejected-draft count so the sliced decode
# query does not attend to dead positions.
correction = (self.spec_num_tokens - draft_input.accept_lengths).to(
self.draft_seq_lens_buf.dtype
)
self.draft_seq_lens_buf[:bs].sub_(correction)

ctx = ForwardContext(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore accepted-length seq_lens for draft catch-up

When speculative decode rejects any candidate (accept_lengths < spec_num_tokens) and the drafter backend uses KV prewrite, draft_seq_lens_buf is copied from input_buffers.seq_lens_buf, which still includes the full verification width (valid_cache_lengths + spec_num_tokens). Without subtracting the rejected-token tail before _run_first_step, the reduced first draft query attends to dead rejected positions in the draft KV cache; the later multi-step path already recomputes cache_start from valid_cache_lengths + accept_lengths, so this mismatch only corrupts the first drafted token after partial rejection.

Useful? React with 👍 / 👎.

@rjzhb
Copy link
Copy Markdown
Contributor

rjzhb commented May 30, 2026

Thank you for the fix. I overlooked this case before. The correction was originally added for Llama Eagle3's decode-kernel direct dispatch optimization (Q-slice + decode kernel on the prewrite-decode catch-up step): since the decode kernel has no causal mask, the seq_lens correction is required there to prevent the catch-up query from attending to the previous multi-step's rejected positions. Qwen3.5 NextN and DeepSeek NextN don't use this optimization, but they currently enter the same branch, which is where the bug originates. I think this PR change alone can currently fix Qwen nextN, but it may cause minor errors in Llama Eagle3.

@XucSh XucSh marked this pull request as draft May 30, 2026 04:39
@LorrinWWW
Copy link
Copy Markdown
Contributor

Maybe a tentative walkaround is to remove this code block (what this PR did), and also change llama_eagle to post-attn-trim to make sure there is no correctness issue in main. And we can bring it back when we have a clean solution to localize the trim. Wdyt? @XucSh @rjzhb

@rjzhb
Copy link
Copy Markdown
Contributor

rjzhb commented May 30, 2026

Maybe a tentative walkaround is to remove this code block (what this PR did), and also change llama_eagle to post-attn-trim to make sure there is no correctness issue in main. And we can bring it back when we have a clean solution to localize the trim. Wdyt? @XucSh @rjzhb

Yeah, i think it does make sense.

Then can delete the decode-kernel direct dispatch branch in python/tokenspeed/runtime/models/llama_eagle3.py:189-212, keeping only the else (standard PagedAttention call):

if ctx.draft_first_step_reduce:
    # KV already written via fused_set_kv_buffer_arg above; slice Q
    # to one query per request and route attn as decode.
    q_rope = q_rope.index_select(0, ctx.gather_ids)
    attn_output = ctx.attn_backend.forward(
        q_rope,
        None,
        None,
        self.attn,
        out_cache_loc,
        ctx.token_to_kv_pool,
        ForwardMode.DECODE,
        ctx.bs,
        save_kv_cache=False,
    )
else:
    attn_output = self.attn(
        q_rope,
        None,
        None,
        save_kv_cache=False,
        ctx=ctx,
        out_cache_loc=out_cache_loc,
    )
image

And move this part to outside of if-else for post attn trim.

@XucSh
Copy link
Copy Markdown
Contributor Author

XucSh commented May 30, 2026

Maybe a tentative walkaround is to remove this code block (what this PR did), and also change llama_eagle to post-attn-trim to make sure there is no correctness issue in main. And we can bring it back when we have a clean solution to localize the trim. Wdyt? @XucSh @rjzhb

Investigating the root cause, let's hold this PR for now.

@LorrinWWW
Copy link
Copy Markdown
Contributor

@XucSh will do a quick fix in a separate PR. And then work on refactoring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants