From 165814258b2bfc2d3c940ebf50269962b30ca705 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Fri, 6 Feb 2026 15:01:44 +0800 Subject: [PATCH] feature(sunjx): add rejection sampling in grm_training --- .../grm_training/rejection_sampling/README.md | 325 ++++++++++++ .../convert_to_rejection_sampling_data_t2i.py | 163 ++++++ .../convert_to_rejection_sampling_data_t2v.py | 155 ++++++ .../rejection_sampling_inference_t2i.py | 370 ++++++++++++++ .../rejection_sampling_inference_t2v.py | 477 ++++++++++++++++++ .../run_rejection_sampling_t2i.sh | 151 ++++++ .../run_rejection_sampling_t2v.sh | 155 ++++++ .../train_rejection_sampling_mix.sh | 130 +++++ .../train_rejection_sampling_t2i.sh | 134 +++++ .../train_rejection_sampling_t2v.sh | 138 +++++ 10 files changed, 2198 insertions(+) create mode 100644 examples/grm_training/rejection_sampling/README.md create mode 100644 examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2i.py create mode 100644 examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py create mode 100644 examples/grm_training/rejection_sampling/rejection_sampling_inference_t2i.py create mode 100644 examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py create mode 100644 examples/grm_training/rejection_sampling/run_rejection_sampling_t2i.sh create mode 100755 examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh create mode 100644 examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh create mode 100755 examples/grm_training/rejection_sampling/train_rejection_sampling_t2i.sh create mode 100755 examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh diff --git a/examples/grm_training/rejection_sampling/README.md b/examples/grm_training/rejection_sampling/README.md new file mode 100644 index 00000000..7239d2f4 --- /dev/null +++ b/examples/grm_training/rejection_sampling/README.md @@ -0,0 +1,325 @@ +# Rejection Sampling for GRM Training + +This directory contains scripts and tools for preparing rejection sampling training data and training GRM (Generative Reward Model) models on both text-to-image (T2I) and text-to-video (T2V) tasks. + +## Overview + +Rejection sampling is a technique to filter high-quality training samples by: +1. Running inference on a dataset using a trained GRM model +2. Filtering correctly predicted samples (where model prediction matches ground truth) +3. Converting filtered samples into training format with Chain-of-Thought (CoT) reasoning +4. Training the model on these high-quality filtered samples + +## Directory Structure + +``` +rejection_sampling/ +├── README.md # This file +├── run_rejection_sampling_t2i.sh # T2I rejection sampling data preparation +├── run_rejection_sampling_t2v.sh # T2V rejection sampling data preparation +├── rejection_sampling_inference_t2i.py # T2I inference and filtering script +├── rejection_sampling_inference_t2v.py # T2V inference and filtering script +├── convert_to_rejection_sampling_data_t2i.py # Convert T2I filtered samples to training format +├── convert_to_rejection_sampling_data_t2v.py # Convert T2V filtered samples to training format +├── train_rejection_sampling_t2i.sh # Train GRM on T2I rejection sampling data +├── train_rejection_sampling_t2v.sh # Train GRM on T2V rejection sampling data +└── train_rejection_sampling_mix.sh # Train GRM on mixed T2I + T2V data +``` + +## Workflow + +### Text-to-Image (T2I) Rejection Sampling + +#### Step 1: Data Preparation + +Run the rejection sampling data preparation pipeline: + +```bash +bash run_rejection_sampling_t2i.sh +``` + +**Configuration** (edit the script before running): +- `MODEL_PATH`: Path to your pre-trained GRM model +- `DATA_PATH`: Dataset path in format `"source:path"` (e.g., `"hpdv3:path/to/dataset.json"`) +- `DATA_ROOT`: Root directory of the dataset (for resolving image paths) +- `OUTPUT_DIR`: Directory to save filtered samples and training data +- `INFERENCE_BATCH_SIZE`: Batch size for inference (default: 8) +- `MAX_NEW_TOKENS`: Maximum tokens to generate (default: 2048) +- `TASK_INSTRUCTION`: CoT instruction template for the model + +This script will: +1. Run inference on the dataset using vLLM +2. Filter correctly predicted samples +3. Convert filtered samples to training format +4. Save the results to `${OUTPUT_DIR}/rejection_sampling_train.json` + +#### Step 2: Training + +Train the model on the filtered rejection sampling data: + +```bash +bash train_rejection_sampling_t2i.sh +``` + +**Configuration** (edit the script before running): +- `MODEL_PATH`: Path to your pre-trained model +- `TRAINING_DATA_PATH`: Path to the training data generated in Step 1 +- `OUTPUT_DIR`: Directory to save model checkpoints +- `LOG_DIR`: Directory to save training logs +- Training hyperparameters: `TBS`, `LR`, `MAX_LENGTH`, `MAX_EPOCHS`, etc. + +### Text-to-Video (T2V) Rejection Sampling + +#### Step 1: Data Preparation + +Run the T2V rejection sampling data preparation pipeline: + +```bash +bash run_rejection_sampling_t2v.sh +``` + +**Configuration** (edit the script before running): +- `MODEL_PATH`: Path to your pre-trained GRM model +- `DATA_PATH`: Array of dataset paths in format `"rapidata-t2v:path/to/dataset.parquet"` +- `OUTPUT_DIR`: Directory to save filtered samples and training data +- `INFERENCE_BATCH_SIZE`: Batch size for inference (default: 8) +- `MAX_NEW_TOKENS`: Maximum tokens to generate (default: 2048) +- `VIDEO_FPS`: Video frames per second for processing (default: 2.0) +- `TASK_INSTRUCTION`: CoT instruction template for video evaluation + +This script will: +1. Run inference on video datasets using vLLM +2. Filter correctly predicted samples based on video quality scores +3. Convert filtered samples to training format +4. Save the results to `${OUTPUT_DIR}/rejection_sampling_train.json` + +#### Step 2: Training + +Train the model on the filtered T2V rejection sampling data: + +```bash +bash train_rejection_sampling_t2v.sh +``` + +**Configuration** (edit the script before running): +- `MODEL_PATH`: Path to your pre-trained model +- `TRAINING_DATA_PATH`: Path to the training data generated in Step 1 +- `OUTPUT_DIR`: Directory to save model checkpoints +- `LOG_DIR`: Directory to save training logs +- `VIDEO_FPS`: Video FPS (must match the value used in data preparation) +- Training hyperparameters: `TBS`, `LR`, `MAX_LENGTH`, `MAX_EPOCHS`, etc. + +### Mixed Training (T2I + T2V) + +Train a single GRM model on both image and video rejection sampling data: + +```bash +bash train_rejection_sampling_mix.sh +``` + +**Configuration** (edit the script before running): +- `MODEL_PATH`: Path to your pre-trained model +- `T2I_TRAINING_DATA_PATH`: Path to T2I rejection sampling training data +- `T2V_TRAINING_DATA_PATH`: Path to T2V rejection sampling training data +- `OUTPUT_DIR`: Directory to save model checkpoints +- `LOG_DIR`: Directory to save training logs +- Training hyperparameters: `TBS`, `LR`, `MAX_LENGTH`, `MAX_EPOCHS`, etc. + +## Python Scripts + +### Inference Scripts + +#### `rejection_sampling_inference_t2i.py` + +Performs inference on T2I datasets and filters correctly predicted samples. + +**Usage:** + +```bash +python rejection_sampling_inference_t2i.py \ + --model_path path/to/model \ + --data_path "hpdv3:path/to/dataset.json" \ + --output_path path/to/filtered_samples.json \ + --batch_size 8 \ + --max_new_tokens 2048 \ + --use_cot \ + --task_instruction "Your task instruction here" \ + --tensor_parallel_size 2 \ + --gpu_memory_utilization 0.9 +``` + +**Key Features:** +- Uses vLLM for efficient inference +- Supports multi-GPU inference via tensor parallelism +- Extracts CoT reasoning from generated text +- Filters samples where model prediction matches ground truth +- Saves filtered samples and statistics + +#### `rejection_sampling_inference_t2v.py` + +Performs inference on T2V datasets and filters correctly predicted samples. + +**Usage:** + +```bash +python rejection_sampling_inference_t2v.py \ + --model_path path/to/model \ + --data_path "rapidata-t2v:path/to/dataset1.parquet,rapidata-t2v:path/to/dataset2.parquet" \ + --output_path path/to/filtered_samples.json \ + --batch_size 8 \ + --max_new_tokens 2048 \ + --use_cot \ + --task_instruction "Your task instruction here" \ + --tensor_parallel_size 2 \ + --gpu_memory_utilization 0.9 \ + --video_fps 2.0 +``` + +**Key Features:** +- Supports multiple T2V datasets (comma-separated) +- Computes ground truth preference based on video quality scores (Alignment + Coherence + Preference) +- Processes videos at specified FPS +- Extracts CoT reasoning from generated text + +### Conversion Scripts + +#### `convert_to_rejection_sampling_data_t2i.py` + +Converts filtered T2I samples to training format. + +**Usage:** + +```bash +python convert_to_rejection_sampling_data_t2i.py \ + --filtered_samples_path path/to/filtered_samples.json \ + --output_path path/to/training_data.json \ + --data_root path/to/dataset/root \ + --task_instruction "Your task instruction here" +``` + +**Output Format:** + +```json +{ + "conversations": [ + { + "from": "human", + "value": "Task instruction with prompt" + }, + { + "from": "gpt", + "value": "\nReasoning process...\n\nImage 1 is better" + } + ], + "images": [ + "path/to/image1.jpg", + "path/to/image2.jpg" + ] +} +``` + +#### `convert_to_rejection_sampling_data_t2v.py` + +Converts filtered T2V samples to training format. + +**Usage:** + +```bash +python convert_to_rejection_sampling_data_t2v.py \ + --filtered_samples_path path/to/filtered_samples.json \ + --output_path path/to/training_data.json \ + --task_instruction "Your task instruction here" \ + --video_fps 2.0 +``` + +**Output Format:** + +```json +{ + "conversations": [ + { + "from": "human", + "value": "Task instruction with prompt" + }, + { + "from": "gpt", + "value": "\nReasoning process...\n\nVideo 1 is better" + } + ], + "images": [ + "path/to/video1.mp4", + "path/to/video2.mp4" + ], + "video_fps": 2.0 +} +``` + +## Environment Variables + +The scripts support the following environment variables: + +- `GPUS_PER_NODE`: Number of GPUs to use (default: 2) +- `NNODES`: Number of nodes (default: 1) +- `NODE_RANK`: Current node rank (default: 0) +- `MASTER_ADDR`: Master node address (default: "localhost") +- `MASTER_PORT`: Master port (default: 29500) + +## Requirements + +- PyTorch +- vLLM +- Transformers +- LightRFT +- loguru +- tqdm + +## Task Instructions + +### T2I Task Instruction + +The default T2I task instruction asks the model to: +1. Evaluate two images on multiple dimensions (semantic consistency, aesthetics, authenticity) +2. Provide scores (1-10) for each dimension with rationale +3. Calculate total scores by summing dimension scores +4. Use Chain-of-Thought reasoning within `` tags +5. Output the final answer in `` tags + +### T2V Task Instruction + +The default T2V task instruction asks the model to: +1. Evaluate two videos on multiple dimensions (semantic consistency, temporal coherence, authenticity) +2. Provide scores (1-10) for each dimension with rationale +3. Calculate total scores by summing dimension scores +4. Use Chain-of-Thought reasoning within `` tags +5. Output the final answer in `` tags + +## Notes + +- The inference scripts use vLLM for efficient batched inference +- Tensor parallelism is used for multi-GPU inference +- Ground truth preferences are determined from dataset annotations +- Only correctly predicted samples are used for training +- The training format is compatible with the GRM training pipeline + +## Troubleshooting + +### Out of Memory (OOM) Issues + +1. Reduce `INFERENCE_BATCH_SIZE` or `MICRO_BATCH_SIZE` +2. Reduce `gpu_memory_utilization` +3. Enable gradient checkpointing (already enabled in training scripts) +4. Reduce `MAX_LENGTH` or `prompt_max_len` + +### Low Accuracy + +1. Check if the task instruction matches the training instruction +2. Verify that the model is properly trained on similar tasks +3. Check if the dataset format is correct +4. Review sample predictions in the generated text + +### Dataset Loading Issues + +1. Verify dataset paths are correct +2. Check dataset format (JSON for T2I, Parquet for T2V) +3. Ensure `DATA_ROOT` is set correctly for T2I +4. Check image/video file paths in the dataset diff --git a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2i.py b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2i.py new file mode 100644 index 00000000..c36e36b6 --- /dev/null +++ b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2i.py @@ -0,0 +1,163 @@ +""" +Convert filtered samples to rejection sampling training data format. + +This script converts the filtered correct samples into the format required +for rejection sampling training, similar to imagegen-cot-reward dataset. +""" + +import os +import json +import argparse +from typing import List, Dict +from loguru import logger + + +def convert_to_rejection_sampling_format( + filtered_samples_path: str, + output_path: str, + data_root: str, + task_instruction_template: str = None, +): + """ + Convert filtered samples to rejection sampling training format. + + :param filtered_samples_path: Path to filtered samples JSON file + :type filtered_samples_path: str + :param output_path: Path to save converted training data + :type output_path: str + :param data_root: Root directory of the dataset (for image paths) + :type data_root: str + :param task_instruction_template: Template for task instruction + :type task_instruction_template: str, optional + :return: List of training data items in imagegen-cot-reward format + :rtype: List[Dict] + """ + logger.info(f"Loading filtered samples from {filtered_samples_path}") + + with open(filtered_samples_path, 'r', encoding='utf-8') as f: + filtered_samples = json.load(f) + + logger.info(f"Loaded {len(filtered_samples)} filtered samples") + + # Default task instruction template + if task_instruction_template is None: + task_instruction_template = """Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: {prompt}""" + + training_data = [] + + for idx, sample in enumerate(filtered_samples): + prompt = sample['prompt'] + path1 = sample['path1'] + path2 = sample['path2'] + preference = sample['preference'] + generated_text = sample.get('generated_text', '') + reasoning = sample.get('reasoning', '') + + # Determine which image is better based on preference + # In HPDv3, path1 is the preferred path, path2 is the rejected path + # preference "A" means Image 1 (which is path1) is better + # preference "B" means Image 2 (which is path2) is better, but this means path2 was randomly chosen as Image 1 + # Actually, in HPDv3GRMHandler, preference A means image0 (first shown) is preferred + # So we need to check which path corresponds to which image + + # Since we stored preferred_path and rejected_path, we know: + # - preferred_path (path1) should be the better one + # - rejected_path (path2) should be the worse one + # But the handler randomly assigns them to Image 1 or Image 2 + + # For training data, we always use: Image 1 = preferred, Image 2 = rejected + # This ensures consistency + answer = "Image 1 is better" if preference == "A" else "Image 2 is better" + image1_path = path1 # preferred + image2_path = path2 # rejected + + # Build the response with CoT reasoning + # Note: We use instead of to match the instruction format + if reasoning: + # Use the extracted reasoning from inference + # Clean up the reasoning text + reasoning_clean = reasoning.strip() + response = f"\n{reasoning_clean}\n\n{answer}" + else: + # If no reasoning was extracted, create a placeholder + # In practice, you might want to regenerate this or use a template + response = f"\nBased on the evaluation of semantic consistency, aesthetics, and authenticity, I will compare the two images.\n\n{answer}" + + # Build conversations format + task_instruction = task_instruction_template.format(prompt=prompt) + + # Create training data item in imagegen-cot-reward format + training_item = { + "conversations": [ + { + "from": "human", + "value": task_instruction + }, + { + "from": "gpt", + "value": response + } + ], + "images": [ + image1_path if os.path.isabs(image1_path) else os.path.join(data_root, image1_path), + image2_path if os.path.isabs(image2_path) else os.path.join(data_root, image2_path), + ] + } + + training_data.append(training_item) + + if (idx + 1) % 100 == 0: + logger.info(f"Converted {idx + 1}/{len(filtered_samples)} samples") + + # Save training data + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(training_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(training_data)} training samples to {output_path}") + + return training_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert filtered samples to rejection sampling training format") + parser.add_argument("--filtered_samples_path", type=str, required=True, + help="Path to filtered samples JSON file") + parser.add_argument("--output_path", type=str, required=True, + help="Path to save converted training data") + parser.add_argument("--data_root", type=str, required=True, + help="Root directory of the dataset (for image paths)") + parser.add_argument("--task_instruction", type=str, default=None, + help="Task instruction template (optional)") + + args = parser.parse_args() + + convert_to_rejection_sampling_format( + filtered_samples_path=args.filtered_samples_path, + output_path=args.output_path, + data_root=args.data_root, + task_instruction_template=args.task_instruction, + ) + diff --git a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py new file mode 100644 index 00000000..08613cf9 --- /dev/null +++ b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py @@ -0,0 +1,155 @@ +""" +Convert filtered samples to rejection sampling training data format for T2V. + +This script converts the filtered correct samples into the format required +for rejection sampling training, similar to imagegen-cot-reward dataset but for videos. +""" + +import os +import json +import argparse +from typing import List, Dict +from loguru import logger + + +def convert_to_rejection_sampling_format( + filtered_samples_path: str, + output_path: str, + task_instruction_template: str = None, + video_fps: float = 2.0, +): + """ + Convert filtered samples to rejection sampling training format for T2V. + + :param filtered_samples_path: Path to filtered samples JSON file + :type filtered_samples_path: str + :param output_path: Path to save converted training data + :type output_path: str + :param task_instruction_template: Template for task instruction + :type task_instruction_template: str, optional + :param video_fps: FPS for video processing + :type video_fps: float + :return: List of training data items in imagegen-cot-reward format (for videos) + :rtype: List[Dict] + """ + logger.info(f"Loading filtered samples from {filtered_samples_path}") + + with open(filtered_samples_path, 'r', encoding='utf-8') as f: + filtered_samples = json.load(f) + + logger.info(f"Loaded {len(filtered_samples)} filtered samples") + + # Default task instruction template for T2V + if task_instruction_template is None: + task_instruction_template = """Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: {prompt}""" + + training_data = [] + + for idx, sample in enumerate(filtered_samples): + prompt = sample['prompt'] + path1 = sample['path1'] + path2 = sample['path2'] + preference = sample['preference'] + generated_text = sample.get('generated_text', '') + reasoning = sample.get('reasoning', '') + + # Determine which video is better based on preference + # preference "A" means Video 1 is better + # preference "B" means Video 2 is better + # For training data, we always use: Video 1 = preferred, Video 2 = rejected + # This ensures consistency + answer = "Video 1 is better" if preference == "A" else "Video 2 is better" + video1_path = path1 if preference == "A" else path2 # preferred + video2_path = path2 if preference == "A" else path1 # rejected + + # Build the response with CoT reasoning + if reasoning: + # Use the extracted reasoning from inference + # Clean up the reasoning text + reasoning_clean = reasoning.strip() + response = f"\n{reasoning_clean}\n\n{answer}" + else: + # If no reasoning was extracted, create a placeholder + response = f"\nBased on the evaluation of semantic consistency, temporal coherence, and authenticity, I will compare the two videos.\n\n{answer}" + + # Build conversations format + task_instruction = task_instruction_template.format(prompt=prompt) + + # Create training data item in imagegen-cot-reward format (but for videos) + # We use "images" field name to be compatible with ImageGenCoTRewardHandler + # but store video paths - the handler will need to be modified to support videos + # For now, we store relative paths from data_root + training_item = { + "conversations": [ + { + "from": "human", + "value": task_instruction + }, + { + "from": "gpt", + "value": response + } + ], + "images": [ + video1_path if os.path.isabs(video1_path) else video1_path, + video2_path if os.path.isabs(video2_path) else video2_path, + ], + "video_fps": video_fps, # Store FPS for video processing + } + + training_data.append(training_item) + + if (idx + 1) % 100 == 0: + logger.info(f"Converted {idx + 1}/{len(filtered_samples)} samples") + + # Save training data + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(training_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(training_data)} training samples to {output_path}") + + return training_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert filtered samples to rejection sampling training format for T2V") + parser.add_argument("--filtered_samples_path", type=str, required=True, + help="Path to filtered samples JSON file") + parser.add_argument("--output_path", type=str, required=True, + help="Path to save converted training data") + parser.add_argument("--task_instruction", type=str, default=None, + help="Task instruction template (optional)") + parser.add_argument("--video_fps", type=float, default=2.0, + help="FPS for video processing") + + args = parser.parse_args() + + convert_to_rejection_sampling_format( + filtered_samples_path=args.filtered_samples_path, + output_path=args.output_path, + task_instruction_template=args.task_instruction, + video_fps=args.video_fps, + ) + diff --git a/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2i.py b/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2i.py new file mode 100644 index 00000000..9203996a --- /dev/null +++ b/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2i.py @@ -0,0 +1,370 @@ +""" +Rejection Sampling Inference Script + +This script performs inference on a dataset using a trained GRM model, +filters out correctly predicted samples, and generates training data +with CoT reasoning for rejection sampling training. +""" + +import os +import json +import argparse +import re +from tqdm import tqdm +from typing import List, Dict +from loguru import logger +from torch.utils.data import DataLoader + +from transformers import AutoProcessor, AutoTokenizer +from vllm import LLM, SamplingParams +from lightrft.datasets import RFTDatasetVL, extract_answer +from lightrft.datasets.hpdv3 import HPDv3GRMHandler + + +def extract_response(text: str, media_type: str = "Image") -> str: + """ + Extract the preference from the generated text. + + It first tries to extract the content from ```` tags using :func:`extract_answer`. + If no tags are found, it performs a heuristic search for key phrases (e.g., "Image 1 is better") + at the end of the text. + + :param text: The generated text from the model + :type text: str + :param media_type: The type of media being evaluated ("Image", "Video", or "Audio"), defaults to "Image" + :type media_type: str, optional + + :return: The extracted preference string (e.g., "Image 1 is better") or None if not found + :rtype: str + """ + # 1. Try extracting from tags + ans = extract_answer(text) + if ans: + return ans + + # 2. Heuristic search if no tags found + text_lower = text.lower() + media_lower = media_type.lower() + + key_1 = f"{media_lower} 1 is better" + key_2 = f"{media_lower} 2 is better" + key_equal = f"both {media_lower}s are equally good" + + idx_1 = text_lower.rfind(key_1) + idx_2 = text_lower.rfind(key_2) + idx_equal = text_lower.rfind(key_equal) + + candidates = {} + if idx_1 != -1: + candidates[f"{media_type} 1 is better"] = idx_1 + if idx_2 != -1: + candidates[f"{media_type} 2 is better"] = idx_2 + if idx_equal != -1: + candidates[f"Both {media_lower}s are equally good"] = idx_equal + + if not candidates: + return None + + # Return the one that appears last in the text + return max(candidates, key=candidates.get) + + +TASK_INSTRUCTION_COT = """Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: {prompt} +""" + + + + +def inference_and_filter( + model_path: str, + data_path: List[str], + output_path: str, + config: dict = None, + batch_size: int = 32, + max_new_tokens: int = 2048, + use_cot: bool = True, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, +): + """ + Perform inference on dataset and filter correctly predicted samples. + + :param model_path: Path to the trained GRM model + :type model_path: str + :param data_path: List of dataset paths in format "source:path" + :type data_path: List[str] + :param output_path: Path to save filtered samples + :type output_path: str + :param config: Configuration dict for dataset + :type config: dict, optional + :param batch_size: Batch size for inference + :type batch_size: int + :param max_new_tokens: Maximum tokens to generate + :type max_new_tokens: int + :param use_cot: Whether to use CoT instruction (for generating reasoning) + :type use_cot: bool + :param tensor_parallel_size: Number of GPUs for tensor parallelism + :type tensor_parallel_size: int + :param gpu_memory_utilization: GPU memory utilization ratio + :type gpu_memory_utilization: float + :return: List of correctly predicted samples with their generated text and reasoning + :rtype: List[Dict] + """ + logger.info(f"Loading model from: {model_path}") + + # Initialize vLLM + llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=True, + gpu_memory_utilization=gpu_memory_utilization, + limit_mm_per_prompt={ + "image": 2, + "video": 2 + }, + ) + + sampling_params = SamplingParams( + temperature=0.0, # For deterministic output + max_tokens=max_new_tokens, + ) + + logger.info(f"Model loaded successfully from {model_path}.") + + # Load Processor and Tokenizer for Dataset + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Load Dataset + dataset = RFTDatasetVL( + data_path, + processor=processor, + tokenizer=tokenizer, + strategy=None, + max_length=8192, + config=config, + is_train=False, + ) + + # Fix handler mapping: RFTDatasetVL uses HPDv3PairHandler which returns 3 values, + # but we need HPDv3GRMHandler which returns 2 values for compatibility + for source in dataset.handlers.keys(): + if source == "hpdv3": + dataset.handlers[source] = HPDv3GRMHandler() + logger.info(f"Replaced handler for {source} with HPDv3GRMHandler for compatibility") + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + collate_fn=dataset.collate_fn, + ) + + logger.info(f"Starting inference with CoT: {use_cot}, batch_size: {batch_size}") + + correct_samples = [] + total_samples = 0 + correct_count = 0 + parse_failures = 0 + + for batch_idx, batch in enumerate(tqdm(data_loader)): + try: + input_texts, image_inputs_list, video_inputs_list, extras, _ = batch + + # Prepare inputs for vLLM + inputs = [] + for i in range(len(input_texts)): + prompt = input_texts[i] + image_inputs = image_inputs_list[i] + video_inputs = video_inputs_list[i] + + mm_data = {} + if image_inputs is not None: + mm_data["image"] = image_inputs + if video_inputs is not None: + mm_data["video"] = video_inputs + + inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data + }) + + # Generate with vLLM + outputs = llm.generate(inputs, sampling_params=sampling_params) + + # Decode + gen_texts = [output.outputs[0].text for output in outputs] + + # Evaluate and filter + batch_correct = 0 + batch_total = len(gen_texts) + for i, (gen_text, extra) in enumerate(zip(gen_texts, extras)): + total_samples += 1 + predicted_answer = extract_response(gen_text, media_type="Image") + gt_preference = extra['preference'] # A or B + + # Mapping logic: + # In HPDv3GRMHandler, preference "A" means Image 1 (first shown) is preferred + # preference "B" means Image 2 (second shown) is preferred + # But the handler randomly swaps images, so we need to check the actual mapping + # The handler stores: preferred_path (path1) and rejected_path (path2) + # When preference is "A", image0 (which could be preferred or rejected) is shown as Image 1 + # When preference is "B", image1 (which could be preferred or rejected) is shown as Image 1 + + # Since the handler randomly assigns, we check based on the stored preference + # If gt_preference is "A", it means Image 1 (first shown) is better + # If gt_preference is "B", it means Image 2 (second shown) is better + is_correct = False + if predicted_answer is None: + parse_failures += 1 + logger.warning(f"Could not extract answer from generated text: {gen_text[:200]}...") + elif gt_preference == "A" and predicted_answer == "Image 1 is better": + is_correct = True + elif gt_preference == "B" and predicted_answer == "Image 2 is better": + is_correct = True + + if is_correct: + correct_count += 1 + batch_correct += 1 + # Prepare sample for rejection sampling training + sample = { + "prompt": extra['prompt'], + "path1": extra['preferred_path'], + "path2": extra['rejected_path'], + "preference": gt_preference, + "generated_text": gen_text, + "predicted_answer": predicted_answer, + } + + # If we want to use the generated CoT reasoning, extract it + if use_cot: + # Extract reasoning from generated text + # Try both and tags (in case of different formats) + reasoning_match = None + import re + # Try first (standard format) + if "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + # Try as fallback + elif "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + + if reasoning_match: + reasoning = reasoning_match.group(1).strip() + sample["reasoning"] = reasoning + else: + # If no reasoning found, we'll use the full generated text (excluding answer) + # or generate it during training data preparation + # Remove answer part to get reasoning + answer_part = f"{predicted_answer}" if predicted_answer else "" + reasoning_candidate = gen_text.replace(answer_part, "").strip() + sample["reasoning"] = reasoning_candidate if reasoning_candidate else None + + correct_samples.append(sample) + + # Output real-time accuracy after each batch + current_accuracy = correct_count / total_samples if total_samples > 0 else 0.0 + batch_accuracy = batch_correct / batch_total if batch_total > 0 else 0.0 + parse_failure_rate = parse_failures / total_samples if total_samples > 0 else 0.0 + logger.info( + f"Batch {batch_idx + 1} | " + f"Batch Acc: {batch_accuracy*100:.2f}% ({batch_correct}/{batch_total}) | " + f"Overall Acc: {current_accuracy*100:.2f}% ({correct_count}/{total_samples}) | " + f"Parse Fail: {parse_failure_rate*100:.2f}% ({parse_failures}/{total_samples}) | " + f"Filtered: {len(correct_samples)}" + ) + + except Exception as e: + logger.error(f"Error at batch {batch_idx}: {e}") + raise + + # Summary + accuracy = correct_count / total_samples if total_samples > 0 else 0.0 + failure_rate = parse_failures / total_samples if total_samples > 0 else 0.0 + logger.info(f"Inference completed. Accuracy: {accuracy*100:.2f}% ({correct_count}/{total_samples})") + logger.info(f"Parse Failure Rate: {failure_rate*100:.2f}% ({parse_failures}/{total_samples})") + logger.info(f"Filtered {len(correct_samples)} correct samples for rejection sampling training") + + # Save filtered samples + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(correct_samples, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + + # Save statistics + stats_path = output_path.replace('.json', '_stats.txt') + with open(stats_path, 'w', encoding='utf-8') as f: + f.write(f"Dataset paths: {data_path}\n") + f.write(f"Model path: {model_path}\n") + f.write(f"Total samples: {total_samples}\n") + f.write(f"Correct samples: {correct_count}\n") + f.write(f"Accuracy: {accuracy*100:.2f}%\n") + f.write(f"Parse failures: {parse_failures}\n") + f.write(f"Parse Failure Rate: {failure_rate*100:.2f}%\n") + f.write(f"Filtered samples for training: {len(correct_samples)}\n") + + return correct_samples + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Rejection Sampling Inference") + parser.add_argument("--model_path", type=str, required=True, help="Path to trained GRM model") + parser.add_argument("--data_path", type=str, required=True, help="Dataset path in format 'source:path'") + parser.add_argument("--output_path", type=str, required=True, help="Path to save filtered samples") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for inference") + parser.add_argument("--max_new_tokens", type=int, default=2048, help="Maximum tokens to generate") + parser.add_argument("--use_cot", action="store_true", default=True, help="Use CoT instruction for reasoning") + parser.add_argument("--task_instruction", type=str, default=TASK_INSTRUCTION_COT, help="Task instruction template") + + # vLLM arguments + parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism") + parser.add_argument("--gpu_memory_utilization", type=float, default=0.9, help="GPU memory utilization ratio") + + args = parser.parse_args() + + # Parse data path + data_paths = [args.data_path] if isinstance(args.data_path, str) else args.data_path.split(",") + + config = { + "task_instruction": args.task_instruction, + "name": "rejection_sampling_inference", + } + + inference_and_filter( + model_path=args.model_path, + data_path=data_paths, + output_path=args.output_path, + config=config, + batch_size=args.batch_size, + max_new_tokens=args.max_new_tokens, + use_cot=args.use_cot, + tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=args.gpu_memory_utilization, + ) + diff --git a/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py b/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py new file mode 100644 index 00000000..65d4e3ae --- /dev/null +++ b/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py @@ -0,0 +1,477 @@ +""" +Rejection Sampling Inference Script for Text-to-Video (T2V) + +This script performs inference on a dataset using a trained GRM model, +filters out correctly predicted samples, and generates training data +with CoT reasoning for rejection sampling training. + +For Rapidata-T2V, we compute gt_preference based on the sum of three dimensions: +Alignment + Coherence + Preference +""" + +import os +import json +import argparse +import re +from tqdm import tqdm +from typing import List, Dict +from loguru import logger +from torch.utils.data import DataLoader + +from transformers import AutoProcessor, AutoTokenizer +from vllm import LLM, SamplingParams +from lightrft.datasets import extract_answer, RFTDatasetVL + +# Import qwen_vl_utils for processing vision info +try: + from qwen_vl_utils import process_vision_info +except ImportError: + try: + from keye_vl_utils import process_vision_info + except ImportError: + raise ImportError("Neither qwen_vl_utils nor keye_vl_utils is available") + + +TASK_INSTRUCTION_COT_T2V = """Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + + +class GRMPromptDatasetVLT2V: + """ + Dataset wrapper for vLLM inference that returns prompts and video paths + instead of tokenized inputs. Adapted for T2V with RFTDatasetVL. + """ + def __init__( + self, + dataset_paths: List[str], + processor: AutoProcessor, + tokenizer: AutoTokenizer, + strategy=None, + max_length: int = 8192, + config: Dict = None, + is_training: bool = False, + ): + self.base_dataset = RFTDatasetVL( + dataset_paths, + processor=processor, + tokenizer=tokenizer, + strategy=strategy, + max_length=max_length, + config=config, + is_train=is_training, + ) + self.processor = processor + self.tokenizer = tokenizer + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, idx): + item = self.base_dataset.data[idx] + source = item["source"] + handler = self.base_dataset.handlers[source] + + # Get media info (paths) + media_info = handler.get_media_info(item) + + # Load media content (needed for parse_item) + loaded_content = self.base_dataset.media_content_loader(media_info) + if loaded_content is None: + raise RuntimeError(f"Failed to load media content: {media_info}") + + # Parse item to get messages (returns messages0, messages1, other for PairHandler) + messages0, messages1, other = handler.parse_item(item, loaded_content, self.base_dataset.config) + + # Combine messages0 and messages1 to show both videos in the same conversation + # Similar to HPDv3GRMHandler format: system prompt + Video 1 + Video 2 + messages = [] + + # Add system prompt (from messages0) + if len(messages0) > 0 and messages0[0].get("role") == "system": + messages.append(messages0[0]) + + # Add Video 1 with label + if len(messages0) > 1 and messages0[1].get("role") == "user": + video1_content = messages0[1]["content"] + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "**Video 1:**" + }, + video1_content[0] if isinstance(video1_content, list) and len(video1_content) > 0 else video1_content + ] + }) + + # Add Video 2 with label (from messages1) + if len(messages1) > 1 and messages1[1].get("role") == "user": + video2_content = messages1[1]["content"] + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "**Video 2:**" + }, + video2_content[0] if isinstance(video2_content, list) and len(video2_content) > 0 else video2_content + ] + }) + + # Get prompt text (exclude the last assistant message for inference) + messages_for_prompt = messages[:-1] if len(messages) > 0 and messages[-1].get("role") == "assistant" else messages + prompt_text = self.processor.apply_chat_template( + messages_for_prompt, + tokenize=False, + add_generation_prompt=True, + ) + + # Extract video information from messages using process_vision_info + # This is the same way test_grm_vl_vllm.py does it + # process_vision_info returns (image_inputs, video_inputs, video_kwargs) + # but we only need image_inputs and video_inputs for vLLM + image_inputs, video_inputs, _ = process_vision_info( + messages_for_prompt, + return_video_kwargs=True, + ) + + # Store original item for accessing raw scores + other['_raw_item'] = item + + return prompt_text, image_inputs, video_inputs, other + + def collate_fn(self, batch): + input_texts = [] + image_inputs_list = [] + video_inputs_list = [] + extras = [] + + for prompt_text, image_inputs, video_inputs, other in batch: + input_texts.append(prompt_text) + image_inputs_list.append(image_inputs if image_inputs else None) + video_inputs_list.append(video_inputs if video_inputs else None) + extras.append(other) + + return input_texts, image_inputs_list, video_inputs_list, extras + + +def safe_get_score(item: Dict, key: str, default: float = 0.0) -> float: + """ + Safely get score value from item, handling None values. + + :param item: Dictionary containing score values + :param key: Key to look up in the dictionary + :param default: Default value to use if key is missing or value is None + :return: Float score value + """ + value = item.get(key, default) + return default if value is None else float(value) + + +def compute_total_score(item: Dict, video_num: int) -> float: + """ + Compute total score for a video based on three dimensions. + + :param item: Dictionary containing score values + :param video_num: Video number (1 or 2) + :return: Total score (Alignment + Coherence + Preference) + """ + alignment = safe_get_score(item, f"weighted_results{video_num}_Alignment", 0.0) + coherence = safe_get_score(item, f"weighted_results{video_num}_Coherence", 0.0) + preference = safe_get_score(item, f"weighted_results{video_num}_Preference", 0.0) + return alignment + coherence + preference + + +def compute_gt_preference_from_scores(item: Dict) -> str: + """ + Compute ground truth preference based on sum of three dimensions: + Alignment + Coherence + Preference + + Returns "A" if video1 has higher total score, "B" if video2 has higher total score. + """ + total_score1 = compute_total_score(item, 1) + total_score2 = compute_total_score(item, 2) + + if total_score1 > total_score2: + return "A" # Video 1 is better + elif total_score1 < total_score2: + return "B" # Video 2 is better + else: + return "C" # Equal (shouldn't happen often, but handle it) + + +def inference_and_filter( + model_path: str, + data_path: List[str], + output_path: str, + config: dict = None, + batch_size: int = 32, + max_new_tokens: int = 2048, + use_cot: bool = True, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, + video_fps: float = 2.0, +): + """ + Perform inference on dataset and filter correctly predicted samples. + + :param model_path: Path to the trained GRM model + :type model_path: str + :param data_path: List of dataset paths in format "source:path" + :type data_path: List[str] + :param output_path: Path to save filtered samples + :type output_path: str + :param config: Configuration dict for dataset + :type config: dict, optional + :param batch_size: Batch size for inference + :type batch_size: int + :param max_new_tokens: Maximum tokens to generate + :type max_new_tokens: int + :param use_cot: Whether to use CoT instruction (for generating reasoning) + :type use_cot: bool + :param tensor_parallel_size: Number of GPUs for tensor parallelism + :type tensor_parallel_size: int + :param gpu_memory_utilization: GPU memory utilization ratio + :type gpu_memory_utilization: float + :param video_fps: FPS for video processing + :type video_fps: float + :return: List of correctly predicted samples with their generated text and reasoning + :rtype: List[Dict] + """ + logger.info(f"Loading model from: {model_path}") + + # Initialize vLLM + llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=True, + gpu_memory_utilization=gpu_memory_utilization, + limit_mm_per_prompt={ + "image": 0, + "video": 2 + }, + ) + + sampling_params = SamplingParams( + temperature=0.0, # For deterministic output + max_tokens=max_new_tokens, + ) + + logger.info(f"Model loaded successfully from {model_path}.") + + # Load Processor and Tokenizer for Dataset + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Load Dataset + dataset = GRMPromptDatasetVLT2V( + data_path, + processor=processor, + tokenizer=tokenizer, + strategy=None, + max_length=8192, + config=config, + is_training=False, + ) + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + collate_fn=dataset.collate_fn, + ) + + logger.info(f"Starting inference with CoT: {use_cot}, batch_size: {batch_size}") + + correct_samples = [] + total_samples = 0 + correct_count = 0 + + for batch_idx, batch in enumerate(tqdm(data_loader)): + try: + input_texts, image_inputs_list, video_inputs_list, extras = batch + + # Prepare inputs for vLLM (same format as test_grm_vl_vllm.py) + inputs = [] + for i in range(len(input_texts)): + prompt = input_texts[i] + image_inputs = image_inputs_list[i] + video_inputs = video_inputs_list[i] + + mm_data = {} + if image_inputs is not None: + mm_data["image"] = image_inputs + if video_inputs is not None: + mm_data["video"] = video_inputs + + inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data + }) + + # Generate with vLLM + outputs = llm.generate(inputs, sampling_params=sampling_params) + + # Decode + gen_texts = [output.outputs[0].text for output in outputs] + + # Evaluate and filter + for i, (gen_text, extra) in enumerate(zip(gen_texts, extras)): + total_samples += 1 + predicted_answer = extract_answer(gen_text) + + # Get raw item to compute gt_preference from scores + raw_item = extra.get('_raw_item', {}) + gt_preference = compute_gt_preference_from_scores(raw_item) + + # Mapping logic: + # "A" means Video 1 is better + # "B" means Video 2 is better + is_correct = False + if gt_preference == "A" and predicted_answer == "Video 1 is better": + is_correct = True + elif gt_preference == "B" and predicted_answer == "Video 2 is better": + is_correct = True + elif gt_preference == "C": + # Handle tie case (should be rare) + logger.warning(f"Tie detected in sample {total_samples}, skipping") + continue + + if is_correct: + correct_count += 1 + # Get video paths from raw item + data_root = raw_item.get('data_root', '') + video1_path = os.path.join(data_root, "videos", raw_item.get('file_name1', '')) + video2_path = os.path.join(data_root, "videos", raw_item.get('file_name2', '')) + + # Prepare sample for rejection sampling training + sample = { + "prompt": raw_item.get('prompt', ''), + "path1": video1_path, + "path2": video2_path, + "preference": gt_preference, + "generated_text": gen_text, + "predicted_answer": predicted_answer, + "score1_total": compute_total_score(raw_item, 1), + "score2_total": compute_total_score(raw_item, 2), + } + + # If we want to use the generated CoT reasoning, extract it + if use_cot: + # Extract reasoning from generated text + reasoning_match = None + # Try tag first + if "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + # Try as fallback (in case model uses different format) + elif "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + + if reasoning_match: + reasoning = reasoning_match.group(1).strip() + sample["reasoning"] = reasoning + else: + # If no reasoning found, use the full generated text (excluding answer) + answer_part = f"{predicted_answer}" if predicted_answer else "" + reasoning_candidate = gen_text.replace(answer_part, "").strip() + sample["reasoning"] = reasoning_candidate if reasoning_candidate else None + + correct_samples.append(sample) + + if total_samples % 100 == 0: + logger.info(f"Processed {total_samples} samples, {correct_count} correct ({correct_count/total_samples*100:.2f}%)") + + except Exception as e: + logger.error(f"Error at batch {batch_idx}: {e}") + import traceback + traceback.print_exc() + raise + + # Summary + accuracy = correct_count / total_samples if total_samples > 0 else 0.0 + logger.info(f"Inference completed. Accuracy: {accuracy*100:.2f}% ({correct_count}/{total_samples})") + logger.info(f"Filtered {len(correct_samples)} correct samples for rejection sampling training") + + # Save filtered samples + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(correct_samples, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + + # Save statistics + stats_path = output_path.replace('.json', '_stats.txt') + with open(stats_path, 'w', encoding='utf-8') as f: + f.write(f"Dataset paths: {data_path}\n") + f.write(f"Model path: {model_path}\n") + f.write(f"Total samples: {total_samples}\n") + f.write(f"Correct samples: {correct_count}\n") + f.write(f"Accuracy: {accuracy*100:.2f}%\n") + f.write(f"Filtered samples for training: {len(correct_samples)}\n") + + return correct_samples + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Rejection Sampling Inference for T2V") + parser.add_argument("--model_path", type=str, required=True, help="Path to trained GRM model") + parser.add_argument("--data_path", type=str, required=True, help="Dataset path(s) in format 'source:path' (comma-separated for multiple)") + parser.add_argument("--output_path", type=str, required=True, help="Path to save filtered samples") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for inference") + parser.add_argument("--max_new_tokens", type=int, default=2048, help="Maximum tokens to generate") + parser.add_argument("--use_cot", action="store_true", default=True, help="Use CoT instruction for reasoning") + parser.add_argument("--task_instruction", type=str, default=TASK_INSTRUCTION_COT_T2V, help="Task instruction template") + + # vLLM arguments + parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism") + parser.add_argument("--gpu_memory_utilization", type=float, default=0.9, help="GPU memory utilization ratio") + parser.add_argument("--video_fps", type=float, default=2.0, help="FPS for video processing") + + args = parser.parse_args() + + # Parse data path + data_paths = args.data_path.split(",") if isinstance(args.data_path, str) else args.data_path + + config = { + "task_instruction": args.task_instruction, + "name": "rejection_sampling_inference_t2v", + "video_fps": args.video_fps, + } + + inference_and_filter( + model_path=args.model_path, + data_path=data_paths, + output_path=args.output_path, + config=config, + batch_size=args.batch_size, + max_new_tokens=args.max_new_tokens, + use_cot=args.use_cot, + tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=args.gpu_memory_utilization, + video_fps=args.video_fps, + ) + diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling_t2i.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling_t2i.sh new file mode 100644 index 00000000..98228661 --- /dev/null +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling_t2i.sh @@ -0,0 +1,151 @@ +#!/bin/bash + +# Rejection Sampling Data Preparation Script +# This script performs rejection sampling data preparation: +# 1. Inference on dataset and filter correct samples +# 2. Convert filtered samples to training format + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (cold-start model) +# Please set your model path here +MODEL_PATH="path/to/your/model" + +# Dataset configuration +# Please set your dataset path here (format: "source:path") +DATA_PATH="hpdv3:path/to/dataset.json" +# Please set your dataset root directory here +DATA_ROOT="path/to/dataset/root" + +# Output paths +OUTPUT_DIR="./results/rejection_sampling_$(date +%Y%m%d_%H%M%S)" +FILTERED_SAMPLES_PATH="${OUTPUT_DIR}/filtered_samples.json" +TRAINING_DATA_PATH="${OUTPUT_DIR}/rejection_sampling_train.json" + +# Inference parameters +INFERENCE_BATCH_SIZE=8 +MAX_NEW_TOKENS=2048 + +# Task instruction for CoT reasoning +TASK_INSTRUCTION="""Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' or 'Both are equal' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Validate required configuration +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ -z "${DATA_PATH}" ]; then + echo "Error: DATA_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ -z "${DATA_ROOT}" ]; then + echo "Error: DATA_ROOT is not set. Please configure it in the script." + exit 1 +fi + +# Create output directory +mkdir -p ${OUTPUT_DIR} +LOG_BASE="${OUTPUT_DIR}/logs" +mkdir -p ${LOG_BASE} + +echo "==========================================" +echo "Rejection Sampling Data Preparation" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Data: ${DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "==========================================" + +############################### Step 1: Inference and Filter ########################## +echo "" +echo "Step 1: Running inference and filtering correct samples..." +echo "==========================================" + +# Use vLLM for inference (vLLM handles multi-GPU internally via tensor_parallel_size) +python examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ + --model_path ${MODEL_PATH} \ + --data_path ${DATA_PATH} \ + --output_path ${FILTERED_SAMPLES_PATH} \ + --batch_size ${INFERENCE_BATCH_SIZE} \ + --max_new_tokens ${MAX_NEW_TOKENS} \ + --use_cot \ + --task_instruction "${TASK_INSTRUCTION}" \ + --tensor_parallel_size ${GPUS_PER_NODE} \ + --gpu_memory_utilization 0.9 \ + 2>&1 | tee ${LOG_BASE}/inference.log + +if [ ! -f "${FILTERED_SAMPLES_PATH}" ]; then + echo "Error: Filtered samples file not created!" + exit 1 +fi + +echo "Step 1 completed. Filtered samples saved to: ${FILTERED_SAMPLES_PATH}" + +############################### Step 2: Convert to Training Format ########################## +echo "" +echo "Step 2: Converting filtered samples to training format..." +echo "==========================================" + +python examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py \ + --filtered_samples_path ${FILTERED_SAMPLES_PATH} \ + --output_path ${TRAINING_DATA_PATH} \ + --data_root ${DATA_ROOT} \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_BASE}/convert.log + +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not created!" + exit 1 +fi + +echo "Step 2 completed. Training data saved to: ${TRAINING_DATA_PATH}" + +echo "" +echo "==========================================" +echo "Rejection Sampling Data Preparation Completed!" +echo "==========================================" +echo "Filtered samples: ${FILTERED_SAMPLES_PATH}" +echo "Training data: ${TRAINING_DATA_PATH}" +echo "All outputs saved to: ${OUTPUT_DIR}" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh new file mode 100755 index 00000000..d15ec144 --- /dev/null +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh @@ -0,0 +1,155 @@ +#!/bin/bash + +# Rejection Sampling Data Preparation Script for Text-to-Video (T2V) +# This script performs rejection sampling data preparation: +# 1. Inference on dataset and filter correct samples +# 2. Convert filtered samples to training format + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (cold-start model) +# Please set your model path here +MODEL_PATH="path/to/your/model" + +# Dataset configuration +# Multiple rapidata-t2v datasets +# Format: "rapidata-t2v:path/to/dataset.parquet" +DATA_PATH=( + "rapidata-t2v:path/to/dataset1.parquet" + "rapidata-t2v:path/to/dataset2.parquet" + "rapidata-t2v:path/to/dataset3.parquet" +) + +# Output paths +OUTPUT_DIR="./results/rejection_sampling_t2v_$(date +%Y%m%d_%H%M%S)" +FILTERED_SAMPLES_PATH="${OUTPUT_DIR}/filtered_samples.json" +TRAINING_DATA_PATH="${OUTPUT_DIR}/rejection_sampling_train.json" + +# Inference parameters +INFERENCE_BATCH_SIZE=8 +MAX_NEW_TOKENS=2048 + +# Video FPS configuration +VIDEO_FPS=2.0 + +# Task instruction for CoT reasoning (T2V) +TASK_INSTRUCTION="""Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Validate required configuration +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ ${#DATA_PATH[@]} -eq 0 ]; then + echo "Error: DATA_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directory +mkdir -p ${OUTPUT_DIR} +LOG_BASE="${OUTPUT_DIR}/logs" +mkdir -p ${LOG_BASE} + +echo "==========================================" +echo "Rejection Sampling Data Preparation (T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Data: ${DATA_PATH[@]}" +echo "Output: ${OUTPUT_DIR}" +echo "==========================================" + +############################### Step 1: Inference and Filter ########################## +echo "" +echo "Step 1: Running inference and filtering correct samples..." +echo "==========================================" + +# Convert array to comma-separated string for Python script +DATA_PATH_STR=$(IFS=','; echo "${DATA_PATH[*]}") + +# Use vLLM for inference (vLLM handles multi-GPU internally via tensor_parallel_size) +python examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py \ + --model_path ${MODEL_PATH} \ + --data_path ${DATA_PATH_STR} \ + --output_path ${FILTERED_SAMPLES_PATH} \ + --batch_size ${INFERENCE_BATCH_SIZE} \ + --max_new_tokens ${MAX_NEW_TOKENS} \ + --use_cot \ + --task_instruction "${TASK_INSTRUCTION}" \ + --tensor_parallel_size ${GPUS_PER_NODE} \ + --gpu_memory_utilization 0.9 \ + --video_fps ${VIDEO_FPS} \ + 2>&1 | tee ${LOG_BASE}/inference.log + +if [ ! -f "${FILTERED_SAMPLES_PATH}" ]; then + echo "Error: Filtered samples file not created!" + exit 1 +fi + +echo "Step 1 completed. Filtered samples saved to: ${FILTERED_SAMPLES_PATH}" + +############################### Step 2: Convert to Training Format ########################## +echo "" +echo "Step 2: Converting filtered samples to training format..." +echo "==========================================" + +python examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py \ + --filtered_samples_path ${FILTERED_SAMPLES_PATH} \ + --output_path ${TRAINING_DATA_PATH} \ + --task_instruction "${TASK_INSTRUCTION}" \ + --video_fps ${VIDEO_FPS} \ + 2>&1 | tee ${LOG_BASE}/convert.log + +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not created!" + exit 1 +fi + +echo "Step 2 completed. Training data saved to: ${TRAINING_DATA_PATH}" + +echo "" +echo "==========================================" +echo "Rejection Sampling Data Preparation Completed (T2V)!" +echo "==========================================" +echo "Filtered samples: ${FILTERED_SAMPLES_PATH}" +echo "Training data: ${TRAINING_DATA_PATH}" +echo "All outputs saved to: ${OUTPUT_DIR}" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh new file mode 100644 index 00000000..2b46eb57 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +# Mixed training script for rejection sampling Image (T2I) + Video (T2V) data +# 使用同一个 GRM 模型,在图像和视频拒绝采样数据上进行联合训练。 + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# 预训练模型路径(从该模型继续训练) +MODEL_PATH="path/to/your/model" + +# 图像拒绝采样数据(convert_to_rejection_sampling_data.py 的输出) +T2I_TRAINING_DATA_PATH="path/to/t2i_rejection_sampling_train.json" + +# 视频拒绝采样数据(convert_to_rejection_sampling_data_t2v.py 的输出) +T2V_TRAINING_DATA_PATH="path/to/t2v_rejection_sampling_train.json" + +# 输出目录 +OUTPUT_DIR="./results/rejection_sampling_mix_$(date +%Y%m%d_%H%M%S)/checkpoint" +LOG_DIR="$(dirname "${OUTPUT_DIR}")/logs" + +# 训练超参(与 examples/grm_training/run_grm_vl.sh 对齐) +TBS=2 # global train batch size(与 run_grm_vl.sh 中 TBS 一致) +LR=1e-5 # 与 run_grm_vl.sh 中 LR 一致 +MAX_LENGTH=8196 # 与 run_grm_vl.sh 中 MAX_LENGTH 一致 +MAX_EPOCHS=3 # 与 run_grm_vl.sh 中 --max_epochs 5 一致 +MICRO_BATCH_SIZE=1 # 与 run_grm_vl.sh 中 micro_train_batch_size 4 一致 +GRADIENT_ACCUMULATION_STEPS=1 # 当前脚本未显式传入该参数,仅作记录 + +# 视频 FPS 配置 +VIDEO_FPS=2.0 + +# 注意: +# - Image 数据的 system prompt(task_instruction)从 T2I_TRAINING_DATA_PATH 对应的 json 里读取; +# - Video 数据的 system prompt 从 T2V_TRAINING_DATA_PATH 对应的 json 里读取; +# 每条样本在 json 的 conversations[0]['value'] 里已经带了各自的 CoT 说明,因此这里不再额外传统一的 TASK_INSTRUCTION。 + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# 减少显存碎片 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# 检查配置 +if [ ! -f "${T2I_TRAINING_DATA_PATH}" ]; then + echo "Error: T2I training data file not found: ${T2I_TRAINING_DATA_PATH}" + exit 1 +fi + +if [ ! -f "${T2V_TRAINING_DATA_PATH}" ]; then + echo "Error: T2V training data file not found: ${T2V_TRAINING_DATA_PATH}" + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# 创建输出目录 +mkdir -p "${OUTPUT_DIR}" +mkdir -p "${LOG_DIR}" + +echo "==========================================" +echo "Rejection Sampling Mixed Training (T2I + T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "T2I Training Data: ${T2I_TRAINING_DATA_PATH}" +echo "T2V Training Data: ${T2V_TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "Video FPS: ${VIDEO_FPS}" +echo "==========================================" + +# 这里利用 GRMDataset 中的多 handler 能力: +# args.train_data 会在 train_grm_vl.py 中被按逗号切分成 list, +# 每个元素是 "source:path" 的形式。 +T2I_SOURCE="imagegen-cot-reward-5k:${T2I_TRAINING_DATA_PATH}" +T2V_SOURCE="rejection-sampling-t2v:${T2V_TRAINING_DATA_PATH}" + +TRAINING_DATA_SOURCES="${T2I_SOURCE},${T2V_SOURCE}" + +############################### Training ########################## +echo "" +echo "Starting mixed training on T2I + T2V rejection sampling data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.0 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps ${VIDEO_FPS} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCES} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 2 \ + --use_tensorboard "$(dirname "${OUTPUT_DIR}")/tensorboard" \ + --l2 1.0e-4 \ + --flash_attn \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Mixed Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling_t2i.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling_t2i.sh new file mode 100755 index 00000000..f12178a1 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling_t2i.sh @@ -0,0 +1,134 @@ +#!/bin/bash + +# Training script for rejection sampling data +# This script trains the model on the filtered rejection sampling data + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (pretrained model to continue training from) +MODEL_PATH="path/to/your/model" + +# Training data path (already converted rejection sampling data) +TRAINING_DATA_PATH="path/to/rejection_sampling_train.json" + +# Output directory for checkpoints +OUTPUT_DIR="./results/rejection_sampling_training_$(date +%Y%m%d_%H%M%S)/checkpoint" +LOG_DIR="./results/rejection_sampling_training_$(date +%Y%m%d_%H%M%S)/logs" + +# Training hyperparameters +TBS=4 # Reduced from 8 to save memory +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=16 # Increase to maintain effective batch size (4 * 16 = 64) + +# Task instruction for CoT reasoning (must match the one used during inference) +TASK_INSTRUCTION="""Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' or 'Both are equal' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs by default +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Memory optimization: reduce fragmentation +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Validate required configuration +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not found: ${TRAINING_DATA_PATH}" + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directories +mkdir -p ${OUTPUT_DIR} +mkdir -p ${LOG_DIR} + +echo "==========================================" +echo "Rejection Sampling Training" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Training Data: ${TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "==========================================" + +# Use imagegen-cot-reward handler for the converted data +TRAINING_DATA_SOURCE="imagegen-cot-reward-5k:${TRAINING_DATA_PATH}" + +############################### Training ########################## +echo "" +echo "Starting training on rejection sampling data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps 2.0 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 2 \ + --use_tensorboard "${OUTPUT_DIR}/../tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh new file mode 100755 index 00000000..e8ceb796 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +# Training script for rejection sampling T2V (text-to-video) data +# This script trains the model on the filtered rejection sampling video data + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (pretrained model to continue training from) +MODEL_PATH="path/to/your/model" + +# Training data path (already converted rejection sampling data) +# This should be the output from convert_to_rejection_sampling_data_t2v.py +TRAINING_DATA_PATH="path/to/rejection_sampling_train.json" + +# Output directory for checkpoints +OUTPUT_DIR="./results/rejection_sampling_t2v_training_$(date +%Y%m%d_%H%M%S)/checkpoint" +LOG_DIR="./results/rejection_sampling_t2v_training_$(date +%Y%m%d_%H%M%S)/logs" + +# Training hyperparameters +TBS=8 +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 # Increase to maintain effective batch size + +# Video FPS configuration +VIDEO_FPS=2.0 + +# Task instruction for CoT reasoning (T2V) - must match the one used during inference +TASK_INSTRUCTION="""Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs by default +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Memory optimization: reduce fragmentation +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Validate required configuration +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not found: ${TRAINING_DATA_PATH}" + echo "Please run convert_to_rejection_sampling_data_t2v.py first to convert filtered samples." + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directories +mkdir -p ${OUTPUT_DIR} +mkdir -p ${LOG_DIR} + +echo "==========================================" +echo "Rejection Sampling Training (T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Training Data: ${TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "Video FPS: ${VIDEO_FPS}" +echo "==========================================" + +# Use rejection-sampling-t2v handler for the converted data +TRAINING_DATA_SOURCE="rejection-sampling-t2v:${TRAINING_DATA_PATH}" + +############################### Training ########################## +echo "" +echo "Starting training on rejection sampling T2V data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps ${VIDEO_FPS} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "${OUTPUT_DIR}/../tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "=========================================="