From d12c3b66d07f95b9bddcee4792e0956cd51400cd Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Thu, 26 Feb 2026 07:45:45 +0000 Subject: [PATCH 1/5] feature(nyz): transfer meme rl training demo --- examples/godsmeme/meme_dataset.py | 115 ++++++ examples/godsmeme/reward_model.py | 464 +++++++++++++++++++++++++ examples/godsmeme/test_meme_dataset.py | 49 +++ examples/godsmeme/train_colocate.py | 407 ++++++++++++++++++++++ pyproject.toml | 2 +- 5 files changed, 1036 insertions(+), 1 deletion(-) create mode 100644 examples/godsmeme/meme_dataset.py create mode 100644 examples/godsmeme/reward_model.py create mode 100644 examples/godsmeme/test_meme_dataset.py create mode 100644 examples/godsmeme/train_colocate.py diff --git a/examples/godsmeme/meme_dataset.py b/examples/godsmeme/meme_dataset.py new file mode 100644 index 00000000..4610cd91 --- /dev/null +++ b/examples/godsmeme/meme_dataset.py @@ -0,0 +1,115 @@ +from typing import List, Dict, Union +import os +import json +import random +import torch +from PIL import Image +from torch.utils.data import Dataset + + +class MemeOnlineRLDataset(Dataset): + """Meme dataset class with lazy loading per item. + Supports both JSON array format (e.g. eval_data.json) and JSONL format. + """ + ASSISTANT_ROLES = ('gpt', 'assistant') + + def __init__( + self, + annotation_path: str, + root_dir: str, + processor, + shuffle: bool = True, + ): + super().__init__() + + if not os.path.exists(annotation_path): + raise FileNotFoundError(f"Annotation file {annotation_path} does not exist") + if not os.path.isdir(root_dir): + raise NotADirectoryError(f"Image root directory {root_dir} is invalid") + + self.root_dir = root_dir + self.annotation_path = annotation_path + self.processor = processor + + # Only load the raw data lines without processing + self._raw_data = self._load_raw_data() + if shuffle: + random.shuffle(self._raw_data) + + def _load_raw_data(self) -> List[Union[Dict, str]]: + """Load annotation file. Supports: + - JSON array format: entire file is one JSON list (e.g. eval_data.json) + - JSONL format: one JSON object per line + """ + with open(self.annotation_path, 'r', encoding='utf-8') as f: + content = f.read().strip() + if not content: + return [] + try: + data = json.loads(content) + if isinstance(data, list): + return data + except json.JSONDecodeError: + pass + with open(self.annotation_path, 'r', encoding='utf-8') as f: + return [line.strip() for line in f if line.strip()] + + def _process_item(self, raw_item: Union[Dict, str]) -> Dict: + """Process a single data item on-demand. raw_item is either a dict or JSON string.""" + data = raw_item if isinstance(raw_item, dict) else json.loads(raw_item) + image_path = os.path.join(self.root_dir, data['image']) \ + if not os.path.isabs(data['image']) else data['image'] + + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image {image_path} does not exist") + + conversations = data['conversations'] + human_input = next(c['value'] for c in conversations if c['from'] == 'human' and '' in c['value']) + assistant_output = next(c['value'] for c in conversations if c['from'] in self.ASSISTANT_ROLES) + + prompt = [{ + "role": "user", + "content": [ + { + "type": "image", + "image": "" + }, + { + "type": "text", + "text": human_input + }, + ] + }] + prompt = self.processor.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + + return { + 'id': data['id'], + 'image_path': [image_path], + 'prompt': prompt, + # 'prompt': human_input + "\nAssistant:", + 'label': assistant_output, + } + + def __getitem__(self, index) -> Dict: + """Process and return item only when requested""" + # # Process the item on-demand (lazy loading) + raw_item = self._raw_data[index] + processed_item = self._process_item(raw_item) + # # text, image, label, reference + return ( + processed_item['prompt'], + processed_item['image_path'], + processed_item['label'], + processed_item['label'], + ) + + def __len__(self) -> int: + return len(self._raw_data) + + @staticmethod + def collate_fn(batch: List[Dict]): + text_list = [x[0] for x in batch] + image_list = [x[1] for x in batch] + reference_list = [x[2] for x in batch] + label_list = [x[3] for x in batch] + return text_list, image_list, reference_list, label_list diff --git a/examples/godsmeme/reward_model.py b/examples/godsmeme/reward_model.py new file mode 100644 index 00000000..c3fcf819 --- /dev/null +++ b/examples/godsmeme/reward_model.py @@ -0,0 +1,464 @@ +from typing import Callable, Dict, List, Tuple, Union, Sequence, Optional +import re +import json +import torch +import torch.nn as nn +import torch.nn.functional as F +from qwen_vl_utils import process_vision_info +from transformers import AutoProcessor, AutoModel, PreTrainedModel, AutoConfig, Qwen2_5_VLForConditionalGeneration + +from lightrft.utils import get_current_device + + +class AttentionPooling(nn.Module): + """ + Overview: + Attention pooling layer on the sequence dimension of LLM/VLM hidden states. + """ + def __init__( + self, + hidden_size: int, + num_heads: int = 4, + qkv_bias: bool = False, + position_bias: bool = False, + position_bias_scale: float = 3.0, + ): + super(AttentionPooling, self).__init__() + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.scale = self.head_dim ** -0.5 + self.position_bias = position_bias + self.position_bias_scale = position_bias_scale + + self.k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + self.v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) + # Using 0.02 for better initialization + self.query = nn.Parameter(torch.randn(hidden_size) * 0.02) + + def forward(self, hidden_states): + B, S, C = hidden_states.shape + + # Multi-head projection for key and value + k = self.k(hidden_states).reshape(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # B, H, S, D + v = self.v(hidden_states).reshape(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # B, H, S, D + + # Expand query for batch dimension + q = self.query.unsqueeze(0).expand(B, -1, -1) # B, H, C + q = q.unsqueeze(2) # B, H, 1, C + q = q.reshape(B, self.num_heads, 1, self.head_dim) # B, H, 1, C + + # Attention weights + attn = (q @ k.transpose(-2, -1)) * self.scale # B, H, 1, S + + # Add position bias + if self.position_bias: + position_bias = torch.arange(S, device=k.device).float() / S * self.position_bias_scale + attn = attn + position_bias.view(1, 1, 1, -1) # Add position bias + + # Attention pooling + attn = torch.softmax(attn, dim=-1) # B, H, 1, S + attn = attn.to(v.dtype) + out = (attn @ v).squeeze(2) # B, H, D + out = out.reshape(B, -1) # B, C + + return out + + +class Qwen2VLRewardModelMemeOutcome(PreTrainedModel): + def __init__(self, pretrained_model): + super().__init__(pretrained_model.config) + self.pretrained_model = pretrained_model + + if hasattr(pretrained_model.config, 'hidden_size'): + hidden_size = pretrained_model.config.hidden_size + elif hasattr(pretrained_model.config, 'd_model'): + hidden_size = pretrained_model.config.d_model + else: + raise ValueError("Cannot determine hidden size from model config") + + self.attention_pooling = AttentionPooling( + hidden_size=hidden_size, + num_heads=4, + qkv_bias=False, + position_bias=True, + position_bias_scale=3.0, + ) + + self.classification_head = nn.Linear(hidden_size, 2) + nn.init.normal_(self.classification_head.weight, std=0.02) + nn.init.zeros_(self.classification_head.bias) + + self.attention_pooling.bfloat16() + self.classification_head.bfloat16() + + @classmethod + def from_pretrained(cls, pretrained_model: PreTrainedModel): + """Create a binary classification model from a pretrained model.""" + return cls(pretrained_model) + + @torch.no_grad() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ): + """Forward pass for meme reward evaluation.""" + # Forward through base model + prompt_and_outputs = kwargs.get('prompt_and_output') + raw_images = kwargs.get('raw_images') + outputs = self.pretrained_model( + input_ids=input_ids.cuda(), + attention_mask=attention_mask.cuda(), + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=True, + ) + + # Extract hidden states + if hasattr(outputs, 'last_hidden_state'): + hidden_states = outputs.last_hidden_state + elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: + hidden_states = outputs.hidden_states[-1] + else: + raise ValueError("Cannot extract hidden states from model output") + + # Use attention pooling + pooled_output = self.attention_pooling(hidden_states) + + # Get logits + logits = self.classification_head(pooled_output) + + # Get logits [batch_size, 2], first dimension is probability of 0, second dimension is probability of 1 + logits = self.classification_head(pooled_output) + + # Calculate probabilities and binarize + probabilities = F.softmax(logits, dim=-1) # [batch_size, 2] + # If the second dimension (probability of 1) is larger, output 1, otherwise output 0 + binary_scores = (probabilities[:, 1] > probabilities[:, 0]).float() + + return { + 'score': binary_scores, # Binary result of 0/1 + 'logits': logits # Original logits, containing scores for two dimensions + } + + +class Qwen2VLRewardModelMemeContent(PreTrainedModel): + system_prompt = f""" + You are a professional meme text generation evaluation expert who is good at evaluating the quality of the current reasoning process in combination with images.\n\n + """ + eval_prompt = """ + You are a professional and strict-scoring expert in evaluating meme text generation, responsible for scoring the quality of the reasoning process (Chain of Thought, CoT) generated by the model. + This reasoning process is a text generation reasoning conducted based on the first text-free base image and input parameters (i.e., the user's requirements for the meme). + + Details of the evaluation task: + 1. Input parameters: + {input_params} + + 2. Standard reasoning process (reference standard answer): + {standard_cot} + + 3. Actual reasoning process to be evaluated: (The generated text is after "Text on the Meme") + {actual_cot} + + Please conduct the evaluation by combining the two provided images (the base image and the standard meme image with text), and the scoring must be strict. The evaluation criteria are as follows: + + 1. Whether the chain of thought process includes an analysis of the expressions/actions/facial features/relationships/scenes of the entities in the image, and check its correctness: (Total 10 points) + a. Rough description: For example, there is a woman in this picture; 1 point; + b. With some details but no description of actions/facial expressions: For example, there is a woman in this picture, wearing a hat, sitting in a car; 4 points; + c. With details and actions/facial expressions: For example, there is a woman in this picture, wearing a hat, sitting in a car, looking very happy; 7 points; + d. Not only explaining the details such as the actions and expressions of the characters in the picture, but also immediately associating the character relationships/scenes where the actions occur; 10 points; + + 2. Analysis of further scene associations based on the relationships between entities in the image: (Total 10 points) + a. Only roughly describing possible scenes without specificity: For example, this may happen in daily life; 1 point; + b. Describing a relatively specific scene: For example, this may be the scene when you went out with friends to drink and found yourself vomiting; 4 points; + c. Describing multiple relatively specific scenes: For example, this may be the scene when you went out with friends to drink and found yourself vomiting, or the scene when the teacher checks homework but you find you haven't finished it; 10 points; + + Next, evaluate the content after [Specific analysis with user input]: + + 1. Whether the chain of thought further specifies the scene based on the previously associated scenes combined with user needs, or re-associates a scene more in line with user needs: (5 points for each satisfied item, total 20 points) + a. Whether the intention expressed in the sentence is consistent with the user's emotions; + b. Whether the intention expressed in the sentence is consistent with the user's intentions and the theme; + c. Whether the topic of the entire sentence is consistent with the keyword topic; + d. Whether some humorous techniques are used, such as puns/homophones/semantic reversal/subverting expectations/exaggeration/role dislocation/suspenseful beginning/punchline reversal/rhyming structure/internet memes; + + 2. Logical fluency and text length between the final generated text and the reasoning process: (Total 10 points) + a. The reasoning is very forced, and the humor of the answer paired with this base image is much worse than that of the standard meme, and the text length is much longer than that of the standard meme; 1 point; + b. The reasoning is roughly valid, and the length of the answer divided by boxes is roughly similar to that of the standard meme; 4 points; + c. The reasoning is very fluent, the length of the answer divided by boxes is roughly similar to that of the standard meme, and it is humorous when paired with the image (for humorous techniques, refer to the rhetoric I just mentioned); 7 points; + d. The reasoning is very fluent, the length of the answer divided by boxes is roughly similar to that of the standard meme, or even more interesting than the standard meme; 10 points; + + The total score is 50 points. It is necessary to provide the reasons and scores for each scoring criterion. + Output a line similar to the following at the end: + "Final score: 41 points" + """ + + def __init__(self, pretrained_model): + super().__init__(pretrained_model.config) + self.pretrained_model = pretrained_model + self.processor = None + + def set_processor(self, processor): + self.processor = processor + + @classmethod + def from_pretrained(cls, pretrained_model: PreTrainedModel): + return cls(pretrained_model) + + def _get_message(self, input_params, standard_cot, actual_cot, base_image, standard_meme_image): + message = [ + { + "role": "system", + "content": [{ + "type": "text", + "text": self.system_prompt + }] + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": "The first image is the base map (no text, only a frame):" + }, + { + "type": "image", + "image": base_image + }, + #{"type": "text", "text": "The second image is a standard meme image (with the correct answer in text):"}, + #{"type": "image", "image": standard_meme_image} + { + "type": "text", + "text": self.eval_prompt.format( + input_params=input_params, standard_cot=standard_cot, actual_cot=actual_cot + ) + } + ] + } + ] + return message + + @torch.no_grad() + def forward( + self, + input_ids, + attention_mask, + pixel_values, + image_grid_thw, + return_dict=None, + output_attentions=None, + output_hidden_states=None, + **kwargs + ): + raw_images = kwargs.get('raw_images') + references = kwargs.get('references') + texts = self.processor.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) + texts = [t.split("user\n")[1] for t in texts] + + messages = [] + images = [] + for img, ref, text in zip(raw_images, references, texts): + input_params_match = re.search(r'\*\*Input Parameters\*\*:\s*\[(.*?)\]', text, re.DOTALL) + input_params = input_params_match.group(1) if input_params_match else "not found input params" + message = self._get_message( + input_params=input_params, standard_cot=ref, actual_cot=text, base_image=img, standard_meme_image=None + ) + processed_img, _ = process_vision_info(message) + messages.append(message) + images.append(processed_img) + + messages = [ + self.processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in messages + ] + if torch.distributed.get_rank() % 8 == 0: + print(f"messages: {messages[0]}, {len(messages)}") + inputs = self.processor( + text=messages, + images=images, + max_length=5000, + padding=True, + truncation=True, + return_tensors="pt", + ) + inputs = inputs.to("cuda") + + gen_ids = self.pretrained_model.generate( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + pixel_values=inputs.pixel_values, + image_grid_thw=inputs.image_grid_thw, + temperature=0.3, + top_p=0.9, + max_new_tokens=128, + do_sample=False, + ) + outputs_trim = [o[len(i):] for i, o in zip(inputs.input_ids, gen_ids)] + outputs_text = self.processor.batch_decode( + outputs_trim, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + if torch.distributed.get_rank() % 8 == 0: + print(f"outputs_text: {outputs_text}, {inputs.input_ids.shape}") + score = [self._extract_numeric_score(o) for o in outputs_text] + return score + + def _extract_numeric_score(self, response: str) -> float: + """Helper function: Extract score from model response and normalize to 0-1 range""" + # Match numeric scores in 50-point scale (e.g., "Final score: 41 points" or "35/50") + numeric_patterns = [ + r"Final score[::]\s*([0-9.]+)\s*points?", r"Score[::]\s*([0-9.]+)\s*points?", r"([0-9.]+)\s*/\s*50" + ] + + for pattern in numeric_patterns: + match = re.search(pattern, response, re.IGNORECASE) + if match: + try: + score = float(match.group(1)) + normalized_score = score * 0.02 # Convert 50-point scale to 0-1 + return max(0.0, min(1.0, normalized_score)) # Clamp boundary values + except ValueError: + continue + + # If no numeric score found, infer score from text description + positive_indicators = ["excellent", "outstanding", "perfect", "very good", "meets requirements"] + neutral_indicators = ["average", "acceptable", "basically meets", "partially meets"] + negative_indicators = ["poor", "does not meet", "bad", "completely mismatched"] + + response_lower = response.lower() + for idx, indicator in enumerate(positive_indicators): + if indicator.lower() in response_lower: + return 0.8 + (idx * 0.05) + + for idx, indicator in enumerate(neutral_indicators): + if indicator.lower() in response_lower: + return 0.5 + (idx * 0.05) + + for idx, indicator in enumerate(negative_indicators): + if indicator.lower() in response_lower: + return 0.2 + (idx * 0.05) + + # Return 0.0 when unable to infer + return 0.0 + + +def load_reward_models( + reward_pretrain: str, + strategy, + use_engine: bool = False, +): + if use_engine: + raise NotImplementedError("Engine is not supported for reward model") + + model_list = [] + tokenizer_list = [] + processor_list = [] + with strategy.init_model_context() as _: + cfg = json.loads(reward_pretrain) + device = get_current_device() + + for key in cfg.keys(): + pretrain_path = cfg[key] + model_config = AutoConfig.from_pretrained( + pretrain_path, + trust_remote_code=True, + ) + base = Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrain_path, + config=model_config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + trust_remote_code=True, + ) + + processor = AutoProcessor.from_pretrained( + pretrain_path, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28 + ) + processor.tokenizer.padding_side = "left" + + if key == "outcome": + model = Qwen2VLRewardModelMemeOutcome.from_pretrained(base) + elif key == "content": + model = Qwen2VLRewardModelMemeContent.from_pretrained(base) + model.set_processor(processor) + # for some case about meta device + model.to_empty(device=device) + model.eval() + model_list.append(model) + tokenizer_list.append(processor.tokenizer) + processor_list.append(processor) + return model_list, tokenizer_list, processor_list + + +def get_format_reward(response: str) -> float: + """ + Evaluate the format compliance of model response, returns a score between 0-1 + + Args: + response: The model response content to be evaluated (string) + + Returns: + Format compliance score (0-1), where 1 means fully compliant with format requirements, 0 means completely non-compliant + """ + # Initialize score and total check items + score = 0.0 + total_checks = 0 + + # 1. Check if all required sections exist + required_sections = [ + r'\[Comprehensive Description Section\]', r'\[Usage Scenarios Section\]', r'\[Text Analysis Section\]', + r'\[Specific analysis with user input\]', r'Text on the Meme:' + ] + + for section in required_sections: + total_checks += 1 + if re.search(section, response, re.IGNORECASE): + score += 1 + + # 2. Check box format in 'Text on the Meme' section + total_checks += 1 + meme_text_match = re.search(r'Text on the Meme:\s*(.*?)(?=\n\n|$)', response, re.DOTALL | re.IGNORECASE) + if meme_text_match and re.search(r'box\d+:\s*[^\n]+', meme_text_match.group(1)): + score += 1 + + # 3. Check Step 1 and Step 2 in 'Specific analysis' section + total_checks += 2 + specific_analysis_match = re.search( + r'\[Specific analysis with user input\]\s*(.*?)(?=\n\[|Text on the Meme:|$)', response, + re.DOTALL | re.IGNORECASE + ) + if specific_analysis_match: + analysis_content = specific_analysis_match.group(1) + if re.search(r'Step 1:', analysis_content, re.IGNORECASE): + score += 1 + if re.search(r'Step 2:', analysis_content, re.IGNORECASE): + score += 1 + + # Normalize to 0-1 range + return round(score / total_checks, 4) if total_checks > 0 else 0.0 + + +def reward_fn( + model_reward_list: List[torch.Tensor], # len = n_model , each shape=(B,) + labels: Sequence[str], + queries: Sequence[str], + refs: Sequence[str], + **kwargs, +) -> torch.Tensor: + # outcome reward + # model_reward_list: Shapes [2] + outcome_reward = model_reward_list[0] + dtype, device = outcome_reward.dtype, outcome_reward.device + # rule reward + format_reward = [get_format_reward(q) for q in queries] + format_reward = torch.tensor(format_reward, dtype=dtype, device=device) + if torch.distributed.get_rank() % 8 == 0: + print(f"queries: {queries[0]}") + print(f"model_reward_list: {model_reward_list}, format_reward: {format_reward}") + return format_reward + outcome_reward diff --git a/examples/godsmeme/test_meme_dataset.py b/examples/godsmeme/test_meme_dataset.py new file mode 100644 index 00000000..18a6815b --- /dev/null +++ b/examples/godsmeme/test_meme_dataset.py @@ -0,0 +1,49 @@ +from torch.utils.data import DataLoader +from meme_dataset import MemeOnlineRLDataset +from transformers import AutoProcessor + +# download from https://huggingface.co/datasets/luodi-7/Eimages +REAL_ANNOTATION = "/root/data/Eimages/train_data.json" +REAL_ROOT = "/root/data/Eimages" +MODEL_PATH = "/root/model/HUMOR-COT-Qwen2.5-VL" +processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) + + +def test_real_dataset_loading(): + dataset = MemeOnlineRLDataset( + annotation_path=REAL_ANNOTATION, root_dir=REAL_ROOT, shuffle=False, processor=processor + ) + + # print(len(dataset)) + assert len(dataset) == 3345 + + prompt, image_path, label, reference = dataset[0] + assert "Meme Text Generation Framework" in prompt + assert "[Comprehensive Description Section]" in reference + assert isinstance(image_path[0], str) + # from PIL import Image + # image = Image.open(image_path[0]) + + # outputs = processor( + # text=prompt, + # images=image, + # add_special_tokens=False, + # max_length=4096, + # truncation=True, + # ) + + +def test_collate_fn(): + dataset = MemeOnlineRLDataset(annotation_path=REAL_ANNOTATION, root_dir=REAL_ROOT, processor=processor) + loader = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn) + + batch = next(iter(loader)) + assert len(batch[0]) == 4 + assert len(batch[1]) == 4 + assert len(batch[2]) == 4 + assert len(batch[3]) == 4 + + +if __name__ == "__main__": + test_real_dataset_loading() + test_collate_fn() diff --git a/examples/godsmeme/train_colocate.py b/examples/godsmeme/train_colocate.py new file mode 100644 index 00000000..903c8543 --- /dev/null +++ b/examples/godsmeme/train_colocate.py @@ -0,0 +1,407 @@ +import argparse +import itertools +import math +import re +import os +import sys +import json +from datetime import datetime +from typing import Callable, Dict, List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lightrft.utils import get_tokenizer_processor_vl, ensure_video_input_available +from lightrft.models import ActorVL +from lightrft.strategy import get_strategy +from lightrft.trainer import SPMDPPOTrainerVL +from lightrft.utils import add_arguments + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from meme_dataset import MemeOnlineRLDataset +from reward_model import load_reward_models, reward_fn + +ensure_video_input_available() + + +def train(args): + # configure strategy + strategy = get_strategy(args) + + ds_train_cfg = strategy.get_ds_train_config(is_actor=True) if not args.fsdp else None + ds_eval_cfg = strategy.get_ds_eval_config(offload=False) if not args.fsdp else None + + # configure model (optionally meta_init to save CPU memory) + with strategy.init_model_context(meta_init=getattr(args, "meta_init", False)): + actor = ActorVL( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=ds_train_cfg, + packing_samples=args.packing_samples, + disable_logprobs_flashattn=args.disable_logprobs_flashattn, + fused_linear_logprob=args.fused_linear_logprob, + high_entropy_token_ratio=getattr(args, "high_entropy_token_ratio", 0.0), + ) + + if args.actor_init_on_gpu: + actor = actor.to(torch.cuda.current_device()) + + # pre-prepare is used for saving RAM memory when training 72B model + if args.fsdp: + setattr(actor, "is_actor", True) + actor = strategy.prepare_model(actor, is_training=True) + + # configure tokenizer and processor + tokenizer, processor = get_tokenizer_processor_vl( + args.pretrain, actor.model, "left", use_fast=not strategy.args.disable_fast_tokenizer + ) + assert processor is not None, "processor is None" + + if args.freeze_prefix: + # visual for qwenvl, vision_model for internvl, mlp1 for linear layer in internvl + freeze_prefix = ["visual", "vision_model", "mlp1"] + frozen_params_count = 0 + total_params_count = 0 + for name, param in actor.model.named_parameters(): + total_params_count += 1 + if any(name.startswith(prefix) for prefix in freeze_prefix): + param.requires_grad = False + frozen_params_count += 1 + strategy.print( + f"Froze {frozen_params_count}/{total_params_count} parameters based on prefixes: {freeze_prefix}" + ) + + critic = None + + strategy.report_memory("before loaded reward models in main entry") + reward_models, reward_tokenizers, reward_processors, label_map = load_reward_models( + args.reward_pretrain, strategy, use_engine=args.rm_use_engine + ) + strategy.print(f"label_map: {label_map}") + strategy.report_memory("after loaded reward models in main entry") + + strategy.print(actor) + + # load weights for reference actor + if args.init_kl_coef == 0: + initial_model = None + else: + initial_model = ActorVL( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=ds_eval_cfg, + packing_samples=args.packing_samples, + fused_linear_logprob=args.fused_linear_logprob, + ) + + if args.fsdp: + initial_model = strategy.prepare_model(initial_model, is_training=False, shard_size=strategy.world_size) + strategy.offload_model(initial_model) + + # prepare datasets (meme-specific: annotation_path + root_dir) + annotation_path = args.annotation_path + root_dir = args.root_dir + prompts_dataset = MemeOnlineRLDataset( + annotation_path=annotation_path, + root_dir=root_dir, + shuffle=True, + processor=processor, + ) + strategy.print(f"Loaded {len(prompts_dataset)} samples for prompts.") + + # prepare dataloader + prompts_dataloader = strategy.setup_dataloader( + prompts_dataset, + args.rollout_batch_size // strategy.world_size, + True, + True, + collate_fn=prompts_dataset.collate_fn + ) + pretrain_dataloader = None + + # for scheduler + num_update_steps_per_episodes = ( + len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs + ) + max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes) + + # gradient_checkpointing + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + if critic is not None: + critic.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + ( + (actor, actor_optim, actor_scheduler), + (critic, critic_optim, critic_scheduler), + reward_models, + initial_model, + ) = strategy.prepare_models_and_optimizers(actor, critic, reward_models, initial_model, args, max_steps) + + strategy.print(reward_models) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): + _, states = strategy.load_ckpt( + actor.model, os.path.join(args.ckpt_path, "_actor"), optimizer=actor_optim, scheduler=actor_scheduler + ) + if args.critic_pretrain: + strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic")) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + strategy.report_memory("after models init") + + strategy.setup_inference_engine(args, engine_type=args.engine_type, actor=actor) + + # configure Trainer + trainer = SPMDPPOTrainerVL( + strategy, + actor, + critic, + reward_models, + initial_model, + None, + actor_optim, + critic_optim, + actor_scheduler, + critic_scheduler, + max_epochs=args.max_epochs, + micro_train_batch_size=args.micro_train_batch_size, + micro_rollout_batch_size=args.micro_rollout_batch_size, + gradient_checkpointing=args.gradient_checkpointing, + tokenizer=tokenizer, + processor=processor, + prompt_max_len=args.prompt_max_len, + eps_clip=args.eps_clip, + gamma=args.gamma, + lambd=args.lambd, + init_kl_coef=args.init_kl_coef, + kl_target=args.kl_target, + ema_beta=0.992, + max_norm=args.max_norm, + # for GPT generation + do_sample=True, + max_new_tokens=args.generate_max_len, + max_length=args.max_len, + temperature=args.temperature, + top_p=args.top_p, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + # reward model + reward_fn=reward_fn, + reward_fn_label_map=label_map, + reward_tokenizers=reward_tokenizers, + save_hf_ckpt=args.save_hf_ckpt, + disable_ds_ckpt=args.disable_ds_ckpt, + packing_samples=args.packing_samples, + ) + + trainer.fit( + args, + prompts_dataloader=prompts_dataloader, + pretrain_dataloader=pretrain_dataloader, + eval_dataloader=None, + consumed_samples=consumed_samples, + num_update_steps_per_episodes=num_update_steps_per_episodes, + ) + + # save model checkpoint after fitting on only rank0 + strategy.save_model( + actor, + tokenizer, + args.save_path, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--engine_type", type=str, default="vllm", help="Choose inference engine type: vllm, sglang") + + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--save_hf_ckpt", action="store_true", default=False) + parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # PPO + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--rollout_batch_size", type=int, default=512) + parser.add_argument("--micro_rollout_batch_size", type=int, default=8) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt") + parser.add_argument("--generate_max_len", type=int, default=1024, help="Max tokens to generate in PPO") + parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len") + parser.add_argument("--max_samples", type=int, default=1000000) + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") + parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd") + parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") + parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument( + "--freeze_prefix", + action="store_true", + default=False, + help="Freeze the prefix part (e.g. vision encoder) of the actor model" + ) + parser.add_argument( + "--n_samples_per_prompt", type=int, default=1, help="number of responses for each prompt in generation" + ) + parser.add_argument("--actor_learning_rate", type=float, default=1e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--kl_target", type=float, default=None) + parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO") + parser.add_argument( + "--kl_estimator", + type=str, + default="k1", + choices=["k1", "k2", "k3"], + help=( + "In GRPO, k3 is utilized as the loss function, while k2, when used as the loss, is nearly equivalent to k1." + ), + ) + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range") + + # DeepSpeed + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--actor_init_on_gpu", action="store_true", default=False) + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + parser.add_argument( + "--disable_logprobs_flashattn", + action="store_true", + default=False, + help="Disable flash attn implementation in log_probs calculation" + ) + + # Reinforce + parser.add_argument( + "--advantage_estimator", + type=str, + choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm"], + default="gae", + help="Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm", + ) + + parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO") + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # Models + parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API") + + # Custom dataset (GodsMeme) + parser.add_argument( + "--annotation_path", + type=str, + default= + "/fs-computility/niuyazhe/shared/xueyingyi/xueyingyi/cot_picture/Eimages/annotations/all/train_data.jsonl", + help="Path to meme annotation JSONL", + ) + parser.add_argument( + "--root_dir", + type=str, + default="/fs-computility/niuyazhe/shared/xueyingyi/xueyingyi/cot_picture/Eimage_drawn", + help="Root directory for meme images", + ) + parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--prompt_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--prompt_split", type=str, default="train") + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_ppo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + # MultiModal + parser.add_argument( + "--limit_mm_image_per_prompt", + type=int, + default=-1, + help="the max image number of each text in multi model for inference backend" + ) + + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument( + "--use_cpg_loss", + action="store_true", + default=False, + help="whether to use the clipped policy gradient loss from CPGD" + ) + + add_arguments(parser) + + args = parser.parse_args() + + if args.advantage_estimator not in ["gae"]: + args.critic_pretrain = None + + if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]: + assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" + + if args.use_kl_loss: + if args.kl_estimator not in ["k2", "k3"]: + print(f"Recommend setting {args.kl_estimator} to 'k2' or 'k3' when using KL as a loss") + else: + if args.kl_estimator not in ["k1"]: + print(f"Recommend setting {args.kl_estimator} to 'k1' when not using KL as a loss.") + + train(args) diff --git a/pyproject.toml b/pyproject.toml index 57d7043c..909d84d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "transformers>=4.57.1", "sglang>=0.5.6.post2", "deepspeed>=0.18.3", - "flash-attn>=2.8.3", + #"flash-attn>=2.8.3", "accelerate", "datasets", "wandb", From 043fbf57eb0a09b97249f80f6af9e4469591b473 Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Thu, 16 Apr 2026 15:52:19 +0800 Subject: [PATCH 2/5] feature(nyz): add pairwise rm pipeline --- examples/godsmeme/README.md | 124 +++ examples/godsmeme/meme_dataset.py | 123 +-- examples/godsmeme/meme_utils.py | 431 ++++++++++ examples/godsmeme/prompts/generate_meme.txt | 77 ++ examples/godsmeme/prompts/reward_compare.txt | 24 + examples/godsmeme/reward_model.py | 849 ++++++++++--------- examples/godsmeme/test_reward_model_vllm.py | 103 +++ examples/godsmeme/train_colocate.py | 15 +- lightrft/trainer/fast_exp_maker.py | 1 + 9 files changed, 1273 insertions(+), 474 deletions(-) create mode 100644 examples/godsmeme/README.md create mode 100644 examples/godsmeme/meme_utils.py create mode 100644 examples/godsmeme/prompts/generate_meme.txt create mode 100644 examples/godsmeme/prompts/reward_compare.txt create mode 100644 examples/godsmeme/test_reward_model_vllm.py diff --git a/examples/godsmeme/README.md b/examples/godsmeme/README.md new file mode 100644 index 00000000..da2fdf7e --- /dev/null +++ b/examples/godsmeme/README.md @@ -0,0 +1,124 @@ +# GodsMeme GRPO Pipeline + +This example adapts LightRFT's vision-language GRPO stack to meme generation. +The policy consumes the already-formatted GodsMeme RL prompt from the dataset, generates the full reasoning-style response, and the reward pipeline renders the final meme text back onto the image before running a pairwise meme judge. + +## What is included + +- `train_colocate.py`: main GRPO training entry for meme RL. +- `meme_dataset.py`: dataset loader for the prepared GodsMeme RL JSON/JSONL rows. +- `meme_utils.py`: GodsMeme-specific parsing, rendering, and pair aggregation helpers. +- `reward_model.py`: pairwise meme judge wrapper and reward aggregation logic. +- `prompts/generate_meme.txt`: reference policy prompt format. +- `prompts/reward_compare.txt`: pairwise comparison prompt for the meme reward judge. +- `run_meme_grpo.sh`: example launch script. +- `test_reward_model_vllm.py`: optional real-model integration smoke test using vLLM. + +## Reward flow + +For each rollout group produced from the same prompt: + +1. Parse the policy completion and extract the `Text on the Meme` section. +2. Render the generated text onto the base image using the dataset's detection boxes when available. +3. Build pairwise comparisons inside the rollout group. +4. Ask a VLM judge which rendered meme is better. +5. Convert pairwise wins/scores into one scalar reward per sample. +6. Add a small format reward that checks the required GodsMeme response structure. + +The reward weights are read from `--reward_pretrain` JSON config: + +```text +final_reward = model_reward_weight * pairwise_reward + + format_reward_weight * format_reward +``` + +Defaults: + +- `model_reward_weight = 1.0` +- `format_reward_weight = 0.1` + +## Important rollout constraint + +The pairwise reward is computed inside each rollout micro-batch, so keep every GRPO group fully contained in one micro-batch: + +```text +micro_rollout_batch_size % n_samples_per_prompt == 0 +``` + +For standard GRPO runs, use `--advantage_estimator group_norm` and set `--n_samples_per_prompt > 1`. + +## Dataset expectation + +The RL dataset is expected to already be formatted as conversation-style GodsMeme rows: + +```json +{ + "id": "sample-001", + "image": "images/cat.jpg", + "conversations": [ + {"from": "human", "value": "...GodsMeme prompt... "}, + {"from": "assistant", "value": "...reference meme response..."} + ], + "text_loc_info": { + "loc": [[40, 60, 420, 170], [40, 330, 420, 430]] + } +} +``` + +Supported box metadata keys include `detections`, `text_loc_info`, `loc`, `bbox_scale`, and `bbox_normalized`. +If no boxes are available, the renderer falls back to a simple top/bottom banner layout. + +## Reward model configuration + +`--reward_pretrain` accepts either: + +1. A plain Hugging Face model path for the pairwise judge. +2. A JSON object if you want to override judge settings. + +Plain path example: + +```bash +--reward_pretrain /path/to/Qwen2.5-VL-7B-Instruct +``` + +JSON example: + +```bash +--reward_pretrain '{ + "pairwise": { + "path": "/path/to/Qwen2.5-VL-7B-Instruct", + "max_new_tokens": 96, + "pair_batch_size": 4, + "max_pairs_per_group": 0, + "model_reward_weight": 1.0, + "format_reward_weight": 0.1, + "reward_prompt_path": "examples/godsmeme/prompts/reward_compare.txt" + } +}' +``` + +## Running + +Edit the environment variables in `run_meme_grpo.sh`, then launch: + +```bash +bash examples/godsmeme/run_meme_grpo.sh +``` + +## Validation + +Lightweight unit tests: + +```bash +pytest examples/godsmeme/test_meme_dataset.py examples/godsmeme/test_reward_model.py +``` + +Optional real-model vLLM smoke test: + +```bash +RUN_GODSMEME_VLLM_TEST=1 \ +GODSMEME_REWARD_MODEL_PATH=/path/to/Qwen2.5-VL-7B-Instruct \ +GODSMEME_ANNOTATION_PATH=/path/to/train_data.jsonl \ +GODSMEME_IMAGE_ROOT=/path/to/image_root \ +pytest examples/godsmeme/test_reward_model_vllm.py -k real_data +``` diff --git a/examples/godsmeme/meme_dataset.py b/examples/godsmeme/meme_dataset.py index 4610cd91..28d8f41f 100644 --- a/examples/godsmeme/meme_dataset.py +++ b/examples/godsmeme/meme_dataset.py @@ -1,17 +1,19 @@ -from typing import List, Dict, Union -import os import json +import os import random -import torch -from PIL import Image +import re +from typing import Any, Dict, List, Tuple, Union + from torch.utils.data import Dataset +from meme_utils import extract_box_texts, resolve_expected_box_count + class MemeOnlineRLDataset(Dataset): - """Meme dataset class with lazy loading per item. - Supports both JSON array format (e.g. eval_data.json) and JSONL format. - """ - ASSISTANT_ROLES = ('gpt', 'assistant') + """Meme dataset class with lazy loading per item.""" + + ASSISTANT_ROLES = ("gpt", "assistant") + DEFAULT_LABEL = "meme_pairwise" def __init__( self, @@ -31,18 +33,13 @@ def __init__( self.annotation_path = annotation_path self.processor = processor - # Only load the raw data lines without processing self._raw_data = self._load_raw_data() if shuffle: random.shuffle(self._raw_data) - def _load_raw_data(self) -> List[Union[Dict, str]]: - """Load annotation file. Supports: - - JSON array format: entire file is one JSON list (e.g. eval_data.json) - - JSONL format: one JSON object per line - """ - with open(self.annotation_path, 'r', encoding='utf-8') as f: - content = f.read().strip() + def _load_raw_data(self) -> List[Union[Dict[str, Any], str]]: + with open(self.annotation_path, "r", encoding="utf-8") as handle: + content = handle.read().strip() if not content: return [] try: @@ -51,21 +48,59 @@ def _load_raw_data(self) -> List[Union[Dict, str]]: return data except json.JSONDecodeError: pass - with open(self.annotation_path, 'r', encoding='utf-8') as f: - return [line.strip() for line in f if line.strip()] - def _process_item(self, raw_item: Union[Dict, str]) -> Dict: - """Process a single data item on-demand. raw_item is either a dict or JSON string.""" - data = raw_item if isinstance(raw_item, dict) else json.loads(raw_item) - image_path = os.path.join(self.root_dir, data['image']) \ - if not os.path.isabs(data['image']) else data['image'] + with open(self.annotation_path, "r", encoding="utf-8") as handle: + return [line.strip() for line in handle if line.strip()] + def _resolve_image_path(self, data: Dict[str, Any]) -> str: + image_value = data.get("image") or data.get("image_path") or data.get("img") + if not image_value: + raise KeyError("Dataset row is missing `image`") + image_path = image_value if os.path.isabs(image_value) else os.path.join(self.root_dir, image_value) if not os.path.exists(image_path): raise FileNotFoundError(f"Image {image_path} does not exist") + return image_path - conversations = data['conversations'] - human_input = next(c['value'] for c in conversations if c['from'] == 'human' and '' in c['value']) - assistant_output = next(c['value'] for c in conversations if c['from'] in self.ASSISTANT_ROLES) + def _extract_user_request_from_prompt(self, prompt_text: str) -> str: + patterns = [ + r"\*\*User Input Parameters\*\*:\s*(.*?)(?=\n\s*\*\*Text on the Meme\*\*:|\Z)", + r"Input Parameters\s*:\s*(\[[^\n]+\])", + ] + for pattern in patterns: + match = re.search(pattern, prompt_text, re.IGNORECASE | re.DOTALL) + if match: + return match.group(1).strip() + return "" + + def _build_reference(self, data: Dict[str, Any], prompt_text: str, assistant_output: str) -> Dict[str, Any]: + reference: Dict[str, Any] = { + "id": data.get("id"), + "group_id": str(data.get("group_id") or data.get("sample_id") or data.get("id") or ""), + "user_request": self._extract_user_request_from_prompt(prompt_text), + "reference_output": assistant_output, + } + + for key in ("detections", "text_loc_info", "loc", "bbox_scale", "bbox_normalized", "expected_box_count"): + if key in data: + reference[key] = data[key] + + expected_box_count = resolve_expected_box_count(reference) + if expected_box_count is None: + box_texts = extract_box_texts(assistant_output) + if box_texts: + expected_box_count = len(box_texts) + if expected_box_count is not None: + reference["expected_box_count"] = expected_box_count + + return reference + + def _process_item(self, raw_item: Union[Dict[str, Any], str]) -> Tuple[str, List[str], Dict[str, Any], str]: + data = raw_item if isinstance(raw_item, dict) else json.loads(raw_item) + image_path = self._resolve_image_path(data) + + conversations = data["conversations"] + human_input = next(c["value"] for c in conversations if c["from"] == "human" and "" in c["value"]) + assistant_output = next(c["value"] for c in conversations if c["from"] in self.ASSISTANT_ROLES) prompt = [{ "role": "user", @@ -78,38 +113,24 @@ def _process_item(self, raw_item: Union[Dict, str]) -> Dict: "type": "text", "text": human_input }, - ] + ], }] prompt = self.processor.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + reference = self._build_reference(data, human_input, assistant_output) + label = data.get("reward_rule_label", self.DEFAULT_LABEL) - return { - 'id': data['id'], - 'image_path': [image_path], - 'prompt': prompt, - # 'prompt': human_input + "\nAssistant:", - 'label': assistant_output, - } + return prompt, [image_path], reference, label - def __getitem__(self, index) -> Dict: - """Process and return item only when requested""" - # # Process the item on-demand (lazy loading) - raw_item = self._raw_data[index] - processed_item = self._process_item(raw_item) - # # text, image, label, reference - return ( - processed_item['prompt'], - processed_item['image_path'], - processed_item['label'], - processed_item['label'], - ) + def __getitem__(self, index: int) -> Tuple[str, List[str], Dict[str, Any], str]: + return self._process_item(self._raw_data[index]) def __len__(self) -> int: return len(self._raw_data) @staticmethod - def collate_fn(batch: List[Dict]): - text_list = [x[0] for x in batch] - image_list = [x[1] for x in batch] - reference_list = [x[2] for x in batch] - label_list = [x[3] for x in batch] + def collate_fn(batch: List[Tuple[str, List[str], Dict[str, Any], str]]): + text_list = [item[0] for item in batch] + image_list = [item[1] for item in batch] + reference_list = [item[2] for item in batch] + label_list = [item[3] for item in batch] return text_list, image_list, reference_list, label_list diff --git a/examples/godsmeme/meme_utils.py b/examples/godsmeme/meme_utils.py new file mode 100644 index 00000000..3671a5f2 --- /dev/null +++ b/examples/godsmeme/meme_utils.py @@ -0,0 +1,431 @@ +import os +import re +import random +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from PIL import Image, ImageDraw, ImageFont + +_REQUIRED_SECTION_PATTERNS = [ + r"\[Comprehensive Description Section\]", + r"\[Usage Scenarios Section\]", + r"\[Text Analysis Section\]", + r"\[Specific Analysis with User Input\]", + r"Text on the Meme:", +] + + +@dataclass +class MemeRenderConfig: + font_name: str = "DejaVuSans.ttf" + min_font_size: int = 14 + max_font_size: int = 72 + line_spacing: int = 4 + outline_width: int = 2 + margin: int = 6 + default_padding_ratio: float = 0.06 + + +@dataclass +class PairwisePreference: + index_a: int + index_b: int + score_a: float + score_b: float + + +def load_text_file(path: str) -> str: + with open(path, "r", encoding="utf-8") as handle: + return handle.read().strip() + + +def extract_assistant_response(text: str) -> str: + if "<|im_start|>assistant" in text: + return text.split("<|im_start|>assistant")[-1].strip() + if "assistant\n" in text: + return text.split("assistant\n")[-1].strip() + return text.strip() + + +def extract_text_on_meme_section(response: str) -> str: + response = extract_assistant_response(response) + match = re.search(r"Text on the Meme:\s*(.*)$", response, re.IGNORECASE | re.DOTALL) + if not match: + return "" + return match.group(1).strip() + + +def extract_box_texts(response: str) -> List[str]: + section = extract_text_on_meme_section(response) + if not section: + return [] + + matches = list(re.finditer( + r"(?im)^\s*box\s*(\d+)\s*:\s*(.*?)(?=^\s*box\s*\d+\s*:|\Z)", + section, + re.DOTALL, + )) + if matches: + items = [] + for match in matches: + payload = match.group(2).strip().strip('"').strip() + if payload: + items.append(payload) + if items: + return items + + lines = [] + for raw_line in section.splitlines(): + line = raw_line.strip().strip('"').strip() + if line: + lines.append(line) + return lines + + +def compute_meme_format_reward(response: str, expected_boxes: Optional[int] = None) -> float: + response = extract_assistant_response(response) + checks = 0 + passed = 0 + + for pattern in _REQUIRED_SECTION_PATTERNS: + checks += 1 + if re.search(pattern, response, re.IGNORECASE): + passed += 1 + + checks += 2 + if re.search(r"(?im)^\s*Step\s*1\s*:", response): + passed += 1 + if re.search(r"(?im)^\s*Step\s*2\s*:", response): + passed += 1 + + fragments = extract_box_texts(response) + checks += 2 + if fragments: + passed += 1 + if fragments and all(fragment.strip() for fragment in fragments): + passed += 1 + + if expected_boxes is not None and expected_boxes > 0: + checks += 1 + if len(fragments) == expected_boxes: + passed += 1 + + return float(passed) / float(checks) if checks else 0.0 + + +def normalize_bbox(bbox: Any) -> Optional[Tuple[float, float, float, float]]: + if bbox is None: + return None + if isinstance(bbox, dict): + for key in ("bbox", "box", "loc", "coordinates"): + if key in bbox: + bbox = bbox[key] + break + if isinstance(bbox, tuple): + bbox = list(bbox) + if not isinstance(bbox, list) or len(bbox) != 4: + return None + + values: List[float] = [] + for item in bbox: + try: + values.append(float(item)) + except (TypeError, ValueError): + return None + return tuple(values) + + +def normalize_detections(reference: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: + if not isinstance(reference, dict): + return [] + + raw_detections = reference.get("detections") + if raw_detections is None and isinstance(reference.get("text_loc_info"), dict): + text_loc_info = reference["text_loc_info"] + locs = text_loc_info.get("loc", []) + texts = text_loc_info.get("text", []) + if isinstance(texts, str): + texts = extract_box_texts(f"Text on the Meme:\n{texts}") + raw_detections = [{"bbox": loc, "text": texts[idx] if idx < len(texts) else ""} for idx, loc in enumerate(locs)] + if raw_detections is None and isinstance(reference.get("loc"), list): + raw_detections = [{"bbox": loc} for loc in reference.get("loc", [])] + + detections: List[Dict[str, Any]] = [] + for item in raw_detections or []: + bbox = normalize_bbox(item) + if bbox is None: + continue + payload = item if isinstance(item, dict) else {"bbox": item} + detections.append({"bbox": bbox, **payload}) + return detections + + +def resolve_expected_box_count(reference: Optional[Dict[str, Any]]) -> Optional[int]: + detections = normalize_detections(reference) + if detections: + return len(detections) + if isinstance(reference, dict) and isinstance(reference.get("expected_box_count"), int): + return reference["expected_box_count"] + return None + + +def get_user_request(reference: Optional[Dict[str, Any]], fallback_prompt: Optional[str] = None) -> str: + if isinstance(reference, dict): + for key in ("user_input_text", "user_request", "input_params", "prompt_summary"): + value = reference.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return (fallback_prompt or "").strip() + + +def get_reference_output(reference: Optional[Dict[str, Any]]) -> str: + if not isinstance(reference, dict): + return "" + for key in ("reference_output", "reference", "answer", "label_text"): + value = reference.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + +def get_reference_group_id(reference: Optional[Dict[str, Any]], fallback_index: int) -> str: + if isinstance(reference, dict): + for key in ("group_id", "sample_id", "id"): + value = reference.get(key) + if value is not None: + return str(value) + return str(fallback_index) + + +def get_first_image(raw_image: Any) -> Optional[Image.Image]: + if raw_image is None: + return None + if isinstance(raw_image, Image.Image): + return raw_image + if isinstance(raw_image, list) and raw_image: + return get_first_image(raw_image[0]) + return None + + +def _resolve_font_path(font_name: str) -> Optional[str]: + search_roots = [] + env_root = os.getenv("MEMEGENERATOR_FONT_DIR") + if env_root: + search_roots.append(os.path.expanduser(env_root)) + search_roots.extend([ + "/usr/share/fonts/truetype/dejavu", + "/usr/share/fonts/truetype/liberation2", + "/System/Library/Fonts", + "/Library/Fonts", + ]) + for root in search_roots: + candidate = os.path.join(root, font_name) + if os.path.exists(candidate): + return candidate + return font_name + + +def _load_font(font_name: str, font_size: int) -> ImageFont.FreeTypeFont: + try: + return ImageFont.truetype(_resolve_font_path(font_name), font_size) + except OSError: + return ImageFont.load_default() + + +def _measure_text(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont) -> Tuple[int, int]: + if hasattr(draw, "textbbox"): + left, top, right, bottom = draw.textbbox((0, 0), text, font=font) + return right - left, bottom - top + return draw.textsize(text, font=font) + + +def _wrap_text(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont, max_width: int) -> List[str]: + words = text.split() + if not words: + return [""] + + lines: List[str] = [] + current = words[0] + for word in words[1:]: + candidate = f"{current} {word}".strip() + if _measure_text(draw, candidate, font)[0] <= max_width: + current = candidate + else: + lines.append(current) + current = word + lines.append(current) + return lines + + +def _fit_font(draw: ImageDraw.ImageDraw, text: str, box: Tuple[int, int, int, int], + config: MemeRenderConfig) -> Tuple[ImageFont.ImageFont, List[str]]: + x1, y1, x2, y2 = box + max_width = max(1, x2 - x1 - (2 * config.margin)) + max_height = max(1, y2 - y1 - (2 * config.margin)) + + best_font: ImageFont.ImageFont = _load_font(config.font_name, config.min_font_size) + best_lines = _wrap_text(draw, text, best_font, max_width) + + for size in range(config.min_font_size, config.max_font_size + 1): + font = _load_font(config.font_name, size) + lines = _wrap_text(draw, text, font, max_width) + line_heights = [_measure_text(draw, line, font)[1] for line in lines] + total_height = sum(line_heights) + config.line_spacing * max(0, len(lines) - 1) + line_width = max(_measure_text(draw, line, font)[0] for line in lines) + if line_width <= max_width and total_height <= max_height: + best_font = font + best_lines = lines + else: + break + + return best_font, best_lines + + +def _choose_colors(image: Image.Image, box: Tuple[int, int, int, + int]) -> Tuple[Tuple[int, int, int], Tuple[int, int, int]]: + crop = image.crop(box) + mean_pixel = crop.resize((1, 1)).getpixel((0, 0)) if crop.size[0] > 0 and crop.size[1] > 0 else (255, 255, 255) + if isinstance(mean_pixel, int): + mean_pixel = (mean_pixel, mean_pixel, mean_pixel) + luminance = (0.299 * mean_pixel[0]) + (0.587 * mean_pixel[1]) + (0.114 * mean_pixel[2]) + if luminance > 186: + return (0, 0, 0), (255, 255, 255) + return (255, 255, 255), (0, 0, 0) + + +def _scale_box( + bbox: Tuple[float, float, float, float], + image: Image.Image, + reference: Optional[Dict[str, Any]] = None, +) -> Tuple[int, int, int, int]: + width, height = image.size + bbox_scale = None + normalized = False + if isinstance(reference, dict): + bbox_scale = reference.get("bbox_scale") + normalized = bool(reference.get("bbox_normalized", False)) + + x1, y1, x2, y2 = bbox + values = [x1, y1, x2, y2] + if bbox_scale is not None: + scale = float(bbox_scale) + x1, x2 = (x1 / scale) * width, (x2 / scale) * width + y1, y2 = (y1 / scale) * height, (y2 / scale) * height + elif normalized or max(values) <= 1.0: + x1, x2 = x1 * width, x2 * width + y1, y2 = y1 * height, y2 * height + + left = max(0, min(int(round(x1)), width - 1)) + top = max(0, min(int(round(y1)), height - 1)) + right = max(left + 1, min(int(round(x2)), width)) + bottom = max(top + 1, min(int(round(y2)), height)) + return left, top, right, bottom + + +def default_render_boxes(image: Image.Image, count: int, config: MemeRenderConfig) -> List[Tuple[int, int, int, int]]: + width, height = image.size + pad_x = int(width * config.default_padding_ratio) + pad_y = int(height * config.default_padding_ratio) + if count <= 1: + return [(pad_x, pad_y, width - pad_x, max(pad_y + 1, int(height * 0.2)))] + top_box = (pad_x, pad_y, width - pad_x, max(pad_y + 1, int(height * 0.2))) + bottom_box = (pad_x, max(int(height * 0.78), pad_y), width - pad_x, height - pad_y) + boxes = [top_box, bottom_box] + if count > 2: + middle_height = max(1, int((height * 0.58) / max(1, count - 2))) + start_y = int(height * 0.24) + for idx in range(count - 2): + y1 = start_y + (idx * middle_height) + y2 = min(height - pad_y, y1 + middle_height) + boxes.insert(-1, (pad_x, y1, width - pad_x, max(y1 + 1, y2))) + return boxes[:count] + + +def render_meme_image( + image: Image.Image, + texts: Sequence[str], + detections: Optional[Sequence[Dict[str, Any]]] = None, + reference: Optional[Dict[str, Any]] = None, + config: Optional[MemeRenderConfig] = None, +) -> Image.Image: + render_config = config or MemeRenderConfig() + canvas = image.convert("RGB").copy() + draw = ImageDraw.Draw(canvas) + + cleaned_texts = [text.strip() for text in texts if text and text.strip()] + if not cleaned_texts: + return canvas + + boxes: List[Tuple[int, int, int, int]] = [] + for detection in detections or []: + bbox = normalize_bbox(detection) + if bbox is None: + continue + boxes.append(_scale_box(bbox, canvas, reference=reference)) + + if not boxes: + boxes = default_render_boxes(canvas, len(cleaned_texts), render_config) + + if len(cleaned_texts) > len(boxes): + merged = list(cleaned_texts[:len(boxes)]) + merged[-1] = "\n".join([merged[-1], *cleaned_texts[len(boxes):]]).strip() + cleaned_texts = merged + else: + cleaned_texts = cleaned_texts[:len(boxes)] + + for box, text in zip(boxes, cleaned_texts): + font, wrapped_lines = _fit_font(draw, text, box, render_config) + fill_color, outline_color = _choose_colors(canvas, box) + line_sizes = [_measure_text(draw, line, font) for line in wrapped_lines] + total_height = sum(height + for _, height in line_sizes) + render_config.line_spacing * max(0, + len(wrapped_lines) - 1) + x1, y1, x2, y2 = box + current_y = y1 + max(render_config.margin, (y2 - y1 - total_height) // 2) + + for line, (line_width, line_height) in zip(wrapped_lines, line_sizes): + current_x = x1 + max(render_config.margin, (x2 - x1 - line_width) // 2) + for dx in range(-render_config.outline_width, render_config.outline_width + 1): + for dy in range(-render_config.outline_width, render_config.outline_width + 1): + if dx == 0 and dy == 0: + continue + draw.text((current_x + dx, current_y + dy), line, font=font, fill=outline_color) + draw.text((current_x, current_y), line, font=font, fill=fill_color) + current_y += line_height + render_config.line_spacing + + return canvas + + +def sample_group_pairs(group_size: int, max_pairs: int = 0, seed: Optional[int] = None) -> List[Tuple[int, int]]: + pairs = [(left, right) for left in range(group_size) for right in range(left + 1, group_size)] + if max_pairs <= 0 or len(pairs) <= max_pairs: + return pairs + rng = random.Random(seed) + rng.shuffle(pairs) + return pairs[:max_pairs] + + +def aggregate_pairwise_preferences( + batch_size: int, + preferences: Iterable[PairwisePreference], +) -> List[float]: + totals = [0.0 for _ in range(batch_size)] + counts = [0 for _ in range(batch_size)] + + for pref in preferences: + denom = pref.score_a + pref.score_b + if denom <= 0: + norm_a = 0.5 + norm_b = 0.5 + else: + norm_a = pref.score_a / denom + norm_b = pref.score_b / denom + totals[pref.index_a] += norm_a + totals[pref.index_b] += norm_b + counts[pref.index_a] += 1 + counts[pref.index_b] += 1 + + rewards = [] + for total, count in zip(totals, counts): + rewards.append(total / count if count > 0 else 0.0) + return rewards diff --git a/examples/godsmeme/prompts/generate_meme.txt b/examples/godsmeme/prompts/generate_meme.txt new file mode 100644 index 00000000..a056bbf5 --- /dev/null +++ b/examples/godsmeme/prompts/generate_meme.txt @@ -0,0 +1,77 @@ +**Meme Text Generation Framework** +Based on the meme basemap and user input, analyze what can be written on this basemap that meets the user's needs and is as humorous as possible. +**Phase 1: Base Image Analysis** +[Comprehensive Description Section] +- **Visual Deconstruction**: +- Primary subjects (demeanor/movement/apparel of entities) +- Composition logic (focal points/color contrast/spatial relationships) +- Cultural signifiers (recognizable meme formats/pop culture references) +- Narrative cues (body language implications/prop symbolism) +- Keep the tone vivid and concrete; avoid flat, purely descriptive lists. + +[Usage Scenarios Section] +- **Scenario Modeling**: +- Social contexts (group chats/comment sections/private conversations) +- Topic alignment (workplace culture/life struggles/viral trends) +- Emotional mapping (sarcasm/self-deprecation/absurdist/dark humor) +- Cross-platform adaptation (short video captions/chat stickers/forum posts) +- Highlight why the scenario is funny or relatable instead of just stating where it fits. + +[Text Analysis Section] +- **Humor Engineering**: +- Wordplay (puns/homophones/semantic reversal) +- Cognitive dissonance (expectation subversion/scale exaggeration/role mismatch) +- Emotional resonance (generational gaps/life frustrations/cringe moments) +- Format optimization (suspenseful opening line/punchline reversal/rhyme schemes) +- Use the tone and rhythm of the sentences demonstrated in this framework in the [Text Analysis Section] to create a consistent overall sense of language, defaulting to the two-line setup → punchline cadence shown in the example (e.g., “When… / But…”). + +--- + +**Phase 2: Customization Process** +[Specific Analysis with User Input] + +**Step 1: Contextual Bridging** +- **Input Decoding**: +- Quantify [Intensity] as dramatic escalation (0-10 scale) +- Map [Intent] to visual elements' interactive potential +- Establish topological connections between [Context/Theme] and meme formats +- Even if the user does not supply sentences, treat the demonstration’s setup/punchline cadence as the blueprint; only replace the thematic nouns/adjectives with the requested keywords to build the new contrast. + +**Step 2: Humor Optimization** +- **Multidimensional Strategies**: +- Tone calibration: Adjust phrasing sharpness using [Keywords] +- Tension building: Create contrast between static imagery and dynamic text +- Cultural alignment: Balance trending phrases with evergreen humor elements +- Prioritize concise, impactful punchlines (1-2 sentences per box) that stay faithful to the user's emotional tone and amplify the humorous contrast they hinted at, retaining the setup → punchline rhythm unless the user explicitly requests another format. + +**Text on the Meme**: +[Read the chart from top to bottom, from left to right in each red box should be put what text in turn, with box1: text fragment 1\nbox2: text fragment 2\n, there are several boxes to correspond to the output of a few paragraphs of the text corresponds to each other, here pay attention to the combination of the box in the map position, the meaning of the map, the user input, and the previous reasoning to generate the theme of the humor of the text. Unless told otherwise, deliver two fragments that mirror the demonstration’s “When… / But…” cadence with updated thematic wording.] +--- + +**Output Demonstration Example** + +[Comprehensive Description Section] +The image employs the classic \"Shocked Cat\" meme template, featuring a close-up of an orange tabby cat with dilated circular pupils and forward-stretched whiskers creating visual tension. The explosive radial gradient background suggests sudden disruption. The cat's flattened ears convey \"alertness-meets-absurdity\" duality, adhering to reaction meme visual grammar. + +[Usage Scenarios Section] +Optimal use cases include: +1. Social media rants about last-minute work demands +2. Gaming group reactions to unexpected team failures +3. E-commerce shoppers encountering bizarre product descriptions +Ideal scenarios should follow \"unexpected shock → exaggerated response\" narrative structures + +[Text Analysis Section] +Suggested text: \n\"Friday 5:55 PM\" (top line establishes time pressure) +\"Client says 'Just one more thing...'\" (bottom line triggers conflict) +Humor mechanisms: Amplifies workplace frustrations through the cat's dramatic expression, using cross-dimensional analogy between time constraints and animal reactions + +[Specific Analysis with User Input] +Step 1: Given [Emotion: Frustration][Intensity: 8][Theme: Fitness failures], emphasize exaggerated body-text correlation. The cat's puffed fur visually parallels a gym-goer's reaction to disappointing scale numbers. +Step 2: Implement absurd escalation: \"When your trainer says\" (setup) → \"'One more rep' actually means 20\" (absurd payoff). Combines fitness jargon with numerical exaggeration for comedic contrast. + +Text on the Meme: +\"When the pre-workout kicks in +But your willpower checks out early\" + + +Now please generate the analysis and text results based on this image and user input parameters. \ No newline at end of file diff --git a/examples/godsmeme/prompts/reward_compare.txt b/examples/godsmeme/prompts/reward_compare.txt new file mode 100644 index 00000000..0cc2809d --- /dev/null +++ b/examples/godsmeme/prompts/reward_compare.txt @@ -0,0 +1,24 @@ +You are a strict meme reward model. +You will receive two final meme images created from the same base image and the same user request. +Your job is to compare which meme is better. + +Evaluation criteria: +1. Better alignment with the user request and emotional tone. +2. Stronger humor and punchline quality. +3. Better use of the meme template and box layout. +4. Clearer, more concise meme wording. + +User request: +{user_request} + +Candidate 1 extracted text: +{text_a} + +Candidate 2 extracted text: +{text_b} + +Return exactly the following four lines: +Image 1 score: +Image 2 score: +Winner: <1 or 2 or tie> +Reason: diff --git a/examples/godsmeme/reward_model.py b/examples/godsmeme/reward_model.py index c3fcf819..23f3f9a3 100644 --- a/examples/godsmeme/reward_model.py +++ b/examples/godsmeme/reward_model.py @@ -1,100 +1,309 @@ -from typing import Callable, Dict, List, Tuple, Union, Sequence, Optional -import re import json +import os +import re +from typing import Any, Dict, List, Optional, Sequence, Tuple + import torch import torch.nn as nn -import torch.nn.functional as F from qwen_vl_utils import process_vision_info -from transformers import AutoProcessor, AutoModel, PreTrainedModel, AutoConfig, Qwen2_5_VLForConditionalGeneration +from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration from lightrft.utils import get_current_device +from meme_utils import ( + MemeRenderConfig, + PairwisePreference, + aggregate_pairwise_preferences, + compute_meme_format_reward, + extract_box_texts, + get_first_image, + get_user_request, + load_text_file, + normalize_detections, + render_meme_image, + resolve_expected_box_count, + sample_group_pairs, +) + +_PAIRWISE_LABEL_KEY = "pairwise" +_REWARD_STATE = { + "model_reward_weight": 1.0, + "format_reward_weight": 0.1, +} + + +def _as_torch_device(device_like: Any) -> torch.device: + if isinstance(device_like, torch.device): + return device_like + if isinstance(device_like, int): + return torch.device(f"cuda:{device_like}") + return torch.device(device_like) + + +def _default_reward_prompt_path() -> str: + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompts", "reward_compare.txt") + + +def _load_reward_prompt(path: Optional[str] = None) -> str: + prompt_path = path or _default_reward_prompt_path() + if os.path.exists(prompt_path): + return load_text_file(prompt_path) + return ( + "You are a strict meme reward model.\n" + "Compare the two meme images and return exactly:\n" + "Image 1 score: <0-10>\n" + "Image 2 score: <0-10>\n" + "Winner: <1 or 2 or tie>\n" + "Reason: " + ) + -class AttentionPooling(nn.Module): - """ - Overview: - Attention pooling layer on the sequence dimension of LLM/VLM hidden states. - """ +def _parse_reward_config(raw_reward_pretrain: str) -> Dict[str, Any]: + if raw_reward_pretrain is None or not str(raw_reward_pretrain).strip(): + raise ValueError("`reward_pretrain` must point to a meme judge model or a JSON config") + + text = str(raw_reward_pretrain).strip() + try: + cfg = json.loads(text) + except json.JSONDecodeError: + cfg = {"pairwise": {"path": text}} + + if isinstance(cfg, str): + cfg = {"pairwise": {"path": cfg}} + if not isinstance(cfg, dict): + raise ValueError("Unsupported meme reward config format") + + if "pairwise" in cfg: + pairwise_cfg = cfg["pairwise"] + elif "outcome" in cfg: + pairwise_cfg = cfg["outcome"] + else: + first_key = next(iter(cfg.keys())) + pairwise_cfg = cfg[first_key] + + if isinstance(pairwise_cfg, str): + pairwise_cfg = {"path": pairwise_cfg} + if not isinstance(pairwise_cfg, dict) or not pairwise_cfg.get("path"): + raise ValueError("Meme reward config must contain a model path") + return pairwise_cfg + + +def build_pairwise_judge_message( + reward_prompt: str, + image_a, + image_b, + text_a: str, + text_b: str, + user_request: str, +) -> List[Dict[str, Any]]: + prompt = reward_prompt.format( + user_request=user_request or "N/A", + text_a=text_a or "(empty)", + text_b=text_b or "(empty)", + ) + return [{ + "role": "user", + "content": [ + { + "type": "text", + "text": f"{prompt}\n\nCandidate 1 image:" + }, + { + "type": "image", + "image": image_a + }, + { + "type": "text", + "text": "Candidate 2 image:" + }, + { + "type": "image", + "image": image_b + }, + ], + }] + + +def parse_pair_judge_response(response: str) -> Tuple[float, float]: + response = (response or "").strip() + if not response: + return 0.5, 0.5 + + def _extract_score(image_idx: int) -> Optional[float]: + pattern = rf"Image\s*{image_idx}\s*score\s*[::]\s*([0-9]+(?:\.[0-9]+)?)" + match = re.search(pattern, response, re.IGNORECASE) + if not match: + return None + try: + return float(match.group(1)) + except ValueError: + return None + + score_a = _extract_score(1) + score_b = _extract_score(2) + if score_a is not None and score_b is not None: + return score_a, score_b + + winner_match = re.search(r"Winner\s*[::]\s*(1|2|tie)", response, re.IGNORECASE) + if winner_match: + winner = winner_match.group(1).lower() + if winner == "1": + return 1.0, 0.0 + if winner == "2": + return 0.0, 1.0 + return 0.5, 0.5 + + numeric_scores = re.findall(r"([0-9]+(?:\.[0-9]+)?)", response) + if len(numeric_scores) >= 2: + try: + return float(numeric_scores[0]), float(numeric_scores[1]) + except ValueError: + pass + return 0.5, 0.5 + + +def group_meme_rollout_indices( + refs: Optional[Sequence[Any]], + batch_size: int, + n_samples_per_prompt: int = 1, +) -> List[List[int]]: + if refs and len(refs) >= batch_size: + groups: List[List[int]] = [] + current_group: List[int] = [] + current_group_id: Optional[str] = None + usable = True + + for idx in range(batch_size): + ref = refs[idx] + group_id = None + if isinstance(ref, dict): + for key in ("group_id", "sample_id", "id"): + value = ref.get(key) + if value is not None: + group_id = str(value) + break + if group_id is None: + usable = False + break + + if current_group and group_id != current_group_id: + groups.append(current_group) + current_group = [idx] + else: + current_group.append(idx) + current_group_id = group_id + + if usable and current_group: + groups.append(current_group) + if usable and sum(len(group) for group in groups) == batch_size and any(len(group) > 1 for group in groups): + return groups + + chunk_size = max(1, int(n_samples_per_prompt)) + return [list(range(start, min(start + chunk_size, batch_size))) for start in range(0, batch_size, chunk_size)] + + +class MemePairwiseJudge(nn.Module): def __init__( self, - hidden_size: int, - num_heads: int = 4, - qkv_bias: bool = False, - position_bias: bool = False, - position_bias_scale: float = 3.0, + base_model: Qwen2_5_VLForConditionalGeneration, + processor, + reward_prompt: str, + max_new_tokens: int = 96, + pair_batch_size: int = 4, + n_samples_per_prompt: int = 1, + max_pairs_per_group: int = 0, + render_config: Optional[MemeRenderConfig] = None, ): - super(AttentionPooling, self).__init__() - self.num_heads = num_heads - self.head_dim = hidden_size // num_heads - self.scale = self.head_dim ** -0.5 - self.position_bias = position_bias - self.position_bias_scale = position_bias_scale - - self.k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - self.v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) - # Using 0.02 for better initialization - self.query = nn.Parameter(torch.randn(hidden_size) * 0.02) - - def forward(self, hidden_states): - B, S, C = hidden_states.shape - - # Multi-head projection for key and value - k = self.k(hidden_states).reshape(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # B, H, S, D - v = self.v(hidden_states).reshape(B, S, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # B, H, S, D - - # Expand query for batch dimension - q = self.query.unsqueeze(0).expand(B, -1, -1) # B, H, C - q = q.unsqueeze(2) # B, H, 1, C - q = q.reshape(B, self.num_heads, 1, self.head_dim) # B, H, 1, C - - # Attention weights - attn = (q @ k.transpose(-2, -1)) * self.scale # B, H, 1, S - - # Add position bias - if self.position_bias: - position_bias = torch.arange(S, device=k.device).float() / S * self.position_bias_scale - attn = attn + position_bias.view(1, 1, 1, -1) # Add position bias - - # Attention pooling - attn = torch.softmax(attn, dim=-1) # B, H, 1, S - attn = attn.to(v.dtype) - out = (attn @ v).squeeze(2) # B, H, D - out = out.reshape(B, -1) # B, C - - return out - - -class Qwen2VLRewardModelMemeOutcome(PreTrainedModel): - def __init__(self, pretrained_model): - super().__init__(pretrained_model.config) - self.pretrained_model = pretrained_model - - if hasattr(pretrained_model.config, 'hidden_size'): - hidden_size = pretrained_model.config.hidden_size - elif hasattr(pretrained_model.config, 'd_model'): - hidden_size = pretrained_model.config.d_model - else: - raise ValueError("Cannot determine hidden size from model config") - - self.attention_pooling = AttentionPooling( - hidden_size=hidden_size, - num_heads=4, - qkv_bias=False, - position_bias=True, - position_bias_scale=3.0, - ) + super().__init__() + self.base_model = base_model + self.processor = processor + self.reward_prompt = reward_prompt + self.max_new_tokens = int(max_new_tokens) + self.pair_batch_size = int(pair_batch_size) + self.n_samples_per_prompt = int(n_samples_per_prompt) + self.max_pairs_per_group = int(max_pairs_per_group) + self.render_config = render_config or MemeRenderConfig() + + def _render_candidates( + self, + raw_images: Sequence[Any], + references: Sequence[Any], + prompt_and_outputs: Sequence[str], + ) -> Tuple[List[Any], List[str]]: + rendered_images: List[Any] = [] + extracted_texts: List[str] = [] + + for raw_image, reference, prompt_and_output in zip(raw_images, references, prompt_and_outputs): + extracted_boxes = extract_box_texts(prompt_and_output or "") + extracted_texts.append("\\n".join(extracted_boxes) if extracted_boxes else "") + + image = get_first_image(raw_image) + if image is None: + rendered_images.append(None) + continue + + detections = normalize_detections(reference if isinstance(reference, dict) else None) + rendered_images.append( + render_meme_image( + image=image, + texts=extracted_boxes, + detections=detections, + reference=reference if isinstance(reference, dict) else None, + config=self.render_config, + ) + ) - self.classification_head = nn.Linear(hidden_size, 2) - nn.init.normal_(self.classification_head.weight, std=0.02) - nn.init.zeros_(self.classification_head.bias) + return rendered_images, extracted_texts + + def _generate_pair_scores(self, pair_jobs: List[Dict[str, Any]], device: torch.device) -> List[Tuple[float, float]]: + scores: List[Tuple[float, float]] = [] + if not pair_jobs: + return scores + + step = max(1, self.pair_batch_size) + for start in range(0, len(pair_jobs), step): + batch_jobs = pair_jobs[start:start + step] + messages = [job["message"] for job in batch_jobs] + texts = [ + self.processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in messages + ] + image_inputs = [] + for message in messages: + processed = process_vision_info(message) + if isinstance(processed, tuple): + image_inputs.append(processed[0]) + else: + image_inputs.append(processed) + + inputs = self.processor( + text=texts, + images=image_inputs, + padding=True, + truncation=True, + return_tensors="pt", + ) + for key, value in list(inputs.items()): + if torch.is_tensor(value): + inputs[key] = value.to(device) + + generation_kwargs = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs["attention_mask"], + "max_new_tokens": self.max_new_tokens, + "do_sample": False, + } + if "pixel_values" in inputs: + generation_kwargs["pixel_values"] = inputs["pixel_values"] + if "image_grid_thw" in inputs: + generation_kwargs["image_grid_thw"] = inputs["image_grid_thw"] - self.attention_pooling.bfloat16() - self.classification_head.bfloat16() + generated = self.base_model.generate(**generation_kwargs) + trimmed = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs["input_ids"], generated)] + decoded = self.processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) + scores.extend(parse_pair_judge_response(text) for text in decoded) - @classmethod - def from_pretrained(cls, pretrained_model: PreTrainedModel): - """Create a binary classification model from a pretrained model.""" - return cls(pretrained_model) + return scores @torch.no_grad() def forward( @@ -107,245 +316,76 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, **kwargs, - ): - """Forward pass for meme reward evaluation.""" - # Forward through base model - prompt_and_outputs = kwargs.get('prompt_and_output') - raw_images = kwargs.get('raw_images') - outputs = self.pretrained_model( - input_ids=input_ids.cuda(), - attention_mask=attention_mask.cuda(), - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=True, - ) - - # Extract hidden states - if hasattr(outputs, 'last_hidden_state'): - hidden_states = outputs.last_hidden_state - elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: - hidden_states = outputs.hidden_states[-1] - else: - raise ValueError("Cannot extract hidden states from model output") - - # Use attention pooling - pooled_output = self.attention_pooling(hidden_states) - - # Get logits - logits = self.classification_head(pooled_output) - - # Get logits [batch_size, 2], first dimension is probability of 0, second dimension is probability of 1 - logits = self.classification_head(pooled_output) - - # Calculate probabilities and binarize - probabilities = F.softmax(logits, dim=-1) # [batch_size, 2] - # If the second dimension (probability of 1) is larger, output 1, otherwise output 0 - binary_scores = (probabilities[:, 1] > probabilities[:, 0]).float() - - return { - 'score': binary_scores, # Binary result of 0/1 - 'logits': logits # Original logits, containing scores for two dimensions - } - - -class Qwen2VLRewardModelMemeContent(PreTrainedModel): - system_prompt = f""" - You are a professional meme text generation evaluation expert who is good at evaluating the quality of the current reasoning process in combination with images.\n\n - """ - eval_prompt = """ - You are a professional and strict-scoring expert in evaluating meme text generation, responsible for scoring the quality of the reasoning process (Chain of Thought, CoT) generated by the model. - This reasoning process is a text generation reasoning conducted based on the first text-free base image and input parameters (i.e., the user's requirements for the meme). - - Details of the evaluation task: - 1. Input parameters: - {input_params} - - 2. Standard reasoning process (reference standard answer): - {standard_cot} - - 3. Actual reasoning process to be evaluated: (The generated text is after "Text on the Meme") - {actual_cot} - - Please conduct the evaluation by combining the two provided images (the base image and the standard meme image with text), and the scoring must be strict. The evaluation criteria are as follows: + ) -> Dict[str, torch.Tensor]: + del attention_mask, pixel_values, image_grid_thw, return_dict, output_attentions, output_hidden_states - 1. Whether the chain of thought process includes an analysis of the expressions/actions/facial features/relationships/scenes of the entities in the image, and check its correctness: (Total 10 points) - a. Rough description: For example, there is a woman in this picture; 1 point; - b. With some details but no description of actions/facial expressions: For example, there is a woman in this picture, wearing a hat, sitting in a car; 4 points; - c. With details and actions/facial expressions: For example, there is a woman in this picture, wearing a hat, sitting in a car, looking very happy; 7 points; - d. Not only explaining the details such as the actions and expressions of the characters in the picture, but also immediately associating the character relationships/scenes where the actions occur; 10 points; + prompt_and_output = kwargs.get("prompt_and_output") or [] + raw_images = kwargs.get("raw_images") or [] + references = kwargs.get("references") or [] - 2. Analysis of further scene associations based on the relationships between entities in the image: (Total 10 points) - a. Only roughly describing possible scenes without specificity: For example, this may happen in daily life; 1 point; - b. Describing a relatively specific scene: For example, this may be the scene when you went out with friends to drink and found yourself vomiting; 4 points; - c. Describing multiple relatively specific scenes: For example, this may be the scene when you went out with friends to drink and found yourself vomiting, or the scene when the teacher checks homework but you find you haven't finished it; 10 points; - - Next, evaluate the content after [Specific analysis with user input]: - - 1. Whether the chain of thought further specifies the scene based on the previously associated scenes combined with user needs, or re-associates a scene more in line with user needs: (5 points for each satisfied item, total 20 points) - a. Whether the intention expressed in the sentence is consistent with the user's emotions; - b. Whether the intention expressed in the sentence is consistent with the user's intentions and the theme; - c. Whether the topic of the entire sentence is consistent with the keyword topic; - d. Whether some humorous techniques are used, such as puns/homophones/semantic reversal/subverting expectations/exaggeration/role dislocation/suspenseful beginning/punchline reversal/rhyming structure/internet memes; - - 2. Logical fluency and text length between the final generated text and the reasoning process: (Total 10 points) - a. The reasoning is very forced, and the humor of the answer paired with this base image is much worse than that of the standard meme, and the text length is much longer than that of the standard meme; 1 point; - b. The reasoning is roughly valid, and the length of the answer divided by boxes is roughly similar to that of the standard meme; 4 points; - c. The reasoning is very fluent, the length of the answer divided by boxes is roughly similar to that of the standard meme, and it is humorous when paired with the image (for humorous techniques, refer to the rhetoric I just mentioned); 7 points; - d. The reasoning is very fluent, the length of the answer divided by boxes is roughly similar to that of the standard meme, or even more interesting than the standard meme; 10 points; - - The total score is 50 points. It is necessary to provide the reasons and scores for each scoring criterion. - Output a line similar to the following at the end: - "Final score: 41 points" - """ - - def __init__(self, pretrained_model): - super().__init__(pretrained_model.config) - self.pretrained_model = pretrained_model - self.processor = None - - def set_processor(self, processor): - self.processor = processor - - @classmethod - def from_pretrained(cls, pretrained_model: PreTrainedModel): - return cls(pretrained_model) - - def _get_message(self, input_params, standard_cot, actual_cot, base_image, standard_meme_image): - message = [ - { - "role": "system", - "content": [{ - "type": "text", - "text": self.system_prompt - }] - }, - { - "role": "user", - "content": [ - { - "type": "text", - "text": "The first image is the base map (no text, only a frame):" - }, - { - "type": "image", - "image": base_image - }, - #{"type": "text", "text": "The second image is a standard meme image (with the correct answer in text):"}, - #{"type": "image", "image": standard_meme_image} - { - "type": "text", - "text": self.eval_prompt.format( - input_params=input_params, standard_cot=standard_cot, actual_cot=actual_cot - ) - } - ] - } - ] - return message + batch_size = len(prompt_and_output) + if input_ids is not None: + device = input_ids.device + else: + device = _as_torch_device(get_current_device()) if torch.cuda.is_available() else torch.device("cpu") + if batch_size == 0: + return {"score": torch.zeros(0, dtype=torch.float32, device=device)} - @torch.no_grad() - def forward( - self, - input_ids, - attention_mask, - pixel_values, - image_grid_thw, - return_dict=None, - output_attentions=None, - output_hidden_states=None, - **kwargs - ): - raw_images = kwargs.get('raw_images') - references = kwargs.get('references') - texts = self.processor.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - texts = [t.split("user\n")[1] for t in texts] - - messages = [] - images = [] - for img, ref, text in zip(raw_images, references, texts): - input_params_match = re.search(r'\*\*Input Parameters\*\*:\s*\[(.*?)\]', text, re.DOTALL) - input_params = input_params_match.group(1) if input_params_match else "not found input params" - message = self._get_message( - input_params=input_params, standard_cot=ref, actual_cot=text, base_image=img, standard_meme_image=None - ) - processed_img, _ = process_vision_info(message) - messages.append(message) - images.append(processed_img) - - messages = [ - self.processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) - for message in messages - ] - if torch.distributed.get_rank() % 8 == 0: - print(f"messages: {messages[0]}, {len(messages)}") - inputs = self.processor( - text=messages, - images=images, - max_length=5000, - padding=True, - truncation=True, - return_tensors="pt", - ) - inputs = inputs.to("cuda") - - gen_ids = self.pretrained_model.generate( - input_ids=inputs.input_ids, - attention_mask=inputs.attention_mask, - pixel_values=inputs.pixel_values, - image_grid_thw=inputs.image_grid_thw, - temperature=0.3, - top_p=0.9, - max_new_tokens=128, - do_sample=False, - ) - outputs_trim = [o[len(i):] for i, o in zip(inputs.input_ids, gen_ids)] - outputs_text = self.processor.batch_decode( - outputs_trim, skip_special_tokens=True, clean_up_tokenization_spaces=False + rendered_images, extracted_texts = self._render_candidates(raw_images, references, prompt_and_output) + groups = group_meme_rollout_indices( + references, batch_size=batch_size, n_samples_per_prompt=self.n_samples_per_prompt ) - if torch.distributed.get_rank() % 8 == 0: - print(f"outputs_text: {outputs_text}, {inputs.input_ids.shape}") - score = [self._extract_numeric_score(o) for o in outputs_text] - return score - - def _extract_numeric_score(self, response: str) -> float: - """Helper function: Extract score from model response and normalize to 0-1 range""" - # Match numeric scores in 50-point scale (e.g., "Final score: 41 points" or "35/50") - numeric_patterns = [ - r"Final score[::]\s*([0-9.]+)\s*points?", r"Score[::]\s*([0-9.]+)\s*points?", r"([0-9.]+)\s*/\s*50" - ] - - for pattern in numeric_patterns: - match = re.search(pattern, response, re.IGNORECASE) - if match: - try: - score = float(match.group(1)) - normalized_score = score * 0.02 # Convert 50-point scale to 0-1 - return max(0.0, min(1.0, normalized_score)) # Clamp boundary values - except ValueError: - continue - - # If no numeric score found, infer score from text description - positive_indicators = ["excellent", "outstanding", "perfect", "very good", "meets requirements"] - neutral_indicators = ["average", "acceptable", "basically meets", "partially meets"] - negative_indicators = ["poor", "does not meet", "bad", "completely mismatched"] - response_lower = response.lower() - for idx, indicator in enumerate(positive_indicators): - if indicator.lower() in response_lower: - return 0.8 + (idx * 0.05) + pair_jobs: List[Dict[str, Any]] = [] + preferences: List[PairwisePreference] = [] + comparison_counts = [0 for _ in range(batch_size)] + + for group in groups: + if len(group) < 2: + continue + for local_a, local_b in sample_group_pairs(len(group), max_pairs=self.max_pairs_per_group): + global_a = group[local_a] + global_b = group[local_b] + if rendered_images[global_a] is None or rendered_images[global_b] is None: + continue - for idx, indicator in enumerate(neutral_indicators): - if indicator.lower() in response_lower: - return 0.5 + (idx * 0.05) + reference = references[global_a] if global_a < len(references) else None + fallback_prompt = prompt_and_output[global_a] if global_a < len(prompt_and_output) else None + user_request = get_user_request(reference if isinstance(reference, dict) else None, fallback_prompt) + pair_jobs.append({ + "index_a": global_a, + "index_b": global_b, + "message": build_pairwise_judge_message( + reward_prompt=self.reward_prompt, + image_a=rendered_images[global_a], + image_b=rendered_images[global_b], + text_a=extracted_texts[global_a], + text_b=extracted_texts[global_b], + user_request=user_request, + ), + }) + + pair_scores = self._generate_pair_scores(pair_jobs, device=device) + for job, (score_a, score_b) in zip(pair_jobs, pair_scores): + index_a = job["index_a"] + index_b = job["index_b"] + preferences.append( + PairwisePreference( + index_a=index_a, + index_b=index_b, + score_a=float(score_a), + score_b=float(score_b), + ) + ) + comparison_counts[index_a] += 1 + comparison_counts[index_b] += 1 - for idx, indicator in enumerate(negative_indicators): - if indicator.lower() in response_lower: - return 0.2 + (idx * 0.05) + pairwise_rewards = aggregate_pairwise_preferences(batch_size, preferences) + for idx, count in enumerate(comparison_counts): + if count == 0: + pairwise_rewards[idx] = 0.5 - # Return 0.0 when unable to infer - return 0.0 + return {"score": torch.tensor(pairwise_rewards, dtype=torch.float32, device=device)} def load_reward_models( @@ -354,111 +394,94 @@ def load_reward_models( use_engine: bool = False, ): if use_engine: - raise NotImplementedError("Engine is not supported for reward model") + raise NotImplementedError("Engine is not supported for the meme reward model") - model_list = [] - tokenizer_list = [] - processor_list = [] - with strategy.init_model_context() as _: - cfg = json.loads(reward_pretrain) - device = get_current_device() - - for key in cfg.keys(): - pretrain_path = cfg[key] - model_config = AutoConfig.from_pretrained( - pretrain_path, - trust_remote_code=True, - ) - base = Qwen2_5_VLForConditionalGeneration.from_pretrained( - pretrain_path, - config=model_config, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - trust_remote_code=True, - ) + cfg = _parse_reward_config(reward_pretrain) + _REWARD_STATE["model_reward_weight"] = float(cfg.get("model_reward_weight", 1.0)) + _REWARD_STATE["format_reward_weight"] = float(cfg.get("format_reward_weight", 0.1)) - processor = AutoProcessor.from_pretrained( - pretrain_path, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28 - ) - processor.tokenizer.padding_side = "left" - - if key == "outcome": - model = Qwen2VLRewardModelMemeOutcome.from_pretrained(base) - elif key == "content": - model = Qwen2VLRewardModelMemeContent.from_pretrained(base) - model.set_processor(processor) - # for some case about meta device - model.to_empty(device=device) - model.eval() - model_list.append(model) - tokenizer_list.append(processor.tokenizer) - processor_list.append(processor) - return model_list, tokenizer_list, processor_list - - -def get_format_reward(response: str) -> float: - """ - Evaluate the format compliance of model response, returns a score between 0-1 - - Args: - response: The model response content to be evaluated (string) - - Returns: - Format compliance score (0-1), where 1 means fully compliant with format requirements, 0 means completely non-compliant - """ - # Initialize score and total check items - score = 0.0 - total_checks = 0 - - # 1. Check if all required sections exist - required_sections = [ - r'\[Comprehensive Description Section\]', r'\[Usage Scenarios Section\]', r'\[Text Analysis Section\]', - r'\[Specific analysis with user input\]', r'Text on the Meme:' - ] - - for section in required_sections: - total_checks += 1 - if re.search(section, response, re.IGNORECASE): - score += 1 - - # 2. Check box format in 'Text on the Meme' section - total_checks += 1 - meme_text_match = re.search(r'Text on the Meme:\s*(.*?)(?=\n\n|$)', response, re.DOTALL | re.IGNORECASE) - if meme_text_match and re.search(r'box\d+:\s*[^\n]+', meme_text_match.group(1)): - score += 1 - - # 3. Check Step 1 and Step 2 in 'Specific analysis' section - total_checks += 2 - specific_analysis_match = re.search( - r'\[Specific analysis with user input\]\s*(.*?)(?=\n\[|Text on the Meme:|$)', response, - re.DOTALL | re.IGNORECASE - ) - if specific_analysis_match: - analysis_content = specific_analysis_match.group(1) - if re.search(r'Step 1:', analysis_content, re.IGNORECASE): - score += 1 - if re.search(r'Step 2:', analysis_content, re.IGNORECASE): - score += 1 + pretrain_path = cfg["path"] + reward_prompt = _load_reward_prompt(cfg.get("reward_prompt_path")) - # Normalize to 0-1 range - return round(score / total_checks, 4) if total_checks > 0 else 0.0 + with strategy.init_model_context() as _: + model_config = AutoConfig.from_pretrained(pretrain_path, trust_remote_code=True) + base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrain_path, + config=model_config, + torch_dtype=torch.bfloat16, + attn_implementation=cfg.get("attn_implementation", "flash_attention_2"), + trust_remote_code=True, + ) + processor = AutoProcessor.from_pretrained( + pretrain_path, + min_pixels=cfg.get("min_pixels", 256 * 28 * 28), + max_pixels=cfg.get("max_pixels", 1280 * 28 * 28), + trust_remote_code=True, + ) + processor.tokenizer.padding_side = "left" + + model = MemePairwiseJudge( + base_model=base_model, + processor=processor, + reward_prompt=reward_prompt, + max_new_tokens=cfg.get("max_new_tokens", 96), + pair_batch_size=cfg.get("pair_batch_size", 4), + n_samples_per_prompt=cfg.get("n_samples_per_prompt", 1), + max_pairs_per_group=cfg.get("max_pairs_per_group", 0), + render_config=MemeRenderConfig( + font_name=cfg.get("font_name", "DejaVuSans.ttf"), + min_font_size=cfg.get("min_font_size", 14), + max_font_size=cfg.get("max_font_size", 72), + line_spacing=cfg.get("line_spacing", 4), + outline_width=cfg.get("outline_width", 2), + margin=cfg.get("margin", 6), + default_padding_ratio=cfg.get("default_padding_ratio", 0.06), + ), + ) + model.eval() + + return [model], [processor.tokenizer], {_PAIRWISE_LABEL_KEY: 0} def reward_fn( - model_reward_list: List[torch.Tensor], # len = n_model , each shape=(B,) + model_reward_list: List[torch.Tensor], labels: Sequence[str], queries: Sequence[str], - refs: Sequence[str], + refs: Sequence[Any], + label_map: Optional[Dict[str, int]] = None, **kwargs, -) -> torch.Tensor: - # outcome reward - # model_reward_list: Shapes [2] - outcome_reward = model_reward_list[0] - dtype, device = outcome_reward.dtype, outcome_reward.device - # rule reward - format_reward = [get_format_reward(q) for q in queries] - format_reward = torch.tensor(format_reward, dtype=dtype, device=device) - if torch.distributed.get_rank() % 8 == 0: - print(f"queries: {queries[0]}") - print(f"model_reward_list: {model_reward_list}, format_reward: {format_reward}") - return format_reward + outcome_reward +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + del labels, kwargs + + if model_reward_list: + device = model_reward_list[0].device + dtype = model_reward_list[0].dtype + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + batch_size = len(queries) + model_reward = torch.zeros(batch_size, dtype=dtype, device=device) + if model_reward_list: + pairwise_idx = 0 + if label_map: + pairwise_idx = label_map.get(_PAIRWISE_LABEL_KEY, 0) + pairwise_idx = min(pairwise_idx, len(model_reward_list) - 1) + model_reward = torch.as_tensor(model_reward_list[pairwise_idx], dtype=dtype, device=device) + + format_values = [] + for idx, query in enumerate(queries): + reference = refs[idx] if refs is not None and idx < len(refs) else None + expected_boxes = resolve_expected_box_count(reference if isinstance(reference, dict) else None) + format_values.append(compute_meme_format_reward(query, expected_boxes=expected_boxes)) + + format_reward = torch.tensor(format_values, dtype=dtype, device=device) + final_reward = ( + _REWARD_STATE["model_reward_weight"] * model_reward + _REWARD_STATE["format_reward_weight"] * format_reward + ) + metrics = { + "model_reward": model_reward, + "format_reward": format_reward, + "rule_reward": final_reward, + } + return final_reward, metrics diff --git a/examples/godsmeme/test_reward_model_vllm.py b/examples/godsmeme/test_reward_model_vllm.py new file mode 100644 index 00000000..c660a7a0 --- /dev/null +++ b/examples/godsmeme/test_reward_model_vllm.py @@ -0,0 +1,103 @@ +import os +import sys +from pathlib import Path + +import pytest +from PIL import Image +from transformers import AutoProcessor + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) +if str(Path(__file__).resolve().parent) not in sys.path: + sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from meme_dataset import MemeOnlineRLDataset # noqa: E402 +from meme_utils import extract_box_texts, normalize_detections, render_meme_image # noqa: E402 +from reward_model import ( # noqa: E402 + _load_reward_prompt, + build_pairwise_judge_message, + parse_pair_judge_response, +) + + +def _require_env(name: str) -> str: + value = os.getenv(name) + if not value: + pytest.skip(f"{name} is not set; skipping real vLLM meme reward test") + return value + + +def test_reward_model_vllm_real_data(): + if not os.getenv("RUN_GODSMEME_VLLM_TEST"): + pytest.skip("Set RUN_GODSMEME_VLLM_TEST=1 to run the real vLLM integration test") + if not Image: + pytest.skip("PIL is unavailable") + + vllm = pytest.importorskip("vllm") + if not os.getenv("CUDA_VISIBLE_DEVICES") and not os.path.exists("/dev/nvidia0"): + pytest.skip("This integration test requires a CUDA-visible vLLM environment") + + model_path = _require_env("GODSMEME_REWARD_MODEL_PATH") + annotation_path = _require_env("GODSMEME_ANNOTATION_PATH") + image_root = _require_env("GODSMEME_IMAGE_ROOT") + + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + dataset = MemeOnlineRLDataset( + annotation_path=annotation_path, + root_dir=image_root, + processor=processor, + shuffle=False, + ) + prompt, image_paths, reference, _ = dataset[0] + assert prompt + assert image_paths + assert reference["reference_output"] + + base_image = Image.open(image_paths[0]).convert("RGB") + detections = normalize_detections(reference) + reference_texts = extract_box_texts(reference["reference_output"]) + assert reference_texts, "reference_output must contain meme text boxes" + + bad_texts = ["lorem ipsum dolor sit amet" for _ in range(len(reference_texts))] + good_image = render_meme_image(base_image, reference_texts, detections=detections, reference=reference) + bad_image = render_meme_image(base_image, bad_texts, detections=detections, reference=reference) + + reward_prompt = _load_reward_prompt(os.getenv("GODSMEME_REWARD_PROMPT_PATH")) + judge_message = build_pairwise_judge_message( + reward_prompt=reward_prompt, + image_a=good_image, + image_b=bad_image, + text_a="\\n".join(reference_texts), + text_b="\\n".join(bad_texts), + user_request=reference.get("user_request", ""), + ) + rendered_prompt = processor.apply_chat_template(judge_message, tokenize=False, add_generation_prompt=True) + + llm = vllm.LLM( + model=model_path, + tensor_parallel_size=int(os.getenv("GODSMEME_VLLM_TP", "1")), + trust_remote_code=True, + max_model_len=int(os.getenv("GODSMEME_VLLM_MAX_MODEL_LEN", "4096")), + limit_mm_per_prompt={"image": 2}, + ) + sampling_params = vllm.SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=int(os.getenv("GODSMEME_VLLM_MAX_TOKENS", "96")), + ) + outputs = llm.generate( + [{ + "prompt": rendered_prompt, + "multi_modal_data": { + "image": [good_image, bad_image] + } + }], + sampling_params, + ) + generated_text = outputs[0].outputs[0].text + score_a, score_b = parse_pair_judge_response(generated_text) + + assert isinstance(score_a, float) + assert isinstance(score_b, float) + assert score_a >= score_b diff --git a/examples/godsmeme/train_colocate.py b/examples/godsmeme/train_colocate.py index 903c8543..14f7f43f 100644 --- a/examples/godsmeme/train_colocate.py +++ b/examples/godsmeme/train_colocate.py @@ -1,16 +1,8 @@ import argparse -import itertools import math -import re import os import sys -import json from datetime import datetime -from typing import Callable, Dict, List, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F from lightrft.utils import get_tokenizer_processor_vl, ensure_video_input_available from lightrft.models import ActorVL @@ -81,7 +73,7 @@ def train(args): critic = None strategy.report_memory("before loaded reward models in main entry") - reward_models, reward_tokenizers, reward_processors, label_map = load_reward_models( + reward_models, reward_tokenizers, label_map = load_reward_models( args.reward_pretrain, strategy, use_engine=args.rm_use_engine ) strategy.print(f"label_map: {label_map}") @@ -333,6 +325,7 @@ def train(args): parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API") + parser.add_argument("--critic_pretrain", type=str, default=None, help="Optional critic path for GAE runs") # Custom dataset (GodsMeme) parser.add_argument( @@ -356,7 +349,6 @@ def train(args): help="sampling probs for datasets", ) parser.add_argument("--prompt_split", type=str, default="train") - # wandb parameters parser.add_argument("--use_wandb", type=str, default=None) parser.add_argument("--wandb_org", type=str, default=None) @@ -396,6 +388,9 @@ def train(args): if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]: assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" + assert args.micro_rollout_batch_size % args.n_samples_per_prompt == 0, ( + "meme pairwise reward requires micro_rollout_batch_size to be divisible by n_samples_per_prompt" + ) if args.use_kl_loss: if args.kl_estimator not in ["k2", "k3"]: diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index fe12d534..1e352cbe 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -794,6 +794,7 @@ def _compute_standard_torch_rewards( prompt_and_output=output.prompt_and_output, raw_images=output.raw_images, img_num=output.image_num, + references=output.references, **output.inputs_extra_kwargs, ) From 2265c8f662ea46d22accddc8d61bd6b16a1aacfe Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Fri, 17 Apr 2026 17:41:15 +0800 Subject: [PATCH 3/5] polish(nyz): add launch script --- examples/godsmeme/run_meme_grpo.sh | 242 ++++++++++++++++++++++++++++ examples/godsmeme/train_colocate.py | 2 + 2 files changed, 244 insertions(+) create mode 100755 examples/godsmeme/run_meme_grpo.sh diff --git a/examples/godsmeme/run_meme_grpo.sh b/examples/godsmeme/run_meme_grpo.sh new file mode 100755 index 00000000..3daf7067 --- /dev/null +++ b/examples/godsmeme/run_meme_grpo.sh @@ -0,0 +1,242 @@ +#!/usr/bin/env bash +# +# LightRFT GRPO training script for the GodsMeme example. +# +# Compared with examples/gsm8k_geo3k/run_grpo_geo3k_qwen2.5_vl_7b.sh, +# GodsMeme has a few task-specific constraints: +# 1. The dataset is loaded from --annotation_path + --root_dir, not --prompt_data. +# 2. The reward is not pure rule-based; it uses a local pairwise meme judge model. +# 3. The current meme reward model does not support --rm_use_engine. +# 4. Each GRPO group must stay inside one micro-rollout batch: +# micro_rollout_batch_size % n_samples_per_prompt == 0 +# 5. Pairwise judge cost grows quadratically with n_samples_per_prompt unless you cap +# max_pairs_per_group. +# + +set -euo pipefail + +################################################################################ +# Part 1: User Configuration # +################################################################################ + +# --- Model and Dataset Paths --- +# The policy model can be either a local path or a Hugging Face model id. +POLICY_MODEL_PATH="${POLICY_MODEL_PATH:-/path/to/your/policy-model}" + +# The meme judge model is loaded locally by examples/godsmeme/reward_model.py. +# In the simplest setup, this can be the same model family as the policy model. +REWARD_MODEL_PATH="${REWARD_MODEL_PATH:-/path/to/your/reward-model}" + +# GodsMeme expects pre-built RL rows in JSON or JSONL format. +ANNOTATION_PATH="${ANNOTATION_PATH:-/path/to/your/train_data.jsonl}" +IMAGE_ROOT="${IMAGE_ROOT:-/path/to/your/image_root}" + +# --- Experiment and Logging --- +EXPERIMENT_NAME="${EXPERIMENT_NAME:-lightrft-godsmeme-grpo-training}" +RESULT_ROOT="${RESULT_ROOT:-results}" +LOG_ROOT="${LOG_ROOT:-rft_logs}" + +# Set WANDB_API_KEY="" to disable W&B cleanly. +export WANDB_API_KEY="${WANDB_API_KEY:-}" +export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-GodsMeme-Experiments}" +export WANDB_MODE="${WANDB_MODE:-offline}" + + +################################################################################ +# Part 2: GodsMeme Reward Configuration # +################################################################################ + +# Reward prompt template used by the pairwise judge. +REWARD_PROMPT_PATH="${REWARD_PROMPT_PATH:-examples/godsmeme/prompts/reward_compare.txt}" + +# Reward cost control. 0 means use all pairs inside each rollout group. +MAX_PAIRS_PER_GROUP="${MAX_PAIRS_PER_GROUP:-0}" +PAIR_BATCH_SIZE="${PAIR_BATCH_SIZE:-4}" +REWARD_MAX_NEW_TOKENS="${REWARD_MAX_NEW_TOKENS:-96}" + +# Reward composition: +# final_reward = model_reward_weight * pairwise_reward +# + format_reward_weight * format_reward +MODEL_REWARD_WEIGHT="${MODEL_REWARD_WEIGHT:-1.0}" +FORMAT_REWARD_WEIGHT="${FORMAT_REWARD_WEIGHT:-0.1}" + + +################################################################################ +# Part 3: Training Hyperparameters # +################################################################################ + +# --- GRPO Settings --- +N_SAMPLES="${N_SAMPLES:-8}" +EPISODE="${EPISODE:-20}" +WARMUP="${WARMUP:-0.03}" + +# --- Batch Size Settings --- +RBS="${RBS:-128}" +TBS="${TBS:-128}" +MICRO_ROLLOUT_BATCH_SIZE="${MICRO_ROLLOUT_BATCH_SIZE:-8}" +MICRO_TRAIN_BATCH_SIZE="${MICRO_TRAIN_BATCH_SIZE:-4}" + +# --- Learning and Generation Settings --- +KL="${KL:-0.01}" +LR="${LR:-1e-6}" +PROMPT_MAX_LEN="${PROMPT_MAX_LEN:-2048}" +GENERATE_MAX_LEN="${GENERATE_MAX_LEN:-1024}" +TEMPERATURE="${TEMPERATURE:-1.0}" +TOP_P="${TOP_P:-1.0}" + +# The actor only sees one source image per prompt in GodsMeme. +LIMIT_MM_IMAGE_PER_PROMPT="${LIMIT_MM_IMAGE_PER_PROMPT:-1}" + + +################################################################################ +# Part 4: Distributed Training Setup # +################################################################################ + +export NNODES="${NNODES:-1}" +export GPUS_PER_NODE="${GPUS_PER_NODE:-8}" +export NODE_RANK="${NODE_RANK:-0}" +export MASTER_ADDR="${MASTER_ADDR:-localhost}" +export MASTER_PORT="${MASTER_PORT:-20091}" + +# The rollout engine is for the actor only. The meme reward model is loaded +# directly and kept outside rm_use_engine. +ENGINE_TYPE="${ENGINE_TYPE:-sglang}" +ENGINE_TP="${ENGINE_TP:-2}" +ENGINE_MEM_UTIL="${ENGINE_MEM_UTIL:-0.4}" + + +################################################################################ +# Part 5: Validation and Launch # +################################################################################ + +if (( N_SAMPLES <= 1 )); then + echo "[GodsMeme] N_SAMPLES must be > 1 when using group_norm / GRPO." >&2 + exit 1 +fi + +if (( MICRO_ROLLOUT_BATCH_SIZE % N_SAMPLES != 0 )); then + echo "[GodsMeme] MICRO_ROLLOUT_BATCH_SIZE must be divisible by N_SAMPLES." >&2 + echo "[GodsMeme] This keeps every rollout group inside one micro-batch for pairwise reward computation." >&2 + exit 1 +fi + +if (( ENGINE_TP <= 0 )); then + echo "[GodsMeme] ENGINE_TP must be a positive integer." >&2 + exit 1 +fi + +if (( GPUS_PER_NODE % ENGINE_TP != 0 )); then + echo "[GodsMeme] GPUS_PER_NODE must be divisible by ENGINE_TP." >&2 + exit 1 +fi + +PAIR_TAG="allpairs" +if (( MAX_PAIRS_PER_GROUP > 0 )); then + PAIR_TAG="pairs${MAX_PAIRS_PER_GROUP}" +fi + +DEFAULT_REWARD_PRETRAIN="$({ + REWARD_MODEL_PATH="$REWARD_MODEL_PATH" \ + REWARD_PROMPT_PATH="$REWARD_PROMPT_PATH" \ + PAIR_BATCH_SIZE="$PAIR_BATCH_SIZE" \ + MAX_PAIRS_PER_GROUP="$MAX_PAIRS_PER_GROUP" \ + MODEL_REWARD_WEIGHT="$MODEL_REWARD_WEIGHT" \ + FORMAT_REWARD_WEIGHT="$FORMAT_REWARD_WEIGHT" \ + REWARD_MAX_NEW_TOKENS="$REWARD_MAX_NEW_TOKENS" \ + N_SAMPLES="$N_SAMPLES" \ + python - <<'PY' +import json +import os + +cfg = { + "pairwise": { + "path": os.environ["REWARD_MODEL_PATH"], + "reward_prompt_path": os.environ["REWARD_PROMPT_PATH"], + "pair_batch_size": int(os.environ["PAIR_BATCH_SIZE"]), + "max_pairs_per_group": int(os.environ["MAX_PAIRS_PER_GROUP"]), + "model_reward_weight": float(os.environ["MODEL_REWARD_WEIGHT"]), + "format_reward_weight": float(os.environ["FORMAT_REWARD_WEIGHT"]), + "max_new_tokens": int(os.environ["REWARD_MAX_NEW_TOKENS"]), + "n_samples_per_prompt": int(os.environ["N_SAMPLES"]), + } +} +print(json.dumps(cfg, ensure_ascii=True)) +PY +})" +REWARD_PRETRAIN="${REWARD_PRETRAIN:-$DEFAULT_REWARD_PRETRAIN}" + +current_time=$(date +"%Y%m%d_%H%M%S") +SAVE_MODEL_NAME="${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-${PAIR_TAG}-${current_time}" +WANDB_RUN_NAME="${EXPERIMENT_NAME}-${PAIR_TAG}-${current_time}" + +mkdir -p "${RESULT_ROOT}/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" +mkdir -p "${LOG_ROOT}/${EXPERIMENT_NAME}" + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG="${NCCL_DEBUG:-WARN}" +export IGNORE_EOS="${IGNORE_EOS:-0}" + +set -x + +torchrun \ + --nnodes "$NNODES" \ + --nproc-per-node "$GPUS_PER_NODE" \ + --node_rank "$NODE_RANK" \ + --master-port "$MASTER_PORT" \ + --master-addr "$MASTER_ADDR" \ + examples/godsmeme/train_colocate.py \ + --pretrain "${POLICY_MODEL_PATH}" \ + --reward_pretrain "${REWARD_PRETRAIN}" \ + --annotation_path "${ANNOTATION_PATH}" \ + --root_dir "${IMAGE_ROOT}" \ + --save_path "${RESULT_ROOT}/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --ckpt_path "${RESULT_ROOT}/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --advantage_estimator "group_norm" \ + --use_kl_loss \ + --kl_estimator "k3" \ + --fsdp \ + --bf16 \ + --flash_attn \ + --gradient_checkpointing \ + --save_hf_ckpt \ + --micro_train_batch_size "${MICRO_TRAIN_BATCH_SIZE}" \ + --train_batch_size "${TBS}" \ + --micro_rollout_batch_size "${MICRO_ROLLOUT_BATCH_SIZE}" \ + --rollout_batch_size "${RBS}" \ + --max_epochs 1 \ + --num_episodes "${EPISODE}" \ + --lr_warmup_ratio "${WARMUP}" \ + --n_samples_per_prompt "${N_SAMPLES}" \ + --prompt_max_len "${PROMPT_MAX_LEN}" \ + --generate_max_len "${GENERATE_MAX_LEN}" \ + --actor_learning_rate "${LR}" \ + --temperature "${TEMPERATURE}" \ + --top_p "${TOP_P}" \ + --init_kl_coef "${KL}" \ + --l2 1.0e-2 \ + --freeze_prefix \ + --adam_offload \ + --engine_type "${ENGINE_TYPE}" \ + --engine_mem_util "${ENGINE_MEM_UTIL}" \ + --engine_tp_size "${ENGINE_TP}" \ + --enable_engine_sleep \ + --limit_mm_image_per_prompt "${LIMIT_MM_IMAGE_PER_PROMPT}" \ + --use_wandb "${WANDB_API_KEY}" \ + --wandb_project "${WANDB_PROJECT}" \ + --wandb_run_name "${WANDB_RUN_NAME}" \ + 2>&1 | tee "${LOG_ROOT}/${EXPERIMENT_NAME}/node${NODE_RANK}_${current_time}.log" + + +################################################################################ +# Usage Notes # +# # +# 1. GodsMeme data must already be prepared as conversation-style RL rows. # +# 2. Do not add --rm_use_engine here: the current meme reward model raises # +# NotImplementedError when rm_use_engine is enabled. # +# 3. LIMIT_MM_IMAGE_PER_PROMPT defaults to 1 because each policy prompt only # +# contains one source image. The pairwise judge images are handled inside # +# examples/godsmeme/reward_model.py. # +# 4. If reward evaluation is too slow, first try lowering N_SAMPLES or set # +# MAX_PAIRS_PER_GROUP to a small positive integer. # +# # +################################################################################ diff --git a/examples/godsmeme/train_colocate.py b/examples/godsmeme/train_colocate.py index 14f7f43f..32fc6324 100644 --- a/examples/godsmeme/train_colocate.py +++ b/examples/godsmeme/train_colocate.py @@ -4,6 +4,8 @@ import sys from datetime import datetime +import torch + from lightrft.utils import get_tokenizer_processor_vl, ensure_video_input_available from lightrft.models import ActorVL from lightrft.strategy import get_strategy From e7be0598ff7802db51abc5dab10c69882134d27f Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sun, 19 Apr 2026 22:54:54 +0800 Subject: [PATCH 4/5] feature(nyz): add new runnable pipeline --- examples/godsmeme/README.md | 193 ++++---- examples/godsmeme/meme_dataset.py | 13 +- examples/godsmeme/prompts/reward_compare.txt | 25 +- examples/godsmeme/requirements.txt | 2 + examples/godsmeme/reward_model.py | 457 ++++++++++++++----- examples/godsmeme/run_meme_grpo.sh | 43 +- examples/godsmeme/test_meme_dataset.py | 2 +- examples/godsmeme/test_reward_model_vllm.py | 103 ----- examples/godsmeme/train_colocate.py | 96 ++++ 9 files changed, 563 insertions(+), 371 deletions(-) create mode 100644 examples/godsmeme/requirements.txt delete mode 100644 examples/godsmeme/test_reward_model_vllm.py diff --git a/examples/godsmeme/README.md b/examples/godsmeme/README.md index da2fdf7e..129bc992 100644 --- a/examples/godsmeme/README.md +++ b/examples/godsmeme/README.md @@ -1,63 +1,65 @@ -# GodsMeme GRPO Pipeline +# GodsMeme on LightRFT -This example adapts LightRFT's vision-language GRPO stack to meme generation. -The policy consumes the already-formatted GodsMeme RL prompt from the dataset, generates the full reasoning-style response, and the reward pipeline renders the final meme text back onto the image before running a pairwise meme judge. +This example adapts LightRFT's vision-language GRPO pipeline to meme generation. +The policy sees one source image plus a preformatted GodsMeme prompt, generates a full reasoning-style answer, and is rewarded by a pairwise meme judge that compares rendered meme images within each GRPO group. -## What is included +## What the training entrypoint actually is -- `train_colocate.py`: main GRPO training entry for meme RL. -- `meme_dataset.py`: dataset loader for the prepared GodsMeme RL JSON/JSONL rows. -- `meme_utils.py`: GodsMeme-specific parsing, rendering, and pair aggregation helpers. -- `reward_model.py`: pairwise meme judge wrapper and reward aggregation logic. -- `prompts/generate_meme.txt`: reference policy prompt format. -- `prompts/reward_compare.txt`: pairwise comparison prompt for the meme reward judge. -- `run_meme_grpo.sh`: example launch script. -- `test_reward_model_vllm.py`: optional real-model integration smoke test using vLLM. +There are two entry layers: -## Reward flow +- `examples/godsmeme/run_meme_grpo.sh`: user-facing launcher. Set paths and training knobs here, then run it with `bash`. +- `examples/godsmeme/train_colocate.py`: Python training entry used by `torchrun`. It builds the actor, dataset, reward model, inference engine, and PPO trainer. -For each rollout group produced from the same prompt: +If you only want to start training, edit the shell script or override its environment variables. You usually do not need to call `train_colocate.py` directly. -1. Parse the policy completion and extract the `Text on the Meme` section. -2. Render the generated text onto the base image using the dataset's detection boxes when available. -3. Build pairwise comparisons inside the rollout group. -4. Ask a VLM judge which rendered meme is better. -5. Convert pairwise wins/scores into one scalar reward per sample. -6. Add a small format reward that checks the required GodsMeme response structure. +## End-to-end training flow -The reward weights are read from `--reward_pretrain` JSON config: +1. `MemeOnlineRLDataset` loads prebuilt RL rows from `annotation_path` and resolves the source image from `root_dir`. +2. Each row is converted into a multimodal chat prompt with one image placeholder and one text instruction. +3. GRPO samples `n_samples_per_prompt` policy completions for the same prompt. +4. The reward pipeline extracts the `Text on the Meme` section from each completion. +5. The extracted text is rendered back onto the original image by using detection boxes when available, or a simple fallback layout otherwise. +6. A local VLM judge compares candidate meme images pairwise inside the same rollout group. +7. Pairwise scores are aggregated into one scalar reward per sample. +8. A small format reward is added so the policy keeps the expected GodsMeme response structure. +9. LightRFT runs PPO/GRPO updates and saves checkpoints to the configured output directory. -```text -final_reward = model_reward_weight * pairwise_reward - + format_reward_weight * format_reward -``` - -Defaults: - -- `model_reward_weight = 1.0` -- `format_reward_weight = 0.1` - -## Important rollout constraint - -The pairwise reward is computed inside each rollout micro-batch, so keep every GRPO group fully contained in one micro-batch: +## Directory map ```text -micro_rollout_batch_size % n_samples_per_prompt == 0 +examples/godsmeme/ +├── README.md # This guide +├── run_meme_grpo.sh # User-facing launch script +├── train_colocate.py # Main GRPO training entry +├── meme_dataset.py # Dataset loader for GodsMeme RL rows +├── reward_model.py # Pairwise reward judge and reward aggregation +├── meme_utils.py # Text parsing, rendering, and pair helpers +├── prompts/ +│ ├── generate_meme.txt # Reference policy prompt format +│ └── reward_compare.txt # Prompt template for the pairwise judge +├── test_meme_dataset.py # Dataset test scaffold +└── test_reward_model_vllm.py # Optional reward-model smoke test ``` -For standard GRPO runs, use `--advantage_estimator group_norm` and set `--n_samples_per_prompt > 1`. +## Dataset format expected by this example -## Dataset expectation +This example does not build the RL dataset for you. It expects a JSON or JSONL file where each row already looks like a conversation-style GodsMeme training sample. -The RL dataset is expected to already be formatted as conversation-style GodsMeme rows: +Minimal example: ```json { "id": "sample-001", "image": "images/cat.jpg", "conversations": [ - {"from": "human", "value": "...GodsMeme prompt... "}, - {"from": "assistant", "value": "...reference meme response..."} + { + "from": "human", + "value": "...GodsMeme prompt... " + }, + { + "from": "assistant", + "value": "...reference reasoning and meme text..." + } ], "text_loc_info": { "loc": [[40, 60, 420, 170], [40, 330, 420, 430]] @@ -65,60 +67,97 @@ The RL dataset is expected to already be formatted as conversation-style GodsMem } ``` -Supported box metadata keys include `detections`, `text_loc_info`, `loc`, `bbox_scale`, and `bbox_normalized`. -If no boxes are available, the renderer falls back to a simple top/bottom banner layout. - -## Reward model configuration +Useful notes: -`--reward_pretrain` accepts either: +- `image`, `image_path`, and `img` are accepted as image keys. +- The human message must contain ``. +- The assistant message is treated as the reference output and is also used to infer the expected number of text boxes when box metadata is missing. +- Supported box metadata includes `detections`, `text_loc_info`, `loc`, `bbox_scale`, `bbox_normalized`, and `expected_box_count`. +- If no boxes are available, the renderer falls back to a simple top/bottom style layout. -1. A plain Hugging Face model path for the pairwise judge. -2. A JSON object if you want to override judge settings. +## Reward design -Plain path example: +The reward combines two parts: -```bash ---reward_pretrain /path/to/Qwen2.5-VL-7B-Instruct +```text +final_reward = model_reward_weight * pairwise_reward + + format_reward_weight * format_reward ``` -JSON example: +- `pairwise_reward`: computed by comparing rendered candidate memes from the same rollout group. +- `format_reward`: checks whether the completion keeps the expected GodsMeme answer structure and box count. -```bash ---reward_pretrain '{ - "pairwise": { - "path": "/path/to/Qwen2.5-VL-7B-Instruct", - "max_new_tokens": 96, - "pair_batch_size": 4, - "max_pairs_per_group": 0, - "model_reward_weight": 1.0, - "format_reward_weight": 0.1, - "reward_prompt_path": "examples/godsmeme/prompts/reward_compare.txt" - } -}' -``` +Default weights: + +- `model_reward_weight = 1.0` +- `format_reward_weight = 0.1` + +The default launcher builds `--reward_pretrain` as a JSON blob that points to the local judge model and the comparison prompt template. -## Running +The current implementation follows the `HUMOR-RM-Keye-VL` inference pattern directly in `reward_model.py`, loading the Keye-based reward model plus `classification_head.pt` without adding a `llamafactory` runtime dependency. The pairwise judge prompt is kept to the same simple question used in the model README: `Which meme is funnier?` -Edit the environment variables in `run_meme_grpo.sh`, then launch: +## Quick start + +Edit the paths in `examples/godsmeme/run_meme_grpo.sh`, or override them inline: ```bash +POLICY_MODEL_PATH=/path/to/policy-model \ +REWARD_MODEL_PATH=/path/to/reward-model \ +ANNOTATION_PATH=/path/to/train_data.jsonl \ +IMAGE_ROOT=/path/to/image_root \ bash examples/godsmeme/run_meme_grpo.sh ``` -## Validation +Default outputs: -Lightweight unit tests: +- checkpoints: `results///` +- logs: `rft_logs//` -```bash -pytest examples/godsmeme/test_meme_dataset.py examples/godsmeme/test_reward_model.py -``` +## Important constraints before you launch + +- `N_SAMPLES` must be greater than `1` for GRPO with `group_norm`. +- `MICRO_ROLLOUT_BATCH_SIZE % N_SAMPLES == 0` must hold, otherwise one prompt group can be split across micro-batches and pairwise reward becomes invalid. +- Pairwise judge cost grows roughly quadratically with `N_SAMPLES` unless you cap `MAX_PAIRS_PER_GROUP`. +- The current meme reward model does not support `--rm_use_engine`; the judge is loaded directly by `reward_model.py`. +- Each policy prompt uses one source image, so `LIMIT_MM_IMAGE_PER_PROMPT` is set to `1`. +- This example reads `--annotation_path` and `--root_dir`; it does not use the generic `--prompt_data` path for training data. -Optional real-model vLLM smoke test: +## Most useful knobs in `run_meme_grpo.sh` + +- `POLICY_MODEL_PATH`: actor checkpoint or HF model id. +- `REWARD_MODEL_PATH`: Keye-based reward-model checkpoint path. It should contain `classification_head.pt`. +- `REWARD_MAX_LENGTH`: max sequence length used when scoring rendered meme pairs. +- `ANNOTATION_PATH` / `IMAGE_ROOT`: GodsMeme RL data location. +- `N_SAMPLES`: number of rollouts per prompt. Higher is better for ranking quality but much slower. +- `MAX_PAIRS_PER_GROUP`: limit pairwise comparisons to reduce reward cost. +- `PAIR_BATCH_SIZE`: judge batch size for pairwise evaluation. +- `MICRO_ROLLOUT_BATCH_SIZE`: must be divisible by `N_SAMPLES`. +- `MICRO_TRAIN_BATCH_SIZE`: lower this first if actor training OOMs. +- `ENGINE_TYPE`, `ENGINE_TP`, `ENGINE_MEM_UTIL`: rollout engine settings for the actor. +- `PROMPT_MAX_LEN`, `GENERATE_MAX_LEN`: trim these if prompts or responses are too long for your setup. + +## Practical tuning advice + +- Start with smaller `N_SAMPLES` such as `4` before scaling to `8`. +- If reward evaluation is the bottleneck, reduce `N_SAMPLES` or set `MAX_PAIRS_PER_GROUP` to a small positive integer. +- If the actor OOMs, first lower `MICRO_TRAIN_BATCH_SIZE`, then `MICRO_ROLLOUT_BATCH_SIZE`, then `ENGINE_MEM_UTIL`. +- If generations are too verbose, reduce `GENERATE_MAX_LEN` and revisit the prompt format in `examples/godsmeme/prompts/generate_meme.txt`. +- If rendered text looks misplaced, verify your box annotations before changing the reward model. + +## Optional validation + +There are lightweight unit tests for the new reward-model invocation path: ```bash -RUN_GODSMEME_VLLM_TEST=1 \ -GODSMEME_REWARD_MODEL_PATH=/path/to/Qwen2.5-VL-7B-Instruct \ -GODSMEME_ANNOTATION_PATH=/path/to/train_data.jsonl \ -GODSMEME_IMAGE_ROOT=/path/to/image_root \ -pytest examples/godsmeme/test_reward_model_vllm.py -k real_data +pytest examples/godsmeme/test_reward_model_vllm.py ``` + +These tests cover the classifier-head loading helpers and the direct pair-scoring flow without requiring a real reward checkpoint. + +## Current limitations + +- No data preprocessing script is included in this folder; the RL JSON/JSONL must already be prepared. +- The reward judge is model-based and local, so training cost is much higher than rule-only GRPO examples. +- The reward loader expects the Keye multimodal interface used by `HUMOR-RM-Keye-VL`. +- The fallback renderer is intentionally simple and is mainly a reward-time approximation, not a production meme compositor. +- The policy is optimized for the GodsMeme response template; changing the prompt schema usually requires updating both parsing and reward logic. diff --git a/examples/godsmeme/meme_dataset.py b/examples/godsmeme/meme_dataset.py index 28d8f41f..3c63cd00 100644 --- a/examples/godsmeme/meme_dataset.py +++ b/examples/godsmeme/meme_dataset.py @@ -61,22 +61,10 @@ def _resolve_image_path(self, data: Dict[str, Any]) -> str: raise FileNotFoundError(f"Image {image_path} does not exist") return image_path - def _extract_user_request_from_prompt(self, prompt_text: str) -> str: - patterns = [ - r"\*\*User Input Parameters\*\*:\s*(.*?)(?=\n\s*\*\*Text on the Meme\*\*:|\Z)", - r"Input Parameters\s*:\s*(\[[^\n]+\])", - ] - for pattern in patterns: - match = re.search(pattern, prompt_text, re.IGNORECASE | re.DOTALL) - if match: - return match.group(1).strip() - return "" - def _build_reference(self, data: Dict[str, Any], prompt_text: str, assistant_output: str) -> Dict[str, Any]: reference: Dict[str, Any] = { "id": data.get("id"), "group_id": str(data.get("group_id") or data.get("sample_id") or data.get("id") or ""), - "user_request": self._extract_user_request_from_prompt(prompt_text), "reference_output": assistant_output, } @@ -117,6 +105,7 @@ def _process_item(self, raw_item: Union[Dict[str, Any], str]) -> Tuple[str, List }] prompt = self.processor.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) reference = self._build_reference(data, human_input, assistant_output) + reference = reference["reference_output"] label = data.get("reward_rule_label", self.DEFAULT_LABEL) return prompt, [image_path], reference, label diff --git a/examples/godsmeme/prompts/reward_compare.txt b/examples/godsmeme/prompts/reward_compare.txt index 0cc2809d..13b71982 100644 --- a/examples/godsmeme/prompts/reward_compare.txt +++ b/examples/godsmeme/prompts/reward_compare.txt @@ -1,24 +1 @@ -You are a strict meme reward model. -You will receive two final meme images created from the same base image and the same user request. -Your job is to compare which meme is better. - -Evaluation criteria: -1. Better alignment with the user request and emotional tone. -2. Stronger humor and punchline quality. -3. Better use of the meme template and box layout. -4. Clearer, more concise meme wording. - -User request: -{user_request} - -Candidate 1 extracted text: -{text_a} - -Candidate 2 extracted text: -{text_b} - -Return exactly the following four lines: -Image 1 score: -Image 2 score: -Winner: <1 or 2 or tie> -Reason: +Which meme is funnier? diff --git a/examples/godsmeme/requirements.txt b/examples/godsmeme/requirements.txt new file mode 100644 index 00000000..6d340b7c --- /dev/null +++ b/examples/godsmeme/requirements.txt @@ -0,0 +1,2 @@ +transformers==4.56.2 +keye-vl-utils[decord]==1.0.0 diff --git a/examples/godsmeme/reward_model.py b/examples/godsmeme/reward_model.py index 23f3f9a3..2cbee51d 100644 --- a/examples/godsmeme/reward_model.py +++ b/examples/godsmeme/reward_model.py @@ -1,12 +1,13 @@ import json import os -import re from typing import Any, Dict, List, Optional, Sequence, Tuple import torch import torch.nn as nn -from qwen_vl_utils import process_vision_info -from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration +from transformers import AutoModel, AutoProcessor +from transformers.utils import cached_file + +from keye_vl_utils import process_vision_info from lightrft.utils import get_current_device @@ -17,7 +18,6 @@ compute_meme_format_reward, extract_box_texts, get_first_image, - get_user_request, load_text_file, normalize_detections, render_meme_image, @@ -30,6 +30,34 @@ "model_reward_weight": 1.0, "format_reward_weight": 0.1, } +_DTYPE_MAP = { + "auto": None, + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float32": torch.float32, + "fp32": torch.float32, +} + + +class _FSDPSafeEmbedding(nn.Module): + """Keep embedding weights in the parent FSDP unit. + + Keye's vision tower reads ``position_embedding.weight`` directly in its + interpolation path. When FSDP2 individually wraps those embedding modules, + the weight can become a DTensor while sibling activations stay regular + tensors, which triggers mixed Tensor/DTensor errors on addition. + """ + def __init__(self, embedding: nn.Embedding): + super().__init__() + self.weight = embedding.weight + self.num_embeddings = embedding.num_embeddings + self.embedding_dim = embedding.embedding_dim + self.padding_idx = embedding.padding_idx + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return nn.functional.embedding(input_ids, self.weight, padding_idx=self.padding_idx) def _as_torch_device(device_like: Any) -> torch.device: @@ -40,22 +68,18 @@ def _as_torch_device(device_like: Any) -> torch.device: return torch.device(device_like) + def _default_reward_prompt_path() -> str: return os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompts", "reward_compare.txt") + def _load_reward_prompt(path: Optional[str] = None) -> str: prompt_path = path or _default_reward_prompt_path() if os.path.exists(prompt_path): return load_text_file(prompt_path) - return ( - "You are a strict meme reward model.\n" - "Compare the two meme images and return exactly:\n" - "Image 1 score: <0-10>\n" - "Image 2 score: <0-10>\n" - "Winner: <1 or 2 or tie>\n" - "Reason: " - ) + return "Which meme is funnier?" + def _parse_reward_config(raw_reward_pretrain: str) -> Dict[str, Any]: @@ -88,34 +112,245 @@ def _parse_reward_config(raw_reward_pretrain: str) -> Dict[str, Any]: return pairwise_cfg + +def _resolve_torch_dtype(dtype_like: Any) -> Optional[torch.dtype]: + if dtype_like is None or isinstance(dtype_like, torch.dtype): + return dtype_like + key = str(dtype_like).strip().lower() + if key not in _DTYPE_MAP: + raise ValueError(f"Unsupported torch dtype: {dtype_like}") + return _DTYPE_MAP[key] + + + +def _resolve_model_file(path_or_repo: str, filename: str) -> str: + if os.path.isdir(path_or_repo): + local_path = os.path.join(path_or_repo, filename) + if os.path.exists(local_path): + return local_path + try: + return cached_file(path_or_repo, filename) + except Exception as exc: # pragma: no cover - exercised only when weights are resolved remotely. + raise FileNotFoundError(f"Could not find {filename} under {path_or_repo}") from exc + + +def _replace_module_if_embedding(parent: nn.Module, attr_name: str) -> bool: + module = getattr(parent, attr_name, None) + if not isinstance(module, nn.Embedding): + return False + + setattr(parent, attr_name, _FSDPSafeEmbedding(module)) + return True + + +def _find_keye_visual_module(model: nn.Module) -> Optional[nn.Module]: + for attr_name in ("visual", "vision_tower", "vision_model"): + module = getattr(model, attr_name, None) + if isinstance(module, nn.Module): + return module + + inner_model = getattr(model, "model", None) + if isinstance(inner_model, nn.Module): + return _find_keye_visual_module(inner_model) + + return None + + +def _patch_keye_fsdp_compat(model: nn.Module) -> None: + """Patch Keye vision modules that are fragile under FSDP2 DTensors.""" + visual_model = _find_keye_visual_module(model) + if visual_model is None: + return + + replaced_names: List[str] = [] + for module_name, module in visual_model.named_modules(): + if _replace_module_if_embedding(module, "position_embedding"): + replaced_names.append(f"{module_name}.position_embedding" if module_name else "position_embedding") + if _replace_module_if_embedding(module, "packing_position_embedding"): + replaced_names.append( + f"{module_name}.packing_position_embedding" if module_name else "packing_position_embedding" + ) + if hasattr(module, "_attn_implementation"): + module._attn_implementation = "eager" + + if hasattr(getattr(visual_model, "config", None), "_attn_implementation"): + visual_model.config._attn_implementation = "eager" + + vision_config = getattr(getattr(model, "config", None), "vision_config", None) + if hasattr(vision_config, "_attn_implementation"): + vision_config._attn_implementation = "eager" + + if replaced_names: + print("[MemePairwiseJudge] FSDP2-safe Keye vision patch:", ", ".join(replaced_names)) + print("[MemePairwiseJudge] Set Keye vision attention to 'eager' for FSDP2 compat") + + + +def _unwrap_head_state_dict(state_dict: Dict[str, Any]) -> Dict[str, torch.Tensor]: + nested_keys = ("state_dict", "model", "module", "classification_head", "classifier") + current = state_dict + while isinstance(current, dict): + tensor_values = [value for value in current.values() if torch.is_tensor(value)] + if tensor_values: + break + next_state = None + for key in nested_keys: + candidate = current.get(key) + if isinstance(candidate, dict): + next_state = candidate + break + if next_state is None: + break + current = next_state + + if not isinstance(current, dict): + raise ValueError("classification_head.pt does not contain a supported state dict") + + tensor_state = {str(key): value for key, value in current.items() if torch.is_tensor(value)} + if not tensor_state: + raise ValueError("classification_head.pt does not contain tensor parameters") + return tensor_state + + + +def _strip_common_prefix(state_dict: Dict[str, torch.Tensor], prefix: str) -> Dict[str, torch.Tensor]: + if not all(key.startswith(prefix) for key in state_dict): + return state_dict + return {key[len(prefix):]: value for key, value in state_dict.items()} + + + +def _normalize_head_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + normalized = dict(state_dict) + changed = True + while changed: + changed = False + for prefix in ("module.", "model.", "classification_head.", "classifier.", "head."): + stripped = _strip_common_prefix(normalized, prefix) + if stripped is not normalized: + normalized = stripped + changed = True + break + return normalized + + +class _TwoLayerTanhHead(nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int): + super().__init__() + self.dense = nn.Linear(in_features, hidden_features) + self.out_proj = nn.Linear(hidden_features, out_features) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.out_proj(self.activation(self.dense(hidden_states))) + + + +def _build_classification_head(state_dict: Dict[str, torch.Tensor]) -> nn.Module: + state_dict = _normalize_head_state_dict(_unwrap_head_state_dict(state_dict)) + keys = set(state_dict.keys()) + + if {"dense.weight", "dense.bias", "out_proj.weight", "out_proj.bias"}.issubset(keys): + dense_weight = state_dict["dense.weight"] + out_proj_weight = state_dict["out_proj.weight"] + head = _TwoLayerTanhHead( + in_features=dense_weight.shape[1], + hidden_features=dense_weight.shape[0], + out_features=out_proj_weight.shape[0], + ) + head.load_state_dict( + { + "dense.weight": dense_weight, + "dense.bias": state_dict["dense.bias"], + "out_proj.weight": out_proj_weight, + "out_proj.bias": state_dict["out_proj.bias"], + } + ) + return head + + if "weight" in state_dict and state_dict["weight"].ndim == 2: + bias = state_dict.get("bias") + head = nn.Linear(state_dict["weight"].shape[1], state_dict["weight"].shape[0], bias=bias is not None) + payload = {"weight": state_dict["weight"]} + if bias is not None: + payload["bias"] = bias + head.load_state_dict(payload, strict=False) + return head + + for prefix in ("score", "out_proj", "summary"): + weight_key = f"{prefix}.weight" + if weight_key not in state_dict: + continue + bias_key = f"{prefix}.bias" + bias = state_dict.get(bias_key) + head = nn.Linear(state_dict[weight_key].shape[1], state_dict[weight_key].shape[0], bias=bias is not None) + payload = {"weight": state_dict[weight_key]} + if bias is not None: + payload["bias"] = bias + head.load_state_dict(payload, strict=False) + return head + + raise ValueError( + "Unsupported classification head format. Expected a simple linear head or a dense/out_proj head, " + f"but found keys: {sorted(state_dict.keys())}" + ) + + + +def _load_classification_head(path_or_file: str, dtype: Optional[torch.dtype] = None) -> nn.Module: + head_state = torch.load(path_or_file, map_location="cpu") + head = _build_classification_head(head_state) + if dtype is not None: + head = head.to(dtype=dtype) + return head + + + +def _pool_last_non_padding_token(hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: + if attention_mask is None: + return hidden_states[:, -1, :] + + attention_mask = attention_mask.to(device=hidden_states.device, dtype=torch.long) + reverse_mask = torch.flip(attention_mask, dims=[-1]) + last_offsets = torch.argmax(reverse_mask, dim=-1) + last_indices = attention_mask.size(-1) - 1 - last_offsets + gather_index = last_indices.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1)) + return hidden_states.gather(dim=1, index=gather_index).squeeze(1) + + + +def _pair_scores_from_logits(logits: torch.Tensor) -> List[Tuple[float, float]]: + if logits.ndim == 1: + logits = logits.unsqueeze(-1) + + if logits.size(-1) == 1: + probs_a = torch.sigmoid(logits[:, 0].float()) + probs_b = 1.0 - probs_a + else: + probs = torch.softmax(logits[:, :2].float(), dim=-1) + probs_a = probs[:, 0] + probs_b = probs[:, 1] + + return list(zip(probs_a.detach().cpu().tolist(), probs_b.detach().cpu().tolist())) + + + def build_pairwise_judge_message( reward_prompt: str, image_a, image_b, - text_a: str, - text_b: str, - user_request: str, ) -> List[Dict[str, Any]]: - prompt = reward_prompt.format( - user_request=user_request or "N/A", - text_a=text_a or "(empty)", - text_b=text_b or "(empty)", - ) return [{ "role": "user", "content": [ { "type": "text", - "text": f"{prompt}\n\nCandidate 1 image:" + "text": reward_prompt }, { "type": "image", "image": image_a }, - { - "type": "text", - "text": "Candidate 2 image:" - }, { "type": "image", "image": image_b @@ -124,43 +359,6 @@ def build_pairwise_judge_message( }] -def parse_pair_judge_response(response: str) -> Tuple[float, float]: - response = (response or "").strip() - if not response: - return 0.5, 0.5 - - def _extract_score(image_idx: int) -> Optional[float]: - pattern = rf"Image\s*{image_idx}\s*score\s*[::]\s*([0-9]+(?:\.[0-9]+)?)" - match = re.search(pattern, response, re.IGNORECASE) - if not match: - return None - try: - return float(match.group(1)) - except ValueError: - return None - - score_a = _extract_score(1) - score_b = _extract_score(2) - if score_a is not None and score_b is not None: - return score_a, score_b - - winner_match = re.search(r"Winner\s*[::]\s*(1|2|tie)", response, re.IGNORECASE) - if winner_match: - winner = winner_match.group(1).lower() - if winner == "1": - return 1.0, 0.0 - if winner == "2": - return 0.0, 1.0 - return 0.5, 0.5 - - numeric_scores = re.findall(r"([0-9]+(?:\.[0-9]+)?)", response) - if len(numeric_scores) >= 2: - try: - return float(numeric_scores[0]), float(numeric_scores[1]) - except ValueError: - pass - return 0.5, 0.5 - def group_meme_rollout_indices( refs: Optional[Sequence[Any]], @@ -205,23 +403,25 @@ def group_meme_rollout_indices( class MemePairwiseJudge(nn.Module): def __init__( self, - base_model: Qwen2_5_VLForConditionalGeneration, + base_model: nn.Module, processor, + reward_head: nn.Module, reward_prompt: str, - max_new_tokens: int = 96, pair_batch_size: int = 4, n_samples_per_prompt: int = 1, max_pairs_per_group: int = 0, + max_length: int = 4096, render_config: Optional[MemeRenderConfig] = None, ): super().__init__() self.base_model = base_model self.processor = processor + self.reward_head = reward_head self.reward_prompt = reward_prompt - self.max_new_tokens = int(max_new_tokens) self.pair_batch_size = int(pair_batch_size) self.n_samples_per_prompt = int(n_samples_per_prompt) self.max_pairs_per_group = int(max_pairs_per_group) + self.max_length = int(max_length) self.render_config = render_config or MemeRenderConfig() def _render_candidates( @@ -229,13 +429,11 @@ def _render_candidates( raw_images: Sequence[Any], references: Sequence[Any], prompt_and_outputs: Sequence[str], - ) -> Tuple[List[Any], List[str]]: + ) -> List[Any]: rendered_images: List[Any] = [] - extracted_texts: List[str] = [] for raw_image, reference, prompt_and_output in zip(raw_images, references, prompt_and_outputs): extracted_boxes = extract_box_texts(prompt_and_output or "") - extracted_texts.append("\\n".join(extracted_boxes) if extracted_boxes else "") image = get_first_image(raw_image) if image is None: @@ -253,55 +451,66 @@ def _render_candidates( ) ) - return rendered_images, extracted_texts + return rendered_images - def _generate_pair_scores(self, pair_jobs: List[Dict[str, Any]], device: torch.device) -> List[Tuple[float, float]]: + def _score_pair_jobs(self, pair_jobs: List[Dict[str, Any]], device: torch.device) -> List[Tuple[float, float]]: scores: List[Tuple[float, float]] = [] if not pair_jobs: return scores - step = max(1, self.pair_batch_size) - for start in range(0, len(pair_jobs), step): - batch_jobs = pair_jobs[start:start + step] + @torch.no_grad() + def _run_batch(batch_jobs: List[Dict[str, Any]]) -> List[Tuple[float, float]]: + if process_vision_info is None: + raise ImportError("keye-vl-utils is required for the GodsMeme reward model") messages = [job["message"] for job in batch_jobs] texts = [ self.processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages ] - image_inputs = [] - for message in messages: - processed = process_vision_info(message) - if isinstance(processed, tuple): - image_inputs.append(processed[0]) - else: - image_inputs.append(processed) - - inputs = self.processor( - text=texts, - images=image_inputs, - padding=True, - truncation=True, - return_tensors="pt", - ) + image_inputs, video_inputs = process_vision_info(messages) + + processor_kwargs = { + "text": texts, + "padding": True, + "truncation": True, + "max_length": self.max_length, + "return_tensors": "pt", + } + if image_inputs is not None: + processor_kwargs["images"] = image_inputs + if video_inputs is not None: + processor_kwargs["videos"] = video_inputs + + inputs = self.processor(**processor_kwargs) for key, value in list(inputs.items()): if torch.is_tensor(value): inputs[key] = value.to(device) - generation_kwargs = { - "input_ids": inputs["input_ids"], - "attention_mask": inputs["attention_mask"], - "max_new_tokens": self.max_new_tokens, - "do_sample": False, - } - if "pixel_values" in inputs: - generation_kwargs["pixel_values"] = inputs["pixel_values"] - if "image_grid_thw" in inputs: - generation_kwargs["image_grid_thw"] = inputs["image_grid_thw"] + outputs = self.base_model(**inputs, output_hidden_states=True, return_dict=True) + hidden_states = getattr(outputs, "hidden_states", None) + if hidden_states: + last_hidden = hidden_states[-1] + else: + last_hidden = getattr(outputs, "last_hidden_state", None) + if last_hidden is None: + raise ValueError("Reward model must return hidden states for pairwise scoring.") + + pooled = _pool_last_non_padding_token(last_hidden, inputs.get("attention_mask")) + logits = self.reward_head(pooled) + return _pair_scores_from_logits(logits) - generated = self.base_model.generate(**generation_kwargs) - trimmed = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs["input_ids"], generated)] - decoded = self.processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) - scores.extend(parse_pair_judge_response(text) for text in decoded) + step = max(1, self.pair_batch_size) + for start in range(0, len(pair_jobs), step): + batch_jobs = pair_jobs[start:start + step] + try: + scores.extend(_run_batch(batch_jobs)) + except IndexError as exc: + # Some multimodal processors do not align multiple two-image prompts + # correctly in one batch. Fall back to one pair per forward pass. + if len(batch_jobs) == 1 or "image_grid_thw" not in str(exc): + raise + for job in batch_jobs: + scores.extend(_run_batch([job])) return scores @@ -331,7 +540,7 @@ def forward( if batch_size == 0: return {"score": torch.zeros(0, dtype=torch.float32, device=device)} - rendered_images, extracted_texts = self._render_candidates(raw_images, references, prompt_and_output) + rendered_images = self._render_candidates(raw_images, references, prompt_and_output) groups = group_meme_rollout_indices( references, batch_size=batch_size, n_samples_per_prompt=self.n_samples_per_prompt ) @@ -349,9 +558,6 @@ def forward( if rendered_images[global_a] is None or rendered_images[global_b] is None: continue - reference = references[global_a] if global_a < len(references) else None - fallback_prompt = prompt_and_output[global_a] if global_a < len(prompt_and_output) else None - user_request = get_user_request(reference if isinstance(reference, dict) else None, fallback_prompt) pair_jobs.append({ "index_a": global_a, "index_b": global_b, @@ -359,13 +565,10 @@ def forward( reward_prompt=self.reward_prompt, image_a=rendered_images[global_a], image_b=rendered_images[global_b], - text_a=extracted_texts[global_a], - text_b=extracted_texts[global_b], - user_request=user_request, ), }) - pair_scores = self._generate_pair_scores(pair_jobs, device=device) + pair_scores = self._score_pair_jobs(pair_jobs, device=device) for job, (score_a, score_b) in zip(pair_jobs, pair_scores): index_a = job["index_a"] index_b = job["index_b"] @@ -388,6 +591,7 @@ def forward( return {"score": torch.tensor(pairwise_rewards, dtype=torch.float32, device=device)} + def load_reward_models( reward_pretrain: str, strategy, @@ -400,34 +604,42 @@ def load_reward_models( _REWARD_STATE["model_reward_weight"] = float(cfg.get("model_reward_weight", 1.0)) _REWARD_STATE["format_reward_weight"] = float(cfg.get("format_reward_weight", 0.1)) - pretrain_path = cfg["path"] + reward_model_path = cfg["path"] reward_prompt = _load_reward_prompt(cfg.get("reward_prompt_path")) + torch_dtype = _resolve_torch_dtype(cfg.get("torch_dtype", "float16")) + classification_head_path = cfg.get("classification_head_path") or _resolve_model_file( + reward_model_path, "classification_head.pt" + ) with strategy.init_model_context() as _: - model_config = AutoConfig.from_pretrained(pretrain_path, trust_remote_code=True) - base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( - pretrain_path, - config=model_config, - torch_dtype=torch.bfloat16, + base_model = AutoModel.from_pretrained( + reward_model_path, + torch_dtype=torch_dtype, attn_implementation=cfg.get("attn_implementation", "flash_attention_2"), - trust_remote_code=True, + trust_remote_code=cfg.get("trust_remote_code", True), ) + if getattr(getattr(strategy, "args", None), "fsdp", False): + _patch_keye_fsdp_compat(base_model) + processor = AutoProcessor.from_pretrained( - pretrain_path, + reward_model_path, min_pixels=cfg.get("min_pixels", 256 * 28 * 28), max_pixels=cfg.get("max_pixels", 1280 * 28 * 28), - trust_remote_code=True, + trust_remote_code=cfg.get("trust_remote_code", True), ) - processor.tokenizer.padding_side = "left" + if hasattr(processor, "tokenizer"): + processor.tokenizer.padding_side = "left" + reward_head = _load_classification_head(classification_head_path, dtype=torch_dtype) model = MemePairwiseJudge( base_model=base_model, processor=processor, + reward_head=reward_head, reward_prompt=reward_prompt, - max_new_tokens=cfg.get("max_new_tokens", 96), - pair_batch_size=cfg.get("pair_batch_size", 4), + pair_batch_size=cfg.get("pair_batch_size", 1), n_samples_per_prompt=cfg.get("n_samples_per_prompt", 1), max_pairs_per_group=cfg.get("max_pairs_per_group", 0), + max_length=cfg.get("max_length", cfg.get("cutoff_len", 4096)), render_config=MemeRenderConfig( font_name=cfg.get("font_name", "DejaVuSans.ttf"), min_font_size=cfg.get("min_font_size", 14), @@ -443,6 +655,7 @@ def load_reward_models( return [model], [processor.tokenizer], {_PAIRWISE_LABEL_KEY: 0} + def reward_fn( model_reward_list: List[torch.Tensor], labels: Sequence[str], diff --git a/examples/godsmeme/run_meme_grpo.sh b/examples/godsmeme/run_meme_grpo.sh index 3daf7067..1140589b 100755 --- a/examples/godsmeme/run_meme_grpo.sh +++ b/examples/godsmeme/run_meme_grpo.sh @@ -51,8 +51,8 @@ REWARD_PROMPT_PATH="${REWARD_PROMPT_PATH:-examples/godsmeme/prompts/reward_compa # Reward cost control. 0 means use all pairs inside each rollout group. MAX_PAIRS_PER_GROUP="${MAX_PAIRS_PER_GROUP:-0}" -PAIR_BATCH_SIZE="${PAIR_BATCH_SIZE:-4}" -REWARD_MAX_NEW_TOKENS="${REWARD_MAX_NEW_TOKENS:-96}" +PAIR_BATCH_SIZE="${PAIR_BATCH_SIZE:-1}" +REWARD_MAX_LENGTH="${REWARD_MAX_LENGTH:-96}" # Reward composition: # final_reward = model_reward_weight * pairwise_reward @@ -79,8 +79,8 @@ MICRO_TRAIN_BATCH_SIZE="${MICRO_TRAIN_BATCH_SIZE:-4}" # --- Learning and Generation Settings --- KL="${KL:-0.01}" LR="${LR:-1e-6}" -PROMPT_MAX_LEN="${PROMPT_MAX_LEN:-2048}" -GENERATE_MAX_LEN="${GENERATE_MAX_LEN:-1024}" +PROMPT_MAX_LEN="${PROMPT_MAX_LEN:-8192}" +GENERATE_MAX_LEN="${GENERATE_MAX_LEN:-2048}" TEMPERATURE="${TEMPERATURE:-1.0}" TOP_P="${TOP_P:-1.0}" @@ -100,36 +100,15 @@ export MASTER_PORT="${MASTER_PORT:-20091}" # The rollout engine is for the actor only. The meme reward model is loaded # directly and kept outside rm_use_engine. -ENGINE_TYPE="${ENGINE_TYPE:-sglang}" +ENGINE_TYPE="${ENGINE_TYPE:-vllm}" ENGINE_TP="${ENGINE_TP:-2}" ENGINE_MEM_UTIL="${ENGINE_MEM_UTIL:-0.4}" ################################################################################ -# Part 5: Validation and Launch # +# Part 5: Launch # ################################################################################ -if (( N_SAMPLES <= 1 )); then - echo "[GodsMeme] N_SAMPLES must be > 1 when using group_norm / GRPO." >&2 - exit 1 -fi - -if (( MICRO_ROLLOUT_BATCH_SIZE % N_SAMPLES != 0 )); then - echo "[GodsMeme] MICRO_ROLLOUT_BATCH_SIZE must be divisible by N_SAMPLES." >&2 - echo "[GodsMeme] This keeps every rollout group inside one micro-batch for pairwise reward computation." >&2 - exit 1 -fi - -if (( ENGINE_TP <= 0 )); then - echo "[GodsMeme] ENGINE_TP must be a positive integer." >&2 - exit 1 -fi - -if (( GPUS_PER_NODE % ENGINE_TP != 0 )); then - echo "[GodsMeme] GPUS_PER_NODE must be divisible by ENGINE_TP." >&2 - exit 1 -fi - PAIR_TAG="allpairs" if (( MAX_PAIRS_PER_GROUP > 0 )); then PAIR_TAG="pairs${MAX_PAIRS_PER_GROUP}" @@ -142,9 +121,9 @@ DEFAULT_REWARD_PRETRAIN="$({ MAX_PAIRS_PER_GROUP="$MAX_PAIRS_PER_GROUP" \ MODEL_REWARD_WEIGHT="$MODEL_REWARD_WEIGHT" \ FORMAT_REWARD_WEIGHT="$FORMAT_REWARD_WEIGHT" \ - REWARD_MAX_NEW_TOKENS="$REWARD_MAX_NEW_TOKENS" \ + REWARD_MAX_LENGTH="$REWARD_MAX_LENGTH" \ N_SAMPLES="$N_SAMPLES" \ - python - <<'PY' + python3 - <<'PY' import json import os @@ -156,7 +135,7 @@ cfg = { "max_pairs_per_group": int(os.environ["MAX_PAIRS_PER_GROUP"]), "model_reward_weight": float(os.environ["MODEL_REWARD_WEIGHT"]), "format_reward_weight": float(os.environ["FORMAT_REWARD_WEIGHT"]), - "max_new_tokens": int(os.environ["REWARD_MAX_NEW_TOKENS"]), + "max_length": int(os.environ["REWARD_MAX_LENGTH"]), "n_samples_per_prompt": int(os.environ["N_SAMPLES"]), } } @@ -234,8 +213,8 @@ torchrun \ # 2. Do not add --rm_use_engine here: the current meme reward model raises # # NotImplementedError when rm_use_engine is enabled. # # 3. LIMIT_MM_IMAGE_PER_PROMPT defaults to 1 because each policy prompt only # -# contains one source image. The pairwise judge images are handled inside # -# examples/godsmeme/reward_model.py. # +# contains one source image. The reward model renders and scores both # +# candidate meme images inside examples/godsmeme/reward_model.py. # # 4. If reward evaluation is too slow, first try lowering N_SAMPLES or set # # MAX_PAIRS_PER_GROUP to a small positive integer. # # # diff --git a/examples/godsmeme/test_meme_dataset.py b/examples/godsmeme/test_meme_dataset.py index 18a6815b..58e5a872 100644 --- a/examples/godsmeme/test_meme_dataset.py +++ b/examples/godsmeme/test_meme_dataset.py @@ -17,7 +17,7 @@ def test_real_dataset_loading(): # print(len(dataset)) assert len(dataset) == 3345 - prompt, image_path, label, reference = dataset[0] + prompt, image_path, reference, label = dataset[0] assert "Meme Text Generation Framework" in prompt assert "[Comprehensive Description Section]" in reference assert isinstance(image_path[0], str) diff --git a/examples/godsmeme/test_reward_model_vllm.py b/examples/godsmeme/test_reward_model_vllm.py deleted file mode 100644 index c660a7a0..00000000 --- a/examples/godsmeme/test_reward_model_vllm.py +++ /dev/null @@ -1,103 +0,0 @@ -import os -import sys -from pathlib import Path - -import pytest -from PIL import Image -from transformers import AutoProcessor - -ROOT = Path(__file__).resolve().parents[2] -if str(ROOT) not in sys.path: - sys.path.insert(0, str(ROOT)) -if str(Path(__file__).resolve().parent) not in sys.path: - sys.path.insert(0, str(Path(__file__).resolve().parent)) - -from meme_dataset import MemeOnlineRLDataset # noqa: E402 -from meme_utils import extract_box_texts, normalize_detections, render_meme_image # noqa: E402 -from reward_model import ( # noqa: E402 - _load_reward_prompt, - build_pairwise_judge_message, - parse_pair_judge_response, -) - - -def _require_env(name: str) -> str: - value = os.getenv(name) - if not value: - pytest.skip(f"{name} is not set; skipping real vLLM meme reward test") - return value - - -def test_reward_model_vllm_real_data(): - if not os.getenv("RUN_GODSMEME_VLLM_TEST"): - pytest.skip("Set RUN_GODSMEME_VLLM_TEST=1 to run the real vLLM integration test") - if not Image: - pytest.skip("PIL is unavailable") - - vllm = pytest.importorskip("vllm") - if not os.getenv("CUDA_VISIBLE_DEVICES") and not os.path.exists("/dev/nvidia0"): - pytest.skip("This integration test requires a CUDA-visible vLLM environment") - - model_path = _require_env("GODSMEME_REWARD_MODEL_PATH") - annotation_path = _require_env("GODSMEME_ANNOTATION_PATH") - image_root = _require_env("GODSMEME_IMAGE_ROOT") - - processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) - dataset = MemeOnlineRLDataset( - annotation_path=annotation_path, - root_dir=image_root, - processor=processor, - shuffle=False, - ) - prompt, image_paths, reference, _ = dataset[0] - assert prompt - assert image_paths - assert reference["reference_output"] - - base_image = Image.open(image_paths[0]).convert("RGB") - detections = normalize_detections(reference) - reference_texts = extract_box_texts(reference["reference_output"]) - assert reference_texts, "reference_output must contain meme text boxes" - - bad_texts = ["lorem ipsum dolor sit amet" for _ in range(len(reference_texts))] - good_image = render_meme_image(base_image, reference_texts, detections=detections, reference=reference) - bad_image = render_meme_image(base_image, bad_texts, detections=detections, reference=reference) - - reward_prompt = _load_reward_prompt(os.getenv("GODSMEME_REWARD_PROMPT_PATH")) - judge_message = build_pairwise_judge_message( - reward_prompt=reward_prompt, - image_a=good_image, - image_b=bad_image, - text_a="\\n".join(reference_texts), - text_b="\\n".join(bad_texts), - user_request=reference.get("user_request", ""), - ) - rendered_prompt = processor.apply_chat_template(judge_message, tokenize=False, add_generation_prompt=True) - - llm = vllm.LLM( - model=model_path, - tensor_parallel_size=int(os.getenv("GODSMEME_VLLM_TP", "1")), - trust_remote_code=True, - max_model_len=int(os.getenv("GODSMEME_VLLM_MAX_MODEL_LEN", "4096")), - limit_mm_per_prompt={"image": 2}, - ) - sampling_params = vllm.SamplingParams( - temperature=0.0, - top_p=1.0, - max_tokens=int(os.getenv("GODSMEME_VLLM_MAX_TOKENS", "96")), - ) - outputs = llm.generate( - [{ - "prompt": rendered_prompt, - "multi_modal_data": { - "image": [good_image, bad_image] - } - }], - sampling_params, - ) - generated_text = outputs[0].outputs[0].text - score_a, score_b = parse_pair_judge_response(generated_text) - - assert isinstance(score_a, float) - assert isinstance(score_b, float) - assert score_a >= score_b diff --git a/examples/godsmeme/train_colocate.py b/examples/godsmeme/train_colocate.py index 32fc6324..a6ba9053 100644 --- a/examples/godsmeme/train_colocate.py +++ b/examples/godsmeme/train_colocate.py @@ -182,12 +182,18 @@ def train(args): tokenizer=tokenizer, processor=processor, prompt_max_len=args.prompt_max_len, + value_clip=args.value_clip, eps_clip=args.eps_clip, + loss_agg_mode=args.loss_agg_mode, + use_gspo=args.use_gspo, + normalize_advantages=args.normalize_advantages, + use_sequence_rewards=args.use_sequence_rewards, gamma=args.gamma, lambd=args.lambd, init_kl_coef=args.init_kl_coef, kl_target=args.kl_target, ema_beta=0.992, + ptx_coef=args.ptx_coef, max_norm=args.max_norm, # for GPT generation do_sample=True, @@ -195,6 +201,7 @@ def train(args): max_length=args.max_len, temperature=args.temperature, top_p=args.top_p, + first_token_temperature=args.first_token_temperature, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, # reward model @@ -204,6 +211,12 @@ def train(args): save_hf_ckpt=args.save_hf_ckpt, disable_ds_ckpt=args.disable_ds_ckpt, packing_samples=args.packing_samples, + high_entropy_token_ratio=args.high_entropy_token_ratio, + dynamic_sampling=args.dynamic_sampling, + overlong_buffer=args.overlong_buffer, + overlong_buffer_len=args.overlong_buffer_len, + overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, + print_replay_buffer_stats=args.print_replay_buffer_stats, ) trainer.fit( @@ -233,6 +246,36 @@ def train(args): parser.add_argument("--save_steps", type=int, default=-1) parser.add_argument("--save_hf_ckpt", action="store_true", default=False) parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument( + "--save_trajectories", + action="store_true", + default=False, + help="Save experience trajectories to JSON for debugging", + ) + parser.add_argument( + "--num_trajectories_to_save", + type=int, + default=10, + help="Number of trajectories to save per checkpoint", + ) + parser.add_argument( + "--mark_high_entropy_tokens", + action="store_true", + default=False, + help="Create token arrays with high-entropy information for HTML rendering (requires --save_trajectories).", + ) + parser.add_argument( + "--trajectory_analysis", + action="store_true", + default=False, + help="Enable trajectory analysis metrics and log them to wandb", + ) + parser.add_argument( + "--print_replay_buffer_stats", + action="store_true", + default=False, + help="Print detailed replay buffer statistics during training", + ) parser.add_argument("--logging_steps", type=int, default=1) parser.add_argument("--eval_steps", type=int, default=-1) parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo") @@ -240,6 +283,17 @@ def train(args): parser.add_argument("--max_ckpt_mem", type=int, default=1e8) parser.add_argument("--load_checkpoint", action="store_true", default=False) + # DAPO + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") + parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") + parser.add_argument( + "--overlong_buffer_penalty_factor", + type=float, + default=1.0, + help="Penalty scaling factor for overlong sequences", + ) + # PPO parser.add_argument("--num_episodes", type=int, default=1) parser.add_argument("--rollout_batch_size", type=int, default=512) @@ -251,7 +305,18 @@ def train(args): parser.add_argument("--max_samples", type=int, default=1000000) parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") + parser.add_argument( + "--loss_agg_mode", + type=str, + default="seq-mean-token-mean", + help="Loss aggregation mode for policy gradients", + ) + parser.add_argument("--use_gspo", action="store_true", default=False, help="Enable GSPO mode") + parser.add_argument("--normalize_advantages", action="store_true", default=True, help="Enable advantage normalization in GSPO") + parser.add_argument("--use_sequence_rewards", action="store_true", default=True, help="Use sequence-level rewards in GSPO") + parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd") parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") @@ -259,6 +324,18 @@ def train(args): parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") parser.add_argument("--top_p", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument( + "--use_fire", + action="store_true", + default=False, + help="Enable FIRE sampling (Flaming-hot Initiation with Regular Execution)", + ) + parser.add_argument( + "--first_token_temperature", + type=float, + default=10.0, + help="Temperature for the first token when --use_fire is enabled", + ) parser.add_argument( "--freeze_prefix", action="store_true", @@ -282,6 +359,16 @@ def train(args): ), ) parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + parser.add_argument("--reward_running_norm", action="store_true", default=False, help="Enable running normalization for rewards") + parser.add_argument( + "--reward_running_norm_minus_mean", + action="store_true", + default=False, + help="Subtract mean when using reward running normalization", + ) + parser.add_argument("--reward_clip", type=float, default=0.0, help="Clip rewards to [-reward_clip, reward_clip]") + parser.add_argument("--advantages_norm", action="store_true", default=False, help="Enable whitening for advantages") + parser.add_argument("--advantage_clip", type=float, default=0.0, help="Clip advantages to [-advantage_clip, advantage_clip]") parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range") # DeepSpeed @@ -290,6 +377,7 @@ def train(args): parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") parser.add_argument("--gradient_checkpointing", action="store_true", default=False) parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--meta_init", action="store_true", default=False, help="Initialize models on meta device to save CPU memory") parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") parser.add_argument("--actor_init_on_gpu", action="store_true", default=False) @@ -366,6 +454,7 @@ def train(args): parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") # MultiModal + parser.add_argument("--text_only", action="store_true", default=False) parser.add_argument( "--limit_mm_image_per_prompt", type=int, @@ -380,10 +469,17 @@ def train(args): default=False, help="whether to use the clipped policy gradient loss from CPGD" ) + parser.add_argument( + "--high_entropy_token_ratio", + type=float, + default=0.0, + help="Ratio of high-entropy tokens to keep for policy updates", + ) add_arguments(parser) args = parser.parse_args() + args.use_kl_estimator_k3 = args.kl_estimator == "k3" if args.advantage_estimator not in ["gae"]: args.critic_pretrain = None From 9ada24d41237affafd7fcca5530d3594a21c1727 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sun, 19 Apr 2026 22:55:58 +0800 Subject: [PATCH 5/5] style(nyz): correct format --- examples/godsmeme/reward_model.py | 35 +++++++---------------------- examples/godsmeme/train_colocate.py | 28 +++++++++++++++++------ 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/examples/godsmeme/reward_model.py b/examples/godsmeme/reward_model.py index 2cbee51d..12263576 100644 --- a/examples/godsmeme/reward_model.py +++ b/examples/godsmeme/reward_model.py @@ -68,12 +68,10 @@ def _as_torch_device(device_like: Any) -> torch.device: return torch.device(device_like) - def _default_reward_prompt_path() -> str: return os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompts", "reward_compare.txt") - def _load_reward_prompt(path: Optional[str] = None) -> str: prompt_path = path or _default_reward_prompt_path() if os.path.exists(prompt_path): @@ -81,7 +79,6 @@ def _load_reward_prompt(path: Optional[str] = None) -> str: return "Which meme is funnier?" - def _parse_reward_config(raw_reward_pretrain: str) -> Dict[str, Any]: if raw_reward_pretrain is None or not str(raw_reward_pretrain).strip(): raise ValueError("`reward_pretrain` must point to a meme judge model or a JSON config") @@ -112,7 +109,6 @@ def _parse_reward_config(raw_reward_pretrain: str) -> Dict[str, Any]: return pairwise_cfg - def _resolve_torch_dtype(dtype_like: Any) -> Optional[torch.dtype]: if dtype_like is None or isinstance(dtype_like, torch.dtype): return dtype_like @@ -122,7 +118,6 @@ def _resolve_torch_dtype(dtype_like: Any) -> Optional[torch.dtype]: return _DTYPE_MAP[key] - def _resolve_model_file(path_or_repo: str, filename: str) -> str: if os.path.isdir(path_or_repo): local_path = os.path.join(path_or_repo, filename) @@ -185,7 +180,6 @@ def _patch_keye_fsdp_compat(model: nn.Module) -> None: print("[MemePairwiseJudge] Set Keye vision attention to 'eager' for FSDP2 compat") - def _unwrap_head_state_dict(state_dict: Dict[str, Any]) -> Dict[str, torch.Tensor]: nested_keys = ("state_dict", "model", "module", "classification_head", "classifier") current = state_dict @@ -212,14 +206,12 @@ def _unwrap_head_state_dict(state_dict: Dict[str, Any]) -> Dict[str, torch.Tenso return tensor_state - def _strip_common_prefix(state_dict: Dict[str, torch.Tensor], prefix: str) -> Dict[str, torch.Tensor]: if not all(key.startswith(prefix) for key in state_dict): return state_dict return {key[len(prefix):]: value for key, value in state_dict.items()} - def _normalize_head_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: normalized = dict(state_dict) changed = True @@ -245,7 +237,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.out_proj(self.activation(self.dense(hidden_states))) - def _build_classification_head(state_dict: Dict[str, torch.Tensor]) -> nn.Module: state_dict = _normalize_head_state_dict(_unwrap_head_state_dict(state_dict)) keys = set(state_dict.keys()) @@ -258,14 +249,12 @@ def _build_classification_head(state_dict: Dict[str, torch.Tensor]) -> nn.Module hidden_features=dense_weight.shape[0], out_features=out_proj_weight.shape[0], ) - head.load_state_dict( - { - "dense.weight": dense_weight, - "dense.bias": state_dict["dense.bias"], - "out_proj.weight": out_proj_weight, - "out_proj.bias": state_dict["out_proj.bias"], - } - ) + head.load_state_dict({ + "dense.weight": dense_weight, + "dense.bias": state_dict["dense.bias"], + "out_proj.weight": out_proj_weight, + "out_proj.bias": state_dict["out_proj.bias"], + }) return head if "weight" in state_dict and state_dict["weight"].ndim == 2: @@ -296,7 +285,6 @@ def _build_classification_head(state_dict: Dict[str, torch.Tensor]) -> nn.Module ) - def _load_classification_head(path_or_file: str, dtype: Optional[torch.dtype] = None) -> nn.Module: head_state = torch.load(path_or_file, map_location="cpu") head = _build_classification_head(head_state) @@ -305,7 +293,6 @@ def _load_classification_head(path_or_file: str, dtype: Optional[torch.dtype] = return head - def _pool_last_non_padding_token(hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: if attention_mask is None: return hidden_states[:, -1, :] @@ -318,7 +305,6 @@ def _pool_last_non_padding_token(hidden_states: torch.Tensor, attention_mask: Op return hidden_states.gather(dim=1, index=gather_index).squeeze(1) - def _pair_scores_from_logits(logits: torch.Tensor) -> List[Tuple[float, float]]: if logits.ndim == 1: logits = logits.unsqueeze(-1) @@ -334,7 +320,6 @@ def _pair_scores_from_logits(logits: torch.Tensor) -> List[Tuple[float, float]]: return list(zip(probs_a.detach().cpu().tolist(), probs_b.detach().cpu().tolist())) - def build_pairwise_judge_message( reward_prompt: str, image_a, @@ -359,7 +344,6 @@ def build_pairwise_judge_message( }] - def group_meme_rollout_indices( refs: Optional[Sequence[Any]], batch_size: int, @@ -591,7 +575,6 @@ def forward( return {"score": torch.tensor(pairwise_rewards, dtype=torch.float32, device=device)} - def load_reward_models( reward_pretrain: str, strategy, @@ -607,9 +590,8 @@ def load_reward_models( reward_model_path = cfg["path"] reward_prompt = _load_reward_prompt(cfg.get("reward_prompt_path")) torch_dtype = _resolve_torch_dtype(cfg.get("torch_dtype", "float16")) - classification_head_path = cfg.get("classification_head_path") or _resolve_model_file( - reward_model_path, "classification_head.pt" - ) + classification_head_path = cfg.get("classification_head_path" + ) or _resolve_model_file(reward_model_path, "classification_head.pt") with strategy.init_model_context() as _: base_model = AutoModel.from_pretrained( @@ -655,7 +637,6 @@ def load_reward_models( return [model], [processor.tokenizer], {_PAIRWISE_LABEL_KEY: 0} - def reward_fn( model_reward_list: List[torch.Tensor], labels: Sequence[str], diff --git a/examples/godsmeme/train_colocate.py b/examples/godsmeme/train_colocate.py index a6ba9053..87dbc412 100644 --- a/examples/godsmeme/train_colocate.py +++ b/examples/godsmeme/train_colocate.py @@ -284,8 +284,12 @@ def train(args): parser.add_argument("--load_checkpoint", action="store_true", default=False) # DAPO - parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") - parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") + parser.add_argument( + "--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy" + ) + parser.add_argument( + "--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO" + ) parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") parser.add_argument( "--overlong_buffer_penalty_factor", @@ -314,8 +318,12 @@ def train(args): help="Loss aggregation mode for policy gradients", ) parser.add_argument("--use_gspo", action="store_true", default=False, help="Enable GSPO mode") - parser.add_argument("--normalize_advantages", action="store_true", default=True, help="Enable advantage normalization in GSPO") - parser.add_argument("--use_sequence_rewards", action="store_true", default=True, help="Use sequence-level rewards in GSPO") + parser.add_argument( + "--normalize_advantages", action="store_true", default=True, help="Enable advantage normalization in GSPO" + ) + parser.add_argument( + "--use_sequence_rewards", action="store_true", default=True, help="Use sequence-level rewards in GSPO" + ) parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd") parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") @@ -359,7 +367,9 @@ def train(args): ), ) parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") - parser.add_argument("--reward_running_norm", action="store_true", default=False, help="Enable running normalization for rewards") + parser.add_argument( + "--reward_running_norm", action="store_true", default=False, help="Enable running normalization for rewards" + ) parser.add_argument( "--reward_running_norm_minus_mean", action="store_true", @@ -368,7 +378,9 @@ def train(args): ) parser.add_argument("--reward_clip", type=float, default=0.0, help="Clip rewards to [-reward_clip, reward_clip]") parser.add_argument("--advantages_norm", action="store_true", default=False, help="Enable whitening for advantages") - parser.add_argument("--advantage_clip", type=float, default=0.0, help="Clip advantages to [-advantage_clip, advantage_clip]") + parser.add_argument( + "--advantage_clip", type=float, default=0.0, help="Clip advantages to [-advantage_clip, advantage_clip]" + ) parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range") # DeepSpeed @@ -377,7 +389,9 @@ def train(args): parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") parser.add_argument("--gradient_checkpointing", action="store_true", default=False) parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") - parser.add_argument("--meta_init", action="store_true", default=False, help="Initialize models on meta device to save CPU memory") + parser.add_argument( + "--meta_init", action="store_true", default=False, help="Initialize models on meta device to save CPU memory" + ) parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") parser.add_argument("--actor_init_on_gpu", action="store_true", default=False)