From e3a2fa7968084befcfc15543fcd5b1ffcc5bb5a4 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Sat, 28 Feb 2026 00:17:35 +0000 Subject: [PATCH 1/4] Update Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/models_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index d8ca11ed3f..65cfb61007 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -20,6 +20,7 @@ from diffusers import ( DiffusionPipeline, + Flux2Pipeline, FluxPipeline, LTXConditionPipeline, StableDiffusion3Pipeline, @@ -42,6 +43,7 @@ class ModelType(str, Enum): SD35_MEDIUM = "sd3.5-medium" FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" + FLUX2_DEV = "flux2-dev" LTX_VIDEO_DEV = "ltx-video-dev" LTX2 = "ltx-2" WAN22_T2V_14b = "wan2.2-t2v-14b" @@ -61,6 +63,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: filter_func_map = { ModelType.FLUX_DEV: filter_func_flux_dev, ModelType.FLUX_SCHNELL: filter_func_default, + ModelType.FLUX2_DEV: filter_func_flux_dev, ModelType.SDXL_BASE: filter_func_default, ModelType.SDXL_TURBO: filter_func_default, ModelType.SD3_MEDIUM: filter_func_default, @@ -82,6 +85,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium", ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", + ModelType.FLUX2_DEV: "black-forest-labs/FLUX.2-dev", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", ModelType.LTX2: "Lightricks/LTX-2", ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", @@ -95,6 +99,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SD35_MEDIUM: StableDiffusion3Pipeline, ModelType.FLUX_DEV: FluxPipeline, ModelType.FLUX_SCHNELL: FluxPipeline, + ModelType.FLUX2_DEV: Flux2Pipeline, ModelType.LTX_VIDEO_DEV: LTXConditionPipeline, ModelType.LTX2: None, ModelType.WAN22_T2V_14b: WanPipeline, @@ -149,6 +154,15 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SD35_MEDIUM: _SD3_BASE_CONFIG, ModelType.FLUX_DEV: _FLUX_BASE_CONFIG, ModelType.FLUX_SCHNELL: _FLUX_BASE_CONFIG, + ModelType.FLUX2_DEV: { + "backbone": "transformer", + "dataset": _SD_PROMPTS_DATASET, + "inference_extra_args": { + "height": 768, + "width": 1024, + "guidance_scale": 4.0, + }, + }, ModelType.LTX_VIDEO_DEV: { "backbone": "transformer", "dataset": _SD_PROMPTS_DATASET, From 643a74150398a1d73e957831d809364da6a5beb0 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Mon, 2 Mar 2026 04:56:26 +0000 Subject: [PATCH 2/4] Add the flux2-dev export support Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/utils.py | 14 ++++--- modelopt/torch/export/diffusers_utils.py | 39 ++++++++++++++++++- .../plugins/diffusion/diffusers.py | 15 +++++++ 3 files changed, 62 insertions(+), 6 deletions(-) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 21fcd87d0b..7578e53244 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -55,11 +55,15 @@ def check_conv_and_mha(backbone, if_fp4, quantize_mha): elif isinstance(module, (Attention, AttentionModuleMixin)): head_size = int(module.inner_dim / module.heads) if not quantize_mha or head_size % 16 != 0: - module.q_bmm_quantizer.disable() - module.k_bmm_quantizer.disable() - module.v_bmm_quantizer.disable() - module.softmax_quantizer.disable() - module.bmm2_output_quantizer.disable() + for attr in ( + "q_bmm_quantizer", + "k_bmm_quantizer", + "v_bmm_quantizer", + "softmax_quantizer", + "bmm2_output_quantizer", + ): + if hasattr(module, attr): + getattr(module, attr).disable() setattr(module, "_disable_fp8_mha", True) print(f"Disabled Attention layer quantization for layer {name}") diff --git a/modelopt/torch/export/diffusers_utils.py b/modelopt/torch/export/diffusers_utils.py index a9bf138767..c6bbee75df 100644 --- a/modelopt/torch/export/diffusers_utils.py +++ b/modelopt/torch/export/diffusers_utils.py @@ -99,7 +99,12 @@ def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool: except (ImportError, AttributeError): return fallback - is_flux = _is_model_type( + is_flux2 = _is_model_type( + "diffusers.models.transformers", + "Flux2Transformer2DModel", + model_class_name == "Flux2Transformer2DModel", + ) + is_flux = not is_flux2 and _is_model_type( "diffusers.models.transformers", "FluxTransformer2DModel", "flux" in model_class_name.lower(), @@ -158,6 +163,37 @@ def _flux_inputs() -> dict[str, torch.Tensor]: dummy_inputs["guidance"] = torch.tensor([3.5], device=device, dtype=torch.float32) return dummy_inputs + def _flux2_inputs() -> dict[str, torch.Tensor]: + # Flux2Transformer2DModel: 3D hidden_states (batch, seq_len, in_channels) + # Requires: hidden_states, encoder_hidden_states, timestep, img_ids, txt_ids + # Unlike Flux1, Flux2 does NOT use pooled_projections. + # RoPE uses 4 axes (32,32,32,32) so img_ids/txt_ids have 4 columns. + in_channels = getattr(cfg, "in_channels", 128) + joint_attention_dim = getattr(cfg, "joint_attention_dim", 15360) + axes_dims_rope = getattr(cfg, "axes_dims_rope", (32, 32, 32, 32)) + guidance_embeds = getattr(cfg, "guidance_embeds", True) + + # Use small dimensions for dummy forward + img_seq_len = 16 # 4x4 latent grid + text_seq_len = 8 + rope_ndim = len(axes_dims_rope) + + dummy_inputs = { + "hidden_states": torch.randn( + batch_size, img_seq_len, in_channels, device=device, dtype=dtype + ), + "encoder_hidden_states": torch.randn( + batch_size, text_seq_len, joint_attention_dim, device=device, dtype=dtype + ), + "timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size), + "img_ids": torch.zeros(img_seq_len, rope_ndim, device=device, dtype=torch.float32), + "txt_ids": torch.zeros(text_seq_len, rope_ndim, device=device, dtype=torch.float32), + "return_dict": False, + } + if guidance_embeds: + dummy_inputs["guidance"] = torch.tensor([4.0], device=device, dtype=torch.float32) + return dummy_inputs + def _sd3_inputs() -> dict[str, torch.Tensor]: # SD3Transformer2DModel: 4D hidden_states (batch, channels, height, width) # Requires: hidden_states, encoder_hidden_states, pooled_projections, timestep @@ -311,6 +347,7 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None: return dummy_inputs model_input_builders = [ + ("flux2", is_flux2, _flux2_inputs), ("flux", is_flux, _flux_inputs), ("sd3", is_sd3, _sd3_inputs), ("dit", is_dit, _dit_inputs), diff --git a/modelopt/torch/quantization/plugins/diffusion/diffusers.py b/modelopt/torch/quantization/plugins/diffusion/diffusers.py index 2ec057766b..f9ae55b3e2 100644 --- a/modelopt/torch/quantization/plugins/diffusion/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusion/diffusers.py @@ -33,6 +33,15 @@ from diffusers.models.transformers.transformer_flux import FluxAttention from diffusers.models.transformers.transformer_ltx import LTXAttention from diffusers.models.transformers.transformer_wan import WanAttention + + try: + from diffusers.models.transformers.transformer_flux2 import ( + Flux2Attention, + Flux2ParallelSelfAttention, + ) + except ImportError: + Flux2Attention = None + Flux2ParallelSelfAttention = None else: AttentionModuleMixin = type("_dummy_type_no_instance", (), {}) # pylint: disable=invalid-name from torch.autograd import Function @@ -190,6 +199,12 @@ def forward(self, *args, **kwargs): QuantModuleRegistry.register({FluxAttention: "FluxAttention"})(_QuantAttentionModuleMixin) QuantModuleRegistry.register({WanAttention: "WanAttention"})(_QuantAttentionModuleMixin) QuantModuleRegistry.register({LTXAttention: "LTXAttention"})(_QuantAttentionModuleMixin) + if Flux2Attention is not None: + QuantModuleRegistry.register({Flux2Attention: "Flux2Attention"})(_QuantAttentionModuleMixin) + if Flux2ParallelSelfAttention is not None: + QuantModuleRegistry.register({Flux2ParallelSelfAttention: "Flux2ParallelSelfAttention"})( + _QuantAttentionModuleMixin + ) original_scaled_dot_product_attention = F.scaled_dot_product_attention From 185fd2857ab5b4b321e98a890280fff0260fc21e Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Fri, 6 Mar 2026 05:37:52 +0000 Subject: [PATCH 3/4] Update the ltx2 filter Signed-off-by: Jingyu Xin --- examples/diffusers/quantization/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index b1e7861a5c..dd351b0757 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -74,7 +74,7 @@ def check_conv_and_mha(backbone, if_fp4, quantize_mha): def filter_func_ltx_video(name: str) -> bool: """Filter function specifically for LTX-Video models.""" pattern = re.compile( - r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single).*" + r".*(proj_in|time_embed|caption_projection|proj_out|patchify_proj|adaln_single|blocks\.(0|1|2|45|46|47)\.).*" ) return pattern.match(name) is not None From ce5e0041fda30dfde05bbd9be302defaf3664923 Mon Sep 17 00:00:00 2001 From: Jingyu Xin Date: Sat, 14 Mar 2026 05:12:09 +0000 Subject: [PATCH 4/4] Update Signed-off-by: Jingyu Xin --- .../diffusers/quantization/models_utils.py | 6 ++- examples/diffusers/quantization/utils.py | 4 +- tests/_test_utils/torch/diffusers_models.py | 26 ++++++++++++ .../torch/export/test_export_diffusers.py | 42 ++++++++++++++++++- 4 files changed, 74 insertions(+), 4 deletions(-) diff --git a/examples/diffusers/quantization/models_utils.py b/examples/diffusers/quantization/models_utils.py index 05df0632ee..1e90c54714 100644 --- a/examples/diffusers/quantization/models_utils.py +++ b/examples/diffusers/quantization/models_utils.py @@ -20,12 +20,16 @@ from diffusers import ( DiffusionPipeline, - Flux2Pipeline, FluxPipeline, LTXConditionPipeline, StableDiffusion3Pipeline, WanPipeline, ) + +try: + from diffusers import Flux2Pipeline +except ImportError: + Flux2Pipeline = None from utils import ( filter_func_default, filter_func_flux_dev, diff --git a/examples/diffusers/quantization/utils.py b/examples/diffusers/quantization/utils.py index 21de84e466..0c38fc2860 100644 --- a/examples/diffusers/quantization/utils.py +++ b/examples/diffusers/quantization/utils.py @@ -81,7 +81,9 @@ def filter_func_ltx_video(name: str) -> bool: def filter_func_flux_dev(name: str) -> bool: """Filter function specifically for Flux-dev models.""" - pattern = re.compile(r"(proj_out.*|.*(time_text_embed|context_embedder|x_embedder|norm_out).*)") + pattern = re.compile( + r"(proj_out.*|.*(time_text_embed|context_embedder|x_embedder|norm_out|time_guidance_embed|stream_modulation).*)" + ) return pattern.match(name) is not None diff --git a/tests/_test_utils/torch/diffusers_models.py b/tests/_test_utils/torch/diffusers_models.py index 7d91b8909b..c2f9d9b3d7 100644 --- a/tests/_test_utils/torch/diffusers_models.py +++ b/tests/_test_utils/torch/diffusers_models.py @@ -27,6 +27,11 @@ DiTTransformer2DModel = None FluxTransformer2DModel = None +try: + from diffusers.models.transformers import Flux2Transformer2DModel +except Exception: # pragma: no cover - optional diffusers models + Flux2Transformer2DModel = None + import modelopt.torch.opt as mto @@ -93,6 +98,27 @@ def get_tiny_flux(**config_kwargs): return FluxTransformer2DModel(**kwargs) +def get_tiny_flux2(**config_kwargs): + """Create a tiny Flux2Transformer2DModel for testing.""" + if Flux2Transformer2DModel is None: + pytest.skip("Flux2Transformer2DModel is not available in this diffusers version.") + + kwargs = { + "patch_size": 1, + "in_channels": 16, + "num_layers": 1, + "num_single_layers": 1, + "attention_head_dim": 16, + "num_attention_heads": 2, + "joint_attention_dim": 32, + "timestep_guidance_channels": 16, + "mlp_ratio": 3.0, + "axes_dims_rope": (4, 4, 4, 4), + } + kwargs.update(**config_kwargs) + return Flux2Transformer2DModel(**kwargs) + + def create_tiny_unet_dir(tmp_path: Path, **config_kwargs) -> Path: """Create and save a tiny UNet model to a directory.""" tiny_unet = get_tiny_unet(**config_kwargs) diff --git a/tests/unit/torch/export/test_export_diffusers.py b/tests/unit/torch/export/test_export_diffusers.py index 856a11de8c..450509d852 100644 --- a/tests/unit/torch/export/test_export_diffusers.py +++ b/tests/unit/torch/export/test_export_diffusers.py @@ -16,7 +16,12 @@ import json import pytest -from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet +from _test_utils.torch.diffusers_models import ( + get_tiny_dit, + get_tiny_flux, + get_tiny_flux2, + get_tiny_unet, +) pytest.importorskip("diffusers") @@ -29,7 +34,9 @@ def _load_config(config_path): return json.load(file) -@pytest.mark.parametrize("model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux]) +@pytest.mark.parametrize( + "model_factory", [get_tiny_unet, get_tiny_dit, get_tiny_flux, get_tiny_flux2] +) def test_export_diffusers_models_non_quantized(tmp_path, model_factory): model = model_factory() export_dir = tmp_path / f"export_{type(model).__name__}" @@ -82,3 +89,34 @@ def _process_stub(*_args, **_kwargs): config_data = _load_config(config_path) assert "quantization_config" in config_data assert config_data["quantization_config"] == convert_hf_quant_config_format(dummy_quant_config) + + +def test_flux2_dummy_inputs_shape(): + """Verify Flux2-specific dummy input shapes: 4-col RoPE ids, no pooled_projections, guidance.""" + import torch + + from modelopt.torch.export.diffusers_utils import generate_diffusion_dummy_inputs + + model = get_tiny_flux2() + cfg = model.config + inputs = generate_diffusion_dummy_inputs(model, torch.device("cpu"), torch.float32) + + assert inputs is not None, "generate_diffusion_dummy_inputs returned None for Flux2" + + # hidden_states: (batch, seq_len, in_channels) + assert inputs["hidden_states"].shape == (1, 16, cfg.in_channels) + + # encoder_hidden_states: (batch, text_seq_len, joint_attention_dim) + assert inputs["encoder_hidden_states"].shape == (1, 8, cfg.joint_attention_dim) + + # RoPE ids must have 4 columns (not 3 like Flux1) + rope_ndim = len(cfg.axes_dims_rope) + assert rope_ndim == 4 + assert inputs["img_ids"].shape == (16, rope_ndim) + assert inputs["txt_ids"].shape == (8, rope_ndim) + + # Flux2 must NOT have pooled_projections (unlike Flux1) + assert "pooled_projections" not in inputs + + # guidance_embeds defaults to True for Flux2 + assert "guidance" in inputs