diff --git a/README copy.md b/README copy.md new file mode 100644 index 0000000..521a9fa --- /dev/null +++ b/README copy.md @@ -0,0 +1,126 @@ +
+ + +

LingBot-World: Advancing Open-source World Models

+ +Robbyant Team + +
+ + +
+ +[![Page](https://img.shields.io/badge/%F0%9F%8C%90%20Project%20Page-Demo-00bfff)](https://technology.robbyant.com/lingbot-world) +[![Tech Report](https://img.shields.io/badge/%F0%9F%93%84%20Tech%20Report-Document-teal)](LingBot_World_paper.pdf) +[![Paper](https://img.shields.io/static/v1?label=Paper&message=PDF&color=red&logo=arxiv)](https://github.com/robbyant/lingbot-world) +[![Model](https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Model&message=HuggingFace&color=yellow)](https://huggingface.co/robbyant/lingbot-world-base-cam) +[![Model](https://img.shields.io/static/v1?label=%F0%9F%A4%96%20Model&message=ModelScope&color=purple)](https://www.modelscope.cn/models/Robbyant/lingbot-world-base-cam) +[![License](https://img.shields.io/badge/License-Apache--2.0-green)](LICENSE.txt) + +
+ +----- + +We are excited to introduce **LingBot-World**, an open-sourced world simulator stemming from video generation. Positioned +as a top-tier world model, LingBot-World offers the following features. +- **High-Fidelity & Diverse Environments**: It maintains high fidelity and robust dynamics in a broad spectrum of environments, including realism, scientific contexts, cartoon styles, and beyond. +- **Long-Term Memory & Consistency**: It enables a minute-level horizon while preserving contextual consistency over time, which is also known as long-term memory. +- **Real-Time Interactivity & Open Access**: It supports real-time interactivity, achieving a latency of under 1 second when producing 16 frames per second. We provide public access to the code and model in an effort to narrow the divide between open-source and closed-source technologies. We believe our release will empower the community with practical applications across areas like content creation, gaming, and robot learning. + +## 🎬 Video Demo +
+ +
+ +## 🔥 News +- Jan 29, 2026: 🎉 We release the technical report, code, and models for LingBot-World. + + + +## ⚙️ Quick Start +This codebase is built upon [Wan2.2](https://github.com/Wan-Video/Wan2.2). Please refer to their documentation for installation instructions. +### Installation +Clone the repo: +```sh +git clone https://github.com/robbyant/lingbot-world.git +cd lingbot-world +``` +Install dependencies: +```sh +# Ensure torch >= 2.4.0 +pip install -r requirements.txt +``` +Install [`flash_attn`](https://github.com/Dao-AILab/flash-attention): +```sh +pip install flash-attn --no-build-isolation +``` +### Model Download + +| Model | Control Signals | Resolution | Download Links | +| :--- | :--- | :--- | :--- | +| **LingBot-World-Base (Cam)** | Camera Poses | 480P & 720P | 🤗 [HuggingFace](https://huggingface.co/robbyant/lingbot-world-base-cam) 🤖 [ModelScope](https://www.modelscope.cn/models/Robbyant/lingbot-world-base-cam) | +| **LingBot-World-Base (Act)** | Actions | - | *To be released* | +| **LingBot-World-Fast** | - | - | *To be released* | + + +Download models using modelscope-cli: + ```sh +pip install modelscope +modelscope download robbyant/lingbot-world-base-cam --local_dir ./lingbot-world-base-cam +``` +### Inference +Our model supports video generation at both 480P and 720P resolutions. You can find data samples for inference in the `examples/` directory, which includes the corresponding input images, prompts, and control signals. To enable long video generation, we utilize multi-GPU inference powered by FSDP and DeepSpeed Ulysses. +- 480P: + +This means the frame_num must be in the form of 4n + 1, where n is an integer (e.g., 1, 2, 3, etc.). For example, valid values include 5, 9, 13, 161, 321, etc. + +python generate.py --task i2v-A14B --size 480*832 --ckpt_dir lingbot-world-base-cam --image examples/00/image.jpg --action_path examples/00 --frame_num 31 --prompt "The video presents a soaring journey through a fantasy jungle. The wind whips past the rider's blue hands gripping the reins, causing the leather straps to vibrate. The ancient gothic castle approaches steadily, its stone details becoming clearer against the backdrop of floating islands and distant waterfalls." --save_file C:\workspace\world\lingbot-world\out + + python generate.py --task i2v-A14B --size 480*832 --ckpt_dir lingbot-world-base-cam --image examples/00/image.jpg --action_path examples/00 --frame_num 21 --prompt "The video presents a soaring journey through a fantasy jungle. The wind whips past the rider's blue hands gripping the reins, causing the leather straps to vibrate. The ancient gothic castle approaches steadily, its stone details becoming clearer against the backdrop of floating islands and distant waterfalls." + +``` sh +$env:USE_LIBUV=0 +torchrun --nproc_per_node=1 generate.py --task i2v-A14B --size 480*832 --ckpt_dir lingbot-world-base-cam --image examples/00/image.jpg --action_path examples/00 --dit_fsdp --t5_fsdp --ulysses_size 8 --frame_num 161 --prompt "The video presents a soaring journey through a fantasy jungle. The wind whips past the rider's blue hands gripping the reins, causing the leather straps to vibrate. The ancient gothic castle approaches steadily, its stone details becoming clearer against the backdrop of floating islands and distant waterfalls." +``` + python -m torch.distributed.run --nproc_per_node=1 generate.py --task i2v-A14B --size 480*832 --ckpt_dir lingbot-world-base-cam --image examples/00/image.jpg --action_path examples/00 --dit_fsdp --t5_fsdp --ulysses_size 8 --frame_num 161 --prompt "The video presents a soaring journey through a fantasy jungle. The wind whips past the rider's blue hands gripping the reins, causing the leather straps to vibrate. The ancient gothic castle approaches steadily, its stone details becoming clearer against the backdrop of floating islands and distant waterfalls. + +- 720P: +``` sh +torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 720*1280 --ckpt_dir lingbot-world-base-cam --image examples/00/image.jpg --action_path examples/00 --dit_fsdp --t5_fsdp --ulysses_size 8 --frame_num 161 --prompt "The video presents a soaring journey through a fantasy jungle. The wind whips past the rider's blue hands gripping the reins, causing the leather straps to vibrate. The ancient gothic castle approaches steadily, its stone details becoming clearer against the backdrop of floating islands and distant waterfalls." +``` +Alternatively, you can run inference without control actions: +``` sh +torchrun --nproc_per_node=8 generate.py --task i2v-A14B --size 480*832 --ckpt_dir lingbot-world-base-cam --image examples/00/image.jpg --dit_fsdp --t5_fsdp --ulysses_size 8 --frame_num 161 --prompt "The video presents a soaring journey through a fantasy jungle. The wind whips past the rider's blue hands gripping the reins, causing the leather straps to vibrate. The ancient gothic castle approaches steadily, its stone details becoming clearer against the backdrop of floating islands and distant waterfalls." +``` +Tips: +If you have sufficient CUDA memory, you may increase the `frame_num` parameter to a value such as 961 to generate a one-minute video at 16 FPS. + +## 📚 Related Projects +- [HoloCine](https://holo-cine.github.io/) +- [Ditto](https://editto.net/) +- [WorldCanvas](https://worldcanvas.github.io/) +- [RewardForcing](https://reward-forcing.github.io/) +- [CoDeF](https://qiuyu96.github.io/CoDeF/) + +## 📜 License +This project is licensed under the Apache 2.0 License. Please refer to the [LICENSE file](LICENSE.txt) for the full text, including details on rights and restrictions. + +## ✨ Acknowledgement +We would like to express our gratitude to the Wan Team for open-sourcing their code and models. Their contributions have been instrumental to the development of this project. + +## 📖 Citation +If you find this work useful for your research, please cite our paper: + +``` +@article{lingbot-world, + title={Advancing Open-source World Models}, + author={Robbyant Team}, + journal={arXiv preprint arXiv:xx.xx}, + year={2026} +} +``` diff --git a/assets/teaser.png b/assets/teaser.png deleted file mode 100644 index c90aeca..0000000 Binary files a/assets/teaser.png and /dev/null differ diff --git a/download.py b/download.py new file mode 100644 index 0000000..084389a --- /dev/null +++ b/download.py @@ -0,0 +1,41 @@ +import argparse +from huggingface_hub import snapshot_download + +if __name__ == "__main__": + # Available models + MODELS = { + #"base-cam": "robbyant/lingbot-world-base-cam", + "base-cam-nf4": "cahlen/lingbot-world-base-cam-nf4", + "base-act": "robbyant/lingbot-world-base-act" + } + + # Set up argument parser + parser = argparse.ArgumentParser(description="Download Lingbot World models from Hugging Face") + parser.add_argument( + "--model", + type=str, + nargs="+", + choices=list(MODELS.keys()), + default=["base-act", "base-cam-nf4"], + help=f"Model(s) to download. Available options: {', '.join(MODELS.keys())} (default: base-act base-cam-nf4)" + ) + parser.add_argument( + "--local-dir", + type=str, + default=None, + help="Local directory to save the model (default: ./model-name)" + ) + + args = parser.parse_args() + + for model in args.model: + repo_id = MODELS[model] + local_dir = args.local_dir if args.local_dir else f"./{model}" + + print(f"Downloading model: {model}") + print(f"Repository: {repo_id}") + print(f"Local directory: {local_dir}") + print() + + snapshot_download(repo_id=repo_id, repo_type="model", local_dir=local_dir) + print(f"Model '{model}' downloaded to {local_dir}") diff --git a/examples/racer/Screenshot.png b/examples/racer/Screenshot.png new file mode 100644 index 0000000..65beaad Binary files /dev/null and b/examples/racer/Screenshot.png differ diff --git a/generate_vbench.py b/generate_vbench.py new file mode 100644 index 0000000..b04807f --- /dev/null +++ b/generate_vbench.py @@ -0,0 +1,491 @@ +import argparse +import csv +import json +import logging +import os +import re +import sys +import time +import warnings +from datetime import datetime + +# Disable torch compile to avoid inductor import errors +os.environ['TORCHDYNAMO_DISABLE'] = '1' + +warnings.filterwarnings('ignore') + +import random + +import torch +import torch.distributed as dist +from PIL import Image + +import wan +from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from wan.distributed.util import init_distributed_group +from wan.utils.utils import merge_video_audio, save_video, str2bool + + +_SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +_VBENCH_ROOT = os.path.join(_SCRIPT_DIR, "..", "VBench", "vbench2_beta_i2v") +_DEFAULT_INFO_JSON = os.path.join(_VBENCH_ROOT, "vbench2_i2v_full_info.json") +_DEFAULT_CROP_DIR = os.path.join(_VBENCH_ROOT, "vbench2_beta_i2v", "data", "crop") +def _safe(s): + return re.sub(r'[<>:"/\\|?*]', "_", s)[:150] + + +EXAMPLE_PROMPT = { + "i2v-A14B": { + "prompt": + "The video presents a cinematic, first-person wandering experience through a hyper-realistic urban environment rendered in a video game engine. It begins with a static, sun-drenched alley framed by graffiti-laden industrial walls and overhead power lines, immediately establishing a gritty, lived-in atmosphere. As the camera pans right and tilts upward, it reveals a sprawling cityscape dominated by towering skyscrapers and industrial infrastructure, all bathed in warm, late-afternoon light that casts long shadows and produces dramatic lens flares. The perspective then transitions into a smooth forward tracking shot along a cracked sidewalk, passing weathered fences, palm trees, and distant pedestrians, creating a sense of immersion and exploration. Midway, the camera briefly follows a walking figure before refocusing on the broader streetscape, culminating in a stabilized view of a small blue van parked at an intersection surrounded by urban elements like parking garages and traffic lights. The entire sequence is characterized by its photorealistic detail, dynamic lighting, and deliberate pacing, evoking the feel of a quiet, sunlit afternoon in a futuristic metropolis.", + "image": + "examples/02/image.jpg", + }, +} + + +def _validate_args(args): + # Basic check + assert args.ckpt_dir is not None, "Please specify the checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + if args.prompt is None: + args.prompt = EXAMPLE_PROMPT[args.task]["prompt"] + if args.image is None and "image" in EXAMPLE_PROMPT[args.task]: + args.image = EXAMPLE_PROMPT[args.task]["image"] + + if args.task == "i2v-A14B": + assert args.image is not None, "Please specify the image path for i2v." + + cfg = WAN_CONFIGS[args.task] + + if args.sample_steps is None: + args.sample_steps = cfg.sample_steps + + if args.sample_shift is None: + args.sample_shift = cfg.sample_shift + + if args.sample_guide_scale is None: + args.sample_guide_scale = cfg.sample_guide_scale + + if args.frame_num is None: + args.frame_num = cfg.frame_num + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( + 0, sys.maxsize) + # Size check + if not 's2v' in args.task: + assert args.size in SUPPORTED_SIZES[ + args. + task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}" + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--task", + type=str, + default="i2v-A14B", + choices=list(WAN_CONFIGS.keys()), + help="The task to run.") + parser.add_argument( + "--size", + type=str, + default="1280*720", + choices=list(SIZE_CONFIGS.keys()), + help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." + ) + parser.add_argument( + "--frame_num", + type=int, + default=81, + help="How many frames of video are generated. The number should be 4n+1 (default: 81)" + ) + parser.add_argument( + "--ckpt_dir", + type=str, + default=None, + help="The path to the checkpoint directory.") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.") + parser.add_argument( + "--ulysses_size", + type=int, + default=1, + help="The size of the ulysses parallelism in DiT.") + parser.add_argument( + "--t5_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for T5.") + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--dit_fsdp", + action="store_true", + default=False, + help="Whether to use FSDP for DiT.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated video to.") + parser.add_argument( + "--prompt", + type=str, + default=None, + help="The prompt to generate the video from.") + parser.add_argument( + "--use_prompt_extend", + action="store_true", + default=False, + help="Whether to use prompt extend.") + parser.add_argument( + "--prompt_extend_method", + type=str, + default="local_qwen", + choices=["dashscope", "local_qwen"], + help="The prompt extend method to use.") + parser.add_argument( + "--prompt_extend_model", + type=str, + default=None, + help="The prompt extend model to use.") + parser.add_argument( + "--prompt_extend_target_lang", + type=str, + default="zh", + choices=["zh", "en"], + help="The target language of prompt extend.") + parser.add_argument( + "--base_seed", + type=int, + default=42, + help="The seed to use for generating the video.") + parser.add_argument( + "--image", + type=str, + default=None, + help="The image to generate the video from.") + parser.add_argument( + "--action_path", + type=str, + default=None, + help="The camera path to generate the video from.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=None, + help="Classifier free guidance scale.") + parser.add_argument( + "--convert_model_dtype", + action="store_true", + default=False, + help="Whether to convert model paramerters dtype.") + # ---- VBench batch args ---- + parser.add_argument("--vbench", action="store_true", default=True, + help="Run VBench batch generation instead of single-video mode.") + parser.add_argument("--image_types", type=str, default="indoor,scenery", + help="Comma-separated image_type values to include (default: scenery,indoor).") + parser.add_argument("--vbench_output_dir", type=str, default="results_vbench/videos", + help="Output directory for vbench videos.") + parser.add_argument("--num_samples", type=int, default=5, + help="Number of samples per prompt.") + parser.add_argument("--vbench_info_json", type=str, default=None, + help="Path to vbench2_i2v_full_info.json.") + parser.add_argument("--crop_dir", type=str, default=None, + help="Path to VBench crop directory.") + parser.add_argument("--resolution", type=str, default="1-1", + help="Crop resolution subfolder.") + + args = parser.parse_args() + if args.vbench: + assert args.ckpt_dir is not None, "Please specify --ckpt_dir (path to Wan checkpoint directory)." + else: + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + logging.info("Starting the generation process...") + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + logging.info("Initializing distributed environment...") + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + logging.info("Distributed environment initialized.") + else: + assert not ( + args.t5_fsdp or args.dit_fsdp + ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." + assert not ( + args.ulysses_size > 1 + ), f"sequence parallel are not supported in non-distributed environments." + + if args.ulysses_size > 1: + assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size." + init_distributed_group() + + logging.info("Loading model configuration...") + cfg = WAN_CONFIGS[args.task] + if args.ulysses_size > 1: + assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + logging.info(f"Input prompt: {args.prompt}") + img = None + if args.image is not None: + logging.info(f"Loading input image from {args.image}...") + img = Image.open(args.image).convert("RGB") + logging.info("Input image loaded.") + + # prompt extend + if args.use_prompt_extend: + logging.info("Extending prompt...") + if rank == 0: + input_prompt = args.prompt + input_prompt = [input_prompt] + else: + input_prompt = [None] + if dist.is_initialized(): + dist.broadcast_object_list(input_prompt, src=0) + args.prompt = input_prompt[0] + logging.info(f"Extended prompt: {args.prompt}") + + logging.info("Creating WanI2V pipeline...") + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_sp=(args.ulysses_size > 1), + t5_cpu=args.t5_cpu, + convert_model_dtype=args.convert_model_dtype, + ) + logging.info("WanI2V pipeline created.") + + logging.info("Generating video...") + video = wan_i2v.generate( + args.prompt, + img, + action_path=args.action_path, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + logging.info("Video generation completed.") + + if rank == 0: + if args.save_file is None: + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = args.prompt.replace(" ", "_").replace("/", + "_")[:50] + suffix = '.mp4' + args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix + + logging.info(f"Saving generated video to {args.save_file}...") + save_video( + tensor=video[None], + save_file=args.save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + if "s2v" in args.task: + if args.enable_tts is False: + merge_video_audio(video_path=args.save_file, audio_path=args.audio) + else: + merge_video_audio(video_path=args.save_file, audio_path="tts.wav") + logging.info("Video saved successfully.") + del video + + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + logging.info("Generation process finished.") + + +def vbench_batch(args): + info_json = os.path.abspath(args.vbench_info_json or _DEFAULT_INFO_JSON) + crop_base = os.path.abspath(args.crop_dir or _DEFAULT_CROP_DIR) + image_dir = os.path.join(crop_base, args.resolution) + out_dir = os.path.abspath(args.vbench_output_dir) + os.makedirs(out_dir, exist_ok=True) + + stats_path = os.path.join(os.path.dirname(out_dir), 'vbench_stats.csv') + stats_f = open(stats_path, 'w', newline='', encoding='utf-8') + stats_w = csv.writer(stats_f) + stats_w.writerow(['task_idx', 'prompt', 'sample_idx', 'duration_s', 'gen_fps', 'out_path', 'status']) + + if not os.path.isfile(info_json): + print(f'[vbench] ERROR: info JSON not found: {info_json}'); return + if not os.path.isdir(image_dir): + print(f'[vbench] ERROR: crop dir not found: {image_dir}'); return + + with open(info_json, encoding='utf-8') as f: + entries = json.load(f) + + allowed = {t.strip() for t in args.image_types.split(',') if t.strip()} if args.image_types else None + populate = None + + seen, prompts = set(), [] + for e in entries: + name = e['image_name'] + if name in seen: continue + if allowed and e.get('image_type') not in allowed: continue + if populate is not None and (e.get('image_type') in _POPULATED_TYPES) != populate: continue + seen.add(name) + prompts.append((name, e['prompt_en'])) + + print(f'[vbench] {len(prompts)} prompts × {args.num_samples} samples = {len(prompts) * args.num_samples} total') + + cfg = WAN_CONFIGS[args.task] + wan_i2v = wan.WanI2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=0, + rank=0, + t5_fsdp=False, + dit_fsdp=False, + use_sp=False, + t5_cpu=args.t5_cpu, + convert_model_dtype=args.convert_model_dtype, + ) + + skipped = generated = errors = 0 + total = len(prompts) * args.num_samples + done = 0 + t_start = time.time() + print(f'[vbench] {len(prompts)} prompts × {args.num_samples} samples = {total} total') + + for task_idx, (image_name, prompt) in enumerate(prompts): + image_path = os.path.join(image_dir, image_name) + if not os.path.isfile(image_path): + print(f'[vbench] skip {task_idx}: image not found — {image_path}') + continue + + img = Image.open(image_path).convert("RGB") + + for sample_idx in range(args.num_samples): + out_path = os.path.join(out_dir, f'{_safe(prompt)}-{sample_idx}.mp4') + if os.path.exists(out_path): + skipped += 1 + done += 1 + stats_w.writerow([task_idx, prompt, sample_idx, '', '', out_path, 'skipped']) + stats_f.flush() + continue + + pct = 100 * done / total if total else 0 + eta = '' + if done > 0: + secs_left = (time.time() - t_start) / done * (total - done) + eta = f' ETA {int(secs_left//3600):02d}h{int(secs_left%3600//60):02d}m{int(secs_left%60):02d}s' + print(f'[vbench] [{done+1}/{total} {pct:.0f}%{eta}] prompt {task_idx+1}/{len(prompts)} sample {sample_idx+1}/{args.num_samples}: {prompt[:50]}') + seed = args.base_seed + sample_idx + try: + with torch.inference_mode(): + t0 = time.time() + video = wan_i2v.generate( + prompt, img, + max_area=MAX_AREA_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift or cfg.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps or cfg.sample_steps, + guide_scale=args.sample_guide_scale or cfg.sample_guide_scale, + seed=seed, + offload_model=True, + ) + elapsed = time.time() - t0 + frame_num = args.frame_num + gen_fps = frame_num / elapsed if elapsed > 0 else 0.0 + from wan.utils.utils import save_video as _save_video + _save_video(tensor=video[None], save_file=out_path, fps=cfg.sample_fps, + nrow=1, normalize=True, value_range=(-1, 1)) + print(f'[vbench] saved {out_path} ({gen_fps:.1f} gen-fps)') + stats_w.writerow([task_idx, prompt, sample_idx, f'{elapsed:.2f}', f'{gen_fps:.2f}', out_path, 'ok']) + stats_f.flush() + generated += 1 + except Exception as exc: + print(f'[vbench] ERROR task {task_idx} sample {sample_idx}: {exc}') + stats_w.writerow([task_idx, prompt, sample_idx, '', '', out_path, 'error']) + stats_f.flush() + errors += 1 + done += 1 + + elapsed_total = time.time() - t_start + stats_f.close() + print(f'\n[vbench] done — generated={generated} skipped={skipped} errors={errors} elapsed={elapsed_total/60:.1f}m') + print(f'[vbench] stats → {stats_path}') + + +if __name__ == "__main__": + args = _parse_args() + if args.vbench: + vbench_batch(args) + else: + generate(args) diff --git a/out/test.mp4 b/out/test.mp4 new file mode 100644 index 0000000..f4d4d9d Binary files /dev/null and b/out/test.mp4 differ diff --git a/pyproject.toml b/pyproject.toml index 97e0df0..6c72f28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "torchvision>=0.19.0", "opencv-python>=4.9.0.80", "diffusers>=0.31.0", - "transformers>=4.49.0", + "transformers>=4.49.0,<5.0", "tokenizers>=0.20.3", "accelerate>=1.1.1", "tqdm", diff --git a/quantize_nf4.py b/quantize_nf4.py new file mode 100644 index 0000000..feaa1d4 --- /dev/null +++ b/quantize_nf4.py @@ -0,0 +1,51 @@ +""" +Quantize a WanModel checkpoint to NF4 (4-bit NormalFloat) using bitsandbytes. +Saves a reload-compatible checkpoint via diffusers save_pretrained. + +Usage: + python quantize_nf4.py \ + --src base-act/high_noise_model \ + --dst base-act/high_noise_model_nf4 +""" +import argparse +import sys +import torch +import torch.nn as nn + +sys.path.insert(0, ".") + +from diffusers import BitsAndBytesConfig +from wan.modules.model import WanModel + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--src", default="base-act/high_noise_model") + p.add_argument("--dst", default="base-act/high_noise_model_nf4") + return p.parse_args() + + +def main(): + args = parse_args() + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + + print(f"Loading {args.src} ...") + model = WanModel.from_pretrained( + args.src, + quantization_config=nf4_config, + torch_dtype=torch.bfloat16, + ) + + print(f"Saving NF4 model to {args.dst} ...") + model.save_pretrained(args.dst) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 0d7ff99..7fb007a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,11 @@ +# PyTorch with CUDA 12.4 support (closest to CUDA 12.8) +# Install with: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 torch>=2.4.0 torchvision>=0.19.0 -torchaudio +torchaudio>=2.4.0 opencv-python>=4.9.0.80 diffusers>=0.31.0 -transformers>=4.49.0,<=4.51.3 +transformers>=4.49.0,<5.0 tokenizers>=0.20.3 accelerate>=1.1.1 tqdm diff --git a/requirements_win.txt b/requirements_win.txt new file mode 100644 index 0000000..e69de29 diff --git a/wan/modules/model.py b/wan/modules/model.py index b17a39f..071e6f1 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -321,6 +321,7 @@ class WanModel(ModelMixin, ConfigMixin): def __init__(self, model_type='t2v', control_type='cam', + cam_channels: int = None, patch_size=(1, 2, 2), text_len=512, in_dim=16, @@ -391,7 +392,9 @@ def __init__(self, self.cross_attn_norm = cross_attn_norm self.eps = eps - if control_type == 'cam': + if cam_channels is not None: + control_dim = cam_channels + elif control_type == 'cam': control_dim = 6 elif control_type == 'act': control_dim = 7