Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, // standalone draft model speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, // Eagle3 speculative decoding
COMMON_SPECULATIVE_TYPE_DRAFT_MTP, // Multi-token prediction
COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH, // DFlash speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding based on n-grams
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
Expand Down Expand Up @@ -377,7 +378,7 @@ struct common_params_speculative {

uint32_t need_n_rs_seq() const {
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 || t == COMMON_SPECULATIVE_TYPE_DRAFT_DFLASH;
});

return needs_rs_seq ? draft.n_max : 0u;
Expand Down
303 changes: 302 additions & 1 deletion common/speculative.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions conversion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"DeepseekV2ForCausalLM": "deepseek",
"DeepseekV3ForCausalLM": "deepseek",
"DeepseekV32ForCausalLM": "deepseek",
"DFlashDraftModel": "qwen",
"DistilBertForMaskedLM": "bert",
"DistilBertForSequenceClassification": "bert",
"DistilBertModel": "bert",
Expand Down
52 changes: 52 additions & 0 deletions conversion/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,55 @@ class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReor
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35MOE


@ModelBase.register("DFlashDraftModel")
class DFlashModel(Qwen3Model):
model_arch = gguf.MODEL_ARCH.DFLASH

def set_vocab(self):
if self.target_model_dir is None:
raise ValueError(
"DFlash draft model requires --target-model-dir to be specified. "
"Please provide the path to the target model directory containing the tokenizer."
)
logger.info(f"DFlash: Using tokenizer from target model: {self.target_model_dir}")
original_dir = self.dir_model
self.dir_model = self.target_model_dir
super().set_vocab()
self.dir_model = original_dir

def set_gguf_parameters(self):
super().set_gguf_parameters()

block_size = self.hparams.get("block_size", 16)
self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.block_size", block_size)
dflash_config = self.hparams.get("dflash_config", {})

target_layer_ids = dflash_config.get("target_layer_ids", [])
if target_layer_ids:
extract_layer_ids = [i + 1 for i in target_layer_ids]
self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layers", extract_layer_ids)
Comment on lines +649 to +656

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add proper keys and methods for these please!


mask_token_id = dflash_config.get("mask_token_id", None)
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
Comment on lines +658 to +660

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask_token_id = dflash_config.get("mask_token_id", None)
if mask_token_id is not None:
self.gguf_writer.add_mask_token_id(mask_token_id)
mask_token_id = dflash_config.get("mask_token_id", None)
if mask_token_id is not None:
self.hparams["mask_token_id"] = mask_token_id

I'm not sure of the purpose of separating the token id like this, but this would have gotten overridden by SpecialVocab later on if there already was a mask_token_id in the config.


use_sliding_window = self.hparams.get("use_sliding_window", False)
sliding_window = self.hparams.get("sliding_window")
layer_types = self.hparams.get("layer_types")
if use_sliding_window and sliding_window and layer_types:
is_swa = [lt == "sliding_attention" for lt in layer_types]
self.gguf_writer.add_sliding_window(sliding_window)
self.gguf_writer.add_sliding_window_pattern(is_swa)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name == "fc.weight":
yield (name, data_torch)
return
if name == "hidden_norm.weight":
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ENC_OUTPUT_NORM), data_torch)
return
if not name.startswith("model."):
name = "model." + name
Comment on lines +677 to +678

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This belongs in filter_tensors and the two above should have gotten renamed and properly mapped in tensor_mapping.

yield from super().modify_tensors(data_torch, name, bid)
29 changes: 28 additions & 1 deletion docs/speculative.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,32 @@ Supported EAGLE-3 draft models include:

For the full and up-to-date list of supported models, see #18039.

### DFlash (`draft-dflash`)

DFlash produces an entire block of draft tokens in a single forward pass (block diffusion) and
injects the target model's hidden states into the draft model's attention, instead of drafting one
token at a time. This keeps the draft model small while making drafting GPU-friendly. Unlike EAGLE-3
(a single-layer autoregressive draft), the DFlash draft uses several transformer layers but emits a
whole block per draft step.

The draft is a small block-diffusion model trained for a specific target (for example
`z-lab/Qwen3-4B-DFlash` for `Qwen/Qwen3-4B`). Convert it with `--target-model-dir` so it inherits the
target's tokenizer and token embeddings:

```bash
python convert_hf_to_gguf.py z-lab/Qwen3-4B-DFlash \
--target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-DFlash.gguf

llama-server -m Qwen3-4B.gguf -md Qwen3-4B-DFlash.gguf \
--spec-type draft-dflash --spec-draft-n-max 15 -fa on --jinja
```

`--spec-draft-n-max` is clamped to the draft model's trained block size.

See:

- #22105

### n-gram Cache (`ngram-cache`)

An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences.
Expand Down Expand Up @@ -147,7 +173,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
### General Speculative Parameters

```
--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
--spec-type [none|draft-simple|draft-eagle3|draft-dflash|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
comma-separated list of types of speculative decoding to use
(default: none)
(env: LLAMA_ARG_SPEC_TYPE)
Expand Down Expand Up @@ -287,6 +313,7 @@ Specifies a comma-separated list of speculative decoding types to use.
| `none` | No speculative decoding (default) |
| `draft-simple` | Use a simple draft model for speculation |
| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states |
| `draft-dflash` | Use a DFlash block-diffusion draft model that emits a block per step |
| `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model |
| `ngram-cache` | Use n-gram cache lookup |
| `ngram-simple` | Use simple n-gram pattern matching |
Expand Down
18 changes: 18 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ class MODEL_ARCH(IntEnum):
PANGU_EMBED = auto()
MISTRAL3 = auto()
EAGLE3 = auto()
DFLASH = auto()
MISTRAL4 = auto()
PADDLEOCR = auto()
MIMO2 = auto()
Expand Down Expand Up @@ -1074,6 +1075,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.EAGLE3: "eagle3",
MODEL_ARCH.DFLASH: "dflash",
MODEL_ARCH.MISTRAL4: "mistral4",
MODEL_ARCH.PADDLEOCR: "paddleocr",
MODEL_ARCH.MIMO2: "mimo2",
Expand Down Expand Up @@ -4086,6 +4088,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FC,
MODEL_TENSOR.D2T,
],
MODEL_ARCH.DFLASH: [
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FC,
MODEL_TENSOR.ENC_OUTPUT_NORM,
],
MODEL_ARCH.MISTRAL4: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_EAGLE3, "eagle3" },
{ LLM_ARCH_DFLASH, "dflash" },
{ LLM_ARCH_MISTRAL4, "mistral4" },
{ LLM_ARCH_PADDLEOCR, "paddleocr" },
{ LLM_ARCH_MIMO2, "mimo2" },
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ enum llm_arch {
LLM_ARCH_TALKIE,
LLM_ARCH_MELLUM,
LLM_ARCH_EAGLE3,
LLM_ARCH_DFLASH,
LLM_ARCH_UNKNOWN,
};

Expand Down
4 changes: 2 additions & 2 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ llama_context::llama_context(
cparams.ctx_other = params.ctx_other;
}

if (model.arch == LLM_ARCH_EAGLE3) {
if (model.arch == LLM_ARCH_EAGLE3 || model.arch == LLM_ARCH_DFLASH) {
if (model.tok_embd == nullptr || model.output == nullptr) {
if (params.ctx_other == nullptr) {
throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)");
throw std::runtime_error(model.arch_name() + " requires ctx_other to be set (this warning is normal during memory fitting)");
}
cparams.ctx_other = params.ctx_other;
}
Expand Down
7 changes: 6 additions & 1 deletion src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,11 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_v_idxs(self_v_idxs, ubatch);

mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
// the mask is left unallocated when the graph only stores K/V without attending
// (e.g. DFlash's KV-injection pass)
if (self_kq_mask && self_kq_mask->buffer) {
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}

if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
Expand Down Expand Up @@ -904,6 +908,7 @@ void llm_graph_result::reset() {
t_logits = nullptr;
t_embd = nullptr;
t_embd_pooled = nullptr;
t_h_nextn = nullptr;

t_layer_inp.resize(LLAMA_MAX_LAYERS);
std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr);
Expand Down
6 changes: 5 additions & 1 deletion src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params
return new llama_model_mistral3(params);
case LLM_ARCH_EAGLE3:
return new llama_model_eagle3(params);
case LLM_ARCH_DFLASH:
return new llama_model_dflash(params);
case LLM_ARCH_MIMO2:
return new llama_model_mimo2(params);
case LLM_ARCH_KIMI_LINEAR:
Expand Down Expand Up @@ -2493,6 +2495,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_STEP35:
case LLM_ARCH_TALKIE:
case LLM_ARCH_MELLUM:
case LLM_ARCH_DFLASH:
return LLAMA_ROPE_TYPE_NEOX;

case LLM_ARCH_QWEN2VL:
Expand Down Expand Up @@ -2616,7 +2619,8 @@ bool llama_model_has_encoder(const llama_model * model) {
switch (model->arch) {
case LLM_ARCH_T5:
case LLM_ARCH_T5ENCODER:
case LLM_ARCH_EAGLE3: return true;
case LLM_ARCH_EAGLE3:
case LLM_ARCH_DFLASH: return true;
default: return false;
}
}
Expand Down
Loading
Loading