Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Changelog

**New Features**

- Add NVFP4 W4A16 weight-only quantization (``nvfp4_w4a16``): FP4 weights with group_size=16, BF16 activations, no calibration forward pass required. Use ``mtq.NVFP4_W4A16_CFG`` or ``--qformat nvfp4_w4a16`` in ``hf_ptq.py``. Exported checkpoints can be served on vLLM after conversion to compressed-tensors format.
- Register ``nn.Embedding`` with ``QuantModuleRegistry`` (weight-only wrapper) and extend the unified HF exporter to pack quantized embedding weights. Enables NVFP4 quantization of ``lm_head`` and the input token embedding on hybrid SSM+Attention models such as Nemotron-H, where those two tables are a sizeable fraction of parameters and leaving them in bf16 wastes most of the compression. Use ``--recipe models/Nemotron-H/nvfp4_w4a16`` (see `modelopt_recipes/models/Nemotron-H/nvfp4_w4a16.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Nemotron-H/nvfp4_w4a16.yaml>`_) to opt in. The ``--exclude_modules`` CLI flag in ``examples/llm_ptq/hf_ptq.py`` lets users selectively exclude individual modules from the recipe's coverage.
- Support full Transformer Engine spec for Minitron pruning (``mcore_minitron``). Now we no longer need to use custom ModelOpt spec. Note that this does not affect the usage of the pruning workflow but makes pruning slightly faster and may result in slightly different pruned model because of different kernel and numerics.
- Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/puzzletron>`_ for more details.
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
Expand Down
33 changes: 33 additions & 0 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import shutil
import sys
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -709,6 +710,38 @@ def has_pack_quantized_config(config):
return model


@contextmanager
def normalized_generation_config_for_export(model):
"""Temporarily swap in a normalized generation_config for export.

Some model cards ship a ``generation_config.json`` that sets sampling hyperparameters
(``top_p``/``top_k``/``temperature``) without ``do_sample=True`` (e.g.
``NVIDIA-Nemotron-3-Nano-4B-BF16``). transformers 5.x strictly validates this on
``save_pretrained`` so the export step fails. We normalize by swapping in a deep copy
with ``do_sample=True`` for the duration of the export and restoring the original
afterwards — leaving ``model.generation_config`` unchanged so any ``.generate()`` calls
outside the export window (e.g. the pre-/post-PTQ previews) remain deterministic.
"""
original = getattr(model, "generation_config", None)
normalized = None
if original is not None and not getattr(original, "do_sample", False):
has_sampling_hyperparam = (
getattr(original, "top_p", None) not in (None, 1.0)
or getattr(original, "top_k", None) not in (None, 0, 50)
or getattr(original, "temperature", None) not in (None, 1.0)
)
if has_sampling_hyperparam:
normalized = copy.deepcopy(original)
normalized.do_sample = True
try:
if normalized is not None:
model.generation_config = normalized
yield
finally:
if normalized is not None:
model.generation_config = original


def is_model_on_gpu(model) -> bool:
"""Returns if the model is fully loaded on GPUs."""
return all("cuda" in str(param.device) for param in model.parameters())
Expand Down
24 changes: 23 additions & 1 deletion examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_nemotron_vl,
load_mtp_weights,
needs_checkpoint_path_update,
normalized_generation_config_for_export,
resolve_checkpoint_dir,
run_nemotron_vl_preview,
)
Expand Down Expand Up @@ -107,6 +108,7 @@ def _set_kv_cache_constant_amax(quant_cfg: list) -> None:
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
"nvfp4_w4a16": mtq.NVFP4_W4A16_CFG,
"nvfp4_awq": mtq.NVFP4_AWQ_LITE_CFG,
"nvfp4_mse": mtq.NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG,
"fp8_pb_wo": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
Expand Down Expand Up @@ -331,6 +333,7 @@ def auto_quantize(
"nvfp4",
"nvfp4_awq",
"nvfp4_mse",
"nvfp4_w4a16",
"w4a8_awq",
"fp8_pb_wo",
"w4a8_mxfp4_fp8",
Expand Down Expand Up @@ -629,6 +632,12 @@ def mono_quantize(
) # Nemotron-Parse specific
print("Quantization will only be applied to the decoder (text generation) component")

# Model-specific quantization extensions (e.g. quantizing lm_head + input embedding for
# Nemotron-H, where those tables are a large fraction of parameters and leaving them at
# bf16 wastes most of the memory savings) are now expressed as recipes under
# ``modelopt_recipes/models/<ModelName>/``. Pass ``--recipe models/<ModelName>/<flavor>``
# (e.g. ``--recipe models/Nemotron-H/nvfp4_w4a16``) to opt in.

if not model_is_already_quantized or calibration_only:
# quantize the model

Expand Down Expand Up @@ -677,7 +686,14 @@ def export_quantized(
default_padding_side,
default_pad_token,
):
with torch.inference_mode():
# ``normalized_generation_config_for_export`` swaps ``model.generation_config`` with
# a deep-copied ``do_sample=True`` variant for the duration of the export so
# ``save_pretrained`` passes transformers 5.x's strict validation without affecting
# any ``.generate()`` callers outside the export window.
with (
torch.inference_mode(),
normalized_generation_config_for_export(full_model),
):
if model_type is None:
print(f"Unknown model type {type(language_model).__name__}. Continue exporting...")
model_type = f"unknown:{type(language_model).__name__}"
Expand Down Expand Up @@ -781,6 +797,12 @@ def export_quantized(
extra_state_dict=mtp_state_dict,
)

if args.qformat == "nvfp4_w4a16":
warnings.warn(
"TensorRT-LLM and SGLang do not support this format. "
"To serve on vLLM, convert the NVFP4 W4A16 checkpoint to compressed-tensors format."
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @ajrasane , should we point the users to how they can convert? do we have a helper in ModelOpt we should point them to?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hychiang-git, are you planning to merge your conversion script to modelopt?

)

# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
tokenizer.padding_side = default_padding_side
Expand Down
10 changes: 8 additions & 2 deletions examples/llm_ptq/scripts/huggingface_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian) ;;
fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | nvfp4_mse | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_experts_only | nvfp4_mlp_only | nvfp4_omlp_only | nvfp4_svdquant | mxfp8 | nvfp4_local_hessian | nvfp4_w4a16) ;;
*)
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian]" >&2
echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, nvfp4_mse, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_experts_only, nvfp4_mlp_only, nvfp4_omlp_only, nvfp4_svdquant, mxfp8, nvfp4_local_hessian, nvfp4_w4a16]" >&2
exit 1
;;
esac
Expand Down Expand Up @@ -199,6 +199,12 @@ if [[ $TASKS =~ "quant" ]] || [[ ! -d "$SAVE_PATH" ]] || [[ ! $(ls -A $SAVE_PATH
exit 0
fi

if [ "$QFORMAT" = "nvfp4_w4a16" ]; then
echo "nvfp4_w4a16 checkpoint exported to $SAVE_PATH"
echo "To serve on vLLM, convert to compressed-tensors"
exit 0
fi

if [[ "$QFORMAT" == *"nvfp4"* ]] || [[ "$KV_CACHE_QUANT" == *"nvfp4"* ]]; then
cuda_major=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader -i 0 | cut -d. -f1)

Expand Down
17 changes: 17 additions & 0 deletions modelopt/torch/export/convert_hf_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def _quant_algo_to_group_config(quant_algo: str, group_size: int | None = None)
return {
"weights": {"dynamic": False, "num_bits": 4, "type": "int", "group_size": gs},
}
elif quant_algo == "NVFP4_W4A16":
gs = group_size or 16
return {
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": gs},
}
elif quant_algo in ("NVFP4_AWQ", "W4A8_AWQ"):
gs = group_size or 128
return {
Expand Down Expand Up @@ -183,6 +188,18 @@ def convert_hf_quant_config_format(input_config: dict[str, Any]) -> dict[str, An
"targets": ["Linear"],
}
new_config["config_groups"] = {"group_0": config_group_details}
elif quant_algo_value == "NVFP4_W4A16":
# Weight-only FP4. Embedding is included alongside Linear because
# ``NVFP4_W4A16_CFG`` targets ``["*"]`` with ``weight_only=True``, so any registered
# ``QuantEmbedding`` gets weight-quantized too. Compressed-tensors dispatches on the
# module's ``__class__.__name__``, so omitting ``Embedding`` would leave quantized
# embedding weights orphaned on the consumer side.
group_size = original_quantization_details.get("group_size", 16)
config_group_details = {
"weights": {"dynamic": False, "num_bits": 4, "type": "float", "group_size": group_size},
"targets": ["Linear", "Embedding"],
}
new_config["config_groups"] = {"group_0": config_group_details}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
elif quant_algo_value == "MIXED_PRECISION":
quantized_layers = original_quantization_details.get("quantized_layers", {})

Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/export/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
QUANTIZATION_MXFP8 = "mxfp8"
QUANTIZATION_W4A8_MXFP4_FP8 = "w4a8_mxfp4_fp8"
QUANTIZATION_NVFP4_AWQ = "nvfp4_awq"
QUANTIZATION_NVFP4_W4A16 = "nvfp4_w4a16" # weight-only FP4
QUANTIZATION_FP8_PB_REAL = "fp8_pb_real"
QUANTIZATION_FP8_PB_WO = "fp8_pb_wo"
QUANTIZATION_FP8_PC_PT = "fp8_pc_pt"
Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4_W4A16,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_MXFP4_FP8,
QUANTIZATION_W4A8_NVFP4_FP8,
Expand Down Expand Up @@ -358,6 +359,7 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4_W4A16,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
Expand Down Expand Up @@ -402,6 +404,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4_W4A16,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
# Calibrate weight quantizer if amax is not set
Expand Down Expand Up @@ -636,6 +639,10 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_prequant", False):
return QUANTIZATION_NVFP4_AWQ
# W4A16 weight-only: input_quantizer absent or disabled
if input_quantizer is None or not input_quantizer.is_enabled:
if scale_bits == (4, 3):
return QUANTIZATION_NVFP4_W4A16
assert input_quantizer is not None, (
f"input_quantizer is None for {quantizer_attr_names}"
)
Expand Down Expand Up @@ -803,6 +810,11 @@ def process_layer_quant_config(layer_config_dict):
"quant_algo": "NVFP4",
"group_size": block_size_value,
}
elif v == "nvfp4_w4a16":
layer_config = {
"quant_algo": "NVFP4_W4A16",
"group_size": block_size_value,
}
elif v == "nvfp4_awq":
layer_config = {
"quant_algo": "NVFP4_AWQ",
Expand Down Expand Up @@ -980,6 +992,7 @@ def to_quantized_weight(
if quantization in [
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_W4A16,
QUANTIZATION_W4A8_NVFP4_FP8,
QUANTIZATION_NVFP4_SVDQUANT,
]:
Expand Down
6 changes: 5 additions & 1 deletion modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from .layer_utils import (
get_expert_linear_names,
get_experts_list,
is_embedding,
is_layernorm,
is_moe,
is_quantlinear,
Expand All @@ -84,6 +85,7 @@
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4_W4A16,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
)
Expand Down Expand Up @@ -520,6 +522,7 @@ def _export_quantized_weight(
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_W4A16,
QUANTIZATION_W4A8_AWQ,
QUANTIZATION_W4A8_NVFP4_FP8,
]:
Expand Down Expand Up @@ -548,6 +551,7 @@ def _export_quantized_weight(
QUANTIZATION_NVFP4,
QUANTIZATION_NVFP4_AWQ,
QUANTIZATION_NVFP4_SVDQUANT,
QUANTIZATION_NVFP4_W4A16,
]:
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
Expand Down Expand Up @@ -650,7 +654,7 @@ def _process_quantized_modules(
# Skip QuantMoELinear - it's handled separately in _reconstruct_fused_moe_linear
if type(sub_module).__name__ == "QuantMoELinear":
continue
if is_quantlinear(sub_module):
if is_quantlinear(sub_module) or is_embedding(sub_module):
try:
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_quantized_weight(sub_module, dtype)
Expand Down
2 changes: 2 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def _nvfp4_selective_quant_cfg(
NVFP4_EXPERTS_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp.experts*", "*block_sparse_moe*"])
NVFP4_MLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*"])
NVFP4_OMLP_ONLY_CFG = _nvfp4_selective_quant_cfg(["*o_proj*", "*mlp*", "*block_sparse_moe*"])
NVFP4_W4A16_CFG = _nvfp4_selective_quant_cfg(["*"], weight_only=True)

# DO NOT ADD NEW CONFIGS HERE. If you want to add a new general recipe, add it to
# modelopt_recipes/general/ptq/ as a yaml file
Expand Down Expand Up @@ -828,6 +829,7 @@ def _nvfp4_selective_quant_cfg(
"NVFP4_MLP_ONLY_CFG",
"NVFP4_EXPERTS_ONLY_CFG",
"NVFP4_OMLP_ONLY_CFG",
"NVFP4_W4A16_CFG",
"MAMBA_MOE_NVFP4_CONSERVATIVE_CFG",
"MAMBA_MOE_NVFP4_AGGRESSIVE_CFG",
"MAMBA_MOE_FP8_CONSERVATIVE_CFG",
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .modules.quant_activations import *
from .modules.quant_batchnorm import *
from .modules.quant_conv import *
from .modules.quant_embedding import *
from .modules.quant_instancenorm import *
from .modules.quant_linear import *
from .modules.quant_module import *
Expand Down
50 changes: 50 additions & 0 deletions modelopt/torch/quantization/nn/modules/quant_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Quantized Embedding.

``nn.Embedding`` quantization is weight-only: only the lookup table (``weight``) is
fake-quantized. Embedding inputs are integer indices — their ``input_quantizer`` is
registered (so config entries like ``"*input_quantizer"`` can still target it) but is
disabled by default so integer tensors pass through untouched.
"""

import torch.nn as nn

from ... import tensor_quant
from .quant_module import QuantLinearConvBase, QuantModuleRegistry

__all__ = ["QuantEmbedding"]


@QuantModuleRegistry.register({nn.Embedding: "nn.Embedding"})
class _QuantEmbedding(QuantLinearConvBase):
"""Quantized base class for ``nn.Embedding``.

Weight-only quantization. Input/output quantizers are created (so wildcard configs
still resolve cleanly) but are disabled — an embedding's input is an index tensor.
"""

default_quant_desc_weight = tensor_quant.QUANT_DESC_8BIT_LINEAR_WEIGHT_PER_ROW

def _setup(self):
super()._setup()
# Embedding inputs are integer indices; never fake-quantize them.
self.input_quantizer.disable()
# output_quantizer is already disabled by QuantInputBase._setup().


# Alias to follow the naming convention of QuantLinear.
QuantEmbedding = _QuantEmbedding
Loading
Loading