-
Notifications
You must be signed in to change notification settings - Fork 375
[Refactor] speculative decoding: use mto config subsystem #1328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,18 +17,28 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
| from enum import Enum | ||
|
|
||
| from pydantic import field_validator | ||
| from pydantic import field_validator, model_validator | ||
|
|
||
| from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField | ||
| from modelopt.torch.quantization.config import QuantizeConfig | ||
| from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig | ||
| from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs | ||
| from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs | ||
| from modelopt.torch.speculative.plugins.hf_training_args import ( | ||
| TrainingArguments as SpecTrainingArgs, | ||
| ) | ||
|
|
||
|
|
||
| class RecipeType(str, Enum): | ||
| """List of recipe types.""" | ||
|
|
||
| PTQ = "ptq" | ||
| SPECULATIVE_EAGLE = "speculative_eagle" | ||
| SPECULATIVE_DFLASH = "speculative_dflash" | ||
| SPECULATIVE_MEDUSA = "speculative_medusa" | ||
| # QAT = "qat" # Not implemented yet, will be added in the future. | ||
|
|
||
|
|
||
|
|
@@ -72,3 +82,94 @@ class ModelOptPTQRecipe(ModelOptRecipeBase): | |
| description="PTQ config containing quant_cfg and algorithm.", | ||
| validate_default=True, | ||
| ) | ||
|
|
||
|
|
||
| class ModelOptSpeculativeRecipeBase(ModelOptRecipeBase): | ||
| """Base class for speculative-decoding recipes. | ||
|
|
||
| Unlike PTQ, speculative-decoding is a training-time optimization: the draft head is trained | ||
| with HF Trainer. We therefore bundle ``model`` / ``data`` / ``training`` sections into the | ||
| recipe so a single YAML is the full experiment spec. Each section is a typed Pydantic model | ||
| (see :mod:`modelopt.torch.speculative.plugins.hf_training_args`) so field typos and bad | ||
| values are caught at recipe-load time; HF trainer fields pass through | ||
| ``TrainingArguments`` via ``extra='allow'``. | ||
| """ | ||
|
|
||
| model: SpecModelArgs = ModeloptField( | ||
| default=SpecModelArgs(), | ||
| title="HF model args", | ||
| description="ModelArguments for the base HF model to train a draft head against.", | ||
| validate_default=True, | ||
| ) | ||
| data: SpecDataArgs = ModeloptField( | ||
| default=SpecDataArgs(), | ||
| title="HF data args", | ||
| description="DataArguments for the training/offline dataset.", | ||
| validate_default=True, | ||
| ) | ||
| training: SpecTrainingArgs = ModeloptField( | ||
| default=SpecTrainingArgs(), | ||
| title="HF training args", | ||
| description="Speculative-decoding extensions; HF trainer fields flow through as extras.", | ||
| validate_default=True, | ||
| ) | ||
|
Comment on lines
+98
to
+115
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: Yes—if you write default=SpecTrainingArgs (i.e., you pass an already-constructed instance as the field default), that instance is created eagerly at Python class definition/import time. In contrast, default_factory=SpecTrainingArgs (i.e., a callable) defers instantiation until Pydantic needs a default for a model instance. In Pydantic v2, the docs explicitly describe default_factory as a callable that is “called to generate a default value”, whereas default is a fixed default value. So using default_factory=SpecTrainingArgs is the recommended way to avoid eager instantiation for nested model defaults. About “validator side effects”/avoid validator-related surprises: Pydantic’s default_factory can be designed to run only when Pydantic is producing the default for the field. Also, Pydantic has shown in issues that the timing/context of default_factory invocation matters (e.g., bugs/confusing errors when validation of earlier fields fails while a data-taking default factory is being called). While those issues are about specific dependent-default scenarios, the general guidance still stands: avoid constructing nested models in the default= argument when you want Pydantic to control when instantiation happens. Practical pattern (recommended) - Prefer: sub: SubModel = Field(default_factory=SubModel) - Avoid: sub: SubModel = Field(default=SubModel) Example - If SpecTrainingArgs is a BaseModel (or otherwise Pydantic-validated type): use default_factory=SpecTrainingArgs so each outer model instance gets its own default created at instantiation time, not at import/class-definition time. Citations:
🏁 Script executed: # First, find and examine the SpecTrainingArgs implementation
find . -type f -name "*.py" | xargs grep -l "class SpecTrainingArgs" | head -5Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Read the config.py file around lines 97-114 and 122-153 to understand the full context
sed -n '90,160p' modelopt/recipe/config.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 2815 🏁 Script executed: # Check if SpecTrainingArgs has __init__ or validators with side effects
grep -A 30 "class SpecTrainingArgs" $(find . -type f -name "*.py" -path "*/speculative/*")Repository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Search for SpecTrainingArgs definition more broadly
rg "class SpecTrainingArgs" --type pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Also check the import statement at the top of config.py
head -50 modelopt/recipe/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 1977 🏁 Script executed: # Find the TrainingArguments class in hf_training_args
find . -path "*speculative/plugins/hf_training_args*" -type fRepository: NVIDIA/Model-Optimizer Length of output: 123 🏁 Script executed: # Once we find it, examine the TrainingArguments.__init__ or field validators
rg "_fill_parallelism" --type py -B 5 -A 10Repository: NVIDIA/Model-Optimizer Length of output: 1622 🏁 Script executed: # Get the full _fill_parallelism method and surrounding context
sed -n '120,180p' modelopt/torch/speculative/plugins/hf_training_args.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 642 🏁 Script executed: # Also check if there are any imports of accelerate in this file
grep -n "accelerate\|ParallelismConfig" modelopt/torch/speculative/plugins/hf_training_args.pyRepository: NVIDIA/Model-Optimizer Length of output: 1012 🏁 Script executed: # Verify the full validator implementation
sed -n '95,130p' modelopt/torch/speculative/plugins/hf_training_args.py | cat -nRepository: NVIDIA/Model-Optimizer Length of output: 1942 🏁 Script executed: # Check if ModeloptField supports default_factory parameter
rg "def ModeloptField|class ModeloptField" --type py -A 20Repository: NVIDIA/Model-Optimizer Length of output: 1492 🏁 Script executed: # Check the full ModeloptField implementation
cat -n modelopt/torch/opt/config.py | sed -n '1,100p'Repository: NVIDIA/Model-Optimizer Length of output: 4970 🏁 Script executed: # Verify if default_factory can be passed through **kwargs to Field
rg "default_factory" modelopt/torch/opt/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Check if there are any examples of default_factory usage in the codebase
rg "default_factory" --type py | head -20Repository: NVIDIA/Model-Optimizer Length of output: 3012 🏁 Script executed: # Check if ModeloptField is used anywhere with default_factory attempts
rg "ModeloptField.*default_factory" --type pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Also verify the assertion in ModeloptField more carefully
sed -n '51,55p' modelopt/torch/opt/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 348 🏁 Script executed: # Check if using default_factory with ModeloptField would work given the assertion
python3 << 'EOF'
# Simulate what would happen if we try to use default_factory with ModeloptField
from pydantic_core import PydanticUndefined
# This is what ModeloptField does
def ModeloptField(default=PydanticUndefined, **kwargs):
assert default is not PydanticUndefined, "A default value must be set for ModeloptField."
return {"default": default, **kwargs}
# Try to pass default_factory
try:
result = ModeloptField(default_factory=lambda: {})
print("Success:", result)
except AssertionError as e:
print("AssertionError:", e)
except TypeError as e:
print("TypeError:", e)
EOFRepository: NVIDIA/Model-Optimizer Length of output: 129 🏁 Script executed: # Check if there's a way to modify ModeloptField to support default_factory
grep -B 5 -A 15 "def ModeloptField" modelopt/torch/opt/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 928 🏁 Script executed: # Verify if using direct Field instead of ModeloptField would work in the recipe file
rg "from pydantic import.*Field" modelopt/recipe/config.pyRepository: NVIDIA/Model-Optimizer Length of output: 48 🏁 Script executed: # Check what imports are available at the top of config.py
head -30 modelopt/recipe/config.py | grep -E "^from|^import"Repository: NVIDIA/Model-Optimizer Length of output: 636 🏁 Script executed: # Check if we need to modify ModeloptField to support default_factory
rg "ModeloptField" modelopt/torch/opt/config.py -B 3 -A 10Repository: NVIDIA/Model-Optimizer Length of output: 1159 Use lazy defaults for nested recipe sections to avoid import-time validator side effects.
The suggested fix requires updating
Once that foundation is in place, the diff changes shown (switching to 🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| class ModelOptEagleRecipe(ModelOptSpeculativeRecipeBase): | ||
| """Our config class for EAGLE speculative decoding recipes.""" | ||
|
|
||
| recipe_type: RecipeType = RecipeType.SPECULATIVE_EAGLE | ||
|
|
||
| eagle: EagleConfig = ModeloptField( | ||
| default=EagleConfig(), | ||
| title="EAGLE config", | ||
| description="EAGLE speculative decoding configuration.", | ||
| validate_default=True, | ||
| ) | ||
|
|
||
| @model_validator(mode="after") | ||
| def _derive_eagle_offline(self) -> ModelOptEagleRecipe: | ||
| self.eagle.eagle_offline = self.data.offline_data_path is not None | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
| def _warn_rope_vs_training_seq_len(self) -> ModelOptEagleRecipe: | ||
| orig_max_pos = self.eagle.eagle_export_rope_scaling.get("original_max_position_embeddings") | ||
| if orig_max_pos is not None and orig_max_pos != self.training.training_seq_len: | ||
| warnings.warn( | ||
| f"eagle.eagle_export_rope_scaling.original_max_position_embeddings ({orig_max_pos}) " | ||
| f"differs from training.training_seq_len ({self.training.training_seq_len}). " | ||
| f"This may affect long-context inference quality." | ||
| ) | ||
| return self | ||
|
|
||
|
|
||
| class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase): | ||
| """Our config class for DFlash speculative decoding recipes.""" | ||
|
|
||
| recipe_type: RecipeType = RecipeType.SPECULATIVE_DFLASH | ||
|
|
||
| dflash: DFlashConfig = ModeloptField( | ||
| default=DFlashConfig(), | ||
| title="DFlash config", | ||
| description="DFlash speculative decoding configuration.", | ||
| validate_default=True, | ||
| ) | ||
|
|
||
| @model_validator(mode="after") | ||
| def _derive_dflash_offline(self) -> ModelOptDFlashRecipe: | ||
| self.dflash.dflash_offline = self.data.offline_data_path is not None | ||
| return self | ||
|
|
||
|
|
||
| class ModelOptMedusaRecipe(ModelOptSpeculativeRecipeBase): | ||
| """Our config class for Medusa speculative decoding recipes.""" | ||
|
|
||
| recipe_type: RecipeType = RecipeType.SPECULATIVE_MEDUSA | ||
|
|
||
| medusa: MedusaConfig = ModeloptField( | ||
| default=MedusaConfig(), | ||
| title="Medusa config", | ||
| description="Medusa speculative decoding configuration.", | ||
| validate_default=True, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,17 @@ | |
| from importlib.abc import Traversable | ||
| from pathlib import Path | ||
|
|
||
| from omegaconf import OmegaConf | ||
|
|
||
| from ._config_loader import BUILTIN_RECIPES_LIB, load_config | ||
| from .config import ModelOptPTQRecipe, ModelOptRecipeBase, RecipeType | ||
| from .config import ( | ||
| ModelOptDFlashRecipe, | ||
| ModelOptEagleRecipe, | ||
| ModelOptMedusaRecipe, | ||
| ModelOptPTQRecipe, | ||
| ModelOptRecipeBase, | ||
| RecipeType, | ||
| ) | ||
|
|
||
| __all__ = ["load_config", "load_recipe"] | ||
|
|
||
|
|
@@ -49,17 +58,29 @@ def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traver | |
| return recipe_path | ||
|
|
||
|
|
||
| def load_recipe(recipe_path: str | Path | Traversable) -> ModelOptRecipeBase: | ||
| """Load a recipe from a YAML file or directory. | ||
| def load_recipe( | ||
| recipe_path: str | Path | Traversable, | ||
| overrides: list[str] | None = None, | ||
| ) -> ModelOptRecipeBase: | ||
| """Load a recipe from a YAML file or directory, with optional CLI-style overrides. | ||
|
|
||
| ``recipe_path`` can be: | ||
|
|
||
| * A ``.yml`` / ``.yaml`` file with ``metadata`` and ``quantize`` sections. | ||
| The suffix may be omitted and will be probed automatically. | ||
| * A directory containing ``recipe.yml`` (metadata) and ``quantize.yml``. | ||
| * A ``.yml`` / ``.yaml`` file with ``metadata`` and one of ``quantize`` (PTQ), | ||
| ``eagle`` (EAGLE speculative decoding), ``dflash`` (DFlash speculative | ||
| decoding) or ``medusa`` (Medusa speculative decoding) sections. The suffix | ||
| may be omitted and will be probed automatically. | ||
| * A directory containing ``recipe.yml`` (metadata) plus ``quantize.yml`` — | ||
| **PTQ recipes only**. Speculative-decoding recipes are always single YAML files. | ||
|
|
||
| The path may be relative to the built-in recipes library or an absolute / | ||
| relative filesystem path. | ||
|
|
||
| ``overrides`` is an optional list of ``key.path=value`` dotlist entries applied | ||
| on top of the YAML before Pydantic validation. Values are parsed with | ||
| ``yaml.safe_load`` so they get proper types (``foo.bar=true`` → bool, ``foo=1`` | ||
| → int, ``foo=[1,2]`` → list, etc.). Only supported when *recipe_path* is a | ||
| single YAML file. | ||
| """ | ||
| resolved = _resolve_recipe_path(recipe_path) | ||
|
|
||
|
|
@@ -72,21 +93,43 @@ def load_recipe(recipe_path: str | Path | Traversable) -> ModelOptRecipeBase: | |
| print(f"[load_recipe] loading: {_display}") | ||
|
|
||
| if resolved.is_file(): | ||
| return _load_recipe_from_file(resolved) | ||
| return _load_recipe_from_file(resolved, overrides=overrides) | ||
|
|
||
| if resolved.is_dir(): | ||
| if overrides: | ||
| raise ValueError( | ||
| "overrides are not supported for directory-format recipes; " | ||
| "use the single-YAML-file form instead." | ||
| ) | ||
| return _load_recipe_from_dir(resolved) | ||
|
|
||
| raise ValueError(f"Recipe path {recipe_path!r} is not a valid YAML file or directory.") | ||
|
|
||
|
|
||
| def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBase: | ||
| """Load a recipe from a YAML file. | ||
| def _apply_dotlist(data: dict, overrides: list[str]) -> dict: | ||
| """Merge ``a.b.c=value`` command line overrides on top of ``data`` via OmegaConf.""" | ||
| for entry in overrides: | ||
| if "=" not in entry: | ||
| raise ValueError(f"Invalid override (missing '='): {entry!r}") | ||
| merged = OmegaConf.merge( | ||
| OmegaConf.create(data), | ||
| OmegaConf.from_dotlist(list(overrides)), | ||
| ) | ||
| return OmegaConf.to_container(merged, resolve=True) | ||
|
Comment on lines
+109
to
+118
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
set -euo pipefail
echo "== OmegaConf resolver registrations =="
rg -n -C2 'OmegaConf\.(register_new_resolver|register_resolver)' --type=py
echo
echo "== Eager resolution sites =="
rg -n -C2 'to_container\(.*resolve=True' --type=py
echo
echo "== Full recipe printing/logging sites =="
rg -n -C2 'pprint\(recipe\)|print_rank_0\(.*recipe|print\(.*recipe' --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 6579 🏁 Script executed: # Check the load_recipe() function to understand the full context
cat -n modelopt/recipe/loader.py | head -130Repository: NVIDIA/Model-Optimizer Length of output: 5940 🏁 Script executed: # Check examples/speculative_decoding/main.py to see the recipe handling
rg -B10 -A5 'pprint\(recipe\)' examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer Length of output: 686 🏁 Script executed: # Check if register_hydra_resolvers is called during recipe loading
rg -n 'register_hydra_resolvers' --type=pyRepository: NVIDIA/Model-Optimizer Length of output: 633 🏁 Script executed: # Check how the result of _apply_dotlist is used
rg -A10 '_apply_dotlist' modelopt/recipe/loader.py | head -30Repository: NVIDIA/Model-Optimizer Length of output: 984 🏁 Script executed: # Check the usage context to see if resolve=True is actually necessary
rg -B5 -A15 '_load_recipe_from_file' modelopt/recipe/loader.py | grep -A20 'def _load_recipe_from_file'Repository: NVIDIA/Model-Optimizer Length of output: 673 Don't eagerly resolve OmegaConf interpolations on untrusted recipe overrides.
Change Suggested fix def _apply_dotlist(data: dict, overrides: list[str]) -> dict:
@@
- return OmegaConf.to_container(merged, resolve=True)
+ return OmegaConf.to_container(merged, resolve=False)🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| def _load_recipe_from_file( | ||
| recipe_file: Path | Traversable, | ||
| overrides: list[str] | None = None, | ||
| ) -> ModelOptRecipeBase: | ||
| """Load a recipe from a YAML file, optionally applying dotlist overrides. | ||
|
|
||
| The file must contain a ``metadata`` section with at least ``recipe_type``, | ||
| plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes. | ||
| plus the algorithm-specific section (``quantize`` / ``eagle`` / ``dflash`` / ``medusa``). | ||
| """ | ||
| data = load_config(recipe_file) | ||
| if overrides: | ||
| data = _apply_dotlist(data, overrides) | ||
|
|
||
| metadata = data.get("metadata", {}) | ||
| recipe_type = metadata.get("recipe_type") | ||
|
|
@@ -101,6 +144,36 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas | |
| description=metadata.get("description", "PTQ recipe."), | ||
| quantize=data["quantize"], | ||
| ) | ||
| if recipe_type == RecipeType.SPECULATIVE_EAGLE: | ||
| if "eagle" not in data: | ||
| raise ValueError(f"EAGLE recipe file {recipe_file} must contain 'eagle'.") | ||
| return ModelOptEagleRecipe( | ||
| description=metadata.get("description", "EAGLE speculative decoding recipe."), | ||
| model=data.get("model") or {}, | ||
| data=data.get("data") or {}, | ||
| training=data.get("training") or {}, | ||
| eagle=data["eagle"], | ||
| ) | ||
| if recipe_type == RecipeType.SPECULATIVE_DFLASH: | ||
| if "dflash" not in data: | ||
| raise ValueError(f"DFlash recipe file {recipe_file} must contain 'dflash'.") | ||
| return ModelOptDFlashRecipe( | ||
| description=metadata.get("description", "DFlash speculative decoding recipe."), | ||
| model=data.get("model") or {}, | ||
| data=data.get("data") or {}, | ||
| training=data.get("training") or {}, | ||
| dflash=data["dflash"], | ||
| ) | ||
| if recipe_type == RecipeType.SPECULATIVE_MEDUSA: | ||
| if "medusa" not in data: | ||
| raise ValueError(f"Medusa recipe file {recipe_file} must contain 'medusa'.") | ||
| return ModelOptMedusaRecipe( | ||
| description=metadata.get("description", "Medusa speculative decoding recipe."), | ||
| model=data.get("model") or {}, | ||
| data=data.get("data") or {}, | ||
| training=data.get("training") or {}, | ||
| medusa=data["medusa"], | ||
| ) | ||
| raise ValueError(f"Unsupported recipe type: {recipe_type!r}") | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazy-load speculative/HF config imports instead of importing them at module import time.
Line 27-Line 32 hard-import speculative/HF modules in a core recipe module. This makes optional extras effectively mandatory for import paths that don’t use speculative recipes (e.g., PTQ-only usage). Please gate these with lazy/plugin loading and keep type-only references behind deferred annotations/
TYPE_CHECKING.As per coding guidelines: "Avoid hard imports of optional dependencies at module level; features should be gated by install extras (
[onnx],[hf],[all]) and loaded lazily viaimport_plugin()".🤖 Prompt for AI Agents