Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions examples/diffusers/quantization/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
StableDiffusion3Pipeline,
WanPipeline,
)

try:
from diffusers import Flux2Pipeline
except ImportError:
Flux2Pipeline = None
from utils import (
filter_func_default,
filter_func_flux_dev,
Expand All @@ -42,6 +47,7 @@ class ModelType(str, Enum):
SD35_MEDIUM = "sd3.5-medium"
FLUX_DEV = "flux-dev"
FLUX_SCHNELL = "flux-schnell"
FLUX2_DEV = "flux2-dev"
LTX_VIDEO_DEV = "ltx-video-dev"
LTX2 = "ltx-2"
WAN22_T2V_14b = "wan2.2-t2v-14b"
Expand All @@ -61,6 +67,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
filter_func_map = {
ModelType.FLUX_DEV: filter_func_flux_dev,
ModelType.FLUX_SCHNELL: filter_func_default,
ModelType.FLUX2_DEV: filter_func_flux_dev,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Check if there's any Flux2-specific filter function defined
rg -n 'filter_func.*flux2' --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Also search for all flux-related filter function definitions
rg -n 'def filter_func.*flux' --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 154


🏁 Script executed:

# Check the context around line 66 in the file to understand the mapping
cat -n examples/diffusers/quantization/models_utils.py | sed -n '50,80p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1309


🏁 Script executed:

# Read the filter_func_flux_dev implementation
sed -n '82,120p' examples/diffusers/quantization/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1456


🏁 Script executed:

# Search for Flux2 model references and layer patterns
rg -n 'flux2|FLUX2' --type=py -i -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 6822


🏁 Script executed:

# Check if there are any tests or examples showing Flux2 layer structure
fd -type f -name '*.py' | xargs grep -l 'flux2\|FLUX2' | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 236


🏁 Script executed:

# Look for Flux2 transformer model structure
rg -n 'Flux2Transformer2DModel|transformer_flux2' --type=py -A 5 | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 1926


🏁 Script executed:

# Check if there are any config or model files for Flux2
fd -type f \( -name '*flux2*' -o -name '*Flux2*' \)

Repository: NVIDIA/Model-Optimizer

Length of output: 236


🏁 Script executed:

# Search for any documentation or tests showing Flux2 layer naming
rg -n 'time_text_embed|context_embedder|x_embedder' --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 800


Create a Flux2-specific filter function.

filter_func_flux_dev uses layer names (time_text_embed, context_embedder, x_embedder, norm_out) that are specific to Flux1. Since Flux2Transformer2DModel is architecturally different—it does not use pooled_projections, has distinct attention modules (Flux2Attention, Flux2ParallelSelfAttention), and different input handling—it likely has different layer naming conventions. A separate filter function should be created for Flux2 to ensure the correct layers are selected for quantization.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/diffusers/quantization/models_utils.py` at line 66, The mapping uses
filter_func_flux_dev (Flux1-specific) for ModelType.FLUX2_DEV; create a new
Flux2-specific filter function (e.g., filter_func_flux2_dev) that targets the
actual Flux2 layer/attribute names and structures (handle absence of
pooled_projections, different embedder names, and the Flux2 attention modules
Flux2Attention and Flux2ParallelSelfAttention) and replace the mapping entry to
use it; implement the selection logic in filter_func_flux2_dev to inspect module
class/type and attribute names used by Flux2Transformer2DModel and return True
only for the quantizable weight/bias tensors for those modules, then register
filter_func_flux2_dev in the ModelType.FLUX2_DEV entry.

ModelType.SDXL_BASE: filter_func_default,
ModelType.SDXL_TURBO: filter_func_default,
ModelType.SD3_MEDIUM: filter_func_default,
Expand All @@ -82,6 +89,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium",
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
ModelType.FLUX2_DEV: "black-forest-labs/FLUX.2-dev",
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
ModelType.LTX2: "Lightricks/LTX-2",
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
Expand All @@ -95,6 +103,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SD35_MEDIUM: StableDiffusion3Pipeline,
ModelType.FLUX_DEV: FluxPipeline,
ModelType.FLUX_SCHNELL: FluxPipeline,
ModelType.FLUX2_DEV: Flux2Pipeline,
ModelType.LTX_VIDEO_DEV: LTXConditionPipeline,
ModelType.LTX2: None,
ModelType.WAN22_T2V_14b: WanPipeline,
Expand Down Expand Up @@ -149,6 +158,15 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SD35_MEDIUM: _SD3_BASE_CONFIG,
ModelType.FLUX_DEV: _FLUX_BASE_CONFIG,
ModelType.FLUX_SCHNELL: _FLUX_BASE_CONFIG,
ModelType.FLUX2_DEV: {
"backbone": "transformer",
"dataset": _SD_PROMPTS_DATASET,
"inference_extra_args": {
"height": 768,
"width": 1024,
"guidance_scale": 4.0,
},
},
ModelType.LTX_VIDEO_DEV: {
"backbone": "transformer",
"dataset": _OPENVID_DATASET,
Expand Down
18 changes: 12 additions & 6 deletions examples/diffusers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,15 @@ def check_conv_and_mha(backbone, if_fp4, quantize_mha):
elif isinstance(module, (Attention, AttentionModuleMixin)):
head_size = int(module.inner_dim / module.heads)
if not quantize_mha or head_size % 16 != 0:
module.q_bmm_quantizer.disable()
module.k_bmm_quantizer.disable()
module.v_bmm_quantizer.disable()
module.softmax_quantizer.disable()
module.bmm2_output_quantizer.disable()
for attr in (
"q_bmm_quantizer",
"k_bmm_quantizer",
"v_bmm_quantizer",
"softmax_quantizer",
"bmm2_output_quantizer",
):
if hasattr(module, attr):
getattr(module, attr).disable()
setattr(module, "_disable_fp8_mha", True)

print(f"Disabled Attention layer quantization for layer {name}")
Expand All @@ -77,7 +81,9 @@ def filter_func_ltx_video(name: str) -> bool:

def filter_func_flux_dev(name: str) -> bool:
"""Filter function specifically for Flux-dev models."""
pattern = re.compile(r"(proj_out.*|.*(time_text_embed|context_embedder|x_embedder|norm_out).*)")
pattern = re.compile(
r"(proj_out.*|.*(time_text_embed|context_embedder|x_embedder|norm_out|time_guidance_embed|stream_modulation).*)"
)
return pattern.match(name) is not None


Expand Down
39 changes: 38 additions & 1 deletion modelopt/torch/export/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,12 @@ def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool:
except (ImportError, AttributeError):
return fallback

is_flux = _is_model_type(
is_flux2 = _is_model_type(
"diffusers.models.transformers",
"Flux2Transformer2DModel",
model_class_name == "Flux2Transformer2DModel",
)
is_flux = not is_flux2 and _is_model_type(
"diffusers.models.transformers",
"FluxTransformer2DModel",
"flux" in model_class_name.lower(),
Expand Down Expand Up @@ -160,6 +165,37 @@ def _flux_inputs() -> dict[str, torch.Tensor]:
dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32)
return dummy_inputs

def _flux2_inputs() -> dict[str, torch.Tensor]:
# Flux2Transformer2DModel: 3D hidden_states (batch, seq_len, in_channels)
# Requires: hidden_states, encoder_hidden_states, timestep, img_ids, txt_ids
# Unlike Flux1, Flux2 does NOT use pooled_projections.
# RoPE uses 4 axes (32,32,32,32) so img_ids/txt_ids have 4 columns.
in_channels = getattr(cfg, "in_channels", 128)
joint_attention_dim = getattr(cfg, "joint_attention_dim", 15360)
axes_dims_rope = getattr(cfg, "axes_dims_rope", (32, 32, 32, 32))
guidance_embeds = getattr(cfg, "guidance_embeds", True)

# Use small dimensions for dummy forward
img_seq_len = 16 # 4x4 latent grid
text_seq_len = 8
rope_ndim = len(axes_dims_rope)

dummy_inputs = {
"hidden_states": torch.randn(
batch_size, img_seq_len, in_channels, device=device, dtype=dtype
),
"encoder_hidden_states": torch.randn(
batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype
),
"timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size),
"img_ids": torch.zeros(img_seq_len, rope_ndim, device=device, dtype=torch.float32),
"txt_ids": torch.zeros(text_seq_len, rope_ndim, device=device, dtype=torch.float32),
"return_dict": False,
}
if guidance_embeds:
dummy_inputs["guidance"] = torch.tensor([4.0], device=device, dtype=torch.float32)
return dummy_inputs
Comment thread
jingyu-ml marked this conversation as resolved.

def _sd3_inputs() -> dict[str, torch.Tensor]:
# SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width)
# Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep
Expand Down Expand Up @@ -313,6 +349,7 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None:
return dummy_inputs

model_input_builders = [
("flux2", is_flux2, _flux2_inputs),
("flux", is_flux, _flux_inputs),
("sd3", is_sd3, _sd3_inputs),
("dit", is_dit, _dit_inputs),
Expand Down
15 changes: 15 additions & 0 deletions modelopt/torch/quantization/plugins/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@
from diffusers.models.transformers.transformer_flux import FluxAttention
from diffusers.models.transformers.transformer_ltx import LTXAttention
from diffusers.models.transformers.transformer_wan import WanAttention

try:
from diffusers.models.transformers.transformer_flux2 import (
Flux2Attention,
Flux2ParallelSelfAttention,
)
except ImportError:
Flux2Attention = None
Flux2ParallelSelfAttention = None
else:
AttentionModuleMixin = type("_dummy_type_no_instance", (), {}) # pylint: disable=invalid-name
from torch.autograd import Function
Expand Down Expand Up @@ -190,6 +199,12 @@ def forward(self, *args, **kwargs):
QuantModuleRegistry.register({FluxAttention: "FluxAttention"})(_QuantAttentionModuleMixin)
QuantModuleRegistry.register({WanAttention: "WanAttention"})(_QuantAttentionModuleMixin)
QuantModuleRegistry.register({LTXAttention: "LTXAttention"})(_QuantAttentionModuleMixin)
if Flux2Attention is not None:
QuantModuleRegistry.register({Flux2Attention: "Flux2Attention"})(_QuantAttentionModuleMixin)
if Flux2ParallelSelfAttention is not None:
QuantModuleRegistry.register({Flux2ParallelSelfAttention: "Flux2ParallelSelfAttention"})(
_QuantAttentionModuleMixin
)


original_scaled_dot_product_attention = F.scaled_dot_product_attention
Expand Down
26 changes: 26 additions & 0 deletions tests/_test_utils/torch/diffusers_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
DiTTransformer2DModel = None
FluxTransformer2DModel = None

try:
from diffusers.models.transformers import Flux2Transformer2DModel
except Exception: # pragma: no cover - optional diffusers models
Flux2Transformer2DModel = None

import modelopt.torch.opt as mto


Expand Down Expand Up @@ -93,6 +98,27 @@ def get_tiny_flux(**config_kwargs):
return FluxTransformer2DModel(**kwargs)


def get_tiny_flux2(**config_kwargs):
"""Create a tiny Flux2Transformer2DModel for testing."""
if Flux2Transformer2DModel is None:
pytest.skip("Flux2Transformer2DModel is not available in this diffusers version.")

kwargs = {
"patch_size": 1,
"in_channels": 16,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 16,
"mlp_ratio": 3.0,
"axes_dims_rope": (4, 4, 4, 4),
}
kwargs.update(**config_kwargs)
return Flux2Transformer2DModel(**kwargs)


def create_tiny_unet_dir(tmp_path: Path, **config_kwargs) -> Path:
"""Create and save a tiny UNet model to a directory."""
tiny_unet = get_tiny_unet(**config_kwargs)
Expand Down
42 changes: 40 additions & 2 deletions tests/unit/torch/export/test_export_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
import json

import pytest
from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet
from _test_utils.torch.diffusers_models import (
get_tiny_dit,
get_tiny_flux,
get_tiny_flux2,
get_tiny_unet,
)

pytest.importorskip("diffusers")

Expand All @@ -29,7 +34,9 @@ def _load_config(config_path):
return json.load(file)


@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux])
@pytest.mark.parametrize(
"model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux, get_tiny_flux2]
)
def test_export_diffusers_models_non_quantized(tmp_path, model_factory):
model = model_factory()
export_dir = tmp_path / f"export_{type(model).__name__}"
Expand Down Expand Up @@ -82,3 +89,34 @@ def _process_stub(*_args, **_kwargs):
config_data = _load_config(config_path)
assert "quantization_config" in config_data
assert config_data["quantization_config"] == convert_hf_quant_config_format(dummy_quant_config)


def test_flux2_dummy_inputs_shape():
"""Verify Flux2-specific dummy input shapes: 4-col RoPE ids, no pooled_projections, guidance."""
import torch

from modelopt.torch.export.diffusers_utils import generate_diffusion_dummy_inputs

model = get_tiny_flux2()
cfg = model.config
inputs = generate_diffusion_dummy_inputs(model, torch.device("cpu"), torch.float32)

assert inputs is not None, "generate_diffusion_dummy_inputs returned None for Flux2"

# hidden_states: (batch, seq_len, in_channels)
assert inputs["hidden_states"].shape == (1, 16, cfg.in_channels)

# encoder_hidden_states: (batch, text_seq_len, joint_attention_dim)
assert inputs["encoder_hidden_states"].shape == (1, 8, cfg.joint_attention_dim)

# RoPE ids must have 4 columns (not 3 like Flux1)
rope_ndim = len(cfg.axes_dims_rope)
assert rope_ndim == 4
assert inputs["img_ids"].shape == (16, rope_ndim)
assert inputs["txt_ids"].shape == (8, rope_ndim)

# Flux2 must NOT have pooled_projections (unlike Flux1)
assert "pooled_projections" not in inputs

# guidance_embeds defaults to True for Flux2
assert "guidance" in inputs
Loading