From bdc8871f6d5e418e17523e8e7c6cdb5fa8a14ebc Mon Sep 17 00:00:00 2001 From: pufanyi Date: Mon, 27 Apr 2026 17:30:25 +0800 Subject: [PATCH 1/3] vllm-omni --- examples/models/vllm_omni_wan22_vbvr_local.sh | 182 +++++ lmms_eval/models/__init__.py | 2 + lmms_eval/models/chat/vllm_omni.py | 772 ++++++++++++++++++ tools/run_vllm_omni_vbvr_local.py | 288 +++++++ tools/run_vllm_omni_vbvr_local_parallel.py | 379 +++++++++ 5 files changed, 1623 insertions(+) create mode 100755 examples/models/vllm_omni_wan22_vbvr_local.sh create mode 100644 lmms_eval/models/chat/vllm_omni.py create mode 100644 tools/run_vllm_omni_vbvr_local.py create mode 100644 tools/run_vllm_omni_vbvr_local_parallel.py diff --git a/examples/models/vllm_omni_wan22_vbvr_local.sh b/examples/models/vllm_omni_wan22_vbvr_local.sh new file mode 100755 index 000000000..e4809ec95 --- /dev/null +++ b/examples/models/vllm_omni_wan22_vbvr_local.sh @@ -0,0 +1,182 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd /mnt/umm/users/pufanyi/workspace/lmms-eval-vllm + +MODEL_DIR=${MODEL_DIR:-/mnt/umm/users/pufanyi/workspace/Wan-Trainer/storage/models/Wan2.2-I2V-A14B-Diffusers} +VBVR_ROOT=${VBVR_ROOT:-/mnt/umm/users/pufanyi/workspace/Wan-Trainer/storage/datasets/VBVR-Bench} +OUTPUT_ROOT=${OUTPUT_ROOT:-/mnt/umm/users/pufanyi/workspace/Wan-Trainer/storage/eval_out/vbvr_wan22_vllm_omni_local_dp8} + +SPLIT=${SPLIT:-all} +LIMIT=${LIMIT:-} + +TP=${TP:-1} +DP=${DP:-} +GPU_MEM_UTIL=${GPU_MEM_UTIL:-0.9} + +NUM_INFERENCE_STEPS=${NUM_INFERENCE_STEPS:-50} +GUIDANCE_SCALE=${GUIDANCE_SCALE:-5.0} +NUM_FRAMES=${NUM_FRAMES:-81} +HEIGHT=${HEIGHT:-384} +WIDTH=${WIDTH:-384} +FPS=${FPS:-16} +SEED=${SEED:-42} +BOUNDARY_RATIO=${BOUNDARY_RATIO:-} +FLOW_SHIFT=${FLOW_SHIFT:-} + +CACHE_BACKEND=${CACHE_BACKEND:-cache_dit} +DIFFUSION_BATCH_SIZE=${DIFFUSION_BATCH_SIZE:-} +REQUEST_BATCH_SIZE=${REQUEST_BATCH_SIZE:-} +OVERWRITE=${OVERWRITE:-0} +SKIP_EVAL=${SKIP_EVAL:-0} +RUN_NAME=${RUN_NAME:-} + +VISIBLE_GPUS=${VISIBLE_GPUS:-${CUDA_VISIBLE_DEVICES:-}} +WORKERS=${WORKERS:-} +LOG_DIR=${LOG_DIR:-$OUTPUT_ROOT/logs} + +if [[ -n "$VISIBLE_GPUS" ]]; then + IFS=',' read -r -a GPU_IDS <<<"$VISIBLE_GPUS" +else + mapfile -t GPU_IDS < <(nvidia-smi --query-gpu=index --format=csv,noheader | tr -d ' ') +fi + +GPU_COUNT=${#GPU_IDS[@]} + +if (( GPU_COUNT == 0 )); then + echo "No visible GPUs found." >&2 + exit 1 +fi + +if [[ -z "$DP" ]]; then + DP=$GPU_COUNT +fi + +if [[ -z "$DIFFUSION_BATCH_SIZE" ]]; then + DIFFUSION_BATCH_SIZE=$DP +fi + +if [[ -z "$REQUEST_BATCH_SIZE" ]]; then + REQUEST_BATCH_SIZE=$DP +fi + +GPUS_PER_WORKER=$((TP * DP)) + +if (( GPUS_PER_WORKER < 1 )); then + echo "Invalid parallelism: TP=$TP DP=$DP" >&2 + exit 1 +fi + +if [[ -z "$WORKERS" ]]; then + WORKERS=$((GPU_COUNT / GPUS_PER_WORKER)) +fi + +if (( WORKERS < 1 )); then + echo "No worker can be launched with GPU_COUNT=$GPU_COUNT and GPUS_PER_WORKER=$GPUS_PER_WORKER." >&2 + exit 1 +fi + +MAX_WORKERS=$((GPU_COUNT / GPUS_PER_WORKER)) +if (( WORKERS > MAX_WORKERS )); then + echo "WORKERS=$WORKERS exceeds the available GPU groups ($MAX_WORKERS)." >&2 + exit 1 +fi + +mkdir -p "$OUTPUT_ROOT" "$LOG_DIR" + +COMMON_ARGS=( + --model "$MODEL_DIR" + --vbvr-root "$VBVR_ROOT" + --output-root "$OUTPUT_ROOT" + --split "$SPLIT" + --tensor-parallel-size "$TP" + --data-parallel-size "$DP" + --gpu-memory-utilization "$GPU_MEM_UTIL" + --cache-backend "$CACHE_BACKEND" + --diffusion-batch-size "$DIFFUSION_BATCH_SIZE" + --request-batch-size "$REQUEST_BATCH_SIZE" + --num-inference-steps "$NUM_INFERENCE_STEPS" + --guidance-scale "$GUIDANCE_SCALE" + --num-frames "$NUM_FRAMES" + --height "$HEIGHT" + --width "$WIDTH" + --fps "$FPS" + --seed "$SEED" +) + +if [[ -n "$LIMIT" ]]; then + COMMON_ARGS+=(--limit "$LIMIT") +fi + +if [[ -n "$FLOW_SHIFT" ]]; then + COMMON_ARGS+=(--flow-shift "$FLOW_SHIFT") +fi + +if [[ -n "$BOUNDARY_RATIO" ]]; then + COMMON_ARGS+=(--boundary-ratio "$BOUNDARY_RATIO") +fi + +if [[ "$OVERWRITE" == "1" ]]; then + COMMON_ARGS+=(--overwrite) +fi + +if [[ -n "$RUN_NAME" ]]; then + COMMON_ARGS+=(--run-name "$RUN_NAME") +fi + +PIDS=() + +cleanup_children() { + for pid in "${PIDS[@]:-}"; do + if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then + kill "$pid" 2>/dev/null || true + fi + done +} + +trap cleanup_children INT TERM + +for ((worker_idx = 0; worker_idx < WORKERS; worker_idx++)); do + start=$((worker_idx * GPUS_PER_WORKER)) + worker_gpus=("${GPU_IDS[@]:start:GPUS_PER_WORKER}") + if (( ${#worker_gpus[@]} != GPUS_PER_WORKER )); then + echo "Worker $worker_idx expected $GPUS_PER_WORKER GPUs but found ${#worker_gpus[@]}." >&2 + cleanup_children + exit 1 + fi + + worker_visible_gpus=$(IFS=','; echo "${worker_gpus[*]}") + worker_log="$LOG_DIR/worker_${worker_idx}.log" + + echo "Launching worker $worker_idx/$((WORKERS - 1)) on GPUs [$worker_visible_gpus] -> $worker_log" + CUDA_VISIBLE_DEVICES="$worker_visible_gpus" \ + .venv/bin/python tools/run_vllm_omni_vbvr_local_parallel.py \ + "${COMMON_ARGS[@]}" \ + --shard-id "$worker_idx" \ + --num-shards "$WORKERS" \ + --skip-eval \ + >"$worker_log" 2>&1 & + + PIDS+=($!) +done + +worker_failed=0 +for pid in "${PIDS[@]}"; do + if ! wait "$pid"; then + worker_failed=1 + fi +done + +if (( worker_failed != 0 )); then + echo "One or more generation workers failed. Check $LOG_DIR." >&2 + exit 1 +fi + +if [[ "$SKIP_EVAL" == "1" ]]; then + exit 0 +fi + +echo "All generation workers finished. Running final evaluation." +.venv/bin/python tools/run_vllm_omni_vbvr_local_parallel.py \ + "${COMMON_ARGS[@]}" \ + --skip-generate diff --git a/lmms_eval/models/__init__.py b/lmms_eval/models/__init__.py index a7c1f92e5..bdbf24535 100644 --- a/lmms_eval/models/__init__.py +++ b/lmms_eval/models/__init__.py @@ -125,6 +125,7 @@ "thyme": "Thyme", "openai": "OpenAICompatible", "vllm": "VLLM", + "vllm_omni": "VLLMOmni", "vllm_generate": "VLLMGenerate", "sglang": "Sglang", "huggingface": "Huggingface", @@ -141,6 +142,7 @@ "async_openai": ("async_openai_compatible_chat", "async_openai_compatible"), "async_hf_model": ("async_hf",), "litellm": ("litellm_chat", "litellm_compatible"), + "vllm_omni": ("vllm-omni",), } diff --git a/lmms_eval/models/chat/vllm_omni.py b/lmms_eval/models/chat/vllm_omni.py new file mode 100644 index 000000000..a11cab4cb --- /dev/null +++ b/lmms_eval/models/chat/vllm_omni.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import atexit +import copy +import importlib +import json +import os +import time +import uuid +import warnings +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from datetime import datetime +from typing import Any, List, Optional, Sequence, Tuple + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + +from lmms_eval.api.instance import GenerationResult, Instance, TokenCounts +from lmms_eval.api.model import lmms +from lmms_eval.api.registry import register_model +from lmms_eval.imports import is_package_available, optional_import +from lmms_eval.models.model_utils.gen_metrics import log_metrics +from lmms_eval.protocol import ChatMessages + +_has_transformers = is_package_available("transformers") +_has_vllm = is_package_available("vllm") +_has_vllm_omni = is_package_available("vllm_omni") +_has_soundfile = is_package_available("soundfile") +_has_diffusers = is_package_available("diffusers") + +AutoProcessor = None +SamplingParams = None +Omni = None +fetch_audio = None +fetch_image = None +fetch_video = None +soundfile = None +export_to_video = None + +WORKERS = int(os.getenv("WORKERS", "8")) + + +def _safe(name: Any, default: str = "x") -> str: + s = "".join(ch if str(ch).isalnum() or ch in "._-" else "_" for ch in str(name)).strip("_") + return (s or default)[:128] + + +def _generate_run_id() -> str: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return f"{timestamp}_{str(uuid.uuid4())[:8]}" + + +def _model_slug(model_path: str) -> str: + return _safe(os.path.basename(str(model_path).rstrip("/")) or "model", default="model") + + +def _default_output_dir(model_path: str) -> str: + return os.path.join("./logs/vllm_omni", _model_slug(model_path), _generate_run_id()) + + +def _build_diffusion_parallel_config( + tensor_parallel_size: int, + data_parallel_size: int, + kwargs: dict[str, Any], +) -> dict[str, Any]: + parallel_config = { + "pipeline_parallel_size": int(kwargs.pop("pipeline_parallel_size", 1) or 1), + "data_parallel_size": int(data_parallel_size), + "tensor_parallel_size": int(tensor_parallel_size), + "ulysses_degree": int(kwargs.pop("ulysses_degree", 1) or 1), + "ring_degree": int(kwargs.pop("ring_degree", 1) or 1), + "ulysses_mode": kwargs.pop("ulysses_mode", "strict") or "strict", + "cfg_parallel_size": int(kwargs.pop("cfg_parallel_size", 1) or 1), + "vae_patch_parallel_size": int(kwargs.pop("vae_patch_parallel_size", 1) or 1), + "use_hsdp": bool(kwargs.pop("use_hsdp", False)), + "hsdp_shard_size": int(kwargs.pop("hsdp_shard_size", -1) or -1), + "hsdp_replicate_size": int(kwargs.pop("hsdp_replicate_size", 1) or 1), + } + sequence_parallel_size = kwargs.pop("sequence_parallel_size", None) + if sequence_parallel_size is not None: + parallel_config["sequence_parallel_size"] = int(sequence_parallel_size) + return parallel_config + + +def _read_model_index_float(model_path: str, key: str) -> float | None: + model_index_path = os.path.join(os.path.expanduser(str(model_path)), "model_index.json") + if not os.path.isfile(model_index_path): + return None + try: + with open(model_index_path, "r", encoding="utf-8") as handle: + value = json.load(handle).get(key) + except Exception: + return None + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +@dataclass +class _PreparedRequest: + prompt: dict[str, Any] + sampling_params_list: Sequence[Any] + task: str + split: Any + doc_id: Any + + +@register_model("vllm_omni", "vllm-omni") +class VLLMOmni(lmms): + is_simple = False + + @staticmethod + def _lazy_import_runtime_dependencies() -> None: + global AutoProcessor, SamplingParams, Omni, fetch_audio, fetch_image, fetch_video, soundfile, export_to_video + + if AutoProcessor is None: + AutoProcessor, _ = optional_import("transformers", "AutoProcessor") + if SamplingParams is None: + SamplingParams, _ = optional_import("vllm", "SamplingParams") + if Omni is None: + Omni, _ = optional_import("vllm_omni", "Omni") + if fetch_audio is None: + fetch_audio, _ = optional_import("vllm.multimodal.utils", "fetch_audio") + if fetch_image is None: + fetch_image, _ = optional_import("vllm.multimodal.utils", "fetch_image") + if fetch_video is None: + fetch_video, _ = optional_import("vllm.multimodal.utils", "fetch_video") + if soundfile is None and _has_soundfile: + soundfile, _ = optional_import("soundfile") + if export_to_video is None and _has_diffusers: + export_to_video, _ = optional_import("diffusers.utils", "export_to_video") + + def __init__( + self, + model: str = "Qwen/Qwen2.5-Omni-7B", + tensor_parallel_size: int = 1, + data_parallel_size: int = 1, + gpu_memory_utilization: float = 0.8, + batch_size: int = 1, + max_frame_num: int = 32, + trust_remote_code: bool = True, + chat_template: Optional[str] = None, + processor_name: Optional[str] = None, + processor_kwargs: Optional[dict[str, Any]] = None, + fps: Optional[int] = 16, + nframes: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, + num_frames: Optional[int] = None, + height: int = 480, + width: int = 832, + seed: int = 42, + boundary_ratio: Optional[float] = None, + flow_shift: Optional[float] = None, + output_dir: Optional[str] = None, + output_modalities: Optional[str | list[str]] = None, + extract_audio_from_video: bool = True, + disable_log_stats: bool = False, + max_new_tokens: int = 4096, + **kwargs, + ) -> None: + super().__init__() + self._lazy_import_runtime_dependencies() + if not _has_vllm_omni or Omni is None: + raise ImportError("vllm-omni is not installed. Please install `vllm-omni` first.") + if not _has_vllm or SamplingParams is None: + raise ImportError("vllm is required by vllm_omni.") + + self.model = model + self.batch_size_per_gpu = int(batch_size) + self.max_frame_num = int(max_frame_num) + self.fps = int(fps) if fps is not None else None + resolved_num_frames = num_frames if num_frames is not None else nframes + self.num_frames = int(resolved_num_frames) if resolved_num_frames is not None else int(max_frame_num) + self.nframes = self.num_frames + self.max_new_tokens = int(max_new_tokens) + self.num_inference_steps = int(num_inference_steps) + self.guidance_scale = float(guidance_scale) + self.guidance_scale_2 = None if guidance_scale_2 is None else float(guidance_scale_2) + self.height = int(height) + self.width = int(width) + self.seed = int(seed) + self.boundary_ratio = None if boundary_ratio is None else float(boundary_ratio) + if self.boundary_ratio is None: + self.boundary_ratio = _read_model_index_float(self.model, "boundary_ratio") + self.flow_shift = None if flow_shift is None else float(flow_shift) + self.extract_audio_from_video = bool(extract_audio_from_video) + self.disable_log_stats = bool(disable_log_stats) + self.output_modalities = self._normalize_output_modalities(output_modalities) + + self.output_dir = os.path.abspath(os.path.expanduser(output_dir or _default_output_dir(self.model))) + os.makedirs(self.output_dir, exist_ok=True) + + processor_kwargs = self._maybe_parse_json_dict(processor_kwargs) or {} + kwargs = self._maybe_parse_json_like_kwargs(kwargs) + if "parallel_config" not in kwargs or kwargs["parallel_config"] is None: + kwargs["parallel_config"] = _build_diffusion_parallel_config( + tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, + kwargs=kwargs, + ) + if "log_stats" not in kwargs: + kwargs["log_stats"] = not self.disable_log_stats + + self.processor = None + self.chat_template = self._load_chat_template(chat_template) + if _has_transformers and AutoProcessor is not None: + try: + self.processor = AutoProcessor.from_pretrained( + processor_name or model, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + except Exception as e: + warnings.warn( + f"Failed to load AutoProcessor for {processor_name or model}: {type(e).__name__}: {e}. " + "Falling back to plain-text prompts.", + stacklevel=2, + ) + if self.chat_template is not None and self.processor is not None: + self.processor.chat_template = self.chat_template + + self.client = Omni( + model=self.model, + tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, + gpu_memory_utilization=gpu_memory_utilization, + trust_remote_code=trust_remote_code, + **kwargs, + ) + atexit.register(self.close) + + @staticmethod + def _maybe_parse_json_dict(value: Any) -> dict[str, Any] | None: + if value is None: + return None + if isinstance(value, dict): + return value + if isinstance(value, str) and value.strip().startswith("{") and value.strip().endswith("}"): + return json.loads(value) + raise TypeError(f"Expected a dict or JSON object string, got {type(value).__name__}") + + @staticmethod + def _maybe_parse_json_like_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]: + parsed = dict(kwargs) + for key, value in parsed.items(): + if isinstance(value, str) and value.strip().startswith("{") and value.strip().endswith("}"): + try: + parsed[key] = json.loads(value) + except json.JSONDecodeError: + pass + return parsed + + @staticmethod + def _normalize_output_modalities(value: Optional[str | list[str]]) -> Optional[list[str]]: + if value is None: + return None + if isinstance(value, str): + return [item.strip() for item in value.split(",") if item.strip()] + return [str(item).strip() for item in value if str(item).strip()] + + @staticmethod + def _load_chat_template(chat_template: Optional[str]) -> Optional[str]: + if chat_template is None: + return None + if os.path.sep in chat_template or chat_template.endswith((".jinja", ".jinja2", ".j2")): + if not os.path.isfile(chat_template): + raise FileNotFoundError(f"Chat template file not found: {chat_template}") + with open(chat_template, "r", encoding="utf-8") as handle: + return handle.read() + return chat_template + + def _select_max_new_tokens(self, request_max_new_tokens: Any) -> int: + if request_max_new_tokens is None: + return self.max_new_tokens + try: + request_max_new_tokens = int(request_max_new_tokens) + except (TypeError, ValueError): + return self.max_new_tokens + return max(request_max_new_tokens, self.max_new_tokens) + + @staticmethod + def _normalize_top_p_for_vllm(top_p: Any) -> Any: + if isinstance(top_p, bool): + return top_p + try: + numeric_top_p = float(top_p) + except (TypeError, ValueError): + return top_p + if numeric_top_p == 0.0: + return 1.0 + return top_p + + def _build_stage0_sampling_params(self, gen_kwargs: dict[str, Any]) -> Sequence[Any]: + sampling_params_list = copy.deepcopy(list(self.client.default_sampling_params_list)) + if not sampling_params_list: + return sampling_params_list + + stage0 = sampling_params_list[0] + gen = dict(gen_kwargs or {}) + diffusion_defaults = { + "num_inference_steps": self.num_inference_steps, + "guidance_scale": self.guidance_scale, + "guidance_scale_2": self.guidance_scale_2, + "num_frames": self.num_frames, + "height": self.height, + "width": self.width, + "seed": self.seed, + "fps": self.fps, + "boundary_ratio": self.boundary_ratio, + "flow_shift": self.flow_shift, + } + for key, value in diffusion_defaults.items(): + if value is not None and hasattr(stage0, key): + setattr(stage0, key, value) + if hasattr(stage0, "guidance_scale_provided"): + setattr(stage0, "guidance_scale_provided", True) + if hasattr(stage0, "max_tokens"): + setattr(stage0, "max_tokens", self._select_max_new_tokens(gen.get("max_new_tokens"))) + if hasattr(stage0, "temperature") and "temperature" in gen: + setattr(stage0, "temperature", gen["temperature"]) + if hasattr(stage0, "top_p") and "top_p" in gen: + setattr(stage0, "top_p", self._normalize_top_p_for_vllm(gen["top_p"])) + + for key, value in gen.items(): + if key in {"until", "max_new_tokens", "temperature", "top_p"}: + continue + if hasattr(stage0, key): + setattr(stage0, key, value) + sampling_params_list[0] = stage0 + return sampling_params_list + + def _apply_chat_template(self, messages: list[dict[str, Any]]) -> str: + if hasattr(self.processor, "apply_chat_template"): + return self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + tokenizer = getattr(self.processor, "tokenizer", None) + if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"): + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + raise AttributeError(f"{type(self.processor).__name__} does not provide apply_chat_template") + + @staticmethod + def _extract_plain_text_prompt(chat_messages: ChatMessages) -> str: + texts: list[str] = [] + for msg in chat_messages.messages: + if msg.role != "user": + continue + for content in msg.content: + if content.type == "text" and content.text: + texts.append(content.text) + return "\n".join(texts).strip() + + @staticmethod + def _is_video_string(value: Any) -> bool: + return isinstance(value, str) and value.lower().endswith((".mp4", ".avi", ".mov", ".flv", ".wmv", ".mkv", ".webm")) + + def _video_fetch_kwargs(self) -> dict[str, Any]: + if self.fps is not None: + return {"fps": self.fps} + return {"num_frames": self.nframes} + + @staticmethod + def _maybe_decode_audio_object(audio_obj: Any) -> tuple[np.ndarray, float] | None: + if isinstance(audio_obj, dict) and "array" in audio_obj: + return VLLMOmni._to_numpy_audio(audio_obj["array"]), float(audio_obj.get("sampling_rate", 16000)) + if isinstance(audio_obj, tuple) and len(audio_obj) == 2 and isinstance(audio_obj[1], (int, float)): + return VLLMOmni._to_numpy_audio(audio_obj[0]), float(audio_obj[1]) + if isinstance(audio_obj, np.ndarray): + return audio_obj.astype(np.float32, copy=False), 16000.0 + if isinstance(audio_obj, list) and audio_obj and all(isinstance(x, (int, float)) for x in audio_obj): + return np.asarray(audio_obj, dtype=np.float32), 16000.0 + if torch.is_tensor(audio_obj): + return audio_obj.detach().cpu().numpy().astype(np.float32, copy=False), 16000.0 + if isinstance(audio_obj, str): + return None + + candidates = [] + if hasattr(audio_obj, "get_all_samples"): + try: + candidates.append(audio_obj.get_all_samples()) + except Exception: + pass + if hasattr(audio_obj, "decode"): + try: + candidates.append(audio_obj.decode()) + except Exception: + pass + if hasattr(audio_obj, "__call__"): + try: + candidates.append(audio_obj()) + except Exception: + pass + candidates.append(audio_obj) + + for candidate in candidates: + if isinstance(candidate, dict) and "array" in candidate: + return VLLMOmni._to_numpy_audio(candidate["array"]), float(candidate.get("sampling_rate", 16000)) + if hasattr(candidate, "array") and hasattr(candidate, "sampling_rate"): + return VLLMOmni._to_numpy_audio(candidate.array), float(candidate.sampling_rate) + if hasattr(candidate, "samples"): + sample_rate = getattr(candidate, "sample_rate", getattr(candidate, "sampling_rate", 16000)) + return VLLMOmni._to_numpy_audio(candidate.samples), float(sample_rate) + if hasattr(candidate, "data") and hasattr(candidate, "sample_rate"): + return VLLMOmni._to_numpy_audio(candidate.data), float(candidate.sample_rate) + return None + + def _prepare_image_input(self, image: Any) -> Any: + if isinstance(image, (Image.Image, np.ndarray)) or torch.is_tensor(image): + return image + if isinstance(image, str): + if fetch_image is None: + raise ImportError("vllm.multimodal.utils.fetch_image is required for image path/url inputs.") + return fetch_image(image) + return image + + def _prepare_audio_input(self, audio: Any) -> Any: + decoded = self._maybe_decode_audio_object(audio) + if decoded is not None: + return decoded + if isinstance(audio, dict): + for key in ("path", "audio", "url"): + value = audio.get(key) + if isinstance(value, str): + audio = value + break + if isinstance(audio, str): + if fetch_audio is None: + raise ImportError("vllm.multimodal.utils.fetch_audio is required for audio path/url inputs.") + return fetch_audio(audio) + raise TypeError(f"Unsupported audio input type: {type(audio).__name__}") + + def _prepare_video_input(self, video: Any) -> tuple[Any, Any | None]: + if isinstance(video, tuple) and len(video) == 2 and isinstance(video[1], dict): + return video, None + if isinstance(video, np.ndarray) or torch.is_tensor(video): + return video, None + if isinstance(video, list) and video: + return video, None + if isinstance(video, str): + if fetch_video is None: + raise ImportError("vllm.multimodal.utils.fetch_video is required for video path/url inputs.") + frames, metadata = fetch_video(video, self._video_fetch_kwargs()) + extracted_audio = None + if self.extract_audio_from_video: + try: + extracted_audio = self._prepare_audio_input(video) + except Exception: + extracted_audio = None + return (frames, metadata), extracted_audio + raise TypeError(f"Unsupported video input type: {type(video).__name__}") + + def _build_multi_modal_data(self, chat_messages: ChatMessages) -> dict[str, Any]: + images, videos, audios = chat_messages.extract_media() + multi_modal_data: dict[str, Any] = {} + + if images: + prepared_images = [self._prepare_image_input(image) for image in images] + if self.processor is None and len(prepared_images) == 1: + multi_modal_data["image"] = prepared_images[0] + else: + multi_modal_data["image"] = prepared_images + + extracted_video_audios = [] + if videos: + prepared_videos = [] + for video in videos: + prepared_video, extracted_audio = self._prepare_video_input(video) + prepared_videos.append(prepared_video) + if extracted_audio is not None: + extracted_video_audios.append(extracted_audio) + multi_modal_data["video"] = prepared_videos + + all_audios = list(audios) + extracted_video_audios + if all_audios: + multi_modal_data["audio"] = [self._prepare_audio_input(audio) for audio in all_audios] + + return multi_modal_data + + def make_one_request(self, request: Instance) -> _PreparedRequest: + ctx, doc_to_messages, gen_kwargs, doc_id, task, split = request.arguments + raw_messages = doc_to_messages(self.task_dict[task][split][doc_id]) + chat_messages = ChatMessages(messages=raw_messages) + if self.processor is not None: + hf_messages = chat_messages.to_hf_messages() + prompt_text = self._apply_chat_template(hf_messages) + else: + prompt_text = self._extract_plain_text_prompt(chat_messages) + prompt = {"prompt": prompt_text} + + multi_modal_data = self._build_multi_modal_data(chat_messages) + if multi_modal_data: + prompt["multi_modal_data"] = multi_modal_data + if self.output_modalities is not None: + prompt["modalities"] = self.output_modalities + + sampling_params_list = self._build_stage0_sampling_params(dict(gen_kwargs or {})) + return _PreparedRequest( + prompt=prompt, + sampling_params_list=sampling_params_list, + task=str(task), + split=split, + doc_id=doc_id, + ) + + @staticmethod + def _sampling_signature(sampling_params_list: Sequence[Any]) -> tuple[str, ...]: + return tuple(repr(params) for params in sampling_params_list) + + @staticmethod + def _extract_text(output: Any) -> str: + outputs = getattr(output, "outputs", []) or [] + if outputs: + return getattr(outputs[0], "text", "") or "" + return "" + + @staticmethod + def _extract_token_counts(output: Any) -> TokenCounts | None: + outputs = getattr(output, "outputs", []) or [] + if not outputs: + return None + token_ids = getattr(outputs[0], "token_ids", None) + if token_ids is None: + return None + return TokenCounts(output_tokens=len(token_ids)) + + @staticmethod + def _to_numpy_audio(audio: Any) -> np.ndarray: + if torch.is_tensor(audio): + audio = audio.detach().cpu().numpy() + elif isinstance(audio, list): + audio = np.asarray(audio, dtype=np.float32) + elif not isinstance(audio, np.ndarray): + raise TypeError(f"Unsupported audio array type: {type(audio).__name__}") + + audio = np.asarray(audio, dtype=np.float32) + audio = np.squeeze(audio) + if audio.ndim == 2 and audio.shape[0] <= 8 and audio.shape[1] > 8: + audio = audio.T + return audio + + def _collect_audio_payloads(self, payload: Any, fallback_sr: Optional[float] = None) -> list[tuple[np.ndarray, float]]: + if payload is None: + return [] + if isinstance(payload, dict): + next_sr = payload.get("audio_sample_rate", payload.get("sampling_rate", payload.get("sample_rate", payload.get("sr", fallback_sr)))) + if "audio" in payload: + return self._collect_audio_payloads(payload["audio"], next_sr) + if "array" in payload: + return [(self._to_numpy_audio(payload["array"]), float(next_sr or 16000))] + clips: list[tuple[np.ndarray, float]] = [] + for value in payload.values(): + clips.extend(self._collect_audio_payloads(value, next_sr)) + return clips + if isinstance(payload, tuple) and len(payload) == 2 and isinstance(payload[1], (int, float)): + return [(self._to_numpy_audio(payload[0]), float(payload[1]))] + if isinstance(payload, (list, tuple)): + if payload and all(isinstance(item, (int, float)) for item in payload): + return [(self._to_numpy_audio(payload), float(fallback_sr or 16000))] + clips: list[tuple[np.ndarray, float]] = [] + for item in payload: + clips.extend(self._collect_audio_payloads(item, fallback_sr)) + return clips + if torch.is_tensor(payload) or isinstance(payload, np.ndarray): + return [(self._to_numpy_audio(payload), float(fallback_sr or 16000))] + return [] + + @staticmethod + def _to_pil_image(image: Any) -> Image.Image: + if isinstance(image, Image.Image): + return image + if torch.is_tensor(image): + image = image.detach().cpu().numpy() + if not isinstance(image, np.ndarray): + raise TypeError(f"Unsupported image output type: {type(image).__name__}") + if image.ndim == 3 and image.shape[0] in {1, 3, 4} and image.shape[-1] not in {1, 3, 4}: + image = np.transpose(image, (1, 2, 0)) + if image.dtype != np.uint8: + image = np.clip(image, 0, 1) * 255 if image.max() <= 1.0 else np.clip(image, 0, 255) + image = image.astype(np.uint8) + return Image.fromarray(image) + + def _request_output_dir(self, task: str, split: Any, doc_id: Any) -> str: + out_dir = os.path.join(self.output_dir, _safe(task), _safe(split), _safe(doc_id)) + os.makedirs(out_dir, exist_ok=True) + return out_dir + + def _save_images(self, images: Sequence[Any], out_dir: str) -> list[str]: + paths = [] + for idx, image in enumerate(images): + image_path = os.path.join(out_dir, f"image_{idx}.png") + self._to_pil_image(image).save(image_path) + paths.append(image_path) + return paths + + def _normalize_video_frames(self, frames: Any) -> list[Any]: + if isinstance(frames, list): + normalized: list[Any] = [] + for item in frames: + normalized.extend(self._normalize_video_frames(item)) + return normalized + if torch.is_tensor(frames): + frames = frames.detach().cpu().numpy() + if isinstance(frames, np.ndarray): + if frames.ndim == 5 and frames.shape[0] == 1: + return self._normalize_video_frames(frames[0]) + if frames.ndim == 4: + return [frames[i] for i in range(frames.shape[0])] + if frames.ndim == 3: + return [frames] + return [frames] + + def _save_video(self, images: Sequence[Any], out_dir: str) -> list[str]: + if not images: + return [] + + video_path = os.path.join(out_dir, "video.mp4") + fps = int(self.fps or 16) + pil_images = [self._to_pil_image(image).convert("RGB") for image in self._normalize_video_frames(list(images))] + if export_to_video is not None: + export_to_video(pil_images, output_video_path=video_path, fps=fps) + return [video_path] + + try: + imageio_v2 = importlib.import_module("imageio.v2") + except Exception as e: + raise ImportError("Saving video outputs requires `diffusers` or `imageio`.") from e + + frames = [np.asarray(self._to_pil_image(image).convert("RGB")) for image in images] + imageio_v2.mimsave(video_path, frames, fps=fps) + return [video_path] + + def _save_audios(self, audio_payload: Any, out_dir: str, fallback_sr: Optional[float]) -> list[str]: + clips = self._collect_audio_payloads(audio_payload, fallback_sr=fallback_sr) + if not clips: + return [] + if not _has_soundfile or soundfile is None: + raise ImportError("soundfile is required to save audio outputs from vllm_omni.") + + paths = [] + for idx, (audio, sample_rate) in enumerate(clips): + audio_path = os.path.join(out_dir, f"audio_{idx}.wav") + soundfile.write(audio_path, audio, int(round(sample_rate))) + paths.append(audio_path) + return paths + + def _format_output(self, text: str, image_paths: list[str], audio_paths: list[str], video_paths: list[str]) -> str: + if not image_paths and not audio_paths and not video_paths: + return text + payload: dict[str, Any] = {"text": text} + if video_paths: + payload["videos"] = video_paths + if image_paths: + payload["images"] = image_paths + if audio_paths: + payload["audios"] = audio_paths + return json.dumps(payload, ensure_ascii=False) + + def _to_generation_result(self, output: Any, prepared: _PreparedRequest) -> GenerationResult: + text = self._extract_text(output) + token_counts = self._extract_token_counts(output) + + image_paths: list[str] = [] + audio_paths: list[str] = [] + video_paths: list[str] = [] + out_dir = self._request_output_dir(prepared.task, prepared.split, prepared.doc_id) + + images = getattr(output, "images", []) or [] + if images: + if len(images) > 1: + video_paths = self._save_video(images, out_dir) + else: + image_paths = self._save_images(images, out_dir) + + multimodal_output = getattr(output, "multimodal_output", {}) or {} + fallback_sr = multimodal_output.get("audio_sample_rate", multimodal_output.get("sampling_rate", multimodal_output.get("sample_rate", multimodal_output.get("sr")))) + if "audio" in multimodal_output: + audio_paths = self._save_audios(multimodal_output["audio"], out_dir, fallback_sr) + + formatted = self._format_output(text, image_paths, audio_paths, video_paths) + return GenerationResult(text=formatted, token_counts=token_counts) + + def _generate_batch(self, prepared_requests: Sequence[_PreparedRequest]) -> tuple[list[Any], float]: + prompts = [prepared.prompt for prepared in prepared_requests] + start_time = time.time() + outputs = self.client.generate( + prompts, + sampling_params_list=prepared_requests[0].sampling_params_list, + use_tqdm=False, + ) + return outputs, time.time() - start_time + + def _generate_single(self, prepared_request: _PreparedRequest) -> tuple[Any, float]: + start_time = time.time() + outputs = self.client.generate( + prepared_request.prompt, + sampling_params_list=prepared_request.sampling_params_list, + use_tqdm=False, + ) + return outputs[0], time.time() - start_time + + def generate_until(self, requests) -> List[GenerationResult]: + res: list[GenerationResult] = [] + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding") + total_elapsed_time = 0.0 + total_tokens = 0 + + batch_size = self.batch_size_per_gpu + batched_requests = [requests[i : i + batch_size] for i in range(0, len(requests), batch_size)] + for batch_requests in batched_requests: + with ThreadPoolExecutor(max_workers=max(1, min(WORKERS, len(batch_requests)))) as executor: + prepared_requests = list(executor.map(self.make_one_request, batch_requests)) + + can_batch = len({self._sampling_signature(prepared.sampling_params_list) for prepared in prepared_requests}) == 1 + if can_batch: + outputs, elapsed = self._generate_batch(prepared_requests) + total_elapsed_time += elapsed + for prepared, output in zip(prepared_requests, outputs): + result = self._to_generation_result(output, prepared) + if result.token_counts and result.token_counts.output_tokens is not None: + total_tokens += result.token_counts.output_tokens + res.append(result) + else: + for prepared in prepared_requests: + output, elapsed = self._generate_single(prepared) + total_elapsed_time += elapsed + result = self._to_generation_result(output, prepared) + if result.token_counts and result.token_counts.output_tokens is not None: + total_tokens += result.token_counts.output_tokens + res.append(result) + + pbar.update(len(batch_requests)) + + if not self.disable_log_stats: + avg_speed = total_tokens / total_elapsed_time if total_elapsed_time > 0 else 0.0 + log_metrics( + total_elapsed_time=total_elapsed_time, + total_gen_tokens=total_tokens, + avg_speed=avg_speed, + additional_metrics={"request_count": len(requests)}, + ) + + pbar.close() + return res + + def close(self) -> None: + client = getattr(self, "client", None) + if client is None: + return + try: + client.close() + except Exception: + pass + + def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: + assert False, "vllm_omni does not support loglikelihood" + + def generate_until_multi_round(self, requests) -> List[str]: + raise NotImplementedError("TODO: Implement multi-round generation") diff --git a/tools/run_vllm_omni_vbvr_local.py b/tools/run_vllm_omni_vbvr_local.py new file mode 100644 index 000000000..12e6a0a0f --- /dev/null +++ b/tools/run_vllm_omni_vbvr_local.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import argparse +import base64 +import copy +import importlib +import io +import json +import os +from datetime import datetime +from pathlib import Path +from typing import Any + +from PIL import Image +from tqdm import tqdm + +from lmms_eval.tasks.vbvr.vbvr_bench import VBVRBench +from vllm_omni import Omni + +try: + from diffusers.utils import export_to_video +except Exception: # pragma: no cover + export_to_video = None + + +FILE_SPLIT_MAP = { + "all": None, + "in_domain": "In-Domain_50", + "out_of_domain": "Out-of-Domain_50", +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Wan2.2 with vllm-omni on a local VBVR-Bench checkout.") + parser.add_argument("--model", required=True, help="Path to the local Wan2.2 checkpoint") + parser.add_argument("--vbvr-root", required=True, help="Path to the local VBVR-Bench root") + parser.add_argument("--manifest", default=None, help="Optional path to VBVR-Bench.json; defaults to /VBVR-Bench.json") + parser.add_argument("--output-root", required=True, help="Directory for generated videos and metrics") + parser.add_argument("--split", choices=sorted(FILE_SPLIT_MAP), default="all") + parser.add_argument("--limit", type=int, default=None, help="Optional sample limit after split filtering") + parser.add_argument("--tensor-parallel-size", type=int, default=1) + parser.add_argument("--data-parallel-size", type=int, default=1) + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + parser.add_argument("--cache-backend", default="cache_dit") + parser.add_argument("--num-inference-steps", type=int, default=50) + parser.add_argument("--guidance-scale", type=float, default=5.0) + parser.add_argument("--guidance-scale-2", type=float, default=None) + parser.add_argument("--num-frames", type=int, default=81) + parser.add_argument("--height", type=int, default=384) + parser.add_argument("--width", type=int, default=384) + parser.add_argument("--fps", type=int, default=16) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--boundary-ratio", type=float, default=None) + parser.add_argument("--flow-shift", type=float, default=None) + parser.add_argument("--diffusion-batch-size", type=int, default=None) + parser.add_argument("--overwrite", action="store_true", help="Regenerate videos even if the output mp4 already exists") + parser.add_argument("--skip-eval", action="store_true", help="Only generate videos; skip VBVR scoring") + parser.add_argument("--task-specific-only", action="store_true", help="Score only VBVR task-specific rules instead of the default weighted aggregate") + parser.add_argument("--run-name", default=None, help="Optional name for the evaluation result JSON") + return parser.parse_args() + + +def decode_base64_image(data: str) -> Image.Image: + payload = data.split(",", 1)[-1] + return Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB") + + +def parse_task_meta(doc: dict[str, Any]) -> tuple[str, str, str]: + raw = ( + doc.get("first_frame_path") + or doc.get("final_frame_path") + or doc.get("prompt_path") + or doc.get("ground_truth_video_path") + or "" + ) + parts = [part for part in str(raw).split("/") if part] + if len(parts) < 3: + raise ValueError(f"Cannot parse VBVR task meta from path: {raw!r}") + return parts[0], parts[1], parts[2] + + +def filtered_docs(manifest_path: Path, split: str, limit: int | None) -> list[dict[str, Any]]: + docs = json.loads(manifest_path.read_text()) + file_split = FILE_SPLIT_MAP[split] + selected: list[dict[str, Any]] = [] + for doc in docs: + doc_file_split, _, _ = parse_task_meta(doc) + if file_split is not None and doc_file_split != file_split: + continue + selected.append(doc) + selected.sort(key=lambda doc: parse_task_meta(doc)) + if limit is not None: + selected = selected[:limit] + return selected + + +def build_sampling_params(omni: Omni, args: argparse.Namespace) -> list[Any]: + sampling_params_list = copy.deepcopy(list(omni.default_sampling_params_list)) + if not sampling_params_list: + raise RuntimeError("vllm-omni returned an empty default_sampling_params_list") + + stage0 = sampling_params_list[0] + boundary_ratio = args.boundary_ratio + if boundary_ratio is None: + boundary_ratio = read_model_index_float(args.model, "boundary_ratio") + + values = { + "num_inference_steps": args.num_inference_steps, + "guidance_scale": args.guidance_scale, + "guidance_scale_2": args.guidance_scale_2, + "num_frames": args.num_frames, + "height": args.height, + "width": args.width, + "fps": args.fps, + "seed": args.seed, + "boundary_ratio": boundary_ratio, + "flow_shift": args.flow_shift, + } + for key, value in values.items(): + if value is not None and hasattr(stage0, key): + setattr(stage0, key, value) + if hasattr(stage0, "guidance_scale_provided"): + stage0.guidance_scale_provided = True + sampling_params_list[0] = stage0 + return sampling_params_list + + +def read_model_index_float(model_path: str, key: str) -> float | None: + model_index_path = Path(model_path).expanduser() / "model_index.json" + if not model_index_path.is_file(): + return None + try: + value = json.loads(model_index_path.read_text()).get(key) + except Exception: + return None + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def build_parallel_config(args: argparse.Namespace) -> dict[str, int]: + return { + "pipeline_parallel_size": 1, + "data_parallel_size": args.data_parallel_size, + "tensor_parallel_size": args.tensor_parallel_size, + } + + +def to_pil_image(image: Any) -> Image.Image: + if isinstance(image, Image.Image): + return image.convert("RGB") + + import numpy as np + import torch + + if torch.is_tensor(image): + image = image.detach().cpu().numpy() + if not isinstance(image, np.ndarray): + raise TypeError(f"Unsupported image output type: {type(image).__name__}") + if image.ndim == 3 and image.shape[0] in {1, 3, 4} and image.shape[-1] not in {1, 3, 4}: + image = image.transpose(1, 2, 0) + if image.dtype != np.uint8: + image = np.clip(image, 0, 1) * 255 if image.max() <= 1.0 else np.clip(image, 0, 255) + image = image.astype(np.uint8) + return Image.fromarray(image).convert("RGB") + + +def normalize_video_frames(frames: Any) -> list[Any]: + import numpy as np + import torch + + if isinstance(frames, list): + if not frames: + return [] + normalized: list[Any] = [] + for item in frames: + normalized.extend(normalize_video_frames(item)) + return normalized + if torch.is_tensor(frames): + frames = frames.detach().cpu().numpy() + if isinstance(frames, np.ndarray): + if frames.ndim == 5 and frames.shape[0] == 1: + return normalize_video_frames(frames[0]) + if frames.ndim == 4: + return [frames[i] for i in range(frames.shape[0])] + if frames.ndim == 3: + return [frames] + return [frames] + + +def save_video(frames: list[Any], output_path: Path, fps: int) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + pil_frames = [to_pil_image(frame) for frame in normalize_video_frames(frames)] + if export_to_video is not None: + export_to_video(pil_frames, output_video_path=str(output_path), fps=fps) + return + + imageio_v2 = importlib.import_module("imageio.v2") + imageio_v2.mimsave(str(output_path), pil_frames, fps=fps) + + +def main() -> None: + args = parse_args() + if args.diffusion_batch_size is None: + args.diffusion_batch_size = max(1, args.data_parallel_size) + vbvr_root = Path(args.vbvr_root).expanduser().resolve() + manifest_path = Path(args.manifest).expanduser().resolve() if args.manifest else vbvr_root / "VBVR-Bench.json" + output_root = Path(args.output_root).expanduser().resolve() + videos_root = output_root / "videos" + metrics_root = output_root / "metrics" + videos_root.mkdir(parents=True, exist_ok=True) + metrics_root.mkdir(parents=True, exist_ok=True) + + docs = filtered_docs(manifest_path, args.split, args.limit) + if not docs: + raise SystemExit("No VBVR samples matched the requested split/limit") + + print(f"Using manifest: {manifest_path}") + print(f"Using GT root: {vbvr_root}") + print(f"Output root: {output_root}") + print(f"Samples: {len(docs)}") + + omni = Omni( + model=args.model, + parallel_config=build_parallel_config(args), + gpu_memory_utilization=args.gpu_memory_utilization, + trust_remote_code=True, + cache_backend=args.cache_backend, + diffusion_batch_size=args.diffusion_batch_size, + ) + + failures: list[dict[str, Any]] = [] + try: + sampling_params_list = build_sampling_params(omni, args) + for doc in tqdm(docs, desc="Generating VBVR videos", dynamic_ncols=True): + file_split, task_name, video_idx = parse_task_meta(doc) + output_path = videos_root / file_split / task_name / f"{video_idx}.mp4" + if output_path.exists() and not args.overwrite: + continue + + prompt = str(doc.get("prompt") or "").strip() + image = decode_base64_image(str(doc["first_image"])) + request = { + "prompt": prompt, + "multi_modal_data": {"image": image}, + } + + try: + outputs = omni.generate(request, sampling_params_list=sampling_params_list, use_tqdm=False) + result = outputs[0] + if getattr(result, "error", None): + raise RuntimeError(str(result.error)) + frames = list(getattr(result, "images", []) or []) + if not frames: + raise RuntimeError("Omni returned no image frames") + save_video(frames, output_path, fps=args.fps) + except Exception as e: # noqa: BLE001 + failures.append( + { + "file_split": file_split, + "task_name": task_name, + "video_idx": video_idx, + "output_path": str(output_path), + "error": f"{type(e).__name__}: {e}", + } + ) + + failures_path = metrics_root / "generation_failures.json" + failures_path.write_text(json.dumps(failures, indent=2)) + print(f"Generation failures: {len(failures)}") + print(f"Failure log: {failures_path}") + finally: + omni.close() + + if args.skip_eval: + return + + file_split = FILE_SPLIT_MAP[args.split] + run_name = args.run_name or datetime.now().strftime("%Y%m%d_%H%M%S") + bench = VBVRBench(gt_base_path=str(vbvr_root), output_path=str(metrics_root)) + bench.evaluate(str(videos_root), name=run_name, split=file_split, task_specific_only=args.task_specific_only) + + +if __name__ == "__main__": + main() diff --git a/tools/run_vllm_omni_vbvr_local_parallel.py b/tools/run_vllm_omni_vbvr_local_parallel.py new file mode 100644 index 000000000..a5ea619e4 --- /dev/null +++ b/tools/run_vllm_omni_vbvr_local_parallel.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +import argparse +import base64 +import copy +import importlib +import io +import json +from datetime import datetime +from pathlib import Path +from typing import Any + +from PIL import Image +from tqdm import tqdm + +from lmms_eval.tasks.vbvr.vbvr_bench import VBVRBench +from vllm_omni import Omni + +try: + from diffusers.utils import export_to_video +except Exception: # pragma: no cover + export_to_video = None + + +FILE_SPLIT_MAP = { + "all": None, + "in_domain": "In-Domain_50", + "out_of_domain": "Out-of-Domain_50", +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run Wan2.2 with vllm-omni on a local VBVR-Bench checkout.") + parser.add_argument("--model", required=True, help="Path to the local Wan2.2 checkpoint") + parser.add_argument("--vbvr-root", required=True, help="Path to the local VBVR-Bench root") + parser.add_argument("--manifest", default=None, help="Optional path to VBVR-Bench.json; defaults to /VBVR-Bench.json") + parser.add_argument("--output-root", required=True, help="Directory for generated videos and metrics") + parser.add_argument("--split", choices=sorted(FILE_SPLIT_MAP), default="all") + parser.add_argument("--limit", type=int, default=None, help="Optional sample limit after split filtering") + parser.add_argument("--tensor-parallel-size", type=int, default=1) + parser.add_argument("--data-parallel-size", type=int, default=1) + parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) + parser.add_argument("--cache-backend", default="cache_dit") + parser.add_argument("--num-inference-steps", type=int, default=50) + parser.add_argument("--guidance-scale", type=float, default=5.0) + parser.add_argument("--guidance-scale-2", type=float, default=None) + parser.add_argument("--num-frames", type=int, default=81) + parser.add_argument("--height", type=int, default=384) + parser.add_argument("--width", type=int, default=384) + parser.add_argument("--fps", type=int, default=16) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--boundary-ratio", type=float, default=None) + parser.add_argument("--flow-shift", type=float, default=None) + parser.add_argument("--diffusion-batch-size", type=int, default=None) + parser.add_argument("--request-batch-size", type=int, default=None, help="How many samples to submit in one Omni.generate call.") + parser.add_argument("--shard-id", type=int, default=0, help="0-based shard index for process-level sample parallelism.") + parser.add_argument("--num-shards", type=int, default=1, help="Total number of disjoint shards.") + parser.add_argument("--overwrite", action="store_true", help="Regenerate videos even if the output mp4 already exists") + parser.add_argument("--skip-generate", action="store_true", help="Only run VBVR scoring on existing videos; skip generation") + parser.add_argument("--skip-eval", action="store_true", help="Only generate videos; skip VBVR scoring") + parser.add_argument("--task-specific-only", action="store_true", help="Score only VBVR task-specific rules instead of the default weighted aggregate") + parser.add_argument("--run-name", default=None, help="Optional name for the evaluation result JSON") + return parser.parse_args() + + +def decode_base64_image(data: str) -> Image.Image: + payload = data.split(",", 1)[-1] + return Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB") + + +def parse_task_meta(doc: dict[str, Any]) -> tuple[str, str, str]: + raw = ( + doc.get("first_frame_path") + or doc.get("final_frame_path") + or doc.get("prompt_path") + or doc.get("ground_truth_video_path") + or "" + ) + parts = [part for part in str(raw).split("/") if part] + if len(parts) < 3: + raise ValueError(f"Cannot parse VBVR task meta from path: {raw!r}") + return parts[0], parts[1], parts[2] + + +def filtered_docs( + manifest_path: Path, + split: str, + limit: int | None, + shard_id: int, + num_shards: int, +) -> list[dict[str, Any]]: + docs = json.loads(manifest_path.read_text()) + file_split = FILE_SPLIT_MAP[split] + selected: list[dict[str, Any]] = [] + for doc in docs: + doc_file_split, _, _ = parse_task_meta(doc) + if file_split is not None and doc_file_split != file_split: + continue + selected.append(doc) + selected.sort(key=lambda doc: parse_task_meta(doc)) + if limit is not None: + selected = selected[:limit] + if num_shards > 1: + selected = [doc for idx, doc in enumerate(selected) if idx % num_shards == shard_id] + return selected + + +def build_sampling_params(omni: Omni, args: argparse.Namespace) -> list[Any]: + sampling_params_list = copy.deepcopy(list(omni.default_sampling_params_list)) + if not sampling_params_list: + raise RuntimeError("vllm-omni returned an empty default_sampling_params_list") + + stage0 = sampling_params_list[0] + boundary_ratio = args.boundary_ratio + if boundary_ratio is None: + boundary_ratio = read_model_index_float(args.model, "boundary_ratio") + + values = { + "num_inference_steps": args.num_inference_steps, + "guidance_scale": args.guidance_scale, + "guidance_scale_2": args.guidance_scale_2, + "num_frames": args.num_frames, + "height": args.height, + "width": args.width, + "fps": args.fps, + "seed": args.seed, + "boundary_ratio": boundary_ratio, + "flow_shift": args.flow_shift, + } + for key, value in values.items(): + if value is not None and hasattr(stage0, key): + setattr(stage0, key, value) + if hasattr(stage0, "guidance_scale_provided"): + stage0.guidance_scale_provided = True + sampling_params_list[0] = stage0 + return sampling_params_list + + +def read_model_index_float(model_path: str, key: str) -> float | None: + model_index_path = Path(model_path).expanduser() / "model_index.json" + if not model_index_path.is_file(): + return None + try: + value = json.loads(model_index_path.read_text()).get(key) + except Exception: + return None + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError): + return None + + +def build_parallel_config(args: argparse.Namespace) -> dict[str, int]: + return { + "pipeline_parallel_size": 1, + "data_parallel_size": args.data_parallel_size, + "tensor_parallel_size": args.tensor_parallel_size, + } + + +def to_pil_image(image: Any) -> Image.Image: + if isinstance(image, Image.Image): + return image.convert("RGB") + + import numpy as np + import torch + + if torch.is_tensor(image): + image = image.detach().cpu().numpy() + if not isinstance(image, np.ndarray): + raise TypeError(f"Unsupported image output type: {type(image).__name__}") + if image.ndim == 3 and image.shape[0] in {1, 3, 4} and image.shape[-1] not in {1, 3, 4}: + image = image.transpose(1, 2, 0) + if image.dtype != np.uint8: + image = np.clip(image, 0, 1) * 255 if image.max() <= 1.0 else np.clip(image, 0, 255) + image = image.astype(np.uint8) + return Image.fromarray(image).convert("RGB") + + +def normalize_video_frames(frames: Any) -> list[Any]: + import numpy as np + import torch + + if isinstance(frames, list): + if not frames: + return [] + normalized: list[Any] = [] + for item in frames: + normalized.extend(normalize_video_frames(item)) + return normalized + if torch.is_tensor(frames): + frames = frames.detach().cpu().numpy() + if isinstance(frames, np.ndarray): + if frames.ndim == 5 and frames.shape[0] == 1: + return normalize_video_frames(frames[0]) + if frames.ndim == 4: + return [frames[i] for i in range(frames.shape[0])] + if frames.ndim == 3: + return [frames] + return [frames] + + +def save_video(frames: Any, output_path: Path, fps: int) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + pil_frames = [to_pil_image(frame) for frame in normalize_video_frames(frames)] + if export_to_video is not None: + export_to_video(pil_frames, output_video_path=str(output_path), fps=fps) + return + + imageio_v2 = importlib.import_module("imageio.v2") + imageio_v2.mimsave(str(output_path), pil_frames, fps=fps) + + +def chunked(items: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]: + return [items[idx : idx + size] for idx in range(0, len(items), size)] + + +def generation_failures_path(metrics_root: Path, shard_id: int, num_shards: int) -> Path: + if num_shards <= 1: + return metrics_root / "generation_failures.json" + return metrics_root / f"generation_failures_shard_{shard_id:02d}_of_{num_shards:02d}.json" + + +def generate_videos( + args: argparse.Namespace, + docs: list[dict[str, Any]], + videos_root: Path, + metrics_root: Path, +) -> None: + failures: list[dict[str, Any]] = [] + failures_path = generation_failures_path(metrics_root, args.shard_id, args.num_shards) + if not docs: + failures_path.write_text(json.dumps(failures, indent=2)) + print("Generation failures: 0") + print(f"Failure log: {failures_path}") + return + + omni = Omni( + model=args.model, + parallel_config=build_parallel_config(args), + gpu_memory_utilization=args.gpu_memory_utilization, + trust_remote_code=True, + cache_backend=args.cache_backend, + diffusion_batch_size=args.diffusion_batch_size, + ) + + try: + sampling_params_list = build_sampling_params(omni, args) + batches = chunked(docs, args.request_batch_size) + desc = f"Generating shard {args.shard_id + 1}/{args.num_shards}" + for batch_docs in tqdm(batches, desc=desc, dynamic_ncols=True): + request_batch: list[dict[str, Any]] = [] + meta_batch: list[tuple[str, str, str, Path]] = [] + for doc in batch_docs: + file_split, task_name, video_idx = parse_task_meta(doc) + output_path = videos_root / file_split / task_name / f"{video_idx}.mp4" + if output_path.exists() and not args.overwrite: + continue + try: + prompt = str(doc.get("prompt") or "").strip() + image = decode_base64_image(str(doc["first_image"])) + request_batch.append({"prompt": prompt, "multi_modal_data": {"image": image}}) + meta_batch.append((file_split, task_name, video_idx, output_path)) + except Exception as e: # noqa: BLE001 + failures.append( + { + "file_split": file_split, + "task_name": task_name, + "video_idx": video_idx, + "output_path": str(output_path), + "error": f"{type(e).__name__}: {e}", + } + ) + + if not request_batch: + continue + + prompts: list[dict[str, Any]] | dict[str, Any] + prompts = request_batch if len(request_batch) > 1 else request_batch[0] + try: + outputs = omni.generate(prompts, sampling_params_list=sampling_params_list, use_tqdm=False) + except Exception as e: # noqa: BLE001 + for file_split, task_name, video_idx, output_path in meta_batch: + failures.append( + { + "file_split": file_split, + "task_name": task_name, + "video_idx": video_idx, + "output_path": str(output_path), + "error": f"{type(e).__name__}: {e}", + } + ) + continue + + if len(outputs) != len(meta_batch): + error_text = f"Expected {len(meta_batch)} outputs, got {len(outputs)}" + for file_split, task_name, video_idx, output_path in meta_batch: + failures.append( + { + "file_split": file_split, + "task_name": task_name, + "video_idx": video_idx, + "output_path": str(output_path), + "error": error_text, + } + ) + continue + + for (file_split, task_name, video_idx, output_path), result in zip(meta_batch, outputs): + try: + if getattr(result, "error", None): + raise RuntimeError(str(result.error)) + frames = getattr(result, "images", None) + if frames is None or (isinstance(frames, list) and not frames): + raise RuntimeError("Omni returned no image frames") + save_video(frames, output_path, fps=args.fps) + except Exception as e: # noqa: BLE001 + failures.append( + { + "file_split": file_split, + "task_name": task_name, + "video_idx": video_idx, + "output_path": str(output_path), + "error": f"{type(e).__name__}: {e}", + } + ) + finally: + omni.close() + + failures_path.write_text(json.dumps(failures, indent=2)) + print(f"Generation failures: {len(failures)}") + print(f"Failure log: {failures_path}") + + +def main() -> None: + args = parse_args() + if args.diffusion_batch_size is None: + args.diffusion_batch_size = max(1, args.data_parallel_size) + if args.request_batch_size is None: + args.request_batch_size = max(1, args.data_parallel_size) + if args.num_shards < 1: + raise SystemExit("--num-shards must be >= 1") + if args.shard_id < 0 or args.shard_id >= args.num_shards: + raise SystemExit("--shard-id must be in [0, num_shards)") + if args.request_batch_size < 1: + raise SystemExit("--request-batch-size must be >= 1") + + vbvr_root = Path(args.vbvr_root).expanduser().resolve() + manifest_path = Path(args.manifest).expanduser().resolve() if args.manifest else vbvr_root / "VBVR-Bench.json" + output_root = Path(args.output_root).expanduser().resolve() + videos_root = output_root / "videos" + metrics_root = output_root / "metrics" + videos_root.mkdir(parents=True, exist_ok=True) + metrics_root.mkdir(parents=True, exist_ok=True) + + docs = filtered_docs(manifest_path, args.split, args.limit, args.shard_id, args.num_shards) + + print(f"Using manifest: {manifest_path}") + print(f"Using GT root: {vbvr_root}") + print(f"Output root: {output_root}") + print(f"Shard: {args.shard_id + 1}/{args.num_shards}") + print(f"Samples: {len(docs)}") + + if not args.skip_generate: + generate_videos(args, docs, videos_root, metrics_root) + + if args.skip_eval: + return + + file_split = FILE_SPLIT_MAP[args.split] + run_name = args.run_name or datetime.now().strftime("%Y%m%d_%H%M%S") + bench = VBVRBench(gt_base_path=str(vbvr_root), output_path=str(metrics_root)) + bench.evaluate(str(videos_root), name=run_name, split=file_split, task_specific_only=args.task_specific_only) + + +if __name__ == "__main__": + main() From ac1eac6fb2bb3adb99a461ce890b86df4d16583b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 27 Apr 2026 09:36:36 +0000 Subject: [PATCH 2/3] style: auto-fix lint (black + isort) --- lmms_eval/models/chat/vllm_omni.py | 3 +-- tools/run_vllm_omni_vbvr_local.py | 10 ++-------- tools/run_vllm_omni_vbvr_local_parallel.py | 10 ++-------- 3 files changed, 5 insertions(+), 18 deletions(-) diff --git a/lmms_eval/models/chat/vllm_omni.py b/lmms_eval/models/chat/vllm_omni.py index a11cab4cb..fcde88296 100644 --- a/lmms_eval/models/chat/vllm_omni.py +++ b/lmms_eval/models/chat/vllm_omni.py @@ -220,8 +220,7 @@ def __init__( ) except Exception as e: warnings.warn( - f"Failed to load AutoProcessor for {processor_name or model}: {type(e).__name__}: {e}. " - "Falling back to plain-text prompts.", + f"Failed to load AutoProcessor for {processor_name or model}: {type(e).__name__}: {e}. " "Falling back to plain-text prompts.", stacklevel=2, ) if self.chat_template is not None and self.processor is not None: diff --git a/tools/run_vllm_omni_vbvr_local.py b/tools/run_vllm_omni_vbvr_local.py index 12e6a0a0f..3113867d4 100644 --- a/tools/run_vllm_omni_vbvr_local.py +++ b/tools/run_vllm_omni_vbvr_local.py @@ -13,9 +13,9 @@ from PIL import Image from tqdm import tqdm +from vllm_omni import Omni from lmms_eval.tasks.vbvr.vbvr_bench import VBVRBench -from vllm_omni import Omni try: from diffusers.utils import export_to_video @@ -66,13 +66,7 @@ def decode_base64_image(data: str) -> Image.Image: def parse_task_meta(doc: dict[str, Any]) -> tuple[str, str, str]: - raw = ( - doc.get("first_frame_path") - or doc.get("final_frame_path") - or doc.get("prompt_path") - or doc.get("ground_truth_video_path") - or "" - ) + raw = doc.get("first_frame_path") or doc.get("final_frame_path") or doc.get("prompt_path") or doc.get("ground_truth_video_path") or "" parts = [part for part in str(raw).split("/") if part] if len(parts) < 3: raise ValueError(f"Cannot parse VBVR task meta from path: {raw!r}") diff --git a/tools/run_vllm_omni_vbvr_local_parallel.py b/tools/run_vllm_omni_vbvr_local_parallel.py index a5ea619e4..9dac84068 100644 --- a/tools/run_vllm_omni_vbvr_local_parallel.py +++ b/tools/run_vllm_omni_vbvr_local_parallel.py @@ -12,9 +12,9 @@ from PIL import Image from tqdm import tqdm +from vllm_omni import Omni from lmms_eval.tasks.vbvr.vbvr_bench import VBVRBench -from vllm_omni import Omni try: from diffusers.utils import export_to_video @@ -69,13 +69,7 @@ def decode_base64_image(data: str) -> Image.Image: def parse_task_meta(doc: dict[str, Any]) -> tuple[str, str, str]: - raw = ( - doc.get("first_frame_path") - or doc.get("final_frame_path") - or doc.get("prompt_path") - or doc.get("ground_truth_video_path") - or "" - ) + raw = doc.get("first_frame_path") or doc.get("final_frame_path") or doc.get("prompt_path") or doc.get("ground_truth_video_path") or "" parts = [part for part in str(raw).split("/") if part] if len(parts) < 3: raise ValueError(f"Cannot parse VBVR task meta from path: {raw!r}") From d1338e7e201acec5e1e303b566097c64b8bb25c9 Mon Sep 17 00:00:00 2001 From: pufanyi Date: Tue, 28 Apr 2026 17:15:49 +0800 Subject: [PATCH 3/3] update --- lmms_eval/models/chat/vllm_omni.py | 290 +++++++++++++++- lmms_eval/tasks/__init__.py | 6 +- tools/run_vllm_omni_vbvr_local.py | 282 ---------------- tools/run_vllm_omni_vbvr_local_parallel.py | 373 --------------------- 4 files changed, 285 insertions(+), 666 deletions(-) delete mode 100644 tools/run_vllm_omni_vbvr_local.py delete mode 100644 tools/run_vllm_omni_vbvr_local_parallel.py diff --git a/lmms_eval/models/chat/vllm_omni.py b/lmms_eval/models/chat/vllm_omni.py index fcde88296..d65c930a4 100644 --- a/lmms_eval/models/chat/vllm_omni.py +++ b/lmms_eval/models/chat/vllm_omni.py @@ -102,6 +102,10 @@ def _read_model_index_float(model_path: str, key: str) -> float | None: return None +def _has_local_model_index(model_path: str) -> bool: + return os.path.isfile(os.path.join(os.path.expanduser(str(model_path)), "model_index.json")) + + @dataclass class _PreparedRequest: prompt: dict[str, Any] @@ -136,6 +140,177 @@ def _lazy_import_runtime_dependencies() -> None: if export_to_video is None and _has_diffusers: export_to_video, _ = optional_import("diffusers.utils", "export_to_video") + @staticmethod + def _int_env(name: str, default: int) -> int: + try: + return int(os.environ.get(name, str(default)) or default) + except (TypeError, ValueError): + return default + + @staticmethod + def _has_internal_parallelism(tensor_parallel_size: int, data_parallel_size: int, kwargs: dict[str, Any]) -> bool: + sizes: list[Any] = [tensor_parallel_size, data_parallel_size] + parallel_config = kwargs.get("parallel_config") + if isinstance(parallel_config, dict): + sizes.extend( + parallel_config.get(key, 1) + for key in ( + "pipeline_parallel_size", + "data_parallel_size", + "tensor_parallel_size", + "cfg_parallel_size", + "vae_patch_parallel_size", + "sequence_parallel_size", + "ulysses_degree", + "ring_degree", + ) + ) + else: + sizes.extend( + getattr(parallel_config, key, 1) + for key in ( + "pipeline_parallel_size", + "data_parallel_size", + "tensor_parallel_size", + "cfg_parallel_size", + "vae_patch_parallel_size", + "sequence_parallel_size", + "ulysses_degree", + "ring_degree", + ) + ) + for size in sizes: + try: + if int(size) > 1: + return True + except (TypeError, ValueError): + return True + return False + + @classmethod + def _pin_default_diffusion_stage_to_local_rank( + cls, + model: str, + tensor_parallel_size: int, + data_parallel_size: int, + kwargs: dict[str, Any], + ) -> None: + world_size = cls._int_env("WORLD_SIZE", 1) + if world_size <= 1 or not _has_local_model_index(model): + return + if cls._has_internal_parallelism(tensor_parallel_size, data_parallel_size, kwargs): + return + if kwargs.get("stage_overrides") or kwargs.get("stage_0_devices") is not None: + return + + local_rank = cls._int_env("LOCAL_RANK", cls._int_env("RANK", 0)) + kwargs["stage_overrides"] = {"0": {"devices": str(local_rank)}} + + @classmethod + def _patch_default_diffusion_stage_devices_for_external_dp(cls, model: str) -> None: + if cls._int_env("WORLD_SIZE", 1) <= 1 or not _has_local_model_index(model): + return + try: + from vllm_omni.engine import async_omni_engine + except Exception: + return + + engine_cls = async_omni_engine.AsyncOmniEngine + original = engine_cls._create_default_diffusion_stage_cfg + if getattr(original, "_lmms_eval_external_dp_devices", False): + return + + def create_default_diffusion_stage_cfg(kwargs): + stage_configs = original(kwargs) + if stage_configs: + local_rank = cls._int_env("LOCAL_RANK", cls._int_env("RANK", 0)) + runtime = stage_configs[0].setdefault("runtime", {}) + runtime["devices"] = str(local_rank) + engine_args = stage_configs[0].setdefault("engine_args", {}) + engine_args.setdefault("master_port", 30005 + local_rank * 1000) + return stage_configs + + create_default_diffusion_stage_cfg._lmms_eval_external_dp_devices = True + engine_cls._create_default_diffusion_stage_cfg = staticmethod(create_default_diffusion_stage_cfg) + + @classmethod + def _patch_diffusion_stage_spawn_env_for_external_dp(cls, model: str) -> None: + if cls._int_env("WORLD_SIZE", 1) <= 1 or not _has_local_model_index(model): + return + try: + from vllm_omni.engine import async_omni_engine + except Exception: + return + + original = async_omni_engine.initialize_diffusion_stage + if getattr(original, "_lmms_eval_external_dp_single_rank_env", False): + return + + def initialize_diffusion_stage_single_rank_env(*args, **kwargs): + elastic_env_keys = tuple(key for key in os.environ if key.startswith("TORCHELASTIC_")) + dist_env_keys = ( + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", + "GROUP_RANK", + "GROUP_WORLD_SIZE", + "ROLE_RANK", + "ROLE_WORLD_SIZE", + "MASTER_ADDR", + "MASTER_PORT", + ) + elastic_env_keys + saved_env = {key: os.environ.get(key) for key in dist_env_keys} + try: + for key in elastic_env_keys: + os.environ.pop(key, None) + os.environ.update( + { + "RANK": "0", + "WORLD_SIZE": "1", + "LOCAL_RANK": "0", + "LOCAL_WORLD_SIZE": "1", + "GROUP_RANK": "0", + "GROUP_WORLD_SIZE": "1", + "ROLE_RANK": "0", + "ROLE_WORLD_SIZE": "1", + "MASTER_ADDR": "localhost", + } + ) + os.environ.pop("MASTER_PORT", None) + return original(*args, **kwargs) + finally: + for key, value in saved_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + + initialize_diffusion_stage_single_rank_env._lmms_eval_external_dp_single_rank_env = True + async_omni_engine.initialize_diffusion_stage = initialize_diffusion_stage_single_rank_env + + @classmethod + def _patch_inline_diffusion_device_for_external_dp(cls, model: str) -> None: + if cls._int_env("WORLD_SIZE", 1) <= 1 or not _has_local_model_index(model): + return + try: + from vllm_omni.platforms import current_omni_platform + except Exception: + return + if vars(current_omni_platform).get("_lmms_eval_external_dp_device_patch", False): + return + + original_get_torch_device = current_omni_platform.get_torch_device + + def get_torch_device(local_rank: int | None = None) -> torch.device: + if local_rank == 0: + external_local_rank = cls._int_env("LOCAL_RANK", cls._int_env("RANK", 0)) + return torch.device("cuda", external_local_rank) + return original_get_torch_device(local_rank) + + current_omni_platform.get_torch_device = staticmethod(get_torch_device) + current_omni_platform._lmms_eval_external_dp_device_patch = True + def __init__( self, model: str = "Qwen/Qwen2.5-Omni-7B", @@ -174,6 +349,12 @@ def __init__( raise ImportError("vllm is required by vllm_omni.") self.model = model + self._rank = self._int_env("RANK", 0) + self._world_size = self._int_env("WORLD_SIZE", 1) + self._local_rank = self._int_env("LOCAL_RANK", self._rank) + self._device = torch.device(f"cuda:{self._local_rank}" if self._world_size > 1 else ("cuda" if torch.cuda.is_available() else "cpu")) + if self._world_size > 1 and torch.cuda.is_available(): + torch.cuda.set_device(self._local_rank) self.batch_size_per_gpu = int(batch_size) self.max_frame_num = int(max_frame_num) self.fps = int(fps) if fps is not None else None @@ -206,6 +387,15 @@ def __init__( data_parallel_size=data_parallel_size, kwargs=kwargs, ) + self._pin_default_diffusion_stage_to_local_rank( + model=self.model, + tensor_parallel_size=tensor_parallel_size, + data_parallel_size=data_parallel_size, + kwargs=kwargs, + ) + self._patch_default_diffusion_stage_devices_for_external_dp(self.model) + self._patch_diffusion_stage_spawn_env_for_external_dp(self.model) + self._patch_inline_diffusion_device_for_external_dp(self.model) if "log_stats" not in kwargs: kwargs["log_stats"] = not self.disable_log_stats @@ -236,6 +426,14 @@ def __init__( ) atexit.register(self.close) + @property + def batch_size(self) -> int: + return self.batch_size_per_gpu + + @property + def device(self) -> torch.device: + return self._device + @staticmethod def _maybe_parse_json_dict(value: Any) -> dict[str, Any] | None: if value is None: @@ -491,6 +689,15 @@ def _build_multi_modal_data(self, chat_messages: ChatMessages) -> dict[str, Any] def make_one_request(self, request: Instance) -> _PreparedRequest: ctx, doc_to_messages, gen_kwargs, doc_id, task, split = request.arguments + if task is None: + if len(self.task_dict) != 1: + raise KeyError(f"Request did not include a task name and multiple tasks are loaded: {list(self.task_dict)}") + task = next(iter(self.task_dict)) + if split is None: + task_splits = self.task_dict[task] + if len(task_splits) != 1: + raise KeyError(f"Request for task {task!r} did not include a split and multiple splits are loaded: {list(task_splits)}") + split = next(iter(task_splits)) raw_messages = doc_to_messages(self.task_dict[task][split][doc_id]) chat_messages = ChatMessages(messages=raw_messages) if self.processor is not None: @@ -605,6 +812,13 @@ def _save_images(self, images: Sequence[Any], out_dir: str) -> list[str]: paths.append(image_path) return paths + @staticmethod + def _is_video_payload(payload: Any) -> bool: + if torch.is_tensor(payload): + payload = payload.detach().cpu() + return payload.ndim >= 4 + return isinstance(payload, np.ndarray) and payload.ndim >= 4 + def _normalize_video_frames(self, frames: Any) -> list[Any]: if isinstance(frames, list): normalized: list[Any] = [] @@ -638,7 +852,7 @@ def _save_video(self, images: Sequence[Any], out_dir: str) -> list[str]: except Exception as e: raise ImportError("Saving video outputs requires `diffusers` or `imageio`.") from e - frames = [np.asarray(self._to_pil_image(image).convert("RGB")) for image in images] + frames = [np.asarray(self._to_pil_image(image).convert("RGB")) for image in self._normalize_video_frames(list(images))] imageio_v2.mimsave(video_path, frames, fps=fps) return [video_path] @@ -677,14 +891,21 @@ def _to_generation_result(self, output: Any, prepared: _PreparedRequest) -> Gene video_paths: list[str] = [] out_dir = self._request_output_dir(prepared.task, prepared.split, prepared.doc_id) - images = getattr(output, "images", []) or [] + images = getattr(output, "images", None) + if images is None: + images = [] + elif isinstance(images, (Image.Image, np.ndarray)) or torch.is_tensor(images): + images = [images] if images: - if len(images) > 1: + if len(images) > 1 or any(self._is_video_payload(image) for image in images): video_paths = self._save_video(images, out_dir) else: image_paths = self._save_images(images, out_dir) multimodal_output = getattr(output, "multimodal_output", {}) or {} + video_payload = multimodal_output.get("video", multimodal_output.get("videos")) + if video_payload is not None and not video_paths: + video_paths = self._save_video(video_payload if isinstance(video_payload, list) else [video_payload], out_dir) fallback_sr = multimodal_output.get("audio_sample_rate", multimodal_output.get("sampling_rate", multimodal_output.get("sample_rate", multimodal_output.get("sr")))) if "audio" in multimodal_output: audio_paths = self._save_audios(multimodal_output["audio"], out_dir, fallback_sr) @@ -692,14 +913,67 @@ def _to_generation_result(self, output: Any, prepared: _PreparedRequest) -> Gene formatted = self._format_output(text, image_paths, audio_paths, video_paths) return GenerationResult(text=formatted, token_counts=token_counts) + def _can_use_diffusion_batch_request(self) -> bool: + return False + + @classmethod + def _slice_batched_payload(cls, payload: Any, idx: int, count: int) -> Any: + if isinstance(payload, dict): + return {key: cls._slice_batched_payload(value, idx, count) for key, value in payload.items()} + if isinstance(payload, list): + if len(payload) == count: + return payload[idx] + return payload + if isinstance(payload, tuple): + if len(payload) == count: + return payload[idx] + return payload + if torch.is_tensor(payload): + return payload[idx] if payload.ndim > 0 and payload.shape[0] == count else payload + if isinstance(payload, np.ndarray): + return payload[idx] if payload.ndim > 0 and payload.shape[0] == count else payload + return payload + + def _split_batched_output(self, output: Any, count: int) -> list[Any]: + images = getattr(output, "images", None) or [] + if not isinstance(images, list): + images = [images] + + if images and len(images) % count != 0: + raise ValueError(f"Batched vllm_omni output has {len(images)} image payloads for {count} requests") + images_per_request = (len(images) // count) if images else 0 + multimodal_output = getattr(output, "multimodal_output", {}) or {} + + split_outputs = [] + for idx in range(count): + split_output = copy.copy(output) + start = idx * images_per_request + end = start + images_per_request + if hasattr(split_output, "images"): + split_output.images = images[start:end] + if hasattr(split_output, "_multimodal_output"): + split_output._multimodal_output = self._slice_batched_payload(multimodal_output, idx, count) + split_outputs.append(split_output) + return split_outputs + def _generate_batch(self, prepared_requests: Sequence[_PreparedRequest]) -> tuple[list[Any], float]: prompts = [prepared.prompt for prepared in prepared_requests] start_time = time.time() - outputs = self.client.generate( - prompts, - sampling_params_list=prepared_requests[0].sampling_params_list, - use_tqdm=False, - ) + if len(prompts) > 1 and self._can_use_diffusion_batch_request(): + batched_outputs = self.client.generate( + [prompts], + sampling_params_list=prepared_requests[0].sampling_params_list, + use_tqdm=False, + ) + if len(batched_outputs) != 1: + raise ValueError(f"Expected one vllm_omni batched output, got {len(batched_outputs)}") + outputs = self._split_batched_output(batched_outputs[0], len(prompts)) + else: + outputs = self.client.generate( + prompts, + sampling_params_list=prepared_requests[0].sampling_params_list, + use_tqdm=False, + ) return outputs, time.time() - start_time def _generate_single(self, prepared_request: _PreparedRequest) -> tuple[Any, float]: diff --git a/lmms_eval/tasks/__init__.py b/lmms_eval/tasks/__init__.py index e8f367caa..c8d7830d9 100755 --- a/lmms_eval/tasks/__init__.py +++ b/lmms_eval/tasks/__init__.py @@ -378,8 +378,8 @@ def load_task_or_group(self, task_list: Optional[Union[str, list]] = None, task_ all_loaded_tasks = dict(collections.ChainMap(*map(load_fn, task_list))) return all_loaded_tasks - def load_config(self, config: Dict): - return self._load_individual_task_or_group(config) + def load_config(self, config: Dict, task_type: Literal["simple", "chat"] = "simple"): + return self._load_individual_task_or_group(config, task_type=task_type) def _get_task_and_group(self, task_dir: str): """Creates a dictionary of tasks index with the following metadata, @@ -571,7 +571,7 @@ def get_task_dict( if isinstance(task_element, dict): task_name_from_config_dict = { **task_name_from_config_dict, - **task_manager.load_config(config=task_element), + **task_manager.load_config(config=task_element, task_type=task_type), } elif isinstance(task_element, Task): diff --git a/tools/run_vllm_omni_vbvr_local.py b/tools/run_vllm_omni_vbvr_local.py deleted file mode 100644 index 3113867d4..000000000 --- a/tools/run_vllm_omni_vbvr_local.py +++ /dev/null @@ -1,282 +0,0 @@ -from __future__ import annotations - -import argparse -import base64 -import copy -import importlib -import io -import json -import os -from datetime import datetime -from pathlib import Path -from typing import Any - -from PIL import Image -from tqdm import tqdm -from vllm_omni import Omni - -from lmms_eval.tasks.vbvr.vbvr_bench import VBVRBench - -try: - from diffusers.utils import export_to_video -except Exception: # pragma: no cover - export_to_video = None - - -FILE_SPLIT_MAP = { - "all": None, - "in_domain": "In-Domain_50", - "out_of_domain": "Out-of-Domain_50", -} - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Run Wan2.2 with vllm-omni on a local VBVR-Bench checkout.") - parser.add_argument("--model", required=True, help="Path to the local Wan2.2 checkpoint") - parser.add_argument("--vbvr-root", required=True, help="Path to the local VBVR-Bench root") - parser.add_argument("--manifest", default=None, help="Optional path to VBVR-Bench.json; defaults to /VBVR-Bench.json") - parser.add_argument("--output-root", required=True, help="Directory for generated videos and metrics") - parser.add_argument("--split", choices=sorted(FILE_SPLIT_MAP), default="all") - parser.add_argument("--limit", type=int, default=None, help="Optional sample limit after split filtering") - parser.add_argument("--tensor-parallel-size", type=int, default=1) - parser.add_argument("--data-parallel-size", type=int, default=1) - parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) - parser.add_argument("--cache-backend", default="cache_dit") - parser.add_argument("--num-inference-steps", type=int, default=50) - parser.add_argument("--guidance-scale", type=float, default=5.0) - parser.add_argument("--guidance-scale-2", type=float, default=None) - parser.add_argument("--num-frames", type=int, default=81) - parser.add_argument("--height", type=int, default=384) - parser.add_argument("--width", type=int, default=384) - parser.add_argument("--fps", type=int, default=16) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--boundary-ratio", type=float, default=None) - parser.add_argument("--flow-shift", type=float, default=None) - parser.add_argument("--diffusion-batch-size", type=int, default=None) - parser.add_argument("--overwrite", action="store_true", help="Regenerate videos even if the output mp4 already exists") - parser.add_argument("--skip-eval", action="store_true", help="Only generate videos; skip VBVR scoring") - parser.add_argument("--task-specific-only", action="store_true", help="Score only VBVR task-specific rules instead of the default weighted aggregate") - parser.add_argument("--run-name", default=None, help="Optional name for the evaluation result JSON") - return parser.parse_args() - - -def decode_base64_image(data: str) -> Image.Image: - payload = data.split(",", 1)[-1] - return Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB") - - -def parse_task_meta(doc: dict[str, Any]) -> tuple[str, str, str]: - raw = doc.get("first_frame_path") or doc.get("final_frame_path") or doc.get("prompt_path") or doc.get("ground_truth_video_path") or "" - parts = [part for part in str(raw).split("/") if part] - if len(parts) < 3: - raise ValueError(f"Cannot parse VBVR task meta from path: {raw!r}") - return parts[0], parts[1], parts[2] - - -def filtered_docs(manifest_path: Path, split: str, limit: int | None) -> list[dict[str, Any]]: - docs = json.loads(manifest_path.read_text()) - file_split = FILE_SPLIT_MAP[split] - selected: list[dict[str, Any]] = [] - for doc in docs: - doc_file_split, _, _ = parse_task_meta(doc) - if file_split is not None and doc_file_split != file_split: - continue - selected.append(doc) - selected.sort(key=lambda doc: parse_task_meta(doc)) - if limit is not None: - selected = selected[:limit] - return selected - - -def build_sampling_params(omni: Omni, args: argparse.Namespace) -> list[Any]: - sampling_params_list = copy.deepcopy(list(omni.default_sampling_params_list)) - if not sampling_params_list: - raise RuntimeError("vllm-omni returned an empty default_sampling_params_list") - - stage0 = sampling_params_list[0] - boundary_ratio = args.boundary_ratio - if boundary_ratio is None: - boundary_ratio = read_model_index_float(args.model, "boundary_ratio") - - values = { - "num_inference_steps": args.num_inference_steps, - "guidance_scale": args.guidance_scale, - "guidance_scale_2": args.guidance_scale_2, - "num_frames": args.num_frames, - "height": args.height, - "width": args.width, - "fps": args.fps, - "seed": args.seed, - "boundary_ratio": boundary_ratio, - "flow_shift": args.flow_shift, - } - for key, value in values.items(): - if value is not None and hasattr(stage0, key): - setattr(stage0, key, value) - if hasattr(stage0, "guidance_scale_provided"): - stage0.guidance_scale_provided = True - sampling_params_list[0] = stage0 - return sampling_params_list - - -def read_model_index_float(model_path: str, key: str) -> float | None: - model_index_path = Path(model_path).expanduser() / "model_index.json" - if not model_index_path.is_file(): - return None - try: - value = json.loads(model_index_path.read_text()).get(key) - except Exception: - return None - if value is None: - return None - try: - return float(value) - except (TypeError, ValueError): - return None - - -def build_parallel_config(args: argparse.Namespace) -> dict[str, int]: - return { - "pipeline_parallel_size": 1, - "data_parallel_size": args.data_parallel_size, - "tensor_parallel_size": args.tensor_parallel_size, - } - - -def to_pil_image(image: Any) -> Image.Image: - if isinstance(image, Image.Image): - return image.convert("RGB") - - import numpy as np - import torch - - if torch.is_tensor(image): - image = image.detach().cpu().numpy() - if not isinstance(image, np.ndarray): - raise TypeError(f"Unsupported image output type: {type(image).__name__}") - if image.ndim == 3 and image.shape[0] in {1, 3, 4} and image.shape[-1] not in {1, 3, 4}: - image = image.transpose(1, 2, 0) - if image.dtype != np.uint8: - image = np.clip(image, 0, 1) * 255 if image.max() <= 1.0 else np.clip(image, 0, 255) - image = image.astype(np.uint8) - return Image.fromarray(image).convert("RGB") - - -def normalize_video_frames(frames: Any) -> list[Any]: - import numpy as np - import torch - - if isinstance(frames, list): - if not frames: - return [] - normalized: list[Any] = [] - for item in frames: - normalized.extend(normalize_video_frames(item)) - return normalized - if torch.is_tensor(frames): - frames = frames.detach().cpu().numpy() - if isinstance(frames, np.ndarray): - if frames.ndim == 5 and frames.shape[0] == 1: - return normalize_video_frames(frames[0]) - if frames.ndim == 4: - return [frames[i] for i in range(frames.shape[0])] - if frames.ndim == 3: - return [frames] - return [frames] - - -def save_video(frames: list[Any], output_path: Path, fps: int) -> None: - output_path.parent.mkdir(parents=True, exist_ok=True) - pil_frames = [to_pil_image(frame) for frame in normalize_video_frames(frames)] - if export_to_video is not None: - export_to_video(pil_frames, output_video_path=str(output_path), fps=fps) - return - - imageio_v2 = importlib.import_module("imageio.v2") - imageio_v2.mimsave(str(output_path), pil_frames, fps=fps) - - -def main() -> None: - args = parse_args() - if args.diffusion_batch_size is None: - args.diffusion_batch_size = max(1, args.data_parallel_size) - vbvr_root = Path(args.vbvr_root).expanduser().resolve() - manifest_path = Path(args.manifest).expanduser().resolve() if args.manifest else vbvr_root / "VBVR-Bench.json" - output_root = Path(args.output_root).expanduser().resolve() - videos_root = output_root / "videos" - metrics_root = output_root / "metrics" - videos_root.mkdir(parents=True, exist_ok=True) - metrics_root.mkdir(parents=True, exist_ok=True) - - docs = filtered_docs(manifest_path, args.split, args.limit) - if not docs: - raise SystemExit("No VBVR samples matched the requested split/limit") - - print(f"Using manifest: {manifest_path}") - print(f"Using GT root: {vbvr_root}") - print(f"Output root: {output_root}") - print(f"Samples: {len(docs)}") - - omni = Omni( - model=args.model, - parallel_config=build_parallel_config(args), - gpu_memory_utilization=args.gpu_memory_utilization, - trust_remote_code=True, - cache_backend=args.cache_backend, - diffusion_batch_size=args.diffusion_batch_size, - ) - - failures: list[dict[str, Any]] = [] - try: - sampling_params_list = build_sampling_params(omni, args) - for doc in tqdm(docs, desc="Generating VBVR videos", dynamic_ncols=True): - file_split, task_name, video_idx = parse_task_meta(doc) - output_path = videos_root / file_split / task_name / f"{video_idx}.mp4" - if output_path.exists() and not args.overwrite: - continue - - prompt = str(doc.get("prompt") or "").strip() - image = decode_base64_image(str(doc["first_image"])) - request = { - "prompt": prompt, - "multi_modal_data": {"image": image}, - } - - try: - outputs = omni.generate(request, sampling_params_list=sampling_params_list, use_tqdm=False) - result = outputs[0] - if getattr(result, "error", None): - raise RuntimeError(str(result.error)) - frames = list(getattr(result, "images", []) or []) - if not frames: - raise RuntimeError("Omni returned no image frames") - save_video(frames, output_path, fps=args.fps) - except Exception as e: # noqa: BLE001 - failures.append( - { - "file_split": file_split, - "task_name": task_name, - "video_idx": video_idx, - "output_path": str(output_path), - "error": f"{type(e).__name__}: {e}", - } - ) - - failures_path = metrics_root / "generation_failures.json" - failures_path.write_text(json.dumps(failures, indent=2)) - print(f"Generation failures: {len(failures)}") - print(f"Failure log: {failures_path}") - finally: - omni.close() - - if args.skip_eval: - return - - file_split = FILE_SPLIT_MAP[args.split] - run_name = args.run_name or datetime.now().strftime("%Y%m%d_%H%M%S") - bench = VBVRBench(gt_base_path=str(vbvr_root), output_path=str(metrics_root)) - bench.evaluate(str(videos_root), name=run_name, split=file_split, task_specific_only=args.task_specific_only) - - -if __name__ == "__main__": - main() diff --git a/tools/run_vllm_omni_vbvr_local_parallel.py b/tools/run_vllm_omni_vbvr_local_parallel.py deleted file mode 100644 index 9dac84068..000000000 --- a/tools/run_vllm_omni_vbvr_local_parallel.py +++ /dev/null @@ -1,373 +0,0 @@ -from __future__ import annotations - -import argparse -import base64 -import copy -import importlib -import io -import json -from datetime import datetime -from pathlib import Path -from typing import Any - -from PIL import Image -from tqdm import tqdm -from vllm_omni import Omni - -from lmms_eval.tasks.vbvr.vbvr_bench import VBVRBench - -try: - from diffusers.utils import export_to_video -except Exception: # pragma: no cover - export_to_video = None - - -FILE_SPLIT_MAP = { - "all": None, - "in_domain": "In-Domain_50", - "out_of_domain": "Out-of-Domain_50", -} - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Run Wan2.2 with vllm-omni on a local VBVR-Bench checkout.") - parser.add_argument("--model", required=True, help="Path to the local Wan2.2 checkpoint") - parser.add_argument("--vbvr-root", required=True, help="Path to the local VBVR-Bench root") - parser.add_argument("--manifest", default=None, help="Optional path to VBVR-Bench.json; defaults to /VBVR-Bench.json") - parser.add_argument("--output-root", required=True, help="Directory for generated videos and metrics") - parser.add_argument("--split", choices=sorted(FILE_SPLIT_MAP), default="all") - parser.add_argument("--limit", type=int, default=None, help="Optional sample limit after split filtering") - parser.add_argument("--tensor-parallel-size", type=int, default=1) - parser.add_argument("--data-parallel-size", type=int, default=1) - parser.add_argument("--gpu-memory-utilization", type=float, default=0.9) - parser.add_argument("--cache-backend", default="cache_dit") - parser.add_argument("--num-inference-steps", type=int, default=50) - parser.add_argument("--guidance-scale", type=float, default=5.0) - parser.add_argument("--guidance-scale-2", type=float, default=None) - parser.add_argument("--num-frames", type=int, default=81) - parser.add_argument("--height", type=int, default=384) - parser.add_argument("--width", type=int, default=384) - parser.add_argument("--fps", type=int, default=16) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--boundary-ratio", type=float, default=None) - parser.add_argument("--flow-shift", type=float, default=None) - parser.add_argument("--diffusion-batch-size", type=int, default=None) - parser.add_argument("--request-batch-size", type=int, default=None, help="How many samples to submit in one Omni.generate call.") - parser.add_argument("--shard-id", type=int, default=0, help="0-based shard index for process-level sample parallelism.") - parser.add_argument("--num-shards", type=int, default=1, help="Total number of disjoint shards.") - parser.add_argument("--overwrite", action="store_true", help="Regenerate videos even if the output mp4 already exists") - parser.add_argument("--skip-generate", action="store_true", help="Only run VBVR scoring on existing videos; skip generation") - parser.add_argument("--skip-eval", action="store_true", help="Only generate videos; skip VBVR scoring") - parser.add_argument("--task-specific-only", action="store_true", help="Score only VBVR task-specific rules instead of the default weighted aggregate") - parser.add_argument("--run-name", default=None, help="Optional name for the evaluation result JSON") - return parser.parse_args() - - -def decode_base64_image(data: str) -> Image.Image: - payload = data.split(",", 1)[-1] - return Image.open(io.BytesIO(base64.b64decode(payload))).convert("RGB") - - -def parse_task_meta(doc: dict[str, Any]) -> tuple[str, str, str]: - raw = doc.get("first_frame_path") or doc.get("final_frame_path") or doc.get("prompt_path") or doc.get("ground_truth_video_path") or "" - parts = [part for part in str(raw).split("/") if part] - if len(parts) < 3: - raise ValueError(f"Cannot parse VBVR task meta from path: {raw!r}") - return parts[0], parts[1], parts[2] - - -def filtered_docs( - manifest_path: Path, - split: str, - limit: int | None, - shard_id: int, - num_shards: int, -) -> list[dict[str, Any]]: - docs = json.loads(manifest_path.read_text()) - file_split = FILE_SPLIT_MAP[split] - selected: list[dict[str, Any]] = [] - for doc in docs: - doc_file_split, _, _ = parse_task_meta(doc) - if file_split is not None and doc_file_split != file_split: - continue - selected.append(doc) - selected.sort(key=lambda doc: parse_task_meta(doc)) - if limit is not None: - selected = selected[:limit] - if num_shards > 1: - selected = [doc for idx, doc in enumerate(selected) if idx % num_shards == shard_id] - return selected - - -def build_sampling_params(omni: Omni, args: argparse.Namespace) -> list[Any]: - sampling_params_list = copy.deepcopy(list(omni.default_sampling_params_list)) - if not sampling_params_list: - raise RuntimeError("vllm-omni returned an empty default_sampling_params_list") - - stage0 = sampling_params_list[0] - boundary_ratio = args.boundary_ratio - if boundary_ratio is None: - boundary_ratio = read_model_index_float(args.model, "boundary_ratio") - - values = { - "num_inference_steps": args.num_inference_steps, - "guidance_scale": args.guidance_scale, - "guidance_scale_2": args.guidance_scale_2, - "num_frames": args.num_frames, - "height": args.height, - "width": args.width, - "fps": args.fps, - "seed": args.seed, - "boundary_ratio": boundary_ratio, - "flow_shift": args.flow_shift, - } - for key, value in values.items(): - if value is not None and hasattr(stage0, key): - setattr(stage0, key, value) - if hasattr(stage0, "guidance_scale_provided"): - stage0.guidance_scale_provided = True - sampling_params_list[0] = stage0 - return sampling_params_list - - -def read_model_index_float(model_path: str, key: str) -> float | None: - model_index_path = Path(model_path).expanduser() / "model_index.json" - if not model_index_path.is_file(): - return None - try: - value = json.loads(model_index_path.read_text()).get(key) - except Exception: - return None - if value is None: - return None - try: - return float(value) - except (TypeError, ValueError): - return None - - -def build_parallel_config(args: argparse.Namespace) -> dict[str, int]: - return { - "pipeline_parallel_size": 1, - "data_parallel_size": args.data_parallel_size, - "tensor_parallel_size": args.tensor_parallel_size, - } - - -def to_pil_image(image: Any) -> Image.Image: - if isinstance(image, Image.Image): - return image.convert("RGB") - - import numpy as np - import torch - - if torch.is_tensor(image): - image = image.detach().cpu().numpy() - if not isinstance(image, np.ndarray): - raise TypeError(f"Unsupported image output type: {type(image).__name__}") - if image.ndim == 3 and image.shape[0] in {1, 3, 4} and image.shape[-1] not in {1, 3, 4}: - image = image.transpose(1, 2, 0) - if image.dtype != np.uint8: - image = np.clip(image, 0, 1) * 255 if image.max() <= 1.0 else np.clip(image, 0, 255) - image = image.astype(np.uint8) - return Image.fromarray(image).convert("RGB") - - -def normalize_video_frames(frames: Any) -> list[Any]: - import numpy as np - import torch - - if isinstance(frames, list): - if not frames: - return [] - normalized: list[Any] = [] - for item in frames: - normalized.extend(normalize_video_frames(item)) - return normalized - if torch.is_tensor(frames): - frames = frames.detach().cpu().numpy() - if isinstance(frames, np.ndarray): - if frames.ndim == 5 and frames.shape[0] == 1: - return normalize_video_frames(frames[0]) - if frames.ndim == 4: - return [frames[i] for i in range(frames.shape[0])] - if frames.ndim == 3: - return [frames] - return [frames] - - -def save_video(frames: Any, output_path: Path, fps: int) -> None: - output_path.parent.mkdir(parents=True, exist_ok=True) - pil_frames = [to_pil_image(frame) for frame in normalize_video_frames(frames)] - if export_to_video is not None: - export_to_video(pil_frames, output_video_path=str(output_path), fps=fps) - return - - imageio_v2 = importlib.import_module("imageio.v2") - imageio_v2.mimsave(str(output_path), pil_frames, fps=fps) - - -def chunked(items: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]: - return [items[idx : idx + size] for idx in range(0, len(items), size)] - - -def generation_failures_path(metrics_root: Path, shard_id: int, num_shards: int) -> Path: - if num_shards <= 1: - return metrics_root / "generation_failures.json" - return metrics_root / f"generation_failures_shard_{shard_id:02d}_of_{num_shards:02d}.json" - - -def generate_videos( - args: argparse.Namespace, - docs: list[dict[str, Any]], - videos_root: Path, - metrics_root: Path, -) -> None: - failures: list[dict[str, Any]] = [] - failures_path = generation_failures_path(metrics_root, args.shard_id, args.num_shards) - if not docs: - failures_path.write_text(json.dumps(failures, indent=2)) - print("Generation failures: 0") - print(f"Failure log: {failures_path}") - return - - omni = Omni( - model=args.model, - parallel_config=build_parallel_config(args), - gpu_memory_utilization=args.gpu_memory_utilization, - trust_remote_code=True, - cache_backend=args.cache_backend, - diffusion_batch_size=args.diffusion_batch_size, - ) - - try: - sampling_params_list = build_sampling_params(omni, args) - batches = chunked(docs, args.request_batch_size) - desc = f"Generating shard {args.shard_id + 1}/{args.num_shards}" - for batch_docs in tqdm(batches, desc=desc, dynamic_ncols=True): - request_batch: list[dict[str, Any]] = [] - meta_batch: list[tuple[str, str, str, Path]] = [] - for doc in batch_docs: - file_split, task_name, video_idx = parse_task_meta(doc) - output_path = videos_root / file_split / task_name / f"{video_idx}.mp4" - if output_path.exists() and not args.overwrite: - continue - try: - prompt = str(doc.get("prompt") or "").strip() - image = decode_base64_image(str(doc["first_image"])) - request_batch.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - meta_batch.append((file_split, task_name, video_idx, output_path)) - except Exception as e: # noqa: BLE001 - failures.append( - { - "file_split": file_split, - "task_name": task_name, - "video_idx": video_idx, - "output_path": str(output_path), - "error": f"{type(e).__name__}: {e}", - } - ) - - if not request_batch: - continue - - prompts: list[dict[str, Any]] | dict[str, Any] - prompts = request_batch if len(request_batch) > 1 else request_batch[0] - try: - outputs = omni.generate(prompts, sampling_params_list=sampling_params_list, use_tqdm=False) - except Exception as e: # noqa: BLE001 - for file_split, task_name, video_idx, output_path in meta_batch: - failures.append( - { - "file_split": file_split, - "task_name": task_name, - "video_idx": video_idx, - "output_path": str(output_path), - "error": f"{type(e).__name__}: {e}", - } - ) - continue - - if len(outputs) != len(meta_batch): - error_text = f"Expected {len(meta_batch)} outputs, got {len(outputs)}" - for file_split, task_name, video_idx, output_path in meta_batch: - failures.append( - { - "file_split": file_split, - "task_name": task_name, - "video_idx": video_idx, - "output_path": str(output_path), - "error": error_text, - } - ) - continue - - for (file_split, task_name, video_idx, output_path), result in zip(meta_batch, outputs): - try: - if getattr(result, "error", None): - raise RuntimeError(str(result.error)) - frames = getattr(result, "images", None) - if frames is None or (isinstance(frames, list) and not frames): - raise RuntimeError("Omni returned no image frames") - save_video(frames, output_path, fps=args.fps) - except Exception as e: # noqa: BLE001 - failures.append( - { - "file_split": file_split, - "task_name": task_name, - "video_idx": video_idx, - "output_path": str(output_path), - "error": f"{type(e).__name__}: {e}", - } - ) - finally: - omni.close() - - failures_path.write_text(json.dumps(failures, indent=2)) - print(f"Generation failures: {len(failures)}") - print(f"Failure log: {failures_path}") - - -def main() -> None: - args = parse_args() - if args.diffusion_batch_size is None: - args.diffusion_batch_size = max(1, args.data_parallel_size) - if args.request_batch_size is None: - args.request_batch_size = max(1, args.data_parallel_size) - if args.num_shards < 1: - raise SystemExit("--num-shards must be >= 1") - if args.shard_id < 0 or args.shard_id >= args.num_shards: - raise SystemExit("--shard-id must be in [0, num_shards)") - if args.request_batch_size < 1: - raise SystemExit("--request-batch-size must be >= 1") - - vbvr_root = Path(args.vbvr_root).expanduser().resolve() - manifest_path = Path(args.manifest).expanduser().resolve() if args.manifest else vbvr_root / "VBVR-Bench.json" - output_root = Path(args.output_root).expanduser().resolve() - videos_root = output_root / "videos" - metrics_root = output_root / "metrics" - videos_root.mkdir(parents=True, exist_ok=True) - metrics_root.mkdir(parents=True, exist_ok=True) - - docs = filtered_docs(manifest_path, args.split, args.limit, args.shard_id, args.num_shards) - - print(f"Using manifest: {manifest_path}") - print(f"Using GT root: {vbvr_root}") - print(f"Output root: {output_root}") - print(f"Shard: {args.shard_id + 1}/{args.num_shards}") - print(f"Samples: {len(docs)}") - - if not args.skip_generate: - generate_videos(args, docs, videos_root, metrics_root) - - if args.skip_eval: - return - - file_split = FILE_SPLIT_MAP[args.split] - run_name = args.run_name or datetime.now().strftime("%Y%m%d_%H%M%S") - bench = VBVRBench(gt_base_path=str(vbvr_root), output_path=str(metrics_root)) - bench.evaluate(str(videos_root), name=run_name, split=file_split, task_specific_only=args.task_specific_only) - - -if __name__ == "__main__": - main()