diff --git a/examples/orm_rl_demo/README.md b/examples/orm_rl_demo/README.md new file mode 100755 index 00000000..ef6e8e46 --- /dev/null +++ b/examples/orm_rl_demo/README.md @@ -0,0 +1,105 @@ +
+ +# ORM RL Demo + +Complete ORM trajectory-scoring RL training demo based on Geo3K. + +
+ +## Overview + +This demo shows the full pipeline of using an ORM to score trajectories for RL training: +- dataset: Geo3K +- actor: Qwen2.5-VL 7B model +- reward: one general outcome reward model combined with rule-based accuracy reward and format reward, all contributing to the GRPO loss +- training engine: FSDP, inference engine: SGLang + +The actor generates Geo3K trajectories, the general ORM scores them, and the scores are combined with a rule-based accuracy reward (`accuracy_reward`) and a format reward (`format_reward`) to compute the final GRPO loss. To avoid rewriting the Geo3K dataset files, the demo overrides the dataset label to `geo3k_general` at runtime so the original dataset path can be reused while routing through the general ORM reward mix. + +Environment requirements stay aligned with the repository-level [README.md](../../README.md). Refer to the main project document. + +## Project Structure + +```text +orm_rl_demo/ +├── train_colocate.py +├── reward_models.py +├── reward_models_utils.py +├── test_reward_models.py +└── run_general_fsdp_qwenvl.sh +``` + +## Quick Start + +Set the data and model paths, then run the entry script: + +```bash +export DATA_PATH=/path/to/geo3k +export PRETRAIN_PATH=/path/to/Qwen2.5-VL-7B-Instruct +export REWARD_PRETRAIN_PATHS='{"general":"/path/to/general-reward-model"}' +bash examples/orm_rl_demo/run_general_fsdp_qwenvl.sh +``` + +Set the dataset and model paths via the environment variables above before running. + +## Results + +### Experiment Setup + +This demo has been validated with one real 2-GPU full training run (W&B: [ORM-RL-Demo-QwenVL-7B-Geo3K](https://wandb.ai/hansbug/ORM-RL-Demo-QwenVL-7B-Geo3K/runs/zrekazyw)): + +| Item | Value | +| --- | --- | +| Actor | Qwen2.5-VL-7B-Instruct | +| General RM | Qwen2.5-VL-7B general reward model | +| Dataset | Geo3K | +| Training engine | FSDP | +| Inference engine | SGLang (`rm_use_engine=True`) | +| Reward mixing | `format_reward × 0.1 + general_model_reward × 0.2 + accuracy_reward × 0.7` | +| Batch sizes | `train_batch_size=128`, `rollout_batch_size=128` | +| Sampling | `n_samples_per_prompt=8`, `num_episodes=20` | +| Sequence length | `prompt_max_len=1024`, `generate_max_len=2048` | +| Optimizer / KL | `actor_learning_rate=1e-6`, `init_kl_coef=0.001`, `lr_warmup_ratio=0.03` | + +Three reward components — rule-based accuracy reward (`accuracy_reward`), ORM scoring (`general_model_reward`, coefficient 0.2), and format reward (`format_reward`) — are combined with the weights above and together compute the final GRPO loss. The `general_model_reward` values shown (e.g. `0.2`) reflect the ORM output (range 0.0 / 0.5 / 1.0) multiplied by the 0.2 coefficient, not the raw model score. + +### Curve Results + +The run completed successfully (`train/global_step=320`, 16 eval passes): +- `eval/reward_mean` improved from `0.4636` to `0.5679` +- Best `eval/reward_mean=0.5686` at `train_step=260` +- Final `eval/accuracy_reward_mean=0.5166`, `eval/format_reward_mean=0.9956`, `eval/general_model_reward_mean=0.1067` + +![](assets/exp_20260417/summary_card.png) + +![](assets/exp_20260417/reward_dashboard.png) + +![](assets/exp_20260417/optimization_dashboard.png) + +### Case Study + +Between step 80 and step 320, two question stems appear in both saved trajectories. The following shows the same two questions compared across early and late training. + +#### Question A: Parallelogram Area + +![](assets/exp_20260417/question_a_step80.png) + +![](assets/exp_20260417/question_a_step320.png) + +- Step 80 rewards: `total=0.3`, `format=1.0`, `accuracy=0.0`, `general_model=0.2`, `rule=0.1` +- Step 320 rewards: `total=1.0`, `format=1.0`, `accuracy=1.0`, `general_model=0.2`, `rule=0.8` +- The actor already produced a close answer at step 80 so the ORM scored it near 1.0 (contributing 0.2 after the 0.2 coefficient); by step 320 the output moved from `38.97` to the rule-matching `39.0`, flipping `accuracy_reward` from `0.0` to `1.0`. + +#### Question B: Tangent Geometry `y` + +![](assets/exp_20260417/question_b_step80.png) + +![](assets/exp_20260417/question_b_step320.png) + +- Step 80 rewards: `total=0.1`, `format=1.0`, `accuracy=0.0`, `general_model=0.0`, `rule=0.1` +- Step 320 rewards: `total=1.0`, `format=1.0`, `accuracy=1.0`, `general_model=0.2`, `rule=0.8` +- At step 80 only format was preserved while both accuracy and ORM failed to reward the answer; by step 320 both became positive contributions. + +## License + +This project is licensed under the Apache 2.0 License. See [LICENSE](../../LICENSE) for details. diff --git a/examples/orm_rl_demo/README_zh.md b/examples/orm_rl_demo/README_zh.md new file mode 100755 index 00000000..96046056 --- /dev/null +++ b/examples/orm_rl_demo/README_zh.md @@ -0,0 +1,105 @@ +
+ +# ORM RL Demo 训练示例 + +基于 Geo3K 的 ORM 轨迹打分 RL 训练完整 demo。 + +
+ +## 概述 + +本示例展示了使用 ORM 对轨迹打分进行 RL 训练的完整流程,包含以下配置: +- 数据集:Geo3K +- actor:Qwen2.5-VL 7B 模型 +- reward:单个 general outcome reward model,与规则正确性奖励和格式奖励组合后共同计算 GRPO loss +- 训练引擎:FSDP,推理引擎:SGLang + +训练时,actor 在 Geo3K 上生成轨迹,general ORM 对轨迹打分,与规则正确性奖励(`accuracy_reward`)和格式奖励(`format_reward`)三路混合后,共同计算 GRPO loss。为了不直接改写 Geo3K 数据集文件,本 demo 在运行时将数据标签覆盖为 `geo3k_general`,沿用原始数据路径的同时走 general ORM reward 融合逻辑。 + +环境要求与仓库根目录 [README_zh.md](../../README_zh.md#环境要求) 保持一致,请直接参考主文档。 + +## 项目结构 + +```text +orm_rl_demo/ +├── train_colocate.py +├── reward_models.py +├── reward_models_utils.py +├── test_reward_models.py +└── run_general_fsdp_qwenvl.sh +``` + +## 快速开始 + +设置数据和模型路径后,运行入口脚本: + +```bash +export DATA_PATH=/path/to/geo3k +export PRETRAIN_PATH=/path/to/Qwen2.5-VL-7B-Instruct +export REWARD_PRETRAIN_PATHS='{"general":"/path/to/general-reward-model"}' +bash examples/orm_rl_demo/run_general_fsdp_qwenvl.sh +``` + +运行前通过上述环境变量指定数据和模型路径。 + +## 实验结果 + +### 实验设置 + +本 demo 已通过一次真实的 2 卡全量训练验通(W&B run:[ORM-RL-Demo-QwenVL-7B-Geo3K](https://wandb.ai/hansbug/ORM-RL-Demo-QwenVL-7B-Geo3K/runs/zrekazyw)),关键配置如下: + +| 项目 | 值 | +| --- | --- | +| Actor | Qwen2.5-VL-7B-Instruct | +| General RM | Qwen2.5-VL-7B general reward model | +| 数据 | Geo3K | +| 训练引擎 | FSDP | +| 推理引擎 | SGLang(`rm_use_engine=True`) | +| Reward 融合 | `format_reward × 0.1 + general_model_reward × 0.2 + accuracy_reward × 0.7` | +| Batch 大小 | `train_batch_size=128`, `rollout_batch_size=128` | +| 采样 | `n_samples_per_prompt=8`, `num_episodes=20` | +| 长度 | `prompt_max_len=1024`, `generate_max_len=2048` | +| 优化 / KL | `actor_learning_rate=1e-6`, `init_kl_coef=0.001`, `lr_warmup_ratio=0.03` | + +三路奖励(规则正确性奖励 `accuracy_reward`、ORM 打分 `general_model_reward` 系数 0.2、格式奖励 `format_reward`)按上述权重混合,共同计算最终的 GRPO loss。其中 `general_model_reward` 对应的 0.2 是权重系数,ORM 模型本身的输出范围为 0.0 / 0.5 / 1.0,乘以 0.2 后得到 reward 贡献。 + +### 整体曲线结果 + +训练完整跑完(`train/global_step=320`,共 16 次 eval): +- `eval/reward_mean` 从 `0.4636` 提升到 `0.5679` +- Best `eval/reward_mean=0.5686`,出现在 `train_step=260` +- Final `eval/accuracy_reward_mean=0.5166`,`eval/format_reward_mean=0.9956`,`eval/general_model_reward_mean=0.1067` + +![](assets/exp_20260417/summary_card.png) + +![](assets/exp_20260417/reward_dashboard.png) + +![](assets/exp_20260417/optimization_dashboard.png) + +### 案例分析 + +Step 80 和 Step 320 之间共有 2 道题目重叠,以下展示这 2 道题从早期到末期的真实对照。 + +#### Question A:平行四边形面积题 + +![](assets/exp_20260417/question_a_step80.png) + +![](assets/exp_20260417/question_a_step320.png) + +- Step 80 reward:`total=0.3`,`format=1.0`,`accuracy=0.0`,`general_model=0.2`,`rule=0.1` +- Step 320 reward:`total=1.0`,`format=1.0`,`accuracy=1.0`,`general_model=0.2`,`rule=0.8` +- 含义:step 80 时 actor 答案已很接近,ORM 打出 1.0,乘系数 0.2 后贡献 0.2;到 step 320 时,输出从 `38.97` 修正为规则答案 `39.0`,`accuracy_reward` 从 `0.0` 跳至 `1.0`。 + +#### Question B:切线几何 `y` + +![](assets/exp_20260417/question_b_step80.png) + +![](assets/exp_20260417/question_b_step320.png) + +- Step 80 reward:`total=0.1`,`format=1.0`,`accuracy=0.0`,`general_model=0.0`,`rule=0.1` +- Step 320 reward:`total=1.0`,`format=1.0`,`accuracy=1.0`,`general_model=0.2`,`rule=0.8` +- 含义:step 80 时只保住了格式,accuracy 和 general RM 均未给分;到 step 320 时,两项均变为正向贡献。 + +## 许可证 + +本项目采用 Apache 2.0 许可证。详见 [LICENSE](../../LICENSE)。 diff --git a/examples/orm_rl_demo/assets/exp_20260417/optimization_dashboard.png b/examples/orm_rl_demo/assets/exp_20260417/optimization_dashboard.png new file mode 100644 index 00000000..12dc06cf Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/optimization_dashboard.png differ diff --git a/examples/orm_rl_demo/assets/exp_20260417/question_a_step320.png b/examples/orm_rl_demo/assets/exp_20260417/question_a_step320.png new file mode 100644 index 00000000..b4ffb426 Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/question_a_step320.png differ diff --git a/examples/orm_rl_demo/assets/exp_20260417/question_a_step80.png b/examples/orm_rl_demo/assets/exp_20260417/question_a_step80.png new file mode 100644 index 00000000..edf199a9 Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/question_a_step80.png differ diff --git a/examples/orm_rl_demo/assets/exp_20260417/question_b_step320.png b/examples/orm_rl_demo/assets/exp_20260417/question_b_step320.png new file mode 100644 index 00000000..5f1c83f7 Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/question_b_step320.png differ diff --git a/examples/orm_rl_demo/assets/exp_20260417/question_b_step80.png b/examples/orm_rl_demo/assets/exp_20260417/question_b_step80.png new file mode 100644 index 00000000..5eba3e72 Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/question_b_step80.png differ diff --git a/examples/orm_rl_demo/assets/exp_20260417/reward_dashboard.png b/examples/orm_rl_demo/assets/exp_20260417/reward_dashboard.png new file mode 100644 index 00000000..2b2d8e9a Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/reward_dashboard.png differ diff --git a/examples/orm_rl_demo/assets/exp_20260417/summary_card.png b/examples/orm_rl_demo/assets/exp_20260417/summary_card.png new file mode 100644 index 00000000..43e69a06 Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/summary_card.png differ diff --git a/examples/orm_rl_demo/reward_models.py b/examples/orm_rl_demo/reward_models.py new file mode 100755 index 00000000..dc671cf4 --- /dev/null +++ b/examples/orm_rl_demo/reward_models.py @@ -0,0 +1,1087 @@ +""" +General reward model helpers for the ORM RL Geo3K demo. + +This example keeps only the general outcome reward-model path that is exercised by +`examples/orm_rl_demo/train_colocate.py` and `examples/orm_rl_demo/test_reward_models.py`. +The shared helper functions below are used by that general reward model for both +HuggingFace and engine-based inference. +""" +from __future__ import annotations + +from typing import Optional, List, Tuple +import re +import copy +import os +import torch +import torch.nn as nn +import torch.distributed as dist +from transformers import LogitsProcessor +from itertools import zip_longest + +from lightrft.utils import get_current_device +from lightrft.strategy.utils.distributed_util import gather_inputs_object_for_inference +from lightrft.strategy import is_engine + + +# ============================================================================ +# Utility Functions +# ============================================================================ + +def is_chinese(text): + """ + Detect whether text contains Chinese characters. + + :param text: Text string to detect + :type text: str + :return: True if text contains Chinese characters, False otherwise + :rtype: bool + """ + if not isinstance(text, str): + return False + chinese_pattern = re.compile(r'[\u4e00-\u9fff]') + return bool(chinese_pattern.search(text)) + + +def _pack_engine_inputs( + prompts: list[str], + image_data: list[list] | None, +) -> tuple[list[str], list[list] | None]: + """ + Pack engine inputs ensuring prompts and image_data have consistent lengths. + + Returns None for image_data when all images are empty to skip redundant parameters. + + :param prompts: List of text prompts + :type prompts: list[str] + :param image_data: List of image data, each element is a list of images + :type image_data: list[list] or None + :return: Processed (prompts, image_data or None) + :rtype: tuple + """ + if image_data is None: + return prompts, None + + fixed_prompts, fixed_images = [], [] + for p, imgs in zip(prompts, image_data): + if "<|image_pad|>" in p: + fixed_prompts.append(p) + fixed_images.append(imgs[:1] or [None]) # at least one placeholder + else: + fixed_prompts.append(p) + fixed_images.append([]) + + assert len(fixed_prompts) == len(fixed_images) + + if all(len(imgs) == 0 for imgs in fixed_images): + fixed_images = None + + return fixed_prompts, fixed_images + + +def _align_prompts_images( + prompts: list[str], + image_data: list[list] | None, +) -> tuple[list[str], list[list] | None]: + """ + Align prompts and images, separating text-only and multimodal data. + + Prompts containing ``<|image_pad|>`` must have at least one placeholder image; + prompts without placeholders must have no images. + + :param prompts: List of text prompts + :type prompts: list[str] + :param image_data: List of image data (None if no images) + :type image_data: list[list] or None + :return: (text_prompts, text_indices, mm_prompts, mm_images) + :rtype: tuple + """ + if image_data is None: # No images passed at all + return prompts, None + text_prompts = [] + mm_prompts, mm_images = [], [] + text_inds = [] + + ind = 0 + for p, imgs in zip_longest(prompts, image_data, fillvalue=None): + if p is None: # Extra images → discard + continue + + imgs = [] if imgs is None else imgs # Ensure imgs is a list + if "<|image_pad|>" in p: # Must keep 1 placeholder + imgs = imgs[:1] or [None] + if isinstance(imgs[0], list): + imgs = imgs[0] + mm_images.append(imgs) + mm_prompts.append(p) + else: # Pure text prompt cannot have images + text_prompts.append(p) + text_inds.append(ind) + + ind += 1 + + return text_prompts, text_inds, mm_prompts, mm_images + + +def _hf_or_engine_generate( + model, + *, + input_ids : torch.Tensor | None = None, + attention_mask : torch.Tensor | None = None, + pixel_values : torch.Tensor | None = None, + image_grid_thw : torch.Tensor | None = None, + prompts : List[str] | None = None, + image_data : List[List] | None = None, + **gen_kwargs, +) -> Tuple[List[str], torch.Tensor | None]: + """ + Unified generation interface supporting both HuggingFace models and SGLang engines. + + Automatically detects model type. Engine mode uses string prompts and image_data; + HF mode uses tensor inputs (input_ids, pixel_values, etc.). + + :param model: HF model or SGLang engine instance + :param input_ids: Input token IDs for HF mode + :type input_ids: torch.Tensor or None + :param attention_mask: Attention mask for HF mode + :type attention_mask: torch.Tensor or None + :param pixel_values: Image pixel values for HF mode + :type pixel_values: torch.Tensor or None + :param image_grid_thw: Image grid size for HF mode + :type image_grid_thw: torch.Tensor or None + :param prompts: Text prompts for Engine mode + :type prompts: list[str] or None + :param image_data: Image data for Engine mode + :type image_data: list[list] or None + :param gen_kwargs: Generation parameters (max_new_tokens, temperature, etc.) + :return: (list of generated texts, generated token IDs or None) + :rtype: tuple + """ + if is_engine(model): + assert input_ids is None, "Cannot pass input_ids in engine mode" + enable_sleep_mode = True + if hasattr(model, "llm_engine"): + enable_sleep_mode = model.llm_engine.vllm_config.model_config.enable_sleep_mode + + if enable_sleep_mode: + model.wake_up() + + if hasattr(model, "tp_group_cpu"): + sampling_params = { + **{k: v for k, v in gen_kwargs.items() if k not in ("do_sample")} + } + + prompt_and_output = gather_inputs_object_for_inference(prompts, model.tp_group_cpu) + image_data = gather_inputs_object_for_inference(image_data, model.tp_group_cpu) + + text_prompts, text_inds, mm_prompts, mm_images = _align_prompts_images(prompt_and_output, image_data) + text_output = [] + mm_output = [] + + if len(text_prompts) > 0: + sgl_outputs = model.generate(prompt=text_prompts, sampling_params=sampling_params, gather_inputs=False) + text_output = [sgl_out["text"] for sgl_out in sgl_outputs] + + if len(mm_prompts) > 0: + sgl_outputs = model.generate( + prompt=mm_prompts, + image_data=mm_images, + sampling_params=sampling_params, + gather_inputs=False, + ) + mm_output = [sgl_out["text"] for sgl_out in sgl_outputs] + + texts = [] + text_output_iter = iter(text_output) + mm_output_iter = iter(mm_output) + # merge results in original order + if len(text_inds) > 0: + for i in range(len(prompt_and_output)): + if i in text_inds: + texts.append(next(text_output_iter)) + else: + texts.append(next(mm_output_iter)) + else: + texts = mm_output + + if model._tp_size > 1: + num_per_rank = len(texts) // model._tp_size + texts = texts[model._tp_rank * num_per_rank : (model._tp_rank + 1) * num_per_rank] + else: + from vllm import SamplingParams + + sampling_kwargs = dict(gen_kwargs) + max_tokens = sampling_kwargs.pop("max_new_tokens", None) + if max_tokens is not None: + sampling_kwargs["max_tokens"] = max_tokens + sampling_kwargs.pop("do_sample", None) + sampling_params = SamplingParams(**sampling_kwargs) + + prompt_and_output = prompts or [] + prompt_and_output, image_data = _pack_engine_inputs( + prompt_and_output, + image_data, + ) + if image_data is None: + image_data = [None] * len(prompt_and_output) + + vllm_prompts = [] + for prompt, imgs in zip(prompt_and_output, image_data): + prompt_item = {"prompt": prompt} + if imgs: + prompt_item["multi_modal_data"] = { + "image": imgs[0] if len(imgs) == 1 else imgs + } + vllm_prompts.append(prompt_item) + + vllm_outputs = model.generate( + vllm_prompts, + sampling_params=sampling_params, + use_tqdm=False, + ) + texts = [ + out.outputs[0].text if getattr(out, "outputs", None) else "" + for out in vllm_outputs + ] + + if dist.is_initialized() and dist.get_rank() == 0: + if not texts or all(not t for t in texts): + print("WARNING: _hf_or_engine_generate produced empty output for all prompts.") + + if enable_sleep_mode: + model.sleep() + torch.cuda.empty_cache() + return texts, None + + else: + gen_ids = model.generate( + input_ids = input_ids, + attention_mask = attention_mask, + pixel_values = pixel_values, + image_grid_thw = image_grid_thw, + **gen_kwargs, + ) + trim = [o[len(i):] for i, o in zip(input_ids, gen_ids)] + return trim, trim + + +# ============================================================================ +# Vision Token Processing +# ============================================================================ + +_VISION_RE = re.compile(r"<\|vision_start\|>.*?<\|vision_end\|>", re.S) + +def _strip_vision_tokens(text: str) -> str: + """Remove vision token markers from text.""" + return re.sub(_VISION_RE, "", text).replace("", "").strip() + + +def _clean_vision_token(text: str) -> str: + """ + Clean vision tokens from text, supporting multiple formats. + + Supported formats: + - <|vision_start|><|image_pad|>...<|vision_end|> + - ... + - + """ + patterns = [ + r"<\|vision_start\|>(<\|image_pad\|>)+<\|vision_end\|>", + r"()+", + r"" + ] + for p in patterns: + text = re.sub(p, "", text) + return text + + +def _replace_vision_token(text: str) -> str: + """ + Replace vision tokens with standard markers. + + Conversion rules: + - <|vision_start|>...<|vision_end|> -> + - ...... -> (internvl format) + """ + text = re.sub(r"<\|vision_start\|>(<\|image_pad\|>)+<\|vision_end\|>", "", text) + text = re.sub(r"()+", "", text) # internvl + + return text + + +def _strip_pad_eos(text: str, pad: str, eos: str) -> str: + """ + Remove leading and trailing pad and eos tokens from text. + + :param text: Text to process + :type text: str + :param pad: Pad token string + :type pad: str + :param eos: EOS token string + :type eos: str + :return: Cleaned text + :rtype: str + """ + pad, eos = map(re.escape, (pad, eos)) + text = re.sub(f"^({eos}|{pad})+", "", text) + text = re.sub(f"({eos}|{pad})+$", "", text) + return text + +# ============================================================================ +# Dialog Parsing Constants and Functions +# ============================================================================ + +# Define constants for vertical bars used in role tags for better readability +FULL_BAR = "|" # U+FF5C Full-width vertical bar +HALF_BAR = "|" # U+007C ASCII vertical bar + +def _parse_dialog(text: str) -> dict: + """ + Parse a full conversation string into a dictionary mapping roles to their content. + + Identifies role tags like ``<| role_name |>`` and extracts the text that follows + each tag. If a role appears multiple times, only the last occurrence is kept. + + :param text: Conversation string with role tags + :type text: str + :return: Dict mapping role names to their message content + :rtype: dict + """ + # 1. Define the regex pattern to find all possible role tags. + # The pattern is written in verbose mode (re.X) for clarity. + tag_pattern = re.compile( + rf""" + < # Match the opening '<' + [{HALF_BAR}{FULL_BAR}] # Match either a half-width or full-width vertical bar + \s*? # Match any whitespace characters (non-greedy) + (.*?) # Capture the role name (non-greedy) + \s*? # Match any whitespace characters (non-greedy) + [{HALF_BAR}{FULL_BAR}] # Match either a half-width or full-width vertical bar + > # Match the closing '>' + """, re.X | re.S + ) + + # Find all occurrences of role tags in the text. + tags = list(tag_pattern.finditer(text)) + dialog = {} + + # 2. Iterate through the found tags to extract roles and content. + for idx, tag in enumerate(tags): + # Extract the role name and normalize it by stripping whitespace and converting to lowercase. + raw_role = tag.group(1).strip() + role = raw_role.lower() + + # Skip special meta-tags that define structure but are not roles. + if role in {"im_start", "im_end", "begin of sentence", "end of sentence"}: + continue + + # Determine the start and end positions of the content for the current role. + # The content starts right after the current tag. + start_pos = tag.end() + # The content ends right before the next tag starts, or at the end of the text. + end_pos = tags[idx + 1].start() if idx + 1 < len(tags) else len(text) + content = text[start_pos:end_pos].strip() + + # 3. Special handling for the 'assistant' role to remove the chain-of-thought block. + # If the content contains ..., we extract only the final response + # that appears after the last tag. + if role == "assistant" and "" in content and "" in content: + think_end = content.rfind("") + if think_end != -1: + content = content[think_end + len(""):].strip() + + # Store the role and its content in the dictionary. + # If the role already exists, its value will be updated with the new content. + dialog[role] = content + + return dialog + +def preprocess_inputs_sglang( + prompt_and_outputs: list, + references: list, + question_response_format_zh: list or str, + question_response_format_en: str, + system_prompt_zh: str = None, + system_prompt_en: str = None, + system_prompt: bool = False, +): + """ + Preprocess batch conversation inputs for SGLang engine. + + Parses conversation text, selects a format template based on detected language, + and optionally prepends a system prompt. + + :param prompt_and_outputs: List of conversation texts + :type prompt_and_outputs: list + :param references: List of reference answers + :type references: list + :param question_response_format_zh: Chinese format template (string or per-sample list) + :type question_response_format_zh: str or list + :param question_response_format_en: English format template + :type question_response_format_en: str + :param system_prompt_zh: Chinese system prompt + :type system_prompt_zh: str or None + :param system_prompt_en: English system prompt + :type system_prompt_en: str or None + :param system_prompt: Whether to prepend a system prompt + :type system_prompt: bool + :return: List of formatted texts ready for model input + :rtype: list + """ + raw_texts = [] + # Process each conversation in the batch. + for i, po in enumerate(prompt_and_outputs): + # Parse the conversation string into a role-content dictionary. + dialog = _parse_dialog(po) + + # --- Step 1: Extract the question --- + if "user" in dialog: + question_raw = dialog["user"] + else: + # Fallback logic: if 'user' role is not found, use the content from the + # first role that is not 'assistant'. If no such role exists, + # use the entire original string as the question. + question_raw = next( + (txt for role, txt in dialog.items() if role != "assistant"), po + ) + # Clean the extracted question (e.g., remove special vision tokens). + # Note: _clean_vision_token function is assumed to be defined elsewhere. + question = _clean_vision_token(question_raw) + + # --- Step 2: Extract the response --- + if "assistant" in dialog: + response = dialog["assistant"] + else: + # Fallback logic: if 'assistant' role is not found, assume the response + # is the text following the last tag. + response = po.split("")[-1].strip() + + reference = references[i] + + # --- Step 3: Select the appropriate formatting template --- + # Note: is_chinese function is assumed to be defined elsewhere. + is_zh = is_chinese(question) + if isinstance(question_response_format_zh, list): + # New feature: Use a custom template for each item in the batch. + fmt = question_response_format_zh[i] + else: + # Old logic: Choose the template based on the detected language. + fmt = question_response_format_zh if is_zh else question_response_format_en + + # --- Step 4: Format the final input string --- + # The template may or may not include a placeholder for the reference text. + if "{reference}" in fmt: + raw_text = fmt.format( + question=question, + reference=reference, + response=response + ) + else: + raw_text = fmt.format(question=question, response=response) + + # --- Step 5: Prepend a system prompt if enabled --- + if system_prompt: + # Select the system prompt based on the language. + system_prompt_text = system_prompt_zh if is_zh else system_prompt_en + # Using deepcopy to avoid modifying the original system prompt object. + final_text = copy.deepcopy(system_prompt_text) + "\n" + raw_text + raw_texts.append(final_text) + else: + raw_texts.append(raw_text) + + return raw_texts + + +def build_general_engine_queries( + processor, + prompt_and_outputs: list, + references: list, + raw_images: list | None, + question_response_format_zh: str, + question_response_format_en: str, + system_prompt_zh: str, + system_prompt_en: str, +): + """ + Build general-RM engine prompts using the model's chat template. + + The vLLM engine path for Qwen2.5-VL is much more stable when we append the + assistant generation prompt explicitly. Without this, the engine often + returns empty strings or prompt-continuation fragments instead of verdicts. + """ + test_data = [] + expected_image_counts = [] + normalized_image_data = [] + + if raw_images is None: + raw_images = [None] * len(prompt_and_outputs) + + for i, prompt_and_output in enumerate(prompt_and_outputs): + dialog = _parse_dialog(prompt_and_output) + + if "user" in dialog: + question_raw = dialog["user"] + else: + question_raw = next( + (txt for role, txt in dialog.items() if role != "assistant"), + prompt_and_output, + ) + + if "assistant" in dialog: + response = dialog["assistant"] + else: + response = prompt_and_output.split("")[-1].strip() + + question = _clean_vision_token(question_raw) + reference = references[i] if references is not None and i < len(references) else "" + is_zh = is_chinese(question) + fmt = question_response_format_zh if is_zh else question_response_format_en + system_prompt = system_prompt_zh if is_zh else system_prompt_en + user_text = fmt.format(question=question, response=response, reference=reference) + + raw_image = raw_images[i] if i < len(raw_images) else None + has_image = raw_image is not None + expected_image_counts.append(1 if has_image else 0) + normalized_image_data.append([raw_image] if has_image else []) + + user_content = [{"type": "text", "text": user_text}] + if has_image: + user_content = [ + { + "type": "image", + "image": [], + "min_pixels": 224 * 224, + "max_pixels": 1280 * 1280, + }, + {"type": "text", "text": user_text}, + ] + + test_data.append( + [ + {"role": "system", "content": [{"type": "text", "text": system_prompt}]}, + {"role": "user", "content": user_content}, + ] + ) + + queries = processor.apply_chat_template( + test_data, + tokenize=False, + add_generation_prompt=True, + ) + if isinstance(queries, str): + queries = [queries] + + fixed_queries = [] + for query, expected_image_count in zip(queries, expected_image_counts): + query_image_token_count = query.count("<|image_pad|>") + if query_image_token_count > expected_image_count: + excess_tokens = query_image_token_count - expected_image_count + query = query.replace("<|image_pad|>", "", excess_tokens) + fixed_queries.append(query) + + return fixed_queries, normalized_image_data + + +def preprocess_inputs( + tokenizer = None, + processor = None, + device = get_current_device(), + system_prompt: Optional[str] = None, + question_response_format: str = "", + input_ids: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pad_token: str = "", + eos_token: str = "<|endoftext|>", + clean_or_replace_vision_token: bool = False, + vision_token_process_type: str = 'clean', + padding_side: str = "left", + return_think_content: bool = False, + debug: bool = False, + queries: Optional[list] = None, + return_raw_texts: bool = False, +): + """ + Preprocess inputs for HuggingFace models. + + Supports building inputs from ``input_ids`` or ``queries``, optional vision-token + processing, and chain-of-thought content separation. + + :param tokenizer: HF tokenizer instance + :param processor: HF processor instance + :param device: Target device + :param system_prompt: System prompt (optional; use to distinguish value/knowledge from safety/normal data) + :type system_prompt: str or None + :param question_response_format: Q&A format template + :type question_response_format: str + :param input_ids: Input token IDs + :type input_ids: torch.Tensor or None + :param pixel_values: Image pixel values + :type pixel_values: torch.Tensor or None + :param pad_token: Padding token + :type pad_token: str + :param eos_token: End-of-sequence token + :type eos_token: str + :param clean_or_replace_vision_token: Whether to process vision tokens + :type clean_or_replace_vision_token: bool + :param vision_token_process_type: Processing method (``'clean'`` or ``'replace'``) + :type vision_token_process_type: str + :param padding_side: Padding direction + :type padding_side: str + :param return_think_content: Whether to separate chain-of-thought content + :type return_think_content: bool + :param debug: Debug mode + :type debug: bool + :param queries: List of query texts + :type queries: list or None + :param return_raw_texts: Whether to return raw texts instead of tensors + :type return_raw_texts: bool + :return: Standard mode returns ``(input_ids, attention_mask, response_empty)``; + CoT mode returns ``(answer_input_ids, answer_mask, think_input_ids, think_mask, valid_think, response_empty)``; + raw text mode returns ``(raw_texts, ...)``. + :rtype: tuple + """ + if input_ids is not None: + processor.tokenizer.padding_side = padding_side + queries = tokenizer.batch_decode(input_ids, skip_special_tokens=False) + else: + assert queries is not None + + for i, query in enumerate(queries): + if clean_or_replace_vision_token: + if vision_token_process_type == 'clean': # value, knowledge + queries[i] = _clean_vision_token(query) + elif vision_token_process_type == 'replace': # safety, normal + queries[i] = _replace_vision_token(query) + else: + raise KeyError(f"Invalid vision token process type: {vision_token_process_type}") + queries[i] = _strip_pad_eos(queries[i], pad_token, eos_token) + eos_token + + # Extract question and response from query using regex + pattern = r"<\|im_start\|>(\w+)\n(.*?)<\|im_end\|>" + # NOTE: parse dialog logic haven't adapt to deepseek model now + def _prepare_message(dialog, test_data, image_token_count_list): + question = dialog.get('user', '') + response = dialog.get('assistant', '') + image_token_count_list.append(question.count('<|image_pad|>')) + if system_prompt is not None: + test_data.append( + [ + {"role": "system", "content":[{"type": "text", "text": system_prompt}]}, + {"role": "user", "content": [{"type": "image", "image": [], "min_pixels": 224 * 224, "max_pixels": 1280 * 1280}, {"type": "text", "text": question_response_format.format(question=question, response=response)}]} + ] + ) + else: + test_data.append( + [ + {"role": "user", "content": [{"type": "text", "text": question_response_format.format(question=question, response=response)}]} + ] + ) + if debug and dist.is_initialized() and dist.get_rank() == 0: + print(f"test_data:\n {test_data[0]}\n") + + # Process all queries in the batch at once + test_data, image_token_count_list = [], [] + think_test_data, think_image_token_count_list, valid_think = [], [], [] + response_empty = [] + for query in queries: + matches = re.findall(pattern, query, re.DOTALL) + dialog = {} + if return_think_content: + think_dialog = {} + valid_think_flag = False + for role, content in matches: + dialog[role] = content.strip() + if return_think_content: + think_dialog[role] = content.strip() + # If assistant's reply contains thinking chain content wrapped in and , extract only the content after + if role == "assistant" and "" in content and "" in content: + # Find the position of the last + think_end_pos = content.rfind("") + if think_end_pos != -1: + # Extract content after and remove leading/trailing whitespace + dialog[role] = content[think_end_pos + len(""):].strip() + if return_think_content: + think_dialog[role] = content[:think_end_pos + len("") + 1].strip() + valid_think_flag = True + + _prepare_message(dialog, test_data, image_token_count_list) + response_empty.append(dialog.get('assistant', '') == '') + if return_think_content: + valid_think.append(valid_think_flag) + _prepare_message(think_dialog, think_test_data, think_image_token_count_list) + + def _get_batch_input(test_data, image_token_count_list, return_raw_texts): + # Process the entire batch at once + if system_prompt is not None: + # Only apply chat template when system prompt is provided + queries = processor.apply_chat_template(test_data, tokenize=False, add_generation_prompt=False) + else: + # For data without system prompt, format directly without applying chat template + queries = [item[0]["content"][0]["text"] for item in test_data] + + # TODO: `apply_chat_template` will add a extra image token in the query, so we need to remove it now, we need more elegant way + for i, query in enumerate(queries): + query_image_token_count = query.count('<|image_pad|>') + if query_image_token_count > image_token_count_list[i]: + # Replace all excess image tokens to match the expected count + excess_tokens = query_image_token_count - image_token_count_list[i] + queries[i] = query.replace('<|image_pad|>', '', excess_tokens) + + if not return_raw_texts: + with torch.no_grad(): + batch_inputs = processor( + text=queries, + padding=True, + return_tensors="pt", + ).to(device) + return batch_inputs + else: + return queries + + answer_batch_input = _get_batch_input(test_data, image_token_count_list, return_raw_texts) + if return_think_content: + think_batch_input = _get_batch_input(think_test_data, think_image_token_count_list, return_raw_texts) + if not return_raw_texts: + return answer_batch_input['input_ids'], answer_batch_input['attention_mask'], think_batch_input['input_ids'], think_batch_input['attention_mask'], valid_think, response_empty + else: + return answer_batch_input, think_batch_input, valid_think + else: + if not return_raw_texts: + return answer_batch_input['input_ids'], answer_batch_input['attention_mask'], response_empty + else: + return answer_batch_input + + + if engine._tp_size > 1: + num_per_rank = len(texts) // engine._tp_size + texts = texts[engine._tp_rank * num_per_rank : (engine._tp_rank+1) * num_per_rank] + + return texts + + +# ============================================================================ +# General Reward Model +# ============================================================================ + +class AllowedTokensLogitsProcessor(LogitsProcessor): + def __init__(self, allowed_token_ids): + self.allowed_token_ids = set(allowed_token_ids) + + def __call__(self, input_ids, scores): + # Set all non-allowed tokens to very negative values + mask = torch.ones_like(scores) * float('-inf') + for token_id in self.allowed_token_ids: + mask[:, token_id] = 0 + return scores + mask + + +class Qwen2VLRewardModelGeneral(nn.Module): + """ + General quality reward model that evaluates answer correctness based on reference answers. + + Scoring rules: + + - ``1.0``: Completely correct (all sub-questions correct) + - ``0.5``: Partially correct (at least one sub-question correct, but not all) + - ``0.0``: Incorrect (all sub-questions wrong or answer irrelevant) + + :param base_model: HF model or Engine instance + :param tokenizer: Tokenizer instance + :param processor: Processor instance + :param text_only: Whether to use text-only mode (no image inputs) + :type text_only: bool + """ + + general_scores = [0.0, 0.5, 1.0] + general_system_prompt_zh = """你是一个评分专家,负责根据参考答案reference评估assistant对user的回复是否正确且合理。 + **你将收到包含以下XML标签的内容:``表示用户的问题,``表示助手的回答,``表示参考答案。** + 请严格按以下规则输出固定稀疏奖励: + + 评估规则: + 1. 答案等价性: + - 简洁答案和带解题步骤的答案都接受,只要包含正确答案 + - 答案可能出现在回答的开头、中间或结尾 + - 只比较核心答案,忽略解释部分 + + 2. 数值等价性: + - 不同格式的数字视为等价(如2,"2",['2'],"答案是2") + - 百分比可以用小数或%表示(如28%=0.28) + - 带/不带逗号的数字视为等价(如123,456.7=123456.7) + + 3. 格式灵活性: + - 列表、引号、表格或纯文本中的正确答案都接受 + - 正确答案周围的额外解释或格式不影响评分 + - 大小写不敏感 + + 4. 多参考答案情况: + - 参考答案有多个可接受答案时,匹配一个即可视为该部分正确。 + + 5. 多子问题情况: + - 如果问题包含多个子问题,需要逐一评估assistant对每个子问题的回答。 + - 只有当所有子问题都回答正确时,总分才为 1.0。 + - 如果至少有一个子问题回答正确,但并非所有子问题都正确,则总分为 0.5。 + - 如果所有子问题都回答错误或回答与问题无关,则总分为 0.0。 + + 6. 容错性: + - 轻微拼写错误或措辞差异不影响评分 + - 等价数学表达式视为正确 + + 输出要求: + 1. **仅允许输出以下三个数值之一:0.0、0.5、1.0** + 2. 根据参考答案与回答的匹配程度选择: + - 完全正确 (所有子问题均正确) → 1.0 + - 部分正确 (至少答对一个子问题,但非全部) → 0.5 + - 错误 (所有子问题均错误或回答与问题无关) → 0.0 + 3. 直接输出数值,不需要任何解释""" + + question_response_format_zh = """请根据以下内容进行评估: + + + {question} + + + + + {response} + + + + {reference} + """ + + general_system_prompt_en = """You are a scoring expert responsible for evaluating whether the assistant's response to the user is correct and reasonable based on the reference answer. + **You will receive content with the following XML tags: `` represents the user's question, `` represents the assistant's answer, and `` represents the reference answer.** + Please strictly output fixed sparse rewards according to the following rules: + + Evaluation Rules: + 1. Answer Equivalence: + - Both concise answers and answers with solution steps are accepted, as long as they contain the correct answer + - The answer may appear at the beginning, middle, or end of the response + - Only compare core answers, ignore explanation parts + + 2. Numerical Equivalence: + - Numbers in different formats are considered equivalent (e.g., 2, "2", ['2'], "the answer is 2") + - Percentages can be expressed as decimals or % (e.g., 28% = 0.28) + - Numbers with/without commas are equivalent (e.g., 123,456.7 = 123456.7) + + 3. Format Flexibility: + - Correct answers in lists, quotes, tables, or plain text are all accepted + - Additional explanations or formatting around the correct answer do not affect scoring + - Case insensitive + + 4. Multiple Reference Answers: + - When there are multiple acceptable reference answers, matching any one is considered correct for that part. + + 5. Multiple Sub-questions: + - If the question contains multiple sub-questions, evaluate the assistant's answer for each sub-question. + - Only when all sub-questions are answered correctly will the total score be 1.0. + - If at least one sub-question is answered correctly, but not all sub-questions are correct, the total score is 0.5. + - If all sub-questions are answered incorrectly or the answer is irrelevant to the question, the total score is 0.0. + + 6. Error Tolerance: + - Minor spelling errors or wording differences do not affect scoring + - Equivalent mathematical expressions are considered correct + + Output Requirements: + 1. **Only the following three values are allowed: 0.0, 0.5, 1.0** + 2. Choose based on the degree of match between the reference answer and the response: + - Completely correct (all sub-questions correct) → 1.0 + - Partially correct (at least one sub-question correct, but not all) → 0.5 + - Incorrect (all sub-questions incorrect or answer irrelevant to question) → 0.0 + 3. Output the value (0.0, 0.5, 1.0) directly, no explanation needed""" + + question_response_format_en = """Please evaluate based on the following content: + + + {question} + + + + + {response} + + + + {reference} + """ + + ALLOWED_STR_TOKENS = ["0", "1", "0.0", "0.5", "1.0"] + + def __init__(self, base_model, tokenizer, processor, text_only: bool = False): + super().__init__() + self.base_model: nn.Module = base_model + self.tokenizer = tokenizer + self.processor = processor + self.device = torch.cuda.current_device() + self.text_only = text_only + + self._allowed_token_seqs: list[list[int]] = [] + for s in self.ALLOWED_STR_TOKENS: + ids = self.tokenizer.encode(s, add_special_tokens=False) + self._allowed_token_seqs.append(ids) + + self._verdict_log_enabled = os.environ.get("ORM_RL_DEMO_RM_VERDICT_LOG", "0") == "1" + self._verdict_log_max = int(os.environ.get("ORM_RL_DEMO_RM_VERDICT_LOG_MAX", "128")) + self._verdict_log_count = 0 + + first_ids = {seq[0] for seq in self._allowed_token_seqs} + self._logits_proc = [AllowedTokensLogitsProcessor(first_ids)] + self._max_answer_len = max(len(x) for x in self._allowed_token_seqs) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + references: List[str] | None = None, + prompt_and_outputs=None, + prompt_and_output=None, + raw_images=None, + **kwargs, # for compatibility + ): + """ + Returns: {'score': FloatTensor[B]}, only in 0/0.5/1 + """ + if prompt_and_outputs is None: + prompt_and_outputs = prompt_and_output + if prompt_and_outputs is None: + prompt_and_outputs = kwargs.get("prompt_and_output") + if prompt_and_outputs is None: + raise ValueError("`prompt_and_outputs` or `prompt_and_output` is required") + + raw_texts = preprocess_inputs_sglang( + prompt_and_outputs, + references, + self.question_response_format_zh, + self.question_response_format_en, + self.general_system_prompt_zh, + self.general_system_prompt_en, + system_prompt=True, + ) + + if is_engine(self.base_model): + raw_texts, raw_images = build_general_engine_queries( + self.processor, + prompt_and_outputs, + references, + raw_images, + self.question_response_format_zh, + self.question_response_format_en, + self.general_system_prompt_zh, + self.general_system_prompt_en, + ) + gen_texts, _ = _hf_or_engine_generate( + self.base_model, + prompts=raw_texts, + image_data=raw_images, + max_new_tokens=4, + temperature=0.0, + ) + else: + model_in = self.processor( + text=raw_texts, padding=True, return_tensors="pt" + ).to(self.device) + _, gen_ids = _hf_or_engine_generate( + self.base_model, + input_ids=model_in["input_ids"], + attention_mask=model_in["attention_mask"], + pixel_values=None, + image_grid_thw=None, + max_new_tokens=self._max_answer_len, + temperature=0.0, + do_sample=False, + logits_processor=self._logits_proc, + ) + gen_texts = self.tokenizer.batch_decode( + gen_ids, skip_special_tokens=True + ) + + log_on_rank0 = (not dist.is_initialized()) or dist.get_rank() == 0 + + def _log_verdict_detail(tag: str, sample_idx: int, raw_text: str, **fields) -> None: + if not (self._verdict_log_enabled and log_on_rank0): + return + if self._verdict_log_count >= self._verdict_log_max: + return + raw_text = raw_text if isinstance(raw_text, str) else str(raw_text) + preview = " ".join(raw_text.split()) + if len(preview) > 200: + preview = preview[:200] + "..." + extras = " ".join(f"{key}={value}" for key, value in fields.items()) + print( + f"[ORM_RM_GENERAL_VERDICT_{tag}] " + f"sample_idx={sample_idx} text_len={len(raw_text)} raw={preview!r} {extras}".rstrip(), + flush=True, + ) + self._verdict_log_count += 1 + + verdict_summary = { + "total": len(gen_texts), + "empty": 0, + "no_numeric": 0, + "value_error": 0, + "parsed": 0, + "parsed_0": 0, + "parsed_0_5": 0, + "parsed_1": 0, + } + + scores = [] + for sample_idx, txt in enumerate(gen_texts): + txt = txt if isinstance(txt, str) else str(txt) + if txt == "": + verdict_summary["empty"] += 1 + scores.append(0.0) + _log_verdict_detail("EMPTY", sample_idx, txt, fallback="0.0") + continue + + m = re.search(r"[-+]?\d*\.?\d+", txt) + if not m: + verdict_summary["no_numeric"] += 1 + scores.append(0.0) + _log_verdict_detail("NO_NUMERIC", sample_idx, txt, fallback="0.0") + continue + + matched_token = m.group() + try: + val = float(matched_token) + except ValueError as exc: + verdict_summary["value_error"] += 1 + scores.append(0.0) + _log_verdict_detail( + "VALUE_ERROR", + sample_idx, + txt, + token=repr(matched_token), + error=repr(exc), + fallback="0.0", + ) + continue + + nearest = min(self.general_scores, key=lambda x: abs(x - val)) + verdict_summary["parsed"] += 1 + if nearest == 0.0: + verdict_summary["parsed_0"] += 1 + elif nearest == 0.5: + verdict_summary["parsed_0_5"] += 1 + elif nearest == 1.0: + verdict_summary["parsed_1"] += 1 + _log_verdict_detail( + "PARSED", + sample_idx, + txt, + token=repr(matched_token), + parsed=repr(val), + snapped=repr(nearest), + ) + scores.append(nearest) + + if self._verdict_log_enabled and log_on_rank0: + print( + "[ORM_RM_GENERAL_VERDICT_SUMMARY] " + + " ".join(f"{key}={value}" for key, value in verdict_summary.items()), + flush=True, + ) + + return {"score": torch.tensor(scores, device=self.device)} diff --git a/examples/orm_rl_demo/reward_models_utils.py b/examples/orm_rl_demo/reward_models_utils.py new file mode 100755 index 00000000..cacb3662 --- /dev/null +++ b/examples/orm_rl_demo/reward_models_utils.py @@ -0,0 +1,886 @@ +""" +General reward model utilities for the ORM RL Geo3K demo. + +This example intentionally keeps a single general reward-model path, plus the +Geo3K-specific reward mixing logic that combines format, general-model, and +accuracy rewards during training and evaluation. +""" +from __future__ import annotations + +import re +import os +import json +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence + +import torch +from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + +from lightrft.models.monkey_patch.hf_generate_patch import ( + apply_monkey_patch_to_generation_mixin, +) +from lightrft.utils import get_current_device + +from reward_models import Qwen2VLRewardModelGeneral + +# ============================================================================ +# Configuration Classes +# ============================================================================ + +class RewardModelType(str, Enum): + """Enumeration of reward model types supported by this demo.""" + GENERAL = "general" + + +@dataclass +class RewardModelConfig: + """ + Configuration for a single reward model. + + :param rtype: Reward model type. + :type rtype: RewardModelType + :param path: Model directory path or HuggingFace model name + :type path: str + :param use_engine: Whether to use SGLang engine instead of HuggingFace. Default to False + :type use_engine: bool + """ + rtype: RewardModelType + path : str + use_engine: bool = False + + +# ============================================================================ +# Model Builder Registry +# ============================================================================ + +_BUILDERS: Dict[RewardModelType, Callable] = {} + +def register_builder(rtype: RewardModelType) -> Callable: + """ + Decorator to register a builder function for a specific reward model type. + + Usage: + @register_builder(RewardModelType.GENERAL) + def build_general(cfg, strategy): + ... + + :param rtype: Reward model type to register builder for + :type rtype: RewardModelType + :return: Decorator function + :rtype: Callable + """ + def deco(fn: Callable) -> Callable: + _BUILDERS[rtype] = fn + return fn + return deco + + +RawRewardInput = Union[str, Dict[str, str], List[Dict[str, str]], None] + + +def extract_response(text: str) -> str: + """ + Extract assistant completion from a full chat transcript. + + Running format/accuracy checks on the full decoded sequence can produce + false positives because the system prompt itself contains formatting + examples such as ``...`` and ``\\boxed{}``. + """ + if not isinstance(text, str): + return "" + + s = text.strip() + if not s: + return s + + assistant_marker = "<|im_start|>assistant" + if assistant_marker in s: + start = s.rfind(assistant_marker) + len(assistant_marker) + tail = s[start:] + end_idx = tail.find("<|im_end|>") + if end_idx != -1: + tail = tail[:end_idx] + return tail.strip() + return s + + +# ============================================================================ +# Configuration Parsing +# ============================================================================ + +def _guess_rtype_from_path(path: str) -> RewardModelType: + """ + Infer reward model type from path string. + + :param path: Model path or name + :type path: str + :return: Inferred reward type + :rtype: RewardModelType + """ + return RewardModelType.GENERAL + +def parse_reward_pretrain( + raw: RawRewardInput, + *, + global_use_engine: bool +) -> Tuple[List[RewardModelConfig], Dict[str, int]]: + """ + Parse reward model configuration from various input formats. + + Supported formats: + 1. JSON: '{"general":"/path/to/rm"}' + 2. CSV: 'general:/path/to/rm' + 3. Plain path: '/path/to/rm' (treated as the general reward model) + 4. Dict/List: {'type':'general','path':'/path/to/rm'} or [{'type':'general','path':'/path/to/rm'}] + + Extra feature: Append ?engine=true to path to override global engine setting + Example: 'general:/path/to/model?engine=true' + + :param raw: Raw configuration input (string, dict, list, or None) + :type raw: RawRewardInput + :param global_use_engine: Global flag for whether to use engine mode + :type global_use_engine: bool + :return: Tuple of (cfgs, label_map) where cfgs is a list of RewardModelConfig objects + and label_map is a dict mapping reward type to index {str: int} + :rtype: Tuple[List[RewardModelConfig], Dict[str, int]] + :raises TypeError: If raw input format is not supported + + Note: + The demo only supports RewardModelType.GENERAL. + """ + if raw is None: raw = "" + + # ---------- 1. Convert string to unified list[(key,path,flag)] ---------- + pair_list: List[Tuple[str, str, Optional[bool]]] = [] + if isinstance(raw, str): + s = raw.strip().lstrip("{").rstrip("}") + # ① JSON + if raw.strip().startswith("{") and raw.strip().endswith("}"): + try: + obj = json.loads(raw) + pair_list = [(k, v, None) for k, v in obj.items()] + except json.JSONDecodeError: + pass + if not pair_list: + # ② kv/comma-separated string + for seg in re.split(r"\s*,\s*", s): + if not seg: continue + if ":" in seg: + k, v = seg.split(":", 1) + pair_list.append((k.strip(), v.strip(), None)) + else: # pure path + pair_list.append(("?", seg.strip(), None)) + elif isinstance(raw, dict): + pair_list = [(k, v, None) for k, v in raw.items()] + elif isinstance(raw, list): + for d in raw: + pair_list.append((d["type"], d["path"], d.get("engine"))) + else: + raise TypeError("Unsupported --reward_pretrain format") + + # ---------- 2. Generate cfg list ---------- + cfgs: List[RewardModelConfig] = [] + for key, path, flag in pair_list: + # Parse path?engine=true/false + use_engine = global_use_engine + if "?engine=" in path: + path, qs = path.split("?engine=", 1) + use_engine = qs.lower() in ("1", "true", "yes") + if flag is not None: + use_engine = flag + if key == "?": + rtype = _guess_rtype_from_path(path) + else: + try: + rtype = RewardModelType(key) + except ValueError as exc: + raise ValueError( + "examples/orm_rl_demo only supports the general reward model. " + f"Got reward type: {key}" + ) from exc + if rtype is not RewardModelType.GENERAL: + raise ValueError( + "examples/orm_rl_demo only supports the general reward model. " + f"Got reward type: {rtype.value}" + ) + cfgs.append(RewardModelConfig(rtype, path, use_engine)) + + # Ensure label_map order is stable and contains general + uniq: List[RewardModelType] = [] + for c in cfgs: + if c.rtype not in uniq: uniq.append(c.rtype) + if RewardModelType.GENERAL not in uniq: + uniq.append(RewardModelType.GENERAL) + label_map = {rt.value: i for i, rt in enumerate(uniq)} + return cfgs, label_map + + +def _normalized_geo3k_general_weights() -> Tuple[float, float, float]: + """ + Return normalized (format, model, accuracy) weights for the Geo3K+ORM mix. + """ + format_w = float(os.environ.get("ORM_RL_DEMO_GEO3K_FORMAT_WEIGHT", "0.1")) + model_w = float(os.environ.get("ORM_RL_DEMO_GEO3K_MODEL_WEIGHT", "0.2")) + accuracy_w = float(os.environ.get("ORM_RL_DEMO_GEO3K_ACCURACY_WEIGHT", "0.7")) + total = format_w + model_w + accuracy_w + if total <= 0: + return 0.1, 0.2, 0.7 + return format_w / total, model_w / total, accuracy_w / total + + +def _infer_rm_engine_tp_size(pretrain_path: str) -> int: + override = os.environ.get("ORM_RL_DEMO_RM_ENGINE_TP") + if override: + return int(override) + + path = pretrain_path.lower() + if "72b" in path: + return 8 + if "32b" in path: + return 4 + if "14b" in path: + return 2 + if "8b" in path or "7b" in path: + return 1 + return 1 + + +def _rm_engine_mem_util() -> float: + return float(os.environ.get("ORM_RL_DEMO_RM_ENGINE_MEM_UTIL", "0.15")) + + +# ============================================================================ +# Model Loading Functions +# ============================================================================ + +def _load_hf_model( + pretrain_path: str, + device: torch.device +) -> Tuple[Qwen2_5_VLForConditionalGeneration, Any]: + """ + Load HuggingFace model and processor. + + :param pretrain_path: Model path or HuggingFace model name + :type pretrain_path: str + :param device: Target device + :type device: torch.device + :return: Tuple of (base_model, processor) + :rtype: Tuple[Qwen2_5_VLForConditionalGeneration, Any] + """ + base = Qwen2_5_VLForConditionalGeneration.from_pretrained( + pretrain_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + ) + processor = AutoProcessor.from_pretrained( + pretrain_path, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28 + ) + processor.tokenizer.padding_side = "left" + return base, processor + + +def _load_engine( + pretrain_path: str, + device: torch.device +) -> Tuple[Any, Any]: + """ + Load reward-model inference engine and processor. + + Automatically determines tensor parallelism size based on reward model type: + - 7B/8B models default to tp_size = 1 + - larger models scale tp_size by model size + + :param pretrain_path: Model path or HuggingFace model name + :type pretrain_path: str + :param device: Target device + :type device: torch.device + :return: Tuple of (engine, processor) + :rtype: Tuple[Any, Any] + + Note: + Engine is set to sleep mode after loading to save memory. + Backend selection is controlled by ORM_RL_DEMO_RM_ENGINE_BACKEND: + - sglang: use SGLang only + - vllm: use vLLM only + - auto: try SGLang first, then fall back to vLLM + """ + tp_size = _infer_rm_engine_tp_size(pretrain_path) + engine_mem_util = _rm_engine_mem_util() + backend = os.environ.get("ORM_RL_DEMO_RM_ENGINE_BACKEND", "auto").strip().lower() + + print( + f"[reward_models_utils] Loading engine from {pretrain_path} " + f"with tp_size={tp_size}, engine_mem_util={engine_mem_util}, backend={backend}" + ) + + engine = None + + if backend in ("auto", "sglang"): + try: + from lightrft.strategy.sglang_utils import get_sglang_engine + + engine = get_sglang_engine( + pretrain_path, + engine_mem_util=engine_mem_util, + tp_size=tp_size, + skip_tokenizer_init=False, + disable_cuda_graph=True, # only for deepseek, TODO: why deepseek pipeline (examples/orm_rl_demo/run_fsdp_deepseek.sh) need this? + ) + print(f"[reward_models_utils] Loaded SGLang engine from {pretrain_path} with tp_size={tp_size}") + except Exception as exc: + if backend == "sglang": + raise + print( + f"[reward_models_utils] SGLang engine init failed for {pretrain_path}: " + f"{type(exc).__name__}: {exc}. Falling back to vLLM." + ) + + if engine is None: + from lightrft.strategy.vllm_utils import get_vllm_engine + + max_model_len = int(os.environ.get("ORM_RL_DEMO_RM_ENGINE_MAX_MODEL_LEN", "4096")) + engine = get_vllm_engine( + pretrain_path, + dtype="bfloat16", + tp_size=tp_size, + mem_util=engine_mem_util, + max_model_len=max_model_len, + enable_sleep=False, + limit_mm_per_prompt={"image": 1}, + ) + print( + f"[reward_models_utils] Loaded vLLM engine from {pretrain_path} " + f"with tp_size={tp_size}, max_model_len={max_model_len}, enable_sleep=False" + ) + + if not hasattr(engine, "llm_engine") or engine.llm_engine.vllm_config.model_config.enable_sleep_mode: + engine.sleep() # Sleep to save memory + + processor = AutoProcessor.from_pretrained( + pretrain_path, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28 + ) + processor.tokenizer.padding_side = "left" + return engine, processor + + +# ============================================================================ +# Model Builders for Each Reward Type +# ============================================================================ + +@register_builder(RewardModelType.GENERAL) +def build_general( + cfg: RewardModelConfig, + strategy: Any, + base: Optional[Tuple[Any, Any]] = None +) -> Tuple[Qwen2VLRewardModelGeneral, Any]: + """ + Build General quality reward model. + + :param cfg: Reward model configuration + :type cfg: RewardModelConfig + :param strategy: Training strategy instance + :type strategy: Any + :param base: Optional shared base model (engine, processor) tuple. Default to None + :type base: Optional[Tuple[Any, Any]] + :return: Tuple of (model, tokenizer) + :rtype: Tuple[Qwen2VLRewardModelGeneral, Any] + """ + if cfg.use_engine: + if base: + engine, proc = base + else: + engine, proc = _load_engine(cfg.path, get_current_device()) + model = Qwen2VLRewardModelGeneral(engine, proc.tokenizer, proc, text_only=strategy.args.text_only) + return model, proc.tokenizer + else: + base_model, proc = _load_hf_model(cfg.path, get_current_device()) + model = Qwen2VLRewardModelGeneral(base_model, proc.tokenizer, proc, text_only=strategy.args.text_only) + model.eval() + return model, proc.tokenizer + +# ============================================================================ +# Main Initialization Entry Point +# ============================================================================ + +def load_reward_models( + raw_reward_pretrain: RawRewardInput, + strategy: Any, + use_engine: bool = False, +) -> Tuple[List[Any], List[Any], Dict[str, int]]: + """ + Load and initialize all reward models from configuration. + + This is the main entry point for loading reward models. It handles: + - Configuration parsing + - Base model sharing (to save memory) + - Model initialization with proper context + - Monkey patching for HuggingFace generation + + :param raw_reward_pretrain: Raw configuration (see parse_reward_pretrain) + :type raw_reward_pretrain: RawRewardInput + :param strategy: Training strategy instance + :type strategy: Any + :param use_engine: Global flag for using SGLang engine. Default to False + :type use_engine: bool + :return: Tuple of (reward_models, reward_tokenizers, label_map) where + reward_models is a list of initialized reward model instances, + reward_tokenizers is a list of corresponding tokenizers, + and label_map is a dict mapping reward type to index + :rtype: Tuple[List[Any], List[Any], Dict[str, int]] + + Note: + Models sharing the same base path will reuse the same loaded base model + to reduce memory footprint. + """ + apply_monkey_patch_to_generation_mixin() + + cfgs, label_map = parse_reward_pretrain( + raw_reward_pretrain, global_use_engine=use_engine + ) + + rms: List[Any] = [] + toks: List[Any] = [] + + # Share base models across reward models to save memory + # Since some reward models can share the same base model, we only load it once + shared_bases: Dict[Tuple[str, bool], Tuple[Any, Any]] = {} + shared_count: Dict[Tuple[str, bool], int] = {} + for cfg in cfgs: + cache_key = (cfg.path, cfg.use_engine) + if cache_key not in shared_count: + shared_count[cache_key] = 1 + else: + shared_count[cache_key] += 1 + + if shared_count[cache_key] == 1: + loader = _load_engine if cfg.use_engine else _load_hf_model + shared_bases[cache_key] = loader(cfg.path, get_current_device()) + strategy.print(f"Init reward model {cfg.path} (engine={cfg.use_engine})") + else: + strategy.print(f"Use shared base model {cfg.path}") + + for cfg in cfgs: + if cfg.rtype not in _BUILDERS: + raise RuntimeError(f"No builder for {cfg.rtype}") + strategy.print(f"Loading {cfg.rtype} from {cfg.path} (engine={cfg.use_engine})") + + # Initialize model with proper context (supports FSDP/meta device init) + with strategy.init_model_context() as _: + # All reward types now support shared base models + rm, tok = _BUILDERS[cfg.rtype](cfg, strategy, base=shared_bases.get((cfg.path, cfg.use_engine))) + + rms.append(rm) + toks.append(tok) + strategy.print(f"Loaded {cfg.rtype}") + + return rms, toks, label_map + + + +# ============================================================================ +# Reward Functions +# ============================================================================ + +def format_reward_fn(sol: str) -> float: + """ + Check if solution matches format: ... + non-empty content. + + :param sol: Solution string to check + :type sol: str + :return: 1.0 if format is valid, 0.0 otherwise + :rtype: float + """ + return 1.0 if re.match(r".*.+?\s*\S+", sol, re.DOTALL) else 0.0 + + +def rule_reward_fn(sol: str, gt: str) -> float: + """ + Extract content after and verify against ground truth using mathruler. + + :param sol: Solution string (may contain ...) + :type sol: str + :param gt: Ground truth answer + :type gt: str + :return: 1.0 if correct, 0.0 otherwise + :rtype: float + """ + from mathruler.grader import extract_boxed_content, grade_answer + ans = sol.split("")[-1] + pred = extract_boxed_content(ans) + if pred == gt or grade_answer(pred, gt): + return 1.0 + return 0.0 + +# ============================================================================ +# Reward Recipe Configuration +# ============================================================================ + +# Original reward recipe for SVKG dataset training (after KG dataset training) + +def geo3k_accuracy_reward_fn(sol: str, gt: str) -> float: + """ + Geo3K accuracy reward function. + + Extract answer from \boxed{} notation and use mathruler to verify correctness. + This is based on the verl implementation for geo3k dataset. + + :param sol: Solution string from model (should contain \boxed{answer}) + :type sol: str + :param gt: Ground truth answer + :type gt: str + :return: 1.0 if answer is correct, 0.0 otherwise + :rtype: float + """ + from mathruler.grader import extract_boxed_content, grade_answer + pred = extract_boxed_content(sol) + return 1.0 if grade_answer(pred, gt) else 0.0 + + +def geo3k_format_reward_fn(sol: str) -> float: + """ + Geo3K format reward function. + + Check if the solution follows the required format: + - Contains ... tags for reasoning + - Contains \boxed{} for final answer + - The think tags must appear BEFORE the boxed answer + + This is based on the verl implementation for geo3k dataset. + + :param sol: Solution string from model + :type sol: str + :return: 1.0 if format is correct, 0.0 otherwise + :rtype: float + """ + # Strip leading/trailing whitespace for robust matching + sol_stripped = sol.strip() + + # Check if solution contains both ... and \boxed{...} + # Use re.search to find positions + think_match = re.search(r'.*?', sol_stripped, re.DOTALL) + boxed_match = re.search(r'\\boxed\{.*?\}', sol_stripped, re.DOTALL) + + # Both components must be present AND think must come before boxed + if think_match and boxed_match: + # Check that comes before \boxed + think_end = think_match.end() + boxed_start = boxed_match.start() + return 1.0 if think_end <= boxed_start else 0.0 + else: + return 0.0 + + +def geo3k_combined_reward_fn( + sol: str, + gt: str, + format_weight: float = 0.1 +) -> float: + """ + Geo3K combined reward function. + + Combines format reward and accuracy reward with specified weights. + Default: 90% accuracy + 10% format (matching verl implementation) + + :param sol: Solution string from model + :type sol: str + :param gt: Ground truth answer + :type gt: str + :param format_weight: Weight for format reward. Default to 0.1 + :type format_weight: float + :return: Weighted combination of format and accuracy rewards + :rtype: float + """ + acc_reward = geo3k_accuracy_reward_fn(sol, gt) + fmt_reward = geo3k_format_reward_fn(sol) + return (1.0 - format_weight) * acc_reward + format_weight * fmt_reward + + +def gsm8k_accuracy_reward_fn(sol: str, gt: str) -> float: + """ + GSM8K accuracy reward function. + + Extract answer from \boxed{} notation and use mathruler to verify correctness. + This follows the same pattern as geo3k but for GSM8K dataset. + + :param sol: Solution string from model (should contain \boxed{answer}) + :type sol: str + :param gt: Ground truth answer + :type gt: str + :return: 1.0 if answer is correct, 0.0 otherwise + :rtype: float + """ + from mathruler.grader import extract_boxed_content, grade_answer + pred = extract_boxed_content(sol) + return 1.0 if grade_answer(pred, gt) else 0.0 + + +def gsm8k_format_reward_fn(sol: str) -> float: + """ + GSM8K format reward function. + + Check if the solution follows the required format: + - Contains ... tags for reasoning + - Contains \boxed{} for final answer + - The think tags must appear BEFORE the boxed answer + + This follows the same pattern as geo3k format checking. + + :param sol: Solution string from model + :type sol: str + :return: 1.0 if format is correct, 0.0 otherwise + :rtype: float + """ + # Strip leading/trailing whitespace for robust matching + sol_stripped = sol.strip() + + # Check if solution contains both ... and \boxed{...} + # Use re.search to find positions + think_match = re.search(r'.*?', sol_stripped, re.DOTALL) + boxed_match = re.search(r'\\boxed\{.*?\}', sol_stripped, re.DOTALL) + + # Both components must be present AND think must come before boxed + if think_match and boxed_match: + # Check that comes before \boxed + think_end = think_match.end() + boxed_start = boxed_match.start() + return 1.0 if think_end <= boxed_start else 0.0 + else: + return 0.0 + + +def gsm8k_combined_reward_fn( + sol: str, + gt: str, + format_weight: float = 0.1 +) -> float: + """ + GSM8K combined reward function. + + Combines format reward and accuracy reward with specified weights. + Default: 90% accuracy + 10% format (matching verl and geo3k implementation) + + :param sol: Solution string from model + :type sol: str + :param gt: Ground truth answer + :type gt: str + :param format_weight: Weight for format reward. Default to 0.1 + :type format_weight: float + :return: Weighted combination of format and accuracy rewards + :rtype: float + """ + acc_reward = gsm8k_accuracy_reward_fn(sol, gt) + fmt_reward = gsm8k_format_reward_fn(sol) + return (1.0 - format_weight) * acc_reward + format_weight * fmt_reward + +RECIPE: Dict[str, List[Tuple[str, Optional[str], float]]] = { + "general": [("model", "general", 1.0)], +} + + +def mix_rewards( + labels: Sequence[str], + model_scores: torch.Tensor, + label_map: Dict[str, int], + solution_strs: Sequence[str], + refs: Sequence[str], +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Mix rewards from multiple sources according to recipe configuration. + + This function combines: + 1. Format reward (always applied) + 2. Model-based rewards (from neural reward models) + 3. Rule-based rewards (from heuristic functions) + + :param labels: List of data labels (length B) + :type labels: Sequence[str] + :param model_scores: Tensor of model scores, shape (n_model, B) + :type model_scores: torch.Tensor + :param label_map: Mapping from reward type to model index + :type label_map: Dict[str, int] + :param solution_strs: List of solution strings (length B) + :type solution_strs: Sequence[str] + :param refs: List of reference answers (length B) + :type refs: Sequence[str] + :return: Tuple of (final_reward, metrics_dict) where final_reward is tensor of shape (B,) + containing combined rewards and metrics_dict contains detailed reward metrics + :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]] + + Error handling: + - If a model is not loaded or index out of bounds, returns 1.0 with warning + - If label not in RECIPE, returns 1.0 with warning + - Never raises IndexError, always returns valid reward + + Note: + Format reward is always computed first, then rewards from recipe are added + """ + if torch.distributed.get_rank() == 0: + print(f"labels:{labels}, model_scores:{model_scores.tolist()}") + device = model_scores.device + n_model, B = model_scores.shape[0], len(labels) + assert model_scores.shape[1] == B, "model_scores second dimension must equal batch size" + + final_reward = torch.zeros(B, dtype=torch.float32, device=device) + + # Initialize metrics dict to track individual reward components + metrics_dict: Dict[str, torch.Tensor] = { + 'format_reward': torch.zeros(B, dtype=torch.float32, device=device), + 'accuracy_reward': torch.zeros(B, dtype=torch.float32, device=device), + 'general_model_reward': torch.zeros(B, dtype=torch.float32, device=device), + 'rule_reward': torch.zeros(B, dtype=torch.float32, device=device), + } + + # ---------- Fallback scoring function ---------- + def get_model_reward(key: str, i: int) -> float: + """ + Try to return model score for , return 1.0 on failure. + + :param key: Reward model type key + :type key: str + :param i: Sample index + :type i: int + :return: Model score or 1.0 if not available + :rtype: float + """ + if key not in label_map: + print(f"Model reward <{key}> not loaded, using 1 as default reward") + return 1.0 + + idx = label_map[key] + if idx >= n_model: + print(f"Model reward <{key}> index {idx} out of bounds " + f"(n_model={n_model}), using 1 as default reward") + return 1.0 + + return float(model_scores[idx, i].item()) + + # ---------- Main loop ---------- + geo3k_fmt_w, geo3k_model_w, geo3k_acc_w = _normalized_geo3k_general_weights() + + for i, lab in enumerate(labels): + sol = solution_strs[i] + sol_completion = extract_response(sol) + gt = refs[i] if i < len(refs) else "" + + if lab == "geo3k_general": + fmt_r = geo3k_format_reward_fn(sol_completion) + acc_r = geo3k_accuracy_reward_fn(sol_completion, gt) + model_score = get_model_reward("general", i) + + final_reward[i] = ( + geo3k_fmt_w * fmt_r + + geo3k_model_w * model_score + + geo3k_acc_w * acc_r + ) + metrics_dict['format_reward'][i] = fmt_r + metrics_dict['accuracy_reward'][i] = acc_r + metrics_dict['general_model_reward'][i] = model_score + rule_w_total = geo3k_fmt_w + geo3k_acc_w + metrics_dict['rule_reward'][i] = (geo3k_fmt_w * fmt_r + geo3k_acc_w * acc_r) / rule_w_total + continue + + # 1) format reward (always present) + r = format_reward_fn(sol_completion) + # Track separately + metrics_dict['format_reward'][i] = r + + # 2) accumulate according to recipe + recipe = RECIPE.get(lab) + if recipe is None: + print(f"label <{lab}> not registered in RECIPE, giving 1 reward directly") + recipe = [] # or raise + + for typ, key, w in recipe: + if typ == "model": + model_r = w * get_model_reward(key, i) + r += model_r + metrics_dict['general_model_reward'][i] += model_r + + elif typ == "rule": + rule_r = w * rule_reward_fn(sol_completion, gt) + r += rule_r + metrics_dict['rule_reward'][i] += rule_r + metrics_dict['accuracy_reward'][i] = rule_r + + elif typ == "geo3k_rule": + r = 0 # TODO: geo3k have own format reward + # Track separately + metrics_dict['accuracy_reward'][i] = 0 + metrics_dict['format_reward'][i] = 0 + # Geo3K pure rule-based reward (format + accuracy) + # Get individual components + acc_r = geo3k_accuracy_reward_fn(sol_completion, gt) + fmt_r = geo3k_format_reward_fn(sol_completion) + combined_r = (1.0 - 0.1) * acc_r + 0.1 * fmt_r + r += w * combined_r + # Track separately + metrics_dict['accuracy_reward'][i] = acc_r + metrics_dict['format_reward'][i] = fmt_r + elif typ == "gsm8k_rule": + r = 0 # TODO: gsm8k have own format reward + # Track separately + metrics_dict['accuracy_reward'][i] = 0 + metrics_dict['format_reward'][i] = 0 + # GSM8K pure rule-based reward (format + accuracy) + # Get individual components + acc_r = gsm8k_accuracy_reward_fn(sol_completion, gt) + fmt_r = gsm8k_format_reward_fn(sol_completion) + combined_r = (1.0 - 0.1) * acc_r + 0.1 * fmt_r + r += w * combined_r + # Track separately + metrics_dict['accuracy_reward'][i] = acc_r + metrics_dict['format_reward'][i] = fmt_r + else: + print(f"Unknown component type {typ}, ignoring") + + final_reward[i] = r + + return final_reward, metrics_dict + + +def reward_fn( + model_reward_list: List[torch.Tensor], + labels: Sequence[str], + queries: Sequence[str], + refs: Sequence[str], + label_map: Dict[str, int], +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + External unified interface for computing final rewards. + + This is the main entry point called by the trainer. It: + 1. Stacks individual model rewards into a single tensor + 2. Calls mix_rewards to combine all reward sources + 3. Returns final reward tensor + + :param model_reward_list: List of reward tensors from each model, each shape (B,) + :type model_reward_list: List[torch.Tensor] + :param labels: List of data labels indicating reward type (length B) + :type labels: Sequence[str] + :param queries: List of query/solution strings (length B) + :type queries: Sequence[str] + :param refs: List of reference answers (length B) + :type refs: Sequence[str] + :param label_map: Mapping from reward type to model index + :type label_map: Dict[str, int] + :return: Tuple of (final_reward, metrics_dict) where final_reward is combined reward tensor + of shape (B,) and metrics_dict contains detailed reward metrics + :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]] + + Note: + If model_reward_list is empty (no NN models), a placeholder zero tensor is created + """ + # print(f"model_reward_list:{model_reward_list}, labels:{labels}, queries:{queries}, refs:{refs}, label_map:{label_map}") + # print(f"label_map:{label_map}") + + # ------ stack to (n_model, B) ------ + if model_reward_list: + model_scores = torch.stack(model_reward_list) # (n_model, B) + else: + # When no torch.nn model RM is available, give placeholder zero score + B = len(labels) + model_scores = torch.zeros(0, B, dtype=torch.float32, device="cuda") + + # ------ call combination logic ------ + return mix_rewards(labels, model_scores, label_map, queries, refs) diff --git a/examples/orm_rl_demo/run_general_fsdp_qwenvl.sh b/examples/orm_rl_demo/run_general_fsdp_qwenvl.sh new file mode 100755 index 00000000..3dfc0206 --- /dev/null +++ b/examples/orm_rl_demo/run_general_fsdp_qwenvl.sh @@ -0,0 +1,153 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +NAME="${NAME:-orm-rl-demo-general-geo3k}" +N_SAMPLES=8 +EPISODE="${EPISODE:-20}" +WARMUP=0.03 +RBS=128 +TBS=128 +KL=0.001 +LR=1e-6 + +PROMPT_MAX_LEN=1024 +GENERATE_MAX_LEN=2048 +EVAL_SPLIT="${EVAL_SPLIT:-test}" +MAX_EVAL_SAMPLES="${MAX_EVAL_SAMPLES:-700}" +limit_mm_image_per_prompt=1 +ENGINE_TP=1 + +export IGNORE_EOS=0 + +DEFAULT_DATA_PATH="/path/to/geo3k" +DEFAULT_PRETRAIN_PATH="/path/to/Qwen2.5-VL-7B-Instruct" +DEFAULT_REWARD_PRETRAIN_PATHS='{"general":"/path/to/general-reward-model"}' + +DATA_PATH="${DATA_PATH:-${DEFAULT_DATA_PATH}}" +PRETRAIN_PATH="${PRETRAIN_PATH:-${DEFAULT_PRETRAIN_PATH}}" +REWARD_PRETRAIN_PATHS="${REWARD_PRETRAIN_PATHS:-${DEFAULT_REWARD_PRETRAIN_PATHS}}" +LABEL_OVERRIDE="${LABEL_OVERRIDE:-geo3k_general}" +USE_RM_ENGINE="${USE_RM_ENGINE:-1}" + +current_time=$(date +"%m%d%H%M") + +cd "${REPO_ROOT}" + +export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}" + +if [ "${DATA_PATH}" = "${DEFAULT_DATA_PATH}" ] || [ "${PRETRAIN_PATH}" = "${DEFAULT_PRETRAIN_PATH}" ] || [ "${REWARD_PRETRAIN_PATHS}" = "${DEFAULT_REWARD_PRETRAIN_PATHS}" ]; then + echo "Set DATA_PATH, PRETRAIN_PATH, and REWARD_PRETRAIN_PATHS before running this template." >&2 + exit 1 +fi + +mkdir -p log +mkdir -p wandb + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG=WARN + +export MLP_WORKER_NUM=1 +export MLP_WORKER_GPU="${MLP_WORKER_GPU:-2}" +export MLP_ROLE_INDEX=0 +export MLP_WORKER_0_PORT=20090 +export MLP_WORKER_0_HOST=localhost + +export MASTER_ADDR=$MLP_WORKER_0_HOST +export NNODES=$MLP_WORKER_NUM +export NODE_RANK=$MLP_ROLE_INDEX +export GPUS_PER_NODE=$MLP_WORKER_GPU +export MASTER_PORT=$MLP_WORKER_0_PORT + +SAVE_MODEL_NAME="LightRFT-geo3k-general-orm-len_${PROMPT_MAX_LEN}_${GENERATE_MAX_LEN}-tbs_${TBS}-rbs_${RBS}-sample_${N_SAMPLES}-kl_${KL}-warmup_${WARMUP}-ep_${EPISODE}-lr_${LR}" + +mkdir -p "results/${NAME}/${SAVE_MODEL_NAME}" +mkdir -p "rft_logs/${NAME}" + +set -x + +export WANDB_MODE="${WANDB_MODE:-offline}" +export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/wandb}" +mkdir -p "${WANDB_DIR}" + +WANDB_PROJECT="${WANDB_PROJECT:-orm-rl-demo-geo3k}" +WANDB_RUN_NAME="${WANDB_RUN_NAME:-orm-rl-demo-general-${current_time}}" +WANDB_ORG="${WANDB_ORG:-}" +ENGINE_TYPE="${ENGINE_TYPE:-sglang}" +ENGINE_MEM_UTIL="${ENGINE_MEM_UTIL:-0.4}" +export ORM_RL_DEMO_RM_ENGINE_BACKEND="${ORM_RL_DEMO_RM_ENGINE_BACKEND:-sglang}" + +rm_use_engine_args=() +if [ "${USE_RM_ENGINE}" = "1" ]; then + rm_use_engine_args+=(--rm_use_engine) +fi + +wandb_org_args=() +if [ -n "${WANDB_ORG}" ]; then + wandb_org_args+=(--wandb_org "${WANDB_ORG}") +fi + +wandb_args=() +if [ -n "${WANDB_API_KEY:-}" ]; then + wandb_args+=(--use_wandb "${WANDB_API_KEY}") + wandb_args+=(--wandb_project "${WANDB_PROJECT}") + wandb_args+=(--wandb_run_name "${WANDB_RUN_NAME}") + wandb_args+=("${wandb_org_args[@]}") +fi + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR "${SCRIPT_DIR}/train_colocate.py" \ + --pretrain "${PRETRAIN_PATH}" \ + --loss_agg_mode seq-mean-token-mean \ + --save_trajectories \ + --num_trajectories_to_save 16 \ + --print_replay_buffer_stats \ + --fsdp \ + --use_kl_loss \ + "${rm_use_engine_args[@]}" \ + --mixed_mm_data \ + --reward_pretrain "${REWARD_PRETRAIN_PATHS}" \ + --save_path "results/${NAME}/${SAVE_MODEL_NAME}" \ + --ckpt_path "results/${NAME}/${SAVE_MODEL_NAME}" \ + --micro_train_batch_size 4 \ + --train_batch_size ${TBS} \ + --micro_rollout_batch_size 4 \ + --rollout_batch_size ${RBS} \ + --advantage_estimator group_norm \ + --max_epochs 1 \ + --num_episodes ${EPISODE} \ + --lr_warmup_ratio ${WARMUP} \ + --n_samples_per_prompt ${N_SAMPLES} \ + --prompt_max_len ${PROMPT_MAX_LEN} \ + --generate_max_len ${GENERATE_MAX_LEN} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --init_kl_coef ${KL} \ + --kl_estimator k3 \ + --prompt_data "${DATA_PATH}" \ + --input_key prompt \ + --images_key images \ + --label_key label \ + --label_override "${LABEL_OVERRIDE}" \ + --eval_steps 20 \ + --eval_split "${EVAL_SPLIT}" \ + --max_eval_samples "${MAX_EVAL_SAMPLES}" \ + --apply_chat_template \ + --flash_attn \ + --gradient_checkpointing \ + --save_steps 20 \ + --max_ckpt_num 1 \ + --engine_type "${ENGINE_TYPE}" \ + --engine_mem_util "${ENGINE_MEM_UTIL}" \ + --engine_tp_size ${ENGINE_TP} \ + --enable_engine_sleep \ + --system_prompt 'A conversation between the User and Assistant. The User asks a question, and the Assistant provides a solution. The Assistant first thinks through the reasoning process internally with self-reflection and consistency check and then gives the final analysis and answer. The reasoning process should be enclosed within , followed directly by the final thought and answer, and the final answer should be put in \boxed{}, like this: reasoning process here final thought and \boxed{answer} here.' \ + --l2 1.0e-2 \ + --freeze_prefix \ + --adam_offload \ + --limit_mm_image_per_prompt ${limit_mm_image_per_prompt} \ + "${wandb_args[@]}" \ + 2>&1 | tee "rft_logs/${NAME}/${NAME}_node${NODE_RANK}_$(date +%Y%m%d_%H%M%S).log" diff --git a/examples/orm_rl_demo/test_reward_models.py b/examples/orm_rl_demo/test_reward_models.py new file mode 100644 index 00000000..0d789482 --- /dev/null +++ b/examples/orm_rl_demo/test_reward_models.py @@ -0,0 +1,105 @@ +""" +Smoke test for the general reward model used by the ORM RL demo. + +This script intentionally uses public text-only examples so it can validate the +general reward model without depending on private datasets or absolute paths. +""" + +import argparse +import os +import sys + +import torch +from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + +sys.path.append(os.path.dirname(__file__)) +from reward_models import Qwen2VLRewardModelGeneral + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Smoke test the general reward model used by orm_rl_demo." + ) + parser.add_argument( + "--model", + required=True, + help="Path or HuggingFace id for the general reward model.", + ) + return parser.parse_args() + + +def build_dialog(question: str, response: str) -> str: + return ( + f"<|im_start|>user\n{question}<|im_end|>\n" + f"<|im_start|>assistant\n{response}<|im_end|>\n" + ) + + +def load_reward_model(model_path: str) -> Qwen2VLRewardModelGeneral: + base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="auto", + ) + processor = AutoProcessor.from_pretrained( + model_path, + min_pixels=256 * 28 * 28, + max_pixels=1280 * 28 * 28, + ) + reward_model = Qwen2VLRewardModelGeneral( + base_model, + processor.tokenizer, + processor, + text_only=True, + ) + reward_model.eval() + return reward_model + + +def run_case(reward_model: Qwen2VLRewardModelGeneral, case: dict) -> None: + outputs = reward_model( + input_ids=None, + attention_mask=None, + references=[case["reference"]], + prompt_and_output=[build_dialog(case["question"], case["response"])], + raw_images=[None], + ) + score = float(outputs["score"].item()) + print(f"{case['name']}: score={score:.1f}, expected={case['expected']:.1f}") + if abs(score - case["expected"]) > 1e-6: + raise AssertionError( + f"{case['name']} expected {case['expected']}, got {score}" + ) + + +def main() -> None: + args = parse_args() + reward_model = load_reward_model(args.model) + + test_cases = [ + { + "name": "correct_answer", + "question": "What is 2 + 2?", + "response": "The answer is 4.", + "reference": "4", + "expected": 1.0, + }, + { + "name": "incorrect_answer", + "question": "What is 2 + 2?", + "response": "The answer is 5.", + "reference": "4", + "expected": 0.0, + }, + ] + + for case in test_cases: + run_case(reward_model, case) + + print("general reward model smoke test passed") + + +if __name__ == "__main__": + with torch.no_grad(): + main() diff --git a/examples/orm_rl_demo/train_colocate.py b/examples/orm_rl_demo/train_colocate.py new file mode 100755 index 00000000..16babf07 --- /dev/null +++ b/examples/orm_rl_demo/train_colocate.py @@ -0,0 +1,656 @@ +""" +Training entry for the Geo3K ORM RL demo with a co-located general reward model. + +This script keeps the demo-specific dataset override and reward wiring local to +`examples/orm_rl_demo` while reusing the shared GRPO training stack. +""" +import argparse +import itertools +import math +import re +import os +import sys +import json +from datetime import datetime +from typing import Callable, Dict, List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lightrft.utils import add_arguments, ensure_video_input_available +ensure_video_input_available() + +from lightrft.datasets import PromptDatasetVL, SFTDatasetVL +from lightrft.models.actor_language import ActorLanguage +from lightrft.models.actor_vl import ActorVL +from lightrft.models.critic_vl import CriticVL +from lightrft.utils import blending_datasets, get_tokenizer_processor_vl + +from lightrft.strategy import get_strategy +from lightrft.trainer.spmd_ppo_trainer import SPMDPPOTrainerVL + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from reward_models_utils import load_reward_models, reward_fn, RECIPE + + +def _apply_label_override(dataset, label_key: str, label_override: str, strategy, dataset_name: str): + """ + Apply a demo-local label override without touching the shared dataset library. + + :param dataset: Source dataset to update. + :type dataset: Any + :param label_key: Dataset field storing the reward label. + :type label_key: str + :param label_override: Override value to inject when non-empty. + :type label_override: str + :param strategy: Training strategy used for logging. + :type strategy: Any + :param dataset_name: Human-readable dataset name for logs. + :type dataset_name: str + :return: Updated dataset. + :rtype: Any + """ + if not label_override: + return dataset + + strategy.print(f"Applying label override '{label_override}' to {dataset_name}") + + def override_label(example): + example[label_key] = label_override + extra = example.get("extra_info") + if isinstance(extra, dict): + extra = dict(extra) + extra[label_key] = label_override + example["extra_info"] = extra + return example + + return dataset.map(override_label) + + +def train(args): + """ + Main training function for GRPO with co-located reward models. + + :param args: Parsed command-line arguments containing all training configuration. + :type args: argparse.Namespace + """ + # configure strategy + strategy = get_strategy(args) + + ds_train_cfg = strategy.get_ds_train_config(is_actor=True) if not args.fsdp else None + ds_eval_cfg = strategy.get_ds_eval_config(offload=False) if not args.fsdp else None + + # configure model + with strategy.init_model_context(meta_init=args.meta_init): + strategy.print(f"Initializing models with meta_init={args.meta_init}") + + # Select Actor class based on text_only flag + if args.text_only: + Actor = ActorLanguage + else: + Actor = ActorVL + + # Initialize Actor (policy model) + actor = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=ds_train_cfg, + packing_samples=args.packing_samples, + disable_logprobs_flashattn=args.disable_logprobs_flashattn, + fused_linear_logprob=args.fused_linear_logprob, + ) + + if args.actor_init_on_gpu: + actor = actor.to(torch.cuda.current_device()) + + # pre-prepare is used for saving RAM memory when training 72B model + if args.fsdp: + setattr(actor, "is_actor", True) + actor = strategy.prepare_model(actor, is_training=True) + + # Optionally freeze parameters (e.g., vision encoder) + if args.freeze_prefix: + freeze_prefix = ["visual"] + frozen_params_count = 0 + total_params_count = 0 + for name, param in actor.model.named_parameters(): + total_params_count += 1 + if any(name.startswith(prefix) for prefix in freeze_prefix): + param.requires_grad = False + frozen_params_count += 1 + strategy.print(f"Froze {frozen_params_count}/{total_params_count} parameters based on prefixes: {freeze_prefix}") + + if args.critic_pretrain: + with strategy.init_model_context(meta_init=args.meta_init): + critic = CriticVL( + args.critic_pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + normalize_reward=args.normalize_reward_for_critic, + ds_config=ds_train_cfg, + init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain, + value_head_prefix=args.value_head_prefix, + ) + else: + critic = None + + if args.fsdp and critic is not None: + critic = strategy.prepare_model(critic, is_training=True) + + # Load the general reward model used by this demo. + strategy.report_memory(f"before loaded reward models in main entry") + reward_models, reward_tokenizers, label_map = load_reward_models( + raw_reward_pretrain=args.reward_pretrain, + strategy=strategy, + use_engine=args.rm_use_engine, + ) + strategy.print(f"label_map: {label_map}") + strategy.report_memory(f"after loaded reward models in main entry") + + strategy.print(actor) + strategy.print(critic) + + # load weights for reference actor + if args.init_kl_coef == 0: + initial_model = None + else: + initial_model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=ds_eval_cfg, + packing_samples=args.packing_samples, + fused_linear_logprob=args.fused_linear_logprob, + ) + + if args.fsdp: + initial_model = strategy.prepare_model(initial_model, is_training=False, shard_size=strategy.world_size) + strategy.offload_model(initial_model) + + if args.enable_ema: + ema_model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=ds_eval_cfg, + ) + else: + ema_model = None + + # configure tokenizer and processor + tokenizer, processor = get_tokenizer_processor_vl( + args.pretrain, actor.model, "left", use_fast=not strategy.args.disable_fast_tokenizer + ) + assert processor is not None, "processor is None" + + # ==================== Data Loading Optimization ==================== + # The following sections now rely on the robust `blending_datasets` function. + # We add more logging for clarity. + + # Prepare prompts dataset + strategy.print(f"Loading prompts dataset from: {args.prompt_data} with split: {args.prompt_split}") + prompts_data = blending_datasets( + args.prompt_data, + args.prompt_data_probs, + strategy, + args.seed, + return_eval=False, + train_split=args.prompt_split, + ) + prompts_data = _apply_label_override( + prompts_data, args.label_key, args.label_override, strategy, "prompt dataset" + ) + + prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data)))) + prompts_dataset = PromptDatasetVL(prompts_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template) + strategy.print(f"Loaded {len(prompts_dataset)} samples for prompts.") + + # Prepare evaluation dataset + eval_dataloader = None + if args.eval_data or args.eval_split: + eval_data_path = args.eval_data if args.eval_data else args.prompt_data + if eval_data_path: + strategy.print(f"Loading evaluation dataset from {eval_data_path}, split='{args.eval_split}'") + eval_data = blending_datasets( + eval_data_path, "1.0", strategy, args.seed, return_eval=False, + # Note: `train_split` parameter is used to specify the desired split name for evaluation data. + train_split=args.eval_split, + ) + eval_data = _apply_label_override( + eval_data, args.label_key, args.label_override, strategy, "evaluation dataset" + ) + if len(eval_data) == 0: + strategy.print(f"Warning: Evaluation dataset at {eval_data_path} with split '{args.eval_split}' is empty. Skipping evaluation.") + else: + eval_data = eval_data.select(range(min(args.max_eval_samples, len(eval_data)))) + + eval_dataset = PromptDatasetVL(eval_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, args.rollout_batch_size // strategy.world_size, False, False, collate_fn=eval_dataset.collate_fn + ) + strategy.print(f"Evaluation dataset loaded: {len(eval_dataset)} samples") + else: + strategy.print("Warning: eval_split specified but no data path available for evaluation.") + + # Prepare pretrain dataset + pretrain_dataloader = None + if args.pretrain_data: + strategy.print(f"Loading pretrain dataset from: {args.pretrain_data} with split: {args.pretrain_split}") + pretrain_data = blending_datasets( + args.pretrain_data, args.pretrain_data_probs, strategy, args.seed, + return_eval=False, train_split=args.pretrain_split, + ) + if len(pretrain_data) == 0: + strategy.print(f"Warning: Pretrain dataset at {args.pretrain_data} is empty. PTX loss will not be applied.") + pretrain_dataloader = None + else: + pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len + # Calculate total samples needed for pretraining + total_pretrain_samples = args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt + pretrain_data_subset = pretrain_data.select(range(min(len(pretrain_data), total_pretrain_samples))) + + pretrain_dataset = SFTDatasetVL( + pretrain_data_subset, tokenizer, pretrain_max_len, strategy, pretrain_mode=True, + ) + strategy.print(f"Loaded {len(pretrain_dataset)} samples for pretraining.") + pretrain_dataloader = itertools.cycle( + iter( + strategy.setup_dataloader( + pretrain_dataset, args.micro_train_batch_size, True, True, pretrain_dataset.collate_fn, + ) + ) + ) + else: + pretrain_dataloader = None + + # Prepare prompts dataloader + prompts_dataloader = strategy.setup_dataloader( + prompts_dataset, args.rollout_batch_size // strategy.world_size, True, True, collate_fn=prompts_dataset.collate_fn + ) + + if args.pretrain_data: + pretrain_dataloader = itertools.cycle( + iter( + strategy.setup_dataloader( + pretrain_dataset, + args.micro_train_batch_size, + True, + True, + pretrain_dataset.collate_fn, + ) + ) + ) + else: + pretrain_dataloader = None + + # for scheduler + num_update_steps_per_episodes = ( + len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs + ) + max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes) + + # gradient_checkpointing + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + if critic is not None: + critic.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + ( + (actor, actor_optim, actor_scheduler), + (critic, critic_optim, critic_scheduler), + reward_models, + initial_model, + ) = strategy.prepare_models_and_optimizers(actor, critic, reward_models, initial_model, args, max_steps) + + strategy.print(reward_models) + + if ema_model: + ema_model._offload = True + ema_model = strategy.prepare(ema_model, is_rlhf=True) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): + _, states = strategy.load_ckpt(actor.model, os.path.join(args.ckpt_path, "_actor"), + optimizer=actor_optim, scheduler=actor_scheduler) + if args.critic_pretrain: + strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic")) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + strategy.report_memory("after models init") + + strategy.report_memory("before setup_inference_engine") + strategy.setup_inference_engine(args, engine_type=args.engine_type, actor=actor) + strategy.report_memory("after setup_inference_engine") + + # configure Trainer + trainer = SPMDPPOTrainerVL( + strategy, + actor, + critic, + reward_models, + initial_model, + ema_model, + actor_optim, + critic_optim, + actor_scheduler, + critic_scheduler, + max_epochs=args.max_epochs, + micro_train_batch_size=args.micro_train_batch_size, + micro_rollout_batch_size=args.micro_rollout_batch_size, + gradient_checkpointing=args.gradient_checkpointing, + tokenizer=tokenizer, + processor=processor, + prompt_max_len=args.prompt_max_len, + value_clip=args.value_clip, + eps_clip=args.eps_clip, + loss_agg_mode=args.loss_agg_mode, + use_gspo=args.use_gspo, + normalize_advantages=args.normalize_advantages, + use_sequence_rewards=args.use_sequence_rewards, + gamma=args.gamma, + lambd=args.lambd, + init_kl_coef=args.init_kl_coef, + kl_target=args.kl_target, + ema_beta=0.992, + ptx_coef=args.ptx_coef, + max_norm=args.max_norm, + # for GPT generation + do_sample=True, + max_new_tokens=args.generate_max_len, + max_length=args.max_len, + temperature=args.temperature, + top_p=args.top_p, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + # reward model + reward_fn=reward_fn, + reward_fn_label_map=label_map, + reward_recipe=RECIPE, + reward_tokenizers=reward_tokenizers, + save_hf_ckpt=args.save_hf_ckpt, + disable_ds_ckpt=args.disable_ds_ckpt, + packing_samples=args.packing_samples, + # overlong_reward + dynamic_sampling=args.dynamic_sampling, + overlong_buffer=args.overlong_buffer, + overlong_buffer_len=args.overlong_buffer_len, + overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, + print_replay_buffer_stats=args.print_replay_buffer_stats, + ) + + trainer.fit(args, prompts_dataloader=prompts_dataloader, pretrain_dataloader=pretrain_dataloader, eval_dataloader=eval_dataloader, consumed_samples=0, num_update_steps_per_episodes=num_update_steps_per_episodes) + + # save model checkpoint after fitting on only rank0 + strategy.save_model( + ema_model if args.enable_ema else actor, + tokenizer, + args.save_path, + ) + + if args.critic_pretrain and args.save_value_network: + strategy.save_model( + critic, + tokenizer, + args.save_path + "_critic", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--engine_type", type=str, default="vllm", help="Choose inference engine type: vllm, sglang") + parser.add_argument("--text_only", action="store_true", default=False) + + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--save_hf_ckpt", action="store_true", default=False) + parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument("--save_trajectories", action="store_true", default=False, help="Save experience trajectories to JSON for debugging") + parser.add_argument("--num_trajectories_to_save", type=int, default=10, help="Number of trajectories to save per checkpoint") + parser.add_argument("--trajectory_analysis", action="store_true", default=False, help="Enable trajectory analysis metrics (repeat_score, reflection_pattern, policy_entropy) and log to wandb") + parser.add_argument("--print_replay_buffer_stats", action="store_true", default=False, help="Print detailed replay buffer statistics during training") + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DAPO + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") + parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") + parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them") + + # PPO + parser.add_argument("--num_episodes", type=int, default=1) + parser.add_argument("--rollout_batch_size", type=int, default=512) + parser.add_argument("--micro_rollout_batch_size", type=int, default=8) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt") + parser.add_argument("--generate_max_len", type=int, default=1024, help="Max tokens to generate in PPO") + parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len") + parser.add_argument("--max_samples", type=int, default=1000000) + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") + parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") + parser.add_argument("--loss_agg_mode", type=str, default='seq-mean-token-mean', + help="Loss aggregation mode. Options: ['token-mean', 'seq-mean-token-sum', 'seq-mean-token-mean', 'seq-mean-token-sum-norm']") + parser.add_argument("--use_gspo", action="store_true", default=False, help="Enable GSPO (Group Sequence Policy Optimization) mode") + parser.add_argument("--normalize_advantages", action="store_true", default=True, help="Enable advantage normalization in GSPO") + parser.add_argument("--use_sequence_rewards", action="store_true", default=True, help="Use sequence-level rewards in GSPO") + parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") + parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd") + parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") + parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--normalize_reward_for_critic", action="store_true", default=False, help="Enable Reward Normalization in critic model") + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--freeze_prefix", action="store_true", default=False, help="Freeze the prefix part (e.g. vision encoder) of the actor model") + parser.add_argument("--freezing_actor_steps", type=int, default=-1, help="Used for critic initialization") + parser.add_argument( + "--n_samples_per_prompt", type=int, default=1, help="number of responses for each prompt in generation" + ) + parser.add_argument("--save_value_network", action="store_true", default=False, help="Save critic model") + parser.add_argument("--actor_learning_rate", type=float, default=1e-6) + parser.add_argument("--critic_learning_rate", type=float, default=9e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--kl_target", type=float, default=None) + parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO") + parser.add_argument( + "--kl_estimator", + type=str, + default="k1", + choices=["k1", "k2", "k3"], + help=( + "In GRPO, k3 is utilized as the loss function, while k2, when used as the loss, is nearly equivalent to k1." + ), + ) + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + + # Reward/Advantage Norm/Clip Arguments + parser.add_argument("--reward_running_norm", action="store_true", default=False, help="Enable running normalization for rewards.") + parser.add_argument("--reward_running_norm_minus_mean", action="store_true", default=False, help="When using reward normalization, subtract the mean; otherwise, only scale by the std.") + parser.add_argument("--reward_clip", type=float, default=0.0, help="Clip rewards to the range [-reward_clip, reward_clip]. 0.0 means no clipping.") + parser.add_argument("--advantages_norm", action="store_true", default=False, help="Enable whitening for advantages.") + parser.add_argument("--advantage_clip", type=float, default=0.0, help="Clip advantages to the range [-advantage_clip, advantage_clip]. 0.0 means no clipping.") + + # DeepSpeed + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--actor_init_on_gpu", action="store_true", default=False) + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + parser.add_argument("--disable_logprobs_flashattn", action="store_true", default=False, help="Disable flash attn implementation in log_probs calculation") + + # FSDP + parser.add_argument("--no_shard_vit", action="store_true", default=False, help="Disable sharding for vision transformer") + parser.add_argument("--meta_init", action="store_true", default=False, help="Initialize models on meta device to save CPU memory") + + # Reinforce + parser.add_argument( + "--advantage_estimator", + type=str, + choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++"], + default="gae", + help="Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm, reinforce++", + ) + + parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO") + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # Models + parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API") + parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--value_head_prefix", type=str, default="score") + + # Custom dataset + parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--prompt_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--prompt_split", type=str, default="train") + + # Evaluation dataset + parser.add_argument("--eval_data", type=str, default=None, help="HF evaluation dataset name or path (default: use prompt_data)") + parser.add_argument("--eval_split", type=str, default="test", help="Evaluation data split (default: test)") + parser.add_argument("--max_eval_samples", type=int, default=500, help="Maximum number of samples to evaluate (default: 500)") + + parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--pretrain_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--pretrain_split", type=str, default="train") + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--images_key", type=str, default="image", help="JSON dataser key for images") + parser.add_argument("--reference_key", type=str, default="reference", help="JSON dataset key for reference answers") + parser.add_argument("--label_key", type=str, default="label", help="JSON dataset key") + parser.add_argument( + "--label_override", + type=str, + default=None, + help="Optional label override applied after dataset loading.", + ) + parser.add_argument("--input_template", type=str, default=None) + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + + parser.add_argument("--system_prompt", type=str, default=None, help="HF System Prompt") + + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="lightrft_train_ppo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + # ModelScope parameters + parser.add_argument("--use_ms", action="store_true", default=False) + + # MultiModal + parser.add_argument("--limit_mm_image_per_prompt", type=int, default=-1, help="the max image number of each text in multi model for inference backend") + + # CPGD + parser.add_argument("--use_cpg_loss", action="store_true", default=False, help="whether to use the clipped policy gradient loss from CPGD") + + add_arguments(parser) + + args = parser.parse_args() + + + if args.advantage_estimator not in ["gae"]: + args.critic_pretrain = None + elif args.critic_pretrain is None: + args.critic_pretrain = args.pretrain + + if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]: + assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" + + if args.use_kl_loss: + if args.kl_estimator not in ["k2", "k3"]: + print(f"Recommend setting {args.kl_estimator} to 'k2' or 'k3' when using KL as a loss") + else: + if args.kl_estimator not in ["k1"]: + print(f"Recommend setting {args.kl_estimator} to 'k1' when not using KL as a loss.") + + if args.advantage_estimator in ["gae", "cpgd"] and args.use_kl_loss: + warnings.warn( + "Using use_kl_loss=True with non-normalized advantage estimator " + "may result in double KL penalty. Consider disabling --use_kl_loss " + "or using --advantage_estimator group_norm" + ) + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + if args.use_ms: + from modelscope.utils.hf_util import patch_hub + + # Patch hub to download models from modelscope to speed up. + patch_hub() + + train(args) diff --git a/lightrft/strategy/strategy_base.py b/lightrft/strategy/strategy_base.py index ac72ba25..050c5a00 100644 --- a/lightrft/strategy/strategy_base.py +++ b/lightrft/strategy/strategy_base.py @@ -535,10 +535,13 @@ def prepare_models_and_optimizers(self, actor, critic, reward_models, initial_mo if critic is not None: critic = self.prepare_model(critic, is_training=True) if not self.config.remote_rm_url: + reward_model_shard_size = self.world_size if isinstance(reward_models, (tuple, list)): - reward_models = [self.prepare_model(model, shard_size=8) for model in reward_models] + reward_models = [ + self.prepare_model(model, shard_size=reward_model_shard_size) for model in reward_models + ] else: - reward_models = self.prepare_model(reward_models, shard_size=8) + reward_models = self.prepare_model(reward_models, shard_size=reward_model_shard_size) # Configure optimizers actor_optim = self.create_optimizer( diff --git a/lightrft/strategy/vllm_utils/vllm_worker_wrap_no_ray.py b/lightrft/strategy/vllm_utils/vllm_worker_wrap_no_ray.py index 89c6104e..3747e4e5 100644 --- a/lightrft/strategy/vllm_utils/vllm_worker_wrap_no_ray.py +++ b/lightrft/strategy/vllm_utils/vllm_worker_wrap_no_ray.py @@ -7,6 +7,8 @@ """ import torch +import vllm +from packaging.version import Version # vLLM version compatibility notes: # -------------------------------- # In older versions of vLLM (< 0.13.0), the Worker class is located under: @@ -16,11 +18,22 @@ # vllm.v1.worker.gpu_worker.Worker # # To maintain compatibility across different vLLM versions, we try importing Worker -# from the new v1 path first (for vllm>=0.13.0). If the import fails (ModuleNotFoundError), -# we fall back to importing from the old path (for vllm<0.13.0). -try: - from vllm.v1.worker.gpu_worker import Worker -except (ModuleNotFoundError, ImportError): +# from the new v1 path only for vllm>=0.13.0. Older releases like v0.7.x may +# already expose a v1 module tree, but their uniproc executor still expects the +# legacy Worker implementation with methods such as determine_num_available_blocks. +if Version(vllm.__version__) >= Version("0.13.0"): + try: + from vllm.v1.worker.gpu_worker import Worker + except (ModuleNotFoundError, ImportError): + try: + from vllm.worker.worker import Worker + except (ModuleNotFoundError, ImportError): + raise ImportError( + "Could not import Worker from vllm. " + "Please ensure you have a compatible version of vllm installed. " + "Supported versions: vllm>=0.6.3 or vllm>=0.13.0" + ) +else: try: from vllm.worker.worker import Worker except (ModuleNotFoundError, ImportError): diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index fe12d534..4f11022b 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -791,6 +791,7 @@ def _compute_standard_torch_rewards( rm_output = rm( sequences, output.attention_mask, + references=output.references, prompt_and_output=output.prompt_and_output, raw_images=output.raw_images, img_num=output.image_num, diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index e3b44091..ab2fc37c 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -405,6 +405,7 @@ def fit( all_rewards = [] all_format_rewards = [] all_accuracy_rewards = [] + all_general_model_rewards = [] all_response_lengths = [] for item in self.replay_buffer.items: @@ -428,6 +429,11 @@ def fit( all_format_rewards.append(reward_metrics['format_reward']) if 'accuracy_reward' in reward_metrics: all_accuracy_rewards.append(reward_metrics['accuracy_reward']) + general_model_reward = reward_metrics.get("general_model_reward") + if general_model_reward is None: + general_model_reward = reward_metrics.get("model_reward") + if general_model_reward is not None: + all_general_model_rewards.append(general_model_reward) # Collect response lengths from rollout if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: @@ -476,6 +482,18 @@ def fit( if abs(mean_accuracy_reward) > 1e-6: rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward + if all_general_model_rewards: + if isinstance(all_general_model_rewards[0], torch.Tensor): + general_model_tensor = torch.cat([t.to(device).float() for t in all_general_model_rewards]) + else: + general_model_tensor = torch.tensor( + all_general_model_rewards, dtype=torch.float32, device=device + ) + + mean_general_model_reward = general_model_tensor.mean().item() + if abs(mean_general_model_reward) > 1e-6: + rollout_status["rollout_general_model_reward"] = mean_general_model_reward + if all_response_lengths: # [TENSOR-FIX] Handle both tensor lists and scalar lists if isinstance(all_response_lengths[0], torch.Tensor): @@ -1200,6 +1218,7 @@ def evaluate(self, eval_dataloader, global_step): all_rewards = [] all_format_rewards = [] all_accuracy_rewards = [] + all_general_model_rewards = [] all_response_lengths = [] num_eval_batches = 0 @@ -1249,6 +1268,11 @@ def extract_values(val): all_format_rewards.extend(extract_values(rm['format_reward'])) if 'accuracy_reward' in rm: all_accuracy_rewards.extend(extract_values(rm['accuracy_reward'])) + general_model_reward = rm.get("general_model_reward") + if general_model_reward is None: + general_model_reward = rm.get("model_reward") + if general_model_reward is not None: + all_general_model_rewards.extend(extract_values(general_model_reward)) num_eval_batches += 1 if num_eval_batches >= len(eval_dataloader): @@ -1271,6 +1295,7 @@ def compute_stats(name, values_list): compute_stats("reward", all_rewards) compute_stats("format_reward", all_format_rewards) compute_stats("accuracy_reward", all_accuracy_rewards) + compute_stats("general_model_reward", all_general_model_rewards) compute_stats("response_length", all_response_lengths) metrics["num_samples"] = len(all_rewards) diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index d79a7458..a0574905 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -316,7 +316,7 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train all_rewards = [] all_format_rewards = [] all_accuracy_rewards = [] - all_model_rewards = [] + all_general_model_rewards = [] all_rule_rewards = [] all_advantages = [] all_returns = [] @@ -334,8 +334,11 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train all_format_rewards.append(reward_metrics['format_reward']) if 'accuracy_reward' in reward_metrics: all_accuracy_rewards.append(reward_metrics['accuracy_reward']) - if 'model_reward' in reward_metrics: - all_model_rewards.append(reward_metrics['model_reward']) + general_model_reward = reward_metrics.get("general_model_reward") + if general_model_reward is None: + general_model_reward = reward_metrics.get("model_reward") + if general_model_reward is not None: + all_general_model_rewards.append(general_model_reward) if 'rule_reward' in reward_metrics: all_rule_rewards.append(reward_metrics['rule_reward']) @@ -379,15 +382,15 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train status_mean["accuracy_reward_mean"] = accuracy_tensor.mean().item() status_mean["accuracy_reward_std"] = accuracy_tensor.std().item() - if all_model_rewards: + if all_general_model_rewards: # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_model_rewards[0], torch.Tensor): - model_tensor = torch.cat([t.to(device).float() for t in all_model_rewards]) + if isinstance(all_general_model_rewards[0], torch.Tensor): + model_tensor = torch.cat([t.to(device).float() for t in all_general_model_rewards]) else: - model_tensor = torch.tensor(all_model_rewards, dtype=torch.float32, device=device) + model_tensor = torch.tensor(all_general_model_rewards, dtype=torch.float32, device=device) if model_tensor.abs().sum() > 0: # Only log if model rewards are non-zero - status_mean["model_reward_mean"] = model_tensor.mean().item() - self.strategy.print(f" model_reward_mean: {status_mean['model_reward_mean']}") + status_mean["general_model_reward_mean"] = model_tensor.mean().item() + self.strategy.print(f" general_model_reward_mean: {status_mean['general_model_reward_mean']}") if all_rule_rewards: # [TENSOR-FIX] Handle both tensor lists and scalar lists @@ -444,6 +447,9 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train f"✅ Accuracy Reward: {status_mean['accuracy_reward_mean']:.4f} ± {status_mean['accuracy_reward_std']:.4f}" # noqa ) + if all_general_model_rewards and "general_model_reward_mean" in status_mean: + self.strategy.print(f"🧠 General RM Reward:{status_mean['general_model_reward_mean']:.4f}") + if all_advantages: self.strategy.print( f"📈 Advantages: {status_mean['advantages_mean']:.4f} ± {status_mean['advantages_std']:.4f} " # noqa