diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 4a5f6288854..9ba1d19ebbb 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -31,12 +31,10 @@ import argparse import os -from dataclasses import dataclass, field -from typing import Literal +from dataclasses import dataclass import torch import transformers -from accelerate import ParallelismConfig from eagle_utils import ( EagleTrainerWithAccLog, EagleTrainingPlot, @@ -44,200 +42,75 @@ make_speculative_data_module, patch_ring_attention_for_ttt, ) -from omegaconf import OmegaConf +from rich.pretty import pprint from transformers.trainer_utils import get_last_checkpoint import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.config import DFlashConfig, EagleConfig +from modelopt.recipe import load_recipe +from modelopt.recipe.config import ModelOptDFlashRecipe, ModelOptEagleRecipe, ModelOptMedusaRecipe from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 +from modelopt.torch.utils.distributed import is_master torch.manual_seed(0) mto.enable_huggingface_checkpointing() @dataclass -class ModelArguments: - model_name_or_path: str | None = field( - default="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - metadata={"help": "HuggingFace model ID or local path to the base model."}, - ) - use_fake_base_for_offline: bool = field( - default=False, - metadata={ - "help": "Load model architecture without real base weights. Offline training only." - }, - ) - trust_remote_code: bool = field( - default=False, metadata={"help": "Trust remote code when loading model."} - ) - - -@dataclass -class DataArguments: - data_path: str = field( - default=None, - metadata={"help": "Path to the online training data."}, - ) - offline_data_path: str = field( - default=None, - metadata={ - "help": "Path to offline training data directory (.pt files). This argument enables offline mode.", - }, - ) - lazy_preprocess: bool = True - draft_vocab_cache: str | None = field( - default=None, - metadata={"help": "Path to draft vocabulary cache file."}, - ) - chat_template: str = field( - default=None, - metadata={ - "help": "Jinja chat template with {% generation %} tags for answer_only_loss. " - "If not set, the tokenizer's built-in template is used (must already have generation tags)." - }, - ) - vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) - vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) - sample_size: int = field( - default=-1, - metadata={"help": "Number of samples to use for training. Use -1 to use all samples."}, - ) +class HfTrainingArguments(transformers.TrainingArguments): + """HF-compatible TrainingArguments with our speculative-decoding extensions. - def __post_init__(self): - if self.sample_size == 0 or self.sample_size < -1: - raise ValueError("sample_size must be -1 (use all samples) or a positive integer") - - -@dataclass -class TrainingArguments(transformers.TrainingArguments): - training_seq_len: int = field( - default=2048, - metadata={ - "help": ( - "Training sequence length. Sequences will be right padded or truncated to this length." - ) - }, - ) - mode: Literal["eagle3", "medusa", "dflash"] = "eagle3" - estimate_ar: bool = field( - default=False, metadata={"help": "Whether to estimate AR using training accuracy to log."} - ) - ar_validate_steps: int = field(default=1000, metadata={"help": "AR validation interval."}) - answer_only_loss: bool = field( - default=False, - metadata={ - "help": "Mask loss on non-assistant tokens. Requires a chat_template with generation tags." - }, - ) - cp_size: int = field(default=1, metadata={"help": "Context parallelism size."}) - dp_shard_size: int | None = field( - default=None, - metadata={"help": "Data parallelism shard size. None = auto (total_gpu / cp_size)."}, - ) + Used only to build the ``transformers.Trainer``-compatible object at runtime via + ``HfTrainingArguments(**recipe.training.model_dump())``. Field set MUST stay in sync + with :class:`modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments`. + """ - -@dataclass -class MedusaArguments: - medusa_num_heads: int | None = field(default=1) - medusa_num_layers: int | None = field(default=1) + training_seq_len: int = 2048 + estimate_ar: bool = False + ar_validate_steps: int = 1000 + answer_only_loss: bool = False + cp_size: int = 1 + dp_shard_size: int | None = None def _parse_cli() -> tuple[str, list[str]]: - """Parse --config (required) from argv; return remaining args as config overrides. + """Parse --config (required) from argv; return remaining args as dotlist overrides. - Extra arguments use OmegaConf dotlist syntax, e.g. + Extra positional args use dotlist syntax, e.g. ``model.model_name_or_path=meta-llama/Llama-3.2-1B training.output_dir=ckpts/test``. """ p = argparse.ArgumentParser(add_help=False) - p.add_argument("--config", required=True, help="Path to the YAML config file.") + p.add_argument( + "--config", + required=True, + help=( + "Path to a modelopt speculative-decoding recipe YAML " + "(speculative_eagle / speculative_dflash / speculative_medusa)." + ), + ) args, overrides = p.parse_known_args() return args.config, overrides -def _load_config(config_path: str, overrides: list[str] = ()) -> tuple[dict, dict, dict]: - """Load training config from a YAML file with sections: model, data, training, eagle/dflash. - - *overrides* are OmegaConf dotlist entries (e.g. ``["model.model_name_or_path=xxx"]``) - applied on top of the YAML. - - Returns: - hf_cfg: Flat dict from model/data/training sections, for HfArgumentParser.parse_dict() - eagle_cfg: Eagle section dict (EagleConfig fields), passed directly to mtsp.convert() - dflash_cfg: DFlash section dict (DFlashConfig fields), passed directly to mtsp.convert() - """ - merged = OmegaConf.load(config_path) - if overrides: - merged = OmegaConf.merge(merged, OmegaConf.from_dotlist(list(overrides))) - cfg = OmegaConf.to_container(merged, resolve=True) - - # Eagle/DFlash sections map directly to config fields — no field enumeration needed. - eagle_cfg = cfg.get("eagle", {}) - dflash_cfg = cfg.get("dflash", {}) - - hf_cfg = { - **cfg.get("model", {}), - **cfg.get("data", {}), - **cfg.get("training", {}), - } - - if hf_cfg.get("dp_shard_size") is None: - cp_size = hf_cfg.get("cp_size", 1) - # Use WORLD_SIZE (total GPUs across all nodes) when available, else local GPU count. - world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) - hf_cfg["dp_shard_size"] = world_size // cp_size - - return hf_cfg, eagle_cfg, dflash_cfg - - def train(): config_path, overrides = _parse_cli() - hf_cfg, eagle_cfg, dflash_cfg = _load_config(config_path, overrides) + recipe = load_recipe(config_path, overrides=overrides) - parser = transformers.HfArgumentParser( - ( - ModelArguments, - DataArguments, - TrainingArguments, - MedusaArguments, - ) - ) - model_args, data_args, training_args, medusa_args = parser.parse_dict( - hf_cfg, allow_extra_keys=True - ) + # Pydantic-typed sections flow straight through as *_args; only TrainingArguments is + # reconstructed as an HF dataclass so it can be handed to transformers.Trainer. + training_args = HfTrainingArguments(**recipe.training.model_dump()) - if not data_args.data_path and not data_args.offline_data_path: + if not recipe.data.data_path and not recipe.data.offline_data_path: raise ValueError( "Either data.data_path or data.offline_data_path must be set in the config." ) - if training_args.cp_size > 1 or training_args.dp_shard_size > 1: - # Auto-compute dp_replicate_size so that - # dp_replicate_size * dp_shard_size * cp_size == world_size. - # Note: torch.cuda.device_count() returns per-node GPU count, not world_size. - # WORLD_SIZE (set by torchrun/accelerate) gives the correct multi-node total. - world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) - parallel_size = training_args.dp_shard_size * training_args.cp_size - if world_size % parallel_size != 0: - raise ValueError( - f"world_size ({world_size}) must be divisible by " - f"dp_shard_size ({training_args.dp_shard_size}) * cp_size ({training_args.cp_size}) " - f"= {parallel_size}" - ) - dp_replicate_size = world_size // parallel_size - training_args.parallelism_config = ParallelismConfig( - cp_size=training_args.cp_size, - dp_shard_size=training_args.dp_shard_size, - dp_replicate_size=dp_replicate_size, - ) if training_args.cp_size > 1: patch_ring_attention_for_ttt() # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 training_args.parallelism_config.sp_backend = None - print_rank_0( - f"arguments: {model_args}, {training_args}, {medusa_args}, " - f"eagle_cfg={eagle_cfg}, dflash_cfg={dflash_cfg}" - ) + if is_master(): + pprint(recipe) # Detect checkpoint to resume from last_checkpoint = ( @@ -250,80 +123,53 @@ def train(): checkpoint = training_args.resume_from_checkpoint or last_checkpoint - use_offline_training = data_args.offline_data_path is not None + use_offline_training = recipe.data.offline_data_path is not None if checkpoint: with patch_transformers5_params_loading(): model = load_vlm_or_llm( - checkpoint, dtype="auto", trust_remote_code=model_args.trust_remote_code + checkpoint, dtype="auto", trust_remote_code=recipe.model.trust_remote_code ) tokenizer = transformers.AutoTokenizer.from_pretrained( - checkpoint, trust_remote_code=model_args.trust_remote_code + checkpoint, trust_remote_code=recipe.model.trust_remote_code ) else: # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). - if use_offline_training: - # Load config first to preserve original num_hidden_layers before - # load_vlm_or_llm may reduce layers for offline space savings. - model_config = transformers.AutoConfig.from_pretrained( - model_args.model_name_or_path, - trust_remote_code=model_args.trust_remote_code, - ) model = load_vlm_or_llm( - model_args.model_name_or_path, - use_fake_base=model_args.use_fake_base_for_offline, + recipe.model.model_name_or_path, + use_fake_base=recipe.model.use_fake_base_for_offline, use_offline_training=use_offline_training, dtype="auto", device_map="cpu", - trust_remote_code=model_args.trust_remote_code, + trust_remote_code=recipe.model.trust_remote_code, ) - if use_offline_training: - # When doing offline training, we need to set num_hidden_layers - # since we override it when loading the model for space savings. - # Some models (e.g. Kimi-K2.5) use non-standard config attributes, - # so fall back to the model's own config if the attribute is missing. - model.config.num_orig_hidden_layers = getattr( - model_config, "num_hidden_layers", model.config.num_hidden_layers - ) - if hasattr(model.config, "layer_types"): - del ( - model.config.layer_types - ) # remove layer_types to avoid mismatch with the modified model tokenizer = transformers.AutoTokenizer.from_pretrained( - model_args.model_name_or_path, + recipe.model.model_name_or_path, model_max_length=training_args.training_seq_len, - trust_remote_code=model_args.trust_remote_code, + trust_remote_code=recipe.model.trust_remote_code, ) - if training_args.mode == "medusa": - config = { - "medusa_num_heads": medusa_args.medusa_num_heads, - "medusa_num_layers": medusa_args.medusa_num_layers, - } - mtsp.convert(model, [("medusa", config)]) - elif training_args.mode == "eagle3": - # Validate and rewrite eagle config fields - eagle_cfg = EagleConfig.model_validate( - eagle_cfg, - context={"training_args": training_args, "data_args": data_args}, - ).model_dump() + if isinstance(recipe, ModelOptMedusaRecipe): + medusa_cfg: dict = recipe.medusa.model_dump() + mtsp.convert(model, [("medusa", medusa_cfg)]) + elif isinstance(recipe, ModelOptEagleRecipe): + eagle_cfg: dict = recipe.eagle.model_dump() mtsp.convert(model, [("eagle", eagle_cfg)]) - - # Load draft vocab cache if the draft model uses a compressed vocabulary - if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: - if not os.path.isfile(data_args.draft_vocab_cache): - raise FileNotFoundError( - f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" - ) - model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) - print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") - elif training_args.mode == "dflash": - dflash_cfg = DFlashConfig.model_validate( - dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args} - ).model_dump() + # Load draft vocab cache + mtsp.plugins.HFEagleModel.load_draft_vocab_cache(model, recipe.data.draft_vocab_cache) + elif isinstance(recipe, ModelOptDFlashRecipe): + # Fall back to tokenizer.mask_token_id when not set in the recipe; require one of the two. + if recipe.dflash.dflash_mask_token_id is None: + recipe.dflash.dflash_mask_token_id = getattr(tokenizer, "mask_token_id", None) + if recipe.dflash.dflash_mask_token_id is None: + raise ValueError( + "dflash.dflash_mask_token_id is required: set it in the recipe YAML " + "or use a tokenizer that defines mask_token_id." + ) + dflash_cfg: dict = recipe.dflash.model_dump() mtsp.convert(model, [("dflash", dflash_cfg)]) else: - raise Exception(f"{training_args.mode} is not supported!") + raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}") # Move any remaining CPU buffers to CUDA so DDP (NCCL-only) can broadcast # them. We iterate named_buffers and reassign via the owning module to @@ -340,18 +186,22 @@ def train(): setattr(mod, parts[-1], buf.to(_target_dev)) print_rank_0("Loading dataset...") - is_dflash = training_args.mode == "dflash" - if training_args.mode in ("eagle3", "dflash"): + is_dflash = isinstance(recipe, ModelOptDFlashRecipe) + if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)): data_module = make_speculative_data_module( tokenizer, - data_args, + recipe.data, train_len=training_args.training_seq_len, answer_only_loss=training_args.answer_only_loss, shift_labels=not is_dflash, ) callbacks = [EagleTrainingPlot(training_args.ar_validate_steps, training_args.estimate_ar)] - if eagle_cfg.get("eagle_base_lora") and eagle_cfg.get("eagle_base_lora_warmup_steps", 0) > 0: + if ( + isinstance(recipe, ModelOptEagleRecipe) + and eagle_cfg.get("eagle_base_lora") + and eagle_cfg.get("eagle_base_lora_warmup_steps", 0) > 0 + ): callbacks.append(LoRAWarmupCallback(eagle_cfg["eagle_base_lora_warmup_steps"])) trainer = EagleTrainerWithAccLog( diff --git a/modelopt/recipe/config.py b/modelopt/recipe/config.py index cc9276c0ff3..16476c4a915 100644 --- a/modelopt/recipe/config.py +++ b/modelopt/recipe/config.py @@ -17,18 +17,28 @@ from __future__ import annotations +import warnings from enum import Enum -from pydantic import field_validator +from pydantic import field_validator, model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField from modelopt.torch.quantization.config import QuantizeConfig +from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig +from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs +from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs +from modelopt.torch.speculative.plugins.hf_training_args import ( + TrainingArguments as SpecTrainingArgs, +) class RecipeType(str, Enum): """List of recipe types.""" PTQ = "ptq" + SPECULATIVE_EAGLE = "speculative_eagle" + SPECULATIVE_DFLASH = "speculative_dflash" + SPECULATIVE_MEDUSA = "speculative_medusa" # QAT = "qat" # Not implemented yet, will be added in the future. @@ -72,3 +82,94 @@ class ModelOptPTQRecipe(ModelOptRecipeBase): description="PTQ config containing quant_cfg and algorithm.", validate_default=True, ) + + +class ModelOptSpeculativeRecipeBase(ModelOptRecipeBase): + """Base class for speculative-decoding recipes. + + Unlike PTQ, speculative-decoding is a training-time optimization: the draft head is trained + with HF Trainer. We therefore bundle ``model`` / ``data`` / ``training`` sections into the + recipe so a single YAML is the full experiment spec. Each section is a typed Pydantic model + (see :mod:`modelopt.torch.speculative.plugins.hf_training_args`) so field typos and bad + values are caught at recipe-load time; HF trainer fields pass through + ``TrainingArguments`` via ``extra='allow'``. + """ + + model: SpecModelArgs = ModeloptField( + default=SpecModelArgs(), + title="HF model args", + description="ModelArguments for the base HF model to train a draft head against.", + validate_default=True, + ) + data: SpecDataArgs = ModeloptField( + default=SpecDataArgs(), + title="HF data args", + description="DataArguments for the training/offline dataset.", + validate_default=True, + ) + training: SpecTrainingArgs = ModeloptField( + default=SpecTrainingArgs(), + title="HF training args", + description="Speculative-decoding extensions; HF trainer fields flow through as extras.", + validate_default=True, + ) + + +class ModelOptEagleRecipe(ModelOptSpeculativeRecipeBase): + """Our config class for EAGLE speculative decoding recipes.""" + + recipe_type: RecipeType = RecipeType.SPECULATIVE_EAGLE + + eagle: EagleConfig = ModeloptField( + default=EagleConfig(), + title="EAGLE config", + description="EAGLE speculative decoding configuration.", + validate_default=True, + ) + + @model_validator(mode="after") + def _derive_eagle_offline(self) -> ModelOptEagleRecipe: + self.eagle.eagle_offline = self.data.offline_data_path is not None + return self + + @model_validator(mode="after") + def _warn_rope_vs_training_seq_len(self) -> ModelOptEagleRecipe: + orig_max_pos = self.eagle.eagle_export_rope_scaling.get("original_max_position_embeddings") + if orig_max_pos is not None and orig_max_pos != self.training.training_seq_len: + warnings.warn( + f"eagle.eagle_export_rope_scaling.original_max_position_embeddings ({orig_max_pos}) " + f"differs from training.training_seq_len ({self.training.training_seq_len}). " + f"This may affect long-context inference quality." + ) + return self + + +class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase): + """Our config class for DFlash speculative decoding recipes.""" + + recipe_type: RecipeType = RecipeType.SPECULATIVE_DFLASH + + dflash: DFlashConfig = ModeloptField( + default=DFlashConfig(), + title="DFlash config", + description="DFlash speculative decoding configuration.", + validate_default=True, + ) + + @model_validator(mode="after") + def _derive_dflash_offline(self) -> ModelOptDFlashRecipe: + self.dflash.dflash_offline = self.data.offline_data_path is not None + return self + + +class ModelOptMedusaRecipe(ModelOptSpeculativeRecipeBase): + """Our config class for Medusa speculative decoding recipes.""" + + recipe_type: RecipeType = RecipeType.SPECULATIVE_MEDUSA + + medusa: MedusaConfig = ModeloptField( + default=MedusaConfig(), + title="Medusa config", + description="Medusa speculative decoding configuration.", + validate_default=True, + ) diff --git a/modelopt/recipe/loader.py b/modelopt/recipe/loader.py index 3a9c66fb22d..a4cce1bb687 100644 --- a/modelopt/recipe/loader.py +++ b/modelopt/recipe/loader.py @@ -21,8 +21,17 @@ from importlib.abc import Traversable from pathlib import Path +from omegaconf import OmegaConf + from ._config_loader import BUILTIN_RECIPES_LIB, load_config -from .config import ModelOptPTQRecipe, ModelOptRecipeBase, RecipeType +from .config import ( + ModelOptDFlashRecipe, + ModelOptEagleRecipe, + ModelOptMedusaRecipe, + ModelOptPTQRecipe, + ModelOptRecipeBase, + RecipeType, +) __all__ = ["load_config", "load_recipe"] @@ -49,17 +58,29 @@ def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traver return recipe_path -def load_recipe(recipe_path: str | Path | Traversable) -> ModelOptRecipeBase: - """Load a recipe from a YAML file or directory. +def load_recipe( + recipe_path: str | Path | Traversable, + overrides: list[str] | None = None, +) -> ModelOptRecipeBase: + """Load a recipe from a YAML file or directory, with optional CLI-style overrides. ``recipe_path`` can be: - * A ``.yml`` / ``.yaml`` file with ``metadata`` and ``quantize`` sections. - The suffix may be omitted and will be probed automatically. - * A directory containing ``recipe.yml`` (metadata) and ``quantize.yml``. + * A ``.yml`` / ``.yaml`` file with ``metadata`` and one of ``quantize`` (PTQ), + ``eagle`` (EAGLE speculative decoding), ``dflash`` (DFlash speculative + decoding) or ``medusa`` (Medusa speculative decoding) sections. The suffix + may be omitted and will be probed automatically. + * A directory containing ``recipe.yml`` (metadata) plus ``quantize.yml`` — + **PTQ recipes only**. Speculative-decoding recipes are always single YAML files. The path may be relative to the built-in recipes library or an absolute / relative filesystem path. + + ``overrides`` is an optional list of ``key.path=value`` dotlist entries applied + on top of the YAML before Pydantic validation. Values are parsed with + ``yaml.safe_load`` so they get proper types (``foo.bar=true`` → bool, ``foo=1`` + → int, ``foo=[1,2]`` → list, etc.). Only supported when *recipe_path* is a + single YAML file. """ resolved = _resolve_recipe_path(recipe_path) @@ -72,21 +93,43 @@ def load_recipe(recipe_path: str | Path | Traversable) -> ModelOptRecipeBase: print(f"[load_recipe] loading: {_display}") if resolved.is_file(): - return _load_recipe_from_file(resolved) + return _load_recipe_from_file(resolved, overrides=overrides) if resolved.is_dir(): + if overrides: + raise ValueError( + "overrides are not supported for directory-format recipes; " + "use the single-YAML-file form instead." + ) return _load_recipe_from_dir(resolved) raise ValueError(f"Recipe path {recipe_path!r} is not a valid YAML file or directory.") -def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBase: - """Load a recipe from a YAML file. +def _apply_dotlist(data: dict, overrides: list[str]) -> dict: + """Merge ``a.b.c=value`` command line overrides on top of ``data`` via OmegaConf.""" + for entry in overrides: + if "=" not in entry: + raise ValueError(f"Invalid override (missing '='): {entry!r}") + merged = OmegaConf.merge( + OmegaConf.create(data), + OmegaConf.from_dotlist(list(overrides)), + ) + return OmegaConf.to_container(merged, resolve=True) + + +def _load_recipe_from_file( + recipe_file: Path | Traversable, + overrides: list[str] | None = None, +) -> ModelOptRecipeBase: + """Load a recipe from a YAML file, optionally applying dotlist overrides. The file must contain a ``metadata`` section with at least ``recipe_type``, - plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes. + plus the algorithm-specific section (``quantize`` / ``eagle`` / ``dflash`` / ``medusa``). """ data = load_config(recipe_file) + if overrides: + data = _apply_dotlist(data, overrides) metadata = data.get("metadata", {}) recipe_type = metadata.get("recipe_type") @@ -101,6 +144,36 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas description=metadata.get("description", "PTQ recipe."), quantize=data["quantize"], ) + if recipe_type == RecipeType.SPECULATIVE_EAGLE: + if "eagle" not in data: + raise ValueError(f"EAGLE recipe file {recipe_file} must contain 'eagle'.") + return ModelOptEagleRecipe( + description=metadata.get("description", "EAGLE speculative decoding recipe."), + model=data.get("model") or {}, + data=data.get("data") or {}, + training=data.get("training") or {}, + eagle=data["eagle"], + ) + if recipe_type == RecipeType.SPECULATIVE_DFLASH: + if "dflash" not in data: + raise ValueError(f"DFlash recipe file {recipe_file} must contain 'dflash'.") + return ModelOptDFlashRecipe( + description=metadata.get("description", "DFlash speculative decoding recipe."), + model=data.get("model") or {}, + data=data.get("data") or {}, + training=data.get("training") or {}, + dflash=data["dflash"], + ) + if recipe_type == RecipeType.SPECULATIVE_MEDUSA: + if "medusa" not in data: + raise ValueError(f"Medusa recipe file {recipe_file} must contain 'medusa'.") + return ModelOptMedusaRecipe( + description=metadata.get("description", "Medusa speculative decoding recipe."), + model=data.get("model") or {}, + data=data.get("data") or {}, + training=data.get("training") or {}, + medusa=data["medusa"], + ) raise ValueError(f"Unsupported recipe type: {recipe_type!r}") diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index d82d471d01a..230ed4d9134 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -15,11 +15,9 @@ """Configurations for speculative decoding modes.""" -import warnings from copy import deepcopy -from typing import Any -from pydantic import ValidationInfo, model_validator +from pydantic import model_validator from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField @@ -71,7 +69,7 @@ class DFlashConfig(ModeloptBaseConfig): default=False, description=( "Whether to use detached DFlash (offline training from pre-computed hidden states). " - "Auto-derived from data_args.offline_data_path during validation — not user-configurable." + "Derived by ModelOptDFlashRecipe from data.offline_data_path; not user-configurable." ), ) @@ -103,10 +101,12 @@ class DFlashConfig(ModeloptBaseConfig): default=True, description="Whether to report eval accuracy." ) - dflash_mask_token_id: int = ModeloptField( + dflash_mask_token_id: int | None = ModeloptField( default=None, - description="Token ID used for masked (unknown) positions. " - "Set explicitly or auto-detected from tokenizer.mask_token_id in main.py.", + description=( + "Token ID used for masked (unknown) positions. Set explicitly in the recipe YAML, " + "or left unset to fall back to ``tokenizer.mask_token_id`` at training time." + ), ) dflash_architecture_config: dict = ModeloptField( @@ -118,43 +118,6 @@ class DFlashConfig(ModeloptBaseConfig): description="Whether to use torch.compile on DFlash forward/loss methods.", ) - @model_validator(mode="before") - @classmethod - def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any: - """Derive ``dflash_offline`` from ``data_args.offline_data_path``. - - This field is auto-derived, not user-configurable: when context provides - ``data_args``, the derived value overrides any user-supplied value. - """ - ctx = info.context if info.context else {} - data_args = ctx.get("data_args") - if data_args is not None and isinstance(data, dict): - data["dflash_offline"] = getattr(data_args, "offline_data_path", None) is not None - return data - - @model_validator(mode="before") - @classmethod - def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any: - """Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context.""" - if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None: - return data - ctx = info.context if info.context else {} - tokenizer = ctx.get("tokenizer") - if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None: - data["dflash_mask_token_id"] = tokenizer.mask_token_id - return data - - @model_validator(mode="after") - def _check_mask_token_id(self) -> "DFlashConfig": - """Validate that mask_token_id is set after all resolution attempts.""" - if self.dflash_mask_token_id is None: - raise ValueError( - "dflash_mask_token_id is required. Set it in the config YAML " - "(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer " - "has a mask_token_id attribute." - ) - return self - class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" @@ -174,7 +137,11 @@ class EagleConfig(ModeloptBaseConfig): """Eagle config.""" eagle_offline: bool = ModeloptField( - default=False, description=("Whether to use detached Eagle.") + default=False, + description=( + "Whether to use detached Eagle. Derived by ModelOptEagleRecipe from " + "data.offline_data_path; not user-configurable." + ), ) eagle_hidden_state_distillation: bool = ModeloptField( @@ -292,16 +259,6 @@ class EagleConfig(ModeloptBaseConfig): ), ) - @model_validator(mode="before") - @classmethod - def _derive_eagle_offline(cls, data: Any, info: ValidationInfo) -> Any: - """Derive ``eagle_offline`` from ``data_args.offline_data_path`` when provided in context.""" - ctx = info.context if info.context else {} - data_args = ctx.get("data_args") - if data_args is not None and isinstance(data, dict): - data["eagle_offline"] = data_args.offline_data_path is not None - return data - @model_validator(mode="after") def _check_rope_scaling_consistency(self) -> "EagleConfig": if not self.eagle_export_rope_scaling: @@ -315,18 +272,3 @@ def _check_rope_scaling_consistency(self) -> "EagleConfig": f"training rope_type is 'default' (no scaling)." ) return self - - @model_validator(mode="after") - def _warn_rope_vs_training_seq_len(self, info: ValidationInfo) -> "EagleConfig": - ctx = info.context if info.context else {} - training_args = ctx.get("training_args") - if training_args is None: - return self - orig_max_pos = self.eagle_export_rope_scaling.get("original_max_position_embeddings") - if orig_max_pos is not None and orig_max_pos != training_args.training_seq_len: - warnings.warn( - f"eagle_export_rope_scaling.original_max_position_embeddings ({orig_max_pos}) " - f"differs from training_seq_len ({training_args.training_seq_len}). " - f"This may affect long-context inference quality." - ) - return self diff --git a/modelopt/torch/speculative/plugins/hf_eagle.py b/modelopt/torch/speculative/plugins/hf_eagle.py index f2040d9d960..c54ff039bc2 100644 --- a/modelopt/torch/speculative/plugins/hf_eagle.py +++ b/modelopt/torch/speculative/plugins/hf_eagle.py @@ -17,6 +17,7 @@ import contextlib import copy +import os from typing import Any import torch @@ -25,6 +26,8 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.utils import ModelOutput +from modelopt.torch.utils import print_rank_0 + from ...export.plugins.hf_spec_export import EagleExporter, SpeculativeDecodingExporter from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel @@ -88,7 +91,7 @@ def _nvtx_range(self, name): return nvtx.range(name) except Exception as e: - print(f"Failed to create NVTX range {name}: {e}") + print_rank_0(f"Failed to create NVTX range {name}: {e}") return contextlib.nullcontext() def _find_base_model_parts(self): @@ -105,7 +108,7 @@ def _find_base_model_parts(self): try: submodule = self.get_submodule(path) assert isinstance(submodule, torch.nn.Module) - print(f"Found {name} at {path}") + print_rank_0(f"Found {name} at {path}") found_submodule = True setattr(self, name, path) break @@ -128,7 +131,7 @@ def _activate_torch_compile(self): try: setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) except Exception: # noqa: PERF203 - print(f"Disabling torch.compile for {name} due to compilation error.") + print_rank_0(f"Disabling torch.compile for {name} due to compilation error.") def get_dummy_inputs(self) -> dict: """Construct dummy inputs for export forward pass.""" @@ -250,6 +253,16 @@ def _preservation_loss( ) return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight + @staticmethod + def load_draft_vocab_cache(model, d2t_path: str) -> None: + """Load the draft vocab cache from the given path.""" + if d2t_path is None or model.eagle_config.draft_vocab_size >= model.eagle_config.vocab_size: + return + if not os.path.isfile(d2t_path): + raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t_path}") + model.eagle_module.d2t = torch.load(d2t_path, weights_only=True) + print_rank_0(f"Loaded draft vocab cache from {d2t_path}.") + def modify( self, config, diff --git a/modelopt/torch/speculative/plugins/hf_training_args.py b/modelopt/torch/speculative/plugins/hf_training_args.py new file mode 100644 index 00000000000..126de5a2a93 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_training_args.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic schemas for HF-trainer-based speculative-decoding experiments. + +These are the typed section models used inside speculative-decoding recipes +(:class:`modelopt.recipe.config.ModelOptEagleRecipe` / +:class:`modelopt.recipe.config.ModelOptDFlashRecipe`). They mirror the HF dataclasses used +by :mod:`examples/speculative_decoding/main.py` so that recipe YAMLs are Pydantic-validated +at load time. + +The module intentionally does NOT import ``transformers`` — it is pure Pydantic schema. +``transformers.TrainingArguments`` is extended separately in the example script to stay +compatible with HF's ``Trainer`` API; the ``TrainingArguments`` model here only declares the +seven speculative-decoding extension fields plus ``extra='allow'`` so HF trainer fields +(learning_rate, num_train_epochs, ...) flow through untouched. + +``TrainingArguments`` does read ``WORLD_SIZE`` and ``torch.cuda.device_count()`` at validation +time to auto-fill ``dp_shard_size`` and derive a ``parallelism_config`` (accelerate's +``ParallelismConfig``) when the run is actually distributed. ``torch`` and ``accelerate`` are +imported lazily from within the validator so importing this module stays cheap and +``accelerate`` only becomes a hard requirement when ``cp_size>1`` or ``dp_shard_size>1``. +""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, field_validator, model_validator + + +class ModelArguments(BaseModel): + """Arguments for loading the base HF model.""" + + model_config = ConfigDict(extra="forbid", protected_namespaces=()) + + model_name_or_path: str | None = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + use_fake_base_for_offline: bool = False + trust_remote_code: bool = False + + +class DataArguments(BaseModel): + """Arguments for the training dataset.""" + + model_config = ConfigDict(extra="forbid") + + data_path: str | None = None + offline_data_path: str | None = None + lazy_preprocess: bool = True + draft_vocab_cache: str | None = None + chat_template: str | None = None + vlm_img_dir: str | None = None + vlm_processor: str | None = None + sample_size: int = -1 + + @field_validator("sample_size") + @classmethod + def _check_sample_size(cls, v: int) -> int: + if v == 0 or v < -1: + raise ValueError("sample_size must be -1 (use all samples) or a positive integer") + return v + + +class TrainingArguments(BaseModel): + """Speculative-decoding extensions on top of ``transformers.TrainingArguments``. + + HF trainer fields (``learning_rate``, ``num_train_epochs``, ...) flow through as extras + via ``extra='allow'`` — they're re-validated later when the dict is passed to + ``HfTrainingArguments(**recipe.training.model_dump())`` in main.py. + """ + + model_config = ConfigDict(extra="allow") + + training_seq_len: int = 2048 + estimate_ar: bool = False + ar_validate_steps: int = 1000 + answer_only_loss: bool = False + cp_size: int = 1 + dp_shard_size: int | None = None + # Derived at validation time from cp_size/dp_shard_size/WORLD_SIZE; typed as Any so this + # module doesn't need to import accelerate.ParallelismConfig just to annotate the field. + parallelism_config: Any = None + + @model_validator(mode="after") + def _fill_parallelism(self) -> TrainingArguments: + # Read WORLD_SIZE (set by torchrun/accelerate, multi-node aware); fall back to the + # local GPU count for single-process runs. + import os + + import torch + + world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) + if self.dp_shard_size is None: + self.dp_shard_size = world_size // self.cp_size + + # Build a ParallelismConfig only when actually running distributed — matches the + # previous main.py guard and avoids requiring accelerate on single-GPU dev boxes. + if self.cp_size > 1 or self.dp_shard_size > 1: + parallel_size = self.dp_shard_size * self.cp_size + if world_size % parallel_size != 0: + raise ValueError( + f"world_size ({world_size}) must be divisible by " + f"dp_shard_size ({self.dp_shard_size}) * cp_size ({self.cp_size}) " + f"= {parallel_size}" + ) + try: + from accelerate import ParallelismConfig + except ImportError as e: + raise ImportError( + "cp_size>1 or dp_shard_size>1 requires `accelerate` for ParallelismConfig. " + "Install it via `pip install accelerate`." + ) from e + self.parallelism_config = ParallelismConfig( + cp_size=self.cp_size, + dp_shard_size=self.dp_shard_size, + dp_replicate_size=world_size // parallel_size, + ) + return self diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml index 3d43e0fe1d4..a38b24d05d6 100644 --- a/modelopt_recipes/general/speculative_decoding/dflash.yaml +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -1,4 +1,9 @@ -# Base config for DFlash training. Override fields via OmegaConf dotlist on the CLI. +# Base config for DFlash training. A full modelopt recipe; override fields via +# OmegaConf dotlist on the CLI (e.g. `model.model_name_or_path=...`). + +metadata: + recipe_type: speculative_dflash + description: DFlash training recipe (model/data/training/dflash bundled). # maps to ModelArguments (main.py) model: @@ -18,7 +23,6 @@ data: # maps to TrainingArguments (main.py) training: # --- commonly modified --- - mode: dflash output_dir: num_train_epochs: 10 per_device_train_batch_size: 1 diff --git a/modelopt_recipes/general/speculative_decoding/eagle3.yaml b/modelopt_recipes/general/speculative_decoding/eagle3.yaml index a1b7ff77708..78767ad1ebb 100644 --- a/modelopt_recipes/general/speculative_decoding/eagle3.yaml +++ b/modelopt_recipes/general/speculative_decoding/eagle3.yaml @@ -1,4 +1,9 @@ -# Base config for EAGLE3 training. Override fields via OmegaConf dotlist on the CLI. +# Base config for EAGLE3 training. A full modelopt recipe; override fields via +# OmegaConf dotlist on the CLI (e.g. `model.model_name_or_path=...`). + +metadata: + recipe_type: speculative_eagle + description: EAGLE3 training recipe (model/data/training/eagle bundled). # maps to ModelArguments (main.py) model: @@ -17,7 +22,6 @@ data: # maps to TrainingArguments (main.py) training: # --- commonly modified --- - mode: eagle3 output_dir: num_train_epochs: 1 per_device_train_batch_size: 1 diff --git a/tests/unit/recipe/test_loader.py b/tests/unit/recipe/test_loader.py index 6926d89a5d2..0eb91a9590e 100644 --- a/tests/unit/recipe/test_loader.py +++ b/tests/unit/recipe/test_loader.py @@ -19,8 +19,13 @@ import pytest -from modelopt.recipe.config import ModelOptPTQRecipe, RecipeType -from modelopt.recipe.loader import load_config, load_recipe +from modelopt.recipe.config import ( + ModelOptDFlashRecipe, + ModelOptEagleRecipe, + ModelOptPTQRecipe, + RecipeType, +) +from modelopt.recipe.loader import _apply_dotlist, load_config, load_recipe # --------------------------------------------------------------------------- # Static YAML fixtures @@ -191,6 +196,223 @@ def test_load_recipe_dir_missing_quantize_raises(tmp_path): load_recipe(tmp_path) +# --------------------------------------------------------------------------- +# load_recipe — EAGLE speculative decoding +# --------------------------------------------------------------------------- + + +def test_load_recipe_eagle_builtin(): + """load_recipe loads the built-in EAGLE recipe and returns a ModelOptEagleRecipe.""" + recipe = load_recipe("general/speculative_decoding/eagle3") + assert recipe.recipe_type == RecipeType.SPECULATIVE_EAGLE + assert isinstance(recipe, ModelOptEagleRecipe) + assert recipe.eagle.eagle_decoder_type == "llama" + assert recipe.eagle.eagle_ttt_steps == 3 + # Full-pipeline recipe also carries typed HF trainer sections. + assert recipe.training.training_seq_len == 2048 + + +def test_load_recipe_eagle_missing_section_raises(tmp_path): + """load_recipe raises ValueError when 'eagle' is absent for a SPECULATIVE_EAGLE recipe.""" + bad = tmp_path / "bad.yml" + bad.write_text("metadata:\n recipe_type: speculative_eagle\n") + with pytest.raises(ValueError, match="eagle"): + load_recipe(bad) + + +def test_load_recipe_eagle_field_validation_raises(tmp_path): + """Invalid EAGLE field values must fail Pydantic validation at load time.""" + bad = tmp_path / "bad.yml" + bad.write_text( + "metadata:\n recipe_type: speculative_eagle\neagle:\n eagle_ttt_steps: not_an_int\n" + ) + with pytest.raises(Exception): # pydantic.ValidationError + load_recipe(bad) + + +# --------------------------------------------------------------------------- +# load_recipe — DFlash speculative decoding +# --------------------------------------------------------------------------- + + +def test_load_recipe_dflash_builtin(): + """load_recipe loads the built-in DFlash recipe and returns a ModelOptDFlashRecipe.""" + recipe = load_recipe("general/speculative_decoding/dflash") + assert recipe.recipe_type == RecipeType.SPECULATIVE_DFLASH + assert isinstance(recipe, ModelOptDFlashRecipe) + assert recipe.dflash.dflash_block_size == 8 + assert recipe.dflash.dflash_num_anchors == 512 + # Full-pipeline recipe also carries typed HF trainer sections. + assert recipe.training.training_seq_len == 4096 + + +def test_load_recipe_dflash_missing_section_raises(tmp_path): + """load_recipe raises ValueError when 'dflash' is absent for a SPECULATIVE_DFLASH recipe.""" + bad = tmp_path / "bad.yml" + bad.write_text("metadata:\n recipe_type: speculative_dflash\n") + with pytest.raises(ValueError, match="dflash"): + load_recipe(bad) + + +def test_load_recipe_eagle_with_training_sections(tmp_path): + """load_recipe populates typed HF trainer sections from all four YAML segments.""" + recipe_path = tmp_path / "eagle.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "model:\n model_name_or_path: TinyLlama/TinyLlama-1.1B-Chat-v1.0\n" + "data:\n data_path: train.jsonl\n" + "training:\n output_dir: ckpts/test\n" + "eagle:\n eagle_decoder_type: llama\n eagle_ttt_steps: 2\n" + ) + recipe = load_recipe(recipe_path) + assert isinstance(recipe, ModelOptEagleRecipe) + assert recipe.model.model_name_or_path == "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + assert recipe.data.data_path == "train.jsonl" + # output_dir is an HF-trainer extra; flows through extras. + assert recipe.training.model_dump()["output_dir"] == "ckpts/test" + assert recipe.eagle.eagle_ttt_steps == 2 + + +def test_typed_model_section_rejects_unknown_field(tmp_path): + """model section has extra='forbid'; unknown keys raise ValidationError at load time.""" + recipe_path = tmp_path / "bad.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "model:\n typo_name: oops\n" + "eagle:\n eagle_decoder_type: llama\n" + ) + with pytest.raises(Exception): # pydantic.ValidationError + load_recipe(recipe_path) + + +def test_typed_training_section_accepts_hf_extras(tmp_path): + """training section has extra='allow'; HF trainer fields flow through without validation.""" + recipe_path = tmp_path / "eagle.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "training:\n" + " num_train_epochs: 3\n" # HF field — accepted as extra + " learning_rate: 1.0e-4\n" # HF field — accepted as extra + " training_seq_len: 4096\n" # our extension field — validated + "eagle:\n eagle_decoder_type: llama\n" + ) + recipe = load_recipe(recipe_path) + assert isinstance(recipe, ModelOptEagleRecipe) + assert recipe.training.training_seq_len == 4096 + dumped = recipe.training.model_dump() + assert dumped["num_train_epochs"] == 3 + assert dumped["learning_rate"] == 1e-4 + + +# --------------------------------------------------------------------------- +# CLI-style dotlist overrides +# --------------------------------------------------------------------------- + + +def test_apply_dotlist_flat(): + """_apply_dotlist sets a top-level key and parses the value with yaml.safe_load.""" + result = _apply_dotlist({"a": 1}, ["b=2"]) + assert result == {"a": 1, "b": 2} + + +def test_apply_dotlist_nested_overwrite(): + """_apply_dotlist overwrites a nested key without mutating input.""" + original = {"model": {"trust_remote_code": False}} + result = _apply_dotlist(original, ["model.trust_remote_code=true"]) + assert result["model"]["trust_remote_code"] is True + assert original["model"]["trust_remote_code"] is False # input untouched + + +def test_apply_dotlist_creates_missing_path(): + """_apply_dotlist creates intermediate dicts when the path doesn't exist.""" + result = _apply_dotlist({}, ["a.b.c=42"]) + assert result == {"a": {"b": {"c": 42}}} + + +def test_apply_dotlist_parses_typed_values(): + """_apply_dotlist preserves yaml.safe_load's type inference.""" + result = _apply_dotlist( + {}, + [ + "int_v=7", + "float_v=1.5", + "bool_v=true", + "null_v=null", + "list_v=[1, 2, 3]", + "str_v=hello", + ], + ) + assert result == { + "int_v": 7, + "float_v": 1.5, + "bool_v": True, + "null_v": None, + "list_v": [1, 2, 3], + "str_v": "hello", + } + + +def test_apply_dotlist_scientific_notation(): + """OmegaConf parses ``1e-4`` as float natively (unlike yaml.safe_load in YAML 1.1 mode).""" + result = _apply_dotlist({}, ["lr=5e-5", "decay=1e-10", "still_str=hello"]) + assert result["lr"] == 5e-5 and isinstance(result["lr"], float) + assert result["decay"] == 1e-10 and isinstance(result["decay"], float) + assert result["still_str"] == "hello" # non-numeric strings stay as strings + + +def test_apply_dotlist_malformed_raises(): + """_apply_dotlist rejects entries missing the '=' separator.""" + with pytest.raises(ValueError, match="missing '='"): + _apply_dotlist({}, ["foo_no_equals"]) + + +def test_load_recipe_with_overrides(tmp_path): + """load_recipe(path, overrides=...) merges dotlist entries before Pydantic validation.""" + recipe_path = tmp_path / "recipe.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "model:\n trust_remote_code: false\n" + "eagle:\n eagle_ttt_steps: 3\n" + ) + recipe = load_recipe( + recipe_path, + overrides=["model.trust_remote_code=true", "eagle.eagle_ttt_steps=7"], + ) + assert isinstance(recipe, ModelOptEagleRecipe) + assert recipe.model.trust_remote_code is True + assert recipe.eagle.eagle_ttt_steps == 7 + + +def test_load_recipe_overrides_rejected_for_dir(tmp_path): + """Overrides are not allowed for directory-format recipes.""" + (tmp_path / "recipe.yml").write_text("metadata:\n recipe_type: ptq\n") + (tmp_path / "quantize.yml").write_text("algorithm: max\nquant_cfg: []\n") + with pytest.raises(ValueError, match="directory-format"): + load_recipe(tmp_path, overrides=["quantize.algorithm=gptq"]) + + +def test_typed_data_sample_size_validator(tmp_path): + """DataArguments rejects sample_size=0 via field_validator.""" + recipe_path = tmp_path / "bad.yml" + recipe_path.write_text( + "metadata:\n recipe_type: speculative_eagle\n" + "data:\n sample_size: 0\n" + "eagle:\n eagle_decoder_type: llama\n" + ) + with pytest.raises(Exception, match="sample_size"): # pydantic.ValidationError + load_recipe(recipe_path) + + +def test_load_recipe_dflash_field_validation_raises(tmp_path): + """Invalid DFlash field values must fail Pydantic validation at load time.""" + bad = tmp_path / "bad.yml" + bad.write_text( + "metadata:\n recipe_type: speculative_dflash\ndflash:\n dflash_block_size: not_an_int\n" + ) + with pytest.raises(Exception): # pydantic.ValidationError + load_recipe(bad) + + # --------------------------------------------------------------------------- # YAML recipe consistency — built-in general/ptq files match config.py dicts # ---------------------------------------------------------------------------