Skip to content

[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search#1387

Open
cjluo-nv wants to merge 3 commits intomainfrom
chenjiel/nvfp4-fp8-sweep-triton
Open

[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search#1387
cjluo-nv wants to merge 3 commits intomainfrom
chenjiel/nvfp4-fp8-sweep-triton

Conversation

@cjluo-nv
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv commented May 4, 2026

Summary

  • Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid FP8 E4M3 scale candidates in registers, and emits the per-block best_amax directly.
  • Triton-backed TritonNVFP4MSECalibrator is the default for mse_calibrate(..., fp8_scale_sweep=True). Set MODELOPT_NVFP4_TRITON_SWEEP=0 to fall back to the reference for debugging or numerics comparison.
  • Microbenchmark: ~42x speedup on a B300 over the reference (NVFP4MSECalibrator) on a representative LLM weight (8192x4096, ~2M NVFP4 blocks): 176.68 ms -> 4.23 ms.
  • End-to-end on Qwen3-8B PTQ (hf_ptq.py --qformat nvfp4_mse, calib=128): ~6.7x faster mtq.quantize with identical global weight-MSE as the reference.

End-to-end Qwen3-8B PTQ (B300)

Note: ran on Qwen3-8B instead of Qwen3.5-9B because the docker's transformers (4.57.3) doesn't yet recognize the qwen3_5 (multimodal) architecture and the model dir doesn't ship modeling_*.py for trust_remote_code. Qwen3-8B is the same family / similar size and gives a representative comparison.

Settings: --calib_size 128 --calib_seq 512, default nvfp4_mse config (NVFP4_W4A4_WEIGHT_MSE_FP8_SWEEP_CFG). The first nvfp4 run was discarded as a warm-up for HF weight loading.

qformat mtq.quantize time global weight MSE vs FP16 orig notes
nvfp4 (no MSE search) 3.46 s 6.363e-6 max-calibration baseline
nvfp4_mse (reference, slow) 43.42 s 4.788e-6 MODELOPT_NVFP4_TRITON_SWEEP=0
nvfp4_mse (Triton, fast — this PR) 6.51 s 4.788e-6 default in this PR
  • The Triton path's quantized weights produce bit-identical global weight MSE to the reference (4.788e-6 vs 4.788e-6), validating the kernel on a real model.
  • Triton nvfp4_mse is 6.67x faster than the reference nvfp4_mse, and adds only ~3 s over the no-MSE baseline (vs ~40 s for the reference) — making the MSE-search option nearly free in practice.
  • The MSE search itself reduces weight quantization error by ~25% vs plain max-calibration (6.363e-6 → 4.788e-6).

Per-layer MSE CSVs are produced by tools/debugger/compare_mse_qwen.py (one row per Linear weight) for closer inspection if needed.

Microbenchmark (B300, 8192x4096 weight, ~2M NVFP4 blocks)

reference NVFP4MSECalibrator:   176.68 ms
triton  TritonNVFP4MSECalibrator: 4.23 ms
speedup: 41.8x

Why this works

Each candidate is constructed as valid_fp8_e4m3_value / 448. With block_amax = global_amax * candidate, the FP8 round-trip on the per-block scale block_amax / 6 (using global_amax / 6 as the FP8 amax) is the identity — so the kernel can compute scale = candidate * global_amax / 6.0 inline and skip the FP8 cast. This keeps the kernel runnable on any CUDA + Triton (no tl.float8e4nv requirement).

Because every candidate's per-block scale is just a rescaling of the same input block, all 126 candidates can be evaluated against a single [BLOCKS_PER_PROGRAM, BLOCK_SIZE] tile held in registers — replacing 126 weight-bandwidth passes with 1.

Two follow-on optimizations close the gap to the compute ceiling:

  • @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps) — the original hand-picked default (BPP=4, num_warps=4) left ~4x on the table; the best B300 config is BPP=64, num_warps=8.
  • Drop the sign-handling tl.where: FP4 quant preserves sign, so (w - w_q)^2 == (|w| - |w_q|)^2 and the kernel works on |w| throughout (one fewer where + negation per element per candidate).

Files

  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py — new kernel + nvfp4_fp8_scale_sweep wrapper, autotuned.
  • modelopt/torch/kernels/quantization/gemm/__init__.py — wire-in.
  • modelopt/torch/quantization/calib/mse.py — new TritonNVFP4MSECalibrator(NVFP4MSECalibrator).
  • modelopt/torch/quantization/model_calib.py — opt-out env var (MODELOPT_NVFP4_TRITON_SWEEP=0); Triton path is default.
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py — 16 GPU tests covering parity (across seeds, block counts, dtypes), input validation, output round-trip, reset, and a wall-clock speedup report.

Numerics

Bit-identical to the reference for typical block counts ({4, 64, 1024} blocks, 3 seeds, fp32/fp16/bf16 — 14/15 microbenchmark tests bit-exact).

On multi-million-block weights an occasional adjacent-candidate tie-break can differ at the fp32-noise level (observed 2 / 2,097,152 blocks in the speedup test, per-block MSE within ~1e-7 relative). The reference's CUDA fake_e4m3fy and the Triton inline math have slightly different op ordering, which lets nearly-tied candidates flip. The speedup test asserts the worst per-block MSE gap is < 1e-5 relative on differing blocks — both choices are valid argmins; the resulting quantized weights are equally good. The Qwen3-8B end-to-end run confirms this: aggregate weight MSE matches the reference exactly at the displayed precision.

Test plan

  • pytest tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py -v — 16/16 pass on B300
  • pytest tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py -v — existing NVFP4 tests still pass (9/9)
  • End-to-end PTQ on Qwen3-8B with --qformat nvfp4_mse: 6.67x speedup, identical weight MSE to reference

🤖 Generated with Claude Code

Replaces the 126-iteration Python sweep in NVFP4MSECalibrator with a single
fused Triton kernel that loads each NVFP4 block once, evaluates all 126 valid
FP8 E4M3 scale candidates in registers, and emits the per-block best amax
directly. For our specific candidate set (FP8 representable values / 448) the
FP8 round-trip on the per-block scale is the identity, so the kernel uses
`scale = candidate * global_amax / 6.0` and runs on any CUDA + Triton.

Triton-backed calibrator is on by default for `mse_calibrate(... fp8_scale_sweep=True)`;
set `MODELOPT_NVFP4_TRITON_SWEEP=0` to fall back to the reference for debugging.

Measured ~7.4x speedup on a B300 over the reference NVFP4MSECalibrator
(8192x4096 weight, ~2M NVFP4 blocks: 176.67 ms -> 23.81 ms). Bit-identical to
the reference for typical block counts; on multi-million-block weights an
occasional adjacent-candidate tie-break can differ at the fp32-noise level
(observed 2 / 2,097,152 blocks; per-block MSE within 1e-7 relative).

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv requested review from a team as code owners May 4, 2026 20:59
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

📝 Walkthrough

Walkthrough

Adds a Triton-fused NVFP4 FP8 scale-sweep kernel, a Triton-backed NVFP4 MSE calibrator, re-exports the kernel symbols, makes the calibrator selectable via an env var, and adds GPU tests for parity, dtype coverage, validation, state semantics, and performance.

Changes

Triton Kernel-Based NVFP4 FP8 Scale Sweep

Layer / File(s) Summary
Data Shape / Candidates
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Adds fp8_scale_candidates(device) producing 126 finite positive FP8 E4M3 scale candidates divided by 448.
Autotune & Kernel
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Adds _FP8_SWEEP_AUTOTUNE_CONFIGS and Triton JIT _fp8_scale_sweep_kernel that loads per-tile
Public Wrapper / Validation
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Adds nvfp4_fp8_scale_sweep(...) wrapper: CUDA checks, block_size/divisibility checks, lazy candidate generation/validation, casting to fp32 on device, flattening to blocks, allocation of best_amax, and autotuned kernel launch.
Package Export
modelopt/torch/kernels/quantization/gemm/__init__.py
Conditionally re-exports all public names from .nvfp4_fp8_sweep when CUDA and Triton import succeed.
Calibrator Implementation
modelopt/torch/quantization/calib/mse.py
Adds TritonNVFP4MSECalibrator (subclass of NVFP4MSECalibrator): caches initial amax shape/dtype/block count, collect() runs nvfp4_fp8_scale_sweep (validates block count, enforces one-shot before reset()), stores _best_amax; compute_amax() returns it; reset() clears it and delegates to base.
Integration Wiring
modelopt/torch/quantization/model_calib.py
Imports TritonNVFP4MSECalibrator, reads MODELOPT_NVFP4_TRITON_SWEEP env var (default enabled), and selects Triton vs reference NVFP4 calibrator when fp8_scale_sweep is active.
Tests / Validation & Benchmarks
tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
Adds GPU tests and helpers: parity across random seeds and block counts, dtype coverage (float32/float16/bfloat16), fake-quant output equivalence, calibrator single-collect/reset semantics, input validation error cases, and a performance benchmark comparing reference vs Triton calibrators with MSE-tolerant tie handling.

Sequence Diagram

sequenceDiagram
    participant User as User
    participant Calibrator as TritonNVFP4MSECalibrator
    participant Wrapper as nvfp4_fp8_scale_sweep
    participant Kernel as _fp8_scale_sweep_kernel

    User->>Calibrator: collect(x)
    Calibrator->>Wrapper: nvfp4_fp8_scale_sweep(x, global_amax, block_size, candidates)
    Wrapper->>Wrapper: validate inputs, generate/cast candidates, flatten to blocks
    Wrapper->>Kernel: launch _fp8_scale_sweep_kernel (autotuned)
    Kernel->>Kernel: load per-tile |w|, for each candidate (static loop): quantize magnitude, compute MSE, track argmin
    Kernel-->>Wrapper: return best_amax per block
    Wrapper-->>Calibrator: best_amax
    Calibrator->>Calibrator: store _best_amax
    User->>Calibrator: compute_amax()
    Calibrator-->>User: return _best_amax
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 error, 1 warning)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error The pull request uses assert for runtime input validation and lacks checks for division by zero, violating security guidelines. Replace assert with explicit ValueError checks and add validation that block_size > 0 before division.
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title '[Quantization] Fused Triton kernel for NVFP4 FP8 scale sweep search' directly and clearly summarizes the main change: introduction of a fused Triton kernel for NVFP4 FP8 scale sweep optimization.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenjiel/nvfp4-fp8-sweep-triton

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@cjluo-nv cjluo-nv requested review from Fridah-nv and realAsma May 4, 2026 21:00
Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Clean, well-structured Triton kernel for speeding up the NVFP4 FP8 scale sweep. The implementation correctly reuses the existing nvfp4_scalar_quant JIT function from nvfp4_quant.py, the math insight about the FP8 round-trip identity for the candidate set is sound, and test coverage is solid (15 GPU tests covering parity, dtypes, round-trip, reset, and speedup).

A few points for the owner:

  1. Unchecked test plan items: The PR body has two unchecked items — H100/A100 run and end-to-end PTQ on a 70B model. Per project norms, these should be completed before merge.

  2. Minor code duplication: fp8_scale_candidates() in nvfp4_fp8_sweep.py duplicates NVFP4MSECalibrator._generate_candidates(). Consider having one call the other (or extracting a shared utility) to keep the candidate generation logic in one place.

  3. local_hessian_calibrate not using the Triton path: This function still uses NVFP4MSECalibrator directly (not TritonNVFP4MSECalibrator), which is correct since it needs a custom error_func. Worth adding a comment there noting that the Triton path doesn't support custom error functions, so someone doesn't "helpfully" switch it later.

  4. collect assumes x.shape[-1] is block_size: This works for the current MSE weight calibration flow where the tensor is pre-reshaped, but could be fragile if the calibrator is used in a different context. A brief assert or docstring note would help.

block_size = x.shape[-1]
n_blocks = x.numel() // block_size
if self._initial_amax.numel() != n_blocks:
raise ValueError(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Nit: block_size = x.shape[-1] assumes the input tensor has already been reshaped to [n_blocks, block_size]. This is true for the current mse_calibrate weight flow, but could silently produce wrong results if someone uses this calibrator with a differently-shaped tensor. Consider adding a brief assertion or docstring note, e.g.:

assert x.ndim == 2, "Expected x to be [n_blocks, block_size] from the weight quantizer reshape"



def fp8_scale_candidates(device: torch.device | str = "cpu") -> torch.Tensor:
"""Return the 126 valid finite positive FP8 E4M3 scale candidates / 448."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

Minor duplication: this function reproduces the same logic as NVFP4MSECalibrator._generate_candidates() in calib/mse.py. Consider having one call the other (or extracting a shared utility) so the candidate generation stays in sync if the candidate set ever changes.

# Replace calibrator with NVFP4MSECalibrator
module._calibrator = NVFP4MSECalibrator(
# Replace calibrator with the fused Triton sweep kernel by default
# (single-shot collect, ~7-20x faster for the weight-MSE phase).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot comment.

The env var check os.environ.get("MODELOPT_NVFP4_TRITON_SWEEP", "1") != "0" is evaluated on every weight quantizer in the loop. Since it won't change mid-loop, consider hoisting it above the loop for clarity and minor efficiency.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py (1)

72-74: 💤 Low value

Consider hoisting candidate loads outside the loop.

The candidate value is loaded inside tl.static_range, which means 126 separate scalar loads per program invocation. Since candidates_ptr points to shared read-only data, you could load all candidates into a register vector once before the loop for better memory efficiency.

That said, tl.static_range unrolls at compile time and Triton's compiler may already optimize repeated scalar loads. This is a minor optimization suggestion.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py` around lines 72
- 74, The loop uses tl.static_range over NUM_CANDIDATES and calls
tl.load(candidates_ptr + k) inside each iteration (producing many scalar loads);
hoist these loads by reading the entire candidates_ptr into a local vector/array
(e.g., candidates_arr) once before the tl.static_range loop and then use
candidates_arr[k] (or the equivalent register lookup) inside the loop; update
references to the temporary c to load from the prefilled candidates_arr instead
of calling tl.load each iteration (keep names NUM_CANDIDATES, candidates_ptr,
tl.static_range, and c to locate the code).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 117-119: Replace the assert-based CUDA check with consistent
exception raising: in the nvfp4_fp8_scale_sweep (or the function containing the
lines checking x.is_cuda and block_size), change the assert x.is_cuda line to
raise a ValueError (or a module-specific custom exception) with a clear message
like "nvfp4_fp8_scale_sweep requires a CUDA tensor" so both the CUDA check and
the block_size divisibility check use the same error style and won't be removed
by python -O; keep the existing block_size check and message for x.numel()
as-is.

---

Nitpick comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 72-74: The loop uses tl.static_range over NUM_CANDIDATES and calls
tl.load(candidates_ptr + k) inside each iteration (producing many scalar loads);
hoist these loads by reading the entire candidates_ptr into a local vector/array
(e.g., candidates_arr) once before the tl.static_range loop and then use
candidates_arr[k] (or the equivalent register lookup) inside the loop; update
references to the temporary c to load from the prefilled candidates_arr instead
of calling tl.load each iteration (keep names NUM_CANDIDATES, candidates_ptr,
tl.static_range, and c to locate the code).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: e28188c9-6b33-47d0-ac6b-3029bbe39550

📥 Commits

Reviewing files that changed from the base of the PR and between 1d21ab9 and 4fbb181.

📒 Files selected for processing (5)
  • modelopt/torch/kernels/quantization/gemm/__init__.py
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py

Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py Outdated
…ner loop

Two follow-on optimizations to the fused FP8 scale sweep kernel:

1. @triton.autotune over (BLOCKS_PER_PROGRAM, num_warps): a hand-sweep on B300
   showed the previous default (BPP=4, num_warps=4) at 23.7 ms left ~4x on the
   table — best config (BPP=64, num_warps=8) lands at ~5 ms. Three configs are
   included to cover small/medium/large N_BLOCKS without flooding compile time.

2. Drop the sign-handling tl.where: since FP4 quantization preserves sign,
   (w - w_q)^2 == (|w| - |w_q|)^2, so the kernel works on |w| throughout and
   skips one tl.where + negation per element per candidate.

Result on the same 8192x4096 weight (~2M blocks) on B300:
  reference NVFP4MSECalibrator:   176.68 ms
  triton  TritonNVFP4MSECalibrator: 4.23 ms
  speedup: 41.8x  (was 7.4x)

This is ~1.2x above the rough pure-compute floor (~240 GF / 67 TF/s ~= 3.6 ms),
so the kernel is now near saturation and further wins would need an algorithmic
change (candidate pruning, etc.).

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 4, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1387/

Built to branch gh-pages at 2026-05-04 22:09 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py`:
- Around line 139-158: The code does not validate a custom candidates tensor
before launching _fp8_scale_sweep_kernel, allowing empty or malformed inputs
which make NUM_CANDIDATES zero and cause out-of-bounds reads (e.g.,
candidates_ptr[0]); fix by validating candidates returned or passed into the
function: if candidates is not None ensure it's a 1-D tensor, non-empty, finite,
positive, and has the expected length (ideally 126) or otherwise raise a clear
error; if candidates is None keep using fp8_scale_candidates(x.device) as
before; perform this validation before calling candidates.contiguous().to(...)
and before computing NUM_CANDIDATES/launching _fp8_scale_sweep_kernel so the
kernel never receives an invalid candidate tensor.
- Around line 112-137: In nvfp4_fp8_scale_sweep validate the public parameter
block_size before using it: add an explicit check that block_size is a positive
integer (e.g. if not isinstance(block_size, int) or block_size <= 0: raise
ValueError(...)) at the start of the function, and only afterwards perform
operations that use it (such as x.numel() % block_size) so you avoid
ZeroDivisionError for 0 and invalid negative values that would otherwise produce
bad kernel launches.

In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 243-247: The reset() implementation in TritonNVFP4MSECalibrator
currently calls MseCalibrator.reset(), which clears _initial_amax and makes the
next collect() dereference None; update TritonNVFP4MSECalibrator.reset() to
either preserve _initial_amax (do not delete or set _initial_amax to its prior
tensor) or reinitialize it to a valid tensor/zero-sized tensor so collect() can
safely call self._initial_amax.numel(), and ensure _best_amax is also reset
consistently; apply the same fix to the other reset override referenced around
the second occurrence (similar block at lines ~274-277) so both reset overrides
in TritonNVFP4MSECalibrator maintain the contract expected by collect().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 61d118f9-f103-4290-945e-bbc50478d48c

📥 Commits

Reviewing files that changed from the base of the PR and between 4fbb181 and 6040607.

📒 Files selected for processing (2)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/calib/mse.py

Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Comment thread modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
Comment thread modelopt/torch/quantization/calib/mse.py
@codecov
Copy link
Copy Markdown

codecov Bot commented May 4, 2026

Codecov Report

❌ Patch coverage is 72.61905% with 23 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.86%. Comparing base (acfab41) to head (bd4fc3a).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
...torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py 56.60% 23 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1387      +/-   ##
==========================================
- Coverage   76.90%   76.86%   -0.04%     
==========================================
  Files         471      472       +1     
  Lines       50562    50660      +98     
==========================================
+ Hits        38886    38942      +56     
- Misses      11676    11718      +42     
Flag Coverage Δ
examples 41.53% <26.19%> (+0.87%) ⬆️
gpu 59.69% <72.61%> (-0.64%) ⬇️
unit 52.72% <14.28%> (-0.08%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Addresses review comments on PR #1387:

- TritonNVFP4MSECalibrator.reset() now leaves the calibrator reusable: shape /
  dtype / n_blocks of the initial amax are stashed in __init__, so collect()
  no longer depends on _initial_amax surviving reset(). Adds an x.ndim==2
  assertion in collect() since the weight quantizer always reshapes upstream.
- nvfp4_fp8_scale_sweep validates inputs cleanly instead of using assert
  (which is stripped by python -O): rejects non-CUDA tensors, non-positive
  block_size, and empty / non-1D candidates with ValueError. Skips the
  per-element finite/positive check on candidates since it would scan a 126-
  entry tensor on every kernel call.
- mse_calibrate hoists the MODELOPT_NVFP4_TRITON_SWEEP env-var lookup out of
  the per-quantizer loop and resolves to the calibrator class once.
- Updates test_reset_allows_recollect to verify the new reuse contract; adds
  test_input_validation covering the new ValueErrors.

The duplicate fp8_scale_candidates implementation in the kernel file and
NVFP4MSECalibrator._generate_candidates() is left in place: deduplicating
would force the reference path to import from the kernel module, which is
gated behind Triton availability. The FP8 E4M3 spec is fixed and the parity
test exercises both paths against each other.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

All critical issues from the previous review have been addressed:

  • assert x.is_cudaraise ValueError
  • block_size <= 0 validation added
  • ✅ Custom candidates tensor validation added
  • reset() reusability fixed (metadata stashed in __init__, tested in test_reset_allows_recollect)
  • ✅ Env var hoisted above the loop
  • x.ndim == 2 assertion added in collect()
  • ✅ Candidate duplication addressed with sync docstring in _generate_candidates

The kernel implementation is clean, the math insight about the FP8 round-trip identity is sound, and test coverage is solid (15 GPU tests covering parity, dtypes, round-trip, reset, input validation, and speedup).

Remaining concern: The PR body still has two unchecked test plan items — H100/A100 validation and end-to-end PTQ on a 70B model. Per project norms, these should be completed before merge. Nudging for human sign-off on that.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/calib/mse.py`:
- Around line 260-266: In collect(), replace the runtime assertion with explicit
input validation: check that the input tensor x has ndim == 2 and raise a
ValueError with a clear message if not; check block_size = x.shape[-1] is > 0
and raise ValueError if it is zero to avoid ZeroDivisionError; compute n_blocks
= x.numel() // block_size only after these checks and if n_blocks !=
self._n_blocks raise a ValueError describing the mismatch (referencing the
collect method and TensorQuantizer._process_for_blockquant behavior to explain
expected [n_blocks, block_size] shape).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b2ef23a9-7bcf-4284-b38a-2cf368077ba4

📥 Commits

Reviewing files that changed from the base of the PR and between 6040607 and bd4fc3a.

📒 Files selected for processing (4)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/model_calib.py
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
✅ Files skipped from review due to trivial changes (2)
  • modelopt/torch/kernels/quantization/gemm/nvfp4_fp8_sweep.py
  • modelopt/torch/quantization/model_calib.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py

Comment on lines +260 to +266
x = x.detach()
# The weight quantizer reshapes its input to [n_blocks, block_size] before
# calling collect (see TensorQuantizer._process_for_blockquant).
assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
block_size = x.shape[-1]
n_blocks = x.numel() // block_size
if n_blocks != self._n_blocks:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify assertion-based runtime validation and block-size handling in this file.
rg -n -C3 'assert x\.ndim|block_size = x\.shape\[-1\]|x\.numel\(\) // block_size' modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 612


🏁 Script executed:

#!/bin/bash
# Get broader context around the collect() method to understand input contracts
sed -n '240,290p' modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2215


🏁 Script executed:

#!/bin/bash
# Check the class definition and docstring
sed -n '200,260p' modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2630


🏁 Script executed:

#!/bin/bash
# Verify SPDX header and file structure
head -20 modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 866


🏁 Script executed:

#!/bin/bash
# Search for other similar patterns in the file (assert for shape validation)
rg -n 'assert.*ndim|assert.*shape' modelopt/torch/quantization/calib/mse.py

Repository: NVIDIA/Model-Optimizer

Length of output: 170


Replace assert with explicit input validation in collect().

Line 263 uses assert for runtime input checks, which can be stripped with optimized Python, and n_blocks = x.numel() // block_size can throw ZeroDivisionError when x.shape[-1] == 0. Use explicit ValueError checks for shape and block_size before division.

Suggested patch
         x = x.detach()
         # The weight quantizer reshapes its input to [n_blocks, block_size] before
         # calling collect (see TensorQuantizer._process_for_blockquant).
-        assert x.ndim == 2, f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
+        if x.ndim != 2:
+            raise ValueError(
+                f"Expected x to be [n_blocks, block_size]; got shape {tuple(x.shape)}."
+            )
         block_size = x.shape[-1]
+        if block_size <= 0:
+            raise ValueError(f"Expected positive block_size in x.shape[-1], got {block_size}.")
         n_blocks = x.numel() // block_size

This is a GPU kernel calibration path (nvfp4_fp8_scale_sweep). Per SECURITY.md: "Apply defensive input validation to prevent crashes/resource exhaustion… validate … shapes/sizes … early."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/calib/mse.py` around lines 260 - 266, In
collect(), replace the runtime assertion with explicit input validation: check
that the input tensor x has ndim == 2 and raise a ValueError with a clear
message if not; check block_size = x.shape[-1] is > 0 and raise ValueError if it
is zero to avoid ZeroDivisionError; compute n_blocks = x.numel() // block_size
only after these checks and if n_blocks != self._n_blocks raise a ValueError
describing the mismatch (referencing the collect method and
TensorQuantizer._process_for_blockquant behavior to explain expected [n_blocks,
block_size] shape).

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

All critical issues from previous reviews have been addressed:

  • assert x.is_cuda replaced with raise ValueError
  • block_size <= 0 validation added before use
  • Custom candidates tensor validation (non-empty, 1-D) added
  • reset() reusability fixed — shape/dtype/n_blocks metadata stashed in __init__, verified by test_reset_allows_recollect
  • Env var read hoisted above the weight quantizer loop
  • x.ndim == 2 shape check added in collect()
  • Candidate duplication addressed with sync docstring

The kernel implementation is correct — the FP8 round-trip identity insight is sound, input validation is thorough, and test coverage is solid (16 GPU tests covering parity across seeds/dtypes/block counts, round-trip, reset, input validation, and speedup reporting). License headers match the canonical LICENSE_HEADER.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants