diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6646359c7c3..d2369885431 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ Changelog **New Features** - Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|``) and optional assistant-token ``loss_mask`` for answer-only-loss training. +- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. 0.44 (2026-05-xx) ^^^^^^^^^^^^^^^^^ diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 1cc1acfbf9a..096d3cfa6b6 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -114,6 +114,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http | GLM-4.78 | ✅ | - | - | - | ✅ | | Kimi K2 | - | - | - | - | ✅ | | MiniMax M2.1 | - | - | - | - | ✅ | +| GPT-OSS10 | - | - | - | - | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | - | | Whisper9 | ✅ | ❌ | ❌ | ❌ | - | | Nemotron-3 | ✅ | ❌ | ❌ | ❌ | ✅ | @@ -128,7 +129,8 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http > *6.Some models currently support export to HF format only.* \ > *7.[PTQ for DeepSeek](../deepseek/README.md)* \ > *8.GLM-4.7 has MTP (Multi-Token Prediction) layers that are automatically loaded and excluded from quantization.* \ -> *9.Running Whisper model with transformers>=5.0 requires [torchcodec](https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-cuda-enabled-torchcodec) and other system packages (e.g. ffmpeg).* +> *9.Running Whisper model with transformers>=5.0 requires [torchcodec](https://github.com/meta-pytorch/torchcodec?tab=readme-ov-file#installing-cuda-enabled-torchcodec) and other system packages (e.g. ffmpeg).* \ +> *10.GPT-OSS ships with native MXFP4 weights; NVFP4 export is produced via the closed-form `--cast_mxfp4_to_nvfp4` cast (see [MXFP4 → NVFP4 cast](#mxfp4--nvfp4-cast-for-gpt-oss)).* > *The accuracy loss after PTQ may vary depending on the actual model and the quantization method. Different models may have different accuracy loss and usually the accuracy loss is more significant when the base model is small. If the accuracy after PTQ is not meeting the requirement, please try either modifying [hf_ptq.py](./hf_ptq.py) and disabling the KV cache quantization or using the [QAT](./../llm_qat/README.md) instead. For NVFP4 quantization specifically, we recommend `nvfp4_mlp_only`, `nvfp4_experts_only`, or `nvfp4_omlp_only` to achieve higher accuracy by restricting quantization to the MLP/expert layers (and optionally the `o_proj` layer) while keeping the attention QKV projections unquantized.* @@ -221,6 +223,22 @@ Available KV cache formats: > *Formats ending in `_cast` (fp8_cast, nvfp4_cast) are fast — they set the amax to the format's full range without data-driven calibration. Other formats use data-driven calibration for potentially better accuracy.* +#### MXFP4 → NVFP4 cast (for GPT-OSS) + +GPT-OSS checkpoints (`openai/gpt-oss-20b`, `openai/gpt-oss-120b`) ship with native MXFP4 weights (`*_blocks` + `*_scales` in the checkpoint, `quantization_config.quant_method == "mxfp4"`). Passing `--cast_mxfp4_to_nvfp4` tells `hf_ptq.py` to read the source MXFP4 scales and produce a closed-form, bit-exact NVFP4 weight export — no GEMM-level recalibration of the weights needed. + +```bash +python hf_ptq.py \ + --pyt_ckpt_path openai/gpt-oss-20b \ + --qformat nvfp4_mlp_only \ + --cast_mxfp4_to_nvfp4 \ + --export_path +``` + +The cast pins each NVFP4 block's `scale_2 = 2^(k_max - 8)` and `_amax = 6 * 2^k_j`, both derived from the source MXFP4 E8M0 scales. For blocks whose `k_j` lands in E4M3's representable window (`k_max - k_j ≤ 17`), NVFP4 dequant matches MXFP4 dequant bit-for-bit; out-of-range blocks fall back to a data-derived per-block amax. + +> *`--cast_mxfp4_to_nvfp4` requires an NVFP4-family `--qformat` (e.g. `nvfp4_mlp_only`, `nvfp4_experts_only`, `nvfp4`) and is incompatible with `--auto_quantize_bits`.* + #### Deepseek R1 [PTQ for DeepSeek](../deepseek/README.md) shows how to quantize the DeepSeek model with FP4 and export to TensorRT-LLM. diff --git a/examples/llm_ptq/cast_mxfp4_to_nvfp4.py b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py new file mode 100644 index 00000000000..26f3c9f8258 --- /dev/null +++ b/examples/llm_ptq/cast_mxfp4_to_nvfp4.py @@ -0,0 +1,453 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Closed-form cast from MXFP4 source to NVFP4 weight quantizer state. + +Reads ``*_scales`` (E8M0 per-MXFP4-block exponents) and ``*_blocks`` (packed +E2M1 nibbles) from a Hugging Face checkpoint with +``quantization_config.quant_method == "mxfp4"`` (e.g. OpenAI's gpt-oss family) +and produces, per source layer: + +* a per-tensor ``global_amax = 6 * 448 * 2^(k_max - 8)`` that pins NVFP4's + ``scale_2`` to ``2^m`` (an exact power of 2, exactly representable in E4M3); +* a per-NVFP4-block ``_amax`` that is bit-exact (``6 * 2^k_j``) for blocks + whose ``k_j`` lands in E4M3's representable window, and falls back to the + data-derived ``max(|w_block|) = max_nibble * 2^k_j`` for out-of-range blocks + (where the per-block scale would clamp anyway). + +Together these guarantee NVFP4 dequant matches MXFP4 dequant bit-for-bit on +every in-range block, and minimizes per-block error on out-of-range ones. +Reads only the scales (~150 MB for gpt-oss-20b) plus the packed nibbles for +out-of-range blocks; runs in seconds. +""" + +import json +from contextlib import ExitStack, contextmanager +from pathlib import Path + +import torch +from safetensors import safe_open + +from modelopt.torch.quantization.nn.modules.tensor_quantizer import NVFP4StaticQuantizer + + +@contextmanager +def _shard_reader(): + """Yield a ``read(key, shard) -> tensor`` closure with cached safetensors handles. + + Each unique shard is opened lazily on first read and closed deterministically + when the context exits, so callers don't need to manage the handle cache or + the surrounding ``ExitStack`` themselves. + """ + with ExitStack() as stack: + handles: dict[Path, safe_open] = {} + + def read(key: str, shard: Path) -> torch.Tensor: + if shard not in handles: + handles[shard] = stack.enter_context(safe_open(shard, framework="pt", device="cpu")) + return handles[shard].get_tensor(key) + + yield read + + +E8M0_BIAS = 127 # E8M0 stores k_j as uint8 with bias 127 +E2M1_MAX = 6.0 +E4M3_MAX = 448.0 +E4M3_KMAX = 8 +E4M3_KMIN = -9 # E4M3 represents 2^k exactly for k in [-9, 8] +# E2M1 magnitude grid indexed by the low 3 bits of an FP4 nibble. +_E2M1_MAGNITUDE = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] +# Cache of the E2M1 magnitude lookup table per (device, dtype) so we don't +# rebuild it for every layer in a batched cast. +_E2M1_MAG_CACHE: "dict[tuple, torch.Tensor]" = {} + + +def _e2m1_magnitude_table(device: torch.device, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Return ``_E2M1_MAGNITUDE`` as a tensor on the requested device, cached.""" + key = (device, dtype) + cached = _E2M1_MAG_CACHE.get(key) + if cached is None: + cached = torch.tensor(_E2M1_MAGNITUDE, dtype=dtype, device=device) + _E2M1_MAG_CACHE[key] = cached + return cached + + +def compute_global_amax_for_scales(e8m0_scales: torch.Tensor) -> tuple[float, dict]: + """Closed-form per-tensor ``global_amax``: ``m = k_max - 8``, ``global_amax = 6 * 448 * 2^m``. + + Args: + e8m0_scales: uint8 tensor of E8M0 scales for one MXFP4 source layer. + + Returns: + global_amax: scalar (float) — pins NVFP4 scale_2 to 2^m. + info: diagnostic dict with k_min, k_max, m, lossless-block stats. + """ + # k_j = e8m0 - 127. MXFP4 quantize emits e8m0=0 (=> k=-127) for all-zero + # blocks; treat those as "ignore me" when computing k_max. + k = e8m0_scales.to(torch.int32) - E8M0_BIAS + nonzero_mask = e8m0_scales > 0 + if nonzero_mask.any(): + k_nonzero = k[nonzero_mask] + k_min = int(k_nonzero.min().item()) + k_max = int(k_nonzero.max().item()) + else: + k_min = k_max = 0 + + m = k_max - E4M3_KMAX + global_amax = E2M1_MAX * E4M3_MAX * float(2.0**m) + + # A block is lossless under this cast iff k_max - k_j <= 17 (its k_j - m sits + # in E4M3's [-9, 8] window). All-zero blocks are trivially lossless because + # their reconstruction is 0 regardless of the snapped scale. + n_total = e8m0_scales.numel() + in_range = (k >= (k_max - 17)) | (~nonzero_mask) + n_lossless = int(in_range.sum().item()) + pct_lossless = 100.0 * n_lossless / n_total if n_total else 100.0 + + return global_amax, { + "k_min": k_min, + "k_max": k_max, + "m": m, + "n_total_blocks": n_total, + "n_lossless_blocks": n_lossless, + "pct_lossless": pct_lossless, + "n_zero_blocks": int((~nonzero_mask).sum().item()), + } + + +def compute_per_block_amax_for_mxfp4( + blocks: torch.Tensor, e8m0_scales: torch.Tensor +) -> torch.Tensor: + """Hybrid per-NVFP4-block amax for MXFP4 -> NVFP4 cast. + + Each MXFP4 block of 32 elements has one E8M0 exponent ``k_j``. Two cases + based on whether ``k_j`` fits in NVFP4's E4M3 scale grid (with + ``m = k_max - 8`` chosen by ``compute_global_amax_for_scales``): + + - **In-range** (``k_j - m`` in ``[-9, 8]``): ``6 * 2^k_j`` (closed-form + ideal). The resulting per-block scale ``2^(k_j - m)`` is exactly + representable in E4M3 — no rounding loss — and + ``round_to_E2M1(value / 2^k_j)`` yields the original MXFP4 nibble + verbatim. Bit-exact reconstruction. + + - **Out of range** (``|k_j - m| > 8/9``): ``max_nibble * 2^k_j``, i.e. + ``max(|w_block|)`` where ``w`` is the MXFP4-dequantized block. This is + the data-derived per-block amax. The per-block scale will still get + clamped at the E4M3 boundary, but data-derived amax keeps the post-clamp + scale closer to the block's actual magnitude than the closed-form ideal + would, which reduces re-bucketing error for OOR blocks where + ``max_nibble < 6``. + + Two NVFP4 blocks of 16 share each MXFP4 block's ``k_j``, so the result is + expanded by ``repeat_interleave(2, dim=-1)``. + + Args: + blocks: uint8 tensor of packed E2M1 nibbles, shape + ``(..., num_mxfp4_blocks, 16)`` (16 bytes per 32-element MXFP4 block). + e8m0_scales: uint8 tensor of E8M0 scales, shape + ``(..., num_mxfp4_blocks)``. + + Returns: + float32 tensor of shape ``(..., 2 * num_mxfp4_blocks)``. + """ + if blocks.shape[-1] != 16 or blocks.shape[:-1] != e8m0_scales.shape: + raise ValueError( + f"shape mismatch: blocks {tuple(blocks.shape)} " + "(expected (..., num_mxfp4_blocks, 16)) " + f"vs scales {tuple(e8m0_scales.shape)}" + ) + + k = e8m0_scales.to(torch.int32) - E8M0_BIAS # (..., num_mxfp4_blocks) + pow2_k = torch.exp2(k.float()) + closed_form_ideal = E2M1_MAX * pow2_k # (..., num_mxfp4_blocks) + + # ``m = k_max - 8`` over non-zero blocks. Compute via masked ``amax`` so + # ``m`` stays a 0-d tensor and we avoid a GPU->CPU sync just to get a + # Python int. All-zero scales fall through with the -E8M0_BIAS sentinel, + # which leaves every block trivially in-range (closed_form_ideal == 0 there). + nonzero = e8m0_scales > 0 + sentinel = torch.full_like(k, -E8M0_BIAS) + k_max = torch.where(nonzero, k, sentinel).amax() + delta = k - (k_max - E4M3_KMAX) + in_range = (delta >= E4M3_KMIN) & (delta <= E4M3_KMAX) + + # Fast path: if every block fits E4M3's [-9, 8] window the per-block amax + # is just the closed-form ideal, and we can skip the per-byte nibble scan + # over the block tensor (which is 16x larger than the scales). For typical + # MXFP4 checkpoints (e.g. gpt-oss-20b) this is the only path ever taken. + if bool(in_range.all()): + return closed_form_ideal.repeat_interleave(2, dim=-1) + + # OOR fallback: data-derived per-block amax = max(|w_block|) after MXFP4 + # dequant = ``max_nibble * 2^k_j``. The MXFP4 nibble is sign-magnitude with + # sign in bit 3 and magnitude index in bits 0-2; we extract per-byte + # magnitudes, take the byte-wise max, then reduce across the 16 bytes to + # get the largest magnitude index in the 32-element block. + low = blocks & 0x07 + high = (blocks >> 4) & 0x07 + max_idx = torch.maximum(low, high).amax(dim=-1).long() + max_nibble = _e2m1_magnitude_table(blocks.device)[max_idx] + data_derived = max_nibble * pow2_k + + per_block_amax_mxfp4 = torch.where(in_range, closed_form_ideal, data_derived) + # Each MXFP4 block of 32 splits into two NVFP4 blocks of 16 sharing k_j. + return per_block_amax_mxfp4.repeat_interleave(2, dim=-1) + + +def quantizer_name_from_blocks_key(blocks_key: str) -> str: + """Map ``_blocks`` -> ``_weight_quantizer``. + + OpenAI's MXFP4 checkpoint convention stores packed weights as + ``_blocks`` and scales as ``_scales``. modelopt's + ``GptOssExperts`` wrapper attaches the weight quantizer at + ``_weight_quantizer``. + """ + assert blocks_key.endswith("_blocks"), f"Unexpected key {blocks_key!r}" + return blocks_key[: -len("_blocks")] + "_weight_quantizer" + + +def _collect_keys_with_suffix(ckpt_dir: Path, suffix: str) -> dict[str, Path]: + """Return ``{tensor_name: shard_path}`` for every key ending with ``suffix``.""" + index_path = ckpt_dir / "model.safetensors.index.json" + if index_path.is_file(): + with index_path.open() as f: + index = json.load(f) + return { + k: ckpt_dir / shard for k, shard in index["weight_map"].items() if k.endswith(suffix) + } + shards = list(ckpt_dir.glob("*.safetensors")) + if len(shards) != 1: + raise FileNotFoundError( + f"Expected model.safetensors.index.json or a single .safetensors file in {ckpt_dir}" + ) + out: dict[str, Path] = {} + with safe_open(shards[0], framework="pt") as f: + # ``safe_open`` is not a dict; ``.keys()`` is its iterator. + for k in f.keys(): # noqa: SIM118 + if k.endswith(suffix): + out[k] = shards[0] + return out + + +def _collect_scales_keys(ckpt_dir: Path) -> dict[str, Path]: + """Return ``{tensor_name: shard_path}`` for every ``*_scales`` key.""" + return _collect_keys_with_suffix(ckpt_dir, "_scales") + + +def build_amax_map(checkpoint_dir: str | Path) -> dict[str, dict]: + """Walk the source MXFP4 checkpoint and build the per-layer amax map. + + Args: + checkpoint_dir: Path to a Hugging Face checkpoint directory whose + ``quantization_config.quant_method`` is ``"mxfp4"`` (OpenAI layout + with ``*_blocks`` + ``*_scales`` tensors). + + Returns: + ``{quantizer_name: {"global_amax": float, "k_min": int, "k_max": int, + "m": int, "n_total_blocks": int, + "n_lossless_blocks": int, "pct_lossless": float, + "n_zero_blocks": int}}`` + + Quantizer names match ``model.named_modules()`` after modelopt + instrumentation (e.g. ``model.layers.0.mlp.experts.gate_up_proj_weight_quantizer``). + + Raises: + SystemExit: if no ``*_scales`` tensors are found (not an MXFP4 checkpoint). + """ + ckpt_dir = Path(checkpoint_dir) + if not ckpt_dir.is_dir(): + raise FileNotFoundError(f"Checkpoint dir not found: {ckpt_dir}") + + scales_keys = _collect_scales_keys(ckpt_dir) + if not scales_keys: + raise SystemExit( + f"No '*_scales' tensors found in {ckpt_dir}. " + "This requires an MXFP4 HF checkpoint with the OpenAI layout." + ) + + amax_map: dict[str, dict] = {} + with _shard_reader() as read: + for tensor_key, shard in sorted(scales_keys.items()): + scales = read(tensor_key, shard) + + global_amax, info = compute_global_amax_for_scales(scales) + + blocks_key = tensor_key[: -len("_scales")] + "_blocks" + qname = quantizer_name_from_blocks_key(blocks_key) + amax_map[qname] = {"global_amax": global_amax, **info} + + return amax_map + + +def force_weight_quantizers_static(quant_cfg: list) -> None: + """Force every weight-quantizer entry's ``block_sizes`` to ``type='static'``. + + The MXFP4 -> NVFP4 cast needs the per-block weight ``_amax`` to be recorded + by max-cal (so it can be paired with the pinned global_amax later). Setting + ``block_sizes['type'] = 'static'`` makes ``is_static_block_quant`` True so + ``promote_nvfp4_static_quantizers`` picks the entry up automatically at the + end of max_calibrate. + """ + for i, entry in enumerate(quant_cfg): + qname = entry.get("quantizer_name", "") + cfg = entry.get("cfg") or {} + bs = cfg.get("block_sizes") + if "weight_quantizer" in qname and isinstance(bs, dict): + quant_cfg[i] = {**entry, "cfg": {**cfg, "block_sizes": {**bs, "type": "static"}}} + + +def apply_to_model( + model: "torch.nn.Module", + source_checkpoint_path: str | Path, +) -> None: + """Closed-form cast: bit-exact MXFP4 -> NVFP4 weight conversion. + + Reads the source MXFP4 ``*_scales`` from ``source_checkpoint_path`` and + overrides two buffers on each matching NVFP4 weight quantizer: + + 1. ``global_amax`` = ``6 * 448 * 2^(k_max - 8)`` (closed-form scalar — + pins ``scale_2 = 2^m``). + 2. ``_amax`` = ``6 * 2^k_j`` per NVFP4 block (closed-form per-block — pins + ``per_block_scale = 2^(k_j - m)``, exactly representable in E4M3). + + Together these guarantee that ``per_block_scale * scale_2 = 2^k_j`` exactly, + so the NVFP4 dequant produces ``nibble * 2^k_j`` — the same value as the + MXFP4 dequant. End-to-end the weight conversion is bit-exact for every + block whose ``k_j`` lands in E4M3's representable range (``k_max - k_j <= 17``). + + The weight quantizer is expected to be an :class:`NVFP4StaticQuantizer` + (:func:`max_calibrate` auto-promotes static-block NVFP4 weight quantizers + at the end of calibration). Both ``_amax`` (per-block from max-cal) and + ``_global_amax`` (per-tensor from the auto-promotion) get overwritten. + """ + ckpt_dir = Path(source_checkpoint_path) + if not ckpt_dir.is_dir(): + raise FileNotFoundError(f"Checkpoint dir not found: {ckpt_dir}") + scales_keys = _collect_scales_keys(ckpt_dir) + if not scales_keys: + raise SystemExit( + f"No '*_scales' tensors found in {ckpt_dir}. " + "This requires an MXFP4 HF checkpoint with the OpenAI layout." + ) + + blocks_keys = _collect_keys_with_suffix(ckpt_dir, "_blocks") + + name_to_module = dict(model.named_modules()) + matched = 0 + missed: list[str] = [] + + n_total_layers = 0 + n_lossless_layers = 0 + grand_total_blocks = 0 + grand_lossless_blocks = 0 + + with _shard_reader() as read: + for tensor_key, shard in sorted(scales_keys.items()): + scales = read(tensor_key, shard) + + global_amax_value, info = compute_global_amax_for_scales(scales) + n_total_layers += 1 + if info["pct_lossless"] >= 100.0: + n_lossless_layers += 1 + grand_total_blocks += info["n_total_blocks"] + grand_lossless_blocks += info["n_lossless_blocks"] + + blocks_key = tensor_key[: -len("_scales")] + "_blocks" + qname = quantizer_name_from_blocks_key(blocks_key) + blocks_shard = blocks_keys.get(blocks_key) + assert blocks_shard is not None, ( + f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." + ) + + weight_quantizer = name_to_module.get(qname) + if weight_quantizer is None: + missed.append(qname) + continue + + # The cast assumes ``max_calibrate`` already promoted this quantizer + # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by + # static-block max-cal and ``_global_amax`` set by the auto-promote). + # Anything else means the qformat or quant_cfg disabled this module's + # weight quantization — surface that loudly so we don't silently no-op. + assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ( + f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " + f"auto-promote), got {type(weight_quantizer).__name__}. The cast " + "requires the matching quantizer to be enabled with static-block " + "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." + ) + existing = getattr(weight_quantizer, "_amax", None) + assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ( + f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " + f"buffer populated by max_calibrate. Got: {existing!r}." + ) + + # Pick the device from the existing per-block ``_amax`` buffer. + device = existing.device + + global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device) + # Fully-lossless layers don't need the packed ``*_blocks`` tensor — + # the per-block amax is just ``6 * 2^k_j`` from ``scales`` alone, and + # avoiding the (16x larger) block read is the main I/O win the + # closed-form path is designed for. + if info["pct_lossless"] >= 100.0: + k = scales.to(torch.int32) - E8M0_BIAS + per_block_amax = ( + (E2M1_MAX * torch.exp2(k.float())) + .repeat_interleave(2, dim=-1) + .to(dtype=torch.float32, device=device) + ) + else: + blocks = read(blocks_key, blocks_shard) + per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to( + dtype=torch.float32, device=device + ) + # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1)) + # while we compute it in natural (E, F, num_blocks) layout. The static + # export path reshapes via ``.view(expected_shape)``, so we just need + # element count to agree, then reshape for the in-place copy. + assert existing.numel() == per_block_amax.numel(), ( + f"{qname}: ``_amax`` element count {existing.numel()} does not " + f"match the cast-computed count {per_block_amax.numel()}. The " + "block layout from calibration disagrees with the source MXFP4 " + "scales — check that the qformat block_size is 16 and the source " + "checkpoint is the same MXFP4 model." + ) + + # global_amax via the NVFP4StaticQuantizer property setter (writes to + # the canonical ``_global_amax`` buffer). + weight_quantizer.global_amax = global_amax + # _amax: in-place buffer copy, reshaping our value to the calibrator's + # storage layout (numel verified above). + with torch.no_grad(): + existing.data.copy_(per_block_amax.view_as(existing)) + + matched += 1 + + print( + f"[cast_mxfp4_to_nvfp4] overrode {matched}/{n_total_layers} weight quantizers from {source_checkpoint_path}" + ) + if missed: + print( + f"[cast_mxfp4_to_nvfp4] warning: {len(missed)} layers had no matching module. " + f"First few: {missed[:5]}" + ) + layer_pct = 100.0 * n_lossless_layers / n_total_layers if n_total_layers else 100.0 + block_pct = 100.0 * grand_lossless_blocks / grand_total_blocks if grand_total_blocks else 100.0 + print( + f"[cast_mxfp4_to_nvfp4] lossless layers: {n_lossless_layers}/{n_total_layers} ({layer_pct:.2f}%)" + ) + print( + f"[cast_mxfp4_to_nvfp4] lossless blocks: {grand_lossless_blocks}/{grand_total_blocks} ({block_pct:.4f}%)" + ) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d660c1de4c8..edf9c4d6f19 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -24,6 +24,8 @@ import numpy as np import torch from accelerate.hooks import remove_hook_from_module +from cast_mxfp4_to_nvfp4 import apply_to_model as apply_cast_mxfp4_to_nvfp4 +from cast_mxfp4_to_nvfp4 import force_weight_quantizers_static from example_utils import ( build_quant_cfg, copy_custom_model_files, @@ -1087,6 +1089,10 @@ def quantize_main( f"Auto-resolved layerwise_checkpoint_dir: {quant_cfg['algorithm']['layerwise_checkpoint_dir']}" ) + if args.cast_mxfp4_to_nvfp4: + quant_cfg = copy.deepcopy(quant_cfg) + force_weight_quantizers_static(quant_cfg["quant_cfg"]) + if args.qformat in QUANT_CFG_CHOICES: mono_quantize( args, @@ -1102,6 +1108,14 @@ def quantize_main( assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton" print(f"qformat: {args.qformat}. No quantization applied, export {device} model") + # If asked, run the closed-form MXFP4 -> NVFP4 cast: read the source MXFP4 + # *_scales tensors and pin each NVFP4 weight quantizer's scale_2 to 2^m. + # Runs after calibration (max_calibrate has already promoted weight quantizers + # to NVFP4StaticQuantizer with a data-derived ``_global_amax``); we just + # override that scalar with the closed-form value before export. + if args.cast_mxfp4_to_nvfp4: + apply_cast_mxfp4_to_nvfp4(language_model, args.pyt_ckpt_path) + post_quantize( args, full_model, @@ -1343,6 +1357,18 @@ def parse_args() -> argparse.Namespace: help="Export as vLLM fake-quant checkpoint (produces vllm_fq_modelopt_state.pth " "for use with vllm_serve_fakequant.py).", ) + parser.add_argument( + "--cast_mxfp4_to_nvfp4", + action="store_true", + default=False, + help=( + "After calibration, override NVFP4 weight quantizers' global_amax with " + "the closed-form value derived from the source MXFP4 *_scales. " + "Per-block _amax is computed from the loaded BF16 weights (data-derived). " + "Use when --pyt_ckpt_path points at an MXFP4 HF checkpoint (e.g. " + "openai/gpt-oss-20b) and the target qformat is NVFP4-family." + ), + ) args = parser.parse_args() if args.moe_calib_experts_ratio is not None and not (0.0 < args.moe_calib_experts_ratio <= 1.0): @@ -1415,4 +1441,17 @@ def main(args: argparse.Namespace): "--specdec_offline_dataset expects a single --calib value, not a comma-separated list." ) + if args.cast_mxfp4_to_nvfp4: + qformats = [q.strip() for q in args.qformat.split(",")] + if not all("nvfp4" in q for q in qformats): + raise ValueError( + "--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values " + f"(got {args.qformat!r}). Use e.g. --qformat nvfp4 or nvfp4_mlp_only." + ) + if args.auto_quantize_bits is not None: + raise ValueError( + "--cast_mxfp4_to_nvfp4 is not supported with --auto_quantize_bits " + "(multi-format auto-quantize)." + ) + main(args) diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index d9c4ff8a7a0..6ca99c7f963 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -143,6 +143,10 @@ if [ -n "$MOE_CALIB_EXPERTS_RATIO" ]; then PTQ_ARGS+=" --moe_calib_experts_ratio=$MOE_CALIB_EXPERTS_RATIO " fi +if $CAST_MXFP4_TO_NVFP4; then + PTQ_ARGS+=" --cast_mxfp4_to_nvfp4 " +fi + if ! $VERBOSE; then PTQ_ARGS+=" --no-verbose " fi diff --git a/examples/llm_ptq/scripts/parser.sh b/examples/llm_ptq/scripts/parser.sh index b41b715340d..3817c1dee7c 100644 --- a/examples/llm_ptq/scripts/parser.sh +++ b/examples/llm_ptq/scripts/parser.sh @@ -34,9 +34,10 @@ parse_options() { KV_CACHE_FREE_GPU_MEMORY_FRACTION=0.8 VERBOSE=true USE_SEQ_DEVICE_MAP=false + CAST_MXFP4_TO_NVFP4=false # Parse command-line options - ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:" -n "$0" -- "$@") + ARGS=$(getopt -o "" -l "model:,quant:,kv_cache_quant:,tp:,pp:,sparsity:,awq_block_size:,calib:,calib_batch_size:,auto_quantize_bits:,output:,batch:,tasks:,lm_eval_tasks:,lm_eval_limit:,simple_eval_tasks:,trust_remote_code,use_seq_device_map,gpu_max_mem_percentage:,kv_cache_free_gpu_memory_fraction:,low_memory_mode,no-verbose,calib_dataset:,calib_seq:,auto_quantize_method:,auto_quantize_score_size:,auto_quantize_checkpoint:,moe_calib_experts_ratio:,cast_mxfp4_to_nvfp4" -n "$0" -- "$@") eval set -- "$ARGS" while true; do @@ -69,6 +70,7 @@ parse_options() { --auto_quantize_score_size ) AUTO_QUANTIZE_SCORE_SIZE="$2"; shift 2;; --auto_quantize_checkpoint ) AUTO_QUANTIZE_CHECKPOINT="$2"; shift 2;; --moe_calib_experts_ratio ) MOE_CALIB_EXPERTS_RATIO="$2"; shift 2;; + --cast_mxfp4_to_nvfp4 ) CAST_MXFP4_TO_NVFP4=true; shift;; -- ) shift; break ;; * ) break ;; esac @@ -158,5 +160,6 @@ parse_options() { echo "auto_quantize_score_size: $AUTO_QUANTIZE_SCORE_SIZE" echo "auto_quantize_checkpoint: $AUTO_QUANTIZE_CHECKPOINT" echo "moe_calib_experts_ratio: $MOE_CALIB_EXPERTS_RATIO" + echo "cast_mxfp4_to_nvfp4: $CAST_MXFP4_TO_NVFP4" echo "=================" } diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e8ee5afd451..ae8bd5c7cbc 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -1079,14 +1079,23 @@ def set_expert_quantizer_amax( target_amax = None - # Collect ANY existing amax values from current batch (most direct source) + # Collect ANY existing amax values from current batch (most direct source). + # Reduce per-quantizer amax to a scalar before stacking — quantizers in + # static-mode (e.g. NVFP4 with pre-computed per-block _amax) carry tensors + # whose shapes differ across attrs (gate_up_proj vs down_proj have different + # output dims), and torch.stack would otherwise fail. The result here is + # only used as a *fallback* scalar `target_amax` for quantizers missing + # amax, so a max-of-max is exactly what we want. valid_amax_values = [] for _, attr_name, quantizer in all_quantizers: existing_amax = getattr(quantizer, "amax", None) if existing_amax is not None: # Convert to tensor and add to collection if isinstance(existing_amax, torch.Tensor): - valid_amax_values.append(existing_amax.to(target_device)) + # Meta tensors have no storage; .amax() / .to() would fail. + if existing_amax.is_meta: + continue + valid_amax_values.append(existing_amax.amax().to(target_device)) else: valid_amax_values.append( torch.tensor(existing_amax, dtype=torch.float32, device=target_device) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a76783ac172..c0f00f7e9a1 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -52,11 +52,7 @@ from torch.distributed.fsdp import FSDPModule from modelopt.torch.quantization import set_quantizer_by_cfg_context -from modelopt.torch.quantization.nn import ( - NVFP4StaticQuantizer, - SequentialQuantizer, - TensorQuantizer, -) +from modelopt.torch.quantization.nn import SequentialQuantizer, TensorQuantizer from modelopt.torch.quantization.qtensor import MXFP8QTensor, NVFP4QTensor from modelopt.torch.quantization.utils import fsdp2_aware_weight_update, quantizer_attr_names @@ -539,11 +535,12 @@ def _export_quantized_weight( expert_type in type(sub_module).__name__ for expert_type in ["Llama4TextExperts", "GptOssExperts"] ) - if is_bmm_expert_weight and isinstance(weight_quantizer, NVFP4StaticQuantizer): - raise ValueError( - "NVFP4StaticQuantizer with BMM-style expert weights (e.g. Llama4TextExperts, " - "GptOssExperts) is not yet supported." - ) + # NVFP4StaticQuantizer + BMM-style experts: route through the static-aware + # ``_from_quantizer`` helper so the pinned per-block ``_amax`` (e.g. set by + # the MXFP4->NVFP4 cast to ``6 * 2^k_j``) is used to derive the FP8 + # per-block scale. The plain ``get_weights_scaling_factor`` would ignore + # ``_amax`` and recompute per-block max from the BF16 weight, which + # rebuckets nibbles and loses bit-exactness when ``max_nibble < 6``. if quantization_format in [ QUANTIZATION_NVFP4, @@ -556,11 +553,18 @@ def _export_quantized_weight( weight, is_bmm_expert_weight=is_bmm_expert_weight ) - weight_scale = NVFP4QTensor.get_weights_scaling_factor( - weight, - block_size=block_size, - weights_scaling_factor_2=weight_scale_2, - )[0] + if NVFP4QTensor._is_static_quantizer(weight_quantizer): + weight_scale = NVFP4QTensor.get_weights_scaling_factor_from_quantizer( + weight_quantizer, + weight, + weight_scale_2, + )[0] + else: + weight_scale = NVFP4QTensor.get_weights_scaling_factor( + weight, + block_size=block_size, + weights_scaling_factor_2=weight_scale_2, + )[0] quantized_weight = to_quantized_weight( weight.to(dtype), diff --git a/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py b/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py index 9eb6b2d49f5..0e6874ab575 100644 --- a/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py +++ b/modelopt/torch/kernels/quantization/gemm/fp4_kernel.py @@ -251,14 +251,26 @@ def static_blockwise_fp4_fake_quant( """Static blockwise FP4 fake quantization using Triton kernel. Args: - x: [NUM_FP4_BLOCKS, BLOCK_SIZE] on CUDA. - amax: [NUM_FP4_BLOCKS] or [NUM_FP4_BLOCKS, 1] per-block amax values. + x: Input tensor on CUDA. The last dim must be the block dim (each consecutive + ``BLOCK_SIZE`` elements form one FP4 block). Any number of leading dims + is supported — they're flattened internally and the shape is restored + on output (so MoE expert weights ``(E, F, K)`` work the same as plain + linear weights ``(N, K)``). + amax: Per-block amax values. ``amax.numel()`` must equal + ``x.numel() // BLOCK_SIZE``. Shape is otherwise free; the kernel + consumes it as a flat 1-D buffer of length ``NUM_FP4_BLOCKS``. global_amax: FP32 scalar global amax. If provided, used to compute scale_fp8_quant_amax. quantize_block_scales: If True, quantize block scales to FP8. out_dtype: Output dtype. Defaults to x.dtype if None. """ - assert x.ndim == 2 - NUM_FP4_BLOCKS, BLOCK_SIZE = x.shape + original_shape = x.shape + NUM_FP4_BLOCKS = amax.numel() + if x.numel() % NUM_FP4_BLOCKS != 0: + raise ValueError( + f"x.numel() ({x.numel()}) is not divisible by amax.numel() ({NUM_FP4_BLOCKS}); " + "they must satisfy x.numel() == NUM_FP4_BLOCKS * BLOCK_SIZE." + ) + BLOCK_SIZE = x.numel() // NUM_FP4_BLOCKS if out_dtype is None: out_dtype = x.dtype @@ -267,7 +279,7 @@ def static_blockwise_fp4_fake_quant( x_flat = x.contiguous().view(-1) y_flat = torch.empty_like(x_flat, dtype=out_dtype) - scale_flat = scale.view(NUM_FP4_BLOCKS).contiguous() + scale_flat = scale.contiguous().view(NUM_FP4_BLOCKS) tl_out_dtype = _torch_dtype_to_tl(out_dtype) @@ -283,4 +295,4 @@ def static_blockwise_fp4_fake_quant( OUT_DTYPE=tl_out_dtype, ) - return y_flat.view_as(x) + return y_flat.view(original_shape) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0c2033041d6..4ce0f62a75d 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -163,6 +163,15 @@ def max_calibrate( if hasattr(module, "layer_sync_moe_local_experts_amax"): module.layer_sync_moe_local_experts_amax(sync_weight_amax=sync_expert_weight_amax) + # Promote eligible static-block NVFP4 weight quantizers to NVFP4StaticQuantizer + # so the static blockwise fake-quant path is used in forward and the export + # picks up the two-level (per-block + global) scaling. Run before the + # ``distributed_sync`` early return so single-process callers also get the + # promotion. ``promote_nvfp4_static_quantizers`` only promotes when + # ``is_static_block_quant`` is True and the per-block ``_amax`` buffer is + # populated, so it's a no-op for dynamic-block / non-NVFP4 configs. + promote_nvfp4_static_quantizers(model) + if not distributed_sync: return diff --git a/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py b/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py new file mode 100644 index 00000000000..b6f8c3de123 --- /dev/null +++ b/tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for ``examples/llm_ptq/cast_mxfp4_to_nvfp4.py``. + +The module lives next to the example script (not inside the ``modelopt`` package), +so we add ``examples/llm_ptq/`` to ``sys.path`` before importing it. +""" + +import json +import sys +from pathlib import Path + +import pytest +import torch +from safetensors.torch import save_file + +_LLM_PTQ_DIR = Path(__file__).resolve().parents[3] / "examples" / "llm_ptq" +if str(_LLM_PTQ_DIR) not in sys.path: + sys.path.insert(0, str(_LLM_PTQ_DIR)) + +import cast_mxfp4_to_nvfp4 as cast + +# ---------- compute_global_amax_for_scales ---------------------------------- + + +def test_global_amax_basic_in_range(): + """Mixed in-range scales: m = k_max - 8, global_amax = 6*448*2^m, lossless = 100%.""" + # k values in [-3, 3] (spread = 6), all blocks lossless. + k = torch.tensor([0, -3, 3, 1, -1, 2], dtype=torch.int32) + e8m0 = (k + cast.E8M0_BIAS).to(torch.uint8) + + global_amax, info = cast.compute_global_amax_for_scales(e8m0) + assert info["k_min"] == -3 + assert info["k_max"] == 3 + assert info["m"] == 3 - 8 # k_max - 8 = -5 + expected = 6.0 * 448.0 * 2.0 ** info["m"] + assert global_amax == pytest.approx(expected) + assert info["n_total_blocks"] == 6 + assert info["n_lossless_blocks"] == 6 + assert info["pct_lossless"] == pytest.approx(100.0) + assert info["n_zero_blocks"] == 0 + + +def test_global_amax_with_zero_blocks(): + """Zero (e8m0=0, k=-127) blocks should be ignored when computing k_max.""" + e8m0 = torch.tensor([0, 0, 130, 125], dtype=torch.uint8) # ks: -127, -127, 3, -2 + global_amax, info = cast.compute_global_amax_for_scales(e8m0) + assert info["k_max"] == 3 # ignores zero blocks + assert info["n_zero_blocks"] == 2 + # Both nonzero blocks satisfy k_max - k_j <= 17, plus zero blocks count as + # lossless because their reconstruction is 0 regardless of scale. + assert info["n_lossless_blocks"] == 4 + + +def test_global_amax_with_oor_blocks(): + """A block 18 powers below k_max is OOR (k_max - k = 18 > 17).""" + # k values: 5, 5, -13 → spread = 18, last block is OOR. + k = torch.tensor([5, 5, -13], dtype=torch.int32) + e8m0 = (k + cast.E8M0_BIAS).to(torch.uint8) + _, info = cast.compute_global_amax_for_scales(e8m0) + assert info["k_max"] == 5 + assert info["n_total_blocks"] == 3 + assert info["n_lossless_blocks"] == 2 # the k=-13 block is OOR + + +def test_global_amax_all_zero(): + """All-zero scales should not crash; k_max defaults to 0.""" + e8m0 = torch.zeros(4, dtype=torch.uint8) + global_amax, info = cast.compute_global_amax_for_scales(e8m0) + assert info["k_min"] == 0 and info["k_max"] == 0 + assert info["n_zero_blocks"] == 4 + # All blocks count as "lossless" (their dequant is 0 regardless of scale). + assert info["n_lossless_blocks"] == 4 + + +# ---------- compute_per_block_amax_for_mxfp4 -------------------------------- + + +def _make_blocks_with_max_nibble(num_blocks: int, max_idx_per_block: list[int]) -> torch.Tensor: + """Build a (num_blocks, 16) byte tensor where block i has E2M1 magnitude + index ``max_idx_per_block[i]`` as its largest nibble; other nibbles are 0. + + Magnitude index goes in the low 3 bits of one nibble; we place it in the + high nibble of byte 0 (so the first byte = (max_idx << 4)). Every other + nibble is 0, so the block-wise max is exactly ``max_idx_per_block[i]``. + """ + assert len(max_idx_per_block) == num_blocks + blocks = torch.zeros((num_blocks, 16), dtype=torch.uint8) + for i, idx in enumerate(max_idx_per_block): + assert 0 <= idx < 8 + blocks[i, 0] = (idx & 0x07) << 4 + return blocks + + +def test_per_block_amax_in_range_returns_closed_form(): + """Every block in-range -> 6 * 2^k_j, regardless of actual nibble content.""" + # k = [0, -2, 4]; k_max = 4, k_min = -2, spread 6 (in-range). + k = torch.tensor([0, -2, 4], dtype=torch.int32) + e8m0 = (k + cast.E8M0_BIAS).to(torch.uint8) + # Blocks have varying max_nibbles, but in-range path ignores them. + blocks = _make_blocks_with_max_nibble(3, [3, 7, 1]) # max nibbles: 1.5, 6, 0.5 + + out = cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + expected_mxfp4 = 6.0 * torch.exp2(k.float()) # ignores max_nibble + expected_nvfp4 = expected_mxfp4.repeat_interleave(2, dim=-1) + assert torch.allclose(out, expected_nvfp4) + + +def test_per_block_amax_oor_uses_data_derived(): + """OOR blocks should use ``max_nibble * 2^k_j`` (data-derived).""" + # k_max=10 → m=2. OOR-low blocks have k_j - m < -9, i.e. k_j < -7. + k = torch.tensor([10, -10], dtype=torch.int32) # second is OOR-low + e8m0 = (k + cast.E8M0_BIAS).to(torch.uint8) + # Block 0 max nibble idx 7 (value 6); block 1 max nibble idx 4 (value 2). + blocks = _make_blocks_with_max_nibble(2, [7, 4]) + + out = cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + + # Block 0 (in-range): 6 * 2^10 = 6144. + # Block 1 (OOR): 2 * 2^-10 (max_nibble=2 since idx=4 -> 2.0). + expected_mxfp4 = torch.tensor([6.0 * 2**10, 2.0 * 2**-10], dtype=torch.float32) + expected_nvfp4 = expected_mxfp4.repeat_interleave(2, dim=-1) + assert torch.allclose(out, expected_nvfp4) + + +def test_per_block_amax_doubles_last_dim(): + """Two NVFP4 blocks per MXFP4 block share the same per-block amax.""" + e8m0 = torch.tensor([130, 124], dtype=torch.uint8) # ks: 3, -3 + blocks = _make_blocks_with_max_nibble(2, [7, 7]) # in-range + out = cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + assert out.shape == (4,) + # Each pair of consecutive entries should be equal. + assert out[0] == out[1] + assert out[2] == out[3] + + +def test_per_block_amax_preserves_leading_dims(): + """Leading dims (E, F, ...) flow through unchanged; only last dim doubles.""" + # shape (E=2, F=3, num_mxfp4_blocks=4) + e8m0 = torch.full((2, 3, 4), 128, dtype=torch.uint8) # all k=1, in-range + blocks = torch.zeros((2, 3, 4, 16), dtype=torch.uint8) + out = cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + assert out.shape == (2, 3, 8) + + +def test_per_block_amax_shape_mismatch_raises(): + """Mismatched leading dims should raise ``ValueError``.""" + blocks = torch.zeros((4, 16), dtype=torch.uint8) + e8m0 = torch.zeros(3, dtype=torch.uint8) # different num_blocks + with pytest.raises(ValueError, match="shape mismatch"): + cast.compute_per_block_amax_for_mxfp4(blocks, e8m0) + + +# ---------- quantizer_name_from_blocks_key ---------------------------------- + + +def test_quantizer_name_from_blocks_key(): + assert ( + cast.quantizer_name_from_blocks_key("model.layers.0.mlp.experts.gate_up_proj_blocks") + == "model.layers.0.mlp.experts.gate_up_proj_weight_quantizer" + ) + assert ( + cast.quantizer_name_from_blocks_key("model.layers.0.mlp.experts.down_proj_blocks") + == "model.layers.0.mlp.experts.down_proj_weight_quantizer" + ) + + +def test_quantizer_name_from_blocks_key_rejects_non_blocks_key(): + with pytest.raises(AssertionError): + cast.quantizer_name_from_blocks_key("model.layers.0.mlp.experts.gate_up_proj_scales") + + +# ---------- _collect_keys_with_suffix + build_amax_map (synthetic ckpt) ------ + + +def _write_synthetic_mxfp4_checkpoint( + tmp_path: Path, + layer_names: list[str], + e8m0_per_layer: dict[str, torch.Tensor], + blocks_per_layer: dict[str, torch.Tensor], +) -> Path: + """Write a tiny safetensors + index.json mimicking the OpenAI MXFP4 layout. + + Each ``layer_names[i]`` becomes ``_blocks`` + ``_scales`` keys. + Returns the checkpoint directory. + """ + ckpt_dir = tmp_path / "fake_mxfp4" + ckpt_dir.mkdir() + state = {} + for name in layer_names: + state[f"{name}_blocks"] = blocks_per_layer[name] + state[f"{name}_scales"] = e8m0_per_layer[name] + shard_name = "model-00001-of-00001.safetensors" + save_file(state, str(ckpt_dir / shard_name)) + index = { + "metadata": {"total_size": sum(t.numel() * t.element_size() for t in state.values())}, + "weight_map": dict.fromkeys(state, shard_name), + } + (ckpt_dir / "model.safetensors.index.json").write_text(json.dumps(index)) + return ckpt_dir + + +def test_collect_keys_with_suffix(tmp_path): + name = "model.layers.0.mlp.experts.gate_up_proj" + ckpt_dir = _write_synthetic_mxfp4_checkpoint( + tmp_path, + [name], + e8m0_per_layer={name: torch.zeros(4, dtype=torch.uint8)}, + blocks_per_layer={name: torch.zeros((4, 16), dtype=torch.uint8)}, + ) + scales_keys = cast._collect_keys_with_suffix(ckpt_dir, "_scales") + blocks_keys = cast._collect_keys_with_suffix(ckpt_dir, "_blocks") + assert set(scales_keys.keys()) == {f"{name}_scales"} + assert set(blocks_keys.keys()) == {f"{name}_blocks"} + + +def test_build_amax_map(tmp_path): + name1 = "model.layers.0.mlp.experts.gate_up_proj" + name2 = "model.layers.0.mlp.experts.down_proj" + e8m0 = { + name1: torch.tensor([130, 128, 125], dtype=torch.uint8), # ks: 3, 1, -2; spread 5 + name2: torch.tensor([135, 120], dtype=torch.uint8), # ks: 8, -7; spread 15 + } + blocks = { + name1: torch.zeros((3, 16), dtype=torch.uint8), + name2: torch.zeros((2, 16), dtype=torch.uint8), + } + ckpt_dir = _write_synthetic_mxfp4_checkpoint(tmp_path, [name1, name2], e8m0, blocks) + + amax_map = cast.build_amax_map(ckpt_dir) + assert set(amax_map.keys()) == {f"{n}_weight_quantizer" for n in (name1, name2)} + + e1 = amax_map[f"{name1}_weight_quantizer"] + assert e1["k_min"] == -2 and e1["k_max"] == 3 and e1["m"] == -5 + assert e1["global_amax"] == pytest.approx(6.0 * 448.0 * 2.0**-5) + assert e1["pct_lossless"] == pytest.approx(100.0) + + e2 = amax_map[f"{name2}_weight_quantizer"] + assert e2["k_min"] == -7 and e2["k_max"] == 8 and e2["m"] == 0 + assert e2["pct_lossless"] == pytest.approx(100.0) + + +def test_build_amax_map_no_scales_raises(tmp_path): + """A directory without ``*_scales`` tensors should error.""" + empty = tmp_path / "empty" + empty.mkdir() + save_file( + {"model.layers.0.weight": torch.zeros(4)}, + str(empty / "model-00001-of-00001.safetensors"), + ) + (empty / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": {}, + "weight_map": {"model.layers.0.weight": "model-00001-of-00001.safetensors"}, + } + ) + ) + with pytest.raises(SystemExit, match="No '\\*_scales'"): + cast.build_amax_map(empty) + + +# ---------- magnitude table cache ------------------------------------------ + + +def test_e2m1_magnitude_table_cached_per_device(): + t1 = cast._e2m1_magnitude_table(torch.device("cpu")) + t2 = cast._e2m1_magnitude_table(torch.device("cpu")) + assert t1 is t2 # cached: same object + assert t1.tolist() == cast._E2M1_MAGNITUDE + + +# ---------- apply_to_model end-to-end (mock model) --------------------------- + + +class _FakeStaticQuantizer(torch.nn.Module): + """Stand-in for NVFP4StaticQuantizer. + + Carries a per-block ``_amax`` buffer and a ``global_amax`` property whose + setter writes ``_global_amax`` — matches the contract apply_to_model relies + on. Subclasses ``cast.NVFP4StaticQuantizer`` so the isinstance check passes. + """ + + def __init__(self, num_blocks: int): + super().__init__() + self.register_buffer("_amax", torch.zeros(num_blocks, dtype=torch.float32)) + self.register_buffer("_global_amax", torch.zeros((), dtype=torch.float32)) + + @property + def global_amax(self) -> torch.Tensor: + return self._global_amax + + @global_amax.setter + def global_amax(self, value: torch.Tensor) -> None: + self._global_amax = value + + +# Inherit at runtime so isinstance(NVFP4StaticQuantizer) is True. +_FakeStaticQuantizer.__bases__ = (cast.NVFP4StaticQuantizer,) + + +class _FakeExperts(torch.nn.Module): + """Mimics a HF GptOssExperts module: a ``*_weight_quantizer`` child.""" + + def __init__(self, num_blocks: int): + super().__init__() + # Quantizer attribute name must match the source key after stripping + # ``_blocks`` and appending ``_weight_quantizer``. + self.gate_up_proj_weight_quantizer = _FakeStaticQuantizer(num_blocks) + + +class _FakeModel(torch.nn.Module): + """Single MLP-like submodule path: ``model.layers.0.mlp.experts.gate_up_proj_*``.""" + + def __init__(self, num_blocks: int): + super().__init__() + self.experts = _FakeExperts(num_blocks) + + +def test_apply_to_model_writes_global_and_per_block_amax(tmp_path): + """Happy path: cast overrides _amax + global_amax on the matching quantizer.""" + # Build a synthetic MXFP4 source: 4 in-range MXFP4 blocks => 8 NVFP4 blocks. + name = "experts.gate_up_proj" + e8m0 = torch.tensor([130, 128, 125, 132], dtype=torch.uint8) # ks: 3, 1, -2, 5 + blocks = torch.zeros((4, 16), dtype=torch.uint8) + ckpt_dir = _write_synthetic_mxfp4_checkpoint( + tmp_path, + [name], + e8m0_per_layer={name: e8m0}, + blocks_per_layer={name: blocks}, + ) + + # 8 NVFP4 blocks (each MXFP4 block of 32 splits into two NVFP4 blocks of 16). + model = _FakeModel(num_blocks=8) + cast.apply_to_model(model, ckpt_dir) + + quantizer = model.experts.gate_up_proj_weight_quantizer + # k_max = 5 -> m = -3 -> global_amax = 6 * 448 * 2^-3 = 336. + assert float(quantizer.global_amax.item()) == pytest.approx(6.0 * 448.0 * 2.0**-3) + # All in-range -> per-block _amax = 6 * 2^k_j, repeat-interleaved by 2. + expected_per_mxfp4 = 6.0 * torch.exp2(torch.tensor([3.0, 1.0, -2.0, 5.0])) + expected_per_nvfp4 = expected_per_mxfp4.repeat_interleave(2) + assert torch.allclose(quantizer._amax.float(), expected_per_nvfp4) + + +def test_apply_to_model_raises_on_missing_blocks_pair(tmp_path): + """If a *_scales tensor has no paired *_blocks tensor, raise ValueError.""" + ckpt_dir = tmp_path / "missing_blocks" + ckpt_dir.mkdir() + # Write only the _scales tensor. + save_file( + {"experts.gate_up_proj_scales": torch.zeros(2, dtype=torch.uint8)}, + str(ckpt_dir / "model-00001-of-00001.safetensors"), + ) + (ckpt_dir / "model.safetensors.index.json").write_text( + json.dumps( + { + "metadata": {}, + "weight_map": {"experts.gate_up_proj_scales": "model-00001-of-00001.safetensors"}, + } + ) + ) + model = _FakeModel(num_blocks=4) + with pytest.raises(AssertionError, match="no paired '.*_blocks' tensor"): + cast.apply_to_model(model, ckpt_dir) + + +def test_apply_to_model_raises_on_wrong_quantizer_type(tmp_path): + """If the matching attribute isn't an NVFP4StaticQuantizer, raise RuntimeError.""" + name = "experts.gate_up_proj" + e8m0 = torch.tensor([130, 128], dtype=torch.uint8) + blocks = torch.zeros((2, 16), dtype=torch.uint8) + ckpt_dir = _write_synthetic_mxfp4_checkpoint(tmp_path, [name], {name: e8m0}, {name: blocks}) + + class _NotAQuantizer(torch.nn.Module): + pass + + class _Wrong(torch.nn.Module): + def __init__(self): + super().__init__() + self.experts = torch.nn.Module() + self.experts.gate_up_proj_weight_quantizer = _NotAQuantizer() + + with pytest.raises(AssertionError, match="expected NVFP4StaticQuantizer"): + cast.apply_to_model(_Wrong(), ckpt_dir)