Skip to content

VSA support for Wan 2.2 and LTX2#1315

Open
jingyu-ml wants to merge 15 commits intomainfrom
jingyux/vsa-diffusion
Open

VSA support for Wan 2.2 and LTX2#1315
jingyu-ml wants to merge 15 commits intomainfrom
jingyux/vsa-diffusion

Conversation

@jingyu-ml
Copy link
Copy Markdown
Contributor

@jingyu-ml jingyu-ml commented Apr 22, 2026

What does this PR do?

Type of change: new feature

Adds end-to-end Video Sparse Attention (VSA) inference support for Wan 2.2
(5B and 14B) under modelopt.torch.sparsity.attention_sparsity (mtsa). The
core VSA method landed in #1053 but the original LTX-2 plugin was dropped and
Wan 2.2 had none, so neither model was actually runnable with VSA. This PR
fills that gap and unifies the Wan 2.2 example around a single
--method {skip_softmax,vsa} entry point.

Main changes:

  • New plugin plugins/wan22.py — forward pre-hook on WanTransformer3DModel
    that reads hidden_states.shape = (B, C, T, H, W), divides by
    config.patch_size, and propagates the post-patchify (T, H, W) to every
    SparseAttentionModule via method.set_video_shape(). Wan uses
    F.scaled_dot_product_attention, so VSA's existing SDPA patch handles the
    rest — no module subclass needed.
  • VSA gate_compress=None fix (methods/vsa.py) — the fastvideo kernel's
    default compress_attn_weight=None returns out_c + out_s, which doubles
    the attention signal on any model without a learned gate (e.g. Wan 2.2). VSA
    now passes an explicit gate=0 tensor so out = 0 * out_c + out_s = out_s.
    Side effect: top_k_ratio=1.0 now cleanly degenerates to dense SDPA
    (modulo bf16 rounding).
  • Plugin registry (plugins/__init__.py) — CUSTOM_MODEL_PLUGINS changed
    from list to set so re-imports stay idempotent (matches quantization /
    peft convention). Wan 2.2 plugin registered via import_plugin so a missing
    optional dep never breaks the core sparse-attention API.
  • Example unification (wan22_skip_softmax.pywan22_sparse_attn.py) —
    single script with --method {skip_softmax,vsa} plus VSA flags
    (--top-k-ratio, --skip-first-last, --enable-vae-tiling). Skip-softmax
    behaviour and CLI are preserved.
  • README rewrite (examples/diffusers/sparsity/README.md) — method
    comparison table, VSA quick-start, config reference, and a dense-equivalence
    sanity-check section with measured PSNR numbers on Wan 2.2 14B.

LTX-2 plugin (plugins/ltx2.py) is included as well — it wraps LTX-2's native
LTXSelfAttention and calls VSA.forward_attention directly, with a
zero-initialised trainable to_gate_compress — but the LTX-2 example is
not in this PR (it depends on third-party ltx_core / ltx_trainer /
ltx_pipelines under the LTX Community License). Example will land separately
once the training loop and license plumbing are finalised.

Usage

import torch
from diffusers import AutoencoderKLWan, WanPipeline
import modelopt.torch.sparsity.attention_sparsity as mtsa

pipe = WanPipeline.from_pretrained(
    "Wan-AI/Wan2.2-T2V-A14B-Diffusers", torch_dtype=torch.bfloat16
).to("cuda")

# VSA config: 50% top-K on self-attention, cross-attention left dense.
# ``video_shape`` is auto-derived by the Wan 2.2 plugin on each forward.
config = {
    "sparse_cfg": {
        "*.attn1*": {
            "method": "vsa",
            "block_size_3d": (4, 4, 4),
            "top_k_ratio": 0.5,
            "enable": True,
        },
        "*.attn2*": {"enable": False},
        "default": {"enable": False},
    },
}
pipe.transformer = mtsa.sparsify(pipe.transformer, config)

video = pipe(prompt="A cat playing piano", num_frames=81).frames[0]

Or the built-in default via the example script:

# VSA at 50% top-K (default block_size_3d=(4,4,4), self-attn only)
python examples/diffusers/sparsity/wan22_sparse_attn.py --method vsa \
    --top-k-ratio 0.5 \
    --prompt "A cat playing piano" --output vsa.mp4

# Skip-softmax (unchanged behaviour, still the default method)
python examples/diffusers/sparsity/wan22_sparse_attn.py \
    --raw-threshold -0.7 \
    --prompt "A cat playing piano" --output out.mp4

Testing

  • Unit testsconda run -n modelopt python -m pytest tests/unit/torch/sparsity/attention_sparsity/ → 149 passed (sparse-attention
    conversion, kernel backends, registry).

  • Wan 2.2 plugin hook test — end-to-end check that video_shape is
    correctly derived from hidden_states.shape / patch_size and propagated
    to every VSA method instance before the SDPA patch fires.

  • Dense-equivalence sanity check on Wan 2.2 14B (720×1280 / 81 frames
    / 40 steps), first-frame PSNR vs dense baseline:

    Comparison PSNR
    baseline vs baseline w/ VAE tiling 40.5 dB
    baseline vs VSA top_k_ratio=1.0 23.9 dB
    baseline vs VSA top_k_ratio=0.5 13.1 dB

    The ~24 dB drop at top_k_ratio=1.0 is error accumulation over 6400
    attention calls through the denoising loop; single-call PSNR vs dense
    SDPA on random inputs is ~50 dB, confirming the dense-equivalence property
    at the kernel level.

  • No regression on skip-softmax — existing Wan 2.2 skip-softmax flows
    (raw threshold, calibration, dense Triton baseline) verified through the
    renamed wan22_sparse_attn.py script.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅

Additional Information

Summary by CodeRabbit

  • New Features

    • Added VSA sparse-attention alongside Skip-Softmax with CLI method switching, VSA-specific options (block size, top-k ratio, video-shape) and an option to enable VAE tiling.
    • Automatic VSA support for Wan 2.2 and LTX‑2 models with inferred video-shape propagation.
  • Documentation

    • Rewrote sparsity README with quick-starts, method comparisons, model-to-script mapping, defaults and known issues.
  • Tests

    • Added unit tests covering LTX‑2 and Wan 2.2 VSA plugins and video-shape extraction/behavior.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested review from a team as code owners April 22, 2026 07:34
@jingyu-ml jingyu-ml marked this pull request as draft April 22, 2026 07:34
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 22, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@jingyu-ml jingyu-ml self-assigned this Apr 22, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

📝 Walkthrough

Walkthrough

Adds VSA (vendor-specific) sparse-attention support alongside Skip-Softmax: new VSA method behavior, LTX‑2 and Wan‑2.2 model plugins, plugin-system import/registration changes, example script/README updates, and unit tests exercising the new plugins and wiring.

Changes

VSA feature, plugins, wiring, examples, and tests

Layer / File(s) Summary
Method semantics
modelopt/torch/sparsity/attention_sparsity/methods/vsa.py
forward_attention treats gate_compress=None as disabled by creating a scalar zero tensor (so kernel never receives None); docstring updated to map None → gate=0.
Plugin registry / guarded imports
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
CUSTOM_MODEL_PLUGINS changed from listset to deduplicate callbacks; added guarded imports for ltx2 and wan22 via import_plugin(...), keeping unguarded huggingface.
HuggingFace plugin adjustment
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
Registration changed from append(...) to add(...) to match set semantics.
LTX‑2 plugin implementation
modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py
New plugin: detects LTX‑2 models/attention, registers _LTX2SparseAttention wrappers, installs model pre-hook to extract (T,H,W) from Modality.positions, projects Q/K/V, applies q/k norms and RoPE, optionally exposes trainable to_gate_compress, dispatches to VSA forward_attention, bubbles stats, and falls back to original forward on incompatibility. Exposes registration entry added to plugins set.
Wan‑2.2 plugin implementation
modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
New plugin: finds WanTransformer3DModel instances, installs per-transformer forward pre-hooks (with_kwargs=True) to compute post-patchify video_shape from 5D hidden-states and config.patch_size, stamps _vsa_video_shape and calls method.set_video_shape(...) for VSA methods (tracking auto-set via _wan22_auto_video_shape); hook registration is idempotent. Registration added to plugins set.
Adapter wiring in example CLI
examples/diffusers/sparsity/wan22_sparse_attn.py
Script refactored to support --method (skip_softmax default, or vsa), split method-agnostic flags (--baseline, --triton-baseline) from method-specific flags, added VSA CLI/config options (--top-k-ratio, --block-size, --video-shape), new helpers _apply_skip_softmax and _apply_vsa, renamed/adjusted builders and runtime-sparsity printing logic, and added --enable-vae-tiling.
Documentation / examples
examples/diffusers/sparsity/README.md
README expanded from skip-softmax-only to document both Skip‑Softmax and VSA, method-switching guidance, model→script mapping (wan22_sparse_attn.py, ltx2_vsa.py), VSA quick-start commands and config parameters (block_size_3d, top_k_ratio, video_shape, enable), and updated defaults/known issues.
Unit tests: LTX‑2 plugin
tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py
New CPU-only tests: detection by class-name and structural duck-typing, _extract_video_shape_hook behavior, and _LTX2SparseAttention._resolve_video_shape resolution and fallback cases (including dead weakref handling).
Unit tests: Wan‑2.2 plugin
tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py
New CPU-only tests: Wan‑2.2 detection, hook idempotency, patch-size parsing, hook behavior deriving/stamping _vsa_video_shape, set_video_shape calls for VSA methods, preservation of user-supplied shapes, and refresh across calls.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Transformer as Transformer (root model)
    participant Hook as Pre-hook (wan22/ltx2)
    participant SparseModule as SparseAttentionModule
    participant VSA as VSA.method (forward_attention)
    Transformer->>Hook: forward (hidden_states / Modality.positions)
    Hook-->>Transformer: compute & store _vsa_video_shape
    Hook->>SparseModule: set method.video_shape (if VSA & auto)
    SparseModule->>VSA: forward_attention(q,k,v, video_shape, gate_compress)
    VSA->>SparseModule: sparse_output (+ stats)
    SparseModule->>Transformer: apply to_out → return output
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.96% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'VSA support for Wan 2.2 and LTX2' accurately and clearly summarizes the main objective of the pull request—adding Video Sparse Attention support for two specific models.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No torch.load() calls found in Python files

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jingyux/vsa-diffusion

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 22, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1315/

Built to branch gh-pages at 2026-05-04 21:51 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (3)
examples/diffusers/sparsity/ltx2_vsa.py (1)

176-182: Consider including the original exception as the cause.

The re-raised ImportError includes the original error message as a string, but chaining with from _LTX_IMPORT_ERROR would preserve the full traceback for debugging.

Proposed fix
     if not _LTX_AVAILABLE:
         raise ImportError(
             "LTX-2 packages are required for this example. Install with: "
-            "pip install ltx-core ltx-trainer ltx-pipelines. "
-            f"(original error: {_LTX_IMPORT_ERROR})"
-        )
+            "pip install ltx-core ltx-trainer ltx-pipelines."
+        ) from _LTX_IMPORT_ERROR
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/ltx2_vsa.py` around lines 176 - 182, The
ImportError raised in main() discards the original exception traceback; modify
the raise to chain the original exception by using "raise ImportError(...)" with
"from _LTX_IMPORT_ERROR" so the original _LTX_IMPORT_ERROR is preserved in the
traceback (referencing the main function and the _LTX_IMPORT_ERROR symbol).
examples/diffusers/sparsity/wan22_sparse_attn.py (1)

263-268: Consider more robust error handling in _parse_int_triple.

The function raises a generic ValueError for invalid input. For better UX, consider using argparse.ArgumentTypeError when used as an argument type converter, or providing more specific error messages distinguishing between parse failures and validation failures.

Proposed enhancement
 def _parse_int_triple(spec: str) -> tuple[int, int, int]:
     """Parse 'T,H,W' into a triple of positive ints."""
+    try:
-    parts = [int(p.strip()) for p in spec.split(",")]
+        parts = [int(p.strip()) for p in spec.split(",")]
+    except ValueError:
+        raise ValueError(f"expected 3 comma-separated integers T,H,W — got {spec!r}")
     if len(parts) != 3 or any(p <= 0 for p in parts):
-        raise ValueError(f"expected 3 positive integers T,H,W — got {spec!r}")
+        raise ValueError(f"expected 3 positive integers T,H,W — got {parts!r}")
     return (parts[0], parts[1], parts[2])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/sparsity/wan22_sparse_attn.py` around lines 263 - 268, The
_parse_int_triple function currently raises a generic ValueError; change it to
raise argparse.ArgumentTypeError so it can be used as an argparse type converter
and improve error clarity by distinguishing parse failures from validation
failures: catch exceptions from int(...) and raise ArgumentTypeError with a
message like "invalid int triple: failed to parse T,H,W from '...'", and if
parsing succeeds but length or positivity checks fail raise ArgumentTypeError
with a message like "expected 3 positive integers T,H,W — got '...'". Update
references to _parse_int_triple to import argparse if needed.
modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py (1)

197-209: Consider lazy import at module level or caching the import.

The ltx_core.model.transformer.rope.apply_rotary_emb import inside _compute_qkv will be executed on every forward pass. While Python caches imports, moving this to a module-level lazy import pattern (similar to _load_sparsity_helpers in triton_fa.py) would make the dependency check explicit and slightly more efficient.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py` around lines 197
- 209, The inline import of ltx_core.model.transformer.rope.apply_rotary_emb
inside _compute_qkv causes repeated import attempts at every forward; instead
implement a module-level lazy loader (e.g., a helper like
_load_apply_rotary_emb) that on first call imports apply_rotary_emb, caches it
in a module-level variable, and raises the same ModuleNotFoundError with the
existing message if unavailable; then replace the local import with a call to
that loader and call the cached apply_rotary_emb(query, pe, self.rope_type) /
apply_rotary_emb(key, pe if k_pe is None else k_pe, self.rope_type). Ensure you
reference rope_type and k_pe handling exactly as in the current code.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/__init__.py`:
- Line 16: Restore a compatibility shim in modelopt.torch.kernels.__init__.py
that re-exports the former public symbols (e.g., attention, attention_calibrate,
IS_AVAILABLE) from their new location modelopt.torch.kernels.common.attention
and emit a DeprecationWarning on import; specifically import those symbols from
modelopt.torch.kernels.common.attention, set them in the package namespace, and
call warnings.warn with a clear deprecation message indicating the new import
path so old code using modelopt.torch.kernels.attention continues to work while
notifying users to update.

In `@modelopt/torch/kernels/sparsity/attention/calibrate.py`:
- Around line 214-239: Reject non-positive threshold candidates before calling
math.log2: validate the caller-provided threshold_trials list (used to build
threshold_tensor) and raise a clear ValueError if any value <= 0; then proceed
to construct threshold_tensor (the current list comprehension using math.log2(t)
* sm_scale) knowing all values are > 0. Locate the logic around threshold_trials
and threshold_tensor in calibrate.py (symbols: threshold_trials,
threshold_tensor, math.log2) and add the check so the error message explicitly
states which input is invalid.
- Around line 157-166: The prog_idx flattening uses a per-program num_q_tiles
computed from tl.load(b_seq_len + 0) (sequence length of batch 0), causing
aliasing across batches; change num_q_tiles to the launch-wide Q-tile count (use
tl.num_programs(2) or equivalent launch dimension) so prog_idx = batch_idx *
num_heads * num_q_tiles + head_idx * num_q_tiles + tile_q uses the global tile
count; update the same computation in the other occurrence (the block around
lines 243-286) so both Per_program_totals / Per_program_skipped use the
launch-wide Q-tile stride.

In `@modelopt/torch/sparsity/attention_sparsity/methods/vsa.py`:
- Around line 289-303: The code now treats gate_compress=None as disabling the
compression branch (equivalent to gate_compress=0), but the forward_attention()
docstring/caller contract still says gate_compress=None means equal weighting
(0.5); update the forward_attention() docstring and any public API docs/tests to
state that gate_compress=None disables compression (i.e., treated as 0) and not
0.5, and adjust any callers/tests that rely on the old semantics to pass an
explicit 0.5 if they need equal weighting; reference the forward_attention
function and the gate_compress handling in VSA (the branch that creates
gate_tiled = torch.zeros(...) when gate_compress is None) when making these
changes.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py`:
- Around line 32-42: The module currently imports .huggingface eagerly which can
raise import errors; change it to a soft-loaded plugin by wrapping that import
in the same import_plugin guard used for ltx2 and wan22 (i.e., use
import_plugin("huggingface") and then from . import huggingface) so the
huggingface integration is loaded lazily and won’t break importing the core
package; update the block containing import_plugin, ltx2, wan22 to include the
huggingface guarded import and remove the top-level from . import huggingface.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py`:
- Around line 113-139: The hook currently always derives and pushes video_shape
into every VSA method; change it so _hook only auto-populates when the method
has no explicit shape: inside the loop over module.modules() for
SparseAttentionModule instances (and after filtering for method.name == "vsa"),
check whether the method already exposes an explicit shape (e.g.,
getattr(method, "video_shape", None) is not None or (callable(getattr(method,
"get_video_shape", None)) and method.get_video_shape() is not None)); only call
method.set_video_shape(video_shape) and set module._vsa_video_shape when no
explicit shape is present. Ensure you reference the symbols _hook,
SparseAttentionModule, _sparse_method_instance, method.name,
method.set_video_shape, and module._vsa_video_shape in your change.

---

Nitpick comments:
In `@examples/diffusers/sparsity/ltx2_vsa.py`:
- Around line 176-182: The ImportError raised in main() discards the original
exception traceback; modify the raise to chain the original exception by using
"raise ImportError(...)" with "from _LTX_IMPORT_ERROR" so the original
_LTX_IMPORT_ERROR is preserved in the traceback (referencing the main function
and the _LTX_IMPORT_ERROR symbol).

In `@examples/diffusers/sparsity/wan22_sparse_attn.py`:
- Around line 263-268: The _parse_int_triple function currently raises a generic
ValueError; change it to raise argparse.ArgumentTypeError so it can be used as
an argparse type converter and improve error clarity by distinguishing parse
failures from validation failures: catch exceptions from int(...) and raise
ArgumentTypeError with a message like "invalid int triple: failed to parse T,H,W
from '...'", and if parsing succeeds but length or positivity checks fail raise
ArgumentTypeError with a message like "expected 3 positive integers T,H,W — got
'...'". Update references to _parse_int_triple to import argparse if needed.

In `@modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py`:
- Around line 197-209: The inline import of
ltx_core.model.transformer.rope.apply_rotary_emb inside _compute_qkv causes
repeated import attempts at every forward; instead implement a module-level lazy
loader (e.g., a helper like _load_apply_rotary_emb) that on first call imports
apply_rotary_emb, caches it in a module-level variable, and raises the same
ModuleNotFoundError with the existing message if unavailable; then replace the
local import with a call to that loader and call the cached
apply_rotary_emb(query, pe, self.rope_type) / apply_rotary_emb(key, pe if k_pe
is None else k_pe, self.rope_type). Ensure you reference rope_type and k_pe
handling exactly as in the current code.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 360df248-bbf0-4613-bc64-b09b1f17f5a7

📥 Commits

Reviewing files that changed from the base of the PR and between 785d3a2 and a849d88.

📒 Files selected for processing (69)
  • CHANGELOG.rst
  • CLAUDE.md
  • examples/deepseek/ptq.py
  • examples/deepseek/quantize_to_nvfp4.py
  • examples/diffusers/README.md
  • examples/diffusers/sparsity/README.md
  • examples/diffusers/sparsity/ltx2_vsa.py
  • examples/diffusers/sparsity/wan22_sparse_attn.py
  • modelopt/torch/kernels/__init__.py
  • modelopt/torch/kernels/common/__init__.py
  • modelopt/torch/kernels/common/attention/__init__.py
  • modelopt/torch/kernels/common/attention/hf_triton_attention.py
  • modelopt/torch/kernels/common/attention/triton_fa.py
  • modelopt/torch/kernels/quantization/__init__.py
  • modelopt/torch/kernels/quantization/attention/__init__.py
  • modelopt/torch/kernels/quantization/conv/README.md
  • modelopt/torch/kernels/quantization/conv/__init__.py
  • modelopt/torch/kernels/quantization/conv/bench_implicit_gemm.py
  • modelopt/torch/kernels/quantization/conv/implicit_gemm_binding.cpp
  • modelopt/torch/kernels/quantization/conv/implicit_gemm_cuda.py
  • modelopt/torch/kernels/quantization/conv/implicit_gemm_kernel.cu
  • modelopt/torch/kernels/quantization/gemm/__init__.py
  • modelopt/torch/kernels/quantization/gemm/fp4_kernel.py
  • modelopt/torch/kernels/quantization/gemm/fp4_kernel_hopper.py
  • modelopt/torch/kernels/quantization/gemm/fp8_kernel.py
  • modelopt/torch/kernels/quantization/gemm/gptq_fused_kernel.py
  • modelopt/torch/kernels/quantization/gemm/nvfp4_quant.py
  • modelopt/torch/kernels/quantization/gemm/tensor_quant.cpp
  • modelopt/torch/kernels/quantization/gemm/tensor_quant.h
  • modelopt/torch/kernels/quantization/gemm/tensor_quant_gpu.cu
  • modelopt/torch/kernels/quantization/gemm/tensor_quant_gpu_fp8.cu
  • modelopt/torch/kernels/quantization/gemm/tensor_quant_mx.cu
  • modelopt/torch/kernels/quantization/gemm/tensor_quant_mx.h
  • modelopt/torch/kernels/sparsity/__init__.py
  • modelopt/torch/kernels/sparsity/attention/__init__.py
  • modelopt/torch/kernels/sparsity/attention/calibrate.py
  • modelopt/torch/kernels/sparsity/attention/diffusers_triton_attention.py
  • modelopt/torch/kernels/sparsity/attention/ltx_triton_attention.py
  • modelopt/torch/kernels/sparsity/attention/skip_softmax_helpers.py
  • modelopt/torch/kernels/sparsity/gemm/__init__.py
  • modelopt/torch/quantization/extensions.py
  • modelopt/torch/quantization/nn/modules/quant_conv.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • modelopt/torch/quantization/tensor_quant.py
  • modelopt/torch/quantization/utils/calib_utils.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/triton_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/vsa.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
  • pyproject.toml
  • tests/gpu/torch/kernels/common/attention/test_triton_fa.py
  • tests/gpu/torch/kernels/conftest.py
  • tests/gpu/torch/kernels/quantization/conv/test_implicit_gemm.py
  • tests/gpu/torch/kernels/sparsity/attention/test_diffusers_triton_attention.py
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_calibrate.py
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_skip_softmax.py
  • tests/gpu/torch/kernels/sparsity/attention/test_triton_fa_sparse_nm.py
  • tests/gpu/torch/quantization/conftest.py
  • tests/gpu/torch/quantization/test_tensor_quant_cuda.py
  • tests/gpu/torch/sparsity/attention_sparsity/test_wan22_skip_softmax.py
  • tests/unit/torch/kernels/common/attention/test_triton_fa.py
  • tests/unit/torch/kernels/sparsity/attention/test_kernel_backends.py
  • tests/unit/torch/kernels/sparsity/attention/test_ltx_triton_attention.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/kernels/quantization/gemm/init.py

"attention_calibrate",
"register_triton_attention",
]
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep a compatibility shim for the old modelopt.torch.kernels imports.

Reducing this package to only a docstring drops previously-exported symbols like attention, attention_calibrate, and IS_AVAILABLE, so existing downstream imports break immediately on upgrade. Please re-export the moved symbols from modelopt.torch.kernels.common.attention and deprecate the old path instead of removing it outright.

Possible shim
 """ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
+
+from .common.attention import (
+    IS_AVAILABLE,
+    attention,
+    attention_calibrate,
+    register_triton_attention,
+)
+
+__all__ = [
+    "IS_AVAILABLE",
+    "attention",
+    "attention_calibrate",
+    "register_triton_attention",
+]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
"""ModelOpt kernel library: common, quantization (conv, gemm), sparsity (attention, gemm)."""
from .common.attention import (
IS_AVAILABLE,
attention,
attention_calibrate,
register_triton_attention,
)
__all__ = [
"IS_AVAILABLE",
"attention",
"attention_calibrate",
"register_triton_attention",
]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/__init__.py` at line 16, Restore a compatibility shim
in modelopt.torch.kernels.__init__.py that re-exports the former public symbols
(e.g., attention, attention_calibrate, IS_AVAILABLE) from their new location
modelopt.torch.kernels.common.attention and emit a DeprecationWarning on import;
specifically import those symbols from modelopt.torch.kernels.common.attention,
set them in the package namespace, and call warnings.warn with a clear
deprecation message indicating the new import path so old code using
modelopt.torch.kernels.attention continues to work while notifying users to
update.

Comment thread modelopt/torch/kernels/sparsity/attention/calibrate.py
Comment thread modelopt/torch/kernels/sparsity/attention/calibrate.py
Comment thread modelopt/torch/sparsity/attention_sparsity/methods/vsa.py
Comment thread modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
Comment thread modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 22, 2026

Codecov Report

❌ Patch coverage is 58.73606% with 111 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.65%. Comparing base (06ef935) to head (416f27e).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
.../torch/sparsity/attention_sparsity/plugins/ltx2.py 39.20% 107 Missing ⚠️
...torch/sparsity/attention_sparsity/plugins/wan22.py 96.38% 3 Missing ⚠️
...t/torch/sparsity/attention_sparsity/methods/vsa.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1315      +/-   ##
==========================================
- Coverage   76.90%   76.65%   -0.25%     
==========================================
  Files         471      473       +2     
  Lines       50565    50831     +266     
==========================================
+ Hits        38885    38965      +80     
- Misses      11680    11866     +186     
Flag Coverage Δ
examples 39.66% <21.18%> (-1.00%) ⬇️
gpu 59.56% <37.91%> (-0.77%) ⬇️
regression 14.89% <17.10%> (+0.06%) ⬆️
unit 52.80% <56.87%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml changed the title Jingyux/vsa diffusion VSA support for Wan 2.2 and LTX2 Apr 23, 2026
@jingyu-ml jingyu-ml marked this pull request as ready for review April 23, 2026 00:15
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds VSA support for Wan 2.2 and LTX-2 models. The core logic looks correct and well-documented. The VSA gate_compress=None fix is important and well-reasoned. The plugin architecture follows good patterns (idempotent registration, graceful fallbacks, class-name-based detection).

Key observations:

  1. Copyright year: Both new files (ltx2.py, wan22.py) use "Copyright (c) 2024" but should be 2025 for new files.
  2. Missing plugin-specific tests: While test_vsa.py covers the core VSA method, config validation, and integration, there are no unit tests for the wan22 or ltx2 plugins specifically (hook installation, video_shape extraction, idempotency guards). The PR description mentions a "Wan 2.2 plugin hook test" but it doesn't appear in the test files.
  3. The LTX-2 plugin is included without an example: This is explicitly called out in the PR description as intentional (pending license/training loop), which is fine, but makes the lack of tests more concerning since there's no way to validate it even manually.
  4. Size: ~1052 lines is at the boundary but the changes are cohesive.

The code quality is high overall — good docstrings, defensive fallbacks, and proper error messages. The CUSTOM_MODEL_PLUGINS list→set migration is a nice improvement.

Comment thread modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
Comment thread modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py
Comment thread modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
Comment thread modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py
Comment thread modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py
@cjluo-nv
Copy link
Copy Markdown
Collaborator

@kaix-nv could you take a look?

Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from LTX license notice point of view

jingyu-ml added 2 commits May 4, 2026 21:19
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
@jingyu-ml jingyu-ml requested a review from cjluo-nv May 4, 2026 21:34
jingyu-ml and others added 2 commits May 4, 2026 21:47
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review summary: All critical issues from the previous review have been addressed:

  1. calibrate.py prog_idx flattening — Fixed; now uses tl.num_programs(2) instead of tl.load(b_seq_len + 0).
  2. vsa.py gate_compress=None docstring — Updated to document the new "disabled/zero" semantics.
  3. wan22.py hook unconditionally overwriting video_shape — Now uses _wan22_auto_video_shape marker to preserve user-supplied shapes.
  4. Missing plugin-specific tests — Comprehensive unit tests added for both plugins (test_wan22_plugin.py, test_ltx2_plugin.py) covering detection, hook registration, idempotency, shape extraction, and edge cases.

Design: The PR extends the existing attention_sparsity plugin architecture with two new model-specific plugins (Wan 2.2, LTX-2) and a new VSA method. No new subsystem or framework is introduced — it plugs into the existing SparseAttentionMethod registry and CUSTOM_MODEL_PLUGINS set. This is the right approach.

Remaining minor items (not blocking, but worth owner attention):

  • The huggingface plugin is still eagerly imported in plugins/__init__.py (not guarded by import_plugin), while ltx2 and wan22 are. This inconsistency was flagged in the prior review and is still present. Low risk since HF is effectively a hard dep, but worth aligning.
  • The threshold validation in calibrate.py (rejecting t <= 0 before math.log2) was suggested but not added. Minor defensive improvement.
  • Size is ~1500 lines but cohesive: two plugins, one method fix, one example refactor, two test files, and a README rewrite. Reasonable not to split.

Nudging for human sign-off since this is a large, multi-file change touching kernel code and adding new plugin integrations.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py`:
- Around line 102-138: The tests call _extract_video_shape_hook with
args=(modality,) but the hook derives seq_len from a tensor arg, so update each
test (test_extracts_3d_positions, test_extracts_4d_positions_taking_start_coord,
test_skips_when_video_is_none, test_skips_when_positions_is_none,
test_skips_when_product_mismatches_seq_len) to pass the positions tensor as the
first positional arg (e.g. args=(modality.positions,)) or otherwise provide the
actual tensor argument the hook expects so seq_len is computed correctly;
additionally, fix test_skips_when_product_mismatches_seq_len to assert the guard
behavior by expecting no _vsa_video_shape attribute (use assert not
hasattr(model, "_vsa_video_shape")) since the product/seq_len mismatch should
skip setting it.
- Line 29: Remove the module-level test gate by deleting the
pytest.importorskip("transformers") call so the suite no longer gets skipped
when transformers isn't installed; keep the test support approach that uses
_make_named_module() and types.SimpleNamespace() to mock any transformer
behavior and ensure any optional transformers usage in the ltx2 plugin is loaded
lazily via import_plugin() rather than gating the whole file.

In `@tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py`:
- Line 28: Remove the module-level pytest.importorskip("transformers") call from
the test file so tests for the WAN22 plugin (which use only mocks and helpers)
are not skipped in environments without transformers; locate the importorskip
invocation and delete it, and if any future tests require transformers, wrap
only those specific test functions or methods with pytest.importorskip or an
appropriate marker instead of skipping the whole module.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 5f2df6e2-9216-4f5a-8cfe-16d9046738b5

📥 Commits

Reviewing files that changed from the base of the PR and between 051b55a and 416f27e.

📒 Files selected for processing (4)
  • modelopt/torch/sparsity/attention_sparsity/methods/vsa.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
  • tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py
  • tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py
  • modelopt/torch/sparsity/attention_sparsity/methods/vsa.py


import pytest

pytest.importorskip("transformers")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check whether this test file references transformers APIs beyond importorskip.
rg -n --type=py 'importorskip\("transformers"\)|\bimport transformers\b|\bfrom transformers\b' tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py

# Check whether the ltx2 plugin itself has a transformers hard dependency.
rg -n --type=py '\btransformers\b' modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py

Repository: NVIDIA/Model-Optimizer

Length of output: 105


🏁 Script executed:

# First, let's see the test file structure and what it actually tests
head -60 tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py

# Also check the imports and setup of the ltx2 plugin
head -30 modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py

Repository: NVIDIA/Model-Optimizer

Length of output: 3557


🏁 Script executed:

# Check all imports in ltx2.py to see if transformers is pulled in transitively
head -50 modelopt/torch/sparsity/attention_sparsity/plugins/ltx2.py

# Check a representative test method to confirm it uses mocks, not transformers APIs
sed -n '90,150p' tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py

Repository: NVIDIA/Model-Optimizer

Length of output: 4721


Remove unnecessary module-level pytest.importorskip("transformers") gate.

Line 29 unnecessarily skips this entire CPU unit test suite. The tests use mock modules (via _make_named_module() and types.SimpleNamespace()) and do not invoke any transformers APIs. The ltx2 plugin itself has no transformers dependency. This skip prevents the test suite from running in environments without transformers installed, contrary to the coding guideline that optional dependencies should be "loaded lazily via import_plugin()" rather than gating entire test suites at module level. Remove line 29.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py` at line 29,
Remove the module-level test gate by deleting the
pytest.importorskip("transformers") call so the suite no longer gets skipped
when transformers isn't installed; keep the test support approach that uses
_make_named_module() and types.SimpleNamespace() to mock any transformer
behavior and ensure any optional transformers usage in the ltx2 plugin is loaded
lazily via import_plugin() rather than gating the whole file.

Comment on lines +102 to +138
def test_extracts_3d_positions(self):
model = nn.Module()
modality = types.SimpleNamespace(positions=_build_positions(2, 3, 4))
_extract_video_shape_hook(model, args=(modality,))
assert model._vsa_video_shape == (2, 3, 4)

def test_extracts_4d_positions_taking_start_coord(self):
model = nn.Module()
# (B, 3, T, 2) — _extract_video_shape_hook drops the trailing dim.
positions_3d = _build_positions(2, 2, 2) # (1, 3, 8)
positions_4d = positions_3d.unsqueeze(-1).expand(-1, -1, -1, 2).contiguous()
modality = types.SimpleNamespace(positions=positions_4d)
_extract_video_shape_hook(model, args=(modality,))
assert model._vsa_video_shape == (2, 2, 2)

def test_skips_when_video_is_none(self):
model = nn.Module()
_extract_video_shape_hook(model, args=(None,))
assert not hasattr(model, "_vsa_video_shape")

def test_skips_when_positions_is_none(self):
model = nn.Module()
modality = types.SimpleNamespace(positions=None)
_extract_video_shape_hook(model, args=(modality,))
assert not hasattr(model, "_vsa_video_shape")

def test_skips_when_product_mismatches_seq_len(self):
"""Defensive guard: if unique-counts don't multiply to seq_len, bail."""
model = nn.Module()
# Both T-dim and H-dim share the same single value, so unique counts
# collapse and the product no longer equals seq_len.
positions = torch.zeros(1, 3, 4, dtype=torch.long)
positions[0, 2] = torch.arange(4) # only W varies
modality = types.SimpleNamespace(positions=positions)
_extract_video_shape_hook(model, args=(modality,))
# 1 * 1 * 4 == 4 still matches seq_len, so this should succeed.
assert model._vsa_video_shape == (1, 1, 4)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Align _extract_video_shape_hook tests with the hook’s seq-len guard.

On Lines 105/114/136, the hook is called with args=(modality,), but implementation resolves seq_len from tensor args before setting _vsa_video_shape. Also, test_skips_when_product_mismatches_seq_len (Line 128) currently asserts success on Line 138, which contradicts its intent.

🔧 Suggested test fixes
     def test_extracts_3d_positions(self):
         model = nn.Module()
         modality = types.SimpleNamespace(positions=_build_positions(2, 3, 4))
-        _extract_video_shape_hook(model, args=(modality,))
+        seq = torch.zeros(1, 24)  # seq_len = T*H*W
+        _extract_video_shape_hook(model, args=(modality, seq))
         assert model._vsa_video_shape == (2, 3, 4)

     def test_extracts_4d_positions_taking_start_coord(self):
         model = nn.Module()
         # (B, 3, T, 2) — _extract_video_shape_hook drops the trailing dim.
         positions_3d = _build_positions(2, 2, 2)  # (1, 3, 8)
         positions_4d = positions_3d.unsqueeze(-1).expand(-1, -1, -1, 2).contiguous()
         modality = types.SimpleNamespace(positions=positions_4d)
-        _extract_video_shape_hook(model, args=(modality,))
+        seq = torch.zeros(1, 8)  # seq_len = T*H*W
+        _extract_video_shape_hook(model, args=(modality, seq))
         assert model._vsa_video_shape == (2, 2, 2)

     def test_skips_when_product_mismatches_seq_len(self):
         """Defensive guard: if unique-counts don't multiply to seq_len, bail."""
         model = nn.Module()
         # Both T-dim and H-dim share the same single value, so unique counts
         # collapse and the product no longer equals seq_len.
         positions = torch.zeros(1, 3, 4, dtype=torch.long)
         positions[0, 2] = torch.arange(4)  # only W varies
         modality = types.SimpleNamespace(positions=positions)
-        _extract_video_shape_hook(model, args=(modality,))
-        # 1 * 1 * 4 == 4 still matches seq_len, so this should succeed.
-        assert model._vsa_video_shape == (1, 1, 4)
+        seq = torch.zeros(1, 5)  # mismatched seq_len
+        _extract_video_shape_hook(model, args=(modality, seq))
+        assert not hasattr(model, "_vsa_video_shape")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/sparsity/attention_sparsity/test_ltx2_plugin.py` around
lines 102 - 138, The tests call _extract_video_shape_hook with args=(modality,)
but the hook derives seq_len from a tensor arg, so update each test
(test_extracts_3d_positions, test_extracts_4d_positions_taking_start_coord,
test_skips_when_video_is_none, test_skips_when_positions_is_none,
test_skips_when_product_mismatches_seq_len) to pass the positions tensor as the
first positional arg (e.g. args=(modality.positions,)) or otherwise provide the
actual tensor argument the hook expects so seq_len is computed correctly;
additionally, fix test_skips_when_product_mismatches_seq_len to assert the guard
behavior by expecting no _vsa_video_shape attribute (use assert not
hasattr(model, "_vsa_video_shape")) since the product/seq_len mismatch should
skip setting it.


import pytest

pytest.importorskip("transformers")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether this test file actually uses transformers symbols beyond importorskip.
rg -n --type=py 'importorskip\("transformers"\)|\bimport transformers\b|\bfrom transformers\b' tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py

# Verify whether the plugin under test hard-depends on transformers.
rg -n --type=py '\btransformers\b' modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py

Repository: NVIDIA/Model-Optimizer

Length of output: 303


🏁 Script executed:

#!/bin/bash
# Check test file imports and structure around the skip statement
head -50 tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py

# Check if wan22.py imports transformers anywhere
grep -n "import.*transformers\|from.*transformers" modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py

# Find _find_wan22_transformers function definition
grep -n "_find_wan22_transformers" modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1966


🏁 Script executed:

#!/bin/bash
# Check imports at the top of wan22.py
head -60 modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py | grep -E "^import|^from"

# Also check the full plugin file for transformers imports
rg -n "^import.*transformers|^from.*transformers" modelopt/torch/sparsity/attention_sparsity/plugins/wan22.py

Repository: NVIDIA/Model-Optimizer

Length of output: 191


🏁 Script executed:

#!/bin/bash
# Check if any test function actually uses transformers (not just variable name "transformers")
# Get the test function definitions
rg -n "^def test_" tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py

# Check if transformers module is referenced anywhere after the skip (not just variable "transformers")
rg -n "transformers\." tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

#!/bin/bash
# Get file size first
wc -l tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py

# Read the full test file
cat -n tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py

Repository: NVIDIA/Model-Optimizer

Length of output: 10447


Remove the unnecessary module-level skip on transformers.

The test file's docstring explicitly states it tests "without requiring a diffusers/Wan install." The file uses only mock classes and test helpers—no transformers APIs are invoked. The plugin itself (wan22.py) has no transformers dependency. Line 28's pytest.importorskip("transformers") will silently skip all tests in CI environments lacking transformers, reducing coverage of core plugin functionality.

Remove the skip entirely. If specific tests in the future need transformers, gate only those individual test methods.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/sparsity/attention_sparsity/test_wan22_plugin.py` at line
28, Remove the module-level pytest.importorskip("transformers") call from the
test file so tests for the WAN22 plugin (which use only mocks and helpers) are
not skipped in environments without transformers; locate the importorskip
invocation and delete it, and if any future tests require transformers, wrap
only those specific test functions or methods with pytest.importorskip or an
appropriate marker instead of skipping the whole module.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants