diff --git a/examples/godsmeme/README.md b/examples/godsmeme/README.md new file mode 100644 index 00000000..129bc992 --- /dev/null +++ b/examples/godsmeme/README.md @@ -0,0 +1,163 @@ +# GodsMeme on LightRFT + +This example adapts LightRFT's vision-language GRPO pipeline to meme generation. +The policy sees one source image plus a preformatted GodsMeme prompt, generates a full reasoning-style answer, and is rewarded by a pairwise meme judge that compares rendered meme images within each GRPO group. + +## What the training entrypoint actually is + +There are two entry layers: + +- `examples/godsmeme/run_meme_grpo.sh`: user-facing launcher. Set paths and training knobs here, then run it with `bash`. +- `examples/godsmeme/train_colocate.py`: Python training entry used by `torchrun`. It builds the actor, dataset, reward model, inference engine, and PPO trainer. + +If you only want to start training, edit the shell script or override its environment variables. You usually do not need to call `train_colocate.py` directly. + +## End-to-end training flow + +1. `MemeOnlineRLDataset` loads prebuilt RL rows from `annotation_path` and resolves the source image from `root_dir`. +2. Each row is converted into a multimodal chat prompt with one image placeholder and one text instruction. +3. GRPO samples `n_samples_per_prompt` policy completions for the same prompt. +4. The reward pipeline extracts the `Text on the Meme` section from each completion. +5. The extracted text is rendered back onto the original image by using detection boxes when available, or a simple fallback layout otherwise. +6. A local VLM judge compares candidate meme images pairwise inside the same rollout group. +7. Pairwise scores are aggregated into one scalar reward per sample. +8. A small format reward is added so the policy keeps the expected GodsMeme response structure. +9. LightRFT runs PPO/GRPO updates and saves checkpoints to the configured output directory. + +## Directory map + +```text +examples/godsmeme/ +├── README.md # This guide +├── run_meme_grpo.sh # User-facing launch script +├── train_colocate.py # Main GRPO training entry +├── meme_dataset.py # Dataset loader for GodsMeme RL rows +├── reward_model.py # Pairwise reward judge and reward aggregation +├── meme_utils.py # Text parsing, rendering, and pair helpers +├── prompts/ +│ ├── generate_meme.txt # Reference policy prompt format +│ └── reward_compare.txt # Prompt template for the pairwise judge +├── test_meme_dataset.py # Dataset test scaffold +└── test_reward_model_vllm.py # Optional reward-model smoke test +``` + +## Dataset format expected by this example + +This example does not build the RL dataset for you. It expects a JSON or JSONL file where each row already looks like a conversation-style GodsMeme training sample. + +Minimal example: + +```json +{ + "id": "sample-001", + "image": "images/cat.jpg", + "conversations": [ + { + "from": "human", + "value": "...GodsMeme prompt... " + }, + { + "from": "assistant", + "value": "...reference reasoning and meme text..." + } + ], + "text_loc_info": { + "loc": [[40, 60, 420, 170], [40, 330, 420, 430]] + } +} +``` + +Useful notes: + +- `image`, `image_path`, and `img` are accepted as image keys. +- The human message must contain ``. +- The assistant message is treated as the reference output and is also used to infer the expected number of text boxes when box metadata is missing. +- Supported box metadata includes `detections`, `text_loc_info`, `loc`, `bbox_scale`, `bbox_normalized`, and `expected_box_count`. +- If no boxes are available, the renderer falls back to a simple top/bottom style layout. + +## Reward design + +The reward combines two parts: + +```text +final_reward = model_reward_weight * pairwise_reward + + format_reward_weight * format_reward +``` + +- `pairwise_reward`: computed by comparing rendered candidate memes from the same rollout group. +- `format_reward`: checks whether the completion keeps the expected GodsMeme answer structure and box count. + +Default weights: + +- `model_reward_weight = 1.0` +- `format_reward_weight = 0.1` + +The default launcher builds `--reward_pretrain` as a JSON blob that points to the local judge model and the comparison prompt template. + +The current implementation follows the `HUMOR-RM-Keye-VL` inference pattern directly in `reward_model.py`, loading the Keye-based reward model plus `classification_head.pt` without adding a `llamafactory` runtime dependency. The pairwise judge prompt is kept to the same simple question used in the model README: `Which meme is funnier?` + +## Quick start + +Edit the paths in `examples/godsmeme/run_meme_grpo.sh`, or override them inline: + +```bash +POLICY_MODEL_PATH=/path/to/policy-model \ +REWARD_MODEL_PATH=/path/to/reward-model \ +ANNOTATION_PATH=/path/to/train_data.jsonl \ +IMAGE_ROOT=/path/to/image_root \ +bash examples/godsmeme/run_meme_grpo.sh +``` + +Default outputs: + +- checkpoints: `results///` +- logs: `rft_logs//` + +## Important constraints before you launch + +- `N_SAMPLES` must be greater than `1` for GRPO with `group_norm`. +- `MICRO_ROLLOUT_BATCH_SIZE % N_SAMPLES == 0` must hold, otherwise one prompt group can be split across micro-batches and pairwise reward becomes invalid. +- Pairwise judge cost grows roughly quadratically with `N_SAMPLES` unless you cap `MAX_PAIRS_PER_GROUP`. +- The current meme reward model does not support `--rm_use_engine`; the judge is loaded directly by `reward_model.py`. +- Each policy prompt uses one source image, so `LIMIT_MM_IMAGE_PER_PROMPT` is set to `1`. +- This example reads `--annotation_path` and `--root_dir`; it does not use the generic `--prompt_data` path for training data. + +## Most useful knobs in `run_meme_grpo.sh` + +- `POLICY_MODEL_PATH`: actor checkpoint or HF model id. +- `REWARD_MODEL_PATH`: Keye-based reward-model checkpoint path. It should contain `classification_head.pt`. +- `REWARD_MAX_LENGTH`: max sequence length used when scoring rendered meme pairs. +- `ANNOTATION_PATH` / `IMAGE_ROOT`: GodsMeme RL data location. +- `N_SAMPLES`: number of rollouts per prompt. Higher is better for ranking quality but much slower. +- `MAX_PAIRS_PER_GROUP`: limit pairwise comparisons to reduce reward cost. +- `PAIR_BATCH_SIZE`: judge batch size for pairwise evaluation. +- `MICRO_ROLLOUT_BATCH_SIZE`: must be divisible by `N_SAMPLES`. +- `MICRO_TRAIN_BATCH_SIZE`: lower this first if actor training OOMs. +- `ENGINE_TYPE`, `ENGINE_TP`, `ENGINE_MEM_UTIL`: rollout engine settings for the actor. +- `PROMPT_MAX_LEN`, `GENERATE_MAX_LEN`: trim these if prompts or responses are too long for your setup. + +## Practical tuning advice + +- Start with smaller `N_SAMPLES` such as `4` before scaling to `8`. +- If reward evaluation is the bottleneck, reduce `N_SAMPLES` or set `MAX_PAIRS_PER_GROUP` to a small positive integer. +- If the actor OOMs, first lower `MICRO_TRAIN_BATCH_SIZE`, then `MICRO_ROLLOUT_BATCH_SIZE`, then `ENGINE_MEM_UTIL`. +- If generations are too verbose, reduce `GENERATE_MAX_LEN` and revisit the prompt format in `examples/godsmeme/prompts/generate_meme.txt`. +- If rendered text looks misplaced, verify your box annotations before changing the reward model. + +## Optional validation + +There are lightweight unit tests for the new reward-model invocation path: + +```bash +pytest examples/godsmeme/test_reward_model_vllm.py +``` + +These tests cover the classifier-head loading helpers and the direct pair-scoring flow without requiring a real reward checkpoint. + +## Current limitations + +- No data preprocessing script is included in this folder; the RL JSON/JSONL must already be prepared. +- The reward judge is model-based and local, so training cost is much higher than rule-only GRPO examples. +- The reward loader expects the Keye multimodal interface used by `HUMOR-RM-Keye-VL`. +- The fallback renderer is intentionally simple and is mainly a reward-time approximation, not a production meme compositor. +- The policy is optimized for the GodsMeme response template; changing the prompt schema usually requires updating both parsing and reward logic. diff --git a/examples/godsmeme/meme_dataset.py b/examples/godsmeme/meme_dataset.py new file mode 100644 index 00000000..3c63cd00 --- /dev/null +++ b/examples/godsmeme/meme_dataset.py @@ -0,0 +1,125 @@ +import json +import os +import random +import re +from typing import Any, Dict, List, Tuple, Union + +from torch.utils.data import Dataset + +from meme_utils import extract_box_texts, resolve_expected_box_count + + +class MemeOnlineRLDataset(Dataset): + """Meme dataset class with lazy loading per item.""" + + ASSISTANT_ROLES = ("gpt", "assistant") + DEFAULT_LABEL = "meme_pairwise" + + def __init__( + self, + annotation_path: str, + root_dir: str, + processor, + shuffle: bool = True, + ): + super().__init__() + + if not os.path.exists(annotation_path): + raise FileNotFoundError(f"Annotation file {annotation_path} does not exist") + if not os.path.isdir(root_dir): + raise NotADirectoryError(f"Image root directory {root_dir} is invalid") + + self.root_dir = root_dir + self.annotation_path = annotation_path + self.processor = processor + + self._raw_data = self._load_raw_data() + if shuffle: + random.shuffle(self._raw_data) + + def _load_raw_data(self) -> List[Union[Dict[str, Any], str]]: + with open(self.annotation_path, "r", encoding="utf-8") as handle: + content = handle.read().strip() + if not content: + return [] + try: + data = json.loads(content) + if isinstance(data, list): + return data + except json.JSONDecodeError: + pass + + with open(self.annotation_path, "r", encoding="utf-8") as handle: + return [line.strip() for line in handle if line.strip()] + + def _resolve_image_path(self, data: Dict[str, Any]) -> str: + image_value = data.get("image") or data.get("image_path") or data.get("img") + if not image_value: + raise KeyError("Dataset row is missing `image`") + image_path = image_value if os.path.isabs(image_value) else os.path.join(self.root_dir, image_value) + if not os.path.exists(image_path): + raise FileNotFoundError(f"Image {image_path} does not exist") + return image_path + + def _build_reference(self, data: Dict[str, Any], prompt_text: str, assistant_output: str) -> Dict[str, Any]: + reference: Dict[str, Any] = { + "id": data.get("id"), + "group_id": str(data.get("group_id") or data.get("sample_id") or data.get("id") or ""), + "reference_output": assistant_output, + } + + for key in ("detections", "text_loc_info", "loc", "bbox_scale", "bbox_normalized", "expected_box_count"): + if key in data: + reference[key] = data[key] + + expected_box_count = resolve_expected_box_count(reference) + if expected_box_count is None: + box_texts = extract_box_texts(assistant_output) + if box_texts: + expected_box_count = len(box_texts) + if expected_box_count is not None: + reference["expected_box_count"] = expected_box_count + + return reference + + def _process_item(self, raw_item: Union[Dict[str, Any], str]) -> Tuple[str, List[str], Dict[str, Any], str]: + data = raw_item if isinstance(raw_item, dict) else json.loads(raw_item) + image_path = self._resolve_image_path(data) + + conversations = data["conversations"] + human_input = next(c["value"] for c in conversations if c["from"] == "human" and "" in c["value"]) + assistant_output = next(c["value"] for c in conversations if c["from"] in self.ASSISTANT_ROLES) + + prompt = [{ + "role": "user", + "content": [ + { + "type": "image", + "image": "" + }, + { + "type": "text", + "text": human_input + }, + ], + }] + prompt = self.processor.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) + reference = self._build_reference(data, human_input, assistant_output) + reference = reference["reference_output"] + label = data.get("reward_rule_label", self.DEFAULT_LABEL) + + return prompt, [image_path], reference, label + + def __getitem__(self, index: int) -> Tuple[str, List[str], Dict[str, Any], str]: + return self._process_item(self._raw_data[index]) + + def __len__(self) -> int: + return len(self._raw_data) + + @staticmethod + def collate_fn(batch: List[Tuple[str, List[str], Dict[str, Any], str]]): + text_list = [item[0] for item in batch] + image_list = [item[1] for item in batch] + reference_list = [item[2] for item in batch] + label_list = [item[3] for item in batch] + return text_list, image_list, reference_list, label_list diff --git a/examples/godsmeme/meme_utils.py b/examples/godsmeme/meme_utils.py new file mode 100644 index 00000000..3671a5f2 --- /dev/null +++ b/examples/godsmeme/meme_utils.py @@ -0,0 +1,431 @@ +import os +import re +import random +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from PIL import Image, ImageDraw, ImageFont + +_REQUIRED_SECTION_PATTERNS = [ + r"\[Comprehensive Description Section\]", + r"\[Usage Scenarios Section\]", + r"\[Text Analysis Section\]", + r"\[Specific Analysis with User Input\]", + r"Text on the Meme:", +] + + +@dataclass +class MemeRenderConfig: + font_name: str = "DejaVuSans.ttf" + min_font_size: int = 14 + max_font_size: int = 72 + line_spacing: int = 4 + outline_width: int = 2 + margin: int = 6 + default_padding_ratio: float = 0.06 + + +@dataclass +class PairwisePreference: + index_a: int + index_b: int + score_a: float + score_b: float + + +def load_text_file(path: str) -> str: + with open(path, "r", encoding="utf-8") as handle: + return handle.read().strip() + + +def extract_assistant_response(text: str) -> str: + if "<|im_start|>assistant" in text: + return text.split("<|im_start|>assistant")[-1].strip() + if "assistant\n" in text: + return text.split("assistant\n")[-1].strip() + return text.strip() + + +def extract_text_on_meme_section(response: str) -> str: + response = extract_assistant_response(response) + match = re.search(r"Text on the Meme:\s*(.*)$", response, re.IGNORECASE | re.DOTALL) + if not match: + return "" + return match.group(1).strip() + + +def extract_box_texts(response: str) -> List[str]: + section = extract_text_on_meme_section(response) + if not section: + return [] + + matches = list(re.finditer( + r"(?im)^\s*box\s*(\d+)\s*:\s*(.*?)(?=^\s*box\s*\d+\s*:|\Z)", + section, + re.DOTALL, + )) + if matches: + items = [] + for match in matches: + payload = match.group(2).strip().strip('"').strip() + if payload: + items.append(payload) + if items: + return items + + lines = [] + for raw_line in section.splitlines(): + line = raw_line.strip().strip('"').strip() + if line: + lines.append(line) + return lines + + +def compute_meme_format_reward(response: str, expected_boxes: Optional[int] = None) -> float: + response = extract_assistant_response(response) + checks = 0 + passed = 0 + + for pattern in _REQUIRED_SECTION_PATTERNS: + checks += 1 + if re.search(pattern, response, re.IGNORECASE): + passed += 1 + + checks += 2 + if re.search(r"(?im)^\s*Step\s*1\s*:", response): + passed += 1 + if re.search(r"(?im)^\s*Step\s*2\s*:", response): + passed += 1 + + fragments = extract_box_texts(response) + checks += 2 + if fragments: + passed += 1 + if fragments and all(fragment.strip() for fragment in fragments): + passed += 1 + + if expected_boxes is not None and expected_boxes > 0: + checks += 1 + if len(fragments) == expected_boxes: + passed += 1 + + return float(passed) / float(checks) if checks else 0.0 + + +def normalize_bbox(bbox: Any) -> Optional[Tuple[float, float, float, float]]: + if bbox is None: + return None + if isinstance(bbox, dict): + for key in ("bbox", "box", "loc", "coordinates"): + if key in bbox: + bbox = bbox[key] + break + if isinstance(bbox, tuple): + bbox = list(bbox) + if not isinstance(bbox, list) or len(bbox) != 4: + return None + + values: List[float] = [] + for item in bbox: + try: + values.append(float(item)) + except (TypeError, ValueError): + return None + return tuple(values) + + +def normalize_detections(reference: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: + if not isinstance(reference, dict): + return [] + + raw_detections = reference.get("detections") + if raw_detections is None and isinstance(reference.get("text_loc_info"), dict): + text_loc_info = reference["text_loc_info"] + locs = text_loc_info.get("loc", []) + texts = text_loc_info.get("text", []) + if isinstance(texts, str): + texts = extract_box_texts(f"Text on the Meme:\n{texts}") + raw_detections = [{"bbox": loc, "text": texts[idx] if idx < len(texts) else ""} for idx, loc in enumerate(locs)] + if raw_detections is None and isinstance(reference.get("loc"), list): + raw_detections = [{"bbox": loc} for loc in reference.get("loc", [])] + + detections: List[Dict[str, Any]] = [] + for item in raw_detections or []: + bbox = normalize_bbox(item) + if bbox is None: + continue + payload = item if isinstance(item, dict) else {"bbox": item} + detections.append({"bbox": bbox, **payload}) + return detections + + +def resolve_expected_box_count(reference: Optional[Dict[str, Any]]) -> Optional[int]: + detections = normalize_detections(reference) + if detections: + return len(detections) + if isinstance(reference, dict) and isinstance(reference.get("expected_box_count"), int): + return reference["expected_box_count"] + return None + + +def get_user_request(reference: Optional[Dict[str, Any]], fallback_prompt: Optional[str] = None) -> str: + if isinstance(reference, dict): + for key in ("user_input_text", "user_request", "input_params", "prompt_summary"): + value = reference.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return (fallback_prompt or "").strip() + + +def get_reference_output(reference: Optional[Dict[str, Any]]) -> str: + if not isinstance(reference, dict): + return "" + for key in ("reference_output", "reference", "answer", "label_text"): + value = reference.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + return "" + + +def get_reference_group_id(reference: Optional[Dict[str, Any]], fallback_index: int) -> str: + if isinstance(reference, dict): + for key in ("group_id", "sample_id", "id"): + value = reference.get(key) + if value is not None: + return str(value) + return str(fallback_index) + + +def get_first_image(raw_image: Any) -> Optional[Image.Image]: + if raw_image is None: + return None + if isinstance(raw_image, Image.Image): + return raw_image + if isinstance(raw_image, list) and raw_image: + return get_first_image(raw_image[0]) + return None + + +def _resolve_font_path(font_name: str) -> Optional[str]: + search_roots = [] + env_root = os.getenv("MEMEGENERATOR_FONT_DIR") + if env_root: + search_roots.append(os.path.expanduser(env_root)) + search_roots.extend([ + "/usr/share/fonts/truetype/dejavu", + "/usr/share/fonts/truetype/liberation2", + "/System/Library/Fonts", + "/Library/Fonts", + ]) + for root in search_roots: + candidate = os.path.join(root, font_name) + if os.path.exists(candidate): + return candidate + return font_name + + +def _load_font(font_name: str, font_size: int) -> ImageFont.FreeTypeFont: + try: + return ImageFont.truetype(_resolve_font_path(font_name), font_size) + except OSError: + return ImageFont.load_default() + + +def _measure_text(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont) -> Tuple[int, int]: + if hasattr(draw, "textbbox"): + left, top, right, bottom = draw.textbbox((0, 0), text, font=font) + return right - left, bottom - top + return draw.textsize(text, font=font) + + +def _wrap_text(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont, max_width: int) -> List[str]: + words = text.split() + if not words: + return [""] + + lines: List[str] = [] + current = words[0] + for word in words[1:]: + candidate = f"{current} {word}".strip() + if _measure_text(draw, candidate, font)[0] <= max_width: + current = candidate + else: + lines.append(current) + current = word + lines.append(current) + return lines + + +def _fit_font(draw: ImageDraw.ImageDraw, text: str, box: Tuple[int, int, int, int], + config: MemeRenderConfig) -> Tuple[ImageFont.ImageFont, List[str]]: + x1, y1, x2, y2 = box + max_width = max(1, x2 - x1 - (2 * config.margin)) + max_height = max(1, y2 - y1 - (2 * config.margin)) + + best_font: ImageFont.ImageFont = _load_font(config.font_name, config.min_font_size) + best_lines = _wrap_text(draw, text, best_font, max_width) + + for size in range(config.min_font_size, config.max_font_size + 1): + font = _load_font(config.font_name, size) + lines = _wrap_text(draw, text, font, max_width) + line_heights = [_measure_text(draw, line, font)[1] for line in lines] + total_height = sum(line_heights) + config.line_spacing * max(0, len(lines) - 1) + line_width = max(_measure_text(draw, line, font)[0] for line in lines) + if line_width <= max_width and total_height <= max_height: + best_font = font + best_lines = lines + else: + break + + return best_font, best_lines + + +def _choose_colors(image: Image.Image, box: Tuple[int, int, int, + int]) -> Tuple[Tuple[int, int, int], Tuple[int, int, int]]: + crop = image.crop(box) + mean_pixel = crop.resize((1, 1)).getpixel((0, 0)) if crop.size[0] > 0 and crop.size[1] > 0 else (255, 255, 255) + if isinstance(mean_pixel, int): + mean_pixel = (mean_pixel, mean_pixel, mean_pixel) + luminance = (0.299 * mean_pixel[0]) + (0.587 * mean_pixel[1]) + (0.114 * mean_pixel[2]) + if luminance > 186: + return (0, 0, 0), (255, 255, 255) + return (255, 255, 255), (0, 0, 0) + + +def _scale_box( + bbox: Tuple[float, float, float, float], + image: Image.Image, + reference: Optional[Dict[str, Any]] = None, +) -> Tuple[int, int, int, int]: + width, height = image.size + bbox_scale = None + normalized = False + if isinstance(reference, dict): + bbox_scale = reference.get("bbox_scale") + normalized = bool(reference.get("bbox_normalized", False)) + + x1, y1, x2, y2 = bbox + values = [x1, y1, x2, y2] + if bbox_scale is not None: + scale = float(bbox_scale) + x1, x2 = (x1 / scale) * width, (x2 / scale) * width + y1, y2 = (y1 / scale) * height, (y2 / scale) * height + elif normalized or max(values) <= 1.0: + x1, x2 = x1 * width, x2 * width + y1, y2 = y1 * height, y2 * height + + left = max(0, min(int(round(x1)), width - 1)) + top = max(0, min(int(round(y1)), height - 1)) + right = max(left + 1, min(int(round(x2)), width)) + bottom = max(top + 1, min(int(round(y2)), height)) + return left, top, right, bottom + + +def default_render_boxes(image: Image.Image, count: int, config: MemeRenderConfig) -> List[Tuple[int, int, int, int]]: + width, height = image.size + pad_x = int(width * config.default_padding_ratio) + pad_y = int(height * config.default_padding_ratio) + if count <= 1: + return [(pad_x, pad_y, width - pad_x, max(pad_y + 1, int(height * 0.2)))] + top_box = (pad_x, pad_y, width - pad_x, max(pad_y + 1, int(height * 0.2))) + bottom_box = (pad_x, max(int(height * 0.78), pad_y), width - pad_x, height - pad_y) + boxes = [top_box, bottom_box] + if count > 2: + middle_height = max(1, int((height * 0.58) / max(1, count - 2))) + start_y = int(height * 0.24) + for idx in range(count - 2): + y1 = start_y + (idx * middle_height) + y2 = min(height - pad_y, y1 + middle_height) + boxes.insert(-1, (pad_x, y1, width - pad_x, max(y1 + 1, y2))) + return boxes[:count] + + +def render_meme_image( + image: Image.Image, + texts: Sequence[str], + detections: Optional[Sequence[Dict[str, Any]]] = None, + reference: Optional[Dict[str, Any]] = None, + config: Optional[MemeRenderConfig] = None, +) -> Image.Image: + render_config = config or MemeRenderConfig() + canvas = image.convert("RGB").copy() + draw = ImageDraw.Draw(canvas) + + cleaned_texts = [text.strip() for text in texts if text and text.strip()] + if not cleaned_texts: + return canvas + + boxes: List[Tuple[int, int, int, int]] = [] + for detection in detections or []: + bbox = normalize_bbox(detection) + if bbox is None: + continue + boxes.append(_scale_box(bbox, canvas, reference=reference)) + + if not boxes: + boxes = default_render_boxes(canvas, len(cleaned_texts), render_config) + + if len(cleaned_texts) > len(boxes): + merged = list(cleaned_texts[:len(boxes)]) + merged[-1] = "\n".join([merged[-1], *cleaned_texts[len(boxes):]]).strip() + cleaned_texts = merged + else: + cleaned_texts = cleaned_texts[:len(boxes)] + + for box, text in zip(boxes, cleaned_texts): + font, wrapped_lines = _fit_font(draw, text, box, render_config) + fill_color, outline_color = _choose_colors(canvas, box) + line_sizes = [_measure_text(draw, line, font) for line in wrapped_lines] + total_height = sum(height + for _, height in line_sizes) + render_config.line_spacing * max(0, + len(wrapped_lines) - 1) + x1, y1, x2, y2 = box + current_y = y1 + max(render_config.margin, (y2 - y1 - total_height) // 2) + + for line, (line_width, line_height) in zip(wrapped_lines, line_sizes): + current_x = x1 + max(render_config.margin, (x2 - x1 - line_width) // 2) + for dx in range(-render_config.outline_width, render_config.outline_width + 1): + for dy in range(-render_config.outline_width, render_config.outline_width + 1): + if dx == 0 and dy == 0: + continue + draw.text((current_x + dx, current_y + dy), line, font=font, fill=outline_color) + draw.text((current_x, current_y), line, font=font, fill=fill_color) + current_y += line_height + render_config.line_spacing + + return canvas + + +def sample_group_pairs(group_size: int, max_pairs: int = 0, seed: Optional[int] = None) -> List[Tuple[int, int]]: + pairs = [(left, right) for left in range(group_size) for right in range(left + 1, group_size)] + if max_pairs <= 0 or len(pairs) <= max_pairs: + return pairs + rng = random.Random(seed) + rng.shuffle(pairs) + return pairs[:max_pairs] + + +def aggregate_pairwise_preferences( + batch_size: int, + preferences: Iterable[PairwisePreference], +) -> List[float]: + totals = [0.0 for _ in range(batch_size)] + counts = [0 for _ in range(batch_size)] + + for pref in preferences: + denom = pref.score_a + pref.score_b + if denom <= 0: + norm_a = 0.5 + norm_b = 0.5 + else: + norm_a = pref.score_a / denom + norm_b = pref.score_b / denom + totals[pref.index_a] += norm_a + totals[pref.index_b] += norm_b + counts[pref.index_a] += 1 + counts[pref.index_b] += 1 + + rewards = [] + for total, count in zip(totals, counts): + rewards.append(total / count if count > 0 else 0.0) + return rewards diff --git a/examples/godsmeme/prompts/generate_meme.txt b/examples/godsmeme/prompts/generate_meme.txt new file mode 100644 index 00000000..a056bbf5 --- /dev/null +++ b/examples/godsmeme/prompts/generate_meme.txt @@ -0,0 +1,77 @@ +**Meme Text Generation Framework** +Based on the meme basemap and user input, analyze what can be written on this basemap that meets the user's needs and is as humorous as possible. +**Phase 1: Base Image Analysis** +[Comprehensive Description Section] +- **Visual Deconstruction**: +- Primary subjects (demeanor/movement/apparel of entities) +- Composition logic (focal points/color contrast/spatial relationships) +- Cultural signifiers (recognizable meme formats/pop culture references) +- Narrative cues (body language implications/prop symbolism) +- Keep the tone vivid and concrete; avoid flat, purely descriptive lists. + +[Usage Scenarios Section] +- **Scenario Modeling**: +- Social contexts (group chats/comment sections/private conversations) +- Topic alignment (workplace culture/life struggles/viral trends) +- Emotional mapping (sarcasm/self-deprecation/absurdist/dark humor) +- Cross-platform adaptation (short video captions/chat stickers/forum posts) +- Highlight why the scenario is funny or relatable instead of just stating where it fits. + +[Text Analysis Section] +- **Humor Engineering**: +- Wordplay (puns/homophones/semantic reversal) +- Cognitive dissonance (expectation subversion/scale exaggeration/role mismatch) +- Emotional resonance (generational gaps/life frustrations/cringe moments) +- Format optimization (suspenseful opening line/punchline reversal/rhyme schemes) +- Use the tone and rhythm of the sentences demonstrated in this framework in the [Text Analysis Section] to create a consistent overall sense of language, defaulting to the two-line setup → punchline cadence shown in the example (e.g., “When… / But…”). + +--- + +**Phase 2: Customization Process** +[Specific Analysis with User Input] + +**Step 1: Contextual Bridging** +- **Input Decoding**: +- Quantify [Intensity] as dramatic escalation (0-10 scale) +- Map [Intent] to visual elements' interactive potential +- Establish topological connections between [Context/Theme] and meme formats +- Even if the user does not supply sentences, treat the demonstration’s setup/punchline cadence as the blueprint; only replace the thematic nouns/adjectives with the requested keywords to build the new contrast. + +**Step 2: Humor Optimization** +- **Multidimensional Strategies**: +- Tone calibration: Adjust phrasing sharpness using [Keywords] +- Tension building: Create contrast between static imagery and dynamic text +- Cultural alignment: Balance trending phrases with evergreen humor elements +- Prioritize concise, impactful punchlines (1-2 sentences per box) that stay faithful to the user's emotional tone and amplify the humorous contrast they hinted at, retaining the setup → punchline rhythm unless the user explicitly requests another format. + +**Text on the Meme**: +[Read the chart from top to bottom, from left to right in each red box should be put what text in turn, with box1: text fragment 1\nbox2: text fragment 2\n, there are several boxes to correspond to the output of a few paragraphs of the text corresponds to each other, here pay attention to the combination of the box in the map position, the meaning of the map, the user input, and the previous reasoning to generate the theme of the humor of the text. Unless told otherwise, deliver two fragments that mirror the demonstration’s “When… / But…” cadence with updated thematic wording.] +--- + +**Output Demonstration Example** + +[Comprehensive Description Section] +The image employs the classic \"Shocked Cat\" meme template, featuring a close-up of an orange tabby cat with dilated circular pupils and forward-stretched whiskers creating visual tension. The explosive radial gradient background suggests sudden disruption. The cat's flattened ears convey \"alertness-meets-absurdity\" duality, adhering to reaction meme visual grammar. + +[Usage Scenarios Section] +Optimal use cases include: +1. Social media rants about last-minute work demands +2. Gaming group reactions to unexpected team failures +3. E-commerce shoppers encountering bizarre product descriptions +Ideal scenarios should follow \"unexpected shock → exaggerated response\" narrative structures + +[Text Analysis Section] +Suggested text: \n\"Friday 5:55 PM\" (top line establishes time pressure) +\"Client says 'Just one more thing...'\" (bottom line triggers conflict) +Humor mechanisms: Amplifies workplace frustrations through the cat's dramatic expression, using cross-dimensional analogy between time constraints and animal reactions + +[Specific Analysis with User Input] +Step 1: Given [Emotion: Frustration][Intensity: 8][Theme: Fitness failures], emphasize exaggerated body-text correlation. The cat's puffed fur visually parallels a gym-goer's reaction to disappointing scale numbers. +Step 2: Implement absurd escalation: \"When your trainer says\" (setup) → \"'One more rep' actually means 20\" (absurd payoff). Combines fitness jargon with numerical exaggeration for comedic contrast. + +Text on the Meme: +\"When the pre-workout kicks in +But your willpower checks out early\" + + +Now please generate the analysis and text results based on this image and user input parameters. \ No newline at end of file diff --git a/examples/godsmeme/prompts/reward_compare.txt b/examples/godsmeme/prompts/reward_compare.txt new file mode 100644 index 00000000..13b71982 --- /dev/null +++ b/examples/godsmeme/prompts/reward_compare.txt @@ -0,0 +1 @@ +Which meme is funnier? diff --git a/examples/godsmeme/requirements.txt b/examples/godsmeme/requirements.txt new file mode 100644 index 00000000..6d340b7c --- /dev/null +++ b/examples/godsmeme/requirements.txt @@ -0,0 +1,2 @@ +transformers==4.56.2 +keye-vl-utils[decord]==1.0.0 diff --git a/examples/godsmeme/reward_model.py b/examples/godsmeme/reward_model.py new file mode 100644 index 00000000..12263576 --- /dev/null +++ b/examples/godsmeme/reward_model.py @@ -0,0 +1,681 @@ +import json +import os +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoProcessor +from transformers.utils import cached_file + +from keye_vl_utils import process_vision_info + +from lightrft.utils import get_current_device + +from meme_utils import ( + MemeRenderConfig, + PairwisePreference, + aggregate_pairwise_preferences, + compute_meme_format_reward, + extract_box_texts, + get_first_image, + load_text_file, + normalize_detections, + render_meme_image, + resolve_expected_box_count, + sample_group_pairs, +) + +_PAIRWISE_LABEL_KEY = "pairwise" +_REWARD_STATE = { + "model_reward_weight": 1.0, + "format_reward_weight": 0.1, +} +_DTYPE_MAP = { + "auto": None, + "float16": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, + "float32": torch.float32, + "fp32": torch.float32, +} + + +class _FSDPSafeEmbedding(nn.Module): + """Keep embedding weights in the parent FSDP unit. + + Keye's vision tower reads ``position_embedding.weight`` directly in its + interpolation path. When FSDP2 individually wraps those embedding modules, + the weight can become a DTensor while sibling activations stay regular + tensors, which triggers mixed Tensor/DTensor errors on addition. + """ + def __init__(self, embedding: nn.Embedding): + super().__init__() + self.weight = embedding.weight + self.num_embeddings = embedding.num_embeddings + self.embedding_dim = embedding.embedding_dim + self.padding_idx = embedding.padding_idx + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return nn.functional.embedding(input_ids, self.weight, padding_idx=self.padding_idx) + + +def _as_torch_device(device_like: Any) -> torch.device: + if isinstance(device_like, torch.device): + return device_like + if isinstance(device_like, int): + return torch.device(f"cuda:{device_like}") + return torch.device(device_like) + + +def _default_reward_prompt_path() -> str: + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompts", "reward_compare.txt") + + +def _load_reward_prompt(path: Optional[str] = None) -> str: + prompt_path = path or _default_reward_prompt_path() + if os.path.exists(prompt_path): + return load_text_file(prompt_path) + return "Which meme is funnier?" + + +def _parse_reward_config(raw_reward_pretrain: str) -> Dict[str, Any]: + if raw_reward_pretrain is None or not str(raw_reward_pretrain).strip(): + raise ValueError("`reward_pretrain` must point to a meme judge model or a JSON config") + + text = str(raw_reward_pretrain).strip() + try: + cfg = json.loads(text) + except json.JSONDecodeError: + cfg = {"pairwise": {"path": text}} + + if isinstance(cfg, str): + cfg = {"pairwise": {"path": cfg}} + if not isinstance(cfg, dict): + raise ValueError("Unsupported meme reward config format") + + if "pairwise" in cfg: + pairwise_cfg = cfg["pairwise"] + elif "outcome" in cfg: + pairwise_cfg = cfg["outcome"] + else: + first_key = next(iter(cfg.keys())) + pairwise_cfg = cfg[first_key] + + if isinstance(pairwise_cfg, str): + pairwise_cfg = {"path": pairwise_cfg} + if not isinstance(pairwise_cfg, dict) or not pairwise_cfg.get("path"): + raise ValueError("Meme reward config must contain a model path") + return pairwise_cfg + + +def _resolve_torch_dtype(dtype_like: Any) -> Optional[torch.dtype]: + if dtype_like is None or isinstance(dtype_like, torch.dtype): + return dtype_like + key = str(dtype_like).strip().lower() + if key not in _DTYPE_MAP: + raise ValueError(f"Unsupported torch dtype: {dtype_like}") + return _DTYPE_MAP[key] + + +def _resolve_model_file(path_or_repo: str, filename: str) -> str: + if os.path.isdir(path_or_repo): + local_path = os.path.join(path_or_repo, filename) + if os.path.exists(local_path): + return local_path + try: + return cached_file(path_or_repo, filename) + except Exception as exc: # pragma: no cover - exercised only when weights are resolved remotely. + raise FileNotFoundError(f"Could not find {filename} under {path_or_repo}") from exc + + +def _replace_module_if_embedding(parent: nn.Module, attr_name: str) -> bool: + module = getattr(parent, attr_name, None) + if not isinstance(module, nn.Embedding): + return False + + setattr(parent, attr_name, _FSDPSafeEmbedding(module)) + return True + + +def _find_keye_visual_module(model: nn.Module) -> Optional[nn.Module]: + for attr_name in ("visual", "vision_tower", "vision_model"): + module = getattr(model, attr_name, None) + if isinstance(module, nn.Module): + return module + + inner_model = getattr(model, "model", None) + if isinstance(inner_model, nn.Module): + return _find_keye_visual_module(inner_model) + + return None + + +def _patch_keye_fsdp_compat(model: nn.Module) -> None: + """Patch Keye vision modules that are fragile under FSDP2 DTensors.""" + visual_model = _find_keye_visual_module(model) + if visual_model is None: + return + + replaced_names: List[str] = [] + for module_name, module in visual_model.named_modules(): + if _replace_module_if_embedding(module, "position_embedding"): + replaced_names.append(f"{module_name}.position_embedding" if module_name else "position_embedding") + if _replace_module_if_embedding(module, "packing_position_embedding"): + replaced_names.append( + f"{module_name}.packing_position_embedding" if module_name else "packing_position_embedding" + ) + if hasattr(module, "_attn_implementation"): + module._attn_implementation = "eager" + + if hasattr(getattr(visual_model, "config", None), "_attn_implementation"): + visual_model.config._attn_implementation = "eager" + + vision_config = getattr(getattr(model, "config", None), "vision_config", None) + if hasattr(vision_config, "_attn_implementation"): + vision_config._attn_implementation = "eager" + + if replaced_names: + print("[MemePairwiseJudge] FSDP2-safe Keye vision patch:", ", ".join(replaced_names)) + print("[MemePairwiseJudge] Set Keye vision attention to 'eager' for FSDP2 compat") + + +def _unwrap_head_state_dict(state_dict: Dict[str, Any]) -> Dict[str, torch.Tensor]: + nested_keys = ("state_dict", "model", "module", "classification_head", "classifier") + current = state_dict + while isinstance(current, dict): + tensor_values = [value for value in current.values() if torch.is_tensor(value)] + if tensor_values: + break + next_state = None + for key in nested_keys: + candidate = current.get(key) + if isinstance(candidate, dict): + next_state = candidate + break + if next_state is None: + break + current = next_state + + if not isinstance(current, dict): + raise ValueError("classification_head.pt does not contain a supported state dict") + + tensor_state = {str(key): value for key, value in current.items() if torch.is_tensor(value)} + if not tensor_state: + raise ValueError("classification_head.pt does not contain tensor parameters") + return tensor_state + + +def _strip_common_prefix(state_dict: Dict[str, torch.Tensor], prefix: str) -> Dict[str, torch.Tensor]: + if not all(key.startswith(prefix) for key in state_dict): + return state_dict + return {key[len(prefix):]: value for key, value in state_dict.items()} + + +def _normalize_head_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + normalized = dict(state_dict) + changed = True + while changed: + changed = False + for prefix in ("module.", "model.", "classification_head.", "classifier.", "head."): + stripped = _strip_common_prefix(normalized, prefix) + if stripped is not normalized: + normalized = stripped + changed = True + break + return normalized + + +class _TwoLayerTanhHead(nn.Module): + def __init__(self, in_features: int, hidden_features: int, out_features: int): + super().__init__() + self.dense = nn.Linear(in_features, hidden_features) + self.out_proj = nn.Linear(hidden_features, out_features) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.out_proj(self.activation(self.dense(hidden_states))) + + +def _build_classification_head(state_dict: Dict[str, torch.Tensor]) -> nn.Module: + state_dict = _normalize_head_state_dict(_unwrap_head_state_dict(state_dict)) + keys = set(state_dict.keys()) + + if {"dense.weight", "dense.bias", "out_proj.weight", "out_proj.bias"}.issubset(keys): + dense_weight = state_dict["dense.weight"] + out_proj_weight = state_dict["out_proj.weight"] + head = _TwoLayerTanhHead( + in_features=dense_weight.shape[1], + hidden_features=dense_weight.shape[0], + out_features=out_proj_weight.shape[0], + ) + head.load_state_dict({ + "dense.weight": dense_weight, + "dense.bias": state_dict["dense.bias"], + "out_proj.weight": out_proj_weight, + "out_proj.bias": state_dict["out_proj.bias"], + }) + return head + + if "weight" in state_dict and state_dict["weight"].ndim == 2: + bias = state_dict.get("bias") + head = nn.Linear(state_dict["weight"].shape[1], state_dict["weight"].shape[0], bias=bias is not None) + payload = {"weight": state_dict["weight"]} + if bias is not None: + payload["bias"] = bias + head.load_state_dict(payload, strict=False) + return head + + for prefix in ("score", "out_proj", "summary"): + weight_key = f"{prefix}.weight" + if weight_key not in state_dict: + continue + bias_key = f"{prefix}.bias" + bias = state_dict.get(bias_key) + head = nn.Linear(state_dict[weight_key].shape[1], state_dict[weight_key].shape[0], bias=bias is not None) + payload = {"weight": state_dict[weight_key]} + if bias is not None: + payload["bias"] = bias + head.load_state_dict(payload, strict=False) + return head + + raise ValueError( + "Unsupported classification head format. Expected a simple linear head or a dense/out_proj head, " + f"but found keys: {sorted(state_dict.keys())}" + ) + + +def _load_classification_head(path_or_file: str, dtype: Optional[torch.dtype] = None) -> nn.Module: + head_state = torch.load(path_or_file, map_location="cpu") + head = _build_classification_head(head_state) + if dtype is not None: + head = head.to(dtype=dtype) + return head + + +def _pool_last_non_padding_token(hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor]) -> torch.Tensor: + if attention_mask is None: + return hidden_states[:, -1, :] + + attention_mask = attention_mask.to(device=hidden_states.device, dtype=torch.long) + reverse_mask = torch.flip(attention_mask, dims=[-1]) + last_offsets = torch.argmax(reverse_mask, dim=-1) + last_indices = attention_mask.size(-1) - 1 - last_offsets + gather_index = last_indices.view(-1, 1, 1).expand(-1, 1, hidden_states.size(-1)) + return hidden_states.gather(dim=1, index=gather_index).squeeze(1) + + +def _pair_scores_from_logits(logits: torch.Tensor) -> List[Tuple[float, float]]: + if logits.ndim == 1: + logits = logits.unsqueeze(-1) + + if logits.size(-1) == 1: + probs_a = torch.sigmoid(logits[:, 0].float()) + probs_b = 1.0 - probs_a + else: + probs = torch.softmax(logits[:, :2].float(), dim=-1) + probs_a = probs[:, 0] + probs_b = probs[:, 1] + + return list(zip(probs_a.detach().cpu().tolist(), probs_b.detach().cpu().tolist())) + + +def build_pairwise_judge_message( + reward_prompt: str, + image_a, + image_b, +) -> List[Dict[str, Any]]: + return [{ + "role": "user", + "content": [ + { + "type": "text", + "text": reward_prompt + }, + { + "type": "image", + "image": image_a + }, + { + "type": "image", + "image": image_b + }, + ], + }] + + +def group_meme_rollout_indices( + refs: Optional[Sequence[Any]], + batch_size: int, + n_samples_per_prompt: int = 1, +) -> List[List[int]]: + if refs and len(refs) >= batch_size: + groups: List[List[int]] = [] + current_group: List[int] = [] + current_group_id: Optional[str] = None + usable = True + + for idx in range(batch_size): + ref = refs[idx] + group_id = None + if isinstance(ref, dict): + for key in ("group_id", "sample_id", "id"): + value = ref.get(key) + if value is not None: + group_id = str(value) + break + if group_id is None: + usable = False + break + + if current_group and group_id != current_group_id: + groups.append(current_group) + current_group = [idx] + else: + current_group.append(idx) + current_group_id = group_id + + if usable and current_group: + groups.append(current_group) + if usable and sum(len(group) for group in groups) == batch_size and any(len(group) > 1 for group in groups): + return groups + + chunk_size = max(1, int(n_samples_per_prompt)) + return [list(range(start, min(start + chunk_size, batch_size))) for start in range(0, batch_size, chunk_size)] + + +class MemePairwiseJudge(nn.Module): + def __init__( + self, + base_model: nn.Module, + processor, + reward_head: nn.Module, + reward_prompt: str, + pair_batch_size: int = 4, + n_samples_per_prompt: int = 1, + max_pairs_per_group: int = 0, + max_length: int = 4096, + render_config: Optional[MemeRenderConfig] = None, + ): + super().__init__() + self.base_model = base_model + self.processor = processor + self.reward_head = reward_head + self.reward_prompt = reward_prompt + self.pair_batch_size = int(pair_batch_size) + self.n_samples_per_prompt = int(n_samples_per_prompt) + self.max_pairs_per_group = int(max_pairs_per_group) + self.max_length = int(max_length) + self.render_config = render_config or MemeRenderConfig() + + def _render_candidates( + self, + raw_images: Sequence[Any], + references: Sequence[Any], + prompt_and_outputs: Sequence[str], + ) -> List[Any]: + rendered_images: List[Any] = [] + + for raw_image, reference, prompt_and_output in zip(raw_images, references, prompt_and_outputs): + extracted_boxes = extract_box_texts(prompt_and_output or "") + + image = get_first_image(raw_image) + if image is None: + rendered_images.append(None) + continue + + detections = normalize_detections(reference if isinstance(reference, dict) else None) + rendered_images.append( + render_meme_image( + image=image, + texts=extracted_boxes, + detections=detections, + reference=reference if isinstance(reference, dict) else None, + config=self.render_config, + ) + ) + + return rendered_images + + def _score_pair_jobs(self, pair_jobs: List[Dict[str, Any]], device: torch.device) -> List[Tuple[float, float]]: + scores: List[Tuple[float, float]] = [] + if not pair_jobs: + return scores + + @torch.no_grad() + def _run_batch(batch_jobs: List[Dict[str, Any]]) -> List[Tuple[float, float]]: + if process_vision_info is None: + raise ImportError("keye-vl-utils is required for the GodsMeme reward model") + messages = [job["message"] for job in batch_jobs] + texts = [ + self.processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) + for message in messages + ] + image_inputs, video_inputs = process_vision_info(messages) + + processor_kwargs = { + "text": texts, + "padding": True, + "truncation": True, + "max_length": self.max_length, + "return_tensors": "pt", + } + if image_inputs is not None: + processor_kwargs["images"] = image_inputs + if video_inputs is not None: + processor_kwargs["videos"] = video_inputs + + inputs = self.processor(**processor_kwargs) + for key, value in list(inputs.items()): + if torch.is_tensor(value): + inputs[key] = value.to(device) + + outputs = self.base_model(**inputs, output_hidden_states=True, return_dict=True) + hidden_states = getattr(outputs, "hidden_states", None) + if hidden_states: + last_hidden = hidden_states[-1] + else: + last_hidden = getattr(outputs, "last_hidden_state", None) + if last_hidden is None: + raise ValueError("Reward model must return hidden states for pairwise scoring.") + + pooled = _pool_last_non_padding_token(last_hidden, inputs.get("attention_mask")) + logits = self.reward_head(pooled) + return _pair_scores_from_logits(logits) + + step = max(1, self.pair_batch_size) + for start in range(0, len(pair_jobs), step): + batch_jobs = pair_jobs[start:start + step] + try: + scores.extend(_run_batch(batch_jobs)) + except IndexError as exc: + # Some multimodal processors do not align multiple two-image prompts + # correctly in one batch. Fall back to one pair per forward pass. + if len(batch_jobs) == 1 or "image_grid_thw" not in str(exc): + raise + for job in batch_jobs: + scores.extend(_run_batch([job])) + + return scores + + @torch.no_grad() + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + return_dict: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> Dict[str, torch.Tensor]: + del attention_mask, pixel_values, image_grid_thw, return_dict, output_attentions, output_hidden_states + + prompt_and_output = kwargs.get("prompt_and_output") or [] + raw_images = kwargs.get("raw_images") or [] + references = kwargs.get("references") or [] + + batch_size = len(prompt_and_output) + if input_ids is not None: + device = input_ids.device + else: + device = _as_torch_device(get_current_device()) if torch.cuda.is_available() else torch.device("cpu") + if batch_size == 0: + return {"score": torch.zeros(0, dtype=torch.float32, device=device)} + + rendered_images = self._render_candidates(raw_images, references, prompt_and_output) + groups = group_meme_rollout_indices( + references, batch_size=batch_size, n_samples_per_prompt=self.n_samples_per_prompt + ) + + pair_jobs: List[Dict[str, Any]] = [] + preferences: List[PairwisePreference] = [] + comparison_counts = [0 for _ in range(batch_size)] + + for group in groups: + if len(group) < 2: + continue + for local_a, local_b in sample_group_pairs(len(group), max_pairs=self.max_pairs_per_group): + global_a = group[local_a] + global_b = group[local_b] + if rendered_images[global_a] is None or rendered_images[global_b] is None: + continue + + pair_jobs.append({ + "index_a": global_a, + "index_b": global_b, + "message": build_pairwise_judge_message( + reward_prompt=self.reward_prompt, + image_a=rendered_images[global_a], + image_b=rendered_images[global_b], + ), + }) + + pair_scores = self._score_pair_jobs(pair_jobs, device=device) + for job, (score_a, score_b) in zip(pair_jobs, pair_scores): + index_a = job["index_a"] + index_b = job["index_b"] + preferences.append( + PairwisePreference( + index_a=index_a, + index_b=index_b, + score_a=float(score_a), + score_b=float(score_b), + ) + ) + comparison_counts[index_a] += 1 + comparison_counts[index_b] += 1 + + pairwise_rewards = aggregate_pairwise_preferences(batch_size, preferences) + for idx, count in enumerate(comparison_counts): + if count == 0: + pairwise_rewards[idx] = 0.5 + + return {"score": torch.tensor(pairwise_rewards, dtype=torch.float32, device=device)} + + +def load_reward_models( + reward_pretrain: str, + strategy, + use_engine: bool = False, +): + if use_engine: + raise NotImplementedError("Engine is not supported for the meme reward model") + + cfg = _parse_reward_config(reward_pretrain) + _REWARD_STATE["model_reward_weight"] = float(cfg.get("model_reward_weight", 1.0)) + _REWARD_STATE["format_reward_weight"] = float(cfg.get("format_reward_weight", 0.1)) + + reward_model_path = cfg["path"] + reward_prompt = _load_reward_prompt(cfg.get("reward_prompt_path")) + torch_dtype = _resolve_torch_dtype(cfg.get("torch_dtype", "float16")) + classification_head_path = cfg.get("classification_head_path" + ) or _resolve_model_file(reward_model_path, "classification_head.pt") + + with strategy.init_model_context() as _: + base_model = AutoModel.from_pretrained( + reward_model_path, + torch_dtype=torch_dtype, + attn_implementation=cfg.get("attn_implementation", "flash_attention_2"), + trust_remote_code=cfg.get("trust_remote_code", True), + ) + if getattr(getattr(strategy, "args", None), "fsdp", False): + _patch_keye_fsdp_compat(base_model) + + processor = AutoProcessor.from_pretrained( + reward_model_path, + min_pixels=cfg.get("min_pixels", 256 * 28 * 28), + max_pixels=cfg.get("max_pixels", 1280 * 28 * 28), + trust_remote_code=cfg.get("trust_remote_code", True), + ) + if hasattr(processor, "tokenizer"): + processor.tokenizer.padding_side = "left" + + reward_head = _load_classification_head(classification_head_path, dtype=torch_dtype) + model = MemePairwiseJudge( + base_model=base_model, + processor=processor, + reward_head=reward_head, + reward_prompt=reward_prompt, + pair_batch_size=cfg.get("pair_batch_size", 1), + n_samples_per_prompt=cfg.get("n_samples_per_prompt", 1), + max_pairs_per_group=cfg.get("max_pairs_per_group", 0), + max_length=cfg.get("max_length", cfg.get("cutoff_len", 4096)), + render_config=MemeRenderConfig( + font_name=cfg.get("font_name", "DejaVuSans.ttf"), + min_font_size=cfg.get("min_font_size", 14), + max_font_size=cfg.get("max_font_size", 72), + line_spacing=cfg.get("line_spacing", 4), + outline_width=cfg.get("outline_width", 2), + margin=cfg.get("margin", 6), + default_padding_ratio=cfg.get("default_padding_ratio", 0.06), + ), + ) + model.eval() + + return [model], [processor.tokenizer], {_PAIRWISE_LABEL_KEY: 0} + + +def reward_fn( + model_reward_list: List[torch.Tensor], + labels: Sequence[str], + queries: Sequence[str], + refs: Sequence[Any], + label_map: Optional[Dict[str, int]] = None, + **kwargs, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + del labels, kwargs + + if model_reward_list: + device = model_reward_list[0].device + dtype = model_reward_list[0].dtype + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 + + batch_size = len(queries) + model_reward = torch.zeros(batch_size, dtype=dtype, device=device) + if model_reward_list: + pairwise_idx = 0 + if label_map: + pairwise_idx = label_map.get(_PAIRWISE_LABEL_KEY, 0) + pairwise_idx = min(pairwise_idx, len(model_reward_list) - 1) + model_reward = torch.as_tensor(model_reward_list[pairwise_idx], dtype=dtype, device=device) + + format_values = [] + for idx, query in enumerate(queries): + reference = refs[idx] if refs is not None and idx < len(refs) else None + expected_boxes = resolve_expected_box_count(reference if isinstance(reference, dict) else None) + format_values.append(compute_meme_format_reward(query, expected_boxes=expected_boxes)) + + format_reward = torch.tensor(format_values, dtype=dtype, device=device) + final_reward = ( + _REWARD_STATE["model_reward_weight"] * model_reward + _REWARD_STATE["format_reward_weight"] * format_reward + ) + metrics = { + "model_reward": model_reward, + "format_reward": format_reward, + "rule_reward": final_reward, + } + return final_reward, metrics diff --git a/examples/godsmeme/run_meme_grpo.sh b/examples/godsmeme/run_meme_grpo.sh new file mode 100755 index 00000000..1140589b --- /dev/null +++ b/examples/godsmeme/run_meme_grpo.sh @@ -0,0 +1,221 @@ +#!/usr/bin/env bash +# +# LightRFT GRPO training script for the GodsMeme example. +# +# Compared with examples/gsm8k_geo3k/run_grpo_geo3k_qwen2.5_vl_7b.sh, +# GodsMeme has a few task-specific constraints: +# 1. The dataset is loaded from --annotation_path + --root_dir, not --prompt_data. +# 2. The reward is not pure rule-based; it uses a local pairwise meme judge model. +# 3. The current meme reward model does not support --rm_use_engine. +# 4. Each GRPO group must stay inside one micro-rollout batch: +# micro_rollout_batch_size % n_samples_per_prompt == 0 +# 5. Pairwise judge cost grows quadratically with n_samples_per_prompt unless you cap +# max_pairs_per_group. +# + +set -euo pipefail + +################################################################################ +# Part 1: User Configuration # +################################################################################ + +# --- Model and Dataset Paths --- +# The policy model can be either a local path or a Hugging Face model id. +POLICY_MODEL_PATH="${POLICY_MODEL_PATH:-/path/to/your/policy-model}" + +# The meme judge model is loaded locally by examples/godsmeme/reward_model.py. +# In the simplest setup, this can be the same model family as the policy model. +REWARD_MODEL_PATH="${REWARD_MODEL_PATH:-/path/to/your/reward-model}" + +# GodsMeme expects pre-built RL rows in JSON or JSONL format. +ANNOTATION_PATH="${ANNOTATION_PATH:-/path/to/your/train_data.jsonl}" +IMAGE_ROOT="${IMAGE_ROOT:-/path/to/your/image_root}" + +# --- Experiment and Logging --- +EXPERIMENT_NAME="${EXPERIMENT_NAME:-lightrft-godsmeme-grpo-training}" +RESULT_ROOT="${RESULT_ROOT:-results}" +LOG_ROOT="${LOG_ROOT:-rft_logs}" + +# Set WANDB_API_KEY="" to disable W&B cleanly. +export WANDB_API_KEY="${WANDB_API_KEY:-}" +export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-GodsMeme-Experiments}" +export WANDB_MODE="${WANDB_MODE:-offline}" + + +################################################################################ +# Part 2: GodsMeme Reward Configuration # +################################################################################ + +# Reward prompt template used by the pairwise judge. +REWARD_PROMPT_PATH="${REWARD_PROMPT_PATH:-examples/godsmeme/prompts/reward_compare.txt}" + +# Reward cost control. 0 means use all pairs inside each rollout group. +MAX_PAIRS_PER_GROUP="${MAX_PAIRS_PER_GROUP:-0}" +PAIR_BATCH_SIZE="${PAIR_BATCH_SIZE:-1}" +REWARD_MAX_LENGTH="${REWARD_MAX_LENGTH:-96}" + +# Reward composition: +# final_reward = model_reward_weight * pairwise_reward +# + format_reward_weight * format_reward +MODEL_REWARD_WEIGHT="${MODEL_REWARD_WEIGHT:-1.0}" +FORMAT_REWARD_WEIGHT="${FORMAT_REWARD_WEIGHT:-0.1}" + + +################################################################################ +# Part 3: Training Hyperparameters # +################################################################################ + +# --- GRPO Settings --- +N_SAMPLES="${N_SAMPLES:-8}" +EPISODE="${EPISODE:-20}" +WARMUP="${WARMUP:-0.03}" + +# --- Batch Size Settings --- +RBS="${RBS:-128}" +TBS="${TBS:-128}" +MICRO_ROLLOUT_BATCH_SIZE="${MICRO_ROLLOUT_BATCH_SIZE:-8}" +MICRO_TRAIN_BATCH_SIZE="${MICRO_TRAIN_BATCH_SIZE:-4}" + +# --- Learning and Generation Settings --- +KL="${KL:-0.01}" +LR="${LR:-1e-6}" +PROMPT_MAX_LEN="${PROMPT_MAX_LEN:-8192}" +GENERATE_MAX_LEN="${GENERATE_MAX_LEN:-2048}" +TEMPERATURE="${TEMPERATURE:-1.0}" +TOP_P="${TOP_P:-1.0}" + +# The actor only sees one source image per prompt in GodsMeme. +LIMIT_MM_IMAGE_PER_PROMPT="${LIMIT_MM_IMAGE_PER_PROMPT:-1}" + + +################################################################################ +# Part 4: Distributed Training Setup # +################################################################################ + +export NNODES="${NNODES:-1}" +export GPUS_PER_NODE="${GPUS_PER_NODE:-8}" +export NODE_RANK="${NODE_RANK:-0}" +export MASTER_ADDR="${MASTER_ADDR:-localhost}" +export MASTER_PORT="${MASTER_PORT:-20091}" + +# The rollout engine is for the actor only. The meme reward model is loaded +# directly and kept outside rm_use_engine. +ENGINE_TYPE="${ENGINE_TYPE:-vllm}" +ENGINE_TP="${ENGINE_TP:-2}" +ENGINE_MEM_UTIL="${ENGINE_MEM_UTIL:-0.4}" + + +################################################################################ +# Part 5: Launch # +################################################################################ + +PAIR_TAG="allpairs" +if (( MAX_PAIRS_PER_GROUP > 0 )); then + PAIR_TAG="pairs${MAX_PAIRS_PER_GROUP}" +fi + +DEFAULT_REWARD_PRETRAIN="$({ + REWARD_MODEL_PATH="$REWARD_MODEL_PATH" \ + REWARD_PROMPT_PATH="$REWARD_PROMPT_PATH" \ + PAIR_BATCH_SIZE="$PAIR_BATCH_SIZE" \ + MAX_PAIRS_PER_GROUP="$MAX_PAIRS_PER_GROUP" \ + MODEL_REWARD_WEIGHT="$MODEL_REWARD_WEIGHT" \ + FORMAT_REWARD_WEIGHT="$FORMAT_REWARD_WEIGHT" \ + REWARD_MAX_LENGTH="$REWARD_MAX_LENGTH" \ + N_SAMPLES="$N_SAMPLES" \ + python3 - <<'PY' +import json +import os + +cfg = { + "pairwise": { + "path": os.environ["REWARD_MODEL_PATH"], + "reward_prompt_path": os.environ["REWARD_PROMPT_PATH"], + "pair_batch_size": int(os.environ["PAIR_BATCH_SIZE"]), + "max_pairs_per_group": int(os.environ["MAX_PAIRS_PER_GROUP"]), + "model_reward_weight": float(os.environ["MODEL_REWARD_WEIGHT"]), + "format_reward_weight": float(os.environ["FORMAT_REWARD_WEIGHT"]), + "max_length": int(os.environ["REWARD_MAX_LENGTH"]), + "n_samples_per_prompt": int(os.environ["N_SAMPLES"]), + } +} +print(json.dumps(cfg, ensure_ascii=True)) +PY +})" +REWARD_PRETRAIN="${REWARD_PRETRAIN:-$DEFAULT_REWARD_PRETRAIN}" + +current_time=$(date +"%Y%m%d_%H%M%S") +SAVE_MODEL_NAME="${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-${PAIR_TAG}-${current_time}" +WANDB_RUN_NAME="${EXPERIMENT_NAME}-${PAIR_TAG}-${current_time}" + +mkdir -p "${RESULT_ROOT}/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" +mkdir -p "${LOG_ROOT}/${EXPERIMENT_NAME}" + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG="${NCCL_DEBUG:-WARN}" +export IGNORE_EOS="${IGNORE_EOS:-0}" + +set -x + +torchrun \ + --nnodes "$NNODES" \ + --nproc-per-node "$GPUS_PER_NODE" \ + --node_rank "$NODE_RANK" \ + --master-port "$MASTER_PORT" \ + --master-addr "$MASTER_ADDR" \ + examples/godsmeme/train_colocate.py \ + --pretrain "${POLICY_MODEL_PATH}" \ + --reward_pretrain "${REWARD_PRETRAIN}" \ + --annotation_path "${ANNOTATION_PATH}" \ + --root_dir "${IMAGE_ROOT}" \ + --save_path "${RESULT_ROOT}/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --ckpt_path "${RESULT_ROOT}/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --advantage_estimator "group_norm" \ + --use_kl_loss \ + --kl_estimator "k3" \ + --fsdp \ + --bf16 \ + --flash_attn \ + --gradient_checkpointing \ + --save_hf_ckpt \ + --micro_train_batch_size "${MICRO_TRAIN_BATCH_SIZE}" \ + --train_batch_size "${TBS}" \ + --micro_rollout_batch_size "${MICRO_ROLLOUT_BATCH_SIZE}" \ + --rollout_batch_size "${RBS}" \ + --max_epochs 1 \ + --num_episodes "${EPISODE}" \ + --lr_warmup_ratio "${WARMUP}" \ + --n_samples_per_prompt "${N_SAMPLES}" \ + --prompt_max_len "${PROMPT_MAX_LEN}" \ + --generate_max_len "${GENERATE_MAX_LEN}" \ + --actor_learning_rate "${LR}" \ + --temperature "${TEMPERATURE}" \ + --top_p "${TOP_P}" \ + --init_kl_coef "${KL}" \ + --l2 1.0e-2 \ + --freeze_prefix \ + --adam_offload \ + --engine_type "${ENGINE_TYPE}" \ + --engine_mem_util "${ENGINE_MEM_UTIL}" \ + --engine_tp_size "${ENGINE_TP}" \ + --enable_engine_sleep \ + --limit_mm_image_per_prompt "${LIMIT_MM_IMAGE_PER_PROMPT}" \ + --use_wandb "${WANDB_API_KEY}" \ + --wandb_project "${WANDB_PROJECT}" \ + --wandb_run_name "${WANDB_RUN_NAME}" \ + 2>&1 | tee "${LOG_ROOT}/${EXPERIMENT_NAME}/node${NODE_RANK}_${current_time}.log" + + +################################################################################ +# Usage Notes # +# # +# 1. GodsMeme data must already be prepared as conversation-style RL rows. # +# 2. Do not add --rm_use_engine here: the current meme reward model raises # +# NotImplementedError when rm_use_engine is enabled. # +# 3. LIMIT_MM_IMAGE_PER_PROMPT defaults to 1 because each policy prompt only # +# contains one source image. The reward model renders and scores both # +# candidate meme images inside examples/godsmeme/reward_model.py. # +# 4. If reward evaluation is too slow, first try lowering N_SAMPLES or set # +# MAX_PAIRS_PER_GROUP to a small positive integer. # +# # +################################################################################ diff --git a/examples/godsmeme/test_meme_dataset.py b/examples/godsmeme/test_meme_dataset.py new file mode 100644 index 00000000..58e5a872 --- /dev/null +++ b/examples/godsmeme/test_meme_dataset.py @@ -0,0 +1,49 @@ +from torch.utils.data import DataLoader +from meme_dataset import MemeOnlineRLDataset +from transformers import AutoProcessor + +# download from https://huggingface.co/datasets/luodi-7/Eimages +REAL_ANNOTATION = "/root/data/Eimages/train_data.json" +REAL_ROOT = "/root/data/Eimages" +MODEL_PATH = "/root/model/HUMOR-COT-Qwen2.5-VL" +processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) + + +def test_real_dataset_loading(): + dataset = MemeOnlineRLDataset( + annotation_path=REAL_ANNOTATION, root_dir=REAL_ROOT, shuffle=False, processor=processor + ) + + # print(len(dataset)) + assert len(dataset) == 3345 + + prompt, image_path, reference, label = dataset[0] + assert "Meme Text Generation Framework" in prompt + assert "[Comprehensive Description Section]" in reference + assert isinstance(image_path[0], str) + # from PIL import Image + # image = Image.open(image_path[0]) + + # outputs = processor( + # text=prompt, + # images=image, + # add_special_tokens=False, + # max_length=4096, + # truncation=True, + # ) + + +def test_collate_fn(): + dataset = MemeOnlineRLDataset(annotation_path=REAL_ANNOTATION, root_dir=REAL_ROOT, processor=processor) + loader = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn) + + batch = next(iter(loader)) + assert len(batch[0]) == 4 + assert len(batch[1]) == 4 + assert len(batch[2]) == 4 + assert len(batch[3]) == 4 + + +if __name__ == "__main__": + test_real_dataset_loading() + test_collate_fn() diff --git a/examples/godsmeme/train_colocate.py b/examples/godsmeme/train_colocate.py new file mode 100644 index 00000000..87dbc412 --- /dev/null +++ b/examples/godsmeme/train_colocate.py @@ -0,0 +1,514 @@ +import argparse +import math +import os +import sys +from datetime import datetime + +import torch + +from lightrft.utils import get_tokenizer_processor_vl, ensure_video_input_available +from lightrft.models import ActorVL +from lightrft.strategy import get_strategy +from lightrft.trainer import SPMDPPOTrainerVL +from lightrft.utils import add_arguments + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from meme_dataset import MemeOnlineRLDataset +from reward_model import load_reward_models, reward_fn + +ensure_video_input_available() + + +def train(args): + # configure strategy + strategy = get_strategy(args) + + ds_train_cfg = strategy.get_ds_train_config(is_actor=True) if not args.fsdp else None + ds_eval_cfg = strategy.get_ds_eval_config(offload=False) if not args.fsdp else None + + # configure model (optionally meta_init to save CPU memory) + with strategy.init_model_context(meta_init=getattr(args, "meta_init", False)): + actor = ActorVL( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=ds_train_cfg, + packing_samples=args.packing_samples, + disable_logprobs_flashattn=args.disable_logprobs_flashattn, + fused_linear_logprob=args.fused_linear_logprob, + high_entropy_token_ratio=getattr(args, "high_entropy_token_ratio", 0.0), + ) + + if args.actor_init_on_gpu: + actor = actor.to(torch.cuda.current_device()) + + # pre-prepare is used for saving RAM memory when training 72B model + if args.fsdp: + setattr(actor, "is_actor", True) + actor = strategy.prepare_model(actor, is_training=True) + + # configure tokenizer and processor + tokenizer, processor = get_tokenizer_processor_vl( + args.pretrain, actor.model, "left", use_fast=not strategy.args.disable_fast_tokenizer + ) + assert processor is not None, "processor is None" + + if args.freeze_prefix: + # visual for qwenvl, vision_model for internvl, mlp1 for linear layer in internvl + freeze_prefix = ["visual", "vision_model", "mlp1"] + frozen_params_count = 0 + total_params_count = 0 + for name, param in actor.model.named_parameters(): + total_params_count += 1 + if any(name.startswith(prefix) for prefix in freeze_prefix): + param.requires_grad = False + frozen_params_count += 1 + strategy.print( + f"Froze {frozen_params_count}/{total_params_count} parameters based on prefixes: {freeze_prefix}" + ) + + critic = None + + strategy.report_memory("before loaded reward models in main entry") + reward_models, reward_tokenizers, label_map = load_reward_models( + args.reward_pretrain, strategy, use_engine=args.rm_use_engine + ) + strategy.print(f"label_map: {label_map}") + strategy.report_memory("after loaded reward models in main entry") + + strategy.print(actor) + + # load weights for reference actor + if args.init_kl_coef == 0: + initial_model = None + else: + initial_model = ActorVL( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=ds_eval_cfg, + packing_samples=args.packing_samples, + fused_linear_logprob=args.fused_linear_logprob, + ) + + if args.fsdp: + initial_model = strategy.prepare_model(initial_model, is_training=False, shard_size=strategy.world_size) + strategy.offload_model(initial_model) + + # prepare datasets (meme-specific: annotation_path + root_dir) + annotation_path = args.annotation_path + root_dir = args.root_dir + prompts_dataset = MemeOnlineRLDataset( + annotation_path=annotation_path, + root_dir=root_dir, + shuffle=True, + processor=processor, + ) + strategy.print(f"Loaded {len(prompts_dataset)} samples for prompts.") + + # prepare dataloader + prompts_dataloader = strategy.setup_dataloader( + prompts_dataset, + args.rollout_batch_size // strategy.world_size, + True, + True, + collate_fn=prompts_dataset.collate_fn + ) + pretrain_dataloader = None + + # for scheduler + num_update_steps_per_episodes = ( + len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs + ) + max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes) + + # gradient_checkpointing + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + if critic is not None: + critic.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + ( + (actor, actor_optim, actor_scheduler), + (critic, critic_optim, critic_scheduler), + reward_models, + initial_model, + ) = strategy.prepare_models_and_optimizers(actor, critic, reward_models, initial_model, args, max_steps) + + strategy.print(reward_models) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): + _, states = strategy.load_ckpt( + actor.model, os.path.join(args.ckpt_path, "_actor"), optimizer=actor_optim, scheduler=actor_scheduler + ) + if args.critic_pretrain: + strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic")) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + strategy.report_memory("after models init") + + strategy.setup_inference_engine(args, engine_type=args.engine_type, actor=actor) + + # configure Trainer + trainer = SPMDPPOTrainerVL( + strategy, + actor, + critic, + reward_models, + initial_model, + None, + actor_optim, + critic_optim, + actor_scheduler, + critic_scheduler, + max_epochs=args.max_epochs, + micro_train_batch_size=args.micro_train_batch_size, + micro_rollout_batch_size=args.micro_rollout_batch_size, + gradient_checkpointing=args.gradient_checkpointing, + tokenizer=tokenizer, + processor=processor, + prompt_max_len=args.prompt_max_len, + value_clip=args.value_clip, + eps_clip=args.eps_clip, + loss_agg_mode=args.loss_agg_mode, + use_gspo=args.use_gspo, + normalize_advantages=args.normalize_advantages, + use_sequence_rewards=args.use_sequence_rewards, + gamma=args.gamma, + lambd=args.lambd, + init_kl_coef=args.init_kl_coef, + kl_target=args.kl_target, + ema_beta=0.992, + ptx_coef=args.ptx_coef, + max_norm=args.max_norm, + # for GPT generation + do_sample=True, + max_new_tokens=args.generate_max_len, + max_length=args.max_len, + temperature=args.temperature, + top_p=args.top_p, + first_token_temperature=args.first_token_temperature, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + # reward model + reward_fn=reward_fn, + reward_fn_label_map=label_map, + reward_tokenizers=reward_tokenizers, + save_hf_ckpt=args.save_hf_ckpt, + disable_ds_ckpt=args.disable_ds_ckpt, + packing_samples=args.packing_samples, + high_entropy_token_ratio=args.high_entropy_token_ratio, + dynamic_sampling=args.dynamic_sampling, + overlong_buffer=args.overlong_buffer, + overlong_buffer_len=args.overlong_buffer_len, + overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, + print_replay_buffer_stats=args.print_replay_buffer_stats, + ) + + trainer.fit( + args, + prompts_dataloader=prompts_dataloader, + pretrain_dataloader=pretrain_dataloader, + eval_dataloader=None, + consumed_samples=consumed_samples, + num_update_steps_per_episodes=num_update_steps_per_episodes, + ) + + # save model checkpoint after fitting on only rank0 + strategy.save_model( + actor, + tokenizer, + args.save_path, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--engine_type", type=str, default="vllm", help="Choose inference engine type: vllm, sglang") + + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--save_hf_ckpt", action="store_true", default=False) + parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument( + "--save_trajectories", + action="store_true", + default=False, + help="Save experience trajectories to JSON for debugging", + ) + parser.add_argument( + "--num_trajectories_to_save", + type=int, + default=10, + help="Number of trajectories to save per checkpoint", + ) + parser.add_argument( + "--mark_high_entropy_tokens", + action="store_true", + default=False, + help="Create token arrays with high-entropy information for HTML rendering (requires --save_trajectories).", + ) + parser.add_argument( + "--trajectory_analysis", + action="store_true", + default=False, + help="Enable trajectory analysis metrics and log them to wandb", + ) + parser.add_argument( + "--print_replay_buffer_stats", + action="store_true", + default=False, + help="Print detailed replay buffer statistics during training", + ) + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DAPO + parser.add_argument( + "--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy" + ) + parser.add_argument( + "--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO" + ) + parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") + parser.add_argument( + "--overlong_buffer_penalty_factor", + type=float, + default=1.0, + help="Penalty scaling factor for overlong sequences", + ) + + # PPO + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--rollout_batch_size", type=int, default=512) + parser.add_argument("--micro_rollout_batch_size", type=int, default=8) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt") + parser.add_argument("--generate_max_len", type=int, default=1024, help="Max tokens to generate in PPO") + parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len") + parser.add_argument("--max_samples", type=int, default=1000000) + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") + parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") + parser.add_argument( + "--loss_agg_mode", + type=str, + default="seq-mean-token-mean", + help="Loss aggregation mode for policy gradients", + ) + parser.add_argument("--use_gspo", action="store_true", default=False, help="Enable GSPO mode") + parser.add_argument( + "--normalize_advantages", action="store_true", default=True, help="Enable advantage normalization in GSPO" + ) + parser.add_argument( + "--use_sequence_rewards", action="store_true", default=True, help="Use sequence-level rewards in GSPO" + ) + parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") + parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd") + parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") + parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--normalize_reward", action="store_true", default=False, help="Enable Reward Normazation") + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument( + "--use_fire", + action="store_true", + default=False, + help="Enable FIRE sampling (Flaming-hot Initiation with Regular Execution)", + ) + parser.add_argument( + "--first_token_temperature", + type=float, + default=10.0, + help="Temperature for the first token when --use_fire is enabled", + ) + parser.add_argument( + "--freeze_prefix", + action="store_true", + default=False, + help="Freeze the prefix part (e.g. vision encoder) of the actor model" + ) + parser.add_argument( + "--n_samples_per_prompt", type=int, default=1, help="number of responses for each prompt in generation" + ) + parser.add_argument("--actor_learning_rate", type=float, default=1e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--kl_target", type=float, default=None) + parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO") + parser.add_argument( + "--kl_estimator", + type=str, + default="k1", + choices=["k1", "k2", "k3"], + help=( + "In GRPO, k3 is utilized as the loss function, while k2, when used as the loss, is nearly equivalent to k1." + ), + ) + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + parser.add_argument( + "--reward_running_norm", action="store_true", default=False, help="Enable running normalization for rewards" + ) + parser.add_argument( + "--reward_running_norm_minus_mean", + action="store_true", + default=False, + help="Subtract mean when using reward running normalization", + ) + parser.add_argument("--reward_clip", type=float, default=0.0, help="Clip rewards to [-reward_clip, reward_clip]") + parser.add_argument("--advantages_norm", action="store_true", default=False, help="Enable whitening for advantages") + parser.add_argument( + "--advantage_clip", type=float, default=0.0, help="Clip advantages to [-advantage_clip, advantage_clip]" + ) + parser.add_argument("--reward_clip_range", type=float, nargs=2, default=(-10, 10), help="Reward clip range") + + # DeepSpeed + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument( + "--meta_init", action="store_true", default=False, help="Initialize models on meta device to save CPU memory" + ) + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--actor_init_on_gpu", action="store_true", default=False) + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + parser.add_argument( + "--disable_logprobs_flashattn", + action="store_true", + default=False, + help="Disable flash attn implementation in log_probs calculation" + ) + + # Reinforce + parser.add_argument( + "--advantage_estimator", + type=str, + choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm"], + default="gae", + help="Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm", + ) + + parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO") + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # Models + parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API") + parser.add_argument("--critic_pretrain", type=str, default=None, help="Optional critic path for GAE runs") + + # Custom dataset (GodsMeme) + parser.add_argument( + "--annotation_path", + type=str, + default= + "/fs-computility/niuyazhe/shared/xueyingyi/xueyingyi/cot_picture/Eimages/annotations/all/train_data.jsonl", + help="Path to meme annotation JSONL", + ) + parser.add_argument( + "--root_dir", + type=str, + default="/fs-computility/niuyazhe/shared/xueyingyi/xueyingyi/cot_picture/Eimage_drawn", + help="Root directory for meme images", + ) + parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--prompt_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--prompt_split", type=str, default="train") + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="openrlhf_train_ppo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + # MultiModal + parser.add_argument("--text_only", action="store_true", default=False) + parser.add_argument( + "--limit_mm_image_per_prompt", + type=int, + default=-1, + help="the max image number of each text in multi model for inference backend" + ) + + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument( + "--use_cpg_loss", + action="store_true", + default=False, + help="whether to use the clipped policy gradient loss from CPGD" + ) + parser.add_argument( + "--high_entropy_token_ratio", + type=float, + default=0.0, + help="Ratio of high-entropy tokens to keep for policy updates", + ) + + add_arguments(parser) + + args = parser.parse_args() + args.use_kl_estimator_k3 = args.kl_estimator == "k3" + + if args.advantage_estimator not in ["gae"]: + args.critic_pretrain = None + + if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]: + assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" + assert args.micro_rollout_batch_size % args.n_samples_per_prompt == 0, ( + "meme pairwise reward requires micro_rollout_batch_size to be divisible by n_samples_per_prompt" + ) + + if args.use_kl_loss: + if args.kl_estimator not in ["k2", "k3"]: + print(f"Recommend setting {args.kl_estimator} to 'k2' or 'k3' when using KL as a loss") + else: + if args.kl_estimator not in ["k1"]: + print(f"Recommend setting {args.kl_estimator} to 'k1' when not using KL as a loss.") + + train(args) diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index fe12d534..1e352cbe 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -794,6 +794,7 @@ def _compute_standard_torch_rewards( prompt_and_output=output.prompt_and_output, raw_images=output.raw_images, img_num=output.image_num, + references=output.references, **output.inputs_extra_kwargs, ) diff --git a/pyproject.toml b/pyproject.toml index 57d7043c..909d84d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "transformers>=4.57.1", "sglang>=0.5.6.post2", "deepspeed>=0.18.3", - "flash-attn>=2.8.3", + #"flash-attn>=2.8.3", "accelerate", "datasets", "wandb",