Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelopt/torch/kernels/quantization/gemm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# fp4_kernel works on any CUDA GPU with triton
from .fp4_kernel import *
from .fp8_kernel import *
from .nvfp4_fp8_sweep import *

# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
if torch.cuda.get_device_capability() >= (8, 9):
Expand Down
166 changes: 166 additions & 0 deletions modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep.

Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single
kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates
and emits the per-block ``best_amax`` directly.

The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see
:func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on
the per-block scale is the identity, so the kernel can use
``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it
runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement).

Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``.
"""

import torch
import triton
import triton.language as tl

from .nvfp4_quant import fp4_round_magnitude

__all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"]


def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor duplication: this function reproduces the same logic as NVFP4MSECalibrator._generate_candidates() in calib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.

uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
return fp8_values[valid_mask] / 448.0


# Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300:
# BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms
# The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64
# would underfill the SMs.
_FP8_SWEEP_AUTOTUNE_CONFIGS = [
triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2),
triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4),
triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8),
]


@triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"])
@triton.jit
def _fp8_scale_sweep_kernel(
x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32)
candidates_ptr, # [NUM_CANDIDATES] fp32
global_amax_ptr, # scalar fp32
best_amax_ptr, # [N_BLOCKS] fp32 output
N_BLOCKS,
BLOCK_SIZE: tl.constexpr,
NUM_CANDIDATES: tl.constexpr,
BLOCKS_PER_PROGRAM: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCKS_PER_PROGRAM
block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM)
block_mask = block_idx < N_BLOCKS

# Load weights for this tile and pre-compute their absolute values once.
# The squared error is sign-invariant since FP4 quant preserves sign:
# (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2
# so we never need ``w`` itself again, dropping a tl.where + negation per element.
elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :]
elem_mask = block_mask[:, None]
w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32))

global_amax = tl.load(global_amax_ptr).to(tl.float32)

best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32)
best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32)

# Loop over the 126 FP8 candidates (compile-time unrolled).
# Scales are guaranteed positive and finite (constructed from a positive candidate
# times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is
# unnecessary apart from the global_amax == 0 case handled below.
for k in tl.static_range(NUM_CANDIDATES):
c = tl.load(candidates_ptr + k).to(tl.float32)
scale = c * global_amax / 6.0
# Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is
# the same for every candidate, so any best_idx is fine.
scale_safe = tl.where(scale == 0.0, 1.0, scale)
q_mag = fp4_round_magnitude(w_abs / scale_safe)
diff = w_abs - q_mag * scale
loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM]
is_better = loss < best_loss
best_loss = tl.where(is_better, loss, best_loss)
best_idx = tl.where(is_better, k, best_idx)

# Map each block's winning candidate index back to its amax = global_amax * c[best].
best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32)
best_amax = global_amax * best_c
tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask)


def nvfp4_fp8_scale_sweep(
x: torch.Tensor,
global_amax: torch.Tensor,
block_size: int = 16,
candidates: torch.Tensor | None = None,
) -> torch.Tensor:
"""Find the per-block FP8 scale that minimizes NVFP4 quantization MSE.

Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into
a single Triton kernel: every block's weight elements are loaded once, all 126
candidates are evaluated in registers, and the running argmin is kept inline.

Args:
x: Weight tensor on CUDA. Total element count must be divisible by
``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``.
global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``).
block_size: NVFP4 block size (typically 16).
candidates: Optional precomputed candidate tensor of shape ``[126]`` (must
be the FP8 E4M3 valid values divided by 448). Built lazily if omitted.

Returns:
``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``.
"""
if not x.is_cuda:
raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.")
if not isinstance(block_size, int) or block_size <= 0:
raise ValueError(f"block_size must be a positive int, got {block_size!r}.")
if x.numel() % block_size != 0:
raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).")
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if candidates is None:
candidates = fp8_scale_candidates(x.device)
candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32)
if candidates.ndim != 1 or candidates.numel() == 0:
raise ValueError(
f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}."
)

n_blocks = x.numel() // block_size
x_flat = x.contiguous().view(-1)
global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1)
best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device)

grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),)
with torch.cuda.device(x.device):
_fp8_scale_sweep_kernel[grid](
x_flat,
candidates,
global_amax_f32,
best_amax,
n_blocks,
BLOCK_SIZE=block_size,
NUM_CANDIDATES=int(candidates.numel()),
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return best_amax
95 changes: 93 additions & 2 deletions modelopt/torch/quantization/calib/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .. import utils as quant_utils
from .calibrator import _Calibrator

__all__ = ["MseCalibrator", "NVFP4MSECalibrator"]
__all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"]


class MseCalibrator(_Calibrator):
Expand Down Expand Up @@ -192,9 +192,100 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor:
return torch.ones_like(self._initial_amax) * self._global_amax * candidates

def _generate_candidates(self, device: torch.device) -> torch.Tensor:
"""Generate 126 valid FP8 E4M3 scale candidates."""
"""Generate 126 valid FP8 E4M3 scale candidates.

Kept in sync with ``fp8_scale_candidates`` in
``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3
spec is fixed, and the parity test exercises both paths against each other.
"""
uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device)
fp8_values = uint8_values.view(torch.float8_e4m3fn).float()
valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0)
fp8_values = fp8_values[valid_mask]
return fp8_values / 448.0


class TritonNVFP4MSECalibrator(NVFP4MSECalibrator):
"""Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE.

Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126
candidates in a single fused Triton kernel — one weight read instead of 126.

Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle.
This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where
the calibrator is collected once per weight and immediately consumed. For
activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`.
Call :meth:`reset` to free internal state and re-enable :meth:`collect`.
"""

def __init__(
self,
amax: torch.Tensor,
global_amax: torch.Tensor,
axis: int | tuple | list | None = None,
quant_func: Callable | None = None,
error_func: Callable | None = None,
):
"""Initialize the Triton-fused NVFP4 MSE calibrator.

See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by
the kernel path but accepted for API parity. Tile shape and ``num_warps`` are
autotuned by the kernel per ``N_BLOCKS``.
"""
super().__init__(
amax=amax,
global_amax=global_amax,
axis=axis,
quant_func=quant_func,
error_func=error_func,
)
# Stash shape metadata so collect() can keep working after reset() releases
# the (potentially large) _initial_amax buffer.
self._initial_amax_shape = tuple(amax.shape)
self._initial_amax_dtype = amax.dtype
self._n_blocks = int(amax.numel())
self._best_amax: torch.Tensor | None = None

@torch.no_grad()
def collect(self, x: torch.Tensor):
"""Run the fused FP8 sweep kernel and store the resulting per-block amax."""
from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep

if self._best_amax is not None:
raise RuntimeError(
"TritonNVFP4MSECalibrator.collect() is one-shot; call reset() to "
"discard the previous result before collecting again."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

x = x.detach()
# The weight quantizer reshapes its input to [n_blocks, block_size] before
# calling collect (see TensorQuantizer._process_for_blockquant).
assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
block_size = x.shape[-1]
n_blocks = x.numel() // block_size
if n_blocks != self._n_blocks:
Comment on lines +260 to +266
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify assertion-based runtime validation and block-size handling in this file.
rg -n -C3 'assert x\.ndim|block_size = x\.shape\[-1\]|x\.numel\(\) // block_size' modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 612


🏁 Script executed:

#!/bin/bash
# Get broader context around the collect() method to understand input contracts
sed -n '240,290p' modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2215


🏁 Script executed:

#!/bin/bash
# Check the class definition and docstring
sed -n '200,260p' modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2630


🏁 Script executed:

#!/bin/bash
# Verify SPDX header and file structure
head -20 modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 866


🏁 Script executed:

#!/bin/bash
# Search for other similar patterns in the file (assert for shape validation)
rg -n 'assert.*ndim|assert.*shape' modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 170


Replace assert with explicit input validation in collect().

Line 263 uses assert for runtime input checks, which can be stripped with optimized Python, and n_blocks = x.numel() // block_size can throw ZeroDivisionError when x.shape[-1] == 0. Use explicit ValueError checks for shape and block_size before division.

Suggested patch
         x = x.detach()
         # The weight quantizer reshapes its input to [n_blocks, block_size] before
         # calling collect (see TensorQuantizer._process_for_blockquant).
-        assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
+        if x.ndim != 2:
+            raise ValueError(
+                f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
+            )
         block_size = x.shape[-1]
+        if block_size <= 0:
+            raise ValueError(f"Expected positive block_size in x.shape[-1], got {block_size}.")
         n_blocks = x.numel() // block_size

This is a GPU kernel calibration path (nvfp4_fp8_scale_sweep). Per SECURITY.md: "Apply defensive input validation to prevent crashes/resource exhaustion… validate … shapes/sizes … early."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/calib/mse.py` around lines 260 - 266, In
collect(), replace the runtime assertion with explicit input validation: check
that the input tensor x has ndim == 2 and raise a ValueError with a clear
message if not; check block_size = x.shape[-1] is > 0 and raise ValueError if it
is zero to avoid ZeroDivisionError; compute n_blocks = x.numel() // block_size
only after these checks and if n_blocks != self._n_blocks raise a ValueError
describing the mismatch (referencing the collect method and
TensorQuantizer._process_for_blockquant behavior to explain expected [n_blocks,
block_size] shape).

raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Nit: block_size = x.shape[-1] assumes the input tensor has already been reshaped to [n_blocks, block_size]. This is true for the current mse_calibrate weight flow, but could silently produce wrong results if someone uses this calibrator with a differently-shaped tensor. Consider adding a brief assertion or docstring note, e.g.:

assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape"

f"initial amax.numel() ({self._n_blocks}) does not match the number "
f"of NVFP4 blocks in x ({n_blocks})."
)

best_amax_flat = nvfp4_fp8_scale_sweep(
x,
self._global_amax,
block_size=block_size,
)
# Match the original shape/dtype of the initial amax so downstream
# load_calib_amax behaves identically to the reference path.
self._best_amax = best_amax_flat.reshape(self._initial_amax_shape).to(
self._initial_amax_dtype
)

@torch.no_grad()
def compute_amax(self, verbose: bool = False):
"""Return the per-block amax computed during ``collect``."""
return self._best_amax

def reset(self):
"""Reset the stored best amax. Subsequent ``collect`` calls are allowed."""
self._best_amax = None
super().reset()
11 changes: 8 additions & 3 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Calibration utilities."""

import math
import os
import time
import warnings
from collections.abc import Callable
Expand All @@ -37,7 +38,7 @@
from modelopt.torch.utils.distributed import DistributedProcessGroup, ParallelState
from modelopt.torch.utils.network import bind_forward_method, unpatch_forward_method

from .calib import MseCalibrator, NVFP4MSECalibrator, _Calibrator
from .calib import MseCalibrator, NVFP4MSECalibrator, TritonNVFP4MSECalibrator, _Calibrator
from .conversion import create_and_replace_svdquant_linear_on_the_fly, set_quantizer_by_cfg_context
from .nn import NVFP4StaticQuantizer, QuantModule, SequentialQuantizer, TensorQuantizer
from .utils import (
Expand Down Expand Up @@ -354,6 +355,11 @@ def mse_calibrate(
weight_quantizers = []
seen_modules = set()

# Triton-fused FP8 sweep is on by default for NVFP4 static quant; set
# MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging.
use_triton_fp8_sweep = os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0"
nvfp4_calibrator_cls = TritonNVFP4MSECalibrator if use_triton_fp8_sweep else NVFP4MSECalibrator

for name, module in list(model.named_modules()):
if isinstance(module, TensorQuantizer) and not module._disabled:
if module._calibrator is not None and not module._dynamic and hasattr(module, "_amax"):
Expand Down Expand Up @@ -391,8 +397,7 @@ def mse_calibrate(
continue

if fp8_scale_sweep and is_nvfp4_static:
# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
module._calibrator = nvfp4_calibrator_cls(
amax=initial_amax,
axis=module._calibrator._axis,
global_amax=module.global_amax,
Expand Down
Loading
Loading