Skip to content

fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration#1382

Open
Fridah-nv wants to merge 1 commit intomainfrom
fridah/fused-moe-MSE-fix
Open

fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration#1382
Fridah-nv wants to merge 1 commit intomainfrom
fridah/fused-moe-MSE-fix

Conversation

@Fridah-nv
Copy link
Copy Markdown
Contributor

@Fridah-nv Fridah-nv commented May 2, 2026

What does this PR do?

Type of change: Bug fix

Fixes several issues with NVFP4 MSE calibration and export for fused MoE expert modules (_QuantFusedExperts — used by Qwen3.6, GLM-5.1, and other HF transformers 5.0+ models that store expert weights as 3-D nn.Parameters).

  • Bug 1 — MSE weight calibration runs 0 iterations for fused experts (model_calib.py)

The weight-quantizer discovery loop in mse_calibrate used the singular attribute name gate_up_proj_weight_quantizer to look up quantizers, but _QuantFusedExperts stores them in a plural nn.ModuleList named gate_up_proj_weight_quantizers. All 20,480 expert quantizers were silently skipped, resulting in "MSE weight calibration: 0it" and no MSE-optimized scales.

Fix: add a second pass that detects plural {param}_weight_quantizers ModuleLists and enqueues each per-expert quantizer with a (param_name, expert_idx) tuple; step 3 unpacks the tuple to extract the per-expert weight slice.

  • Bug 2 — Zero weight scales in exported checkpoint (nvfp4_tensor.py)

Per-block weight scales can silently underflow to 0 when cast to FP8 E4M3FN. The existing scale == 0 guard only catches exact float32 zeros; values in (0, 2^-9) pass through and become 0 after the FP8 cast. This affects both the dynamic recompute path (get_weights_scaling_factor) and the static calibrated path (get_weights_scaling_factor_from_quantizer).

Fix: clamp per-block scales to 2^-9 (smallest positive FP8 E4M3FN subnormal) before the FP8 cast in both paths.

  • Bug 3 — Zero/corrupt amax for uncalibrated experts at export (moe_utils.py)

Experts that receive no tokens during calibration have _amax = 0 or uninitialized values. The existing scalar fallback used 1e-4 which itself underflows to 0 in FP8 E4M3FN (1e-4 < 2^-9 ≈ 0.00195). Additionally, the per-block fallback tensor had shape (H*W, 1) instead of (H, W), causing a shape mismatch that silently bypassed the fallback and fell through to the bad scalar. Finally, a stale zero global_amax from an uncalibrated expert was not recomputed, causing division-by-zero in the FP8 scale formula.

Fix: reshape the per-block fallback correctly; raise the clamp floor to 2e-3; always recompute global_amax from the current (possibly patched) per-block _amax.

Additional fixes:

  • moe_utils.py: safe CPU extraction of _amax before deepcopy to avoid async CUDA errors from corrupt bfloat16 amax storage on under-calibrated experts.
  • model_quant.py: print_quant_summary now calls os.makedirs(output_dir, exist_ok=True) before writing .quant_summary.txt, preventing a FileNotFoundError when the export directory doesn't exist yet.
  • tensor_quantizer.py: change default format in _short_amax / _short_tensor from ".4f" to ".2e" so small amax values (e.g. 3.5e-7) display as 3.50e-07 instead of 0.0000.
  • hf_ptq.py: strip leading pad tokens from the preview input and add skip_special_tokens=True to input_decode, fixing degenerate pre/post-PTQ output on models that use EOS as the pad token (e.g. Qwen3).

Usage

 # Quantize Qwen3.6-35B-A3B (or any compatible fused-expert MoE) with the new recipe:
  python examples/llm_ptq/hf_ptq.py \                                                                                                     
      --pyt_ckpt_path /path/to/Qwen3.6-35B-A3B \                                                                                          
      --recipe modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml \                                                                 
      --export_path /path/to/output \                                                                                                     
      --calib_size 512 --calib_seq 2048   

Testing

validated on Qwen3.6-35B-A3B (8× B200):

  • 21,740 quantizers inserted; 20,480/20,480 MSE weight calibrations completed (~11 min)
  • 0 / 2,013,265,920 zero weight_scale entries in the exported checkpoint (3 shards)
  • Pre- and post-PTQ generation produce coherent, semantically consistent output

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features

    • Added a new PTQ recipe for NVFP4 quantization of MoE routed experts with MSE-based calibration and FP8 scale sweep.
  • Bug Fixes

    • Improved robustness of MoE expert export by safely handling CUDA tensors and validating quantization scales.
    • Fixed FP8 scale underflow to zero during quantization.
    • Auto-creates output directories for quantization summary files.
    • Enhanced preview input handling for tokenization accuracy.
  • Improvements

    • Extended model calibration to support fused expert quantizers.

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
@Fridah-nv Fridah-nv requested review from a team as code owners May 2, 2026 00:14
@Fridah-nv Fridah-nv requested review from Edwardf0t1 and sychen52 May 2, 2026 00:14
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 2, 2026

📝 Walkthrough

Walkthrough

This PR improves mixture-of-experts (MoE) quantization safety and calibration, introduces per-expert weight calibration support, adds a new NVFP4 experts-only PTQ recipe, and refines LLM PTQ preview input handling. It includes FP8 scale underflow prevention, safer CUDA tensor handling during MoE export, per-expert calibration discovery, and display formatting updates.

Changes

MoE Quantization Infrastructure & Experts-Only PTQ Recipe

Layer / File(s) Summary
FP8 Scale Safety
modelopt/torch/quantization/qtensor/nvfp4_tensor.py
FP8 E4M3FN scales are now clamped to minimum subnormal (2^-9) before casting in both static and dynamic quantizer paths, preventing underflow to zero.
Quantizer Display
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
_short_amax and _short_tensor format strings switched from .4f to .2e for scientific notation representation.
Per-Expert Calibration Support
modelopt/torch/quantization/model_calib.py
mse_calibrate now detects _QuantFusedExperts per-expert weight quantizers stored in nn.ModuleList and enqueues them with tuple identifiers (param_name, expert_idx) for individual calibration.
MoE Export Safety
modelopt/torch/export/moe_utils.py
Introduces _safe_cpu_amax helper to extract quantizer _amax to CPU before deepcopy, avoiding CUDA state corruption. Adds per-block _amax validation with weight-derived fallback for invalid entries, recomputes global_amax from patched per-block values, and ensures _amax is on weight device before NVFP4 export.
Output Directory Handling
modelopt/torch/quantization/model_quant.py
print_quant_summary now creates output_dir if it does not exist before writing the summary file.
Recipe Configuration
modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml
New PTQ recipe configuring NVFP4 W4A4 MSE-based quantization for sequential, block-sparse, and fused MoE experts with FP8 scale sweep enabled. Disables gates, routers, shared experts, attention, and non-expert layers.

LLM PTQ Preview Input Handling

Layer / File(s) Summary
Preview Input Preparation
examples/llm_ptq/hf_ptq.py
pre_quantize now strips leading pad tokens from preview_input_ids for non-Whisper models when tokenizer and pad_token_id are available. post_quantize decoding now includes skip_special_tokens=True when decoding preview outputs.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title refers to key aspects of the PR (fused MoE, Qwen3.6, GLM5.1, MSE calibration) but is incomplete—it lists affected models and mentions fixes without clearly identifying the primary change or issue being fixed. Consider a more specific title that clearly identifies the main fix, e.g. 'Fix MSE calibration and export for fused MoE experts in Qwen3.6/GLM5.1' or 'Fix per-block scale underflow and calibration discovery for fused MoE'
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 91.67% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns from SECURITY.md guidelines found. All dangerous operations properly safeguarded or absent.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fridah/fused-moe-MSE-fix

Review rate limit: 9/10 reviews remaining, refill in 6 minutes.

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 2, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1382/

Built to branch gh-pages at 2026-05-02 00:18 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/moe_utils.py`:
- Around line 98-103: The temporary mutation of w_quantizer_src._amax before
calling copy.deepcopy may leave the source quantizer with _amax == None if
deepcopy raises; change the code around copy.deepcopy(w_quantizer_src) to save
_saved_amax, set w_quantizer_src._amax = None, then perform deepcopy inside a
try block and restore w_quantizer_src._amax = _saved_amax in a finally block;
after deepcopy set w_quantizer._amax = gu_amax_cpu as before so the source state
is always restored even on exceptions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c7efeb50-0d25-4ef7-8b84-e1a0a66662b4

📥 Commits

Reviewing files that changed from the base of the PR and between 9d2e608 and 35dad9a.

📒 Files selected for processing (7)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/moe_utils.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/model_quant.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • modelopt_recipes/general/ptq/nvfp4_experts_only_mse.yaml

Comment on lines +98 to +103
if is_gate_up:
_saved_amax = getattr(w_quantizer_src, "_amax", None)
w_quantizer_src._amax = None
w_quantizer = copy.deepcopy(w_quantizer_src)
w_quantizer_src._amax = _saved_amax
w_quantizer._amax = gu_amax_cpu
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Protect temporary _amax mutation with try/finally.

If copy.deepcopy() raises, _amax is left as None on the source quantizer. Wrap restore in finally to avoid state corruption on failure.

Proposed fix
             if is_gate_up:
                 _saved_amax = getattr(w_quantizer_src, "_amax", None)
-                w_quantizer_src._amax = None
-                w_quantizer = copy.deepcopy(w_quantizer_src)
-                w_quantizer_src._amax = _saved_amax
+                w_quantizer_src._amax = None
+                try:
+                    w_quantizer = copy.deepcopy(w_quantizer_src)
+                finally:
+                    w_quantizer_src._amax = _saved_amax
                 w_quantizer._amax = gu_amax_cpu
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/moe_utils.py` around lines 98 - 103, The temporary
mutation of w_quantizer_src._amax before calling copy.deepcopy may leave the
source quantizer with _amax == None if deepcopy raises; change the code around
copy.deepcopy(w_quantizer_src) to save _saved_amax, set w_quantizer_src._amax =
None, then perform deepcopy inside a try block and restore w_quantizer_src._amax
= _saved_amax in a finally block; after deepcopy set w_quantizer._amax =
gu_amax_cpu as before so the source state is always restored even on exceptions.

@Fridah-nv Fridah-nv requested a review from cjluo-nv May 2, 2026 00:21
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

This PR fixes several real bugs in the fused MoE quantization pipeline (MSE calibration discovery, FP8 scale underflow, uncalibrated expert export). The fixes are well-described in the PR body and address genuine correctness issues. However, there are several concerns:

  1. Missing unit tests (critical): No tests are added for any of the bug fixes. The existing test_fused_experts.py covers registration/conversion/basic export but doesn't exercise MSE calibration for fused experts, FP8 scale clamping, or the invalid-amax patching logic. Given the complexity of the moe_utils.py changes and the project's known pattern of missing tests, this is a blocking concern.

  2. Threshold inconsistency: _MIN_VALID_AMAX = 1e-4 is below FP8 E4M3FN minimum (2^-9 ≈ 0.00195), meaning values between 1e-4 and 2e-3 pass the validity check but could still underflow.

  3. Hardcoded block_size=16: The fallback per-block amax computation in moe_utils.py hardcodes 16. If the actual block size differs, the shape will be wrong.

  4. Copyright year: New YAML file has Copyright (c) 2024 but LICENSE_HEADER says 2026.

)

# If the weight quantizer was never calibrated, compute amax from weights.
# Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Bug: _MIN_VALID_AMAX = 1e-4 is below the FP8 E4M3FN minimum subnormal (2^-9 ≈ 0.00195). Values between 1e-4 and ~0.00195 will pass this validity check but will still underflow to 0 when cast to FP8 E4M3FN. Consider using 2e-3 (which you already use for clamping) or 2**-9 as the minimum valid threshold for consistency with the nvfp4_tensor.py fix.

)
if invalid_mask.any():
per_block_fallback = (
weight_slice.detach()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Hardcoded block_size=16 here and at line 173. If the quantizer's actual block size is different, the reshape will produce an incorrect shape. Consider extracting the block size from the weight quantizer (e.g. w_quantizer.block_sizes.get(-1, 16)) rather than hardcoding.

torch.cuda.synchronize(amax.device)
return amax.detach().cpu().float()
except Exception:
return None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Bare except Exception silently swallows all errors and returns None. While defensive coding for corrupt CUDA tensors is reasonable, this could mask unrelated bugs. Consider catching a narrower set of exceptions (e.g. RuntimeError) or at minimum logging a warning when the fallback is triggered:

except RuntimeError:
    warnings.warn(f"Failed to extract _amax to CPU for {quantizer_src}, using fallback")
    return None

# the corrupt CUDA tensor, then inject the pre-extracted CPU amax.
if is_gate_up:
_saved_amax = getattr(w_quantizer_src, "_amax", None)
w_quantizer_src._amax = None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

For down_proj (non-gate_up), w_quantizer = w_quantizer_src — this is the original quantizer, not a copy. Then w_quantizer._amax = down_amax_cpu mutates the original quantizer's _amax. This is fine if the module is only exported once, but is potentially surprising. A comment noting this is intentional mutation of the original would help.

@@ -0,0 +1,130 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Copyright year is 2024 but the project's LICENSE_HEADER specifies 2026. New files should use the current year from the canonical header.

# per-expert quantizer individually.
for param_name, _ in parent_module.named_parameters(recurse=False):
qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None)
if not isinstance(qlist, nn.ModuleList):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The pattern named_parameters(recurse=False) + checking for f"{param_name}_weight_quantizers" works for the current _QuantFusedExperts layout, but is fairly fragile. If other modules happen to have a parameter and a same-named ModuleList with _weight_quantizers suffix, they'd be picked up too. Consider adding a type check (e.g. checking if parent_module is a _QuantFusedExperts instance) or at least a comment noting the assumption.

"input_features" if model_type == "whisper" else "input_ids"
][0:1]
# Strip leading padding tokens so the preview input shows real content
if model_type not in ("whisper",) and tokenizer is not None and tokenizer.pad_token_id is not None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor: if preview_input_ids has no non-pad tokens (e.g. all tokens are padding), first_non_pad will be empty and first_non_pad[0] will error. The first_non_pad.numel() > 0 check correctly guards this — just confirming it's intentional that the original (all-padding) input is preserved in that edge case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants