Skip to content

fix(trtllm-mla): make spec-decode CUDA graph capture causal#260

Open
mesaleh wants to merge 1 commit into
lightseekorg:mainfrom
mesaleh:fix/trtllm-mla-cudagraph-capture
Open

fix(trtllm-mla): make spec-decode CUDA graph capture causal#260
mesaleh wants to merge 1 commit into
lightseekorg:mainfrom
mesaleh:fix/trtllm-mla-cudagraph-capture

Conversation

@mesaleh
Copy link
Copy Markdown

@mesaleh mesaleh commented May 26, 2026

Summary

  • initialize spec-decode CUDA graph capture with internally consistent synthetic sequence lengths and cache slots
  • point TRT-LLM MLA capture block tables at the same synthetic per-request pages and zero those pages before capture
  • flatten multi-token TRT-LLM MLA decode into per-token decode queries so target verification and draft catch-up use causal sequence bounds per token

Motivation

While bringing up Kimi K2.6 NVFP4 with EAGLE3 speculative decoding and --attention-backend trtllm_mla, CUDA graph startup could enter an impossible synthetic state: multi-token decode capture used grouped metadata that did not match the dummy KV/cache locations. This led to capture-time instability and downstream invalid draft ids.

This PR is intentionally independent of #217. It does not touch the EAGLE first-step reduce plumbing and merge-tests cleanly with pull/217/head.

Validation

  • python3 -m compileall -q python/tokenspeed/runtime/execution/cuda_graph_wrapper.py python/tokenspeed/runtime/layers/attention/backends/trtllm_mla.py
  • Local GB200 validation: Kimi K2.6 NVFP4 + EAGLE3 + trtllm_mla captured batch sizes [1, 2, 3, 4, 5, 6, 7, 8] and reached healthy readiness
  • OpenAI-compatible synthetic benchmark completed 10/10 requests with p50 TPOT 3.760 ms and mean decode 264.676 tok/s

@mesaleh mesaleh requested a review from a team as a code owner May 26, 2026 10:11
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 476a102cec

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

# Target verification receives seq_lens at the end of the
# speculative window. Convert to per-token bounds:
# [valid+1, valid+2, ..., valid+q_len].
base_lens = base_lens - (q_len_per_req - 1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard per-token seq_lens against padded decode rows

In the new non-draft multi-token path, base_lens is shifted by q_len_per_req - 1, but CUDA-graph replay pads synthetic rows with seq_lens=1 when bs is rounded up (see cuda_graph_wrapper.py padding path). For any padded replay with q_len_per_req > 1, this produces zero/negative seq_lens for the padded rows, which are then passed into trtllm_batch_decode_with_kv_cache_mla; that can cause invalid causal bounds or kernel misbehavior for non-captured batch sizes (e.g., bs=3 padded to 4).

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in d57fe3d. This was valid: CUDA graph replay can pad target decode rows with seq_len=1, and the non-draft multi-token path shifted that value backward by q_len_per_req - 1 before expanding per-token bounds. TRTLLM MLA now clamps the shifted dummy-row base length to at least one before adding per-token offsets, so padded rows stay valid while real verification rows keep the intended causal sequence lengths. I also amended the PR commit metadata/signoff to use the GitHub noreply identity.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d57fe3d9bd

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +368 to +372
* pages_per_req
* page_size
)
self.input_buffers.out_cache_loc_buf[: bs * tokens_per_req].copy_(
(request_offsets + token_offsets).reshape(-1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Bound synthetic capture offsets to allocated KV range

This synthetic out_cache_loc layout assumes there are at least bs * pages_per_req * page_size writable token slots, but capture sizes come from max_num_seqs while KV capacity can be much smaller (for example with low max_total_tokens or CI-small KV settings). In that case these computed locations exceed the backing KV buffer, and set_mla_kv_buffer later writes with out-of-range indices during warmup/capture, causing startup-time failures. The previous capture path stayed on the dummy slot and did not require this extra capacity, so this is a new regression for constrained-memory configurations.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e7df291. This was valid: the synthetic private-page layout made capture require bs * pages_per_req padding slots even though constrained KV configurations may not have that much capacity. Capture now uses the reserved padding page for all synthetic rows, writes dummy KV only to that bounded page, zeroes that page before warmup/capture, and has TRTLLM MLA capture block tables point at the same dummy page. Replay still refreshes real request block tables from req_to_page, so this only affects synthetic capture.

Signed-off-by: Moustafa Saleh <8815169+mesaleh@users.noreply.github.com>
@mesaleh mesaleh force-pushed the fix/trtllm-mla-cudagraph-capture branch from d57fe3d to e7df291 Compare May 26, 2026 19:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant