From 1377465c6d5ddfcffcb9ae88daa0a5aa82474368 Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 24 Apr 2026 20:05:29 +0000 Subject: [PATCH 1/2] fix: layerwise calibration backward-compat, recipe split, batch-size guard - config: accept legacy `use_sequential` via AliasChoices on `layerwise` so pre-#1251 PTQ checkpoints load; still serializes as `layerwise` - recipes: split nvfp4_experts_only-fp8_kv into default (no layerwise) and _layerwise variants - hf_ptq: auto batch-size detection not supported with layerwise; default to batch_size=1 in that case - tests: cover alias accept, current-name accept, dump under current name, and extra='forbid' still rejecting unknowns Signed-off-by: realAsma --- examples/llm_ptq/hf_ptq.py | 24 +++-- modelopt/torch/quantization/config.py | 3 +- .../ptq/nvfp4_experts_only-fp8_kv.yaml | 4 +- .../nvfp4_experts_only-fp8_kv_layerwise.yaml | 94 +++++++++++++++++++ .../quantization/test_config_validation.py | 33 +++++++ 5 files changed, 148 insertions(+), 10 deletions(-) create mode 100644 modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d660c1de4c..1533db14ed 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -955,6 +955,18 @@ def quantize_main( default_pad_token, device: torch.device, ): + # Load the recipe up front so we can detect layerwise calibration before batch-size probing. + recipe = None + if args.recipe is not None and not args.auto_quantize_bits: + print(f"Use recipe {args.recipe} for quantization") + recipe = load_recipe(args.recipe) + assert isinstance(recipe, ModelOptPTQRecipe), ( + f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" + ) + + recipe_algorithm = recipe.quantize.model_dump().get("algorithm") if recipe else None + is_layerwise = isinstance(recipe_algorithm, dict) and recipe_algorithm.get("layerwise", False) + if args.batch_size == 0: # For VL models with image-text calibration, skip automatic batch size detection # since get_max_batch_size can't handle multimodal inputs @@ -968,6 +980,11 @@ def quantize_main( "Offline speculative decoding calibration enabled. Using default batch_size=1 for calibration." ) args.batch_size = 1 + # Layerwise calibration processes one layer at a time; auto batch-size probing runs a + # full-model forward which defeats the point and can OOM on very large models. + elif is_layerwise: + print("Layerwise calibration enabled. Using default batch_size=1 for calibration.") + args.batch_size = 1 else: # Calibration/sparsification will actually take much more memory than regular inference # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio @@ -1027,12 +1044,7 @@ def quantize_main( else: # mono quantization - if args.recipe is not None: - print(f"Use recipe {args.recipe} for quantization") - recipe = load_recipe(args.recipe) - assert isinstance(recipe, ModelOptPTQRecipe), ( - f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" - ) + if recipe is not None: quant_cfg = recipe.quantize.model_dump() else: diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 186ff1c7ed..552054f489 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -154,7 +154,7 @@ import warnings from typing import Any, Literal, cast -from pydantic import ValidationInfo, field_validator, model_validator +from pydantic import AliasChoices, ValidationInfo, field_validator, model_validator from typing_extensions import Required, TypedDict from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField @@ -1219,6 +1219,7 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): layerwise: bool = ModeloptField( default=False, + validation_alias=AliasChoices("layerwise", "use_sequential"), title="Enable layerwise (layer-by-layer) calibration.", description=( "If True, the calibration algorithm is applied layer by layer. " diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml index 220d062232..f3f62f0663 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml @@ -15,12 +15,10 @@ metadata: recipe_type: ptq - description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration. + description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max calibration. quantize: algorithm: method: max - # Max calibration is fast and does not typically need checkpointing. - layerwise: true quant_cfg: - quantizer_name: '*' enable: false diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml new file mode 100644 index 0000000000..220d062232 --- /dev/null +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +metadata: + recipe_type: ptq + description: NVFP4 static weight and dynamic activation for expert layers only (W4A4), FP8 KV cache, max layerwise calibration. +quantize: + algorithm: + method: max + # Max calibration is fast and does not typically need checkpointing. + layerwise: true + quant_cfg: + - quantizer_name: '*' + enable: false + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*block_sparse_moe*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + - quantizer_name: '*block_sparse_moe.gate*' + enable: false + - quantizer_name: '*linear_attn.conv1d*' + enable: false + - quantizer_name: '*lm_head*' + enable: false + - quantizer_name: '*mixer.conv1d*' + enable: false + - quantizer_name: '*mlp.gate.*' + enable: false + - quantizer_name: '*mlp.shared_expert_gate.*' + enable: false + - quantizer_name: '*output_layer*' + enable: false + - quantizer_name: '*proj_out.*' + enable: false + - quantizer_name: '*router*' + enable: false + - quantizer_name: 'output.*' + enable: false + - parent_class: 'nn.BatchNorm1d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm2d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.BatchNorm3d' + quantizer_name: '*' + enable: false + - parent_class: 'nn.LeakyReLU' + quantizer_name: '*' + enable: false diff --git a/tests/unit/torch/quantization/test_config_validation.py b/tests/unit/torch/quantization/test_config_validation.py index f5b1e576f5..84306dc511 100644 --- a/tests/unit/torch/quantization/test_config_validation.py +++ b/tests/unit/torch/quantization/test_config_validation.py @@ -25,6 +25,7 @@ INT4_AWQ_CFG, NVFP4_DEFAULT_CFG, W4A8_AWQ_BETA_CFG, + MaxCalibConfig, QuantizeConfig, find_quant_cfg_entry_by_path, need_calibration, @@ -525,3 +526,35 @@ def test_validate_quant_cfg_entries_accepts_valid_cfg(self): algorithm="max", ) assert len(cfg.quant_cfg) == 2 + + +class TestLayerwiseUseSequentialAlias: + """`layerwise` accepts the legacy `use_sequential` name via validation_alias. + + Old PTQ checkpoints serialized the field as `use_sequential` before #1251 renamed + it to `layerwise`. AliasChoices lets those checkpoints load without a migration + validator while still serializing under the current name. + """ + + def test_use_sequential_true_sets_layerwise(self): + cfg = MaxCalibConfig(use_sequential=True) + assert cfg.layerwise is True + + def test_use_sequential_false_sets_layerwise(self): + cfg = MaxCalibConfig(use_sequential=False) + assert cfg.layerwise is False + + def test_layerwise_name_still_accepted(self): + cfg = MaxCalibConfig(layerwise=True) + assert cfg.layerwise is True + + def test_serializes_under_current_name(self): + """Dump must use `layerwise`, not the legacy alias.""" + dumped = MaxCalibConfig(use_sequential=True).model_dump() + assert dumped["layerwise"] is True + assert "use_sequential" not in dumped + + def test_unknown_field_still_rejected(self): + """extra='forbid' must still reject unrelated unknown fields.""" + with pytest.raises(ValidationError): + MaxCalibConfig(not_a_real_field=True) From f78ac5046ffd884e8d900e5ce5b568dc57836f0f Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 24 Apr 2026 22:06:33 +0000 Subject: [PATCH 2/2] fix: address review feedback on layerwise detection + header + input validation - examples/llm_ptq/hf_ptq.py: replace dict-inspection layerwise detection with a small recursive helper accepting ModelOptPTQRecipe directly, handling list-form QuantizeAlgoCfgType (per coderabbitai, jenchen13). - examples/llm_ptq/hf_ptq.py: convert recipe-type assert to explicit if/raise TypeError so validation is not stripped under python -O (per cjluo-nv). - modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml: bump new-file copyright header to 2026 per LICENSE_HEADER (per cjluo-nv). Signed-off-by: realAsma --- examples/llm_ptq/hf_ptq.py | 17 ++++++++++++----- .../nvfp4_experts_only-fp8_kv_layerwise.yaml | 2 +- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 1533db14ed..8b732d510f 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -960,12 +960,19 @@ def quantize_main( if args.recipe is not None and not args.auto_quantize_bits: print(f"Use recipe {args.recipe} for quantization") recipe = load_recipe(args.recipe) - assert isinstance(recipe, ModelOptPTQRecipe), ( - f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" - ) + if not isinstance(recipe, ModelOptPTQRecipe): + raise TypeError( + f"Expected PTQ recipe, but got {type(recipe).__name__} from {args.recipe}" + ) + + def _is_layerwise(obj): + if isinstance(obj, ModelOptPTQRecipe): + return _is_layerwise(obj.quantize.algorithm) + if isinstance(obj, list): + return any(_is_layerwise(a) for a in obj) + return bool(getattr(obj, "layerwise", False)) - recipe_algorithm = recipe.quantize.model_dump().get("algorithm") if recipe else None - is_layerwise = isinstance(recipe_algorithm, dict) and recipe_algorithm.get("layerwise", False) + is_layerwise = _is_layerwise(recipe) if args.batch_size == 0: # For VL models with image-text calibration, skip automatic batch size detection diff --git a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml index 220d062232..62e46c3f66 100644 --- a/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml +++ b/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv_layerwise.yaml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License");