Skip to content

Conversation

@leisuzz
Copy link
Contributor

@leisuzz leisuzz commented Dec 18, 2025

What does this PR do?

The text encoder is too large in Flux2, and offload to cpu requires a lot of time to get the prompt.

  1. I add the feature to use FSDP in text encoder, which can compute efficiently with multiple GPUs.
  2. The checkpoint is not supporting FSDP now, I added the option if the accelerate uses FSDP.

It is FSDP2, and the script is:

accelerate launch --config_file ${config_file} \
  ./train_dreambooth_lora_flux2_img2img.py \
  --pretrained_model_name_or_path=$model_name  \
  --dataset_name=$dataset_name \
  --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
  --resolution=$resolution \
  --train_batch_size=$batch_size \
  --guidance_scale=1 \
  --mixed_precision=$mixed_precision \
  --max_grad_norm=1 \
  --dataloader_num_workers=0 \
  --gradient_accumulation_steps=$gradient_accumulation_steps \
  --learning_rate=1e-05 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --gradient_checkpointing \
  --max_train_steps=$max_train_steps \
  --checkpointing_steps=5000 \
  --enable_npu_flash_attention \
  --rank=16 \
  --seed="0" \
  --skip_final_inference \
  --cache_latents \
  --offload \
  --fsdp_text_encoder \
  --output_dir=${output_path} \

The accelerate config is:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_version: 2
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock,Flux2SingleTransformerBlock
  fsdp_forward_prefetch: true
  fsdp_sync_module_states: false
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_use_orig_params: false
  fsdp_activation_checkpointing: true
  fsdp_reshard_after_forward: true
  fsdp_cpu_ram_efficient_loading: false
main_training_function: main
machine_rank: 0
main_process_ip: localhost
main_process_port: 6878
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@leisuzz leisuzz force-pushed the fsdp branch 3 times, most recently from 559a7a3 to 343b12a Compare December 18, 2025 12:31
@leisuzz
Copy link
Contributor Author

leisuzz commented Dec 18, 2025

@sayakpaul Please take a look at this PR. Thank you for your help :)

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool work, thank you for this!

Just confirming -- this is FSDP2, right?

Also, could you provide an example command and your setup so that we can test?

Additionally, can we similarly the denoiser like this?

@leisuzz
Copy link
Contributor Author

leisuzz commented Dec 19, 2025

Very cool work, thank you for this!

Just confirming -- this is FSDP2, right?

Also, could you provide an example command and your setup so that we can test?

Additionally, can we similarly the denoiser like this?

It is FSDP2, and the script is:

accelerate launch --config_file ${config_file} \
  ./train_dreambooth_lora_flux2_img2img.py \
  --pretrained_model_name_or_path=$model_name  \
  --dataset_name=$dataset_name \
  --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
  --resolution=$resolution \
  --train_batch_size=$batch_size \
  --guidance_scale=1 \
  --mixed_precision=$mixed_precision \
  --max_grad_norm=1 \
  --dataloader_num_workers=0 \
  --gradient_accumulation_steps=$gradient_accumulation_steps \
  --learning_rate=1e-05 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --gradient_checkpointing \
  --max_train_steps=$max_train_steps \
  --checkpointing_steps=5000 \
  --enable_npu_flash_attention \
  --rank=16 \
  --seed="0" \
  --skip_final_inference \
  --cache_latents \
  --offload \
  --fsdp_text_encoder \
  --output_dir=${output_path} \

The accelerate config is:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_version: 2
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock,Flux2SingleTransformerBlock
  fsdp_forward_prefetch: true
  fsdp_sync_module_states: false
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_use_orig_params: false
  fsdp_activation_checkpointing: true
  fsdp_reshard_after_forward: true
  fsdp_cpu_ram_efficient_loading: false
main_training_function: main
machine_rank: 0
main_process_ip: localhost
main_process_port: 6878
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look neat to me!

Let's also update the README about this.

Comment on lines 13 to 16
import torch.distributed as dist
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should guard this like so:

if getattr(torch, "distributed", None) is not None:
    import torch.distributed as dist

Same for FSDP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified it, please take a look

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem like the commits were pushed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check it out

Comment on lines 469 to 470
if dist.is_initialized():
dist.barrier()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified it, please take a look

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 22, 2025

Style bot fixed some files and pushed the changes.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few more comments which are mostly minor. As also mentioned earlier, let's make a note of this in the README_flux2.md.


import numpy as np
import torch
import torch.distributed as dist
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be guarded as well.

if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
transformer_lora_layers_to_save = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simplify this block of code a bit:

transformer_cls = type(unwrap_model(transformer))

def _to_cpu_contiguous(sd):
    return {
        k: (v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v)
        for k, v in sd.items()
    }

# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None

for m in models:
    if isinstance(unwrap_model(m), transformer_cls):
        transformer_model = m
        modules_to_save["transformer"] = m
    else:
        raise ValueError(f"unexpected save model: {m.__class__}")

if transformer_model is None:
    raise ValueError("No transformer model found in `models`.")

# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(models) if is_fsdp else None

# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
    peft_kwargs = {}
    if is_fsdp:
        peft_kwargs["state_dict"] = state_dict

    transformer_lora_layers_to_save = get_peft_model_state_dict(
        unwrap_model(transformer_model) if is_fsdp else transformer_model,
        **peft_kwargs,
    )

    if is_fsdp:
        transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)

        # make sure to pop weight so that corresponding model is not saved again
        if weights:
            weights.pop()

We can move _to_cpu_contiguous() to the training_utils.py module.

if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
transformer_lora_layers_to_save = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants