Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions diffsynth/diffusion/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def check_resize_height_width(self, height, width, num_frames=None):

def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1):
# Transform a PIL.Image to torch.Tensor
if isinstance(image, torch.Tensor):
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
if image.dim() == 3 and "B" in pattern:
image = image.unsqueeze(0)
return image
image = torch.Tensor(np.array(image, dtype=np.float32))
image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
image = image * ((max_value - min_value) / 255) + min_value
Expand All @@ -121,6 +126,13 @@ def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H

def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1):
# Transform a list of PIL.Image to torch.Tensor
if isinstance(video, torch.Tensor):
video = video.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device)
if video.dim() == 4:
video = video.unsqueeze(0)
elif video.dim() == 3:
video = video.unsqueeze(0).unsqueeze(2)
return video
video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video]
video = torch.stack(video, dim=pattern.index("T") // 2)
return video
Expand Down
13 changes: 13 additions & 0 deletions diffsynth/diffusion/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def add_model_config(parser: argparse.ArgumentParser):

def add_training_config(parser: argparse.ArgumentParser):
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
Expand Down Expand Up @@ -60,11 +61,23 @@ def add_gradient_config(parser: argparse.ArgumentParser):
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
return parser

def add_eval_config(parser: argparse.ArgumentParser):
parser.add_argument("--val_dataset_base_path", type=str, default=None, help="Base path of the validation dataset.")
parser.add_argument("--val_dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the validation dataset.")
parser.add_argument("--val_dataset_repeat", type=int, default=1, help="Number of times to repeat the validation dataset per epoch.")
parser.add_argument("--val_dataset_num_workers", type=int, default=0, help="Number of workers for validation data loading.")
parser.add_argument("--val_data_file_keys", type=str, default=None, help="Data file keys for validation metadata. Comma-separated.")
parser.add_argument("--val_batch_size", type=int, default=None, help="Batch size for validation. Defaults to --batch_size when unset.")
parser.add_argument("--eval_every_n_epochs", type=int, default=1, help="Run evaluation every N epochs when validation data is provided.")
parser.add_argument("--eval_max_batches", type=int, default=None, help="Maximum validation batches per eval pass.")
return parser

def add_general_config(parser: argparse.ArgumentParser):
parser = add_dataset_base_config(parser)
parser = add_model_config(parser)
parser = add_training_config(parser)
parser = add_output_config(parser)
parser = add_lora_config(parser)
parser = add_gradient_config(parser)
parser = add_eval_config(parser)
return parser
181 changes: 180 additions & 1 deletion diffsynth/diffusion/runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,127 @@
import os, torch
import numpy as np
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger


def _pad_frames(frames, target_frames):
if target_frames is None:
return frames
if len(frames) >= target_frames:
return frames[:target_frames]
if len(frames) == 0:
raise ValueError("Cannot pad empty frame list.")
pad_frame = frames[-1]
return frames + [pad_frame] * (target_frames - len(frames))


def _frame_to_tensor(frame, min_value=-1.0, max_value=1.0):
if isinstance(frame, torch.Tensor):
tensor = frame
if tensor.dim() == 3 and tensor.shape[0] not in (1, 3):
tensor = tensor.permute(2, 0, 1)
return tensor
array = np.array(frame, dtype=np.float32)
tensor = torch.from_numpy(array).permute(2, 0, 1)
tensor = tensor * ((max_value - min_value) / 255.0) + min_value
return tensor


def _frames_to_tensor(frames, min_value=-1.0, max_value=1.0):
frame_tensors = [_frame_to_tensor(frame, min_value=min_value, max_value=max_value) for frame in frames]
return torch.stack(frame_tensors, dim=1)


def _collate_batch(batch, data_file_keys, num_frames):
if len(batch) == 1:
return batch[0]
single_frame_keys = {"reference_image", "vace_reference_image"}
output = {}
keys = batch[0].keys()
for key in keys:
values = [sample.get(key) for sample in batch]
if key in data_file_keys:
is_mask = "mask" in key
min_value = 0.0 if is_mask else -1.0
max_value = 1.0 if is_mask else 1.0
if any(value is None for value in values):
raise ValueError(f"Missing key '{key}' in one or more batch samples.")
if key in single_frame_keys:
frames = []
for value in values:
if isinstance(value, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The check isinstance(value, list) is too restrictive. It will likely fail for custom sequence-like objects such as VideoData returned by the dataset, causing an error during batch collation. To make this more robust, you should check for the general Sequence type from collections.abc and explicitly exclude strings.

You'll need to add from collections.abc import Sequence at the top of the file.

Suggested change
if isinstance(value, list):
if isinstance(value, Sequence) and not isinstance(value, str):

if len(value) == 0:
raise ValueError(f"Key '{key}' has empty frame list.")
frames.append(value[0])
else:
frames.append(value)
tensors = [_frame_to_tensor(frame, min_value=min_value, max_value=max_value) for frame in frames]
output[key] = torch.stack(tensors, dim=0)
else:
tensors = []
for value in values:
if isinstance(value, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the previous comment, this isinstance(value, list) check is too restrictive and will fail for custom sequence types like VideoData. Using collections.abc.Sequence will make the collate function more general and prevent errors with different dataset implementations.

Suggested change
if isinstance(value, list):
if isinstance(value, Sequence) and not isinstance(value, str):

padded = _pad_frames(value, num_frames)
tensors.append(_frames_to_tensor(padded, min_value=min_value, max_value=max_value))
elif isinstance(value, torch.Tensor):
tensors.append(value)
else:
raise ValueError(f"Unsupported value type for key '{key}': {type(value)}")
output[key] = torch.stack(tensors, dim=0)
else:
output[key] = values
return output


def run_validation(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
num_workers: int,
batch_size: int,
data_file_keys: list[str],
num_frames: int,
max_batches: int = None,
):
if dataset is None:
return None
if batch_size > 1:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda batch: _collate_batch(batch, data_file_keys, num_frames),
num_workers=num_workers,
)
else:
dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers)
dataloader = accelerator.prepare(dataloader)
was_training = model.training
model.eval()
losses = []
with torch.no_grad():
for step, data in enumerate(tqdm(dataloader, desc="Eval")):
if max_batches is not None and step >= max_batches:
break
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
loss = loss.detach().float()
loss = accelerator.gather(loss)
losses.append(loss.flatten())
if was_training:
model.train()
if not losses:
return None
mean_loss = torch.cat(losses).mean().item()
if accelerator.is_main_process:
print(f"Eval loss: {mean_loss:.6f}")
return mean_loss


def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
Expand All @@ -15,6 +132,7 @@ def launch_training_task(
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
val_dataset: torch.utils.data.Dataset = None,
args = None,
):
if args is not None:
Expand All @@ -23,27 +141,88 @@ def launch_training_task(
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
batch_size = args.batch_size
data_file_keys = args.data_file_keys.split(",")
num_frames = getattr(args, "num_frames", None)
val_num_workers = args.val_dataset_num_workers
val_batch_size = args.val_batch_size or batch_size
val_data_file_keys = (args.val_data_file_keys or args.data_file_keys).split(",")
eval_every_n_epochs = args.eval_every_n_epochs
eval_max_batches = args.eval_max_batches
else:
batch_size = 1
data_file_keys = []
num_frames = None
val_num_workers = 0
val_batch_size = 1
val_data_file_keys = []
eval_every_n_epochs = 0
eval_max_batches = None

optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
if batch_size > 1:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=lambda batch: _collate_batch(batch, data_file_keys, num_frames),
num_workers=num_workers,
)
else:
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)

model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)

best_val_loss = None
for epoch_id in range(num_epochs):
epoch_loss_sum = None
epoch_steps = 0
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
loss_value = loss.detach().float()
if epoch_loss_sum is None:
epoch_loss_sum = loss_value
else:
epoch_loss_sum = epoch_loss_sum + loss_value
epoch_steps += 1
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step()
if epoch_loss_sum is None:
epoch_loss_sum = torch.tensor(0.0, device=accelerator.device)
steps_tensor = torch.tensor(float(epoch_steps), device=epoch_loss_sum.device)
loss_stats = torch.stack([epoch_loss_sum, steps_tensor]).unsqueeze(0)
gathered_stats = accelerator.gather(loss_stats)
if accelerator.is_main_process:
total_loss = gathered_stats[:, 0].sum().item()
total_steps = gathered_stats[:, 1].sum().item()
avg_loss = total_loss / total_steps if total_steps > 0 else float("nan")
print(f"Train loss (epoch {epoch_id}): {avg_loss:.6f}")
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
if val_dataset is not None and eval_every_n_epochs > 0 and (epoch_id + 1) % eval_every_n_epochs == 0:
val_loss = run_validation(
accelerator,
val_dataset,
model,
val_num_workers,
val_batch_size,
val_data_file_keys,
num_frames,
max_batches=eval_max_batches,
)
if val_loss is not None and (best_val_loss is None or val_loss < best_val_loss):
best_val_loss = val_loss
if accelerator.is_main_process:
print(f"New best eval loss: {best_val_loss:.6f}. Saving best checkpoint.")
model_logger.save_model(accelerator, model, "best.safetensors")
model_logger.on_training_end(accelerator, model, save_steps)


Expand Down
14 changes: 14 additions & 0 deletions diffsynth/models/wan_video_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,14 @@ def single_decode(self, hidden_state, device):


def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
if isinstance(videos, torch.Tensor):
if tiled:
if videos.shape[0] != 1:
raise ValueError("tiled encode does not support batch size > 1")
tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor)
tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor)
return self.tiled_encode(videos, device, tile_size, tile_stride)
return self.single_encode(videos, device)
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
Expand All @@ -1233,6 +1241,12 @@ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(1


def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
if isinstance(hidden_states, torch.Tensor):
if tiled:
if hidden_states.shape[0] != 1:
raise ValueError("tiled decode does not support batch size > 1")
return self.tiled_decode(hidden_states, device, tile_size, tile_stride)
return self.single_decode(hidden_states, device)
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
Expand Down
24 changes: 13 additions & 11 deletions diffsynth/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,16 +350,18 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames):
class WanVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"),
input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image", "batch_size"),
output_params=("noise",)
)

def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image):
def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image, batch_size):
if batch_size is None:
batch_size = 1
length = (num_frames - 1) // 4 + 1
if vace_reference_image is not None:
f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1
length += f
shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
shape = (batch_size, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor)
noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device)
if vace_reference_image is not None:
noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2)
Expand Down Expand Up @@ -650,7 +652,7 @@ def process(
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)

vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
vace_mask_latents = rearrange(vace_video_mask[:, 0], "B T (H P) (W Q) -> B (P Q) T H W", P=8, Q=8)
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')

if vace_reference_image is None:
Expand All @@ -661,15 +663,15 @@ def process(

vace_reference_image = pipe.preprocess_video(vace_reference_image)

if vace_reference_image.dim() == 4:
vace_reference_image = vace_reference_image.unsqueeze(2)
bs, c, f, h, w = vace_reference_image.shape
new_vace_ref_images = []
vace_reference_latents = []
for j in range(f):
new_vace_ref_images.append(vace_reference_image[0, :, j:j+1])
vace_reference_image = new_vace_ref_images

vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents]
frame = vace_reference_image[:, :, j:j+1]
frame_latents = pipe.vae.encode(frame, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
frame_latents = torch.concat((frame_latents, torch.zeros_like(frame_latents)), dim=1)
vace_reference_latents.append(frame_latents)

vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2)
Expand Down
Loading