diff --git a/docs/configuration/server.md b/docs/configuration/server.md index d16dc1017..82989d84d 100644 --- a/docs/configuration/server.md +++ b/docs/configuration/server.md @@ -120,7 +120,7 @@ the values accepted by the bundled `tokenspeed-smg` package. | Parameter | Purpose | | --- | --- | | `--speculative-config` | JSON speculative decoding configuration. | -| `--speculative-algorithm` | Speculative algorithm, such as `EAGLE3` or `MTP`. | +| `--speculative-algorithm` | Speculative algorithm, such as `EAGLE3`, `MTP`, or `DFLASH`. | | `--speculative-draft-model-path` | Draft model path or repo ID. | | `--speculative-draft-model-quantization` | Draft model quantization. Defaults to `unquant`. | | `--speculative-num-steps` | Number of draft model steps. Defaults to `3`. | diff --git a/docs/recipes/models.md b/docs/recipes/models.md index 85169494c..53ce96bad 100644 --- a/docs/recipes/models.md +++ b/docs/recipes/models.md @@ -33,6 +33,38 @@ tokenspeed serve nvidia/Kimi-K2.5-NVFP4 \ For K2.6, keep the same parameter shape and change the checkpoint and parser only if the model card requires a different value. +To enable a compatible DFlash draft model, keep the target launch shape and add +the draft model path plus DFlash speculative decoding options: + +```bash +tokenspeed serve nvidia/Kimi-K2.6-NVFP4 \ + --served-model-name kimi-k2.6 \ + --trust-remote-code \ + --max-model-len 262144 \ + --kv-cache-dtype fp8 \ + --quantization nvfp4 \ + --tensor-parallel-size 4 \ + --enable-expert-parallel \ + --chunked-prefill-size 8192 \ + --max-num-seqs 256 \ + --attention-backend tokenspeed_mla \ + --moe-backend flashinfer_trtllm \ + --reasoning-parser kimi_k25 \ + --tool-call-parser kimik2 \ + --speculative-algorithm DFLASH \ + --speculative-draft-model-path /path/to/kimi-k2.6-dflash \ + --speculative-num-draft-tokens 8 \ + --speculative-num-steps 7 \ + --drafter-attention-backend fa4 \ + --host 0.0.0.0 \ + --port 8000 +``` + +Known limitation: native TokenSpeed DFlash currently uses full-history draft +attention. It does not yet expose an equivalent of SGLang's +`--speculative-dflash-draft-window-size`; add such a flag before relying on +bounded draft attention for long-context deployments. + ## Qwen3 Dense / Qwen3 30B-A3B Qwen2, dense Qwen3, and Qwen3 MoE checkpoints use different architecture names. diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index 9fa151db1..826e333d8 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -349,6 +349,44 @@ def _capture_one(self, bs: int): grammar_backend=self.grammar_backend, ) + # Spec-decode capture runs a synthetic multi-token decode. Keep the + # dummy cache lengths internally consistent with that token count so + # attention warmup does not read an impossible q_len > seq_len state. + tokens_per_req = self.max_tokens_per_req + self.input_buffers.seq_lens_buf[:bs].fill_(tokens_per_req) + self.input_buffers.input_lengths_buf[:bs].fill_(tokens_per_req) + + # Capture block tables point at synthetic per-request pages. Write the + # dummy KV tokens into those same slots so attention warmup/capture reads + # initialized keys instead of the reserved padding slot. + page_size = self.input_buffers.page_size + pages_per_req = (tokens_per_req + page_size - 1) // page_size + token_offsets = torch.arange( + tokens_per_req, dtype=torch.int32, device=self.device + ) + request_offsets = ( + torch.arange(bs, dtype=torch.int32, device=self.device).unsqueeze(1) + * pages_per_req + * page_size + ) + self.input_buffers.out_cache_loc_buf[: bs * tokens_per_req].copy_( + (request_offsets + token_offsets).reshape(-1) + ) + + # Some fused decode kernels may read full page vectors during capture + # even when seq_lens bounds the logical context. Clear the synthetic + # pages so any padding read is deterministic zero, not allocator noise. + capture_slots = bs * pages_per_req * page_size + for pool in (self.token_to_kv_pool, self.draft_token_to_kv_pool): + if pool is None or not hasattr(pool, "kv_buffer"): + continue + for layer_buf in pool.kv_buffer: + if isinstance(layer_buf, (tuple, list)): + for sub_buf in layer_buf: + sub_buf[:capture_slots].zero_() + else: + layer_buf[:capture_slots].zero_() + self._init_capture_metadata(bs) def run_once(): @@ -379,6 +417,11 @@ def run_once(): torch.cuda.synchronize() dist.barrier() + # Warmups can switch a backend back to eager metadata objects. Restore + # the graph-backed metadata immediately before capture so replay-time + # metadata refreshes update the same tensors recorded by the graph. + self._init_capture_metadata(bs) + # Fill sampler buffers OUTSIDE the capture so RNG ops aren't recorded. if self.sampling_backend is not None: self.sampling_backend.prepare_capture( diff --git a/python/tokenspeed/runtime/execution/drafter/dflash.py b/python/tokenspeed/runtime/execution/drafter/dflash.py new file mode 100644 index 000000000..2d9b82a84 --- /dev/null +++ b/python/tokenspeed/runtime/execution/drafter/dflash.py @@ -0,0 +1,506 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from tokenspeed.runtime.distributed.comm_ops import all_gather_into_tensor +from tokenspeed.runtime.execution.cache_loc_kernel import compute_out_cache_loc_uniform +from tokenspeed.runtime.execution.context import ForwardContext +from tokenspeed.runtime.execution.drafter.base import BaseDrafter +from tokenspeed.runtime.execution.forward_batch_info import ( + CaptureHiddenMode, + ForwardMode, +) +from tokenspeed.runtime.layers.logits_processor import LogitsMetadata +from tokenspeed.runtime.utils import get_colorful_logger +from tokenspeed.runtime.utils.env import get_global_server_args +from tokenspeed.runtime.utils.nvtx import nvtx_range + +if TYPE_CHECKING: + from tokenspeed.runtime.execution.input_buffer import InputBuffers + from tokenspeed.runtime.execution.model_runner import ModelRunner + from tokenspeed.runtime.execution.runtime_states import RuntimeStates + from tokenspeed.runtime.layers.logits_processor import LogitsProcessorOutput + +logger = get_colorful_logger(__name__) + + +class DFlash(BaseDrafter): + """DFlash block drafter backed by a native TokenSpeed draft model.""" + + def __init__( + self, + spec_num_tokens: int, + spec_num_steps: int, + page_size: int, + draft_model_runner: ModelRunner | None = None, + req_to_page: torch.Tensor | None = None, + attn_backend=None, + token_to_kv_pool=None, + runtime_states: RuntimeStates | None = None, + input_buffers: InputBuffers | None = None, + vocab_size: int | None = None, + ) -> None: + super().__init__( + spec_num_tokens=spec_num_tokens, + spec_num_steps=spec_num_steps, + draft_model_runner=draft_model_runner, + runtime_states=runtime_states, + input_buffers=input_buffers, + page_size=page_size, + req_to_page=req_to_page, + attn_backend=attn_backend, + token_to_kv_pool=token_to_kv_pool, + vocab_size=vocab_size, + ) + server_args = get_global_server_args() + if not server_args.speculative_draft_model_path: + raise ValueError("DFLASH requires --speculative-draft-model-path.") + + if draft_model_runner is None: + raise ValueError("Native DFLASH requires a draft model runner.") + self.device = torch.device(draft_model_runner.device) + self.model = draft_model_runner.model + + cfg = self.model.config + dflash_cfg = getattr(cfg, "dflash_config", {}) or {} + self.target_layer_ids = [int(x) for x in dflash_cfg.get("target_layer_ids", [])] + if not self.target_layer_ids: + raise ValueError( + "DFLASH draft config must define dflash_config.target_layer_ids." + ) + if "mask_token_id" not in dflash_cfg: + raise ValueError( + "DFLASH draft config must define dflash_config.mask_token_id." + ) + self.mask_token_id = int(dflash_cfg["mask_token_id"]) + self.block_size = int(getattr(cfg, "block_size", spec_num_tokens)) + if self.block_size != int(spec_num_tokens): + logger.warning( + "DFLASH block size mismatch: checkpoint block_size=%s, " + "runtime speculative_num_draft_tokens=%s.", + self.block_size, + spec_num_tokens, + ) + self.hidden_size = int(getattr(cfg, "hidden_size")) + self.idle_forward_steps = 1 + self._init_native_buffers() + self._greedy_gathered_max: torch.Tensor | None = None + self._greedy_gathered_ids: torch.Tensor | None = None + self._greedy_gather_cap = 0 + + def _init_native_buffers(self) -> None: + if self.input_buffers is None: + raise ValueError("Native DFLASH requires input buffers.") + if self.req_to_page is None: + raise ValueError("Native DFLASH requires req_to_page.") + if self.attn_backend is None or self.token_to_kv_pool is None: + raise ValueError("Native DFLASH requires draft attention components.") + + max_bs = self.input_buffers.max_bs + self.draft_seq_lens_buf = torch.zeros_like(self.input_buffers.seq_lens_buf) + self.draft_out_cache_loc_buf = torch.empty( + (max_bs * self.spec_num_tokens,), + dtype=torch.int32, + device=self.device, + ) + self.draft_input_lengths_buf = torch.full( + (max_bs,), + self.spec_num_tokens, + dtype=torch.int32, + device=self.device, + ) + self.draft_extend_seq_lens_cpu = torch.full( + (max_bs,), + self.spec_num_tokens, + dtype=torch.int32, + pin_memory=True, + ) + self.block_offsets = torch.arange( + self.spec_num_tokens, dtype=torch.int64, device=self.device + ) + self.block_ids_buf = torch.empty( + (max_bs, self.spec_num_tokens), dtype=torch.int32, device=self.device + ) + self.block_positions_buf = torch.empty( + (max_bs, self.spec_num_tokens), dtype=torch.int64, device=self.device + ) + + def bind_target_model(self, target_model) -> None: + language_model = getattr(target_model, "language_model", target_model) + self.target_model = target_model + self.target_language_model = language_model + self.embed_tokens = target_model.get_input_embeddings() + self.lm_head = target_model.lm_head + self.logits_processor = language_model.logits_processor + + def _greedy_sample_from_vocab_parallel_head( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + if not hasattr(self.lm_head, "weight") or not hasattr( + self.lm_head, "shard_indices" + ): + metadata = LogitsMetadata(forward_mode=ForwardMode.DECODE) + logits = self.logits_processor._get_logits( + hidden_states, self.lm_head, metadata + ) + return torch.argmax(logits, dim=-1).to(torch.int32) + + shard = self.lm_head.shard_indices + weight = self.lm_head.weight + hidden_states = hidden_states.to(weight.dtype) + + num_org = int(shard.num_org_elements) + num_org_padded = int(shard.num_org_elements_padded) + num_added = int(shard.num_added_elements) + org_vocab_start = int(shard.org_vocab_start_index) + added_vocab_start = int(shard.added_vocab_start_index) + + chunk_len = int(hidden_states.shape[0]) + if num_org > 0: + base_logits = torch.matmul(hidden_states, weight[:num_org].T) + local_max, local_arg = torch.max(base_logits, dim=-1) + else: + local_max = torch.full( + (chunk_len,), + torch.finfo(weight.dtype).min, + dtype=weight.dtype, + device=hidden_states.device, + ) + local_arg = torch.zeros( + (chunk_len,), dtype=torch.int64, device=hidden_states.device + ) + + if num_added > 0: + added_start = num_org_padded + added_end = num_org_padded + num_added + added_weight = weight[added_start:added_end] + added_logits = torch.matmul(hidden_states, added_weight.T) + added_max, added_arg = torch.max(added_logits, dim=-1) + use_added = added_max > local_max + local_max = torch.where(use_added, added_max, local_max) + local_arg = torch.where( + use_added, + added_arg.to(local_arg.dtype) + num_org_padded, + local_arg, + ) + + if num_added == 0: + global_ids = local_arg + org_vocab_start + else: + global_ids = torch.empty( + (chunk_len,), dtype=torch.int64, device=hidden_states.device + ) + is_base = local_arg < num_org + global_ids[is_base] = org_vocab_start + local_arg[is_base] + global_ids[~is_base] = added_vocab_start + ( + local_arg[~is_base] - num_org_padded + ) + + tp_size = int(self.logits_processor.tp_size) + if tp_size == 1: + return global_ids.to(torch.int32) + + needed = tp_size * chunk_len + if ( + self._greedy_gather_cap < needed + or self._greedy_gathered_max is None + or self._greedy_gathered_ids is None + or self._greedy_gathered_max.dtype != local_max.dtype + or self._greedy_gathered_max.device != hidden_states.device + ): + self._greedy_gathered_max = torch.empty( + (needed,), dtype=local_max.dtype, device=hidden_states.device + ) + self._greedy_gathered_ids = torch.empty( + (needed,), dtype=global_ids.dtype, device=hidden_states.device + ) + self._greedy_gather_cap = needed + + gathered_max = self._greedy_gathered_max[:needed] + gathered_ids = self._greedy_gathered_ids[:needed] + all_gather_into_tensor( + gathered_max, + local_max.contiguous(), + self.logits_processor.tp_rank, + self.logits_processor.tp_group, + ) + all_gather_into_tensor( + gathered_ids, + global_ids.contiguous(), + self.logits_processor.tp_rank, + self.logits_processor.tp_group, + ) + + gathered_max = gathered_max.view(tp_size, chunk_len) + gathered_ids = gathered_ids.view(tp_size, chunk_len) + best_rank = torch.argmax(gathered_max, dim=0).unsqueeze(0) + return torch.gather(gathered_ids, 0, best_rank).view(-1).to(torch.int32) + + @nvtx_range("dflash_update_native_cache", color="purple") + def _update_native_cache_from_target( + self, + base_ctx: ForwardContext, + logits_output: LogitsProcessorOutput, + accept_lengths: torch.Tensor, + ) -> None: + hidden = logits_output.hidden_states + if hidden is None: + raise RuntimeError("DFLASH requires target hidden states.") + if hidden.shape[0] != base_ctx.input_num_tokens: + raise RuntimeError( + "DFLASH hidden-state/token mismatch: " + f"hidden_tokens={hidden.shape[0]}, input_tokens={base_ctx.input_num_tokens}." + ) + + bs = base_ctx.bs + lengths = self.input_buffers.input_lengths_buf[:bs].to(torch.int64) + req_pool_indices = self.input_buffers.req_pool_indices_buf[:bs] + positions = self.input_buffers.positions_buf[: base_ctx.input_num_tokens] + cache_locs = self.input_buffers.out_cache_loc_buf[: base_ctx.input_num_tokens] + + if ( + base_ctx.num_extends == 0 + and torch.cuda.is_available() + and torch.cuda.is_current_stream_capturing() + ): + old_lens = self.runtime_states.valid_cache_lengths.index_select( + 0, req_pool_indices + ) + self.draft_seq_lens_buf[:bs].copy_( + old_lens.to(torch.int32) + accept_lengths[:bs].to(torch.int32) + ) + self._write_native_cache(hidden, positions, cache_locs) + return + + hidden_chunks = torch.split(hidden, lengths.detach().cpu().tolist(), dim=0) + pos_chunks = torch.split(positions, lengths.detach().cpu().tolist(), dim=0) + loc_chunks = torch.split(cache_locs, lengths.detach().cpu().tolist(), dim=0) + + selected_hidden = [] + selected_positions = [] + selected_cache_locs = [] + new_seq_lens = torch.empty((bs,), dtype=torch.int32, device=self.device) + + for row, (chunk, pos_chunk, loc_chunk) in enumerate( + zip(hidden_chunks, pos_chunks, loc_chunks, strict=True) + ): + if row < base_ctx.num_extends: + take = int(chunk.shape[0]) + else: + take = int(accept_lengths[row].item()) + if take <= 0: + pool_idx = req_pool_indices[row] + new_seq_lens[row] = self.runtime_states.valid_cache_lengths[pool_idx] + continue + + chunk = chunk[:take].contiguous() + pos_chunk = pos_chunk[:take].contiguous() + loc_chunk = loc_chunk[:take].contiguous() + selected_hidden.append(chunk) + selected_positions.append(pos_chunk) + selected_cache_locs.append(loc_chunk) + new_seq_lens[row] = (pos_chunk[-1] + 1).to(torch.int32) + + self.draft_seq_lens_buf[:bs].copy_(new_seq_lens) + if not selected_hidden: + return + + target_hidden = torch.cat(selected_hidden, dim=0) + target_positions = torch.cat(selected_positions, dim=0) + target_cache_locs = torch.cat(selected_cache_locs, dim=0) + self._write_native_cache(target_hidden, target_positions, target_cache_locs) + + def _write_native_cache( + self, + target_hidden: torch.Tensor, + target_positions: torch.Tensor, + target_cache_locs: torch.Tensor, + ) -> None: + target_hidden = target_hidden.to( + device=self.device, + dtype=self.draft_model_runner.model.fc.weight.dtype, + ) + expected_width = int(self.draft_model_runner.model.fc.in_features) + actual_width = int(target_hidden.shape[-1]) + if actual_width != expected_width: + raise RuntimeError( + "DFLASH captured hidden width mismatch: " + f"expected {expected_width}, got {actual_width}. " + "Check dflash_config.target_layer_ids against the target model." + ) + with torch.inference_mode(): + ctx_hidden = self.draft_model_runner.model.project_target_hidden( + target_hidden + ) + for layer in self.draft_model_runner.model.layers: + attn = layer.self_attn + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(target_positions, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + self.token_to_kv_pool.set_kv_buffer( + attn.attn, + target_cache_locs, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + @staticmethod + def _current_tokens_from_output( + output_tokens: torch.Tensor, + accept_lengths: torch.Tensor, + num_extends: int, + spec_num_tokens: int, + ) -> torch.Tensor: + bs = accept_lengths.shape[0] + current = torch.empty((bs,), dtype=torch.int32, device=output_tokens.device) + if num_extends > 0: + current[:num_extends] = output_tokens[:num_extends] + num_decodes = bs - num_extends + if num_decodes > 0: + offsets = ( + torch.arange( + num_decodes, dtype=torch.int64, device=output_tokens.device + ) + * spec_num_tokens + - 1 + + num_extends + ) + current[num_extends:] = output_tokens[ + offsets + accept_lengths[num_extends:].to(torch.int64) + ] + return current + + def get_candidates(self, base_ctx: ForwardContext) -> torch.Tensor | None: + num_extends = base_ctx.num_extends + num_decodes = base_ctx.bs - num_extends + if num_decodes == 0: + return None + num_decode_tokens = num_decodes * self.spec_num_tokens + num_prefill_tokens = base_ctx.input_num_tokens - num_decode_tokens + return self.input_buffers.input_ids_buf[ + num_prefill_tokens : base_ctx.input_num_tokens + ].reshape(num_decodes, self.spec_num_tokens) + + def draft(self, current_tokens: torch.Tensor) -> torch.Tensor: + return self._draft_native(current_tokens) + + @nvtx_range("dflash_native_draft", color="purple") + def _draft_native(self, current_tokens: torch.Tensor) -> torch.Tensor: + bs = current_tokens.shape[0] + req_pool_indices = self.input_buffers.req_pool_indices_buf[:bs] + prefix_lens = self.draft_seq_lens_buf[:bs].clone() + seq_lens_after = self.draft_seq_lens_buf[:bs] + seq_lens_after.copy_(prefix_lens + int(self.spec_num_tokens)) + + block_ids = self.block_ids_buf[:bs] + block_ids.fill_(int(self.mask_token_id)) + block_ids[:, 0].copy_(current_tokens.to(torch.int32)) + block_positions = self.block_positions_buf[:bs] + block_positions.copy_( + prefix_lens.to(torch.int64).unsqueeze(1) + self.block_offsets + ) + + cache_locs = self.draft_out_cache_loc_buf[: bs * self.spec_num_tokens] + compute_out_cache_loc_uniform( + out_cache_loc_ptr=cache_locs, + req_pool_indices=req_pool_indices, + uniform_input_length=self.spec_num_tokens, + cache_start=prefix_lens, + req_to_pages=self.req_to_page, + page_size=self.page_size, + ) + + if not (torch.cuda.is_available() and torch.cuda.is_current_stream_capturing()): + self.attn_backend.init_forward_metadata( + bs=bs, + num_extends=0, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens_after, + req_to_page=self.req_to_page, + forward_mode=ForwardMode.DECODE, + extend_seq_lens_cpu=self.draft_extend_seq_lens_cpu[:bs], + ) + + ctx = ForwardContext( + attn_backend=self.attn_backend, + token_to_kv_pool=self.token_to_kv_pool, + req_to_page=self.req_to_page, + bs=bs, + num_extends=0, + input_num_tokens=bs * self.spec_num_tokens, + forward_mode=ForwardMode.DECODE, + capture_hidden_mode=CaptureHiddenMode.FULL, + ) + + flat_ids = block_ids.reshape(-1) + input_embeds = self.embed_tokens(flat_ids) + with torch.inference_mode(): + logits_output = self.draft_model_runner.forward( + ctx=ctx, + input_ids=flat_ids, + positions=block_positions.reshape(-1), + out_cache_loc=cache_locs, + input_lengths=self.draft_input_lengths_buf[:bs], + captured_hidden_states=None, + input_embeds=input_embeds, + ) + + draft_hidden = logits_output.hidden_states + if draft_hidden is None: + raise RuntimeError( + "Native DFLASH draft model did not return hidden states." + ) + draft_hidden = draft_hidden.view(bs, self.spec_num_tokens, self.hidden_size) + + next_tokens = torch.empty( + (bs, self.spec_num_tokens), dtype=torch.int32, device=self.device + ) + next_tokens[:, 0] = current_tokens.to(torch.int32) + sampled = self._greedy_sample_from_vocab_parallel_head( + draft_hidden[:, 1:, :].reshape(-1, self.hidden_size) + ) + next_tokens[:, 1:] = sampled.view(bs, self.spec_num_tokens - 1) + return next_tokens + + @nvtx_range("drafter:dflash", color="purple") + def run( + self, + base_ctx: ForwardContext, + logits_output: LogitsProcessorOutput, + output_tokens: torch.Tensor, + accept_lengths: torch.Tensor, + ) -> torch.Tensor: + if not hasattr(self, "target_model"): + raise RuntimeError("DFLASH drafter is not bound to a target model.") + self._update_native_cache_from_target(base_ctx, logits_output, accept_lengths) + current_tokens = self._current_tokens_from_output( + output_tokens, + accept_lengths, + base_ctx.num_extends, + self.spec_num_tokens, + ) + return self.draft(current_tokens) diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index e583e8712..42d19adc9 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -35,6 +35,7 @@ from tokenspeed.runtime.execution.cache_loc_kernel import update_block_table from tokenspeed.runtime.execution.context import ForwardContext from tokenspeed.runtime.execution.cuda_graph_wrapper import CudaGraphWrapper +from tokenspeed.runtime.execution.drafter.dflash import DFlash from tokenspeed.runtime.execution.drafter.eagle import Eagle from tokenspeed.runtime.execution.forward_batch_info import ( CaptureHiddenMode, @@ -61,7 +62,7 @@ logger = get_colorful_logger(__name__) -_DRAFTER_MAPPING = {"EAGLE3": Eagle, "MTP": Eagle} +_DRAFTER_MAPPING = {"EAGLE3": Eagle, "MTP": Eagle, "DFLASH": DFlash} @dataclass @@ -232,12 +233,24 @@ def __init__( token_to_kv_pool=draft_token_to_kv_pool, vocab_size=config.vocab_size, ) - embed, head = self.model_runner.model.get_embed_and_head() - draft_model_runner.model.set_embed_and_head(embed, head) + if hasattr(self.drafter, "bind_target_model"): + self.drafter.bind_target_model(self.model_runner.model) + if config.spec_algo in ("EAGLE3", "MTP"): + embed, head = self.model_runner.model.get_embed_and_head() + draft_model_runner.model.set_embed_and_head(embed, head) if config.spec_algo in ("EAGLE3",) and hasattr( self.model_runner.model, "set_eagle3_layers_to_capture" ): self.model_runner.model.set_eagle3_layers_to_capture() + if config.spec_algo == "DFLASH": + if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"): + raise ValueError( + "DFLASH requires the target model to support " + "set_dflash_layers_to_capture." + ) + self.model_runner.model.set_dflash_layers_to_capture( + self.drafter.target_layer_ids + ) else: self.drafter = None @@ -894,7 +907,10 @@ def execute_idle_forward( global_bs=global_bs, all_decode_or_idle=all_decode_or_idle, ) - for _ in range(self.drafter.spec_num_steps): + idle_forward_steps = getattr( + self.drafter, "idle_forward_steps", self.drafter.spec_num_steps + ) + for _ in range(idle_forward_steps or 0): self.drafter.draft_model_runner.forward( draft_ctx, input_ids=empty, diff --git a/python/tokenspeed/runtime/execution/model_runner.py b/python/tokenspeed/runtime/execution/model_runner.py index 90141c4a6..71ab8af5a 100644 --- a/python/tokenspeed/runtime/execution/model_runner.py +++ b/python/tokenspeed/runtime/execution/model_runner.py @@ -124,6 +124,7 @@ def forward( seq_lens: torch.Tensor | None = None, extend_prefix_lens: torch.Tensor | None = None, captured_hidden_states: torch.Tensor | None = None, + input_embeds: torch.Tensor | None = None, multimodal_context: MultimodalForwardContext | None = None, ) -> LogitsProcessorOutput: kwargs = {} @@ -137,6 +138,8 @@ def forward( kwargs["get_embedding"] = True if captured_hidden_states is not None: kwargs["captured_hidden_states"] = captured_hidden_states + if input_embeds is not None: + kwargs["input_embeds"] = input_embeds if multimodal_context is not None: kwargs["multimodal_context"] = multimodal_context diff --git a/python/tokenspeed/runtime/layers/attention/backends/flash_attention.py b/python/tokenspeed/runtime/layers/attention/backends/flash_attention.py index a9d48a9d0..87bbc66fe 100644 --- a/python/tokenspeed/runtime/layers/attention/backends/flash_attention.py +++ b/python/tokenspeed/runtime/layers/attention/backends/flash_attention.py @@ -729,6 +729,7 @@ def forward_extend( ) use_cascade_attn = is_target_verify and self.topk > 1 and not is_swa + non_causal = bool(getattr(layer, "non_causal", False)) # Only pass ``ver`` when talking to a non-default FlashAttention # interface version. @@ -782,7 +783,7 @@ def forward_extend( cu_seqlens_k_new=None, max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, - causal=not use_cascade_attn, + causal=not (use_cascade_attn or non_causal), window_size=window_size, softcap=layer.logit_cap, k_descale=k_descale, @@ -1189,7 +1190,7 @@ def forward_decode( cu_seqlens_k_new=None, max_seqlen_q=max_seqlen_q, softmax_scale=layer.scaling, - causal=not use_cascade_attn, + causal=not (use_cascade_attn or non_causal), softcap=layer.logit_cap, k_descale=k_descale, v_descale=v_descale, diff --git a/python/tokenspeed/runtime/layers/attention/configs/mha.py b/python/tokenspeed/runtime/layers/attention/configs/mha.py index 22fcb7f42..37fad45fc 100644 --- a/python/tokenspeed/runtime/layers/attention/configs/mha.py +++ b/python/tokenspeed/runtime/layers/attention/configs/mha.py @@ -45,6 +45,10 @@ def generate( speculative_num_steps=server_args.speculative_num_steps, speculative_num_draft_tokens=server_args.speculative_num_draft_tokens, ) + kv_cache_dtype = server_args.kv_cache_dtype + if is_draft and server_args.speculative_algorithm == "DFLASH": + kv_cache_dtype = "bfloat16" + return cls( device=server_args.device, context_len=model_config.context_len, @@ -58,7 +62,7 @@ def generate( head_dim=model_config.head_dim, attn_tp_size=server_args.attn_tp_size or server_args.mapping.attn.tp_size, dtype=model_config.dtype, - kv_cache_dtype=resolve_dtype(server_args.kv_cache_dtype), + kv_cache_dtype=resolve_dtype(kv_cache_dtype), page_size=server_args.block_size, max_bs=server_args.max_num_seqs // (server_args.data_parallel_size or server_args.mapping.attn.dp_size), diff --git a/python/tokenspeed/runtime/models/deepseek_v3.py b/python/tokenspeed/runtime/models/deepseek_v3.py index 9c8b5e634..7b3ee5b1c 100644 --- a/python/tokenspeed/runtime/models/deepseek_v3.py +++ b/python/tokenspeed/runtime/models/deepseek_v3.py @@ -1352,6 +1352,22 @@ def set_eagle3_layers_to_capture(self, layer_ids: list[int] | None = None): else: self.model.layers_to_capture = {val + 1 for val in layer_ids} + def set_dflash_layers_to_capture(self, layer_ids: list[int]): + # DFlash checkpoints name 0-indexed target layer outputs. The capture + # check runs before layer i, so capture at i + 1 for layer i's output. + num_layers = len(self.model.layers) + if len(set(layer_ids)) != len(layer_ids): + raise ValueError("DFLASH target_layer_ids must be unique.") + + invalid = [val for val in layer_ids if val < 0 or val + 1 >= num_layers] + if invalid: + raise ValueError( + "DFLASH target_layer_ids must map to capturable target layer " + f"outputs. Got invalid ids {invalid}; valid range is " + f"[0, {num_layers - 2}] for {num_layers} target layers." + ) + self.model.layers_to_capture = {val + 1 for val in layer_ids} + def get_param(self, params_dict, name): if name in params_dict: return params_dict[name] diff --git a/python/tokenspeed/runtime/models/dflash.py b/python/tokenspeed/runtime/models/dflash.py new file mode 100644 index 000000000..993dfc970 --- /dev/null +++ b/python/tokenspeed/runtime/models/dflash.py @@ -0,0 +1,435 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +from collections.abc import Iterable +from typing import Optional + +import torch +from torch import nn + +from tokenspeed.runtime.distributed.comm_ops import all_reduce +from tokenspeed.runtime.distributed.mapping import Mapping +from tokenspeed.runtime.execution.context import ForwardContext +from tokenspeed.runtime.layers.activation import SiluAndMul +from tokenspeed.runtime.layers.layernorm import RMSNorm +from tokenspeed.runtime.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from tokenspeed.runtime.layers.logits_processor import LogitsProcessorOutput +from tokenspeed.runtime.layers.paged_attention import PagedAttention +from tokenspeed.runtime.layers.quantization.base_config import QuantizationConfig +from tokenspeed.runtime.layers.rotary_embedding import get_rope +from tokenspeed.runtime.model_loader.weight_utils import default_weight_loader +from tokenspeed.runtime.utils import add_prefix +from tokenspeed.runtime.utils.env import global_server_args_dict + + +class DFlashAttention(nn.Module): + def __init__( + self, + config, + mapping: Mapping, + layer_id: int, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.mapping = mapping + self.hidden_size = int(config.hidden_size) + self.tp_rank = self.mapping.attn.tp_rank + self.tp_size = self.mapping.attn.tp_size + self.total_num_heads = int(config.num_attention_heads) + self.total_num_kv_heads = int( + getattr(config, "num_key_value_heads", self.total_num_heads) + ) + assert self.total_num_heads % self.tp_size == 0 + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = int( + getattr(config, "head_dim", self.hidden_size // self.total_num_heads) + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=bool(getattr(config, "attention_bias", False)), + quant_config=quant_config, + prefix=add_prefix("qkv_proj", prefix), + tp_rank=self.mapping.attn.tp_rank, + tp_size=self.mapping.attn.tp_size, + tp_group=self.mapping.attn.tp_group, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=bool(getattr(config, "attention_bias", False)), + quant_config=quant_config, + prefix=add_prefix("o_proj", prefix), + reduce_results=False, + tp_rank=self.mapping.attn.tp_rank, + tp_size=self.mapping.attn.tp_size, + tp_group=self.mapping.attn.tp_group, + ) + eps = float(getattr(config, "rms_norm_eps", 1e-6)) + self.q_norm = RMSNorm(self.head_dim, eps=eps) + self.k_norm = RMSNorm(self.head_dim, eps=eps) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=int(getattr(config, "max_position_embeddings", 32768)), + base=float(getattr(config, "rope_theta", 1000000)), + rope_scaling=getattr(config, "rope_scaling", None), + ) + + # The FA4 MHA extend selector currently has no sliding-window kernel + # for this draft shape. Use full attention for draft proposals; target + # verification remains authoritative for accepted tokens. + sliding_window = -1 + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + sliding_window_size=sliding_window, + ) + self.attn.non_causal = True + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + q = q.reshape(-1, self.head_dim) + k = k.reshape(-1, self.head_dim) + q = self.q_norm(q).view(-1, self.q_size) + k = self.k_norm(k).view(-1, self.kv_size) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + k_cache = k.view(-1, self.num_kv_heads, self.head_dim) + v_cache = v.view(-1, self.num_kv_heads, self.head_dim) + ctx.token_to_kv_pool.set_kv_buffer( + self.attn, + out_cache_loc, + k_cache, + v_cache, + self.attn.k_scale, + self.attn.v_scale, + ) + attn_output = self.attn( + q, + None, + None, + ctx, + out_cache_loc, + save_kv_cache=False, + ) + if len(attn_output.size()) == 3: + attn_output = attn_output.reshape(attn_output.shape[0], -1) + output, _ = self.o_proj(attn_output) + return output + + def kv_proj_only( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + qkv, _ = self.qkv_proj(hidden_states) + _, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return k, v + + def apply_k_norm(self, k: torch.Tensor) -> torch.Tensor: + k_shape = k.shape + return self.k_norm(k.reshape(-1, self.head_dim)).view(k_shape) + + def apply_k_rope(self, positions: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + dummy_q = k.new_empty(k.shape) + _, k = self.rotary_emb(positions, dummy_q, k) + return k + + +class DFlashMLP(nn.Module): + def __init__( + self, + config, + mapping: Mapping, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + intermediate_size = int(config.intermediate_size) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + tp_rank=mapping.dense.tp_rank, + tp_size=mapping.dense.tp_size, + tp_group=mapping.dense.tp_group, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + reduce_results=False, + tp_rank=mapping.dense.tp_rank, + tp_size=mapping.dense.tp_size, + tp_group=mapping.dense.tp_group, + ) + if getattr(config, "hidden_act", "silu") != "silu": + raise ValueError("DFlash only supports silu activation.") + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DFlashDecoderLayer(nn.Module): + def __init__( + self, + config, + mapping: Mapping, + layer_id: int, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + eps = float(getattr(config, "rms_norm_eps", 1e-6)) + self.mapping = mapping + self.input_layernorm = RMSNorm(hidden_size, eps=eps) + self.self_attn = DFlashAttention( + config=config, + mapping=mapping, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + ) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=eps) + self.mlp = DFlashMLP( + config=config, + mapping=mapping, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ctx: ForwardContext, + out_cache_loc: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if ctx.forward_mode.is_idle(): + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + elif ( + ctx.input_num_tokens > global_server_args_dict["comm_fusion_max_num_tokens"] + ): + hidden_states = all_reduce( + hidden_states, self.mapping.dense.tp_rank, self.mapping.dense.tp_group + ) + hidden_states, residual = self.input_layernorm(hidden_states, residual) + else: + hidden_states, residual, *_ = ( + self.input_layernorm.forward_with_allreduce_fusion( + self.mapping.dense.tp_rank, + self.mapping.dense.tp_group, + hidden_states, + residual, + ) + ) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ctx=ctx, + out_cache_loc=out_cache_loc, + ) + + if ctx.input_num_tokens > global_server_args_dict["comm_fusion_max_num_tokens"]: + hidden_states = all_reduce( + hidden_states, self.mapping.attn.tp_rank, self.mapping.attn.tp_group + ) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + else: + hidden_states, residual, *_ = ( + self.post_attention_layernorm.forward_with_allreduce_fusion( + self.mapping.attn.tp_rank, + self.mapping.attn.tp_group, + hidden_states, + residual, + ) + ) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DFlashDraftModel(nn.Module): + def __init__( + self, + config, + mapping: Mapping, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.mapping = mapping + eps = float(getattr(config, "rms_norm_eps", 1e-6)) + self.layers = nn.ModuleList( + [ + DFlashDecoderLayer( + config=config, + mapping=mapping, + layer_id=i, + quant_config=quant_config, + prefix=add_prefix(f"layers.{i}", prefix), + ) + for i in range(int(config.num_hidden_layers)) + ] + ) + self.norm = RMSNorm(int(config.hidden_size), eps=eps) + target_layer_ids = (getattr(config, "dflash_config", {}) or {}).get( + "target_layer_ids", [] + ) + self.num_context_features = len(target_layer_ids) + self.fc = nn.Linear( + self.num_context_features * int(config.hidden_size), + int(config.hidden_size), + bias=False, + ) + self.hidden_norm = RMSNorm(int(config.hidden_size), eps=eps) + self.block_size = int(getattr(config, "block_size", 8)) + + def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: + return self.hidden_norm(self.fc(target_hidden)) + + @torch.no_grad() + def forward( + self, + ctx: ForwardContext, + input_ids: torch.Tensor, + positions: torch.Tensor, + out_cache_loc: torch.Tensor, + input_lengths: torch.Tensor, + input_embeds: torch.Tensor | None = None, + **kwargs, + ) -> LogitsProcessorOutput: + if input_embeds is None: + if not ctx.forward_mode.is_idle(): + raise ValueError("DFlashDraftModel requires input_embeds.") + hidden_states = self.fc.weight.new_empty((0, int(self.config.hidden_size))) + else: + hidden_states = input_embeds + residual = None + + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + ctx=ctx, + out_cache_loc=out_cache_loc, + residual=residual, + ) + + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + return LogitsProcessorOutput( + next_token_logits=None, hidden_states=hidden_states + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + + def resolve_name(name: str) -> str | None: + if name in params_dict: + return name + if name.startswith("model.") and name[len("model.") :] in params_dict: + return name[len("model.") :] + return None + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if f".{weight_name}." not in name: + continue + resolved = resolve_name(name.replace(weight_name, param_name)) + if resolved is None: + continue + param = params_dict[resolved] + param.weight_loader(param, loaded_weight, shard_id) + break + else: + resolved = resolve_name(name) + if resolved is None: + continue + param = params_dict[resolved] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = [DFlashDraftModel] diff --git a/python/tokenspeed/runtime/spec_decode/algorithm.py b/python/tokenspeed/runtime/spec_decode/algorithm.py index 7f7b10ba5..dfedfddaa 100755 --- a/python/tokenspeed/runtime/spec_decode/algorithm.py +++ b/python/tokenspeed/runtime/spec_decode/algorithm.py @@ -25,6 +25,7 @@ class SpeculativeAlgorithm(IntEnum): NONE = auto() EAGLE3 = auto() MTP = auto() + DFLASH = auto() def is_none(self) -> bool: return self == SpeculativeAlgorithm.NONE @@ -37,6 +38,7 @@ def from_string(name: str | None) -> "SpeculativeAlgorithm": name_map = { "EAGLE3": SpeculativeAlgorithm.EAGLE3, "MTP": SpeculativeAlgorithm.MTP, + "DFLASH": SpeculativeAlgorithm.DFLASH, None: SpeculativeAlgorithm.NONE, } if name is not None: diff --git a/python/tokenspeed/runtime/utils/hf_transformers_utils.py b/python/tokenspeed/runtime/utils/hf_transformers_utils.py index 69aa2af1e..52ca7aa49 100755 --- a/python/tokenspeed/runtime/utils/hf_transformers_utils.py +++ b/python/tokenspeed/runtime/utils/hf_transformers_utils.py @@ -236,6 +236,7 @@ def get_config( and config.architectures and "NextN" not in config.architectures[0] and "Eagle" not in config.architectures[0] + and "DFlash" not in config.architectures[0] ): if config.architectures[0] == "MiniMaxM2ForCausalLM": config.architectures[0] = "LlamaForCausalLMEagle3" diff --git a/python/tokenspeed/runtime/utils/server_args.py b/python/tokenspeed/runtime/utils/server_args.py index 5c21e23b4..d24333835 100755 --- a/python/tokenspeed/runtime/utils/server_args.py +++ b/python/tokenspeed/runtime/utils/server_args.py @@ -340,7 +340,13 @@ def resolve_config_aliases(self): num_speculative_tokens = config.get("num_speculative_tokens") if num_speculative_tokens is not None: - self.speculative_num_steps = int(num_speculative_tokens) + num_speculative_tokens = int(num_speculative_tokens) + if self.speculative_algorithm == "DFLASH": + if self.speculative_num_draft_tokens is None: + self.speculative_num_draft_tokens = num_speculative_tokens + self.speculative_num_steps = max(num_speculative_tokens - 1, 0) + else: + self.speculative_num_steps = num_speculative_tokens if self.speculative_num_draft_tokens is None: self.speculative_num_draft_tokens = self.speculative_num_steps + 1 @@ -518,6 +524,18 @@ def resolve_speculative_decoding(self): if self.speculative_draft_model_quantization == "unquant": self.speculative_draft_model_quantization = None + if self.speculative_algorithm == "DFLASH": + expected_steps = max(int(self.speculative_num_draft_tokens) - 1, 0) + if self.speculative_num_steps == ServerArgs.speculative_num_steps: + self.speculative_num_steps = expected_steps + elif self.speculative_num_steps != expected_steps: + raise ValueError( + "DFLASH requires speculative_num_steps to equal " + "speculative_num_draft_tokens - 1. " + f"Got {self.speculative_num_steps=} and " + f"{self.speculative_num_draft_tokens=}." + ) + if self.eagle3_layers_to_capture is not None: self.eagle3_layers_to_capture = [ int(x) for x in self.eagle3_layers_to_capture.split(",") @@ -1413,7 +1431,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE3", "MTP"], + choices=["EAGLE3", "MTP", "DFLASH"], help="Speculative algorithm.", ) parser.add_argument(