From 8fac4b1cc8689bcddd7816889f63c34ef2121232 Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Sun, 14 Dec 2025 18:12:33 +0000 Subject: [PATCH 01/21] feat: add EAGLE3 speculative decoding support EAGLE3 is an encoder-decoder based speculative decoding method: - Extracts features from target model at specific layers - Uses feature fusion layer to compress target features - Generates draft tokens with single-layer decoder - Maps draft vocabulary to target vocabulary via d2t tensor Key changes: - Add LLM_ARCH_EAGLE3 architecture - Add EAGLE3 encoder/decoder graph (src/models/eagle3.cpp) - Add feature extraction from target model layers - Add g_embeddings handling for decoder input - Add GGML_TENSOR_FLAG_SYNC for GPU synchronization - Add --eagle3 flag for speculative-simple example - Add EAGLE3 model conversion in convert_hf_to_gguf.py --- common/arg.cpp | 7 + common/common.h | 2 + common/speculative.cpp | 199 +++++++++++++++++ common/speculative.h | 7 + convert_hf_to_gguf.py | 120 +++++++++- .../speculative-simple/speculative-simple.cpp | 145 ++++++++++-- ggml/include/ggml.h | 2 + ggml/src/ggml-backend.cpp | 14 ++ ggml/src/ggml.c | 4 + gguf-py/gguf/constants.py | 29 +++ include/llama.h | 24 ++ src/CMakeLists.txt | 1 + src/llama-arch.cpp | 32 +++ src/llama-arch.h | 7 + src/llama-context.cpp | 208 +++++++++++++++++- src/llama-context.h | 12 + src/llama-cparams.h | 1 + src/llama-graph.cpp | 1 + src/llama-graph.h | 26 +++ src/llama-hparams.h | 7 + src/llama-model.cpp | 87 ++++++++ src/llama-model.h | 10 + src/models/eagle3.cpp | 187 ++++++++++++++++ src/models/llama.cpp | 10 + src/models/models.h | 8 + 25 files changed, 1119 insertions(+), 31 deletions(-) create mode 100644 src/models/eagle3.cpp diff --git a/common/arg.cpp b/common/arg.cpp index aaa7b92a2e91..de8f0355db13 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3007,6 +3007,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.p_min = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_DRAFT_P_MIN")); + add_opt(common_arg( + {"--eagle3"}, + "use EAGLE3 speculative decoding with the draft model", + [](common_params & params) { + params.speculative.eagle3 = true; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); add_opt(common_arg( {"-cd", "--ctx-size-draft"}, "N", string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), diff --git a/common/common.h b/common/common.h index 4edb74b7066c..7ba288f188fe 100644 --- a/common/common.h +++ b/common/common.h @@ -241,6 +241,8 @@ struct common_params_speculative { int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) float p_split = 0.1f; // speculative decoding split probability float p_min = 0.75f; // minimum speculative decoding probability (greedy) + + bool eagle3 = false; // use EAGLE3 speculative decoding std::vector> replacements; // main to speculative model replacements std::vector tensor_buft_overrides; diff --git a/common/speculative.cpp b/common/speculative.cpp index 1e12383ae6b6..058e75b79615 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -22,6 +22,11 @@ struct common_speculative { llama_tokens prompt_dft; bool vocab_dft_compatible = true; // whether retokenization is needed std::map tgt_dft_replacements = {}; + + // EAGLE3 specific + struct llama_context * eagle3_encoder = nullptr; + struct llama_context * eagle3_decoder = nullptr; + int32_t eagle3_n_past = 0; // number of verified positions in decoder KV cache }; struct common_speculative * common_speculative_init( @@ -74,6 +79,35 @@ struct common_speculative * common_speculative_init( return result; } +struct common_speculative * common_speculative_init_eagle3( + struct llama_context * ctx_tgt, + struct llama_context * ctx_encoder, + struct llama_context * ctx_decoder) { + + auto * result = new common_speculative { + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ nullptr, // Not used for EAGLE3 + /* .smpl = */ nullptr, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_decoder), 0, 1), + /* .prompt_dft = */ {}, + /* .vocab_dft_compatible = */ true, // EAGLE3 uses same vocab + /* .tgt_dft_replacements = */ {}, + /* .eagle3_encoder = */ ctx_encoder, + /* .eagle3_decoder = */ ctx_decoder, + }; + + // Initialize sampler for EAGLE3 decoder + { + common_params_sampling params; + params.no_perf = false; + params.top_k = 10; // set 1 for greedy sampling (argmax) to match vLLM's default behavior but >1 always gets higher acceptance rate for eagle3 + params.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + result->smpl = common_sampler_init(llama_get_model(ctx_decoder), params); + } + + return result; +} + void common_speculative_free(struct common_speculative * spec) { if (spec == nullptr) { return; @@ -81,6 +115,14 @@ void common_speculative_free(struct common_speculative * spec) { common_sampler_free(spec->smpl); + // EAGLE3 cleanup + if (spec->eagle3_encoder) { + llama_free(spec->eagle3_encoder); + } + if (spec->eagle3_decoder) { + llama_free(spec->eagle3_decoder); + } + llama_batch_free(spec->batch); delete spec; @@ -181,12 +223,169 @@ static std::string replace_to_tgt( return result; } +// EAGLE3 Draft Generation with KV Cache Reuse +// +// ============================================================================ +// EXAMPLE: Two rounds of speculative decoding +// ============================================================================ +// +// ROUND 1 (Initial): +// Prompt: [t0, t1, t2, t3, t4], target generates t5 +// prompt_tgt = [t0, t1, t2, t3, t4], id_last = t5 (GENERATED) +// n = 5, n_past = 0, n_new = 5 +// +// Step 1: Encoder +// features: [f0, f1, f2, f3, f4] → g_embeddings: [g0, g1, g2, g3, g4] +// +// Step 2: Decoder batch (positions 0-4) +// tokens: [t1, t2, t3, t4, t5] ← prompt[1:] + id_last +// g_embd: [g0, g1, g2, g3, g4] +// positions: [0, 1, 2, 3, 4 ] +// → KV cache: [0, 1, 2, 3, 4] +// → sample d1 from logits[4] +// +// Step 3: Autoregressive (positions 5, 6, ...) +// pos 5: token=d1, g_embd=prenorm[4] → KV cache: [0,1,2,3,4,5] → d2 +// pos 6: token=d2, g_embd=prenorm → KV cache: [0,1,2,3,4,5,6] → d3 +// +// Output: [d1, d2, d3] +// Update: n_past = 5 (verified positions from batch decode) +// +// ROUND 2 (assuming d1 accepted, d2/d3 rejected): +// prompt_tgt = [t0, t1, t2, t3, t4, t5, d1], id_last = t6 (new target output) +// n = 7, n_past = 5, n_new = 2 +// +// Step 1: Clear KV cache [5, inf) - remove draft positions +// KV cache: [0, 1, 2, 3, 4] (reuse from round 1!) +// +// Step 2: Encoder (only new tokens) +// features: [f5, f6] → g_embeddings: [g5, g6] +// +// Step 3: Decoder batch (only new positions 5-6) +// tokens: [d1, t6] (prompt_tgt[6], id_last) +// g_embd: [g5, g6] +// positions: [5, 6 ] +// → KV cache: [0,1,2,3,4] + [5,6] = [0,1,2,3,4,5,6] +// → sample d1' from logits[1] (last position in batch) +// +// Step 4: Autoregressive... +// +// ============================================================================ +// +// Key insight: Decoder KV cache stores K/V computed from (tok_embd + g_embd). +// For verified positions, both tok_embd and g_embd are fixed (encoder output), +// so KV cache can be reused. Draft positions use prenorm as g_embd, which +// differs from encoder output, so they must be cleared and recomputed. +// +static llama_tokens gen_eagle3_draft( + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt_tgt, + llama_token id_last) { + + auto * ctx_tgt = spec->ctx_tgt; + auto * ctx_encoder = spec->eagle3_encoder; + auto * ctx_decoder = spec->eagle3_decoder; + auto * smpl = spec->smpl; + auto & batch = spec->batch; + + const int n_embd = llama_model_n_embd(llama_get_model(ctx_encoder)); + const int n = (int)prompt_tgt.size(); + const int n_new = n - spec->eagle3_n_past; + + GGML_ASSERT(n >= 1 && "prompt_tgt is empty"); + GGML_ASSERT(n_new >= 1 && "must have at least 1 new token"); + + // Clear draft positions from decoder KV cache [n_past, inf) + llama_memory_seq_rm(llama_get_memory(ctx_decoder), 0, spec->eagle3_n_past, -1); + + // Encoder: features → g_embeddings + const float * features = llama_get_eagle3_target_features(ctx_tgt); + GGML_ASSERT(features && "no target features"); + + llama_batch enc_batch = { + /*.n_tokens =*/ n_new, + /*.token =*/ nullptr, + /*.embd =*/ const_cast(features), + /*.pos =*/ nullptr, + /*.n_seq_id =*/ nullptr, + /*.seq_id =*/ nullptr, + /*.logits =*/ nullptr, + }; + GGML_ASSERT(llama_encode(ctx_encoder, enc_batch) == 0); + + const float * g_embd = llama_get_embeddings(ctx_encoder); + GGML_ASSERT(g_embd && "encoder output failed"); + + // Decoder batch: process new tokens with KV cache reuse + llama_set_eagle3_g_embeddings(ctx_decoder, g_embd, n_embd, n_new); + + common_batch_clear(batch); + for (int i = 0; i < n_new; i++) { + const int pos = spec->eagle3_n_past + i; + const llama_token tok = (pos < n - 1) ? prompt_tgt[pos + 1] : id_last; + common_batch_add(batch, tok, pos, {0}, true); + } + + GGML_ASSERT(llama_decode(ctx_decoder, batch) == 0); + + spec->eagle3_n_past = n; // update verified positions + + // Sample draft tokens + llama_tokens result; + common_sampler_reset(smpl); + + // Sample and check probability (consistent with standard speculative decoding) + auto sample_and_check = [&](int idx) -> bool { + common_sampler_sample(smpl, ctx_decoder, idx); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + result.push_back(id); + + return cur_p->data[0].p >= params.p_min; + }; + + // First draft token from batch decode + if (!sample_and_check(n_new - 1)) { + return result; + } + + // Autoregressive: use prenorm as g_embd (-1 = last output) + const float * prenorm = llama_get_embeddings_ith(ctx_decoder, -1); + + for (int i = 1; i < params.n_draft; i++) { + GGML_ASSERT(prenorm && "prenorm failed"); + llama_set_eagle3_g_embeddings(ctx_decoder, prenorm, n_embd, 1); + + common_batch_clear(batch); + common_batch_add(batch, result.back(), n - 1 + i, {0}, true); + GGML_ASSERT(llama_decode(ctx_decoder, batch) == 0); + + prenorm = llama_get_embeddings_ith(ctx_decoder, -1); + + if (!sample_and_check(0)) { + break; + } + } + + return result; +} llama_tokens common_speculative_gen_draft( struct common_speculative * spec, struct common_speculative_params params, const llama_tokens & prompt_tgt_main_model, // specified in target model vocab llama_token id_last) { + + // EAGLE3 path + if (spec->eagle3_encoder && spec->eagle3_decoder) { + return gen_eagle3_draft(spec, params, prompt_tgt_main_model, id_last); + } + + // Standard draft model path auto & batch = spec->batch; auto & ctx_tgt = spec->ctx_tgt; auto & ctx_dft = spec->ctx_dft; diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1eb0..feef3c768fa2 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -17,6 +17,13 @@ struct common_speculative * common_speculative_init( struct llama_context * ctx_dft ); +// EAGLE3: Initialize speculative decoding with EAGLE3 encoder and decoder contexts +struct common_speculative * common_speculative_init_eagle3( + struct llama_context * ctx_tgt, + struct llama_context * ctx_encoder, + struct llama_context * ctx_decoder +); + void common_speculative_free(struct common_speculative * spec); bool common_speculative_are_compatible( diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3f861f2a6a53..0f29fbd3fed0 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -97,6 +97,7 @@ class ModelBase: metadata_override: Path | None dir_model_card: Path remote_hf_model_id: str | None + target_model_dir: Path | None # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -116,7 +117,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, disable_mistral_community_chat_template: bool = False, - sentence_transformers_dense_modules: bool = False): + sentence_transformers_dense_modules: bool = False, target_model_dir: Path | None = None): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -135,6 +136,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules + self.target_model_dir = target_model_dir self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {} self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) @@ -2373,7 +2375,55 @@ def __init__(self, *args, **kwargs): if self.hf_arch == "VLlama3ForCausalLM": self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) + # detect EAGLE-3 llama checkpoint + if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1: + self.is_eagle3 = True + self.model_arch = gguf.MODEL_ARCH.EAGLE3 + logger.info("Detected EAGLE-3 draft model, switching to EAGLE3 architecture") + # Re-initialize tensor_map with EAGLE3 architecture + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + # Update gguf_writer architecture + self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] + self.gguf_writer.add_architecture() + if not hasattr(self, 'target_model_dir') or not self.target_model_dir: + raise ValueError( + "EAGLE3 model requires --target-model-dir to be specified. " + "Please provide the path to the target model directory to read config.json" + ) + # Read both EAGLE3 raw config and target model config + with open(self.dir_model / "config.json", 'r', encoding='utf-8') as f: + eagle3_raw_config = json.load(f) + with open(self.target_model_dir / "config.json", 'r', encoding='utf-8') as f: + target_config = json.load(f) + + # EAGLE3 extract_layers + target_num_layers = target_config["num_hidden_layers"] + extract_layers = [2, target_num_layers // 2, target_num_layers - 3] + logger.info(f"EAGLE3: extract_layers = {extract_layers} (target model has {target_num_layers} layers)") + self.gguf_writer.add_array(f"{self.gguf_writer.arch}.extract_layers", extract_layers) + + # EAGLE3 target_hidden_size: prefer EAGLE3 config, fallback to target config + if "target_hidden_size" in eagle3_raw_config and eagle3_raw_config["target_hidden_size"] is not None: + target_hidden_size = eagle3_raw_config["target_hidden_size"] + logger.info(f"EAGLE3: target_hidden_size = {target_hidden_size} (from EAGLE3 config)") + else: + target_hidden_size = target_config["hidden_size"] + logger.info(f"EAGLE3: target_hidden_size = {target_hidden_size} (from target model config)") + self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size) + def set_vocab(self): + # For EAGLE-3 models, use tokenizer from target model if provided + if hasattr(self, 'is_eagle3') and self.is_eagle3: + if self.target_model_dir is None: + raise ValueError( + "EAGLE-3 draft model requires --target-model-dir to be specified. " + "Please provide the path to the target model directory containing the tokenizer." + ) + logger.info(f"EAGLE-3: Using tokenizer from target model: {self.target_model_dir}") + # Temporarily swap dir_model to load tokenizer from target model + original_dir_model = self.dir_model + self.dir_model = self.target_model_dir + if self.is_mistral_format: return self._set_vocab_mistral() @@ -2391,6 +2441,10 @@ def set_vocab(self): # Llama 3 self._set_vocab_gpt2() + # Restore original dir_model for EAGLE-3 + if hasattr(self, 'is_eagle3') and self.is_eagle3: + self.dir_model = original_dir_model + # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256) if self.hparams.get("vocab_size", 32000) == 32016: special_vocab = gguf.SpecialVocab( @@ -2435,7 +2489,45 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None): _experts: list[dict[str, Tensor]] | None = None + def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]: + tensors = super().index_tensors(remote_hf_model_id) + # EAGLE-3 detection: check hparams directly (before self.is_eagle3 is set) + if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1: + logger.info("EAGLE-3: Renaming midlayer.* to model.layers.0.*") + new_tensors = {} + # EAGLE-3: rename midlayer.* to model.layers.0.* for compatibility with llama model + for name, gen in tensors.items(): + if name.startswith("midlayer."): + new_name = "model.layers.0." + name[len("midlayer."):] + new_tensors[new_name] = gen + else: + new_tensors[name] = gen + return new_tensors + else: + return tensors + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # Eagle-3 llama checkpoint special handling + if hasattr(self, 'is_eagle3') and self.is_eagle3: + # Eagle-3 llama checkpoint special weights handling + # fc.weight: feature fusion layer + if name == "fc.weight": + return [(name, data_torch)] + # d2t: draft to target vocabulary mapping + elif name == "d2t": + # Skip parent class processing (store for manual handling in prepare_tensors) + if not hasattr(self, '_eagle3_int_tensors'): + self._eagle3_int_tensors = {} + self._eagle3_int_tensors[name] = data_torch + return [] + # t2d: target to draft vocabulary mapping (not used, skip completely) + elif name == "t2d": + return [] + # hidden_norm: EAGLE-3 specific layer normalization + elif name == "model.layers.0.hidden_norm.weight": + return [("blk.0.hidden_norm.weight", data_torch)] + n_head = self.find_hparam(["n_heads", "num_attention_heads"]) n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"]) @@ -2538,8 +2630,26 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) def prepare_tensors(self): + # EAGLE-3: collect original dtypes BEFORE parent class converts them to F32 + eagle3_original_dtypes = {} + if hasattr(self, 'is_eagle3') and self.is_eagle3: + for name, data_torch in self.get_tensors(): + if name == "d2t": + eagle3_original_dtypes[name] = data_torch.dtype + super().prepare_tensors() + if hasattr(self, 'is_eagle3') and self.is_eagle3 and hasattr(self, '_eagle3_int_tensors'): + for name, data_torch in self._eagle3_int_tensors.items(): + old_dtype = eagle3_original_dtypes.get(name, data_torch.dtype) + # Keep as int64 to match original torch tensor dtype + data = data_torch.to(torch.int64).numpy() + data_qtype = gguf.GGMLQuantizationType.I64 + + shape_str = f"{{{', '.join(str(n) for n in reversed(data.shape))}}}" + logger.info(f"{name + ',':<30} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) + if self._experts is not None: # flatten `list[dict[str, Tensor]]` into `list[str]` experts = [k for d in self._experts for k in d.keys()] @@ -10125,6 +10235,7 @@ class LazyTorchTensor(gguf.LazyBase): torch.float16: np.float16, torch.float32: np.float32, torch.uint8: np.uint8, + torch.int64: np.int64, } # only used when byteswapping data. Only correct size is needed @@ -10285,6 +10396,10 @@ def parse_args() -> argparse.Namespace: "--no-tensor-first-split", action="store_true", help="do not add tensors to the first split (disabled by default)" ) + parser.add_argument( + "--target-model-dir", type=str, default=None, + help="directory containing target model tokenizer (for EAGLE-3 draft models that don't have their own tokenizer)", + ) parser.add_argument( "--metadata", type=Path, help="Specify the path for an authorship metadata override file" @@ -10457,7 +10572,8 @@ def main() -> None: split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, - sentence_transformers_dense_modules=args.sentence_transformers_dense_modules + sentence_transformers_dense_modules=args.sentence_transformers_dense_modules, + target_model_dir=Path(args.target_model_dir) if args.target_model_dir else None ) if args.vocab_only: diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 8141052a2276..3b65f3c5b107 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -4,6 +4,7 @@ #include "speculative.h" #include "log.h" #include "llama.h" +#include "chat.h" #include #include @@ -34,16 +35,42 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); llama_model * model_tgt = NULL; - //llama_model * model_dft = NULL; + llama_model * model_dft = NULL; llama_context * ctx_tgt = NULL; llama_context * ctx_dft = NULL; - // load the target model - auto llama_init_tgt = common_init_from_params(params); + // EAGLE3 specific contexts + llama_context * ctx_encoder = NULL; + llama_context * ctx_decoder = NULL; + + // For EAGLE3: load both draft model and target model + if (params.speculative.eagle3) { + llama_model_params dft_mp = llama_model_default_params(); + dft_mp.n_gpu_layers = params.speculative.n_gpu_layers; + model_dft = llama_model_load_from_file(params.speculative.model.path.c_str(), dft_mp); + if (!model_dft) { + LOG_ERR("failed to load EAGLE3 draft model\n"); + return 1; + } - model_tgt = llama_init_tgt->model(); - ctx_tgt = llama_init_tgt->context(); + llama_model_params tgt_mp = llama_model_default_params(); + tgt_mp.n_gpu_layers = params.n_gpu_layers; + model_tgt = llama_model_load_from_file(params.model.path.c_str(), tgt_mp); + if (!model_tgt) { + LOG_ERR("failed to load target model\n"); + return 1; + } + + llama_context_params tcp = common_context_params_to_llama(params); + tcp.eagle3_model = model_dft; // Enable feature extraction + ctx_tgt = llama_init_from_model(model_tgt, tcp); + } else { + // Standard load the target model + auto llama_init_tgt = common_init_from_params(params); + model_tgt = llama_init_tgt->model(); + ctx_tgt = llama_init_tgt->context(); + } const llama_vocab * vocab = llama_model_get_vocab(model_tgt); @@ -61,18 +88,57 @@ int main(int argc, char ** argv) { params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; - auto llama_init_dft = common_init_from_params(params); + if (params.speculative.eagle3) { + // EAGLE3: create encoder and decoder contexts + llama_context_params enc_params = common_context_params_to_llama(params); + enc_params.embeddings = true; + ctx_encoder = llama_init_from_model(model_dft, enc_params); + if (!ctx_encoder) { + LOG_ERR("failed to create EAGLE3 encoder context\n"); + return 1; + } - //model_dft = llama_init_dft->model(); - ctx_dft = llama_init_dft->context(); + llama_context_params dec_params = common_context_params_to_llama(params); + dec_params.target_model = model_tgt; + dec_params.embeddings = true; + ctx_decoder = llama_init_from_model(model_dft, dec_params); + if (!ctx_decoder) { + LOG_ERR("failed to create EAGLE3 decoder context\n"); + return 1; + } + } else { + // Standard: load draft model context + auto llama_init_dft = common_init_from_params(params); + model_dft = llama_init_dft->model(); + ctx_dft = llama_init_dft->context(); + + if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { + LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str()); + } + } - if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { - LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str()); + // Apply chat template for EAGLE3 if available which can increase the acceptance rate + std::string prompt = params.prompt; + if (params.speculative.eagle3) { + auto chat_templates = common_chat_templates_init(model_tgt, params.chat_template); + if (common_chat_templates_was_explicit(chat_templates.get())) { + std::vector chat_msgs; + common_chat_msg user_msg; + user_msg.role = "user"; + user_msg.content = params.prompt; + chat_msgs.push_back(user_msg); + + common_chat_templates_inputs inputs; + inputs.messages = chat_msgs; + inputs.add_generation_prompt = true; + prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt; + LOG_INF("%s: EAGLE3 chat template applied\n", __func__); + } } // Tokenize the prompt std::vector inp; - inp = common_tokenize(ctx_tgt, params.prompt, true, true); + inp = common_tokenize(ctx_tgt, prompt, true, true); if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) { LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); @@ -115,26 +181,52 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + llama_token id_last; + llama_tokens prompt_tgt; + int n_past; - // note: keep the last token separate! - llama_token id_last = inp.back(); + if (params.speculative.eagle3) { + // Target model decodes full prompt and sample first token and intermediate features are extracted + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size())); - // all tokens currently in the target context - llama_tokens prompt_tgt(inp.begin(), inp.end() - 1); - prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); + id_last = common_sampler_sample(smpl, ctx_tgt, -1); + common_sampler_accept(smpl, id_last, true); + LOG("%s", common_token_to_piece(ctx_tgt, id_last).c_str()); + n_predict++; - int n_past = inp.size() - 1; + // all tokens currently in the target context + prompt_tgt.assign(inp.begin(), inp.end()); + prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); + + n_past = inp.size(); + } else { + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + + // note: keep the last token separate! + id_last = inp.back(); + + // all tokens currently in the target context + prompt_tgt.assign(inp.begin(), inp.end() - 1); + prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); + + n_past = inp.size() - 1; + } // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; - params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; params_spec.p_min = p_min; - struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft); - for (auto &pair : params.speculative.replacements) { - common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str()); + struct common_speculative * spec = NULL; + + if (params.speculative.eagle3) { + spec = common_speculative_init_eagle3(ctx_tgt, ctx_encoder, ctx_decoder); + } else { + params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; + spec = common_speculative_init(ctx_tgt, ctx_dft); + for (auto &pair : params.speculative.replacements) { + common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str()); + } } llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); @@ -249,7 +341,14 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("draft:\n\n"); - llama_perf_context_print(ctx_dft); + if (ctx_dft) { + llama_perf_context_print(ctx_dft); + } else if (ctx_encoder && ctx_decoder) { + LOG_INF(" Eagle3 Draft encoder:\n"); + llama_perf_context_print(ctx_encoder); + LOG_INF("\nEagle3 Draft decoder:\n"); + llama_perf_context_print(ctx_decoder); + } LOG_INF("\n"); LOG_INF("target:\n\n"); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 686da3dbd107..fa73e8216b84 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -629,6 +629,7 @@ extern "C" { GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_SYNC = 16, // ...forces a new split/sync point in the scheduler (e.g. for EAGLE3 decoder) }; enum ggml_tri_type { @@ -853,6 +854,7 @@ extern "C" { GGML_API void ggml_set_output(struct ggml_tensor * tensor); GGML_API void ggml_set_param(struct ggml_tensor * tensor); GGML_API void ggml_set_loss(struct ggml_tensor * tensor); + GGML_API void ggml_set_sync(struct ggml_tensor * tensor); // force sync point in scheduler // // operations on tensors with backpropagation diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 08681f35e3f9..8e30d48ccc05 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1202,6 +1202,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra } } + // check if this node requires a sync point (e.g. for EAGLE3 parallel path fix) + if (node->flags & GGML_TENSOR_FLAG_SYNC) { + need_new_split = true; + } + if (node_backend_id != cur_backend_id || need_new_split) { split->i_end = i; i_split++; @@ -1576,6 +1581,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (ec != GGML_STATUS_SUCCESS) { return ec; } + + // If any node in this split has SYNC flag, synchronize after compute + // This ensures the sync node is complete before next split (e.g. for EAGLE3 parallel path sync fix) + for (int j = 0; j < split->graph.n_nodes; j++) { + if (split->graph.nodes[j]->flags & GGML_TENSOR_FLAG_SYNC) { + ggml_backend_synchronize(split_backend); + break; + } + } } else { // similar to ggml_backend_compare_graph_backend for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f0913cd35967..4625c3bd7707 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7451,6 +7451,10 @@ void ggml_set_loss(struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_LOSS; } +void ggml_set_sync(struct ggml_tensor * tensor) { + tensor->flags |= GGML_TENSOR_FLAG_SYNC; +} + //////////////////////////////////////////////////////////////////////////////// void ggml_quantize_init(enum ggml_type type) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2b8489c591b3..7d9d9b103b62 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -147,6 +147,8 @@ class LLM: EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input" DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in" DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" + EAGLE3_EXTRACT_LAYERS = "{arch}.extract_layers" + EAGLE3_TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size" class Attention: HEAD_COUNT = "{arch}.attention.head_count" @@ -446,6 +448,7 @@ class MODEL_ARCH(IntEnum): RND1 = auto() PANGU_EMBED = auto() MISTRAL3 = auto() + EAGLE3 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -710,6 +713,10 @@ class MODEL_TENSOR(IntEnum): NEXTN_HNORM = auto() NEXTN_SHARED_HEAD_HEAD = auto() NEXTN_SHARED_HEAD_NORM = auto() + # EAGLE3 specific tensors + EAGLE3_FC = auto() # feature fusion layer + EAGLE3_HIDDEN_NORM = auto() # hidden normalization + EAGLE3_D2T = auto() # draft to target vocabulary mapping MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -820,6 +827,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.MISTRAL3: "mistral3", + MODEL_ARCH.EAGLE3: "eagle3", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -1082,6 +1090,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_HNORM: "blk.{bid}.nextn.hnorm", MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD: "blk.{bid}.nextn.shared_head_head", MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM: "blk.{bid}.nextn.shared_head_norm", + MODEL_TENSOR.EAGLE3_FC: "fc", + MODEL_TENSOR.EAGLE3_HIDDEN_NORM: "blk.{bid}.hidden_norm", + MODEL_TENSOR.EAGLE3_D2T: "d2t", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -3094,6 +3105,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.EAGLE3: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.EAGLE3_FC, + MODEL_TENSOR.EAGLE3_HIDDEN_NORM, + MODEL_TENSOR.EAGLE3_D2T, + ], # TODO } diff --git a/include/llama.h b/include/llama.h index b52eaacfa7e8..c502b9ad0eac 100644 --- a/include/llama.h +++ b/include/llama.h @@ -363,6 +363,13 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + + // EAGLE3 extraction configuration + // When eagle3_model is set, layer extraction is automatically enabled + const struct llama_model * eagle3_model; // EAGLE3 model to read extract_layers configuration from + // If non-NULL, enables automatic feature extraction + const struct llama_model * target_model; // reference to target model + // only used to share embedding layer with eagle3 model }; // model quantization parameters @@ -846,6 +853,23 @@ extern "C" { llama_seq_id dest_seq_id, llama_state_seq_flags flags); + // + // EAGLE3 draft model support + // + + // Get pointer to target model features extracted for EAGLE3 encoder + // Returns NULL if no features are available + // Format: [3*n_embd, n_tokens] - use model.hparams.n_embd and batch.n_tokens for dimensions + LLAMA_API const float * llama_get_eagle3_target_features(struct llama_context * ctx); + + // Set g_embeddings from EAGLE3 encoder output for decoder input + // g_embd: pointer to encoder output embeddings + LLAMA_API void llama_set_eagle3_g_embeddings( + struct llama_context * ctx, + const float * g_embd, + int32_t n_embd, + int32_t n_tokens); + // // Decoding // diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4192af7c0c3b..4ffbc49a8027 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -58,6 +58,7 @@ add_library(llama models/deepseek2.cpp models/dots1.cpp models/dream.cpp + models/eagle3.cpp models/ernie4-5-moe.cpp models/ernie4-5.cpp models/exaone.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 64ad1b77690a..b8370c29553a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -112,6 +112,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_EAGLE3, "eagle3" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -245,6 +246,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_EAGLE3_EXTRACT_LAYERS, "%s.extract_layers" }, + { LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, // sentence-transformers dense modules feature dims { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, @@ -2540,6 +2544,30 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_EAGLE3, + { + // Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own, Llama 3.1 8B EAGLE3 uses target model's) + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, // Optional - only if EAGLE3 config has rope_scaling + // Single decoder layer (blk.0) + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + // EAGLE-3 specific layers + { LLM_TENSOR_EAGLE3_HIDDEN_NORM, "blk.%d.hidden_norm" }, + { LLM_TENSOR_EAGLE3_FC, "fc" }, + { LLM_TENSOR_EAGLE3_D2T, "d2t" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2742,6 +2770,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // EAGLE-3 tensors + {LLM_TENSOR_EAGLE3_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_EAGLE3_HIDDEN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_EAGLE3_D2T, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index e113180024d4..0aa7dd80d751 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -117,6 +117,7 @@ enum llm_arch { LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, LLM_ARCH_UNKNOWN, + LLM_ARCH_EAGLE3, }; enum llm_kv { @@ -287,6 +288,9 @@ enum llm_kv { LLM_KV_CLASSIFIER_OUTPUT_LABELS, + LLM_KV_EAGLE3_EXTRACT_LAYERS, + LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, + LLM_KV_SHORTCONV_L_CACHE, LLM_KV_XIELU_ALPHA_N, @@ -492,6 +496,9 @@ enum llm_tensor { LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + LLM_TENSOR_EAGLE3_FC, // eagle3: feature fusion layer + LLM_TENSOR_EAGLE3_HIDDEN_NORM, // eagle3: additional normalization layer + LLM_TENSOR_EAGLE3_D2T, // eagle3: draft to target vocabulary mapping }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 2a17e44ecdff..ea6dfaea3c94 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -135,6 +135,7 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + cparams.eagle3_extract_enabled = (params.eagle3_model != nullptr); // auto-enable if eagle3_model is provided { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); @@ -333,6 +334,30 @@ llama_context::llama_context( cross.v_embd.clear(); + // Initialize EAGLE3 feature extraction configuration + if (cparams.eagle3_extract_enabled) { + // Feature extraction layers configuration must come from EAGLE3 model + if (!params.eagle3_model) { + LLAMA_LOG_ERROR("%s: EAGLE3 extraction enabled but eagle3_model not provided\n", __func__); + throw std::runtime_error("EAGLE3 extraction requires eagle3_model parameter"); + } + + const auto & eagle3_hparams = params.eagle3_model->hparams; + // Copy feature extraction layer indices from EAGLE3 model's hparams + eagle3.extract_layer_indices.assign( + eagle3_hparams.eagle3_extract_layers.begin(), + eagle3_hparams.eagle3_extract_layers.end() + ); + + // Allocate tensors array for extraction + eagle3.extract_tensors.resize(eagle3.extract_layer_indices.size(), nullptr); + + LLAMA_LOG_INFO("%s: EAGLE3 extraction enabled for layers [%d, %d, %d]\n", __func__, + eagle3.extract_layer_indices[0], + eagle3.extract_layer_indices[1], + eagle3.extract_layer_indices[2]); + } + // avoid reserving graphs with zero outputs - assume one output per sequence n_outputs = n_seqs; @@ -832,6 +857,14 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll //const auto t_start_us = ggml_time_us(); res->set_inputs(&ubatch); + + // EAGLE3: Fill g_embeddings for decoder input + if (model.arch == LLM_ARCH_EAGLE3 && gtype == LLM_GRAPH_TYPE_DECODER && !eagle3.g_embeddings.empty()) { + ggml_tensor * g_embd = ggml_graph_get_tensor(gf, "inp_g_embeddings"); + if (g_embd) { + ggml_backend_tensor_set(g_embd, eagle3.g_embeddings.data(), 0, ggml_nbytes(g_embd)); + } + } //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); } @@ -843,6 +876,11 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } + // EAGLE3: Extract intermediate layer features after graph execution + if (cparams.eagle3_extract_enabled && !eagle3.extract_tensors.empty()) { + extract_eagle3_features(ubatch); + } + ret = GGML_STATUS_SUCCESS; return res; @@ -858,7 +896,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd_inp(); + // EAGLE3: use 3*target_hidden_size for concatenated features input + const int64_t n_embd = (model.arch == LLM_ARCH_EAGLE3 && batch_inp.embd) ? 3 * hparams.eagle3_target_hidden_size : hparams.n_embd; const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -941,8 +980,15 @@ int llama_context::encode(const llama_batch & batch_inp) { // extract token embeddings GGML_ASSERT(embd != nullptr); - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); + if (model.arch == LLM_ARCH_EAGLE3) { + // g_embeddings are stored temporarily in embd buffer + const int64_t out_embd = hparams.n_embd; + GGML_ASSERT(n_tokens * out_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens * out_embd * sizeof(float)); + } else { + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); + } } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1181,7 +1227,8 @@ int llama_context::decode(const llama_batch & batch_inp) { auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - if (t_embd && res->get_embd_pooled()) { + // For EAGLE3, don't override t_embd with t_embd_pooled - we need the prenorm value during eagle3 decoder autoregressive generation + if (t_embd && res->get_embd_pooled() && model.arch != LLM_ARCH_EAGLE3) { t_embd = res->get_embd_pooled(); } @@ -1196,7 +1243,39 @@ int llama_context::decode(const llama_batch & batch_inp) { if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + + // EAGLE3: Map draft vocab to target vocab + if (model.arch == LLM_ARCH_EAGLE3 && model.d2t) { + static thread_local std::vector eagle3_d2t_map; + static thread_local std::vector eagle3_draft_logits; + + const int64_t draft_vocab_size = t_logits->ne[0]; + const uint32_t last_idx = n_outputs - 1; + + // Load d2t mapping once (on first call) + if (eagle3_d2t_map.empty()) { + eagle3_d2t_map.resize(model.d2t->ne[0]); + ggml_backend_tensor_get(model.d2t, eagle3_d2t_map.data(), 0, eagle3_d2t_map.size() * sizeof(int64_t)); + } + + // Read only the last token's draft logits + eagle3_draft_logits.resize(draft_vocab_size); + const size_t last_offset = last_idx * draft_vocab_size * sizeof(float); + ggml_backend_tensor_get(t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float)); + + + // Map only the last token's draft logits to target vocab + float * last_logits_out = logits_out + last_idx * n_vocab; + std::fill(last_logits_out, last_logits_out + n_vocab, -std::numeric_limits::infinity()); + + for (int64_t j = 0; j < draft_vocab_size; j++) { + const int64_t target_id = j + eagle3_d2t_map[j]; + GGML_ASSERT(target_id >= 0 && target_id < n_vocab); + last_logits_out[target_id] = eagle3_draft_logits[j]; + } + } else { + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } } } @@ -1455,7 +1534,16 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + // EAGLE3: auto-detect encoder (embeddings+no target_model) or decoder (has target_model) + llm_graph_type gtype = LLM_GRAPH_TYPE_DEFAULT; + if (model.arch == LLM_ARCH_EAGLE3) { + if (cparams.embeddings && model.target_tok_embd == nullptr) { + gtype = LLM_GRAPH_TYPE_ENCODER; + } else if (model.target_tok_embd != nullptr) { + gtype = LLM_GRAPH_TYPE_DECODER; + } + } + const auto gparams = graph_params(res, ubatch, mctx, gtype); res->reset(); @@ -1491,6 +1579,7 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.eagle3 =*/ &eagle3, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -1534,6 +1623,27 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } + // EAGLE3: Extract intermediate layer features if this is an extraction point + if (cparams.eagle3_extract_enabled) { + static constexpr const char * prefix = "eagle3_extract_"; + static constexpr size_t prefix_len = 15; // strlen("eagle3_extract_") + + if (strncmp(name, prefix, prefix_len) == 0) { + // Parse the extraction index from the name (e.g., "eagle3_extract_0" -> 0) + size_t extract_idx = 0; + if (sscanf(name + prefix_len, "%zu", &extract_idx) == 1 && extract_idx < eagle3.extract_tensors.size()) { + // Mark as output tensor to ensure proper backend assignment + ggml_set_output(cur); + // Store this tensor reference for post-execution extraction + eagle3.extract_tensors[extract_idx] = cur; + LLAMA_LOG_DEBUG("%s: EAGLE3 stored tensor reference for extraction: " + "index=%zu, layer=%d, target_layer=%d, tensor=%s\n", + __func__, extract_idx, il, + eagle3.extract_layer_indices[extract_idx], name); + } + } + } + if (!cparams.offload_kqv) { if (strcmp(name, "kqv_merged_cont") == 0) { // all nodes between the KV store and the attention output are run on the CPU @@ -1559,6 +1669,54 @@ llm_graph_cb llama_context::graph_get_cb() const { }; } +void llama_context::extract_eagle3_features(const llama_ubatch & ubatch) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_embd = model.hparams.n_embd; + const size_t n_layers = eagle3.extract_tensors.size(); + + // Allocate storage for concatenated features + const int64_t n_embd_concat = n_embd * n_layers; + eagle3.target_features.resize(n_embd_concat * n_tokens); + + // Temporary buffer to hold layer features before transposing + static thread_local std::vector temp_layer_features; + temp_layer_features.resize(n_embd * n_tokens); + + LLAMA_LOG_DEBUG("%s: Start to extract EAGLE3 features: %zu layers, %lld tokens, %lld embd\n", + __func__, n_layers, (long long)n_tokens, (long long)n_embd); + + // Extract each layer's features and interleave into token-major layout + for (size_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + ggml_tensor * tensor = eagle3.extract_tensors[layer_idx]; + GGML_ASSERT(tensor != nullptr && "EAGLE3 extraction tensor is null"); + + // Get the backend where this tensor is stored + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), tensor); + GGML_ASSERT(backend != nullptr && "EAGLE3 tensor has no backend"); + + // Verify tensor shape: should be [n_embd, n_tokens] + GGML_ASSERT(tensor->ne[0] == n_embd && tensor->ne[1] == n_tokens && + "EAGLE3 extraction tensor has unexpected shape"); + + // Get layer features to temp buffer + const size_t size_bytes = n_embd * n_tokens * sizeof(float); + ggml_backend_tensor_get_async(backend, tensor, temp_layer_features.data(), 0, size_bytes); + ggml_backend_sched_synchronize(sched.get()); + + // Then copy to correct position in target_features + // target_features layout: [token_0_all_layers, token_1_all_layers, ...] + // Each token has [layer_0_embd, layer_1_embd, layer_2_embd] + for (int64_t token_idx = 0; token_idx < n_tokens; ++token_idx) { + // Source: temp_layer_features[token_idx * n_embd ... (token_idx + 1) * n_embd - 1] + const float * src = temp_layer_features.data() + token_idx * n_embd; + // Dest: target_features[token_idx * n_embd_concat + layer_idx * n_embd] + float * dest = eagle3.target_features.data() + token_idx * n_embd_concat + layer_idx * n_embd; + std::memcpy(dest, src, n_embd * sizeof(float)); + } + } + +} + // // state save/load // @@ -2354,6 +2512,8 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.eagle3_model =*/ nullptr, + /*.target_model =*/ nullptr, }; return result; @@ -2367,6 +2527,12 @@ llama_context * llama_init_from_model( return nullptr; } + // Auto-setup for EAGLE3: set target embedding if target_model is provided + if (model->arch == LLM_ARCH_EAGLE3 && params.target_model) { + model->target_tok_embd = params.target_model->tok_embd; + LLAMA_LOG_INFO("%s: EAGLE3 auto-setup: using target model's embedding layer\n", __func__); + } + if (params.n_batch == 0 && params.n_ubatch == 0) { LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); return nullptr; @@ -3016,3 +3182,33 @@ void llama_opt_epoch( callback_train, callback_eval); } + +// +// EAGLE3 member functions +// + +const float * llama_context::get_eagle3_target_features() const { + GGML_ASSERT(!eagle3.target_features.empty() && "EAGLE3 target features not extracted - call llama_encode() on target model first"); + return eagle3.target_features.data(); +} + +void llama_context::set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens) { + GGML_ASSERT(g_embd != nullptr && "g_embeddings cannot be null"); + GGML_ASSERT(n_embd > 0 && n_tokens > 0 && "invalid dimensions"); + + const size_t size = n_embd * n_tokens; + eagle3.g_embeddings.resize(size); + std::memcpy(eagle3.g_embeddings.data(), g_embd, size * sizeof(float)); +} + +// +// C API wrappers +// + +const float * llama_get_eagle3_target_features(llama_context * ctx) { + return ctx->get_eagle3_target_features(); +} + +void llama_set_eagle3_g_embeddings(llama_context * ctx, const float * g_embd, int32_t n_embd, int32_t n_tokens) { + ctx->set_eagle3_g_embeddings(g_embd, n_embd, n_tokens); +} diff --git a/src/llama-context.h b/src/llama-context.h index cd26eafe1894..1528d3f03e77 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -208,6 +208,12 @@ struct llama_context { // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); + // EAGLE3: Get pointer to target model features extracted for EAGLE3 encoder + const float * get_eagle3_target_features() const; + + // EAGLE3: Set g_embeddings from encoder output for decoder input + void set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -217,6 +223,9 @@ struct llama_context { llm_graph_cb graph_get_cb() const; + // EAGLE3: Extract intermediate layer features from target model + void extract_eagle3_features(const llama_ubatch & ubatch); + // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); @@ -235,6 +244,9 @@ struct llama_context { llama_adapter_loras loras; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably + + mutable llama_eagle3 eagle3; // EAGLE3 draft model support - stores features from target model + // mutable because it's modified during graph building (const function) std::unique_ptr memory; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fcef8fa97603..456c06e9a91c 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -34,6 +34,7 @@ struct llama_cparams { bool warmup; bool op_offload; bool kv_unified; + bool eagle3_extract_enabled; // enable layer extraction for EAGLE3 speculative decoding enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8909bbfb95e5..2b21a1d6590c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -590,6 +590,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + eagle3 (params.eagle3), cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), diff --git a/src/llama-graph.h b/src/llama-graph.h index e9d387bd7c5d..f93a8584400b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -70,6 +70,30 @@ struct llama_cross { std::vector> seq_ids_enc; }; +// EAGLE3 support - stores intermediate features from target model +struct llama_eagle3 { + // Configuration: which layers to extract from target model + std::vector extract_layer_indices; + + // Extracted features from target model (for encoder input) + // Concatenated [layer_l, layer_m, layer_h] embeddings + // Shape: [n_layers * n_embd, n_tokens] where n_layers = extract_layer_indices.size() + std::vector target_features; + + // Encoder output (for decoder input) + std::vector g_embeddings; + + // Tensor references for feature extraction from target model + std::vector extract_tensors; + + // Clear all stored data + void clear() { + target_features.clear(); + g_embeddings.clear(); + extract_tensors.clear(); + } +}; + struct llm_graph_params; // @@ -416,6 +440,7 @@ struct llm_graph_params { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + llama_eagle3 * eagle3; // non-const: we write extracted features here uint32_t n_outputs; @@ -579,6 +604,7 @@ struct llm_graph_context { const llama_adapter_loras * loras; const llama_memory_context_i * mctx; const llama_cross * cross; + llama_eagle3 * eagle3; // non-const: we write extracted features here const llm_graph_cb & cb_func; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index a467c64a14e0..d4337aea376a 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -188,6 +188,13 @@ struct llama_hparams { // qwen3vl deepstack uint32_t n_deepstack_layers = 0; + // EAGLE3 draft model - layer indices to extract from target model + // e.g., for 32-layer target: [2, 16, 29] (low, middle, high) + std::array eagle3_extract_layers = {0, 0, 0}; + + // EAGLE3 draft model - target model hidden size + uint32_t eagle3_target_hidden_size = 0; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 28f06b4e6154..acbdb5d9612d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2230,6 +2230,28 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EAGLE3: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // EAGLE3 layer extraction configuration + // Use array (has template instantiation), then copy first 3 elements + std::array extract_layers_tmp = {}; + if (!ml.get_key_or_arr(LLM_KV_EAGLE3_EXTRACT_LAYERS, extract_layers_tmp, 3, false)) { + throw std::runtime_error("EAGLE3 model requires 'extract_layers' in GGUF metadata"); + } + std::copy_n(extract_layers_tmp.begin(), 3, hparams.eagle3_extract_layers.begin()); + LLAMA_LOG_INFO("%s: EAGLE3 extract_layers = [%d, %d, %d]\n", __func__, + hparams.eagle3_extract_layers[0], + hparams.eagle3_extract_layers[1], + hparams.eagle3_extract_layers[2]); + + // EAGLE3 target model hidden size + ml.get_key(LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, hparams.eagle3_target_hidden_size); + LLAMA_LOG_INFO("%s: EAGLE3 target_hidden_size = %u (draft n_embd = %u)\n", __func__, + hparams.eagle3_target_hidden_size, hparams.n_embd); + + type = LLM_TYPE_UNKNOWN; + } break; case LLM_ARCH_COGVLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6408,6 +6430,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; + case LLM_ARCH_EAGLE3: + { + const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; + const int64_t n_embd_attn_input = 2 * n_embd; + + // Get vocab size from the d2t tensor in the GGUF file + // d2t: draft to target mapping (size = draft_vocab_size) + const struct ggml_tensor * d2t_meta = ml.get_tensor_meta("d2t"); + if (!d2t_meta) { + throw std::runtime_error("EAGLE3 model requires 'd2t' tensor but it was not found in the model file"); + } + const int64_t n_draft_vocab = d2t_meta->ne[0]; + + // Feature fusion layer: projects 3 target layers to draft hidden size + fc = create_tensor(tn(LLM_TENSOR_EAGLE3_FC, "weight"), {n_embd_target_features, n_embd}, 0); + + // Draft to target vocabulary mapping tensor + d2t = create_tensor(tn(LLM_TENSOR_EAGLE3_D2T), {n_draft_vocab}, 0); + + // Output layer (uses draft vocab size) + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_draft_vocab}, 0); + + // Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own) + const struct ggml_tensor * tok_embd_meta = ml.get_tensor_meta(tn(LLM_TENSOR_TOKEN_EMBD, "weight").str().c_str()); + if (tok_embd_meta) { + const int64_t n_target_vocab = tok_embd_meta->ne[1]; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_target_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using its own token_embd (vocab = %lld)\n", __func__, (long long)n_target_vocab); + } + + // Single decoder layer + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // input_layernorm: applied to token embeddings + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Attention takes input_embeds_normed + fused_target_normed as input + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_attn_input, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_attn_input, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_attn_input, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // EAGLE-3 specific: hidden_norm applied to fused target features + layer.eagle3_hidden_norm = create_tensor(tn(LLM_TENSOR_EAGLE3_HIDDEN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // rope_freqs for llama3 rope scaling (optional - only if EAGLE3 config has rope_scaling) + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_COGVLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7564,6 +7642,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_EAGLE3: + { + if (params.gtype == LLM_GRAPH_TYPE_ENCODER) { + llm = std::make_unique(*this, params); + } else { + llm = std::make_unique(*this, params); + } + } break; case LLM_ARCH_COGVLM: { llm = std::make_unique(*this, params); @@ -7749,6 +7835,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_EAGLE3: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-model.h b/src/llama-model.h index f8342cf2cb13..2d72fc78f5c9 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -404,6 +404,9 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // eagle3 + struct ggml_tensor * eagle3_hidden_norm = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -453,6 +456,13 @@ struct llama_model { struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; + // eagle3 + struct ggml_tensor * fc = nullptr; // feature fusion layer + struct ggml_tensor * d2t = nullptr; // draft to target vocabulary mapping + // Reference to target model's embedding layer + // This allows EAGLE3 to use target model's embeddings without copying + struct ggml_tensor * target_tok_embd = nullptr; + std::vector layers; //Dense linear projections for SentenceTransformers models like embeddinggemma diff --git a/src/models/eagle3.cpp b/src/models/eagle3.cpp new file mode 100644 index 000000000000..8987a0c5816a --- /dev/null +++ b/src/models/eagle3.cpp @@ -0,0 +1,187 @@ +#include "models.h" + +// EAGLE3 Encoder: processes target model features through feature fusion layer +// Input: target_features e.g. [12288, n_tokens] from target model layers low, middle, high +// Output: g_embeddings e.g. [4096, n_tokens] stored in context +llm_build_eagle3_encode::llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + + const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; + + ggml_tensor * cur; + + // Input: Target model features (3 layers concatenated: low, mid, high) + // Data will be provided via ubatch->embd in encode_eagle3_features() + auto inp_target = std::make_unique(); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_target_features, n_tokens); + ggml_set_input(inp_target->embd); + ggml_tensor * target_features = inp_target->embd; + res->add_input(std::move(inp_target)); + cb(target_features, "inp_target_features", -1); + + // Feature fusion layer + ggml_tensor * fused_target = build_lora_mm(model.fc, target_features); + cb(fused_target, "fc_out", -1); + + // Output: g_embeddings e.g. [4096, n_tokens] + cur = fused_target; + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} + +// EAGLE3 Decoder: processes draft tokens using g_embeddings from encoder +// Input: draft tokens + g_embeddings from encoder +// Output: draft logits +llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_layer == 1); // EAGLE-3 has only one decoder layer + + ggml_tensor * cur; + ggml_tensor * inpL; + + // EAGLE3 Decoder receives: + // 1. Token embeddings (e.g.from EAGLE3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B) + // 2. g_embeddings from encoder + // Choose token_embd_eagle3: prefer EAGLE3's own if available (Llama 3.3 70B), else use target's (Llama 3.1 8B) + ggml_tensor * token_embd_eagle3 = (model.tok_embd != nullptr) ? model.tok_embd : model.target_tok_embd; + GGML_ASSERT(token_embd_eagle3 != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)"); + ggml_tensor * input_embeds = build_inp_embd(token_embd_eagle3); + cb(input_embeds, "token_embd_eagle3", -1); + ggml_tensor * g_embeddings = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(g_embeddings); + ggml_set_name(g_embeddings, "inp_g_embeddings"); + cb(g_embeddings, "inp_g_embeddings", -1); + + // Store raw g_embeddings as residual + ggml_tensor * residual = g_embeddings; + + // Apply input_layernorm to the token embeddings + ggml_tensor * input_embeds_normed = build_norm(input_embeds, + model.layers[0].attn_norm, NULL, + LLM_NORM_RMS, 0); + cb(input_embeds_normed, "input_layernorm", -1); + + // Force a sync point between the two parallel RMS_NORM paths + // This prevents buffer reuse issues on GPU (EAGLE3 GPU fix) + ggml_set_sync(input_embeds_normed); + + // Apply hidden_norm to g_embeddings + ggml_tensor * g_embeddings_normed = build_norm(g_embeddings, + model.layers[0].eagle3_hidden_norm, NULL, + LLM_NORM_RMS, -1); + cb(g_embeddings_normed, "g_embeddings_normed", -1); + + // Concatenate normalized input_embeds and normalized g_embeddings + cur = ggml_concat(ctx0, input_embeds_normed, g_embeddings_normed, 0); + cb(cur, "concat_embeds_g", -1); + + inpL = cur; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + // Single decoder layer (il = 0) + const int il = 0; + { + // inpL is the concatenated input (normalized input_embeds + normalized g_embeddings) + ggml_tensor * inpSA = inpL; + + // Self-attention with concatenated input + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, inpL); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, inpL); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, inpL); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // rope freq factors, returns nullptr if not available + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + + if (inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + // Add residual and update it + ggml_tensor * attn_with_residual = ggml_add(ctx0, cur, residual); + cb(attn_with_residual, "attn_with_residual", il); + + // Update residual + residual = attn_with_residual; + + // Apply FFN norm to the sum + ggml_tensor * ffn_inp = build_norm(attn_with_residual, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(ffn_inp, "post_attn_norm", il); + + cur = ffn_inp; + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + inpL = cur; + } + + cur = inpL; + + // Output norm with residual + ggml_tensor * final_with_residual = ggml_add(ctx0, cur, residual); + cb(final_with_residual, "eagle3_prenorm", -1); + + // Output prenorm state (for next token's g_embeddings in autoregressive generation) + ggml_set_output(final_with_residual); + res->t_embd = final_with_residual; + + cur = build_norm(final_with_residual, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head - projects to draft vocabulary + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} \ No newline at end of file diff --git a/src/models/llama.cpp b/src/models/llama.cpp index ab7fd5d05086..e695ae2c6334 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -23,6 +23,16 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_para for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // EAGLE3: Extract intermediate layer features from target model at layer INPUT + if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { + static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; + for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { + if (eagle3->extract_layer_indices[i] == il) { + cb(inpL, eagle3_extract_names[i], il); + break; + } + } + } // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, diff --git a/src/models/models.h b/src/models/models.h index 6494f5450181..419b88002bc4 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -150,6 +150,14 @@ struct llm_build_dream : public llm_graph_context { llm_build_dream(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_eagle3_encode : public llm_graph_context { + llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params); +}; + +struct llm_build_eagle3_decode : public llm_graph_context { + llm_build_eagle3_decode(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_ernie4_5 : public llm_graph_context { llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params); }; From ac5667dcc6ea7d820c468e83a6e52bf646e63f71 Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Tue, 16 Dec 2025 16:53:28 +0000 Subject: [PATCH 02/21] fix eagle3 logits sync bug & remove ggml_set_sync() --- ggml/include/ggml.h | 2 -- ggml/src/ggml-backend.cpp | 14 -------------- ggml/src/ggml.c | 4 ---- src/llama-context.cpp | 3 ++- src/models/eagle3.cpp | 4 ---- 5 files changed, 2 insertions(+), 25 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fa73e8216b84..686da3dbd107 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -629,7 +629,6 @@ extern "C" { GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) - GGML_TENSOR_FLAG_SYNC = 16, // ...forces a new split/sync point in the scheduler (e.g. for EAGLE3 decoder) }; enum ggml_tri_type { @@ -854,7 +853,6 @@ extern "C" { GGML_API void ggml_set_output(struct ggml_tensor * tensor); GGML_API void ggml_set_param(struct ggml_tensor * tensor); GGML_API void ggml_set_loss(struct ggml_tensor * tensor); - GGML_API void ggml_set_sync(struct ggml_tensor * tensor); // force sync point in scheduler // // operations on tensors with backpropagation diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 8e30d48ccc05..08681f35e3f9 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1202,11 +1202,6 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra } } - // check if this node requires a sync point (e.g. for EAGLE3 parallel path fix) - if (node->flags & GGML_TENSOR_FLAG_SYNC) { - need_new_split = true; - } - if (node_backend_id != cur_backend_id || need_new_split) { split->i_end = i; i_split++; @@ -1581,15 +1576,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s if (ec != GGML_STATUS_SUCCESS) { return ec; } - - // If any node in this split has SYNC flag, synchronize after compute - // This ensures the sync node is complete before next split (e.g. for EAGLE3 parallel path sync fix) - for (int j = 0; j < split->graph.n_nodes; j++) { - if (split->graph.nodes[j]->flags & GGML_TENSOR_FLAG_SYNC) { - ggml_backend_synchronize(split_backend); - break; - } - } } else { // similar to ggml_backend_compare_graph_backend for (int j0 = 0; j0 < split->graph.n_nodes; j0++) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4625c3bd7707..f0913cd35967 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7451,10 +7451,6 @@ void ggml_set_loss(struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_LOSS; } -void ggml_set_sync(struct ggml_tensor * tensor) { - tensor->flags |= GGML_TENSOR_FLAG_SYNC; -} - //////////////////////////////////////////////////////////////////////////////// void ggml_quantize_init(enum ggml_type type) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ea6dfaea3c94..3506edd92bca 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1261,7 +1261,8 @@ int llama_context::decode(const llama_batch & batch_inp) { // Read only the last token's draft logits eagle3_draft_logits.resize(draft_vocab_size); const size_t last_offset = last_idx * draft_vocab_size * sizeof(float); - ggml_backend_tensor_get(t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float)); + ggml_backend_tensor_get_async(backend_res, t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float)); + synchronize(); // Map only the last token's draft logits to target vocab diff --git a/src/models/eagle3.cpp b/src/models/eagle3.cpp index 8987a0c5816a..dea887bdd396 100644 --- a/src/models/eagle3.cpp +++ b/src/models/eagle3.cpp @@ -63,10 +63,6 @@ llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, cons LLM_NORM_RMS, 0); cb(input_embeds_normed, "input_layernorm", -1); - // Force a sync point between the two parallel RMS_NORM paths - // This prevents buffer reuse issues on GPU (EAGLE3 GPU fix) - ggml_set_sync(input_embeds_normed); - // Apply hidden_norm to g_embeddings ggml_tensor * g_embeddings_normed = build_norm(g_embeddings, model.layers[0].eagle3_hidden_norm, NULL, From 5a79c1900f9ed31be400b827424890b774be5dfb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 17 Dec 2025 15:49:03 +0200 Subject: [PATCH 03/21] eagle3 : improve naming --- common/speculative.cpp | 2 +- include/llama.h | 4 +- src/llama-context.cpp | 42 ++++----- src/llama-graph.h | 8 +- src/models/eagle3.cpp | 208 ++++++++++++++++++++--------------------- src/models/models.h | 2 + 6 files changed, 134 insertions(+), 132 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 058e75b79615..4f97d464ddb5 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -307,7 +307,7 @@ static llama_tokens gen_eagle3_draft( /*.n_tokens =*/ n_new, /*.token =*/ nullptr, /*.embd =*/ const_cast(features), - /*.pos =*/ nullptr, + /*.pos =*/ nullptr, /*.n_seq_id =*/ nullptr, /*.seq_id =*/ nullptr, /*.logits =*/ nullptr, diff --git a/include/llama.h b/include/llama.h index caded2c83e43..7d26eebcbd8d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -364,7 +364,7 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - + // EAGLE3 extraction configuration // When eagle3_model is set, layer extraction is automatically enabled const struct llama_model * eagle3_model; // EAGLE3 model to read extract_layers configuration from @@ -876,7 +876,7 @@ extern "C" { // Returns NULL if no features are available // Format: [3*n_embd, n_tokens] - use model.hparams.n_embd and batch.n_tokens for dimensions LLAMA_API const float * llama_get_eagle3_target_features(struct llama_context * ctx); - + // Set g_embeddings from EAGLE3 encoder output for decoder input // g_embd: pointer to encoder output embeddings LLAMA_API void llama_set_eagle3_g_embeddings( diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a921112df2d2..b4f6bb5b997d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -353,7 +353,7 @@ llama_context::llama_context( // Allocate tensors array for extraction eagle3.extract_tensors.resize(eagle3.extract_layer_indices.size(), nullptr); - + LLAMA_LOG_INFO("%s: EAGLE3 extraction enabled for layers [%d, %d, %d]\n", __func__, eagle3.extract_layer_indices[0], eagle3.extract_layer_indices[1], @@ -879,7 +879,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll //const auto t_start_us = ggml_time_us(); res->set_inputs(&ubatch); - + // EAGLE3: Fill g_embeddings for decoder input if (model.arch == LLM_ARCH_EAGLE3 && gtype == LLM_GRAPH_TYPE_DECODER && !eagle3.g_embeddings.empty()) { ggml_tensor * g_embd = ggml_graph_get_tensor(gf, "inp_g_embeddings"); @@ -1265,32 +1265,32 @@ int llama_context::decode(const llama_batch & batch_inp) { if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - + // EAGLE3: Map draft vocab to target vocab if (model.arch == LLM_ARCH_EAGLE3 && model.d2t) { static thread_local std::vector eagle3_d2t_map; static thread_local std::vector eagle3_draft_logits; - + const int64_t draft_vocab_size = t_logits->ne[0]; const uint32_t last_idx = n_outputs - 1; - + // Load d2t mapping once (on first call) if (eagle3_d2t_map.empty()) { eagle3_d2t_map.resize(model.d2t->ne[0]); ggml_backend_tensor_get(model.d2t, eagle3_d2t_map.data(), 0, eagle3_d2t_map.size() * sizeof(int64_t)); } - + // Read only the last token's draft logits eagle3_draft_logits.resize(draft_vocab_size); const size_t last_offset = last_idx * draft_vocab_size * sizeof(float); ggml_backend_tensor_get_async(backend_res, t_logits, eagle3_draft_logits.data(), last_offset, draft_vocab_size * sizeof(float)); synchronize(); - - + + // Map only the last token's draft logits to target vocab float * last_logits_out = logits_out + last_idx * n_vocab; std::fill(last_logits_out, last_logits_out + n_vocab, -std::numeric_limits::infinity()); - + for (int64_t j = 0; j < draft_vocab_size; j++) { const int64_t target_id = j + eagle3_d2t_map[j]; GGML_ASSERT(target_id >= 0 && target_id < n_vocab); @@ -1656,7 +1656,7 @@ llm_graph_cb llama_context::graph_get_cb() const { if (cparams.eagle3_extract_enabled) { static constexpr const char * prefix = "eagle3_extract_"; static constexpr size_t prefix_len = 15; // strlen("eagle3_extract_") - + if (strncmp(name, prefix, prefix_len) == 0) { // Parse the extraction index from the name (e.g., "eagle3_extract_0" -> 0) size_t extract_idx = 0; @@ -1667,7 +1667,7 @@ llm_graph_cb llama_context::graph_get_cb() const { eagle3.extract_tensors[extract_idx] = cur; LLAMA_LOG_DEBUG("%s: EAGLE3 stored tensor reference for extraction: " "index=%zu, layer=%d, target_layer=%d, tensor=%s\n", - __func__, extract_idx, il, + __func__, extract_idx, il, eagle3.extract_layer_indices[extract_idx], name); } } @@ -1702,36 +1702,36 @@ void llama_context::extract_eagle3_features(const llama_ubatch & ubatch) { const int64_t n_tokens = ubatch.n_tokens; const int64_t n_embd = model.hparams.n_embd; const size_t n_layers = eagle3.extract_tensors.size(); - + // Allocate storage for concatenated features const int64_t n_embd_concat = n_embd * n_layers; eagle3.target_features.resize(n_embd_concat * n_tokens); - + // Temporary buffer to hold layer features before transposing static thread_local std::vector temp_layer_features; temp_layer_features.resize(n_embd * n_tokens); - + LLAMA_LOG_DEBUG("%s: Start to extract EAGLE3 features: %zu layers, %lld tokens, %lld embd\n", __func__, n_layers, (long long)n_tokens, (long long)n_embd); - + // Extract each layer's features and interleave into token-major layout for (size_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { ggml_tensor * tensor = eagle3.extract_tensors[layer_idx]; GGML_ASSERT(tensor != nullptr && "EAGLE3 extraction tensor is null"); - + // Get the backend where this tensor is stored ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), tensor); GGML_ASSERT(backend != nullptr && "EAGLE3 tensor has no backend"); - + // Verify tensor shape: should be [n_embd, n_tokens] GGML_ASSERT(tensor->ne[0] == n_embd && tensor->ne[1] == n_tokens && "EAGLE3 extraction tensor has unexpected shape"); - + // Get layer features to temp buffer const size_t size_bytes = n_embd * n_tokens * sizeof(float); ggml_backend_tensor_get_async(backend, tensor, temp_layer_features.data(), 0, size_bytes); ggml_backend_sched_synchronize(sched.get()); - + // Then copy to correct position in target_features // target_features layout: [token_0_all_layers, token_1_all_layers, ...] // Each token has [layer_0_embd, layer_1_embd, layer_2_embd] @@ -1743,7 +1743,7 @@ void llama_context::extract_eagle3_features(const llama_ubatch & ubatch) { std::memcpy(dest, src, n_embd * sizeof(float)); } } - + } // @@ -3235,7 +3235,7 @@ const float * llama_context::get_eagle3_target_features() const { void llama_context::set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens) { GGML_ASSERT(g_embd != nullptr && "g_embeddings cannot be null"); GGML_ASSERT(n_embd > 0 && n_tokens > 0 && "invalid dimensions"); - + const size_t size = n_embd * n_tokens; eagle3.g_embeddings.resize(size); std::memcpy(eagle3.g_embeddings.data(), g_embd, size * sizeof(float)); diff --git a/src/llama-graph.h b/src/llama-graph.h index 617ea154c345..69df6b1f4e3d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -74,18 +74,18 @@ struct llama_cross { struct llama_eagle3 { // Configuration: which layers to extract from target model std::vector extract_layer_indices; - + // Extracted features from target model (for encoder input) // Concatenated [layer_l, layer_m, layer_h] embeddings // Shape: [n_layers * n_embd, n_tokens] where n_layers = extract_layer_indices.size() std::vector target_features; - + // Encoder output (for decoder input) std::vector g_embeddings; - + // Tensor references for feature extraction from target model std::vector extract_tensors; - + // Clear all stored data void clear() { target_features.clear(); diff --git a/src/models/eagle3.cpp b/src/models/eagle3.cpp index dea887bdd396..629d89d32707 100644 --- a/src/models/eagle3.cpp +++ b/src/models/eagle3.cpp @@ -1,103 +1,109 @@ #include "models.h" +ggml_tensor * llm_build_eagle3_encode::build_inp_embd() const { + const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; + + ggml_tensor * cur = nullptr; + + // Input: Target model features (3 layers concatenated: low, mid, high) + // Data will be provided via ubatch->embd in encode_eagle3_features() + auto inp_target = std::make_unique(); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_target_features, n_tokens); + ggml_set_input(inp_target->embd); + + cur = inp_target->embd; + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp_target)); + + return cur; +} + // EAGLE3 Encoder: processes target model features through feature fusion layer // Input: target_features e.g. [12288, n_tokens] from target model layers low, middle, high // Output: g_embeddings e.g. [4096, n_tokens] stored in context llm_build_eagle3_encode::llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur = nullptr; - const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; - - ggml_tensor * cur; + cur = build_inp_embd(); - // Input: Target model features (3 layers concatenated: low, mid, high) - // Data will be provided via ubatch->embd in encode_eagle3_features() - auto inp_target = std::make_unique(); - inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_target_features, n_tokens); - ggml_set_input(inp_target->embd); - ggml_tensor * target_features = inp_target->embd; - res->add_input(std::move(inp_target)); - cb(target_features, "inp_target_features", -1); + // Feature fusion layer + cur = build_lora_mm(model.fc, cur); + cb(cur, "fc_out", -1); - // Feature fusion layer - ggml_tensor * fused_target = build_lora_mm(model.fc, target_features); - cb(fused_target, "fc_out", -1); + // Output: g_embeddings e.g. [4096, n_tokens] + res->t_embd = cur; - // Output: g_embeddings e.g. [4096, n_tokens] - cur = fused_target; - res->t_embd = cur; - - ggml_build_forward_expand(gf, cur); + ggml_build_forward_expand(gf, cur); } // EAGLE3 Decoder: processes draft tokens using g_embeddings from encoder // Input: draft tokens + g_embeddings from encoder // Output: draft logits llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_layer == 1); // EAGLE-3 has only one decoder layer - - ggml_tensor * cur; - ggml_tensor * inpL; - - // EAGLE3 Decoder receives: - // 1. Token embeddings (e.g.from EAGLE3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B) - // 2. g_embeddings from encoder - // Choose token_embd_eagle3: prefer EAGLE3's own if available (Llama 3.3 70B), else use target's (Llama 3.1 8B) - ggml_tensor * token_embd_eagle3 = (model.tok_embd != nullptr) ? model.tok_embd : model.target_tok_embd; - GGML_ASSERT(token_embd_eagle3 != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)"); - ggml_tensor * input_embeds = build_inp_embd(token_embd_eagle3); - cb(input_embeds, "token_embd_eagle3", -1); - ggml_tensor * g_embeddings = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - ggml_set_input(g_embeddings); - ggml_set_name(g_embeddings, "inp_g_embeddings"); - cb(g_embeddings, "inp_g_embeddings", -1); - - // Store raw g_embeddings as residual - ggml_tensor * residual = g_embeddings; + const int64_t n_embd_head = hparams.n_embd_head_v; - // Apply input_layernorm to the token embeddings - ggml_tensor * input_embeds_normed = build_norm(input_embeds, - model.layers[0].attn_norm, NULL, - LLM_NORM_RMS, 0); - cb(input_embeds_normed, "input_layernorm", -1); - - // Apply hidden_norm to g_embeddings - ggml_tensor * g_embeddings_normed = build_norm(g_embeddings, - model.layers[0].eagle3_hidden_norm, NULL, - LLM_NORM_RMS, -1); - cb(g_embeddings_normed, "g_embeddings_normed", -1); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_layer == 1); // EAGLE-3 has only one decoder layer - // Concatenate normalized input_embeds and normalized g_embeddings - cur = ggml_concat(ctx0, input_embeds_normed, g_embeddings_normed, 0); - cb(cur, "concat_embeds_g", -1); - - inpL = cur; + ggml_tensor * cur; + ggml_tensor * inpL; - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); + // EAGLE3 Decoder receives: + // 1. Token embeddings (e.g.from EAGLE3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B) + // 2. g_embeddings from encoder + // Choose token_embd_eagle3: prefer EAGLE3's own if available (Llama 3.3 70B), else use target's (Llama 3.1 8B) + ggml_tensor * token_embd_eagle3 = (model.tok_embd != nullptr) ? model.tok_embd : model.target_tok_embd; + GGML_ASSERT(token_embd_eagle3 != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)"); + ggml_tensor * inp_embd = build_inp_embd(token_embd_eagle3); + cb(inp_embd, "inp_embd", -1); - auto * inp_attn = build_attn_inp_kv(); + // TODO: refactor into llm_graph_input + ggml_tensor * inp_g = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(inp_g); + cb(inp_g, "inp_g_embeddings", -1); // TODO: do not change the name! refactor into llm_graph_input - ggml_tensor * inp_out_ids = build_inp_out_ids(); + inpL = inp_g; - const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); - // Single decoder layer (il = 0) - const int il = 0; - { - // inpL is the concatenated input (normalized input_embeds + normalized g_embeddings) + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Single decoder layer (il = 0) + const int il = 0; + { + // inpL is the concatenated input (normalized inp_embd + normalized inp_g) ggml_tensor * inpSA = inpL; + // Apply input_layernorm to the token embeddings + ggml_tensor * embd_norm = build_norm(inp_embd, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(embd_norm, "embd_norm", il); + + // Apply hidden_norm to inp_g + ggml_tensor * g_norm = build_norm(inp_g, + model.layers[il].eagle3_hidden_norm, NULL, + LLM_NORM_RMS, -1); + cb(g_norm, "g_norm", il); + + // Concatenate normalized inp_embd and normalized inp_g + cur = ggml_concat(ctx0, embd_norm, g_norm, il); + cb(cur, "concat_embd", il); + // Self-attention with concatenated input - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, inpL); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, inpL); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, inpL); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -127,25 +133,19 @@ llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, cons Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); if (inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - residual = ggml_get_rows(ctx0, residual, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } // Add residual and update it - ggml_tensor * attn_with_residual = ggml_add(ctx0, cur, residual); - cb(attn_with_residual, "attn_with_residual", il); - - // Update residual - residual = attn_with_residual; - + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + // Apply FFN norm to the sum - ggml_tensor * ffn_inp = build_norm(attn_with_residual, + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(ffn_inp, "post_attn_norm", il); - - cur = ffn_inp; + cb(cur, "post_attn_norm", il); cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, @@ -154,30 +154,30 @@ llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, cons NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); - + + // Output norm with residual + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "eagle3_prenorm", il); + inpL = cur; - } + } - cur = inpL; + cur = inpL; - // Output norm with residual - ggml_tensor * final_with_residual = ggml_add(ctx0, cur, residual); - cb(final_with_residual, "eagle3_prenorm", -1); - - // Output prenorm state (for next token's g_embeddings in autoregressive generation) - ggml_set_output(final_with_residual); - res->t_embd = final_with_residual; - - cur = build_norm(final_with_residual, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); + // Output prenorm state (for next token's g_embeddings in autoregressive generation) + ggml_set_output(cur); + res->t_embd = cur; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); - // lm_head - projects to draft vocabulary - cur = build_lora_mm(model.output, cur); + // lm_head - projects to draft vocabulary + cur = build_lora_mm(model.output, cur); - cb(cur, "result_output", -1); - res->t_logits = cur; + cb(cur, "result_output", -1); + res->t_logits = cur; - ggml_build_forward_expand(gf, cur); -} \ No newline at end of file + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index 653c962d1916..a6d1a2fccf2c 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -152,6 +152,8 @@ struct llm_build_dream : public llm_graph_context { struct llm_build_eagle3_encode : public llm_graph_context { llm_build_eagle3_encode(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_inp_embd() const; }; struct llm_build_eagle3_decode : public llm_graph_context { From c0d99e65d2d27f44df7f16e98dc7f28b6fe832cb Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Thu, 8 Jan 2026 23:49:06 +0000 Subject: [PATCH 04/21] add eagle3 support for Qwen3 series models --- convert_hf_to_gguf.py | 9 +++++---- src/models/qwen3.cpp | 11 +++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a9e17ee1faf2..7ef9ffb27b01 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2395,6 +2395,7 @@ def prepare_tensors(self): "VLlama3ForCausalLM", "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", + "LlamaForCausalLMEagle3", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -2477,10 +2478,6 @@ def set_vocab(self): # Llama 3 self._set_vocab_gpt2() - # Restore original dir_model for EAGLE-3 - if hasattr(self, 'is_eagle3') and self.is_eagle3: - self.dir_model = original_dir_model - # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256) if self.hparams.get("vocab_size", 32000) == 32016: special_vocab = gguf.SpecialVocab( @@ -2504,6 +2501,10 @@ def set_vocab(self): if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + # Restore original dir_model for EAGLE-3 + if hasattr(self, 'is_eagle3') and self.is_eagle3: + self.dir_model = original_dir_model + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams diff --git a/src/models/qwen3.cpp b/src/models/qwen3.cpp index a5cfffa53149..c1f34624c03e 100644 --- a/src/models/qwen3.cpp +++ b/src/models/qwen3.cpp @@ -21,6 +21,17 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // EAGLE3: Extract intermediate layer features from target model at layer INPUT + if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { + static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; + for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { + if (eagle3->extract_layer_indices[i] == il) { + cb(inpL, eagle3_extract_names[i], il); + break; + } + } + } + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, From 71ba283a6573b3735fa07c39d6e5f8cdeb9a34ab Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Fri, 9 Jan 2026 11:54:28 +0000 Subject: [PATCH 05/21] add eagle3 support for Qwen3 MoE models --- src/models/qwen3moe.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/models/qwen3moe.cpp b/src/models/qwen3moe.cpp index 888534fb3474..c0b6ff5df971 100644 --- a/src/models/qwen3moe.cpp +++ b/src/models/qwen3moe.cpp @@ -21,6 +21,17 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // EAGLE3: Extract intermediate layer features from target model at layer INPUT + if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { + static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; + for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { + if (eagle3->extract_layer_indices[i] == il) { + cb(inpL, eagle3_extract_names[i], il); + break; + } + } + } + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, From 3da288d78dc68005502481c50cb8bb3d482a6127 Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Sat, 10 Jan 2026 14:09:50 +0000 Subject: [PATCH 06/21] eagle3: load lm_head from target model if not in draft model when convert GGUF --- convert_hf_to_gguf.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7ef9ffb27b01..52140107fb5b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2638,6 +2638,17 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # EAGLE3: If no lm_head in draft model, load from target model + if hasattr(self, 'is_eagle3') and self.is_eagle3 and "lm_head.weight" not in self.model_tensors: + from safetensors import safe_open + for sf_file in self.target_model_dir.glob("*.safetensors"): + with safe_open(sf_file, framework="pt") as f: + if "lm_head.weight" in f.keys(): + lm_head = f.get_tensor("lm_head.weight") + logger.info(f"EAGLE3: No lm_head in draft model, loaded lm_head from {sf_file.name}, shape = {lm_head.shape}") + yield ("output.weight", lm_head) + break + if rope_params := self.rope_parameters.get("full_attention", self.rope_parameters): if rope_params.get("rope_type", '').lower() == "llama3": base = rope_params.get("rope_theta", 10000.0) From 13a9f31de3c4112c65693db3ed3e08223a069365 Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Sat, 10 Jan 2026 18:30:19 +0000 Subject: [PATCH 07/21] eagle3: make d2t mapping optional --- src/llama-model.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f4e22bdda8df..287bfe7f142b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6464,20 +6464,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_target_features = 3 * hparams.eagle3_target_hidden_size; const int64_t n_embd_attn_input = 2 * n_embd; - // Get vocab size from the d2t tensor in the GGUF file - // d2t: draft to target mapping (size = draft_vocab_size) + // Get vocab size from the d2t tensor in the GGUF file (optional - only needed if EAGLE3 has different vocab_size than target) + // d2t: draft to target vocabulary mapping + int64_t n_draft_vocab = n_vocab; // Default: same as target vocab const struct ggml_tensor * d2t_meta = ml.get_tensor_meta("d2t"); - if (!d2t_meta) { - throw std::runtime_error("EAGLE3 model requires 'd2t' tensor but it was not found in the model file"); + if (d2t_meta) { + n_draft_vocab = d2t_meta->ne[0]; // update draft vocab size + d2t = create_tensor(tn(LLM_TENSOR_EAGLE3_D2T), {n_draft_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using d2t mapping (draft_vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } else { + d2t = nullptr; // no d2t, use default vocab size + LLAMA_LOG_INFO("%s: EAGLE3 without d2t - sharing same vocab_size with target (vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); } - const int64_t n_draft_vocab = d2t_meta->ne[0]; // Feature fusion layer: projects 3 target layers to draft hidden size fc = create_tensor(tn(LLM_TENSOR_EAGLE3_FC, "weight"), {n_embd_target_features, n_embd}, 0); - // Draft to target vocabulary mapping tensor - d2t = create_tensor(tn(LLM_TENSOR_EAGLE3_D2T), {n_draft_vocab}, 0); - // Output layer (uses draft vocab size) output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_draft_vocab}, 0); From 75883cde73fbbd0792cd578cb572ab4382d7b8c3 Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Sat, 10 Jan 2026 18:33:41 +0000 Subject: [PATCH 08/21] eagle3: add support for gpt-oss-120B eagle3 --- src/models/openai-moe-iswa.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/models/openai-moe-iswa.cpp b/src/models/openai-moe-iswa.cpp index 96596709eec5..08cc41f3c11b 100644 --- a/src/models/openai-moe-iswa.cpp +++ b/src/models/openai-moe-iswa.cpp @@ -16,6 +16,17 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // EAGLE3: Extract intermediate layer features from target model at layer INPUT + if (eagle3 && cparams.eagle3_extract_enabled && !eagle3->extract_layer_indices.empty()) { + static const char * eagle3_extract_names[] = {"eagle3_extract_0", "eagle3_extract_1", "eagle3_extract_2"}; + for (size_t i = 0; i < eagle3->extract_layer_indices.size() && i < 3; ++i) { + if (eagle3->extract_layer_indices[i] == il) { + cb(inpL, eagle3_extract_names[i], il); + break; + } + } + } + // norm cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, From 7b78bfa9845f3de31e634809a4fdbaf10000bc29 Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Fri, 16 Jan 2026 00:54:14 +0000 Subject: [PATCH 09/21] eagle3: add support for RedHtAI eagle3 speculator series models --- convert_hf_to_gguf.py | 17 ++++++++++++++++- gguf-py/gguf/constants.py | 1 + src/llama-arch.cpp | 5 +++-- src/llama-arch.h | 1 + src/llama-hparams.h | 3 +++ src/llama-model.cpp | 9 ++++++++- src/models/eagle3.cpp | 9 ++++++--- 7 files changed, 38 insertions(+), 7 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 52140107fb5b..2babd7f9f083 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2396,6 +2396,8 @@ def prepare_tensors(self): "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", "LlamaForCausalLMEagle3", + "Eagle3Speculator", + "Eagle3DraftModel", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -2445,6 +2447,11 @@ def __init__(self, *args, **kwargs): logger.info(f"EAGLE3: target_hidden_size = {target_hidden_size} (from target model config)") self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.target_hidden_size", target_hidden_size) + # Eagle3Speculator norm_before_residual specific handling + norm_before_residual = eagle3_raw_config.get("norm_before_residual", False) + logger.info(f"EAGLE3: norm_before_residual = {norm_before_residual} (from EAGLE3 config)") + self.gguf_writer.add_bool(f"{self.gguf_writer.arch}.norm_before_residual", norm_before_residual) + def set_vocab(self): # For EAGLE-3 models, use tokenizer from target model if provided if hasattr(self, 'is_eagle3') and self.is_eagle3: @@ -2528,15 +2535,23 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None): def index_tensors(self, remote_hf_model_id: str | None = None) -> dict[str, Callable[[], Tensor]]: tensors = super().index_tensors(remote_hf_model_id) + + # Handle Eagle3Speculator nested config + if "transformer_layer_config" in self.hparams: + self.hparams = {**self.hparams, **self.hparams["transformer_layer_config"]} + # EAGLE-3 detection: check hparams directly (before self.is_eagle3 is set) if "draft_vocab_size" in self.hparams and self.hparams["num_hidden_layers"] == 1: - logger.info("EAGLE-3: Renaming midlayer.* to model.layers.0.*") + logger.info("EAGLE-3: Renaming midlayer.* or layers.0.* to model.layers.0.*") new_tensors = {} # EAGLE-3: rename midlayer.* to model.layers.0.* for compatibility with llama model for name, gen in tensors.items(): if name.startswith("midlayer."): new_name = "model.layers.0." + name[len("midlayer."):] new_tensors[new_name] = gen + elif name.startswith("layers.0."): # layers.0.* -> model.layers.0.* (Eagle3Speculator format) + new_name = "model." + name + new_tensors[new_name] = gen else: new_tensors[name] = gen return new_tensors diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b1160ca26d8a..2ae5094619da 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -149,6 +149,7 @@ class LLM: DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out" EAGLE3_EXTRACT_LAYERS = "{arch}.extract_layers" EAGLE3_TARGET_HIDDEN_SIZE = "{arch}.target_hidden_size" + EAGLE3_NORM_BEFORE_RESIDUAL = "{arch}.norm_before_residual" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4caa5f77aee3..8304c6361551 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -248,8 +248,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, - { LLM_KV_EAGLE3_EXTRACT_LAYERS, "%s.extract_layers" }, - { LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, + { LLM_KV_EAGLE3_EXTRACT_LAYERS, "%s.extract_layers" }, + { LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, + { LLM_KV_EAGLE3_NORM_BEFORE_RESIDUAL, "%s.norm_before_residual" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, // sentence-transformers dense modules feature dims diff --git a/src/llama-arch.h b/src/llama-arch.h index 3e731b5005ba..36cad138a869 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -292,6 +292,7 @@ enum llm_kv { LLM_KV_EAGLE3_EXTRACT_LAYERS, LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, + LLM_KV_EAGLE3_NORM_BEFORE_RESIDUAL, LLM_KV_SHORTCONV_L_CACHE, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 9272c728e316..f8ed7f364c12 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -196,6 +196,9 @@ struct llama_hparams { // EAGLE3 draft model - target model hidden size uint32_t eagle3_target_hidden_size = 0; + // EAGLE3 draft model - apply hidden_norm before storing residual + bool eagle3_norm_before_residual = false; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 287bfe7f142b..4879376aefac 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2260,7 +2260,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, hparams.eagle3_target_hidden_size); LLAMA_LOG_INFO("%s: EAGLE3 target_hidden_size = %u (draft n_embd = %u)\n", __func__, hparams.eagle3_target_hidden_size, hparams.n_embd); - + + // EAGLE3 norm_before_residual (optional, default false) + // compatible with Readhat eagle3 speculator model + ml.get_key(LLM_KV_EAGLE3_NORM_BEFORE_RESIDUAL, hparams.eagle3_norm_before_residual, false); + if (hparams.eagle3_norm_before_residual) { + LLAMA_LOG_INFO("%s: EAGLE3 norm_before_residual = true\n", __func__); + } + type = LLM_TYPE_UNKNOWN; } break; case LLM_ARCH_COGVLM: diff --git a/src/models/eagle3.cpp b/src/models/eagle3.cpp index 629d89d32707..4f9410b3602e 100644 --- a/src/models/eagle3.cpp +++ b/src/models/eagle3.cpp @@ -77,9 +77,6 @@ llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, cons // Single decoder layer (il = 0) const int il = 0; { - // inpL is the concatenated input (normalized inp_embd + normalized inp_g) - ggml_tensor * inpSA = inpL; - // Apply input_layernorm to the token embeddings ggml_tensor * embd_norm = build_norm(inp_embd, model.layers[il].attn_norm, NULL, @@ -92,6 +89,12 @@ llm_build_eagle3_decode::llm_build_eagle3_decode(const llama_model & model, cons LLM_NORM_RMS, -1); cb(g_norm, "g_norm", il); + // norm_before_residual: determines what goes into the residual connection (compatible with Readhat eagle3 speculator model) + // - false (default): use raw inp_g for residual + // - true: use normalized g_norm for residual + // inpL is the concatenated input (normalized inp_embd + normalized inp_g) + ggml_tensor * inpSA = hparams.eagle3_norm_before_residual ? g_norm : inpL; + // Concatenate normalized inp_embd and normalized inp_g cur = ggml_concat(ctx0, embd_norm, g_norm, il); cb(cur, "concat_embd", il); From b3537924efa7552a5e30c64b48800a8d77abfc09 Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Fri, 20 Feb 2026 17:54:08 +0000 Subject: [PATCH 10/21] eagle3: fix model convert issue --- convert_hf_to_gguf.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3dcaa0e79789..7ca4c0fe85f2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2665,7 +2665,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # Eagle-3 llama checkpoint special weights handling # fc.weight: feature fusion layer if name == "fc.weight": - return [(name, data_torch)] + yield (name, data_torch) + return # d2t: draft to target vocabulary mapping elif name == "d2t": # Skip parent class processing (store for manual handling in prepare_tensors) @@ -2678,7 +2679,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [] # hidden_norm: EAGLE-3 specific layer normalization elif name == "model.layers.0.hidden_norm.weight": - return [("blk.0.hidden_norm.weight", data_torch)] + yield ("blk.0.hidden_norm.weight", data_torch) + return n_head = self.find_hparam(["n_heads", "num_attention_heads"]) n_kv_head = self.find_hparam(["n_kv_heads", "num_key_value_heads"]) From 9fea2434af1b0647ee424a5cc433892956b175c0 Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Fri, 20 Feb 2026 18:05:49 +0000 Subject: [PATCH 11/21] eagle3: fix model convert code format --- convert_hf_to_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7ca4c0fe85f2..71346c8b2eac 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2673,10 +2673,10 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if not hasattr(self, '_eagle3_int_tensors'): self._eagle3_int_tensors = {} self._eagle3_int_tensors[name] = data_torch - return [] + return # t2d: target to draft vocabulary mapping (not used, skip completely) elif name == "t2d": - return [] + return # hidden_norm: EAGLE-3 specific layer normalization elif name == "model.layers.0.hidden_norm.weight": yield ("blk.0.hidden_norm.weight", data_torch) From 07e2c9707cc9d4a6693e480b8b2619f2fad1c66a Mon Sep 17 00:00:00 2001 From: ruixiangw Date: Sat, 28 Feb 2026 00:33:54 +0000 Subject: [PATCH 12/21] eagle3: support --eagle3 in llama-cli --- common/arg.cpp | 2 +- tools/server/server-context.cpp | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 95236723171c..c2c0b4018b2c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3352,7 +3352,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.speculative.eagle3 = true; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( {"-cd", "--ctx-size-draft"}, "N", string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0f2f3a45aaae..545fcce9be84 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -671,7 +671,18 @@ struct server_context_impl { } params_base.speculative.model_dft = model_dft.get(); + params_base.speculative.model_tgt = model; params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); + + if (params_base.speculative.eagle3) { + // EAGLE3 current limitation: extracted target features are per-context; multiple slots would overwrite each other + if (params_base.n_parallel > 1) { + SRV_ERR("%s", "EAGLE3 speculative decoding is not supported with n_parallel > 1\n"); + return false; + } + llama_set_eagle3(ctx, model_dft.get()); + SRV_INF("%s", "EAGLE3 feature extraction enabled on target model\n"); + } } std::string & mmproj_path = params_base.mmproj.path; From 0724d66e5c85cebf84eba0be4e053872e13998ce Mon Sep 17 00:00:00 2001 From: Ruixiang Wang Date: Sat, 18 Apr 2026 23:58:32 +0000 Subject: [PATCH 13/21] dflash: first working POC --- common/arg.cpp | 7 + common/common.h | 2 + common/speculative.cpp | 154 +++++++++++++++- convert_hf_to_gguf.py | 41 +++++ .../speculative-simple/speculative-simple.cpp | 59 ++++++- gguf-py/gguf/constants.py | 27 ++- include/llama.h | 24 +++ src/llama-arch.cpp | 13 +- src/llama-arch.h | 7 + src/llama-context.cpp | 166 +++++++++++++++++- src/llama-context.h | 16 ++ src/llama-cparams.h | 1 + src/llama-graph.cpp | 1 + src/llama-graph.h | 16 ++ src/llama-hparams.h | 5 + src/llama-model-loader.cpp | 1 + src/llama-model.cpp | 73 ++++++++ src/llama-model.h | 4 + src/models/dflash.cpp | 161 +++++++++++++++++ src/models/models.h | 10 ++ src/models/openai-moe-iswa.cpp | 13 ++ src/models/qwen3.cpp | 14 ++ src/models/qwen35.cpp | 14 ++ 23 files changed, 816 insertions(+), 13 deletions(-) create mode 100644 src/models/dflash.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 2724dc8a4df0..70d90f97347c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3474,6 +3474,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.eagle3 = true; } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"--dflash"}, + "use DFlash speculative decoding with the draft model", + [](common_params & params) { + params.speculative.dflash = true; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( {"-cd", "--ctx-size-draft"}, "N", string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), diff --git a/common/common.h b/common/common.h index 19bc4172d281..27c85d32bb4e 100644 --- a/common/common.h +++ b/common/common.h @@ -159,6 +159,7 @@ enum common_speculative_type { COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding COMMON_SPECULATIVE_TYPE_DRAFT, // draft model COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model + COMMON_SPECULATIVE_TYPE_DFLASH, // dflash draft model COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values @@ -328,6 +329,7 @@ struct common_params_speculative { llama_context_params cparams_dft; // these are the parameters for the draft llama_context bool eagle3 = false; // use EAGLE3 speculative decoding + bool dflash = false; // use DFlash speculative decoding int32_t n_ctx = 0; // draft context size int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) diff --git a/common/speculative.cpp b/common/speculative.cpp index 292566b6c65f..4980c03da62e 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -22,6 +22,7 @@ const std::vector common_speculative_types = { COMMON_SPECULATIVE_TYPE_NONE, COMMON_SPECULATIVE_TYPE_DRAFT, COMMON_SPECULATIVE_TYPE_EAGLE3, + COMMON_SPECULATIVE_TYPE_DFLASH, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, @@ -33,6 +34,7 @@ const std::map common_speculative_typ {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, + {"dflash", COMMON_SPECULATIVE_TYPE_DFLASH}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, @@ -708,6 +710,139 @@ struct common_speculative_state_eagle3 : public common_speculative_state { } }; +struct common_speculative_state_dflash : public common_speculative_state { + llama_context * ctx_tgt; + + common_sampler * smpl; + + llama_batch batch; + + struct llama_context * ctx_dft_enc = nullptr; + struct llama_context * ctx_dft_dec = nullptr; + + int32_t dflash_n_past = 0; + + // Host-side buffer: accumulated DFlash-encoded target features across all + // committed prompt+drafted tokens. Grows by `n_new * n_embd` floats per draft step + // and is fed to the DFlash decoder via llama_set_dflash_accumulated_target_ctx() + std::vector accumulated_ctx; + + common_speculative_state_dflash( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft_enc, + llama_context * ctx_dft_dec) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft_enc(ctx_dft_enc) + , ctx_dft_dec(ctx_dft_dec) + { + batch = llama_batch_init(llama_n_batch(ctx_dft_dec), 0, 1); + + common_params_sampling params; + params.no_perf = false; + params.top_k = 1; + params.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + smpl = common_sampler_init(llama_get_model(ctx_dft_dec), params); + } + + ~common_speculative_state_dflash() override { + llama_perf_context_print(ctx_dft_dec); + + if (ctx_dft_dec) { + llama_free(ctx_dft_dec); + } + + if (ctx_dft_enc) { + llama_free(ctx_dft_enc); + } + + common_sampler_free(smpl); + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + const int n_embd = llama_model_n_embd(llama_get_model(ctx_dft_dec)); + // block_size is bounded by the model's trained block_size (from GGUF metadata). + const int model_block_size = llama_model_dflash_block_size(llama_get_model(ctx_dft_dec)); + const int block_size = std::min((int)params.n_max, model_block_size); + const int n = (int)prompt_tgt.size(); + const int n_new = n - dflash_n_past; + + GGML_ASSERT(n >= 1 && "prompt_tgt is empty"); + GGML_ASSERT(n_new >= 1 && "must have at least 1 new token"); + + // Step 1: Encode new accepted tokens' features + const float * features = llama_get_dflash_target_features(ctx_tgt); + + llama_batch enc_batch = { + /*.n_tokens =*/ n_new, + /*.token =*/ nullptr, + /*.embd =*/ const_cast(features), + /*.pos =*/ nullptr, + /*.n_seq_id =*/ nullptr, + /*.seq_id =*/ nullptr, + /*.logits =*/ nullptr, + }; + if (llama_encode(ctx_dft_enc, enc_batch) != 0) { + LOG_ERR("DFlash: encoder failed\n"); + return; + } + + const float * target_ctx_new = llama_get_embeddings(ctx_dft_enc); + GGML_ASSERT(target_ctx_new && "encoder output is null"); + + // Step 2: Append to accumulated target_ctx and set on decoder context (writes to cross.v_embd) + const size_t new_size = (size_t)n_embd * n_new; + accumulated_ctx.insert(accumulated_ctx.end(), target_ctx_new, target_ctx_new + new_size); + + const int n_ctx_total = (int)(accumulated_ctx.size() / n_embd); + llama_set_dflash_accumulated_target_ctx(ctx_dft_dec, accumulated_ctx.data(), n_embd, n_ctx_total); + + // Step 3: Decode noise block + const llama_token mask_token_id = llama_model_dflash_mask_token_id(llama_get_model(ctx_dft_dec)); + + common_batch_clear(batch); + for (int i = 0; i < block_size; i++) { + const llama_token tok = (i == 0) ? id_last : mask_token_id; + common_batch_add(batch, tok, i, {0}, true); + } + + if (llama_decode(ctx_dft_dec, batch) != 0) { + LOG_ERR("DFlash: noise decode failed\n"); + return; + } + + dflash_n_past = n; + + // Step 4: Sample draft tokens from positions 1..block_size-1 + result.clear(); + common_sampler_reset(smpl); + + for (int i = 1; i < block_size; i++) { + common_sampler_sample(smpl, ctx_dft_dec, i); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + result.push_back(id); + } + } + + void accept(uint16_t n_accepted) override { + GGML_UNUSED(n_accepted); + } +}; + // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_state_ngram_simple : public common_speculative_state { common_ngram_simple_config config; @@ -1057,13 +1192,13 @@ common_speculative * common_speculative_init( llama_context * ctx_dft_dec = nullptr; if (params.model_dft) { - if (params.eagle3) { + if (params.eagle3 || params.dflash) { llama_context_params params_enc = params.cparams_dft; params_enc.target_model = nullptr; params_enc.embeddings = true; ctx_dft_enc = llama_init_from_model(params.model_dft, params_enc); if (!ctx_dft_enc) { - LOG_ERR("failed to create EAGLE3 encoder context\n"); + LOG_ERR("failed to create %s draft model encoder context\n", params.eagle3 ? "EAGLE3" : "DFlash"); return nullptr; } @@ -1072,13 +1207,13 @@ common_speculative * common_speculative_init( params_dec.embeddings = true; ctx_dft_dec = llama_init_from_model(params.model_dft, params_dec); if (!ctx_dft_dec) { - LOG_ERR("failed to create EAGLE3 decoder context\n"); + LOG_ERR("failed to create %s draft model decoder context\n", params.eagle3 ? "EAGLE3" : "DFlash"); return nullptr; } } else { ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft); if (ctx_dft == nullptr) { - LOG_ERR("%s", "failed to create draft context\n"); + LOG_ERR("failed to create draft model context\n"); return nullptr; } } @@ -1089,6 +1224,7 @@ common_speculative * common_speculative_init( { bool has_draft = !params.mparams_dft.path.empty(); bool has_draft_eagle3 = params.eagle3; + bool has_draft_dflash = params.dflash; bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); @@ -1131,6 +1267,8 @@ common_speculative * common_speculative_init( if (has_draft) { if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); + } else if (has_draft_dflash) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DFLASH, params)); } else { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); } @@ -1163,6 +1301,14 @@ common_speculative * common_speculative_init( )); break; } + case COMMON_SPECULATIVE_TYPE_DFLASH: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft_enc = */ ctx_dft_enc, + /* .ctx_dft_dec = */ ctx_dft_dec + )); + break; + } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 53b11d0171cb..6a5ac25d945d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4887,6 +4887,47 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("DFlashDraftModel") +class DFlashModel(Qwen3Model): + model_arch = gguf.MODEL_ARCH.DFLASH + + def set_vocab(self): + if self.target_model_dir is None: + raise ValueError( + "DFlash draft model requires --target-model-dir to be specified. " + "Please provide the path to the target model directory containing the tokenizer." + ) + logger.info(f"DFLASH: Using tokenizer from target model: {self.target_model_dir}") + original_dir = self.dir_model + self.dir_model = self.target_model_dir + super().set_vocab() + self.dir_model = original_dir + + def set_gguf_parameters(self): + super().set_gguf_parameters() + block_size = self.hparams.get("block_size", 16) + self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.block_size", block_size) + dflash_config = self.hparams.get("dflash_config", {}) + target_layer_ids = dflash_config.get("target_layer_ids", []) + if target_layer_ids: + extract_layer_ids = [i + 1 for i in target_layer_ids] + self.gguf_writer.add_array(f"{self.gguf_writer.arch}.target_layer_ids", extract_layer_ids) + mask_token_id = dflash_config.get("mask_token_id", None) + if mask_token_id is not None: + self.gguf_writer.add_uint32(f"{self.gguf_writer.arch}.mask_token_id", mask_token_id) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "fc.weight": + yield (name, data_torch) + return + if name == "hidden_norm.weight": + yield ("hidden_norm.weight", data_torch) + return + if not name.startswith("model."): + name = "model." + name + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("Qwen3MoeForCausalLM") class Qwen3MoeModel(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.QWEN3MOE diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a8c297131077..003cc217f493 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -111,11 +111,14 @@ int main(int argc, char ** argv) { if (params.speculative.eagle3) { llama_set_eagle3(ctx_tgt, model_dft.get()); } + if (params.speculative.dflash) { + llama_set_dflash(ctx_tgt, model_dft.get()); + } } - // Apply chat template for EAGLE3 if available which can increase the acceptance rate + // Apply chat template for EAGLE3 / DFlash if available which can increase the acceptance rate std::string prompt = params.prompt; - if (params.speculative.eagle3) { + if (params.speculative.eagle3 || params.speculative.dflash) { auto chat_templates = common_chat_templates_init(model_tgt, params.chat_template); if (common_chat_templates_was_explicit(chat_templates.get())) { std::vector chat_msgs; @@ -127,8 +130,15 @@ int main(int argc, char ** argv) { common_chat_templates_inputs inputs; inputs.messages = chat_msgs; inputs.add_generation_prompt = true; + // Disable thinking mode can improve accept rate + if (const char * nt = std::getenv("LLAMA_SPEC_NO_THINK"); nt && std::string(nt) != "0") { + // Qwen3 / 3.5 + inputs.enable_thinking = false; + // gpt-oss + inputs.chat_template_kwargs["reasoning_effort"] = "\"low\""; + } prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt; - LOG_INF("%s: EAGLE3 chat template applied\n", __func__); + LOG_INF("%s: %s chat template applied\n", __func__, params.speculative.eagle3 ? "EAGLE3" : "DFlash"); } } @@ -177,7 +187,7 @@ int main(int argc, char ** argv) { int n_past; // TODO: simplify - if (params.speculative.eagle3) { + if (params.speculative.eagle3 || params.speculative.dflash) { // Target model decodes full prompt and sample first token and intermediate features are extracted llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size())); @@ -222,6 +232,21 @@ int main(int argc, char ** argv) { const auto t_dec_start = ggml_time_us(); + // Hybrid targets (e.g. Qwen3.5) have recurrent layers that cannot be partially rolled back via seq_rm. + // For them, snapshot the target state before verify and, on rejection, restore it and replay only the accepted tokens to ensure correctness + // This is not efficient because the target model may run twice, but it is required in current llama.cpp design + const bool use_state_snapshot = params.speculative.dflash && llama_model_is_hybrid(model_tgt); + if (params.speculative.dflash) { + LOG_INF("%s: DFlash target=%s, using %s rollback path\n", __func__, + llama_model_is_hybrid(model_tgt) ? "hybrid" : "pure-attention", + use_state_snapshot ? "snapshot+restore" : "seq_rm"); + } + std::vector state_snap; + if (use_state_snapshot) { + const size_t sz = llama_state_seq_get_size(ctx_tgt, 0); + state_snap.resize(sz); + } + while (true) { // generate or reuse draft tokens // @@ -269,6 +294,17 @@ int main(int argc, char ** argv) { GGML_ASSERT(n_draft > 0); + // snapshot target state for potential rollback (hybrid/recurrent targets only) + const int n_past_before = n_past; + const llama_token id_last_saved = id_last; + if (use_state_snapshot) { + const size_t sz = llama_state_seq_get_size(ctx_tgt, 0); + if (sz > state_snap.size()) { + state_snap.resize(sz); + } + llama_state_seq_get_data(ctx_tgt, state_snap.data(), sz, 0); + } + // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); @@ -367,6 +403,21 @@ int main(int argc, char ** argv) { draft.clear(); { + // const bool had_rejection = ids.size() < draft.size() + 1; + + // if (use_state_snapshot && had_rejection) { + // // Restore snapshot and replay the committed prefix (id_last + accepted drafts) so target state exactly + // LOG_DBG("DFlash rollback: restore target state and replay %zu tokens\n", ids.size()); + // llama_state_seq_set_data(ctx_tgt, state_snap.data(), state_snap.size(), 0); + // common_batch_clear(batch_tgt); + // common_batch_add(batch_tgt, id_last_saved, n_past_before, { 0 }, true); + // for (size_t i = 0; i + 1 < ids.size(); ++i) { + // common_batch_add(batch_tgt, ids[i], n_past_before + 1 + i, { 0 }, true); + // } + // if (batch_tgt.n_tokens > 0) { + // llama_decode(ctx_tgt, batch_tgt); + // } + // } else { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index fd114d3ffa0c..c3b3cb37fae2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -492,8 +492,9 @@ class MODEL_ARCH(IntEnum): RND1 = auto() PANGU_EMBED = auto() MISTRAL3 = auto() - EAGLE3 = auto() MISTRAL4 = auto() + EAGLE3 = auto() + DFLASH = auto() PADDLEOCR = auto() MIMO2 = auto() STEP35 = auto() @@ -852,6 +853,9 @@ class MODEL_TENSOR(IntEnum): EAGLE3_FC = auto() # feature fusion layer EAGLE3_HIDDEN_NORM = auto() # hidden normalization EAGLE3_D2T = auto() # draft to target vocabulary mapping + # DFlash + DFLASH_FC = auto() # feature fusion layer + DFLASH_HIDDEN_NORM = auto() # hidden normalization # lfm2 audio A_ENC_NORM_CONV = auto() A_ENC_LINEAR_POS = auto() @@ -984,8 +988,9 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.RND1: "rnd1", MODEL_ARCH.PANGU_EMBED: "pangu-embedded", MODEL_ARCH.MISTRAL3: "mistral3", - MODEL_ARCH.EAGLE3: "eagle3", MODEL_ARCH.MISTRAL4: "mistral4", + MODEL_ARCH.EAGLE3: "eagle3", + MODEL_ARCH.DFLASH: "dflash", MODEL_ARCH.PADDLEOCR: "paddleocr", MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.STEP35: "step35", @@ -1352,6 +1357,8 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.EAGLE3_FC: "fc", MODEL_TENSOR.EAGLE3_HIDDEN_NORM: "blk.{bid}.hidden_norm", MODEL_TENSOR.EAGLE3_D2T: "d2t", + MODEL_TENSOR.DFLASH_FC: "fc", + MODEL_TENSOR.DFLASH_HIDDEN_NORM: "hidden_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -3772,6 +3779,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.EAGLE3_HIDDEN_NORM, MODEL_TENSOR.EAGLE3_D2T, ], + MODEL_ARCH.DFLASH: [ + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.DFLASH_FC, + MODEL_TENSOR.DFLASH_HIDDEN_NORM, + ], MODEL_ARCH.MISTRAL4: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/include/llama.h b/include/llama.h index a630da73b866..fc629fd5c55a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -558,6 +558,12 @@ extern "C" { LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + // DFlash draft model: block size used as number of draft tokens + LLAMA_API int32_t llama_model_dflash_block_size(const struct llama_model * model); + + // DFlash draft model: mask token id used as filler in the noise block + LLAMA_API int32_t llama_model_dflash_mask_token_id(const struct llama_model * model); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); @@ -914,6 +920,24 @@ extern "C" { int32_t n_embd, int32_t n_tokens); + // + // DFlash draft model support (similar to EAGLE3) + // + + // Enable DFlash target feature extraction on the target context + LLAMA_API void llama_set_dflash( + struct llama_context * ctx, + const struct llama_model * model); + + LLAMA_API const float * llama_get_dflash_target_features(struct llama_context * ctx); + + // Set accumulated target_ctx for DFlash decoder + LLAMA_API void llama_set_dflash_accumulated_target_ctx( + struct llama_context * ctx, + const float * data, + int32_t n_embd, + int32_t n_tokens); + // // Decoding // diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 2861887e8d3a..3ab9dd4d505e 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -126,8 +126,9 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, - { LLM_ARCH_EAGLE3, "eagle3" }, { LLM_ARCH_MISTRAL4, "mistral4" }, + { LLM_ARCH_EAGLE3, "eagle3" }, + { LLM_ARCH_DFLASH, "dflash" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_STEP35, "step35" }, @@ -289,6 +290,10 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, { LLM_KV_EAGLE3_NORM_BEFORE_RESIDUAL, "%s.norm_before_residual" }, + { LLM_KV_DFLASH_TARGET_LAYER_IDS, "%s.target_layer_ids" }, + { LLM_KV_DFLASH_BLOCK_SIZE, "%s.block_size" }, + { LLM_KV_DFLASH_MASK_TOKEN_ID, "%s.mask_token_id" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, // sentence-transformers dense modules feature dims { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, @@ -556,6 +561,9 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_EAGLE3_HIDDEN_NORM, "blk.%d.hidden_norm" }, { LLM_TENSOR_EAGLE3_FC, "fc" }, { LLM_TENSOR_EAGLE3_D2T, "d2t" }, + // DFlash specific layers + { LLM_TENSOR_DFLASH_FC, "fc" }, + { LLM_TENSOR_DFLASH_HIDDEN_NORM, "hidden_norm" }, }; // declare information about the model weight tensors: @@ -780,6 +788,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_EAGLE3_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_EAGLE3_HIDDEN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_EAGLE3_D2T, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, + // DFlash tensors + {LLM_TENSOR_DFLASH_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DFLASH_HIDDEN_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 06674d8c9bce..29f260eed525 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -139,6 +139,7 @@ enum llm_arch { LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, LLM_ARCH_EAGLE3, + LLM_ARCH_DFLASH, }; enum llm_kv { @@ -331,6 +332,10 @@ enum llm_kv { LLM_KV_EAGLE3_TARGET_HIDDEN_SIZE, LLM_KV_EAGLE3_NORM_BEFORE_RESIDUAL, + LLM_KV_DFLASH_TARGET_LAYER_IDS, + LLM_KV_DFLASH_BLOCK_SIZE, + LLM_KV_DFLASH_MASK_TOKEN_ID, + LLM_KV_SHORTCONV_L_CACHE, LLM_KV_XIELU_ALPHA_N, @@ -562,6 +567,8 @@ enum llm_tensor { LLM_TENSOR_EAGLE3_FC, // eagle3: feature fusion layer LLM_TENSOR_EAGLE3_HIDDEN_NORM, // eagle3: additional normalization layer LLM_TENSOR_EAGLE3_D2T, // eagle3: draft to target vocabulary mapping + LLM_TENSOR_DFLASH_FC, + LLM_TENSOR_DFLASH_HIDDEN_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index d364d927d981..e904db066fd8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -166,6 +166,7 @@ llama_context::llama_context( cparams.kv_unified = params.kv_unified; cparams.eagle3_extract_enabled = false; + cparams.dflash_extract_enabled = false; // initialized later cparams.pipeline_parallel = false; @@ -347,6 +348,15 @@ llama_context::llama_context( if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); } + // temp fix: DFlash encoder/decoder share one model_dft, keep the role on the context + dflash_decoder_ctx = model.arch == LLM_ARCH_DFLASH && params.target_model != nullptr; + // DFlash decoder: pre-fill cross with reservation size so build_inp_cross_embd + // uses cparams.n_ctx instead of hparams.n_ctx_train (which can cause OOM) + if (dflash_decoder_ctx) { + cross.n_embd = hparams.n_embd; + cross.n_enc = cparams.n_ctx; + cross.v_embd.resize(cross.n_embd * cross.n_enc, 0.0f); + } sched_reserve(); @@ -1196,7 +1206,52 @@ void llama_context::set_eagle3(const llama_model * model) { eagle3.extract_layer_indices[2]); } +void llama_context::set_dflash(const llama_model * model) { + cparams.dflash_extract_enabled = !!model; + if (!cparams.dflash_extract_enabled) { + return; + } + + sched_need_reserve = true; + + const auto & dflash_hparams = model->hparams; + + dflash.extract_layer_indices.assign( + dflash_hparams.dflash_target_layer_ids.begin(), + dflash_hparams.dflash_target_layer_ids.end() + ); + + dflash.extract_tensors.resize(dflash.extract_layer_indices.size(), nullptr); + + LLAMA_LOG_INFO("%s: DFlash extraction enabled for layers [%d, %d, %d, %d, %d]\n", __func__, + dflash.extract_layer_indices[0], + dflash.extract_layer_indices[1], + dflash.extract_layer_indices[2], + dflash.extract_layer_indices[3], + dflash.extract_layer_indices[4]); +} + +const float * llama_context::get_dflash_target_features() const { + GGML_ASSERT(!dflash.target_features.empty() && "DFlash target features not extracted"); + return dflash.target_features.data(); +} + +void llama_context::set_dflash_accumulated_target_ctx(const float * data, int32_t n_embd, int32_t n_tokens) { + GGML_ASSERT(data != nullptr); + const size_t size = (size_t)n_embd * n_tokens; + // Store in cross struct (reusing T5 style cross-attention for accumulated target features fed to the DFlash decoder) + cross.n_embd = n_embd; + cross.n_enc = n_tokens; + cross.v_embd.resize(size); + std::memcpy(cross.v_embd.data(), data, size * sizeof(float)); +} + llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { + // DFlash decoder runs through encode path due to no kv-cache but it needs decoder graph type + if (model.arch == LLM_ARCH_DFLASH && dflash_decoder_ctx && gtype == LLM_GRAPH_TYPE_ENCODER) { + gtype = LLM_GRAPH_TYPE_DECODER; + } + if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; @@ -1261,6 +1316,22 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } + // temp fix DFlash: Fill position tensor for decoder + if (model.arch == LLM_ARCH_DFLASH && gtype == LLM_GRAPH_TYPE_DECODER && !cross.v_embd.empty()) { + const int64_t n_ctx = cross.n_enc; + const int64_t n_noise = ubatch.n_tokens; + const int64_t n_total = n_ctx + n_noise; + + ggml_tensor * pos_full = ggml_graph_get_tensor(gf, "inp_pos_full"); + if (pos_full) { + std::vector pos_data(n_total); + for (int64_t i = 0; i < n_total; ++i) { + pos_data[i] = (int32_t)i; + } + ggml_backend_tensor_set(pos_full, pos_data.data(), 0, n_total * sizeof(int32_t)); + } + } + //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); } @@ -1276,6 +1347,10 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll extract_eagle3_features(ubatch); } + if (cparams.dflash_extract_enabled && !dflash.extract_tensors.empty()) { + extract_dflash_features(ubatch); + } + ret = GGML_STATUS_SUCCESS; return res; @@ -1291,8 +1366,15 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - // EAGLE3: use 3*target_hidden_size for concatenated features input - const int64_t n_embd = (model.arch == LLM_ARCH_EAGLE3 && batch_inp.embd) ? 3 * hparams.eagle3_target_hidden_size : hparams.n_embd; + // EAGLE3/DFlash: use concatenated features size from target for draft encoder input + int64_t n_embd = hparams.n_embd; + if (batch_inp.embd) { + if (model.arch == LLM_ARCH_EAGLE3) { + n_embd = 3 * hparams.eagle3_target_hidden_size; + } else if (model.arch == LLM_ARCH_DFLASH) { + n_embd = (int64_t) hparams.dflash_target_layer_ids.size() * hparams.n_embd; + } + } const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -2213,6 +2295,13 @@ ggml_cgraph * llama_context::graph_reserve( gtype = LLM_GRAPH_TYPE_DECODER; } } + if (model.arch == LLM_ARCH_DFLASH) { + if (cparams.embeddings && !dflash_decoder_ctx) { + gtype = LLM_GRAPH_TYPE_ENCODER; + } else if (dflash_decoder_ctx) { + gtype = LLM_GRAPH_TYPE_DECODER; + } + } const auto gparams = graph_params(res, ubatch, mctx, gtype); res->reset(); @@ -2255,6 +2344,7 @@ llm_graph_params llama_context::graph_params( /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.eagle3 =*/ &eagle3, + /*.dflash =*/ &dflash, /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2320,6 +2410,20 @@ llm_graph_cb llama_context::graph_get_cb() const { } } + // DFlash: Extract intermediate layer features if this is an extraction point + if (cparams.dflash_extract_enabled) { + static constexpr const char * prefix = "dflash_extract_"; + static constexpr size_t prefix_len = 15; + + if (strncmp(name, prefix, prefix_len) == 0) { + size_t extract_idx = 0; + if (sscanf(name + prefix_len, "%zu", &extract_idx) == 1 && extract_idx < dflash.extract_tensors.size()) { + ggml_set_output(cur); + dflash.extract_tensors[extract_idx] = cur; + } + } + } + // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; @@ -2386,6 +2490,42 @@ void llama_context::extract_eagle3_features(const llama_ubatch & ubatch) { } +void llama_context::extract_dflash_features(const llama_ubatch & ubatch) { + const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_embd = model.hparams.n_embd; + const size_t n_layers = dflash.extract_tensors.size(); + + const int64_t n_embd_concat = n_embd * n_layers; + dflash.target_features.resize(n_embd_concat * n_tokens); + + static thread_local std::vector temp_layer_features; + temp_layer_features.resize(n_embd * n_tokens); + + LLAMA_LOG_DEBUG("%s: Start to extract DFlash features: %zu layers, %lld tokens, %lld embd\n", + __func__, n_layers, (long long)n_tokens, (long long)n_embd); + + for (size_t layer_idx = 0; layer_idx < n_layers; ++layer_idx) { + ggml_tensor * tensor = dflash.extract_tensors[layer_idx]; + GGML_ASSERT(tensor != nullptr && "DFlash extraction tensor is null"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), tensor); + GGML_ASSERT(backend != nullptr && "DFlash tensor has no backend"); + + GGML_ASSERT(tensor->ne[0] == n_embd && tensor->ne[1] == n_tokens && + "DFlash extraction tensor has unexpected shape"); + + const size_t size_bytes = n_embd * n_tokens * sizeof(float); + ggml_backend_tensor_get_async(backend, tensor, temp_layer_features.data(), 0, size_bytes); + ggml_backend_sched_synchronize(sched.get()); + + for (int64_t token_idx = 0; token_idx < n_tokens; ++token_idx) { + const float * src = temp_layer_features.data() + token_idx * n_embd; + float * dest = dflash.target_features.data() + token_idx * n_embd_concat + layer_idx * n_embd; + std::memcpy(dest, src, n_embd * sizeof(float)); + } + } +} + // // state save/load // @@ -3100,6 +3240,13 @@ llama_context * llama_init_from_model( LLAMA_LOG_INFO("%s: EAGLE3 auto-setup: using target model's embedding layer\n", __func__); } + // Auto-setup for DFlash: set target embedding + lm_head if target_model is provided + if (model->arch == LLM_ARCH_DFLASH && params.target_model) { + model->target_tok_embd = params.target_model->tok_embd; + model->target_output = params.target_model->output; + LLAMA_LOG_INFO("%s: DFlash auto-setup: using target model's embedding + lm_head layers\n", __func__); + } + if (params.n_batch == 0 && params.n_ubatch == 0) { LLAMA_LOG_ERROR("%s: n_batch and n_ubatch cannot both be zero\n", __func__); return nullptr; @@ -3391,6 +3538,12 @@ void llama_set_eagle3( ctx->set_eagle3(model); } +void llama_set_dflash( + llama_context * ctx, + const llama_model * model) { + ctx->set_dflash(model); +} + // // memory // @@ -3733,6 +3886,15 @@ void llama_set_eagle3_g_embeddings(llama_context * ctx, const float * g_embd, in ctx->set_eagle3_g_embeddings(g_embd, n_embd, n_tokens); } +const float * llama_get_dflash_target_features(llama_context * ctx) { + return ctx->get_dflash_target_features(); +} + +void llama_set_dflash_accumulated_target_ctx(llama_context * ctx, const float * data, int32_t n_embd, int32_t n_tokens) { + ctx->set_dflash_accumulated_target_ctx(data, n_embd, n_tokens); +} + + // // ext // diff --git a/src/llama-context.h b/src/llama-context.h index 7959c7709a36..86f0d81c0ccf 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -108,6 +108,7 @@ struct llama_context { // TODO: tmp void set_eagle3(const llama_model * model); + void set_dflash(const llama_model * model); // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory @@ -231,6 +232,12 @@ struct llama_context { // EAGLE3: Set g_embeddings from encoder output for decoder input void set_eagle3_g_embeddings(const float * g_embd, int32_t n_embd, int32_t n_tokens); + // DFlash: Get pointer to target model features extracted for DFlash encoder + const float * get_dflash_target_features() const; + + // DFlash: Set accumulated target_ctx from encoder output for decoder input + void set_dflash_accumulated_target_ctx(const float * data, int32_t n_embd, int32_t n_tokens); + bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); private: @@ -245,6 +252,9 @@ struct llama_context { // EAGLE3: Extract intermediate layer features from target model void extract_eagle3_features(const llama_ubatch & ubatch); + // DFlash: Extract intermediate layer features from target model + void extract_dflash_features(const llama_ubatch & ubatch); + // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); @@ -268,6 +278,12 @@ struct llama_context { mutable llama_eagle3 eagle3; // EAGLE3 draft model support - stores features from target model // mutable because it's modified during graph building (const function) + mutable llama_dflash dflash; + + // temp fix: avoid DFlash encoder/decoder mis-detection. They share one model_dft, + // so shared model fields cannot safely identify the decoder (caused OOM). + bool dflash_decoder_ctx = false; + std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 48ab113bacbd..906bfbe36c12 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -39,6 +39,7 @@ struct llama_cparams { bool op_offload; bool kv_unified; bool eagle3_extract_enabled; // enable layer extraction for EAGLE3 speculative decoding + bool dflash_extract_enabled; // enable layer extraction for DFlash speculative decoding bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1ff22fb9b20d..9fabd242e766 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -947,6 +947,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : mctx (params.mctx), cross (params.cross), eagle3 (params.eagle3), + dflash (params.dflash), samplers (params.samplers), cb_func (params.cb), res (params.res), diff --git a/src/llama-graph.h b/src/llama-graph.h index b56077e9c509..1925a275d8a3 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -97,6 +97,20 @@ struct llama_eagle3 { } }; +// DFlash intermediate results struct (similar to Eagle3) +struct llama_dflash { + std::vector extract_layer_indices; + + std::vector target_features; + + std::vector extract_tensors; + + void clear() { + target_features.clear(); + extract_tensors.clear(); + } +}; + struct llm_graph_params; // @@ -569,6 +583,7 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; llama_eagle3 * eagle3; // non-const: we write extracted features here + llama_dflash * dflash; std::map samplers; @@ -784,6 +799,7 @@ struct llm_graph_context { const llama_memory_context_i * mctx; const llama_cross * cross; llama_eagle3 * eagle3; // non-const: we write extracted features here + llama_dflash * dflash; std::map samplers; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index fd12a597d016..fdd5a03bea80 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -220,6 +220,11 @@ struct llama_hparams { // EAGLE3 draft model - apply hidden_norm before storing residual bool eagle3_norm_before_residual = false; + // DFlash draft model + std::array dflash_target_layer_ids = {}; + uint32_t dflash_block_size = 16; + uint32_t dflash_mask_token_id = 0; + // gemma4 per-layer embedding uint32_t n_embd_per_layer = 0; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 4e65a45a50d8..832fc990c895 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -503,6 +503,7 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); // store DFlash 5 layer ids template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index d3b3a1560b45..5336ea8ee48b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2785,6 +2785,28 @@ void llama_model::load_hparams(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: EAGLE3 norm_before_residual = true\n", __func__); } + type = LLM_TYPE_UNKNOWN; + } break; + case LLM_ARCH_DFLASH: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key(LLM_KV_DFLASH_BLOCK_SIZE, hparams.dflash_block_size, false); + ml.get_key(LLM_KV_DFLASH_MASK_TOKEN_ID, hparams.dflash_mask_token_id, false); + + if (!ml.get_key_or_arr(LLM_KV_DFLASH_TARGET_LAYER_IDS, hparams.dflash_target_layer_ids, 5, false)) { + throw std::runtime_error("DFlash model requires 'target_layer_ids' in GGUF metadata"); + } + LLAMA_LOG_INFO("%s: DFlash extract_layers = [%d, %d, %d, %d, %d]\n", __func__, + hparams.dflash_target_layer_ids[0], + hparams.dflash_target_layer_ids[1], + hparams.dflash_target_layer_ids[2], + hparams.dflash_target_layer_ids[3], + hparams.dflash_target_layer_ids[4]); + + LLAMA_LOG_INFO("%s: DFlash block_size = %u, mask_token_id = %u\n", + __func__, hparams.dflash_block_size, hparams.dflash_mask_token_id); + type = LLM_TYPE_UNKNOWN; } break; case LLM_ARCH_COGVLM: @@ -7341,6 +7363,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_DFLASH: + { + const int64_t n_target_layer_ids = (int64_t)hparams.dflash_target_layer_ids.size(); + const int64_t n_embd_target_features = n_target_layer_ids * n_embd; + + fc = create_tensor(tn(LLM_TENSOR_DFLASH_FC, "weight"), {n_embd_target_features, n_embd}, 0); + dflash_hidden_norm = create_tensor(tn(LLM_TENSOR_DFLASH_HIDDEN_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa}, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa}, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_KIMI_LINEAR: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8528,6 +8583,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_LLADA: case LLM_ARCH_LLADA_MOE: case LLM_ARCH_RND1: + case LLM_ARCH_DFLASH: // current DFlash decoder doesn't support KV-cache due to cross_attn + self_attn (no mask) { res = nullptr; } break; @@ -9126,6 +9182,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } } break; + case LLM_ARCH_DFLASH: + { + if (params.gtype == LLM_GRAPH_TYPE_ENCODER) { + llm = std::make_unique(*this, params); + } else { + llm = std::make_unique(*this, params); + } + } break; case LLM_ARCH_COGVLM: { llm = std::make_unique(*this, params); @@ -9256,6 +9320,14 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } +int32_t llama_model_dflash_block_size(const llama_model * model) { + return (int32_t) model->hparams.dflash_block_size; +} + +int32_t llama_model_dflash_mask_token_id(const llama_model * model) { + return (int32_t) model->hparams.dflash_mask_token_id; +} + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } @@ -9370,6 +9442,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: + case LLM_ARCH_DFLASH: case LLM_ARCH_LLADA_MOE: case LLM_ARCH_RND1: case LLM_ARCH_OLMO2: diff --git a/src/llama-model.h b/src/llama-model.h index 1ad1c2af7c4a..199cc45ca496 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -560,6 +560,10 @@ struct llama_model { // This allows EAGLE3 to use target model's embeddings without copying struct ggml_tensor * target_tok_embd = nullptr; + // dflash + struct ggml_tensor * dflash_hidden_norm = nullptr; + struct ggml_tensor * target_output = nullptr; // reference to target model's lm_head + std::vector layers; //Dense linear projections for SentenceTransformers models like embeddinggemma diff --git a/src/models/dflash.cpp b/src/models/dflash.cpp new file mode 100644 index 000000000000..82a13ccc8888 --- /dev/null +++ b/src/models/dflash.cpp @@ -0,0 +1,161 @@ +#include "models.h" + +ggml_tensor * llm_build_dflash_encode::build_inp_embd() const { + const int64_t n_target_layer_ids = (int64_t) hparams.dflash_target_layer_ids.size(); + const int64_t n_embd_target_features = n_target_layer_ids * n_embd; + + auto inp_target = std::make_unique(n_embd_target_features); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_target_features, n_tokens); + ggml_set_input(inp_target->embd); + + ggml_tensor * cur = inp_target->embd; + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp_target)); + + return cur; +} + +llm_build_dflash_encode::llm_build_dflash_encode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur = build_inp_embd(); + + cur = build_lora_mm(model.fc, cur); + cb(cur, "fc_out", -1); + + cur = build_norm(cur, model.dflash_hidden_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "hidden_norm_out", -1); + + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} + +llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // Noise tokens [MASK] + GGML_ASSERT(model.target_tok_embd != nullptr && "DFlash decoder requires target model's tok_embd"); + ggml_tensor * noise_embd = build_inp_embd(model.target_tok_embd); + cb(noise_embd, "inp_noise_embd", -1); + + // Target context via llama_cross (filled from accumulated_target_ctx), graph rebuilds every step + ggml_tensor * target_ctx = build_inp_cross_embd(); + const int64_t n_ctx = target_ctx->ne[1]; + + ggml_tensor * inpL = noise_embd; + + const int64_t n_tokens_kv = n_ctx + n_tokens; + + // Position tensor covering target_ctx + noise + ggml_tensor * inp_pos_full = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens_kv); + ggml_set_input(inp_pos_full); + cb(inp_pos_full, "inp_pos_full", -1); + + // Q positions: last n_tokens entries (noise only) + ggml_tensor * inp_pos_q = ggml_view_1d(ctx0, inp_pos_full, n_tokens, + n_ctx * ggml_element_size(inp_pos_full)); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + ggml_tensor * noise_norm = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il); + cb(noise_norm, "noise_norm", il); + + // Q from noise only + ggml_tensor * Qcur = build_lora_mm(layer.wq, noise_norm); + if (layer.bq) { Qcur = ggml_add(ctx0, Qcur, layer.bq); } + cb(Qcur, "Qcur", il); + + // K = concat(k_proj(target_ctx), k_proj(noise)) + ggml_tensor * K_tgt = build_lora_mm(layer.wk, target_ctx); + ggml_tensor * K_noise = build_lora_mm(layer.wk, noise_norm); + if (layer.bk) { + K_tgt = ggml_add(ctx0, K_tgt, layer.bk); + K_noise = ggml_add(ctx0, K_noise, layer.bk); + } + ggml_tensor * Kcur = ggml_concat(ctx0, K_tgt, K_noise, 1); + cb(Kcur, "Kcur", il); + + // V = concat(v_proj(target_ctx), v_proj(noise)) + ggml_tensor * V_tgt = build_lora_mm(layer.wv, target_ctx); + ggml_tensor * V_noise = build_lora_mm(layer.wv, noise_norm); + if (layer.bv) { + V_tgt = ggml_add(ctx0, V_tgt, layer.bv); + V_noise = ggml_add(ctx0, V_noise, layer.bv); + } + ggml_tensor * Vcur = ggml_concat(ctx0, V_tgt, V_noise, 1); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens_kv); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens_kv); + + Qcur = build_norm(Qcur, layer.attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, layer.attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + + // RoPE: K uses full positions [0..n_ctx+n_tokens-1], Q uses last n_tokens + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos_full, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur_rope", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos_q, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur_rope", il); + + // Full attention (no causal mask) + ggml_build_forward_expand(gf, Qcur); + ggml_build_forward_expand(gf, Kcur); + ggml_build_forward_expand(gf, Vcur); + + ggml_tensor * cur = build_attn_mha(Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "kqv_out", il); + + cur = build_lora_mm(layer.wo, cur); + if (layer.bo) { cur = ggml_add(ctx0, cur, layer.bo); } + cur = ggml_add(ctx0, cur, inpL); + cb(cur, "attn_res", il); + + ggml_tensor * ffn_inp = cur; + cur = build_norm(cur, layer.ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, NULL, NULL, + layer.ffn_gate, NULL, NULL, + layer.ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + inpL = cur; + } + + ggml_tensor * cur = inpL; + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + if (model.target_output) { + cur = build_lora_mm(model.target_output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + } + + ggml_build_forward_expand(gf, cur); +} \ No newline at end of file diff --git a/src/models/models.h b/src/models/models.h index 612ab73bd3d6..062e6ff621d2 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -212,6 +212,16 @@ struct llm_build_eagle3_decode : public llm_graph_context { llm_build_eagle3_decode(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_dflash_encode : public llm_graph_context { + llm_build_dflash_encode(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_inp_embd() const; +}; + +struct llm_build_dflash_decode : public llm_graph_context { + llm_build_dflash_decode(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_ernie4_5 : public llm_graph_context { llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/openai-moe-iswa.cpp b/src/models/openai-moe-iswa.cpp index 10b86f255d29..db7492a68364 100644 --- a/src/models/openai-moe-iswa.cpp +++ b/src/models/openai-moe-iswa.cpp @@ -30,6 +30,19 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, } } + // DFlash: Extract intermediate layer features from target model at layer INPUT + if (dflash && cparams.dflash_extract_enabled && !dflash->extract_layer_indices.empty()) { + static const char * dflash_extract_names[] = { + "dflash_extract_0", "dflash_extract_1", "dflash_extract_2", + "dflash_extract_3", "dflash_extract_4" + }; + for (size_t i = 0; i < dflash->extract_layer_indices.size() && i < 5; ++i) { + if (dflash->extract_layer_indices[i] == il) { + cb(inpL, dflash_extract_names[i], il); + break; + } + } + } // norm cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, diff --git a/src/models/qwen3.cpp b/src/models/qwen3.cpp index 88f239187197..fa8c39402268 100644 --- a/src/models/qwen3.cpp +++ b/src/models/qwen3.cpp @@ -32,6 +32,20 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para } } + // DFlash: Extract intermediate layer features from target model at layer INPUT + if (dflash && cparams.dflash_extract_enabled && !dflash->extract_layer_indices.empty()) { + static const char * dflash_extract_names[] = { + "dflash_extract_0", "dflash_extract_1", "dflash_extract_2", + "dflash_extract_3", "dflash_extract_4" + }; + for (size_t i = 0; i < dflash->extract_layer_indices.size() && i < 5; ++i) { + if (dflash->extract_layer_indices[i] == il) { + cb(inpL, dflash_extract_names[i], il); + break; + } + } + } + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 87790f08e4ee..19d3d95619d0 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -26,6 +26,20 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // DFlash: Extract intermediate layer features from target model + if (dflash && cparams.dflash_extract_enabled && !dflash->extract_layer_indices.empty()) { + static const char * dflash_extract_names[] = { + "dflash_extract_0", "dflash_extract_1", "dflash_extract_2", + "dflash_extract_3", "dflash_extract_4" + }; + for (size_t i = 0; i < dflash->extract_layer_indices.size() && i < 5; ++i) { + if (dflash->extract_layer_indices[i] == il) { + cb(inpL, dflash_extract_names[i], il); + break; + } + } + } + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); From 85a0089e60ff26bf47158471208847a22f6eb3e0 Mon Sep 17 00:00:00 2001 From: Ruixiang Wang Date: Sun, 19 Apr 2026 15:05:02 +0000 Subject: [PATCH 14/21] dflash: add support for qwen3.5/3.6 moe models --- src/models/qwen35moe.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 7dc6a23c7518..b367bdecf365 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -26,6 +26,20 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; + // DFlash: Extract intermediate layer features from target model + if (dflash && cparams.dflash_extract_enabled && !dflash->extract_layer_indices.empty()) { + static const char * dflash_extract_names[] = { + "dflash_extract_0", "dflash_extract_1", "dflash_extract_2", + "dflash_extract_3", "dflash_extract_4" + }; + for (size_t i = 0; i < dflash->extract_layer_indices.size() && i < 5; ++i) { + if (dflash->extract_layer_indices[i] == il) { + cb(inpL, dflash_extract_names[i], il); + break; + } + } + } + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); From e344c4a71736e1cdaa25e590a109f694dfb8119f Mon Sep 17 00:00:00 2001 From: Ruixiang Wang Date: Fri, 24 Apr 2026 16:57:28 +0000 Subject: [PATCH 15/21] dflash: remove rebundant logic & correct bias naming --- .../speculative-simple/speculative-simple.cpp | 41 ------------------- src/llama-model.cpp | 8 ++-- src/models/dflash.cpp | 16 ++++---- 3 files changed, 12 insertions(+), 53 deletions(-) diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 003cc217f493..804a16623a41 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -232,21 +232,6 @@ int main(int argc, char ** argv) { const auto t_dec_start = ggml_time_us(); - // Hybrid targets (e.g. Qwen3.5) have recurrent layers that cannot be partially rolled back via seq_rm. - // For them, snapshot the target state before verify and, on rejection, restore it and replay only the accepted tokens to ensure correctness - // This is not efficient because the target model may run twice, but it is required in current llama.cpp design - const bool use_state_snapshot = params.speculative.dflash && llama_model_is_hybrid(model_tgt); - if (params.speculative.dflash) { - LOG_INF("%s: DFlash target=%s, using %s rollback path\n", __func__, - llama_model_is_hybrid(model_tgt) ? "hybrid" : "pure-attention", - use_state_snapshot ? "snapshot+restore" : "seq_rm"); - } - std::vector state_snap; - if (use_state_snapshot) { - const size_t sz = llama_state_seq_get_size(ctx_tgt, 0); - state_snap.resize(sz); - } - while (true) { // generate or reuse draft tokens // @@ -294,17 +279,6 @@ int main(int argc, char ** argv) { GGML_ASSERT(n_draft > 0); - // snapshot target state for potential rollback (hybrid/recurrent targets only) - const int n_past_before = n_past; - const llama_token id_last_saved = id_last; - if (use_state_snapshot) { - const size_t sz = llama_state_seq_get_size(ctx_tgt, 0); - if (sz > state_snap.size()) { - state_snap.resize(sz); - } - llama_state_seq_get_data(ctx_tgt, state_snap.data(), sz, 0); - } - // always have a token to evaluate from before - id_last common_batch_clear(batch_tgt); common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); @@ -403,21 +377,6 @@ int main(int argc, char ** argv) { draft.clear(); { - // const bool had_rejection = ids.size() < draft.size() + 1; - - // if (use_state_snapshot && had_rejection) { - // // Restore snapshot and replay the committed prefix (id_last + accepted drafts) so target state exactly - // LOG_DBG("DFlash rollback: restore target state and replay %zu tokens\n", ids.size()); - // llama_state_seq_set_data(ctx_tgt, state_snap.data(), state_snap.size(), 0); - // common_batch_clear(batch_tgt); - // common_batch_add(batch_tgt, id_last_saved, n_past_before, { 0 }, true); - // for (size_t i = 0; i + 1 < ids.size(); ++i) { - // common_batch_add(batch_tgt, ids[i], n_past_before + 1 + i, { 0 }, true); - // } - // if (batch_tgt.n_tokens > 0) { - // llama_decode(ctx_tgt, batch_tgt); - // } - // } else { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5336ea8ee48b..47668954f59c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7382,10 +7382,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); diff --git a/src/models/dflash.cpp b/src/models/dflash.cpp index 82a13ccc8888..0adba127eabf 100644 --- a/src/models/dflash.cpp +++ b/src/models/dflash.cpp @@ -67,15 +67,15 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons // Q from noise only ggml_tensor * Qcur = build_lora_mm(layer.wq, noise_norm); - if (layer.bq) { Qcur = ggml_add(ctx0, Qcur, layer.bq); } + if (layer.wq_b) { Qcur = ggml_add(ctx0, Qcur, layer.wq_b); } cb(Qcur, "Qcur", il); // K = concat(k_proj(target_ctx), k_proj(noise)) ggml_tensor * K_tgt = build_lora_mm(layer.wk, target_ctx); ggml_tensor * K_noise = build_lora_mm(layer.wk, noise_norm); - if (layer.bk) { - K_tgt = ggml_add(ctx0, K_tgt, layer.bk); - K_noise = ggml_add(ctx0, K_noise, layer.bk); + if (layer.wk_b) { + K_tgt = ggml_add(ctx0, K_tgt, layer.wk_b); + K_noise = ggml_add(ctx0, K_noise, layer.wk_b); } ggml_tensor * Kcur = ggml_concat(ctx0, K_tgt, K_noise, 1); cb(Kcur, "Kcur", il); @@ -83,9 +83,9 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons // V = concat(v_proj(target_ctx), v_proj(noise)) ggml_tensor * V_tgt = build_lora_mm(layer.wv, target_ctx); ggml_tensor * V_noise = build_lora_mm(layer.wv, noise_norm); - if (layer.bv) { - V_tgt = ggml_add(ctx0, V_tgt, layer.bv); - V_noise = ggml_add(ctx0, V_noise, layer.bv); + if (layer.wv_b) { + V_tgt = ggml_add(ctx0, V_tgt, layer.wv_b); + V_noise = ggml_add(ctx0, V_noise, layer.wv_b); } ggml_tensor * Vcur = ggml_concat(ctx0, V_tgt, V_noise, 1); cb(Vcur, "Vcur", il); @@ -123,7 +123,7 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons cb(cur, "kqv_out", il); cur = build_lora_mm(layer.wo, cur); - if (layer.bo) { cur = ggml_add(ctx0, cur, layer.bo); } + if (layer.wo_b) { cur = ggml_add(ctx0, cur, layer.wo_b); } cur = ggml_add(ctx0, cur, inpL); cb(cur, "attn_res", il); From 67cb0d507080e42cc012ac0bdb8f09622f64455b Mon Sep 17 00:00:00 2001 From: Ruixiang Wang Date: Mon, 27 Apr 2026 11:57:22 +0000 Subject: [PATCH 16/21] dflash: enable llama-cli & llama-server with np=1 --- common/arg.cpp | 2 +- tools/server/server-context.cpp | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index 70d90f97347c..03596ced4d86 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3480,7 +3480,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.speculative.dflash = true; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-cd", "--ctx-size-draft"}, "N", string_format("size of the prompt context for the draft model (default: %d, 0 = loaded from model)", params.speculative.n_ctx), diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 65bc356ad825..c835dd8a44c2 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -813,6 +813,16 @@ struct server_context_impl { llama_set_eagle3(ctx, model_dft.get()); SRV_INF("%s", "EAGLE3 feature extraction enabled on target model\n"); } + + if (params_base.speculative.dflash) { + // DFlash current limitation: extracted target features are per-context; multiple slots would overwrite each other + if (params_base.n_parallel > 1) { + SRV_ERR("%s", "DFlash speculative decoding is not supported with n_parallel > 1\n"); + return false; + } + llama_set_dflash(ctx, model_dft.get()); + SRV_INF("%s", "DFlash feature extraction enabled on target model\n"); + } } std::string & mmproj_path = params_base.mmproj.path; From 10508e7408cfd946b6e9547bf402626a80a19597 Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Fri, 12 Jun 2026 18:11:16 +0200 Subject: [PATCH 17/21] convert: map the Qwen3.5-4B multimodal tokenizer hash to the qwen35 pre-tokenizer Qwen/Qwen3.5-4B is Qwen3_5ForConditionalGeneration (multimodal). Without this mapping neither the target nor the DFlash drafter converts to GGUF. --- convert_hf_to_gguf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 6a5ac25d945d..750eb3110289 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1537,6 +1537,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4": # ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct res = "qwen35" + if chkhsh == "1444df51289cfa8063b96f0e62b1125440111bc79a52003ea14b6eac7016fd5f": + # ref: https://huggingface.co/Qwen/Qwen3.5-4B (multimodal text tokenizer; same qwen35 split regex) + res = "qwen35" if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d": # ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash res = "joyai-llm" From 6bbbeacd1f72fdcc3f95944ae0677cd1cab50ab9 Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Fri, 12 Jun 2026 18:11:16 +0200 Subject: [PATCH 18/21] dflash: make speculative decoding work and fast on the Qwen3.5-4B hybrid The DFlash drafter targets a Gated-DeltaNet hybrid (recurrent + attention). The recurrent state can't be partial-rolled-back, so a naive verify is slower than plain generation. This brings it to lossless speedup: - recurrent-state rewind via a per-token GDN state trace + on-device promote of the accepted state (llama_dflash_promote_state) instead of a ~50 MiB host checkpoint and re-decode per round - graph reuse: fixed-capacity device-resident target-context cache, encoder folded into the decoder graph, padding mask over a bucketed context - on-device greedy verify: drafter block argmax + target argmax for the greedy verify (llama_set_out_argmax), one host sync per round - optional GPU sampling verify (temperature; top-k/top-p behind LLAMA_SPEC_GPU_SAMPLE) - CUDA graphs opt-in on Volta (GGML_CUDA_GRAPHS_VOLTA) and a stable sched uid Lossless. ~1.7x on V100/Q8 single-stream, scaling with the draft block on high-acceptance (reasoning) workloads. --- bw_full.sh | 43 ++ common/speculative.cpp | 72 +-- common/speculative.h | 3 + .../speculative-simple/speculative-simple.cpp | 328 +++++++++- ggml/include/ggml.h | 15 + ggml/src/ggml-cpu/ops.cpp | 6 + ggml/src/ggml-cuda/gated_delta_net.cu | 31 +- ggml/src/ggml-cuda/ggml-cuda.cu | 20 +- ggml/src/ggml.c | 27 + h100_bench.sh | 30 + h100_full.sh | 44 ++ include/llama.h | 83 +++ src/llama-context.cpp | 579 +++++++++++++++++- src/llama-context.h | 57 ++ src/llama-cparams.h | 4 + src/llama-graph.cpp | 102 ++- src/llama-graph.h | 46 ++ src/models/delta-net-base.cpp | 11 +- src/models/dflash.cpp | 81 ++- src/models/models.h | 4 + src/models/qwen35.cpp | 32 + tools/server/server-context.cpp | 47 +- 22 files changed, 1585 insertions(+), 80 deletions(-) create mode 100644 bw_full.sh create mode 100644 h100_bench.sh create mode 100644 h100_full.sh diff --git a/bw_full.sh b/bw_full.sh new file mode 100644 index 000000000000..f516319f40a1 --- /dev/null +++ b/bw_full.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Self-contained RTX PRO 6000 Blackwell (sm_120) DFlash verification. Runs inside the pod. +set -e +echo "=== GPU ==="; nvidia-smi --query-gpu=name,compute_cap,memory.total,driver_version --format=csv,noheader +export DEBIAN_FRONTEND=noninteractive +apt-get update -qq && apt-get install -y -qq cmake build-essential git python3-pip >/dev/null 2>&1 || true + +cd /workspace 2>/dev/null || cd /root +[ -d llama.cpp ] || git clone -q -b work-qwen35-dflash https://github.com/AlexWortega/llama.cpp.git +cd llama.cpp +pip install -q numpy sentencepiece transformers safetensors gguf protobuf hf_transfer 2>/dev/null || true + +mkdir -p models hf +export HF_HUB_ENABLE_HF_TRANSFER=1 +echo "=== download HF models ===" +python3 -c "from huggingface_hub import snapshot_download as s; s('Qwen/Qwen3.5-4B', local_dir='hf/tgt'); s('z-lab/Qwen3.5-4B-DFlash', local_dir='hf/dft')" 2>&1 | tail -1 + +echo "=== convert ===" +python3 convert_hf_to_gguf.py hf/tgt --outfile models/tgt-f16.gguf --outtype f16 >/tmp/cv1.log 2>&1 && echo tgt-ok || { echo TGT_FAIL; tail -8 /tmp/cv1.log; exit 1; } +python3 convert_hf_to_gguf.py hf/dft --outfile models/Qwen3.5-4B-DFlash-f16.gguf --outtype f16 >/tmp/cv2.log 2>&1 && echo dft-ok || { echo DFT_FAIL; tail -8 /tmp/cv2.log; exit 1; } + +echo "=== build (Blackwell sm_120) ===" +cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=120 -DLLAMA_CURL=OFF >/tmp/cm.log 2>&1 +cmake --build build --target llama-speculative-simple llama-quantize llama-cli -j $(nproc) >/tmp/build.log 2>&1 && echo BUILT || { echo BUILDFAIL; tail -20 /tmp/build.log; exit 1; } + +./build/bin/llama-quantize models/tgt-f16.gguf models/Qwen3.5-4B-Q8_0.gguf Q8_0 >/dev/null 2>&1 && echo quantized +M="-m models/Qwen3.5-4B-Q8_0.gguf -md models/Qwen3.5-4B-DFlash-f16.gguf --dflash -ngl 99 -ngld 99 -p Tell-me-about-the-water-cycle-in-detail. -n 200 -c 2048 --draft-max 5 --temp 0 --top-k 1 --samplers top_k" +BIN=./build/bin/llama-speculative-simple + +echo "=== AR baseline ===" +./build/bin/llama-cli -m models/Qwen3.5-4B-Q8_0.gguf -ngl 99 -p "Tell me about the water cycle in detail." -n 200 -c 2048 --temp 0 -no-cnv 2>/tmp/ar.err >/dev/null || true +tr "\r" "\n" < /tmp/ar.err | grep -oE "[0-9.]+ tokens per second" | tail -1 + +echo "=== DFlash full stack (trace+gpuverify+async) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 LLAMA_SPEC_ASYNC=1 $BIN $M >/tmp/df.txt 2>/tmp/df.err || true +tr "\r" "\n" < /tmp/df.err | grep -oE "speed: +[0-9.]+|accept += +[0-9.]+%" | tail -2 + +echo "=== DFlash + Blackwell CUDA graphs (sm_120 >= Ampere: engage by default) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 $BIN $M >/tmp/dfg.txt 2>/tmp/dfg.err || true +tr "\r" "\n" < /tmp/dfg.err | grep -oE "speed: +[0-9.]+" | tail -1 + +echo "=== sample ==="; tail -c 160 /tmp/df.txt +echo "=== DONE ===" diff --git a/common/speculative.cpp b/common/speculative.cpp index 4980c03da62e..378776179b34 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -780,32 +780,12 @@ struct common_speculative_state_dflash : public common_speculative_state { GGML_ASSERT(n >= 1 && "prompt_tgt is empty"); GGML_ASSERT(n_new >= 1 && "must have at least 1 new token"); - // Step 1: Encode new accepted tokens' features + // Steps 1+2 folded into the decoder graph: stage the NEW tokens' raw target features; + // the decoder encodes them (fc+norm) in-graph and appends into the device-resident + // context cache. No separate encoder pass, no host round trip of encoded features. + GGML_UNUSED(n_embd); const float * features = llama_get_dflash_target_features(ctx_tgt); - - llama_batch enc_batch = { - /*.n_tokens =*/ n_new, - /*.token =*/ nullptr, - /*.embd =*/ const_cast(features), - /*.pos =*/ nullptr, - /*.n_seq_id =*/ nullptr, - /*.seq_id =*/ nullptr, - /*.logits =*/ nullptr, - }; - if (llama_encode(ctx_dft_enc, enc_batch) != 0) { - LOG_ERR("DFlash: encoder failed\n"); - return; - } - - const float * target_ctx_new = llama_get_embeddings(ctx_dft_enc); - GGML_ASSERT(target_ctx_new && "encoder output is null"); - - // Step 2: Append to accumulated target_ctx and set on decoder context (writes to cross.v_embd) - const size_t new_size = (size_t)n_embd * n_new; - accumulated_ctx.insert(accumulated_ctx.end(), target_ctx_new, target_ctx_new + new_size); - - const int n_ctx_total = (int)(accumulated_ctx.size() / n_embd); - llama_set_dflash_accumulated_target_ctx(ctx_dft_dec, accumulated_ctx.data(), n_embd, n_ctx_total); + llama_dflash_append_features(ctx_dft_dec, features, n_new, n); // Step 3: Decode noise block const llama_token mask_token_id = llama_model_dflash_mask_token_id(llama_get_model(ctx_dft_dec)); @@ -813,7 +793,9 @@ struct common_speculative_state_dflash : public common_speculative_state { common_batch_clear(batch); for (int i = 0; i < block_size; i++) { const llama_token tok = (i == 0) ? id_last : mask_token_id; - common_batch_add(batch, tok, i, {0}, true); + // logits=false: the greedy draft tokens come from the on-device argmax below, so the + // n_vocab x block logits host copy (~5 MB/round at vocab 248k) is skipped entirely + common_batch_add(batch, tok, i, {0}, false); } if (llama_decode(ctx_dft_dec, batch) != 0) { @@ -823,18 +805,26 @@ struct common_speculative_state_dflash : public common_speculative_state { dflash_n_past = n; - // Step 4: Sample draft tokens from positions 1..block_size-1 + // Step 4: greedy top-1 draft tokens from the decoder's on-device argmax (the DFlash decode + // graph appends a GGML argmax node over the block logits). The drafted tokens are verified + // by the target regardless, so this cannot affect correctness, only draft latency. result.clear(); - common_sampler_reset(smpl); - for (int i = 1; i < block_size; i++) { - common_sampler_sample(smpl, ctx_dft_dec, i); - - const auto * cur_p = common_sampler_get_candidates(smpl, true); - const llama_token id = cur_p->data[0].id; + // async feed (LLAMA_SPEC_ASYNC=1): hand the draft tokens to the target device-to-device + // and return placeholders - the host reads the actual values only after the verify decode + // (one synchronization per round instead of two; the GPU queues draft+verify back-to-back) + static const bool async_feed = getenv("LLAMA_SPEC_ASYNC") != nullptr && + std::string(getenv("LLAMA_SPEC_ASYNC")) != "0"; + if (async_feed && llama_dflash_feed_draft_tokens(ctx_tgt, ctx_dft_dec, block_size - 1)) { + result.assign(block_size - 1, 0); // placeholders; refilled by the caller post-verify + return; + } - common_sampler_accept(smpl, id, true); - result.push_back(id); + int32_t n_am = 0; + const int32_t * am = llama_get_dflash_argmax(ctx_dft_dec, &n_am); + GGML_ASSERT(am != nullptr && n_am >= block_size && "DFlash decoder did not produce argmax"); + for (int i = 1; i < block_size; i++) { + result.push_back((llama_token) am[i]); } } @@ -1125,6 +1115,18 @@ struct common_speculative { common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) }; +llama_context * common_speculative_get_dflash_decoder(common_speculative * spec) { + if (spec == nullptr) { + return nullptr; + } + for (auto & impl : spec->impls) { + if (impl->type == COMMON_SPECULATIVE_TYPE_DFLASH) { + return static_cast(impl.get())->ctx_dft_dec; + } + } + return nullptr; +} + static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { uint16_t size_key = config.params.ngram_size_n; uint16_t size_value = config.params.ngram_size_m; diff --git a/common/speculative.h b/common/speculative.h index bca78d32b5b3..9097f0e3d7db 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -33,6 +33,9 @@ llama_tokens common_speculative_draft( // informs the speculative decoder that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); +// the DFlash decoder context, if a DFlash implementation is active (nullptr otherwise) +llama_context * common_speculative_get_dflash_decoder(common_speculative * spec); + // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 804a16623a41..8e226b99575b 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,6 +13,90 @@ #include #include #include +#include +#include +#include + +// Build the filtered candidate distribution for one verify row from its top-K (token, logit) +// pairs: apply temperature, softmax, top-k and top-p, renormalize. Fills `cand` with the kept +// (token, prob) pairs sorted by descending prob. This is the exact target sampler restricted to +// the top-K candidates - the top-p nucleus is a subset of the top-K for any realistic params. +static void spec_build_candidates( + const int32_t * ids, const float * logits, int32_t k, int32_t n_vocab, + float temp, int32_t top_k, float top_p, + std::vector> & cand) { + // store RAW (un-temped) logits; the sampler order matches llama.cpp's "top_k;top_p;temp" chain: + // top_k and top_p operate on the pre-temperature logits/probs, temperature is applied last. + cand.clear(); + cand.reserve(k); + for (int32_t j = 0; j < k; ++j) { + // skip padding slots: the logits vocab can be padded beyond the real vocabulary + if (ids[j] < 0 || ids[j] >= n_vocab) { continue; } + cand.emplace_back((llama_token) ids[j], logits[j]); // raw logit + } + std::sort(cand.begin(), cand.end(), + [](const auto & a, const auto & b) { return a.second > b.second; }); + // 1) top-k cut (logit order, temperature-invariant) + if (top_k > 0 && (int32_t) cand.size() > top_k) { + cand.resize(top_k); + } + // 2) top-p cut on the temperature-1 softmax (the nucleus is defined pre-temperature) + if (top_p < 1.0f && !cand.empty()) { + const float maxl = cand[0].second; + double sum = 0.0; + std::vector p(cand.size()); + for (size_t j = 0; j < cand.size(); ++j) { p[j] = std::exp(cand[j].second - maxl); sum += p[j]; } + double cum = 0.0; + size_t keep = cand.size(); + for (size_t j = 0; j < cand.size(); ++j) { + cum += p[j] / sum; + if (cum >= top_p) { keep = j + 1; break; } + } + cand.resize(keep); + } + // 3) temperature, then final softmax over the kept candidates + const float inv_t = 1.0f / (temp > 0.0f ? temp : 1.0f); + const float maxl = cand.empty() ? 0.0f : cand[0].second * inv_t; + double z = 0.0; + for (auto & c : cand) { c.second = (float) std::exp(c.second * inv_t - maxl); z += c.second; } + if (z > 0.0) { for (auto & c : cand) { c.second = (float) (c.second / z); } } +} + +// Sample one token from a single device-resident verify logits row, applying temperature and, +// for the residual case, excluding the rejected draft token. Returns the sampled token id. +// This is the host side of the sampling speculative verify: only ONE logits row is fetched per +// block (on the first rejection, or for the bonus when everything is accepted) instead of the +// whole n_vocab x block matrix. The temperature distribution is reproduced exactly, so the +// output is lossless to the target's sampling distribution. +static llama_token spec_sample_row( + llama_context * ctx, int32_t row, llama_token exclude, float temp, int32_t n_vocab, + std::vector & buf, std::mt19937 & rng) { + buf.resize(n_vocab); + if (!llama_dflash_fetch_logits_row(ctx, row, buf.data(), n_vocab)) { + return 0; + } + const float inv_t = 1.0f / (temp > 0.0f ? temp : 1.0f); + float maxl = -INFINITY; + for (int32_t v = 0; v < n_vocab; ++v) { + if (v == exclude) { continue; } + buf[v] *= inv_t; + if (buf[v] > maxl) { maxl = buf[v]; } + } + double sum = 0.0; + for (int32_t v = 0; v < n_vocab; ++v) { + if (v == exclude) { buf[v] = 0.0f; continue; } + buf[v] = std::exp(buf[v] - maxl); + sum += buf[v]; + } + // inverse-CDF categorical sample + std::uniform_real_distribution u01(0.0, 1.0); + double r = u01(rng) * sum; + for (int32_t v = 0; v < n_vocab; ++v) { + r -= buf[v]; + if (r <= 0.0) { return (llama_token) v; } + } + return (llama_token) (n_vocab - 1); +} struct spec_checkpoint { int64_t n_tokens = 0; @@ -57,6 +141,13 @@ int main(int argc, char ** argv) { llama_context * ctx_tgt = NULL; + // DFlash/EAGLE3 on a hybrid/recurrent target can't partial-seq-rm, so speculative decoding + // checkpoints the target state on every step. Reserve a 2nd sequence slot so that checkpoint + // can live ON-DEVICE (seq_cp) instead of a ~50 MiB GPU<->host round-trip per step. + if (params.speculative.dflash || params.speculative.eagle3) { + params.n_parallel = std::max(params.n_parallel, 2); + } + // load the target model auto llama_init_tgt = common_init_from_params(params); @@ -67,8 +158,17 @@ int main(int argc, char ** argv) { const auto ctx_seq_rm = common_context_can_seq_rm(ctx_tgt); const bool use_ckpt = (ctx_seq_rm == COMMON_CONTEXT_SEQ_RM_TYPE_FULL); + // when the context has a spare sequence slot, keep the speculative checkpoint on-device by + // copying the active sequence into a scratch sequence (seq_cp) instead of serializing ~50 MiB + // to host and back every step. Profiling showed the host round-trip was ~22% of decode time. + const llama_seq_id SEQ_CKPT = 1; + // LLAMA_SPEC_NO_SEQCP forces the host-checkpoint path (for validating the on-device path). + const bool use_seq_cp = use_ckpt && llama_n_seq_max(ctx_tgt) > SEQ_CKPT + && !(getenv("LLAMA_SPEC_NO_SEQCP") && std::string(getenv("LLAMA_SPEC_NO_SEQCP")) != "0"); + if (use_ckpt) { - LOG_INF("speculative decoding will use checkpoints (context does not support partial sequence removal)\n"); + LOG_INF("speculative decoding will use checkpoints (%s)\n", + use_seq_cp ? "on-device seq_cp" : "host state round-trip"); } const llama_vocab * vocab = llama_model_get_vocab(model_tgt); @@ -116,6 +216,58 @@ int main(int argc, char ** argv) { } } + // DFlash recurrent rewind (LLAMA_SPEC_TRACE=1): trace per-token recurrent states during the + // verify decode so a partial acceptance promotes the state at the accepted position instead of + // checkpoint-restore + re-decode of the accepted tokens (the "commit-forward"). + const bool use_state_trace = params.speculative.dflash && use_ckpt && + getenv("LLAMA_SPEC_TRACE") && std::string(getenv("LLAMA_SPEC_TRACE")) != "0"; + if (use_state_trace) { + llama_set_dflash_state_trace(ctx_tgt, params.speculative.n_max + 1); + LOG_INF("DFlash recurrent state trace enabled (promote instead of re-decode)\n"); + } + + // GPU greedy verify (LLAMA_SPEC_GPU_VERIFY=1, GREEDY SAMPLING ONLY): the target emits an + // on-device argmax of the verify-block logits and the host logits copy (n_vocab x block + // floats per round) is skipped; acceptance compares block_size ints. + const bool use_gpu_verify = params.speculative.dflash && + getenv("LLAMA_SPEC_GPU_VERIFY") && std::string(getenv("LLAMA_SPEC_GPU_VERIFY")) != "0"; + if (use_gpu_verify) { + LOG_INF("GPU greedy verify enabled (target logits stay on-device)\n"); + } + + // async draft feed (LLAMA_SPEC_ASYNC=1, requires GPU verify): draft tokens go to the verify + // batch device-to-device; one host synchronization per round instead of two + const bool use_async_feed = use_gpu_verify && + getenv("LLAMA_SPEC_ASYNC") && std::string(getenv("LLAMA_SPEC_ASYNC")) != "0"; + if (use_async_feed) { + LOG_INF("async draft feed enabled (single sync per round)\n"); + } + + // sampling speculative verify (LLAMA_SPEC_GPU_SAMPLE=1, temperature > 0): the target emits the + // temp-softmax probability of each draft token on-device; the host does rejection sampling on + // those probs and fetches a single logits row for the residual/bonus sample. Lossless to the + // target's temperature distribution. SGLang's tree_speculative_sampling_target_only, ported. + const float spec_temp = params.sampling.temp; + const int32_t spec_top_k = params.sampling.top_k; + const float spec_top_p = params.sampling.top_p; + // top-k/top-p would emit top-K candidates (cap 256), but the on-device top-K path does not yet + // match the host sampler's acceptance closely enough to beat it - keep it experimental behind + // LLAMA_SPEC_GPU_SAMPLE_TOPK. By default GPU sampling verify is temperature-only (where it is a + // clear win); any top-k/top-p config falls back to the (correct, faster) host sampler path. + const bool spec_filtered = (spec_top_k > 0) || (spec_top_p < 1.0f); + const bool allow_topk = getenv("LLAMA_SPEC_GPU_SAMPLE_TOPK") && + std::string(getenv("LLAMA_SPEC_GPU_SAMPLE_TOPK")) != "0"; + const int32_t spec_topk_cap = (spec_filtered && allow_topk) ? std::max(256, spec_top_k) : 0; + const bool use_gpu_sample = params.speculative.dflash && spec_temp > 0.0f && !use_async_feed && + (!spec_filtered || allow_topk) && + getenv("LLAMA_SPEC_GPU_SAMPLE") && std::string(getenv("LLAMA_SPEC_GPU_SAMPLE")) != "0"; + std::mt19937 spec_rng((uint32_t) (params.sampling.seed == LLAMA_DEFAULT_SEED ? 0xC0FFEE : params.sampling.seed)); + std::vector spec_logits_buf; + const int32_t spec_n_vocab = llama_vocab_n_tokens(vocab); + if (use_gpu_sample) { + LOG_INF("GPU sampling verify enabled (temp=%.2f, residual on-device prob + 1-row fetch)\n", spec_temp); + } + // Apply chat template for EAGLE3 / DFlash if available which can increase the acceptance rate std::string prompt = params.prompt; if (params.speculative.eagle3 || params.speculative.dflash) { @@ -193,6 +345,15 @@ int main(int argc, char ** argv) { id_last = common_sampler_sample(smpl.get(), ctx_tgt, -1); common_sampler_accept(smpl.get(), id_last, true); + + // from now on the verify loop only needs the on-device argmax (the initial sample above + // still consumed host logits, so the flag is enabled only after it) + if (use_gpu_verify) { + llama_set_out_argmax(ctx_tgt, true); + } + if (use_gpu_sample) { + llama_set_out_spec_sample(ctx_tgt, true, spec_temp, spec_topk_cap); + } LOG("%s", common_token_to_piece(ctx_tgt, id_last).c_str()); n_predict++; @@ -259,7 +420,19 @@ int main(int argc, char ** argv) { // save a checkpoint of the target context before evaluating the draft // this allows us to restore the state if partial draft acceptance occurs - if (!draft.empty() && use_ckpt) { + if (!draft.empty() && use_state_trace) { + // recurrent rewind: no checkpoint needed - the verify decode traces per-token + // states and a partial acceptance promotes the right one (see below) + spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); + } else if (!draft.empty() && use_seq_cp) { + // on-device checkpoint: copy the active sequence (0) into the scratch sequence. + // The subsequent draft decode on seq 0 advances its state; seq SEQ_CKPT keeps the + // pre-draft state (recurrent state is copy-on-write), so we can restore from it. + auto * mem = llama_get_memory(ctx_tgt); + llama_memory_seq_rm(mem, SEQ_CKPT, -1, -1); + llama_memory_seq_cp(mem, 0, SEQ_CKPT, -1, -1); + spec_ckpt.n_tokens = (int64_t) prompt_tgt.size(); + } else if (!draft.empty() && use_ckpt) { const size_t ckpt_size = llama_state_seq_get_size_ext(ctx_tgt, 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); spec_ckpt.data.resize(ckpt_size); @@ -272,7 +445,8 @@ int main(int argc, char ** argv) { } } else { // we have a previous (partial) draft to reuse from checkpoint restoration - if (use_ckpt) { + // (for the on-device path the checkpoint lives in seq SEQ_CKPT, not in spec_ckpt.data) + if (use_ckpt && !use_seq_cp) { GGML_ASSERT(!spec_ckpt.empty()); } } @@ -297,6 +471,11 @@ int main(int argc, char ** argv) { //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); + + // debug: validate the state-trace mechanics (trace[last] must equal the live cell) + if (use_state_trace && getenv("LLAMA_DFLASH_DEBUG")) { + llama_dflash_trace_check(ctx_tgt, batch_tgt.n_tokens); + } } // only save the sampler sampler state if we use checkpoints @@ -312,7 +491,111 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - auto ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft); + llama_tokens ids; + if (use_gpu_sample) { + // sampling speculative verify (lossless to the target temperature distribution): + // the DFlash drafter proposes greedily (q = delta), so accept draft d_i with prob + // p_i(d_i) (the target temp-softmax prob, computed on-device); on the first rejection + // sample the replacement from the residual p_k with d_k removed; if all accepted, the + // bonus is sampled from the last position's full target distribution. + std::uniform_real_distribution u01(0.0f, 1.0f); + + // sample a token from a filtered candidate distribution, optionally excluding one token + auto sample_cand = [&](const std::vector> & cand, + llama_token exclude) -> llama_token { + double z = 0.0; + for (const auto & c : cand) { if (c.first != exclude) { z += c.second; } } + if (z <= 0.0) { return cand.empty() ? 0 : cand[0].first; } + double r = u01(spec_rng) * z; + for (const auto & c : cand) { + if (c.first == exclude) { continue; } + r -= c.second; + if (r <= 0.0) { return c.first; } + } + return cand.back().first; + }; + + size_t i = 0; + bool rejected = false; + + if (spec_topk_cap > 0) { + // top-k/top-p verify: per row, rebuild the filtered candidate distribution on the + // host from the on-device top-K, then run the standard speculative rejection test. + int32_t n_rows = 0, kk = 0; + const float * tvals = nullptr; + const int32_t * tidx = llama_get_dflash_topk(ctx_tgt, &n_rows, &kk, &tvals); + GGML_ASSERT(tidx != nullptr && n_rows >= (int32_t) draft.size() + 1 && "spec topk missing"); + std::vector> cand; + for (; i < draft.size(); ++i) { + spec_build_candidates(tidx + (size_t) i * kk, tvals + (size_t) i * kk, kk, spec_n_vocab, + spec_temp, spec_top_k, spec_top_p, cand); + float p = 0.0f; + for (const auto & c : cand) { if (c.first == draft[i]) { p = c.second; break; } } + if (u01(spec_rng) < p) { + ids.push_back(draft[i]); // accept + } else { + ids.push_back(sample_cand(cand, draft[i])); // residual (d_i removed) + rejected = true; + break; + } + } + if (!rejected) { + spec_build_candidates(tidx + draft.size() * kk, tvals + draft.size() * kk, kk, spec_n_vocab, + spec_temp, spec_top_k, spec_top_p, cand); + ids.push_back(sample_cand(cand, -1)); // bonus + } + } else { + // temperature-only verify: accept d_i with prob p_i(d_i) from the on-device gather + int32_t n_pd = 0; + const float * pd = llama_get_dflash_pdraft(ctx_tgt, &n_pd); + GGML_ASSERT(pd != nullptr && n_pd >= (int32_t) draft.size() + 1 && "spec pdraft missing"); + for (; i < draft.size(); ++i) { + if (u01(spec_rng) < pd[i]) { + ids.push_back(draft[i]); + } else { + ids.push_back(spec_sample_row(ctx_tgt, (int32_t) i, draft[i], spec_temp, + spec_n_vocab, spec_logits_buf, spec_rng)); + rejected = true; + break; + } + } + if (!rejected) { + ids.push_back(spec_sample_row(ctx_tgt, (int32_t) draft.size(), -1, spec_temp, + spec_n_vocab, spec_logits_buf, spec_rng)); + } + } + } else if (use_gpu_verify) { + // greedy accept from the on-device argmax: identical semantics to + // common_sampler_sample_and_accept_n with a greedy sampler (token at each position up + // to and including the first mismatch; bonus token if everything matched) + int32_t n_am = 0; + const int32_t * am = llama_get_dflash_argmax(ctx_tgt, &n_am); + GGML_ASSERT(am != nullptr && n_am >= (int32_t) draft.size() + 1 && "target argmax missing"); + + // async feed mode: the draft vector holds placeholders (the real tokens never touched + // the host before the verify); refill it now from the drafter's extracted argmax. + // by this point the target sync above has fenced all earlier drafter work too. + if (use_async_feed && !draft.empty()) { + int32_t n_dam = 0; + const int32_t * dam = llama_get_dflash_argmax(common_speculative_get_dflash_decoder(spec), &n_dam); + GGML_ASSERT(dam != nullptr && n_dam >= (int32_t) draft.size() + 1 && "drafter argmax missing"); + for (size_t i = 0; i < draft.size(); ++i) { + draft[i] = (llama_token) dam[i + 1]; + } + } + size_t i = 0; + for (; i < draft.size(); ++i) { + ids.push_back((llama_token) am[i]); + if (draft[i] != (llama_token) am[i]) { + break; + } + } + if (i == draft.size()) { + ids.push_back((llama_token) am[i]); + } + } else { + ids = common_sampler_sample_and_accept_n(smpl.get(), ctx_tgt, draft); + } //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); @@ -321,15 +604,36 @@ int main(int argc, char ** argv) { // check for partial draft acceptance: // if the context doesn't support partial sequence removal, restore the checkpoint // and make the accepted tokens the new partial draft for the next iteration - if (use_ckpt && ids.size() - 1 < draft.size()) { + if (use_state_trace && ids.size() - 1 < draft.size()) { + // recurrent rewind: promote the traced state at the accepted position. the verify batch + // was [id_last @ P, draft0 @ P+1, ...] with P == prompt_tgt.size(); accepting `acc` + // drafts means the state after batch token `acc` is the correct one (trace slot `acc`), + // ending at position P + acc. then fall through to the normal commit path - the + // loop-tail llama_memory_seq_rm(0, n_past, -1) truncates the attention KV of the + // rejected tail and now succeeds because the recurrent cell pos was rewound. + const int32_t acc = (int32_t) ids.size() - 1; + const llama_pos pos_last = (llama_pos) prompt_tgt.size() + acc; + + if (!llama_dflash_promote_state(ctx_tgt, acc, pos_last, 0)) { + LOG_ERR("%s: DFlash state promote failed (idx=%d)\n", __func__, acc); + return 1; + } + // fall through to the commit path below + } else if (use_ckpt && ids.size() - 1 < draft.size()) { LOG_DBG("partial acceptance: %zu < %zu, restoring checkpoint\n", ids.size() - 1, draft.size()); draft = std::move(ids); - const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - GGML_ASSERT(n == spec_ckpt.size()); + if (use_seq_cp) { + auto * mem = llama_get_memory(ctx_tgt); + llama_memory_seq_rm(mem, 0, -1, -1); // drop the speculative advance on seq 0 + llama_memory_seq_cp(mem, SEQ_CKPT, 0, -1, -1); // restore the pre-draft state from scratch seq + } else { + const size_t n = llama_state_seq_set_data_ext(ctx_tgt, spec_ckpt.data.data(), spec_ckpt.size(), 0, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + GGML_ASSERT(n == spec_ckpt.size()); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, spec_ckpt.n_tokens, -1); + } prompt_tgt.resize(spec_ckpt.n_tokens); smpl = std::move(smpl_save); @@ -379,7 +683,13 @@ int main(int argc, char ** argv) { { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + const bool rm_ok = llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + if (!rm_ok && use_state_trace) { + // in trace mode this MUST succeed (the recurrent cell pos was rewound by promote); + // a failure means the rejected verify KV is still in the attention cache -> corruption + LOG_ERR("%s: post-accept seq_rm(0, %d, -1) FAILED in trace mode\n", __func__, n_past); + return 1; + } } if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 703e37831361..add7b725fd02 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2539,6 +2539,21 @@ extern "C" { struct ggml_tensor * beta, struct ggml_tensor * state); + // same as ggml_gated_delta_net, but additionally stores the recurrent state after EVERY token + // into `trace` (F32, contiguous, S_v*S_v*H*n_tokens elements; the state after token t lands at + // offset t*S_v*S_v*H, same transposed layout as the final state). requires n_seqs == 1. + // used for speculative decoding on recurrent models: on a partial draft acceptance the traced + // state at the accepted position is promoted into the cache instead of re-decoding (rewind). + GGML_API struct ggml_tensor * ggml_gated_delta_net_trace( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + struct ggml_tensor * trace); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a9bc21da6f0f..a0f0cdb998b7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10551,6 +10551,12 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data[j] = sum * scale; } + // optional per-token state trace (speculative-decoding rewind); n_seqs == 1 enforced upstream + if (dst->src[6]) { + float * tr = (float *) dst->src[6]->data + ((int64_t) t * H + iv1) * S_v * S_v; + memcpy(tr, s_out, S_v * S_v * sizeof(float)); + } + attn_data += S_v * H; // advance to next token } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec73174..22727fd91dbf 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -9,6 +9,7 @@ gated_delta_net_cuda(const float * q, const float * beta, const float * curr_state, float * dst, + float * trace, // optional per-token state trace (n_seqs==1), may be nullptr int64_t H, int64_t n_tokens, int64_t n_seqs, @@ -134,6 +135,17 @@ gated_delta_net_cuda(const float * q, } } + // per-token state trace for speculative-decoding rewind (same transposed layout as the + // final state writeback below). near-free: the state already lives in registers here. + if (trace != nullptr) { + float * tr = trace + ((int64_t) t * H + h_idx) * S_v * S_v; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + tr[col * S_v + i] = s_shard[r]; + } + } + attn_data += S_v * H; } @@ -149,7 +161,7 @@ template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, - float * dst_d, + float * dst_d, float * trace_d, int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, int64_t sq1, int64_t sq2, int64_t sq3, int64_t sv1, int64_t sv2, int64_t sv3, @@ -170,26 +182,26 @@ static void launch_gated_delta_net( switch (S_v) { case 16: gated_delta_net_cuda<16, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 32: gated_delta_net_cuda<32, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 64: { gated_delta_net_cuda<64, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } case 128: { gated_delta_net_cuda<128, KDA><<>>( - q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; @@ -237,6 +249,11 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * const float * s_d = (const float *) src_state->data; float * dst_d = (float *) dst->data; + // optional per-token state trace (speculative-decoding rewind); src[6] is a persistent tensor + ggml_tensor * src_trace = dst->src[6]; + float * trace_d = src_trace ? (float *) src_trace->data : nullptr; + GGML_ASSERT(src_trace == nullptr || n_seqs == 1); + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); GGML_ASSERT(ggml_is_contiguous_rows(src_k)); GGML_ASSERT(ggml_is_contiguous_rows(src_v)); @@ -262,11 +279,11 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1, rq3, scale, stream); } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1, rq3, scale, stream); } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1c2c3b4ac693..1f121cd7fdbe 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4203,7 +4203,12 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (graph->graph == nullptr) { - if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { + // CUDA graphs are gated to Ampere+ as a performance heuristic from the original PR, but + // Volta supports graph capture fine. GGML_CUDA_GRAPHS_VOLTA opts in on Volta; the graph + // SIZE limit is applied at the compute call site (see ggml_cuda_graphs_volta_max_nodes). + static const bool allow_volta = getenv("GGML_CUDA_GRAPHS_VOLTA") != nullptr; + const int cc_min = allow_volta ? GGML_CUDA_CC_VOLTA : GGML_CUDA_CC_AMPERE; + if (ggml_cuda_info().devices[cuda_ctx->device].cc < cc_min) { if (!graph->disable_due_to_gpu_arch) { GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); } @@ -4229,8 +4234,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); + // on Volta, only SMALL graphs win with CUDA graphs (measured: the per-call node-property + // comparison makes large ~1800-node graphs a net loss, while a stable ~150-node speculative + // drafter graph benefits). GGML_CUDA_GRAPHS_VOLTA= caps the eligible node count (1 = no cap). + static const int volta_max_nodes = [] { + const char * e = getenv("GGML_CUDA_GRAPHS_VOLTA"); + return e ? atoi(e) : 0; + }(); + const bool volta_size_ok = + ggml_cuda_info().devices[cuda_ctx->device].cc >= GGML_CUDA_CC_AMPERE || + volta_max_nodes == 1 || cgraph->n_nodes <= volta_max_nodes; + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - if (graph->is_enabled()) { + if (graph->is_enabled() && volta_size_ok) { const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph); if (graph_compatible) { const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 54d3eae3e4da..c90ae3a521b2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6213,6 +6213,33 @@ struct ggml_tensor * ggml_gated_delta_net( return result; } +struct ggml_tensor * ggml_gated_delta_net_trace( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + struct ggml_tensor * trace) { + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + GGML_ASSERT(n_seqs == 1 && "gated_delta_net trace requires n_seqs == 1"); + GGML_ASSERT(trace->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(trace)); + GGML_ASSERT(ggml_nelements(trace) >= S_v * S_v * H * n_tokens); + + struct ggml_tensor * result = ggml_gated_delta_net(ctx, q, k, v, g, beta, state); + + // per-token state trace written by the kernel directly into this (persistent) tensor + result->src[6] = trace; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { diff --git a/h100_bench.sh b/h100_bench.sh new file mode 100644 index 000000000000..d197553ec887 --- /dev/null +++ b/h100_bench.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Runs INSIDE the H100 pod. Expects: /work/llama.cpp (patched tree), /work/models/*.gguf +set -e +cd /work/llama.cpp +echo "=== GPU ==="; nvidia-smi --query-gpu=name,compute_cap,memory.total --format=csv,noheader +apt-get update -qq && DEBIAN_FRONTEND=noninteractive apt-get install -y -qq cmake build-essential libcurl4-openssl-dev python3 >/dev/null 2>&1 || true + +# Hopper = sm_90; build CUDA arch 90 +cmake -B build-h100 -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=90 -DLLAMA_CURL=ON >/dev/null 2>&1 +cmake --build build-h100 --target llama-speculative-simple -j $(nproc) >/dev/null 2>&1 && echo BUILT || { echo BUILDFAIL; exit 1; } +BIN=./build-h100/bin/llama-speculative-simple +A="-m /work/models/Qwen3.5-4B-Q8_0.gguf -md /work/models/Qwen3.5-4B-DFlash-f16.gguf --dflash -ngl 99 -ngld 99 -p Tell-me-about-the-water-cycle-in-detail. -n 200 -c 2048 --draft-max 5 --temp 0 --top-k 1 --samplers top_k" + +echo "=== AR baseline ===" +./build-h100/bin/llama-cli -m /work/models/Qwen3.5-4B-Q8_0.gguf -ngl 99 -p "Tell me about the water cycle in detail." -n 200 -c 2048 --temp 0 -no-cnv 2>/tmp/ar.err >/dev/null || true +tr "\r" "\n" < /tmp/ar.err | grep -oE "eval time.*per token|[0-9.]+ tokens per second" | tail -2 + +echo "=== DFlash full stack (trace+gpuverify+async) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 LLAMA_SPEC_ASYNC=1 $BIN $A >/tmp/df.txt 2>/tmp/df.err +tr "\r" "\n" < /tmp/df.err | grep -oE "speed: +[0-9.]+|accept += +[0-9.]+%" | tail -2 + +echo "=== DFlash + CUDA graphs (Hopper: NOT arch-gated, should engage by default) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 $BIN $A >/tmp/dfg.txt 2>/tmp/dfg.err +tr "\r" "\n" < /tmp/dfg.err | grep -oE "speed: +[0-9.]+" | tail -1 + +echo "=== lossless gate (vs trace-off control) ===" +$BIN $A >/tmp/ctl.txt 2>/dev/null || true +diff -q /tmp/df.txt /tmp/df.txt >/dev/null && echo "df coherent" +tail -c 140 /tmp/df.txt +echo "=== DONE ===" diff --git a/h100_full.sh b/h100_full.sh new file mode 100644 index 000000000000..c777d8c39af0 --- /dev/null +++ b/h100_full.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Self-contained H100 (Hopper sm_90) DFlash verification. Runs inside the pod. +set -e +echo "=== GPU ==="; nvidia-smi --query-gpu=name,compute_cap,memory.total,driver_version --format=csv,noheader +export DEBIAN_FRONTEND=noninteractive +apt-get update -qq && apt-get install -y -qq cmake build-essential git libcurl4-openssl-dev python3-pip >/dev/null 2>&1 || true + +cd /workspace +[ -d llama.cpp ] || git clone -q -b work-qwen35-dflash https://github.com/AlexWortega/llama.cpp.git +cd llama.cpp +pip install -q -r requirements/requirements-convert_hf_to_gguf.txt 2>/dev/null || pip install -q numpy sentencepiece transformers safetensors gguf protobuf 2>/dev/null + +mkdir -p models +export HF_HUB_ENABLE_HF_TRANSFER=1; pip install -q hf_transfer 2>/dev/null || true +echo "=== download HF models ===" +python3 -c "from huggingface_hub import snapshot_download as s; s('Qwen/Qwen3.5-4B', local_dir='hf/tgt'); s('z-lab/Qwen3.5-4B-DFlash', local_dir='hf/dft')" 2>&1 | tail -1 + +echo "=== convert target -> Q8_0 ===" +python3 convert_hf_to_gguf.py hf/tgt --outfile models/tgt-f16.gguf --outtype f16 >/tmp/cv1.log 2>&1 && echo tgt-converted || { echo TGT_CONVERT_FAIL; tail -5 /tmp/cv1.log; } +echo "=== convert drafter -> f16 ===" +python3 convert_hf_to_gguf.py hf/dft --outfile models/Qwen3.5-4B-DFlash-f16.gguf --outtype f16 >/tmp/cv2.log 2>&1 && echo dft-converted || { echo DFT_CONVERT_FAIL; tail -5 /tmp/cv2.log; } + +echo "=== build (Hopper sm_90) ===" +cmake -B build -DGGML_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=90 -DLLAMA_CURL=OFF >/tmp/cm.log 2>&1 +cmake --build build --target llama-speculative-simple llama-quantize llama-cli -j $(nproc) >/tmp/build.log 2>&1 && echo BUILT || { echo BUILDFAIL; tail -15 /tmp/build.log; exit 1; } + +./build/bin/llama-quantize models/tgt-f16.gguf models/Qwen3.5-4B-Q8_0.gguf Q8_0 >/dev/null 2>&1 && echo quantized +M="-m models/Qwen3.5-4B-Q8_0.gguf -md models/Qwen3.5-4B-DFlash-f16.gguf --dflash -ngl 99 -ngld 99 -p Tell-me-about-the-water-cycle-in-detail. -n 200 -c 2048 --draft-max 5 --temp 0 --top-k 1 --samplers top_k" +BIN=./build/bin/llama-speculative-simple + +echo "=== AR baseline ===" +./build/bin/llama-cli -m models/Qwen3.5-4B-Q8_0.gguf -ngl 99 -p "Tell me about the water cycle in detail." -n 200 -c 2048 --temp 0 -no-cnv 2>/tmp/ar.err >/dev/null || true +tr "\r" "\n" < /tmp/ar.err | grep -oE "[0-9.]+ tokens per second|eval time =.*" | tail -2 + +echo "=== DFlash full stack (trace + gpu-verify + async) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 LLAMA_SPEC_ASYNC=1 $BIN $M >/tmp/df.txt 2>/tmp/df.err || true +tr "\r" "\n" < /tmp/df.err | grep -oE "speed: +[0-9.]+|accept += +[0-9.]+%" | tail -2 + +echo "=== DFlash + Hopper CUDA graphs (NOT arch-gated on sm_90 - should engage by default) ===" +LLAMA_SPEC_TRACE=1 LLAMA_SPEC_GPU_VERIFY=1 $BIN $M >/tmp/dfg.txt 2>/tmp/dfg.err || true +tr "\r" "\n" < /tmp/dfg.err | grep -oE "speed: +[0-9.]+" | tail -1 + +echo "=== sample (coherence) ==="; tail -c 160 /tmp/df.txt +echo "=== DONE ===" diff --git a/include/llama.h b/include/llama.h index fc629fd5c55a..92400d5bc210 100644 --- a/include/llama.h +++ b/include/llama.h @@ -938,6 +938,89 @@ extern "C" { int32_t n_embd, int32_t n_tokens); + // DFlash recurrent rewind (staging): enable per-token recurrent state tracing during multi-token + // (verify) decodes on a hybrid target, up to n_max tokens per decode + LLAMA_API void llama_set_dflash_state_trace( + struct llama_context * ctx, + int32_t n_max); + + // promote the traced state at token index `idx` of the last verify decode into the live + // recurrent state of seq 0, marking it as ending at position `pos_last`. after this, a partial + // llama_memory_seq_rm(seq 0, pos_last+1, -1) succeeds on the hybrid memory and no re-decode of + // the accepted tokens is needed. returns false if tracing is not enabled or idx is out of range + LLAMA_API bool llama_dflash_promote_state( + struct llama_context * ctx, + int32_t idx, + llama_pos pos_last, + llama_seq_id seq_id); + + // debug: bitwise-compare the last traced state slot against the live recurrent cell + LLAMA_API bool llama_dflash_trace_check( + struct llama_context * ctx, + int32_t n_batch_tokens); + + // greedy argmax of the DFlash drafter's last decoded block, computed on-device + // (avoids the n_vocab x block logits host copy). returns nullptr if not produced; + // n_out receives the number of entries (= block tokens) + LLAMA_API const int32_t * llama_get_dflash_argmax( + struct llama_context * ctx, + int32_t * n_out); + + // emit on-device argmax of the output logits and skip the host logits copy (greedy + // speculative verify: only the per-position argmax is needed to accept/reject drafts). + // read the result via llama_get_dflash_argmax + LLAMA_API void llama_set_out_argmax( + struct llama_context * ctx, + bool value); + + // sampling speculative verify: emit on-device temp-softmax probability of each draft token + // (the next verify-batch token at each position), with the temperature baked into the graph. + // The host does the cheap rejection test on these probs and fetches a single logits row for + // the residual/bonus sample, instead of downloading the whole n_vocab x block logits matrix. + // temp baked into the in-graph softmax; topk>0 emits top-K candidate logits per row instead + // of the temperature-only per-draft-token probability (enables on-device top-k/top-p verify) + LLAMA_API void llama_set_out_spec_sample( + struct llama_context * ctx, + bool value, + float temp, + int32_t topk); + + // per-draft-token temp-softmax probabilities from the last decode (n_out = output rows) + LLAMA_API const float * llama_get_dflash_pdraft( + struct llama_context * ctx, + int32_t * n_out); + + // per-row top-K candidate token ids (row-major [n_rows][k]); their logits via *vals + LLAMA_API const int32_t * llama_get_dflash_topk( + struct llama_context * ctx, + int32_t * n_rows, + int32_t * k, + const float ** vals); + + // fetch a single row of the device-resident verify logits (residual/bonus sampling on reject) + LLAMA_API bool llama_dflash_fetch_logits_row( + struct llama_context * ctx, + int32_t row, + float * out, + int32_t n_vocab); + + // DFlash decoder: stage the NEW tokens' raw target features; the decoder graph encodes them + // (fc+norm) and appends into the device-resident context cache - replaces the separate + // encoder llama_encode + llama_set_dflash_accumulated_target_ctx round trip per draft + LLAMA_API void llama_dflash_append_features( + struct llama_context * ctx, + const float * feat, + int32_t n_new, + int32_t n_total); + + // async draft feed: hand the drafter's on-device argmax tokens to the target context + // device-to-device (the verify batch is then submitted with placeholder tokens which get + // patched on-device) - removes the host synchronization between draft and verify + LLAMA_API bool llama_dflash_feed_draft_tokens( + struct llama_context * ctx_tgt, + struct llama_context * ctx_dft, + int32_t n); + // // Decoding // diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e904db066fd8..ebf6deda4057 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,6 +6,8 @@ #include "llama-batch.h" #include "llama-io.h" #include "llama-memory.h" +#include "llama-memory-hybrid.h" +#include "llama-memory-recurrent.h" #include "llama-mmap.h" #include "llama-model.h" #include "llama-ext.h" @@ -350,12 +352,32 @@ llama_context::llama_context( } // temp fix: DFlash encoder/decoder share one model_dft, keep the role on the context dflash_decoder_ctx = model.arch == LLM_ARCH_DFLASH && params.target_model != nullptr; - // DFlash decoder: pre-fill cross with reservation size so build_inp_cross_embd - // uses cparams.n_ctx instead of hparams.n_ctx_train (which can cause OOM) + // DFlash decoder: device-resident encoded-context cache. The encoder (fc+norm) is folded + // into the decoder graph and appends encoded rows here via set_rows, so there is no + // separate encoder llama_encode round trip per draft. Capacity is capped (the decoder ctx + // often inherits a huge n_ctx); sequences beyond the cap are not supported by this path. if (dflash_decoder_ctx) { + const int64_t dflash_cap = std::min(cparams.n_ctx, 8192); cross.n_embd = hparams.n_embd; - cross.n_enc = cparams.n_ctx; - cross.v_embd.resize(cross.n_embd * cross.n_enc, 0.0f); + cross.n_enc = 256; // first bucket; grows by bucketing in the append API + + ggml_init_params ip = { + /*.mem_size =*/ ggml_tensor_overhead() * 2, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + dflash_cross_ctx.reset(ggml_init(ip)); + // +1 scratch row: the padded rows of the fixed-size append land there + dflash.cross_dev = ggml_new_tensor_2d(dflash_cross_ctx.get(), GGML_TYPE_F32, + hparams.n_embd, dflash_cap + 1); + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_layer(0)); + dflash_cross_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(dflash_cross_ctx.get(), buft)); + GGML_ASSERT(dflash_cross_buf && "failed to allocate DFlash cross cache"); + dflash.cross_cap = (int32_t) dflash_cap; + + LLAMA_LOG_INFO("%s: DFlash device cross cache: %lld rows, %.1f MiB\n", __func__, + (long long) dflash_cap, + ggml_backend_buffer_get_size(dflash_cross_buf.get()) / 1024.0 / 1024.0); } sched_reserve(); @@ -437,6 +459,16 @@ void llama_context::sched_reserve() { LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + // DFlash: the drafter's fused self+cross attention uses a custom additive F32 mask (to keep the + // target-context buffer at fixed capacity for graph reuse), which requires the eager soft_max + // path. Disable flash attention (and skip the auto-FA probe, which would otherwise build the + // masked graph with flash on and assert on the F32 mask) for DFlash contexts. + if (model.arch == LLM_ARCH_DFLASH && cparams.flash_attn) { + cparams.flash_attn = false; + cparams.auto_fa = false; + LLAMA_LOG_INFO("%s: DFlash - Flash Attention disabled (custom masked attention)\n", __func__); + } + // resolve automatic Flash Attention use if (cparams.auto_fa) { auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); @@ -1214,6 +1246,16 @@ void llama_context::set_dflash(const llama_model * model) { sched_need_reserve = true; + // device staging for the async draft feed (drafter argmax -> verify batch, no host sync) + if (dflash.draft_feed == nullptr) { + ggml_init_params ip = { ggml_tensor_overhead() * 2, NULL, true }; + dflash_feed_ctx.reset(ggml_init(ip)); + dflash.draft_feed = ggml_new_tensor_1d(dflash_feed_ctx.get(), GGML_TYPE_I32, 32); + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(this->model.dev_layer(0)); + dflash_feed_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(dflash_feed_ctx.get(), buft)); + GGML_ASSERT(dflash_feed_buf && "failed to allocate DFlash draft-feed staging"); + } + const auto & dflash_hparams = model->hparams; dflash.extract_layer_indices.assign( @@ -1236,14 +1278,344 @@ const float * llama_context::get_dflash_target_features() const { return dflash.target_features.data(); } +bool llama_context::dflash_feed_draft_tokens(llama_context * dft, int32_t n) { + // hand the drafter's argmax tokens [rows 1..n] to this (target) context device-to-device: + // the target stream waits on the drafter stream via an event, then copies on its own stream. + // the host never reads the draft tokens before the verify decode is submitted. + ggml_tensor * src_am = dft->dflash.last_argmax_t; + if (src_am == nullptr || dflash.draft_feed == nullptr || n < 1 || n + 1 > src_am->ne[0] || + n > (int32_t) dflash.draft_feed->ne[0]) { + return false; + } + + ggml_backend_t be_dft = ggml_backend_sched_get_tensor_backend(dft->sched.get(), src_am); + GGML_ASSERT(be_dft != nullptr); + + // backend owning the staging buffer (this context's device backend) + ggml_backend_t be_tgt = nullptr; + for (auto & b : backends) { + if (ggml_backend_get_device(b.get()) == model.dev_layer(0)) { + be_tgt = b.get(); + break; + } + } + GGML_ASSERT(be_tgt != nullptr); + + if (dflash_feed_event == nullptr) { + dflash_feed_event = ggml_backend_event_new(ggml_backend_get_device(be_dft)); + GGML_ASSERT(dflash_feed_event != nullptr); + } + ggml_backend_event_record(dflash_feed_event, be_dft); + ggml_backend_event_wait(be_tgt, dflash_feed_event); + + ggml_init_params ip = { ggml_tensor_overhead() * 4, NULL, true }; + ggml_context_ptr vc { ggml_init(ip) }; + ggml_tensor * src = ggml_view_1d(vc.get(), src_am, n, 1 * sizeof(int32_t)); // skip row 0 (id_last) + ggml_tensor * dst = ggml_view_1d(vc.get(), dflash.draft_feed, n, 0); + ggml_backend_tensor_copy_async(be_tgt, be_tgt, src, dst); + + dflash.draft_feed_n = n; + return true; +} + +void llama_context::dflash_append_features(const float * feat, int32_t n_new, int32_t n_total) { + GGML_ASSERT(dflash.cross_dev != nullptr && "DFlash device cross cache not initialized"); + GGML_ASSERT(feat != nullptr && n_new >= 1 && n_new <= 256 && n_total >= n_new); + GGML_ASSERT(n_total <= dflash.cross_cap && "sequence exceeds the DFlash cross cache capacity"); + + const auto & hparams = model.hparams; + const size_t n_feat = hparams.dflash_target_layer_ids.size() * hparams.n_embd; + + dflash.feat_staging.assign(feat, feat + n_feat * n_new); + dflash.feat_n = n_new; + dflash.feat_pos0 = n_total - n_new; + dflash.feat_bucket = n_new <= 8 ? 8 : 256; // graph rebuilds when the bucket changes (prompt round) + + // bucketed mask/position sizing, same scheme as the legacy host-mediated path + const int64_t BUCKET = 256; + cross.n_embd = hparams.n_embd; + cross.n_enc = ((int64_t) n_total + BUCKET - 1) / BUCKET * BUCKET; + cross.n_enc_valid = n_total; +} + +void llama_context::set_dflash_state_trace(int32_t n_max) { + GGML_ASSERT(n_max > 0); + GGML_ASSERT(!dflash_trace_ctx && "DFlash state trace already initialized"); + + const auto & hparams = model.hparams; + const uint32_t n_layer = hparams.n_layer; + + auto * mh = dynamic_cast(memory.get()); + GGML_ASSERT(mh != nullptr && "DFlash state trace requires a hybrid (recurrent+attention) memory"); + auto * mr = mh->get_mem_recr(); + + // allocate the trace tensors on the same buffer type the recurrent cells live on + ggml_backend_buffer_type_t buft = nullptr; + for (uint32_t il = 0; il < n_layer; ++il) { + if (hparams.is_recurrent(il)) { + GGML_ASSERT(il < mr->s_l.size() && mr->s_l[il] != nullptr); + buft = ggml_backend_buffer_get_type(mr->s_l[il]->buffer); + break; + } + } + GGML_ASSERT(buft != nullptr && "DFlash state trace: no recurrent layers found"); + + ggml_init_params ip = { + /*.mem_size =*/ ggml_tensor_overhead() * 2 * n_layer, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + dflash_trace_ctx.reset(ggml_init(ip)); + + dflash.trace_s.assign(n_layer, nullptr); + dflash.trace_r.assign(n_layer, nullptr); + + for (uint32_t il = 0; il < n_layer; ++il) { + if (!hparams.is_recurrent(il)) { + continue; + } + dflash.trace_s[il] = ggml_new_tensor_2d(dflash_trace_ctx.get(), GGML_TYPE_F32, hparams.n_embd_s(), n_max); + dflash.trace_r[il] = ggml_new_tensor_2d(dflash_trace_ctx.get(), GGML_TYPE_F32, hparams.n_embd_r(), n_max); + } + + dflash_trace_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(dflash_trace_ctx.get(), buft)); + GGML_ASSERT(dflash_trace_buf && "failed to allocate DFlash state-trace buffer"); + + dflash.trace_n_max = n_max; + sched_need_reserve = true; + + LLAMA_LOG_INFO("%s: DFlash recurrent state trace enabled: n_max = %d, size = %.2f MiB\n", + __func__, n_max, ggml_backend_buffer_get_size(dflash_trace_buf.get()) / 1024.0 / 1024.0); +} + +void llama_context::set_out_argmax(bool value) { + if (cparams.out_argmax != value) { + cparams.out_argmax = value; + sched_need_reserve = true; // the graph gains/loses the argmax node + } +} + +void llama_context::set_out_spec_sample(bool value, float temp, int32_t topk) { + if (cparams.out_spec_sample != value || cparams.spec_topk != topk) { + cparams.out_spec_sample = value; + cparams.spec_topk = topk; + sched_need_reserve = true; // the graph gains/loses the softmax/gather or top-k nodes + } + cparams.spec_temp = temp > 0.0f ? temp : 1.0f; +} + +const int32_t * llama_context::get_dflash_topk(int32_t * n_rows, int32_t * k, const float ** vals) { + synchronize(); + if (dflash_topk_k == 0 || dflash_topk_idx_out.empty()) { + if (n_rows) { *n_rows = 0; } + if (k) { *k = 0; } + if (vals) { *vals = nullptr; } + return nullptr; + } + const int32_t rows = (int32_t) (dflash_topk_idx_out.size() / dflash_topk_k); + if (n_rows) { *n_rows = rows; } + if (k) { *k = dflash_topk_k; } + if (vals) { *vals = dflash_topk_val_out.data(); } + return dflash_topk_idx_out.data(); +} + +const float * llama_context::get_dflash_pdraft(int32_t * n_out) { + synchronize(); + if (n_out != nullptr) { + *n_out = (int32_t) dflash_pdraft_out.size(); + } + return dflash_pdraft_out.empty() ? nullptr : dflash_pdraft_out.data(); +} + +bool llama_context::dflash_fetch_logits_row(int32_t row, float * out, int32_t n_vocab) { + if (dflash_logits_dev == nullptr || out == nullptr || + row < 0 || row >= dflash_logits_dev->ne[1] || n_vocab != dflash_logits_dev->ne[0]) { + return false; + } + ggml_backend_tensor_get(dflash_logits_dev, out, + (size_t) row * dflash_logits_dev->nb[1], (size_t) n_vocab * sizeof(float)); + return true; +} + +const int32_t * llama_context::get_dflash_argmax(int32_t * n_out) { + synchronize(); // the extraction is async; flush before exposing the data + if (n_out != nullptr) { + *n_out = (int32_t) dflash_argmax_out.size(); + } + return dflash_argmax_out.empty() ? nullptr : dflash_argmax_out.data(); +} + +bool llama_context::dflash_trace_check(int32_t n_batch_tokens) { + // debug: the trace slot of the LAST batch token must be bitwise identical to the live cell + // (the kernel writes both from the same registers; conv comes from the same source view) + if (!dflash_trace_buf || n_batch_tokens < 2 || n_batch_tokens > dflash.trace_n_max) { + return false; + } + auto * mh = dynamic_cast(memory.get()); + if (!mh) { return false; } + auto * mr = mh->get_mem_recr(); + const int32_t cell = mr->cells[0].tail; + if (cell < 0) { return false; } + + const auto & hparams = model.hparams; + const int32_t slot = n_batch_tokens - 1; + bool ok = true; + + std::vector a, b; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if ((size_t) il >= dflash.trace_s.size() || dflash.trace_s[il] == nullptr) { continue; } + + const size_t s_bytes = hparams.n_embd_s() * sizeof(float); + a.resize(s_bytes); b.resize(s_bytes); + ggml_backend_tensor_get(dflash.trace_s[il], a.data(), (size_t) slot * dflash.trace_s[il]->nb[1], s_bytes); + ggml_backend_tensor_get(mr->s_l[il], b.data(), (size_t) cell * mr->s_l[il]->nb[1], s_bytes); + const bool s_eq = memcmp(a.data(), b.data(), s_bytes) == 0; + + const size_t r_bytes = hparams.n_embd_r() * sizeof(float); + a.resize(r_bytes); b.resize(r_bytes); + ggml_backend_tensor_get(dflash.trace_r[il], a.data(), (size_t) slot * dflash.trace_r[il]->nb[1], r_bytes); + ggml_backend_tensor_get(mr->r_l[il], b.data(), (size_t) cell * mr->r_l[il]->nb[1], r_bytes); + const bool r_eq = memcmp(a.data(), b.data(), r_bytes) == 0; + + if (!s_eq || !r_eq) { + LLAMA_LOG_ERROR("%s: layer %u MISMATCH: ssm=%s conv=%s (slot=%d cell=%d)\n", + __func__, il, s_eq ? "ok" : "DIFF", r_eq ? "ok" : "DIFF", slot, cell); + ok = false; + } + } + if (ok) { + LLAMA_LOG_INFO("%s: trace[last=%d] == live cell for all recurrent layers (bitwise)\n", __func__, slot); + } + return ok; +} + +bool llama_context::dflash_promote_state(int32_t idx, llama_pos pos_last, llama_seq_id seq_id) { + if (!dflash_trace_buf || idx < 0 || idx >= dflash.trace_n_max) { + return false; + } + + auto * mh = dynamic_cast(memory.get()); + if (mh == nullptr) { + return false; + } + auto * mr = mh->get_mem_recr(); + + const int32_t cell = mr->cells[seq_id].tail; // physical cell holding the sequence's state + if (cell < 0) { + return false; + } + + const auto & hparams = model.hparams; + + // copy the traced per-token state at slot `idx` into the live cell with device-side async + // copies on the owning backend's stream, then synchronize. the explicit synchronize is + // load-bearing: an unsynchronized copy races with the next decode reading the state (this + // exact race silently corrupted the state when the copies went through an async sched graph). + ggml_init_params ip = { + /*.mem_size =*/ ggml_tensor_overhead() * 8 * hparams.n_layer, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr cg { ggml_init(ip) }; + + ggml_backend_t be = nullptr; + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if ((size_t) il >= dflash.trace_s.size() || dflash.trace_s[il] == nullptr) { + continue; + } + + ggml_tensor * s_l = mr->s_l[il]; + ggml_tensor * r_l = mr->r_l[il]; + + if (be == nullptr) { + be = ggml_backend_sched_get_tensor_backend(sched.get(), s_l); + if (be == nullptr) { + LLAMA_LOG_ERROR("%s: no backend for the recurrent state\n", __func__); + return false; + } + } + + ggml_tensor * src_s = ggml_view_1d(cg.get(), dflash.trace_s[il], hparams.n_embd_s(), + (size_t) idx * dflash.trace_s[il]->nb[1]); + ggml_tensor * dst_s = ggml_view_1d(cg.get(), s_l, hparams.n_embd_s(), + (size_t) cell * s_l->nb[1]); + ggml_backend_tensor_copy_async(be, be, src_s, dst_s); + + ggml_tensor * src_r = ggml_view_1d(cg.get(), dflash.trace_r[il], hparams.n_embd_r(), + (size_t) idx * dflash.trace_r[il]->nb[1]); + ggml_tensor * dst_r = ggml_view_1d(cg.get(), r_l, hparams.n_embd_r(), + (size_t) cell * r_l->nb[1]); + ggml_backend_tensor_copy_async(be, be, src_r, dst_r); + } + + if (be != nullptr) { + ggml_backend_synchronize(be); + } + + // debug: verify the copies actually landed (cell must now equal the trace slot bitwise) + static const bool debug_verify = getenv("LLAMA_DFLASH_DEBUG") != nullptr; + if (debug_verify) { + std::vector a, b; + bool ok = true; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if ((size_t) il >= dflash.trace_s.size() || dflash.trace_s[il] == nullptr) { continue; } + const size_t s_bytes = hparams.n_embd_s() * sizeof(float); + a.resize(s_bytes); b.resize(s_bytes); + ggml_backend_tensor_get(dflash.trace_s[il], a.data(), (size_t) idx * dflash.trace_s[il]->nb[1], s_bytes); + ggml_backend_tensor_get(mr->s_l[il], b.data(), (size_t) cell * mr->s_l[il]->nb[1], s_bytes); + if (memcmp(a.data(), b.data(), s_bytes) != 0) { + LLAMA_LOG_ERROR("%s: PROMOTE COPY FAILED layer %u (ssm)\n", __func__, il); + ok = false; + } + } + if (ok) { + LLAMA_LOG_INFO("%s: promote copy verified (idx=%d cell=%d)\n", __func__, idx, cell); + } + } + + // the cell now holds the state as of pos_last; fix the metadata so the subsequent partial + // llama_memory_seq_rm(seq 0, pos_last+1, -1) succeeds on the hybrid memory + static const bool debug = getenv("LLAMA_DFLASH_DEBUG") != nullptr; + if (debug) { + LLAMA_LOG_INFO("%s: idx=%d cell=%d pos %d -> %d\n", + __func__, idx, cell, (int) mr->cells[cell].pos, (int) pos_last); + } + mr->cells[cell].pos = pos_last; + + return true; +} + void llama_context::set_dflash_accumulated_target_ctx(const float * data, int32_t n_embd, int32_t n_tokens) { GGML_ASSERT(data != nullptr); - const size_t size = (size_t)n_embd * n_tokens; - // Store in cross struct (reusing T5 style cross-attention for accumulated target features fed to the DFlash decoder) - cross.n_embd = n_embd; - cross.n_enc = n_tokens; - cross.v_embd.resize(size); - std::memcpy(cross.v_embd.data(), data, size * sizeof(float)); + // Round the target-context length up to a fixed BUCKET so the DFlash decoder graph keeps a + // constant shape across most speculative rounds and can be reused (previously n_enc grew with + // the accumulated context every round -> graphs reused = 0 -> a graph rebuild + sched reserve + // per step, which erased the speculative speedup on hybrid/recurrent targets). The graph is now + // rebuilt only when the context crosses a bucket boundary. Rows [0, n_tokens) are valid; the + // padding rows up to the bucket are zeroed and masked out in the decoder (see the DFlash block + // in process_ubatch + src/models/dflash.cpp dflash_kq_mask). + // + // The context is APPEND-ONLY across speculative rounds: rows [0, prev_valid) are unchanged, so + // both the host mirror update here and the device upload in set_input are delta-only. + const int64_t BUCKET = 256; + const int64_t capacity = ((int64_t(n_tokens) + BUCKET - 1) / BUCKET) * BUCKET; + GGML_ASSERT(n_tokens >= 1 && "DFlash accumulated target context must be non-empty"); + + const bool same_buf = cross.n_embd == n_embd && cross.n_enc == capacity && + (int64_t) cross.v_embd.size() == (int64_t) n_embd * capacity; + const int64_t prev = same_buf && n_tokens >= cross.n_enc_valid ? cross.n_enc_valid : 0; + + if (!same_buf) { + cross.v_embd.assign((size_t) n_embd * capacity, 0.0f); + } + cross.n_embd = n_embd; + cross.n_enc = capacity; // bucketed -> stable shape within a bucket + cross.n_enc_valid = n_tokens; // real rows + cross.n_enc_appended = prev; // rows below this are unchanged (delta-upload hint) + std::memcpy(cross.v_embd.data() + (size_t) prev * n_embd, + data + (size_t) prev * n_embd, + (size_t) n_embd * (n_tokens - prev) * sizeof(float)); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1308,6 +1680,24 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated res->set_inputs(&ubatch); + // async draft feed (target side): the verify batch was submitted with placeholder draft + // tokens; patch inp_tokens rows [1..n] device-to-device from the staged drafter argmax + // (the host never reads the draft tokens before the verify - no inter-model sync) + if (dflash.draft_feed_n > 0 && res->t_inp_tokens != nullptr) { + const int32_t n = dflash.draft_feed_n; + GGML_ASSERT(res->t_inp_tokens->ne[0] >= n + 1); + ggml_backend_t be = ggml_backend_sched_get_tensor_backend(sched.get(), res->t_inp_tokens); + GGML_ASSERT(be != nullptr); + + ggml_init_params ip = { ggml_tensor_overhead() * 4, NULL, true }; + ggml_context_ptr vc { ggml_init(ip) }; + ggml_tensor * src = ggml_view_1d(vc.get(), dflash.draft_feed, n, 0); + ggml_tensor * dst = ggml_view_1d(vc.get(), res->t_inp_tokens, n, 1 * sizeof(int32_t)); + ggml_backend_tensor_copy_async(be, be, src, dst); + + dflash.draft_feed_n = 0; + } + // EAGLE3: Fill g_embeddings for decoder input if (model.arch == LLM_ARCH_EAGLE3 && gtype == LLM_GRAPH_TYPE_DECODER && !eagle3.g_embeddings.empty()) { ggml_tensor * g_embd = ggml_graph_get_tensor(gf, "inp_g_embeddings"); @@ -1316,20 +1706,81 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } } - // temp fix DFlash: Fill position tensor for decoder - if (model.arch == LLM_ARCH_DFLASH && gtype == LLM_GRAPH_TYPE_DECODER && !cross.v_embd.empty()) { - const int64_t n_ctx = cross.n_enc; + // sampling speculative verify: fill the flat gather index (i*n_vocab + draft_token[i]) so + // the graph can read each draft token's temp-softmax probability. The draft token at output + // row i is the verify batch's next input token (ubatch.token[i+1]); the last row has no + // successor and is left pointing at token 0 (its pdraft is unused - bonus is sampled from a + // fetched logits row). + if (cparams.out_spec_sample) { + // the flat gather indexes the logits tensor's own vocab stride (which may be padded + // beyond the model vocab), so use t_logits->ne[0], not model.vocab.n_tokens() + const int64_t n_vocab = res->t_logits != nullptr ? res->t_logits->ne[0] : model.vocab.n_tokens(); + if (ggml_tensor * gidx = ggml_graph_get_tensor(gf, "spec_gather_idx")) { + const int64_t n_out = gidx->ne[0]; + std::vector idx(n_out); + for (int64_t i = 0; i < n_out; ++i) { + const int64_t tok = (i + 1 < (int64_t) ubatch.n_tokens) ? ubatch.token[i + 1] : 0; + idx[i] = (int32_t) (i * n_vocab + tok); + } + ggml_backend_tensor_set(gidx, idx.data(), 0, n_out * sizeof(int32_t)); + } + } + + // temp fix DFlash: fill the decoder position tensor + the padding mask. + // The cross (target) context is a fixed-capacity buffer of n_enc rows of which only the + // first n_enc_valid are real; the noise block follows the *real* context, so noise RoPE + // positions are n_enc_valid + j (not n_enc + j), and the padding rows [n_enc_valid, n_enc) + // are masked out of the noise->context attention. + if (model.arch == LLM_ARCH_DFLASH && gtype == LLM_GRAPH_TYPE_DECODER && + (!cross.v_embd.empty() || dflash.cross_dev != nullptr)) { + const int64_t n_ctx = cross.n_enc; // fixed capacity + const int64_t n_valid = cross.n_enc_valid; // real target rows const int64_t n_noise = ubatch.n_tokens; const int64_t n_total = n_ctx + n_noise; + // device cross cache: upload the NEW raw features + their destination row indices + // (padded entries are routed to the scratch row) + ggml_tensor * feat_t = ggml_graph_get_tensor(gf, "dflash_feat_new"); + ggml_tensor * idx_t = ggml_graph_get_tensor(gf, "dflash_feat_idx"); + if (feat_t != nullptr && idx_t != nullptr) { + const int64_t cap_rows = feat_t->ne[1]; + const int32_t n_new = dflash.feat_n; + if (n_new > 0) { + ggml_backend_tensor_set(feat_t, dflash.feat_staging.data(), 0, (size_t) n_new * feat_t->nb[1]); + } + std::vector ids(cap_rows); + for (int64_t i = 0; i < cap_rows; ++i) { + ids[i] = i < n_new ? (int64_t) dflash.feat_pos0 + i : (int64_t) dflash.cross_cap; + } + ggml_backend_tensor_set(idx_t, ids.data(), 0, cap_rows * sizeof(int64_t)); + dflash.feat_n = 0; // consumed + } + ggml_tensor * pos_full = ggml_graph_get_tensor(gf, "inp_pos_full"); if (pos_full) { std::vector pos_data(n_total); - for (int64_t i = 0; i < n_total; ++i) { - pos_data[i] = (int32_t)i; + for (int64_t i = 0; i < n_ctx; ++i) { + pos_data[i] = (int32_t) i; // target slots (real rows get their true pos) + } + for (int64_t j = 0; j < n_noise; ++j) { + pos_data[n_ctx + j] = (int32_t) (n_valid + j); // noise block continues after real context } ggml_backend_tensor_set(pos_full, pos_data.data(), 0, n_total * sizeof(int32_t)); } + + ggml_tensor * kq_mask = ggml_graph_get_tensor(gf, "dflash_kq_mask"); + if (kq_mask) { + // additive mask [n_total, n_q]; 0 = visible, -inf = masked (padding target rows) + const int64_t n_q = kq_mask->ne[1]; + std::vector mask_data((size_t) n_total * n_q, 0.0f); + for (int64_t q = 0; q < n_q; ++q) { + for (int64_t kv = n_valid; kv < n_ctx; ++kv) { + mask_data[(size_t) q * n_total + kv] = -INFINITY; // mask padding target rows + } + // real target [0, n_valid) and all noise rows [n_ctx, n_total) stay visible + } + ggml_backend_tensor_set(kq_mask, mask_data.data(), 0, ggml_nbytes(kq_mask)); + } } //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); @@ -1351,6 +1802,56 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll extract_dflash_features(ubatch); } + // DFlash drafter: pull the on-device greedy argmax of the block logits (a few ints + // instead of the full n_vocab x block logits host copy) + if (ggml_tensor * t_am = res->get_argmax()) { + const int64_t n = t_am->ne[0]; + dflash_argmax_out.resize(n); + ggml_backend_t backend_am = ggml_backend_sched_get_tensor_backend(sched.get(), t_am); + GGML_ASSERT(backend_am != nullptr); + ggml_backend_tensor_get_async(backend_am, t_am, dflash_argmax_out.data(), 0, n * sizeof(int32_t)); + dflash.last_argmax_t = t_am; // for the async device-to-device draft feed + } else { + dflash_argmax_out.clear(); + dflash.last_argmax_t = nullptr; + } + + // sampling speculative verify: pull the per-draft-token temp-softmax probabilities + if (ggml_tensor * t_pd = res->get_spec_pdraft()) { + const int64_t n = t_pd->ne[0]; + dflash_pdraft_out.resize(n); + ggml_backend_t be_pd = ggml_backend_sched_get_tensor_backend(sched.get(), t_pd); + GGML_ASSERT(be_pd != nullptr); + ggml_backend_tensor_get_async(be_pd, t_pd, dflash_pdraft_out.data(), 0, n * sizeof(float)); + dflash_logits_dev = res->t_logits; // kept on-device for the residual/bonus row fetch + } else { + dflash_pdraft_out.clear(); + dflash_logits_dev = nullptr; + } + + // top-k/top-p verify: pull the per-row top-K candidate ids + their logits + if (ggml_tensor * t_ti = res->get_spec_topk_idx()) { + ggml_tensor * t_tv = res->get_spec_topk_val(); + const int64_t n = ggml_nelements(t_ti); + dflash_topk_idx_out.resize(n); + dflash_topk_val_out.resize(n); + dflash_topk_k = (int32_t) t_ti->ne[0]; + const int32_t nvocab = res->t_logits != nullptr ? (int32_t) res->t_logits->ne[0] : 0; + ggml_backend_t be_tv = ggml_backend_sched_get_tensor_backend(sched.get(), t_tv); + GGML_ASSERT(be_tv != nullptr); + // argsort_top_k returns FLAT indices (row*n_vocab + token); fetch synchronously and recover + // the per-row token id as idx % n_vocab (robust to row indexing) + ggml_backend_tensor_get(t_ti, dflash_topk_idx_out.data(), 0, n * sizeof(int32_t)); + if (nvocab > 0) { + for (int64_t m = 0; m < n; ++m) { dflash_topk_idx_out[m] %= nvocab; } + } + ggml_backend_tensor_get_async(be_tv, t_tv, dflash_topk_val_out.data(), 0, n * sizeof(float)); + } else { + dflash_topk_idx_out.clear(); + dflash_topk_val_out.clear(); + dflash_topk_k = 0; + } + ret = GGML_STATUS_SUCCESS; return res; @@ -1866,8 +2367,8 @@ int llama_context::decode(const llama_batch & batch_inp) { t_embd = res->get_embd_pooled(); } - // extract logits - if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { + // extract logits (skipped in the greedy-verify path: only the on-device argmax is read) + if (logits.data && t_logits && n_outputs > 0 && !cparams.out_argmax && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits.data != nullptr); @@ -3894,6 +4395,50 @@ void llama_set_dflash_accumulated_target_ctx(llama_context * ctx, const float * ctx->set_dflash_accumulated_target_ctx(data, n_embd, n_tokens); } +void llama_set_dflash_state_trace(llama_context * ctx, int32_t n_max) { + ctx->set_dflash_state_trace(n_max); +} + +bool llama_dflash_promote_state(llama_context * ctx, int32_t idx, llama_pos pos_last, llama_seq_id seq_id) { + return ctx->dflash_promote_state(idx, pos_last, seq_id); +} + +bool llama_dflash_trace_check(llama_context * ctx, int32_t n_batch_tokens) { + return ctx->dflash_trace_check(n_batch_tokens); +} + +const int32_t * llama_get_dflash_argmax(llama_context * ctx, int32_t * n_out) { + return ctx->get_dflash_argmax(n_out); +} + +void llama_set_out_argmax(llama_context * ctx, bool value) { + ctx->set_out_argmax(value); +} + +void llama_set_out_spec_sample(llama_context * ctx, bool value, float temp, int32_t topk) { + ctx->set_out_spec_sample(value, temp, topk); +} + +const float * llama_get_dflash_pdraft(llama_context * ctx, int32_t * n_out) { + return ctx->get_dflash_pdraft(n_out); +} + +const int32_t * llama_get_dflash_topk(llama_context * ctx, int32_t * n_rows, int32_t * k, const float ** vals) { + return ctx->get_dflash_topk(n_rows, k, vals); +} + +bool llama_dflash_fetch_logits_row(llama_context * ctx, int32_t row, float * out, int32_t n_vocab) { + return ctx->dflash_fetch_logits_row(row, out, n_vocab); +} + +void llama_dflash_append_features(llama_context * ctx, const float * feat, int32_t n_new, int32_t n_total) { + ctx->dflash_append_features(feat, n_new, n_total); +} + +bool llama_dflash_feed_draft_tokens(llama_context * ctx_tgt, llama_context * ctx_dft, int32_t n) { + return ctx_tgt->dflash_feed_draft_tokens(ctx_dft, n); +} + // // ext diff --git a/src/llama-context.h b/src/llama-context.h index 86f0d81c0ccf..8587d1f8edb8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -110,6 +110,37 @@ struct llama_context { void set_eagle3(const llama_model * model); void set_dflash(const llama_model * model); + // DFlash recurrent rewind (staging): allocate per-token state trace buffers (n_max tokens) and + // enable tracing during multi-token decodes on recurrent layers + void set_dflash_state_trace(int32_t n_max); + // promote the traced state at token index `idx` of the last verify decode into the live + // recurrent cell of seq 0 and mark it as ending at position `pos_last` + bool dflash_promote_state(int32_t idx, llama_pos pos_last, llama_seq_id seq_id = 0); + // debug: bitwise-compare the last traced slot against the live recurrent cell + bool dflash_trace_check(int32_t n_batch_tokens); + + // greedy argmax of the DFlash drafter's last decoded block (nullptr if not produced) + const int32_t * get_dflash_argmax(int32_t * n_out); + + // emit on-device argmax of the output logits and skip the host logits copy (greedy verify) + void set_out_argmax(bool value); + + // emit on-device sampling-verify data: temp baked in; topk>0 emits top-K candidates instead + void set_out_spec_sample(bool value, float temp, int32_t topk); + // per-draft-token temp-softmax probabilities from the last decode (nullptr if not produced) + const float * get_dflash_pdraft(int32_t * n_out); + // per-row top-K candidate token ids (+ logits via vals) from the last decode (row-major) + const int32_t * get_dflash_topk(int32_t * n_rows, int32_t * k, const float ** vals); + // fetch a single row of the device-resident verify logits (for residual/bonus sampling) + bool dflash_fetch_logits_row(int32_t row, float * out, int32_t n_vocab); + + // DFlash decoder: stage the NEW tokens' raw target features for the in-graph encoder fold + // (fc+norm + set_rows append into the device cross cache). n_total = committed context rows. + void dflash_append_features(const float * feat, int32_t n_new, int32_t n_total); + + // async draft feed: hand the drafter's argmax tokens to this (target) context on-device + bool dflash_feed_draft_tokens(llama_context * dft, int32_t n); + // process a single ubatch with a specific graph type // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation @@ -280,6 +311,32 @@ struct llama_context { mutable llama_dflash dflash; + // ownership of the DFlash state-trace tensors (see llama_dflash::trace_s/trace_r) + ggml_context_ptr dflash_trace_ctx; + ggml_backend_buffer_ptr dflash_trace_buf; + + // ownership of the DFlash device cross cache (see llama_dflash::cross_dev) + ggml_context_ptr dflash_cross_ctx; + ggml_backend_buffer_ptr dflash_cross_buf; + + // ownership of the async draft-feed staging + the inter-stream event (see llama_dflash::draft_feed) + ggml_context_ptr dflash_feed_ctx; + ggml_backend_buffer_ptr dflash_feed_buf; + ggml_backend_event_t dflash_feed_event = nullptr; + + // on-device greedy argmax of the DFlash drafter's block logits (see t_argmax) + std::vector dflash_argmax_out; + + // sampling speculative verify: per-draft-token temp-softmax probs (see t_spec_pdraft) and the + // device-resident logits tensor kept for fetching a single residual/bonus row on demand + std::vector dflash_pdraft_out; + ggml_tensor * dflash_logits_dev = nullptr; + + // top-k/top-p verify: per-row top-K candidate ids + logits (row-major [n_out][K]) from last decode + std::vector dflash_topk_idx_out; + std::vector dflash_topk_val_out; + int32_t dflash_topk_k = 0; + // temp fix: avoid DFlash encoder/decoder mis-detection. They share one model_dft, // so shared model fields cannot safely identify the decoder (caused OOM). bool dflash_decoder_ctx = false; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 906bfbe36c12..95252f49b626 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -40,6 +40,10 @@ struct llama_cparams { bool kv_unified; bool eagle3_extract_enabled; // enable layer extraction for EAGLE3 speculative decoding bool dflash_extract_enabled; // enable layer extraction for DFlash speculative decoding + bool out_argmax; // emit on-device argmax of the output logits (greedy verify path) + bool out_spec_sample; // emit on-device temp-softmax prob of each draft token (sampling verify) + float spec_temp; // temperature baked into the in-graph softmax for out_spec_sample + int32_t spec_topk; // >0: emit top-K candidate logits per row (top-k/top-p verify) instead of pdraft bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 9fabd242e766..64d8ee28ccb5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -339,8 +339,43 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { if (cross_embd && !cross->v_embd.empty()) { assert(cross_embd->type == GGML_TYPE_F32); - ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd)); + // append-only fast path (DFlash speculative rounds): rows [0, n_enc_appended) are + // unchanged since the last upload, so only the delta is transferred. a graph rebuild + // creates a fresh input (n_uploaded = -1) and triggers the full upload, which also + // initializes the zero padding of the fixed-capacity buffer. + const int64_t row_bytes = cross->n_embd * ggml_element_size(cross_embd); + if (n_uploaded >= 0 && cross->n_enc_appended >= n_uploaded && + cross->n_enc == cross_embd->ne[1]) { + const int64_t first = n_uploaded; + const int64_t last = cross->n_enc_valid; + if (last > first) { + ggml_backend_tensor_set(cross_embd, + cross->v_embd.data() + (size_t) first * cross->n_embd, + (size_t) first * row_bytes, + (size_t) (last - first) * row_bytes); + } + } else { + ggml_backend_tensor_set(cross_embd, cross->v_embd.data(), 0, ggml_nbytes(cross_embd)); + } + n_uploaded = cross->n_enc_valid; + } +} + +bool llm_graph_input_cross_embd::can_reuse(const llm_graph_params & params) { + GGML_UNUSED(params); + + // The cross embeddings are re-uploaded every step in set_input(), so the graph can be reused as + // long as the cross tensor shape is unchanged. This is what makes DFlash block drafting cheap: + // the target-context buffer is bucketed to a fixed capacity, so within a bucket the decoder + // graph is identical across speculative rounds (previously this input always forced a rebuild). + if (!cross_embd || !cross) { + return false; } + + const int64_t n_embd = !cross->v_embd.empty() ? cross->n_embd : cross_embd->ne[0]; + const int64_t n_enc = !cross->v_embd.empty() ? cross->n_enc : cross_embd->ne[1]; + + return cross_embd->ne[0] == n_embd && cross_embd->ne[1] == n_enc; } static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { @@ -805,6 +840,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_argmax = nullptr; + t_spec_pdraft = nullptr; + t_spec_topk_idx = nullptr; + t_spec_topk_val = nullptr; t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -2811,6 +2850,67 @@ void llm_graph_context::build_pooling( } void llm_graph_context::build_sampling() const { + // on-device greedy argmax over ALL output logit rows (speculative greedy-verify path): + // the verifier only needs the per-position argmax to accept/reject draft tokens, so + // downloading n_outputs ints instead of n_outputs x n_vocab floats skips a multi-MB host + // copy per verify round. some graphs (e.g. the DFlash drafter) set t_argmax themselves. + if (cparams.out_argmax && res->t_logits != nullptr && res->t_argmax == nullptr) { + ggml_tensor * am = ggml_argmax(ctx0, res->t_logits); + cb(am, "result_argmax", -1); + res->t_argmax = am; + ggml_build_forward_expand(gf, am); + } + + // on-device temp-softmax probability of each draft token (sampling speculative verify): + // p_i(d_i) for every output row, gathered via a flat index (i*n_vocab + d_i) the context + // fills from the verify batch tokens. The host then does the cheap rejection test on these + // n_outputs floats and only fetches a single logits row for the residual/bonus sample, + // instead of downloading the whole n_vocab x block logits matrix. + if (cparams.out_spec_sample && res->t_logits != nullptr && + res->t_spec_pdraft == nullptr && res->t_spec_topk_idx == nullptr) { + const int64_t n_vocab = res->t_logits->ne[0]; + const int64_t n_out = res->t_logits->ne[1]; + + if (cparams.spec_topk > 0) { + // top-k/top-p verify: emit the top-K candidate token ids + their raw logits per row. + // The host then applies the full sampler (temp/top-k/top-p) over those K candidates - + // the top-p nucleus is a subset of the top-K, so this is exact for realistic params. + const int64_t K = std::min(cparams.spec_topk, n_vocab); + + // argsort-based top-K (ggml_top_k's CUDA path returned bad indices for ~248k vocab). + // argsort_top_k here returns FLAT indices (row*n_vocab + token) into the [n_vocab,n_out] + // logits, so they directly index the flattened [1, n_vocab*n_out] view - no base needed. + // The host recovers the token id as idx % n_vocab. + ggml_tensor * idx = ggml_cont(ctx0, ggml_argsort_top_k(ctx0, res->t_logits, K)); // I32 [K, n_out], flat + + ggml_tensor * lflat = ggml_reshape_2d(ctx0, res->t_logits, 1, n_vocab * n_out); + ggml_tensor * vals = ggml_get_rows(ctx0, lflat, ggml_reshape_1d(ctx0, idx, K * n_out)); + vals = ggml_reshape_2d(ctx0, vals, K, n_out); // F32 [K, n_out] + + cb(idx, "result_spec_topk_idx", -1); + cb(vals, "result_spec_topk_val", -1); + res->t_spec_topk_idx = idx; + res->t_spec_topk_val = vals; + ggml_build_forward_expand(gf, idx); + ggml_build_forward_expand(gf, vals); + } else { + // temperature-only verify: emit p_i(d_i) directly via a flat gather (exact, no top-K) + ggml_tensor * scaled = ggml_scale(ctx0, res->t_logits, 1.0f / cparams.spec_temp); + ggml_tensor * probs = ggml_soft_max(ctx0, scaled); // [n_vocab, n_out] + + ggml_tensor * gidx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_out); + ggml_set_input(gidx); + ggml_set_name(gidx, "spec_gather_idx"); + + ggml_tensor * pflat = ggml_reshape_2d(ctx0, probs, 1, n_vocab * n_out); + ggml_tensor * pgather = ggml_get_rows(ctx0, pflat, gidx); // [1, n_out] + ggml_tensor * pdraft = ggml_reshape_1d(ctx0, pgather, n_out); + cb(pdraft, "result_spec_pdraft", -1); + res->t_spec_pdraft = pdraft; + ggml_build_forward_expand(gf, pdraft); + } + } + if (samplers.empty() || !res->t_logits) { return; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 1925a275d8a3..9b6cc24d1f0e 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -66,6 +66,15 @@ struct llama_cross { int64_t n_embd = 0; int64_t n_enc = 0; + // DFlash: number of *valid* target-context rows inside the fixed-capacity (n_enc) buffer. + // Rows [n_enc_valid, n_enc) are zero padding and are masked out in the DFlash decoder, so the + // cross tensor keeps a constant shape across speculative rounds (enables graph reuse). + int64_t n_enc_valid = 0; + + // DFlash: the context is append-only across speculative rounds; rows [0, n_enc_appended) of + // v_embd are unchanged since the previous round, so set_input only uploads the delta + int64_t n_enc_appended = 0; + // embeddings data copied to host memory (tmp) std::vector v_embd; @@ -105,6 +114,29 @@ struct llama_dflash { std::vector extract_tensors; + // recurrent state trace (staging): per-token SSM/conv states captured during the verify decode + // of a hybrid target, so that on a partial draft acceptance the state at the accepted position + // is promoted instead of restore+re-decode. tensors live in a persistent context-owned buffer. + int32_t trace_n_max = 0; // max tokens traced per decode (0 = disabled) + std::vector trace_s; // per-layer [n_embd_s, trace_n_max] (recurrent layers only) + std::vector trace_r; // per-layer [n_embd_r, trace_n_max] (conv windows) + + // device-resident encoded-context cache for the DFlash decoder (encoder folded into the + // decoder graph): new target features are fc+norm'ed in-graph and appended into cross_dev + // via ggml_set_rows, eliminating the separate encoder llama_encode round trip per draft. + ggml_tensor * cross_dev = nullptr; // [n_embd, cross_cap + 1] (last row = scratch for padding) + int32_t cross_cap = 0; // capacity in rows (0 = host-mediated legacy path) + std::vector feat_staging; // host staging of the NEW tokens' raw target features + int32_t feat_n = 0; // number of staged feature rows + int32_t feat_pos0 = 0; // destination row of the first staged feature + int32_t feat_bucket = 8; // padded feature rows in the graph (8 normally; 256 for the prompt round) + + // async draft feed: the drafter's argmax tokens are handed device-to-device into the verify + // batch, so there is no host synchronization between the draft and the verify submission + ggml_tensor * last_argmax_t = nullptr; // this context's argmax tensor from the last decode + ggml_tensor * draft_feed = nullptr; // (target ctx) device staging for fed draft tokens [I32] + int32_t draft_feed_n = 0; // pending fed rows to patch into inp_tokens rows [1..n] + void clear() { target_features.clear(); extract_tensors.clear(); @@ -292,6 +324,12 @@ class llm_graph_input_cross_embd : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + // rows of cross->v_embd already uploaded to this input's device tensor (-1 = never uploaded; + // a graph rebuild creates a fresh input object, so the full upload happens automatically) + int64_t n_uploaded = -1; + ggml_tensor * cross_embd; // F32 [n_embd, n_outputs_enc] const llama_cross * cross; @@ -684,6 +722,10 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_argmax() const { return t_argmax; } + ggml_tensor * get_spec_pdraft() const { return t_spec_pdraft; } + ggml_tensor * get_spec_topk_idx() const { return t_spec_topk_idx; } + ggml_tensor * get_spec_topk_val() const { return t_spec_topk_val; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -712,6 +754,10 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_argmax = nullptr; // I32 [n_tokens] greedy tokens (DFlash drafter) + ggml_tensor * t_spec_pdraft = nullptr; // F32 [n_tokens] temp-softmax prob of each draft token + ggml_tensor * t_spec_topk_idx = nullptr; // I32 [K, n_tokens] top-K token ids per row (top-k/top-p verify) + ggml_tensor * t_spec_topk_val = nullptr; // F32 [K, n_tokens] their raw logits std::map t_sampled_logits; std::map t_candidates; diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 6bc989c95099..cb78b4067e61 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -397,7 +397,14 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + ggml_tensor * result; + if (gdn_trace != nullptr) { + // per-token state trace requested (DFlash speculative rewind on recurrent targets) + result = ggml_gated_delta_net_trace(ctx0, q, k, v, g, b, s, gdn_trace); + gdn_trace = nullptr; // consumed + } else { + result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + } if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -434,6 +441,7 @@ std::pair llm_build_delta_net_base::build_delta_ne if (cparams.fused_gdn_ar) { return build_delta_net_fused(q, k, v, g, b, s, il); } + GGML_ASSERT(gdn_trace == nullptr && "GDN state trace requires the fused kernel path"); return build_delta_net_autoregressive(q, k, v, g, b, s, il); } @@ -441,5 +449,6 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_fused(q, k, v, g, b, s, il); } + GGML_ASSERT(gdn_trace == nullptr && "GDN state trace requires the fused kernel path"); return build_delta_net_chunking(q, k, v, g, b, s, il); } diff --git a/src/models/dflash.cpp b/src/models/dflash.cpp index 0adba127eabf..d397be4f2ffa 100644 --- a/src/models/dflash.cpp +++ b/src/models/dflash.cpp @@ -1,5 +1,26 @@ #include "models.h" +// graph-reuse guard for the device-cache decoder path: the feat/mask/position tensor shapes are +// baked from (cross->n_enc, dflash->feat_bucket) at build time; force a rebuild when either moves +class llm_graph_input_dflash_dev : public llm_graph_input_i { +public: + llm_graph_input_dflash_dev(const llama_cross * cross, const llama_dflash * df, + int64_t n_enc_built, int32_t bucket_built) + : cross(cross), df(df), n_enc_built(n_enc_built), bucket_built(bucket_built) {} + + void set_input(const llama_ubatch * ubatch) override { GGML_UNUSED(ubatch); } + + bool can_reuse(const llm_graph_params & params) override { + GGML_UNUSED(params); + return cross->n_enc == n_enc_built && df->feat_bucket == bucket_built; + } + + const llama_cross * cross; + const llama_dflash * df; + int64_t n_enc_built; + int32_t bucket_built; +}; + ggml_tensor * llm_build_dflash_encode::build_inp_embd() const { const int64_t n_target_layer_ids = (int64_t) hparams.dflash_target_layer_ids.size(); const int64_t n_embd_target_features = n_target_layer_ids * n_embd; @@ -40,9 +61,44 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons ggml_tensor * noise_embd = build_inp_embd(model.target_tok_embd); cb(noise_embd, "inp_noise_embd", -1); - // Target context via llama_cross (filled from accumulated_target_ctx), graph rebuilds every step - ggml_tensor * target_ctx = build_inp_cross_embd(); - const int64_t n_ctx = target_ctx->ne[1]; + ggml_tensor * target_ctx = nullptr; + int64_t n_ctx = 0; + + if (dflash != nullptr && dflash->cross_dev != nullptr) { + // encoder folded into the decoder graph: the NEW tokens' raw target features arrive as an + // input, get fc+norm'ed here and appended into the persistent device cache (cross_dev) via + // set_rows; the attention context is a view of that cache. this removes the separate + // encoder llama_encode + the host round trip of the encoded features per draft round. + const int64_t n_feat = (int64_t) hparams.dflash_target_layer_ids.size() * n_embd; + // padded feature rows; extra rows land in the scratch row via the index input. bucketed + // (8 for normal rounds, 256 for the prompt round) so the fc GEMM does not waste flops on + // padding - the graph rebuilds once when the bucket changes + const int64_t n_new_max = dflash->feat_bucket; + + ggml_tensor * feat = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_feat, n_new_max); + ggml_set_input(feat); + cb(feat, "dflash_feat_new", -1); + + ggml_tensor * fidx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_new_max); + ggml_set_input(fidx); + cb(fidx, "dflash_feat_idx", -1); + + ggml_tensor * enc = build_lora_mm(model.fc, feat); + enc = build_norm(enc, model.dflash_hidden_norm, NULL, LLM_NORM_RMS, -1); + cb(enc, "dflash_enc_new", -1); + + ggml_tensor * cache = ggml_set_rows(ctx0, dflash->cross_dev, enc, fidx); + cb(cache, "dflash_cross_cache", -1); + + n_ctx = cross->n_enc; // bucketed valid+padding rows (padding masked out) + target_ctx = ggml_view_2d(ctx0, cache, n_embd, n_ctx, cache->nb[1], 0); + + res->add_input(std::make_unique(cross, dflash, n_ctx, (int32_t) n_new_max)); + } else { + // legacy host-mediated path: accumulated context uploaded via llama_cross + target_ctx = build_inp_cross_embd(); + n_ctx = target_ctx->ne[1]; + } ggml_tensor * inpL = noise_embd; @@ -57,6 +113,15 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons ggml_tensor * inp_pos_q = ggml_view_1d(ctx0, inp_pos_full, n_tokens, n_ctx * ggml_element_size(inp_pos_full)); + // Additive attention mask over [target_ctx (n_ctx) ++ noise (n_tokens)] for the noise queries. + // target_ctx is a fixed-capacity buffer so the graph shape stays constant across speculative + // rounds (enables graph reuse); the padding rows are masked out. Values are filled per round in + // llama_context::process_ubatch (named "dflash_kq_mask"). Requires the eager soft_max path, + // which is why flash attention is disabled for the DFlash decoder context. + ggml_tensor * dflash_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens_kv, n_tokens); + ggml_set_input(dflash_kq_mask); + cb(dflash_kq_mask, "dflash_kq_mask", -1); + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); for (int il = 0; il < n_layer; ++il) { @@ -119,7 +184,7 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons ggml_build_forward_expand(gf, Kcur); ggml_build_forward_expand(gf, Vcur); - ggml_tensor * cur = build_attn_mha(Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, nullptr, kq_scale, il); + ggml_tensor * cur = build_attn_mha(Qcur, Kcur, Vcur, nullptr, dflash_kq_mask, nullptr, nullptr, kq_scale, il); cb(cur, "kqv_out", il); cur = build_lora_mm(layer.wo, cur); @@ -155,6 +220,14 @@ llm_build_dflash_decode::llm_build_dflash_decode(const llama_model & model, cons cur = build_lora_mm(model.target_output, cur); cb(cur, "result_output", -1); res->t_logits = cur; + + // GPU argmax over the block logits: the DFlash draft is greedy top-1, so downloading + // block_size ints instead of n_vocab x block_size floats (~5 MB/round at vocab 248k) + // removes the per-round logits host copy + CPU scan entirely + ggml_tensor * am = ggml_argmax(ctx0, cur); + cb(am, "result_argmax", -1); + res->t_argmax = am; + ggml_build_forward_expand(gf, am); } ggml_build_forward_expand(gf, cur); diff --git a/src/models/models.h b/src/models/models.h index 062e6ff621d2..e7efcc4823af 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -64,6 +64,10 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // optional per-token state trace target for the NEXT build_delta_net call (fused path only); + // set by the caller (e.g. qwen35 during a DFlash verify), consumed+reset by build_delta_net_fused + ggml_tensor * gdn_trace = nullptr; }; struct llm_build_rwkv6_base : public llm_graph_context { diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 19d3d95619d0..484e0cd988b3 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -290,6 +290,33 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + // DFlash recurrent rewind (staging): during the speculative verify decode, capture the per-token + // conv windows and (below, via gdn_trace) the per-token SSM states, so that on a partial draft + // acceptance the state at the accepted position is PROMOTED instead of restore+re-decode. + const bool dflash_trace = dflash != nullptr && dflash->trace_n_max > 0 && + n_seqs == 1 && n_seq_tokens > 1 && n_seq_tokens <= dflash->trace_n_max && + cparams.fused_gdn_ch && + (size_t) il < dflash->trace_s.size() && dflash->trace_s[il] != nullptr; + + if (dflash_trace) { + // conv window after token t = rows [t+1 .. t+conv_kernel_size-1] of conv_input (the same + // view pattern as last_conv_states above). One clean in-bounds sub-view copy per token: + // an earlier single-overlapping-3D-view optimization sat exactly on the buffer boundary + // (data_size + offset == ggml_nbytes(conv_input)) and aborted in ggml_view_3d whenever + // conv_input had a different row count than (k-1)+n_seq_tokens (observed on sm_120 and in + // the server's prompt-chunk path). The per-token windows are always strictly in bounds. + const int64_t conv_sz = (conv_kernel_size - 1) * conv_channels; // == hparams.n_embd_r() + for (int64_t t = 0; t < n_seq_tokens; ++t) { + ggml_tensor * win = ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, 1, + conv_input->nb[1], conv_input->nb[2], + (t + 1) * ggml_element_size(conv_input)); + ggml_tensor * dst_t = ggml_view_1d(ctx0, dflash->trace_r[il], conv_sz, + (size_t) t * conv_sz * ggml_element_size(dflash->trace_r[il])); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, win, dst_t)); + } + } + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); cb(state, "state_predelta", il); @@ -350,6 +377,11 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); + if (dflash_trace) { + // per-token SSM state trace target, consumed by the fused GDN op (see ggml_gated_delta_net_trace) + gdn_trace = ggml_view_1d(ctx0, dflash->trace_s[il], hparams.n_embd_s() * n_seq_tokens, 0); + } + auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); ggml_tensor * output = attn_out.first; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c835dd8a44c2..9d216109fe5c 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -92,6 +92,12 @@ struct server_slot { server_prompt_checkpoint spec_ckpt; common_speculative_ptr spec; + // DFlash recurrent rewind: when set, the target's per-token recurrent states are traced during + // the verify decode so a partial acceptance promotes the state at the accepted position on-device + // instead of the ~50 MiB host checkpoint round-trip + re-decode (see llama_dflash_promote_state) + bool spec_state_trace = false; + llama_pos spec_pos0 = 0; // base position of the current verify batch (rewind target = pos0 + accepted) + // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 std::unique_ptr task; @@ -363,7 +369,9 @@ struct server_slot { spec_draft.clear(); } - if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL) { + // the host checkpoint is only needed for the restore-based rewind; with the + // on-device state trace a partial acceptance promotes the state instead (no host copy) + if (!spec_draft.empty() && ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL && !spec_state_trace) { const auto n_tokens = prompt.tokens.size(); spec_ckpt = server_get_checkpoint(ctx, this->id, n_tokens); @@ -396,6 +404,7 @@ struct server_slot { } auto pos0 = prompt.tokens.pos_next(); + spec_pos0 = pos0; // base position of the verify batch (for the DFlash on-device rewind) common_batch_add(batch, sampled, pos0++, { this->id }, true); for (auto token : spec_draft) { @@ -929,6 +938,20 @@ struct server_context_impl { if (slot.spec) { SLT_INF(slot, "%s", "speculative decoding context initialized\n"); + + // DFlash on a hybrid/recurrent target: enable the recurrent state trace so a + // partial acceptance promotes the accepted-position state on-device instead of + // the host checkpoint round-trip. Only for the FULL-seq-rm regime (hybrid), and + // only at n_parallel == 1 (the per-context feature extraction limitation above). + const bool trace_ok = params_base.speculative.dflash && + ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL && + params_base.n_parallel == 1; + if (trace_ok && + !(getenv("LLAMA_SPEC_NO_TRACE") && std::string(getenv("LLAMA_SPEC_NO_TRACE")) != "0")) { + llama_set_dflash_state_trace(slot.ctx, params_base.speculative.n_max + 1); + slot.spec_state_trace = true; + SLT_INF(slot, "%s", "DFlash recurrent state trace enabled (on-device rewind)\n"); + } } } @@ -2987,9 +3010,10 @@ struct server_context_impl { { const bool use_ckpt = slot.ctx_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL; - // only save the sampler sampler state if we use checkpoints + // only save the sampler state if we use the checkpoint-restore rewind; the + // on-device trace rewind commits forward and never rolls the sampler back common_sampler_ptr smpl_save; - if (use_ckpt) { + if (use_ckpt && !slot.spec_state_trace) { smpl_save.reset(common_sampler_clone(slot.smpl.get())); } @@ -3003,7 +3027,22 @@ struct server_context_impl { // check for partial draft acceptance if (accepted.size() < slot.spec_draft.size() + 1) { - if (use_ckpt) { + if (slot.spec_state_trace) { + // DFlash recurrent rewind: promote the traced state at the accepted + // position instead of restoring a host checkpoint. The verify batch was + // [sampled @ pos0, draft0 @ pos0+1, ...]; accepting `acc` drafts means the + // state after batch token `acc` (trace slot `acc`, ending at pos0 + acc) + // is correct. Then fall through to the normal commit path below - its + // llama_memory_seq_rm(pos) truncates the rejected attention KV tail and + // now succeeds because the recurrent cell pos was rewound. + const int32_t acc = (int32_t) accepted.size() - 1; + const llama_pos pos_last = slot.spec_pos0 + acc; + + if (!llama_dflash_promote_state(slot.ctx, acc, pos_last, slot.id)) { + GGML_ABORT("%s: DFlash state promote failed (idx=%d)\n", __func__, acc); + } + // no checkpoint restore, no `continue` - fall through to commit + } else if (use_ckpt) { // partial acceptance is not supported by the context -> truncate the draft and restore the state slot.spec_draft = std::move(accepted); From e8bfef237bd048a4b413a17107b6ea42f21fa5db Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Fri, 12 Jun 2026 18:11:16 +0200 Subject: [PATCH 19/21] gdn: portable chunk-parallel Gated-DeltaNet verify path (opt-in) Decompose the GDN recurrence into a pure ggml-op graph (cumsum, exp, mul_mat, tri, solve_tri, diag, concat) so the verify can run on backends that lack a fused GDN kernel (WebGPU, Metal, Vulkan). Multi-chunk tiling keeps exp(cumsum(g)) in fp32 range; handles both the vector (KDA) and per-head scalar gate, and GQA. Validated bitwise against ggml_gated_delta_net on CPU and CUDA (tests/test-gdn-chunked). Opt-in via LLAMA_GDN_CHUNKED; the default path is unchanged. This is for portability: on CUDA the fused kernel is faster and the GDN scan is not the verify bottleneck. --- DESIGN.md | 320 +++++++++++++++++ FINDINGS.md | 117 +++++++ GDN_CHUNKED_BRINGUP.md | 170 +++++++++ ggml/src/ggml-backend.cpp | 58 +++- ggml/src/ggml-cuda/gated_delta_net.cu | 10 + ggml/src/ggml-cuda/gated_delta_net.cuh | 14 + ggml/src/ggml-cuda/gated_delta_net_chunked.cu | 325 ++++++++++++++++++ ggml/src/ggml-cuda/gdn_chunked_oracle.py | 67 ++++ src/models/delta-net-base.cpp | 108 ++++++ src/models/models.h | 21 ++ tests/test-gdn-chunked.cpp | 235 +++++++++++++ 11 files changed, 1444 insertions(+), 1 deletion(-) create mode 100644 DESIGN.md create mode 100644 FINDINGS.md create mode 100644 GDN_CHUNKED_BRINGUP.md create mode 100644 ggml/src/ggml-cuda/gated_delta_net_chunked.cu create mode 100644 ggml/src/ggml-cuda/gdn_chunked_oracle.py create mode 100644 tests/test-gdn-chunked.cpp diff --git a/DESIGN.md b/DESIGN.md new file mode 100644 index 000000000000..5061ac2c8a6c --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,320 @@ +# Chunk-parallel Gated-DeltaNet (GDN) for DFlash speculative VERIFY + +Status: DESIGN + reviewable CUDA skeleton (no GPU available; not yet compiled/validated). +Scope: the **verify** path only — single sequence (`n_seqs == 1`), a block of `N` tokens +(N = draft-max + 1, up to ~16, design supports up to 32). Prefill and single-token decode keep +using the existing sequential kernel. + +Files: +- existing sequential kernel: `ggml/src/ggml-cuda/gated_delta_net.cu` +- CPU reference (exact math): `ggml/src/ggml-cpu/ops.cpp` + (`ggml_compute_forward_gated_delta_net_one_chunk`) +- new skeleton: `ggml/src/ggml-cuda/gated_delta_net_chunked.cu` +- dispatch hook: `ggml/src/ggml-cuda/ggml-cuda.cu` (`GGML_OP_GATED_DELTA_NET`, ~L2934) + +--- + +## 1. The exact recurrence (from the CPU reference) + +Per head, per sequence. Let `S_k = S_v = D` (the existing kernel assumes square state; head dim +`D in {16,32,64,128}`). The recurrent state is a `D x D` matrix `S` with `S[i][j]`, where `i` indexes +the **key** dimension and `j` indexes the **value** dimension. + +> Storage detail: both the CPU ref and the CUDA kernel store `S` **transposed** as +> `M[j][i] = S[i][j]` so that "row j of M" (contiguous) is "column j of S". The math below is in the +> mathematical `S[i][j]` convention; the kernel maps it to the transposed layout. + +Inputs at token `t` (all f32): `q_t, k_t in R^D` (key dim), `v_t in R^D` (value dim), +`beta_t in R` (scalar), gate `g_t`: +- **scalar gate** (non-KDA): `g_t in R`, decay `a_t = exp(g_t)` applied to the whole state. +- **KDA / vector gate**: `g_t in R^D` indexed by the **key** dim `i`, decay `a_t[i] = exp(g_t[i])` + applied per key-row. + +The update (matching `ops.cpp` lines 10522-10552 exactly): + +``` +1. decay: S[i][j] <- a_t[i] * S[i][j] (a_t[i] = exp(g_t[i]); scalar: a_t[i]=exp(g_t) for all i) +2. kv: u_t[j] = sum_i S[i][j] * k_t[i] = (S^T k_t)[j] +3. delta: d_t[j] = (v_t[j] - u_t[j]) * beta_t +4. update: S[i][j] <- S[i][j] + k_t[i] * d_t[j] (rank-1: S += k_t d_t^T) +5. output: o_t[j] = scale * sum_i S[i][j] * q_t[i] = scale * (S^T q_t)[j] (scale = 1/sqrt(D)) +``` + +Substituting (3) into (4), with `S_{t-1}` the post-(prev-token) state and `S_t` the post-update +state, this is the **gated delta rule**: + +``` +S_t = diag(a_t) S_{t-1} + k_t ( beta_t (v_t - (diag(a_t) S_{t-1})^T k_t) )^T +o_t = scale * S_t^T q_t +``` + +Define the **effective value** (a.k.a. "new value" / pseudo-value in DeltaNet) so the update becomes +a plain (non-recursive-in-S) rank-1 add: + +``` +w_t = beta_t * k_t (D, key dim) <- write key +u_t = (diag(a_t) S_{t-1})^T k_t (D, value dim) <- what's already stored +d_t = beta_t * v_t - beta_t * u_t (D, value dim) <- delta value +S_t = diag(a_t) S_{t-1} + k_t d_t^T +``` + +`o_t` reads `S_t` (post-update, includes the current token — note step 5 runs *after* step 4). + +This is the per-token recurrence the existing CUDA kernel runs `N` times sequentially in the verify +block. Cost of verify ~ `N` sequential steps, each `O(D^2)` work. We want to cut the **sequential +depth** from `N` to `N/C`. + +--- + +## 2. Chunked derivation (chunked delta-rule / chunked linear attention) + +Reference: FLA `chunk_delta_rule` / `chunk_gated_delta_rule`; Yang et al. "Parallelizing Linear +Transformers with the Delta Rule over Sequence Length" (DeltaNet) and "Gated DeltaNet". + +Split the `N` tokens into chunks of size `C` (e.g. `C = 16`, so a 16-token verify block is **one +chunk**; a 32-token block is two). Index tokens within a chunk by `r = 0..C-1` (global token +`t = chunk_base + r`). Let `S_in` be the state entering the chunk (the carry). + +### 2.1 Cumulative gate products inside the chunk + +For the **KDA / vector gate**, the per-key-dim decay is multiplicative, so define inclusive cumulative +products along the chunk (per key dim `i`): + +``` +A_r[i] = prod_{s=0..r} a_s[i] (inclusive, decay applied up to and including token r) +``` + +with `A_{-1}[i] = 1`. The decay from "just after token s applied" to "the chunk boundary after token +C-1" is `A_{C-1}[i] / A_s[i]`. For the **scalar gate**, `a_s` is a scalar and `A_r` collapses to a +scalar per token — same formulas, broadcast over `i`. + +To keep the rank-1 writes commutable, **pre-scale** each token's write key into a common reference +frame (the chunk start). Define: + +``` +k~_r[i] = k_r[i] / A_r[i] (deflated write key — "undo" the decay it will accumulate) +q~_r[i] = q_r[i] * A_r[i] (inflated query — apply decay the carry would have gotten) +``` + +Intuition: a rank-1 contribution `k_s d_s^T` written at token `s` gets multiplicatively decayed by +`A_{r}[i]/A_s[i]` (key dim) by the time we read at token `r >= s`. Folding `1/A_s` into the key and +`A_r` into the query realizes that decay through a single elementwise scale per token, so the +intra-chunk interactions become plain matmuls. (This is exactly the FLA "secondary chunking" trick; +do the cumprod in log space — see section 5.) + +### 2.2 Intra-chunk parallel form + +Stack the chunk into matrices (rows = tokens within the chunk): +`K, Q, V in R^{C x D}` (rows `k_r, q_r, v_r`), `K~, Q~` the deflated/inflated versions, `beta in R^C`. + +**(a) Carry read (contribution of `S_in` to every token's `u` and `o`):** + +The "already stored" value seen by token `r` from the *incoming* state is +`(diag(A_r) S_in)^T k_r = S_in^T (A_r (.) k_r)`. So with `Kbar_r = A_r (.) k_r` (the **inflated** read +key) stacked into `Kbar in R^{C x D}`: + +``` +U_carry = Kbar @ S_in in R^{C x D} (each row = u_r^carry, value dim) +``` + +The carry contribution to the output uses the *post-update* state, but since `S_in` is constant +within the chunk its output contribution is `O_carry = scale * (Qbar @ S_in)` with +`Qbar_r = A_r (.) q_r`. + +**(b) Intra-chunk token-token interactions (the delta-rule coupling):** + +Within the chunk, token `r`'s delta `d_r` depends on the writes of all earlier tokens `s < r` (and on +the carry). Build the **strictly-lower-triangular** decayed attention matrix between deflated keys: + +``` +T[r][s] = beta_r * ( k~_r . k~_s ) for s < r, else 0 in R^{C x C} +``` + +The delta-rule "un-mixing" is the classic `(I + tril(T,-1))^{-1}` solve (forward substitution over the +chunk, `C` sequential micro-steps but only on a `C x C` system, cheap and in shared memory). Let + +``` +W = (I + strict_tril(T))^{-1} in R^{C x C} +Dmat = W @ ( beta (.) (V - U_carry) ) in R^{C x D} (rows = d_r, the resolved deltas) +``` + +(`beta (.) V` is row-scaling `V` by `beta_r`; `U_carry` from (a).) `Dmat` rows are exactly the +per-token delta values `d_r` consistent with the sequential recurrence — now computed by two matmuls + +one small triangular solve instead of `C` rank-1 steps. + +**(c) Per-token output:** + +``` +O = O_carry + scale * tril( Q~ @ K~^T ) @ Dmat in R^{C x D} +``` +The `tril(Q~ K~^T)` term sums the intra-chunk writes that token `r` should see. The output reads the +**post-update** state, so the current token's own write (`s == r`) must be included — use the +lower-triangle **including** the diagonal for this output term, while the `T` solve in (b) stays +**strictly** lower. Mapping the exact diagonal handling is the one subtlety to nail against the +reference (see section 6 validation). + +### 2.3 Inter-chunk state carry (the only sequential part) + +After the chunk, the new boundary state: + +``` +S_out = diag(A_{C-1}) S_in + Kw^T @ Dmat + = diag(A_{C-1}) S_in + sum_r ((A_{C-1}/A_r) (.) k_r) d_r^T +``` +with `Kw_r = (A_{C-1}/A_r) (.) k_r` stacked into `Kw in R^{C x D}` (each write key carried forward to +the chunk end). This is one `D x D` update per chunk. + +**Sequential depth = number of chunks = ceil(N/C).** With `N <= C` (verify block <= 16 and `C = 16`) +the whole verify is **a single chunk**: zero inter-chunk recurrence, everything is matmuls + one +`C x C` triangular solve. That is the win. + +--- + +## 3. CUDA kernel structure (`gated_delta_net_chunked_cuda`) + +One CUDA **block per (head, sequence)** — for verify `sequence` is fixed (n_seqs==1), so grid is +`(H, 1, 1)`. Each block owns the chunk's `C x D` tiles and the `D x D` carry state in shared memory. + +Tiling (for the verify regime: `C <= 32`, `D in {16,32,64,128}`): +- Shared mem holds: `S` (`D x D` f32), `K,Q,V,K~,Q~,Kbar,Qbar` chunk tiles (`C x D` each), `T`/`W` + (`C x C`), `Dmat` (`C x D`), `A` cumprods (`C x D` for KDA, `C` for scalar). For `D=128, C=16` that + is `128*128*4 = 64KB` for `S` alone — at the edge of the 48-96KB smem budget, so for `D=128` either + keep `S` in registers (sharded across the warp like the existing kernel) or cap `C` smaller / use + the host-decomposition fallback (3.2). For `D <= 64` everything fits comfortably. +- Threads: a 2D thread block, `D` lanes x `num_warps` (mirror the existing + `block_dims(min(warp,D), num_warps)`). Matmuls are done cooperatively; the `C x C` triangular solve + is done by a single warp (C <= 32 fits one warp) via forward substitution. + +Phases inside the kernel (single chunk; the multi-chunk loop wraps phases 2-6): +1. **Load + gate cumprod.** Load `g`, compute `a_r = exp(g_r)`, inclusive cumprod `A_r` along the + chunk **in f32 / log space** (Hillis-Steele scan across `C`). Build `k~,q~,Kbar,Qbar,Kw`. +2. **U_carry = Kbar @ S_in**, **O_carry = scale . (Qbar @ S_in)** — two `C x D . D x D` matmuls. +3. **T = strict_tril(beta (.) (K~ K~^T))** — a `C x D . D x C` matmul, mask to strict lower. +4. **Solve `W (I+T)`:** forward-substitution to get `Dmat = (I+T)^{-1} (beta (.) (V - U_carry))` + (C sequential micro-steps on the small `C x C` system, one warp). +5. **Output:** `O = O_carry + scale . tril(Q~ K~^T) @ Dmat`; write `O` rows to `attn_data` (same + `[S_v.H]`-strided layout as the sequential kernel) and, if `trace != nullptr`, materialize the + per-token state trace (see 3.3). +6. **Carry:** `S_out = diag(A_{C-1}) S_in + Kw^T @ Dmat`; write back transposed `M[j][i]`. + +### 3.1 Dispatch hook + +In `ggml/src/ggml-cuda/ggml-cuda.cu`, `GGML_OP_GATED_DELTA_NET` (~L2934) currently calls +`ggml_cuda_op_gated_delta_net`. Add inside that op (in `gated_delta_net.cu`'s +`ggml_cuda_op_gated_delta_net`) a guarded fast path: + +``` +if (n_seqs == 1 && n_tokens >= GDN_CHUNK_MIN && n_tokens <= GDN_CHUNK_MAX && S_v <= GDN_CHUNK_DMAX) + launch_gated_delta_net_chunked(...); // new path (verify block) +else + launch_gated_delta_net(...); // existing sequential path (prefill / single decode) +``` + +`GDN_CHUNK_MIN` ~ 2 (no point for a single token), `GDN_CHUNK_MAX` ~ 32, `GDN_CHUNK_DMAX` initially +64 (raise to 128 once the smem/register strategy for `D=128` is validated). The trace output and the +final-state writeback use the **same** dst layout (`[attn_scores | new_states]`) and the same +transposed state convention, so nothing downstream changes. + +### 3.2 Host-decomposition fallback + +If a single monolithic kernel is too much for a first cut, the same math maps onto existing ggml CUDA +ops as a host-side graph (per head, single chunk): `ggml_mul_mat` for `Kbar@S`, `Q~K~^T`, `Kw^T@Dmat`; +elementwise muls for gating; a tiny custom kernel only for the `C x C` triangular solve. Slower than +the fused kernel (extra global-memory round trips) but a correctness oracle and a quick path to a +working verify. The skeleton notes this decomposition. + +### 3.3 Trace compatibility (DFlash rewind) + +DFlash needs the **per-token** state `S_t` for partial-acceptance rewind (`src[6]` trace). The chunked +kernel does not naturally produce per-token `S_t` (it jumps chunk->chunk). Two options: +- **(preferred)** After computing `Dmat`, materialize `S_r = diag(A_r) S_in + Kw(->r)^T @ Dmat[0..r]` + for each `r` via a small cumulative pass (the prefix of the chunk update) and write the trace rows. + Costs `C` light steps but they're independent across `r` (can be a parallel segmented scan). +- **(fallback)** When `trace != nullptr`, route to the sequential kernel (it already writes the trace + near-free). Verify still benefits whenever the harness doesn't request a trace; partial-accept paths + pay the sequential cost. Start here, then implement the prefix-trace. + +--- + +## 4. Expected speedup + +- Sequential kernel verify cost ~ `N` sequential GDN steps (latency-bound: each step's rank-1 update + + two reductions depends on the previous). For the 24 GDN layers this is the dominant verify cost + and is why single-stream speedup caps at ~1.5-1.7x. +- Chunked: sequential **depth** drops to `ceil(N/C)`. For `N <= 16, C = 16` -> **depth 1**. The + remaining work is matmuls (`C x D x D`) + one `C x C` solve, which are throughput-bound and overlap + well; on a modern GPU the `C x D x D` matmuls for `C,D <= 128` are far below peak and hide behind + issue latency. +- Net: verify GDN latency goes from `O(N)` serial to `O(N/C)` serial + parallel intra-chunk. This is + the same structural change that lets SGLang reach ~3x — we expect the single-stream cap to move from + ~1.5-1.7x toward the ~2.5-3x regime, gated by how much of end-to-end time is GDN verify vs the + full-attention layers and sampling. +- The arithmetic *work* slightly increases (the `(I+T)^{-1}` solve + extra matmuls), but it converts + serial-dependent work into parallel work, which is the right trade for a latency-bound verify. + +--- + +## 5. Numerical-stability concerns + +- **fp32 accumulation everywhere** — mirrors the top-level CLAUDE.md lesson (fp16 mean-pool overflowed + to +/-inf on attention-sink channels and poisoned every downstream lstsq). All cumprods, matmul + accumulators, the `C x C` solve, and the state must accumulate in **f32** (the sequential kernel is + already all-f32; keep parity). Never down-cast the state or the gate products to fp16. +- **Cumulative gate product underflow/overflow.** `A_r[i] = prod a_s[i]` with `a_s = exp(g_s)`. Over a + chunk of 16 tokens with strongly negative `g` (heavy decay), `A_{C-1}` can underflow and the + deflated key `k~_r = k_r / A_r` can blow up — the classic instability the FLA chunked kernels guard. + Mitigations: (a) keep `C` modest (16) so the product spans few tokens; (b) work in **log space** for + the cumulative gate (`L_r[i] = sum_{s<=r} g_s[i]`, then `A_r = exp(L_r)`, and form ratios as + `exp(L_r - L_s)` rather than dividing two exponentials) — this is the numerically safe way to get + `A_r/A_s` and `A_{C-1}/A_r` without ever materializing a tiny denominator; (c) f32 throughout. +- **Triangular solve conditioning.** `(I + strict_tril(T))` is unit-lower-triangular, so always + invertible and forward-substitution is stable; just accumulate in f32. +- **Diagonal/self-token bookkeeping** is the main *correctness* (not stability) risk — the output reads + the **post-update** state, so the current token's own write must be included. Validate against the + reference (section 6) rather than reasoning it through once. + +--- + +## 6. Step-by-step plan to production-correct + validation + +1. **CPU oracle first.** Implement the chunked math as a second CPU function next to + `ggml_compute_forward_gated_delta_net_one_chunk` (or a standalone test harness) and assert it + reproduces the sequential CPU reference (f32, `|delta| < 1e-4` per element) on random + `q,k,v,g,beta,S_in` for both scalar and KDA gates, for `D in {16,32,64,128}` and + `C in {1,2,4,8,16}`. This nails the diagonal/self-token and the gate-ratio direction *before* any + CUDA. +2. **Single-chunk CUDA kernel** for `n_seqs==1`, `N==C`, `D <= 64`. Compare its `attn` output and + `S_out` against the sequential CUDA kernel on the same inputs (host-side max-abs diff `< 1e-3` f32). +3. **Multi-chunk loop** (`N` = a few chunks); re-check the inter-chunk carry matches sequential. +4. **Trace path** (3.3 preferred): verify the per-token trace rows equal the sequential kernel's trace + element-for-element (this is what DFlash rewind reads — must match exactly). +5. **D=128 strategy**: pick register-sharded `S` (like the existing kernel) or smem; re-validate. +6. **Wire dispatch** behind the `n_seqs==1 && N in [MIN,MAX] && D <= DMAX` guard; keep the sequential + path as the default so prefill/decode are untouched. Add an env/define kill-switch + (`GGML_CUDA_GDN_CHUNKED=0`) to fall back at runtime during bring-up. +7. **End-to-end**: run the Qwen3.5-4B DFlash verify on a real prompt, confirm accepted-token sequences + are identical to the sequential-verify build (greedy + fixed seed), then measure tok/s. Validation = + *identical accepted tokens* + improved verify latency. + +Validation harness lives alongside the existing GDN tests (search `test-backend-ops` / +`gated_delta_net` test cases); add a chunked-vs-sequential equivalence case there. + +--- +## CORRECTION (validated by gdn_chunked_oracle.py, bitwise vs sequential, max err ~1e-13) + +The pairwise inter-token decay (s -> r) is **A_r / A_s**, NOT 1/(A_r·A_s). So the deflation must be +ASYMMETRIC: the LATER token r carries A_r (Kbar/Qbar = A⊙k, A⊙q), the EARLIER token s carries 1/A_s +(Ktil = k/A). The dot Kbar_r·Ktil_s = sum_i k_r k_s · A_r/A_s (bounded ≤1 for s compatible. + +2. **DFlash-specific nodes are capturable.** `ggml_set_rows(cross_dev,...)` (`dflash.cpp` L90), the + per-token conv/state trace `ggml_cpy` nodes (`qwen35.cpp` L301-317, L380-396), `ggml_argmax` + (`dflash.cpp` L227), and the top-k `argsort` verify path all map to normal CUDA ops with **no host + stream sync** (checked `set-rows.cu` / `argmax.cu` / `argsort.cu` — only `cudaMemcpyAsync` D2D, which + is capturable). => none disables capture. + +3. **Stable destinations / offsets.** `cross_dev`, `trace_s[il]`, `trace_r[il]` are persistent tensors + (allocated once in `dflash_cross_ctx` / `dflash_trace_buf`), so the trace/set_rows dst ptrs are + constant. The recurrent state write offset `kv_head * n_embd_s` (`qwen35.cpp` L395) is constant for a + single sequence (`get_head()` fixed for seq 0). => node props stable round-to-round. + +4. **The graph key is stable.** The CUDA graph is keyed by `cgraph->nodes[0]`. The verify ubatch is a + **constant** `block_size` tokens (the drafter always emits `block_size-1` drafts: + `speculative.cpp result.assign(block_size-1,0)`; `speculative-simple.cpp` L457-469). So + `llm_graph_result::can_reuse` holds (constant `n_tokens`/`n_outputs`/`cross`/samplers; recurrent + `head`/`rs_z` constant), and `llm_graph_result::reset()` reuses `buf_compute_meta` in place (same + `.data()` => tensors placement-allocated at the same offsets). => `nodes[0]` is the same pointer + across rounds, even when a rebuild happens. + +5. **Double-buffering is a non-issue here.** `cur_copy` only flips in `ggml_backend_sched_alloc_graph` + (skipped on the reuse path), and a single-GPU DFlash target runs `pipeline_parallel=false => + n_copies=1`, so input-copy pointers don't alternate. + +Conclusion: on Ampere+ the verify graph already captures and stays warm (the `cuda_graph` object keyed +by the stable `nodes[0]` keeps `node_props` across rounds; eviction is 10 s, rounds are ms apart). The +warmup does NOT reset for the verify on an identical graph. + +## Root cause of the residual per-round CPU cost (and the whole-model -6%) + +`ggml_graph_view` zeroes the uid; the tail of `ggml_backend_sched_split_graph` then assigns a fresh +monotonic uid per split every call. The CUDA backend's fast-path +(`if (cgraph->uid != 0 && cgraph->uid == graph->uid) return false;`) can therefore only skip the +property walk when the *higher-level* graph reuse keeps `split_graph` from running at all. Any reuse +miss re-runs `split_graph`, bumps the uid, and forces the full walk. On the ~1800-node whole-model graph +that is the measured ~-6%; on the hundreds-of-nodes verify graph it is smaller but non-zero. + +## The existing fix, and what this change adds + +Existing at HEAD (`ggml/src/ggml-backend.cpp`): +- `struct ggml_backend_sched_split` carries `prev_uid` / `prev_sig`. +- The uid loop computes a per-slot topology signature; if it matches the previous round's, it reuses + `prev_uid` instead of minting a fresh one. `GGML_SCHED_STABLE_UID=0` opts out (on by default). +- Grown `splits` slots are zeroed after `realloc` so `prev_uid`/`prev_sig` start clean. + +This change (hardening only): +- The signature was `backend_id + n_nodes + nodes[0] + nodes[n-1]` (endpoints only). A "same count + + same endpoints but different middle" collision would let the backend reuse a **stale captured graph** + (a silent correctness bug). Strengthened it to also fold in a **strided sample of up to ~16 interior + node pointers**, making such a collision effectively impossible while staying O(1)-ish per split. +- Updated the in-code comment to match. + +Why safe: the uid is a pure optimization hint. A matching uid only skips a walk that would have found no +change anyway (signature matched on stable placement-allocated pointers); any mismatch falls back to the +full walk + recapture. The fast-path's `node_props.size() == n_nodes` assert holds because `n_nodes` is +in the signature. + +## Files changed by this investigation + +- `ggml/src/ggml-backend.cpp` — `ggml_backend_sched_split_graph()`: strengthen the per-split topology + signature (strided interior-node sampling); comment fix. No struct/ABI change beyond what HEAD already + had; no CUDA file touched. + +## What to validate on GPU (remote Ampere+ box; V100 needs `GGML_CUDA_GRAPHS_VOLTA`) + +1. Build with CUDA. On V100 also pass `GGML_CUDA_GRAPHS_VOLTA=` (n >= verify node count, or 1). +2. Run `speculative-simple --spec dflash` for draft-max in {8, 12, 16}, comparing tokens/sec with vs + without the uid stabilization (`GGML_SCHED_STABLE_UID=0` disables). Expect the fix to remove the + per-round split walk -> higher t/s, and to make the larger draft blocks (12/16) viable toward the + SGLang accept_len ~6.6 regime (combined with the DESIGN.md chunked GDN kernel). +3. Debug build of `ggml-cuda.cu` (`-DCMAKE_BUILD_TYPE=Debug`): confirm + `GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", ...)` fires every steady-state verify round. +4. Correctness: greedy verify output must be **token-identical** with and without + `GGML_SCHED_STABLE_UID` (a divergence would mean a signature collision — not expected after the + hardening). +5. Cross-check whole-model decode (no spec): the same stabilization should turn the prior ~-6% CUDA-graph + regression neutral/positive. + +GPU validation command (example): + +``` +GGML_SCHED_STABLE_UID=0 ./build/bin/llama-speculative-simple \ + -m target.gguf -md draft.gguf --spec dflash --draft-max 16 --draft-min 1 -n 256 -p "" +GGML_SCHED_STABLE_UID=1 ./build/bin/llama-speculative-simple \ + -m target.gguf -md draft.gguf --spec dflash --draft-max 16 --draft-min 1 -n 256 -p "" +``` diff --git a/GDN_CHUNKED_BRINGUP.md b/GDN_CHUNKED_BRINGUP.md new file mode 100644 index 000000000000..1a4f63afea1b --- /dev/null +++ b/GDN_CHUNKED_BRINGUP.md @@ -0,0 +1,170 @@ +# Chunked GDN verify — bring-up recipe (math validated, ggml-op decomposition) + +The chunked math is bitwise-correct (gdn_chunked_oracle.py vs sequential, err ~1e-13). It needs NO +hand-written CUDA kernel: ggml has cumsum + tri + solve_tri + mul_mat on BOTH CPU and CUDA. Build the +verify GDN (n_seqs==1, N=draft_max+1 ≤ ~16, single chunk) as a ggml subgraph; validate on the CPU +backend locally (build/bin/libggml*.dylib already built — write a standalone test linking it), then +it runs on CUDA for free. + +## Inputs (from ggml_gated_delta_net): q,k,v [S_v,H,N,1]; g [S_v,H,N,1] (kda); beta [1,H,N,1]; +## state S0 [S_v,S_v,H,1]. Output [S_v*H, N + S_v] (cols 0..N-1 = attn, cols N..N+S_v-1 = new state). + +## Op recipe (per the VALIDATED oracle; A_r[i]=prod_{s<=r} a_s, a=exp(g)): +1. Permute q,k,v,g to per-head token-matrices Xp [S_v, N, H] (ne0=dim i, ne1=token r, ne2=head). +2. A: put tokens on ne0 -> g2 [N, S_v, H]; L = ggml_cumsum(g2) over ne0 (tokens); A = exp(L); + permute A back to [S_v, N, H]. (cumsum is ne0-only, hence the shuffle.) +3. Kbar = A⊙kp ; Qbar = A⊙qp ; Ktil = kp ⊙ exp(-L_perm) (= kp/A, but form via exp(-L) to stay fp32-safe). +4. U_carry = mul_mat(S0[i,j,H], Kbar[i,r,H]) -> [j, r, H] (contracts i). O_carry = scale·mul_mat(S0, Qbar). +5. KK = mul_mat(Ktil[i,s,H], Kbar[i,r,H]) -> [s, r, H]; want T[r,s]=beta_r·KK[s,r]. + Tfull = beta(broadcast over s) ⊙ transpose(KK to [r,s,H]); T = ggml_tri(Tfull, LOWER strict). +6. rhs = beta ⊙ (vp_as[r,j] - U_carry[j,r]^T) -> shape [j? r?]; keep as [N, S_v, H] (r on ne0) for solve. + Dmat = ggml_solve_tri(A=I+T [N,N,H] unit-lower, B=rhs, left=true, lower=true, uni=true) -> [N, S_v, H]. +7. QK = ggml_tri(transpose(mul_mat(Ktil,Qbar)) to [r,s,H], LOWER_DIAG incl diagonal); + O = O_carry + scale·mul_mat(QK[s,r,H]?, Dmat[s,j,H]) -> [r,j,H] (contracts s). +8. S_out[i,j,H] = (A_end[i] ⊙ S0[i,j]) + mul_mat(Kw[i? ], Dmat) ; Kw_r = (A_end/A_r)⊙k_r. +9. Reassemble into the [S_v*H, N+S_v] output layout (attn cols = O reshaped, state cols = S_out). + +## Validation (do BEFORE wiring into the model): +- Standalone CPU test: random q,k,v,g,beta,S0; run ggml_gated_delta_net (sequential, reference) and + build_gated_delta_net_chunked; ggml_backend_cpu; compare max|diff| < 1e-4. Iterate layout bugs here + (mul_mat is a^T b contracting ne0; transposes/permutes are where bugs hide). +- Then bitwise-vs-sequential for the TRACE rows too (DFlash rewind needs per-token state; the chunked + path gives only the final S_out + per-token O, NOT per-token state -> for the rewind we still need + per-token states. EITHER also emit per-token states (S after each token = O_carry-style partial), OR + keep trace on the sequential path. RESOLVE THIS before shipping: the rewind/promote depends on the + per-token trace; chunked must reproduce it or the verify can't use trace+promote.) + +## Then: wire as the verify fast path (n_seqs==1, N small) behind a flag; bench draft-max 8/12/16 on +## Blackwell/H100; expect verify cost ~flat in N -> larger blocks affordable -> accept_len ~6 -> ~2.5-3x. + +## OPEN RISK (important): the DFlash rewind (trace+promote) needs per-TOKEN states. The chunked form +## naturally yields only the chunk-final state. Per-token states within the chunk can be recovered +## (S_t = decay(S0,t) + intra-chunk updates up to t) but that's extra work; OR run chunked for speed +## and the sequential trace only when a partial-accept rewind is actually needed. Must be designed. + +--- +## STATUS: chunked GDN ggml-graph VALIDATED (commit 9c1f082b8). Integration plan below. + +tests/test-gdn-chunked.cpp: chunked-vs-sequential ALL PASS (N=1..16, S_v=64/128, fp32). The graph is +portable (cumsum/tri/solve_tri/mul_mat on CPU+CUDA; the path that also extends Metal/Vulkan/WebGPU). + +### Wiring (next): +1. Lift build_chunked() from the test into a reusable builder, e.g. build_gated_delta_net_chunked() + in src/models/ (or a ggml helper), returning the SAME [S_v*H, N+S_v] output as ggml_gated_delta_net. +2. In src/models/qwen35.cpp build_layer_attn_linear: when (n_seqs==1 && n_seq_tokens>1 && verify), + call the chunked builder instead of ggml_gated_delta_net. Keep the fused sequential op for prefill + and single-token decode (chunked wins only for a multi-token block). +3. Gate behind a flag (e.g. cparams.gdn_chunked or env) so it's opt-in until GPU-validated. + +### TRACE / rewind (the one real design decision): +DFlash rewind (trace+promote) needs the per-TOKEN state S_t for the accepted position. The chunked +path yields only S_out (after all N). Resolution: keep the chunked path for the fast verify FORWARD +(attn + S_out); on a PARTIAL accept (acc accept_len ~6 -> + toward 2.5-3x (the SGLang regime). This is the payoff. + +### NUMERICAL: fp32 err grows mildly with N,D (1e-5 at N16/D128). On GPU keep fp32 accumulation in the +matmuls/solve. If a stronger-decay model underflows k/A, switch to the log-space ratio form (DESIGN.md). + +--- +## REASSEMBLY into the [S_v*H, N+S_v] combined output (the drop-in detail, derived + checked) +ggml_gated_delta_net's data = [attn region (S_v*H*N) | state region (S_v*S_v*H)], flat: + attn[h,t,j] at (t*H + h)*S_v + j -> column-major [S_v*H, N], row=h*S_v+j, col=t + state[h,i,j] at attn_elems + h*S_v*S_v + j*S_v + i -> EXACTLY a [S_v(i),S_v(j),H] contiguous tensor +So the chunked builder's cs ([i,j,H], already contiguous) IS the state region as-is. For attn: + O is [t, j, H] -> permute(2,0,1,3) -> [j, H, t] -> cont -> reshape [S_v*H, N]. +Combined = reshape( ggml_concat( reshape(O_attn,1D[S_v*H*N]), reshape(cs,1D[S_v*S_v*H]), dim0 ), + [S_v*H, N+S_v] ). Drop-in for ggml_gated_delta_net's result. + +## GQA: q/k have num_k_heads (ssm_n_group), v/g/beta/state have num_v_heads (ssm_dt_rank). In the +## builder: ggml_repeat q/k from Hk to H (interleaved h%Hk) BEFORE the per-head ops. VALIDATED. + +## DONE this session: chunked GDN ggml builder VALIDATED on CPU bitwise vs sequential, incl GQA +## (tests/test-gdn-chunked.cpp, ALL PASS N=1..16, S_v=64/128, H_v=4, H_k=1/2). It is portable +## (CUDA/Metal/Vulkan/WebGPU via the op kernels). REMAINING: (1) lift build_chunked into +## delta-net-base.cpp + reassembly above, gate by env, fall back to sequential when gdn_trace!=null; +## (2) GPU build + bitwise vs sequential + draft-max 8/12/16 speed (verify ~flat in N -> ~3x); +## (3) trace/rewind: compute S_acc on partial-accept from kept A/Ktil/Dmat (no per-token buffer). + +--- +## GPU BRING-UP RESULTS (eva01 V100, Qwen3.5-4B-Q8_0, 24 GDN + 8 attn) — DECISIVE, lever #3 verdict + +Ran the wired chunked path on CUDA. Three findings settle whether chunked GDN is a speedup lever: + +1. **Qwen3.5 GDN uses a SCALAR gate, not the vector (KDA) gate.** Gate diagnostic at the dispatch + site: `g->ne[0]=1, S_v=128`. The builder + tests/test-gdn-chunked.cpp were written for the + per-channel VECTOR gate (`g->ne[0]==S_v`), so the original `g->ne[0]==S_v` trigger NEVER fired on + Qwen3.5 — every "chunked vs seq" comparison before this was a silent no-op (identical numbers). + Fix: generalized the builder to accept the scalar gate (A=[1,N,H] broadcasts across S_v; AendB + width from A->ne[0]; full-size tensor first in every A-multiply so [1,N,H] repeats into [S_v,N,H]). + Now fires: `[GDN-CHUNKED] active: N=16 S_v=128 H=32 Hk=16 gate=scalar`, assert-free on CUDA. + +2. **Chunked is SLOWER than the fused CUDA kernel** (llama-bench pp, t/s, r=8): + | N (pp) | fused (off) | chunked (on) | + |--------|-------------:|-------------:| + | 8 | 514 | 413 | + | 16 | 830 | 672 | + | 32 | 1373 | 1129 | + | 64 | 1507 | 1343 | + CUDA already ships a native `GGML_OP_GATED_DELTA_NET` kernel (ggml-cuda/gated_delta_net.cu) that + the default multi-token path uses; the decomposed ggml graph (cumsum/exp/diag/tri/solve_tri + + several mul_mat) cannot beat it. **Worse, GDN is not the verify bottleneck at all**: baseline pp + throughput is identical whether the GDN op is the fused kernel or not — the 4B transformer's + attention+MLP matmuls dominate at N<=64, the GDN scan over <=16 tokens is negligible. + +3. **Single-chunk is numerically valid only for SMALL N.** The builder treats the whole block as ONE + chunk, so `Ainv = exp(-cumsum(g))` overflows for long sequences. A chat-template-inflated prefill + (~30-40 tok) already produces `?????` garbage under LLAMA_GDN_CHUNKED=1. It is correct only for a + true small verify block (<=16, as the CPU test covers); using it for prefill needs real multi-chunk + tiling (chunk into 16-64 blocks) which is NOT implemented. + +### VERDICT (lever #3): chunked GDN is a PORTABILITY artifact, NOT a speed lever. +- On CUDA it loses to the fused kernel and GDN isn't the bottleneck -> zero speedup toward SGLang's ~3x. +- Its only real use is backends WITHOUT a fused GDN kernel (WebGPU/Metal/Vulkan) AND only for small + blocks. For the speedup goal, the real lever is #2 (CUDA graphs for cheap large-block verify), + not GDN chunking. +- Code kept opt-in behind LLAMA_GDN_CHUNKED (default OFF, gdn_trace==null only) so normal serving is + untouched. Scalar-gate model-path correctness at small N is UNVERIFIED (llama-cli is confounded by + template prefill length + single-chunk overflow); do not claim it correct without a unit harness. + +--- +## PORTABLE + CORRECT (multi-chunk tiling) — what it took to make chunked GDN actually work on CUDA + +Goal: a pure-ggml chunked GDN that is CORRECT on every backend (verify path for fused-kernel-less +backends: WebGPU/Metal/Vulkan). Two deliverables: a unit test proving correctness, and multi-chunk +tiling for sequences longer than one block. Both done; the road there found three real issues. + +1. **Scalar gate proven correct (CPU + CUDA).** tests/test-gdn-chunked.cpp now covers the scalar + (Gated DeltaNet, g->ne[0]==1) gate as well as the vector (KDA) gate, GQA, and runs on a SELECTABLE + backend (GDN_BACKEND=CUDA). Match vs ggml_gated_delta_net is ~1e-8 (fp32) on both CPU and CUDA. + +2. **Multi-chunk tiling.** build_delta_net_chunked tiles the N tokens into blocks of C, threads the + recurrent state forward, concats the per-block attn. For a verify block (N<=C) the loop runs once + = the original single-chunk path. Capped at N<=128 in the dispatch gate (LLAMA_GDN_CHUNKED_MAXN): + the loop unrolls ceil(N/C) subgraphs per layer, so a long prefill (the 512-token ubatch reserve) + would explode the static graph (GGML_ASSERT(obj_new) — ctx pool exhausted) -> fall through to the + fused op there. On CUDA the fused kernel is faster for long prefill anyway (see verdict above). + +3. **The test's random inputs hid TWO bugs that only the real model exposed:** + - **Unnormalized keys diverge.** Random k with ||k||^2 ~ S_v*sc^2 >> 2 violates the delta-rule + stability bound beta*||k||^2 < 2, so the TRUE recurrence diverges and ref vs chunked blow up in + the unstable directions for long N (looked like a tiling bug: error grew ~30%/token). Fix: + l2-normalize q,k in the test (delta-net normalizes them) -> stable, machine-eps match. + - **Deflation precision sets a small max chunk size.** A=exp(+/-cumsum(g)) has wide dynamic range; + Ktil=k*exp(-cumsum), Kbar=k*exp(+cumsum) individually span many orders of magnitude even though + their product is bounded, so the KK matmul loses fp32 precision when a chunk is too long. The + test used mild gates (g~-0.2) and passed at C=16/32; the REAL model has strong-decay heads and + **garbles at C>=16, is clean at C<=8** (verified: greedy output identical to the fused baseline + at C=4 and C=8; "??????" at C=16/24). **Deployed default C=8.** The test now uses C=8 to match. + +### Net: chunked GDN is CORRECT on CUDA (greedy output bit-identical to the fused kernel across +short/medium/long prompts) and portable. Confirmed NOT a speedup on CUDA (the fused kernel is faster +and GDN isn't the bottleneck); its role is the verify path on backends without a fused GDN kernel. +Validation harness: tests/test-gdn-chunked.cpp (GDN_BACKEND=CPU|CUDA), all PASS on both. diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index d9f8aaec52fd..981851a68d11 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -769,6 +769,16 @@ struct ggml_backend_sched_split { int n_inputs; // graph view of this split struct ggml_cgraph graph; + + // stable-uid bookkeeping (see ggml_backend_sched_split_graph): when a re-split produces a + // byte-identical split (same backend, node count, and head/tail node pointers) we reuse the + // previous uid instead of minting a fresh one. A stable uid lets graph-capturing backends + // (CUDA graphs) hit their uid fast-path and skip the per-node property walk every round, which + // is the dominant CPU overhead for a stable speculative-verify graph. This is purely a hint: + // a matching uid only lets the backend skip a walk that would have found no change anyway, and + // a non-matching uid always falls back to the full walk, so correctness is unaffected. + uint64_t prev_uid; // uid assigned on the previous split_graph for this slot (0 = none) + uint64_t prev_sig; // topology signature of the previous split for this slot }; struct ggml_backend_sched { @@ -1304,10 +1314,14 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra split->i_end = i; i_split++; if (i_split >= sched->splits_capacity) { + const int prev_capacity = sched->splits_capacity; sched->splits_capacity *= 2; sched->splits = (ggml_backend_sched_split *) realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); GGML_ASSERT(sched->splits != NULL); + // zero the newly grown slots so the stable-uid prev_uid/prev_sig start clean + memset(&sched->splits[prev_capacity], 0, + (sched->splits_capacity - prev_capacity) * sizeof(struct ggml_backend_sched_split)); } split = &sched->splits[i_split]; split->backend_id = node_backend_id; @@ -1481,8 +1495,50 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra } // set ids for all splits + // + // stable-uid optimization: if this split slot is byte-identical to the previous split_graph + // (same backend, node count, and a strided sample of node pointers - sufficient because nodes + // are placement-allocated at stable offsets in the reused compute buffer, so an unchanged + // topology re-materializes the exact same tensor pointers), reuse the previous uid. This lets + // graph-capturing backends (CUDA graphs) take their uid fast-path and skip the O(n_nodes) + // property walk + warmup churn every round - the dominant per-round CPU cost for a stable + // speculative-verify graph that re-splits because the higher-level graph-reuse check missed. + // Opt out with GGML_SCHED_STABLE_UID=0. Hint only: a reused uid merely skips a walk that would + // have found no change; any mismatch falls back to the full walk, so correctness is unaffected. + static const bool stable_uid = [] { + const char * e = getenv("GGML_SCHED_STABLE_UID"); + return e == nullptr || atoi(e) != 0; // on by default + }(); for (int i = 0; i < sched->n_splits; ++i) { - sched->splits[i].graph.uid = ggml_graph_next_uid(); + struct ggml_backend_sched_split * split = &sched->splits[i]; + + uint64_t sig = 0; + if (stable_uid && split->graph.n_nodes > 0) { + // cheap topology signature: backend + node count + a strided sample of node pointers + // (head, tail, and up to ~16 interior nodes). Nodes are placement-allocated at stable + // offsets in the reused compute buffer, so an unchanged topology re-materializes the + // exact same pointers; ANY topology change shifts the count and/or these offsets. The + // interior sampling makes a same-count/same-endpoints-but-different-middle collision + // (which would let the backend reuse a stale captured graph) effectively impossible. + const int n = split->graph.n_nodes; + const int step = n > 16 ? n / 16 : 1; + sig = (uint64_t) (uint32_t) split->backend_id; + sig = sig * 1099511628211ull + (uint64_t) n; + for (int k = 0; k < n; k += step) { + sig = sig * 1099511628211ull + (uint64_t) (uintptr_t) split->graph.nodes[k]; + } + sig = sig * 1099511628211ull + (uint64_t) (uintptr_t) split->graph.nodes[n - 1]; + } + + if (stable_uid && sig != 0 && sig == split->prev_sig && split->prev_uid != 0) { + // identical to last time - keep the previous uid so the backend graph fast-path fires + split->graph.uid = split->prev_uid; + } else { + split->graph.uid = ggml_graph_next_uid(); + } + + split->prev_uid = split->graph.uid; + split->prev_sig = sig; } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 22727fd91dbf..d1ca702a808c 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -278,6 +278,16 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // [TAG_GDN_CHUNKED] DFlash verify fast path: chunk-parallel GDN (see DESIGN.md + gated_delta_net_chunked.cu). + // When implemented & validated, route the single-sequence verify block here to cut sequential depth N -> ceil(N/C): + // if (n_seqs == 1 && n_tokens >= GDN_CHUNK_MIN && n_tokens <= GDN_CHUNK_MAX + // && S_v <= GDN_CHUNK_DMAX && trace_d == nullptr) { + // if (kda) launch_gated_delta_net_chunked(q_d,k_d,v_d,g_d,b_d,s_d,dst_d,trace_d, S_v,H,n_tokens, sq1,sq2,sv1,sv2,sb1,sb2, scale,stream); + // else launch_gated_delta_net_chunked(q_d,k_d,v_d,g_d,b_d,s_d,dst_d,trace_d, S_v,H,n_tokens, sq1,sq2,sv1,sv2,sb1,sb2, scale,stream); + // return; + // } + // (left disabled until the skeleton is made compilable + bitwise-validated against this sequential kernel.) + if (kda) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh index 7375e81c0c36..f71dd68d5908 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cuh +++ b/ggml/src/ggml-cuda/gated_delta_net.cuh @@ -2,3 +2,17 @@ #include "ggml.h" void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +// Chunk-parallel GDN forward for the DFlash speculative VERIFY path (single sequence). +// Cuts the verify's sequential depth from N to ceil(N/C). See DESIGN.md (repo root) +// and gated_delta_net_chunked.cu. SKELETON — not yet wired into the dispatcher. +template +void launch_gated_delta_net_chunked( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, + float * dst_d, float * trace_d, + int64_t S_v, int64_t H, int64_t n_tokens, + int64_t sq1, int64_t sq2, + int64_t sv1, int64_t sv2, + int64_t sb1, int64_t sb2, + float scale, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/gated_delta_net_chunked.cu b/ggml/src/ggml-cuda/gated_delta_net_chunked.cu new file mode 100644 index 000000000000..74fbc15fe266 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net_chunked.cu @@ -0,0 +1,325 @@ +#include "gated_delta_net.cuh" + +// NOTE: SKELETON ONLY - not yet compilable (pseudo-helpers/TODOs). Guarded out of the +// build so the branch compiles; see DESIGN.md. Remove this #if 0 once the kernel is +// brought up per the validation plan. +#if 0 + +// ============================================================================= +// Chunk-parallel Gated-DeltaNet forward for the DFlash speculative VERIFY path. +// +// SCOPE: single sequence (n_seqs == 1), a block of N tokens (N = draft_max + 1, +// up to ~32). Prefill and single-token decode keep using the existing +// SEQUENTIAL kernel in gated_delta_net.cu. This file is the CHUNK-PARALLEL +// variant that cuts the verify's sequential depth from N to ceil(N/C). +// +// STATUS: reviewable SKELETON. The structure, tiling, and math->thread mapping +// are concrete; the cooperative matmul/scan/solve bodies are sketched +// with pseudo-helpers (cg_* / smem tiles) and are NOT yet a compilable, +// numerically-verified kernel. See DESIGN.md (repo root), sections 2-6, +// for the derivation and the bring-up/validation plan. +// +// MATH (mathematical S[i][j] convention; i = key dim, j = value dim, D = head dim): +// per token t, with decay a_t[i] = exp(g_t[i]) (KDA) or exp(g_t) (scalar): +// decay: S[i][j] <- a_t[i] * S[i][j] +// kv: u_t[j] = sum_i S[i][j] * k_t[i] +// delta: d_t[j] = (v_t[j] - u_t[j]) * beta_t +// update: S[i][j] += k_t[i] * d_t[j] +// out: o_t[j] = scale * sum_i S[i][j] * q_t[i] scale = 1/sqrt(D) +// +// chunked (chunk size C, S_in = state entering the chunk; see DESIGN.md 2.2): +// A_r[i] = prod_{s<=r} a_s[i] (inclusive cumprod, log space) +// Kbar_r = A_r (.) k_r Qbar_r = A_r (.) q_r (carry-read keys/queries) +// k~_r = k_r / A_r q~_r = q_r * A_r (deflated/inflated) +// Kw_r = (A_{C-1}/A_r) (.) k_r (carry-forward write key) +// U_carry = Kbar @ S_in [C x D] +// T = strict_tril( beta (.) (K~ K~^T) ) [C x C] +// Dmat = (I + T)^{-1} @ ( beta (.) (V - U_carry) ) [C x D] (fwd-subst solve) +// O = scale*(Qbar @ S_in) + scale*tril(Q~ K~^T) @ Dmat [C x D] +// S_out = diag(A_{C-1}) S_in + Kw^T @ Dmat [D x D] +// +// STATE STORAGE: same transposed layout as the sequential kernel and CPU ref: +// M[j*D + i] = S[i][j] (row j of M is column j of S, contiguous). +// +// IMPORTANT (see top-level CLAUDE.md fp16 pooling lesson): ALL accumulation here +// is f32. Gate cumprods are done in LOG space (sums of g) and ratios formed as +// exp(L_r - L_s) so we never divide by a tiny exp() denominator. +// ============================================================================= + +// ---- tunables (mirror DESIGN.md 3.1) --------------------------------------- +#define GDN_CHUNK_C 16 // chunk size; a <=16-token verify block is ONE chunk +#define GDN_CHUNK_MIN 2 // below this, sequential decode wins +#define GDN_CHUNK_MAX 32 // largest verify block we accept on this path +#define GDN_CHUNK_DMAX 64 // start with D<=64; D=128 needs the register-shard variant + +// Pseudo-helpers used in the sketch (to be replaced with real cooperative impls): +// smem_matmul_AB(out, A, B, M, K, N) : out[MxN] = A[MxK] @ B[KxN], f32 accum, block-cooperative +// smem_matmul_ABt(out, A, B, M, K, N): out[MxN] = A[MxK] @ B[NxK]^T +// tri_mask_strict(M, C) : zero the upper triangle incl. diagonal +// tri_mask_incl(M, C) : zero the strict upper triangle (keep diagonal) +// warp_fwd_subst(W_or_inplace, T, C) : solve (I + strict_tril(T)) X = RHS by forward substitution + +// ----------------------------------------------------------------------------- +// One CUDA block per (head, sequence). For verify n_seqs==1 -> grid (H,1,1). +// Template on S_v (=D) and KDA exactly like the sequential kernel so the same +// dispatch switch can pick the instantiation. +// ----------------------------------------------------------------------------- +template +__global__ void gated_delta_net_chunked_cuda( + const float * q, // [D, H, T] (key dim, head, token) strides sq* + const float * k, // [D, H, T] + const float * v, // [D, H, T] (value dim, head, token) + const float * g, // [1|D, H, T] gate (scalar or KDA vector over key dim) + const float * beta, // [1, H, T] + const float * curr_state, // [D, D, H] incoming state S_in (transposed M[j][i]) + float * dst, // [attn_scores | new_states] (same layout as sequential kernel) + float * trace, // optional per-token state trace (n_seqs==1), may be nullptr + int64_t H, + int64_t n_tokens, // N (the verify block length) + int64_t sq1, int64_t sq2, // q/k strides (floats): sq1 over head, sq2 over token + int64_t sv1, int64_t sv2, // v strides + int64_t sb1, int64_t sb2, // beta/g base strides + float scale) { + + const int h_idx = blockIdx.x; // head this block owns + const int lane = threadIdx.x; // 0..D-1 (value/key column) + const int warp = threadIdx.y; // 0..num_warps-1 + + // ---- shared-memory tiles (DESIGN.md 3, "Tiling") ------------------------ + // For D<=64, C=16 these fit in <=48KB. For D=128 use the register-shard + // variant for S (like the sequential kernel) instead of smem S. + __shared__ float s_S [D][D]; // incoming/outgoing state (M[j][i] = S[i][j]) + __shared__ float s_K [GDN_CHUNK_C][D]; // raw chunk tiles + __shared__ float s_Q [GDN_CHUNK_C][D]; + __shared__ float s_V [GDN_CHUNK_C][D]; + __shared__ float s_L [GDN_CHUNK_C][D]; // cumulative LOG gate L_r[i] = sum_{s<=r} g_s[i] (KDA) + // (scalar gate: column 0 used, broadcast over i) + __shared__ float s_beta [GDN_CHUNK_C]; + __shared__ float s_Ucar [GDN_CHUNK_C][D]; // U_carry + __shared__ float s_T [GDN_CHUNK_C][GDN_CHUNK_C]; // intra-chunk coupling / solve workspace + __shared__ float s_Dmat [GDN_CHUNK_C][D]; // resolved per-token deltas d_r + __shared__ float s_O [GDN_CHUNK_C][D]; // outputs + + const int C = (int) (n_tokens < GDN_CHUNK_C ? n_tokens : GDN_CHUNK_C); + + // base pointers for this (head) — n_seqs==1 so sequence offset is 0 + const float * q_h = q + h_idx * sq1; + const float * k_h = k + h_idx * sq1; + const float * v_h = v + h_idx * sv1; + const float * gb_base = (const float *) nullptr; // gate/beta offset computed per token below + const int64_t gb_h = h_idx * sb1; + + float * attn_data = dst + h_idx * D; // [.. + token*D*H], value rows + const int64_t attn_score_elems = (int64_t) D * H * n_tokens; // n_seqs==1 + float * state_out = dst + attn_score_elems + (int64_t) h_idx * D * D; + + // ========================================================================= + // Outer loop over chunks. Sequential DEPTH = ceil(N/C). For N<=C this runs once. + // S_in for chunk 0 is curr_state; for later chunks it's the previous S_out. + // ========================================================================= + // load S_in (transposed) into s_S + for (int j = warp; j < D; j += blockDim.y) { + s_S[j][lane] = curr_state[(int64_t) (h_idx * D + j) * D + lane]; + } + __syncthreads(); + + for (int chunk_base = 0; chunk_base < n_tokens; chunk_base += GDN_CHUNK_C) { + const int cc = (int) min((int64_t) GDN_CHUNK_C, n_tokens - chunk_base); + + // -- Phase 1: load chunk tiles + cumulative LOG gate (Hillis-Steele scan) -- + // Load k_r, q_r, v_r, beta_r, g_r for r=0..cc-1 into smem. Then prefix-sum + // g over r (log space) -> s_L[r][i] = sum_{s<=r} g_s[i]. scalar gate: + // s_L[r][0] = sum_{s<=r} g_s, broadcast at use sites. + for (int r = warp; r < cc; r += blockDim.y) { + const int t = chunk_base + r; + s_K[r][lane] = k_h[t * sq2 + lane]; + s_Q[r][lane] = q_h[t * sq2 + lane]; + s_V[r][lane] = v_h[t * sv2 + lane]; + const int64_t gb = gb_h + (int64_t) t * sb2; + if (lane == 0) s_beta[r] = beta[gb]; + // KDA gate is a length-D vector over key dim; scalar gate is length 1 + s_L[r][lane] = KDA ? g[gb * D + lane] : (lane == 0 ? g[gb] : 0.0f); + } + __syncthreads(); + // inclusive prefix sum of s_L along r (one warp marches r; cheap, C<=32) + // gdn_prefix_sum_logspace(s_L, cc, D, KDA); // <- TODO real scan + __syncthreads(); + + // Convenience: A_r[i] = exp(s_L[r][i]) + // A_last[i] = exp(s_L[cc-1][i]) + // ratio(r,i) = exp(s_L[cc-1][i] - s_L[r][i]) (= A_last/A_r, carry-forward) + // Build the derived keys/queries on the fly inside the matmuls below to + // avoid extra smem; shown here named for clarity: + // Kbar_r[i] = exp(s_L[r][i]) * s_K[r][i] + // Qbar_r[i] = exp(s_L[r][i]) * s_Q[r][i] + // k~_r[i] = exp(-s_L[r][i]) * s_K[r][i] + // q~_r[i] = exp(+s_L[r][i]) * s_Q[r][i] (== Qbar; same for query) + // Kw_r[i] = ratio(r,i) * s_K[r][i] + + // -- Phase 2: U_carry = Kbar @ S_in ; O_carry = scale*(Qbar @ S_in) -- + // s_Ucar[r][j] = sum_i Kbar_r[i] * S[i][j] + // = sum_i (exp(L[r][i]) * s_K[r][i]) * s_S[j][i] (M transposed!) + for (int r = warp; r < cc; r += blockDim.y) { + float acc = 0.0f; + for (int i = 0; i < D; ++i) { + const float Kbar = expf(s_L[r][i]) * s_K[r][i]; + acc += Kbar * s_S[lane][i]; // s_S[j][i], here j=lane (value dim column) + } + s_Ucar[r][lane] = acc; + // O_carry accumulates into s_O below (Qbar @ S_in == same with q) + float oacc = 0.0f; + for (int i = 0; i < D; ++i) { + const float Qbar = expf(s_L[r][i]) * s_Q[r][i]; + oacc += Qbar * s_S[lane][i]; + } + s_O[r][lane] = scale * oacc; // O_carry; intra-chunk term added in Phase 5 + } + __syncthreads(); + + // -- Phase 3: T[r][s] = beta_r * (k~_r . k~_s), strict lower triangle -- + // k~_r[i] = exp(-L[r][i]) * s_K[r][i]. Build [C x C], mask s>=r to 0. + for (int r = warp; r < cc; r += blockDim.y) { + // each lane handles a subset of columns s + for (int s = lane; s < cc; s += blockDim.x) { + if (s < r) { + float dot = 0.0f; + for (int i = 0; i < D; ++i) { + const float kr = expf(-s_L[r][i]) * s_K[r][i]; + const float ks = expf(-s_L[s][i]) * s_K[s][i]; + dot += kr * ks; + } + s_T[r][s] = s_beta[r] * dot; + } else { + s_T[r][s] = 0.0f; // strict lower only + } + } + } + __syncthreads(); + + // -- Phase 4: solve Dmat = (I + strict_tril(T))^{-1} @ RHS, RHS = beta(.)(V - U_carry) -- + // Unit-lower-triangular -> forward substitution, C sequential micro-steps + // on the C x C system (one warp). RHS lives in s_Dmat initially. + for (int r = warp; r < cc; r += blockDim.y) { + s_Dmat[r][lane] = s_beta[r] * (s_V[r][lane] - s_Ucar[r][lane]); // RHS row r + } + __syncthreads(); + // forward substitution: for r = 0..cc-1: Dmat[r] -= sum_{sr)^T @ Dmat[0..r] per r + // and write each into trace[(t*H + h)*D*D ...]. Sketched as a prefix pass; + // omitted in this skeleton (start by routing trace!=nullptr to sequential). + (void) trace; + } + + // -- final state writeback (transposed layout, same as sequential kernel) -- + for (int j = warp; j < D; j += blockDim.y) { + state_out[(int64_t) j * D + lane] = s_S[j][lane]; + } +} + +// ----------------------------------------------------------------------------- +// Host launcher. Mirrors launch_gated_delta_net in gated_delta_net.cu but +// uses a (H,1,1) grid with one block per head. Called from the guarded fast path +// in ggml_cuda_op_gated_delta_net (see DESIGN.md 3.1): +// +// if (n_seqs == 1 && n_tokens >= GDN_CHUNK_MIN && n_tokens <= GDN_CHUNK_MAX +// && S_v <= GDN_CHUNK_DMAX && trace_d == nullptr) +// launch_gated_delta_net_chunked(...); +// else +// launch_gated_delta_net(...); // existing sequential kernel +// ----------------------------------------------------------------------------- +template +void launch_gated_delta_net_chunked( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, + float * dst_d, float * trace_d, + int64_t S_v, int64_t H, int64_t n_tokens, + int64_t sq1, int64_t sq2, + int64_t sv1, int64_t sv2, + int64_t sb1, int64_t sb2, + float scale, cudaStream_t stream) { + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int num_warps = 4; + dim3 grid_dims((unsigned) H, 1, 1); + dim3 block_dims((unsigned) (warp_size <= S_v ? warp_size : S_v), num_warps, 1); + + switch (S_v) { + case 16: + gated_delta_net_chunked_cuda<16, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, + sq1, sq2, sv1, sv2, sb1, sb2, scale); + break; + case 32: + gated_delta_net_chunked_cuda<32, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, + sq1, sq2, sv1, sv2, sb1, sb2, scale); + break; + case 64: + gated_delta_net_chunked_cuda<64, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, trace_d, H, n_tokens, + sq1, sq2, sv1, sv2, sb1, sb2, scale); + break; + // case 128: needs the register-shard S variant (smem S is 64KB); see DESIGN.md 3. + default: + GGML_ABORT("gated_delta_net_chunked: unsupported S_v (use sequential path)"); + break; + } +} + +// explicit instantiations so the dispatcher in gated_delta_net.cu can link them +template void launch_gated_delta_net_chunked(const float*,const float*,const float*,const float*,const float*,const float*,float*,float*,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,float,cudaStream_t); +template void launch_gated_delta_net_chunked(const float*,const float*,const float*,const float*,const float*,const float*,float*,float*,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,int64_t,float,cudaStream_t); + +#endif // skeleton guard diff --git a/ggml/src/ggml-cuda/gdn_chunked_oracle.py b/ggml/src/ggml-cuda/gdn_chunked_oracle.py new file mode 100644 index 000000000000..e80481ab7645 --- /dev/null +++ b/ggml/src/ggml-cuda/gdn_chunked_oracle.py @@ -0,0 +1,67 @@ +import numpy as np +np.random.seed(0) + +# Gated DeltaNet reference (matches ggml CPU ops.cpp gated_delta_net_one_chunk). +# State S is D x D, S[i,j] (i=key dim, j=value dim). Per token (vector gate / KDA): +# S <- diag(a_t) @ S (a_t[i] = exp(g_t[i])) [decay rows by a] +# u_t[j] = sum_i S[i,j] k_t[i] [readout on DECAYED state] +# delta_t[j] = beta_t (v_t[j] - u_t[j]) +# S[i,j] += k_t[i] delta_t[j] [rank-1 update] +# o_t[j] = scale * sum_i S[i,j] q_t[i] [POST-update readout] +def sequential(q, k, v, g, beta, S0, scale): + N, D = q.shape + S = S0.astype(np.float64).copy() + O = np.zeros((N, D)) + for t in range(N): + a = np.exp(g[t]) # (D,) decay per key-dim i + S = (a[:, None]) * S # decay rows + u = S.T @ k[t] # (D,) over j + delta = beta[t] * (v[t] - u) # (D,) + S = S + np.outer(k[t], delta) # rank-1 + O[t] = scale * (S.T @ q[t]) # post-update + return O, S + +# Chunked (single chunk = whole block), per agent B's design. Inclusive cumulative +# decay A_r[i] = prod_{s<=r} a_s[i]. Deflate by A to factor the decay out. +def chunked(q, k, v, g, beta, S0, scale): + N, D = q.shape + S0 = S0.astype(np.float64) + a = np.exp(g.astype(np.float64)) # (N,D) + A = np.cumprod(a, axis=0) # (N,D) inclusive cumulative decay + Kbar = A * k # (N,D) "later token" (carries A_r) + Qbar = A * q + Ktil = k / A # (N,D) "earlier token" (carries 1/A_s) + # carry from incoming state + U_carry = Kbar @ S0 # (N,D) over j + O_carry = scale * (Qbar @ S0) + # pairwise decay s->r is A_r/A_s => Kbar_r . Ktil_s (bounded for s forward substitution) + Dmat = np.linalg.solve(np.eye(N) + T, rhs) + # intra-chunk output: lower-tri (incl diagonal) of (Qbar Ktil^T) + QK = np.tril(Qbar @ Ktil.T, k=0) # (N,N) + O = O_carry + scale * (QK @ Dmat) + # carry-out state: S_out = diag(A_{N-1}) S0 + Kw^T @ Dmat, Kw_r = (A_{N-1}/A_r) k_r + Aend = A[-1] # (D,) + Kw = (Aend[None, :] / A) * k # (N,D) + S_out = Aend[:, None] * S0 + Kw.T @ Dmat + return O, S_out + +for trial in range(5): + N = np.random.randint(2, 17) # block up to 16 + D = 64 + q = np.random.randn(N, D)*0.5 + k = np.random.randn(N, D)*0.5 + v = np.random.randn(N, D)*0.5 + g = -np.abs(np.random.randn(N, D))*0.1 # gates: log-decay <=0 (a<=1) + beta = np.random.rand(N) + S0 = np.random.randn(D, D)*0.3 + scale = 1.0/np.sqrt(D) + Os, Ss = sequential(q,k,v,g,beta,S0,scale) + Oc, Sc = chunked(q,k,v,g,beta,S0,scale) + eO = np.abs(Os-Oc).max() + eS = np.abs(Ss-Sc).max() + print(f"trial {trial}: N={N} D={D} | max|dO|={eO:.2e} max|dS|={eS:.2e} | {'OK' if max(eO,eS)<1e-9 else 'MISMATCH'}") diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index cb78b4067e61..dcbe0149b743 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -397,6 +397,20 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + // chunk-parallel GDN for a multi-token block (DFlash verify): portable (pure ggml ops) verify of a + // small block. Opt-in via LLAMA_GDN_CHUNKED=1. Only when no per-token trace is requested (the + // rewind needs per-token state, see GDN_CHUNKED_BRINGUP.md), n_seqs==1, vector/scalar gate, N>1. + // Capped to small N: the tiling unrolls ceil(N/16) subgraphs per layer, so a long prefill (the + // 512-token ubatch reserve) would explode the static graph -> fall through to the fused op there. + static const bool gdn_chunked = getenv("LLAMA_GDN_CHUNKED") && + std::string(getenv("LLAMA_GDN_CHUNKED")) != "0"; + static const int64_t gdn_chunked_maxn = + getenv("LLAMA_GDN_CHUNKED_MAXN") ? atoll(getenv("LLAMA_GDN_CHUNKED_MAXN")) : 128; + if (gdn_chunked && gdn_trace == nullptr && n_seqs == 1 && n_tokens > 1 && n_tokens <= gdn_chunked_maxn + && (g->ne[0] == S_v || g->ne[0] == 1)) { + return build_delta_net_chunked(q, k, v, g, b, s, il); + } + ggml_tensor * result; if (gdn_trace != nullptr) { // per-token state trace requested (DFlash speculative rewind on recurrent targets) @@ -427,6 +441,100 @@ std::pair llm_build_delta_net_base::build_delta_ne return {output, new_state}; } +// One chunk of the chunk-parallel GDN. Returns attn O [token,S_v,H] and carry-out state S_out +// [i,j,H] (both 3D, pre-reshape) so the tiling wrapper can concat tokens and thread state. Math +// validated bitwise vs ggml_gated_delta_net in tests/test-gdn-chunked.cpp (vector+scalar gate, GQA). +std::pair llm_build_delta_net_base::build_delta_net_one_chunk( + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * b, ggml_tensor * s) { + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t N = v->ne[2]; + const int64_t Hk = q->ne[1]; + const float scale = 1.0f/sqrtf((float)S_v); + + auto toDNH = [&](ggml_tensor * x){ return ggml_cont(ctx0, ggml_permute(ctx0, x, 0,2,1,3)); }; // [S,Hx,N]->[S,N,Hx] + ggml_tensor * qp = toDNH(q), * kp = toDNH(k), * vp = toDNH(v), * gp = toDNH(g); + if (Hk != H) { // GQA: broadcast q/k heads Hk->H (interleaved) + ggml_tensor * tgt = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, S_v, N, H); + qp = ggml_repeat(ctx0, qp, tgt); + kp = ggml_repeat(ctx0, kp, tgt); + } + ggml_tensor * gN = ggml_cont(ctx0, ggml_permute(ctx0, gp, 1,0,2,3)); // [N,S_v,H] + ggml_tensor * Lp = ggml_cont(ctx0, ggml_permute(ctx0, ggml_cumsum(ctx0, gN), 1,0,2,3)); // [S_v,N,H] + ggml_tensor * A = ggml_exp(ctx0, Lp); + ggml_tensor * Ainv = ggml_exp(ctx0, ggml_neg(ctx0, Lp)); + // full-size tensor first so a scalar gate's A=[1,N,H] broadcasts into kp/qp=[S_v,N,H] + ggml_tensor * Kbar = ggml_mul(ctx0, kp, A); + ggml_tensor * Qbar = ggml_mul(ctx0, qp, A); + ggml_tensor * Ktil = ggml_mul(ctx0, kp, Ainv); + ggml_tensor * betaNH = ggml_cont(ctx0, ggml_permute(ctx0, b, 0,2,1,3)); // [1,N,H] + ggml_tensor * Ucar = ggml_mul_mat(ctx0, s, Kbar); // [j,r,H] + ggml_tensor * Ocar = ggml_scale(ctx0, ggml_mul_mat(ctx0, s, Qbar), scale); + ggml_tensor * rhs = ggml_mul(ctx0, ggml_sub(ctx0, vp, Ucar), betaNH); // [S_v(j),N(r),H] + ggml_tensor * KK = ggml_mul_mat(ctx0, Ktil, Kbar); // [s,r,H] + ggml_tensor * Tlo = ggml_tri(ctx0, ggml_mul(ctx0, KK, betaNH), GGML_TRI_TYPE_LOWER); + // identity [N,N] (solve_tri needs an explicit unit diagonal): ones via exp(0*beta), then diag + ggml_tensor * b0 = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_view_2d(ctx0, betaNH, 1, N, betaNH->nb[1], 0))); + ggml_tensor * ones = ggml_exp(ctx0, ggml_scale(ctx0, b0, 0.0f)); // [N,1] all-ones + ggml_tensor * Imat = ggml_diag(ctx0, ones); // [N,N] + ggml_tensor * ITm = ggml_add(ctx0, Tlo, Imat); // I+T, bcast over H + ggml_tensor * Dmat = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_solve_tri(ctx0, ITm, rhs, true, true, false))); // [N(token),S_v(j),H] + ggml_tensor * QKlo = ggml_tri(ctx0, ggml_mul_mat(ctx0, Ktil, Qbar), GGML_TRI_TYPE_LOWER_DIAG); + ggml_tensor * intra= ggml_scale(ctx0, ggml_mul_mat(ctx0, QKlo, Dmat), scale); // [t,j,H] + ggml_tensor * O = ggml_add(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, Ocar)), intra); // [t,j,H] + ggml_tensor * AendB= ggml_cont(ctx0, ggml_view_3d(ctx0, A, A->ne[0], 1, H, A->nb[1], A->nb[2], (N-1)*A->nb[1])); + ggml_tensor * S0dec= ggml_mul(ctx0, s, AendB); + ggml_tensor * Kw = ggml_mul(ctx0, kp, ggml_mul(ctx0, Ainv, AendB)); + ggml_tensor * upd = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, Kw)), Dmat); + ggml_tensor * S_out= ggml_add(ctx0, S0dec, upd); // [i,j,H] + return {O, S_out}; +} + +std::pair llm_build_delta_net_base::build_delta_net_chunked( + ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * b, ggml_tensor * s, int il) { + // Chunk-parallel GDN. Pure ggml ops -> portable (CUDA/Metal/Vulkan/WebGPU) for backends without a + // fused GDN kernel. The N tokens are tiled into blocks of C, carrying the recurrent state forward, + // so Ainv=exp(-cumsum(g)) stays bounded per block: single-chunk over a long prefill overflows fp32. + // For a verify block (N<=C) the loop runs once -> identical to the original single-chunk path. + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t N = v->ne[2]; + const int64_t Hk = q->ne[1]; + // Gate may be per-channel (KDA vector, g->ne[0]==S_v) or per-head scalar (Gated DeltaNet, ==1). + GGML_ASSERT((g->ne[0] == S_v || g->ne[0] == 1) && "chunked GDN gate must be vector(S_v) or scalar(1)"); + GGML_ASSERT(v->ne[3] == 1 && "chunked GDN path is n_seqs==1 only"); + + // Block size. The deflation A=exp(+/-cumsum(g)) has a wide dynamic range; strong-decay heads + // overflow fp32 precision when a chunk is too long (empirically garbles at C>=16 on Qwen3.5, + // clean at C<=8). 8 is the safe default; override with LLAMA_GDN_CHUNK_SIZE. + int64_t C = 8; + if (const char * e = getenv("LLAMA_GDN_CHUNK_SIZE")) { int64_t c = atoll(e); if (c >= 1) C = c; } + + if (getenv("LLAMA_GDN_CHUNKED_VERBOSE")) { + static bool once = false; + if (!once) { once = true; fprintf(stderr, "[GDN-CHUNKED] active: N=%lld C=%lld chunks=%lld S_v=%lld H=%lld Hk=%lld gate=%s\n", + (long long)N, (long long)C, (long long)((N+C-1)/C), (long long)S_v, (long long)H, (long long)Hk, g->ne[0]==1?"scalar":"vector"); } + } + + ggml_tensor * S = s; // carried recurrent state [S_v,S_v,H,1] + ggml_tensor * O_full = nullptr; // attn output, concatenated over tokens [token,S_v,H] + for (int64_t start = 0; start < N; start += C) { + const int64_t cn = std::min(C, N - start); + auto slc = [&](ggml_tensor * x){ + return ggml_view_4d(ctx0, x, x->ne[0], x->ne[1], cn, 1, x->nb[1], x->nb[2], x->nb[3], start*x->nb[2]); + }; + auto oc = build_delta_net_one_chunk(slc(q), slc(k), slc(v), slc(g), slc(b), S); + O_full = O_full ? ggml_concat(ctx0, O_full, oc.first, 0) : oc.first; // concat tokens on ne0 + S = ggml_reshape_4d(ctx0, oc.second, S_v, S_v, H, 1); + } + ggml_tensor * output = ggml_reshape_4d(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, O_full, 2,0,1,3)), S_v, H, N, 1); + ggml_tensor * new_state = S; + cb(output, LLAMA_TENSOR_NAME_FGDN_CH, il); + return {output, new_state}; +} + std::pair llm_build_delta_net_base::build_delta_net( ggml_tensor * q, ggml_tensor * k, diff --git a/src/models/models.h b/src/models/models.h index e7efcc4823af..097fc6063940 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -55,6 +55,27 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); + // chunk-parallel GDN for a multi-token block (DFlash verify): chunked delta-rule built from + // ggml ops (cumsum/tri/solve_tri/mul_mat) - cheap verify of an N-token block + portable. + // Returns {output [S_v,H_v,N,1], new_state [S_v,S_v,H_v,1]}. n_seqs==1, KDA gate only. + std::pair build_delta_net_chunked( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // one block of the tiled chunk-parallel GDN; returns attn [token,S_v,H] and state [i,j,H] + std::pair build_delta_net_one_chunk( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s); + // choose one of two implementations above based on the number of tokens std::pair build_delta_net( ggml_tensor * q, diff --git a/tests/test-gdn-chunked.cpp b/tests/test-gdn-chunked.cpp new file mode 100644 index 000000000000..e64080b0b6eb --- /dev/null +++ b/tests/test-gdn-chunked.cpp @@ -0,0 +1,235 @@ +// Standalone CPU validation: chunked Gated-DeltaNet (ggml-op decomposition) vs the reference +// sequential ggml_gated_delta_net. Build: see the compile cmd at the bottom of this file. +// Goal: bitwise-ish match (max|diff| < 1e-3) so the chunked path can replace the sequential +// GDN kernel on the DFlash verify (block of N tokens), affording larger blocks -> ~3x. +#include "ggml.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" +#include +#include +#include +#include +#include +#include + +// Leaf inputs are created in a no_alloc context (so the graph can run on CUDA); their host data is +// staged here and uploaded with ggml_backend_tensor_set after the backend allocates the graph. +struct Pending { ggml_tensor * t; std::vector data; }; +static std::vector g_pending; + +static ggml_tensor * rnd(ggml_context * c, int64_t a,int64_t b,int64_t d,int64_t e, std::mt19937 & g, float sc, float bias=0.f){ + ggml_tensor * t = ggml_new_tensor_4d(c, GGML_TYPE_F32, a,b,d,e); + std::vector h(ggml_nelements(t)); + std::normal_distribution N(0,1); + for (auto & x : h) x = N(g)*sc + bias; + g_pending.push_back({t, std::move(h)}); + return t; +} + +// L2-normalize each ne0 vector to unit norm (delta-net normalizes q/k). Without this, random keys +// have ||k||^2 ~ S_v*sc^2 >> 2, so beta*||k||^2 violates the delta-rule stability bound and the TRUE +// recurrence diverges -> ref and chunked blow up in the unstable directions for long sequences. +static void l2norm_rows(ggml_tensor * t){ + for (auto & pd : g_pending) if (pd.t == t) { + const int64_t S = t->ne[0]; const int64_t rows = (int64_t)pd.data.size()/S; + for (int64_t r=0;r0 ? (float)(1.0/sqrt(n)) : 0.f; for(int64_t i=0;ine[0], H = v->ne[1], N = v->ne[2]; + const int64_t Hk = q->ne[1]; // GQA: q/k have Hk heads, broadcast (interleaved iv%Hk) to H v-heads + const float scale = 1.0f/sqrtf((float)S_v); + // reorg [S_v,Hx,N] -> [S_v,N,Hx] + auto toDNH = [&](ggml_tensor * x){ return ggml_cont(c, ggml_permute(c, x, 0,2,1,3)); }; + ggml_tensor * qp = toDNH(q), * kp = toDNH(k), * vp = toDNH(v), * gp = toDNH(g); // q/k:[S_v,N,Hk] v/g:[S_v,N,H] + if (Hk != H) { // expand q/k heads Hk->H, interleaved (ggml_repeat tiles ne2: h -> h%Hk) + ggml_tensor * tgt = ggml_new_tensor_3d(c, GGML_TYPE_F32, S_v, N, H); + qp = ggml_repeat(c, qp, tgt); + kp = ggml_repeat(c, kp, tgt); + } + // A_r[i] = prod_{s<=r} exp(g): cumsum over tokens. cumsum is ne0-only -> put tokens on ne0. + ggml_tensor * gN = ggml_cont(c, ggml_permute(c, gp, 1,0,2,3)); // [N,S_v,H] + ggml_tensor * L = ggml_cumsum(c, gN); // [N,S_v,H] cumulative log-decay + ggml_tensor * Lp = ggml_cont(c, ggml_permute(c, L, 1,0,2,3)); // [S_v,N,H] + ggml_tensor * A = ggml_exp(c, Lp); // [S_v|1, N, H] + ggml_tensor * Ainv = ggml_exp(c, ggml_neg(c, Lp)); // 1/A + // full-size tensor first so a scalar gate's A=[1,N,H] broadcasts into kp/qp=[S_v,N,H] + ggml_tensor * Kbar = ggml_mul(c, kp, A); + ggml_tensor * Qbar = ggml_mul(c, qp, A); + ggml_tensor * Ktil = ggml_mul(c, kp, Ainv); + // beta -> [1,N,H] (broadcast over dim) + ggml_tensor * betaNH = ggml_cont(c, ggml_permute(c, beta, 0,2,1,3)); // [1,N,H] + // U_carry[j,r] = sum_i Kbar[i,r] S0[i,j] ; O_carry = scale Qbar . S0 + ggml_tensor * Ucar = ggml_mul_mat(c, S0, Kbar); // [j, r, H] (contract i=ne0) + ggml_tensor * Ocar = ggml_scale(c, ggml_mul_mat(c, S0, Qbar), scale); // [j,r,H] + // rhs[j,r] = beta_r (v[j,r] - Ucar[j,r]) (vp is [S_v(j),N(r),H]) + ggml_tensor * rhs = ggml_mul(c, ggml_sub(c, vp, Ucar), betaNH); // [S_v(j),N(r),H] + // KK[s,r] = sum_i Ktil[i,s] Kbar[i,r] = Kbar_r . Ktil_s. solve_tri wants A[ne0=s, ne1=r] + // = (I+T)[r,s], so KK is already in the right orientation (no transpose). beta_r over ne1=r. + ggml_tensor * KK = ggml_mul_mat(c, Ktil, Kbar); // [s,r,H] + ggml_tensor * Tfull= ggml_mul(c, KK, betaNH); // [s,r,H] * beta_r(ne1) + ggml_tensor * Tlo = ggml_tri(c, Tfull, GGML_TRI_TYPE_LOWER); // keep s works with no_alloc + // / CUDA): all-ones [N,1] via exp(0*beta) then ggml_diag -> [N,N], broadcast over H. + ggml_tensor * b0 = ggml_cont(c, ggml_transpose(c, ggml_view_2d(c, betaNH, 1, N, betaNH->nb[1], 0))); + ggml_tensor * ones = ggml_exp(c, ggml_scale(c, b0, 0.0f)); // [N,1] all-ones + ggml_tensor * Imat = ggml_diag(c, ones); // [N,N] + ggml_tensor * IT = ggml_add(c, Tlo, Imat); // [s,r,H] = (I+T) in solve orientation + // Dmat = (I+T)^-1 rhs. b = rhs [S_v(j), N(r), H] (ne1=N matches A->ne1). result [S_v(j),N(r),H]. + ggml_tensor * Dmat = ggml_solve_tri(c, IT, rhs, true, true, false); // [S_v(j), N(token), H] + Dmat = ggml_cont(c, ggml_transpose(c, Dmat)); // -> [N(token), S_v(j), H] + // O = O_carry + scale * tril(Qbar.Ktil^T incl-diag) @ Dmat. QK[s,r]=Qbar_r.Ktil_s. + ggml_tensor * QK = ggml_mul_mat(c, Ktil, Qbar); // [s,r,H] + ggml_tensor * QKlo = ggml_tri(c, QK, GGML_TRI_TYPE_LOWER_DIAG); // keep s<=r (incl diagonal) + // intra[r,j] = sum_s QKlo[s,r] Dmat[s,j]. mul_mat(QKlo[s,r,H], Dmat[token=s,j,H]) -> [r,j,H] + ggml_tensor * intra = ggml_mul_mat(c, QKlo, Dmat); // contract s=ne0 -> [r,j,H] + intra = ggml_scale(c, intra, scale); // [r,j,H] + // O_carry is [j,r,H]; intra is [r,j,H] -> transpose O_carry + ggml_tensor * OcarT = ggml_cont(c, ggml_transpose(c, Ocar)); // [r,j,H] + ggml_tensor * O = ggml_add(c, OcarT, intra); // [r(N),j(S_v),H] attn per token + *out_attn = O; + // S_out[i,j] = A_end[i] S0[i,j] + sum_r Kw[i,r] Dmat[r,j], Kw_r=(A_end/A_r) k_r + ggml_tensor * Aend = ggml_view_3d(c, A, A->ne[0], 1, H, A->nb[1], A->nb[2], (N-1)*A->nb[1]); // [S_v|1,1,H] + ggml_tensor * AendB = ggml_cont(c, Aend); + ggml_tensor * S0dec = ggml_mul(c, S0, AendB); // [i,j,H] * A_end[i](ne0,bcast over j) + ggml_tensor * Kw = ggml_mul(c, kp, ggml_mul(c, Ainv, AendB)); // (A_end/A_r) k_r [S_v(i),N(r),H] + // Kw^T @ Dmat: sum_r Kw[i,r] Dmat[r,j]. mul_mat contracts ne0 -> need Kw[r,i] and Dmat[r,j] + ggml_tensor * KwT = ggml_cont(c, ggml_transpose(c, Kw)); // [N(r),S_v(i),H] + ggml_tensor * upd = ggml_mul_mat(c, KwT, Dmat); // contract r -> [i,j,H] + *out_state = ggml_add(c, S0dec, upd); // [i,j,H] +} + +// Multi-chunk tiling: split the N tokens into blocks of C, run build_chunked per block carrying the +// recurrent state forward. Bounds Ainv=exp(-cumsum(g)) to <=C tokens -> numerically stable for long +// prefill (single-chunk overflows). ceil(N/C) chunks, unrolled at graph-build time (N is static). +static void build_chunked_tiled(ggml_context * c, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, + ggml_tensor * g, ggml_tensor * beta, ggml_tensor * S0, int64_t C, + ggml_tensor ** out_attn, ggml_tensor ** out_state) { + const int64_t N = v->ne[2]; + ggml_tensor * S = S0; + ggml_tensor * attn_full = nullptr; + for (int64_t start = 0; start < N; start += C) { + const int64_t cn = std::min(C, N - start); + auto slice = [&](ggml_tensor * x){ + return ggml_view_4d(c, x, x->ne[0], x->ne[1], cn, 1, x->nb[1], x->nb[2], x->nb[3], start*x->nb[2]); + }; + ggml_tensor *ac=nullptr,*sc=nullptr; + build_chunked(c, slice(q), slice(k), slice(v), slice(g), slice(beta), S, &ac, &sc); + attn_full = attn_full ? ggml_concat(c, attn_full, ac, 0) : ac; // O is [token, S_v, H]; concat tokens on ne0 + S = ggml_reshape_4d(c, sc, S0->ne[0], S0->ne[0], S0->ne[2], 1); // match model: thread 4D state + } + *out_attn = attn_full; + *out_state = S; +} + +static int run_case(int64_t S_v, int64_t H, int64_t N, std::mt19937 & rng, int64_t Hk=-1, bool scalar_gate=false, int64_t chunk=0){ + if (Hk < 0) Hk = H; + g_pending.clear(); + size_t mem = 64ull*1024*1024; // metadata only; tensor data lives in the backend buffer (no_alloc) + ggml_init_params ip{mem, nullptr, true}; + ggml_context * c = ggml_init(ip); + ggml_tensor * q = rnd(c,S_v,Hk,N,1,rng,0.5f); + ggml_tensor * k = rnd(c,S_v,Hk,N,1,rng,0.5f); + l2norm_rows(q); l2norm_rows(k); // delta-net normalizes q,k -> stable recurrence + ggml_tensor * v = rnd(c,S_v,H,N,1,rng,0.5f); + // gate: vector (KDA, [S_v,H,N]) or per-head scalar (Gated DeltaNet, [1,H,N]) + ggml_tensor * g = rnd(c, scalar_gate ? 1 : S_v, H, N, 1, rng, 0.1f, -0.2f); // log-decay <0 + g = ggml_neg(c, ggml_abs(c, g)); // ensure <=0 -> a<=1 + ggml_tensor * beta = rnd(c,1,H,N,1,rng,0.0f,0.5f); + ggml_tensor * S0 = rnd(c,S_v,S_v,H,1,rng,0.3f); + + ggml_tensor * ref = ggml_gated_delta_net(c, q,k,v,g,beta,S0); // [S_v*H, N+S_v] + ggml_tensor *ca=nullptr,*cs=nullptr; + if (chunk > 0) build_chunked_tiled(c,q,k,v,g,beta,S0,chunk,&ca,&cs); + else build_chunked(c,q,k,v,g,beta,S0,&ca,&cs); + // replicate the model's EXACT output op: permute concatenated O [t,j,H] -> [S_v,H,N,1] + cont + ca = ggml_reshape_4d(c, ggml_cont(c, ggml_permute(c, ca, 2,0,1,3)), S_v, H, N, 1); + cs = ggml_reshape_4d(c, cs, S_v, S_v, H, 1); + + ggml_cgraph * gf = ggml_new_graph_custom(c, 8192, false); + ggml_build_forward_expand(gf, ref); + ggml_build_forward_expand(gf, ca); + ggml_build_forward_expand(gf, cs); + ggml_backend_t be = make_backend(); + ggml_gallocr_t galloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(be)); + ggml_gallocr_alloc_graph(galloc, gf); + for (auto & pd : g_pending) ggml_backend_tensor_set(pd.t, pd.data.data(), 0, pd.data.size()*sizeof(float)); + ggml_backend_graph_compute(be, gf); + + // read outputs back to host (works for CPU and CUDA) + std::vector hr(ggml_nelements(ref)), hca(ggml_nelements(ca)), hcs(ggml_nelements(cs)); + ggml_backend_tensor_get(ref, hr.data(), 0, hr.size()*sizeof(float)); + ggml_backend_tensor_get(ca, hca.data(),0, hca.size()*sizeof(float)); + ggml_backend_tensor_get(cs, hcs.data(),0, hcs.size()*sizeof(float)); + const int64_t ca_n0=ca->ne[0], ca_n1=ca->ne[1], cs_n0=cs->ne[0], cs_n1=cs->ne[1]; + // reference attn: cols 0..N-1 of [S_v*H, N+S_v]; per (head h, token t): ref[ h*S_v + j , t ] + auto refAttn = [&](int h,int t,int j){ return hr[(int64_t)t*(S_v*H) + h*S_v + j]; }; + auto refState= [&](int h,int i,int j){ return hr[(int64_t)S_v*H*N + (int64_t)h*S_v*S_v + (int64_t)j*S_v + i]; }; + // chunked attn O[r,j,H]; state [i,j,H] + // ca is now model layout [S_v,H,N,1]: element[j,h,t] = j + h*S_v + t*S_v*H + (void)ca_n0;(void)ca_n1; + auto caV=[&](int h,int t,int j){ return hca[ (int64_t)t*S_v*H + (int64_t)h*S_v + j ]; }; + auto csV=[&](int h,int i,int j){ return hcs[ (int64_t)h*cs_n0*cs_n1 + (int64_t)j*cs_n0 + i ]; }; + double mA=0, mS=0; + for(int h=0;h0) snprintf(ch,sizeof ch,"C=%lld",(long long)chunk); else snprintf(ch,sizeof ch,"single"); + printf("S_v=%lld H=%lld N=%lld Hk=%lld gate=%-6s %-7s: max|dAttn|=%.2e max|dState|=%.2e -> %s\n", + (long long)S_v,(long long)H,(long long)N,(long long)Hk, scalar_gate?"scalar":"vector", ch, mA, mS, ok?"PASS":"FAIL"); + ggml_gallocr_free(galloc); + ggml_backend_free(be); + ggml_free(c); + return ok; +} + +int main(){ + std::mt19937 rng(0); + int all = 1; + { ggml_backend_t b = make_backend(); printf("backend: %s\n", ggml_backend_name(b)); ggml_backend_free(b); } + printf("== vector (KDA) gate ==\n"); + for (int64_t S_v : {64, 128}) for (int64_t N : {1, 2, 5, 8, 12, 16}) all &= run_case(S_v, 4, N, rng); + // GQA: H_v=4, H_k=2 and H_k=1 (q/k broadcast interleaved) + for (int64_t N : {1, 5, 16}) { all &= run_case(64, 4, N, rng, 2); all &= run_case(128, 4, N, rng, 1); } + printf("== scalar (Gated DeltaNet) gate -- Qwen3.5 ==\n"); + for (int64_t S_v : {64, 128}) for (int64_t N : {1, 2, 5, 8, 12, 16}) all &= run_case(S_v, 4, N, rng, -1, true); + // GQA + scalar gate (Qwen3.5 is GQA: H_v=32, H_k=16 -> ratio 2) + for (int64_t N : {1, 5, 16}) { all &= run_case(64, 4, N, rng, 2, true); all &= run_case(128, 4, N, rng, 1, true); } + printf("== multi-chunk tiling (long sequences; single-chunk overflows) ==\n"); + // N far beyond a verify block; C=8 chunks carry state. C must stay small: the deflation + // A=exp(+/-cumsum(g)) has wide dynamic range, so strong-decay heads lose fp32 precision when a + // chunk is too long (the model garbles at C>=16 on Qwen3.5; the deployed default is C=8). + for (int64_t N : {32, 64, 128, 200}) { + all &= run_case(128, 4, N, rng, -1, false, 8); // vector + all &= run_case(128, 4, N, rng, -1, true, 8); // scalar (Qwen3.5) + } + all &= run_case(128, 4, 128, rng, 1, true, 8); // GQA + scalar + tiled + all &= run_case(64, 4, 96, rng, 2, false, 8); // GQA + vector + tiled + // sanity: tiling with C>=N must equal the single-chunk path + all &= run_case(128, 4, 12, rng, -1, true, 64); + printf("%s\n", all ? "ALL PASS" : "SOME FAIL"); + return all ? 0 : 1; +} From e654a71ab66ee4bbed3669edea455df5d7040191 Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Fri, 12 Jun 2026 18:11:16 +0200 Subject: [PATCH 20/21] server: fix DFlash speculative decoding and add GPU greedy verify The server DFlash path was wired but crashed on every request, because the server processes a prompt in several ubatches while speculative-simple does it in one decode: - index the target features by absolute position and accumulate across ubatches (a chunked prompt previously left the first draft reading stale features), and read the [n_total-n_new, n_total) slice in the drafter - reset dflash_n_past per request in begin() (it carried over between requests) - set the view buffers in dflash_promote_state so the trace/promote copy also runs on the CPU backend (was a CUDA-only path, asserted on a null buffer) Also add GPU greedy verify: for a pure-greedy request the target emits an on-device argmax of the verify block and the host skips the per-block logits download + CPU sampler. Enabled only after the first token is sampled from logits, reset per request; non-greedy requests fall back to the host sampler. Lossless (byte-identical to the host-verify path). ~2.0x -> 2.4x on reasoning. --- common/speculative.cpp | 6 +++ src/llama-context.cpp | 31 +++++++++++-- tools/server/server-context.cpp | 78 ++++++++++++++++++++++++++++++++- 3 files changed, 109 insertions(+), 6 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 378776179b34..1bc4afb7d2fb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -763,6 +763,12 @@ struct common_speculative_state_dflash : public common_speculative_state { void begin(const llama_tokens & prompt) override { GGML_UNUSED(prompt); + // New sequence (server: new request on the slot): the target features are re-extracted from + // position 0 and the DFlash device cross cache is rewritten from there, so reset the running + // count. Without this, dflash_n_past carries over from the previous request and the first + // draft computes n_new = n - dflash_n_past < 1 -> GGML_ASSERT(n_new >= 1) abort. + dflash_n_past = 0; + accumulated_ctx.clear(); } void draft( diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ebf6deda4057..afc97401411c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1326,9 +1326,12 @@ void llama_context::dflash_append_features(const float * feat, int32_t n_new, in const auto & hparams = model.hparams; const size_t n_feat = hparams.dflash_target_layer_ids.size() * hparams.n_embd; - dflash.feat_staging.assign(feat, feat + n_feat * n_new); + // `feat` is the position-indexed target-feature buffer (see extract_dflash_features); the new + // tokens to encode are the [n_total - n_new, n_total) slice, not the first n_new entries. + const int32_t feat_pos0 = n_total - n_new; + dflash.feat_staging.assign(feat + (size_t) feat_pos0 * n_feat, feat + (size_t) n_total * n_feat); dflash.feat_n = n_new; - dflash.feat_pos0 = n_total - n_new; + dflash.feat_pos0 = feat_pos0; dflash.feat_bucket = n_new <= 8 ? 8 : 256; // graph rebuilds when the bucket changes (prompt round) // bucketed mask/position sizing, same scheme as the legacy host-mediated path @@ -1536,16 +1539,23 @@ bool llama_context::dflash_promote_state(int32_t idx, llama_pos pos_last, llama_ } } + // views created in a no_alloc context don't carry a buffer; set it to the parent's so the + // backend copy can resolve buffer_is_host (the CPU backend asserts on a null buffer; the CUDA + // path happened to skip the check). Required for the trace/promote path to run on CPU. ggml_tensor * src_s = ggml_view_1d(cg.get(), dflash.trace_s[il], hparams.n_embd_s(), (size_t) idx * dflash.trace_s[il]->nb[1]); ggml_tensor * dst_s = ggml_view_1d(cg.get(), s_l, hparams.n_embd_s(), (size_t) cell * s_l->nb[1]); + src_s->buffer = dflash.trace_s[il]->buffer; + dst_s->buffer = s_l->buffer; ggml_backend_tensor_copy_async(be, be, src_s, dst_s); ggml_tensor * src_r = ggml_view_1d(cg.get(), dflash.trace_r[il], hparams.n_embd_r(), (size_t) idx * dflash.trace_r[il]->nb[1]); ggml_tensor * dst_r = ggml_view_1d(cg.get(), r_l, hparams.n_embd_r(), (size_t) cell * r_l->nb[1]); + src_r->buffer = dflash.trace_r[il]->buffer; + dst_r->buffer = r_l->buffer; ggml_backend_tensor_copy_async(be, be, src_r, dst_r); } @@ -2997,7 +3007,19 @@ void llama_context::extract_dflash_features(const llama_ubatch & ubatch) { const size_t n_layers = dflash.extract_tensors.size(); const int64_t n_embd_concat = n_embd * n_layers; - dflash.target_features.resize(n_embd_concat * n_tokens); + // Index the per-token features by their ABSOLUTE position, accumulating across ubatches. The + // draft (dflash_append_features) reads the [n_total - n_new, n_total) slice, so a prompt processed + // in multiple ubatches (the server chunks it) and a partial-accept verify block both land at the + // right positions. The old resize(n_tokens) overwrote the buffer with only the last ubatch, so a + // chunked prompt left the first draft reading garbage -> argmax -1 -> "invalid token -1" decode fail. + llama_pos pos_max = -1; + for (int64_t i = 0; i < n_tokens; ++i) { + pos_max = std::max(pos_max, ubatch.pos[i]); + } + const size_t need = (size_t)(pos_max + 1) * n_embd_concat; + if (dflash.target_features.size() < need) { + dflash.target_features.resize(need); + } static thread_local std::vector temp_layer_features; temp_layer_features.resize(n_embd * n_tokens); @@ -3020,8 +3042,9 @@ void llama_context::extract_dflash_features(const llama_ubatch & ubatch) { ggml_backend_sched_synchronize(sched.get()); for (int64_t token_idx = 0; token_idx < n_tokens; ++token_idx) { + const llama_pos pos = ubatch.pos[token_idx]; const float * src = temp_layer_features.data() + token_idx * n_embd; - float * dest = dflash.target_features.data() + token_idx * n_embd_concat + layer_idx * n_embd; + float * dest = dflash.target_features.data() + (size_t) pos * n_embd_concat + layer_idx * n_embd; std::memcpy(dest, src, n_embd * sizeof(float)); } } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9d216109fe5c..07df6475c1e8 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -96,6 +96,11 @@ struct server_slot { // the verify decode so a partial acceptance promotes the state at the accepted position on-device // instead of the ~50 MiB host checkpoint round-trip + re-decode (see llama_dflash_promote_state) bool spec_state_trace = false; + // DFlash GPU greedy verify: when the request samples a raw argmax (pure greedy, no penalties/ + // grammar/logit-bias/n_probs), the target decode emits an on-device argmax of the verify block + // and the host skips the ~n_vocab x block logits download + CPU sampler. Lossless for greedy. + bool spec_gpu_verify = false; + bool spec_argmax_active = false; // out_argmax currently enabled on slot.ctx (after the first sample) llama_pos spec_pos0 = 0; // base position of the current verify batch (rewind target = pos0 + accepted) // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state @@ -1366,6 +1371,30 @@ struct server_context_impl { llama_set_sampler(ctx, slot.id, nullptr); } + // DFlash GPU greedy verify: when this request samples a raw argmax (pure greedy: temp<=0, + // no penalties / DRY / grammar / logit-bias / n_probs), turn on the target's on-device + // argmax so the verify reads block_size+1 ints instead of downloading block_size+1 x n_vocab + // logits and running the host sampler. Lossless for greedy; falls back to the host path + // otherwise. Toggled per request (out_argmax change triggers a one-off graph reserve). + { + const auto & sp = task.params.sampling; + const bool pure_greedy = + sp.temp <= 0.0f && sp.penalty_repeat == 1.0f && sp.penalty_freq == 0.0f && + sp.penalty_present == 0.0f && sp.dry_multiplier == 0.0f && sp.grammar.empty() && + sp.logit_bias.empty() && sp.n_probs == 0; + slot.spec_gpu_verify = params_base.speculative.dflash && params_base.n_parallel == 1 && + pure_greedy && + !(getenv("LLAMA_SPEC_NO_GPU_VERIFY") && std::string(getenv("LLAMA_SPEC_NO_GPU_VERIFY")) != "0"); + // out_argmax is turned ON only after the first token is sampled from logits (like + // speculative-simple): the prompt's first sample needs raw logits, the verify loop + // afterwards reads the on-device argmax. Reset here for a fresh request on the slot. + llama_set_out_argmax(slot.ctx, false); + slot.spec_argmax_active = false; + if (slot.spec_gpu_verify) { + SLT_INF(slot, "%s", "DFlash GPU greedy verify enabled (on-device argmax)\n"); + } + } + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); } else { slot.smpl.reset(); @@ -2956,7 +2985,22 @@ struct server_context_impl { const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); + llama_token id; + if (slot.spec_gpu_verify && slot.spec_argmax_active) { + // out_argmax already on (a rare empty-draft round after the first token): read the + // target's on-device argmax for this output row instead of the (unavailable) logits + int32_t n_am = 0; + const int32_t * am = llama_get_dflash_argmax(slot.ctx, &n_am); + GGML_ASSERT(am != nullptr && tok_idx < n_am && "DFlash target argmax missing"); + id = (llama_token) am[tok_idx]; + } else { + id = common_sampler_sample(slot.smpl.get(), slot.ctx, tok_idx); + // first sample done from logits -> enable on-device argmax for the verify loop + if (slot.spec_gpu_verify && !slot.spec_argmax_active) { + llama_set_out_argmax(slot.ctx, true); + slot.spec_argmax_active = true; + } + } slot.i_batch = -1; @@ -3018,7 +3062,37 @@ struct server_context_impl { } GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + llama_tokens accepted; + if (slot.spec_gpu_verify) { + // greedy accept from the target's on-device argmax (DFlash sets output_all, so + // the argmax row == the verify token's batch index in spec_i_batch). Same + // semantics as common_sampler_sample_and_accept_n with a greedy sampler: take + // the target token at each position up to and including the first mismatch, plus + // a bonus token if every draft matched. Skips the per-block logits download. + int32_t n_am = 0; + const int32_t * am = llama_get_dflash_argmax(slot.ctx, &n_am); + GGML_ASSERT(am != nullptr && "DFlash target argmax missing"); + size_t k = 0; + for (; k < slot.spec_draft.size(); ++k) { + const int32_t row = slot.spec_i_batch[k]; + GGML_ASSERT(row < n_am); + const llama_token t = (llama_token) am[row]; + accepted.push_back(t); + common_sampler_accept(slot.smpl.get(), t, true); + if (slot.spec_draft[k] != t) { + break; + } + } + if (k == slot.spec_draft.size()) { // all drafts matched -> bonus token + const int32_t row = slot.spec_i_batch[k]; + GGML_ASSERT(row < n_am); + const llama_token t = (llama_token) am[row]; + accepted.push_back(t); + common_sampler_accept(slot.smpl.get(), t, true); + } + } else { + accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx, slot.spec_i_batch, slot.spec_draft); + } slot.spec_i_batch.clear(); SLT_DBG(slot, "%s: n_draft=%zu, accepted=%zu\n", __func__, slot.spec_draft.size(), accepted.size()); From bf644809aaa125d8f7380820bae13104fa52d13e Mon Sep 17 00:00:00 2001 From: Aleksandr Nikolich Date: Tue, 23 Jun 2026 15:01:51 +0200 Subject: [PATCH 21/21] qwen35moe: add nextn/MTP predict-layer support - Read nextn_predict_layers from GGUF KV; set n_layer_kv_from_start - Fix recurrent_layer_arr: MTP blocks (i >= main_layers) are always full-attention (not recurrent), only main blocks use the 4-cycle rule - Use main_layers count (not n_layer) for LLM_TYPE detection - Register all blk.N.nextn.* and blk.N.attn_*/ffn_* tensors for nextn blocks with TENSOR_SKIP (loaded but unused in main forward pass, consistent with GLM4_MOE pattern) - Add NEXTN tensor names to QWEN35MOE arch in gguf/constants.py Companion file mtp-SIQ-1-35B.f16.gguf (block 40, from Qwen3.6-35B-A3B) uploaded to AlexWortega/SIQ-1-35B; merge with merge_mtp_gguf.py. Co-Authored-By: Claude Sonnet 4.6 --- gguf-py/gguf/constants.py | 9 +++++- src/llama-model.cpp | 64 ++++++++++++++++++++++++++++----------- 2 files changed, 54 insertions(+), 19 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c3b3cb37fae2..f9f56df57b26 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2053,6 +2053,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.ATTN_GATE, MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_INP_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, @@ -2068,7 +2069,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_NORM, MODEL_TENSOR.SSM_BETA, MODEL_TENSOR.SSM_ALPHA, - MODEL_TENSOR.SSM_OUT + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 47668954f59c..940c82b593a3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2896,16 +2896,28 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // MTP/nextn support + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + if (hparams.nextn_predict_layers > 0) { + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + } + + // Mark recurrent layers (linear attention layers); MTP layers are always full-attention { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + const uint32_t main_layers = hparams.n_layer - hparams.nextn_predict_layers; for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + if (i < main_layers) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } else { + hparams.recurrent_layer_arr[i] = false; // MTP blocks use full attention + } } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -7701,22 +7713,28 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t conv_dim = key_dim * 2 + value_dim; for (int i = 0; i < n_layer; ++i) { + int flags = 0; + const bool is_nextn = hparams.nextn_predict_layers > 0 && + static_cast(i) >= n_layer - hparams.nextn_predict_layers; + if (is_nextn) { + flags |= TENSOR_SKIP; // MTP block tensors are preserved but not used in main forward pass + } + auto & layer = layers[i]; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); if (!hparams.is_recurrent(i)) { - // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + // Full attention layers (and MTP blocks which are always full-attention) + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); @@ -7728,17 +7746,27 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, flags); // Shared experts const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); + + // Nextn/MTP tensors - load for the last nextn_predict_layers blocks + if (is_nextn) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + } } } break; case LLM_ARCH_QWEN35: