-
Notifications
You must be signed in to change notification settings - Fork 380
[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search #1387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Fused Triton kernel for the NVFP4 weight-MSE FP8 scale sweep. | ||
|
|
||
| Replaces the 126-iteration Python sweep in :class:`NVFP4MSECalibrator` with a single | ||
| kernel that, for each NVFP4 block, evaluates all 126 valid FP8 E4M3 scale candidates | ||
| and emits the per-block ``best_amax`` directly. | ||
|
|
||
| The 126 candidates are constructed as ``valid_fp8_e4m3_value / 448`` (see | ||
| :func:`fp8_scale_candidates`). For these specific candidates, the FP8 round-trip on | ||
| the per-block scale is the identity, so the kernel can use | ||
| ``scale = candidate * global_amax / 6.0`` without an explicit FP8 cast — making it | ||
| runnable on any CUDA GPU with Triton (no ``tl.float8e4nv`` requirement). | ||
|
|
||
| Tile shape (``BLOCKS_PER_PROGRAM``) and ``num_warps`` are autotuned per ``N_BLOCKS``. | ||
| """ | ||
|
|
||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from .nvfp4_quant import fp4_round_magnitude | ||
|
|
||
| __all__ = ["fp8_scale_candidates", "nvfp4_fp8_scale_sweep"] | ||
|
|
||
|
|
||
| def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: | ||
| """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" | ||
| uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) | ||
| fp8_values = uint8_values.view(torch.float8_e4m3fn).float() | ||
| valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) | ||
| return fp8_values[valid_mask] / 448.0 | ||
|
|
||
|
|
||
| # Selected from a (BLOCKS_PER_PROGRAM, num_warps) sweep on B300: | ||
| # BPP=16,nw=2: 6.06 ms BPP=32,nw=4: 6.06 ms BPP=64,nw=8: 5.08 ms | ||
| # The smaller-tile entries cover cases where N_BLOCKS is small enough that BPP=64 | ||
| # would underfill the SMs. | ||
| _FP8_SWEEP_AUTOTUNE_CONFIGS = [ | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 16}, num_warps=2), | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 32}, num_warps=4), | ||
| triton.Config({"BLOCKS_PER_PROGRAM": 64}, num_warps=8), | ||
| ] | ||
|
|
||
|
|
||
| @triton.autotune(configs=_FP8_SWEEP_AUTOTUNE_CONFIGS, key=["N_BLOCKS"]) | ||
| @triton.jit | ||
| def _fp8_scale_sweep_kernel( | ||
| x_ptr, # [N_BLOCKS * BLOCK_SIZE], any float dtype (loaded as fp32) | ||
| candidates_ptr, # [NUM_CANDIDATES] fp32 | ||
| global_amax_ptr, # scalar fp32 | ||
| best_amax_ptr, # [N_BLOCKS] fp32 output | ||
| N_BLOCKS, | ||
| BLOCK_SIZE: tl.constexpr, | ||
| NUM_CANDIDATES: tl.constexpr, | ||
| BLOCKS_PER_PROGRAM: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(axis=0) | ||
| block_start = pid * BLOCKS_PER_PROGRAM | ||
| block_idx = block_start + tl.arange(0, BLOCKS_PER_PROGRAM) | ||
| block_mask = block_idx < N_BLOCKS | ||
|
|
||
| # Load weights for this tile and pre-compute their absolute values once. | ||
| # The squared error is sign-invariant since FP4 quant preserves sign: | ||
| # (w - w_q)^2 = (|w| - |w_q|)^2 = (|w| - q_mag * scale)^2 | ||
| # so we never need ``w`` itself again, dropping a tl.where + negation per element. | ||
| elem_offs = block_idx[:, None] * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)[None, :] | ||
| elem_mask = block_mask[:, None] | ||
| w_abs = tl.abs(tl.load(x_ptr + elem_offs, mask=elem_mask, other=0.0).to(tl.float32)) | ||
|
|
||
| global_amax = tl.load(global_amax_ptr).to(tl.float32) | ||
|
|
||
| best_loss = tl.full([BLOCKS_PER_PROGRAM], float("inf"), dtype=tl.float32) | ||
| best_idx = tl.zeros([BLOCKS_PER_PROGRAM], dtype=tl.int32) | ||
|
|
||
| # Loop over the 126 FP8 candidates (compile-time unrolled). | ||
| # Scales are guaranteed positive and finite (constructed from a positive candidate | ||
| # times nonneg global_amax), so the degenerate-scale guard from nvfp4_scalar_quant is | ||
| # unnecessary apart from the global_amax == 0 case handled below. | ||
| for k in tl.static_range(NUM_CANDIDATES): | ||
| c = tl.load(candidates_ptr + k).to(tl.float32) | ||
| scale = c * global_amax / 6.0 | ||
| # Avoid divide-by-zero when global_amax == 0; the resulting err == w_abs² is | ||
| # the same for every candidate, so any best_idx is fine. | ||
| scale_safe = tl.where(scale == 0.0, 1.0, scale) | ||
| q_mag = fp4_round_magnitude(w_abs / scale_safe) | ||
| diff = w_abs - q_mag * scale | ||
| loss = tl.sum(diff * diff, axis=1) # [BLOCKS_PER_PROGRAM] | ||
| is_better = loss < best_loss | ||
| best_loss = tl.where(is_better, loss, best_loss) | ||
| best_idx = tl.where(is_better, k, best_idx) | ||
|
|
||
| # Map each block's winning candidate index back to its amax = global_amax * c[best]. | ||
| best_c = tl.load(candidates_ptr + best_idx, mask=block_mask, other=0.0).to(tl.float32) | ||
| best_amax = global_amax * best_c | ||
| tl.store(best_amax_ptr + block_idx, best_amax, mask=block_mask) | ||
|
|
||
|
|
||
| def nvfp4_fp8_scale_sweep( | ||
| x: torch.Tensor, | ||
| global_amax: torch.Tensor, | ||
| block_size: int = 16, | ||
| candidates: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| """Find the per-block FP8 scale that minimizes NVFP4 quantization MSE. | ||
|
|
||
| Equivalent to the 126-step sweep in :class:`NVFP4MSECalibrator`, but fused into | ||
| a single Triton kernel: every block's weight elements are loaded once, all 126 | ||
| candidates are evaluated in registers, and the running argmin is kept inline. | ||
|
|
||
| Args: | ||
| x: Weight tensor on CUDA. Total element count must be divisible by | ||
| ``block_size``; layout is treated as a flat ``[N_BLOCKS, BLOCK_SIZE]``. | ||
| global_amax: Scalar FP32 global amax (``= reduce_amax(per_block_amax)``). | ||
| block_size: NVFP4 block size (typically 16). | ||
| candidates: Optional precomputed candidate tensor of shape ``[126]`` (must | ||
| be the FP8 E4M3 valid values divided by 448). Built lazily if omitted. | ||
|
|
||
| Returns: | ||
| ``best_amax`` of shape ``[N_BLOCKS]``, fp32, on the same device as ``x``. | ||
| """ | ||
| if not x.is_cuda: | ||
| raise ValueError("nvfp4_fp8_scale_sweep requires a CUDA tensor.") | ||
| if not isinstance(block_size, int) or block_size <= 0: | ||
| raise ValueError(f"block_size must be a positive int, got {block_size!r}.") | ||
| if x.numel() % block_size != 0: | ||
| raise ValueError(f"x.numel() ({x.numel()}) is not divisible by block_size ({block_size}).") | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| if candidates is None: | ||
| candidates = fp8_scale_candidates(x.device) | ||
| candidates = candidates.contiguous().to(device=x.device, dtype=torch.float32) | ||
| if candidates.ndim != 1 or candidates.numel() == 0: | ||
| raise ValueError( | ||
| f"candidates must be a non-empty 1-D tensor; got shape {tuple(candidates.shape)}." | ||
| ) | ||
|
|
||
| n_blocks = x.numel() // block_size | ||
| x_flat = x.contiguous().view(-1) | ||
| global_amax_f32 = global_amax.detach().to(device=x.device, dtype=torch.float32).reshape(1) | ||
| best_amax = torch.empty(n_blocks, dtype=torch.float32, device=x.device) | ||
|
|
||
| grid = lambda meta: (triton.cdiv(n_blocks, meta["BLOCKS_PER_PROGRAM"]),) | ||
| with torch.cuda.device(x.device): | ||
| _fp8_scale_sweep_kernel[grid]( | ||
| x_flat, | ||
| candidates, | ||
| global_amax_f32, | ||
| best_amax, | ||
| n_blocks, | ||
| BLOCK_SIZE=block_size, | ||
| NUM_CANDIDATES=int(candidates.numel()), | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| return best_amax | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,7 +24,7 @@ | |
| from .. import utils as quant_utils | ||
| from .calibrator import _Calibrator | ||
|
|
||
| __all__ = ["MseCalibrator", "NVFP4MSECalibrator"] | ||
| __all__ = ["MseCalibrator", "NVFP4MSECalibrator", "TritonNVFP4MSECalibrator"] | ||
|
|
||
|
|
||
| class MseCalibrator(_Calibrator): | ||
|
|
@@ -192,9 +192,100 @@ def _compute_candidate_amax(self, candidates: torch.Tensor) -> torch.Tensor: | |
| return torch.ones_like(self._initial_amax) * self._global_amax * candidates | ||
|
|
||
| def _generate_candidates(self, device: torch.device) -> torch.Tensor: | ||
| """Generate 126 valid FP8 E4M3 scale candidates.""" | ||
| """Generate 126 valid FP8 E4M3 scale candidates. | ||
|
|
||
| Kept in sync with ``fp8_scale_candidates`` in | ||
| ``modelopt.torch.kernels.quantization.gemm.nvfp4_fp8_sweep`` — the FP8 E4M3 | ||
| spec is fixed, and the parity test exercises both paths against each other. | ||
| """ | ||
| uint8_values = torch.arange(0, 128, dtype=torch.uint8, device=device) | ||
| fp8_values = uint8_values.view(torch.float8_e4m3fn).float() | ||
| valid_mask = torch.isfinite(fp8_values) & (fp8_values > 0) | ||
| fp8_values = fp8_values[valid_mask] | ||
| return fp8_values / 448.0 | ||
|
|
||
|
|
||
| class TritonNVFP4MSECalibrator(NVFP4MSECalibrator): | ||
| """Triton-fused FP8 scale sweep calibrator for NVFP4 weight MSE. | ||
|
|
||
| Numerically equivalent to :class:`NVFP4MSECalibrator` but evaluates all 126 | ||
| candidates in a single fused Triton kernel — one weight read instead of 126. | ||
|
|
||
| Limitation: a single ``collect()`` call is supported per ``compute_amax`` cycle. | ||
| This matches the static weight-MSE flow (``mse_calibrate``'s weight loop), where | ||
| the calibrator is collected once per weight and immediately consumed. For | ||
| activation calibration (multiple ``collect`` calls), use :class:`NVFP4MSECalibrator`. | ||
| Call :meth:`reset` to free internal state and re-enable :meth:`collect`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| amax: torch.Tensor, | ||
| global_amax: torch.Tensor, | ||
| axis: int | tuple | list | None = None, | ||
| quant_func: Callable | None = None, | ||
| error_func: Callable | None = None, | ||
| ): | ||
| """Initialize the Triton-fused NVFP4 MSE calibrator. | ||
|
|
||
| See :class:`NVFP4MSECalibrator`. ``quant_func``/``error_func`` are unused by | ||
| the kernel path but accepted for API parity. Tile shape and ``num_warps`` are | ||
| autotuned by the kernel per ``N_BLOCKS``. | ||
| """ | ||
| super().__init__( | ||
| amax=amax, | ||
| global_amax=global_amax, | ||
| axis=axis, | ||
| quant_func=quant_func, | ||
| error_func=error_func, | ||
| ) | ||
| # Stash shape metadata so collect() can keep working after reset() releases | ||
| # the (potentially large) _initial_amax buffer. | ||
| self._initial_amax_shape = tuple(amax.shape) | ||
| self._initial_amax_dtype = amax.dtype | ||
| self._n_blocks = int(amax.numel()) | ||
| self._best_amax: torch.Tensor | None = None | ||
|
|
||
| @torch.no_grad() | ||
| def collect(self, x: torch.Tensor): | ||
| """Run the fused FP8 sweep kernel and store the resulting per-block amax.""" | ||
| from modelopt.torch.kernels.quantization.gemm import nvfp4_fp8_scale_sweep | ||
|
|
||
| if self._best_amax is not None: | ||
| raise RuntimeError( | ||
| "TritonNVFP4MSECalibrator.collect() is one-shot; call reset() to " | ||
| "discard the previous result before collecting again." | ||
| ) | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
|
|
||
| x = x.detach() | ||
| # The weight quantizer reshapes its input to [n_blocks, block_size] before | ||
| # calling collect (see TensorQuantizer._process_for_blockquant). | ||
| assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." | ||
| block_size = x.shape[-1] | ||
| n_blocks = x.numel() // block_size | ||
| if n_blocks != self._n_blocks: | ||
|
Comment on lines
+260
to
+266
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Verify assertion-based runtime validation and block-size handling in this file.
rg -n -C3 'assert x\.ndim|block_size = x\.shape\[-1\]|x\.numel\(\) // block_size' modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer Length of output: 612 🏁 Script executed: #!/bin/bash
# Get broader context around the collect() method to understand input contracts
sed -n '240,290p' modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer Length of output: 2215 🏁 Script executed: #!/bin/bash
# Check the class definition and docstring
sed -n '200,260p' modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer Length of output: 2630 🏁 Script executed: #!/bin/bash
# Verify SPDX header and file structure
head -20 modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer Length of output: 866 🏁 Script executed: #!/bin/bash
# Search for other similar patterns in the file (assert for shape validation)
rg -n 'assert.*ndim|assert.*shape' modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer Length of output: 170 Replace Line 263 uses Suggested patch x = x.detach()
# The weight quantizer reshapes its input to [n_blocks, block_size] before
# calling collect (see TensorQuantizer._process_for_blockquant).
- assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
+ if x.ndim != 2:
+ raise ValueError(
+ f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
+ )
block_size = x.shape[-1]
+ if block_size <= 0:
+ raise ValueError(f"Expected positive block_size in x.shape[-1], got {block_size}.")
n_blocks = x.numel() // block_sizeThis is a GPU kernel calibration path (nvfp4_fp8_scale_sweep). Per SECURITY.md: "Apply defensive input validation to prevent crashes/resource exhaustion… validate … shapes/sizes … early." 🤖 Prompt for AI Agents |
||
| raise ValueError( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nit: assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape" |
||
| f"initial amax.numel() ({self._n_blocks}) does not match the number " | ||
| f"of NVFP4 blocks in x ({n_blocks})." | ||
| ) | ||
|
|
||
| best_amax_flat = nvfp4_fp8_scale_sweep( | ||
| x, | ||
| self._global_amax, | ||
| block_size=block_size, | ||
| ) | ||
| # Match the original shape/dtype of the initial amax so downstream | ||
| # load_calib_amax behaves identically to the reference path. | ||
| self._best_amax = best_amax_flat.reshape(self._initial_amax_shape).to( | ||
| self._initial_amax_dtype | ||
| ) | ||
|
|
||
| @torch.no_grad() | ||
| def compute_amax(self, verbose: bool = False): | ||
| """Return the per-block amax computed during ``collect``.""" | ||
| return self._best_amax | ||
|
|
||
| def reset(self): | ||
| """Reset the stored best amax. Subsequent ``collect`` calls are allowed.""" | ||
| self._best_amax = None | ||
| super().reset() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor duplication: this function reproduces the same logic as
NVFP4MSECalibrator._generate_candidates()incalib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.