From 31b3bca71560cf670f58fbfbd670a2f161d41082 Mon Sep 17 00:00:00 2001 From: realAsma Date: Mon, 9 Mar 2026 20:08:54 +0000 Subject: [PATCH 1/3] Auto Quantize improvements and bug fixes for large sparse MoEs - Add get_auto_quantize_config API to extract quant config from search results - Save/restore calibration state in auto_quantize checkpoints - Add NemotronH MoE expert support in auto_quantize grouping/scoring - Fix SequentialQuantizer scope, use F.kl_div for numerical stability - Fix mypy errors and clean up tests Co-Authored-By: Claude Opus 4.6 Signed-off-by: realAsma --- CHANGELOG.rst | 3 + modelopt/torch/opt/searcher.py | 39 +++- modelopt/torch/quantization/algorithms.py | 221 +++++++++++++----- modelopt/torch/quantization/model_calib.py | 26 +-- modelopt/torch/quantization/model_quant.py | 75 +++++- .../torch/quantization/plugins/huggingface.py | 14 +- .../quantization/src/tensor_quant_gpu_fp8.cu | 10 +- modelopt/torch/quantization/tensor_quant.py | 8 +- modelopt/torch/quantization/utils.py | 7 +- .../quantization/test_tensor_quant_cuda.py | 13 ++ .../unit/torch/quantization/test_autoquant.py | 105 ++++++++- .../torch/quantization/test_quantize_cpu.py | 16 ++ 12 files changed, 440 insertions(+), 97 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ae94ef2ab3..f3430012fa 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -21,6 +21,9 @@ NVIDIA Model Optimizer Changelog - ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy. - Add :meth:`compute_quantization_mse ` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering. - **AutoQDQ**: New tool for automated Q/DQ (Quantize/Dequantize) placement optimization for ONNX models. Uses TensorRT latency measurements to choose insertion schemes that minimize inference time. Discovers regions automatically, groups them by structural pattern, and tests multiple Q/DQ schemes per pattern. Supports INT8 and FP8 quantization, pattern cache for warm-start on similar models, checkpoint/resume, and importing patterns from an existing QDQ baseline. CLI: ``python -m modelopt.onnx.quantization.autotune``. See the AutoQDQ guide in the documentation. +- Add ``get_auto_quantize_config`` API to extract a flat quantization config from ``auto_quantize`` search results, enabling re-quantization at different effective bit targets without re-running calibration. +- Improve ``auto_quantize`` checkpoint/resume: calibration state is now saved and restored across runs, avoiding redundant calibration when resuming a search. +- Add NemotronH MoE expert support in ``auto_quantize`` grouping and scoring rules. **Misc** diff --git a/modelopt/torch/opt/searcher.py b/modelopt/torch/opt/searcher.py index ab3930c207..49f43ef0b7 100644 --- a/modelopt/torch/opt/searcher.py +++ b/modelopt/torch/opt/searcher.py @@ -236,33 +236,50 @@ def state_dict(self) -> SearchStateDict: """The state dictionary that can be stored/loaded.""" return {key: getattr(self, key) for key in self.default_state_dict} - def load_search_checkpoint(self) -> bool: + def _get_checkpoint_path(self) -> str | None: + """Get per-rank checkpoint path when distributed, otherwise the original path.""" + checkpoint = self.config["checkpoint"] + if checkpoint is None: + return None + if dist.is_initialized(): + dirname, basename = os.path.split(checkpoint) + name, ext = os.path.splitext(basename) + return os.path.join(dirname, f"{name}{dist.rank()}{ext}") + return checkpoint + + def load_search_checkpoint(self, strict=True) -> bool: """Load function for search checkpoint returning indicator whether checkpoint was loaded.""" - # check if checkpoint exists - checkpoint: str | None = self.config["checkpoint"] + checkpoint = self._get_checkpoint_path() if checkpoint is None: return False + # Backward compat: fall back to the original single-file path + if not os.path.exists(checkpoint): + checkpoint = self.config["checkpoint"] if not os.path.exists(checkpoint): warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") return False - # iterate through state dict and load keys print_rank_0(f"Loading searcher state from {checkpoint}...") # Security NOTE: weights_only=False is used here on ModelOpt-generated ckpt, not on untrusted user input state_dict = torch.load(checkpoint, weights_only=False) - assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!" - for key, state in state_dict.items(): - setattr(self, key, state) + if strict: + assert state_dict.keys() == self.state_dict().keys(), "Keys in checkpoint don't match!" + for key, default_val in self.default_state_dict.items(): + setattr(self, key, state_dict.get(key, default_val)) return True def save_search_checkpoint(self, verbose=False) -> None: """Save function for search checkpoint.""" - # check if save requirements are satisfied - checkpoint: str | None = self.config["checkpoint"] - if checkpoint is None or not dist.is_master(): + checkpoint = self._get_checkpoint_path() + if checkpoint is None: return - # save state dict + if dist.is_initialized(): + warn_rank_0( + "torch.distributed is initialized. Please maintain the same parallelism " + "configuration (world size, TP, EP, etc.) across search save and restore sessions." + ) + if verbose: print(f"Saving searcher state to {checkpoint}...") save_dirname, _ = os.path.split(checkpoint) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 7c5d5e6dbb..94075ce6b9 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -28,6 +28,7 @@ import regex as re import torch import torch.nn as nn +import torch.nn.functional as F from tqdm import tqdm from modelopt.torch.opt.conversion import ModeloptStateManager @@ -160,7 +161,7 @@ def fold_pqs_to_weights(model): model_calib._ENABLE_FOLDING_PQS_TO_WEIGHTS = True for name, module in model.named_modules(): if is_quantized_linear(module): - with SequentialQuantizer.convert_to_single_quantizer(model): + with SequentialQuantizer.convert_to_single_quantizer(module): if module.weight_quantizer.pre_quant_scale is not None: weight_pqs = module.weight_quantizer.pre_quant_scale delattr(module.weight_quantizer, "_pre_quant_scale") @@ -184,12 +185,14 @@ def __init__( quant_modules: list[nn.Module] | None = None, score_modules: list[nn.Module] | None = None, name: str | None = None, + quant_module_names: list[str] | None = None, ) -> None: """Initializes Hparam with original value and choices.""" choices = sorted({*(choices if choices else []), QuantRecipe(quant_cfg=None)}) super().__init__(choices, original=choices[0]) self.name = name + self.quant_module_names = quant_module_names or [] self.quant_modules = list(set(quant_modules or [])) self.score_modules = list(set(score_modules or self.quant_modules)) @@ -333,11 +336,14 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): candidate_stats: dict[str, dict[str, list[float]]] best: dict[str, Any] + quantizer_states: dict + method_name: str = None quant_grouping_rules = [ r"^(.*?)\.(q_proj|k_proj|v_proj)$", # q_proj, k_proj, v_proj for llama like models # gate_proj, up_proj, down_proj for Qwen3 like MoE models r"^(.*?\.mlp\.experts)\.\d+\.(gate_proj|up_proj|down_proj)$", + r"^(.*?\.mixer\.experts)\.\d+\.(up_proj|down_proj)$", # NemotronH MoE experts r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts @@ -363,7 +369,9 @@ def default_search_config(self): def default_state_dict(self) -> SearchStateDict: """Get the default state dict for AutoQuantize.""" return { + "method": self.method_name, "candidate_stats": defaultdict(dict), + "quantizer_states": {}, "best": {"recipe": {}, "constraints": {}, "score": float("inf"), "is_satisfied": False}, } @@ -379,6 +387,9 @@ def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: ) return config + def load_search_checkpoint(self) -> bool: + return super().load_search_checkpoint(strict=False) + @staticmethod def _is_auto_quantize_module(module): return ( @@ -517,12 +528,13 @@ def insert_hparams_after_merge_rules(self, model, quant_recipes, disabled_layers disabled = any(disabled for _, _, disabled, _ in module_info_list) score_modules = [score_module for _, _, _, score_module in module_info_list] - quant_recipes = None if disabled else quant_recipes + _quant_recipes = None if disabled else quant_recipes hparam = QuantRecipeHparam( - quant_recipes, + _quant_recipes, quant_modules=quant_modules, score_modules=score_modules, name=str(group_key), + quant_module_names=[name for _, name, _, _ in module_info_list], ) for module in quant_modules: @@ -570,6 +582,7 @@ def initialize_candidate_stats(self): self.candidate_stats[name]["formats"] = formats self.candidate_stats[name]["scores"] = scores self.candidate_stats[name]["costs"] = costs + self.candidate_stats[name]["module_names"] = hparam.quant_module_names def _run_func(self, func, num_iters=1, desc=""): for i, data in tqdm( @@ -584,7 +597,17 @@ def before_search(self): # Import here to avoid circular import from modelopt.torch.quantization.model_quant import calibrate + from .conversion import restore_quantizer_state, update_quantize_metadata + from .utils import get_quantizer_state_dict, set_quantizer_state_dict + super().before_search() + restored_method = getattr(self, "method", None) + if self.candidate_stats and restored_method not in (None, self.method_name): + raise ValueError( + f"Checkpoint method '{restored_method}' does not match current method " + f"'{self.method_name}'. Use a different checkpoint path." + ) + self.method = self.method_name search_recipes = self._get_search_recipes(self.config["quantization_formats"]) self._verify_constraint(search_recipes) @@ -595,10 +618,27 @@ def before_search(self): QuantRecipe.disable_folding_pqs_to_weights() # Iterate over the search recipes and calibrate the quantizers for each recipe + calibrated_new = False for recipe in search_recipes: if recipe == QuantRecipe(quant_cfg=None): # No-quant format continue + for name, hparam in named_hparams(self.model, configurable=True): + if not isinstance(hparam, QuantRecipeHparam): + continue + hparam.active = recipe + + if recipe in self.quantizer_states: + saved = self.quantizer_states[recipe] + # config is unused by restore_quantizer_state + restore_quantizer_state( + self.model, QuantizeConfig(), {"quantizer_state": saved["metadata"]} + ) + set_quantizer_state_dict(self.model, saved["state_dict"]) + if self.config["verbose"]: + print_rank_0(f"AutoQuantize: Restored calibration for {recipe}") + continue + # Lets reduce the number of calibration steps for AWQ since it takes longer num_calib_steps = ( self.config["num_calib_steps"] @@ -613,12 +653,6 @@ def forward_loop(model): desc=f"Calibrating for {recipe}", ) - for name, hparam in named_hparams(self.model, configurable=True): - if not isinstance(hparam, QuantRecipeHparam): - continue - hparam.active = recipe - - # Now calibrate the quantizers for the recipe calibrate( self.model, algorithm=recipe.config.algorithm, @@ -628,6 +662,17 @@ def forward_loop(model): # across layers, lets not save this new mode in the modelopt state. # TODO: This is a hack. We need to create a mode for auto_quantize to handle this in a clean way. ModeloptStateManager(self.model).state_dict().pop() + metadata: dict = {} + # config is unused by update_quantize_metadata + update_quantize_metadata(self.model, QuantizeConfig(), metadata) + self.quantizer_states[recipe] = { + "metadata": metadata["quantizer_state"], + "state_dict": get_quantizer_state_dict(self.model), + } + calibrated_new = True + + if calibrated_new: + self.save_search_checkpoint(verbose=self.config["verbose"]) if self.candidate_stats: if self.config["verbose"]: @@ -636,9 +681,16 @@ def forward_loop(model): self.estimate_sensitivity_scores() self.initialize_candidate_stats() - # Save checkpoint after successful score estimation self.save_search_checkpoint(verbose=self.config["verbose"]) + @staticmethod + def _print_recipe_summary(best_recipe, total_cost, total_weight_size, prefix="AutoQuantize"): + for name, recipe in best_recipe.items(): + print_rank_0(f"{prefix} best recipe for {name.replace('.quant_recipe', '')}: {recipe}") + effective_bits = (total_cost / total_weight_size) * 16 + print_rank_0(f"{prefix} effective bits: {effective_bits:.2f}") + return effective_bits + @staticmethod def _get_total_weight_size(modules): return sum( @@ -695,16 +747,13 @@ def run_search(self): get_hparam(self.model, name).active = best_format best_constraints += best_hparam_recipe_info["costs"] best_scores += best_hparam_recipe_info["scores"] - if verbose: - print_rank_0( - f"AutoQuantize best recipe for {name.replace('.quant_recipe', '')}: {best_recipe[name]}" - ) - effective_bits_from_search = (best_constraints / total_weight_size) * 16 if verbose: - print_rank_0( - f"AutoQuantize effective bits from search: {effective_bits_from_search: .2f}" + effective_bits_from_search = self._print_recipe_summary( + best_recipe, best_constraints, total_weight_size ) + else: + effective_bits_from_search = (best_constraints / total_weight_size) * 16 self.best["recipe"] = best_recipe self.best["constraints"] = {"effective_bits": effective_bits_from_search} @@ -753,9 +802,12 @@ class AutoQuantizeGradientSearcher(_AutoQuantizeBaseSearcher): can be estimated together at a single point (e.g., the MLP output level). """ + method_name = "gradient" + score_module_rules = [ # Use MLP layer output for gate_proj, up_proj, down_proj for Qwen3 like MoE models (local and shared experts) r"^(.*?\.mlp)\.experts\.\d+\.(gate_proj|up_proj|down_proj)$", + r"^(.*?\.mixer)\.experts\.\d+\.(up_proj|down_proj)$", # NemotronH MoE experts r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts ] @@ -1025,9 +1077,7 @@ def run_search_with_stats(self, max_weight_size, verbose=False): # TODO: Enable torch compile for this function # Currently modelopt.onnx is breaking this -def _get_softmax_dist( - logits: torch.Tensor, tp_group, return_log_prob: bool = False -) -> torch.Tensor: +def _get_log_softmax_dist(logits: torch.Tensor, tp_group) -> torch.Tensor: # TODO: test this dtype = logits.dtype max_logits = torch.amax(logits, dim=-1, keepdim=True) @@ -1035,46 +1085,23 @@ def _get_softmax_dist( logits = (logits - max_logits).float() sum_exp_logits = torch.exp(torch.logsumexp(logits, dim=-1, keepdim=True)) torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=tp_group) - logits = logits - torch.log(sum_exp_logits) - if return_log_prob: - return logits.to(dtype) - else: - return torch.exp(logits).to(dtype) + return (logits - torch.log(sum_exp_logits)).to(dtype) -def _get_softmax(logits: torch.Tensor, return_log_prob: bool = False) -> torch.Tensor: - # TODO: do we need to do log_softmax in float32? - # log_softmax is supposed to be numerically stable implementation - log_prob = torch.log_softmax(logits.float(), dim=-1) - if return_log_prob: - return log_prob - else: - return torch.exp(log_prob) - - -def _get_p_log_q(p: torch.Tensor, log_q: torch.Tensor) -> torch.Tensor: - return torch.sum(p * log_q).float() - - -def _get_prob_from_logits( - logits: torch.Tensor, return_log_prob: bool = False, lm_head: nn.Module = None -) -> torch.Tensor: +def _get_log_prob(logits: torch.Tensor, lm_head: nn.Module = None) -> torch.Tensor: parallel_state: ParallelState | None = ( getattr(lm_head, "parallel_state", None) if lm_head is not None else None ) if parallel_state is not None and parallel_state.tensor_parallel_group.is_initialized(): - return _get_softmax_dist( - logits, parallel_state.tensor_parallel_group.group, return_log_prob - ) - return _get_softmax(logits, return_log_prob) + return _get_log_softmax_dist(logits, parallel_state.tensor_parallel_group.group) + return torch.log_softmax(logits.float(), dim=-1) def _get_kl_div_loss( - prob_unquant: torch.Tensor, logits_quant: torch.Tensor, lm_head: nn.Module = None + log_prob_unquant: torch.Tensor, logits_quant: torch.Tensor, lm_head: nn.Module = None ) -> torch.Tensor: - log_prob_quant = _get_prob_from_logits(logits_quant, return_log_prob=True, lm_head=lm_head) - # We dont need to calculate the full kl div loss here, just get - p*log_q - return -_get_p_log_q(prob_unquant, log_prob_quant) + log_prob_quant = _get_log_prob(logits_quant, lm_head=lm_head) + return F.kl_div(log_prob_quant, log_prob_unquant, reduction="sum", log_target=True) def _get_lm_head(model: nn.Module) -> nn.Module: @@ -1090,6 +1117,8 @@ def _get_lm_head(model: nn.Module) -> nn.Module: class AutoQuantizeKLDivSearcher(_AutoQuantizeBaseSearcher): """A searcher for AutoQuantize algorithm that uses KL-Divergence loss based score estimation.""" + method_name = "kl_div" + @property def default_search_config(self): """Get the default config for the searcher.""" @@ -1141,11 +1170,7 @@ def set_to_unquantized(): ): set_to_unquantized() logits_unquant = self.config["forward_step"](self.model, data) - prob_unquant = _get_prob_from_logits( - logits_unquant, - return_log_prob=False, - lm_head=_get_lm_head(self.model), - ) + log_prob_unquant = _get_log_prob(logits_unquant, lm_head=_get_lm_head(self.model)) for name, hparam in tqdm( list(named_hparams(self.model, configurable=True)), desc="Evaluating hparams" @@ -1157,7 +1182,9 @@ def set_to_unquantized(): continue hparam.active = recipe logits_quant = self.config["forward_step"](self.model, data) - score = _get_kl_div_loss(prob_unquant, logits_quant, _get_lm_head(self.model)) + score = _get_kl_div_loss( + log_prob_unquant, logits_quant, _get_lm_head(self.model) + ) if hparam._importance_dict[recipe][hparam.score_modules[0]] is None: hparam._importance_dict[recipe][hparam.score_modules[0]] = score else: @@ -1252,3 +1279,85 @@ def run_search_with_stats(self, max_weight_size, verbose=False): # Backward compatibility alias (defaults to gradient-based searcher) AutoQuantizeSearcher = AutoQuantizeGradientSearcher + + +def get_auto_quantize_config(search_state, constraints=None, verbose=False): + """Build a flat quant config dict from auto_quantize search_state. + + Re-solves for ``constraints`` if provided, otherwise uses the best recipe from the search. + + Args: + search_state: The state dict returned by :func:`auto_quantize`. + constraints: Optional dict with ``effective_bits`` key to re-solve for a new target. + verbose: If True, prints the per-layer recipe assignments. + + Returns: + A config dict suitable for :func:`quantize`. + """ + if constraints is not None: + best_recipe = _resolve_best_recipe(search_state, constraints, verbose=verbose) + else: + best_recipe = search_state["best"]["recipe"] + + quant_cfg: dict[str, Any] = {"*": {"enable": False}} + for hparam_name, recipe in best_recipe.items(): + if recipe == QuantRecipe(quant_cfg=None): + continue + module_names = search_state["candidate_stats"][hparam_name]["module_names"] + for module_name in module_names: + for quantizer_attr in ("input_quantizer", "weight_quantizer"): + matched_cfg = _match_quantizer_cfg(recipe.config.quant_cfg, quantizer_attr) + if matched_cfg is not None: + quant_cfg[f"{module_name}.{quantizer_attr}"] = matched_cfg + + def _cfg_to_dict(v): + if isinstance(v, mtq_config.QuantizerAttributeConfig): + return { + "enable": v.enable, + "num_bits": v.num_bits, + **v.model_dump(exclude_defaults=True), + } + if isinstance(v, list): + return [_cfg_to_dict(c) for c in v] + return v + + quant_cfg = {k: _cfg_to_dict(v) for k, v in quant_cfg.items()} + return {"quant_cfg": quant_cfg, "algorithm": "max"} + + +def _resolve_best_recipe(search_state, constraints, verbose=False): + effective_bits = constraints["effective_bits"] + compression = effective_bits / 16.0 + candidate_stats = search_state["candidate_stats"] + total_weight_size = sum(s["costs"][-1] for s in candidate_stats.values()) + max_weight_size = total_weight_size * compression + method = search_state["method"] + + if method == "gradient": + searcher = AutoQuantizeGradientSearcher() + elif method == "kl_div": + searcher = AutoQuantizeKLDivSearcher() + else: + raise ValueError( + f"Unknown autoquant search method: {method!r}. Expected 'gradient' or 'kl_div'." + ) + + searcher.candidate_stats = candidate_stats + best_recipe_info, _ = searcher.run_search_with_stats(max_weight_size, verbose=verbose) + + best_recipe = {name: info["format"] for name, info in best_recipe_info.items()} + if verbose: + total_cost = sum(info["costs"] for info in best_recipe_info.values()) + _AutoQuantizeBaseSearcher._print_recipe_summary( + best_recipe, total_cost, total_weight_size, prefix="get_auto_quantize_config" + ) + + return best_recipe + + +def _match_quantizer_cfg(quant_cfg, quantizer_attr): + matched = None + for pattern, cfg in quant_cfg.items(): + if fnmatch.fnmatch(quantizer_attr, pattern): + matched = cfg + return matched diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 70f036a8d6..5618fa413f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -66,12 +66,13 @@ def weight_only_quantize(model: nn.Module): """Just quantize the weights of the model.""" + name_to_module = dict(model.named_modules()) seen_modules = set() - for name, module in model.named_modules(): + for module in name_to_module.values(): if module in seen_modules: continue for weight_name in weight_attr_names(module): - with enable_weight_access_and_writeback(module, model): + with enable_weight_access_and_writeback(module, model, name_to_module): weight_quantizer = getattr( module, quantizer_attr_names(weight_name).weight_quantizer ) @@ -348,12 +349,6 @@ def mse_calibrate( ) continue - if fp8_scale_sweep and not is_nvfp4_static: - warnings.warn( - f"fp8_scale_sweep is enabled but quantizer '{name}' is not NVFP4 static " - "block quantization. fp8_scale_sweep will be ignored for this quantizer." - ) - # Create MSE calibrator with quant_func module._calibrator = MseCalibrator( amax=initial_amax, @@ -365,7 +360,8 @@ def mse_calibrate( ) # Identify weight quantizers by checking if they have corresponding weight parameters - for name, parent_module in model.named_modules(): + name_to_module = dict(model.named_modules()) + for parent_module in name_to_module.values(): if parent_module in seen_modules: continue for weight_name in weight_attr_names(parent_module): @@ -384,7 +380,7 @@ def mse_calibrate( # Enable calibration mode for the weight quantizer weight_quantizer.disable_quant() weight_quantizer.enable_calib() - with enable_weight_access_and_writeback(parent_module, model): + with enable_weight_access_and_writeback(parent_module, model, name_to_module): weight = getattr(parent_module, weight_name) weight_quantizer(weight) @@ -901,8 +897,9 @@ def postprocess(module): scale_a = scale_a.clamp(min=1e-4, max=1e4) apply_pre_quant_scale_and_smooth(module, scale_a) + name_to_module = dict(model.named_modules()) smoothed_modules = 0 - for name, module in model.named_modules(): + for name, module in name_to_module.items(): if is_quantized_linear(module): if not hasattr(module.input_quantizer, "_amax"): warnings.warn(f"{name} is not calibrated, skip smoothing") @@ -918,7 +915,7 @@ def postprocess(module): f"Error: {name} has only one channel to smooth" ) - with enable_weight_access_and_writeback(module, model): + with enable_weight_access_and_writeback(module, model, name_to_module): postprocess(module) smoothed_modules += 1 @@ -1508,9 +1505,10 @@ def postprocess(module, name): create_and_replace_svdquant_linear_on_the_fly(model=model) awq(model, forward_loop, "awq_lite", **kwargs) - for name, module in model.named_modules(): + name_to_module = dict(model.named_modules()) + for name, module in name_to_module.items(): if is_quantized_linear(module) and module.weight_quantizer.is_enabled: - with enable_weight_access_and_writeback(module, model): + with enable_weight_access_and_writeback(module, model, name_to_module): postprocess(module, name) max_calibrate(model, forward_loop) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index ff3ae567be..4aa1ff46b4 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -34,10 +34,12 @@ from modelopt.torch.utils import atomic_print from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe +from .algorithms import get_auto_quantize_config as _get_auto_quantize_config from .config import QuantizeAlgoCfgType from .conversion import set_quantizer_attribute from .mode import QuantizeModeRegistry, get_modelike_from_algo_cfg from .nn import QuantModule, TensorQuantizer +from .utils import is_quantized __all__ = [ "auto_quantize", @@ -46,6 +48,7 @@ "disable_quantizer", "enable_quantizer", "fold_weight", + "get_auto_quantize_config", "postprocess_amax", "print_quant_summary", "quantize", @@ -147,6 +150,9 @@ def quantize( performs calibration as specified by ``quant_cfg``. ``forward_loop`` is used to forward data through the model and gather statistics for calibration. + If the model is already quantized, the provided ``config`` is applied to the existing + quantizers and calibration is run. + Args: model: A pytorch model config: A dictionary or an instance of @@ -230,7 +236,12 @@ def forward_loop(model) -> None: Returns: A pytorch model which has been quantized and calibrated. """ - model = apply_mode(model, mode=[("quantize", config)], registry=QuantizeModeRegistry) + if not is_quantized(model): + model = apply_mode(model, mode=[("quantize", dict(config))], registry=QuantizeModeRegistry) + else: + # Already quantized, so lets apply the quant_cfg from the config + quant_cfg = QuantizeConfig(**dict(config)).quant_cfg + set_quantizer_by_cfg(model, quant_cfg) return calibrate(model, config.get("algorithm"), forward_loop=forward_loop) @@ -240,6 +251,19 @@ def forward_loop(model) -> None: # This way wecan limit the granularity of quantization search. For example, # - limit the quantization format search to decoder block level (instead of each linear layer level) # - Same format for all self attention layers of a model etc. + +_AUTO_QUANTIZE_SUPPORTED_ALGORITHMS = { + None, + "max", + "mse", + "local_hessian", + "smoothquant", + "awq_lite", + "awq_full", + "awq_clip", +} + + def auto_quantize( model: nn.Module, constraints: dict[str, float | str] = {"effective_bits": 4.8}, @@ -467,6 +491,16 @@ def forward_backward_step(model, batch) -> None: assert len(processed_quantization_formats) > 0, "`quantization_formats` should not be empty" + for quant_cfg, name in processed_quantization_formats: + algo = QuantRecipe(quant_cfg, name=name).config.algorithm + algo_method = algo["method"] if isinstance(algo, dict) else algo + if algo_method not in _AUTO_QUANTIZE_SUPPORTED_ALGORITHMS: + raise ValueError( + f"Algorithm '{algo_method}' in '{name}' is not supported by auto_quantize yet. " + "Please run auto_quantize with 'max' or 'mse' calibration and use " + "get_auto_quantize_config() to obtain a config for mtq.quantize()." + ) + # Select the appropriate searcher based on method if method == "gradient": searcher = AutoQuantizeGradientSearcher() @@ -499,6 +533,45 @@ def forward_backward_step(model, batch) -> None: return model, searcher.state_dict() +def get_auto_quantize_config(search_state, constraints=None, verbose=False): + """Build a flat quant config from auto_quantize search_state. + + Re-solves for ``constraints`` if provided, otherwise uses the stored best recipe. + + Args: + search_state: The state dict returned by :func:`auto_quantize`. + constraints: Optional dict, e.g. ``{"effective_bits": 5.5}``, to re-solve for a + different target without re-running calibration or scoring. + verbose: If True, prints the per-layer recipe assignments. + + Returns: + A config dict suitable for :func:`quantize`. + + Example: + + .. code-block:: python + + model, search_state = mtq.auto_quantize(model, ...) + + # Re-solve for a different effective_bits target (cheap, no GPU needed) + config = mtq.get_auto_quantize_config(search_state, {"effective_bits": 5.5}) + + # Or use the original result + config = mtq.get_auto_quantize_config(search_state) + + # [Optional] Customize algorithm if needed + config["algorithm"] = {"method": "gptq_lite", "sequential": True} + + # Reuse on the same model (e.g. run a longer calibration pass) + model = mtq.quantize(model, config, forward_loop=calibrate_loop) + + # Or apply the same/customized config on a fresh model instance + # fresh_model = load_model(...) + # fresh_model = mtq.quantize(fresh_model, config, forward_loop=calibrate_loop) + """ + return _get_auto_quantize_config(search_state, constraints, verbose=verbose) + + def disable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable): """Disable quantizer by wildcard or filter function.""" set_quantizer_attribute(model, wildcard_or_filter_func, {"enable": False}) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 318646c394..85d653d182 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -1201,11 +1201,15 @@ def setup_model_for_gradient_checkpointing(model: nn.Module): "Disable gradient checkpointing after AutoQuantize if this is not desired!" ) model.gradient_checkpointing_enable({"use_reentrant": True}) - model.train() # Model needs to be in training mode to enable gradient checkpointing - # Set all dropout layers to eval mode for deterministic auto-quantize scores - for name, module in model.named_modules(): - if isinstance(model, torch.nn.Dropout): - module.eval() + for m in model.modules(): + if hasattr(m, "gradient_checkpointing"): + m.train() # Make sure the module is in training mode to enable gradient checkpointing + else: + # Eval mode for non-checkpointed modules to avoid fused kernels + # that bypass linear layers. E.g. in nemotron-h, the Mamba layer's + # training path uses a fused kernel that takes out_proj weights + # directly, skipping the linear module's forward (and thus quantization). + m.eval() except Exception as e: warnings.warn( f"AutoQuantize: Error enabling gradient checkpointing for huggingface model due to: {e}, " diff --git a/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu b/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu index 38a66562f7..9db141f0cb 100644 --- a/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu +++ b/modelopt/torch/quantization/src/tensor_quant_gpu_fp8.cu @@ -71,7 +71,10 @@ at::Tensor fake_e4m3fy_with_axis(at::Tensor inputs, at::Tensor amax, int axis) { int axis_size = inputs.size(axis); int outer_size = inputs.stride(axis); - auto scale = 448.f / amax; + const float epsilon = 1.0f / (1 << 24); + auto zero_amax_mask = amax <= epsilon; + auto safe_amax = at::where(zero_amax_mask, torch::ones_like(amax), amax); + auto scale = 448.f / safe_amax; auto inv_scale = 1.f / scale; auto stream = c10::cuda::getCurrentCUDAStream(); @@ -88,7 +91,10 @@ at::Tensor fake_e4m3fy(at::Tensor inputs, at::Tensor amax) { inputs = inputs.contiguous(); amax = amax.view(-1).to(at::kFloat); size_t numel = inputs.numel(); - at::Tensor scale = 448.f / amax; + const float epsilon = 1.0f / (1 << 24); + auto zero_amax_mask = amax <= epsilon; + auto safe_amax = at::where(zero_amax_mask, torch::ones_like(amax), amax); + at::Tensor scale = 448.f / safe_amax; auto inv_scale = 1.f / scale; auto outputs = torch::empty_like(inputs); auto stream = c10::cuda::getCurrentCUDAStream(); diff --git a/modelopt/torch/quantization/tensor_quant.py b/modelopt/torch/quantization/tensor_quant.py index d9b5839716..16b9d32997 100644 --- a/modelopt/torch/quantization/tensor_quant.py +++ b/modelopt/torch/quantization/tensor_quant.py @@ -45,9 +45,13 @@ def _fp8_eager(x, amax=None): dtype = x.dtype if amax is not None: - scale = 448.0 / (amax.to(torch.float32)) + amax = amax.to(torch.float32) + epsilon = 1.0 / (1 << 24) + zero_amax_mask = amax <= epsilon + safe_amax = torch.where(zero_amax_mask, torch.ones_like(amax), amax) + scale = 448.0 / safe_amax scale_inv = 1 / scale - x = x.to(torch.float32) * scale + x = (x.to(torch.float32) * scale).clamp(min=-448.0, max=448.0) x = x.to(torch.float8_e4m3fn) if amax is not None: x = x.to(torch.float32) * scale_inv diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 3c0d5e4344..e8a0749b51 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -487,8 +487,11 @@ def enable_weight_access_and_writeback(module, root_model, name_to_module: dict Args: module: The module to access weights for. root_model: The root model containing the module. - name_to_module: Optional pre-computed dict mapping names to modules (for performance). - If not provided, will be computed on-the-fly. + name_to_module: Pre-computed ``dict(root_model.named_modules())``. Without this, + every call iterates ``root_model.named_modules()`` internally, leading to O(N^2) + total cost when called in a loop. This causes significant CPU overhead on large + models, particularly Sparse MoE architectures where each expert is typically + implemented as its own module. """ if _get_enclosing_fsdp_module(module, root_model, name_to_module) is not None: context = fsdp2_weight_access_and_writeback_context(module, root_model) diff --git a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py index e84b1a49ad..1a28d229f4 100644 --- a/tests/gpu/torch/quantization/test_tensor_quant_cuda.py +++ b/tests/gpu/torch/quantization/test_tensor_quant_cuda.py @@ -145,6 +145,19 @@ def test_e4m3_per_channel(self, axis): xq_test = tensor_quant.scaled_e4m3(x, amax, None, 4, 3) assert torch.allclose(xq_test, xq_ref) + @pytest.mark.parametrize("device", ["cuda", "cpu"]) + def test_zero_amax_is_finite(self, device): + x = torch.randn(4, 4, device=device, dtype=torch.float32) + amax = torch.zeros((1,), device=device) + xq = tensor_quant.scaled_e4m3(x, amax, None, 4, 3) + assert torch.isfinite(xq).all() + + def test_zero_amax_per_channel_is_finite(self): + x = torch.randn(2, 3, 4, device="cuda", dtype=torch.float32) + amax = torch.tensor([1.0, 0.0, 1.0], device="cuda").view(1, 3, 1) + xq = tensor_quant.scaled_e4m3(x, amax, None, 4, 3) + assert torch.isfinite(xq).all() + class Testfp4: @pytest.mark.skipif(get_cuda_ext_mx() is None, reason="cuda_ext_mx is not available") diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 1a5cfee324..c0f049174e 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -208,6 +208,27 @@ def loss_func(output): assert not best_model.mlp.input_quantizer.is_enabled +def test_auto_quantize_disabled_layers_no_poison(): + """disabled_layers must only affect the matched layers, not all subsequent layer groups.""" + model = TransformerBlock() + + best_model, _ = mtq.auto_quantize( + model, + constraints={"effective_bits": 5.0}, + quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], + data_loader=[model.get_input() for _ in range(2)], + forward_step=lambda model, batch: model(batch), + loss_func=lambda output, data: output.sum(), + disabled_layers=["*mlp*"], + num_calib_steps=2, + num_score_steps=2, + ) + + assert not best_model.mlp.input_quantizer.is_enabled + hparam = best_model.attn.q_proj.get_hparam("quant_recipe") + assert QuantRecipe(mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG) in hparam.choices + + INT4INT8_AWQ_CFG = { "quant_cfg": { "*weight_quantizer": [ @@ -260,12 +281,12 @@ def _test_data_parallel_auto_quantize(rank, size): lambda a: a[0], ) - print(f"rank {rank} search_history: {search_history}") - if search_history != search_history_rank0: - print(f"rank {rank} search_history_rank0: {search_history_rank0}") + # quantizer_states contains tensors which can't be compared with == + sh = {k: v for k, v in search_history.items() if k != "quantizer_states"} + sh0 = {k: v for k, v in search_history_rank0.items() if k != "quantizer_states"} # Assert that the costs, scores and searched recipes are the same across all ranks - assert search_history == search_history_rank0 + assert sh == sh0 assert search_history["best"]["is_satisfied"] @@ -390,6 +411,12 @@ def test_auto_quantize_checkpoint_resume(method, tmp_path, capsys): "Expected restore message when resuming from checkpoint" ) + # Verify method is correctly persisted in checkpoint and state dicts + saved = torch.load(checkpoint_path, weights_only=False) + assert saved["method"] == method + assert state_dict_1["method"] == method + assert state_dict_2["method"] == method + # Results should be identical when using same constraint assert state_dict_1["candidate_stats"] == state_dict_2["candidate_stats"] assert state_dict_1["best"]["recipe"] == state_dict_2["best"]["recipe"] @@ -397,3 +424,73 @@ def test_auto_quantize_checkpoint_resume(method, tmp_path, capsys): pytest.approx(state_dict_1["best"]["constraints"]["effective_bits"]) == state_dict_2["best"]["constraints"]["effective_bits"] ) + + # Verify calibration was also restored from checkpoint + assert "Restored calibration for" in captured.out + + # Verify quantizer_states is saved in checkpoint + assert "quantizer_states" in saved + assert len(saved["quantizer_states"]) > 0 + for recipe_state in saved["quantizer_states"].values(): + assert "metadata" in recipe_state + assert "state_dict" in recipe_state + + # Verify resumed model produces identical quantizer_states + assert state_dict_1["quantizer_states"].keys() == state_dict_2["quantizer_states"].keys() + for recipe in state_dict_1["quantizer_states"]: + s1 = state_dict_1["quantizer_states"][recipe] + s2 = state_dict_2["quantizer_states"][recipe] + # Verify metadata (quantizer properties + tensor shape/dtype info) match per quantizer + assert s1["metadata"].keys() == s2["metadata"].keys() + for qname in s1["metadata"]: + assert s1["metadata"][qname] == s2["metadata"][qname], ( + f"Metadata mismatch for {qname} in {recipe}" + ) + # Verify actual tensor values match per quantizer + assert s1["state_dict"].keys() == s2["state_dict"].keys() + for qname in s1["state_dict"]: + for buf_name in s1["state_dict"][qname]: + torch.testing.assert_close( + s1["state_dict"][qname][buf_name], s2["state_dict"][qname][buf_name] + ) + + +@pytest.mark.parametrize("method", ["gradient", "kl_div"]) +def test_get_auto_quantize_config(method): + model = TransformerBlock() + + _, search_state = mtq.auto_quantize( + model, + constraints={"effective_bits": 6.0}, + quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], + data_loader=[model.get_input() for _ in range(4)], + forward_step=lambda model, batch: model(batch), + loss_func=lambda output, data: output.sum(), + num_calib_steps=2, + num_score_steps=2, + method=method, + ) + + # Verify search_state has method and module_names + assert search_state["method"] == method + for stats in search_state["candidate_stats"].values(): + assert "module_names" in stats + assert len(stats["module_names"]) > 0 + + # Use stored best recipe + config = mtq.get_auto_quantize_config(search_state) + assert "quant_cfg" in config + assert config["quant_cfg"]["*"] == {"enable": False} + assert config["algorithm"] == "max" + + # Re-solve with different constraints + config_resoled = mtq.get_auto_quantize_config( + search_state, constraints={"effective_bits": 12.0} + ) + assert "quant_cfg" in config_resoled + + # Apply config to a fresh model + fresh_model = TransformerBlock() + fresh_model = mtq.quantize(fresh_model, config, forward_loop=lambda m: m(model.get_input())) + output = fresh_model(model.get_input()) + assert output is not None diff --git a/tests/unit/torch/quantization/test_quantize_cpu.py b/tests/unit/torch/quantization/test_quantize_cpu.py index 43233b3239..641eafd2ff 100644 --- a/tests/unit/torch/quantization/test_quantize_cpu.py +++ b/tests/unit/torch/quantization/test_quantize_cpu.py @@ -267,3 +267,19 @@ def test_block_sizes_axis_model(): if hasattr(module, "weight_quantizer"): assert name_ref == name assert torch.allclose(module_ref.weight_quantizer.amax, module.weight_quantizer.amax) + + +def test_quantize_twice(): + """Test that calling mtq.quantize twice on the same model works.""" + model = SimpleLinear() + inputs = model.get_input() + + def forward_loop(model): + return model(inputs) + + model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop=forward_loop) + out1 = model(inputs) + model = mtq.quantize(model, mtq.INT8_DEFAULT_CFG, forward_loop=forward_loop) + out2 = model(inputs) + + assert torch.allclose(out1, out2), "Re-quantization with same config should be idempotent" From c2cb8ec2b41154d255488645973982a217648cb4 Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 10 Mar 2026 18:11:40 +0000 Subject: [PATCH 2/3] Address PR review feedback: type annotation, checkpoint warning, comment Co-Authored-By: Claude Opus 4.6 Signed-off-by: realAsma --- modelopt/torch/opt/searcher.py | 4 ++++ modelopt/torch/quantization/algorithms.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/opt/searcher.py b/modelopt/torch/opt/searcher.py index 49f43ef0b7..7714ca614d 100644 --- a/modelopt/torch/opt/searcher.py +++ b/modelopt/torch/opt/searcher.py @@ -254,6 +254,10 @@ def load_search_checkpoint(self, strict=True) -> bool: return False # Backward compat: fall back to the original single-file path if not os.path.exists(checkpoint): + warn_rank_0( + f"Per-rank checkpoint {checkpoint} not found, falling back to " + f"{self.config['checkpoint']}. Ensure world size matches the original run." + ) checkpoint = self.config["checkpoint"] if not os.path.exists(checkpoint): warn_rank_0(f"Checkpoint {checkpoint} does not exist! Initializing from scratch.") diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 94075ce6b9..5377040451 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -337,7 +337,7 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): candidate_stats: dict[str, dict[str, list[float]]] best: dict[str, Any] quantizer_states: dict - method_name: str = None + method_name: str | None = None quant_grouping_rules = [ r"^(.*?)\.(q_proj|k_proj|v_proj)$", # q_proj, k_proj, v_proj for llama like models @@ -1356,6 +1356,7 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): def _match_quantizer_cfg(quant_cfg, quantizer_attr): + # Last-match-wins to mirror set_quantizer_by_cfg behavior matched = None for pattern, cfg in quant_cfg.items(): if fnmatch.fnmatch(quantizer_attr, pattern): From 67b2dc2149a7b60cc03b26ade0c71271a725137c Mon Sep 17 00:00:00 2001 From: realAsma Date: Tue, 10 Mar 2026 20:05:35 +0000 Subject: [PATCH 3/3] Add warning in get_auto_quantize_config about algorithm='max' default Co-Authored-By: Claude Opus 4.6 Signed-off-by: realAsma --- modelopt/torch/quantization/algorithms.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 5377040451..339e9d0bb9 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -1322,6 +1322,11 @@ def _cfg_to_dict(v): return v quant_cfg = {k: _cfg_to_dict(v) for k, v in quant_cfg.items()} + warnings.warn( + "get_auto_quantize_config: returned config uses algorithm='max'. " + "Per-recipe calibration algorithms (e.g. smoothquant, awq) are not preserved. " + "Update config['algorithm'] if a different calibration algorithm is needed (e.g. 'gptq')." + ) return {"quant_cfg": quant_cfg, "algorithm": "max"}