diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 47349527c7..aafe955dd0 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,7 @@ NVIDIA Model Optimizer Changelog - Enable PTQ workflow for Qwen3.5 MoE models. - Add ``nvfp4_omlp_only`` quantization format for NVFP4 quantization. This is similar to ``nvfp4_mlp_only`` but also quantizes the output projection layer in attention. - ``pass_through_bwd`` in the quantization config is now default to True. Please set it to False if you want to use STE with zeroed outlier gradients for potentially better QAT accuracy. +- Add :meth:`compute_quantization_mse ` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering. **Misc** diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 77743b41c2..1a18d1817d 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -41,6 +41,7 @@ __all__ = [ "auto_quantize", "calibrate", + "compute_quantization_mse", "disable_quantizer", "enable_quantizer", "fold_weight", @@ -535,3 +536,79 @@ def fold_weight(model: nn.Module, keep_attrs: bool = False): for name, module in model.named_modules(): if isinstance(module, QuantModule): module.fold_weight(keep_attrs) + + +@torch.no_grad() +def compute_quantization_mse( + model: nn.Module, + forward_loop: ForwardLoop, + wildcards: str | Callable | list[str | Callable] = "*", +) -> dict[str, float]: + """Compute the mean-squared quantization error for selected quantizers. + + Runs ``forward_loop`` through the model while recording, for every matching + :class:`TensorQuantizer`, the MSE between the original float tensor and + its fake-quantized (Q→DQ) counterpart. Values are averaged over all + calibration batches. + + Args: + model: A quantized model (output of :func:`quantize`). + forward_loop: Callable that takes ``model`` and runs data through it. + wildcards: One or more fnmatch glob patterns (or callable filters) + matched against :class:`TensorQuantizer` module names in + ``model.named_modules()``. Follows the same convention as + ``quant_cfg`` wildcard keys. Defaults to ``"*"`` (all quantizers). + + Returns: + A dict mapping each matched quantizer's fully-qualified name to its + mean MSE (float). Quantizers that are disabled or not in fake-quant + mode are skipped and absent from the output. + + Example:: + + mse = mtq.compute_quantization_mse( + model, + forward_loop, + wildcards=["*k_bmm_quantizer", "*v_bmm_quantizer"], + ) + for name, err in sorted(mse.items()): + print(f"{name}: {err:.4e}") + """ + if not isinstance(wildcards, list): + wildcards = [wildcards] + + def _matches(name: str) -> bool: + return any(fnmatch.fnmatch(name, w) if isinstance(w, str) else w(name) for w in wildcards) + + accumulators: dict[str, dict] = {} # name -> {"sum": float, "count": int} + hooks = [] + + for name, module in model.named_modules(): + if not isinstance(module, TensorQuantizer): + continue + if not _matches(name): + continue + if not (module._if_quant and module._fake_quant) or module._disabled: + continue + accumulators[name] = {"sum": 0.0, "count": 0} + + def _make_hook(acc): + def hook(mod, inp, out): + original = inp[0].detach().float() + quantized = out.detach().float() + acc["sum"] += torch.mean((original - quantized) ** 2).item() + acc["count"] += 1 + + return hook + + hooks.append(module.register_forward_hook(_make_hook(accumulators[name]))) + + try: + forward_loop(model) + finally: + for h in hooks: + h.remove() + + return { + name: acc["sum"] / acc["count"] for name, acc in accumulators.items() if acc["count"] > 0 + } diff --git a/tests/unit/torch/quantization/test_compute_quantization_mse.py b/tests/unit/torch/quantization/test_compute_quantization_mse.py new file mode 100644 index 0000000000..9a9a81a611 --- /dev/null +++ b/tests/unit/torch/quantization/test_compute_quantization_mse.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for mtq.compute_quantization_mse().""" + +import torch +from _test_utils.torch.quantization.models import SimpleLinear + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.nn import TensorQuantizer + +INT8_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": 8, "axis": 0}, + "*input_quantizer": {"num_bits": 8, "axis": None}, + }, + "algorithm": "max", +} + + +def _make_quantized_model(): + model = SimpleLinear() + calib_data = [model.get_input() for _ in range(4)] + + def forward_loop(m): + for batch in calib_data: + m(batch) + + mtq.quantize(model, INT8_CFG, forward_loop) + return model, forward_loop + + +class TestComputeQuantizationMse: + def test_returns_nonnegative_values(self): + """MSE values must be >= 0 for all quantizers.""" + model, forward_loop = _make_quantized_model() + mse = mtq.compute_quantization_mse(model, forward_loop) + assert len(mse) > 0 + assert all(v >= 0.0 for v in mse.values()) + + def test_wildcard_star_covers_all_enabled_fake_quant(self): + """Default wildcard '*' should return an entry for every enabled fake-quant quantizer.""" + model, forward_loop = _make_quantized_model() + mse = mtq.compute_quantization_mse(model, forward_loop, wildcards="*") + + expected_names = { + name + for name, module in model.named_modules() + if isinstance(module, TensorQuantizer) + and module._if_quant + and module._fake_quant + and not module._disabled + } + assert set(mse.keys()) == expected_names + + def test_wildcard_filters_by_suffix(self): + """A suffix pattern should restrict results to matching quantizer names.""" + model, forward_loop = _make_quantized_model() + mse = mtq.compute_quantization_mse(model, forward_loop, wildcards="*weight_quantizer") + assert len(mse) > 0 + assert all("weight_quantizer" in k for k in mse) + # No input quantizers should appear + assert not any("input_quantizer" in k for k in mse) + + def test_list_of_wildcards(self): + """A list of patterns should return the union of matched quantizers.""" + model, forward_loop = _make_quantized_model() + mse_weight = mtq.compute_quantization_mse( + model, forward_loop, wildcards="*weight_quantizer" + ) + mse_input = mtq.compute_quantization_mse(model, forward_loop, wildcards="*input_quantizer") + mse_both = mtq.compute_quantization_mse( + model, forward_loop, wildcards=["*weight_quantizer", "*input_quantizer"] + ) + assert set(mse_both.keys()) == set(mse_weight.keys()) | set(mse_input.keys()) + + def test_callable_filter(self): + """A callable wildcard should select quantizers by arbitrary predicate.""" + model, forward_loop = _make_quantized_model() + # Pick only quantizers belonging to the first linear layer (net.0) + mse = mtq.compute_quantization_mse(model, forward_loop, wildcards=lambda n: "net.0" in n) + assert len(mse) > 0 + assert all("net.0" in k for k in mse) + + def test_disabled_quantizer_absent_from_result(self): + """A quantizer disabled after calibration must not appear in the output.""" + model, forward_loop = _make_quantized_model() + + # Disable one quantizer and record its name + disabled_name = None + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer) and module._if_quant and module._fake_quant: + module.disable() + disabled_name = name + break + + assert disabled_name is not None, "No enabled quantizer found to disable" + + mse = mtq.compute_quantization_mse(model, forward_loop) + assert disabled_name not in mse + + def test_no_matching_wildcard_returns_empty_dict(self): + """A pattern that matches nothing should return an empty dict.""" + model, forward_loop = _make_quantized_model() + mse = mtq.compute_quantization_mse( + model, forward_loop, wildcards="*nonexistent_quantizer_xyz*" + ) + assert mse == {} + + def test_does_not_modify_model_parameters(self): + """Running MSE measurement must leave model weights unchanged.""" + model, forward_loop = _make_quantized_model() + params_before = {k: v.clone() for k, v in model.named_parameters()} + mtq.compute_quantization_mse(model, forward_loop) + for k, v in model.named_parameters(): + assert torch.equal(v, params_before[k]), f"Parameter {k} was modified" + + def test_hooks_removed_after_call(self): + """All forward hooks registered during the call must be cleaned up.""" + model, forward_loop = _make_quantized_model() + + hooks_before = sum( + len(m._forward_hooks) for m in model.modules() if isinstance(m, TensorQuantizer) + ) + mtq.compute_quantization_mse(model, forward_loop) + hooks_after = sum( + len(m._forward_hooks) for m in model.modules() if isinstance(m, TensorQuantizer) + ) + assert hooks_after == hooks_before