Skip to content

[Refactor] speculative decoding: use mto config subsystem#1328

Open
h-guo18 wants to merge 5 commits intomainfrom
haoguo/spec-mto-config
Open

[Refactor] speculative decoding: use mto config subsystem#1328
h-guo18 wants to merge 5 commits intomainfrom
haoguo/spec-mto-config

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Apr 23, 2026

What does this PR do?

Type of change: new feature

Port the speculative-decoding example to ModelOpt's recipe/config subsystem: model / data / training / <algo> now load from a single YAML with Pydantic validation and OmegaConf dotlist overrides. Adds built-in eagle3 / dflash recipes, drops the redundant training.mode field (inferred from recipe class), and shrinks main.py by ~145 lines (−208 / +63).

JIRA: https://jirasw.nvidia.com/browse/OMNIML-3859

Usage

python main.py --config general/speculative_decoding/eagle3 \
    model.model_name_or_path=meta-llama/Llama-3.2-1B \
    data.data_path=train.jsonl \
    training.output_dir=ckpts/test

Testing

  • pytest tests/unit/recipe/test_loader.py — new coverage for Eagle / DFlash YAML loading, dotlist overrides, and field-level validation.
  • Smoke-trained both built-in eagle3 and dflash recipes end-to-end.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ❌ — main.py CLI switched to --config <recipe> (+ dotlist overrides); the old argparse flags are removed.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A — no new deps (pydantic, omegaconf already in core).
  • Did you write any new necessary tests?: ✅ — tests/unit/recipe/test_loader.py.
  • Did you update Changelog?: ❌ — to be added.

Additional Information

Follow-up to the modelopt.recipe subsystem introduced for PTQ; this PR extends the same declarative-YAML pattern to speculative decoding (Eagle3 / DFlash / Medusa).

Summary by CodeRabbit

  • New Features

    • Added typed speculative-decoding recipe support for EAGLE, DFlash, and Medusa; CLI dotlist overrides supported for single-file recipes.
    • Trainer/config schema extended with speculative-training fields and draft-vocab cache loading for Eagle.
  • Bug Fixes

    • Offline training no longer mutates model configs; loader enforces required algorithm sections and prints recipe/config only on the primary process.
    • Reduced noisy per-rank logging by restricting status output to the primary process.
  • Tests

    • Expanded tests for recipe loading, dotlist overrides, validation strictness, and error cases.
  • Documentation

    • Recipe YAMLs updated with metadata and usage notes.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 23, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 23, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Typed single-file speculative-decoding recipes and loader dotlist overrides are added; loader dispatches on metadata.recipe_type to create validated EAGLE/DFlash/Medusa recipe objects with typed model, data, and HF training sections. HF training args schema, eagle draft-vocab loader, and the example script were updated to consume the typed recipe.

Changes

Cohort / File(s) Summary
Recipe Configuration
modelopt/recipe/config.py
Adds SPECULATIVE_EAGLE, SPECULATIVE_DFLASH, SPECULATIVE_MEDUSA and new recipe models: ModelOptSpeculativeRecipeBase, ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe with typed model, data, training sections and algorithm-specific fields; offline flags set at recipe level and some cross-field validation added.
Recipe Loading
modelopt/recipe/loader.py
load_recipe(..., overrides=None) added; supports CLI-style dotlist overrides for single-file recipes (applied via OmegaConf), rejects overrides for directory-format recipes, and instantiates concrete speculative recipe classes based on metadata.recipe_type, erroring when required top-level algorithm sections are missing.
HF Training Args Schema
modelopt/torch/speculative/plugins/hf_training_args.py
New Pydantic-only schema module providing ModelArguments, DataArguments, TrainingArguments (permits HF Trainer extras) and speculative fields (training_seq_len, estimate_ar, ar_validate_steps, answer_only_loss, cp_size, dp_shard_size, parallelism_config) with runtime validation to infer dp_shard_size and optionally populate parallelism_config.
Speculative Torch Config Changes
modelopt/torch/speculative/config.py
Removes context-based auto-derivation validators for dflash_offline/eagle_offline, defers mask-token required checks when tokenizer absent, removes rope-vs-training-seq warning; field docs updated to indicate offline flags are recipe-derived.
HF EAGLE Plugin
modelopt/torch/speculative/plugins/hf_eagle.py
Replaced unqualified prints with print_rank_0 and added HFEagleModel.load_draft_vocab_cache(model, d2t_path: str) to validate and conditionally load a draft-vocab cache tensor into model.eagle_module.d2t.
Example Script
examples/speculative_decoding/main.py
Refactors script to load typed recipe via load_recipe() (with overrides), build HfTrainingArguments from recipe.training.model_dump(), route conversion by concrete recipe type (Medusa/Eagle/DFlash), remove prior mode-based dispatch and auto-parallelism mutation, and print recipe/config only on master.
Recipe YAMLs
modelopt_recipes/general/speculative_decoding/dflash.yaml, modelopt_recipes/general/speculative_decoding/eagle3.yaml
Converted to full modelopt recipes: add metadata with recipe_type and description, update header comments, and remove explicit training.mode fields.
Tests
tests/unit/recipe/test_loader.py
Expanded tests: _apply_dotlist unit tests (parsing, scientific notation, nested creation, immutability, error cases), load_recipe(..., overrides=...) behavior, positive tests for loading typed EAGLE/DFlash recipes, and negative tests for missing algorithm sections and invalid typed fields.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant Main as Speculative Main Script
    participant Loader as load_recipe()
    participant Parser as YAML Parser
    participant Validator as Pydantic Validator

    User->>Main: run with recipe_path + overrides
    Main->>Loader: load_recipe(path, overrides)
    Loader->>Parser: parse YAML file
    Parser-->>Loader: parsed dict
    alt overrides provided
        Loader->>Loader: apply dotlist overrides (OmegaConf)
    end
    Loader->>Validator: select recipe class by metadata.recipe_type
    Validator->>Validator: validate model / data / training sections
    Validator-->>Loader: validated recipe instance
    Loader-->>Main: return typed recipe
    Main->>Main: construct HfTrainingArguments from recipe.training
    Main->>Main: route conversion/execution by recipe type (eagle/dflash/medusa)
Loading
sequenceDiagram
    participant Script as Main Script
    participant Recipe as Recipe Config
    participant HfArgs as HfTrainingArguments
    participant Trainer as HF Trainer

    Script->>Recipe: load_recipe(path)
    Recipe-->>Script: typed recipe
    Script->>HfArgs: HfTrainingArguments.from_dict(recipe.training.model_dump())
    HfArgs->>HfArgs: infer dp_shard_size from WORLD_SIZE (if needed)
    HfArgs->>HfArgs: validate speculative fields (training_seq_len, estimate_ar, ...)
    HfArgs-->>Script: validated hf args
    rect rgba(100,150,200,0.5)
        Script->>Trainer: perform recipe-specific conversion & training
    end
    Trainer-->>Script: complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[Refactor] speculative decoding: use mto config subsystem' accurately describes the main objective of the PR—refactoring speculative decoding to use the ModelOpt (mto) config subsystem instead of direct argparse/OmegaConf + HfArgumentParser.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Changes introduce only torch.load calls with weights_only=True, no unsafe pickle, numpy.load with allow_pickle, hardcoded trust_remote_code, eval/exec, nosec bypasses, or non-permissive dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/spec-mto-config

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 23, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1328/

Built to branch gh-pages at 2026-04-30 02:21 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@h-guo18 h-guo18 changed the title mto config subsystem speculative decoding: use mto config subsystem Apr 23, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 23, 2026

Codecov Report

❌ Patch coverage is 90.90909% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.64%. Comparing base (c796611) to head (69f3c40).
⚠️ Report is 8 commits behind head on main.

Files with missing lines Patch % Lines
...lopt/torch/speculative/plugins/hf_training_args.py 86.95% 6 Missing ⚠️
modelopt/recipe/loader.py 90.90% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1328      +/-   ##
==========================================
+ Coverage   74.60%   74.64%   +0.03%     
==========================================
  Files         467      468       +1     
  Lines       50176    50260      +84     
==========================================
+ Hits        37435    37517      +82     
- Misses      12741    12743       +2     
Flag Coverage Δ
examples 38.85% <72.72%> (+3.25%) ⬆️
gpu 58.55% <72.72%> (-0.48%) ⬇️
regression 15.20% <85.22%> (+0.44%) ⬆️
unit 52.40% <90.90%> (+0.05%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread examples/speculative_decoding/main.py
Comment thread examples/speculative_decoding/main.py
@h-guo18 h-guo18 self-assigned this Apr 23, 2026
@h-guo18 h-guo18 marked this pull request as ready for review April 23, 2026 22:57
@h-guo18 h-guo18 requested review from a team as code owners April 23, 2026 22:57
@h-guo18 h-guo18 requested review from sychen52 and yeyu-nvidia April 23, 2026 22:57
@h-guo18 h-guo18 changed the title speculative decoding: use mto config subsystem [Refactor] speculative decoding: use mto config subsystem Apr 23, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/main.py (1)

153-186: ⚠️ Potential issue | 🟠 Major

Medusa recipe skips data_module creation, causing an undefined variable error.

When recipe is ModelOptMedusaRecipe, the code at line 154 calls mtsp.convert() but then falls through to line 203-211 where data_module is only created for ModelOptEagleRecipe or ModelOptDFlashRecipe. This leaves data_module undefined, causing a NameError at line 226 when passed to EagleTrainerWithAccLog.

Proposed fix: Add Medusa to data_module creation or handle separately
     print_rank_0("Loading dataset...")
     is_dflash = isinstance(recipe, ModelOptDFlashRecipe)
-    if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)):
+    if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe)):
         data_module = make_speculative_data_module(
             tokenizer,
             recipe.data,
             train_len=training_args.training_seq_len,
             answer_only_loss=training_args.answer_only_loss,
             shift_labels=not is_dflash,
         )
+    else:
+        raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 153 - 186, The Medusa
branch (when recipe is ModelOptMedusaRecipe) calls mtsp.convert but never
creates the data_module used later by EagleTrainerWithAccLog, causing a
NameError; update the ModelOptMedusaRecipe branch to either construct the same
data_module as done for ModelOptEagleRecipe/ModelOptDFlashRecipe (using
recipe.data and the existing data-module factory logic) or explicitly set
data_module to an appropriate value (or None) and handle that case before
calling EagleTrainerWithAccLog so data_module is always defined; refer to the
branches around ModelOptMedusaRecipe, ModelOptEagleRecipe, ModelOptDFlashRecipe,
mtsp.convert, recipe.data, and EagleTrainerWithAccLog to place the fix.
🧹 Nitpick comments (1)
examples/speculative_decoding/main.py (1)

61-76: Field synchronization between HfTrainingArguments and SpecTrainingArgs should be explicit.

The docstring correctly notes these field sets "MUST stay in sync" with modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments. Consider adding a test or assertion to catch drift between these two definitions automatically.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 61 - 76, Add an automated
check to ensure HfTrainingArguments and the plugin TrainingArguments stay in
sync: write a small unit test or an import-time assertion that imports
HfTrainingArguments (examples.speculative_decoding.main.HfTrainingArguments) and
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, extracts
their dataclass/field names (and optionally defaults/types), and fails if any
field names or default values differ; place the check in tests (preferred) or at
module import to catch drift early.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 96-130: The _fill_parallelism validator currently sets world_size
using int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())), which yields
0 on CPU-only machines and leads to dp_shard_size becoming 0; change that logic
to treat a missing WORLD_SIZE and a torch.cuda.device_count() of 0 as a
single-process run by using something like device_count =
torch.cuda.device_count(); world_size = int(os.environ.get("WORLD_SIZE", max(1,
device_count))) (or world_size = int(os.environ.get("WORLD_SIZE", device_count
or 1))) so world_size is at least 1, then compute dp_shard_size and do the
existing divisibility checks and ParallelismConfig creation as before to avoid
downstream surprises and division-by-zero-like behavior when cp_size or
dp_shard_size are used.

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 153-186: The Medusa branch (when recipe is ModelOptMedusaRecipe)
calls mtsp.convert but never creates the data_module used later by
EagleTrainerWithAccLog, causing a NameError; update the ModelOptMedusaRecipe
branch to either construct the same data_module as done for
ModelOptEagleRecipe/ModelOptDFlashRecipe (using recipe.data and the existing
data-module factory logic) or explicitly set data_module to an appropriate value
(or None) and handle that case before calling EagleTrainerWithAccLog so
data_module is always defined; refer to the branches around
ModelOptMedusaRecipe, ModelOptEagleRecipe, ModelOptDFlashRecipe, mtsp.convert,
recipe.data, and EagleTrainerWithAccLog to place the fix.

---

Nitpick comments:
In `@examples/speculative_decoding/main.py`:
- Around line 61-76: Add an automated check to ensure HfTrainingArguments and
the plugin TrainingArguments stay in sync: write a small unit test or an
import-time assertion that imports HfTrainingArguments
(examples.speculative_decoding.main.HfTrainingArguments) and
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, extracts
their dataclass/field names (and optionally defaults/types), and fails if any
field names or default values differ; place the check in tests (preferred) or at
module import to catch drift early.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 838ec6d9-91a7-46a6-a197-701d6268d76c

📥 Commits

Reviewing files that changed from the base of the PR and between c796611 and 69f3c40.

📒 Files selected for processing (7)
  • examples/speculative_decoding/main.py
  • modelopt/recipe/config.py
  • modelopt/recipe/loader.py
  • modelopt/torch/speculative/plugins/hf_training_args.py
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • modelopt_recipes/general/speculative_decoding/eagle3.yaml
  • tests/unit/recipe/test_loader.py

Comment on lines +96 to +130
@model_validator(mode="after")
def _fill_parallelism(self) -> TrainingArguments:
# Read WORLD_SIZE (set by torchrun/accelerate, multi-node aware); fall back to the
# local GPU count for single-process runs.
import os

import torch

world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
if self.dp_shard_size is None:
self.dp_shard_size = world_size // self.cp_size

# Build a ParallelismConfig only when actually running distributed — matches the
# previous main.py guard and avoids requiring accelerate on single-GPU dev boxes.
if self.cp_size > 1 or self.dp_shard_size > 1:
parallel_size = self.dp_shard_size * self.cp_size
if world_size % parallel_size != 0:
raise ValueError(
f"world_size ({world_size}) must be divisible by "
f"dp_shard_size ({self.dp_shard_size}) * cp_size ({self.cp_size}) "
f"= {parallel_size}"
)
try:
from accelerate import ParallelismConfig
except ImportError as e:
raise ImportError(
"cp_size>1 or dp_shard_size>1 requires `accelerate` for ParallelismConfig. "
"Install it via `pip install accelerate`."
) from e
self.parallelism_config = ParallelismConfig(
cp_size=self.cp_size,
dp_shard_size=self.dp_shard_size,
dp_replicate_size=world_size // parallel_size,
)
return self
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Edge case: torch.cuda.device_count() returns 0 on CPU-only machines.

When running on a machine without CUDA (e.g., during local development or CI), torch.cuda.device_count() returns 0. Combined with a missing WORLD_SIZE env var, this results in world_size=0, which causes dp_shard_size = 0 // cp_size = 0. While this won't crash immediately, it may produce unexpected behavior downstream.

Consider handling this edge case explicitly:

Proposed fix: Default world_size to 1 when no GPUs detected
-        world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
+        world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count() or 1))
         if self.dp_shard_size is None:
             self.dp_shard_size = world_size // self.cp_size
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_training_args.py` around lines 96 -
130, The _fill_parallelism validator currently sets world_size using
int(os.environ.get("WORLD_SIZE", torch.cuda.device_count())), which yields 0 on
CPU-only machines and leads to dp_shard_size becoming 0; change that logic to
treat a missing WORLD_SIZE and a torch.cuda.device_count() of 0 as a
single-process run by using something like device_count =
torch.cuda.device_count(); world_size = int(os.environ.get("WORLD_SIZE", max(1,
device_count))) (or world_size = int(os.environ.get("WORLD_SIZE", device_count
or 1))) so world_size is at least 1, then compute dp_shard_size and do the
existing divisibility checks and ParallelismConfig creation as before to avoid
downstream surprises and division-by-zero-like behavior when cp_size or
dp_shard_size are used.

@shengliangxu
Copy link
Copy Markdown
Collaborator

As discussed in the meeting, let's strip off the config override part into a separate change.


import torch

world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these logic really is beyond the model_validator scope, put it somewhere like "init_distributed_env" or something like that is better.

how do you think?

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/spec-mto-config branch from 69f3c40 to 3bf010b Compare April 30, 2026 00:46
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/recipe/config.py`:
- Around line 97-114: The nested config fields (model, data, training) currently
instantiate SpecModelArgs/SpecDataArgs/SpecTrainingArgs at import time; update
ModeloptField to accept a default_factory alternative (or remove the assertion
forcing default so callers can use pydantic.Field) so you can switch those
fields to lazy defaults like default_factory=SpecTrainingArgs (or use
pydantic.Field(default_factory=...)) instead of default=SpecTrainingArgs();
specifically modify the ModeloptField implementation (the class/function named
ModeloptField) to accept and pass through default_factory to pydantic.Field (and
stop asserting a concrete default), then change the three fields using
ModeloptField to use default_factory=SpecModelArgs,
default_factory=SpecDataArgs, and default_factory=SpecTrainingArgs.

In `@modelopt/recipe/loader.py`:
- Around line 109-118: In _apply_dotlist, avoid eagerly resolving OmegaConf
interpolations by changing the OmegaConf.to_container call so it does not
resolve interpolations; specifically, update the call in function _apply_dotlist
to use resolve=False instead of resolve=True when converting the merged
OmegaConf to a plain dict so user-supplied overrides (e.g., ${...} patterns) are
preserved and not expanded before Pydantic validation.

In `@modelopt/torch/speculative/plugins/hf_training_args.py`:
- Around line 90-106: The _fill_parallelism model_validator currently divides by
cp_size and can accept non-positive cp_size/dp_shard_size from user configs;
update _fill_parallelism (on TrainingArguments) to first validate that cp_size
is an int > 0 and that if dp_shard_size is not None it is an int >= 1 (and also
verify computed world_size is >=1), and raise a clear ValueError (so Pydantic
surfaces it) if these checks fail; only after these guards compute world_size
and set self.dp_shard_size = world_size // self.cp_size when dp_shard_size is
None to avoid ZeroDivisionError and invalid runtime state.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 50f5aa9c-3c49-42a0-be51-0ab116444297

📥 Commits

Reviewing files that changed from the base of the PR and between 69f3c40 and 3bf010b.

📒 Files selected for processing (7)
  • examples/speculative_decoding/main.py
  • modelopt/recipe/config.py
  • modelopt/recipe/loader.py
  • modelopt/torch/speculative/plugins/hf_training_args.py
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • modelopt_recipes/general/speculative_decoding/eagle3.yaml
  • tests/unit/recipe/test_loader.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt_recipes/general/speculative_decoding/eagle3.yaml
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt_recipes/general/speculative_decoding/dflash.yaml
  • tests/unit/recipe/test_loader.py

Comment on lines +153 to +176
if isinstance(recipe, ModelOptMedusaRecipe):
mtsp.convert(model, [("medusa", recipe.medusa.model_dump())])
elif isinstance(recipe, ModelOptEagleRecipe):
# Validate and rewrite eagle config fields
eagle_cfg = EagleConfig.model_validate(
eagle_cfg,
context={"training_args": training_args, "data_args": data_args},
recipe.eagle.model_dump(),
context={"training_args": training_args, "data_args": recipe.data},
).model_dump()
mtsp.convert(model, [("eagle", eagle_cfg)])

# Load draft vocab cache if the draft model uses a compressed vocabulary
if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size:
if not os.path.isfile(data_args.draft_vocab_cache):
raise FileNotFoundError(
f"Draft vocab cache provided but not found: {data_args.draft_vocab_cache}"
)
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
elif training_args.mode == "dflash":
d2t = recipe.data.draft_vocab_cache
if d2t is None or not os.path.isfile(d2t):
raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t}")
model.eagle_module.d2t = torch.load(d2t, weights_only=True)
print_rank_0(f"Loaded draft vocab cache from {d2t}.")
elif isinstance(recipe, ModelOptDFlashRecipe):
dflash_cfg = DFlashConfig.model_validate(
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
recipe.dflash.model_dump(), context={"tokenizer": tokenizer, "data_args": recipe.data}
).model_dump()
mtsp.convert(model, [("dflash", dflash_cfg)])
else:
raise Exception(f"{training_args.mode} is not supported!")
raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

ModelOptMedusaRecipe reaches trainer setup with no data_module.

This function now accepts Medusa recipes and converts the model for them, but data_module is only created for Eagle/DFlash. A Medusa run will therefore hit an UnboundLocalError at Line 216 when **data_module is expanded.

Either wire Medusa into dataset construction here as well, or reject Medusa earlier with an explicit error until the training path is implemented.

Also applies to: 193-216

Comment thread modelopt/recipe/config.py
Comment on lines +97 to +114
model: SpecModelArgs = ModeloptField(
default=SpecModelArgs(),
title="HF model args",
description="ModelArguments for the base HF model to train a draft head against.",
validate_default=True,
)
data: SpecDataArgs = ModeloptField(
default=SpecDataArgs(),
title="HF data args",
description="DataArguments for the training/offline dataset.",
validate_default=True,
)
training: SpecTrainingArgs = ModeloptField(
default=SpecTrainingArgs(),
title="HF training args",
description="Speculative-decoding extensions; HF trainer fields flow through as extras.",
validate_default=True,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In Python and Pydantic v2, are expressions like default=SpecTrainingArgs()evaluated at class definition/import time, and isdefault_factory=SpecTrainingArgs the recommended way to avoid eager instantiation and validator side effects for nested model defaults?

💡 Result:

Yes—if you write default=SpecTrainingArgs (i.e., you pass an already-constructed instance as the field default), that instance is created eagerly at Python class definition/import time. In contrast, default_factory=SpecTrainingArgs (i.e., a callable) defers instantiation until Pydantic needs a default for a model instance. In Pydantic v2, the docs explicitly describe default_factory as a callable that is “called to generate a default value”, whereas default is a fixed default value. So using default_factory=SpecTrainingArgs is the recommended way to avoid eager instantiation for nested model defaults. About “validator side effects”/avoid validator-related surprises: Pydantic’s default_factory can be designed to run only when Pydantic is producing the default for the field. Also, Pydantic has shown in issues that the timing/context of default_factory invocation matters (e.g., bugs/confusing errors when validation of earlier fields fails while a data-taking default factory is being called). While those issues are about specific dependent-default scenarios, the general guidance still stands: avoid constructing nested models in the default= argument when you want Pydantic to control when instantiation happens. Practical pattern (recommended) - Prefer: sub: SubModel = Field(default_factory=SubModel) - Avoid: sub: SubModel = Field(default=SubModel) Example - If SpecTrainingArgs is a BaseModel (or otherwise Pydantic-validated type): use default_factory=SpecTrainingArgs so each outer model instance gets its own default created at instantiation time, not at import/class-definition time.

Citations:


🏁 Script executed:

# First, find and examine the SpecTrainingArgs implementation
find . -type f -name "*.py" | xargs grep -l "class SpecTrainingArgs" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Read the config.py file around lines 97-114 and 122-153 to understand the full context
sed -n '90,160p' modelopt/recipe/config.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 2815


🏁 Script executed:

# Check if SpecTrainingArgs has __init__ or validators with side effects
grep -A 30 "class SpecTrainingArgs" $(find . -type f -name "*.py" -path "*/speculative/*")

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for SpecTrainingArgs definition more broadly
rg "class SpecTrainingArgs" --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Also check the import statement at the top of config.py
head -50 modelopt/recipe/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1977


🏁 Script executed:

# Find the TrainingArguments class in hf_training_args
find . -path "*speculative/plugins/hf_training_args*" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 123


🏁 Script executed:

# Once we find it, examine the TrainingArguments.__init__ or field validators
rg "_fill_parallelism" --type py -B 5 -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 1622


🏁 Script executed:

# Get the full _fill_parallelism method and surrounding context
sed -n '120,180p' modelopt/torch/speculative/plugins/hf_training_args.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 642


🏁 Script executed:

# Also check if there are any imports of accelerate in this file
grep -n "accelerate\|ParallelismConfig" modelopt/torch/speculative/plugins/hf_training_args.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1012


🏁 Script executed:

# Verify the full validator implementation
sed -n '95,130p' modelopt/torch/speculative/plugins/hf_training_args.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1942


🏁 Script executed:

# Check if ModeloptField supports default_factory parameter
rg "def ModeloptField|class ModeloptField" --type py -A 20

Repository: NVIDIA/Model-Optimizer

Length of output: 1492


🏁 Script executed:

# Check the full ModeloptField implementation
cat -n modelopt/torch/opt/config.py | sed -n '1,100p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4970


🏁 Script executed:

# Verify if default_factory can be passed through **kwargs to Field
rg "default_factory" modelopt/torch/opt/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if there are any examples of default_factory usage in the codebase
rg "default_factory" --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 3012


🏁 Script executed:

# Check if ModeloptField is used anywhere with default_factory attempts
rg "ModeloptField.*default_factory" --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Also verify the assertion in ModeloptField more carefully
sed -n '51,55p' modelopt/torch/opt/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 348


🏁 Script executed:

# Check if using default_factory with ModeloptField would work given the assertion
python3 << 'EOF'
# Simulate what would happen if we try to use default_factory with ModeloptField
from pydantic_core import PydanticUndefined

# This is what ModeloptField does
def ModeloptField(default=PydanticUndefined, **kwargs):
    assert default is not PydanticUndefined, "A default value must be set for ModeloptField."
    return {"default": default, **kwargs}

# Try to pass default_factory
try:
    result = ModeloptField(default_factory=lambda: {})
    print("Success:", result)
except AssertionError as e:
    print("AssertionError:", e)
except TypeError as e:
    print("TypeError:", e)
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 129


🏁 Script executed:

# Check if there's a way to modify ModeloptField to support default_factory
grep -B 5 -A 15 "def ModeloptField" modelopt/torch/opt/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 928


🏁 Script executed:

# Verify if using direct Field instead of ModeloptField would work in the recipe file
rg "from pydantic import.*Field" modelopt/recipe/config.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check what imports are available at the top of config.py
head -30 modelopt/recipe/config.py | grep -E "^from|^import"

Repository: NVIDIA/Model-Optimizer

Length of output: 636


🏁 Script executed:

# Check if we need to modify ModeloptField to support default_factory
rg "ModeloptField" modelopt/torch/opt/config.py -B 3 -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 1159


Use lazy defaults for nested recipe sections to avoid import-time validator side effects.

SpecTrainingArgs(), SpecModelArgs(), and related config objects are instantiated when the module is imported, not when a recipe instance is created. Because SpecTrainingArgs has an @model_validator(mode="after") decorator on _fill_parallelism(), that validator runs immediately at import time, reading WORLD_SIZE and potentially importing accelerate on multi-GPU hosts. This makes importing modelopt.recipe.config trigger host inspection and can expose host-specific derived defaults.

The suggested fix requires updating ModeloptField (in modelopt/torch/opt/config.py) to support default_factory in addition to default. Currently, ModeloptField asserts that a default value must be provided, which prevents using default_factory=.... Either:

  1. Modify ModeloptField to accept default_factory as an alternative (passing it through to pydantic.Field), or
  2. Use pydantic.Field directly instead of ModeloptField for these nested config fields and adjust the ModeloptField assertion accordingly.

Once that foundation is in place, the diff changes shown (switching to default_factory=SpecTrainingArgs, etc.) can be applied.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/recipe/config.py` around lines 97 - 114, The nested config fields
(model, data, training) currently instantiate
SpecModelArgs/SpecDataArgs/SpecTrainingArgs at import time; update ModeloptField
to accept a default_factory alternative (or remove the assertion forcing default
so callers can use pydantic.Field) so you can switch those fields to lazy
defaults like default_factory=SpecTrainingArgs (or use
pydantic.Field(default_factory=...)) instead of default=SpecTrainingArgs();
specifically modify the ModeloptField implementation (the class/function named
ModeloptField) to accept and pass through default_factory to pydantic.Field (and
stop asserting a concrete default), then change the three fields using
ModeloptField to use default_factory=SpecModelArgs,
default_factory=SpecDataArgs, and default_factory=SpecTrainingArgs.

Comment thread modelopt/recipe/loader.py
Comment on lines +109 to +118
def _apply_dotlist(data: dict, overrides: list[str]) -> dict:
"""Merge ``a.b.c=value`` command line overrides on top of ``data`` via OmegaConf."""
for entry in overrides:
if "=" not in entry:
raise ValueError(f"Invalid override (missing '='): {entry!r}")
merged = OmegaConf.merge(
OmegaConf.create(data),
OmegaConf.from_dotlist(list(overrides)),
)
return OmegaConf.to_container(merged, resolve=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== OmegaConf resolver registrations =="
rg -n -C2 'OmegaConf\.(register_new_resolver|register_resolver)' --type=py

echo
echo "== Eager resolution sites =="
rg -n -C2 'to_container\(.*resolve=True' --type=py

echo
echo "== Full recipe printing/logging sites =="
rg -n -C2 'pprint\(recipe\)|print_rank_0\(.*recipe|print\(.*recipe' --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 6579


🏁 Script executed:

# Check the load_recipe() function to understand the full context
cat -n modelopt/recipe/loader.py | head -130

Repository: NVIDIA/Model-Optimizer

Length of output: 5940


🏁 Script executed:

# Check examples/speculative_decoding/main.py to see the recipe handling
rg -B10 -A5 'pprint\(recipe\)' examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 686


🏁 Script executed:

# Check if register_hydra_resolvers is called during recipe loading
rg -n 'register_hydra_resolvers' --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 633


🏁 Script executed:

# Check how the result of _apply_dotlist is used
rg -A10 '_apply_dotlist' modelopt/recipe/loader.py | head -30

Repository: NVIDIA/Model-Optimizer

Length of output: 984


🏁 Script executed:

# Check the usage context to see if resolve=True is actually necessary
rg -B5 -A15 '_load_recipe_from_file' modelopt/recipe/loader.py | grep -A20 'def _load_recipe_from_file'

Repository: NVIDIA/Model-Optimizer

Length of output: 673


Don't eagerly resolve OmegaConf interpolations on untrusted recipe overrides.

OmegaConf.to_container(..., resolve=True) at line 118 evaluates ${...} expressions before Pydantic validation. User-supplied YAML and CLI overrides can contain ${oc.env:...} patterns that resolve to environment variables. Since examples/speculative_decoding/main.py calls pprint(recipe) on the loaded config, resolved secrets will leak into logs.

Change resolve=True to resolve=False. Pydantic handles unresolved OmegaConf dicts fine, and type conversion is already performed by YAML parsing.

Suggested fix
 def _apply_dotlist(data: dict, overrides: list[str]) -> dict:
@@
-    return OmegaConf.to_container(merged, resolve=True)
+    return OmegaConf.to_container(merged, resolve=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/recipe/loader.py` around lines 109 - 118, In _apply_dotlist, avoid
eagerly resolving OmegaConf interpolations by changing the
OmegaConf.to_container call so it does not resolve interpolations; specifically,
update the call in function _apply_dotlist to use resolve=False instead of
resolve=True when converting the merged OmegaConf to a plain dict so
user-supplied overrides (e.g., ${...} patterns) are preserved and not expanded
before Pydantic validation.

Comment on lines +90 to +106
cp_size: int = 1
dp_shard_size: int | None = None
# Derived at validation time from cp_size/dp_shard_size/WORLD_SIZE; typed as Any so this
# module doesn't need to import accelerate.ParallelismConfig just to annotate the field.
parallelism_config: Any = None

@model_validator(mode="after")
def _fill_parallelism(self) -> TrainingArguments:
# Read WORLD_SIZE (set by torchrun/accelerate, multi-node aware); fall back to the
# local GPU count for single-process runs.
import os

import torch

world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
if self.dp_shard_size is None:
self.dp_shard_size = world_size // self.cp_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate cp_size/dp_shard_size before deriving parallelism.

Right now training.cp_size=0 crashes with a ZeroDivisionError at Line 106, and training.dp_shard_size=0 or negative values can slip through as invalid runtime state instead of a clean Pydantic error. Since these fields are user-configurable via YAML/dotlist, they should be rejected up front.

Suggested fix
 class TrainingArguments(BaseModel):
@@
     cp_size: int = 1
     dp_shard_size: int | None = None
@@
+    `@field_validator`("cp_size")
+    `@classmethod`
+    def _check_cp_size(cls, v: int) -> int:
+        if v < 1:
+            raise ValueError("cp_size must be >= 1")
+        return v
+
+    `@field_validator`("dp_shard_size")
+    `@classmethod`
+    def _check_dp_shard_size(cls, v: int | None) -> int | None:
+        if v is not None and v < 1:
+            raise ValueError("dp_shard_size must be >= 1 when provided")
+        return v
+
     `@model_validator`(mode="after")
     def _fill_parallelism(self) -> TrainingArguments:

As per coding guidelines "Contributors must treat user-provided artifacts/configs as untrusted and avoid unsafe parsing/deserialization."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_training_args.py` around lines 90 -
106, The _fill_parallelism model_validator currently divides by cp_size and can
accept non-positive cp_size/dp_shard_size from user configs; update
_fill_parallelism (on TrainingArguments) to first validate that cp_size is an
int > 0 and that if dp_shard_size is not None it is an int >= 1 (and also verify
computed world_size is >=1), and raise a clear ValueError (so Pydantic surfaces
it) if these checks fail; only after these guards compute world_size and set
self.dp_shard_size = world_size // self.cp_size when dp_shard_size is None to
avoid ZeroDivisionError and invalid runtime state.

h-guo18 added 2 commits April 30, 2026 01:40
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/recipe/config.py (1)

98-115: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Use default_factory for nested config defaults to avoid import-time instantiation.

Line 99, Line 105, Line 111 (and similarly Line 124, Line 153, Line 171) instantiate nested config objects at import time via default=...(). This is the same issue already raised earlier and is still unresolved. Prefer lazy construction with default_factory (which may require extending ModeloptField to pass through default_factory).

Suggested direction
-from pydantic import field_validator, model_validator
+from pydantic import Field, field_validator, model_validator

-    model: SpecModelArgs = ModeloptField(default=SpecModelArgs(), ...)
+    model: SpecModelArgs = Field(default_factory=SpecModelArgs, ...)

-    data: SpecDataArgs = ModeloptField(default=SpecDataArgs(), ...)
+    data: SpecDataArgs = Field(default_factory=SpecDataArgs, ...)

-    training: SpecTrainingArgs = ModeloptField(default=SpecTrainingArgs(), ...)
+    training: SpecTrainingArgs = Field(default_factory=SpecTrainingArgs, ...)

If ModeloptField must be retained, update ModeloptField to accept default_factory as an alternative to default.

#!/bin/bash
# Verify eager constructor defaults in recipe config and whether ModeloptField supports default_factory.
python - <<'PY'
import ast, pathlib
p = pathlib.Path("modelopt/recipe/config.py")
tree = ast.parse(p.read_text())
for n in ast.walk(tree):
    if isinstance(n, ast.Call) and isinstance(n.func, ast.Name) and n.func.id == "ModeloptField":
        for kw in n.keywords:
            if kw.arg == "default" and isinstance(kw.value, ast.Call):
                fn = kw.value.func
                name = getattr(fn, "id", getattr(fn, "attr", type(fn).__name__))
                print(f"Line {n.lineno}: eager default via {name}()")
PY

rg -n -C2 "def ModeloptField|default_factory|PydanticUndefined|assert .*default" modelopt/torch/opt/config.py

Also applies to: 123-124, 152-153, 170-171

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/recipe/config.py` around lines 98 - 115, The ModeloptField usages
for the nested config attributes model, data, and training (and other spots
noted) currently instantiate objects at import time via default=SpecModelArgs(),
default=SpecDataArgs(), default=SpecTrainingArgs(), which causes eager
construction; change these to use lazy construction by supporting and passing
default_factory=SpecModelArgs, default_factory=SpecDataArgs,
default_factory=SpecTrainingArgs instead of calling the constructors, and if
ModeloptField does not yet accept default_factory update the ModeloptField
implementation to accept a default_factory kwarg, store/forward it to the
underlying field construction logic (mirroring pydantic/dataclasses semantics),
and ensure validate_default behavior still works with the factory.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/recipe/config.py`:
- Around line 27-32: The module currently performs hard imports of
speculative/HF classes DFlashConfig, EagleConfig, MedusaConfig and
SpecDataArgs/SpecModelArgs/SpecTrainingArgs at import time; change this to
lazy/plugin loading by moving those imports behind typing.TYPE_CHECKING for type
annotations and using the project's import_plugin() (or an equivalent lazy
import helper) where the classes are actually needed at runtime (e.g., in
factory functions or recipe registration). Keep type-only references via from
typing import TYPE_CHECKING and string annotations or if TYPE_CHECKING: import
the speculative classes, and replace direct module-level imports with calls to
import_plugin("modelopt.torch.speculative...") right before using
DFlashConfig/EagleConfig/MedusaConfig or the Spec* argument classes.

---

Duplicate comments:
In `@modelopt/recipe/config.py`:
- Around line 98-115: The ModeloptField usages for the nested config attributes
model, data, and training (and other spots noted) currently instantiate objects
at import time via default=SpecModelArgs(), default=SpecDataArgs(),
default=SpecTrainingArgs(), which causes eager construction; change these to use
lazy construction by supporting and passing default_factory=SpecModelArgs,
default_factory=SpecDataArgs, default_factory=SpecTrainingArgs instead of
calling the constructors, and if ModeloptField does not yet accept
default_factory update the ModeloptField implementation to accept a
default_factory kwarg, store/forward it to the underlying field construction
logic (mirroring pydantic/dataclasses semantics), and ensure validate_default
behavior still works with the factory.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 143fdb4d-cd06-457c-b8b0-cbace1bf21cc

📥 Commits

Reviewing files that changed from the base of the PR and between 23094e4 and d3ffd94.

📒 Files selected for processing (1)
  • modelopt/recipe/config.py

Comment thread modelopt/recipe/config.py
Comment on lines +27 to +32
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig, MedusaConfig
from modelopt.torch.speculative.plugins.hf_training_args import DataArguments as SpecDataArgs
from modelopt.torch.speculative.plugins.hf_training_args import ModelArguments as SpecModelArgs
from modelopt.torch.speculative.plugins.hf_training_args import (
TrainingArguments as SpecTrainingArgs,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Lazy-load speculative/HF config imports instead of importing them at module import time.

Line 27-Line 32 hard-import speculative/HF modules in a core recipe module. This makes optional extras effectively mandatory for import paths that don’t use speculative recipes (e.g., PTQ-only usage). Please gate these with lazy/plugin loading and keep type-only references behind deferred annotations/TYPE_CHECKING.

As per coding guidelines: "Avoid hard imports of optional dependencies at module level; features should be gated by install extras ([onnx], [hf], [all]) and loaded lazily via import_plugin()".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/recipe/config.py` around lines 27 - 32, The module currently
performs hard imports of speculative/HF classes DFlashConfig, EagleConfig,
MedusaConfig and SpecDataArgs/SpecModelArgs/SpecTrainingArgs at import time;
change this to lazy/plugin loading by moving those imports behind
typing.TYPE_CHECKING for type annotations and using the project's
import_plugin() (or an equivalent lazy import helper) where the classes are
actually needed at runtime (e.g., in factory functions or recipe registration).
Keep type-only references via from typing import TYPE_CHECKING and string
annotations or if TYPE_CHECKING: import the speculative classes, and replace
direct module-level imports with calls to
import_plugin("modelopt.torch.speculative...") right before using
DFlashConfig/EagleConfig/MedusaConfig or the Spec* argument classes.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
examples/speculative_decoding/main.py (1)

186-209: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

data_module is undefined for Medusa recipes and will crash trainer construction.

At Line 186, data_module is created only for Eagle/DFlash. For ModelOptMedusaRecipe, Line 208 expands **data_module before assignment (UnboundLocalError).

Suggested fix
-    if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)):
+    if isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe)):
         data_module = make_speculative_data_module(
             tokenizer,
             recipe.data,
             train_len=training_args.training_seq_len,
             answer_only_loss=training_args.answer_only_loss,
             shift_labels=not is_dflash,
         )
+    else:
+        raise ValueError(f"Unsupported speculative recipe type for dataset setup: {type(recipe).__name__}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 186 - 209, The variable
data_module is only set inside the isinstance(recipe, (ModelOptEagleRecipe,
ModelOptDFlashRecipe)) branch but later always expanded into
EagleTrainerWithAccLog via **data_module, causing an UnboundLocalError for
ModelOptMedusaRecipe; fix by ensuring data_module is always defined before
trainer construction—either initialize data_module = {} prior to the conditional
or add an else branch that creates/returns the appropriate Medusa data module
(e.g., a make_medusa_data_module or equivalent) so that **data_module is safe
for all recipe types (references: recipe, ModelOptEagleRecipe,
ModelOptDFlashRecipe, ModelOptMedusaRecipe, make_speculative_data_module,
EagleTrainerWithAccLog, **data_module).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/speculative_decoding/main.py`:
- Around line 196-201: The code references eagle_cfg when deciding to append
LoRAWarmupCallback but eagle_cfg may be undefined in the checkpoint-resume path;
initialize or load eagle_cfg before that conditional so the .get calls are safe.
Concretely, ensure eagle_cfg is defined (e.g., set eagle_cfg = {} or populate it
from the checkpoint metadata in the resume branch) prior to the callback gating
logic that checks isinstance(recipe, ModelOptEagleRecipe) and
eagle_cfg.get(...), so
LoRAWarmupCallback(eagle_cfg["eagle_base_lora_warmup_steps"]) only executes when
eagle_cfg exists and contains the key.
- Around line 109-113: The code mutates training_args.parallelism_config when
training_args.cp_size > 1 but HfTrainingArguments lacks that field; either add a
parallelism_config field and an initializer/validator to HfTrainingArguments
(similar to
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments) so it is
always present, or guard the mutation by checking hasattr(training_args,
"parallelism_config") (or training_args.parallelism_config is not None) before
assigning sp_backend = None; update references to training_args.cp_size and
training_args.dp_shard_size accordingly so the behavior is consistent.

In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 257-264: load_draft_vocab_cache currently assigns torch.load
directly to model.eagle_module.d2t which can cause device/dtype/shape
mismatches; instead, load the saved tensor using map_location to the
model.eagle_module.d2t device, validate that the loaded tensor's dtype and numel
match model.eagle_module.d2t (and that model.eagle_config.draft_vocab_size <
model.eagle_config.vocab_size condition holds), then copy the data into the
existing buffer in-place (e.g., via copy_()) so the registered buffer keeps its
device and shape; update load_draft_vocab_cache and references to
model.eagle_module.d2t accordingly.

---

Duplicate comments:
In `@examples/speculative_decoding/main.py`:
- Around line 186-209: The variable data_module is only set inside the
isinstance(recipe, (ModelOptEagleRecipe, ModelOptDFlashRecipe)) branch but later
always expanded into EagleTrainerWithAccLog via **data_module, causing an
UnboundLocalError for ModelOptMedusaRecipe; fix by ensuring data_module is
always defined before trainer construction—either initialize data_module = {}
prior to the conditional or add an else branch that creates/returns the
appropriate Medusa data module (e.g., a make_medusa_data_module or equivalent)
so that **data_module is safe for all recipe types (references: recipe,
ModelOptEagleRecipe, ModelOptDFlashRecipe, ModelOptMedusaRecipe,
make_speculative_data_module, EagleTrainerWithAccLog, **data_module).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 16de0a2a-dbe7-453e-a197-0049f864b524

📥 Commits

Reviewing files that changed from the base of the PR and between d3ffd94 and a2afa7e.

📒 Files selected for processing (2)
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/plugins/hf_eagle.py

Comment on lines 109 to +113
if training_args.cp_size > 1:
patch_ring_attention_for_ttt()
# Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
training_args.parallelism_config.sp_backend = None
print_rank_0(
f"arguments: {model_args}, {training_args}, {medusa_args}, "
f"eagle_cfg={eagle_cfg}, dflash_cfg={dflash_cfg}"
)
if is_master():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify current code path has no local parallelism_config construction.
rg -n "parallelism_config|cp_size|dp_shard_size" examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 231


🏁 Script executed:

cat -n examples/speculative_decoding/main.py | head -150 | tail -90

Repository: NVIDIA/Model-Optimizer

Length of output: 4210


🏁 Script executed:

rg -n "parallelism_config" --type py -A 2 -B 2

Repository: NVIDIA/Model-Optimizer

Length of output: 2228


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_training_args.py | sed -n '90,150p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2256


🏁 Script executed:

cat -n examples/speculative_decoding/main.py | sed -n '61,110p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2309


Add parallelism_config field to HfTrainingArguments or guard access before mutation.

At line 112, the code accesses training_args.parallelism_config.sp_backend = None, but parallelism_config is not defined in the HfTrainingArguments dataclass. The dataclass only has cp_size and dp_shard_size fields. When cp_size > 1, this will raise an AttributeError at runtime.

The docstring (lines 66–67) states the field set "MUST stay in sync with" modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments, which does have a parallelism_config field and a validator to initialize it. Either add the field and validator to the dataclass, or guard the attribute access with a check.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 109 - 113, The code
mutates training_args.parallelism_config when training_args.cp_size > 1 but
HfTrainingArguments lacks that field; either add a parallelism_config field and
an initializer/validator to HfTrainingArguments (similar to
modelopt.torch.speculative.plugins.hf_training_args.TrainingArguments) so it is
always present, or guard the mutation by checking hasattr(training_args,
"parallelism_config") (or training_args.parallelism_config is not None) before
assigning sp_backend = None; update references to training_args.cp_size and
training_args.dp_shard_size accordingly so the behavior is consistent.

Comment thread examples/speculative_decoding/main.py
Comment on lines +257 to +264
def load_draft_vocab_cache(model, d2t_path: str) -> None:
"""Load the draft vocab cache from the given path."""
if d2t_path is None or model.eagle_config.draft_vocab_size >= model.eagle_config.vocab_size:
return
if not os.path.isfile(d2t_path):
raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t_path}")
model.eagle_module.d2t = torch.load(d2t_path, weights_only=True)
print_rank_0(f"Loaded draft vocab cache from {d2t_path}.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "hf_eagle.py" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 115


🏁 Script executed:

cat -n modelopt/torch/speculative/plugins/hf_eagle.py | sed -n '250,270p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1054


🏁 Script executed:

# Search for where d2t is registered or initialized
rg "d2t" modelopt/torch/speculative/plugins/hf_eagle.py -A 2 -B 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1231


🏁 Script executed:

# Search for how d2t is used in logits remapping or elsewhere
rg "d2t" modelopt/torch/speculative/ -A 1 -B 1

Repository: NVIDIA/Model-Optimizer

Length of output: 4108


Load the draft vocab cache with proper device mapping and validation.

At line 263, assigning torch.load(...) directly to model.eagle_module.d2t risks device/dtype/shape mismatches that fail later during logits remapping. The loaded tensor defaults to CPU but the buffer is pre-registered on the model's device. Load with map_location, validate dtype and numel match expectations, and use copy_() to update the buffer in-place.

Suggested fix
 `@staticmethod`
 def load_draft_vocab_cache(model, d2t_path: str) -> None:
     """Load the draft vocab cache from the given path."""
     if d2t_path is None or model.eagle_config.draft_vocab_size >= model.eagle_config.vocab_size:
         return
     if not os.path.isfile(d2t_path):
         raise FileNotFoundError(f"Draft vocab cache provided but not found: {d2t_path}")
-    model.eagle_module.d2t = torch.load(d2t_path, weights_only=True)
+    target = model.eagle_module.d2t
+    loaded = torch.load(d2t_path, map_location=target.device, weights_only=True)
+    if loaded.dtype != target.dtype or loaded.numel() != target.numel():
+        raise ValueError(
+            f"Invalid draft vocab cache at {d2t_path}: "
+            f"expected dtype={target.dtype}, numel={target.numel()}, "
+            f"got dtype={loaded.dtype}, numel={loaded.numel()}"
+        )
+    target.copy_(loaded.view_as(target))
     print_rank_0(f"Loaded draft vocab cache from {d2t_path}.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 257 - 264,
load_draft_vocab_cache currently assigns torch.load directly to
model.eagle_module.d2t which can cause device/dtype/shape mismatches; instead,
load the saved tensor using map_location to the model.eagle_module.d2t device,
validate that the loaded tensor's dtype and numel match model.eagle_module.d2t
(and that model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size
condition holds), then copy the data into the existing buffer in-place (e.g.,
via copy_()) so the registered buffer keeps its device and shape; update
load_draft_vocab_cache and references to model.eagle_module.d2t accordingly.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants