Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 68 additions & 218 deletions examples/speculative_decoding/main.py

Large diffs are not rendered by default.

103 changes: 102 additions & 1 deletion modelopt/recipe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,28 @@

from __future__ import annotations

import warnings
from enum import Enum

from pydantic import field_validator
from pydantic import field_validator, model_validator

from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
from modelopt.torch.quantization.config import QuantizeConfig
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig
from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs
from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs
from modelopt.torch.speculative.plugins.hf_training_args import (
TrainingArguments as SpecTrainingArgs,
)
Comment on lines +27 to +32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Lazy-load speculative/HF config imports instead of importing them at module import time.

Line 27-Line 32 hard-import speculative/HF modules in a core recipe module. This makes optional extras effectively mandatory for import paths that don’t use speculative recipes (e.g., PTQ-only usage). Please gate these with lazy/plugin loading and keep type-only references behind deferred annotations/TYPE_CHECKING.

As per coding guidelines: "Avoid hard imports of optional dependencies at module level; features should be gated by install extras ([onnx], [hf], [all]) and loaded lazily via import_plugin()".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/recipe/config.py` around lines 27 - 32, The module currently
performs hard imports of speculative/HF classes DFlashConfig, EagleConfig,
MedusaConfig and SpecDataArgs/SpecModelArgs/SpecTrainingArgs at import time;
change this to lazy/plugin loading by moving those imports behind
typing.TYPE_CHECKING for type annotations and using the project's
import_plugin() (or an equivalent lazy import helper) where the classes are
actually needed at runtime (e.g., in factory functions or recipe registration).
Keep type-only references via from typing import TYPE_CHECKING and string
annotations or if TYPE_CHECKING: import the speculative classes, and replace
direct module-level imports with calls to
import_plugin("modelopt.torch.speculative...") right before using
DFlashConfig/EagleConfig/MedusaConfig or the Spec* argument classes.



class RecipeType(str, Enum):
"""List of recipe types."""

PTQ = "ptq"
SPECULATIVE_EAGLE = "speculative_eagle"
SPECULATIVE_DFLASH = "speculative_dflash"
SPECULATIVE_MEDUSA = "speculative_medusa"
# QAT = "qat" # Not implemented yet, will be added in the future.


Expand Down Expand Up @@ -72,3 +82,94 @@ class ModelOptPTQRecipe(ModelOptRecipeBase):
description="PTQ config containing quant_cfg and algorithm.",
validate_default=True,
)


class ModelOptSpeculativeRecipeBase(ModelOptRecipeBase):
"""Base class for speculative-decoding recipes.

Unlike PTQ, speculative-decoding is a training-time optimization: the draft head is trained
with HF Trainer. We therefore bundle ``model`` / ``data`` / ``training`` sections into the
recipe so a single YAML is the full experiment spec. Each section is a typed Pydantic model
(see :mod:`modelopt.torch.speculative.plugins.hf_training_args`) so field typos and bad
values are caught at recipe-load time; HF trainer fields pass through
``TrainingArguments`` via ``extra='allow'``.
"""

model: SpecModelArgs = ModeloptField(
default=SpecModelArgs(),
title="HF model args",
description="ModelArguments for the base HF model to train a draft head against.",
validate_default=True,
)
data: SpecDataArgs = ModeloptField(
default=SpecDataArgs(),
title="HF data args",
description="DataArguments for the training/offline dataset.",
validate_default=True,
)
training: SpecTrainingArgs = ModeloptField(
default=SpecTrainingArgs(),
title="HF training args",
description="Speculative-decoding extensions; HF trainer fields flow through as extras.",
validate_default=True,
)
Comment on lines +98 to +115
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In Python and Pydantic v2, are expressions like default=SpecTrainingArgs()evaluated at class definition/import time, and isdefault_factory=SpecTrainingArgs the recommended way to avoid eager instantiation and validator side effects for nested model defaults?

💡 Result:

Yes—if you write default=SpecTrainingArgs (i.e., you pass an already-constructed instance as the field default), that instance is created eagerly at Python class definition/import time. In contrast, default_factory=SpecTrainingArgs (i.e., a callable) defers instantiation until Pydantic needs a default for a model instance. In Pydantic v2, the docs explicitly describe default_factory as a callable that is “called to generate a default value”, whereas default is a fixed default value. So using default_factory=SpecTrainingArgs is the recommended way to avoid eager instantiation for nested model defaults. About “validator side effects”/avoid validator-related surprises: Pydantic’s default_factory can be designed to run only when Pydantic is producing the default for the field. Also, Pydantic has shown in issues that the timing/context of default_factory invocation matters (e.g., bugs/confusing errors when validation of earlier fields fails while a data-taking default factory is being called). While those issues are about specific dependent-default scenarios, the general guidance still stands: avoid constructing nested models in the default= argument when you want Pydantic to control when instantiation happens. Practical pattern (recommended) - Prefer: sub: SubModel = Field(default_factory=SubModel) - Avoid: sub: SubModel = Field(default=SubModel) Example - If SpecTrainingArgs is a BaseModel (or otherwise Pydantic-validated type): use default_factory=SpecTrainingArgs so each outer model instance gets its own default created at instantiation time, not at import/class-definition time.

Citations:


🏁 Script executed:

# First, find and examine the SpecTrainingArgs implementation
find . -type f -name "*.py" | xargs grep -l "class SpecTrainingArgs" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Read the config.py file around lines 97-114 and 122-153 to understand the full context
sed -n '90,160p' modelopt/recipe/config.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2815


🏁 Script executed:

# Check if SpecTrainingArgs has __init__ or validators with side effects
grep -A 30 "class SpecTrainingArgs" $(find . -type f -name "*.py" -path "*/speculative/*")

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for SpecTrainingArgs definition more broadly
rg "class SpecTrainingArgs" --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Also check the import statement at the top of config.py
head -50 modelopt/recipe/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1977


🏁 Script executed:

# Find the TrainingArguments class in hf_training_args
find . -path "*speculative/plugins/hf_training_args*" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 123


🏁 Script executed:

# Once we find it, examine the TrainingArguments.__init__ or field validators
rg "_fill_parallelism" --type py -B 5 -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 1622


🏁 Script executed:

# Get the full _fill_parallelism method and surrounding context
sed -n '120,180p' modelopt/torch/speculative/plugins/hf_training_args.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 642


🏁 Script executed:

# Also check if there are any imports of accelerate in this file
grep -n "accelerate\|ParallelismConfig" modelopt/torch/speculative/plugins/hf_training_args.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1012


🏁 Script executed:

# Verify the full validator implementation
sed -n '95,130p' modelopt/torch/speculative/plugins/hf_training_args.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1942


🏁 Script executed:

# Check if ModeloptField supports default_factory parameter
rg "def ModeloptField|class ModeloptField" --type py -A 20

Repository: NVIDIA/Model-Optimizer

Length of output: 1492


🏁 Script executed:

# Check the full ModeloptField implementation
cat -n modelopt/torch/opt/config.py | sed -n '1,100p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4970


🏁 Script executed:

# Verify if default_factory can be passed through **kwargs to Field
rg "default_factory" modelopt/torch/opt/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if there are any examples of default_factory usage in the codebase
rg "default_factory" --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 3012


🏁 Script executed:

# Check if ModeloptField is used anywhere with default_factory attempts
rg "ModeloptField.*default_factory" --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Also verify the assertion in ModeloptField more carefully
sed -n '51,55p' modelopt/torch/opt/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 348


🏁 Script executed:

# Check if using default_factory with ModeloptField would work given the assertion
python3 << 'EOF'
# Simulate what would happen if we try to use default_factory with ModeloptField
from pydantic_core import PydanticUndefined

# This is what ModeloptField does
def ModeloptField(default=PydanticUndefined, **kwargs):
    assert default is not PydanticUndefined, "A default value must be set for ModeloptField."
    return {"default": default, **kwargs}

# Try to pass default_factory
try:
    result = ModeloptField(default_factory=lambda: {})
    print("Success:", result)
except AssertionError as e:
    print("AssertionError:", e)
except TypeError as e:
    print("TypeError:", e)
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 129


🏁 Script executed:

# Check if there's a way to modify ModeloptField to support default_factory
grep -B 5 -A 15 "def ModeloptField" modelopt/torch/opt/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 928


🏁 Script executed:

# Verify if using direct Field instead of ModeloptField would work in the recipe file
rg "from pydantic import.*Field" modelopt/recipe/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check what imports are available at the top of config.py
head -30 modelopt/recipe/config.py | grep -E "^from|^import"

Repository: NVIDIA/Model-Optimizer

Length of output: 636


🏁 Script executed:

# Check if we need to modify ModeloptField to support default_factory
rg "ModeloptField" modelopt/torch/opt/config.py -B 3 -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 1159


Use lazy defaults for nested recipe sections to avoid import-time validator side effects.

SpecTrainingArgs(), SpecModelArgs(), and related config objects are instantiated when the module is imported, not when a recipe instance is created. Because SpecTrainingArgs has an @model_validator(mode="after") decorator on _fill_parallelism(), that validator runs immediately at import time, reading WORLD_SIZE and potentially importing accelerate on multi-GPU hosts. This makes importing modelopt.recipe.config trigger host inspection and can expose host-specific derived defaults.

The suggested fix requires updating ModeloptField (in modelopt/torch/opt/config.py) to support default_factory in addition to default. Currently, ModeloptField asserts that a default value must be provided, which prevents using default_factory=.... Either:

  1. Modify ModeloptField to accept default_factory as an alternative (passing it through to pydantic.Field), or
  2. Use pydantic.Field directly instead of ModeloptField for these nested config fields and adjust the ModeloptField assertion accordingly.

Once that foundation is in place, the diff changes shown (switching to default_factory=SpecTrainingArgs, etc.) can be applied.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/recipe/config.py` around lines 97 - 114, The nested config fields
(model, data, training) currently instantiate
SpecModelArgs/SpecDataArgs/SpecTrainingArgs at import time; update ModeloptField
to accept a default_factory alternative (or remove the assertion forcing default
so callers can use pydantic.Field) so you can switch those fields to lazy
defaults like default_factory=SpecTrainingArgs (or use
pydantic.Field(default_factory=...)) instead of default=SpecTrainingArgs();
specifically modify the ModeloptField implementation (the class/function named
ModeloptField) to accept and pass through default_factory to pydantic.Field (and
stop asserting a concrete default), then change the three fields using
ModeloptField to use default_factory=SpecModelArgs,
default_factory=SpecDataArgs, and default_factory=SpecTrainingArgs.



class ModelOptEagleRecipe(ModelOptSpeculativeRecipeBase):
"""Our config class for EAGLE speculative decoding recipes."""

recipe_type: RecipeType = RecipeType.SPECULATIVE_EAGLE

eagle: EagleConfig = ModeloptField(
default=EagleConfig(),
title="EAGLE config",
description="EAGLE speculative decoding configuration.",
validate_default=True,
)

@model_validator(mode="after")
def _derive_eagle_offline(self) -> ModelOptEagleRecipe:
self.eagle.eagle_offline = self.data.offline_data_path is not None
return self

@model_validator(mode="after")
def _warn_rope_vs_training_seq_len(self) -> ModelOptEagleRecipe:
orig_max_pos = self.eagle.eagle_export_rope_scaling.get("original_max_position_embeddings")
if orig_max_pos is not None and orig_max_pos != self.training.training_seq_len:
warnings.warn(
f"eagle.eagle_export_rope_scaling.original_max_position_embeddings ({orig_max_pos}) "
f"differs from training.training_seq_len ({self.training.training_seq_len}). "
f"This may affect long-context inference quality."
)
return self


class ModelOptDFlashRecipe(ModelOptSpeculativeRecipeBase):
"""Our config class for DFlash speculative decoding recipes."""

recipe_type: RecipeType = RecipeType.SPECULATIVE_DFLASH

dflash: DFlashConfig = ModeloptField(
default=DFlashConfig(),
title="DFlash config",
description="DFlash speculative decoding configuration.",
validate_default=True,
)

@model_validator(mode="after")
def _derive_dflash_offline(self) -> ModelOptDFlashRecipe:
self.dflash.dflash_offline = self.data.offline_data_path is not None
return self


class ModelOptMedusaRecipe(ModelOptSpeculativeRecipeBase):
"""Our config class for Medusa speculative decoding recipes."""

recipe_type: RecipeType = RecipeType.SPECULATIVE_MEDUSA

medusa: MedusaConfig = ModeloptField(
default=MedusaConfig(),
title="Medusa config",
description="Medusa speculative decoding configuration.",
validate_default=True,
)
93 changes: 83 additions & 10 deletions modelopt/recipe/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@
from importlib.abc import Traversable
from pathlib import Path

from omegaconf import OmegaConf

from ._config_loader import BUILTIN_RECIPES_LIB, load_config
from .config import ModelOptPTQRecipe, ModelOptRecipeBase, RecipeType
from .config import (
ModelOptDFlashRecipe,
ModelOptEagleRecipe,
ModelOptMedusaRecipe,
ModelOptPTQRecipe,
ModelOptRecipeBase,
RecipeType,
)

__all__ = ["load_config", "load_recipe"]

Expand All @@ -49,17 +58,29 @@ def _resolve_recipe_path(recipe_path: str | Path | Traversable) -> Path | Traver
return recipe_path


def load_recipe(recipe_path: str | Path | Traversable) -> ModelOptRecipeBase:
"""Load a recipe from a YAML file or directory.
def load_recipe(
recipe_path: str | Path | Traversable,
overrides: list[str] | None = None,
) -> ModelOptRecipeBase:
"""Load a recipe from a YAML file or directory, with optional CLI-style overrides.

``recipe_path`` can be:

* A ``.yml`` / ``.yaml`` file with ``metadata`` and ``quantize`` sections.
The suffix may be omitted and will be probed automatically.
* A directory containing ``recipe.yml`` (metadata) and ``quantize.yml``.
* A ``.yml`` / ``.yaml`` file with ``metadata`` and one of ``quantize`` (PTQ),
``eagle`` (EAGLE speculative decoding), ``dflash`` (DFlash speculative
decoding) or ``medusa`` (Medusa speculative decoding) sections. The suffix
may be omitted and will be probed automatically.
* A directory containing ``recipe.yml`` (metadata) plus ``quantize.yml`` —
**PTQ recipes only**. Speculative-decoding recipes are always single YAML files.

The path may be relative to the built-in recipes library or an absolute /
relative filesystem path.

``overrides`` is an optional list of ``key.path=value`` dotlist entries applied
on top of the YAML before Pydantic validation. Values are parsed with
``yaml.safe_load`` so they get proper types (``foo.bar=true`` → bool, ``foo=1``
→ int, ``foo=[1,2]`` → list, etc.). Only supported when *recipe_path* is a
single YAML file.
"""
resolved = _resolve_recipe_path(recipe_path)

Expand All @@ -72,21 +93,43 @@ def load_recipe(recipe_path: str | Path | Traversable) -> ModelOptRecipeBase:
print(f"[load_recipe] loading: {_display}")

if resolved.is_file():
return _load_recipe_from_file(resolved)
return _load_recipe_from_file(resolved, overrides=overrides)

if resolved.is_dir():
if overrides:
raise ValueError(
"overrides are not supported for directory-format recipes; "
"use the single-YAML-file form instead."
)
return _load_recipe_from_dir(resolved)

raise ValueError(f"Recipe path {recipe_path!r} is not a valid YAML file or directory.")


def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBase:
"""Load a recipe from a YAML file.
def _apply_dotlist(data: dict, overrides: list[str]) -> dict:
"""Merge ``a.b.c=value`` command line overrides on top of ``data`` via OmegaConf."""
for entry in overrides:
if "=" not in entry:
raise ValueError(f"Invalid override (missing '='): {entry!r}")
merged = OmegaConf.merge(
OmegaConf.create(data),
OmegaConf.from_dotlist(list(overrides)),
)
return OmegaConf.to_container(merged, resolve=True)
Comment on lines +109 to +118
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== OmegaConf resolver registrations =="
rg -n -C2 'OmegaConf\.(register_new_resolver|register_resolver)' --type=py

echo
echo "== Eager resolution sites =="
rg -n -C2 'to_container\(.*resolve=True' --type=py

echo
echo "== Full recipe printing/logging sites =="
rg -n -C2 'pprint\(recipe\)|print_rank_0\(.*recipe|print\(.*recipe' --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 6579


🏁 Script executed:

# Check the load_recipe() function to understand the full context
cat -n modelopt/recipe/loader.py | head -130

Repository: NVIDIA/Model-Optimizer

Length of output: 5940


🏁 Script executed:

# Check examples/speculative_decoding/main.py to see the recipe handling
rg -B10 -A5 'pprint\(recipe\)' examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 686


🏁 Script executed:

# Check if register_hydra_resolvers is called during recipe loading
rg -n 'register_hydra_resolvers' --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 633


🏁 Script executed:

# Check how the result of _apply_dotlist is used
rg -A10 '_apply_dotlist' modelopt/recipe/loader.py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 984


🏁 Script executed:

# Check the usage context to see if resolve=True is actually necessary
rg -B5 -A15 '_load_recipe_from_file' modelopt/recipe/loader.py | grep -A20 'def _load_recipe_from_file'

Repository: NVIDIA/Model-Optimizer

Length of output: 673


Don't eagerly resolve OmegaConf interpolations on untrusted recipe overrides.

OmegaConf.to_container(..., resolve=True) at line 118 evaluates ${...} expressions before Pydantic validation. User-supplied YAML and CLI overrides can contain ${oc.env:...} patterns that resolve to environment variables. Since examples/speculative_decoding/main.py calls pprint(recipe) on the loaded config, resolved secrets will leak into logs.

Change resolve=True to resolve=False. Pydantic handles unresolved OmegaConf dicts fine, and type conversion is already performed by YAML parsing.

Suggested fix
 def _apply_dotlist(data: dict, overrides: list[str]) -> dict:
@@
-    return OmegaConf.to_container(merged, resolve=True)
+    return OmegaConf.to_container(merged, resolve=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/recipe/loader.py` around lines 109 - 118, In _apply_dotlist, avoid
eagerly resolving OmegaConf interpolations by changing the
OmegaConf.to_container call so it does not resolve interpolations; specifically,
update the call in function _apply_dotlist to use resolve=False instead of
resolve=True when converting the merged OmegaConf to a plain dict so
user-supplied overrides (e.g., ${...} patterns) are preserved and not expanded
before Pydantic validation.



def _load_recipe_from_file(
recipe_file: Path | Traversable,
overrides: list[str] | None = None,
) -> ModelOptRecipeBase:
"""Load a recipe from a YAML file, optionally applying dotlist overrides.

The file must contain a ``metadata`` section with at least ``recipe_type``,
plus a ``quant_cfg`` mapping and an optional ``algorithm`` for PTQ recipes.
plus the algorithm-specific section (``quantize`` / ``eagle`` / ``dflash`` / ``medusa``).
"""
data = load_config(recipe_file)
if overrides:
data = _apply_dotlist(data, overrides)

metadata = data.get("metadata", {})
recipe_type = metadata.get("recipe_type")
Expand All @@ -101,6 +144,36 @@ def _load_recipe_from_file(recipe_file: Path | Traversable) -> ModelOptRecipeBas
description=metadata.get("description", "PTQ recipe."),
quantize=data["quantize"],
)
if recipe_type == RecipeType.SPECULATIVE_EAGLE:
if "eagle" not in data:
raise ValueError(f"EAGLE recipe file {recipe_file} must contain 'eagle'.")
return ModelOptEagleRecipe(
description=metadata.get("description", "EAGLE speculative decoding recipe."),
model=data.get("model") or {},
data=data.get("data") or {},
training=data.get("training") or {},
eagle=data["eagle"],
)
if recipe_type == RecipeType.SPECULATIVE_DFLASH:
if "dflash" not in data:
raise ValueError(f"DFlash recipe file {recipe_file} must contain 'dflash'.")
return ModelOptDFlashRecipe(
description=metadata.get("description", "DFlash speculative decoding recipe."),
model=data.get("model") or {},
data=data.get("data") or {},
training=data.get("training") or {},
dflash=data["dflash"],
)
if recipe_type == RecipeType.SPECULATIVE_MEDUSA:
if "medusa" not in data:
raise ValueError(f"Medusa recipe file {recipe_file} must contain 'medusa'.")
return ModelOptMedusaRecipe(
description=metadata.get("description", "Medusa speculative decoding recipe."),
model=data.get("model") or {},
data=data.get("data") or {},
training=data.get("training") or {},
medusa=data["medusa"],
)
raise ValueError(f"Unsupported recipe type: {recipe_type!r}")


Expand Down
82 changes: 12 additions & 70 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@

"""Configurations for speculative decoding modes."""

import warnings
from copy import deepcopy
from typing import Any

from pydantic import ValidationInfo, model_validator
from pydantic import model_validator

from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField

Expand Down Expand Up @@ -71,7 +69,7 @@ class DFlashConfig(ModeloptBaseConfig):
default=False,
description=(
"Whether to use detached DFlash (offline training from pre-computed hidden states). "
"Auto-derived from data_args.offline_data_path during validation — not user-configurable."
"Derived by ModelOptDFlashRecipe from data.offline_data_path; not user-configurable."
),
)

Expand Down Expand Up @@ -103,10 +101,12 @@ class DFlashConfig(ModeloptBaseConfig):
default=True, description="Whether to report eval accuracy."
)

dflash_mask_token_id: int = ModeloptField(
dflash_mask_token_id: int | None = ModeloptField(
default=None,
description="Token ID used for masked (unknown) positions. "
"Set explicitly or auto-detected from tokenizer.mask_token_id in main.py.",
description=(
"Token ID used for masked (unknown) positions. Set explicitly in the recipe YAML, "
"or left unset to fall back to ``tokenizer.mask_token_id`` at training time."
),
)

dflash_architecture_config: dict = ModeloptField(
Expand All @@ -118,43 +118,6 @@ class DFlashConfig(ModeloptBaseConfig):
description="Whether to use torch.compile on DFlash forward/loss methods.",
)

@model_validator(mode="before")
@classmethod
def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any:
"""Derive ``dflash_offline`` from ``data_args.offline_data_path``.

This field is auto-derived, not user-configurable: when context provides
``data_args``, the derived value overrides any user-supplied value.
"""
ctx = info.context if info.context else {}
data_args = ctx.get("data_args")
if data_args is not None and isinstance(data, dict):
data["dflash_offline"] = getattr(data_args, "offline_data_path", None) is not None
return data

@model_validator(mode="before")
@classmethod
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
return data
ctx = info.context if info.context else {}
tokenizer = ctx.get("tokenizer")
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
data["dflash_mask_token_id"] = tokenizer.mask_token_id
return data

@model_validator(mode="after")
def _check_mask_token_id(self) -> "DFlashConfig":
"""Validate that mask_token_id is set after all resolution attempts."""
if self.dflash_mask_token_id is None:
raise ValueError(
"dflash_mask_token_id is required. Set it in the config YAML "
"(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
"has a mask_token_id attribute."
)
return self


class MedusaConfig(ModeloptBaseConfig):
"""Medusa config."""
Expand All @@ -174,7 +137,11 @@ class EagleConfig(ModeloptBaseConfig):
"""Eagle config."""

eagle_offline: bool = ModeloptField(
default=False, description=("Whether to use detached Eagle.")
default=False,
description=(
"Whether to use detached Eagle. Derived by ModelOptEagleRecipe from "
"data.offline_data_path; not user-configurable."
),
)

eagle_hidden_state_distillation: bool = ModeloptField(
Expand Down Expand Up @@ -292,16 +259,6 @@ class EagleConfig(ModeloptBaseConfig):
),
)

@model_validator(mode="before")
@classmethod
def _derive_eagle_offline(cls, data: Any, info: ValidationInfo) -> Any:
"""Derive ``eagle_offline`` from ``data_args.offline_data_path`` when provided in context."""
ctx = info.context if info.context else {}
data_args = ctx.get("data_args")
if data_args is not None and isinstance(data, dict):
data["eagle_offline"] = data_args.offline_data_path is not None
return data

@model_validator(mode="after")
def _check_rope_scaling_consistency(self) -> "EagleConfig":
if not self.eagle_export_rope_scaling:
Expand All @@ -315,18 +272,3 @@ def _check_rope_scaling_consistency(self) -> "EagleConfig":
f"training rope_type is 'default' (no scaling)."
)
return self

@model_validator(mode="after")
def _warn_rope_vs_training_seq_len(self, info: ValidationInfo) -> "EagleConfig":
ctx = info.context if info.context else {}
training_args = ctx.get("training_args")
if training_args is None:
return self
orig_max_pos = self.eagle_export_rope_scaling.get("original_max_position_embeddings")
if orig_max_pos is not None and orig_max_pos != training_args.training_seq_len:
warnings.warn(
f"eagle_export_rope_scaling.original_max_position_embeddings ({orig_max_pos}) "
f"differs from training_seq_len ({training_args.training_seq_len}). "
f"This may affect long-context inference quality."
)
return self
Loading
Loading