Add closed-form MXFP4 -> NVFP4 weight cast (--cast_mxfp4_to_nvfp4)#1372
Add closed-form MXFP4 -> NVFP4 weight cast (--cast_mxfp4_to_nvfp4)#1372
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
✅ Files skipped from review due to trivial changes (1)
📝 WalkthroughWalkthroughAdds an MXFP4→NVFP4 casting utility and CLI flag that derive global and per-block amax from Hugging Face safetensors and apply them in-place to NVFP4 static weight quantizers; integrates promotion/flow changes so static-block NVFP4 quantizers are handled during calibration/export. Changes
Sequence Diagram(s)sequenceDiagram
participant HF as HF_checkpoint
participant Shard as Shard_reader
participant Builder as Amax_builder
participant Model as Model_instance
participant Quant as NVFP4_quantizer
HF->>Shard: Read `*_scales` and optional `*_blocks`
Shard->>Builder: Provide per-layer scales & blocks
Builder->>Builder: Build/cached E2M1 magnitude table
Builder->>Builder: Compute k_min/k_max → global_amax
Builder->>Builder: Compute per-block _amax (closed-form or nibble-scan)
Builder->>Model: Map checkpoint keys → quantizer module names
Model->>Quant: Locate NVFP4StaticQuantizer instances
Quant->>Quant: Set `weight_quantizer.global_amax`
Quant->>Quant: Copy per-block `_amax` into buffer (torch.no_grad)
Quant-->>Model: Confirm updates / diagnostics
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Review rate limit: 8/10 reviews remaining, refill in 6 minutes and 42 seconds. Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (2)
examples/llm_ptq/hf_ptq.py (1)
1126-1128: Add fail-fast validation for cast mode compatibilityPlease reject
--cast_mxfp4_to_nvfp4unless--qformatis NVFP4-family (and preferably disallow with multi-format auto-quantize). Right now invalid combinations can proceed and fail late.Suggested guard
args = parser.parse_args() + if args.cast_mxfp4_to_nvfp4: + qformats = [q.strip() for q in args.qformat.split(",")] + if not all("nvfp4" in q for q in qformats): + parser.error("--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values.") + if args.auto_quantize_bits is not None and len(qformats) > 1: + parser.error( + "--cast_mxfp4_to_nvfp4 is not supported with multi-format auto_quantize." + )Also applies to: 1370-1381
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_ptq/hf_ptq.py` around lines 1126 - 1128, Add a fail-fast guard before calling apply_cast_mxfp4_to_nvfp4: validate that args.cast_mxfp4_to_nvfp4 is only true when args.qformat is in the NVFP4 family (check the exact qformat string values your code uses) and reject/raise an error (or exit) if it’s used together with any multi-format auto-quantize option (the flag/variable controlling multi-format auto-quantize in your parser) to prevent late failures; update the two call sites that invoke apply_cast_mxfp4_to_nvfp4 (the one using args.cast_mxfp4_to_nvfp4 and args.pyt_ckpt_path and the similar block around the 1370-1381 region) to perform this validation first and emit a clear message about allowed combinations.examples/llm_ptq/cast_mxfp4_to_nvfp4.py (1)
39-42: Lazy-load the optional HF dependencies.Importing
safetensorsand the quantizer class at module load time makes this helper fail to import unless the full extra set is already installed. Please move these imports into the code paths that actually need them, or gate them through the repo’s plugin-loading pattern.As per coding guidelines, "Avoid hard imports of optional dependencies at module level; features should be gated by install extras (
[onnx],[hf],[all]) and loaded lazily via `import_plugin()``."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 39 - 42, The module-level imports of safetensors and NVFP4StaticQuantizer cause hard dependency loading; move these imports into the function(s) that actually use them (e.g., inside the routine that opens the safetensors file or performs quantization) or replace them with the repo's plugin loader (import_plugin) so they are lazily loaded; specifically, remove "from safetensors import safe_open" and "from modelopt.torch.quantization.nn.modules.tensor_quantizer import NVFP4StaticQuantizer" from the top-level and import safe_open and NVFP4StaticQuantizer within the function that calls them (or use import_plugin('safetensors').safe_open / import_plugin('modelopt.torch.quantization').NVFP4StaticQuantizer) and add a clear error message when the import fails.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py`:
- Around line 144-148: The shape check currently only compares leading dims
(blocks.shape[:-1] vs e8m0_scales.shape) and misses verifying the trailing block
size; update the validation to also require blocks.shape[-1] == 16 (and raise
the same ValueError if not) so malformed tensors like (..., N, 15) or (..., N,
32) are rejected; keep the existing error message (referencing blocks and
e8m0_scales) and perform this check where the current if comparing
blocks.shape[:-1] and e8m0_scales.shape appears.
- Around line 328-370: The code still unconditionally loads the packed blocks
tensor via _read and calls compute_per_block_amax_for_mxfp4 even when
compute_global_amax_for_scales reported info["pct_lossless"] >= 100.0; change
the logic in the loop (around compute_global_amax_for_scales, qname,
weight_quantizer checks) to short-circuit the I/O path for fully-lossless
layers: if info["pct_lossless"] >= 100.0, construct global_amax from
global_amax_value on the existing per-block buffer device and reuse the existing
per-block `_amax` (e.g., existing.to(dtype=torch.float32, device=device)) as
per_block_amax, and skip calling _read(blocks_key, ...) and
compute_per_block_amax_for_mxfp4; keep the NVFP4StaticQuantizer/assert checks
and only avoid the expensive block read when pct_lossless==100.0.
- Around line 338-382: Several assertions validating untrusted checkpoint/model
data (the check that blocks_shard is not None, the type-check of
weight_quantizer, the presence/shape check of weight_quantizer._amax, and the
element-count check comparing existing.numel() to per_block_amax.numel()) must
be replaced with explicit exception handling; update the checks around
blocks_shard, weight_quantizer (qname → NVFP4StaticQuantizer), existing/_amax,
and element count to raise specific exceptions (e.g., ValueError or
RuntimeError) with clear diagnostic messages including qname, expected vs actual
types/values, and any relevant shapes/numel counts, rather than using assert so
validation still runs under python -O and treats all checkpoint artifacts as
untrusted.
- Around line 258-263: The handle cache (handles: dict and safe_open(...)
usages) is never closed, leaking file descriptors/mmaps; wrap the cache creation
and all safe_open acquisitions in a context-managed ExitStack (or equivalent) so
each safe_open call is entered via stack.enter_context(...) and all handles are
closed deterministically when the function returns—apply this change in both the
function that iterates sorted(scales_keys.items()) (where scales =
handles[shard].get_tensor(tensor_key)) and in apply_to_model() around its main
loop, ensuring the ExitStack is created at the start of the function and closed
on all return paths so file handles are always released.
In `@modelopt/torch/export/layer_utils.py`:
- Around line 1094-1096: The collection of valid amax values should skip meta
tensors to avoid calling .to(target_device) on tensors without storage; modify
the block that appends to valid_amax_values (where existing_amax is checked) to
first check if existing_amax is a meta tensor (existing_amax.is_meta) and only
call existing_amax.amax().to(target_device) for non-meta tensors, otherwise
ignore/skip that existing_amax so meta tensors are not included in the fallback
aggregation.
In `@modelopt/torch/kernels/quantization/gemm/fp4_kernel.py`:
- Around line 267-274: The code assumes amax has elements but will divide by
zero if amax.numel() == 0; add an explicit guard before computing NUM_FP4_BLOCKS
and BLOCK_SIZE: check if amax.numel() == 0 and raise a clear ValueError (or
handle it) indicating amax is empty; then compute NUM_FP4_BLOCKS = amax.numel(),
verify x.numel() % NUM_FP4_BLOCKS == 0 as before, and compute BLOCK_SIZE =
x.numel() // NUM_FP4_BLOCKS (references: amax, NUM_FP4_BLOCKS, x, BLOCK_SIZE).
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 283-289: The promotion call to promote_nvfp4_static_quantizers is
placed after an early return in max_calibrate so the branch taken when
distributed_sync=False never gets NVFP4StaticQuantizer promotion; move the
promote_nvfp4_static_quantizers(model) invocation so it runs before the early
return in max_calibrate (or invoke it in both the distributed_sync=True and
distributed_sync=False branches) so that promotion always occurs regardless of
the distributed_sync flag.
---
Nitpick comments:
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py`:
- Around line 39-42: The module-level imports of safetensors and
NVFP4StaticQuantizer cause hard dependency loading; move these imports into the
function(s) that actually use them (e.g., inside the routine that opens the
safetensors file or performs quantization) or replace them with the repo's
plugin loader (import_plugin) so they are lazily loaded; specifically, remove
"from safetensors import safe_open" and "from
modelopt.torch.quantization.nn.modules.tensor_quantizer import
NVFP4StaticQuantizer" from the top-level and import safe_open and
NVFP4StaticQuantizer within the function that calls them (or use
import_plugin('safetensors').safe_open /
import_plugin('modelopt.torch.quantization').NVFP4StaticQuantizer) and add a
clear error message when the import fails.
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1126-1128: Add a fail-fast guard before calling
apply_cast_mxfp4_to_nvfp4: validate that args.cast_mxfp4_to_nvfp4 is only true
when args.qformat is in the NVFP4 family (check the exact qformat string values
your code uses) and reject/raise an error (or exit) if it’s used together with
any multi-format auto-quantize option (the flag/variable controlling
multi-format auto-quantize in your parser) to prevent late failures; update the
two call sites that invoke apply_cast_mxfp4_to_nvfp4 (the one using
args.cast_mxfp4_to_nvfp4 and args.pyt_ckpt_path and the similar block around the
1370-1381 region) to perform this validation first and emit a clear message
about allowed combinations.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 636d8bc6-4bef-45a4-a6e8-d12387edfcf0
📒 Files selected for processing (9)
examples/llm_ptq/cast_mxfp4_to_nvfp4.pyexamples/llm_ptq/hf_ptq.pyexamples/llm_ptq/scripts/huggingface_example.shexamples/llm_ptq/scripts/parser.shmodelopt/torch/export/layer_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/kernels/quantization/gemm/fp4_kernel.pymodelopt/torch/quantization/model_calib.pytests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py
| assert blocks_shard is not None, ( | ||
| f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." | ||
| ) | ||
|
|
||
| weight_quantizer = name_to_module.get(qname) | ||
| if weight_quantizer is None: | ||
| missed.append(qname) | ||
| continue | ||
|
|
||
| # The cast assumes ``max_calibrate`` already promoted this quantizer | ||
| # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by | ||
| # static-block max-cal and ``_global_amax`` set by the auto-promote). | ||
| # Anything else means the qformat or quant_cfg disabled this module's | ||
| # weight quantization — surface that loudly so we don't silently no-op. | ||
| assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ( | ||
| f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " | ||
| f"auto-promote), got {type(weight_quantizer).__name__}. The cast " | ||
| "requires the matching quantizer to be enabled with static-block " | ||
| "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." | ||
| ) | ||
| existing = getattr(weight_quantizer, "_amax", None) | ||
| assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ( | ||
| f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " | ||
| f"buffer populated by max_calibrate. Got: {existing!r}." | ||
| ) | ||
|
|
||
| # Pick the device from the existing per-block ``_amax`` buffer. | ||
| device = existing.device | ||
|
|
||
| global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device) | ||
| blocks = _read(blocks_key, blocks_shard) | ||
| per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to( | ||
| dtype=torch.float32, device=device | ||
| ) | ||
| # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1)) | ||
| # while we compute it in natural (E, F, num_blocks) layout. The static | ||
| # export path reshapes via ``.view(expected_shape)``, so we just need | ||
| # element count to agree, then reshape for the in-place copy. | ||
| assert existing.numel() == per_block_amax.numel(), ( | ||
| f"{qname}: ``_amax`` element count {existing.numel()} does not " | ||
| f"match the cast-computed count {per_block_amax.numel()}. The " | ||
| "block layout from calibration disagrees with the source MXFP4 " | ||
| "scales — check that the qformat block_size is 16 and the source " | ||
| "checkpoint is the same MXFP4 model." | ||
| ) |
There was a problem hiding this comment.
❓ Verification inconclusive
Script executed:
cat -n examples/llm_ptq/cast_mxfp4_to_nvfp4.py | sed -n '330,390p'Repository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
330 if info["pct_lossless"] >= 100.0:
331 n_lossless_layers += 1
332 grand_total_blocks += info["n_total_blocks"]
333 grand_lossless_blocks += info["n_lossless_blocks"]
334
335 blocks_key = tensor_key[: -len("_scales")] + "_blocks"
336 qname = quantizer_name_from_blocks_key(blocks_key)
337 blocks_shard = blocks_keys.get(blocks_key)
338 assert blocks_shard is not None, (
339 f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint."
340 )
341
342 weight_quantizer = name_to_module.get(qname)
343 if weight_quantizer is None:
344 missed.append(qname)
345 continue
346
347 # The cast assumes ``max_calibrate`` already promoted this quantizer
348 # to NVFP4StaticQuantizer (with ``_amax`` populated per-block by
349 # static-block max-cal and ``_global_amax`` set by the auto-promote).
350 # Anything else means the qformat or quant_cfg disabled this module's
351 # weight quantization — surface that loudly so we don't silently no-op.
352 assert isinstance(weight_quantizer, NVFP4StaticQuantizer), (
353 f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's "
354 f"auto-promote), got {type(weight_quantizer).__name__}. The cast "
355 "requires the matching quantizer to be enabled with static-block "
356 "NVFP4 (num_bits=(2,1), scale_bits=(4,3))."
357 )
358 existing = getattr(weight_quantizer, "_amax", None)
359 assert isinstance(existing, torch.Tensor) and existing.numel() > 1, (
360 f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` "
361 f"buffer populated by max_calibrate. Got: {existing!r}."
362 )
363
364 # Pick the device from the existing per-block ``_amax`` buffer.
365 device = existing.device
366
367 global_amax = torch.tensor(float(global_amax_value), dtype=torch.float32, device=device)
368 blocks = _read(blocks_key, blocks_shard)
369 per_block_amax = compute_per_block_amax_for_mxfp4(blocks, scales).to(
370 dtype=torch.float32, device=device
371 )
372 # Numel must match — calibration may store ``_amax`` flat (e.g. (N, 1))
373 # while we compute it in natural (E, F, num_blocks) layout. The static
374 # export path reshapes via ``.view(expected_shape)``, so we just need
375 # element count to agree, then reshape for the in-place copy.
376 assert existing.numel() == per_block_amax.numel(), (
377 f"{qname}: ``_amax`` element count {existing.numel()} does not "
378 f"match the cast-computed count {per_block_amax.numel()}. The "
379 "block layout from calibration disagrees with the source MXFP4 "
380 "scales — check that the qformat block_size is 16 and the source "
381 "checkpoint is the same MXFP4 model."
382 )
383
384 # global_amax via the NVFP4StaticQuantizer property setter (writes to
385 # the canonical ``_global_amax`` buffer).
386 weight_quantizer.global_amax = global_amax
387 # _amax: in-place buffer copy, reshaping our value to the calibrator's
388 # storage layout (numel verified above).
389 with torch.no_grad():
390 existing.data.copy_(per_block_amax.view_as(existing))
Script executed:
# Check if this is in examples/ and verify guideline applicability
ls -la examples/llm_ptq/cast_mxfp4_to_nvfp4.pyRepository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
-rw-r--r-- 1 jailuser jailuser 18001 Apr 29 16:46 examples/llm_ptq/cast_mxfp4_to_nvfp4.py
Script executed:
# Verify the assert pattern used in this section
rg -n "assert.*blocks_shard|assert.*isinstance.*NVFP4StaticQuantizer|assert.*existing.*numel" examples/llm_ptq/cast_mxfp4_to_nvfp4.py -A 2Repository: NVIDIA/Model-Optimizer
Repository: NVIDIA/Model-Optimizer
Exit code: 0
stdout:
338: assert blocks_shard is not None, (
339- f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint."
340- )
--
352: assert isinstance(weight_quantizer, NVFP4StaticQuantizer), (
353- f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's "
354- f"auto-promote), got {type(weight_quantizer).__name__}. The cast "
--
359: assert isinstance(existing, torch.Tensor) and existing.numel() > 1, (
360- f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` "
361- f"buffer populated by max_calibrate. Got: {existing!r}."
--
376: assert existing.numel() == per_block_amax.numel(), (
377- f"{qname}: ``_amax`` element count {existing.numel()} does not "
378- f"match the cast-computed count {per_block_amax.numel()}. The "
Replace assert statements with explicit exception handling for checkpoint and model validation.
The four assertions at lines 338, 352, 359, and 376 validate untrusted checkpoint contents (blocks_shard, quantizer type, _amax buffer). Under python -O, these asserts vanish entirely, leaving no validation. Raise concrete exceptions (e.g., ValueError, RuntimeError) instead.
Per SECURITY.md: when implementing the MXFP4→NVFP4 cast flow and checkpoint loading, treat all model/weight artifacts as untrusted and prefer safe, unconditional validation.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 338 - 382, Several
assertions validating untrusted checkpoint/model data (the check that
blocks_shard is not None, the type-check of weight_quantizer, the presence/shape
check of weight_quantizer._amax, and the element-count check comparing
existing.numel() to per_block_amax.numel()) must be replaced with explicit
exception handling; update the checks around blocks_shard, weight_quantizer
(qname → NVFP4StaticQuantizer), existing/_amax, and element count to raise
specific exceptions (e.g., ValueError or RuntimeError) with clear diagnostic
messages including qname, expected vs actual types/values, and any relevant
shapes/numel counts, rather than using assert so validation still runs under
python -O and treats all checkpoint artifacts as untrusted.
| NUM_FP4_BLOCKS = amax.numel() | ||
| if x.numel() % NUM_FP4_BLOCKS != 0: | ||
| raise ValueError( | ||
| f"x.numel() ({x.numel()}) is not divisible by amax.numel() ({NUM_FP4_BLOCKS}); " | ||
| "they must satisfy x.numel() == NUM_FP4_BLOCKS * BLOCK_SIZE." | ||
| ) | ||
| BLOCK_SIZE = x.numel() // NUM_FP4_BLOCKS | ||
|
|
There was a problem hiding this comment.
Guard against empty amax before modulo/division
If amax.numel() == 0, Line 268/Line 273 will hit division-by-zero. Add an explicit precheck.
Suggested fix
original_shape = x.shape
NUM_FP4_BLOCKS = amax.numel()
+ if NUM_FP4_BLOCKS == 0:
+ raise ValueError("amax must contain at least one block.")
if x.numel() % NUM_FP4_BLOCKS != 0:
raise ValueError(
f"x.numel() ({x.numel()}) is not divisible by amax.numel() ({NUM_FP4_BLOCKS}); "
"they must satisfy x.numel() == NUM_FP4_BLOCKS * BLOCK_SIZE."
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/quantization/gemm/fp4_kernel.py` around lines 267 -
274, The code assumes amax has elements but will divide by zero if amax.numel()
== 0; add an explicit guard before computing NUM_FP4_BLOCKS and BLOCK_SIZE:
check if amax.numel() == 0 and raise a clear ValueError (or handle it)
indicating amax is empty; then compute NUM_FP4_BLOCKS = amax.numel(), verify
x.numel() % NUM_FP4_BLOCKS == 0 as before, and compute BLOCK_SIZE = x.numel() //
NUM_FP4_BLOCKS (references: amax, NUM_FP4_BLOCKS, x, BLOCK_SIZE).
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1372 +/- ##
===========================================
+ Coverage 66.36% 76.96% +10.60%
===========================================
Files 471 471
Lines 50504 50510 +6
===========================================
+ Hits 33515 38875 +5360
+ Misses 16989 11635 -5354
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Well-documented PR with solid math and good test coverage for the core conversion logic. The supporting fixes across the codebase (kernel flexibility, export path, expert amax stacking) are small and correct. A few issues to address:
- Resource leak:
safe_openhandles are stored in dicts but never closed in bothbuild_amax_mapandapply_to_model. - Missing test for
apply_to_model: The most complex function in the module has no unit tests. The test suite covers the math functions and checkpoint reading well, but the model-interaction path (the function users actually call) is untested. - Minor:
quantizer_name_from_blocks_keyusesassertfor input validation on user-provided checkpoint data — this should be a proper exception since malformed checkpoint keys are a runtime data condition, not a programming invariant.
| """Walk the source MXFP4 checkpoint and build the per-layer amax map. | ||
|
|
||
| Args: | ||
| checkpoint_dir: Path to a Hugging Face checkpoint directory whose |
There was a problem hiding this comment.
Bot comment.
Resource leak: safe_open handles in handles dict are never closed. Consider using a try/finally block or closing them at the end:
try:
# ... loop body ...
finally:
for h in handles.values():
# safe_open objects are context managers
pass # they get GC'd, but explicit cleanup is betterSame issue exists in apply_to_model at the other handles dict (line ~273).
| in_range = (delta >= E4M3_KMIN) & (delta <= E4M3_KMAX) | ||
|
|
||
| # Fast path: if every block fits E4M3's [-9, 8] window the per-block amax | ||
| # is just the closed-form ideal, and we can skip the per-byte nibble scan |
There was a problem hiding this comment.
Bot comment.
This assert validates user-provided checkpoint data (the key suffix). If someone passes a malformed key, this would be silently skipped under python -O. Since this is input validation rather than a programming invariant, consider using a ValueError instead:
if not blocks_key.endswith("_blocks"):
raise ValueError(f"Unexpected key {blocks_key!r}")| t1 = cast._e2m1_magnitude_table(torch.device("cpu")) | ||
| t2 = cast._e2m1_magnitude_table(torch.device("cpu")) | ||
| assert t1 is t2 # cached: same object | ||
| assert t1.tolist() == cast._E2M1_MAGNITUDE |
There was a problem hiding this comment.
Bot comment.
Missing test coverage for apply_to_model: The most complex and user-facing function in the module is not tested. It has significant logic around matching quantizer names to model modules, asserting quantizer types, reshaping amax buffers, and writing to quantizer state. Consider adding at least one test with a mock model (e.g., a simple nn.Module with a mock NVFP4StaticQuantizer child) to exercise the happy path and verify that global_amax and _amax are written correctly.
|
|
||
|
|
||
| def apply_to_model( | ||
| model: "torch.nn.Module", |
There was a problem hiding this comment.
Bot comment.
Same concern as quantizer_name_from_blocks_key: this assert validates runtime data (whether a paired _blocks tensor exists in the checkpoint). A missing pair is a data error, not a programming bug. Consider ValueError or a descriptive RuntimeError.
Research artifact comparing three algorithms for converting an MXFP4
tensor (block 32, E2M1 + E8M0) to NVFP4 (block 16, E2M1 + E4M3 + FP32
global scale):
Algo 1: dequantize MXFP4 -> bf16 -> standard NVFP4 quantize.
Algo 2: keep E2M1 nibbles verbatim; pick global S = 2^m and store
per-block E4M3 scales as 2^(k_j - m), snapping out-of-range
blocks. Two m strategies: midpoint and 1D integer search over
the closed-form snap-error objective.
Algo 3: hybrid - verbatim path for in-range blocks (zero error) plus
NVFP4 requantization with fixed scale_2 = 2^m for OOR blocks.
m chosen by direct-MSE 1D sweep.
Includes 27 scenarios (gaussian, heavy-tail, outlier patterns, spread
boundary tests, layer-shaped LLM weights) and a report summarizing
results, the snap-up/snap-down asymmetry that drives the m choice, and
the one pathological case (single dominant outlier) where Algo 3 still
trails Algo 1 by 0.21% due to integer-m vs continuous scale_2.
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
The m-search loop in the original Algo 3 turns out to be unnecessary.
Across all 27 test scenarios the search converges on m = k_max - 8 and
that closed-form rule is provably the right pick:
- For spread <= 17, every block's k_j - m lands in [8 - spread, 8],
a subset of E4M3's exact-power-of-2 window [-9, 8]. All blocks take
the verbatim path; the conversion is lossless (MSE = 0).
- For spread > 17, m = k_max - 8 is the only choice that does not
NaN the highest-magnitude blocks: a lower m drives the per-block
scale amax/(6*2^m) above E4M3's max (448); a higher m only shrinks
in-range coverage on the low side without helping the high side.
Replaces the brute-force algo3_hybrid_requant with a single-pass
algo3_hybrid using the closed-form m. The Algo 4 / Algo 5 variants
that were used to discover this rule are removed; the script is back
to three algorithms (Algo 1 / Algo 2 / Algo 3) and the report has been
rewritten accordingly.
Same MSE numbers as before. No library changes — strictly under
scratch/.
Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
When the source HF checkpoint is MXFP4 (e.g. openai/gpt-oss-20b), the new flag pins NVFP4 weight quantizers' scale_2 to 2^m (m = k_max - 8) and the per-block _amax to 6 * 2^k_j read from the source *_scales. Per-block scale = 2^(k_j - m) is exactly representable in E4M3 for in-range blocks, so NVFP4 dequant matches MXFP4 dequant bit-for-bit (verified SNR=inf on gpt-oss-20b's full ~19B-element MoE expert weights). For out-of-range blocks (k_max - k_j > 17), the per-block amax falls back to data-derived max(|w_block|), keeping the post-clamp scale closer to the actual block magnitude than the closed-form ideal would. Modelopt-side enablers: - max_calibrate auto-promotes static-block NVFP4 weight quantizers to NVFP4StaticQuantizer at the end of calibration. - static_blockwise_fp4_fake_quant kernel accepts N-D inputs (was 2D-only), unblocking MoE expert weights of shape (E, F, K). - BMM-experts NVFP4 export routes through get_weights_scaling_factor_from_quantizer for static-mode quantizers, so the pinned _amax is consumed (was bypassed by recompute-from-weight). - set_expert_quantizer_amax scalar-reduces per-quantizer amax before stacking, supporting per-block (vs scalar) static-mode amax. Wired through scripts/parser.sh + scripts/huggingface_example.sh as the shell-level --cast_mxfp4_to_nvfp4 flag. Removes the scratch/ MSE experiment (kept in PR description for context). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
- cast_mxfp4_to_nvfp4.py: validate blocks.shape[-1] == 16; extract a _shard_reader() context manager so build_amax_map and apply_to_model share a single auto-closed safetensors handle cache; skip the 16x larger *_blocks read for fully-lossless layers (per-block amax is 6 * 2^k_j read directly from *_scales). - hf_ptq.py: parse-time guard rejects --cast_mxfp4_to_nvfp4 without an NVFP4-family --qformat or with --auto_quantize_bits. - model_calib.py: move promote_nvfp4_static_quantizers() before the distributed_sync=False early return so single-process callers also get static-block NVFP4 promotion at the end of max_calibrate. - export/layer_utils.py: skip meta tensors in set_expert_quantizer_amax fallback aggregation (.amax()/.to() would fail on a meta tensor). - tests/.../test_cast_mxfp4_to_nvfp4.py: add three apply_to_model tests (happy path with a mock NVFP4StaticQuantizer subclass; missing *_blocks pair raises AssertionError; wrong quantizer type raises AssertionError). 18 tests pass. Verified: 20b PTQ + cast still 100% lossless (48/48 layers, 597M/597M blocks); TRT-LLM inference produces the same deterministic outputs as before the refactor. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
785b43f to
21da896
Compare
meenchen
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review: All critical comments from the previous review have been addressed.
Addressed:
- Resource leak:
safe_openhandles now managed via_shard_reader()context manager withExitStack— deterministic cleanup on all paths. - NVFP4StaticQuantizer promotion: moved before the
distributed_syncearly return inmax_calibrate, so single-process callers also get promotion. - Block trailing dim validation: now checks
blocks.shape[-1] != 16. - I/O skip for lossless layers: fully-lossless layers compute per-block amax from scales alone, avoiding the 16x larger block read.
- Meta tensor guard:
existing_amax.is_metacheck added inset_expert_quantizer_amax. apply_to_modeltest coverage: 3 tests added (happy path, missing blocks pair, wrong quantizer type).- Fail-fast validation:
--cast_mxfp4_to_nvfp4now rejects non-NVFP4 qformats andauto_quantize_bits.
Design review: This is a focused conversion utility, not a new architectural subsystem. The 6-directory span comes from small targeted changes across existing subsystems (kernel N-D input support, export static-quantizer routing, calibration promotion timing) plus the new conversion script in examples/. The design is appropriate and well-scoped.
Code quality: Well-documented with thorough docstrings explaining the math. Test suite covers core math functions, edge cases (zero blocks, OOR blocks), checkpoint I/O, and the model-interaction path. The ~950 lines are cohesive and can't be reasonably split further. License headers match the canonical LICENSE_HEADER.
Complex PR: spans 6 directories (≥ 5). Looping in a human for approval.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
modelopt/torch/kernels/quantization/gemm/fp4_kernel.py (1)
267-273:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winGuard empty
amaxbefore modulo/division.Line 268 can hit division-by-zero when
amax.numel() == 0, before your validation error path runs.Suggested fix
original_shape = x.shape NUM_FP4_BLOCKS = amax.numel() + if NUM_FP4_BLOCKS == 0: + raise ValueError("amax must contain at least one block.") if x.numel() % NUM_FP4_BLOCKS != 0: raise ValueError( f"x.numel() ({x.numel()}) is not divisible by amax.numel() ({NUM_FP4_BLOCKS}); " "they must satisfy x.numel() == NUM_FP4_BLOCKS * BLOCK_SIZE." )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/kernels/quantization/gemm/fp4_kernel.py` around lines 267 - 273, The code risks a division-by-zero when NUM_FP4_BLOCKS = amax.numel() is 0; update the checks in the block using NUM_FP4_BLOCKS/amax.numel() (the variables referenced) to first guard against an empty amax by raising a clear ValueError if NUM_FP4_BLOCKS == 0, then perform the divisibility check (x.numel() % NUM_FP4_BLOCKS) and finally compute BLOCK_SIZE = x.numel() // NUM_FP4_BLOCKS; ensure the error messages reference the offending sizes (amax.numel() and x.numel()) for clarity.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py`:
- Around line 223-240: The code currently trusts entries from
index["weight_map"] and constructs paths via ckpt_dir / shard, which allows
path-traversal or absolute paths; fix by normalizing and restricting each shard
path: for each shard string from index["weight_map"] compute a resolved path
(e.g., shard_path = (ckpt_dir / shard).resolve(strict=False)) and compare it
against the resolved checkpoint root (root = ckpt_dir.resolve()); if shard_path
is not inside root (use Path.is_relative_to or os.path.commonpath(root,
shard_path) != str(root)) then reject or raise an exception, otherwise use that
normalized shard_path when building the returned dict (apply same check before
passing a shard to safe_open or when assigning out[k] = shard_path for keys
iterated via safe_open); this prevents absolute, ../, or symlink escapes from
loading files outside the checkpoint directory.
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1454-1465: The validation for args.cast_mxfp4_to_nvfp4 currently
only checks args.qformat; update it to derive the effective qformat(s) when
args.recipe is provided (use the same recipe parsing/processing logic used
earlier where the recipe is loaded into the quantization config) and validate
those effective qformat values instead of only args.qformat; specifically, in
the block guarding args.cast_mxfp4_to_nvfp4, consult args.recipe / the
recipe-derived quant config (the same variable or function used around the
recipe handling at lines ~1031–1038) to produce the effective qformats and then
enforce that all contain "nvfp4" and that auto_quantize_bits is not set for
multi-format auto-quantize.
---
Duplicate comments:
In `@modelopt/torch/kernels/quantization/gemm/fp4_kernel.py`:
- Around line 267-273: The code risks a division-by-zero when NUM_FP4_BLOCKS =
amax.numel() is 0; update the checks in the block using
NUM_FP4_BLOCKS/amax.numel() (the variables referenced) to first guard against an
empty amax by raising a clear ValueError if NUM_FP4_BLOCKS == 0, then perform
the divisibility check (x.numel() % NUM_FP4_BLOCKS) and finally compute
BLOCK_SIZE = x.numel() // NUM_FP4_BLOCKS; ensure the error messages reference
the offending sizes (amax.numel() and x.numel()) for clarity.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 57e66c3c-1b4e-4cb0-b223-412ff08efeab
📒 Files selected for processing (9)
examples/llm_ptq/cast_mxfp4_to_nvfp4.pyexamples/llm_ptq/hf_ptq.pyexamples/llm_ptq/scripts/huggingface_example.shexamples/llm_ptq/scripts/parser.shmodelopt/torch/export/layer_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/kernels/quantization/gemm/fp4_kernel.pymodelopt/torch/quantization/model_calib.pytests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py
✅ Files skipped from review due to trivial changes (1)
- examples/llm_ptq/scripts/huggingface_example.sh
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/llm_ptq/scripts/parser.sh
- tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py
| index_path = ckpt_dir / "model.safetensors.index.json" | ||
| if index_path.is_file(): | ||
| with index_path.open() as f: | ||
| index = json.load(f) | ||
| return { | ||
| k: ckpt_dir / shard for k, shard in index["weight_map"].items() if k.endswith(suffix) | ||
| } | ||
| shards = list(ckpt_dir.glob("*.safetensors")) | ||
| if len(shards) != 1: | ||
| raise FileNotFoundError( | ||
| f"Expected model.safetensors.index.json or a single .safetensors file in {ckpt_dir}" | ||
| ) | ||
| out: dict[str, Path] = {} | ||
| with safe_open(shards[0], framework="pt") as f: | ||
| # ``safe_open`` is not a dict; ``.keys()`` is its iterator. | ||
| for k in f.keys(): # noqa: SIM118 | ||
| if k.endswith(suffix): | ||
| out[k] = shards[0] |
There was a problem hiding this comment.
Normalize shard paths before opening checkpoint files.
model.safetensors.index.json is user-controlled input, but this code trusts each weight_map entry and joins it straight onto ckpt_dir. A crafted index can point at ../..., an absolute path, or a symlinked .safetensors file outside the checkpoint root, which turns checkpoint loading into a path-traversal file read.
Suggested fix
def _collect_keys_with_suffix(ckpt_dir: Path, suffix: str) -> dict[str, Path]:
"""Return ``{tensor_name: shard_path}`` for every key ending with ``suffix``."""
+ ckpt_root = ckpt_dir.resolve()
+
+ def _checked_shard_path(path: Path) -> Path:
+ candidate = path.resolve()
+ try:
+ candidate.relative_to(ckpt_root)
+ except ValueError as exc:
+ raise ValueError(f"Shard path escapes checkpoint dir: {path}") from exc
+ return candidate
+
index_path = ckpt_dir / "model.safetensors.index.json"
if index_path.is_file():
with index_path.open() as f:
index = json.load(f)
return {
- k: ckpt_dir / shard for k, shard in index["weight_map"].items() if k.endswith(suffix)
+ k: _checked_shard_path(ckpt_dir / shard)
+ for k, shard in index["weight_map"].items()
+ if k.endswith(suffix)
}
- shards = list(ckpt_dir.glob("*.safetensors"))
+ shards = [_checked_shard_path(shard) for shard in ckpt_dir.glob("*.safetensors")]As per coding guidelines, "ModelOpt loads user-provided artifacts (models/weights/configs/calibration). Treat checkpoint/weight inputs as untrusted and apply safe deserialization practices."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 223 - 240, The code
currently trusts entries from index["weight_map"] and constructs paths via
ckpt_dir / shard, which allows path-traversal or absolute paths; fix by
normalizing and restricting each shard path: for each shard string from
index["weight_map"] compute a resolved path (e.g., shard_path = (ckpt_dir /
shard).resolve(strict=False)) and compare it against the resolved checkpoint
root (root = ckpt_dir.resolve()); if shard_path is not inside root (use
Path.is_relative_to or os.path.commonpath(root, shard_path) != str(root)) then
reject or raise an exception, otherwise use that normalized shard_path when
building the returned dict (apply same check before passing a shard to safe_open
or when assigning out[k] = shard_path for keys iterated via safe_open); this
prevents absolute, ../, or symlink escapes from loading files outside the
checkpoint directory.
| if args.cast_mxfp4_to_nvfp4: | ||
| qformats = [q.strip() for q in args.qformat.split(",")] | ||
| if not all("nvfp4" in q for q in qformats): | ||
| raise ValueError( | ||
| "--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values " | ||
| f"(got {args.qformat!r}). Use e.g. --qformat nvfp4 or nvfp4_mlp_only." | ||
| ) | ||
| if args.auto_quantize_bits is not None: | ||
| raise ValueError( | ||
| "--cast_mxfp4_to_nvfp4 is not supported with --auto_quantize_bits " | ||
| "(multi-format auto-quantize)." | ||
| ) |
There was a problem hiding this comment.
--cast_mxfp4_to_nvfp4 validation should account for --recipe.
Line 1456 validates only args.qformat, but when --recipe is set (Lines 1031-1038), the effective quantization config comes from the recipe. This can incorrectly reject valid NVFP4 recipe runs.
Suggested fix
if args.cast_mxfp4_to_nvfp4:
- qformats = [q.strip() for q in args.qformat.split(",")]
- if not all("nvfp4" in q for q in qformats):
- raise ValueError(
- "--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values "
- f"(got {args.qformat!r}). Use e.g. --qformat nvfp4 or nvfp4_mlp_only."
- )
+ if args.recipe is None:
+ qformats = [q.strip() for q in args.qformat.split(",")]
+ if not all("nvfp4" in q for q in qformats):
+ raise ValueError(
+ "--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values "
+ f"(got {args.qformat!r}). Use e.g. --qformat nvfp4 or nvfp4_mlp_only."
+ )
+ elif "nvfp4" not in args.recipe.lower():
+ raise ValueError(
+ "--cast_mxfp4_to_nvfp4 requires an NVFP4 PTQ recipe."
+ )
if args.auto_quantize_bits is not None:
raise ValueError(
"--cast_mxfp4_to_nvfp4 is not supported with --auto_quantize_bits "
"(multi-format auto-quantize)."
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/hf_ptq.py` around lines 1454 - 1465, The validation for
args.cast_mxfp4_to_nvfp4 currently only checks args.qformat; update it to derive
the effective qformat(s) when args.recipe is provided (use the same recipe
parsing/processing logic used earlier where the recipe is loaded into the
quantization config) and validate those effective qformat values instead of only
args.qformat; specifically, in the block guarding args.cast_mxfp4_to_nvfp4,
consult args.recipe / the recipe-derived quant config (the same variable or
function used around the recipe handling at lines ~1031–1038) to produce the
effective qformats and then enforce that all contain "nvfp4" and that
auto_quantize_bits is not set for multi-format auto-quantize.
Is this correct? Out of range blocks will be Update: I understand now that this computation is using E4M3 subnormal values. |
| fi | ||
|
|
||
| if $CAST_MXFP4_TO_NVFP4; then | ||
| PTQ_ARGS+=" --cast_mxfp4_to_nvfp4 " |
There was a problem hiding this comment.
Should we make this a default option? I am thinking of detecting MXFP4 quant config on the fly, and just applying the casting by default.
There was a problem hiding this comment.
So far we only have gpt-oss and DSV4 (I will update another code path in the deepseek example). I think it's ok to have this flag for now. Also each model may have mxfp4 implemented differently.
| # promotion. ``promote_nvfp4_static_quantizers`` only promotes when | ||
| # ``is_static_block_quant`` is True and the per-block ``_amax`` buffer is | ||
| # populated, so it's a no-op for dynamic-block / non-NVFP4 configs. | ||
| promote_nvfp4_static_quantizers(model) |
There was a problem hiding this comment.
does this break GPTQ/MSE support?
There was a problem hiding this comment.
It should not. It just says: if the weight quantizer is static instead of dynamic fp4, let's make it static fp4 weight quantizer type instead of the parent tensor quantizer.
Move the inline weight-quantizer block_sizes='static' rewrite out of quantize_main() into a public force_weight_quantizers_static() helper in cast_mxfp4_to_nvfp4.py, keeping the cast-specific config logic colocated with the rest of the cast flow. Addresses review feedback on PR #1372. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Add a GPT-OSS row + footnote to the llm_ptq support matrix and a new "MXFP4 -> NVFP4 cast (for GPT-OSS)" subsection covering usage, the closed-form per-block math, and the NVFP4-qformat / no-auto-quantize constraints. Add a one-line CHANGELOG entry under 0.45. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
examples/llm_ptq/hf_ptq.py (1)
1444-1455:⚠️ Potential issue | 🟠 Major | ⚡ Quick win
--cast_mxfp4_to_nvfp4validation does not account for--recipe.When
--recipeis set (lines 1032-1038), the effective quantization config comes from the recipe, notargs.qformat. The current validation only checksargs.qformat, which could incorrectly reject valid NVFP4 recipe runs or accept invalid non-NVFP4 recipes.Suggested fix
if args.cast_mxfp4_to_nvfp4: - qformats = [q.strip() for q in args.qformat.split(",")] - if not all("nvfp4" in q for q in qformats): - raise ValueError( - "--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values " - f"(got {args.qformat!r}). Use e.g. --qformat nvfp4 or nvfp4_mlp_only." - ) + if args.recipe is None: + qformats = [q.strip() for q in args.qformat.split(",")] + if not all("nvfp4" in q for q in qformats): + raise ValueError( + "--cast_mxfp4_to_nvfp4 requires NVFP4-family --qformat values " + f"(got {args.qformat!r}). Use e.g. --qformat nvfp4 or nvfp4_mlp_only." + ) + elif "nvfp4" not in args.recipe.lower(): + raise ValueError( + f"--cast_mxfp4_to_nvfp4 requires an NVFP4-family recipe (got {args.recipe!r})." + ) if args.auto_quantize_bits is not None: raise ValueError( "--cast_mxfp4_to_nvfp4 is not supported with --auto_quantize_bits " "(multi-format auto-quantize)." )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_ptq/hf_ptq.py` around lines 1444 - 1455, The validation block for args.cast_mxfp4_to_nvfp4 currently inspects only args.qformat; change it to derive the effective qformat(s) the code will use when a recipe is provided (i.e., reuse the same recipe→quantization-config logic used earlier when applying args.recipe to produce the effective quant config) and validate against that effective list instead of raw args.qformat, then keep the existing auto_quantize_bits incompatibility check but based on the effective qformats; ensure you still raise the same ValueError messages (substituting the effective qformat string where appropriate) and reference args.cast_mxfp4_to_nvfp4, args.recipe, args.qformat and args.auto_quantize_bits in the updated logic.examples/llm_ptq/cast_mxfp4_to_nvfp4.py (2)
221-241:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winPath traversal vulnerability: shard paths from index.json are not validated.
The
weight_mapentries inmodel.safetensors.index.jsonare user-controlled input. A crafted index could contain../..., absolute paths, or symlinks pointing outside the checkpoint directory, turning this into an arbitrary file read.As per coding guidelines: "When adding/adjusting checkpoint casting/export tooling, treat all model/weight artifacts as untrusted: validate sizes/types to avoid crashes/DoS from malformed inputs."
Suggested fix
def _collect_keys_with_suffix(ckpt_dir: Path, suffix: str) -> dict[str, Path]: """Return ``{tensor_name: shard_path}`` for every key ending with ``suffix``.""" + ckpt_root = ckpt_dir.resolve() + + def _safe_shard_path(shard_name: str) -> Path: + candidate = (ckpt_dir / shard_name).resolve() + if not candidate.is_relative_to(ckpt_root): + raise ValueError(f"Shard path escapes checkpoint dir: {shard_name}") + return candidate + index_path = ckpt_dir / "model.safetensors.index.json" if index_path.is_file(): with index_path.open() as f: index = json.load(f) return { - k: ckpt_dir / shard for k, shard in index["weight_map"].items() if k.endswith(suffix) + k: _safe_shard_path(shard) + for k, shard in index["weight_map"].items() + if k.endswith(suffix) } shards = list(ckpt_dir.glob("*.safetensors"))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 221 - 241, The weight_map entries read in _collect_keys_with_suffix are untrusted and may contain path traversal or absolute paths; after composing shard paths from index entries, resolve each candidate path (e.g., resolved = (ckpt_dir / shard).resolve()) and verify it is a regular file and resides inside the checkpoint directory (e.g., ensure resolved.is_file() and resolved.is_relative_to(ckpt_dir.resolve()) or compare parents), rejecting/raising if not; also reject absolute paths that escape ckpt_dir and avoid following symlinks that point outside by checking resolved path ownership, and only add validated paths to the returned dict (do similar validation for any branch that reads safe_open on a shard).
370-426:⚠️ Potential issue | 🟠 Major | ⚡ Quick winReplace
assertstatements with explicit exceptions for checkpoint/model validation.Lines 370, 384, 391, and 420 use
assertto validate checkpoint contents and model state. These checks validate untrusted data and will be stripped underpython -O, leaving no validation.Per coding guidelines: "treat all model/weight artifacts as untrusted."
Suggested fix
blocks_shard = blocks_keys.get(blocks_key) - assert blocks_shard is not None, ( - f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." - ) + if blocks_shard is None: + raise ValueError( + f"{tensor_key}: no paired '{blocks_key}' tensor found in source checkpoint." + ) weight_quantizer = name_to_module.get(qname) if weight_quantizer is None: missed.append(qname) continue - assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ( - f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " - f"auto-promote), got {type(weight_quantizer).__name__}. The cast " - "requires the matching quantizer to be enabled with static-block " - "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." - ) + if not isinstance(weight_quantizer, NVFP4StaticQuantizer): + raise TypeError( + f"{qname}: expected NVFP4StaticQuantizer (set by max_calibrate's " + f"auto-promote), got {type(weight_quantizer).__name__}. The cast " + "requires the matching quantizer to be enabled with static-block " + "NVFP4 (num_bits=(2,1), scale_bits=(4,3))." + ) existing = getattr(weight_quantizer, "_amax", None) - assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ( - f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " - f"buffer populated by max_calibrate. Got: {existing!r}." - ) + if not isinstance(existing, torch.Tensor) or existing.numel() <= 1: + raise ValueError( + f"{qname}: NVFP4StaticQuantizer must have a per-block ``_amax`` " + f"buffer populated by max_calibrate. Got: {existing!r}." + ) ... - assert existing.numel() == per_block_amax.numel(), ( - f"{qname}: ``_amax`` element count {existing.numel()} does not " - f"match the cast-computed count {per_block_amax.numel()}. The " - "block layout from calibration disagrees with the source MXFP4 " - "scales — check that the qformat block_size is 16 and the source " - "checkpoint is the same MXFP4 model." - ) + if existing.numel() != per_block_amax.numel(): + raise ValueError( + f"{qname}: ``_amax`` element count {existing.numel()} does not " + f"match the cast-computed count {per_block_amax.numel()}. The " + "block layout from calibration disagrees with the source MXFP4 " + "scales — check that the qformat block_size is 16 and the source " + "checkpoint is the same MXFP4 model." + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 370 - 426, Replace the four runtime `assert` checks that validate untrusted checkpoint/model data with explicit exceptions so they are not stripped under python -O: where the code currently does `assert blocks_shard is not None, ...` raise a ValueError (or RuntimeError) with the same message referencing `tensor_key`/`blocks_key`; where it checks `if weight_quantizer is None: missed.append(qname); continue` keep that logic but replace the subsequent `assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ...` with a TypeError (or ValueError) using `qname` and the `NVFP4StaticQuantizer` type name in the message; replace `assert isinstance(existing, torch.Tensor) and existing.numel() > 1, ...` with a ValueError mentioning `qname` and the `existing` repr; and replace the final `assert existing.numel() == per_block_amax.numel(), ...` with a ValueError that mentions `qname`, `existing.numel()`, and `per_block_amax.numel()` so validation always runs regardless of optimization flags.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py`:
- Around line 209-218: Replace the assert in quantizer_name_from_blocks_key with
explicit input validation: in function quantizer_name_from_blocks_key, check if
blocks_key.endswith("_blocks") and if not raise a ValueError (including the
original message f"Unexpected key {blocks_key!r}"), then return the transformed
string as before; this ensures validation remains active under python -O and
clearly signals bad user input.
---
Duplicate comments:
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py`:
- Around line 221-241: The weight_map entries read in _collect_keys_with_suffix
are untrusted and may contain path traversal or absolute paths; after composing
shard paths from index entries, resolve each candidate path (e.g., resolved =
(ckpt_dir / shard).resolve()) and verify it is a regular file and resides inside
the checkpoint directory (e.g., ensure resolved.is_file() and
resolved.is_relative_to(ckpt_dir.resolve()) or compare parents),
rejecting/raising if not; also reject absolute paths that escape ckpt_dir and
avoid following symlinks that point outside by checking resolved path ownership,
and only add validated paths to the returned dict (do similar validation for any
branch that reads safe_open on a shard).
- Around line 370-426: Replace the four runtime `assert` checks that validate
untrusted checkpoint/model data with explicit exceptions so they are not
stripped under python -O: where the code currently does `assert blocks_shard is
not None, ...` raise a ValueError (or RuntimeError) with the same message
referencing `tensor_key`/`blocks_key`; where it checks `if weight_quantizer is
None: missed.append(qname); continue` keep that logic but replace the subsequent
`assert isinstance(weight_quantizer, NVFP4StaticQuantizer), ...` with a
TypeError (or ValueError) using `qname` and the `NVFP4StaticQuantizer` type name
in the message; replace `assert isinstance(existing, torch.Tensor) and
existing.numel() > 1, ...` with a ValueError mentioning `qname` and the
`existing` repr; and replace the final `assert existing.numel() ==
per_block_amax.numel(), ...` with a ValueError that mentions `qname`,
`existing.numel()`, and `per_block_amax.numel()` so validation always runs
regardless of optimization flags.
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1444-1455: The validation block for args.cast_mxfp4_to_nvfp4
currently inspects only args.qformat; change it to derive the effective
qformat(s) the code will use when a recipe is provided (i.e., reuse the same
recipe→quantization-config logic used earlier when applying args.recipe to
produce the effective quant config) and validate against that effective list
instead of raw args.qformat, then keep the existing auto_quantize_bits
incompatibility check but based on the effective qformats; ensure you still
raise the same ValueError messages (substituting the effective qformat string
where appropriate) and reference args.cast_mxfp4_to_nvfp4, args.recipe,
args.qformat and args.auto_quantize_bits in the updated logic.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b3add9df-3fc6-4254-89cd-a2a376ae1b19
📒 Files selected for processing (2)
examples/llm_ptq/cast_mxfp4_to_nvfp4.pyexamples/llm_ptq/hf_ptq.py
| def quantizer_name_from_blocks_key(blocks_key: str) -> str: | ||
| """Map ``<base>_blocks`` -> ``<base>_weight_quantizer``. | ||
|
|
||
| OpenAI's MXFP4 checkpoint convention stores packed weights as | ||
| ``<name>_blocks`` and scales as ``<name>_scales``. modelopt's | ||
| ``GptOssExperts`` wrapper attaches the weight quantizer at | ||
| ``<name>_weight_quantizer``. | ||
| """ | ||
| assert blocks_key.endswith("_blocks"), f"Unexpected key {blocks_key!r}" | ||
| return blocks_key[: -len("_blocks")] + "_weight_quantizer" |
There was a problem hiding this comment.
Replace assert with explicit exception for input validation.
This validates user-provided checkpoint key data, not a programming invariant. Under python -O the assert is disabled, leaving no validation.
Suggested fix
def quantizer_name_from_blocks_key(blocks_key: str) -> str:
- assert blocks_key.endswith("_blocks"), f"Unexpected key {blocks_key!r}"
+ if not blocks_key.endswith("_blocks"):
+ raise ValueError(f"Unexpected key {blocks_key!r}: expected '*_blocks' suffix")
return blocks_key[: -len("_blocks")] + "_weight_quantizer"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/llm_ptq/cast_mxfp4_to_nvfp4.py` around lines 209 - 218, Replace the
assert in quantizer_name_from_blocks_key with explicit input validation: in
function quantizer_name_from_blocks_key, check if blocks_key.endswith("_blocks")
and if not raise a ValueError (including the original message f"Unexpected key
{blocks_key!r}"), then return the transformed string as before; this ensures
validation remains active under python -O and clearly signals bad user input.
Make the new --cast_mxfp4_to_nvfp4 entry's link anonymous (trailing double underscore) so it doesn't collide with the existing named target for the same README text on the multinode_ptq entry. Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
Summary
--cast_mxfp4_to_nvfp4flag inhf_ptq.py(andhuggingface_example.sh) that converts an MXFP4 source checkpoint (e.g.openai/gpt-oss-20b) into an NVFP4 export with bit-exact weight reconstruction for the in-range blocks.scale_2 = 2^m(wherem = k_max − 8) and_amax = 6·2^k_jper NVFP4 block, both read from the source*_scales. The resulting per-block scale2^(k_j − m)is exactly representable in E4M3, soround_to_E2M1(value / 2^k_j)yields the original MXFP4 nibble verbatim. For out-of-range blocks (k_max − k_j > 17) the per-block amax falls back to data-derivedmax(|w_block|), which keeps the post-E4M3-clamp scale close to the block's actual magnitude.Verification
End-to-end on
openai/gpt-oss-20bwith--qformat=nvfp4_mlp_only --cast_mxfp4_to_nvfp4:End-to-end on
openai/gpt-oss-120bwith the same flags (4×B200,--use_seq_device_map --gpu_max_mem_percentage 0.5 --calib_batch_size 4):Five layers fall into the OOR regime (block-spread > 17); the remaining 1,214 OOR blocks use the data-derived per-block amax fallback. Block-level losslessness is 99.99996% end-to-end.
Per-tensor MSE between MXFP4 source dequant and NVFP4 export dequant (~19B elements):
Modelopt-side enablers
max_calibrateauto-promotes static-block NVFP4 weight quantizers toNVFP4StaticQuantizerat the end of calibration.static_blockwise_fp4_fake_quantkernel accepts N-D inputs (was 2D-only), unblocking MoE expert weights of shape(E, F, K).get_weights_scaling_factor_from_quantizerfor static-mode quantizers, so the pinned_amaxis actually consumed.set_expert_quantizer_amaxscalar-reduces per-quantizer amax before stacking, supporting per-block (vs scalar) static-mode amax.Test plan
tests/examples/llm_ptq/test_cast_mxfp4_to_nvfp4.py(15 tests, all passing) cover: scalar/global-amax math, per-block hybrid (in-range closed-form vs OOR data-derived), shape preservation, key collection, and end-to-endbuild_amax_mapagainst a synthetic safetensors checkpoint.openai/gpt-oss-20b(nvfp4_mlp_onlyqformat) with--cast_mxfp4_to_nvfp4succeeds; export takes ~21 s. 100% lossless cast (48/48 layers, 597,196,800 / 597,196,800 blocks).openai/gpt-oss-120b(4×B200,nvfp4_mlp_only,--use_seq_device_map --gpu_max_mem_percentage 0.5 --calib_batch_size 4). 67/72 layers fully lossless; 99.99996% block-level losslessness (3,583,179,586 / 3,583,180,800).examples/llm_ptq/run_tensorrt_llm.py:🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Documentation