Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ NVIDIA Model Optimizer Changelog
- ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy.
- Add :meth:`compute_quantization_mse <modelopt.torch.quantization.model_quant.compute_quantization_mse>` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering.
- **AutoQDQ**: New tool for automated Q/DQ (Quantize/Dequantize) placement optimization for ONNX models. Uses TensorRT latency measurements to choose insertion schemes that minimize inference time. Discovers regions automatically, groups them by structural pattern, and tests multiple Q/DQ schemes per pattern. Supports INT8 and FP8 quantization, pattern cache for warm-start on similar models, checkpoint/resume, and importing patterns from an existing QDQ baseline. CLI: ``python -m modelopt.onnx.quantization.autotune``. See the AutoQDQ guide in the documentation.
- Add ``get_auto_quantize_config`` API to extract a flat quantization config from ``auto_quantize`` search results, enabling re-quantization at different effective bit targets without re-running calibration.
- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search.
- Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules.

**Misc**

Expand Down
43 changes: 32 additions & 11 deletions modelopt/torch/opt/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,33 +236,54 @@ def state_dict(self) -> SearchStateDict:
"""The state dictionary that can be stored/loaded."""
return {key: getattr(self, key) for key in self.default_state_dict}

def load_search_checkpoint(self) -> bool:
def _get_checkpoint_path(self) -> str | None:
"""Get per-rank checkpoint path when distributed, otherwise the original path."""
checkpoint = self.config["checkpoint"]
if checkpoint is None:
return None
if dist.is_initialized():
dirname, basename = os.path.split(checkpoint)
name, ext = os.path.splitext(basename)
return os.path.join(dirname, f"{name}{dist.rank()}{ext}")
return checkpoint

def load_search_checkpoint(self, strict=True) -> bool:
"""Load function for search checkpoint returning indicator whether checkpoint was loaded."""
# check if checkpoint exists
checkpoint: str | None = self.config["checkpoint"]
checkpoint = self._get_checkpoint_path()
if checkpoint is None:
return False
# Backward compat: fall back to the original single-file path
if not os.path.exists(checkpoint):
warn_rank_0(
f"Per-rank checkpoint {checkpoint} not found, falling back to "
f"{self.config['checkpoint']}. Ensure world size matches the original run."
)
checkpoint = self.config["checkpoint"]
if not os.path.exists(checkpoint):
warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.")
return False

# iterate through state dict and load keys
print_rank_0(f"Loading searcher state from {checkpoint}...")
# Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input
state_dict = torch.load(checkpoint, weights_only=False)
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
for key, state in state_dict.items():
setattr(self, key, state)
if strict:
assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!"
for key, default_val in self.default_state_dict.items():
setattr(self, key, state_dict.get(key, default_val))
return True

def save_search_checkpoint(self, verbose=False) -> None:
"""Save function for search checkpoint."""
# check if save requirements are satisfied
checkpoint: str | None = self.config["checkpoint"]
if checkpoint is None or not dist.is_master():
checkpoint = self._get_checkpoint_path()
if checkpoint is None:
return

# save state dict
if dist.is_initialized():
warn_rank_0(
"torch.distributed is initialized. Please maintain the same parallelism "
"configuration (world size, TP, EP, etc.) across search save and restore sessions."
)

if verbose:
print(f"Saving searcher state to {checkpoint}...")
save_dirname, _ = os.path.split(checkpoint)
Expand Down
Loading
Loading