Skip to content

feat(spec-decode): add native DFlash support#263

Open
mesaleh wants to merge 1 commit into
lightseekorg:mainfrom
mesaleh:feat/dflash-native-support
Open

feat(spec-decode): add native DFlash support#263
mesaleh wants to merge 1 commit into
lightseekorg:mainfrom
mesaleh:feat/dflash-native-support

Conversation

@mesaleh
Copy link
Copy Markdown

@mesaleh mesaleh commented May 26, 2026

Summary

  • Add DFLASH as a speculative decoding algorithm with a native TokenSpeed DFlash drafter and draft model.
  • Wire target hidden-state capture, non-causal draft attention, BF16 draft KV cache handling, and input_embeds forwarding for the draft runner.
  • Add Kimi K2.6 DFlash recipe docs and document the current full-history draft-attention limitation.

Notes

Tests

  • git diff --check HEAD~1..HEAD
  • python3 -m py_compile on changed Python files
  • Scanned changed files for internal names/paths before opening the PR

@mesaleh mesaleh requested a review from a team as a code owner May 26, 2026 16:52
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 19dbc99a81

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +375 to +376
if input_embeds is None:
raise ValueError("DFlashDraftModel requires input_embeds.")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle IDLE forwards without requiring input_embeds

DFlashDraftModel.forward unconditionally raises when input_embeds is None, but ModelExecutor.execute_idle_forward invokes the drafter model in ForwardMode.IDLE with empty tensors and no input_embeds to keep DP collectives aligned on idle ranks. With --speculative-algorithm DFLASH and data_parallel_size > 1, any rank that enters idle-forward will fail here instead of participating in the collective, which can stall or crash distributed decoding.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7068958. DFlashDraftModel.forward now accepts the idle path without input_embeds by creating an empty hidden-state tensor and continuing through the draft model, so idle ranks can participate in the same forward path instead of raising before distributed alignment. Also amended the commit with a DCO signoff.

@mesaleh mesaleh force-pushed the feat/dflash-native-support branch from 19dbc99 to 7068958 Compare May 26, 2026 17:31
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7068958425

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

)
)

hidden_states = self.self_attn(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Skip draft attention in IDLE forwards

When ModelExecutor.execute_idle_forward sends ForwardMode.IDLE with empty tensors to the drafter model (used on idle DP ranks), DFlashDecoderLayer.forward still runs self.self_attn(...) unconditionally. Unlike the shared decoder path that bypasses attention in IDLE mode, this path can hit FlashAttention with uninitialized/stale metadata (idle-forward does not set draft attention metadata) and fail before matching required collectives, which can stall multi-rank speculative decode runs in eager/non-graph idle paths.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 7ef7841. DFlashDecoderLayer.forward now returns through the MLP-only path for ForwardMode.IDLE before touching self_attn or draft attention metadata, matching the shared decoder behavior and avoiding FlashAttention metadata access on idle ranks.

@mesaleh mesaleh force-pushed the feat/dflash-native-support branch from 7068958 to 7ef7841 Compare May 26, 2026 17:41
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 7ef7841320

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +182 to +183
kv = F.linear(hidden_states.to(weight.dtype), weight, bias)
return kv.split([self.kv_size, self.kv_size], dim=-1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Route KV projection through quantized linear path

Calling F.linear directly here bypasses QKVParallelLinear's quantization logic, so FP8/quantized draft checkpoints (where qkv_proj carries scales/packing, e.g. in runtime/layers/dense/fp8.py) will compute incorrect K/V activations for cache writes. This only shows up when the draft model is quantized, but in that case it can corrupt DFlash proposal quality because _write_native_cache seeds every layer's KV from these values.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9a280bb. DFlashAttention.kv_proj_only now always routes through QKVParallelLinear and splits the returned q/k/v tensor instead of calling F.linear directly, so quantized or packed draft weights use the layer quantization path for cache K/V projection. I also removed the now-unused functional import and reran the full PR-file checks before pushing.

@mesaleh mesaleh force-pushed the feat/dflash-native-support branch from 7ef7841 to 9a280bb Compare May 26, 2026 18:03
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 9a280bb551

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +429 to +431
self.attn_backend.init_forward_metadata(
bs=bs,
num_extends=0,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid rebuilding draft metadata inside captured decode graph

DFlash._draft_native calls self.attn_backend.init_forward_metadata(...) on every draft step, but this function is also executed while _forward_step is CUDA-graph-captured (CudaGraphWrapper._capture_one captures the full forward path). In captured decode runs, replay does not re-execute this Python metadata construction, so the draft attention kernels can keep using capture-time page-table/sequence metadata instead of per-request replay metadata, which can corrupt speculative draft attention on graph-enabled decode.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 8b1d059. DFlash now skips the eager init_forward_metadata call while the CUDA graph is actively being captured, so the captured draft attention kernels use the graph-backed metadata tensors that CudaGraphWrapper refreshes at replay. I also restore capture metadata again after warmup and immediately before graph capture, because warmup runs can leave the backend pointing at eager metadata objects.

Comment on lines +245 to +247
if config.spec_algo == "DFLASH" and hasattr(
self.model_runner.model, "set_dflash_layers_to_capture"
):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Enforce DFLASH-compatible target layer capture at startup

The DFLASH setup only applies set_dflash_layers_to_capture when that method exists, and otherwise silently continues. In that case the target forward returns regular hidden states while the drafter still expects concatenated DFLASH context features, so project_target_hidden(self.fc(...)) can fail at runtime with shape mismatch on the first decode step. This should be validated eagerly for spec_algo == "DFLASH" instead of deferring to a late runtime failure.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 8b1d059. For spec_algo == DFLASH, ModelExecutor now validates that the target model implements set_dflash_layers_to_capture and raises at startup if it does not. Compatible targets still call set_dflash_layers_to_capture with the draft checkpoint target_layer_ids, so unsupported targets fail early instead of producing a first-decode shape mismatch.

@mesaleh mesaleh force-pushed the feat/dflash-native-support branch from 9a280bb to 8b1d059 Compare May 26, 2026 18:25
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 8b1d059c63

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +344 to +345
ctx_hidden = self.draft_model_runner.model.project_target_hidden(
target_hidden
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Validate captured-layer width before projecting DFLASH context

_write_native_cache assumes the target forward always returns concatenated hidden states for every dflash_config.target_layer_ids entry, then immediately feeds that tensor into fc with fixed input width (len(target_layer_ids) * hidden_size). If a draft checkpoint is paired with an incompatible target (for example, one or more configured layer IDs are out of range or not captured), the captured tensor is narrower and this path fails with a matmul shape error on the first decode step instead of failing fast at startup.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 52dde79. This was valid: the previous startup check only proved that the target exposed a DFlash capture hook, not that the draft checkpoint target_layer_ids would produce the width expected by the draft fc projection. DeepSeek/Kimi-family targets now validate duplicate and out-of-range DFLASH target_layer_ids during set_dflash_layers_to_capture, and _write_native_cache now raises a clear captured-hidden width mismatch before projection for any custom target that still returns the wrong width.

Signed-off-by: Moustafa Saleh <8815169+mesaleh@users.noreply.github.com>
@mesaleh mesaleh force-pushed the feat/dflash-native-support branch from 8b1d059 to 52dde79 Compare May 26, 2026 19:14
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 52dde79161

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=not use_cascade_attn,
causal=not (use_cascade_attn or non_causal),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Define non_causal before MLA decode attention call

forward_decode now passes causal=not (use_cascade_attn or non_causal) in the MLA path, but non_causal is never initialized in this function. On models that run self.use_mla=True with single-token decode (q_len_per_req == 1, the normal decode path), this raises a NameError before attention executes, so decode fails regardless of speculative mode.

Useful? React with 👍 / 👎.

Comment on lines +269 to +271
if ctx.forward_mode.is_idle():
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep IDLE draft path participating in TP collectives

The IDLE fast path returns immediately after self.mlp(hidden_states) and bypasses all all_reduce/fused allreduce-norm calls in the non-idle branch. In DFLASH runs with mixed active/idle DP replicas and default dense parallelism (dense_tp_size=world_size), active ranks still execute these draft-model TP collectives while idle ranks skip them, which can desynchronize NCCL collective counts and hang decoding.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 52dde79161

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

language_model = getattr(target_model, "language_model", target_model)
self.target_model = target_model
self.target_language_model = language_model
self.embed_tokens = target_model.get_input_embeddings()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Bind target embeddings with a model-agnostic fallback

The new DFLASH path unconditionally calls target_model.get_input_embeddings(), but this commit also adds set_dflash_layers_to_capture to DeepseekV3ForCausalLM, which does not implement get_input_embeddings(). In that valid DFLASH target configuration, startup fails with AttributeError during drafter binding before serving begins. Please use the same fallback pattern already used in Kimi wrappers (e.g., fall back to target_model.model.embed_tokens) so targets that expose DFLASH capture support can actually initialize.

Useful? React with 👍 / 👎.

Comment on lines +483 to +484
sampled = self._greedy_sample_from_vocab_parallel_head(
draft_hidden[:, 1:, :].reshape(-1, self.hidden_size)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Reject one-token DFLASH blocks before drafting

DFLASH currently allows --speculative-num-draft-tokens 1 (the server args validation only enforces num_steps = num_draft_tokens - 1), but this draft step always runs vocab sampling on draft_hidden[:, 1:, :]. With a one-token block that slice is empty, so the downstream argmax reduction hits a zero-length tensor at runtime on the first decode step. Add an explicit validation that DFLASH draft-token blocks are at least 2 (or short-circuit sampling for the 1-token case).

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant