From 6bca67e07295d698ae436443e1375183b0af4d09 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Fri, 1 May 2026 15:20:08 -0700 Subject: [PATCH] audoquantize for VLM Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- examples/llm_ptq/hf_ptq.py | 154 ++++++++++++++++++++-- modelopt/torch/export/quant_utils.py | 16 ++- modelopt/torch/quantization/algorithms.py | 136 ++++++++++++++++--- 3 files changed, 272 insertions(+), 34 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d660c1de4c8..08934412866 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -15,6 +15,7 @@ import argparse import copy +import os import random import time import warnings @@ -137,6 +138,43 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None: mto.enable_huggingface_checkpointing() +NVFP4_W4A16_CFG = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*weight_quantizer", + "cfg": { + "num_bits": (2, 1), + "block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, + }, + }, + {"quantizer_name": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", +} + +FP8_W8A16_CFG = { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*weight_quantizer", + "cfg": {"num_bits": (4, 3), "axis": None}, + }, + {"quantizer_name": "*input_quantizer", "enable": False}, + *_default_disabled_quantizer_cfg, + ], + "algorithm": "max", +} + +QUANT_CFG_CHOICES.update( + { + "nvfp4_w4a16": NVFP4_W4A16_CFG, + "fp8_w8a16": FP8_W8A16_CFG, + } +) + + def extract_and_prepare_language_model_from_vl(full_model): """Extract language model from VL model and disable quantization for non-language components. @@ -326,6 +364,8 @@ def auto_quantize( "nvfp4_omlp_only", "nvfp4_local_hessian", "mxfp8", + "nvfp4_w4a16", + "fp8_w8a16", ] for qformat in qformat_list ), "One or more quantization formats provided are not supported for unified checkpoint export" @@ -348,6 +388,38 @@ def forward_step(model, batch): f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'" ) + # Let AutoQuantize search lm_head, but keep modules out that vLLM either + # constructs as BF16-only paths today or has known unsafe fused dispatch for. + disabled_layers = [ + entry["quantizer_name"] + for entry in _default_disabled_quantizer_cfg + if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*" + ] + enable_linear_attn_big3 = os.environ.get("MODELOPT_AUTOQ_ENABLE_LINEAR_ATTN_BIG3") == "1" + enable_shared_expert = os.environ.get("MODELOPT_AUTOQ_ENABLE_SHARED_EXPERT") == "1" + autoq_extra_disabled = [ + "*shared_expert_gate*", + "*linear_attn.in_proj_a*", + "*linear_attn.in_proj_b*", + ] + if not enable_shared_expert: + autoq_extra_disabled.append("*mlp.shared_expert*") + if not enable_linear_attn_big3: + autoq_extra_disabled.extend( + [ + "*linear_attn.in_proj_qkv*", + "*linear_attn.in_proj_z*", + "*linear_attn.out_proj*", + ] + ) + for pat in autoq_extra_disabled: + if pat not in disabled_layers: + disabled_layers.append(pat) + if is_multimodal_model(language_model): + for pat in ("*visual*", "*mtp*", "*vision_tower*"): + if pat not in disabled_layers: + disabled_layers.append(pat) + language_model, _ = mtq.auto_quantize( language_model, constraints={"effective_bits": args.auto_quantize_bits}, @@ -362,12 +434,7 @@ def forward_step(model, batch): len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1) ), verbose=True, - # Disable all default disabled layers such as lm_head, mlp.gate, router etc. - disabled_layers=[ - entry["quantizer_name"] - for entry in _default_disabled_quantizer_cfg - if "parent_class" not in entry - ], + disabled_layers=disabled_layers, method=auto_quantize_method, checkpoint=auto_quantize_checkpoint, ) @@ -507,12 +574,26 @@ def load_model(args: argparse.Namespace): ] # We only quantize the language model for VLMs other than the type supported above. - extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl( - full_model - ) - if extracted_lm is not None: - language_model = extracted_lm - model_type = extracted_model_type + # For AutoQuantize, skip the eager visual-disable side-effect: it + # registers ``modelopt`` state on each visual sibling, which + # ``mtq.auto_quantize → apply_mode → is_converted`` then trips on + # ("Model has multiple modelopt states!"). AutoQuantize handles + # visual/mtp via ``disabled_layers`` patterns instead, so the + # extraction is unnecessary for that path. + # + # For ``--recipe`` mode on a VLM, lm_head sits on the OUTER + # CausalLM. Recipe rules can't see it via the inner language + # backbone, so we keep ``language_model = full_model`` here and + # let ``quantize_main`` strip visual/mtp siblings around + # ``mtq.quantize`` (so registration/calibration stays fast and + # batch_size auto-detect doesn't collapse to 1). + if args.auto_quantize_bits is None and args.recipe is None: + extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl( + full_model + ) + if extracted_lm is not None: + language_model = extracted_lm + model_type = extracted_model_type tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) @@ -628,6 +709,32 @@ def mono_quantize( else None, ) + # When ``--recipe`` is given on a VLM we keep ``language_model = + # full_model`` (so recipe rules can match lm_head) but ``mtq.quantize`` + # would otherwise walk and register quantizers on every Linear in the + # visual encoder + MTP head. + stripped_vlm_modules: dict[str, torch.nn.Module] = {} + if args.recipe is not None and language_model is full_model: + for path in ("model.visual", "mtp"): + parts = path.split(".") + parent = full_model + ok = True + for p in parts[:-1]: + if not hasattr(parent, p): + ok = False + break + parent = getattr(parent, p) + if ok and hasattr(parent, parts[-1]): + mod = getattr(parent, parts[-1]) + if mod is not None and isinstance(mod, torch.nn.Module): + stripped_vlm_modules[path] = mod + setattr(parent, parts[-1], None) + if stripped_vlm_modules: + print( + "[recipe] stripped VLM siblings before mtq.quantize: " + + ", ".join(stripped_vlm_modules.keys()) + ) + if calibration_only: language_model = mtq.calibrate( language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop @@ -635,6 +742,19 @@ def mono_quantize( else: language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop) + # Restore stripped VLM siblings so export sees the full model. + for path, mod in stripped_vlm_modules.items(): + parts = path.split(".") + parent = full_model + for p in parts[:-1]: + parent = getattr(parent, p) + setattr(parent, parts[-1], mod) + if stripped_vlm_modules: + print( + "[recipe] restored VLM siblings after mtq.quantize: " + + ", ".join(stripped_vlm_modules.keys()) + ) + # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: language_model_lineage = get_language_model_from_vl(full_model) @@ -1018,10 +1138,18 @@ def quantize_main( "Auto quantization needs multiple quantization format." ) + # For VL models, autoquant must walk submodules of the OUTER CausalLM + # (which carries lm_head and the LM-head forward path) — otherwise + # lm_head and any sibling-of-language_model modules are silently + # invisible to the search. ``forward_step`` also needs the outer model + # to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``). + # Visual tower and MTP siblings are auto-excluded inside + # ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns. auto_quantize( args, - language_model, + full_model, calib_dataloader, + auto_quantize_method=args.auto_quantize_method, ) else: diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 76f304a478a..7d83deef834 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -1354,9 +1354,23 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False """Preprocess the quantized linears that we plan to fuse. Use resmooth_only for MOE experts as each individual expert is not fused. + + When the modules carry mismatched quantization formats — most often after + AutoQuantize picks different formats for layers that share input but were + not coalesced into a single search group — we cannot coalesce them into a + fused linear. In that case, fall back to skipping the fusion so each linear + exports independently with its own format, instead of asserting. """ quantization_format_list = [get_quantization_format(module) for module in modules] - assert all_items_same(quantization_format_list), "Modules have different quantization formats" + if not all_items_same(quantization_format_list): + warn( + "preprocess_linear_fusion: modules in this fusion group have mixed " + f"quantization formats {quantization_format_list}. Skipping fusion; " + "each linear will export with its own format. Common cause: " + "AutoQuantize assigned different formats to fusion-mate linears.", + stacklevel=2, + ) + return # Activation if hasattr(modules[0], "input_quantizer"): diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index f1db2df9e84..08fb67aa6e0 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -46,6 +46,77 @@ from .utils import is_quantized_linear +def _is_fused_experts_module(module: nn.Module) -> bool: + """Return True if ``module`` is a quantized fused-MoE-experts container. + + These modules expose plural ``*_input_quantizer`` and ``*_weight_quantizers`` + (an ``nn.ModuleList`` of per-expert quantizers) instead of the singular + ``input_quantizer`` / ``weight_quantizer`` attrs found on standard + ``nn.Linear``-derived QuantModules. AutoQuantize hparam discovery and cost + accounting need to recognize this layout to enumerate fused experts as + search dimensions. + """ + # Late import to avoid a circular import at module load time. + try: + from .plugins.huggingface import _QuantFusedExperts + except ImportError: + return False + return isinstance(module, _QuantFusedExperts) + + +# Quantizer attribute names that participate in AutoQuantize snapshot/restore. +_STD_QUANTIZER_ATTRS = ("input_quantizer", "weight_quantizer", "output_quantizer") +_FUSED_EXPERTS_QUANTIZER_ATTRS = ( + "gate_up_proj_input_quantizer", + "gate_up_proj_weight_quantizers", + "down_proj_input_quantizer", + "down_proj_weight_quantizers", +) + + +def _get_quantizer_attrs(module: nn.Module) -> tuple[str, ...]: + """Return the quantizer attribute names that AutoQuantize must snapshot/restore. + + For fused MoE experts, this returns the four plural quantizer attrs (two + shared input quantizers + two ``ModuleList`` of per-expert weight quantizers). + For standard Linear-derived QuantModules, returns the canonical trio. + """ + if _is_fused_experts_module(module): + return _FUSED_EXPERTS_QUANTIZER_ATTRS + return _STD_QUANTIZER_ATTRS + + +def _make_fresh_quantizer_for_attr(module: nn.Module, attr_name: str) -> nn.Module: + """Return a fresh, default quantizer object suitable to overwrite ``module.``. + + For ModuleList attrs (per-expert quantizers on fused-experts modules), the + returned ModuleList preserves the original list length so per-expert + enumeration stays consistent across recipes. + """ + current = getattr(module, attr_name, None) + if isinstance(current, nn.ModuleList): + return nn.ModuleList(TensorQuantizer() for _ in range(len(current))) + return TensorQuantizer() + + +def _get_module_weight_numel(module: nn.Module) -> int: + """Return the total parameter count of a module's quantizable weights. + + Standard QuantLinear modules have a single ``weight`` parameter. Fused + experts modules have two 3-D fused parameters (``gate_up_proj`` and + ``down_proj``) instead — both contribute to the cost accounting. + """ + if _is_fused_experts_module(module): + total = 0 + for attr in ("gate_up_proj", "down_proj"): + param = getattr(module, attr, None) + if param is not None: + total += param.numel() + return total + weight = getattr(module, "weight", None) + return weight.numel() if weight is not None else 0 + + def estimate_quant_compression(quant_cfg: QuantizeConfig) -> float: """Estimate the compression ratio of a quantization configuration. @@ -218,26 +289,26 @@ def __init__( # This is a hack; We dont want to make the input_quantizer, weight_quantizer, output_quantizer # a dynamic attribute for backward compatibility with the model_calib.py # TODO: Make input_quantizer, weight_quantizer, output_quantizer a dynamic attribute and get rid of this hack + # NOTE: For fused-experts modules, the relevant attrs are plural + # (``*_input_quantizer`` + ``*_weight_quantizers`` ModuleList) — see + # ``_get_quantizer_attrs``. Both layouts share the same snapshot dict + # shape so ``active.setter`` swaps the right child modules. self._all_quantizer_choices = {quant_recipe: {} for quant_recipe in self.choices} quant_recipe: QuantRecipe for quant_recipe in self.choices: for quant_module in self.quant_modules: - for quantizer_attr_name in [ - "input_quantizer", - "weight_quantizer", - "output_quantizer", - ]: - setattr(quant_module, quantizer_attr_name, TensorQuantizer()) + attr_names = _get_quantizer_attrs(quant_module) + for attr_name in attr_names: + setattr( + quant_module, + attr_name, + _make_fresh_quantizer_for_attr(quant_module, attr_name), + ) set_quantizer_by_cfg(quant_module, quant_recipe.config.quant_cfg) self._all_quantizer_choices[quant_recipe][quant_module] = { - quantizer_attr_name: getattr(quant_module, quantizer_attr_name) - for quantizer_attr_name in [ - "input_quantizer", - "weight_quantizer", - "output_quantizer", - ] + attr_name: getattr(quant_module, attr_name) for attr_name in attr_names } self.active = self.original @@ -344,6 +415,20 @@ def attrs(self) -> list[str]: return ["name", *super().attrs] +_LINEAR_ATTN_QKVZ_RE = re.compile(r"^(.*?\.linear_attn)\.(?:in_proj_qkv|in_proj_z)$") +_LINEAR_ATTN_BA_RE = re.compile(r"^(.*?\.linear_attn)\.(?:in_proj_a|in_proj_b)$") + + +def _linear_attn_qkvz_group_key(_model, name: str) -> str | None: + m = _LINEAR_ATTN_QKVZ_RE.match(name) + return f"{m.group(1)}/qkvz" if m else None + + +def _linear_attn_ba_group_key(_model, name: str) -> str | None: + m = _LINEAR_ATTN_BA_RE.match(name) + return f"{m.group(1)}/ba" if m else None + + class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): """Base searcher for AutoQuantize algorithm.""" @@ -365,6 +450,13 @@ class _AutoQuantizeBaseSearcher(BaseSearcher, ABC): r"^(.*?)\.(gate_proj|up_proj)$", # gate_proj, up_proj for llama like models r"^(.*?)\.(\d+\.(w1|w2|w3))$", # mixtral experts r"^(.*?)\.((w1_linear|w2_linear|w3_linear)\.\d+)$", # dbrx experts + # Qwen3.5/3.6 hybrid linear_attn: vLLM fuses (in_proj_qkv, in_proj_z) + # into ``in_proj_qkvz`` and (in_proj_a, in_proj_b) into ``in_proj_ba`` and + # requires fused shards to share quant_algo. Two callables (not one + # regex) so qkv+z and a+b produce DIFFERENT group keys; each pair + # stays with its own fusion partner. + _linear_attn_qkvz_group_key, + _linear_attn_ba_group_key, ] score_module_rules = [] @@ -410,9 +502,15 @@ def load_search_checkpoint(self) -> bool: @staticmethod def _is_auto_quantize_module(module): - return ( - is_quantized_linear(module) or isinstance(module, QuantLinearConvBase) - ) and isinstance(module, QuantModule) + if (is_quantized_linear(module) or isinstance(module, QuantLinearConvBase)) and isinstance( + module, QuantModule + ): + return True + # Fused MoE experts: a single ``QuantModule`` that owns N per-expert + # weight quantizers in an ``nn.ModuleList`` plus shared input quantizers. + # All N experts in a layer share one search dimension (one recipe per + # fused module). + return _is_fused_experts_module(module) and isinstance(module, QuantModule) @staticmethod def _get_search_recipes(quantization_formats): @@ -712,11 +810,9 @@ def _print_recipe_summary(best_recipe, total_cost, total_weight_size, prefix="Au @staticmethod def _get_total_weight_size(modules): return sum( - ( - module.weight.numel() - if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module) - else 0 - ) + _get_module_weight_numel(module) + if _AutoQuantizeBaseSearcher._is_auto_quantize_module(module) + else 0 for module in modules )