feat(spec-decode): add native DFlash support#263
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 19dbc99a81
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if input_embeds is None: | ||
| raise ValueError("DFlashDraftModel requires input_embeds.") |
There was a problem hiding this comment.
Handle IDLE forwards without requiring input_embeds
DFlashDraftModel.forward unconditionally raises when input_embeds is None, but ModelExecutor.execute_idle_forward invokes the drafter model in ForwardMode.IDLE with empty tensors and no input_embeds to keep DP collectives aligned on idle ranks. With --speculative-algorithm DFLASH and data_parallel_size > 1, any rank that enters idle-forward will fail here instead of participating in the collective, which can stall or crash distributed decoding.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in 7068958. DFlashDraftModel.forward now accepts the idle path without input_embeds by creating an empty hidden-state tensor and continuing through the draft model, so idle ranks can participate in the same forward path instead of raising before distributed alignment. Also amended the commit with a DCO signoff.
19dbc99 to
7068958
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7068958425
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| ) | ||
| ) | ||
|
|
||
| hidden_states = self.self_attn( |
There was a problem hiding this comment.
Skip draft attention in IDLE forwards
When ModelExecutor.execute_idle_forward sends ForwardMode.IDLE with empty tensors to the drafter model (used on idle DP ranks), DFlashDecoderLayer.forward still runs self.self_attn(...) unconditionally. Unlike the shared decoder path that bypasses attention in IDLE mode, this path can hit FlashAttention with uninitialized/stale metadata (idle-forward does not set draft attention metadata) and fail before matching required collectives, which can stall multi-rank speculative decode runs in eager/non-graph idle paths.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in 7ef7841. DFlashDecoderLayer.forward now returns through the MLP-only path for ForwardMode.IDLE before touching self_attn or draft attention metadata, matching the shared decoder behavior and avoiding FlashAttention metadata access on idle ranks.
7068958 to
7ef7841
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 7ef7841320
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| kv = F.linear(hidden_states.to(weight.dtype), weight, bias) | ||
| return kv.split([self.kv_size, self.kv_size], dim=-1) |
There was a problem hiding this comment.
Route KV projection through quantized linear path
Calling F.linear directly here bypasses QKVParallelLinear's quantization logic, so FP8/quantized draft checkpoints (where qkv_proj carries scales/packing, e.g. in runtime/layers/dense/fp8.py) will compute incorrect K/V activations for cache writes. This only shows up when the draft model is quantized, but in that case it can corrupt DFlash proposal quality because _write_native_cache seeds every layer's KV from these values.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in 9a280bb. DFlashAttention.kv_proj_only now always routes through QKVParallelLinear and splits the returned q/k/v tensor instead of calling F.linear directly, so quantized or packed draft weights use the layer quantization path for cache K/V projection. I also removed the now-unused functional import and reran the full PR-file checks before pushing.
7ef7841 to
9a280bb
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 9a280bb551
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| self.attn_backend.init_forward_metadata( | ||
| bs=bs, | ||
| num_extends=0, |
There was a problem hiding this comment.
Avoid rebuilding draft metadata inside captured decode graph
DFlash._draft_native calls self.attn_backend.init_forward_metadata(...) on every draft step, but this function is also executed while _forward_step is CUDA-graph-captured (CudaGraphWrapper._capture_one captures the full forward path). In captured decode runs, replay does not re-execute this Python metadata construction, so the draft attention kernels can keep using capture-time page-table/sequence metadata instead of per-request replay metadata, which can corrupt speculative draft attention on graph-enabled decode.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in 8b1d059. DFlash now skips the eager init_forward_metadata call while the CUDA graph is actively being captured, so the captured draft attention kernels use the graph-backed metadata tensors that CudaGraphWrapper refreshes at replay. I also restore capture metadata again after warmup and immediately before graph capture, because warmup runs can leave the backend pointing at eager metadata objects.
| if config.spec_algo == "DFLASH" and hasattr( | ||
| self.model_runner.model, "set_dflash_layers_to_capture" | ||
| ): |
There was a problem hiding this comment.
Enforce DFLASH-compatible target layer capture at startup
The DFLASH setup only applies set_dflash_layers_to_capture when that method exists, and otherwise silently continues. In that case the target forward returns regular hidden states while the drafter still expects concatenated DFLASH context features, so project_target_hidden(self.fc(...)) can fail at runtime with shape mismatch on the first decode step. This should be validated eagerly for spec_algo == "DFLASH" instead of deferring to a late runtime failure.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in 8b1d059. For spec_algo == DFLASH, ModelExecutor now validates that the target model implements set_dflash_layers_to_capture and raises at startup if it does not. Compatible targets still call set_dflash_layers_to_capture with the draft checkpoint target_layer_ids, so unsupported targets fail early instead of producing a first-decode shape mismatch.
9a280bb to
8b1d059
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 8b1d059c63
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| ctx_hidden = self.draft_model_runner.model.project_target_hidden( | ||
| target_hidden |
There was a problem hiding this comment.
Validate captured-layer width before projecting DFLASH context
_write_native_cache assumes the target forward always returns concatenated hidden states for every dflash_config.target_layer_ids entry, then immediately feeds that tensor into fc with fixed input width (len(target_layer_ids) * hidden_size). If a draft checkpoint is paired with an incompatible target (for example, one or more configured layer IDs are out of range or not captured), the captured tensor is narrower and this path fails with a matmul shape error on the first decode step instead of failing fast at startup.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in 52dde79. This was valid: the previous startup check only proved that the target exposed a DFlash capture hook, not that the draft checkpoint target_layer_ids would produce the width expected by the draft fc projection. DeepSeek/Kimi-family targets now validate duplicate and out-of-range DFLASH target_layer_ids during set_dflash_layers_to_capture, and _write_native_cache now raises a clear captured-hidden width mismatch before projection for any custom target that still returns the wrong width.
Signed-off-by: Moustafa Saleh <8815169+mesaleh@users.noreply.github.com>
8b1d059 to
52dde79
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 52dde79161
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| max_seqlen_q=max_seqlen_q, | ||
| softmax_scale=layer.scaling, | ||
| causal=not use_cascade_attn, | ||
| causal=not (use_cascade_attn or non_causal), |
There was a problem hiding this comment.
Define non_causal before MLA decode attention call
forward_decode now passes causal=not (use_cascade_attn or non_causal) in the MLA path, but non_causal is never initialized in this function. On models that run self.use_mla=True with single-token decode (q_len_per_req == 1, the normal decode path), this raises a NameError before attention executes, so decode fails regardless of speculative mode.
Useful? React with 👍 / 👎.
| if ctx.forward_mode.is_idle(): | ||
| hidden_states = self.mlp(hidden_states) | ||
| return hidden_states, residual |
There was a problem hiding this comment.
Keep IDLE draft path participating in TP collectives
The IDLE fast path returns immediately after self.mlp(hidden_states) and bypasses all all_reduce/fused allreduce-norm calls in the non-idle branch. In DFLASH runs with mixed active/idle DP replicas and default dense parallelism (dense_tp_size=world_size), active ranks still execute these draft-model TP collectives while idle ranks skip them, which can desynchronize NCCL collective counts and hang decoding.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 52dde79161
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| language_model = getattr(target_model, "language_model", target_model) | ||
| self.target_model = target_model | ||
| self.target_language_model = language_model | ||
| self.embed_tokens = target_model.get_input_embeddings() |
There was a problem hiding this comment.
Bind target embeddings with a model-agnostic fallback
The new DFLASH path unconditionally calls target_model.get_input_embeddings(), but this commit also adds set_dflash_layers_to_capture to DeepseekV3ForCausalLM, which does not implement get_input_embeddings(). In that valid DFLASH target configuration, startup fails with AttributeError during drafter binding before serving begins. Please use the same fallback pattern already used in Kimi wrappers (e.g., fall back to target_model.model.embed_tokens) so targets that expose DFLASH capture support can actually initialize.
Useful? React with 👍 / 👎.
| sampled = self._greedy_sample_from_vocab_parallel_head( | ||
| draft_hidden[:, 1:, :].reshape(-1, self.hidden_size) |
There was a problem hiding this comment.
Reject one-token DFLASH blocks before drafting
DFLASH currently allows --speculative-num-draft-tokens 1 (the server args validation only enforces num_steps = num_draft_tokens - 1), but this draft step always runs vocab sampling on draft_hidden[:, 1:, :]. With a one-token block that slice is empty, so the downstream argmax reduction hits a zero-length tensor at runtime on the first decode step. Add an explicit validation that DFLASH draft-token blocks are at least 2 (or short-circuit sampling for the 1-token case).
Useful? React with 👍 / 👎.
Summary
DFLASHas a speculative decoding algorithm with a native TokenSpeed DFlash drafter and draft model.input_embedsforwarding for the draft runner.Notes
--speculative-dflash-draft-window-sizeyet.Tests
git diff --check HEAD~1..HEADpython3 -m py_compileon changed Python files