Skip to content

Skip Softmax diffusion export#1269

Open
jingyu-ml wants to merge 45 commits intomainfrom
jingyux/diffusion-skip-softmax-2
Open

Skip Softmax diffusion export#1269
jingyu-ml wants to merge 45 commits intomainfrom
jingyux/diffusion-skip-softmax-2

Conversation

@jingyu-ml
Copy link
Copy Markdown
Contributor

@jingyu-ml jingyu-ml commented Apr 15, 2026

What does this PR do?

Type of change: New Feature

Adds HuggingFace checkpoint export for diffusion pipelines calibrated with skip-softmax, on top of the base skip-softmax MR (jingyux/diffusion-skip-softmax). Concretely:

  • _export_diffusers_checkpoint now walks every nn.Module component of a diffusers pipeline, calls export_sparse_attention_config, injects the result into that component's config.json as sparse_attention_config, and additionally writes a unified top-level sparse.yaml keyed by pipeline component (transformer, transformer_2, …). The existing LLM export_hf_checkpoint path also gains a sibling sparse.yaml dump.
  • export_sparse_attention_config is generalized: per-group nesting (group_0.threshold_scale_factor, group_0.raw_threshold, group_0.disabled_layers) so future sparse methods can coexist, plus per-layer disabled_layers reporting and a raw_threshold-only path for uncalibrated use.
  • Log-space calibration export: the calibrator now propagates log_a / fit_logspace through its result dict, and the exporter emits the matching formula: "log_a + b * target_sparsity" for diffusion (linear-space a * exp(b * S) is still used for LLMs).
  • Example wiring: examples/diffusers/sparsity/wan22_skip_softmax.py gets an --export-dir flag that calls export_hf_checkpoint(pipe, export_dir=...) after calibration.
  • Updated CHANGELOG.rst to note diffusion coverage for skip-softmax.

Usage

python examples/diffusers/sparsity/wan22_skip_softmax.py \
    --model-path Wan-AI/Wan2.2-TI2V-5B-Diffusers \
    --calibrate \
    --target-sparsity 0.5 \
    --export-dir ./wan22_skip_softmax_ckpt

Equivalent Python:

from diffusers import WanPipeline
from modelopt.torch.export import export_hf_checkpoint
import modelopt.torch.sparsity.attention_sparsity as mtsa

pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.2-TI2V-5B-Diffusers")
mtsa.calibrate(pipe.transformer, ...)          # from the base skip-softmax MR
export_hf_checkpoint(pipe, export_dir="./wan22_skip_softmax_ckpt")

Resulting layout:

wan22_skip_softmax_ckpt/
├── sparse.yaml                      # unified, keyed by component
├── transformer/
│   └── config.json                  # carries sparse_attention_config
├── transformer_2/
│   └── config.json
├── vae/ …
└── scheduler/ …

A representative config.json entry for a diffusion component:

"sparse_attention_config": {
  "config_groups": {
    "group_0": {
      "sparse_algo": "softmax_skip",
      "targets": ["WanAttention"],
      "threshold_scale_factor": {
        "formula": "log_a + b * target_sparsity",
        "prefill": {"log_a": 0.21, "b": 3.45}
      },
      "disabled_layers": ["blocks.0.attn1", "blocks.39.attn1"]
    }
  },
  "producer": {"name": "modelopt", "version": "0.37.0"}
}

Testing

  • Extended tests/examples/diffusers/test_sparsity.py with a calibrate → export → reload round-trip on a small diffusion pipeline, asserting the presence and shape of sparse_attention_config in each component's config.json and the unified top-level sparse.yaml.
  • Manually verified on Wan2.2-T2V-14B: sparse.yaml and transformer{,_2}/config.json contain the expected log-space threshold_scale_factor, any disabled layers, and producer metadata; a freshly loaded pipeline from the exported checkpoint reproduces the calibrated sparsity target end-to-end.
  • LLM export path regression-checked by re-running existing test_sparsity tests — the new sparse.yaml sibling is emitted without changing the existing config.json patching behavior.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Additional Information

Summary by CodeRabbit

  • New Features

    • Sparse attention now supports video diffusion models in addition to language models
    • New CLI option to export sparsified diffusers checkpoints to a user-specified directory
    • Exports include a top-level YAML summary of per-component sparse attention configs alongside existing JSON configs
    • Calibration can optionally fit in log-space and records observed sparsity bounds
  • Tests

    • Added end-to-end and unit tests covering diffusers sparsity flows and kernel backend registration/configuration

jingyu-ml and others added 30 commits April 2, 2026 06:02
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners April 15, 2026 22:00
@jingyu-ml jingyu-ml requested review from ajrasane and kaix-nv April 15, 2026 22:00
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml requested review from ChenhanYu and realAsma April 15, 2026 22:00
@jingyu-ml jingyu-ml marked this pull request as draft April 15, 2026 22:00
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 15, 2026

📝 Walkthrough
🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Skip Softmax diffusion export' directly summarizes the main change: adding HuggingFace checkpoint export support for diffusers pipelines with skip-softmax sparsity.
Docstring Coverage ✅ Passed Docstring coverage is 81.16% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected: no unsafe torch.load, numpy.load, eval/exec, or yaml.load calls found in modified files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jingyux/diffusion-skip-softmax-2

Comment @coderabbitai help to get the list of available commands and usage tips.

@jingyu-ml jingyu-ml self-assigned this Apr 15, 2026
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 15, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1269/

Built to branch gh-pages at 2026-05-04 22:44 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 15, 2026

Codecov Report

❌ Patch coverage is 25.38330% with 438 lines in your changes missing coverage. Please review.
✅ Project coverage is 55.31%. Comparing base (04fcf24) to head (b77d098).

Files with missing lines Patch % Lines
modelopt/torch/kernels/triton_fa.py 0.00% 108 Missing ⚠️
.../attention_sparsity/methods/triton_skip_softmax.py 16.37% 97 Missing ⚠️
...attention_sparsity/kernels/ltx_triton_attention.py 19.54% 70 Missing ⚠️
...ion_sparsity/kernels/diffusers_triton_attention.py 48.51% 52 Missing ⚠️
...rsity/attention_sparsity/calibration/calibrator.py 9.75% 37 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 40.42% 28 Missing ⚠️
modelopt/torch/export/unified_export_hf.py 26.08% 17 Missing ⚠️
...arsity/attention_sparsity/calibration/calibrate.py 19.04% 17 Missing ⚠️
...y/attention_sparsity/methods/flash_skip_softmax.py 14.28% 6 Missing ⚠️
modelopt/torch/kernels/__init__.py 33.33% 2 Missing ⚠️
... and 4 more
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1269       +/-   ##
===========================================
- Coverage   75.58%   55.31%   -20.28%     
===========================================
  Files         459      460        +1     
  Lines       48612    49345      +733     
===========================================
- Hits        36745    27295     -9450     
- Misses      11867    22050    +10183     
Flag Coverage Δ
unit 51.95% <25.38%> (-0.27%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

jingyu-ml and others added 5 commits April 15, 2026 22:14
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml changed the title Jingyux/diffusion skip softmax 2 Skip Softmax diffusion export Apr 16, 2026
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml marked this pull request as ready for review April 17, 2026 00:03
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (5)
modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py (1)

387-394: Minor: Consider wrapping set_skip_softmax_context(True) in the ExitStack for exception safety.

If an exception occurs after set_skip_softmax_context(True) but before the callback is registered (e.g., during stack.enter_context), the skip-softmax context would remain enabled without cleanup.

♻️ Safer ordering
         from ..kernels import set_skip_softmax_context

         stack = ExitStack()
+        stack.callback(set_skip_softmax_context, False)
         set_skip_softmax_context(True)
-        stack.callback(set_skip_softmax_context, False)

         stack.enter_context(replace_function(torch.nn.functional, "softmax", sparse_softmax))
         return stack

This ensures the cleanup callback is registered before the state is modified, so any subsequent exception during setup will still trigger cleanup when the stack is garbage collected or explicitly closed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 387 - 394, Register the cleanup callback on the ExitStack before
enabling the skip-softmax flag to ensure exception safety: call
stack.callback(set_skip_softmax_context, False) first, then call
set_skip_softmax_context(True), and only after that perform
stack.enter_context(replace_function(torch.nn.functional, "softmax",
sparse_softmax)); this guarantees that if an exception occurs during
enter_context the skip-softmax state (managed by set_skip_softmax_context) will
still be cleaned up.
modelopt/torch/export/diffusers_utils.py (1)

49-59: Consider making this a one-time warning or moving it to actual usage.

This warning fires at module import time whenever ltx_pipelines is installed, which may be noisy for users who import diffusers_utils but don't use LTX-2 features. Additionally, stacklevel=2 at module load time may not point to a meaningful location.

Consider using warnings.warn(..., stacklevel=2) with filterwarnings to show once, or deferring the warning to actual LTX-2 usage (similar to line 395-404 in _ltx2_dummy_forward).

♻️ One-time warning option
import warnings

# At the top of the module
_LTX2_LICENSE_WARNING_SHOWN = False

# Inside the try block
try:
    from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline as _TI2VidTwoStagesPipeline
    
    def _show_ltx2_license_warning():
        global _LTX2_LICENSE_WARNING_SHOWN
        if not _LTX2_LICENSE_WARNING_SHOWN:
            warnings.warn(
                "LTX-2 packages ... (license text)",
                UserWarning,
                stacklevel=3,
            )
            _LTX2_LICENSE_WARNING_SHOWN = True
    
    TI2VidTwoStagesPipeline = _TI2VidTwoStagesPipeline
except Exception:
    TI2VidTwoStagesPipeline = None

Then call _show_ltx2_license_warning() in functions that actually use LTX-2.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/diffusers_utils.py` around lines 49 - 59, The module
currently emits a loud license UserWarning at import via the warnings.warn call;
change this to a one-time or deferred warning: wrap the import of ltx_pipelines
and the current warnings.warn invocation behind a module-level flag (e.g.,
_LTX2_LICENSE_WARNING_SHOWN) or remove the warn from import and instead call a
small helper like _show_ltx2_license_warning() from actual LTX-2 entrypoints
(for example inside TI2VidTwoStagesPipeline usage code or _ltx2_dummy_forward)
so the warning is emitted at first use only; keep stacklevel at an appropriate
value (e.g., 3) when invoking warnings.warn to point to user code and ensure
TI2VidTwoStagesPipeline remains set to None on import failure.
modelopt/torch/export/unified_export_hf.py (1)

1261-1266: Remove redundant import.

The yaml module is already imported at line 31. This inline import is unnecessary.

♻️ Proposed fix
                 config_data["sparse_attention_config"] = sparse_attn_config

                 # Also save as standalone YAML for easy inspection and reuse
-                import yaml
-
                 yaml_path = Path(export_dir) / "sparse.yaml"
                 with open(yaml_path, "w") as file:
                     yaml.dump(sparse_attn_config, file, default_flow_style=False, sort_keys=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 1261 - 1266, Remove
the redundant inline "import yaml" inside the export block; the module is
already imported at the top of the file, so delete the inline import and keep
the yaml usage (yaml_path = Path(export_dir) / "sparse.yaml" and
yaml.dump(sparse_attn_config, ...)) as-is to avoid duplicate imports and retain
the YAML file write behavior.
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)

314-318: Consider documenting this asymmetry.

The decode phase always requires calibration_data and tokenizer (RULER-based), even when a custom forward_loop was provided for prefill. This asymmetry between prefill (supports custom forward_loop) and decode (always requires RULER) could be confusing.

Consider adding a docstring note or raising a more descriptive error message explaining why decode requires the RULER dataset.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 314 - 318, Update the documentation and error to make the prefill/decode
asymmetry explicit: add a docstring note on the top-level function in
calibrate.py (where decode_forward_loop is created) explaining that prefill
accepts a custom forward_loop but decode always requires a RULER-style
calibration_data and tokenizer because decode uses
create_decode_calibration_forward_loop; and replace the RuntimeError raised when
calibration_data or tokenizer is missing with a more descriptive message that
states "decode requires a RULER-style calibration_data and tokenizer (used by
create_decode_calibration_forward_loop) even if a custom prefill forward_loop
was provided." Reference the symbols calibration_data, tokenizer,
create_decode_calibration_forward_loop, and decode_forward_loop in the
docstring/error.
examples/diffusers/sparsity/wan22_skip_softmax.py (1)

56-63: Lazy-load the optional datasets/diffusers dependencies.

This example hard-imports optional integrations at module load time, which makes simple imports fail unless the full diffusers stack is already installed. Moving these imports into build_pipeline(), load_calib_prompts(), and main() keeps the example gated behind the right extras and avoids breaking unrelated tooling that imports example modules. As per coding guidelines, "Gate optional features by install extras ([onnx], [hf], [all]); avoid hard imports of optional dependencies at module level" and "Use the plugin system with import_plugin() for lazy loading of optional integrations (HuggingFace, Megatron, etc.)".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 56 - 63, The
module currently hard-imports optional packages (datasets, diffusers,
diffusers.utils.export_to_video and diffusers classes
AutoencoderKLWan/WanPipeline) at top-level; move those imports into the
functions that actually use them (e.g., build_pipeline(), load_calib_prompts(),
and main()) so the example can be imported without the optional extras
installed, and prefer the plugin loader where available (e.g., call
import_plugin or similar before importing HuggingFace/diffusers modules) to gate
the integrations; update references to SparseAttentionModule and any diffusers
types after the local imports so runtime code uses the lazily-loaded symbols.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 271-276: The function load_calib_prompts currently loads the
entire "train" split which is wasteful; change it to only load the first
calib_size examples by using a sliced split (e.g. "train[:{calib_size}]" or
equivalent) or streaming so only the needed items are materialized, then collect
the captions and return them; update load_calib_prompts to call load_dataset
with the sliced split string based on the calib_size parameter and build the
prompts list from that smaller dataset.

In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 1241-1246: The program index calculation uses a per-program tile
count computed from tl.load(b_seq_len + 0) which collides when batches have
variable lengths; replace the local num_q_tiles = tl.cdiv(tl.load(b_seq_len +
0), BLOCK_M) with a tile count derived from the same max_input_len used when
allocating counters (e.g. num_q_tiles = tl.cdiv(max_input_len, BLOCK_M)), ensure
the kernel signature accepts that max_input_len scalar (or otherwise pass the
same allocation param) and update uses of num_q_tiles/prog_idx/base so each
(batch, head, tile) maps to a unique counter slot matching
attention_calibrate()'s allocation.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 127-136: The current broad try/except hides real errors from
register_diffusers_triton_attention() so integration failures are swallowed;
change the logic to only suppress ImportError when importing ModelMixin but
allow exceptions from register_diffusers_triton_attention() to surface: first
try importing ModelMixin and on ImportError simply return/skip, then if
isinstance(model, ModelMixin) import register_diffusers_triton_attention and if
it's not None call register_diffusers_triton_attention() without a broad except
so any runtime errors propagate (or re-raise after logging) — refer to
ModelMixin and register_diffusers_triton_attention to locate and update the
conversion.py block.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`:
- Around line 123-213: The function _diffusers_triton_attention currently
accepts attn_mask, dropout_p, and enable_gqa but ignores them; either remove
these params from the signature and ensure the backend's _supported_arg_names
(where supported args are derived) no longer lists them, or implement their
semantics: detect attn_mask != None and convert it into the Triton/kernel
expected mask metadata (or pass it via kw as "attn_mask" / appropriate key),
handle dropout_p > 0 by passing a "dropout_p" kw and ensuring training mode
semantics are respected, and honor enable_gqa by adjusting q/k/v shapes/scale
(grouped-query attention layout changes) and passing an "enable_gqa" or
equivalent flag into kw; update callers/registration so _supported_arg_names
matches the final signature/behavior.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`:
- Around line 164-168: The current __call__ wrapper incorrectly routes masked
attention through _ltx_triton_attention which ignores mask and forces
is_causal=False; change the conditional to only use the Triton path when active
and mask is None (e.g., if active and mask is None: return
_ltx_triton_attention(...)); otherwise always call and return
self._original_fn(q, k, v, heads, mask) so mask semantics are preserved. Ensure
you reference _get_ltx_triton_context, _ltx_triton_attention, and
self._original_fn when making the change.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`:
- Around line 255-269: get_threshold_info() currently always reports the static
lambda threshold even when _triton_inference_context() supplies a
skip_softmax_raw_threshold used by the kernel; update get_threshold_info to
check the triton inference context (via _triton_inference_context()) and, if a
skip_softmax_raw_threshold is present, return a "raw" type with that raw
threshold value (and note any related keys like skip_softmax_threshold or
target_sparse_ratio as applicable); otherwise preserve the existing
dynamic_calibrated/static return paths so the reported info matches the actual
value used by the kernel.

---

Nitpick comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 56-63: The module currently hard-imports optional packages
(datasets, diffusers, diffusers.utils.export_to_video and diffusers classes
AutoencoderKLWan/WanPipeline) at top-level; move those imports into the
functions that actually use them (e.g., build_pipeline(), load_calib_prompts(),
and main()) so the example can be imported without the optional extras
installed, and prefer the plugin loader where available (e.g., call
import_plugin or similar before importing HuggingFace/diffusers modules) to gate
the integrations; update references to SparseAttentionModule and any diffusers
types after the local imports so runtime code uses the lazily-loaded symbols.

In `@modelopt/torch/export/diffusers_utils.py`:
- Around line 49-59: The module currently emits a loud license UserWarning at
import via the warnings.warn call; change this to a one-time or deferred
warning: wrap the import of ltx_pipelines and the current warnings.warn
invocation behind a module-level flag (e.g., _LTX2_LICENSE_WARNING_SHOWN) or
remove the warn from import and instead call a small helper like
_show_ltx2_license_warning() from actual LTX-2 entrypoints (for example inside
TI2VidTwoStagesPipeline usage code or _ltx2_dummy_forward) so the warning is
emitted at first use only; keep stacklevel at an appropriate value (e.g., 3)
when invoking warnings.warn to point to user code and ensure
TI2VidTwoStagesPipeline remains set to None on import failure.

In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1261-1266: Remove the redundant inline "import yaml" inside the
export block; the module is already imported at the top of the file, so delete
the inline import and keep the yaml usage (yaml_path = Path(export_dir) /
"sparse.yaml" and yaml.dump(sparse_attn_config, ...)) as-is to avoid duplicate
imports and retain the YAML file write behavior.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 314-318: Update the documentation and error to make the
prefill/decode asymmetry explicit: add a docstring note on the top-level
function in calibrate.py (where decode_forward_loop is created) explaining that
prefill accepts a custom forward_loop but decode always requires a RULER-style
calibration_data and tokenizer because decode uses
create_decode_calibration_forward_loop; and replace the RuntimeError raised when
calibration_data or tokenizer is missing with a more descriptive message that
states "decode requires a RULER-style calibration_data and tokenizer (used by
create_decode_calibration_forward_loop) even if a custom prefill forward_loop
was provided." Reference the symbols calibration_data, tokenizer,
create_decode_calibration_forward_loop, and decode_forward_loop in the
docstring/error.

In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 387-394: Register the cleanup callback on the ExitStack before
enabling the skip-softmax flag to ensure exception safety: call
stack.callback(set_skip_softmax_context, False) first, then call
set_skip_softmax_context(True), and only after that perform
stack.enter_context(replace_function(torch.nn.functional, "softmax",
sparse_softmax)); this guarantees that if an exception occurs during
enter_context the skip-softmax state (managed by set_skip_softmax_context) will
still be cleaned up.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: b4db3917-e462-4d17-b588-79e5f63acc4d

📥 Commits

Reviewing files that changed from the base of the PR and between 04fcf24 and b77d098.

📒 Files selected for processing (23)
  • CHANGELOG.rst
  • examples/diffusers/README.md
  • examples/diffusers/sparsity/README.md
  • examples/diffusers/sparsity/wan22_skip_softmax.py
  • modelopt/torch/export/diffusers_utils.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/kernels/__init__.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
  • modelopt/torch/sparsity/attention_sparsity/stats_manager.py
  • tests/_test_utils/torch/diffusers_models.py
  • tests/examples/diffusers/test_sparsity.py
  • tests/unit/torch/sparsity/attention_sparsity/test_kernel_backends.py

Comment on lines +271 to +276
def load_calib_prompts(calib_size: int) -> list[str]:
"""Load calibration prompts from OpenVid-1M dataset."""
dataset = load_dataset("nkp37/OpenVid-1M", split="train")
prompts = list(dataset["caption"][:calib_size])
print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M")
return prompts
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t load the full OpenVid-1M split just to grab a few prompts.

load_dataset(..., split="train") materializes the whole split metadata, and then this code only uses the first calib_size captions. For the default calib_size=4, that is a lot of unnecessary I/O and makes calibration much slower and more fragile than it needs to be.

🔧 Suggested fix
 def load_calib_prompts(calib_size: int) -> list[str]:
     """Load calibration prompts from OpenVid-1M dataset."""
-    dataset = load_dataset("nkp37/OpenVid-1M", split="train")
-    prompts = list(dataset["caption"][:calib_size])
+    dataset = load_dataset("nkp37/OpenVid-1M", split=f"train[:{calib_size}]")
+    prompts = list(dataset["caption"])
     print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M")
     return prompts
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 271 - 276,
The function load_calib_prompts currently loads the entire "train" split which
is wasteful; change it to only load the first calib_size examples by using a
sliced split (e.g. "train[:{calib_size}]" or equivalent) or streaming so only
the needed items are materialized, then collect the captions and return them;
update load_calib_prompts to call load_dataset with the sliced split string
based on the calib_size parameter and build the prompts list from that smaller
dataset.

Comment thread modelopt/torch/kernels/triton_fa.py Outdated
Comment on lines +1241 to +1246
# --- Write per-program counters (no atomics, just stores) ---
# Compute unique flat program index for this (batch, head, q_tile)
num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) # conservative upper bound
num_heads = tl.num_programs(1)
prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q
base = prog_idx * NUM_THRESHOLDS
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Program indexing collides for variable-length batches.

attention_calibrate() allocates one counter row per launched program using triton.cdiv(max_input_len, BLOCK_M), but the kernel flattens (batch, head, tile) with tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M). If batch element 0 is shorter than another sequence, different programs write into the same counter slots and the exported calibration stats are wrong.

🔧 Suggested fix
-    num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M)  # conservative upper bound
+    num_q_tiles = tl.num_programs(2)
     num_heads = tl.num_programs(1)
     prog_idx = batch_idx * num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 1241 - 1246, The program
index calculation uses a per-program tile count computed from tl.load(b_seq_len
+ 0) which collides when batches have variable lengths; replace the local
num_q_tiles = tl.cdiv(tl.load(b_seq_len + 0), BLOCK_M) with a tile count derived
from the same max_input_len used when allocating counters (e.g. num_q_tiles =
tl.cdiv(max_input_len, BLOCK_M)), ensure the kernel signature accepts that
max_input_len scalar (or otherwise pass the same allocation param) and update
uses of num_q_tiles/prog_idx/base so each (batch, head, tile) maps to a unique
counter slot matching attention_calibrate()'s allocation.

Comment thread modelopt/torch/sparsity/attention_sparsity/conversion.py
Comment on lines +123 to +213
def _diffusers_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
) -> torch.Tensor:
"""Compute attention via Triton FA kernel on diffusers layout ``[B, S, H, D]``."""
batch, seq_q, num_heads_q, head_dim = query.shape
seq_k = key.shape[1]
device = query.device

# Reshape from diffusers [B, S, H, D] -> flat [B*S, H, D]
q = query.reshape(batch * seq_q, num_heads_q, head_dim).contiguous()
k = key.reshape(batch * seq_k, key.shape[2], head_dim).contiguous()
v = value.reshape(batch * seq_k, value.shape[2], head_dim).contiguous()

# Build varlen metadata
b_start_loc_q = torch.arange(batch, device=device, dtype=torch.int32) * seq_q
b_seq_len_q = torch.full((batch,), seq_q, device=device, dtype=torch.int32)

if scale is None:
scale = 1.0 / math.sqrt(head_dim)

kw: dict = {
"b_start_loc": b_start_loc_q,
"b_seq_len": b_seq_len_q,
"max_input_len": seq_q,
"is_causal": is_causal,
"softmax_scale": scale,
}

if seq_q != seq_k:
b_start_loc_k = torch.arange(batch, device=device, dtype=torch.int32) * seq_k
b_seq_len_k = torch.full((batch,), seq_k, device=device, dtype=torch.int32)
kw["b_start_loc_k"] = b_start_loc_k
kw["b_seq_len_k"] = b_seq_len_k
kw["max_input_len_k"] = seq_k

# --- Calibration mode: collect multi-threshold stats ---
calib_mode = getattr(_thread_local, "calibration_mode", False)
if calib_mode:
trials = getattr(_thread_local, "threshold_trials", None)
if trials and attention_calibrate is not None:
o, counters = attention_calibrate(q, k, v, **kw, threshold_trials=trials)

# Accumulate counters across all attention calls in this forward pass
prev = getattr(_thread_local, "calibration_counters", None)
if prev is None:
_thread_local.calibration_counters = counters
else:
_thread_local.calibration_counters = prev + counters

# Store actual KV sequence length for calibration stats
_thread_local.calibration_seq_k = seq_k

return o.view(batch, seq_q, num_heads_q, head_dim)

# --- Inference mode: skip-softmax with raw, dynamic, or static threshold ---
raw_thresh = getattr(_thread_local, "raw_threshold", None)
if raw_thresh is not None:
# Raw threshold: passed directly to kernel as skip_threshold_log2
kw["skip_softmax_raw_threshold"] = raw_thresh
else:
scale_factor = getattr(_thread_local, "scale_factor", None)
if scale_factor is not None and scale_factor > 0.0:
# Dynamic threshold: adapt to actual sequence length
kw["skip_softmax_threshold"] = scale_factor / seq_k
else:
threshold = getattr(_thread_local, "skip_threshold", None)
if threshold is not None and threshold > 0.0:
kw["skip_softmax_threshold"] = threshold

assert attention is not None, "Triton attention kernel not available (requires CUDA + triton)"
do_measure = getattr(_thread_local, "measure_sparsity", False)
if do_measure:
kw["measure_sparsity"] = True
o = attention(q, k, v, **kw)

# Accumulate runtime sparsity counters from the kernel output
if do_measure and hasattr(o, "_sparsity_total"):
prev_total = getattr(_thread_local, "sparsity_total", 0)
prev_skipped = getattr(_thread_local, "sparsity_skipped", 0)
_thread_local.sparsity_total = prev_total + o._sparsity_total
_thread_local.sparsity_skipped = prev_skipped + o._sparsity_skipped

return o.view(batch, seq_q, num_heads_q, head_dim)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In diffusers 0.32.2 custom attention backends, can attn_mask, dropout_p, or enable_gqa be non-default for registered backends, and are custom backends expected to preserve those semantics?

💡 Result:

In diffusers 0.32.2, custom attention backends registered via _AttentionBackendRegistry can have non-default support for attn_mask, dropout_p, and enable_gqa. The dispatch_attention_fn passes these parameters (along with query, key, value, is_causal, scale, attention_kwargs) to the registered backend function. Only parameters matching the backend's supported_arg_names are passed; unsupported ones are filtered out without error (though checks may raise if enabled). Custom backends are expected to handle these parameters if listed in their supported args during registration, or ignore/filter them as appropriate to preserve semantics when non-None/non-default. Some registered backends do not support attn_mask, leading to silent ignoring unless checks are enabled (DIFFUSERS_ATTN_CHECKS=yes), which validates shapes and raises errors for mismatches.

Citations:


🏁 Script executed:

# Find the file and look for registration/decorator code
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | head -150

Repository: NVIDIA/Model-Optimizer

Length of output: 6967


🏁 Script executed:

# Continue reading to find registration/decorator code
cat -n modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py | tail -100

Repository: NVIDIA/Model-Optimizer

Length of output: 4972


🏁 Script executed:

# Double-check the signature one more time to ensure accuracy
python3 << 'EOF'
import inspect
sig = """def _diffusers_triton_attention(
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attn_mask: torch.Tensor | None = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: float | None = None,
    enable_gqa: bool = False,
) -> torch.Tensor:"""

# Extract parameter names
params = ['query', 'key', 'value', 'attn_mask', 'dropout_p', 'is_causal', 'scale', 'enable_gqa']
print("Function parameters:", params)
print("\nParameters with defaults that could be non-default:")
print("- attn_mask (default None, but can be passed as Tensor)")
print("- dropout_p (default 0.0, but can be non-zero)")
print("- enable_gqa (default False, but can be True)")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 376


Remove unsupported parameters from the function signature or implement their semantics.

This backend is registered with attn_mask, dropout_p, and enable_gqa in its _supported_arg_names (lines 237–239 derive supported args from the function signature), which means diffusers will pass these arguments to the implementation. However, lines 123–212 silently ignore all three, computing plain attention regardless. This breaks correctness if a caller provides a non-None attention mask, non-zero dropout, or enables GQA.

Either:

  • Remove these parameters from the signature and handle the filtered-out case explicitly, or
  • Implement the masking, dropout, and GQA semantics in the Triton path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@modelopt/torch/sparsity/attention_sparsity/kernels/diffusers_triton_attention.py`
around lines 123 - 213, The function _diffusers_triton_attention currently
accepts attn_mask, dropout_p, and enable_gqa but ignores them; either remove
these params from the signature and ensure the backend's _supported_arg_names
(where supported args are derived) no longer lists them, or implement their
semantics: detect attn_mask != None and convert it into the Triton/kernel
expected mask metadata (or pass it via kw as "attn_mask" / appropriate key),
handle dropout_p > 0 by passing a "dropout_p" kw and ensuring training mode
semantics are respected, and honor enable_gqa by adjusting q/k/v shapes/scale
(grouped-query attention layout changes) and passing an "enable_gqa" or
equivalent flag into kw; update callers/registration so _supported_arg_names
matches the final signature/behavior.

Comment on lines +164 to +168
def __call__(self, q, k, v, heads, mask=None):
active, threshold, _scale_factor = _get_ltx_triton_context()
if active:
return _ltx_triton_attention(q, k, v, heads, mask, threshold)
return self._original_fn(q, k, v, heads, mask)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

In ltx_core's Attention.attention_function API, can the mask argument be non-None during normal inference or training, and what semantics must a custom backend preserve for it?

💡 Result:

Yes, the mask argument in ltx_core's Attention.attention_function API can be non-None during both normal inference and training. The API signature for AttentionCallable.call is def call(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask: torch.Tensor | None = None) -> torch.Tensor:. The Attention module's forward method passes a potentially non-None mask from the Modality.attention_mask, which is described as "Optional 2-D self-attention mask, shape (B, T, T). Values in [0, 1] where 1 = full attention and 0 = no attention. None means unrestricted (full) attention between all tokens. Built incrementally by conditioning items." Custom backends must preserve the semantics of applying the mask as an additive attention bias in scaled dot-product attention, matching PyTorch's torch.nn.functional.scaled_dot_product_attention behavior where attn_mask=mask is passed with dropout_p=0.0 and is_causal=False. The PytorchAttention implementation reshapes the mask if ndim==2 by unsqueezing to add batch dim, then passes directly. XFormers and FlashAttention implementations also accept and use the mask parameter. During inference, dropout_p=0.0 is used (no dropout), and the same API applies to training where masks from conditioning (e.g., attention_strength_wrapper) can be non-None. Perturbations skip entire attention computations but do not alter mask handling when attention runs.

Citations:


Fall back when mask is provided until Triton path supports masking.

The _ltx_triton_attention() function ignores the mask parameter and always sets is_causal=False. Since ltx_core passes non-None masks during normal inference and training, and custom backends must preserve mask semantics as an additive attention bias, this wrapper changes attention behavior for any masked case. Only take the Triton path when mask is None; otherwise, defer to self._original_fn(...).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py`
around lines 164 - 168, The current __call__ wrapper incorrectly routes masked
attention through _ltx_triton_attention which ignores mask and forces
is_causal=False; change the conditional to only use the Triton path when active
and mask is None (e.g., if active and mask is None: return
_ltx_triton_attention(...)); otherwise always call and return
self._original_fn(q, k, v, heads, mask) so mask semantics are preserved. Ensure
you reference _get_ltx_triton_context, _ltx_triton_attention, and
self._original_fn when making the change.

Comment on lines +255 to +269
def get_threshold_info(self) -> dict:
"""Get threshold information for debugging/display."""
scale_factor = self._get_scale_factor()
if scale_factor is not None:
return {
"type": "dynamic_calibrated",
"formula": "threshold = scale_factor / seq_k (computed at runtime)",
"scale_factor": scale_factor,
"calibration_params": self.calibration_params,
"target_sparse_ratio": self.target_sparse_ratio,
}
return {
"type": "static",
"value": self.skip_softmax_threshold,
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Report raw-threshold mode in get_threshold_info().

_triton_inference_context() gives skip_softmax_raw_threshold highest priority, but get_threshold_info() still reports the static lambda threshold. In raw-threshold runs the printed summary is therefore misleading even though the kernel is using a different value.

🔧 Suggested fix
     def get_threshold_info(self) -> dict:
         """Get threshold information for debugging/display."""
+        if self.skip_softmax_raw_threshold is not None:
+            return {
+                "type": "raw",
+                "value": self.skip_softmax_raw_threshold,
+            }
         scale_factor = self._get_scale_factor()
         if scale_factor is not None:
             return {
                 "type": "dynamic_calibrated",
                 "formula": "threshold = scale_factor / seq_k (computed at runtime)",
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py`
around lines 255 - 269, get_threshold_info() currently always reports the static
lambda threshold even when _triton_inference_context() supplies a
skip_softmax_raw_threshold used by the kernel; update get_threshold_info to
check the triton inference context (via _triton_inference_context()) and, if a
skip_softmax_raw_threshold is present, return a "raw" type with that raw
threshold value (and note any related keys like skip_softmax_threshold or
target_sparse_ratio as applicable); otherwise preserve the existing
dynamic_calibrated/static return paths so the reported info matches the actual
value used by the kernel.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
examples/diffusers/sparsity/wan22_skip_softmax.py (1)

62-63: Lazy-load export_hf_checkpoint only when --export-dir is used.

Line [62] makes export dependencies mandatory even for regular inference runs. Move this import into the export branch.

♻️ Suggested refactor
 import modelopt.torch.sparsity.attention_sparsity as mtsa
-from modelopt.torch.export import export_hf_checkpoint
 from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule
@@
     # ---- Export (optional) ----
     if args.export_dir and not args.baseline:
+        from modelopt.torch.export import export_hf_checkpoint
+
         print(f"Exporting sparsified checkpoint to {args.export_dir}...")
         export_hf_checkpoint(pipe, export_dir=args.export_dir)

As per coding guidelines: "Avoid hard imports of optional dependencies at module level; features should be gated by install extras ([onnx], [hf], [all]) and loaded lazily via import_plugin()."

Also applies to: 456-460

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_skip_softmax.py` around lines 62 - 63, The
top-level import export_hf_checkpoint makes optional HF/ONNX dependencies
mandatory; remove the module-level import of export_hf_checkpoint and instead
lazily import it inside the branch that handles exporting (the code path that
checks args.export_dir or the export branch around where lines 456-460 run),
e.g., perform a local import of export_hf_checkpoint (or use import_plugin())
just before calling it in the export branch and also remove any other hard
imports of HF-export helpers at module scope (referencing export_hf_checkpoint
and the export-related code near lines 456-460).
modelopt/torch/export/unified_export_hf.py (1)

31-31: Make YAML dependency lazy and prefer safe_dump.

Line [31] hard-imports YAML for all imports of this module, but YAML is only needed on sparse-export paths. Also, use safe_dump in both write sites for predictable output.

♻️ Suggested refactor
-import yaml
@@
     if pipeline_sparse_configs:
+        import yaml
+
         yaml_path = export_dir / "sparse.yaml"
         with open(yaml_path, "w") as file:
-            yaml.dump(pipeline_sparse_configs, file, default_flow_style=False, sort_keys=False)
+            yaml.safe_dump(pipeline_sparse_configs, file, default_flow_style=False, sort_keys=False)
@@
-                # Also save as standalone YAML for easy inspection and reuse
-                import yaml
-
+                # Also save as standalone YAML for easy inspection and reuse
+                import yaml
                 yaml_path = Path(export_dir) / "sparse.yaml"
                 with open(yaml_path, "w") as file:
-                    yaml.dump(sparse_attn_config, file, default_flow_style=False, sort_keys=False)
+                    yaml.safe_dump(sparse_attn_config, file, default_flow_style=False, sort_keys=False)

As per coding guidelines: "Avoid hard imports of optional dependencies at module level; features should be gated by install extras ([onnx], [hf], [all]) and loaded lazily via import_plugin()."

Also applies to: 1037-1040, 1261-1267

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` at line 31, Remove the
module-level "import yaml" and instead lazily import YAML only where it’s needed
using import_plugin() inside the sparse-export code paths, e.g. at the start of
the functions/branches that emit YAML and assign the result to a local yaml
variable; also replace any uses of yaml.dump with yaml.safe_dump at both YAML
write sites in this module (the two sparse-export write branches) so the
optional dependency is only loaded on demand and safer serialization is used.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py`:
- Around line 349-351: The code currently assumes result contains "log_a" when
result.get("fit_logspace") is true, which can raise a KeyError; update the block
that sets params["log_a"] and params["fit_logspace"] to first confirm "log_a"
exists (e.g., if "fit_logspace" in result and "log_a" in result) or use
result.get("log_a") and only assign params["log_a"] when that value is not None,
and still set params["fit_logspace"]=True when fitting was attempted; locate
this logic around the params/result handling in calibrate.py (the result dict,
params dict, and the "fit_logspace"/"log_a" keys) and add the guard so
incomplete calibration output cannot break the flow.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py`:
- Around line 439-441: The early-return drops per-layer disable metadata:
instead of returning None when calibration_params and raw_threshold are both
None, check if disabled_layer_names (or the variable storing disabled layers) is
non-empty and, if so, return an export payload that contains the
disabled_layer_names metadata (even if calibration_params and raw_threshold are
absent); otherwise keep returning None. Update the conditional around
calibration_params and raw_threshold to preserve/emit disabled_layer_names in
the exported structure so disabled-layer-only exports are retained.

---

Nitpick comments:
In `@examples/diffusers/sparsity/wan22_skip_softmax.py`:
- Around line 62-63: The top-level import export_hf_checkpoint makes optional
HF/ONNX dependencies mandatory; remove the module-level import of
export_hf_checkpoint and instead lazily import it inside the branch that handles
exporting (the code path that checks args.export_dir or the export branch around
where lines 456-460 run), e.g., perform a local import of export_hf_checkpoint
(or use import_plugin()) just before calling it in the export branch and also
remove any other hard imports of HF-export helpers at module scope (referencing
export_hf_checkpoint and the export-related code near lines 456-460).

In `@modelopt/torch/export/unified_export_hf.py`:
- Line 31: Remove the module-level "import yaml" and instead lazily import YAML
only where it’s needed using import_plugin() inside the sparse-export code
paths, e.g. at the start of the functions/branches that emit YAML and assign the
result to a local yaml variable; also replace any uses of yaml.dump with
yaml.safe_dump at both YAML write sites in this module (the two sparse-export
write branches) so the optional dependency is only loaded on demand and safer
serialization is used.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9abb7cd1-0818-4739-a05f-4d886d9b4a0b

📥 Commits

Reviewing files that changed from the base of the PR and between b77d098 and 65f380d.

📒 Files selected for processing (5)
  • CHANGELOG.rst
  • examples/diffusers/sparsity/wan22_skip_softmax.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst

Comment on lines +349 to +351
if result.get("fit_logspace"):
params["log_a"] = result["log_a"]
params["fit_logspace"] = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard log_a access when fit_logspace is enabled.

Line [350] directly indexes result["log_a"]; if calibration output is incomplete, this throws and drops the whole calibration flow.

🔧 Suggested fix
             if result.get("fit_logspace"):
-                params["log_a"] = result["log_a"]
-                params["fit_logspace"] = True
+                if "log_a" not in result:
+                    warnings.warn(f"{phase} calibration marked fit_logspace=True but missing log_a")
+                else:
+                    params["log_a"] = result["log_a"]
+                    params["fit_logspace"] = True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 349 - 351, The code currently assumes result contains "log_a" when
result.get("fit_logspace") is true, which can raise a KeyError; update the block
that sets params["log_a"] and params["fit_logspace"] to first confirm "log_a"
exists (e.g., if "fit_logspace" in result and "log_a" in result) or use
result.get("log_a") and only assign params["log_a"] when that value is not None,
and still set params["fit_logspace"]=True when fitting was attempted; locate
this logic around the params/result handling in calibrate.py (the result dict,
params dict, and the "fit_logspace"/"log_a" keys) and add the guard so
incomplete calibration output cannot break the flow.

Comment on lines +439 to 441
# Nothing exportable if no calibration params and no raw threshold
if calibration_params is None and raw_threshold is None:
return None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don’t drop disabled_layers-only exports.

Lines [439]-[441] return None even when disabled_layer_names was populated, so per-layer disable metadata is lost for non-calibrated/static-threshold runs.

🔧 Suggested fix
-    if calibration_params is None and raw_threshold is None:
+    if calibration_params is None and raw_threshold is None and not disabled_layer_names:
         return None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 439 -
441, The early-return drops per-layer disable metadata: instead of returning
None when calibration_params and raw_threshold are both None, check if
disabled_layer_names (or the variable storing disabled layers) is non-empty and,
if so, return an export payload that contains the disabled_layer_names metadata
(even if calibration_params and raw_threshold are absent); otherwise keep
returning None. Update the conditional around calibration_params and
raw_threshold to preserve/emit disabled_layer_names in the exported structure so
disabled-layer-only exports are retained.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_hf.py (1)

1045-1046: ⚡ Quick win

Use yaml.safe_dump() for exported sparse artifacts to ensure downstream yaml.safe_load() compatibility.

Both YAML writes at lines 1046 and 1286 currently use yaml.dump(...). Prefer yaml.safe_dump(...) to ensure emitted files stay in the safe YAML subset and remain compatible with downstream consumers using yaml.safe_load() workflows.

Suggested changes
-            yaml.dump(pipeline_sparse_configs, file, default_flow_style=False, sort_keys=False)
+            yaml.safe_dump(pipeline_sparse_configs, file, default_flow_style=False, sort_keys=False)
-                    yaml.dump(sparse_attn_config, file, default_flow_style=False, sort_keys=False)
+                    yaml.safe_dump(sparse_attn_config, file, default_flow_style=False, sort_keys=False)

The data structures contain only basic Python types (dicts, lists, strings, numbers) and are fully compatible with yaml.safe_dump(), which supports the same parameters.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/export/unified_export_hf.py` around lines 1045 - 1046, Replace
unsafe yaml.dump calls with yaml.safe_dump when writing exported sparse
artifacts: change the write that opens yaml_path and dumps
pipeline_sparse_configs to use yaml.safe_dump(..., default_flow_style=False,
sort_keys=False) and do the same for the other yaml.dump call in this module
that emits the sparse artifact (the second YAML write around the later export
step). Keep the same parameters (default_flow_style=False, sort_keys=False) so
output stays identical but safe_load-compatible.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1045-1046: Replace unsafe yaml.dump calls with yaml.safe_dump when
writing exported sparse artifacts: change the write that opens yaml_path and
dumps pipeline_sparse_configs to use yaml.safe_dump(...,
default_flow_style=False, sort_keys=False) and do the same for the other
yaml.dump call in this module that emits the sparse artifact (the second YAML
write around the later export step). Keep the same parameters
(default_flow_style=False, sort_keys=False) so output stays identical but
safe_load-compatible.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b3ad4ebe-4472-4ec2-a748-bd5aa4a8e8e1

📥 Commits

Reviewing files that changed from the base of the PR and between 65f380d and 5eb9352.

📒 Files selected for processing (2)
  • CHANGELOG.rst
  • modelopt/torch/export/unified_export_hf.py
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant