fix(trtllm-mla): make spec-decode CUDA graph capture causal#260
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 476a102cec
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| # Target verification receives seq_lens at the end of the | ||
| # speculative window. Convert to per-token bounds: | ||
| # [valid+1, valid+2, ..., valid+q_len]. | ||
| base_lens = base_lens - (q_len_per_req - 1) |
There was a problem hiding this comment.
Guard per-token seq_lens against padded decode rows
In the new non-draft multi-token path, base_lens is shifted by q_len_per_req - 1, but CUDA-graph replay pads synthetic rows with seq_lens=1 when bs is rounded up (see cuda_graph_wrapper.py padding path). For any padded replay with q_len_per_req > 1, this produces zero/negative seq_lens for the padded rows, which are then passed into trtllm_batch_decode_with_kv_cache_mla; that can cause invalid causal bounds or kernel misbehavior for non-captured batch sizes (e.g., bs=3 padded to 4).
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in d57fe3d. This was valid: CUDA graph replay can pad target decode rows with seq_len=1, and the non-draft multi-token path shifted that value backward by q_len_per_req - 1 before expanding per-token bounds. TRTLLM MLA now clamps the shifted dummy-row base length to at least one before adding per-token offsets, so padded rows stay valid while real verification rows keep the intended causal sequence lengths. I also amended the PR commit metadata/signoff to use the GitHub noreply identity.
476a102 to
d57fe3d
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d57fe3d9bd
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| * pages_per_req | ||
| * page_size | ||
| ) | ||
| self.input_buffers.out_cache_loc_buf[: bs * tokens_per_req].copy_( | ||
| (request_offsets + token_offsets).reshape(-1) |
There was a problem hiding this comment.
Bound synthetic capture offsets to allocated KV range
This synthetic out_cache_loc layout assumes there are at least bs * pages_per_req * page_size writable token slots, but capture sizes come from max_num_seqs while KV capacity can be much smaller (for example with low max_total_tokens or CI-small KV settings). In that case these computed locations exceed the backing KV buffer, and set_mla_kv_buffer later writes with out-of-range indices during warmup/capture, causing startup-time failures. The previous capture path stayed on the dummy slot and did not require this extra capacity, so this is a new regression for constrained-memory configurations.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in e7df291. This was valid: the synthetic private-page layout made capture require bs * pages_per_req padding slots even though constrained KV configurations may not have that much capacity. Capture now uses the reserved padding page for all synthetic rows, writes dummy KV only to that bounded page, zeroes that page before warmup/capture, and has TRTLLM MLA capture block tables point at the same dummy page. Replay still refreshes real request block tables from req_to_page, so this only affects synthetic capture.
Signed-off-by: Moustafa Saleh <8815169+mesaleh@users.noreply.github.com>
d57fe3d to
e7df291
Compare
Summary
Motivation
While bringing up Kimi K2.6 NVFP4 with EAGLE3 speculative decoding and
--attention-backend trtllm_mla, CUDA graph startup could enter an impossible synthetic state: multi-token decode capture used grouped metadata that did not match the dummy KV/cache locations. This led to capture-time instability and downstream invalid draft ids.This PR is intentionally independent of #217. It does not touch the EAGLE first-step reduce plumbing and merge-tests cleanly with
pull/217/head.Validation
python3 -m compileall -q python/tokenspeed/runtime/execution/cuda_graph_wrapper.py python/tokenspeed/runtime/layers/attention/backends/trtllm_mla.pytrtllm_mlacaptured batch sizes[1, 2, 3, 4, 5, 6, 7, 8]and reached healthy readiness