-
Notifications
You must be signed in to change notification settings - Fork 380
Skip Softmax diffusion export #1269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jingyu-ml
wants to merge
45
commits into
main
Choose a base branch
from
jingyux/diffusion-skip-softmax-2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
1d8ac33
Add the Skip softmax diffusion
jingyu-ml 1f8f0d3
Add test case
jingyu-ml 5873652
Fixed error
jingyu-ml 4c179a3
Fixed the test case
jingyu-ml 2c323df
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 8702b7b
Removed the token import
jingyu-ml bbe2123
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 70099a5
removed the unused code
jingyu-ml 6cc96a4
Update the README
jingyu-ml 4de0d3b
Updated the example script
jingyu-ml b3d3d4d
Update the readme
jingyu-ml 8dc6162
Update the calibration kernel
jingyu-ml 8aa32cc
ADd the readme
jingyu-ml fbeabcf
Update the example script
jingyu-ml 6a4ab8b
Update the code
jingyu-ml d7dd15c
Update the calibration loop
jingyu-ml b86d311
Remove the eager attention
jingyu-ml f5a9af9
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 45bcad6
Update the calibration, fixed some bugs
jingyu-ml 22c5b85
Add the test case
jingyu-ml aa44a9d
Fixed the lint error
jingyu-ml e5293de
Updated the README
jingyu-ml 40fdd44
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 40d61dd
Update the test case
jingyu-ml 3845b47
Fixed the CICD
jingyu-ml 560015c
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml f86580c
Added the ltx2 warning
jingyu-ml ee162b3
addressed the ltx2 issue and the import issue
jingyu-ml eef0577
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 5219ab0
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 5ba76ac
Address comments
jingyu-ml d2d6d83
Update the readme
jingyu-ml 5ab4ebb
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 8d71522
Update
jingyu-ml 79b5f2a
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml 8787d47
Added the test case
jingyu-ml 7bae5d7
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml a21cac2
Remove the eager import
jingyu-ml 849cc1a
Merge branch 'jingyux/diffusion-skip-softmax' into jingyux/diffusion-…
jingyu-ml a1b002e
Update the format
jingyu-ml b3f4cab
Add the change to the changelog
jingyu-ml 4dbab66
Merge branch 'main' into jingyux/diffusion-skip-softmax
jingyu-ml b77d098
Merge branch 'jingyux/diffusion-skip-softmax' into jingyux/diffusion-…
jingyu-ml 65f380d
Merge branch 'main' into jingyux/diffusion-skip-softmax-2
jingyu-ml 5eb9352
Merge branch 'main' into jingyux/diffusion-skip-softmax-2
jingyu-ml File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -349,73 +349,148 @@ def update_sparse_attention_metadata( | |
| def export_sparse_attention_config(model: nn.Module) -> dict[str, Any] | None: | ||
| """Extract sparse attention config for export to config.json. | ||
|
|
||
| Extracts the calibration parameters (a, b) for the exponential threshold model | ||
| from the first sparse attention module that has calibrated thresholds. | ||
| Extracts calibration parameters, method metadata, and per-layer enable/disable | ||
| state from sparse attention modules. Supports both LLM and diffusion models. | ||
|
|
||
| The exported config allows computing threshold at runtime: | ||
| scale_factor = a * exp(b * target_sparsity) | ||
| threshold = scale_factor / seqlen | ||
| Algorithm-specific parameters (``threshold_scale_factor``, ``raw_threshold``, | ||
| ``disabled_layers``) are nested inside the config group that owns them. | ||
| This allows future sparse attention methods to define their own parameter | ||
| schemas in separate groups without collision. | ||
|
|
||
| The formula in the export reflects the actual fitting mode used during | ||
| calibration: | ||
|
|
||
| - **Linear-space fit** (default, LLMs): ``scale_factor = a * exp(b * S)`` | ||
| exports ``a`` and ``b``. | ||
| - **Log-space fit** (diffusion): ``log_a + b * S`` | ||
| exports ``log_a`` and ``b``. | ||
|
|
||
| At runtime: ``threshold = scale_factor / seqlen``. | ||
|
|
||
| Args: | ||
| model: Model with sparse attention applied | ||
|
|
||
| Returns: | ||
| Dictionary with sparse attention config for HuggingFace config.json export. | ||
| Returns None if no calibrated sparse attention modules found. | ||
| Returns None if no sparse attention modules are found, or if no calibration | ||
| parameters and no raw threshold are available. | ||
|
|
||
| Example output:: | ||
| Example output (LLM, linear-space fit):: | ||
|
|
||
| { | ||
| "config_groups": { | ||
| "group_0": {"sparse_algo": "softmax_skip", "targets": ["LlamaAttention"]} | ||
| "group_0": { | ||
| "sparse_algo": "softmax_skip", | ||
| "targets": ["LlamaAttention"], | ||
| "threshold_scale_factor": { | ||
| "formula": "a * exp(b * target_sparsity)", | ||
| "prefill": {"a": 7.93, "b": 8.61}, | ||
| "decode": {"a": 0.12, "b": 9.85}, | ||
| }, | ||
| } | ||
| }, | ||
| "threshold_scale_factor": { | ||
| "formula": "a * exp(b * target_sparsity)", | ||
| "prefill": {"a": 7.93, "b": 8.61}, | ||
| "decode": {"a": 0.12, "b": 9.85}, | ||
| "producer": {"name": "modelopt", "version": "0.37.0"}, | ||
| } | ||
|
|
||
| Example output (diffusion, log-space fit):: | ||
|
|
||
| { | ||
| "config_groups": { | ||
| "group_0": { | ||
| "sparse_algo": "softmax_skip", | ||
| "targets": ["Attention"], | ||
| "threshold_scale_factor": { | ||
| "formula": "log_a + b * target_sparsity", | ||
| "prefill": {"log_a": 0.21, "b": 3.45}, | ||
| }, | ||
| "disabled_layers": ["blocks.0.attn1", "blocks.39.attn1"], | ||
| } | ||
| }, | ||
| "producer": {"name": "modelopt", "version": "0.37.0"}, | ||
| } | ||
| """ | ||
| # Collect sparse attention module info | ||
| calibration_params = None | ||
| raw_threshold = None | ||
| target_classes: set[str] = set() | ||
| disabled_layer_names: list[str] = [] | ||
|
|
||
| for module in get_sparse_attention_modules(model): | ||
| for name, module in get_named_sparse_attention_modules(model): | ||
| # Get the original wrapped module's class name | ||
| if hasattr(module, "get_original_cls_by_level"): | ||
| original_cls = module.get_original_cls_by_level(level=0) | ||
| if original_cls is not None: | ||
| target_classes.add(original_cls.__name__) | ||
|
|
||
| # Get calibration params from first module that has them | ||
| if not module.is_enabled: | ||
| disabled_layer_names.append(get_unwrapped_name(name, model)) | ||
| continue | ||
|
|
||
| # Get calibration params from first enabled module that has them | ||
| if calibration_params is None: | ||
| calibration_params = getattr(module._sparse_method_instance, "calibration_params", None) | ||
|
|
||
| # Return None if no calibration params found | ||
| if calibration_params is None: | ||
| # Get raw threshold from first enabled module that has one | ||
| if raw_threshold is None: | ||
| raw_threshold = getattr( | ||
| module._sparse_method_instance, "skip_softmax_raw_threshold", None | ||
| ) | ||
|
|
||
| # Nothing exportable if no calibration params and no raw threshold | ||
| if calibration_params is None and raw_threshold is None: | ||
| return None | ||
|
Comment on lines
+439
to
441
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don’t drop Lines [439]-[441] return 🔧 Suggested fix- if calibration_params is None and raw_threshold is None:
+ if calibration_params is None and raw_threshold is None and not disabled_layer_names:
return None🤖 Prompt for AI Agents |
||
|
|
||
| # Build threshold_scale_factor with model parameters | ||
| threshold_scale_factor: dict[str, Any] = { | ||
| "formula": "a * exp(b * target_sparsity)", | ||
| # Build the config group for softmax_skip. | ||
| # All algorithm-specific parameters live inside the group so that future | ||
| # sparse attention methods can define their own parameter schemas in | ||
| # separate groups without collision. | ||
| group_0: dict[str, Any] = { | ||
| "sparse_algo": "softmax_skip", | ||
| "targets": sorted(target_classes) if target_classes else ["Attention"], | ||
| } | ||
| for phase in ["prefill", "decode"]: | ||
| if phase in calibration_params: | ||
| threshold_scale_factor[phase] = { | ||
| "a": calibration_params[phase]["a"], | ||
| "b": calibration_params[phase]["b"], | ||
|
|
||
| # Build threshold_scale_factor from calibration params. | ||
| # The formula depends on the fitting mode used during calibration: | ||
| # - Linear-space fit: scale_factor = a * exp(b * target_sparsity) | ||
| # - Log-space fit: log(scale_factor) = log_a + b * target_sparsity | ||
| if calibration_params is not None: | ||
| first_phase = next((p for p in ["prefill", "decode"] if p in calibration_params), None) | ||
| fit_logspace = first_phase is not None and calibration_params[first_phase].get( | ||
| "fit_logspace", False | ||
| ) | ||
|
|
||
| if fit_logspace: | ||
| threshold_scale_factor: dict[str, Any] = { | ||
| "formula": "log_a + b * target_sparsity", | ||
| } | ||
| for phase in ["prefill", "decode"]: | ||
| if phase in calibration_params and "log_a" in calibration_params[phase]: | ||
| threshold_scale_factor[phase] = { | ||
| "log_a": calibration_params[phase]["log_a"], | ||
| "b": calibration_params[phase]["b"], | ||
| } | ||
| else: | ||
| threshold_scale_factor = { | ||
| "formula": "a * exp(b * target_sparsity)", | ||
| } | ||
| for phase in ["prefill", "decode"]: | ||
| if phase in calibration_params: | ||
| threshold_scale_factor[phase] = { | ||
| "a": calibration_params[phase]["a"], | ||
| "b": calibration_params[phase]["b"], | ||
| } | ||
|
|
||
| group_0["threshold_scale_factor"] = threshold_scale_factor | ||
|
|
||
| if raw_threshold is not None: | ||
| group_0["raw_threshold"] = raw_threshold | ||
|
|
||
| if disabled_layer_names: | ||
| group_0["disabled_layers"] = disabled_layer_names | ||
|
|
||
| # Build the export config | ||
| export_config: dict[str, Any] = { | ||
| "config_groups": { | ||
| "group_0": { | ||
| "sparse_algo": "softmax_skip", | ||
| "targets": sorted(target_classes) if target_classes else ["Attention"], | ||
| } | ||
| }, | ||
| "threshold_scale_factor": threshold_scale_factor, | ||
| "config_groups": {"group_0": group_0}, | ||
| "producer": { | ||
| "name": "modelopt", | ||
| "version": mo_version, | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard
log_aaccess whenfit_logspaceis enabled.Line [350] directly indexes
result["log_a"]; if calibration output is incomplete, this throws and drops the whole calibration flow.🔧 Suggested fix
if result.get("fit_logspace"): - params["log_a"] = result["log_a"] - params["fit_logspace"] = True + if "log_a" not in result: + warnings.warn(f"{phase} calibration marked fit_logspace=True but missing log_a") + else: + params["log_a"] = result["log_a"] + params["fit_logspace"] = True🤖 Prompt for AI Agents