From 1dc8d2a2afa38daa55e6a70acb6a855bc5e8d0a8 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:44:25 +0000 Subject: [PATCH 1/7] prototype Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/model_quant.py | 87 +++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 77743b41c2..caae93e315 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -30,7 +30,7 @@ from modelopt.torch.opt.utils import forward_with_reshard from modelopt.torch.quantization.config import QuantizeConfig from modelopt.torch.quantization.conversion import set_quantizer_by_cfg -from modelopt.torch.utils import atomic_print +from modelopt.torch.utils import atomic_print, print_rank_0 from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe from .config import QuantizeAlgoCfgType @@ -41,6 +41,7 @@ __all__ = [ "auto_quantize", "calibrate", + "compute_quantization_mse", "disable_quantizer", "enable_quantizer", "fold_weight", @@ -535,3 +536,87 @@ def fold_weight(model: nn.Module, keep_attrs: bool = False): for name, module in model.named_modules(): if isinstance(module, QuantModule): module.fold_weight(keep_attrs) + +@torch.no_grad() +def compute_quantization_mse( + model: nn.Module, + forward_loop: ForwardLoop, + wildcards: str | Callable | list[str | Callable] = "*", +) -> dict[str, float]: + """Compute the mean-squared quantization error for selected quantizers. + + Runs ``forward_loop`` through the model while recording, for every matching + :class:`TensorQuantizer`, the MSE between the original float tensor and + its fake-quantized (Q→DQ) counterpart. Values are averaged over all + calibration batches. + + Args: + model: A quantized model (output of :func:`quantize`). + forward_loop: Callable that takes ``model`` and runs data through it. + wildcards: One or more fnmatch glob patterns (or callable filters) + matched against :class:`TensorQuantizer` module names in + ``model.named_modules()``. Follows the same convention as + ``quant_cfg`` wildcard keys. Defaults to ``"*"`` (all quantizers). + + Returns: + A dict mapping each matched quantizer's fully-qualified name to its + mean MSE (float). Quantizers that are disabled or not in fake-quant + mode are skipped and absent from the output. + + Example:: + + mse = mtq.compute_quantization_mse( + model, + forward_loop, + wildcards=["*k_bmm_quantizer", "*v_bmm_quantizer"], + ) + for name, err in sorted(mse.items()): + print(f"{name}: {err:.4e}") + """ + if isinstance(wildcards, (str, Callable)): + wildcards = [wildcards] + + def _matches(name: str) -> bool: + return any( + fnmatch.fnmatch(name, w) if isinstance(w, str) else w(name) + for w in wildcards + ) + + accumulators: dict[str, dict] = {} # name -> {"sum": float, "count": int} + hooks = [] + + for name, module in model.named_modules(): + if not isinstance(module, TensorQuantizer): + continue + if not _matches(name): + continue + if not (module._if_quant and module._fake_quant): + print_rank_0( + f"[compute_quantization_mse] Skipping {name}: " + f"_if_quant={module._if_quant}, _fake_quant={module._fake_quant}" + ) + continue + + accumulators[name] = {"sum": 0.0, "count": 0} + + def _make_hook(acc): + def hook(mod, inp, out): + original = inp[0].detach().float() + quantized = out.detach().float() + acc["sum"] += torch.mean((original - quantized) ** 2).item() + acc["count"] += 1 + + return hook + + hooks.append(module.register_forward_hook(_make_hook(accumulators[name]))) + + forward_loop(model) + + for h in hooks: + h.remove() + + return { + name: acc["sum"] / acc["count"] + for name, acc in accumulators.items() + if acc["count"] > 0 + } From ab58ddade257ada96c087a304b5cd842b64b0b7d Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 26 Feb 2026 19:47:14 +0000 Subject: [PATCH 2/7] add test Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- .../test_compute_quantization_mse.py | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100644 tests/unit/torch/quantization/test_compute_quantization_mse.py diff --git a/tests/unit/torch/quantization/test_compute_quantization_mse.py b/tests/unit/torch/quantization/test_compute_quantization_mse.py new file mode 100644 index 0000000000..24ab98b429 --- /dev/null +++ b/tests/unit/torch/quantization/test_compute_quantization_mse.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for mtq.compute_quantization_mse().""" + +import pytest +import torch +from _test_utils.torch.quantization.models import SimpleLinear + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.nn import TensorQuantizer + +INT8_CFG = { + "quant_cfg": { + "*weight_quantizer": {"num_bits": 8, "axis": 0}, + "*input_quantizer": {"num_bits": 8, "axis": None}, + }, + "algorithm": "max", +} + + +def _make_quantized_model(): + model = SimpleLinear() + calib_data = [model.get_input() for _ in range(4)] + + def forward_loop(m): + for batch in calib_data: + m(batch) + + mtq.quantize(model, INT8_CFG, forward_loop) + return model, forward_loop + + +class TestComputeQuantizationMse: + def test_returns_nonnegative_values(self): + """MSE values must be >= 0 for all quantizers.""" + model, forward_loop = _make_quantized_model() + mse = mtq.compute_quantization_mse(model, forward_loop) + assert len(mse) > 0 + assert all(v >= 0.0 for v in mse.values()) + + def test_wildcard_star_covers_all_enabled_fake_quant(self): + """Default wildcard '*' should return an entry for every enabled fake-quant quantizer.""" + model, forward_loop = _make_quantized_model() + mse = mtq.compute_quantization_mse(model, forward_loop, wildcards="*") + + expected_names = { + name + for name, module in model.named_modules() + if isinstance(module, TensorQuantizer) and module._if_quant and module._fake_quant + } + assert set(mse.keys()) == expected_names + + def test_wildcard_filters_by_suffix(self): + """A suffix pattern should restrict results to matching quantizer names.""" + model, forward_loop = _make_quantized_model() + mse = mtq.compute_quantization_mse(model, forward_loop, wildcards="*weight_quantizer") + assert len(mse) > 0 + assert all("weight_quantizer" in k for k in mse.keys()) + # No input quantizers should appear + assert not any("input_quantizer" in k for k in mse.keys()) + + def test_list_of_wildcards(self): + """A list of patterns should return the union of matched quantizers.""" + model, forward_loop = _make_quantized_model() + mse_weight = mtq.compute_quantization_mse(model, forward_loop, wildcards="*weight_quantizer") + mse_input = mtq.compute_quantization_mse(model, forward_loop, wildcards="*input_quantizer") + mse_both = mtq.compute_quantization_mse( + model, forward_loop, wildcards=["*weight_quantizer", "*input_quantizer"] + ) + assert set(mse_both.keys()) == set(mse_weight.keys()) | set(mse_input.keys()) + + def test_callable_filter(self): + """A callable wildcard should select quantizers by arbitrary predicate.""" + model, forward_loop = _make_quantized_model() + # Pick only quantizers belonging to the first linear layer (net.0) + mse = mtq.compute_quantization_mse( + model, forward_loop, wildcards=lambda n: "net.0" in n + ) + assert len(mse) > 0 + assert all("net.0" in k for k in mse.keys()) + + def test_disabled_quantizer_absent_from_result(self): + """A quantizer disabled after calibration must not appear in the output.""" + model, forward_loop = _make_quantized_model() + + # Disable one quantizer and record its name + disabled_name = None + for name, module in model.named_modules(): + if isinstance(module, TensorQuantizer) and module._if_quant and module._fake_quant: + module.disable() + disabled_name = name + break + + assert disabled_name is not None, "No enabled quantizer found to disable" + + mse = mtq.compute_quantization_mse(model, forward_loop) + assert disabled_name not in mse + + def test_no_matching_wildcard_returns_empty_dict(self): + """A pattern that matches nothing should return an empty dict.""" + model, forward_loop = _make_quantized_model() + mse = mtq.compute_quantization_mse( + model, forward_loop, wildcards="*nonexistent_quantizer_xyz*" + ) + assert mse == {} + + def test_does_not_modify_model_parameters(self): + """Running MSE measurement must leave model weights unchanged.""" + model, forward_loop = _make_quantized_model() + params_before = {k: v.clone() for k, v in model.named_parameters()} + mtq.compute_quantization_mse(model, forward_loop) + for k, v in model.named_parameters(): + assert torch.equal(v, params_before[k]), f"Parameter {k} was modified" + + def test_hooks_removed_after_call(self): + """All forward hooks registered during the call must be cleaned up.""" + model, forward_loop = _make_quantized_model() + + hooks_before = sum( + len(m._forward_hooks) + for m in model.modules() + if isinstance(m, TensorQuantizer) + ) + mtq.compute_quantization_mse(model, forward_loop) + hooks_after = sum( + len(m._forward_hooks) + for m in model.modules() + if isinstance(m, TensorQuantizer) + ) + assert hooks_after == hooks_before From 905307037de40b18457f257b6003104c76332b3f Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Wed, 4 Mar 2026 21:55:36 +0000 Subject: [PATCH 3/7] fix code format Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- CHANGELOG.rst | 1 + modelopt/torch/quantization/model_quant.py | 11 +++------ .../test_compute_quantization_mse.py | 23 ++++++++----------- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4b3ee96fb0..6b47e00e61 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ NVIDIA Model Optimizer Changelog - Add support for rotating the input before quantization for RHT. - Add support for advanced weight scale search for NVFP4 quantization and its export path. - Enable PTQ workflow for Qwen3.5 MoE models. +- Add :meth:`compute_quantization_mse ` API to measure per-quantizer mean-squared quantization error, with flexible wildcard and callable filtering. **Misc** diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index caae93e315..e83d6c6154 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -573,14 +573,11 @@ def compute_quantization_mse( for name, err in sorted(mse.items()): print(f"{name}: {err:.4e}") """ - if isinstance(wildcards, (str, Callable)): + if not isinstance(wildcards, list): wildcards = [wildcards] def _matches(name: str) -> bool: - return any( - fnmatch.fnmatch(name, w) if isinstance(w, str) else w(name) - for w in wildcards - ) + return any(fnmatch.fnmatch(name, w) if isinstance(w, str) else w(name) for w in wildcards) accumulators: dict[str, dict] = {} # name -> {"sum": float, "count": int} hooks = [] @@ -616,7 +613,5 @@ def hook(mod, inp, out): h.remove() return { - name: acc["sum"] / acc["count"] - for name, acc in accumulators.items() - if acc["count"] > 0 + name: acc["sum"] / acc["count"] for name, acc in accumulators.items() if acc["count"] > 0 } diff --git a/tests/unit/torch/quantization/test_compute_quantization_mse.py b/tests/unit/torch/quantization/test_compute_quantization_mse.py index 24ab98b429..18e7ad8f25 100644 --- a/tests/unit/torch/quantization/test_compute_quantization_mse.py +++ b/tests/unit/torch/quantization/test_compute_quantization_mse.py @@ -15,7 +15,6 @@ """Unit tests for mtq.compute_quantization_mse().""" -import pytest import torch from _test_utils.torch.quantization.models import SimpleLinear @@ -68,14 +67,16 @@ def test_wildcard_filters_by_suffix(self): model, forward_loop = _make_quantized_model() mse = mtq.compute_quantization_mse(model, forward_loop, wildcards="*weight_quantizer") assert len(mse) > 0 - assert all("weight_quantizer" in k for k in mse.keys()) + assert all("weight_quantizer" in k for k in mse) # No input quantizers should appear - assert not any("input_quantizer" in k for k in mse.keys()) + assert not any("input_quantizer" in k for k in mse) def test_list_of_wildcards(self): """A list of patterns should return the union of matched quantizers.""" model, forward_loop = _make_quantized_model() - mse_weight = mtq.compute_quantization_mse(model, forward_loop, wildcards="*weight_quantizer") + mse_weight = mtq.compute_quantization_mse( + model, forward_loop, wildcards="*weight_quantizer" + ) mse_input = mtq.compute_quantization_mse(model, forward_loop, wildcards="*input_quantizer") mse_both = mtq.compute_quantization_mse( model, forward_loop, wildcards=["*weight_quantizer", "*input_quantizer"] @@ -86,11 +87,9 @@ def test_callable_filter(self): """A callable wildcard should select quantizers by arbitrary predicate.""" model, forward_loop = _make_quantized_model() # Pick only quantizers belonging to the first linear layer (net.0) - mse = mtq.compute_quantization_mse( - model, forward_loop, wildcards=lambda n: "net.0" in n - ) + mse = mtq.compute_quantization_mse(model, forward_loop, wildcards=lambda n: "net.0" in n) assert len(mse) > 0 - assert all("net.0" in k for k in mse.keys()) + assert all("net.0" in k for k in mse) def test_disabled_quantizer_absent_from_result(self): """A quantizer disabled after calibration must not appear in the output.""" @@ -130,14 +129,10 @@ def test_hooks_removed_after_call(self): model, forward_loop = _make_quantized_model() hooks_before = sum( - len(m._forward_hooks) - for m in model.modules() - if isinstance(m, TensorQuantizer) + len(m._forward_hooks) for m in model.modules() if isinstance(m, TensorQuantizer) ) mtq.compute_quantization_mse(model, forward_loop) hooks_after = sum( - len(m._forward_hooks) - for m in model.modules() - if isinstance(m, TensorQuantizer) + len(m._forward_hooks) for m in model.modules() if isinstance(m, TensorQuantizer) ) assert hooks_after == hooks_before From 287218be8ae5edea8595e1034cfdf57478fbf612 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:07:44 +0000 Subject: [PATCH 4/7] fix format Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/model_quant.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index e83d6c6154..147b81e444 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -537,6 +537,7 @@ def fold_weight(model: nn.Module, keep_attrs: bool = False): if isinstance(module, QuantModule): module.fold_weight(keep_attrs) + @torch.no_grad() def compute_quantization_mse( model: nn.Module, From 10ea8befcf9c4bd54cbb2b452cef644e2a91ec96 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Wed, 4 Mar 2026 23:44:30 +0000 Subject: [PATCH 5/7] address comments Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/model_quant.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 147b81e444..c747a2abf6 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -30,7 +30,7 @@ from modelopt.torch.opt.utils import forward_with_reshard from modelopt.torch.quantization.config import QuantizeConfig from modelopt.torch.quantization.conversion import set_quantizer_by_cfg -from modelopt.torch.utils import atomic_print, print_rank_0 +from modelopt.torch.utils import atomic_print from .algorithms import AutoQuantizeGradientSearcher, AutoQuantizeKLDivSearcher, QuantRecipe from .config import QuantizeAlgoCfgType @@ -588,13 +588,6 @@ def _matches(name: str) -> bool: continue if not _matches(name): continue - if not (module._if_quant and module._fake_quant): - print_rank_0( - f"[compute_quantization_mse] Skipping {name}: " - f"_if_quant={module._if_quant}, _fake_quant={module._fake_quant}" - ) - continue - accumulators[name] = {"sum": 0.0, "count": 0} def _make_hook(acc): @@ -608,10 +601,11 @@ def hook(mod, inp, out): hooks.append(module.register_forward_hook(_make_hook(accumulators[name]))) - forward_loop(model) - - for h in hooks: - h.remove() + try: + forward_loop(model) + finally: + for h in hooks: + h.remove() return { name: acc["sum"] / acc["count"] for name, acc in accumulators.items() if acc["count"] > 0 From 99d0c14f4f85fd551f8d8ca88f916cf0e52189ee Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 5 Mar 2026 06:15:56 +0000 Subject: [PATCH 6/7] fix Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/model_quant.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index c747a2abf6..615942645d 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -588,6 +588,8 @@ def _matches(name: str) -> bool: continue if not _matches(name): continue + if not (module._if_quant and module._fake_quant): + continue accumulators[name] = {"sum": 0.0, "count": 0} def _make_hook(acc): From 3f418c25dc825a0b1f15cbe915e5f18a9d16f07a Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 5 Mar 2026 17:48:53 +0000 Subject: [PATCH 7/7] fix test Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/model_quant.py | 2 +- .../unit/torch/quantization/test_compute_quantization_mse.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index 615942645d..1a18d1817d 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -588,7 +588,7 @@ def _matches(name: str) -> bool: continue if not _matches(name): continue - if not (module._if_quant and module._fake_quant): + if not (module._if_quant and module._fake_quant) or module._disabled: continue accumulators[name] = {"sum": 0.0, "count": 0} diff --git a/tests/unit/torch/quantization/test_compute_quantization_mse.py b/tests/unit/torch/quantization/test_compute_quantization_mse.py index 18e7ad8f25..9a9a81a611 100644 --- a/tests/unit/torch/quantization/test_compute_quantization_mse.py +++ b/tests/unit/torch/quantization/test_compute_quantization_mse.py @@ -58,7 +58,10 @@ def test_wildcard_star_covers_all_enabled_fake_quant(self): expected_names = { name for name, module in model.named_modules() - if isinstance(module, TensorQuantizer) and module._if_quant and module._fake_quant + if isinstance(module, TensorQuantizer) + and module._if_quant + and module._fake_quant + and not module._disabled } assert set(mse.keys()) == expected_names