Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 141 additions & 13 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import argparse
import copy
import os
import random
import time
import warnings
Expand Down Expand Up @@ -137,6 +138,43 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
mto.enable_huggingface_checkpointing()


NVFP4_W4A16_CFG = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*weight_quantizer",
"cfg": {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
},
},
{"quantizer_name": "*input_quantizer", "enable": False},
*_default_disabled_quantizer_cfg,
],
"algorithm": "max",
}

FP8_W8A16_CFG = {
"quant_cfg": [
{"quantizer_name": "*", "enable": False},
{
"quantizer_name": "*weight_quantizer",
"cfg": {"num_bits": (4, 3), "axis": None},
},
{"quantizer_name": "*input_quantizer", "enable": False},
*_default_disabled_quantizer_cfg,
],
"algorithm": "max",
}

QUANT_CFG_CHOICES.update(
{
"nvfp4_w4a16": NVFP4_W4A16_CFG,
"fp8_w8a16": FP8_W8A16_CFG,
}
)


def extract_and_prepare_language_model_from_vl(full_model):
"""Extract language model from VL model and disable quantization for non-language components.

Expand Down Expand Up @@ -326,6 +364,8 @@ def auto_quantize(
"nvfp4_omlp_only",
"nvfp4_local_hessian",
"mxfp8",
"nvfp4_w4a16",
"fp8_w8a16",
]
for qformat in qformat_list
), "One or more quantization formats provided are not supported for unified checkpoint export"
Expand All @@ -348,6 +388,38 @@ def forward_step(model, batch):
f"Invalid auto_quantize_method: {auto_quantize_method}. Must be 'gradient' or 'kl_div'"
)

# Let AutoQuantize search lm_head, but keep modules out that vLLM either
# constructs as BF16-only paths today or has known unsafe fused dispatch for.
disabled_layers = [
entry["quantizer_name"]
for entry in _default_disabled_quantizer_cfg
if "parent_class" not in entry and entry["quantizer_name"] != "*lm_head*"
]
enable_linear_attn_big3 = os.environ.get("MODELOPT_AUTOQ_ENABLE_LINEAR_ATTN_BIG3") == "1"
enable_shared_expert = os.environ.get("MODELOPT_AUTOQ_ENABLE_SHARED_EXPERT") == "1"
autoq_extra_disabled = [
"*shared_expert_gate*",
"*linear_attn.in_proj_a*",
"*linear_attn.in_proj_b*",
]
if not enable_shared_expert:
autoq_extra_disabled.append("*mlp.shared_expert*")
if not enable_linear_attn_big3:
autoq_extra_disabled.extend(
[
"*linear_attn.in_proj_qkv*",
"*linear_attn.in_proj_z*",
"*linear_attn.out_proj*",
]
)
for pat in autoq_extra_disabled:
if pat not in disabled_layers:
disabled_layers.append(pat)
if is_multimodal_model(language_model):
for pat in ("*visual*", "*mtp*", "*vision_tower*"):
if pat not in disabled_layers:
disabled_layers.append(pat)

language_model, _ = mtq.auto_quantize(
language_model,
constraints={"effective_bits": args.auto_quantize_bits},
Expand All @@ -362,12 +434,7 @@ def forward_step(model, batch):
len(calib_dataloader), max(auto_quantize_score_size // args.batch_size, 1)
),
verbose=True,
# Disable all default disabled layers such as lm_head, mlp.gate, router etc.
disabled_layers=[
entry["quantizer_name"]
for entry in _default_disabled_quantizer_cfg
if "parent_class" not in entry
],
disabled_layers=disabled_layers,
method=auto_quantize_method,
checkpoint=auto_quantize_checkpoint,
)
Expand Down Expand Up @@ -507,12 +574,26 @@ def load_model(args: argparse.Namespace):
]

# We only quantize the language model for VLMs other than the type supported above.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
full_model
)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type
# For AutoQuantize, skip the eager visual-disable side-effect: it
# registers ``modelopt`` state on each visual sibling, which
# ``mtq.auto_quantize → apply_mode → is_converted`` then trips on
# ("Model has multiple modelopt states!"). AutoQuantize handles
# visual/mtp via ``disabled_layers`` patterns instead, so the
# extraction is unnecessary for that path.
#
# For ``--recipe`` mode on a VLM, lm_head sits on the OUTER
# CausalLM. Recipe rules can't see it via the inner language
# backbone, so we keep ``language_model = full_model`` here and
# let ``quantize_main`` strip visual/mtp siblings around
# ``mtq.quantize`` (so registration/calibration stays fast and
# batch_size auto-detect doesn't collapse to 1).
if args.auto_quantize_bits is None and args.recipe is None:
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
full_model
)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type

tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)

Expand Down Expand Up @@ -628,13 +709,52 @@ def mono_quantize(
else None,
)

# When ``--recipe`` is given on a VLM we keep ``language_model =
# full_model`` (so recipe rules can match lm_head) but ``mtq.quantize``
# would otherwise walk and register quantizers on every Linear in the
# visual encoder + MTP head.
stripped_vlm_modules: dict[str, torch.nn.Module] = {}
if args.recipe is not None and language_model is full_model:
for path in ("model.visual", "mtp"):
parts = path.split(".")
parent = full_model
ok = True
for p in parts[:-1]:
if not hasattr(parent, p):
ok = False
break
parent = getattr(parent, p)
if ok and hasattr(parent, parts[-1]):
mod = getattr(parent, parts[-1])
if mod is not None and isinstance(mod, torch.nn.Module):
stripped_vlm_modules[path] = mod
setattr(parent, parts[-1], None)
if stripped_vlm_modules:
print(
"[recipe] stripped VLM siblings before mtq.quantize: "
+ ", ".join(stripped_vlm_modules.keys())
)

if calibration_only:
language_model = mtq.calibrate(
language_model, quant_cfg["algorithm"], forward_loop=calibrate_loop
)
else:
language_model = mtq.quantize(language_model, quant_cfg, forward_loop=calibrate_loop)

# Restore stripped VLM siblings so export sees the full model.
for path, mod in stripped_vlm_modules.items():
parts = path.split(".")
parent = full_model
for p in parts[:-1]:
parent = getattr(parent, p)
setattr(parent, parts[-1], mod)
if stripped_vlm_modules:
print(
"[recipe] restored VLM siblings after mtq.quantize: "
+ ", ".join(stripped_vlm_modules.keys())
)

# For VL models, update full_model to use the quantized language model
if is_nemotron_vl_model:
language_model_lineage = get_language_model_from_vl(full_model)
Expand Down Expand Up @@ -1018,10 +1138,18 @@ def quantize_main(
"Auto quantization needs multiple quantization format."
)

# For VL models, autoquant must walk submodules of the OUTER CausalLM
# (which carries lm_head and the LM-head forward path) — otherwise
# lm_head and any sibling-of-language_model modules are silently
# invisible to the search. ``forward_step`` also needs the outer model
# to produce ``CausalLMOutputWithPast`` (for ``.loss`` / ``.logits``).
# Visual tower and MTP siblings are auto-excluded inside
# ``auto_quantize()`` via *visual* / *mtp* / *vision_tower* patterns.
auto_quantize(
args,
language_model,
full_model,
calib_dataloader,
auto_quantize_method=args.auto_quantize_method,
)

else:
Expand Down
16 changes: 15 additions & 1 deletion modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,9 +1354,23 @@ def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False
"""Preprocess the quantized linears that we plan to fuse.

Use resmooth_only for MOE experts as each individual expert is not fused.

When the modules carry mismatched quantization formats — most often after
AutoQuantize picks different formats for layers that share input but were
not coalesced into a single search group — we cannot coalesce them into a
fused linear. In that case, fall back to skipping the fusion so each linear
exports independently with its own format, instead of asserting.
"""
quantization_format_list = [get_quantization_format(module) for module in modules]
assert all_items_same(quantization_format_list), "Modules have different quantization formats"
if not all_items_same(quantization_format_list):
warn(
"preprocess_linear_fusion: modules in this fusion group have mixed "
f"quantization formats {quantization_format_list}. Skipping fusion; "
"each linear will export with its own format. Common cause: "
"AutoQuantize assigned different formats to fusion-mate linears.",
stacklevel=2,
)
return

# Activation
if hasattr(modules[0], "input_quantizer"):
Expand Down
Loading
Loading