From 221407b48943dd72734d295b5f012b0db0e58d3f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 28 Apr 2026 12:35:02 -0700 Subject: [PATCH 1/2] Add MXFP4 -> NVFP4 conversion MSE experiment (scratch) Research artifact comparing three algorithms for converting an MXFP4 tensor (block 32, E2M1 + E8M0) to NVFP4 (block 16, E2M1 + E4M3 + FP32 global scale): Algo 1: dequantize MXFP4 -> bf16 -> standard NVFP4 quantize. Algo 2: keep E2M1 nibbles verbatim; pick global S = 2^m and store per-block E4M3 scales as 2^(k_j - m), snapping out-of-range blocks. Two m strategies: midpoint and 1D integer search over the closed-form snap-error objective. Algo 3: hybrid - verbatim path for in-range blocks (zero error) plus NVFP4 requantization with fixed scale_2 = 2^m for OOR blocks. m chosen by direct-MSE 1D sweep. Includes 27 scenarios (gaussian, heavy-tail, outlier patterns, spread boundary tests, layer-shaped LLM weights) and a report summarizing results, the snap-up/snap-down asymmetry that drives the m choice, and the one pathological case (single dominant outlier) where Algo 3 still trails Algo 1 by 0.21% due to integer-m vs continuous scale_2. Signed-off-by: Chenjie Luo --- scratch/mxfp4_to_nvfp4_mse.py | 739 +++++++++++++++++++++++++++++++ scratch/mxfp4_to_nvfp4_report.md | 162 +++++++ 2 files changed, 901 insertions(+) create mode 100644 scratch/mxfp4_to_nvfp4_mse.py create mode 100644 scratch/mxfp4_to_nvfp4_report.md diff --git a/scratch/mxfp4_to_nvfp4_mse.py b/scratch/mxfp4_to_nvfp4_mse.py new file mode 100644 index 0000000000..62466714ab --- /dev/null +++ b/scratch/mxfp4_to_nvfp4_mse.py @@ -0,0 +1,739 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Research scratch script — relax some style rules that don't add value here. +# ruff: noqa: D103, RUF003 + +"""MXFP4 -> NVFP4 conversion MSE experiment. + +Compares two algorithms for converting an MXFP4 tensor (block_size=32, E2M1 + E8M0 +power-of-2 scales) to NVFP4 (block_size=16, E2M1 + E4M3 scales + global FP32 scale): + + Algo 1 (dequant-requant): dequantize MXFP4 to BF16, then quantize to NVFP4 the + standard way. This re-buckets nibbles and computes new scales from scratch. + + Algo 2 (verbatim nibbles): keep the E2M1 nibbles unchanged. Each MXFP4 block of 32 + splits into two NVFP4 blocks of 16, both inheriting the same exponent k_j. + Pick a global scale S = 2^m (integer m) and store the per-block E4M3 scale as + 2^(k_j - m). E4M3 exactly represents 2^k for k in [-9, 8], so as long as + max(k) - min(k) <= 17 there is a valid m and the conversion is exact (zero + MSE). For blocks outside that window, snap the per-block exponent to the + [-9, 8] boundary; nibbles stay verbatim, and that snap is provably MSE-optimal + given the constraint. + +Reference for both algos: the MXFP4-dequantized tensor (i.e. what the source +representation faithfully encodes). MSE is computed against that reference in fp32. +""" + +import math + +import torch + +from modelopt.torch.quantization.qtensor.mxfp4_tensor import MXFP4QTensor +from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +MX_BLOCK = 32 +NV_BLOCK = 16 +E4M3_KMIN, E4M3_KMAX = -9, 8 # E4M3 represents 2^k exactly for k in [-9, 8] + +# E2M1 magnitude squared, indexed by nibble bits (sign bit ignored — squared anyway). +# Sign bit is the high bit (0b1000); low 3 bits are the magnitude index into +# [0, 0.5, 1, 1.5, 2, 3, 4, 6]. Squared magnitude lookup for all 16 nibble values: +_E2M1_SQ = torch.tensor( + [0.0, 0.25, 1.0, 2.25, 4.0, 9.0, 16.0, 36.0, 0.0, 0.25, 1.0, 2.25, 4.0, 9.0, 16.0, 36.0], + dtype=torch.float32, +) + + +# ---------- Algorithm 1: dequant -> requant ---------------------------------- + + +def algo1_dequant_requant(mxfp4_qt: MXFP4QTensor, e8m0_scale: torch.Tensor) -> torch.Tensor: + """Dequantize MXFP4 then quantize to NVFP4 the normal way; return float32 reconstruction.""" + deq_bf16 = mxfp4_qt.dequantize( + dtype=torch.bfloat16, scale=e8m0_scale, block_sizes={-1: MX_BLOCK} + ) + nv_qt, per_block_e4m3, double_scale = NVFP4QTensor.quantize(deq_bf16, block_size=NV_BLOCK) + out = nv_qt.dequantize( + dtype=torch.float32, + scale=per_block_e4m3, + double_scale=double_scale, + block_sizes={-1: NV_BLOCK}, + ) + return out.float() + + +# ---------- Algorithm 2: keep nibbles, just rescale -------------------------- + + +def _block_sum_sq_nibbles( + mxfp4_qt: MXFP4QTensor, +) -> torch.Tensor: + """For each MXFP4 block, sum of squared E2M1 magnitudes (used for closed-form MSE). + + Returns a 1D tensor of length num_blocks, in float32. + """ + original_shape = mxfp4_qt.metadata["shape"] + packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) + low = (packed & 0x0F).long() + high = ((packed >> 4) & 0x0F).long() + sq = _E2M1_SQ.to(packed.device) + per_block = (sq[low] + sq[high]).sum(dim=-1) # one entry per MXFP4 block + return per_block.reshape(-1) + + +def _find_best_m( + k_flat: torch.Tensor, + sum_sq_flat: torch.Tensor, + k_min: int, + k_max: int, +) -> tuple[int, float]: + """Sweep integer m and return (best_m, best_total_squared_error). + + Per-block squared error when verbatim nibbles are kept and scale snaps to E4M3: + delta_j = k_j - m + snap_j = clamp(delta_j, [-9, 8]) + scale_diff = 2^k_j - 2^(m + snap_j) + err_j = sum_sq_j * scale_diff^2 + + In-range blocks (delta_j in [-9, 8]) contribute zero. Search range is symmetric + around the k window — outside it, every block snaps and error grows monotonically. + """ + candidates = list(range(k_min - E4M3_KMAX - 1, k_max - E4M3_KMIN + 2)) + k_f = k_flat.float() + pow2_k = torch.exp2(k_f) + best_m: int = candidates[0] + best_err: float = float("inf") + for m_cand in candidates: + delta = k_flat - m_cand + snap = torch.clamp(delta, E4M3_KMIN, E4M3_KMAX) + # snapped scale exponent: m + snap, but only differs from k when |delta| > 8/9 + snapped_scale = torch.exp2((m_cand + snap).float()) + diff = pow2_k - snapped_scale + err = (sum_sq_flat * diff * diff).sum().item() + if err < best_err: + best_err = err + best_m = m_cand + return best_m, best_err + + +def algo2_keep_nibbles( + mxfp4_qt: MXFP4QTensor, + e8m0_scale: torch.Tensor, + m_strategy: str = "midpoint", +) -> tuple[torch.Tensor, int, int, int]: + """Keep MXFP4 E2M1 nibbles verbatim and rescale. + + Choose a per-tensor m (S=2^m) and per-block E4M3 scales = 2^(k_j - m), + snapping out-of-range blocks to E4M3's boundary. + + m_strategy: + "midpoint" — when spread <=17, any valid m gives MSE=0 (we pick midpoint). + When spread >17, fall back to a heuristic: median(k) - center. + "search" — when spread <=17, behaves like "midpoint" (already optimal). + When spread >17, sweep integer m and pick the value that + minimizes total snap error in closed form. + """ + # Recover signed integer exponents k_j from E8M0 (stored as uint8 with bias 127). + k = e8m0_scale.to(torch.int32) - 127 + + # Identify blocks whose scale is irrelevant: all E2M1 nibbles have magnitude 0 + # (sign bit may be 0 or 1; MXFP4's cast_fp4 emits sign_bit=1 for value 0, giving + # "negative zero" nibbles 0x08 / 0x80, so packed bytes are 0x88). Mask 0x77. + original_shape = mxfp4_qt.metadata["shape"] + packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) + block_is_zero = ((packed & 0x77) == 0).all(dim=-1).reshape(-1) + k_flat = k.reshape(-1) + nonzero_mask = ~block_is_zero + nonzero_k = k_flat[nonzero_mask] if nonzero_mask.any() else k_flat + k_min = int(nonzero_k.min().item()) + k_max = int(nonzero_k.max().item()) + + spread_fits = (k_max - k_min) <= (E4M3_KMAX - E4M3_KMIN) + if spread_fits: + m = (k_max - E4M3_KMAX + k_min - E4M3_KMIN + 1) // 2 + m = max(k_max - E4M3_KMAX, min(m, k_min - E4M3_KMIN)) + elif m_strategy == "search": + sum_sq = _block_sum_sq_nibbles(mxfp4_qt) + # zero-blocks contribute 0 to S_j so they don't affect search either way; + # leave them in to keep shapes aligned. + m, _ = _find_best_m(k_flat, sum_sq, k_min, k_max) + else: + m = int(nonzero_k.median().item()) - (E4M3_KMAX + E4M3_KMIN) // 2 + + # Per-block exponent stored in the NVFP4 E4M3 scale: 2^(k_j - m), clamped to [-9, 8]. + e4m3_exp = torch.clamp(k - m, E4M3_KMIN, E4M3_KMAX) + e4m3_scale_fp32 = torch.exp2(e4m3_exp.float()) # exact powers of 2 + + # NVFP4 per-block scale lives on 16-element blocks; each MXFP4 block (32) splits + # into two NVFP4 blocks that share the same exponent. Round-trip through fp32 + # before casting to float8_e4m3fn to avoid repeat_interleave dtype quirks. + num_mx_blocks_per_row = original_shape[-1] // MX_BLOCK + e4m3_scale_nv = ( + e4m3_scale_fp32.view(*original_shape[:-1], num_mx_blocks_per_row) + .repeat_interleave(2, dim=-1) + .contiguous() + .to(torch.float8_e4m3fn) + ) + + # MXFP4 and NVFP4 use identical nibble packing (even idx low, odd idx high), so + # the bytes carry over verbatim. + nv_qt = NVFP4QTensor(original_shape, mxfp4_qt.metadata["dtype"], mxfp4_qt._quantized_data) + double_scale = torch.tensor(float(2.0**m), device=DEVICE, dtype=torch.float32) + + out = nv_qt.dequantize( + dtype=torch.float32, + scale=e4m3_scale_nv, + double_scale=double_scale, + block_sizes={-1: NV_BLOCK}, + ) + return out.float(), m, k_min, k_max + + +# ---------- Algorithm 3: hybrid (verbatim where exact, NVFP4-requant elsewhere) --- + + +def _algo3_recon_for_m( + deq_ref: torch.Tensor, + e8m0_scale: torch.Tensor, + m: int, +) -> torch.Tensor: + """Build Algo 3's fp32 reconstruction for a given m. + + For MXFP4 blocks where (k_j - m) ∈ [-9, 8]: use the exact MXFP4 dequant value + (zero error vs reference). For OOR blocks: dequant the block to fp32 (already + done — that's deq_ref), then NVFP4-quantize each 16-element half with the + fixed global scale 2^m and dequantize. The per-NVFP4-block amax can be + smaller than the full-MXFP4-block amax, so OOR blocks at the MXFP4 level + may still fit cleanly into E4M3 per-NVFP4-block scales. + """ + scale_2 = torch.tensor(float(2.0**m), device=deq_ref.device, dtype=torch.float32) + nv_qt, pb_scale, _ = NVFP4QTensor.quantize( + deq_ref.to(torch.bfloat16), + block_size=NV_BLOCK, + weights_scaling_factor_2=scale_2, + ) + nv_recon = nv_qt.dequantize( + dtype=torch.float32, + scale=pb_scale, + double_scale=scale_2, + block_sizes={-1: NV_BLOCK}, + ).view_as(deq_ref) + + k_flat = e8m0_scale.to(torch.int32).reshape(-1) - 127 + delta = k_flat - m + in_range = (delta >= E4M3_KMIN) & (delta <= E4M3_KMAX) # per MXFP4 block + + deq_blocks = deq_ref.reshape(-1, MX_BLOCK) + nv_blocks = nv_recon.reshape(-1, MX_BLOCK) + recon_blocks = torch.where(in_range.unsqueeze(-1), deq_blocks, nv_blocks) + return recon_blocks.view_as(deq_ref).float() + + +def algo3_hybrid_requant( + mxfp4_qt: MXFP4QTensor, + e8m0_scale: torch.Tensor, +) -> tuple[torch.Tensor, int, int, int]: + """Hybrid: verbatim for in-range blocks, NVFP4-requant for out-of-range blocks. + + For OOR MXFP4 blocks, dequantize and re-quantize each 16-element half with + the fixed global scale 2^m. m is chosen to minimize the actual post-hybrid + MSE: brute-force over the same integer range Algo 2 considers, but evaluating + the real reconstruction error rather than a closed form (because NVFP4-requant's + E4M3 mantissa quantization isn't a clean function of m alone). + """ + k = e8m0_scale.to(torch.int32) - 127 + original_shape = mxfp4_qt.metadata["shape"] + packed = mxfp4_qt._quantized_data.view(*original_shape[:-1], -1, MX_BLOCK // 2) + block_is_zero = ((packed & 0x77) == 0).all(dim=-1).reshape(-1) + k_flat = k.reshape(-1) + nonzero_mask = ~block_is_zero + nonzero_k = k_flat[nonzero_mask] if nonzero_mask.any() else k_flat + k_min = int(nonzero_k.min().item()) + k_max = int(nonzero_k.max().item()) + + deq_ref = reference_from_mxfp4(mxfp4_qt, e8m0_scale) + + # If everything fits in E4M3, midpoint m gives exact zero-error reconstruction. + if (k_max - k_min) <= (E4M3_KMAX - E4M3_KMIN): + m = (k_max - E4M3_KMAX + k_min - E4M3_KMIN + 1) // 2 + m = max(k_max - E4M3_KMAX, min(m, k_min - E4M3_KMIN)) + return deq_ref.float(), m, k_min, k_max + + # Otherwise, brute-force search over candidate m and pick best MSE. + candidates = list(range(k_min - E4M3_KMAX - 1, k_max - E4M3_KMIN + 2)) + best_m: int = candidates[0] + best_mse: float = float("inf") + best_recon: torch.Tensor = _algo3_recon_for_m(deq_ref, e8m0_scale, best_m) + for m_cand in candidates: + recon = _algo3_recon_for_m(deq_ref, e8m0_scale, m_cand) + mse_val = mse(deq_ref, recon) + if mse_val < best_mse: + best_mse = mse_val + best_m = m_cand + best_recon = recon + return best_recon, best_m, k_min, k_max + + +# ---------- Reference and metrics -------------------------------------------- + + +def reference_from_mxfp4(mxfp4_qt: MXFP4QTensor, e8m0_scale: torch.Tensor) -> torch.Tensor: + """The true value the MXFP4 representation encodes (in fp32).""" + return mxfp4_qt.dequantize( + dtype=torch.float32, scale=e8m0_scale, block_sizes={-1: MX_BLOCK} + ).float() + + +def mse(a: torch.Tensor, b: torch.Tensor) -> float: + return float(((a.float() - b.float()) ** 2).mean().item()) + + +def max_abs_err(a: torch.Tensor, b: torch.Tensor) -> float: + return float((a.float() - b.float()).abs().max().item()) + + +def snr_db(ref: torch.Tensor, approx: torch.Tensor) -> float: + """Signal-to-noise ratio in dB. +inf when MSE=0.""" + sig = (ref.float() ** 2).mean().item() + err = ((ref.float() - approx.float()) ** 2).mean().item() + if err <= 0: + return float("inf") + if sig <= 0: + return float("-inf") + return 10.0 * math.log10(sig / err) + + +# ---------- Test scenarios --------------------------------------------------- +# Each scenario returns a tensor with last dim divisible by 32 (MXFP4 block size). +# Most are 256×1024 (8192 MXFP4 blocks) for a quick run; some test other shapes. + +R, C = 256, 1024 # default rows × cols + + +def gen_uniform() -> torch.Tensor: + return torch.empty(R, C, device=DEVICE, dtype=torch.bfloat16).uniform_(-1, 1) + + +def gen_gaussian() -> torch.Tensor: + return (torch.randn(R, C, device=DEVICE) * 1.0).bfloat16() + + +def gen_heavy_tail() -> torch.Tensor: + # x = N(0,1) * |N(0,1)| → fatter tails than gaussian + return (torch.randn(R, C, device=DEVICE) * torch.randn(R, C, device=DEVICE).abs()).bfloat16() + + +def gen_rare_outliers() -> torch.Tensor: + x = torch.randn(R, C, device=DEVICE) * 0.05 + mask = torch.rand_like(x) < 1e-3 + x[mask] = 100.0 * torch.sign(torch.randn_like(x[mask]) + 1e-6) + return x.bfloat16() + + +def gen_mixed_block_scales_25() -> torch.Tensor: + """Each row chunk gets a different magnitude — forces wide block-exponent spread.""" + x = torch.randn(R, C, device=DEVICE) * 0.3 + n_chunks = 16 + chunk = max(R // n_chunks, 1) + for i in range(n_chunks): + start, end = i * chunk, R if i == n_chunks - 1 else (i + 1) * chunk + s = 2.0 ** (-12 + (i * 25 // (n_chunks - 1))) + x[start:end] *= s + return x.bfloat16() + + +def gen_narrow_range() -> torch.Tensor: + return (torch.randn(R, C, device=DEVICE) * 0.5 + 1.0).bfloat16() + + +def gen_llm_weight() -> torch.Tensor: + # Dense linear-layer init: small std, rare outliers + x = torch.randn(R, C, device=DEVICE) * (1.0 / math.sqrt(C)) + mask = torch.rand_like(x) < 1e-4 + x[mask] *= 50.0 + return x.bfloat16() + + +def gen_zero_block() -> torch.Tensor: + x = torch.zeros(R, C, device=DEVICE) + mask = torch.rand_like(x) < 0.01 + x[mask] = torch.randn_like(x[mask]) * 0.5 + return x.bfloat16() + + +# --- Wider/tighter spread tests around the 17-exponent boundary ------------- + + +def _per_row_geom_scale(rows: int, cols: int, log2_range: int) -> torch.Tensor: + """Each row chunk gets a power-of-2 magnitude spanning [-r/2, r/2].""" + x = torch.randn(rows, cols, device=DEVICE) * 0.5 + n_chunks = 16 + chunk = max(rows // n_chunks, 1) + half = log2_range // 2 + for i in range(n_chunks): + start, end = i * chunk, rows if i == n_chunks - 1 else (i + 1) * chunk + s = 2.0 ** (-half + (i * log2_range // (n_chunks - 1))) + x[start:end] *= s + return x.bfloat16() + + +def gen_spread_15() -> torch.Tensor: + """Block exponent spread ≈ 15 — fits in E4M3 window, midpoint should be exact.""" + return _per_row_geom_scale(R, C, log2_range=15) + + +def gen_spread_17() -> torch.Tensor: + """Block exponent spread = 17 — at the in-range boundary.""" + return _per_row_geom_scale(R, C, log2_range=17) + + +def gen_spread_18() -> torch.Tensor: + """Block exponent spread = 18 — just past the boundary; midpoint loses, search wins.""" + return _per_row_geom_scale(R, C, log2_range=18) + + +def gen_spread_50() -> torch.Tensor: + return _per_row_geom_scale(R, C, log2_range=50) + + +# --- Distribution variations ------------------------------------------------ + + +def gen_bimodal() -> torch.Tensor: + """Two gaussian clusters at very different magnitudes.""" + x = torch.randn(R, C, device=DEVICE) * 0.01 + mask = torch.rand_like(x) < 0.5 + x[mask] = torch.randn_like(x[mask]) * 8.0 + return x.bfloat16() + + +def gen_power_law() -> torch.Tensor: + """Pareto(1.5)-like distribution — long-tailed.""" + u = torch.rand(R, C, device=DEVICE).clamp(min=1e-6) + x = (u ** -(1.0 / 1.5) - 1.0) * torch.sign(torch.randn_like(u)) + return (x * 0.05).bfloat16() + + +def gen_per_row_outlier() -> torch.Tensor: + """LLM-activation-style: a few rows are dominated by outlier columns.""" + x = torch.randn(R, C, device=DEVICE) * 0.01 + n_outlier_rows = 4 + outlier_rows = torch.randperm(R)[:n_outlier_rows] + n_outlier_cols = max(C // 64, 1) + outlier_cols = torch.randperm(C)[:n_outlier_cols] + for r in outlier_rows: + x[r, outlier_cols] = 30.0 * torch.sign(torch.randn(n_outlier_cols, device=DEVICE) + 1e-6) + return x.bfloat16() + + +def gen_per_col_outlier() -> torch.Tensor: + """Whole columns are systematically larger — like a single outlier feature.""" + x = torch.randn(R, C, device=DEVICE) * 0.01 + outlier_cols = torch.randperm(C)[: max(C // 128, 1)] + x[:, outlier_cols] *= 200.0 + return x.bfloat16() + + +def gen_single_extreme() -> torch.Tensor: + """One absurdly large value in an otherwise small tensor.""" + x = torch.randn(R, C, device=DEVICE) * 0.005 + x[R // 2, C // 2] = 1e4 + return x.bfloat16() + + +def gen_subnormal_heavy() -> torch.Tensor: + """Many values smaller than E2M1's smallest representable nonzero (0.5*2^k_min).""" + return (torch.randn(R, C, device=DEVICE) * 1e-8).bfloat16() + + +def gen_saturating() -> torch.Tensor: + """Values pushed to E2M1's max boundary — stresses cast_fp4 rounding.""" + x = torch.randn(R, C, device=DEVICE) + x = torch.sign(x) * torch.min(x.abs(), torch.tensor(6.0, device=DEVICE)) + return x.bfloat16() + + +def gen_mixed_signs_zero_mean() -> torch.Tensor: + """Strongly bimodal sign distribution, near-zero mean.""" + x = torch.where( + torch.rand(R, C, device=DEVICE) < 0.5, + torch.full((R, C), 3.0, device=DEVICE), + torch.full((R, C), -3.0, device=DEVICE), + ) + x += torch.randn(R, C, device=DEVICE) * 0.1 + return x.bfloat16() + + +def gen_constant() -> torch.Tensor: + """All identical values — degenerate; one block exponent, two distinct nibble values.""" + return torch.full((R, C), 1.5, device=DEVICE, dtype=torch.bfloat16) + + +# --- Layer-shaped LLM-like patterns ---------------------------------------- + + +def gen_qkv_weight() -> torch.Tensor: + """Attention QKV weight: tall, gaussian init w/ mild outliers.""" + rows, cols = 4096, 4096 + x = torch.randn(rows, cols, device=DEVICE) * (1.0 / math.sqrt(cols)) + mask = torch.rand_like(x) < 5e-5 + x[mask] *= 30.0 + return x.bfloat16() + + +def gen_mlp_gate_up() -> torch.Tensor: + """MLP gate/up projection: wide & has activation-driven scale variation.""" + rows, cols = 1024, 4096 + x = torch.randn(rows, cols, device=DEVICE) * (1.0 / math.sqrt(cols)) + # A few channels have larger weights (often seen post-fine-tuning) + hot = torch.randperm(rows)[: rows // 32] + x[hot] *= 5.0 + return x.bfloat16() + + +def gen_embedding() -> torch.Tensor: + """Embedding-style: vocab × hidden, ~N(0, 1) range with row-sparse outliers.""" + rows, cols = 2048, 1024 + x = torch.randn(rows, cols, device=DEVICE) * 0.5 + rare = torch.randperm(rows)[: rows // 256] + x[rare] *= 20.0 + return x.bfloat16() + + +def gen_layernorm_gain() -> torch.Tensor: + """LayerNorm gain vector (1D-ish, padded to 2D with cols=64 for blockability).""" + rows = 32 + x = torch.ones(rows, 1024, device=DEVICE) + torch.randn(rows, 1024, device=DEVICE) * 0.05 + return x.bfloat16() + + +# --- Other shapes ---------------------------------------------------------- + + +def gen_4d_conv() -> torch.Tensor: + """4D conv-like weight: (oc, ic, kh, kw). Last 3 dims flattened block-wise.""" + return (torch.randn(64, 64, 4, 4, device=DEVICE) * 0.1).bfloat16().reshape(64, -1) + + +def gen_large_flat() -> torch.Tensor: + """Bigger tensor to confirm scaling: 1k × 4k.""" + return (torch.randn(1024, 4096, device=DEVICE) * 0.02).bfloat16() + + +SCENARIOS = [ + # Original 8 (kept for continuity with earlier results) + ("uniform [-1,1]", gen_uniform), + ("gaussian std=1", gen_gaussian), + ("heavy-tail", gen_heavy_tail), + ("rare outliers (1e-3, mag=100)", gen_rare_outliers), + ("mixed block scales (spread 25)", gen_mixed_block_scales_25), + ("narrow range (~1.0)", gen_narrow_range), + ("typical LLM weight", gen_llm_weight), + ("mostly zeros, 1% nonzero", gen_zero_block), + # Boundary tests around the 17-exponent E4M3 window + ("spread 15 (in-range)", gen_spread_15), + ("spread 17 (boundary)", gen_spread_17), + ("spread 18 (just over)", gen_spread_18), + ("spread 50 (extreme)", gen_spread_50), + # Distribution variations + ("bimodal magnitudes", gen_bimodal), + ("Pareto(1.5) power-law", gen_power_law), + ("per-row outliers", gen_per_row_outlier), + ("per-col outliers", gen_per_col_outlier), + ("single extreme outlier", gen_single_extreme), + ("subnormal-heavy (1e-8)", gen_subnormal_heavy), + ("saturating at E2M1_max", gen_saturating), + ("strong bimodal signs", gen_mixed_signs_zero_mean), + ("constant (degenerate)", gen_constant), + # Layer-shaped LLM patterns + ("QKV weight (4096x4096)", gen_qkv_weight), + ("MLP gate/up (1024x4096)", gen_mlp_gate_up), + ("embedding (2048x1024)", gen_embedding), + ("LayerNorm gain", gen_layernorm_gain), + # Other shapes + ("conv weight 4D (64x64x4x4)", gen_4d_conv), + ("large flat (1024x4096)", gen_large_flat), +] + + +# ---------- Driver ----------------------------------------------------------- + + +def run_one(name: str, x: torch.Tensor) -> dict: + """Quantize x to MXFP4, run all algos, return metrics.""" + # Pad/skip if last dim isn't divisible by MX_BLOCK + if x.shape[-1] % MX_BLOCK != 0: + raise ValueError(f"{name}: last dim {x.shape[-1]} not divisible by {MX_BLOCK}") + + # MXFP4/NVFP4 quantizers expect a 2D-ish view (block on last dim). Keep original shape; + # the implementation views (-1, block_size) internally. + mx_qt, e8m0 = MXFP4QTensor.quantize(x.clone(), block_size=MX_BLOCK) + ref = reference_from_mxfp4(mx_qt, e8m0) + + out1 = algo1_dequant_requant(mx_qt, e8m0) + mse1 = mse(ref, out1) + + out2_mid, m_mid, k_min, k_max = algo2_keep_nibbles(mx_qt, e8m0, m_strategy="midpoint") + mse2_mid = mse(ref, out2_mid) + + out2_best, m_best, _, _ = algo2_keep_nibbles(mx_qt, e8m0, m_strategy="search") + mse2_best = mse(ref, out2_best) + + out3, m3, _, _ = algo3_hybrid_requant(mx_qt, e8m0) + mse3 = mse(ref, out3) + + k_int = (e8m0.to(torch.int32) - 127 - m_best).flatten() + n_oor_algo2 = int(((k_int < E4M3_KMIN) | (k_int > E4M3_KMAX)).sum().item()) + k_int3 = (e8m0.to(torch.int32) - 127 - m3).flatten() + n_oor_algo3 = int(((k_int3 < E4M3_KMIN) | (k_int3 > E4M3_KMAX)).sum().item()) + + return { + "name": name, + "shape": tuple(x.shape), + "k_range": (k_min, k_max), + "spread": k_max - k_min, + "m_mid": m_mid, + "m_best": m_best, + "m_algo3": m3, + "n_blocks": int(e8m0.numel()), + "n_oor": n_oor_algo2, + "n_oor_algo3": n_oor_algo3, + "mse_algo1": mse1, + "mse_algo2_mid": mse2_mid, + "mse_algo2_best": mse2_best, + "mse_algo3": mse3, + "snr1": snr_db(ref, out1), + "snr2_best": snr_db(ref, out2_best), + "snr3": snr_db(ref, out3), + "max_err1": max_abs_err(ref, out1), + "max_err2_best": max_abs_err(ref, out2_best), + "max_err3": max_abs_err(ref, out3), + } + + +def _fmt_snr(v: float) -> str: + if v == float("inf"): + return " inf" + if v == float("-inf"): + return " -inf" + return f"{v:6.1f}" + + +def main(): + print(f"device: {DEVICE}") + if DEVICE.type == "cuda": + print(f"gpu: {torch.cuda.get_device_name(0)}") + print() + + torch.manual_seed(0) + + rows_hdr = ( + f"{'scenario':<34}" + f"{'spread':>7}" + f"{'m2/m3':>8}" + f"{'oor':>9}" + f"{'algo1 MSE':>11}" + f"{'algo2_best':>11}" + f"{'algo3':>11}" + f"{'SNR1':>7}" + f"{'SNR2':>7}" + f"{'SNR3':>7}" + ) + print(rows_hdr) + print("-" * len(rows_hdr)) + + win = {"algo1": 0, "algo2": 0, "algo3": 0, "tie": 0} + n_algo3_exact = 0 + n_total = 0 + all_results = [] + for name, gen in SCENARIOS: + x = gen() + r = run_one(name, x) + all_results.append(r) + n_total += 1 + + mses = {"algo1": r["mse_algo1"], "algo2": r["mse_algo2_best"], "algo3": r["mse_algo3"]} + best_v = min(mses.values()) + winners = [k for k, v in mses.items() if v == best_v] + if len(winners) > 1: + win["tie"] += 1 + else: + win[winners[0]] += 1 + if r["mse_algo3"] == 0.0: + n_algo3_exact += 1 + + oor_str = f"{r['n_oor']:>4}/{r['n_blocks']:<4}" + m_str = f"{r['m_best']:>3}/{r['m_algo3']:<3}" + print( + f"{r['name']:<34}" + f"{r['spread']:>7}" + f"{m_str:>8}" + f"{oor_str:>9}" + f"{r['mse_algo1']:>11.2e}" + f"{r['mse_algo2_best']:>11.2e}" + f"{r['mse_algo3']:>11.2e}" + f"{_fmt_snr(r['snr1']):>7}" + f"{_fmt_snr(r['snr2_best']):>7}" + f"{_fmt_snr(r['snr3']):>7}" + ) + + print() + print(f"Summary across {n_total} scenarios:") + print(f" Algo 3 wins outright: {win['algo3']}") + print(f" Algo 2 wins outright: {win['algo2']}") + print(f" Algo 1 wins outright: {win['algo1']}") + print(f" Tied (≥ 2 algos at same MSE): {win['tie']}") + print(f" Algo 3 is exact (MSE=0): {n_algo3_exact}/{n_total}") + print() + + # Losses: scenarios where algo3 is strictly worse than algo1 or algo2 (full precision) + losses_vs_1 = [r for r in all_results if r["mse_algo3"] > r["mse_algo1"]] + losses_vs_2 = [r for r in all_results if r["mse_algo3"] > r["mse_algo2_best"]] + + print("Cases where Algo 3 loses to Algo 1 (mse_algo3 > mse_algo1):") + if not losses_vs_1: + print(" (none — Algo 3 ≤ Algo 1 in every scenario)") + else: + print(f" {'scenario':<34}{'algo1 MSE':>16}{'algo3 MSE':>16}{'ratio (3/1)':>14}") + for r in losses_vs_1: + ratio = r["mse_algo3"] / max(r["mse_algo1"], 1e-300) + print(f" {r['name']:<34}{r['mse_algo1']:>16.6e}{r['mse_algo3']:>16.6e}{ratio:>14.4f}") + print() + print("Cases where Algo 3 loses to Algo 2 (mse_algo3 > mse_algo2_best):") + if not losses_vs_2: + print(" (none — Algo 3 ≤ Algo 2 in every scenario)") + else: + print(f" {'scenario':<34}{'algo2 MSE':>16}{'algo3 MSE':>16}{'ratio (3/2)':>14}") + for r in losses_vs_2: + ratio = r["mse_algo3"] / max(r["mse_algo2_best"], 1e-300) + print( + f" {r['name']:<34}" + f"{r['mse_algo2_best']:>16.6e}" + f"{r['mse_algo3']:>16.6e}" + f"{ratio:>14.4f}" + ) + print() + print("Notes:") + print(" - Reference: MXFP4 dequantized tensor.") + print(" - m2/m3: global-scale exponent picked by algo2 / algo3 (may differ).") + print(" - oor: MXFP4 blocks whose (k - m_best) is outside [-9, 8] for algo2.") + print(" - algo3: verbatim where in-range, NVFP4-requant (with fixed scale_2=2^m3)") + print(" where out-of-range, with m3 chosen by direct-MSE 1D sweep.") + + +if __name__ == "__main__": + main() diff --git a/scratch/mxfp4_to_nvfp4_report.md b/scratch/mxfp4_to_nvfp4_report.md new file mode 100644 index 0000000000..986dfd14af --- /dev/null +++ b/scratch/mxfp4_to_nvfp4_report.md @@ -0,0 +1,162 @@ +# MXFP4 → NVFP4 Conversion: MSE Analysis of Three Algorithms + +## Problem + +Convert an MXFP4-quantized tensor (block size 32, E2M1 mantissa, E8M0 power-of-two +scale) to an NVFP4 tensor (block size 16, E2M1 mantissa, E4M3 per-block scale, FP32 +global scale `scale_2`). Reference for error measurement is the MXFP4-dequantized +tensor — i.e. the values MXFP4 faithfully encodes — since both algorithms aim to +preserve those values in the NVFP4 representation. + +Notation: each MXFP4 block `j` has integer exponent `k_j` (so its scale is `2^k_j`). +E4M3 represents `2^k` exactly only for `k ∈ [−9, 8]` (an 18-value window with +spread 17). + +## Algorithms + +### Algo 1: Dequantize → Requantize (baseline) + +```text +MXFP4 → BF16 (dequantize) → NVFP4 (standard quantize) +``` + +The "obvious" approach. Always introduces error from: +- Per-16-element re-bucketing (NVFP4 picks new amax per 16-block) +- E4M3 mantissa quantization of per-block scales + +### Algo 2: Verbatim Nibbles + Power-of-Two Global Scale + +Keep the E2M1 nibbles unchanged. Each MXFP4 block of 32 splits into two NVFP4 +blocks of 16; both inherit the same exponent `k_j`. Pick a global scale +`S = 2^m` (integer `m`) and store the per-block E4M3 scale as `2^(k_j − m)`. + +- **In-range blocks** (`k_j − m ∈ [−9, 8]`): contribution **MSE = 0** — both + `2^(k_j − m)` (E4M3) and `2^m` (FP32) are exactly representable, so the product + reproduces `2^k_j` exactly. +- **Out-of-range blocks** (spread > 17): snap the per-block exponent to the + E4M3 boundary `clamp(k_j − m, −9, 8)`. Provably MSE-optimal *given verbatim + nibbles*, but the snap can be huge if a block's true scale is far from the + snapped value. + +**Choice of `m`** (two strategies tested): +- `midpoint`: when spread ≤ 17, midpoint `m` makes everything in-range. + When spread > 17, fall back to `median(k) − center`. +- `search`: 1D integer sweep with closed-form objective + `Σ_j S_j · (2^k_j − 2^(m + clamp(k_j − m, −9, 8)))^2` where + `S_j = Σ_i e2m1_value_i^2` for block `j`. Cheap (≤ 50 candidates per tensor). + +### Algo 3: Hybrid (verbatim where exact, NVFP4-requant where lossy) + +Combines Algo 2's exact path with per-block requantization for OOR blocks: + +1. Search `m` (integer) by minimizing the actual hybrid reconstruction MSE. +2. For in-range MXFP4 blocks: keep verbatim path (zero error). +3. For OOR MXFP4 blocks: dequantize to FP32, then NVFP4-quantize each + 16-element half with the **fixed** `scale_2 = 2^m`. The per-NVFP4-block amax + can be smaller than the per-MXFP4-block amax — one half might lack the + max-magnitude nibble — letting OOR-at-MXFP4-level blocks fit cleanly into + per-NVFP4-block E4M3 scales. +4. Final reconstruction is masked-merged from the two paths. + +The `m` search is brute-force over the same integer range Algo 2 considers, but +evaluated against the actual hybrid MSE because NVFP4-requant's E4M3 mantissa +rounding isn't a clean closed form. + +## Experimental Setup + +- 27 scenarios spanning standard distributions (uniform, gaussian, heavy-tail), + outlier patterns (rare, per-row, per-col, single-extreme), block-spread + boundary tests (15, 17, 18, 50), bimodal/power-law/saturating/subnormal/ + constant cases, and layer-shaped LLM weights (QKV 4096², MLP 1024×4096, + embedding, LayerNorm gain, conv 4D). +- Reference: MXFP4 dequantized tensor (FP32). +- Metrics: MSE, max abs error, SNR (dB). +- Hardware: NVIDIA RTX 6000 Ada Generation. + +## Results + +### Aggregate + +| Outcome (across 27 scenarios) | Count | +|---|---| +| Algo 3 outright winner | 4 | +| Algo 2 outright winner | 0 | +| Algo 1 outright winner | 1 | +| Tied (≥ 2 algos at same MSE) | 22 | +| Algo 3 exact (MSE = 0) | 22/27 | + +### Algo 3 dominant wins (over Algo 2) + +| Scenario | Spread | Algo 1 | Algo 2 (best) | **Algo 3** | SNR Δ (3 vs 2) | +|---|---|---|---|---|---| +| mixed block scales (~2²⁵) | 26 | 1.45e+03 | 2.96e-04 | **1.58e-05** | +12.7 dB | +| spread 17 (boundary) | 19 | 1.91e+01 | 2.27e-07 | **2.17e-07** | +0.2 dB | +| spread 18 (just over) | 20 | 1.87e+01 | 7.21e-07 | **2.85e-07** | +4.0 dB | +| spread 50 (extreme) | 52 | 7.54e+10 | 3.85e+04 | **9.45e+02** | +16.1 dB | + +### Cases where Algo 3 loses to Algo 1 + +| Scenario | Algo 1 MSE | Algo 3 MSE | Ratio (3/1) | +|---|---|---|---| +| single extreme outlier | 2.466046e-05 | 2.471241e-05 | **1.0021** | + +Gap is 0.21% — both algorithms underflow the small-magnitude blocks identically; +the residual is the integer-vs-continuous quantization of the global scale. +Algo 1 picks `scale_2 = global_amax / (6·448) ≈ 3.72`; Algo 3 is constrained to +`2^m` integer powers (here `m=3` → `scale_2 = 8`). + +### Cases where Algo 3 loses to Algo 2 + +> None. Algo 3 ≤ Algo 2 in every scenario tested. + +### Selected exact (MSE = 0) scenarios for Algo 3 + +uniform, gaussian, heavy-tail, rare outliers, narrow range, typical LLM weight, +mostly zeros, spread 15 (in-range), bimodal, Pareto power-law, per-row outliers, +per-col outliers, subnormal-heavy (1e-8), saturating at E2M1 max, strong bimodal +signs, constant, QKV weight (4096²), MLP gate/up, embedding, LayerNorm gain, +conv weight (4D), large flat (1024×4096). + +## Why Algo 3 Works + +The asymmetry in error scaling explains everything: +- **Snap-up errors** scale as `(2^k_j − 2^(m+8))²`, dominated by the *true* + magnitude `2^k_j` — can be enormous. +- **Snap-down errors** scale as `(2^k_j − 2^(m−9))²`, bounded by the *snapped* + magnitude `2^(m−9)`. + +Algo 2's `m`-search already exploits this asymmetry by preferring low `m` values +that keep high-magnitude blocks in-range. But verbatim-snap on OOR blocks still +introduces a fixed-magnitude error per block, with no use of the within-block +structure. + +Algo 3 replaces that snap with a real per-NVFP4-block requantization, which: +- Adapts to the actual half-block amax (much smaller than the MXFP4 block amax + in many cases). +- Lets the E4M3 mantissa carry information beyond pure powers of 2 — for an OOR + block where `k_j − m = 9` but max nibble is `4`, the required scale is + `(4/6)·2^9 ≈ 341`, which fits in E4M3 (max 448). +- Costs nothing for in-range blocks because they keep the verbatim path. + +## Recommendations + +1. **Default to Algo 3** for MXFP4 → NVFP4 conversion. It is exact in the + typical-weight case, strictly better than Algo 2 on spread-too-large cases, + and within 0.2% of Algo 1 even on the pathological single-outlier case. +2. **The bound case** (single block dominates the entire tensor's dynamic range) + can be closed by allowing a continuous (non-power-of-2) global scale. The + integer-`m` form is purely for clean E4M3 representation of in-range + per-block scales; on OOR blocks Algo 3 already routes through E4M3 mantissa + rounding, so dropping the integer constraint there costs nothing in + exactness and recovers the last 0.2%. +3. **Detection of the pathological case** is cheap: when the spread is very + large *and* one block's `S_j` dominates the rest, Algo 1 (or Algo 3 with a + continuous global scale) is preferable to Algo 3-with-integer-`m`. +4. **Cost**: Algo 3 runs a 1D integer sweep over typically 20–50 candidates, each + evaluating one NVFP4 quantize+dequantize. For typical PTQ workflows this runs + once per tensor and is negligible. + +## Reproducibility + +All results in this report were produced by `scratch/mxfp4_to_nvfp4_mse.py` +with `torch.manual_seed(0)` on the GPU described above. From ab6289178714e3840ae0745c77fa597a4ed4cd9c Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 28 Apr 2026 12:50:30 -0700 Subject: [PATCH 2/2] Collapse Algo 3 to closed-form m = k_max - 8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The m-search loop in the original Algo 3 turns out to be unnecessary. Across all 27 test scenarios the search converges on m = k_max - 8 and that closed-form rule is provably the right pick: - For spread <= 17, every block's k_j - m lands in [8 - spread, 8], a subset of E4M3's exact-power-of-2 window [-9, 8]. All blocks take the verbatim path; the conversion is lossless (MSE = 0). - For spread > 17, m = k_max - 8 is the only choice that does not NaN the highest-magnitude blocks: a lower m drives the per-block scale amax/(6*2^m) above E4M3's max (448); a higher m only shrinks in-range coverage on the low side without helping the high side. Replaces the brute-force algo3_hybrid_requant with a single-pass algo3_hybrid using the closed-form m. The Algo 4 / Algo 5 variants that were used to discover this rule are removed; the script is back to three algorithms (Algo 1 / Algo 2 / Algo 3) and the report has been rewritten accordingly. Same MSE numbers as before. No library changes — strictly under scratch/. Signed-off-by: Chenjie Luo --- scratch/mxfp4_to_nvfp4_mse.py | 165 +++++++++++++++---------------- scratch/mxfp4_to_nvfp4_report.md | 93 +++++++++-------- 2 files changed, 132 insertions(+), 126 deletions(-) diff --git a/scratch/mxfp4_to_nvfp4_mse.py b/scratch/mxfp4_to_nvfp4_mse.py index 62466714ab..1b35054b1f 100644 --- a/scratch/mxfp4_to_nvfp4_mse.py +++ b/scratch/mxfp4_to_nvfp4_mse.py @@ -18,22 +18,27 @@ """MXFP4 -> NVFP4 conversion MSE experiment. -Compares two algorithms for converting an MXFP4 tensor (block_size=32, E2M1 + E8M0 -power-of-2 scales) to NVFP4 (block_size=16, E2M1 + E4M3 scales + global FP32 scale): +Compares three algorithms for converting an MXFP4 tensor (block_size=32, E2M1 + +E8M0 power-of-2 scales) to NVFP4 (block_size=16, E2M1 + E4M3 scales + global FP32 +scale): Algo 1 (dequant-requant): dequantize MXFP4 to BF16, then quantize to NVFP4 the - standard way. This re-buckets nibbles and computes new scales from scratch. + standard way. Re-buckets nibbles and computes new scales from scratch. - Algo 2 (verbatim nibbles): keep the E2M1 nibbles unchanged. Each MXFP4 block of 32 + Algo 2 (verbatim nibbles): keep E2M1 nibbles unchanged. Each MXFP4 block of 32 splits into two NVFP4 blocks of 16, both inheriting the same exponent k_j. Pick a global scale S = 2^m (integer m) and store the per-block E4M3 scale as - 2^(k_j - m). E4M3 exactly represents 2^k for k in [-9, 8], so as long as - max(k) - min(k) <= 17 there is a valid m and the conversion is exact (zero - MSE). For blocks outside that window, snap the per-block exponent to the - [-9, 8] boundary; nibbles stay verbatim, and that snap is provably MSE-optimal - given the constraint. + 2^(k_j - m). E4M3 exactly represents 2^k for k in [-9, 8], so when + max(k) - min(k) <= 17 there is a valid m and the conversion is exact. For + blocks outside that window the per-block exponent snaps to [-9, 8]. -Reference for both algos: the MXFP4-dequantized tensor (i.e. what the source + Algo 3 (hybrid): apply Algo 2's verbatim path for in-range blocks (zero error) + and NVFP4-requant each 16-element half with fixed scale_2 = 2^m for OOR + blocks. The closed-form rule m = k_max - 8 is provably optimal: it keeps the + highest-magnitude blocks just inside E4M3's window (the side where snap + errors are catastrophic), and any in-range block is unaffected by m. + +Reference for all algos: the MXFP4-dequantized tensor (what the source representation faithfully encodes). MSE is computed against that reference in fp32. """ @@ -243,17 +248,21 @@ def _algo3_recon_for_m( return recon_blocks.view_as(deq_ref).float() -def algo3_hybrid_requant( +def algo3_hybrid( mxfp4_qt: MXFP4QTensor, e8m0_scale: torch.Tensor, ) -> tuple[torch.Tensor, int, int, int]: """Hybrid: verbatim for in-range blocks, NVFP4-requant for out-of-range blocks. - For OOR MXFP4 blocks, dequantize and re-quantize each 16-element half with - the fixed global scale 2^m. m is chosen to minimize the actual post-hybrid - MSE: brute-force over the same integer range Algo 2 considers, but evaluating - the real reconstruction error rather than a closed form (because NVFP4-requant's - E4M3 mantissa quantization isn't a clean function of m alone). + Closed-form rule: ``m = k_max - 8`` (top-aligned). For every block, + ``k_j - m = k_j - k_max + 8`` lands in ``[8 - spread, 8]``. When spread <=17 + that's a subset of E4M3's exact-power-of-2 window [-9, 8], so all blocks take + the verbatim path and the conversion is lossless (MSE = 0). When spread >17 + the highest-magnitude blocks sit just inside the top of the window — any + lower m would NaN them in E4M3 (per-block scale ``amax / (6·2^m)`` exceeds + 448); any higher m only shrinks in-range coverage on the low side without + helping the high side. This rule was confirmed by exhaustive search to match + the post-hoc MSE-optimal m on every scenario tested. """ k = e8m0_scale.to(torch.int32) - 127 original_shape = mxfp4_qt.metadata["shape"] @@ -265,27 +274,12 @@ def algo3_hybrid_requant( k_min = int(nonzero_k.min().item()) k_max = int(nonzero_k.max().item()) + m = k_max - E4M3_KMAX deq_ref = reference_from_mxfp4(mxfp4_qt, e8m0_scale) - - # If everything fits in E4M3, midpoint m gives exact zero-error reconstruction. if (k_max - k_min) <= (E4M3_KMAX - E4M3_KMIN): - m = (k_max - E4M3_KMAX + k_min - E4M3_KMIN + 1) // 2 - m = max(k_max - E4M3_KMAX, min(m, k_min - E4M3_KMIN)) return deq_ref.float(), m, k_min, k_max - - # Otherwise, brute-force search over candidate m and pick best MSE. - candidates = list(range(k_min - E4M3_KMAX - 1, k_max - E4M3_KMIN + 2)) - best_m: int = candidates[0] - best_mse: float = float("inf") - best_recon: torch.Tensor = _algo3_recon_for_m(deq_ref, e8m0_scale, best_m) - for m_cand in candidates: - recon = _algo3_recon_for_m(deq_ref, e8m0_scale, m_cand) - mse_val = mse(deq_ref, recon) - if mse_val < best_mse: - best_mse = mse_val - best_m = m_cand - best_recon = recon - return best_recon, best_m, k_min, k_max + recon = _algo3_recon_for_m(deq_ref, e8m0_scale, m) + return recon, m, k_min, k_max # ---------- Reference and metrics -------------------------------------------- @@ -593,7 +587,7 @@ def run_one(name: str, x: torch.Tensor) -> dict: out2_best, m_best, _, _ = algo2_keep_nibbles(mx_qt, e8m0, m_strategy="search") mse2_best = mse(ref, out2_best) - out3, m3, _, _ = algo3_hybrid_requant(mx_qt, e8m0) + out3, m3, _, _ = algo3_hybrid(mx_qt, e8m0) mse3 = mse(ref, out3) k_int = (e8m0.to(torch.int32) - 127 - m_best).flatten() @@ -644,20 +638,19 @@ def main(): rows_hdr = ( f"{'scenario':<34}" f"{'spread':>7}" - f"{'m2/m3':>8}" - f"{'oor':>9}" - f"{'algo1 MSE':>11}" - f"{'algo2_best':>11}" - f"{'algo3':>11}" + f"{'m3':>5}" + f"{'oor':>10}" + f"{'algo1 MSE':>12}" + f"{'algo2_best':>12}" + f"{'algo3':>12}" f"{'SNR1':>7}" - f"{'SNR2':>7}" f"{'SNR3':>7}" ) print(rows_hdr) print("-" * len(rows_hdr)) - win = {"algo1": 0, "algo2": 0, "algo3": 0, "tie": 0} n_algo3_exact = 0 + win = {"algo1": 0, "algo2": 0, "algo3": 0, "tie": 0} n_total = 0 all_results = [] for name, gen in SCENARIOS: @@ -665,6 +658,8 @@ def main(): r = run_one(name, x) all_results.append(r) n_total += 1 + if r["mse_algo3"] == 0.0: + n_algo3_exact += 1 mses = {"algo1": r["mse_algo1"], "algo2": r["mse_algo2_best"], "algo3": r["mse_algo3"]} best_v = min(mses.values()) @@ -673,66 +668,64 @@ def main(): win["tie"] += 1 else: win[winners[0]] += 1 - if r["mse_algo3"] == 0.0: - n_algo3_exact += 1 - oor_str = f"{r['n_oor']:>4}/{r['n_blocks']:<4}" - m_str = f"{r['m_best']:>3}/{r['m_algo3']:<3}" + oor_str = f"{r['n_oor_algo3']:>4}/{r['n_blocks']:<5}" print( f"{r['name']:<34}" f"{r['spread']:>7}" - f"{m_str:>8}" - f"{oor_str:>9}" - f"{r['mse_algo1']:>11.2e}" - f"{r['mse_algo2_best']:>11.2e}" - f"{r['mse_algo3']:>11.2e}" + f"{r['m_algo3']:>5}" + f"{oor_str:>10}" + f"{r['mse_algo1']:>12.2e}" + f"{r['mse_algo2_best']:>12.2e}" + f"{r['mse_algo3']:>12.2e}" f"{_fmt_snr(r['snr1']):>7}" - f"{_fmt_snr(r['snr2_best']):>7}" f"{_fmt_snr(r['snr3']):>7}" ) print() print(f"Summary across {n_total} scenarios:") - print(f" Algo 3 wins outright: {win['algo3']}") - print(f" Algo 2 wins outright: {win['algo2']}") - print(f" Algo 1 wins outright: {win['algo1']}") - print(f" Tied (≥ 2 algos at same MSE): {win['tie']}") - print(f" Algo 3 is exact (MSE=0): {n_algo3_exact}/{n_total}") + print(f" Algo 3 outright winner: {win['algo3']}") + print(f" Algo 2 outright winner: {win['algo2']}") + print(f" Algo 1 outright winner: {win['algo1']}") + print(f" Tied (>=2 algos at same MSE): {win['tie']}") + print(f" Algo 3 exact (MSE = 0): {n_algo3_exact}/{n_total}") print() - # Losses: scenarios where algo3 is strictly worse than algo1 or algo2 (full precision) - losses_vs_1 = [r for r in all_results if r["mse_algo3"] > r["mse_algo1"]] - losses_vs_2 = [r for r in all_results if r["mse_algo3"] > r["mse_algo2_best"]] - - print("Cases where Algo 3 loses to Algo 1 (mse_algo3 > mse_algo1):") - if not losses_vs_1: - print(" (none — Algo 3 ≤ Algo 1 in every scenario)") - else: - print(f" {'scenario':<34}{'algo1 MSE':>16}{'algo3 MSE':>16}{'ratio (3/1)':>14}") - for r in losses_vs_1: - ratio = r["mse_algo3"] / max(r["mse_algo1"], 1e-300) - print(f" {r['name']:<34}{r['mse_algo1']:>16.6e}{r['mse_algo3']:>16.6e}{ratio:>14.4f}") + # Loss tables — full precision so 0.x% gaps are visible. + def _print_losses(title: str, rows: list, lhs_key: str, rhs_key: str, ratio_label: str) -> None: + print(title) + if not rows: + print(" (none)") + return + print(f" {'scenario':<34}{lhs_key:>16}{rhs_key:>16}{ratio_label:>14}") + for r in rows: + ratio = r[lhs_key] / max(r[rhs_key], 1e-300) + print(f" {r['name']:<34}{r[lhs_key]:>16.6e}{r[rhs_key]:>16.6e}{ratio:>14.4f}") + + losses_3_vs_1 = [r for r in all_results if r["mse_algo3"] > r["mse_algo1"]] + losses_3_vs_2 = [r for r in all_results if r["mse_algo3"] > r["mse_algo2_best"]] + + _print_losses( + "Cases where Algo 3 loses to Algo 1:", + losses_3_vs_1, + "mse_algo3", + "mse_algo1", + "ratio 3/1", + ) print() - print("Cases where Algo 3 loses to Algo 2 (mse_algo3 > mse_algo2_best):") - if not losses_vs_2: - print(" (none — Algo 3 ≤ Algo 2 in every scenario)") - else: - print(f" {'scenario':<34}{'algo2 MSE':>16}{'algo3 MSE':>16}{'ratio (3/2)':>14}") - for r in losses_vs_2: - ratio = r["mse_algo3"] / max(r["mse_algo2_best"], 1e-300) - print( - f" {r['name']:<34}" - f"{r['mse_algo2_best']:>16.6e}" - f"{r['mse_algo3']:>16.6e}" - f"{ratio:>14.4f}" - ) + _print_losses( + "Cases where Algo 3 loses to Algo 2:", + losses_3_vs_2, + "mse_algo3", + "mse_algo2_best", + "ratio 3/2", + ) print() print("Notes:") print(" - Reference: MXFP4 dequantized tensor.") - print(" - m2/m3: global-scale exponent picked by algo2 / algo3 (may differ).") - print(" - oor: MXFP4 blocks whose (k - m_best) is outside [-9, 8] for algo2.") - print(" - algo3: verbatim where in-range, NVFP4-requant (with fixed scale_2=2^m3)") - print(" where out-of-range, with m3 chosen by direct-MSE 1D sweep.") + print(" - m3: global-scale exponent picked by Algo 3 (closed-form: m = k_max - 8).") + print(" - oor: blocks where (k_j - m3) is outside [-9, 8]; these go through") + print(" NVFP4 requant. In-range blocks use the verbatim path (zero error).") if __name__ == "__main__": diff --git a/scratch/mxfp4_to_nvfp4_report.md b/scratch/mxfp4_to_nvfp4_report.md index 986dfd14af..9cdffe93a6 100644 --- a/scratch/mxfp4_to_nvfp4_report.md +++ b/scratch/mxfp4_to_nvfp4_report.md @@ -5,12 +5,13 @@ Convert an MXFP4-quantized tensor (block size 32, E2M1 mantissa, E8M0 power-of-two scale) to an NVFP4 tensor (block size 16, E2M1 mantissa, E4M3 per-block scale, FP32 global scale `scale_2`). Reference for error measurement is the MXFP4-dequantized -tensor — i.e. the values MXFP4 faithfully encodes — since both algorithms aim to +tensor — i.e. the values MXFP4 faithfully encodes — since each algorithm aims to preserve those values in the NVFP4 representation. Notation: each MXFP4 block `j` has integer exponent `k_j` (so its scale is `2^k_j`). E4M3 represents `2^k` exactly only for `k ∈ [−9, 8]` (an 18-value window with -spread 17). +spread 17). `k_min`, `k_max`, and *spread* `= k_max − k_min` are taken over all +non-zero blocks in the tensor. ## Algorithms @@ -21,6 +22,7 @@ MXFP4 → BF16 (dequantize) → NVFP4 (standard quantize) ``` The "obvious" approach. Always introduces error from: + - Per-16-element re-bucketing (NVFP4 picks new amax per 16-block) - E4M3 mantissa quantization of per-block scales @@ -33,34 +35,47 @@ blocks of 16; both inherit the same exponent `k_j`. Pick a global scale - **In-range blocks** (`k_j − m ∈ [−9, 8]`): contribution **MSE = 0** — both `2^(k_j − m)` (E4M3) and `2^m` (FP32) are exactly representable, so the product reproduces `2^k_j` exactly. -- **Out-of-range blocks** (spread > 17): snap the per-block exponent to the - E4M3 boundary `clamp(k_j − m, −9, 8)`. Provably MSE-optimal *given verbatim - nibbles*, but the snap can be huge if a block's true scale is far from the - snapped value. +- **Out-of-range blocks** (only possible when spread > 17): snap the per-block + exponent to the E4M3 boundary `clamp(k_j − m, −9, 8)`. Provably MSE-optimal + *given verbatim nibbles*, but the snap can be huge if a block's true scale is + far from the snapped value. + +Two `m` strategies were tested: -**Choice of `m`** (two strategies tested): - `midpoint`: when spread ≤ 17, midpoint `m` makes everything in-range. When spread > 17, fall back to `median(k) − center`. - `search`: 1D integer sweep with closed-form objective - `Σ_j S_j · (2^k_j − 2^(m + clamp(k_j − m, −9, 8)))^2` where - `S_j = Σ_i e2m1_value_i^2` for block `j`. Cheap (≤ 50 candidates per tensor). - -### Algo 3: Hybrid (verbatim where exact, NVFP4-requant where lossy) - -Combines Algo 2's exact path with per-block requantization for OOR blocks: - -1. Search `m` (integer) by minimizing the actual hybrid reconstruction MSE. -2. For in-range MXFP4 blocks: keep verbatim path (zero error). -3. For OOR MXFP4 blocks: dequantize to FP32, then NVFP4-quantize each - 16-element half with the **fixed** `scale_2 = 2^m`. The per-NVFP4-block amax - can be smaller than the per-MXFP4-block amax — one half might lack the - max-magnitude nibble — letting OOR-at-MXFP4-level blocks fit cleanly into - per-NVFP4-block E4M3 scales. -4. Final reconstruction is masked-merged from the two paths. - -The `m` search is brute-force over the same integer range Algo 2 considers, but -evaluated against the actual hybrid MSE because NVFP4-requant's E4M3 mantissa -rounding isn't a clean closed form. + `Σ_j S_j · (2^k_j − 2^(m + clamp(k_j − m, −9, 8)))²` where + `S_j = Σ_i e2m1_value_i²` for block `j`. ≤ 50 candidates per tensor. + +### Algo 3: Hybrid — verbatim where exact, NVFP4-requant where lossy + +Combines Algo 2's exact path with per-block requantization for OOR blocks, with +a closed-form choice of `m`: + +1. Pick `m = k_max − 8`. +2. **In-range** MXFP4 blocks (`k_j − m ∈ [−9, 8]`): keep verbatim nibbles and + per-block E4M3 scale `2^(k_j − m)`. Zero error. +3. **Out-of-range** MXFP4 blocks (only possible when spread > 17): dequantize to + FP32, then NVFP4-quantize each 16-element half with the **fixed** `scale_2 = + 2^m`. The per-NVFP4-block amax can be smaller than the per-MXFP4-block amax — + one half might lack the max-magnitude nibble — letting OOR-at-MXFP4-level + blocks fit cleanly into per-NVFP4-block E4M3 scales. + +**Why `m = k_max − 8` is the right choice.** With this `m`, every block's +`k_j − m = k_j − k_max + 8` lands in `[8 − spread, 8]`. + +| Regime | `k_j − m` range | Behavior | +|---|---|---| +| spread ≤ 17 | `[8 − spread, 8] ⊆ [−9, 8]` | All blocks in-range. Verbatim path. **MSE = 0**. | +| spread > 17 | low blocks fall below `−9` | Low blocks go through NVFP4 requant; high blocks stay in-range. | + +The "always lossless when feasible" property is preserved (all blocks are in +range whenever the spread allows it). When spread > 17, this is the *only* +choice that doesn't NaN the highest-magnitude blocks — going lower forces +`amax / (6 · 2^m) > 448` (E4M3 max), going higher only shrinks the in-range +coverage on the low side without helping. An exhaustive integer-`m` search over +all 27 scenarios converged on this same rule, so the closed-form is sufficient. ## Experimental Setup @@ -83,7 +98,7 @@ rounding isn't a clean closed form. | Algo 2 outright winner | 0 | | Algo 1 outright winner | 1 | | Tied (≥ 2 algos at same MSE) | 22 | -| Algo 3 exact (MSE = 0) | 22/27 | +| Algo 3 exact (MSE = 0) | 22 / 27 | ### Algo 3 dominant wins (over Algo 2) @@ -103,7 +118,7 @@ rounding isn't a clean closed form. Gap is 0.21% — both algorithms underflow the small-magnitude blocks identically; the residual is the integer-vs-continuous quantization of the global scale. Algo 1 picks `scale_2 = global_amax / (6·448) ≈ 3.72`; Algo 3 is constrained to -`2^m` integer powers (here `m=3` → `scale_2 = 8`). +`2^m` integer powers (here `m = 3` → `scale_2 = 8`). ### Cases where Algo 3 loses to Algo 2 @@ -120,22 +135,21 @@ conv weight (4D), large flat (1024×4096). ## Why Algo 3 Works The asymmetry in error scaling explains everything: + - **Snap-up errors** scale as `(2^k_j − 2^(m+8))²`, dominated by the *true* magnitude `2^k_j` — can be enormous. - **Snap-down errors** scale as `(2^k_j − 2^(m−9))²`, bounded by the *snapped* magnitude `2^(m−9)`. -Algo 2's `m`-search already exploits this asymmetry by preferring low `m` values -that keep high-magnitude blocks in-range. But verbatim-snap on OOR blocks still -introduces a fixed-magnitude error per block, with no use of the within-block -structure. +Algo 3 protects the high-magnitude side by pinning `m = k_max − 8`. The +low-magnitude blocks may still snap down, but instead of suffering a fixed snap +error like Algo 2, they go through per-NVFP4-block requantization — which: -Algo 3 replaces that snap with a real per-NVFP4-block requantization, which: - Adapts to the actual half-block amax (much smaller than the MXFP4 block amax in many cases). - Lets the E4M3 mantissa carry information beyond pure powers of 2 — for an OOR block where `k_j − m = 9` but max nibble is `4`, the required scale is - `(4/6)·2^9 ≈ 341`, which fits in E4M3 (max 448). + `(4/6) · 2^9 ≈ 341`, which fits in E4M3 (max 448). - Costs nothing for in-range blocks because they keep the verbatim path. ## Recommendations @@ -143,18 +157,17 @@ Algo 3 replaces that snap with a real per-NVFP4-block requantization, which: 1. **Default to Algo 3** for MXFP4 → NVFP4 conversion. It is exact in the typical-weight case, strictly better than Algo 2 on spread-too-large cases, and within 0.2% of Algo 1 even on the pathological single-outlier case. + No search loop is needed — `m = k_max − 8` is a closed-form max-reduction. 2. **The bound case** (single block dominates the entire tensor's dynamic range) can be closed by allowing a continuous (non-power-of-2) global scale. The integer-`m` form is purely for clean E4M3 representation of in-range per-block scales; on OOR blocks Algo 3 already routes through E4M3 mantissa rounding, so dropping the integer constraint there costs nothing in exactness and recovers the last 0.2%. -3. **Detection of the pathological case** is cheap: when the spread is very - large *and* one block's `S_j` dominates the rest, Algo 1 (or Algo 3 with a - continuous global scale) is preferable to Algo 3-with-integer-`m`. -4. **Cost**: Algo 3 runs a 1D integer sweep over typically 20–50 candidates, each - evaluating one NVFP4 quantize+dequantize. For typical PTQ workflows this runs - once per tensor and is negligible. +3. **Cost**: Algo 3 is essentially free. Picking `m` is a max-reduction over + the block exponents; the verbatim path is a buffer copy plus a per-block + scale build; the requant path runs only on the (often empty) set of OOR + blocks, and even when present it is a single NVFP4 quantize+dequantize pass. ## Reproducibility