Conversation
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughAdds VSA (vendor-specific) sparse-attention support alongside Skip-Softmax: new VSA method behavior, LTX‑2 and Wan‑2.2 model plugins, plugin-system import/registration changes, example script/README updates, and unit tests exercising the new plugins and wiring. ChangesVSA feature, plugins, wiring, examples, and tests
Sequence Diagram(s)sequenceDiagram
autonumber
participant Transformer as Transformer (root model)
participant Hook as Pre-hook (wan22/ltx2)
participant SparseModule as SparseAttentionModule
participant VSA as VSA.method (forward_attention)
Transformer->>Hook: forward (hidden_states / Modality.positions)
Hook-->>Transformer: compute & store _vsa_video_shape
Hook->>SparseModule: set method.video_shape (if VSA & auto)
SparseModule->>VSA: forward_attention(q,k,v, video_shape, gate_compress)
VSA->>SparseModule: sparse_output (+ stats)
SparseModule->>Transformer: apply to_out → return output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
|
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (3)
examples/diffusers/sparsity/ltx2_vsa.py (1)
176-182: Consider including the original exception as the cause.The re-raised
ImportErrorincludes the original error message as a string, but chaining withfrom _LTX_IMPORT_ERRORwould preserve the full traceback for debugging.Proposed fix
if not _LTX_AVAILABLE: raise ImportError( "LTX-2 packages are required for this example. Install with: " - "pip install ltx-core ltx-trainer ltx-pipelines. " - f"(original error: {_LTX_IMPORT_ERROR})" - ) + "pip install ltx-core ltx-trainer ltx-pipelines." + ) from _LTX_IMPORT_ERROR🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/ltx2_vsa.py` around lines 176 - 182, The ImportError raised in main() discards the original exception traceback; modify the raise to chain the original exception by using "raise ImportError(...)" with "from _LTX_IMPORT_ERROR" so the original _LTX_IMPORT_ERROR is preserved in the traceback (referencing the main function and the _LTX_IMPORT_ERROR symbol).examples/diffusers/sparsity/wan22_sparse_attn.py (1)
263-268: Consider more robust error handling in_parse_int_triple.The function raises a generic
ValueErrorfor invalid input. For better UX, consider usingargparse.ArgumentTypeErrorwhen used as an argument type converter, or providing more specific error messages distinguishing between parse failures and validation failures.Proposed enhancement
def _parse_int_triple(spec: str) -> tuple[int, int, int]: """Parse 'T,H,W' into a triple of positive ints.""" + try: - parts = [int(p.strip()) for p in spec.split(",")] + parts = [int(p.strip()) for p in spec.split(",")] + except ValueError: + raise ValueError(f"expected 3 comma-separated integers T,H,W — got {spec!r}") if len(parts) != 3 or any(p <= 0 for p in parts): - raise ValueError(f"expected 3 positive integers T,H,W — got {spec!r}") + raise ValueError(f"expected 3 positive integers T,H,W — got {parts!r}") return (parts[0], parts[1], parts[2])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/diffusers/sparsity/wan22_sparse_attn.py` around lines 263 - 268, The _parse_int_triple function currently raises a generic ValueError; change it to raise argparse.ArgumentTypeError so it can be used as an argparse type converter and improve error clarity by distinguishing parse failures from validation failures: catch exceptions from int(...) and raise ArgumentTypeError with a message like "invalid int triple: failed to parse T,H,W from '...'", and if parsing succeeds but length or positivity checks fail raise ArgumentTypeError with a message like "expected 3 positive integers T,H,W — got '...'". Update references to _parse_int_triple to import argparse if needed.modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py (1)
197-209: Consider lazy import at module level or caching the import.The
ltx_core.model.transformer.rope.apply_rotary_embimport inside_compute_qkvwill be executed on every forward pass. While Python caches imports, moving this to a module-level lazy import pattern (similar to_load_sparsity_helpersintriton_fa.py) would make the dependency check explicit and slightly more efficient.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py` around lines 197 - 209, The inline import of ltx_core.model.transformer.rope.apply_rotary_emb inside _compute_qkv causes repeated import attempts at every forward; instead implement a module-level lazy loader (e.g., a helper like _load_apply_rotary_emb) that on first call imports apply_rotary_emb, caches it in a module-level variable, and raises the same ModuleNotFoundError with the existing message if unavailable; then replace the local import with a call to that loader and call the cached apply_rotary_emb(query, pe, self.rope_type) / apply_rotary_emb(key, pe if k_pe is None else k_pe, self.rope_type). Ensure you reference rope_type and k_pe handling exactly as in the current code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/kernels/__init__.py`:
- Line 16: Restore a compatibility shim in modelopt.torch.kernels.__init__.py
that re-exports the former public symbols (e.g., attention, attention_calibrate,
IS_AVAILABLE) from their new location modelopt.torch.kernels.common.attention
and emit a DeprecationWarning on import; specifically import those symbols from
modelopt.torch.kernels.common.attention, set them in the package namespace, and
call warnings.warn with a clear deprecation message indicating the new import
path so old code using modelopt.torch.kernels.attention continues to work while
notifying users to update.
In `@modelopt/torch/kernels/sparsity/attention/calibrate.py`:
- Around line 214-239: Reject non-positive threshold candidates before calling
math.log2: validate the caller-provided threshold_trials list (used to build
threshold_tensor) and raise a clear ValueError if any value <= 0; then proceed
to construct threshold_tensor (the current list comprehension using math.log2(t)
* sm_scale) knowing all values are > 0. Locate the logic around threshold_trials
and threshold_tensor in calibrate.py (symbols: threshold_trials,
threshold_tensor, math.log2) and add the check so the error message explicitly
states which input is invalid.
- Around line 157-166: The prog_idx flattening uses a per-program num_q_tiles
computed from tl.load(b_seq_len + 0) (sequence length of batch 0), causing
aliasing across batches; change num_q_tiles to the launch-wide Q-tile count (use
tl.num_programs(2) or equivalent launch dimension) so prog_idx = batch_idx *
num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q uses the global tile
count; update the same computation in the other occurrence (the block around
lines 243-286) so both Per_program_totals / Per_program_skipped use the
launch-wide Q-tile stride.
In `@modelopt/torch/sparsity/attention_sparsity/methods/vsa.py`:
- Around line 289-303: The code now treats gate_compress=None as disabling the
compression branch (equivalent to gate_compress=0), but the forward_attention()
docstring/caller contract still says gate_compress=None means equal weighting
(0.5); update the forward_attention() docstring and any public API docs/tests to
state that gate_compress=None disables compression (i.e., treated as 0) and not
0.5, and adjust any callers/tests that rely on the old semantics to pass an
explicit 0.5 if they need equal weighting; reference the forward_attention
function and the gate_compress handling in VSA (the branch that creates
gate_tiled = torch.zeros(...) when gate_compress is None) when making these
changes.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py`:
- Around line 32-42: The module currently imports .huggingface eagerly which can
raise import errors; change it to a soft-loaded plugin by wrapping that import
in the same import_plugin guard used for ltx2 and wan22 (i.e., use
import_plugin("huggingface") and then from . import huggingface) so the
huggingface integration is loaded lazily and won’t break importing the core
package; update the block containing import_plugin, ltx2, wan22 to include the
huggingface guarded import and remove the top-level from . import huggingface.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py`:
- Around line 113-139: The hook currently always derives and pushes video_shape
into every VSA method; change it so _hook only auto-populates when the method
has no explicit shape: inside the loop over module.modules() for
SparseAttentionModule instances (and after filtering for method.name == "vsa"),
check whether the method already exposes an explicit shape (e.g.,
getattr(method, "video_shape", None) is not None or (callable(getattr(method,
"get_video_shape", None)) and method.get_video_shape() is not None)); only call
method.set_video_shape(video_shape) and set module._vsa_video_shape when no
explicit shape is present. Ensure you reference the symbols _hook,
SparseAttentionModule, _sparse_method_instance, method.name,
method.set_video_shape, and module._vsa_video_shape in your change.
---
Nitpick comments:
In `@examples/diffusers/sparsity/ltx2_vsa.py`:
- Around line 176-182: The ImportError raised in main() discards the original
exception traceback; modify the raise to chain the original exception by using
"raise ImportError(...)" with "from _LTX_IMPORT_ERROR" so the original
_LTX_IMPORT_ERROR is preserved in the traceback (referencing the main function
and the _LTX_IMPORT_ERROR symbol).
In `@examples/diffusers/sparsity/wan22_sparse_attn.py`:
- Around line 263-268: The _parse_int_triple function currently raises a generic
ValueError; change it to raise argparse.ArgumentTypeError so it can be used as
an argparse type converter and improve error clarity by distinguishing parse
failures from validation failures: catch exceptions from int(...) and raise
ArgumentTypeError with a message like "invalid int triple: failed to parse T,H,W
from '...'", and if parsing succeeds but length or positivity checks fail raise
ArgumentTypeError with a message like "expected 3 positive integers T,H,W — got
'...'". Update references to _parse_int_triple to import argparse if needed.
In `@modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py`:
- Around line 197-209: The inline import of
ltx_core.model.transformer.rope.apply_rotary_emb inside _compute_qkv causes
repeated import attempts at every forward; instead implement a module-level lazy
loader (e.g., a helper like _load_apply_rotary_emb) that on first call imports
apply_rotary_emb, caches it in a module-level variable, and raises the same
ModuleNotFoundError with the existing message if unavailable; then replace the
local import with a call to that loader and call the cached
apply_rotary_emb(query, pe, self.rope_type) / apply_rotary_emb(key, pe if k_pe
is None else k_pe, self.rope_type). Ensure you reference rope_type and k_pe
handling exactly as in the current code.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 360df248-bbf0-4613-bc64-b09b1f17f5a7
📒 Files selected for processing (69)
CHANGELOG.rstCLAUDE.mdexamples/deepseek/ptq.pyexamples/deepseek/quantize_to_nvfp4.pyexamples/diffusers/README.mdexamples/diffusers/sparsity/README.mdexamples/diffusers/sparsity/ltx2_vsa.pyexamples/diffusers/sparsity/wan22_sparse_attn.pymodelopt/torch/kernels/__init__.pymodelopt/torch/kernels/common/__init__.pymodelopt/torch/kernels/common/attention/__init__.pymodelopt/torch/kernels/common/attention/hf_triton_attention.pymodelopt/torch/kernels/common/attention/triton_fa.pymodelopt/torch/kernels/quantization/__init__.pymodelopt/torch/kernels/quantization/attention/__init__.pymodelopt/torch/kernels/quantization/conv/README.mdmodelopt/torch/kernels/quantization/conv/__init__.pymodelopt/torch/kernels/quantization/conv/bench_implicit_gemm.pymodelopt/torch/kernels/quantization/conv/implicit_gemm_binding.cppmodelopt/torch/kernels/quantization/conv/implicit_gemm_cuda.pymodelopt/torch/kernels/quantization/conv/implicit_gemm_kernel.cumodelopt/torch/kernels/quantization/gemm/__init__.pymodelopt/torch/kernels/quantization/gemm/fp4_kernel.pymodelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.pymodelopt/torch/kernels/quantization/gemm/fp8_kernel.pymodelopt/torch/kernels/quantization/gemm/gptq_fused_kernel.pymodelopt/torch/kernels/quantization/gemm/nvfp4_quant.pymodelopt/torch/kernels/quantization/gemm/tensor_quant.cppmodelopt/torch/kernels/quantization/gemm/tensor_quant.hmodelopt/torch/kernels/quantization/gemm/tensor_quant_gpu.cumodelopt/torch/kernels/quantization/gemm/tensor_quant_gpu_fp8.cumodelopt/torch/kernels/quantization/gemm/tensor_quant_mx.cumodelopt/torch/kernels/quantization/gemm/tensor_quant_mx.hmodelopt/torch/kernels/sparsity/__init__.pymodelopt/torch/kernels/sparsity/attention/__init__.pymodelopt/torch/kernels/sparsity/attention/calibrate.pymodelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.pymodelopt/torch/kernels/sparsity/attention/ltx_triton_attention.pymodelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.pymodelopt/torch/kernels/sparsity/gemm/__init__.pymodelopt/torch/quantization/extensions.pymodelopt/torch/quantization/nn/modules/quant_conv.pymodelopt/torch/quantization/plugins/huggingface.pymodelopt/torch/quantization/qtensor/nvfp4_tensor.pymodelopt/torch/quantization/tensor_quant.pymodelopt/torch/quantization/utils/calib_utils.pymodelopt/torch/sparsity/attention_sparsity/conversion.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.pymodelopt/torch/sparsity/attention_sparsity/methods/vsa.pymodelopt/torch/sparsity/attention_sparsity/plugins/__init__.pymodelopt/torch/sparsity/attention_sparsity/plugins/huggingface.pymodelopt/torch/sparsity/attention_sparsity/plugins/ltx2.pymodelopt/torch/sparsity/attention_sparsity/plugins/wan22.pypyproject.tomltests/gpu/torch/kernels/common/attention/test_triton_fa.pytests/gpu/torch/kernels/conftest.pytests/gpu/torch/kernels/quantization/conv/test_implicit_gemm.pytests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.pytests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.pytests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.pytests/gpu/torch/kernels/sparsity/attention/test_triton_fa_sparse_nm.pytests/gpu/torch/quantization/conftest.pytests/gpu/torch/quantization/test_tensor_quant_cuda.pytests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.pytests/unit/torch/kernels/common/attention/test_triton_fa.pytests/unit/torch/kernels/sparsity/attention/test_kernel_backends.pytests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.pytests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
💤 Files with no reviewable changes (1)
- modelopt/torch/kernels/quantization/gemm/init.py
| "attention_calibrate", | ||
| "register_triton_attention", | ||
| ] | ||
| """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm).""" |
There was a problem hiding this comment.
Keep a compatibility shim for the old modelopt.torch.kernels imports.
Reducing this package to only a docstring drops previously-exported symbols like attention, attention_calibrate, and IS_AVAILABLE, so existing downstream imports break immediately on upgrade. Please re-export the moved symbols from modelopt.torch.kernels.common.attention and deprecate the old path instead of removing it outright.
Possible shim
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
+
+from .common.attention import (
+ IS_AVAILABLE,
+ attention,
+ attention_calibrate,
+ register_triton_attention,
+)
+
+__all__ = [
+ "IS_AVAILABLE",
+ "attention",
+ "attention_calibrate",
+ "register_triton_attention",
+]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm).""" | |
| """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm).""" | |
| from .common.attention import ( | |
| IS_AVAILABLE, | |
| attention, | |
| attention_calibrate, | |
| register_triton_attention, | |
| ) | |
| __all__ = [ | |
| "IS_AVAILABLE", | |
| "attention", | |
| "attention_calibrate", | |
| "register_triton_attention", | |
| ] |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/kernels/__init__.py` at line 16, Restore a compatibility shim
in modelopt.torch.kernels.__init__.py that re-exports the former public symbols
(e.g., attention, attention_calibrate, IS_AVAILABLE) from their new location
modelopt.torch.kernels.common.attention and emit a DeprecationWarning on import;
specifically import those symbols from modelopt.torch.kernels.common.attention,
set them in the package namespace, and call warnings.warn with a clear
deprecation message indicating the new import path so old code using
modelopt.torch.kernels.attention continues to work while notifying users to
update.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1315 +/- ##
==========================================
- Coverage 76.90% 76.65% -0.25%
==========================================
Files 471 473 +2
Lines 50565 50831 +266
==========================================
+ Hits 38885 38965 +80
- Misses 11680 11866 +186
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
cjluo-nv
left a comment
There was a problem hiding this comment.
This PR adds VSA support for Wan 2.2 and LTX-2 models. The core logic looks correct and well-documented. The VSA gate_compress=None fix is important and well-reasoned. The plugin architecture follows good patterns (idempotent registration, graceful fallbacks, class-name-based detection).
Key observations:
- Copyright year: Both new files (
ltx2.py,wan22.py) use "Copyright (c) 2024" but should be 2025 for new files. - Missing plugin-specific tests: While
test_vsa.pycovers the core VSA method, config validation, and integration, there are no unit tests for the wan22 or ltx2 plugins specifically (hook installation, video_shape extraction, idempotency guards). The PR description mentions a "Wan 2.2 plugin hook test" but it doesn't appear in the test files. - The LTX-2 plugin is included without an example: This is explicitly called out in the PR description as intentional (pending license/training loop), which is fine, but makes the lack of tests more concerning since there's no way to validate it even manually.
- Size: ~1052 lines is at the boundary but the changes are cohesive.
The code quality is high overall — good docstrings, defensive fallbacks, and proper error messages. The CUSTOM_MODEL_PLUGINS list→set migration is a nice improvement.
|
@kaix-nv could you take a look? |
kevalmorabia97
left a comment
There was a problem hiding this comment.
LGTM from LTX license notice point of view
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
cjluo-nv
left a comment
There was a problem hiding this comment.
Bot review — DM the bot to share feedback.
Re-review summary: All critical issues from the previous review have been addressed:
- ✅ calibrate.py prog_idx flattening — Fixed; now uses
tl.num_programs(2)instead oftl.load(b_seq_len + 0). - ✅ vsa.py gate_compress=None docstring — Updated to document the new "disabled/zero" semantics.
- ✅ wan22.py hook unconditionally overwriting video_shape — Now uses
_wan22_auto_video_shapemarker to preserve user-supplied shapes. - ✅ Missing plugin-specific tests — Comprehensive unit tests added for both plugins (
test_wan22_plugin.py,test_ltx2_plugin.py) covering detection, hook registration, idempotency, shape extraction, and edge cases.
Design: The PR extends the existing attention_sparsity plugin architecture with two new model-specific plugins (Wan 2.2, LTX-2) and a new VSA method. No new subsystem or framework is introduced — it plugs into the existing SparseAttentionMethod registry and CUSTOM_MODEL_PLUGINS set. This is the right approach.
Remaining minor items (not blocking, but worth owner attention):
- The
huggingfaceplugin is still eagerly imported inplugins/__init__.py(not guarded byimport_plugin), whileltx2andwan22are. This inconsistency was flagged in the prior review and is still present. Low risk since HF is effectively a hard dep, but worth aligning. - The threshold validation in
calibrate.py(rejectingt <= 0beforemath.log2) was suggested but not added. Minor defensive improvement. - Size is ~1500 lines but cohesive: two plugins, one method fix, one example refactor, two test files, and a README rewrite. Reasonable not to split.
Nudging for human sign-off since this is a large, multi-file change touching kernel code and adding new plugin integrations.
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py`:
- Around line 102-138: The tests call _extract_video_shape_hook with
args=(modality,) but the hook derives seq_len from a tensor arg, so update each
test (test_extracts_3d_positions, test_extracts_4d_positions_taking_start_coord,
test_skips_when_video_is_none, test_skips_when_positions_is_none,
test_skips_when_product_mismatches_seq_len) to pass the positions tensor as the
first positional arg (e.g. args=(modality.positions,)) or otherwise provide the
actual tensor argument the hook expects so seq_len is computed correctly;
additionally, fix test_skips_when_product_mismatches_seq_len to assert the guard
behavior by expecting no _vsa_video_shape attribute (use assert not
hasattr(model, "_vsa_video_shape")) since the product/seq_len mismatch should
skip setting it.
- Line 29: Remove the module-level test gate by deleting the
pytest.importorskip("transformers") call so the suite no longer gets skipped
when transformers isn't installed; keep the test support approach that uses
_make_named_module() and types.SimpleNamespace() to mock any transformer
behavior and ensure any optional transformers usage in the ltx2 plugin is loaded
lazily via import_plugin() rather than gating the whole file.
In `@tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py`:
- Line 28: Remove the module-level pytest.importorskip("transformers") call from
the test file so tests for the WAN22 plugin (which use only mocks and helpers)
are not skipped in environments without transformers; locate the importorskip
invocation and delete it, and if any future tests require transformers, wrap
only those specific test functions or methods with pytest.importorskip or an
appropriate marker instead of skipping the whole module.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5f2df6e2-9216-4f5a-8cfe-16d9046738b5
📒 Files selected for processing (4)
modelopt/torch/sparsity/attention_sparsity/methods/vsa.pymodelopt/torch/sparsity/attention_sparsity/plugins/wan22.pytests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.pytests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
- modelopt/torch/sparsity/attention_sparsity/methods/vsa.py
|
|
||
| import pytest | ||
|
|
||
| pytest.importorskip("transformers") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check whether this test file references transformers APIs beyond importorskip.
rg -n --type=py 'importorskip\("transformers"\)|\bimport transformers\b|\bfrom transformers\b' tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py
# Check whether the ltx2 plugin itself has a transformers hard dependency.
rg -n --type=py '\btransformers\b' modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.pyRepository: NVIDIA/Model-Optimizer
Length of output: 105
🏁 Script executed:
# First, let's see the test file structure and what it actually tests
head -60 tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py
# Also check the imports and setup of the ltx2 plugin
head -30 modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.pyRepository: NVIDIA/Model-Optimizer
Length of output: 3557
🏁 Script executed:
# Check all imports in ltx2.py to see if transformers is pulled in transitively
head -50 modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py
# Check a representative test method to confirm it uses mocks, not transformers APIs
sed -n '90,150p' tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.pyRepository: NVIDIA/Model-Optimizer
Length of output: 4721
Remove unnecessary module-level pytest.importorskip("transformers") gate.
Line 29 unnecessarily skips this entire CPU unit test suite. The tests use mock modules (via _make_named_module() and types.SimpleNamespace()) and do not invoke any transformers APIs. The ltx2 plugin itself has no transformers dependency. This skip prevents the test suite from running in environments without transformers installed, contrary to the coding guideline that optional dependencies should be "loaded lazily via import_plugin()" rather than gating entire test suites at module level. Remove line 29.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py` at line 29,
Remove the module-level test gate by deleting the
pytest.importorskip("transformers") call so the suite no longer gets skipped
when transformers isn't installed; keep the test support approach that uses
_make_named_module() and types.SimpleNamespace() to mock any transformer
behavior and ensure any optional transformers usage in the ltx2 plugin is loaded
lazily via import_plugin() rather than gating the whole file.
| def test_extracts_3d_positions(self): | ||
| model = nn.Module() | ||
| modality = types.SimpleNamespace(positions=_build_positions(2, 3, 4)) | ||
| _extract_video_shape_hook(model, args=(modality,)) | ||
| assert model._vsa_video_shape == (2, 3, 4) | ||
|
|
||
| def test_extracts_4d_positions_taking_start_coord(self): | ||
| model = nn.Module() | ||
| # (B, 3, T, 2) — _extract_video_shape_hook drops the trailing dim. | ||
| positions_3d = _build_positions(2, 2, 2) # (1, 3, 8) | ||
| positions_4d = positions_3d.unsqueeze(-1).expand(-1, -1, -1, 2).contiguous() | ||
| modality = types.SimpleNamespace(positions=positions_4d) | ||
| _extract_video_shape_hook(model, args=(modality,)) | ||
| assert model._vsa_video_shape == (2, 2, 2) | ||
|
|
||
| def test_skips_when_video_is_none(self): | ||
| model = nn.Module() | ||
| _extract_video_shape_hook(model, args=(None,)) | ||
| assert not hasattr(model, "_vsa_video_shape") | ||
|
|
||
| def test_skips_when_positions_is_none(self): | ||
| model = nn.Module() | ||
| modality = types.SimpleNamespace(positions=None) | ||
| _extract_video_shape_hook(model, args=(modality,)) | ||
| assert not hasattr(model, "_vsa_video_shape") | ||
|
|
||
| def test_skips_when_product_mismatches_seq_len(self): | ||
| """Defensive guard: if unique-counts don't multiply to seq_len, bail.""" | ||
| model = nn.Module() | ||
| # Both T-dim and H-dim share the same single value, so unique counts | ||
| # collapse and the product no longer equals seq_len. | ||
| positions = torch.zeros(1, 3, 4, dtype=torch.long) | ||
| positions[0, 2] = torch.arange(4) # only W varies | ||
| modality = types.SimpleNamespace(positions=positions) | ||
| _extract_video_shape_hook(model, args=(modality,)) | ||
| # 1 * 1 * 4 == 4 still matches seq_len, so this should succeed. | ||
| assert model._vsa_video_shape == (1, 1, 4) |
There was a problem hiding this comment.
Align _extract_video_shape_hook tests with the hook’s seq-len guard.
On Lines 105/114/136, the hook is called with args=(modality,), but implementation resolves seq_len from tensor args before setting _vsa_video_shape. Also, test_skips_when_product_mismatches_seq_len (Line 128) currently asserts success on Line 138, which contradicts its intent.
🔧 Suggested test fixes
def test_extracts_3d_positions(self):
model = nn.Module()
modality = types.SimpleNamespace(positions=_build_positions(2, 3, 4))
- _extract_video_shape_hook(model, args=(modality,))
+ seq = torch.zeros(1, 24) # seq_len = T*H*W
+ _extract_video_shape_hook(model, args=(modality, seq))
assert model._vsa_video_shape == (2, 3, 4)
def test_extracts_4d_positions_taking_start_coord(self):
model = nn.Module()
# (B, 3, T, 2) — _extract_video_shape_hook drops the trailing dim.
positions_3d = _build_positions(2, 2, 2) # (1, 3, 8)
positions_4d = positions_3d.unsqueeze(-1).expand(-1, -1, -1, 2).contiguous()
modality = types.SimpleNamespace(positions=positions_4d)
- _extract_video_shape_hook(model, args=(modality,))
+ seq = torch.zeros(1, 8) # seq_len = T*H*W
+ _extract_video_shape_hook(model, args=(modality, seq))
assert model._vsa_video_shape == (2, 2, 2)
def test_skips_when_product_mismatches_seq_len(self):
"""Defensive guard: if unique-counts don't multiply to seq_len, bail."""
model = nn.Module()
# Both T-dim and H-dim share the same single value, so unique counts
# collapse and the product no longer equals seq_len.
positions = torch.zeros(1, 3, 4, dtype=torch.long)
positions[0, 2] = torch.arange(4) # only W varies
modality = types.SimpleNamespace(positions=positions)
- _extract_video_shape_hook(model, args=(modality,))
- # 1 * 1 * 4 == 4 still matches seq_len, so this should succeed.
- assert model._vsa_video_shape == (1, 1, 4)
+ seq = torch.zeros(1, 5) # mismatched seq_len
+ _extract_video_shape_hook(model, args=(modality, seq))
+ assert not hasattr(model, "_vsa_video_shape")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py` around
lines 102 - 138, The tests call _extract_video_shape_hook with args=(modality,)
but the hook derives seq_len from a tensor arg, so update each test
(test_extracts_3d_positions, test_extracts_4d_positions_taking_start_coord,
test_skips_when_video_is_none, test_skips_when_positions_is_none,
test_skips_when_product_mismatches_seq_len) to pass the positions tensor as the
first positional arg (e.g. args=(modality.positions,)) or otherwise provide the
actual tensor argument the hook expects so seq_len is computed correctly;
additionally, fix test_skips_when_product_mismatches_seq_len to assert the guard
behavior by expecting no _vsa_video_shape attribute (use assert not
hasattr(model, "_vsa_video_shape")) since the product/seq_len mismatch should
skip setting it.
|
|
||
| import pytest | ||
|
|
||
| pytest.importorskip("transformers") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether this test file actually uses transformers symbols beyond importorskip.
rg -n --type=py 'importorskip\("transformers"\)|\bimport transformers\b|\bfrom transformers\b' tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py
# Verify whether the plugin under test hard-depends on transformers.
rg -n --type=py '\btransformers\b' modelopt/torch/sparsity/attention_sparsity/plugins/wan22.pyRepository: NVIDIA/Model-Optimizer
Length of output: 303
🏁 Script executed:
#!/bin/bash
# Check test file imports and structure around the skip statement
head -50 tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py
# Check if wan22.py imports transformers anywhere
grep -n "import.*transformers\|from.*transformers" modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
# Find _find_wan22_transformers function definition
grep -n "_find_wan22_transformers" modelopt/torch/sparsity/attention_sparsity/plugins/wan22.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1966
🏁 Script executed:
#!/bin/bash
# Check imports at the top of wan22.py
head -60 modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py | grep -E "^import|^from"
# Also check the full plugin file for transformers imports
rg -n "^import.*transformers|^from.*transformers" modelopt/torch/sparsity/attention_sparsity/plugins/wan22.pyRepository: NVIDIA/Model-Optimizer
Length of output: 191
🏁 Script executed:
#!/bin/bash
# Check if any test function actually uses transformers (not just variable name "transformers")
# Get the test function definitions
rg -n "^def test_" tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py
# Check if transformers module is referenced anywhere after the skip (not just variable "transformers")
rg -n "transformers\." tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
#!/bin/bash
# Get file size first
wc -l tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py
# Read the full test file
cat -n tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.pyRepository: NVIDIA/Model-Optimizer
Length of output: 10447
Remove the unnecessary module-level skip on transformers.
The test file's docstring explicitly states it tests "without requiring a diffusers/Wan install." The file uses only mock classes and test helpers—no transformers APIs are invoked. The plugin itself (wan22.py) has no transformers dependency. Line 28's pytest.importorskip("transformers") will silently skip all tests in CI environments lacking transformers, reducing coverage of core plugin functionality.
Remove the skip entirely. If specific tests in the future need transformers, gate only those individual test methods.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py` at line
28, Remove the module-level pytest.importorskip("transformers") call from the
test file so tests for the WAN22 plugin (which use only mocks and helpers) are
not skipped in environments without transformers; locate the importorskip
invocation and delete it, and if any future tests require transformers, wrap
only those specific test functions or methods with pytest.importorskip or an
appropriate marker instead of skipping the whole module.
What does this PR do?
Type of change: new feature
Adds end-to-end Video Sparse Attention (VSA) inference support for Wan 2.2
(5B and 14B) under
modelopt.torch.sparsity.attention_sparsity(mtsa). Thecore VSA method landed in #1053 but the original LTX-2 plugin was dropped and
Wan 2.2 had none, so neither model was actually runnable with VSA. This PR
fills that gap and unifies the Wan 2.2 example around a single
--method {skip_softmax,vsa}entry point.Main changes:
plugins/wan22.py— forward pre-hook onWanTransformer3DModelthat reads
hidden_states.shape = (B, C, T, H, W), divides byconfig.patch_size, and propagates the post-patchify(T, H, W)to everySparseAttentionModuleviamethod.set_video_shape(). Wan usesF.scaled_dot_product_attention, so VSA's existing SDPA patch handles therest — no module subclass needed.
gate_compress=Nonefix (methods/vsa.py) — the fastvideo kernel'sdefault
compress_attn_weight=Nonereturnsout_c + out_s, which doublesthe attention signal on any model without a learned gate (e.g. Wan 2.2). VSA
now passes an explicit
gate=0tensor soout = 0 * out_c + out_s = out_s.Side effect:
top_k_ratio=1.0now cleanly degenerates to dense SDPA(modulo bf16 rounding).
plugins/__init__.py) —CUSTOM_MODEL_PLUGINSchangedfrom
listtosetso re-imports stay idempotent (matches quantization /peft convention). Wan 2.2 plugin registered via
import_pluginso a missingoptional dep never breaks the core sparse-attention API.
wan22_skip_softmax.py→wan22_sparse_attn.py) —single script with
--method {skip_softmax,vsa}plus VSA flags(
--top-k-ratio,--skip-first-last,--enable-vae-tiling). Skip-softmaxbehaviour and CLI are preserved.
examples/diffusers/sparsity/README.md) — methodcomparison table, VSA quick-start, config reference, and a dense-equivalence
sanity-check section with measured PSNR numbers on Wan 2.2 14B.
LTX-2 plugin (
plugins/ltx2.py) is included as well — it wraps LTX-2's nativeLTXSelfAttentionand callsVSA.forward_attentiondirectly, with azero-initialised trainable
to_gate_compress— but the LTX-2 example isnot in this PR (it depends on third-party
ltx_core/ltx_trainer/ltx_pipelinesunder the LTX Community License). Example will land separatelyonce the training loop and license plumbing are finalised.
Usage
Or the built-in default via the example script:
Testing
Unit tests —
conda run -n modelopt python -m pytest tests/unit/torch/sparsity/attention_sparsity/→ 149 passed (sparse-attentionconversion, kernel backends, registry).
Wan 2.2 plugin hook test — end-to-end check that
video_shapeiscorrectly derived from
hidden_states.shape / patch_sizeand propagatedto every VSA method instance before the SDPA patch fires.
Dense-equivalence sanity check on Wan 2.2 14B (720×1280 / 81 frames
/ 40 steps), first-frame PSNR vs dense baseline:
top_k_ratio=1.0top_k_ratio=0.5The ~24 dB drop at
top_k_ratio=1.0is error accumulation over 6400attention calls through the denoising loop; single-call PSNR vs dense
SDPA on random inputs is ~50 dB, confirming the dense-equivalence property
at the kernel level.
No regression on skip-softmax — existing Wan 2.2 skip-softmax flows
(raw threshold, calibration, dense Triton baseline) verified through the
renamed
wan22_sparse_attn.pyscript.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).CONTRIBUTING.md: ✅Additional Information
Summary by CodeRabbit
New Features
Documentation
Tests