[Refactor] speculative decoding: use mto config subsystem#1328
[Refactor] speculative decoding: use mto config subsystem#1328
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughTyped single-file speculative-decoding recipes and loader dotlist overrides are added; loader dispatches on Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Main as Speculative Main Script
participant Loader as load_recipe()
participant Parser as YAML Parser
participant Validator as Pydantic Validator
User->>Main: run with recipe_path + overrides
Main->>Loader: load_recipe(path, overrides)
Loader->>Parser: parse YAML file
Parser-->>Loader: parsed dict
alt overrides provided
Loader->>Loader: apply dotlist overrides (OmegaConf)
end
Loader->>Validator: select recipe class by metadata.recipe_type
Validator->>Validator: validate model / data / training sections
Validator-->>Loader: validated recipe instance
Loader-->>Main: return typed recipe
Main->>Main: construct HfTrainingArguments from recipe.training
Main->>Main: route conversion/execution by recipe type (eagle/dflash/medusa)
sequenceDiagram
participant Script as Main Script
participant Recipe as Recipe Config
participant HfArgs as HfTrainingArguments
participant Trainer as HF Trainer
Script->>Recipe: load_recipe(path)
Recipe-->>Script: typed recipe
Script->>HfArgs: HfTrainingArguments.from_dict(recipe.training.model_dump())
HfArgs->>HfArgs: infer dp_shard_size from WORLD_SIZE (if needed)
HfArgs->>HfArgs: validate speculative fields (training_seq_len, estimate_ar, ...)
HfArgs-->>Script: validated hf args
rect rgba(100,150,200,0.5)
Script->>Trainer: perform recipe-specific conversion & training
end
Trainer-->>Script: complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 6✅ Passed checks (6 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1328 +/- ##
==========================================
+ Coverage 74.60% 74.64% +0.03%
==========================================
Files 467 468 +1
Lines 50176 50260 +84
==========================================
+ Hits 37435 37517 +82
- Misses 12741 12743 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)
153-186:⚠️ Potential issue | 🟠 MajorMedusa recipe skips
data_modulecreation, causing an undefined variable error.When
recipeisModelOptMedusaRecipe, the code at line 154 callsmtsp.convert()but then falls through to line 203-211 wheredata_moduleis only created forModelOptEagleRecipeorModelOptDFlashRecipe. This leavesdata_moduleundefined, causing aNameErrorat line 226 when passed toEagleTrainerWithAccLog.Proposed fix: Add Medusa to data_module creation or handle separately
print_rank_0("Loading dataset...") is_dflash = isinstance(recipe, ModelOptDFlashRecipe) - if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)): + if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe)): data_module = make_speculative_data_module( tokenizer, recipe.data, train_len=training_args.training_seq_len, answer_only_loss=training_args.answer_only_loss, shift_labels=not is_dflash, ) + else: + raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 153 - 186, The Medusa branch (when recipe is ModelOptMedusaRecipe) calls mtsp.convert but never creates the data_module used later by EagleTrainerWithAccLog, causing a NameError; update the ModelOptMedusaRecipe branch to either construct the same data_module as done for ModelOptEagleRecipe/ModelOptDFlashRecipe (using recipe.data and the existing data-module factory logic) or explicitly set data_module to an appropriate value (or None) and handle that case before calling EagleTrainerWithAccLog so data_module is always defined; refer to the branches around ModelOptMedusaRecipe, ModelOptEagleRecipe, ModelOptDFlashRecipe, mtsp.convert, recipe.data, and EagleTrainerWithAccLog to place the fix.
🧹 Nitpick comments (1)
examples/speculative_decoding/main.py (1)
61-76: Field synchronization betweenHfTrainingArgumentsandSpecTrainingArgsshould be explicit.The docstring correctly notes these field sets "MUST stay in sync" with
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments. Consider adding a test or assertion to catch drift between these two definitions automatically.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 61 - 76, Add an automated check to ensure HfTrainingArguments and the plugin TrainingArguments stay in sync: write a small unit test or an import-time assertion that imports HfTrainingArguments (examples.speculative_decoding.main.HfTrainingArguments) and modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, extracts their dataclass/field names (and optionally defaults/types), and fails if any field names or default values differ; place the check in tests (preferred) or at module import to catch drift early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 96-130: The _fill_parallelism validator currently sets world_size
using int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())), which yields
0 on CPU-only machines and leads to dp_shard_size becoming 0; change that logic
to treat a missing WORLD_SIZE and a torch.cuda.device_count() of 0 as a
single-process run by using something like device_count =
torch.cuda.device_count(); world_size = int(os.environ.get("WORLD_SIZE", max(1,
device_count))) (or world_size = int(os.environ.get("WORLD_SIZE", device_count
or 1))) so world_size is at least 1, then compute dp_shard_size and do the
existing divisibility checks and ParallelismConfig creation as before to avoid
downstream surprises and division-by-zero-like behavior when cp_size or
dp_shard_size are used.
---
Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 153-186: The Medusa branch (when recipe is ModelOptMedusaRecipe)
calls mtsp.convert but never creates the data_module used later by
EagleTrainerWithAccLog, causing a NameError; update the ModelOptMedusaRecipe
branch to either construct the same data_module as done for
ModelOptEagleRecipe/ModelOptDFlashRecipe (using recipe.data and the existing
data-module factory logic) or explicitly set data_module to an appropriate value
(or None) and handle that case before calling EagleTrainerWithAccLog so
data_module is always defined; refer to the branches around
ModelOptMedusaRecipe, ModelOptEagleRecipe, ModelOptDFlashRecipe, mtsp.convert,
recipe.data, and EagleTrainerWithAccLog to place the fix.
---
Nitpick comments:
In `@examples/speculative_decoding/main.py`:
- Around line 61-76: Add an automated check to ensure HfTrainingArguments and
the plugin TrainingArguments stay in sync: write a small unit test or an
import-time assertion that imports HfTrainingArguments
(examples.speculative_decoding.main.HfTrainingArguments) and
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, extracts
their dataclass/field names (and optionally defaults/types), and fails if any
field names or default values differ; place the check in tests (preferred) or at
module import to catch drift early.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 838ec6d9-91a7-46a6-a197-701d6268d76c
📒 Files selected for processing (7)
examples/speculative_decoding/main.pymodelopt/recipe/config.pymodelopt/recipe/loader.pymodelopt/torch/speculative/plugins/hf_training_args.pymodelopt_recipes/general/speculative_decoding/dflash.yamlmodelopt_recipes/general/speculative_decoding/eagle3.yamltests/unit/recipe/test_loader.py
| @model_validator(mode="after") | ||
| def _fill_parallelism(self) -> TrainingArguments: | ||
| # Read WORLD_SIZE (set by torchrun/accelerate, multi-node aware); fall back to the | ||
| # local GPU count for single-process runs. | ||
| import os | ||
|
|
||
| import torch | ||
|
|
||
| world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) | ||
| if self.dp_shard_size is None: | ||
| self.dp_shard_size = world_size // self.cp_size | ||
|
|
||
| # Build a ParallelismConfig only when actually running distributed — matches the | ||
| # previous main.py guard and avoids requiring accelerate on single-GPU dev boxes. | ||
| if self.cp_size > 1 or self.dp_shard_size > 1: | ||
| parallel_size = self.dp_shard_size * self.cp_size | ||
| if world_size % parallel_size != 0: | ||
| raise ValueError( | ||
| f"world_size ({world_size}) must be divisible by " | ||
| f"dp_shard_size ({self.dp_shard_size}) * cp_size ({self.cp_size}) " | ||
| f"= {parallel_size}" | ||
| ) | ||
| try: | ||
| from accelerate import ParallelismConfig | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "cp_size>1 or dp_shard_size>1 requires `accelerate` for ParallelismConfig. " | ||
| "Install it via `pip install accelerate`." | ||
| ) from e | ||
| self.parallelism_config = ParallelismConfig( | ||
| cp_size=self.cp_size, | ||
| dp_shard_size=self.dp_shard_size, | ||
| dp_replicate_size=world_size // parallel_size, | ||
| ) | ||
| return self |
There was a problem hiding this comment.
Edge case: torch.cuda.device_count() returns 0 on CPU-only machines.
When running on a machine without CUDA (e.g., during local development or CI), torch.cuda.device_count() returns 0. Combined with a missing WORLD_SIZE env var, this results in world_size=0, which causes dp_shard_size = 0 // cp_size = 0. While this won't crash immediately, it may produce unexpected behavior downstream.
Consider handling this edge case explicitly:
Proposed fix: Default world_size to 1 when no GPUs detected
- world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
+ world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count() or 1))
if self.dp_shard_size is None:
self.dp_shard_size = world_size // self.cp_size🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_training_args.py` around lines 96 -
130, The _fill_parallelism validator currently sets world_size using
int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())), which yields 0 on
CPU-only machines and leads to dp_shard_size becoming 0; change that logic to
treat a missing WORLD_SIZE and a torch.cuda.device_count() of 0 as a
single-process run by using something like device_count =
torch.cuda.device_count(); world_size = int(os.environ.get("WORLD_SIZE", max(1,
device_count))) (or world_size = int(os.environ.get("WORLD_SIZE", device_count
or 1))) so world_size is at least 1, then compute dp_shard_size and do the
existing divisibility checks and ParallelismConfig creation as before to avoid
downstream surprises and division-by-zero-like behavior when cp_size or
dp_shard_size are used.
|
As discussed in the meeting, let's strip off the config override part into a separate change. |
|
|
||
| import torch | ||
|
|
||
| world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) |
There was a problem hiding this comment.
these logic really is beyond the model_validator scope, put it somewhere like "init_distributed_env" or something like that is better.
how do you think?
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
69f3c40 to
3bf010b
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/recipe/config.py`:
- Around line 97-114: The nested config fields (model, data, training) currently
instantiate SpecModelArgs/SpecDataArgs/SpecTrainingArgs at import time; update
ModeloptField to accept a default_factory alternative (or remove the assertion
forcing default so callers can use pydantic.Field) so you can switch those
fields to lazy defaults like default_factory=SpecTrainingArgs (or use
pydantic.Field(default_factory=...)) instead of default=SpecTrainingArgs();
specifically modify the ModeloptField implementation (the class/function named
ModeloptField) to accept and pass through default_factory to pydantic.Field (and
stop asserting a concrete default), then change the three fields using
ModeloptField to use default_factory=SpecModelArgs,
default_factory=SpecDataArgs, and default_factory=SpecTrainingArgs.
In `@modelopt/recipe/loader.py`:
- Around line 109-118: In _apply_dotlist, avoid eagerly resolving OmegaConf
interpolations by changing the OmegaConf.to_container call so it does not
resolve interpolations; specifically, update the call in function _apply_dotlist
to use resolve=False instead of resolve=True when converting the merged
OmegaConf to a plain dict so user-supplied overrides (e.g., ${...} patterns) are
preserved and not expanded before Pydantic validation.
In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 90-106: The _fill_parallelism model_validator currently divides by
cp_size and can accept non-positive cp_size/dp_shard_size from user configs;
update _fill_parallelism (on TrainingArguments) to first validate that cp_size
is an int > 0 and that if dp_shard_size is not None it is an int >= 1 (and also
verify computed world_size is >=1), and raise a clear ValueError (so Pydantic
surfaces it) if these checks fail; only after these guards compute world_size
and set self.dp_shard_size = world_size // self.cp_size when dp_shard_size is
None to avoid ZeroDivisionError and invalid runtime state.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 50f5aa9c-3c49-42a0-be51-0ab116444297
📒 Files selected for processing (7)
examples/speculative_decoding/main.pymodelopt/recipe/config.pymodelopt/recipe/loader.pymodelopt/torch/speculative/plugins/hf_training_args.pymodelopt_recipes/general/speculative_decoding/dflash.yamlmodelopt_recipes/general/speculative_decoding/eagle3.yamltests/unit/recipe/test_loader.py
✅ Files skipped from review due to trivial changes (1)
- modelopt_recipes/general/speculative_decoding/eagle3.yaml
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt_recipes/general/speculative_decoding/dflash.yaml
- tests/unit/recipe/test_loader.py
| if isinstance(recipe, ModelOptMedusaRecipe): | ||
| mtsp.convert(model, [("medusa", recipe.medusa.model_dump())]) | ||
| elif isinstance(recipe, ModelOptEagleRecipe): | ||
| # Validate and rewrite eagle config fields | ||
| eagle_cfg = EagleConfig.model_validate( | ||
| eagle_cfg, | ||
| context={"training_args": training_args, "data_args": data_args}, | ||
| recipe.eagle.model_dump(), | ||
| context={"training_args": training_args, "data_args": recipe.data}, | ||
| ).model_dump() | ||
| mtsp.convert(model, [("eagle", eagle_cfg)]) | ||
|
|
||
| # Load draft vocab cache if the draft model uses a compressed vocabulary | ||
| if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: | ||
| if not os.path.isfile(data_args.draft_vocab_cache): | ||
| raise FileNotFoundError( | ||
| f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}" | ||
| ) | ||
| model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) | ||
| print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") | ||
| elif training_args.mode == "dflash": | ||
| d2t = recipe.data.draft_vocab_cache | ||
| if d2t is None or not os.path.isfile(d2t): | ||
| raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t}") | ||
| model.eagle_module.d2t = torch.load(d2t, weights_only=True) | ||
| print_rank_0(f"Loaded draft vocab cache from {d2t}.") | ||
| elif isinstance(recipe, ModelOptDFlashRecipe): | ||
| dflash_cfg = DFlashConfig.model_validate( | ||
| dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args} | ||
| recipe.dflash.model_dump(), context={"tokenizer": tokenizer, "data_args": recipe.data} | ||
| ).model_dump() | ||
| mtsp.convert(model, [("dflash", dflash_cfg)]) | ||
| else: | ||
| raise Exception(f"{training_args.mode} is not supported!") | ||
| raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}") |
There was a problem hiding this comment.
ModelOptMedusaRecipe reaches trainer setup with no data_module.
This function now accepts Medusa recipes and converts the model for them, but data_module is only created for Eagle/DFlash. A Medusa run will therefore hit an UnboundLocalError at Line 216 when **data_module is expanded.
Either wire Medusa into dataset construction here as well, or reject Medusa earlier with an explicit error until the training path is implemented.
Also applies to: 193-216
| model: SpecModelArgs = ModeloptField( | ||
| default=SpecModelArgs(), | ||
| title="HF model args", | ||
| description="ModelArguments for the base HF model to train a draft head against.", | ||
| validate_default=True, | ||
| ) | ||
| data: SpecDataArgs = ModeloptField( | ||
| default=SpecDataArgs(), | ||
| title="HF data args", | ||
| description="DataArguments for the training/offline dataset.", | ||
| validate_default=True, | ||
| ) | ||
| training: SpecTrainingArgs = ModeloptField( | ||
| default=SpecTrainingArgs(), | ||
| title="HF training args", | ||
| description="Speculative-decoding extensions; HF trainer fields flow through as extras.", | ||
| validate_default=True, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In Python and Pydantic v2, are expressions like default=SpecTrainingArgs()evaluated at class definition/import time, and isdefault_factory=SpecTrainingArgs the recommended way to avoid eager instantiation and validator side effects for nested model defaults?
💡 Result:
Yes—if you write default=SpecTrainingArgs (i.e., you pass an already-constructed instance as the field default), that instance is created eagerly at Python class definition/import time. In contrast, default_factory=SpecTrainingArgs (i.e., a callable) defers instantiation until Pydantic needs a default for a model instance. In Pydantic v2, the docs explicitly describe default_factory as a callable that is “called to generate a default value”, whereas default is a fixed default value. So using default_factory=SpecTrainingArgs is the recommended way to avoid eager instantiation for nested model defaults. About “validator side effects”/avoid validator-related surprises: Pydantic’s default_factory can be designed to run only when Pydantic is producing the default for the field. Also, Pydantic has shown in issues that the timing/context of default_factory invocation matters (e.g., bugs/confusing errors when validation of earlier fields fails while a data-taking default factory is being called). While those issues are about specific dependent-default scenarios, the general guidance still stands: avoid constructing nested models in the default= argument when you want Pydantic to control when instantiation happens. Practical pattern (recommended) - Prefer: sub: SubModel = Field(default_factory=SubModel) - Avoid: sub: SubModel = Field(default=SubModel) Example - If SpecTrainingArgs is a BaseModel (or otherwise Pydantic-validated type): use default_factory=SpecTrainingArgs so each outer model instance gets its own default created at instantiation time, not at import/class-definition time.
Citations:
- 1: https://pydantic.dev/docs/validation/latest/concepts/fields/
- 2: https://docs.pydantic.dev/2.6/concepts/fields/
- 3: Default values are persisting in v2 pydantic/pydantic#5672
- 4: default_factory is called with missing validated_data if validation of prior field fails pydantic/pydantic#11358
🏁 Script executed:
# First, find and examine the SpecTrainingArgs implementation
find . -type f -name "*.py" | xargs grep -l "class SpecTrainingArgs" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Read the config.py file around lines 97-114 and 122-153 to understand the full context
sed -n '90,160p' modelopt/recipe/config.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 2815
🏁 Script executed:
# Check if SpecTrainingArgs has __init__ or validators with side effects
grep -A 30 "class SpecTrainingArgs" $(find . -type f -name "*.py" -path "*/speculative/*")Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Search for SpecTrainingArgs definition more broadly
rg "class SpecTrainingArgs" --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Also check the import statement at the top of config.py
head -50 modelopt/recipe/config.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1977
🏁 Script executed:
# Find the TrainingArguments class in hf_training_args
find . -path "*speculative/plugins/hf_training_args*" -type fRepository: NVIDIA/Model-Optimizer
Length of output: 123
🏁 Script executed:
# Once we find it, examine the TrainingArguments.__init__ or field validators
rg "_fill_parallelism" --type py -B 5 -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 1622
🏁 Script executed:
# Get the full _fill_parallelism method and surrounding context
sed -n '120,180p' modelopt/torch/speculative/plugins/hf_training_args.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 642
🏁 Script executed:
# Also check if there are any imports of accelerate in this file
grep -n "accelerate\|ParallelismConfig" modelopt/torch/speculative/plugins/hf_training_args.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1012
🏁 Script executed:
# Verify the full validator implementation
sed -n '95,130p' modelopt/torch/speculative/plugins/hf_training_args.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 1942
🏁 Script executed:
# Check if ModeloptField supports default_factory parameter
rg "def ModeloptField|class ModeloptField" --type py -A 20Repository: NVIDIA/Model-Optimizer
Length of output: 1492
🏁 Script executed:
# Check the full ModeloptField implementation
cat -n modelopt/torch/opt/config.py | sed -n '1,100p'Repository: NVIDIA/Model-Optimizer
Length of output: 4970
🏁 Script executed:
# Verify if default_factory can be passed through **kwargs to Field
rg "default_factory" modelopt/torch/opt/config.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check if there are any examples of default_factory usage in the codebase
rg "default_factory" --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 3012
🏁 Script executed:
# Check if ModeloptField is used anywhere with default_factory attempts
rg "ModeloptField.*default_factory" --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Also verify the assertion in ModeloptField more carefully
sed -n '51,55p' modelopt/torch/opt/config.pyRepository: NVIDIA/Model-Optimizer
Length of output: 348
🏁 Script executed:
# Check if using default_factory with ModeloptField would work given the assertion
python3 << 'EOF'
# Simulate what would happen if we try to use default_factory with ModeloptField
from pydantic_core import PydanticUndefined
# This is what ModeloptField does
def ModeloptField(default=PydanticUndefined, **kwargs):
assert default is not PydanticUndefined, "A default value must be set for ModeloptField."
return {"default": default, **kwargs}
# Try to pass default_factory
try:
result = ModeloptField(default_factory=lambda: {})
print("Success:", result)
except AssertionError as e:
print("AssertionError:", e)
except TypeError as e:
print("TypeError:", e)
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 129
🏁 Script executed:
# Check if there's a way to modify ModeloptField to support default_factory
grep -B 5 -A 15 "def ModeloptField" modelopt/torch/opt/config.pyRepository: NVIDIA/Model-Optimizer
Length of output: 928
🏁 Script executed:
# Verify if using direct Field instead of ModeloptField would work in the recipe file
rg "from pydantic import.*Field" modelopt/recipe/config.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check what imports are available at the top of config.py
head -30 modelopt/recipe/config.py | grep -E "^from|^import"Repository: NVIDIA/Model-Optimizer
Length of output: 636
🏁 Script executed:
# Check if we need to modify ModeloptField to support default_factory
rg "ModeloptField" modelopt/torch/opt/config.py -B 3 -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 1159
Use lazy defaults for nested recipe sections to avoid import-time validator side effects.
SpecTrainingArgs(), SpecModelArgs(), and related config objects are instantiated when the module is imported, not when a recipe instance is created. Because SpecTrainingArgs has an @model_validator(mode="after") decorator on _fill_parallelism(), that validator runs immediately at import time, reading WORLD_SIZE and potentially importing accelerate on multi-GPU hosts. This makes importing modelopt.recipe.config trigger host inspection and can expose host-specific derived defaults.
The suggested fix requires updating ModeloptField (in modelopt/torch/opt/config.py) to support default_factory in addition to default. Currently, ModeloptField asserts that a default value must be provided, which prevents using default_factory=.... Either:
- Modify
ModeloptFieldto acceptdefault_factoryas an alternative (passing it through topydantic.Field), or - Use
pydantic.Fielddirectly instead ofModeloptFieldfor these nested config fields and adjust theModeloptFieldassertion accordingly.
Once that foundation is in place, the diff changes shown (switching to default_factory=SpecTrainingArgs, etc.) can be applied.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/recipe/config.py` around lines 97 - 114, The nested config fields
(model, data, training) currently instantiate
SpecModelArgs/SpecDataArgs/SpecTrainingArgs at import time; update ModeloptField
to accept a default_factory alternative (or remove the assertion forcing default
so callers can use pydantic.Field) so you can switch those fields to lazy
defaults like default_factory=SpecTrainingArgs (or use
pydantic.Field(default_factory=...)) instead of default=SpecTrainingArgs();
specifically modify the ModeloptField implementation (the class/function named
ModeloptField) to accept and pass through default_factory to pydantic.Field (and
stop asserting a concrete default), then change the three fields using
ModeloptField to use default_factory=SpecModelArgs,
default_factory=SpecDataArgs, and default_factory=SpecTrainingArgs.
| def _apply_dotlist(data: dict, overrides: list[str]) -> dict: | ||
| """Merge ``a.b.c=value`` command line overrides on top of ``data`` via OmegaConf.""" | ||
| for entry in overrides: | ||
| if "=" not in entry: | ||
| raise ValueError(f"Invalid override (missing '='): {entry!r}") | ||
| merged = OmegaConf.merge( | ||
| OmegaConf.create(data), | ||
| OmegaConf.from_dotlist(list(overrides)), | ||
| ) | ||
| return OmegaConf.to_container(merged, resolve=True) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== OmegaConf resolver registrations =="
rg -n -C2 'OmegaConf\.(register_new_resolver|register_resolver)' --type=py
echo
echo "== Eager resolution sites =="
rg -n -C2 'to_container\(.*resolve=True' --type=py
echo
echo "== Full recipe printing/logging sites =="
rg -n -C2 'pprint\(recipe\)|print_rank_0\(.*recipe|print\(.*recipe' --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 6579
🏁 Script executed:
# Check the load_recipe() function to understand the full context
cat -n modelopt/recipe/loader.py | head -130Repository: NVIDIA/Model-Optimizer
Length of output: 5940
🏁 Script executed:
# Check examples/speculative_decoding/main.py to see the recipe handling
rg -B10 -A5 'pprint\(recipe\)' examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 686
🏁 Script executed:
# Check if register_hydra_resolvers is called during recipe loading
rg -n 'register_hydra_resolvers' --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 633
🏁 Script executed:
# Check how the result of _apply_dotlist is used
rg -A10 '_apply_dotlist' modelopt/recipe/loader.py | head -30Repository: NVIDIA/Model-Optimizer
Length of output: 984
🏁 Script executed:
# Check the usage context to see if resolve=True is actually necessary
rg -B5 -A15 '_load_recipe_from_file' modelopt/recipe/loader.py | grep -A20 'def _load_recipe_from_file'Repository: NVIDIA/Model-Optimizer
Length of output: 673
Don't eagerly resolve OmegaConf interpolations on untrusted recipe overrides.
OmegaConf.to_container(..., resolve=True) at line 118 evaluates ${...} expressions before Pydantic validation. User-supplied YAML and CLI overrides can contain ${oc.env:...} patterns that resolve to environment variables. Since examples/speculative_decoding/main.py calls pprint(recipe) on the loaded config, resolved secrets will leak into logs.
Change resolve=True to resolve=False. Pydantic handles unresolved OmegaConf dicts fine, and type conversion is already performed by YAML parsing.
Suggested fix
def _apply_dotlist(data: dict, overrides: list[str]) -> dict:
@@
- return OmegaConf.to_container(merged, resolve=True)
+ return OmegaConf.to_container(merged, resolve=False)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/recipe/loader.py` around lines 109 - 118, In _apply_dotlist, avoid
eagerly resolving OmegaConf interpolations by changing the
OmegaConf.to_container call so it does not resolve interpolations; specifically,
update the call in function _apply_dotlist to use resolve=False instead of
resolve=True when converting the merged OmegaConf to a plain dict so
user-supplied overrides (e.g., ${...} patterns) are preserved and not expanded
before Pydantic validation.
| cp_size: int = 1 | ||
| dp_shard_size: int | None = None | ||
| # Derived at validation time from cp_size/dp_shard_size/WORLD_SIZE; typed as Any so this | ||
| # module doesn't need to import accelerate.ParallelismConfig just to annotate the field. | ||
| parallelism_config: Any = None | ||
|
|
||
| @model_validator(mode="after") | ||
| def _fill_parallelism(self) -> TrainingArguments: | ||
| # Read WORLD_SIZE (set by torchrun/accelerate, multi-node aware); fall back to the | ||
| # local GPU count for single-process runs. | ||
| import os | ||
|
|
||
| import torch | ||
|
|
||
| world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())) | ||
| if self.dp_shard_size is None: | ||
| self.dp_shard_size = world_size // self.cp_size |
There was a problem hiding this comment.
Validate cp_size/dp_shard_size before deriving parallelism.
Right now training.cp_size=0 crashes with a ZeroDivisionError at Line 106, and training.dp_shard_size=0 or negative values can slip through as invalid runtime state instead of a clean Pydantic error. Since these fields are user-configurable via YAML/dotlist, they should be rejected up front.
Suggested fix
class TrainingArguments(BaseModel):
@@
cp_size: int = 1
dp_shard_size: int | None = None
@@
+ `@field_validator`("cp_size")
+ `@classmethod`
+ def _check_cp_size(cls, v: int) -> int:
+ if v < 1:
+ raise ValueError("cp_size must be >= 1")
+ return v
+
+ `@field_validator`("dp_shard_size")
+ `@classmethod`
+ def _check_dp_shard_size(cls, v: int | None) -> int | None:
+ if v is not None and v < 1:
+ raise ValueError("dp_shard_size must be >= 1 when provided")
+ return v
+
`@model_validator`(mode="after")
def _fill_parallelism(self) -> TrainingArguments:As per coding guidelines "Contributors must treat user-provided artifacts/configs as untrusted and avoid unsafe parsing/deserialization."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_training_args.py` around lines 90 -
106, The _fill_parallelism model_validator currently divides by cp_size and can
accept non-positive cp_size/dp_shard_size from user configs; update
_fill_parallelism (on TrainingArguments) to first validate that cp_size is an
int > 0 and that if dp_shard_size is not None it is an int >= 1 (and also verify
computed world_size is >=1), and raise a clear ValueError (so Pydantic surfaces
it) if these checks fail; only after these guards compute world_size and set
self.dp_shard_size = world_size // self.cp_size when dp_shard_size is None to
avoid ZeroDivisionError and invalid runtime state.
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/recipe/config.py (1)
98-115:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftUse
default_factoryfor nested config defaults to avoid import-time instantiation.Line 99, Line 105, Line 111 (and similarly Line 124, Line 153, Line 171) instantiate nested config objects at import time via
default=...(). This is the same issue already raised earlier and is still unresolved. Prefer lazy construction withdefault_factory(which may require extendingModeloptFieldto pass throughdefault_factory).Suggested direction
-from pydantic import field_validator, model_validator +from pydantic import Field, field_validator, model_validator - model: SpecModelArgs = ModeloptField(default=SpecModelArgs(), ...) + model: SpecModelArgs = Field(default_factory=SpecModelArgs, ...) - data: SpecDataArgs = ModeloptField(default=SpecDataArgs(), ...) + data: SpecDataArgs = Field(default_factory=SpecDataArgs, ...) - training: SpecTrainingArgs = ModeloptField(default=SpecTrainingArgs(), ...) + training: SpecTrainingArgs = Field(default_factory=SpecTrainingArgs, ...)If
ModeloptFieldmust be retained, updateModeloptFieldto acceptdefault_factoryas an alternative todefault.#!/bin/bash # Verify eager constructor defaults in recipe config and whether ModeloptField supports default_factory. python - <<'PY' import ast, pathlib p = pathlib.Path("modelopt/recipe/config.py") tree = ast.parse(p.read_text()) for n in ast.walk(tree): if isinstance(n, ast.Call) and isinstance(n.func, ast.Name) and n.func.id == "ModeloptField": for kw in n.keywords: if kw.arg == "default" and isinstance(kw.value, ast.Call): fn = kw.value.func name = getattr(fn, "id", getattr(fn, "attr", type(fn).__name__)) print(f"Line {n.lineno}: eager default via {name}()") PY rg -n -C2 "def ModeloptField|default_factory|PydanticUndefined|assert .*default" modelopt/torch/opt/config.pyAlso applies to: 123-124, 152-153, 170-171
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/recipe/config.py` around lines 98 - 115, The ModeloptField usages for the nested config attributes model, data, and training (and other spots noted) currently instantiate objects at import time via default=SpecModelArgs(), default=SpecDataArgs(), default=SpecTrainingArgs(), which causes eager construction; change these to use lazy construction by supporting and passing default_factory=SpecModelArgs, default_factory=SpecDataArgs, default_factory=SpecTrainingArgs instead of calling the constructors, and if ModeloptField does not yet accept default_factory update the ModeloptField implementation to accept a default_factory kwarg, store/forward it to the underlying field construction logic (mirroring pydantic/dataclasses semantics), and ensure validate_default behavior still works with the factory.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/recipe/config.py`:
- Around line 27-32: The module currently performs hard imports of
speculative/HF classes DFlashConfig, EagleConfig, MedusaConfig and
SpecDataArgs/SpecModelArgs/SpecTrainingArgs at import time; change this to
lazy/plugin loading by moving those imports behind typing.TYPE_CHECKING for type
annotations and using the project's import_plugin() (or an equivalent lazy
import helper) where the classes are actually needed at runtime (e.g., in
factory functions or recipe registration). Keep type-only references via from
typing import TYPE_CHECKING and string annotations or if TYPE_CHECKING: import
the speculative classes, and replace direct module-level imports with calls to
import_plugin("modelopt.torch.speculative...") right before using
DFlashConfig/EagleConfig/MedusaConfig or the Spec* argument classes.
---
Duplicate comments:
In `@modelopt/recipe/config.py`:
- Around line 98-115: The ModeloptField usages for the nested config attributes
model, data, and training (and other spots noted) currently instantiate objects
at import time via default=SpecModelArgs(), default=SpecDataArgs(),
default=SpecTrainingArgs(), which causes eager construction; change these to use
lazy construction by supporting and passing default_factory=SpecModelArgs,
default_factory=SpecDataArgs, default_factory=SpecTrainingArgs instead of
calling the constructors, and if ModeloptField does not yet accept
default_factory update the ModeloptField implementation to accept a
default_factory kwarg, store/forward it to the underlying field construction
logic (mirroring pydantic/dataclasses semantics), and ensure validate_default
behavior still works with the factory.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 143fdb4d-cd06-457c-b8b0-cbace1bf21cc
📒 Files selected for processing (1)
modelopt/recipe/config.py
| from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig | ||
| from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs | ||
| from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs | ||
| from modelopt.torch.speculative.plugins.hf_training_args import ( | ||
| TrainingArguments as SpecTrainingArgs, | ||
| ) |
There was a problem hiding this comment.
Lazy-load speculative/HF config imports instead of importing them at module import time.
Line 27-Line 32 hard-import speculative/HF modules in a core recipe module. This makes optional extras effectively mandatory for import paths that don’t use speculative recipes (e.g., PTQ-only usage). Please gate these with lazy/plugin loading and keep type-only references behind deferred annotations/TYPE_CHECKING.
As per coding guidelines: "Avoid hard imports of optional dependencies at module level; features should be gated by install extras ([onnx], [hf], [all]) and loaded lazily via import_plugin()".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/recipe/config.py` around lines 27 - 32, The module currently
performs hard imports of speculative/HF classes DFlashConfig, EagleConfig,
MedusaConfig and SpecDataArgs/SpecModelArgs/SpecTrainingArgs at import time;
change this to lazy/plugin loading by moving those imports behind
typing.TYPE_CHECKING for type annotations and using the project's
import_plugin() (or an equivalent lazy import helper) where the classes are
actually needed at runtime (e.g., in factory functions or recipe registration).
Keep type-only references via from typing import TYPE_CHECKING and string
annotations or if TYPE_CHECKING: import the speculative classes, and replace
direct module-level imports with calls to
import_plugin("modelopt.torch.speculative...") right before using
DFlashConfig/EagleConfig/MedusaConfig or the Spec* argument classes.
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
examples/speculative_decoding/main.py (1)
186-209:⚠️ Potential issue | 🔴 Critical | ⚡ Quick win
data_moduleis undefined for Medusa recipes and will crash trainer construction.At Line 186,
data_moduleis created only for Eagle/DFlash. ForModelOptMedusaRecipe, Line 208 expands**data_modulebefore assignment (UnboundLocalError).Suggested fix
- if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)): + if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe)): data_module = make_speculative_data_module( tokenizer, recipe.data, train_len=training_args.training_seq_len, answer_only_loss=training_args.answer_only_loss, shift_labels=not is_dflash, ) + else: + raise ValueError(f"Unsupported speculative recipe type for dataset setup: {type(recipe).__name__}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/main.py` around lines 186 - 209, The variable data_module is only set inside the isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)) branch but later always expanded into EagleTrainerWithAccLog via **data_module, causing an UnboundLocalError for ModelOptMedusaRecipe; fix by ensuring data_module is always defined before trainer construction—either initialize data_module = {} prior to the conditional or add an else branch that creates/returns the appropriate Medusa data module (e.g., a make_medusa_data_module or equivalent) so that **data_module is safe for all recipe types (references: recipe, ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe, make_speculative_data_module, EagleTrainerWithAccLog, **data_module).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/main.py`:
- Around line 196-201: The code references eagle_cfg when deciding to append
LoRAWarmupCallback but eagle_cfg may be undefined in the checkpoint-resume path;
initialize or load eagle_cfg before that conditional so the .get calls are safe.
Concretely, ensure eagle_cfg is defined (e.g., set eagle_cfg = {} or populate it
from the checkpoint metadata in the resume branch) prior to the callback gating
logic that checks isinstance(recipe, ModelOptEagleRecipe) and
eagle_cfg.get(...), so
LoRAWarmupCallback(eagle_cfg["eagle_base_lora_warmup_steps"]) only executes when
eagle_cfg exists and contains the key.
- Around line 109-113: The code mutates training_args.parallelism_config when
training_args.cp_size > 1 but HfTrainingArguments lacks that field; either add a
parallelism_config field and an initializer/validator to HfTrainingArguments
(similar to
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments) so it is
always present, or guard the mutation by checking hasattr(training_args,
"parallelism_config") (or training_args.parallelism_config is not None) before
assigning sp_backend = None; update references to training_args.cp_size and
training_args.dp_shard_size accordingly so the behavior is consistent.
In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 257-264: load_draft_vocab_cache currently assigns torch.load
directly to model.eagle_module.d2t which can cause device/dtype/shape
mismatches; instead, load the saved tensor using map_location to the
model.eagle_module.d2t device, validate that the loaded tensor's dtype and numel
match model.eagle_module.d2t (and that model.eagle_config.draft_vocab_size <
model.eagle_config.vocab_size condition holds), then copy the data into the
existing buffer in-place (e.g., via copy_()) so the registered buffer keeps its
device and shape; update load_draft_vocab_cache and references to
model.eagle_module.d2t accordingly.
---
Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 186-209: The variable data_module is only set inside the
isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)) branch but later
always expanded into EagleTrainerWithAccLog via **data_module, causing an
UnboundLocalError for ModelOptMedusaRecipe; fix by ensuring data_module is
always defined before trainer construction—either initialize data_module = {}
prior to the conditional or add an else branch that creates/returns the
appropriate Medusa data module (e.g., a make_medusa_data_module or equivalent)
so that **data_module is safe for all recipe types (references: recipe,
ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe,
make_speculative_data_module, EagleTrainerWithAccLog, **data_module).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 16de0a2a-dbe7-453e-a197-0049f864b524
📒 Files selected for processing (2)
examples/speculative_decoding/main.pymodelopt/torch/speculative/plugins/hf_eagle.py
| if training_args.cp_size > 1: | ||
| patch_ring_attention_for_ttt() | ||
| # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0 | ||
| training_args.parallelism_config.sp_backend = None | ||
| print_rank_0( | ||
| f"arguments: {model_args}, {training_args}, {medusa_args}, " | ||
| f"eagle_cfg={eagle_cfg}, dflash_cfg={dflash_cfg}" | ||
| ) | ||
| if is_master(): |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify current code path has no local parallelism_config construction.
rg -n "parallelism_config|cp_size|dp_shard_size" examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 231
🏁 Script executed:
cat -n examples/speculative_decoding/main.py | head -150 | tail -90Repository: NVIDIA/Model-Optimizer
Length of output: 4210
🏁 Script executed:
rg -n "parallelism_config" --type py -A 2 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 2228
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_training_args.py | sed -n '90,150p'Repository: NVIDIA/Model-Optimizer
Length of output: 2256
🏁 Script executed:
cat -n examples/speculative_decoding/main.py | sed -n '61,110p'Repository: NVIDIA/Model-Optimizer
Length of output: 2309
Add parallelism_config field to HfTrainingArguments or guard access before mutation.
At line 112, the code accesses training_args.parallelism_config.sp_backend = None, but parallelism_config is not defined in the HfTrainingArguments dataclass. The dataclass only has cp_size and dp_shard_size fields. When cp_size > 1, this will raise an AttributeError at runtime.
The docstring (lines 66–67) states the field set "MUST stay in sync with" modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, which does have a parallelism_config field and a validator to initialize it. Either add the field and validator to the dataclass, or guard the attribute access with a check.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/main.py` around lines 109 - 113, The code
mutates training_args.parallelism_config when training_args.cp_size > 1 but
HfTrainingArguments lacks that field; either add a parallelism_config field and
an initializer/validator to HfTrainingArguments (similar to
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments) so it is
always present, or guard the mutation by checking hasattr(training_args,
"parallelism_config") (or training_args.parallelism_config is not None) before
assigning sp_backend = None; update references to training_args.cp_size and
training_args.dp_shard_size accordingly so the behavior is consistent.
| def load_draft_vocab_cache(model, d2t_path: str) -> None: | ||
| """Load the draft vocab cache from the given path.""" | ||
| if d2t_path is None or model.eagle_config.draft_vocab_size >= model.eagle_config.vocab_size: | ||
| return | ||
| if not os.path.isfile(d2t_path): | ||
| raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t_path}") | ||
| model.eagle_module.d2t = torch.load(d2t_path, weights_only=True) | ||
| print_rank_0(f"Loaded draft vocab cache from {d2t_path}.") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "hf_eagle.py" -type fRepository: NVIDIA/Model-Optimizer
Length of output: 115
🏁 Script executed:
cat -n modelopt/torch/speculative/plugins/hf_eagle.py | sed -n '250,270p'Repository: NVIDIA/Model-Optimizer
Length of output: 1054
🏁 Script executed:
# Search for where d2t is registered or initialized
rg "d2t" modelopt/torch/speculative/plugins/hf_eagle.py -A 2 -B 2Repository: NVIDIA/Model-Optimizer
Length of output: 1231
🏁 Script executed:
# Search for how d2t is used in logits remapping or elsewhere
rg "d2t" modelopt/torch/speculative/ -A 1 -B 1Repository: NVIDIA/Model-Optimizer
Length of output: 4108
Load the draft vocab cache with proper device mapping and validation.
At line 263, assigning torch.load(...) directly to model.eagle_module.d2t risks device/dtype/shape mismatches that fail later during logits remapping. The loaded tensor defaults to CPU but the buffer is pre-registered on the model's device. Load with map_location, validate dtype and numel match expectations, and use copy_() to update the buffer in-place.
Suggested fix
`@staticmethod`
def load_draft_vocab_cache(model, d2t_path: str) -> None:
"""Load the draft vocab cache from the given path."""
if d2t_path is None or model.eagle_config.draft_vocab_size >= model.eagle_config.vocab_size:
return
if not os.path.isfile(d2t_path):
raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t_path}")
- model.eagle_module.d2t = torch.load(d2t_path, weights_only=True)
+ target = model.eagle_module.d2t
+ loaded = torch.load(d2t_path, map_location=target.device, weights_only=True)
+ if loaded.dtype != target.dtype or loaded.numel() != target.numel():
+ raise ValueError(
+ f"Invalid draft vocab cache at {d2t_path}: "
+ f"expected dtype={target.dtype}, numel={target.numel()}, "
+ f"got dtype={loaded.dtype}, numel={loaded.numel()}"
+ )
+ target.copy_(loaded.view_as(target))
print_rank_0(f"Loaded draft vocab cache from {d2t_path}.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 257 - 264,
load_draft_vocab_cache currently assigns torch.load directly to
model.eagle_module.d2t which can cause device/dtype/shape mismatches; instead,
load the saved tensor using map_location to the model.eagle_module.d2t device,
validate that the loaded tensor's dtype and numel match model.eagle_module.d2t
(and that model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size
condition holds), then copy the data into the existing buffer in-place (e.g.,
via copy_()) so the registered buffer keeps its device and shape; update
load_draft_vocab_cache and references to model.eagle_module.d2t accordingly.
What does this PR do?
Type of change: new feature
Port the speculative-decoding example to ModelOpt's recipe/config subsystem:
model/data/training/<algo>now load from a single YAML with Pydantic validation and OmegaConf dotlist overrides. Adds built-ineagle3/dflashrecipes, drops the redundanttraining.modefield (inferred from recipe class), and shrinksmain.pyby ~145 lines (−208 / +63).JIRA: https://jirasw.nvidia.com/browse/OMNIML-3859
Usage
python main.py --config general/speculative_decoding/eagle3 \ model.model_name_or_path=meta-llama/Llama-3.2-1B \ data.data_path=train.jsonl \ training.output_dir=ckpts/testTesting
pytest tests/unit/recipe/test_loader.py— new coverage for Eagle / DFlash YAML loading, dotlist overrides, and field-level validation.eagle3anddflashrecipes end-to-end.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).main.pyCLI switched to--config <recipe>(+ dotlist overrides); the old argparse flags are removed.CONTRIBUTING.md: N/A — no new deps (pydantic,omegaconfalready in core).tests/unit/recipe/test_loader.py.Additional Information
Follow-up to the
modelopt.recipesubsystem introduced for PTQ; this PR extends the same declarative-YAML pattern to speculative decoding (Eagle3 / DFlash / Medusa).Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation