diff --git a/.gitignore b/.gitignore
index d3a81dae..54733cc6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1219,4 +1219,8 @@ wandb*
examples/demo_grpo/results*
build/*
examples/math_benchmarks/eval_results/
-.llmconfig.yaml
\ No newline at end of file
+.llmconfig.yaml
+
+# Local agent tool state
+.claude/
+.codex
diff --git a/examples/math_prm/README.md b/examples/math_prm/README.md
new file mode 100644
index 00000000..a5c385a4
--- /dev/null
+++ b/examples/math_prm/README.md
@@ -0,0 +1,272 @@
+# Math PRM: GRPO Training with a Process Reward Model
+
+This example trains [URSA-8B](https://huggingface.co/URSA-MATH/URSA-8B) — a multimodal math VLM — with [URSA-8B-RM](https://huggingface.co/URSA-MATH/URSA-RM-8B) as a Process Reward Model (PRM), using the GRPO algorithm with a **PS-GRPO** reward signal as proposed in the [URSA paper (NeurIPS 2025)](https://arxiv.org/abs/2501.04686).
+
+Unlike the rule-based examples under `examples/gsm8k_geo3k/`, the reward here comes from a neural reward model that scores **each reasoning step**, and the final per-trajectory reward depends on *how the step scores evolve* across the response, not just on whether the final answer is right.
+
+The example ships **two algorithm paths** side by side:
+
+1. **PS-GRPO** (`run_grpo_math_prm_ursa_8b.sh`) — the paper's recommended reward `r ∈ {0, 0.5, 1}`, used as a single per-trajectory scalar by standard GRPO. This is the production recipe.
+2. **Strict paper Eq.9 variant 2** (`run_grpo_math_prm_ursa_8b_variant2.sh`) — the per-step PRM advantage `A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) + GroupNorm_G(r_o^i)` (paper Appendix B.1). The paper itself rejects this in favour of PS-GRPO; it ships here as an ablation comparator. The advantage estimator lives entirely in [`ursa_variant2.py`](ursa_variant2.py) (zero edits to `lightrft/`).
+
+## Overview
+
+| Item | Math PRM |
+|------|----------|
+| Task | Multimodal math reasoning (text + image questions) |
+| Modality | Multi-modal (text + image) |
+| Actor | URSA-8B (hybrid SAM-B + SigLIP-L vision tower + Qwen2.5-Math-Instruct) |
+| Reward Model | URSA-8B-RM (process reward model, step-level scoring) |
+| Reward formula (PS-GRPO) | `r ∈ {0, 0.5, 1}` (correctness × step-stability) |
+| Algorithm | GRPO (group_norm advantage estimator) or `ursa_variant2` for paper Eq.9 |
+| Rollout engine | Local Hugging Face (vLLM/SGLang URSA support is future work) |
+
+The PS-GRPO reward is computed inside `MathPRMReward` ([reward_models.py](reward_models.py)) and follows the URSA paper:
+
+```text
+r = 0 if outcome_correct == 0
+r = 1 if outcome_correct == 1 and no step-score drop
+r = 0.5 ( = 1 - DROP_GAMMA) if outcome_correct == 1 but a step-score drop occurred
+```
+
+A **step-score drop** is detected when any consecutive pair of step scores has a relative drop ≥ `_DROP_THRESHOLD = 0.3`.
+
+---
+
+## 1. Dataset Preprocessing
+
+The training data is `MMathCoT-1M` (Stage 3 split), which needs to be converted into the LightRFT manifest schema. Both `--input-path` and `--image-root` are **required** (no defaults — paths are environment-specific):
+
+```bash
+python examples/math_prm/tools/prepare_ursa_stage3_manifest.py \
+ --input-path /your/data/URSA-MATH/MMathCoT-1M/train.jsonl \
+ --image-root /your/data/URSA-MATH/images \
+ --output-path /your/output/math_psgrpo.jsonl
+```
+
+Each row in the converted manifest looks like:
+
+```json
+{
+ "prompt": "Math question text",
+ "images": ["/abs/path/to/image.png"],
+ "reference": "Ground-truth answer",
+ "label": "math_psgrpo"
+}
+```
+
+The `label` field is what selects the reward path. Available labels:
+
+| Label | Reward signal |
+|---|---|
+| `math_psgrpo` | PS-GRPO: `{0, 0.5, 1}` (default for this example) |
+| `math_prm` | Pure PRM aggregated step score (continuous in `[0, 1]`) |
+| `math_prm_combined` | PRM aggregated score + 0.5 × rule-based correctness |
+| `math_rule` | Rule-only baseline `{0, 1}` based on answer match |
+| `math_per_step_prm` | Per-step PRM scores for `--advantage_estimator ursa_variant2` (paper Eq.9, see §6) |
+
+For a smoke conversion (32 samples), pass `--max-samples 32`.
+
+---
+
+## 2. Model Checkpoints
+
+You need both the URSA-8B actor and the URSA-8B-RM reward model:
+
+```bash
+# Hugging Face IDs
+URSA-MATH/URSA-8B # actor
+URSA-MATH/URSA-RM-8B # reward model
+```
+
+Download to a local directory and set the paths in `run_grpo_math_prm_ursa_8b.sh`.
+
+---
+
+## 3. Configure and Run Training (PS-GRPO recipe)
+
+Edit `Part 1: User Configuration` at the top of [run_grpo_math_prm_ursa_8b.sh](run_grpo_math_prm_ursa_8b.sh):
+
+```bash
+PATH_TO_YOUR_BASE_MODEL="/path/to/URSA-8B"
+PATH_TO_URSA_RM="/path/to/URSA-RM-8B"
+PATH_TO_YOUR_MATH_DATASET="/path/to/math_psgrpo.jsonl"
+EXPERIMENT_NAME="lightrft-ursa8b-math-prm"
+export WANDB_API_KEY="YOUR_WANDB_API_KEY" # leave empty to disable W&B
+```
+
+Then run:
+
+```bash
+bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh
+```
+
+The default machine target is `1 node × 8 A100/H100 GPUs`. For a different topology, override the standard env vars:
+
+```bash
+NNODES=2 GPUS_PER_NODE=8 NODE_RANK=0 \
+MASTER_ADDR=10.0.0.1 MASTER_PORT=20092 \
+bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh
+```
+
+---
+
+## 4. Key Hyperparameters
+
+The launcher uses the URSA-MATH paper's Stage 3 defaults:
+
+| Param | Value | Notes |
+|---|---|---|
+| `N_SAMPLES` | 8 | Responses sampled per prompt for GRPO |
+| `EPISODE` | 10 | Total training episodes |
+| `RBS` / `TBS` | 128 / 128 | Rollout / training batch size |
+| `KL` | 0.001 | Initial KL coefficient |
+| `KL_TARGET` | (off) | If set, switches to AdaptiveKLController |
+| `LR` | 1e-6 | Actor learning rate |
+| `PROMPT_MAX_LEN` | 1024 | |
+| `GENERATE_MAX_LEN` | 3072 | |
+| `MAX_SAMPLES` | 15360 | Cap on training subset (paper proxy) |
+| `EVAL_HOLDOUT_SIZE` | 500 | A deterministic held-out subset is reserved from `prompt_data` for in-domain eval |
+
+To enable the adaptive KL controller (recommended if you observe the KL drifting), set `KL_TARGET` to a small positive value, e.g. `KL_TARGET=0.5`.
+
+---
+
+## 5. What's Logged
+
+W&B panels are split into three namespaces:
+
+- `rollout/*` — per-step rollout statistics: `reward`, `outcome_correct`, `model_reward`, `has_drop_moment`, `response_length`, `step_score_min/mean/last`, `step_count`, `final_reward`, `max_relative_drop`, `answer_tag_present`, `answer_extraction_failed`, `used_answer_fallback`, `used_mathruler`, `reference_supported`, plus variant-2 diagnostics `alignment_failed` / `n_aligned_steps`.
+- `train/*` — per-step training statistics: `policy_loss`, `kl`, `actor_lr`, `advantages`, `return`, plus variant-2 diagnostics `ursa_v2_adv_pos_frac` / `_neg_frac` / `_zero_frac` / `_abs_mean` / `_oc_normed_std` / `_msp_normed_std` / `_traj_step_count_mean`.
+- `eval/*` — evaluation pass on the held-out split: `reward`, `outcome_correct`, `response_length`, `answer_extraction_failed`, `has_drop_moment`, `model_reward`, `step_score_min/mean/last`, `step_count`, `final_reward`, `max_relative_drop`, `answer_tag_present`, `used_answer_fallback`, `used_mathruler`, `reference_supported`.
+
+The full per-sample reward metric set emitted by `MathPRMReward` is documented at the top of `forward()` in [reward_models.py](reward_models.py).
+
+---
+
+## 6. Strict Paper Eq.9 — variant 2 path
+
+`run_grpo_math_prm_ursa_8b_variant2.sh` runs the URSA paper's Appendix B.1 Eq.9 "variant 2" advantage formula side by side with PS-GRPO so the two can be ablated. The implementation lives in [`ursa_variant2.py`](ursa_variant2.py) as a new `--advantage_estimator ursa_variant2` registered via an idempotent monkey-patch from `examples/math_prm/` (no edits to `lightrft/`).
+
+### Formula
+
+```text
+A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) ← process-reward term
+ + GroupNorm_G(r_o^i) ← outcome-reward term
+```
+
+where `t` indexes a **step** (not a token), `r_{s,t}^i` is the sigmoid PRM score for step `t` in trajectory `i`, `r̄_s^i = mean_t r_{s,t}^i`, `r_o^i ∈ {0,1}` is the outcome reward, and `G` is the GRPO group size (`n_samples_per_prompt`). The per-step `A_t^i` is broadcast to every token within step `t`'s span. **There is no cumulative return**, and the outcome term is preserved (not bypassed).
+
+### Workflow
+
+The variant-2 path requires rows labeled `math_per_step_prm` instead of `math_psgrpo`. Easiest way is to sed-relabel the PS-GRPO manifest:
+
+```bash
+sed 's/"label":[ ]*"math_psgrpo"/"label": "math_per_step_prm"/g' \
+ /path/to/math_psgrpo.jsonl \
+ > /path/to/math_per_step_prm.jsonl
+```
+
+The variant-2 launcher will auto-swap to the relabeled sibling if it finds the legacy psgrpo path in `PATH_TO_YOUR_MATH_DATASET`, and assert that the first row's label is `math_per_step_prm` before training. Set `PATH_TO_YOUR_MATH_DATASET_VARIANT2` to a custom path to override this.
+
+```bash
+PATH_TO_YOUR_MATH_DATASET_VARIANT2=/path/to/math_per_step_prm.jsonl \
+bash examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh
+```
+
+`--per_step_reward_mode` (`raw` / `group_norm`) only affects the **legacy Math-Shepherd-style per-token reward path** (different `_apply_step_reward_group_norm` aggregation); `--advantage_estimator ursa_variant2` does its own group normalization inside the calculator and is unaffected by this flag.
+
+### Unit tests
+
+`test_ursa_variant2.py` ships 9 acceptance-criterion tests (numerical equality with hand-computed Eq.9, GroupNorm correctness for K=2/K=4 groups, span broadcast, outcome-term non-bypass). Run them:
+
+```bash
+python3 -m unittest examples.math_prm.test_ursa_variant2 -v
+```
+
+---
+
+## 7. Results — 9-day production run (PS-GRPO)
+
+A 9-day full production run on 8× H100 with the PS-GRPO recipe is summarized below. The variant-2 path was validated by a parallel 9-day run (W&B `kdwjt4eo`); see [PR #53 final-report comment](https://github.com/opendilab/LightRFT/pull/53#issuecomment-4608400929) for the side-by-side comparison.
+
+| Metric | Baseline (Step 20) | Peak | Final | Δ vs baseline |
+|---|---|---|---|---|
+| `eval/outcome_correct` | 0.5952 | **0.6508** (Step 231) | 0.6290 (Step 1008) | **+3.4 pp** |
+| `eval/answer_extraction_failed` | 0.028 | 0.018 (~Step 160) | 0.034 | -0.6 pp ↓ |
+| `eval/has_drop_moment` | 0.0 | — | 0.0 | (PRM never triggered) |
+| `eval/response_length` | 400 | 337 (~Step 240) | 377 | -23 ↓ |
+| `rollout/alignment_failed` | 0 | — | 0 | 100% step-boundary alignment |
+| W&B run | [`kdwjt4eo`](https://wandb.ai/hansbug/LightRFT-URSA8B-Stage3/runs/kdwjt4eo) |
+
+#### Eval trajectory
+
+`eval/outcome_correct` peaks at Step 231 (+5.6pp), shows a transient dip near Step 300 (reward-hacking signature) but **self-heals**, and stabilizes in the 0.60–0.65 range for the remaining 7 days:
+
+
+
+#### KL + rollout overview
+
+`train/kl` exits warmup quickly (1e-4 → 1.0 by Step ~200), oscillates in 1–100 range, occasional single-batch spikes >100 always self-correct. `rollout/outcome_correct` and `rollout/model_reward` track together (no reward-hacking decoupling):
+
+
+
+#### Eval-time generation quality
+
+`eval/answer_extraction_failed` briefly spikes during the Step 300 dip (the `†Answer:` marker format drift the URSA paper warns about) but recovers back to 2–5%. `eval/response_length` and `eval/step_count` stay stable — no length collapse:
+
+
+
+#### variant 2 path health (W&B run `kdwjt4eo`)
+
+`ursa_v2_adv_pos_frac` and `_neg_frac` stay roughly balanced (30–40% / 25–35%) — GroupNorm produces signed advantages as expected. `_msp_normed_std` stays close to 1.0. `rollout/alignment_failed` is 0 the entire run:
+
+
+
+---
+
+## 8. Files Under This Directory
+
+```text
+examples/math_prm/
+├── README.md - This guide (en)
+├── README_zh.md - This guide (zh)
+├── train_colocate.py - Main training entry (called by torchrun)
+├── run_grpo_math_prm_ursa_8b.sh - PS-GRPO launcher (recommended)
+├── run_grpo_math_prm_ursa_8b_variant2.sh - Strict paper Eq.9 launcher (ablation)
+├── reward_models.py - MathPRMReward implementation (PS-GRPO)
+├── reward_models_utils.py - Reward recipe / mixing logic per label
+├── ursa_actor.py - URSA-specific actor wrapper
+├── ursa_variant2.py - UrsaVariant2Calculator (paper Eq.9, examples-only)
+├── math_prm_trainer.py - MathPRMSPMDPPOTrainerVL (wandb metric mapping)
+├── math_prm_output.py - "†Answer:" marker / structured-stop helpers
+├── rollout_eos_patch.py - StoppingCriteria injection for reliable EOS under FSDP
+├── test_ursa_variant2.py - 9 unit tests for variant 2 (AC1-AC5)
+├── ursa_model/ - Vendored URSA model code (config / processor / model)
+├── tools/
+│ ├── prepare_ursa_stage3_manifest.py - Dataset conversion tool
+│ └── prepare_ursa_engine_checkpoint.py - Engine-mode checkpoint wrapper
+└── assets/
+ └── exp_20260603/ - W&B screenshots from the 9-day production run
+```
+
+---
+
+## 9. Citation
+
+If you use this example, please cite the URSA paper:
+
+```bibtex
+@article{luo2025ursa,
+ title={URSA: Understanding and Verifying Chain-of-Thought Reasoning in Multimodal Mathematics},
+ author={Luo, Ruilin and Zheng, Zhuofan and Wang, Yifan and Yu, Yiyao and Ni, Xinzhe and Lin, Zicheng and Zeng, Jin and Yang, Yujiu},
+ journal={NeurIPS},
+ year={2025}
+}
+```
+
+---
+
+## License
+
+This example is released under the same license as the parent LightRFT project (see top-level `LICENSE`).
diff --git a/examples/math_prm/README_zh.md b/examples/math_prm/README_zh.md
new file mode 100644
index 00000000..d6ea7f2b
--- /dev/null
+++ b/examples/math_prm/README_zh.md
@@ -0,0 +1,272 @@
+# Math PRM:基于过程奖励模型 (PRM) 的 GRPO 训练
+
+本示例使用 [URSA-8B](https://huggingface.co/URSA-MATH/URSA-8B)(多模态数学 VLM)作为 actor,配合 [URSA-8B-RM](https://huggingface.co/URSA-MATH/URSA-RM-8B) 作为过程奖励模型 (PRM),按 [URSA 论文(NeurIPS 2025)](https://arxiv.org/abs/2501.04686)所提的 **PS-GRPO** 奖励路径用 GRPO 算法训练。
+
+不同于 `examples/gsm8k_geo3k/` 那类规则型 reward 示例,这里的 reward 来自一个对**每一步推理**打分的神经网络奖励模型,最终的 trajectory-level reward 取决于 step scores 沿 response 的**演化形态**,而不仅仅是最终答案是否正确。
+
+本 example 同时附带**两条算法路径**用于对比:
+
+1. **PS-GRPO**(`run_grpo_math_prm_ursa_8b.sh`)—— 论文最终采纳的 `r ∈ {0, 0.5, 1}` 单标量奖励,由标准 GRPO 处理。**生产推荐配方**。
+2. **Paper Eq.9 严格 variant 2**(`run_grpo_math_prm_ursa_8b_variant2.sh`)—— 论文附录 B.1 的逐 step PRM advantage:`A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) + GroupNorm_G(r_o^i)`。论文自身否决了它,本 example 保留只为做 ablation 对照。完整实现位于 [`ursa_variant2.py`](ursa_variant2.py)(不修改 `lightrft/`)。
+
+## 总览
+
+| 项 | Math PRM |
+|------|----------|
+| 任务 | 多模态数学推理(文本+图像题) |
+| 模态 | 多模态(文本 + 图像) |
+| Actor | URSA-8B(SAM-B + SigLIP-L 视觉塔 + Qwen2.5-Math-Instruct) |
+| Reward Model | URSA-8B-RM(过程奖励模型,step-level scoring) |
+| Reward 公式(PS-GRPO) | `r ∈ {0, 0.5, 1}`(正确性 × step 稳定性) |
+| 算法 | GRPO(group_norm 优势估计器)或 paper Eq.9 的 `ursa_variant2` |
+| Rollout 引擎 | 本地 Hugging Face(vLLM/SGLang 对 URSA 的支持待后续) |
+
+PS-GRPO 奖励在 `MathPRMReward`([reward_models.py](reward_models.py))中按 URSA 论文公式计算:
+
+```text
+r = 0 若 outcome_correct == 0
+r = 1 若 outcome_correct == 1 且无 step-score drop
+r = 0.5 ( = 1 - DROP_GAMMA) 若 outcome_correct == 1 但存在 step-score drop
+```
+
+**Step-score drop** 的判定:任意相邻 step score 出现相对降幅 ≥ `_DROP_THRESHOLD = 0.3`。
+
+---
+
+## 1. 数据预处理
+
+训练数据为 `MMathCoT-1M`(Stage 3 子集),需要转换成 LightRFT 的 manifest 格式。`--input-path` 与 `--image-root` 均**必填**(无默认值——路径与环境相关):
+
+```bash
+python examples/math_prm/tools/prepare_ursa_stage3_manifest.py \
+ --input-path /your/data/URSA-MATH/MMathCoT-1M/train.jsonl \
+ --image-root /your/data/URSA-MATH/images \
+ --output-path /your/output/math_psgrpo.jsonl
+```
+
+转换后每行 manifest 形如:
+
+```json
+{
+ "prompt": "数学题目文本",
+ "images": ["/abs/path/to/image.png"],
+ "reference": "标准答案",
+ "label": "math_psgrpo"
+}
+```
+
+`label` 字段决定选择哪一条 reward 路径。可选值:
+
+| Label | Reward 信号 |
+|---|---|
+| `math_psgrpo` | PS-GRPO:`{0, 0.5, 1}`(本 example 默认) |
+| `math_prm` | 纯 PRM 聚合 step score(连续值 `[0, 1]`) |
+| `math_prm_combined` | PRM 聚合分数 + 0.5 × 规则正确性 |
+| `math_rule` | 规则基线:`{0, 1}` 按答案匹配 |
+| `math_per_step_prm` | 逐 step PRM 分数,供 `--advantage_estimator ursa_variant2`(paper Eq.9,详见 §6)使用 |
+
+需要 32 行小规模转换做 smoke 时用 `--max-samples 32`。
+
+---
+
+## 2. 模型 checkpoint
+
+需要 URSA-8B(actor)与 URSA-8B-RM(reward model)两个权重:
+
+```bash
+# Hugging Face IDs
+URSA-MATH/URSA-8B # actor
+URSA-MATH/URSA-RM-8B # reward model
+```
+
+下载到本地后在 `run_grpo_math_prm_ursa_8b.sh` 里配置路径。
+
+---
+
+## 3. 配置并启动训练(PS-GRPO 配方)
+
+编辑 [run_grpo_math_prm_ursa_8b.sh](run_grpo_math_prm_ursa_8b.sh) 顶部的 `Part 1: User Configuration`:
+
+```bash
+PATH_TO_YOUR_BASE_MODEL="/path/to/URSA-8B"
+PATH_TO_URSA_RM="/path/to/URSA-RM-8B"
+PATH_TO_YOUR_MATH_DATASET="/path/to/math_psgrpo.jsonl"
+EXPERIMENT_NAME="lightrft-ursa8b-math-prm"
+export WANDB_API_KEY="YOUR_WANDB_API_KEY" # 留空则禁用 W&B
+```
+
+然后运行:
+
+```bash
+bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh
+```
+
+默认目标硬件 `1 节点 × 8 A100/H100`。改 topology 用环境变量 override:
+
+```bash
+NNODES=2 GPUS_PER_NODE=8 NODE_RANK=0 \
+MASTER_ADDR=10.0.0.1 MASTER_PORT=20092 \
+bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh
+```
+
+---
+
+## 4. 关键超参
+
+启动脚本使用 URSA-MATH 论文 Stage 3 默认值:
+
+| 参数 | 值 | 备注 |
+|---|---|---|
+| `N_SAMPLES` | 8 | 每个 prompt GRPO 采样数 |
+| `EPISODE` | 10 | 总训练 episodes |
+| `RBS` / `TBS` | 128 / 128 | rollout / training batch size |
+| `KL` | 0.001 | 初始 KL 系数 |
+| `KL_TARGET` | (off) | 设值后切到 AdaptiveKLController |
+| `LR` | 1e-6 | Actor 学习率 |
+| `PROMPT_MAX_LEN` | 1024 | |
+| `GENERATE_MAX_LEN` | 3072 | |
+| `MAX_SAMPLES` | 15360 | 训练子集上限(论文 proxy) |
+| `EVAL_HOLDOUT_SIZE` | 500 | 从 `prompt_data` 中保留的确定性 held-out 子集 |
+
+观察到 KL 飘移时建议开 adaptive KL 控制器:`KL_TARGET=0.5`。
+
+---
+
+## 5. 日志字段
+
+W&B 面板按三个 namespace 划分:
+
+- `rollout/*` — 每步 rollout 统计:`reward`、`outcome_correct`、`model_reward`、`has_drop_moment`、`response_length`、`step_score_min/mean/last`、`step_count`、`final_reward`、`max_relative_drop`、`answer_tag_present`、`answer_extraction_failed`、`used_answer_fallback`、`used_mathruler`、`reference_supported`,以及 variant-2 诊断字段 `alignment_failed` / `n_aligned_steps`。
+- `train/*` — 每步训练统计:`policy_loss`、`kl`、`actor_lr`、`advantages`、`return`,以及 variant-2 诊断字段 `ursa_v2_adv_pos_frac` / `_neg_frac` / `_zero_frac` / `_abs_mean` / `_oc_normed_std` / `_msp_normed_std` / `_traj_step_count_mean`。
+- `eval/*` — held-out 评测:`reward`、`outcome_correct`、`response_length`、`answer_extraction_failed`、`has_drop_moment`、`model_reward`、`step_score_min/mean/last`、`step_count`、`final_reward`、`max_relative_drop`、`answer_tag_present`、`used_answer_fallback`、`used_mathruler`、`reference_supported`。
+
+`MathPRMReward` 输出的全套 per-sample 奖励 metric 文档见 [reward_models.py](reward_models.py) 中 `forward()` 顶部注释。
+
+---
+
+## 6. Paper Eq.9 严格对齐 — variant 2 路径
+
+`run_grpo_math_prm_ursa_8b_variant2.sh` 是 URSA 论文附录 B.1 Eq.9 "variant 2" 严格实现,与 PS-GRPO 并行存在,用于 ablation 对比。实现在 [`ursa_variant2.py`](ursa_variant2.py),通过幂等 monkey-patch 注册一个新 `--advantage_estimator ursa_variant2`(不修改 `lightrft/`)。
+
+### 公式
+
+```text
+A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) ← process-reward 项
+ + GroupNorm_G(r_o^i) ← outcome-reward 项
+```
+
+其中 `t` 是 **step 索引**(不是 token),`r_{s,t}^i` 为 trajectory `i` 第 `t` 个 step 的 sigmoid PRM 分,`r̄_s^i = mean_t r_{s,t}^i`,`r_o^i ∈ {0,1}` 是 outcome reward,`G` 是 GRPO group size(`n_samples_per_prompt`)。逐 step `A_t^i` 广播到该 step 覆盖的所有 token。**无 cumulative return**,outcome 项保留(不像 Math-Shepherd 风格的 Mode B 那样被丢弃)。
+
+### 数据集 / 启动流程
+
+variant 2 路径要求 manifest 行 label 是 `math_per_step_prm` 而不是 `math_psgrpo`。最简单方法是 sed-relabel PS-GRPO manifest:
+
+```bash
+sed 's/"label":[ ]*"math_psgrpo"/"label": "math_per_step_prm"/g' \
+ /path/to/math_psgrpo.jsonl \
+ > /path/to/math_per_step_prm.jsonl
+```
+
+variant 2 启动脚本会自动检测 `PATH_TO_YOUR_MATH_DATASET` 是否指向 psgrpo 路径,若是则自动 swap 到 `*per_step_prm*.jsonl` 兄弟文件,并在训练前 assert 首行 label 是 `math_per_step_prm`。如需自定义路径,设 `PATH_TO_YOUR_MATH_DATASET_VARIANT2`。
+
+```bash
+PATH_TO_YOUR_MATH_DATASET_VARIANT2=/path/to/math_per_step_prm.jsonl \
+bash examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh
+```
+
+`--per_step_reward_mode`(`raw` / `group_norm`)只影响**遗留 Math-Shepherd 风格逐 token reward 路径**(`_apply_step_reward_group_norm` 不同聚合方式);`--advantage_estimator ursa_variant2` 自带 group normalization,不受该 flag 影响。
+
+### 单元测试
+
+`test_ursa_variant2.py` 包含 9 个 AC(acceptance criterion)级单测:与手算 Eq.9 数值等价、K=2/K=4 group 的 GroupNorm 正确性、span 广播、outcome 项非旁路。运行:
+
+```bash
+python3 -m unittest examples.math_prm.test_ursa_variant2 -v
+```
+
+---
+
+## 7. 实验结果 — 9 天生产 run(PS-GRPO)
+
+8× H100 上 PS-GRPO 配方跑满 9 天的关键指标如下。variant 2 路径并行跑了同样 9 天作对照(W&B `kdwjt4eo`),完整对比见 [PR #53 最终报告 comment](https://github.com/opendilab/LightRFT/pull/53#issuecomment-4608400929)。
+
+| 指标 | baseline (Step 20) | peak | final | Δ vs baseline |
+|---|---|---|---|---|
+| `eval/outcome_correct` | 0.5952 | **0.6508** (Step 231) | 0.6290 (Step 1008) | **+3.4 pp** |
+| `eval/answer_extraction_failed` | 0.028 | 0.018 (~Step 160) | 0.034 | -0.6 pp ↓ |
+| `eval/has_drop_moment` | 0.0 | — | 0.0 | (PRM 全程未触发) |
+| `eval/response_length` | 400 | 337 (~Step 240) | 377 | -23 ↓ |
+| `rollout/alignment_failed` | 0 | — | 0 | 100% step 边界对齐 |
+| W&B run | [`kdwjt4eo`](https://wandb.ai/hansbug/LightRFT-URSA8B-Stage3/runs/kdwjt4eo) |
+
+#### eval 轨迹
+
+`eval/outcome_correct` 在 Step 231 见峰 +5.6pp,约 Step 300 出现一次 dip(reward hacking signature)但**自愈**,剩 7 天稳定在 0.60–0.65 区间:
+
+
+
+#### KL + rollout 全局视角
+
+`train/kl` 走出 warmup 后(1e-4 → 1.0 by Step ~200),在 1–100 区间震荡,偶发单 batch >100 spike 总能自愈。`rollout/outcome_correct` 与 `rollout/model_reward` 长期同向变化(无 reward hacking 解耦):
+
+
+
+#### eval 生成质量
+
+`eval/answer_extraction_failed` 在 Step 300 dip 期间短暂飙到 18%(URSA 论文警告的 `†Answer:` 格式漂移信号),之后回稳到 2–5%。`eval/response_length` 与 `eval/step_count` 稳定 — 无 length collapse:
+
+
+
+#### variant 2 路径健康度(W&B run `kdwjt4eo`)
+
+`ursa_v2_adv_pos_frac` 与 `_neg_frac` 长期保持平衡(30–40% / 25–35%)—— GroupNorm 持续产出 signed advantages。`_msp_normed_std` 贴近 1.0。`rollout/alignment_failed` 全程 0:
+
+
+
+---
+
+## 8. 目录文件清单
+
+```text
+examples/math_prm/
+├── README.md - 本文档(英文)
+├── README_zh.md - 本文档(中文)
+├── train_colocate.py - 训练主入口(由 torchrun 调用)
+├── run_grpo_math_prm_ursa_8b.sh - PS-GRPO 启动脚本(推荐)
+├── run_grpo_math_prm_ursa_8b_variant2.sh - paper Eq.9 严格启动脚本(ablation)
+├── reward_models.py - MathPRMReward 实现(PS-GRPO)
+├── reward_models_utils.py - 按 label 选 reward 配方的逻辑
+├── ursa_actor.py - URSA actor wrapper
+├── ursa_variant2.py - UrsaVariant2Calculator(paper Eq.9,纯 examples/)
+├── math_prm_trainer.py - MathPRMSPMDPPOTrainerVL(wandb metric 映射)
+├── math_prm_output.py - "†Answer:" marker / structured-stop helpers
+├── rollout_eos_patch.py - FSDP 下可靠 EOS 的 StoppingCriteria 注入
+├── test_ursa_variant2.py - variant 2 的 9 个 AC 级单测
+├── ursa_model/ - 内置 URSA 模型代码(config / processor / model)
+├── tools/
+│ ├── prepare_ursa_stage3_manifest.py - 数据集转换工具
+│ └── prepare_ursa_engine_checkpoint.py - Engine-mode checkpoint wrapper
+└── assets/
+ └── exp_20260603/ - 9 天生产 run 的 W&B 截图
+```
+
+---
+
+## 9. 引用
+
+使用本 example 请引用 URSA 论文:
+
+```bibtex
+@article{luo2025ursa,
+ title={URSA: Understanding and Verifying Chain-of-Thought Reasoning in Multimodal Mathematics},
+ author={Luo, Ruilin and Zheng, Zhuofan and Wang, Yifan and Yu, Yiyao and Ni, Xinzhe and Lin, Zicheng and Zeng, Jin and Yang, Yujiu},
+ journal={NeurIPS},
+ year={2025}
+}
+```
+
+---
+
+## License
+
+本 example 与上层 LightRFT 项目使用同一 License(见仓库根目录 `LICENSE`)。
diff --git a/examples/math_prm/assets/exp_20260603/eval_outcome.png b/examples/math_prm/assets/exp_20260603/eval_outcome.png
new file mode 100644
index 00000000..18d056c0
Binary files /dev/null and b/examples/math_prm/assets/exp_20260603/eval_outcome.png differ
diff --git a/examples/math_prm/assets/exp_20260603/eval_quality.png b/examples/math_prm/assets/exp_20260603/eval_quality.png
new file mode 100644
index 00000000..40d4bd64
Binary files /dev/null and b/examples/math_prm/assets/exp_20260603/eval_quality.png differ
diff --git a/examples/math_prm/assets/exp_20260603/kl_and_rollout.png b/examples/math_prm/assets/exp_20260603/kl_and_rollout.png
new file mode 100644
index 00000000..65a150dd
Binary files /dev/null and b/examples/math_prm/assets/exp_20260603/kl_and_rollout.png differ
diff --git a/examples/math_prm/assets/exp_20260603/variant2_health.png b/examples/math_prm/assets/exp_20260603/variant2_health.png
new file mode 100644
index 00000000..d86be4f9
Binary files /dev/null and b/examples/math_prm/assets/exp_20260603/variant2_health.png differ
diff --git a/examples/math_prm/math_prm_output.py b/examples/math_prm/math_prm_output.py
new file mode 100644
index 00000000..ac86c892
--- /dev/null
+++ b/examples/math_prm/math_prm_output.py
@@ -0,0 +1,109 @@
+"""
+Helpers for URSA Math PRM structured outputs.
+
+These helpers centralize the heuristics used to keep Phase 3 generation inside the
+expected `Step N:` / `†Answer:` format without introducing Phase 4 reward logic.
+"""
+
+import re
+from typing import Optional
+
+
+MATH_PRM_STRUCTURED_LABELS = frozenset({"math_prm", "math_prm_combined", "math_psgrpo"})
+MATH_PRM_ANSWER_MARKER = "†Answer:"
+_MAX_MATH_PRM_ANSWER_WORDS = 24
+_MAX_MATH_PRM_ANSWER_CHARS = 160
+_EARLY_STOP_ANSWER_WORDS = 12
+_EARLY_STOP_ANSWER_CHARS = 80
+_BOOLEAN_ANSWERS = {"yes", "no", "true", "false"}
+_ALGEBRAIC_ANSWER_PATTERN = re.compile(
+ r"[A-Za-z][A-Za-z0-9_]*(?:\s*[,;]\s*[A-Za-z][A-Za-z0-9_]*)*\s*=\s*[-+A-Za-z0-9$\\][A-Za-z0-9\s,./%()=\-+*$\\^{}]*"
+)
+
+
+def is_math_prm_structured_label(label: Optional[str]) -> bool:
+ return isinstance(label, str) and label.lower() in MATH_PRM_STRUCTURED_LABELS
+
+
+def find_math_prm_tail_cutoff(text: str) -> Optional[int]:
+ cut_positions = []
+ for pattern in (
+ r"(? str:
+ if not response_text:
+ return response_text
+ return re.sub(r"(?m)^StepStep\s+(\d+:)", r"Step \1", response_text)
+
+
+def _extract_answer_line(response_text: str) -> tuple[str, str, bool]:
+ normalized_text = _normalize_math_prm_response(response_text)
+ marker_index = normalized_text.find(MATH_PRM_ANSWER_MARKER)
+ if marker_index < 0:
+ return normalized_text, "", False
+
+ answer_tail = normalized_text[marker_index + len(MATH_PRM_ANSWER_MARKER):].lstrip()
+ answer_lines = answer_tail.splitlines()
+ answer_line = " ".join(answer_lines[0].split()) if answer_lines else ""
+ has_more_lines = len(answer_lines) > 1
+ return normalized_text, answer_line, has_more_lines
+
+
+def should_stop_math_prm_response_text(response_text: str) -> bool:
+ normalized_text, answer_line, has_more_lines = _extract_answer_line(response_text)
+ if normalized_text.find(MATH_PRM_ANSWER_MARKER) < 0 or not answer_line:
+ return False
+ if has_more_lines:
+ return True
+ if find_math_prm_tail_cutoff(answer_line) is not None:
+ return True
+
+ lower_answer = answer_line.lower()
+ if lower_answer in _BOOLEAN_ANSWERS:
+ return True
+ if re.fullmatch(r"[-+]?[$]?\d[\d\s,./%()=-]*", answer_line):
+ return True
+ if _ALGEBRAIC_ANSWER_PATTERN.fullmatch(answer_line):
+ return True
+ if re.fullmatch(r"[A-E]", answer_line):
+ return True
+ if answer_line.endswith((".", "!", "?", "%", ")", "]")):
+ return True
+ if len(answer_line.split()) >= _EARLY_STOP_ANSWER_WORDS:
+ return True
+ if len(answer_line) >= _EARLY_STOP_ANSWER_CHARS:
+ return True
+ return False
+
+
+def sanitize_math_prm_response_text(response_text: str) -> str:
+ normalized_text, answer_line, _ = _extract_answer_line(response_text)
+ marker_index = normalized_text.find(MATH_PRM_ANSWER_MARKER)
+ if marker_index < 0:
+ return normalized_text
+
+ prefix = normalized_text[: marker_index + len(MATH_PRM_ANSWER_MARKER)]
+
+ cutoff = find_math_prm_tail_cutoff(answer_line)
+ if cutoff is not None:
+ answer_line = answer_line[:cutoff]
+
+ answer_words = answer_line.split()
+ if len(answer_words) > _MAX_MATH_PRM_ANSWER_WORDS:
+ answer_line = " ".join(answer_words[:_MAX_MATH_PRM_ANSWER_WORDS])
+ if len(answer_line) > _MAX_MATH_PRM_ANSWER_CHARS:
+ truncated = answer_line[:_MAX_MATH_PRM_ANSWER_CHARS]
+ answer_line = truncated.rsplit(" ", 1)[0] or truncated
+
+ answer_line = answer_line.rstrip(" ,;:")
+ return prefix.rstrip() if not answer_line else f"{prefix} {answer_line}".rstrip()
diff --git a/examples/math_prm/math_prm_trainer.py b/examples/math_prm/math_prm_trainer.py
new file mode 100644
index 00000000..d6a02d09
--- /dev/null
+++ b/examples/math_prm/math_prm_trainer.py
@@ -0,0 +1,455 @@
+from contextlib import contextmanager
+from typing import Dict, Optional
+
+import torch
+
+from lightrft.trainer.spmd_ppo_trainer import SPMDPPOTrainerVL
+
+# Explicitly register the ursa_variant2 monkey-patches so
+# ``--advantage_estimator ursa_variant2`` resolves to the paper Eq.9
+# strict-alignment calculator. train_colocate.py runs with
+# cwd=examples/math_prm, so a top-level (non-package) import is used here.
+# (register_ursa_variant2() is idempotent.)
+from ursa_variant2 import register_ursa_variant2
+
+register_ursa_variant2()
+
+
+def _detach_rollout_eos_patch(rollout_actor):
+ """Detach rollout_eos_patch.StructuredAnswerStoppingCriteria wrap from a rollout actor.
+
+ Returns the unwrapped (original) generate function so the caller can restore
+ the patch later. Returns None if no patch is installed.
+
+ The patch wraps ``model.generate`` with ``functools.wraps``, so the original
+ function is reachable via ``__wrapped__``. We rely on the patch's idempotency
+ flag ``_math_prm_rollout_eos_patch_installed`` to detect installation.
+ """
+ if rollout_actor is None:
+ return None
+ model = getattr(rollout_actor, "model", None)
+ if model is None:
+ return None
+ if not getattr(model, "_math_prm_rollout_eos_patch_installed", False):
+ return None
+ patched = model.generate
+ original = getattr(patched, "__wrapped__", None)
+ if original is None:
+ return None
+ model.generate = original
+ model._math_prm_rollout_eos_patch_installed = False
+ return patched
+
+
+def _reattach_rollout_eos_patch(rollout_actor, patched_generate):
+ """Reinstall a previously detached patched generate function."""
+ if rollout_actor is None or patched_generate is None:
+ return
+ model = getattr(rollout_actor, "model", None)
+ if model is None:
+ return
+ model.generate = patched_generate
+ model._math_prm_rollout_eos_patch_installed = True
+
+
+class MathPRMSPMDPPOTrainerVL(SPMDPPOTrainerVL):
+ """SPMD PPO trainer specialized for URSA-MATH process-reward training.
+
+ Differs from the base ``SPMDPPOTrainerVL`` in two ways:
+
+ 1. **W&B namespace mapping** — rollout/train/eval metric streams are
+ projected into the ``rollout/``, ``train/``, ``eval/`` namespaces via
+ ``_ROLLOUT_KEY_SOURCES`` / ``_TRAIN_KEY_SOURCES`` / ``_EVAL_KEY_SOURCES``
+ so the dashboards stay aligned with the URSA paper's reporting
+ conventions even as upstream metric names drift.
+ 2. **URSA-specific eval** — :meth:`evaluate` runs under a runtime context
+ that prevents the actor from generating ``<|image|>`` sentinel tokens
+ and aggregates per-dataset metrics into a single weighted-average
+ view (see :meth:`_aggregate_eval_metrics`).
+
+ All checkpoint, logging, profile, and trajectory-saving wiring is unchanged
+ from the base class apart from the namespace mapping.
+ """
+
+ _ROLLOUT_KEY_SOURCES = {
+ "reward": ("rollout_reward", "step_reward_mean", "reward"),
+ "reward_std": ("rollout_reward_std", "step_reward_std"),
+ "outcome_correct": ("rollout_outcome_correct", "outcome_correct_mean", "reward_metrics/outcome_correct"),
+ "has_drop_moment": ("rollout_has_drop_moment", "has_drop_moment_mean", "reward_metrics/has_drop_moment"),
+ "model_reward": ("rollout_model_reward", "model_reward_mean", "reward_metrics/model_reward"),
+ "response_length": ("rollout_response_length", "response_length_mean", "response_length"),
+ # PRM step-score distribution (computed by MathPRMReward.forward but
+ # previously not surfaced to wandb; useful for monitoring PRM behaviour
+ # at scale, not just min-aggregation).
+ "step_score_min": ("rollout_step_score_min", "step_score_min_mean", "reward_metrics/step_score_min"),
+ "step_score_mean": ("rollout_step_score_mean", "step_score_mean_mean", "reward_metrics/step_score_mean"),
+ "step_score_last": ("rollout_step_score_last", "step_score_last_mean", "reward_metrics/step_score_last"),
+ "step_count": ("rollout_step_count", "step_count_mean", "reward_metrics/step_count"),
+ # PS-GRPO + answer-extraction diagnostics (already computed, were missing
+ # from wandb mapping — required for reward-hacking forensics).
+ "final_reward": ("rollout_final_reward", "final_reward_mean", "reward_metrics/final_reward"),
+ "max_relative_drop": (
+ "rollout_max_relative_drop", "max_relative_drop_mean", "reward_metrics/max_relative_drop"
+ ),
+ "answer_tag_present": (
+ "rollout_answer_tag_present", "answer_tag_present_mean", "reward_metrics/answer_tag_present"
+ ),
+ "answer_extraction_failed": (
+ "rollout_answer_extraction_failed", "answer_extraction_failed_mean",
+ "reward_metrics/answer_extraction_failed"
+ ),
+ "used_answer_fallback": (
+ "rollout_used_answer_fallback", "used_answer_fallback_mean", "reward_metrics/used_answer_fallback"
+ ),
+ "used_mathruler": ("rollout_used_mathruler", "used_mathruler_mean", "reward_metrics/used_mathruler"),
+ "reference_supported": (
+ "rollout_reference_supported", "reference_supported_mean", "reward_metrics/reference_supported"
+ ),
+ # Variant 2 (per-step PRM) diagnostics — populated only when
+ # the dataset row label is "math_per_step_prm". For "math_psgrpo"
+ # rows these stay 0 (no alignment was attempted).
+ "alignment_failed": ("rollout_alignment_failed", "alignment_failed_mean", "reward_metrics/alignment_failed"),
+ "n_aligned_steps": ("rollout_n_aligned_steps", "n_aligned_steps_mean", "reward_metrics/n_aligned_steps"),
+ }
+ _TRAIN_KEY_SOURCES = {
+ "policy_loss": ("policy_loss", ),
+ "kl": ("kl", ),
+ "actor_lr": ("actor_lr", ),
+ "critic_loss": ("critic_loss", ),
+ "critic_lr": ("critic_lr", ),
+ "values": ("values", ),
+ "values_std": ("values_std", ),
+ "reward": ("reward", ),
+ "reward_std": ("step_reward_std", ),
+ "return": ("return", ),
+ "return_std": ("returns_std", ),
+ "response_length": ("response_length", ),
+ "total_length": ("total_length", ),
+ "num_actions": ("num_actions", ),
+ "approx_kl": ("approx_kl", ),
+ "clipfrac": ("clipfrac", ),
+ "ratio_mean": ("ratio_mean", ),
+ "ratio_max": ("ratio_max", ),
+ "advantages": ("advantages_mean", ),
+ "advantages_std": ("advantages_std", ),
+ "ptx_loss": ("ptx_loss", ),
+ # URSA paper Eq.9 variant 2 advantage-calculator diagnostics. Populated
+ # only when --advantage_estimator ursa_variant2 is active; otherwise
+ # absent from experience.info and silently skipped by _build_train_metrics.
+ "ursa_v2_adv_pos_frac": ("ursa_v2_adv_pos_frac", ),
+ "ursa_v2_adv_neg_frac": ("ursa_v2_adv_neg_frac", ),
+ "ursa_v2_adv_zero_frac": ("ursa_v2_adv_zero_frac", ),
+ "ursa_v2_adv_abs_mean": ("ursa_v2_adv_abs_mean", ),
+ "ursa_v2_oc_normed_std": ("ursa_v2_oc_normed_std", ),
+ "ursa_v2_msp_normed_std": ("ursa_v2_msp_normed_std", ),
+ "ursa_v2_traj_step_count_mean": ("ursa_v2_traj_step_count_mean", ),
+ }
+ _EVAL_KEY_SOURCES = {
+ "reward": ("reward", "reward_mean"),
+ "outcome_correct": ("outcome_correct", "outcome_correct_mean"),
+ "has_drop_moment": ("has_drop_moment", "has_drop_moment_mean"),
+ "model_reward": ("model_reward", "model_reward_mean"),
+ "response_length": ("response_length", "response_length_mean"),
+ "answer_extraction_failed": ("answer_extraction_failed", "answer_extraction_failed_mean"),
+ # PRM step-score distribution + PS-GRPO diagnostics in eval (parity with
+ # the expanded rollout-side mapping above).
+ "step_score_min": ("step_score_min", "step_score_min_mean"),
+ "step_score_mean": ("step_score_mean", "step_score_mean_mean"),
+ "step_score_last": ("step_score_last", "step_score_last_mean"),
+ "step_count": ("step_count", "step_count_mean"),
+ "final_reward": ("final_reward", "final_reward_mean"),
+ "max_relative_drop": ("max_relative_drop", "max_relative_drop_mean"),
+ "answer_tag_present": ("answer_tag_present", "answer_tag_present_mean"),
+ "used_answer_fallback": ("used_answer_fallback", "used_answer_fallback_mean"),
+ "used_mathruler": ("used_mathruler", "used_mathruler_mean"),
+ "reference_supported": ("reference_supported", "reference_supported_mean"),
+ # Variant 2 diagnostics in eval (eval also runs the PRM forward
+ # if the dataset label is "math_per_step_prm")
+ "alignment_failed": ("alignment_failed", "alignment_failed_mean"),
+ "n_aligned_steps": ("n_aligned_steps", "n_aligned_steps_mean"),
+ }
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._train_generate_kwargs = dict(self.generate_kwargs)
+ self._eval_generate_kwargs = self._build_eval_generate_kwargs()
+ if self._wandb is not None and self.strategy.is_rank_0():
+ self._wandb.define_metric("rollout/*", step_metric=None, step_sync=False, overwrite=True)
+ self._wandb.define_metric("train/*", step_metric=None, step_sync=False, overwrite=True)
+ self._wandb.define_metric("eval/train_step")
+ self._wandb.define_metric("eval/*", step_metric="eval/train_step", step_sync=True, overwrite=True)
+ self._wandb.define_metric("profile/train_step")
+ self._wandb.define_metric("profile/*", step_metric="profile/train_step", step_sync=True, overwrite=True)
+
+ def _build_eval_generate_kwargs(self) -> Dict:
+ eval_generate_kwargs = dict(self._train_generate_kwargs)
+ eval_generate_kwargs["do_sample"] = bool(getattr(self.strategy.args, "eval_do_sample", False))
+ eval_generate_kwargs["max_new_tokens"] = (
+ getattr(self.strategy.args, "eval_generate_max_len", None)
+ or self._train_generate_kwargs.get("max_new_tokens")
+ )
+ eval_generate_kwargs["temperature"] = getattr(self.strategy.args, "eval_temperature", 0.0)
+ eval_generate_kwargs["top_p"] = getattr(self.strategy.args, "eval_top_p", 1.0)
+ eval_generate_kwargs["top_k"] = getattr(self.strategy.args, "eval_top_k", -1)
+ eval_generate_kwargs["repetition_penalty"] = getattr(self.strategy.args, "eval_repetition_penalty", 1.0)
+ eval_generate_kwargs["no_repeat_ngram_size"] = getattr(
+ self.strategy.args,
+ "eval_no_repeat_ngram_size",
+ 0,
+ )
+ return eval_generate_kwargs
+
+ @contextmanager
+ def _runtime_eval_context(self):
+ original_generate_kwargs = self.generate_kwargs
+ original_n_samples = self.strategy.args.n_samples_per_prompt
+ original_advantage_estimator = self.strategy.args.advantage_estimator
+ original_config_n_samples = getattr(self.strategy.config, "n_samples_per_prompt", None)
+ original_config_advantage_estimator = getattr(self.strategy.config, "advantage_estimator", None)
+
+ self.generate_kwargs = dict(self._eval_generate_kwargs)
+ self.strategy.args.n_samples_per_prompt = max(
+ 1, int(getattr(self.strategy.args, "eval_n_samples_per_prompt", 1))
+ )
+ self.strategy.args.advantage_estimator = "reinforce"
+ if original_config_n_samples is not None:
+ self.strategy.config.n_samples_per_prompt = self.strategy.args.n_samples_per_prompt
+ if original_config_advantage_estimator is not None:
+ self.strategy.config.advantage_estimator = "reinforce"
+
+ # Detach rollout_eos_patch on the inference engine for the duration of eval.
+ # The patch is meant to save GPU during training rollouts (early-stops at
+ # the first ``†Answer:`` line) but truncates response tokens that the
+ # reward extractor needs in eval; ablation showed it lowers eval
+ # outcome_correct by ~8pp at bs=4 and is catastrophic at bs=1
+ # (extraction-failure 44%). See PR #53 issuecomment-4394071500.
+ rollout_actor = getattr(self.strategy, "inference_engine", None)
+ detached_patch = _detach_rollout_eos_patch(rollout_actor)
+ if detached_patch is not None and self.strategy.is_rank_0():
+ self.strategy.print("[eval] rollout_eos_patch detached for the eval pass")
+
+ try:
+ yield
+ finally:
+ self.generate_kwargs = original_generate_kwargs
+ self.strategy.args.n_samples_per_prompt = original_n_samples
+ self.strategy.args.advantage_estimator = original_advantage_estimator
+ if original_config_n_samples is not None:
+ self.strategy.config.n_samples_per_prompt = original_config_n_samples
+ if original_config_advantage_estimator is not None:
+ self.strategy.config.advantage_estimator = original_config_advantage_estimator
+ if detached_patch is not None:
+ _reattach_rollout_eos_patch(rollout_actor, detached_patch)
+ if self.strategy.is_rank_0():
+ self.strategy.print("[eval] rollout_eos_patch reattached after eval")
+
+ def _build_rollout_metrics(self, logs_dict: Dict[str, float]) -> Dict[str, float]:
+ rollout_metrics = {}
+ for target_key, source_keys in self._ROLLOUT_KEY_SOURCES.items():
+ for source_key in source_keys:
+ if source_key in logs_dict:
+ rollout_metrics[target_key] = logs_dict[source_key]
+ break
+ return rollout_metrics
+
+ def _build_train_metrics(self, logs_dict: Dict[str, float]) -> Dict[str, float]:
+ train_metrics = {}
+ for target_key, source_keys in self._TRAIN_KEY_SOURCES.items():
+ for source_key in source_keys:
+ if source_key in logs_dict:
+ train_metrics[target_key] = logs_dict[source_key]
+ break
+ return train_metrics
+
+ def _build_eval_metrics(self, raw_eval_metrics: Dict[str, float]) -> Dict[str, float]:
+ eval_metrics = {}
+ for target_key, source_keys in self._EVAL_KEY_SOURCES.items():
+ for source_key in source_keys:
+ if source_key in raw_eval_metrics:
+ eval_metrics[target_key] = raw_eval_metrics[source_key]
+ break
+ return eval_metrics
+
+ def _aggregate_eval_metrics(self, raw_eval_metrics: Dict[str, float]) -> Dict[str, float]:
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
+ return raw_eval_metrics
+
+ gathered_metrics = [None] * torch.distributed.get_world_size()
+ torch.distributed.all_gather_object(gathered_metrics, raw_eval_metrics or {})
+
+ total_samples = sum(float(metrics.get("num_samples", 0.0)) for metrics in gathered_metrics if metrics)
+ if total_samples <= 0:
+ return {}
+
+ aggregated_metrics = {"num_samples": total_samples}
+ mean_keys = {key for metrics in gathered_metrics if metrics for key in metrics.keys() if key.endswith("_mean")}
+ for key in mean_keys:
+ weighted_sum = 0.0
+ for metrics in gathered_metrics:
+ if not metrics or key not in metrics:
+ continue
+ weighted_sum += float(metrics["num_samples"]) * float(metrics[key])
+ aggregated_metrics[key] = weighted_sum / total_samples
+ return aggregated_metrics
+
+ def evaluate(self, eval_dataloader, global_step):
+ """Run URSA-flavored evaluation and return aggregated metrics.
+
+ Wraps the base trainer's :meth:`evaluate` in :meth:`_runtime_eval_context`
+ so the actor cannot emit reserved sentinel tokens (e.g. ``<|image|>``)
+ during eval rollouts, then folds per-dataset metrics into a single
+ sample-weighted average via :meth:`_aggregate_eval_metrics`.
+
+ :param eval_dataloader: Iterable over eval batches.
+ :param global_step: Current training step, used by the base evaluator
+ and logged alongside aggregated metrics on rank 0.
+ :returns: Dict of metric_name -> float ready to be uploaded under the
+ ``eval/`` namespace. Empty dict if the base evaluator produced no
+ metrics this step.
+ """
+ with self._runtime_eval_context():
+ raw_eval_metrics = super().evaluate(eval_dataloader, global_step)
+ aggregated_eval_metrics = self._aggregate_eval_metrics(raw_eval_metrics)
+ eval_metrics = self._build_eval_metrics(aggregated_eval_metrics)
+ if self.strategy.is_rank_0() and eval_metrics:
+ self.strategy.print(f"Aggregated runtime eval metrics (Step {global_step}):")
+ for key, value in eval_metrics.items():
+ self.strategy.print(f" {key}: {value:.4f}")
+ return eval_metrics
+
+ def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}, episode=0):
+ """Drive periodic W&B/TensorBoard logging, eval, and checkpoint saves.
+
+ Called once per training step. Three gating cadences from ``args``:
+ ``logging_steps`` (rollout + train metrics), ``eval_steps`` (runs
+ :meth:`evaluate` and uploads under ``eval/``), and ``save_steps``
+ (writes a ``global_step{N}`` checkpoint).
+
+ Differences from the base trainer's same-named method:
+ - Rollout / train metric streams are routed through
+ :meth:`_build_rollout_metrics` / :meth:`_build_train_metrics` so
+ they pick up URSA-specific keys (PRM diagnostics, ursa_v2_* fields).
+ - W&B logs use a monotonic ``wandb_log_counter`` instead of
+ ``global_step`` to keep the eval and train series on a single
+ increasing x-axis when eval and train ticks interleave.
+
+ :param args: argparse-parsed runtime args (uses ``logging_steps``,
+ ``eval_steps``, ``save_steps``).
+ :param global_step: Current training step.
+ :param step_bar: tqdm progress bar (kept for base-class signature
+ compatibility; not used directly here).
+ :param logs_dict: Per-step metric dict produced by the base trainer.
+ :param client_states: Extra state forwarded to the checkpoint saver.
+ :param episode: Current episode counter logged alongside other metrics.
+ """
+ if global_step % args.logging_steps == 0:
+ rollout_metrics = self._build_rollout_metrics(logs_dict)
+ train_metrics = self._build_train_metrics(logs_dict)
+
+ if self._wandb is not None and self.strategy.is_rank_0():
+ all_wandb_logs = {}
+
+ for key, value in rollout_metrics.items():
+ all_wandb_logs[f"rollout/{key}"] = value
+ all_wandb_logs["rollout/episode"] = episode
+
+ for key, value in train_metrics.items():
+ all_wandb_logs[f"train/{key}"] = value
+ all_wandb_logs["train/episode"] = episode
+
+ if all_wandb_logs:
+ self.wandb_log_counter += 1
+ self._wandb.log(all_wandb_logs, step=self.wandb_log_counter, commit=True)
+ self._update_wandb_summary(all_wandb_logs)
+
+ elif self._tensorboard is not None and self.strategy.is_rank_0():
+ for key, value in rollout_metrics.items():
+ self._tensorboard.add_scalar(f"rollout/{key}", value, global_step)
+ for key, value in train_metrics.items():
+ self._tensorboard.add_scalar(f"train/{key}", value, global_step)
+
+ if global_step % args.eval_steps == 0 and self.eval_dataloader is not None:
+ with self.profiler.phase("eval"):
+ with self.profiler.section("total"):
+ raw_eval_metrics = self.evaluate(self.eval_dataloader, global_step)
+
+ if raw_eval_metrics and self.strategy.is_rank_0():
+ self.eval_step_counter += 1
+
+ if self._wandb is not None:
+ eval_logs = {}
+ for key, value in raw_eval_metrics.items():
+ eval_logs[f"eval/{key}"] = value
+
+ eval_logs["eval/train_step"] = global_step
+ eval_logs["eval/episode"] = episode
+
+ self.wandb_log_counter += 1
+ self._wandb.log(eval_logs, step=self.wandb_log_counter, commit=True)
+ self._update_wandb_summary(eval_logs)
+
+ elif self._tensorboard is not None:
+ for key, value in raw_eval_metrics.items():
+ self._tensorboard.add_scalar(f"eval/{key}", value, global_step)
+
+ if global_step % args.save_steps == 0:
+ with self.profiler.phase("checkpoint"):
+ with self.profiler.section("total"):
+ tag = f"global_step{global_step}"
+ self._save_checkpoint(args, tag, client_states)
+
+ def log_profile_metrics(self, global_step: int, episode: int, profile_snapshot: Optional[Dict]) -> None:
+ """Forward step-profiler snapshots into W&B/TensorBoard on rank 0 only.
+
+ Profile snapshots come from :class:`StepProfileRecorder` and contain
+ a human-readable ``summary`` plus prebuilt ``wandb_logs`` dict. On
+ W&B we upload the prebuilt dict as-is under the ``profile/`` namespace;
+ on TensorBoard we fall back to writing ``sections_max_s`` and
+ ``sections_max_ratio`` scalars individually.
+
+ :param global_step: Current training step (used only by the TB path).
+ :param episode: Current episode counter, embedded into the W&B record.
+ :param profile_snapshot: Snapshot dict from the profiler, or None
+ if profiling was disabled or this step yielded no data. None is a
+ no-op.
+ """
+ if not profile_snapshot or not self.strategy.is_rank_0():
+ return
+
+ summary = profile_snapshot.get("summary")
+ if summary:
+ self.strategy.print(summary)
+
+ if self._wandb is not None:
+ wandb_logs = dict(profile_snapshot.get("wandb_logs", {}))
+ if wandb_logs:
+ wandb_logs["profile/episode"] = episode
+ self.wandb_log_counter += 1
+ self._wandb.log(wandb_logs, step=self.wandb_log_counter, commit=True)
+ self._update_wandb_summary(wandb_logs)
+
+ elif self._tensorboard is not None:
+ record = profile_snapshot.get("record", {})
+ for key, value in record.get("sections_max_s", {}).items():
+ self._tensorboard.add_scalar(f"profile/{key}_s", value, global_step)
+ for key, value in record.get("sections_max_ratio", {}).items():
+ self._tensorboard.add_scalar(f"profile/{key}_ratio", value, global_step)
+
+ def save_trajectories(self, global_step: int):
+ """Persist the current replay buffer contents to disk when configured.
+
+ No-op when either a :class:`TrajectorySaver` has not been wired up
+ or the replay buffer is empty. The saver itself handles rank gating
+ and on-disk layout.
+
+ :param global_step: Step value embedded in the saved trajectory's
+ file name / metadata so downstream analysis can correlate runs.
+ """
+ if self.trajectory_saver is not None and self.replay_buffer.items:
+ self.trajectory_saver.save_trajectories(
+ experiences=self.replay_buffer.items,
+ step=global_step,
+ num_samples=self.num_trajectories_to_save,
+ prefix="trajectories",
+ compute_stats=self.args.trajectory_analysis,
+ )
diff --git a/examples/math_prm/reward_models.py b/examples/math_prm/reward_models.py
new file mode 100644
index 00000000..b00c7ab8
--- /dev/null
+++ b/examples/math_prm/reward_models.py
@@ -0,0 +1,715 @@
+"""URSA-MATH Stage 3 reward model helpers."""
+
+from __future__ import annotations
+
+import re
+from itertools import zip_longest
+from typing import Any, Dict
+
+import torch
+import torch.nn as nn
+
+from lightrft.evaluation.math_eval_utils import (
+ compare_answers,
+ extract_answer,
+ extract_answer_from_tags,
+ extract_boxed_answer,
+ extract_multiple_choice_answer,
+ extract_numeric_answer,
+ normalize_answer,
+)
+
+try:
+ from mathruler.grader import grade_answer as mathruler_grade_answer
+except ImportError:
+ mathruler_grade_answer = None
+
+
+_VISION_PATTERNS = [
+ r"<\|vision_start\|>(<\|image_pad\|>)+<\|vision_end\|>",
+ r"
()+",
+ r"",
+]
+
+
+def _clean_vision_token(text: str) -> str:
+ """Remove vision placeholders from a user question before PRM scoring."""
+ for pattern in _VISION_PATTERNS:
+ text = re.sub(pattern, "", text)
+ return text
+
+
+# Pattern matching "Step N:" / "†Answer:" markers — kept only for diagnostic
+# `matched_patterns` output during alignment. The actual alignment uses the
+# URSA-native path (PRM's own ``replace_specific_plus_minus_with_ki`` + PRM's
+# tokenizer offset_mapping) instead of re-implementing char-level simulation.
+_STEP_OR_ANSWER_PATTERN = re.compile(r"(Step \d+\s*:|†Answer\s*:)")
+
+
+def find_step_boundaries_in_response_tokens(prm_module, response_text: str, question_text: str = ""):
+ """URSA-native step-boundary alignment.
+
+ Algorithm (no analytic re-implementation — every step uses PRM's own
+ code):
+ 1. Build prefix exactly like ``MathPRMReward._prepare_prm_input``:
+ prefix = _PRM_PROMPT + question + "\\n" (or _PRM_PROMPT + "\\n")
+ 2. Form the same string PRM scores on:
+ prm_input_str = prm_module.replace_specific_plus_minus_with_ki(
+ prefix + response_text)
+ 3. Tokenize with prm_module.tokenizer (the EXACT tokenizer PRM uses) and
+ locate every ` и` token (id == prm_module.tag_id).
+ 4. Each ` и` token's offset_mapping char_start lies inside prm_input_str.
+ Subtract len(prefix) and ``2 * k_tag`` (each prior ` и` adds 2 chars)
+ to recover the position in the ORIGINAL ``response_text`` where the
+ step-end occurs.
+ 5. Re-tokenize ``response_text`` (without ` и`) and find the response
+ token whose char_end <= that position. That is the per-step
+ boundary token index.
+
+ Returned indices are relative to response start (0 = first response
+ token). Caller (compute_reward via fast_exp_maker._compute_advantages)
+ scatters per-step rewards onto these indices.
+
+ Why native? It avoids any divergence between an analytic char-level
+ model of ` и` insertion and PRM's actual tokenizer behavior. If the
+ tokenizer ever merges ` и` with adjacent chars, this path stays correct
+ because we read offsets from the actual tokenization PRM uses.
+
+ Parameters
+ ----------
+ prm_module : MathPRMReward
+ Provides ``_PRM_PROMPT``, ``tokenizer``, ``tag_id``, and
+ ``replace_specific_plus_minus_with_ki``.
+ response_text : str
+ The actor-generated response (assistant content only, no chat tags).
+ question_text : str, optional
+ Prompt question — passed through ``_prepare_prm_input`` so the prefix
+ length matches the PRM-side string exactly.
+
+ Returns
+ -------
+ boundaries : list[int]
+ Per-step boundary token indices in the response token sequence.
+ ``len(boundaries) == number of step_scores PRM emits``.
+ matched_patterns : list[str]
+ ``Step N:`` / ``†Answer:`` patterns found in the response (debug aid).
+ """
+ matched_patterns = [m.group() for m in _STEP_OR_ANSWER_PATTERN.finditer(response_text)]
+
+ if question_text and not isinstance(question_text, float):
+ prefix_str = prm_module._PRM_PROMPT + question_text + "\n"
+ else:
+ prefix_str = prm_module._PRM_PROMPT + "\n"
+ prefix_len = len(prefix_str)
+ prm_input_str = prm_module.replace_specific_plus_minus_with_ki(prefix_str + response_text)
+
+ tok = prm_module.tokenizer
+ enc_prm = tok(prm_input_str, return_offsets_mapping=True, add_special_tokens=False)
+ prm_offsets = enc_prm["offset_mapping"]
+ prm_ids = enc_prm["input_ids"]
+ tag_id = prm_module.tag_id
+
+ char_in_response: list[int] = []
+ k_tag = 0
+ for tid, off in zip(prm_ids, prm_offsets):
+ if tid == tag_id:
+ char_in_response.append(off[0] - prefix_len - 2 * k_tag)
+ k_tag += 1
+
+ enc_resp = tok(response_text, return_offsets_mapping=True, add_special_tokens=False)
+ resp_offsets = enc_resp["offset_mapping"]
+
+ boundaries: list[int] = []
+ for cp in char_in_response:
+ last_idx = -1
+ for tok_idx, (_, ce) in enumerate(resp_offsets):
+ if 0 < ce <= cp:
+ last_idx = tok_idx
+ boundaries.append(last_idx)
+ return boundaries, matched_patterns
+
+
+class MathPRMReward(nn.Module):
+ """Wrap URSA-RM with the original URSA-MATH step-level scoring protocol."""
+
+ _SYSTEM_PROMPT = "You are a helpful assistant."
+ _PRM_PROMPT = (
+ "You are given a problem and a step-by-step solution. "
+ "You need to check the correctness of each step.\nQuestion:"
+ )
+ _IMAGE_PAD = 575
+ # PS-GRPO step-score drop hyperparameters (URSA-MATH paper):
+ # _DROP_THRESHOLD - relative drop fraction that counts as a "drop moment"
+ # _DROP_GAMMA - reward penalty when a drop moment is observed for a
+ # correct answer; final_reward = 1 - _DROP_GAMMA = 0.5
+ _DROP_THRESHOLD = 0.3
+ _DROP_GAMMA = 0.5
+
+ def __init__(self, base_model: nn.Module, processor, aggregation: str = "min") -> None:
+ super().__init__()
+ self.model = base_model
+ self.processor = processor
+ self.tokenizer = processor.tokenizer
+ self.aggregation = aggregation
+
+ tag_ids = self.tokenizer.encode(" и", add_special_tokens=False)
+ assert len(tag_ids) == 1, (
+ "The step tag ' и' must map to exactly one token. "
+ f"Got {tag_ids!r} instead."
+ )
+ self.tag_id = int(tag_ids[0])
+
+ @staticmethod
+ def replace_specific_plus_minus_with_ki(text: str) -> str:
+ """Insert the URSA step-boundary marker `` и`` before each next step."""
+ pattern = r"Step \d+"
+ matches = list(re.finditer(pattern, text))
+ positions = [(match.start(), match.end()) for match in matches]
+ if not positions:
+ return text + " и"
+
+ text_list = list(text)
+ insert_positions = []
+ try:
+ for i in range(1, len(positions)):
+ for j in range(positions[i][0] - 1, positions[i - 1][1], -1):
+ if text_list[j] not in {" ", "\n"}:
+ insert_positions.append(j + 1)
+ break
+
+ answer_start = text.find("†Answer:")
+ if answer_start != -1:
+ for j in range(answer_start - 1, positions[-1][1], -1):
+ if text_list[j] not in {" ", "\n"}:
+ insert_positions.append(j + 1)
+ break
+
+ for index in sorted(insert_positions, reverse=True):
+ text = text[:index] + " и" + text[index:]
+ return text
+ except Exception:
+ return text + " и"
+
+ def _prepare_prm_input(self, question: str, response: str) -> str:
+ if not question or isinstance(question, float):
+ instruction = self._PRM_PROMPT + "\n" + response
+ else:
+ instruction = self._PRM_PROMPT + question + "\n" + response
+ return self.replace_specific_plus_minus_with_ki(instruction)
+
+ def _split_conversation(self, prompt_and_output: str) -> tuple[str, str]:
+ question = ""
+ response = ""
+
+ for sep in ("<|im_start|>user\n", "User:", "USER:"):
+ if sep not in prompt_and_output:
+ continue
+ user_block = prompt_and_output.split(sep)[-1]
+ for end in ("<|im_end|>", "<|im_start|>"):
+ if end in user_block:
+ user_block = user_block.split(end)[0]
+ question = self._clean_question_text(user_block)
+ break
+
+ for sep in ("<|im_start|>assistant\n", "Assistant:", "ASSISTANT:"):
+ if sep not in prompt_and_output:
+ continue
+ response_block = prompt_and_output.split(sep)[-1]
+ for end in ("<|im_end|>", "<|endoftext|>"):
+ if end in response_block:
+ response_block = response_block.split(end)[0]
+ response = response_block.strip()
+ break
+
+ if not response:
+ response = prompt_and_output
+ return question, response
+
+ @staticmethod
+ def _clean_question_text(question: str) -> str:
+ question = _clean_vision_token(question)
+ question = question.replace("<|image|>", "").replace("", "")
+ return question.strip()
+
+ @staticmethod
+ def _select_prm_image(raw_image: Any) -> list[Any]:
+ if isinstance(raw_image, (list, tuple)):
+ for item in raw_image:
+ if item is not None:
+ return [item]
+ return [None]
+ return [raw_image] if raw_image is not None else [None]
+
+ @staticmethod
+ def _safe_text(value: Any) -> str:
+ if value is None:
+ return ""
+ return str(value).strip()
+
+ @staticmethod
+ def _is_multiple_choice_reference(reference: str) -> bool:
+ ref = normalize_answer(reference).strip().upper()
+ return len(ref) == 1 and ref in {"A", "B", "C", "D"}
+
+ @classmethod
+ def _infer_reference_type(cls, reference: Any) -> tuple[str, bool]:
+ reference_text = cls._safe_text(reference)
+ if not reference_text:
+ return "missing", False
+
+ reference_norm = normalize_answer(reference_text).strip()
+ if cls._is_multiple_choice_reference(reference_norm):
+ return "multiple_choice", True
+
+ if reference_norm.lower() in {"yes", "no", "true", "false"}:
+ return "text", True
+
+ numeric_candidate = reference_norm.replace(",", "")
+ if re.fullmatch(r"-?\d+(?:\.\d+)?", numeric_candidate):
+ return "numeric", True
+ if re.fullmatch(r"-?\d+/\d+", numeric_candidate):
+ return "numeric", True
+
+ if any(token in reference_norm for token in ("\\", "=", "^", "_", "{", "}", "sqrt", "frac")):
+ return "formula", True
+ if re.search(r"[a-zA-Z]", reference_norm) and re.search(r"[\d=+\-*/()]", reference_norm):
+ return "formula", True
+
+ return "text", True
+
+ @classmethod
+ def _extract_answer_from_candidate(cls, candidate: str, reference_type: str) -> str:
+ candidate = cls._safe_text(candidate)
+ if not candidate:
+ return ""
+
+ boxed = extract_boxed_answer(candidate)
+ if boxed:
+ return boxed
+
+ tagged = extract_answer_from_tags(candidate, "answer")
+ if tagged:
+ return tagged
+
+ candidate = re.sub(
+ r"^(?:†\s*)?(?:final answer|correct answer(?: is)?|the answer is|answer)\s*[::]?\s*",
+ "",
+ candidate,
+ flags=re.IGNORECASE,
+ ).strip()
+ candidate = candidate.rstrip(" .")
+ if not candidate:
+ return ""
+
+ if reference_type == "multiple_choice":
+ extracted = extract_multiple_choice_answer(candidate)
+ return extracted or normalize_answer(candidate).strip().upper()
+
+ if reference_type == "numeric":
+ if any(token in candidate for token in ("\\", "/", "=", "^", "{", "}", "sqrt", "frac")):
+ return normalize_answer(candidate)
+ extracted = extract_numeric_answer(candidate)
+ return extracted or normalize_answer(candidate)
+
+ if reference_type in {"formula", "text"}:
+ return normalize_answer(candidate)
+
+ return extract_answer(candidate)
+
+ @classmethod
+ def _extract_final_answer_details(cls, response: str, reference_type: str) -> Dict[str, Any]:
+ response = cls._safe_text(response)
+ details: Dict[str, Any] = {
+ "predicted_answer": "",
+ "answer_tag_present": False,
+ "answer_extraction_failed": True,
+ "used_answer_fallback": False,
+ "extraction_source": "missing",
+ }
+ if not response:
+ return details
+
+ if "†Answer:" in response:
+ details["answer_tag_present"] = True
+ answer_block = response.split("†Answer:", 1)[-1]
+ answer_block = re.split(r"\n\s*Step\s+\d+\s*:", answer_block, maxsplit=1)[0]
+ candidate_lines = [line.strip() for line in answer_block.splitlines() if line.strip()]
+ candidate = candidate_lines[0] if candidate_lines else answer_block.strip()
+ predicted_answer = cls._extract_answer_from_candidate(candidate, reference_type)
+ details["predicted_answer"] = predicted_answer
+ details["answer_extraction_failed"] = predicted_answer == ""
+ details["extraction_source"] = "dagger_answer"
+ return details
+
+ explicit_fallbacks = [
+ ("boxed", extract_boxed_answer(response)),
+ ("tagged_answer", extract_answer_from_tags(response, "answer")),
+ ]
+ for source, match in explicit_fallbacks:
+ if match:
+ details["predicted_answer"] = normalize_answer(match)
+ details["answer_extraction_failed"] = False
+ details["used_answer_fallback"] = True
+ details["extraction_source"] = source
+ return details
+
+ lines = [line.strip() for line in response.splitlines() if line.strip()]
+ if lines:
+ last_line = lines[-1]
+ explicit_line = re.match(
+ r"^(?:†\s*)?(?:final answer|correct answer(?: is)?|the answer is|answer)\b",
+ last_line,
+ flags=re.IGNORECASE,
+ )
+ if explicit_line:
+ predicted_answer = cls._extract_answer_from_candidate(last_line, reference_type)
+ details["predicted_answer"] = predicted_answer
+ details["answer_extraction_failed"] = predicted_answer == ""
+ details["used_answer_fallback"] = True
+ details["extraction_source"] = "explicit_last_line"
+ return details
+
+ return details
+
+ @classmethod
+ def _compare_final_answer(
+ cls,
+ predicted_answer: str,
+ reference: Any,
+ reference_type: str,
+ reference_supported: bool,
+ ) -> tuple[bool, str]:
+ reference_text = cls._safe_text(reference)
+ if not reference_supported:
+ return False, "unsupported_reference"
+ if not reference_text:
+ return False, "missing_reference"
+ if not predicted_answer:
+ return False, "missing_prediction"
+
+ if reference_type == "multiple_choice":
+ pred_norm = normalize_answer(predicted_answer).strip().upper()
+ ref_norm = normalize_answer(reference_text).strip().upper()
+ return pred_norm == ref_norm, "multiple_choice_exact"
+
+ if reference_type in {"numeric", "formula"}:
+ if mathruler_grade_answer is not None:
+ try:
+ if mathruler_grade_answer(predicted_answer, reference_text):
+ return True, "mathruler"
+ except Exception:
+ pass
+ return compare_answers(predicted_answer, reference_text, is_multiple_choice=False), "math_eval"
+
+ return compare_answers(predicted_answer, reference_text, is_multiple_choice=False), "text_compare"
+
+ @classmethod
+ def _evaluate_answer_alignment(cls, response: str, reference: Any) -> Dict[str, Any]:
+ reference_type, reference_supported = cls._infer_reference_type(reference)
+ extraction = cls._extract_final_answer_details(response, reference_type)
+ outcome_correct, comparison_method = cls._compare_final_answer(
+ extraction["predicted_answer"],
+ reference,
+ reference_type,
+ reference_supported,
+ )
+ return {
+ "reference_type": reference_type,
+ "reference_supported": reference_supported,
+ "comparison_method": comparison_method,
+ **extraction,
+ "outcome_correct": outcome_correct,
+ }
+
+ @classmethod
+ def _compute_relative_drop(cls, step_scores: torch.Tensor) -> tuple[float, bool]:
+ if step_scores.numel() < 2:
+ return 0.0, False
+
+ scores = step_scores.detach().float()
+ prev_scores = scores[:-1]
+ next_scores = scores[1:]
+ denom = torch.clamp(prev_scores, min=1e-6)
+ relative_drops = torch.clamp((prev_scores - next_scores) / denom, min=0.0)
+ max_relative_drop = float(relative_drops.max().item()) if relative_drops.numel() else 0.0
+ return max_relative_drop, max_relative_drop >= cls._DROP_THRESHOLD
+
+ @classmethod
+ def _compute_psgrpo_metrics(
+ cls,
+ response: str,
+ reference: Any,
+ step_scores: torch.Tensor,
+ ) -> Dict[str, float]:
+ answer_eval = cls._evaluate_answer_alignment(response, reference)
+ outcome_correct = float(answer_eval["outcome_correct"])
+ max_relative_drop, has_drop_moment = cls._compute_relative_drop(step_scores)
+
+ final_reward = 0.0
+ if outcome_correct > 0.0:
+ final_reward = 1.0 - cls._DROP_GAMMA if has_drop_moment else 1.0
+
+ return {
+ "outcome_correct": outcome_correct,
+ "max_relative_drop": max_relative_drop,
+ "has_drop_moment": float(has_drop_moment),
+ "final_reward": final_reward,
+ "answer_tag_present": float(answer_eval["answer_tag_present"]),
+ "answer_extraction_failed": float(answer_eval["answer_extraction_failed"]),
+ "used_answer_fallback": float(answer_eval["used_answer_fallback"]),
+ "reference_supported": float(answer_eval["reference_supported"]),
+ "used_mathruler": float(answer_eval["comparison_method"] == "mathruler"),
+ }
+
+ @torch.no_grad()
+ def forward(
+ self,
+ sequences,
+ attention_mask,
+ prompt_and_output=None,
+ raw_images=None,
+ references=None,
+ labels=None,
+ **kwargs,
+ ) -> torch.Tensor | Dict[str, torch.Tensor]:
+ device = next(self.model.parameters()).device
+
+ if prompt_and_output is None and sequences is not None:
+ prompt_and_output = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
+ elif prompt_and_output is None:
+ raise ValueError("Either sequences or prompt_and_output must be provided")
+
+ return_dict = bool(kwargs.get("return_dict", False))
+
+ batch_rewards = []
+ # Per-sample reward metrics emitted alongside the scalar reward.
+ # They are grouped into three buckets:
+ #
+ # 1. PRM step-score statistics (continuous, distribution shape):
+ # model_reward - aggregated step score (min/avg/last per agg setting)
+ # step_score_min - lowest step score in the response
+ # step_score_mean - mean step score
+ # step_score_last - score of the final step
+ # step_count - number of "Step N:" blocks scored
+ #
+ # 2. Outcome / correctness signals (mostly binary):
+ # outcome_correct - 1 if extracted answer matches ground truth, else 0
+ # has_drop_moment - 1 if any consecutive step pair dropped > _DROP_THRESHOLD
+ # max_relative_drop - magnitude of the largest relative drop
+ # final_reward - PS-GRPO reward {0, 1-_DROP_GAMMA, 1} fed into GRPO
+ #
+ # 3. Diagnostics on answer extraction / grading path (low-volume but useful
+ # when debugging dataset / format / mathruler issues):
+ # answer_tag_present - 1 if the "†Answer:" marker appeared
+ # answer_extraction_failed - 1 if no answer string could be extracted
+ # used_answer_fallback - 1 if the heuristic last-line fallback fired
+ # reference_supported - 1 if the ground-truth schema is recognized
+ # used_mathruler - 1 if mathruler grading was the deciding step
+ #
+ # NOTE: ``accuracy_reward`` used to live here, but for math_psgrpo it is
+ # exactly equal to ``outcome_correct`` (see _compute_psgrpo_metrics).
+ # It now lives only in reward_models_utils.mix_rewards where it is set
+ # by the rule branch for the math_rule / math_prm_combined recipes.
+ batch_metrics: Dict[str, list[float]] = {
+ "model_reward": [],
+ "step_score_min": [],
+ "step_score_mean": [],
+ "step_score_last": [],
+ "step_count": [],
+ "outcome_correct": [],
+ "max_relative_drop": [],
+ "has_drop_moment": [],
+ "final_reward": [],
+ "answer_tag_present": [],
+ "answer_extraction_failed": [],
+ "used_answer_fallback": [],
+ "reference_supported": [],
+ "used_mathruler": [],
+ # math_per_step_prm diagnostics
+ "alignment_failed": [],
+ "n_aligned_steps": [],
+ }
+ # Per-step PRM data (variant 2). Lists of per-trajectory tensors,
+ # only populated for label == "math_per_step_prm" trajectories.
+ # When all trajectories in a batch have label == "math_psgrpo" the
+ # collected lists stay empty and the dict keys are dropped at the end
+ # so legacy callers that don't know about per-step rewards see no
+ # change.
+ batch_step_rewards: list[torch.Tensor] = []
+ batch_step_token_indices: list[torch.Tensor] = []
+
+ image_inputs = raw_images or [None] * len(prompt_and_output)
+ ref_inputs = references or [None] * len(prompt_and_output)
+ label_inputs = labels or ["math_prm"] * len(prompt_and_output)
+
+ for text, sample_image, reference, label in zip_longest(
+ prompt_and_output, image_inputs, ref_inputs, label_inputs, fillvalue=None
+ ):
+ if text is None:
+ continue
+
+ question, response = self._split_conversation(text)
+ input_prompt = self._prepare_prm_input(question, response)
+ conversation = [
+ {"role": "system", "content": self._SYSTEM_PROMPT},
+ {"role": "user", "content": "<|image|>" + input_prompt},
+ ]
+ formatted_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
+ inputs = self.processor(
+ formatted_prompt,
+ self._select_prm_image(sample_image),
+ return_tensors="pt",
+ ).to(device, torch.bfloat16)
+
+ # Sanity check: response (RL-generated) is not vision-cleaned, so it can
+ # contain literal `<|image|>` / `` strings that the tokenizer maps
+ # to image_token_index. The PRM only ever receives one image, so any
+ # extras would crash _merge_input_ids_with_image_features. Keep the first
+ # image token (intended placeholder) and replace the rest with a benign
+ # text token so PRM scoring continues instead of aborting the rollout.
+ image_token_id = getattr(self.model.config, "image_token_index", None)
+ if image_token_id is not None:
+ input_ids_view = inputs["input_ids"]
+ image_mask_flat = (input_ids_view == image_token_id).view(-1)
+ extras = torch.nonzero(image_mask_flat, as_tuple=False).squeeze(-1)
+ if extras.numel() > 1:
+ replacement = self.tokenizer.pad_token_id
+ if replacement is None:
+ replacement = self.tokenizer.eos_token_id
+ flat = input_ids_view.view(-1)
+ flat[extras[1:]] = replacement
+ inputs["input_ids"] = flat.view(input_ids_view.shape)
+
+ reward = self.model(**inputs).logits
+ input_ids = inputs["input_ids"].view(-1)
+ padding = torch.full((self._IMAGE_PAD,), -1, device=device)
+ input_ids_aligned = torch.cat((input_ids[:1], padding, input_ids[1:]))
+
+ reward_flat = reward.view(-1)
+ step_logits = reward_flat[input_ids_aligned == self.tag_id]
+ step_scores = torch.sigmoid(step_logits).view(-1)
+ psgrpo_metrics = self._compute_psgrpo_metrics(response, reference, step_scores)
+
+ if step_scores.numel() == 0:
+ aggregated_score = 0.0
+ elif self.aggregation == "min":
+ aggregated_score = float(torch.min(step_scores).item())
+ elif self.aggregation in {"avg", "mean"}:
+ aggregated_score = float(torch.mean(step_scores).item())
+ elif self.aggregation == "last":
+ aggregated_score = float(step_scores[-1].item())
+ else:
+ raise ValueError(f"Unknown aggregation: {self.aggregation!r}")
+
+ # ---- variant 2 (per-step PRM reward) alignment -------------------
+ # For label == "math_per_step_prm" we additionally locate the
+ # boundary token of each "Step N:" inside the response so the
+ # step_scores tensor can be scattered to per-token positions
+ # downstream (instead of being collapsed to one scalar).
+ #
+ # Alignment is *self-contained*: we re-tokenize ``response`` with
+ # the actor's (== PRM's, in URSA family) tokenizer and use the
+ # offset mapping to reverse-find token indices for each "Step N:"
+ # / "†Answer:" pattern. Indices are relative to the response
+ # start so they line up with the action_mask axis of final_reward
+ # in compute_reward.
+ #
+ # If the alignment fails (n_steps_prm != n_boundaries) we
+ # *bypass* per-step mode for that trajectory: emit empty tensors
+ # so compute_reward falls back to the trajectory-scalar path,
+ # and bump the alignment_failed metric for monitoring.
+ if label == "math_per_step_prm" and step_scores.numel() > 0:
+ boundaries, matched_patterns = find_step_boundaries_in_response_tokens(
+ self, response, question_text=question
+ )
+ aligned = (len(boundaries) == int(step_scores.numel()))
+ if aligned:
+ traj_step_rewards = step_scores.detach().to(torch.float32).cpu()
+ traj_step_tokens = torch.tensor(boundaries, dtype=torch.long)
+ n_aligned = len(boundaries)
+ else:
+ # Alignment failed: emit empties; downstream falls back to
+ # trajectory-scalar mode for this row.
+ traj_step_rewards = torch.empty(0, dtype=torch.float32)
+ traj_step_tokens = torch.empty(0, dtype=torch.long)
+ n_aligned = 0
+ batch_step_rewards.append(traj_step_rewards)
+ batch_step_token_indices.append(traj_step_tokens)
+ batch_metrics["alignment_failed"].append(0.0 if aligned else 1.0)
+ batch_metrics["n_aligned_steps"].append(float(n_aligned))
+ else:
+ # No per-step request for this row; emit empty placeholders to
+ # keep the per-traj list aligned with batch_rewards.
+ batch_step_rewards.append(torch.empty(0, dtype=torch.float32))
+ batch_step_token_indices.append(torch.empty(0, dtype=torch.long))
+ batch_metrics["alignment_failed"].append(0.0)
+ batch_metrics["n_aligned_steps"].append(0.0)
+
+ # ---- trajectory-scalar reward (PSGRPO / aggregate path) ----------
+ if label == "math_psgrpo":
+ sequence_reward = psgrpo_metrics["final_reward"]
+ elif label == "math_per_step_prm":
+ # In per-step mode the trajectory-scalar field is still used
+ # by GroupNorm baseline — use outcome (clean signal) instead
+ # of aggregated_score (which would double-count step rewards).
+ sequence_reward = float(psgrpo_metrics["outcome_correct"])
+ else:
+ sequence_reward = aggregated_score
+ batch_rewards.append(sequence_reward)
+ batch_metrics["model_reward"].append(aggregated_score)
+ batch_metrics["step_score_min"].append(float(torch.min(step_scores).item()) if step_scores.numel() else 0.0)
+ batch_metrics["step_score_mean"].append(float(torch.mean(step_scores).item()) if step_scores.numel() else 0.0)
+ batch_metrics["step_score_last"].append(float(step_scores[-1].item()) if step_scores.numel() else 0.0)
+ batch_metrics["step_count"].append(float(step_scores.numel()))
+ # Diagnostics: outcome_correct and answer-extraction signals are
+ # always meaningful (they're computed by _evaluate_answer_alignment
+ # which is independent of PSGRPO drop-moment); only the
+ # drop-moment-specific fields (max_relative_drop, has_drop_moment,
+ # final_reward) zero out for non-PSGRPO labels.
+ _UNIVERSAL_METRICS = {
+ "outcome_correct",
+ "answer_tag_present",
+ "answer_extraction_failed",
+ "used_answer_fallback",
+ "reference_supported",
+ "used_mathruler",
+ }
+ for key in (
+ "outcome_correct",
+ "max_relative_drop",
+ "has_drop_moment",
+ "final_reward",
+ "answer_tag_present",
+ "answer_extraction_failed",
+ "used_answer_fallback",
+ "reference_supported",
+ "used_mathruler",
+ ):
+ if label == "math_psgrpo" or key in _UNIVERSAL_METRICS:
+ batch_metrics[key].append(psgrpo_metrics[key])
+ else:
+ # PSGRPO-specific (max_relative_drop, has_drop_moment,
+ # final_reward) — only meaningful when label is
+ # "math_psgrpo"; zero out for other labels to preserve
+ # historical metric tensor shape & semantics.
+ batch_metrics[key].append(0.0)
+
+ score_tensor = torch.tensor(batch_rewards, dtype=torch.float32, device=device)
+ if references is None and labels is None and not return_dict:
+ return score_tensor
+
+ metrics_tensor = {
+ key: torch.tensor(values, dtype=torch.float32, device=device)
+ for key, values in batch_metrics.items()
+ }
+ out = {"score": score_tensor, **metrics_tensor}
+
+ # Only attach per-step fields if any trajectory had non-empty step data.
+ # Stored as Python lists of CPU tensors (variable length per traj) to
+ # avoid forcing every caller to handle padded tensors.
+ any_per_step = any(t.numel() > 0 for t in batch_step_rewards)
+ if any_per_step:
+ out["step_rewards"] = batch_step_rewards
+ out["step_token_indices"] = batch_step_token_indices
+
+ return out
diff --git a/examples/math_prm/reward_models_utils.py b/examples/math_prm/reward_models_utils.py
new file mode 100644
index 00000000..d6fd4ca1
--- /dev/null
+++ b/examples/math_prm/reward_models_utils.py
@@ -0,0 +1,373 @@
+"""Math-only reward loading and aggregation utilities for URSA-MATH Stage 3."""
+
+from __future__ import annotations
+
+import json
+import re
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
+
+import torch
+
+from lightrft.models.monkey_patch.hf_generate_patch import apply_monkey_patch_to_generation_mixin
+from lightrft.utils import get_current_device
+
+from reward_models import MathPRMReward
+
+
+class RewardModelType(str, Enum):
+ """Supported reward model types for the math_prm example."""
+
+ MATH_PRM = "math_prm"
+
+
+@dataclass
+class RewardModelConfig:
+ """Configuration for one reward model instance."""
+
+ rtype: RewardModelType
+ path: str
+ use_engine: bool = False
+
+
+RawRewardInput = Union[str, Dict[str, str], List[Dict[str, str]], None]
+_BUILDERS: Dict[RewardModelType, Callable[..., Tuple[Any, Any]]] = {}
+
+
+def register_builder(rtype: RewardModelType) -> Callable:
+ def deco(fn: Callable) -> Callable:
+ _BUILDERS[rtype] = fn
+ return fn
+
+ return deco
+
+
+def _guess_rtype_from_path(path: str) -> RewardModelType:
+ lowered = path.lower()
+ if any(keyword in lowered for keyword in ("ursa", "prm", "math-rm", "step-reward", "process-reward")):
+ return RewardModelType.MATH_PRM
+ return RewardModelType.MATH_PRM
+
+
+def parse_reward_pretrain(
+ raw: RawRewardInput,
+ *,
+ global_use_engine: bool,
+) -> Tuple[List[RewardModelConfig], Dict[str, int]]:
+ """Parse reward model config while keeping the old flexible input shapes."""
+
+ if raw is None:
+ return [], {}
+
+ pair_list: List[Tuple[str, str, Optional[bool]]] = []
+ if isinstance(raw, str):
+ text = raw.strip()
+ if not text:
+ return [], {}
+ if text.startswith("{") and text.endswith("}"):
+ obj = json.loads(text)
+ pair_list = [(key, value, None) for key, value in obj.items()]
+ else:
+ for segment in re.split(r"\s*,\s*", text):
+ if not segment:
+ continue
+ if ":" in segment:
+ key, value = segment.split(":", 1)
+ pair_list.append((key.strip(), value.strip(), None))
+ else:
+ pair_list.append(("?", segment.strip(), None))
+ elif isinstance(raw, dict):
+ pair_list = [(key, value, None) for key, value in raw.items()]
+ elif isinstance(raw, list):
+ for item in raw:
+ pair_list.append((item["type"], item["path"], item.get("engine")))
+ else:
+ raise TypeError("Unsupported --reward_pretrain format")
+
+ cfgs: List[RewardModelConfig] = []
+ for key, path, flag in pair_list:
+ use_engine = global_use_engine
+ if "?engine=" in path:
+ path, qs = path.split("?engine=", 1)
+ use_engine = qs.lower() in {"1", "true", "yes"}
+ if flag is not None:
+ use_engine = bool(flag)
+ rtype = _guess_rtype_from_path(path) if key == "?" else RewardModelType(key)
+ cfgs.append(RewardModelConfig(rtype=rtype, path=path, use_engine=use_engine))
+
+ label_map = {cfg.rtype.value: index for index, cfg in enumerate(cfgs)}
+ return cfgs, label_map
+
+
+def _load_ursa_prm_model(pretrain_path: str, device: torch.device | int) -> Tuple[Any, Any]:
+ from ursa_model import UrsaForTokenClassification, UrsaProcessor
+
+ processor = UrsaProcessor.from_pretrained(pretrain_path)
+ model = UrsaForTokenClassification.from_pretrained(
+ pretrain_path,
+ torch_dtype=torch.bfloat16,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ )
+ model = model.to(device)
+ model.eval()
+ return model, processor
+
+
+def _load_engine(pretrain_path: str, device: torch.device | int) -> Tuple[Any, Any]:
+ raise RuntimeError(
+ "The math_prm example no longer supports external reward-model engines. "
+ "URSA-RM is loaded through the local HF path instead."
+ )
+
+
+def _shared_base_key(cfg: RewardModelConfig) -> Optional[Tuple[str, str]]:
+ if cfg.rtype != RewardModelType.MATH_PRM:
+ return None
+ return (cfg.path, cfg.rtype.value)
+
+
+def _load_shared_base(cfg: RewardModelConfig) -> Tuple[Any, Any]:
+ return _load_ursa_prm_model(cfg.path, get_current_device())
+
+
+@register_builder(RewardModelType.MATH_PRM)
+def build_math_prm(
+ cfg: RewardModelConfig,
+ strategy: Any,
+ base: Optional[Tuple[Any, Any]] = None,
+) -> Tuple[MathPRMReward, Any]:
+ if cfg.use_engine:
+ strategy.print(
+ "[build_math_prm] Engine mode is not supported for URSA-RM. "
+ "Falling back to direct HF loading."
+ )
+
+ if base is None:
+ base_model, processor = _load_ursa_prm_model(cfg.path, get_current_device())
+ else:
+ base_model, processor = base
+
+ reward_model = MathPRMReward(
+ base_model=base_model,
+ processor=processor,
+ aggregation="min",
+ )
+ reward_model.eval()
+ return reward_model, processor.tokenizer
+
+
+def load_reward_models(
+ raw_reward_pretrain: RawRewardInput,
+ strategy: Any,
+ use_engine: bool = False,
+) -> Tuple[List[Any], List[Any], Dict[str, int]]:
+ apply_monkey_patch_to_generation_mixin()
+ cfgs, label_map = parse_reward_pretrain(raw_reward_pretrain, global_use_engine=use_engine)
+
+ reward_models: List[Any] = []
+ reward_tokenizers: List[Any] = []
+ shared_bases: Dict[Tuple[str, str], Tuple[Any, Any]] = {}
+
+ for cfg in cfgs:
+ cache_key = _shared_base_key(cfg)
+ if cache_key is not None and cache_key not in shared_bases:
+ shared_bases[cache_key] = _load_shared_base(cfg)
+ strategy.print(f"Init reward model base {cfg.path} (engine={cfg.use_engine}, type={cfg.rtype})")
+
+ for cfg in cfgs:
+ if cfg.rtype not in _BUILDERS:
+ raise RuntimeError(f"No builder registered for {cfg.rtype}")
+ strategy.print(f"Loading {cfg.rtype} from {cfg.path} (engine={cfg.use_engine})")
+ with strategy.init_model_context() as _:
+ reward_model, tokenizer = _BUILDERS[cfg.rtype](
+ cfg,
+ strategy,
+ base=shared_bases.get(_shared_base_key(cfg)),
+ )
+ reward_models.append(reward_model)
+ reward_tokenizers.append(tokenizer)
+ strategy.print(f"Loaded {cfg.rtype}")
+
+ return reward_models, reward_tokenizers, label_map
+
+
+def math_prm_format_reward_fn(sol: str) -> float:
+ """Diagnostic-only check for the required Stage 3 ``Step N`` / ``†Answer`` format."""
+ if not isinstance(sol, str):
+ return 0.0
+ step_matches = re.findall(r"(?m)^Step\s+\d+\s*:\s*\S", sol)
+ answer_matches = re.findall(r"(?m)^†Answer:\s*\S", sol)
+ non_empty_lines = [line.strip() for line in sol.splitlines() if line.strip()]
+ if not step_matches or len(answer_matches) != 1 or not non_empty_lines:
+ return 0.0
+ return 1.0 if non_empty_lines[-1].startswith("†Answer:") else 0.0
+
+
+def format_reward_fn(sol: str) -> float:
+ """Compatibility alias kept for older callers inside this example directory."""
+ return math_prm_format_reward_fn(sol)
+
+
+def rule_reward_fn(sol: str, gt: str) -> float:
+ """Rule-only baseline using the same controlled final-answer extraction as PS-GRPO."""
+ if not gt:
+ return 0.0
+ answer_eval = MathPRMReward._evaluate_answer_alignment(sol, gt)
+ return 1.0 if answer_eval["outcome_correct"] else 0.0
+
+
+RECIPE: Dict[str, List[Tuple[str, Optional[str], float]]] = {
+ "math_prm": [("model", "math_prm", 1.0)],
+ "math_psgrpo": [("model", "math_prm", 1.0)],
+ "math_per_step_prm": [("model", "math_prm", 1.0)],
+ "math_prm_combined": [("model", "math_prm", 1.0), ("rule", None, 0.5)],
+ "math_rule": [("rule", None, 1.0)],
+}
+
+
+NO_GLOBAL_FORMAT_REWARD_LABELS = {
+ "math_prm",
+ "math_psgrpo",
+ "math_per_step_prm",
+ "math_prm_combined",
+ "math_rule",
+}
+
+
+def mix_rewards(
+ labels: Sequence[str],
+ model_scores: torch.Tensor,
+ label_map: Dict[str, int],
+ solution_strs: Sequence[str],
+ refs: Sequence[str],
+ model_reward_metrics_list: Optional[List[Optional[Dict[str, torch.Tensor]]]] = None,
+) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ if model_scores.numel() > 0:
+ device = model_scores.device
+ elif model_reward_metrics_list:
+ first_metric_tensor = next(
+ (
+ tensor
+ for metrics in model_reward_metrics_list
+ if metrics
+ for tensor in metrics.values()
+ if isinstance(tensor, torch.Tensor)
+ ),
+ None,
+ )
+ device = first_metric_tensor.device if first_metric_tensor is not None else torch.device("cpu")
+ else:
+ device = torch.device("cpu")
+
+ n_model = int(model_scores.shape[0])
+ batch_size = len(labels)
+ if model_scores.ndim != 2:
+ raise ValueError(f"model_scores must have shape (n_model, B), got {tuple(model_scores.shape)!r}")
+ if model_scores.shape[1] != batch_size:
+ raise AssertionError("model_scores second dimension must equal batch size")
+
+ final_reward = torch.zeros(batch_size, dtype=torch.float32, device=device)
+ metrics_dict: Dict[str, torch.Tensor] = {
+ "format_reward": torch.zeros(batch_size, dtype=torch.float32, device=device),
+ "accuracy_reward": torch.zeros(batch_size, dtype=torch.float32, device=device),
+ "model_reward": torch.zeros(batch_size, dtype=torch.float32, device=device),
+ "rule_reward": torch.zeros(batch_size, dtype=torch.float32, device=device),
+ "outcome_correct": torch.zeros(batch_size, dtype=torch.float32, device=device),
+ "max_relative_drop": torch.zeros(batch_size, dtype=torch.float32, device=device),
+ "has_drop_moment": torch.zeros(batch_size, dtype=torch.float32, device=device),
+ "final_reward": torch.zeros(batch_size, dtype=torch.float32, device=device),
+ }
+
+ def ensure_metric_key(metric_name: str) -> None:
+ if metric_name not in metrics_dict:
+ metrics_dict[metric_name] = torch.zeros(batch_size, dtype=torch.float32, device=device)
+
+ def get_model_reward(key: str, index: int) -> float:
+ if key not in label_map:
+ print(f"Model reward <{key}> not loaded, using 1.0 as fallback")
+ return 1.0
+ model_index = label_map[key]
+ if model_index >= n_model:
+ print(f"Model reward <{key}> index {model_index} out of bounds, using 1.0 as fallback")
+ return 1.0
+ return float(model_scores[model_index, index].item())
+
+ def get_model_metrics(key: str, index: int) -> Dict[str, float]:
+ if not model_reward_metrics_list or key not in label_map:
+ return {}
+ model_index = label_map[key]
+ if model_index >= len(model_reward_metrics_list):
+ return {}
+ metrics = model_reward_metrics_list[model_index]
+ if not metrics:
+ return {}
+
+ sample_metrics: Dict[str, float] = {}
+ for metric_name, tensor_value in metrics.items():
+ if not isinstance(tensor_value, torch.Tensor):
+ continue
+ flat_tensor = tensor_value.reshape(-1)
+ if flat_tensor.numel() <= index:
+ continue
+ sample_metrics[metric_name] = float(flat_tensor[index].item())
+ return sample_metrics
+
+ for index, label in enumerate(labels):
+ solution = solution_strs[index]
+ reference = refs[index] if index < len(refs) else ""
+ format_metric = math_prm_format_reward_fn(solution)
+ metrics_dict["format_reward"][index] = format_metric
+ reward_value = 0.0 if label in NO_GLOBAL_FORMAT_REWARD_LABELS else format_metric
+
+ recipe = RECIPE.get(label)
+ if recipe is None:
+ print(f"label <{label}> not registered in RECIPE, returning 0.0 reward")
+ recipe = []
+
+ for reward_type, key, weight in recipe:
+ if reward_type == "model":
+ model_reward = weight * get_model_reward(key, index)
+ reward_value += model_reward
+ metrics_dict["model_reward"][index] += model_reward
+ for metric_name, metric_value in get_model_metrics(key, index).items():
+ ensure_metric_key(metric_name)
+ if metric_name == "final_reward":
+ continue
+ metrics_dict[metric_name][index] = metric_value
+ elif reward_type == "rule":
+ rule_reward = weight * rule_reward_fn(solution, reference)
+ reward_value += rule_reward
+ metrics_dict["rule_reward"][index] += rule_reward
+ metrics_dict["accuracy_reward"][index] = rule_reward
+ else:
+ print(f"Unknown component type {reward_type}, ignoring")
+
+ final_reward[index] = reward_value
+ metrics_dict["final_reward"][index] = reward_value
+
+ return final_reward, metrics_dict
+
+
+def reward_fn(
+ model_reward_list: List[torch.Tensor],
+ model_reward_metrics_list: Optional[List[Optional[Dict[str, torch.Tensor]]]],
+ labels: Sequence[str],
+ queries: Sequence[str],
+ refs: Sequence[str],
+ label_map: Dict[str, int],
+) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ if model_reward_list:
+ model_scores = torch.stack(model_reward_list)
+ else:
+ model_scores = torch.zeros((0, len(labels)), dtype=torch.float32, device="cpu")
+
+ return mix_rewards(
+ labels=labels,
+ model_scores=model_scores,
+ label_map=label_map,
+ solution_strs=queries,
+ refs=refs,
+ model_reward_metrics_list=model_reward_metrics_list,
+ )
diff --git a/examples/math_prm/rollout_eos_patch.py b/examples/math_prm/rollout_eos_patch.py
new file mode 100644
index 00000000..e46b1d5f
--- /dev/null
+++ b/examples/math_prm/rollout_eos_patch.py
@@ -0,0 +1,254 @@
+"""
+Math PRM rollout EOS patch — keeps the fix local to examples/math_prm/.
+
+Background
+----------
+On 8-GPU FSDP rollouts, historical attempts to terminate URSA generation
+through a ``LogitsProcessor`` that nudges the eos logit up were unreliable:
+logs showed the processor firing hundreds of times (``forced_eos_rows=291``
+per batch-of-4) while every sample still ran the full ``max_new_tokens``
+(``mean_length=511.8`` / 512). The "logits nudge → sampled token →
+``EosTokenCriteria``" handshake does not close under FSDP's numerical
+regime, and even on single card it is only probabilistic.
+
+Fix
+---
+Install a ``StoppingCriteria`` directly on the rollout actor's underlying
+HF model. HF's sample loop calls ``stopping_criteria(input_ids, scores)``
+*after* each new token is appended, and ANDs the returned mask into
+``unfinished_sequences``. When we return True for a row, HF marks it
+finished immediately — no sampling, no logit tricks, no numerical edge.
+
+The criteria also exposes an ``eos_token_id`` attribute so HF's
+``has_eos_stopping_criteria`` detection (``utils.py:2735``) treats our
+signal as EOS-equivalent and enables the post-EOS pad-fill path at
+``utils.py:2835`` for rows we have marked done.
+
+Shape
+-----
+This module is self-contained under ``examples/math_prm/`` and is installed
+from ``train_colocate.py`` via ``install_math_prm_rollout_eos_patch``. The
+install helper wraps ``rollout_actor.model.generate`` so every generate
+call gets a fresh criteria injected. Since math_prm's training loop only
+ever runs math_prm batches, unconditional injection is correct — on
+non-math content the criteria simply never sees ``†Answer:`` and its
+``done`` mask stays all-False (pure runtime no-op).
+
+No changes to ``lightrft/`` are required.
+"""
+
+from __future__ import annotations
+
+import functools
+import time
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
+
+from math_prm_output import MATH_PRM_ANSWER_MARKER, should_stop_math_prm_response_text
+
+
+class StructuredAnswerStoppingCriteria(StoppingCriteria):
+ """
+ Terminate URSA-math_prm rollout generation when a fully-formed
+ ``†Answer: `` line has been emitted.
+
+ Key properties:
+
+ - Exposes ``eos_token_id`` as an attribute so HF's internal
+ ``has_eos_stopping_criteria`` detection (``utils.py:2735``) treats this
+ criteria as EOS-equivalent, which enables the post-EOS pad-fill path
+ (``utils.py:2835``). Without that attr, rows we mark done would keep
+ getting non-pad filler tokens written into their slots, and
+ ``process_sequences`` — which derives attention-mask from
+ ``seq.ne(eos_token_id) & seq.ne(pad_token_id)`` — would still count
+ those positions as real content.
+ - Checks only every ``check_interval`` tokens to amortise CPU
+ ``batch_decode`` cost (matching the existing LogitsProcessor cadence).
+ - Done bits are *sticky*: once set, the criteria re-asserts them on
+ every subsequent call, including between gated checks. This is
+ critical — HF's sample loop ANDs our return into
+ ``unfinished_sequences`` (``utils.py:2842``), so if we ever returned
+ False for a row we had previously stopped, HF would un-stop it.
+ """
+
+ def __init__(self, tokenizer, prompt_length: int, eos_token_id: int):
+ self.tokenizer = tokenizer
+ self.prompt_length = int(prompt_length)
+ self.eos_token_id = int(eos_token_id)
+ self.check_interval = 4
+ self.marker_scan_max_tokens = 192
+ self.answer_tail_max_tokens = 128
+ self.answer_marker_token_ids = tuple(
+ int(token_id) for token_id in tokenizer.encode(MATH_PRM_ANSWER_MARKER, add_special_tokens=False)
+ )
+ self._marker_seen: Optional[List[bool]] = None
+ self._done: Optional[torch.Tensor] = None
+ self._stats: Dict[str, float] = defaultdict(float)
+
+ def _ensure_state(self, batch_size: int, device) -> None:
+ if self._marker_seen is None or len(self._marker_seen) != batch_size:
+ self._marker_seen = [False] * batch_size
+ if self._done is None or self._done.numel() != batch_size:
+ self._done = torch.zeros(batch_size, dtype=torch.bool, device=device)
+ elif self._done.device != device:
+ self._done = self._done.to(device)
+
+ def _scan_row_for_answer_marker(self, row_token_ids: torch.Tensor) -> bool:
+ marker_token_ids = self.answer_marker_token_ids
+ if not marker_token_ids:
+ return MATH_PRM_ANSWER_MARKER in self.tokenizer.decode(row_token_ids, skip_special_tokens=False)
+
+ token_ids = row_token_ids.tolist()
+ marker_len = len(marker_token_ids)
+ if len(token_ids) < marker_len:
+ return False
+
+ search_start = max(0, len(token_ids) - self.marker_scan_max_tokens)
+ token_ids = token_ids[search_start:]
+ last_start = len(token_ids) - marker_len + 1
+ for start_idx in range(max(last_start, 0)):
+ if tuple(token_ids[start_idx:start_idx + marker_len]) == marker_token_ids:
+ return True
+ return False
+
+ def _decode_rows(self, row_token_ids: torch.Tensor) -> List[str]:
+ decode_t0 = time.time()
+ texts = self.tokenizer.batch_decode(row_token_ids, skip_special_tokens=False)
+ self._stats["decode_time_s"] += time.time() - decode_t0
+ self._stats["decoded_rows"] += len(texts)
+ return texts
+
+ def get_debug_stats(self) -> Optional[Dict[str, Union[int, float]]]:
+ if self._stats["calls"] <= 0:
+ return None
+ return {
+ "calls": int(self._stats["calls"]),
+ "gated_checks": int(self._stats["gated_checks"]),
+ "marker_scan_rows": int(self._stats["marker_scan_rows"]),
+ "marker_hits": int(self._stats["marker_hits"]),
+ "answer_tail_rows": int(self._stats["answer_tail_rows"]),
+ "decoded_rows": int(self._stats["decoded_rows"]),
+ "stopped_rows": int(self._stats["stopped_rows"]),
+ "decode_time_s": round(float(self._stats["decode_time_s"]), 4),
+ }
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
+ self._stats["calls"] += 1
+ batch_size = input_ids.size(0)
+ self._ensure_state(batch_size, input_ids.device)
+
+ if input_ids.size(1) <= self.prompt_length:
+ return self._done.clone()
+
+ generated_length = input_ids.size(1) - self.prompt_length
+ # Always return the sticky done mask, even on non-gated-check steps,
+ # so HF cannot flip previously-stopped rows back to unfinished.
+ if generated_length % self.check_interval != 0:
+ return self._done.clone()
+ self._stats["gated_checks"] += 1
+
+ unresolved_rows = [
+ idx for idx in range(batch_size)
+ if not self._marker_seen[idx] and not bool(self._done[idx].item())
+ ]
+ if unresolved_rows:
+ scan_start = max(self.prompt_length, input_ids.size(1) - self.marker_scan_max_tokens)
+ scan_ids = input_ids[unresolved_rows, scan_start:].detach().cpu()
+ self._stats["marker_scan_rows"] += len(unresolved_rows)
+ matched_row_indices = []
+ matched_scan_ids = []
+ for row_idx, row_token_ids in zip(unresolved_rows, scan_ids):
+ if self._scan_row_for_answer_marker(row_token_ids):
+ self._marker_seen[row_idx] = True
+ matched_row_indices.append(row_idx)
+ matched_scan_ids.append(row_token_ids)
+
+ if matched_row_indices:
+ self._stats["marker_hits"] += len(matched_row_indices)
+ matched_scan_ids = torch.stack(matched_scan_ids)
+ scan_texts = self._decode_rows(matched_scan_ids)
+ for row_idx, text in zip(matched_row_indices, scan_texts):
+ if should_stop_math_prm_response_text(text):
+ self._done[row_idx] = True
+
+ marker_rows = [
+ idx for idx in range(batch_size)
+ if self._marker_seen[idx] and not bool(self._done[idx].item())
+ ]
+ if marker_rows:
+ tail_start = max(self.prompt_length, input_ids.size(1) - self.answer_tail_max_tokens)
+ tail_ids = input_ids[marker_rows, tail_start:].detach().cpu()
+ self._stats["answer_tail_rows"] += len(marker_rows)
+ tail_texts = self._decode_rows(tail_ids)
+ for row_idx, text in zip(marker_rows, tail_texts):
+ if should_stop_math_prm_response_text(text):
+ self._done[row_idx] = True
+
+ self._stats["stopped_rows"] = int(self._done.sum().item())
+ return self._done.clone()
+
+
+def install_math_prm_rollout_eos_patch(rollout_actor, tokenizer, eos_token_id: int) -> None:
+ """
+ Wrap ``rollout_actor.model.generate`` so that every generate call gets a
+ fresh ``StructuredAnswerStoppingCriteria`` injected into its
+ ``stopping_criteria`` kwarg.
+
+ This is only installed from the math_prm example's ``train_colocate.py``
+ on the dedicated rollout actor that is used exclusively for math_prm
+ batches, so unconditional injection is correct and keeps the patch
+ self-contained without any reliance on lightrft-side signals.
+
+ For non-math batches the criteria simply never sees ``†Answer:`` in the
+ decoded tail, so its ``done`` mask stays all-False and the patch is a
+ no-op at runtime.
+
+ Idempotent: a second install call is a no-op.
+ """
+ model = rollout_actor.model
+ if getattr(model, "_math_prm_rollout_eos_patch_installed", False):
+ return
+
+ orig_generate = model.generate
+
+ @functools.wraps(orig_generate)
+ def patched_generate(*args: Any, **kwargs: Any):
+ input_ids = kwargs.get("input_ids")
+ if input_ids is None and args:
+ input_ids = args[0]
+ if input_ids is not None and hasattr(input_ids, "size"):
+ prompt_length = int(input_ids.size(1))
+ new_criteria = StructuredAnswerStoppingCriteria(
+ tokenizer=tokenizer,
+ prompt_length=prompt_length,
+ eos_token_id=int(eos_token_id),
+ )
+ existing = kwargs.get("stopping_criteria")
+ if existing is None:
+ kwargs["stopping_criteria"] = StoppingCriteriaList([new_criteria])
+ else:
+ # Be conservative — if caller already provided criteria,
+ # prepend ours rather than dropping theirs.
+ kwargs["stopping_criteria"] = StoppingCriteriaList([new_criteria, *existing])
+
+ # HF auto-enables `synced_gpus=True` under FSDP (see
+ # generation/utils.py:2218), but each rank here runs an independent
+ # local-HF generate on its own prompt slice: reshard_after_forward
+ # is False on the rollout actor so there are no per-step
+ # collectives. Leaving synced_gpus on causes the loop to `continue`
+ # past the input_ids append at utils.py:2838 once this rank's rows
+ # are all done — combined with URSA's prefill-vs-decode branching
+ # in modeling_ursa.py:279 (takes prefill when
+ # `input_ids.shape[1] != 1`), the stale input_ids triggers an
+ # IndexError in `_merge_input_ids_with_image_features`. Force it
+ # off so each rank's generate loop exits cleanly when its own
+ # stopping criteria fire.
+ kwargs.setdefault("synced_gpus", False)
+
+ return orig_generate(*args, **kwargs)
+
+ model.generate = patched_generate
+ model._math_prm_rollout_eos_patch_installed = True
diff --git a/examples/math_prm/run_grpo_math_prm_ursa_8b.sh b/examples/math_prm/run_grpo_math_prm_ursa_8b.sh
new file mode 100755
index 00000000..12fd9178
--- /dev/null
+++ b/examples/math_prm/run_grpo_math_prm_ursa_8b.sh
@@ -0,0 +1,261 @@
+#!/bin/bash
+# Fail fast: a crashed torchrun must propagate its exit code through the
+# `2>&1 | tee` pipeline below so multi-node orchestrators / CI see the error.
+set -eo pipefail
+#
+# LightRFT GRPO Training Script - URSA-8B with URSA-8B-RM (Math PRM).
+#
+# This script trains URSA-8B (a multimodal math VLM built on Qwen2.5-Math) with
+# URSA-8B-RM as a Process Reward Model. The reward signal is PS-GRPO over the
+# PRM step scores: r in {0, 0.5, 1} based on outcome correctness and whether
+# any step-score drop event was observed in the response.
+#
+# - Actor: URSA-8B (hybrid SAM-B + SigLIP-L vision tower + Qwen2.5-Math)
+# - Reward: URSA-8B-RM (process reward model for step-level scoring)
+# - Engine: local HF rollout (vLLM/SGLang URSA support is future work)
+# - Algorithm: GRPO with PS-GRPO reward via the math_psgrpo label
+#
+
+# Auto-load credentials/paths from .env if present (no-op when missing).
+# Useful keys: WANDB_API_KEY, WANDB_PROJECT, HF_TOKEN, PATH_TO_YOUR_BASE_MODEL,
+# PATH_TO_URSA_RM, PATH_TO_YOUR_MATH_DATASET, LIGHTRFT_OUTPUT_ROOT.
+if [ -f "$(dirname "$0")/../../.env" ]; then
+ set -a; . "$(dirname "$0")/../../.env"; set +a
+fi
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
+export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}"
+# Alias project-specific WANDB key names to the canonical WANDB_API_KEY so
+# the rest of the script (and wandb itself) can use the canonical name.
+: "${WANDB_API_KEY:=${LIGHTRFT_WANDB_API_KEY:-${WANDB_TOKEN:-${WANDB_KEY:-}}}}"
+export WANDB_API_KEY
+
+################################################################################
+# Part 1: User Configuration #
+# Please update the following paths and settings to match your environment. #
+################################################################################
+
+# --- Model and Dataset Paths ---
+# Each value can be overridden by exporting the env var with the same name
+# before invoking this script (e.g. for CI or per-machine paths). The strings
+# below are placeholders to make the script self-documenting; a real run must
+# either edit them or override via env.
+PATH_TO_YOUR_BASE_MODEL="${PATH_TO_YOUR_BASE_MODEL:-/path/to/your/URSA-8B}"
+PATH_TO_URSA_RM="${PATH_TO_URSA_RM:-/path/to/your/URSA-RM-8B}"
+PATH_TO_YOUR_MATH_DATASET="${PATH_TO_YOUR_MATH_DATASET:-/path/to/your/preprocessed/math_psgrpo.jsonl}"
+
+# --- Experiment and Logging ---
+EXPERIMENT_NAME="${EXPERIMENT_NAME:-lightrft-ursa8b-math-prm}"
+LIGHTRFT_OUTPUT_ROOT="${LIGHTRFT_OUTPUT_ROOT:-.}"
+
+# W&B configuration. Leave WANDB_API_KEY empty to disable W&B.
+export WANDB_API_KEY="${WANDB_API_KEY:-YOUR_WANDB_API_KEY}"
+WANDB_ORG="${WANDB_ORG:-${WANDB_ENTITY:-}}"
+export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-URSA8B-MathPRM}"
+
+
+################################################################################
+# Part 2: Training Hyperparameters #
+# These settings control the training process. Adjust them as needed. #
+################################################################################
+
+# --- GRPO settings ---
+N_SAMPLES=8 # Number of samples per prompt for GRPO (must be > 1).
+EPISODE=10 # Total number of training episodes.
+WARMUP=0.03 # Learning rate warmup ratio.
+RBS=128 # Rollout Batch Size.
+TBS=128 # Training Batch Size.
+
+# --- Learning and model settings ---
+# K3 estimator (Schulman) at the historical default 0.001. The earlier proposal
+# to switch to K2 + 0.005 was justified by KL ~ 11 nats observed on the broken
+# run; once the silent log-prob misalignment was fixed (see PR #53), the real
+# K3 sits at ~0.04 and the K2/K3/K1 ratios collapse to numerically equivalent
+# small values, so the estimator + coefficient change has no remaining
+# justification. Keep historical values to minimize the PR's behavior diff.
+KL_ESTIMATOR=k3 # Schulman K3 = exp(-r) - 1 + r. Historical default.
+KL=0.001 # Historical default. K3 * 0.001 ~= 4e-5 budget on real KL.
+KL_TARGET="" # If set (e.g. "0.5"), enables AdaptiveKLController.
+# Variant 2 per-step PRM reward mode. Only meaningful when prompts have label
+# "math_per_step_prm" (see fast_exp_maker._apply_step_reward_group_norm). Values:
+# raw : scatter raw sigmoid step_score (paper Figure ablation; default)
+# group_norm : per-step group-relative baseline (GRPO convention)
+PER_STEP_REWARD_MODE="${PER_STEP_REWARD_MODE:-raw}"
+
+LR=1e-6 # Actor learning rate.
+PROMPT_MAX_LEN=1024 # Max length of the input prompt.
+GENERATE_MAX_LEN=3072 # Max length of the generated response.
+MAX_SAMPLES=15360 # Cap on the training subset size.
+
+# --- Multi-modal settings ---
+limit_mm_image_per_prompt=10
+
+# --- Evaluation settings ---
+# Eval pulls a fixed deterministic held-out subset out of the training manifest
+# (URSA Stage 3 protocol).
+EVAL_STEPS=20
+EVAL_HOLDOUT_SIZE=500
+MAX_EVAL_SAMPLES=500
+
+
+################################################################################
+# Part 3: Distributed Training Setup #
+# Configure settings for multi-GPU and multi-node training. #
+################################################################################
+
+export NNODES="${NNODES:-1}"
+export GPUS_PER_NODE="${GPUS_PER_NODE:-8}"
+export NODE_RANK="${NODE_RANK:-0}"
+export MASTER_ADDR="${MASTER_ADDR:-localhost}"
+export MASTER_PORT="${MASTER_PORT:-20092}"
+
+
+################################################################################
+# Part 4: Execution and Logging #
+# This section prepares and launches the training command. #
+################################################################################
+
+# --- Generate dynamic names and paths ---
+# SAVE_MODEL_NAME / WANDB_RUN_NAME are env-overridable so a resumed run can target
+# the existing ckpt directory instead of creating a fresh timestamped one.
+current_time=$(date +"%Y%m%d_%H%M%S")
+SAVE_MODEL_NAME="${SAVE_MODEL_NAME:-${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}}"
+WANDB_RUN_NAME="${WANDB_RUN_NAME:-${EXPERIMENT_NAME}-${current_time}}"
+SAVE_DIR="${LIGHTRFT_OUTPUT_ROOT}/results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}"
+LOG_DIR="${LIGHTRFT_OUTPUT_ROOT}/rft_logs/${EXPERIMENT_NAME}"
+export WANDB_DIR="${WANDB_DIR:-${LIGHTRFT_OUTPUT_ROOT}/wandb}"
+
+mkdir -p "${SAVE_DIR}"
+mkdir -p "${LOG_DIR}"
+mkdir -p "${WANDB_DIR}"
+TRAIN_LOG="${LOG_DIR}/node${NODE_RANK}_${current_time}.log"
+
+export TORCH_NCCL_AVOID_RECORD_STREAMS=1
+export NCCL_DEBUG="WARN"
+if [[ -n "${WANDB_API_KEY}" && "${WANDB_API_KEY}" != "YOUR_WANDB_API_KEY" ]]; then
+ export WANDB_MODE="${WANDB_MODE:-online}"
+else
+ export WANDB_MODE="${WANDB_MODE:-offline}"
+fi
+
+# Optional adaptive-KL flag block (only added when KL_TARGET is non-empty).
+KL_TARGET_ARGS=()
+if [[ -n "${KL_TARGET}" ]]; then
+ KL_TARGET_ARGS=(--kl_target "${KL_TARGET}")
+fi
+
+# Optional resume-from-checkpoint flag. Set LOAD_CHECKPOINT=1 in the environment
+# to continue training from ${ckpt_path}/_actor (and _critic if applicable).
+RESUME_ARGS=()
+if [[ "${LOAD_CHECKPOINT:-0}" == "1" ]]; then
+ RESUME_ARGS=(--load_checkpoint)
+fi
+
+WANDB_ORG_ARGS=()
+if [[ -n "${WANDB_ORG}" ]]; then
+ WANDB_ORG_ARGS=(--wandb_org "${WANDB_ORG}")
+fi
+
+# Math PRM uses a single URSA-RM checkpoint registered under the math_prm label.
+REWARD_PRETRAIN_PATHS="{\"math_prm\":\"${PATH_TO_URSA_RM}\"}"
+
+# URSA enforces a fixed structured response format for the PRM scorer.
+SYSTEM_PROMPT='A conversation between the User and Assistant. The User asks a question that may require mathematical or visual reasoning, and the Assistant solves it step by step. Each step MUST begin with "Step N:" (e.g. "Step 1:", "Step 2:") on its own line. After all steps, output exactly one final answer line prefixed with "†Answer:" (e.g. "†Answer: 42"). Stop immediately after the "†Answer:" line and do not output any extra text, repeated answer markers, or additional steps.'
+
+
+################################################################################
+# Part 5: Main Training Command #
+################################################################################
+
+# Use the conda env's torchrun explicitly: under bash -c, `conda activate` does
+# not propagate to subprocesses, so a plain `torchrun` may resolve to a system
+# python that lacks transformers/flash_attn etc. Override with TORCHRUN= if you
+# launch from a different env.
+TORCHRUN="${TORCHRUN:-torchrun}"
+"${TORCHRUN}" \
+ --nnodes $NNODES \
+ --nproc-per-node $GPUS_PER_NODE \
+ --node_rank $NODE_RANK \
+ --master-port $MASTER_PORT \
+ --master-addr $MASTER_ADDR \
+ examples/math_prm/train_colocate.py \
+ --pretrain "${PATH_TO_YOUR_BASE_MODEL}" \
+ --reward_pretrain "${REWARD_PRETRAIN_PATHS}" \
+ --prompt_data "${PATH_TO_YOUR_MATH_DATASET}" \
+ --max_samples ${MAX_SAMPLES} \
+ --input_key "prompt" \
+ --images_key "images" \
+ --label_key "label" \
+ --apply_chat_template \
+ --system_prompt "${SYSTEM_PROMPT}" \
+ --save_path "${SAVE_DIR}" \
+ --ckpt_path "${SAVE_DIR}" \
+ --save_steps 20 \
+ --max_ckpt_num 2 \
+ --save_trajectories \
+ --num_trajectories_to_save 16 \
+ --print_replay_buffer_stats \
+ --fsdp \
+ --bf16 \
+ --flash_attn \
+ --gradient_checkpointing \
+ --zero_stage 3 \
+ --adam_offload \
+ --freeze_prefix \
+ --l2 1.0e-2 \
+ --mixed_mm_data \
+ --limit_mm_image_per_prompt $limit_mm_image_per_prompt \
+ --loss_agg_mode "seq-mean-token-mean" \
+ --advantage_estimator "group_norm" \
+ --max_epochs 1 \
+ --num_episodes ${EPISODE} \
+ --lr_warmup_ratio ${WARMUP} \
+ --n_samples_per_prompt $N_SAMPLES \
+ --train_batch_size ${TBS} \
+ --rollout_batch_size ${RBS} \
+ --prompt_max_len $PROMPT_MAX_LEN \
+ --generate_max_len $GENERATE_MAX_LEN \
+ --actor_learning_rate $LR \
+ --use_kl_loss \
+ --init_kl_coef $KL \
+ --kl_estimator "${KL_ESTIMATOR}" \
+ --per_step_reward_mode "${PER_STEP_REWARD_MODE}" \
+ "${KL_TARGET_ARGS[@]}" \
+ "${RESUME_ARGS[@]}" \
+ --engine_type "hf" \
+ --engine_mem_util 0.6 \
+ --local_hf_generate_max_batch_size 4 \
+ --local_hf_max_new_tokens 512 \
+ --hf_separate_rollout_actor \
+ --hf_separate_rollout_keep_on_gpu \
+ --enable_engine_sleep \
+ --eval_steps ${EVAL_STEPS} \
+ --eval_holdout_size ${EVAL_HOLDOUT_SIZE} \
+ --max_eval_samples ${MAX_EVAL_SAMPLES} \
+ --use_wandb "true" \
+ "${WANDB_ORG_ARGS[@]}" \
+ --wandb_project "${WANDB_PROJECT}" \
+ --wandb_run_name "${WANDB_RUN_NAME}" \
+ 2>&1 | tee "${TRAIN_LOG}"
+
+
+################################################################################
+# Usage Instructions #
+# #
+# Step 1: Prepare the URSA-8B actor and URSA-8B-RM reward model checkpoints. #
+# Both are public on Hugging Face under the URSA-MATH project. Set #
+# PATH_TO_YOUR_BASE_MODEL and PATH_TO_URSA_RM to the local directories. #
+# #
+# Step 2: Preprocess the math PRM dataset. #
+# `python examples/math_prm/tools/prepare_ursa_stage3_manifest.py` #
+# produces a JSONL manifest with fields {prompt, images, reference, label} #
+# where label="math_psgrpo" enables the PS-GRPO reward path. #
+# #
+# Step 3: Configure the script. #
+# Edit "Part 1: User Configuration" at the top of this file. Set the paths #
+# to your URSA-8B actor, URSA-8B-RM reward model, and preprocessed manifest. #
+# #
+# Step 4: Run the training script. #
+# `bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh` #
+# #
+################################################################################
diff --git a/examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh b/examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh
new file mode 100755
index 00000000..319d110a
--- /dev/null
+++ b/examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh
@@ -0,0 +1,303 @@
+#!/bin/bash
+# Fail fast: a crashed torchrun must propagate its exit code through the
+# `2>&1 | tee` pipeline below so multi-node orchestrators / CI see the error.
+set -eo pipefail
+#
+# LightRFT GRPO Training Script — URSA-8B with URSA-8B-RM, strict URSA-paper
+# Eq.9 (variant 2) advantage.
+#
+# Differs from run_grpo_math_prm_ursa_8b.sh ONLY in --advantage_estimator:
+# ursa_variant2 — strict paper Eq.9 form, computed in
+# examples/math_prm/ursa_variant2.py:
+# A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i)
+# + GroupNorm_G(r_o^i)
+# A_t broadcast to every token in step t's span.
+# No cumulative return. Outcome term retained.
+#
+# Auto-swaps PATH_TO_YOUR_MATH_DATASET to the .per_step_prm.jsonl sibling
+# (label="math_per_step_prm") because variant 2 needs per-step labels.
+#
+# - Actor: URSA-8B (hybrid SAM-B + SigLIP-L vision tower + Qwen2.5-Math)
+# - Reward: URSA-8B-RM (process reward model for step-level scoring)
+# - Engine: local HF rollout (vLLM/SGLang URSA support is future work)
+# - Algorithm: GRPO with strict URSA paper Eq.9 advantage via the
+# math_per_step_prm label (see examples/math_prm/ursa_variant2.py)
+#
+
+# Auto-load credentials/paths from .env if present (no-op when missing).
+# Useful keys: WANDB_API_KEY, WANDB_PROJECT, HF_TOKEN, PATH_TO_YOUR_BASE_MODEL,
+# PATH_TO_URSA_RM, PATH_TO_YOUR_MATH_DATASET, LIGHTRFT_OUTPUT_ROOT.
+if [ -f "$(dirname "$0")/../../.env" ]; then
+ set -a; . "$(dirname "$0")/../../.env"; set +a
+fi
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)"
+export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}"
+# Alias project-specific WANDB key names to the canonical WANDB_API_KEY so
+# the rest of the script (and wandb itself) can use the canonical name.
+: "${WANDB_API_KEY:=${LIGHTRFT_WANDB_API_KEY:-${WANDB_TOKEN:-${WANDB_KEY:-}}}}"
+export WANDB_API_KEY
+
+################################################################################
+# Part 1: User Configuration #
+# Please update the following paths and settings to match your environment. #
+################################################################################
+
+# --- Model and Dataset Paths ---
+# Each value can be overridden by exporting the env var with the same name
+# before invoking this script (e.g. for CI or per-machine paths). The strings
+# below are placeholders to make the script self-documenting; a real run must
+# either edit them or override via env.
+PATH_TO_YOUR_BASE_MODEL="${PATH_TO_YOUR_BASE_MODEL:-/path/to/your/URSA-8B}"
+PATH_TO_URSA_RM="${PATH_TO_URSA_RM:-/path/to/your/URSA-RM-8B}"
+# variant 2 NEEDS rows labeled "math_per_step_prm". The PS-GRPO dataset has
+# label="math_psgrpo" everywhere — running variant 2 on it would silently
+# emit zero step_rewards. If the caller still points PATH_TO_YOUR_MATH_DATASET
+# at a psgrpo .jsonl (legacy default), we auto-swap to its sed-relabeled
+# sibling. Build that sibling once with the one-liner documented in
+# README.md §6 "Strict Paper Eq.9 — variant 2 path":
+# sed 's/"label":[ ]*"math_psgrpo"/"label": "math_per_step_prm"/g' \
+# /path/to/math_psgrpo.jsonl > /path/to/math_per_step_prm.jsonl
+# If the caller wants a custom path, set PATH_TO_YOUR_MATH_DATASET_VARIANT2.
+if [ -n "${PATH_TO_YOUR_MATH_DATASET_VARIANT2:-}" ]; then
+ PATH_TO_YOUR_MATH_DATASET="${PATH_TO_YOUR_MATH_DATASET_VARIANT2}"
+elif [ -n "${PATH_TO_YOUR_MATH_DATASET:-}" ] && [[ "${PATH_TO_YOUR_MATH_DATASET}" != *per_step_prm* ]]; then
+ PATH_TO_YOUR_MATH_DATASET="${PATH_TO_YOUR_MATH_DATASET%.jsonl}.per_step_prm.jsonl"
+fi
+PATH_TO_YOUR_MATH_DATASET="${PATH_TO_YOUR_MATH_DATASET:-/path/to/your/preprocessed/math_per_step_prm.jsonl}"
+if [ ! -f "${PATH_TO_YOUR_MATH_DATASET}" ]; then
+ echo "[variant2 launch] FATAL: dataset not found: ${PATH_TO_YOUR_MATH_DATASET}" >&2
+ exit 1
+fi
+# Sanity: first row must already be relabeled, otherwise variant 2 silently fails.
+FIRST_LABEL=$(head -1 "${PATH_TO_YOUR_MATH_DATASET}" | python3 -c 'import sys,json; print(json.loads(sys.stdin.read()).get("label",""))' 2>/dev/null || echo "")
+if [ "${FIRST_LABEL}" != "math_per_step_prm" ]; then
+ echo "[variant2 launch] FATAL: dataset first row label='${FIRST_LABEL}', expected 'math_per_step_prm'." >&2
+ echo "Pre-process with:" >&2
+ echo " sed 's/\"label\":[ ]*\"math_psgrpo\"/\"label\": \"math_per_step_prm\"/g' SRC > DST" >&2
+ exit 1
+fi
+echo "[variant2 launch] using dataset: ${PATH_TO_YOUR_MATH_DATASET}"
+
+# --- Experiment and Logging ---
+EXPERIMENT_NAME="${EXPERIMENT_NAME:-lightrft-ursa8b-math-prm-variant2}"
+LIGHTRFT_OUTPUT_ROOT="${LIGHTRFT_OUTPUT_ROOT:-.}"
+
+# W&B configuration. Leave WANDB_API_KEY empty to disable W&B.
+export WANDB_API_KEY="${WANDB_API_KEY:-YOUR_WANDB_API_KEY}"
+WANDB_ORG="${WANDB_ORG:-${WANDB_ENTITY:-}}"
+export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-URSA8B-MathPRM}"
+
+
+################################################################################
+# Part 2: Training Hyperparameters #
+# These settings control the training process. Adjust them as needed. #
+################################################################################
+
+# --- GRPO settings ---
+N_SAMPLES=8 # Number of samples per prompt for GRPO (must be > 1).
+EPISODE=10 # Total number of training episodes.
+WARMUP=0.03 # Learning rate warmup ratio.
+RBS=128 # Rollout Batch Size.
+TBS=128 # Training Batch Size.
+
+# --- Learning and model settings ---
+# K3 estimator (Schulman) at the historical default 0.001. The earlier proposal
+# to switch to K2 + 0.005 was justified by KL ~ 11 nats observed on the broken
+# run; once the silent log-prob misalignment was fixed (see PR #53), the real
+# K3 sits at ~0.04 and the K2/K3/K1 ratios collapse to numerically equivalent
+# small values, so the estimator + coefficient change has no remaining
+# justification. Keep historical values to minimize the PR's behavior diff.
+KL_ESTIMATOR=k3 # Schulman K3 = exp(-r) - 1 + r. Historical default.
+KL=0.001 # Historical default. K3 * 0.001 ~= 4e-5 budget on real KL.
+KL_TARGET="" # If set (e.g. "0.5"), enables AdaptiveKLController.
+# NOTE: --per_step_reward_mode is intentionally NOT passed to train_colocate.py
+# in this launcher. It only affects the legacy Math-Shepherd-style per-token
+# reward path (fast_exp_maker._apply_step_reward_group_norm); the
+# ursa_variant2 advantage estimator does its own GroupNorm inside
+# UrsaVariant2Calculator.preprocess_rewards (see examples/math_prm/ursa_variant2.py),
+# so passing the flag here is inert and only adds cognitive load. The
+# PS-GRPO launcher (run_grpo_math_prm_ursa_8b.sh) still exposes it.
+
+LR=1e-6 # Actor learning rate.
+PROMPT_MAX_LEN=1024 # Max length of the input prompt.
+GENERATE_MAX_LEN=3072 # Max length of the generated response.
+MAX_SAMPLES=15360 # Cap on the training subset size.
+
+# --- Multi-modal settings ---
+limit_mm_image_per_prompt=10
+
+# --- Evaluation settings ---
+# Eval pulls a fixed deterministic held-out subset out of the training manifest
+# (URSA Stage 3 protocol).
+EVAL_STEPS=20
+EVAL_HOLDOUT_SIZE=500
+MAX_EVAL_SAMPLES=500
+
+
+################################################################################
+# Part 3: Distributed Training Setup #
+# Configure settings for multi-GPU and multi-node training. #
+################################################################################
+
+export NNODES="${NNODES:-1}"
+export GPUS_PER_NODE="${GPUS_PER_NODE:-8}"
+export NODE_RANK="${NODE_RANK:-0}"
+export MASTER_ADDR="${MASTER_ADDR:-localhost}"
+export MASTER_PORT="${MASTER_PORT:-20092}"
+
+
+################################################################################
+# Part 4: Execution and Logging #
+# This section prepares and launches the training command. #
+################################################################################
+
+# --- Generate dynamic names and paths ---
+# SAVE_MODEL_NAME / WANDB_RUN_NAME are env-overridable so a resumed run can target
+# the existing ckpt directory instead of creating a fresh timestamped one.
+current_time=$(date +"%Y%m%d_%H%M%S")
+SAVE_MODEL_NAME="${SAVE_MODEL_NAME:-${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}}"
+WANDB_RUN_NAME="${WANDB_RUN_NAME:-${EXPERIMENT_NAME}-${current_time}}"
+SAVE_DIR="${LIGHTRFT_OUTPUT_ROOT}/results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}"
+LOG_DIR="${LIGHTRFT_OUTPUT_ROOT}/rft_logs/${EXPERIMENT_NAME}"
+export WANDB_DIR="${WANDB_DIR:-${LIGHTRFT_OUTPUT_ROOT}/wandb}"
+
+mkdir -p "${SAVE_DIR}"
+mkdir -p "${LOG_DIR}"
+mkdir -p "${WANDB_DIR}"
+TRAIN_LOG="${LOG_DIR}/node${NODE_RANK}_${current_time}.log"
+
+export TORCH_NCCL_AVOID_RECORD_STREAMS=1
+export NCCL_DEBUG="WARN"
+if [[ -n "${WANDB_API_KEY}" && "${WANDB_API_KEY}" != "YOUR_WANDB_API_KEY" ]]; then
+ export WANDB_MODE="${WANDB_MODE:-online}"
+else
+ export WANDB_MODE="${WANDB_MODE:-offline}"
+fi
+
+# Optional adaptive-KL flag block (only added when KL_TARGET is non-empty).
+KL_TARGET_ARGS=()
+if [[ -n "${KL_TARGET}" ]]; then
+ KL_TARGET_ARGS=(--kl_target "${KL_TARGET}")
+fi
+
+# Optional resume-from-checkpoint flag. Set LOAD_CHECKPOINT=1 in the environment
+# to continue training from ${ckpt_path}/_actor (and _critic if applicable).
+RESUME_ARGS=()
+if [[ "${LOAD_CHECKPOINT:-0}" == "1" ]]; then
+ RESUME_ARGS=(--load_checkpoint)
+fi
+
+WANDB_ORG_ARGS=()
+if [[ -n "${WANDB_ORG}" ]]; then
+ WANDB_ORG_ARGS=(--wandb_org "${WANDB_ORG}")
+fi
+
+# Math PRM uses a single URSA-RM checkpoint registered under the math_prm label.
+REWARD_PRETRAIN_PATHS="{\"math_prm\":\"${PATH_TO_URSA_RM}\"}"
+
+# URSA enforces a fixed structured response format for the PRM scorer.
+SYSTEM_PROMPT='A conversation between the User and Assistant. The User asks a question that may require mathematical or visual reasoning, and the Assistant solves it step by step. Each step MUST begin with "Step N:" (e.g. "Step 1:", "Step 2:") on its own line. After all steps, output exactly one final answer line prefixed with "†Answer:" (e.g. "†Answer: 42"). Stop immediately after the "†Answer:" line and do not output any extra text, repeated answer markers, or additional steps.'
+
+
+################################################################################
+# Part 5: Main Training Command #
+################################################################################
+
+# Use the conda env's torchrun explicitly: under bash -c, `conda activate` does
+# not propagate to subprocesses, so a plain `torchrun` may resolve to a system
+# python that lacks transformers/flash_attn etc. Override with TORCHRUN= if you
+# launch from a different env.
+TORCHRUN="${TORCHRUN:-torchrun}"
+"${TORCHRUN}" \
+ --nnodes $NNODES \
+ --nproc-per-node $GPUS_PER_NODE \
+ --node_rank $NODE_RANK \
+ --master-port $MASTER_PORT \
+ --master-addr $MASTER_ADDR \
+ examples/math_prm/train_colocate.py \
+ --pretrain "${PATH_TO_YOUR_BASE_MODEL}" \
+ --reward_pretrain "${REWARD_PRETRAIN_PATHS}" \
+ --prompt_data "${PATH_TO_YOUR_MATH_DATASET}" \
+ --max_samples ${MAX_SAMPLES} \
+ --input_key "prompt" \
+ --images_key "images" \
+ --label_key "label" \
+ --apply_chat_template \
+ --system_prompt "${SYSTEM_PROMPT}" \
+ --save_path "${SAVE_DIR}" \
+ --ckpt_path "${SAVE_DIR}" \
+ --save_steps 20 \
+ --max_ckpt_num 2 \
+ --save_trajectories \
+ --num_trajectories_to_save 16 \
+ --print_replay_buffer_stats \
+ --fsdp \
+ --bf16 \
+ --flash_attn \
+ --gradient_checkpointing \
+ --zero_stage 3 \
+ --adam_offload \
+ --freeze_prefix \
+ --l2 1.0e-2 \
+ --mixed_mm_data \
+ --limit_mm_image_per_prompt $limit_mm_image_per_prompt \
+ --loss_agg_mode "seq-mean-token-mean" \
+ --advantage_estimator "ursa_variant2" \
+ --max_epochs 1 \
+ --num_episodes ${EPISODE} \
+ --lr_warmup_ratio ${WARMUP} \
+ --n_samples_per_prompt $N_SAMPLES \
+ --train_batch_size ${TBS} \
+ --rollout_batch_size ${RBS} \
+ --prompt_max_len $PROMPT_MAX_LEN \
+ --generate_max_len $GENERATE_MAX_LEN \
+ --actor_learning_rate $LR \
+ --use_kl_loss \
+ --init_kl_coef $KL \
+ --kl_estimator "${KL_ESTIMATOR}" \
+ "${KL_TARGET_ARGS[@]}" \
+ "${RESUME_ARGS[@]}" \
+ --engine_type "hf" \
+ --engine_mem_util 0.6 \
+ --local_hf_generate_max_batch_size 4 \
+ --local_hf_max_new_tokens 512 \
+ --hf_separate_rollout_actor \
+ --hf_separate_rollout_keep_on_gpu \
+ --enable_engine_sleep \
+ --eval_steps ${EVAL_STEPS} \
+ --eval_holdout_size ${EVAL_HOLDOUT_SIZE} \
+ --max_eval_samples ${MAX_EVAL_SAMPLES} \
+ --use_wandb "true" \
+ "${WANDB_ORG_ARGS[@]}" \
+ --wandb_project "${WANDB_PROJECT}" \
+ --wandb_run_name "${WANDB_RUN_NAME}" \
+ 2>&1 | tee "${TRAIN_LOG}"
+
+
+################################################################################
+# Usage Instructions #
+# #
+# Step 1: Prepare the URSA-8B actor and URSA-8B-RM reward model checkpoints. #
+# Both are public on Hugging Face under the URSA-MATH project. Set #
+# PATH_TO_YOUR_BASE_MODEL and PATH_TO_URSA_RM to the local directories. #
+# #
+# Step 2: Preprocess the math PRM dataset and relabel it for variant 2. #
+# First produce the standard PS-GRPO manifest: #
+# python examples/math_prm/tools/prepare_ursa_stage3_manifest.py \ #
+# --input-path /your/data/MMathCoT-1M/train.jsonl \ #
+# --image-root /your/data/URSA-MATH/images \ #
+# --output-path /your/output/math_psgrpo.jsonl #
+# Then sed-relabel into a math_per_step_prm sibling (variant 2 requires it): #
+# sed 's/"label":[ ]*"math_psgrpo"/"label": "math_per_step_prm"/g' \ #
+# /your/output/math_psgrpo.jsonl #
+# > /your/output/math_per_step_prm.jsonl #
+# #
+# Step 3: Configure the script. #
+# Edit "Part 1: User Configuration" at the top of this file. Set the paths #
+# to your URSA-8B actor, URSA-8B-RM reward model, and preprocessed manifest. #
+# #
+# Step 4: Run the training script. #
+# `bash examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh` #
+# #
+################################################################################
diff --git a/examples/math_prm/test_ursa_variant2.py b/examples/math_prm/test_ursa_variant2.py
new file mode 100644
index 00000000..6fccb154
--- /dev/null
+++ b/examples/math_prm/test_ursa_variant2.py
@@ -0,0 +1,369 @@
+"""Strict-alignment tests for the URSA paper Eq.9 advantage estimator.
+
+Tests cover the five acceptance criteria AC1–AC5 from the PR plan:
+
+ AC1 numerical equivalence with hand-computed paper Eq.9 (max|Δ|<1e-5)
+ AC2 outcome reward is NOT bypassed (changing r_o changes advantages)
+ AC3 group normalization is correct over K=n_samples_per_prompt
+ AC4 per-step advantage broadcast to the *full* step span (not just the
+ boundary token); advantage jumps at step boundaries
+ AC5 with realistic mixed-sign inputs, advantages contain both signs
+ (regression against the legacy ``per_step_reward_mode=raw`` failure
+ mode where every advantage was positive)
+
+Run from repo root:
+ python3 -m unittest examples.math_prm.test_ursa_variant2 -v
+Or directly:
+ python3 examples/math_prm/test_ursa_variant2.py
+"""
+
+from __future__ import annotations
+
+import math
+import os
+import sys
+import unittest
+from types import SimpleNamespace
+from typing import List
+
+# Allow `import ursa_variant2` whether run from repo root (CI) or from
+# examples/math_prm/ (developer convenience).
+_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
+if _THIS_DIR not in sys.path:
+ sys.path.insert(0, _THIS_DIR)
+
+import torch
+
+import ursa_variant2 as ursa_v2 # registers the monkey-patch at import time
+
+
+def _make_cfg(n_samples: int = 2, advantage_clip: float = 0) -> SimpleNamespace:
+ """Minimal config namespace UrsaVariant2Calculator reads from."""
+ return SimpleNamespace(
+ n_samples_per_prompt=n_samples,
+ advantage_clip=advantage_clip,
+ )
+
+
+def _make_exp(
+ action_mask: torch.Tensor,
+ reward: torch.Tensor,
+ step_rewards: List[torch.Tensor],
+ step_token_indices: List[torch.Tensor],
+):
+ """Minimal experience stub: just info dict + action_mask."""
+ return SimpleNamespace(
+ action_mask=action_mask,
+ info={
+ "reward": reward,
+ "step_rewards": step_rewards,
+ "step_token_indices": step_token_indices,
+ },
+ )
+
+
+# Hand-computed Eq.9 reference, written out explicitly so reviewers can
+# verify the test logic itself matches paper Appendix B.1 Eq.9.
+def _hand_compute_eq9(
+ step_rewards: List[torch.Tensor],
+ step_token_indices: List[torch.Tensor],
+ outcome: torch.Tensor,
+ K: int,
+ T: int,
+) -> torch.Tensor:
+ """Brute-force reference: build per-token advantages following paper Eq.9.
+
+ A_t^i = r_{s,t}^i * GroupNorm_G(r̄_s^i) + GroupNorm_G(r_o^i)
+ where t indexes steps and the value is broadcast to every token within
+ the span [start_k, end_k] (start_0=0, start_k = end_{k-1}+1).
+ """
+ B = outcome.numel()
+ assert B % K == 0
+ G = B // K
+
+ r_bar = torch.stack(
+ [sr.float().mean() if sr.numel() > 0 else torch.tensor(0.0) for sr in step_rewards]
+ )
+
+ def gn(x: torch.Tensor) -> torch.Tensor:
+ g = x.float().reshape(G, K)
+ return ((g - g.mean(dim=-1, keepdim=True))
+ / (g.std(dim=-1, unbiased=False, keepdim=True) + 1e-9)).flatten()
+
+ oc_norm = gn(outcome)
+ msp_norm = gn(r_bar)
+
+ adv = torch.zeros(B, T, dtype=torch.float32)
+ for i in range(B):
+ sr = step_rewards[i].float()
+ si = step_token_indices[i].long()
+ n = sr.numel()
+ if n == 0:
+ adv[i] = oc_norm[i]
+ continue
+ starts = torch.cat([torch.zeros(1, dtype=torch.long), si[:-1] + 1])
+ ends = si
+ for k in range(n):
+ sk = max(0, int(starts[k]))
+ ek = min(T - 1, int(ends[k]))
+ adv[i, sk:ek + 1] = sr[k] * msp_norm[i] + oc_norm[i]
+ last_end = int(ends[-1])
+ if last_end + 1 < T:
+ adv[i, last_end + 1:] = oc_norm[i]
+ return adv
+
+
+class _Base(unittest.TestCase):
+ def setUp(self):
+ torch.manual_seed(0)
+
+
+class TestAC1NumericalEquivalence(_Base):
+ """AC1: implementation matches hand-computed Eq.9 within tolerance."""
+
+ def test_basic_k2_three_steps(self):
+ K = 2
+ B = 4
+ T = 30
+ step_rewards = [
+ torch.tensor([0.80, 0.70, 0.30]), # traj 0
+ torch.tensor([0.85, 0.75, 0.90]), # traj 1
+ torch.tensor([0.50, 0.55, 0.60]), # traj 2
+ torch.tensor([0.60, 0.65, 0.70]), # traj 3
+ ]
+ step_token_indices = [torch.tensor([5, 12, 20])] * 4
+ outcome = torch.tensor([1.0, 1.0, 0.0, 1.0])
+ action_mask = torch.ones(B, T, dtype=torch.long)
+
+ expected = _hand_compute_eq9(step_rewards, step_token_indices, outcome, K, T)
+
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ # preprocess_rewards takes one experience holding all B trajectories
+ exp = _make_exp(action_mask, outcome, step_rewards, step_token_indices)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+ adv, ret, info = calc.compute(exp, final_reward=None, gamma=None, generate_kwargs={})
+
+ max_abs = (adv - expected).abs().max().item()
+ self.assertLess(max_abs, 1e-5, f"AC1 violated: max|Δ|={max_abs}")
+ # returns mirror advantages (no value function)
+ self.assertLess((ret - adv).abs().max().item(), 1e-9)
+
+ def test_k4_variable_step_count(self):
+ K = 4
+ T = 40
+ step_rewards = [
+ torch.tensor([0.9, 0.8]), # 2 steps
+ torch.tensor([0.5, 0.4, 0.6]), # 3 steps
+ torch.tensor([0.7]), # 1 step
+ torch.tensor([0.6, 0.65, 0.55, 0.50]), # 4 steps
+ ]
+ step_token_indices = [
+ torch.tensor([10, 25]),
+ torch.tensor([8, 18, 30]),
+ torch.tensor([22]),
+ torch.tensor([7, 15, 22, 33]),
+ ]
+ outcome = torch.tensor([1.0, 0.0, 1.0, 0.0])
+ action_mask = torch.ones(K, T, dtype=torch.long)
+
+ expected = _hand_compute_eq9(step_rewards, step_token_indices, outcome, K, T)
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ exp = _make_exp(action_mask, outcome, step_rewards, step_token_indices)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+ adv, _, _ = calc.compute(exp, None, None, {})
+ max_abs = (adv - expected).abs().max().item()
+ self.assertLess(max_abs, 1e-5, f"AC1 (K=4 variable steps) violated: max|Δ|={max_abs}")
+
+
+class TestAC2OutcomeNotBypassed(_Base):
+ """AC2: changing outcome reward must change advantages.
+
+ Regression for the Mode B bypass bug: under the old per-step path,
+ feeding outcome through ``compute_reward`` r had no effect because
+ Mode B threw it away. UrsaVariant2Calculator must NOT have this bug.
+ """
+
+ def _run(self, outcome):
+ K = 2
+ B = 4
+ T = 20
+ step_rewards = [torch.tensor([0.5, 0.6, 0.7])] * B
+ step_token_indices = [torch.tensor([5, 10, 15])] * B
+ action_mask = torch.ones(B, T, dtype=torch.long)
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ exp = _make_exp(action_mask, outcome, step_rewards, step_token_indices)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+ adv, _, _ = calc.compute(exp, None, None, {})
+ return adv
+
+ def test_all_correct_vs_all_wrong_differs(self):
+ # If outcome is constant across the group, the GroupNorm of outcome
+ # is exactly zero, so the additive term vanishes — that's expected
+ # and not a bug; instead we compare a per-prompt mixed case.
+ # (One sample correct, one wrong in each of two prompts.)
+ oc_a = torch.tensor([1.0, 0.0, 1.0, 0.0])
+ oc_b = torch.tensor([0.0, 1.0, 0.0, 1.0])
+ adv_a = self._run(oc_a)
+ adv_b = self._run(oc_b)
+ diff = (adv_a - adv_b).abs().max().item()
+ self.assertGreater(diff, 0.5, f"AC2 violated: outcome flip should "
+ f"flip the sign of the outcome term "
+ f"(max|Δ|={diff})")
+
+ def test_outcome_anchor_extends_past_last_step(self):
+ # Pad tail past the last step should carry the outcome anchor only.
+ K = 2
+ T = 30
+ step_rewards = [
+ torch.tensor([0.6, 0.7]),
+ torch.tensor([0.6, 0.7]),
+ ]
+ step_token_indices = [torch.tensor([5, 10])] * 2
+ # Trajectory 0 wins outcome, trajectory 1 loses — group_norm gives
+ # ±1 for the outcome term.
+ outcome = torch.tensor([1.0, 0.0])
+ action_mask = torch.ones(K, T, dtype=torch.long)
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ exp = _make_exp(action_mask, outcome, step_rewards, step_token_indices)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+ adv, _, _ = calc.compute(exp, None, None, {})
+ # tail tokens (idx 11..29) should equal oc_normed (= ±1 within tol)
+ traj0_tail = adv[0, 11:]
+ traj1_tail = adv[1, 11:]
+ # Expect tail = oc_normed[i] (no process-reward there)
+ self.assertGreater(traj0_tail.mean().item(), 0.5,
+ f"traj0 tail must carry positive outcome anchor "
+ f"(got mean={traj0_tail.mean().item():.3f})")
+ self.assertLess(traj1_tail.mean().item(), -0.5,
+ f"traj1 tail must carry negative outcome anchor "
+ f"(got mean={traj1_tail.mean().item():.3f})")
+
+
+class TestAC3GroupNormCorrect(_Base):
+ """AC3: GroupNorm zero-mean / unit-std across K siblings for both terms."""
+
+ def test_k2_msp_normed_zero_mean(self):
+ K = 2
+ B = 4
+ T = 10
+ step_rewards = [
+ torch.tensor([0.9, 0.8, 0.7]),
+ torch.tensor([0.3, 0.4, 0.2]),
+ torch.tensor([0.8, 0.7, 0.6]),
+ torch.tensor([0.5, 0.5, 0.4]),
+ ]
+ sti = [torch.tensor([2, 5, 8])] * B
+ outcome = torch.tensor([1.0, 0.0, 1.0, 0.0])
+ action_mask = torch.ones(B, T, dtype=torch.long)
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ exp = _make_exp(action_mask, outcome, step_rewards, sti)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+
+ # Stored normed values should sum to 0 per group (within tol)
+ oc_n = exp.info["_ursa_oc_normed"].view(-1, K)
+ msp_n = exp.info["_ursa_msp_normed"].view(-1, K)
+ self.assertLess(oc_n.sum(dim=-1).abs().max().item(), 1e-5)
+ self.assertLess(msp_n.sum(dim=-1).abs().max().item(), 1e-5)
+
+ def test_k4_msp_normed_unit_std(self):
+ K = 4
+ T = 10
+ step_rewards = [
+ torch.tensor([0.9, 0.8]),
+ torch.tensor([0.5, 0.4]),
+ torch.tensor([0.7, 0.7]),
+ torch.tensor([0.2, 0.3]),
+ ]
+ sti = [torch.tensor([3, 7])] * K
+ outcome = torch.tensor([1.0, 0.0, 1.0, 0.0])
+ action_mask = torch.ones(K, T, dtype=torch.long)
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ exp = _make_exp(action_mask, outcome, step_rewards, sti)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+ msp = exp.info["_ursa_msp_normed"]
+ # std (unbiased=False) across the single group of K=4 should be ~1
+ std = msp.std(unbiased=False).item()
+ self.assertAlmostEqual(std, 1.0, places=3,
+ msg=f"AC3 violated: msp_normed std={std}")
+
+
+class TestAC4SpanBroadcast(_Base):
+ """AC4: advantage is constant within each step span and changes at the boundary."""
+
+ def test_advantage_constant_within_span(self):
+ K = 2
+ B = 2
+ T = 25
+ step_rewards = [
+ torch.tensor([0.9, 0.5, 0.7]),
+ torch.tensor([0.4, 0.6, 0.8]),
+ ]
+ sti = [torch.tensor([4, 12, 20])] * B
+ outcome = torch.tensor([1.0, 0.0])
+ action_mask = torch.ones(B, T, dtype=torch.long)
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ exp = _make_exp(action_mask, outcome, step_rewards, sti)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+ adv, _, _ = calc.compute(exp, None, None, {})
+
+ # span 0 of traj 0: tokens 0..4 should all be equal
+ span0 = adv[0, 0:5]
+ self.assertLess((span0 - span0[0]).abs().max().item(), 1e-6)
+ # span 1: tokens 5..12 equal
+ span1 = adv[0, 5:13]
+ self.assertLess((span1 - span1[0]).abs().max().item(), 1e-6)
+ # span 2: tokens 13..20 equal
+ span2 = adv[0, 13:21]
+ self.assertLess((span2 - span2[0]).abs().max().item(), 1e-6)
+ # adjacent spans differ (otherwise per-step credit is degenerate)
+ self.assertNotAlmostEqual(span0[0].item(), span1[0].item(), places=4)
+ self.assertNotAlmostEqual(span1[0].item(), span2[0].item(), places=4)
+
+
+class TestAC5SignedAdvantages(_Base):
+ """AC5: typical inputs produce both positive and negative advantages."""
+
+ def test_signed_advantages_on_synthetic_batch(self):
+ K = 2
+ B = 4
+ T = 25
+ step_rewards = [
+ torch.tensor([0.8, 0.7, 0.3]),
+ torch.tensor([0.85, 0.75, 0.9]),
+ torch.tensor([0.5, 0.55, 0.6]),
+ torch.tensor([0.6, 0.65, 0.7]),
+ ]
+ sti = [torch.tensor([5, 12, 20])] * B
+ outcome = torch.tensor([1.0, 1.0, 0.0, 1.0])
+ action_mask = torch.ones(B, T, dtype=torch.long)
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ exp = _make_exp(action_mask, outcome, step_rewards, sti)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+ adv, _, info = calc.compute(exp, None, None, {})
+ # advantages must contain both signs (paper Eq.9 → zero-mean per group)
+ self.assertGreater(info["ursa_v2_adv_pos_frac"], 0.05,
+ "AC5: advantages should contain positive entries")
+ self.assertGreater(info["ursa_v2_adv_neg_frac"], 0.05,
+ "AC5: advantages should contain negative entries")
+
+
+class TestK1Fallback(_Base):
+ """When K=1, group norm is degenerate; calculator must not crash."""
+
+ def test_k1_returns_zero_advantage(self):
+ K = 1
+ T = 20
+ step_rewards = [torch.tensor([0.5, 0.6])]
+ sti = [torch.tensor([5, 10])]
+ outcome = torch.tensor([1.0])
+ action_mask = torch.ones(1, T, dtype=torch.long)
+ calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K))
+ exp = _make_exp(action_mask, outcome, step_rewards, sti)
+ calc.preprocess_rewards(outcome, [exp], max_new_tokens=T)
+ adv, _, info = calc.compute(exp, None, None, {})
+ self.assertEqual(adv.abs().sum().item(), 0.0)
+ self.assertEqual(info.get("ursa_v2_fallback_used"), 1.0)
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/examples/math_prm/tools/__init__.py b/examples/math_prm/tools/__init__.py
new file mode 100644
index 00000000..07de1e77
--- /dev/null
+++ b/examples/math_prm/tools/__init__.py
@@ -0,0 +1 @@
+"""Helper scripts, smoke runners, and regression checks for the URSA math_prm example."""
diff --git a/examples/math_prm/tools/prepare_ursa_engine_checkpoint.py b/examples/math_prm/tools/prepare_ursa_engine_checkpoint.py
new file mode 100644
index 00000000..d720e878
--- /dev/null
+++ b/examples/math_prm/tools/prepare_ursa_engine_checkpoint.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python
+"""
+Build an engine-friendly local wrapper checkpoint for URSA-8B.
+
+The upstream URSA checkpoints do not ship ``auto_map`` metadata or local model
+code files, which prevents inference engines such as vLLM/SGLang from loading
+the custom architecture via HuggingFace dynamic modules. This helper creates a
+thin wrapper directory that:
+
+1. symlinks the original checkpoint weights/tokenizer assets
+2. symlinks the local ``examples/math_prm/ursa_model/*.py`` files
+3. writes patched ``config.json`` / ``preprocessor_config.json`` /
+ ``tokenizer_config.json`` with the required ``auto_map`` entries
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import shutil
+from pathlib import Path
+
+
+MODEL_AUTO_MAP = {
+ "AutoConfig": "configuration_ursa.UrsaConfig",
+ "AutoModel": "modeling_ursa.UrsaForConditionalGeneration",
+ "AutoModelForVision2Seq": "modeling_ursa.UrsaForConditionalGeneration",
+}
+
+PROCESSOR_AUTO_MAP = {
+ "AutoProcessor": "processing_ursa.UrsaProcessor",
+ "AutoImageProcessor": "image_processing_vlm.VLMImageProcessor",
+}
+
+
+def _safe_unlink(path: Path) -> None:
+ if path.is_symlink() or path.is_file():
+ path.unlink()
+ elif path.is_dir():
+ shutil.rmtree(path)
+
+
+def _ensure_symlink(src: Path, dst: Path) -> None:
+ if dst.exists() or dst.is_symlink():
+ if dst.is_symlink() and dst.resolve() == src.resolve():
+ return
+ _safe_unlink(dst)
+ dst.symlink_to(src)
+
+
+def _load_json(path: Path) -> dict:
+ return json.loads(path.read_text())
+
+
+def _write_json(path: Path, payload: dict) -> None:
+ path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n")
+
+
+def build_wrapper(source_model_path: Path, output_path: Path, local_ursa_dir: Path) -> None:
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ for src in source_model_path.iterdir():
+ dst = output_path / src.name
+ if src.name in {"config.json", "preprocessor_config.json", "tokenizer_config.json"}:
+ continue
+ _ensure_symlink(src, dst)
+
+ for src in local_ursa_dir.glob("*.py"):
+ _ensure_symlink(src, output_path / src.name)
+
+ config = _load_json(source_model_path / "config.json")
+ auto_map = dict(config.get("auto_map") or {})
+ auto_map.update(MODEL_AUTO_MAP)
+ config["auto_map"] = auto_map
+ _write_json(output_path / "config.json", config)
+
+ preprocessor_path = source_model_path / "preprocessor_config.json"
+ if preprocessor_path.exists():
+ preprocessor = _load_json(preprocessor_path)
+ preprocessor["processor_class"] = "UrsaProcessor"
+ preprocessor["image_processor_type"] = "VLMImageProcessor"
+ preprocessor_auto_map = dict(preprocessor.get("auto_map") or {})
+ preprocessor_auto_map.update(PROCESSOR_AUTO_MAP)
+ preprocessor["auto_map"] = preprocessor_auto_map
+ _write_json(output_path / "preprocessor_config.json", preprocessor)
+
+ tokenizer_config_path = source_model_path / "tokenizer_config.json"
+ if tokenizer_config_path.exists():
+ tokenizer_config = _load_json(tokenizer_config_path)
+ tokenizer_config["processor_class"] = "UrsaProcessor"
+ tokenizer_auto_map = dict(tokenizer_config.get("auto_map") or {})
+ tokenizer_auto_map.update(PROCESSOR_AUTO_MAP)
+ tokenizer_config["auto_map"] = tokenizer_auto_map
+ _write_json(output_path / "tokenizer_config.json", tokenizer_config)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--source-model-path", required=True)
+ parser.add_argument("--output-path", required=True)
+ args = parser.parse_args()
+
+ source_model_path = Path(args.source_model_path).resolve()
+ output_path = Path(args.output_path).resolve()
+ local_ursa_dir = Path(__file__).resolve().parents[1] / "ursa_model"
+
+ build_wrapper(source_model_path, output_path, local_ursa_dir)
+ print(str(output_path))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/math_prm/tools/prepare_ursa_stage3_manifest.py b/examples/math_prm/tools/prepare_ursa_stage3_manifest.py
new file mode 100644
index 00000000..0dbac106
--- /dev/null
+++ b/examples/math_prm/tools/prepare_ursa_stage3_manifest.py
@@ -0,0 +1,322 @@
+"""
+Prepare a LightRFT-compatible Stage 3 manifest from URSA-MATH raw data.
+
+This script converts the raw `MMathCoT-1M` jsonl schema:
+
+ {"image_url": "...", "instruction": "...", "output": "..."}
+
+into a LightRFT prompt dataset schema:
+
+ {
+ "prompt": "...",
+ "images": ["/abs/path/to/image.png"],
+ "reference": "...",
+ "label": "math_psgrpo"
+ }
+
+It also performs a lightweight `PromptDatasetVL` smoke validation on the
+converted records so the output can be consumed directly by
+`examples/math_prm/train_colocate.py`.
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import re
+from collections import Counter
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Any
+
+
+REPO_ROOT = Path(__file__).resolve().parents[3]
+
+import sys
+
+if str(REPO_ROOT) not in sys.path:
+ sys.path.insert(0, str(REPO_ROOT))
+
+from lightrft.datasets.prompts_dataset_vl import PromptDatasetVL
+
+
+# --input-path and --image-root are intentionally required (no path default):
+# the data lives outside the repo and varies per environment. The output paths
+# resolve under REPO_ROOT/tmp/ so they always succeed locally.
+DEFAULT_OUTPUT_PATH = str(REPO_ROOT / "tmp" / "ursa_stage3" / "mmathcot_stage3_math_psgrpo.jsonl")
+DEFAULT_SUMMARY_PATH = str(REPO_ROOT / "tmp" / "ursa_stage3" / "mmathcot_stage3_math_psgrpo.summary.json")
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description=(
+ "Convert URSA-MATH MMathCoT-1M raw jsonl into a LightRFT "
+ "Stage 3 prompt manifest and validate it with PromptDatasetVL."
+ )
+ )
+ parser.add_argument(
+ "--input-path",
+ type=str,
+ required=True,
+ help="Path to MMathCoT-1M raw train.jsonl (e.g. /data/URSA-MATH/MMathCoT-1M/train.jsonl).",
+ )
+ parser.add_argument(
+ "--image-root",
+ type=str,
+ required=True,
+ help="Root directory for URSA-MATH image assets (e.g. /data/URSA-MATH/images).",
+ )
+ parser.add_argument(
+ "--output-path",
+ type=str,
+ default=DEFAULT_OUTPUT_PATH,
+ help="Output path for the converted LightRFT jsonl manifest.",
+ )
+ parser.add_argument(
+ "--summary-path",
+ type=str,
+ default=DEFAULT_SUMMARY_PATH,
+ help="Path to write the conversion/validation summary json.",
+ )
+ parser.add_argument(
+ "--label",
+ type=str,
+ default="math_psgrpo",
+ help="Label written into the converted manifest.",
+ )
+ parser.add_argument(
+ "--prompt-mode",
+ type=str,
+ choices=["question_only", "instruction"],
+ default="question_only",
+ help="How to build the LightRFT prompt from raw instruction.",
+ )
+ parser.add_argument(
+ "--max-samples",
+ type=int,
+ default=None,
+ help="Optional cap for the number of raw rows to process.",
+ )
+ parser.add_argument(
+ "--smoke-samples",
+ type=int,
+ default=4,
+ help="How many converted samples to use for PromptDatasetVL smoke validation.",
+ )
+ return parser.parse_args()
+
+
+def extract_prompt(raw_instruction: str, prompt_mode: str) -> tuple[str, bool]:
+ text = (raw_instruction or "").strip()
+ if not text:
+ return "", False
+
+ if prompt_mode == "instruction":
+ return text, False
+
+ marker = "Question:"
+ idx = text.find(marker)
+ if idx == -1:
+ return text, True
+
+ question = text[idx + len(marker):].strip()
+ # Some raw rows contain duplicated or malformed prefixes such as
+ # "Question:estion: ...". Strip these repeatedly before returning.
+ prefix_re = re.compile(r"^(?:(?:[Qq]uestion|[Qq]estion|[Ee]stion|[Uu]estion)\s*:)\s*")
+ while True:
+ cleaned = prefix_re.sub("", question)
+ if cleaned == question:
+ break
+ question = cleaned.strip()
+ if not question:
+ return text, True
+ return question, False
+
+
+def extract_reference(raw_output: str) -> tuple[str, bool]:
+ text = (raw_output or "").strip()
+ if not text:
+ return "", False
+
+ marker = "†Answer:"
+ idx = text.rfind(marker)
+ if idx == -1:
+ return text, True
+
+ answer = text[idx + len(marker):].strip()
+ if not answer:
+ return text, True
+ return answer, False
+
+
+def build_record(
+ raw: dict[str, Any],
+ source_index: int,
+ image_root: Path,
+ prompt_mode: str,
+ label: str,
+) -> tuple[dict[str, Any], dict[str, Any]]:
+ image_url = str(raw.get("image_url", "")).strip()
+ instruction = str(raw.get("instruction", "")).strip()
+ output = str(raw.get("output", "")).strip()
+
+ prompt, used_prompt_fallback = extract_prompt(instruction, prompt_mode)
+ reference, used_reference_fallback = extract_reference(output)
+
+ image_path = (image_root / image_url).resolve()
+ prefix = image_url.split("/", 1)[0] if image_url else ""
+
+ record = {
+ "data_source": "URSA-MATH/MMathCoT-1M",
+ "prompt": prompt,
+ "images": [str(image_path)],
+ "reference": reference,
+ "ground_truth": reference,
+ "label": label,
+ "reward_model": {
+ "ground_truth": reference,
+ },
+ "extra_info": {
+ "source_index": source_index,
+ "raw_image_url": image_url,
+ "image_prefix": prefix,
+ "prompt_mode": prompt_mode,
+ },
+ }
+
+ meta = {
+ "image_path_exists": image_path.exists(),
+ "prompt_empty": prompt == "",
+ "reference_empty": reference == "",
+ "used_prompt_fallback": used_prompt_fallback,
+ "used_reference_fallback": used_reference_fallback,
+ "image_prefix": prefix,
+ "image_path": str(image_path),
+ }
+ return record, meta
+
+
+def smoke_validate(converted_rows: list[dict[str, Any]], smoke_samples: int) -> dict[str, Any]:
+ smoke_rows = converted_rows[: max(1, min(smoke_samples, len(converted_rows)))]
+ strategy = SimpleNamespace(
+ args=SimpleNamespace(
+ input_key="prompt",
+ images_key="images",
+ reference_key="reference",
+ label_key="label",
+ apply_chat_template=False,
+ system_prompt=None,
+ )
+ )
+ dataset = PromptDatasetVL(
+ smoke_rows,
+ tokenizer=None,
+ processor=None,
+ max_length=0,
+ strategy=strategy,
+ )
+ items = [dataset[i] for i in range(len(dataset))]
+ prompts, images, references, labels = dataset.collate_fn(items)
+
+ first_prompt, first_images, first_reference, first_label = items[0]
+ return {
+ "sample_count": len(dataset),
+ "first_item": {
+ "prompt_preview": first_prompt[:240],
+ "image_count": len(first_images) if isinstance(first_images, list) else 0,
+ "first_image": first_images[0] if isinstance(first_images, list) and first_images else None,
+ "reference": first_reference,
+ "label": first_label,
+ },
+ "collate_sizes": {
+ "prompts": len(prompts),
+ "images": len(images),
+ "references": len(references),
+ "labels": len(labels),
+ },
+ }
+
+
+def main() -> None:
+ args = parse_args()
+
+ input_path = Path(args.input_path).resolve()
+ image_root = Path(args.image_root).resolve()
+ output_path = Path(args.output_path).resolve()
+ summary_path = Path(args.summary_path).resolve()
+
+ if not input_path.exists():
+ raise FileNotFoundError(f"input jsonl not found: {input_path}")
+ if not image_root.exists():
+ raise FileNotFoundError(f"image root not found: {image_root}")
+
+ counters = Counter()
+ prefix_counter: Counter[str] = Counter()
+ smoke_rows: list[dict[str, Any]] = []
+
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ with input_path.open("r", encoding="utf-8") as fp, output_path.open("w", encoding="utf-8") as out_fp:
+ for source_index, line in enumerate(fp):
+ if args.max_samples is not None and source_index >= args.max_samples:
+ break
+
+ counters["rows_seen"] += 1
+ raw = json.loads(line)
+ record, meta = build_record(
+ raw=raw,
+ source_index=source_index,
+ image_root=image_root,
+ prompt_mode=args.prompt_mode,
+ label=args.label,
+ )
+
+ prefix_counter[meta["image_prefix"]] += 1
+ if meta["used_prompt_fallback"]:
+ counters["prompt_fallback_rows"] += 1
+ if meta["used_reference_fallback"]:
+ counters["reference_fallback_rows"] += 1
+ if meta["prompt_empty"]:
+ counters["empty_prompt_rows"] += 1
+ if meta["reference_empty"]:
+ counters["empty_reference_rows"] += 1
+ if not meta["image_path_exists"]:
+ raise FileNotFoundError(
+ f"missing image for row {source_index}: {meta['image_path']}"
+ )
+
+ out_fp.write(json.dumps(record, ensure_ascii=False) + "\n")
+ counters["rows_written"] += 1
+ if len(smoke_rows) < max(1, args.smoke_samples):
+ smoke_rows.append(record)
+
+ if not smoke_rows:
+ raise ValueError("No rows were converted. Check the input path and --max-samples.")
+
+ smoke = smoke_validate(smoke_rows, args.smoke_samples)
+
+ summary = {
+ "input_path": str(input_path),
+ "image_root": str(image_root),
+ "output_path": str(output_path),
+ "summary_path": str(summary_path),
+ "label": args.label,
+ "prompt_mode": args.prompt_mode,
+ "rows_seen": counters["rows_seen"],
+ "rows_written": counters["rows_written"],
+ "prompt_fallback_rows": counters["prompt_fallback_rows"],
+ "reference_fallback_rows": counters["reference_fallback_rows"],
+ "empty_prompt_rows": counters["empty_prompt_rows"],
+ "empty_reference_rows": counters["empty_reference_rows"],
+ "image_prefix_counts": dict(prefix_counter),
+ "images_per_sample_counts": {"1": counters["rows_written"]},
+ "smoke_validation": smoke,
+ }
+
+ summary_path.parent.mkdir(parents=True, exist_ok=True)
+ summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2) + "\n", encoding="utf-8")
+
+ print(json.dumps(summary, ensure_ascii=False, indent=2))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/math_prm/train_colocate.py b/examples/math_prm/train_colocate.py
new file mode 100755
index 00000000..70ae3e63
--- /dev/null
+++ b/examples/math_prm/train_colocate.py
@@ -0,0 +1,1115 @@
+"""
+GRPO Training with Co-located Reward Models
+
+This script implements Group Relative Policy Optimization (GRPO) training
+with co-located reward models for reinforcement learning from human feedback (RLHF).
+
+Key Features:
+ - Supports both text-only and vision-language models
+ - Multiple reward models (Value, Safety, Knowledge, Normal, General)
+ - Flexible strategy: DeepSpeed ZeRO or FSDP
+ - Meta device initialization for memory optimization
+ - EMA (Exponential Moving Average) model support
+ - Dynamic sampling and overlong buffer penalties (DAPO)
+
+Main Components:
+ - Actor: Policy model being trained
+ - Critic: Value model for advantage estimation (optional for GRPO)
+ - Reward Models: Multiple models for evaluating different aspects
+ - Initial Model: Reference model for KL divergence
+
+Training Pipeline:
+ 1. Load and initialize models (actor, critic, reward models)
+ 2. Setup data loaders (prompts + optional pretrain data)
+ 3. Configure optimizers and schedulers
+ 4. Run PPO/GRPO training loop via SPMDPPOTrainerVL
+
+Usage:
+ python examples/math_prm/train_colocate.py --pretrain --reward_pretrain ...
+
+For more details on arguments, see the argument parser at the bottom of this file.
+"""
+import argparse
+import itertools
+import math
+import re
+import os
+import sys
+import json
+from datetime import datetime
+from typing import Callable, Dict, List, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import (
+ AutoConfig,
+ AutoModelForTokenClassification,
+ AutoModelForVision2Seq,
+)
+
+from lightrft.utils import add_arguments, ensure_video_input_available
+ensure_video_input_available()
+
+from lightrft.datasets import PromptDatasetVL, SFTDatasetVL
+from lightrft.utils import blending_datasets, get_tokenizer_processor_vl
+from lightrft.models.actor_language import ActorLanguage
+from lightrft.models.actor_vl import ActorVL
+
+from lightrft.strategy import get_strategy
+
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+from math_prm_trainer import MathPRMSPMDPPOTrainerVL
+from reward_models_utils import load_reward_models, reward_fn, RECIPE
+
+
+def is_ursa_model(model_path: str) -> bool:
+ """
+ Check if the model is a URSA model by looking for URSA-specific config.
+
+ URSA models have:
+ - architectures: ["UrsaForConditionalGeneration"]
+ - model_type: "ursa"
+ - vision_config and aligner_config sections
+
+ Args:
+ model_path: Path to the model directory
+
+ Returns:
+ True if this is a URSA model, False otherwise
+ """
+ import os
+ config_path = os.path.join(model_path, "config.json")
+ if os.path.exists(config_path):
+ try:
+ import json
+ with open(config_path, 'r') as f:
+ config = json.load(f)
+ # Check for UrsaForConditionalGeneration in architectures
+ architectures = config.get("architectures", [])
+ if "UrsaForConditionalGeneration" in architectures:
+ return True
+ # Fallback: check model_type
+ if config.get("model_type") == "ursa":
+ return True
+ except:
+ pass
+ return False
+
+
+def resolve_reference_shard_size(world_size: int, preferred_shard_size: int = 8) -> int:
+ """
+ Pick a reference-model FSDP shard size that preserves the original 8-way
+ layout when possible, but still works for bounded small-world-size runs.
+ """
+ if world_size <= 0:
+ return preferred_shard_size
+ candidate = min(preferred_shard_size, world_size)
+ while candidate > 1 and world_size % candidate != 0:
+ candidate -= 1
+ return candidate
+
+
+def split_runtime_eval_dataset(prompts_data, args, strategy):
+ """
+ Build a deterministic held-out runtime eval split from prompt_data when no
+ explicit eval dataset is provided.
+
+ This follows the paper/plan intent of using a stable in-domain held-out set
+ instead of relying on an optional dataset split name.
+ """
+ if args.eval_holdout_size <= 0 or args.max_eval_samples <= 0:
+ return prompts_data, None
+
+ total_samples = len(prompts_data)
+ if total_samples <= 1:
+ strategy.print("Warning: prompt_data is too small to carve out a held-out runtime eval split.")
+ return prompts_data, None
+
+ eval_size = min(args.eval_holdout_size, args.max_eval_samples, total_samples - 1)
+ if eval_size <= 0:
+ strategy.print("Warning: held-out runtime eval split resolved to zero samples; skipping eval split.")
+ return prompts_data, None
+
+ if not hasattr(prompts_data, "train_test_split"):
+ strategy.print("Warning: prompt_data does not support train_test_split(); skipping held-out runtime eval.")
+ return prompts_data, None
+
+ split = prompts_data.train_test_split(test_size=eval_size, shuffle=True, seed=args.eval_holdout_seed)
+ train_data = split["train"]
+ eval_data = split["test"]
+ strategy.print(
+ "Prepared runtime eval holdout from prompt_data "
+ f"(train={len(train_data)}, eval={len(eval_data)}, seed={args.eval_holdout_seed})."
+ )
+ return train_data, eval_data
+
+
+def load_actor_tokenizer_processor(
+ *,
+ model_path: str,
+ model,
+ strategy,
+ use_fast: bool,
+):
+ """
+ Load the actor tokenizer/processor, using the explicit URSA processor path
+ when the checkpoint is a URSA model.
+ """
+ if is_ursa_model(model_path):
+ from ursa_model import UrsaProcessor
+
+ processor = UrsaProcessor.from_pretrained(model_path)
+ tokenizer = processor.tokenizer
+ tokenizer.padding_side = "left"
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.pad_token_id = tokenizer.eos_token_id
+ model.config.pad_token_id = tokenizer.pad_token_id
+ strategy.print(
+ f"Loaded URSA processor explicitly: tokenizer={type(tokenizer).__name__}, "
+ f"processor={type(processor).__name__}"
+ )
+ return tokenizer, processor
+
+ return get_tokenizer_processor_vl(
+ model_path,
+ model,
+ "left",
+ use_fast=use_fast,
+ )
+
+
+def build_actor_init_kwargs(
+ args,
+ *,
+ ds_config,
+ include_lora: bool,
+ include_disable_logprobs_flashattn: bool,
+):
+ """
+ Build Actor/UrsaActor initialization kwargs while keeping train/eval variants aligned.
+ """
+ kwargs = dict(
+ use_flash_attention_2=args.flash_attn,
+ bf16=args.bf16,
+ load_in_4bit=args.load_in_4bit,
+ ds_config=ds_config,
+ packing_samples=args.packing_samples,
+ fused_linear_logprob=args.fused_linear_logprob,
+ )
+ if include_lora:
+ kwargs.update(
+ lora_rank=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=args.target_modules,
+ lora_dropout=args.lora_dropout,
+ )
+ if include_disable_logprobs_flashattn:
+ kwargs["disable_logprobs_flashattn"] = args.disable_logprobs_flashattn
+ return kwargs
+
+
+def prepare_ursa_runtime_for_inference_engines(strategy=None):
+ """
+ Register the local URSA classes with HuggingFace auto classes so rollout
+ engines that rely on ``AutoConfig`` can resolve ``model_type='ursa'``.
+ """
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ if current_dir not in sys.path:
+ sys.path.insert(0, current_dir)
+
+ pythonpath = os.environ.get("PYTHONPATH")
+ pythonpath_parts = pythonpath.split(os.pathsep) if pythonpath else []
+ if current_dir not in pythonpath_parts:
+ os.environ["PYTHONPATH"] = os.pathsep.join([current_dir, *pythonpath_parts]) if pythonpath_parts else current_dir
+
+ from ursa_model import (
+ UrsaConfig,
+ UrsaForConditionalGeneration,
+ UrsaForTokenClassification,
+ )
+
+ AutoConfig.register("ursa", UrsaConfig, exist_ok=True)
+ AutoModelForVision2Seq.register(UrsaConfig, UrsaForConditionalGeneration, exist_ok=True)
+ AutoModelForTokenClassification.register(UrsaConfig, UrsaForTokenClassification, exist_ok=True)
+
+ if strategy is not None:
+ strategy.print(
+ "Registered URSA auto classes for inference engines "
+ f"(sys.path/PYTHONPATH include {current_dir})"
+ )
+
+
+def train(args):
+ """
+ Main training function for GRPO with co-located reward models.
+
+ Training workflow:
+ 1. Initialize strategy (DeepSpeed or FSDP)
+ 2. Initialize models with meta_init option for memory efficiency
+ 3. Load reward models (multiple types supported)
+ 4. Setup dataloaders for prompts and optional pretrain data
+ 5. Configure optimizers and schedulers
+ 6. Setup inference engine (vLLM or SGLang)
+ 7. Run training loop via SPMDPPOTrainerVL
+ 8. Save final model
+
+ Args:
+ args: Parsed command-line arguments containing all training configuration
+
+ Key configurations:
+ - meta_init: Initialize models on meta device to save CPU RAM
+ - freeze_prefix: Freeze vision encoder during training
+ - fsdp: Use FSDP instead of DeepSpeed
+ - rm_use_engine: Generic flag retained for other reward types, but
+ URSA math_prm/math_psgrpo PRM paths still load via HF directly
+ """
+ if args.hf_separate_rollout_actor and args.engine_type != "hf":
+ raise ValueError("--hf_separate_rollout_actor requires --engine_type hf.")
+ if args.hf_separate_rollout_actor and not args.fsdp:
+ raise ValueError("--hf_separate_rollout_actor currently requires --fsdp.")
+
+ # configure strategy
+ strategy = get_strategy(args)
+
+ ds_train_cfg = strategy.get_ds_train_config(is_actor=True) if not args.fsdp else None
+ ds_eval_cfg = strategy.get_ds_eval_config(offload=False) if not args.fsdp else None
+
+ # configure model
+ # ==================== Model Initialization ====================
+ # Initialize all models within init_model_context for memory efficiency.
+ # When meta_init=True, models are created on "meta" device as empty shells,
+ # fundamentally resolving CPU OOM issues.
+ with strategy.init_model_context(meta_init=args.meta_init):
+ strategy.print(f"Initializing models with meta_init={args.meta_init}")
+
+ # Check if this is a URSA model
+ is_ursa = is_ursa_model(args.pretrain)
+
+ # Select Actor class based on model type and text_only flag
+ if is_ursa:
+ strategy.print(f"Detected URSA model, using UrsaActor")
+ from ursa_actor import UrsaActor
+ Actor = UrsaActor
+ elif args.text_only:
+ Actor = ActorLanguage
+ else:
+ Actor = ActorVL
+
+ # Initialize Actor (policy model)
+ actor = Actor(
+ args.pretrain,
+ **build_actor_init_kwargs(
+ args,
+ ds_config=ds_train_cfg,
+ include_lora=True,
+ include_disable_logprobs_flashattn=True,
+ ),
+ )
+
+ rollout_actor = None
+ if args.hf_separate_rollout_actor:
+ rollout_actor = Actor(
+ args.pretrain,
+ **build_actor_init_kwargs(
+ args,
+ ds_config=ds_eval_cfg,
+ include_lora=True,
+ include_disable_logprobs_flashattn=True,
+ ),
+ )
+
+ if args.actor_init_on_gpu:
+ actor = actor.to(torch.cuda.current_device())
+
+ # pre-prepare is used for saving RAM memory when training 72B model
+ if args.fsdp:
+ setattr(actor, "is_actor", True)
+ actor = strategy.prepare_model(actor, is_training=True)
+
+ # Optionally freeze parameters (e.g., vision encoder).
+ # Qwen2-VL etc. expose vision under "visual.*"; URSA uses "vision_model.*"
+ # plus an "aligner.*" projector. Match all of them so --freeze_prefix
+ # actually fires for the URSA stack.
+ if args.freeze_prefix:
+ freeze_prefix = ["visual", "vision_model", "aligner"]
+ frozen_params_count = 0
+ total_params_count = 0
+ for name, param in actor.model.named_parameters():
+ total_params_count += 1
+ if any(name.startswith(prefix) for prefix in freeze_prefix):
+ param.requires_grad = False
+ frozen_params_count += 1
+ strategy.print(f"Froze {frozen_params_count}/{total_params_count} parameters based on prefixes: {freeze_prefix}")
+
+ if args.critic_pretrain:
+ try:
+ from lightrft.models import get_vlm_for_sequence_regression
+ except ImportError as exc:
+ raise ImportError(
+ "critic_pretrain was provided, but get_vlm_for_sequence_regression "
+ "is not available in this LightRFT checkout."
+ ) from exc
+ critic = get_vlm_for_sequence_regression(
+ args.critic_pretrain,
+ "critic",
+ normalize_reward=args.normalize_reward_for_critic,
+ use_flash_attention_2=args.flash_attn,
+ bf16=args.bf16,
+ load_in_4bit=args.load_in_4bit,
+ lora_rank=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=args.target_modules,
+ lora_dropout=args.lora_dropout,
+ ds_config=ds_train_cfg,
+ value_head_prefix=args.value_head_prefix,
+ init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain,
+ )
+ else:
+ critic = None
+
+ # Load reward models (multiple types: value, safety, knowledge, etc.)
+ strategy.report_memory(f"before loaded reward models in main entry")
+ reward_models, reward_tokenizers, label_map = load_reward_models(
+ raw_reward_pretrain=args.reward_pretrain,
+ strategy=strategy,
+ use_engine=args.rm_use_engine,
+ )
+ strategy.print(f"label_map: {label_map}")
+ strategy.report_memory(f"after loaded reward models in main entry")
+
+ strategy.print(actor)
+ strategy.print(critic)
+
+ # load weights for reference actor
+ if args.init_kl_coef == 0:
+ initial_model = None
+ else:
+ # Use the same Actor class (including URSA if detected)
+ initial_model = Actor(
+ args.pretrain,
+ **build_actor_init_kwargs(
+ args,
+ ds_config=ds_eval_cfg,
+ include_lora=False,
+ include_disable_logprobs_flashattn=False,
+ ),
+ )
+ if args.fsdp:
+ reference_shard_size = resolve_reference_shard_size(
+ world_size=strategy.world_size,
+ preferred_shard_size=8,
+ )
+ strategy.print(
+ "Preparing reference model with shard_size="
+ f"{reference_shard_size} (world_size={strategy.world_size})"
+ )
+ initial_model = strategy.prepare_model(
+ initial_model,
+ is_training=False,
+ shard_size=reference_shard_size,
+ )
+ strategy.offload_model(initial_model)
+
+ if args.enable_ema:
+ # Use the same Actor class (including URSA if detected)
+ ema_model = Actor(
+ args.pretrain,
+ use_flash_attention_2=args.flash_attn,
+ bf16=args.bf16,
+ load_in_4bit=args.load_in_4bit,
+ ds_config=ds_eval_cfg,
+ )
+ else:
+ ema_model = None
+
+ # configure tokenizer and processor
+ tokenizer, processor = load_actor_tokenizer_processor(
+ model_path=args.pretrain,
+ model=actor.model,
+ strategy=strategy,
+ use_fast=not strategy.args.disable_fast_tokenizer,
+ )
+ assert processor is not None, "processor is None"
+
+ # ==================== Data Loading Optimization ====================
+ # The following sections now rely on the robust `blending_datasets` function.
+ # We add more logging for clarity.
+
+ # Prepare prompts dataset
+ strategy.print(f"Loading prompts dataset from: {args.prompt_data} with split: {args.prompt_split}")
+ prompts_data = blending_datasets(
+ args.prompt_data,
+ args.prompt_data_probs,
+ strategy,
+ args.seed,
+ return_eval=False,
+ train_split=args.prompt_split,
+ )
+
+ heldout_eval_data = None
+ if not args.eval_data and not args.eval_split:
+ prompts_data, heldout_eval_data = split_runtime_eval_dataset(prompts_data, args, strategy)
+
+ prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data))))
+ prompts_dataset = PromptDatasetVL(prompts_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template)
+ strategy.print(f"Loaded {len(prompts_dataset)} samples for prompts.")
+
+ # Prepare evaluation dataset
+ eval_dataloader = None
+ if args.eval_data or args.eval_split:
+ eval_data_path = args.eval_data if args.eval_data else args.prompt_data
+ if eval_data_path:
+ strategy.print(f"Loading evaluation dataset from {eval_data_path}, split='{args.eval_split}'")
+ eval_data = blending_datasets(
+ eval_data_path, "1.0", strategy, args.seed, return_eval=False,
+ # Note: `train_split` parameter is used to specify the desired split name for evaluation data.
+ train_split=args.eval_split,
+ )
+ if len(eval_data) == 0:
+ strategy.print(f"Warning: Evaluation dataset at {eval_data_path} with split '{args.eval_split}' is empty. Skipping evaluation.")
+ else:
+ eval_data = eval_data.select(range(min(args.max_eval_samples, len(eval_data))))
+
+ eval_dataset = PromptDatasetVL(eval_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template)
+ # Cap eval DataLoader batch_size by local_hf_generate_max_batch_size to
+ # avoid the padding-leak bug. See heldout branch below for full rationale.
+ eval_dp_batch_size = args.rollout_batch_size // strategy.world_size
+ if args.engine_type == "hf":
+ mb_cap = int(getattr(args, "local_hf_generate_max_batch_size", 0) or 0)
+ if mb_cap > 0:
+ eval_dp_batch_size = min(eval_dp_batch_size, mb_cap)
+ eval_dataloader = strategy.setup_dataloader(
+ eval_dataset,
+ eval_dp_batch_size,
+ False,
+ False,
+ collate_fn=eval_dataset.collate_fn,
+ drop_last=False,
+ )
+ strategy.print(
+ f"Evaluation dataset loaded: {len(eval_dataset)} samples "
+ f"(eval DataLoader batch_size={eval_dp_batch_size})"
+ )
+ else:
+ strategy.print("Warning: eval_split specified but no data path available for evaluation.")
+ elif heldout_eval_data is not None:
+ eval_data = heldout_eval_data.select(range(min(args.max_eval_samples, len(heldout_eval_data))))
+ eval_dataset = PromptDatasetVL(
+ eval_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template
+ )
+ # Match DataLoader batch_size to local_hf_generate_max_batch_size for engine_type=hf.
+ # fast_exp_maker.process_multimodal_batch calls processor(padding=True) on the full
+ # DataLoader batch, then strategy_base chunks the already-padded tensor into
+ # micro-batches of local_hf_generate_max_batch_size. Without this alignment, each
+ # micro-batch keeps the max-of-DL-batch padded length (e.g. 16-wide pad in a 4-wide
+ # chunk), and the extra left-pad tokens — even with attention_mask masking — degrade
+ # URSA's greedy decode by ~8pp via RoPE / vision-path interaction. Setting DL-batch
+ # = micro-batch eliminates the leak so eval matches `tmp/ckpt_eval_aligned.py --bs N`
+ # exactly. See PR53 issuecomment-... for the 11.9pp breakdown.
+ eval_dp_batch_size = args.rollout_batch_size // strategy.world_size
+ if args.engine_type == "hf":
+ mb_cap = int(getattr(args, "local_hf_generate_max_batch_size", 0) or 0)
+ if mb_cap > 0:
+ eval_dp_batch_size = min(eval_dp_batch_size, mb_cap)
+ eval_dataloader = strategy.setup_dataloader(
+ eval_dataset,
+ eval_dp_batch_size,
+ False,
+ False,
+ collate_fn=eval_dataset.collate_fn,
+ drop_last=False,
+ )
+ strategy.print(
+ f"Held-out runtime evaluation dataset loaded: {len(eval_dataset)} samples "
+ f"(eval DataLoader batch_size={eval_dp_batch_size}, aligned with "
+ f"local_hf_generate_max_batch_size={getattr(args, 'local_hf_generate_max_batch_size', 'n/a')})"
+ )
+
+ # Prepare pretrain dataset
+ pretrain_dataloader = None
+ if args.pretrain_data:
+ strategy.print(f"Loading pretrain dataset from: {args.pretrain_data} with split: {args.pretrain_split}")
+ pretrain_data = blending_datasets(
+ args.pretrain_data, args.pretrain_data_probs, strategy, args.seed,
+ return_eval=False, train_split=args.pretrain_split,
+ )
+ if len(pretrain_data) == 0:
+ strategy.print(f"Warning: Pretrain dataset at {args.pretrain_data} is empty. PTX loss will not be applied.")
+ pretrain_dataloader = None
+ else:
+ pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
+ # Calculate total samples needed for pretraining
+ total_pretrain_samples = args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt
+ pretrain_data_subset = pretrain_data.select(range(min(len(pretrain_data), total_pretrain_samples)))
+
+ pretrain_dataset = SFTDatasetVL(
+ pretrain_data_subset, tokenizer, pretrain_max_len, strategy, pretrain_mode=True,
+ )
+ strategy.print(f"Loaded {len(pretrain_dataset)} samples for pretraining.")
+ pretrain_dataloader = itertools.cycle(
+ iter(
+ strategy.setup_dataloader(
+ pretrain_dataset, args.micro_train_batch_size, True, True, pretrain_dataset.collate_fn,
+ )
+ )
+ )
+ else:
+ pretrain_dataloader = None
+
+ # Prepare prompts dataloader
+ prompts_dataloader = strategy.setup_dataloader(
+ prompts_dataset, args.rollout_batch_size // strategy.world_size, True, True, collate_fn=prompts_dataset.collate_fn
+ )
+
+ if args.pretrain_data:
+ pretrain_dataloader = itertools.cycle(
+ iter(
+ strategy.setup_dataloader(
+ pretrain_dataset,
+ args.micro_train_batch_size,
+ True,
+ True,
+ pretrain_dataset.collate_fn,
+ )
+ )
+ )
+ else:
+ pretrain_dataloader = None
+
+ # for scheduler
+ num_update_steps_per_episodes = (
+ len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs
+ )
+ max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes)
+
+ # gradient_checkpointing
+ if args.gradient_checkpointing:
+ actor.gradient_checkpointing_enable(
+ gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
+ )
+ if critic is not None:
+ critic.gradient_checkpointing_enable(
+ gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
+ )
+
+ (
+ (actor, actor_optim, actor_scheduler),
+ (critic, critic_optim, critic_scheduler),
+ reward_models,
+ initial_model,
+ ) = strategy.prepare_models_and_optimizers(actor, critic, reward_models, initial_model, args, max_steps)
+
+ if rollout_actor is not None:
+ keep_rollout_on_gpu = bool(getattr(strategy.config, "hf_separate_rollout_keep_on_gpu", False))
+ rollout_actor = strategy.prepare_model(
+ rollout_actor,
+ is_training=False,
+ shard_size=-1,
+ reshard_after_forward=False,
+ )
+ rollout_actor.gradient_checkpointing_disable()
+ rollout_actor.eval()
+ if not keep_rollout_on_gpu:
+ strategy.offload_model(rollout_actor)
+ residency_note = "kept on GPU" if keep_rollout_on_gpu else "offloaded to CPU"
+ strategy.print(
+ "Prepared separate local HF rollout actor with FSDP full-shard, gc disabled, "
+ f"reshard_after_forward disabled, and {residency_note}."
+ )
+
+ # rollout_eos_patch is OPT-IN as of the eval-fix PR. With it OFF (default),
+ # rollout/eval generation falls back to HF default stopping (EosTokenCriteria +
+ # MaxLengthCriteria) which is token-equivalent to a bare model.generate call.
+ # See PR #53 issuecomment-4394197141 for why the patch was harmful by default.
+ if getattr(args, "enable_rollout_eos_patch", False):
+ from rollout_eos_patch import install_math_prm_rollout_eos_patch
+ install_math_prm_rollout_eos_patch(rollout_actor, tokenizer, tokenizer.eos_token_id)
+ strategy.print(
+ "Installed math_prm rollout EOS patch on rollout_actor.model.generate "
+ "(legacy --enable_rollout_eos_patch flag set; this BIASES rollout reward "
+ "and eval outcome — only enable to reproduce historical broken behavior)."
+ )
+ else:
+ strategy.print(
+ "rollout_eos_patch NOT installed (default). Generation uses HF default "
+ "stopping criteria (EosTokenCriteria + MaxLengthCriteria), token-equivalent "
+ "to bare model.generate. Use --enable_rollout_eos_patch to restore legacy."
+ )
+
+ strategy.print(reward_models)
+
+ if ema_model:
+ ema_model._offload = True
+ ema_model = strategy.prepare(ema_model, is_rlhf=True)
+
+ # load checkpoint
+ consumed_samples = 0
+ if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")):
+ _, states = strategy.load_ckpt(actor.model, os.path.join(args.ckpt_path, "_actor"),
+ optimizer=actor_optim, scheduler=actor_scheduler)
+ if args.critic_pretrain:
+ strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic"))
+ consumed_samples = states["consumed_samples"]
+ strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}")
+
+ os.makedirs(args.save_path, exist_ok=True)
+ strategy.report_memory("after models init")
+
+ if is_ursa:
+ prepare_ursa_runtime_for_inference_engines(strategy)
+
+ strategy.report_memory("before setup_inference_engine")
+ strategy.setup_inference_engine(
+ args,
+ engine_type=args.engine_type,
+ actor=actor,
+ rollout_actor=rollout_actor,
+ tokenizer=tokenizer,
+ processor=processor,
+ )
+ strategy.report_memory("after setup_inference_engine")
+
+ # configure Trainer
+ trainer = MathPRMSPMDPPOTrainerVL(
+ strategy,
+ actor,
+ critic,
+ reward_models,
+ initial_model,
+ ema_model,
+ actor_optim,
+ critic_optim,
+ actor_scheduler,
+ critic_scheduler,
+ max_epochs=args.max_epochs,
+ micro_train_batch_size=args.micro_train_batch_size,
+ micro_rollout_batch_size=args.micro_rollout_batch_size,
+ gradient_checkpointing=args.gradient_checkpointing,
+ tokenizer=tokenizer,
+ processor=processor,
+ prompt_max_len=args.prompt_max_len,
+ value_clip=args.value_clip,
+ eps_clip=args.eps_clip,
+ loss_agg_mode=args.loss_agg_mode,
+ use_gspo=args.use_gspo,
+ normalize_advantages=args.normalize_advantages,
+ use_sequence_rewards=args.use_sequence_rewards,
+ gamma=args.gamma,
+ lambd=args.lambd,
+ init_kl_coef=args.init_kl_coef,
+ kl_target=args.kl_target,
+ ema_beta=0.992,
+ ptx_coef=args.ptx_coef,
+ max_norm=args.max_norm,
+ # for GPT generation
+ do_sample=True,
+ max_new_tokens=args.generate_max_len,
+ max_length=args.max_len,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ top_k=args.top_k,
+ repetition_penalty=args.repetition_penalty,
+ no_repeat_ngram_size=args.no_repeat_ngram_size,
+ pad_token_id=tokenizer.pad_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ # reward model
+ reward_fn=reward_fn,
+ reward_fn_label_map=label_map,
+ reward_recipe=RECIPE,
+ reward_tokenizers=reward_tokenizers,
+ save_hf_ckpt=args.save_hf_ckpt,
+ disable_ds_ckpt=args.disable_ds_ckpt,
+ packing_samples=args.packing_samples,
+ # overlong_reward
+ dynamic_sampling=args.dynamic_sampling,
+ overlong_buffer=args.overlong_buffer,
+ overlong_buffer_len=args.overlong_buffer_len,
+ overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor,
+ print_replay_buffer_stats=args.print_replay_buffer_stats,
+ )
+
+ # ---- Optional initial evaluate-only at step 0 (no PPO update) ----
+ # Useful for diagnosing model state at step 0 vs step 1; e.g. to attribute the
+ # outcome gap between standalone bs=1 eval and the 8-rank FSDP bs=4 eval pipeline.
+ # Triggered by --initial_eval (default False, no-op).
+ if getattr(args, "initial_eval", False) and eval_dataloader is not None:
+ strategy.print(f"\n{'=' * 60}\n[initial_eval] Running evaluate at step 0 (NO PPO update)\n{'=' * 60}")
+ trainer.eval_dataloader = eval_dataloader # ensure trainer has handle
+ raw = trainer.evaluate(eval_dataloader, global_step=0)
+ if strategy.is_rank_0() and raw:
+ strategy.print(f"[initial_eval] step 0 outcome: {raw}")
+ if getattr(args, "initial_eval_only", False):
+ strategy.print("[initial_eval] --initial_eval_only set, exiting before training.")
+ return
+
+ trainer.fit(args, prompts_dataloader=prompts_dataloader, pretrain_dataloader=pretrain_dataloader, eval_dataloader=eval_dataloader, consumed_samples=0, num_update_steps_per_episodes=num_update_steps_per_episodes)
+
+ # save model checkpoint after fitting on only rank0
+ strategy.save_model(
+ ema_model if args.enable_ema else actor,
+ tokenizer,
+ args.save_path,
+ )
+
+ if args.critic_pretrain and args.save_value_network:
+ strategy.save_model(
+ critic,
+ tokenizer,
+ args.save_path + "_critic",
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--engine_type", type=str, default="hf", help="Choose inference engine type: vllm, sglang, hf")
+ parser.add_argument("--text_only", action="store_true", default=False)
+
+ # Checkpoint
+ parser.add_argument("--save_path", type=str, default="./ckpt")
+ parser.add_argument("--save_steps", type=int, default=-1)
+ parser.add_argument("--save_hf_ckpt", action="store_true", default=False)
+ parser.add_argument("--disable_ds_ckpt", action="store_true", default=False)
+ parser.add_argument("--save_trajectories", action="store_true", default=False, help="Save experience trajectories to JSON for debugging")
+ parser.add_argument(
+ "--trajectory_analysis",
+ action="store_true",
+ default=False,
+ help="Enable extra trajectory analysis metrics when saving trajectories",
+ )
+ parser.add_argument("--num_trajectories_to_save", type=int, default=10, help="Number of trajectories to save per checkpoint")
+ parser.add_argument("--print_replay_buffer_stats", action="store_true", default=False, help="Print detailed replay buffer statistics during training")
+ parser.add_argument("--enable_profile", action="store_true", default=False, help="Enable persistent step profiling with local files and W&B metrics")
+ parser.add_argument("--logging_steps", type=int, default=1)
+ parser.add_argument("--eval_steps", type=int, default=-1)
+ parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo")
+ parser.add_argument("--max_ckpt_num", type=int, default=3)
+ parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
+ parser.add_argument("--load_checkpoint", action="store_true", default=False)
+
+ # DAPO
+ parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy")
+ parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO")
+ parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer")
+ parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them")
+
+ # PPO
+ parser.add_argument("--num_episodes", type=int, default=10)
+ parser.add_argument("--rollout_batch_size", type=int, default=128)
+ parser.add_argument("--micro_rollout_batch_size", type=int, default=4)
+ parser.add_argument("--max_epochs", type=int, default=1)
+ parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt")
+ parser.add_argument("--generate_max_len", type=int, default=3072, help="Max tokens to generate in PPO")
+ parser.add_argument(
+ "--max_len",
+ type=int,
+ default=None,
+ help=(
+ "Optional explicit total max_len (prompt + generation) for the "
+ "PromptDataset/SFTDataset. Defaults to prompt_max_len + "
+ "generate_max_len when unset; see train_colocate.py:542 and :709."
+ ),
+ )
+ parser.add_argument("--max_samples", type=int, default=15360)
+ parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping")
+ parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss")
+ parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef")
+ parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range")
+ parser.add_argument("--loss_agg_mode", type=str, default='seq-mean-token-mean',
+ help="Loss aggregation mode. Options: ['token-mean', 'seq-mean-token-sum', 'seq-mean-token-mean', 'seq-mean-token-sum-norm']")
+ parser.add_argument("--use_gspo", action="store_true", default=False, help="Enable GSPO (Group Sequence Policy Optimization) mode")
+ parser.add_argument("--normalize_advantages", action="store_true", default=True, help="Enable advantage normalization in GSPO")
+ parser.add_argument("--use_sequence_rewards", action="store_true", default=True, help="Use sequence-level rewards in GSPO")
+ parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range")
+ parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd")
+ parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma")
+ parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU")
+ parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size")
+ parser.add_argument("--normalize_reward_for_critic", action="store_true", default=False, help="Enable Reward Normalization in critic model")
+ parser.add_argument("--top_p", type=float, default=1.0)
+ parser.add_argument("--top_k", type=int, default=-1)
+ parser.add_argument("--temperature", type=float, default=1.0)
+ parser.add_argument("--repetition_penalty", type=float, default=1.0)
+ parser.add_argument("--no_repeat_ngram_size", type=int, default=0)
+ parser.add_argument("--freeze_prefix", action="store_true", default=False, help="Freeze the prefix part (e.g. vision encoder) of the actor model")
+ parser.add_argument("--freezing_actor_steps", type=int, default=-1, help="Used for critic initialization")
+ parser.add_argument(
+ "--n_samples_per_prompt", type=int, default=8, help="number of responses for each prompt in generation"
+ )
+ parser.add_argument("--save_value_network", action="store_true", default=False, help="Save critic model")
+ parser.add_argument("--actor_learning_rate", type=float, default=1e-6)
+ parser.add_argument("--critic_learning_rate", type=float, default=9e-6)
+ parser.add_argument("--lr_warmup_ratio", type=float, default=0.03)
+ parser.add_argument("--kl_target", type=float, default=None)
+ parser.add_argument("--init_kl_coef", type=float, default=0.001, help="KL penalty in PPO")
+ parser.add_argument(
+ "--kl_estimator",
+ type=str,
+ default="k1",
+ choices=["k1", "k2", "k3"],
+ help=(
+ "In GRPO, k3 is utilized as the loss function, while k2, when used as the loss, is nearly equivalent to k1."
+ ),
+ )
+ parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer")
+
+ # Reward/Advantage Norm/Clip Arguments
+ parser.add_argument("--reward_running_norm", action="store_true", default=False, help="Enable running normalization for rewards.")
+ parser.add_argument("--reward_running_norm_minus_mean", action="store_true", default=False, help="When using reward normalization, subtract the mean; otherwise, only scale by the std.")
+ parser.add_argument("--reward_clip", type=float, default=0.0, help="Clip rewards to the range [-reward_clip, reward_clip]. 0.0 means no clipping.")
+ parser.add_argument("--advantages_norm", action="store_true", default=False, help="Enable whitening for advantages.")
+ parser.add_argument("--advantage_clip", type=float, default=0.0, help="Clip advantages to the range [-advantage_clip, advantage_clip]. 0.0 means no clipping.")
+
+ # DeepSpeed
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed")
+ parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage")
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
+ parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16")
+ parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.")
+ parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size")
+ parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer")
+ parser.add_argument("--actor_init_on_gpu", action="store_true", default=False)
+ parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2")
+ parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss")
+ parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type")
+ parser.add_argument("--overlap_comm", action="store_true", default=False)
+ parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False)
+ parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False)
+ parser.add_argument("--disable_logprobs_flashattn", action="store_true", default=False, help="Disable flash attn implementation in log_probs calculation")
+
+ # FSDP
+ parser.add_argument("--no_shard_vit", action="store_true", default=False, help="Disable sharding for vision transformer")
+ parser.add_argument("--meta_init", action="store_true", default=False, help="Initialize models on meta device to save CPU memory")
+
+ # Reinforce
+ parser.add_argument(
+ "--advantage_estimator",
+ type=str,
+ choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++",
+ "ursa_variant2"],
+ default="gae",
+ help=(
+ "Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm, "
+ "reinforce++. 'ursa_variant2' (URSA paper Eq.9 strict alignment) is provided by "
+ "examples/math_prm/ursa_variant2.py and only meaningful with label='math_per_step_prm' + "
+ "n_samples_per_prompt >= 2."
+ ),
+ )
+
+ parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO")
+
+ parser.add_argument(
+ "--per_step_reward_mode",
+ type=str,
+ choices=["raw", "group_norm"],
+ default="group_norm",
+ help=(
+ "How to integrate per-step PRM rewards (Math-Shepherd-style "
+ "per-token reward path, distinct from the strict paper Eq.9 path "
+ "selected via --advantage_estimator ursa_variant2). "
+ "'group_norm' (default): for each step k, subtract group mean and "
+ "divide by group std across the K trajectories in the same prompt "
+ "group BEFORE scattering to step-boundary tokens. Produces "
+ "zero-mean signed advantages (GRPO baseline convention). "
+ "'raw': scatter raw sigmoid step_score directly. WARNING — raw "
+ "is unsafe: sigmoid scores are always positive, so every "
+ "post-cumsum token advantage is non-negative and PG pushes "
+ "every probability up. Kept only for paper Figure ablation. "
+ "Only active when label is 'math_per_step_prm' AND "
+ "--advantage_estimator is the cumsum path (group_norm/grpo). "
+ "For the strict paper Eq.9 path use --advantage_estimator "
+ "ursa_variant2 (handles its own group normalization)."
+ ),
+ )
+
+ # LoRA
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
+ parser.add_argument("--lora_rank", type=int, default=0)
+ parser.add_argument("--lora_alpha", type=int, default=16)
+ parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear")
+ parser.add_argument("--lora_dropout", type=float, default=0)
+
+ # Models
+ parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path")
+ parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path")
+ parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API")
+ parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path")
+ parser.add_argument("--value_head_prefix", type=str, default="score")
+
+ # Custom dataset
+ parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path")
+ parser.add_argument(
+ "--prompt_data_probs",
+ type=str,
+ default="1.0",
+ help="sampling probs for datasets",
+ )
+ parser.add_argument("--prompt_split", type=str, default="train")
+
+ # Evaluation dataset
+ parser.add_argument("--eval_data", type=str, default=None, help="HF evaluation dataset name or path (default: use prompt_data)")
+ parser.add_argument("--eval_split", type=str, default="", help="Evaluation data split (default: disabled)")
+ parser.add_argument("--max_eval_samples", type=int, default=500, help="Maximum number of samples to evaluate (default: 500)")
+ parser.add_argument(
+ "--eval_holdout_size",
+ type=int,
+ default=500,
+ help="Deterministic held-out eval subset size sampled from prompt_data when eval_data is unset (default: 500)",
+ )
+ parser.add_argument(
+ "--eval_holdout_seed",
+ type=int,
+ default=42,
+ help="Seed for deterministic held-out runtime eval split (default: 42)",
+ )
+ parser.add_argument(
+ "--eval_n_samples_per_prompt",
+ type=int,
+ default=1,
+ help="Number of eval generations per prompt (default: 1)",
+ )
+ parser.add_argument(
+ "--eval_do_sample",
+ action="store_true",
+ default=False,
+ help="Use sampling during runtime eval instead of greedy decoding",
+ )
+ parser.add_argument(
+ "--eval_generate_max_len",
+ type=int,
+ default=None,
+ help="Maximum generation length for runtime eval (default: use generate_max_len)",
+ )
+ parser.add_argument("--eval_temperature", type=float, default=0.0, help="Eval temperature (default: 0.0)")
+ parser.add_argument("--eval_top_p", type=float, default=1.0, help="Eval top-p (default: 1.0)")
+ parser.add_argument("--eval_top_k", type=int, default=-1, help="Eval top-k (default: -1)")
+ parser.add_argument(
+ "--eval_repetition_penalty",
+ type=float,
+ default=1.0,
+ help="Eval repetition penalty (default: 1.0)",
+ )
+ parser.add_argument(
+ "--eval_no_repeat_ngram_size",
+ type=int,
+ default=0,
+ help="Eval no-repeat-ngram size (default: 0)",
+ )
+
+ parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path")
+ parser.add_argument(
+ "--pretrain_data_probs",
+ type=str,
+ default="1.0",
+ help="sampling probs for datasets",
+ )
+ parser.add_argument("--pretrain_split", type=str, default="train")
+ parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key")
+ parser.add_argument("--images_key", type=str, default="images", help="JSON dataset key for images")
+ parser.add_argument("--reference_key", type=str, default="reference", help="JSON dataset key for reference answers")
+ parser.add_argument("--label_key", type=str, default="label", help="JSON dataset key")
+ parser.add_argument("--input_template", type=str, default=None)
+ parser.add_argument(
+ "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template"
+ )
+
+ parser.add_argument("--system_prompt", type=str, default=None, help="HF System Prompt")
+
+
+ # wandb parameters
+ parser.add_argument("--use_wandb", type=str, default=None)
+ parser.add_argument("--wandb_org", type=str, default=None)
+ parser.add_argument("--wandb_group", type=str, default=None)
+ parser.add_argument("--wandb_project", type=str, default="lightrft_train_ppo")
+ parser.add_argument(
+ "--wandb_run_name",
+ type=str,
+ default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"),
+ )
+
+ # TensorBoard parameters
+ parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path")
+
+ # ModelScope parameters
+ parser.add_argument("--use_ms", action="store_true", default=False)
+
+ # MultiModal
+ parser.add_argument("--limit_mm_image_per_prompt", type=int, default=-1, help="the max image number of each text in multi model for inference backend")
+
+ # CPGD
+ parser.add_argument("--use_cpg_loss", action="store_true", default=False, help="whether to use the clipped policy gradient loss from CPGD")
+
+ # initial-eval (eval at step 0, before any PPO update)
+ parser.add_argument(
+ "--initial_eval", action="store_true", default=False,
+ help="Run evaluate(global_step=0) before fit(). Useful for measuring base "
+ "model outcome under the actual training eval pipeline (8-rank FSDP + "
+ "bs=4 etc.) without any PPO drift.",
+ )
+ parser.add_argument(
+ "--initial_eval_only", action="store_true", default=False,
+ help="With --initial_eval, exit immediately after initial eval (skip training).",
+ )
+
+ # math_prm rollout EOS patch
+ parser.add_argument(
+ "--enable_rollout_eos_patch", action="store_true", default=False,
+ help=(
+ "Install StructuredAnswerStoppingCriteria on rollout_actor.model.generate "
+ "(legacy behavior). DEFAULT OFF. The patch makes generation stop right after "
+ "'†Answer:' marker, but historical experiments (PR #53 issuecomment-4394197141) "
+ "showed it (a) lowers eval outcome by ~9.8pp due to truncated tokens, and "
+ "(b) biases rollout reward signal towards short responses (Goodhart's law) "
+ "causing length collapse during RL. With patch off, generation falls back to "
+ "HF default stopping (EosTokenCriteria + MaxLengthCriteria), which is what we "
+ "want for both rollout reward fidelity and eval accuracy alignment."
+ ),
+ )
+
+ add_arguments(parser)
+
+ args = parser.parse_args()
+
+
+ if args.advantage_estimator not in ["gae"]:
+ args.critic_pretrain = None
+ elif args.critic_pretrain is None:
+ args.critic_pretrain = args.pretrain
+
+ if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]:
+ assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1"
+
+ if args.use_kl_loss:
+ if args.kl_estimator not in ["k2", "k3"]:
+ print(f"Recommend setting {args.kl_estimator} to 'k2' or 'k3' when using KL as a loss")
+ else:
+ if args.kl_estimator not in ["k1"]:
+ print(f"Recommend setting {args.kl_estimator} to 'k1' when not using KL as a loss.")
+
+ if args.advantage_estimator in ["gae", "cpgd"] and args.use_kl_loss:
+ warnings.warn(
+ "Using use_kl_loss=True with non-normalized advantage estimator "
+ "may result in double KL penalty. Consider disabling --use_kl_loss "
+ "or using --advantage_estimator group_norm"
+ )
+
+ if args.input_template and "{}" not in args.input_template:
+ print("[Warning] {} not in args.input_template, set to None")
+ args.input_template = None
+
+ if args.input_template and "\\n" in args.input_template:
+ print(
+ "[Warning] input_template contains \\n chracters instead of newline. "
+ "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell."
+ )
+
+ if args.use_ms:
+ from modelscope.utils.hf_util import patch_hub
+
+ # Patch hub to download models from modelscope to speed up.
+ patch_hub()
+
+ train(args)
diff --git a/examples/math_prm/ursa_actor.py b/examples/math_prm/ursa_actor.py
new file mode 100644
index 00000000..7e928365
--- /dev/null
+++ b/examples/math_prm/ursa_actor.py
@@ -0,0 +1,390 @@
+"""
+URSA-8B Actor Model Loader
+
+This module provides a custom actor loader for URSA-8B models, which use
+UrsaForConditionalGeneration instead of the standard AutoModelForVision2Seq.
+
+URSA-8B architecture:
+- Hybrid vision tower: SAM-B (1024x1024) + SigLIP-L (384x384)
+- MLP projector: Maps vision features to LLM embedding space
+- Language model: Qwen2.5-Math-Instruct (8B params)
+"""
+
+import sys
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Optional, Union
+from transformers.integrations.deepspeed import HfDeepSpeedConfig
+
+# Add current directory to path for ursa_model imports
+current_dir = os.path.dirname(os.path.abspath(__file__))
+if current_dir not in sys.path:
+ sys.path.insert(0, current_dir)
+
+from ursa_model import UrsaForConditionalGeneration
+from lightrft.models.actor_vl import ActorVL
+from lightrft.models.utils import apply_lora_configuration, reset_position_ids, entropy_from_logits
+
+
+class UrsaActor(ActorVL):
+ """
+ Actor wrapper for URSA-8B models.
+
+ This class extends ActorVL to support loading URSA-8B models using
+ UrsaForConditionalGeneration instead of AutoModelForVision2Seq.
+
+ Usage:
+ actor = UrsaActor(
+ pretrain_or_model="/path/to/URSA-8B",
+ use_flash_attention_2=True,
+ bf16=True,
+ lora_rank=0,
+ )
+ """
+
+ def __init__(
+ self,
+ pretrain_or_model,
+ use_flash_attention_2=False,
+ bf16=True,
+ lora_rank=0,
+ lora_alpha=16,
+ lora_dropout=0,
+ target_modules=None,
+ ds_config=None,
+ device_map=None,
+ packing_samples=False,
+ high_entropy_token_ratio=0.0,
+ **kwargs,
+ ) -> None:
+ """
+ Initialize URSA-8B actor model.
+
+ Args:
+ pretrain_or_model: Path to URSA-8B checkpoint or model instance
+ use_flash_attention_2: Enable Flash Attention 2.0
+ bf16: Use bfloat16 precision
+ lora_rank: LoRA rank (0 disables LoRA)
+ lora_alpha: LoRA alpha scaling parameter
+ lora_dropout: LoRA dropout rate
+ target_modules: Target modules for LoRA (auto-detected if None)
+ ds_config: DeepSpeed configuration
+ device_map: Device mapping for model placement
+ packing_samples: Enable sample packing
+ high_entropy_token_ratio: High entropy token filtering ratio
+ """
+ # Initialize parent class without calling its __init__
+ # We'll handle model loading ourselves
+ nn.Module.__init__(self)
+ self.high_entropy_token_ratio = high_entropy_token_ratio
+
+ if isinstance(pretrain_or_model, str):
+ self.pretrain_or_model = pretrain_or_model
+ attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager"
+
+ # DeepSpeed ZeRO-3 integration
+ if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
+ dschf = HfDeepSpeedConfig(ds_config)
+ else:
+ dschf = None # noqa: F841
+
+ # Prepare loading kwargs
+ from_pretrained_kwargs = {
+ "trust_remote_code": True,
+ "attn_implementation": attn_implementation,
+ "torch_dtype": torch.bfloat16 if bf16 else "auto",
+ }
+
+ # Check if we're in meta device context (FSDP)
+ try:
+ test_tensor = torch.empty(1)
+ is_meta_context = test_tensor.is_meta
+ except: # noqa
+ is_meta_context = False
+
+ if not is_meta_context and device_map is not None:
+ from_pretrained_kwargs["device_map"] = device_map
+
+ print(f"[UrsaActor] Loading URSA-8B model from {pretrain_or_model}")
+
+ # Load URSA model using UrsaForConditionalGeneration
+ self.model = UrsaForConditionalGeneration.from_pretrained(
+ pretrain_or_model,
+ **from_pretrained_kwargs
+ )
+
+ print(f"[UrsaActor] Successfully loaded URSA-8B model")
+
+ # Apply LoRA if requested
+ if lora_rank > 0:
+ print(f"[UrsaActor] Applying LoRA with rank={lora_rank}, alpha={lora_alpha}")
+ self.model = apply_lora_configuration(
+ model=self.model,
+ lora_rank=lora_rank,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ target_modules=target_modules,
+ freeze_vision_tower=True,
+ )
+
+ # Disable cache for training
+ self.model.config.use_cache = False
+
+ # Enable sample packing if requested
+ self.packing_samples = packing_samples
+ else:
+ # Model instance provided directly
+ self.model = pretrain_or_model
+ self.pretrain_or_model = "ursa"
+
+ print(f"[UrsaActor] Model type: {self.pretrain_or_model}")
+
+ def forward(
+ self,
+ sequences: torch.LongTensor,
+ num_actions=None,
+ attention_mask: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.Tensor] = None,
+ video_grid_thw: Optional[torch.Tensor] = None,
+ return_output: bool = False,
+ packed_seq_lens: Optional[list] = None,
+ ) -> torch.Tensor:
+ """
+ VLM-aligned forward.
+
+ URSA's vision tower expands every <|image|> placeholder into 576 vision
+ tokens during the LM forward, so ``output["logits"]`` is longer than the
+ input ``sequences`` along the seq dim. The default ``ActorVL.forward``
+ feeds ``output["logits"][:, :-1, :]`` (length E-1) and ``sequences[:, 1:]``
+ (length T-1) into ``log_probs_from_logits``, which then hits PyTorch's
+ ``gather(dim=-1, index=...)`` — that op silently TRUNCATES the rows of
+ ``logits`` to ``len(labels)`` instead of erroring. The result: log-probs
+ are read from the wrong (vision-token / early-prompt) positions, never
+ from the actual generation positions. KL/PPO/ratio all become noise.
+
+ We sidestep the bug entirely by slicing the logits to the action range
+ on the seq dim first (where alignment is unambiguous because generation
+ always lives at the tail of the expanded sequence), then using a single
+ ``F.log_softmax + gather`` over the action labels. fp32 throughout so
+ the precision matches the rest of the PPO loss path.
+ """
+ if self.packing_samples:
+ position_ids = reset_position_ids(attention_mask)
+ attention_mask_for_model = None
+ else:
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ attention_mask_for_model = attention_mask
+
+ # Sanitize actor-leaked image tokens before forward. During GRPO rollout
+ # an actor can generate literal `<|image|>` / `` strings inside
+ # its response (rare but observed; freq goes up when KL spikes hard
+ # late in training). The tokenizer then maps those to image_token_index,
+ # so the sequence ends up with MORE image-token slots than the prompt
+ # actually had images. URSA's vision merge requires
+ # image_token_count == n_image_features and aborts with
+ # "The input provided to the model are wrong.
+ # The number of image tokens is N while the number of image given is M.
+ # This prevents correct indexing and breaks batch generation."
+ # which crashes the whole PPO step. cce5ae5 already fixed this on the
+ # PRM forward path; the actor forward needs the same protection because
+ # the same actor-generated sequences are replayed here every PPO inner
+ # epoch. Align both directions:
+ # * token_count > image_count : extras are leaked, replace with pad.
+ # * token_count < image_count : truncate pixel_values/image_grid_thw.
+ sequences, pixel_values, image_grid_thw = self._align_image_tokens_to_images(
+ sequences, pixel_values, image_grid_thw
+ )
+
+ forward_kwargs = dict(
+ attention_mask=attention_mask_for_model,
+ position_ids=position_ids,
+ pixel_values=self._cast_multimodal_tensor(pixel_values),
+ image_grid_thw=image_grid_thw,
+ pixel_values_videos=self._cast_multimodal_tensor(pixel_values_videos),
+ video_grid_thw=video_grid_thw,
+ use_cache=False,
+ )
+ for k in ("pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "use_cache"):
+ if not self._supports_model_kwarg(k):
+ forward_kwargs.pop(k, None)
+
+ output = self.model(sequences, **forward_kwargs)
+
+ if num_actions is None:
+ assert return_output
+ return output
+
+ logits = output["logits"]
+ seq_T = sequences.size(1)
+ logit_T = logits.size(1)
+ if self.packing_samples:
+ raise NotImplementedError(
+ "UrsaActor.forward does not yet support packed_seq_lens. The "
+ "default ActorVL packing path is silently miscomputed for VLMs "
+ "that expand image placeholders; we don't want to bake the same "
+ "bug in here. Add explicit packed-aware alignment when needed."
+ )
+
+ # Generation tokens always sit at the tail of the expanded sequence,
+ # so logits at expanded positions [E - num_actions - 1 .. E - 2]
+ # predict tokens at expanded positions [E - num_actions .. E - 1] —
+ # which are the same generation tokens as ``sequences[:, -num_actions:]``
+ # in the unexpanded view (the unexpanded vs expanded offset only affects
+ # positions BEFORE the image placeholders, all in the prompt).
+ action_logits = logits[:, -(num_actions + 1):-1, :]
+ action_labels = sequences[:, -num_actions:]
+ if action_logits.size(1) != action_labels.size(1):
+ raise RuntimeError(
+ f"action_logits seq len {action_logits.size(1)} does not match "
+ f"action_labels seq len {action_labels.size(1)} "
+ f"(num_actions={num_actions}, seq_T={seq_T}, logit_T={logit_T})"
+ )
+
+ action_logp_full = F.log_softmax(action_logits.float(), dim=-1)
+ action_log_probs = action_logp_full.gather(
+ -1, action_labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ if self.high_entropy_token_ratio > 0.0:
+ # Entropy of the action-position distribution, in the same fp32 used above.
+ probs = action_logp_full.exp()
+ action_entropy = -(probs * action_logp_full).sum(dim=-1)
+ else:
+ action_entropy = None
+
+ if return_output:
+ if action_entropy is not None:
+ output_dict = dict(output)
+ output_dict["action_entropy"] = action_entropy
+ return (action_log_probs, output_dict)
+ return (action_log_probs, output)
+ return action_log_probs
+
+ def _align_image_tokens_to_images(self, sequences, pixel_values, image_grid_thw):
+ """Make ``sequences``'s image-token count match the actual image count.
+
+ URSA's vision merge crashes with::
+
+ ValueError: The input provided to the model are wrong.
+ The number of image tokens is N while the number of image given is M.
+ This prevents correct indexing and breaks batch generation.
+
+ whenever the count of ``image_token_index`` markers inside the input
+ sequence is unequal to the number of image features supplied via
+ ``pixel_values`` (one row of ``image_grid_thw`` per image). During GRPO
+ rollout this can happen because the actor occasionally generates a
+ literal ``<|image|>`` / ```` token inside its response — usually
+ rare, but observed to fire when KL drifts very high mid-training and
+ the actor's output distribution gets unstable.
+
+ Strategy (no-op fast path when counts already agree):
+
+ * count_tok > count_img : leaked extras — replace the trailing
+ (count_tok - count_img) image-token slots in each row with a benign
+ text token (pad / eos). Keep the first count_img in their original
+ positions so the vision features still merge into the prompt.
+
+ * count_tok < count_img : sequence is missing some image-token slots
+ relative to the image features. Truncate ``image_grid_thw`` and the
+ corresponding rows of ``pixel_values`` to the first count_tok
+ images. (Information loss is unavoidable, but the PPO step still
+ makes progress on the other tokens.)
+
+ Returns the (possibly sanitized) ``(sequences, pixel_values, image_grid_thw)``.
+ Original tensors are returned unchanged when no mismatch exists.
+ """
+ if sequences is None:
+ return sequences, pixel_values, image_grid_thw
+ image_token_id = getattr(self.model.config, "image_token_index", None)
+ if image_token_id is None:
+ return sequences, pixel_values, image_grid_thw
+ if pixel_values is None and image_grid_thw is None:
+ # Pure text micro-batch — nothing to align; also strip any leaked
+ # image tokens so the LM head doesn't see them as content.
+ n_tok = int((sequences == image_token_id).sum().item())
+ if n_tok == 0:
+ return sequences, pixel_values, image_grid_thw
+
+ # Per-row image-token positions (flat indices).
+ # We sanitize in-place on a clone to avoid mutating shared rollout buffers.
+ seq = sequences.clone()
+ flat = seq.view(-1)
+ tok_positions = torch.nonzero(flat == image_token_id, as_tuple=False).squeeze(-1)
+ n_tok = int(tok_positions.numel())
+
+ # Number of images supplied. image_grid_thw has one row per image and
+ # is the more reliable source than pixel_values (which may be packed).
+ if image_grid_thw is not None:
+ n_img = int(image_grid_thw.size(0))
+ elif pixel_values is not None:
+ # Fallback: assume one image per row of pixel_values.
+ n_img = int(pixel_values.size(0)) if pixel_values.dim() >= 1 else 0
+ else:
+ n_img = 0
+
+ if n_tok == n_img:
+ return sequences, pixel_values, image_grid_thw
+
+ replacement = None
+ tokenizer = getattr(self, "tokenizer", None)
+ if tokenizer is not None:
+ replacement = tokenizer.pad_token_id
+ if replacement is None:
+ replacement = tokenizer.eos_token_id
+ if replacement is None:
+ # Last-resort: pick a known safe id (eos is usually safe across HF tokenizers).
+ replacement = int(getattr(self.model.config, "eos_token_id", 0) or 0)
+
+ if n_tok > n_img:
+ # Leaked extras — replace tail extras with pad/eos so token_count == n_img.
+ extras = tok_positions[n_img:]
+ flat[extras] = replacement
+ seq = flat.view_as(sequences)
+ return seq, pixel_values, image_grid_thw
+
+ # n_tok < n_img: truncate image features to match token slots.
+ new_grid = image_grid_thw[:n_tok] if image_grid_thw is not None else None
+ new_pixel = pixel_values
+ if pixel_values is not None and image_grid_thw is not None and n_tok > 0:
+ # pixel_values is the concat of per-image patches; the per-image
+ # row counts come from image_grid_thw[i, 0] * thw[i, 1] * thw[i, 2].
+ # Keep the first n_tok image's patches.
+ patch_counts = (image_grid_thw[:, 0] * image_grid_thw[:, 1] * image_grid_thw[:, 2]).long()
+ keep = int(patch_counts[:n_tok].sum().item())
+ new_pixel = pixel_values[:keep]
+ elif pixel_values is not None and n_tok == 0:
+ new_pixel = None
+ new_grid = None
+ return sequences, new_pixel, new_grid
+
+
+def create_ursa_actor(args, ds_config=None):
+ """
+ Factory function to create URSA-8B actor from training args.
+
+ Args:
+ args: Training arguments (argparse.Namespace)
+ ds_config: DeepSpeed configuration dict
+
+ Returns:
+ UrsaActor instance
+ """
+ return UrsaActor(
+ args.pretrain,
+ use_flash_attention_2=args.flash_attn,
+ bf16=args.bf16,
+ load_in_4bit=getattr(args, 'load_in_4bit', False),
+ lora_rank=args.lora_rank,
+ lora_alpha=args.lora_alpha,
+ target_modules=getattr(args, 'target_modules', None),
+ lora_dropout=args.lora_dropout,
+ ds_config=ds_config,
+ packing_samples=args.packing_samples,
+ disable_logprobs_flashattn=getattr(args, 'disable_logprobs_flashattn', False),
+ fused_linear_logprob=getattr(args, 'fused_linear_logprob', False),
+ )
diff --git a/examples/math_prm/ursa_model/__init__.py b/examples/math_prm/ursa_model/__init__.py
new file mode 100644
index 00000000..bdbaeec2
--- /dev/null
+++ b/examples/math_prm/ursa_model/__init__.py
@@ -0,0 +1,30 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .image_processing_vlm import VLMImageProcessor, VLMImageProcessorConfig
+from .modeling_ursa import UrsaForConditionalGeneration, UrsaForTokenClassification
+from .processing_ursa import UrsaProcessor
+from .configuration_ursa import VisionConfig, UrsaConfig, AlignerConfig
+from .projector import MlpProjector
+
+__all__ = [
+ "VLMImageProcessor",
+ "UrsaProcessor",
+ "UrsaForConditionalGeneration",
+ "UrsaForTokenClassification",
+ "VLMImageProcessorConfig",
+ "VisionConfig",
+ "MlpProjector",
+ "AlignerConfig",
+ "UrsaConfig"
+]
\ No newline at end of file
diff --git a/examples/math_prm/ursa_model/attrdict_compat.py b/examples/math_prm/ursa_model/attrdict_compat.py
new file mode 100644
index 00000000..74c3aa3b
--- /dev/null
+++ b/examples/math_prm/ursa_model/attrdict_compat.py
@@ -0,0 +1,23 @@
+try:
+ from attrdict import AttrDict # type: ignore
+except ImportError:
+ try:
+ from easydict import EasyDict as AttrDict # type: ignore
+ except ImportError:
+ class AttrDict(dict):
+ """Minimal AttrDict fallback for URSA config objects."""
+
+ def __getattr__(self, item):
+ try:
+ return self[item]
+ except KeyError as exc:
+ raise AttributeError(item) from exc
+
+ def __setattr__(self, key, value):
+ self[key] = value
+
+ def __delattr__(self, item):
+ try:
+ del self[item]
+ except KeyError as exc:
+ raise AttributeError(item) from exc
diff --git a/examples/math_prm/ursa_model/clip_encoder.py b/examples/math_prm/ursa_model/clip_encoder.py
new file mode 100644
index 00000000..8695701a
--- /dev/null
+++ b/examples/math_prm/ursa_model/clip_encoder.py
@@ -0,0 +1,242 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+from typing import Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torchvision.transforms
+from einops import rearrange
+
+from .sam import create_sam_vit
+from .siglip_vit import create_siglip_vit
+
+
+class CLIPVisionTower(nn.Module):
+ def __init__(
+ self,
+ model_name: str = "siglip_large_patch16_384",
+ image_size: Union[Tuple[int, int], int] = 336,
+ select_feature: str = "patch",
+ select_layer: int = -2,
+ select_layers: list = None,
+ ckpt_path: str = "",
+ pixel_mean: Optional[List[float]] = None,
+ pixel_std: Optional[List[float]] = None,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.model_name = model_name
+ self.select_feature = select_feature
+ self.select_layer = select_layer
+ self.select_layers = select_layers
+
+ vision_tower_params = {
+ "model_name": model_name,
+ "image_size": image_size,
+ "ckpt_path": ckpt_path,
+ "select_layer": select_layer,
+ }
+ vision_tower_params.update(kwargs)
+ self.vision_tower, self.forward_kwargs = self.build_vision_tower(
+ vision_tower_params
+ )
+
+ if pixel_mean is not None and pixel_std is not None:
+ image_norm = torchvision.transforms.Normalize(
+ mean=pixel_mean, std=pixel_std
+ )
+ else:
+ image_norm = None
+
+ self.image_norm = image_norm
+
+ def build_vision_tower(self, vision_tower_params):
+ if self.model_name.startswith("siglip"):
+ self.select_feature = "same"
+ vision_tower = create_siglip_vit(**vision_tower_params)
+ forward_kwargs = dict()
+
+ elif self.model_name.startswith("sam"):
+ vision_tower = create_sam_vit(**vision_tower_params)
+ forward_kwargs = dict()
+
+ else: # huggingface
+ from transformers import CLIPVisionModel
+
+ vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params)
+ forward_kwargs = dict(output_hidden_states=True)
+
+ return vision_tower, forward_kwargs
+
+ def feature_select(self, image_forward_outs):
+ if isinstance(image_forward_outs, torch.Tensor):
+ # the output has been the self.select_layer"s features
+ image_features = image_forward_outs
+ else:
+ image_features = image_forward_outs.hidden_states[self.select_layer]
+
+ if self.select_feature == "patch":
+ # if the output has cls_token
+ image_features = image_features[:, 1:]
+ elif self.select_feature == "cls_patch":
+ image_features = image_features
+ elif self.select_feature == "same":
+ image_features = image_features
+
+ else:
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
+ return image_features
+
+ def forward(self, images):
+ """
+
+ Args:
+ images (torch.Tensor): [b, 3, H, W]
+
+ Returns:
+ image_features (torch.Tensor): [b, n_patch, d]
+ """
+
+ if self.image_norm is not None:
+ images = self.image_norm(images)
+
+ image_forward_outs = self.vision_tower(images, **self.forward_kwargs)
+ image_features = self.feature_select(image_forward_outs)
+ return image_features
+
+
+class HybridVisionTower(nn.Module):
+ def __init__(
+ self,
+ high_res_cfg: Dict,
+ low_res_cfg: Dict,
+ freeze_high: bool = False,
+ freeze_low: bool = False,
+ concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple",
+ **ignore_kwargs,
+ ):
+ super().__init__()
+
+ self.vision_tower_high = CLIPVisionTower(**high_res_cfg)
+ self.vision_tower_low = CLIPVisionTower(**low_res_cfg)
+ self.low_res_size = low_res_cfg["image_size"]
+ self.concat_type = concat_type
+
+ self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024))
+ self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024))
+
+ if freeze_high:
+ for p_name, p in self.vision_tower_high.named_parameters():
+ p.requires_grad = False
+ self.vision_tower_high = self.vision_tower_high.eval()
+ else:
+ # train donwsamples and neck
+ for p_name, p in self.vision_tower_high.named_parameters():
+ if "downsamples" in p_name or "neck" in p_name:
+ p.requires_grad = True
+ else:
+ p.requires_grad = False
+
+ if freeze_low:
+ for p in self.vision_tower_low.parameters():
+ p.requires_grad = False
+ self.vision_tower_low = self.vision_tower_low.eval()
+
+ self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True)
+
+ def forward(self, images: torch.Tensor):
+ """
+
+ Args:
+ images (torch.Tensor): [bs, 3, H, W]
+
+ Returns:
+ res (torch.Tensor): [bs, t, c]
+ """
+
+ # [bs, c, h, w]
+ high_images = images
+
+ # [bs, c, h_low, w_low]
+ low_images = self.resize(images)
+
+ # separately run two vision towers
+ # run high_res vision tower
+ high_res = self.vision_tower_high(high_images)
+ # [bs, c, h, w] -> [bs, h*w, c]
+ high_res = rearrange(high_res, "b c h w -> b (h w) c")
+ # run low_res vision tower
+ low_res = self.vision_tower_low(low_images)
+
+ if self.concat_type == "feature":
+ images_features = torch.cat([high_res, low_res], dim=-1)
+ elif self.concat_type == "sequence":
+ images_features = torch.cat([high_res, low_res], dim=1)
+ elif self.concat_type == "add":
+ images_features = high_res + low_res
+ elif self.concat_type == "tuple":
+ images_features = (high_res, low_res)
+
+ else:
+ raise ValueError(
+ "Currently only support `feature`, `sequence`, `add` and `tuple` concat type."
+ )
+
+ return images_features
+
+
+if __name__ == "__main__":
+ image_size = 1024
+ x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda()
+
+ high_res_cfg = dict(
+ model_name="sam_b_downsample",
+ select_feature="same",
+ image_size=image_size,
+ pixel_mean=(0.48145466, 0.4578275, 0.40821073),
+ pixel_std=(0.26862954, 0.26130258, 0.27577711),
+ select_layer=-1,
+ ckpt_path="",
+ )
+
+ low_res_cfg = dict(
+ model_name="siglip_large_patch16_384",
+ select_feature="same",
+ image_size=384,
+ pixel_mean=(0.5, 0.5, 0.5),
+ pixel_std=(0.5, 0.5, 0.5),
+ select_layer=-1,
+ ckpt_path="",
+ )
+
+ net = (
+ HybridVisionTower(
+ high_res_cfg=high_res_cfg,
+ low_res_cfg=low_res_cfg,
+ freeze_high=True,
+ freeze_low=True,
+ concat_type="tuple",
+ )
+ .bfloat16()
+ .cuda()
+ )
+ high_x, low_x = net(x)
+ print(x.shape, high_x.shape, low_x.shape)
diff --git a/examples/math_prm/ursa_model/configuration_ursa.py b/examples/math_prm/ursa_model/configuration_ursa.py
new file mode 100644
index 00000000..a233dd63
--- /dev/null
+++ b/examples/math_prm/ursa_model/configuration_ursa.py
@@ -0,0 +1,144 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+
+if sys.version_info >= (3, 10):
+ print("Python version is above 3.10, patching the collections module.")
+ import collections
+ import collections.abc
+
+ for type_name in collections.abc.__all__:
+ setattr(collections, type_name, getattr(collections.abc, type_name))
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+from transformers import CONFIG_MAPPING
+from .attrdict_compat import AttrDict
+logger = logging.get_logger(__name__)
+
+
+class VisionConfig(PretrainedConfig):
+ model_type = "vision"
+ cls: str = ""
+ params: AttrDict = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.cls = kwargs.get("cls", "")
+ if not isinstance(self.cls, str):
+ self.cls = self.cls.__name__
+
+ self.params = AttrDict(kwargs.get("params", {}))
+
+
+class AlignerConfig(PretrainedConfig):
+ model_type = "aligner"
+ cls: str = ""
+ params: AttrDict = {}
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ self.cls = kwargs.get("cls", "")
+ if not isinstance(self.cls, str):
+ self.cls = self.cls.__name__
+
+ self.params = AttrDict(kwargs.get("params", {}))
+
+class UrsaConfig(PretrainedConfig):
+ model_type = "ursa"
+ is_composition = False
+ vision_config: VisionConfig
+ aligner_config: AlignerConfig
+ text_config: PretrainedConfig
+
+ def __init__(
+ self,
+ vision_config=None,
+ aligner_config=None,
+ text_config=None,
+ ignore_index=-100,
+ image_token_index=32000,
+ projector_hidden_act="gelu",
+ vision_feature_select_strategy="default",
+ vision_feature_layer=-2,
+ **kwargs,
+ ):
+ self.ignore_index = ignore_index
+ self.image_token_index = image_token_index
+ self.projector_hidden_act = projector_hidden_act
+
+ if vision_feature_select_strategy not in ["default", "full"]:
+ raise ValueError(
+ "vision_feature_select_strategy should be one of 'default', 'full'."
+ f"Got: {vision_feature_select_strategy}"
+ )
+
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+
+ if vision_config is None:
+ vision_config = VisionConfig()
+ vision_config.cls = "HybridVisionTower"
+ vision_config.params = {
+ "concat_type": "tuple",
+ "high_res_cfg": {
+ "ckpt_path": "",
+ "image_size": 1024,
+ "model_name": "sam_b_downsample",
+ "output_dim": 1024,
+ "pixel_mean": [
+ 0.48145466,
+ 0.4578275,
+ 0.40821073
+ ],
+ "pixel_std": [
+ 0.26862954,
+ 0.26130258,
+ 0.27577711
+ ],
+ "select_feature": "same",
+ "select_layer": -1
+ },
+ "low_res_cfg": {
+ "ckpt_path": "",
+ "image_size": 384,
+ "model_name": "siglip_large_patch16_384",
+ "output_dim": 1024,
+ "pixel_mean": [
+ 0.5,
+ 0.5,
+ 0.5
+ ],
+ "pixel_std": [
+ 0.5,
+ 0.5,
+ 0.5
+ ],
+ "select_feature": "same",
+ "select_layer": -1
+ }
+ }
+ self.vision_config = vision_config
+ self.aligner_config = aligner_config
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["llama"]()
+
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
diff --git a/examples/math_prm/ursa_model/image_processing_vlm.py b/examples/math_prm/ursa_model/image_processing_vlm.py
new file mode 100644
index 00000000..367dee10
--- /dev/null
+++ b/examples/math_prm/ursa_model/image_processing_vlm.py
@@ -0,0 +1,208 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+from typing import List, Tuple, Union
+
+import numpy as np
+import torch
+import torchvision
+import torchvision.transforms.functional
+from PIL import Image
+from transformers import AutoImageProcessor, PretrainedConfig
+from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
+from transformers.image_utils import to_numpy_array
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+ImageType = Union[np.ndarray, torch.Tensor, Image.Image]
+IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
+IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
+IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
+
+
+def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+
+class VLMImageProcessorConfig(PretrainedConfig):
+ model_type = "deepseek_vlm"
+ image_size: int
+ min_size: int
+ image_mean: Union[Tuple[float, float, float], List[float]]
+ image_std: Union[Tuple[float, float, float], List[float]]
+ rescale_factor: float
+ do_normalize: bool
+
+ def __init__(
+ self,
+ image_size: int,
+ min_size: int = 14,
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
+ 0.48145466,
+ 0.4578275,
+ 0.40821073,
+ ),
+ image_std: Union[Tuple[float, float, float], List[float]] = (
+ 0.26862954,
+ 0.26130258,
+ 0.27577711,
+ ),
+ rescale_factor: float = 1.0 / 255.0,
+ do_normalize: bool = True,
+ **kwargs,
+ ):
+ self.image_size = image_size
+ self.min_size = min_size
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.rescale_factor = rescale_factor
+ self.do_normalize = do_normalize
+
+ super().__init__(**kwargs)
+
+
+class VLMImageProcessor(BaseImageProcessor):
+ model_input_names = ["pixel_values"]
+
+ def __init__(
+ self,
+ image_size: int,
+ min_size: int = 14,
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
+ 0.48145466,
+ 0.4578275,
+ 0.40821073,
+ ),
+ image_std: Union[Tuple[float, float, float], List[float]] = (
+ 0.26862954,
+ 0.26130258,
+ 0.27577711,
+ ),
+ rescale_factor: float = 1.0 / 255.0,
+ do_normalize: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.image_size = image_size
+ self.rescale_factor = rescale_factor
+ self.image_mean = image_mean
+ self.image_std = image_std
+ self.min_size = min_size
+ self.do_normalize = do_normalize
+
+ if image_mean is None:
+ self.background_color = (127, 127, 127)
+ else:
+ self.background_color = tuple([int(x * 255) for x in image_mean])
+
+ def resize(self, pil_img: Image) -> np.ndarray:
+ """
+
+ Args:
+ pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
+
+ Returns:
+ x (np.ndarray): [3, self.image_size, self.image_size]
+ """
+
+ width, height = pil_img.size
+ max_size = max(width, height)
+
+ size = [
+ max(int(height / max_size * self.image_size), self.min_size),
+ max(int(width / max_size * self.image_size), self.min_size),
+ ]
+
+ if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
+ print(f"orig size = {pil_img.size}, new size = {size}")
+ raise ValueError("Invalid size!")
+
+ pil_img = torchvision.transforms.functional.resize(
+ pil_img,
+ size,
+ interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC,
+ antialias=True,
+ )
+
+ pil_img = expand2square(pil_img, self.background_color)
+ x = to_numpy_array(pil_img)
+
+ # [H, W, 3] -> [3, H, W]
+ x = np.transpose(x, (2, 0, 1))
+
+ return x
+
+ def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
+ # resize and pad to [self.image_size, self.image_size]
+ # then convert from [H, W, 3] to [3, H, W]
+ images: List[np.ndarray] = [self.resize(image) for image in images]
+
+ # resacle from [0, 255] -> [0, 1]
+ images = [
+ self.rescale(
+ image=image,
+ scale=self.rescale_factor,
+ input_data_format="channels_first",
+ )
+ for image in images
+ ]
+
+ # normalize
+ if self.do_normalize:
+ images = [
+ self.normalize(
+ image=image,
+ mean=self.image_mean,
+ std=self.image_std,
+ input_data_format="channels_first",
+ )
+ for image in images
+ ]
+
+ data = {"pixel_values": images}
+ return BatchFeature(data=data, tensor_type=return_tensors)
+
+ @property
+ def default_shape(self):
+ return [3, self.image_size, self.image_size]
+
+
+AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor)
+
+
+if __name__ == "__main__":
+ image_processor = VLMImageProcessor(
+ image_size=1024,
+ image_mean=IMAGENET_INCEPTION_MEAN,
+ image_std=IMAGENET_INCEPTION_STD,
+ do_normalize=True,
+ )
diff --git a/examples/math_prm/ursa_model/modeling_ursa.py b/examples/math_prm/ursa_model/modeling_ursa.py
new file mode 100644
index 00000000..215008ff
--- /dev/null
+++ b/examples/math_prm/ursa_model/modeling_ursa.py
@@ -0,0 +1,742 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from transformers import (
+ PreTrainedModel,
+ AutoModel,
+ AutoModelForCausalLM,
+)
+from transformers.generation import GenerationMixin
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache
+from transformers.modeling_outputs import ModelOutput
+from .configuration_ursa import UrsaConfig, AlignerConfig, VisionConfig
+from .clip_encoder import CLIPVisionTower, HybridVisionTower
+from .projector import MlpProjector
+
+
+def _get_module_float_dtype(module) -> Optional[torch.dtype]:
+ module_dtype = getattr(module, "dtype", None)
+ if isinstance(module_dtype, torch.dtype):
+ return module_dtype
+ for parameter in module.parameters():
+ if torch.is_floating_point(parameter):
+ return parameter.dtype
+ return None
+
+
+def _cast_pixel_values_to_vision_dtype(pixel_values: Optional[torch.Tensor], vision_model):
+ if pixel_values is None or not torch.is_tensor(pixel_values) or not torch.is_floating_point(pixel_values):
+ return pixel_values
+ vision_dtype = _get_module_float_dtype(vision_model)
+ if vision_dtype is None or pixel_values.dtype == vision_dtype:
+ return pixel_values
+ return pixel_values.to(dtype=vision_dtype)
+
+
+@dataclass
+class UrsaCausalLMOutputWithPast(ModelOutput):
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ labels: Optional[Tuple[torch.FloatTensor]] = None
+
+
+class UrsaPreTrainedModel(PreTrainedModel):
+ config_class = UrsaConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["qwen2vlmVisionAttention"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = (
+ self.config.initializer_range
+ if hasattr(self.config, "initializer_range")
+ else self.config.text_config.initializer_range
+ )
+ if hasattr(module, "class_embedding"):
+ module.class_embedding.data.normal_(mean=0.0, std=std)
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ @property
+ def _supports_sdpa(self):
+ """
+ Retrieve language_model's attribute to check whether the model supports
+ SDPA or not.
+ """
+ language_model = getattr(self, "language_model", None)
+ if language_model is None:
+ return False
+ return getattr(language_model, "_supports_sdpa", False)
+
+
+class UrsaForConditionalGeneration(UrsaPreTrainedModel, GenerationMixin):
+ def __init__(self, config: UrsaConfig):
+ super().__init__(config)
+ # print(config)
+ # print(type(config.vision_config))
+ # print(type(config.aligner_config))
+ self.vision_model = HybridVisionTower(**config.vision_config["params"])
+ # print(config.aligner_config)
+
+ # print(config.aligner_config.params)
+ aligner_config = AlignerConfig(**config.aligner_config)
+ self.aligner = MlpProjector(aligner_config.params)
+ self.vocab_size = config.text_config.vocab_size
+ self.language_model = AutoModelForCausalLM.from_config(
+ config.text_config, attn_implementation=config._attn_implementation
+ )
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ def set_decoder(self, decoder):
+ self.language_model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def tie_weights(self):
+ return self.language_model.tie_weights()
+
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ # update vocab size
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
+ self.vocab_size = model_embeds.num_embeddings
+ return model_embeds
+
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
+ num_images, num_image_patches, embed_dim = image_features.shape
+ batch_size, sequence_length = input_ids.shape
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
+ # 1. Create a mask to know where special image tokens are
+ special_image_token_mask = input_ids == self.config.image_token_index
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
+ # Compute the maximum embed dimension
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
+
+ # 2. Compute the positions where text should be written
+ # Calculate new positions for text tokens in merged image-text sequence.
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
+ if left_padding:
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
+
+ # 3. Create the full embedding, already padded to the maximum position
+ final_embedding = torch.zeros(
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ final_attention_mask = torch.zeros(
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
+ )
+ if labels is not None:
+ final_labels = torch.full(
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
+ )
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
+ # set the corresponding tensors into their correct target device.
+ target_device = inputs_embeds.device
+ batch_indices, non_image_indices, text_to_overwrite = (
+ batch_indices.to(target_device),
+ non_image_indices.to(target_device),
+ text_to_overwrite.to(target_device),
+ )
+ attention_mask = attention_mask.to(target_device)
+
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
+ if labels is not None:
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
+
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
+ image_to_overwrite = torch.full(
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
+ )
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
+
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
+ raise ValueError(
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
+ )
+
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
+ final_attention_mask |= image_to_overwrite
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
+
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
+
+ final_embedding[batch_indices, indices_to_mask] = 0
+
+ if labels is None:
+ final_labels = None
+
+ return final_embedding, final_attention_mask, final_labels, position_ids
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[int] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, UrsaCausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if inputs_embeds is None:
+ # 1. Extra the input embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and images
+ if pixel_values is not None and input_ids.shape[1] != 1:
+ pixel_values = _cast_pixel_values_to_vision_dtype(pixel_values, self.vision_model)
+ image_outputs = self.vision_model(pixel_values)
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
+ # selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+
+ # if vision_feature_select_strategy == "default":
+ # selected_image_feature = selected_image_feature[:, 1:]
+ # elif vision_feature_select_strategy == "full":
+ # selected_image_feature = selected_image_feature
+ # else:
+ # raise ValueError(
+ # f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
+ # )
+
+ image_features = self.aligner(image_outputs)
+ inputs_embeds = inputs_embeds.to(image_features.dtype)
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
+ image_features, inputs_embeds, input_ids, attention_mask, labels
+ )
+
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
+ # generation with cache
+ elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
+ # that are set to 0
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
+
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
+
+ # Get the target length
+ target_length = input_ids.shape[1]
+ past_length = first_layer_past_key_value.shape[-1]
+
+ extended_attention_mask = torch.ones(
+ (attention_mask.shape[0], past_length),
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ )
+
+ # Filter out only the tokens that can be un-attended, this can happen
+ # if one uses qwen2vlm + Fused modules where the cache on the
+ # first iteration is already big enough, or if one passes custom cache
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
+ new_batch_index = batch_index[valid_indices]
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
+
+ # Zero-out the places where we don't need to attend
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
+
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = outputs[0]
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ if attention_mask is not None:
+ shift_attention_mask = attention_mask[..., 1:]
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return UrsaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
+ ):
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ if hasattr(past_key_values, "seen_tokens"):
+ past_length = past_key_values.seen_tokens
+ elif hasattr(past_key_values, "_seen_tokens"):
+ past_length = past_key_values._seen_tokens
+ else:
+ past_length = cache_length
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+ elif self.config.image_token_index in input_ids:
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
+ # older attention values, as their corresponding values are not part of the input.
+ if cache_length < past_length and attention_mask is not None:
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ }
+ )
+ return model_inputs
+
+ def _reorder_cache(self, *args, **kwargs):
+ return self.language_model._reorder_cache(*args, **kwargs)
+
+
+class UrsaForTokenClassification(UrsaPreTrainedModel):
+ def __init__(self, config: UrsaConfig):
+ super().__init__(config)
+ self.vision_model = HybridVisionTower(**config.vision_config["params"])
+ aligner_config = AlignerConfig(**config.aligner_config)
+ self.aligner = MlpProjector(aligner_config.params)
+ self.vocab_size = config.text_config.vocab_size
+ self.language_model = AutoModelForCausalLM.from_config(
+ config.text_config, attn_implementation=config._attn_implementation
+ )
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
+
+ self.dropout = nn.Dropout(0.1)
+ self.score = nn.Linear(self.language_model.config.hidden_size, 1)
+ self.post_init()
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # print(module)
+ nn.init.xavier_uniform_(module.weight)
+ nn.init.zeros_(module.bias)
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ def set_decoder(self, decoder):
+ self.language_model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def tie_weights(self):
+ return self.language_model.tie_weights()
+
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ # update vocab size
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
+ self.vocab_size = model_embeds.num_embeddings
+ return model_embeds
+
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
+ num_images, num_image_patches, embed_dim = image_features.shape
+ batch_size, sequence_length = input_ids.shape
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
+ # 1. Create a mask to know where special image tokens are
+ special_image_token_mask = input_ids == self.config.image_token_index
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
+ # Compute the maximum embed dimension
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
+
+ # 2. Compute the positions where text should be written
+ # Calculate new positions for text tokens in merged image-text sequence.
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
+ if left_padding:
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
+
+ # 3. Create the full embedding, already padded to the maximum position
+ final_embedding = torch.zeros(
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
+ )
+ final_attention_mask = torch.zeros(
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
+ )
+ if labels is not None:
+ final_labels = torch.full(
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
+ )
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
+ # set the corresponding tensors into their correct target device.
+ target_device = inputs_embeds.device
+ batch_indices, non_image_indices, text_to_overwrite = (
+ batch_indices.to(target_device),
+ non_image_indices.to(target_device),
+ text_to_overwrite.to(target_device),
+ )
+ attention_mask = attention_mask.to(target_device)
+
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"]
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
+ if labels is not None:
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
+ image_to_overwrite = torch.full(
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
+ )
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
+
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
+ raise ValueError(
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
+ )
+
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
+ final_attention_mask |= image_to_overwrite
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
+
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
+
+ final_embedding[batch_indices, indices_to_mask] = 0
+
+ if labels is None:
+ final_labels = None
+
+ return final_embedding, final_attention_mask, final_labels, position_ids
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[int] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, UrsaCausalLMOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+ if inputs_embeds is None:
+ # 1. Extra the input embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and images
+ if pixel_values is not None and input_ids.shape[1] != 1:
+ pixel_values = _cast_pixel_values_to_vision_dtype(pixel_values, self.vision_model)
+ image_outputs = self.vision_model(pixel_values)
+ # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
+ # selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
+
+ # if vision_feature_select_strategy == "default":
+ # selected_image_feature = selected_image_feature[:, 1:]
+ # elif vision_feature_select_strategy == "full":
+ # selected_image_feature = selected_image_feature
+ # else:
+ # raise ValueError(
+ # f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
+ # )
+
+ image_features = self.aligner(image_outputs)
+ inputs_embeds = inputs_embeds.to(image_features.dtype)
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
+ image_features, inputs_embeds, input_ids, attention_mask, labels
+ )
+
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
+ # generation with cache
+ elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
+ # that are set to 0
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
+
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
+
+ # Get the target length
+ target_length = input_ids.shape[1]
+ past_length = first_layer_past_key_value.shape[-1]
+
+ extended_attention_mask = torch.ones(
+ (attention_mask.shape[0], past_length),
+ dtype=attention_mask.dtype,
+ device=attention_mask.device,
+ )
+
+ # Filter out only the tokens that can be un-attended, this can happen
+ # if one uses qwen2vlm + Fused modules where the cache on the
+ # first iteration is already big enough, or if one passes custom cache
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
+ new_batch_index = batch_index[valid_indices]
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
+
+ # Zero-out the places where we don't need to attend
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
+
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ output_hidden_states=True
+ )
+
+ # logits = outputs[0]
+ logits = outputs.hidden_states[-1]
+ logits = self.dropout(logits)
+ logits = self.score(logits)
+
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ # if attention_mask is not None:
+ # shift_attention_mask = attention_mask[..., 1:]
+ # shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
+ # shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
+ # else:
+ # shift_logits = logits[..., :-1, :].contiguous()
+ # shift_labels = labels[..., 1:].contiguous()
+ # # Flatten the tokens
+ # loss_fct = nn.CrossEntropyLoss()
+ # loss = loss_fct(
+ # shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
+ # )
+ # loss_fct = nn.CrossEntropyLoss()
+ # loss = loss_fct(logits.view(-1, 1), labels.view(-1))
+ loss = None
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return UrsaCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ # past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ labels=labels
+ # attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
+ ):
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ if hasattr(past_key_values, "seen_tokens"):
+ past_length = past_key_values.seen_tokens
+ elif hasattr(past_key_values, "_seen_tokens"):
+ past_length = past_key_values._seen_tokens
+ else:
+ past_length = cache_length
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+ elif self.config.image_token_index in input_ids:
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
+ # older attention values, as their corresponding values are not part of the input.
+ if cache_length < past_length and attention_mask is not None:
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "pixel_values": pixel_values,
+ }
+ )
+ return model_inputs
+
+ def _reorder_cache(self, *args, **kwargs):
+ return self.language_model._reorder_cache(*args, **kwargs)
diff --git a/examples/math_prm/ursa_model/processing_ursa.py b/examples/math_prm/ursa_model/processing_ursa.py
new file mode 100644
index 00000000..1a92774a
--- /dev/null
+++ b/examples/math_prm/ursa_model/processing_ursa.py
@@ -0,0 +1,82 @@
+# Copyright (2025) Bytedance Ltd. and/or its affiliates
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+
+# http://www.apache.org/licenses/LICENSE-2.0
+
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput
+from transformers.processing_utils import ProcessorMixin
+from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
+from transformers.utils import TensorType
+
+
+class UrsaProcessor(ProcessorMixin):
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ image_processor_class = "AutoImageProcessor"
+ tokenizer_class = "AutoTokenizer"
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs)
+
+ @staticmethod
+ def _normalize_image_placeholders(text):
+ if isinstance(text, str):
+ return text.replace("", "<|image|>")
+ if isinstance(text, list):
+ return [UrsaProcessor._normalize_image_placeholders(item) for item in text]
+ if isinstance(text, tuple):
+ return tuple(UrsaProcessor._normalize_image_placeholders(item) for item in text)
+ return text
+
+ def __call__(
+ self,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ images: ImageInput = None,
+ videos: ImageInput = None,
+ padding: Union[bool, str, PaddingStrategy] = False,
+ truncation: Union[bool, str, TruncationStrategy] = None,
+ max_length=None,
+ add_special_tokens: bool = True,
+ return_tensors: Optional[Union[str, TensorType]] = None, # or TensorType.PYTORCH
+ **kwargs,
+ ) -> BatchFeature:
+ if videos is not None:
+ raise ValueError("UrsaProcessor does not support video inputs.")
+ image_inputs = {}
+ if images is not None:
+ image_inputs = self.image_processor(images, return_tensors=return_tensors)
+ text = self._normalize_image_placeholders(text)
+ text_inputs = self.tokenizer(
+ text,
+ return_tensors=return_tensors,
+ padding=padding,
+ truncation=truncation,
+ max_length=max_length,
+ add_special_tokens=add_special_tokens,
+ **kwargs,
+ )
+ return BatchFeature(data={**text_inputs, **image_inputs})
+
+ def decode(self, *args, **kwargs):
+ return self.tokenizer.decode(*args, **kwargs)
+
+ def batch_decode(self, *args, **kwargs):
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ @property
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
diff --git a/examples/math_prm/ursa_model/projector.py b/examples/math_prm/ursa_model/projector.py
new file mode 100644
index 00000000..13a2c36e
--- /dev/null
+++ b/examples/math_prm/ursa_model/projector.py
@@ -0,0 +1,101 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+from typing import Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from .attrdict_compat import AttrDict
+
+
+class MlpProjector(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+
+ self.cfg = cfg
+
+ if cfg.projector_type == "identity":
+ modules = nn.Identity()
+
+ elif cfg.projector_type == "linear":
+ modules = nn.Linear(cfg.input_dim, cfg.n_embed)
+
+ elif cfg.projector_type == "mlp_gelu":
+ mlp_depth = cfg.get("depth", 1)
+ modules = [nn.Linear(cfg.input_dim, cfg.n_embed)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
+ modules = nn.Sequential(*modules)
+
+ elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
+ mlp_depth = cfg.get("depth", 1)
+ self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
+ self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2)
+
+ modules = []
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(cfg.n_embed, cfg.n_embed))
+ modules = nn.Sequential(*modules)
+
+ else:
+ raise ValueError(f"Unknown projector type: {cfg.projector_type}")
+
+ self.layers = modules
+
+ def forward(
+ self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
+ ):
+ """
+
+ Args:
+ x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor,
+ then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x);
+ otherwise it is the feature from the single vision encoder.
+
+ Returns:
+ x (torch.Tensor): [b, s, c]
+ """
+
+ if isinstance(x_or_tuple, tuple):
+ # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu":
+ high_x, low_x = x_or_tuple
+ high_x = self.high_up_proj(high_x)
+ low_x = self.low_up_proj(low_x)
+ x = torch.concat([high_x, low_x], dim=-1)
+ else:
+ x = x_or_tuple
+
+ return self.layers(x)
+
+
+if __name__ == "__main__":
+ cfg = AttrDict(
+ input_dim=1024,
+ n_embed=2048,
+ depth=2,
+ projector_type="low_high_hybrid_split_mlp_gelu",
+ )
+ inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024))
+
+ m = MlpProjector(cfg)
+ out = m(inputs)
+ print(out.shape)
diff --git a/examples/math_prm/ursa_model/sam.py b/examples/math_prm/ursa_model/sam.py
new file mode 100644
index 00000000..31159a94
--- /dev/null
+++ b/examples/math_prm/ursa_model/sam.py
@@ -0,0 +1,593 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+from dataclasses import dataclass
+from functools import partial
+from typing import List, Optional, Tuple, Type, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MLPBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ mlp_dim: int,
+ act: Type[nn.Module] = nn.GELU,
+ ) -> None:
+ super().__init__()
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+ self.act = act()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
+class LayerNorm2d(nn.Module):
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(num_channels))
+ self.bias = nn.Parameter(torch.zeros(num_channels))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
+class ImageEncoderViT(nn.Module):
+ def __init__(
+ self,
+ img_size: int = 1024,
+ patch_size: int = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ out_chans: int = 256,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_abs_pos: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ global_attn_indexes: Tuple[int, ...] = (),
+ downsample_channels: Tuple[int, ...] = (512, 1024),
+ ) -> None:
+ """
+ Args:
+ img_size (int): Input image size.
+ patch_size (int): Patch size.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ depth (int): Depth of ViT.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_abs_pos (bool): If True, use absolute positional embeddings.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks.
+ global_attn_indexes (list): Indexes for blocks using global attention.
+ downsample_channels (list): Channels for downsampling layers.
+ """
+ super().__init__()
+ self.img_size = img_size
+
+ self.patch_embed = PatchEmbed(
+ kernel_size=(patch_size, patch_size),
+ stride=(patch_size, patch_size),
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+
+ self.pos_embed: Optional[nn.Parameter] = None
+ if use_abs_pos:
+ # Initialize absolute positional embedding with pretrain image size.
+ self.pos_embed = nn.Parameter(
+ torch.zeros(
+ 1, img_size // patch_size, img_size // patch_size, embed_dim
+ )
+ )
+
+ self.blocks = nn.ModuleList()
+ for i in range(depth):
+ block = Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ window_size=window_size if i not in global_attn_indexes else 0,
+ input_size=(img_size // patch_size, img_size // patch_size),
+ )
+ self.blocks.append(block)
+
+ self.neck = nn.Sequential(
+ nn.Conv2d(
+ embed_dim,
+ out_chans,
+ kernel_size=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ nn.Conv2d(
+ out_chans,
+ out_chans,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ LayerNorm2d(out_chans),
+ )
+
+ in_channels = out_chans
+ downsamples = []
+ for i in range(len(downsample_channels)):
+ out_channels = downsample_channels[i]
+ downsamples.append(
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False,
+ )
+ )
+ in_channels = out_channels
+ self.downsamples = nn.Sequential(*downsamples)
+
+ self.sam_hd = True
+ if self.sam_hd:
+ self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1))
+ # self.neck_hd = nn.Linear(embed_dim, embed_dim)
+ self.neck_hd = copy.deepcopy(self.neck)
+ # self.downsamples_hd = copy.deepcopy(self.downsamples)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ global_features = []
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if self.sam_hd and blk.window_size == 0:
+ global_features.append(x)
+
+ x = self.neck(x.permute(0, 3, 1, 2))
+ x_dtype = x.dtype
+ x = F.interpolate(
+ x.float(), size=(96, 96), mode="bilinear", align_corners=False
+ ).to(x_dtype)
+ x = self.downsamples(x)
+
+ if self.sam_hd:
+ first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2))
+ x_dtype = first_global_feature.dtype
+ first_global_feature = F.interpolate(
+ first_global_feature.float(),
+ size=(96, 96),
+ mode="bilinear",
+ align_corners=False,
+ )
+ first_global_feature = self.downsamples(first_global_feature.to(x_dtype))
+ x = x + first_global_feature * self.hd_alpha_downsamples
+
+ return x
+
+
+class Block(nn.Module):
+ """Transformer blocks with support of window attention and residual propagation blocks"""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ act_layer: Type[nn.Module] = nn.GELU,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ window_size: int = 0,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads in each ViT block.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ norm_layer (nn.Module): Normalization layer.
+ act_layer (nn.Module): Activation layer.
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ window_size (int): Window size for window attention blocks. If it equals 0, then
+ use global attention.
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
+ positional parameter size.
+ """
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ use_rel_pos=use_rel_pos,
+ rel_pos_zero_init=rel_pos_zero_init,
+ input_size=input_size if window_size == 0 else (window_size, window_size),
+ )
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = MLPBlock(
+ embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
+ )
+
+ self.window_size = window_size
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+ x = self.norm1(x)
+ # Window partition
+ if self.window_size > 0:
+ H, W = x.shape[1], x.shape[2]
+ x, pad_hw = window_partition(x, self.window_size)
+
+ x = self.attn(x)
+ # Reverse window partition
+ if self.window_size > 0:
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+ x = shortcut + x
+ x = x + self.mlp(self.norm2(x))
+
+ return x
+
+
+class Attention(nn.Module):
+ """Multi-head Attention block with relative position embeddings."""
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ use_rel_pos: bool = False,
+ rel_pos_zero_init: bool = True,
+ input_size: Optional[Tuple[int, int]] = None,
+ ) -> None:
+ """
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
+ positional parameter size.
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim, dim)
+
+ self.use_rel_pos = use_rel_pos
+ if self.use_rel_pos:
+ assert (
+ input_size is not None
+ ), "Input size must be provided if using relative positional encoding."
+ # initialize relative positional embeddings
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, H, W, _ = x.shape
+ # qkv with shape (3, B, nHead, H * W, C)
+ qkv = (
+ self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ )
+ # q, k, v with shape (B * nHead, H * W, C)
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+ def do_attention(q, k, v):
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ if self.use_rel_pos:
+ attn = add_decomposed_rel_pos(
+ attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
+ )
+
+ attn = attn.softmax(dim=-1)
+ x = (
+ (attn @ v)
+ .view(B, self.num_heads, H, W, -1)
+ .permute(0, 2, 3, 1, 4)
+ .reshape(B, H, W, -1)
+ )
+
+ return x
+
+ # from haiscale.utils import on_demand_checkpoint
+ # x = on_demand_checkpoint(do_attention, q, k, v)
+ x = do_attention(q, k, v)
+ x = self.proj(x)
+
+ return x
+
+
+def window_partition(
+ x: torch.Tensor, window_size: int
+) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """
+ Partition into non-overlapping windows with padding if needed.
+ Args:
+ x (tensor): input tokens with [B, H, W, C].
+ window_size (int): window size.
+
+ Returns:
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
+ (Hp, Wp): padded height and width before partition
+ """
+ B, H, W, C = x.shape
+
+ pad_h = (window_size - H % window_size) % window_size
+ pad_w = (window_size - W % window_size) % window_size
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+ Hp, Wp = H + pad_h, W + pad_w
+
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ )
+ return windows, (Hp, Wp)
+
+
+def window_unpartition(
+ windows: torch.Tensor,
+ window_size: int,
+ pad_hw: Tuple[int, int],
+ hw: Tuple[int, int],
+) -> torch.Tensor:
+ """
+ Window unpartition into original sequences and removing padding.
+ Args:
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+ window_size (int): window size.
+ pad_hw (Tuple): padded height and width (Hp, Wp).
+ hw (Tuple): original height and width (H, W) before padding.
+
+ Returns:
+ x: unpartitioned sequences with [B, H, W, C].
+ """
+ Hp, Wp = pad_hw
+ H, W = hw
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+ x = windows.view(
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
+ )
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+ if Hp > H or Wp > W:
+ x = x[:, :H, :W, :].contiguous()
+ return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+ size=max_rel_dist,
+ mode="linear",
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+ q_size: Tuple[int, int],
+ k_size: Tuple[int, int],
+) -> torch.Tensor:
+ """
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
+ Args:
+ attn (Tensor): attention map.
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+ Returns:
+ attn (Tensor): attention map with added relative positional embeddings.
+ """
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+ B, _, dim = q.shape
+ r_q = q.reshape(B, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+ attn = (
+ attn.view(B, q_h, q_w, k_h, k_w)
+ + rel_h[:, :, :, :, None]
+ + rel_w[:, :, :, None, :]
+ ).view(B, q_h * q_w, k_h * k_w)
+
+ return attn
+
+
+class PatchEmbed(nn.Module):
+ """
+ Image to Patch Embedding.
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int] = (16, 16),
+ stride: Tuple[int, int] = (16, 16),
+ padding: Tuple[int, int] = (0, 0),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ """
+ Args:
+ kernel_size (Tuple): kernel size of the projection layer.
+ stride (Tuple): stride of the projection layer.
+ padding (Tuple): padding size of the projection layer.
+ in_chans (int): Number of input image channels.
+ embed_dim (int): Patch embedding dimension.
+ """
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ # B C H W -> B H W C
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+@dataclass
+class SAMViTCfg:
+ image_size: Union[Tuple[int, int], int] = 1024
+ width: int = 1024
+ layers: int = 23
+ heads: int = 16
+ patch_size: int = 16
+ window_size: int = 14
+ prompt_embed_dim: int = 256
+ global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23)
+ downsample_channels: Union[List[int], Tuple[int]] = (512, 1024)
+
+
+SAM_MODEL_CONFIG = {
+ "sam_vit_b": {
+ "width": 768,
+ "layers": 12,
+ "heads": 12,
+ "global_attn_indexes": [2, 5, 8, 11],
+ "downsample_channels": (),
+ },
+ "sam_b_downsample": {
+ "width": 768,
+ "layers": 12,
+ "heads": 12,
+ "global_attn_indexes": [2, 5, 8, 11],
+ "downsample_channels": (512, 1024),
+ },
+ "sam_vit_l": {
+ "width": 1024,
+ "layers": 24,
+ "heads": 16,
+ "global_attn_indexes": [5, 11, 17, 23],
+ "downsample_channels": (),
+ },
+ "sam_vit_h": {
+ "width": 1280,
+ "layers": 32,
+ "heads": 16,
+ "global_attn_indexes": [7, 15, 23, 31],
+ "downsample_channels": (),
+ },
+}
+
+
+def create_sam_vit(
+ model_name: str = "sam_b_downsample",
+ image_size: int = 1024,
+ ckpt_path: str = "",
+ **kwargs,
+):
+ assert (
+ model_name in SAM_MODEL_CONFIG.keys()
+ ), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}"
+
+ sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name])
+ image_encoder = ImageEncoderViT(
+ depth=sam_cfg.layers,
+ embed_dim=sam_cfg.width,
+ img_size=image_size,
+ mlp_ratio=4,
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+ num_heads=sam_cfg.heads,
+ patch_size=sam_cfg.patch_size,
+ qkv_bias=True,
+ use_rel_pos=True,
+ global_attn_indexes=sam_cfg.global_attn_indexes,
+ window_size=14,
+ out_chans=sam_cfg.prompt_embed_dim,
+ downsample_channels=sam_cfg.downsample_channels,
+ )
+
+ if ckpt_path:
+ state_dict = torch.load(ckpt_path)
+ image_encoder.load_state_dict(state_dict, strict=False)
+ print(f"SAM-ViT restores from {ckpt_path}")
+
+ return image_encoder
+
+
+if __name__ == "__main__":
+ x = torch.zeros(2, 3, 1024, 1024).bfloat16()
+ # x.permute(0, 3, 1, 2)
+ net = create_sam_vit().bfloat16()
+ out = net(x)
+ print(x.shape, out.shape)
diff --git a/examples/math_prm/ursa_model/siglip_vit.py b/examples/math_prm/ursa_model/siglip_vit.py
new file mode 100644
index 00000000..a93707ba
--- /dev/null
+++ b/examples/math_prm/ursa_model/siglip_vit.py
@@ -0,0 +1,681 @@
+# Copyright (c) 2023-2024 DeepSeek.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy of
+# this software and associated documentation files (the "Software"), to deal in
+# the Software without restriction, including without limitation the rights to
+# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
+# the Software, and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
+# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
+# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
+# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
+import math
+import warnings
+from dataclasses import dataclass
+from functools import partial
+from typing import (
+ Callable,
+ Dict,
+ Final,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ Union,
+)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.layers import (
+ AttentionPoolLatent,
+ DropPath,
+ LayerType,
+ Mlp,
+ PatchDropout,
+ PatchEmbed,
+ resample_abs_pos_embed,
+)
+from timm.models._manipulate import checkpoint_seq, named_apply
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2,
+ )
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std) # noqa: E741
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.0))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
+ # type: (torch.Tensor, float, float, float, float) -> torch.Tensor
+ r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first
+ convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype.
+ Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn
+ from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+
+ with torch.no_grad():
+ dtype = tensor.dtype
+ tensor_fp32 = tensor.float()
+ tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b)
+ tensor_dtype = tensor_fp32.to(dtype=dtype)
+ tensor.copy_(tensor_dtype)
+
+
+def init_weights(self):
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
+ trunc_normal_(self.latent, std=self.latent_dim**-0.5)
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = "") -> None:
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, "init_weights"):
+ module.init_weights()
+
+
+class Attention(nn.Module):
+ fused_attn: Final[bool]
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ norm_layer: nn.Module = nn.LayerNorm,
+ ) -> None:
+ super().__init__()
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim**-0.5
+ # self.fused_attn = use_fused_attn()
+ self.fused_attn = True
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
+ .permute(2, 0, 3, 1, 4)
+ )
+ q, k, v = qkv.unbind(0)
+ q, k = self.q_norm(q), self.k_norm(k)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.attn_drop.p if self.training else 0.0,
+ )
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: float = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ qk_norm: bool = False,
+ proj_drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values: Optional[float] = None,
+ drop_path: float = 0.0,
+ act_layer: nn.Module = nn.GELU,
+ norm_layer: nn.Module = nn.LayerNorm,
+ mlp_layer: nn.Module = Mlp,
+ ) -> None:
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ norm_layer=norm_layer,
+ )
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = mlp_layer(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """Vision Transformer
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+ """
+
+ dynamic_img_size: Final[bool]
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ global_pool: Literal["", "avg", "token", "map"] = "token",
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ qk_norm: bool = False,
+ init_values: Optional[float] = None,
+ class_token: bool = True,
+ no_embed_class: bool = False,
+ reg_tokens: int = 0,
+ pre_norm: bool = False,
+ fc_norm: Optional[bool] = None,
+ dynamic_img_size: bool = False,
+ dynamic_img_pad: bool = False,
+ drop_rate: float = 0.0,
+ pos_drop_rate: float = 0.0,
+ patch_drop_rate: float = 0.0,
+ proj_drop_rate: float = 0.0,
+ attn_drop_rate: float = 0.0,
+ drop_path_rate: float = 0.0,
+ weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "",
+ embed_layer: Callable = PatchEmbed,
+ norm_layer: Optional[LayerType] = None,
+ act_layer: Optional[LayerType] = None,
+ block_fn: Type[nn.Module] = Block,
+ mlp_layer: Type[nn.Module] = Mlp,
+ ignore_head: bool = False,
+ ) -> None:
+ """
+ Args:
+ img_size: Input image size.
+ patch_size: Patch size.
+ in_chans: Number of image input channels.
+ num_classes: Mumber of classes for classification head.
+ global_pool: Type of global pooling for final sequence (default: 'token').
+ embed_dim: Transformer embedding dimension.
+ depth: Depth of transformer.
+ num_heads: Number of attention heads.
+ mlp_ratio: Ratio of mlp hidden dim to embedding dim.
+ qkv_bias: Enable bias for qkv projections if True.
+ init_values: Layer-scale init values (layer-scale enabled if not None).
+ class_token: Use class token.
+ no_embed_class: Don't include position embeddings for class (or reg) tokens.
+ reg_tokens: Number of register tokens.
+ fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
+ drop_rate: Head dropout rate.
+ pos_drop_rate: Position embedding dropout rate.
+ attn_drop_rate: Attention dropout rate.
+ drop_path_rate: Stochastic depth rate.
+ weight_init: Weight initialization scheme.
+ embed_layer: Patch embedding layer.
+ norm_layer: Normalization layer.
+ act_layer: MLP activation layer.
+ block_fn: Transformer block layer.
+ """
+ super().__init__()
+ assert global_pool in ("", "avg", "token", "map")
+ assert class_token or global_pool != "token"
+ use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
+ # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
+ # act_layer = get_act_layer(act_layer) or nn.GELU
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ act_layer = nn.GELU
+
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = (
+ embed_dim # num_features for consistency with other models
+ )
+ self.num_prefix_tokens = 1 if class_token else 0
+ self.num_prefix_tokens += reg_tokens
+ self.num_reg_tokens = reg_tokens
+ self.has_class_token = class_token
+ self.no_embed_class = (
+ no_embed_class # don't embed prefix positions (includes reg)
+ )
+ self.dynamic_img_size = dynamic_img_size
+ self.grad_checkpointing = False
+ self.ignore_head = ignore_head
+
+ embed_args = {}
+ if dynamic_img_size:
+ # flatten deferred until after pos embed
+ embed_args.update(dict(strict_img_size=False, output_fmt="NHWC"))
+ self.patch_embed = embed_layer(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
+ dynamic_img_pad=dynamic_img_pad,
+ **embed_args,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = (
+ nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
+ )
+ self.reg_token = (
+ nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
+ )
+ embed_len = (
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
+ )
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+ if patch_drop_rate > 0:
+ self.patch_drop = PatchDropout(
+ patch_drop_rate,
+ num_prefix_tokens=self.num_prefix_tokens,
+ )
+ else:
+ self.patch_drop = nn.Identity()
+ self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+ self.blocks = nn.Sequential(
+ *[
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ init_values=init_values,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ mlp_layer=mlp_layer,
+ )
+ for i in range(depth)
+ ]
+ )
+ self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
+
+ # Classifier Head
+ if global_pool == "map":
+ AttentionPoolLatent.init_weights = init_weights
+ self.attn_pool = AttentionPoolLatent(
+ self.embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ norm_layer=norm_layer,
+ )
+ else:
+ self.attn_pool = None
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = (
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ )
+
+ if weight_init != "skip":
+ self.init_weights(weight_init)
+
+ def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None:
+ assert mode in ("jax", "jax_nlhb", "moco", "")
+ # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0
+ trunc_normal_(self.pos_embed, std=0.02)
+ if self.cls_token is not None:
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ @torch.jit.ignore
+ def no_weight_decay(self) -> Set:
+ return {"pos_embed", "cls_token", "dist_token"}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse: bool = False) -> Dict:
+ return dict(
+ stem=r"^cls_token|pos_embed|patch_embed", # stem and embed
+ blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable: bool = True) -> None:
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self) -> nn.Module:
+ return self.head
+
+ def reset_classifier(self, num_classes: int, global_pool=None) -> None:
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ("", "avg", "token", "map")
+ if global_pool == "map" and self.attn_pool is None:
+ assert (
+ False
+ ), "Cannot currently add attention pooling in reset_classifier()."
+ elif global_pool != "map " and self.attn_pool is not None:
+ self.attn_pool = None # remove attention pooling
+ self.global_pool = global_pool
+ self.head = (
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ )
+
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
+ if self.dynamic_img_size:
+ B, H, W, C = x.shape
+ pos_embed = resample_abs_pos_embed(
+ self.pos_embed,
+ (H, W),
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
+ )
+ x = x.view(B, -1, C)
+ else:
+ pos_embed = self.pos_embed
+
+ to_cat = []
+ if self.cls_token is not None:
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
+ if self.reg_token is not None:
+ to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
+
+ if self.no_embed_class:
+ # deit-3, updated JAX (big vision)
+ # position embedding does not overlap with class token, add then concat
+ x = x + pos_embed
+ if to_cat:
+ x = torch.cat(to_cat + [x], dim=1)
+ else:
+ # original timm, JAX, and deit vit impl
+ # pos_embed has entry for class token, concat then add
+ if to_cat:
+ x = torch.cat(to_cat + [x], dim=1)
+ x = x + pos_embed
+
+ return self.pos_drop(x)
+
+ def _intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1,
+ ) -> List[torch.Tensor]:
+ outputs, num_blocks = [], len(self.blocks)
+ take_indices = set(
+ range(num_blocks - n, num_blocks) if isinstance(n, int) else n
+ )
+
+ # forward pass
+ x = self.patch_embed(x)
+ x = self._pos_embed(x)
+ x = self.patch_drop(x)
+ x = self.norm_pre(x)
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+ if i in take_indices:
+ outputs.append(x)
+
+ return outputs
+
+ def get_intermediate_layers(
+ self,
+ x: torch.Tensor,
+ n: Union[int, Sequence] = 1,
+ reshape: bool = False,
+ return_prefix_tokens: bool = False,
+ norm: bool = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+ """Intermediate layer accessor (NOTE: This is a WIP experiment).
+ Inspired by DINO / DINOv2 interface
+ """
+ # take last n blocks if n is an int, if in is a sequence, select by matching indices
+ outputs = self._intermediate_layers(x, n)
+ if norm:
+ outputs = [self.norm(out) for out in outputs]
+ prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs]
+ outputs = [out[:, self.num_prefix_tokens :] for out in outputs]
+
+ if reshape:
+ grid_size = self.patch_embed.grid_size
+ outputs = [
+ out.reshape(x.shape[0], grid_size[0], grid_size[1], -1)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ for out in outputs
+ ]
+
+ if return_prefix_tokens:
+ return tuple(zip(outputs, prefix_tokens))
+ return tuple(outputs)
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ x = self._pos_embed(x)
+ x = self.patch_drop(x)
+ x = self.norm_pre(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
+ if self.attn_pool is not None:
+ x = self.attn_pool(x)
+ elif self.global_pool == "avg":
+ x = x[:, self.num_prefix_tokens :].mean(dim=1)
+ elif self.global_pool:
+ x = x[:, 0] # class token
+ x = self.fc_norm(x)
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.forward_features(x)
+ if not self.ignore_head:
+ x = self.forward_head(x)
+ return x
+
+
+@dataclass
+class SigLIPVisionCfg:
+ width: int = 1152
+ layers: Union[Tuple[int, int, int, int], int] = 27
+ heads: int = 16
+ patch_size: int = 14
+ image_size: Union[Tuple[int, int], int] = 336
+ global_pool: str = "map"
+ mlp_ratio: float = 3.7362
+ class_token: bool = False
+ num_classes: int = 0
+ use_checkpoint: bool = False
+
+
+SigLIP_MODEL_CONFIG = {
+ "siglip_so400m_patch14_384": {
+ "image_size": 336,
+ "patch_size": 14,
+ "width": 1152,
+ "layers": 27,
+ "heads": 16,
+ "mlp_ratio": 3.7362,
+ "global_pool": "map",
+ "use_checkpoint": False,
+ },
+ "siglip_so400m_patch14_224": {
+ "image_size": 224,
+ "patch_size": 14,
+ "width": 1152,
+ "layers": 27,
+ "heads": 16,
+ "mlp_ratio": 3.7362,
+ "global_pool": "map",
+ "use_checkpoint": False,
+ },
+ "siglip_large_patch16_384": {
+ "image_size": 384,
+ "patch_size": 16,
+ "width": 1024,
+ "layers": 24,
+ "heads": 16,
+ "mlp_ratio": 4,
+ "global_pool": "map",
+ "use_checkpoint": False,
+ },
+}
+
+
+def create_siglip_vit(
+ model_name: str = "siglip_so400m_patch14_384",
+ image_size: int = 384,
+ select_layer: int = -1,
+ ckpt_path: str = "",
+ **kwargs,
+):
+ assert (
+ model_name in SigLIP_MODEL_CONFIG.keys()
+ ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}"
+
+ vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name])
+
+ if select_layer <= 0:
+ layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1)
+ else:
+ layers = min(vision_cfg.layers, select_layer)
+
+ model = VisionTransformer(
+ img_size=image_size,
+ patch_size=vision_cfg.patch_size,
+ embed_dim=vision_cfg.width,
+ depth=layers,
+ num_heads=vision_cfg.heads,
+ mlp_ratio=vision_cfg.mlp_ratio,
+ class_token=vision_cfg.class_token,
+ global_pool=vision_cfg.global_pool,
+ ignore_head=kwargs.get("ignore_head", True),
+ weight_init=kwargs.get("weight_init", "skip"),
+ num_classes=0,
+ )
+
+ if ckpt_path:
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
+ print(
+ f"SigLIP-ViT restores from {ckpt_path},\n"
+ f"\tincompatible_keys:', {incompatible_keys}."
+ )
+
+ return model
diff --git a/examples/math_prm/ursa_variant2.py b/examples/math_prm/ursa_variant2.py
new file mode 100644
index 00000000..3e376044
--- /dev/null
+++ b/examples/math_prm/ursa_variant2.py
@@ -0,0 +1,441 @@
+"""URSA paper Eq.9 strict-alignment advantage estimator.
+
+Paper: arXiv 2501.04686 (NeurIPS 2025), Appendix B.1 Eq.9 — the second
+straw-man variant the paper considers (and ultimately *rejects* in favour
+of PS-GRPO):
+
+ A_t^i = r_{s,t}^i * GroupNorm_G(r̄_s^i) (process-reward term)
+ + GroupNorm_G(r_o^i) (outcome-reward term)
+
+where t indexes *steps* (not tokens), r_{s,t}^i is the sigmoid PRM score for
+step t in trajectory i, r̄_s^i = mean_t r_{s,t}^i is the per-trajectory mean
+PRM score, r_o^i ∈ {0,1} is the outcome reward, and GroupNorm_G is
+(x - mean_G(x)) / std_G(x) over the G=K trajectories sampled from the same
+prompt. The token-level A_t is broadcast to every token spanned by step t.
+
+This file is intentionally self-contained in ``examples/math_prm/`` and does
+**not** modify any code under ``lightrft/``. It registers a new estimator
+``ursa_variant2`` by monkey-patching
+``lightrft.trainer.advantage_calculator.get_advantage_calculator`` at import
+time. The patch is idempotent.
+
+Why a separate path, not a flag on the existing per-step PRM path:
+the legacy ``per_step_reward_mode`` path (still useful as Math-Shepherd-style
+step-MC return) goes through ``compute_reward`` Mode B + reverse-cumsum +
+GroupNormCalculator. That fully bypasses the outcome reward and uses
+cumulative returns, both of which contradict Eq.9. Keeping the two paths
+side by side allows ablation between paper-strict (this estimator) and
+Math-Shepherd-style (legacy).
+"""
+
+from __future__ import annotations
+
+from typing import Dict, List, Optional, Tuple
+
+import torch
+
+from lightrft.trainer.advantage_calculator import (
+ AdvantageCalculator,
+ compute_clip_fraction,
+)
+
+
+class UrsaVariant2Calculator(AdvantageCalculator):
+ """Strict paper Eq.9 implementation.
+
+ Reads per-trajectory step PRM scores and outcome reward from
+ ``experience.info`` (already emitted by ``MathPRMReward.forward`` for
+ label ``"math_per_step_prm"``), does its own GroupNorm, and writes
+ a per-token advantage tensor where every token within the span of
+ step k carries A_k. No cumulative return. ``returns`` mirrors
+ ``advantages`` (no separate value function).
+ """
+
+ _ESTIMATOR_NAME = "ursa_variant2"
+
+ def preprocess_rewards(
+ self,
+ rewards: torch.Tensor,
+ experiences: List,
+ max_new_tokens: int,
+ ) -> Tuple[List, List[torch.Tensor]]:
+ """Compute GroupNormed (r̄_s, r_o) across all G trajectories.
+
+ ``rewards`` is the concatenated per-trajectory scalar reward from
+ every experience in the batch — for ``math_per_step_prm`` rows this
+ is ``outcome_correct ∈ {0,1}`` (see ``reward_models.py:655``).
+ ``experiences`` lets us gather per-trajectory step_rewards needed
+ to compute r̄_s^i. We write the GroupNormed values back into each
+ experience's ``info`` under reserved keys ``_ursa_oc_normed`` and
+ ``_ursa_msp_normed`` so ``compute()`` can pick them up later.
+
+ Aborts the variant-2 path (returns identity-chunked rewards
+ without touching info) when ``n_samples_per_prompt < 2`` — a
+ single-sample group has std = 0 and ``A_t`` would collapse.
+ """
+ config = self.config
+ n_samples = int(getattr(config, "n_samples_per_prompt", 1) or 1)
+
+ # Identity preprocessing if K<2 — variant 2 needs a group to normalize.
+ # We still chunk rewards back so the downstream contract is preserved.
+ if n_samples < 2:
+ reward_chunks = rewards.chunk(len(experiences)) if len(experiences) > 0 else []
+ return experiences, list(reward_chunks)
+
+ device = rewards.device
+ total_B = rewards.numel()
+ if total_B % n_samples != 0:
+ # Cannot group — bail out gracefully, fall back to identity.
+ reward_chunks = rewards.chunk(len(experiences))
+ return experiences, list(reward_chunks)
+
+ # Compute r̄_s^i for every trajectory across experiences.
+ mean_step_prm_chunks: List[torch.Tensor] = []
+ per_exp_traj_counts: List[int] = []
+ for exp in experiences:
+ sr_list = exp.info.get("step_rewards")
+ n_traj = int(exp.info["reward"].numel())
+ per_exp_traj_counts.append(n_traj)
+ if sr_list is None:
+ # No per-step data — treat r̄_s as zero (variant 2 will rely
+ # on the outcome-norm anchor only for that trajectory).
+ mean_step_prm_chunks.append(
+ torch.zeros(n_traj, dtype=torch.float32, device=device)
+ )
+ continue
+ means: List[torch.Tensor] = []
+ for sr in sr_list:
+ if sr.numel() > 0:
+ means.append(sr.to(device=device, dtype=torch.float32).mean())
+ else:
+ means.append(torch.zeros((), dtype=torch.float32, device=device))
+ if len(means) != n_traj:
+ # Misaligned bookkeeping — fall back per-traj to zero.
+ pad = [torch.zeros((), dtype=torch.float32, device=device)] * (
+ n_traj - len(means)
+ )
+ means = means + pad
+ means = means[:n_traj]
+ mean_step_prm_chunks.append(torch.stack(means).to(device=device))
+ mean_step_prm = torch.cat(mean_step_prm_chunks, dim=0) # (total_B,)
+
+ # GroupNorm both terms across G=K siblings (paper Eq.9 footer text).
+ oc_flat = rewards.to(device=device, dtype=torch.float32)
+ oc_g = oc_flat.reshape(-1, n_samples)
+ oc_normed = (
+ (oc_g - oc_g.mean(dim=-1, keepdim=True))
+ / (oc_g.std(dim=-1, unbiased=False, keepdim=True) + 1e-9)
+ ).flatten()
+
+ msp_g = mean_step_prm.reshape(-1, n_samples)
+ msp_normed = (
+ (msp_g - msp_g.mean(dim=-1, keepdim=True))
+ / (msp_g.std(dim=-1, unbiased=False, keepdim=True) + 1e-9)
+ ).flatten()
+
+ # Scatter normed values back per-experience (keep CPU-side view to avoid
+ # device contention when compute() runs in a different stream).
+ offset = 0
+ for exp, n_traj in zip(experiences, per_exp_traj_counts):
+ exp.info["_ursa_oc_normed"] = oc_normed[offset:offset + n_traj].clone().cpu()
+ exp.info["_ursa_msp_normed"] = msp_normed[offset:offset + n_traj].clone().cpu()
+ exp.info["_ursa_mean_step_prm_raw"] = mean_step_prm[offset:offset + n_traj].clone().cpu()
+ offset += n_traj
+
+ # Default behaviour: chunk the (unmodified) per-trajectory rewards back
+ # to per-experience tensors — ``compute()`` ignores this anyway.
+ reward_chunks = oc_flat.chunk(len(experiences)) if len(experiences) > 0 else []
+ return experiences, list(reward_chunks)
+
+ def compute(
+ self,
+ experience,
+ final_reward: torch.Tensor,
+ gamma: Optional[float],
+ generate_kwargs: Dict,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
+ """Build per-token advantages via paper Eq.9.
+
+ Ignores ``final_reward`` (which carries the legacy Mode B step
+ scatter + KL — orthogonal to Eq.9). KL is still applied separately
+ by the surrounding ``--use_kl_loss`` path; we only own the
+ advantage shape here.
+ """
+ action_mask = experience.action_mask
+ if action_mask is None:
+ raise ValueError(
+ "UrsaVariant2Calculator requires action_mask (token-level "
+ "broadcast over step spans is undefined without it)."
+ )
+
+ device = action_mask.device
+ B, T = action_mask.shape
+
+ info = experience.info
+ oc_normed = info.get("_ursa_oc_normed")
+ msp_normed = info.get("_ursa_msp_normed")
+ if oc_normed is None or msp_normed is None:
+ # preprocess_rewards bailed (K<2 or shape mismatch) — fall back
+ # to a degenerate advantage tensor of zeros so loss stays finite.
+ advantages = torch.zeros(B, T, device=device, dtype=torch.float32)
+ returns = advantages.clone()
+ return advantages, returns, {"ursa_v2_fallback_used": 1.0}
+
+ oc_normed = oc_normed.to(device=device, dtype=torch.float32)
+ msp_normed = msp_normed.to(device=device, dtype=torch.float32)
+
+ step_rewards_list = info.get("step_rewards") or []
+ step_indices_list = info.get("step_token_indices") or []
+
+ # One-shot diagnostic on first invocation (rank0 only). Two purposes:
+ # (1) verifies step_rewards / step_token_indices actually reached
+ # compute() — easy to miss otherwise if something upstream
+ # drops the lists (cf. the multi-RM aggregator drop fix).
+ # (2) dumps the full paper Eq.9 chain on a real trajectory so the
+ # smoke log carries acceptance evidence for AC1+AC8 directly.
+ if not getattr(UrsaVariant2Calculator, "_dumped_first_call", False):
+ try:
+ import torch.distributed as dist
+ rank = dist.get_rank() if dist.is_initialized() else 0
+ except Exception:
+ rank = 0
+ if rank == 0:
+ step_rewards_keys = bool(info.get("step_rewards"))
+ step_indices_keys = bool(info.get("step_token_indices"))
+ sr_lens = ([t.numel() for t in (info.get("step_rewards") or [])][:8])
+ sti_lens = ([t.numel() for t in (info.get("step_token_indices") or [])][:8])
+ print(f"[ursa_v2:compute] first call rank=0 B={B} T={T} "
+ f"has_step_rewards={step_rewards_keys} "
+ f"has_step_token_indices={step_indices_keys} "
+ f"sr_lens(first8)={sr_lens} sti_lens(first8)={sti_lens} "
+ f"info keys={sorted([k for k in info.keys() if not k.startswith('_')])}",
+ flush=True)
+ # Full Eq.9 chain dump — every intermediate value so a
+ # reviewer can verify the implementation matches the paper
+ # formula on real PRM output.
+ if info.get("step_rewards"):
+ sr_lst = info["step_rewards"]
+ sti_lst = info["step_token_indices"]
+ outcome = info["reward"].float()
+ K = self.config.n_samples_per_prompt
+ print(f"[ursa_v2:chain] === paper Eq.9 chain on real PRM output (K={K}) ===", flush=True)
+ print(f"[ursa_v2:chain] outcome (r_o per traj) = {outcome.tolist()}", flush=True)
+ r_bar = torch.stack([t.float().mean() if t.numel() > 0
+ else torch.tensor(0.0) for t in sr_lst])
+ print(f"[ursa_v2:chain] r_bar_s (mean step PRM)= {r_bar.tolist()}", flush=True)
+ print(f"[ursa_v2:chain] msp_normed (post GN) = {msp_normed.tolist()}", flush=True)
+ print(f"[ursa_v2:chain] oc_normed (post GN) = {oc_normed.tolist()}", flush=True)
+ for i in range(min(B, K)):
+ sr_i = sr_lst[i].float().tolist()
+ si_i = sti_lst[i].long().tolist()
+ a_steps = [float(r) * float(msp_normed[i]) + float(oc_normed[i])
+ for r in sr_i]
+ print(f"[ursa_v2:chain] traj {i}: r_o={float(outcome[i]):+.2f} "
+ f"r_bar={float(r_bar[i]):.4f} msp_normed={float(msp_normed[i]):+.4f} "
+ f"oc_normed={float(oc_normed[i]):+.4f}", flush=True)
+ for k, (r, idx, a) in enumerate(zip(sr_i, si_i, a_steps)):
+ print(f"[ursa_v2:chain] step {k+1}: r_s={r:.4f} "
+ f"end_token={idx:4d} A_step={r:.4f}·{float(msp_normed[i]):+.4f} + "
+ f"{float(oc_normed[i]):+.4f} = {a:+.4f}", flush=True)
+ UrsaVariant2Calculator._dumped_first_call = True
+
+ advantages = torch.zeros(B, T, device=device, dtype=torch.float32)
+ per_traj_step_count = []
+ for i in range(B):
+ has_steps = (
+ i < len(step_rewards_list)
+ and step_rewards_list[i].numel() > 0
+ and i < len(step_indices_list)
+ and step_indices_list[i].numel() == step_rewards_list[i].numel()
+ )
+ if not has_steps:
+ # No step data — degenerate to outcome-only term spread over
+ # the response (matches paper's natural limit when n_steps=0
+ # since the process-reward term vanishes).
+ advantages[i] = oc_normed[i] * action_mask[i].to(torch.float32)
+ per_traj_step_count.append(0)
+ continue
+
+ sr = step_rewards_list[i].to(device=device, dtype=torch.float32) # (n_steps,)
+ si = step_indices_list[i].to(device=device, dtype=torch.long) # (n_steps,) END idx
+ n_steps = sr.numel()
+ per_traj_step_count.append(int(n_steps))
+
+ # Span starts: 0 for step 0, end_{k-1}+1 for k > 0
+ starts = torch.cat([
+ torch.zeros(1, dtype=torch.long, device=device),
+ si[:-1] + 1,
+ ])
+ ends = si
+
+ # Per-step advantage: A_k = r_{s,k} * msp_normed[i] + oc_normed[i]
+ # (paper Eq.9)
+ A_steps = sr * msp_normed[i] + oc_normed[i] # (n_steps,)
+
+ for k in range(n_steps):
+ sk = max(0, int(starts[k].item()))
+ ek = min(T - 1, int(ends[k].item()))
+ if sk > ek:
+ continue
+ advantages[i, sk:ek + 1] = A_steps[k]
+
+ # Tokens past the last step boundary (e.g. final `†Answer:` line
+ # tokens) are not covered by any step. Per paper Eq.9 the second
+ # term is t-independent, so we still apply oc_normed[i] there
+ # to give the model an outcome-only signal on the tail. This
+ # matches the implicit reading that the outcome anchor lives
+ # on the whole trajectory while step rewards live on steps.
+ last_end = int(ends[-1].item()) if n_steps > 0 else -1
+ if last_end + 1 < T:
+ advantages[i, last_end + 1:] = oc_normed[i]
+
+ # Respect the response action mask everywhere.
+ advantages = advantages * action_mask.to(torch.float32)
+ returns = advantages.clone()
+
+ # Per-step credit diagnostics (these flow into the trainer's wandb
+ # under `train/`-style keys via the existing info_dict pipeline).
+ n_valid = action_mask.sum().clamp(min=1).to(torch.float32)
+ info_dict: Dict[str, float] = {
+ # Restrict the *_frac counters to valid (un-masked) tokens so they
+ # don't include padding-induced zeros in the denominator's response
+ # area. n_valid is action_mask.sum(); we mask both numerator and
+ # event-set to (action_mask == 1).
+ "ursa_v2_adv_pos_frac": ((advantages > 0) & action_mask.bool()).to(torch.float32).sum().item() / n_valid.item(),
+ "ursa_v2_adv_neg_frac": ((advantages < 0) & action_mask.bool()).to(torch.float32).sum().item() / n_valid.item(),
+ "ursa_v2_adv_zero_frac": ((advantages == 0) & action_mask.bool()).to(torch.float32).sum().item() / n_valid.item(),
+ "ursa_v2_adv_abs_mean": advantages.abs().sum().item() / n_valid.item(),
+ "ursa_v2_oc_normed_std": oc_normed.std(unbiased=False).item() if oc_normed.numel() > 1 else 0.0,
+ "ursa_v2_msp_normed_std": msp_normed.std(unbiased=False).item() if msp_normed.numel() > 1 else 0.0,
+ "ursa_v2_traj_step_count_mean": (
+ sum(per_traj_step_count) / max(1, len(per_traj_step_count))
+ ),
+ }
+
+ # Advantage clipping (config knob, optional).
+ if getattr(self.config, "advantage_clip", 0) > 0:
+ clip_val = self.config.advantage_clip
+ info_dict["advantage_clip_frac"] = compute_clip_fraction(
+ advantages, clip_val, -clip_val
+ )
+ advantages = torch.clamp(advantages, -clip_val, clip_val)
+
+ return advantages, returns, info_dict
+
+
+def _install_aggregate_rewards_patch() -> None:
+ """Forward step_rewards / step_token_indices through the multi-RM aggregator.
+
+ Background: ``examples/math_prm/reward_models_utils.load_reward_models``
+ returns reward_models as a List[nn.Module] even when there is only one
+ RM. That makes ``fast_exp_maker._aggregate_rewards`` take the
+ ``is_multi_rm=True`` branch, which writes ``outputs[i].rewards`` and
+ ``outputs[i].reward_metrics`` but — by design — drops the per-step
+ variable-length fields. That's correct for true multi-RM aggregation
+ (where combining variable-length step tensors across RMs is ill-
+ defined), but it silently breaks the single-list-of-one-RM case that
+ this example uses.
+
+ Patch: after the original ``_aggregate_rewards`` runs, scan for the
+ "single underlying RM but exposed as a 1-list" pattern and lift the
+ step_rewards / step_token_indices from that one RM's batch result
+ into ``outputs[i]``. No behaviour change for true multi-RM setups.
+ """
+ from lightrft.trainer import fast_exp_maker as _fem
+
+ # _aggregate_rewards lives on RewardComputationEngine (separate class
+ # from FastExperienceMaker; reachable via fast_exp_maker.RewardComputationEngine
+ # or self.reward_engine on the maker).
+ _RewardEngine = getattr(_fem, "RewardComputationEngine", None)
+ if _RewardEngine is None or not hasattr(_RewardEngine, "_aggregate_rewards"):
+ return
+ if getattr(_RewardEngine, "_ursa_v2_aggregator_patched", False):
+ return
+
+ _original = _RewardEngine._aggregate_rewards
+
+ def _aggregate_rewards_patched(self, outputs, all_rewards_list, is_multi_rm):
+ _original(self, outputs, all_rewards_list, is_multi_rm)
+ if not is_multi_rm:
+ return
+ # If multiple RMs actually produced step_rewards we don't know how to
+ # merge them — bail (keep lightrft's safe default).
+ rms_with_steps = [
+ rm_idx for rm_idx in range(len(all_rewards_list))
+ if any(getattr(r, "step_rewards", None) is not None
+ for r in all_rewards_list[rm_idx])
+ ]
+ if len(rms_with_steps) != 1:
+ return
+ rm_idx = rms_with_steps[0]
+ for mb_idx in range(len(outputs)):
+ res = all_rewards_list[rm_idx][mb_idx]
+ sr = getattr(res, "step_rewards", None)
+ sti = getattr(res, "step_token_indices", None)
+ if sr is not None and getattr(outputs[mb_idx], "step_rewards", None) is None:
+ outputs[mb_idx].step_rewards = sr
+ if sti is not None and getattr(outputs[mb_idx], "step_token_indices", None) is None:
+ outputs[mb_idx].step_token_indices = sti
+
+ _aggregate_rewards_patched._ursa_v2_patched = True
+ _RewardEngine._aggregate_rewards = _aggregate_rewards_patched
+ _RewardEngine._ursa_v2_aggregator_patched = True
+
+
+def _install_get_advantage_calculator_patch() -> None:
+ """Idempotently inject ``ursa_variant2`` into lightrft's calculator factory.
+
+ Done from examples/ rather than editing ``lightrft/`` to keep the new
+ estimator strictly contained in this example. The patch wraps the
+ original factory; unknown names still raise the original ValueError
+ listing the *original* supported set + this estimator.
+
+ Important: we patch every module that has already done
+ ``from .advantage_calculator import get_advantage_calculator`` because
+ those imports bind the original function object into the consumer
+ module's namespace — patching just the source module would miss them.
+ """
+ from lightrft.trainer import advantage_calculator as _ac
+
+ if getattr(_ac.get_advantage_calculator, "_ursa_v2_patched", False):
+ return
+
+ _original = _ac.get_advantage_calculator
+
+ def get_advantage_calculator_patched(estimator_name: str, config):
+ if estimator_name == UrsaVariant2Calculator._ESTIMATOR_NAME:
+ return UrsaVariant2Calculator(config)
+ return _original(estimator_name, config)
+
+ get_advantage_calculator_patched._ursa_v2_patched = True
+ _ac.get_advantage_calculator = get_advantage_calculator_patched
+
+ # Also patch known consumers that did ``from .advantage_calculator import
+ # get_advantage_calculator`` (binding the original ref into their own
+ # namespace). Currently fast_exp_maker is the only such consumer; if more
+ # appear later, list them here.
+ import sys
+ for mod_name in ("lightrft.trainer.fast_exp_maker",):
+ mod = sys.modules.get(mod_name)
+ if mod is not None and hasattr(mod, "get_advantage_calculator"):
+ mod.get_advantage_calculator = get_advantage_calculator_patched
+
+
+def register_ursa_variant2() -> None:
+ """Install both monkey-patches so ``--advantage_estimator ursa_variant2``
+ becomes a valid option and the multi-RM aggregator forwards step_rewards.
+
+ Idempotent. The two underlying ``_install_*_patch`` helpers each guard
+ themselves with a sentinel attribute, so calling this multiple times
+ (e.g. from both ``math_prm_trainer`` and a future user-side import) is
+ safe.
+ """
+ _install_get_advantage_calculator_patch()
+ _install_aggregate_rewards_patch()
+
+
+# Also install on import so existing call-sites that rely on the side-effect
+# behaviour (``import ursa_variant2`` near the top of ``math_prm_trainer``)
+# still work. New code should prefer the explicit ``register_ursa_variant2()``
+# entry point.
+register_ursa_variant2()
diff --git a/lightrft/models/actor_language.py b/lightrft/models/actor_language.py
index 912e93d3..244ebc90 100644
--- a/lightrft/models/actor_language.py
+++ b/lightrft/models/actor_language.py
@@ -210,10 +210,14 @@ def generate(
use_cache=True,
num_beams=kwargs.get("num_beams", 1),
attention_mask=kwargs.get("attention_mask"),
+ logits_processor=kwargs.get("logits_processor"),
eos_token_id=kwargs.get("eos_token_id"),
pad_token_id=kwargs.get("pad_token_id"),
min_new_tokens=kwargs.get("min_new_tokens", 1),
+ repetition_penalty=kwargs.get("repetition_penalty", 1.0),
)
+ if kwargs.get("no_repeat_ngram_size", 0) > 0:
+ generate_args["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"]
if kwargs.get("max_new_tokens") is not None:
generate_args["max_new_tokens"] = kwargs["max_new_tokens"]
if kwargs.get("max_length") is not None:
diff --git a/lightrft/models/actor_vl.py b/lightrft/models/actor_vl.py
index 878c611e..4d9fce40 100644
--- a/lightrft/models/actor_vl.py
+++ b/lightrft/models/actor_vl.py
@@ -19,6 +19,7 @@
- MoE (Mixture of Experts) model support
"""
+import inspect
from typing import Optional, Tuple, Union
import torch
@@ -85,6 +86,40 @@ class ActorVL(nn.Module):
# Model modality declaration - defines what types of inputs this model accepts
modality = ActorModality.VISION_LANGUAGE
+ def _get_model_dtype(self) -> Optional[torch.dtype]:
+ model_dtype = getattr(self.model, "dtype", None)
+ if isinstance(model_dtype, torch.dtype):
+ return model_dtype
+ for parameter in self.model.parameters():
+ if torch.is_floating_point(parameter):
+ return parameter.dtype
+ return None
+
+ def _cast_multimodal_tensor(self, value: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ if value is None or not torch.is_tensor(value) or not torch.is_floating_point(value):
+ return value
+ model_dtype = self._get_model_dtype()
+ if model_dtype is None or value.dtype == model_dtype:
+ return value
+ return value.to(dtype=model_dtype)
+
+ def _supports_model_kwarg(self, kwarg_name: str, *, generation: bool = False) -> bool:
+ targets = []
+ if generation:
+ targets.append(getattr(self.model, "prepare_inputs_for_generation", None))
+ targets.append(getattr(self.model, "forward", None))
+
+ for target in targets:
+ if target is None:
+ continue
+ try:
+ parameters = inspect.signature(target).parameters
+ except (TypeError, ValueError):
+ continue
+ if kwarg_name in parameters:
+ return True
+ return False
+
def __init__(
self,
pretrain_or_model,
@@ -102,6 +137,7 @@ def __init__(
) -> None:
super().__init__()
self.high_entropy_token_ratio = high_entropy_token_ratio
+ self.packing_samples = packing_samples
if isinstance(pretrain_or_model, str):
self.pretrain_or_model = pretrain_or_model
@@ -151,8 +187,6 @@ def __init__(
# Use `model.generate(use_cache=True)` instead.`
self.model.config.use_cache = False
- # packing samples using Flash Attention 2
- self.packing_samples = packing_samples
else:
self.model = pretrain_or_model
self.pretrain_or_model = pretrain_or_model.config.model_type
@@ -206,9 +240,9 @@ def generate(
"""
generate_args = {
"input_ids": input_ids,
- "pixel_values": pixel_values,
+ "pixel_values": self._cast_multimodal_tensor(pixel_values),
"image_grid_thw": image_grid_thw,
- "pixel_values_videos": pixel_values_videos,
+ "pixel_values_videos": self._cast_multimodal_tensor(pixel_values_videos),
"video_grid_thw": video_grid_thw,
"top_k": kwargs.get("top_k", None),
"top_p": kwargs.get("top_p", None),
@@ -218,16 +252,26 @@ def generate(
"use_cache": True,
"num_beams": kwargs.get("num_beams", 1),
"attention_mask": kwargs.get("attention_mask"),
+ "logits_processor": kwargs.get("logits_processor"),
"eos_token_id": kwargs.get("eos_token_id"),
"pad_token_id": kwargs.get("pad_token_id"),
"min_new_tokens": kwargs.get("min_new_tokens", 1),
+ "repetition_penalty": kwargs.get("repetition_penalty", 1.0),
}
+ if kwargs.get("no_repeat_ngram_size", 0) > 0:
+ generate_args["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"]
if kwargs.get("max_new_tokens", None):
generate_args["max_new_tokens"] = kwargs.get("max_new_tokens")
if kwargs.get("max_length", None):
generate_args["max_length"] = kwargs.get("max_length")
+ for model_kwarg in ("pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"):
+ if model_kwarg in generate_args and (
+ generate_args[model_kwarg] is None or not self._supports_model_kwarg(model_kwarg, generation=True)
+ ):
+ generate_args.pop(model_kwarg)
+
# Call generate
sequences = self.model.generate(**generate_args)
@@ -306,14 +350,21 @@ def forward(
# explicitly ignore attention_mask for packing_samples
attention_mask = None
+ forward_kwargs = {
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ "pixel_values": self._cast_multimodal_tensor(pixel_values),
+ "image_grid_thw": image_grid_thw,
+ "pixel_values_videos": self._cast_multimodal_tensor(pixel_values_videos),
+ "video_grid_thw": video_grid_thw,
+ }
+ for model_kwarg in ("pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"):
+ if not self._supports_model_kwarg(model_kwarg):
+ forward_kwargs.pop(model_kwarg, None)
+
output = self.model(
sequences,
- attention_mask=attention_mask,
- position_ids=position_ids,
- pixel_values=pixel_values,
- image_grid_thw=image_grid_thw,
- pixel_values_videos=pixel_values_videos,
- video_grid_thw=video_grid_thw,
+ **forward_kwargs,
)
if num_actions is None: # defult
diff --git a/lightrft/models/loss.py b/lightrft/models/loss.py
index d593cb93..a3134e74 100644
--- a/lightrft/models/loss.py
+++ b/lightrft/models/loss.py
@@ -176,6 +176,22 @@ def __init__(
self.use_dapo = use_dapo
self.use_cpg_loss = use_cpg_loss
self.high_entropy_token_ratio = high_entropy_token_ratio
+ # Per-forward diagnostic stats for ratio behavior. Populated each time
+ # forward() runs in PPO mode and read by the trainer via get_last_stats().
+ # CPGD mode has no ratio so these stay empty.
+ self._last_stats: dict = {}
+
+ def get_last_stats(self) -> dict:
+ """Return per-token ratio diagnostics from the most recent PPO forward.
+
+ Keys: ``ratio_mean``, ``ratio_max``, ``ratio_min``, ``clipfrac``,
+ ``approx_kl``. ``clipfrac`` is the fraction of valid action tokens
+ whose unclipped ratio fell outside ``[1 - clip_eps, 1 + clip_eps]``.
+ ``approx_kl`` is the K2 estimator ``0.5 * mean((log p - log p_old)^2)``
+ which is a low-variance proxy for the per-step "old vs new" KL
+ (different from the actor-vs-reference KL controller signal).
+ """
+ return dict(self._last_stats)
def forward(
self,
@@ -246,12 +262,44 @@ def forward(
return loss
# PPO loss
- ratio = (log_probs - old_log_probs).exp()
+ log_ratio = log_probs - old_log_probs
+ ratio = log_ratio.exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
loss = masked_mean(loss, final_mask, dim=-1).mean()
+ # Diagnostic stats over valid action tokens only.
+ # Detached so they don't accidentally enter the autograd graph.
+ with torch.no_grad():
+ if final_mask is None:
+ m = torch.ones_like(ratio, dtype=torch.bool)
+ else:
+ m = final_mask.bool()
+ r_valid = ratio[m]
+ lr_valid = log_ratio[m]
+ if r_valid.numel() == 0:
+ self._last_stats = {
+ "ratio_mean": 0.0,
+ "ratio_max": 0.0,
+ "ratio_min": 0.0,
+ "clipfrac": 0.0,
+ "approx_kl": 0.0,
+ }
+ else:
+ # `clipfrac` counts the tokens whose UNCLIPPED ratio is outside
+ # [1-eps, 1+eps]. High clipfrac means PPO is suppressing many
+ # gradient updates this step, which is a signal that the new
+ # policy has moved noticeably from the rollout policy.
+ clipped = (r_valid > 1 + self.clip_eps) | (r_valid < 1 - self.clip_eps)
+ self._last_stats = {
+ "ratio_mean": r_valid.float().mean().item(),
+ "ratio_max": r_valid.float().max().item(),
+ "ratio_min": r_valid.float().min().item(),
+ "clipfrac": clipped.float().mean().item(),
+ "approx_kl": 0.5 * lr_valid.float().pow(2).mean().item(),
+ }
+
return loss
diff --git a/lightrft/models/utils.py b/lightrft/models/utils.py
index 15f9c75f..fa8de9f5 100644
--- a/lightrft/models/utils.py
+++ b/lightrft/models/utils.py
@@ -224,6 +224,24 @@ def log_probs_from_logits(
>>> log_probs.shape
torch.Size([2, 3])
"""
+ # PyTorch's torch.gather(dim=-1, index=...) does NOT require non-dim
+ # axes to match: when ``logits`` has more rows than ``labels``, gather
+ # silently truncates to ``len(labels)`` instead of raising. That made it
+ # impossible to spot a VLM-specific alignment bug where ``output["logits"]``
+ # is longer than ``sequences`` (image placeholder gets expanded into N
+ # vision-patch tokens during the LM forward) — see PR #53. Reject the
+ # mismatch up-front so any future caller using this helper crashes loudly
+ # and is forced to align logits to labels at the call site.
+ if logits.shape[:-1] != labels.shape:
+ raise ValueError(
+ "log_probs_from_logits: logits and labels must have matching "
+ f"non-vocab shapes. Got logits.shape={tuple(logits.shape)}, "
+ f"labels.shape={tuple(labels.shape)}. For VLMs, output['logits'] "
+ "may be longer than the input sequences because vision tokens "
+ "expand placeholders during the forward pass — slice the logits "
+ "to the action range before calling this helper."
+ )
+
if logits.dtype in [torch.float32, torch.float64]:
batch_dim = logits.shape[:-1]
last_dim = logits.shape[-1]
@@ -444,14 +462,24 @@ def compute_reward(
action_mask: Optional[torch.Tensor] = None,
num_actions: Optional[Union[int, list[int]]] = None,
reward_clip_range: Tuple[float, float] = None,
+ step_rewards: Optional[torch.Tensor] = None,
+ step_token_indices: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, list[torch.Tensor]]:
"""
Compute final reward by combining base reward with KL penalty.
Combines base reward with KL divergence penalty to encourage policy stability.
- Supports two modes: with action mask (efficient) and without (individual processing).
-
- :param r: Base reward tensor or scalar
+ Supports three modes:
+ A. trajectory-scalar (default, legacy behavior): scatters scalar `r[i]`
+ to the EOS position of each row. Used by ORM-style RL.
+ B. per-step (NEW, opt-in): scatters multiple step rewards to step
+ boundary token positions for true per-step credit assignment. Used
+ by PRM-RL methods like Math-Shepherd / variant 2 of URSA paper.
+ Activated when `step_rewards` is not None.
+ C. no action_mask: per-row variable-length list mode (legacy).
+
+ :param r: Base reward tensor of shape (B,) or scalar. In per-step mode,
+ used as a fallback only if step_rewards is None for any row.
:type r: Union[torch.Tensor, float]
:param kl_coef: KL penalty coefficient (<=0 disables penalty)
:type kl_coef: float
@@ -463,11 +491,23 @@ def compute_reward(
:type num_actions: Optional[Union[int, list[int]]]
:param reward_clip_range: (min, max) to clip base reward
:type reward_clip_range: Tuple[float, float]
-
- :return: Final reward tensor or list
+ :param step_rewards: PER-STEP rewards of shape (B, max_steps) padded with
+ any value (only positions in `step_token_indices` are read). When
+ provided together with ``step_token_indices``, mode (B) is enabled
+ and the scalar ``r`` EOS-scatter is bypassed.
+ :type step_rewards: Optional[torch.Tensor]
+ :param step_token_indices: Token indices in the action / response space
+ of shape (B, max_steps). Positions with value < 0 are treated as
+ padding and skipped. ``step_token_indices[i, k]`` = boundary token
+ index in sequence i for step k; ``step_rewards[i, k]`` is scattered
+ to that position (NOT the EOS) so cumulative-returns can propagate
+ per-step credit.
+ :type step_token_indices: Optional[torch.Tensor]
+
+ :return: Final reward tensor or list, shape (B, response_size).
:rtype: Union[torch.Tensor, list[torch.Tensor]]
- Example::
+ Example (mode A, legacy)::
>>> r = torch.tensor([1.0, 2.0])
>>> kl_coef = 0.1
>>> kl = torch.tensor([[0.1, 0.2, 0.3], [0.2, 0.1, 0.4]])
@@ -475,29 +515,70 @@ def compute_reward(
>>> reward = compute_reward(r, kl_coef, kl, action_mask)
>>> reward.shape
torch.Size([2, 3])
+
+ Example (mode B, per-step)::
+ >>> step_rewards = torch.tensor([[0.2, 0.8, 0.7],
+ ... [0.5, 0.6, -1.]]) # last col padded
+ >>> step_token_indices = torch.tensor([[1, 3, 4],
+ ... [0, 2, -1]])
+ >>> reward = compute_reward(
+ ... r=None, kl_coef=0.0, kl=torch.zeros(2, 5),
+ ... action_mask=torch.ones(2, 5),
+ ... step_rewards=step_rewards,
+ ... step_token_indices=step_token_indices,
+ ... )
+ >>> reward
+ tensor([[0.0000, 0.2000, 0.0000, 0.8000, 0.7000],
+ [0.5000, 0.0000, 0.6000, 0.0000, 0.0000]])
"""
if kl_coef <= 0.0:
kl_coef = 0.0
- if reward_clip_range:
+ use_per_step = step_rewards is not None and step_token_indices is not None
+
+ if not use_per_step and reward_clip_range:
r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1])
if action_mask is not None:
kl_reward = -kl_coef * kl
- # The following code is equivalent to:
- #
- # last_reward = torch.zeros_like(kl)
- # for i in range(last_reward.size(0)):
- # for t in reversed(range(last_reward.size(1))):
- # if action_mask[i][t] > 0.5:
- # last_reward[i][t] = r[i]
- # break
- #
- eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True)
- last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype))
+
+ if use_per_step:
+ # Mode B: scatter each step's reward to its boundary token index.
+ # step_rewards: (B, S), step_token_indices: (B, S). Padding = idx < 0.
+ base = torch.zeros_like(kl)
+ valid = step_token_indices >= 0
+ if valid.any():
+ # Gather flat row/col indices, scatter via index_put_
+ row_idx = torch.arange(step_token_indices.size(0), device=kl.device)
+ row_idx = row_idx.unsqueeze(1).expand_as(step_token_indices)
+ flat_rows = row_idx[valid]
+ flat_cols = step_token_indices[valid].long()
+ # Clamp cols into valid range to avoid OOB; padded cols are filtered by `valid`
+ flat_cols = flat_cols.clamp(min=0, max=base.size(1) - 1)
+ flat_vals = step_rewards[valid].to(kl.dtype)
+ # accumulate (multiple steps could land on same token; rare but safe)
+ base.index_put_((flat_rows, flat_cols), flat_vals, accumulate=True)
+ last_reward = base
+ else:
+ # Mode A: legacy — scatter scalar r[i] to EOS index of row i.
+ #
+ # The following code is equivalent to:
+ #
+ # last_reward = torch.zeros_like(kl)
+ # for i in range(last_reward.size(0)):
+ # for t in reversed(range(last_reward.size(1))):
+ # if action_mask[i][t] > 0.5:
+ # last_reward[i][t] = r[i]
+ # break
+ #
+ eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True)
+ last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype))
reward = last_reward + kl_reward
else:
+ # Mode C: per-row variable-length (legacy). Per-step mode is only
+ # supported with action_mask; fall back to scalar EOS even if
+ # use_per_step is set, to keep this branch backward-compat.
# TODO: write a more efficient version
reward = []
for i, (kl_seg, action_len) in enumerate(zip(kl, num_actions)):
diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py
index 008fccd9..f523b3bb 100644
--- a/lightrft/strategy/config.py
+++ b/lightrft/strategy/config.py
@@ -48,10 +48,18 @@ class StrategyConfig:
overlap_comm: bool = False
# Engine and inference parameters
- # (str): Inference engine type, defaults to "vllm"
+ # (str): Inference engine type, defaults to "vllm". Supported values include "vllm", "sglang", and "hf".
engine_type: str = "vllm"
# (int): Engine tensor parallelism size, defaults to 1
engine_tp_size: int = 1
+ # (int): Maximum local HF generation batch size, <=0 disables chunking
+ local_hf_generate_max_batch_size: int = 0
+ # (int): Optional max_new_tokens cap applied only to local HF rollout generation, <=0 disables the cap
+ local_hf_max_new_tokens: int = 0
+ # (bool): Use a dedicated local HF rollout actor instead of reusing the training actor
+ hf_separate_rollout_actor: bool = False
+ # (bool): Keep the dedicated local HF rollout actor resident on GPU instead of sleeping/offloading it
+ hf_separate_rollout_keep_on_gpu: bool = False
# (bool): Enable engine sleep mode, defaults to False
enable_engine_sleep: bool = False
# (int): Local rank for distributed training, defaults to -1
@@ -139,7 +147,6 @@ class StrategyConfig:
plot_every: int = -1
# (bool): Use TensorBoard for logging, defaults to False
use_tensorboard: bool = False
-
# Additional arguments for backward compatibility
# (Dict[str, Any]): Extra arguments for backward compatibility, defaults to {}
extra_args: Dict[str, Any] = field(default_factory=dict)
@@ -237,7 +244,10 @@ def print_config_summary(self) -> None:
# Engine and Inference Parameters
print("\nEngine and Inference Parameters:")
- for attr in ['engine_type', 'engine_tp_size', 'enable_engine_sleep', 'local_rank', 'sp_size']:
+ for attr in [
+ 'engine_type', 'engine_tp_size', 'local_hf_generate_max_batch_size', 'hf_separate_rollout_actor',
+ 'enable_engine_sleep', 'local_rank', 'sp_size'
+ ]:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
diff --git a/lightrft/strategy/fake_strategy.py b/lightrft/strategy/fake_strategy.py
index e1ed71ed..c39f005f 100644
--- a/lightrft/strategy/fake_strategy.py
+++ b/lightrft/strategy/fake_strategy.py
@@ -282,7 +282,7 @@ def get_rank(self) -> int:
"""
return 0
- def setup_inference_engine(self, args, engine_type="vllm", actor=None):
+ def setup_inference_engine(self, args, engine_type="vllm", actor=None, tokenizer=None, processor=None):
"""
Fake inference engine setup - returns None.
@@ -311,7 +311,18 @@ def wakeup_inference_engine(self):
"""
self.print("FakeStrategy: Inference engine wakeup skipped")
- def engine_generate_local(self, sampling_params, prompt_token_ids=None, multi_modal_inputs=None):
+ def engine_generate_local(
+ self,
+ sampling_params,
+ prompt_token_ids=None,
+ multi_modal_inputs=None,
+ pixel_values=None,
+ image_grid_thw=None,
+ pixel_values_videos=None,
+ video_grid_thw=None,
+ images_num=None,
+ videos_num=None,
+ ):
"""
Fake generation - returns empty results.
@@ -332,7 +343,13 @@ def gather_and_generate(
all_prompts=None,
all_images=None,
sleep_engine=True,
- images_num=None
+ images_num=None,
+ all_videos=None,
+ videos_num=None,
+ all_images_pixel_values=None,
+ all_videos_pixel_values=None,
+ all_images_grid_thw=None,
+ all_videos_grid_thw=None,
):
"""
Fake gather and generate - returns empty results.
diff --git a/lightrft/strategy/strategy_base.py b/lightrft/strategy/strategy_base.py
index 050c5a00..20a919cf 100644
--- a/lightrft/strategy/strategy_base.py
+++ b/lightrft/strategy/strategy_base.py
@@ -110,7 +110,15 @@ def __init__( # pylint: disable=R0917
# inference (rollout) engine related
self.inference_engine = None
self.inference_engine_status = EngineStatus.SLEEPED
+ self.inference_tokenizer = None
+ self.inference_processor = None
self.broadcast_manager = None
+ self.rollout_train_actor = None
+ self.rollout_train_actor_is_on_gpu = False
+ self.use_separate_hf_rollout_actor = False
+ self._separate_hf_rollout_sync_param_pairs = None
+ self._separate_hf_rollout_sync_buffer_pairs = None
+ self.last_separate_hf_rollout_sync_stats = None
self.time_steps = defaultdict(int)
@@ -184,7 +192,7 @@ def setup_distributed(self, timeout: Optional[timedelta] = None, num_gpu_per_nod
)
# TODO: unify the init_process_group for both vllm and sglang when stable version finished
- if self.config.engine_type in ("vllm", "sglang"):
+ if self.config.engine_type in ("vllm", "sglang", "hf"):
dist.init_process_group(
rank=rank,
world_size=world_size,
@@ -198,7 +206,7 @@ def setup_distributed(self, timeout: Optional[timedelta] = None, num_gpu_per_nod
raise ValueError(f"Unsupported backend: {self.config.engine_type}")
else:
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
- if self.config.engine_type in ("vllm", "sglang"):
+ if self.config.engine_type in ("vllm", "sglang", "hf"):
deepspeed.init_distributed(dist_backend="nccl", timeout=timeout)
else:
raise ValueError(f"Unsupported backend: {self.config.engine_type}")
@@ -528,6 +536,16 @@ def prepare_models_and_optimizers(self, actor, critic, reward_models, initial_mo
setattr(actor, "is_actor", True)
fsdp_enable = self.config.fsdp
+
+ def _resolve_fsdp_shard_size(preferred_shard_size: int) -> int:
+ world_size = max(int(getattr(self, "world_size", 1)), 1)
+ if world_size % preferred_shard_size == 0:
+ return preferred_shard_size
+ candidate = min(preferred_shard_size, world_size)
+ while candidate > 1 and world_size % candidate != 0:
+ candidate -= 1
+ return max(candidate, 1)
+
# For FSDP: wrap model first, then create optimizer
if fsdp_enable:
actor = self.prepare_model(actor, is_training=True)
@@ -535,7 +553,11 @@ def prepare_models_and_optimizers(self, actor, critic, reward_models, initial_mo
if critic is not None:
critic = self.prepare_model(critic, is_training=True)
if not self.config.remote_rm_url:
- reward_model_shard_size = self.world_size
+ reward_model_shard_size = _resolve_fsdp_shard_size(preferred_shard_size=8)
+ self.print(
+ "Preparing reward model(s) with shard_size="
+ f"{reward_model_shard_size} (world_size={getattr(self, 'world_size', 1)})"
+ )
if isinstance(reward_models, (tuple, list)):
reward_models = [
self.prepare_model(model, shard_size=reward_model_shard_size) for model in reward_models
@@ -655,13 +677,164 @@ def report_memory(cls, prefix=""):
f"ALLOCATED={torch.cuda.memory_allocated() / 1e9:.2f} GB"
)
- def setup_inference_engine(self, args, engine_type="vllm", actor=None):
+ def _uses_separate_hf_rollout_actor(self) -> bool:
+ return (
+ self.inference_engine_type == "hf" and self.use_separate_hf_rollout_actor
+ and self.rollout_train_actor is not None and self.inference_engine is not None
+ and self.inference_engine is not self.rollout_train_actor
+ )
+
+ def _keeps_separate_hf_rollout_actor_on_gpu(self) -> bool:
+ return self._uses_separate_hf_rollout_actor() and bool(
+ getattr(self.config, "hf_separate_rollout_keep_on_gpu", False)
+ )
+
+ def _build_local_hf_rollout_actor_sync_plan(
+ self, src_actor: nn.Module, dst_actor: nn.Module
+ ) -> Tuple[List[Tuple[str, nn.Parameter, nn.Parameter]], List[Tuple[str, torch.Tensor, torch.Tensor]]]:
+ src_params = dict(src_actor.named_parameters())
+ dst_params = dict(dst_actor.named_parameters())
+ if src_params.keys() != dst_params.keys():
+ missing_in_dst = sorted(set(src_params) - set(dst_params))
+ missing_in_src = sorted(set(dst_params) - set(src_params))
+ raise ValueError(
+ "Separate local HF rollout actor parameter mismatch. "
+ f"missing_in_dst={missing_in_dst[:8]}, missing_in_src={missing_in_src[:8]}"
+ )
+
+ param_pairs = []
+ for name, src_param in src_params.items():
+ dst_param = dst_params[name]
+ if src_param.shape != dst_param.shape:
+ raise ValueError(
+ f"Separate local HF rollout actor parameter shape mismatch for {name}: "
+ f"{tuple(src_param.shape)} vs {tuple(dst_param.shape)}"
+ )
+ param_pairs.append((name, src_param, dst_param))
+
+ src_buffers = dict(src_actor.named_buffers())
+ dst_buffers = dict(dst_actor.named_buffers())
+ buffer_pairs = []
+ for name in sorted(set(src_buffers) & set(dst_buffers)):
+ src_buffer = src_buffers[name]
+ dst_buffer = dst_buffers[name]
+ if src_buffer.shape != dst_buffer.shape:
+ raise ValueError(
+ f"Separate local HF rollout actor buffer shape mismatch for {name}: "
+ f"{tuple(src_buffer.shape)} vs {tuple(dst_buffer.shape)}"
+ )
+ buffer_pairs.append((name, src_buffer, dst_buffer))
+ return param_pairs, buffer_pairs
+
+ def _copy_local_hf_rollout_actor_state(self, src_actor: nn.Module, dst_actor: nn.Module) -> None:
+ if self._separate_hf_rollout_sync_param_pairs is None or self._separate_hf_rollout_sync_buffer_pairs is None:
+ (
+ self._separate_hf_rollout_sync_param_pairs,
+ self._separate_hf_rollout_sync_buffer_pairs,
+ ) = self._build_local_hf_rollout_actor_sync_plan(src_actor, dst_actor)
+
+ for name, src_param, dst_param in self._separate_hf_rollout_sync_param_pairs:
+ src_tensor = src_param.detach()
+ if src_tensor.device != dst_param.device or src_tensor.dtype != dst_param.dtype:
+ src_tensor = src_tensor.to(device=dst_param.device, dtype=dst_param.dtype)
+ dst_param.detach().copy_(src_tensor)
+
+ for name, src_buffer_ref, dst_buffer in self._separate_hf_rollout_sync_buffer_pairs:
+ src_buffer = src_buffer_ref.detach()
+ if src_buffer.device != dst_buffer.device or src_buffer.dtype != dst_buffer.dtype:
+ src_buffer = src_buffer.to(device=dst_buffer.device, dtype=dst_buffer.dtype)
+ dst_buffer.detach().copy_(src_buffer)
+
+ def _prepare_separate_hf_rollout_actor_for_generation(self) -> None:
+ if not self._uses_separate_hf_rollout_actor():
+ return
+ model = getattr(self.inference_engine, "model", None)
+ if isinstance(model, nn.Module):
+ model.eval()
+ self.inference_engine.eval()
+
+ def _sync_separate_hf_rollout_actor(self, actor: nn.Module) -> None:
+ if not self.config.fsdp:
+ raise NotImplementedError("Separate local HF rollout actor currently only supports FSDP.")
+ if not self._uses_separate_hf_rollout_actor():
+ raise RuntimeError("Separate local HF rollout actor is not initialized.")
+
+ keep_on_gpu = self._keeps_separate_hf_rollout_actor_on_gpu()
+ sync_t0 = time.time()
+ actor_offloaded = False
+ offload_actor_t0 = time.time()
+ if not keep_on_gpu and self.rollout_train_actor_is_on_gpu:
+ self.offload_model(actor)
+ self.rollout_train_actor_is_on_gpu = False
+ actor_offloaded = True
+ offload_actor_s = time.time() - offload_actor_t0
+
+ rollout_offloaded = False
+ offload_rollout_t0 = time.time()
+ if not keep_on_gpu and self.inference_engine_status != EngineStatus.SLEEPED:
+ self.offload_model(self.inference_engine, empty_cache=False)
+ rollout_offloaded = True
+ offload_rollout_s = time.time() - offload_rollout_t0
+
+ copy_state_t0 = time.time()
+ self._copy_local_hf_rollout_actor_state(actor, self.inference_engine)
+ copy_state_s = time.time() - copy_state_t0
+
+ reload_rollout_t0 = time.time()
+ rollout_reloaded = False
+ if keep_on_gpu:
+ # `keep_on_gpu=True` skips the wakeup path, so we must explicitly
+ # rematerialize the rollout actor here after copying the updated state.
+ self.reload_model(self.inference_engine)
+ rollout_reloaded = True
+ reload_rollout_s = time.time() - reload_rollout_t0
+
+ prepare_t0 = time.time()
+ self._prepare_separate_hf_rollout_actor_for_generation()
+ prepare_s = time.time() - prepare_t0
+
+ sync_clear_t0 = time.time()
+ if keep_on_gpu:
+ torch.cuda.synchronize()
+ torch.distributed.barrier()
+ self.inference_engine_status = EngineStatus.WAKEUP
+ else:
+ self.inference_engine_status = EngineStatus.SLEEPED
+ self.sync_and_clear_cache()
+ sync_clear_s = time.time() - sync_clear_t0
+ self.last_separate_hf_rollout_sync_stats = {
+ "total_s": round(time.time() - sync_t0, 4),
+ "keep_on_gpu": keep_on_gpu,
+ "actor_offloaded": actor_offloaded,
+ "rollout_offloaded": rollout_offloaded,
+ "rollout_reloaded": rollout_reloaded,
+ "offload_actor_s": round(offload_actor_s, 4),
+ "offload_rollout_s": round(offload_rollout_s, 4),
+ "copy_state_s": round(copy_state_s, 4),
+ "reload_rollout_s": round(reload_rollout_s, 4),
+ "prepare_s": round(prepare_s, 4),
+ "sync_clear_s": round(sync_clear_s, 4),
+ }
+ self.print(
+ "Finished update engine weights for separate local HF rollout actor",
+ self.last_separate_hf_rollout_sync_stats,
+ )
+
+ def setup_inference_engine(
+ self,
+ args,
+ engine_type="vllm",
+ actor=None,
+ rollout_actor=None,
+ tokenizer=None,
+ processor=None,
+ ):
"""
Initialize and setup the inference engine.
:param args: Configuration arguments
:type args: argparse.Namespace
- :param engine_type: Type of inference engine ('vllm' or 'sglang')
+ :param engine_type: Type of inference engine ('vllm', 'sglang', or 'hf')
:type engine_type: str
:param actor: The actor module, if passed, will be used to update engine weights
:type actor: torch.nn.Module
@@ -671,6 +844,14 @@ def setup_inference_engine(self, args, engine_type="vllm", actor=None):
:raises ValueError: If engine_type is not supported
"""
self.inference_engine_type = engine_type
+ self.inference_tokenizer = tokenizer
+ self.inference_processor = processor
+ self.rollout_train_actor = None
+ self.rollout_train_actor_is_on_gpu = False
+ self.use_separate_hf_rollout_actor = False
+ self._separate_hf_rollout_sync_param_pairs = None
+ self._separate_hf_rollout_sync_buffer_pairs = None
+ self.last_separate_hf_rollout_sync_stats = None
if engine_type == "vllm":
# Conditional import: vLLM is optional and only imported when explicitly requested
@@ -683,10 +864,30 @@ def setup_inference_engine(self, args, engine_type="vllm", actor=None):
from .sglang_utils import get_sglang_engine_for_rollout
self.inference_engine = get_sglang_engine_for_rollout(args)
self.inference_engine_status = EngineStatus.WAKEUP
+ elif engine_type == "hf":
+ if actor is None:
+ raise ValueError("engine_type='hf' requires the prepared actor to be passed in.")
+ if getattr(args, "hf_separate_rollout_actor", False):
+ if rollout_actor is None:
+ raise ValueError(
+ "engine_type='hf' with --hf_separate_rollout_actor requires a prepared rollout_actor."
+ )
+ self.use_separate_hf_rollout_actor = True
+ self.rollout_train_actor = actor
+ self.rollout_train_actor_is_on_gpu = True
+ self.inference_engine = rollout_actor
+ self._prepare_separate_hf_rollout_actor_for_generation()
+ self.inference_engine_status = (
+ EngineStatus.WAKEUP if self._keeps_separate_hf_rollout_actor_on_gpu() else EngineStatus.SLEEPED
+ )
+ else:
+ # Local HF mode reuses the actor directly for time-boxed smoke runs.
+ self.inference_engine = actor
+ self.inference_engine_status = EngineStatus.WAKEUP
else:
raise ValueError(f"Unsupported engine type: {engine_type}")
- if actor is not None:
+ if actor is not None and (engine_type != "hf" or self.use_separate_hf_rollout_actor):
self.update_engine_weights(actor)
self.maybe_sleep_inference_engine()
return self.inference_engine
@@ -700,15 +901,33 @@ def maybe_sleep_inference_engine(self):
:raises ValueError: If the inference engine type is not supported
"""
- if self.inference_engine is not None and self.args.enable_engine_sleep:
+ if self.inference_engine is None or self.inference_engine_status == EngineStatus.SLEEPED:
+ return
+ if self._keeps_separate_hf_rollout_actor_on_gpu():
+ self._prepare_separate_hf_rollout_actor_for_generation()
+ self.inference_engine_status = EngineStatus.WAKEUP
+ return
+ if self.inference_engine is not None and (
+ self.args.enable_engine_sleep or self._uses_separate_hf_rollout_actor()
+ ):
+ sleep_t0 = time.time()
if self.inference_engine_type in ["vllm", "sglang"]:
self.inference_engine.sleep()
+ elif self.inference_engine_type == "hf":
+ if self._uses_separate_hf_rollout_actor():
+ self.offload_model(self.inference_engine)
+ if not self.rollout_train_actor_is_on_gpu:
+ self.reload_model(self.rollout_train_actor)
+ self.rollout_train_actor_is_on_gpu = True
+ self.rollout_train_actor.train()
+ else:
+ return
else:
raise ValueError(f"Unsupported engine type: {self.inference_engine_type}")
self.inference_engine_status = EngineStatus.SLEEPED
self.sync_and_clear_cache()
- self.print("Sleeped inference engine")
+ self.print(f"Sleeped inference engine, TIMECOST {time.time() - sleep_t0}")
def wakeup_inference_engine(self):
"""
@@ -722,11 +941,25 @@ def wakeup_inference_engine(self):
"""
if self.inference_engine is None or self.inference_engine_status == EngineStatus.WAKEUP:
return
+ if self._keeps_separate_hf_rollout_actor_on_gpu():
+ self._prepare_separate_hf_rollout_actor_for_generation()
+ self.inference_engine_status = EngineStatus.WAKEUP
+ return
self.sync_and_clear_cache()
wkup_t0 = time.time()
if self.inference_engine_type in ["vllm", "sglang"]:
self.inference_engine.wake_up()
+ elif self.inference_engine_type == "hf":
+ if self._uses_separate_hf_rollout_actor():
+ if self.rollout_train_actor_is_on_gpu:
+ self.offload_model(self.rollout_train_actor)
+ self.rollout_train_actor_is_on_gpu = False
+ self.reload_model(self.inference_engine)
+ self._prepare_separate_hf_rollout_actor_for_generation()
+ self.print(f"Finished {self.inference_engine_type} wakeup, TIMECOST {time.time() - wkup_t0}")
+ self.inference_engine_status = EngineStatus.WAKEUP
+ return
else:
raise ValueError(f"Unsupported engine type: {self.inference_engine_type}")
# torch.cuda.reset_max_memory_allocated()
@@ -779,6 +1012,12 @@ def engine_generate_local(
sampling_params: Any,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
multi_modal_inputs: Optional[List[Dict[str, Any]]] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ image_grid_thw: Optional[torch.Tensor] = None,
+ pixel_values_videos: Optional[torch.Tensor] = None,
+ video_grid_thw: Optional[torch.Tensor] = None,
+ images_num: Optional[List[int]] = None,
+ videos_num: Optional[List[int]] = None,
) -> List[EasyDict]:
"""
Perform text or multimodal generation using different inference engines based on the input mode.
@@ -804,7 +1043,7 @@ def engine_generate_local(
if prompt_token_ids is None and multi_modal_inputs is None:
raise ValueError("Either prompt_token_ids or multi_modal_inputs must be provided.")
- if prompt_token_ids is not None and multi_modal_inputs is not None:
+ if self.inference_engine_type != "hf" and prompt_token_ids is not None and multi_modal_inputs is not None:
raise ValueError("Both prompt_token_ids and multi_modal_inputs can not be provided at the same time.")
# if inference engine is vllm
@@ -883,6 +1122,155 @@ def engine_generate_local(
output_token_ids=sglang_outputs[i]["output_ids"],
) for i in range(len(sglang_outputs))
]
+ elif self.inference_engine_type == "hf":
+ if prompt_token_ids is None:
+ raise ValueError("Local HF inference requires prompt_token_ids.")
+
+ from lightrft.datasets.utils import zero_pad_sequences
+
+ model = getattr(self.inference_engine, "model", self.inference_engine)
+ model_config = getattr(model, "config", None)
+ eos_token_id = getattr(self.inference_tokenizer, "eos_token_id", None)
+ pad_token_id = getattr(self.inference_tokenizer, "pad_token_id", None)
+
+ if eos_token_id is None and model_config is not None:
+ eos_token_id = getattr(model_config, "eos_token_id", None)
+ if pad_token_id is None and model_config is not None:
+ pad_token_id = getattr(model_config, "pad_token_id", None)
+
+ if isinstance(eos_token_id, (list, tuple)):
+ eos_token_id = eos_token_id[0]
+ if isinstance(pad_token_id, (list, tuple)):
+ pad_token_id = pad_token_id[0]
+ if pad_token_id is None:
+ pad_token_id = eos_token_id
+ if eos_token_id is None:
+ raise ValueError("Unable to resolve eos_token_id for local HF inference engine.")
+
+ normalized_prompt_ids = []
+ for token_ids in prompt_token_ids:
+ if isinstance(token_ids, torch.Tensor):
+ token_ids = token_ids.tolist()
+ normalized_prompt_ids.append(token_ids)
+
+ device = torch.cuda.current_device()
+
+ def _prepare_tensor(tensor):
+ if tensor is None:
+ return None
+ if isinstance(tensor, torch.Tensor) and tensor.numel() == 0:
+ return None
+ return tensor.to(device, non_blocking=True)
+
+ top_k = sampling_params.get("top_k", None)
+ if top_k is not None and top_k <= 0:
+ top_k = None
+ temperature = sampling_params.get("temperature", 1.0)
+ do_sample = sampling_params.get("do_sample", temperature is None or temperature > 0)
+
+ def _run_local_hf_batch(
+ batch_prompt_token_ids,
+ batch_pixel_values=None,
+ batch_image_grid_thw=None,
+ batch_pixel_values_videos=None,
+ batch_video_grid_thw=None,
+ ):
+ prompt_tensors = [torch.tensor(token_ids, dtype=torch.long) for token_ids in batch_prompt_token_ids]
+ padded_input_ids = zero_pad_sequences(prompt_tensors, side="left", value=pad_token_id).to(device)
+ attention_mask = padded_input_ids.ne(pad_token_id).long()
+ prompt_lengths = attention_mask.sum(dim=1).detach().cpu()
+
+ generate_t0 = time.time()
+ with torch.no_grad():
+ sequences, attention_mask_out, _ = self.inference_engine.generate(
+ input_ids=padded_input_ids,
+ attention_mask=attention_mask,
+ pixel_values=_prepare_tensor(batch_pixel_values),
+ image_grid_thw=_prepare_tensor(batch_image_grid_thw),
+ pixel_values_videos=_prepare_tensor(batch_pixel_values_videos),
+ video_grid_thw=_prepare_tensor(batch_video_grid_thw),
+ top_k=top_k,
+ top_p=sampling_params.get("top_p", 1.0),
+ temperature=temperature,
+ do_sample=do_sample,
+ max_new_tokens=sampling_params.get("max_new_tokens", 1024),
+ min_new_tokens=sampling_params.get("min_new_tokens", 1),
+ repetition_penalty=sampling_params.get("repetition_penalty", 1.0),
+ no_repeat_ngram_size=sampling_params.get("no_repeat_ngram_size", 0),
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ )
+ elapsed_s = round(time.time() - generate_t0, 4)
+ self.print(
+ "Local HF model.generate finished:", {
+ "batch_size": len(batch_prompt_token_ids),
+ "prompt_tokens": [len(token_ids) for token_ids in batch_prompt_token_ids],
+ "elapsed_s": elapsed_s,
+ }
+ )
+
+ output_start_idx = padded_input_ids.size(1)
+ sequences = sequences.detach().cpu()
+ attention_mask_out = attention_mask_out.detach().cpu()
+
+ batch_outputs = []
+ for idx in range(sequences.size(0)):
+ total_length = int(attention_mask_out[idx].sum().item())
+ generated_length = max(total_length - int(prompt_lengths[idx].item()), 0)
+ output_end_idx = output_start_idx + generated_length
+ batch_outputs.append(
+ EasyDict(
+ prompt_token_ids=batch_prompt_token_ids[idx],
+ output_token_ids=sequences[idx, output_start_idx:output_end_idx].tolist(),
+ )
+ )
+ return batch_outputs, elapsed_s
+
+ max_batch_size = max(int(getattr(self.config, "local_hf_generate_max_batch_size", 0) or 0), 0)
+ if max_batch_size > 0 and len(normalized_prompt_ids) > max_batch_size:
+ images_prefix = None
+ if images_num is not None:
+ images_prefix = [0]
+ for num in images_num:
+ images_prefix.append(images_prefix[-1] + num)
+ videos_prefix = None
+ if videos_num is not None:
+ videos_prefix = [0]
+ for num in videos_num:
+ videos_prefix.append(videos_prefix[-1] + num)
+
+ def _slice_modal_tensor(tensor, offsets, start, end):
+ if tensor is None:
+ return None
+ if not isinstance(tensor, torch.Tensor):
+ return tensor
+ if tensor.numel() == 0:
+ return tensor
+ if offsets is not None:
+ return tensor[offsets[start]:offsets[end]]
+ return tensor[start:end]
+
+ engine_outputs = []
+ for start in range(0, len(normalized_prompt_ids), max_batch_size):
+ end = min(start + max_batch_size, len(normalized_prompt_ids))
+ batch_outputs, chunk_elapsed_s = _run_local_hf_batch(
+ normalized_prompt_ids[start:end],
+ batch_pixel_values=_slice_modal_tensor(pixel_values, images_prefix, start, end),
+ batch_image_grid_thw=_slice_modal_tensor(image_grid_thw, images_prefix, start, end),
+ batch_pixel_values_videos=_slice_modal_tensor(pixel_values_videos, videos_prefix, start, end),
+ batch_video_grid_thw=_slice_modal_tensor(video_grid_thw, videos_prefix, start, end),
+ )
+ engine_outputs.extend(batch_outputs)
+ return engine_outputs
+
+ batch_outputs, chunk_elapsed_s = _run_local_hf_batch(
+ normalized_prompt_ids,
+ batch_pixel_values=pixel_values,
+ batch_image_grid_thw=image_grid_thw,
+ batch_pixel_values_videos=pixel_values_videos,
+ batch_video_grid_thw=video_grid_thw,
+ )
+ return batch_outputs
else:
raise ValueError(f"Unsupported engine type: {self.inference_engine_type}")
@@ -980,6 +1368,10 @@ def gather_and_generate(
images_num=None,
all_videos=None,
videos_num=None,
+ all_images_pixel_values=None,
+ all_videos_pixel_values=None,
+ all_images_grid_thw=None,
+ all_videos_grid_thw=None,
):
"""
Gather prompts across distributed ranks and perform text/multimodal generation.
@@ -1020,6 +1412,8 @@ def gather_and_generate(
"""
if self.inference_engine is None:
raise NotImplementedError("Inference engine is not initialized.")
+ if self._uses_separate_hf_rollout_actor():
+ sleep_engine = True
self.wakeup_inference_engine()
# is_multimodal = all_images is not None
@@ -1029,35 +1423,51 @@ def gather_and_generate(
is_multimodal = (((all_images is not None) and any(img is not None for img in all_images))
or ((all_videos is not None) and any(vid is not None for vid in all_videos)))
- if is_multimodal:
- inputs = self._build_multimodal_inputs(
- all_prompts=all_prompts,
- all_images=all_images,
- images_num=images_num,
- all_videos=all_videos,
- videos_num=videos_num,
- )
- else:
+ if self.inference_engine_type == "hf":
inputs = all_prompt_token_ids
assert inputs is not None
+ self.print(f"Start VLM gather_and_generate ..., total prompts: {len(inputs)}")
+ all_outputs = self.engine_generate_local(
+ sampling_params=sampling_params,
+ prompt_token_ids=inputs,
+ pixel_values=all_images_pixel_values if is_multimodal else None,
+ image_grid_thw=all_images_grid_thw if is_multimodal else None,
+ pixel_values_videos=all_videos_pixel_values if is_multimodal else None,
+ video_grid_thw=all_videos_grid_thw if is_multimodal else None,
+ images_num=images_num if is_multimodal else None,
+ videos_num=videos_num if is_multimodal else None,
+ )
+ local_outputs = all_outputs
+ else:
+ if is_multimodal:
+ inputs = self._build_multimodal_inputs(
+ all_prompts=all_prompts,
+ all_images=all_images,
+ images_num=images_num,
+ all_videos=all_videos,
+ videos_num=videos_num,
+ )
+ else:
+ inputs = all_prompt_token_ids
+ assert inputs is not None
- inputs = gather_inputs_object_for_inference(input_data=inputs, group=self.engine_mp_group)
+ inputs = gather_inputs_object_for_inference(input_data=inputs, group=self.engine_mp_group)
- self.print(f"Start VLM gather_and_generate ..., total prompts: {len(inputs)}")
+ self.print(f"Start VLM gather_and_generate ..., total prompts: {len(inputs)}")
- all_outputs = self.engine_generate_local(
- sampling_params=sampling_params,
- prompt_token_ids=None if is_multimodal else inputs,
- multi_modal_inputs=inputs if is_multimodal else None,
- )
+ all_outputs = self.engine_generate_local(
+ sampling_params=sampling_params,
+ prompt_token_ids=None if is_multimodal else inputs,
+ multi_modal_inputs=inputs if is_multimodal else None,
+ )
- engine_mp_size = torch.distributed.get_world_size(self.engine_mp_group)
- num_prompts_per_rank = len(all_outputs) // engine_mp_size
- assert len(all_outputs) % engine_mp_size == 0
- cur_rank = torch.distributed.get_rank(self.engine_mp_group)
- local_outputs = all_outputs[cur_rank * num_prompts_per_rank:(cur_rank + 1) * num_prompts_per_rank]
+ engine_mp_size = torch.distributed.get_world_size(self.engine_mp_group)
+ num_prompts_per_rank = len(all_outputs) // engine_mp_size
+ assert len(all_outputs) % engine_mp_size == 0
+ cur_rank = torch.distributed.get_rank(self.engine_mp_group)
+ local_outputs = all_outputs[cur_rank * num_prompts_per_rank:(cur_rank + 1) * num_prompts_per_rank]
- if self.inference_engine_type == "sglang":
+ if is_multimodal and self.inference_engine_type == "sglang":
# For SGLang VLM case, prompt_token_ids is set to None in engine_generate_local
# We need to fill it with the actual token_ids here
for i, output in enumerate(local_outputs):
@@ -1084,6 +1494,12 @@ def update_engine_weights(self, actor):
if self.inference_engine is None:
self.print("Skip update engine weights since inference engine is not initialized.")
return
+ if self.inference_engine_type == "hf":
+ if self._uses_separate_hf_rollout_actor():
+ self._sync_separate_hf_rollout_actor(actor)
+ else:
+ self.print("Skip update engine weights for local HF engine because it reuses the actor directly.")
+ return
# 1. wakeup engine if sleeped
self.wakeup_inference_engine()
diff --git a/lightrft/strategy/vllm_utils/__init__.py b/lightrft/strategy/vllm_utils/__init__.py
index fc9e67e6..2f6cd05f 100644
--- a/lightrft/strategy/vllm_utils/__init__.py
+++ b/lightrft/strategy/vllm_utils/__init__.py
@@ -14,6 +14,7 @@
To use vLLM backend, install with: pip install "LightRFT[vllm]"
"""
+import os
from typing import Any
@@ -156,6 +157,8 @@ def get_vllm_engine(
dtype=dtype,
tensor_parallel_size=tp_size,
gpu_memory_utilization=mem_util,
+ trust_remote_code=True,
+ allowed_local_media_path=os.environ.get("VLLM_ALLOWED_LOCAL_MEDIA_PATH", "/"),
distributed_executor_backend="external_launcher",
worker_cls="lightrft.strategy.vllm_utils.vllm_worker_wrap_no_ray.WorkerWrap",
enable_sleep_mode=enable_sleep,
diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py
index 8230ea65..2d97f85d 100644
--- a/lightrft/trainer/fast_exp_maker.py
+++ b/lightrft/trainer/fast_exp_maker.py
@@ -25,6 +25,7 @@
import time
import pathlib
import warnings
+from contextlib import contextmanager
from typing import Callable, Dict, List, Tuple, Union, Optional
from dataclasses import dataclass
from copy import deepcopy
@@ -138,6 +139,38 @@ class _SamplesOutput:
inputs_extra_kwargs: Optional[dict] = None
prompt_and_output: Optional[List[str]] = None
+ # Per-step PRM rewards (variant 2). When the reward model returns
+ # ``step_rewards`` / ``step_token_indices`` arrays alongside the
+ # trajectory-scalar score, they're stored here so that downstream
+ # advantage computation can scatter per-step credit to specific token
+ # positions (instead of only the EOS token). Both fields are
+ # micro-batch-sized lists; index i holds the per-step data for the i-th
+ # trajectory in the micro-batch as 1-D CPU tensors of equal length.
+ # Empty tensors mean "this trajectory uses trajectory-scalar rewards"
+ # and the legacy EOS-scatter path is taken in compute_reward.
+ step_rewards: Optional[List[torch.Tensor]] = None
+ step_token_indices: Optional[List[torch.Tensor]] = None
+
+
+@dataclass
+class _RewardBatchResult:
+ """Reward scores and optional auxiliary metrics for one micro-batch."""
+
+ scores: torch.Tensor
+ metrics: Optional[Dict[str, torch.Tensor]] = None
+ # Variant 2 (per-step PRM): when present, list[i] is a 1-D CPU tensor of
+ # step rewards / step boundary token indices for trajectory i in the
+ # micro-batch. When None or all-empty, the trajectory-scalar (EOS-scatter)
+ # path is used downstream.
+ step_rewards: Optional[List[torch.Tensor]] = None
+ step_token_indices: Optional[List[torch.Tensor]] = None
+
+
+class _NullProfiler:
+ @contextmanager
+ def section(self, _name: str):
+ yield
+
# ============================================================================
# Helper Classes
@@ -284,8 +317,10 @@ def process_multimodal_batch(
processor_kwargs = {
"text": all_prompts_multimodal.copy(),
"add_special_tokens": False,
+ "padding": True,
"max_length": self.prompt_max_len,
"truncation": True,
+ "return_tensors": "pt",
}
if flat_images:
processor_kwargs["images"] = flat_images
@@ -301,6 +336,15 @@ def process_multimodal_batch(
all_images_grid_thw_multimodal = inputs_multimodal.get("image_grid_thw", None)
all_videos_grid_thw_multimodal = inputs_multimodal.get("video_grid_thw", None)
+ # Some VLM processors (for example URSA) return batched image tensors directly
+ # and do not expose Qwen2-VL style grid metadata. In that case we synthesize a
+ # minimal per-image grid so the existing sample/replay slicing logic can still
+ # split one tensor slice per image while the model simply ignores the grid input.
+ if flat_images and all_images_grid_thw_multimodal is None:
+ all_images_grid_thw_multimodal = torch.ones((len(flat_images), 3), dtype=torch.long)
+ if flat_videos and all_videos_grid_thw_multimodal is None:
+ all_videos_grid_thw_multimodal = torch.ones((len(flat_videos), 3), dtype=torch.long)
+
# ===== Stage 4: Merge back in original order =====
total_samples = L * N
all_prompts_out = [None] * total_samples
@@ -446,16 +490,7 @@ def __init__(
:type packing_samples: bool
"""
self.reward_model = reward_model
- # Ensure remote_rm_url is a list for consistent iteration
- if remote_rm_url is None:
- self.remote_rm_url = None
- elif isinstance(remote_rm_url, str):
- self.remote_rm_url = [remote_rm_url]
- elif isinstance(remote_rm_url, (list, tuple)):
- self.remote_rm_url = list(remote_rm_url)
- else:
- raise TypeError(f"remote_rm_url must be str, list, tuple, or None, got {type(remote_rm_url).__name__}")
-
+ self.remote_rm_url = remote_rm_url
self.custom_reward_func = custom_reward_func
self.reward_fn = reward_fn
self.reward_fn_label_map = reward_fn_label_map or {}
@@ -581,7 +616,7 @@ def _compute_local_rewards(
self.strategy.reload_model(rm)
# Compute rewards for each RM
- # all_rewards_list[rm_idx][micro_batch_idx] = Tensor(batch_size,)
+ # all_rewards_list[rm_idx][micro_batch_idx] = _RewardBatchResult(batch_size,)
all_rewards_list = []
for rm_idx, rm in enumerate(rm_list):
@@ -602,7 +637,7 @@ def _compute_single_rm_rewards(
outputs: List[_SamplesOutput],
vlm_mode: bool,
device: torch.device,
- ) -> List[torch.Tensor]:
+ ) -> List[_RewardBatchResult]:
"""
Compute rewards for a single reward model across all micro-batches.
@@ -616,8 +651,8 @@ def _compute_single_rm_rewards(
:type vlm_mode: bool
:param device: Target device
:type device: torch.device
- :return: List of reward tensors, one per micro-batch
- :rtype: List[torch.Tensor]
+ :return: List of reward results, one per micro-batch
+ :rtype: List[_RewardBatchResult]
"""
# Check if this is a custom engine model (non-torch base_model)
is_custom_engine = (
@@ -640,7 +675,7 @@ def _compute_filtered_rewards(
rm_idx: int,
outputs: List[_SamplesOutput],
device: torch.device,
- ) -> List[torch.Tensor]:
+ ) -> List[_RewardBatchResult]:
"""
Compute rewards using optimized filtering (only process relevant samples).
@@ -655,8 +690,8 @@ def _compute_filtered_rewards(
:type outputs: List[_SamplesOutput]
:param device: Target device
:type device: torch.device
- :return: List of reward tensors per micro-batch
- :rtype: List[torch.Tensor]
+ :return: List of reward results per micro-batch
+ :rtype: List[_RewardBatchResult]
"""
# Get RM key from inverse label map
rm_key = self.inv_label_map.get(rm_idx)
@@ -695,7 +730,12 @@ def _compute_filtered_rewards(
# ========== Process Stage: Compute or skip ==========
if not needed_positions:
# No samples need this RM, return zeros for all micro-batches
- return [torch.zeros(len(output.labels), dtype=torch.float32, device=device) for output in outputs]
+ return [
+ _RewardBatchResult(
+ scores=torch.zeros(len(output.labels), dtype=torch.float32, device=device),
+ metrics=None,
+ ) for output in outputs
+ ]
# Run single forward pass on filtered samples
rm_output = rm(
@@ -717,14 +757,14 @@ def _compute_filtered_rewards(
for (mb_idx, samp_idx), score in zip(needed_positions, filtered_scores):
micro_batch_rewards[mb_idx][samp_idx] = score
- return micro_batch_rewards
+ return [_RewardBatchResult(scores=rewards, metrics=None) for rewards in micro_batch_rewards]
def _compute_batched_custom_engine_rewards(
self,
rm,
outputs: List[_SamplesOutput],
device: torch.device, # noqa: ARG002 (unused but kept for API consistency)
- ) -> List[torch.Tensor]:
+ ) -> List[_RewardBatchResult]:
"""
Compute rewards using custom engine with full batch processing (legacy path).
@@ -734,8 +774,8 @@ def _compute_batched_custom_engine_rewards(
:type outputs: List[_SamplesOutput]
:param device: Target device (unused but kept for API consistency)
:type device: torch.device
- :return: List of reward tensors per micro-batch
- :rtype: List[torch.Tensor]
+ :return: List of reward results per micro-batch
+ :rtype: List[_RewardBatchResult]
"""
# Flatten all micro-batches into single batch
flat_data = {
@@ -767,7 +807,34 @@ def _compute_batched_custom_engine_rewards(
# Split back into micro-batches
batch_sizes = [len(output.prompt_and_output) for output in outputs]
- return list(all_scores.split(batch_sizes))
+ return [
+ _RewardBatchResult(scores=scores.to(device=device, dtype=torch.float32), metrics=None)
+ for scores in all_scores.split(batch_sizes)
+ ]
+
+ @staticmethod
+ def _normalize_reward_metrics(
+ rm_output: Dict[str, torch.Tensor],
+ batch_size: int,
+ device: torch.device,
+ ) -> Optional[Dict[str, torch.Tensor]]:
+ metrics: Dict[str, torch.Tensor] = {}
+ for key, value in rm_output.items():
+ if key == "score":
+ continue
+ if not isinstance(value, torch.Tensor):
+ if isinstance(value, (int, float, bool)):
+ value = torch.tensor(value, dtype=torch.float32)
+ else:
+ continue
+ metric = value.to(device=device)
+ if metric.ndim == 0:
+ metric = metric.repeat(batch_size)
+ metric = metric.reshape(-1).float()
+ if metric.numel() != batch_size:
+ continue
+ metrics[key] = metric
+ return metrics or None
def _compute_standard_torch_rewards(
self,
@@ -775,7 +842,7 @@ def _compute_standard_torch_rewards(
outputs: List[_SamplesOutput],
vlm_mode: bool, # noqa: ARG002 (kept for future VLM-specific logic)
device: torch.device,
- ) -> List[torch.Tensor]:
+ ) -> List[_RewardBatchResult]:
"""
Compute rewards using standard PyTorch reward model.
@@ -789,10 +856,10 @@ def _compute_standard_torch_rewards(
:type vlm_mode: bool
:param device: Target device
:type device: torch.device
- :return: List of reward tensors per micro-batch
- :rtype: List[torch.Tensor]
+ :return: List of reward results per micro-batch
+ :rtype: List[_RewardBatchResult]
"""
- micro_batch_rewards = []
+ micro_batch_rewards: List[_RewardBatchResult] = []
for output in outputs:
# Unpack sequences if needed
@@ -805,22 +872,44 @@ def _compute_standard_torch_rewards(
rm_output = rm(
sequences,
output.attention_mask,
- references=output.references,
prompt_and_output=output.prompt_and_output,
raw_images=output.raw_images,
img_num=output.image_num,
+ references=output.references,
+ labels=output.labels,
**output.inputs_extra_kwargs,
)
- score = rm_output["score"] if isinstance(rm_output, dict) else rm_output
- micro_batch_rewards.append(torch.as_tensor(score, dtype=torch.float32, device=device))
+ if isinstance(rm_output, dict):
+ score = torch.as_tensor(rm_output["score"], dtype=torch.float32, device=device)
+ metrics = self._normalize_reward_metrics(rm_output, score.numel(), device)
+ # Variant 2 (per-step PRM reward): the reward model may emit
+ # per-trajectory step_rewards / step_token_indices lists. They
+ # come back as Python lists of CPU tensors (variable length per
+ # trajectory) — pass them through unchanged so compute_reward
+ # downstream can scatter per-step credit.
+ step_rewards = rm_output.get("step_rewards")
+ step_token_indices = rm_output.get("step_token_indices")
+ else:
+ score = torch.as_tensor(rm_output, dtype=torch.float32, device=device)
+ metrics = None
+ step_rewards = None
+ step_token_indices = None
+ micro_batch_rewards.append(
+ _RewardBatchResult(
+ scores=score,
+ metrics=metrics,
+ step_rewards=step_rewards,
+ step_token_indices=step_token_indices,
+ )
+ )
return micro_batch_rewards
def _aggregate_rewards(
self,
outputs: List[_SamplesOutput],
- all_rewards_list: List[List[torch.Tensor]],
+ all_rewards_list: List[List[_RewardBatchResult]],
is_multi_rm: bool,
) -> None:
"""
@@ -828,8 +917,8 @@ def _aggregate_rewards(
:param outputs: Sample outputs (modified in-place)
:type outputs: List[_SamplesOutput]
- :param all_rewards_list: Nested list [rm_idx][micro_batch_idx] -> Tensor
- :type all_rewards_list: List[List[torch.Tensor]]
+ :param all_rewards_list: Nested list [rm_idx][micro_batch_idx] -> reward result
+ :type all_rewards_list: List[List[_RewardBatchResult]]
:param is_multi_rm: Whether using multiple reward models
:type is_multi_rm: bool
"""
@@ -838,7 +927,8 @@ def _aggregate_rewards(
for mb_idx in range(num_micro_batches):
# Collect rewards from all RMs for this micro-batch
- same_batch_rewards = [all_rewards_list[rm_idx][mb_idx] for rm_idx in range(num_rms)]
+ same_batch_results = [all_rewards_list[rm_idx][mb_idx] for rm_idx in range(num_rms)]
+ same_batch_rewards = [result.scores for result in same_batch_results]
if is_multi_rm:
# Use custom aggregation function
@@ -850,6 +940,7 @@ def _aggregate_rewards(
rewards, reward_metrics = self.reward_fn(
model_reward_list=same_batch_rewards,
+ model_reward_metrics_list=[result.metrics for result in same_batch_results],
labels=outputs[mb_idx].labels,
queries=queries,
refs=outputs[mb_idx].references,
@@ -859,8 +950,17 @@ def _aggregate_rewards(
outputs[mb_idx].reward_metrics = reward_metrics
else:
# Single RM, use score directly
- outputs[mb_idx].rewards = same_batch_rewards[0]
- outputs[mb_idx].reward_metrics = None
+ outputs[mb_idx].rewards = same_batch_results[0].scores
+ outputs[mb_idx].reward_metrics = same_batch_results[0].metrics
+ # Variant 2 (per-step PRM): forward step_rewards / token
+ # indices from the single reward model into outputs so
+ # _process_experiences / compute_reward can see them. Multi-RM
+ # mode is intentionally NOT supported here — per-step credit
+ # assignment with multiple reward models would require a
+ # bespoke aggregator (open question for follow-up); we keep
+ # the multi-RM path on trajectory-scalar rewards only.
+ outputs[mb_idx].step_rewards = same_batch_results[0].step_rewards
+ outputs[mb_idx].step_token_indices = same_batch_results[0].step_token_indices
# ============================================================================
@@ -893,7 +993,7 @@ class FastExperienceMaker(NaiveExperienceMaker):
processor: Multimodal processor for vision-language models
*args, **kwargs: Arguments passed to parent NaiveExperienceMaker
"""
- def __init__(self, *args, packing_samples: bool = False, processor=None, **kwargs):
+ def __init__(self, *args, packing_samples: bool = False, processor=None, profiler=None, **kwargs):
"""
Initialize FastExperienceMaker.
@@ -913,6 +1013,7 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg
self.backend = self.strategy.args.engine_type
self.packing_samples = packing_samples
self.processor = processor
+ self.profiler = profiler if profiler is not None else _NullProfiler()
# Initialize tokenizer (extract from processor if needed)
if self.processor is not None:
@@ -938,31 +1039,9 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg
else:
self.multimodal_processor = None
- # For On-Policy Distillation (OPD), prefer dedicated teacher_model_url.
- # Fall back to remote_rm_url with deprecation warning for backwards compatibility.
- if advantage_estimator == "on_policy_distillation":
- teacher_url = getattr(self.strategy.args, 'teacher_model_url', None)
- if teacher_url is not None:
- self.teacher_model_url = teacher_url
- elif self.remote_rm_url is not None:
- import warnings
- warnings.warn(
- "Using --remote_rm_url as teacher URL is deprecated. "
- "Use --teacher_model_url instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- self.teacher_model_url = self.remote_rm_url
- else:
- self.teacher_model_url = None
- rm_url_for_reward_engine = None # Don't use remote_rm_url for rewards in OPD mode
- else:
- self.teacher_model_url = None
- rm_url_for_reward_engine = self.remote_rm_url
-
self.reward_engine = RewardComputationEngine(
reward_model=self.reward_model,
- remote_rm_url=rm_url_for_reward_engine,
+ remote_rm_url=self.remote_rm_url,
custom_reward_func=getattr(self, "custom_reward_func", None),
reward_fn=self.reward_fn,
reward_fn_label_map=getattr(self, "reward_fn_label_map", None),
@@ -1014,83 +1093,76 @@ def make_experience_list(
"""
config = self.strategy.config
- # Normalize images if provided
- if all_images is not None:
- if self.multimodal_processor is None:
- raise ValueError(
- "Multimodal data (images) provided but processor was not initialized. "
- "Please provide a processor when initializing FastExperienceMaker for VLM support."
- )
- all_images = normalize_images(all_images)
-
- # Normalize videos if provided
- if all_videos is not None:
- if self.multimodal_processor is None:
- raise ValueError(
- "Multimodal data (videos) provided but processor was not initialized. "
- "Please provide a processor when initializing FastExperienceMaker for VLM support."
+ with self.profiler.section("collect/total"):
+ if all_images is not None:
+ if self.multimodal_processor is None:
+ raise ValueError(
+ "Multimodal data (images) provided but processor was not initialized. "
+ "Please provide a processor when initializing FastExperienceMaker for VLM support."
+ )
+ all_images = normalize_images(all_images)
+
+ if all_videos is not None:
+ if self.multimodal_processor is None:
+ raise ValueError(
+ "Multimodal data (videos) provided but processor was not initialized. "
+ "Please provide a processor when initializing FastExperienceMaker for VLM support."
+ )
+ all_videos = normalize_videos(all_videos)
+
+ images_num = (get_images_num(all_images) if self.multimodal_processor and all_images is not None else None)
+ videos_num = (get_videos_num(all_videos) if self.multimodal_processor and all_videos is not None else None)
+
+ Timer.start(' generate_samples')
+ with self.profiler.section("collect/generate"):
+ samples_list = self.generate_samples(
+ all_prompts,
+ all_images=all_images,
+ images_num=images_num,
+ all_videos=all_videos,
+ videos_num=videos_num,
+ all_references=all_references,
+ all_labels=all_labels,
+ **generate_kwargs,
)
- all_videos = normalize_videos(all_videos)
-
- # Get image counts
- images_num = (get_images_num(all_images) if self.multimodal_processor and all_images is not None else None)
-
- # Get video counts
- videos_num = (get_videos_num(all_videos) if self.multimodal_processor and all_videos is not None else None)
-
- # ========== Stage 1: Sample Generation ==========
- Timer.start(' generate_samples')
- samples_list = self.generate_samples(
- all_prompts,
- all_images=all_images,
- images_num=images_num,
- all_videos=all_videos,
- videos_num=videos_num,
- all_references=all_references,
- all_labels=all_labels,
- **generate_kwargs,
- )
- Timer.stop(' generate_samples')
+ Timer.stop(' generate_samples')
- torch.distributed.barrier()
- torch.cuda.synchronize()
+ torch.distributed.barrier()
+ torch.cuda.synchronize()
- # ========== Stage 2: Shard-Parallel Preprocessing ==========
- all_samples = self.strategy.sp_data_processor.preprocess(samples_list)
+ with self.profiler.section("collect/sp_preprocess"):
+ all_samples = self.strategy.sp_data_processor.preprocess(samples_list)
- # ========== Stage 3: Model Inference ==========
- Timer.start(' make_experience')
- experiences = self._make_experience_list_by_model(all_samples)
- Timer.stop(' make_experience')
+ Timer.start(' make_experience')
+ with self.profiler.section("collect/model_total"):
+ experiences = self._make_experience_list_by_model(all_samples)
+ Timer.stop(' make_experience')
- # ========== Stage 4: Shard-Parallel Postprocessing ==========
- experiences = self.strategy.sp_data_processor.postprocess(experiences)
+ with self.profiler.section("collect/sp_postprocess"):
+ experiences = self.strategy.sp_data_processor.postprocess(experiences)
- # ========== Stage 5: Reward Processing ==========
- experiences, rewards = self._process_experiences( # GRPO's -mean / std operation is performed in this method
- experiences, generate_kwargs.get("max_new_tokens", 1024)
- )
+ with self.profiler.section("collect/process_rewards"):
+ experiences, rewards = self._process_experiences(
+ experiences, generate_kwargs.get("max_new_tokens", 1024)
+ )
- # ========== Stage 6: Multi-Image/Video Handling ==========
- if (images_num is not None and not all(num == 1 for num in images_num)) or \
- (videos_num is not None and not all(num == 1 for num in videos_num)):
- # Expand image_num by n_samples_per_prompt
- expanded_images_num = sum([[num] * config.n_samples_per_prompt
- for num in images_num], []) if images_num is not None else None
+ if (images_num is not None and not all(num == 1 for num in images_num)) or \
+ (videos_num is not None and not all(num == 1 for num in videos_num)):
+ expanded_images_num = sum([[num] * config.n_samples_per_prompt
+ for num in images_num], []) if images_num is not None else None
- expanded_videos_num = sum([[num] * config.n_samples_per_prompt
- for num in videos_num], []) if videos_num is not None else None
+ expanded_videos_num = sum([[num] * config.n_samples_per_prompt
+ for num in videos_num], []) if videos_num is not None else None
- self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num)
+ self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num)
- # ========== Stage 6.5: On-Policy Distillation Teacher Log-Probs ==========
- if config.advantage_estimator == "on_policy_distillation":
- self._fetch_teacher_logprobs(experiences)
+ if config.advantage_estimator == "on_policy_distillation":
+ self._fetch_teacher_logprobs(experiences)
- # ========== Stage 7: Advantage Computation ==========
- experiences = self._compute_advantages_and_returns(experiences, rewards, generate_kwargs)
+ with self.profiler.section("collect/advantages"):
+ experiences = self._compute_advantages_and_returns(experiences, rewards, generate_kwargs)
- return experiences
+ return experiences
@torch.no_grad()
def generate_samples(
@@ -1141,7 +1213,6 @@ def generate_samples(
is_multimodal = all_images is not None or all_videos is not None
n_samples = config.n_samples_per_prompt
- # Initialize multimodal-specific variables to None
all_images_num = None
all_videos_num = None
all_images_pixel_values = None
@@ -1149,144 +1220,144 @@ def generate_samples(
all_images_grid_thw = None
all_videos_grid_thw = None
- # ========== Configure Sampling Parameters ==========
- if config.engine_type == "vllm":
- # vLLM-specific sampling configuration
- # Note: vLLM is an optional dependency. Install with: pip install "LightRFT[vllm]"
- # This import is conditional and only executed when engine_type is "vllm"
- from vllm import SamplingParams
-
- # For vllm>=0.13.0, truncate_prompt_tokens must not exceed max_model_len
- # For older versions, we can use 8192 directly without validation
- if vllm_ge_0130():
- max_model_len = self.strategy.inference_engine.llm_engine.model_config.max_model_len
- truncate_tokens = min(8192, max_model_len)
+ with self.profiler.section("collect/generate_prepare"):
+ if config.engine_type == "vllm":
+ from vllm import SamplingParams
+
+ if vllm_ge_0130():
+ max_model_len = self.strategy.inference_engine.llm_engine.model_config.max_model_len
+ truncate_tokens = min(8192, max_model_len)
+ else:
+ truncate_tokens = 8192
+
+ sampling_params = SamplingParams(
+ temperature=generate_kwargs.get("temperature", 1.0),
+ top_p=generate_kwargs.get("top_p", 1.0),
+ top_k=generate_kwargs.get("top_k", -1),
+ repetition_penalty=generate_kwargs.get("repetition_penalty", 1.0),
+ max_tokens=generate_kwargs.get("max_new_tokens", 1024),
+ min_tokens=generate_kwargs.get("min_new_tokens", 1),
+ skip_special_tokens=generate_kwargs.get("skip_special_tokens", False),
+ include_stop_str_in_output=True,
+ ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1",
+ truncate_prompt_tokens=truncate_tokens,
+ )
+ elif config.engine_type == "sglang":
+ sampling_params = dict(
+ n=1,
+ temperature=generate_kwargs.get("temperature", 1.0),
+ top_p=generate_kwargs.get("top_p", 1.0),
+ top_k=generate_kwargs.get("top_k", -1),
+ max_new_tokens=generate_kwargs.get("max_new_tokens", 1024),
+ presence_penalty=0.0,
+ frequency_penalty=0.0,
+ repetition_penalty=generate_kwargs.get("repetition_penalty", 1.0),
+ skip_special_tokens=generate_kwargs.get("skip_special_tokens", False),
+ spaces_between_special_tokens=True,
+ ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1",
+ )
+ elif config.engine_type == "hf":
+ max_new_tokens = generate_kwargs.get("max_new_tokens", 1024)
+ local_hf_max_new_tokens = int(getattr(config, "local_hf_max_new_tokens", 0) or 0)
+ if local_hf_max_new_tokens > 0:
+ max_new_tokens = min(max_new_tokens, local_hf_max_new_tokens)
+ sampling_params = dict(
+ temperature=generate_kwargs.get("temperature", 1.0),
+ top_p=generate_kwargs.get("top_p", 1.0),
+ top_k=generate_kwargs.get("top_k", -1),
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=generate_kwargs.get("min_new_tokens", 1),
+ do_sample=generate_kwargs.get("do_sample", True),
+ repetition_penalty=generate_kwargs.get("repetition_penalty", 1.0),
+ no_repeat_ngram_size=generate_kwargs.get("no_repeat_ngram_size", 0),
+ skip_special_tokens=generate_kwargs.get("skip_special_tokens", False),
+ ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1",
+ )
else:
- truncate_tokens = 8192
-
- sampling_params = SamplingParams(
- temperature=generate_kwargs.get("temperature", 1.0),
- top_p=generate_kwargs.get("top_p", 1.0),
- top_k=generate_kwargs.get("top_k", -1),
- max_tokens=generate_kwargs.get("max_new_tokens", 1024),
- min_tokens=generate_kwargs.get("min_new_tokens", 1),
- skip_special_tokens=generate_kwargs.get("skip_special_tokens", False),
- include_stop_str_in_output=True,
- ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1",
- truncate_prompt_tokens=truncate_tokens,
- )
- elif config.engine_type == "sglang":
- # SGLang-specific sampling configuration (default backend)
- sampling_params = dict(
- n=1,
- temperature=generate_kwargs.get("temperature", 1.0),
- top_p=generate_kwargs.get("top_p", 1.0),
- top_k=generate_kwargs.get("top_k", -1),
- max_new_tokens=generate_kwargs.get("max_new_tokens", 1024),
- presence_penalty=0.0,
- frequency_penalty=0.0,
- repetition_penalty=1.0,
- skip_special_tokens=generate_kwargs.get("skip_special_tokens", False),
- spaces_between_special_tokens=True,
- ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1",
- )
- else:
- raise ValueError(f"Unsupported engine type: {config.engine_type}")
+ raise ValueError(f"Unsupported engine type: {config.engine_type}")
- # ========== Expand Labels ==========
- if all_labels is not None:
- all_labels = sum([[label] * n_samples for label in all_labels], [])
+ if all_labels is not None:
+ all_labels = sum([[label] * n_samples for label in all_labels], [])
- # ========== Process Multimodal Data ==========
- if is_multimodal:
- processed_data = self.multimodal_processor.process_multimodal_batch(
- all_prompts=all_prompts,
- all_images=all_images,
- all_references=all_references,
- images_num=images_num,
- n_samples_per_prompt=n_samples,
- all_videos=all_videos,
- videos_num=videos_num,
- )
- all_prompt_token_ids = processed_data["all_prompt_token_ids"]
- all_prompts = processed_data["all_prompts"]
- all_images = processed_data["all_images"]
- all_videos = processed_data["all_videos"]
- all_images_num = processed_data["all_images_num"]
- all_videos_num = processed_data["all_videos_num"]
- all_images_grid_thw = processed_data["all_images_grid_thw"]
- all_videos_grid_thw = processed_data["all_videos_grid_thw"]
- all_images_pixel_values = processed_data["all_images_pixel_values"]
- all_videos_pixel_values = processed_data["all_videos_pixel_values"]
- all_references = processed_data.get("all_references", None)
- else:
- # Text-only processing
- tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False)
- all_prompt_token_ids = sum([[token_ids] * n_samples for token_ids in tokenized["input_ids"]], [])
+ if is_multimodal:
+ processed_data = self.multimodal_processor.process_multimodal_batch(
+ all_prompts=all_prompts,
+ all_images=all_images,
+ all_references=all_references,
+ images_num=images_num,
+ n_samples_per_prompt=n_samples,
+ all_videos=all_videos,
+ videos_num=videos_num,
+ )
+ all_prompt_token_ids = processed_data["all_prompt_token_ids"]
+ all_prompts = processed_data["all_prompts"]
+ all_images = processed_data["all_images"]
+ all_videos = processed_data["all_videos"]
+ all_images_num = processed_data["all_images_num"]
+ all_videos_num = processed_data["all_videos_num"]
+ all_images_grid_thw = processed_data["all_images_grid_thw"]
+ all_videos_grid_thw = processed_data["all_videos_grid_thw"]
+ all_images_pixel_values = processed_data["all_images_pixel_values"]
+ all_videos_pixel_values = processed_data["all_videos_pixel_values"]
+ all_references = processed_data.get("all_references", None)
+ else:
+ tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False)
+ all_prompt_token_ids = sum([[token_ids] * n_samples for token_ids in tokenized["input_ids"]], [])
- # ========== Generate via Inference Engine ==========
- # Call fire_sampling function or direct generation
try:
- if hasattr(self.strategy.args, 'use_fire') and self.strategy.args.use_fire:
- # Use FIRE sampling (Flaming-hot Initiation with Regular Execution)
- # According to the paper (https://arxiv.org/abs/2410.21236), FIRE only changes
- # the temperature for the first token. All other sampling parameters (top_k, top_p, etc.)
- # are kept the same between first token and remaining tokens.
- sleep_engine = getattr(self.strategy.args, 'enable_engine_sleep', False)
-
- def generate_fn(
- sampling_params,
- all_prompt_token_ids,
- all_prompts=None,
- all_images=None,
- all_videos=None,
- images_num=None,
- videos_num=None,
- ):
- return self.strategy.gather_and_generate(
- sampling_params=sampling_params,
+ with self.profiler.section("collect/generate_engine"):
+ if hasattr(self.strategy.args, 'use_fire') and self.strategy.args.use_fire:
+ sleep_engine = getattr(self.strategy.args, 'enable_engine_sleep', False)
+
+ def generate_fn(
+ sampling_params,
+ all_prompt_token_ids,
+ all_prompts=None,
+ all_images=None,
+ all_videos=None,
+ images_num=None,
+ videos_num=None,
+ ):
+ return self.strategy.gather_and_generate(
+ sampling_params=sampling_params,
+ all_prompt_token_ids=all_prompt_token_ids,
+ all_prompts=all_prompts,
+ all_images=all_images,
+ all_videos=all_videos,
+ images_num=images_num,
+ videos_num=videos_num,
+ sleep_engine=sleep_engine,
+ )
+
+ all_outputs = fire_sampling(
all_prompt_token_ids=all_prompt_token_ids,
+ generate_fn=generate_fn,
+ engine_type=config.engine_type,
+ first_token_temperature=generate_kwargs.get("first_token_temperature", 10.0),
+ temperature=generate_kwargs.get("temperature", 1.0),
+ is_multimodal=is_multimodal,
all_prompts=all_prompts,
all_images=all_images,
all_videos=all_videos,
- images_num=images_num,
- videos_num=videos_num,
- sleep_engine=sleep_engine,
+ all_images_num=all_images_num,
+ all_videos_num=all_videos_num,
+ sampling_params=sampling_params,
+ )
+ else:
+ all_outputs = self.strategy.gather_and_generate(
+ sampling_params=sampling_params,
+ all_prompt_token_ids=all_prompt_token_ids,
+ all_prompts=all_prompts if is_multimodal else None,
+ sleep_engine=self.strategy.args.enable_engine_sleep,
+ all_images=all_images if is_multimodal else None,
+ all_videos=all_videos if is_multimodal else None,
+ images_num=all_images_num if is_multimodal else None,
+ videos_num=all_videos_num if is_multimodal else None,
+ all_images_pixel_values=all_images_pixel_values if is_multimodal else None,
+ all_videos_pixel_values=all_videos_pixel_values if is_multimodal else None,
+ all_images_grid_thw=all_images_grid_thw if is_multimodal else None,
+ all_videos_grid_thw=all_videos_grid_thw if is_multimodal else None,
)
-
- all_outputs = fire_sampling(
- all_prompt_token_ids=all_prompt_token_ids,
- generate_fn=generate_fn,
- engine_type=config.engine_type,
- first_token_temperature=generate_kwargs.get(
- "first_token_temperature",
- getattr(self.strategy.args, "first_token_temperature", 10.0),
- ),
- temperature=generate_kwargs.get("temperature", 1.0),
- # Note: first_token_top_k and first_token_top_p are deprecated and ignored
- # The function will use top_k and top_p from sampling_params for both stages
- is_multimodal=is_multimodal,
- all_prompts=all_prompts,
- all_images=all_images,
- all_videos=all_videos,
- all_images_num=all_images_num,
- all_videos_num=all_videos_num,
- sampling_params=sampling_params,
- tokenizer=self.tokenizer,
- )
- else:
- # maybe this can be called in if and else respectively? or like this?
- # Use original single-shot generation
- all_outputs = self.strategy.gather_and_generate(
- sampling_params=sampling_params,
- all_prompt_token_ids=all_prompt_token_ids,
- all_prompts=all_prompts if is_multimodal else None,
- sleep_engine=self.strategy.args.enable_engine_sleep,
- all_images=all_images if is_multimodal else None,
- all_videos=all_videos if is_multimodal else None,
- images_num=all_images_num if is_multimodal else None,
- videos_num=all_videos_num if is_multimodal else None,
- )
except ValueError as e:
if "prompt" in str(e) and "too long" in str(e):
self.strategy.print(f"[Skip] {e}")
@@ -1294,68 +1365,67 @@ def generate_fn(
else:
raise
- # ========== Process Outputs into Samples ==========
- samples_list = []
- image_patch_idx = 0
- video_patch_idx = 0
- image_start_idx = 0
- video_start_idx = 0
-
- for i in range(0, len(all_outputs), config.micro_rollout_batch_size):
- micro_batch_outputs = all_outputs[i:i + config.micro_rollout_batch_size]
- micro_batch_prompts = all_prompts[i:i + config.micro_rollout_batch_size]
-
- # Extract micro-batch data
- micro_batch_grid_thw = None
- micro_batch_video_grid_thw = None
- micro_batch_raw_images = None
-
- if is_multimodal:
- rollout_image_count = sum(all_images_num[i:i + config.micro_rollout_batch_size])
- micro_batch_grid_thw = all_images_grid_thw[image_start_idx:image_start_idx + rollout_image_count]
- micro_batch_raw_images = all_images[i:i + config.micro_rollout_batch_size]
- image_start_idx += rollout_image_count
-
- rollout_video_count = sum(all_videos_num[i:i + config.micro_rollout_batch_size])
- micro_batch_video_grid_thw = all_videos_grid_thw[video_start_idx:video_start_idx + rollout_video_count]
- video_start_idx += rollout_video_count
-
- micro_batch_references = (all_references[i:i + config.micro_rollout_batch_size] if all_references else None)
- micro_batch_labels = (all_labels[i:i + config.micro_rollout_batch_size] if all_labels else None)
-
- # Build samples
- if not self.packing_samples:
- sample, updated_patch_idx, updated_video_patch_idx = self._build_unpacked_sample(
- outputs=micro_batch_outputs,
- prompts=micro_batch_prompts,
- labels=micro_batch_labels,
- references=micro_batch_references,
- is_multimodal=is_multimodal,
- grid_thw=micro_batch_grid_thw,
- video_grid_thw=micro_batch_video_grid_thw,
- raw_images=micro_batch_raw_images,
- pixel_values=all_images_pixel_values if is_multimodal else None,
- pixel_values_videos=all_videos_pixel_values if is_multimodal else None,
- images_num=all_images_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None,
- videos_num=all_videos_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None,
- image_patch_idx=image_patch_idx,
- video_patch_idx=video_patch_idx,
- )
- # Update patch indices from the returned values
- if updated_patch_idx is not None:
- image_patch_idx = updated_patch_idx
- if updated_video_patch_idx is not None:
- video_patch_idx = updated_video_patch_idx
- samples_list.append(sample)
- else:
- # Packed samples
- sample = self._build_packed_sample(
- outputs=micro_batch_outputs,
- prompts=micro_batch_prompts,
- labels=micro_batch_labels,
- references=micro_batch_references,
+ with self.profiler.section("collect/generate_build_samples"):
+ samples_list = []
+ image_patch_idx = 0
+ video_patch_idx = 0
+ image_start_idx = 0
+ video_start_idx = 0
+
+ for i in range(0, len(all_outputs), config.micro_rollout_batch_size):
+ micro_batch_outputs = all_outputs[i:i + config.micro_rollout_batch_size]
+ micro_batch_prompts = all_prompts[i:i + config.micro_rollout_batch_size]
+
+ micro_batch_grid_thw = None
+ micro_batch_video_grid_thw = None
+ micro_batch_raw_images = None
+
+ if is_multimodal:
+ rollout_image_count = sum(all_images_num[i:i + config.micro_rollout_batch_size])
+ micro_batch_grid_thw = all_images_grid_thw[image_start_idx:image_start_idx + rollout_image_count]
+ micro_batch_raw_images = all_images[i:i + config.micro_rollout_batch_size]
+ image_start_idx += rollout_image_count
+
+ rollout_video_count = sum(all_videos_num[i:i + config.micro_rollout_batch_size])
+ micro_batch_video_grid_thw = all_videos_grid_thw[video_start_idx:video_start_idx +
+ rollout_video_count]
+ video_start_idx += rollout_video_count
+
+ micro_batch_references = (
+ all_references[i:i + config.micro_rollout_batch_size] if all_references else None
)
- samples_list.append(sample)
+ micro_batch_labels = (all_labels[i:i + config.micro_rollout_batch_size] if all_labels else None)
+
+ if not self.packing_samples:
+ sample, updated_patch_idx, updated_video_patch_idx = self._build_unpacked_sample(
+ outputs=micro_batch_outputs,
+ prompts=micro_batch_prompts,
+ labels=micro_batch_labels,
+ references=micro_batch_references,
+ is_multimodal=is_multimodal,
+ grid_thw=micro_batch_grid_thw,
+ video_grid_thw=micro_batch_video_grid_thw,
+ raw_images=micro_batch_raw_images,
+ pixel_values=all_images_pixel_values if is_multimodal else None,
+ pixel_values_videos=all_videos_pixel_values if is_multimodal else None,
+ images_num=all_images_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None,
+ videos_num=all_videos_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None,
+ image_patch_idx=image_patch_idx,
+ video_patch_idx=video_patch_idx,
+ )
+ if updated_patch_idx is not None:
+ image_patch_idx = updated_patch_idx
+ if updated_video_patch_idx is not None:
+ video_patch_idx = updated_video_patch_idx
+ samples_list.append(sample)
+ else:
+ sample = self._build_packed_sample(
+ outputs=micro_batch_outputs,
+ prompts=micro_batch_prompts,
+ labels=micro_batch_labels,
+ references=micro_batch_references,
+ )
+ samples_list.append(sample)
# Report timing
torch.cuda.synchronize()
@@ -1498,7 +1568,6 @@ def _fetch_teacher_logprobs(
:type experiences: List[Union[Experience, ExperienceVL]]
"""
- # Get teacher URL from config
teacher_url = self.teacher_model_url
if isinstance(teacher_url, list):
teacher_url = teacher_url[0] if teacher_url else None
@@ -1511,18 +1580,13 @@ def _fetch_teacher_logprobs(
Timer.start(' fetch_teacher_logprobs')
for exp in experiences:
- sequences = exp.sequences # [batch_size, seq_len]
- attention_mask = exp.attention_mask # [batch_size, seq_len]
- action_mask = exp.action_mask # [batch_size, num_tokens]
+ sequences = exp.sequences
+ attention_mask = exp.attention_mask
+ action_mask = exp.action_mask
- # response_lengths must be int for slicing
response_lengths = action_mask.sum(dim=-1).int().tolist()
num_tokens = action_mask.shape[1]
- # Strip padding tokens before sending to SGLang.
- # sequences is [prompt, response, eos, pad, pad, ...] — the padding
- # tokens would cause SGLang to return logprobs for pad positions,
- # making the [-resp_len:] slice grab wrong tokens.
input_ids_list = []
for j in range(sequences.shape[0]):
valid_len = int(attention_mask[j].sum().item())
@@ -1542,14 +1606,6 @@ def _fetch_teacher_logprobs(
finally:
loop.close()
- # Align teacher log probs to action_log_probs shape [batch_size, num_tokens].
- # Use action_mask indices directly — works regardless of left/right padding.
- #
- # Correctness: teacher_lp_list[i] contains teacher logprobs for response
- # tokens in first→last order (from teacher_lp[-resp_len:]).
- # valid_indices = ascending positions where action_mask==1 (real response tokens).
- # So aligned_teacher_lp[i, valid_indices[k]] = tlp[k] correctly maps the k-th
- # teacher logprob to the k-th response token position, regardless of padding direction.
batch_size = sequences.shape[0]
aligned_teacher_lp = torch.zeros(batch_size, num_tokens, dtype=torch.float32)
for i, (tlp, resp_len) in enumerate(zip(teacher_lp_list, response_lengths)):
@@ -1609,6 +1665,89 @@ def _process_experiences(
# Use calculator's preprocess_rewards method
return self.advantage_calculator.preprocess_rewards(rewards, experiences, max_new_tokens)
+ def _apply_step_reward_group_norm(self, experiences: List) -> None:
+ """Variant 2 (per-step PRM) GRPO-style step-level baseline subtraction.
+
+ Active only when ``self.strategy.config.per_step_reward_mode == "group_norm"``.
+ For each step k, computes mean/std across the K trajectories that share
+ the same prompt (group), then replaces step_rewards with
+ ``(s - mean) / std`` so that scattered token-level rewards are zero-mean
+ signed advantages — analogous to GRPO trajectory-level baseline but at
+ step granularity.
+
+ This is the difference between paper variant 2's two interpretations:
+ - "raw": scatter raw sigmoid step_score (paper Figure ablation)
+ - "group_norm": scatter group-normalized step_score (GRPO convention)
+
+ Padding (step_token_indices < 0) is masked so it doesn't pollute mean/std.
+ Operates in-place on ``experience.info["step_rewards"]``.
+ """
+ config = self.strategy.config
+ K = config.n_samples_per_prompt
+ if K is None or K < 2:
+ return # need >= 2 samples per group for std
+
+ all_sr: List[torch.Tensor] = []
+ all_sti: List[torch.Tensor] = []
+ for exp in experiences:
+ sr_list = exp.info.get("step_rewards")
+ sti_list = exp.info.get("step_token_indices")
+ if sr_list is None or sti_list is None:
+ return # no per-step data present
+ all_sr.extend(sr_list)
+ all_sti.extend(sti_list)
+
+ if not all_sr:
+ return
+ B = len(all_sr)
+ if B % K != 0:
+ warnings.warn(
+ f"per_step_reward_mode=group_norm: trajectory count {B} not divisible "
+ f"by n_samples_per_prompt {K}; skipping step-level group norm."
+ )
+ return
+
+ max_steps = max((t.numel() for t in all_sr), default=0)
+ if max_steps == 0:
+ return
+
+ sr_padded = torch.zeros(B, max_steps, dtype=torch.float32)
+ valid = torch.zeros(B, max_steps, dtype=torch.bool)
+ for i, (sr, sti) in enumerate(zip(all_sr, all_sti)):
+ n = sr.numel()
+ if n > 0:
+ sr_padded[i, :n] = sr.float()
+ valid[i, :n] = (sti >= 0)
+
+ G = B // K
+ sr_g = sr_padded.reshape(G, K, max_steps)
+ valid_g = valid.reshape(G, K, max_steps).float()
+
+ n_valid = valid_g.sum(dim=1, keepdim=True).clamp(min=1.0)
+ sr_masked = sr_g * valid_g
+ mean = sr_masked.sum(dim=1, keepdim=True) / n_valid
+ sq = ((sr_g - mean) * valid_g) ** 2
+ var = sq.sum(dim=1, keepdim=True) / n_valid
+ std = (var + 1e-9).sqrt()
+ sr_normed = (sr_g - mean) / std
+ # Zero out invalid entries so they don't show up if accidentally scattered.
+ sr_normed = sr_normed * valid_g
+ sr_normed_flat = sr_normed.reshape(B, max_steps)
+
+ # Write back, preserving each trajectory's actual step count (n).
+ idx = 0
+ for exp in experiences:
+ sr_list = exp.info["step_rewards"]
+ new_sr_list: List[torch.Tensor] = []
+ for sr in sr_list:
+ n = sr.numel()
+ if n > 0:
+ new_sr_list.append(sr_normed_flat[idx, :n].cpu().to(sr.dtype))
+ else:
+ new_sr_list.append(sr)
+ idx += 1
+ exp.info["step_rewards"] = new_sr_list
+
def _compute_advantages_and_returns(
self,
experiences: List[ExperienceVL],
@@ -1632,6 +1771,12 @@ def _compute_advantages_and_returns(
"""
config = self.strategy.config
+ # Variant 2 step-level baseline subtraction (cross-experience, group-aware).
+ # Only active when label produced step_rewards AND user set
+ # --per_step_reward_mode=group_norm.
+ if getattr(config, "per_step_reward_mode", "raw") == "group_norm":
+ self._apply_step_reward_group_norm(experiences)
+
for experience, reward in zip(experiences, rewards):
reward = reward.to("cuda")
processed_reward = reward.clone() # TODO:check
@@ -1653,12 +1798,42 @@ def _compute_advantages_and_returns(
processed_reward = torch.clamp(processed_reward, -config.reward_clip, config.reward_clip)
# ========== Final Reward (with KL penalty) ==========
+ # Variant 2 (per-step PRM): if experience carries
+ # step_rewards/step_token_indices (placed by _pack_experience
+ # when the reward model emitted them), build a padded
+ # (batch, max_steps) tensor for compute_reward to scatter into
+ # per-step boundary tokens. Trajectories with empty step lists
+ # get all-(-1) indices => compute_reward filters them out and
+ # falls back to the trajectory-scalar (EOS) path for that row.
+ step_rewards_padded = None
+ step_indices_padded = None
+ step_rewards_list = experience.info.get("step_rewards")
+ step_indices_list = experience.info.get("step_token_indices")
+ if (
+ step_rewards_list is not None and step_indices_list is not None
+ and any(t.numel() > 0 for t in step_rewards_list)
+ ):
+ max_steps = max(t.numel() for t in step_rewards_list)
+ if max_steps > 0:
+ B = len(step_rewards_list)
+ step_rewards_padded = torch.zeros(B, max_steps, dtype=torch.float32, device="cuda")
+ step_indices_padded = torch.full((B, max_steps), -1, dtype=torch.long, device="cuda")
+ for i, (sr, sti) in enumerate(zip(step_rewards_list, step_indices_list)):
+ n = sr.numel()
+ if n > 0:
+ step_rewards_padded[
+ i, :n] = sr.to(step_rewards_padded.device, dtype=step_rewards_padded.dtype)
+ step_indices_padded[
+ i, :n] = sti.to(step_indices_padded.device, dtype=step_indices_padded.dtype)
+
final_reward = compute_reward(
processed_reward,
self.kl_ctl.value,
experience.kl,
action_mask=experience.action_mask,
num_actions=experience.info["num_actions"],
+ step_rewards=step_rewards_padded,
+ step_token_indices=step_indices_padded,
)
# ========== Advantage Estimation ==========
@@ -1714,80 +1889,59 @@ def _make_experience_list_by_model(
device = get_current_device()
vlm_mode = isinstance(all_samples[0], SamplesVL)
- # ========== Stage 0: Preprocessing ==========
outputs = [self._preprocess_sample(sample, vlm_mode, device) for sample in all_samples]
- # ========== Stage 1: Actor Forward ==========
Timer.start(' actor_logprob')
- # Check if we need to compute entropy for high-entropy token filtering
- need_entropy = hasattr(self.actor, 'high_entropy_token_ratio') and self.actor.high_entropy_token_ratio > 0.0
- for output in outputs:
- if need_entropy:
- # Request full output to get action_entropy
- action_log_probs, model_output = self.actor(
- output.sequences,
- output.num_actions,
- output.attention_mask,
- packed_seq_lens=output.packed_seq_lens,
- return_output=True,
- **output.inputs_extra_kwargs
- )
- output.action_log_probs = action_log_probs
- # Extract action_entropy if available
- if "action_entropy" in model_output:
- output.action_entropy = model_output["action_entropy"]
- else:
- output.action_log_probs = self.actor(
- output.sequences,
- output.num_actions,
- output.attention_mask,
- packed_seq_lens=output.packed_seq_lens,
- **output.inputs_extra_kwargs
- )
+ with self.profiler.section("collect/model/actor_logprob"):
+ need_entropy = hasattr(self.actor, 'high_entropy_token_ratio') and self.actor.high_entropy_token_ratio > 0.0
+ for output in outputs:
+ if need_entropy:
+ action_log_probs, model_output = self.actor(
+ output.sequences,
+ output.num_actions,
+ output.attention_mask,
+ packed_seq_lens=output.packed_seq_lens,
+ return_output=True,
+ **output.inputs_extra_kwargs
+ )
+ output.action_log_probs = action_log_probs
+ if "action_entropy" in model_output:
+ output.action_entropy = model_output["action_entropy"]
+ else:
+ output.action_log_probs = self.actor(
+ output.sequences,
+ output.num_actions,
+ output.attention_mask,
+ packed_seq_lens=output.packed_seq_lens,
+ **output.inputs_extra_kwargs
+ )
Timer.stop(' actor_logprob')
- # ========== Stage 2: Initial Model ==========
if self.initial_model is not None:
- # Note: Manual reload/offload is safe for initial_model because:
- # 1. It's initialized with is_training=False (see train_colocate.py:207)
- # 2. This means FSDP's CPUOffloadPolicy is NOT enabled (see fsdpv2.py:375)
- # 3. Without CPUOffloadPolicy, FSDP doesn't automatically manage parameter movement
- # 4. We can safely use manual reload_model() to move model from CPU to GPU
- # 5. After computing base_action_log_probs, we offload back to CPU to save memory
- # This pattern works because there's no conflict with FSDP's automatic management.
- self.strategy.reload_model(self.initial_model)
- for output in outputs:
- output.base_action_log_probs = self.initial_model(
- output.sequences,
- output.num_actions,
- output.attention_mask,
- packed_seq_lens=output.packed_seq_lens,
- **output.inputs_extra_kwargs
- )
- # Offload back to CPU to free GPU memory for subsequent stages
- self.strategy.offload_model(self.initial_model)
+ with self.profiler.section("collect/model/reference_logprob"):
+ self.strategy.reload_model(self.initial_model)
+ for output in outputs:
+ output.base_action_log_probs = self.initial_model(
+ output.sequences,
+ output.num_actions,
+ output.attention_mask,
+ packed_seq_lens=output.packed_seq_lens,
+ **output.inputs_extra_kwargs
+ )
+ self.strategy.offload_model(self.initial_model)
- # ========== Stage 3: Critic ==========
- Timer.start(' critic')
if self.critic is not None:
- # Note: When critic is initialized with is_training=True and fsdp_cpu_offload=True,
- # FSDP's CPUOffloadPolicy automatically manages parameter movement between CPU/GPU.
- # Manual reload_model/offload_model calls will conflict with FSDP's automatic management
- # and cause "FSDP parameters should be materialized on CPU" error.
- # The CPUOffloadPolicy will automatically:
- # 1. Prefetch parameters from CPU to GPU before forward pass
- # 2. Offload parameters back to CPU after forward pass
- # This is the recommended approach for memory-efficient training with FSDP2.
- for output in outputs:
- output.value = self.critic(
- output.sequences, output.num_actions, output.attention_mask, **output.inputs_extra_kwargs
- )
- Timer.stop(' critic')
+ with self.profiler.section("collect/model/critic_forward"):
+ self.strategy.reload_model(self.critic)
+ for output in outputs:
+ output.value = self.critic(
+ output.sequences, output.num_actions, output.attention_mask, **output.inputs_extra_kwargs
+ )
+ self.strategy.offload_model(self.critic)
- # ========== Stage 4: Reward Models ==========
- self.reward_engine.compute_rewards(outputs, vlm_mode, device)
+ with self.profiler.section("collect/model/reward_forward"):
+ self.reward_engine.compute_rewards(outputs, vlm_mode, device)
- # ========== Stage 5: Assemble Experiences ==========
return [self._pack_experience(output, vlm_mode) for output in outputs]
def _preprocess_sample(
@@ -1832,9 +1986,6 @@ def _preprocess_sample(
"pixel_values_videos": sample.pixel_values_videos,
"video_grid_thw": sample.video_grid_thws,
}
- # Audio-language actors expect audio_values; pipeline stores them in pixel_values slot
- if "audio_values" in self._actor_supported_params:
- candidate_params["audio_values"] = candidate_params.get("pixel_values")
# Filter to only include supported parameters
extra_kwargs = {
@@ -1892,16 +2043,16 @@ def _fix_qwen_vl_image_tokens(
return
config = self.strategy.unwrap_model(self.actor.model).config
- image_token_id = config.image_token_id
+ image_token_id = getattr(config, "image_token_id", getattr(config, "image_token_index", None))
+ if image_token_id is None:
+ return
num_tokens = (sequences == image_token_id).sum()
- vision_config = getattr(config, "vision_config", None)
- spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2)
- merge_length = spatial_merge_size ** 2
-
- if sample.image_grid_thws is not None and sample.image_grid_thws.numel() > 0:
- num_patches = int(torch.prod(sample.image_grid_thws, dim=-1).sum().item() // merge_length)
+ # Qwen2-VL style processors usually flatten patches (dim < 4), while some
+ # VLMs such as URSA keep one image tensor per item (dim == 4). Handle both.
+ if sample.pixel_values.dim() >= 4:
+ num_patches = sample.pixel_values.shape[0]
else:
- num_patches = sample.pixel_values.shape[0] // merge_length
+ num_patches = sample.pixel_values.shape[0] // 4
if num_tokens != num_patches:
self.strategy.print(
@@ -1972,6 +2123,17 @@ def _pack_experience(
if output.reward_metrics is not None:
info['reward_metrics'] = output.reward_metrics
+ # Variant 2 (per-step PRM rewards): forward to experience.info so
+ # downstream advantage computation can build per-token reward signals.
+ # Stored as Python lists of CPU tensors (variable length per traj).
+ # _process_experiences chunks experiences along micro_batch_size,
+ # so we keep the *micro-batch local* indexing here — i.e. info
+ # carries the lists scoped to this single _SamplesOutput.
+ if output.step_rewards is not None:
+ info['step_rewards'] = output.step_rewards
+ if output.step_token_indices is not None:
+ info['step_token_indices'] = output.step_token_indices
+
# Create Experience object
if vlm:
return ExperienceVL(
@@ -1991,8 +2153,6 @@ def _pack_experience(
info=info,
kl=kl,
action_entropy=output.action_entropy,
- labels=output.labels, # data source labels (if available, e.g., "gsm8k_rule")
- references=output.references, # ground truth references (if available, e.g., correct answers)
)
else:
return Experience(
diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py
index c9cdb8e1..08b406ae 100644
--- a/lightrft/trainer/ppo_trainer_vl.py
+++ b/lightrft/trainer/ppo_trainer_vl.py
@@ -1,7 +1,9 @@
import os
import sys
import os.path
+from collections import defaultdict
from abc import ABC
+from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional
import torch
@@ -15,10 +17,28 @@
from lightrft.models.actor_modality import ActorModality, get_supported_parameters
from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl
from lightrft.utils.distributed_sampler import DistributedSampler
-from lightrft.utils import rotate_ckpt_dirs
from lightrft.trainer import AdaptiveKLController, ExperienceVL, FixedKLController, NaiveExperienceMakerVL, NaiveReplayBufferVL # noqa
+class _NullStepProfiler:
+ @contextmanager
+ def section(self, _name: str):
+ yield
+
+ @contextmanager
+ def phase(self, _name: str):
+ yield
+
+ def start_step(self, *_args, **_kwargs):
+ return None
+
+ def finish_step(self, *_args, **_kwargs):
+ return None
+
+ def close(self):
+ return None
+
+
class PPOTrainerVL(ABC):
"""
Trainer for Proximal Policy Optimization (PPO) algorithm for Vision-Language Models.
@@ -162,7 +182,6 @@ def __init__(
self.reward_fn = reward_fn
self.reward_fn_label_map = reward_fn_label_map
self.reward_recipe = reward_recipe
- self.is_lora = getattr(self.args, "lora_rank", 0) > 0
self.actor = actor
self.critic = critic
@@ -217,6 +236,7 @@ def __init__(
self._tensorboard = None
self.eval_step_counter = 0 # Independent counter for eval X-axis
self.wandb_log_counter = 0 # Global counter for unique wandb system steps
+ self.profiler = getattr(self, "profiler", _NullStepProfiler())
if self.strategy.args.use_wandb and self.strategy.is_rank_0():
import wandb
@@ -254,6 +274,28 @@ def __init__(
log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
self._tensorboard = SummaryWriter(log_dir=log_dir)
+ def _update_wandb_summary(self, logs: Dict[str, Any]) -> None:
+ if self._wandb is None or not self.strategy.is_rank_0() or not logs:
+ return
+
+ summary_logs = {}
+ for key, value in logs.items():
+ if isinstance(value, torch.Tensor):
+ if value.numel() != 1:
+ continue
+ value = value.item()
+ elif hasattr(value, "item") and not isinstance(value, (str, bytes)):
+ try:
+ value = value.item()
+ except (TypeError, ValueError):
+ pass
+
+ if isinstance(value, (int, float, bool, str)):
+ summary_logs[key] = value
+
+ if summary_logs and self._wandb.run is not None:
+ self._wandb.run.summary.update(summary_logs)
+
def fit(
self,
args,
@@ -359,6 +401,8 @@ def fit(
)
for batch in self.prompts_dataloader:
+ if hasattr(self, "profiler") and self.profiler is not None:
+ self.profiler.start_step(steps, episode)
# Compatible with both image-only (4 args) and video (5 args) dataloaders
if len(batch) == 5:
rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch
@@ -366,9 +410,14 @@ def fit(
rand_prompts, rand_images, rand_references, rand_labels = batch
rand_videos = None
- # TODO: Remove debug print
+ batch_preview = min(2, len(rand_prompts))
self.strategy.print(
- f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa
+ "collect phase batch summary: "
+ f"batch_size={len(rand_prompts)}, "
+ f"preview_prompts={rand_prompts[:batch_preview]}, "
+ f"preview_images={rand_images[:batch_preview]}, "
+ f"preview_references={rand_references[:batch_preview]}, "
+ f"preview_labels={rand_labels[:batch_preview]}"
)
for i, experience in enumerate(
@@ -403,9 +452,7 @@ def fit(
rollout_status = {}
if self.replay_buffer.items:
all_rewards = []
- all_format_rewards = []
- all_accuracy_rewards = []
- all_general_model_rewards = []
+ reward_metric_values = defaultdict(list)
all_response_lengths = []
for item in self.replay_buffer.items:
@@ -421,19 +468,9 @@ def fit(
hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info
and item.info['reward_metrics'] is not None
):
-
reward_metrics = item.info['reward_metrics']
-
- # Safely extract sub-metrics
- if 'format_reward' in reward_metrics:
- all_format_rewards.append(reward_metrics['format_reward'])
- if 'accuracy_reward' in reward_metrics:
- all_accuracy_rewards.append(reward_metrics['accuracy_reward'])
- general_model_reward = reward_metrics.get("general_model_reward")
- if general_model_reward is None:
- general_model_reward = reward_metrics.get("model_reward")
- if general_model_reward is not None:
- all_general_model_rewards.append(general_model_reward)
+ for key, value in reward_metrics.items():
+ reward_metric_values[key].append(value)
# Collect response lengths from rollout
if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info:
@@ -451,42 +488,18 @@ def fit(
rollout_status["rollout_reward"] = rewards_tensor.mean().item()
rollout_status["rollout_reward_std"] = rewards_tensor.std().item()
- if all_format_rewards:
- # [TENSOR-FIX] Handle both tensor lists and scalar lists
- # Issue: all_format_rewards may contain tensors (from reward_metrics),
- # but torch.tensor() cannot convert a list of tensors directly.
- # Solution: Use torch.cat() for tensor lists, torch.tensor() for scalar lists
- if isinstance(all_format_rewards[0], torch.Tensor):
- # List of tensors: concatenate them
- format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards])
- else:
- # List of scalars: convert to tensor
- format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device)
-
- mean_format_reward = format_tensor.mean().item()
- rollout_status["rollout_format_reward"] = mean_format_reward
-
- if all_accuracy_rewards:
- # [TENSOR-FIX] Handle both tensor lists and scalar lists
- if isinstance(all_accuracy_rewards[0], torch.Tensor):
- accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards])
- else:
- accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device)
-
- mean_accuracy_reward = accuracy_tensor.mean().item()
- rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward
-
- if all_general_model_rewards:
- if isinstance(all_general_model_rewards[0], torch.Tensor):
- general_model_tensor = torch.cat([t.to(device).float() for t in all_general_model_rewards])
+ for metric_name, values in reward_metric_values.items():
+ if not values:
+ continue
+ if isinstance(values[0], torch.Tensor):
+ metric_tensor = torch.cat([t.to(device).float() for t in values])
else:
- general_model_tensor = torch.tensor(
- all_general_model_rewards, dtype=torch.float32, device=device
- )
-
- mean_general_model_reward = general_model_tensor.mean().item()
- if abs(mean_general_model_reward) > 1e-6:
- rollout_status["rollout_general_model_reward"] = mean_general_model_reward
+ metric_tensor = torch.tensor(values, dtype=torch.float32, device=device)
+ if metric_tensor.numel() == 0:
+ continue
+ mean_metric = metric_tensor.mean().item()
+ if abs(mean_metric) > 1e-6:
+ rollout_status[f"rollout_{metric_name}"] = mean_metric
if all_response_lengths:
# [TENSOR-FIX] Handle both tensor lists and scalar lists
@@ -524,9 +537,16 @@ def fit(
self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode)
+ if hasattr(self, "profiler") and self.profiler is not None:
+ profile_snapshot = self.profiler.finish_step()
+ if hasattr(self, "log_profile_metrics"):
+ self.log_profile_metrics(steps, episode, profile_snapshot)
+
pbar.update()
steps = steps + 1
+ if hasattr(self, "profiler") and self.profiler is not None:
+ self.profiler.close()
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
if self._tensorboard is not None and self.strategy.is_rank_0():
@@ -742,125 +762,120 @@ def training_step_actor(self,
)
return {} # Emergency fallback - should not normally execute
- # Actor loss
- # Build kwargs based on actor's modality - only include supported parameters
- candidate_params = {
- "pixel_values": pixel_values,
- "image_grid_thw": image_grid_thws,
- "pixel_values_videos": pixel_values_videos,
- "video_grid_thw": video_grid_thws,
- }
- # Audio-language actors expect audio_values; pipeline stores them in pixel_values slot
- if "audio_values" in self._actor_supported_params:
- candidate_params["audio_values"] = candidate_params.get("pixel_values")
-
- actor_kwargs = {key: value for key, value in candidate_params.items() if key in self._actor_supported_params}
-
- action_log_probs, output = self.actor(
- sequences,
- num_actions,
- attention_mask=attention_mask,
- return_output=True,
- packed_seq_lens=packed_seq_lens,
- **actor_kwargs
- )
+ with self.profiler.section("learn/actor/total"):
+ candidate_params = {
+ "pixel_values": pixel_values,
+ "image_grid_thw": image_grid_thws,
+ "pixel_values_videos": pixel_values_videos,
+ "video_grid_thw": video_grid_thws,
+ }
+
+ actor_kwargs = {
+ key: value
+ for key, value in candidate_params.items()
+ if key in self._actor_supported_params
+ }
+
+ with self.profiler.section("learn/actor/forward"):
+ action_log_probs, output = self.actor(
+ sequences,
+ num_actions,
+ attention_mask=attention_mask,
+ return_output=True,
+ packed_seq_lens=packed_seq_lens,
+ **actor_kwargs
+ )
# NOTE: Explicit masking in log-space is incorrect - removed
# if experience.action_mask is not None:
# # Setting masked positions to 0 to match old_action_log_probs is WRONG in log-space
# action_log_probs = action_log_probs * experience.action_mask
- # Loss function
- actor_loss = self.actor_loss_fn(
- action_log_probs,
- old_action_log_probs,
- advantages,
- action_mask=experience.action_mask,
- entropy_mask=entropy_mask,
- )
-
- if self.args.use_kl_loss:
- if self.initial_model is not None:
- # TODO(pu): Text-only action mask for KL calculation
-
- kl = compute_approx_kl(
+ with self.profiler.section("learn/actor/loss"):
+ actor_loss = self.actor_loss_fn(
action_log_probs,
- base_action_log_probs,
- experience.action_mask,
- kl_estimator=self.args.kl_estimator,
+ old_action_log_probs,
+ advantages,
+ action_mask=experience.action_mask,
+ entropy_mask=entropy_mask,
)
- # [Protection measure 2] Per-token KL Clamping
- # NOTE: Adding this causes svkng training to not converge
- # kl = torch.clamp(kl, min=0.0, max=20.0)
+ if self.args.use_kl_loss:
+ if self.initial_model is not None:
+ kl = compute_approx_kl(
+ action_log_probs,
+ base_action_log_probs,
+ experience.action_mask,
+ kl_estimator=self.args.kl_estimator,
+ )
+ else:
+ kl = torch.zeros_like(
+ action_log_probs,
+ dtype=action_log_probs.dtype,
+ device=action_log_probs.device,
+ )
- else:
- kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device)
+ if not self.args.packing_samples:
+ kl_mean = masked_mean(kl, experience.action_mask, dim=-1)
+ else:
+ kl = unpacking_samples(kl, num_actions)
+ kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device)
- if not self.args.packing_samples:
- kl_mean = masked_mean(kl, experience.action_mask, dim=-1)
- # Not supported for packed samples
+ kl_loss = kl_mean.mean()
+ experience.info["kl"] = kl_loss.item()
else:
- # Convert tensor into list of tensors for easier manipulation within dataset
- kl = unpacking_samples(kl, num_actions)
- kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device)
-
- kl_loss = kl_mean.mean()
- experience.info["kl"] = kl_loss.item()
- else:
- kl_loss = 0
-
- # Mixtral auxiliary loss
- if self.aux_loss:
- aux_loss = output.aux_loss
- else:
- aux_loss = 0
-
- loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value
-
- if torch.isnan(loss) or torch.isinf(loss):
- self.strategy.print("[CRITICAL ERROR] Actor loss is NaN or Inf at step. Skipping update.")
- self.strategy.print(f" Actor Loss: {actor_loss.item()}")
- self.strategy.print(f" KL Loss: {kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss}")
-
- self.strategy.backward(loss, self.actor, self.actor_optim)
-
- # PTX loss for supervised fine-tuning
- if self.pretrain_dataloader is not None:
- data = next(self.pretrain_dataloader)
- inputs = data[1].squeeze(1).to(torch.cuda.current_device())
- attention_mask = data[2].squeeze(1).to(torch.cuda.current_device())
- label = torch.where(
- attention_mask.bool(),
- inputs,
- self.ptx_loss_fn.IGNORE_INDEX,
- )
- pixel_values = data[3].to(torch.cuda.current_device())
- image_grid_thws = data[4].to(torch.cuda.current_device())
-
- output = self.actor(
- inputs,
- attention_mask=attention_mask,
- pixel_values=pixel_values,
- image_grid_thw=image_grid_thws,
- return_output=True
- )
- ptx_log_probs = output["logits"]
+ kl_loss = 0
- # Loss function
- ptx_loss = self.ptx_loss_fn(ptx_log_probs, label)
- # Mixtral auxiliary loss
if self.aux_loss:
aux_loss = output.aux_loss
else:
aux_loss = 0
- loss = ptx_loss + aux_loss * self.args.aux_loss_coef
- self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim)
-
- self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
- if self.ema_model:
- self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda")
+ loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value
+
+ if torch.isnan(loss) or torch.isinf(loss):
+ self.strategy.print("[CRITICAL ERROR] Actor loss is NaN or Inf at step. Skipping update.")
+ self.strategy.print(f" Actor Loss: {actor_loss.item()}")
+ self.strategy.print(f" KL Loss: {kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss}")
+
+ with self.profiler.section("learn/actor/backward"):
+ self.strategy.backward(loss, self.actor, self.actor_optim)
+
+ if self.pretrain_dataloader is not None:
+ with self.profiler.section("learn/actor/ptx"):
+ data = next(self.pretrain_dataloader)
+ inputs = data[1].squeeze(1).to(torch.cuda.current_device())
+ attention_mask = data[2].squeeze(1).to(torch.cuda.current_device())
+ label = torch.where(
+ attention_mask.bool(),
+ inputs,
+ self.ptx_loss_fn.IGNORE_INDEX,
+ )
+ pixel_values = data[3].to(torch.cuda.current_device())
+ image_grid_thws = data[4].to(torch.cuda.current_device())
+
+ output = self.actor(
+ inputs,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thws,
+ return_output=True
+ )
+ ptx_log_probs = output["logits"]
+ ptx_loss = self.ptx_loss_fn(ptx_log_probs, label)
+ if self.aux_loss:
+ aux_loss = output.aux_loss
+ else:
+ aux_loss = 0
+ loss = ptx_loss + aux_loss * self.args.aux_loss_coef
+ self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim)
+
+ with self.profiler.section("learn/actor/optimizer_step"):
+ self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor")
+
+ if self.ema_model:
+ with self.profiler.section("learn/actor/ema"):
+ self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda")
# Status
status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]}
@@ -988,33 +1003,35 @@ def ensure_device_and_contiguous(tensor, name="tensor"):
sequences = ensure_device_and_contiguous(sequences, "sequences")
attention_mask = ensure_device_and_contiguous(attention_mask, "attention_mask")
- # Critic loss
- values, output = self.critic(
- sequences,
- num_actions=num_actions,
- attention_mask=attention_mask,
- pixel_values=pixel_values,
- image_grid_thw=image_grid_thws,
- pixel_values_videos=pixel_values_videos,
- video_grid_thw=video_grid_thws,
- return_output=True,
- packed_seq_lens=packed_seq_lens,
- )
- # Loss function
- critic_loss = self.critic_loss_fn(
- values,
- old_values,
- returns,
- action_mask=experience.action_mask,
- )
- # Mixtral auxiliary loss
- if self.aux_loss:
- aux_loss = output.aux_loss
- else:
- aux_loss = 0
- loss = critic_loss + aux_loss * self.args.aux_loss_coef
- self.strategy.backward(loss, self.critic, self.critic_optim)
- self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic")
+ with self.profiler.section("learn/critic/total"):
+ with self.profiler.section("learn/critic/forward"):
+ values, output = self.critic(
+ sequences,
+ num_actions=num_actions,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thws,
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thws,
+ return_output=True,
+ packed_seq_lens=packed_seq_lens,
+ )
+ with self.profiler.section("learn/critic/loss"):
+ critic_loss = self.critic_loss_fn(
+ values,
+ old_values,
+ returns,
+ action_mask=experience.action_mask,
+ )
+ if self.aux_loss:
+ aux_loss = output.aux_loss
+ else:
+ aux_loss = 0
+ loss = critic_loss + aux_loss * self.args.aux_loss_coef
+ with self.profiler.section("learn/critic/backward"):
+ self.strategy.backward(loss, self.critic, self.critic_optim)
+ with self.profiler.section("learn/critic/optimizer_step"):
+ self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic")
# Status
status = {
@@ -1090,10 +1107,10 @@ def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, c
for k, v in self.experience_maker.perf_stats.items():
all_wandb_logs[f"perf/experience_maker/{k}"] = v
- # Commit Train/Rollout logs with unique system step
if all_wandb_logs:
self.wandb_log_counter += 1
self._wandb.log(all_wandb_logs, step=self.wandb_log_counter, commit=True)
+ self._update_wandb_summary(all_wandb_logs)
# TensorBoard Logging
elif self._tensorboard is not None and self.strategy.is_rank_0():
@@ -1125,12 +1142,9 @@ def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, c
eval_logs["eval/train_step"] = global_step
eval_logs["eval/episode"] = episode
- # IMPORTANT:
- # Use wandb_log_counter to ensure eval has a unique system step
- # This prevents eval metrics from being overwritten by train metrics
- # The plots will still use eval/global_step as X-axis due to define_metric
self.wandb_log_counter += 1
self._wandb.log(eval_logs, step=self.wandb_log_counter, commit=True)
+ self._update_wandb_summary(eval_logs)
# TensorBoard Logging for Eval
elif self._tensorboard is not None:
@@ -1155,11 +1169,10 @@ def _save_checkpoint(self, args, tag, client_states):
:param client_states: Client state for checkpoint recovery.
:type client_states: dict
"""
- ckpt_path = args.ckpt_path
- if not self.disable_ds_ckpt and not self.is_lora:
+ if not self.disable_ds_ckpt:
self.strategy.save_ckpt(
self.actor.model,
- os.path.join(ckpt_path, "_actor"),
+ os.path.join(args.ckpt_path, "_actor"),
tag,
args.max_ckpt_num,
args.max_ckpt_mem,
@@ -1167,24 +1180,11 @@ def _save_checkpoint(self, args, tag, client_states):
)
if self.critic is not None:
self.strategy.save_ckpt(
- self.critic, os.path.join(ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
- )
-
- # For LoRA, we ALWAYS save the HF adapter as it is much smaller and more convenient for deployment.
- if self.save_hf_ckpt or self.is_lora:
- # Rotate HF checkpoints
- if self.strategy.is_rank_0():
- os.makedirs(ckpt_path, exist_ok=True)
- max_num = getattr(args, "max_ckpt_num", 3)
- rotate_ckpt_dirs(
- ckpt_path,
- max_num,
- suffix="_lora",
- strategy=self.strategy,
- label="HF ckpt",
+ self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
)
- save_path = os.path.join(ckpt_path, f"{tag}_lora")
+ if self.save_hf_ckpt:
+ save_path = os.path.join(args.ckpt_path, f"{tag}_hf")
self.strategy.save_model(self.actor, self.tokenizer, save_path)
def evaluate(self, eval_dataloader, global_step):
@@ -1210,9 +1210,7 @@ def evaluate(self, eval_dataloader, global_step):
self.critic.eval()
all_rewards = []
- all_format_rewards = []
- all_accuracy_rewards = []
- all_general_model_rewards = []
+ reward_metric_values = defaultdict(list)
all_response_lengths = []
num_eval_batches = 0
@@ -1258,15 +1256,8 @@ def extract_values(val):
if 'reward_metrics' in info:
rm = info['reward_metrics']
- if 'format_reward' in rm:
- all_format_rewards.extend(extract_values(rm['format_reward']))
- if 'accuracy_reward' in rm:
- all_accuracy_rewards.extend(extract_values(rm['accuracy_reward']))
- general_model_reward = rm.get("general_model_reward")
- if general_model_reward is None:
- general_model_reward = rm.get("model_reward")
- if general_model_reward is not None:
- all_general_model_rewards.extend(extract_values(general_model_reward))
+ for key, value in rm.items():
+ reward_metric_values[key].extend(extract_values(value))
num_eval_batches += 1
if num_eval_batches >= len(eval_dataloader):
@@ -1287,9 +1278,8 @@ def compute_stats(name, values_list):
# metrics[f"{name}_std"] = t.std().item() # Optional
compute_stats("reward", all_rewards)
- compute_stats("format_reward", all_format_rewards)
- compute_stats("accuracy_reward", all_accuracy_rewards)
- compute_stats("general_model_reward", all_general_model_rewards)
+ for metric_name, values in reward_metric_values.items():
+ compute_stats(metric_name, values)
compute_stats("response_length", all_response_lengths)
metrics["num_samples"] = len(all_rewards)
diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py
index 026056da..95c4ce9d 100644
--- a/lightrft/trainer/spmd_ppo_trainer.py
+++ b/lightrft/trainer/spmd_ppo_trainer.py
@@ -18,7 +18,9 @@
- Efficient distributed training across multiple devices and nodes
"""
+import os
import time
+from collections import defaultdict
import torch
from tqdm import tqdm
@@ -30,7 +32,7 @@
from lightrft.trainer.replay_buffer import make_experience_batch
from lightrft.trainer.replay_buffer_vl import make_experience_batch as make_experience_batch_vl
from lightrft.models.utils import create_high_entropy_mask
-from lightrft.utils import init_logger
+from lightrft.utils import StepProfileRecorder, init_logger
logger = init_logger(__name__)
@@ -115,6 +117,11 @@ def __init__(
# TODO: here we pass a list of concrete params, this may collapse in future versions.
# Create experience maker with appropriate parameters
processor = kwargs.pop("processor", None)
+ self.profiler = StepProfileRecorder(
+ enabled=bool(getattr(self.args, "enable_profile", False)),
+ output_dir=os.path.join(self.args.save_path, "profile"),
+ print_fn=self.strategy.print,
+ )
self.experience_maker = FastExperienceMaker(
self.actor,
@@ -131,6 +138,7 @@ def __init__(
self.reward_recipe,
packing_samples=self.packing_samples,
processor=processor,
+ profiler=self.profiler,
)
# Extract high_entropy_token_ratio for entropy-based token filtering
@@ -192,322 +200,272 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train
torch.cuda.synchronize()
train_begin = time.time()
- torch.cuda.empty_cache()
- self.strategy.maybe_load_optimizer(self.actor_optim)
- all_items = self.strategy.sp_data_processor.preprocess(self.replay_buffer.items)
-
- device = torch.cuda.current_device()
-
- status_list = []
- status_mean = {}
- for epoch in range(self.max_epochs):
- pbar = tqdm(
- range(0, len(all_items), self.micro_train_batch_size),
- desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
- disable=not self.strategy.is_rank_0(),
- )
- for i in pbar:
- items = all_items[i:i + self.micro_train_batch_size]
- if self.VLM:
- experience = make_experience_batch_vl(items, packing_samples=self.packing_samples)
- else:
- experience = make_experience_batch(items, packing_samples=self.packing_samples)
- experience.to_device(device)
-
- # ======================================================================================
- # Validate data BEFORE calling training_step to prevent execution path divergence
- # If validation is done inside training_step, different ranks may follow different code paths
- # (some return early, others continue), causing deadlock in collective communication ops.
-
- # Step 1: Each rank validates its local data
- should_skip_local = False
- if self.VLM and hasattr(self, '_validate_qwen_vl_tensors'):
- # Call the same validation logic used in training_step_actor
- sequences = experience.sequences
- pixel_values = experience.pixel_values
-
- # Validate before any forward pass
- is_valid = self._validate_qwen_vl_tensors(
- sequences, pixel_values, context="pre_training_validation"
- )
- should_skip_local = not is_valid
-
- # Step 2: Synchronize skip decision across all ranks via all_reduce
- # This ensures all ranks agree on whether to skip, preventing execution divergence
- skip_flag = torch.tensor([1.0 if should_skip_local else 0.0], device=device)
- torch.distributed.all_reduce(skip_flag, op=torch.distributed.ReduceOp.MAX)
-
- # Step 3: Collectively skip if ANY rank detected invalid data
- if skip_flag.item() > 0:
- if self.strategy.is_rank_0():
- pbar.set_description(f"Train epoch [{epoch + 1}/{self.max_epochs}] (skipping invalid batch)")
- continue # All ranks skip together - no deadlock
- # ======================================================================================
-
- # Create entropy_mask if high_entropy_token_ratio > 0 and action_entropy is available
- entropy_mask = None
- if hasattr(experience, 'action_entropy') and experience.action_entropy is not None:
- if self.high_entropy_token_ratio > 0.0:
- entropy_mask = create_high_entropy_mask(
- experience.action_entropy, experience.action_mask, self.high_entropy_token_ratio
+ with self.profiler.section("learn/total"):
+ torch.cuda.empty_cache()
+ self.strategy.maybe_load_optimizer(self.actor_optim)
+ with self.profiler.section("learn/sp_preprocess"):
+ all_items = self.strategy.sp_data_processor.preprocess(self.replay_buffer.items)
+
+ device = torch.cuda.current_device()
+ status_list = []
+ status_mean = {}
+
+ for epoch in range(self.max_epochs):
+ pbar = tqdm(
+ range(0, len(all_items), self.micro_train_batch_size),
+ desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
+ disable=not self.strategy.is_rank_0(),
+ )
+ for i in pbar:
+ items = all_items[i:i + self.micro_train_batch_size]
+ if self.VLM:
+ experience = make_experience_batch_vl(items, packing_samples=self.packing_samples)
+ else:
+ experience = make_experience_batch(items, packing_samples=self.packing_samples)
+ experience.to_device(device)
+
+ should_skip_local = False
+ if self.VLM and hasattr(self, '_validate_qwen_vl_tensors'):
+ sequences = experience.sequences
+ pixel_values = experience.pixel_values
+ is_valid = self._validate_qwen_vl_tensors(
+ sequences, pixel_values, context="pre_training_validation"
)
-
- # Call training_step which will handle both GSPO and standard modes
- status = self.training_step(experience, global_steps, entropy_mask=entropy_mask)
-
- # for DP
- # weighted mean for kl
- if "kl" in status:
- status["kl"] *= status["response_length"]
- status = self.strategy.all_reduce(status)
- status["kl"] /= status["response_length"]
-
- # Training epoch progress bar: show per-batch metrics for detailed monitoring
- short_status = {}
-
- if "policy_loss" in status:
- short_status = {
- "pg": status["policy_loss"], # policy gradient loss
- "rm": status["reward"], # per-batch reward (instantaneous)
- "ret": status["return"], # per-batch return (instantaneous)
- "glen": status["response_length"], # per-batch response length
- "tlen": status["total_length"], # per-batch total length
- "kl": status["kl"], # KL divergence
- "act_lr": status["actor_lr"], # actor learning rate
- }
-
- if "critic_loss" in status:
- short_status["cri"] = status["critic_loss"]
- short_status["vals"] = status["values"]
- short_status["cri_lr"] = status["critic_lr"]
-
- if "ptx_loss" in status:
- short_status["ptx"] = status["ptx_loss"]
-
- status_list.append(status)
- pbar.set_postfix(short_status)
-
- # Short status keys added for progress bar display:
- # "pg": policy_loss
- # "rm": reward
- # "ret": return
- # "glen": response_length
- # "tlen": total_length
- # "kl": KL divergence
- # "act_lr": actor_lr
- if status_list:
- status_mean = status_list[0]
- for m in status_list[1:]:
- for k, v in m.items():
- status_mean[k] += v
- for k in status_mean.keys():
- status_mean[k] /= len(status_list)
-
- # ========== Aggregate step-level reward metrics from replay buffer ==========
- # NOTE: These metrics are aggregated from ALL experiences in the current step's
- # replay buffer (e.g., 640 experiences if rollout_batch_size=128, n_samples=5).
- # They represent the TRUE statistics of the rollout phase, NOT the training phase
- # micro-batch averages which are less representative.
- #
- # Naming convention:
- # - "*_mean" suffix: mean across all experiences in this step
- # - "step_*" prefix: clarifies this is per-step aggregation, not per-episode
- if self.replay_buffer.items:
- all_rewards = []
- all_format_rewards = []
- all_accuracy_rewards = []
- all_general_model_rewards = []
- all_rule_rewards = []
- all_advantages = []
- all_returns = []
- all_response_lengths = []
- all_opd_reverse_kl = [] # For on-policy distillation metrics
-
- for item in self.replay_buffer.items:
- # Collect rewards
- if hasattr(item, 'info') and item.info is not None and 'reward' in item.info:
- all_rewards.append(item.info['reward'])
-
- # Collect detailed reward metrics from info dict
- if hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info:
- reward_metrics = item.info['reward_metrics']
- if 'format_reward' in reward_metrics:
- all_format_rewards.append(reward_metrics['format_reward'])
- if 'accuracy_reward' in reward_metrics:
- all_accuracy_rewards.append(reward_metrics['accuracy_reward'])
- general_model_reward = reward_metrics.get("general_model_reward")
- if general_model_reward is None:
- general_model_reward = reward_metrics.get("model_reward")
- if general_model_reward is not None:
- all_general_model_rewards.append(general_model_reward)
- if 'rule_reward' in reward_metrics:
- all_rule_rewards.append(reward_metrics['rule_reward'])
-
- # Collect advantages and returns
- if hasattr(item, 'advantages') and item.advantages is not None:
- all_advantages.append(item.advantages)
- if hasattr(item, 'returns') and item.returns is not None:
- all_returns.append(item.returns)
- if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info:
- all_response_lengths.append(item.info['response_length'])
-
- # Collect on-policy distillation reverse KL
- if hasattr(item, 'info') and item.info is not None and 'opd_reverse_kl' in item.info:
- all_opd_reverse_kl.append(item.info['opd_reverse_kl'])
-
- # Compute statistics
- # [TENSOR-FIX] Handle both tensor lists and scalar lists for all reward types
- if all_rewards:
- # Handle both tensor lists (from batched rewards) and scalar lists
- if isinstance(all_rewards[0], torch.Tensor):
- rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards])
- else:
- rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device)
- # Use "step_*" prefix to clarify this is per-step aggregation, not per-episode
- status_mean["step_reward_mean"] = rewards_tensor.mean().item()
- status_mean["step_reward_std"] = rewards_tensor.std().item()
- status_mean["step_reward_max"] = rewards_tensor.max().item()
- status_mean["step_reward_min"] = rewards_tensor.min().item()
-
- if all_format_rewards:
- # [TENSOR-FIX] Handle both tensor lists and scalar lists
- if isinstance(all_format_rewards[0], torch.Tensor):
- format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards])
- else:
- format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device)
- status_mean["format_reward_mean"] = format_tensor.mean().item()
- status_mean["format_reward_std"] = format_tensor.std().item()
-
- if all_accuracy_rewards:
- # [TENSOR-FIX] Handle both tensor lists and scalar lists
- if isinstance(all_accuracy_rewards[0], torch.Tensor):
- accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards])
- else:
- accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device)
- status_mean["accuracy_reward_mean"] = accuracy_tensor.mean().item()
- status_mean["accuracy_reward_std"] = accuracy_tensor.std().item()
-
- if all_general_model_rewards:
- # [TENSOR-FIX] Handle both tensor lists and scalar lists
- if isinstance(all_general_model_rewards[0], torch.Tensor):
- model_tensor = torch.cat([t.to(device).float() for t in all_general_model_rewards])
- else:
- model_tensor = torch.tensor(all_general_model_rewards, dtype=torch.float32, device=device)
- if model_tensor.abs().sum() > 0: # Only log if model rewards are non-zero
- status_mean["general_model_reward_mean"] = model_tensor.mean().item()
- self.strategy.print(f" general_model_reward_mean: {status_mean['general_model_reward_mean']}")
-
- if all_rule_rewards:
- # [TENSOR-FIX] Handle both tensor lists and scalar lists
- if isinstance(all_rule_rewards[0], torch.Tensor):
- rule_tensor = torch.cat([t.to(device).float() for t in all_rule_rewards])
- else:
- rule_tensor = torch.tensor(all_rule_rewards, dtype=torch.float32, device=device)
- status_mean["rule_reward_mean"] = rule_tensor.mean().item()
- self.strategy.print(f"rule_reward_mean: {status_mean['rule_reward_mean']}")
-
- # For advantages, returns, and lengths, they are already lists of tensors,
- # so torch.cat() is the correct function to use.
- if all_advantages:
- advantages_tensor = torch.cat(all_advantages)
- status_mean["advantages_mean"] = advantages_tensor.mean().item()
- status_mean["advantages_std"] = advantages_tensor.std().item()
- status_mean["advantages_max"] = advantages_tensor.max().item()
- status_mean["advantages_min"] = advantages_tensor.min().item()
-
- if all_returns:
- returns_tensor = torch.cat(all_returns)
- status_mean["returns_mean"] = returns_tensor.mean().item()
- status_mean["returns_std"] = returns_tensor.std().item()
-
- if all_response_lengths:
- # [TENSOR-FIX] Handle both tensor lists and scalar lists
- if isinstance(all_response_lengths[0], torch.Tensor):
- lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths])
- else:
- lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device)
- status_mean["response_length_mean"] = lengths_tensor.float().mean().item()
- status_mean["response_length_std"] = lengths_tensor.float().std().item()
-
- # On-Policy Distillation metrics
- if all_opd_reverse_kl:
- # Collect reverse KL from all experiences
- if isinstance(all_opd_reverse_kl[0], torch.Tensor):
- opd_kl_tensor = torch.cat([t.to(device).float() for t in all_opd_reverse_kl])
- else:
- opd_kl_tensor = torch.tensor(all_opd_reverse_kl, dtype=torch.float32, device=device)
- # Mask out zero values (padding)
- non_zero_mask = opd_kl_tensor != 0
- if non_zero_mask.any():
- masked_kl = opd_kl_tensor[non_zero_mask]
- status_mean["opd_reverse_kl_mean"] = masked_kl.mean().item()
- status_mean["opd_reverse_kl_std"] = masked_kl.std().item()
- status_mean["opd_reverse_kl_max"] = masked_kl.max().item()
- status_mean["opd_reverse_kl_min"] = masked_kl.min().item()
-
- # Print detailed reward breakdown (only on rank 0)
- if self.print_replay_buffer_stats and self.strategy.is_rank_0():
- self.strategy.print("\n" + "=" * 60)
- self.strategy.print("📊 Detailed Step Statistics")
- self.strategy.print("=" * 60)
+ should_skip_local = not is_valid
+
+ skip_flag = torch.tensor([1.0 if should_skip_local else 0.0], device=device)
+ torch.distributed.all_reduce(skip_flag, op=torch.distributed.ReduceOp.MAX)
+ if skip_flag.item() > 0:
+ if self.strategy.is_rank_0():
+ pbar.set_description(
+ f"Train epoch [{epoch + 1}/{self.max_epochs}] (skipping invalid batch)"
+ )
+ continue
+
+ entropy_mask = None
+ if hasattr(experience, 'action_entropy') and experience.action_entropy is not None:
+ if self.high_entropy_token_ratio > 0.0:
+ entropy_mask = create_high_entropy_mask(
+ experience.action_entropy,
+ experience.action_mask,
+ self.high_entropy_token_ratio,
+ )
+
+ with self.profiler.section("learn/micro_batch_total"):
+ status = self.training_step(experience, global_steps, entropy_mask=entropy_mask)
+
+ if "kl" in status:
+ status["kl"] *= status["response_length"]
+ status = self.strategy.all_reduce(status)
+ status["kl"] /= status["response_length"]
+
+ short_status = {}
+ if "policy_loss" in status:
+ short_status = {
+ "pg": status["policy_loss"],
+ "rm": status["reward"],
+ "ret": status["return"],
+ "glen": status["response_length"],
+ "tlen": status["total_length"],
+ "kl": status["kl"],
+ "act_lr": status["actor_lr"],
+ }
+ if "critic_loss" in status:
+ short_status["cri"] = status["critic_loss"]
+ short_status["vals"] = status["values"]
+ short_status["cri_lr"] = status["critic_lr"]
+ if "ptx_loss" in status:
+ short_status["ptx"] = status["ptx_loss"]
+
+ status_list.append(status)
+ pbar.set_postfix(short_status)
+
+ if status_list:
+ status_mean = status_list[0]
+ for metric_dict in status_list[1:]:
+ for key, value in metric_dict.items():
+ status_mean[key] += value
+ for key in status_mean.keys():
+ status_mean[key] /= len(status_list)
+
+ if self.replay_buffer.items:
+ all_rewards = []
+ reward_metric_values = defaultdict(list)
+ all_advantages = []
+ all_returns = []
+ all_response_lengths = []
+ all_total_lengths = []
+ all_opd_reverse_kl = []
+
+ for item in self.replay_buffer.items:
+ if hasattr(item, 'info') and item.info is not None and 'reward' in item.info:
+ all_rewards.append(item.info['reward'])
+ if hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info:
+ reward_metrics = item.info['reward_metrics']
+ if reward_metrics is not None:
+ for key, value in reward_metrics.items():
+ reward_metric_values[key].append(value)
+ if hasattr(item, 'advantages') and item.advantages is not None:
+ all_advantages.append(item.advantages)
+ if hasattr(item, 'returns') and item.returns is not None:
+ all_returns.append(item.returns)
+ if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info:
+ all_response_lengths.append(item.info['response_length'])
+ if hasattr(item, 'info') and item.info is not None and 'total_length' in item.info:
+ all_total_lengths.append(item.info['total_length'])
+ if hasattr(item, 'info') and item.info is not None and 'opd_reverse_kl' in item.info:
+ all_opd_reverse_kl.append(item.info['opd_reverse_kl'])
if all_rewards:
- self.strategy.print(
- f"🎁 Total Reward: {status_mean['step_reward_mean']:.4f} ± {status_mean['step_reward_std']:.4f} " # noqa
- f"(min={status_mean['step_reward_min']:.4f}, max={status_mean['step_reward_max']:.4f})"
- )
-
- if all_format_rewards:
- self.strategy.print(
- f"📝 Format Reward: {status_mean['format_reward_mean']:.4f} ± {status_mean['format_reward_std']:.4f}" # noqa
- )
-
- if all_accuracy_rewards:
- self.strategy.print(
- f"✅ Accuracy Reward: {status_mean['accuracy_reward_mean']:.4f} ± {status_mean['accuracy_reward_std']:.4f}" # noqa
- )
-
- if all_general_model_rewards and "general_model_reward_mean" in status_mean:
+ if isinstance(all_rewards[0], torch.Tensor):
+ rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards])
+ else:
+ rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device)
+ status_mean["step_reward_mean"] = rewards_tensor.mean().item()
+ status_mean["step_reward_std"] = rewards_tensor.std().item()
+ status_mean["step_reward_max"] = rewards_tensor.max().item()
+ status_mean["step_reward_min"] = rewards_tensor.min().item()
+ status_mean["step_reward_zero_ratio"] = (rewards_tensor == 0).float().mean().item()
+ status_mean["step_reward_one_ratio"] = (rewards_tensor == 1).float().mean().item()
+
+ for metric_name, values in reward_metric_values.items():
+ if not values:
+ continue
+ if isinstance(values[0], torch.Tensor):
+ metric_tensor = torch.cat([t.to(device).float() for t in values])
+ else:
+ metric_tensor = torch.tensor(values, dtype=torch.float32, device=device)
+ if metric_tensor.numel() == 0:
+ continue
+ if metric_name == "general_model_reward" and metric_tensor.abs().sum() == 0:
+ continue
+ status_mean[f"{metric_name}_mean"] = metric_tensor.mean().item()
+ status_mean[f"{metric_name}_std"] = metric_tensor.std().item()
+
+ if "general_model_reward_mean" in status_mean:
self.strategy.print(f"🧠 General RM Reward:{status_mean['general_model_reward_mean']:.4f}")
if all_advantages:
- self.strategy.print(
- f"📈 Advantages: {status_mean['advantages_mean']:.4f} ± {status_mean['advantages_std']:.4f} " # noqa
- f"(min={status_mean['advantages_min']:.4f}, max={status_mean['advantages_max']:.4f})"
- )
+ advantages_tensor = torch.cat(all_advantages)
+ status_mean["advantages_mean"] = advantages_tensor.mean().item()
+ status_mean["advantages_std"] = advantages_tensor.std().item()
+ status_mean["advantages_max"] = advantages_tensor.max().item()
+ status_mean["advantages_min"] = advantages_tensor.min().item()
if all_returns:
- self.strategy.print(
- f"💰 Returns: {status_mean['returns_mean']:.4f} ± {status_mean['returns_std']:.4f}"
- )
+ returns_tensor = torch.cat(all_returns)
+ status_mean["returns_mean"] = returns_tensor.mean().item()
+ status_mean["returns_std"] = returns_tensor.std().item()
if all_response_lengths:
- self.strategy.print(
- f"📏 Response Length: {status_mean['response_length_mean']:.1f} ± {status_mean['response_length_std']:.1f} tokens" # noqa
- )
+ if isinstance(all_response_lengths[0], torch.Tensor):
+ lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths])
+ else:
+ lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device)
+ status_mean["response_length_mean"] = lengths_tensor.float().mean().item()
+ status_mean["response_length_std"] = lengths_tensor.float().std().item()
+ status_mean["response_length_zero_ratio"] = (lengths_tensor <= 1).float().mean().item()
+ generate_max_len = getattr(self.args, "generate_max_len", None)
+ if generate_max_len:
+ status_mean["response_hit_max_ratio"] = (lengths_tensor
+ >= float(generate_max_len - 1)).float().mean().item()
+
+ if all_total_lengths:
+ if isinstance(all_total_lengths[0], torch.Tensor):
+ total_lengths_tensor = torch.cat([t.to(device).float() for t in all_total_lengths])
+ else:
+ total_lengths_tensor = torch.tensor(all_total_lengths, dtype=torch.float32, device=device)
+ status_mean["total_length_mean"] = total_lengths_tensor.float().mean().item()
+ status_mean["total_length_std"] = total_lengths_tensor.float().std().item()
+
+ if all_opd_reverse_kl:
+ if isinstance(all_opd_reverse_kl[0], torch.Tensor):
+ opd_kl_tensor = torch.cat([t.to(device).float() for t in all_opd_reverse_kl])
+ else:
+ opd_kl_tensor = torch.tensor(all_opd_reverse_kl, dtype=torch.float32, device=device)
+ non_zero_mask = opd_kl_tensor != 0
+ if non_zero_mask.any():
+ masked_kl = opd_kl_tensor[non_zero_mask]
+ status_mean["opd_reverse_kl_mean"] = masked_kl.mean().item()
+ status_mean["opd_reverse_kl_std"] = masked_kl.std().item()
+ status_mean["opd_reverse_kl_max"] = masked_kl.max().item()
+ status_mean["opd_reverse_kl_min"] = masked_kl.min().item()
+
+ if self.print_replay_buffer_stats and self.strategy.is_rank_0():
+ self.strategy.print("\n" + "=" * 60)
+ self.strategy.print("📊 Detailed Step Statistics")
+ self.strategy.print("=" * 60)
+
+ if all_rewards:
+ self.strategy.print(
+ f"🎁 Total Reward: {status_mean['step_reward_mean']:.4f} ± {status_mean['step_reward_std']:.4f} " # noqa
+ f"(min={status_mean['step_reward_min']:.4f}, max={status_mean['step_reward_max']:.4f})"
+ )
+ self.strategy.print(
+ f" Reward Ratios: zero={status_mean['step_reward_zero_ratio']:.4f}, "
+ f"one={status_mean['step_reward_one_ratio']:.4f}"
+ )
- if all_opd_reverse_kl and 'opd_reverse_kl_mean' in status_mean:
- self.strategy.print(
- f"🎓 OPD Reverse KL: {status_mean['opd_reverse_kl_mean']:.4f} ± {status_mean['opd_reverse_kl_std']:.4f} " # noqa
- f"(min={status_mean['opd_reverse_kl_min']:.4f}, max={status_mean['opd_reverse_kl_max']:.4f})"
- )
+ reward_metric_print_order = [
+ ("format_reward", "📝 Format Reward"),
+ ("accuracy_reward", "✅ Accuracy Reward"),
+ ("outcome_correct", "🎯 Outcome Correct"),
+ ("model_reward", "🤖 Model Reward"),
+ ("final_reward", "🏁 Final Reward"),
+ ("rule_reward", "⚖️ Rule Reward"),
+ ("max_relative_drop", "📉 Max Relative Drop"),
+ ("has_drop_moment", "🪂 Drop Moment"),
+ ("step_score_min", "🔬 Step Score Min"),
+ ("step_score_mean", "🔬 Step Score Mean"),
+ ("step_score_last", "🔬 Step Score Last"),
+ ("step_count", "🧮 Step Count"),
+ ]
+ for metric_name, title in reward_metric_print_order:
+ mean_key = f"{metric_name}_mean"
+ std_key = f"{metric_name}_std"
+ if mean_key in status_mean:
+ self.strategy.print(f"{title:<20} {status_mean[mean_key]:.4f} ± {status_mean[std_key]:.4f}")
+
+ if all_advantages:
+ self.strategy.print(
+ f"📈 Advantages: {status_mean['advantages_mean']:.4f} ± {status_mean['advantages_std']:.4f} " # noqa
+ f"(min={status_mean['advantages_min']:.4f}, max={status_mean['advantages_max']:.4f})"
+ )
- self.strategy.print("=" * 60 + "\n")
+ if all_returns:
+ self.strategy.print(
+ f"💰 Returns: {status_mean['returns_mean']:.4f} ± {status_mean['returns_std']:.4f}"
+ )
- torch.cuda.empty_cache()
+ if all_response_lengths:
+ self.strategy.print(
+ f"📏 Response Length: {status_mean['response_length_mean']:.1f} ± {status_mean['response_length_std']:.1f} tokens" # noqa
+ )
+ self.strategy.print(
+ f" Length Ratios: empty={status_mean['response_length_zero_ratio']:.4f}, "
+ f"hit_max={status_mean.get('response_hit_max_ratio', 0.0):.4f}"
+ )
- self.strategy.maybe_offload_optimizer(self.actor_optim)
- torch.cuda.synchronize()
- torch.cuda.empty_cache()
- self.strategy.print(f"PPO Train TIMECOST {time.time() - train_begin}")
- self.strategy.report_memory("after train, opt offloaded, before update weights")
- self.strategy.print(torch.cuda.memory_summary())
- self.strategy.update_engine_weights(self.actor)
-
- # Save trajectories at the end of ppo_train, BEFORE replay buffer is cleared
- # This ensures we have data to save when trajectory saving is enabled
- if global_steps % self.args.save_steps == 0:
- self.save_trajectories(global_steps)
+ if all_total_lengths:
+ self.strategy.print(
+ f"📦 Total Length: {status_mean['total_length_mean']:.1f} ± {status_mean['total_length_std']:.1f} tokens" # noqa
+ )
+
+ self.strategy.print("=" * 60 + "\n")
+
+ torch.cuda.empty_cache()
+ self.strategy.maybe_offload_optimizer(self.actor_optim)
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ self.strategy.print(f"PPO Train TIMECOST {time.time() - train_begin}")
+ self.strategy.report_memory("after train, opt offloaded, before update weights")
+ self.strategy.print(torch.cuda.memory_summary())
+ with self.profiler.section("learn/update_engine_weights"):
+ self.strategy.update_engine_weights(self.actor)
+
+ if global_steps % self.args.save_steps == 0:
+ with self.profiler.section("learn/save_trajectories"):
+ self.save_trajectories(global_steps)
return status_mean
diff --git a/lightrft/utils/__init__.py b/lightrft/utils/__init__.py
index 01600077..8d090bb7 100644
--- a/lightrft/utils/__init__.py
+++ b/lightrft/utils/__init__.py
@@ -5,6 +5,7 @@
"""
from .logging_utils import init_logger
+from .profile_recorder import StepProfileRecorder
from .remote_rm_utils import remote_rm_fn
from .trajectory_saver import TrajectorySaver, create_trajectory_saver
from .distributed_sampler import DistributedSampler
@@ -21,6 +22,7 @@
__all__ = [
# logging and trajectory
"init_logger",
+ "StepProfileRecorder",
"remote_rm_fn",
'TrajectorySaver',
'create_trajectory_saver',
diff --git a/lightrft/utils/cli_args.py b/lightrft/utils/cli_args.py
index c282b422..84c3204a 100644
--- a/lightrft/utils/cli_args.py
+++ b/lightrft/utils/cli_args.py
@@ -47,6 +47,36 @@ def add_arguments(parser: argparse.ArgumentParser) -> None:
help="Fraction of GPU memory reserved for the inference engine's KV cache (range: 0.0 to 1.0). "
"Higher values improve throughput but may risk out-of-memory errors.",
)
+ parser.add_argument(
+ "--local_hf_generate_max_batch_size",
+ type=int,
+ default=0,
+ help="Maximum per-call sample batch size for the local HuggingFace inference engine. "
+ "A value <= 0 keeps the current single-shot behavior. "
+ "Set this to a small positive integer (for example 1 or 2) to chunk local HF rollout generation "
+ "when multimodal models would otherwise OOM.",
+ )
+ parser.add_argument(
+ "--local_hf_max_new_tokens",
+ type=int,
+ default=0,
+ help="Optional hard cap for max_new_tokens applied only to the local HuggingFace rollout path. "
+ "A value <= 0 keeps the launcher-provided generation length unchanged.",
+ )
+ parser.add_argument(
+ "--hf_separate_rollout_actor",
+ action="store_true",
+ default=False,
+ help="Use a dedicated local HF rollout actor instead of reusing the training actor directly. "
+ "The current implementation is intended for FSDP-based quick rollout experiments.",
+ )
+ parser.add_argument(
+ "--hf_separate_rollout_keep_on_gpu",
+ action="store_true",
+ default=False,
+ help="For local HF separate rollout: keep the rollout actor resident on GPU instead of sleeping/offloading "
+ "it between rollout and training phases. Use this when per-rank memory headroom is sufficient.",
+ )
parser.add_argument(
"--enable_engine_sleep",
action="store_true",
@@ -117,7 +147,6 @@ def add_arguments(parser: argparse.ArgumentParser) -> None:
help="Interval (in training steps) for plotting and saving the generated sequence length distribution. "
"Only effective if `--log_dir` is set.",
)
-
# for rewards models
parser.add_argument(
"--rm_use_engine",
diff --git a/lightrft/utils/profile_recorder.py b/lightrft/utils/profile_recorder.py
new file mode 100644
index 00000000..52bde311
--- /dev/null
+++ b/lightrft/utils/profile_recorder.py
@@ -0,0 +1,518 @@
+"""
+Step-level profiling utilities for long-running training jobs.
+
+This module provides a lightweight profiler that is suitable for production
+training loops:
+
+- Measures named sections in wall-clock seconds.
+- Persists per-step summaries to JSONL with flush + fsync.
+- Maintains a continuously refreshed latest snapshot so interrupted jobs still
+ leave readable profiling state on disk.
+- Optionally emits sampled ``torch.profiler`` traces on rank 0.
+"""
+
+from __future__ import annotations
+
+import json
+import os
+import threading
+import time
+from contextlib import contextmanager, nullcontext
+from pathlib import Path
+from typing import Dict, Iterator, List, Optional
+
+import torch
+
+
+class _DummyTorchProfiler:
+ def start(self) -> None:
+ pass
+
+ def stop(self) -> None:
+ pass
+
+ def step(self) -> None:
+ pass
+
+
+class StepProfileRecorder:
+ """
+ Persistent step profiler for distributed training.
+
+ The recorder keeps local timing state on every rank, aggregates it at
+ train-step boundaries, and writes the aggregated profile on rank 0.
+ """
+
+ TRACE_WAIT_STEPS = 1
+ TRACE_WARMUP_STEPS = 1
+ TRACE_ACTIVE_STEPS = 2
+ TRACE_REPEAT = 2
+ HEARTBEAT_INTERVAL_S = 1.0
+
+ def __init__(self, enabled: bool, output_dir: str, print_fn=None) -> None:
+ self.enabled = bool(enabled)
+ self.output_dir = Path(output_dir)
+ self.print_fn = print_fn
+ self.rank = torch.distributed.get_rank() if self._dist_enabled() else 0
+ self.world_size = torch.distributed.get_world_size() if self._dist_enabled() else 1
+ self.is_rank_0 = self.rank == 0
+
+ self.current_step: Optional[int] = None
+ self.current_episode: Optional[int] = None
+ self.current_step_start_wall: Optional[float] = None
+ self.current_step_started_at: Optional[float] = None
+ self.section_totals: Dict[str, float] = {}
+ self.phase_stack: List[str] = []
+ self.active_section_name: Optional[str] = None
+ self.active_section_start_wall: Optional[float] = None
+ self.last_section_name: Optional[str] = None
+ self.last_section_elapsed_s: Optional[float] = None
+ self._state_lock = threading.Lock()
+ self._write_lock = threading.Lock()
+ self._heartbeat_stop = threading.Event()
+ self._heartbeat_thread: Optional[threading.Thread] = None
+ self._snapshot_generation = 0
+
+ self.rank_step_profile_path = self.output_dir / f"step_profile.rank{self.rank}.jsonl"
+ self.rank_latest_profile_path = self.output_dir / f"step_profile.rank{self.rank}.latest.json"
+ self.rank_current_profile_path = self.output_dir / f"step_profile.rank{self.rank}.current.json"
+ self.step_profile_path = self.output_dir / "step_profile.global.jsonl"
+ self.latest_profile_path = self.output_dir / "step_profile.latest.json"
+ self.current_profile_path = self.output_dir / "step_profile.current.json"
+ self.trace_dir = self.output_dir / "traces"
+
+ self._torch_profiler = None
+ if self.enabled:
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+ if self.is_rank_0:
+ self.trace_dir.mkdir(parents=True, exist_ok=True)
+ self._torch_profiler = self._build_torch_profiler()
+ self._torch_profiler.start()
+ self._start_heartbeat()
+
+ @staticmethod
+ def _dist_enabled() -> bool:
+ return torch.distributed.is_available() and torch.distributed.is_initialized()
+
+ @staticmethod
+ def _cuda_sync_if_available() -> None:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ def _build_torch_profiler(self):
+ if not self.is_rank_0:
+ return _DummyTorchProfiler()
+
+ from torch.profiler import ProfilerActivity
+
+ return torch.profiler.profile(
+ schedule=torch.profiler.schedule(
+ wait=self.TRACE_WAIT_STEPS,
+ warmup=self.TRACE_WARMUP_STEPS,
+ active=self.TRACE_ACTIVE_STEPS,
+ repeat=self.TRACE_REPEAT,
+ ),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(str(self.trace_dir)),
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+ record_shapes=False,
+ with_stack=False,
+ profile_memory=False,
+ )
+
+ def start_step(self, train_step: int, episode: int) -> None:
+ if not self.enabled:
+ return
+
+ self._cuda_sync_if_available()
+ with self._state_lock:
+ self._snapshot_generation += 1
+ self.current_step = int(train_step)
+ self.current_episode = int(episode)
+ self.current_step_start_wall = time.perf_counter()
+ self.current_step_started_at = time.time()
+ self.section_totals = {}
+ self.phase_stack = []
+ self.active_section_name = None
+ self.active_section_start_wall = None
+ self.last_section_name = None
+ self.last_section_elapsed_s = None
+ self._write_current_snapshot()
+
+ @contextmanager
+ def phase(self, phase_name: str) -> Iterator[None]:
+ if not self.enabled:
+ yield
+ return
+
+ cleaned = phase_name.strip("/")
+ if not cleaned:
+ yield
+ return
+
+ self.phase_stack.append(cleaned)
+ try:
+ yield
+ finally:
+ self.phase_stack.pop()
+
+ @contextmanager
+ def section(self, name: str) -> Iterator[None]:
+ if not self.enabled or self.current_step is None:
+ yield
+ return
+
+ full_name = self._qualify_name(name)
+ record_ctx = torch.profiler.record_function(full_name) if self.is_rank_0 else nullcontext()
+
+ self._cuda_sync_if_available()
+ start = time.perf_counter()
+ with self._state_lock:
+ self.active_section_name = full_name
+ self.active_section_start_wall = start
+ self._write_current_snapshot()
+ with record_ctx:
+ try:
+ yield
+ finally:
+ self._cuda_sync_if_available()
+ elapsed = time.perf_counter() - start
+ with self._state_lock:
+ self.section_totals[full_name] = self.section_totals.get(full_name, 0.0) + elapsed
+ self.active_section_name = None
+ self.active_section_start_wall = None
+ self.last_section_name = full_name
+ self.last_section_elapsed_s = elapsed
+ self._write_current_snapshot()
+
+ def _qualify_name(self, name: str) -> str:
+ cleaned = name.strip("/")
+ if not self.phase_stack:
+ return cleaned
+ return "/".join([*self.phase_stack, cleaned])
+
+ def finish_step(self, extra: Optional[Dict] = None) -> Optional[Dict]:
+ if not self.enabled or self.current_step is None or self.current_step_start_wall is None:
+ return None
+
+ self._cuda_sync_if_available()
+ with self._state_lock:
+ train_step = int(self.current_step)
+ episode = int(self.current_episode) if self.current_episode is not None else None
+ started_at = self.current_step_started_at
+ total_elapsed = time.perf_counter() - self.current_step_start_wall
+ local_sections = dict(self.section_totals)
+ local_sections["step/total"] = total_elapsed
+ self._snapshot_generation += 1
+ self.current_step = None
+ self.current_episode = None
+ self.current_step_start_wall = None
+ self.current_step_started_at = None
+ self.section_totals = {}
+ self.phase_stack = []
+ self.active_section_name = None
+ self.active_section_start_wall = None
+ self.last_section_name = None
+ self.last_section_elapsed_s = None
+
+ local_step_total_s = local_sections.get("step/total", total_elapsed)
+ local_ratios = {
+ name: (value / local_step_total_s if local_step_total_s > 0 else 0.0)
+ for name, value in local_sections.items()
+ }
+ local_record = {
+ "train_step": train_step,
+ "episode": episode,
+ "rank": self.rank,
+ "world_size": self.world_size,
+ "started_at": started_at,
+ "finished_at": time.time(),
+ "sections_local_s": local_sections,
+ "sections_local_ratio": local_ratios,
+ }
+ if extra:
+ local_record["extra"] = extra
+ self._append_jsonl(self.rank_step_profile_path, local_record)
+ self._write_atomic_json(self.rank_latest_profile_path, local_record)
+ self._write_atomic_json(self.rank_current_profile_path, local_record)
+
+ gathered_sections = self._gather_sections(local_sections)
+ self._torch_profiler.step()
+
+ result = None
+ if self.is_rank_0:
+ aggregated = self._aggregate_sections(gathered_sections)
+ step_total_s = aggregated["max_s"].get("step/total", total_elapsed)
+ ratios = {
+ name: (value / step_total_s if step_total_s > 0 else 0.0)
+ for name, value in aggregated["max_s"].items()
+ }
+ mean_ratios = {
+ name: (
+ value / aggregated["mean_s"].get("step/total", step_total_s)
+ if aggregated["mean_s"].get("step/total", step_total_s) > 0 else 0.0
+ )
+ for name, value in aggregated["mean_s"].items()
+ }
+ record = {
+ "train_step": train_step,
+ "episode": episode,
+ "world_size": self.world_size,
+ "available_ranks": list(range(self.world_size)),
+ "started_at": started_at,
+ "finished_at": time.time(),
+ "sections_max_s": aggregated["max_s"],
+ "sections_mean_s": aggregated["mean_s"],
+ "sections_max_ratio": ratios,
+ "sections_mean_ratio": mean_ratios,
+ }
+ if extra:
+ record["extra"] = extra
+ self._append_jsonl(self.step_profile_path, record)
+ self._write_atomic_json(self.latest_profile_path, record)
+ self._write_atomic_json(self.current_profile_path, record)
+ result = {
+ "record": record,
+ "wandb_logs": self._build_wandb_logs(train_step, aggregated["max_s"], ratios),
+ "summary": self._build_summary(aggregated["max_s"], ratios),
+ }
+ return result
+
+ def close(self) -> None:
+ if not self.enabled:
+ return
+ self._heartbeat_stop.set()
+ if self._heartbeat_thread is not None:
+ self._heartbeat_thread.join(timeout=max(self.HEARTBEAT_INTERVAL_S * 2, 2.0))
+ if self._torch_profiler is not None:
+ self._torch_profiler.stop()
+
+ def _gather_sections(self, local_sections: Dict[str, float]) -> List[Dict[str, float]]:
+ if not self._dist_enabled():
+ return [local_sections]
+
+ gathered_sections = [None for _ in range(self.world_size)]
+ torch.distributed.all_gather_object(gathered_sections, local_sections)
+ return [item or {} for item in gathered_sections]
+
+ @staticmethod
+ def _aggregate_sections(section_list: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]:
+ section_names = sorted({name for section_dict in section_list for name in section_dict})
+ max_s: Dict[str, float] = {}
+ mean_s: Dict[str, float] = {}
+ world_size = max(len(section_list), 1)
+
+ for name in section_names:
+ values = [float(section_dict.get(name, 0.0)) for section_dict in section_list]
+ max_s[name] = max(values)
+ mean_s[name] = sum(values) / world_size
+
+ return {"max_s": max_s, "mean_s": mean_s}
+
+ @staticmethod
+ def _flatten_section_name(name: str) -> str:
+ return name.replace("/", "_").replace(" ", "_")
+
+ def _build_wandb_logs(self, train_step: int, max_s: Dict[str, float], ratios: Dict[str, float]) -> Dict[str, float]:
+ logs = {"profile/train_step": train_step}
+ for name, value in max_s.items():
+ flat_name = self._flatten_section_name(name)
+ logs[f"profile/{flat_name}_s"] = value
+ logs[f"profile/{flat_name}_ratio"] = ratios.get(name, 0.0)
+ return logs
+
+ @staticmethod
+ def _build_summary(max_s: Dict[str, float], ratios: Dict[str, float]) -> str:
+ interesting_sections = [
+ "collect/total",
+ "collect/generate",
+ "learn/total",
+ "learn/update_engine_weights",
+ "eval/total",
+ "checkpoint/total",
+ ]
+ parts = []
+ step_total = max_s.get("step/total", 0.0)
+ parts.append(f"step_total={step_total:.2f}s")
+ for name in interesting_sections:
+ if name not in max_s:
+ continue
+ parts.append(f"{name}={max_s[name]:.2f}s ({ratios.get(name, 0.0):.1%})")
+ return "profile: " + ", ".join(parts)
+
+ def _write_current_snapshot(self) -> None:
+ snapshot = self._build_current_snapshot()
+ if snapshot is None:
+ return
+ if not self._write_snapshot_if_current(self.rank_current_profile_path, snapshot):
+ return
+ if not self.is_rank_0:
+ return
+
+ global_snapshot = self._build_global_current_snapshot(snapshot)
+ if global_snapshot is None:
+ return
+ self._write_atomic_json(self.current_profile_path, global_snapshot)
+
+ def _build_current_snapshot(self) -> Optional[Dict]:
+ if not self.enabled:
+ return None
+
+ with self._state_lock:
+ if self.current_step is None or self.current_step_start_wall is None:
+ return None
+
+ current_elapsed_s = max(time.perf_counter() - self.current_step_start_wall, 0.0)
+ sections_local_s = dict(self.section_totals)
+ active_section_elapsed_s = None
+ if self.active_section_name is not None and self.active_section_start_wall is not None:
+ active_section_elapsed_s = max(time.perf_counter() - self.active_section_start_wall, 0.0)
+ sections_local_s[self.active_section_name
+ ] = (sections_local_s.get(self.active_section_name, 0.0) + active_section_elapsed_s)
+ current_ratios = {
+ name: (value / current_elapsed_s if current_elapsed_s > 0 else 0.0)
+ for name, value in sections_local_s.items()
+ }
+ snapshot = {
+ "train_step": self.current_step,
+ "episode": self.current_episode,
+ "rank": self.rank,
+ "world_size": self.world_size,
+ "started_at": self.current_step_started_at,
+ "partial": True,
+ "current_elapsed_s": current_elapsed_s,
+ "sections_local_s": sections_local_s,
+ "sections_local_ratio": current_ratios,
+ "_snapshot_generation": self._snapshot_generation,
+ }
+ if self.last_section_name is not None:
+ snapshot["last_section"] = self.last_section_name
+ if self.last_section_elapsed_s is not None:
+ snapshot["last_elapsed_s"] = self.last_section_elapsed_s
+ if self.active_section_name is not None:
+ snapshot["active_section"] = self.active_section_name
+ if active_section_elapsed_s is not None:
+ snapshot["active_section_elapsed_s"] = active_section_elapsed_s
+ return snapshot
+
+ def _build_global_current_snapshot(self, rank0_snapshot: Dict) -> Optional[Dict]:
+ current_step = rank0_snapshot.get("train_step")
+ current_episode = rank0_snapshot.get("episode")
+ if current_step is None:
+ return None
+
+ snapshots = []
+ available_ranks = []
+ active_sections = {}
+ for rank in range(self.world_size):
+ if rank == self.rank:
+ candidate = dict(rank0_snapshot)
+ else:
+ candidate = self._read_json(self.output_dir / f"step_profile.rank{rank}.current.json")
+ if not candidate:
+ continue
+ if candidate.get("train_step") != current_step or candidate.get("episode") != current_episode:
+ continue
+ snapshots.append(candidate)
+ available_ranks.append(rank)
+ active_section = candidate.get("active_section")
+ if active_section:
+ active_sections[f"rank{rank}"] = active_section
+
+ if not snapshots:
+ return None
+
+ aggregated = self._aggregate_sections([snapshot.get("sections_local_s", {}) for snapshot in snapshots])
+ elapsed_values = [float(snapshot.get("current_elapsed_s", 0.0)) for snapshot in snapshots]
+ max_elapsed = max(elapsed_values) if elapsed_values else 0.0
+ mean_elapsed = sum(elapsed_values) / len(elapsed_values) if elapsed_values else 0.0
+ max_ratios = {
+ name: (value / max_elapsed if max_elapsed > 0 else 0.0)
+ for name, value in aggregated["max_s"].items()
+ }
+ mean_ratios = {
+ name: (value / mean_elapsed if mean_elapsed > 0 else 0.0)
+ for name, value in aggregated["mean_s"].items()
+ }
+ started_at_candidates = [
+ snapshot.get("started_at") for snapshot in snapshots if snapshot.get("started_at") is not None
+ ]
+
+ global_snapshot = {
+ "train_step": current_step,
+ "episode": current_episode,
+ "world_size": self.world_size,
+ "available_ranks": available_ranks,
+ "num_rank_snapshots": len(snapshots),
+ "started_at": min(started_at_candidates) if started_at_candidates else None,
+ "partial": True,
+ "current_elapsed_max_s": max_elapsed,
+ "current_elapsed_mean_s": mean_elapsed,
+ "sections_max_s": aggregated["max_s"],
+ "sections_mean_s": aggregated["mean_s"],
+ "sections_max_ratio": max_ratios,
+ "sections_mean_ratio": mean_ratios,
+ }
+ if active_sections:
+ global_snapshot["active_sections"] = active_sections
+ return global_snapshot
+
+ def _start_heartbeat(self) -> None:
+ self._heartbeat_thread = threading.Thread(
+ target=self._heartbeat_loop,
+ name="step-profile-heartbeat",
+ daemon=True,
+ )
+ self._heartbeat_thread.start()
+
+ def _heartbeat_loop(self) -> None:
+ while not self._heartbeat_stop.wait(self.HEARTBEAT_INTERVAL_S):
+ self._write_current_snapshot()
+
+ def _write_snapshot_if_current(self, path: Path, payload: Dict) -> bool:
+ snapshot_generation = payload.get("_snapshot_generation")
+ if snapshot_generation is None:
+ self._write_atomic_json(path, payload)
+ return True
+
+ sanitized_payload = dict(payload)
+ sanitized_payload.pop("_snapshot_generation", None)
+ with self._write_lock:
+ with self._state_lock:
+ if snapshot_generation != self._snapshot_generation:
+ return False
+ self._write_atomic_json_unlocked(path, sanitized_payload)
+ return True
+
+ def _append_jsonl(self, path: Path, payload: Dict) -> None:
+ with self._write_lock:
+ self._append_jsonl_unlocked(path, payload)
+
+ @staticmethod
+ def _append_jsonl_unlocked(path: Path, payload: Dict) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with path.open("a", encoding="utf-8") as f:
+ f.write(json.dumps(payload, ensure_ascii=False, sort_keys=True) + "\n")
+ f.flush()
+ os.fsync(f.fileno())
+
+ def _write_atomic_json(self, path: Path, payload: Dict) -> None:
+ with self._write_lock:
+ self._write_atomic_json_unlocked(path, payload)
+
+ @staticmethod
+ def _write_atomic_json_unlocked(path: Path, payload: Dict) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ tmp_path = path.with_suffix(path.suffix + ".tmp")
+ with tmp_path.open("w", encoding="utf-8") as f:
+ json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True)
+ f.flush()
+ os.fsync(f.fileno())
+ os.replace(tmp_path, path)
+
+ @staticmethod
+ def _read_json(path: Path) -> Optional[Dict]:
+ try:
+ with path.open("r", encoding="utf-8") as f:
+ return json.load(f)
+ except (FileNotFoundError, json.JSONDecodeError, OSError):
+ return None
diff --git a/requirements.txt b/requirements.txt
index 56a06651..565ace1b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -25,3 +25,11 @@ librosa
qwen-vl-utils
sympy
matplotlib
+# URSA-MATH dependencies
+attrdict
+numpy
+pandas
+Pillow
+regex
+timm
+torchvision