Skip to content

fix(deepseek-v4): close MTP acceptance gap#207

Merged
lightseek-bot merged 21 commits into
lightseekorg:mainfrom
Xiangyi1996:xiangyi/v4-mtp-gap-rebased
Jun 1, 2026
Merged

fix(deepseek-v4): close MTP acceptance gap#207
lightseek-bot merged 21 commits into
lightseekorg:mainfrom
Xiangyi1996:xiangyi/v4-mtp-gap-rebased

Conversation

@Xiangyi1996
Copy link
Copy Markdown
Contributor

@Xiangyi1996 Xiangyi1996 commented May 21, 2026

Summary

++This PR is a bit large because it contains the V4/R1 MTP stack plus the acceptance-gap fix and follow-up CI/review fixes.

This PR closes the DeepSeek V4 MTP acceptance gap between TokenSpeed and TRTLLM.

Root cause:

  • The remaining gap was not from compressed KV / CSA indexer cache.
  • It came from MTP draft decode using stale/incorrect V4 paged KV cache metadata.
  • V4 has multiple cache tables; the SWA compact table could be observed with the wrong request context during draft/target-verify transitions.

Fix:

  • Make target-verify/draft-extend forward modes explicit.
  • Refresh paged cache group metadata for MTP draft/target paths.
  • Carry V4 SWA/compressed KV/CSA metadata consistently through draft decode.
  • Keep target-verify logits/hidden states correctly for speculative decoding.
  • Add tests for V4 SWA slot sanitization / paged metadata behavior.

Validation

  • pre-commit run --all-files: passed
  • py_compile on touched runtime/test files: passed
  • Acceptance rerun after rebase:
    • Decoded Tok/Iter = 2.8447
    • Spec Accept Rate = 0.6485
    • In TRTLLM 2.8-2.9 range

@Xiangyi1996 Xiangyi1996 force-pushed the xiangyi/v4-mtp-gap-rebased branch from c285d95 to 63e22c5 Compare May 22, 2026 05:34
@Xiangyi1996 Xiangyi1996 marked this pull request as ready for review May 22, 2026 05:37
@Xiangyi1996 Xiangyi1996 requested a review from a team as a code owner May 22, 2026 05:37
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

draft_cache_cell_size = (
draft_attn_config.cache_cell_size()
* draft_model_config.num_attention_layers
)

P1 Badge Use V4 grouped draft cache size in page-budget profiling

When the draft model is also DeepSeek V4 (the new is_deepseek_v4_draft_model path), this branch still computes draft_cache_cell_size from draft_attn_config.cache_cell_size(), which is the generic MLA estimate and does not include V4 grouped caches (SWA/compressed/indexer/state). profile_deepseek_v4_max_num_pages then overestimates available KV pages for target+draft, so deployments can admit a token/page budget that exceeds real GPU memory and fail with OOM under load; this should use the V4-specific draft size (draft_profile_cache_cell_size / layout-based sizing) instead.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c01d2a08dd

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/execution/forward_batch_info.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 1f8108eed5

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/execution/drafter/eagle.py Outdated
Comment thread python/tokenspeed/runtime/execution/cuda_graph_wrapper.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 2494dc30ac

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/execution/model_executor.py Outdated
Comment thread python/tokenspeed/runtime/execution/cuda_graph_wrapper.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: cb37d86925

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/execution/cuda_graph_wrapper.py Outdated
@Xiangyi1996 Xiangyi1996 force-pushed the xiangyi/v4-mtp-gap-rebased branch from f77d47a to e4223a0 Compare May 22, 2026 07:20
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e4223a006a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/models/deepseek_v4.py Outdated
@lightseek-bot
Copy link
Copy Markdown
Contributor

@Xiangyi1996 please fix the conflicts thanks!

@Xiangyi1996 Xiangyi1996 force-pushed the xiangyi/v4-mtp-gap-rebased branch from e4223a0 to c08927d Compare May 22, 2026 13:48
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c08927dc81

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/execution/cuda_graph_wrapper.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5abff4a8cd

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread python/tokenspeed/runtime/models/deepseek_v4.py Outdated
Comment thread python/tokenspeed/runtime/models/deepseek_v4.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 5fa924b9da

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread python/tokenspeed/runtime/models/deepseek_v4.py Outdated
Comment thread python/tokenspeed/runtime/engine/scheduler_utils.py
Comment thread python/tokenspeed/runtime/models/deepseek_v4_mtp.py
@Xiangyi1996 Xiangyi1996 force-pushed the xiangyi/v4-mtp-gap-rebased branch from f819dd2 to 5cc6bcd Compare May 26, 2026 06:28
@dongjiyingdjy dongjiyingdjy requested a review from yweng0828 May 27, 2026 07:12
@yweng0828 yweng0828 requested a review from syuoni May 27, 2026 07:38
@Xiangyi1996
Copy link
Copy Markdown
Contributor Author

This PR is a bit large because it contains the V4/R1 MTP stack plus the acceptance-gap fix and follow-up CI/review fixes.

The main risk areas are:

  1. ForwardMode changes: target_verify / draft_extend should only affect V4 paged-cache MTP paths.
  2. V4 paged cache metadata propagation: SWA / compressed KV / CSA indexer group tables and base offsets.
  3. Draft decode step handling: spec_step_idx and accepted-prefix advancement.
  4. Guards: prefix-cache and overlap scheduling are disabled for speculative V4 paged-cache groups.
  5. Non-V4 compatibility: existing draft/spec paths should continue using their previous metadata behavior.

Validation so far:

  • Acceptance recovered from ~2.10 to ~2.84 decoded tok/iter, back to TRTLLM 2.8-2.9 range.
  • Local b200 CI passed: ut-runtime-minimax-m2 / b200-2gpu and ut-runtime-1gpu / b200-1gpu.
  • Kimi b200-4gpu eval is being rerun once GPU/cache environment is available.

Copy link
Copy Markdown
Contributor

@dongjiyingdjy dongjiyingdjy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

23 files, +3096/-180. Enables MTP speculative decoding for V4, fixes acceptance gap vs TRTLLM (2.8447 tok/iter match). Core approach — separate TARGET_VERIFY/DRAFT_EXTEND modes with V4-specific metadata lifecycle — is sound.

Issues to address

1. _forward_flashmla_sparse per-token Python loop (perf)

The new sparse attention method iterates per-token in Python with GPU ops inside each iteration. If this is on the MTP draft hot path, it will dominate latency. Please confirm this is only a correctness fallback, or add a guard/warning.

2. Mixed batch disabled for ALL speculative models, not just V4 (regression risk)

enable_mixed_prefill_decode = (
    server_args.enable_mixed_batch and server_args.speculative_algorithm is None
)

This disables mixed batch whenever any speculative algorithm is active, including EAGLE3 on non-V4 models. If mixed batch previously worked with EAGLE3, this is a regression. Should this be V4-specific?

3. _select_decode_metadata fallback chain is fragile for CUDA graph replay

The method searches through forward_metadata, forward_decode_metadata, forward_prefill_metadata by matching num_tokens. If a stale graph replay happens to match the wrong metadata's token count, it silently reads wrong data. The RuntimeError fallback is fine for eager but catastrophic inside a captured graph.

4. _refresh_cuda_graph_packed_metadata allocates temporaries every replay

self._cuda_graph_query_start_loc[:bs+1].copy_(torch.arange(bs+1, ...) * tokens_per_req)
self._cuda_graph_token_to_req[:total_tokens].copy_(torch.arange(bs, ...).repeat_interleave(tokens_per_req))

torch.arange + repeat_interleave on every CUDA graph replay step. These could be precomputed once during capture and reused.

Suggestions

5. Repeated is_target_verify() or is_draft_extend() pattern (~15 occurrences)

Consider adding a ForwardMode.is_speculative() helper to reduce repetition and ensure consistent handling.

6. num_tokens fallback in init_forward_metadata uses stale spec config

When neither positions nor explicit num_tokens is provided for speculative modes, the code estimates from speculative_num_draft_tokens or speculative_num_steps set at init time. After partial accept the actual count could differ. Callers in CudaGraphWrapper do pass num_tokens explicitly, so this is a latent edge case, not a current bug.

Looks good

  • Test coverage is thorough (metadata shape transitions, non-V4 backend contracts, SWA slot sanitization)
  • cache_start.clone() correctly prevents drafter from modifying shared seq_lens_buf
  • SWA slot sanitization + CUDA kernel guard is defense-in-depth
  • _init_capture_metadata ordering (before warmup + before capture) is correct, though a comment explaining why it runs twice would help

@dongjiyingdjy
Copy link
Copy Markdown
Contributor

dongjiyingdjy commented May 28, 2026

Supplement: 6 new functions in this PR are never called — dead code. Suggest removing them:

  • _forward_flashmla_sparse and its 3 helpers (_slots_from_local_indices, _swa_slots_for_token, _compressed_slots_for_token)
  • _deep_gemm_fp8_linear
  • _prefill_chunk_token_offsets

If these are reserved for follow-up work, please don't include them in this PR — add them when they're actually wired in.

@Xiangyi1996
Copy link
Copy Markdown
Contributor Author

Thanks for the detailed review. I addressed the items one by one:

  1. _forward_flashmla_sparse per-token Python loop

Removed _forward_flashmla_sparse and its helper methods. The decode/speculative decode now always goes through forward_deepseek_v4_decode.

  1. Mixed batch disabled for all speculative models

Fixed. I replaced the global speculative_algorithm is None guard with should_enable_mixed_prefill_decode(). Mixed prefill/decode is now disabled only for speculative runs with V4 paged-cache groups; non-V4 speculative backends such as EAGLE3 keep the previous mixed-batch behavior.

  1. _select_decode_metadata fallback chain

_select_decode_metadata() no longer falls back to forward_prefill_metadata. It only returns matching decode metadata, otherwise None.

  1. CUDA graph replay temporary allocations

Fixed. _refresh_cuda_graph_packed_metadata() now reuses precomputed query_start and token_to_req , instead of creating torch.arange() / repeat_interleave() temporaries during replay.

  1. Repeated is_target_verify() or is_draft_extend()

Added ForwardMode.is_speculative() relpace is_target_verify() or is_draft_extend()

  1. num_tokens fallback using stale speculative config

Fixed. V4 speculative metadata setup now requires explicit num_tokens or positions

dead code:

Removed the unused _forward_flashmla_sparse helpers, _deep_gemm_fp8_linear, and _prefill_chunk_token_offsets. Chunked prefill now directly uses metadata.query_start_loc for token offsets.

…ken_id

R1-0528-NVFP4-v2 marks q_a_proj / kv_a_proj_with_mqa in exclude_modules
(stored as bf16 at logical shape), but DeepseekV3FusedQkvAProjWithMqa
allocates an NVFP4-packed buffer because the fused prefix is not in
exclude_modules. Detect component-level exclusion and pass through
quant_config=None to fall back to bf16.

Also add get_hot_token_id() returning None to DeepseekV3ForCausalLMNextN
to match the EAGLE3/MTP drafter contract (mirrors qwen3_5_nextn.py).

Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Mask graph-padded cache writes in the DeepSeek V4 paged-cache paths. CUDA graph replay can pad token rows with stale but in-range slot values, so capacity checks alone are not enough before cache inserts.

Plumb ForwardMetadata.is_valid_token through V4 cache metadata and mask invalid rows to -1 before SWA KV cache, compressed KV cache, compressor state, indexer state, and CSA indexer cache inserts. Also expand per-request request indices and validity masks to per-token form for packed TP draft decode.

Keep the earlier review follow-ups in this commit: preserve non-V4 draft-extend metadata, pass MTP spec_step_idx into draft models, and cover logits-processor gather fallback behavior.

Validated: focused pytest groups for slot mapping and V4 group mapping pass locally; DP=4 short E2E sanity 4/4 OK with no hard errors; TP=2 short E2E sanity 2/2 OK with no NCCL hang or hard errors.
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
@Xiangyi1996 Xiangyi1996 force-pushed the xiangyi/v4-mtp-gap-rebased branch from d7f0b0e to a3f8f62 Compare June 1, 2026 02:58
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: f47a91173c

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread python/tokenspeed/runtime/layers/attention/kv_cache/deepseek_v4.py
Comment thread python/tokenspeed/runtime/execution/cuda_graph_wrapper.py Outdated
Comment thread python/tokenspeed/runtime/execution/model_executor.py Outdated
Comment thread python/tokenspeed/runtime/engine/scheduler_utils.py Outdated
Signed-off-by: Xiangyi Zhang <xiangyiz@nvidia.com>
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 01942d545a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread python/tokenspeed/runtime/layers/attention/backends/deepseek_v4.py
Comment on lines +155 to +158
def __init__(self, **kwargs):
# Explicit __init__ prevents transformers from auto-generating one
# that skips Qwen3_5Config.__init__ (text/vision config setup).
super().__init__(**kwargs)
Copy link
Copy Markdown
Contributor Author

@Xiangyi1996 Xiangyi1996 Jun 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this Qwen3.5 fix is in a V4 MTP PR

This PR adds new attribute-access paths in ModelConfig.__init__ (via
_derive_num_attention_layers / is_deepseek_v4_nextn) to support V4 MTP.
Those new paths happen to exercise the Qwen3.5 config in a way that triggers
a latent bug on main introduced by transformers 5.6:

  • Qwen3_5MoeConfig had no explicit __init__
  • transformers 5.6 auto-generates one for subclasses, bypassing
    Qwen3_5Config.__init__
  • text_config is therefore left in an uninitialized/wrapped form
  • Subsequent __getattr__ forwards loop on text_config itself → RecursionError

Reproduced on origin/main (without this PR) using a minimal
self-referencing text_config setup; the same trace matches the CI
failure at qwen3_5_config.py:114-115.

@yweng0828 confirmed his other PR doesn't hit this because it doesn't go
through the new ModelConfig path.

Discussed with @dongjiyingdjy and @yweng0828: keeping the fix in this PR
since it's needed to unbreak CI here. Happy to cherry-pick to a standalone
follow-up if preferred.

CI evidence (pre-fix): https://github.com/lightseekorg/tokenspeed/actions/runs/26626283458/job/78483017501

@lightseek-bot lightseek-bot merged commit a8a2d0e into lightseekorg:main Jun 1, 2026
30 of 34 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants