-
Notifications
You must be signed in to change notification settings - Fork 375
fixes for fused moe (qwen3.6, GLM5.1 + MSE calibration #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,9 +59,26 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| # 2-3. Split + export each per-expert projection. | ||
| fused_dim0 = gate_up.shape[1] # 2 * expert_dim | ||
|
|
||
| def _safe_cpu_amax(quantizer_src: nn.Module) -> torch.Tensor | None: | ||
| """Extract _amax to CPU float32, surfacing and clearing any pending CUDA error first.""" | ||
| amax = getattr(quantizer_src, "_amax", None) | ||
| if amax is None or not isinstance(amax, torch.Tensor): | ||
| return None | ||
| try: | ||
| if amax.is_cuda: | ||
| torch.cuda.synchronize(amax.device) | ||
| return amax.detach().cpu().float() | ||
| except Exception: | ||
| return None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Bare except RuntimeError:
warnings.warn(f"Failed to extract _amax to CPU for {quantizer_src}, using fallback")
return None |
||
|
|
||
| for idx in range(n): | ||
| expert = nn.Module() | ||
|
|
||
| # Extract amaxes to CPU before deepcopy: cloning a corrupt CUDA _amax tensor | ||
| # (e.g. from an under-calibrated expert) triggers an async CUDA error. | ||
| gu_amax_cpu = _safe_cpu_amax(module.gate_up_proj_weight_quantizers[idx]) | ||
| down_amax_cpu = _safe_cpu_amax(module.down_proj_weight_quantizers[idx]) | ||
|
|
||
| projections = [ | ||
| ("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True), | ||
| ("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True), | ||
|
|
@@ -76,8 +93,17 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| ) | ||
| i_quantizer = gate_up_input_q if is_gate_up else down_input_q | ||
|
|
||
| # gate/up share a weight quantizer — clone so each gets independent amax. | ||
| w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src | ||
| # gate/up share a quantizer — deepcopy with _amax nulled to avoid cloning | ||
| # the corrupt CUDA tensor, then inject the pre-extracted CPU amax. | ||
| if is_gate_up: | ||
| _saved_amax = getattr(w_quantizer_src, "_amax", None) | ||
| w_quantizer_src._amax = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
For |
||
| w_quantizer = copy.deepcopy(w_quantizer_src) | ||
| w_quantizer_src._amax = _saved_amax | ||
| w_quantizer._amax = gu_amax_cpu | ||
|
Comment on lines
+98
to
+103
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Protect temporary If Proposed fix if is_gate_up:
_saved_amax = getattr(w_quantizer_src, "_amax", None)
- w_quantizer_src._amax = None
- w_quantizer = copy.deepcopy(w_quantizer_src)
- w_quantizer_src._amax = _saved_amax
+ w_quantizer_src._amax = None
+ try:
+ w_quantizer = copy.deepcopy(w_quantizer_src)
+ finally:
+ w_quantizer_src._amax = _saved_amax
w_quantizer._amax = gu_amax_cpu🤖 Prompt for AI Agents |
||
| else: | ||
| w_quantizer = w_quantizer_src | ||
| w_quantizer._amax = down_amax_cpu | ||
|
|
||
| # For per-channel amax (dim >= 1), proportionally slice dim-0 | ||
| # to match the split weight. | ||
|
|
@@ -86,12 +112,14 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| and w_quantizer._amax is not None | ||
| and w_quantizer._amax.dim() >= 1 | ||
| ): | ||
| amax = w_quantizer._amax | ||
| amax = w_quantizer._amax # CPU float32 | ||
| amax_dim0 = amax.shape[0] | ||
| if fused_total % amax_dim0 == 0: | ||
| if amax_dim0 % fused_total == 0: | ||
| slice_start = fused_start * amax_dim0 // fused_total | ||
| slice_end = (fused_start + weight_slice.shape[0]) * amax_dim0 // fused_total | ||
| w_quantizer.amax = amax[slice_start:slice_end].contiguous() | ||
| # Bypass amax.setter (which forbids shape changes); w_quantizer is a | ||
| # deepcopy for gate/up so mutating it is safe. | ||
| w_quantizer._amax = amax[slice_start:slice_end].contiguous() | ||
| else: | ||
| warnings.warn( | ||
| f"Expert {idx} {proj_name}: fused amax dim0 ({amax_dim0}) does not " | ||
|
|
@@ -100,20 +128,68 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| stacklevel=2, | ||
| ) | ||
|
|
||
| # If the weight quantizer was never calibrated, compute amax from weights. | ||
| # Patch invalid per-block amax entries (NaN/inf/negative/zero/too-small/too-large) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Bug: |
||
| # with weight-derived fallback values. | ||
| _MIN_VALID_AMAX = 1e-4 | ||
| _MAX_VALID_AMAX = 1e6 | ||
| if ( | ||
| hasattr(w_quantizer, "_amax") | ||
| and w_quantizer._amax is not None | ||
| and w_quantizer._amax.numel() > 1 | ||
| ): | ||
| amax_cpu = w_quantizer._amax | ||
| invalid_mask = ~( | ||
| torch.isfinite(amax_cpu) | ||
| & (amax_cpu >= _MIN_VALID_AMAX) | ||
| & (amax_cpu <= _MAX_VALID_AMAX) | ||
| ) | ||
| if invalid_mask.any(): | ||
| per_block_fallback = ( | ||
| weight_slice.detach() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Hardcoded |
||
| .reshape(-1, 16) | ||
| .abs() | ||
| .amax(dim=1, keepdim=True) | ||
| .cpu() | ||
| .float() | ||
| .clamp(min=2e-3) | ||
| .reshape(amax_cpu.shape) | ||
| ) | ||
| amax_cpu[invalid_mask] = per_block_fallback[invalid_mask] | ||
| w_quantizer._amax = amax_cpu | ||
|
|
||
| # For uncalibrated experts (amax missing or invalid scalar), fall back to | ||
| # per-block amax from weights so the static export path can reshape it correctly. | ||
| if ( | ||
| hasattr(w_quantizer, "is_enabled") | ||
| and w_quantizer.is_enabled | ||
| and ( | ||
| not hasattr(w_quantizer, "_amax") | ||
| or w_quantizer._amax is None | ||
| or torch.all(w_quantizer._amax == 0) | ||
| or ( | ||
| w_quantizer._amax.numel() == 1 | ||
| and not ( | ||
| torch.isfinite(w_quantizer._amax) | ||
| and w_quantizer._amax >= _MIN_VALID_AMAX | ||
| and w_quantizer._amax <= _MAX_VALID_AMAX | ||
| ) | ||
| ) | ||
| ) | ||
| ): | ||
| w_quantizer.amax = weight_slice.abs().amax().to(torch.float32) | ||
| _block_size = 16 | ||
| fallback_per_block = ( | ||
| weight_slice.detach() | ||
| .reshape(-1, _block_size) | ||
| .abs() | ||
| .amax(dim=1, keepdim=True) | ||
| .cpu() | ||
| .float() | ||
| .clamp(min=2e-3) | ||
| .reshape(*weight_slice.shape[:-1], weight_slice.shape[-1] // _block_size) | ||
| ) | ||
| w_quantizer._amax = fallback_per_block | ||
| warnings.warn( | ||
| f"Expert {idx} {proj_name} weight quantizer was not calibrated " | ||
| f"(amax missing or zero). Using weight-derived amax as fallback. " | ||
| f"(amax missing or zero). Using weight-derived per-block amax as fallback. " | ||
| f"Consider using more calibration data to activate all experts.", | ||
| stacklevel=2, | ||
| ) | ||
|
|
@@ -123,6 +199,18 @@ def _export_fused_experts(module: nn.Module, dtype: torch.dtype) -> None: | |
| wrapper.weight_quantizer = w_quantizer | ||
| wrapper.input_quantizer = i_quantizer | ||
|
|
||
| # Set global_amax to route to the static NVFP4 export path (reads per-block _amax). | ||
| # Always recompute from the current (possibly patched) _amax — a stale zero | ||
| # global_amax causes division-by-zero in the per-block scale formula. | ||
| wq = wrapper.weight_quantizer | ||
| if ( | ||
| hasattr(wq, "_amax") | ||
| and wq._amax is not None | ||
| and wq._amax.numel() > 1 | ||
| ): | ||
| wq._amax = wq._amax.to(weight_slice.device) | ||
| wq.global_amax = wq._amax.float().amax().clamp(min=2e-3) | ||
|
|
||
| _export_quantized_weight(wrapper, dtype) | ||
|
|
||
| proj = nn.Module() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -421,6 +421,18 @@ def mse_calibrate( | |
| if isinstance(weight_quantizer, TensorQuantizer) and weight_quantizer.is_enabled: | ||
| if getattr(weight_quantizer, "_calibrator", None) is not None: | ||
| weight_quantizers.append((parent_module, weight_name, weight_quantizer)) | ||
| # _QuantFusedExperts stores per-expert weight quantizers as nn.ModuleList named | ||
| # {param_name}_weight_quantizers (plural). Detect this pattern and enqueue each | ||
| # per-expert quantizer individually. | ||
| for param_name, _ in parent_module.named_parameters(recurse=False): | ||
| qlist = getattr(parent_module, f"{param_name}_weight_quantizers", None) | ||
| if not isinstance(qlist, nn.ModuleList): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The pattern |
||
| continue | ||
| for expert_idx, wq in enumerate(qlist): | ||
| if isinstance(wq, TensorQuantizer) and wq.is_enabled: | ||
| if getattr(wq, "_calibrator", None) is not None: | ||
| weight_quantizers.append((parent_module, (param_name, expert_idx), wq)) | ||
|
|
||
| seen_modules.add(parent_module) | ||
|
|
||
| # Step 3: Calibrate weight quantizers ONE AT A TIME with immediate amax computation | ||
|
|
@@ -432,7 +444,11 @@ def mse_calibrate( | |
| weight_quantizer.disable_quant() | ||
| weight_quantizer.enable_calib() | ||
| with enable_weight_access_and_writeback(parent_module, model, name_to_module): | ||
| weight = getattr(parent_module, weight_name) | ||
| if isinstance(weight_name, tuple): | ||
| param_name, expert_idx = weight_name | ||
| weight = getattr(parent_module, param_name)[expert_idx] | ||
| else: | ||
| weight = getattr(parent_module, weight_name) | ||
| weight_quantizer(weight) | ||
|
|
||
| # IMMEDIATELY compute amax and reset calibrator to free memory | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Copyright year is |
||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| metadata: | ||
| recipe_type: ptq | ||
| description: > | ||
| NVFP4 W4A4 for MoE routed experts only. Static weight scales via MSE + FP8 scale sweep; | ||
| dynamic activation scales. Supports sequential experts (nn.Linear-based) and fused experts | ||
| (_QuantFusedExperts, HF transformers 5.0+ 3D nn.Parameter style). | ||
| quantize: | ||
| algorithm: | ||
| method: mse | ||
| fp8_scale_sweep: true | ||
| layerwise: false | ||
| quant_cfg: | ||
| # ── Disable everything first ───────────────────────────────────────────── | ||
| - quantizer_name: '*' | ||
| enable: false | ||
|
|
||
| # ── Sequential experts (nn.Linear per expert) ──────────────────────────── | ||
| - quantizer_name: '*mlp.experts*weight_quantizer' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: static | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
| - quantizer_name: '*mlp.experts*input_quantizer' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: dynamic | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
|
|
||
| # ── Sequential experts: Mixtral / block_sparse_moe style ──────────────── | ||
| - quantizer_name: '*block_sparse_moe*weight_quantizer' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: static | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
| - quantizer_name: '*block_sparse_moe*input_quantizer' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: dynamic | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
|
|
||
| # ── Fused experts (_QuantFusedExperts, HF transformers 5.0+ 3D nn.Parameter style) ── | ||
| - quantizer_name: '*gate_up_proj_weight_quantizers*' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: static | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
| - quantizer_name: '*gate_up_proj_input_quantizer*' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: dynamic | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
| - quantizer_name: '*down_proj_weight_quantizers*' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: static | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
| - quantizer_name: '*down_proj_input_quantizer*' | ||
| enable: true | ||
| cfg: | ||
| block_sizes: | ||
| -1: 16 | ||
| type: dynamic | ||
| scale_bits: e4m3 | ||
| num_bits: e2m1 | ||
|
|
||
| # ── Exclusions: shared experts, attention, routers, lm_head ───────────── | ||
| - quantizer_name: '*block_sparse_moe.gate*' | ||
| enable: false | ||
| - quantizer_name: '*linear_attn.conv1d*' | ||
| enable: false | ||
| - quantizer_name: '*lm_head*' | ||
| enable: false | ||
| - quantizer_name: '*mlp.gate.*' | ||
| enable: false | ||
| - quantizer_name: '*mlp.shared_expert*' | ||
| enable: false | ||
| - quantizer_name: '*mlp.shared_expert_gate.*' | ||
| enable: false | ||
| - quantizer_name: '*router*' | ||
| enable: false | ||
| - quantizer_name: 'output.*' | ||
| enable: false | ||
| - parent_class: 'nn.BatchNorm1d' | ||
| quantizer_name: '*' | ||
| enable: false | ||
| - parent_class: 'nn.BatchNorm2d' | ||
| quantizer_name: '*' | ||
| enable: false | ||
| - parent_class: 'nn.BatchNorm3d' | ||
| quantizer_name: '*' | ||
| enable: false | ||
| - parent_class: 'nn.LeakyReLU' | ||
| quantizer_name: '*' | ||
| enable: false | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: if
preview_input_idshas no non-pad tokens (e.g. all tokens are padding),first_non_padwill be empty andfirst_non_pad[0]will error. Thefirst_non_pad.numel() > 0check correctly guards this — just confirming it's intentional that the original (all-padding) input is preserved in that edge case.