[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search#1387
[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search#1387
Conversation
Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid FP8 E4M3 scale candidates in registers, and emits the per-block best amax directly. For our specific candidate set (FP8 representable values / 448) the FP8 round-trip on the per-block scale is the identity, so the kernel uses `scale = candidate * global_amax / 6.0` and runs on any CUDA + Triton. Triton-backed calibrator is on by default for `mse_calibrate(... fp8_scale_sweep=True)`; set `MODELOPT_NVFP4_TRITON_SWEEP=0` to fall back to the reference for debugging. Measured ~7.4x speedup on a B300 over the reference NVFP4MSECalibrator (8192x4096 weight, ~2M NVFP4 blocks: 176.67 ms -> 23.81 ms). Bit-identical to the reference for typical block counts; on multi-million-block weights an occasional adjacent-candidate tie-break can differ at the fp32-noise level (observed 2 / 2,097,152 blocks; per-block MSE within 1e-7 relative). Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
📝 WalkthroughWalkthroughAdds a Triton-fused NVFP4 FP8 scale-sweep kernel, a Triton-backed NVFP4 MSE calibrator, re-exports the kernel symbols, makes the calibrator selectable via an env var, and adds GPU tests for parity, dtype coverage, validation, state semantics, and performance. ChangesTriton Kernel-Based NVFP4 FP8 Scale Sweep
Sequence DiagramsequenceDiagram
participant User as User
participant Calibrator as TritonNVFP4MSECalibrator
participant Wrapper as nvfp4_fp8_scale_sweep
participant Kernel as _fp8_scale_sweep_kernel
User->>Calibrator: collect(x)
Calibrator->>Wrapper: nvfp4_fp8_scale_sweep(x, global_amax, block_size, candidates)
Wrapper->>Wrapper: validate inputs, generate/cast candidates, flatten to blocks
Wrapper->>Kernel: launch _fp8_scale_sweep_kernel (autotuned)
Kernel->>Kernel: load per-tile |w|, for each candidate (static loop): quantize magnitude, compute MSE, track argmin
Kernel-->>Wrapper: return best_amax per block
Wrapper-->>Calibrator: best_amax
Calibrator->>Calibrator: store _best_amax
User->>Calibrator: compute_amax()
Calibrator-->>User: return _best_amax
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error, 1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Clean, well-structured Triton kernel for speeding up the NVFP4 FP8 scale sweep. The implementation correctly reuses the existing nvfp4_scalar_quant JIT function from nvfp4_quant.py, the math insight about the FP8 round-trip identity for the candidate set is sound, and test coverage is solid (15 GPU tests covering parity, dtypes, round-trip, reset, and speedup).
A few points for the owner:
-
Unchecked test plan items: The PR body has two unchecked items — H100/A100 run and end-to-end PTQ on a 70B model. Per project norms, these should be completed before merge.
-
Minor code duplication:
fp8_scale_candidates()innvfp4_fp8_sweep.pyduplicatesNVFP4MSECalibrator._generate_candidates(). Consider having one call the other (or extracting a shared utility) to keep the candidate generation logic in one place. -
local_hessian_calibratenot using the Triton path: This function still usesNVFP4MSECalibratordirectly (notTritonNVFP4MSECalibrator), which is correct since it needs a customerror_func. Worth adding a comment there noting that the Triton path doesn't support custom error functions, so someone doesn't "helpfully" switch it later. -
collectassumesx.shape[-1]is block_size: This works for the current MSE weight calibration flow where the tensor is pre-reshaped, but could be fragile if the calibrator is used in a different context. A brief assert or docstring note would help.
| block_size = x.shape[-1] | ||
| n_blocks = x.numel() // block_size | ||
| if self._initial_amax.numel() != n_blocks: | ||
| raise ValueError( |
There was a problem hiding this comment.
Bot comment.
Nit: block_size = x.shape[-1] assumes the input tensor has already been reshaped to [n_blocks, block_size]. This is true for the current mse_calibrate weight flow, but could silently produce wrong results if someone uses this calibrator with a differently-shaped tensor. Consider adding a brief assertion or docstring note, e.g.:
assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape"|
|
||
|
|
||
| def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor: | ||
| """Return the 126 valid finite positive FP8 E4M3 scale candidates / 448.""" |
There was a problem hiding this comment.
Bot comment.
Minor duplication: this function reproduces the same logic as NVFP4MSECalibrator._generate_candidates() in calib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.
| # Replace calibrator with NVFP4MSECalibrator | ||
| module._calibrator = NVFP4MSECalibrator( | ||
| # Replace calibrator with the fused Triton sweep kernel by default | ||
| # (single-shot collect, ~7-20x faster for the weight-MSE phase). |
There was a problem hiding this comment.
Bot comment.
The env var check os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" is evaluated on every weight quantizer in the loop. Since it won't change mid-loop, consider hoisting it above the loop for clarity and minor efficiency.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py (1)
72-74: 💤 Low valueConsider hoisting candidate loads outside the loop.
The candidate value is loaded inside
tl.static_range, which means 126 separate scalar loads per program invocation. Sincecandidates_ptrpoints to shared read-only data, you could load all candidates into a register vector once before the loop for better memory efficiency.That said,
tl.static_rangeunrolls at compile time and Triton's compiler may already optimize repeated scalar loads. This is a minor optimization suggestion.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py` around lines 72 - 74, The loop uses tl.static_range over NUM_CANDIDATES and calls tl.load(candidates_ptr + k) inside each iteration (producing many scalar loads); hoist these loads by reading the entire candidates_ptr into a local vector/array (e.g., candidates_arr) once before the tl.static_range loop and then use candidates_arr[k] (or the equivalent register lookup) inside the loop; update references to the temporary c to load from the prefilled candidates_arr instead of calling tl.load each iteration (keep names NUM_CANDIDATES, candidates_ptr, tl.static_range, and c to locate the code).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 117-119: Replace the assert-based CUDA check with consistent
exception raising: in the nvfp4_fp8_scale_sweep (or the function containing the
lines checking x.is_cuda and block_size), change the assert x.is_cuda line to
raise a ValueError (or a module-specific custom exception) with a clear message
like "nvfp4_fp8_scale_sweep requires a CUDA tensor" so both the CUDA check and
the block_size divisibility check use the same error style and won't be removed
by python -O; keep the existing block_size check and message for x.numel()
as-is.
---
Nitpick comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 72-74: The loop uses tl.static_range over NUM_CANDIDATES and calls
tl.load(candidates_ptr + k) inside each iteration (producing many scalar loads);
hoist these loads by reading the entire candidates_ptr into a local vector/array
(e.g., candidates_arr) once before the tl.static_range loop and then use
candidates_arr[k] (or the equivalent register lookup) inside the loop; update
references to the temporary c to load from the prefilled candidates_arr instead
of calling tl.load each iteration (keep names NUM_CANDIDATES, candidates_ptr,
tl.static_range, and c to locate the code).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: e28188c9-6b33-47d0-ac6b-3029bbe39550
📒 Files selected for processing (5)
modelopt/torch/kernels/quantization/gemm/__init__.pymodelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.pymodelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
…ner loop Two follow-on optimizations to the fused FP8 scale sweep kernel: 1. @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps): a hand-sweep on B300 showed the previous default (BPP=4, num_warps=4) at 23.7 ms left ~4x on the table — best config (BPP=64, num_warps=8) lands at ~5 ms. Three configs are included to cover small/medium/large N_BLOCKS without flooding compile time. 2. Drop the sign-handling tl.where: since FP4 quantization preserves sign, (w - w_q)^2 == (|w| - |w_q|)^2, so the kernel works on |w| throughout and skips one tl.where + negation per element per candidate. Result on the same 8192x4096 weight (~2M blocks) on B300: reference NVFP4MSECalibrator: 176.68 ms triton TritonNVFP4MSECalibrator: 4.23 ms speedup: 41.8x (was 7.4x) This is ~1.2x above the rough pure-compute floor (~240 GF / 67 TF/s ~= 3.6 ms), so the kernel is now near saturation and further wins would need an algorithmic change (candidate pruning, etc.). Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
|
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 139-158: The code does not validate a custom candidates tensor
before launching _fp8_scale_sweep_kernel, allowing empty or malformed inputs
which make NUM_CANDIDATES zero and cause out-of-bounds reads (e.g.,
candidates_ptr[0]); fix by validating candidates returned or passed into the
function: if candidates is not None ensure it's a 1-D tensor, non-empty, finite,
positive, and has the expected length (ideally 126) or otherwise raise a clear
error; if candidates is None keep using fp8_scale_candidates(x.device) as
before; perform this validation before calling candidates.contiguous().to(...)
and before computing NUM_CANDIDATES/launching _fp8_scale_sweep_kernel so the
kernel never receives an invalid candidate tensor.
- Around line 112-137: In nvfp4_fp8_scale_sweep validate the public parameter
block_size before using it: add an explicit check that block_size is a positive
integer (e.g. if not isinstance(block_size, int) or block_size <= 0: raise
ValueError(...)) at the start of the function, and only afterwards perform
operations that use it (such as x.numel() % block_size) so you avoid
ZeroDivisionError for 0 and invalid negative values that would otherwise produce
bad kernel launches.
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 243-247: The reset() implementation in TritonNVFP4MSECalibrator
currently calls MseCalibrator.reset(), which clears _initial_amax and makes the
next collect() dereference None; update TritonNVFP4MSECalibrator.reset() to
either preserve _initial_amax (do not delete or set _initial_amax to its prior
tensor) or reinitialize it to a valid tensor/zero-sized tensor so collect() can
safely call self._initial_amax.numel(), and ensure _best_amax is also reset
consistently; apply the same fix to the other reset override referenced around
the second occurrence (similar block at lines ~274-277) so both reset overrides
in TritonNVFP4MSECalibrator maintain the contract expected by collect().
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 61d118f9-f103-4290-945e-bbc50478d48c
📒 Files selected for processing (2)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.pymodelopt/torch/quantization/calib/mse.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1387 +/- ##
==========================================
- Coverage 76.90% 76.86% -0.04%
==========================================
Files 471 472 +1
Lines 50562 50660 +98
==========================================
+ Hits 38886 38942 +56
- Misses 11676 11718 +42
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Addresses review comments on PR #1387: - TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape / dtype / n_blocks of the initial amax are stashed in __init__, so collect() no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2 assertion in collect() since the weight quantizer always reshapes upstream. - nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert (which is stripped by python -O): rejects non-CUDA tensors, non-positive block_size, and empty / non-1D candidates with ValueError. Skips the per-element finite/positive check on candidates since it would scan a 126- entry tensor on every kernel call. - mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of the per-quantizer loop and resolves to the calibrator class once. - Updates test_reset_allows_recollect to verify the new reuse contract; adds test_input_validation covering the new ValueErrors. The duplicate fp8_scale_candidates implementation in the kernel file and NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating would force the reference path to import from the kernel module, which is gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity test exercises both paths against each other. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
All critical issues from the previous review have been addressed:
- ✅
assert x.is_cuda→raise ValueError - ✅
block_size <= 0validation added - ✅ Custom
candidatestensor validation added - ✅
reset()reusability fixed (metadata stashed in__init__, tested intest_reset_allows_recollect) - ✅ Env var hoisted above the loop
- ✅
x.ndim == 2assertion added incollect() - ✅ Candidate duplication addressed with sync docstring in
_generate_candidates
The kernel implementation is clean, the math insight about the FP8 round-trip identity is sound, and test coverage is solid (15 GPU tests covering parity, dtypes, round-trip, reset, input validation, and speedup).
Remaining concern: The PR body still has two unchecked test plan items — H100/A100 validation and end-to-end PTQ on a 70B model. Per project norms, these should be completed before merge. Nudging for human sign-off on that.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 260-266: In collect(), replace the runtime assertion with explicit
input validation: check that the input tensor x has ndim == 2 and raise a
ValueError with a clear message if not; check block_size = x.shape[-1] is > 0
and raise ValueError if it is zero to avoid ZeroDivisionError; compute n_blocks
= x.numel() // block_size only after these checks and if n_blocks !=
self._n_blocks raise a ValueError describing the mismatch (referencing the
collect method and TensorQuantizer._process_for_blockquant behavior to explain
expected [n_blocks, block_size] shape).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b2ef23a9-7bcf-4284-b38a-2cf368077ba4
📒 Files selected for processing (4)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.pymodelopt/torch/quantization/calib/mse.pymodelopt/torch/quantization/model_calib.pytests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
✅ Files skipped from review due to trivial changes (2)
- modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
- modelopt/torch/quantization/model_calib.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
| x = x.detach() | ||
| # The weight quantizer reshapes its input to [n_blocks, block_size] before | ||
| # calling collect (see TensorQuantizer._process_for_blockquant). | ||
| assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}." | ||
| block_size = x.shape[-1] | ||
| n_blocks = x.numel() // block_size | ||
| if n_blocks != self._n_blocks: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify assertion-based runtime validation and block-size handling in this file.
rg -n -C3 'assert x\.ndim|block_size = x\.shape\[-1\]|x\.numel\(\) // block_size' modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer
Length of output: 612
🏁 Script executed:
#!/bin/bash
# Get broader context around the collect() method to understand input contracts
sed -n '240,290p' modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2215
🏁 Script executed:
#!/bin/bash
# Check the class definition and docstring
sed -n '200,260p' modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2630
🏁 Script executed:
#!/bin/bash
# Verify SPDX header and file structure
head -20 modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer
Length of output: 866
🏁 Script executed:
#!/bin/bash
# Search for other similar patterns in the file (assert for shape validation)
rg -n 'assert.*ndim|assert.*shape' modelopt/torch/quantization/calib/mse.pyRepository: NVIDIA/Model-Optimizer
Length of output: 170
Replace assert with explicit input validation in collect().
Line 263 uses assert for runtime input checks, which can be stripped with optimized Python, and n_blocks = x.numel() // block_size can throw ZeroDivisionError when x.shape[-1] == 0. Use explicit ValueError checks for shape and block_size before division.
Suggested patch
x = x.detach()
# The weight quantizer reshapes its input to [n_blocks, block_size] before
# calling collect (see TensorQuantizer._process_for_blockquant).
- assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
+ if x.ndim != 2:
+ raise ValueError(
+ f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
+ )
block_size = x.shape[-1]
+ if block_size <= 0:
+ raise ValueError(f"Expected positive block_size in x.shape[-1], got {block_size}.")
n_blocks = x.numel() // block_sizeThis is a GPU kernel calibration path (nvfp4_fp8_scale_sweep). Per SECURITY.md: "Apply defensive input validation to prevent crashes/resource exhaustion… validate … shapes/sizes … early."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@modelopt/torch/quantization/calib/mse.py` around lines 260 - 266, In
collect(), replace the runtime assertion with explicit input validation: check
that the input tensor x has ndim == 2 and raise a ValueError with a clear
message if not; check block_size = x.shape[-1] is > 0 and raise ValueError if it
is zero to avoid ZeroDivisionError; compute n_blocks = x.numel() // block_size
only after these checks and if n_blocks != self._n_blocks raise a ValueError
describing the mismatch (referencing the collect method and
TensorQuantizer._process_for_blockquant behavior to explain expected [n_blocks,
block_size] shape).
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
All critical issues from previous reviews have been addressed:
assert x.is_cudareplaced withraise ValueErrorblock_size <= 0validation added before use- Custom
candidatestensor validation (non-empty, 1-D) added reset()reusability fixed — shape/dtype/n_blocks metadata stashed in__init__, verified bytest_reset_allows_recollect- Env var read hoisted above the weight quantizer loop
x.ndim == 2shape check added incollect()- Candidate duplication addressed with sync docstring
The kernel implementation is correct — the FP8 round-trip identity insight is sound, input validation is thorough, and test coverage is solid (16 GPU tests covering parity across seeds/dtypes/block counts, round-trip, reset, input validation, and speedup reporting). License headers match the canonical LICENSE_HEADER.
Summary
NVFP4MSECalibratorwith a fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid FP8 E4M3 scale candidates in registers, and emits the per-blockbest_amaxdirectly.TritonNVFP4MSECalibratoris the default formse_calibrate(..., fp8_scale_sweep=True). SetMODELOPT_NVFP4_TRITON_SWEEP=0to fall back to the reference for debugging or numerics comparison.NVFP4MSECalibrator) on a representative LLM weight (8192x4096, ~2M NVFP4 blocks):176.68 ms -> 4.23 ms.hf_ptq.py --qformat nvfp4_mse, calib=128): ~6.7x faster mtq.quantize with identical global weight-MSE as the reference.End-to-end Qwen3-8B PTQ (B300)
Settings:
--calib_size 128 --calib_seq 512, defaultnvfp4_mseconfig (NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG). The firstnvfp4run was discarded as a warm-up for HF weight loading.mtq.quantizetimenvfp4(no MSE search)nvfp4_mse(reference, slow)MODELOPT_NVFP4_TRITON_SWEEP=0nvfp4_mse(Triton, fast — this PR)nvfp4_mseis 6.67x faster than the referencenvfp4_mse, and adds only ~3 s over the no-MSE baseline (vs ~40 s for the reference) — making the MSE-search option nearly free in practice.Per-layer MSE CSVs are produced by
tools/debugger/compare_mse_qwen.py(one row per Linear weight) for closer inspection if needed.Microbenchmark (B300, 8192x4096 weight, ~2M NVFP4 blocks)
Why this works
Each candidate is constructed as
valid_fp8_e4m3_value / 448. Withblock_amax = global_amax * candidate, the FP8 round-trip on the per-block scaleblock_amax / 6(usingglobal_amax / 6as the FP8 amax) is the identity — so the kernel can computescale = candidate * global_amax / 6.0inline and skip the FP8 cast. This keeps the kernel runnable on any CUDA + Triton (notl.float8e4nvrequirement).Because every candidate's per-block scale is just a rescaling of the same input block, all 126 candidates can be evaluated against a single
[BLOCKS_PER_PROGRAM, BLOCK_SIZE]tile held in registers — replacing 126 weight-bandwidth passes with 1.Two follow-on optimizations close the gap to the compute ceiling:
@triton.autotuneover(BLOCKS_PER_PROGRAM, num_warps)— the original hand-picked default (BPP=4, num_warps=4) left ~4x on the table; the best B300 config isBPP=64, num_warps=8.tl.where: FP4 quant preserves sign, so(w - w_q)^2 == (|w| - |w_q|)^2and the kernel works on|w|throughout (one fewer where + negation per element per candidate).Files
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py— new kernel +nvfp4_fp8_scale_sweepwrapper, autotuned.modelopt/torch/kernels/quantization/gemm/__init__.py— wire-in.modelopt/torch/quantization/calib/mse.py— newTritonNVFP4MSECalibrator(NVFP4MSECalibrator).modelopt/torch/quantization/model_calib.py— opt-out env var (MODELOPT_NVFP4_TRITON_SWEEP=0); Triton path is default.tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py— 16 GPU tests covering parity (across seeds, block counts, dtypes), input validation, output round-trip, reset, and a wall-clock speedup report.Numerics
Bit-identical to the reference for typical block counts (
{4, 64, 1024}blocks, 3 seeds, fp32/fp16/bf16 — 14/15 microbenchmark tests bit-exact).On multi-million-block weights an occasional adjacent-candidate tie-break can differ at the fp32-noise level (observed 2 / 2,097,152 blocks in the speedup test, per-block MSE within ~1e-7 relative). The reference's CUDA
fake_e4m3fyand the Triton inline math have slightly different op ordering, which lets nearly-tied candidates flip. The speedup test asserts the worst per-block MSE gap is< 1e-5relative on differing blocks — both choices are valid argmins; the resulting quantized weights are equally good. The Qwen3-8B end-to-end run confirms this: aggregate weight MSE matches the reference exactly at the displayed precision.Test plan
pytest tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py -v— 16/16 pass on B300pytest tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py -v— existing NVFP4 tests still pass (9/9)--qformat nvfp4_mse: 6.67x speedup, identical weight MSE to reference🤖 Generated with Claude Code