From 3faee18447fdcea4eab7ef7e088be95c8681bbd6 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 28 Apr 2026 12:35:02 -0700 Subject: [PATCH 1/7] Add MXFP4 -> NVFP4 conversion MSE experiment (scratch) Research artifact comparing three algorithms for converting an MXFP4 tensor (block 32, E2M1 + E8M0) to NVFP4 (block 16, E2M1 + E4M3 + FP32 global scale): Algo 1: dequantize MXFP4 -> bf16 -> standard NVFP4 quantize. Algo 2: keep E2M1 nibbles verbatim; pick global S = 2^m and store per-block E4M3 scales as 2^(k_j - m), snapping out-of-range blocks. Two m strategies: midpoint and 1D integer search over the closed-form snap-error objective. Algo 3: hybrid - verbatim path for in-range blocks (zero error) plus NVFP4 requantization with fixed scale_2 = 2^m for OOR blocks. m chosen by direct-MSE 1D sweep. Includes 27 scenarios (gaussian, heavy-tail, outlier patterns, spread boundary tests, layer-shaped LLM weights) and a report summarizing results, the snap-up/snap-down asymmetry that drives the m choice, and the one pathological case (single dominant outlier) where Algo 3 still trails Algo 1 by 0.21% due to integer-m vs continuous scale_2. Signed-off-by: Chenjie Luo --- scratch/mxfp4_to_nvfp4_mse.py | 739 +++++++++++++++++++++++++++++++ scratch/mxfp4_to_nvfp4_report.md | 162 +++++++ 2 files changed, 901 insertions(+) create mode 100644 scratch/mxfp4_to_nvfp4_mse.py create mode 100644 scratch/mxfp4_to_nvfp4_report.md diff --git a/scratch/mxfp4_to_nvfp4_mse.py b/scratch/mxfp4_to_nvfp4_mse.py new file mode 100644 index 0000000000..62466714ab --- /dev/null +++ b/scratch/mxfp4_to_nvfp4_mse.py @@ -0,0 +1,739 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Research scratch script — relax some style rules that don't add value here. +# ruff: noqa: D103, RUF003 + +"""MXFP4 -> NVFP4 conversion MSE experiment. + +Compares two algorithms for converting an MXFP4 tensor (block_size=32, E2M1 + E8M0 +power-of-2 scales) to NVFP4 (block_size=16, E2M1 + E4M3 scales + global FP32 scale): + + Algo 1 (dequant-requant): dequantize MXFP4 to BF16, then quantize to NVFP4 the + standard way. This re-buckets nibbles and computes new scales from scratch. + + Algo 2 (verbatim nibbles): keep the E2M1 nibbles unchanged. Each MXFP4 block of 32 + splits into two NVFP4 blocks of 16, both inheriting the same exponent k_j. + Pick a global scale S = 2^m (integer m) and store the per-block E4M3 scale as + 2^(k_j - m). E4M3 exactly represents 2^k for k in [-9, 8], so as long as + max(k) - min(k) <= 17 there is a valid m and the conversion is exact (zero + MSE). For blocks outside that window, snap the per-block exponent to the + [-9, 8] boundary; nibbles stay verbatim, and that snap is provably MSE-optimal + given the constraint. + +Reference for both algos: the MXFP4-dequantized tensor (i.e. what the source +representation faithfully encodes). MSE is computed against that reference in fp32. +""" + +import math + +import torch + +from modelopt.torch.quantization.qtensor.mxfp4_tensor import MXFP4QTensor +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +MX_BLOCK = 32 +NV_BLOCK = 16 +E4M3_KMIN, E4M3_KMAX = -9, 8 # E4M3 represents 2^k exactly for k in [-9, 8] + +# E2M1 magnitude squared, indexed by nibble bits (sign bit ignored — squared anyway). +# Sign bit is the high bit (0b1000); low 3 bits are the magnitude index into +# [0, 0.5, 1, 1.5, 2, 3, 4, 6]. Squared magnitude lookup for all 16 nibble values: +_E2M1_SQ = torch.tensor( + [0.0, 0.25, 1.0, 2.25, 4.0, 9.0, 16.0, 36.0, 0.0, 0.25, 1.0, 2.25, 4.0, 9.0, 16.0, 36.0], + dtype=torch.float32, +) + + +# ---------- Algorithm 1: dequant -> requant ---------------------------------- + + +def algo1_dequant_requant(mxfp4_qt: MXFP4QTensor, e8m0_scale: torch.Tensor) -> torch.Tensor: + """Dequantize MXFP4 then quantize to NVFP4 the normal way; return float32 reconstruction.""" + deq_bf16 = mxfp4_qt.dequantize( + dtype=torch.bfloat16, scale=e8m0_scale, block_sizes={-1: MX_BLOCK} + ) + nv_qt, per_block_e4m3, double_scale = NVFP4QTensor.quantize(deq_bf16, block_size=NV_BLOCK) + out = nv_qt.dequantize( + dtype=torch.float32, + scale=per_block_e4m3, + double_scale=double_scale, + block_sizes={-1: NV_BLOCK}, + ) + return out.float() + + +# ---------- Algorithm 2: keep nibbles, just rescale -------------------------- + + +def _block_sum_sq_nibbles( + mxfp4_qt: MXFP4QTensor, +) -> torch.Tensor: + """For each MXFP4 block, sum of squared E2M1 magnitudes (used for closed-form MSE). + + Returns a 1D tensor of length num_blocks, in float32. + """ + original_shape = mxfp4_qt.metadata["shape"] + packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) + low = (packed & 0x0F).long() + high = ((packed >> 4) & 0x0F).long() + sq = _E2M1_SQ.to(packed.device) + per_block = (sq[low] + sq[high]).sum(dim=-1) # one entry per MXFP4 block + return per_block.reshape(-1) + + +def _find_best_m( + k_flat: torch.Tensor, + sum_sq_flat: torch.Tensor, + k_min: int, + k_max: int, +) -> tuple[int, float]: + """Sweep integer m and return (best_m, best_total_squared_error). + + Per-block squared error when verbatim nibbles are kept and scale snaps to E4M3: + delta_j = k_j - m + snap_j = clamp(delta_j, [-9, 8]) + scale_diff = 2^k_j - 2^(m + snap_j) + err_j = sum_sq_j * scale_diff^2 + + In-range blocks (delta_j in [-9, 8]) contribute zero. Search range is symmetric + around the k window — outside it, every block snaps and error grows monotonically. + """ + candidates = list(range(k_min - E4M3_KMAX - 1, k_max - E4M3_KMIN + 2)) + k_f = k_flat.float() + pow2_k = torch.exp2(k_f) + best_m: int = candidates[0] + best_err: float = float("inf") + for m_cand in candidates: + delta = k_flat - m_cand + snap = torch.clamp(delta, E4M3_KMIN, E4M3_KMAX) + # snapped scale exponent: m + snap, but only differs from k when |delta| > 8/9 + snapped_scale = torch.exp2((m_cand + snap).float()) + diff = pow2_k - snapped_scale + err = (sum_sq_flat * diff * diff).sum().item() + if err < best_err: + best_err = err + best_m = m_cand + return best_m, best_err + + +def algo2_keep_nibbles( + mxfp4_qt: MXFP4QTensor, + e8m0_scale: torch.Tensor, + m_strategy: str = "midpoint", +) -> tuple[torch.Tensor, int, int, int]: + """Keep MXFP4 E2M1 nibbles verbatim and rescale. + + Choose a per-tensor m (S=2^m) and per-block E4M3 scales = 2^(k_j - m), + snapping out-of-range blocks to E4M3's boundary. + + m_strategy: + "midpoint" — when spread <=17, any valid m gives MSE=0 (we pick midpoint). + When spread >17, fall back to a heuristic: median(k) - center. + "search" — when spread <=17, behaves like "midpoint" (already optimal). + When spread >17, sweep integer m and pick the value that + minimizes total snap error in closed form. + """ + # Recover signed integer exponents k_j from E8M0 (stored as uint8 with bias 127). + k = e8m0_scale.to(torch.int32) - 127 + + # Identify blocks whose scale is irrelevant: all E2M1 nibbles have magnitude 0 + # (sign bit may be 0 or 1; MXFP4's cast_fp4 emits sign_bit=1 for value 0, giving + # "negative zero" nibbles 0x08 / 0x80, so packed bytes are 0x88). Mask 0x77. + original_shape = mxfp4_qt.metadata["shape"] + packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) + block_is_zero = ((packed & 0x77) == 0).all(dim=-1).reshape(-1) + k_flat = k.reshape(-1) + nonzero_mask = ~block_is_zero + nonzero_k = k_flat[nonzero_mask] if nonzero_mask.any() else k_flat + k_min = int(nonzero_k.min().item()) + k_max = int(nonzero_k.max().item()) + + spread_fits = (k_max - k_min) <= (E4M3_KMAX - E4M3_KMIN) + if spread_fits: + m = (k_max - E4M3_KMAX + k_min - E4M3_KMIN + 1) // 2 + m = max(k_max - E4M3_KMAX, min(m, k_min - E4M3_KMIN)) + elif m_strategy == "search": + sum_sq = _block_sum_sq_nibbles(mxfp4_qt) + # zero-blocks contribute 0 to S_j so they don't affect search either way; + # leave them in to keep shapes aligned. + m, _ = _find_best_m(k_flat, sum_sq, k_min, k_max) + else: + m = int(nonzero_k.median().item()) - (E4M3_KMAX + E4M3_KMIN) // 2 + + # Per-block exponent stored in the NVFP4 E4M3 scale: 2^(k_j - m), clamped to [-9, 8]. + e4m3_exp = torch.clamp(k - m, E4M3_KMIN, E4M3_KMAX) + e4m3_scale_fp32 = torch.exp2(e4m3_exp.float()) # exact powers of 2 + + # NVFP4 per-block scale lives on 16-element blocks; each MXFP4 block (32) splits + # into two NVFP4 blocks that share the same exponent. Round-trip through fp32 + # before casting to float8_e4m3fn to avoid repeat_interleave dtype quirks. + num_mx_blocks_per_row = original_shape[-1] // MX_BLOCK + e4m3_scale_nv = ( + e4m3_scale_fp32.view(*original_shape[:-1], num_mx_blocks_per_row) + .repeat_interleave(2, dim=-1) + .contiguous() + .to(torch.float8_e4m3fn) + ) + + # MXFP4 and NVFP4 use identical nibble packing (even idx low, odd idx high), so + # the bytes carry over verbatim. + nv_qt = NVFP4QTensor(original_shape, mxfp4_qt.metadata["dtype"], mxfp4_qt._quantized_data) + double_scale = torch.tensor(float(2.0**m), device=DEVICE, dtype=torch.float32) + + out = nv_qt.dequantize( + dtype=torch.float32, + scale=e4m3_scale_nv, + double_scale=double_scale, + block_sizes={-1: NV_BLOCK}, + ) + return out.float(), m, k_min, k_max + + +# ---------- Algorithm 3: hybrid (verbatim where exact, NVFP4-requant elsewhere) --- + + +def _algo3_recon_for_m( + deq_ref: torch.Tensor, + e8m0_scale: torch.Tensor, + m: int, +) -> torch.Tensor: + """Build Algo 3's fp32 reconstruction for a given m. + + For MXFP4 blocks where (k_j - m) ∈ [-9, 8]: use the exact MXFP4 dequant value + (zero error vs reference). For OOR blocks: dequant the block to fp32 (already + done — that's deq_ref), then NVFP4-quantize each 16-element half with the + fixed global scale 2^m and dequantize. The per-NVFP4-block amax can be + smaller than the full-MXFP4-block amax, so OOR blocks at the MXFP4 level + may still fit cleanly into E4M3 per-NVFP4-block scales. + """ + scale_2 = torch.tensor(float(2.0**m), device=deq_ref.device, dtype=torch.float32) + nv_qt, pb_scale, _ = NVFP4QTensor.quantize( + deq_ref.to(torch.bfloat16), + block_size=NV_BLOCK, + weights_scaling_factor_2=scale_2, + ) + nv_recon = nv_qt.dequantize( + dtype=torch.float32, + scale=pb_scale, + double_scale=scale_2, + block_sizes={-1: NV_BLOCK}, + ).view_as(deq_ref) + + k_flat = e8m0_scale.to(torch.int32).reshape(-1) - 127 + delta = k_flat - m + in_range = (delta >= E4M3_KMIN) & (delta <= E4M3_KMAX) # per MXFP4 block + + deq_blocks = deq_ref.reshape(-1, MX_BLOCK) + nv_blocks = nv_recon.reshape(-1, MX_BLOCK) + recon_blocks = torch.where(in_range.unsqueeze(-1), deq_blocks, nv_blocks) + return recon_blocks.view_as(deq_ref).float() + + +def algo3_hybrid_requant( + mxfp4_qt: MXFP4QTensor, + e8m0_scale: torch.Tensor, +) -> tuple[torch.Tensor, int, int, int]: + """Hybrid: verbatim for in-range blocks, NVFP4-requant for out-of-range blocks. + + For OOR MXFP4 blocks, dequantize and re-quantize each 16-element half with + the fixed global scale 2^m. m is chosen to minimize the actual post-hybrid + MSE: brute-force over the same integer range Algo 2 considers, but evaluating + the real reconstruction error rather than a closed form (because NVFP4-requant's + E4M3 mantissa quantization isn't a clean function of m alone). + """ + k = e8m0_scale.to(torch.int32) - 127 + original_shape = mxfp4_qt.metadata["shape"] + packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) + block_is_zero = ((packed & 0x77) == 0).all(dim=-1).reshape(-1) + k_flat = k.reshape(-1) + nonzero_mask = ~block_is_zero + nonzero_k = k_flat[nonzero_mask] if nonzero_mask.any() else k_flat + k_min = int(nonzero_k.min().item()) + k_max = int(nonzero_k.max().item()) + + deq_ref = reference_from_mxfp4(mxfp4_qt, e8m0_scale) + + # If everything fits in E4M3, midpoint m gives exact zero-error reconstruction. + if (k_max - k_min) <= (E4M3_KMAX - E4M3_KMIN): + m = (k_max - E4M3_KMAX + k_min - E4M3_KMIN + 1) // 2 + m = max(k_max - E4M3_KMAX, min(m, k_min - E4M3_KMIN)) + return deq_ref.float(), m, k_min, k_max + + # Otherwise, brute-force search over candidate m and pick best MSE. + candidates = list(range(k_min - E4M3_KMAX - 1, k_max - E4M3_KMIN + 2)) + best_m: int = candidates[0] + best_mse: float = float("inf") + best_recon: torch.Tensor = _algo3_recon_for_m(deq_ref, e8m0_scale, best_m) + for m_cand in candidates: + recon = _algo3_recon_for_m(deq_ref, e8m0_scale, m_cand) + mse_val = mse(deq_ref, recon) + if mse_val < best_mse: + best_mse = mse_val + best_m = m_cand + best_recon = recon + return best_recon, best_m, k_min, k_max + + +# ---------- Reference and metrics -------------------------------------------- + + +def reference_from_mxfp4(mxfp4_qt: MXFP4QTensor, e8m0_scale: torch.Tensor) -> torch.Tensor: + """The true value the MXFP4 representation encodes (in fp32).""" + return mxfp4_qt.dequantize( + dtype=torch.float32, scale=e8m0_scale, block_sizes={-1: MX_BLOCK} + ).float() + + +def mse(a: torch.Tensor, b: torch.Tensor) -> float: + return float(((a.float() - b.float()) ** 2).mean().item()) + + +def max_abs_err(a: torch.Tensor, b: torch.Tensor) -> float: + return float((a.float() - b.float()).abs().max().item()) + + +def snr_db(ref: torch.Tensor, approx: torch.Tensor) -> float: + """Signal-to-noise ratio in dB. +inf when MSE=0.""" + sig = (ref.float() ** 2).mean().item() + err = ((ref.float() - approx.float()) ** 2).mean().item() + if err <= 0: + return float("inf") + if sig <= 0: + return float("-inf") + return 10.0 * math.log10(sig / err) + + +# ---------- Test scenarios --------------------------------------------------- +# Each scenario returns a tensor with last dim divisible by 32 (MXFP4 block size). +# Most are 256×1024 (8192 MXFP4 blocks) for a quick run; some test other shapes. + +R, C = 256, 1024 # default rows × cols + + +def gen_uniform() -> torch.Tensor: + return torch.empty(R, C, device=DEVICE, dtype=torch.bfloat16).uniform_(-1, 1) + + +def gen_gaussian() -> torch.Tensor: + return (torch.randn(R, C, device=DEVICE) * 1.0).bfloat16() + + +def gen_heavy_tail() -> torch.Tensor: + # x = N(0,1) * |N(0,1)| → fatter tails than gaussian + return (torch.randn(R, C, device=DEVICE) * torch.randn(R, C, device=DEVICE).abs()).bfloat16() + + +def gen_rare_outliers() -> torch.Tensor: + x = torch.randn(R, C, device=DEVICE) * 0.05 + mask = torch.rand_like(x) < 1e-3 + x[mask] = 100.0 * torch.sign(torch.randn_like(x[mask]) + 1e-6) + return x.bfloat16() + + +def gen_mixed_block_scales_25() -> torch.Tensor: + """Each row chunk gets a different magnitude — forces wide block-exponent spread.""" + x = torch.randn(R, C, device=DEVICE) * 0.3 + n_chunks = 16 + chunk = max(R // n_chunks, 1) + for i in range(n_chunks): + start, end = i * chunk, R if i == n_chunks - 1 else (i + 1) * chunk + s = 2.0 ** (-12 + (i * 25 // (n_chunks - 1))) + x[start:end] *= s + return x.bfloat16() + + +def gen_narrow_range() -> torch.Tensor: + return (torch.randn(R, C, device=DEVICE) * 0.5 + 1.0).bfloat16() + + +def gen_llm_weight() -> torch.Tensor: + # Dense linear-layer init: small std, rare outliers + x = torch.randn(R, C, device=DEVICE) * (1.0 / math.sqrt(C)) + mask = torch.rand_like(x) < 1e-4 + x[mask] *= 50.0 + return x.bfloat16() + + +def gen_zero_block() -> torch.Tensor: + x = torch.zeros(R, C, device=DEVICE) + mask = torch.rand_like(x) < 0.01 + x[mask] = torch.randn_like(x[mask]) * 0.5 + return x.bfloat16() + + +# --- Wider/tighter spread tests around the 17-exponent boundary ------------- + + +def _per_row_geom_scale(rows: int, cols: int, log2_range: int) -> torch.Tensor: + """Each row chunk gets a power-of-2 magnitude spanning [-r/2, r/2].""" + x = torch.randn(rows, cols, device=DEVICE) * 0.5 + n_chunks = 16 + chunk = max(rows // n_chunks, 1) + half = log2_range // 2 + for i in range(n_chunks): + start, end = i * chunk, rows if i == n_chunks - 1 else (i + 1) * chunk + s = 2.0 ** (-half + (i * log2_range // (n_chunks - 1))) + x[start:end] *= s + return x.bfloat16() + + +def gen_spread_15() -> torch.Tensor: + """Block exponent spread ≈ 15 — fits in E4M3 window, midpoint should be exact.""" + return _per_row_geom_scale(R, C, log2_range=15) + + +def gen_spread_17() -> torch.Tensor: + """Block exponent spread = 17 — at the in-range boundary.""" + return _per_row_geom_scale(R, C, log2_range=17) + + +def gen_spread_18() -> torch.Tensor: + """Block exponent spread = 18 — just past the boundary; midpoint loses, search wins.""" + return _per_row_geom_scale(R, C, log2_range=18) + + +def gen_spread_50() -> torch.Tensor: + return _per_row_geom_scale(R, C, log2_range=50) + + +# --- Distribution variations ------------------------------------------------ + + +def gen_bimodal() -> torch.Tensor: + """Two gaussian clusters at very different magnitudes.""" + x = torch.randn(R, C, device=DEVICE) * 0.01 + mask = torch.rand_like(x) < 0.5 + x[mask] = torch.randn_like(x[mask]) * 8.0 + return x.bfloat16() + + +def gen_power_law() -> torch.Tensor: + """Pareto(1.5)-like distribution — long-tailed.""" + u = torch.rand(R, C, device=DEVICE).clamp(min=1e-6) + x = (u ** -(1.0 / 1.5) - 1.0) * torch.sign(torch.randn_like(u)) + return (x * 0.05).bfloat16() + + +def gen_per_row_outlier() -> torch.Tensor: + """LLM-activation-style: a few rows are dominated by outlier columns.""" + x = torch.randn(R, C, device=DEVICE) * 0.01 + n_outlier_rows = 4 + outlier_rows = torch.randperm(R)[:n_outlier_rows] + n_outlier_cols = max(C // 64, 1) + outlier_cols = torch.randperm(C)[:n_outlier_cols] + for r in outlier_rows: + x[r, outlier_cols] = 30.0 * torch.sign(torch.randn(n_outlier_cols, device=DEVICE) + 1e-6) + return x.bfloat16() + + +def gen_per_col_outlier() -> torch.Tensor: + """Whole columns are systematically larger — like a single outlier feature.""" + x = torch.randn(R, C, device=DEVICE) * 0.01 + outlier_cols = torch.randperm(C)[: max(C // 128, 1)] + x[:, outlier_cols] *= 200.0 + return x.bfloat16() + + +def gen_single_extreme() -> torch.Tensor: + """One absurdly large value in an otherwise small tensor.""" + x = torch.randn(R, C, device=DEVICE) * 0.005 + x[R // 2, C // 2] = 1e4 + return x.bfloat16() + + +def gen_subnormal_heavy() -> torch.Tensor: + """Many values smaller than E2M1's smallest representable nonzero (0.5*2^k_min).""" + return (torch.randn(R, C, device=DEVICE) * 1e-8).bfloat16() + + +def gen_saturating() -> torch.Tensor: + """Values pushed to E2M1's max boundary — stresses cast_fp4 rounding.""" + x = torch.randn(R, C, device=DEVICE) + x = torch.sign(x) * torch.min(x.abs(), torch.tensor(6.0, device=DEVICE)) + return x.bfloat16() + + +def gen_mixed_signs_zero_mean() -> torch.Tensor: + """Strongly bimodal sign distribution, near-zero mean.""" + x = torch.where( + torch.rand(R, C, device=DEVICE) < 0.5, + torch.full((R, C), 3.0, device=DEVICE), + torch.full((R, C), -3.0, device=DEVICE), + ) + x += torch.randn(R, C, device=DEVICE) * 0.1 + return x.bfloat16() + + +def gen_constant() -> torch.Tensor: + """All identical values — degenerate; one block exponent, two distinct nibble values.""" + return torch.full((R, C), 1.5, device=DEVICE, dtype=torch.bfloat16) + + +# --- Layer-shaped LLM-like patterns ---------------------------------------- + + +def gen_qkv_weight() -> torch.Tensor: + """Attention QKV weight: tall, gaussian init w/ mild outliers.""" + rows, cols = 4096, 4096 + x = torch.randn(rows, cols, device=DEVICE) * (1.0 / math.sqrt(cols)) + mask = torch.rand_like(x) < 5e-5 + x[mask] *= 30.0 + return x.bfloat16() + + +def gen_mlp_gate_up() -> torch.Tensor: + """MLP gate/up projection: wide & has activation-driven scale variation.""" + rows, cols = 1024, 4096 + x = torch.randn(rows, cols, device=DEVICE) * (1.0 / math.sqrt(cols)) + # A few channels have larger weights (often seen post-fine-tuning) + hot = torch.randperm(rows)[: rows // 32] + x[hot] *= 5.0 + return x.bfloat16() + + +def gen_embedding() -> torch.Tensor: + """Embedding-style: vocab × hidden, ~N(0, 1) range with row-sparse outliers.""" + rows, cols = 2048, 1024 + x = torch.randn(rows, cols, device=DEVICE) * 0.5 + rare = torch.randperm(rows)[: rows // 256] + x[rare] *= 20.0 + return x.bfloat16() + + +def gen_layernorm_gain() -> torch.Tensor: + """LayerNorm gain vector (1D-ish, padded to 2D with cols=64 for blockability).""" + rows = 32 + x = torch.ones(rows, 1024, device=DEVICE) + torch.randn(rows, 1024, device=DEVICE) * 0.05 + return x.bfloat16() + + +# --- Other shapes ---------------------------------------------------------- + + +def gen_4d_conv() -> torch.Tensor: + """4D conv-like weight: (oc, ic, kh, kw). Last 3 dims flattened block-wise.""" + return (torch.randn(64, 64, 4, 4, device=DEVICE) * 0.1).bfloat16().reshape(64, -1) + + +def gen_large_flat() -> torch.Tensor: + """Bigger tensor to confirm scaling: 1k × 4k.""" + return (torch.randn(1024, 4096, device=DEVICE) * 0.02).bfloat16() + + +SCENARIOS = [ + # Original 8 (kept for continuity with earlier results) + ("uniform [-1,1]", gen_uniform), + ("gaussian std=1", gen_gaussian), + ("heavy-tail", gen_heavy_tail), + ("rare outliers (1e-3, mag=100)", gen_rare_outliers), + ("mixed block scales (spread 25)", gen_mixed_block_scales_25), + ("narrow range (~1.0)", gen_narrow_range), + ("typical LLM weight", gen_llm_weight), + ("mostly zeros, 1% nonzero", gen_zero_block), + # Boundary tests around the 17-exponent E4M3 window + ("spread 15 (in-range)", gen_spread_15), + ("spread 17 (boundary)", gen_spread_17), + ("spread 18 (just over)", gen_spread_18), + ("spread 50 (extreme)", gen_spread_50), + # Distribution variations + ("bimodal magnitudes", gen_bimodal), + ("Pareto(1.5) power-law", gen_power_law), + ("per-row outliers", gen_per_row_outlier), + ("per-col outliers", gen_per_col_outlier), + ("single extreme outlier", gen_single_extreme), + ("subnormal-heavy (1e-8)", gen_subnormal_heavy), + ("saturating at E2M1_max", gen_saturating), + ("strong bimodal signs", gen_mixed_signs_zero_mean), + ("constant (degenerate)", gen_constant), + # Layer-shaped LLM patterns + ("QKV weight (4096x4096)", gen_qkv_weight), + ("MLP gate/up (1024x4096)", gen_mlp_gate_up), + ("embedding (2048x1024)", gen_embedding), + ("LayerNorm gain", gen_layernorm_gain), + # Other shapes + ("conv weight 4D (64x64x4x4)", gen_4d_conv), + ("large flat (1024x4096)", gen_large_flat), +] + + +# ---------- Driver ----------------------------------------------------------- + + +def run_one(name: str, x: torch.Tensor) -> dict: + """Quantize x to MXFP4, run all algos, return metrics.""" + # Pad/skip if last dim isn't divisible by MX_BLOCK + if x.shape[-1] % MX_BLOCK != 0: + raise ValueError(f"{name}: last dim {x.shape[-1]} not divisible by {MX_BLOCK}") + + # MXFP4/NVFP4 quantizers expect a 2D-ish view (block on last dim). Keep original shape; + # the implementation views (-1, block_size) internally. + mx_qt, e8m0 = MXFP4QTensor.quantize(x.clone(), block_size=MX_BLOCK) + ref = reference_from_mxfp4(mx_qt, e8m0) + + out1 = algo1_dequant_requant(mx_qt, e8m0) + mse1 = mse(ref, out1) + + out2_mid, m_mid, k_min, k_max = algo2_keep_nibbles(mx_qt, e8m0, m_strategy="midpoint") + mse2_mid = mse(ref, out2_mid) + + out2_best, m_best, _, _ = algo2_keep_nibbles(mx_qt, e8m0, m_strategy="search") + mse2_best = mse(ref, out2_best) + + out3, m3, _, _ = algo3_hybrid_requant(mx_qt, e8m0) + mse3 = mse(ref, out3) + + k_int = (e8m0.to(torch.int32) - 127 - m_best).flatten() + n_oor_algo2 = int(((k_int < E4M3_KMIN) | (k_int > E4M3_KMAX)).sum().item()) + k_int3 = (e8m0.to(torch.int32) - 127 - m3).flatten() + n_oor_algo3 = int(((k_int3 < E4M3_KMIN) | (k_int3 > E4M3_KMAX)).sum().item()) + + return { + "name": name, + "shape": tuple(x.shape), + "k_range": (k_min, k_max), + "spread": k_max - k_min, + "m_mid": m_mid, + "m_best": m_best, + "m_algo3": m3, + "n_blocks": int(e8m0.numel()), + "n_oor": n_oor_algo2, + "n_oor_algo3": n_oor_algo3, + "mse_algo1": mse1, + "mse_algo2_mid": mse2_mid, + "mse_algo2_best": mse2_best, + "mse_algo3": mse3, + "snr1": snr_db(ref, out1), + "snr2_best": snr_db(ref, out2_best), + "snr3": snr_db(ref, out3), + "max_err1": max_abs_err(ref, out1), + "max_err2_best": max_abs_err(ref, out2_best), + "max_err3": max_abs_err(ref, out3), + } + + +def _fmt_snr(v: float) -> str: + if v == float("inf"): + return " inf" + if v == float("-inf"): + return " -inf" + return f"{v:6.1f}" + + +def main(): + print(f"device: {DEVICE}") + if DEVICE.type == "cuda": + print(f"gpu: {torch.cuda.get_device_name(0)}") + print() + + torch.manual_seed(0) + + rows_hdr = ( + f"{'scenario':<34}" + f"{'spread':>7}" + f"{'m2/m3':>8}" + f"{'oor':>9}" + f"{'algo1 MSE':>11}" + f"{'algo2_best':>11}" + f"{'algo3':>11}" + f"{'SNR1':>7}" + f"{'SNR2':>7}" + f"{'SNR3':>7}" + ) + print(rows_hdr) + print("-" * len(rows_hdr)) + + win = {"algo1": 0, "algo2": 0, "algo3": 0, "tie": 0} + n_algo3_exact = 0 + n_total = 0 + all_results = [] + for name, gen in SCENARIOS: + x = gen() + r = run_one(name, x) + all_results.append(r) + n_total += 1 + + mses = {"algo1": r["mse_algo1"], "algo2": r["mse_algo2_best"], "algo3": r["mse_algo3"]} + best_v = min(mses.values()) + winners = [k for k, v in mses.items() if v == best_v] + if len(winners) > 1: + win["tie"] += 1 + else: + win[winners[0]] += 1 + if r["mse_algo3"] == 0.0: + n_algo3_exact += 1 + + oor_str = f"{r['n_oor']:>4}/{r['n_blocks']:<4}" + m_str = f"{r['m_best']:>3}/{r['m_algo3']:<3}" + print( + f"{r['name']:<34}" + f"{r['spread']:>7}" + f"{m_str:>8}" + f"{oor_str:>9}" + f"{r['mse_algo1']:>11.2e}" + f"{r['mse_algo2_best']:>11.2e}" + f"{r['mse_algo3']:>11.2e}" + f"{_fmt_snr(r['snr1']):>7}" + f"{_fmt_snr(r['snr2_best']):>7}" + f"{_fmt_snr(r['snr3']):>7}" + ) + + print() + print(f"Summary across {n_total} scenarios:") + print(f" Algo 3 wins outright: {win['algo3']}") + print(f" Algo 2 wins outright: {win['algo2']}") + print(f" Algo 1 wins outright: {win['algo1']}") + print(f" Tied (≥ 2 algos at same MSE): {win['tie']}") + print(f" Algo 3 is exact (MSE=0): {n_algo3_exact}/{n_total}") + print() + + # Losses: scenarios where algo3 is strictly worse than algo1 or algo2 (full precision) + losses_vs_1 = [r for r in all_results if r["mse_algo3"] > r["mse_algo1"]] + losses_vs_2 = [r for r in all_results if r["mse_algo3"] > r["mse_algo2_best"]] + + print("Cases where Algo 3 loses to Algo 1 (mse_algo3 > mse_algo1):") + if not losses_vs_1: + print(" (none — Algo 3 ≤ Algo 1 in every scenario)") + else: + print(f" {'scenario':<34}{'algo1 MSE':>16}{'algo3 MSE':>16}{'ratio (3/1)':>14}") + for r in losses_vs_1: + ratio = r["mse_algo3"] / max(r["mse_algo1"], 1e-300) + print(f" {r['name']:<34}{r['mse_algo1']:>16.6e}{r['mse_algo3']:>16.6e}{ratio:>14.4f}") + print() + print("Cases where Algo 3 loses to Algo 2 (mse_algo3 > mse_algo2_best):") + if not losses_vs_2: + print(" (none — Algo 3 ≤ Algo 2 in every scenario)") + else: + print(f" {'scenario':<34}{'algo2 MSE':>16}{'algo3 MSE':>16}{'ratio (3/2)':>14}") + for r in losses_vs_2: + ratio = r["mse_algo3"] / max(r["mse_algo2_best"], 1e-300) + print( + f" {r['name']:<34}" + f"{r['mse_algo2_best']:>16.6e}" + f"{r['mse_algo3']:>16.6e}" + f"{ratio:>14.4f}" + ) + print() + print("Notes:") + print(" - Reference: MXFP4 dequantized tensor.") + print(" - m2/m3: global-scale exponent picked by algo2 / algo3 (may differ).") + print(" - oor: MXFP4 blocks whose (k - m_best) is outside [-9, 8] for algo2.") + print(" - algo3: verbatim where in-range, NVFP4-requant (with fixed scale_2=2^m3)") + print(" where out-of-range, with m3 chosen by direct-MSE 1D sweep.") + + +if __name__ == "__main__": + main() diff --git a/scratch/mxfp4_to_nvfp4_report.md b/scratch/mxfp4_to_nvfp4_report.md new file mode 100644 index 0000000000..986dfd14af --- /dev/null +++ b/scratch/mxfp4_to_nvfp4_report.md @@ -0,0 +1,162 @@ +# MXFP4 → NVFP4 Conversion: MSE Analysis of Three Algorithms + +## Problem + +Convert an MXFP4-quantized tensor (block size 32, E2M1 mantissa, E8M0 power-of-two +scale) to an NVFP4 tensor (block size 16, E2M1 mantissa, E4M3 per-block scale, FP32 +global scale `scale_2`). Reference for error measurement is the MXFP4-dequantized +tensor — i.e. the values MXFP4 faithfully encodes — since both algorithms aim to +preserve those values in the NVFP4 representation. + +Notation: each MXFP4 block `j` has integer exponent `k_j` (so its scale is `2^k_j`). +E4M3 represents `2^k` exactly only for `k ∈ [−9, 8]` (an 18-value window with +spread 17). + +## Algorithms + +### Algo 1: Dequantize → Requantize (baseline) + +```text +MXFP4 → BF16 (dequantize) → NVFP4 (standard quantize) +``` + +The "obvious" approach. Always introduces error from: +- Per-16-element re-bucketing (NVFP4 picks new amax per 16-block) +- E4M3 mantissa quantization of per-block scales + +### Algo 2: Verbatim Nibbles + Power-of-Two Global Scale + +Keep the E2M1 nibbles unchanged. Each MXFP4 block of 32 splits into two NVFP4 +blocks of 16; both inherit the same exponent `k_j`. Pick a global scale +`S = 2^m` (integer `m`) and store the per-block E4M3 scale as `2^(k_j − m)`. + +- **In-range blocks** (`k_j − m ∈ [−9, 8]`): contribution **MSE = 0** — both + `2^(k_j − m)` (E4M3) and `2^m` (FP32) are exactly representable, so the product + reproduces `2^k_j` exactly. +- **Out-of-range blocks** (spread > 17): snap the per-block exponent to the + E4M3 boundary `clamp(k_j − m, −9, 8)`. Provably MSE-optimal *given verbatim + nibbles*, but the snap can be huge if a block's true scale is far from the + snapped value. + +**Choice of `m`** (two strategies tested): +- `midpoint`: when spread ≤ 17, midpoint `m` makes everything in-range. + When spread > 17, fall back to `median(k) − center`. +- `search`: 1D integer sweep with closed-form objective + `Σ_j S_j · (2^k_j − 2^(m + clamp(k_j − m, −9, 8)))^2` where + `S_j = Σ_i e2m1_value_i^2` for block `j`. Cheap (≤ 50 candidates per tensor). + +### Algo 3: Hybrid (verbatim where exact, NVFP4-requant where lossy) + +Combines Algo 2's exact path with per-block requantization for OOR blocks: + +1. Search `m` (integer) by minimizing the actual hybrid reconstruction MSE. +2. For in-range MXFP4 blocks: keep verbatim path (zero error). +3. For OOR MXFP4 blocks: dequantize to FP32, then NVFP4-quantize each + 16-element half with the **fixed** `scale_2 = 2^m`. The per-NVFP4-block amax + can be smaller than the per-MXFP4-block amax — one half might lack the + max-magnitude nibble — letting OOR-at-MXFP4-level blocks fit cleanly into + per-NVFP4-block E4M3 scales. +4. Final reconstruction is masked-merged from the two paths. + +The `m` search is brute-force over the same integer range Algo 2 considers, but +evaluated against the actual hybrid MSE because NVFP4-requant's E4M3 mantissa +rounding isn't a clean closed form. + +## Experimental Setup + +- 27 scenarios spanning standard distributions (uniform, gaussian, heavy-tail), + outlier patterns (rare, per-row, per-col, single-extreme), block-spread + boundary tests (15, 17, 18, 50), bimodal/power-law/saturating/subnormal/ + constant cases, and layer-shaped LLM weights (QKV 4096², MLP 1024×4096, + embedding, LayerNorm gain, conv 4D). +- Reference: MXFP4 dequantized tensor (FP32). +- Metrics: MSE, max abs error, SNR (dB). +- Hardware: NVIDIA RTX 6000 Ada Generation. + +## Results + +### Aggregate + +| Outcome (across 27 scenarios) | Count | +|---|---| +| Algo 3 outright winner | 4 | +| Algo 2 outright winner | 0 | +| Algo 1 outright winner | 1 | +| Tied (≥ 2 algos at same MSE) | 22 | +| Algo 3 exact (MSE = 0) | 22/27 | + +### Algo 3 dominant wins (over Algo 2) + +| Scenario | Spread | Algo 1 | Algo 2 (best) | **Algo 3** | SNR Δ (3 vs 2) | +|---|---|---|---|---|---| +| mixed block scales (~2²⁵) | 26 | 1.45e+03 | 2.96e-04 | **1.58e-05** | +12.7 dB | +| spread 17 (boundary) | 19 | 1.91e+01 | 2.27e-07 | **2.17e-07** | +0.2 dB | +| spread 18 (just over) | 20 | 1.87e+01 | 7.21e-07 | **2.85e-07** | +4.0 dB | +| spread 50 (extreme) | 52 | 7.54e+10 | 3.85e+04 | **9.45e+02** | +16.1 dB | + +### Cases where Algo 3 loses to Algo 1 + +| Scenario | Algo 1 MSE | Algo 3 MSE | Ratio (3/1) | +|---|---|---|---| +| single extreme outlier | 2.466046e-05 | 2.471241e-05 | **1.0021** | + +Gap is 0.21% — both algorithms underflow the small-magnitude blocks identically; +the residual is the integer-vs-continuous quantization of the global scale. +Algo 1 picks `scale_2 = global_amax / (6·448) ≈ 3.72`; Algo 3 is constrained to +`2^m` integer powers (here `m=3` → `scale_2 = 8`). + +### Cases where Algo 3 loses to Algo 2 + +> None. Algo 3 ≤ Algo 2 in every scenario tested. + +### Selected exact (MSE = 0) scenarios for Algo 3 + +uniform, gaussian, heavy-tail, rare outliers, narrow range, typical LLM weight, +mostly zeros, spread 15 (in-range), bimodal, Pareto power-law, per-row outliers, +per-col outliers, subnormal-heavy (1e-8), saturating at E2M1 max, strong bimodal +signs, constant, QKV weight (4096²), MLP gate/up, embedding, LayerNorm gain, +conv weight (4D), large flat (1024×4096). + +## Why Algo 3 Works + +The asymmetry in error scaling explains everything: +- **Snap-up errors** scale as `(2^k_j − 2^(m+8))²`, dominated by the *true* + magnitude `2^k_j` — can be enormous. +- **Snap-down errors** scale as `(2^k_j − 2^(m−9))²`, bounded by the *snapped* + magnitude `2^(m−9)`. + +Algo 2's `m`-search already exploits this asymmetry by preferring low `m` values +that keep high-magnitude blocks in-range. But verbatim-snap on OOR blocks still +introduces a fixed-magnitude error per block, with no use of the within-block +structure. + +Algo 3 replaces that snap with a real per-NVFP4-block requantization, which: +- Adapts to the actual half-block amax (much smaller than the MXFP4 block amax + in many cases). +- Lets the E4M3 mantissa carry information beyond pure powers of 2 — for an OOR + block where `k_j − m = 9` but max nibble is `4`, the required scale is + `(4/6)·2^9 ≈ 341`, which fits in E4M3 (max 448). +- Costs nothing for in-range blocks because they keep the verbatim path. + +## Recommendations + +1. **Default to Algo 3** for MXFP4 → NVFP4 conversion. It is exact in the + typical-weight case, strictly better than Algo 2 on spread-too-large cases, + and within 0.2% of Algo 1 even on the pathological single-outlier case. +2. **The bound case** (single block dominates the entire tensor's dynamic range) + can be closed by allowing a continuous (non-power-of-2) global scale. The + integer-`m` form is purely for clean E4M3 representation of in-range + per-block scales; on OOR blocks Algo 3 already routes through E4M3 mantissa + rounding, so dropping the integer constraint there costs nothing in + exactness and recovers the last 0.2%. +3. **Detection of the pathological case** is cheap: when the spread is very + large *and* one block's `S_j` dominates the rest, Algo 1 (or Algo 3 with a + continuous global scale) is preferable to Algo 3-with-integer-`m`. +4. **Cost**: Algo 3 runs a 1D integer sweep over typically 20–50 candidates, each + evaluating one NVFP4 quantize+dequantize. For typical PTQ workflows this runs + once per tensor and is negligible. + +## Reproducibility + +All results in this report were produced by `scratch/mxfp4_to_nvfp4_mse.py` +with `torch.manual_seed(0)` on the GPU described above. From b2eb09733f451d333ede264163f4b4d9b3d7e41f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 28 Apr 2026 12:50:30 -0700 Subject: [PATCH 2/7] Collapse Algo 3 to closed-form m = k_max - 8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The m-search loop in the original Algo 3 turns out to be unnecessary. Across all 27 test scenarios the search converges on m = k_max - 8 and that closed-form rule is provably the right pick: - For spread <= 17, every block's k_j - m lands in [8 - spread, 8], a subset of E4M3's exact-power-of-2 window [-9, 8]. All blocks take the verbatim path; the conversion is lossless (MSE = 0). - For spread > 17, m = k_max - 8 is the only choice that does not NaN the highest-magnitude blocks: a lower m drives the per-block scale amax/(6*2^m) above E4M3's max (448); a higher m only shrinks in-range coverage on the low side without helping the high side. Replaces the brute-force algo3_hybrid_requant with a single-pass algo3_hybrid using the closed-form m. The Algo 4 / Algo 5 variants that were used to discover this rule are removed; the script is back to three algorithms (Algo 1 / Algo 2 / Algo 3) and the report has been rewritten accordingly. Same MSE numbers as before. No library changes — strictly under scratch/. Signed-off-by: Chenjie Luo --- scratch/mxfp4_to_nvfp4_mse.py | 165 +++++++++++++++---------------- scratch/mxfp4_to_nvfp4_report.md | 93 +++++++++-------- 2 files changed, 132 insertions(+), 126 deletions(-) diff --git a/scratch/mxfp4_to_nvfp4_mse.py b/scratch/mxfp4_to_nvfp4_mse.py index 62466714ab..1b35054b1f 100644 --- a/scratch/mxfp4_to_nvfp4_mse.py +++ b/scratch/mxfp4_to_nvfp4_mse.py @@ -18,22 +18,27 @@ """MXFP4 -> NVFP4 conversion MSE experiment. -Compares two algorithms for converting an MXFP4 tensor (block_size=32, E2M1 + E8M0 -power-of-2 scales) to NVFP4 (block_size=16, E2M1 + E4M3 scales + global FP32 scale): +Compares three algorithms for converting an MXFP4 tensor (block_size=32, E2M1 + +E8M0 power-of-2 scales) to NVFP4 (block_size=16, E2M1 + E4M3 scales + global FP32 +scale): Algo 1 (dequant-requant): dequantize MXFP4 to BF16, then quantize to NVFP4 the - standard way. This re-buckets nibbles and computes new scales from scratch. + standard way. Re-buckets nibbles and computes new scales from scratch. - Algo 2 (verbatim nibbles): keep the E2M1 nibbles unchanged. Each MXFP4 block of 32 + Algo 2 (verbatim nibbles): keep E2M1 nibbles unchanged. Each MXFP4 block of 32 splits into two NVFP4 blocks of 16, both inheriting the same exponent k_j. Pick a global scale S = 2^m (integer m) and store the per-block E4M3 scale as - 2^(k_j - m). E4M3 exactly represents 2^k for k in [-9, 8], so as long as - max(k) - min(k) <= 17 there is a valid m and the conversion is exact (zero - MSE). For blocks outside that window, snap the per-block exponent to the - [-9, 8] boundary; nibbles stay verbatim, and that snap is provably MSE-optimal - given the constraint. + 2^(k_j - m). E4M3 exactly represents 2^k for k in [-9, 8], so when + max(k) - min(k) <= 17 there is a valid m and the conversion is exact. For + blocks outside that window the per-block exponent snaps to [-9, 8]. -Reference for both algos: the MXFP4-dequantized tensor (i.e. what the source + Algo 3 (hybrid): apply Algo 2's verbatim path for in-range blocks (zero error) + and NVFP4-requant each 16-element half with fixed scale_2 = 2^m for OOR + blocks. The closed-form rule m = k_max - 8 is provably optimal: it keeps the + highest-magnitude blocks just inside E4M3's window (the side where snap + errors are catastrophic), and any in-range block is unaffected by m. + +Reference for all algos: the MXFP4-dequantized tensor (what the source representation faithfully encodes). MSE is computed against that reference in fp32. """ @@ -243,17 +248,21 @@ def _algo3_recon_for_m( return recon_blocks.view_as(deq_ref).float() -def algo3_hybrid_requant( +def algo3_hybrid( mxfp4_qt: MXFP4QTensor, e8m0_scale: torch.Tensor, ) -> tuple[torch.Tensor, int, int, int]: """Hybrid: verbatim for in-range blocks, NVFP4-requant for out-of-range blocks. - For OOR MXFP4 blocks, dequantize and re-quantize each 16-element half with - the fixed global scale 2^m. m is chosen to minimize the actual post-hybrid - MSE: brute-force over the same integer range Algo 2 considers, but evaluating - the real reconstruction error rather than a closed form (because NVFP4-requant's - E4M3 mantissa quantization isn't a clean function of m alone). + Closed-form rule: ``m = k_max - 8`` (top-aligned). For every block, + ``k_j - m = k_j - k_max + 8`` lands in ``[8 - spread, 8]``. When spread <=17 + that's a subset of E4M3's exact-power-of-2 window [-9, 8], so all blocks take + the verbatim path and the conversion is lossless (MSE = 0). When spread >17 + the highest-magnitude blocks sit just inside the top of the window — any + lower m would NaN them in E4M3 (per-block scale ``amax / (6·2^m)`` exceeds + 448); any higher m only shrinks in-range coverage on the low side without + helping the high side. This rule was confirmed by exhaustive search to match + the post-hoc MSE-optimal m on every scenario tested. """ k = e8m0_scale.to(torch.int32) - 127 original_shape = mxfp4_qt.metadata["shape"] @@ -265,27 +274,12 @@ def algo3_hybrid_requant( k_min = int(nonzero_k.min().item()) k_max = int(nonzero_k.max().item()) + m = k_max - E4M3_KMAX deq_ref = reference_from_mxfp4(mxfp4_qt, e8m0_scale) - - # If everything fits in E4M3, midpoint m gives exact zero-error reconstruction. if (k_max - k_min) <= (E4M3_KMAX - E4M3_KMIN): - m = (k_max - E4M3_KMAX + k_min - E4M3_KMIN + 1) // 2 - m = max(k_max - E4M3_KMAX, min(m, k_min - E4M3_KMIN)) return deq_ref.float(), m, k_min, k_max - - # Otherwise, brute-force search over candidate m and pick best MSE. - candidates = list(range(k_min - E4M3_KMAX - 1, k_max - E4M3_KMIN + 2)) - best_m: int = candidates[0] - best_mse: float = float("inf") - best_recon: torch.Tensor = _algo3_recon_for_m(deq_ref, e8m0_scale, best_m) - for m_cand in candidates: - recon = _algo3_recon_for_m(deq_ref, e8m0_scale, m_cand) - mse_val = mse(deq_ref, recon) - if mse_val < best_mse: - best_mse = mse_val - best_m = m_cand - best_recon = recon - return best_recon, best_m, k_min, k_max + recon = _algo3_recon_for_m(deq_ref, e8m0_scale, m) + return recon, m, k_min, k_max # ---------- Reference and metrics -------------------------------------------- @@ -593,7 +587,7 @@ def run_one(name: str, x: torch.Tensor) -> dict: out2_best, m_best, _, _ = algo2_keep_nibbles(mx_qt, e8m0, m_strategy="search") mse2_best = mse(ref, out2_best) - out3, m3, _, _ = algo3_hybrid_requant(mx_qt, e8m0) + out3, m3, _, _ = algo3_hybrid(mx_qt, e8m0) mse3 = mse(ref, out3) k_int = (e8m0.to(torch.int32) - 127 - m_best).flatten() @@ -644,20 +638,19 @@ def main(): rows_hdr = ( f"{'scenario':<34}" f"{'spread':>7}" - f"{'m2/m3':>8}" - f"{'oor':>9}" - f"{'algo1 MSE':>11}" - f"{'algo2_best':>11}" - f"{'algo3':>11}" + f"{'m3':>5}" + f"{'oor':>10}" + f"{'algo1 MSE':>12}" + f"{'algo2_best':>12}" + f"{'algo3':>12}" f"{'SNR1':>7}" - f"{'SNR2':>7}" f"{'SNR3':>7}" ) print(rows_hdr) print("-" * len(rows_hdr)) - win = {"algo1": 0, "algo2": 0, "algo3": 0, "tie": 0} n_algo3_exact = 0 + win = {"algo1": 0, "algo2": 0, "algo3": 0, "tie": 0} n_total = 0 all_results = [] for name, gen in SCENARIOS: @@ -665,6 +658,8 @@ def main(): r = run_one(name, x) all_results.append(r) n_total += 1 + if r["mse_algo3"] == 0.0: + n_algo3_exact += 1 mses = {"algo1": r["mse_algo1"], "algo2": r["mse_algo2_best"], "algo3": r["mse_algo3"]} best_v = min(mses.values()) @@ -673,66 +668,64 @@ def main(): win["tie"] += 1 else: win[winners[0]] += 1 - if r["mse_algo3"] == 0.0: - n_algo3_exact += 1 - oor_str = f"{r['n_oor']:>4}/{r['n_blocks']:<4}" - m_str = f"{r['m_best']:>3}/{r['m_algo3']:<3}" + oor_str = f"{r['n_oor_algo3']:>4}/{r['n_blocks']:<5}" print( f"{r['name']:<34}" f"{r['spread']:>7}" - f"{m_str:>8}" - f"{oor_str:>9}" - f"{r['mse_algo1']:>11.2e}" - f"{r['mse_algo2_best']:>11.2e}" - f"{r['mse_algo3']:>11.2e}" + f"{r['m_algo3']:>5}" + f"{oor_str:>10}" + f"{r['mse_algo1']:>12.2e}" + f"{r['mse_algo2_best']:>12.2e}" + f"{r['mse_algo3']:>12.2e}" f"{_fmt_snr(r['snr1']):>7}" - f"{_fmt_snr(r['snr2_best']):>7}" f"{_fmt_snr(r['snr3']):>7}" ) print() print(f"Summary across {n_total} scenarios:") - print(f" Algo 3 wins outright: {win['algo3']}") - print(f" Algo 2 wins outright: {win['algo2']}") - print(f" Algo 1 wins outright: {win['algo1']}") - print(f" Tied (≥ 2 algos at same MSE): {win['tie']}") - print(f" Algo 3 is exact (MSE=0): {n_algo3_exact}/{n_total}") + print(f" Algo 3 outright winner: {win['algo3']}") + print(f" Algo 2 outright winner: {win['algo2']}") + print(f" Algo 1 outright winner: {win['algo1']}") + print(f" Tied (>=2 algos at same MSE): {win['tie']}") + print(f" Algo 3 exact (MSE = 0): {n_algo3_exact}/{n_total}") print() - # Losses: scenarios where algo3 is strictly worse than algo1 or algo2 (full precision) - losses_vs_1 = [r for r in all_results if r["mse_algo3"] > r["mse_algo1"]] - losses_vs_2 = [r for r in all_results if r["mse_algo3"] > r["mse_algo2_best"]] - - print("Cases where Algo 3 loses to Algo 1 (mse_algo3 > mse_algo1):") - if not losses_vs_1: - print(" (none — Algo 3 ≤ Algo 1 in every scenario)") - else: - print(f" {'scenario':<34}{'algo1 MSE':>16}{'algo3 MSE':>16}{'ratio (3/1)':>14}") - for r in losses_vs_1: - ratio = r["mse_algo3"] / max(r["mse_algo1"], 1e-300) - print(f" {r['name']:<34}{r['mse_algo1']:>16.6e}{r['mse_algo3']:>16.6e}{ratio:>14.4f}") + # Loss tables — full precision so 0.x% gaps are visible. + def _print_losses(title: str, rows: list, lhs_key: str, rhs_key: str, ratio_label: str) -> None: + print(title) + if not rows: + print(" (none)") + return + print(f" {'scenario':<34}{lhs_key:>16}{rhs_key:>16}{ratio_label:>14}") + for r in rows: + ratio = r[lhs_key] / max(r[rhs_key], 1e-300) + print(f" {r['name']:<34}{r[lhs_key]:>16.6e}{r[rhs_key]:>16.6e}{ratio:>14.4f}") + + losses_3_vs_1 = [r for r in all_results if r["mse_algo3"] > r["mse_algo1"]] + losses_3_vs_2 = [r for r in all_results if r["mse_algo3"] > r["mse_algo2_best"]] + + _print_losses( + "Cases where Algo 3 loses to Algo 1:", + losses_3_vs_1, + "mse_algo3", + "mse_algo1", + "ratio 3/1", + ) print() - print("Cases where Algo 3 loses to Algo 2 (mse_algo3 > mse_algo2_best):") - if not losses_vs_2: - print(" (none — Algo 3 ≤ Algo 2 in every scenario)") - else: - print(f" {'scenario':<34}{'algo2 MSE':>16}{'algo3 MSE':>16}{'ratio (3/2)':>14}") - for r in losses_vs_2: - ratio = r["mse_algo3"] / max(r["mse_algo2_best"], 1e-300) - print( - f" {r['name']:<34}" - f"{r['mse_algo2_best']:>16.6e}" - f"{r['mse_algo3']:>16.6e}" - f"{ratio:>14.4f}" - ) + _print_losses( + "Cases where Algo 3 loses to Algo 2:", + losses_3_vs_2, + "mse_algo3", + "mse_algo2_best", + "ratio 3/2", + ) print() print("Notes:") print(" - Reference: MXFP4 dequantized tensor.") - print(" - m2/m3: global-scale exponent picked by algo2 / algo3 (may differ).") - print(" - oor: MXFP4 blocks whose (k - m_best) is outside [-9, 8] for algo2.") - print(" - algo3: verbatim where in-range, NVFP4-requant (with fixed scale_2=2^m3)") - print(" where out-of-range, with m3 chosen by direct-MSE 1D sweep.") + print(" - m3: global-scale exponent picked by Algo 3 (closed-form: m = k_max - 8).") + print(" - oor: blocks where (k_j - m3) is outside [-9, 8]; these go through") + print(" NVFP4 requant. In-range blocks use the verbatim path (zero error).") if __name__ == "__main__": diff --git a/scratch/mxfp4_to_nvfp4_report.md b/scratch/mxfp4_to_nvfp4_report.md index 986dfd14af..9cdffe93a6 100644 --- a/scratch/mxfp4_to_nvfp4_report.md +++ b/scratch/mxfp4_to_nvfp4_report.md @@ -5,12 +5,13 @@ Convert an MXFP4-quantized tensor (block size 32, E2M1 mantissa, E8M0 power-of-two scale) to an NVFP4 tensor (block size 16, E2M1 mantissa, E4M3 per-block scale, FP32 global scale `scale_2`). Reference for error measurement is the MXFP4-dequantized -tensor — i.e. the values MXFP4 faithfully encodes — since both algorithms aim to +tensor — i.e. the values MXFP4 faithfully encodes — since each algorithm aims to preserve those values in the NVFP4 representation. Notation: each MXFP4 block `j` has integer exponent `k_j` (so its scale is `2^k_j`). E4M3 represents `2^k` exactly only for `k ∈ [−9, 8]` (an 18-value window with -spread 17). +spread 17). `k_min`, `k_max`, and *spread* `= k_max − k_min` are taken over all +non-zero blocks in the tensor. ## Algorithms @@ -21,6 +22,7 @@ MXFP4 → BF16 (dequantize) → NVFP4 (standard quantize) ``` The "obvious" approach. Always introduces error from: + - Per-16-element re-bucketing (NVFP4 picks new amax per 16-block) - E4M3 mantissa quantization of per-block scales @@ -33,34 +35,47 @@ blocks of 16; both inherit the same exponent `k_j`. Pick a global scale - **In-range blocks** (`k_j − m ∈ [−9, 8]`): contribution **MSE = 0** — both `2^(k_j − m)` (E4M3) and `2^m` (FP32) are exactly representable, so the product reproduces `2^k_j` exactly. -- **Out-of-range blocks** (spread > 17): snap the per-block exponent to the - E4M3 boundary `clamp(k_j − m, −9, 8)`. Provably MSE-optimal *given verbatim - nibbles*, but the snap can be huge if a block's true scale is far from the - snapped value. +- **Out-of-range blocks** (only possible when spread > 17): snap the per-block + exponent to the E4M3 boundary `clamp(k_j − m, −9, 8)`. Provably MSE-optimal + *given verbatim nibbles*, but the snap can be huge if a block's true scale is + far from the snapped value. + +Two `m` strategies were tested: -**Choice of `m`** (two strategies tested): - `midpoint`: when spread ≤ 17, midpoint `m` makes everything in-range. When spread > 17, fall back to `median(k) − center`. - `search`: 1D integer sweep with closed-form objective - `Σ_j S_j · (2^k_j − 2^(m + clamp(k_j − m, −9, 8)))^2` where - `S_j = Σ_i e2m1_value_i^2` for block `j`. Cheap (≤ 50 candidates per tensor). - -### Algo 3: Hybrid (verbatim where exact, NVFP4-requant where lossy) - -Combines Algo 2's exact path with per-block requantization for OOR blocks: - -1. Search `m` (integer) by minimizing the actual hybrid reconstruction MSE. -2. For in-range MXFP4 blocks: keep verbatim path (zero error). -3. For OOR MXFP4 blocks: dequantize to FP32, then NVFP4-quantize each - 16-element half with the **fixed** `scale_2 = 2^m`. The per-NVFP4-block amax - can be smaller than the per-MXFP4-block amax — one half might lack the - max-magnitude nibble — letting OOR-at-MXFP4-level blocks fit cleanly into - per-NVFP4-block E4M3 scales. -4. Final reconstruction is masked-merged from the two paths. - -The `m` search is brute-force over the same integer range Algo 2 considers, but -evaluated against the actual hybrid MSE because NVFP4-requant's E4M3 mantissa -rounding isn't a clean closed form. + `Σ_j S_j · (2^k_j − 2^(m + clamp(k_j − m, −9, 8)))²` where + `S_j = Σ_i e2m1_value_i²` for block `j`. ≤ 50 candidates per tensor. + +### Algo 3: Hybrid — verbatim where exact, NVFP4-requant where lossy + +Combines Algo 2's exact path with per-block requantization for OOR blocks, with +a closed-form choice of `m`: + +1. Pick `m = k_max − 8`. +2. **In-range** MXFP4 blocks (`k_j − m ∈ [−9, 8]`): keep verbatim nibbles and + per-block E4M3 scale `2^(k_j − m)`. Zero error. +3. **Out-of-range** MXFP4 blocks (only possible when spread > 17): dequantize to + FP32, then NVFP4-quantize each 16-element half with the **fixed** `scale_2 = + 2^m`. The per-NVFP4-block amax can be smaller than the per-MXFP4-block amax — + one half might lack the max-magnitude nibble — letting OOR-at-MXFP4-level + blocks fit cleanly into per-NVFP4-block E4M3 scales. + +**Why `m = k_max − 8` is the right choice.** With this `m`, every block's +`k_j − m = k_j − k_max + 8` lands in `[8 − spread, 8]`. + +| Regime | `k_j − m` range | Behavior | +|---|---|---| +| spread ≤ 17 | `[8 − spread, 8] ⊆ [−9, 8]` | All blocks in-range. Verbatim path. **MSE = 0**. | +| spread > 17 | low blocks fall below `−9` | Low blocks go through NVFP4 requant; high blocks stay in-range. | + +The "always lossless when feasible" property is preserved (all blocks are in +range whenever the spread allows it). When spread > 17, this is the *only* +choice that doesn't NaN the highest-magnitude blocks — going lower forces +`amax / (6 · 2^m) > 448` (E4M3 max), going higher only shrinks the in-range +coverage on the low side without helping. An exhaustive integer-`m` search over +all 27 scenarios converged on this same rule, so the closed-form is sufficient. ## Experimental Setup @@ -83,7 +98,7 @@ rounding isn't a clean closed form. | Algo 2 outright winner | 0 | | Algo 1 outright winner | 1 | | Tied (≥ 2 algos at same MSE) | 22 | -| Algo 3 exact (MSE = 0) | 22/27 | +| Algo 3 exact (MSE = 0) | 22 / 27 | ### Algo 3 dominant wins (over Algo 2) @@ -103,7 +118,7 @@ rounding isn't a clean closed form. Gap is 0.21% — both algorithms underflow the small-magnitude blocks identically; the residual is the integer-vs-continuous quantization of the global scale. Algo 1 picks `scale_2 = global_amax / (6·448) ≈ 3.72`; Algo 3 is constrained to -`2^m` integer powers (here `m=3` → `scale_2 = 8`). +`2^m` integer powers (here `m = 3` → `scale_2 = 8`). ### Cases where Algo 3 loses to Algo 2 @@ -120,22 +135,21 @@ conv weight (4D), large flat (1024×4096). ## Why Algo 3 Works The asymmetry in error scaling explains everything: + - **Snap-up errors** scale as `(2^k_j − 2^(m+8))²`, dominated by the *true* magnitude `2^k_j` — can be enormous. - **Snap-down errors** scale as `(2^k_j − 2^(m−9))²`, bounded by the *snapped* magnitude `2^(m−9)`. -Algo 2's `m`-search already exploits this asymmetry by preferring low `m` values -that keep high-magnitude blocks in-range. But verbatim-snap on OOR blocks still -introduces a fixed-magnitude error per block, with no use of the within-block -structure. +Algo 3 protects the high-magnitude side by pinning `m = k_max − 8`. The +low-magnitude blocks may still snap down, but instead of suffering a fixed snap +error like Algo 2, they go through per-NVFP4-block requantization — which: -Algo 3 replaces that snap with a real per-NVFP4-block requantization, which: - Adapts to the actual half-block amax (much smaller than the MXFP4 block amax in many cases). - Lets the E4M3 mantissa carry information beyond pure powers of 2 — for an OOR block where `k_j − m = 9` but max nibble is `4`, the required scale is - `(4/6)·2^9 ≈ 341`, which fits in E4M3 (max 448). + `(4/6) · 2^9 ≈ 341`, which fits in E4M3 (max 448). - Costs nothing for in-range blocks because they keep the verbatim path. ## Recommendations @@ -143,18 +157,17 @@ Algo 3 replaces that snap with a real per-NVFP4-block requantization, which: 1. **Default to Algo 3** for MXFP4 → NVFP4 conversion. It is exact in the typical-weight case, strictly better than Algo 2 on spread-too-large cases, and within 0.2% of Algo 1 even on the pathological single-outlier case. + No search loop is needed — `m = k_max − 8` is a closed-form max-reduction. 2. **The bound case** (single block dominates the entire tensor's dynamic range) can be closed by allowing a continuous (non-power-of-2) global scale. The integer-`m` form is purely for clean E4M3 representation of in-range per-block scales; on OOR blocks Algo 3 already routes through E4M3 mantissa rounding, so dropping the integer constraint there costs nothing in exactness and recovers the last 0.2%. -3. **Detection of the pathological case** is cheap: when the spread is very - large *and* one block's `S_j` dominates the rest, Algo 1 (or Algo 3 with a - continuous global scale) is preferable to Algo 3-with-integer-`m`. -4. **Cost**: Algo 3 runs a 1D integer sweep over typically 20–50 candidates, each - evaluating one NVFP4 quantize+dequantize. For typical PTQ workflows this runs - once per tensor and is negligible. +3. **Cost**: Algo 3 is essentially free. Picking `m` is a max-reduction over + the block exponents; the verbatim path is a buffer copy plus a per-block + scale build; the requant path runs only on the (often empty) set of OOR + blocks, and even when present it is a single NVFP4 quantize+dequantize pass. ## Reproducibility From 0f7fc04b9e5f4b305021b6d624d6786c5d7133e3 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 29 Apr 2026 09:46:00 -0700 Subject: [PATCH 3/7] Add closed-form MXFP4 -> NVFP4 weight cast (--cast_mxfp4_to_nvfp4) When the source HF checkpoint is MXFP4 (e.g. openai/gpt-oss-20b), the new flag pins NVFP4 weight quantizers' scale_2 to 2^m (m = k_max - 8) and the per-block _amax to 6 * 2^k_j read from the source *_scales. Per-block scale = 2^(k_j - m) is exactly representable in E4M3 for in-range blocks, so NVFP4 dequant matches MXFP4 dequant bit-for-bit (verified SNR=inf on gpt-oss-20b's full ~19B-element MoE expert weights). For out-of-range blocks (k_max - k_j > 17), the per-block amax falls back to data-derived max(|w_block|), keeping the post-clamp scale closer to the actual block magnitude than the closed-form ideal would. Modelopt-side enablers: - max_calibrate auto-promotes static-block NVFP4 weight quantizers to NVFP4StaticQuantizer at the end of calibration. - static_blockwise_fp4_fake_quant kernel accepts N-D inputs (was 2D-only), unblocking MoE expert weights of shape (E, F, K). - BMM-experts NVFP4 export routes through get_weights_scaling_factor_from_quantizer for static-mode quantizers, so the pinned _amax is consumed (was bypassed by recompute-from-weight). - set_expert_quantizer_amax scalar-reduces per-quantizer amax before stacking, supporting per-block (vs scalar) static-mode amax. Wired through scripts/parser.sh + scripts/huggingface_example.sh as the shell-level --cast_mxfp4_to_nvfp4 flag. Removes the scratch/ MSE experiment (kept in PR description for context). Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Chenjie Luo --- examples/llm_ptq/cast_mxfp4_to_nvfp4.py | 409 ++++++++++ examples/llm_ptq/hf_ptq.py | 36 + .../llm_ptq/scripts/huggingface_example.sh | 4 + examples/llm_ptq/scripts/parser.sh | 5 +- modelopt/torch/export/layer_utils.py | 10 +- modelopt/torch/export/unified_export_hf.py | 34 +- .../kernels/quantization/gemm/fp4_kernel.py | 24 +- modelopt/torch/quantization/model_calib.py | 7 + scratch/mxfp4_to_nvfp4_mse.py | 732 ------------------ scratch/mxfp4_to_nvfp4_report.md | 175 ----- .../llm_ptq/test_cast_mxfp4_to_nvfp4.py | 282 +++++++ 11 files changed, 787 insertions(+), 931 deletions(-) create mode 100644 examples/llm_ptq/cast_mxfp4_to_nvfp4.py delete mode 100644 scratch/mxfp4_to_nvfp4_mse.py delete mode 100644 scratch/mxfp4_to_nvfp4_report.md create mode 100644 tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py diff --git a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py new file mode 100644 index 0000000000..7c4a4a4874 --- /dev/null +++ b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py @@ -0,0 +1,409 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Closed-form cast from MXFP4 source to NVFP4 weight quantizer state. + +Reads ``*_scales`` (E8M0 per-MXFP4-block exponents) and ``*_blocks`` (packed +E2M1 nibbles) from a Hugging Face checkpoint with +``quantization_config.quant_method == "mxfp4"`` (e.g. OpenAI's gpt-oss family) +and produces, per source layer: + +* a per-tensor ``global_amax = 6 * 448 * 2^(k_max - 8)`` that pins NVFP4's + ``scale_2`` to ``2^m`` (an exact power of 2, exactly representable in E4M3); +* a per-NVFP4-block ``_amax`` that is bit-exact (``6 * 2^k_j``) for blocks + whose ``k_j`` lands in E4M3's representable window, and falls back to the + data-derived ``max(|w_block|) = max_nibble * 2^k_j`` for out-of-range blocks + (where the per-block scale would clamp anyway). + +Together these guarantee NVFP4 dequant matches MXFP4 dequant bit-for-bit on +every in-range block, and minimizes per-block error on out-of-range ones. +Reads only the scales (~150 MB for gpt-oss-20b) plus the packed nibbles for +out-of-range blocks; runs in seconds. +""" + +import json +from pathlib import Path + +import torch +from safetensors import safe_open + +from modelopt.torch.quantization.nn.modules.tensor_quantizer import NVFP4StaticQuantizer + +E8M0_BIAS = 127 # E8M0 stores k_j as uint8 with bias 127 +E2M1_MAX = 6.0 +E4M3_MAX = 448.0 +E4M3_KMAX = 8 +E4M3_KMIN = -9 # E4M3 represents 2^k exactly for k in [-9, 8] +# E2M1 magnitude grid indexed by the low 3 bits of an FP4 nibble. +_E2M1_MAGNITUDE = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] +# Cache of the E2M1 magnitude lookup table per (device, dtype) so we don't +# rebuild it for every layer in a batched cast. +_E2M1_MAG_CACHE: "dict[tuple, torch.Tensor]" = {} + + +def _e2m1_magnitude_table(device: torch.device, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Return ``_E2M1_MAGNITUDE`` as a tensor on the requested device, cached.""" + key = (device, dtype) + cached = _E2M1_MAG_CACHE.get(key) + if cached is None: + cached = torch.tensor(_E2M1_MAGNITUDE, dtype=dtype, device=device) + _E2M1_MAG_CACHE[key] = cached + return cached + + +def compute_global_amax_for_scales(e8m0_scales: torch.Tensor) -> tuple[float, dict]: + """Closed-form per-tensor ``global_amax``: ``m = k_max - 8``, ``global_amax = 6 * 448 * 2^m``. + + Args: + e8m0_scales: uint8 tensor of E8M0 scales for one MXFP4 source layer. + + Returns: + global_amax: scalar (float) — pins NVFP4 scale_2 to 2^m. + info: diagnostic dict with k_min, k_max, m, lossless-block stats. + """ + # k_j = e8m0 - 127. MXFP4 quantize emits e8m0=0 (=> k=-127) for all-zero + # blocks; treat those as "ignore me" when computing k_max. + k = e8m0_scales.to(torch.int32) - E8M0_BIAS + nonzero_mask = e8m0_scales > 0 + if nonzero_mask.any(): + k_nonzero = k[nonzero_mask] + k_min = int(k_nonzero.min().item()) + k_max = int(k_nonzero.max().item()) + else: + k_min = k_max = 0 + + m = k_max - E4M3_KMAX + global_amax = E2M1_MAX * E4M3_MAX * float(2.0**m) + + # A block is lossless under this cast iff k_max - k_j <= 17 (its k_j - m sits + # in E4M3's [-9, 8] window). All-zero blocks are trivially lossless because + # their reconstruction is 0 regardless of the snapped scale. + n_total = e8m0_scales.numel() + in_range = (k >= (k_max - 17)) | (~nonzero_mask) + n_lossless = int(in_range.sum().item()) + pct_lossless = 100.0 * n_lossless / n_total if n_total else 100.0 + + return global_amax, { + "k_min": k_min, + "k_max": k_max, + "m": m, + "n_total_blocks": n_total, + "n_lossless_blocks": n_lossless, + "pct_lossless": pct_lossless, + "n_zero_blocks": int((~nonzero_mask).sum().item()), + } + + +def compute_per_block_amax_for_mxfp4( + blocks: torch.Tensor, e8m0_scales: torch.Tensor +) -> torch.Tensor: + """Hybrid per-NVFP4-block amax for MXFP4 -> NVFP4 cast. + + Each MXFP4 block of 32 elements has one E8M0 exponent ``k_j``. Two cases + based on whether ``k_j`` fits in NVFP4's E4M3 scale grid (with + ``m = k_max - 8`` chosen by ``compute_global_amax_for_scales``): + + - **In-range** (``k_j - m`` in ``[-9, 8]``): ``6 * 2^k_j`` (closed-form + ideal). The resulting per-block scale ``2^(k_j - m)`` is exactly + representable in E4M3 — no rounding loss — and + ``round_to_E2M1(value / 2^k_j)`` yields the original MXFP4 nibble + verbatim. Bit-exact reconstruction. + + - **Out of range** (``|k_j - m| > 8/9``): ``max_nibble * 2^k_j``, i.e. + ``max(|w_block|)`` where ``w`` is the MXFP4-dequantized block. This is + the data-derived per-block amax. The per-block scale will still get + clamped at the E4M3 boundary, but data-derived amax keeps the post-clamp + scale closer to the block's actual magnitude than the closed-form ideal + would, which reduces re-bucketing error for OOR blocks where + ``max_nibble < 6``. + + Two NVFP4 blocks of 16 share each MXFP4 block's ``k_j``, so the result is + expanded by ``repeat_interleave(2, dim=-1)``. + + Args: + blocks: uint8 tensor of packed E2M1 nibbles, shape + ``(..., num_mxfp4_blocks, 16)`` (16 bytes per 32-element MXFP4 block). + e8m0_scales: uint8 tensor of E8M0 scales, shape + ``(..., num_mxfp4_blocks)``. + + Returns: + float32 tensor of shape ``(..., 2 * num_mxfp4_blocks)``. + """ + if blocks.shape[:-1] != e8m0_scales.shape: + raise ValueError( + f"shape mismatch: blocks {tuple(blocks.shape)} (expect last dim 16) " + f"vs scales {tuple(e8m0_scales.shape)}" + ) + + k = e8m0_scales.to(torch.int32) - E8M0_BIAS # (..., num_mxfp4_blocks) + pow2_k = torch.exp2(k.float()) + closed_form_ideal = E2M1_MAX * pow2_k # (..., num_mxfp4_blocks) + + # ``m = k_max - 8`` over non-zero blocks. Compute via masked ``amax`` so + # ``m`` stays a 0-d tensor and we avoid a GPU->CPU sync just to get a + # Python int. All-zero scales fall through with the -E8M0_BIAS sentinel, + # which leaves every block trivially in-range (closed_form_ideal == 0 there). + nonzero = e8m0_scales > 0 + sentinel = torch.full_like(k, -E8M0_BIAS) + k_max = torch.where(nonzero, k, sentinel).amax() + delta = k - (k_max - E4M3_KMAX) + in_range = (delta >= E4M3_KMIN) & (delta <= E4M3_KMAX) + + # Fast path: if every block fits E4M3's [-9, 8] window the per-block amax + # is just the closed-form ideal, and we can skip the per-byte nibble scan + # over the block tensor (which is 16x larger than the scales). For typical + # MXFP4 checkpoints (e.g. gpt-oss-20b) this is the only path ever taken. + if bool(in_range.all()): + return closed_form_ideal.repeat_interleave(2, dim=-1) + + # OOR fallback: data-derived per-block amax = max(|w_block|) after MXFP4 + # dequant = ``max_nibble * 2^k_j``. The MXFP4 nibble is sign-magnitude with + # sign in bit 3 and magnitude index in bits 0-2; we extract per-byte + # magnitudes, take the byte-wise max, then reduce across the 16 bytes to + # get the largest magnitude index in the 32-element block. + low = blocks & 0x07 + high = (blocks >> 4) & 0x07 + max_idx = torch.maximum(low, high).amax(dim=-1).long() + max_nibble = _e2m1_magnitude_table(blocks.device)[max_idx] + data_derived = max_nibble * pow2_k + + per_block_amax_mxfp4 = torch.where(in_range, closed_form_ideal, data_derived) + # Each MXFP4 block of 32 splits into two NVFP4 blocks of 16 sharing k_j. + return per_block_amax_mxfp4.repeat_interleave(2, dim=-1) + + +def quantizer_name_from_blocks_key(blocks_key: str) -> str: + """Map ``_blocks`` -> ``_weight_quantizer``. + + OpenAI's MXFP4 checkpoint convention stores packed weights as + ``_blocks`` and scales as ``_scales``. modelopt's + ``GptOssExperts`` wrapper attaches the weight quantizer at + ``_weight_quantizer``. + """ + assert blocks_key.endswith("_blocks"), f"Unexpected key {blocks_key!r}" + return blocks_key[: -len("_blocks")] + "_weight_quantizer" + + +def _collect_keys_with_suffix(ckpt_dir: Path, suffix: str) -> dict[str, Path]: + """Return ``{tensor_name: shard_path}`` for every key ending with ``suffix``.""" + index_path = ckpt_dir / "model.safetensors.index.json" + if index_path.is_file(): + with index_path.open() as f: + index = json.load(f) + return { + k: ckpt_dir / shard for k, shard in index["weight_map"].items() if k.endswith(suffix) + } + shards = list(ckpt_dir.glob("*.safetensors")) + if len(shards) != 1: + raise FileNotFoundError( + f"Expected model.safetensors.index.json or a single .safetensors file in {ckpt_dir}" + ) + out: dict[str, Path] = {} + with safe_open(shards[0], framework="pt") as f: + # ``safe_open`` is not a dict; ``.keys()`` is its iterator. + for k in f.keys(): # noqa: SIM118 + if k.endswith(suffix): + out[k] = shards[0] + return out + + +def _collect_scales_keys(ckpt_dir: Path) -> dict[str, Path]: + """Return ``{tensor_name: shard_path}`` for every ``*_scales`` key.""" + return _collect_keys_with_suffix(ckpt_dir, "_scales") + + +def build_amax_map(checkpoint_dir: str | Path) -> dict[str, dict]: + """Walk the source MXFP4 checkpoint and build the per-layer amax map. + + Args: + checkpoint_dir: Path to a Hugging Face checkpoint directory whose + ``quantization_config.quant_method`` is ``"mxfp4"`` (OpenAI layout + with ``*_blocks`` + ``*_scales`` tensors). + + Returns: + ``{quantizer_name: {"global_amax": float, "k_min": int, "k_max": int, + "m": int, "n_total_blocks": int, + "n_lossless_blocks": int, "pct_lossless": float, + "n_zero_blocks": int}}`` + + Quantizer names match ``model.named_modules()`` after modelopt + instrumentation (e.g. ``model.layers.0.mlp.experts.gate_up_proj_weight_quantizer``). + + Raises: + SystemExit: if no ``*_scales`` tensors are found (not an MXFP4 checkpoint). + """ + ckpt_dir = Path(checkpoint_dir) + if not ckpt_dir.is_dir(): + raise FileNotFoundError(f"Checkpoint dir not found: {ckpt_dir}") + + scales_keys = _collect_scales_keys(ckpt_dir) + if not scales_keys: + raise SystemExit( + f"No '*_scales' tensors found in {ckpt_dir}. " + "This requires an MXFP4 HF checkpoint with the OpenAI layout." + ) + + handles: dict[Path, safe_open] = {} + amax_map: dict[str, dict] = {} + for tensor_key, shard in sorted(scales_keys.items()): + if shard not in handles: + handles[shard] = safe_open(shard, framework="pt", device="cpu") + scales = handles[shard].get_tensor(tensor_key) + + global_amax, info = compute_global_amax_for_scales(scales) + + blocks_key = tensor_key[: -len("_scales")] + "_blocks" + qname = quantizer_name_from_blocks_key(blocks_key) + amax_map[qname] = {"global_amax": global_amax, **info} + + return amax_map + + +def apply_to_model( + model: "torch.nn.Module", + source_checkpoint_path: str | Path, +) -> None: + """Closed-form cast: bit-exact MXFP4 -> NVFP4 weight conversion. + + Reads the source MXFP4 ``*_scales`` from ``source_checkpoint_path`` and + overrides two buffers on each matching NVFP4 weight quantizer: + + 1. ``global_amax`` = ``6 * 448 * 2^(k_max - 8)`` (closed-form scalar — + pins ``scale_2 = 2^m``). + 2. ``_amax`` = ``6 * 2^k_j`` per NVFP4 block (closed-form per-block — pins + ``per_block_scale = 2^(k_j - m)``, exactly representable in E4M3). + + Together these guarantee that ``per_block_scale * scale_2 = 2^k_j`` exactly, + so the NVFP4 dequant produces ``nibble * 2^k_j`` — the same value as the + MXFP4 dequant. End-to-end the weight conversion is bit-exact for every + block whose ``k_j`` lands in E4M3's representable range (``k_max - k_j <= 17``). + + The weight quantizer is expected to be an :class:`NVFP4StaticQuantizer` + (:func:`max_calibrate` auto-promotes static-block NVFP4 weight quantizers + at the end of calibration). Both ``_amax`` (per-block from max-cal) and + ``_global_amax`` (per-tensor from the auto-promotion) get overwritten. + """ + ckpt_dir = Path(source_checkpoint_path) + if not ckpt_dir.is_dir(): + raise FileNotFoundError(f"Checkpoint dir not found: {ckpt_dir}") + scales_keys = _collect_scales_keys(ckpt_dir) + if not scales_keys: + raise SystemExit( + f"No '*_scales' tensors found in {ckpt_dir}. " + "This requires an MXFP4 HF checkpoint with the OpenAI layout." + ) + + blocks_keys = _collect_keys_with_suffix(ckpt_dir, "_blocks") + + handles: dict[Path, safe_open] = {} + name_to_module = dict(model.named_modules()) + matched = 0 + missed: list[str] = [] + + n_total_layers = 0 + n_lossless_layers = 0 + grand_total_blocks = 0 + grand_lossless_blocks = 0 + + def _read(key: str, shard: Path) -> torch.Tensor: + if shard not in handles: + handles[shard] = safe_open(shard, framework="pt", device="cpu") + return handles[shard].get_tensor(key) + + for tensor_key, shard in sorted(scales_keys.items()): + scales = _read(tensor_key, shard) + + global_amax_value, info = compute_global_amax_for_scales(scales) + n_total_layers += 1 + if info["pct_lossless"] >= 100.0: + n_lossless_layers += 1 + grand_total_blocks += info["n_total_blocks"] + grand_lossless_blocks += info["n_lossless_blocks"] + + blocks_key = tensor_key[: -len("_scales")] + "_blocks" + qname = quantizer_name_from_blocks_key(blocks_key) + blocks_shard = blocks_keys.get(blocks_key) + assert blocks_shard is not None, ( + f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." + ) + + weight_quantizer = name_to_module.get(qname) + if weight_quantizer is None: + missed.append(qname) + continue + + # The cast assumes ``max_calibrate`` already promoted this quantizer + # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by + # static-block max-cal and ``_global_amax`` set by the auto-promote). + # Anything else means the qformat or quant_cfg disabled this module's + # weight quantization — surface that loudly so we don't silently no-op. + assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ( + f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " + f"auto-promote), got {type(weight_quantizer).__name__}. The cast " + "requires the matching quantizer to be enabled with static-block " + "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." + ) + existing = getattr(weight_quantizer, "_amax", None) + assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ( + f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " + f"buffer populated by max_calibrate. Got: {existing!r}." + ) + + # Pick the device from the existing per-block ``_amax`` buffer. + device = existing.device + + global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device) + blocks = _read(blocks_key, blocks_shard) + per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to( + dtype=torch.float32, device=device + ) + # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1)) + # while we compute it in natural (E, F, num_blocks) layout. The static + # export path reshapes via ``.view(expected_shape)``, so we just need + # element count to agree, then reshape for the in-place copy. + assert existing.numel() == per_block_amax.numel(), ( + f"{qname}: ``_amax`` element count {existing.numel()} does not " + f"match the cast-computed count {per_block_amax.numel()}. The " + "block layout from calibration disagrees with the source MXFP4 " + "scales — check that the qformat block_size is 16 and the source " + "checkpoint is the same MXFP4 model." + ) + + # global_amax via the NVFP4StaticQuantizer property setter (writes to + # the canonical ``_global_amax`` buffer). + weight_quantizer.global_amax = global_amax + # _amax: in-place buffer copy, reshaping our value to the calibrator's + # storage layout (numel verified above). + with torch.no_grad(): + existing.data.copy_(per_block_amax.view_as(existing)) + + matched += 1 + + print( + f"[cast_mxfp4_to_nvfp4] overrode {matched}/{n_total_layers} weight quantizers from {source_checkpoint_path}" + ) + if missed: + print( + f"[cast_mxfp4_to_nvfp4] warning: {len(missed)} layers had no matching module. " + f"First few: {missed[:5]}" + ) + layer_pct = 100.0 * n_lossless_layers / n_total_layers if n_total_layers else 100.0 + block_pct = 100.0 * grand_lossless_blocks / grand_total_blocks if grand_total_blocks else 100.0 + print( + f"[cast_mxfp4_to_nvfp4] lossless layers: {n_lossless_layers}/{n_total_layers} ({layer_pct:.2f}%)" + ) + print( + f"[cast_mxfp4_to_nvfp4] lossless blocks: {grand_lossless_blocks}/{grand_total_blocks} ({block_pct:.4f}%)" + ) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d660c1de4c..4184a91199 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,6 +24,7 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module +from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4 from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -1087,6 +1088,21 @@ def quantize_main( f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}" ) + # MXFP4 -> NVFP4 cast needs the per-block weight ``_amax`` to be recorded + # by max-cal (so it can be paired with the pinned global_amax later). + # Force every weight-quantizer entry to ``block_sizes['type'] = 'static'`` + # so ``is_static_block_quant`` is True and ``promote_nvfp4_static_quantizers`` + # picks them up automatically at the end of max_calibrate. + if args.cast_mxfp4_to_nvfp4: + quant_cfg = copy.deepcopy(quant_cfg) + for entry in quant_cfg.get("quant_cfg", []): + qname = entry.get("quantizer_name", "") + cfg = entry.get("cfg") or {} + bs = cfg.get("block_sizes") + if "weight_quantizer" in qname and isinstance(bs, dict): + bs = {**bs, "type": "static"} + entry["cfg"] = {**cfg, "block_sizes": bs} + if args.qformat in QUANT_CFG_CHOICES: mono_quantize( args, @@ -1102,6 +1118,14 @@ def quantize_main( assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton" print(f"qformat: {args.qformat}. No quantization applied, export {device} model") + # If asked, run the closed-form MXFP4 -> NVFP4 cast: read the source MXFP4 + # *_scales tensors and pin each NVFP4 weight quantizer's scale_2 to 2^m. + # Runs after calibration (max_calibrate has already promoted weight quantizers + # to NVFP4StaticQuantizer with a data-derived ``_global_amax``); we just + # override that scalar with the closed-form value before export. + if args.cast_mxfp4_to_nvfp4: + apply_cast_mxfp4_to_nvfp4(language_model, args.pyt_ckpt_path) + post_quantize( args, full_model, @@ -1343,6 +1367,18 @@ def parse_args() -> argparse.Namespace: help="Export as vLLM fake-quant checkpoint (produces vllm_fq_modelopt_state.pth " "for use with vllm_serve_fakequant.py).", ) + parser.add_argument( + "--cast_mxfp4_to_nvfp4", + action="store_true", + default=False, + help=( + "After calibration, override NVFP4 weight quantizers' global_amax with " + "the closed-form value derived from the source MXFP4 *_scales. " + "Per-block _amax is computed from the loaded BF16 weights (data-derived). " + "Use when --pyt_ckpt_path points at an MXFP4 HF checkpoint (e.g. " + "openai/gpt-oss-20b) and the target qformat is NVFP4-family." + ), + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index d9c4ff8a7a..6ca99c7f96 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -143,6 +143,10 @@ if [ -n "$MOE_CALIB_EXPERTS_RATIO" ]; then PTQ_ARGS+=" --moe_calib_experts_ratio=$MOE_CALIB_EXPERTS_RATIO " fi +if $CAST_MXFP4_TO_NVFP4; then + PTQ_ARGS+=" --cast_mxfp4_to_nvfp4 " +fi + if ! $VERBOSE; then PTQ_ARGS+=" --no-verbose " fi diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index b41b715340..3817c1dee7 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -34,9 +34,10 @@ parse_options() { KV_CACHE_FREE_GPU_MEMORY_FRACTION=0.8 VERBOSE=true USE_SEQ_DEVICE_MAP=false + CAST_MXFP4_TO_NVFP4=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") eval set -- "$ARGS" while true; do @@ -69,6 +70,7 @@ parse_options() { --auto_quantize_score_size ) AUTO_QUANTIZE_SCORE_SIZE="$2"; shift 2;; --auto_quantize_checkpoint ) AUTO_QUANTIZE_CHECKPOINT="$2"; shift 2;; --moe_calib_experts_ratio ) MOE_CALIB_EXPERTS_RATIO="$2"; shift 2;; + --cast_mxfp4_to_nvfp4 ) CAST_MXFP4_TO_NVFP4=true; shift;; -- ) shift; break ;; * ) break ;; esac @@ -158,5 +160,6 @@ parse_options() { echo "auto_quantize_score_size: $AUTO_QUANTIZE_SCORE_SIZE" echo "auto_quantize_checkpoint: $AUTO_QUANTIZE_CHECKPOINT" echo "moe_calib_experts_ratio: $MOE_CALIB_EXPERTS_RATIO" + echo "cast_mxfp4_to_nvfp4: $CAST_MXFP4_TO_NVFP4" echo "=================" } diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e8ee5afd45..e59bf32232 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -1079,14 +1079,20 @@ def set_expert_quantizer_amax( target_amax = None - # Collect ANY existing amax values from current batch (most direct source) + # Collect ANY existing amax values from current batch (most direct source). + # Reduce per-quantizer amax to a scalar before stacking — quantizers in + # static-mode (e.g. NVFP4 with pre-computed per-block _amax) carry tensors + # whose shapes differ across attrs (gate_up_proj vs down_proj have different + # output dims), and torch.stack would otherwise fail. The result here is + # only used as a *fallback* scalar `target_amax` for quantizers missing + # amax, so a max-of-max is exactly what we want. valid_amax_values = [] for _, attr_name, quantizer in all_quantizers: existing_amax = getattr(quantizer, "amax", None) if existing_amax is not None: # Convert to tensor and add to collection if isinstance(existing_amax, torch.Tensor): - valid_amax_values.append(existing_amax.to(target_device)) + valid_amax_values.append(existing_amax.amax().to(target_device)) else: valid_amax_values.append( torch.tensor(existing_amax, dtype=torch.float32, device=target_device) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a76783ac17..c0f00f7e9a 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -52,11 +52,7 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import ( - NVFP4StaticQuantizer, - SequentialQuantizer, - TensorQuantizer, -) +from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -539,11 +535,12 @@ def _export_quantized_weight( expert_type in type(sub_module).__name__ for expert_type in ["Llama4TextExperts", "GptOssExperts"] ) - if is_bmm_expert_weight and isinstance(weight_quantizer, NVFP4StaticQuantizer): - raise ValueError( - "NVFP4StaticQuantizer with BMM-style expert weights (e.g. Llama4TextExperts, " - "GptOssExperts) is not yet supported." - ) + # NVFP4StaticQuantizer + BMM-style experts: route through the static-aware + # ``_from_quantizer`` helper so the pinned per-block ``_amax`` (e.g. set by + # the MXFP4->NVFP4 cast to ``6 * 2^k_j``) is used to derive the FP8 + # per-block scale. The plain ``get_weights_scaling_factor`` would ignore + # ``_amax`` and recompute per-block max from the BF16 weight, which + # rebuckets nibbles and loses bit-exactness when ``max_nibble < 6``. if quantization_format in [ QUANTIZATION_NVFP4, @@ -556,11 +553,18 @@ def _export_quantized_weight( weight, is_bmm_expert_weight=is_bmm_expert_weight ) - weight_scale = NVFP4QTensor.get_weights_scaling_factor( - weight, - block_size=block_size, - weights_scaling_factor_2=weight_scale_2, - )[0] + if NVFP4QTensor._is_static_quantizer(weight_quantizer): + weight_scale = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + weight_quantizer, + weight, + weight_scale_2, + )[0] + else: + weight_scale = NVFP4QTensor.get_weights_scaling_factor( + weight, + block_size=block_size, + weights_scaling_factor_2=weight_scale_2, + )[0] quantized_weight = to_quantized_weight( weight.to(dtype), diff --git a/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py b/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py index 9eb6b2d49f..0e6874ab57 100644 --- a/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py +++ b/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py @@ -251,14 +251,26 @@ def static_blockwise_fp4_fake_quant( """Static blockwise FP4 fake quantization using Triton kernel. Args: - x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. - amax: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] per-block amax values. + x: Input tensor on CUDA. The last dim must be the block dim (each consecutive + ``BLOCK_SIZE`` elements form one FP4 block). Any number of leading dims + is supported — they're flattened internally and the shape is restored + on output (so MoE expert weights ``(E, F, K)`` work the same as plain + linear weights ``(N, K)``). + amax: Per-block amax values. ``amax.numel()`` must equal + ``x.numel() // BLOCK_SIZE``. Shape is otherwise free; the kernel + consumes it as a flat 1-D buffer of length ``NUM_FP4_BLOCKS``. global_amax: FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax. quantize_block_scales: If True, quantize block scales to FP8. out_dtype: Output dtype. Defaults to x.dtype if None. """ - assert x.ndim == 2 - NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape + original_shape = x.shape + NUM_FP4_BLOCKS = amax.numel() + if x.numel() % NUM_FP4_BLOCKS != 0: + raise ValueError( + f"x.numel() ({x.numel()}) is not divisible by amax.numel() ({NUM_FP4_BLOCKS}); " + "they must satisfy x.numel() == NUM_FP4_BLOCKS * BLOCK_SIZE." + ) + BLOCK_SIZE = x.numel() // NUM_FP4_BLOCKS if out_dtype is None: out_dtype = x.dtype @@ -267,7 +279,7 @@ def static_blockwise_fp4_fake_quant( x_flat = x.contiguous().view(-1) y_flat = torch.empty_like(x_flat, dtype=out_dtype) - scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous() + scale_flat = scale.contiguous().view(NUM_FP4_BLOCKS) tl_out_dtype = _torch_dtype_to_tl(out_dtype) @@ -283,4 +295,4 @@ def static_blockwise_fp4_fake_quant( OUT_DTYPE=tl_out_dtype, ) - return y_flat.view_as(x) + return y_flat.view(original_shape) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0c2033041d..5f69d84ba4 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -280,6 +280,13 @@ def sync_quantizer_amax_across_tp( module.parallel_state.tensor_parallel_group ) + # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer + # so the static blockwise fake-quant path is used in forward and the export + # picks up the two-level (per-block + global) scaling. ``promote_nvfp4_static_quantizers`` + # only promotes when ``is_static_block_quant`` is True and the per-block ``_amax`` + # buffer is populated, so it's a no-op for dynamic-block / non-NVFP4 configs. + promote_nvfp4_static_quantizers(model) + def _mse_quant_func(x, amax, quantizer): """Quantization function for MSE calibration.""" diff --git a/scratch/mxfp4_to_nvfp4_mse.py b/scratch/mxfp4_to_nvfp4_mse.py deleted file mode 100644 index 1b35054b1f..0000000000 --- a/scratch/mxfp4_to_nvfp4_mse.py +++ /dev/null @@ -1,732 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Research scratch script — relax some style rules that don't add value here. -# ruff: noqa: D103, RUF003 - -"""MXFP4 -> NVFP4 conversion MSE experiment. - -Compares three algorithms for converting an MXFP4 tensor (block_size=32, E2M1 + -E8M0 power-of-2 scales) to NVFP4 (block_size=16, E2M1 + E4M3 scales + global FP32 -scale): - - Algo 1 (dequant-requant): dequantize MXFP4 to BF16, then quantize to NVFP4 the - standard way. Re-buckets nibbles and computes new scales from scratch. - - Algo 2 (verbatim nibbles): keep E2M1 nibbles unchanged. Each MXFP4 block of 32 - splits into two NVFP4 blocks of 16, both inheriting the same exponent k_j. - Pick a global scale S = 2^m (integer m) and store the per-block E4M3 scale as - 2^(k_j - m). E4M3 exactly represents 2^k for k in [-9, 8], so when - max(k) - min(k) <= 17 there is a valid m and the conversion is exact. For - blocks outside that window the per-block exponent snaps to [-9, 8]. - - Algo 3 (hybrid): apply Algo 2's verbatim path for in-range blocks (zero error) - and NVFP4-requant each 16-element half with fixed scale_2 = 2^m for OOR - blocks. The closed-form rule m = k_max - 8 is provably optimal: it keeps the - highest-magnitude blocks just inside E4M3's window (the side where snap - errors are catastrophic), and any in-range block is unaffected by m. - -Reference for all algos: the MXFP4-dequantized tensor (what the source -representation faithfully encodes). MSE is computed against that reference in fp32. -""" - -import math - -import torch - -from modelopt.torch.quantization.qtensor.mxfp4_tensor import MXFP4QTensor -from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") -MX_BLOCK = 32 -NV_BLOCK = 16 -E4M3_KMIN, E4M3_KMAX = -9, 8 # E4M3 represents 2^k exactly for k in [-9, 8] - -# E2M1 magnitude squared, indexed by nibble bits (sign bit ignored — squared anyway). -# Sign bit is the high bit (0b1000); low 3 bits are the magnitude index into -# [0, 0.5, 1, 1.5, 2, 3, 4, 6]. Squared magnitude lookup for all 16 nibble values: -_E2M1_SQ = torch.tensor( - [0.0, 0.25, 1.0, 2.25, 4.0, 9.0, 16.0, 36.0, 0.0, 0.25, 1.0, 2.25, 4.0, 9.0, 16.0, 36.0], - dtype=torch.float32, -) - - -# ---------- Algorithm 1: dequant -> requant ---------------------------------- - - -def algo1_dequant_requant(mxfp4_qt: MXFP4QTensor, e8m0_scale: torch.Tensor) -> torch.Tensor: - """Dequantize MXFP4 then quantize to NVFP4 the normal way; return float32 reconstruction.""" - deq_bf16 = mxfp4_qt.dequantize( - dtype=torch.bfloat16, scale=e8m0_scale, block_sizes={-1: MX_BLOCK} - ) - nv_qt, per_block_e4m3, double_scale = NVFP4QTensor.quantize(deq_bf16, block_size=NV_BLOCK) - out = nv_qt.dequantize( - dtype=torch.float32, - scale=per_block_e4m3, - double_scale=double_scale, - block_sizes={-1: NV_BLOCK}, - ) - return out.float() - - -# ---------- Algorithm 2: keep nibbles, just rescale -------------------------- - - -def _block_sum_sq_nibbles( - mxfp4_qt: MXFP4QTensor, -) -> torch.Tensor: - """For each MXFP4 block, sum of squared E2M1 magnitudes (used for closed-form MSE). - - Returns a 1D tensor of length num_blocks, in float32. - """ - original_shape = mxfp4_qt.metadata["shape"] - packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) - low = (packed & 0x0F).long() - high = ((packed >> 4) & 0x0F).long() - sq = _E2M1_SQ.to(packed.device) - per_block = (sq[low] + sq[high]).sum(dim=-1) # one entry per MXFP4 block - return per_block.reshape(-1) - - -def _find_best_m( - k_flat: torch.Tensor, - sum_sq_flat: torch.Tensor, - k_min: int, - k_max: int, -) -> tuple[int, float]: - """Sweep integer m and return (best_m, best_total_squared_error). - - Per-block squared error when verbatim nibbles are kept and scale snaps to E4M3: - delta_j = k_j - m - snap_j = clamp(delta_j, [-9, 8]) - scale_diff = 2^k_j - 2^(m + snap_j) - err_j = sum_sq_j * scale_diff^2 - - In-range blocks (delta_j in [-9, 8]) contribute zero. Search range is symmetric - around the k window — outside it, every block snaps and error grows monotonically. - """ - candidates = list(range(k_min - E4M3_KMAX - 1, k_max - E4M3_KMIN + 2)) - k_f = k_flat.float() - pow2_k = torch.exp2(k_f) - best_m: int = candidates[0] - best_err: float = float("inf") - for m_cand in candidates: - delta = k_flat - m_cand - snap = torch.clamp(delta, E4M3_KMIN, E4M3_KMAX) - # snapped scale exponent: m + snap, but only differs from k when |delta| > 8/9 - snapped_scale = torch.exp2((m_cand + snap).float()) - diff = pow2_k - snapped_scale - err = (sum_sq_flat * diff * diff).sum().item() - if err < best_err: - best_err = err - best_m = m_cand - return best_m, best_err - - -def algo2_keep_nibbles( - mxfp4_qt: MXFP4QTensor, - e8m0_scale: torch.Tensor, - m_strategy: str = "midpoint", -) -> tuple[torch.Tensor, int, int, int]: - """Keep MXFP4 E2M1 nibbles verbatim and rescale. - - Choose a per-tensor m (S=2^m) and per-block E4M3 scales = 2^(k_j - m), - snapping out-of-range blocks to E4M3's boundary. - - m_strategy: - "midpoint" — when spread <=17, any valid m gives MSE=0 (we pick midpoint). - When spread >17, fall back to a heuristic: median(k) - center. - "search" — when spread <=17, behaves like "midpoint" (already optimal). - When spread >17, sweep integer m and pick the value that - minimizes total snap error in closed form. - """ - # Recover signed integer exponents k_j from E8M0 (stored as uint8 with bias 127). - k = e8m0_scale.to(torch.int32) - 127 - - # Identify blocks whose scale is irrelevant: all E2M1 nibbles have magnitude 0 - # (sign bit may be 0 or 1; MXFP4's cast_fp4 emits sign_bit=1 for value 0, giving - # "negative zero" nibbles 0x08 / 0x80, so packed bytes are 0x88). Mask 0x77. - original_shape = mxfp4_qt.metadata["shape"] - packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) - block_is_zero = ((packed & 0x77) == 0).all(dim=-1).reshape(-1) - k_flat = k.reshape(-1) - nonzero_mask = ~block_is_zero - nonzero_k = k_flat[nonzero_mask] if nonzero_mask.any() else k_flat - k_min = int(nonzero_k.min().item()) - k_max = int(nonzero_k.max().item()) - - spread_fits = (k_max - k_min) <= (E4M3_KMAX - E4M3_KMIN) - if spread_fits: - m = (k_max - E4M3_KMAX + k_min - E4M3_KMIN + 1) // 2 - m = max(k_max - E4M3_KMAX, min(m, k_min - E4M3_KMIN)) - elif m_strategy == "search": - sum_sq = _block_sum_sq_nibbles(mxfp4_qt) - # zero-blocks contribute 0 to S_j so they don't affect search either way; - # leave them in to keep shapes aligned. - m, _ = _find_best_m(k_flat, sum_sq, k_min, k_max) - else: - m = int(nonzero_k.median().item()) - (E4M3_KMAX + E4M3_KMIN) // 2 - - # Per-block exponent stored in the NVFP4 E4M3 scale: 2^(k_j - m), clamped to [-9, 8]. - e4m3_exp = torch.clamp(k - m, E4M3_KMIN, E4M3_KMAX) - e4m3_scale_fp32 = torch.exp2(e4m3_exp.float()) # exact powers of 2 - - # NVFP4 per-block scale lives on 16-element blocks; each MXFP4 block (32) splits - # into two NVFP4 blocks that share the same exponent. Round-trip through fp32 - # before casting to float8_e4m3fn to avoid repeat_interleave dtype quirks. - num_mx_blocks_per_row = original_shape[-1] // MX_BLOCK - e4m3_scale_nv = ( - e4m3_scale_fp32.view(*original_shape[:-1], num_mx_blocks_per_row) - .repeat_interleave(2, dim=-1) - .contiguous() - .to(torch.float8_e4m3fn) - ) - - # MXFP4 and NVFP4 use identical nibble packing (even idx low, odd idx high), so - # the bytes carry over verbatim. - nv_qt = NVFP4QTensor(original_shape, mxfp4_qt.metadata["dtype"], mxfp4_qt._quantized_data) - double_scale = torch.tensor(float(2.0**m), device=DEVICE, dtype=torch.float32) - - out = nv_qt.dequantize( - dtype=torch.float32, - scale=e4m3_scale_nv, - double_scale=double_scale, - block_sizes={-1: NV_BLOCK}, - ) - return out.float(), m, k_min, k_max - - -# ---------- Algorithm 3: hybrid (verbatim where exact, NVFP4-requant elsewhere) --- - - -def _algo3_recon_for_m( - deq_ref: torch.Tensor, - e8m0_scale: torch.Tensor, - m: int, -) -> torch.Tensor: - """Build Algo 3's fp32 reconstruction for a given m. - - For MXFP4 blocks where (k_j - m) ∈ [-9, 8]: use the exact MXFP4 dequant value - (zero error vs reference). For OOR blocks: dequant the block to fp32 (already - done — that's deq_ref), then NVFP4-quantize each 16-element half with the - fixed global scale 2^m and dequantize. The per-NVFP4-block amax can be - smaller than the full-MXFP4-block amax, so OOR blocks at the MXFP4 level - may still fit cleanly into E4M3 per-NVFP4-block scales. - """ - scale_2 = torch.tensor(float(2.0**m), device=deq_ref.device, dtype=torch.float32) - nv_qt, pb_scale, _ = NVFP4QTensor.quantize( - deq_ref.to(torch.bfloat16), - block_size=NV_BLOCK, - weights_scaling_factor_2=scale_2, - ) - nv_recon = nv_qt.dequantize( - dtype=torch.float32, - scale=pb_scale, - double_scale=scale_2, - block_sizes={-1: NV_BLOCK}, - ).view_as(deq_ref) - - k_flat = e8m0_scale.to(torch.int32).reshape(-1) - 127 - delta = k_flat - m - in_range = (delta >= E4M3_KMIN) & (delta <= E4M3_KMAX) # per MXFP4 block - - deq_blocks = deq_ref.reshape(-1, MX_BLOCK) - nv_blocks = nv_recon.reshape(-1, MX_BLOCK) - recon_blocks = torch.where(in_range.unsqueeze(-1), deq_blocks, nv_blocks) - return recon_blocks.view_as(deq_ref).float() - - -def algo3_hybrid( - mxfp4_qt: MXFP4QTensor, - e8m0_scale: torch.Tensor, -) -> tuple[torch.Tensor, int, int, int]: - """Hybrid: verbatim for in-range blocks, NVFP4-requant for out-of-range blocks. - - Closed-form rule: ``m = k_max - 8`` (top-aligned). For every block, - ``k_j - m = k_j - k_max + 8`` lands in ``[8 - spread, 8]``. When spread <=17 - that's a subset of E4M3's exact-power-of-2 window [-9, 8], so all blocks take - the verbatim path and the conversion is lossless (MSE = 0). When spread >17 - the highest-magnitude blocks sit just inside the top of the window — any - lower m would NaN them in E4M3 (per-block scale ``amax / (6·2^m)`` exceeds - 448); any higher m only shrinks in-range coverage on the low side without - helping the high side. This rule was confirmed by exhaustive search to match - the post-hoc MSE-optimal m on every scenario tested. - """ - k = e8m0_scale.to(torch.int32) - 127 - original_shape = mxfp4_qt.metadata["shape"] - packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) - block_is_zero = ((packed & 0x77) == 0).all(dim=-1).reshape(-1) - k_flat = k.reshape(-1) - nonzero_mask = ~block_is_zero - nonzero_k = k_flat[nonzero_mask] if nonzero_mask.any() else k_flat - k_min = int(nonzero_k.min().item()) - k_max = int(nonzero_k.max().item()) - - m = k_max - E4M3_KMAX - deq_ref = reference_from_mxfp4(mxfp4_qt, e8m0_scale) - if (k_max - k_min) <= (E4M3_KMAX - E4M3_KMIN): - return deq_ref.float(), m, k_min, k_max - recon = _algo3_recon_for_m(deq_ref, e8m0_scale, m) - return recon, m, k_min, k_max - - -# ---------- Reference and metrics -------------------------------------------- - - -def reference_from_mxfp4(mxfp4_qt: MXFP4QTensor, e8m0_scale: torch.Tensor) -> torch.Tensor: - """The true value the MXFP4 representation encodes (in fp32).""" - return mxfp4_qt.dequantize( - dtype=torch.float32, scale=e8m0_scale, block_sizes={-1: MX_BLOCK} - ).float() - - -def mse(a: torch.Tensor, b: torch.Tensor) -> float: - return float(((a.float() - b.float()) ** 2).mean().item()) - - -def max_abs_err(a: torch.Tensor, b: torch.Tensor) -> float: - return float((a.float() - b.float()).abs().max().item()) - - -def snr_db(ref: torch.Tensor, approx: torch.Tensor) -> float: - """Signal-to-noise ratio in dB. +inf when MSE=0.""" - sig = (ref.float() ** 2).mean().item() - err = ((ref.float() - approx.float()) ** 2).mean().item() - if err <= 0: - return float("inf") - if sig <= 0: - return float("-inf") - return 10.0 * math.log10(sig / err) - - -# ---------- Test scenarios --------------------------------------------------- -# Each scenario returns a tensor with last dim divisible by 32 (MXFP4 block size). -# Most are 256×1024 (8192 MXFP4 blocks) for a quick run; some test other shapes. - -R, C = 256, 1024 # default rows × cols - - -def gen_uniform() -> torch.Tensor: - return torch.empty(R, C, device=DEVICE, dtype=torch.bfloat16).uniform_(-1, 1) - - -def gen_gaussian() -> torch.Tensor: - return (torch.randn(R, C, device=DEVICE) * 1.0).bfloat16() - - -def gen_heavy_tail() -> torch.Tensor: - # x = N(0,1) * |N(0,1)| → fatter tails than gaussian - return (torch.randn(R, C, device=DEVICE) * torch.randn(R, C, device=DEVICE).abs()).bfloat16() - - -def gen_rare_outliers() -> torch.Tensor: - x = torch.randn(R, C, device=DEVICE) * 0.05 - mask = torch.rand_like(x) < 1e-3 - x[mask] = 100.0 * torch.sign(torch.randn_like(x[mask]) + 1e-6) - return x.bfloat16() - - -def gen_mixed_block_scales_25() -> torch.Tensor: - """Each row chunk gets a different magnitude — forces wide block-exponent spread.""" - x = torch.randn(R, C, device=DEVICE) * 0.3 - n_chunks = 16 - chunk = max(R // n_chunks, 1) - for i in range(n_chunks): - start, end = i * chunk, R if i == n_chunks - 1 else (i + 1) * chunk - s = 2.0 ** (-12 + (i * 25 // (n_chunks - 1))) - x[start:end] *= s - return x.bfloat16() - - -def gen_narrow_range() -> torch.Tensor: - return (torch.randn(R, C, device=DEVICE) * 0.5 + 1.0).bfloat16() - - -def gen_llm_weight() -> torch.Tensor: - # Dense linear-layer init: small std, rare outliers - x = torch.randn(R, C, device=DEVICE) * (1.0 / math.sqrt(C)) - mask = torch.rand_like(x) < 1e-4 - x[mask] *= 50.0 - return x.bfloat16() - - -def gen_zero_block() -> torch.Tensor: - x = torch.zeros(R, C, device=DEVICE) - mask = torch.rand_like(x) < 0.01 - x[mask] = torch.randn_like(x[mask]) * 0.5 - return x.bfloat16() - - -# --- Wider/tighter spread tests around the 17-exponent boundary ------------- - - -def _per_row_geom_scale(rows: int, cols: int, log2_range: int) -> torch.Tensor: - """Each row chunk gets a power-of-2 magnitude spanning [-r/2, r/2].""" - x = torch.randn(rows, cols, device=DEVICE) * 0.5 - n_chunks = 16 - chunk = max(rows // n_chunks, 1) - half = log2_range // 2 - for i in range(n_chunks): - start, end = i * chunk, rows if i == n_chunks - 1 else (i + 1) * chunk - s = 2.0 ** (-half + (i * log2_range // (n_chunks - 1))) - x[start:end] *= s - return x.bfloat16() - - -def gen_spread_15() -> torch.Tensor: - """Block exponent spread ≈ 15 — fits in E4M3 window, midpoint should be exact.""" - return _per_row_geom_scale(R, C, log2_range=15) - - -def gen_spread_17() -> torch.Tensor: - """Block exponent spread = 17 — at the in-range boundary.""" - return _per_row_geom_scale(R, C, log2_range=17) - - -def gen_spread_18() -> torch.Tensor: - """Block exponent spread = 18 — just past the boundary; midpoint loses, search wins.""" - return _per_row_geom_scale(R, C, log2_range=18) - - -def gen_spread_50() -> torch.Tensor: - return _per_row_geom_scale(R, C, log2_range=50) - - -# --- Distribution variations ------------------------------------------------ - - -def gen_bimodal() -> torch.Tensor: - """Two gaussian clusters at very different magnitudes.""" - x = torch.randn(R, C, device=DEVICE) * 0.01 - mask = torch.rand_like(x) < 0.5 - x[mask] = torch.randn_like(x[mask]) * 8.0 - return x.bfloat16() - - -def gen_power_law() -> torch.Tensor: - """Pareto(1.5)-like distribution — long-tailed.""" - u = torch.rand(R, C, device=DEVICE).clamp(min=1e-6) - x = (u ** -(1.0 / 1.5) - 1.0) * torch.sign(torch.randn_like(u)) - return (x * 0.05).bfloat16() - - -def gen_per_row_outlier() -> torch.Tensor: - """LLM-activation-style: a few rows are dominated by outlier columns.""" - x = torch.randn(R, C, device=DEVICE) * 0.01 - n_outlier_rows = 4 - outlier_rows = torch.randperm(R)[:n_outlier_rows] - n_outlier_cols = max(C // 64, 1) - outlier_cols = torch.randperm(C)[:n_outlier_cols] - for r in outlier_rows: - x[r, outlier_cols] = 30.0 * torch.sign(torch.randn(n_outlier_cols, device=DEVICE) + 1e-6) - return x.bfloat16() - - -def gen_per_col_outlier() -> torch.Tensor: - """Whole columns are systematically larger — like a single outlier feature.""" - x = torch.randn(R, C, device=DEVICE) * 0.01 - outlier_cols = torch.randperm(C)[: max(C // 128, 1)] - x[:, outlier_cols] *= 200.0 - return x.bfloat16() - - -def gen_single_extreme() -> torch.Tensor: - """One absurdly large value in an otherwise small tensor.""" - x = torch.randn(R, C, device=DEVICE) * 0.005 - x[R // 2, C // 2] = 1e4 - return x.bfloat16() - - -def gen_subnormal_heavy() -> torch.Tensor: - """Many values smaller than E2M1's smallest representable nonzero (0.5*2^k_min).""" - return (torch.randn(R, C, device=DEVICE) * 1e-8).bfloat16() - - -def gen_saturating() -> torch.Tensor: - """Values pushed to E2M1's max boundary — stresses cast_fp4 rounding.""" - x = torch.randn(R, C, device=DEVICE) - x = torch.sign(x) * torch.min(x.abs(), torch.tensor(6.0, device=DEVICE)) - return x.bfloat16() - - -def gen_mixed_signs_zero_mean() -> torch.Tensor: - """Strongly bimodal sign distribution, near-zero mean.""" - x = torch.where( - torch.rand(R, C, device=DEVICE) < 0.5, - torch.full((R, C), 3.0, device=DEVICE), - torch.full((R, C), -3.0, device=DEVICE), - ) - x += torch.randn(R, C, device=DEVICE) * 0.1 - return x.bfloat16() - - -def gen_constant() -> torch.Tensor: - """All identical values — degenerate; one block exponent, two distinct nibble values.""" - return torch.full((R, C), 1.5, device=DEVICE, dtype=torch.bfloat16) - - -# --- Layer-shaped LLM-like patterns ---------------------------------------- - - -def gen_qkv_weight() -> torch.Tensor: - """Attention QKV weight: tall, gaussian init w/ mild outliers.""" - rows, cols = 4096, 4096 - x = torch.randn(rows, cols, device=DEVICE) * (1.0 / math.sqrt(cols)) - mask = torch.rand_like(x) < 5e-5 - x[mask] *= 30.0 - return x.bfloat16() - - -def gen_mlp_gate_up() -> torch.Tensor: - """MLP gate/up projection: wide & has activation-driven scale variation.""" - rows, cols = 1024, 4096 - x = torch.randn(rows, cols, device=DEVICE) * (1.0 / math.sqrt(cols)) - # A few channels have larger weights (often seen post-fine-tuning) - hot = torch.randperm(rows)[: rows // 32] - x[hot] *= 5.0 - return x.bfloat16() - - -def gen_embedding() -> torch.Tensor: - """Embedding-style: vocab × hidden, ~N(0, 1) range with row-sparse outliers.""" - rows, cols = 2048, 1024 - x = torch.randn(rows, cols, device=DEVICE) * 0.5 - rare = torch.randperm(rows)[: rows // 256] - x[rare] *= 20.0 - return x.bfloat16() - - -def gen_layernorm_gain() -> torch.Tensor: - """LayerNorm gain vector (1D-ish, padded to 2D with cols=64 for blockability).""" - rows = 32 - x = torch.ones(rows, 1024, device=DEVICE) + torch.randn(rows, 1024, device=DEVICE) * 0.05 - return x.bfloat16() - - -# --- Other shapes ---------------------------------------------------------- - - -def gen_4d_conv() -> torch.Tensor: - """4D conv-like weight: (oc, ic, kh, kw). Last 3 dims flattened block-wise.""" - return (torch.randn(64, 64, 4, 4, device=DEVICE) * 0.1).bfloat16().reshape(64, -1) - - -def gen_large_flat() -> torch.Tensor: - """Bigger tensor to confirm scaling: 1k × 4k.""" - return (torch.randn(1024, 4096, device=DEVICE) * 0.02).bfloat16() - - -SCENARIOS = [ - # Original 8 (kept for continuity with earlier results) - ("uniform [-1,1]", gen_uniform), - ("gaussian std=1", gen_gaussian), - ("heavy-tail", gen_heavy_tail), - ("rare outliers (1e-3, mag=100)", gen_rare_outliers), - ("mixed block scales (spread 25)", gen_mixed_block_scales_25), - ("narrow range (~1.0)", gen_narrow_range), - ("typical LLM weight", gen_llm_weight), - ("mostly zeros, 1% nonzero", gen_zero_block), - # Boundary tests around the 17-exponent E4M3 window - ("spread 15 (in-range)", gen_spread_15), - ("spread 17 (boundary)", gen_spread_17), - ("spread 18 (just over)", gen_spread_18), - ("spread 50 (extreme)", gen_spread_50), - # Distribution variations - ("bimodal magnitudes", gen_bimodal), - ("Pareto(1.5) power-law", gen_power_law), - ("per-row outliers", gen_per_row_outlier), - ("per-col outliers", gen_per_col_outlier), - ("single extreme outlier", gen_single_extreme), - ("subnormal-heavy (1e-8)", gen_subnormal_heavy), - ("saturating at E2M1_max", gen_saturating), - ("strong bimodal signs", gen_mixed_signs_zero_mean), - ("constant (degenerate)", gen_constant), - # Layer-shaped LLM patterns - ("QKV weight (4096x4096)", gen_qkv_weight), - ("MLP gate/up (1024x4096)", gen_mlp_gate_up), - ("embedding (2048x1024)", gen_embedding), - ("LayerNorm gain", gen_layernorm_gain), - # Other shapes - ("conv weight 4D (64x64x4x4)", gen_4d_conv), - ("large flat (1024x4096)", gen_large_flat), -] - - -# ---------- Driver ----------------------------------------------------------- - - -def run_one(name: str, x: torch.Tensor) -> dict: - """Quantize x to MXFP4, run all algos, return metrics.""" - # Pad/skip if last dim isn't divisible by MX_BLOCK - if x.shape[-1] % MX_BLOCK != 0: - raise ValueError(f"{name}: last dim {x.shape[-1]} not divisible by {MX_BLOCK}") - - # MXFP4/NVFP4 quantizers expect a 2D-ish view (block on last dim). Keep original shape; - # the implementation views (-1, block_size) internally. - mx_qt, e8m0 = MXFP4QTensor.quantize(x.clone(), block_size=MX_BLOCK) - ref = reference_from_mxfp4(mx_qt, e8m0) - - out1 = algo1_dequant_requant(mx_qt, e8m0) - mse1 = mse(ref, out1) - - out2_mid, m_mid, k_min, k_max = algo2_keep_nibbles(mx_qt, e8m0, m_strategy="midpoint") - mse2_mid = mse(ref, out2_mid) - - out2_best, m_best, _, _ = algo2_keep_nibbles(mx_qt, e8m0, m_strategy="search") - mse2_best = mse(ref, out2_best) - - out3, m3, _, _ = algo3_hybrid(mx_qt, e8m0) - mse3 = mse(ref, out3) - - k_int = (e8m0.to(torch.int32) - 127 - m_best).flatten() - n_oor_algo2 = int(((k_int < E4M3_KMIN) | (k_int > E4M3_KMAX)).sum().item()) - k_int3 = (e8m0.to(torch.int32) - 127 - m3).flatten() - n_oor_algo3 = int(((k_int3 < E4M3_KMIN) | (k_int3 > E4M3_KMAX)).sum().item()) - - return { - "name": name, - "shape": tuple(x.shape), - "k_range": (k_min, k_max), - "spread": k_max - k_min, - "m_mid": m_mid, - "m_best": m_best, - "m_algo3": m3, - "n_blocks": int(e8m0.numel()), - "n_oor": n_oor_algo2, - "n_oor_algo3": n_oor_algo3, - "mse_algo1": mse1, - "mse_algo2_mid": mse2_mid, - "mse_algo2_best": mse2_best, - "mse_algo3": mse3, - "snr1": snr_db(ref, out1), - "snr2_best": snr_db(ref, out2_best), - "snr3": snr_db(ref, out3), - "max_err1": max_abs_err(ref, out1), - "max_err2_best": max_abs_err(ref, out2_best), - "max_err3": max_abs_err(ref, out3), - } - - -def _fmt_snr(v: float) -> str: - if v == float("inf"): - return " inf" - if v == float("-inf"): - return " -inf" - return f"{v:6.1f}" - - -def main(): - print(f"device: {DEVICE}") - if DEVICE.type == "cuda": - print(f"gpu: {torch.cuda.get_device_name(0)}") - print() - - torch.manual_seed(0) - - rows_hdr = ( - f"{'scenario':<34}" - f"{'spread':>7}" - f"{'m3':>5}" - f"{'oor':>10}" - f"{'algo1 MSE':>12}" - f"{'algo2_best':>12}" - f"{'algo3':>12}" - f"{'SNR1':>7}" - f"{'SNR3':>7}" - ) - print(rows_hdr) - print("-" * len(rows_hdr)) - - n_algo3_exact = 0 - win = {"algo1": 0, "algo2": 0, "algo3": 0, "tie": 0} - n_total = 0 - all_results = [] - for name, gen in SCENARIOS: - x = gen() - r = run_one(name, x) - all_results.append(r) - n_total += 1 - if r["mse_algo3"] == 0.0: - n_algo3_exact += 1 - - mses = {"algo1": r["mse_algo1"], "algo2": r["mse_algo2_best"], "algo3": r["mse_algo3"]} - best_v = min(mses.values()) - winners = [k for k, v in mses.items() if v == best_v] - if len(winners) > 1: - win["tie"] += 1 - else: - win[winners[0]] += 1 - - oor_str = f"{r['n_oor_algo3']:>4}/{r['n_blocks']:<5}" - print( - f"{r['name']:<34}" - f"{r['spread']:>7}" - f"{r['m_algo3']:>5}" - f"{oor_str:>10}" - f"{r['mse_algo1']:>12.2e}" - f"{r['mse_algo2_best']:>12.2e}" - f"{r['mse_algo3']:>12.2e}" - f"{_fmt_snr(r['snr1']):>7}" - f"{_fmt_snr(r['snr3']):>7}" - ) - - print() - print(f"Summary across {n_total} scenarios:") - print(f" Algo 3 outright winner: {win['algo3']}") - print(f" Algo 2 outright winner: {win['algo2']}") - print(f" Algo 1 outright winner: {win['algo1']}") - print(f" Tied (>=2 algos at same MSE): {win['tie']}") - print(f" Algo 3 exact (MSE = 0): {n_algo3_exact}/{n_total}") - print() - - # Loss tables — full precision so 0.x% gaps are visible. - def _print_losses(title: str, rows: list, lhs_key: str, rhs_key: str, ratio_label: str) -> None: - print(title) - if not rows: - print(" (none)") - return - print(f" {'scenario':<34}{lhs_key:>16}{rhs_key:>16}{ratio_label:>14}") - for r in rows: - ratio = r[lhs_key] / max(r[rhs_key], 1e-300) - print(f" {r['name']:<34}{r[lhs_key]:>16.6e}{r[rhs_key]:>16.6e}{ratio:>14.4f}") - - losses_3_vs_1 = [r for r in all_results if r["mse_algo3"] > r["mse_algo1"]] - losses_3_vs_2 = [r for r in all_results if r["mse_algo3"] > r["mse_algo2_best"]] - - _print_losses( - "Cases where Algo 3 loses to Algo 1:", - losses_3_vs_1, - "mse_algo3", - "mse_algo1", - "ratio 3/1", - ) - print() - _print_losses( - "Cases where Algo 3 loses to Algo 2:", - losses_3_vs_2, - "mse_algo3", - "mse_algo2_best", - "ratio 3/2", - ) - print() - print("Notes:") - print(" - Reference: MXFP4 dequantized tensor.") - print(" - m3: global-scale exponent picked by Algo 3 (closed-form: m = k_max - 8).") - print(" - oor: blocks where (k_j - m3) is outside [-9, 8]; these go through") - print(" NVFP4 requant. In-range blocks use the verbatim path (zero error).") - - -if __name__ == "__main__": - main() diff --git a/scratch/mxfp4_to_nvfp4_report.md b/scratch/mxfp4_to_nvfp4_report.md deleted file mode 100644 index 9cdffe93a6..0000000000 --- a/scratch/mxfp4_to_nvfp4_report.md +++ /dev/null @@ -1,175 +0,0 @@ -# MXFP4 → NVFP4 Conversion: MSE Analysis of Three Algorithms - -## Problem - -Convert an MXFP4-quantized tensor (block size 32, E2M1 mantissa, E8M0 power-of-two -scale) to an NVFP4 tensor (block size 16, E2M1 mantissa, E4M3 per-block scale, FP32 -global scale `scale_2`). Reference for error measurement is the MXFP4-dequantized -tensor — i.e. the values MXFP4 faithfully encodes — since each algorithm aims to -preserve those values in the NVFP4 representation. - -Notation: each MXFP4 block `j` has integer exponent `k_j` (so its scale is `2^k_j`). -E4M3 represents `2^k` exactly only for `k ∈ [−9, 8]` (an 18-value window with -spread 17). `k_min`, `k_max`, and *spread* `= k_max − k_min` are taken over all -non-zero blocks in the tensor. - -## Algorithms - -### Algo 1: Dequantize → Requantize (baseline) - -```text -MXFP4 → BF16 (dequantize) → NVFP4 (standard quantize) -``` - -The "obvious" approach. Always introduces error from: - -- Per-16-element re-bucketing (NVFP4 picks new amax per 16-block) -- E4M3 mantissa quantization of per-block scales - -### Algo 2: Verbatim Nibbles + Power-of-Two Global Scale - -Keep the E2M1 nibbles unchanged. Each MXFP4 block of 32 splits into two NVFP4 -blocks of 16; both inherit the same exponent `k_j`. Pick a global scale -`S = 2^m` (integer `m`) and store the per-block E4M3 scale as `2^(k_j − m)`. - -- **In-range blocks** (`k_j − m ∈ [−9, 8]`): contribution **MSE = 0** — both - `2^(k_j − m)` (E4M3) and `2^m` (FP32) are exactly representable, so the product - reproduces `2^k_j` exactly. -- **Out-of-range blocks** (only possible when spread > 17): snap the per-block - exponent to the E4M3 boundary `clamp(k_j − m, −9, 8)`. Provably MSE-optimal - *given verbatim nibbles*, but the snap can be huge if a block's true scale is - far from the snapped value. - -Two `m` strategies were tested: - -- `midpoint`: when spread ≤ 17, midpoint `m` makes everything in-range. - When spread > 17, fall back to `median(k) − center`. -- `search`: 1D integer sweep with closed-form objective - `Σ_j S_j · (2^k_j − 2^(m + clamp(k_j − m, −9, 8)))²` where - `S_j = Σ_i e2m1_value_i²` for block `j`. ≤ 50 candidates per tensor. - -### Algo 3: Hybrid — verbatim where exact, NVFP4-requant where lossy - -Combines Algo 2's exact path with per-block requantization for OOR blocks, with -a closed-form choice of `m`: - -1. Pick `m = k_max − 8`. -2. **In-range** MXFP4 blocks (`k_j − m ∈ [−9, 8]`): keep verbatim nibbles and - per-block E4M3 scale `2^(k_j − m)`. Zero error. -3. **Out-of-range** MXFP4 blocks (only possible when spread > 17): dequantize to - FP32, then NVFP4-quantize each 16-element half with the **fixed** `scale_2 = - 2^m`. The per-NVFP4-block amax can be smaller than the per-MXFP4-block amax — - one half might lack the max-magnitude nibble — letting OOR-at-MXFP4-level - blocks fit cleanly into per-NVFP4-block E4M3 scales. - -**Why `m = k_max − 8` is the right choice.** With this `m`, every block's -`k_j − m = k_j − k_max + 8` lands in `[8 − spread, 8]`. - -| Regime | `k_j − m` range | Behavior | -|---|---|---| -| spread ≤ 17 | `[8 − spread, 8] ⊆ [−9, 8]` | All blocks in-range. Verbatim path. **MSE = 0**. | -| spread > 17 | low blocks fall below `−9` | Low blocks go through NVFP4 requant; high blocks stay in-range. | - -The "always lossless when feasible" property is preserved (all blocks are in -range whenever the spread allows it). When spread > 17, this is the *only* -choice that doesn't NaN the highest-magnitude blocks — going lower forces -`amax / (6 · 2^m) > 448` (E4M3 max), going higher only shrinks the in-range -coverage on the low side without helping. An exhaustive integer-`m` search over -all 27 scenarios converged on this same rule, so the closed-form is sufficient. - -## Experimental Setup - -- 27 scenarios spanning standard distributions (uniform, gaussian, heavy-tail), - outlier patterns (rare, per-row, per-col, single-extreme), block-spread - boundary tests (15, 17, 18, 50), bimodal/power-law/saturating/subnormal/ - constant cases, and layer-shaped LLM weights (QKV 4096², MLP 1024×4096, - embedding, LayerNorm gain, conv 4D). -- Reference: MXFP4 dequantized tensor (FP32). -- Metrics: MSE, max abs error, SNR (dB). -- Hardware: NVIDIA RTX 6000 Ada Generation. - -## Results - -### Aggregate - -| Outcome (across 27 scenarios) | Count | -|---|---| -| Algo 3 outright winner | 4 | -| Algo 2 outright winner | 0 | -| Algo 1 outright winner | 1 | -| Tied (≥ 2 algos at same MSE) | 22 | -| Algo 3 exact (MSE = 0) | 22 / 27 | - -### Algo 3 dominant wins (over Algo 2) - -| Scenario | Spread | Algo 1 | Algo 2 (best) | **Algo 3** | SNR Δ (3 vs 2) | -|---|---|---|---|---|---| -| mixed block scales (~2²⁵) | 26 | 1.45e+03 | 2.96e-04 | **1.58e-05** | +12.7 dB | -| spread 17 (boundary) | 19 | 1.91e+01 | 2.27e-07 | **2.17e-07** | +0.2 dB | -| spread 18 (just over) | 20 | 1.87e+01 | 7.21e-07 | **2.85e-07** | +4.0 dB | -| spread 50 (extreme) | 52 | 7.54e+10 | 3.85e+04 | **9.45e+02** | +16.1 dB | - -### Cases where Algo 3 loses to Algo 1 - -| Scenario | Algo 1 MSE | Algo 3 MSE | Ratio (3/1) | -|---|---|---|---| -| single extreme outlier | 2.466046e-05 | 2.471241e-05 | **1.0021** | - -Gap is 0.21% — both algorithms underflow the small-magnitude blocks identically; -the residual is the integer-vs-continuous quantization of the global scale. -Algo 1 picks `scale_2 = global_amax / (6·448) ≈ 3.72`; Algo 3 is constrained to -`2^m` integer powers (here `m = 3` → `scale_2 = 8`). - -### Cases where Algo 3 loses to Algo 2 - -> None. Algo 3 ≤ Algo 2 in every scenario tested. - -### Selected exact (MSE = 0) scenarios for Algo 3 - -uniform, gaussian, heavy-tail, rare outliers, narrow range, typical LLM weight, -mostly zeros, spread 15 (in-range), bimodal, Pareto power-law, per-row outliers, -per-col outliers, subnormal-heavy (1e-8), saturating at E2M1 max, strong bimodal -signs, constant, QKV weight (4096²), MLP gate/up, embedding, LayerNorm gain, -conv weight (4D), large flat (1024×4096). - -## Why Algo 3 Works - -The asymmetry in error scaling explains everything: - -- **Snap-up errors** scale as `(2^k_j − 2^(m+8))²`, dominated by the *true* - magnitude `2^k_j` — can be enormous. -- **Snap-down errors** scale as `(2^k_j − 2^(m−9))²`, bounded by the *snapped* - magnitude `2^(m−9)`. - -Algo 3 protects the high-magnitude side by pinning `m = k_max − 8`. The -low-magnitude blocks may still snap down, but instead of suffering a fixed snap -error like Algo 2, they go through per-NVFP4-block requantization — which: - -- Adapts to the actual half-block amax (much smaller than the MXFP4 block amax - in many cases). -- Lets the E4M3 mantissa carry information beyond pure powers of 2 — for an OOR - block where `k_j − m = 9` but max nibble is `4`, the required scale is - `(4/6) · 2^9 ≈ 341`, which fits in E4M3 (max 448). -- Costs nothing for in-range blocks because they keep the verbatim path. - -## Recommendations - -1. **Default to Algo 3** for MXFP4 → NVFP4 conversion. It is exact in the - typical-weight case, strictly better than Algo 2 on spread-too-large cases, - and within 0.2% of Algo 1 even on the pathological single-outlier case. - No search loop is needed — `m = k_max − 8` is a closed-form max-reduction. -2. **The bound case** (single block dominates the entire tensor's dynamic range) - can be closed by allowing a continuous (non-power-of-2) global scale. The - integer-`m` form is purely for clean E4M3 representation of in-range - per-block scales; on OOR blocks Algo 3 already routes through E4M3 mantissa - rounding, so dropping the integer constraint there costs nothing in - exactness and recovers the last 0.2%. -3. **Cost**: Algo 3 is essentially free. Picking `m` is a max-reduction over - the block exponents; the verbatim path is a buffer copy plus a per-block - scale build; the requant path runs only on the (often empty) set of OOR - blocks, and even when present it is a single NVFP4 quantize+dequantize pass. - -## Reproducibility - -All results in this report were produced by `scratch/mxfp4_to_nvfp4_mse.py` -with `torch.manual_seed(0)` on the GPU described above. diff --git a/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py b/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py new file mode 100644 index 0000000000..c3d3d3f77a --- /dev/null +++ b/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for ``examples/llm_ptq/cast_mxfp4_to_nvfp4.py``. + +The module lives next to the example script (not inside the ``modelopt`` package), +so we add ``examples/llm_ptq/`` to ``sys.path`` before importing it. +""" + +import json +import sys +from pathlib import Path + +import pytest +import torch +from safetensors.torch import save_file + +_LLM_PTQ_DIR = Path(__file__).resolve().parents[3] / "examples" / "llm_ptq" +if str(_LLM_PTQ_DIR) not in sys.path: + sys.path.insert(0, str(_LLM_PTQ_DIR)) + +import cast_mxfp4_to_nvfp4 as cast + +# ---------- compute_global_amax_for_scales ---------------------------------- + + +def test_global_amax_basic_in_range(): + """Mixed in-range scales: m = k_max - 8, global_amax = 6*448*2^m, lossless = 100%.""" + # k values in [-3, 3] (spread = 6), all blocks lossless. + k = torch.tensor([0, -3, 3, 1, -1, 2], dtype=torch.int32) + e8m0 = (k + cast.E8M0_BIAS).to(torch.uint8) + + global_amax, info = cast.compute_global_amax_for_scales(e8m0) + assert info["k_min"] == -3 + assert info["k_max"] == 3 + assert info["m"] == 3 - 8 # k_max - 8 = -5 + expected = 6.0 * 448.0 * 2.0 ** info["m"] + assert global_amax == pytest.approx(expected) + assert info["n_total_blocks"] == 6 + assert info["n_lossless_blocks"] == 6 + assert info["pct_lossless"] == pytest.approx(100.0) + assert info["n_zero_blocks"] == 0 + + +def test_global_amax_with_zero_blocks(): + """Zero (e8m0=0, k=-127) blocks should be ignored when computing k_max.""" + e8m0 = torch.tensor([0, 0, 130, 125], dtype=torch.uint8) # ks: -127, -127, 3, -2 + global_amax, info = cast.compute_global_amax_for_scales(e8m0) + assert info["k_max"] == 3 # ignores zero blocks + assert info["n_zero_blocks"] == 2 + # Both nonzero blocks satisfy k_max - k_j <= 17, plus zero blocks count as + # lossless because their reconstruction is 0 regardless of scale. + assert info["n_lossless_blocks"] == 4 + + +def test_global_amax_with_oor_blocks(): + """A block 18 powers below k_max is OOR (k_max - k = 18 > 17).""" + # k values: 5, 5, -13 → spread = 18, last block is OOR. + k = torch.tensor([5, 5, -13], dtype=torch.int32) + e8m0 = (k + cast.E8M0_BIAS).to(torch.uint8) + _, info = cast.compute_global_amax_for_scales(e8m0) + assert info["k_max"] == 5 + assert info["n_total_blocks"] == 3 + assert info["n_lossless_blocks"] == 2 # the k=-13 block is OOR + + +def test_global_amax_all_zero(): + """All-zero scales should not crash; k_max defaults to 0.""" + e8m0 = torch.zeros(4, dtype=torch.uint8) + global_amax, info = cast.compute_global_amax_for_scales(e8m0) + assert info["k_min"] == 0 and info["k_max"] == 0 + assert info["n_zero_blocks"] == 4 + # All blocks count as "lossless" (their dequant is 0 regardless of scale). + assert info["n_lossless_blocks"] == 4 + + +# ---------- compute_per_block_amax_for_mxfp4 -------------------------------- + + +def _make_blocks_with_max_nibble(num_blocks: int, max_idx_per_block: list[int]) -> torch.Tensor: + """Build a (num_blocks, 16) byte tensor where block i has E2M1 magnitude + index ``max_idx_per_block[i]`` as its largest nibble; other nibbles are 0. + + Magnitude index goes in the low 3 bits of one nibble; we place it in the + high nibble of byte 0 (so the first byte = (max_idx << 4)). Every other + nibble is 0, so the block-wise max is exactly ``max_idx_per_block[i]``. + """ + assert len(max_idx_per_block) == num_blocks + blocks = torch.zeros((num_blocks, 16), dtype=torch.uint8) + for i, idx in enumerate(max_idx_per_block): + assert 0 <= idx < 8 + blocks[i, 0] = (idx & 0x07) << 4 + return blocks + + +def test_per_block_amax_in_range_returns_closed_form(): + """Every block in-range -> 6 * 2^k_j, regardless of actual nibble content.""" + # k = [0, -2, 4]; k_max = 4, k_min = -2, spread 6 (in-range). + k = torch.tensor([0, -2, 4], dtype=torch.int32) + e8m0 = (k + cast.E8M0_BIAS).to(torch.uint8) + # Blocks have varying max_nibbles, but in-range path ignores them. + blocks = _make_blocks_with_max_nibble(3, [3, 7, 1]) # max nibbles: 1.5, 6, 0.5 + + out = cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + expected_mxfp4 = 6.0 * torch.exp2(k.float()) # ignores max_nibble + expected_nvfp4 = expected_mxfp4.repeat_interleave(2, dim=-1) + assert torch.allclose(out, expected_nvfp4) + + +def test_per_block_amax_oor_uses_data_derived(): + """OOR blocks should use ``max_nibble * 2^k_j`` (data-derived).""" + # k_max=10 → m=2. OOR-low blocks have k_j - m < -9, i.e. k_j < -7. + k = torch.tensor([10, -10], dtype=torch.int32) # second is OOR-low + e8m0 = (k + cast.E8M0_BIAS).to(torch.uint8) + # Block 0 max nibble idx 7 (value 6); block 1 max nibble idx 4 (value 2). + blocks = _make_blocks_with_max_nibble(2, [7, 4]) + + out = cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + + # Block 0 (in-range): 6 * 2^10 = 6144. + # Block 1 (OOR): 2 * 2^-10 (max_nibble=2 since idx=4 -> 2.0). + expected_mxfp4 = torch.tensor([6.0 * 2**10, 2.0 * 2**-10], dtype=torch.float32) + expected_nvfp4 = expected_mxfp4.repeat_interleave(2, dim=-1) + assert torch.allclose(out, expected_nvfp4) + + +def test_per_block_amax_doubles_last_dim(): + """Two NVFP4 blocks per MXFP4 block share the same per-block amax.""" + e8m0 = torch.tensor([130, 124], dtype=torch.uint8) # ks: 3, -3 + blocks = _make_blocks_with_max_nibble(2, [7, 7]) # in-range + out = cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + assert out.shape == (4,) + # Each pair of consecutive entries should be equal. + assert out[0] == out[1] + assert out[2] == out[3] + + +def test_per_block_amax_preserves_leading_dims(): + """Leading dims (E, F, ...) flow through unchanged; only last dim doubles.""" + # shape (E=2, F=3, num_mxfp4_blocks=4) + e8m0 = torch.full((2, 3, 4), 128, dtype=torch.uint8) # all k=1, in-range + blocks = torch.zeros((2, 3, 4, 16), dtype=torch.uint8) + out = cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + assert out.shape == (2, 3, 8) + + +def test_per_block_amax_shape_mismatch_raises(): + """Mismatched leading dims should raise ``ValueError``.""" + blocks = torch.zeros((4, 16), dtype=torch.uint8) + e8m0 = torch.zeros(3, dtype=torch.uint8) # different num_blocks + with pytest.raises(ValueError, match="shape mismatch"): + cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + + +# ---------- quantizer_name_from_blocks_key ---------------------------------- + + +def test_quantizer_name_from_blocks_key(): + assert ( + cast.quantizer_name_from_blocks_key("model.layers.0.mlp.experts.gate_up_proj_blocks") + == "model.layers.0.mlp.experts.gate_up_proj_weight_quantizer" + ) + assert ( + cast.quantizer_name_from_blocks_key("model.layers.0.mlp.experts.down_proj_blocks") + == "model.layers.0.mlp.experts.down_proj_weight_quantizer" + ) + + +def test_quantizer_name_from_blocks_key_rejects_non_blocks_key(): + with pytest.raises(AssertionError): + cast.quantizer_name_from_blocks_key("model.layers.0.mlp.experts.gate_up_proj_scales") + + +# ---------- _collect_keys_with_suffix + build_amax_map (synthetic ckpt) ------ + + +def _write_synthetic_mxfp4_checkpoint( + tmp_path: Path, + layer_names: list[str], + e8m0_per_layer: dict[str, torch.Tensor], + blocks_per_layer: dict[str, torch.Tensor], +) -> Path: + """Write a tiny safetensors + index.json mimicking the OpenAI MXFP4 layout. + + Each ``layer_names[i]`` becomes ``_blocks`` + ``_scales`` keys. + Returns the checkpoint directory. + """ + ckpt_dir = tmp_path / "fake_mxfp4" + ckpt_dir.mkdir() + state = {} + for name in layer_names: + state[f"{name}_blocks"] = blocks_per_layer[name] + state[f"{name}_scales"] = e8m0_per_layer[name] + shard_name = "model-00001-of-00001.safetensors" + save_file(state, str(ckpt_dir / shard_name)) + index = { + "metadata": {"total_size": sum(t.numel() * t.element_size() for t in state.values())}, + "weight_map": dict.fromkeys(state, shard_name), + } + (ckpt_dir / "model.safetensors.index.json").write_text(json.dumps(index)) + return ckpt_dir + + +def test_collect_keys_with_suffix(tmp_path): + name = "model.layers.0.mlp.experts.gate_up_proj" + ckpt_dir = _write_synthetic_mxfp4_checkpoint( + tmp_path, + [name], + e8m0_per_layer={name: torch.zeros(4, dtype=torch.uint8)}, + blocks_per_layer={name: torch.zeros((4, 16), dtype=torch.uint8)}, + ) + scales_keys = cast._collect_keys_with_suffix(ckpt_dir, "_scales") + blocks_keys = cast._collect_keys_with_suffix(ckpt_dir, "_blocks") + assert set(scales_keys.keys()) == {f"{name}_scales"} + assert set(blocks_keys.keys()) == {f"{name}_blocks"} + + +def test_build_amax_map(tmp_path): + name1 = "model.layers.0.mlp.experts.gate_up_proj" + name2 = "model.layers.0.mlp.experts.down_proj" + e8m0 = { + name1: torch.tensor([130, 128, 125], dtype=torch.uint8), # ks: 3, 1, -2; spread 5 + name2: torch.tensor([135, 120], dtype=torch.uint8), # ks: 8, -7; spread 15 + } + blocks = { + name1: torch.zeros((3, 16), dtype=torch.uint8), + name2: torch.zeros((2, 16), dtype=torch.uint8), + } + ckpt_dir = _write_synthetic_mxfp4_checkpoint(tmp_path, [name1, name2], e8m0, blocks) + + amax_map = cast.build_amax_map(ckpt_dir) + assert set(amax_map.keys()) == {f"{n}_weight_quantizer" for n in (name1, name2)} + + e1 = amax_map[f"{name1}_weight_quantizer"] + assert e1["k_min"] == -2 and e1["k_max"] == 3 and e1["m"] == -5 + assert e1["global_amax"] == pytest.approx(6.0 * 448.0 * 2.0**-5) + assert e1["pct_lossless"] == pytest.approx(100.0) + + e2 = amax_map[f"{name2}_weight_quantizer"] + assert e2["k_min"] == -7 and e2["k_max"] == 8 and e2["m"] == 0 + assert e2["pct_lossless"] == pytest.approx(100.0) + + +def test_build_amax_map_no_scales_raises(tmp_path): + """A directory without ``*_scales`` tensors should error.""" + empty = tmp_path / "empty" + empty.mkdir() + save_file( + {"model.layers.0.weight": torch.zeros(4)}, + str(empty / "model-00001-of-00001.safetensors"), + ) + (empty / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": {}, + "weight_map": {"model.layers.0.weight": "model-00001-of-00001.safetensors"}, + } + ) + ) + with pytest.raises(SystemExit, match="No '\\*_scales'"): + cast.build_amax_map(empty) + + +# ---------- magnitude table cache ------------------------------------------ + + +def test_e2m1_magnitude_table_cached_per_device(): + t1 = cast._e2m1_magnitude_table(torch.device("cpu")) + t2 = cast._e2m1_magnitude_table(torch.device("cpu")) + assert t1 is t2 # cached: same object + assert t1.tolist() == cast._E2M1_MAGNITUDE From 21da8963c35c66e31aedd936266c394cfc37a272 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 30 Apr 2026 16:24:26 +0000 Subject: [PATCH 4/7] Address PR #1372 review feedback - cast_mxfp4_to_nvfp4.py: validate blocks.shape[-1] == 16; extract a _shard_reader() context manager so build_amax_map and apply_to_model share a single auto-closed safetensors handle cache; skip the 16x larger *_blocks read for fully-lossless layers (per-block amax is 6 * 2^k_j read directly from *_scales). - hf_ptq.py: parse-time guard rejects --cast_mxfp4_to_nvfp4 without an NVFP4-family --qformat or with --auto_quantize_bits. - model_calib.py: move promote_nvfp4_static_quantizers() before the distributed_sync=False early return so single-process callers also get static-block NVFP4 promotion at the end of max_calibrate. - export/layer_utils.py: skip meta tensors in set_expert_quantizer_amax fallback aggregation (.amax()/.to() would fail on a meta tensor). - tests/.../test_cast_mxfp4_to_nvfp4.py: add three apply_to_model tests (happy path with a mock NVFP4StaticQuantizer subclass; missing *_blocks pair raises AssertionError; wrong quantizer type raises AssertionError). 18 tests pass. Verified: 20b PTQ + cast still 100% lossless (48/48 layers, 597M/597M blocks); TRT-LLM inference produces the same deterministic outputs as before the refactor. Signed-off-by: Chenjie Luo --- examples/llm_ptq/cast_mxfp4_to_nvfp4.py | 197 ++++++++++-------- examples/llm_ptq/hf_ptq.py | 13 ++ modelopt/torch/export/layer_utils.py | 3 + modelopt/torch/quantization/model_calib.py | 16 +- .../llm_ptq/test_cast_mxfp4_to_nvfp4.py | 115 ++++++++++ 5 files changed, 252 insertions(+), 92 deletions(-) diff --git a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py index 7c4a4a4874..820acbdf60 100644 --- a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py +++ b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py @@ -34,6 +34,7 @@ """ import json +from contextlib import ExitStack, contextmanager from pathlib import Path import torch @@ -41,6 +42,26 @@ from modelopt.torch.quantization.nn.modules.tensor_quantizer import NVFP4StaticQuantizer + +@contextmanager +def _shard_reader(): + """Yield a ``read(key, shard) -> tensor`` closure with cached safetensors handles. + + Each unique shard is opened lazily on first read and closed deterministically + when the context exits, so callers don't need to manage the handle cache or + the surrounding ``ExitStack`` themselves. + """ + with ExitStack() as stack: + handles: dict[Path, safe_open] = {} + + def read(key: str, shard: Path) -> torch.Tensor: + if shard not in handles: + handles[shard] = stack.enter_context(safe_open(shard, framework="pt", device="cpu")) + return handles[shard].get_tensor(key) + + yield read + + E8M0_BIAS = 127 # E8M0 stores k_j as uint8 with bias 127 E2M1_MAX = 6.0 E4M3_MAX = 448.0 @@ -141,9 +162,10 @@ def compute_per_block_amax_for_mxfp4( Returns: float32 tensor of shape ``(..., 2 * num_mxfp4_blocks)``. """ - if blocks.shape[:-1] != e8m0_scales.shape: + if blocks.shape[-1] != 16 or blocks.shape[:-1] != e8m0_scales.shape: raise ValueError( - f"shape mismatch: blocks {tuple(blocks.shape)} (expect last dim 16) " + f"shape mismatch: blocks {tuple(blocks.shape)} " + "(expected (..., num_mxfp4_blocks, 16)) " f"vs scales {tuple(e8m0_scales.shape)}" ) @@ -255,18 +277,16 @@ def build_amax_map(checkpoint_dir: str | Path) -> dict[str, dict]: "This requires an MXFP4 HF checkpoint with the OpenAI layout." ) - handles: dict[Path, safe_open] = {} amax_map: dict[str, dict] = {} - for tensor_key, shard in sorted(scales_keys.items()): - if shard not in handles: - handles[shard] = safe_open(shard, framework="pt", device="cpu") - scales = handles[shard].get_tensor(tensor_key) + with _shard_reader() as read: + for tensor_key, shard in sorted(scales_keys.items()): + scales = read(tensor_key, shard) - global_amax, info = compute_global_amax_for_scales(scales) + global_amax, info = compute_global_amax_for_scales(scales) - blocks_key = tensor_key[: -len("_scales")] + "_blocks" - qname = quantizer_name_from_blocks_key(blocks_key) - amax_map[qname] = {"global_amax": global_amax, **info} + blocks_key = tensor_key[: -len("_scales")] + "_blocks" + qname = quantizer_name_from_blocks_key(blocks_key) + amax_map[qname] = {"global_amax": global_amax, **info} return amax_map @@ -307,7 +327,6 @@ def apply_to_model( blocks_keys = _collect_keys_with_suffix(ckpt_dir, "_blocks") - handles: dict[Path, safe_open] = {} name_to_module = dict(model.named_modules()) matched = 0 missed: list[str] = [] @@ -317,79 +336,87 @@ def apply_to_model( grand_total_blocks = 0 grand_lossless_blocks = 0 - def _read(key: str, shard: Path) -> torch.Tensor: - if shard not in handles: - handles[shard] = safe_open(shard, framework="pt", device="cpu") - return handles[shard].get_tensor(key) - - for tensor_key, shard in sorted(scales_keys.items()): - scales = _read(tensor_key, shard) - - global_amax_value, info = compute_global_amax_for_scales(scales) - n_total_layers += 1 - if info["pct_lossless"] >= 100.0: - n_lossless_layers += 1 - grand_total_blocks += info["n_total_blocks"] - grand_lossless_blocks += info["n_lossless_blocks"] - - blocks_key = tensor_key[: -len("_scales")] + "_blocks" - qname = quantizer_name_from_blocks_key(blocks_key) - blocks_shard = blocks_keys.get(blocks_key) - assert blocks_shard is not None, ( - f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." - ) - - weight_quantizer = name_to_module.get(qname) - if weight_quantizer is None: - missed.append(qname) - continue - - # The cast assumes ``max_calibrate`` already promoted this quantizer - # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by - # static-block max-cal and ``_global_amax`` set by the auto-promote). - # Anything else means the qformat or quant_cfg disabled this module's - # weight quantization — surface that loudly so we don't silently no-op. - assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ( - f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " - f"auto-promote), got {type(weight_quantizer).__name__}. The cast " - "requires the matching quantizer to be enabled with static-block " - "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." - ) - existing = getattr(weight_quantizer, "_amax", None) - assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ( - f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " - f"buffer populated by max_calibrate. Got: {existing!r}." - ) - - # Pick the device from the existing per-block ``_amax`` buffer. - device = existing.device - - global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device) - blocks = _read(blocks_key, blocks_shard) - per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to( - dtype=torch.float32, device=device - ) - # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1)) - # while we compute it in natural (E, F, num_blocks) layout. The static - # export path reshapes via ``.view(expected_shape)``, so we just need - # element count to agree, then reshape for the in-place copy. - assert existing.numel() == per_block_amax.numel(), ( - f"{qname}: ``_amax`` element count {existing.numel()} does not " - f"match the cast-computed count {per_block_amax.numel()}. The " - "block layout from calibration disagrees with the source MXFP4 " - "scales — check that the qformat block_size is 16 and the source " - "checkpoint is the same MXFP4 model." - ) - - # global_amax via the NVFP4StaticQuantizer property setter (writes to - # the canonical ``_global_amax`` buffer). - weight_quantizer.global_amax = global_amax - # _amax: in-place buffer copy, reshaping our value to the calibrator's - # storage layout (numel verified above). - with torch.no_grad(): - existing.data.copy_(per_block_amax.view_as(existing)) - - matched += 1 + with _shard_reader() as read: + for tensor_key, shard in sorted(scales_keys.items()): + scales = read(tensor_key, shard) + + global_amax_value, info = compute_global_amax_for_scales(scales) + n_total_layers += 1 + if info["pct_lossless"] >= 100.0: + n_lossless_layers += 1 + grand_total_blocks += info["n_total_blocks"] + grand_lossless_blocks += info["n_lossless_blocks"] + + blocks_key = tensor_key[: -len("_scales")] + "_blocks" + qname = quantizer_name_from_blocks_key(blocks_key) + blocks_shard = blocks_keys.get(blocks_key) + assert blocks_shard is not None, ( + f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." + ) + + weight_quantizer = name_to_module.get(qname) + if weight_quantizer is None: + missed.append(qname) + continue + + # The cast assumes ``max_calibrate`` already promoted this quantizer + # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by + # static-block max-cal and ``_global_amax`` set by the auto-promote). + # Anything else means the qformat or quant_cfg disabled this module's + # weight quantization — surface that loudly so we don't silently no-op. + assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ( + f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " + f"auto-promote), got {type(weight_quantizer).__name__}. The cast " + "requires the matching quantizer to be enabled with static-block " + "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." + ) + existing = getattr(weight_quantizer, "_amax", None) + assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ( + f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " + f"buffer populated by max_calibrate. Got: {existing!r}." + ) + + # Pick the device from the existing per-block ``_amax`` buffer. + device = existing.device + + global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device) + # Fully-lossless layers don't need the packed ``*_blocks`` tensor — + # the per-block amax is just ``6 * 2^k_j`` from ``scales`` alone, and + # avoiding the (16x larger) block read is the main I/O win the + # closed-form path is designed for. + if info["pct_lossless"] >= 100.0: + k = scales.to(torch.int32) - E8M0_BIAS + per_block_amax = ( + (E2M1_MAX * torch.exp2(k.float())) + .repeat_interleave(2, dim=-1) + .to(dtype=torch.float32, device=device) + ) + else: + blocks = read(blocks_key, blocks_shard) + per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to( + dtype=torch.float32, device=device + ) + # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1)) + # while we compute it in natural (E, F, num_blocks) layout. The static + # export path reshapes via ``.view(expected_shape)``, so we just need + # element count to agree, then reshape for the in-place copy. + assert existing.numel() == per_block_amax.numel(), ( + f"{qname}: ``_amax`` element count {existing.numel()} does not " + f"match the cast-computed count {per_block_amax.numel()}. The " + "block layout from calibration disagrees with the source MXFP4 " + "scales — check that the qformat block_size is 16 and the source " + "checkpoint is the same MXFP4 model." + ) + + # global_amax via the NVFP4StaticQuantizer property setter (writes to + # the canonical ``_global_amax`` buffer). + weight_quantizer.global_amax = global_amax + # _amax: in-place buffer copy, reshaping our value to the calibrator's + # storage layout (numel verified above). + with torch.no_grad(): + existing.data.copy_(per_block_amax.view_as(existing)) + + matched += 1 print( f"[cast_mxfp4_to_nvfp4] overrode {matched}/{n_total_layers} weight quantizers from {source_checkpoint_path}" diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 4184a91199..ad36eddf32 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -1451,4 +1451,17 @@ def main(args: argparse.Namespace): "--specdec_offline_dataset expects a single --calib value, not a comma-separated list." ) + if args.cast_mxfp4_to_nvfp4: + qformats = [q.strip() for q in args.qformat.split(",")] + if not all("nvfp4" in q for q in qformats): + raise ValueError( + "--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values " + f"(got {args.qformat!r}). Use e.g. --qformat nvfp4 or nvfp4_mlp_only." + ) + if args.auto_quantize_bits is not None: + raise ValueError( + "--cast_mxfp4_to_nvfp4 is not supported with --auto_quantize_bits " + "(multi-format auto-quantize)." + ) + main(args) diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e59bf32232..ae8bd5c7cb 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -1092,6 +1092,9 @@ def set_expert_quantizer_amax( if existing_amax is not None: # Convert to tensor and add to collection if isinstance(existing_amax, torch.Tensor): + # Meta tensors have no storage; .amax() / .to() would fail. + if existing_amax.is_meta: + continue valid_amax_values.append(existing_amax.amax().to(target_device)) else: valid_amax_values.append( diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5f69d84ba4..4ce0f62a75 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -163,6 +163,15 @@ def max_calibrate( if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_expert_weight_amax) + # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer + # so the static blockwise fake-quant path is used in forward and the export + # picks up the two-level (per-block + global) scaling. Run before the + # ``distributed_sync`` early return so single-process callers also get the + # promotion. ``promote_nvfp4_static_quantizers`` only promotes when + # ``is_static_block_quant`` is True and the per-block ``_amax`` buffer is + # populated, so it's a no-op for dynamic-block / non-NVFP4 configs. + promote_nvfp4_static_quantizers(model) + if not distributed_sync: return @@ -280,13 +289,6 @@ def sync_quantizer_amax_across_tp( module.parallel_state.tensor_parallel_group ) - # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer - # so the static blockwise fake-quant path is used in forward and the export - # picks up the two-level (per-block + global) scaling. ``promote_nvfp4_static_quantizers`` - # only promotes when ``is_static_block_quant`` is True and the per-block ``_amax`` - # buffer is populated, so it's a no-op for dynamic-block / non-NVFP4 configs. - promote_nvfp4_static_quantizers(model) - def _mse_quant_func(x, amax, quantizer): """Quantization function for MSE calibration.""" diff --git a/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py b/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py index c3d3d3f77a..b6f8c3de12 100644 --- a/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py +++ b/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py @@ -280,3 +280,118 @@ def test_e2m1_magnitude_table_cached_per_device(): t2 = cast._e2m1_magnitude_table(torch.device("cpu")) assert t1 is t2 # cached: same object assert t1.tolist() == cast._E2M1_MAGNITUDE + + +# ---------- apply_to_model end-to-end (mock model) --------------------------- + + +class _FakeStaticQuantizer(torch.nn.Module): + """Stand-in for NVFP4StaticQuantizer. + + Carries a per-block ``_amax`` buffer and a ``global_amax`` property whose + setter writes ``_global_amax`` — matches the contract apply_to_model relies + on. Subclasses ``cast.NVFP4StaticQuantizer`` so the isinstance check passes. + """ + + def __init__(self, num_blocks: int): + super().__init__() + self.register_buffer("_amax", torch.zeros(num_blocks, dtype=torch.float32)) + self.register_buffer("_global_amax", torch.zeros((), dtype=torch.float32)) + + @property + def global_amax(self) -> torch.Tensor: + return self._global_amax + + @global_amax.setter + def global_amax(self, value: torch.Tensor) -> None: + self._global_amax = value + + +# Inherit at runtime so isinstance(NVFP4StaticQuantizer) is True. +_FakeStaticQuantizer.__bases__ = (cast.NVFP4StaticQuantizer,) + + +class _FakeExperts(torch.nn.Module): + """Mimics a HF GptOssExperts module: a ``*_weight_quantizer`` child.""" + + def __init__(self, num_blocks: int): + super().__init__() + # Quantizer attribute name must match the source key after stripping + # ``_blocks`` and appending ``_weight_quantizer``. + self.gate_up_proj_weight_quantizer = _FakeStaticQuantizer(num_blocks) + + +class _FakeModel(torch.nn.Module): + """Single MLP-like submodule path: ``model.layers.0.mlp.experts.gate_up_proj_*``.""" + + def __init__(self, num_blocks: int): + super().__init__() + self.experts = _FakeExperts(num_blocks) + + +def test_apply_to_model_writes_global_and_per_block_amax(tmp_path): + """Happy path: cast overrides _amax + global_amax on the matching quantizer.""" + # Build a synthetic MXFP4 source: 4 in-range MXFP4 blocks => 8 NVFP4 blocks. + name = "experts.gate_up_proj" + e8m0 = torch.tensor([130, 128, 125, 132], dtype=torch.uint8) # ks: 3, 1, -2, 5 + blocks = torch.zeros((4, 16), dtype=torch.uint8) + ckpt_dir = _write_synthetic_mxfp4_checkpoint( + tmp_path, + [name], + e8m0_per_layer={name: e8m0}, + blocks_per_layer={name: blocks}, + ) + + # 8 NVFP4 blocks (each MXFP4 block of 32 splits into two NVFP4 blocks of 16). + model = _FakeModel(num_blocks=8) + cast.apply_to_model(model, ckpt_dir) + + quantizer = model.experts.gate_up_proj_weight_quantizer + # k_max = 5 -> m = -3 -> global_amax = 6 * 448 * 2^-3 = 336. + assert float(quantizer.global_amax.item()) == pytest.approx(6.0 * 448.0 * 2.0**-3) + # All in-range -> per-block _amax = 6 * 2^k_j, repeat-interleaved by 2. + expected_per_mxfp4 = 6.0 * torch.exp2(torch.tensor([3.0, 1.0, -2.0, 5.0])) + expected_per_nvfp4 = expected_per_mxfp4.repeat_interleave(2) + assert torch.allclose(quantizer._amax.float(), expected_per_nvfp4) + + +def test_apply_to_model_raises_on_missing_blocks_pair(tmp_path): + """If a *_scales tensor has no paired *_blocks tensor, raise ValueError.""" + ckpt_dir = tmp_path / "missing_blocks" + ckpt_dir.mkdir() + # Write only the _scales tensor. + save_file( + {"experts.gate_up_proj_scales": torch.zeros(2, dtype=torch.uint8)}, + str(ckpt_dir / "model-00001-of-00001.safetensors"), + ) + (ckpt_dir / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": {}, + "weight_map": {"experts.gate_up_proj_scales": "model-00001-of-00001.safetensors"}, + } + ) + ) + model = _FakeModel(num_blocks=4) + with pytest.raises(AssertionError, match="no paired '.*_blocks' tensor"): + cast.apply_to_model(model, ckpt_dir) + + +def test_apply_to_model_raises_on_wrong_quantizer_type(tmp_path): + """If the matching attribute isn't an NVFP4StaticQuantizer, raise RuntimeError.""" + name = "experts.gate_up_proj" + e8m0 = torch.tensor([130, 128], dtype=torch.uint8) + blocks = torch.zeros((2, 16), dtype=torch.uint8) + ckpt_dir = _write_synthetic_mxfp4_checkpoint(tmp_path, [name], {name: e8m0}, {name: blocks}) + + class _NotAQuantizer(torch.nn.Module): + pass + + class _Wrong(torch.nn.Module): + def __init__(self): + super().__init__() + self.experts = torch.nn.Module() + self.experts.gate_up_proj_weight_quantizer = _NotAQuantizer() + + with pytest.raises(AssertionError, match="expected NVFP4StaticQuantizer"): + cast.apply_to_model(_Wrong(), ckpt_dir) From 0ef1b9879bdd4ed3492491f34ec0b5b6d29a3b47 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 1 May 2026 17:38:40 +0000 Subject: [PATCH 5/7] Extract cast_mxfp4_to_nvfp4 quant_cfg mutation into helper Move the inline weight-quantizer block_sizes='static' rewrite out of quantize_main() into a public force_weight_quantizers_static() helper in cast_mxfp4_to_nvfp4.py, keeping the cast-specific config logic colocated with the rest of the cast flow. Addresses review feedback on PR #1372. Signed-off-by: Chenjie Luo --- examples/llm_ptq/cast_mxfp4_to_nvfp4.py | 17 +++++++++++++++++ examples/llm_ptq/hf_ptq.py | 14 ++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py index 820acbdf60..26f3c9f825 100644 --- a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py +++ b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py @@ -291,6 +291,23 @@ def build_amax_map(checkpoint_dir: str | Path) -> dict[str, dict]: return amax_map +def force_weight_quantizers_static(quant_cfg: list) -> None: + """Force every weight-quantizer entry's ``block_sizes`` to ``type='static'``. + + The MXFP4 -> NVFP4 cast needs the per-block weight ``_amax`` to be recorded + by max-cal (so it can be paired with the pinned global_amax later). Setting + ``block_sizes['type'] = 'static'`` makes ``is_static_block_quant`` True so + ``promote_nvfp4_static_quantizers`` picks the entry up automatically at the + end of max_calibrate. + """ + for i, entry in enumerate(quant_cfg): + qname = entry.get("quantizer_name", "") + cfg = entry.get("cfg") or {} + bs = cfg.get("block_sizes") + if "weight_quantizer" in qname and isinstance(bs, dict): + quant_cfg[i] = {**entry, "cfg": {**cfg, "block_sizes": {**bs, "type": "static"}}} + + def apply_to_model( model: "torch.nn.Module", source_checkpoint_path: str | Path, diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index ad36eddf32..edf9c4d6f1 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -25,6 +25,7 @@ import torch from accelerate.hooks import remove_hook_from_module from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4 +from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -1088,20 +1089,9 @@ def quantize_main( f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}" ) - # MXFP4 -> NVFP4 cast needs the per-block weight ``_amax`` to be recorded - # by max-cal (so it can be paired with the pinned global_amax later). - # Force every weight-quantizer entry to ``block_sizes['type'] = 'static'`` - # so ``is_static_block_quant`` is True and ``promote_nvfp4_static_quantizers`` - # picks them up automatically at the end of max_calibrate. if args.cast_mxfp4_to_nvfp4: quant_cfg = copy.deepcopy(quant_cfg) - for entry in quant_cfg.get("quant_cfg", []): - qname = entry.get("quantizer_name", "") - cfg = entry.get("cfg") or {} - bs = cfg.get("block_sizes") - if "weight_quantizer" in qname and isinstance(bs, dict): - bs = {**bs, "type": "static"} - entry["cfg"] = {**cfg, "block_sizes": bs} + force_weight_quantizers_static(quant_cfg["quant_cfg"]) if args.qformat in QUANT_CFG_CHOICES: mono_quantize( From fd171fca816dbc98a887cb7e353018951d336cb8 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 1 May 2026 17:45:30 +0000 Subject: [PATCH 6/7] Document --cast_mxfp4_to_nvfp4 in llm_ptq README and CHANGELOG Add a GPT-OSS row + footnote to the llm_ptq support matrix and a new "MXFP4 -> NVFP4 cast (for GPT-OSS)" subsection covering usage, the closed-form per-block math, and the NVFP4-qformat / no-auto-quantize constraints. Add a one-line CHANGELOG entry under 0.45. Signed-off-by: Chenjie Luo --- CHANGELOG.rst | 1 + examples/llm_ptq/README.md | 20 +++++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6646359c7c..db9392a3d7 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ Changelog **New Features** - Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|``) and optional assistant-token ``loss_mask`` for answer-only-loss training. +- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `_ for usage. 0.44 (2026-05-xx) ^^^^^^^^^^^^^^^^^ diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 1cc1acfbf9..096d3cfa6b 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -114,6 +114,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http | GLM-4.78 | ✅ | - | - | - | ✅ | | Kimi K2 | - | - | - | - | ✅ | | MiniMax M2.1 | - | - | - | - | ✅ | +| GPT-OSS10 | - | - | - | - | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | - | | Whisper9 | ✅ | ❌ | ❌ | ❌ | - | | Nemotron-3 | ✅ | ❌ | ❌ | ❌ | ✅ | @@ -128,7 +129,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http > *6.Some models currently support export to HF format only.* \ > *7.[PTQ for DeepSeek](../deepseek/README.md)* \ > *8.GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.* \ -> *9.Running Whisper model with transformers>=5.0 requires [torchcodec](https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-cuda-enabled-torchcodec) and other system packages (e.g. ffmpeg).* +> *9.Running Whisper model with transformers>=5.0 requires [torchcodec](https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-cuda-enabled-torchcodec) and other system packages (e.g. ffmpeg).* \ +> *10.GPT-OSS ships with native MXFP4 weights; NVFP4 export is produced via the closed-form `--cast_mxfp4_to_nvfp4` cast (see [MXFP4 → NVFP4 cast](#mxfp4--nvfp4-cast-for-gpt-oss)).* > *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only`, `nvfp4_experts_only`, or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP/expert layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.* @@ -221,6 +223,22 @@ Available KV cache formats: > *Formats ending in `_cast` (fp8_cast, nvfp4_cast) are fast — they set the amax to the format's full range without data-driven calibration. Other formats use data-driven calibration for potentially better accuracy.* +#### MXFP4 → NVFP4 cast (for GPT-OSS) + +GPT-OSS checkpoints (`openai/gpt-oss-20b`, `openai/gpt-oss-120b`) ship with native MXFP4 weights (`*_blocks` + `*_scales` in the checkpoint, `quantization_config.quant_method == "mxfp4"`). Passing `--cast_mxfp4_to_nvfp4` tells `hf_ptq.py` to read the source MXFP4 scales and produce a closed-form, bit-exact NVFP4 weight export — no GEMM-level recalibration of the weights needed. + +```bash +python hf_ptq.py \ + --pyt_ckpt_path openai/gpt-oss-20b \ + --qformat nvfp4_mlp_only \ + --cast_mxfp4_to_nvfp4 \ + --export_path +``` + +The cast pins each NVFP4 block's `scale_2 = 2^(k_max - 8)` and `_amax = 6 * 2^k_j`, both derived from the source MXFP4 E8M0 scales. For blocks whose `k_j` lands in E4M3's representable window (`k_max - k_j ≤ 17`), NVFP4 dequant matches MXFP4 dequant bit-for-bit; out-of-range blocks fall back to a data-derived per-block amax. + +> *`--cast_mxfp4_to_nvfp4` requires an NVFP4-family `--qformat` (e.g. `nvfp4_mlp_only`, `nvfp4_experts_only`, `nvfp4`) and is incompatible with `--auto_quantize_bits`.* + #### Deepseek R1 [PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM. From d1b46b678e085d3d7d153f417989e949ea259aa2 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Fri, 1 May 2026 17:51:13 +0000 Subject: [PATCH 7/7] Fix CHANGELOG duplicate-target warning for llm_ptq README link Make the new --cast_mxfp4_to_nvfp4 entry's link anonymous (trailing double underscore) so it doesn't collide with the existing named target for the same README text on the multinode_ptq entry. Signed-off-by: Chenjie Luo --- CHANGELOG.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index db9392a3d7..d236988543 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,7 +18,7 @@ Changelog **New Features** - Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|``) and optional assistant-token ``loss_mask`` for answer-only-loss training. -- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `_ for usage. +- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. 0.44 (2026-05-xx) ^^^^^^^^^^^^^^^^^