diff --git a/examples/orm_rl_demo/README.md b/examples/orm_rl_demo/README.md
new file mode 100755
index 00000000..ef6e8e46
--- /dev/null
+++ b/examples/orm_rl_demo/README.md
@@ -0,0 +1,105 @@
+
+
+# ORM RL Demo
+
+Complete ORM trajectory-scoring RL training demo based on Geo3K.
+
+
+
+## Overview
+
+This demo shows the full pipeline of using an ORM to score trajectories for RL training:
+- dataset: Geo3K
+- actor: Qwen2.5-VL 7B model
+- reward: one general outcome reward model combined with rule-based accuracy reward and format reward, all contributing to the GRPO loss
+- training engine: FSDP, inference engine: SGLang
+
+The actor generates Geo3K trajectories, the general ORM scores them, and the scores are combined with a rule-based accuracy reward (`accuracy_reward`) and a format reward (`format_reward`) to compute the final GRPO loss. To avoid rewriting the Geo3K dataset files, the demo overrides the dataset label to `geo3k_general` at runtime so the original dataset path can be reused while routing through the general ORM reward mix.
+
+Environment requirements stay aligned with the repository-level [README.md](../../README.md). Refer to the main project document.
+
+## Project Structure
+
+```text
+orm_rl_demo/
+├── train_colocate.py
+├── reward_models.py
+├── reward_models_utils.py
+├── test_reward_models.py
+└── run_general_fsdp_qwenvl.sh
+```
+
+## Quick Start
+
+Set the data and model paths, then run the entry script:
+
+```bash
+export DATA_PATH=/path/to/geo3k
+export PRETRAIN_PATH=/path/to/Qwen2.5-VL-7B-Instruct
+export REWARD_PRETRAIN_PATHS='{"general":"/path/to/general-reward-model"}'
+bash examples/orm_rl_demo/run_general_fsdp_qwenvl.sh
+```
+
+Set the dataset and model paths via the environment variables above before running.
+
+## Results
+
+### Experiment Setup
+
+This demo has been validated with one real 2-GPU full training run (W&B: [ORM-RL-Demo-QwenVL-7B-Geo3K](https://wandb.ai/hansbug/ORM-RL-Demo-QwenVL-7B-Geo3K/runs/zrekazyw)):
+
+| Item | Value |
+| --- | --- |
+| Actor | Qwen2.5-VL-7B-Instruct |
+| General RM | Qwen2.5-VL-7B general reward model |
+| Dataset | Geo3K |
+| Training engine | FSDP |
+| Inference engine | SGLang (`rm_use_engine=True`) |
+| Reward mixing | `format_reward × 0.1 + general_model_reward × 0.2 + accuracy_reward × 0.7` |
+| Batch sizes | `train_batch_size=128`, `rollout_batch_size=128` |
+| Sampling | `n_samples_per_prompt=8`, `num_episodes=20` |
+| Sequence length | `prompt_max_len=1024`, `generate_max_len=2048` |
+| Optimizer / KL | `actor_learning_rate=1e-6`, `init_kl_coef=0.001`, `lr_warmup_ratio=0.03` |
+
+Three reward components — rule-based accuracy reward (`accuracy_reward`), ORM scoring (`general_model_reward`, coefficient 0.2), and format reward (`format_reward`) — are combined with the weights above and together compute the final GRPO loss. The `general_model_reward` values shown (e.g. `0.2`) reflect the ORM output (range 0.0 / 0.5 / 1.0) multiplied by the 0.2 coefficient, not the raw model score.
+
+### Curve Results
+
+The run completed successfully (`train/global_step=320`, 16 eval passes):
+- `eval/reward_mean` improved from `0.4636` to `0.5679`
+- Best `eval/reward_mean=0.5686` at `train_step=260`
+- Final `eval/accuracy_reward_mean=0.5166`, `eval/format_reward_mean=0.9956`, `eval/general_model_reward_mean=0.1067`
+
+
+
+
+
+
+
+### Case Study
+
+Between step 80 and step 320, two question stems appear in both saved trajectories. The following shows the same two questions compared across early and late training.
+
+#### Question A: Parallelogram Area
+
+
+
+
+
+- Step 80 rewards: `total=0.3`, `format=1.0`, `accuracy=0.0`, `general_model=0.2`, `rule=0.1`
+- Step 320 rewards: `total=1.0`, `format=1.0`, `accuracy=1.0`, `general_model=0.2`, `rule=0.8`
+- The actor already produced a close answer at step 80 so the ORM scored it near 1.0 (contributing 0.2 after the 0.2 coefficient); by step 320 the output moved from `38.97` to the rule-matching `39.0`, flipping `accuracy_reward` from `0.0` to `1.0`.
+
+#### Question B: Tangent Geometry `y`
+
+
+
+
+
+- Step 80 rewards: `total=0.1`, `format=1.0`, `accuracy=0.0`, `general_model=0.0`, `rule=0.1`
+- Step 320 rewards: `total=1.0`, `format=1.0`, `accuracy=1.0`, `general_model=0.2`, `rule=0.8`
+- At step 80 only format was preserved while both accuracy and ORM failed to reward the answer; by step 320 both became positive contributions.
+
+## License
+
+This project is licensed under the Apache 2.0 License. See [LICENSE](../../LICENSE) for details.
diff --git a/examples/orm_rl_demo/README_zh.md b/examples/orm_rl_demo/README_zh.md
new file mode 100755
index 00000000..96046056
--- /dev/null
+++ b/examples/orm_rl_demo/README_zh.md
@@ -0,0 +1,105 @@
+
+
+# ORM RL Demo 训练示例
+
+基于 Geo3K 的 ORM 轨迹打分 RL 训练完整 demo。
+
+
+
+## 概述
+
+本示例展示了使用 ORM 对轨迹打分进行 RL 训练的完整流程,包含以下配置:
+- 数据集:Geo3K
+- actor:Qwen2.5-VL 7B 模型
+- reward:单个 general outcome reward model,与规则正确性奖励和格式奖励组合后共同计算 GRPO loss
+- 训练引擎:FSDP,推理引擎:SGLang
+
+训练时,actor 在 Geo3K 上生成轨迹,general ORM 对轨迹打分,与规则正确性奖励(`accuracy_reward`)和格式奖励(`format_reward`)三路混合后,共同计算 GRPO loss。为了不直接改写 Geo3K 数据集文件,本 demo 在运行时将数据标签覆盖为 `geo3k_general`,沿用原始数据路径的同时走 general ORM reward 融合逻辑。
+
+环境要求与仓库根目录 [README_zh.md](../../README_zh.md#环境要求) 保持一致,请直接参考主文档。
+
+## 项目结构
+
+```text
+orm_rl_demo/
+├── train_colocate.py
+├── reward_models.py
+├── reward_models_utils.py
+├── test_reward_models.py
+└── run_general_fsdp_qwenvl.sh
+```
+
+## 快速开始
+
+设置数据和模型路径后,运行入口脚本:
+
+```bash
+export DATA_PATH=/path/to/geo3k
+export PRETRAIN_PATH=/path/to/Qwen2.5-VL-7B-Instruct
+export REWARD_PRETRAIN_PATHS='{"general":"/path/to/general-reward-model"}'
+bash examples/orm_rl_demo/run_general_fsdp_qwenvl.sh
+```
+
+运行前通过上述环境变量指定数据和模型路径。
+
+## 实验结果
+
+### 实验设置
+
+本 demo 已通过一次真实的 2 卡全量训练验通(W&B run:[ORM-RL-Demo-QwenVL-7B-Geo3K](https://wandb.ai/hansbug/ORM-RL-Demo-QwenVL-7B-Geo3K/runs/zrekazyw)),关键配置如下:
+
+| 项目 | 值 |
+| --- | --- |
+| Actor | Qwen2.5-VL-7B-Instruct |
+| General RM | Qwen2.5-VL-7B general reward model |
+| 数据 | Geo3K |
+| 训练引擎 | FSDP |
+| 推理引擎 | SGLang(`rm_use_engine=True`) |
+| Reward 融合 | `format_reward × 0.1 + general_model_reward × 0.2 + accuracy_reward × 0.7` |
+| Batch 大小 | `train_batch_size=128`, `rollout_batch_size=128` |
+| 采样 | `n_samples_per_prompt=8`, `num_episodes=20` |
+| 长度 | `prompt_max_len=1024`, `generate_max_len=2048` |
+| 优化 / KL | `actor_learning_rate=1e-6`, `init_kl_coef=0.001`, `lr_warmup_ratio=0.03` |
+
+三路奖励(规则正确性奖励 `accuracy_reward`、ORM 打分 `general_model_reward` 系数 0.2、格式奖励 `format_reward`)按上述权重混合,共同计算最终的 GRPO loss。其中 `general_model_reward` 对应的 0.2 是权重系数,ORM 模型本身的输出范围为 0.0 / 0.5 / 1.0,乘以 0.2 后得到 reward 贡献。
+
+### 整体曲线结果
+
+训练完整跑完(`train/global_step=320`,共 16 次 eval):
+- `eval/reward_mean` 从 `0.4636` 提升到 `0.5679`
+- Best `eval/reward_mean=0.5686`,出现在 `train_step=260`
+- Final `eval/accuracy_reward_mean=0.5166`,`eval/format_reward_mean=0.9956`,`eval/general_model_reward_mean=0.1067`
+
+
+
+
+
+
+
+### 案例分析
+
+Step 80 和 Step 320 之间共有 2 道题目重叠,以下展示这 2 道题从早期到末期的真实对照。
+
+#### Question A:平行四边形面积题
+
+
+
+
+
+- Step 80 reward:`total=0.3`,`format=1.0`,`accuracy=0.0`,`general_model=0.2`,`rule=0.1`
+- Step 320 reward:`total=1.0`,`format=1.0`,`accuracy=1.0`,`general_model=0.2`,`rule=0.8`
+- 含义:step 80 时 actor 答案已很接近,ORM 打出 1.0,乘系数 0.2 后贡献 0.2;到 step 320 时,输出从 `38.97` 修正为规则答案 `39.0`,`accuracy_reward` 从 `0.0` 跳至 `1.0`。
+
+#### Question B:切线几何 `y`
+
+
+
+
+
+- Step 80 reward:`total=0.1`,`format=1.0`,`accuracy=0.0`,`general_model=0.0`,`rule=0.1`
+- Step 320 reward:`total=1.0`,`format=1.0`,`accuracy=1.0`,`general_model=0.2`,`rule=0.8`
+- 含义:step 80 时只保住了格式,accuracy 和 general RM 均未给分;到 step 320 时,两项均变为正向贡献。
+
+## 许可证
+
+本项目采用 Apache 2.0 许可证。详见 [LICENSE](../../LICENSE)。
diff --git a/examples/orm_rl_demo/assets/exp_20260417/optimization_dashboard.png b/examples/orm_rl_demo/assets/exp_20260417/optimization_dashboard.png
new file mode 100644
index 00000000..12dc06cf
Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/optimization_dashboard.png differ
diff --git a/examples/orm_rl_demo/assets/exp_20260417/question_a_step320.png b/examples/orm_rl_demo/assets/exp_20260417/question_a_step320.png
new file mode 100644
index 00000000..b4ffb426
Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/question_a_step320.png differ
diff --git a/examples/orm_rl_demo/assets/exp_20260417/question_a_step80.png b/examples/orm_rl_demo/assets/exp_20260417/question_a_step80.png
new file mode 100644
index 00000000..edf199a9
Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/question_a_step80.png differ
diff --git a/examples/orm_rl_demo/assets/exp_20260417/question_b_step320.png b/examples/orm_rl_demo/assets/exp_20260417/question_b_step320.png
new file mode 100644
index 00000000..5f1c83f7
Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/question_b_step320.png differ
diff --git a/examples/orm_rl_demo/assets/exp_20260417/question_b_step80.png b/examples/orm_rl_demo/assets/exp_20260417/question_b_step80.png
new file mode 100644
index 00000000..5eba3e72
Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/question_b_step80.png differ
diff --git a/examples/orm_rl_demo/assets/exp_20260417/reward_dashboard.png b/examples/orm_rl_demo/assets/exp_20260417/reward_dashboard.png
new file mode 100644
index 00000000..2b2d8e9a
Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/reward_dashboard.png differ
diff --git a/examples/orm_rl_demo/assets/exp_20260417/summary_card.png b/examples/orm_rl_demo/assets/exp_20260417/summary_card.png
new file mode 100644
index 00000000..43e69a06
Binary files /dev/null and b/examples/orm_rl_demo/assets/exp_20260417/summary_card.png differ
diff --git a/examples/orm_rl_demo/reward_models.py b/examples/orm_rl_demo/reward_models.py
new file mode 100755
index 00000000..dc671cf4
--- /dev/null
+++ b/examples/orm_rl_demo/reward_models.py
@@ -0,0 +1,1087 @@
+"""
+General reward model helpers for the ORM RL Geo3K demo.
+
+This example keeps only the general outcome reward-model path that is exercised by
+`examples/orm_rl_demo/train_colocate.py` and `examples/orm_rl_demo/test_reward_models.py`.
+The shared helper functions below are used by that general reward model for both
+HuggingFace and engine-based inference.
+"""
+from __future__ import annotations
+
+from typing import Optional, List, Tuple
+import re
+import copy
+import os
+import torch
+import torch.nn as nn
+import torch.distributed as dist
+from transformers import LogitsProcessor
+from itertools import zip_longest
+
+from lightrft.utils import get_current_device
+from lightrft.strategy.utils.distributed_util import gather_inputs_object_for_inference
+from lightrft.strategy import is_engine
+
+
+# ============================================================================
+# Utility Functions
+# ============================================================================
+
+def is_chinese(text):
+ """
+ Detect whether text contains Chinese characters.
+
+ :param text: Text string to detect
+ :type text: str
+ :return: True if text contains Chinese characters, False otherwise
+ :rtype: bool
+ """
+ if not isinstance(text, str):
+ return False
+ chinese_pattern = re.compile(r'[\u4e00-\u9fff]')
+ return bool(chinese_pattern.search(text))
+
+
+def _pack_engine_inputs(
+ prompts: list[str],
+ image_data: list[list] | None,
+) -> tuple[list[str], list[list] | None]:
+ """
+ Pack engine inputs ensuring prompts and image_data have consistent lengths.
+
+ Returns None for image_data when all images are empty to skip redundant parameters.
+
+ :param prompts: List of text prompts
+ :type prompts: list[str]
+ :param image_data: List of image data, each element is a list of images
+ :type image_data: list[list] or None
+ :return: Processed (prompts, image_data or None)
+ :rtype: tuple
+ """
+ if image_data is None:
+ return prompts, None
+
+ fixed_prompts, fixed_images = [], []
+ for p, imgs in zip(prompts, image_data):
+ if "<|image_pad|>" in p:
+ fixed_prompts.append(p)
+ fixed_images.append(imgs[:1] or [None]) # at least one placeholder
+ else:
+ fixed_prompts.append(p)
+ fixed_images.append([])
+
+ assert len(fixed_prompts) == len(fixed_images)
+
+ if all(len(imgs) == 0 for imgs in fixed_images):
+ fixed_images = None
+
+ return fixed_prompts, fixed_images
+
+
+def _align_prompts_images(
+ prompts: list[str],
+ image_data: list[list] | None,
+) -> tuple[list[str], list[list] | None]:
+ """
+ Align prompts and images, separating text-only and multimodal data.
+
+ Prompts containing ``<|image_pad|>`` must have at least one placeholder image;
+ prompts without placeholders must have no images.
+
+ :param prompts: List of text prompts
+ :type prompts: list[str]
+ :param image_data: List of image data (None if no images)
+ :type image_data: list[list] or None
+ :return: (text_prompts, text_indices, mm_prompts, mm_images)
+ :rtype: tuple
+ """
+ if image_data is None: # No images passed at all
+ return prompts, None
+ text_prompts = []
+ mm_prompts, mm_images = [], []
+ text_inds = []
+
+ ind = 0
+ for p, imgs in zip_longest(prompts, image_data, fillvalue=None):
+ if p is None: # Extra images → discard
+ continue
+
+ imgs = [] if imgs is None else imgs # Ensure imgs is a list
+ if "<|image_pad|>" in p: # Must keep 1 placeholder
+ imgs = imgs[:1] or [None]
+ if isinstance(imgs[0], list):
+ imgs = imgs[0]
+ mm_images.append(imgs)
+ mm_prompts.append(p)
+ else: # Pure text prompt cannot have images
+ text_prompts.append(p)
+ text_inds.append(ind)
+
+ ind += 1
+
+ return text_prompts, text_inds, mm_prompts, mm_images
+
+
+def _hf_or_engine_generate(
+ model,
+ *,
+ input_ids : torch.Tensor | None = None,
+ attention_mask : torch.Tensor | None = None,
+ pixel_values : torch.Tensor | None = None,
+ image_grid_thw : torch.Tensor | None = None,
+ prompts : List[str] | None = None,
+ image_data : List[List] | None = None,
+ **gen_kwargs,
+) -> Tuple[List[str], torch.Tensor | None]:
+ """
+ Unified generation interface supporting both HuggingFace models and SGLang engines.
+
+ Automatically detects model type. Engine mode uses string prompts and image_data;
+ HF mode uses tensor inputs (input_ids, pixel_values, etc.).
+
+ :param model: HF model or SGLang engine instance
+ :param input_ids: Input token IDs for HF mode
+ :type input_ids: torch.Tensor or None
+ :param attention_mask: Attention mask for HF mode
+ :type attention_mask: torch.Tensor or None
+ :param pixel_values: Image pixel values for HF mode
+ :type pixel_values: torch.Tensor or None
+ :param image_grid_thw: Image grid size for HF mode
+ :type image_grid_thw: torch.Tensor or None
+ :param prompts: Text prompts for Engine mode
+ :type prompts: list[str] or None
+ :param image_data: Image data for Engine mode
+ :type image_data: list[list] or None
+ :param gen_kwargs: Generation parameters (max_new_tokens, temperature, etc.)
+ :return: (list of generated texts, generated token IDs or None)
+ :rtype: tuple
+ """
+ if is_engine(model):
+ assert input_ids is None, "Cannot pass input_ids in engine mode"
+ enable_sleep_mode = True
+ if hasattr(model, "llm_engine"):
+ enable_sleep_mode = model.llm_engine.vllm_config.model_config.enable_sleep_mode
+
+ if enable_sleep_mode:
+ model.wake_up()
+
+ if hasattr(model, "tp_group_cpu"):
+ sampling_params = {
+ **{k: v for k, v in gen_kwargs.items() if k not in ("do_sample")}
+ }
+
+ prompt_and_output = gather_inputs_object_for_inference(prompts, model.tp_group_cpu)
+ image_data = gather_inputs_object_for_inference(image_data, model.tp_group_cpu)
+
+ text_prompts, text_inds, mm_prompts, mm_images = _align_prompts_images(prompt_and_output, image_data)
+ text_output = []
+ mm_output = []
+
+ if len(text_prompts) > 0:
+ sgl_outputs = model.generate(prompt=text_prompts, sampling_params=sampling_params, gather_inputs=False)
+ text_output = [sgl_out["text"] for sgl_out in sgl_outputs]
+
+ if len(mm_prompts) > 0:
+ sgl_outputs = model.generate(
+ prompt=mm_prompts,
+ image_data=mm_images,
+ sampling_params=sampling_params,
+ gather_inputs=False,
+ )
+ mm_output = [sgl_out["text"] for sgl_out in sgl_outputs]
+
+ texts = []
+ text_output_iter = iter(text_output)
+ mm_output_iter = iter(mm_output)
+ # merge results in original order
+ if len(text_inds) > 0:
+ for i in range(len(prompt_and_output)):
+ if i in text_inds:
+ texts.append(next(text_output_iter))
+ else:
+ texts.append(next(mm_output_iter))
+ else:
+ texts = mm_output
+
+ if model._tp_size > 1:
+ num_per_rank = len(texts) // model._tp_size
+ texts = texts[model._tp_rank * num_per_rank : (model._tp_rank + 1) * num_per_rank]
+ else:
+ from vllm import SamplingParams
+
+ sampling_kwargs = dict(gen_kwargs)
+ max_tokens = sampling_kwargs.pop("max_new_tokens", None)
+ if max_tokens is not None:
+ sampling_kwargs["max_tokens"] = max_tokens
+ sampling_kwargs.pop("do_sample", None)
+ sampling_params = SamplingParams(**sampling_kwargs)
+
+ prompt_and_output = prompts or []
+ prompt_and_output, image_data = _pack_engine_inputs(
+ prompt_and_output,
+ image_data,
+ )
+ if image_data is None:
+ image_data = [None] * len(prompt_and_output)
+
+ vllm_prompts = []
+ for prompt, imgs in zip(prompt_and_output, image_data):
+ prompt_item = {"prompt": prompt}
+ if imgs:
+ prompt_item["multi_modal_data"] = {
+ "image": imgs[0] if len(imgs) == 1 else imgs
+ }
+ vllm_prompts.append(prompt_item)
+
+ vllm_outputs = model.generate(
+ vllm_prompts,
+ sampling_params=sampling_params,
+ use_tqdm=False,
+ )
+ texts = [
+ out.outputs[0].text if getattr(out, "outputs", None) else ""
+ for out in vllm_outputs
+ ]
+
+ if dist.is_initialized() and dist.get_rank() == 0:
+ if not texts or all(not t for t in texts):
+ print("WARNING: _hf_or_engine_generate produced empty output for all prompts.")
+
+ if enable_sleep_mode:
+ model.sleep()
+ torch.cuda.empty_cache()
+ return texts, None
+
+ else:
+ gen_ids = model.generate(
+ input_ids = input_ids,
+ attention_mask = attention_mask,
+ pixel_values = pixel_values,
+ image_grid_thw = image_grid_thw,
+ **gen_kwargs,
+ )
+ trim = [o[len(i):] for i, o in zip(input_ids, gen_ids)]
+ return trim, trim
+
+
+# ============================================================================
+# Vision Token Processing
+# ============================================================================
+
+_VISION_RE = re.compile(r"<\|vision_start\|>.*?<\|vision_end\|>", re.S)
+
+def _strip_vision_tokens(text: str) -> str:
+ """Remove vision token markers from text."""
+ return re.sub(_VISION_RE, "", text).replace("", "").strip()
+
+
+def _clean_vision_token(text: str) -> str:
+ """
+ Clean vision tokens from text, supporting multiple formats.
+
+ Supported formats:
+ - <|vision_start|><|image_pad|>...<|vision_end|>
+ -
...
+ -
+ """
+ patterns = [
+ r"<\|vision_start\|>(<\|image_pad\|>)+<\|vision_end\|>",
+ r"
()+",
+ r""
+ ]
+ for p in patterns:
+ text = re.sub(p, "", text)
+ return text
+
+
+def _replace_vision_token(text: str) -> str:
+ """
+ Replace vision tokens with standard markers.
+
+ Conversion rules:
+ - <|vision_start|>...<|vision_end|> ->
+ -
...... -> (internvl format)
+ """
+ text = re.sub(r"<\|vision_start\|>(<\|image_pad\|>)+<\|vision_end\|>", "", text)
+ text = re.sub(r"
()+", "", text) # internvl
+
+ return text
+
+
+def _strip_pad_eos(text: str, pad: str, eos: str) -> str:
+ """
+ Remove leading and trailing pad and eos tokens from text.
+
+ :param text: Text to process
+ :type text: str
+ :param pad: Pad token string
+ :type pad: str
+ :param eos: EOS token string
+ :type eos: str
+ :return: Cleaned text
+ :rtype: str
+ """
+ pad, eos = map(re.escape, (pad, eos))
+ text = re.sub(f"^({eos}|{pad})+", "", text)
+ text = re.sub(f"({eos}|{pad})+$", "", text)
+ return text
+
+# ============================================================================
+# Dialog Parsing Constants and Functions
+# ============================================================================
+
+# Define constants for vertical bars used in role tags for better readability
+FULL_BAR = "|" # U+FF5C Full-width vertical bar
+HALF_BAR = "|" # U+007C ASCII vertical bar
+
+def _parse_dialog(text: str) -> dict:
+ """
+ Parse a full conversation string into a dictionary mapping roles to their content.
+
+ Identifies role tags like ``<| role_name |>`` and extracts the text that follows
+ each tag. If a role appears multiple times, only the last occurrence is kept.
+
+ :param text: Conversation string with role tags
+ :type text: str
+ :return: Dict mapping role names to their message content
+ :rtype: dict
+ """
+ # 1. Define the regex pattern to find all possible role tags.
+ # The pattern is written in verbose mode (re.X) for clarity.
+ tag_pattern = re.compile(
+ rf"""
+ < # Match the opening '<'
+ [{HALF_BAR}{FULL_BAR}] # Match either a half-width or full-width vertical bar
+ \s*? # Match any whitespace characters (non-greedy)
+ (.*?) # Capture the role name (non-greedy)
+ \s*? # Match any whitespace characters (non-greedy)
+ [{HALF_BAR}{FULL_BAR}] # Match either a half-width or full-width vertical bar
+ > # Match the closing '>'
+ """, re.X | re.S
+ )
+
+ # Find all occurrences of role tags in the text.
+ tags = list(tag_pattern.finditer(text))
+ dialog = {}
+
+ # 2. Iterate through the found tags to extract roles and content.
+ for idx, tag in enumerate(tags):
+ # Extract the role name and normalize it by stripping whitespace and converting to lowercase.
+ raw_role = tag.group(1).strip()
+ role = raw_role.lower()
+
+ # Skip special meta-tags that define structure but are not roles.
+ if role in {"im_start", "im_end", "begin of sentence", "end of sentence"}:
+ continue
+
+ # Determine the start and end positions of the content for the current role.
+ # The content starts right after the current tag.
+ start_pos = tag.end()
+ # The content ends right before the next tag starts, or at the end of the text.
+ end_pos = tags[idx + 1].start() if idx + 1 < len(tags) else len(text)
+ content = text[start_pos:end_pos].strip()
+
+ # 3. Special handling for the 'assistant' role to remove the chain-of-thought block.
+ # If the content contains ..., we extract only the final response
+ # that appears after the last tag.
+ if role == "assistant" and "" in content and "" in content:
+ think_end = content.rfind("")
+ if think_end != -1:
+ content = content[think_end + len(""):].strip()
+
+ # Store the role and its content in the dictionary.
+ # If the role already exists, its value will be updated with the new content.
+ dialog[role] = content
+
+ return dialog
+
+def preprocess_inputs_sglang(
+ prompt_and_outputs: list,
+ references: list,
+ question_response_format_zh: list or str,
+ question_response_format_en: str,
+ system_prompt_zh: str = None,
+ system_prompt_en: str = None,
+ system_prompt: bool = False,
+):
+ """
+ Preprocess batch conversation inputs for SGLang engine.
+
+ Parses conversation text, selects a format template based on detected language,
+ and optionally prepends a system prompt.
+
+ :param prompt_and_outputs: List of conversation texts
+ :type prompt_and_outputs: list
+ :param references: List of reference answers
+ :type references: list
+ :param question_response_format_zh: Chinese format template (string or per-sample list)
+ :type question_response_format_zh: str or list
+ :param question_response_format_en: English format template
+ :type question_response_format_en: str
+ :param system_prompt_zh: Chinese system prompt
+ :type system_prompt_zh: str or None
+ :param system_prompt_en: English system prompt
+ :type system_prompt_en: str or None
+ :param system_prompt: Whether to prepend a system prompt
+ :type system_prompt: bool
+ :return: List of formatted texts ready for model input
+ :rtype: list
+ """
+ raw_texts = []
+ # Process each conversation in the batch.
+ for i, po in enumerate(prompt_and_outputs):
+ # Parse the conversation string into a role-content dictionary.
+ dialog = _parse_dialog(po)
+
+ # --- Step 1: Extract the question ---
+ if "user" in dialog:
+ question_raw = dialog["user"]
+ else:
+ # Fallback logic: if 'user' role is not found, use the content from the
+ # first role that is not 'assistant'. If no such role exists,
+ # use the entire original string as the question.
+ question_raw = next(
+ (txt for role, txt in dialog.items() if role != "assistant"), po
+ )
+ # Clean the extracted question (e.g., remove special vision tokens).
+ # Note: _clean_vision_token function is assumed to be defined elsewhere.
+ question = _clean_vision_token(question_raw)
+
+ # --- Step 2: Extract the response ---
+ if "assistant" in dialog:
+ response = dialog["assistant"]
+ else:
+ # Fallback logic: if 'assistant' role is not found, assume the response
+ # is the text following the last tag.
+ response = po.split("")[-1].strip()
+
+ reference = references[i]
+
+ # --- Step 3: Select the appropriate formatting template ---
+ # Note: is_chinese function is assumed to be defined elsewhere.
+ is_zh = is_chinese(question)
+ if isinstance(question_response_format_zh, list):
+ # New feature: Use a custom template for each item in the batch.
+ fmt = question_response_format_zh[i]
+ else:
+ # Old logic: Choose the template based on the detected language.
+ fmt = question_response_format_zh if is_zh else question_response_format_en
+
+ # --- Step 4: Format the final input string ---
+ # The template may or may not include a placeholder for the reference text.
+ if "{reference}" in fmt:
+ raw_text = fmt.format(
+ question=question,
+ reference=reference,
+ response=response
+ )
+ else:
+ raw_text = fmt.format(question=question, response=response)
+
+ # --- Step 5: Prepend a system prompt if enabled ---
+ if system_prompt:
+ # Select the system prompt based on the language.
+ system_prompt_text = system_prompt_zh if is_zh else system_prompt_en
+ # Using deepcopy to avoid modifying the original system prompt object.
+ final_text = copy.deepcopy(system_prompt_text) + "\n" + raw_text
+ raw_texts.append(final_text)
+ else:
+ raw_texts.append(raw_text)
+
+ return raw_texts
+
+
+def build_general_engine_queries(
+ processor,
+ prompt_and_outputs: list,
+ references: list,
+ raw_images: list | None,
+ question_response_format_zh: str,
+ question_response_format_en: str,
+ system_prompt_zh: str,
+ system_prompt_en: str,
+):
+ """
+ Build general-RM engine prompts using the model's chat template.
+
+ The vLLM engine path for Qwen2.5-VL is much more stable when we append the
+ assistant generation prompt explicitly. Without this, the engine often
+ returns empty strings or prompt-continuation fragments instead of verdicts.
+ """
+ test_data = []
+ expected_image_counts = []
+ normalized_image_data = []
+
+ if raw_images is None:
+ raw_images = [None] * len(prompt_and_outputs)
+
+ for i, prompt_and_output in enumerate(prompt_and_outputs):
+ dialog = _parse_dialog(prompt_and_output)
+
+ if "user" in dialog:
+ question_raw = dialog["user"]
+ else:
+ question_raw = next(
+ (txt for role, txt in dialog.items() if role != "assistant"),
+ prompt_and_output,
+ )
+
+ if "assistant" in dialog:
+ response = dialog["assistant"]
+ else:
+ response = prompt_and_output.split("")[-1].strip()
+
+ question = _clean_vision_token(question_raw)
+ reference = references[i] if references is not None and i < len(references) else ""
+ is_zh = is_chinese(question)
+ fmt = question_response_format_zh if is_zh else question_response_format_en
+ system_prompt = system_prompt_zh if is_zh else system_prompt_en
+ user_text = fmt.format(question=question, response=response, reference=reference)
+
+ raw_image = raw_images[i] if i < len(raw_images) else None
+ has_image = raw_image is not None
+ expected_image_counts.append(1 if has_image else 0)
+ normalized_image_data.append([raw_image] if has_image else [])
+
+ user_content = [{"type": "text", "text": user_text}]
+ if has_image:
+ user_content = [
+ {
+ "type": "image",
+ "image": [],
+ "min_pixels": 224 * 224,
+ "max_pixels": 1280 * 1280,
+ },
+ {"type": "text", "text": user_text},
+ ]
+
+ test_data.append(
+ [
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
+ {"role": "user", "content": user_content},
+ ]
+ )
+
+ queries = processor.apply_chat_template(
+ test_data,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ if isinstance(queries, str):
+ queries = [queries]
+
+ fixed_queries = []
+ for query, expected_image_count in zip(queries, expected_image_counts):
+ query_image_token_count = query.count("<|image_pad|>")
+ if query_image_token_count > expected_image_count:
+ excess_tokens = query_image_token_count - expected_image_count
+ query = query.replace("<|image_pad|>", "", excess_tokens)
+ fixed_queries.append(query)
+
+ return fixed_queries, normalized_image_data
+
+
+def preprocess_inputs(
+ tokenizer = None,
+ processor = None,
+ device = get_current_device(),
+ system_prompt: Optional[str] = None,
+ question_response_format: str = "",
+ input_ids: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ pad_token: str = "",
+ eos_token: str = "<|endoftext|>",
+ clean_or_replace_vision_token: bool = False,
+ vision_token_process_type: str = 'clean',
+ padding_side: str = "left",
+ return_think_content: bool = False,
+ debug: bool = False,
+ queries: Optional[list] = None,
+ return_raw_texts: bool = False,
+):
+ """
+ Preprocess inputs for HuggingFace models.
+
+ Supports building inputs from ``input_ids`` or ``queries``, optional vision-token
+ processing, and chain-of-thought content separation.
+
+ :param tokenizer: HF tokenizer instance
+ :param processor: HF processor instance
+ :param device: Target device
+ :param system_prompt: System prompt (optional; use to distinguish value/knowledge from safety/normal data)
+ :type system_prompt: str or None
+ :param question_response_format: Q&A format template
+ :type question_response_format: str
+ :param input_ids: Input token IDs
+ :type input_ids: torch.Tensor or None
+ :param pixel_values: Image pixel values
+ :type pixel_values: torch.Tensor or None
+ :param pad_token: Padding token
+ :type pad_token: str
+ :param eos_token: End-of-sequence token
+ :type eos_token: str
+ :param clean_or_replace_vision_token: Whether to process vision tokens
+ :type clean_or_replace_vision_token: bool
+ :param vision_token_process_type: Processing method (``'clean'`` or ``'replace'``)
+ :type vision_token_process_type: str
+ :param padding_side: Padding direction
+ :type padding_side: str
+ :param return_think_content: Whether to separate chain-of-thought content
+ :type return_think_content: bool
+ :param debug: Debug mode
+ :type debug: bool
+ :param queries: List of query texts
+ :type queries: list or None
+ :param return_raw_texts: Whether to return raw texts instead of tensors
+ :type return_raw_texts: bool
+ :return: Standard mode returns ``(input_ids, attention_mask, response_empty)``;
+ CoT mode returns ``(answer_input_ids, answer_mask, think_input_ids, think_mask, valid_think, response_empty)``;
+ raw text mode returns ``(raw_texts, ...)``.
+ :rtype: tuple
+ """
+ if input_ids is not None:
+ processor.tokenizer.padding_side = padding_side
+ queries = tokenizer.batch_decode(input_ids, skip_special_tokens=False)
+ else:
+ assert queries is not None
+
+ for i, query in enumerate(queries):
+ if clean_or_replace_vision_token:
+ if vision_token_process_type == 'clean': # value, knowledge
+ queries[i] = _clean_vision_token(query)
+ elif vision_token_process_type == 'replace': # safety, normal
+ queries[i] = _replace_vision_token(query)
+ else:
+ raise KeyError(f"Invalid vision token process type: {vision_token_process_type}")
+ queries[i] = _strip_pad_eos(queries[i], pad_token, eos_token) + eos_token
+
+ # Extract question and response from query using regex
+ pattern = r"<\|im_start\|>(\w+)\n(.*?)<\|im_end\|>"
+ # NOTE: parse dialog logic haven't adapt to deepseek model now
+ def _prepare_message(dialog, test_data, image_token_count_list):
+ question = dialog.get('user', '')
+ response = dialog.get('assistant', '')
+ image_token_count_list.append(question.count('<|image_pad|>'))
+ if system_prompt is not None:
+ test_data.append(
+ [
+ {"role": "system", "content":[{"type": "text", "text": system_prompt}]},
+ {"role": "user", "content": [{"type": "image", "image": [], "min_pixels": 224 * 224, "max_pixels": 1280 * 1280}, {"type": "text", "text": question_response_format.format(question=question, response=response)}]}
+ ]
+ )
+ else:
+ test_data.append(
+ [
+ {"role": "user", "content": [{"type": "text", "text": question_response_format.format(question=question, response=response)}]}
+ ]
+ )
+ if debug and dist.is_initialized() and dist.get_rank() == 0:
+ print(f"test_data:\n {test_data[0]}\n")
+
+ # Process all queries in the batch at once
+ test_data, image_token_count_list = [], []
+ think_test_data, think_image_token_count_list, valid_think = [], [], []
+ response_empty = []
+ for query in queries:
+ matches = re.findall(pattern, query, re.DOTALL)
+ dialog = {}
+ if return_think_content:
+ think_dialog = {}
+ valid_think_flag = False
+ for role, content in matches:
+ dialog[role] = content.strip()
+ if return_think_content:
+ think_dialog[role] = content.strip()
+ # If assistant's reply contains thinking chain content wrapped in and , extract only the content after
+ if role == "assistant" and "" in content and "" in content:
+ # Find the position of the last
+ think_end_pos = content.rfind("")
+ if think_end_pos != -1:
+ # Extract content after and remove leading/trailing whitespace
+ dialog[role] = content[think_end_pos + len(""):].strip()
+ if return_think_content:
+ think_dialog[role] = content[:think_end_pos + len("") + 1].strip()
+ valid_think_flag = True
+
+ _prepare_message(dialog, test_data, image_token_count_list)
+ response_empty.append(dialog.get('assistant', '') == '')
+ if return_think_content:
+ valid_think.append(valid_think_flag)
+ _prepare_message(think_dialog, think_test_data, think_image_token_count_list)
+
+ def _get_batch_input(test_data, image_token_count_list, return_raw_texts):
+ # Process the entire batch at once
+ if system_prompt is not None:
+ # Only apply chat template when system prompt is provided
+ queries = processor.apply_chat_template(test_data, tokenize=False, add_generation_prompt=False)
+ else:
+ # For data without system prompt, format directly without applying chat template
+ queries = [item[0]["content"][0]["text"] for item in test_data]
+
+ # TODO: `apply_chat_template` will add a extra image token in the query, so we need to remove it now, we need more elegant way
+ for i, query in enumerate(queries):
+ query_image_token_count = query.count('<|image_pad|>')
+ if query_image_token_count > image_token_count_list[i]:
+ # Replace all excess image tokens to match the expected count
+ excess_tokens = query_image_token_count - image_token_count_list[i]
+ queries[i] = query.replace('<|image_pad|>', '', excess_tokens)
+
+ if not return_raw_texts:
+ with torch.no_grad():
+ batch_inputs = processor(
+ text=queries,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+ return batch_inputs
+ else:
+ return queries
+
+ answer_batch_input = _get_batch_input(test_data, image_token_count_list, return_raw_texts)
+ if return_think_content:
+ think_batch_input = _get_batch_input(think_test_data, think_image_token_count_list, return_raw_texts)
+ if not return_raw_texts:
+ return answer_batch_input['input_ids'], answer_batch_input['attention_mask'], think_batch_input['input_ids'], think_batch_input['attention_mask'], valid_think, response_empty
+ else:
+ return answer_batch_input, think_batch_input, valid_think
+ else:
+ if not return_raw_texts:
+ return answer_batch_input['input_ids'], answer_batch_input['attention_mask'], response_empty
+ else:
+ return answer_batch_input
+
+
+ if engine._tp_size > 1:
+ num_per_rank = len(texts) // engine._tp_size
+ texts = texts[engine._tp_rank * num_per_rank : (engine._tp_rank+1) * num_per_rank]
+
+ return texts
+
+
+# ============================================================================
+# General Reward Model
+# ============================================================================
+
+class AllowedTokensLogitsProcessor(LogitsProcessor):
+ def __init__(self, allowed_token_ids):
+ self.allowed_token_ids = set(allowed_token_ids)
+
+ def __call__(self, input_ids, scores):
+ # Set all non-allowed tokens to very negative values
+ mask = torch.ones_like(scores) * float('-inf')
+ for token_id in self.allowed_token_ids:
+ mask[:, token_id] = 0
+ return scores + mask
+
+
+class Qwen2VLRewardModelGeneral(nn.Module):
+ """
+ General quality reward model that evaluates answer correctness based on reference answers.
+
+ Scoring rules:
+
+ - ``1.0``: Completely correct (all sub-questions correct)
+ - ``0.5``: Partially correct (at least one sub-question correct, but not all)
+ - ``0.0``: Incorrect (all sub-questions wrong or answer irrelevant)
+
+ :param base_model: HF model or Engine instance
+ :param tokenizer: Tokenizer instance
+ :param processor: Processor instance
+ :param text_only: Whether to use text-only mode (no image inputs)
+ :type text_only: bool
+ """
+
+ general_scores = [0.0, 0.5, 1.0]
+ general_system_prompt_zh = """你是一个评分专家,负责根据参考答案reference评估assistant对user的回复是否正确且合理。
+ **你将收到包含以下XML标签的内容:``表示用户的问题,``表示助手的回答,``表示参考答案。**
+ 请严格按以下规则输出固定稀疏奖励:
+
+ 评估规则:
+ 1. 答案等价性:
+ - 简洁答案和带解题步骤的答案都接受,只要包含正确答案
+ - 答案可能出现在回答的开头、中间或结尾
+ - 只比较核心答案,忽略解释部分
+
+ 2. 数值等价性:
+ - 不同格式的数字视为等价(如2,"2",['2'],"答案是2")
+ - 百分比可以用小数或%表示(如28%=0.28)
+ - 带/不带逗号的数字视为等价(如123,456.7=123456.7)
+
+ 3. 格式灵活性:
+ - 列表、引号、表格或纯文本中的正确答案都接受
+ - 正确答案周围的额外解释或格式不影响评分
+ - 大小写不敏感
+
+ 4. 多参考答案情况:
+ - 参考答案有多个可接受答案时,匹配一个即可视为该部分正确。
+
+ 5. 多子问题情况:
+ - 如果问题包含多个子问题,需要逐一评估assistant对每个子问题的回答。
+ - 只有当所有子问题都回答正确时,总分才为 1.0。
+ - 如果至少有一个子问题回答正确,但并非所有子问题都正确,则总分为 0.5。
+ - 如果所有子问题都回答错误或回答与问题无关,则总分为 0.0。
+
+ 6. 容错性:
+ - 轻微拼写错误或措辞差异不影响评分
+ - 等价数学表达式视为正确
+
+ 输出要求:
+ 1. **仅允许输出以下三个数值之一:0.0、0.5、1.0**
+ 2. 根据参考答案与回答的匹配程度选择:
+ - 完全正确 (所有子问题均正确) → 1.0
+ - 部分正确 (至少答对一个子问题,但非全部) → 0.5
+ - 错误 (所有子问题均错误或回答与问题无关) → 0.0
+ 3. 直接输出数值,不需要任何解释"""
+
+ question_response_format_zh = """请根据以下内容进行评估:
+
+
+ {question}
+
+
+
+
+ {response}
+
+
+
+ {reference}
+ """
+
+ general_system_prompt_en = """You are a scoring expert responsible for evaluating whether the assistant's response to the user is correct and reasonable based on the reference answer.
+ **You will receive content with the following XML tags: `` represents the user's question, `` represents the assistant's answer, and `` represents the reference answer.**
+ Please strictly output fixed sparse rewards according to the following rules:
+
+ Evaluation Rules:
+ 1. Answer Equivalence:
+ - Both concise answers and answers with solution steps are accepted, as long as they contain the correct answer
+ - The answer may appear at the beginning, middle, or end of the response
+ - Only compare core answers, ignore explanation parts
+
+ 2. Numerical Equivalence:
+ - Numbers in different formats are considered equivalent (e.g., 2, "2", ['2'], "the answer is 2")
+ - Percentages can be expressed as decimals or % (e.g., 28% = 0.28)
+ - Numbers with/without commas are equivalent (e.g., 123,456.7 = 123456.7)
+
+ 3. Format Flexibility:
+ - Correct answers in lists, quotes, tables, or plain text are all accepted
+ - Additional explanations or formatting around the correct answer do not affect scoring
+ - Case insensitive
+
+ 4. Multiple Reference Answers:
+ - When there are multiple acceptable reference answers, matching any one is considered correct for that part.
+
+ 5. Multiple Sub-questions:
+ - If the question contains multiple sub-questions, evaluate the assistant's answer for each sub-question.
+ - Only when all sub-questions are answered correctly will the total score be 1.0.
+ - If at least one sub-question is answered correctly, but not all sub-questions are correct, the total score is 0.5.
+ - If all sub-questions are answered incorrectly or the answer is irrelevant to the question, the total score is 0.0.
+
+ 6. Error Tolerance:
+ - Minor spelling errors or wording differences do not affect scoring
+ - Equivalent mathematical expressions are considered correct
+
+ Output Requirements:
+ 1. **Only the following three values are allowed: 0.0, 0.5, 1.0**
+ 2. Choose based on the degree of match between the reference answer and the response:
+ - Completely correct (all sub-questions correct) → 1.0
+ - Partially correct (at least one sub-question correct, but not all) → 0.5
+ - Incorrect (all sub-questions incorrect or answer irrelevant to question) → 0.0
+ 3. Output the value (0.0, 0.5, 1.0) directly, no explanation needed"""
+
+ question_response_format_en = """Please evaluate based on the following content:
+
+
+ {question}
+
+
+
+
+ {response}
+
+
+
+ {reference}
+ """
+
+ ALLOWED_STR_TOKENS = ["0", "1", "0.0", "0.5", "1.0"]
+
+ def __init__(self, base_model, tokenizer, processor, text_only: bool = False):
+ super().__init__()
+ self.base_model: nn.Module = base_model
+ self.tokenizer = tokenizer
+ self.processor = processor
+ self.device = torch.cuda.current_device()
+ self.text_only = text_only
+
+ self._allowed_token_seqs: list[list[int]] = []
+ for s in self.ALLOWED_STR_TOKENS:
+ ids = self.tokenizer.encode(s, add_special_tokens=False)
+ self._allowed_token_seqs.append(ids)
+
+ self._verdict_log_enabled = os.environ.get("ORM_RL_DEMO_RM_VERDICT_LOG", "0") == "1"
+ self._verdict_log_max = int(os.environ.get("ORM_RL_DEMO_RM_VERDICT_LOG_MAX", "128"))
+ self._verdict_log_count = 0
+
+ first_ids = {seq[0] for seq in self._allowed_token_seqs}
+ self._logits_proc = [AllowedTokensLogitsProcessor(first_ids)]
+ self._max_answer_len = max(len(x) for x in self._allowed_token_seqs)
+
+ @torch.no_grad()
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ pixel_values: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.Tensor] = None,
+ references: List[str] | None = None,
+ prompt_and_outputs=None,
+ prompt_and_output=None,
+ raw_images=None,
+ **kwargs, # for compatibility
+ ):
+ """
+ Returns: {'score': FloatTensor[B]}, only in 0/0.5/1
+ """
+ if prompt_and_outputs is None:
+ prompt_and_outputs = prompt_and_output
+ if prompt_and_outputs is None:
+ prompt_and_outputs = kwargs.get("prompt_and_output")
+ if prompt_and_outputs is None:
+ raise ValueError("`prompt_and_outputs` or `prompt_and_output` is required")
+
+ raw_texts = preprocess_inputs_sglang(
+ prompt_and_outputs,
+ references,
+ self.question_response_format_zh,
+ self.question_response_format_en,
+ self.general_system_prompt_zh,
+ self.general_system_prompt_en,
+ system_prompt=True,
+ )
+
+ if is_engine(self.base_model):
+ raw_texts, raw_images = build_general_engine_queries(
+ self.processor,
+ prompt_and_outputs,
+ references,
+ raw_images,
+ self.question_response_format_zh,
+ self.question_response_format_en,
+ self.general_system_prompt_zh,
+ self.general_system_prompt_en,
+ )
+ gen_texts, _ = _hf_or_engine_generate(
+ self.base_model,
+ prompts=raw_texts,
+ image_data=raw_images,
+ max_new_tokens=4,
+ temperature=0.0,
+ )
+ else:
+ model_in = self.processor(
+ text=raw_texts, padding=True, return_tensors="pt"
+ ).to(self.device)
+ _, gen_ids = _hf_or_engine_generate(
+ self.base_model,
+ input_ids=model_in["input_ids"],
+ attention_mask=model_in["attention_mask"],
+ pixel_values=None,
+ image_grid_thw=None,
+ max_new_tokens=self._max_answer_len,
+ temperature=0.0,
+ do_sample=False,
+ logits_processor=self._logits_proc,
+ )
+ gen_texts = self.tokenizer.batch_decode(
+ gen_ids, skip_special_tokens=True
+ )
+
+ log_on_rank0 = (not dist.is_initialized()) or dist.get_rank() == 0
+
+ def _log_verdict_detail(tag: str, sample_idx: int, raw_text: str, **fields) -> None:
+ if not (self._verdict_log_enabled and log_on_rank0):
+ return
+ if self._verdict_log_count >= self._verdict_log_max:
+ return
+ raw_text = raw_text if isinstance(raw_text, str) else str(raw_text)
+ preview = " ".join(raw_text.split())
+ if len(preview) > 200:
+ preview = preview[:200] + "..."
+ extras = " ".join(f"{key}={value}" for key, value in fields.items())
+ print(
+ f"[ORM_RM_GENERAL_VERDICT_{tag}] "
+ f"sample_idx={sample_idx} text_len={len(raw_text)} raw={preview!r} {extras}".rstrip(),
+ flush=True,
+ )
+ self._verdict_log_count += 1
+
+ verdict_summary = {
+ "total": len(gen_texts),
+ "empty": 0,
+ "no_numeric": 0,
+ "value_error": 0,
+ "parsed": 0,
+ "parsed_0": 0,
+ "parsed_0_5": 0,
+ "parsed_1": 0,
+ }
+
+ scores = []
+ for sample_idx, txt in enumerate(gen_texts):
+ txt = txt if isinstance(txt, str) else str(txt)
+ if txt == "":
+ verdict_summary["empty"] += 1
+ scores.append(0.0)
+ _log_verdict_detail("EMPTY", sample_idx, txt, fallback="0.0")
+ continue
+
+ m = re.search(r"[-+]?\d*\.?\d+", txt)
+ if not m:
+ verdict_summary["no_numeric"] += 1
+ scores.append(0.0)
+ _log_verdict_detail("NO_NUMERIC", sample_idx, txt, fallback="0.0")
+ continue
+
+ matched_token = m.group()
+ try:
+ val = float(matched_token)
+ except ValueError as exc:
+ verdict_summary["value_error"] += 1
+ scores.append(0.0)
+ _log_verdict_detail(
+ "VALUE_ERROR",
+ sample_idx,
+ txt,
+ token=repr(matched_token),
+ error=repr(exc),
+ fallback="0.0",
+ )
+ continue
+
+ nearest = min(self.general_scores, key=lambda x: abs(x - val))
+ verdict_summary["parsed"] += 1
+ if nearest == 0.0:
+ verdict_summary["parsed_0"] += 1
+ elif nearest == 0.5:
+ verdict_summary["parsed_0_5"] += 1
+ elif nearest == 1.0:
+ verdict_summary["parsed_1"] += 1
+ _log_verdict_detail(
+ "PARSED",
+ sample_idx,
+ txt,
+ token=repr(matched_token),
+ parsed=repr(val),
+ snapped=repr(nearest),
+ )
+ scores.append(nearest)
+
+ if self._verdict_log_enabled and log_on_rank0:
+ print(
+ "[ORM_RM_GENERAL_VERDICT_SUMMARY] "
+ + " ".join(f"{key}={value}" for key, value in verdict_summary.items()),
+ flush=True,
+ )
+
+ return {"score": torch.tensor(scores, device=self.device)}
diff --git a/examples/orm_rl_demo/reward_models_utils.py b/examples/orm_rl_demo/reward_models_utils.py
new file mode 100755
index 00000000..cacb3662
--- /dev/null
+++ b/examples/orm_rl_demo/reward_models_utils.py
@@ -0,0 +1,886 @@
+"""
+General reward model utilities for the ORM RL Geo3K demo.
+
+This example intentionally keeps a single general reward-model path, plus the
+Geo3K-specific reward mixing logic that combines format, general-model, and
+accuracy rewards during training and evaluation.
+"""
+from __future__ import annotations
+
+import re
+import os
+import json
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
+
+import torch
+from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
+
+from lightrft.models.monkey_patch.hf_generate_patch import (
+ apply_monkey_patch_to_generation_mixin,
+)
+from lightrft.utils import get_current_device
+
+from reward_models import Qwen2VLRewardModelGeneral
+
+# ============================================================================
+# Configuration Classes
+# ============================================================================
+
+class RewardModelType(str, Enum):
+ """Enumeration of reward model types supported by this demo."""
+ GENERAL = "general"
+
+
+@dataclass
+class RewardModelConfig:
+ """
+ Configuration for a single reward model.
+
+ :param rtype: Reward model type.
+ :type rtype: RewardModelType
+ :param path: Model directory path or HuggingFace model name
+ :type path: str
+ :param use_engine: Whether to use SGLang engine instead of HuggingFace. Default to False
+ :type use_engine: bool
+ """
+ rtype: RewardModelType
+ path : str
+ use_engine: bool = False
+
+
+# ============================================================================
+# Model Builder Registry
+# ============================================================================
+
+_BUILDERS: Dict[RewardModelType, Callable] = {}
+
+def register_builder(rtype: RewardModelType) -> Callable:
+ """
+ Decorator to register a builder function for a specific reward model type.
+
+ Usage:
+ @register_builder(RewardModelType.GENERAL)
+ def build_general(cfg, strategy):
+ ...
+
+ :param rtype: Reward model type to register builder for
+ :type rtype: RewardModelType
+ :return: Decorator function
+ :rtype: Callable
+ """
+ def deco(fn: Callable) -> Callable:
+ _BUILDERS[rtype] = fn
+ return fn
+ return deco
+
+
+RawRewardInput = Union[str, Dict[str, str], List[Dict[str, str]], None]
+
+
+def extract_response(text: str) -> str:
+ """
+ Extract assistant completion from a full chat transcript.
+
+ Running format/accuracy checks on the full decoded sequence can produce
+ false positives because the system prompt itself contains formatting
+ examples such as ``...`` and ``\\boxed{}``.
+ """
+ if not isinstance(text, str):
+ return ""
+
+ s = text.strip()
+ if not s:
+ return s
+
+ assistant_marker = "<|im_start|>assistant"
+ if assistant_marker in s:
+ start = s.rfind(assistant_marker) + len(assistant_marker)
+ tail = s[start:]
+ end_idx = tail.find("<|im_end|>")
+ if end_idx != -1:
+ tail = tail[:end_idx]
+ return tail.strip()
+ return s
+
+
+# ============================================================================
+# Configuration Parsing
+# ============================================================================
+
+def _guess_rtype_from_path(path: str) -> RewardModelType:
+ """
+ Infer reward model type from path string.
+
+ :param path: Model path or name
+ :type path: str
+ :return: Inferred reward type
+ :rtype: RewardModelType
+ """
+ return RewardModelType.GENERAL
+
+def parse_reward_pretrain(
+ raw: RawRewardInput,
+ *,
+ global_use_engine: bool
+) -> Tuple[List[RewardModelConfig], Dict[str, int]]:
+ """
+ Parse reward model configuration from various input formats.
+
+ Supported formats:
+ 1. JSON: '{"general":"/path/to/rm"}'
+ 2. CSV: 'general:/path/to/rm'
+ 3. Plain path: '/path/to/rm' (treated as the general reward model)
+ 4. Dict/List: {'type':'general','path':'/path/to/rm'} or [{'type':'general','path':'/path/to/rm'}]
+
+ Extra feature: Append ?engine=true to path to override global engine setting
+ Example: 'general:/path/to/model?engine=true'
+
+ :param raw: Raw configuration input (string, dict, list, or None)
+ :type raw: RawRewardInput
+ :param global_use_engine: Global flag for whether to use engine mode
+ :type global_use_engine: bool
+ :return: Tuple of (cfgs, label_map) where cfgs is a list of RewardModelConfig objects
+ and label_map is a dict mapping reward type to index {str: int}
+ :rtype: Tuple[List[RewardModelConfig], Dict[str, int]]
+ :raises TypeError: If raw input format is not supported
+
+ Note:
+ The demo only supports RewardModelType.GENERAL.
+ """
+ if raw is None: raw = ""
+
+ # ---------- 1. Convert string to unified list[(key,path,flag)] ----------
+ pair_list: List[Tuple[str, str, Optional[bool]]] = []
+ if isinstance(raw, str):
+ s = raw.strip().lstrip("{").rstrip("}")
+ # ① JSON
+ if raw.strip().startswith("{") and raw.strip().endswith("}"):
+ try:
+ obj = json.loads(raw)
+ pair_list = [(k, v, None) for k, v in obj.items()]
+ except json.JSONDecodeError:
+ pass
+ if not pair_list:
+ # ② kv/comma-separated string
+ for seg in re.split(r"\s*,\s*", s):
+ if not seg: continue
+ if ":" in seg:
+ k, v = seg.split(":", 1)
+ pair_list.append((k.strip(), v.strip(), None))
+ else: # pure path
+ pair_list.append(("?", seg.strip(), None))
+ elif isinstance(raw, dict):
+ pair_list = [(k, v, None) for k, v in raw.items()]
+ elif isinstance(raw, list):
+ for d in raw:
+ pair_list.append((d["type"], d["path"], d.get("engine")))
+ else:
+ raise TypeError("Unsupported --reward_pretrain format")
+
+ # ---------- 2. Generate cfg list ----------
+ cfgs: List[RewardModelConfig] = []
+ for key, path, flag in pair_list:
+ # Parse path?engine=true/false
+ use_engine = global_use_engine
+ if "?engine=" in path:
+ path, qs = path.split("?engine=", 1)
+ use_engine = qs.lower() in ("1", "true", "yes")
+ if flag is not None:
+ use_engine = flag
+ if key == "?":
+ rtype = _guess_rtype_from_path(path)
+ else:
+ try:
+ rtype = RewardModelType(key)
+ except ValueError as exc:
+ raise ValueError(
+ "examples/orm_rl_demo only supports the general reward model. "
+ f"Got reward type: {key}"
+ ) from exc
+ if rtype is not RewardModelType.GENERAL:
+ raise ValueError(
+ "examples/orm_rl_demo only supports the general reward model. "
+ f"Got reward type: {rtype.value}"
+ )
+ cfgs.append(RewardModelConfig(rtype, path, use_engine))
+
+ # Ensure label_map order is stable and contains general
+ uniq: List[RewardModelType] = []
+ for c in cfgs:
+ if c.rtype not in uniq: uniq.append(c.rtype)
+ if RewardModelType.GENERAL not in uniq:
+ uniq.append(RewardModelType.GENERAL)
+ label_map = {rt.value: i for i, rt in enumerate(uniq)}
+ return cfgs, label_map
+
+
+def _normalized_geo3k_general_weights() -> Tuple[float, float, float]:
+ """
+ Return normalized (format, model, accuracy) weights for the Geo3K+ORM mix.
+ """
+ format_w = float(os.environ.get("ORM_RL_DEMO_GEO3K_FORMAT_WEIGHT", "0.1"))
+ model_w = float(os.environ.get("ORM_RL_DEMO_GEO3K_MODEL_WEIGHT", "0.2"))
+ accuracy_w = float(os.environ.get("ORM_RL_DEMO_GEO3K_ACCURACY_WEIGHT", "0.7"))
+ total = format_w + model_w + accuracy_w
+ if total <= 0:
+ return 0.1, 0.2, 0.7
+ return format_w / total, model_w / total, accuracy_w / total
+
+
+def _infer_rm_engine_tp_size(pretrain_path: str) -> int:
+ override = os.environ.get("ORM_RL_DEMO_RM_ENGINE_TP")
+ if override:
+ return int(override)
+
+ path = pretrain_path.lower()
+ if "72b" in path:
+ return 8
+ if "32b" in path:
+ return 4
+ if "14b" in path:
+ return 2
+ if "8b" in path or "7b" in path:
+ return 1
+ return 1
+
+
+def _rm_engine_mem_util() -> float:
+ return float(os.environ.get("ORM_RL_DEMO_RM_ENGINE_MEM_UTIL", "0.15"))
+
+
+# ============================================================================
+# Model Loading Functions
+# ============================================================================
+
+def _load_hf_model(
+ pretrain_path: str,
+ device: torch.device
+) -> Tuple[Qwen2_5_VLForConditionalGeneration, Any]:
+ """
+ Load HuggingFace model and processor.
+
+ :param pretrain_path: Model path or HuggingFace model name
+ :type pretrain_path: str
+ :param device: Target device
+ :type device: torch.device
+ :return: Tuple of (base_model, processor)
+ :rtype: Tuple[Qwen2_5_VLForConditionalGeneration, Any]
+ """
+ base = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ pretrain_path,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ )
+ processor = AutoProcessor.from_pretrained(
+ pretrain_path, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
+ )
+ processor.tokenizer.padding_side = "left"
+ return base, processor
+
+
+def _load_engine(
+ pretrain_path: str,
+ device: torch.device
+) -> Tuple[Any, Any]:
+ """
+ Load reward-model inference engine and processor.
+
+ Automatically determines tensor parallelism size based on reward model type:
+ - 7B/8B models default to tp_size = 1
+ - larger models scale tp_size by model size
+
+ :param pretrain_path: Model path or HuggingFace model name
+ :type pretrain_path: str
+ :param device: Target device
+ :type device: torch.device
+ :return: Tuple of (engine, processor)
+ :rtype: Tuple[Any, Any]
+
+ Note:
+ Engine is set to sleep mode after loading to save memory.
+ Backend selection is controlled by ORM_RL_DEMO_RM_ENGINE_BACKEND:
+ - sglang: use SGLang only
+ - vllm: use vLLM only
+ - auto: try SGLang first, then fall back to vLLM
+ """
+ tp_size = _infer_rm_engine_tp_size(pretrain_path)
+ engine_mem_util = _rm_engine_mem_util()
+ backend = os.environ.get("ORM_RL_DEMO_RM_ENGINE_BACKEND", "auto").strip().lower()
+
+ print(
+ f"[reward_models_utils] Loading engine from {pretrain_path} "
+ f"with tp_size={tp_size}, engine_mem_util={engine_mem_util}, backend={backend}"
+ )
+
+ engine = None
+
+ if backend in ("auto", "sglang"):
+ try:
+ from lightrft.strategy.sglang_utils import get_sglang_engine
+
+ engine = get_sglang_engine(
+ pretrain_path,
+ engine_mem_util=engine_mem_util,
+ tp_size=tp_size,
+ skip_tokenizer_init=False,
+ disable_cuda_graph=True, # only for deepseek, TODO: why deepseek pipeline (examples/orm_rl_demo/run_fsdp_deepseek.sh) need this?
+ )
+ print(f"[reward_models_utils] Loaded SGLang engine from {pretrain_path} with tp_size={tp_size}")
+ except Exception as exc:
+ if backend == "sglang":
+ raise
+ print(
+ f"[reward_models_utils] SGLang engine init failed for {pretrain_path}: "
+ f"{type(exc).__name__}: {exc}. Falling back to vLLM."
+ )
+
+ if engine is None:
+ from lightrft.strategy.vllm_utils import get_vllm_engine
+
+ max_model_len = int(os.environ.get("ORM_RL_DEMO_RM_ENGINE_MAX_MODEL_LEN", "4096"))
+ engine = get_vllm_engine(
+ pretrain_path,
+ dtype="bfloat16",
+ tp_size=tp_size,
+ mem_util=engine_mem_util,
+ max_model_len=max_model_len,
+ enable_sleep=False,
+ limit_mm_per_prompt={"image": 1},
+ )
+ print(
+ f"[reward_models_utils] Loaded vLLM engine from {pretrain_path} "
+ f"with tp_size={tp_size}, max_model_len={max_model_len}, enable_sleep=False"
+ )
+
+ if not hasattr(engine, "llm_engine") or engine.llm_engine.vllm_config.model_config.enable_sleep_mode:
+ engine.sleep() # Sleep to save memory
+
+ processor = AutoProcessor.from_pretrained(
+ pretrain_path, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
+ )
+ processor.tokenizer.padding_side = "left"
+ return engine, processor
+
+
+# ============================================================================
+# Model Builders for Each Reward Type
+# ============================================================================
+
+@register_builder(RewardModelType.GENERAL)
+def build_general(
+ cfg: RewardModelConfig,
+ strategy: Any,
+ base: Optional[Tuple[Any, Any]] = None
+) -> Tuple[Qwen2VLRewardModelGeneral, Any]:
+ """
+ Build General quality reward model.
+
+ :param cfg: Reward model configuration
+ :type cfg: RewardModelConfig
+ :param strategy: Training strategy instance
+ :type strategy: Any
+ :param base: Optional shared base model (engine, processor) tuple. Default to None
+ :type base: Optional[Tuple[Any, Any]]
+ :return: Tuple of (model, tokenizer)
+ :rtype: Tuple[Qwen2VLRewardModelGeneral, Any]
+ """
+ if cfg.use_engine:
+ if base:
+ engine, proc = base
+ else:
+ engine, proc = _load_engine(cfg.path, get_current_device())
+ model = Qwen2VLRewardModelGeneral(engine, proc.tokenizer, proc, text_only=strategy.args.text_only)
+ return model, proc.tokenizer
+ else:
+ base_model, proc = _load_hf_model(cfg.path, get_current_device())
+ model = Qwen2VLRewardModelGeneral(base_model, proc.tokenizer, proc, text_only=strategy.args.text_only)
+ model.eval()
+ return model, proc.tokenizer
+
+# ============================================================================
+# Main Initialization Entry Point
+# ============================================================================
+
+def load_reward_models(
+ raw_reward_pretrain: RawRewardInput,
+ strategy: Any,
+ use_engine: bool = False,
+) -> Tuple[List[Any], List[Any], Dict[str, int]]:
+ """
+ Load and initialize all reward models from configuration.
+
+ This is the main entry point for loading reward models. It handles:
+ - Configuration parsing
+ - Base model sharing (to save memory)
+ - Model initialization with proper context
+ - Monkey patching for HuggingFace generation
+
+ :param raw_reward_pretrain: Raw configuration (see parse_reward_pretrain)
+ :type raw_reward_pretrain: RawRewardInput
+ :param strategy: Training strategy instance
+ :type strategy: Any
+ :param use_engine: Global flag for using SGLang engine. Default to False
+ :type use_engine: bool
+ :return: Tuple of (reward_models, reward_tokenizers, label_map) where
+ reward_models is a list of initialized reward model instances,
+ reward_tokenizers is a list of corresponding tokenizers,
+ and label_map is a dict mapping reward type to index
+ :rtype: Tuple[List[Any], List[Any], Dict[str, int]]
+
+ Note:
+ Models sharing the same base path will reuse the same loaded base model
+ to reduce memory footprint.
+ """
+ apply_monkey_patch_to_generation_mixin()
+
+ cfgs, label_map = parse_reward_pretrain(
+ raw_reward_pretrain, global_use_engine=use_engine
+ )
+
+ rms: List[Any] = []
+ toks: List[Any] = []
+
+ # Share base models across reward models to save memory
+ # Since some reward models can share the same base model, we only load it once
+ shared_bases: Dict[Tuple[str, bool], Tuple[Any, Any]] = {}
+ shared_count: Dict[Tuple[str, bool], int] = {}
+ for cfg in cfgs:
+ cache_key = (cfg.path, cfg.use_engine)
+ if cache_key not in shared_count:
+ shared_count[cache_key] = 1
+ else:
+ shared_count[cache_key] += 1
+
+ if shared_count[cache_key] == 1:
+ loader = _load_engine if cfg.use_engine else _load_hf_model
+ shared_bases[cache_key] = loader(cfg.path, get_current_device())
+ strategy.print(f"Init reward model {cfg.path} (engine={cfg.use_engine})")
+ else:
+ strategy.print(f"Use shared base model {cfg.path}")
+
+ for cfg in cfgs:
+ if cfg.rtype not in _BUILDERS:
+ raise RuntimeError(f"No builder for {cfg.rtype}")
+ strategy.print(f"Loading {cfg.rtype} from {cfg.path} (engine={cfg.use_engine})")
+
+ # Initialize model with proper context (supports FSDP/meta device init)
+ with strategy.init_model_context() as _:
+ # All reward types now support shared base models
+ rm, tok = _BUILDERS[cfg.rtype](cfg, strategy, base=shared_bases.get((cfg.path, cfg.use_engine)))
+
+ rms.append(rm)
+ toks.append(tok)
+ strategy.print(f"Loaded {cfg.rtype}")
+
+ return rms, toks, label_map
+
+
+
+# ============================================================================
+# Reward Functions
+# ============================================================================
+
+def format_reward_fn(sol: str) -> float:
+ """
+ Check if solution matches format: ... + non-empty content.
+
+ :param sol: Solution string to check
+ :type sol: str
+ :return: 1.0 if format is valid, 0.0 otherwise
+ :rtype: float
+ """
+ return 1.0 if re.match(r".*.+?\s*\S+", sol, re.DOTALL) else 0.0
+
+
+def rule_reward_fn(sol: str, gt: str) -> float:
+ """
+ Extract content after and verify against ground truth using mathruler.
+
+ :param sol: Solution string (may contain ...)
+ :type sol: str
+ :param gt: Ground truth answer
+ :type gt: str
+ :return: 1.0 if correct, 0.0 otherwise
+ :rtype: float
+ """
+ from mathruler.grader import extract_boxed_content, grade_answer
+ ans = sol.split("")[-1]
+ pred = extract_boxed_content(ans)
+ if pred == gt or grade_answer(pred, gt):
+ return 1.0
+ return 0.0
+
+# ============================================================================
+# Reward Recipe Configuration
+# ============================================================================
+
+# Original reward recipe for SVKG dataset training (after KG dataset training)
+
+def geo3k_accuracy_reward_fn(sol: str, gt: str) -> float:
+ """
+ Geo3K accuracy reward function.
+
+ Extract answer from \boxed{} notation and use mathruler to verify correctness.
+ This is based on the verl implementation for geo3k dataset.
+
+ :param sol: Solution string from model (should contain \boxed{answer})
+ :type sol: str
+ :param gt: Ground truth answer
+ :type gt: str
+ :return: 1.0 if answer is correct, 0.0 otherwise
+ :rtype: float
+ """
+ from mathruler.grader import extract_boxed_content, grade_answer
+ pred = extract_boxed_content(sol)
+ return 1.0 if grade_answer(pred, gt) else 0.0
+
+
+def geo3k_format_reward_fn(sol: str) -> float:
+ """
+ Geo3K format reward function.
+
+ Check if the solution follows the required format:
+ - Contains ... tags for reasoning
+ - Contains \boxed{} for final answer
+ - The think tags must appear BEFORE the boxed answer
+
+ This is based on the verl implementation for geo3k dataset.
+
+ :param sol: Solution string from model
+ :type sol: str
+ :return: 1.0 if format is correct, 0.0 otherwise
+ :rtype: float
+ """
+ # Strip leading/trailing whitespace for robust matching
+ sol_stripped = sol.strip()
+
+ # Check if solution contains both ... and \boxed{...}
+ # Use re.search to find positions
+ think_match = re.search(r'.*?', sol_stripped, re.DOTALL)
+ boxed_match = re.search(r'\\boxed\{.*?\}', sol_stripped, re.DOTALL)
+
+ # Both components must be present AND think must come before boxed
+ if think_match and boxed_match:
+ # Check that comes before \boxed
+ think_end = think_match.end()
+ boxed_start = boxed_match.start()
+ return 1.0 if think_end <= boxed_start else 0.0
+ else:
+ return 0.0
+
+
+def geo3k_combined_reward_fn(
+ sol: str,
+ gt: str,
+ format_weight: float = 0.1
+) -> float:
+ """
+ Geo3K combined reward function.
+
+ Combines format reward and accuracy reward with specified weights.
+ Default: 90% accuracy + 10% format (matching verl implementation)
+
+ :param sol: Solution string from model
+ :type sol: str
+ :param gt: Ground truth answer
+ :type gt: str
+ :param format_weight: Weight for format reward. Default to 0.1
+ :type format_weight: float
+ :return: Weighted combination of format and accuracy rewards
+ :rtype: float
+ """
+ acc_reward = geo3k_accuracy_reward_fn(sol, gt)
+ fmt_reward = geo3k_format_reward_fn(sol)
+ return (1.0 - format_weight) * acc_reward + format_weight * fmt_reward
+
+
+def gsm8k_accuracy_reward_fn(sol: str, gt: str) -> float:
+ """
+ GSM8K accuracy reward function.
+
+ Extract answer from \boxed{} notation and use mathruler to verify correctness.
+ This follows the same pattern as geo3k but for GSM8K dataset.
+
+ :param sol: Solution string from model (should contain \boxed{answer})
+ :type sol: str
+ :param gt: Ground truth answer
+ :type gt: str
+ :return: 1.0 if answer is correct, 0.0 otherwise
+ :rtype: float
+ """
+ from mathruler.grader import extract_boxed_content, grade_answer
+ pred = extract_boxed_content(sol)
+ return 1.0 if grade_answer(pred, gt) else 0.0
+
+
+def gsm8k_format_reward_fn(sol: str) -> float:
+ """
+ GSM8K format reward function.
+
+ Check if the solution follows the required format:
+ - Contains ... tags for reasoning
+ - Contains \boxed{} for final answer
+ - The think tags must appear BEFORE the boxed answer
+
+ This follows the same pattern as geo3k format checking.
+
+ :param sol: Solution string from model
+ :type sol: str
+ :return: 1.0 if format is correct, 0.0 otherwise
+ :rtype: float
+ """
+ # Strip leading/trailing whitespace for robust matching
+ sol_stripped = sol.strip()
+
+ # Check if solution contains both ... and \boxed{...}
+ # Use re.search to find positions
+ think_match = re.search(r'.*?', sol_stripped, re.DOTALL)
+ boxed_match = re.search(r'\\boxed\{.*?\}', sol_stripped, re.DOTALL)
+
+ # Both components must be present AND think must come before boxed
+ if think_match and boxed_match:
+ # Check that comes before \boxed
+ think_end = think_match.end()
+ boxed_start = boxed_match.start()
+ return 1.0 if think_end <= boxed_start else 0.0
+ else:
+ return 0.0
+
+
+def gsm8k_combined_reward_fn(
+ sol: str,
+ gt: str,
+ format_weight: float = 0.1
+) -> float:
+ """
+ GSM8K combined reward function.
+
+ Combines format reward and accuracy reward with specified weights.
+ Default: 90% accuracy + 10% format (matching verl and geo3k implementation)
+
+ :param sol: Solution string from model
+ :type sol: str
+ :param gt: Ground truth answer
+ :type gt: str
+ :param format_weight: Weight for format reward. Default to 0.1
+ :type format_weight: float
+ :return: Weighted combination of format and accuracy rewards
+ :rtype: float
+ """
+ acc_reward = gsm8k_accuracy_reward_fn(sol, gt)
+ fmt_reward = gsm8k_format_reward_fn(sol)
+ return (1.0 - format_weight) * acc_reward + format_weight * fmt_reward
+
+RECIPE: Dict[str, List[Tuple[str, Optional[str], float]]] = {
+ "general": [("model", "general", 1.0)],
+}
+
+
+def mix_rewards(
+ labels: Sequence[str],
+ model_scores: torch.Tensor,
+ label_map: Dict[str, int],
+ solution_strs: Sequence[str],
+ refs: Sequence[str],
+) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Mix rewards from multiple sources according to recipe configuration.
+
+ This function combines:
+ 1. Format reward (always applied)
+ 2. Model-based rewards (from neural reward models)
+ 3. Rule-based rewards (from heuristic functions)
+
+ :param labels: List of data labels (length B)
+ :type labels: Sequence[str]
+ :param model_scores: Tensor of model scores, shape (n_model, B)
+ :type model_scores: torch.Tensor
+ :param label_map: Mapping from reward type to model index
+ :type label_map: Dict[str, int]
+ :param solution_strs: List of solution strings (length B)
+ :type solution_strs: Sequence[str]
+ :param refs: List of reference answers (length B)
+ :type refs: Sequence[str]
+ :return: Tuple of (final_reward, metrics_dict) where final_reward is tensor of shape (B,)
+ containing combined rewards and metrics_dict contains detailed reward metrics
+ :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]]
+
+ Error handling:
+ - If a model is not loaded or index out of bounds, returns 1.0 with warning
+ - If label not in RECIPE, returns 1.0 with warning
+ - Never raises IndexError, always returns valid reward
+
+ Note:
+ Format reward is always computed first, then rewards from recipe are added
+ """
+ if torch.distributed.get_rank() == 0:
+ print(f"labels:{labels}, model_scores:{model_scores.tolist()}")
+ device = model_scores.device
+ n_model, B = model_scores.shape[0], len(labels)
+ assert model_scores.shape[1] == B, "model_scores second dimension must equal batch size"
+
+ final_reward = torch.zeros(B, dtype=torch.float32, device=device)
+
+ # Initialize metrics dict to track individual reward components
+ metrics_dict: Dict[str, torch.Tensor] = {
+ 'format_reward': torch.zeros(B, dtype=torch.float32, device=device),
+ 'accuracy_reward': torch.zeros(B, dtype=torch.float32, device=device),
+ 'general_model_reward': torch.zeros(B, dtype=torch.float32, device=device),
+ 'rule_reward': torch.zeros(B, dtype=torch.float32, device=device),
+ }
+
+ # ---------- Fallback scoring function ----------
+ def get_model_reward(key: str, i: int) -> float:
+ """
+ Try to return model score for , return 1.0 on failure.
+
+ :param key: Reward model type key
+ :type key: str
+ :param i: Sample index
+ :type i: int
+ :return: Model score or 1.0 if not available
+ :rtype: float
+ """
+ if key not in label_map:
+ print(f"Model reward <{key}> not loaded, using 1 as default reward")
+ return 1.0
+
+ idx = label_map[key]
+ if idx >= n_model:
+ print(f"Model reward <{key}> index {idx} out of bounds "
+ f"(n_model={n_model}), using 1 as default reward")
+ return 1.0
+
+ return float(model_scores[idx, i].item())
+
+ # ---------- Main loop ----------
+ geo3k_fmt_w, geo3k_model_w, geo3k_acc_w = _normalized_geo3k_general_weights()
+
+ for i, lab in enumerate(labels):
+ sol = solution_strs[i]
+ sol_completion = extract_response(sol)
+ gt = refs[i] if i < len(refs) else ""
+
+ if lab == "geo3k_general":
+ fmt_r = geo3k_format_reward_fn(sol_completion)
+ acc_r = geo3k_accuracy_reward_fn(sol_completion, gt)
+ model_score = get_model_reward("general", i)
+
+ final_reward[i] = (
+ geo3k_fmt_w * fmt_r
+ + geo3k_model_w * model_score
+ + geo3k_acc_w * acc_r
+ )
+ metrics_dict['format_reward'][i] = fmt_r
+ metrics_dict['accuracy_reward'][i] = acc_r
+ metrics_dict['general_model_reward'][i] = model_score
+ rule_w_total = geo3k_fmt_w + geo3k_acc_w
+ metrics_dict['rule_reward'][i] = (geo3k_fmt_w * fmt_r + geo3k_acc_w * acc_r) / rule_w_total
+ continue
+
+ # 1) format reward (always present)
+ r = format_reward_fn(sol_completion)
+ # Track separately
+ metrics_dict['format_reward'][i] = r
+
+ # 2) accumulate according to recipe
+ recipe = RECIPE.get(lab)
+ if recipe is None:
+ print(f"label <{lab}> not registered in RECIPE, giving 1 reward directly")
+ recipe = [] # or raise
+
+ for typ, key, w in recipe:
+ if typ == "model":
+ model_r = w * get_model_reward(key, i)
+ r += model_r
+ metrics_dict['general_model_reward'][i] += model_r
+
+ elif typ == "rule":
+ rule_r = w * rule_reward_fn(sol_completion, gt)
+ r += rule_r
+ metrics_dict['rule_reward'][i] += rule_r
+ metrics_dict['accuracy_reward'][i] = rule_r
+
+ elif typ == "geo3k_rule":
+ r = 0 # TODO: geo3k have own format reward
+ # Track separately
+ metrics_dict['accuracy_reward'][i] = 0
+ metrics_dict['format_reward'][i] = 0
+ # Geo3K pure rule-based reward (format + accuracy)
+ # Get individual components
+ acc_r = geo3k_accuracy_reward_fn(sol_completion, gt)
+ fmt_r = geo3k_format_reward_fn(sol_completion)
+ combined_r = (1.0 - 0.1) * acc_r + 0.1 * fmt_r
+ r += w * combined_r
+ # Track separately
+ metrics_dict['accuracy_reward'][i] = acc_r
+ metrics_dict['format_reward'][i] = fmt_r
+ elif typ == "gsm8k_rule":
+ r = 0 # TODO: gsm8k have own format reward
+ # Track separately
+ metrics_dict['accuracy_reward'][i] = 0
+ metrics_dict['format_reward'][i] = 0
+ # GSM8K pure rule-based reward (format + accuracy)
+ # Get individual components
+ acc_r = gsm8k_accuracy_reward_fn(sol_completion, gt)
+ fmt_r = gsm8k_format_reward_fn(sol_completion)
+ combined_r = (1.0 - 0.1) * acc_r + 0.1 * fmt_r
+ r += w * combined_r
+ # Track separately
+ metrics_dict['accuracy_reward'][i] = acc_r
+ metrics_dict['format_reward'][i] = fmt_r
+ else:
+ print(f"Unknown component type {typ}, ignoring")
+
+ final_reward[i] = r
+
+ return final_reward, metrics_dict
+
+
+def reward_fn(
+ model_reward_list: List[torch.Tensor],
+ labels: Sequence[str],
+ queries: Sequence[str],
+ refs: Sequence[str],
+ label_map: Dict[str, int],
+) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ External unified interface for computing final rewards.
+
+ This is the main entry point called by the trainer. It:
+ 1. Stacks individual model rewards into a single tensor
+ 2. Calls mix_rewards to combine all reward sources
+ 3. Returns final reward tensor
+
+ :param model_reward_list: List of reward tensors from each model, each shape (B,)
+ :type model_reward_list: List[torch.Tensor]
+ :param labels: List of data labels indicating reward type (length B)
+ :type labels: Sequence[str]
+ :param queries: List of query/solution strings (length B)
+ :type queries: Sequence[str]
+ :param refs: List of reference answers (length B)
+ :type refs: Sequence[str]
+ :param label_map: Mapping from reward type to model index
+ :type label_map: Dict[str, int]
+ :return: Tuple of (final_reward, metrics_dict) where final_reward is combined reward tensor
+ of shape (B,) and metrics_dict contains detailed reward metrics
+ :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]]
+
+ Note:
+ If model_reward_list is empty (no NN models), a placeholder zero tensor is created
+ """
+ # print(f"model_reward_list:{model_reward_list}, labels:{labels}, queries:{queries}, refs:{refs}, label_map:{label_map}")
+ # print(f"label_map:{label_map}")
+
+ # ------ stack to (n_model, B) ------
+ if model_reward_list:
+ model_scores = torch.stack(model_reward_list) # (n_model, B)
+ else:
+ # When no torch.nn model RM is available, give placeholder zero score
+ B = len(labels)
+ model_scores = torch.zeros(0, B, dtype=torch.float32, device="cuda")
+
+ # ------ call combination logic ------
+ return mix_rewards(labels, model_scores, label_map, queries, refs)
diff --git a/examples/orm_rl_demo/run_general_fsdp_qwenvl.sh b/examples/orm_rl_demo/run_general_fsdp_qwenvl.sh
new file mode 100755
index 00000000..3dfc0206
--- /dev/null
+++ b/examples/orm_rl_demo/run_general_fsdp_qwenvl.sh
@@ -0,0 +1,153 @@
+#!/bin/bash
+
+set -euo pipefail
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
+
+NAME="${NAME:-orm-rl-demo-general-geo3k}"
+N_SAMPLES=8
+EPISODE="${EPISODE:-20}"
+WARMUP=0.03
+RBS=128
+TBS=128
+KL=0.001
+LR=1e-6
+
+PROMPT_MAX_LEN=1024
+GENERATE_MAX_LEN=2048
+EVAL_SPLIT="${EVAL_SPLIT:-test}"
+MAX_EVAL_SAMPLES="${MAX_EVAL_SAMPLES:-700}"
+limit_mm_image_per_prompt=1
+ENGINE_TP=1
+
+export IGNORE_EOS=0
+
+DEFAULT_DATA_PATH="/path/to/geo3k"
+DEFAULT_PRETRAIN_PATH="/path/to/Qwen2.5-VL-7B-Instruct"
+DEFAULT_REWARD_PRETRAIN_PATHS='{"general":"/path/to/general-reward-model"}'
+
+DATA_PATH="${DATA_PATH:-${DEFAULT_DATA_PATH}}"
+PRETRAIN_PATH="${PRETRAIN_PATH:-${DEFAULT_PRETRAIN_PATH}}"
+REWARD_PRETRAIN_PATHS="${REWARD_PRETRAIN_PATHS:-${DEFAULT_REWARD_PRETRAIN_PATHS}}"
+LABEL_OVERRIDE="${LABEL_OVERRIDE:-geo3k_general}"
+USE_RM_ENGINE="${USE_RM_ENGINE:-1}"
+
+current_time=$(date +"%m%d%H%M")
+
+cd "${REPO_ROOT}"
+
+export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}"
+
+if [ "${DATA_PATH}" = "${DEFAULT_DATA_PATH}" ] || [ "${PRETRAIN_PATH}" = "${DEFAULT_PRETRAIN_PATH}" ] || [ "${REWARD_PRETRAIN_PATHS}" = "${DEFAULT_REWARD_PRETRAIN_PATHS}" ]; then
+ echo "Set DATA_PATH, PRETRAIN_PATH, and REWARD_PRETRAIN_PATHS before running this template." >&2
+ exit 1
+fi
+
+mkdir -p log
+mkdir -p wandb
+
+export TORCH_NCCL_AVOID_RECORD_STREAMS=1
+export NCCL_DEBUG=WARN
+
+export MLP_WORKER_NUM=1
+export MLP_WORKER_GPU="${MLP_WORKER_GPU:-2}"
+export MLP_ROLE_INDEX=0
+export MLP_WORKER_0_PORT=20090
+export MLP_WORKER_0_HOST=localhost
+
+export MASTER_ADDR=$MLP_WORKER_0_HOST
+export NNODES=$MLP_WORKER_NUM
+export NODE_RANK=$MLP_ROLE_INDEX
+export GPUS_PER_NODE=$MLP_WORKER_GPU
+export MASTER_PORT=$MLP_WORKER_0_PORT
+
+SAVE_MODEL_NAME="LightRFT-geo3k-general-orm-len_${PROMPT_MAX_LEN}_${GENERATE_MAX_LEN}-tbs_${TBS}-rbs_${RBS}-sample_${N_SAMPLES}-kl_${KL}-warmup_${WARMUP}-ep_${EPISODE}-lr_${LR}"
+
+mkdir -p "results/${NAME}/${SAVE_MODEL_NAME}"
+mkdir -p "rft_logs/${NAME}"
+
+set -x
+
+export WANDB_MODE="${WANDB_MODE:-offline}"
+export WANDB_DIR="${WANDB_DIR:-${REPO_ROOT}/wandb}"
+mkdir -p "${WANDB_DIR}"
+
+WANDB_PROJECT="${WANDB_PROJECT:-orm-rl-demo-geo3k}"
+WANDB_RUN_NAME="${WANDB_RUN_NAME:-orm-rl-demo-general-${current_time}}"
+WANDB_ORG="${WANDB_ORG:-}"
+ENGINE_TYPE="${ENGINE_TYPE:-sglang}"
+ENGINE_MEM_UTIL="${ENGINE_MEM_UTIL:-0.4}"
+export ORM_RL_DEMO_RM_ENGINE_BACKEND="${ORM_RL_DEMO_RM_ENGINE_BACKEND:-sglang}"
+
+rm_use_engine_args=()
+if [ "${USE_RM_ENGINE}" = "1" ]; then
+ rm_use_engine_args+=(--rm_use_engine)
+fi
+
+wandb_org_args=()
+if [ -n "${WANDB_ORG}" ]; then
+ wandb_org_args+=(--wandb_org "${WANDB_ORG}")
+fi
+
+wandb_args=()
+if [ -n "${WANDB_API_KEY:-}" ]; then
+ wandb_args+=(--use_wandb "${WANDB_API_KEY}")
+ wandb_args+=(--wandb_project "${WANDB_PROJECT}")
+ wandb_args+=(--wandb_run_name "${WANDB_RUN_NAME}")
+ wandb_args+=("${wandb_org_args[@]}")
+fi
+
+torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR "${SCRIPT_DIR}/train_colocate.py" \
+ --pretrain "${PRETRAIN_PATH}" \
+ --loss_agg_mode seq-mean-token-mean \
+ --save_trajectories \
+ --num_trajectories_to_save 16 \
+ --print_replay_buffer_stats \
+ --fsdp \
+ --use_kl_loss \
+ "${rm_use_engine_args[@]}" \
+ --mixed_mm_data \
+ --reward_pretrain "${REWARD_PRETRAIN_PATHS}" \
+ --save_path "results/${NAME}/${SAVE_MODEL_NAME}" \
+ --ckpt_path "results/${NAME}/${SAVE_MODEL_NAME}" \
+ --micro_train_batch_size 4 \
+ --train_batch_size ${TBS} \
+ --micro_rollout_batch_size 4 \
+ --rollout_batch_size ${RBS} \
+ --advantage_estimator group_norm \
+ --max_epochs 1 \
+ --num_episodes ${EPISODE} \
+ --lr_warmup_ratio ${WARMUP} \
+ --n_samples_per_prompt ${N_SAMPLES} \
+ --prompt_max_len ${PROMPT_MAX_LEN} \
+ --generate_max_len ${GENERATE_MAX_LEN} \
+ --zero_stage 3 \
+ --bf16 \
+ --actor_learning_rate ${LR} \
+ --init_kl_coef ${KL} \
+ --kl_estimator k3 \
+ --prompt_data "${DATA_PATH}" \
+ --input_key prompt \
+ --images_key images \
+ --label_key label \
+ --label_override "${LABEL_OVERRIDE}" \
+ --eval_steps 20 \
+ --eval_split "${EVAL_SPLIT}" \
+ --max_eval_samples "${MAX_EVAL_SAMPLES}" \
+ --apply_chat_template \
+ --flash_attn \
+ --gradient_checkpointing \
+ --save_steps 20 \
+ --max_ckpt_num 1 \
+ --engine_type "${ENGINE_TYPE}" \
+ --engine_mem_util "${ENGINE_MEM_UTIL}" \
+ --engine_tp_size ${ENGINE_TP} \
+ --enable_engine_sleep \
+ --system_prompt 'A conversation between the User and Assistant. The User asks a question, and the Assistant provides a solution. The Assistant first thinks through the reasoning process internally with self-reflection and consistency check and then gives the final analysis and answer. The reasoning process should be enclosed within , followed directly by the final thought and answer, and the final answer should be put in \boxed{}, like this: reasoning process here final thought and \boxed{answer} here.' \
+ --l2 1.0e-2 \
+ --freeze_prefix \
+ --adam_offload \
+ --limit_mm_image_per_prompt ${limit_mm_image_per_prompt} \
+ "${wandb_args[@]}" \
+ 2>&1 | tee "rft_logs/${NAME}/${NAME}_node${NODE_RANK}_$(date +%Y%m%d_%H%M%S).log"
diff --git a/examples/orm_rl_demo/test_reward_models.py b/examples/orm_rl_demo/test_reward_models.py
new file mode 100644
index 00000000..0d789482
--- /dev/null
+++ b/examples/orm_rl_demo/test_reward_models.py
@@ -0,0 +1,105 @@
+"""
+Smoke test for the general reward model used by the ORM RL demo.
+
+This script intentionally uses public text-only examples so it can validate the
+general reward model without depending on private datasets or absolute paths.
+"""
+
+import argparse
+import os
+import sys
+
+import torch
+from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
+
+sys.path.append(os.path.dirname(__file__))
+from reward_models import Qwen2VLRewardModelGeneral
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Smoke test the general reward model used by orm_rl_demo."
+ )
+ parser.add_argument(
+ "--model",
+ required=True,
+ help="Path or HuggingFace id for the general reward model.",
+ )
+ return parser.parse_args()
+
+
+def build_dialog(question: str, response: str) -> str:
+ return (
+ f"<|im_start|>user\n{question}<|im_end|>\n"
+ f"<|im_start|>assistant\n{response}<|im_end|>\n"
+ )
+
+
+def load_reward_model(model_path: str) -> Qwen2VLRewardModelGeneral:
+ base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ device_map="auto",
+ )
+ processor = AutoProcessor.from_pretrained(
+ model_path,
+ min_pixels=256 * 28 * 28,
+ max_pixels=1280 * 28 * 28,
+ )
+ reward_model = Qwen2VLRewardModelGeneral(
+ base_model,
+ processor.tokenizer,
+ processor,
+ text_only=True,
+ )
+ reward_model.eval()
+ return reward_model
+
+
+def run_case(reward_model: Qwen2VLRewardModelGeneral, case: dict) -> None:
+ outputs = reward_model(
+ input_ids=None,
+ attention_mask=None,
+ references=[case["reference"]],
+ prompt_and_output=[build_dialog(case["question"], case["response"])],
+ raw_images=[None],
+ )
+ score = float(outputs["score"].item())
+ print(f"{case['name']}: score={score:.1f}, expected={case['expected']:.1f}")
+ if abs(score - case["expected"]) > 1e-6:
+ raise AssertionError(
+ f"{case['name']} expected {case['expected']}, got {score}"
+ )
+
+
+def main() -> None:
+ args = parse_args()
+ reward_model = load_reward_model(args.model)
+
+ test_cases = [
+ {
+ "name": "correct_answer",
+ "question": "What is 2 + 2?",
+ "response": "The answer is 4.",
+ "reference": "4",
+ "expected": 1.0,
+ },
+ {
+ "name": "incorrect_answer",
+ "question": "What is 2 + 2?",
+ "response": "The answer is 5.",
+ "reference": "4",
+ "expected": 0.0,
+ },
+ ]
+
+ for case in test_cases:
+ run_case(reward_model, case)
+
+ print("general reward model smoke test passed")
+
+
+if __name__ == "__main__":
+ with torch.no_grad():
+ main()
diff --git a/examples/orm_rl_demo/train_colocate.py b/examples/orm_rl_demo/train_colocate.py
new file mode 100755
index 00000000..16babf07
--- /dev/null
+++ b/examples/orm_rl_demo/train_colocate.py
@@ -0,0 +1,656 @@
+"""
+Training entry for the Geo3K ORM RL demo with a co-located general reward model.
+
+This script keeps the demo-specific dataset override and reward wiring local to
+`examples/orm_rl_demo` while reusing the shared GRPO training stack.
+"""
+import argparse
+import itertools
+import math
+import re
+import os
+import sys
+import json
+from datetime import datetime
+from typing import Callable, Dict, List, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from lightrft.utils import add_arguments, ensure_video_input_available
+ensure_video_input_available()
+
+from lightrft.datasets import PromptDatasetVL, SFTDatasetVL
+from lightrft.models.actor_language import ActorLanguage
+from lightrft.models.actor_vl import ActorVL
+from lightrft.models.critic_vl import CriticVL
+from lightrft.utils import blending_datasets, get_tokenizer_processor_vl
+
+from lightrft.strategy import get_strategy
+from lightrft.trainer.spmd_ppo_trainer import SPMDPPOTrainerVL
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from reward_models_utils import load_reward_models, reward_fn, RECIPE
+
+
+def _apply_label_override(dataset, label_key: str, label_override: str, strategy, dataset_name: str):
+ """
+ Apply a demo-local label override without touching the shared dataset library.
+
+ :param dataset: Source dataset to update.
+ :type dataset: Any
+ :param label_key: Dataset field storing the reward label.
+ :type label_key: str
+ :param label_override: Override value to inject when non-empty.
+ :type label_override: str
+ :param strategy: Training strategy used for logging.
+ :type strategy: Any
+ :param dataset_name: Human-readable dataset name for logs.
+ :type dataset_name: str
+ :return: Updated dataset.
+ :rtype: Any
+ """
+ if not label_override:
+ return dataset
+
+ strategy.print(f"Applying label override '{label_override}' to {dataset_name}")
+
+ def override_label(example):
+ example[label_key] = label_override
+ extra = example.get("extra_info")
+ if isinstance(extra, dict):
+ extra = dict(extra)
+ extra[label_key] = label_override
+ example["extra_info"] = extra
+ return example
+
+ return dataset.map(override_label)
+
+
+def train(args):
+ """
+ Main training function for GRPO with co-located reward models.
+
+ :param args: Parsed command-line arguments containing all training configuration.
+ :type args: argparse.Namespace
+ """
+ # configure strategy
+ strategy = get_strategy(args)
+
+ ds_train_cfg = strategy.get_ds_train_config(is_actor=True) if not args.fsdp else None
+ ds_eval_cfg = strategy.get_ds_eval_config(offload=False) if not args.fsdp else None
+
+ # configure model
+ with strategy.init_model_context(meta_init=args.meta_init):
+ strategy.print(f"Initializing models with meta_init={args.meta_init}")
+
+ # Select Actor class based on text_only flag
+ if args.text_only:
+ Actor = ActorLanguage
+ else:
+ Actor = ActorVL
+
+ # Initialize Actor (policy model)
+ actor = Actor(
+ args.pretrain,
+ use_flash_attention_2=args.flash_attn,
+ bf16=args.bf16,
+ load_in_4bit=args.load_in_4bit,
+ lora_rank=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=args.target_modules,
+ lora_dropout=args.lora_dropout,
+ ds_config=ds_train_cfg,
+ packing_samples=args.packing_samples,
+ disable_logprobs_flashattn=args.disable_logprobs_flashattn,
+ fused_linear_logprob=args.fused_linear_logprob,
+ )
+
+ if args.actor_init_on_gpu:
+ actor = actor.to(torch.cuda.current_device())
+
+ # pre-prepare is used for saving RAM memory when training 72B model
+ if args.fsdp:
+ setattr(actor, "is_actor", True)
+ actor = strategy.prepare_model(actor, is_training=True)
+
+ # Optionally freeze parameters (e.g., vision encoder)
+ if args.freeze_prefix:
+ freeze_prefix = ["visual"]
+ frozen_params_count = 0
+ total_params_count = 0
+ for name, param in actor.model.named_parameters():
+ total_params_count += 1
+ if any(name.startswith(prefix) for prefix in freeze_prefix):
+ param.requires_grad = False
+ frozen_params_count += 1
+ strategy.print(f"Froze {frozen_params_count}/{total_params_count} parameters based on prefixes: {freeze_prefix}")
+
+ if args.critic_pretrain:
+ with strategy.init_model_context(meta_init=args.meta_init):
+ critic = CriticVL(
+ args.critic_pretrain,
+ use_flash_attention_2=args.flash_attn,
+ bf16=args.bf16,
+ load_in_4bit=args.load_in_4bit,
+ lora_rank=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=args.target_modules,
+ lora_dropout=args.lora_dropout,
+ normalize_reward=args.normalize_reward_for_critic,
+ ds_config=ds_train_cfg,
+ init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain,
+ value_head_prefix=args.value_head_prefix,
+ )
+ else:
+ critic = None
+
+ if args.fsdp and critic is not None:
+ critic = strategy.prepare_model(critic, is_training=True)
+
+ # Load the general reward model used by this demo.
+ strategy.report_memory(f"before loaded reward models in main entry")
+ reward_models, reward_tokenizers, label_map = load_reward_models(
+ raw_reward_pretrain=args.reward_pretrain,
+ strategy=strategy,
+ use_engine=args.rm_use_engine,
+ )
+ strategy.print(f"label_map: {label_map}")
+ strategy.report_memory(f"after loaded reward models in main entry")
+
+ strategy.print(actor)
+ strategy.print(critic)
+
+ # load weights for reference actor
+ if args.init_kl_coef == 0:
+ initial_model = None
+ else:
+ initial_model = Actor(
+ args.pretrain,
+ use_flash_attention_2=args.flash_attn,
+ bf16=args.bf16,
+ load_in_4bit=args.load_in_4bit,
+ ds_config=ds_eval_cfg,
+ packing_samples=args.packing_samples,
+ fused_linear_logprob=args.fused_linear_logprob,
+ )
+
+ if args.fsdp:
+ initial_model = strategy.prepare_model(initial_model, is_training=False, shard_size=strategy.world_size)
+ strategy.offload_model(initial_model)
+
+ if args.enable_ema:
+ ema_model = Actor(
+ args.pretrain,
+ use_flash_attention_2=args.flash_attn,
+ bf16=args.bf16,
+ load_in_4bit=args.load_in_4bit,
+ ds_config=ds_eval_cfg,
+ )
+ else:
+ ema_model = None
+
+ # configure tokenizer and processor
+ tokenizer, processor = get_tokenizer_processor_vl(
+ args.pretrain, actor.model, "left", use_fast=not strategy.args.disable_fast_tokenizer
+ )
+ assert processor is not None, "processor is None"
+
+ # ==================== Data Loading Optimization ====================
+ # The following sections now rely on the robust `blending_datasets` function.
+ # We add more logging for clarity.
+
+ # Prepare prompts dataset
+ strategy.print(f"Loading prompts dataset from: {args.prompt_data} with split: {args.prompt_split}")
+ prompts_data = blending_datasets(
+ args.prompt_data,
+ args.prompt_data_probs,
+ strategy,
+ args.seed,
+ return_eval=False,
+ train_split=args.prompt_split,
+ )
+ prompts_data = _apply_label_override(
+ prompts_data, args.label_key, args.label_override, strategy, "prompt dataset"
+ )
+
+ prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data))))
+ prompts_dataset = PromptDatasetVL(prompts_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template)
+ strategy.print(f"Loaded {len(prompts_dataset)} samples for prompts.")
+
+ # Prepare evaluation dataset
+ eval_dataloader = None
+ if args.eval_data or args.eval_split:
+ eval_data_path = args.eval_data if args.eval_data else args.prompt_data
+ if eval_data_path:
+ strategy.print(f"Loading evaluation dataset from {eval_data_path}, split='{args.eval_split}'")
+ eval_data = blending_datasets(
+ eval_data_path, "1.0", strategy, args.seed, return_eval=False,
+ # Note: `train_split` parameter is used to specify the desired split name for evaluation data.
+ train_split=args.eval_split,
+ )
+ eval_data = _apply_label_override(
+ eval_data, args.label_key, args.label_override, strategy, "evaluation dataset"
+ )
+ if len(eval_data) == 0:
+ strategy.print(f"Warning: Evaluation dataset at {eval_data_path} with split '{args.eval_split}' is empty. Skipping evaluation.")
+ else:
+ eval_data = eval_data.select(range(min(args.max_eval_samples, len(eval_data))))
+
+ eval_dataset = PromptDatasetVL(eval_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template)
+ eval_dataloader = strategy.setup_dataloader(
+ eval_dataset, args.rollout_batch_size // strategy.world_size, False, False, collate_fn=eval_dataset.collate_fn
+ )
+ strategy.print(f"Evaluation dataset loaded: {len(eval_dataset)} samples")
+ else:
+ strategy.print("Warning: eval_split specified but no data path available for evaluation.")
+
+ # Prepare pretrain dataset
+ pretrain_dataloader = None
+ if args.pretrain_data:
+ strategy.print(f"Loading pretrain dataset from: {args.pretrain_data} with split: {args.pretrain_split}")
+ pretrain_data = blending_datasets(
+ args.pretrain_data, args.pretrain_data_probs, strategy, args.seed,
+ return_eval=False, train_split=args.pretrain_split,
+ )
+ if len(pretrain_data) == 0:
+ strategy.print(f"Warning: Pretrain dataset at {args.pretrain_data} is empty. PTX loss will not be applied.")
+ pretrain_dataloader = None
+ else:
+ pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
+ # Calculate total samples needed for pretraining
+ total_pretrain_samples = args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt
+ pretrain_data_subset = pretrain_data.select(range(min(len(pretrain_data), total_pretrain_samples)))
+
+ pretrain_dataset = SFTDatasetVL(
+ pretrain_data_subset, tokenizer, pretrain_max_len, strategy, pretrain_mode=True,
+ )
+ strategy.print(f"Loaded {len(pretrain_dataset)} samples for pretraining.")
+ pretrain_dataloader = itertools.cycle(
+ iter(
+ strategy.setup_dataloader(
+ pretrain_dataset, args.micro_train_batch_size, True, True, pretrain_dataset.collate_fn,
+ )
+ )
+ )
+ else:
+ pretrain_dataloader = None
+
+ # Prepare prompts dataloader
+ prompts_dataloader = strategy.setup_dataloader(
+ prompts_dataset, args.rollout_batch_size // strategy.world_size, True, True, collate_fn=prompts_dataset.collate_fn
+ )
+
+ if args.pretrain_data:
+ pretrain_dataloader = itertools.cycle(
+ iter(
+ strategy.setup_dataloader(
+ pretrain_dataset,
+ args.micro_train_batch_size,
+ True,
+ True,
+ pretrain_dataset.collate_fn,
+ )
+ )
+ )
+ else:
+ pretrain_dataloader = None
+
+ # for scheduler
+ num_update_steps_per_episodes = (
+ len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs
+ )
+ max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes)
+
+ # gradient_checkpointing
+ if args.gradient_checkpointing:
+ actor.gradient_checkpointing_enable(
+ gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
+ )
+ if critic is not None:
+ critic.gradient_checkpointing_enable(
+ gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
+ )
+
+ (
+ (actor, actor_optim, actor_scheduler),
+ (critic, critic_optim, critic_scheduler),
+ reward_models,
+ initial_model,
+ ) = strategy.prepare_models_and_optimizers(actor, critic, reward_models, initial_model, args, max_steps)
+
+ strategy.print(reward_models)
+
+ if ema_model:
+ ema_model._offload = True
+ ema_model = strategy.prepare(ema_model, is_rlhf=True)
+
+ # load checkpoint
+ consumed_samples = 0
+ if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")):
+ _, states = strategy.load_ckpt(actor.model, os.path.join(args.ckpt_path, "_actor"),
+ optimizer=actor_optim, scheduler=actor_scheduler)
+ if args.critic_pretrain:
+ strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic"))
+ consumed_samples = states["consumed_samples"]
+ strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}")
+
+ os.makedirs(args.save_path, exist_ok=True)
+ strategy.report_memory("after models init")
+
+ strategy.report_memory("before setup_inference_engine")
+ strategy.setup_inference_engine(args, engine_type=args.engine_type, actor=actor)
+ strategy.report_memory("after setup_inference_engine")
+
+ # configure Trainer
+ trainer = SPMDPPOTrainerVL(
+ strategy,
+ actor,
+ critic,
+ reward_models,
+ initial_model,
+ ema_model,
+ actor_optim,
+ critic_optim,
+ actor_scheduler,
+ critic_scheduler,
+ max_epochs=args.max_epochs,
+ micro_train_batch_size=args.micro_train_batch_size,
+ micro_rollout_batch_size=args.micro_rollout_batch_size,
+ gradient_checkpointing=args.gradient_checkpointing,
+ tokenizer=tokenizer,
+ processor=processor,
+ prompt_max_len=args.prompt_max_len,
+ value_clip=args.value_clip,
+ eps_clip=args.eps_clip,
+ loss_agg_mode=args.loss_agg_mode,
+ use_gspo=args.use_gspo,
+ normalize_advantages=args.normalize_advantages,
+ use_sequence_rewards=args.use_sequence_rewards,
+ gamma=args.gamma,
+ lambd=args.lambd,
+ init_kl_coef=args.init_kl_coef,
+ kl_target=args.kl_target,
+ ema_beta=0.992,
+ ptx_coef=args.ptx_coef,
+ max_norm=args.max_norm,
+ # for GPT generation
+ do_sample=True,
+ max_new_tokens=args.generate_max_len,
+ max_length=args.max_len,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ # reward model
+ reward_fn=reward_fn,
+ reward_fn_label_map=label_map,
+ reward_recipe=RECIPE,
+ reward_tokenizers=reward_tokenizers,
+ save_hf_ckpt=args.save_hf_ckpt,
+ disable_ds_ckpt=args.disable_ds_ckpt,
+ packing_samples=args.packing_samples,
+ # overlong_reward
+ dynamic_sampling=args.dynamic_sampling,
+ overlong_buffer=args.overlong_buffer,
+ overlong_buffer_len=args.overlong_buffer_len,
+ overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor,
+ print_replay_buffer_stats=args.print_replay_buffer_stats,
+ )
+
+ trainer.fit(args, prompts_dataloader=prompts_dataloader, pretrain_dataloader=pretrain_dataloader, eval_dataloader=eval_dataloader, consumed_samples=0, num_update_steps_per_episodes=num_update_steps_per_episodes)
+
+ # save model checkpoint after fitting on only rank0
+ strategy.save_model(
+ ema_model if args.enable_ema else actor,
+ tokenizer,
+ args.save_path,
+ )
+
+ if args.critic_pretrain and args.save_value_network:
+ strategy.save_model(
+ critic,
+ tokenizer,
+ args.save_path + "_critic",
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--engine_type", type=str, default="vllm", help="Choose inference engine type: vllm, sglang")
+ parser.add_argument("--text_only", action="store_true", default=False)
+
+ # Checkpoint
+ parser.add_argument("--save_path", type=str, default="./ckpt")
+ parser.add_argument("--save_steps", type=int, default=-1)
+ parser.add_argument("--save_hf_ckpt", action="store_true", default=False)
+ parser.add_argument("--disable_ds_ckpt", action="store_true", default=False)
+ parser.add_argument("--save_trajectories", action="store_true", default=False, help="Save experience trajectories to JSON for debugging")
+ parser.add_argument("--num_trajectories_to_save", type=int, default=10, help="Number of trajectories to save per checkpoint")
+ parser.add_argument("--trajectory_analysis", action="store_true", default=False, help="Enable trajectory analysis metrics (repeat_score, reflection_pattern, policy_entropy) and log to wandb")
+ parser.add_argument("--print_replay_buffer_stats", action="store_true", default=False, help="Print detailed replay buffer statistics during training")
+ parser.add_argument("--logging_steps", type=int, default=1)
+ parser.add_argument("--eval_steps", type=int, default=-1)
+ parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo")
+ parser.add_argument("--max_ckpt_num", type=int, default=3)
+ parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
+ parser.add_argument("--load_checkpoint", action="store_true", default=False)
+
+ # DAPO
+ parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy")
+ parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO")
+ parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer")
+ parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them")
+
+ # PPO
+ parser.add_argument("--num_episodes", type=int, default=1)
+ parser.add_argument("--rollout_batch_size", type=int, default=512)
+ parser.add_argument("--micro_rollout_batch_size", type=int, default=8)
+ parser.add_argument("--max_epochs", type=int, default=1)
+ parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt")
+ parser.add_argument("--generate_max_len", type=int, default=1024, help="Max tokens to generate in PPO")
+ parser.add_argument("--max_len", type=int, default=None, help="deprecated max_len")
+ parser.add_argument("--max_samples", type=int, default=1000000)
+ parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping")
+ parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss")
+ parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef")
+ parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range")
+ parser.add_argument("--loss_agg_mode", type=str, default='seq-mean-token-mean',
+ help="Loss aggregation mode. Options: ['token-mean', 'seq-mean-token-sum', 'seq-mean-token-mean', 'seq-mean-token-sum-norm']")
+ parser.add_argument("--use_gspo", action="store_true", default=False, help="Enable GSPO (Group Sequence Policy Optimization) mode")
+ parser.add_argument("--normalize_advantages", action="store_true", default=True, help="Enable advantage normalization in GSPO")
+ parser.add_argument("--use_sequence_rewards", action="store_true", default=True, help="Use sequence-level rewards in GSPO")
+ parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range")
+ parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd")
+ parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma")
+ parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU")
+ parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size")
+ parser.add_argument("--normalize_reward_for_critic", action="store_true", default=False, help="Enable Reward Normalization in critic model")
+ parser.add_argument("--top_p", type=float, default=1.0)
+ parser.add_argument("--temperature", type=float, default=1.0)
+ parser.add_argument("--freeze_prefix", action="store_true", default=False, help="Freeze the prefix part (e.g. vision encoder) of the actor model")
+ parser.add_argument("--freezing_actor_steps", type=int, default=-1, help="Used for critic initialization")
+ parser.add_argument(
+ "--n_samples_per_prompt", type=int, default=1, help="number of responses for each prompt in generation"
+ )
+ parser.add_argument("--save_value_network", action="store_true", default=False, help="Save critic model")
+ parser.add_argument("--actor_learning_rate", type=float, default=1e-6)
+ parser.add_argument("--critic_learning_rate", type=float, default=9e-6)
+ parser.add_argument("--lr_warmup_ratio", type=float, default=0.03)
+ parser.add_argument("--kl_target", type=float, default=None)
+ parser.add_argument("--init_kl_coef", type=float, default=0.01, help="KL penalty in PPO")
+ parser.add_argument(
+ "--kl_estimator",
+ type=str,
+ default="k1",
+ choices=["k1", "k2", "k3"],
+ help=(
+ "In GRPO, k3 is utilized as the loss function, while k2, when used as the loss, is nearly equivalent to k1."
+ ),
+ )
+ parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer")
+
+ # Reward/Advantage Norm/Clip Arguments
+ parser.add_argument("--reward_running_norm", action="store_true", default=False, help="Enable running normalization for rewards.")
+ parser.add_argument("--reward_running_norm_minus_mean", action="store_true", default=False, help="When using reward normalization, subtract the mean; otherwise, only scale by the std.")
+ parser.add_argument("--reward_clip", type=float, default=0.0, help="Clip rewards to the range [-reward_clip, reward_clip]. 0.0 means no clipping.")
+ parser.add_argument("--advantages_norm", action="store_true", default=False, help="Enable whitening for advantages.")
+ parser.add_argument("--advantage_clip", type=float, default=0.0, help="Clip advantages to the range [-advantage_clip, advantage_clip]. 0.0 means no clipping.")
+
+ # DeepSpeed
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
+ parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage")
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
+ parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
+ parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.")
+ parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size")
+ parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer")
+ parser.add_argument("--actor_init_on_gpu", action="store_true", default=False)
+ parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2")
+ parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss")
+ parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type")
+ parser.add_argument("--overlap_comm", action="store_true", default=False)
+ parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
+ parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
+ parser.add_argument("--disable_logprobs_flashattn", action="store_true", default=False, help="Disable flash attn implementation in log_probs calculation")
+
+ # FSDP
+ parser.add_argument("--no_shard_vit", action="store_true", default=False, help="Disable sharding for vision transformer")
+ parser.add_argument("--meta_init", action="store_true", default=False, help="Initialize models on meta device to save CPU memory")
+
+ # Reinforce
+ parser.add_argument(
+ "--advantage_estimator",
+ type=str,
+ choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++"],
+ default="gae",
+ help="Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm, reinforce++",
+ )
+
+ parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO")
+
+ # LoRA
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
+ parser.add_argument("--lora_rank", type=int, default=0)
+ parser.add_argument("--lora_alpha", type=int, default=16)
+ parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear")
+ parser.add_argument("--lora_dropout", type=float, default=0)
+
+ # Models
+ parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path")
+ parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path")
+ parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API")
+ parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path")
+ parser.add_argument("--value_head_prefix", type=str, default="score")
+
+ # Custom dataset
+ parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path")
+ parser.add_argument(
+ "--prompt_data_probs",
+ type=str,
+ default="1.0",
+ help="sampling probs for datasets",
+ )
+ parser.add_argument("--prompt_split", type=str, default="train")
+
+ # Evaluation dataset
+ parser.add_argument("--eval_data", type=str, default=None, help="HF evaluation dataset name or path (default: use prompt_data)")
+ parser.add_argument("--eval_split", type=str, default="test", help="Evaluation data split (default: test)")
+ parser.add_argument("--max_eval_samples", type=int, default=500, help="Maximum number of samples to evaluate (default: 500)")
+
+ parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path")
+ parser.add_argument(
+ "--pretrain_data_probs",
+ type=str,
+ default="1.0",
+ help="sampling probs for datasets",
+ )
+ parser.add_argument("--pretrain_split", type=str, default="train")
+ parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key")
+ parser.add_argument("--images_key", type=str, default="image", help="JSON dataser key for images")
+ parser.add_argument("--reference_key", type=str, default="reference", help="JSON dataset key for reference answers")
+ parser.add_argument("--label_key", type=str, default="label", help="JSON dataset key")
+ parser.add_argument(
+ "--label_override",
+ type=str,
+ default=None,
+ help="Optional label override applied after dataset loading.",
+ )
+ parser.add_argument("--input_template", type=str, default=None)
+ parser.add_argument(
+ "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template"
+ )
+
+ parser.add_argument("--system_prompt", type=str, default=None, help="HF System Prompt")
+
+
+ # wandb parameters
+ parser.add_argument("--use_wandb", type=str, default=None)
+ parser.add_argument("--wandb_org", type=str, default=None)
+ parser.add_argument("--wandb_group", type=str, default=None)
+ parser.add_argument("--wandb_project", type=str, default="lightrft_train_ppo")
+ parser.add_argument(
+ "--wandb_run_name",
+ type=str,
+ default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"),
+ )
+
+ # TensorBoard parameters
+ parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path")
+
+ # ModelScope parameters
+ parser.add_argument("--use_ms", action="store_true", default=False)
+
+ # MultiModal
+ parser.add_argument("--limit_mm_image_per_prompt", type=int, default=-1, help="the max image number of each text in multi model for inference backend")
+
+ # CPGD
+ parser.add_argument("--use_cpg_loss", action="store_true", default=False, help="whether to use the clipped policy gradient loss from CPGD")
+
+ add_arguments(parser)
+
+ args = parser.parse_args()
+
+
+ if args.advantage_estimator not in ["gae"]:
+ args.critic_pretrain = None
+ elif args.critic_pretrain is None:
+ args.critic_pretrain = args.pretrain
+
+ if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]:
+ assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1"
+
+ if args.use_kl_loss:
+ if args.kl_estimator not in ["k2", "k3"]:
+ print(f"Recommend setting {args.kl_estimator} to 'k2' or 'k3' when using KL as a loss")
+ else:
+ if args.kl_estimator not in ["k1"]:
+ print(f"Recommend setting {args.kl_estimator} to 'k1' when not using KL as a loss.")
+
+ if args.advantage_estimator in ["gae", "cpgd"] and args.use_kl_loss:
+ warnings.warn(
+ "Using use_kl_loss=True with non-normalized advantage estimator "
+ "may result in double KL penalty. Consider disabling --use_kl_loss "
+ "or using --advantage_estimator group_norm"
+ )
+
+ if args.input_template and "{}" not in args.input_template:
+ print("[Warning] {} not in args.input_template, set to None")
+ args.input_template = None
+
+ if args.input_template and "\\n" in args.input_template:
+ print(
+ "[Warning] input_template contains \\n chracters instead of newline. "
+ "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell."
+ )
+
+ if args.use_ms:
+ from modelscope.utils.hf_util import patch_hub
+
+ # Patch hub to download models from modelscope to speed up.
+ patch_hub()
+
+ train(args)
diff --git a/lightrft/strategy/strategy_base.py b/lightrft/strategy/strategy_base.py
index ac72ba25..050c5a00 100644
--- a/lightrft/strategy/strategy_base.py
+++ b/lightrft/strategy/strategy_base.py
@@ -535,10 +535,13 @@ def prepare_models_and_optimizers(self, actor, critic, reward_models, initial_mo
if critic is not None:
critic = self.prepare_model(critic, is_training=True)
if not self.config.remote_rm_url:
+ reward_model_shard_size = self.world_size
if isinstance(reward_models, (tuple, list)):
- reward_models = [self.prepare_model(model, shard_size=8) for model in reward_models]
+ reward_models = [
+ self.prepare_model(model, shard_size=reward_model_shard_size) for model in reward_models
+ ]
else:
- reward_models = self.prepare_model(reward_models, shard_size=8)
+ reward_models = self.prepare_model(reward_models, shard_size=reward_model_shard_size)
# Configure optimizers
actor_optim = self.create_optimizer(
diff --git a/lightrft/strategy/vllm_utils/vllm_worker_wrap_no_ray.py b/lightrft/strategy/vllm_utils/vllm_worker_wrap_no_ray.py
index 89c6104e..3747e4e5 100644
--- a/lightrft/strategy/vllm_utils/vllm_worker_wrap_no_ray.py
+++ b/lightrft/strategy/vllm_utils/vllm_worker_wrap_no_ray.py
@@ -7,6 +7,8 @@
"""
import torch
+import vllm
+from packaging.version import Version
# vLLM version compatibility notes:
# --------------------------------
# In older versions of vLLM (< 0.13.0), the Worker class is located under:
@@ -16,11 +18,22 @@
# vllm.v1.worker.gpu_worker.Worker
#
# To maintain compatibility across different vLLM versions, we try importing Worker
-# from the new v1 path first (for vllm>=0.13.0). If the import fails (ModuleNotFoundError),
-# we fall back to importing from the old path (for vllm<0.13.0).
-try:
- from vllm.v1.worker.gpu_worker import Worker
-except (ModuleNotFoundError, ImportError):
+# from the new v1 path only for vllm>=0.13.0. Older releases like v0.7.x may
+# already expose a v1 module tree, but their uniproc executor still expects the
+# legacy Worker implementation with methods such as determine_num_available_blocks.
+if Version(vllm.__version__) >= Version("0.13.0"):
+ try:
+ from vllm.v1.worker.gpu_worker import Worker
+ except (ModuleNotFoundError, ImportError):
+ try:
+ from vllm.worker.worker import Worker
+ except (ModuleNotFoundError, ImportError):
+ raise ImportError(
+ "Could not import Worker from vllm. "
+ "Please ensure you have a compatible version of vllm installed. "
+ "Supported versions: vllm>=0.6.3 or vllm>=0.13.0"
+ )
+else:
try:
from vllm.worker.worker import Worker
except (ModuleNotFoundError, ImportError):
diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py
index fe12d534..4f11022b 100644
--- a/lightrft/trainer/fast_exp_maker.py
+++ b/lightrft/trainer/fast_exp_maker.py
@@ -791,6 +791,7 @@ def _compute_standard_torch_rewards(
rm_output = rm(
sequences,
output.attention_mask,
+ references=output.references,
prompt_and_output=output.prompt_and_output,
raw_images=output.raw_images,
img_num=output.image_num,
diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py
index e3b44091..ab2fc37c 100644
--- a/lightrft/trainer/ppo_trainer_vl.py
+++ b/lightrft/trainer/ppo_trainer_vl.py
@@ -405,6 +405,7 @@ def fit(
all_rewards = []
all_format_rewards = []
all_accuracy_rewards = []
+ all_general_model_rewards = []
all_response_lengths = []
for item in self.replay_buffer.items:
@@ -428,6 +429,11 @@ def fit(
all_format_rewards.append(reward_metrics['format_reward'])
if 'accuracy_reward' in reward_metrics:
all_accuracy_rewards.append(reward_metrics['accuracy_reward'])
+ general_model_reward = reward_metrics.get("general_model_reward")
+ if general_model_reward is None:
+ general_model_reward = reward_metrics.get("model_reward")
+ if general_model_reward is not None:
+ all_general_model_rewards.append(general_model_reward)
# Collect response lengths from rollout
if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info:
@@ -476,6 +482,18 @@ def fit(
if abs(mean_accuracy_reward) > 1e-6:
rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward
+ if all_general_model_rewards:
+ if isinstance(all_general_model_rewards[0], torch.Tensor):
+ general_model_tensor = torch.cat([t.to(device).float() for t in all_general_model_rewards])
+ else:
+ general_model_tensor = torch.tensor(
+ all_general_model_rewards, dtype=torch.float32, device=device
+ )
+
+ mean_general_model_reward = general_model_tensor.mean().item()
+ if abs(mean_general_model_reward) > 1e-6:
+ rollout_status["rollout_general_model_reward"] = mean_general_model_reward
+
if all_response_lengths:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
if isinstance(all_response_lengths[0], torch.Tensor):
@@ -1200,6 +1218,7 @@ def evaluate(self, eval_dataloader, global_step):
all_rewards = []
all_format_rewards = []
all_accuracy_rewards = []
+ all_general_model_rewards = []
all_response_lengths = []
num_eval_batches = 0
@@ -1249,6 +1268,11 @@ def extract_values(val):
all_format_rewards.extend(extract_values(rm['format_reward']))
if 'accuracy_reward' in rm:
all_accuracy_rewards.extend(extract_values(rm['accuracy_reward']))
+ general_model_reward = rm.get("general_model_reward")
+ if general_model_reward is None:
+ general_model_reward = rm.get("model_reward")
+ if general_model_reward is not None:
+ all_general_model_rewards.extend(extract_values(general_model_reward))
num_eval_batches += 1
if num_eval_batches >= len(eval_dataloader):
@@ -1271,6 +1295,7 @@ def compute_stats(name, values_list):
compute_stats("reward", all_rewards)
compute_stats("format_reward", all_format_rewards)
compute_stats("accuracy_reward", all_accuracy_rewards)
+ compute_stats("general_model_reward", all_general_model_rewards)
compute_stats("response_length", all_response_lengths)
metrics["num_samples"] = len(all_rewards)
diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py
index d79a7458..a0574905 100644
--- a/lightrft/trainer/spmd_ppo_trainer.py
+++ b/lightrft/trainer/spmd_ppo_trainer.py
@@ -316,7 +316,7 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train
all_rewards = []
all_format_rewards = []
all_accuracy_rewards = []
- all_model_rewards = []
+ all_general_model_rewards = []
all_rule_rewards = []
all_advantages = []
all_returns = []
@@ -334,8 +334,11 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train
all_format_rewards.append(reward_metrics['format_reward'])
if 'accuracy_reward' in reward_metrics:
all_accuracy_rewards.append(reward_metrics['accuracy_reward'])
- if 'model_reward' in reward_metrics:
- all_model_rewards.append(reward_metrics['model_reward'])
+ general_model_reward = reward_metrics.get("general_model_reward")
+ if general_model_reward is None:
+ general_model_reward = reward_metrics.get("model_reward")
+ if general_model_reward is not None:
+ all_general_model_rewards.append(general_model_reward)
if 'rule_reward' in reward_metrics:
all_rule_rewards.append(reward_metrics['rule_reward'])
@@ -379,15 +382,15 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train
status_mean["accuracy_reward_mean"] = accuracy_tensor.mean().item()
status_mean["accuracy_reward_std"] = accuracy_tensor.std().item()
- if all_model_rewards:
+ if all_general_model_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
- if isinstance(all_model_rewards[0], torch.Tensor):
- model_tensor = torch.cat([t.to(device).float() for t in all_model_rewards])
+ if isinstance(all_general_model_rewards[0], torch.Tensor):
+ model_tensor = torch.cat([t.to(device).float() for t in all_general_model_rewards])
else:
- model_tensor = torch.tensor(all_model_rewards, dtype=torch.float32, device=device)
+ model_tensor = torch.tensor(all_general_model_rewards, dtype=torch.float32, device=device)
if model_tensor.abs().sum() > 0: # Only log if model rewards are non-zero
- status_mean["model_reward_mean"] = model_tensor.mean().item()
- self.strategy.print(f" model_reward_mean: {status_mean['model_reward_mean']}")
+ status_mean["general_model_reward_mean"] = model_tensor.mean().item()
+ self.strategy.print(f" general_model_reward_mean: {status_mean['general_model_reward_mean']}")
if all_rule_rewards:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
@@ -444,6 +447,9 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train
f"✅ Accuracy Reward: {status_mean['accuracy_reward_mean']:.4f} ± {status_mean['accuracy_reward_std']:.4f}" # noqa
)
+ if all_general_model_rewards and "general_model_reward_mean" in status_mean:
+ self.strategy.print(f"🧠 General RM Reward:{status_mean['general_model_reward_mean']:.4f}")
+
if all_advantages:
self.strategy.print(
f"📈 Advantages: {status_mean['advantages_mean']:.4f} ± {status_mean['advantages_std']:.4f} " # noqa