From 27c354dbada46f986930aa47a6306c2390d76f85 Mon Sep 17 00:00:00 2001 From: ahmeddawy Date: Wed, 14 Jan 2026 15:19:43 +0200 Subject: [PATCH 1/5] enable batch size --- diffsynth/diffusion/base_pipeline.py | 12 +++ diffsynth/diffusion/parsers.py | 1 + diffsynth/diffusion/runner.py | 88 ++++++++++++++++++- diffsynth/models/wan_video_vae.py | 14 +++ diffsynth/pipelines/wan_video.py | 24 ++--- .../model_training/lora/Wan2.1-VACE-1.3B.sh | 3 +- examples/wanvideo/model_training/train.py | 60 +++++++++++-- 7 files changed, 182 insertions(+), 20 deletions(-) diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 4fe155963..a0fb2bebf 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -112,6 +112,11 @@ def check_resize_height_width(self, height, width, num_frames=None): def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H W", min_value=-1, max_value=1): # Transform a PIL.Image to torch.Tensor + if isinstance(image, torch.Tensor): + image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + if image.dim() == 3 and "B" in pattern: + image = image.unsqueeze(0) + return image image = torch.Tensor(np.array(image, dtype=np.float32)) image = image.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) image = image * ((max_value - min_value) / 255) + min_value @@ -121,6 +126,13 @@ def preprocess_image(self, image, torch_dtype=None, device=None, pattern="B C H def preprocess_video(self, video, torch_dtype=None, device=None, pattern="B C T H W", min_value=-1, max_value=1): # Transform a list of PIL.Image to torch.Tensor + if isinstance(video, torch.Tensor): + video = video.to(dtype=torch_dtype or self.torch_dtype, device=device or self.device) + if video.dim() == 4: + video = video.unsqueeze(0) + elif video.dim() == 3: + video = video.unsqueeze(0).unsqueeze(2) + return video video = [self.preprocess_image(image, torch_dtype=torch_dtype, device=device, min_value=min_value, max_value=max_value) for image in video] video = torch.stack(video, dim=pattern.index("T") // 2) return video diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index b8c6c6afd..8e5b65880 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -32,6 +32,7 @@ def add_model_config(parser: argparse.ArgumentParser): def add_training_config(parser: argparse.ArgumentParser): parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") + parser.add_argument("--batch_size", type=int, default=1, help="Batch size for training.") parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.") parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 63cd85616..7497ac063 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -1,10 +1,80 @@ import os, torch +import numpy as np from tqdm import tqdm from accelerate import Accelerator from .training_module import DiffusionTrainingModule from .logger import ModelLogger +def _pad_frames(frames, target_frames): + if target_frames is None: + return frames + if len(frames) >= target_frames: + return frames[:target_frames] + if len(frames) == 0: + raise ValueError("Cannot pad empty frame list.") + pad_frame = frames[-1] + return frames + [pad_frame] * (target_frames - len(frames)) + + +def _frame_to_tensor(frame, min_value=-1.0, max_value=1.0): + if isinstance(frame, torch.Tensor): + tensor = frame + if tensor.dim() == 3 and tensor.shape[0] not in (1, 3): + tensor = tensor.permute(2, 0, 1) + return tensor + array = np.array(frame, dtype=np.float32) + tensor = torch.from_numpy(array).permute(2, 0, 1) + tensor = tensor * ((max_value - min_value) / 255.0) + min_value + return tensor + + +def _frames_to_tensor(frames, min_value=-1.0, max_value=1.0): + frame_tensors = [_frame_to_tensor(frame, min_value=min_value, max_value=max_value) for frame in frames] + return torch.stack(frame_tensors, dim=1) + + +def _collate_batch(batch, data_file_keys, num_frames): + if len(batch) == 1: + return batch[0] + single_frame_keys = {"reference_image", "vace_reference_image"} + output = {} + keys = batch[0].keys() + for key in keys: + values = [sample.get(key) for sample in batch] + if key in data_file_keys: + is_mask = "mask" in key + min_value = 0.0 if is_mask else -1.0 + max_value = 1.0 if is_mask else 1.0 + if any(value is None for value in values): + raise ValueError(f"Missing key '{key}' in one or more batch samples.") + if key in single_frame_keys: + frames = [] + for value in values: + if isinstance(value, list): + if len(value) == 0: + raise ValueError(f"Key '{key}' has empty frame list.") + frames.append(value[0]) + else: + frames.append(value) + tensors = [_frame_to_tensor(frame, min_value=min_value, max_value=max_value) for frame in frames] + output[key] = torch.stack(tensors, dim=0) + else: + tensors = [] + for value in values: + if isinstance(value, list): + padded = _pad_frames(value, num_frames) + tensors.append(_frames_to_tensor(padded, min_value=min_value, max_value=max_value)) + elif isinstance(value, torch.Tensor): + tensors.append(value) + else: + raise ValueError(f"Unsupported value type for key '{key}': {type(value)}") + output[key] = torch.stack(tensors, dim=0) + else: + output[key] = values + return output + + def launch_training_task( accelerator: Accelerator, dataset: torch.utils.data.Dataset, @@ -23,10 +93,26 @@ def launch_training_task( num_workers = args.dataset_num_workers save_steps = args.save_steps num_epochs = args.num_epochs + batch_size = args.batch_size + data_file_keys = args.data_file_keys.split(",") + num_frames = getattr(args, "num_frames", None) + else: + batch_size = 1 + data_file_keys = [] + num_frames = None optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + if batch_size > 1: + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=lambda batch: _collate_batch(batch, data_file_keys, num_frames), + num_workers=num_workers, + ) + else: + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index d24e29d93..9ce06f29b 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -1216,6 +1216,14 @@ def single_decode(self, hidden_state, device): def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + if isinstance(videos, torch.Tensor): + if tiled: + if videos.shape[0] != 1: + raise ValueError("tiled encode does not support batch size > 1") + tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor) + tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor) + return self.tiled_encode(videos, device, tile_size, tile_stride) + return self.single_encode(videos, device) videos = [video.to("cpu") for video in videos] hidden_states = [] for video in videos: @@ -1233,6 +1241,12 @@ def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(1 def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + if isinstance(hidden_states, torch.Tensor): + if tiled: + if hidden_states.shape[0] != 1: + raise ValueError("tiled decode does not support batch size > 1") + return self.tiled_decode(hidden_states, device, tile_size, tile_stride) + return self.single_decode(hidden_states, device) hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] videos = [] for hidden_state in hidden_states: diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 5b4c0b41a..d5e921b05 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -350,16 +350,18 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames): class WanVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image", "batch_size"), output_params=("noise",) ) - def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image, batch_size): + if batch_size is None: + batch_size = 1 length = (num_frames - 1) // 4 + 1 if vace_reference_image is not None: f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 length += f - shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) + shape = (batch_size, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) if vace_reference_image is not None: noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) @@ -650,7 +652,7 @@ def process( reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) vace_video_latents = torch.concat((inactive, reactive), dim=1) - vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = rearrange(vace_video_mask[:, 0], "B T (H P) (W Q) -> B (P Q) T H W", P=8, Q=8) vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') if vace_reference_image is None: @@ -661,15 +663,15 @@ def process( vace_reference_image = pipe.preprocess_video(vace_reference_image) + if vace_reference_image.dim() == 4: + vace_reference_image = vace_reference_image.unsqueeze(2) bs, c, f, h, w = vace_reference_image.shape - new_vace_ref_images = [] + vace_reference_latents = [] for j in range(f): - new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) - vace_reference_image = new_vace_ref_images - - vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) - vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents] + frame = vace_reference_image[:, :, j:j+1] + frame_latents = pipe.vae.encode(frame, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + frame_latents = torch.concat((frame_latents, torch.zeros_like(frame_latents)), dim=1) + vace_reference_latents.append(frame_latents) vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2) vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh index b56507889..5c17b7e3d 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh @@ -7,6 +7,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --dataset_repeat 100 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ --learning_rate 1e-4 \ + --batch_size 2 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B_lora" \ @@ -14,4 +15,4 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload \ No newline at end of file + --use_gradient_checkpointing_offload diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 497343822..7c746c0c4 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -63,13 +63,39 @@ def __init__( self.min_timestep_boundary = min_timestep_boundary def parse_extra_inputs(self, data, extra_inputs, inputs_shared): + def pick_first_frame(value): + if isinstance(value, torch.Tensor): + if value.dim() == 5: + return value[:, :, 0] + if value.dim() == 4: + if value.shape[0] in (1, 3) and value.shape[1] not in (1, 3): + return value[:, 0] + return value + return value + if isinstance(value, list): + return value[0] + return value + + def pick_last_frame(value): + if isinstance(value, torch.Tensor): + if value.dim() == 5: + return value[:, :, -1] + if value.dim() == 4: + if value.shape[0] in (1, 3) and value.shape[1] not in (1, 3): + return value[:, -1] + return value + return value + if isinstance(value, list): + return value[-1] + return value + for extra_input in extra_inputs: if extra_input == "input_image": - inputs_shared["input_image"] = data["video"][0] + inputs_shared["input_image"] = pick_first_frame(data["video"]) elif extra_input == "end_image": - inputs_shared["end_image"] = data["video"][-1] + inputs_shared["end_image"] = pick_last_frame(data["video"]) elif extra_input == "reference_image" or extra_input == "vace_reference_image": - inputs_shared[extra_input] = data[extra_input][0] + inputs_shared[extra_input] = pick_first_frame(data[extra_input]) else: inputs_shared[extra_input] = data[extra_input] return inputs_shared @@ -77,13 +103,33 @@ def parse_extra_inputs(self, data, extra_inputs, inputs_shared): def get_pipeline_inputs(self, data): inputs_posi = {"prompt": data["prompt"]} inputs_nega = {} + input_video = data["video"] + if isinstance(input_video, torch.Tensor): + if input_video.dim() == 5: + batch_size = input_video.shape[0] + num_frames = input_video.shape[2] + height = input_video.shape[3] + width = input_video.shape[4] + elif input_video.dim() == 4: + batch_size = 1 + num_frames = input_video.shape[1] + height = input_video.shape[2] + width = input_video.shape[3] + else: + raise ValueError(f"Unsupported input_video tensor shape: {input_video.shape}") + else: + batch_size = 1 + num_frames = len(input_video) + height = input_video[0].size[1] + width = input_video[0].size[0] inputs_shared = { # Assume you are using this pipeline for inference, # please fill in the input parameters. - "input_video": data["video"], - "height": data["video"][0].size[1], - "width": data["video"][0].size[0], - "num_frames": len(data["video"]), + "input_video": input_video, + "height": height, + "width": width, + "num_frames": num_frames, + "batch_size": batch_size, # Please do not modify the following parameters # unless you clearly know what this will cause. "cfg_scale": 1, From 6a5a0fed25b434ea46f2abf3d6334630099c41e7 Mon Sep 17 00:00:00 2001 From: ahmeddawy Date: Wed, 14 Jan 2026 15:26:49 +0200 Subject: [PATCH 2/5] update bash --- .../wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh index 5c17b7e3d..31f51ebc9 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh @@ -1,10 +1,10 @@ accelerate launch examples/wanvideo/model_training/train.py \ - --dataset_base_path data/example_video_dataset \ - --dataset_metadata_path data/example_video_dataset/metadata_vace.csv \ - --data_file_keys "video,vace_video,vace_reference_image" \ + --dataset_base_path /mnt/bucket/dawy/video_generation/two_stage_dataset \ + --dataset_metadata_path /mnt/bucket/dawy/video_generation/two_stage_dataset/metadata_vanilla_stage1.csv \ + --data_file_keys "video,vace_video,vace_reference_image,vace_video_mask" \ --height 480 \ --width 832 \ - --dataset_repeat 100 \ + --dataset_repeat 1 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ --learning_rate 1e-4 \ --batch_size 2 \ @@ -16,3 +16,6 @@ accelerate launch examples/wanvideo/model_training/train.py \ --lora_rank 32 \ --extra_inputs "vace_video,vace_reference_image" \ --use_gradient_checkpointing_offload + + + From 1cb14f7cd918e4902c70deda33a4068e7d7113e3 Mon Sep 17 00:00:00 2001 From: ahmeddawy Date: Thu, 15 Jan 2026 11:44:17 +0200 Subject: [PATCH 3/5] Adding eval videos and losses print --- diffsynth/diffusion/parsers.py | 12 +++ diffsynth/diffusion/runner.py | 87 +++++++++++++++++++ .../model_training/lora/Wan2.1-VACE-1.3B.sh | 12 ++- examples/wanvideo/model_training/train.py | 27 +++++- 4 files changed, 134 insertions(+), 4 deletions(-) diff --git a/diffsynth/diffusion/parsers.py b/diffsynth/diffusion/parsers.py index 8e5b65880..a1e174425 100644 --- a/diffsynth/diffusion/parsers.py +++ b/diffsynth/diffusion/parsers.py @@ -61,6 +61,17 @@ def add_gradient_config(parser: argparse.ArgumentParser): parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") return parser +def add_eval_config(parser: argparse.ArgumentParser): + parser.add_argument("--val_dataset_base_path", type=str, default=None, help="Base path of the validation dataset.") + parser.add_argument("--val_dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the validation dataset.") + parser.add_argument("--val_dataset_repeat", type=int, default=1, help="Number of times to repeat the validation dataset per epoch.") + parser.add_argument("--val_dataset_num_workers", type=int, default=0, help="Number of workers for validation data loading.") + parser.add_argument("--val_data_file_keys", type=str, default=None, help="Data file keys for validation metadata. Comma-separated.") + parser.add_argument("--val_batch_size", type=int, default=None, help="Batch size for validation. Defaults to --batch_size when unset.") + parser.add_argument("--eval_every_n_epochs", type=int, default=1, help="Run evaluation every N epochs when validation data is provided.") + parser.add_argument("--eval_max_batches", type=int, default=None, help="Maximum validation batches per eval pass.") + return parser + def add_general_config(parser: argparse.ArgumentParser): parser = add_dataset_base_config(parser) parser = add_model_config(parser) @@ -68,4 +79,5 @@ def add_general_config(parser: argparse.ArgumentParser): parser = add_output_config(parser) parser = add_lora_config(parser) parser = add_gradient_config(parser) + parser = add_eval_config(parser) return parser diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index 7497ac063..d6f5344e4 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -75,6 +75,53 @@ def _collate_batch(batch, data_file_keys, num_frames): return output +def run_validation( + accelerator: Accelerator, + dataset: torch.utils.data.Dataset, + model: DiffusionTrainingModule, + num_workers: int, + batch_size: int, + data_file_keys: list[str], + num_frames: int, + max_batches: int = None, +): + if dataset is None: + return None + if batch_size > 1: + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + collate_fn=lambda batch: _collate_batch(batch, data_file_keys, num_frames), + num_workers=num_workers, + ) + else: + dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0], num_workers=num_workers) + dataloader = accelerator.prepare(dataloader) + was_training = model.training + model.eval() + losses = [] + with torch.no_grad(): + for step, data in enumerate(tqdm(dataloader, desc="Eval")): + if max_batches is not None and step >= max_batches: + break + if dataset.load_from_cache: + loss = model({}, inputs=data) + else: + loss = model(data) + loss = loss.detach().float() + loss = accelerator.gather(loss) + losses.append(loss.flatten()) + if was_training: + model.train() + if not losses: + return None + mean_loss = torch.cat(losses).mean().item() + if accelerator.is_main_process: + print(f"Eval loss: {mean_loss:.6f}") + return mean_loss + + def launch_training_task( accelerator: Accelerator, dataset: torch.utils.data.Dataset, @@ -85,6 +132,7 @@ def launch_training_task( num_workers: int = 1, save_steps: int = None, num_epochs: int = 1, + val_dataset: torch.utils.data.Dataset = None, args = None, ): if args is not None: @@ -96,10 +144,20 @@ def launch_training_task( batch_size = args.batch_size data_file_keys = args.data_file_keys.split(",") num_frames = getattr(args, "num_frames", None) + val_num_workers = args.val_dataset_num_workers + val_batch_size = args.val_batch_size or batch_size + val_data_file_keys = (args.val_data_file_keys or args.data_file_keys).split(",") + eval_every_n_epochs = args.eval_every_n_epochs + eval_max_batches = args.eval_max_batches else: batch_size = 1 data_file_keys = [] num_frames = None + val_num_workers = 0 + val_batch_size = 1 + val_data_file_keys = [] + eval_every_n_epochs = 0 + eval_max_batches = None optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) @@ -117,6 +175,8 @@ def launch_training_task( model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch_id in range(num_epochs): + epoch_loss_sum = None + epoch_steps = 0 for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() @@ -124,12 +184,39 @@ def launch_training_task( loss = model({}, inputs=data) else: loss = model(data) + loss_value = loss.detach().float() + if epoch_loss_sum is None: + epoch_loss_sum = loss_value + else: + epoch_loss_sum = epoch_loss_sum + loss_value + epoch_steps += 1 accelerator.backward(loss) optimizer.step() model_logger.on_step_end(accelerator, model, save_steps) scheduler.step() + if epoch_loss_sum is None: + epoch_loss_sum = torch.tensor(0.0, device=accelerator.device) + steps_tensor = torch.tensor(float(epoch_steps), device=epoch_loss_sum.device) + loss_stats = torch.stack([epoch_loss_sum, steps_tensor]).unsqueeze(0) + gathered_stats = accelerator.gather(loss_stats) + if accelerator.is_main_process: + total_loss = gathered_stats[:, 0].sum().item() + total_steps = gathered_stats[:, 1].sum().item() + avg_loss = total_loss / total_steps if total_steps > 0 else float("nan") + print(f"Train loss (epoch {epoch_id}): {avg_loss:.6f}") if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) + if val_dataset is not None and eval_every_n_epochs > 0 and (epoch_id + 1) % eval_every_n_epochs == 0: + run_validation( + accelerator, + val_dataset, + model, + val_num_workers, + val_batch_size, + val_data_file_keys, + num_frames, + max_batches=eval_max_batches, + ) model_logger.on_training_end(accelerator, model, save_steps) diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh index 31f51ebc9..4907885ab 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh @@ -8,14 +8,20 @@ accelerate launch examples/wanvideo/model_training/train.py \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ --learning_rate 1e-4 \ --batch_size 2 \ + --gradient_accumulation_steps 4 \ --num_epochs 5 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B_lora" \ --lora_base_model "vace" \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --lora_rank 32 \ - --extra_inputs "vace_video,vace_reference_image" \ - --use_gradient_checkpointing_offload - + --extra_inputs "video,vace_video,vace_reference_image,vace_video_mask" \ + --use_gradient_checkpointing_offload \ + --val_dataset_base_path /mnt/bucket/dawy/video_generation/two_stage_dataset \ + --val_dataset_metadata_path /mnt/bucket/dawy/video_generation/two_stage_dataset/metadata_vanilla_stage1_val.csv \ + --val_data_file_keys "video,vace_video,vace_reference_image,vace_video_mask" \ + --val_batch_size 2 \ + --eval_every_n_epochs 1 \ + --eval_max_batches 50 diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 7c746c0c4..c43e6e3a5 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -173,6 +173,31 @@ def wan_parser(): gradient_accumulation_steps=args.gradient_accumulation_steps, kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)], ) + val_dataset = None + if args.val_dataset_metadata_path is not None: + val_base_path = args.val_dataset_base_path or args.dataset_base_path + val_data_file_keys = (args.val_data_file_keys or args.data_file_keys).split(",") + val_dataset = UnifiedDataset( + base_path=val_base_path, + metadata_path=args.val_dataset_metadata_path, + repeat=args.val_dataset_repeat, + data_file_keys=val_data_file_keys, + main_data_operator=UnifiedDataset.default_video_operator( + base_path=val_base_path, + max_pixels=args.max_pixels, + height=args.height, + width=args.width, + height_division_factor=16, + width_division_factor=16, + num_frames=args.num_frames, + time_division_factor=4, + time_division_remainder=1, + ), + special_operator_map={ + "animate_face_video": ToAbsolutePath(val_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)), + "input_audio": ToAbsolutePath(val_base_path) >> LoadAudio(sr=16000), + } + ) dataset = UnifiedDataset( base_path=args.dataset_base_path, metadata_path=args.dataset_metadata_path, @@ -228,4 +253,4 @@ def wan_parser(): "direct_distill": launch_training_task, "direct_distill:train": launch_training_task, } - launcher_map[args.task](accelerator, dataset, model, model_logger, args=args) + launcher_map[args.task](accelerator, dataset, model, model_logger, args=args, val_dataset=val_dataset) From edce95313849c94762ee777033a27654b435e6db Mon Sep 17 00:00:00 2001 From: ahmeddawy Date: Thu, 15 Jan 2026 10:08:30 +0000 Subject: [PATCH 4/5] Resolve stash conflict in Wan2.1-VACE-1.3B lora train script --- .../model_training/lora/Wan2.1-VACE-1.3B.sh | 10 ++-- .../validate_lora/test_Wan2.1-VACE-1.3B.py | 55 +++++++++++++++++++ 2 files changed, 60 insertions(+), 5 deletions(-) create mode 100644 examples/wanvideo/model_training/validate_lora/test_Wan2.1-VACE-1.3B.py diff --git a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh index 4907885ab..43c967b97 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh @@ -1,15 +1,15 @@ accelerate launch examples/wanvideo/model_training/train.py \ --dataset_base_path /mnt/bucket/dawy/video_generation/two_stage_dataset \ - --dataset_metadata_path /mnt/bucket/dawy/video_generation/two_stage_dataset/metadata_vanilla_stage1.csv \ + --dataset_metadata_path /mnt/bucket/dawy/video_generation/two_stage_dataset/metadata_vanilla_stage1_train.csv \ --data_file_keys "video,vace_video,vace_reference_image,vace_video_mask" \ --height 480 \ --width 832 \ - --dataset_repeat 1 \ + --dataset_repeat 2 \ --model_id_with_origin_paths "Wan-AI/Wan2.1-VACE-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-VACE-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-VACE-1.3B:Wan2.1_VAE.pth" \ --learning_rate 1e-4 \ - --batch_size 2 \ + --batch_size 5 \ --gradient_accumulation_steps 4 \ - --num_epochs 5 \ + --num_epochs 25 \ --remove_prefix_in_ckpt "pipe.vace." \ --output_path "./models/train/Wan2.1-VACE-1.3B_lora" \ --lora_base_model "vace" \ @@ -18,7 +18,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --extra_inputs "video,vace_video,vace_reference_image,vace_video_mask" \ --use_gradient_checkpointing_offload \ --val_dataset_base_path /mnt/bucket/dawy/video_generation/two_stage_dataset \ - --val_dataset_metadata_path /mnt/bucket/dawy/video_generation/two_stage_dataset/metadata_vanilla_stage1_val.csv \ + --val_dataset_metadata_path /mnt/bucket/dawy/video_generation/two_stage_dataset/metadata_vanilla_stage1_eval.csv \ --val_data_file_keys "video,vace_video,vace_reference_image,vace_video_mask" \ --val_batch_size 2 \ --eval_every_n_epochs 1 \ diff --git a/examples/wanvideo/model_training/validate_lora/test_Wan2.1-VACE-1.3B.py b/examples/wanvideo/model_training/validate_lora/test_Wan2.1-VACE-1.3B.py new file mode 100644 index 000000000..a5b50f4a7 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/test_Wan2.1-VACE-1.3B.py @@ -0,0 +1,55 @@ +import csv +import re +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from pathlib import Path +import os +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-VACE-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], +) +pipe.load_lora(pipe.vace, "models/train/Wan2.1-VACE-1.3B_lora/epoch-6.safetensors", alpha=1) + +out_dir = Path("/mnt/bucket/dawy/video_generation/vace_training_exps/vace_refactor/batching/epoch_6/") +out_dir.mkdir(parents=True, exist_ok=True) + +dataset_path = Path("/mnt/bucket/dawy/video_generation/two_stage_dataset/") +test_csv_path = Path("/mnt/bucket/dawy/video_generation/two_stage_dataset/Vace_Video_Generation_Dataset_Analysis_test.csv") + +with test_csv_path.open(newline="") as f: + rows = list(csv.DictReader(f)) + +for row in rows: + video_folder = row["video_id"] + camera_motion = row["camera_motion"] + motion_slug = re.sub(r"[^a-z0-9]+", "_", camera_motion.strip().lower()).strip("_") + video_path = dataset_path / video_folder / f"vace_video_{video_folder}.mp4" + mask_path = dataset_path / video_folder / f"vace_video_mask_{video_folder}.mp4" + if not video_path.exists() or not mask_path.exists(): + print(f"[SKIP] {video_folder}: missing vace_video or mask") + continue + + print(video_folder, camera_motion) + + video = VideoData(str(video_path), height=480, width=832) + if len(video) >= 81 : + video = [video[i] for i in range(81)] + else: + video = [video[i] for i in range(len(video))] + reference_image = VideoData(str(video_path), height=480, width=832)[0] + vace_mask = VideoData(str(mask_path), height=480, width=832) + vace_mask = [vace_mask[i] for i in range(len(video))] + video = pipe( + prompt=" ", + vace_video=video, vace_reference_image=reference_image, num_frames=len(video), vace_video_mask=vace_mask, + seed=1, tiled=True + ) + out_path = out_dir / f"video_{video_folder}_{motion_slug}.mp4" + save_video(video, str(out_path), fps=15, quality=5) From 2a790e9bf404617bba1e50e64b1a7812b68505c9 Mon Sep 17 00:00:00 2001 From: ahmeddawy Date: Thu, 15 Jan 2026 13:48:54 +0200 Subject: [PATCH 5/5] save best eval ckpt --- diffsynth/diffusion/runner.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/diffsynth/diffusion/runner.py b/diffsynth/diffusion/runner.py index d6f5344e4..f557a1e12 100644 --- a/diffsynth/diffusion/runner.py +++ b/diffsynth/diffusion/runner.py @@ -174,6 +174,7 @@ def launch_training_task( model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + best_val_loss = None for epoch_id in range(num_epochs): epoch_loss_sum = None epoch_steps = 0 @@ -207,7 +208,7 @@ def launch_training_task( if save_steps is None: model_logger.on_epoch_end(accelerator, model, epoch_id) if val_dataset is not None and eval_every_n_epochs > 0 and (epoch_id + 1) % eval_every_n_epochs == 0: - run_validation( + val_loss = run_validation( accelerator, val_dataset, model, @@ -217,6 +218,11 @@ def launch_training_task( num_frames, max_batches=eval_max_batches, ) + if val_loss is not None and (best_val_loss is None or val_loss < best_val_loss): + best_val_loss = val_loss + if accelerator.is_main_process: + print(f"New best eval loss: {best_val_loss:.6f}. Saving best checkpoint.") + model_logger.save_model(accelerator, model, "best.safetensors") model_logger.on_training_end(accelerator, model, save_steps)