From 2a0c85283253ef477fe4d34e6f88fa5cf62b4f45 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 28 Apr 2026 11:31:11 -0700 Subject: [PATCH 01/15] add nemotron super 4 nvfp4 recipe Signed-off-by: Jennifer Chen --- .../super-nvfp4.yaml | 124 ++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml new file mode 100644 index 00000000000..bde26519bfb --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# - MoE routed experts (mixer.experts..{up,down}_proj): NVFP4 W4A4 weight MSE, group_size 16 +# - MoE shared experts (mixer.shared_experts.{up,down}_proj): FP8 per-tensor +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: weight MSE with FP8-scale sweep over the 128 e4m3 scale values +# (NVFP4 weights use static block scales selected by MSE; FP8 per-tensor scales +# are also chosen via MSE search instead of plain amax). +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and attention o_proj/fc1_latent_proj/fc2_latent_proj + FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Weight-MSE calibration with FP8 scale sweep. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + quant_cfg: + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # Weight uses static block scales (chosen by MSE); activations stay dynamic. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # latent MOE down/up projections) -> FP8 per-tensor. + # NOTE: only 3 layers quantized latent MOE to FP8, layers 1, 3, 5 + - quantizer_name: '*mixer.fc1_latent_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.fc1_latent_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.fc2_latent_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.fc2_latent_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + + # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. + # SSM state / mamba conv1d stay FP16. From 9f96df3181a5840fe9818de280bec27e59191ee4 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 28 Apr 2026 11:35:38 -0700 Subject: [PATCH 02/15] remove latent moe fp8 Signed-off-by: Jennifer Chen --- .../super-nvfp4.yaml | 23 ------------------- 1 file changed, 23 deletions(-) diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml index bde26519bfb..7a480e971f1 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -91,29 +91,6 @@ quantize: num_bits: e4m3 axis: - # latent MOE down/up projections) -> FP8 per-tensor. - # NOTE: only 3 layers quantized latent MOE to FP8, layers 1, 3, 5 - - quantizer_name: '*mixer.fc1_latent_proj*weight_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*mixer.fc1_latent_proj*input_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*mixer.fc2_latent_proj*weight_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*mixer.fc2_latent_proj*input_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - # KV cache -> FP8. - quantizer_name: '*[kv]_bmm_quantizer' enable: true From 9282cdb240f451bb82d296bf8762cf64f0e7fcee Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 28 Apr 2026 11:54:29 -0700 Subject: [PATCH 03/15] fix docstring Signed-off-by: Jennifer Chen --- .../Nemotron-3-Super-120B-A12B/super-nvfp4.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml index 7a480e971f1..679ebf82a44 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,15 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# Approximately mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: # - MoE routed experts (mixer.experts..{up,down}_proj): NVFP4 W4A4 weight MSE, group_size 16 # - MoE shared experts (mixer.shared_experts.{up,down}_proj): FP8 per-tensor # - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor # - KV cache: FP8 # - Attention linears ({q,k,v}_proj): BF16 (not quantized) -# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) -# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) -# - SSM cache: FP32 (can be set to FP16 in VLLM) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) # # Calibration: weight MSE with FP8-scale sweep over the 128 e4m3 scale values # (NVFP4 weights use static block scales selected by MSE; FP8 per-tensor scales @@ -35,6 +35,8 @@ quantize: method: mse fp8_scale_sweep: true quant_cfg: + # Disable all layers by default so that these layers stay in their original precision: BF16/FP32: + # lm_head, output projection, MoE routers/gates, MTP head, SSM state, mamba conv1d. - quantizer_name: '*' enable: false @@ -97,5 +99,3 @@ quantize: cfg: num_bits: e4m3 - # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. - # SSM state / mamba conv1d stay FP16. From cfaf05504a301a8849bbb9d05aab89fdac08f849 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 29 Apr 2026 12:09:32 -0700 Subject: [PATCH 04/15] fix MSE moe calibration and add stride for fp8 scale sweep Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/calib/mse.py | 7 +++ modelopt/torch/quantization/config.py | 8 +++ modelopt/torch/quantization/model_calib.py | 68 ++++++++++++++-------- 3 files changed, 60 insertions(+), 23 deletions(-) diff --git a/modelopt/torch/quantization/calib/mse.py b/modelopt/torch/quantization/calib/mse.py index 1f439a7e778..a3f026e5f03 100644 --- a/modelopt/torch/quantization/calib/mse.py +++ b/modelopt/torch/quantization/calib/mse.py @@ -181,10 +181,12 @@ def __init__( axis: int | tuple | list | None = None, quant_func: Callable | None = None, error_func: Callable | None = None, + fp8_scale_sweep_stride: int = 1, ): """Initialize NVFP4 MSE calibrator with per-block and global amax.""" super().__init__(amax=amax, axis=axis, quant_func=quant_func, error_func=error_func) self._global_amax = global_amax + self._fp8_scale_sweep_stride = max(1, fp8_scale_sweep_stride or 1) def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: if candidates.ndim != 0: # Called during final compute amax @@ -197,4 +199,9 @@ def _generate_candidates(self, device: torch.device) -> torch.Tensor: fp8_values = uint8_values.view(torch.float8_e4m3fn).float() valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) fp8_values = fp8_values[valid_mask] + if self._fp8_scale_sweep_stride > 1: + candidates = fp8_values[:: self._fp8_scale_sweep_stride] + if candidates[-1] != fp8_values[-1]: + candidates = torch.cat([candidates, fp8_values[-1:]]) + fp8_values = candidates return fp8_values / 448.0 diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 34e7f692ca0..6f8e31bfdd0 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1327,6 +1327,14 @@ class MseCalibConfig(QuantizeAlgorithmConfig): "start_multiplier, and stop_multiplier are ignored.", ) + fp8_scale_sweep_stride: int | None = ModeloptField( + default=1, + ge=1, + title="Stride for FP8 scale sweep candidates.", + description="Subsample every Nth valid FP8 E4M3 scale candidate when fp8_scale_sweep is True. " + "A value of 1 preserves the exhaustive sweep.", + ) + distributed_sync: bool | None = ModeloptField( default=True, title="Whether to sync the amax across the distributed processes.", diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0c2033041d6..bed04104752 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -110,26 +110,44 @@ def _has_expert_parallelism(module: nn.Module) -> bool: return ps is not None and ps.expert_model_parallel_group.is_initialized() -def _check_moe_calibration_complete(quantizer, parallel_state): - """Raise error if MoE calibration is incomplete (some ranks have amax, others don't).""" +def _is_dynamic_block_quantizer(quantizer) -> bool: + block_sizes = getattr(quantizer, "block_sizes", None) + if isinstance(block_sizes, dict): + return block_sizes.get("type") == "dynamic" + return getattr(block_sizes, "type", None) == "dynamic" + + +def _iter_leaf_quantizers(quantizer): if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - _check_moe_calibration_complete(_q, parallel_state) + yield from _iter_leaf_quantizers(_q) return - for group in [ - parallel_state.data_parallel_group, - parallel_state.expert_model_parallel_group, - parallel_state.tensor_parallel_group, - ]: - if not group.is_initialized(): + yield quantizer + + +def _check_moe_calibration_complete(quantizer, parallel_state): + """Raise error if MoE calibration is incomplete across distributed MoE ranks.""" + for leaf_quantizer in _iter_leaf_quantizers(quantizer): + if _is_dynamic_block_quantizer(leaf_quantizer): continue - has_amax = getattr(quantizer, "_amax", None) is not None - amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs) - if any(amax_states) and not all(amax_states): - raise RuntimeError( - "MoE calibration incomplete: some experts received no tokens during calibration. " - "Increase --calib-size to ensure all experts see calibration data." + + has_amax = getattr(leaf_quantizer, "_amax", None) is not None + for group in [ + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ]: + if not group.is_initialized(): + continue + amax_states = DistributedProcessGroup.get_dist_syncd_obj( + has_amax, group, lambda objs: objs ) + if any(amax_states) and not all(amax_states): + raise RuntimeError( + "MoE calibration incomplete: some experts received no tokens during " + "calibration. Increase --calib-size to ensure all experts see calibration " + "data." + ) @torch.no_grad() @@ -175,13 +193,13 @@ def max_calibrate( def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" - if isinstance(quantizer, SequentialQuantizer): - for _q in quantizer: - sync_quantizer_amax_across_dp_ep(_q, parallel_state) - return - if getattr(quantizer, "_amax", None) is not None: - quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) - quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + for leaf_quantizer in _iter_leaf_quantizers(quantizer): + if _is_dynamic_block_quantizer(leaf_quantizer): + continue + leaf_quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + leaf_quantizer.sync_amax_across_distributed_group( + parallel_state.expert_model_parallel_group + ) # TODO: create sync_bias_across_distributed_group # Step 2:Sync amax across data parallelism @@ -226,7 +244,7 @@ def sync_quantizer_amax_across_tp( ) # Skip amax sync for INT4 / W4A8 block quantization # Sync amax for NVFP4 (dynamic per-block, static per-tensor quantized scale) - if getattr(quantizer.block_sizes, "type", None) == "dynamic": + if _is_dynamic_block_quantizer(quantizer): return if quantizer.axis in axes_for_sync and quantizer.amax is not None: @@ -314,6 +332,7 @@ def mse_calibrate( start_multiplier: float = 0.25, stop_multiplier: float = 4.0, fp8_scale_sweep: bool = False, + fp8_scale_sweep_stride: int = 1, ): """Calibrate the model using MSE-based amax search. @@ -333,6 +352,8 @@ def mse_calibrate( for NVFP4 per-block quantization instead of using multipliers. This is specifically designed for optimizing the FP8-quantized per-block scales in NVFP4 format (default: False). + fp8_scale_sweep_stride: Subsample every Nth FP8 E4M3 candidate when + fp8_scale_sweep is enabled. A value of 1 preserves exhaustive sweep. See :class:`MseCalibConfig ` for details on the remaining arguments. @@ -388,6 +409,7 @@ def mse_calibrate( axis=module._calibrator._axis, global_amax=module.global_amax, quant_func=partial(_mse_quant_func, quantizer=module), + fp8_scale_sweep_stride=fp8_scale_sweep_stride, ) continue From f7961977d04e0695fc63873405452f84d0ea0dba Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 29 Apr 2026 12:10:59 -0700 Subject: [PATCH 05/15] fix MLM naming in recipe and add stride recipe Signed-off-by: Jennifer Chen --- .../super-nvfp4-fp8-sweep-stride4.yaml | 135 ++++++++++++++++++ .../super-nvfp4.yaml | 40 ++++++ 2 files changed, 175 insertions(+) create mode 100644 modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml new file mode 100644 index 00000000000..2cbf38e50b5 --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: weight MSE with a stride-4 FP8-scale sweep over the e4m3 scale +# values. This keeps the FP8 static-scale path but uses a coarser candidate set. +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and attention o_proj/fc1_latent_proj/fc2_latent_proj + FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Weight-MSE calibration with stride-4 FP8 scale sweep. +quantize: + algorithm: + method: mse + fp8_scale_sweep: true + fp8_scale_sweep_stride: 4 + quant_cfg: + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # Weight uses static block scales (chosen by MSE); activations stay dynamic. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + + # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. + # SSM state / mamba conv1d stay FP16. diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml index 679ebf82a44..549c3fa6caf 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -13,9 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +<<<<<<< Updated upstream # Approximately mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: # - MoE routed experts (mixer.experts..{up,down}_proj): NVFP4 W4A4 weight MSE, group_size 16 # - MoE shared experts (mixer.shared_experts.{up,down}_proj): FP8 per-tensor +======= +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +>>>>>>> Stashed changes # - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor # - KV cache: FP8 # - Attention linears ({q,k,v}_proj): BF16 (not quantized) @@ -42,6 +52,7 @@ quantize: # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. # Weight uses static block scales (chosen by MSE); activations stay dynamic. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. - quantizer_name: '*mixer.experts.*weight_quantizer' enable: true cfg: @@ -58,8 +69,26 @@ quantize: type: dynamic scale_bits: e4m3 num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. - quantizer_name: '*mixer.shared_experts.*weight_quantizer' enable: true cfg: @@ -70,6 +99,17 @@ quantize: cfg: num_bits: e4m3 axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: # Mamba mixer linears -> FP8 per-tensor. - quantizer_name: '*mixer.in_proj*weight_quantizer' From 81d9d87a49d7a6671e9cd63b6d3d3149c0bf9d81 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 29 Apr 2026 13:04:32 -0700 Subject: [PATCH 06/15] fix merge conflict Signed-off-by: Jennifer Chen --- .../super-nvfp4-fp8-sweep-stride4.yaml | 2 +- .../models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml | 9 +-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml index 2cbf38e50b5..efd07f7762a 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml @@ -31,7 +31,7 @@ # values. This keeps the FP8 static-scale path but uses a coarser candidate set. metadata: recipe_type: ptq - description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and attention o_proj/fc1_latent_proj/fc2_latent_proj + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and Latent MOE fc1_latent_proj/fc2_latent_proj FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Weight-MSE calibration with stride-4 FP8 scale sweep. quantize: algorithm: diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml index 549c3fa6caf..5ab59b69cd4 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -13,11 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -<<<<<<< Updated upstream -# Approximately mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: -# - MoE routed experts (mixer.experts..{up,down}_proj): NVFP4 W4A4 weight MSE, group_size 16 -# - MoE shared experts (mixer.shared_experts.{up,down}_proj): FP8 per-tensor -======= # Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: # - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 # HF names: mixer.experts..{up,down}_proj @@ -25,7 +20,6 @@ # - MoE shared experts: FP8 per-tensor # HF names: mixer.shared_experts.{up,down}_proj # Megatron-Core names: mlp.shared_experts.linear_fc{1,2} ->>>>>>> Stashed changes # - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor # - KV cache: FP8 # - Attention linears ({q,k,v}_proj): BF16 (not quantized) @@ -38,7 +32,7 @@ # are also chosen via MSE search instead of plain amax). metadata: recipe_type: ptq - description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and attention o_proj/fc1_latent_proj/fc2_latent_proj + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and Latent MOE fc1_latent_proj/fc2_latent_proj FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Weight-MSE calibration with FP8 scale sweep. quantize: algorithm: @@ -138,4 +132,3 @@ quantize: enable: true cfg: num_bits: e4m3 - From acf6892fb80d32ae394e914ce6d8abea82f5689b Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 29 Apr 2026 13:35:23 -0700 Subject: [PATCH 07/15] amax recipe Signed-off-by: Jennifer Chen --- .../super-nvfp4-amax.yaml | 133 ++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml new file mode 100644 index 00000000000..170738458e6 --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: amax/max calibration comparison variant. This skips MSE weight +# scale search and uses max calibration for enabled quantizers. +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and Latent MOE fc1_latent_proj/fc2_latent_proj + FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Amax calibration comparison variant. +quantize: + algorithm: + method: max + quant_cfg: + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # Weight uses static block scales (chosen by MSE); activations stay dynamic. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: static + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + + # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. + # SSM state / mamba conv1d stay FP16. From 372820e8a3ee13795564dd414a922907c0a9d83b Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 29 Apr 2026 13:36:44 -0700 Subject: [PATCH 08/15] fix config Signed-off-by: Jennifer Chen --- .../{super-nvfp4-amax.yaml => super-nvfp4-max-calib.yaml} | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) rename modelopt_recipes/models/Nemotron-3-Super-120B-A12B/{super-nvfp4-amax.yaml => super-nvfp4-max-calib.yaml} (97%) diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml similarity index 97% rename from modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml rename to modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml index 170738458e6..4f0f54ccb2f 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml @@ -14,7 +14,8 @@ # limitations under the License. # Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: -# - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 +# but with ONE major difference: use max calibration instead of MSE +# - MoE routed experts: NVFP4 W4A4 weight, group_size 16 # HF names: mixer.experts..{up,down}_proj # Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} # - MoE shared experts: FP8 per-tensor From c635b8a7d944676e50c8efda9c2c7ccd0b6d12be Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 29 Apr 2026 18:36:44 -0700 Subject: [PATCH 09/15] add amax recipe Signed-off-by: Jennifer Chen --- .../super-nvfp4-amax.yaml | 134 ++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml new file mode 100644 index 00000000000..8ae6f686930 --- /dev/null +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# but with ONE major difference: use max calibration instead of MSE. +# - MoE routed experts: NVFP4 W4A4 weight, group_size 16 +# HF names: mixer.experts..{up,down}_proj +# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} +# - MoE shared experts: FP8 per-tensor +# HF names: mixer.shared_experts.{up,down}_proj +# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} +# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor +# - KV cache: FP8 +# - Attention linears ({q,k,v}_proj): BF16 (not quantized) +# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) +# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) +# - SSM cache: FP32 (can be set to FP16 in VLLM) +# +# Calibration: amax/max calibration comparison variant. This skips MSE weight +# scale search and uses max calibration for enabled quantizers. +metadata: + recipe_type: ptq + description: Super NVFP4 mixed precision - sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and Latent MOE fc1_latent_proj/fc2_latent_proj + FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Amax calibration comparison variant. +quantize: + algorithm: + method: max + quant_cfg: + - quantizer_name: '*' + enable: false + + # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. + # Amax/max comparison uses dynamic block scales for both weight and activation. + # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. + - quantizer_name: '*mixer.experts.*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mixer.experts.*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. + - quantizer_name: '*mlp.experts*weight_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + - quantizer_name: '*mlp.experts*input_quantizer' + enable: true + cfg: + block_sizes: + -1: 16 + type: dynamic + scale_bits: e4m3 + num_bits: e2m1 + + # MoE shared experts -> FP8 per-tensor. + # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. + - quantizer_name: '*mixer.shared_experts.*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.shared_experts.*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. + - quantizer_name: '*mlp.shared_experts*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mlp.shared_experts*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # Mamba mixer linears -> FP8 per-tensor. + - quantizer_name: '*mixer.in_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.in_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*weight_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + - quantizer_name: '*mixer.out_proj*input_quantizer' + enable: true + cfg: + num_bits: e4m3 + axis: + + # KV cache -> FP8. + - quantizer_name: '*[kv]_bmm_quantizer' + enable: true + cfg: + num_bits: e4m3 + + # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. + # SSM state / mamba conv1d stay FP16. From a829722df13fcd47717dbcb7e74ab3362a9bb788 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 1 May 2026 12:29:12 -0700 Subject: [PATCH 10/15] mixed precision export for megatron Signed-off-by: Jennifer Chen --- .../torch/export/unified_export_megatron.py | 89 ++++++++++++++++--- .../export/test_unified_export_megatron.py | 59 +++++++++--- 2 files changed, 120 insertions(+), 28 deletions(-) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 23b8cfd1630..862e2031e27 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -61,6 +61,7 @@ get_weight_block_size, get_weight_scaling_factor, get_weight_scaling_factor_2, + process_layer_quant_config, to_quantized_weight, ) @@ -169,6 +170,7 @@ def __init__( self.all_rules = self._populate_rule_book() self.rules = self.all_rules[self.arch] self.exclude_modules = [] + self.layer_config_dict = {} if not hasattr(model, "_modelopt_state"): return @@ -324,22 +326,32 @@ def save_pretrained( print(f"Successfully loaded {len(mtp_state_dict)} MTP tensors") combined_exclude_modules = self._gather_exclude_modules() + combined_layer_config_dict = self._gather_layer_config_dict() if is_last_stage_main_rank and quantization is not None: - self._hf_quant_config = { + if combined_layer_config_dict: + quantization_config = process_layer_quant_config(combined_layer_config_dict) + quantization_config["exclude_modules"] = combined_exclude_modules + else: + quantization_config = { + "quant_algo": quantization, + "exclude_modules": combined_exclude_modules, + } + if quantization == "NVFP4": # update block size + quantization_config["group_size"] = 16 + + if hasattr(self, "kv_cache_dtype"): + quantization_config["kv_cache_quant_algo"] = self.kv_cache_dtype + + raw_hf_quant_config = { "producer": { "name": "modelopt", "version": __version__, }, - "quantization": { - "quant_algo": quantization, - "exclude_modules": combined_exclude_modules, - }, + "quantization": quantization_config, } - if quantization == "NVFP4": # update block size - self._hf_quant_config["quantization"]["group_size"] = 16 - if hasattr(self, "kv_cache_dtype"): - self._hf_quant_config["quantization"]["kv_cache_quant_algo"] = self.kv_cache_dtype + # Use one serving-facing config for both hf_quant_config.json and config.json. + self._hf_quant_config = convert_hf_quant_config_format(raw_hf_quant_config) with open(save_directory + "/hf_quant_config.json", "w") as f: json.dump(self._hf_quant_config, f, indent=4) @@ -359,10 +371,9 @@ def save_pretrained( # Newer versions of VLLM expect config.json with hf_quant_config config_json_file = save_directory + "/config.json" if self._hf_quant_config and os.path.exists(config_json_file): - converted_quant_config = convert_hf_quant_config_format(self._hf_quant_config) with open(config_json_file) as f: config_dict = json.load(f) - config_dict["quantization_config"] = converted_quant_config + config_dict["quantization_config"] = self._hf_quant_config with open(config_json_file, "w") as f: json.dump(config_dict, f, indent=4) @@ -803,9 +814,7 @@ def _get_quantized_state( name_to_value = {} qformat: str = self._get_quantization_format(module) if qformat is None and "norm" not in prefix: - # Add exclude layers for hf_quant_config. Note that if the prefix is not an empty - # string then it usually ends with "." which needs to be removed. - self.exclude_modules.append(prefix.removesuffix(".")) + self._record_excluded_module(prefix) block_size = get_weight_block_size(module) name_to_value = self._get_weight_bias(module, dtype, name_to_value) @@ -850,6 +859,27 @@ def _get_weight_scales(self, quantized_state: dict[str, Any], qformat: str): return weight_scale, weight_scale_2 + def _record_layer_quant_config(self, prefix: str, qformat: str | None, block_size: int): + """Record per-HF-layer quantization metadata for mixed precision exports.""" + if qformat in (None, QUANTIZATION_NONE): + return + + layer_name = prefix.removesuffix(".") + if "{" in layer_name or not layer_name: + return + + self.layer_config_dict[layer_name + ".quantization"] = qformat + self.layer_config_dict[layer_name + ".awq_block_size"] = block_size + + def _record_excluded_module(self, prefix: str): + """Record an unquantized HF module prefix for hf_quant_config.""" + layer_name = prefix.removesuffix(".") + if "{" in layer_name or not layer_name: + return + + if layer_name not in self.exclude_modules: + self.exclude_modules.append(layer_name) + def _name_remapping( self, module: torch.nn.Module | torch.Tensor, @@ -866,6 +896,7 @@ def _name_remapping( return name_to_value, qformat, block_size = self._get_quantized_state(module, dtype, prefix=prefix) + self._record_layer_quant_config(prefix, qformat, block_size) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -906,6 +937,8 @@ def _gated_mlp_slicing( gate_proj_prefix = prefix + gate_proj_name + "." up_proj_prefix = prefix + up_proj_name + "." + self._record_layer_quant_config(gate_proj_prefix, qformat, block_size) + self._record_layer_quant_config(up_proj_prefix, qformat, block_size) ffn_hidden_size = module.config.ffn_hidden_size gate_proj_weight = weight[:ffn_hidden_size, :] @@ -986,6 +1019,7 @@ def _grouped_mlp_slicing(self, module, prefix, parallel_config=None): for expert_id in range(num_experts): expert_prefix = prefix.format(expert_id) + "." + self._record_layer_quant_config(expert_prefix, qformat, block_size) weight_key = f"weight{expert_id}" if weight_key not in state_dict: @@ -1030,6 +1064,18 @@ def _qkv_slicing( q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." v_proj_prefix = prefix + v_proj_name + "." + self._record_layer_quant_config(q_proj_prefix, qformat, block_size) + self._record_layer_quant_config(k_proj_prefix, qformat, block_size) + self._record_layer_quant_config(v_proj_prefix, qformat, block_size) + if qformat in (None, QUANTIZATION_NONE): + # MCore stores Q/K/V as one fused linear_qkv module, but HF exports them + # as separate q_proj/k_proj/v_proj modules. Record the HF names so + # runtime quant configs do not miss excluded fused-QKV projections. + fused_prefix = prefix.removesuffix(".") + self.exclude_modules = [m for m in self.exclude_modules if m != fused_prefix] + self._record_excluded_module(q_proj_prefix) + self._record_excluded_module(k_proj_prefix) + self._record_excluded_module(v_proj_prefix) config = module.config hidden_size = config.hidden_size @@ -1179,6 +1225,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): weight_scale_list.append(weight_scale) weight_scale_2_list.append(weight_scale_2) input_scale_list.append(input_scale) + self._record_layer_quant_config(prefix, qformat, block_size) merged_weight = torch.stack(weight_list, dim=0) @@ -1247,6 +1294,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): weight_scale_2_list.append(weight_scale_2) input_scale_list.append(input_scale) bias_list.append(bias) + self._record_layer_quant_config(prefix, qformat, block_size) merged_weight = torch.stack(weight_list, dim=0) @@ -1349,6 +1397,19 @@ def _gather_exclude_modules(self): combined_exclude_modules.update(modules) return sorted(combined_exclude_modules) + def _gather_layer_config_dict(self): + """Get per-layer quantization metadata from all ranks for hf_quant_config.""" + if not torch.distributed.is_initialized(): + return dict(sorted(self.layer_config_dict.items())) + + all_layer_config_dicts = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(all_layer_config_dicts, self.layer_config_dict) + combined_layer_config_dict = {} + for layer_config_dict in all_layer_config_dicts: + if layer_config_dict: + combined_layer_config_dict.update(layer_config_dict) + return dict(sorted(combined_layer_config_dict.items())) + def export_mcore_gpt_to_hf( model: torch.nn.Module, diff --git a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py index 3fac8269ccd..8dfbc0323c2 100644 --- a/tests/gpu_megatron/torch/export/test_unified_export_megatron.py +++ b/tests/gpu_megatron/torch/export/test_unified_export_megatron.py @@ -29,7 +29,7 @@ import modelopt.torch.quantization as mtq import modelopt.torch.speculative as mtsp -from modelopt.torch.export import KV_CACHE_FP8, export_mcore_gpt_to_hf, import_mcore_gpt_from_hf +from modelopt.torch.export import export_mcore_gpt_to_hf, import_mcore_gpt_from_hf from modelopt.torch.export.unified_export_megatron import GPTModelExporter from modelopt.torch.speculative.eagle.default_config import default_eagle_config from modelopt.torch.speculative.plugins.megatron_eagle import _DynamicEagleGPTModel @@ -42,15 +42,8 @@ def _verify_model_quant_config( """Verify config.json and hf_quant_config.json""" config_dict = json.load(open(export_dir / "config.json")) hf_quant_config_dict = json.load(open(export_dir / "hf_quant_config.json")) - # Make sure config.json and hf_quant_config.json are consistent - assert ( - config_dict["quantization_config"]["quant_algo"] - == hf_quant_config_dict["quantization"]["quant_algo"] - ) - assert ( - config_dict["quantization_config"]["ignore"] - == hf_quant_config_dict["quantization"]["exclude_modules"] - ) + # Make sure config.json and hf_quant_config.json use the same serving config. + assert config_dict["quantization_config"] == hf_quant_config_dict # Verify config.json if kv_cache_quant_cfg: @@ -58,17 +51,17 @@ def _verify_model_quant_config( # Verify hf_quant_config.json if quant_config: - quant_config_dict = hf_quant_config_dict["quantization"] + quant_config_dict = hf_quant_config_dict quant_type = quant_config_dict["quant_algo"] assert ( quant_type in quant_config ) # quant config str is subset of quant config e.g. NVFP4 -> NVFP4_DEFAULT_CFG - assert len(quant_config_dict["exclude_modules"]) > 1 # Dynamically added exclude modules + assert len(quant_config_dict["ignore"]) > 1 # Dynamically added exclude modules if quant_type == "NVFP4": - assert quant_config_dict["group_size"] == 16 + assert quant_config_dict["config_groups"]["group_0"]["weights"]["group_size"] == 16 if kv_cache_quant_cfg: - assert quant_config_dict["kv_cache_quant_algo"] == KV_CACHE_FP8 + assert quant_config_dict["kv_cache_scheme"]["num_bits"] == 8 def _test_unified_export_megatron( @@ -295,6 +288,44 @@ def test_qkv_slicing_gqa_tp2(dist_workers_size_2, tmp_path): dist_workers_size_2.run(partial(_test_qkv_slicing_gqa_tp2, tmp_path)) +def test_qkv_slicing_records_hf_excludes_for_unquantized_fused_qkv(): + """Unquantized fused MCore linear_qkv should become HF q/k/v excludes.""" + exporter = object.__new__(GPTModelExporter) + exporter.dtype = torch.bfloat16 + exporter.exclude_modules = ["backbone.layers.0.mixer"] + exporter.layer_config_dict = {} + exporter._state_dict = {} + + hidden_size = 8 + head_size = 4 + num_attention_heads = 2 + num_query_groups = 1 + qkv_dim = num_attention_heads + 2 * num_query_groups + weight = torch.arange(qkv_dim * head_size * hidden_size, dtype=torch.bfloat16).reshape( + qkv_dim * head_size, hidden_size + ) + + module = torch.nn.Module() + module.config = type( + "Config", + (), + { + "hidden_size": hidden_size, + "num_query_groups": num_query_groups, + "num_attention_heads": num_attention_heads, + "kv_channels": head_size, + }, + )() + exporter._get_quantized_state = lambda *args, **kwargs: ({"weight": weight}, None, 0) + + exporter._qkv_slicing(module, "backbone.layers.0.mixer.") + + assert "backbone.layers.0.mixer" not in exporter.exclude_modules + assert "backbone.layers.0.mixer.q_proj" in exporter.exclude_modules + assert "backbone.layers.0.mixer.k_proj" in exporter.exclude_modules + assert "backbone.layers.0.mixer.v_proj" in exporter.exclude_modules + + def _make_exporter_for_mtp(model_dir: Path) -> GPTModelExporter: """Create a minimal GPTModelExporter instance for testing _get_mtp_state_dict.""" exporter = object.__new__(GPTModelExporter) From 5e32bd1befe881f8e7a022d12a37e3d2164d9cc0 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 1 May 2026 12:39:05 -0700 Subject: [PATCH 11/15] fix mcore ckpt resume for static quantizers and MSE export Signed-off-by: Jennifer Chen --- .../torch/quantization/plugins/megatron.py | 8 ++++++-- .../torch/quantization/qtensor/nvfp4_tensor.py | 18 ++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 0b50fd937ae..b2347f10492 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -399,7 +399,9 @@ def _get_shard_axis_dict(self, state_dict): """ shard_axis_dict = {} for k in state_dict: - if "weight_quantizer." in k: + # Static NVFP4 _global_amax is a scalar shared by all TP ranks; only shard + # per-block/per-channel weight quantizer state. + if "weight_quantizer." in k and not k.endswith("._global_amax"): weight_quantizer_axis = self.get_submodule(k.rsplit(".", 1)[0]).axis if weight_quantizer_axis is not None: shard_axis_dict[k] = 0 @@ -427,7 +429,9 @@ def _get_shard_axis_dict(self, state_dict): """ shard_axis_dict = {} for k in state_dict: - if "weight_quantizer." in k: + # Static NVFP4 _global_amax is a scalar shared by all TP ranks; only shard + # per-block/per-channel weight quantizer state. + if "weight_quantizer." in k and not k.endswith("._global_amax"): weight_quantizer_axis = None if isinstance(self.weight_quantizer, TensorQuantizer): weight_quantizer_axis = self.weight_quantizer.axis diff --git a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py index fe30e283c2d..8cb0b66e45e 100644 --- a/modelopt/torch/quantization/qtensor/nvfp4_tensor.py +++ b/modelopt/torch/quantization/qtensor/nvfp4_tensor.py @@ -55,7 +55,16 @@ def get_e2m1_bounds(cls, device): @classmethod def _is_static_quantizer(cls, weight_quantizer) -> bool: """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax.""" - return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None + global_amax = cls._get_static_global_amax(weight_quantizer) + return global_amax is not None + + @classmethod + def _get_static_global_amax(cls, weight_quantizer): + """Return global amax from live or restored static NVFP4 quantizers.""" + global_amax = getattr(weight_quantizer, "global_amax", None) + if global_amax is None: + global_amax = getattr(weight_quantizer, "_global_amax", None) + return global_amax @classmethod def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): @@ -70,8 +79,9 @@ def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer): Returns: The global scaling factor as a float tensor. """ - if cls._is_static_quantizer(weight_quantizer): - return weight_quantizer.global_amax.float() / (6.0 * 448.0) + global_amax = cls._get_static_global_amax(weight_quantizer) + if global_amax is not None: + return global_amax.float() / (6.0 * 448.0) else: assert hasattr(weight_quantizer, "_amax"), ( "Weight quantizer does not have attribute amax" @@ -109,7 +119,7 @@ def get_weights_scaling_factor_from_quantizer( if cls._is_static_quantizer(weight_quantizer): # Static path: use pre-computed per-block amax values from quantizer - global_amax = weight_quantizer.global_amax.float() + global_amax = cls._get_static_global_amax(weight_quantizer).float() per_block_amax = weight_quantizer._amax.float() # Compute scales in float From ef374568ee2fbce5da3ceab6b4ad7a7b4184cefd Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 1 May 2026 13:51:49 -0700 Subject: [PATCH 12/15] remove duplicate recipe Signed-off-by: Jennifer Chen --- .../super-nvfp4-amax.yaml | 134 ------------------ 1 file changed, 134 deletions(-) delete mode 100644 modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml deleted file mode 100644 index 8ae6f686930..00000000000 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-amax.yaml +++ /dev/null @@ -1,134 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: -# but with ONE major difference: use max calibration instead of MSE. -# - MoE routed experts: NVFP4 W4A4 weight, group_size 16 -# HF names: mixer.experts..{up,down}_proj -# Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} -# - MoE shared experts: FP8 per-tensor -# HF names: mixer.shared_experts.{up,down}_proj -# Megatron-Core names: mlp.shared_experts.linear_fc{1,2} -# - Mamba mixer linears (mixer.{in,out}_proj): FP8 per-tensor -# - KV cache: FP8 -# - Attention linears ({q,k,v}_proj): BF16 (not quantized) -# - MTP head, lm_head, output, mamba conv1d: BF16 (not quantized) -# - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) -# - SSM cache: FP32 (can be set to FP16 in VLLM) -# -# Calibration: amax/max calibration comparison variant. This skips MSE weight -# scale search and uses max calibration for enabled quantizers. -metadata: - recipe_type: ptq - description: Super NVFP4 mixed precision - sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and Latent MOE fc1_latent_proj/fc2_latent_proj - FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Amax calibration comparison variant. -quantize: - algorithm: - method: max - quant_cfg: - - quantizer_name: '*' - enable: false - - # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. - # Amax/max comparison uses dynamic block scales for both weight and activation. - # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. - - quantizer_name: '*mixer.experts.*weight_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - - quantizer_name: '*mixer.experts.*input_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - # Megatron-Core/PTQ names: decoder.layers.*.mlp.experts.local_experts.*.linear_fc{1,2}. - - quantizer_name: '*mlp.experts*weight_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - - quantizer_name: '*mlp.experts*input_quantizer' - enable: true - cfg: - block_sizes: - -1: 16 - type: dynamic - scale_bits: e4m3 - num_bits: e2m1 - - # MoE shared experts -> FP8 per-tensor. - # HF/export names: backbone.layers.*.mixer.shared_experts.{up,down}_proj. - - quantizer_name: '*mixer.shared_experts.*weight_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*mixer.shared_experts.*input_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - # Megatron-Core/PTQ names: decoder.layers.*.mlp.shared_experts.linear_fc{1,2}. - - quantizer_name: '*mlp.shared_experts*weight_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*mlp.shared_experts*input_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - # Mamba mixer linears -> FP8 per-tensor. - - quantizer_name: '*mixer.in_proj*weight_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*mixer.in_proj*input_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*mixer.out_proj*weight_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - quantizer_name: '*mixer.out_proj*input_quantizer' - enable: true - cfg: - num_bits: e4m3 - axis: - - # KV cache -> FP8. - - quantizer_name: '*[kv]_bmm_quantizer' - enable: true - cfg: - num_bits: e4m3 - - # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. - # SSM state / mamba conv1d stay FP16. From b5f4b56087edfabad44eb37172eab321dee349e5 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 1 May 2026 13:53:18 -0700 Subject: [PATCH 13/15] fix docstring Signed-off-by: Jennifer Chen --- .../super-nvfp4-fp8-sweep-stride4.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml index efd07f7762a..b3adc1dca5c 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json: +# Mirrors the published nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 hf_quant_config.json BUT adds a stride=4 for FP8 scale sweep, which is useful for large models to improve PTQ efficiency. # - MoE routed experts: NVFP4 W4A4 weight MSE, group_size 16 # HF names: mixer.experts..{up,down}_proj # Megatron-Core names: mlp.experts.local_experts..linear_fc{1,2} From b5c5331a85c73f067b28b8add02611a59c9629d6 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 1 May 2026 13:53:33 -0700 Subject: [PATCH 14/15] fix max calib recipe Signed-off-by: Jennifer Chen --- .../Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml index 4f0f54ccb2f..7c02aaea2aa 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml @@ -42,14 +42,14 @@ quantize: enable: false # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. - # Weight uses static block scales (chosen by MSE); activations stay dynamic. + # Max/amax calibration uses dynamic block scales for both weight and activation. # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. - quantizer_name: '*mixer.experts.*weight_quantizer' enable: true cfg: block_sizes: -1: 16 - type: static + type: dynamic scale_bits: e4m3 num_bits: e2m1 - quantizer_name: '*mixer.experts.*input_quantizer' @@ -66,7 +66,7 @@ quantize: cfg: block_sizes: -1: 16 - type: static + type: dynamic scale_bits: e4m3 num_bits: e2m1 - quantizer_name: '*mlp.experts*input_quantizer' From 5de5541c2fcd9d5100b0673536e3efcf96cd6a15 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 1 May 2026 14:05:24 -0700 Subject: [PATCH 15/15] cleanup recipes Signed-off-by: Jennifer Chen --- .../super-nvfp4-fp8-sweep-stride4.yaml | 9 ++++----- .../super-nvfp4-max-calib.yaml | 13 +++++-------- .../Nemotron-3-Super-120B-A12B/super-nvfp4.yaml | 8 ++++---- 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml index b3adc1dca5c..5a95dea8993 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-fp8-sweep-stride4.yaml @@ -31,14 +31,16 @@ # values. This keeps the FP8 static-scale path but uses a coarser candidate set. metadata: recipe_type: ptq - description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and Latent MOE fc1_latent_proj/fc2_latent_proj - FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Weight-MSE calibration with stride-4 FP8 scale sweep. + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj + FP8 per-tensor; FP8 KV cache; everything else(lm_head/MTP/Latent MOE) stay BF16. Weight-MSE calibration with stride-4 FP8 scale sweep. quantize: algorithm: method: mse fp8_scale_sweep: true fp8_scale_sweep_stride: 4 quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. - quantizer_name: '*' enable: false @@ -130,6 +132,3 @@ quantize: enable: true cfg: num_bits: e4m3 - - # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. - # SSM state / mamba conv1d stay FP16. diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml index 7c02aaea2aa..74ae05d2933 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4-max-calib.yaml @@ -28,21 +28,21 @@ # - Latent MOE (fc1_latent_proj, fc2_latent_proj): BF16 (not quantized) # - SSM cache: FP32 (can be set to FP16 in VLLM) # -# Calibration: amax/max calibration comparison variant. This skips MSE weight -# scale search and uses max calibration for enabled quantizers. +# Calibration: amax/max calibration comparison variant metadata: recipe_type: ptq - description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and Latent MOE fc1_latent_proj/fc2_latent_proj - FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Amax calibration comparison variant. + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj + FP8 per-tensor; FP8 KV cache; everything else(lm_head/MTP/Latent MOE) stay BF16. Amax calibration comparison variant. quantize: algorithm: method: max quant_cfg: + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. - quantizer_name: '*' enable: false # MoE routed experts -> NVFP4 W4A4, block_size 16, e4m3 scale. - # Max/amax calibration uses dynamic block scales for both weight and activation. # HF/export names: backbone.layers.*.mixer.experts.*.{up,down}_proj. - quantizer_name: '*mixer.experts.*weight_quantizer' enable: true @@ -129,6 +129,3 @@ quantize: enable: true cfg: num_bits: e4m3 - - # Stay BF16: lm_head, output projection, MoE routers/gates, MTP head. - # SSM state / mamba conv1d stay FP16. diff --git a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml index 5ab59b69cd4..f6535cca812 100644 --- a/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml +++ b/modelopt_recipes/models/Nemotron-3-Super-120B-A12B/super-nvfp4.yaml @@ -32,15 +32,15 @@ # are also chosen via MSE search instead of plain amax). metadata: recipe_type: ptq - description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj, and Latent MOE fc1_latent_proj/fc2_latent_proj - FP8 per-tensor; FP8 KV cache; lm_head/MTP/SSM stay BF16/FP16. Weight-MSE calibration with FP8 scale sweep. + description: Super NVFP4 mixed precision — sparse MoE experts NVFP4 (W4A4, group_size 16); shared experts, mamba in/out_proj + FP8 per-tensor; FP8 KV cache; everything else(lm_head/MTP/latent MOE) stay BF16. Weight-MSE calibration with FP8 scale sweep. quantize: algorithm: method: mse fp8_scale_sweep: true quant_cfg: - # Disable all layers by default so that these layers stay in their original precision: BF16/FP32: - # lm_head, output projection, MoE routers/gates, MTP head, SSM state, mamba conv1d. + # Disable all layers by default so that these layers stay in original BF16 precision: + # lm_head, output projection, MoE routers/gates, Latent MOE, MTP head, mamba conv1d. - quantizer_name: '*' enable: false