Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,11 @@ def pre_quantize(
preview_input_ids = next(iter(calib_dataloader))[
"input_features" if model_type == "whisper" else "input_ids"
][0:1]
# Strip leading padding tokens so the preview input shows real content
if model_type not in ("whisper",) and tokenizer is not None and tokenizer.pad_token_id is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor: if preview_input_ids has no non-pad tokens (e.g. all tokens are padding), first_non_pad will be empty and first_non_pad[0] will error. The first_non_pad.numel() > 0 check correctly guards this — just confirming it's intentional that the original (all-padding) input is preserved in that edge case.

first_non_pad = (preview_input_ids[0] != tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
if first_non_pad.numel() > 0:
preview_input_ids = preview_input_ids[:, first_non_pad[0]:]

# Generate preview before quantization
if args.skip_generate:
Expand Down Expand Up @@ -897,7 +902,7 @@ def input_decode(input_ids):
if processor is not None and isinstance(processor, WhisperProcessor):
return first_text_speech_dataset
elif tokenizer is not None:
return tokenizer.batch_decode(input_ids)
return tokenizer.batch_decode(input_ids, skip_special_tokens=True)
else:
raise ValueError("The processor or tokenizer must be set")

Expand Down
106 changes: 97 additions & 9 deletions modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,26 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
# 2-3. Split + export each per-expert projection.
fused_dim0 = gate_up.shape[1] # 2 * expert_dim

def _safe_cpu_amax(quantizer_src: nn.Module) -> torch.Tensor | None:
"""Extract _amax to CPU float32, surfacing and clearing any pending CUDA error first."""
amax = getattr(quantizer_src, "_amax", None)
if amax is None or not isinstance(amax, torch.Tensor):
return None
try:
if amax.is_cuda:
torch.cuda.synchronize(amax.device)
return amax.detach().cpu().float()
except Exception:
return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Bare except Exception silently swallows all errors and returns None. While defensive coding for corrupt CUDA tensors is reasonable, this could mask unrelated bugs. Consider catching a narrower set of exceptions (e.g. RuntimeError) or at minimum logging a warning when the fallback is triggered:

except RuntimeError:
    warnings.warn(f"Failed to extract _amax to CPU for {quantizer_src}, using fallback")
    return None


for idx in range(n):
expert = nn.Module()

# Extract amaxes to CPU before deepcopy: cloning a corrupt CUDA _amax tensor
# (e.g. from an under-calibrated expert) triggers an async CUDA error.
gu_amax_cpu = _safe_cpu_amax(module.gate_up_proj_weight_quantizers[idx])
down_amax_cpu = _safe_cpu_amax(module.down_proj_weight_quantizers[idx])

projections = [
("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True),
("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True),
Expand All @@ -76,8 +93,17 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
)
i_quantizer = gate_up_input_q if is_gate_up else down_input_q

# gate/up share a weight quantizer — clone so each gets independent amax.
w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src
# gate/up share a quantizer — deepcopy with _amax nulled to avoid cloning
# the corrupt CUDA tensor, then inject the pre-extracted CPU amax.
if is_gate_up:
_saved_amax = getattr(w_quantizer_src, "_amax", None)
w_quantizer_src._amax = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

For down_proj (non-gate_up), w_quantizer = w_quantizer_src — this is the original quantizer, not a copy. Then w_quantizer._amax = down_amax_cpu mutates the original quantizer's _amax. This is fine if the module is only exported once, but is potentially surprising. A comment noting this is intentional mutation of the original would help.

w_quantizer = copy.deepcopy(w_quantizer_src)
w_quantizer_src._amax = _saved_amax
w_quantizer._amax = gu_amax_cpu
Comment on lines +98 to +103
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Protect temporary _amax mutation with try/finally.

If copy.deepcopy() raises, _amax is left as None on the source quantizer. Wrap restore in finally to avoid state corruption on failure.

Proposed fix
             if is_gate_up:
                 _saved_amax = getattr(w_quantizer_src, "_amax", None)
-                w_quantizer_src._amax = None
-                w_quantizer = copy.deepcopy(w_quantizer_src)
-                w_quantizer_src._amax = _saved_amax
+                w_quantizer_src._amax = None
+                try:
+                    w_quantizer = copy.deepcopy(w_quantizer_src)
+                finally:
+                    w_quantizer_src._amax = _saved_amax
                 w_quantizer._amax = gu_amax_cpu
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/moe_utils.py` around lines 98 - 103, The temporary
mutation of w_quantizer_src._amax before calling copy.deepcopy may leave the
source quantizer with _amax == None if deepcopy raises; change the code around
copy.deepcopy(w_quantizer_src) to save _saved_amax, set w_quantizer_src._amax =
None, then perform deepcopy inside a try block and restore w_quantizer_src._amax
= _saved_amax in a finally block; after deepcopy set w_quantizer._amax =
gu_amax_cpu as before so the source state is always restored even on exceptions.

else:
w_quantizer = w_quantizer_src
w_quantizer._amax = down_amax_cpu

# For per-channel amax (dim >= 1), proportionally slice dim-0
# to match the split weight.
Expand All @@ -86,12 +112,14 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
and w_quantizer._amax is not None
and w_quantizer._amax.dim() >= 1
):
amax = w_quantizer._amax
amax = w_quantizer._amax # CPU float32
amax_dim0 = amax.shape[0]
if fused_total % amax_dim0 == 0:
if amax_dim0 % fused_total == 0:
slice_start = fused_start * amax_dim0 // fused_total
slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total
w_quantizer.amax = amax[slice_start:slice_end].contiguous()
# Bypass amax.setter (which forbids shape changes); w_quantizer is a
# deepcopy for gate/up so mutating it is safe.
w_quantizer._amax = amax[slice_start:slice_end].contiguous()
else:
warnings.warn(
f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not "
Expand All @@ -100,20 +128,68 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
stacklevel=2,
)

# If the weight quantizer was never calibrated, compute amax from weights.
# Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Bug: _MIN_VALID_AMAX = 1e-4 is below the FP8 E4M3FN minimum subnormal (2^-9 ≈ 0.00195). Values between 1e-4 and ~0.00195 will pass this validity check but will still underflow to 0 when cast to FP8 E4M3FN. Consider using 2e-3 (which you already use for clamping) or 2**-9 as the minimum valid threshold for consistency with the nvfp4_tensor.py fix.

# with weight-derived fallback values.
_MIN_VALID_AMAX = 1e-4
_MAX_VALID_AMAX = 1e6
if (
hasattr(w_quantizer, "_amax")
and w_quantizer._amax is not None
and w_quantizer._amax.numel() > 1
):
amax_cpu = w_quantizer._amax
invalid_mask = ~(
torch.isfinite(amax_cpu)
& (amax_cpu >= _MIN_VALID_AMAX)
& (amax_cpu <= _MAX_VALID_AMAX)
)
if invalid_mask.any():
per_block_fallback = (
weight_slice.detach()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Hardcoded block_size=16 here and at line 173. If the quantizer's actual block size is different, the reshape will produce an incorrect shape. Consider extracting the block size from the weight quantizer (e.g. w_quantizer.block_sizes.get(-1, 16)) rather than hardcoding.

.reshape(-1, 16)
.abs()
.amax(dim=1, keepdim=True)
.cpu()
.float()
.clamp(min=2e-3)
.reshape(amax_cpu.shape)
)
amax_cpu[invalid_mask] = per_block_fallback[invalid_mask]
w_quantizer._amax = amax_cpu

# For uncalibrated experts (amax missing or invalid scalar), fall back to
# per-block amax from weights so the static export path can reshape it correctly.
if (
hasattr(w_quantizer, "is_enabled")
and w_quantizer.is_enabled
and (
not hasattr(w_quantizer, "_amax")
or w_quantizer._amax is None
or torch.all(w_quantizer._amax == 0)
or (
w_quantizer._amax.numel() == 1
and not (
torch.isfinite(w_quantizer._amax)
and w_quantizer._amax >= _MIN_VALID_AMAX
and w_quantizer._amax <= _MAX_VALID_AMAX
)
)
)
):
w_quantizer.amax = weight_slice.abs().amax().to(torch.float32)
_block_size = 16
fallback_per_block = (
weight_slice.detach()
.reshape(-1, _block_size)
.abs()
.amax(dim=1, keepdim=True)
.cpu()
.float()
.clamp(min=2e-3)
.reshape(*weight_slice.shape[:-1], weight_slice.shape[-1] // _block_size)
)
w_quantizer._amax = fallback_per_block
warnings.warn(
f"Expert {idx} {proj_name} weight quantizer was not calibrated "
f"(amax missing or zero). Using weight-derived amax as fallback. "
f"(amax missing or zero). Using weight-derived per-block amax as fallback. "
f"Consider using more calibration data to activate all experts.",
stacklevel=2,
)
Expand All @@ -123,6 +199,18 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None:
wrapper.weight_quantizer = w_quantizer
wrapper.input_quantizer = i_quantizer

# Set global_amax to route to the static NVFP4 export path (reads per-block _amax).
# Always recompute from the current (possibly patched) _amax — a stale zero
# global_amax causes division-by-zero in the per-block scale formula.
wq = wrapper.weight_quantizer
if (
hasattr(wq, "_amax")
and wq._amax is not None
and wq._amax.numel() > 1
):
wq._amax = wq._amax.to(weight_slice.device)
wq.global_amax = wq._amax.float().amax().clamp(min=2e-3)

_export_quantized_weight(wrapper, dtype)

proj = nn.Module()
Expand Down
18 changes: 17 additions & 1 deletion modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,18 @@ def mse_calibrate(
if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled:
if getattr(weight_quantizer, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, weight_name, weight_quantizer))
# _QuantFusedExperts stores per-expert weight quantizers as nn.ModuleList named
# {param_name}_weight_quantizers (plural). Detect this pattern and enqueue each
# per-expert quantizer individually.
for param_name, _ in parent_module.named_parameters(recurse=False):
qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None)
if not isinstance(qlist, nn.ModuleList):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The pattern named_parameters(recurse=False) + checking for f"{param_name}_weight_quantizers" works for the current _QuantFusedExperts layout, but is fairly fragile. If other modules happen to have a parameter and a same-named ModuleList with _weight_quantizers suffix, they'd be picked up too. Consider adding a type check (e.g. checking if parent_module is a _QuantFusedExperts instance) or at least a comment noting the assumption.

continue
for expert_idx, wq in enumerate(qlist):
if isinstance(wq, TensorQuantizer) and wq.is_enabled:
if getattr(wq, "_calibrator", None) is not None:
weight_quantizers.append((parent_module, (param_name, expert_idx), wq))

seen_modules.add(parent_module)

# Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation
Expand All @@ -432,7 +444,11 @@ def mse_calibrate(
weight_quantizer.disable_quant()
weight_quantizer.enable_calib()
with enable_weight_access_and_writeback(parent_module, model, name_to_module):
weight = getattr(parent_module, weight_name)
if isinstance(weight_name, tuple):
param_name, expert_idx = weight_name
weight = getattr(parent_module, param_name)[expert_idx]
else:
weight = getattr(parent_module, weight_name)
weight_quantizer(weight)

# IMMEDIATELY compute amax and reset calibrator to free memory
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/model_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ def print_quant_summary(model: nn.Module, output_dir: str | None = None):
lines.append(f"{len(lines)} TensorQuantizers found in model")

if output_dir:
os.makedirs(output_dir, exist_ok=True)
path = os.path.join(output_dir, ".quant_summary.txt")
with open(path, "w", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def forward(self, inputs):

return outputs

def _short_amax(self, fmt=".4f"):
def _short_amax(self, fmt=".2e"):
"""Short description of amax.

Returns:
Expand All @@ -1130,7 +1130,7 @@ def _short_amax(self, fmt=".4f"):
return "meta"
return self._short_tensor(self._amax, fmt)

def _short_tensor(self, tensor: torch.Tensor, fmt=".4f"):
def _short_tensor(self, tensor: torch.Tensor, fmt=".2e"):
"""Short description of tensor."""
if tensor.numel() == 1:
return f"{tensor.item():{fmt}}"
Expand Down
13 changes: 10 additions & 3 deletions modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def get_weights_scaling_factor_from_quantizer(

# Quantize scales to FP8
if not keep_high_precision:
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to(
torch.float8_e4m3fn
)
_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive subnormal
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).clamp(
min=_FP8_E4M3FN_MIN
).to(torch.float8_e4m3fn)
return per_block_scale, weights_scaling_factor_2
else:
# Dynamic path: compute from weight tensor
Expand Down Expand Up @@ -167,6 +168,12 @@ def get_weights_scaling_factor(
per_block_scale[per_block_scale == 0] = 1.0
# Convert to torch.float8_e4m3fn
if not keep_high_precision:
# Clamp to the minimum positive FP8 E4M3FN subnormal (~0.00195 = 2^-9) before
# casting. Without this, blocks whose scale falls below the FP8 representable
# range silently underflow to 0, causing those blocks to produce zero output at
# inference even when the weights are non-trivial.
_FP8_E4M3FN_MIN = 2**-9 # 0.001953125 — smallest positive subnormal
per_block_scale = per_block_scale.clamp(min=_FP8_E4M3FN_MIN)
per_block_scale = per_block_scale.to(torch.float8_e4m3fn)
return per_block_scale, weights_scaling_factor_2

Expand Down
130 changes: 130 additions & 0 deletions modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Copyright year is 2024 but the project's LICENSE_HEADER specifies 2026. New files should use the current year from the canonical header.

# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

metadata:
recipe_type: ptq
description: >
NVFP4 W4A4 for MoE routed experts only. Static weight scales via MSE + FP8 scale sweep;
dynamic activation scales. Supports sequential experts (nn.Linear-based) and fused experts
(_QuantFusedExperts, HF transformers 5.0+ 3D nn.Parameter style).
quantize:
algorithm:
method: mse
fp8_scale_sweep: true
layerwise: false
quant_cfg:
# ── Disable everything first ─────────────────────────────────────────────
- quantizer_name: '*'
enable: false

# ── Sequential experts (nn.Linear per expert) ────────────────────────────
- quantizer_name: '*mlp.experts*weight_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: static
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*mlp.experts*input_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1

# ── Sequential experts: Mixtral / block_sparse_moe style ────────────────
- quantizer_name: '*block_sparse_moe*weight_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: static
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*block_sparse_moe*input_quantizer'
enable: true
cfg:
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1

# ── Fused experts (_QuantFusedExperts, HF transformers 5.0+ 3D nn.Parameter style) ──
- quantizer_name: '*gate_up_proj_weight_quantizers*'
enable: true
cfg:
block_sizes:
-1: 16
type: static
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*gate_up_proj_input_quantizer*'
enable: true
cfg:
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*down_proj_weight_quantizers*'
enable: true
cfg:
block_sizes:
-1: 16
type: static
scale_bits: e4m3
num_bits: e2m1
- quantizer_name: '*down_proj_input_quantizer*'
enable: true
cfg:
block_sizes:
-1: 16
type: dynamic
scale_bits: e4m3
num_bits: e2m1

# ── Exclusions: shared experts, attention, routers, lm_head ─────────────
- quantizer_name: '*block_sparse_moe.gate*'
enable: false
- quantizer_name: '*linear_attn.conv1d*'
enable: false
- quantizer_name: '*lm_head*'
enable: false
- quantizer_name: '*mlp.gate.*'
enable: false
- quantizer_name: '*mlp.shared_expert*'
enable: false
- quantizer_name: '*mlp.shared_expert_gate.*'
enable: false
- quantizer_name: '*router*'
enable: false
- quantizer_name: 'output.*'
enable: false
- parent_class: 'nn.BatchNorm1d'
quantizer_name: '*'
enable: false
- parent_class: 'nn.BatchNorm2d'
quantizer_name: '*'
enable: false
- parent_class: 'nn.BatchNorm3d'
quantizer_name: '*'
enable: false
- parent_class: 'nn.LeakyReLU'
quantizer_name: '*'
enable: false
Loading