Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
1d8ac33
Add the Skip softmax diffusion
jingyu-ml Apr 2, 2026
1f8f0d3
Add test case
jingyu-ml Apr 2, 2026
5873652
Fixed error
jingyu-ml Apr 2, 2026
4c179a3
Fixed the test case
jingyu-ml Apr 2, 2026
2c323df
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 2, 2026
8702b7b
Removed the token import
jingyu-ml Apr 6, 2026
bbe2123
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 6, 2026
70099a5
removed the unused code
jingyu-ml Apr 6, 2026
6cc96a4
Update the README
jingyu-ml Apr 6, 2026
4de0d3b
Updated the example script
jingyu-ml Apr 7, 2026
b3d3d4d
Update the readme
jingyu-ml Apr 7, 2026
8dc6162
Update the calibration kernel
jingyu-ml Apr 7, 2026
8aa32cc
ADd the readme
jingyu-ml Apr 7, 2026
fbeabcf
Update the example script
jingyu-ml Apr 7, 2026
6a4ab8b
Update the code
jingyu-ml Apr 7, 2026
d7dd15c
Update the calibration loop
jingyu-ml Apr 8, 2026
b86d311
Remove the eager attention
jingyu-ml Apr 8, 2026
f5a9af9
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 8, 2026
45bcad6
Update the calibration, fixed some bugs
jingyu-ml Apr 9, 2026
22c5b85
Add the test case
jingyu-ml Apr 9, 2026
aa44a9d
Fixed the lint error
jingyu-ml Apr 9, 2026
e5293de
Updated the README
jingyu-ml Apr 9, 2026
40fdd44
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 9, 2026
40d61dd
Update the test case
jingyu-ml Apr 9, 2026
3845b47
Fixed the CICD
jingyu-ml Apr 9, 2026
560015c
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 13, 2026
f86580c
Added the ltx2 warning
jingyu-ml Apr 13, 2026
ee162b3
addressed the ltx2 issue and the import issue
jingyu-ml Apr 13, 2026
eef0577
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 13, 2026
5219ab0
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 14, 2026
5ba76ac
Address comments
jingyu-ml Apr 14, 2026
d2d6d83
Update the readme
jingyu-ml Apr 15, 2026
5ab4ebb
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 15, 2026
8d71522
Update
jingyu-ml Apr 15, 2026
79b5f2a
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 15, 2026
8787d47
Added the test case
jingyu-ml Apr 15, 2026
7bae5d7
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 16, 2026
a21cac2
Remove the eager import
jingyu-ml Apr 16, 2026
849cc1a
Merge branch 'jingyux/diffusion-skip-softmax' into jingyux/diffusion-…
jingyu-ml Apr 16, 2026
a1b002e
Update the format
jingyu-ml Apr 16, 2026
b3f4cab
Add the change to the changelog
jingyu-ml Apr 16, 2026
4dbab66
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml Apr 16, 2026
b77d098
Merge branch 'jingyux/diffusion-skip-softmax' into jingyux/diffusion-…
jingyu-ml Apr 16, 2026
65f380d
Merge branch 'main' into jingyux/diffusion-skip-softmax-2
jingyu-ml Apr 27, 2026
5eb9352
Merge branch 'main' into jingyux/diffusion-skip-softmax-2
jingyu-ml May 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Changelog
- Add Puzzletron - a new algorithm for heterogeneous pruning of LLM and VLM models. See `examples/puzzletron/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/puzzletron>`_ for more details.
- Added iterator interface using CalibrationDataReader in ONNX quantization workflow.
- Add N:M sparse softmax support to the Triton flash attention kernel (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
- Add skip-softmax skipping to the Triton flash attention kernel (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ for usage.
- Add skip-softmax skipping to the Triton flash attention kernel for both language models and video diffusion models (``modelopt.torch.kernels.common.attention.triton_fa``). See `examples/llm_sparsity/attention_sparsity/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_sparsity/attention_sparsity>`_ and `examples/diffusers/sparsity/ <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/diffusers/sparsity>`_ for usage.
- Add Video Sparse Attention (VSA) method for video diffusion models (``modelopt.torch.sparsity.attention_sparsity``). VSA uses 3D block tiling with a two-branch architecture for attention speedup.
- Enable PTQ workflow for the Step3.5-Flash MoE model with NVFP4 W4A4 + FP8 KV cache quantization. See `modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/models/Step3.5-Flash/nvfp4-mlp-only.yaml>`_ for more details.
- Add support for vLLM fakequant reload using ModelOpt state for HF models. See `examples/vllm_serve/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/vllm_serve#load-qatptq-model-and-serve-in-vllm-wip>`_ for more details.
Expand Down
16 changes: 16 additions & 0 deletions examples/diffusers/sparsity/wan22_skip_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from diffusers.utils import export_to_video

import modelopt.torch.sparsity.attention_sparsity as mtsa
from modelopt.torch.export import export_hf_checkpoint
from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule

DEFAULT_MODEL_PATH = os.environ.get("WAN22_MODEL_PATH", "Wan-AI/Wan2.2-TI2V-5B-Diffusers")
Expand Down Expand Up @@ -199,6 +200,16 @@ def parse_args() -> argparse.Namespace:
default=4,
help="Number of calibration prompts from OpenVid-1M dataset",
)

# Export options
parser.add_argument(
"--export-dir",
type=str,
default=None,
help="Export sparsified model as a HuggingFace checkpoint to this directory. "
"The sparse_attention_config (calibration params, disabled layers, etc.) "
"is written into each component's config.json.",
)
return parser.parse_args()


Expand Down Expand Up @@ -442,6 +453,11 @@ def main() -> None:
torch.cuda.empty_cache()
print("Cleared CUDA cache after calibration")

# ---- Export (optional) ----
if args.export_dir and not args.baseline:
print(f"Exporting sparsified checkpoint to {args.export_dir}...")
export_hf_checkpoint(pipe, export_dir=args.export_dir)

# ---- Generate (optional) ----
if args.prompt:
# Enable runtime sparsity measurement before generation
Expand Down
36 changes: 36 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import torch
import torch.nn as nn
import yaml
from safetensors.torch import save_file

try:
Expand Down Expand Up @@ -949,6 +950,9 @@ def _export_diffusers_checkpoint(
is_diffusers_pipe = False

# Step 3: Export each nn.Module component with quantization handling
# Collect sparse attention configs across all components for a unified sparse.yaml
pipeline_sparse_configs: dict[str, Any] = {}

for component_name, component in module_components.items():
is_quantized = has_quantized_modules(component)
status = "quantized" if is_quantized else "non-quantized"
Expand Down Expand Up @@ -1015,8 +1019,33 @@ def _export_diffusers_checkpoint(
model_type=model_type,
)

# Step 8: Update config.json with sparse attention info (both quantized and non-quantized)
if export_sparse_attention_config is not None:
sparse_attn_config = export_sparse_attention_config(component)
if sparse_attn_config is not None:
config_path = component_export_dir / "config.json"
if config_path.exists():
with open(config_path) as file:
config_data = json.load(file)
config_data["sparse_attention_config"] = sparse_attn_config
with open(config_path, "w") as file:
json.dump(config_data, file, indent=4)
print(f" Added sparse_attention_config to {config_path.name}")

# Collect for the unified sparse.yaml (keyed by component name)
pipeline_sparse_configs[component_name] = sparse_attn_config

print(f" Saved to: {component_export_dir}")

# Step 8.5: Write unified sparse.yaml at the top-level export directory.
# Combines sparse configs from all components (keyed by pipeline component name)
# so downstream consumers get the full pipeline's sparse config in one file.
if pipeline_sparse_configs:
yaml_path = export_dir / "sparse.yaml"
with open(yaml_path, "w") as file:
yaml.dump(pipeline_sparse_configs, file, default_flow_style=False, sort_keys=False)
print(f"Saved unified sparse config to {yaml_path}")

# Step 4: Export non-nn.Module components (tokenizers, schedulers, feature extractors, etc.)
if is_diffusers_pipe:
for component_name, component in all_components.items():
Expand Down Expand Up @@ -1249,6 +1278,13 @@ def export_hf_checkpoint(
if sparse_attn_config is not None:
config_data["sparse_attention_config"] = sparse_attn_config

# Also save as standalone YAML for easy inspection and reuse
import yaml

yaml_path = Path(export_dir) / "sparse.yaml"
with open(yaml_path, "w") as file:
yaml.dump(sparse_attn_config, file, default_flow_style=False, sort_keys=False)

with open(original_config, "w") as file:
json.dump(config_data, file, indent=4)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ def calibrate_sparse_attention(
"a": result["a"],
"b": result["b"],
}
if result.get("fit_logspace"):
params["log_a"] = result["log_a"]
params["fit_logspace"] = True
Comment on lines +349 to +351
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Guard log_a access when fit_logspace is enabled.

Line [350] directly indexes result["log_a"]; if calibration output is incomplete, this throws and drops the whole calibration flow.

🔧 Suggested fix
             if result.get("fit_logspace"):
-                params["log_a"] = result["log_a"]
-                params["fit_logspace"] = True
+                if "log_a" not in result:
+                    warnings.warn(f"{phase} calibration marked fit_logspace=True but missing log_a")
+                else:
+                    params["log_a"] = result["log_a"]
+                    params["fit_logspace"] = True
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py` around
lines 349 - 351, The code currently assumes result contains "log_a" when
result.get("fit_logspace") is true, which can raise a KeyError; update the block
that sets params["log_a"] and params["fit_logspace"] to first confirm "log_a"
exists (e.g., if "fit_logspace" in result and "log_a" in result) or use
result.get("log_a") and only assign params["log_a"] when that value is not None,
and still set params["fit_logspace"]=True when fitting was attempted; locate
this logic around the params/result handling in calibrate.py (the result dict,
params dict, and the "fit_logspace"/"log_a" keys) and add the guard so
incomplete calibration output cannot break the flow.

if "min_observed_sparsity" in result:
params["min_observed_sparsity"] = result["min_observed_sparsity"]
if "max_observed_sparsity" in result:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,21 @@ def exponential(sparsity, a, b):
avg_s = np.mean([p["sparsity"] for p in points])
print(f" {threshold:<12.4f} {avg_sf:<12.2f} {avg_s:<12.2%} {len(points):<8}")

return {
result = {
"phase": phase,
"a": float(a),
"b": float(b),
"r_squared": float(r_squared),
"num_data_points": int(np.sum(valid_mask)),
"total_samples": len(all_data_points),
"calibration_type": "exponential",
"fit_logspace": self.fit_logspace,
"min_observed_sparsity": min_observed_sparsity,
"max_observed_sparsity": max_observed_sparsity,
}
if self.fit_logspace:
result["log_a"] = float(log_a)
return result

def _enable_calibration_mode(self, modules: list[nn.Module]):
"""Enable calibration mode on sparse attention modules."""
Expand Down
137 changes: 106 additions & 31 deletions modelopt/torch/sparsity/attention_sparsity/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,73 +349,148 @@ def update_sparse_attention_metadata(
def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None:
"""Extract sparse attention config for export to config.json.

Extracts the calibration parameters (a, b) for the exponential threshold model
from the first sparse attention module that has calibrated thresholds.
Extracts calibration parameters, method metadata, and per-layer enable/disable
state from sparse attention modules. Supports both LLM and diffusion models.

The exported config allows computing threshold at runtime:
scale_factor = a * exp(b * target_sparsity)
threshold = scale_factor / seqlen
Algorithm-specific parameters (``threshold_scale_factor``, ``raw_threshold``,
``disabled_layers``) are nested inside the config group that owns them.
This allows future sparse attention methods to define their own parameter
schemas in separate groups without collision.

The formula in the export reflects the actual fitting mode used during
calibration:

- **Linear-space fit** (default, LLMs): ``scale_factor = a * exp(b * S)``
exports ``a`` and ``b``.
- **Log-space fit** (diffusion): ``log_a + b * S``
exports ``log_a`` and ``b``.

At runtime: ``threshold = scale_factor / seqlen``.

Args:
model: Model with sparse attention applied

Returns:
Dictionary with sparse attention config for HuggingFace config.json export.
Returns None if no calibrated sparse attention modules found.
Returns None if no sparse attention modules are found, or if no calibration
parameters and no raw threshold are available.

Example output::
Example output (LLM, linear-space fit)::

{
"config_groups": {
"group_0": {"sparse_algo": "softmax_skip", "targets": ["LlamaAttention"]}
"group_0": {
"sparse_algo": "softmax_skip",
"targets": ["LlamaAttention"],
"threshold_scale_factor": {
"formula": "a * exp(b * target_sparsity)",
"prefill": {"a": 7.93, "b": 8.61},
"decode": {"a": 0.12, "b": 9.85},
},
}
},
"threshold_scale_factor": {
"formula": "a * exp(b * target_sparsity)",
"prefill": {"a": 7.93, "b": 8.61},
"decode": {"a": 0.12, "b": 9.85},
"producer": {"name": "modelopt", "version": "0.37.0"},
}

Example output (diffusion, log-space fit)::

{
"config_groups": {
"group_0": {
"sparse_algo": "softmax_skip",
"targets": ["Attention"],
"threshold_scale_factor": {
"formula": "log_a + b * target_sparsity",
"prefill": {"log_a": 0.21, "b": 3.45},
},
"disabled_layers": ["blocks.0.attn1", "blocks.39.attn1"],
}
},
"producer": {"name": "modelopt", "version": "0.37.0"},
}
"""
# Collect sparse attention module info
calibration_params = None
raw_threshold = None
target_classes: set[str] = set()
disabled_layer_names: list[str] = []

for module in get_sparse_attention_modules(model):
for name, module in get_named_sparse_attention_modules(model):
# Get the original wrapped module's class name
if hasattr(module, "get_original_cls_by_level"):
original_cls = module.get_original_cls_by_level(level=0)
if original_cls is not None:
target_classes.add(original_cls.__name__)

# Get calibration params from first module that has them
if not module.is_enabled:
disabled_layer_names.append(get_unwrapped_name(name, model))
continue

# Get calibration params from first enabled module that has them
if calibration_params is None:
calibration_params = getattr(module._sparse_method_instance, "calibration_params", None)

# Return None if no calibration params found
if calibration_params is None:
# Get raw threshold from first enabled module that has one
if raw_threshold is None:
raw_threshold = getattr(
module._sparse_method_instance, "skip_softmax_raw_threshold", None
)

# Nothing exportable if no calibration params and no raw threshold
if calibration_params is None and raw_threshold is None:
return None
Comment on lines +439 to 441
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don’t drop disabled_layers-only exports.

Lines [439]-[441] return None even when disabled_layer_names was populated, so per-layer disable metadata is lost for non-calibrated/static-threshold runs.

🔧 Suggested fix
-    if calibration_params is None and raw_threshold is None:
+    if calibration_params is None and raw_threshold is None and not disabled_layer_names:
         return None
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/conversion.py` around lines 439 -
441, The early-return drops per-layer disable metadata: instead of returning
None when calibration_params and raw_threshold are both None, check if
disabled_layer_names (or the variable storing disabled layers) is non-empty and,
if so, return an export payload that contains the disabled_layer_names metadata
(even if calibration_params and raw_threshold are absent); otherwise keep
returning None. Update the conditional around calibration_params and
raw_threshold to preserve/emit disabled_layer_names in the exported structure so
disabled-layer-only exports are retained.


# Build threshold_scale_factor with model parameters
threshold_scale_factor: dict[str, Any] = {
"formula": "a * exp(b * target_sparsity)",
# Build the config group for softmax_skip.
# All algorithm-specific parameters live inside the group so that future
# sparse attention methods can define their own parameter schemas in
# separate groups without collision.
group_0: dict[str, Any] = {
"sparse_algo": "softmax_skip",
"targets": sorted(target_classes) if target_classes else ["Attention"],
}
for phase in ["prefill", "decode"]:
if phase in calibration_params:
threshold_scale_factor[phase] = {
"a": calibration_params[phase]["a"],
"b": calibration_params[phase]["b"],

# Build threshold_scale_factor from calibration params.
# The formula depends on the fitting mode used during calibration:
# - Linear-space fit: scale_factor = a * exp(b * target_sparsity)
# - Log-space fit: log(scale_factor) = log_a + b * target_sparsity
if calibration_params is not None:
first_phase = next((p for p in ["prefill", "decode"] if p in calibration_params), None)
fit_logspace = first_phase is not None and calibration_params[first_phase].get(
"fit_logspace", False
)

if fit_logspace:
threshold_scale_factor: dict[str, Any] = {
"formula": "log_a + b * target_sparsity",
}
for phase in ["prefill", "decode"]:
if phase in calibration_params and "log_a" in calibration_params[phase]:
threshold_scale_factor[phase] = {
"log_a": calibration_params[phase]["log_a"],
"b": calibration_params[phase]["b"],
}
else:
threshold_scale_factor = {
"formula": "a * exp(b * target_sparsity)",
}
for phase in ["prefill", "decode"]:
if phase in calibration_params:
threshold_scale_factor[phase] = {
"a": calibration_params[phase]["a"],
"b": calibration_params[phase]["b"],
}

group_0["threshold_scale_factor"] = threshold_scale_factor

if raw_threshold is not None:
group_0["raw_threshold"] = raw_threshold

if disabled_layer_names:
group_0["disabled_layers"] = disabled_layer_names

# Build the export config
export_config: dict[str, Any] = {
"config_groups": {
"group_0": {
"sparse_algo": "softmax_skip",
"targets": sorted(target_classes) if target_classes else ["Attention"],
}
},
"threshold_scale_factor": threshold_scale_factor,
"config_groups": {"group_0": group_0},
"producer": {
"name": "modelopt",
"version": mo_version,
Expand Down
Loading
Loading