diff --git a/.gitignore b/.gitignore index d3a81dae..54733cc6 100644 --- a/.gitignore +++ b/.gitignore @@ -1219,4 +1219,8 @@ wandb* examples/demo_grpo/results* build/* examples/math_benchmarks/eval_results/ -.llmconfig.yaml \ No newline at end of file +.llmconfig.yaml + +# Local agent tool state +.claude/ +.codex diff --git a/examples/math_prm/README.md b/examples/math_prm/README.md new file mode 100644 index 00000000..a5c385a4 --- /dev/null +++ b/examples/math_prm/README.md @@ -0,0 +1,272 @@ +# Math PRM: GRPO Training with a Process Reward Model + +This example trains [URSA-8B](https://huggingface.co/URSA-MATH/URSA-8B) — a multimodal math VLM — with [URSA-8B-RM](https://huggingface.co/URSA-MATH/URSA-RM-8B) as a Process Reward Model (PRM), using the GRPO algorithm with a **PS-GRPO** reward signal as proposed in the [URSA paper (NeurIPS 2025)](https://arxiv.org/abs/2501.04686). + +Unlike the rule-based examples under `examples/gsm8k_geo3k/`, the reward here comes from a neural reward model that scores **each reasoning step**, and the final per-trajectory reward depends on *how the step scores evolve* across the response, not just on whether the final answer is right. + +The example ships **two algorithm paths** side by side: + +1. **PS-GRPO** (`run_grpo_math_prm_ursa_8b.sh`) — the paper's recommended reward `r ∈ {0, 0.5, 1}`, used as a single per-trajectory scalar by standard GRPO. This is the production recipe. +2. **Strict paper Eq.9 variant 2** (`run_grpo_math_prm_ursa_8b_variant2.sh`) — the per-step PRM advantage `A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) + GroupNorm_G(r_o^i)` (paper Appendix B.1). The paper itself rejects this in favour of PS-GRPO; it ships here as an ablation comparator. The advantage estimator lives entirely in [`ursa_variant2.py`](ursa_variant2.py) (zero edits to `lightrft/`). + +## Overview + +| Item | Math PRM | +|------|----------| +| Task | Multimodal math reasoning (text + image questions) | +| Modality | Multi-modal (text + image) | +| Actor | URSA-8B (hybrid SAM-B + SigLIP-L vision tower + Qwen2.5-Math-Instruct) | +| Reward Model | URSA-8B-RM (process reward model, step-level scoring) | +| Reward formula (PS-GRPO) | `r ∈ {0, 0.5, 1}` (correctness × step-stability) | +| Algorithm | GRPO (group_norm advantage estimator) or `ursa_variant2` for paper Eq.9 | +| Rollout engine | Local Hugging Face (vLLM/SGLang URSA support is future work) | + +The PS-GRPO reward is computed inside `MathPRMReward` ([reward_models.py](reward_models.py)) and follows the URSA paper: + +```text +r = 0 if outcome_correct == 0 +r = 1 if outcome_correct == 1 and no step-score drop +r = 0.5 ( = 1 - DROP_GAMMA) if outcome_correct == 1 but a step-score drop occurred +``` + +A **step-score drop** is detected when any consecutive pair of step scores has a relative drop ≥ `_DROP_THRESHOLD = 0.3`. + +--- + +## 1. Dataset Preprocessing + +The training data is `MMathCoT-1M` (Stage 3 split), which needs to be converted into the LightRFT manifest schema. Both `--input-path` and `--image-root` are **required** (no defaults — paths are environment-specific): + +```bash +python examples/math_prm/tools/prepare_ursa_stage3_manifest.py \ + --input-path /your/data/URSA-MATH/MMathCoT-1M/train.jsonl \ + --image-root /your/data/URSA-MATH/images \ + --output-path /your/output/math_psgrpo.jsonl +``` + +Each row in the converted manifest looks like: + +```json +{ + "prompt": "Math question text", + "images": ["/abs/path/to/image.png"], + "reference": "Ground-truth answer", + "label": "math_psgrpo" +} +``` + +The `label` field is what selects the reward path. Available labels: + +| Label | Reward signal | +|---|---| +| `math_psgrpo` | PS-GRPO: `{0, 0.5, 1}` (default for this example) | +| `math_prm` | Pure PRM aggregated step score (continuous in `[0, 1]`) | +| `math_prm_combined` | PRM aggregated score + 0.5 × rule-based correctness | +| `math_rule` | Rule-only baseline `{0, 1}` based on answer match | +| `math_per_step_prm` | Per-step PRM scores for `--advantage_estimator ursa_variant2` (paper Eq.9, see §6) | + +For a smoke conversion (32 samples), pass `--max-samples 32`. + +--- + +## 2. Model Checkpoints + +You need both the URSA-8B actor and the URSA-8B-RM reward model: + +```bash +# Hugging Face IDs +URSA-MATH/URSA-8B # actor +URSA-MATH/URSA-RM-8B # reward model +``` + +Download to a local directory and set the paths in `run_grpo_math_prm_ursa_8b.sh`. + +--- + +## 3. Configure and Run Training (PS-GRPO recipe) + +Edit `Part 1: User Configuration` at the top of [run_grpo_math_prm_ursa_8b.sh](run_grpo_math_prm_ursa_8b.sh): + +```bash +PATH_TO_YOUR_BASE_MODEL="/path/to/URSA-8B" +PATH_TO_URSA_RM="/path/to/URSA-RM-8B" +PATH_TO_YOUR_MATH_DATASET="/path/to/math_psgrpo.jsonl" +EXPERIMENT_NAME="lightrft-ursa8b-math-prm" +export WANDB_API_KEY="YOUR_WANDB_API_KEY" # leave empty to disable W&B +``` + +Then run: + +```bash +bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh +``` + +The default machine target is `1 node × 8 A100/H100 GPUs`. For a different topology, override the standard env vars: + +```bash +NNODES=2 GPUS_PER_NODE=8 NODE_RANK=0 \ +MASTER_ADDR=10.0.0.1 MASTER_PORT=20092 \ +bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh +``` + +--- + +## 4. Key Hyperparameters + +The launcher uses the URSA-MATH paper's Stage 3 defaults: + +| Param | Value | Notes | +|---|---|---| +| `N_SAMPLES` | 8 | Responses sampled per prompt for GRPO | +| `EPISODE` | 10 | Total training episodes | +| `RBS` / `TBS` | 128 / 128 | Rollout / training batch size | +| `KL` | 0.001 | Initial KL coefficient | +| `KL_TARGET` | (off) | If set, switches to AdaptiveKLController | +| `LR` | 1e-6 | Actor learning rate | +| `PROMPT_MAX_LEN` | 1024 | | +| `GENERATE_MAX_LEN` | 3072 | | +| `MAX_SAMPLES` | 15360 | Cap on training subset (paper proxy) | +| `EVAL_HOLDOUT_SIZE` | 500 | A deterministic held-out subset is reserved from `prompt_data` for in-domain eval | + +To enable the adaptive KL controller (recommended if you observe the KL drifting), set `KL_TARGET` to a small positive value, e.g. `KL_TARGET=0.5`. + +--- + +## 5. What's Logged + +W&B panels are split into three namespaces: + +- `rollout/*` — per-step rollout statistics: `reward`, `outcome_correct`, `model_reward`, `has_drop_moment`, `response_length`, `step_score_min/mean/last`, `step_count`, `final_reward`, `max_relative_drop`, `answer_tag_present`, `answer_extraction_failed`, `used_answer_fallback`, `used_mathruler`, `reference_supported`, plus variant-2 diagnostics `alignment_failed` / `n_aligned_steps`. +- `train/*` — per-step training statistics: `policy_loss`, `kl`, `actor_lr`, `advantages`, `return`, plus variant-2 diagnostics `ursa_v2_adv_pos_frac` / `_neg_frac` / `_zero_frac` / `_abs_mean` / `_oc_normed_std` / `_msp_normed_std` / `_traj_step_count_mean`. +- `eval/*` — evaluation pass on the held-out split: `reward`, `outcome_correct`, `response_length`, `answer_extraction_failed`, `has_drop_moment`, `model_reward`, `step_score_min/mean/last`, `step_count`, `final_reward`, `max_relative_drop`, `answer_tag_present`, `used_answer_fallback`, `used_mathruler`, `reference_supported`. + +The full per-sample reward metric set emitted by `MathPRMReward` is documented at the top of `forward()` in [reward_models.py](reward_models.py). + +--- + +## 6. Strict Paper Eq.9 — variant 2 path + +`run_grpo_math_prm_ursa_8b_variant2.sh` runs the URSA paper's Appendix B.1 Eq.9 "variant 2" advantage formula side by side with PS-GRPO so the two can be ablated. The implementation lives in [`ursa_variant2.py`](ursa_variant2.py) as a new `--advantage_estimator ursa_variant2` registered via an idempotent monkey-patch from `examples/math_prm/` (no edits to `lightrft/`). + +### Formula + +```text +A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) ← process-reward term + + GroupNorm_G(r_o^i) ← outcome-reward term +``` + +where `t` indexes a **step** (not a token), `r_{s,t}^i` is the sigmoid PRM score for step `t` in trajectory `i`, `r̄_s^i = mean_t r_{s,t}^i`, `r_o^i ∈ {0,1}` is the outcome reward, and `G` is the GRPO group size (`n_samples_per_prompt`). The per-step `A_t^i` is broadcast to every token within step `t`'s span. **There is no cumulative return**, and the outcome term is preserved (not bypassed). + +### Workflow + +The variant-2 path requires rows labeled `math_per_step_prm` instead of `math_psgrpo`. Easiest way is to sed-relabel the PS-GRPO manifest: + +```bash +sed 's/"label":[ ]*"math_psgrpo"/"label": "math_per_step_prm"/g' \ + /path/to/math_psgrpo.jsonl \ + > /path/to/math_per_step_prm.jsonl +``` + +The variant-2 launcher will auto-swap to the relabeled sibling if it finds the legacy psgrpo path in `PATH_TO_YOUR_MATH_DATASET`, and assert that the first row's label is `math_per_step_prm` before training. Set `PATH_TO_YOUR_MATH_DATASET_VARIANT2` to a custom path to override this. + +```bash +PATH_TO_YOUR_MATH_DATASET_VARIANT2=/path/to/math_per_step_prm.jsonl \ +bash examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh +``` + +`--per_step_reward_mode` (`raw` / `group_norm`) only affects the **legacy Math-Shepherd-style per-token reward path** (different `_apply_step_reward_group_norm` aggregation); `--advantage_estimator ursa_variant2` does its own group normalization inside the calculator and is unaffected by this flag. + +### Unit tests + +`test_ursa_variant2.py` ships 9 acceptance-criterion tests (numerical equality with hand-computed Eq.9, GroupNorm correctness for K=2/K=4 groups, span broadcast, outcome-term non-bypass). Run them: + +```bash +python3 -m unittest examples.math_prm.test_ursa_variant2 -v +``` + +--- + +## 7. Results — 9-day production run (PS-GRPO) + +A 9-day full production run on 8× H100 with the PS-GRPO recipe is summarized below. The variant-2 path was validated by a parallel 9-day run (W&B `kdwjt4eo`); see [PR #53 final-report comment](https://github.com/opendilab/LightRFT/pull/53#issuecomment-4608400929) for the side-by-side comparison. + +| Metric | Baseline (Step 20) | Peak | Final | Δ vs baseline | +|---|---|---|---|---| +| `eval/outcome_correct` | 0.5952 | **0.6508** (Step 231) | 0.6290 (Step 1008) | **+3.4 pp** | +| `eval/answer_extraction_failed` | 0.028 | 0.018 (~Step 160) | 0.034 | -0.6 pp ↓ | +| `eval/has_drop_moment` | 0.0 | — | 0.0 | (PRM never triggered) | +| `eval/response_length` | 400 | 337 (~Step 240) | 377 | -23 ↓ | +| `rollout/alignment_failed` | 0 | — | 0 | 100% step-boundary alignment | +| W&B run | [`kdwjt4eo`](https://wandb.ai/hansbug/LightRFT-URSA8B-Stage3/runs/kdwjt4eo) | + +#### Eval trajectory + +`eval/outcome_correct` peaks at Step 231 (+5.6pp), shows a transient dip near Step 300 (reward-hacking signature) but **self-heals**, and stabilizes in the 0.60–0.65 range for the remaining 7 days: + +![eval trajectory](assets/exp_20260603/eval_outcome.png) + +#### KL + rollout overview + +`train/kl` exits warmup quickly (1e-4 → 1.0 by Step ~200), oscillates in 1–100 range, occasional single-batch spikes >100 always self-correct. `rollout/outcome_correct` and `rollout/model_reward` track together (no reward-hacking decoupling): + +![KL + rollout](assets/exp_20260603/kl_and_rollout.png) + +#### Eval-time generation quality + +`eval/answer_extraction_failed` briefly spikes during the Step 300 dip (the `†Answer:` marker format drift the URSA paper warns about) but recovers back to 2–5%. `eval/response_length` and `eval/step_count` stay stable — no length collapse: + +![eval quality](assets/exp_20260603/eval_quality.png) + +#### variant 2 path health (W&B run `kdwjt4eo`) + +`ursa_v2_adv_pos_frac` and `_neg_frac` stay roughly balanced (30–40% / 25–35%) — GroupNorm produces signed advantages as expected. `_msp_normed_std` stays close to 1.0. `rollout/alignment_failed` is 0 the entire run: + +![variant 2 health](assets/exp_20260603/variant2_health.png) + +--- + +## 8. Files Under This Directory + +```text +examples/math_prm/ +├── README.md - This guide (en) +├── README_zh.md - This guide (zh) +├── train_colocate.py - Main training entry (called by torchrun) +├── run_grpo_math_prm_ursa_8b.sh - PS-GRPO launcher (recommended) +├── run_grpo_math_prm_ursa_8b_variant2.sh - Strict paper Eq.9 launcher (ablation) +├── reward_models.py - MathPRMReward implementation (PS-GRPO) +├── reward_models_utils.py - Reward recipe / mixing logic per label +├── ursa_actor.py - URSA-specific actor wrapper +├── ursa_variant2.py - UrsaVariant2Calculator (paper Eq.9, examples-only) +├── math_prm_trainer.py - MathPRMSPMDPPOTrainerVL (wandb metric mapping) +├── math_prm_output.py - "†Answer:" marker / structured-stop helpers +├── rollout_eos_patch.py - StoppingCriteria injection for reliable EOS under FSDP +├── test_ursa_variant2.py - 9 unit tests for variant 2 (AC1-AC5) +├── ursa_model/ - Vendored URSA model code (config / processor / model) +├── tools/ +│ ├── prepare_ursa_stage3_manifest.py - Dataset conversion tool +│ └── prepare_ursa_engine_checkpoint.py - Engine-mode checkpoint wrapper +└── assets/ + └── exp_20260603/ - W&B screenshots from the 9-day production run +``` + +--- + +## 9. Citation + +If you use this example, please cite the URSA paper: + +```bibtex +@article{luo2025ursa, + title={URSA: Understanding and Verifying Chain-of-Thought Reasoning in Multimodal Mathematics}, + author={Luo, Ruilin and Zheng, Zhuofan and Wang, Yifan and Yu, Yiyao and Ni, Xinzhe and Lin, Zicheng and Zeng, Jin and Yang, Yujiu}, + journal={NeurIPS}, + year={2025} +} +``` + +--- + +## License + +This example is released under the same license as the parent LightRFT project (see top-level `LICENSE`). diff --git a/examples/math_prm/README_zh.md b/examples/math_prm/README_zh.md new file mode 100644 index 00000000..d6ea7f2b --- /dev/null +++ b/examples/math_prm/README_zh.md @@ -0,0 +1,272 @@ +# Math PRM:基于过程奖励模型 (PRM) 的 GRPO 训练 + +本示例使用 [URSA-8B](https://huggingface.co/URSA-MATH/URSA-8B)(多模态数学 VLM)作为 actor,配合 [URSA-8B-RM](https://huggingface.co/URSA-MATH/URSA-RM-8B) 作为过程奖励模型 (PRM),按 [URSA 论文(NeurIPS 2025)](https://arxiv.org/abs/2501.04686)所提的 **PS-GRPO** 奖励路径用 GRPO 算法训练。 + +不同于 `examples/gsm8k_geo3k/` 那类规则型 reward 示例,这里的 reward 来自一个对**每一步推理**打分的神经网络奖励模型,最终的 trajectory-level reward 取决于 step scores 沿 response 的**演化形态**,而不仅仅是最终答案是否正确。 + +本 example 同时附带**两条算法路径**用于对比: + +1. **PS-GRPO**(`run_grpo_math_prm_ursa_8b.sh`)—— 论文最终采纳的 `r ∈ {0, 0.5, 1}` 单标量奖励,由标准 GRPO 处理。**生产推荐配方**。 +2. **Paper Eq.9 严格 variant 2**(`run_grpo_math_prm_ursa_8b_variant2.sh`)—— 论文附录 B.1 的逐 step PRM advantage:`A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) + GroupNorm_G(r_o^i)`。论文自身否决了它,本 example 保留只为做 ablation 对照。完整实现位于 [`ursa_variant2.py`](ursa_variant2.py)(不修改 `lightrft/`)。 + +## 总览 + +| 项 | Math PRM | +|------|----------| +| 任务 | 多模态数学推理(文本+图像题) | +| 模态 | 多模态(文本 + 图像) | +| Actor | URSA-8B(SAM-B + SigLIP-L 视觉塔 + Qwen2.5-Math-Instruct) | +| Reward Model | URSA-8B-RM(过程奖励模型,step-level scoring) | +| Reward 公式(PS-GRPO) | `r ∈ {0, 0.5, 1}`(正确性 × step 稳定性) | +| 算法 | GRPO(group_norm 优势估计器)或 paper Eq.9 的 `ursa_variant2` | +| Rollout 引擎 | 本地 Hugging Face(vLLM/SGLang 对 URSA 的支持待后续) | + +PS-GRPO 奖励在 `MathPRMReward`([reward_models.py](reward_models.py))中按 URSA 论文公式计算: + +```text +r = 0 若 outcome_correct == 0 +r = 1 若 outcome_correct == 1 且无 step-score drop +r = 0.5 ( = 1 - DROP_GAMMA) 若 outcome_correct == 1 但存在 step-score drop +``` + +**Step-score drop** 的判定:任意相邻 step score 出现相对降幅 ≥ `_DROP_THRESHOLD = 0.3`。 + +--- + +## 1. 数据预处理 + +训练数据为 `MMathCoT-1M`(Stage 3 子集),需要转换成 LightRFT 的 manifest 格式。`--input-path` 与 `--image-root` 均**必填**(无默认值——路径与环境相关): + +```bash +python examples/math_prm/tools/prepare_ursa_stage3_manifest.py \ + --input-path /your/data/URSA-MATH/MMathCoT-1M/train.jsonl \ + --image-root /your/data/URSA-MATH/images \ + --output-path /your/output/math_psgrpo.jsonl +``` + +转换后每行 manifest 形如: + +```json +{ + "prompt": "数学题目文本", + "images": ["/abs/path/to/image.png"], + "reference": "标准答案", + "label": "math_psgrpo" +} +``` + +`label` 字段决定选择哪一条 reward 路径。可选值: + +| Label | Reward 信号 | +|---|---| +| `math_psgrpo` | PS-GRPO:`{0, 0.5, 1}`(本 example 默认) | +| `math_prm` | 纯 PRM 聚合 step score(连续值 `[0, 1]`) | +| `math_prm_combined` | PRM 聚合分数 + 0.5 × 规则正确性 | +| `math_rule` | 规则基线:`{0, 1}` 按答案匹配 | +| `math_per_step_prm` | 逐 step PRM 分数,供 `--advantage_estimator ursa_variant2`(paper Eq.9,详见 §6)使用 | + +需要 32 行小规模转换做 smoke 时用 `--max-samples 32`。 + +--- + +## 2. 模型 checkpoint + +需要 URSA-8B(actor)与 URSA-8B-RM(reward model)两个权重: + +```bash +# Hugging Face IDs +URSA-MATH/URSA-8B # actor +URSA-MATH/URSA-RM-8B # reward model +``` + +下载到本地后在 `run_grpo_math_prm_ursa_8b.sh` 里配置路径。 + +--- + +## 3. 配置并启动训练(PS-GRPO 配方) + +编辑 [run_grpo_math_prm_ursa_8b.sh](run_grpo_math_prm_ursa_8b.sh) 顶部的 `Part 1: User Configuration`: + +```bash +PATH_TO_YOUR_BASE_MODEL="/path/to/URSA-8B" +PATH_TO_URSA_RM="/path/to/URSA-RM-8B" +PATH_TO_YOUR_MATH_DATASET="/path/to/math_psgrpo.jsonl" +EXPERIMENT_NAME="lightrft-ursa8b-math-prm" +export WANDB_API_KEY="YOUR_WANDB_API_KEY" # 留空则禁用 W&B +``` + +然后运行: + +```bash +bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh +``` + +默认目标硬件 `1 节点 × 8 A100/H100`。改 topology 用环境变量 override: + +```bash +NNODES=2 GPUS_PER_NODE=8 NODE_RANK=0 \ +MASTER_ADDR=10.0.0.1 MASTER_PORT=20092 \ +bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh +``` + +--- + +## 4. 关键超参 + +启动脚本使用 URSA-MATH 论文 Stage 3 默认值: + +| 参数 | 值 | 备注 | +|---|---|---| +| `N_SAMPLES` | 8 | 每个 prompt GRPO 采样数 | +| `EPISODE` | 10 | 总训练 episodes | +| `RBS` / `TBS` | 128 / 128 | rollout / training batch size | +| `KL` | 0.001 | 初始 KL 系数 | +| `KL_TARGET` | (off) | 设值后切到 AdaptiveKLController | +| `LR` | 1e-6 | Actor 学习率 | +| `PROMPT_MAX_LEN` | 1024 | | +| `GENERATE_MAX_LEN` | 3072 | | +| `MAX_SAMPLES` | 15360 | 训练子集上限(论文 proxy) | +| `EVAL_HOLDOUT_SIZE` | 500 | 从 `prompt_data` 中保留的确定性 held-out 子集 | + +观察到 KL 飘移时建议开 adaptive KL 控制器:`KL_TARGET=0.5`。 + +--- + +## 5. 日志字段 + +W&B 面板按三个 namespace 划分: + +- `rollout/*` — 每步 rollout 统计:`reward`、`outcome_correct`、`model_reward`、`has_drop_moment`、`response_length`、`step_score_min/mean/last`、`step_count`、`final_reward`、`max_relative_drop`、`answer_tag_present`、`answer_extraction_failed`、`used_answer_fallback`、`used_mathruler`、`reference_supported`,以及 variant-2 诊断字段 `alignment_failed` / `n_aligned_steps`。 +- `train/*` — 每步训练统计:`policy_loss`、`kl`、`actor_lr`、`advantages`、`return`,以及 variant-2 诊断字段 `ursa_v2_adv_pos_frac` / `_neg_frac` / `_zero_frac` / `_abs_mean` / `_oc_normed_std` / `_msp_normed_std` / `_traj_step_count_mean`。 +- `eval/*` — held-out 评测:`reward`、`outcome_correct`、`response_length`、`answer_extraction_failed`、`has_drop_moment`、`model_reward`、`step_score_min/mean/last`、`step_count`、`final_reward`、`max_relative_drop`、`answer_tag_present`、`used_answer_fallback`、`used_mathruler`、`reference_supported`。 + +`MathPRMReward` 输出的全套 per-sample 奖励 metric 文档见 [reward_models.py](reward_models.py) 中 `forward()` 顶部注释。 + +--- + +## 6. Paper Eq.9 严格对齐 — variant 2 路径 + +`run_grpo_math_prm_ursa_8b_variant2.sh` 是 URSA 论文附录 B.1 Eq.9 "variant 2" 严格实现,与 PS-GRPO 并行存在,用于 ablation 对比。实现在 [`ursa_variant2.py`](ursa_variant2.py),通过幂等 monkey-patch 注册一个新 `--advantage_estimator ursa_variant2`(不修改 `lightrft/`)。 + +### 公式 + +```text +A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) ← process-reward 项 + + GroupNorm_G(r_o^i) ← outcome-reward 项 +``` + +其中 `t` 是 **step 索引**(不是 token),`r_{s,t}^i` 为 trajectory `i` 第 `t` 个 step 的 sigmoid PRM 分,`r̄_s^i = mean_t r_{s,t}^i`,`r_o^i ∈ {0,1}` 是 outcome reward,`G` 是 GRPO group size(`n_samples_per_prompt`)。逐 step `A_t^i` 广播到该 step 覆盖的所有 token。**无 cumulative return**,outcome 项保留(不像 Math-Shepherd 风格的 Mode B 那样被丢弃)。 + +### 数据集 / 启动流程 + +variant 2 路径要求 manifest 行 label 是 `math_per_step_prm` 而不是 `math_psgrpo`。最简单方法是 sed-relabel PS-GRPO manifest: + +```bash +sed 's/"label":[ ]*"math_psgrpo"/"label": "math_per_step_prm"/g' \ + /path/to/math_psgrpo.jsonl \ + > /path/to/math_per_step_prm.jsonl +``` + +variant 2 启动脚本会自动检测 `PATH_TO_YOUR_MATH_DATASET` 是否指向 psgrpo 路径,若是则自动 swap 到 `*per_step_prm*.jsonl` 兄弟文件,并在训练前 assert 首行 label 是 `math_per_step_prm`。如需自定义路径,设 `PATH_TO_YOUR_MATH_DATASET_VARIANT2`。 + +```bash +PATH_TO_YOUR_MATH_DATASET_VARIANT2=/path/to/math_per_step_prm.jsonl \ +bash examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh +``` + +`--per_step_reward_mode`(`raw` / `group_norm`)只影响**遗留 Math-Shepherd 风格逐 token reward 路径**(`_apply_step_reward_group_norm` 不同聚合方式);`--advantage_estimator ursa_variant2` 自带 group normalization,不受该 flag 影响。 + +### 单元测试 + +`test_ursa_variant2.py` 包含 9 个 AC(acceptance criterion)级单测:与手算 Eq.9 数值等价、K=2/K=4 group 的 GroupNorm 正确性、span 广播、outcome 项非旁路。运行: + +```bash +python3 -m unittest examples.math_prm.test_ursa_variant2 -v +``` + +--- + +## 7. 实验结果 — 9 天生产 run(PS-GRPO) + +8× H100 上 PS-GRPO 配方跑满 9 天的关键指标如下。variant 2 路径并行跑了同样 9 天作对照(W&B `kdwjt4eo`),完整对比见 [PR #53 最终报告 comment](https://github.com/opendilab/LightRFT/pull/53#issuecomment-4608400929)。 + +| 指标 | baseline (Step 20) | peak | final | Δ vs baseline | +|---|---|---|---|---| +| `eval/outcome_correct` | 0.5952 | **0.6508** (Step 231) | 0.6290 (Step 1008) | **+3.4 pp** | +| `eval/answer_extraction_failed` | 0.028 | 0.018 (~Step 160) | 0.034 | -0.6 pp ↓ | +| `eval/has_drop_moment` | 0.0 | — | 0.0 | (PRM 全程未触发) | +| `eval/response_length` | 400 | 337 (~Step 240) | 377 | -23 ↓ | +| `rollout/alignment_failed` | 0 | — | 0 | 100% step 边界对齐 | +| W&B run | [`kdwjt4eo`](https://wandb.ai/hansbug/LightRFT-URSA8B-Stage3/runs/kdwjt4eo) | + +#### eval 轨迹 + +`eval/outcome_correct` 在 Step 231 见峰 +5.6pp,约 Step 300 出现一次 dip(reward hacking signature)但**自愈**,剩 7 天稳定在 0.60–0.65 区间: + +![eval trajectory](assets/exp_20260603/eval_outcome.png) + +#### KL + rollout 全局视角 + +`train/kl` 走出 warmup 后(1e-4 → 1.0 by Step ~200),在 1–100 区间震荡,偶发单 batch >100 spike 总能自愈。`rollout/outcome_correct` 与 `rollout/model_reward` 长期同向变化(无 reward hacking 解耦): + +![KL + rollout](assets/exp_20260603/kl_and_rollout.png) + +#### eval 生成质量 + +`eval/answer_extraction_failed` 在 Step 300 dip 期间短暂飙到 18%(URSA 论文警告的 `†Answer:` 格式漂移信号),之后回稳到 2–5%。`eval/response_length` 与 `eval/step_count` 稳定 — 无 length collapse: + +![eval quality](assets/exp_20260603/eval_quality.png) + +#### variant 2 路径健康度(W&B run `kdwjt4eo`) + +`ursa_v2_adv_pos_frac` 与 `_neg_frac` 长期保持平衡(30–40% / 25–35%)—— GroupNorm 持续产出 signed advantages。`_msp_normed_std` 贴近 1.0。`rollout/alignment_failed` 全程 0: + +![variant 2 health](assets/exp_20260603/variant2_health.png) + +--- + +## 8. 目录文件清单 + +```text +examples/math_prm/ +├── README.md - 本文档(英文) +├── README_zh.md - 本文档(中文) +├── train_colocate.py - 训练主入口(由 torchrun 调用) +├── run_grpo_math_prm_ursa_8b.sh - PS-GRPO 启动脚本(推荐) +├── run_grpo_math_prm_ursa_8b_variant2.sh - paper Eq.9 严格启动脚本(ablation) +├── reward_models.py - MathPRMReward 实现(PS-GRPO) +├── reward_models_utils.py - 按 label 选 reward 配方的逻辑 +├── ursa_actor.py - URSA actor wrapper +├── ursa_variant2.py - UrsaVariant2Calculator(paper Eq.9,纯 examples/) +├── math_prm_trainer.py - MathPRMSPMDPPOTrainerVL(wandb metric 映射) +├── math_prm_output.py - "†Answer:" marker / structured-stop helpers +├── rollout_eos_patch.py - FSDP 下可靠 EOS 的 StoppingCriteria 注入 +├── test_ursa_variant2.py - variant 2 的 9 个 AC 级单测 +├── ursa_model/ - 内置 URSA 模型代码(config / processor / model) +├── tools/ +│ ├── prepare_ursa_stage3_manifest.py - 数据集转换工具 +│ └── prepare_ursa_engine_checkpoint.py - Engine-mode checkpoint wrapper +└── assets/ + └── exp_20260603/ - 9 天生产 run 的 W&B 截图 +``` + +--- + +## 9. 引用 + +使用本 example 请引用 URSA 论文: + +```bibtex +@article{luo2025ursa, + title={URSA: Understanding and Verifying Chain-of-Thought Reasoning in Multimodal Mathematics}, + author={Luo, Ruilin and Zheng, Zhuofan and Wang, Yifan and Yu, Yiyao and Ni, Xinzhe and Lin, Zicheng and Zeng, Jin and Yang, Yujiu}, + journal={NeurIPS}, + year={2025} +} +``` + +--- + +## License + +本 example 与上层 LightRFT 项目使用同一 License(见仓库根目录 `LICENSE`)。 diff --git a/examples/math_prm/assets/exp_20260603/eval_outcome.png b/examples/math_prm/assets/exp_20260603/eval_outcome.png new file mode 100644 index 00000000..18d056c0 Binary files /dev/null and b/examples/math_prm/assets/exp_20260603/eval_outcome.png differ diff --git a/examples/math_prm/assets/exp_20260603/eval_quality.png b/examples/math_prm/assets/exp_20260603/eval_quality.png new file mode 100644 index 00000000..40d4bd64 Binary files /dev/null and b/examples/math_prm/assets/exp_20260603/eval_quality.png differ diff --git a/examples/math_prm/assets/exp_20260603/kl_and_rollout.png b/examples/math_prm/assets/exp_20260603/kl_and_rollout.png new file mode 100644 index 00000000..65a150dd Binary files /dev/null and b/examples/math_prm/assets/exp_20260603/kl_and_rollout.png differ diff --git a/examples/math_prm/assets/exp_20260603/variant2_health.png b/examples/math_prm/assets/exp_20260603/variant2_health.png new file mode 100644 index 00000000..d86be4f9 Binary files /dev/null and b/examples/math_prm/assets/exp_20260603/variant2_health.png differ diff --git a/examples/math_prm/math_prm_output.py b/examples/math_prm/math_prm_output.py new file mode 100644 index 00000000..ac86c892 --- /dev/null +++ b/examples/math_prm/math_prm_output.py @@ -0,0 +1,109 @@ +""" +Helpers for URSA Math PRM structured outputs. + +These helpers centralize the heuristics used to keep Phase 3 generation inside the +expected `Step N:` / `†Answer:` format without introducing Phase 4 reward logic. +""" + +import re +from typing import Optional + + +MATH_PRM_STRUCTURED_LABELS = frozenset({"math_prm", "math_prm_combined", "math_psgrpo"}) +MATH_PRM_ANSWER_MARKER = "†Answer:" +_MAX_MATH_PRM_ANSWER_WORDS = 24 +_MAX_MATH_PRM_ANSWER_CHARS = 160 +_EARLY_STOP_ANSWER_WORDS = 12 +_EARLY_STOP_ANSWER_CHARS = 80 +_BOOLEAN_ANSWERS = {"yes", "no", "true", "false"} +_ALGEBRAIC_ANSWER_PATTERN = re.compile( + r"[A-Za-z][A-Za-z0-9_]*(?:\s*[,;]\s*[A-Za-z][A-Za-z0-9_]*)*\s*=\s*[-+A-Za-z0-9$\\][A-Za-z0-9\s,./%()=\-+*$\\^{}]*" +) + + +def is_math_prm_structured_label(label: Optional[str]) -> bool: + return isinstance(label, str) and label.lower() in MATH_PRM_STRUCTURED_LABELS + + +def find_math_prm_tail_cutoff(text: str) -> Optional[int]: + cut_positions = [] + for pattern in ( + r"(? str: + if not response_text: + return response_text + return re.sub(r"(?m)^StepStep\s+(\d+:)", r"Step \1", response_text) + + +def _extract_answer_line(response_text: str) -> tuple[str, str, bool]: + normalized_text = _normalize_math_prm_response(response_text) + marker_index = normalized_text.find(MATH_PRM_ANSWER_MARKER) + if marker_index < 0: + return normalized_text, "", False + + answer_tail = normalized_text[marker_index + len(MATH_PRM_ANSWER_MARKER):].lstrip() + answer_lines = answer_tail.splitlines() + answer_line = " ".join(answer_lines[0].split()) if answer_lines else "" + has_more_lines = len(answer_lines) > 1 + return normalized_text, answer_line, has_more_lines + + +def should_stop_math_prm_response_text(response_text: str) -> bool: + normalized_text, answer_line, has_more_lines = _extract_answer_line(response_text) + if normalized_text.find(MATH_PRM_ANSWER_MARKER) < 0 or not answer_line: + return False + if has_more_lines: + return True + if find_math_prm_tail_cutoff(answer_line) is not None: + return True + + lower_answer = answer_line.lower() + if lower_answer in _BOOLEAN_ANSWERS: + return True + if re.fullmatch(r"[-+]?[$]?\d[\d\s,./%()=-]*", answer_line): + return True + if _ALGEBRAIC_ANSWER_PATTERN.fullmatch(answer_line): + return True + if re.fullmatch(r"[A-E]", answer_line): + return True + if answer_line.endswith((".", "!", "?", "%", ")", "]")): + return True + if len(answer_line.split()) >= _EARLY_STOP_ANSWER_WORDS: + return True + if len(answer_line) >= _EARLY_STOP_ANSWER_CHARS: + return True + return False + + +def sanitize_math_prm_response_text(response_text: str) -> str: + normalized_text, answer_line, _ = _extract_answer_line(response_text) + marker_index = normalized_text.find(MATH_PRM_ANSWER_MARKER) + if marker_index < 0: + return normalized_text + + prefix = normalized_text[: marker_index + len(MATH_PRM_ANSWER_MARKER)] + + cutoff = find_math_prm_tail_cutoff(answer_line) + if cutoff is not None: + answer_line = answer_line[:cutoff] + + answer_words = answer_line.split() + if len(answer_words) > _MAX_MATH_PRM_ANSWER_WORDS: + answer_line = " ".join(answer_words[:_MAX_MATH_PRM_ANSWER_WORDS]) + if len(answer_line) > _MAX_MATH_PRM_ANSWER_CHARS: + truncated = answer_line[:_MAX_MATH_PRM_ANSWER_CHARS] + answer_line = truncated.rsplit(" ", 1)[0] or truncated + + answer_line = answer_line.rstrip(" ,;:") + return prefix.rstrip() if not answer_line else f"{prefix} {answer_line}".rstrip() diff --git a/examples/math_prm/math_prm_trainer.py b/examples/math_prm/math_prm_trainer.py new file mode 100644 index 00000000..d6a02d09 --- /dev/null +++ b/examples/math_prm/math_prm_trainer.py @@ -0,0 +1,455 @@ +from contextlib import contextmanager +from typing import Dict, Optional + +import torch + +from lightrft.trainer.spmd_ppo_trainer import SPMDPPOTrainerVL + +# Explicitly register the ursa_variant2 monkey-patches so +# ``--advantage_estimator ursa_variant2`` resolves to the paper Eq.9 +# strict-alignment calculator. train_colocate.py runs with +# cwd=examples/math_prm, so a top-level (non-package) import is used here. +# (register_ursa_variant2() is idempotent.) +from ursa_variant2 import register_ursa_variant2 + +register_ursa_variant2() + + +def _detach_rollout_eos_patch(rollout_actor): + """Detach rollout_eos_patch.StructuredAnswerStoppingCriteria wrap from a rollout actor. + + Returns the unwrapped (original) generate function so the caller can restore + the patch later. Returns None if no patch is installed. + + The patch wraps ``model.generate`` with ``functools.wraps``, so the original + function is reachable via ``__wrapped__``. We rely on the patch's idempotency + flag ``_math_prm_rollout_eos_patch_installed`` to detect installation. + """ + if rollout_actor is None: + return None + model = getattr(rollout_actor, "model", None) + if model is None: + return None + if not getattr(model, "_math_prm_rollout_eos_patch_installed", False): + return None + patched = model.generate + original = getattr(patched, "__wrapped__", None) + if original is None: + return None + model.generate = original + model._math_prm_rollout_eos_patch_installed = False + return patched + + +def _reattach_rollout_eos_patch(rollout_actor, patched_generate): + """Reinstall a previously detached patched generate function.""" + if rollout_actor is None or patched_generate is None: + return + model = getattr(rollout_actor, "model", None) + if model is None: + return + model.generate = patched_generate + model._math_prm_rollout_eos_patch_installed = True + + +class MathPRMSPMDPPOTrainerVL(SPMDPPOTrainerVL): + """SPMD PPO trainer specialized for URSA-MATH process-reward training. + + Differs from the base ``SPMDPPOTrainerVL`` in two ways: + + 1. **W&B namespace mapping** — rollout/train/eval metric streams are + projected into the ``rollout/``, ``train/``, ``eval/`` namespaces via + ``_ROLLOUT_KEY_SOURCES`` / ``_TRAIN_KEY_SOURCES`` / ``_EVAL_KEY_SOURCES`` + so the dashboards stay aligned with the URSA paper's reporting + conventions even as upstream metric names drift. + 2. **URSA-specific eval** — :meth:`evaluate` runs under a runtime context + that prevents the actor from generating ``<|image|>`` sentinel tokens + and aggregates per-dataset metrics into a single weighted-average + view (see :meth:`_aggregate_eval_metrics`). + + All checkpoint, logging, profile, and trajectory-saving wiring is unchanged + from the base class apart from the namespace mapping. + """ + + _ROLLOUT_KEY_SOURCES = { + "reward": ("rollout_reward", "step_reward_mean", "reward"), + "reward_std": ("rollout_reward_std", "step_reward_std"), + "outcome_correct": ("rollout_outcome_correct", "outcome_correct_mean", "reward_metrics/outcome_correct"), + "has_drop_moment": ("rollout_has_drop_moment", "has_drop_moment_mean", "reward_metrics/has_drop_moment"), + "model_reward": ("rollout_model_reward", "model_reward_mean", "reward_metrics/model_reward"), + "response_length": ("rollout_response_length", "response_length_mean", "response_length"), + # PRM step-score distribution (computed by MathPRMReward.forward but + # previously not surfaced to wandb; useful for monitoring PRM behaviour + # at scale, not just min-aggregation). + "step_score_min": ("rollout_step_score_min", "step_score_min_mean", "reward_metrics/step_score_min"), + "step_score_mean": ("rollout_step_score_mean", "step_score_mean_mean", "reward_metrics/step_score_mean"), + "step_score_last": ("rollout_step_score_last", "step_score_last_mean", "reward_metrics/step_score_last"), + "step_count": ("rollout_step_count", "step_count_mean", "reward_metrics/step_count"), + # PS-GRPO + answer-extraction diagnostics (already computed, were missing + # from wandb mapping — required for reward-hacking forensics). + "final_reward": ("rollout_final_reward", "final_reward_mean", "reward_metrics/final_reward"), + "max_relative_drop": ( + "rollout_max_relative_drop", "max_relative_drop_mean", "reward_metrics/max_relative_drop" + ), + "answer_tag_present": ( + "rollout_answer_tag_present", "answer_tag_present_mean", "reward_metrics/answer_tag_present" + ), + "answer_extraction_failed": ( + "rollout_answer_extraction_failed", "answer_extraction_failed_mean", + "reward_metrics/answer_extraction_failed" + ), + "used_answer_fallback": ( + "rollout_used_answer_fallback", "used_answer_fallback_mean", "reward_metrics/used_answer_fallback" + ), + "used_mathruler": ("rollout_used_mathruler", "used_mathruler_mean", "reward_metrics/used_mathruler"), + "reference_supported": ( + "rollout_reference_supported", "reference_supported_mean", "reward_metrics/reference_supported" + ), + # Variant 2 (per-step PRM) diagnostics — populated only when + # the dataset row label is "math_per_step_prm". For "math_psgrpo" + # rows these stay 0 (no alignment was attempted). + "alignment_failed": ("rollout_alignment_failed", "alignment_failed_mean", "reward_metrics/alignment_failed"), + "n_aligned_steps": ("rollout_n_aligned_steps", "n_aligned_steps_mean", "reward_metrics/n_aligned_steps"), + } + _TRAIN_KEY_SOURCES = { + "policy_loss": ("policy_loss", ), + "kl": ("kl", ), + "actor_lr": ("actor_lr", ), + "critic_loss": ("critic_loss", ), + "critic_lr": ("critic_lr", ), + "values": ("values", ), + "values_std": ("values_std", ), + "reward": ("reward", ), + "reward_std": ("step_reward_std", ), + "return": ("return", ), + "return_std": ("returns_std", ), + "response_length": ("response_length", ), + "total_length": ("total_length", ), + "num_actions": ("num_actions", ), + "approx_kl": ("approx_kl", ), + "clipfrac": ("clipfrac", ), + "ratio_mean": ("ratio_mean", ), + "ratio_max": ("ratio_max", ), + "advantages": ("advantages_mean", ), + "advantages_std": ("advantages_std", ), + "ptx_loss": ("ptx_loss", ), + # URSA paper Eq.9 variant 2 advantage-calculator diagnostics. Populated + # only when --advantage_estimator ursa_variant2 is active; otherwise + # absent from experience.info and silently skipped by _build_train_metrics. + "ursa_v2_adv_pos_frac": ("ursa_v2_adv_pos_frac", ), + "ursa_v2_adv_neg_frac": ("ursa_v2_adv_neg_frac", ), + "ursa_v2_adv_zero_frac": ("ursa_v2_adv_zero_frac", ), + "ursa_v2_adv_abs_mean": ("ursa_v2_adv_abs_mean", ), + "ursa_v2_oc_normed_std": ("ursa_v2_oc_normed_std", ), + "ursa_v2_msp_normed_std": ("ursa_v2_msp_normed_std", ), + "ursa_v2_traj_step_count_mean": ("ursa_v2_traj_step_count_mean", ), + } + _EVAL_KEY_SOURCES = { + "reward": ("reward", "reward_mean"), + "outcome_correct": ("outcome_correct", "outcome_correct_mean"), + "has_drop_moment": ("has_drop_moment", "has_drop_moment_mean"), + "model_reward": ("model_reward", "model_reward_mean"), + "response_length": ("response_length", "response_length_mean"), + "answer_extraction_failed": ("answer_extraction_failed", "answer_extraction_failed_mean"), + # PRM step-score distribution + PS-GRPO diagnostics in eval (parity with + # the expanded rollout-side mapping above). + "step_score_min": ("step_score_min", "step_score_min_mean"), + "step_score_mean": ("step_score_mean", "step_score_mean_mean"), + "step_score_last": ("step_score_last", "step_score_last_mean"), + "step_count": ("step_count", "step_count_mean"), + "final_reward": ("final_reward", "final_reward_mean"), + "max_relative_drop": ("max_relative_drop", "max_relative_drop_mean"), + "answer_tag_present": ("answer_tag_present", "answer_tag_present_mean"), + "used_answer_fallback": ("used_answer_fallback", "used_answer_fallback_mean"), + "used_mathruler": ("used_mathruler", "used_mathruler_mean"), + "reference_supported": ("reference_supported", "reference_supported_mean"), + # Variant 2 diagnostics in eval (eval also runs the PRM forward + # if the dataset label is "math_per_step_prm") + "alignment_failed": ("alignment_failed", "alignment_failed_mean"), + "n_aligned_steps": ("n_aligned_steps", "n_aligned_steps_mean"), + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._train_generate_kwargs = dict(self.generate_kwargs) + self._eval_generate_kwargs = self._build_eval_generate_kwargs() + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.define_metric("rollout/*", step_metric=None, step_sync=False, overwrite=True) + self._wandb.define_metric("train/*", step_metric=None, step_sync=False, overwrite=True) + self._wandb.define_metric("eval/train_step") + self._wandb.define_metric("eval/*", step_metric="eval/train_step", step_sync=True, overwrite=True) + self._wandb.define_metric("profile/train_step") + self._wandb.define_metric("profile/*", step_metric="profile/train_step", step_sync=True, overwrite=True) + + def _build_eval_generate_kwargs(self) -> Dict: + eval_generate_kwargs = dict(self._train_generate_kwargs) + eval_generate_kwargs["do_sample"] = bool(getattr(self.strategy.args, "eval_do_sample", False)) + eval_generate_kwargs["max_new_tokens"] = ( + getattr(self.strategy.args, "eval_generate_max_len", None) + or self._train_generate_kwargs.get("max_new_tokens") + ) + eval_generate_kwargs["temperature"] = getattr(self.strategy.args, "eval_temperature", 0.0) + eval_generate_kwargs["top_p"] = getattr(self.strategy.args, "eval_top_p", 1.0) + eval_generate_kwargs["top_k"] = getattr(self.strategy.args, "eval_top_k", -1) + eval_generate_kwargs["repetition_penalty"] = getattr(self.strategy.args, "eval_repetition_penalty", 1.0) + eval_generate_kwargs["no_repeat_ngram_size"] = getattr( + self.strategy.args, + "eval_no_repeat_ngram_size", + 0, + ) + return eval_generate_kwargs + + @contextmanager + def _runtime_eval_context(self): + original_generate_kwargs = self.generate_kwargs + original_n_samples = self.strategy.args.n_samples_per_prompt + original_advantage_estimator = self.strategy.args.advantage_estimator + original_config_n_samples = getattr(self.strategy.config, "n_samples_per_prompt", None) + original_config_advantage_estimator = getattr(self.strategy.config, "advantage_estimator", None) + + self.generate_kwargs = dict(self._eval_generate_kwargs) + self.strategy.args.n_samples_per_prompt = max( + 1, int(getattr(self.strategy.args, "eval_n_samples_per_prompt", 1)) + ) + self.strategy.args.advantage_estimator = "reinforce" + if original_config_n_samples is not None: + self.strategy.config.n_samples_per_prompt = self.strategy.args.n_samples_per_prompt + if original_config_advantage_estimator is not None: + self.strategy.config.advantage_estimator = "reinforce" + + # Detach rollout_eos_patch on the inference engine for the duration of eval. + # The patch is meant to save GPU during training rollouts (early-stops at + # the first ``†Answer:`` line) but truncates response tokens that the + # reward extractor needs in eval; ablation showed it lowers eval + # outcome_correct by ~8pp at bs=4 and is catastrophic at bs=1 + # (extraction-failure 44%). See PR #53 issuecomment-4394071500. + rollout_actor = getattr(self.strategy, "inference_engine", None) + detached_patch = _detach_rollout_eos_patch(rollout_actor) + if detached_patch is not None and self.strategy.is_rank_0(): + self.strategy.print("[eval] rollout_eos_patch detached for the eval pass") + + try: + yield + finally: + self.generate_kwargs = original_generate_kwargs + self.strategy.args.n_samples_per_prompt = original_n_samples + self.strategy.args.advantage_estimator = original_advantage_estimator + if original_config_n_samples is not None: + self.strategy.config.n_samples_per_prompt = original_config_n_samples + if original_config_advantage_estimator is not None: + self.strategy.config.advantage_estimator = original_config_advantage_estimator + if detached_patch is not None: + _reattach_rollout_eos_patch(rollout_actor, detached_patch) + if self.strategy.is_rank_0(): + self.strategy.print("[eval] rollout_eos_patch reattached after eval") + + def _build_rollout_metrics(self, logs_dict: Dict[str, float]) -> Dict[str, float]: + rollout_metrics = {} + for target_key, source_keys in self._ROLLOUT_KEY_SOURCES.items(): + for source_key in source_keys: + if source_key in logs_dict: + rollout_metrics[target_key] = logs_dict[source_key] + break + return rollout_metrics + + def _build_train_metrics(self, logs_dict: Dict[str, float]) -> Dict[str, float]: + train_metrics = {} + for target_key, source_keys in self._TRAIN_KEY_SOURCES.items(): + for source_key in source_keys: + if source_key in logs_dict: + train_metrics[target_key] = logs_dict[source_key] + break + return train_metrics + + def _build_eval_metrics(self, raw_eval_metrics: Dict[str, float]) -> Dict[str, float]: + eval_metrics = {} + for target_key, source_keys in self._EVAL_KEY_SOURCES.items(): + for source_key in source_keys: + if source_key in raw_eval_metrics: + eval_metrics[target_key] = raw_eval_metrics[source_key] + break + return eval_metrics + + def _aggregate_eval_metrics(self, raw_eval_metrics: Dict[str, float]) -> Dict[str, float]: + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return raw_eval_metrics + + gathered_metrics = [None] * torch.distributed.get_world_size() + torch.distributed.all_gather_object(gathered_metrics, raw_eval_metrics or {}) + + total_samples = sum(float(metrics.get("num_samples", 0.0)) for metrics in gathered_metrics if metrics) + if total_samples <= 0: + return {} + + aggregated_metrics = {"num_samples": total_samples} + mean_keys = {key for metrics in gathered_metrics if metrics for key in metrics.keys() if key.endswith("_mean")} + for key in mean_keys: + weighted_sum = 0.0 + for metrics in gathered_metrics: + if not metrics or key not in metrics: + continue + weighted_sum += float(metrics["num_samples"]) * float(metrics[key]) + aggregated_metrics[key] = weighted_sum / total_samples + return aggregated_metrics + + def evaluate(self, eval_dataloader, global_step): + """Run URSA-flavored evaluation and return aggregated metrics. + + Wraps the base trainer's :meth:`evaluate` in :meth:`_runtime_eval_context` + so the actor cannot emit reserved sentinel tokens (e.g. ``<|image|>``) + during eval rollouts, then folds per-dataset metrics into a single + sample-weighted average via :meth:`_aggregate_eval_metrics`. + + :param eval_dataloader: Iterable over eval batches. + :param global_step: Current training step, used by the base evaluator + and logged alongside aggregated metrics on rank 0. + :returns: Dict of metric_name -> float ready to be uploaded under the + ``eval/`` namespace. Empty dict if the base evaluator produced no + metrics this step. + """ + with self._runtime_eval_context(): + raw_eval_metrics = super().evaluate(eval_dataloader, global_step) + aggregated_eval_metrics = self._aggregate_eval_metrics(raw_eval_metrics) + eval_metrics = self._build_eval_metrics(aggregated_eval_metrics) + if self.strategy.is_rank_0() and eval_metrics: + self.strategy.print(f"Aggregated runtime eval metrics (Step {global_step}):") + for key, value in eval_metrics.items(): + self.strategy.print(f" {key}: {value:.4f}") + return eval_metrics + + def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}, episode=0): + """Drive periodic W&B/TensorBoard logging, eval, and checkpoint saves. + + Called once per training step. Three gating cadences from ``args``: + ``logging_steps`` (rollout + train metrics), ``eval_steps`` (runs + :meth:`evaluate` and uploads under ``eval/``), and ``save_steps`` + (writes a ``global_step{N}`` checkpoint). + + Differences from the base trainer's same-named method: + - Rollout / train metric streams are routed through + :meth:`_build_rollout_metrics` / :meth:`_build_train_metrics` so + they pick up URSA-specific keys (PRM diagnostics, ursa_v2_* fields). + - W&B logs use a monotonic ``wandb_log_counter`` instead of + ``global_step`` to keep the eval and train series on a single + increasing x-axis when eval and train ticks interleave. + + :param args: argparse-parsed runtime args (uses ``logging_steps``, + ``eval_steps``, ``save_steps``). + :param global_step: Current training step. + :param step_bar: tqdm progress bar (kept for base-class signature + compatibility; not used directly here). + :param logs_dict: Per-step metric dict produced by the base trainer. + :param client_states: Extra state forwarded to the checkpoint saver. + :param episode: Current episode counter logged alongside other metrics. + """ + if global_step % args.logging_steps == 0: + rollout_metrics = self._build_rollout_metrics(logs_dict) + train_metrics = self._build_train_metrics(logs_dict) + + if self._wandb is not None and self.strategy.is_rank_0(): + all_wandb_logs = {} + + for key, value in rollout_metrics.items(): + all_wandb_logs[f"rollout/{key}"] = value + all_wandb_logs["rollout/episode"] = episode + + for key, value in train_metrics.items(): + all_wandb_logs[f"train/{key}"] = value + all_wandb_logs["train/episode"] = episode + + if all_wandb_logs: + self.wandb_log_counter += 1 + self._wandb.log(all_wandb_logs, step=self.wandb_log_counter, commit=True) + self._update_wandb_summary(all_wandb_logs) + + elif self._tensorboard is not None and self.strategy.is_rank_0(): + for key, value in rollout_metrics.items(): + self._tensorboard.add_scalar(f"rollout/{key}", value, global_step) + for key, value in train_metrics.items(): + self._tensorboard.add_scalar(f"train/{key}", value, global_step) + + if global_step % args.eval_steps == 0 and self.eval_dataloader is not None: + with self.profiler.phase("eval"): + with self.profiler.section("total"): + raw_eval_metrics = self.evaluate(self.eval_dataloader, global_step) + + if raw_eval_metrics and self.strategy.is_rank_0(): + self.eval_step_counter += 1 + + if self._wandb is not None: + eval_logs = {} + for key, value in raw_eval_metrics.items(): + eval_logs[f"eval/{key}"] = value + + eval_logs["eval/train_step"] = global_step + eval_logs["eval/episode"] = episode + + self.wandb_log_counter += 1 + self._wandb.log(eval_logs, step=self.wandb_log_counter, commit=True) + self._update_wandb_summary(eval_logs) + + elif self._tensorboard is not None: + for key, value in raw_eval_metrics.items(): + self._tensorboard.add_scalar(f"eval/{key}", value, global_step) + + if global_step % args.save_steps == 0: + with self.profiler.phase("checkpoint"): + with self.profiler.section("total"): + tag = f"global_step{global_step}" + self._save_checkpoint(args, tag, client_states) + + def log_profile_metrics(self, global_step: int, episode: int, profile_snapshot: Optional[Dict]) -> None: + """Forward step-profiler snapshots into W&B/TensorBoard on rank 0 only. + + Profile snapshots come from :class:`StepProfileRecorder` and contain + a human-readable ``summary`` plus prebuilt ``wandb_logs`` dict. On + W&B we upload the prebuilt dict as-is under the ``profile/`` namespace; + on TensorBoard we fall back to writing ``sections_max_s`` and + ``sections_max_ratio`` scalars individually. + + :param global_step: Current training step (used only by the TB path). + :param episode: Current episode counter, embedded into the W&B record. + :param profile_snapshot: Snapshot dict from the profiler, or None + if profiling was disabled or this step yielded no data. None is a + no-op. + """ + if not profile_snapshot or not self.strategy.is_rank_0(): + return + + summary = profile_snapshot.get("summary") + if summary: + self.strategy.print(summary) + + if self._wandb is not None: + wandb_logs = dict(profile_snapshot.get("wandb_logs", {})) + if wandb_logs: + wandb_logs["profile/episode"] = episode + self.wandb_log_counter += 1 + self._wandb.log(wandb_logs, step=self.wandb_log_counter, commit=True) + self._update_wandb_summary(wandb_logs) + + elif self._tensorboard is not None: + record = profile_snapshot.get("record", {}) + for key, value in record.get("sections_max_s", {}).items(): + self._tensorboard.add_scalar(f"profile/{key}_s", value, global_step) + for key, value in record.get("sections_max_ratio", {}).items(): + self._tensorboard.add_scalar(f"profile/{key}_ratio", value, global_step) + + def save_trajectories(self, global_step: int): + """Persist the current replay buffer contents to disk when configured. + + No-op when either a :class:`TrajectorySaver` has not been wired up + or the replay buffer is empty. The saver itself handles rank gating + and on-disk layout. + + :param global_step: Step value embedded in the saved trajectory's + file name / metadata so downstream analysis can correlate runs. + """ + if self.trajectory_saver is not None and self.replay_buffer.items: + self.trajectory_saver.save_trajectories( + experiences=self.replay_buffer.items, + step=global_step, + num_samples=self.num_trajectories_to_save, + prefix="trajectories", + compute_stats=self.args.trajectory_analysis, + ) diff --git a/examples/math_prm/reward_models.py b/examples/math_prm/reward_models.py new file mode 100644 index 00000000..b00c7ab8 --- /dev/null +++ b/examples/math_prm/reward_models.py @@ -0,0 +1,715 @@ +"""URSA-MATH Stage 3 reward model helpers.""" + +from __future__ import annotations + +import re +from itertools import zip_longest +from typing import Any, Dict + +import torch +import torch.nn as nn + +from lightrft.evaluation.math_eval_utils import ( + compare_answers, + extract_answer, + extract_answer_from_tags, + extract_boxed_answer, + extract_multiple_choice_answer, + extract_numeric_answer, + normalize_answer, +) + +try: + from mathruler.grader import grade_answer as mathruler_grade_answer +except ImportError: + mathruler_grade_answer = None + + +_VISION_PATTERNS = [ + r"<\|vision_start\|>(<\|image_pad\|>)+<\|vision_end\|>", + r"()+", + r"", +] + + +def _clean_vision_token(text: str) -> str: + """Remove vision placeholders from a user question before PRM scoring.""" + for pattern in _VISION_PATTERNS: + text = re.sub(pattern, "", text) + return text + + +# Pattern matching "Step N:" / "†Answer:" markers — kept only for diagnostic +# `matched_patterns` output during alignment. The actual alignment uses the +# URSA-native path (PRM's own ``replace_specific_plus_minus_with_ki`` + PRM's +# tokenizer offset_mapping) instead of re-implementing char-level simulation. +_STEP_OR_ANSWER_PATTERN = re.compile(r"(Step \d+\s*:|†Answer\s*:)") + + +def find_step_boundaries_in_response_tokens(prm_module, response_text: str, question_text: str = ""): + """URSA-native step-boundary alignment. + + Algorithm (no analytic re-implementation — every step uses PRM's own + code): + 1. Build prefix exactly like ``MathPRMReward._prepare_prm_input``: + prefix = _PRM_PROMPT + question + "\\n" (or _PRM_PROMPT + "\\n") + 2. Form the same string PRM scores on: + prm_input_str = prm_module.replace_specific_plus_minus_with_ki( + prefix + response_text) + 3. Tokenize with prm_module.tokenizer (the EXACT tokenizer PRM uses) and + locate every ` и` token (id == prm_module.tag_id). + 4. Each ` и` token's offset_mapping char_start lies inside prm_input_str. + Subtract len(prefix) and ``2 * k_tag`` (each prior ` и` adds 2 chars) + to recover the position in the ORIGINAL ``response_text`` where the + step-end occurs. + 5. Re-tokenize ``response_text`` (without ` и`) and find the response + token whose char_end <= that position. That is the per-step + boundary token index. + + Returned indices are relative to response start (0 = first response + token). Caller (compute_reward via fast_exp_maker._compute_advantages) + scatters per-step rewards onto these indices. + + Why native? It avoids any divergence between an analytic char-level + model of ` и` insertion and PRM's actual tokenizer behavior. If the + tokenizer ever merges ` и` with adjacent chars, this path stays correct + because we read offsets from the actual tokenization PRM uses. + + Parameters + ---------- + prm_module : MathPRMReward + Provides ``_PRM_PROMPT``, ``tokenizer``, ``tag_id``, and + ``replace_specific_plus_minus_with_ki``. + response_text : str + The actor-generated response (assistant content only, no chat tags). + question_text : str, optional + Prompt question — passed through ``_prepare_prm_input`` so the prefix + length matches the PRM-side string exactly. + + Returns + ------- + boundaries : list[int] + Per-step boundary token indices in the response token sequence. + ``len(boundaries) == number of step_scores PRM emits``. + matched_patterns : list[str] + ``Step N:`` / ``†Answer:`` patterns found in the response (debug aid). + """ + matched_patterns = [m.group() for m in _STEP_OR_ANSWER_PATTERN.finditer(response_text)] + + if question_text and not isinstance(question_text, float): + prefix_str = prm_module._PRM_PROMPT + question_text + "\n" + else: + prefix_str = prm_module._PRM_PROMPT + "\n" + prefix_len = len(prefix_str) + prm_input_str = prm_module.replace_specific_plus_minus_with_ki(prefix_str + response_text) + + tok = prm_module.tokenizer + enc_prm = tok(prm_input_str, return_offsets_mapping=True, add_special_tokens=False) + prm_offsets = enc_prm["offset_mapping"] + prm_ids = enc_prm["input_ids"] + tag_id = prm_module.tag_id + + char_in_response: list[int] = [] + k_tag = 0 + for tid, off in zip(prm_ids, prm_offsets): + if tid == tag_id: + char_in_response.append(off[0] - prefix_len - 2 * k_tag) + k_tag += 1 + + enc_resp = tok(response_text, return_offsets_mapping=True, add_special_tokens=False) + resp_offsets = enc_resp["offset_mapping"] + + boundaries: list[int] = [] + for cp in char_in_response: + last_idx = -1 + for tok_idx, (_, ce) in enumerate(resp_offsets): + if 0 < ce <= cp: + last_idx = tok_idx + boundaries.append(last_idx) + return boundaries, matched_patterns + + +class MathPRMReward(nn.Module): + """Wrap URSA-RM with the original URSA-MATH step-level scoring protocol.""" + + _SYSTEM_PROMPT = "You are a helpful assistant." + _PRM_PROMPT = ( + "You are given a problem and a step-by-step solution. " + "You need to check the correctness of each step.\nQuestion:" + ) + _IMAGE_PAD = 575 + # PS-GRPO step-score drop hyperparameters (URSA-MATH paper): + # _DROP_THRESHOLD - relative drop fraction that counts as a "drop moment" + # _DROP_GAMMA - reward penalty when a drop moment is observed for a + # correct answer; final_reward = 1 - _DROP_GAMMA = 0.5 + _DROP_THRESHOLD = 0.3 + _DROP_GAMMA = 0.5 + + def __init__(self, base_model: nn.Module, processor, aggregation: str = "min") -> None: + super().__init__() + self.model = base_model + self.processor = processor + self.tokenizer = processor.tokenizer + self.aggregation = aggregation + + tag_ids = self.tokenizer.encode(" и", add_special_tokens=False) + assert len(tag_ids) == 1, ( + "The step tag ' и' must map to exactly one token. " + f"Got {tag_ids!r} instead." + ) + self.tag_id = int(tag_ids[0]) + + @staticmethod + def replace_specific_plus_minus_with_ki(text: str) -> str: + """Insert the URSA step-boundary marker `` и`` before each next step.""" + pattern = r"Step \d+" + matches = list(re.finditer(pattern, text)) + positions = [(match.start(), match.end()) for match in matches] + if not positions: + return text + " и" + + text_list = list(text) + insert_positions = [] + try: + for i in range(1, len(positions)): + for j in range(positions[i][0] - 1, positions[i - 1][1], -1): + if text_list[j] not in {" ", "\n"}: + insert_positions.append(j + 1) + break + + answer_start = text.find("†Answer:") + if answer_start != -1: + for j in range(answer_start - 1, positions[-1][1], -1): + if text_list[j] not in {" ", "\n"}: + insert_positions.append(j + 1) + break + + for index in sorted(insert_positions, reverse=True): + text = text[:index] + " и" + text[index:] + return text + except Exception: + return text + " и" + + def _prepare_prm_input(self, question: str, response: str) -> str: + if not question or isinstance(question, float): + instruction = self._PRM_PROMPT + "\n" + response + else: + instruction = self._PRM_PROMPT + question + "\n" + response + return self.replace_specific_plus_minus_with_ki(instruction) + + def _split_conversation(self, prompt_and_output: str) -> tuple[str, str]: + question = "" + response = "" + + for sep in ("<|im_start|>user\n", "User:", "USER:"): + if sep not in prompt_and_output: + continue + user_block = prompt_and_output.split(sep)[-1] + for end in ("<|im_end|>", "<|im_start|>"): + if end in user_block: + user_block = user_block.split(end)[0] + question = self._clean_question_text(user_block) + break + + for sep in ("<|im_start|>assistant\n", "Assistant:", "ASSISTANT:"): + if sep not in prompt_and_output: + continue + response_block = prompt_and_output.split(sep)[-1] + for end in ("<|im_end|>", "<|endoftext|>"): + if end in response_block: + response_block = response_block.split(end)[0] + response = response_block.strip() + break + + if not response: + response = prompt_and_output + return question, response + + @staticmethod + def _clean_question_text(question: str) -> str: + question = _clean_vision_token(question) + question = question.replace("<|image|>", "").replace("", "") + return question.strip() + + @staticmethod + def _select_prm_image(raw_image: Any) -> list[Any]: + if isinstance(raw_image, (list, tuple)): + for item in raw_image: + if item is not None: + return [item] + return [None] + return [raw_image] if raw_image is not None else [None] + + @staticmethod + def _safe_text(value: Any) -> str: + if value is None: + return "" + return str(value).strip() + + @staticmethod + def _is_multiple_choice_reference(reference: str) -> bool: + ref = normalize_answer(reference).strip().upper() + return len(ref) == 1 and ref in {"A", "B", "C", "D"} + + @classmethod + def _infer_reference_type(cls, reference: Any) -> tuple[str, bool]: + reference_text = cls._safe_text(reference) + if not reference_text: + return "missing", False + + reference_norm = normalize_answer(reference_text).strip() + if cls._is_multiple_choice_reference(reference_norm): + return "multiple_choice", True + + if reference_norm.lower() in {"yes", "no", "true", "false"}: + return "text", True + + numeric_candidate = reference_norm.replace(",", "") + if re.fullmatch(r"-?\d+(?:\.\d+)?", numeric_candidate): + return "numeric", True + if re.fullmatch(r"-?\d+/\d+", numeric_candidate): + return "numeric", True + + if any(token in reference_norm for token in ("\\", "=", "^", "_", "{", "}", "sqrt", "frac")): + return "formula", True + if re.search(r"[a-zA-Z]", reference_norm) and re.search(r"[\d=+\-*/()]", reference_norm): + return "formula", True + + return "text", True + + @classmethod + def _extract_answer_from_candidate(cls, candidate: str, reference_type: str) -> str: + candidate = cls._safe_text(candidate) + if not candidate: + return "" + + boxed = extract_boxed_answer(candidate) + if boxed: + return boxed + + tagged = extract_answer_from_tags(candidate, "answer") + if tagged: + return tagged + + candidate = re.sub( + r"^(?:†\s*)?(?:final answer|correct answer(?: is)?|the answer is|answer)\s*[::]?\s*", + "", + candidate, + flags=re.IGNORECASE, + ).strip() + candidate = candidate.rstrip(" .") + if not candidate: + return "" + + if reference_type == "multiple_choice": + extracted = extract_multiple_choice_answer(candidate) + return extracted or normalize_answer(candidate).strip().upper() + + if reference_type == "numeric": + if any(token in candidate for token in ("\\", "/", "=", "^", "{", "}", "sqrt", "frac")): + return normalize_answer(candidate) + extracted = extract_numeric_answer(candidate) + return extracted or normalize_answer(candidate) + + if reference_type in {"formula", "text"}: + return normalize_answer(candidate) + + return extract_answer(candidate) + + @classmethod + def _extract_final_answer_details(cls, response: str, reference_type: str) -> Dict[str, Any]: + response = cls._safe_text(response) + details: Dict[str, Any] = { + "predicted_answer": "", + "answer_tag_present": False, + "answer_extraction_failed": True, + "used_answer_fallback": False, + "extraction_source": "missing", + } + if not response: + return details + + if "†Answer:" in response: + details["answer_tag_present"] = True + answer_block = response.split("†Answer:", 1)[-1] + answer_block = re.split(r"\n\s*Step\s+\d+\s*:", answer_block, maxsplit=1)[0] + candidate_lines = [line.strip() for line in answer_block.splitlines() if line.strip()] + candidate = candidate_lines[0] if candidate_lines else answer_block.strip() + predicted_answer = cls._extract_answer_from_candidate(candidate, reference_type) + details["predicted_answer"] = predicted_answer + details["answer_extraction_failed"] = predicted_answer == "" + details["extraction_source"] = "dagger_answer" + return details + + explicit_fallbacks = [ + ("boxed", extract_boxed_answer(response)), + ("tagged_answer", extract_answer_from_tags(response, "answer")), + ] + for source, match in explicit_fallbacks: + if match: + details["predicted_answer"] = normalize_answer(match) + details["answer_extraction_failed"] = False + details["used_answer_fallback"] = True + details["extraction_source"] = source + return details + + lines = [line.strip() for line in response.splitlines() if line.strip()] + if lines: + last_line = lines[-1] + explicit_line = re.match( + r"^(?:†\s*)?(?:final answer|correct answer(?: is)?|the answer is|answer)\b", + last_line, + flags=re.IGNORECASE, + ) + if explicit_line: + predicted_answer = cls._extract_answer_from_candidate(last_line, reference_type) + details["predicted_answer"] = predicted_answer + details["answer_extraction_failed"] = predicted_answer == "" + details["used_answer_fallback"] = True + details["extraction_source"] = "explicit_last_line" + return details + + return details + + @classmethod + def _compare_final_answer( + cls, + predicted_answer: str, + reference: Any, + reference_type: str, + reference_supported: bool, + ) -> tuple[bool, str]: + reference_text = cls._safe_text(reference) + if not reference_supported: + return False, "unsupported_reference" + if not reference_text: + return False, "missing_reference" + if not predicted_answer: + return False, "missing_prediction" + + if reference_type == "multiple_choice": + pred_norm = normalize_answer(predicted_answer).strip().upper() + ref_norm = normalize_answer(reference_text).strip().upper() + return pred_norm == ref_norm, "multiple_choice_exact" + + if reference_type in {"numeric", "formula"}: + if mathruler_grade_answer is not None: + try: + if mathruler_grade_answer(predicted_answer, reference_text): + return True, "mathruler" + except Exception: + pass + return compare_answers(predicted_answer, reference_text, is_multiple_choice=False), "math_eval" + + return compare_answers(predicted_answer, reference_text, is_multiple_choice=False), "text_compare" + + @classmethod + def _evaluate_answer_alignment(cls, response: str, reference: Any) -> Dict[str, Any]: + reference_type, reference_supported = cls._infer_reference_type(reference) + extraction = cls._extract_final_answer_details(response, reference_type) + outcome_correct, comparison_method = cls._compare_final_answer( + extraction["predicted_answer"], + reference, + reference_type, + reference_supported, + ) + return { + "reference_type": reference_type, + "reference_supported": reference_supported, + "comparison_method": comparison_method, + **extraction, + "outcome_correct": outcome_correct, + } + + @classmethod + def _compute_relative_drop(cls, step_scores: torch.Tensor) -> tuple[float, bool]: + if step_scores.numel() < 2: + return 0.0, False + + scores = step_scores.detach().float() + prev_scores = scores[:-1] + next_scores = scores[1:] + denom = torch.clamp(prev_scores, min=1e-6) + relative_drops = torch.clamp((prev_scores - next_scores) / denom, min=0.0) + max_relative_drop = float(relative_drops.max().item()) if relative_drops.numel() else 0.0 + return max_relative_drop, max_relative_drop >= cls._DROP_THRESHOLD + + @classmethod + def _compute_psgrpo_metrics( + cls, + response: str, + reference: Any, + step_scores: torch.Tensor, + ) -> Dict[str, float]: + answer_eval = cls._evaluate_answer_alignment(response, reference) + outcome_correct = float(answer_eval["outcome_correct"]) + max_relative_drop, has_drop_moment = cls._compute_relative_drop(step_scores) + + final_reward = 0.0 + if outcome_correct > 0.0: + final_reward = 1.0 - cls._DROP_GAMMA if has_drop_moment else 1.0 + + return { + "outcome_correct": outcome_correct, + "max_relative_drop": max_relative_drop, + "has_drop_moment": float(has_drop_moment), + "final_reward": final_reward, + "answer_tag_present": float(answer_eval["answer_tag_present"]), + "answer_extraction_failed": float(answer_eval["answer_extraction_failed"]), + "used_answer_fallback": float(answer_eval["used_answer_fallback"]), + "reference_supported": float(answer_eval["reference_supported"]), + "used_mathruler": float(answer_eval["comparison_method"] == "mathruler"), + } + + @torch.no_grad() + def forward( + self, + sequences, + attention_mask, + prompt_and_output=None, + raw_images=None, + references=None, + labels=None, + **kwargs, + ) -> torch.Tensor | Dict[str, torch.Tensor]: + device = next(self.model.parameters()).device + + if prompt_and_output is None and sequences is not None: + prompt_and_output = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) + elif prompt_and_output is None: + raise ValueError("Either sequences or prompt_and_output must be provided") + + return_dict = bool(kwargs.get("return_dict", False)) + + batch_rewards = [] + # Per-sample reward metrics emitted alongside the scalar reward. + # They are grouped into three buckets: + # + # 1. PRM step-score statistics (continuous, distribution shape): + # model_reward - aggregated step score (min/avg/last per agg setting) + # step_score_min - lowest step score in the response + # step_score_mean - mean step score + # step_score_last - score of the final step + # step_count - number of "Step N:" blocks scored + # + # 2. Outcome / correctness signals (mostly binary): + # outcome_correct - 1 if extracted answer matches ground truth, else 0 + # has_drop_moment - 1 if any consecutive step pair dropped > _DROP_THRESHOLD + # max_relative_drop - magnitude of the largest relative drop + # final_reward - PS-GRPO reward {0, 1-_DROP_GAMMA, 1} fed into GRPO + # + # 3. Diagnostics on answer extraction / grading path (low-volume but useful + # when debugging dataset / format / mathruler issues): + # answer_tag_present - 1 if the "†Answer:" marker appeared + # answer_extraction_failed - 1 if no answer string could be extracted + # used_answer_fallback - 1 if the heuristic last-line fallback fired + # reference_supported - 1 if the ground-truth schema is recognized + # used_mathruler - 1 if mathruler grading was the deciding step + # + # NOTE: ``accuracy_reward`` used to live here, but for math_psgrpo it is + # exactly equal to ``outcome_correct`` (see _compute_psgrpo_metrics). + # It now lives only in reward_models_utils.mix_rewards where it is set + # by the rule branch for the math_rule / math_prm_combined recipes. + batch_metrics: Dict[str, list[float]] = { + "model_reward": [], + "step_score_min": [], + "step_score_mean": [], + "step_score_last": [], + "step_count": [], + "outcome_correct": [], + "max_relative_drop": [], + "has_drop_moment": [], + "final_reward": [], + "answer_tag_present": [], + "answer_extraction_failed": [], + "used_answer_fallback": [], + "reference_supported": [], + "used_mathruler": [], + # math_per_step_prm diagnostics + "alignment_failed": [], + "n_aligned_steps": [], + } + # Per-step PRM data (variant 2). Lists of per-trajectory tensors, + # only populated for label == "math_per_step_prm" trajectories. + # When all trajectories in a batch have label == "math_psgrpo" the + # collected lists stay empty and the dict keys are dropped at the end + # so legacy callers that don't know about per-step rewards see no + # change. + batch_step_rewards: list[torch.Tensor] = [] + batch_step_token_indices: list[torch.Tensor] = [] + + image_inputs = raw_images or [None] * len(prompt_and_output) + ref_inputs = references or [None] * len(prompt_and_output) + label_inputs = labels or ["math_prm"] * len(prompt_and_output) + + for text, sample_image, reference, label in zip_longest( + prompt_and_output, image_inputs, ref_inputs, label_inputs, fillvalue=None + ): + if text is None: + continue + + question, response = self._split_conversation(text) + input_prompt = self._prepare_prm_input(question, response) + conversation = [ + {"role": "system", "content": self._SYSTEM_PROMPT}, + {"role": "user", "content": "<|image|>" + input_prompt}, + ] + formatted_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + inputs = self.processor( + formatted_prompt, + self._select_prm_image(sample_image), + return_tensors="pt", + ).to(device, torch.bfloat16) + + # Sanity check: response (RL-generated) is not vision-cleaned, so it can + # contain literal `<|image|>` / `` strings that the tokenizer maps + # to image_token_index. The PRM only ever receives one image, so any + # extras would crash _merge_input_ids_with_image_features. Keep the first + # image token (intended placeholder) and replace the rest with a benign + # text token so PRM scoring continues instead of aborting the rollout. + image_token_id = getattr(self.model.config, "image_token_index", None) + if image_token_id is not None: + input_ids_view = inputs["input_ids"] + image_mask_flat = (input_ids_view == image_token_id).view(-1) + extras = torch.nonzero(image_mask_flat, as_tuple=False).squeeze(-1) + if extras.numel() > 1: + replacement = self.tokenizer.pad_token_id + if replacement is None: + replacement = self.tokenizer.eos_token_id + flat = input_ids_view.view(-1) + flat[extras[1:]] = replacement + inputs["input_ids"] = flat.view(input_ids_view.shape) + + reward = self.model(**inputs).logits + input_ids = inputs["input_ids"].view(-1) + padding = torch.full((self._IMAGE_PAD,), -1, device=device) + input_ids_aligned = torch.cat((input_ids[:1], padding, input_ids[1:])) + + reward_flat = reward.view(-1) + step_logits = reward_flat[input_ids_aligned == self.tag_id] + step_scores = torch.sigmoid(step_logits).view(-1) + psgrpo_metrics = self._compute_psgrpo_metrics(response, reference, step_scores) + + if step_scores.numel() == 0: + aggregated_score = 0.0 + elif self.aggregation == "min": + aggregated_score = float(torch.min(step_scores).item()) + elif self.aggregation in {"avg", "mean"}: + aggregated_score = float(torch.mean(step_scores).item()) + elif self.aggregation == "last": + aggregated_score = float(step_scores[-1].item()) + else: + raise ValueError(f"Unknown aggregation: {self.aggregation!r}") + + # ---- variant 2 (per-step PRM reward) alignment ------------------- + # For label == "math_per_step_prm" we additionally locate the + # boundary token of each "Step N:" inside the response so the + # step_scores tensor can be scattered to per-token positions + # downstream (instead of being collapsed to one scalar). + # + # Alignment is *self-contained*: we re-tokenize ``response`` with + # the actor's (== PRM's, in URSA family) tokenizer and use the + # offset mapping to reverse-find token indices for each "Step N:" + # / "†Answer:" pattern. Indices are relative to the response + # start so they line up with the action_mask axis of final_reward + # in compute_reward. + # + # If the alignment fails (n_steps_prm != n_boundaries) we + # *bypass* per-step mode for that trajectory: emit empty tensors + # so compute_reward falls back to the trajectory-scalar path, + # and bump the alignment_failed metric for monitoring. + if label == "math_per_step_prm" and step_scores.numel() > 0: + boundaries, matched_patterns = find_step_boundaries_in_response_tokens( + self, response, question_text=question + ) + aligned = (len(boundaries) == int(step_scores.numel())) + if aligned: + traj_step_rewards = step_scores.detach().to(torch.float32).cpu() + traj_step_tokens = torch.tensor(boundaries, dtype=torch.long) + n_aligned = len(boundaries) + else: + # Alignment failed: emit empties; downstream falls back to + # trajectory-scalar mode for this row. + traj_step_rewards = torch.empty(0, dtype=torch.float32) + traj_step_tokens = torch.empty(0, dtype=torch.long) + n_aligned = 0 + batch_step_rewards.append(traj_step_rewards) + batch_step_token_indices.append(traj_step_tokens) + batch_metrics["alignment_failed"].append(0.0 if aligned else 1.0) + batch_metrics["n_aligned_steps"].append(float(n_aligned)) + else: + # No per-step request for this row; emit empty placeholders to + # keep the per-traj list aligned with batch_rewards. + batch_step_rewards.append(torch.empty(0, dtype=torch.float32)) + batch_step_token_indices.append(torch.empty(0, dtype=torch.long)) + batch_metrics["alignment_failed"].append(0.0) + batch_metrics["n_aligned_steps"].append(0.0) + + # ---- trajectory-scalar reward (PSGRPO / aggregate path) ---------- + if label == "math_psgrpo": + sequence_reward = psgrpo_metrics["final_reward"] + elif label == "math_per_step_prm": + # In per-step mode the trajectory-scalar field is still used + # by GroupNorm baseline — use outcome (clean signal) instead + # of aggregated_score (which would double-count step rewards). + sequence_reward = float(psgrpo_metrics["outcome_correct"]) + else: + sequence_reward = aggregated_score + batch_rewards.append(sequence_reward) + batch_metrics["model_reward"].append(aggregated_score) + batch_metrics["step_score_min"].append(float(torch.min(step_scores).item()) if step_scores.numel() else 0.0) + batch_metrics["step_score_mean"].append(float(torch.mean(step_scores).item()) if step_scores.numel() else 0.0) + batch_metrics["step_score_last"].append(float(step_scores[-1].item()) if step_scores.numel() else 0.0) + batch_metrics["step_count"].append(float(step_scores.numel())) + # Diagnostics: outcome_correct and answer-extraction signals are + # always meaningful (they're computed by _evaluate_answer_alignment + # which is independent of PSGRPO drop-moment); only the + # drop-moment-specific fields (max_relative_drop, has_drop_moment, + # final_reward) zero out for non-PSGRPO labels. + _UNIVERSAL_METRICS = { + "outcome_correct", + "answer_tag_present", + "answer_extraction_failed", + "used_answer_fallback", + "reference_supported", + "used_mathruler", + } + for key in ( + "outcome_correct", + "max_relative_drop", + "has_drop_moment", + "final_reward", + "answer_tag_present", + "answer_extraction_failed", + "used_answer_fallback", + "reference_supported", + "used_mathruler", + ): + if label == "math_psgrpo" or key in _UNIVERSAL_METRICS: + batch_metrics[key].append(psgrpo_metrics[key]) + else: + # PSGRPO-specific (max_relative_drop, has_drop_moment, + # final_reward) — only meaningful when label is + # "math_psgrpo"; zero out for other labels to preserve + # historical metric tensor shape & semantics. + batch_metrics[key].append(0.0) + + score_tensor = torch.tensor(batch_rewards, dtype=torch.float32, device=device) + if references is None and labels is None and not return_dict: + return score_tensor + + metrics_tensor = { + key: torch.tensor(values, dtype=torch.float32, device=device) + for key, values in batch_metrics.items() + } + out = {"score": score_tensor, **metrics_tensor} + + # Only attach per-step fields if any trajectory had non-empty step data. + # Stored as Python lists of CPU tensors (variable length per traj) to + # avoid forcing every caller to handle padded tensors. + any_per_step = any(t.numel() > 0 for t in batch_step_rewards) + if any_per_step: + out["step_rewards"] = batch_step_rewards + out["step_token_indices"] = batch_step_token_indices + + return out diff --git a/examples/math_prm/reward_models_utils.py b/examples/math_prm/reward_models_utils.py new file mode 100644 index 00000000..d6fd4ca1 --- /dev/null +++ b/examples/math_prm/reward_models_utils.py @@ -0,0 +1,373 @@ +"""Math-only reward loading and aggregation utilities for URSA-MATH Stage 3.""" + +from __future__ import annotations + +import json +import re +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch + +from lightrft.models.monkey_patch.hf_generate_patch import apply_monkey_patch_to_generation_mixin +from lightrft.utils import get_current_device + +from reward_models import MathPRMReward + + +class RewardModelType(str, Enum): + """Supported reward model types for the math_prm example.""" + + MATH_PRM = "math_prm" + + +@dataclass +class RewardModelConfig: + """Configuration for one reward model instance.""" + + rtype: RewardModelType + path: str + use_engine: bool = False + + +RawRewardInput = Union[str, Dict[str, str], List[Dict[str, str]], None] +_BUILDERS: Dict[RewardModelType, Callable[..., Tuple[Any, Any]]] = {} + + +def register_builder(rtype: RewardModelType) -> Callable: + def deco(fn: Callable) -> Callable: + _BUILDERS[rtype] = fn + return fn + + return deco + + +def _guess_rtype_from_path(path: str) -> RewardModelType: + lowered = path.lower() + if any(keyword in lowered for keyword in ("ursa", "prm", "math-rm", "step-reward", "process-reward")): + return RewardModelType.MATH_PRM + return RewardModelType.MATH_PRM + + +def parse_reward_pretrain( + raw: RawRewardInput, + *, + global_use_engine: bool, +) -> Tuple[List[RewardModelConfig], Dict[str, int]]: + """Parse reward model config while keeping the old flexible input shapes.""" + + if raw is None: + return [], {} + + pair_list: List[Tuple[str, str, Optional[bool]]] = [] + if isinstance(raw, str): + text = raw.strip() + if not text: + return [], {} + if text.startswith("{") and text.endswith("}"): + obj = json.loads(text) + pair_list = [(key, value, None) for key, value in obj.items()] + else: + for segment in re.split(r"\s*,\s*", text): + if not segment: + continue + if ":" in segment: + key, value = segment.split(":", 1) + pair_list.append((key.strip(), value.strip(), None)) + else: + pair_list.append(("?", segment.strip(), None)) + elif isinstance(raw, dict): + pair_list = [(key, value, None) for key, value in raw.items()] + elif isinstance(raw, list): + for item in raw: + pair_list.append((item["type"], item["path"], item.get("engine"))) + else: + raise TypeError("Unsupported --reward_pretrain format") + + cfgs: List[RewardModelConfig] = [] + for key, path, flag in pair_list: + use_engine = global_use_engine + if "?engine=" in path: + path, qs = path.split("?engine=", 1) + use_engine = qs.lower() in {"1", "true", "yes"} + if flag is not None: + use_engine = bool(flag) + rtype = _guess_rtype_from_path(path) if key == "?" else RewardModelType(key) + cfgs.append(RewardModelConfig(rtype=rtype, path=path, use_engine=use_engine)) + + label_map = {cfg.rtype.value: index for index, cfg in enumerate(cfgs)} + return cfgs, label_map + + +def _load_ursa_prm_model(pretrain_path: str, device: torch.device | int) -> Tuple[Any, Any]: + from ursa_model import UrsaForTokenClassification, UrsaProcessor + + processor = UrsaProcessor.from_pretrained(pretrain_path) + model = UrsaForTokenClassification.from_pretrained( + pretrain_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + ) + model = model.to(device) + model.eval() + return model, processor + + +def _load_engine(pretrain_path: str, device: torch.device | int) -> Tuple[Any, Any]: + raise RuntimeError( + "The math_prm example no longer supports external reward-model engines. " + "URSA-RM is loaded through the local HF path instead." + ) + + +def _shared_base_key(cfg: RewardModelConfig) -> Optional[Tuple[str, str]]: + if cfg.rtype != RewardModelType.MATH_PRM: + return None + return (cfg.path, cfg.rtype.value) + + +def _load_shared_base(cfg: RewardModelConfig) -> Tuple[Any, Any]: + return _load_ursa_prm_model(cfg.path, get_current_device()) + + +@register_builder(RewardModelType.MATH_PRM) +def build_math_prm( + cfg: RewardModelConfig, + strategy: Any, + base: Optional[Tuple[Any, Any]] = None, +) -> Tuple[MathPRMReward, Any]: + if cfg.use_engine: + strategy.print( + "[build_math_prm] Engine mode is not supported for URSA-RM. " + "Falling back to direct HF loading." + ) + + if base is None: + base_model, processor = _load_ursa_prm_model(cfg.path, get_current_device()) + else: + base_model, processor = base + + reward_model = MathPRMReward( + base_model=base_model, + processor=processor, + aggregation="min", + ) + reward_model.eval() + return reward_model, processor.tokenizer + + +def load_reward_models( + raw_reward_pretrain: RawRewardInput, + strategy: Any, + use_engine: bool = False, +) -> Tuple[List[Any], List[Any], Dict[str, int]]: + apply_monkey_patch_to_generation_mixin() + cfgs, label_map = parse_reward_pretrain(raw_reward_pretrain, global_use_engine=use_engine) + + reward_models: List[Any] = [] + reward_tokenizers: List[Any] = [] + shared_bases: Dict[Tuple[str, str], Tuple[Any, Any]] = {} + + for cfg in cfgs: + cache_key = _shared_base_key(cfg) + if cache_key is not None and cache_key not in shared_bases: + shared_bases[cache_key] = _load_shared_base(cfg) + strategy.print(f"Init reward model base {cfg.path} (engine={cfg.use_engine}, type={cfg.rtype})") + + for cfg in cfgs: + if cfg.rtype not in _BUILDERS: + raise RuntimeError(f"No builder registered for {cfg.rtype}") + strategy.print(f"Loading {cfg.rtype} from {cfg.path} (engine={cfg.use_engine})") + with strategy.init_model_context() as _: + reward_model, tokenizer = _BUILDERS[cfg.rtype]( + cfg, + strategy, + base=shared_bases.get(_shared_base_key(cfg)), + ) + reward_models.append(reward_model) + reward_tokenizers.append(tokenizer) + strategy.print(f"Loaded {cfg.rtype}") + + return reward_models, reward_tokenizers, label_map + + +def math_prm_format_reward_fn(sol: str) -> float: + """Diagnostic-only check for the required Stage 3 ``Step N`` / ``†Answer`` format.""" + if not isinstance(sol, str): + return 0.0 + step_matches = re.findall(r"(?m)^Step\s+\d+\s*:\s*\S", sol) + answer_matches = re.findall(r"(?m)^†Answer:\s*\S", sol) + non_empty_lines = [line.strip() for line in sol.splitlines() if line.strip()] + if not step_matches or len(answer_matches) != 1 or not non_empty_lines: + return 0.0 + return 1.0 if non_empty_lines[-1].startswith("†Answer:") else 0.0 + + +def format_reward_fn(sol: str) -> float: + """Compatibility alias kept for older callers inside this example directory.""" + return math_prm_format_reward_fn(sol) + + +def rule_reward_fn(sol: str, gt: str) -> float: + """Rule-only baseline using the same controlled final-answer extraction as PS-GRPO.""" + if not gt: + return 0.0 + answer_eval = MathPRMReward._evaluate_answer_alignment(sol, gt) + return 1.0 if answer_eval["outcome_correct"] else 0.0 + + +RECIPE: Dict[str, List[Tuple[str, Optional[str], float]]] = { + "math_prm": [("model", "math_prm", 1.0)], + "math_psgrpo": [("model", "math_prm", 1.0)], + "math_per_step_prm": [("model", "math_prm", 1.0)], + "math_prm_combined": [("model", "math_prm", 1.0), ("rule", None, 0.5)], + "math_rule": [("rule", None, 1.0)], +} + + +NO_GLOBAL_FORMAT_REWARD_LABELS = { + "math_prm", + "math_psgrpo", + "math_per_step_prm", + "math_prm_combined", + "math_rule", +} + + +def mix_rewards( + labels: Sequence[str], + model_scores: torch.Tensor, + label_map: Dict[str, int], + solution_strs: Sequence[str], + refs: Sequence[str], + model_reward_metrics_list: Optional[List[Optional[Dict[str, torch.Tensor]]]] = None, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + if model_scores.numel() > 0: + device = model_scores.device + elif model_reward_metrics_list: + first_metric_tensor = next( + ( + tensor + for metrics in model_reward_metrics_list + if metrics + for tensor in metrics.values() + if isinstance(tensor, torch.Tensor) + ), + None, + ) + device = first_metric_tensor.device if first_metric_tensor is not None else torch.device("cpu") + else: + device = torch.device("cpu") + + n_model = int(model_scores.shape[0]) + batch_size = len(labels) + if model_scores.ndim != 2: + raise ValueError(f"model_scores must have shape (n_model, B), got {tuple(model_scores.shape)!r}") + if model_scores.shape[1] != batch_size: + raise AssertionError("model_scores second dimension must equal batch size") + + final_reward = torch.zeros(batch_size, dtype=torch.float32, device=device) + metrics_dict: Dict[str, torch.Tensor] = { + "format_reward": torch.zeros(batch_size, dtype=torch.float32, device=device), + "accuracy_reward": torch.zeros(batch_size, dtype=torch.float32, device=device), + "model_reward": torch.zeros(batch_size, dtype=torch.float32, device=device), + "rule_reward": torch.zeros(batch_size, dtype=torch.float32, device=device), + "outcome_correct": torch.zeros(batch_size, dtype=torch.float32, device=device), + "max_relative_drop": torch.zeros(batch_size, dtype=torch.float32, device=device), + "has_drop_moment": torch.zeros(batch_size, dtype=torch.float32, device=device), + "final_reward": torch.zeros(batch_size, dtype=torch.float32, device=device), + } + + def ensure_metric_key(metric_name: str) -> None: + if metric_name not in metrics_dict: + metrics_dict[metric_name] = torch.zeros(batch_size, dtype=torch.float32, device=device) + + def get_model_reward(key: str, index: int) -> float: + if key not in label_map: + print(f"Model reward <{key}> not loaded, using 1.0 as fallback") + return 1.0 + model_index = label_map[key] + if model_index >= n_model: + print(f"Model reward <{key}> index {model_index} out of bounds, using 1.0 as fallback") + return 1.0 + return float(model_scores[model_index, index].item()) + + def get_model_metrics(key: str, index: int) -> Dict[str, float]: + if not model_reward_metrics_list or key not in label_map: + return {} + model_index = label_map[key] + if model_index >= len(model_reward_metrics_list): + return {} + metrics = model_reward_metrics_list[model_index] + if not metrics: + return {} + + sample_metrics: Dict[str, float] = {} + for metric_name, tensor_value in metrics.items(): + if not isinstance(tensor_value, torch.Tensor): + continue + flat_tensor = tensor_value.reshape(-1) + if flat_tensor.numel() <= index: + continue + sample_metrics[metric_name] = float(flat_tensor[index].item()) + return sample_metrics + + for index, label in enumerate(labels): + solution = solution_strs[index] + reference = refs[index] if index < len(refs) else "" + format_metric = math_prm_format_reward_fn(solution) + metrics_dict["format_reward"][index] = format_metric + reward_value = 0.0 if label in NO_GLOBAL_FORMAT_REWARD_LABELS else format_metric + + recipe = RECIPE.get(label) + if recipe is None: + print(f"label <{label}> not registered in RECIPE, returning 0.0 reward") + recipe = [] + + for reward_type, key, weight in recipe: + if reward_type == "model": + model_reward = weight * get_model_reward(key, index) + reward_value += model_reward + metrics_dict["model_reward"][index] += model_reward + for metric_name, metric_value in get_model_metrics(key, index).items(): + ensure_metric_key(metric_name) + if metric_name == "final_reward": + continue + metrics_dict[metric_name][index] = metric_value + elif reward_type == "rule": + rule_reward = weight * rule_reward_fn(solution, reference) + reward_value += rule_reward + metrics_dict["rule_reward"][index] += rule_reward + metrics_dict["accuracy_reward"][index] = rule_reward + else: + print(f"Unknown component type {reward_type}, ignoring") + + final_reward[index] = reward_value + metrics_dict["final_reward"][index] = reward_value + + return final_reward, metrics_dict + + +def reward_fn( + model_reward_list: List[torch.Tensor], + model_reward_metrics_list: Optional[List[Optional[Dict[str, torch.Tensor]]]], + labels: Sequence[str], + queries: Sequence[str], + refs: Sequence[str], + label_map: Dict[str, int], +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + if model_reward_list: + model_scores = torch.stack(model_reward_list) + else: + model_scores = torch.zeros((0, len(labels)), dtype=torch.float32, device="cpu") + + return mix_rewards( + labels=labels, + model_scores=model_scores, + label_map=label_map, + solution_strs=queries, + refs=refs, + model_reward_metrics_list=model_reward_metrics_list, + ) diff --git a/examples/math_prm/rollout_eos_patch.py b/examples/math_prm/rollout_eos_patch.py new file mode 100644 index 00000000..e46b1d5f --- /dev/null +++ b/examples/math_prm/rollout_eos_patch.py @@ -0,0 +1,254 @@ +""" +Math PRM rollout EOS patch — keeps the fix local to examples/math_prm/. + +Background +---------- +On 8-GPU FSDP rollouts, historical attempts to terminate URSA generation +through a ``LogitsProcessor`` that nudges the eos logit up were unreliable: +logs showed the processor firing hundreds of times (``forced_eos_rows=291`` +per batch-of-4) while every sample still ran the full ``max_new_tokens`` +(``mean_length=511.8`` / 512). The "logits nudge → sampled token → +``EosTokenCriteria``" handshake does not close under FSDP's numerical +regime, and even on single card it is only probabilistic. + +Fix +--- +Install a ``StoppingCriteria`` directly on the rollout actor's underlying +HF model. HF's sample loop calls ``stopping_criteria(input_ids, scores)`` +*after* each new token is appended, and ANDs the returned mask into +``unfinished_sequences``. When we return True for a row, HF marks it +finished immediately — no sampling, no logit tricks, no numerical edge. + +The criteria also exposes an ``eos_token_id`` attribute so HF's +``has_eos_stopping_criteria`` detection (``utils.py:2735``) treats our +signal as EOS-equivalent and enables the post-EOS pad-fill path at +``utils.py:2835`` for rows we have marked done. + +Shape +----- +This module is self-contained under ``examples/math_prm/`` and is installed +from ``train_colocate.py`` via ``install_math_prm_rollout_eos_patch``. The +install helper wraps ``rollout_actor.model.generate`` so every generate +call gets a fresh criteria injected. Since math_prm's training loop only +ever runs math_prm batches, unconditional injection is correct — on +non-math content the criteria simply never sees ``†Answer:`` and its +``done`` mask stays all-False (pure runtime no-op). + +No changes to ``lightrft/`` are required. +""" + +from __future__ import annotations + +import functools +import time +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +import torch +from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList + +from math_prm_output import MATH_PRM_ANSWER_MARKER, should_stop_math_prm_response_text + + +class StructuredAnswerStoppingCriteria(StoppingCriteria): + """ + Terminate URSA-math_prm rollout generation when a fully-formed + ``†Answer: `` line has been emitted. + + Key properties: + + - Exposes ``eos_token_id`` as an attribute so HF's internal + ``has_eos_stopping_criteria`` detection (``utils.py:2735``) treats this + criteria as EOS-equivalent, which enables the post-EOS pad-fill path + (``utils.py:2835``). Without that attr, rows we mark done would keep + getting non-pad filler tokens written into their slots, and + ``process_sequences`` — which derives attention-mask from + ``seq.ne(eos_token_id) & seq.ne(pad_token_id)`` — would still count + those positions as real content. + - Checks only every ``check_interval`` tokens to amortise CPU + ``batch_decode`` cost (matching the existing LogitsProcessor cadence). + - Done bits are *sticky*: once set, the criteria re-asserts them on + every subsequent call, including between gated checks. This is + critical — HF's sample loop ANDs our return into + ``unfinished_sequences`` (``utils.py:2842``), so if we ever returned + False for a row we had previously stopped, HF would un-stop it. + """ + + def __init__(self, tokenizer, prompt_length: int, eos_token_id: int): + self.tokenizer = tokenizer + self.prompt_length = int(prompt_length) + self.eos_token_id = int(eos_token_id) + self.check_interval = 4 + self.marker_scan_max_tokens = 192 + self.answer_tail_max_tokens = 128 + self.answer_marker_token_ids = tuple( + int(token_id) for token_id in tokenizer.encode(MATH_PRM_ANSWER_MARKER, add_special_tokens=False) + ) + self._marker_seen: Optional[List[bool]] = None + self._done: Optional[torch.Tensor] = None + self._stats: Dict[str, float] = defaultdict(float) + + def _ensure_state(self, batch_size: int, device) -> None: + if self._marker_seen is None or len(self._marker_seen) != batch_size: + self._marker_seen = [False] * batch_size + if self._done is None or self._done.numel() != batch_size: + self._done = torch.zeros(batch_size, dtype=torch.bool, device=device) + elif self._done.device != device: + self._done = self._done.to(device) + + def _scan_row_for_answer_marker(self, row_token_ids: torch.Tensor) -> bool: + marker_token_ids = self.answer_marker_token_ids + if not marker_token_ids: + return MATH_PRM_ANSWER_MARKER in self.tokenizer.decode(row_token_ids, skip_special_tokens=False) + + token_ids = row_token_ids.tolist() + marker_len = len(marker_token_ids) + if len(token_ids) < marker_len: + return False + + search_start = max(0, len(token_ids) - self.marker_scan_max_tokens) + token_ids = token_ids[search_start:] + last_start = len(token_ids) - marker_len + 1 + for start_idx in range(max(last_start, 0)): + if tuple(token_ids[start_idx:start_idx + marker_len]) == marker_token_ids: + return True + return False + + def _decode_rows(self, row_token_ids: torch.Tensor) -> List[str]: + decode_t0 = time.time() + texts = self.tokenizer.batch_decode(row_token_ids, skip_special_tokens=False) + self._stats["decode_time_s"] += time.time() - decode_t0 + self._stats["decoded_rows"] += len(texts) + return texts + + def get_debug_stats(self) -> Optional[Dict[str, Union[int, float]]]: + if self._stats["calls"] <= 0: + return None + return { + "calls": int(self._stats["calls"]), + "gated_checks": int(self._stats["gated_checks"]), + "marker_scan_rows": int(self._stats["marker_scan_rows"]), + "marker_hits": int(self._stats["marker_hits"]), + "answer_tail_rows": int(self._stats["answer_tail_rows"]), + "decoded_rows": int(self._stats["decoded_rows"]), + "stopped_rows": int(self._stats["stopped_rows"]), + "decode_time_s": round(float(self._stats["decode_time_s"]), 4), + } + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: + self._stats["calls"] += 1 + batch_size = input_ids.size(0) + self._ensure_state(batch_size, input_ids.device) + + if input_ids.size(1) <= self.prompt_length: + return self._done.clone() + + generated_length = input_ids.size(1) - self.prompt_length + # Always return the sticky done mask, even on non-gated-check steps, + # so HF cannot flip previously-stopped rows back to unfinished. + if generated_length % self.check_interval != 0: + return self._done.clone() + self._stats["gated_checks"] += 1 + + unresolved_rows = [ + idx for idx in range(batch_size) + if not self._marker_seen[idx] and not bool(self._done[idx].item()) + ] + if unresolved_rows: + scan_start = max(self.prompt_length, input_ids.size(1) - self.marker_scan_max_tokens) + scan_ids = input_ids[unresolved_rows, scan_start:].detach().cpu() + self._stats["marker_scan_rows"] += len(unresolved_rows) + matched_row_indices = [] + matched_scan_ids = [] + for row_idx, row_token_ids in zip(unresolved_rows, scan_ids): + if self._scan_row_for_answer_marker(row_token_ids): + self._marker_seen[row_idx] = True + matched_row_indices.append(row_idx) + matched_scan_ids.append(row_token_ids) + + if matched_row_indices: + self._stats["marker_hits"] += len(matched_row_indices) + matched_scan_ids = torch.stack(matched_scan_ids) + scan_texts = self._decode_rows(matched_scan_ids) + for row_idx, text in zip(matched_row_indices, scan_texts): + if should_stop_math_prm_response_text(text): + self._done[row_idx] = True + + marker_rows = [ + idx for idx in range(batch_size) + if self._marker_seen[idx] and not bool(self._done[idx].item()) + ] + if marker_rows: + tail_start = max(self.prompt_length, input_ids.size(1) - self.answer_tail_max_tokens) + tail_ids = input_ids[marker_rows, tail_start:].detach().cpu() + self._stats["answer_tail_rows"] += len(marker_rows) + tail_texts = self._decode_rows(tail_ids) + for row_idx, text in zip(marker_rows, tail_texts): + if should_stop_math_prm_response_text(text): + self._done[row_idx] = True + + self._stats["stopped_rows"] = int(self._done.sum().item()) + return self._done.clone() + + +def install_math_prm_rollout_eos_patch(rollout_actor, tokenizer, eos_token_id: int) -> None: + """ + Wrap ``rollout_actor.model.generate`` so that every generate call gets a + fresh ``StructuredAnswerStoppingCriteria`` injected into its + ``stopping_criteria`` kwarg. + + This is only installed from the math_prm example's ``train_colocate.py`` + on the dedicated rollout actor that is used exclusively for math_prm + batches, so unconditional injection is correct and keeps the patch + self-contained without any reliance on lightrft-side signals. + + For non-math batches the criteria simply never sees ``†Answer:`` in the + decoded tail, so its ``done`` mask stays all-False and the patch is a + no-op at runtime. + + Idempotent: a second install call is a no-op. + """ + model = rollout_actor.model + if getattr(model, "_math_prm_rollout_eos_patch_installed", False): + return + + orig_generate = model.generate + + @functools.wraps(orig_generate) + def patched_generate(*args: Any, **kwargs: Any): + input_ids = kwargs.get("input_ids") + if input_ids is None and args: + input_ids = args[0] + if input_ids is not None and hasattr(input_ids, "size"): + prompt_length = int(input_ids.size(1)) + new_criteria = StructuredAnswerStoppingCriteria( + tokenizer=tokenizer, + prompt_length=prompt_length, + eos_token_id=int(eos_token_id), + ) + existing = kwargs.get("stopping_criteria") + if existing is None: + kwargs["stopping_criteria"] = StoppingCriteriaList([new_criteria]) + else: + # Be conservative — if caller already provided criteria, + # prepend ours rather than dropping theirs. + kwargs["stopping_criteria"] = StoppingCriteriaList([new_criteria, *existing]) + + # HF auto-enables `synced_gpus=True` under FSDP (see + # generation/utils.py:2218), but each rank here runs an independent + # local-HF generate on its own prompt slice: reshard_after_forward + # is False on the rollout actor so there are no per-step + # collectives. Leaving synced_gpus on causes the loop to `continue` + # past the input_ids append at utils.py:2838 once this rank's rows + # are all done — combined with URSA's prefill-vs-decode branching + # in modeling_ursa.py:279 (takes prefill when + # `input_ids.shape[1] != 1`), the stale input_ids triggers an + # IndexError in `_merge_input_ids_with_image_features`. Force it + # off so each rank's generate loop exits cleanly when its own + # stopping criteria fire. + kwargs.setdefault("synced_gpus", False) + + return orig_generate(*args, **kwargs) + + model.generate = patched_generate + model._math_prm_rollout_eos_patch_installed = True diff --git a/examples/math_prm/run_grpo_math_prm_ursa_8b.sh b/examples/math_prm/run_grpo_math_prm_ursa_8b.sh new file mode 100755 index 00000000..12fd9178 --- /dev/null +++ b/examples/math_prm/run_grpo_math_prm_ursa_8b.sh @@ -0,0 +1,261 @@ +#!/bin/bash +# Fail fast: a crashed torchrun must propagate its exit code through the +# `2>&1 | tee` pipeline below so multi-node orchestrators / CI see the error. +set -eo pipefail +# +# LightRFT GRPO Training Script - URSA-8B with URSA-8B-RM (Math PRM). +# +# This script trains URSA-8B (a multimodal math VLM built on Qwen2.5-Math) with +# URSA-8B-RM as a Process Reward Model. The reward signal is PS-GRPO over the +# PRM step scores: r in {0, 0.5, 1} based on outcome correctness and whether +# any step-score drop event was observed in the response. +# +# - Actor: URSA-8B (hybrid SAM-B + SigLIP-L vision tower + Qwen2.5-Math) +# - Reward: URSA-8B-RM (process reward model for step-level scoring) +# - Engine: local HF rollout (vLLM/SGLang URSA support is future work) +# - Algorithm: GRPO with PS-GRPO reward via the math_psgrpo label +# + +# Auto-load credentials/paths from .env if present (no-op when missing). +# Useful keys: WANDB_API_KEY, WANDB_PROJECT, HF_TOKEN, PATH_TO_YOUR_BASE_MODEL, +# PATH_TO_URSA_RM, PATH_TO_YOUR_MATH_DATASET, LIGHTRFT_OUTPUT_ROOT. +if [ -f "$(dirname "$0")/../../.env" ]; then + set -a; . "$(dirname "$0")/../../.env"; set +a +fi +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}" +# Alias project-specific WANDB key names to the canonical WANDB_API_KEY so +# the rest of the script (and wandb itself) can use the canonical name. +: "${WANDB_API_KEY:=${LIGHTRFT_WANDB_API_KEY:-${WANDB_TOKEN:-${WANDB_KEY:-}}}}" +export WANDB_API_KEY + +################################################################################ +# Part 1: User Configuration # +# Please update the following paths and settings to match your environment. # +################################################################################ + +# --- Model and Dataset Paths --- +# Each value can be overridden by exporting the env var with the same name +# before invoking this script (e.g. for CI or per-machine paths). The strings +# below are placeholders to make the script self-documenting; a real run must +# either edit them or override via env. +PATH_TO_YOUR_BASE_MODEL="${PATH_TO_YOUR_BASE_MODEL:-/path/to/your/URSA-8B}" +PATH_TO_URSA_RM="${PATH_TO_URSA_RM:-/path/to/your/URSA-RM-8B}" +PATH_TO_YOUR_MATH_DATASET="${PATH_TO_YOUR_MATH_DATASET:-/path/to/your/preprocessed/math_psgrpo.jsonl}" + +# --- Experiment and Logging --- +EXPERIMENT_NAME="${EXPERIMENT_NAME:-lightrft-ursa8b-math-prm}" +LIGHTRFT_OUTPUT_ROOT="${LIGHTRFT_OUTPUT_ROOT:-.}" + +# W&B configuration. Leave WANDB_API_KEY empty to disable W&B. +export WANDB_API_KEY="${WANDB_API_KEY:-YOUR_WANDB_API_KEY}" +WANDB_ORG="${WANDB_ORG:-${WANDB_ENTITY:-}}" +export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-URSA8B-MathPRM}" + + +################################################################################ +# Part 2: Training Hyperparameters # +# These settings control the training process. Adjust them as needed. # +################################################################################ + +# --- GRPO settings --- +N_SAMPLES=8 # Number of samples per prompt for GRPO (must be > 1). +EPISODE=10 # Total number of training episodes. +WARMUP=0.03 # Learning rate warmup ratio. +RBS=128 # Rollout Batch Size. +TBS=128 # Training Batch Size. + +# --- Learning and model settings --- +# K3 estimator (Schulman) at the historical default 0.001. The earlier proposal +# to switch to K2 + 0.005 was justified by KL ~ 11 nats observed on the broken +# run; once the silent log-prob misalignment was fixed (see PR #53), the real +# K3 sits at ~0.04 and the K2/K3/K1 ratios collapse to numerically equivalent +# small values, so the estimator + coefficient change has no remaining +# justification. Keep historical values to minimize the PR's behavior diff. +KL_ESTIMATOR=k3 # Schulman K3 = exp(-r) - 1 + r. Historical default. +KL=0.001 # Historical default. K3 * 0.001 ~= 4e-5 budget on real KL. +KL_TARGET="" # If set (e.g. "0.5"), enables AdaptiveKLController. +# Variant 2 per-step PRM reward mode. Only meaningful when prompts have label +# "math_per_step_prm" (see fast_exp_maker._apply_step_reward_group_norm). Values: +# raw : scatter raw sigmoid step_score (paper Figure ablation; default) +# group_norm : per-step group-relative baseline (GRPO convention) +PER_STEP_REWARD_MODE="${PER_STEP_REWARD_MODE:-raw}" + +LR=1e-6 # Actor learning rate. +PROMPT_MAX_LEN=1024 # Max length of the input prompt. +GENERATE_MAX_LEN=3072 # Max length of the generated response. +MAX_SAMPLES=15360 # Cap on the training subset size. + +# --- Multi-modal settings --- +limit_mm_image_per_prompt=10 + +# --- Evaluation settings --- +# Eval pulls a fixed deterministic held-out subset out of the training manifest +# (URSA Stage 3 protocol). +EVAL_STEPS=20 +EVAL_HOLDOUT_SIZE=500 +MAX_EVAL_SAMPLES=500 + + +################################################################################ +# Part 3: Distributed Training Setup # +# Configure settings for multi-GPU and multi-node training. # +################################################################################ + +export NNODES="${NNODES:-1}" +export GPUS_PER_NODE="${GPUS_PER_NODE:-8}" +export NODE_RANK="${NODE_RANK:-0}" +export MASTER_ADDR="${MASTER_ADDR:-localhost}" +export MASTER_PORT="${MASTER_PORT:-20092}" + + +################################################################################ +# Part 4: Execution and Logging # +# This section prepares and launches the training command. # +################################################################################ + +# --- Generate dynamic names and paths --- +# SAVE_MODEL_NAME / WANDB_RUN_NAME are env-overridable so a resumed run can target +# the existing ckpt directory instead of creating a fresh timestamped one. +current_time=$(date +"%Y%m%d_%H%M%S") +SAVE_MODEL_NAME="${SAVE_MODEL_NAME:-${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}}" +WANDB_RUN_NAME="${WANDB_RUN_NAME:-${EXPERIMENT_NAME}-${current_time}}" +SAVE_DIR="${LIGHTRFT_OUTPUT_ROOT}/results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" +LOG_DIR="${LIGHTRFT_OUTPUT_ROOT}/rft_logs/${EXPERIMENT_NAME}" +export WANDB_DIR="${WANDB_DIR:-${LIGHTRFT_OUTPUT_ROOT}/wandb}" + +mkdir -p "${SAVE_DIR}" +mkdir -p "${LOG_DIR}" +mkdir -p "${WANDB_DIR}" +TRAIN_LOG="${LOG_DIR}/node${NODE_RANK}_${current_time}.log" + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG="WARN" +if [[ -n "${WANDB_API_KEY}" && "${WANDB_API_KEY}" != "YOUR_WANDB_API_KEY" ]]; then + export WANDB_MODE="${WANDB_MODE:-online}" +else + export WANDB_MODE="${WANDB_MODE:-offline}" +fi + +# Optional adaptive-KL flag block (only added when KL_TARGET is non-empty). +KL_TARGET_ARGS=() +if [[ -n "${KL_TARGET}" ]]; then + KL_TARGET_ARGS=(--kl_target "${KL_TARGET}") +fi + +# Optional resume-from-checkpoint flag. Set LOAD_CHECKPOINT=1 in the environment +# to continue training from ${ckpt_path}/_actor (and _critic if applicable). +RESUME_ARGS=() +if [[ "${LOAD_CHECKPOINT:-0}" == "1" ]]; then + RESUME_ARGS=(--load_checkpoint) +fi + +WANDB_ORG_ARGS=() +if [[ -n "${WANDB_ORG}" ]]; then + WANDB_ORG_ARGS=(--wandb_org "${WANDB_ORG}") +fi + +# Math PRM uses a single URSA-RM checkpoint registered under the math_prm label. +REWARD_PRETRAIN_PATHS="{\"math_prm\":\"${PATH_TO_URSA_RM}\"}" + +# URSA enforces a fixed structured response format for the PRM scorer. +SYSTEM_PROMPT='A conversation between the User and Assistant. The User asks a question that may require mathematical or visual reasoning, and the Assistant solves it step by step. Each step MUST begin with "Step N:" (e.g. "Step 1:", "Step 2:") on its own line. After all steps, output exactly one final answer line prefixed with "†Answer:" (e.g. "†Answer: 42"). Stop immediately after the "†Answer:" line and do not output any extra text, repeated answer markers, or additional steps.' + + +################################################################################ +# Part 5: Main Training Command # +################################################################################ + +# Use the conda env's torchrun explicitly: under bash -c, `conda activate` does +# not propagate to subprocesses, so a plain `torchrun` may resolve to a system +# python that lacks transformers/flash_attn etc. Override with TORCHRUN= if you +# launch from a different env. +TORCHRUN="${TORCHRUN:-torchrun}" +"${TORCHRUN}" \ + --nnodes $NNODES \ + --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK \ + --master-port $MASTER_PORT \ + --master-addr $MASTER_ADDR \ + examples/math_prm/train_colocate.py \ + --pretrain "${PATH_TO_YOUR_BASE_MODEL}" \ + --reward_pretrain "${REWARD_PRETRAIN_PATHS}" \ + --prompt_data "${PATH_TO_YOUR_MATH_DATASET}" \ + --max_samples ${MAX_SAMPLES} \ + --input_key "prompt" \ + --images_key "images" \ + --label_key "label" \ + --apply_chat_template \ + --system_prompt "${SYSTEM_PROMPT}" \ + --save_path "${SAVE_DIR}" \ + --ckpt_path "${SAVE_DIR}" \ + --save_steps 20 \ + --max_ckpt_num 2 \ + --save_trajectories \ + --num_trajectories_to_save 16 \ + --print_replay_buffer_stats \ + --fsdp \ + --bf16 \ + --flash_attn \ + --gradient_checkpointing \ + --zero_stage 3 \ + --adam_offload \ + --freeze_prefix \ + --l2 1.0e-2 \ + --mixed_mm_data \ + --limit_mm_image_per_prompt $limit_mm_image_per_prompt \ + --loss_agg_mode "seq-mean-token-mean" \ + --advantage_estimator "group_norm" \ + --max_epochs 1 \ + --num_episodes ${EPISODE} \ + --lr_warmup_ratio ${WARMUP} \ + --n_samples_per_prompt $N_SAMPLES \ + --train_batch_size ${TBS} \ + --rollout_batch_size ${RBS} \ + --prompt_max_len $PROMPT_MAX_LEN \ + --generate_max_len $GENERATE_MAX_LEN \ + --actor_learning_rate $LR \ + --use_kl_loss \ + --init_kl_coef $KL \ + --kl_estimator "${KL_ESTIMATOR}" \ + --per_step_reward_mode "${PER_STEP_REWARD_MODE}" \ + "${KL_TARGET_ARGS[@]}" \ + "${RESUME_ARGS[@]}" \ + --engine_type "hf" \ + --engine_mem_util 0.6 \ + --local_hf_generate_max_batch_size 4 \ + --local_hf_max_new_tokens 512 \ + --hf_separate_rollout_actor \ + --hf_separate_rollout_keep_on_gpu \ + --enable_engine_sleep \ + --eval_steps ${EVAL_STEPS} \ + --eval_holdout_size ${EVAL_HOLDOUT_SIZE} \ + --max_eval_samples ${MAX_EVAL_SAMPLES} \ + --use_wandb "true" \ + "${WANDB_ORG_ARGS[@]}" \ + --wandb_project "${WANDB_PROJECT}" \ + --wandb_run_name "${WANDB_RUN_NAME}" \ + 2>&1 | tee "${TRAIN_LOG}" + + +################################################################################ +# Usage Instructions # +# # +# Step 1: Prepare the URSA-8B actor and URSA-8B-RM reward model checkpoints. # +# Both are public on Hugging Face under the URSA-MATH project. Set # +# PATH_TO_YOUR_BASE_MODEL and PATH_TO_URSA_RM to the local directories. # +# # +# Step 2: Preprocess the math PRM dataset. # +# `python examples/math_prm/tools/prepare_ursa_stage3_manifest.py` # +# produces a JSONL manifest with fields {prompt, images, reference, label} # +# where label="math_psgrpo" enables the PS-GRPO reward path. # +# # +# Step 3: Configure the script. # +# Edit "Part 1: User Configuration" at the top of this file. Set the paths # +# to your URSA-8B actor, URSA-8B-RM reward model, and preprocessed manifest. # +# # +# Step 4: Run the training script. # +# `bash examples/math_prm/run_grpo_math_prm_ursa_8b.sh` # +# # +################################################################################ diff --git a/examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh b/examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh new file mode 100755 index 00000000..319d110a --- /dev/null +++ b/examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh @@ -0,0 +1,303 @@ +#!/bin/bash +# Fail fast: a crashed torchrun must propagate its exit code through the +# `2>&1 | tee` pipeline below so multi-node orchestrators / CI see the error. +set -eo pipefail +# +# LightRFT GRPO Training Script — URSA-8B with URSA-8B-RM, strict URSA-paper +# Eq.9 (variant 2) advantage. +# +# Differs from run_grpo_math_prm_ursa_8b.sh ONLY in --advantage_estimator: +# ursa_variant2 — strict paper Eq.9 form, computed in +# examples/math_prm/ursa_variant2.py: +# A_t^i = r_{s,t}^i · GroupNorm_G(r̄_s^i) +# + GroupNorm_G(r_o^i) +# A_t broadcast to every token in step t's span. +# No cumulative return. Outcome term retained. +# +# Auto-swaps PATH_TO_YOUR_MATH_DATASET to the .per_step_prm.jsonl sibling +# (label="math_per_step_prm") because variant 2 needs per-step labels. +# +# - Actor: URSA-8B (hybrid SAM-B + SigLIP-L vision tower + Qwen2.5-Math) +# - Reward: URSA-8B-RM (process reward model for step-level scoring) +# - Engine: local HF rollout (vLLM/SGLang URSA support is future work) +# - Algorithm: GRPO with strict URSA paper Eq.9 advantage via the +# math_per_step_prm label (see examples/math_prm/ursa_variant2.py) +# + +# Auto-load credentials/paths from .env if present (no-op when missing). +# Useful keys: WANDB_API_KEY, WANDB_PROJECT, HF_TOKEN, PATH_TO_YOUR_BASE_MODEL, +# PATH_TO_URSA_RM, PATH_TO_YOUR_MATH_DATASET, LIGHTRFT_OUTPUT_ROOT. +if [ -f "$(dirname "$0")/../../.env" ]; then + set -a; . "$(dirname "$0")/../../.env"; set +a +fi +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +export PYTHONPATH="${REPO_ROOT}:${PYTHONPATH:-}" +# Alias project-specific WANDB key names to the canonical WANDB_API_KEY so +# the rest of the script (and wandb itself) can use the canonical name. +: "${WANDB_API_KEY:=${LIGHTRFT_WANDB_API_KEY:-${WANDB_TOKEN:-${WANDB_KEY:-}}}}" +export WANDB_API_KEY + +################################################################################ +# Part 1: User Configuration # +# Please update the following paths and settings to match your environment. # +################################################################################ + +# --- Model and Dataset Paths --- +# Each value can be overridden by exporting the env var with the same name +# before invoking this script (e.g. for CI or per-machine paths). The strings +# below are placeholders to make the script self-documenting; a real run must +# either edit them or override via env. +PATH_TO_YOUR_BASE_MODEL="${PATH_TO_YOUR_BASE_MODEL:-/path/to/your/URSA-8B}" +PATH_TO_URSA_RM="${PATH_TO_URSA_RM:-/path/to/your/URSA-RM-8B}" +# variant 2 NEEDS rows labeled "math_per_step_prm". The PS-GRPO dataset has +# label="math_psgrpo" everywhere — running variant 2 on it would silently +# emit zero step_rewards. If the caller still points PATH_TO_YOUR_MATH_DATASET +# at a psgrpo .jsonl (legacy default), we auto-swap to its sed-relabeled +# sibling. Build that sibling once with the one-liner documented in +# README.md §6 "Strict Paper Eq.9 — variant 2 path": +# sed 's/"label":[ ]*"math_psgrpo"/"label": "math_per_step_prm"/g' \ +# /path/to/math_psgrpo.jsonl > /path/to/math_per_step_prm.jsonl +# If the caller wants a custom path, set PATH_TO_YOUR_MATH_DATASET_VARIANT2. +if [ -n "${PATH_TO_YOUR_MATH_DATASET_VARIANT2:-}" ]; then + PATH_TO_YOUR_MATH_DATASET="${PATH_TO_YOUR_MATH_DATASET_VARIANT2}" +elif [ -n "${PATH_TO_YOUR_MATH_DATASET:-}" ] && [[ "${PATH_TO_YOUR_MATH_DATASET}" != *per_step_prm* ]]; then + PATH_TO_YOUR_MATH_DATASET="${PATH_TO_YOUR_MATH_DATASET%.jsonl}.per_step_prm.jsonl" +fi +PATH_TO_YOUR_MATH_DATASET="${PATH_TO_YOUR_MATH_DATASET:-/path/to/your/preprocessed/math_per_step_prm.jsonl}" +if [ ! -f "${PATH_TO_YOUR_MATH_DATASET}" ]; then + echo "[variant2 launch] FATAL: dataset not found: ${PATH_TO_YOUR_MATH_DATASET}" >&2 + exit 1 +fi +# Sanity: first row must already be relabeled, otherwise variant 2 silently fails. +FIRST_LABEL=$(head -1 "${PATH_TO_YOUR_MATH_DATASET}" | python3 -c 'import sys,json; print(json.loads(sys.stdin.read()).get("label",""))' 2>/dev/null || echo "") +if [ "${FIRST_LABEL}" != "math_per_step_prm" ]; then + echo "[variant2 launch] FATAL: dataset first row label='${FIRST_LABEL}', expected 'math_per_step_prm'." >&2 + echo "Pre-process with:" >&2 + echo " sed 's/\"label\":[ ]*\"math_psgrpo\"/\"label\": \"math_per_step_prm\"/g' SRC > DST" >&2 + exit 1 +fi +echo "[variant2 launch] using dataset: ${PATH_TO_YOUR_MATH_DATASET}" + +# --- Experiment and Logging --- +EXPERIMENT_NAME="${EXPERIMENT_NAME:-lightrft-ursa8b-math-prm-variant2}" +LIGHTRFT_OUTPUT_ROOT="${LIGHTRFT_OUTPUT_ROOT:-.}" + +# W&B configuration. Leave WANDB_API_KEY empty to disable W&B. +export WANDB_API_KEY="${WANDB_API_KEY:-YOUR_WANDB_API_KEY}" +WANDB_ORG="${WANDB_ORG:-${WANDB_ENTITY:-}}" +export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-URSA8B-MathPRM}" + + +################################################################################ +# Part 2: Training Hyperparameters # +# These settings control the training process. Adjust them as needed. # +################################################################################ + +# --- GRPO settings --- +N_SAMPLES=8 # Number of samples per prompt for GRPO (must be > 1). +EPISODE=10 # Total number of training episodes. +WARMUP=0.03 # Learning rate warmup ratio. +RBS=128 # Rollout Batch Size. +TBS=128 # Training Batch Size. + +# --- Learning and model settings --- +# K3 estimator (Schulman) at the historical default 0.001. The earlier proposal +# to switch to K2 + 0.005 was justified by KL ~ 11 nats observed on the broken +# run; once the silent log-prob misalignment was fixed (see PR #53), the real +# K3 sits at ~0.04 and the K2/K3/K1 ratios collapse to numerically equivalent +# small values, so the estimator + coefficient change has no remaining +# justification. Keep historical values to minimize the PR's behavior diff. +KL_ESTIMATOR=k3 # Schulman K3 = exp(-r) - 1 + r. Historical default. +KL=0.001 # Historical default. K3 * 0.001 ~= 4e-5 budget on real KL. +KL_TARGET="" # If set (e.g. "0.5"), enables AdaptiveKLController. +# NOTE: --per_step_reward_mode is intentionally NOT passed to train_colocate.py +# in this launcher. It only affects the legacy Math-Shepherd-style per-token +# reward path (fast_exp_maker._apply_step_reward_group_norm); the +# ursa_variant2 advantage estimator does its own GroupNorm inside +# UrsaVariant2Calculator.preprocess_rewards (see examples/math_prm/ursa_variant2.py), +# so passing the flag here is inert and only adds cognitive load. The +# PS-GRPO launcher (run_grpo_math_prm_ursa_8b.sh) still exposes it. + +LR=1e-6 # Actor learning rate. +PROMPT_MAX_LEN=1024 # Max length of the input prompt. +GENERATE_MAX_LEN=3072 # Max length of the generated response. +MAX_SAMPLES=15360 # Cap on the training subset size. + +# --- Multi-modal settings --- +limit_mm_image_per_prompt=10 + +# --- Evaluation settings --- +# Eval pulls a fixed deterministic held-out subset out of the training manifest +# (URSA Stage 3 protocol). +EVAL_STEPS=20 +EVAL_HOLDOUT_SIZE=500 +MAX_EVAL_SAMPLES=500 + + +################################################################################ +# Part 3: Distributed Training Setup # +# Configure settings for multi-GPU and multi-node training. # +################################################################################ + +export NNODES="${NNODES:-1}" +export GPUS_PER_NODE="${GPUS_PER_NODE:-8}" +export NODE_RANK="${NODE_RANK:-0}" +export MASTER_ADDR="${MASTER_ADDR:-localhost}" +export MASTER_PORT="${MASTER_PORT:-20092}" + + +################################################################################ +# Part 4: Execution and Logging # +# This section prepares and launches the training command. # +################################################################################ + +# --- Generate dynamic names and paths --- +# SAVE_MODEL_NAME / WANDB_RUN_NAME are env-overridable so a resumed run can target +# the existing ckpt directory instead of creating a fresh timestamped one. +current_time=$(date +"%Y%m%d_%H%M%S") +SAVE_MODEL_NAME="${SAVE_MODEL_NAME:-${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}}" +WANDB_RUN_NAME="${WANDB_RUN_NAME:-${EXPERIMENT_NAME}-${current_time}}" +SAVE_DIR="${LIGHTRFT_OUTPUT_ROOT}/results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" +LOG_DIR="${LIGHTRFT_OUTPUT_ROOT}/rft_logs/${EXPERIMENT_NAME}" +export WANDB_DIR="${WANDB_DIR:-${LIGHTRFT_OUTPUT_ROOT}/wandb}" + +mkdir -p "${SAVE_DIR}" +mkdir -p "${LOG_DIR}" +mkdir -p "${WANDB_DIR}" +TRAIN_LOG="${LOG_DIR}/node${NODE_RANK}_${current_time}.log" + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG="WARN" +if [[ -n "${WANDB_API_KEY}" && "${WANDB_API_KEY}" != "YOUR_WANDB_API_KEY" ]]; then + export WANDB_MODE="${WANDB_MODE:-online}" +else + export WANDB_MODE="${WANDB_MODE:-offline}" +fi + +# Optional adaptive-KL flag block (only added when KL_TARGET is non-empty). +KL_TARGET_ARGS=() +if [[ -n "${KL_TARGET}" ]]; then + KL_TARGET_ARGS=(--kl_target "${KL_TARGET}") +fi + +# Optional resume-from-checkpoint flag. Set LOAD_CHECKPOINT=1 in the environment +# to continue training from ${ckpt_path}/_actor (and _critic if applicable). +RESUME_ARGS=() +if [[ "${LOAD_CHECKPOINT:-0}" == "1" ]]; then + RESUME_ARGS=(--load_checkpoint) +fi + +WANDB_ORG_ARGS=() +if [[ -n "${WANDB_ORG}" ]]; then + WANDB_ORG_ARGS=(--wandb_org "${WANDB_ORG}") +fi + +# Math PRM uses a single URSA-RM checkpoint registered under the math_prm label. +REWARD_PRETRAIN_PATHS="{\"math_prm\":\"${PATH_TO_URSA_RM}\"}" + +# URSA enforces a fixed structured response format for the PRM scorer. +SYSTEM_PROMPT='A conversation between the User and Assistant. The User asks a question that may require mathematical or visual reasoning, and the Assistant solves it step by step. Each step MUST begin with "Step N:" (e.g. "Step 1:", "Step 2:") on its own line. After all steps, output exactly one final answer line prefixed with "†Answer:" (e.g. "†Answer: 42"). Stop immediately after the "†Answer:" line and do not output any extra text, repeated answer markers, or additional steps.' + + +################################################################################ +# Part 5: Main Training Command # +################################################################################ + +# Use the conda env's torchrun explicitly: under bash -c, `conda activate` does +# not propagate to subprocesses, so a plain `torchrun` may resolve to a system +# python that lacks transformers/flash_attn etc. Override with TORCHRUN= if you +# launch from a different env. +TORCHRUN="${TORCHRUN:-torchrun}" +"${TORCHRUN}" \ + --nnodes $NNODES \ + --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK \ + --master-port $MASTER_PORT \ + --master-addr $MASTER_ADDR \ + examples/math_prm/train_colocate.py \ + --pretrain "${PATH_TO_YOUR_BASE_MODEL}" \ + --reward_pretrain "${REWARD_PRETRAIN_PATHS}" \ + --prompt_data "${PATH_TO_YOUR_MATH_DATASET}" \ + --max_samples ${MAX_SAMPLES} \ + --input_key "prompt" \ + --images_key "images" \ + --label_key "label" \ + --apply_chat_template \ + --system_prompt "${SYSTEM_PROMPT}" \ + --save_path "${SAVE_DIR}" \ + --ckpt_path "${SAVE_DIR}" \ + --save_steps 20 \ + --max_ckpt_num 2 \ + --save_trajectories \ + --num_trajectories_to_save 16 \ + --print_replay_buffer_stats \ + --fsdp \ + --bf16 \ + --flash_attn \ + --gradient_checkpointing \ + --zero_stage 3 \ + --adam_offload \ + --freeze_prefix \ + --l2 1.0e-2 \ + --mixed_mm_data \ + --limit_mm_image_per_prompt $limit_mm_image_per_prompt \ + --loss_agg_mode "seq-mean-token-mean" \ + --advantage_estimator "ursa_variant2" \ + --max_epochs 1 \ + --num_episodes ${EPISODE} \ + --lr_warmup_ratio ${WARMUP} \ + --n_samples_per_prompt $N_SAMPLES \ + --train_batch_size ${TBS} \ + --rollout_batch_size ${RBS} \ + --prompt_max_len $PROMPT_MAX_LEN \ + --generate_max_len $GENERATE_MAX_LEN \ + --actor_learning_rate $LR \ + --use_kl_loss \ + --init_kl_coef $KL \ + --kl_estimator "${KL_ESTIMATOR}" \ + "${KL_TARGET_ARGS[@]}" \ + "${RESUME_ARGS[@]}" \ + --engine_type "hf" \ + --engine_mem_util 0.6 \ + --local_hf_generate_max_batch_size 4 \ + --local_hf_max_new_tokens 512 \ + --hf_separate_rollout_actor \ + --hf_separate_rollout_keep_on_gpu \ + --enable_engine_sleep \ + --eval_steps ${EVAL_STEPS} \ + --eval_holdout_size ${EVAL_HOLDOUT_SIZE} \ + --max_eval_samples ${MAX_EVAL_SAMPLES} \ + --use_wandb "true" \ + "${WANDB_ORG_ARGS[@]}" \ + --wandb_project "${WANDB_PROJECT}" \ + --wandb_run_name "${WANDB_RUN_NAME}" \ + 2>&1 | tee "${TRAIN_LOG}" + + +################################################################################ +# Usage Instructions # +# # +# Step 1: Prepare the URSA-8B actor and URSA-8B-RM reward model checkpoints. # +# Both are public on Hugging Face under the URSA-MATH project. Set # +# PATH_TO_YOUR_BASE_MODEL and PATH_TO_URSA_RM to the local directories. # +# # +# Step 2: Preprocess the math PRM dataset and relabel it for variant 2. # +# First produce the standard PS-GRPO manifest: # +# python examples/math_prm/tools/prepare_ursa_stage3_manifest.py \ # +# --input-path /your/data/MMathCoT-1M/train.jsonl \ # +# --image-root /your/data/URSA-MATH/images \ # +# --output-path /your/output/math_psgrpo.jsonl # +# Then sed-relabel into a math_per_step_prm sibling (variant 2 requires it): # +# sed 's/"label":[ ]*"math_psgrpo"/"label": "math_per_step_prm"/g' \ # +# /your/output/math_psgrpo.jsonl # +# > /your/output/math_per_step_prm.jsonl # +# # +# Step 3: Configure the script. # +# Edit "Part 1: User Configuration" at the top of this file. Set the paths # +# to your URSA-8B actor, URSA-8B-RM reward model, and preprocessed manifest. # +# # +# Step 4: Run the training script. # +# `bash examples/math_prm/run_grpo_math_prm_ursa_8b_variant2.sh` # +# # +################################################################################ diff --git a/examples/math_prm/test_ursa_variant2.py b/examples/math_prm/test_ursa_variant2.py new file mode 100644 index 00000000..6fccb154 --- /dev/null +++ b/examples/math_prm/test_ursa_variant2.py @@ -0,0 +1,369 @@ +"""Strict-alignment tests for the URSA paper Eq.9 advantage estimator. + +Tests cover the five acceptance criteria AC1–AC5 from the PR plan: + + AC1 numerical equivalence with hand-computed paper Eq.9 (max|Δ|<1e-5) + AC2 outcome reward is NOT bypassed (changing r_o changes advantages) + AC3 group normalization is correct over K=n_samples_per_prompt + AC4 per-step advantage broadcast to the *full* step span (not just the + boundary token); advantage jumps at step boundaries + AC5 with realistic mixed-sign inputs, advantages contain both signs + (regression against the legacy ``per_step_reward_mode=raw`` failure + mode where every advantage was positive) + +Run from repo root: + python3 -m unittest examples.math_prm.test_ursa_variant2 -v +Or directly: + python3 examples/math_prm/test_ursa_variant2.py +""" + +from __future__ import annotations + +import math +import os +import sys +import unittest +from types import SimpleNamespace +from typing import List + +# Allow `import ursa_variant2` whether run from repo root (CI) or from +# examples/math_prm/ (developer convenience). +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +if _THIS_DIR not in sys.path: + sys.path.insert(0, _THIS_DIR) + +import torch + +import ursa_variant2 as ursa_v2 # registers the monkey-patch at import time + + +def _make_cfg(n_samples: int = 2, advantage_clip: float = 0) -> SimpleNamespace: + """Minimal config namespace UrsaVariant2Calculator reads from.""" + return SimpleNamespace( + n_samples_per_prompt=n_samples, + advantage_clip=advantage_clip, + ) + + +def _make_exp( + action_mask: torch.Tensor, + reward: torch.Tensor, + step_rewards: List[torch.Tensor], + step_token_indices: List[torch.Tensor], +): + """Minimal experience stub: just info dict + action_mask.""" + return SimpleNamespace( + action_mask=action_mask, + info={ + "reward": reward, + "step_rewards": step_rewards, + "step_token_indices": step_token_indices, + }, + ) + + +# Hand-computed Eq.9 reference, written out explicitly so reviewers can +# verify the test logic itself matches paper Appendix B.1 Eq.9. +def _hand_compute_eq9( + step_rewards: List[torch.Tensor], + step_token_indices: List[torch.Tensor], + outcome: torch.Tensor, + K: int, + T: int, +) -> torch.Tensor: + """Brute-force reference: build per-token advantages following paper Eq.9. + + A_t^i = r_{s,t}^i * GroupNorm_G(r̄_s^i) + GroupNorm_G(r_o^i) + where t indexes steps and the value is broadcast to every token within + the span [start_k, end_k] (start_0=0, start_k = end_{k-1}+1). + """ + B = outcome.numel() + assert B % K == 0 + G = B // K + + r_bar = torch.stack( + [sr.float().mean() if sr.numel() > 0 else torch.tensor(0.0) for sr in step_rewards] + ) + + def gn(x: torch.Tensor) -> torch.Tensor: + g = x.float().reshape(G, K) + return ((g - g.mean(dim=-1, keepdim=True)) + / (g.std(dim=-1, unbiased=False, keepdim=True) + 1e-9)).flatten() + + oc_norm = gn(outcome) + msp_norm = gn(r_bar) + + adv = torch.zeros(B, T, dtype=torch.float32) + for i in range(B): + sr = step_rewards[i].float() + si = step_token_indices[i].long() + n = sr.numel() + if n == 0: + adv[i] = oc_norm[i] + continue + starts = torch.cat([torch.zeros(1, dtype=torch.long), si[:-1] + 1]) + ends = si + for k in range(n): + sk = max(0, int(starts[k])) + ek = min(T - 1, int(ends[k])) + adv[i, sk:ek + 1] = sr[k] * msp_norm[i] + oc_norm[i] + last_end = int(ends[-1]) + if last_end + 1 < T: + adv[i, last_end + 1:] = oc_norm[i] + return adv + + +class _Base(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + + +class TestAC1NumericalEquivalence(_Base): + """AC1: implementation matches hand-computed Eq.9 within tolerance.""" + + def test_basic_k2_three_steps(self): + K = 2 + B = 4 + T = 30 + step_rewards = [ + torch.tensor([0.80, 0.70, 0.30]), # traj 0 + torch.tensor([0.85, 0.75, 0.90]), # traj 1 + torch.tensor([0.50, 0.55, 0.60]), # traj 2 + torch.tensor([0.60, 0.65, 0.70]), # traj 3 + ] + step_token_indices = [torch.tensor([5, 12, 20])] * 4 + outcome = torch.tensor([1.0, 1.0, 0.0, 1.0]) + action_mask = torch.ones(B, T, dtype=torch.long) + + expected = _hand_compute_eq9(step_rewards, step_token_indices, outcome, K, T) + + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + # preprocess_rewards takes one experience holding all B trajectories + exp = _make_exp(action_mask, outcome, step_rewards, step_token_indices) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + adv, ret, info = calc.compute(exp, final_reward=None, gamma=None, generate_kwargs={}) + + max_abs = (adv - expected).abs().max().item() + self.assertLess(max_abs, 1e-5, f"AC1 violated: max|Δ|={max_abs}") + # returns mirror advantages (no value function) + self.assertLess((ret - adv).abs().max().item(), 1e-9) + + def test_k4_variable_step_count(self): + K = 4 + T = 40 + step_rewards = [ + torch.tensor([0.9, 0.8]), # 2 steps + torch.tensor([0.5, 0.4, 0.6]), # 3 steps + torch.tensor([0.7]), # 1 step + torch.tensor([0.6, 0.65, 0.55, 0.50]), # 4 steps + ] + step_token_indices = [ + torch.tensor([10, 25]), + torch.tensor([8, 18, 30]), + torch.tensor([22]), + torch.tensor([7, 15, 22, 33]), + ] + outcome = torch.tensor([1.0, 0.0, 1.0, 0.0]) + action_mask = torch.ones(K, T, dtype=torch.long) + + expected = _hand_compute_eq9(step_rewards, step_token_indices, outcome, K, T) + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + exp = _make_exp(action_mask, outcome, step_rewards, step_token_indices) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + adv, _, _ = calc.compute(exp, None, None, {}) + max_abs = (adv - expected).abs().max().item() + self.assertLess(max_abs, 1e-5, f"AC1 (K=4 variable steps) violated: max|Δ|={max_abs}") + + +class TestAC2OutcomeNotBypassed(_Base): + """AC2: changing outcome reward must change advantages. + + Regression for the Mode B bypass bug: under the old per-step path, + feeding outcome through ``compute_reward`` r had no effect because + Mode B threw it away. UrsaVariant2Calculator must NOT have this bug. + """ + + def _run(self, outcome): + K = 2 + B = 4 + T = 20 + step_rewards = [torch.tensor([0.5, 0.6, 0.7])] * B + step_token_indices = [torch.tensor([5, 10, 15])] * B + action_mask = torch.ones(B, T, dtype=torch.long) + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + exp = _make_exp(action_mask, outcome, step_rewards, step_token_indices) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + adv, _, _ = calc.compute(exp, None, None, {}) + return adv + + def test_all_correct_vs_all_wrong_differs(self): + # If outcome is constant across the group, the GroupNorm of outcome + # is exactly zero, so the additive term vanishes — that's expected + # and not a bug; instead we compare a per-prompt mixed case. + # (One sample correct, one wrong in each of two prompts.) + oc_a = torch.tensor([1.0, 0.0, 1.0, 0.0]) + oc_b = torch.tensor([0.0, 1.0, 0.0, 1.0]) + adv_a = self._run(oc_a) + adv_b = self._run(oc_b) + diff = (adv_a - adv_b).abs().max().item() + self.assertGreater(diff, 0.5, f"AC2 violated: outcome flip should " + f"flip the sign of the outcome term " + f"(max|Δ|={diff})") + + def test_outcome_anchor_extends_past_last_step(self): + # Pad tail past the last step should carry the outcome anchor only. + K = 2 + T = 30 + step_rewards = [ + torch.tensor([0.6, 0.7]), + torch.tensor([0.6, 0.7]), + ] + step_token_indices = [torch.tensor([5, 10])] * 2 + # Trajectory 0 wins outcome, trajectory 1 loses — group_norm gives + # ±1 for the outcome term. + outcome = torch.tensor([1.0, 0.0]) + action_mask = torch.ones(K, T, dtype=torch.long) + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + exp = _make_exp(action_mask, outcome, step_rewards, step_token_indices) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + adv, _, _ = calc.compute(exp, None, None, {}) + # tail tokens (idx 11..29) should equal oc_normed (= ±1 within tol) + traj0_tail = adv[0, 11:] + traj1_tail = adv[1, 11:] + # Expect tail = oc_normed[i] (no process-reward there) + self.assertGreater(traj0_tail.mean().item(), 0.5, + f"traj0 tail must carry positive outcome anchor " + f"(got mean={traj0_tail.mean().item():.3f})") + self.assertLess(traj1_tail.mean().item(), -0.5, + f"traj1 tail must carry negative outcome anchor " + f"(got mean={traj1_tail.mean().item():.3f})") + + +class TestAC3GroupNormCorrect(_Base): + """AC3: GroupNorm zero-mean / unit-std across K siblings for both terms.""" + + def test_k2_msp_normed_zero_mean(self): + K = 2 + B = 4 + T = 10 + step_rewards = [ + torch.tensor([0.9, 0.8, 0.7]), + torch.tensor([0.3, 0.4, 0.2]), + torch.tensor([0.8, 0.7, 0.6]), + torch.tensor([0.5, 0.5, 0.4]), + ] + sti = [torch.tensor([2, 5, 8])] * B + outcome = torch.tensor([1.0, 0.0, 1.0, 0.0]) + action_mask = torch.ones(B, T, dtype=torch.long) + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + exp = _make_exp(action_mask, outcome, step_rewards, sti) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + + # Stored normed values should sum to 0 per group (within tol) + oc_n = exp.info["_ursa_oc_normed"].view(-1, K) + msp_n = exp.info["_ursa_msp_normed"].view(-1, K) + self.assertLess(oc_n.sum(dim=-1).abs().max().item(), 1e-5) + self.assertLess(msp_n.sum(dim=-1).abs().max().item(), 1e-5) + + def test_k4_msp_normed_unit_std(self): + K = 4 + T = 10 + step_rewards = [ + torch.tensor([0.9, 0.8]), + torch.tensor([0.5, 0.4]), + torch.tensor([0.7, 0.7]), + torch.tensor([0.2, 0.3]), + ] + sti = [torch.tensor([3, 7])] * K + outcome = torch.tensor([1.0, 0.0, 1.0, 0.0]) + action_mask = torch.ones(K, T, dtype=torch.long) + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + exp = _make_exp(action_mask, outcome, step_rewards, sti) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + msp = exp.info["_ursa_msp_normed"] + # std (unbiased=False) across the single group of K=4 should be ~1 + std = msp.std(unbiased=False).item() + self.assertAlmostEqual(std, 1.0, places=3, + msg=f"AC3 violated: msp_normed std={std}") + + +class TestAC4SpanBroadcast(_Base): + """AC4: advantage is constant within each step span and changes at the boundary.""" + + def test_advantage_constant_within_span(self): + K = 2 + B = 2 + T = 25 + step_rewards = [ + torch.tensor([0.9, 0.5, 0.7]), + torch.tensor([0.4, 0.6, 0.8]), + ] + sti = [torch.tensor([4, 12, 20])] * B + outcome = torch.tensor([1.0, 0.0]) + action_mask = torch.ones(B, T, dtype=torch.long) + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + exp = _make_exp(action_mask, outcome, step_rewards, sti) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + adv, _, _ = calc.compute(exp, None, None, {}) + + # span 0 of traj 0: tokens 0..4 should all be equal + span0 = adv[0, 0:5] + self.assertLess((span0 - span0[0]).abs().max().item(), 1e-6) + # span 1: tokens 5..12 equal + span1 = adv[0, 5:13] + self.assertLess((span1 - span1[0]).abs().max().item(), 1e-6) + # span 2: tokens 13..20 equal + span2 = adv[0, 13:21] + self.assertLess((span2 - span2[0]).abs().max().item(), 1e-6) + # adjacent spans differ (otherwise per-step credit is degenerate) + self.assertNotAlmostEqual(span0[0].item(), span1[0].item(), places=4) + self.assertNotAlmostEqual(span1[0].item(), span2[0].item(), places=4) + + +class TestAC5SignedAdvantages(_Base): + """AC5: typical inputs produce both positive and negative advantages.""" + + def test_signed_advantages_on_synthetic_batch(self): + K = 2 + B = 4 + T = 25 + step_rewards = [ + torch.tensor([0.8, 0.7, 0.3]), + torch.tensor([0.85, 0.75, 0.9]), + torch.tensor([0.5, 0.55, 0.6]), + torch.tensor([0.6, 0.65, 0.7]), + ] + sti = [torch.tensor([5, 12, 20])] * B + outcome = torch.tensor([1.0, 1.0, 0.0, 1.0]) + action_mask = torch.ones(B, T, dtype=torch.long) + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + exp = _make_exp(action_mask, outcome, step_rewards, sti) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + adv, _, info = calc.compute(exp, None, None, {}) + # advantages must contain both signs (paper Eq.9 → zero-mean per group) + self.assertGreater(info["ursa_v2_adv_pos_frac"], 0.05, + "AC5: advantages should contain positive entries") + self.assertGreater(info["ursa_v2_adv_neg_frac"], 0.05, + "AC5: advantages should contain negative entries") + + +class TestK1Fallback(_Base): + """When K=1, group norm is degenerate; calculator must not crash.""" + + def test_k1_returns_zero_advantage(self): + K = 1 + T = 20 + step_rewards = [torch.tensor([0.5, 0.6])] + sti = [torch.tensor([5, 10])] + outcome = torch.tensor([1.0]) + action_mask = torch.ones(1, T, dtype=torch.long) + calc = ursa_v2.UrsaVariant2Calculator(_make_cfg(n_samples=K)) + exp = _make_exp(action_mask, outcome, step_rewards, sti) + calc.preprocess_rewards(outcome, [exp], max_new_tokens=T) + adv, _, info = calc.compute(exp, None, None, {}) + self.assertEqual(adv.abs().sum().item(), 0.0) + self.assertEqual(info.get("ursa_v2_fallback_used"), 1.0) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/examples/math_prm/tools/__init__.py b/examples/math_prm/tools/__init__.py new file mode 100644 index 00000000..07de1e77 --- /dev/null +++ b/examples/math_prm/tools/__init__.py @@ -0,0 +1 @@ +"""Helper scripts, smoke runners, and regression checks for the URSA math_prm example.""" diff --git a/examples/math_prm/tools/prepare_ursa_engine_checkpoint.py b/examples/math_prm/tools/prepare_ursa_engine_checkpoint.py new file mode 100644 index 00000000..d720e878 --- /dev/null +++ b/examples/math_prm/tools/prepare_ursa_engine_checkpoint.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +""" +Build an engine-friendly local wrapper checkpoint for URSA-8B. + +The upstream URSA checkpoints do not ship ``auto_map`` metadata or local model +code files, which prevents inference engines such as vLLM/SGLang from loading +the custom architecture via HuggingFace dynamic modules. This helper creates a +thin wrapper directory that: + +1. symlinks the original checkpoint weights/tokenizer assets +2. symlinks the local ``examples/math_prm/ursa_model/*.py`` files +3. writes patched ``config.json`` / ``preprocessor_config.json`` / + ``tokenizer_config.json`` with the required ``auto_map`` entries +""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +from pathlib import Path + + +MODEL_AUTO_MAP = { + "AutoConfig": "configuration_ursa.UrsaConfig", + "AutoModel": "modeling_ursa.UrsaForConditionalGeneration", + "AutoModelForVision2Seq": "modeling_ursa.UrsaForConditionalGeneration", +} + +PROCESSOR_AUTO_MAP = { + "AutoProcessor": "processing_ursa.UrsaProcessor", + "AutoImageProcessor": "image_processing_vlm.VLMImageProcessor", +} + + +def _safe_unlink(path: Path) -> None: + if path.is_symlink() or path.is_file(): + path.unlink() + elif path.is_dir(): + shutil.rmtree(path) + + +def _ensure_symlink(src: Path, dst: Path) -> None: + if dst.exists() or dst.is_symlink(): + if dst.is_symlink() and dst.resolve() == src.resolve(): + return + _safe_unlink(dst) + dst.symlink_to(src) + + +def _load_json(path: Path) -> dict: + return json.loads(path.read_text()) + + +def _write_json(path: Path, payload: dict) -> None: + path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n") + + +def build_wrapper(source_model_path: Path, output_path: Path, local_ursa_dir: Path) -> None: + output_path.mkdir(parents=True, exist_ok=True) + + for src in source_model_path.iterdir(): + dst = output_path / src.name + if src.name in {"config.json", "preprocessor_config.json", "tokenizer_config.json"}: + continue + _ensure_symlink(src, dst) + + for src in local_ursa_dir.glob("*.py"): + _ensure_symlink(src, output_path / src.name) + + config = _load_json(source_model_path / "config.json") + auto_map = dict(config.get("auto_map") or {}) + auto_map.update(MODEL_AUTO_MAP) + config["auto_map"] = auto_map + _write_json(output_path / "config.json", config) + + preprocessor_path = source_model_path / "preprocessor_config.json" + if preprocessor_path.exists(): + preprocessor = _load_json(preprocessor_path) + preprocessor["processor_class"] = "UrsaProcessor" + preprocessor["image_processor_type"] = "VLMImageProcessor" + preprocessor_auto_map = dict(preprocessor.get("auto_map") or {}) + preprocessor_auto_map.update(PROCESSOR_AUTO_MAP) + preprocessor["auto_map"] = preprocessor_auto_map + _write_json(output_path / "preprocessor_config.json", preprocessor) + + tokenizer_config_path = source_model_path / "tokenizer_config.json" + if tokenizer_config_path.exists(): + tokenizer_config = _load_json(tokenizer_config_path) + tokenizer_config["processor_class"] = "UrsaProcessor" + tokenizer_auto_map = dict(tokenizer_config.get("auto_map") or {}) + tokenizer_auto_map.update(PROCESSOR_AUTO_MAP) + tokenizer_config["auto_map"] = tokenizer_auto_map + _write_json(output_path / "tokenizer_config.json", tokenizer_config) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--source-model-path", required=True) + parser.add_argument("--output-path", required=True) + args = parser.parse_args() + + source_model_path = Path(args.source_model_path).resolve() + output_path = Path(args.output_path).resolve() + local_ursa_dir = Path(__file__).resolve().parents[1] / "ursa_model" + + build_wrapper(source_model_path, output_path, local_ursa_dir) + print(str(output_path)) + + +if __name__ == "__main__": + main() diff --git a/examples/math_prm/tools/prepare_ursa_stage3_manifest.py b/examples/math_prm/tools/prepare_ursa_stage3_manifest.py new file mode 100644 index 00000000..0dbac106 --- /dev/null +++ b/examples/math_prm/tools/prepare_ursa_stage3_manifest.py @@ -0,0 +1,322 @@ +""" +Prepare a LightRFT-compatible Stage 3 manifest from URSA-MATH raw data. + +This script converts the raw `MMathCoT-1M` jsonl schema: + + {"image_url": "...", "instruction": "...", "output": "..."} + +into a LightRFT prompt dataset schema: + + { + "prompt": "...", + "images": ["/abs/path/to/image.png"], + "reference": "...", + "label": "math_psgrpo" + } + +It also performs a lightweight `PromptDatasetVL` smoke validation on the +converted records so the output can be consumed directly by +`examples/math_prm/train_colocate.py`. +""" + +from __future__ import annotations + +import argparse +import json +import re +from collections import Counter +from pathlib import Path +from types import SimpleNamespace +from typing import Any + + +REPO_ROOT = Path(__file__).resolve().parents[3] + +import sys + +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from lightrft.datasets.prompts_dataset_vl import PromptDatasetVL + + +# --input-path and --image-root are intentionally required (no path default): +# the data lives outside the repo and varies per environment. The output paths +# resolve under REPO_ROOT/tmp/ so they always succeed locally. +DEFAULT_OUTPUT_PATH = str(REPO_ROOT / "tmp" / "ursa_stage3" / "mmathcot_stage3_math_psgrpo.jsonl") +DEFAULT_SUMMARY_PATH = str(REPO_ROOT / "tmp" / "ursa_stage3" / "mmathcot_stage3_math_psgrpo.summary.json") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Convert URSA-MATH MMathCoT-1M raw jsonl into a LightRFT " + "Stage 3 prompt manifest and validate it with PromptDatasetVL." + ) + ) + parser.add_argument( + "--input-path", + type=str, + required=True, + help="Path to MMathCoT-1M raw train.jsonl (e.g. /data/URSA-MATH/MMathCoT-1M/train.jsonl).", + ) + parser.add_argument( + "--image-root", + type=str, + required=True, + help="Root directory for URSA-MATH image assets (e.g. /data/URSA-MATH/images).", + ) + parser.add_argument( + "--output-path", + type=str, + default=DEFAULT_OUTPUT_PATH, + help="Output path for the converted LightRFT jsonl manifest.", + ) + parser.add_argument( + "--summary-path", + type=str, + default=DEFAULT_SUMMARY_PATH, + help="Path to write the conversion/validation summary json.", + ) + parser.add_argument( + "--label", + type=str, + default="math_psgrpo", + help="Label written into the converted manifest.", + ) + parser.add_argument( + "--prompt-mode", + type=str, + choices=["question_only", "instruction"], + default="question_only", + help="How to build the LightRFT prompt from raw instruction.", + ) + parser.add_argument( + "--max-samples", + type=int, + default=None, + help="Optional cap for the number of raw rows to process.", + ) + parser.add_argument( + "--smoke-samples", + type=int, + default=4, + help="How many converted samples to use for PromptDatasetVL smoke validation.", + ) + return parser.parse_args() + + +def extract_prompt(raw_instruction: str, prompt_mode: str) -> tuple[str, bool]: + text = (raw_instruction or "").strip() + if not text: + return "", False + + if prompt_mode == "instruction": + return text, False + + marker = "Question:" + idx = text.find(marker) + if idx == -1: + return text, True + + question = text[idx + len(marker):].strip() + # Some raw rows contain duplicated or malformed prefixes such as + # "Question:estion: ...". Strip these repeatedly before returning. + prefix_re = re.compile(r"^(?:(?:[Qq]uestion|[Qq]estion|[Ee]stion|[Uu]estion)\s*:)\s*") + while True: + cleaned = prefix_re.sub("", question) + if cleaned == question: + break + question = cleaned.strip() + if not question: + return text, True + return question, False + + +def extract_reference(raw_output: str) -> tuple[str, bool]: + text = (raw_output or "").strip() + if not text: + return "", False + + marker = "†Answer:" + idx = text.rfind(marker) + if idx == -1: + return text, True + + answer = text[idx + len(marker):].strip() + if not answer: + return text, True + return answer, False + + +def build_record( + raw: dict[str, Any], + source_index: int, + image_root: Path, + prompt_mode: str, + label: str, +) -> tuple[dict[str, Any], dict[str, Any]]: + image_url = str(raw.get("image_url", "")).strip() + instruction = str(raw.get("instruction", "")).strip() + output = str(raw.get("output", "")).strip() + + prompt, used_prompt_fallback = extract_prompt(instruction, prompt_mode) + reference, used_reference_fallback = extract_reference(output) + + image_path = (image_root / image_url).resolve() + prefix = image_url.split("/", 1)[0] if image_url else "" + + record = { + "data_source": "URSA-MATH/MMathCoT-1M", + "prompt": prompt, + "images": [str(image_path)], + "reference": reference, + "ground_truth": reference, + "label": label, + "reward_model": { + "ground_truth": reference, + }, + "extra_info": { + "source_index": source_index, + "raw_image_url": image_url, + "image_prefix": prefix, + "prompt_mode": prompt_mode, + }, + } + + meta = { + "image_path_exists": image_path.exists(), + "prompt_empty": prompt == "", + "reference_empty": reference == "", + "used_prompt_fallback": used_prompt_fallback, + "used_reference_fallback": used_reference_fallback, + "image_prefix": prefix, + "image_path": str(image_path), + } + return record, meta + + +def smoke_validate(converted_rows: list[dict[str, Any]], smoke_samples: int) -> dict[str, Any]: + smoke_rows = converted_rows[: max(1, min(smoke_samples, len(converted_rows)))] + strategy = SimpleNamespace( + args=SimpleNamespace( + input_key="prompt", + images_key="images", + reference_key="reference", + label_key="label", + apply_chat_template=False, + system_prompt=None, + ) + ) + dataset = PromptDatasetVL( + smoke_rows, + tokenizer=None, + processor=None, + max_length=0, + strategy=strategy, + ) + items = [dataset[i] for i in range(len(dataset))] + prompts, images, references, labels = dataset.collate_fn(items) + + first_prompt, first_images, first_reference, first_label = items[0] + return { + "sample_count": len(dataset), + "first_item": { + "prompt_preview": first_prompt[:240], + "image_count": len(first_images) if isinstance(first_images, list) else 0, + "first_image": first_images[0] if isinstance(first_images, list) and first_images else None, + "reference": first_reference, + "label": first_label, + }, + "collate_sizes": { + "prompts": len(prompts), + "images": len(images), + "references": len(references), + "labels": len(labels), + }, + } + + +def main() -> None: + args = parse_args() + + input_path = Path(args.input_path).resolve() + image_root = Path(args.image_root).resolve() + output_path = Path(args.output_path).resolve() + summary_path = Path(args.summary_path).resolve() + + if not input_path.exists(): + raise FileNotFoundError(f"input jsonl not found: {input_path}") + if not image_root.exists(): + raise FileNotFoundError(f"image root not found: {image_root}") + + counters = Counter() + prefix_counter: Counter[str] = Counter() + smoke_rows: list[dict[str, Any]] = [] + + output_path.parent.mkdir(parents=True, exist_ok=True) + with input_path.open("r", encoding="utf-8") as fp, output_path.open("w", encoding="utf-8") as out_fp: + for source_index, line in enumerate(fp): + if args.max_samples is not None and source_index >= args.max_samples: + break + + counters["rows_seen"] += 1 + raw = json.loads(line) + record, meta = build_record( + raw=raw, + source_index=source_index, + image_root=image_root, + prompt_mode=args.prompt_mode, + label=args.label, + ) + + prefix_counter[meta["image_prefix"]] += 1 + if meta["used_prompt_fallback"]: + counters["prompt_fallback_rows"] += 1 + if meta["used_reference_fallback"]: + counters["reference_fallback_rows"] += 1 + if meta["prompt_empty"]: + counters["empty_prompt_rows"] += 1 + if meta["reference_empty"]: + counters["empty_reference_rows"] += 1 + if not meta["image_path_exists"]: + raise FileNotFoundError( + f"missing image for row {source_index}: {meta['image_path']}" + ) + + out_fp.write(json.dumps(record, ensure_ascii=False) + "\n") + counters["rows_written"] += 1 + if len(smoke_rows) < max(1, args.smoke_samples): + smoke_rows.append(record) + + if not smoke_rows: + raise ValueError("No rows were converted. Check the input path and --max-samples.") + + smoke = smoke_validate(smoke_rows, args.smoke_samples) + + summary = { + "input_path": str(input_path), + "image_root": str(image_root), + "output_path": str(output_path), + "summary_path": str(summary_path), + "label": args.label, + "prompt_mode": args.prompt_mode, + "rows_seen": counters["rows_seen"], + "rows_written": counters["rows_written"], + "prompt_fallback_rows": counters["prompt_fallback_rows"], + "reference_fallback_rows": counters["reference_fallback_rows"], + "empty_prompt_rows": counters["empty_prompt_rows"], + "empty_reference_rows": counters["empty_reference_rows"], + "image_prefix_counts": dict(prefix_counter), + "images_per_sample_counts": {"1": counters["rows_written"]}, + "smoke_validation": smoke, + } + + summary_path.parent.mkdir(parents=True, exist_ok=True) + summary_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + + print(json.dumps(summary, ensure_ascii=False, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/examples/math_prm/train_colocate.py b/examples/math_prm/train_colocate.py new file mode 100755 index 00000000..70ae3e63 --- /dev/null +++ b/examples/math_prm/train_colocate.py @@ -0,0 +1,1115 @@ +""" +GRPO Training with Co-located Reward Models + +This script implements Group Relative Policy Optimization (GRPO) training +with co-located reward models for reinforcement learning from human feedback (RLHF). + +Key Features: + - Supports both text-only and vision-language models + - Multiple reward models (Value, Safety, Knowledge, Normal, General) + - Flexible strategy: DeepSpeed ZeRO or FSDP + - Meta device initialization for memory optimization + - EMA (Exponential Moving Average) model support + - Dynamic sampling and overlong buffer penalties (DAPO) + +Main Components: + - Actor: Policy model being trained + - Critic: Value model for advantage estimation (optional for GRPO) + - Reward Models: Multiple models for evaluating different aspects + - Initial Model: Reference model for KL divergence + +Training Pipeline: + 1. Load and initialize models (actor, critic, reward models) + 2. Setup data loaders (prompts + optional pretrain data) + 3. Configure optimizers and schedulers + 4. Run PPO/GRPO training loop via SPMDPPOTrainerVL + +Usage: + python examples/math_prm/train_colocate.py --pretrain --reward_pretrain ... + +For more details on arguments, see the argument parser at the bottom of this file. +""" +import argparse +import itertools +import math +import re +import os +import sys +import json +from datetime import datetime +from typing import Callable, Dict, List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import ( + AutoConfig, + AutoModelForTokenClassification, + AutoModelForVision2Seq, +) + +from lightrft.utils import add_arguments, ensure_video_input_available +ensure_video_input_available() + +from lightrft.datasets import PromptDatasetVL, SFTDatasetVL +from lightrft.utils import blending_datasets, get_tokenizer_processor_vl +from lightrft.models.actor_language import ActorLanguage +from lightrft.models.actor_vl import ActorVL + +from lightrft.strategy import get_strategy + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from math_prm_trainer import MathPRMSPMDPPOTrainerVL +from reward_models_utils import load_reward_models, reward_fn, RECIPE + + +def is_ursa_model(model_path: str) -> bool: + """ + Check if the model is a URSA model by looking for URSA-specific config. + + URSA models have: + - architectures: ["UrsaForConditionalGeneration"] + - model_type: "ursa" + - vision_config and aligner_config sections + + Args: + model_path: Path to the model directory + + Returns: + True if this is a URSA model, False otherwise + """ + import os + config_path = os.path.join(model_path, "config.json") + if os.path.exists(config_path): + try: + import json + with open(config_path, 'r') as f: + config = json.load(f) + # Check for UrsaForConditionalGeneration in architectures + architectures = config.get("architectures", []) + if "UrsaForConditionalGeneration" in architectures: + return True + # Fallback: check model_type + if config.get("model_type") == "ursa": + return True + except: + pass + return False + + +def resolve_reference_shard_size(world_size: int, preferred_shard_size: int = 8) -> int: + """ + Pick a reference-model FSDP shard size that preserves the original 8-way + layout when possible, but still works for bounded small-world-size runs. + """ + if world_size <= 0: + return preferred_shard_size + candidate = min(preferred_shard_size, world_size) + while candidate > 1 and world_size % candidate != 0: + candidate -= 1 + return candidate + + +def split_runtime_eval_dataset(prompts_data, args, strategy): + """ + Build a deterministic held-out runtime eval split from prompt_data when no + explicit eval dataset is provided. + + This follows the paper/plan intent of using a stable in-domain held-out set + instead of relying on an optional dataset split name. + """ + if args.eval_holdout_size <= 0 or args.max_eval_samples <= 0: + return prompts_data, None + + total_samples = len(prompts_data) + if total_samples <= 1: + strategy.print("Warning: prompt_data is too small to carve out a held-out runtime eval split.") + return prompts_data, None + + eval_size = min(args.eval_holdout_size, args.max_eval_samples, total_samples - 1) + if eval_size <= 0: + strategy.print("Warning: held-out runtime eval split resolved to zero samples; skipping eval split.") + return prompts_data, None + + if not hasattr(prompts_data, "train_test_split"): + strategy.print("Warning: prompt_data does not support train_test_split(); skipping held-out runtime eval.") + return prompts_data, None + + split = prompts_data.train_test_split(test_size=eval_size, shuffle=True, seed=args.eval_holdout_seed) + train_data = split["train"] + eval_data = split["test"] + strategy.print( + "Prepared runtime eval holdout from prompt_data " + f"(train={len(train_data)}, eval={len(eval_data)}, seed={args.eval_holdout_seed})." + ) + return train_data, eval_data + + +def load_actor_tokenizer_processor( + *, + model_path: str, + model, + strategy, + use_fast: bool, +): + """ + Load the actor tokenizer/processor, using the explicit URSA processor path + when the checkpoint is a URSA model. + """ + if is_ursa_model(model_path): + from ursa_model import UrsaProcessor + + processor = UrsaProcessor.from_pretrained(model_path) + tokenizer = processor.tokenizer + tokenizer.padding_side = "left" + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id + model.config.pad_token_id = tokenizer.pad_token_id + strategy.print( + f"Loaded URSA processor explicitly: tokenizer={type(tokenizer).__name__}, " + f"processor={type(processor).__name__}" + ) + return tokenizer, processor + + return get_tokenizer_processor_vl( + model_path, + model, + "left", + use_fast=use_fast, + ) + + +def build_actor_init_kwargs( + args, + *, + ds_config, + include_lora: bool, + include_disable_logprobs_flashattn: bool, +): + """ + Build Actor/UrsaActor initialization kwargs while keeping train/eval variants aligned. + """ + kwargs = dict( + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=ds_config, + packing_samples=args.packing_samples, + fused_linear_logprob=args.fused_linear_logprob, + ) + if include_lora: + kwargs.update( + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ) + if include_disable_logprobs_flashattn: + kwargs["disable_logprobs_flashattn"] = args.disable_logprobs_flashattn + return kwargs + + +def prepare_ursa_runtime_for_inference_engines(strategy=None): + """ + Register the local URSA classes with HuggingFace auto classes so rollout + engines that rely on ``AutoConfig`` can resolve ``model_type='ursa'``. + """ + current_dir = os.path.dirname(os.path.abspath(__file__)) + if current_dir not in sys.path: + sys.path.insert(0, current_dir) + + pythonpath = os.environ.get("PYTHONPATH") + pythonpath_parts = pythonpath.split(os.pathsep) if pythonpath else [] + if current_dir not in pythonpath_parts: + os.environ["PYTHONPATH"] = os.pathsep.join([current_dir, *pythonpath_parts]) if pythonpath_parts else current_dir + + from ursa_model import ( + UrsaConfig, + UrsaForConditionalGeneration, + UrsaForTokenClassification, + ) + + AutoConfig.register("ursa", UrsaConfig, exist_ok=True) + AutoModelForVision2Seq.register(UrsaConfig, UrsaForConditionalGeneration, exist_ok=True) + AutoModelForTokenClassification.register(UrsaConfig, UrsaForTokenClassification, exist_ok=True) + + if strategy is not None: + strategy.print( + "Registered URSA auto classes for inference engines " + f"(sys.path/PYTHONPATH include {current_dir})" + ) + + +def train(args): + """ + Main training function for GRPO with co-located reward models. + + Training workflow: + 1. Initialize strategy (DeepSpeed or FSDP) + 2. Initialize models with meta_init option for memory efficiency + 3. Load reward models (multiple types supported) + 4. Setup dataloaders for prompts and optional pretrain data + 5. Configure optimizers and schedulers + 6. Setup inference engine (vLLM or SGLang) + 7. Run training loop via SPMDPPOTrainerVL + 8. Save final model + + Args: + args: Parsed command-line arguments containing all training configuration + + Key configurations: + - meta_init: Initialize models on meta device to save CPU RAM + - freeze_prefix: Freeze vision encoder during training + - fsdp: Use FSDP instead of DeepSpeed + - rm_use_engine: Generic flag retained for other reward types, but + URSA math_prm/math_psgrpo PRM paths still load via HF directly + """ + if args.hf_separate_rollout_actor and args.engine_type != "hf": + raise ValueError("--hf_separate_rollout_actor requires --engine_type hf.") + if args.hf_separate_rollout_actor and not args.fsdp: + raise ValueError("--hf_separate_rollout_actor currently requires --fsdp.") + + # configure strategy + strategy = get_strategy(args) + + ds_train_cfg = strategy.get_ds_train_config(is_actor=True) if not args.fsdp else None + ds_eval_cfg = strategy.get_ds_eval_config(offload=False) if not args.fsdp else None + + # configure model + # ==================== Model Initialization ==================== + # Initialize all models within init_model_context for memory efficiency. + # When meta_init=True, models are created on "meta" device as empty shells, + # fundamentally resolving CPU OOM issues. + with strategy.init_model_context(meta_init=args.meta_init): + strategy.print(f"Initializing models with meta_init={args.meta_init}") + + # Check if this is a URSA model + is_ursa = is_ursa_model(args.pretrain) + + # Select Actor class based on model type and text_only flag + if is_ursa: + strategy.print(f"Detected URSA model, using UrsaActor") + from ursa_actor import UrsaActor + Actor = UrsaActor + elif args.text_only: + Actor = ActorLanguage + else: + Actor = ActorVL + + # Initialize Actor (policy model) + actor = Actor( + args.pretrain, + **build_actor_init_kwargs( + args, + ds_config=ds_train_cfg, + include_lora=True, + include_disable_logprobs_flashattn=True, + ), + ) + + rollout_actor = None + if args.hf_separate_rollout_actor: + rollout_actor = Actor( + args.pretrain, + **build_actor_init_kwargs( + args, + ds_config=ds_eval_cfg, + include_lora=True, + include_disable_logprobs_flashattn=True, + ), + ) + + if args.actor_init_on_gpu: + actor = actor.to(torch.cuda.current_device()) + + # pre-prepare is used for saving RAM memory when training 72B model + if args.fsdp: + setattr(actor, "is_actor", True) + actor = strategy.prepare_model(actor, is_training=True) + + # Optionally freeze parameters (e.g., vision encoder). + # Qwen2-VL etc. expose vision under "visual.*"; URSA uses "vision_model.*" + # plus an "aligner.*" projector. Match all of them so --freeze_prefix + # actually fires for the URSA stack. + if args.freeze_prefix: + freeze_prefix = ["visual", "vision_model", "aligner"] + frozen_params_count = 0 + total_params_count = 0 + for name, param in actor.model.named_parameters(): + total_params_count += 1 + if any(name.startswith(prefix) for prefix in freeze_prefix): + param.requires_grad = False + frozen_params_count += 1 + strategy.print(f"Froze {frozen_params_count}/{total_params_count} parameters based on prefixes: {freeze_prefix}") + + if args.critic_pretrain: + try: + from lightrft.models import get_vlm_for_sequence_regression + except ImportError as exc: + raise ImportError( + "critic_pretrain was provided, but get_vlm_for_sequence_regression " + "is not available in this LightRFT checkout." + ) from exc + critic = get_vlm_for_sequence_regression( + args.critic_pretrain, + "critic", + normalize_reward=args.normalize_reward_for_critic, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=args.target_modules, + lora_dropout=args.lora_dropout, + ds_config=ds_train_cfg, + value_head_prefix=args.value_head_prefix, + init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain, + ) + else: + critic = None + + # Load reward models (multiple types: value, safety, knowledge, etc.) + strategy.report_memory(f"before loaded reward models in main entry") + reward_models, reward_tokenizers, label_map = load_reward_models( + raw_reward_pretrain=args.reward_pretrain, + strategy=strategy, + use_engine=args.rm_use_engine, + ) + strategy.print(f"label_map: {label_map}") + strategy.report_memory(f"after loaded reward models in main entry") + + strategy.print(actor) + strategy.print(critic) + + # load weights for reference actor + if args.init_kl_coef == 0: + initial_model = None + else: + # Use the same Actor class (including URSA if detected) + initial_model = Actor( + args.pretrain, + **build_actor_init_kwargs( + args, + ds_config=ds_eval_cfg, + include_lora=False, + include_disable_logprobs_flashattn=False, + ), + ) + if args.fsdp: + reference_shard_size = resolve_reference_shard_size( + world_size=strategy.world_size, + preferred_shard_size=8, + ) + strategy.print( + "Preparing reference model with shard_size=" + f"{reference_shard_size} (world_size={strategy.world_size})" + ) + initial_model = strategy.prepare_model( + initial_model, + is_training=False, + shard_size=reference_shard_size, + ) + strategy.offload_model(initial_model) + + if args.enable_ema: + # Use the same Actor class (including URSA if detected) + ema_model = Actor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=args.load_in_4bit, + ds_config=ds_eval_cfg, + ) + else: + ema_model = None + + # configure tokenizer and processor + tokenizer, processor = load_actor_tokenizer_processor( + model_path=args.pretrain, + model=actor.model, + strategy=strategy, + use_fast=not strategy.args.disable_fast_tokenizer, + ) + assert processor is not None, "processor is None" + + # ==================== Data Loading Optimization ==================== + # The following sections now rely on the robust `blending_datasets` function. + # We add more logging for clarity. + + # Prepare prompts dataset + strategy.print(f"Loading prompts dataset from: {args.prompt_data} with split: {args.prompt_split}") + prompts_data = blending_datasets( + args.prompt_data, + args.prompt_data_probs, + strategy, + args.seed, + return_eval=False, + train_split=args.prompt_split, + ) + + heldout_eval_data = None + if not args.eval_data and not args.eval_split: + prompts_data, heldout_eval_data = split_runtime_eval_dataset(prompts_data, args, strategy) + + prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data)))) + prompts_dataset = PromptDatasetVL(prompts_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template) + strategy.print(f"Loaded {len(prompts_dataset)} samples for prompts.") + + # Prepare evaluation dataset + eval_dataloader = None + if args.eval_data or args.eval_split: + eval_data_path = args.eval_data if args.eval_data else args.prompt_data + if eval_data_path: + strategy.print(f"Loading evaluation dataset from {eval_data_path}, split='{args.eval_split}'") + eval_data = blending_datasets( + eval_data_path, "1.0", strategy, args.seed, return_eval=False, + # Note: `train_split` parameter is used to specify the desired split name for evaluation data. + train_split=args.eval_split, + ) + if len(eval_data) == 0: + strategy.print(f"Warning: Evaluation dataset at {eval_data_path} with split '{args.eval_split}' is empty. Skipping evaluation.") + else: + eval_data = eval_data.select(range(min(args.max_eval_samples, len(eval_data)))) + + eval_dataset = PromptDatasetVL(eval_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template) + # Cap eval DataLoader batch_size by local_hf_generate_max_batch_size to + # avoid the padding-leak bug. See heldout branch below for full rationale. + eval_dp_batch_size = args.rollout_batch_size // strategy.world_size + if args.engine_type == "hf": + mb_cap = int(getattr(args, "local_hf_generate_max_batch_size", 0) or 0) + if mb_cap > 0: + eval_dp_batch_size = min(eval_dp_batch_size, mb_cap) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, + eval_dp_batch_size, + False, + False, + collate_fn=eval_dataset.collate_fn, + drop_last=False, + ) + strategy.print( + f"Evaluation dataset loaded: {len(eval_dataset)} samples " + f"(eval DataLoader batch_size={eval_dp_batch_size})" + ) + else: + strategy.print("Warning: eval_split specified but no data path available for evaluation.") + elif heldout_eval_data is not None: + eval_data = heldout_eval_data.select(range(min(args.max_eval_samples, len(heldout_eval_data)))) + eval_dataset = PromptDatasetVL( + eval_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template + ) + # Match DataLoader batch_size to local_hf_generate_max_batch_size for engine_type=hf. + # fast_exp_maker.process_multimodal_batch calls processor(padding=True) on the full + # DataLoader batch, then strategy_base chunks the already-padded tensor into + # micro-batches of local_hf_generate_max_batch_size. Without this alignment, each + # micro-batch keeps the max-of-DL-batch padded length (e.g. 16-wide pad in a 4-wide + # chunk), and the extra left-pad tokens — even with attention_mask masking — degrade + # URSA's greedy decode by ~8pp via RoPE / vision-path interaction. Setting DL-batch + # = micro-batch eliminates the leak so eval matches `tmp/ckpt_eval_aligned.py --bs N` + # exactly. See PR53 issuecomment-... for the 11.9pp breakdown. + eval_dp_batch_size = args.rollout_batch_size // strategy.world_size + if args.engine_type == "hf": + mb_cap = int(getattr(args, "local_hf_generate_max_batch_size", 0) or 0) + if mb_cap > 0: + eval_dp_batch_size = min(eval_dp_batch_size, mb_cap) + eval_dataloader = strategy.setup_dataloader( + eval_dataset, + eval_dp_batch_size, + False, + False, + collate_fn=eval_dataset.collate_fn, + drop_last=False, + ) + strategy.print( + f"Held-out runtime evaluation dataset loaded: {len(eval_dataset)} samples " + f"(eval DataLoader batch_size={eval_dp_batch_size}, aligned with " + f"local_hf_generate_max_batch_size={getattr(args, 'local_hf_generate_max_batch_size', 'n/a')})" + ) + + # Prepare pretrain dataset + pretrain_dataloader = None + if args.pretrain_data: + strategy.print(f"Loading pretrain dataset from: {args.pretrain_data} with split: {args.pretrain_split}") + pretrain_data = blending_datasets( + args.pretrain_data, args.pretrain_data_probs, strategy, args.seed, + return_eval=False, train_split=args.pretrain_split, + ) + if len(pretrain_data) == 0: + strategy.print(f"Warning: Pretrain dataset at {args.pretrain_data} is empty. PTX loss will not be applied.") + pretrain_dataloader = None + else: + pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len + # Calculate total samples needed for pretraining + total_pretrain_samples = args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt + pretrain_data_subset = pretrain_data.select(range(min(len(pretrain_data), total_pretrain_samples))) + + pretrain_dataset = SFTDatasetVL( + pretrain_data_subset, tokenizer, pretrain_max_len, strategy, pretrain_mode=True, + ) + strategy.print(f"Loaded {len(pretrain_dataset)} samples for pretraining.") + pretrain_dataloader = itertools.cycle( + iter( + strategy.setup_dataloader( + pretrain_dataset, args.micro_train_batch_size, True, True, pretrain_dataset.collate_fn, + ) + ) + ) + else: + pretrain_dataloader = None + + # Prepare prompts dataloader + prompts_dataloader = strategy.setup_dataloader( + prompts_dataset, args.rollout_batch_size // strategy.world_size, True, True, collate_fn=prompts_dataset.collate_fn + ) + + if args.pretrain_data: + pretrain_dataloader = itertools.cycle( + iter( + strategy.setup_dataloader( + pretrain_dataset, + args.micro_train_batch_size, + True, + True, + pretrain_dataset.collate_fn, + ) + ) + ) + else: + pretrain_dataloader = None + + # for scheduler + num_update_steps_per_episodes = ( + len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs + ) + max_steps = math.ceil(args.num_episodes * num_update_steps_per_episodes) + + # gradient_checkpointing + if args.gradient_checkpointing: + actor.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + if critic is not None: + critic.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} + ) + + ( + (actor, actor_optim, actor_scheduler), + (critic, critic_optim, critic_scheduler), + reward_models, + initial_model, + ) = strategy.prepare_models_and_optimizers(actor, critic, reward_models, initial_model, args, max_steps) + + if rollout_actor is not None: + keep_rollout_on_gpu = bool(getattr(strategy.config, "hf_separate_rollout_keep_on_gpu", False)) + rollout_actor = strategy.prepare_model( + rollout_actor, + is_training=False, + shard_size=-1, + reshard_after_forward=False, + ) + rollout_actor.gradient_checkpointing_disable() + rollout_actor.eval() + if not keep_rollout_on_gpu: + strategy.offload_model(rollout_actor) + residency_note = "kept on GPU" if keep_rollout_on_gpu else "offloaded to CPU" + strategy.print( + "Prepared separate local HF rollout actor with FSDP full-shard, gc disabled, " + f"reshard_after_forward disabled, and {residency_note}." + ) + + # rollout_eos_patch is OPT-IN as of the eval-fix PR. With it OFF (default), + # rollout/eval generation falls back to HF default stopping (EosTokenCriteria + + # MaxLengthCriteria) which is token-equivalent to a bare model.generate call. + # See PR #53 issuecomment-4394197141 for why the patch was harmful by default. + if getattr(args, "enable_rollout_eos_patch", False): + from rollout_eos_patch import install_math_prm_rollout_eos_patch + install_math_prm_rollout_eos_patch(rollout_actor, tokenizer, tokenizer.eos_token_id) + strategy.print( + "Installed math_prm rollout EOS patch on rollout_actor.model.generate " + "(legacy --enable_rollout_eos_patch flag set; this BIASES rollout reward " + "and eval outcome — only enable to reproduce historical broken behavior)." + ) + else: + strategy.print( + "rollout_eos_patch NOT installed (default). Generation uses HF default " + "stopping criteria (EosTokenCriteria + MaxLengthCriteria), token-equivalent " + "to bare model.generate. Use --enable_rollout_eos_patch to restore legacy." + ) + + strategy.print(reward_models) + + if ema_model: + ema_model._offload = True + ema_model = strategy.prepare(ema_model, is_rlhf=True) + + # load checkpoint + consumed_samples = 0 + if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): + _, states = strategy.load_ckpt(actor.model, os.path.join(args.ckpt_path, "_actor"), + optimizer=actor_optim, scheduler=actor_scheduler) + if args.critic_pretrain: + strategy.load_ckpt(critic, os.path.join(args.ckpt_path, "_critic")) + consumed_samples = states["consumed_samples"] + strategy.print(f"Loaded the checkpoint: {args.ckpt_path}, consumed_samples: {consumed_samples}") + + os.makedirs(args.save_path, exist_ok=True) + strategy.report_memory("after models init") + + if is_ursa: + prepare_ursa_runtime_for_inference_engines(strategy) + + strategy.report_memory("before setup_inference_engine") + strategy.setup_inference_engine( + args, + engine_type=args.engine_type, + actor=actor, + rollout_actor=rollout_actor, + tokenizer=tokenizer, + processor=processor, + ) + strategy.report_memory("after setup_inference_engine") + + # configure Trainer + trainer = MathPRMSPMDPPOTrainerVL( + strategy, + actor, + critic, + reward_models, + initial_model, + ema_model, + actor_optim, + critic_optim, + actor_scheduler, + critic_scheduler, + max_epochs=args.max_epochs, + micro_train_batch_size=args.micro_train_batch_size, + micro_rollout_batch_size=args.micro_rollout_batch_size, + gradient_checkpointing=args.gradient_checkpointing, + tokenizer=tokenizer, + processor=processor, + prompt_max_len=args.prompt_max_len, + value_clip=args.value_clip, + eps_clip=args.eps_clip, + loss_agg_mode=args.loss_agg_mode, + use_gspo=args.use_gspo, + normalize_advantages=args.normalize_advantages, + use_sequence_rewards=args.use_sequence_rewards, + gamma=args.gamma, + lambd=args.lambd, + init_kl_coef=args.init_kl_coef, + kl_target=args.kl_target, + ema_beta=0.992, + ptx_coef=args.ptx_coef, + max_norm=args.max_norm, + # for GPT generation + do_sample=True, + max_new_tokens=args.generate_max_len, + max_length=args.max_len, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + no_repeat_ngram_size=args.no_repeat_ngram_size, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + # reward model + reward_fn=reward_fn, + reward_fn_label_map=label_map, + reward_recipe=RECIPE, + reward_tokenizers=reward_tokenizers, + save_hf_ckpt=args.save_hf_ckpt, + disable_ds_ckpt=args.disable_ds_ckpt, + packing_samples=args.packing_samples, + # overlong_reward + dynamic_sampling=args.dynamic_sampling, + overlong_buffer=args.overlong_buffer, + overlong_buffer_len=args.overlong_buffer_len, + overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, + print_replay_buffer_stats=args.print_replay_buffer_stats, + ) + + # ---- Optional initial evaluate-only at step 0 (no PPO update) ---- + # Useful for diagnosing model state at step 0 vs step 1; e.g. to attribute the + # outcome gap between standalone bs=1 eval and the 8-rank FSDP bs=4 eval pipeline. + # Triggered by --initial_eval (default False, no-op). + if getattr(args, "initial_eval", False) and eval_dataloader is not None: + strategy.print(f"\n{'=' * 60}\n[initial_eval] Running evaluate at step 0 (NO PPO update)\n{'=' * 60}") + trainer.eval_dataloader = eval_dataloader # ensure trainer has handle + raw = trainer.evaluate(eval_dataloader, global_step=0) + if strategy.is_rank_0() and raw: + strategy.print(f"[initial_eval] step 0 outcome: {raw}") + if getattr(args, "initial_eval_only", False): + strategy.print("[initial_eval] --initial_eval_only set, exiting before training.") + return + + trainer.fit(args, prompts_dataloader=prompts_dataloader, pretrain_dataloader=pretrain_dataloader, eval_dataloader=eval_dataloader, consumed_samples=0, num_update_steps_per_episodes=num_update_steps_per_episodes) + + # save model checkpoint after fitting on only rank0 + strategy.save_model( + ema_model if args.enable_ema else actor, + tokenizer, + args.save_path, + ) + + if args.critic_pretrain and args.save_value_network: + strategy.save_model( + critic, + tokenizer, + args.save_path + "_critic", + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--engine_type", type=str, default="hf", help="Choose inference engine type: vllm, sglang, hf") + parser.add_argument("--text_only", action="store_true", default=False) + + # Checkpoint + parser.add_argument("--save_path", type=str, default="./ckpt") + parser.add_argument("--save_steps", type=int, default=-1) + parser.add_argument("--save_hf_ckpt", action="store_true", default=False) + parser.add_argument("--disable_ds_ckpt", action="store_true", default=False) + parser.add_argument("--save_trajectories", action="store_true", default=False, help="Save experience trajectories to JSON for debugging") + parser.add_argument( + "--trajectory_analysis", + action="store_true", + default=False, + help="Enable extra trajectory analysis metrics when saving trajectories", + ) + parser.add_argument("--num_trajectories_to_save", type=int, default=10, help="Number of trajectories to save per checkpoint") + parser.add_argument("--print_replay_buffer_stats", action="store_true", default=False, help="Print detailed replay buffer statistics during training") + parser.add_argument("--enable_profile", action="store_true", default=False, help="Enable persistent step profiling with local files and W&B metrics") + parser.add_argument("--logging_steps", type=int, default=1) + parser.add_argument("--eval_steps", type=int, default=-1) + parser.add_argument("--ckpt_path", type=str, default="./ckpt/checkpoints_ppo") + parser.add_argument("--max_ckpt_num", type=int, default=3) + parser.add_argument("--max_ckpt_mem", type=int, default=1e8) + parser.add_argument("--load_checkpoint", action="store_true", default=False) + + # DAPO + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") + parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") + parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them") + + # PPO + parser.add_argument("--num_episodes", type=int, default=10) + parser.add_argument("--rollout_batch_size", type=int, default=128) + parser.add_argument("--micro_rollout_batch_size", type=int, default=4) + parser.add_argument("--max_epochs", type=int, default=1) + parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for each prompt") + parser.add_argument("--generate_max_len", type=int, default=3072, help="Max tokens to generate in PPO") + parser.add_argument( + "--max_len", + type=int, + default=None, + help=( + "Optional explicit total max_len (prompt + generation) for the " + "PromptDataset/SFTDataset. Defaults to prompt_max_len + " + "generate_max_len when unset; see train_colocate.py:542 and :709." + ), + ) + parser.add_argument("--max_samples", type=int, default=15360) + parser.add_argument("--max_norm", type=float, default=1.0, help="Gradient clipping") + parser.add_argument("--l2", type=float, default=0.0, help="weight decay loss") + parser.add_argument("--ptx_coef", type=float, default=0.05, help="PPO-ptx loss coef") + parser.add_argument("--eps_clip", type=float, default=0.2, help="PPO clip range") + parser.add_argument("--loss_agg_mode", type=str, default='seq-mean-token-mean', + help="Loss aggregation mode. Options: ['token-mean', 'seq-mean-token-sum', 'seq-mean-token-mean', 'seq-mean-token-sum-norm']") + parser.add_argument("--use_gspo", action="store_true", default=False, help="Enable GSPO (Group Sequence Policy Optimization) mode") + parser.add_argument("--normalize_advantages", action="store_true", default=True, help="Enable advantage normalization in GSPO") + parser.add_argument("--use_sequence_rewards", action="store_true", default=True, help="Use sequence-level rewards in GSPO") + parser.add_argument("--value_clip", type=float, default=0.2, help="PPO value clip range") + parser.add_argument("--lambd", type=float, default=0.95, help="PPO GAE lambd") + parser.add_argument("--gamma", type=float, default=1, help="PPO GAE gamma") + parser.add_argument("--micro_train_batch_size", type=int, default=4, help="batch size per GPU") + parser.add_argument("--train_batch_size", type=int, default=128, help="Global training batch size") + parser.add_argument("--normalize_reward_for_critic", action="store_true", default=False, help="Enable Reward Normalization in critic model") + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--top_k", type=int, default=-1) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--repetition_penalty", type=float, default=1.0) + parser.add_argument("--no_repeat_ngram_size", type=int, default=0) + parser.add_argument("--freeze_prefix", action="store_true", default=False, help="Freeze the prefix part (e.g. vision encoder) of the actor model") + parser.add_argument("--freezing_actor_steps", type=int, default=-1, help="Used for critic initialization") + parser.add_argument( + "--n_samples_per_prompt", type=int, default=8, help="number of responses for each prompt in generation" + ) + parser.add_argument("--save_value_network", action="store_true", default=False, help="Save critic model") + parser.add_argument("--actor_learning_rate", type=float, default=1e-6) + parser.add_argument("--critic_learning_rate", type=float, default=9e-6) + parser.add_argument("--lr_warmup_ratio", type=float, default=0.03) + parser.add_argument("--kl_target", type=float, default=None) + parser.add_argument("--init_kl_coef", type=float, default=0.001, help="KL penalty in PPO") + parser.add_argument( + "--kl_estimator", + type=str, + default="k1", + choices=["k1", "k2", "k3"], + help=( + "In GRPO, k3 is utilized as the loss function, while k2, when used as the loss, is nearly equivalent to k1." + ), + ) + parser.add_argument("--adam_betas", type=float, nargs=2, default=(0.9, 0.95), help="Betas for Adam optimizer") + + # Reward/Advantage Norm/Clip Arguments + parser.add_argument("--reward_running_norm", action="store_true", default=False, help="Enable running normalization for rewards.") + parser.add_argument("--reward_running_norm_minus_mean", action="store_true", default=False, help="When using reward normalization, subtract the mean; otherwise, only scale by the std.") + parser.add_argument("--reward_clip", type=float, default=0.0, help="Clip rewards to the range [-reward_clip, reward_clip]. 0.0 means no clipping.") + parser.add_argument("--advantages_norm", action="store_true", default=False, help="Enable whitening for advantages.") + parser.add_argument("--advantage_clip", type=float, default=0.0, help="Clip advantages to the range [-advantage_clip, advantage_clip]. 0.0 means no clipping.") + + # DeepSpeed + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for deepspeed") + parser.add_argument("--zero_stage", type=int, default=2, help="DeepSpeed ZeRO stage") + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False, help="Enable bfloat16") + parser.add_argument("--enable_ema", action="store_true", help="Enable EMA checkpoint for the model.") + parser.add_argument("--zpg", type=int, default=1, help="ZeRO++ max partition size") + parser.add_argument("--adam_offload", action="store_true", default=False, help="Offload Adam Optimizer") + parser.add_argument("--actor_init_on_gpu", action="store_true", default=False) + parser.add_argument("--flash_attn", action="store_true", default=False, help="Enable FlashAttention2") + parser.add_argument("--aux_loss_coef", type=float, default=0, help="MoE balancing loss") + parser.add_argument("--grad_accum_dtype", type=str, default=None, help="Adam grad accum data type") + parser.add_argument("--overlap_comm", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing_use_reentrant", action="store_true", default=False) + parser.add_argument("--disable_fast_tokenizer", action="store_true", default=False) + parser.add_argument("--disable_logprobs_flashattn", action="store_true", default=False, help="Disable flash attn implementation in log_probs calculation") + + # FSDP + parser.add_argument("--no_shard_vit", action="store_true", default=False, help="Disable sharding for vision transformer") + parser.add_argument("--meta_init", action="store_true", default=False, help="Initialize models on meta device to save CPU memory") + + # Reinforce + parser.add_argument( + "--advantage_estimator", + type=str, + choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++", + "ursa_variant2"], + default="gae", + help=( + "Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm, " + "reinforce++. 'ursa_variant2' (URSA paper Eq.9 strict alignment) is provided by " + "examples/math_prm/ursa_variant2.py and only meaningful with label='math_per_step_prm' + " + "n_samples_per_prompt >= 2." + ), + ) + + parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO") + + parser.add_argument( + "--per_step_reward_mode", + type=str, + choices=["raw", "group_norm"], + default="group_norm", + help=( + "How to integrate per-step PRM rewards (Math-Shepherd-style " + "per-token reward path, distinct from the strict paper Eq.9 path " + "selected via --advantage_estimator ursa_variant2). " + "'group_norm' (default): for each step k, subtract group mean and " + "divide by group std across the K trajectories in the same prompt " + "group BEFORE scattering to step-boundary tokens. Produces " + "zero-mean signed advantages (GRPO baseline convention). " + "'raw': scatter raw sigmoid step_score directly. WARNING — raw " + "is unsafe: sigmoid scores are always positive, so every " + "post-cumsum token advantage is non-negative and PG pushes " + "every probability up. Kept only for paper Figure ablation. " + "Only active when label is 'math_per_step_prm' AND " + "--advantage_estimator is the cumsum path (group_norm/grpo). " + "For the strict paper Eq.9 path use --advantage_estimator " + "ursa_variant2 (handles its own group normalization)." + ), + ) + + # LoRA + parser.add_argument("--load_in_4bit", action="store_true", default=False) + parser.add_argument("--lora_rank", type=int, default=0) + parser.add_argument("--lora_alpha", type=int, default=16) + parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") + parser.add_argument("--lora_dropout", type=float, default=0) + + # Models + parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API") + parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path") + parser.add_argument("--value_head_prefix", type=str, default="score") + + # Custom dataset + parser.add_argument("--prompt_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--prompt_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--prompt_split", type=str, default="train") + + # Evaluation dataset + parser.add_argument("--eval_data", type=str, default=None, help="HF evaluation dataset name or path (default: use prompt_data)") + parser.add_argument("--eval_split", type=str, default="", help="Evaluation data split (default: disabled)") + parser.add_argument("--max_eval_samples", type=int, default=500, help="Maximum number of samples to evaluate (default: 500)") + parser.add_argument( + "--eval_holdout_size", + type=int, + default=500, + help="Deterministic held-out eval subset size sampled from prompt_data when eval_data is unset (default: 500)", + ) + parser.add_argument( + "--eval_holdout_seed", + type=int, + default=42, + help="Seed for deterministic held-out runtime eval split (default: 42)", + ) + parser.add_argument( + "--eval_n_samples_per_prompt", + type=int, + default=1, + help="Number of eval generations per prompt (default: 1)", + ) + parser.add_argument( + "--eval_do_sample", + action="store_true", + default=False, + help="Use sampling during runtime eval instead of greedy decoding", + ) + parser.add_argument( + "--eval_generate_max_len", + type=int, + default=None, + help="Maximum generation length for runtime eval (default: use generate_max_len)", + ) + parser.add_argument("--eval_temperature", type=float, default=0.0, help="Eval temperature (default: 0.0)") + parser.add_argument("--eval_top_p", type=float, default=1.0, help="Eval top-p (default: 1.0)") + parser.add_argument("--eval_top_k", type=int, default=-1, help="Eval top-k (default: -1)") + parser.add_argument( + "--eval_repetition_penalty", + type=float, + default=1.0, + help="Eval repetition penalty (default: 1.0)", + ) + parser.add_argument( + "--eval_no_repeat_ngram_size", + type=int, + default=0, + help="Eval no-repeat-ngram size (default: 0)", + ) + + parser.add_argument("--pretrain_data", type=str, default=None, help="HF dataset name or path") + parser.add_argument( + "--pretrain_data_probs", + type=str, + default="1.0", + help="sampling probs for datasets", + ) + parser.add_argument("--pretrain_split", type=str, default="train") + parser.add_argument("--input_key", type=str, default="input", help="JSON dataset key") + parser.add_argument("--images_key", type=str, default="images", help="JSON dataset key for images") + parser.add_argument("--reference_key", type=str, default="reference", help="JSON dataset key for reference answers") + parser.add_argument("--label_key", type=str, default="label", help="JSON dataset key") + parser.add_argument("--input_template", type=str, default=None) + parser.add_argument( + "--apply_chat_template", action="store_true", default=False, help="Use HF tokenizer chat template" + ) + + parser.add_argument("--system_prompt", type=str, default=None, help="HF System Prompt") + + + # wandb parameters + parser.add_argument("--use_wandb", type=str, default=None) + parser.add_argument("--wandb_org", type=str, default=None) + parser.add_argument("--wandb_group", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="lightrft_train_ppo") + parser.add_argument( + "--wandb_run_name", + type=str, + default="ppo_%s" % datetime.now().strftime("%m%dT%H:%M"), + ) + + # TensorBoard parameters + parser.add_argument("--use_tensorboard", type=str, default=None, help="TensorBoard logging path") + + # ModelScope parameters + parser.add_argument("--use_ms", action="store_true", default=False) + + # MultiModal + parser.add_argument("--limit_mm_image_per_prompt", type=int, default=-1, help="the max image number of each text in multi model for inference backend") + + # CPGD + parser.add_argument("--use_cpg_loss", action="store_true", default=False, help="whether to use the clipped policy gradient loss from CPGD") + + # initial-eval (eval at step 0, before any PPO update) + parser.add_argument( + "--initial_eval", action="store_true", default=False, + help="Run evaluate(global_step=0) before fit(). Useful for measuring base " + "model outcome under the actual training eval pipeline (8-rank FSDP + " + "bs=4 etc.) without any PPO drift.", + ) + parser.add_argument( + "--initial_eval_only", action="store_true", default=False, + help="With --initial_eval, exit immediately after initial eval (skip training).", + ) + + # math_prm rollout EOS patch + parser.add_argument( + "--enable_rollout_eos_patch", action="store_true", default=False, + help=( + "Install StructuredAnswerStoppingCriteria on rollout_actor.model.generate " + "(legacy behavior). DEFAULT OFF. The patch makes generation stop right after " + "'†Answer:' marker, but historical experiments (PR #53 issuecomment-4394197141) " + "showed it (a) lowers eval outcome by ~9.8pp due to truncated tokens, and " + "(b) biases rollout reward signal towards short responses (Goodhart's law) " + "causing length collapse during RL. With patch off, generation falls back to " + "HF default stopping (EosTokenCriteria + MaxLengthCriteria), which is what we " + "want for both rollout reward fidelity and eval accuracy alignment." + ), + ) + + add_arguments(parser) + + args = parser.parse_args() + + + if args.advantage_estimator not in ["gae"]: + args.critic_pretrain = None + elif args.critic_pretrain is None: + args.critic_pretrain = args.pretrain + + if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]: + assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" + + if args.use_kl_loss: + if args.kl_estimator not in ["k2", "k3"]: + print(f"Recommend setting {args.kl_estimator} to 'k2' or 'k3' when using KL as a loss") + else: + if args.kl_estimator not in ["k1"]: + print(f"Recommend setting {args.kl_estimator} to 'k1' when not using KL as a loss.") + + if args.advantage_estimator in ["gae", "cpgd"] and args.use_kl_loss: + warnings.warn( + "Using use_kl_loss=True with non-normalized advantage estimator " + "may result in double KL penalty. Consider disabling --use_kl_loss " + "or using --advantage_estimator group_norm" + ) + + if args.input_template and "{}" not in args.input_template: + print("[Warning] {} not in args.input_template, set to None") + args.input_template = None + + if args.input_template and "\\n" in args.input_template: + print( + "[Warning] input_template contains \\n chracters instead of newline. " + "You likely want to pass $'\\n' in Bash or \"`n\" in PowerShell." + ) + + if args.use_ms: + from modelscope.utils.hf_util import patch_hub + + # Patch hub to download models from modelscope to speed up. + patch_hub() + + train(args) diff --git a/examples/math_prm/ursa_actor.py b/examples/math_prm/ursa_actor.py new file mode 100644 index 00000000..7e928365 --- /dev/null +++ b/examples/math_prm/ursa_actor.py @@ -0,0 +1,390 @@ +""" +URSA-8B Actor Model Loader + +This module provides a custom actor loader for URSA-8B models, which use +UrsaForConditionalGeneration instead of the standard AutoModelForVision2Seq. + +URSA-8B architecture: +- Hybrid vision tower: SAM-B (1024x1024) + SigLIP-L (384x384) +- MLP projector: Maps vision features to LLM embedding space +- Language model: Qwen2.5-Math-Instruct (8B params) +""" + +import sys +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional, Union +from transformers.integrations.deepspeed import HfDeepSpeedConfig + +# Add current directory to path for ursa_model imports +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.insert(0, current_dir) + +from ursa_model import UrsaForConditionalGeneration +from lightrft.models.actor_vl import ActorVL +from lightrft.models.utils import apply_lora_configuration, reset_position_ids, entropy_from_logits + + +class UrsaActor(ActorVL): + """ + Actor wrapper for URSA-8B models. + + This class extends ActorVL to support loading URSA-8B models using + UrsaForConditionalGeneration instead of AutoModelForVision2Seq. + + Usage: + actor = UrsaActor( + pretrain_or_model="/path/to/URSA-8B", + use_flash_attention_2=True, + bf16=True, + lora_rank=0, + ) + """ + + def __init__( + self, + pretrain_or_model, + use_flash_attention_2=False, + bf16=True, + lora_rank=0, + lora_alpha=16, + lora_dropout=0, + target_modules=None, + ds_config=None, + device_map=None, + packing_samples=False, + high_entropy_token_ratio=0.0, + **kwargs, + ) -> None: + """ + Initialize URSA-8B actor model. + + Args: + pretrain_or_model: Path to URSA-8B checkpoint or model instance + use_flash_attention_2: Enable Flash Attention 2.0 + bf16: Use bfloat16 precision + lora_rank: LoRA rank (0 disables LoRA) + lora_alpha: LoRA alpha scaling parameter + lora_dropout: LoRA dropout rate + target_modules: Target modules for LoRA (auto-detected if None) + ds_config: DeepSpeed configuration + device_map: Device mapping for model placement + packing_samples: Enable sample packing + high_entropy_token_ratio: High entropy token filtering ratio + """ + # Initialize parent class without calling its __init__ + # We'll handle model loading ourselves + nn.Module.__init__(self) + self.high_entropy_token_ratio = high_entropy_token_ratio + + if isinstance(pretrain_or_model, str): + self.pretrain_or_model = pretrain_or_model + attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" + + # DeepSpeed ZeRO-3 integration + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None # noqa: F841 + + # Prepare loading kwargs + from_pretrained_kwargs = { + "trust_remote_code": True, + "attn_implementation": attn_implementation, + "torch_dtype": torch.bfloat16 if bf16 else "auto", + } + + # Check if we're in meta device context (FSDP) + try: + test_tensor = torch.empty(1) + is_meta_context = test_tensor.is_meta + except: # noqa + is_meta_context = False + + if not is_meta_context and device_map is not None: + from_pretrained_kwargs["device_map"] = device_map + + print(f"[UrsaActor] Loading URSA-8B model from {pretrain_or_model}") + + # Load URSA model using UrsaForConditionalGeneration + self.model = UrsaForConditionalGeneration.from_pretrained( + pretrain_or_model, + **from_pretrained_kwargs + ) + + print(f"[UrsaActor] Successfully loaded URSA-8B model") + + # Apply LoRA if requested + if lora_rank > 0: + print(f"[UrsaActor] Applying LoRA with rank={lora_rank}, alpha={lora_alpha}") + self.model = apply_lora_configuration( + model=self.model, + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + target_modules=target_modules, + freeze_vision_tower=True, + ) + + # Disable cache for training + self.model.config.use_cache = False + + # Enable sample packing if requested + self.packing_samples = packing_samples + else: + # Model instance provided directly + self.model = pretrain_or_model + self.pretrain_or_model = "ursa" + + print(f"[UrsaActor] Model type: {self.pretrain_or_model}") + + def forward( + self, + sequences: torch.LongTensor, + num_actions=None, + attention_mask: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + return_output: bool = False, + packed_seq_lens: Optional[list] = None, + ) -> torch.Tensor: + """ + VLM-aligned forward. + + URSA's vision tower expands every <|image|> placeholder into 576 vision + tokens during the LM forward, so ``output["logits"]`` is longer than the + input ``sequences`` along the seq dim. The default ``ActorVL.forward`` + feeds ``output["logits"][:, :-1, :]`` (length E-1) and ``sequences[:, 1:]`` + (length T-1) into ``log_probs_from_logits``, which then hits PyTorch's + ``gather(dim=-1, index=...)`` — that op silently TRUNCATES the rows of + ``logits`` to ``len(labels)`` instead of erroring. The result: log-probs + are read from the wrong (vision-token / early-prompt) positions, never + from the actual generation positions. KL/PPO/ratio all become noise. + + We sidestep the bug entirely by slicing the logits to the action range + on the seq dim first (where alignment is unambiguous because generation + always lives at the tail of the expanded sequence), then using a single + ``F.log_softmax + gather`` over the action labels. fp32 throughout so + the precision matches the rest of the PPO loss path. + """ + if self.packing_samples: + position_ids = reset_position_ids(attention_mask) + attention_mask_for_model = None + else: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + attention_mask_for_model = attention_mask + + # Sanitize actor-leaked image tokens before forward. During GRPO rollout + # an actor can generate literal `<|image|>` / `` strings inside + # its response (rare but observed; freq goes up when KL spikes hard + # late in training). The tokenizer then maps those to image_token_index, + # so the sequence ends up with MORE image-token slots than the prompt + # actually had images. URSA's vision merge requires + # image_token_count == n_image_features and aborts with + # "The input provided to the model are wrong. + # The number of image tokens is N while the number of image given is M. + # This prevents correct indexing and breaks batch generation." + # which crashes the whole PPO step. cce5ae5 already fixed this on the + # PRM forward path; the actor forward needs the same protection because + # the same actor-generated sequences are replayed here every PPO inner + # epoch. Align both directions: + # * token_count > image_count : extras are leaked, replace with pad. + # * token_count < image_count : truncate pixel_values/image_grid_thw. + sequences, pixel_values, image_grid_thw = self._align_image_tokens_to_images( + sequences, pixel_values, image_grid_thw + ) + + forward_kwargs = dict( + attention_mask=attention_mask_for_model, + position_ids=position_ids, + pixel_values=self._cast_multimodal_tensor(pixel_values), + image_grid_thw=image_grid_thw, + pixel_values_videos=self._cast_multimodal_tensor(pixel_values_videos), + video_grid_thw=video_grid_thw, + use_cache=False, + ) + for k in ("pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "use_cache"): + if not self._supports_model_kwarg(k): + forward_kwargs.pop(k, None) + + output = self.model(sequences, **forward_kwargs) + + if num_actions is None: + assert return_output + return output + + logits = output["logits"] + seq_T = sequences.size(1) + logit_T = logits.size(1) + if self.packing_samples: + raise NotImplementedError( + "UrsaActor.forward does not yet support packed_seq_lens. The " + "default ActorVL packing path is silently miscomputed for VLMs " + "that expand image placeholders; we don't want to bake the same " + "bug in here. Add explicit packed-aware alignment when needed." + ) + + # Generation tokens always sit at the tail of the expanded sequence, + # so logits at expanded positions [E - num_actions - 1 .. E - 2] + # predict tokens at expanded positions [E - num_actions .. E - 1] — + # which are the same generation tokens as ``sequences[:, -num_actions:]`` + # in the unexpanded view (the unexpanded vs expanded offset only affects + # positions BEFORE the image placeholders, all in the prompt). + action_logits = logits[:, -(num_actions + 1):-1, :] + action_labels = sequences[:, -num_actions:] + if action_logits.size(1) != action_labels.size(1): + raise RuntimeError( + f"action_logits seq len {action_logits.size(1)} does not match " + f"action_labels seq len {action_labels.size(1)} " + f"(num_actions={num_actions}, seq_T={seq_T}, logit_T={logit_T})" + ) + + action_logp_full = F.log_softmax(action_logits.float(), dim=-1) + action_log_probs = action_logp_full.gather( + -1, action_labels.unsqueeze(-1) + ).squeeze(-1) + + if self.high_entropy_token_ratio > 0.0: + # Entropy of the action-position distribution, in the same fp32 used above. + probs = action_logp_full.exp() + action_entropy = -(probs * action_logp_full).sum(dim=-1) + else: + action_entropy = None + + if return_output: + if action_entropy is not None: + output_dict = dict(output) + output_dict["action_entropy"] = action_entropy + return (action_log_probs, output_dict) + return (action_log_probs, output) + return action_log_probs + + def _align_image_tokens_to_images(self, sequences, pixel_values, image_grid_thw): + """Make ``sequences``'s image-token count match the actual image count. + + URSA's vision merge crashes with:: + + ValueError: The input provided to the model are wrong. + The number of image tokens is N while the number of image given is M. + This prevents correct indexing and breaks batch generation. + + whenever the count of ``image_token_index`` markers inside the input + sequence is unequal to the number of image features supplied via + ``pixel_values`` (one row of ``image_grid_thw`` per image). During GRPO + rollout this can happen because the actor occasionally generates a + literal ``<|image|>`` / ```` token inside its response — usually + rare, but observed to fire when KL drifts very high mid-training and + the actor's output distribution gets unstable. + + Strategy (no-op fast path when counts already agree): + + * count_tok > count_img : leaked extras — replace the trailing + (count_tok - count_img) image-token slots in each row with a benign + text token (pad / eos). Keep the first count_img in their original + positions so the vision features still merge into the prompt. + + * count_tok < count_img : sequence is missing some image-token slots + relative to the image features. Truncate ``image_grid_thw`` and the + corresponding rows of ``pixel_values`` to the first count_tok + images. (Information loss is unavoidable, but the PPO step still + makes progress on the other tokens.) + + Returns the (possibly sanitized) ``(sequences, pixel_values, image_grid_thw)``. + Original tensors are returned unchanged when no mismatch exists. + """ + if sequences is None: + return sequences, pixel_values, image_grid_thw + image_token_id = getattr(self.model.config, "image_token_index", None) + if image_token_id is None: + return sequences, pixel_values, image_grid_thw + if pixel_values is None and image_grid_thw is None: + # Pure text micro-batch — nothing to align; also strip any leaked + # image tokens so the LM head doesn't see them as content. + n_tok = int((sequences == image_token_id).sum().item()) + if n_tok == 0: + return sequences, pixel_values, image_grid_thw + + # Per-row image-token positions (flat indices). + # We sanitize in-place on a clone to avoid mutating shared rollout buffers. + seq = sequences.clone() + flat = seq.view(-1) + tok_positions = torch.nonzero(flat == image_token_id, as_tuple=False).squeeze(-1) + n_tok = int(tok_positions.numel()) + + # Number of images supplied. image_grid_thw has one row per image and + # is the more reliable source than pixel_values (which may be packed). + if image_grid_thw is not None: + n_img = int(image_grid_thw.size(0)) + elif pixel_values is not None: + # Fallback: assume one image per row of pixel_values. + n_img = int(pixel_values.size(0)) if pixel_values.dim() >= 1 else 0 + else: + n_img = 0 + + if n_tok == n_img: + return sequences, pixel_values, image_grid_thw + + replacement = None + tokenizer = getattr(self, "tokenizer", None) + if tokenizer is not None: + replacement = tokenizer.pad_token_id + if replacement is None: + replacement = tokenizer.eos_token_id + if replacement is None: + # Last-resort: pick a known safe id (eos is usually safe across HF tokenizers). + replacement = int(getattr(self.model.config, "eos_token_id", 0) or 0) + + if n_tok > n_img: + # Leaked extras — replace tail extras with pad/eos so token_count == n_img. + extras = tok_positions[n_img:] + flat[extras] = replacement + seq = flat.view_as(sequences) + return seq, pixel_values, image_grid_thw + + # n_tok < n_img: truncate image features to match token slots. + new_grid = image_grid_thw[:n_tok] if image_grid_thw is not None else None + new_pixel = pixel_values + if pixel_values is not None and image_grid_thw is not None and n_tok > 0: + # pixel_values is the concat of per-image patches; the per-image + # row counts come from image_grid_thw[i, 0] * thw[i, 1] * thw[i, 2]. + # Keep the first n_tok image's patches. + patch_counts = (image_grid_thw[:, 0] * image_grid_thw[:, 1] * image_grid_thw[:, 2]).long() + keep = int(patch_counts[:n_tok].sum().item()) + new_pixel = pixel_values[:keep] + elif pixel_values is not None and n_tok == 0: + new_pixel = None + new_grid = None + return sequences, new_pixel, new_grid + + +def create_ursa_actor(args, ds_config=None): + """ + Factory function to create URSA-8B actor from training args. + + Args: + args: Training arguments (argparse.Namespace) + ds_config: DeepSpeed configuration dict + + Returns: + UrsaActor instance + """ + return UrsaActor( + args.pretrain, + use_flash_attention_2=args.flash_attn, + bf16=args.bf16, + load_in_4bit=getattr(args, 'load_in_4bit', False), + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=getattr(args, 'target_modules', None), + lora_dropout=args.lora_dropout, + ds_config=ds_config, + packing_samples=args.packing_samples, + disable_logprobs_flashattn=getattr(args, 'disable_logprobs_flashattn', False), + fused_linear_logprob=getattr(args, 'fused_linear_logprob', False), + ) diff --git a/examples/math_prm/ursa_model/__init__.py b/examples/math_prm/ursa_model/__init__.py new file mode 100644 index 00000000..bdbaeec2 --- /dev/null +++ b/examples/math_prm/ursa_model/__init__.py @@ -0,0 +1,30 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .image_processing_vlm import VLMImageProcessor, VLMImageProcessorConfig +from .modeling_ursa import UrsaForConditionalGeneration, UrsaForTokenClassification +from .processing_ursa import UrsaProcessor +from .configuration_ursa import VisionConfig, UrsaConfig, AlignerConfig +from .projector import MlpProjector + +__all__ = [ + "VLMImageProcessor", + "UrsaProcessor", + "UrsaForConditionalGeneration", + "UrsaForTokenClassification", + "VLMImageProcessorConfig", + "VisionConfig", + "MlpProjector", + "AlignerConfig", + "UrsaConfig" +] \ No newline at end of file diff --git a/examples/math_prm/ursa_model/attrdict_compat.py b/examples/math_prm/ursa_model/attrdict_compat.py new file mode 100644 index 00000000..74c3aa3b --- /dev/null +++ b/examples/math_prm/ursa_model/attrdict_compat.py @@ -0,0 +1,23 @@ +try: + from attrdict import AttrDict # type: ignore +except ImportError: + try: + from easydict import EasyDict as AttrDict # type: ignore + except ImportError: + class AttrDict(dict): + """Minimal AttrDict fallback for URSA config objects.""" + + def __getattr__(self, item): + try: + return self[item] + except KeyError as exc: + raise AttributeError(item) from exc + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, item): + try: + del self[item] + except KeyError as exc: + raise AttributeError(item) from exc diff --git a/examples/math_prm/ursa_model/clip_encoder.py b/examples/math_prm/ursa_model/clip_encoder.py new file mode 100644 index 00000000..8695701a --- /dev/null +++ b/examples/math_prm/ursa_model/clip_encoder.py @@ -0,0 +1,242 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torchvision.transforms +from einops import rearrange + +from .sam import create_sam_vit +from .siglip_vit import create_siglip_vit + + +class CLIPVisionTower(nn.Module): + def __init__( + self, + model_name: str = "siglip_large_patch16_384", + image_size: Union[Tuple[int, int], int] = 336, + select_feature: str = "patch", + select_layer: int = -2, + select_layers: list = None, + ckpt_path: str = "", + pixel_mean: Optional[List[float]] = None, + pixel_std: Optional[List[float]] = None, + **kwargs, + ): + super().__init__() + + self.model_name = model_name + self.select_feature = select_feature + self.select_layer = select_layer + self.select_layers = select_layers + + vision_tower_params = { + "model_name": model_name, + "image_size": image_size, + "ckpt_path": ckpt_path, + "select_layer": select_layer, + } + vision_tower_params.update(kwargs) + self.vision_tower, self.forward_kwargs = self.build_vision_tower( + vision_tower_params + ) + + if pixel_mean is not None and pixel_std is not None: + image_norm = torchvision.transforms.Normalize( + mean=pixel_mean, std=pixel_std + ) + else: + image_norm = None + + self.image_norm = image_norm + + def build_vision_tower(self, vision_tower_params): + if self.model_name.startswith("siglip"): + self.select_feature = "same" + vision_tower = create_siglip_vit(**vision_tower_params) + forward_kwargs = dict() + + elif self.model_name.startswith("sam"): + vision_tower = create_sam_vit(**vision_tower_params) + forward_kwargs = dict() + + else: # huggingface + from transformers import CLIPVisionModel + + vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) + forward_kwargs = dict(output_hidden_states=True) + + return vision_tower, forward_kwargs + + def feature_select(self, image_forward_outs): + if isinstance(image_forward_outs, torch.Tensor): + # the output has been the self.select_layer"s features + image_features = image_forward_outs + else: + image_features = image_forward_outs.hidden_states[self.select_layer] + + if self.select_feature == "patch": + # if the output has cls_token + image_features = image_features[:, 1:] + elif self.select_feature == "cls_patch": + image_features = image_features + elif self.select_feature == "same": + image_features = image_features + + else: + raise ValueError(f"Unexpected select feature: {self.select_feature}") + return image_features + + def forward(self, images): + """ + + Args: + images (torch.Tensor): [b, 3, H, W] + + Returns: + image_features (torch.Tensor): [b, n_patch, d] + """ + + if self.image_norm is not None: + images = self.image_norm(images) + + image_forward_outs = self.vision_tower(images, **self.forward_kwargs) + image_features = self.feature_select(image_forward_outs) + return image_features + + +class HybridVisionTower(nn.Module): + def __init__( + self, + high_res_cfg: Dict, + low_res_cfg: Dict, + freeze_high: bool = False, + freeze_low: bool = False, + concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple", + **ignore_kwargs, + ): + super().__init__() + + self.vision_tower_high = CLIPVisionTower(**high_res_cfg) + self.vision_tower_low = CLIPVisionTower(**low_res_cfg) + self.low_res_size = low_res_cfg["image_size"] + self.concat_type = concat_type + + self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024)) + self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024)) + + if freeze_high: + for p_name, p in self.vision_tower_high.named_parameters(): + p.requires_grad = False + self.vision_tower_high = self.vision_tower_high.eval() + else: + # train donwsamples and neck + for p_name, p in self.vision_tower_high.named_parameters(): + if "downsamples" in p_name or "neck" in p_name: + p.requires_grad = True + else: + p.requires_grad = False + + if freeze_low: + for p in self.vision_tower_low.parameters(): + p.requires_grad = False + self.vision_tower_low = self.vision_tower_low.eval() + + self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True) + + def forward(self, images: torch.Tensor): + """ + + Args: + images (torch.Tensor): [bs, 3, H, W] + + Returns: + res (torch.Tensor): [bs, t, c] + """ + + # [bs, c, h, w] + high_images = images + + # [bs, c, h_low, w_low] + low_images = self.resize(images) + + # separately run two vision towers + # run high_res vision tower + high_res = self.vision_tower_high(high_images) + # [bs, c, h, w] -> [bs, h*w, c] + high_res = rearrange(high_res, "b c h w -> b (h w) c") + # run low_res vision tower + low_res = self.vision_tower_low(low_images) + + if self.concat_type == "feature": + images_features = torch.cat([high_res, low_res], dim=-1) + elif self.concat_type == "sequence": + images_features = torch.cat([high_res, low_res], dim=1) + elif self.concat_type == "add": + images_features = high_res + low_res + elif self.concat_type == "tuple": + images_features = (high_res, low_res) + + else: + raise ValueError( + "Currently only support `feature`, `sequence`, `add` and `tuple` concat type." + ) + + return images_features + + +if __name__ == "__main__": + image_size = 1024 + x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda() + + high_res_cfg = dict( + model_name="sam_b_downsample", + select_feature="same", + image_size=image_size, + pixel_mean=(0.48145466, 0.4578275, 0.40821073), + pixel_std=(0.26862954, 0.26130258, 0.27577711), + select_layer=-1, + ckpt_path="", + ) + + low_res_cfg = dict( + model_name="siglip_large_patch16_384", + select_feature="same", + image_size=384, + pixel_mean=(0.5, 0.5, 0.5), + pixel_std=(0.5, 0.5, 0.5), + select_layer=-1, + ckpt_path="", + ) + + net = ( + HybridVisionTower( + high_res_cfg=high_res_cfg, + low_res_cfg=low_res_cfg, + freeze_high=True, + freeze_low=True, + concat_type="tuple", + ) + .bfloat16() + .cuda() + ) + high_x, low_x = net(x) + print(x.shape, high_x.shape, low_x.shape) diff --git a/examples/math_prm/ursa_model/configuration_ursa.py b/examples/math_prm/ursa_model/configuration_ursa.py new file mode 100644 index 00000000..a233dd63 --- /dev/null +++ b/examples/math_prm/ursa_model/configuration_ursa.py @@ -0,0 +1,144 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +if sys.version_info >= (3, 10): + print("Python version is above 3.10, patching the collections module.") + import collections + import collections.abc + + for type_name in collections.abc.__all__: + setattr(collections, type_name, getattr(collections.abc, type_name)) + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers import CONFIG_MAPPING +from .attrdict_compat import AttrDict +logger = logging.get_logger(__name__) + + +class VisionConfig(PretrainedConfig): + model_type = "vision" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + + +class AlignerConfig(PretrainedConfig): + model_type = "aligner" + cls: str = "" + params: AttrDict = {} + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.cls = kwargs.get("cls", "") + if not isinstance(self.cls, str): + self.cls = self.cls.__name__ + + self.params = AttrDict(kwargs.get("params", {})) + +class UrsaConfig(PretrainedConfig): + model_type = "ursa" + is_composition = False + vision_config: VisionConfig + aligner_config: AlignerConfig + text_config: PretrainedConfig + + def __init__( + self, + vision_config=None, + aligner_config=None, + text_config=None, + ignore_index=-100, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + **kwargs, + ): + self.ignore_index = ignore_index + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if vision_config is None: + vision_config = VisionConfig() + vision_config.cls = "HybridVisionTower" + vision_config.params = { + "concat_type": "tuple", + "high_res_cfg": { + "ckpt_path": "", + "image_size": 1024, + "model_name": "sam_b_downsample", + "output_dim": 1024, + "pixel_mean": [ + 0.48145466, + 0.4578275, + 0.40821073 + ], + "pixel_std": [ + 0.26862954, + 0.26130258, + 0.27577711 + ], + "select_feature": "same", + "select_layer": -1 + }, + "low_res_cfg": { + "ckpt_path": "", + "image_size": 384, + "model_name": "siglip_large_patch16_384", + "output_dim": 1024, + "pixel_mean": [ + 0.5, + 0.5, + 0.5 + ], + "pixel_std": [ + 0.5, + 0.5, + 0.5 + ], + "select_feature": "same", + "select_layer": -1 + } + } + self.vision_config = vision_config + self.aligner_config = aligner_config + if isinstance(text_config, dict): + text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama" + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + + super().__init__(**kwargs) diff --git a/examples/math_prm/ursa_model/image_processing_vlm.py b/examples/math_prm/ursa_model/image_processing_vlm.py new file mode 100644 index 00000000..367dee10 --- /dev/null +++ b/examples/math_prm/ursa_model/image_processing_vlm.py @@ -0,0 +1,208 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import List, Tuple, Union + +import numpy as np +import torch +import torchvision +import torchvision.transforms.functional +from PIL import Image +from transformers import AutoImageProcessor, PretrainedConfig +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_utils import to_numpy_array +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ImageType = Union[np.ndarray, torch.Tensor, Image.Image] +IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) +IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) + + +def expand2square(pil_img, background_color): + width, height = pil_img.size + if width == height: + return pil_img + elif width > height: + result = Image.new(pil_img.mode, (width, width), background_color) + result.paste(pil_img, (0, (width - height) // 2)) + return result + else: + result = Image.new(pil_img.mode, (height, height), background_color) + result.paste(pil_img, ((height - width) // 2, 0)) + return result + + +class VLMImageProcessorConfig(PretrainedConfig): + model_type = "deepseek_vlm" + image_size: int + min_size: int + image_mean: Union[Tuple[float, float, float], List[float]] + image_std: Union[Tuple[float, float, float], List[float]] + rescale_factor: float + do_normalize: bool + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + self.image_size = image_size + self.min_size = min_size + self.image_mean = image_mean + self.image_std = image_std + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + + super().__init__(**kwargs) + + +class VLMImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + image_size: int, + min_size: int = 14, + image_mean: Union[Tuple[float, float, float], List[float]] = ( + 0.48145466, + 0.4578275, + 0.40821073, + ), + image_std: Union[Tuple[float, float, float], List[float]] = ( + 0.26862954, + 0.26130258, + 0.27577711, + ), + rescale_factor: float = 1.0 / 255.0, + do_normalize: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_size = image_size + self.rescale_factor = rescale_factor + self.image_mean = image_mean + self.image_std = image_std + self.min_size = min_size + self.do_normalize = do_normalize + + if image_mean is None: + self.background_color = (127, 127, 127) + else: + self.background_color = tuple([int(x * 255) for x in image_mean]) + + def resize(self, pil_img: Image) -> np.ndarray: + """ + + Args: + pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB + + Returns: + x (np.ndarray): [3, self.image_size, self.image_size] + """ + + width, height = pil_img.size + max_size = max(width, height) + + size = [ + max(int(height / max_size * self.image_size), self.min_size), + max(int(width / max_size * self.image_size), self.min_size), + ] + + if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: + print(f"orig size = {pil_img.size}, new size = {size}") + raise ValueError("Invalid size!") + + pil_img = torchvision.transforms.functional.resize( + pil_img, + size, + interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, + antialias=True, + ) + + pil_img = expand2square(pil_img, self.background_color) + x = to_numpy_array(pil_img) + + # [H, W, 3] -> [3, H, W] + x = np.transpose(x, (2, 0, 1)) + + return x + + def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: + # resize and pad to [self.image_size, self.image_size] + # then convert from [H, W, 3] to [3, H, W] + images: List[np.ndarray] = [self.resize(image) for image in images] + + # resacle from [0, 255] -> [0, 1] + images = [ + self.rescale( + image=image, + scale=self.rescale_factor, + input_data_format="channels_first", + ) + for image in images + ] + + # normalize + if self.do_normalize: + images = [ + self.normalize( + image=image, + mean=self.image_mean, + std=self.image_std, + input_data_format="channels_first", + ) + for image in images + ] + + data = {"pixel_values": images} + return BatchFeature(data=data, tensor_type=return_tensors) + + @property + def default_shape(self): + return [3, self.image_size, self.image_size] + + +AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor) + + +if __name__ == "__main__": + image_processor = VLMImageProcessor( + image_size=1024, + image_mean=IMAGENET_INCEPTION_MEAN, + image_std=IMAGENET_INCEPTION_STD, + do_normalize=True, + ) diff --git a/examples/math_prm/ursa_model/modeling_ursa.py b/examples/math_prm/ursa_model/modeling_ursa.py new file mode 100644 index 00000000..215008ff --- /dev/null +++ b/examples/math_prm/ursa_model/modeling_ursa.py @@ -0,0 +1,742 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers import ( + PreTrainedModel, + AutoModel, + AutoModelForCausalLM, +) +from transformers.generation import GenerationMixin +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ModelOutput +from .configuration_ursa import UrsaConfig, AlignerConfig, VisionConfig +from .clip_encoder import CLIPVisionTower, HybridVisionTower +from .projector import MlpProjector + + +def _get_module_float_dtype(module) -> Optional[torch.dtype]: + module_dtype = getattr(module, "dtype", None) + if isinstance(module_dtype, torch.dtype): + return module_dtype + for parameter in module.parameters(): + if torch.is_floating_point(parameter): + return parameter.dtype + return None + + +def _cast_pixel_values_to_vision_dtype(pixel_values: Optional[torch.Tensor], vision_model): + if pixel_values is None or not torch.is_tensor(pixel_values) or not torch.is_floating_point(pixel_values): + return pixel_values + vision_dtype = _get_module_float_dtype(vision_model) + if vision_dtype is None or pixel_values.dtype == vision_dtype: + return pixel_values + return pixel_values.to(dtype=vision_dtype) + + +@dataclass +class UrsaCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + labels: Optional[Tuple[torch.FloatTensor]] = None + + +class UrsaPreTrainedModel(PreTrainedModel): + config_class = UrsaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["qwen2vlmVisionAttention"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = ( + self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA or not. + """ + language_model = getattr(self, "language_model", None) + if language_model is None: + return False + return getattr(language_model, "_supports_sdpa", False) + + +class UrsaForConditionalGeneration(UrsaPreTrainedModel, GenerationMixin): + def __init__(self, config: UrsaConfig): + super().__init__(config) + # print(config) + # print(type(config.vision_config)) + # print(type(config.aligner_config)) + self.vision_model = HybridVisionTower(**config.vision_config["params"]) + # print(config.aligner_config) + + # print(config.aligner_config.params) + aligner_config = AlignerConfig(**config.aligner_config) + self.aligner = MlpProjector(aligner_config.params) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UrsaCausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if inputs_embeds is None: + # 1. Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + pixel_values = _cast_pixel_values_to_vision_dtype(pixel_values, self.vision_model) + image_outputs = self.vision_model(pixel_values) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + # selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + # if vision_feature_select_strategy == "default": + # selected_image_feature = selected_image_feature[:, 1:] + # elif vision_feature_select_strategy == "full": + # selected_image_feature = selected_image_feature + # else: + # raise ValueError( + # f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + # ) + + image_features = self.aligner(image_outputs) + inputs_embeds = inputs_embeds.to(image_features.dtype) + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses qwen2vlm + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return UrsaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + if hasattr(past_key_values, "seen_tokens"): + past_length = past_key_values.seen_tokens + elif hasattr(past_key_values, "_seen_tokens"): + past_length = past_key_values._seen_tokens + else: + past_length = cache_length + else: + cache_length = past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + elif self.config.image_token_index in input_ids: + input_ids = input_ids[:, input_ids.shape[1] - 1 :] + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + ) + return model_inputs + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) + + +class UrsaForTokenClassification(UrsaPreTrainedModel): + def __init__(self, config: UrsaConfig): + super().__init__(config) + self.vision_model = HybridVisionTower(**config.vision_config["params"]) + aligner_config = AlignerConfig(**config.aligner_config) + self.aligner = MlpProjector(aligner_config.params) + self.vocab_size = config.text_config.vocab_size + self.language_model = AutoModelForCausalLM.from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.dropout = nn.Dropout(0.1) + self.score = nn.Linear(self.language_model.config.hidden_size, 1) + self.post_init() + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # print(module) + nn.init.xavier_uniform_(module.weight) + nn.init.zeros_(module.bias) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): + num_images, num_image_patches, embed_dim = image_features.shape + batch_size, sequence_length = input_ids.shape + left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == self.config.image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # Compute the maximum embed dimension + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length + batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) + + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] + if left_padding: + new_token_positions += nb_image_pad[:, None] # offset for left padding + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + if labels is not None: + final_labels = torch.full( + (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + if labels is not None: + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) + + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): + raise ValueError( + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." + ) + + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. + batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) + indices_to_mask = new_token_positions[batch_indices, pad_indices] + + final_embedding[batch_indices, indices_to_mask] = 0 + + if labels is None: + final_labels = None + + return final_embedding, final_attention_mask, final_labels, position_ids + + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[int] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, UrsaCausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + if inputs_embeds is None: + # 1. Extra the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and images + if pixel_values is not None and input_ids.shape[1] != 1: + pixel_values = _cast_pixel_values_to_vision_dtype(pixel_values, self.vision_model) + image_outputs = self.vision_model(pixel_values) + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. + # selected_image_feature = image_outputs.hidden_states[vision_feature_layer] + + # if vision_feature_select_strategy == "default": + # selected_image_feature = selected_image_feature[:, 1:] + # elif vision_feature_select_strategy == "full": + # selected_image_feature = selected_image_feature + # else: + # raise ValueError( + # f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" + # ) + + image_features = self.aligner(image_outputs) + inputs_embeds = inputs_embeds.to(image_features.dtype) + inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( + image_features, inputs_embeds, input_ids, attention_mask, labels + ) + + # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of + # generation with cache + elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + + # Get the target length + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + + # Filter out only the tokens that can be un-attended, this can happen + # if one uses qwen2vlm + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] + + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + output_hidden_states=True + ) + + # logits = outputs[0] + logits = outputs.hidden_states[-1] + logits = self.dropout(logits) + logits = self.score(logits) + + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + # if attention_mask is not None: + # shift_attention_mask = attention_mask[..., 1:] + # shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() + # shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() + # else: + # shift_logits = logits[..., :-1, :].contiguous() + # shift_labels = labels[..., 1:].contiguous() + # # Flatten the tokens + # loss_fct = nn.CrossEntropyLoss() + # loss = loss_fct( + # shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) + # ) + # loss_fct = nn.CrossEntropyLoss() + # loss = loss_fct(logits.view(-1, 1), labels.view(-1)) + loss = None + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return UrsaCausalLMOutputWithPast( + loss=loss, + logits=logits, + # past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + labels=labels + # attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + if hasattr(past_key_values, "seen_tokens"): + past_length = past_key_values.seen_tokens + elif hasattr(past_key_values, "_seen_tokens"): + past_length = past_key_values._seen_tokens + else: + past_length = cache_length + else: + cache_length = past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + elif self.config.image_token_index in input_ids: + input_ids = input_ids[:, input_ids.shape[1] - 1 :] + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "pixel_values": pixel_values, + } + ) + return model_inputs + + def _reorder_cache(self, *args, **kwargs): + return self.language_model._reorder_cache(*args, **kwargs) diff --git a/examples/math_prm/ursa_model/processing_ursa.py b/examples/math_prm/ursa_model/processing_ursa.py new file mode 100644 index 00000000..1a92774a --- /dev/null +++ b/examples/math_prm/ursa_model/processing_ursa.py @@ -0,0 +1,82 @@ +# Copyright (2025) Bytedance Ltd. and/or its affiliates +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import TensorType + + +class UrsaProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + super().__init__(image_processor, tokenizer, chat_template=chat_template, **kwargs) + + @staticmethod + def _normalize_image_placeholders(text): + if isinstance(text, str): + return text.replace("", "<|image|>") + if isinstance(text, list): + return [UrsaProcessor._normalize_image_placeholders(item) for item in text] + if isinstance(text, tuple): + return tuple(UrsaProcessor._normalize_image_placeholders(item) for item in text) + return text + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + images: ImageInput = None, + videos: ImageInput = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length=None, + add_special_tokens: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, # or TensorType.PYTORCH + **kwargs, + ) -> BatchFeature: + if videos is not None: + raise ValueError("UrsaProcessor does not support video inputs.") + image_inputs = {} + if images is not None: + image_inputs = self.image_processor(images, return_tensors=return_tensors) + text = self._normalize_image_placeholders(text) + text_inputs = self.tokenizer( + text, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, + **kwargs, + ) + return BatchFeature(data={**text_inputs, **image_inputs}) + + def decode(self, *args, **kwargs): + return self.tokenizer.decode(*args, **kwargs) + + def batch_decode(self, *args, **kwargs): + return self.tokenizer.batch_decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) diff --git a/examples/math_prm/ursa_model/projector.py b/examples/math_prm/ursa_model/projector.py new file mode 100644 index 00000000..13a2c36e --- /dev/null +++ b/examples/math_prm/ursa_model/projector.py @@ -0,0 +1,101 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +from typing import Tuple, Union + +import torch +import torch.nn as nn + +from .attrdict_compat import AttrDict + + +class MlpProjector(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + if cfg.projector_type == "identity": + modules = nn.Identity() + + elif cfg.projector_type == "linear": + modules = nn.Linear(cfg.input_dim, cfg.n_embed) + + elif cfg.projector_type == "mlp_gelu": + mlp_depth = cfg.get("depth", 1) + modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + mlp_depth = cfg.get("depth", 1) + self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) + + modules = [] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) + modules = nn.Sequential(*modules) + + else: + raise ValueError(f"Unknown projector type: {cfg.projector_type}") + + self.layers = modules + + def forward( + self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] + ): + """ + + Args: + x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, + then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); + otherwise it is the feature from the single vision encoder. + + Returns: + x (torch.Tensor): [b, s, c] + """ + + if isinstance(x_or_tuple, tuple): + # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": + high_x, low_x = x_or_tuple + high_x = self.high_up_proj(high_x) + low_x = self.low_up_proj(low_x) + x = torch.concat([high_x, low_x], dim=-1) + else: + x = x_or_tuple + + return self.layers(x) + + +if __name__ == "__main__": + cfg = AttrDict( + input_dim=1024, + n_embed=2048, + depth=2, + projector_type="low_high_hybrid_split_mlp_gelu", + ) + inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) + + m = MlpProjector(cfg) + out = m(inputs) + print(out.shape) diff --git a/examples/math_prm/ursa_model/sam.py b/examples/math_prm/ursa_model/sam.py new file mode 100644 index 00000000..31159a94 --- /dev/null +++ b/examples/math_prm/ursa_model/sam.py @@ -0,0 +1,593 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + downsample_channels: Tuple[int, ...] = (512, 1024), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + downsample_channels (list): Channels for downsampling layers. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros( + 1, img_size // patch_size, img_size // patch_size, embed_dim + ) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + in_channels = out_chans + downsamples = [] + for i in range(len(downsample_channels)): + out_channels = downsample_channels[i] + downsamples.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + ) + ) + in_channels = out_channels + self.downsamples = nn.Sequential(*downsamples) + + self.sam_hd = True + if self.sam_hd: + self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1)) + # self.neck_hd = nn.Linear(embed_dim, embed_dim) + self.neck_hd = copy.deepcopy(self.neck) + # self.downsamples_hd = copy.deepcopy(self.downsamples) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + global_features = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if self.sam_hd and blk.window_size == 0: + global_features.append(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + x_dtype = x.dtype + x = F.interpolate( + x.float(), size=(96, 96), mode="bilinear", align_corners=False + ).to(x_dtype) + x = self.downsamples(x) + + if self.sam_hd: + first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2)) + x_dtype = first_global_feature.dtype + first_global_feature = F.interpolate( + first_global_feature.float(), + size=(96, 96), + mode="bilinear", + align_corners=False, + ) + first_global_feature = self.downsamples(first_global_feature.to(x_dtype)) + x = x + first_global_feature * self.hd_alpha_downsamples + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock( + embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer + ) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = ( + self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + ) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + def do_attention(q, k, v): + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.use_rel_pos: + attn = add_decomposed_rel_pos( + attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) + ) + + attn = attn.softmax(dim=-1) + x = ( + (attn @ v) + .view(B, self.num_heads, H, W, -1) + .permute(0, 2, 3, 1, 4) + .reshape(B, H, W, -1) + ) + + return x + + # from haiscale.utils import on_demand_checkpoint + # x = on_demand_checkpoint(do_attention, q, k, v) + x = do_attention(q, k, v) + x = self.proj(x) + + return x + + +def window_partition( + x: torch.Tensor, window_size: int +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, + window_size: int, + pad_hw: Tuple[int, int], + hw: Tuple[int, int], +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + + rel_h[:, :, :, :, None] + + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x + + +@dataclass +class SAMViTCfg: + image_size: Union[Tuple[int, int], int] = 1024 + width: int = 1024 + layers: int = 23 + heads: int = 16 + patch_size: int = 16 + window_size: int = 14 + prompt_embed_dim: int = 256 + global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23) + downsample_channels: Union[List[int], Tuple[int]] = (512, 1024) + + +SAM_MODEL_CONFIG = { + "sam_vit_b": { + "width": 768, + "layers": 12, + "heads": 12, + "global_attn_indexes": [2, 5, 8, 11], + "downsample_channels": (), + }, + "sam_b_downsample": { + "width": 768, + "layers": 12, + "heads": 12, + "global_attn_indexes": [2, 5, 8, 11], + "downsample_channels": (512, 1024), + }, + "sam_vit_l": { + "width": 1024, + "layers": 24, + "heads": 16, + "global_attn_indexes": [5, 11, 17, 23], + "downsample_channels": (), + }, + "sam_vit_h": { + "width": 1280, + "layers": 32, + "heads": 16, + "global_attn_indexes": [7, 15, 23, 31], + "downsample_channels": (), + }, +} + + +def create_sam_vit( + model_name: str = "sam_b_downsample", + image_size: int = 1024, + ckpt_path: str = "", + **kwargs, +): + assert ( + model_name in SAM_MODEL_CONFIG.keys() + ), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}" + + sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name]) + image_encoder = ImageEncoderViT( + depth=sam_cfg.layers, + embed_dim=sam_cfg.width, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=sam_cfg.heads, + patch_size=sam_cfg.patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=sam_cfg.global_attn_indexes, + window_size=14, + out_chans=sam_cfg.prompt_embed_dim, + downsample_channels=sam_cfg.downsample_channels, + ) + + if ckpt_path: + state_dict = torch.load(ckpt_path) + image_encoder.load_state_dict(state_dict, strict=False) + print(f"SAM-ViT restores from {ckpt_path}") + + return image_encoder + + +if __name__ == "__main__": + x = torch.zeros(2, 3, 1024, 1024).bfloat16() + # x.permute(0, 3, 1, 2) + net = create_sam_vit().bfloat16() + out = net(x) + print(x.shape, out.shape) diff --git a/examples/math_prm/ursa_model/siglip_vit.py b/examples/math_prm/ursa_model/siglip_vit.py new file mode 100644 index 00000000..a93707ba --- /dev/null +++ b/examples/math_prm/ursa_model/siglip_vit.py @@ -0,0 +1,681 @@ +# Copyright (c) 2023-2024 DeepSeek. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +# the Software, and to permit persons to whom the Software is furnished to do so, +# subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +# https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import ( + Callable, + Dict, + Final, + List, + Literal, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.layers import ( + AttentionPoolLatent, + DropPath, + LayerType, + Mlp, + PatchDropout, + PatchEmbed, + resample_abs_pos_embed, +) +from timm.models._manipulate import checkpoint_seq, named_apply + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) # noqa: E741 + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first + convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype. + Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn + from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + + with torch.no_grad(): + dtype = tensor.dtype + tensor_fp32 = tensor.float() + tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) + tensor_dtype = tensor_fp32.to(dtype=dtype) + tensor.copy_(tensor_dtype) + + +def init_weights(self): + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) + trunc_normal_(self.latent, std=self.latent_dim**-0.5) + + +def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, "init_weights"): + module.init_weights() + + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + # self.fused_attn = use_fused_attn() + self.fused_attn = True + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class VisionTransformer(nn.Module): + """Vision Transformer + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + + dynamic_img_size: Final[bool] + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: Literal["", "avg", "token", "map"] = "token", + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + qk_norm: bool = False, + init_values: Optional[float] = None, + class_token: bool = True, + no_embed_class: bool = False, + reg_tokens: int = 0, + pre_norm: bool = False, + fc_norm: Optional[bool] = None, + dynamic_img_size: bool = False, + dynamic_img_pad: bool = False, + drop_rate: float = 0.0, + pos_drop_rate: float = 0.0, + patch_drop_rate: float = 0.0, + proj_drop_rate: float = 0.0, + attn_drop_rate: float = 0.0, + drop_path_rate: float = 0.0, + weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", + embed_layer: Callable = PatchEmbed, + norm_layer: Optional[LayerType] = None, + act_layer: Optional[LayerType] = None, + block_fn: Type[nn.Module] = Block, + mlp_layer: Type[nn.Module] = Mlp, + ignore_head: bool = False, + ) -> None: + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of image input channels. + num_classes: Mumber of classes for classification head. + global_pool: Type of global pooling for final sequence (default: 'token'). + embed_dim: Transformer embedding dimension. + depth: Depth of transformer. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: Enable bias for qkv projections if True. + init_values: Layer-scale init values (layer-scale enabled if not None). + class_token: Use class token. + no_embed_class: Don't include position embeddings for class (or reg) tokens. + reg_tokens: Number of register tokens. + fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. + drop_rate: Head dropout rate. + pos_drop_rate: Position embedding dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + weight_init: Weight initialization scheme. + embed_layer: Patch embedding layer. + norm_layer: Normalization layer. + act_layer: MLP activation layer. + block_fn: Transformer block layer. + """ + super().__init__() + assert global_pool in ("", "avg", "token", "map") + assert class_token or global_pool != "token" + use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm + # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) + # act_layer = get_act_layer(act_layer) or nn.GELU + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = ( + embed_dim # num_features for consistency with other models + ) + self.num_prefix_tokens = 1 if class_token else 0 + self.num_prefix_tokens += reg_tokens + self.num_reg_tokens = reg_tokens + self.has_class_token = class_token + self.no_embed_class = ( + no_embed_class # don't embed prefix positions (includes reg) + ) + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = False + self.ignore_head = ignore_head + + embed_args = {} + if dynamic_img_size: + # flatten deferred until after pos embed + embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + dynamic_img_pad=dynamic_img_pad, + **embed_args, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = ( + nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + ) + self.reg_token = ( + nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None + ) + embed_len = ( + num_patches if no_embed_class else num_patches + self.num_prefix_tokens + ) + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) + self.pos_drop = nn.Dropout(p=pos_drop_rate) + if patch_drop_rate > 0: + self.patch_drop = PatchDropout( + patch_drop_rate, + num_prefix_tokens=self.num_prefix_tokens, + ) + else: + self.patch_drop = nn.Identity() + self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + self.blocks = nn.Sequential( + *[ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + init_values=init_values, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + mlp_layer=mlp_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Classifier Head + if global_pool == "map": + AttentionPoolLatent.init_weights = init_weights + self.attn_pool = AttentionPoolLatent( + self.embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + ) + else: + self.attn_pool = None + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + self.head_drop = nn.Dropout(drop_rate) + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + if weight_init != "skip": + self.init_weights(weight_init) + + def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: + assert mode in ("jax", "jax_nlhb", "moco", "") + # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + @torch.jit.ignore + def no_weight_decay(self) -> Set: + return {"pos_embed", "cls_token", "dist_token"} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict: + return dict( + stem=r"^cls_token|pos_embed|patch_embed", # stem and embed + blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ("", "avg", "token", "map") + if global_pool == "map" and self.attn_pool is None: + assert ( + False + ), "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != "map " and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling + self.global_pool = global_pool + self.head = ( + nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + ) + + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + if self.dynamic_img_size: + B, H, W, C = x.shape + pos_embed = resample_abs_pos_embed( + self.pos_embed, + (H, W), + num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, + ) + x = x.view(B, -1, C) + else: + pos_embed = self.pos_embed + + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + if self.reg_token is not None: + to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + pos_embed + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + pos_embed + + return self.pos_drop(x) + + def _intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + ) -> List[torch.Tensor]: + outputs, num_blocks = [], len(self.blocks) + take_indices = set( + range(num_blocks - n, num_blocks) if isinstance(n, int) else n + ) + + # forward pass + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in take_indices: + outputs.append(x) + + return outputs + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, + reshape: bool = False, + return_prefix_tokens: bool = False, + norm: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + """Intermediate layer accessor (NOTE: This is a WIP experiment). + Inspired by DINO / DINOv2 interface + """ + # take last n blocks if n is an int, if in is a sequence, select by matching indices + outputs = self._intermediate_layers(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] + outputs = [out[:, self.num_prefix_tokens :] for out in outputs] + + if reshape: + grid_size = self.patch_embed.grid_size + outputs = [ + out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) + .permute(0, 3, 1, 2) + .contiguous() + for out in outputs + ] + + if return_prefix_tokens: + return tuple(zip(outputs, prefix_tokens)) + return tuple(outputs) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + if self.attn_pool is not None: + x = self.attn_pool(x) + elif self.global_pool == "avg": + x = x[:, self.num_prefix_tokens :].mean(dim=1) + elif self.global_pool: + x = x[:, 0] # class token + x = self.fc_norm(x) + x = self.head_drop(x) + return x if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + if not self.ignore_head: + x = self.forward_head(x) + return x + + +@dataclass +class SigLIPVisionCfg: + width: int = 1152 + layers: Union[Tuple[int, int, int, int], int] = 27 + heads: int = 16 + patch_size: int = 14 + image_size: Union[Tuple[int, int], int] = 336 + global_pool: str = "map" + mlp_ratio: float = 3.7362 + class_token: bool = False + num_classes: int = 0 + use_checkpoint: bool = False + + +SigLIP_MODEL_CONFIG = { + "siglip_so400m_patch14_384": { + "image_size": 336, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_so400m_patch14_224": { + "image_size": 224, + "patch_size": 14, + "width": 1152, + "layers": 27, + "heads": 16, + "mlp_ratio": 3.7362, + "global_pool": "map", + "use_checkpoint": False, + }, + "siglip_large_patch16_384": { + "image_size": 384, + "patch_size": 16, + "width": 1024, + "layers": 24, + "heads": 16, + "mlp_ratio": 4, + "global_pool": "map", + "use_checkpoint": False, + }, +} + + +def create_siglip_vit( + model_name: str = "siglip_so400m_patch14_384", + image_size: int = 384, + select_layer: int = -1, + ckpt_path: str = "", + **kwargs, +): + assert ( + model_name in SigLIP_MODEL_CONFIG.keys() + ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" + + vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) + + if select_layer <= 0: + layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) + else: + layers = min(vision_cfg.layers, select_layer) + + model = VisionTransformer( + img_size=image_size, + patch_size=vision_cfg.patch_size, + embed_dim=vision_cfg.width, + depth=layers, + num_heads=vision_cfg.heads, + mlp_ratio=vision_cfg.mlp_ratio, + class_token=vision_cfg.class_token, + global_pool=vision_cfg.global_pool, + ignore_head=kwargs.get("ignore_head", True), + weight_init=kwargs.get("weight_init", "skip"), + num_classes=0, + ) + + if ckpt_path: + state_dict = torch.load(ckpt_path, map_location="cpu") + + incompatible_keys = model.load_state_dict(state_dict, strict=False) + print( + f"SigLIP-ViT restores from {ckpt_path},\n" + f"\tincompatible_keys:', {incompatible_keys}." + ) + + return model diff --git a/examples/math_prm/ursa_variant2.py b/examples/math_prm/ursa_variant2.py new file mode 100644 index 00000000..3e376044 --- /dev/null +++ b/examples/math_prm/ursa_variant2.py @@ -0,0 +1,441 @@ +"""URSA paper Eq.9 strict-alignment advantage estimator. + +Paper: arXiv 2501.04686 (NeurIPS 2025), Appendix B.1 Eq.9 — the second +straw-man variant the paper considers (and ultimately *rejects* in favour +of PS-GRPO): + + A_t^i = r_{s,t}^i * GroupNorm_G(r̄_s^i) (process-reward term) + + GroupNorm_G(r_o^i) (outcome-reward term) + +where t indexes *steps* (not tokens), r_{s,t}^i is the sigmoid PRM score for +step t in trajectory i, r̄_s^i = mean_t r_{s,t}^i is the per-trajectory mean +PRM score, r_o^i ∈ {0,1} is the outcome reward, and GroupNorm_G is +(x - mean_G(x)) / std_G(x) over the G=K trajectories sampled from the same +prompt. The token-level A_t is broadcast to every token spanned by step t. + +This file is intentionally self-contained in ``examples/math_prm/`` and does +**not** modify any code under ``lightrft/``. It registers a new estimator +``ursa_variant2`` by monkey-patching +``lightrft.trainer.advantage_calculator.get_advantage_calculator`` at import +time. The patch is idempotent. + +Why a separate path, not a flag on the existing per-step PRM path: +the legacy ``per_step_reward_mode`` path (still useful as Math-Shepherd-style +step-MC return) goes through ``compute_reward`` Mode B + reverse-cumsum + +GroupNormCalculator. That fully bypasses the outcome reward and uses +cumulative returns, both of which contradict Eq.9. Keeping the two paths +side by side allows ablation between paper-strict (this estimator) and +Math-Shepherd-style (legacy). +""" + +from __future__ import annotations + +from typing import Dict, List, Optional, Tuple + +import torch + +from lightrft.trainer.advantage_calculator import ( + AdvantageCalculator, + compute_clip_fraction, +) + + +class UrsaVariant2Calculator(AdvantageCalculator): + """Strict paper Eq.9 implementation. + + Reads per-trajectory step PRM scores and outcome reward from + ``experience.info`` (already emitted by ``MathPRMReward.forward`` for + label ``"math_per_step_prm"``), does its own GroupNorm, and writes + a per-token advantage tensor where every token within the span of + step k carries A_k. No cumulative return. ``returns`` mirrors + ``advantages`` (no separate value function). + """ + + _ESTIMATOR_NAME = "ursa_variant2" + + def preprocess_rewards( + self, + rewards: torch.Tensor, + experiences: List, + max_new_tokens: int, + ) -> Tuple[List, List[torch.Tensor]]: + """Compute GroupNormed (r̄_s, r_o) across all G trajectories. + + ``rewards`` is the concatenated per-trajectory scalar reward from + every experience in the batch — for ``math_per_step_prm`` rows this + is ``outcome_correct ∈ {0,1}`` (see ``reward_models.py:655``). + ``experiences`` lets us gather per-trajectory step_rewards needed + to compute r̄_s^i. We write the GroupNormed values back into each + experience's ``info`` under reserved keys ``_ursa_oc_normed`` and + ``_ursa_msp_normed`` so ``compute()`` can pick them up later. + + Aborts the variant-2 path (returns identity-chunked rewards + without touching info) when ``n_samples_per_prompt < 2`` — a + single-sample group has std = 0 and ``A_t`` would collapse. + """ + config = self.config + n_samples = int(getattr(config, "n_samples_per_prompt", 1) or 1) + + # Identity preprocessing if K<2 — variant 2 needs a group to normalize. + # We still chunk rewards back so the downstream contract is preserved. + if n_samples < 2: + reward_chunks = rewards.chunk(len(experiences)) if len(experiences) > 0 else [] + return experiences, list(reward_chunks) + + device = rewards.device + total_B = rewards.numel() + if total_B % n_samples != 0: + # Cannot group — bail out gracefully, fall back to identity. + reward_chunks = rewards.chunk(len(experiences)) + return experiences, list(reward_chunks) + + # Compute r̄_s^i for every trajectory across experiences. + mean_step_prm_chunks: List[torch.Tensor] = [] + per_exp_traj_counts: List[int] = [] + for exp in experiences: + sr_list = exp.info.get("step_rewards") + n_traj = int(exp.info["reward"].numel()) + per_exp_traj_counts.append(n_traj) + if sr_list is None: + # No per-step data — treat r̄_s as zero (variant 2 will rely + # on the outcome-norm anchor only for that trajectory). + mean_step_prm_chunks.append( + torch.zeros(n_traj, dtype=torch.float32, device=device) + ) + continue + means: List[torch.Tensor] = [] + for sr in sr_list: + if sr.numel() > 0: + means.append(sr.to(device=device, dtype=torch.float32).mean()) + else: + means.append(torch.zeros((), dtype=torch.float32, device=device)) + if len(means) != n_traj: + # Misaligned bookkeeping — fall back per-traj to zero. + pad = [torch.zeros((), dtype=torch.float32, device=device)] * ( + n_traj - len(means) + ) + means = means + pad + means = means[:n_traj] + mean_step_prm_chunks.append(torch.stack(means).to(device=device)) + mean_step_prm = torch.cat(mean_step_prm_chunks, dim=0) # (total_B,) + + # GroupNorm both terms across G=K siblings (paper Eq.9 footer text). + oc_flat = rewards.to(device=device, dtype=torch.float32) + oc_g = oc_flat.reshape(-1, n_samples) + oc_normed = ( + (oc_g - oc_g.mean(dim=-1, keepdim=True)) + / (oc_g.std(dim=-1, unbiased=False, keepdim=True) + 1e-9) + ).flatten() + + msp_g = mean_step_prm.reshape(-1, n_samples) + msp_normed = ( + (msp_g - msp_g.mean(dim=-1, keepdim=True)) + / (msp_g.std(dim=-1, unbiased=False, keepdim=True) + 1e-9) + ).flatten() + + # Scatter normed values back per-experience (keep CPU-side view to avoid + # device contention when compute() runs in a different stream). + offset = 0 + for exp, n_traj in zip(experiences, per_exp_traj_counts): + exp.info["_ursa_oc_normed"] = oc_normed[offset:offset + n_traj].clone().cpu() + exp.info["_ursa_msp_normed"] = msp_normed[offset:offset + n_traj].clone().cpu() + exp.info["_ursa_mean_step_prm_raw"] = mean_step_prm[offset:offset + n_traj].clone().cpu() + offset += n_traj + + # Default behaviour: chunk the (unmodified) per-trajectory rewards back + # to per-experience tensors — ``compute()`` ignores this anyway. + reward_chunks = oc_flat.chunk(len(experiences)) if len(experiences) > 0 else [] + return experiences, list(reward_chunks) + + def compute( + self, + experience, + final_reward: torch.Tensor, + gamma: Optional[float], + generate_kwargs: Dict, + ) -> Tuple[torch.Tensor, torch.Tensor, Dict]: + """Build per-token advantages via paper Eq.9. + + Ignores ``final_reward`` (which carries the legacy Mode B step + scatter + KL — orthogonal to Eq.9). KL is still applied separately + by the surrounding ``--use_kl_loss`` path; we only own the + advantage shape here. + """ + action_mask = experience.action_mask + if action_mask is None: + raise ValueError( + "UrsaVariant2Calculator requires action_mask (token-level " + "broadcast over step spans is undefined without it)." + ) + + device = action_mask.device + B, T = action_mask.shape + + info = experience.info + oc_normed = info.get("_ursa_oc_normed") + msp_normed = info.get("_ursa_msp_normed") + if oc_normed is None or msp_normed is None: + # preprocess_rewards bailed (K<2 or shape mismatch) — fall back + # to a degenerate advantage tensor of zeros so loss stays finite. + advantages = torch.zeros(B, T, device=device, dtype=torch.float32) + returns = advantages.clone() + return advantages, returns, {"ursa_v2_fallback_used": 1.0} + + oc_normed = oc_normed.to(device=device, dtype=torch.float32) + msp_normed = msp_normed.to(device=device, dtype=torch.float32) + + step_rewards_list = info.get("step_rewards") or [] + step_indices_list = info.get("step_token_indices") or [] + + # One-shot diagnostic on first invocation (rank0 only). Two purposes: + # (1) verifies step_rewards / step_token_indices actually reached + # compute() — easy to miss otherwise if something upstream + # drops the lists (cf. the multi-RM aggregator drop fix). + # (2) dumps the full paper Eq.9 chain on a real trajectory so the + # smoke log carries acceptance evidence for AC1+AC8 directly. + if not getattr(UrsaVariant2Calculator, "_dumped_first_call", False): + try: + import torch.distributed as dist + rank = dist.get_rank() if dist.is_initialized() else 0 + except Exception: + rank = 0 + if rank == 0: + step_rewards_keys = bool(info.get("step_rewards")) + step_indices_keys = bool(info.get("step_token_indices")) + sr_lens = ([t.numel() for t in (info.get("step_rewards") or [])][:8]) + sti_lens = ([t.numel() for t in (info.get("step_token_indices") or [])][:8]) + print(f"[ursa_v2:compute] first call rank=0 B={B} T={T} " + f"has_step_rewards={step_rewards_keys} " + f"has_step_token_indices={step_indices_keys} " + f"sr_lens(first8)={sr_lens} sti_lens(first8)={sti_lens} " + f"info keys={sorted([k for k in info.keys() if not k.startswith('_')])}", + flush=True) + # Full Eq.9 chain dump — every intermediate value so a + # reviewer can verify the implementation matches the paper + # formula on real PRM output. + if info.get("step_rewards"): + sr_lst = info["step_rewards"] + sti_lst = info["step_token_indices"] + outcome = info["reward"].float() + K = self.config.n_samples_per_prompt + print(f"[ursa_v2:chain] === paper Eq.9 chain on real PRM output (K={K}) ===", flush=True) + print(f"[ursa_v2:chain] outcome (r_o per traj) = {outcome.tolist()}", flush=True) + r_bar = torch.stack([t.float().mean() if t.numel() > 0 + else torch.tensor(0.0) for t in sr_lst]) + print(f"[ursa_v2:chain] r_bar_s (mean step PRM)= {r_bar.tolist()}", flush=True) + print(f"[ursa_v2:chain] msp_normed (post GN) = {msp_normed.tolist()}", flush=True) + print(f"[ursa_v2:chain] oc_normed (post GN) = {oc_normed.tolist()}", flush=True) + for i in range(min(B, K)): + sr_i = sr_lst[i].float().tolist() + si_i = sti_lst[i].long().tolist() + a_steps = [float(r) * float(msp_normed[i]) + float(oc_normed[i]) + for r in sr_i] + print(f"[ursa_v2:chain] traj {i}: r_o={float(outcome[i]):+.2f} " + f"r_bar={float(r_bar[i]):.4f} msp_normed={float(msp_normed[i]):+.4f} " + f"oc_normed={float(oc_normed[i]):+.4f}", flush=True) + for k, (r, idx, a) in enumerate(zip(sr_i, si_i, a_steps)): + print(f"[ursa_v2:chain] step {k+1}: r_s={r:.4f} " + f"end_token={idx:4d} A_step={r:.4f}·{float(msp_normed[i]):+.4f} + " + f"{float(oc_normed[i]):+.4f} = {a:+.4f}", flush=True) + UrsaVariant2Calculator._dumped_first_call = True + + advantages = torch.zeros(B, T, device=device, dtype=torch.float32) + per_traj_step_count = [] + for i in range(B): + has_steps = ( + i < len(step_rewards_list) + and step_rewards_list[i].numel() > 0 + and i < len(step_indices_list) + and step_indices_list[i].numel() == step_rewards_list[i].numel() + ) + if not has_steps: + # No step data — degenerate to outcome-only term spread over + # the response (matches paper's natural limit when n_steps=0 + # since the process-reward term vanishes). + advantages[i] = oc_normed[i] * action_mask[i].to(torch.float32) + per_traj_step_count.append(0) + continue + + sr = step_rewards_list[i].to(device=device, dtype=torch.float32) # (n_steps,) + si = step_indices_list[i].to(device=device, dtype=torch.long) # (n_steps,) END idx + n_steps = sr.numel() + per_traj_step_count.append(int(n_steps)) + + # Span starts: 0 for step 0, end_{k-1}+1 for k > 0 + starts = torch.cat([ + torch.zeros(1, dtype=torch.long, device=device), + si[:-1] + 1, + ]) + ends = si + + # Per-step advantage: A_k = r_{s,k} * msp_normed[i] + oc_normed[i] + # (paper Eq.9) + A_steps = sr * msp_normed[i] + oc_normed[i] # (n_steps,) + + for k in range(n_steps): + sk = max(0, int(starts[k].item())) + ek = min(T - 1, int(ends[k].item())) + if sk > ek: + continue + advantages[i, sk:ek + 1] = A_steps[k] + + # Tokens past the last step boundary (e.g. final `†Answer:` line + # tokens) are not covered by any step. Per paper Eq.9 the second + # term is t-independent, so we still apply oc_normed[i] there + # to give the model an outcome-only signal on the tail. This + # matches the implicit reading that the outcome anchor lives + # on the whole trajectory while step rewards live on steps. + last_end = int(ends[-1].item()) if n_steps > 0 else -1 + if last_end + 1 < T: + advantages[i, last_end + 1:] = oc_normed[i] + + # Respect the response action mask everywhere. + advantages = advantages * action_mask.to(torch.float32) + returns = advantages.clone() + + # Per-step credit diagnostics (these flow into the trainer's wandb + # under `train/`-style keys via the existing info_dict pipeline). + n_valid = action_mask.sum().clamp(min=1).to(torch.float32) + info_dict: Dict[str, float] = { + # Restrict the *_frac counters to valid (un-masked) tokens so they + # don't include padding-induced zeros in the denominator's response + # area. n_valid is action_mask.sum(); we mask both numerator and + # event-set to (action_mask == 1). + "ursa_v2_adv_pos_frac": ((advantages > 0) & action_mask.bool()).to(torch.float32).sum().item() / n_valid.item(), + "ursa_v2_adv_neg_frac": ((advantages < 0) & action_mask.bool()).to(torch.float32).sum().item() / n_valid.item(), + "ursa_v2_adv_zero_frac": ((advantages == 0) & action_mask.bool()).to(torch.float32).sum().item() / n_valid.item(), + "ursa_v2_adv_abs_mean": advantages.abs().sum().item() / n_valid.item(), + "ursa_v2_oc_normed_std": oc_normed.std(unbiased=False).item() if oc_normed.numel() > 1 else 0.0, + "ursa_v2_msp_normed_std": msp_normed.std(unbiased=False).item() if msp_normed.numel() > 1 else 0.0, + "ursa_v2_traj_step_count_mean": ( + sum(per_traj_step_count) / max(1, len(per_traj_step_count)) + ), + } + + # Advantage clipping (config knob, optional). + if getattr(self.config, "advantage_clip", 0) > 0: + clip_val = self.config.advantage_clip + info_dict["advantage_clip_frac"] = compute_clip_fraction( + advantages, clip_val, -clip_val + ) + advantages = torch.clamp(advantages, -clip_val, clip_val) + + return advantages, returns, info_dict + + +def _install_aggregate_rewards_patch() -> None: + """Forward step_rewards / step_token_indices through the multi-RM aggregator. + + Background: ``examples/math_prm/reward_models_utils.load_reward_models`` + returns reward_models as a List[nn.Module] even when there is only one + RM. That makes ``fast_exp_maker._aggregate_rewards`` take the + ``is_multi_rm=True`` branch, which writes ``outputs[i].rewards`` and + ``outputs[i].reward_metrics`` but — by design — drops the per-step + variable-length fields. That's correct for true multi-RM aggregation + (where combining variable-length step tensors across RMs is ill- + defined), but it silently breaks the single-list-of-one-RM case that + this example uses. + + Patch: after the original ``_aggregate_rewards`` runs, scan for the + "single underlying RM but exposed as a 1-list" pattern and lift the + step_rewards / step_token_indices from that one RM's batch result + into ``outputs[i]``. No behaviour change for true multi-RM setups. + """ + from lightrft.trainer import fast_exp_maker as _fem + + # _aggregate_rewards lives on RewardComputationEngine (separate class + # from FastExperienceMaker; reachable via fast_exp_maker.RewardComputationEngine + # or self.reward_engine on the maker). + _RewardEngine = getattr(_fem, "RewardComputationEngine", None) + if _RewardEngine is None or not hasattr(_RewardEngine, "_aggregate_rewards"): + return + if getattr(_RewardEngine, "_ursa_v2_aggregator_patched", False): + return + + _original = _RewardEngine._aggregate_rewards + + def _aggregate_rewards_patched(self, outputs, all_rewards_list, is_multi_rm): + _original(self, outputs, all_rewards_list, is_multi_rm) + if not is_multi_rm: + return + # If multiple RMs actually produced step_rewards we don't know how to + # merge them — bail (keep lightrft's safe default). + rms_with_steps = [ + rm_idx for rm_idx in range(len(all_rewards_list)) + if any(getattr(r, "step_rewards", None) is not None + for r in all_rewards_list[rm_idx]) + ] + if len(rms_with_steps) != 1: + return + rm_idx = rms_with_steps[0] + for mb_idx in range(len(outputs)): + res = all_rewards_list[rm_idx][mb_idx] + sr = getattr(res, "step_rewards", None) + sti = getattr(res, "step_token_indices", None) + if sr is not None and getattr(outputs[mb_idx], "step_rewards", None) is None: + outputs[mb_idx].step_rewards = sr + if sti is not None and getattr(outputs[mb_idx], "step_token_indices", None) is None: + outputs[mb_idx].step_token_indices = sti + + _aggregate_rewards_patched._ursa_v2_patched = True + _RewardEngine._aggregate_rewards = _aggregate_rewards_patched + _RewardEngine._ursa_v2_aggregator_patched = True + + +def _install_get_advantage_calculator_patch() -> None: + """Idempotently inject ``ursa_variant2`` into lightrft's calculator factory. + + Done from examples/ rather than editing ``lightrft/`` to keep the new + estimator strictly contained in this example. The patch wraps the + original factory; unknown names still raise the original ValueError + listing the *original* supported set + this estimator. + + Important: we patch every module that has already done + ``from .advantage_calculator import get_advantage_calculator`` because + those imports bind the original function object into the consumer + module's namespace — patching just the source module would miss them. + """ + from lightrft.trainer import advantage_calculator as _ac + + if getattr(_ac.get_advantage_calculator, "_ursa_v2_patched", False): + return + + _original = _ac.get_advantage_calculator + + def get_advantage_calculator_patched(estimator_name: str, config): + if estimator_name == UrsaVariant2Calculator._ESTIMATOR_NAME: + return UrsaVariant2Calculator(config) + return _original(estimator_name, config) + + get_advantage_calculator_patched._ursa_v2_patched = True + _ac.get_advantage_calculator = get_advantage_calculator_patched + + # Also patch known consumers that did ``from .advantage_calculator import + # get_advantage_calculator`` (binding the original ref into their own + # namespace). Currently fast_exp_maker is the only such consumer; if more + # appear later, list them here. + import sys + for mod_name in ("lightrft.trainer.fast_exp_maker",): + mod = sys.modules.get(mod_name) + if mod is not None and hasattr(mod, "get_advantage_calculator"): + mod.get_advantage_calculator = get_advantage_calculator_patched + + +def register_ursa_variant2() -> None: + """Install both monkey-patches so ``--advantage_estimator ursa_variant2`` + becomes a valid option and the multi-RM aggregator forwards step_rewards. + + Idempotent. The two underlying ``_install_*_patch`` helpers each guard + themselves with a sentinel attribute, so calling this multiple times + (e.g. from both ``math_prm_trainer`` and a future user-side import) is + safe. + """ + _install_get_advantage_calculator_patch() + _install_aggregate_rewards_patch() + + +# Also install on import so existing call-sites that rely on the side-effect +# behaviour (``import ursa_variant2`` near the top of ``math_prm_trainer``) +# still work. New code should prefer the explicit ``register_ursa_variant2()`` +# entry point. +register_ursa_variant2() diff --git a/lightrft/models/actor_language.py b/lightrft/models/actor_language.py index 912e93d3..244ebc90 100644 --- a/lightrft/models/actor_language.py +++ b/lightrft/models/actor_language.py @@ -210,10 +210,14 @@ def generate( use_cache=True, num_beams=kwargs.get("num_beams", 1), attention_mask=kwargs.get("attention_mask"), + logits_processor=kwargs.get("logits_processor"), eos_token_id=kwargs.get("eos_token_id"), pad_token_id=kwargs.get("pad_token_id"), min_new_tokens=kwargs.get("min_new_tokens", 1), + repetition_penalty=kwargs.get("repetition_penalty", 1.0), ) + if kwargs.get("no_repeat_ngram_size", 0) > 0: + generate_args["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"] if kwargs.get("max_new_tokens") is not None: generate_args["max_new_tokens"] = kwargs["max_new_tokens"] if kwargs.get("max_length") is not None: diff --git a/lightrft/models/actor_vl.py b/lightrft/models/actor_vl.py index 878c611e..4d9fce40 100644 --- a/lightrft/models/actor_vl.py +++ b/lightrft/models/actor_vl.py @@ -19,6 +19,7 @@ - MoE (Mixture of Experts) model support """ +import inspect from typing import Optional, Tuple, Union import torch @@ -85,6 +86,40 @@ class ActorVL(nn.Module): # Model modality declaration - defines what types of inputs this model accepts modality = ActorModality.VISION_LANGUAGE + def _get_model_dtype(self) -> Optional[torch.dtype]: + model_dtype = getattr(self.model, "dtype", None) + if isinstance(model_dtype, torch.dtype): + return model_dtype + for parameter in self.model.parameters(): + if torch.is_floating_point(parameter): + return parameter.dtype + return None + + def _cast_multimodal_tensor(self, value: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if value is None or not torch.is_tensor(value) or not torch.is_floating_point(value): + return value + model_dtype = self._get_model_dtype() + if model_dtype is None or value.dtype == model_dtype: + return value + return value.to(dtype=model_dtype) + + def _supports_model_kwarg(self, kwarg_name: str, *, generation: bool = False) -> bool: + targets = [] + if generation: + targets.append(getattr(self.model, "prepare_inputs_for_generation", None)) + targets.append(getattr(self.model, "forward", None)) + + for target in targets: + if target is None: + continue + try: + parameters = inspect.signature(target).parameters + except (TypeError, ValueError): + continue + if kwarg_name in parameters: + return True + return False + def __init__( self, pretrain_or_model, @@ -102,6 +137,7 @@ def __init__( ) -> None: super().__init__() self.high_entropy_token_ratio = high_entropy_token_ratio + self.packing_samples = packing_samples if isinstance(pretrain_or_model, str): self.pretrain_or_model = pretrain_or_model @@ -151,8 +187,6 @@ def __init__( # Use `model.generate(use_cache=True)` instead.` self.model.config.use_cache = False - # packing samples using Flash Attention 2 - self.packing_samples = packing_samples else: self.model = pretrain_or_model self.pretrain_or_model = pretrain_or_model.config.model_type @@ -206,9 +240,9 @@ def generate( """ generate_args = { "input_ids": input_ids, - "pixel_values": pixel_values, + "pixel_values": self._cast_multimodal_tensor(pixel_values), "image_grid_thw": image_grid_thw, - "pixel_values_videos": pixel_values_videos, + "pixel_values_videos": self._cast_multimodal_tensor(pixel_values_videos), "video_grid_thw": video_grid_thw, "top_k": kwargs.get("top_k", None), "top_p": kwargs.get("top_p", None), @@ -218,16 +252,26 @@ def generate( "use_cache": True, "num_beams": kwargs.get("num_beams", 1), "attention_mask": kwargs.get("attention_mask"), + "logits_processor": kwargs.get("logits_processor"), "eos_token_id": kwargs.get("eos_token_id"), "pad_token_id": kwargs.get("pad_token_id"), "min_new_tokens": kwargs.get("min_new_tokens", 1), + "repetition_penalty": kwargs.get("repetition_penalty", 1.0), } + if kwargs.get("no_repeat_ngram_size", 0) > 0: + generate_args["no_repeat_ngram_size"] = kwargs["no_repeat_ngram_size"] if kwargs.get("max_new_tokens", None): generate_args["max_new_tokens"] = kwargs.get("max_new_tokens") if kwargs.get("max_length", None): generate_args["max_length"] = kwargs.get("max_length") + for model_kwarg in ("pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"): + if model_kwarg in generate_args and ( + generate_args[model_kwarg] is None or not self._supports_model_kwarg(model_kwarg, generation=True) + ): + generate_args.pop(model_kwarg) + # Call generate sequences = self.model.generate(**generate_args) @@ -306,14 +350,21 @@ def forward( # explicitly ignore attention_mask for packing_samples attention_mask = None + forward_kwargs = { + "attention_mask": attention_mask, + "position_ids": position_ids, + "pixel_values": self._cast_multimodal_tensor(pixel_values), + "image_grid_thw": image_grid_thw, + "pixel_values_videos": self._cast_multimodal_tensor(pixel_values_videos), + "video_grid_thw": video_grid_thw, + } + for model_kwarg in ("pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw"): + if not self._supports_model_kwarg(model_kwarg): + forward_kwargs.pop(model_kwarg, None) + output = self.model( sequences, - attention_mask=attention_mask, - position_ids=position_ids, - pixel_values=pixel_values, - image_grid_thw=image_grid_thw, - pixel_values_videos=pixel_values_videos, - video_grid_thw=video_grid_thw, + **forward_kwargs, ) if num_actions is None: # defult diff --git a/lightrft/models/loss.py b/lightrft/models/loss.py index d593cb93..a3134e74 100644 --- a/lightrft/models/loss.py +++ b/lightrft/models/loss.py @@ -176,6 +176,22 @@ def __init__( self.use_dapo = use_dapo self.use_cpg_loss = use_cpg_loss self.high_entropy_token_ratio = high_entropy_token_ratio + # Per-forward diagnostic stats for ratio behavior. Populated each time + # forward() runs in PPO mode and read by the trainer via get_last_stats(). + # CPGD mode has no ratio so these stay empty. + self._last_stats: dict = {} + + def get_last_stats(self) -> dict: + """Return per-token ratio diagnostics from the most recent PPO forward. + + Keys: ``ratio_mean``, ``ratio_max``, ``ratio_min``, ``clipfrac``, + ``approx_kl``. ``clipfrac`` is the fraction of valid action tokens + whose unclipped ratio fell outside ``[1 - clip_eps, 1 + clip_eps]``. + ``approx_kl`` is the K2 estimator ``0.5 * mean((log p - log p_old)^2)`` + which is a low-variance proxy for the per-step "old vs new" KL + (different from the actor-vs-reference KL controller signal). + """ + return dict(self._last_stats) def forward( self, @@ -246,12 +262,44 @@ def forward( return loss # PPO loss - ratio = (log_probs - old_log_probs).exp() + log_ratio = log_probs - old_log_probs + ratio = log_ratio.exp() surr1 = ratio * advantages surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages loss = -torch.min(surr1, surr2) loss = masked_mean(loss, final_mask, dim=-1).mean() + # Diagnostic stats over valid action tokens only. + # Detached so they don't accidentally enter the autograd graph. + with torch.no_grad(): + if final_mask is None: + m = torch.ones_like(ratio, dtype=torch.bool) + else: + m = final_mask.bool() + r_valid = ratio[m] + lr_valid = log_ratio[m] + if r_valid.numel() == 0: + self._last_stats = { + "ratio_mean": 0.0, + "ratio_max": 0.0, + "ratio_min": 0.0, + "clipfrac": 0.0, + "approx_kl": 0.0, + } + else: + # `clipfrac` counts the tokens whose UNCLIPPED ratio is outside + # [1-eps, 1+eps]. High clipfrac means PPO is suppressing many + # gradient updates this step, which is a signal that the new + # policy has moved noticeably from the rollout policy. + clipped = (r_valid > 1 + self.clip_eps) | (r_valid < 1 - self.clip_eps) + self._last_stats = { + "ratio_mean": r_valid.float().mean().item(), + "ratio_max": r_valid.float().max().item(), + "ratio_min": r_valid.float().min().item(), + "clipfrac": clipped.float().mean().item(), + "approx_kl": 0.5 * lr_valid.float().pow(2).mean().item(), + } + return loss diff --git a/lightrft/models/utils.py b/lightrft/models/utils.py index 15f9c75f..fa8de9f5 100644 --- a/lightrft/models/utils.py +++ b/lightrft/models/utils.py @@ -224,6 +224,24 @@ def log_probs_from_logits( >>> log_probs.shape torch.Size([2, 3]) """ + # PyTorch's torch.gather(dim=-1, index=...) does NOT require non-dim + # axes to match: when ``logits`` has more rows than ``labels``, gather + # silently truncates to ``len(labels)`` instead of raising. That made it + # impossible to spot a VLM-specific alignment bug where ``output["logits"]`` + # is longer than ``sequences`` (image placeholder gets expanded into N + # vision-patch tokens during the LM forward) — see PR #53. Reject the + # mismatch up-front so any future caller using this helper crashes loudly + # and is forced to align logits to labels at the call site. + if logits.shape[:-1] != labels.shape: + raise ValueError( + "log_probs_from_logits: logits and labels must have matching " + f"non-vocab shapes. Got logits.shape={tuple(logits.shape)}, " + f"labels.shape={tuple(labels.shape)}. For VLMs, output['logits'] " + "may be longer than the input sequences because vision tokens " + "expand placeholders during the forward pass — slice the logits " + "to the action range before calling this helper." + ) + if logits.dtype in [torch.float32, torch.float64]: batch_dim = logits.shape[:-1] last_dim = logits.shape[-1] @@ -444,14 +462,24 @@ def compute_reward( action_mask: Optional[torch.Tensor] = None, num_actions: Optional[Union[int, list[int]]] = None, reward_clip_range: Tuple[float, float] = None, + step_rewards: Optional[torch.Tensor] = None, + step_token_indices: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, list[torch.Tensor]]: """ Compute final reward by combining base reward with KL penalty. Combines base reward with KL divergence penalty to encourage policy stability. - Supports two modes: with action mask (efficient) and without (individual processing). - - :param r: Base reward tensor or scalar + Supports three modes: + A. trajectory-scalar (default, legacy behavior): scatters scalar `r[i]` + to the EOS position of each row. Used by ORM-style RL. + B. per-step (NEW, opt-in): scatters multiple step rewards to step + boundary token positions for true per-step credit assignment. Used + by PRM-RL methods like Math-Shepherd / variant 2 of URSA paper. + Activated when `step_rewards` is not None. + C. no action_mask: per-row variable-length list mode (legacy). + + :param r: Base reward tensor of shape (B,) or scalar. In per-step mode, + used as a fallback only if step_rewards is None for any row. :type r: Union[torch.Tensor, float] :param kl_coef: KL penalty coefficient (<=0 disables penalty) :type kl_coef: float @@ -463,11 +491,23 @@ def compute_reward( :type num_actions: Optional[Union[int, list[int]]] :param reward_clip_range: (min, max) to clip base reward :type reward_clip_range: Tuple[float, float] - - :return: Final reward tensor or list + :param step_rewards: PER-STEP rewards of shape (B, max_steps) padded with + any value (only positions in `step_token_indices` are read). When + provided together with ``step_token_indices``, mode (B) is enabled + and the scalar ``r`` EOS-scatter is bypassed. + :type step_rewards: Optional[torch.Tensor] + :param step_token_indices: Token indices in the action / response space + of shape (B, max_steps). Positions with value < 0 are treated as + padding and skipped. ``step_token_indices[i, k]`` = boundary token + index in sequence i for step k; ``step_rewards[i, k]`` is scattered + to that position (NOT the EOS) so cumulative-returns can propagate + per-step credit. + :type step_token_indices: Optional[torch.Tensor] + + :return: Final reward tensor or list, shape (B, response_size). :rtype: Union[torch.Tensor, list[torch.Tensor]] - Example:: + Example (mode A, legacy):: >>> r = torch.tensor([1.0, 2.0]) >>> kl_coef = 0.1 >>> kl = torch.tensor([[0.1, 0.2, 0.3], [0.2, 0.1, 0.4]]) @@ -475,29 +515,70 @@ def compute_reward( >>> reward = compute_reward(r, kl_coef, kl, action_mask) >>> reward.shape torch.Size([2, 3]) + + Example (mode B, per-step):: + >>> step_rewards = torch.tensor([[0.2, 0.8, 0.7], + ... [0.5, 0.6, -1.]]) # last col padded + >>> step_token_indices = torch.tensor([[1, 3, 4], + ... [0, 2, -1]]) + >>> reward = compute_reward( + ... r=None, kl_coef=0.0, kl=torch.zeros(2, 5), + ... action_mask=torch.ones(2, 5), + ... step_rewards=step_rewards, + ... step_token_indices=step_token_indices, + ... ) + >>> reward + tensor([[0.0000, 0.2000, 0.0000, 0.8000, 0.7000], + [0.5000, 0.0000, 0.6000, 0.0000, 0.0000]]) """ if kl_coef <= 0.0: kl_coef = 0.0 - if reward_clip_range: + use_per_step = step_rewards is not None and step_token_indices is not None + + if not use_per_step and reward_clip_range: r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1]) if action_mask is not None: kl_reward = -kl_coef * kl - # The following code is equivalent to: - # - # last_reward = torch.zeros_like(kl) - # for i in range(last_reward.size(0)): - # for t in reversed(range(last_reward.size(1))): - # if action_mask[i][t] > 0.5: - # last_reward[i][t] = r[i] - # break - # - eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True) - last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype)) + + if use_per_step: + # Mode B: scatter each step's reward to its boundary token index. + # step_rewards: (B, S), step_token_indices: (B, S). Padding = idx < 0. + base = torch.zeros_like(kl) + valid = step_token_indices >= 0 + if valid.any(): + # Gather flat row/col indices, scatter via index_put_ + row_idx = torch.arange(step_token_indices.size(0), device=kl.device) + row_idx = row_idx.unsqueeze(1).expand_as(step_token_indices) + flat_rows = row_idx[valid] + flat_cols = step_token_indices[valid].long() + # Clamp cols into valid range to avoid OOB; padded cols are filtered by `valid` + flat_cols = flat_cols.clamp(min=0, max=base.size(1) - 1) + flat_vals = step_rewards[valid].to(kl.dtype) + # accumulate (multiple steps could land on same token; rare but safe) + base.index_put_((flat_rows, flat_cols), flat_vals, accumulate=True) + last_reward = base + else: + # Mode A: legacy — scatter scalar r[i] to EOS index of row i. + # + # The following code is equivalent to: + # + # last_reward = torch.zeros_like(kl) + # for i in range(last_reward.size(0)): + # for t in reversed(range(last_reward.size(1))): + # if action_mask[i][t] > 0.5: + # last_reward[i][t] = r[i] + # break + # + eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True) + last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype)) reward = last_reward + kl_reward else: + # Mode C: per-row variable-length (legacy). Per-step mode is only + # supported with action_mask; fall back to scalar EOS even if + # use_per_step is set, to keep this branch backward-compat. # TODO: write a more efficient version reward = [] for i, (kl_seg, action_len) in enumerate(zip(kl, num_actions)): diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index 008fccd9..f523b3bb 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -48,10 +48,18 @@ class StrategyConfig: overlap_comm: bool = False # Engine and inference parameters - # (str): Inference engine type, defaults to "vllm" + # (str): Inference engine type, defaults to "vllm". Supported values include "vllm", "sglang", and "hf". engine_type: str = "vllm" # (int): Engine tensor parallelism size, defaults to 1 engine_tp_size: int = 1 + # (int): Maximum local HF generation batch size, <=0 disables chunking + local_hf_generate_max_batch_size: int = 0 + # (int): Optional max_new_tokens cap applied only to local HF rollout generation, <=0 disables the cap + local_hf_max_new_tokens: int = 0 + # (bool): Use a dedicated local HF rollout actor instead of reusing the training actor + hf_separate_rollout_actor: bool = False + # (bool): Keep the dedicated local HF rollout actor resident on GPU instead of sleeping/offloading it + hf_separate_rollout_keep_on_gpu: bool = False # (bool): Enable engine sleep mode, defaults to False enable_engine_sleep: bool = False # (int): Local rank for distributed training, defaults to -1 @@ -139,7 +147,6 @@ class StrategyConfig: plot_every: int = -1 # (bool): Use TensorBoard for logging, defaults to False use_tensorboard: bool = False - # Additional arguments for backward compatibility # (Dict[str, Any]): Extra arguments for backward compatibility, defaults to {} extra_args: Dict[str, Any] = field(default_factory=dict) @@ -237,7 +244,10 @@ def print_config_summary(self) -> None: # Engine and Inference Parameters print("\nEngine and Inference Parameters:") - for attr in ['engine_type', 'engine_tp_size', 'enable_engine_sleep', 'local_rank', 'sp_size']: + for attr in [ + 'engine_type', 'engine_tp_size', 'local_hf_generate_max_batch_size', 'hf_separate_rollout_actor', + 'enable_engine_sleep', 'local_rank', 'sp_size' + ]: current = getattr(self, attr) default = getattr(default_config, attr) status = "Overridden" if current != default else "Default" diff --git a/lightrft/strategy/fake_strategy.py b/lightrft/strategy/fake_strategy.py index e1ed71ed..c39f005f 100644 --- a/lightrft/strategy/fake_strategy.py +++ b/lightrft/strategy/fake_strategy.py @@ -282,7 +282,7 @@ def get_rank(self) -> int: """ return 0 - def setup_inference_engine(self, args, engine_type="vllm", actor=None): + def setup_inference_engine(self, args, engine_type="vllm", actor=None, tokenizer=None, processor=None): """ Fake inference engine setup - returns None. @@ -311,7 +311,18 @@ def wakeup_inference_engine(self): """ self.print("FakeStrategy: Inference engine wakeup skipped") - def engine_generate_local(self, sampling_params, prompt_token_ids=None, multi_modal_inputs=None): + def engine_generate_local( + self, + sampling_params, + prompt_token_ids=None, + multi_modal_inputs=None, + pixel_values=None, + image_grid_thw=None, + pixel_values_videos=None, + video_grid_thw=None, + images_num=None, + videos_num=None, + ): """ Fake generation - returns empty results. @@ -332,7 +343,13 @@ def gather_and_generate( all_prompts=None, all_images=None, sleep_engine=True, - images_num=None + images_num=None, + all_videos=None, + videos_num=None, + all_images_pixel_values=None, + all_videos_pixel_values=None, + all_images_grid_thw=None, + all_videos_grid_thw=None, ): """ Fake gather and generate - returns empty results. diff --git a/lightrft/strategy/strategy_base.py b/lightrft/strategy/strategy_base.py index 050c5a00..20a919cf 100644 --- a/lightrft/strategy/strategy_base.py +++ b/lightrft/strategy/strategy_base.py @@ -110,7 +110,15 @@ def __init__( # pylint: disable=R0917 # inference (rollout) engine related self.inference_engine = None self.inference_engine_status = EngineStatus.SLEEPED + self.inference_tokenizer = None + self.inference_processor = None self.broadcast_manager = None + self.rollout_train_actor = None + self.rollout_train_actor_is_on_gpu = False + self.use_separate_hf_rollout_actor = False + self._separate_hf_rollout_sync_param_pairs = None + self._separate_hf_rollout_sync_buffer_pairs = None + self.last_separate_hf_rollout_sync_stats = None self.time_steps = defaultdict(int) @@ -184,7 +192,7 @@ def setup_distributed(self, timeout: Optional[timedelta] = None, num_gpu_per_nod ) # TODO: unify the init_process_group for both vllm and sglang when stable version finished - if self.config.engine_type in ("vllm", "sglang"): + if self.config.engine_type in ("vllm", "sglang", "hf"): dist.init_process_group( rank=rank, world_size=world_size, @@ -198,7 +206,7 @@ def setup_distributed(self, timeout: Optional[timedelta] = None, num_gpu_per_nod raise ValueError(f"Unsupported backend: {self.config.engine_type}") else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs - if self.config.engine_type in ("vllm", "sglang"): + if self.config.engine_type in ("vllm", "sglang", "hf"): deepspeed.init_distributed(dist_backend="nccl", timeout=timeout) else: raise ValueError(f"Unsupported backend: {self.config.engine_type}") @@ -528,6 +536,16 @@ def prepare_models_and_optimizers(self, actor, critic, reward_models, initial_mo setattr(actor, "is_actor", True) fsdp_enable = self.config.fsdp + + def _resolve_fsdp_shard_size(preferred_shard_size: int) -> int: + world_size = max(int(getattr(self, "world_size", 1)), 1) + if world_size % preferred_shard_size == 0: + return preferred_shard_size + candidate = min(preferred_shard_size, world_size) + while candidate > 1 and world_size % candidate != 0: + candidate -= 1 + return max(candidate, 1) + # For FSDP: wrap model first, then create optimizer if fsdp_enable: actor = self.prepare_model(actor, is_training=True) @@ -535,7 +553,11 @@ def prepare_models_and_optimizers(self, actor, critic, reward_models, initial_mo if critic is not None: critic = self.prepare_model(critic, is_training=True) if not self.config.remote_rm_url: - reward_model_shard_size = self.world_size + reward_model_shard_size = _resolve_fsdp_shard_size(preferred_shard_size=8) + self.print( + "Preparing reward model(s) with shard_size=" + f"{reward_model_shard_size} (world_size={getattr(self, 'world_size', 1)})" + ) if isinstance(reward_models, (tuple, list)): reward_models = [ self.prepare_model(model, shard_size=reward_model_shard_size) for model in reward_models @@ -655,13 +677,164 @@ def report_memory(cls, prefix=""): f"ALLOCATED={torch.cuda.memory_allocated() / 1e9:.2f} GB" ) - def setup_inference_engine(self, args, engine_type="vllm", actor=None): + def _uses_separate_hf_rollout_actor(self) -> bool: + return ( + self.inference_engine_type == "hf" and self.use_separate_hf_rollout_actor + and self.rollout_train_actor is not None and self.inference_engine is not None + and self.inference_engine is not self.rollout_train_actor + ) + + def _keeps_separate_hf_rollout_actor_on_gpu(self) -> bool: + return self._uses_separate_hf_rollout_actor() and bool( + getattr(self.config, "hf_separate_rollout_keep_on_gpu", False) + ) + + def _build_local_hf_rollout_actor_sync_plan( + self, src_actor: nn.Module, dst_actor: nn.Module + ) -> Tuple[List[Tuple[str, nn.Parameter, nn.Parameter]], List[Tuple[str, torch.Tensor, torch.Tensor]]]: + src_params = dict(src_actor.named_parameters()) + dst_params = dict(dst_actor.named_parameters()) + if src_params.keys() != dst_params.keys(): + missing_in_dst = sorted(set(src_params) - set(dst_params)) + missing_in_src = sorted(set(dst_params) - set(src_params)) + raise ValueError( + "Separate local HF rollout actor parameter mismatch. " + f"missing_in_dst={missing_in_dst[:8]}, missing_in_src={missing_in_src[:8]}" + ) + + param_pairs = [] + for name, src_param in src_params.items(): + dst_param = dst_params[name] + if src_param.shape != dst_param.shape: + raise ValueError( + f"Separate local HF rollout actor parameter shape mismatch for {name}: " + f"{tuple(src_param.shape)} vs {tuple(dst_param.shape)}" + ) + param_pairs.append((name, src_param, dst_param)) + + src_buffers = dict(src_actor.named_buffers()) + dst_buffers = dict(dst_actor.named_buffers()) + buffer_pairs = [] + for name in sorted(set(src_buffers) & set(dst_buffers)): + src_buffer = src_buffers[name] + dst_buffer = dst_buffers[name] + if src_buffer.shape != dst_buffer.shape: + raise ValueError( + f"Separate local HF rollout actor buffer shape mismatch for {name}: " + f"{tuple(src_buffer.shape)} vs {tuple(dst_buffer.shape)}" + ) + buffer_pairs.append((name, src_buffer, dst_buffer)) + return param_pairs, buffer_pairs + + def _copy_local_hf_rollout_actor_state(self, src_actor: nn.Module, dst_actor: nn.Module) -> None: + if self._separate_hf_rollout_sync_param_pairs is None or self._separate_hf_rollout_sync_buffer_pairs is None: + ( + self._separate_hf_rollout_sync_param_pairs, + self._separate_hf_rollout_sync_buffer_pairs, + ) = self._build_local_hf_rollout_actor_sync_plan(src_actor, dst_actor) + + for name, src_param, dst_param in self._separate_hf_rollout_sync_param_pairs: + src_tensor = src_param.detach() + if src_tensor.device != dst_param.device or src_tensor.dtype != dst_param.dtype: + src_tensor = src_tensor.to(device=dst_param.device, dtype=dst_param.dtype) + dst_param.detach().copy_(src_tensor) + + for name, src_buffer_ref, dst_buffer in self._separate_hf_rollout_sync_buffer_pairs: + src_buffer = src_buffer_ref.detach() + if src_buffer.device != dst_buffer.device or src_buffer.dtype != dst_buffer.dtype: + src_buffer = src_buffer.to(device=dst_buffer.device, dtype=dst_buffer.dtype) + dst_buffer.detach().copy_(src_buffer) + + def _prepare_separate_hf_rollout_actor_for_generation(self) -> None: + if not self._uses_separate_hf_rollout_actor(): + return + model = getattr(self.inference_engine, "model", None) + if isinstance(model, nn.Module): + model.eval() + self.inference_engine.eval() + + def _sync_separate_hf_rollout_actor(self, actor: nn.Module) -> None: + if not self.config.fsdp: + raise NotImplementedError("Separate local HF rollout actor currently only supports FSDP.") + if not self._uses_separate_hf_rollout_actor(): + raise RuntimeError("Separate local HF rollout actor is not initialized.") + + keep_on_gpu = self._keeps_separate_hf_rollout_actor_on_gpu() + sync_t0 = time.time() + actor_offloaded = False + offload_actor_t0 = time.time() + if not keep_on_gpu and self.rollout_train_actor_is_on_gpu: + self.offload_model(actor) + self.rollout_train_actor_is_on_gpu = False + actor_offloaded = True + offload_actor_s = time.time() - offload_actor_t0 + + rollout_offloaded = False + offload_rollout_t0 = time.time() + if not keep_on_gpu and self.inference_engine_status != EngineStatus.SLEEPED: + self.offload_model(self.inference_engine, empty_cache=False) + rollout_offloaded = True + offload_rollout_s = time.time() - offload_rollout_t0 + + copy_state_t0 = time.time() + self._copy_local_hf_rollout_actor_state(actor, self.inference_engine) + copy_state_s = time.time() - copy_state_t0 + + reload_rollout_t0 = time.time() + rollout_reloaded = False + if keep_on_gpu: + # `keep_on_gpu=True` skips the wakeup path, so we must explicitly + # rematerialize the rollout actor here after copying the updated state. + self.reload_model(self.inference_engine) + rollout_reloaded = True + reload_rollout_s = time.time() - reload_rollout_t0 + + prepare_t0 = time.time() + self._prepare_separate_hf_rollout_actor_for_generation() + prepare_s = time.time() - prepare_t0 + + sync_clear_t0 = time.time() + if keep_on_gpu: + torch.cuda.synchronize() + torch.distributed.barrier() + self.inference_engine_status = EngineStatus.WAKEUP + else: + self.inference_engine_status = EngineStatus.SLEEPED + self.sync_and_clear_cache() + sync_clear_s = time.time() - sync_clear_t0 + self.last_separate_hf_rollout_sync_stats = { + "total_s": round(time.time() - sync_t0, 4), + "keep_on_gpu": keep_on_gpu, + "actor_offloaded": actor_offloaded, + "rollout_offloaded": rollout_offloaded, + "rollout_reloaded": rollout_reloaded, + "offload_actor_s": round(offload_actor_s, 4), + "offload_rollout_s": round(offload_rollout_s, 4), + "copy_state_s": round(copy_state_s, 4), + "reload_rollout_s": round(reload_rollout_s, 4), + "prepare_s": round(prepare_s, 4), + "sync_clear_s": round(sync_clear_s, 4), + } + self.print( + "Finished update engine weights for separate local HF rollout actor", + self.last_separate_hf_rollout_sync_stats, + ) + + def setup_inference_engine( + self, + args, + engine_type="vllm", + actor=None, + rollout_actor=None, + tokenizer=None, + processor=None, + ): """ Initialize and setup the inference engine. :param args: Configuration arguments :type args: argparse.Namespace - :param engine_type: Type of inference engine ('vllm' or 'sglang') + :param engine_type: Type of inference engine ('vllm', 'sglang', or 'hf') :type engine_type: str :param actor: The actor module, if passed, will be used to update engine weights :type actor: torch.nn.Module @@ -671,6 +844,14 @@ def setup_inference_engine(self, args, engine_type="vllm", actor=None): :raises ValueError: If engine_type is not supported """ self.inference_engine_type = engine_type + self.inference_tokenizer = tokenizer + self.inference_processor = processor + self.rollout_train_actor = None + self.rollout_train_actor_is_on_gpu = False + self.use_separate_hf_rollout_actor = False + self._separate_hf_rollout_sync_param_pairs = None + self._separate_hf_rollout_sync_buffer_pairs = None + self.last_separate_hf_rollout_sync_stats = None if engine_type == "vllm": # Conditional import: vLLM is optional and only imported when explicitly requested @@ -683,10 +864,30 @@ def setup_inference_engine(self, args, engine_type="vllm", actor=None): from .sglang_utils import get_sglang_engine_for_rollout self.inference_engine = get_sglang_engine_for_rollout(args) self.inference_engine_status = EngineStatus.WAKEUP + elif engine_type == "hf": + if actor is None: + raise ValueError("engine_type='hf' requires the prepared actor to be passed in.") + if getattr(args, "hf_separate_rollout_actor", False): + if rollout_actor is None: + raise ValueError( + "engine_type='hf' with --hf_separate_rollout_actor requires a prepared rollout_actor." + ) + self.use_separate_hf_rollout_actor = True + self.rollout_train_actor = actor + self.rollout_train_actor_is_on_gpu = True + self.inference_engine = rollout_actor + self._prepare_separate_hf_rollout_actor_for_generation() + self.inference_engine_status = ( + EngineStatus.WAKEUP if self._keeps_separate_hf_rollout_actor_on_gpu() else EngineStatus.SLEEPED + ) + else: + # Local HF mode reuses the actor directly for time-boxed smoke runs. + self.inference_engine = actor + self.inference_engine_status = EngineStatus.WAKEUP else: raise ValueError(f"Unsupported engine type: {engine_type}") - if actor is not None: + if actor is not None and (engine_type != "hf" or self.use_separate_hf_rollout_actor): self.update_engine_weights(actor) self.maybe_sleep_inference_engine() return self.inference_engine @@ -700,15 +901,33 @@ def maybe_sleep_inference_engine(self): :raises ValueError: If the inference engine type is not supported """ - if self.inference_engine is not None and self.args.enable_engine_sleep: + if self.inference_engine is None or self.inference_engine_status == EngineStatus.SLEEPED: + return + if self._keeps_separate_hf_rollout_actor_on_gpu(): + self._prepare_separate_hf_rollout_actor_for_generation() + self.inference_engine_status = EngineStatus.WAKEUP + return + if self.inference_engine is not None and ( + self.args.enable_engine_sleep or self._uses_separate_hf_rollout_actor() + ): + sleep_t0 = time.time() if self.inference_engine_type in ["vllm", "sglang"]: self.inference_engine.sleep() + elif self.inference_engine_type == "hf": + if self._uses_separate_hf_rollout_actor(): + self.offload_model(self.inference_engine) + if not self.rollout_train_actor_is_on_gpu: + self.reload_model(self.rollout_train_actor) + self.rollout_train_actor_is_on_gpu = True + self.rollout_train_actor.train() + else: + return else: raise ValueError(f"Unsupported engine type: {self.inference_engine_type}") self.inference_engine_status = EngineStatus.SLEEPED self.sync_and_clear_cache() - self.print("Sleeped inference engine") + self.print(f"Sleeped inference engine, TIMECOST {time.time() - sleep_t0}") def wakeup_inference_engine(self): """ @@ -722,11 +941,25 @@ def wakeup_inference_engine(self): """ if self.inference_engine is None or self.inference_engine_status == EngineStatus.WAKEUP: return + if self._keeps_separate_hf_rollout_actor_on_gpu(): + self._prepare_separate_hf_rollout_actor_for_generation() + self.inference_engine_status = EngineStatus.WAKEUP + return self.sync_and_clear_cache() wkup_t0 = time.time() if self.inference_engine_type in ["vllm", "sglang"]: self.inference_engine.wake_up() + elif self.inference_engine_type == "hf": + if self._uses_separate_hf_rollout_actor(): + if self.rollout_train_actor_is_on_gpu: + self.offload_model(self.rollout_train_actor) + self.rollout_train_actor_is_on_gpu = False + self.reload_model(self.inference_engine) + self._prepare_separate_hf_rollout_actor_for_generation() + self.print(f"Finished {self.inference_engine_type} wakeup, TIMECOST {time.time() - wkup_t0}") + self.inference_engine_status = EngineStatus.WAKEUP + return else: raise ValueError(f"Unsupported engine type: {self.inference_engine_type}") # torch.cuda.reset_max_memory_allocated() @@ -779,6 +1012,12 @@ def engine_generate_local( sampling_params: Any, prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, multi_modal_inputs: Optional[List[Dict[str, Any]]] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.Tensor] = None, + video_grid_thw: Optional[torch.Tensor] = None, + images_num: Optional[List[int]] = None, + videos_num: Optional[List[int]] = None, ) -> List[EasyDict]: """ Perform text or multimodal generation using different inference engines based on the input mode. @@ -804,7 +1043,7 @@ def engine_generate_local( if prompt_token_ids is None and multi_modal_inputs is None: raise ValueError("Either prompt_token_ids or multi_modal_inputs must be provided.") - if prompt_token_ids is not None and multi_modal_inputs is not None: + if self.inference_engine_type != "hf" and prompt_token_ids is not None and multi_modal_inputs is not None: raise ValueError("Both prompt_token_ids and multi_modal_inputs can not be provided at the same time.") # if inference engine is vllm @@ -883,6 +1122,155 @@ def engine_generate_local( output_token_ids=sglang_outputs[i]["output_ids"], ) for i in range(len(sglang_outputs)) ] + elif self.inference_engine_type == "hf": + if prompt_token_ids is None: + raise ValueError("Local HF inference requires prompt_token_ids.") + + from lightrft.datasets.utils import zero_pad_sequences + + model = getattr(self.inference_engine, "model", self.inference_engine) + model_config = getattr(model, "config", None) + eos_token_id = getattr(self.inference_tokenizer, "eos_token_id", None) + pad_token_id = getattr(self.inference_tokenizer, "pad_token_id", None) + + if eos_token_id is None and model_config is not None: + eos_token_id = getattr(model_config, "eos_token_id", None) + if pad_token_id is None and model_config is not None: + pad_token_id = getattr(model_config, "pad_token_id", None) + + if isinstance(eos_token_id, (list, tuple)): + eos_token_id = eos_token_id[0] + if isinstance(pad_token_id, (list, tuple)): + pad_token_id = pad_token_id[0] + if pad_token_id is None: + pad_token_id = eos_token_id + if eos_token_id is None: + raise ValueError("Unable to resolve eos_token_id for local HF inference engine.") + + normalized_prompt_ids = [] + for token_ids in prompt_token_ids: + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.tolist() + normalized_prompt_ids.append(token_ids) + + device = torch.cuda.current_device() + + def _prepare_tensor(tensor): + if tensor is None: + return None + if isinstance(tensor, torch.Tensor) and tensor.numel() == 0: + return None + return tensor.to(device, non_blocking=True) + + top_k = sampling_params.get("top_k", None) + if top_k is not None and top_k <= 0: + top_k = None + temperature = sampling_params.get("temperature", 1.0) + do_sample = sampling_params.get("do_sample", temperature is None or temperature > 0) + + def _run_local_hf_batch( + batch_prompt_token_ids, + batch_pixel_values=None, + batch_image_grid_thw=None, + batch_pixel_values_videos=None, + batch_video_grid_thw=None, + ): + prompt_tensors = [torch.tensor(token_ids, dtype=torch.long) for token_ids in batch_prompt_token_ids] + padded_input_ids = zero_pad_sequences(prompt_tensors, side="left", value=pad_token_id).to(device) + attention_mask = padded_input_ids.ne(pad_token_id).long() + prompt_lengths = attention_mask.sum(dim=1).detach().cpu() + + generate_t0 = time.time() + with torch.no_grad(): + sequences, attention_mask_out, _ = self.inference_engine.generate( + input_ids=padded_input_ids, + attention_mask=attention_mask, + pixel_values=_prepare_tensor(batch_pixel_values), + image_grid_thw=_prepare_tensor(batch_image_grid_thw), + pixel_values_videos=_prepare_tensor(batch_pixel_values_videos), + video_grid_thw=_prepare_tensor(batch_video_grid_thw), + top_k=top_k, + top_p=sampling_params.get("top_p", 1.0), + temperature=temperature, + do_sample=do_sample, + max_new_tokens=sampling_params.get("max_new_tokens", 1024), + min_new_tokens=sampling_params.get("min_new_tokens", 1), + repetition_penalty=sampling_params.get("repetition_penalty", 1.0), + no_repeat_ngram_size=sampling_params.get("no_repeat_ngram_size", 0), + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + ) + elapsed_s = round(time.time() - generate_t0, 4) + self.print( + "Local HF model.generate finished:", { + "batch_size": len(batch_prompt_token_ids), + "prompt_tokens": [len(token_ids) for token_ids in batch_prompt_token_ids], + "elapsed_s": elapsed_s, + } + ) + + output_start_idx = padded_input_ids.size(1) + sequences = sequences.detach().cpu() + attention_mask_out = attention_mask_out.detach().cpu() + + batch_outputs = [] + for idx in range(sequences.size(0)): + total_length = int(attention_mask_out[idx].sum().item()) + generated_length = max(total_length - int(prompt_lengths[idx].item()), 0) + output_end_idx = output_start_idx + generated_length + batch_outputs.append( + EasyDict( + prompt_token_ids=batch_prompt_token_ids[idx], + output_token_ids=sequences[idx, output_start_idx:output_end_idx].tolist(), + ) + ) + return batch_outputs, elapsed_s + + max_batch_size = max(int(getattr(self.config, "local_hf_generate_max_batch_size", 0) or 0), 0) + if max_batch_size > 0 and len(normalized_prompt_ids) > max_batch_size: + images_prefix = None + if images_num is not None: + images_prefix = [0] + for num in images_num: + images_prefix.append(images_prefix[-1] + num) + videos_prefix = None + if videos_num is not None: + videos_prefix = [0] + for num in videos_num: + videos_prefix.append(videos_prefix[-1] + num) + + def _slice_modal_tensor(tensor, offsets, start, end): + if tensor is None: + return None + if not isinstance(tensor, torch.Tensor): + return tensor + if tensor.numel() == 0: + return tensor + if offsets is not None: + return tensor[offsets[start]:offsets[end]] + return tensor[start:end] + + engine_outputs = [] + for start in range(0, len(normalized_prompt_ids), max_batch_size): + end = min(start + max_batch_size, len(normalized_prompt_ids)) + batch_outputs, chunk_elapsed_s = _run_local_hf_batch( + normalized_prompt_ids[start:end], + batch_pixel_values=_slice_modal_tensor(pixel_values, images_prefix, start, end), + batch_image_grid_thw=_slice_modal_tensor(image_grid_thw, images_prefix, start, end), + batch_pixel_values_videos=_slice_modal_tensor(pixel_values_videos, videos_prefix, start, end), + batch_video_grid_thw=_slice_modal_tensor(video_grid_thw, videos_prefix, start, end), + ) + engine_outputs.extend(batch_outputs) + return engine_outputs + + batch_outputs, chunk_elapsed_s = _run_local_hf_batch( + normalized_prompt_ids, + batch_pixel_values=pixel_values, + batch_image_grid_thw=image_grid_thw, + batch_pixel_values_videos=pixel_values_videos, + batch_video_grid_thw=video_grid_thw, + ) + return batch_outputs else: raise ValueError(f"Unsupported engine type: {self.inference_engine_type}") @@ -980,6 +1368,10 @@ def gather_and_generate( images_num=None, all_videos=None, videos_num=None, + all_images_pixel_values=None, + all_videos_pixel_values=None, + all_images_grid_thw=None, + all_videos_grid_thw=None, ): """ Gather prompts across distributed ranks and perform text/multimodal generation. @@ -1020,6 +1412,8 @@ def gather_and_generate( """ if self.inference_engine is None: raise NotImplementedError("Inference engine is not initialized.") + if self._uses_separate_hf_rollout_actor(): + sleep_engine = True self.wakeup_inference_engine() # is_multimodal = all_images is not None @@ -1029,35 +1423,51 @@ def gather_and_generate( is_multimodal = (((all_images is not None) and any(img is not None for img in all_images)) or ((all_videos is not None) and any(vid is not None for vid in all_videos))) - if is_multimodal: - inputs = self._build_multimodal_inputs( - all_prompts=all_prompts, - all_images=all_images, - images_num=images_num, - all_videos=all_videos, - videos_num=videos_num, - ) - else: + if self.inference_engine_type == "hf": inputs = all_prompt_token_ids assert inputs is not None + self.print(f"Start VLM gather_and_generate ..., total prompts: {len(inputs)}") + all_outputs = self.engine_generate_local( + sampling_params=sampling_params, + prompt_token_ids=inputs, + pixel_values=all_images_pixel_values if is_multimodal else None, + image_grid_thw=all_images_grid_thw if is_multimodal else None, + pixel_values_videos=all_videos_pixel_values if is_multimodal else None, + video_grid_thw=all_videos_grid_thw if is_multimodal else None, + images_num=images_num if is_multimodal else None, + videos_num=videos_num if is_multimodal else None, + ) + local_outputs = all_outputs + else: + if is_multimodal: + inputs = self._build_multimodal_inputs( + all_prompts=all_prompts, + all_images=all_images, + images_num=images_num, + all_videos=all_videos, + videos_num=videos_num, + ) + else: + inputs = all_prompt_token_ids + assert inputs is not None - inputs = gather_inputs_object_for_inference(input_data=inputs, group=self.engine_mp_group) + inputs = gather_inputs_object_for_inference(input_data=inputs, group=self.engine_mp_group) - self.print(f"Start VLM gather_and_generate ..., total prompts: {len(inputs)}") + self.print(f"Start VLM gather_and_generate ..., total prompts: {len(inputs)}") - all_outputs = self.engine_generate_local( - sampling_params=sampling_params, - prompt_token_ids=None if is_multimodal else inputs, - multi_modal_inputs=inputs if is_multimodal else None, - ) + all_outputs = self.engine_generate_local( + sampling_params=sampling_params, + prompt_token_ids=None if is_multimodal else inputs, + multi_modal_inputs=inputs if is_multimodal else None, + ) - engine_mp_size = torch.distributed.get_world_size(self.engine_mp_group) - num_prompts_per_rank = len(all_outputs) // engine_mp_size - assert len(all_outputs) % engine_mp_size == 0 - cur_rank = torch.distributed.get_rank(self.engine_mp_group) - local_outputs = all_outputs[cur_rank * num_prompts_per_rank:(cur_rank + 1) * num_prompts_per_rank] + engine_mp_size = torch.distributed.get_world_size(self.engine_mp_group) + num_prompts_per_rank = len(all_outputs) // engine_mp_size + assert len(all_outputs) % engine_mp_size == 0 + cur_rank = torch.distributed.get_rank(self.engine_mp_group) + local_outputs = all_outputs[cur_rank * num_prompts_per_rank:(cur_rank + 1) * num_prompts_per_rank] - if self.inference_engine_type == "sglang": + if is_multimodal and self.inference_engine_type == "sglang": # For SGLang VLM case, prompt_token_ids is set to None in engine_generate_local # We need to fill it with the actual token_ids here for i, output in enumerate(local_outputs): @@ -1084,6 +1494,12 @@ def update_engine_weights(self, actor): if self.inference_engine is None: self.print("Skip update engine weights since inference engine is not initialized.") return + if self.inference_engine_type == "hf": + if self._uses_separate_hf_rollout_actor(): + self._sync_separate_hf_rollout_actor(actor) + else: + self.print("Skip update engine weights for local HF engine because it reuses the actor directly.") + return # 1. wakeup engine if sleeped self.wakeup_inference_engine() diff --git a/lightrft/strategy/vllm_utils/__init__.py b/lightrft/strategy/vllm_utils/__init__.py index fc9e67e6..2f6cd05f 100644 --- a/lightrft/strategy/vllm_utils/__init__.py +++ b/lightrft/strategy/vllm_utils/__init__.py @@ -14,6 +14,7 @@ To use vLLM backend, install with: pip install "LightRFT[vllm]" """ +import os from typing import Any @@ -156,6 +157,8 @@ def get_vllm_engine( dtype=dtype, tensor_parallel_size=tp_size, gpu_memory_utilization=mem_util, + trust_remote_code=True, + allowed_local_media_path=os.environ.get("VLLM_ALLOWED_LOCAL_MEDIA_PATH", "/"), distributed_executor_backend="external_launcher", worker_cls="lightrft.strategy.vllm_utils.vllm_worker_wrap_no_ray.WorkerWrap", enable_sleep_mode=enable_sleep, diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 8230ea65..2d97f85d 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -25,6 +25,7 @@ import time import pathlib import warnings +from contextlib import contextmanager from typing import Callable, Dict, List, Tuple, Union, Optional from dataclasses import dataclass from copy import deepcopy @@ -138,6 +139,38 @@ class _SamplesOutput: inputs_extra_kwargs: Optional[dict] = None prompt_and_output: Optional[List[str]] = None + # Per-step PRM rewards (variant 2). When the reward model returns + # ``step_rewards`` / ``step_token_indices`` arrays alongside the + # trajectory-scalar score, they're stored here so that downstream + # advantage computation can scatter per-step credit to specific token + # positions (instead of only the EOS token). Both fields are + # micro-batch-sized lists; index i holds the per-step data for the i-th + # trajectory in the micro-batch as 1-D CPU tensors of equal length. + # Empty tensors mean "this trajectory uses trajectory-scalar rewards" + # and the legacy EOS-scatter path is taken in compute_reward. + step_rewards: Optional[List[torch.Tensor]] = None + step_token_indices: Optional[List[torch.Tensor]] = None + + +@dataclass +class _RewardBatchResult: + """Reward scores and optional auxiliary metrics for one micro-batch.""" + + scores: torch.Tensor + metrics: Optional[Dict[str, torch.Tensor]] = None + # Variant 2 (per-step PRM): when present, list[i] is a 1-D CPU tensor of + # step rewards / step boundary token indices for trajectory i in the + # micro-batch. When None or all-empty, the trajectory-scalar (EOS-scatter) + # path is used downstream. + step_rewards: Optional[List[torch.Tensor]] = None + step_token_indices: Optional[List[torch.Tensor]] = None + + +class _NullProfiler: + @contextmanager + def section(self, _name: str): + yield + # ============================================================================ # Helper Classes @@ -284,8 +317,10 @@ def process_multimodal_batch( processor_kwargs = { "text": all_prompts_multimodal.copy(), "add_special_tokens": False, + "padding": True, "max_length": self.prompt_max_len, "truncation": True, + "return_tensors": "pt", } if flat_images: processor_kwargs["images"] = flat_images @@ -301,6 +336,15 @@ def process_multimodal_batch( all_images_grid_thw_multimodal = inputs_multimodal.get("image_grid_thw", None) all_videos_grid_thw_multimodal = inputs_multimodal.get("video_grid_thw", None) + # Some VLM processors (for example URSA) return batched image tensors directly + # and do not expose Qwen2-VL style grid metadata. In that case we synthesize a + # minimal per-image grid so the existing sample/replay slicing logic can still + # split one tensor slice per image while the model simply ignores the grid input. + if flat_images and all_images_grid_thw_multimodal is None: + all_images_grid_thw_multimodal = torch.ones((len(flat_images), 3), dtype=torch.long) + if flat_videos and all_videos_grid_thw_multimodal is None: + all_videos_grid_thw_multimodal = torch.ones((len(flat_videos), 3), dtype=torch.long) + # ===== Stage 4: Merge back in original order ===== total_samples = L * N all_prompts_out = [None] * total_samples @@ -446,16 +490,7 @@ def __init__( :type packing_samples: bool """ self.reward_model = reward_model - # Ensure remote_rm_url is a list for consistent iteration - if remote_rm_url is None: - self.remote_rm_url = None - elif isinstance(remote_rm_url, str): - self.remote_rm_url = [remote_rm_url] - elif isinstance(remote_rm_url, (list, tuple)): - self.remote_rm_url = list(remote_rm_url) - else: - raise TypeError(f"remote_rm_url must be str, list, tuple, or None, got {type(remote_rm_url).__name__}") - + self.remote_rm_url = remote_rm_url self.custom_reward_func = custom_reward_func self.reward_fn = reward_fn self.reward_fn_label_map = reward_fn_label_map or {} @@ -581,7 +616,7 @@ def _compute_local_rewards( self.strategy.reload_model(rm) # Compute rewards for each RM - # all_rewards_list[rm_idx][micro_batch_idx] = Tensor(batch_size,) + # all_rewards_list[rm_idx][micro_batch_idx] = _RewardBatchResult(batch_size,) all_rewards_list = [] for rm_idx, rm in enumerate(rm_list): @@ -602,7 +637,7 @@ def _compute_single_rm_rewards( outputs: List[_SamplesOutput], vlm_mode: bool, device: torch.device, - ) -> List[torch.Tensor]: + ) -> List[_RewardBatchResult]: """ Compute rewards for a single reward model across all micro-batches. @@ -616,8 +651,8 @@ def _compute_single_rm_rewards( :type vlm_mode: bool :param device: Target device :type device: torch.device - :return: List of reward tensors, one per micro-batch - :rtype: List[torch.Tensor] + :return: List of reward results, one per micro-batch + :rtype: List[_RewardBatchResult] """ # Check if this is a custom engine model (non-torch base_model) is_custom_engine = ( @@ -640,7 +675,7 @@ def _compute_filtered_rewards( rm_idx: int, outputs: List[_SamplesOutput], device: torch.device, - ) -> List[torch.Tensor]: + ) -> List[_RewardBatchResult]: """ Compute rewards using optimized filtering (only process relevant samples). @@ -655,8 +690,8 @@ def _compute_filtered_rewards( :type outputs: List[_SamplesOutput] :param device: Target device :type device: torch.device - :return: List of reward tensors per micro-batch - :rtype: List[torch.Tensor] + :return: List of reward results per micro-batch + :rtype: List[_RewardBatchResult] """ # Get RM key from inverse label map rm_key = self.inv_label_map.get(rm_idx) @@ -695,7 +730,12 @@ def _compute_filtered_rewards( # ========== Process Stage: Compute or skip ========== if not needed_positions: # No samples need this RM, return zeros for all micro-batches - return [torch.zeros(len(output.labels), dtype=torch.float32, device=device) for output in outputs] + return [ + _RewardBatchResult( + scores=torch.zeros(len(output.labels), dtype=torch.float32, device=device), + metrics=None, + ) for output in outputs + ] # Run single forward pass on filtered samples rm_output = rm( @@ -717,14 +757,14 @@ def _compute_filtered_rewards( for (mb_idx, samp_idx), score in zip(needed_positions, filtered_scores): micro_batch_rewards[mb_idx][samp_idx] = score - return micro_batch_rewards + return [_RewardBatchResult(scores=rewards, metrics=None) for rewards in micro_batch_rewards] def _compute_batched_custom_engine_rewards( self, rm, outputs: List[_SamplesOutput], device: torch.device, # noqa: ARG002 (unused but kept for API consistency) - ) -> List[torch.Tensor]: + ) -> List[_RewardBatchResult]: """ Compute rewards using custom engine with full batch processing (legacy path). @@ -734,8 +774,8 @@ def _compute_batched_custom_engine_rewards( :type outputs: List[_SamplesOutput] :param device: Target device (unused but kept for API consistency) :type device: torch.device - :return: List of reward tensors per micro-batch - :rtype: List[torch.Tensor] + :return: List of reward results per micro-batch + :rtype: List[_RewardBatchResult] """ # Flatten all micro-batches into single batch flat_data = { @@ -767,7 +807,34 @@ def _compute_batched_custom_engine_rewards( # Split back into micro-batches batch_sizes = [len(output.prompt_and_output) for output in outputs] - return list(all_scores.split(batch_sizes)) + return [ + _RewardBatchResult(scores=scores.to(device=device, dtype=torch.float32), metrics=None) + for scores in all_scores.split(batch_sizes) + ] + + @staticmethod + def _normalize_reward_metrics( + rm_output: Dict[str, torch.Tensor], + batch_size: int, + device: torch.device, + ) -> Optional[Dict[str, torch.Tensor]]: + metrics: Dict[str, torch.Tensor] = {} + for key, value in rm_output.items(): + if key == "score": + continue + if not isinstance(value, torch.Tensor): + if isinstance(value, (int, float, bool)): + value = torch.tensor(value, dtype=torch.float32) + else: + continue + metric = value.to(device=device) + if metric.ndim == 0: + metric = metric.repeat(batch_size) + metric = metric.reshape(-1).float() + if metric.numel() != batch_size: + continue + metrics[key] = metric + return metrics or None def _compute_standard_torch_rewards( self, @@ -775,7 +842,7 @@ def _compute_standard_torch_rewards( outputs: List[_SamplesOutput], vlm_mode: bool, # noqa: ARG002 (kept for future VLM-specific logic) device: torch.device, - ) -> List[torch.Tensor]: + ) -> List[_RewardBatchResult]: """ Compute rewards using standard PyTorch reward model. @@ -789,10 +856,10 @@ def _compute_standard_torch_rewards( :type vlm_mode: bool :param device: Target device :type device: torch.device - :return: List of reward tensors per micro-batch - :rtype: List[torch.Tensor] + :return: List of reward results per micro-batch + :rtype: List[_RewardBatchResult] """ - micro_batch_rewards = [] + micro_batch_rewards: List[_RewardBatchResult] = [] for output in outputs: # Unpack sequences if needed @@ -805,22 +872,44 @@ def _compute_standard_torch_rewards( rm_output = rm( sequences, output.attention_mask, - references=output.references, prompt_and_output=output.prompt_and_output, raw_images=output.raw_images, img_num=output.image_num, + references=output.references, + labels=output.labels, **output.inputs_extra_kwargs, ) - score = rm_output["score"] if isinstance(rm_output, dict) else rm_output - micro_batch_rewards.append(torch.as_tensor(score, dtype=torch.float32, device=device)) + if isinstance(rm_output, dict): + score = torch.as_tensor(rm_output["score"], dtype=torch.float32, device=device) + metrics = self._normalize_reward_metrics(rm_output, score.numel(), device) + # Variant 2 (per-step PRM reward): the reward model may emit + # per-trajectory step_rewards / step_token_indices lists. They + # come back as Python lists of CPU tensors (variable length per + # trajectory) — pass them through unchanged so compute_reward + # downstream can scatter per-step credit. + step_rewards = rm_output.get("step_rewards") + step_token_indices = rm_output.get("step_token_indices") + else: + score = torch.as_tensor(rm_output, dtype=torch.float32, device=device) + metrics = None + step_rewards = None + step_token_indices = None + micro_batch_rewards.append( + _RewardBatchResult( + scores=score, + metrics=metrics, + step_rewards=step_rewards, + step_token_indices=step_token_indices, + ) + ) return micro_batch_rewards def _aggregate_rewards( self, outputs: List[_SamplesOutput], - all_rewards_list: List[List[torch.Tensor]], + all_rewards_list: List[List[_RewardBatchResult]], is_multi_rm: bool, ) -> None: """ @@ -828,8 +917,8 @@ def _aggregate_rewards( :param outputs: Sample outputs (modified in-place) :type outputs: List[_SamplesOutput] - :param all_rewards_list: Nested list [rm_idx][micro_batch_idx] -> Tensor - :type all_rewards_list: List[List[torch.Tensor]] + :param all_rewards_list: Nested list [rm_idx][micro_batch_idx] -> reward result + :type all_rewards_list: List[List[_RewardBatchResult]] :param is_multi_rm: Whether using multiple reward models :type is_multi_rm: bool """ @@ -838,7 +927,8 @@ def _aggregate_rewards( for mb_idx in range(num_micro_batches): # Collect rewards from all RMs for this micro-batch - same_batch_rewards = [all_rewards_list[rm_idx][mb_idx] for rm_idx in range(num_rms)] + same_batch_results = [all_rewards_list[rm_idx][mb_idx] for rm_idx in range(num_rms)] + same_batch_rewards = [result.scores for result in same_batch_results] if is_multi_rm: # Use custom aggregation function @@ -850,6 +940,7 @@ def _aggregate_rewards( rewards, reward_metrics = self.reward_fn( model_reward_list=same_batch_rewards, + model_reward_metrics_list=[result.metrics for result in same_batch_results], labels=outputs[mb_idx].labels, queries=queries, refs=outputs[mb_idx].references, @@ -859,8 +950,17 @@ def _aggregate_rewards( outputs[mb_idx].reward_metrics = reward_metrics else: # Single RM, use score directly - outputs[mb_idx].rewards = same_batch_rewards[0] - outputs[mb_idx].reward_metrics = None + outputs[mb_idx].rewards = same_batch_results[0].scores + outputs[mb_idx].reward_metrics = same_batch_results[0].metrics + # Variant 2 (per-step PRM): forward step_rewards / token + # indices from the single reward model into outputs so + # _process_experiences / compute_reward can see them. Multi-RM + # mode is intentionally NOT supported here — per-step credit + # assignment with multiple reward models would require a + # bespoke aggregator (open question for follow-up); we keep + # the multi-RM path on trajectory-scalar rewards only. + outputs[mb_idx].step_rewards = same_batch_results[0].step_rewards + outputs[mb_idx].step_token_indices = same_batch_results[0].step_token_indices # ============================================================================ @@ -893,7 +993,7 @@ class FastExperienceMaker(NaiveExperienceMaker): processor: Multimodal processor for vision-language models *args, **kwargs: Arguments passed to parent NaiveExperienceMaker """ - def __init__(self, *args, packing_samples: bool = False, processor=None, **kwargs): + def __init__(self, *args, packing_samples: bool = False, processor=None, profiler=None, **kwargs): """ Initialize FastExperienceMaker. @@ -913,6 +1013,7 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg self.backend = self.strategy.args.engine_type self.packing_samples = packing_samples self.processor = processor + self.profiler = profiler if profiler is not None else _NullProfiler() # Initialize tokenizer (extract from processor if needed) if self.processor is not None: @@ -938,31 +1039,9 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg else: self.multimodal_processor = None - # For On-Policy Distillation (OPD), prefer dedicated teacher_model_url. - # Fall back to remote_rm_url with deprecation warning for backwards compatibility. - if advantage_estimator == "on_policy_distillation": - teacher_url = getattr(self.strategy.args, 'teacher_model_url', None) - if teacher_url is not None: - self.teacher_model_url = teacher_url - elif self.remote_rm_url is not None: - import warnings - warnings.warn( - "Using --remote_rm_url as teacher URL is deprecated. " - "Use --teacher_model_url instead.", - DeprecationWarning, - stacklevel=2, - ) - self.teacher_model_url = self.remote_rm_url - else: - self.teacher_model_url = None - rm_url_for_reward_engine = None # Don't use remote_rm_url for rewards in OPD mode - else: - self.teacher_model_url = None - rm_url_for_reward_engine = self.remote_rm_url - self.reward_engine = RewardComputationEngine( reward_model=self.reward_model, - remote_rm_url=rm_url_for_reward_engine, + remote_rm_url=self.remote_rm_url, custom_reward_func=getattr(self, "custom_reward_func", None), reward_fn=self.reward_fn, reward_fn_label_map=getattr(self, "reward_fn_label_map", None), @@ -1014,83 +1093,76 @@ def make_experience_list( """ config = self.strategy.config - # Normalize images if provided - if all_images is not None: - if self.multimodal_processor is None: - raise ValueError( - "Multimodal data (images) provided but processor was not initialized. " - "Please provide a processor when initializing FastExperienceMaker for VLM support." - ) - all_images = normalize_images(all_images) - - # Normalize videos if provided - if all_videos is not None: - if self.multimodal_processor is None: - raise ValueError( - "Multimodal data (videos) provided but processor was not initialized. " - "Please provide a processor when initializing FastExperienceMaker for VLM support." + with self.profiler.section("collect/total"): + if all_images is not None: + if self.multimodal_processor is None: + raise ValueError( + "Multimodal data (images) provided but processor was not initialized. " + "Please provide a processor when initializing FastExperienceMaker for VLM support." + ) + all_images = normalize_images(all_images) + + if all_videos is not None: + if self.multimodal_processor is None: + raise ValueError( + "Multimodal data (videos) provided but processor was not initialized. " + "Please provide a processor when initializing FastExperienceMaker for VLM support." + ) + all_videos = normalize_videos(all_videos) + + images_num = (get_images_num(all_images) if self.multimodal_processor and all_images is not None else None) + videos_num = (get_videos_num(all_videos) if self.multimodal_processor and all_videos is not None else None) + + Timer.start(' generate_samples') + with self.profiler.section("collect/generate"): + samples_list = self.generate_samples( + all_prompts, + all_images=all_images, + images_num=images_num, + all_videos=all_videos, + videos_num=videos_num, + all_references=all_references, + all_labels=all_labels, + **generate_kwargs, ) - all_videos = normalize_videos(all_videos) - - # Get image counts - images_num = (get_images_num(all_images) if self.multimodal_processor and all_images is not None else None) - - # Get video counts - videos_num = (get_videos_num(all_videos) if self.multimodal_processor and all_videos is not None else None) - - # ========== Stage 1: Sample Generation ========== - Timer.start(' generate_samples') - samples_list = self.generate_samples( - all_prompts, - all_images=all_images, - images_num=images_num, - all_videos=all_videos, - videos_num=videos_num, - all_references=all_references, - all_labels=all_labels, - **generate_kwargs, - ) - Timer.stop(' generate_samples') + Timer.stop(' generate_samples') - torch.distributed.barrier() - torch.cuda.synchronize() + torch.distributed.barrier() + torch.cuda.synchronize() - # ========== Stage 2: Shard-Parallel Preprocessing ========== - all_samples = self.strategy.sp_data_processor.preprocess(samples_list) + with self.profiler.section("collect/sp_preprocess"): + all_samples = self.strategy.sp_data_processor.preprocess(samples_list) - # ========== Stage 3: Model Inference ========== - Timer.start(' make_experience') - experiences = self._make_experience_list_by_model(all_samples) - Timer.stop(' make_experience') + Timer.start(' make_experience') + with self.profiler.section("collect/model_total"): + experiences = self._make_experience_list_by_model(all_samples) + Timer.stop(' make_experience') - # ========== Stage 4: Shard-Parallel Postprocessing ========== - experiences = self.strategy.sp_data_processor.postprocess(experiences) + with self.profiler.section("collect/sp_postprocess"): + experiences = self.strategy.sp_data_processor.postprocess(experiences) - # ========== Stage 5: Reward Processing ========== - experiences, rewards = self._process_experiences( # GRPO's -mean / std operation is performed in this method - experiences, generate_kwargs.get("max_new_tokens", 1024) - ) + with self.profiler.section("collect/process_rewards"): + experiences, rewards = self._process_experiences( + experiences, generate_kwargs.get("max_new_tokens", 1024) + ) - # ========== Stage 6: Multi-Image/Video Handling ========== - if (images_num is not None and not all(num == 1 for num in images_num)) or \ - (videos_num is not None and not all(num == 1 for num in videos_num)): - # Expand image_num by n_samples_per_prompt - expanded_images_num = sum([[num] * config.n_samples_per_prompt - for num in images_num], []) if images_num is not None else None + if (images_num is not None and not all(num == 1 for num in images_num)) or \ + (videos_num is not None and not all(num == 1 for num in videos_num)): + expanded_images_num = sum([[num] * config.n_samples_per_prompt + for num in images_num], []) if images_num is not None else None - expanded_videos_num = sum([[num] * config.n_samples_per_prompt - for num in videos_num], []) if videos_num is not None else None + expanded_videos_num = sum([[num] * config.n_samples_per_prompt + for num in videos_num], []) if videos_num is not None else None - self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num) + self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num) - # ========== Stage 6.5: On-Policy Distillation Teacher Log-Probs ========== - if config.advantage_estimator == "on_policy_distillation": - self._fetch_teacher_logprobs(experiences) + if config.advantage_estimator == "on_policy_distillation": + self._fetch_teacher_logprobs(experiences) - # ========== Stage 7: Advantage Computation ========== - experiences = self._compute_advantages_and_returns(experiences, rewards, generate_kwargs) + with self.profiler.section("collect/advantages"): + experiences = self._compute_advantages_and_returns(experiences, rewards, generate_kwargs) - return experiences + return experiences @torch.no_grad() def generate_samples( @@ -1141,7 +1213,6 @@ def generate_samples( is_multimodal = all_images is not None or all_videos is not None n_samples = config.n_samples_per_prompt - # Initialize multimodal-specific variables to None all_images_num = None all_videos_num = None all_images_pixel_values = None @@ -1149,144 +1220,144 @@ def generate_samples( all_images_grid_thw = None all_videos_grid_thw = None - # ========== Configure Sampling Parameters ========== - if config.engine_type == "vllm": - # vLLM-specific sampling configuration - # Note: vLLM is an optional dependency. Install with: pip install "LightRFT[vllm]" - # This import is conditional and only executed when engine_type is "vllm" - from vllm import SamplingParams - - # For vllm>=0.13.0, truncate_prompt_tokens must not exceed max_model_len - # For older versions, we can use 8192 directly without validation - if vllm_ge_0130(): - max_model_len = self.strategy.inference_engine.llm_engine.model_config.max_model_len - truncate_tokens = min(8192, max_model_len) + with self.profiler.section("collect/generate_prepare"): + if config.engine_type == "vllm": + from vllm import SamplingParams + + if vllm_ge_0130(): + max_model_len = self.strategy.inference_engine.llm_engine.model_config.max_model_len + truncate_tokens = min(8192, max_model_len) + else: + truncate_tokens = 8192 + + sampling_params = SamplingParams( + temperature=generate_kwargs.get("temperature", 1.0), + top_p=generate_kwargs.get("top_p", 1.0), + top_k=generate_kwargs.get("top_k", -1), + repetition_penalty=generate_kwargs.get("repetition_penalty", 1.0), + max_tokens=generate_kwargs.get("max_new_tokens", 1024), + min_tokens=generate_kwargs.get("min_new_tokens", 1), + skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), + include_stop_str_in_output=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + truncate_prompt_tokens=truncate_tokens, + ) + elif config.engine_type == "sglang": + sampling_params = dict( + n=1, + temperature=generate_kwargs.get("temperature", 1.0), + top_p=generate_kwargs.get("top_p", 1.0), + top_k=generate_kwargs.get("top_k", -1), + max_new_tokens=generate_kwargs.get("max_new_tokens", 1024), + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=generate_kwargs.get("repetition_penalty", 1.0), + skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), + spaces_between_special_tokens=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + elif config.engine_type == "hf": + max_new_tokens = generate_kwargs.get("max_new_tokens", 1024) + local_hf_max_new_tokens = int(getattr(config, "local_hf_max_new_tokens", 0) or 0) + if local_hf_max_new_tokens > 0: + max_new_tokens = min(max_new_tokens, local_hf_max_new_tokens) + sampling_params = dict( + temperature=generate_kwargs.get("temperature", 1.0), + top_p=generate_kwargs.get("top_p", 1.0), + top_k=generate_kwargs.get("top_k", -1), + max_new_tokens=max_new_tokens, + min_new_tokens=generate_kwargs.get("min_new_tokens", 1), + do_sample=generate_kwargs.get("do_sample", True), + repetition_penalty=generate_kwargs.get("repetition_penalty", 1.0), + no_repeat_ngram_size=generate_kwargs.get("no_repeat_ngram_size", 0), + skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) else: - truncate_tokens = 8192 - - sampling_params = SamplingParams( - temperature=generate_kwargs.get("temperature", 1.0), - top_p=generate_kwargs.get("top_p", 1.0), - top_k=generate_kwargs.get("top_k", -1), - max_tokens=generate_kwargs.get("max_new_tokens", 1024), - min_tokens=generate_kwargs.get("min_new_tokens", 1), - skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), - include_stop_str_in_output=True, - ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", - truncate_prompt_tokens=truncate_tokens, - ) - elif config.engine_type == "sglang": - # SGLang-specific sampling configuration (default backend) - sampling_params = dict( - n=1, - temperature=generate_kwargs.get("temperature", 1.0), - top_p=generate_kwargs.get("top_p", 1.0), - top_k=generate_kwargs.get("top_k", -1), - max_new_tokens=generate_kwargs.get("max_new_tokens", 1024), - presence_penalty=0.0, - frequency_penalty=0.0, - repetition_penalty=1.0, - skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), - spaces_between_special_tokens=True, - ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", - ) - else: - raise ValueError(f"Unsupported engine type: {config.engine_type}") + raise ValueError(f"Unsupported engine type: {config.engine_type}") - # ========== Expand Labels ========== - if all_labels is not None: - all_labels = sum([[label] * n_samples for label in all_labels], []) + if all_labels is not None: + all_labels = sum([[label] * n_samples for label in all_labels], []) - # ========== Process Multimodal Data ========== - if is_multimodal: - processed_data = self.multimodal_processor.process_multimodal_batch( - all_prompts=all_prompts, - all_images=all_images, - all_references=all_references, - images_num=images_num, - n_samples_per_prompt=n_samples, - all_videos=all_videos, - videos_num=videos_num, - ) - all_prompt_token_ids = processed_data["all_prompt_token_ids"] - all_prompts = processed_data["all_prompts"] - all_images = processed_data["all_images"] - all_videos = processed_data["all_videos"] - all_images_num = processed_data["all_images_num"] - all_videos_num = processed_data["all_videos_num"] - all_images_grid_thw = processed_data["all_images_grid_thw"] - all_videos_grid_thw = processed_data["all_videos_grid_thw"] - all_images_pixel_values = processed_data["all_images_pixel_values"] - all_videos_pixel_values = processed_data["all_videos_pixel_values"] - all_references = processed_data.get("all_references", None) - else: - # Text-only processing - tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False) - all_prompt_token_ids = sum([[token_ids] * n_samples for token_ids in tokenized["input_ids"]], []) + if is_multimodal: + processed_data = self.multimodal_processor.process_multimodal_batch( + all_prompts=all_prompts, + all_images=all_images, + all_references=all_references, + images_num=images_num, + n_samples_per_prompt=n_samples, + all_videos=all_videos, + videos_num=videos_num, + ) + all_prompt_token_ids = processed_data["all_prompt_token_ids"] + all_prompts = processed_data["all_prompts"] + all_images = processed_data["all_images"] + all_videos = processed_data["all_videos"] + all_images_num = processed_data["all_images_num"] + all_videos_num = processed_data["all_videos_num"] + all_images_grid_thw = processed_data["all_images_grid_thw"] + all_videos_grid_thw = processed_data["all_videos_grid_thw"] + all_images_pixel_values = processed_data["all_images_pixel_values"] + all_videos_pixel_values = processed_data["all_videos_pixel_values"] + all_references = processed_data.get("all_references", None) + else: + tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False) + all_prompt_token_ids = sum([[token_ids] * n_samples for token_ids in tokenized["input_ids"]], []) - # ========== Generate via Inference Engine ========== - # Call fire_sampling function or direct generation try: - if hasattr(self.strategy.args, 'use_fire') and self.strategy.args.use_fire: - # Use FIRE sampling (Flaming-hot Initiation with Regular Execution) - # According to the paper (https://arxiv.org/abs/2410.21236), FIRE only changes - # the temperature for the first token. All other sampling parameters (top_k, top_p, etc.) - # are kept the same between first token and remaining tokens. - sleep_engine = getattr(self.strategy.args, 'enable_engine_sleep', False) - - def generate_fn( - sampling_params, - all_prompt_token_ids, - all_prompts=None, - all_images=None, - all_videos=None, - images_num=None, - videos_num=None, - ): - return self.strategy.gather_and_generate( - sampling_params=sampling_params, + with self.profiler.section("collect/generate_engine"): + if hasattr(self.strategy.args, 'use_fire') and self.strategy.args.use_fire: + sleep_engine = getattr(self.strategy.args, 'enable_engine_sleep', False) + + def generate_fn( + sampling_params, + all_prompt_token_ids, + all_prompts=None, + all_images=None, + all_videos=None, + images_num=None, + videos_num=None, + ): + return self.strategy.gather_and_generate( + sampling_params=sampling_params, + all_prompt_token_ids=all_prompt_token_ids, + all_prompts=all_prompts, + all_images=all_images, + all_videos=all_videos, + images_num=images_num, + videos_num=videos_num, + sleep_engine=sleep_engine, + ) + + all_outputs = fire_sampling( all_prompt_token_ids=all_prompt_token_ids, + generate_fn=generate_fn, + engine_type=config.engine_type, + first_token_temperature=generate_kwargs.get("first_token_temperature", 10.0), + temperature=generate_kwargs.get("temperature", 1.0), + is_multimodal=is_multimodal, all_prompts=all_prompts, all_images=all_images, all_videos=all_videos, - images_num=images_num, - videos_num=videos_num, - sleep_engine=sleep_engine, + all_images_num=all_images_num, + all_videos_num=all_videos_num, + sampling_params=sampling_params, + ) + else: + all_outputs = self.strategy.gather_and_generate( + sampling_params=sampling_params, + all_prompt_token_ids=all_prompt_token_ids, + all_prompts=all_prompts if is_multimodal else None, + sleep_engine=self.strategy.args.enable_engine_sleep, + all_images=all_images if is_multimodal else None, + all_videos=all_videos if is_multimodal else None, + images_num=all_images_num if is_multimodal else None, + videos_num=all_videos_num if is_multimodal else None, + all_images_pixel_values=all_images_pixel_values if is_multimodal else None, + all_videos_pixel_values=all_videos_pixel_values if is_multimodal else None, + all_images_grid_thw=all_images_grid_thw if is_multimodal else None, + all_videos_grid_thw=all_videos_grid_thw if is_multimodal else None, ) - - all_outputs = fire_sampling( - all_prompt_token_ids=all_prompt_token_ids, - generate_fn=generate_fn, - engine_type=config.engine_type, - first_token_temperature=generate_kwargs.get( - "first_token_temperature", - getattr(self.strategy.args, "first_token_temperature", 10.0), - ), - temperature=generate_kwargs.get("temperature", 1.0), - # Note: first_token_top_k and first_token_top_p are deprecated and ignored - # The function will use top_k and top_p from sampling_params for both stages - is_multimodal=is_multimodal, - all_prompts=all_prompts, - all_images=all_images, - all_videos=all_videos, - all_images_num=all_images_num, - all_videos_num=all_videos_num, - sampling_params=sampling_params, - tokenizer=self.tokenizer, - ) - else: - # maybe this can be called in if and else respectively? or like this? - # Use original single-shot generation - all_outputs = self.strategy.gather_and_generate( - sampling_params=sampling_params, - all_prompt_token_ids=all_prompt_token_ids, - all_prompts=all_prompts if is_multimodal else None, - sleep_engine=self.strategy.args.enable_engine_sleep, - all_images=all_images if is_multimodal else None, - all_videos=all_videos if is_multimodal else None, - images_num=all_images_num if is_multimodal else None, - videos_num=all_videos_num if is_multimodal else None, - ) except ValueError as e: if "prompt" in str(e) and "too long" in str(e): self.strategy.print(f"[Skip] {e}") @@ -1294,68 +1365,67 @@ def generate_fn( else: raise - # ========== Process Outputs into Samples ========== - samples_list = [] - image_patch_idx = 0 - video_patch_idx = 0 - image_start_idx = 0 - video_start_idx = 0 - - for i in range(0, len(all_outputs), config.micro_rollout_batch_size): - micro_batch_outputs = all_outputs[i:i + config.micro_rollout_batch_size] - micro_batch_prompts = all_prompts[i:i + config.micro_rollout_batch_size] - - # Extract micro-batch data - micro_batch_grid_thw = None - micro_batch_video_grid_thw = None - micro_batch_raw_images = None - - if is_multimodal: - rollout_image_count = sum(all_images_num[i:i + config.micro_rollout_batch_size]) - micro_batch_grid_thw = all_images_grid_thw[image_start_idx:image_start_idx + rollout_image_count] - micro_batch_raw_images = all_images[i:i + config.micro_rollout_batch_size] - image_start_idx += rollout_image_count - - rollout_video_count = sum(all_videos_num[i:i + config.micro_rollout_batch_size]) - micro_batch_video_grid_thw = all_videos_grid_thw[video_start_idx:video_start_idx + rollout_video_count] - video_start_idx += rollout_video_count - - micro_batch_references = (all_references[i:i + config.micro_rollout_batch_size] if all_references else None) - micro_batch_labels = (all_labels[i:i + config.micro_rollout_batch_size] if all_labels else None) - - # Build samples - if not self.packing_samples: - sample, updated_patch_idx, updated_video_patch_idx = self._build_unpacked_sample( - outputs=micro_batch_outputs, - prompts=micro_batch_prompts, - labels=micro_batch_labels, - references=micro_batch_references, - is_multimodal=is_multimodal, - grid_thw=micro_batch_grid_thw, - video_grid_thw=micro_batch_video_grid_thw, - raw_images=micro_batch_raw_images, - pixel_values=all_images_pixel_values if is_multimodal else None, - pixel_values_videos=all_videos_pixel_values if is_multimodal else None, - images_num=all_images_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None, - videos_num=all_videos_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None, - image_patch_idx=image_patch_idx, - video_patch_idx=video_patch_idx, - ) - # Update patch indices from the returned values - if updated_patch_idx is not None: - image_patch_idx = updated_patch_idx - if updated_video_patch_idx is not None: - video_patch_idx = updated_video_patch_idx - samples_list.append(sample) - else: - # Packed samples - sample = self._build_packed_sample( - outputs=micro_batch_outputs, - prompts=micro_batch_prompts, - labels=micro_batch_labels, - references=micro_batch_references, + with self.profiler.section("collect/generate_build_samples"): + samples_list = [] + image_patch_idx = 0 + video_patch_idx = 0 + image_start_idx = 0 + video_start_idx = 0 + + for i in range(0, len(all_outputs), config.micro_rollout_batch_size): + micro_batch_outputs = all_outputs[i:i + config.micro_rollout_batch_size] + micro_batch_prompts = all_prompts[i:i + config.micro_rollout_batch_size] + + micro_batch_grid_thw = None + micro_batch_video_grid_thw = None + micro_batch_raw_images = None + + if is_multimodal: + rollout_image_count = sum(all_images_num[i:i + config.micro_rollout_batch_size]) + micro_batch_grid_thw = all_images_grid_thw[image_start_idx:image_start_idx + rollout_image_count] + micro_batch_raw_images = all_images[i:i + config.micro_rollout_batch_size] + image_start_idx += rollout_image_count + + rollout_video_count = sum(all_videos_num[i:i + config.micro_rollout_batch_size]) + micro_batch_video_grid_thw = all_videos_grid_thw[video_start_idx:video_start_idx + + rollout_video_count] + video_start_idx += rollout_video_count + + micro_batch_references = ( + all_references[i:i + config.micro_rollout_batch_size] if all_references else None ) - samples_list.append(sample) + micro_batch_labels = (all_labels[i:i + config.micro_rollout_batch_size] if all_labels else None) + + if not self.packing_samples: + sample, updated_patch_idx, updated_video_patch_idx = self._build_unpacked_sample( + outputs=micro_batch_outputs, + prompts=micro_batch_prompts, + labels=micro_batch_labels, + references=micro_batch_references, + is_multimodal=is_multimodal, + grid_thw=micro_batch_grid_thw, + video_grid_thw=micro_batch_video_grid_thw, + raw_images=micro_batch_raw_images, + pixel_values=all_images_pixel_values if is_multimodal else None, + pixel_values_videos=all_videos_pixel_values if is_multimodal else None, + images_num=all_images_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None, + videos_num=all_videos_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None, + image_patch_idx=image_patch_idx, + video_patch_idx=video_patch_idx, + ) + if updated_patch_idx is not None: + image_patch_idx = updated_patch_idx + if updated_video_patch_idx is not None: + video_patch_idx = updated_video_patch_idx + samples_list.append(sample) + else: + sample = self._build_packed_sample( + outputs=micro_batch_outputs, + prompts=micro_batch_prompts, + labels=micro_batch_labels, + references=micro_batch_references, + ) + samples_list.append(sample) # Report timing torch.cuda.synchronize() @@ -1498,7 +1568,6 @@ def _fetch_teacher_logprobs( :type experiences: List[Union[Experience, ExperienceVL]] """ - # Get teacher URL from config teacher_url = self.teacher_model_url if isinstance(teacher_url, list): teacher_url = teacher_url[0] if teacher_url else None @@ -1511,18 +1580,13 @@ def _fetch_teacher_logprobs( Timer.start(' fetch_teacher_logprobs') for exp in experiences: - sequences = exp.sequences # [batch_size, seq_len] - attention_mask = exp.attention_mask # [batch_size, seq_len] - action_mask = exp.action_mask # [batch_size, num_tokens] + sequences = exp.sequences + attention_mask = exp.attention_mask + action_mask = exp.action_mask - # response_lengths must be int for slicing response_lengths = action_mask.sum(dim=-1).int().tolist() num_tokens = action_mask.shape[1] - # Strip padding tokens before sending to SGLang. - # sequences is [prompt, response, eos, pad, pad, ...] — the padding - # tokens would cause SGLang to return logprobs for pad positions, - # making the [-resp_len:] slice grab wrong tokens. input_ids_list = [] for j in range(sequences.shape[0]): valid_len = int(attention_mask[j].sum().item()) @@ -1542,14 +1606,6 @@ def _fetch_teacher_logprobs( finally: loop.close() - # Align teacher log probs to action_log_probs shape [batch_size, num_tokens]. - # Use action_mask indices directly — works regardless of left/right padding. - # - # Correctness: teacher_lp_list[i] contains teacher logprobs for response - # tokens in first→last order (from teacher_lp[-resp_len:]). - # valid_indices = ascending positions where action_mask==1 (real response tokens). - # So aligned_teacher_lp[i, valid_indices[k]] = tlp[k] correctly maps the k-th - # teacher logprob to the k-th response token position, regardless of padding direction. batch_size = sequences.shape[0] aligned_teacher_lp = torch.zeros(batch_size, num_tokens, dtype=torch.float32) for i, (tlp, resp_len) in enumerate(zip(teacher_lp_list, response_lengths)): @@ -1609,6 +1665,89 @@ def _process_experiences( # Use calculator's preprocess_rewards method return self.advantage_calculator.preprocess_rewards(rewards, experiences, max_new_tokens) + def _apply_step_reward_group_norm(self, experiences: List) -> None: + """Variant 2 (per-step PRM) GRPO-style step-level baseline subtraction. + + Active only when ``self.strategy.config.per_step_reward_mode == "group_norm"``. + For each step k, computes mean/std across the K trajectories that share + the same prompt (group), then replaces step_rewards with + ``(s - mean) / std`` so that scattered token-level rewards are zero-mean + signed advantages — analogous to GRPO trajectory-level baseline but at + step granularity. + + This is the difference between paper variant 2's two interpretations: + - "raw": scatter raw sigmoid step_score (paper Figure ablation) + - "group_norm": scatter group-normalized step_score (GRPO convention) + + Padding (step_token_indices < 0) is masked so it doesn't pollute mean/std. + Operates in-place on ``experience.info["step_rewards"]``. + """ + config = self.strategy.config + K = config.n_samples_per_prompt + if K is None or K < 2: + return # need >= 2 samples per group for std + + all_sr: List[torch.Tensor] = [] + all_sti: List[torch.Tensor] = [] + for exp in experiences: + sr_list = exp.info.get("step_rewards") + sti_list = exp.info.get("step_token_indices") + if sr_list is None or sti_list is None: + return # no per-step data present + all_sr.extend(sr_list) + all_sti.extend(sti_list) + + if not all_sr: + return + B = len(all_sr) + if B % K != 0: + warnings.warn( + f"per_step_reward_mode=group_norm: trajectory count {B} not divisible " + f"by n_samples_per_prompt {K}; skipping step-level group norm." + ) + return + + max_steps = max((t.numel() for t in all_sr), default=0) + if max_steps == 0: + return + + sr_padded = torch.zeros(B, max_steps, dtype=torch.float32) + valid = torch.zeros(B, max_steps, dtype=torch.bool) + for i, (sr, sti) in enumerate(zip(all_sr, all_sti)): + n = sr.numel() + if n > 0: + sr_padded[i, :n] = sr.float() + valid[i, :n] = (sti >= 0) + + G = B // K + sr_g = sr_padded.reshape(G, K, max_steps) + valid_g = valid.reshape(G, K, max_steps).float() + + n_valid = valid_g.sum(dim=1, keepdim=True).clamp(min=1.0) + sr_masked = sr_g * valid_g + mean = sr_masked.sum(dim=1, keepdim=True) / n_valid + sq = ((sr_g - mean) * valid_g) ** 2 + var = sq.sum(dim=1, keepdim=True) / n_valid + std = (var + 1e-9).sqrt() + sr_normed = (sr_g - mean) / std + # Zero out invalid entries so they don't show up if accidentally scattered. + sr_normed = sr_normed * valid_g + sr_normed_flat = sr_normed.reshape(B, max_steps) + + # Write back, preserving each trajectory's actual step count (n). + idx = 0 + for exp in experiences: + sr_list = exp.info["step_rewards"] + new_sr_list: List[torch.Tensor] = [] + for sr in sr_list: + n = sr.numel() + if n > 0: + new_sr_list.append(sr_normed_flat[idx, :n].cpu().to(sr.dtype)) + else: + new_sr_list.append(sr) + idx += 1 + exp.info["step_rewards"] = new_sr_list + def _compute_advantages_and_returns( self, experiences: List[ExperienceVL], @@ -1632,6 +1771,12 @@ def _compute_advantages_and_returns( """ config = self.strategy.config + # Variant 2 step-level baseline subtraction (cross-experience, group-aware). + # Only active when label produced step_rewards AND user set + # --per_step_reward_mode=group_norm. + if getattr(config, "per_step_reward_mode", "raw") == "group_norm": + self._apply_step_reward_group_norm(experiences) + for experience, reward in zip(experiences, rewards): reward = reward.to("cuda") processed_reward = reward.clone() # TODO:check @@ -1653,12 +1798,42 @@ def _compute_advantages_and_returns( processed_reward = torch.clamp(processed_reward, -config.reward_clip, config.reward_clip) # ========== Final Reward (with KL penalty) ========== + # Variant 2 (per-step PRM): if experience carries + # step_rewards/step_token_indices (placed by _pack_experience + # when the reward model emitted them), build a padded + # (batch, max_steps) tensor for compute_reward to scatter into + # per-step boundary tokens. Trajectories with empty step lists + # get all-(-1) indices => compute_reward filters them out and + # falls back to the trajectory-scalar (EOS) path for that row. + step_rewards_padded = None + step_indices_padded = None + step_rewards_list = experience.info.get("step_rewards") + step_indices_list = experience.info.get("step_token_indices") + if ( + step_rewards_list is not None and step_indices_list is not None + and any(t.numel() > 0 for t in step_rewards_list) + ): + max_steps = max(t.numel() for t in step_rewards_list) + if max_steps > 0: + B = len(step_rewards_list) + step_rewards_padded = torch.zeros(B, max_steps, dtype=torch.float32, device="cuda") + step_indices_padded = torch.full((B, max_steps), -1, dtype=torch.long, device="cuda") + for i, (sr, sti) in enumerate(zip(step_rewards_list, step_indices_list)): + n = sr.numel() + if n > 0: + step_rewards_padded[ + i, :n] = sr.to(step_rewards_padded.device, dtype=step_rewards_padded.dtype) + step_indices_padded[ + i, :n] = sti.to(step_indices_padded.device, dtype=step_indices_padded.dtype) + final_reward = compute_reward( processed_reward, self.kl_ctl.value, experience.kl, action_mask=experience.action_mask, num_actions=experience.info["num_actions"], + step_rewards=step_rewards_padded, + step_token_indices=step_indices_padded, ) # ========== Advantage Estimation ========== @@ -1714,80 +1889,59 @@ def _make_experience_list_by_model( device = get_current_device() vlm_mode = isinstance(all_samples[0], SamplesVL) - # ========== Stage 0: Preprocessing ========== outputs = [self._preprocess_sample(sample, vlm_mode, device) for sample in all_samples] - # ========== Stage 1: Actor Forward ========== Timer.start(' actor_logprob') - # Check if we need to compute entropy for high-entropy token filtering - need_entropy = hasattr(self.actor, 'high_entropy_token_ratio') and self.actor.high_entropy_token_ratio > 0.0 - for output in outputs: - if need_entropy: - # Request full output to get action_entropy - action_log_probs, model_output = self.actor( - output.sequences, - output.num_actions, - output.attention_mask, - packed_seq_lens=output.packed_seq_lens, - return_output=True, - **output.inputs_extra_kwargs - ) - output.action_log_probs = action_log_probs - # Extract action_entropy if available - if "action_entropy" in model_output: - output.action_entropy = model_output["action_entropy"] - else: - output.action_log_probs = self.actor( - output.sequences, - output.num_actions, - output.attention_mask, - packed_seq_lens=output.packed_seq_lens, - **output.inputs_extra_kwargs - ) + with self.profiler.section("collect/model/actor_logprob"): + need_entropy = hasattr(self.actor, 'high_entropy_token_ratio') and self.actor.high_entropy_token_ratio > 0.0 + for output in outputs: + if need_entropy: + action_log_probs, model_output = self.actor( + output.sequences, + output.num_actions, + output.attention_mask, + packed_seq_lens=output.packed_seq_lens, + return_output=True, + **output.inputs_extra_kwargs + ) + output.action_log_probs = action_log_probs + if "action_entropy" in model_output: + output.action_entropy = model_output["action_entropy"] + else: + output.action_log_probs = self.actor( + output.sequences, + output.num_actions, + output.attention_mask, + packed_seq_lens=output.packed_seq_lens, + **output.inputs_extra_kwargs + ) Timer.stop(' actor_logprob') - # ========== Stage 2: Initial Model ========== if self.initial_model is not None: - # Note: Manual reload/offload is safe for initial_model because: - # 1. It's initialized with is_training=False (see train_colocate.py:207) - # 2. This means FSDP's CPUOffloadPolicy is NOT enabled (see fsdpv2.py:375) - # 3. Without CPUOffloadPolicy, FSDP doesn't automatically manage parameter movement - # 4. We can safely use manual reload_model() to move model from CPU to GPU - # 5. After computing base_action_log_probs, we offload back to CPU to save memory - # This pattern works because there's no conflict with FSDP's automatic management. - self.strategy.reload_model(self.initial_model) - for output in outputs: - output.base_action_log_probs = self.initial_model( - output.sequences, - output.num_actions, - output.attention_mask, - packed_seq_lens=output.packed_seq_lens, - **output.inputs_extra_kwargs - ) - # Offload back to CPU to free GPU memory for subsequent stages - self.strategy.offload_model(self.initial_model) + with self.profiler.section("collect/model/reference_logprob"): + self.strategy.reload_model(self.initial_model) + for output in outputs: + output.base_action_log_probs = self.initial_model( + output.sequences, + output.num_actions, + output.attention_mask, + packed_seq_lens=output.packed_seq_lens, + **output.inputs_extra_kwargs + ) + self.strategy.offload_model(self.initial_model) - # ========== Stage 3: Critic ========== - Timer.start(' critic') if self.critic is not None: - # Note: When critic is initialized with is_training=True and fsdp_cpu_offload=True, - # FSDP's CPUOffloadPolicy automatically manages parameter movement between CPU/GPU. - # Manual reload_model/offload_model calls will conflict with FSDP's automatic management - # and cause "FSDP parameters should be materialized on CPU" error. - # The CPUOffloadPolicy will automatically: - # 1. Prefetch parameters from CPU to GPU before forward pass - # 2. Offload parameters back to CPU after forward pass - # This is the recommended approach for memory-efficient training with FSDP2. - for output in outputs: - output.value = self.critic( - output.sequences, output.num_actions, output.attention_mask, **output.inputs_extra_kwargs - ) - Timer.stop(' critic') + with self.profiler.section("collect/model/critic_forward"): + self.strategy.reload_model(self.critic) + for output in outputs: + output.value = self.critic( + output.sequences, output.num_actions, output.attention_mask, **output.inputs_extra_kwargs + ) + self.strategy.offload_model(self.critic) - # ========== Stage 4: Reward Models ========== - self.reward_engine.compute_rewards(outputs, vlm_mode, device) + with self.profiler.section("collect/model/reward_forward"): + self.reward_engine.compute_rewards(outputs, vlm_mode, device) - # ========== Stage 5: Assemble Experiences ========== return [self._pack_experience(output, vlm_mode) for output in outputs] def _preprocess_sample( @@ -1832,9 +1986,6 @@ def _preprocess_sample( "pixel_values_videos": sample.pixel_values_videos, "video_grid_thw": sample.video_grid_thws, } - # Audio-language actors expect audio_values; pipeline stores them in pixel_values slot - if "audio_values" in self._actor_supported_params: - candidate_params["audio_values"] = candidate_params.get("pixel_values") # Filter to only include supported parameters extra_kwargs = { @@ -1892,16 +2043,16 @@ def _fix_qwen_vl_image_tokens( return config = self.strategy.unwrap_model(self.actor.model).config - image_token_id = config.image_token_id + image_token_id = getattr(config, "image_token_id", getattr(config, "image_token_index", None)) + if image_token_id is None: + return num_tokens = (sequences == image_token_id).sum() - vision_config = getattr(config, "vision_config", None) - spatial_merge_size = getattr(vision_config, "spatial_merge_size", 2) - merge_length = spatial_merge_size ** 2 - - if sample.image_grid_thws is not None and sample.image_grid_thws.numel() > 0: - num_patches = int(torch.prod(sample.image_grid_thws, dim=-1).sum().item() // merge_length) + # Qwen2-VL style processors usually flatten patches (dim < 4), while some + # VLMs such as URSA keep one image tensor per item (dim == 4). Handle both. + if sample.pixel_values.dim() >= 4: + num_patches = sample.pixel_values.shape[0] else: - num_patches = sample.pixel_values.shape[0] // merge_length + num_patches = sample.pixel_values.shape[0] // 4 if num_tokens != num_patches: self.strategy.print( @@ -1972,6 +2123,17 @@ def _pack_experience( if output.reward_metrics is not None: info['reward_metrics'] = output.reward_metrics + # Variant 2 (per-step PRM rewards): forward to experience.info so + # downstream advantage computation can build per-token reward signals. + # Stored as Python lists of CPU tensors (variable length per traj). + # _process_experiences chunks experiences along micro_batch_size, + # so we keep the *micro-batch local* indexing here — i.e. info + # carries the lists scoped to this single _SamplesOutput. + if output.step_rewards is not None: + info['step_rewards'] = output.step_rewards + if output.step_token_indices is not None: + info['step_token_indices'] = output.step_token_indices + # Create Experience object if vlm: return ExperienceVL( @@ -1991,8 +2153,6 @@ def _pack_experience( info=info, kl=kl, action_entropy=output.action_entropy, - labels=output.labels, # data source labels (if available, e.g., "gsm8k_rule") - references=output.references, # ground truth references (if available, e.g., correct answers) ) else: return Experience( diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index c9cdb8e1..08b406ae 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -1,7 +1,9 @@ import os import sys import os.path +from collections import defaultdict from abc import ABC +from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional import torch @@ -15,10 +17,28 @@ from lightrft.models.actor_modality import ActorModality, get_supported_parameters from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl from lightrft.utils.distributed_sampler import DistributedSampler -from lightrft.utils import rotate_ckpt_dirs from lightrft.trainer import AdaptiveKLController, ExperienceVL, FixedKLController, NaiveExperienceMakerVL, NaiveReplayBufferVL # noqa +class _NullStepProfiler: + @contextmanager + def section(self, _name: str): + yield + + @contextmanager + def phase(self, _name: str): + yield + + def start_step(self, *_args, **_kwargs): + return None + + def finish_step(self, *_args, **_kwargs): + return None + + def close(self): + return None + + class PPOTrainerVL(ABC): """ Trainer for Proximal Policy Optimization (PPO) algorithm for Vision-Language Models. @@ -162,7 +182,6 @@ def __init__( self.reward_fn = reward_fn self.reward_fn_label_map = reward_fn_label_map self.reward_recipe = reward_recipe - self.is_lora = getattr(self.args, "lora_rank", 0) > 0 self.actor = actor self.critic = critic @@ -217,6 +236,7 @@ def __init__( self._tensorboard = None self.eval_step_counter = 0 # Independent counter for eval X-axis self.wandb_log_counter = 0 # Global counter for unique wandb system steps + self.profiler = getattr(self, "profiler", _NullStepProfiler()) if self.strategy.args.use_wandb and self.strategy.is_rank_0(): import wandb @@ -254,6 +274,28 @@ def __init__( log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) self._tensorboard = SummaryWriter(log_dir=log_dir) + def _update_wandb_summary(self, logs: Dict[str, Any]) -> None: + if self._wandb is None or not self.strategy.is_rank_0() or not logs: + return + + summary_logs = {} + for key, value in logs.items(): + if isinstance(value, torch.Tensor): + if value.numel() != 1: + continue + value = value.item() + elif hasattr(value, "item") and not isinstance(value, (str, bytes)): + try: + value = value.item() + except (TypeError, ValueError): + pass + + if isinstance(value, (int, float, bool, str)): + summary_logs[key] = value + + if summary_logs and self._wandb.run is not None: + self._wandb.run.summary.update(summary_logs) + def fit( self, args, @@ -359,6 +401,8 @@ def fit( ) for batch in self.prompts_dataloader: + if hasattr(self, "profiler") and self.profiler is not None: + self.profiler.start_step(steps, episode) # Compatible with both image-only (4 args) and video (5 args) dataloaders if len(batch) == 5: rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch @@ -366,9 +410,14 @@ def fit( rand_prompts, rand_images, rand_references, rand_labels = batch rand_videos = None - # TODO: Remove debug print + batch_preview = min(2, len(rand_prompts)) self.strategy.print( - f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa + "collect phase batch summary: " + f"batch_size={len(rand_prompts)}, " + f"preview_prompts={rand_prompts[:batch_preview]}, " + f"preview_images={rand_images[:batch_preview]}, " + f"preview_references={rand_references[:batch_preview]}, " + f"preview_labels={rand_labels[:batch_preview]}" ) for i, experience in enumerate( @@ -403,9 +452,7 @@ def fit( rollout_status = {} if self.replay_buffer.items: all_rewards = [] - all_format_rewards = [] - all_accuracy_rewards = [] - all_general_model_rewards = [] + reward_metric_values = defaultdict(list) all_response_lengths = [] for item in self.replay_buffer.items: @@ -421,19 +468,9 @@ def fit( hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info and item.info['reward_metrics'] is not None ): - reward_metrics = item.info['reward_metrics'] - - # Safely extract sub-metrics - if 'format_reward' in reward_metrics: - all_format_rewards.append(reward_metrics['format_reward']) - if 'accuracy_reward' in reward_metrics: - all_accuracy_rewards.append(reward_metrics['accuracy_reward']) - general_model_reward = reward_metrics.get("general_model_reward") - if general_model_reward is None: - general_model_reward = reward_metrics.get("model_reward") - if general_model_reward is not None: - all_general_model_rewards.append(general_model_reward) + for key, value in reward_metrics.items(): + reward_metric_values[key].append(value) # Collect response lengths from rollout if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: @@ -451,42 +488,18 @@ def fit( rollout_status["rollout_reward"] = rewards_tensor.mean().item() rollout_status["rollout_reward_std"] = rewards_tensor.std().item() - if all_format_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - # Issue: all_format_rewards may contain tensors (from reward_metrics), - # but torch.tensor() cannot convert a list of tensors directly. - # Solution: Use torch.cat() for tensor lists, torch.tensor() for scalar lists - if isinstance(all_format_rewards[0], torch.Tensor): - # List of tensors: concatenate them - format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards]) - else: - # List of scalars: convert to tensor - format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) - - mean_format_reward = format_tensor.mean().item() - rollout_status["rollout_format_reward"] = mean_format_reward - - if all_accuracy_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_accuracy_rewards[0], torch.Tensor): - accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards]) - else: - accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) - - mean_accuracy_reward = accuracy_tensor.mean().item() - rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward - - if all_general_model_rewards: - if isinstance(all_general_model_rewards[0], torch.Tensor): - general_model_tensor = torch.cat([t.to(device).float() for t in all_general_model_rewards]) + for metric_name, values in reward_metric_values.items(): + if not values: + continue + if isinstance(values[0], torch.Tensor): + metric_tensor = torch.cat([t.to(device).float() for t in values]) else: - general_model_tensor = torch.tensor( - all_general_model_rewards, dtype=torch.float32, device=device - ) - - mean_general_model_reward = general_model_tensor.mean().item() - if abs(mean_general_model_reward) > 1e-6: - rollout_status["rollout_general_model_reward"] = mean_general_model_reward + metric_tensor = torch.tensor(values, dtype=torch.float32, device=device) + if metric_tensor.numel() == 0: + continue + mean_metric = metric_tensor.mean().item() + if abs(mean_metric) > 1e-6: + rollout_status[f"rollout_{metric_name}"] = mean_metric if all_response_lengths: # [TENSOR-FIX] Handle both tensor lists and scalar lists @@ -524,9 +537,16 @@ def fit( self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + if hasattr(self, "profiler") and self.profiler is not None: + profile_snapshot = self.profiler.finish_step() + if hasattr(self, "log_profile_metrics"): + self.log_profile_metrics(steps, episode, profile_snapshot) + pbar.update() steps = steps + 1 + if hasattr(self, "profiler") and self.profiler is not None: + self.profiler.close() if self._wandb is not None and self.strategy.is_rank_0(): self._wandb.finish() if self._tensorboard is not None and self.strategy.is_rank_0(): @@ -742,125 +762,120 @@ def training_step_actor(self, ) return {} # Emergency fallback - should not normally execute - # Actor loss - # Build kwargs based on actor's modality - only include supported parameters - candidate_params = { - "pixel_values": pixel_values, - "image_grid_thw": image_grid_thws, - "pixel_values_videos": pixel_values_videos, - "video_grid_thw": video_grid_thws, - } - # Audio-language actors expect audio_values; pipeline stores them in pixel_values slot - if "audio_values" in self._actor_supported_params: - candidate_params["audio_values"] = candidate_params.get("pixel_values") - - actor_kwargs = {key: value for key, value in candidate_params.items() if key in self._actor_supported_params} - - action_log_probs, output = self.actor( - sequences, - num_actions, - attention_mask=attention_mask, - return_output=True, - packed_seq_lens=packed_seq_lens, - **actor_kwargs - ) + with self.profiler.section("learn/actor/total"): + candidate_params = { + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thws, + "pixel_values_videos": pixel_values_videos, + "video_grid_thw": video_grid_thws, + } + + actor_kwargs = { + key: value + for key, value in candidate_params.items() + if key in self._actor_supported_params + } + + with self.profiler.section("learn/actor/forward"): + action_log_probs, output = self.actor( + sequences, + num_actions, + attention_mask=attention_mask, + return_output=True, + packed_seq_lens=packed_seq_lens, + **actor_kwargs + ) # NOTE: Explicit masking in log-space is incorrect - removed # if experience.action_mask is not None: # # Setting masked positions to 0 to match old_action_log_probs is WRONG in log-space # action_log_probs = action_log_probs * experience.action_mask - # Loss function - actor_loss = self.actor_loss_fn( - action_log_probs, - old_action_log_probs, - advantages, - action_mask=experience.action_mask, - entropy_mask=entropy_mask, - ) - - if self.args.use_kl_loss: - if self.initial_model is not None: - # TODO(pu): Text-only action mask for KL calculation - - kl = compute_approx_kl( + with self.profiler.section("learn/actor/loss"): + actor_loss = self.actor_loss_fn( action_log_probs, - base_action_log_probs, - experience.action_mask, - kl_estimator=self.args.kl_estimator, + old_action_log_probs, + advantages, + action_mask=experience.action_mask, + entropy_mask=entropy_mask, ) - # [Protection measure 2] Per-token KL Clamping - # NOTE: Adding this causes svkng training to not converge - # kl = torch.clamp(kl, min=0.0, max=20.0) + if self.args.use_kl_loss: + if self.initial_model is not None: + kl = compute_approx_kl( + action_log_probs, + base_action_log_probs, + experience.action_mask, + kl_estimator=self.args.kl_estimator, + ) + else: + kl = torch.zeros_like( + action_log_probs, + dtype=action_log_probs.dtype, + device=action_log_probs.device, + ) - else: - kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device) + if not self.args.packing_samples: + kl_mean = masked_mean(kl, experience.action_mask, dim=-1) + else: + kl = unpacking_samples(kl, num_actions) + kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device) - if not self.args.packing_samples: - kl_mean = masked_mean(kl, experience.action_mask, dim=-1) - # Not supported for packed samples + kl_loss = kl_mean.mean() + experience.info["kl"] = kl_loss.item() else: - # Convert tensor into list of tensors for easier manipulation within dataset - kl = unpacking_samples(kl, num_actions) - kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device) - - kl_loss = kl_mean.mean() - experience.info["kl"] = kl_loss.item() - else: - kl_loss = 0 - - # Mixtral auxiliary loss - if self.aux_loss: - aux_loss = output.aux_loss - else: - aux_loss = 0 - - loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value - - if torch.isnan(loss) or torch.isinf(loss): - self.strategy.print("[CRITICAL ERROR] Actor loss is NaN or Inf at step. Skipping update.") - self.strategy.print(f" Actor Loss: {actor_loss.item()}") - self.strategy.print(f" KL Loss: {kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss}") - - self.strategy.backward(loss, self.actor, self.actor_optim) - - # PTX loss for supervised fine-tuning - if self.pretrain_dataloader is not None: - data = next(self.pretrain_dataloader) - inputs = data[1].squeeze(1).to(torch.cuda.current_device()) - attention_mask = data[2].squeeze(1).to(torch.cuda.current_device()) - label = torch.where( - attention_mask.bool(), - inputs, - self.ptx_loss_fn.IGNORE_INDEX, - ) - pixel_values = data[3].to(torch.cuda.current_device()) - image_grid_thws = data[4].to(torch.cuda.current_device()) - - output = self.actor( - inputs, - attention_mask=attention_mask, - pixel_values=pixel_values, - image_grid_thw=image_grid_thws, - return_output=True - ) - ptx_log_probs = output["logits"] + kl_loss = 0 - # Loss function - ptx_loss = self.ptx_loss_fn(ptx_log_probs, label) - # Mixtral auxiliary loss if self.aux_loss: aux_loss = output.aux_loss else: aux_loss = 0 - loss = ptx_loss + aux_loss * self.args.aux_loss_coef - self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim) - - self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor") - if self.ema_model: - self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda") + loss = actor_loss + aux_loss * self.args.aux_loss_coef + kl_loss * self.kl_ctl.value + + if torch.isnan(loss) or torch.isinf(loss): + self.strategy.print("[CRITICAL ERROR] Actor loss is NaN or Inf at step. Skipping update.") + self.strategy.print(f" Actor Loss: {actor_loss.item()}") + self.strategy.print(f" KL Loss: {kl_loss.item() if isinstance(kl_loss, torch.Tensor) else kl_loss}") + + with self.profiler.section("learn/actor/backward"): + self.strategy.backward(loss, self.actor, self.actor_optim) + + if self.pretrain_dataloader is not None: + with self.profiler.section("learn/actor/ptx"): + data = next(self.pretrain_dataloader) + inputs = data[1].squeeze(1).to(torch.cuda.current_device()) + attention_mask = data[2].squeeze(1).to(torch.cuda.current_device()) + label = torch.where( + attention_mask.bool(), + inputs, + self.ptx_loss_fn.IGNORE_INDEX, + ) + pixel_values = data[3].to(torch.cuda.current_device()) + image_grid_thws = data[4].to(torch.cuda.current_device()) + + output = self.actor( + inputs, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thws, + return_output=True + ) + ptx_log_probs = output["logits"] + ptx_loss = self.ptx_loss_fn(ptx_log_probs, label) + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + loss = ptx_loss + aux_loss * self.args.aux_loss_coef + self.strategy.backward(self.ptx_coef * loss, self.actor, self.actor_optim) + + with self.profiler.section("learn/actor/optimizer_step"): + self.strategy.optimizer_step(self.actor_optim, self.actor, self.actor_scheduler, name="actor") + + if self.ema_model: + with self.profiler.section("learn/actor/ema"): + self.strategy.moving_average(self.actor, self.ema_model, self.ema_beta, "cuda") # Status status = {"policy_loss": actor_loss.item(), "actor_lr": self.actor_scheduler.get_last_lr()[0]} @@ -988,33 +1003,35 @@ def ensure_device_and_contiguous(tensor, name="tensor"): sequences = ensure_device_and_contiguous(sequences, "sequences") attention_mask = ensure_device_and_contiguous(attention_mask, "attention_mask") - # Critic loss - values, output = self.critic( - sequences, - num_actions=num_actions, - attention_mask=attention_mask, - pixel_values=pixel_values, - image_grid_thw=image_grid_thws, - pixel_values_videos=pixel_values_videos, - video_grid_thw=video_grid_thws, - return_output=True, - packed_seq_lens=packed_seq_lens, - ) - # Loss function - critic_loss = self.critic_loss_fn( - values, - old_values, - returns, - action_mask=experience.action_mask, - ) - # Mixtral auxiliary loss - if self.aux_loss: - aux_loss = output.aux_loss - else: - aux_loss = 0 - loss = critic_loss + aux_loss * self.args.aux_loss_coef - self.strategy.backward(loss, self.critic, self.critic_optim) - self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic") + with self.profiler.section("learn/critic/total"): + with self.profiler.section("learn/critic/forward"): + values, output = self.critic( + sequences, + num_actions=num_actions, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thws, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thws, + return_output=True, + packed_seq_lens=packed_seq_lens, + ) + with self.profiler.section("learn/critic/loss"): + critic_loss = self.critic_loss_fn( + values, + old_values, + returns, + action_mask=experience.action_mask, + ) + if self.aux_loss: + aux_loss = output.aux_loss + else: + aux_loss = 0 + loss = critic_loss + aux_loss * self.args.aux_loss_coef + with self.profiler.section("learn/critic/backward"): + self.strategy.backward(loss, self.critic, self.critic_optim) + with self.profiler.section("learn/critic/optimizer_step"): + self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic") # Status status = { @@ -1090,10 +1107,10 @@ def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, c for k, v in self.experience_maker.perf_stats.items(): all_wandb_logs[f"perf/experience_maker/{k}"] = v - # Commit Train/Rollout logs with unique system step if all_wandb_logs: self.wandb_log_counter += 1 self._wandb.log(all_wandb_logs, step=self.wandb_log_counter, commit=True) + self._update_wandb_summary(all_wandb_logs) # TensorBoard Logging elif self._tensorboard is not None and self.strategy.is_rank_0(): @@ -1125,12 +1142,9 @@ def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, c eval_logs["eval/train_step"] = global_step eval_logs["eval/episode"] = episode - # IMPORTANT: - # Use wandb_log_counter to ensure eval has a unique system step - # This prevents eval metrics from being overwritten by train metrics - # The plots will still use eval/global_step as X-axis due to define_metric self.wandb_log_counter += 1 self._wandb.log(eval_logs, step=self.wandb_log_counter, commit=True) + self._update_wandb_summary(eval_logs) # TensorBoard Logging for Eval elif self._tensorboard is not None: @@ -1155,11 +1169,10 @@ def _save_checkpoint(self, args, tag, client_states): :param client_states: Client state for checkpoint recovery. :type client_states: dict """ - ckpt_path = args.ckpt_path - if not self.disable_ds_ckpt and not self.is_lora: + if not self.disable_ds_ckpt: self.strategy.save_ckpt( self.actor.model, - os.path.join(ckpt_path, "_actor"), + os.path.join(args.ckpt_path, "_actor"), tag, args.max_ckpt_num, args.max_ckpt_mem, @@ -1167,24 +1180,11 @@ def _save_checkpoint(self, args, tag, client_states): ) if self.critic is not None: self.strategy.save_ckpt( - self.critic, os.path.join(ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem - ) - - # For LoRA, we ALWAYS save the HF adapter as it is much smaller and more convenient for deployment. - if self.save_hf_ckpt or self.is_lora: - # Rotate HF checkpoints - if self.strategy.is_rank_0(): - os.makedirs(ckpt_path, exist_ok=True) - max_num = getattr(args, "max_ckpt_num", 3) - rotate_ckpt_dirs( - ckpt_path, - max_num, - suffix="_lora", - strategy=self.strategy, - label="HF ckpt", + self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem ) - save_path = os.path.join(ckpt_path, f"{tag}_lora") + if self.save_hf_ckpt: + save_path = os.path.join(args.ckpt_path, f"{tag}_hf") self.strategy.save_model(self.actor, self.tokenizer, save_path) def evaluate(self, eval_dataloader, global_step): @@ -1210,9 +1210,7 @@ def evaluate(self, eval_dataloader, global_step): self.critic.eval() all_rewards = [] - all_format_rewards = [] - all_accuracy_rewards = [] - all_general_model_rewards = [] + reward_metric_values = defaultdict(list) all_response_lengths = [] num_eval_batches = 0 @@ -1258,15 +1256,8 @@ def extract_values(val): if 'reward_metrics' in info: rm = info['reward_metrics'] - if 'format_reward' in rm: - all_format_rewards.extend(extract_values(rm['format_reward'])) - if 'accuracy_reward' in rm: - all_accuracy_rewards.extend(extract_values(rm['accuracy_reward'])) - general_model_reward = rm.get("general_model_reward") - if general_model_reward is None: - general_model_reward = rm.get("model_reward") - if general_model_reward is not None: - all_general_model_rewards.extend(extract_values(general_model_reward)) + for key, value in rm.items(): + reward_metric_values[key].extend(extract_values(value)) num_eval_batches += 1 if num_eval_batches >= len(eval_dataloader): @@ -1287,9 +1278,8 @@ def compute_stats(name, values_list): # metrics[f"{name}_std"] = t.std().item() # Optional compute_stats("reward", all_rewards) - compute_stats("format_reward", all_format_rewards) - compute_stats("accuracy_reward", all_accuracy_rewards) - compute_stats("general_model_reward", all_general_model_rewards) + for metric_name, values in reward_metric_values.items(): + compute_stats(metric_name, values) compute_stats("response_length", all_response_lengths) metrics["num_samples"] = len(all_rewards) diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index 026056da..95c4ce9d 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -18,7 +18,9 @@ - Efficient distributed training across multiple devices and nodes """ +import os import time +from collections import defaultdict import torch from tqdm import tqdm @@ -30,7 +32,7 @@ from lightrft.trainer.replay_buffer import make_experience_batch from lightrft.trainer.replay_buffer_vl import make_experience_batch as make_experience_batch_vl from lightrft.models.utils import create_high_entropy_mask -from lightrft.utils import init_logger +from lightrft.utils import StepProfileRecorder, init_logger logger = init_logger(__name__) @@ -115,6 +117,11 @@ def __init__( # TODO: here we pass a list of concrete params, this may collapse in future versions. # Create experience maker with appropriate parameters processor = kwargs.pop("processor", None) + self.profiler = StepProfileRecorder( + enabled=bool(getattr(self.args, "enable_profile", False)), + output_dir=os.path.join(self.args.save_path, "profile"), + print_fn=self.strategy.print, + ) self.experience_maker = FastExperienceMaker( self.actor, @@ -131,6 +138,7 @@ def __init__( self.reward_recipe, packing_samples=self.packing_samples, processor=processor, + profiler=self.profiler, ) # Extract high_entropy_token_ratio for entropy-based token filtering @@ -192,322 +200,272 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train torch.cuda.synchronize() train_begin = time.time() - torch.cuda.empty_cache() - self.strategy.maybe_load_optimizer(self.actor_optim) - all_items = self.strategy.sp_data_processor.preprocess(self.replay_buffer.items) - - device = torch.cuda.current_device() - - status_list = [] - status_mean = {} - for epoch in range(self.max_epochs): - pbar = tqdm( - range(0, len(all_items), self.micro_train_batch_size), - desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]", - disable=not self.strategy.is_rank_0(), - ) - for i in pbar: - items = all_items[i:i + self.micro_train_batch_size] - if self.VLM: - experience = make_experience_batch_vl(items, packing_samples=self.packing_samples) - else: - experience = make_experience_batch(items, packing_samples=self.packing_samples) - experience.to_device(device) - - # ====================================================================================== - # Validate data BEFORE calling training_step to prevent execution path divergence - # If validation is done inside training_step, different ranks may follow different code paths - # (some return early, others continue), causing deadlock in collective communication ops. - - # Step 1: Each rank validates its local data - should_skip_local = False - if self.VLM and hasattr(self, '_validate_qwen_vl_tensors'): - # Call the same validation logic used in training_step_actor - sequences = experience.sequences - pixel_values = experience.pixel_values - - # Validate before any forward pass - is_valid = self._validate_qwen_vl_tensors( - sequences, pixel_values, context="pre_training_validation" - ) - should_skip_local = not is_valid - - # Step 2: Synchronize skip decision across all ranks via all_reduce - # This ensures all ranks agree on whether to skip, preventing execution divergence - skip_flag = torch.tensor([1.0 if should_skip_local else 0.0], device=device) - torch.distributed.all_reduce(skip_flag, op=torch.distributed.ReduceOp.MAX) - - # Step 3: Collectively skip if ANY rank detected invalid data - if skip_flag.item() > 0: - if self.strategy.is_rank_0(): - pbar.set_description(f"Train epoch [{epoch + 1}/{self.max_epochs}] (skipping invalid batch)") - continue # All ranks skip together - no deadlock - # ====================================================================================== - - # Create entropy_mask if high_entropy_token_ratio > 0 and action_entropy is available - entropy_mask = None - if hasattr(experience, 'action_entropy') and experience.action_entropy is not None: - if self.high_entropy_token_ratio > 0.0: - entropy_mask = create_high_entropy_mask( - experience.action_entropy, experience.action_mask, self.high_entropy_token_ratio + with self.profiler.section("learn/total"): + torch.cuda.empty_cache() + self.strategy.maybe_load_optimizer(self.actor_optim) + with self.profiler.section("learn/sp_preprocess"): + all_items = self.strategy.sp_data_processor.preprocess(self.replay_buffer.items) + + device = torch.cuda.current_device() + status_list = [] + status_mean = {} + + for epoch in range(self.max_epochs): + pbar = tqdm( + range(0, len(all_items), self.micro_train_batch_size), + desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]", + disable=not self.strategy.is_rank_0(), + ) + for i in pbar: + items = all_items[i:i + self.micro_train_batch_size] + if self.VLM: + experience = make_experience_batch_vl(items, packing_samples=self.packing_samples) + else: + experience = make_experience_batch(items, packing_samples=self.packing_samples) + experience.to_device(device) + + should_skip_local = False + if self.VLM and hasattr(self, '_validate_qwen_vl_tensors'): + sequences = experience.sequences + pixel_values = experience.pixel_values + is_valid = self._validate_qwen_vl_tensors( + sequences, pixel_values, context="pre_training_validation" ) - - # Call training_step which will handle both GSPO and standard modes - status = self.training_step(experience, global_steps, entropy_mask=entropy_mask) - - # for DP - # weighted mean for kl - if "kl" in status: - status["kl"] *= status["response_length"] - status = self.strategy.all_reduce(status) - status["kl"] /= status["response_length"] - - # Training epoch progress bar: show per-batch metrics for detailed monitoring - short_status = {} - - if "policy_loss" in status: - short_status = { - "pg": status["policy_loss"], # policy gradient loss - "rm": status["reward"], # per-batch reward (instantaneous) - "ret": status["return"], # per-batch return (instantaneous) - "glen": status["response_length"], # per-batch response length - "tlen": status["total_length"], # per-batch total length - "kl": status["kl"], # KL divergence - "act_lr": status["actor_lr"], # actor learning rate - } - - if "critic_loss" in status: - short_status["cri"] = status["critic_loss"] - short_status["vals"] = status["values"] - short_status["cri_lr"] = status["critic_lr"] - - if "ptx_loss" in status: - short_status["ptx"] = status["ptx_loss"] - - status_list.append(status) - pbar.set_postfix(short_status) - - # Short status keys added for progress bar display: - # "pg": policy_loss - # "rm": reward - # "ret": return - # "glen": response_length - # "tlen": total_length - # "kl": KL divergence - # "act_lr": actor_lr - if status_list: - status_mean = status_list[0] - for m in status_list[1:]: - for k, v in m.items(): - status_mean[k] += v - for k in status_mean.keys(): - status_mean[k] /= len(status_list) - - # ========== Aggregate step-level reward metrics from replay buffer ========== - # NOTE: These metrics are aggregated from ALL experiences in the current step's - # replay buffer (e.g., 640 experiences if rollout_batch_size=128, n_samples=5). - # They represent the TRUE statistics of the rollout phase, NOT the training phase - # micro-batch averages which are less representative. - # - # Naming convention: - # - "*_mean" suffix: mean across all experiences in this step - # - "step_*" prefix: clarifies this is per-step aggregation, not per-episode - if self.replay_buffer.items: - all_rewards = [] - all_format_rewards = [] - all_accuracy_rewards = [] - all_general_model_rewards = [] - all_rule_rewards = [] - all_advantages = [] - all_returns = [] - all_response_lengths = [] - all_opd_reverse_kl = [] # For on-policy distillation metrics - - for item in self.replay_buffer.items: - # Collect rewards - if hasattr(item, 'info') and item.info is not None and 'reward' in item.info: - all_rewards.append(item.info['reward']) - - # Collect detailed reward metrics from info dict - if hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info: - reward_metrics = item.info['reward_metrics'] - if 'format_reward' in reward_metrics: - all_format_rewards.append(reward_metrics['format_reward']) - if 'accuracy_reward' in reward_metrics: - all_accuracy_rewards.append(reward_metrics['accuracy_reward']) - general_model_reward = reward_metrics.get("general_model_reward") - if general_model_reward is None: - general_model_reward = reward_metrics.get("model_reward") - if general_model_reward is not None: - all_general_model_rewards.append(general_model_reward) - if 'rule_reward' in reward_metrics: - all_rule_rewards.append(reward_metrics['rule_reward']) - - # Collect advantages and returns - if hasattr(item, 'advantages') and item.advantages is not None: - all_advantages.append(item.advantages) - if hasattr(item, 'returns') and item.returns is not None: - all_returns.append(item.returns) - if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: - all_response_lengths.append(item.info['response_length']) - - # Collect on-policy distillation reverse KL - if hasattr(item, 'info') and item.info is not None and 'opd_reverse_kl' in item.info: - all_opd_reverse_kl.append(item.info['opd_reverse_kl']) - - # Compute statistics - # [TENSOR-FIX] Handle both tensor lists and scalar lists for all reward types - if all_rewards: - # Handle both tensor lists (from batched rewards) and scalar lists - if isinstance(all_rewards[0], torch.Tensor): - rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards]) - else: - rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device) - # Use "step_*" prefix to clarify this is per-step aggregation, not per-episode - status_mean["step_reward_mean"] = rewards_tensor.mean().item() - status_mean["step_reward_std"] = rewards_tensor.std().item() - status_mean["step_reward_max"] = rewards_tensor.max().item() - status_mean["step_reward_min"] = rewards_tensor.min().item() - - if all_format_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_format_rewards[0], torch.Tensor): - format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards]) - else: - format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) - status_mean["format_reward_mean"] = format_tensor.mean().item() - status_mean["format_reward_std"] = format_tensor.std().item() - - if all_accuracy_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_accuracy_rewards[0], torch.Tensor): - accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards]) - else: - accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) - status_mean["accuracy_reward_mean"] = accuracy_tensor.mean().item() - status_mean["accuracy_reward_std"] = accuracy_tensor.std().item() - - if all_general_model_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_general_model_rewards[0], torch.Tensor): - model_tensor = torch.cat([t.to(device).float() for t in all_general_model_rewards]) - else: - model_tensor = torch.tensor(all_general_model_rewards, dtype=torch.float32, device=device) - if model_tensor.abs().sum() > 0: # Only log if model rewards are non-zero - status_mean["general_model_reward_mean"] = model_tensor.mean().item() - self.strategy.print(f" general_model_reward_mean: {status_mean['general_model_reward_mean']}") - - if all_rule_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_rule_rewards[0], torch.Tensor): - rule_tensor = torch.cat([t.to(device).float() for t in all_rule_rewards]) - else: - rule_tensor = torch.tensor(all_rule_rewards, dtype=torch.float32, device=device) - status_mean["rule_reward_mean"] = rule_tensor.mean().item() - self.strategy.print(f"rule_reward_mean: {status_mean['rule_reward_mean']}") - - # For advantages, returns, and lengths, they are already lists of tensors, - # so torch.cat() is the correct function to use. - if all_advantages: - advantages_tensor = torch.cat(all_advantages) - status_mean["advantages_mean"] = advantages_tensor.mean().item() - status_mean["advantages_std"] = advantages_tensor.std().item() - status_mean["advantages_max"] = advantages_tensor.max().item() - status_mean["advantages_min"] = advantages_tensor.min().item() - - if all_returns: - returns_tensor = torch.cat(all_returns) - status_mean["returns_mean"] = returns_tensor.mean().item() - status_mean["returns_std"] = returns_tensor.std().item() - - if all_response_lengths: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_response_lengths[0], torch.Tensor): - lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths]) - else: - lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device) - status_mean["response_length_mean"] = lengths_tensor.float().mean().item() - status_mean["response_length_std"] = lengths_tensor.float().std().item() - - # On-Policy Distillation metrics - if all_opd_reverse_kl: - # Collect reverse KL from all experiences - if isinstance(all_opd_reverse_kl[0], torch.Tensor): - opd_kl_tensor = torch.cat([t.to(device).float() for t in all_opd_reverse_kl]) - else: - opd_kl_tensor = torch.tensor(all_opd_reverse_kl, dtype=torch.float32, device=device) - # Mask out zero values (padding) - non_zero_mask = opd_kl_tensor != 0 - if non_zero_mask.any(): - masked_kl = opd_kl_tensor[non_zero_mask] - status_mean["opd_reverse_kl_mean"] = masked_kl.mean().item() - status_mean["opd_reverse_kl_std"] = masked_kl.std().item() - status_mean["opd_reverse_kl_max"] = masked_kl.max().item() - status_mean["opd_reverse_kl_min"] = masked_kl.min().item() - - # Print detailed reward breakdown (only on rank 0) - if self.print_replay_buffer_stats and self.strategy.is_rank_0(): - self.strategy.print("\n" + "=" * 60) - self.strategy.print("📊 Detailed Step Statistics") - self.strategy.print("=" * 60) + should_skip_local = not is_valid + + skip_flag = torch.tensor([1.0 if should_skip_local else 0.0], device=device) + torch.distributed.all_reduce(skip_flag, op=torch.distributed.ReduceOp.MAX) + if skip_flag.item() > 0: + if self.strategy.is_rank_0(): + pbar.set_description( + f"Train epoch [{epoch + 1}/{self.max_epochs}] (skipping invalid batch)" + ) + continue + + entropy_mask = None + if hasattr(experience, 'action_entropy') and experience.action_entropy is not None: + if self.high_entropy_token_ratio > 0.0: + entropy_mask = create_high_entropy_mask( + experience.action_entropy, + experience.action_mask, + self.high_entropy_token_ratio, + ) + + with self.profiler.section("learn/micro_batch_total"): + status = self.training_step(experience, global_steps, entropy_mask=entropy_mask) + + if "kl" in status: + status["kl"] *= status["response_length"] + status = self.strategy.all_reduce(status) + status["kl"] /= status["response_length"] + + short_status = {} + if "policy_loss" in status: + short_status = { + "pg": status["policy_loss"], + "rm": status["reward"], + "ret": status["return"], + "glen": status["response_length"], + "tlen": status["total_length"], + "kl": status["kl"], + "act_lr": status["actor_lr"], + } + if "critic_loss" in status: + short_status["cri"] = status["critic_loss"] + short_status["vals"] = status["values"] + short_status["cri_lr"] = status["critic_lr"] + if "ptx_loss" in status: + short_status["ptx"] = status["ptx_loss"] + + status_list.append(status) + pbar.set_postfix(short_status) + + if status_list: + status_mean = status_list[0] + for metric_dict in status_list[1:]: + for key, value in metric_dict.items(): + status_mean[key] += value + for key in status_mean.keys(): + status_mean[key] /= len(status_list) + + if self.replay_buffer.items: + all_rewards = [] + reward_metric_values = defaultdict(list) + all_advantages = [] + all_returns = [] + all_response_lengths = [] + all_total_lengths = [] + all_opd_reverse_kl = [] + + for item in self.replay_buffer.items: + if hasattr(item, 'info') and item.info is not None and 'reward' in item.info: + all_rewards.append(item.info['reward']) + if hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info: + reward_metrics = item.info['reward_metrics'] + if reward_metrics is not None: + for key, value in reward_metrics.items(): + reward_metric_values[key].append(value) + if hasattr(item, 'advantages') and item.advantages is not None: + all_advantages.append(item.advantages) + if hasattr(item, 'returns') and item.returns is not None: + all_returns.append(item.returns) + if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: + all_response_lengths.append(item.info['response_length']) + if hasattr(item, 'info') and item.info is not None and 'total_length' in item.info: + all_total_lengths.append(item.info['total_length']) + if hasattr(item, 'info') and item.info is not None and 'opd_reverse_kl' in item.info: + all_opd_reverse_kl.append(item.info['opd_reverse_kl']) if all_rewards: - self.strategy.print( - f"🎁 Total Reward: {status_mean['step_reward_mean']:.4f} ± {status_mean['step_reward_std']:.4f} " # noqa - f"(min={status_mean['step_reward_min']:.4f}, max={status_mean['step_reward_max']:.4f})" - ) - - if all_format_rewards: - self.strategy.print( - f"📝 Format Reward: {status_mean['format_reward_mean']:.4f} ± {status_mean['format_reward_std']:.4f}" # noqa - ) - - if all_accuracy_rewards: - self.strategy.print( - f"✅ Accuracy Reward: {status_mean['accuracy_reward_mean']:.4f} ± {status_mean['accuracy_reward_std']:.4f}" # noqa - ) - - if all_general_model_rewards and "general_model_reward_mean" in status_mean: + if isinstance(all_rewards[0], torch.Tensor): + rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards]) + else: + rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device) + status_mean["step_reward_mean"] = rewards_tensor.mean().item() + status_mean["step_reward_std"] = rewards_tensor.std().item() + status_mean["step_reward_max"] = rewards_tensor.max().item() + status_mean["step_reward_min"] = rewards_tensor.min().item() + status_mean["step_reward_zero_ratio"] = (rewards_tensor == 0).float().mean().item() + status_mean["step_reward_one_ratio"] = (rewards_tensor == 1).float().mean().item() + + for metric_name, values in reward_metric_values.items(): + if not values: + continue + if isinstance(values[0], torch.Tensor): + metric_tensor = torch.cat([t.to(device).float() for t in values]) + else: + metric_tensor = torch.tensor(values, dtype=torch.float32, device=device) + if metric_tensor.numel() == 0: + continue + if metric_name == "general_model_reward" and metric_tensor.abs().sum() == 0: + continue + status_mean[f"{metric_name}_mean"] = metric_tensor.mean().item() + status_mean[f"{metric_name}_std"] = metric_tensor.std().item() + + if "general_model_reward_mean" in status_mean: self.strategy.print(f"🧠 General RM Reward:{status_mean['general_model_reward_mean']:.4f}") if all_advantages: - self.strategy.print( - f"📈 Advantages: {status_mean['advantages_mean']:.4f} ± {status_mean['advantages_std']:.4f} " # noqa - f"(min={status_mean['advantages_min']:.4f}, max={status_mean['advantages_max']:.4f})" - ) + advantages_tensor = torch.cat(all_advantages) + status_mean["advantages_mean"] = advantages_tensor.mean().item() + status_mean["advantages_std"] = advantages_tensor.std().item() + status_mean["advantages_max"] = advantages_tensor.max().item() + status_mean["advantages_min"] = advantages_tensor.min().item() if all_returns: - self.strategy.print( - f"💰 Returns: {status_mean['returns_mean']:.4f} ± {status_mean['returns_std']:.4f}" - ) + returns_tensor = torch.cat(all_returns) + status_mean["returns_mean"] = returns_tensor.mean().item() + status_mean["returns_std"] = returns_tensor.std().item() if all_response_lengths: - self.strategy.print( - f"📏 Response Length: {status_mean['response_length_mean']:.1f} ± {status_mean['response_length_std']:.1f} tokens" # noqa - ) + if isinstance(all_response_lengths[0], torch.Tensor): + lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths]) + else: + lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device) + status_mean["response_length_mean"] = lengths_tensor.float().mean().item() + status_mean["response_length_std"] = lengths_tensor.float().std().item() + status_mean["response_length_zero_ratio"] = (lengths_tensor <= 1).float().mean().item() + generate_max_len = getattr(self.args, "generate_max_len", None) + if generate_max_len: + status_mean["response_hit_max_ratio"] = (lengths_tensor + >= float(generate_max_len - 1)).float().mean().item() + + if all_total_lengths: + if isinstance(all_total_lengths[0], torch.Tensor): + total_lengths_tensor = torch.cat([t.to(device).float() for t in all_total_lengths]) + else: + total_lengths_tensor = torch.tensor(all_total_lengths, dtype=torch.float32, device=device) + status_mean["total_length_mean"] = total_lengths_tensor.float().mean().item() + status_mean["total_length_std"] = total_lengths_tensor.float().std().item() + + if all_opd_reverse_kl: + if isinstance(all_opd_reverse_kl[0], torch.Tensor): + opd_kl_tensor = torch.cat([t.to(device).float() for t in all_opd_reverse_kl]) + else: + opd_kl_tensor = torch.tensor(all_opd_reverse_kl, dtype=torch.float32, device=device) + non_zero_mask = opd_kl_tensor != 0 + if non_zero_mask.any(): + masked_kl = opd_kl_tensor[non_zero_mask] + status_mean["opd_reverse_kl_mean"] = masked_kl.mean().item() + status_mean["opd_reverse_kl_std"] = masked_kl.std().item() + status_mean["opd_reverse_kl_max"] = masked_kl.max().item() + status_mean["opd_reverse_kl_min"] = masked_kl.min().item() + + if self.print_replay_buffer_stats and self.strategy.is_rank_0(): + self.strategy.print("\n" + "=" * 60) + self.strategy.print("📊 Detailed Step Statistics") + self.strategy.print("=" * 60) + + if all_rewards: + self.strategy.print( + f"🎁 Total Reward: {status_mean['step_reward_mean']:.4f} ± {status_mean['step_reward_std']:.4f} " # noqa + f"(min={status_mean['step_reward_min']:.4f}, max={status_mean['step_reward_max']:.4f})" + ) + self.strategy.print( + f" Reward Ratios: zero={status_mean['step_reward_zero_ratio']:.4f}, " + f"one={status_mean['step_reward_one_ratio']:.4f}" + ) - if all_opd_reverse_kl and 'opd_reverse_kl_mean' in status_mean: - self.strategy.print( - f"🎓 OPD Reverse KL: {status_mean['opd_reverse_kl_mean']:.4f} ± {status_mean['opd_reverse_kl_std']:.4f} " # noqa - f"(min={status_mean['opd_reverse_kl_min']:.4f}, max={status_mean['opd_reverse_kl_max']:.4f})" - ) + reward_metric_print_order = [ + ("format_reward", "📝 Format Reward"), + ("accuracy_reward", "✅ Accuracy Reward"), + ("outcome_correct", "🎯 Outcome Correct"), + ("model_reward", "🤖 Model Reward"), + ("final_reward", "🏁 Final Reward"), + ("rule_reward", "⚖️ Rule Reward"), + ("max_relative_drop", "📉 Max Relative Drop"), + ("has_drop_moment", "🪂 Drop Moment"), + ("step_score_min", "🔬 Step Score Min"), + ("step_score_mean", "🔬 Step Score Mean"), + ("step_score_last", "🔬 Step Score Last"), + ("step_count", "🧮 Step Count"), + ] + for metric_name, title in reward_metric_print_order: + mean_key = f"{metric_name}_mean" + std_key = f"{metric_name}_std" + if mean_key in status_mean: + self.strategy.print(f"{title:<20} {status_mean[mean_key]:.4f} ± {status_mean[std_key]:.4f}") + + if all_advantages: + self.strategy.print( + f"📈 Advantages: {status_mean['advantages_mean']:.4f} ± {status_mean['advantages_std']:.4f} " # noqa + f"(min={status_mean['advantages_min']:.4f}, max={status_mean['advantages_max']:.4f})" + ) - self.strategy.print("=" * 60 + "\n") + if all_returns: + self.strategy.print( + f"💰 Returns: {status_mean['returns_mean']:.4f} ± {status_mean['returns_std']:.4f}" + ) - torch.cuda.empty_cache() + if all_response_lengths: + self.strategy.print( + f"📏 Response Length: {status_mean['response_length_mean']:.1f} ± {status_mean['response_length_std']:.1f} tokens" # noqa + ) + self.strategy.print( + f" Length Ratios: empty={status_mean['response_length_zero_ratio']:.4f}, " + f"hit_max={status_mean.get('response_hit_max_ratio', 0.0):.4f}" + ) - self.strategy.maybe_offload_optimizer(self.actor_optim) - torch.cuda.synchronize() - torch.cuda.empty_cache() - self.strategy.print(f"PPO Train TIMECOST {time.time() - train_begin}") - self.strategy.report_memory("after train, opt offloaded, before update weights") - self.strategy.print(torch.cuda.memory_summary()) - self.strategy.update_engine_weights(self.actor) - - # Save trajectories at the end of ppo_train, BEFORE replay buffer is cleared - # This ensures we have data to save when trajectory saving is enabled - if global_steps % self.args.save_steps == 0: - self.save_trajectories(global_steps) + if all_total_lengths: + self.strategy.print( + f"📦 Total Length: {status_mean['total_length_mean']:.1f} ± {status_mean['total_length_std']:.1f} tokens" # noqa + ) + + self.strategy.print("=" * 60 + "\n") + + torch.cuda.empty_cache() + self.strategy.maybe_offload_optimizer(self.actor_optim) + torch.cuda.synchronize() + torch.cuda.empty_cache() + self.strategy.print(f"PPO Train TIMECOST {time.time() - train_begin}") + self.strategy.report_memory("after train, opt offloaded, before update weights") + self.strategy.print(torch.cuda.memory_summary()) + with self.profiler.section("learn/update_engine_weights"): + self.strategy.update_engine_weights(self.actor) + + if global_steps % self.args.save_steps == 0: + with self.profiler.section("learn/save_trajectories"): + self.save_trajectories(global_steps) return status_mean diff --git a/lightrft/utils/__init__.py b/lightrft/utils/__init__.py index 01600077..8d090bb7 100644 --- a/lightrft/utils/__init__.py +++ b/lightrft/utils/__init__.py @@ -5,6 +5,7 @@ """ from .logging_utils import init_logger +from .profile_recorder import StepProfileRecorder from .remote_rm_utils import remote_rm_fn from .trajectory_saver import TrajectorySaver, create_trajectory_saver from .distributed_sampler import DistributedSampler @@ -21,6 +22,7 @@ __all__ = [ # logging and trajectory "init_logger", + "StepProfileRecorder", "remote_rm_fn", 'TrajectorySaver', 'create_trajectory_saver', diff --git a/lightrft/utils/cli_args.py b/lightrft/utils/cli_args.py index c282b422..84c3204a 100644 --- a/lightrft/utils/cli_args.py +++ b/lightrft/utils/cli_args.py @@ -47,6 +47,36 @@ def add_arguments(parser: argparse.ArgumentParser) -> None: help="Fraction of GPU memory reserved for the inference engine's KV cache (range: 0.0 to 1.0). " "Higher values improve throughput but may risk out-of-memory errors.", ) + parser.add_argument( + "--local_hf_generate_max_batch_size", + type=int, + default=0, + help="Maximum per-call sample batch size for the local HuggingFace inference engine. " + "A value <= 0 keeps the current single-shot behavior. " + "Set this to a small positive integer (for example 1 or 2) to chunk local HF rollout generation " + "when multimodal models would otherwise OOM.", + ) + parser.add_argument( + "--local_hf_max_new_tokens", + type=int, + default=0, + help="Optional hard cap for max_new_tokens applied only to the local HuggingFace rollout path. " + "A value <= 0 keeps the launcher-provided generation length unchanged.", + ) + parser.add_argument( + "--hf_separate_rollout_actor", + action="store_true", + default=False, + help="Use a dedicated local HF rollout actor instead of reusing the training actor directly. " + "The current implementation is intended for FSDP-based quick rollout experiments.", + ) + parser.add_argument( + "--hf_separate_rollout_keep_on_gpu", + action="store_true", + default=False, + help="For local HF separate rollout: keep the rollout actor resident on GPU instead of sleeping/offloading " + "it between rollout and training phases. Use this when per-rank memory headroom is sufficient.", + ) parser.add_argument( "--enable_engine_sleep", action="store_true", @@ -117,7 +147,6 @@ def add_arguments(parser: argparse.ArgumentParser) -> None: help="Interval (in training steps) for plotting and saving the generated sequence length distribution. " "Only effective if `--log_dir` is set.", ) - # for rewards models parser.add_argument( "--rm_use_engine", diff --git a/lightrft/utils/profile_recorder.py b/lightrft/utils/profile_recorder.py new file mode 100644 index 00000000..52bde311 --- /dev/null +++ b/lightrft/utils/profile_recorder.py @@ -0,0 +1,518 @@ +""" +Step-level profiling utilities for long-running training jobs. + +This module provides a lightweight profiler that is suitable for production +training loops: + +- Measures named sections in wall-clock seconds. +- Persists per-step summaries to JSONL with flush + fsync. +- Maintains a continuously refreshed latest snapshot so interrupted jobs still + leave readable profiling state on disk. +- Optionally emits sampled ``torch.profiler`` traces on rank 0. +""" + +from __future__ import annotations + +import json +import os +import threading +import time +from contextlib import contextmanager, nullcontext +from pathlib import Path +from typing import Dict, Iterator, List, Optional + +import torch + + +class _DummyTorchProfiler: + def start(self) -> None: + pass + + def stop(self) -> None: + pass + + def step(self) -> None: + pass + + +class StepProfileRecorder: + """ + Persistent step profiler for distributed training. + + The recorder keeps local timing state on every rank, aggregates it at + train-step boundaries, and writes the aggregated profile on rank 0. + """ + + TRACE_WAIT_STEPS = 1 + TRACE_WARMUP_STEPS = 1 + TRACE_ACTIVE_STEPS = 2 + TRACE_REPEAT = 2 + HEARTBEAT_INTERVAL_S = 1.0 + + def __init__(self, enabled: bool, output_dir: str, print_fn=None) -> None: + self.enabled = bool(enabled) + self.output_dir = Path(output_dir) + self.print_fn = print_fn + self.rank = torch.distributed.get_rank() if self._dist_enabled() else 0 + self.world_size = torch.distributed.get_world_size() if self._dist_enabled() else 1 + self.is_rank_0 = self.rank == 0 + + self.current_step: Optional[int] = None + self.current_episode: Optional[int] = None + self.current_step_start_wall: Optional[float] = None + self.current_step_started_at: Optional[float] = None + self.section_totals: Dict[str, float] = {} + self.phase_stack: List[str] = [] + self.active_section_name: Optional[str] = None + self.active_section_start_wall: Optional[float] = None + self.last_section_name: Optional[str] = None + self.last_section_elapsed_s: Optional[float] = None + self._state_lock = threading.Lock() + self._write_lock = threading.Lock() + self._heartbeat_stop = threading.Event() + self._heartbeat_thread: Optional[threading.Thread] = None + self._snapshot_generation = 0 + + self.rank_step_profile_path = self.output_dir / f"step_profile.rank{self.rank}.jsonl" + self.rank_latest_profile_path = self.output_dir / f"step_profile.rank{self.rank}.latest.json" + self.rank_current_profile_path = self.output_dir / f"step_profile.rank{self.rank}.current.json" + self.step_profile_path = self.output_dir / "step_profile.global.jsonl" + self.latest_profile_path = self.output_dir / "step_profile.latest.json" + self.current_profile_path = self.output_dir / "step_profile.current.json" + self.trace_dir = self.output_dir / "traces" + + self._torch_profiler = None + if self.enabled: + self.output_dir.mkdir(parents=True, exist_ok=True) + if self.is_rank_0: + self.trace_dir.mkdir(parents=True, exist_ok=True) + self._torch_profiler = self._build_torch_profiler() + self._torch_profiler.start() + self._start_heartbeat() + + @staticmethod + def _dist_enabled() -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() + + @staticmethod + def _cuda_sync_if_available() -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def _build_torch_profiler(self): + if not self.is_rank_0: + return _DummyTorchProfiler() + + from torch.profiler import ProfilerActivity + + return torch.profiler.profile( + schedule=torch.profiler.schedule( + wait=self.TRACE_WAIT_STEPS, + warmup=self.TRACE_WARMUP_STEPS, + active=self.TRACE_ACTIVE_STEPS, + repeat=self.TRACE_REPEAT, + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler(str(self.trace_dir)), + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + with_stack=False, + profile_memory=False, + ) + + def start_step(self, train_step: int, episode: int) -> None: + if not self.enabled: + return + + self._cuda_sync_if_available() + with self._state_lock: + self._snapshot_generation += 1 + self.current_step = int(train_step) + self.current_episode = int(episode) + self.current_step_start_wall = time.perf_counter() + self.current_step_started_at = time.time() + self.section_totals = {} + self.phase_stack = [] + self.active_section_name = None + self.active_section_start_wall = None + self.last_section_name = None + self.last_section_elapsed_s = None + self._write_current_snapshot() + + @contextmanager + def phase(self, phase_name: str) -> Iterator[None]: + if not self.enabled: + yield + return + + cleaned = phase_name.strip("/") + if not cleaned: + yield + return + + self.phase_stack.append(cleaned) + try: + yield + finally: + self.phase_stack.pop() + + @contextmanager + def section(self, name: str) -> Iterator[None]: + if not self.enabled or self.current_step is None: + yield + return + + full_name = self._qualify_name(name) + record_ctx = torch.profiler.record_function(full_name) if self.is_rank_0 else nullcontext() + + self._cuda_sync_if_available() + start = time.perf_counter() + with self._state_lock: + self.active_section_name = full_name + self.active_section_start_wall = start + self._write_current_snapshot() + with record_ctx: + try: + yield + finally: + self._cuda_sync_if_available() + elapsed = time.perf_counter() - start + with self._state_lock: + self.section_totals[full_name] = self.section_totals.get(full_name, 0.0) + elapsed + self.active_section_name = None + self.active_section_start_wall = None + self.last_section_name = full_name + self.last_section_elapsed_s = elapsed + self._write_current_snapshot() + + def _qualify_name(self, name: str) -> str: + cleaned = name.strip("/") + if not self.phase_stack: + return cleaned + return "/".join([*self.phase_stack, cleaned]) + + def finish_step(self, extra: Optional[Dict] = None) -> Optional[Dict]: + if not self.enabled or self.current_step is None or self.current_step_start_wall is None: + return None + + self._cuda_sync_if_available() + with self._state_lock: + train_step = int(self.current_step) + episode = int(self.current_episode) if self.current_episode is not None else None + started_at = self.current_step_started_at + total_elapsed = time.perf_counter() - self.current_step_start_wall + local_sections = dict(self.section_totals) + local_sections["step/total"] = total_elapsed + self._snapshot_generation += 1 + self.current_step = None + self.current_episode = None + self.current_step_start_wall = None + self.current_step_started_at = None + self.section_totals = {} + self.phase_stack = [] + self.active_section_name = None + self.active_section_start_wall = None + self.last_section_name = None + self.last_section_elapsed_s = None + + local_step_total_s = local_sections.get("step/total", total_elapsed) + local_ratios = { + name: (value / local_step_total_s if local_step_total_s > 0 else 0.0) + for name, value in local_sections.items() + } + local_record = { + "train_step": train_step, + "episode": episode, + "rank": self.rank, + "world_size": self.world_size, + "started_at": started_at, + "finished_at": time.time(), + "sections_local_s": local_sections, + "sections_local_ratio": local_ratios, + } + if extra: + local_record["extra"] = extra + self._append_jsonl(self.rank_step_profile_path, local_record) + self._write_atomic_json(self.rank_latest_profile_path, local_record) + self._write_atomic_json(self.rank_current_profile_path, local_record) + + gathered_sections = self._gather_sections(local_sections) + self._torch_profiler.step() + + result = None + if self.is_rank_0: + aggregated = self._aggregate_sections(gathered_sections) + step_total_s = aggregated["max_s"].get("step/total", total_elapsed) + ratios = { + name: (value / step_total_s if step_total_s > 0 else 0.0) + for name, value in aggregated["max_s"].items() + } + mean_ratios = { + name: ( + value / aggregated["mean_s"].get("step/total", step_total_s) + if aggregated["mean_s"].get("step/total", step_total_s) > 0 else 0.0 + ) + for name, value in aggregated["mean_s"].items() + } + record = { + "train_step": train_step, + "episode": episode, + "world_size": self.world_size, + "available_ranks": list(range(self.world_size)), + "started_at": started_at, + "finished_at": time.time(), + "sections_max_s": aggregated["max_s"], + "sections_mean_s": aggregated["mean_s"], + "sections_max_ratio": ratios, + "sections_mean_ratio": mean_ratios, + } + if extra: + record["extra"] = extra + self._append_jsonl(self.step_profile_path, record) + self._write_atomic_json(self.latest_profile_path, record) + self._write_atomic_json(self.current_profile_path, record) + result = { + "record": record, + "wandb_logs": self._build_wandb_logs(train_step, aggregated["max_s"], ratios), + "summary": self._build_summary(aggregated["max_s"], ratios), + } + return result + + def close(self) -> None: + if not self.enabled: + return + self._heartbeat_stop.set() + if self._heartbeat_thread is not None: + self._heartbeat_thread.join(timeout=max(self.HEARTBEAT_INTERVAL_S * 2, 2.0)) + if self._torch_profiler is not None: + self._torch_profiler.stop() + + def _gather_sections(self, local_sections: Dict[str, float]) -> List[Dict[str, float]]: + if not self._dist_enabled(): + return [local_sections] + + gathered_sections = [None for _ in range(self.world_size)] + torch.distributed.all_gather_object(gathered_sections, local_sections) + return [item or {} for item in gathered_sections] + + @staticmethod + def _aggregate_sections(section_list: List[Dict[str, float]]) -> Dict[str, Dict[str, float]]: + section_names = sorted({name for section_dict in section_list for name in section_dict}) + max_s: Dict[str, float] = {} + mean_s: Dict[str, float] = {} + world_size = max(len(section_list), 1) + + for name in section_names: + values = [float(section_dict.get(name, 0.0)) for section_dict in section_list] + max_s[name] = max(values) + mean_s[name] = sum(values) / world_size + + return {"max_s": max_s, "mean_s": mean_s} + + @staticmethod + def _flatten_section_name(name: str) -> str: + return name.replace("/", "_").replace(" ", "_") + + def _build_wandb_logs(self, train_step: int, max_s: Dict[str, float], ratios: Dict[str, float]) -> Dict[str, float]: + logs = {"profile/train_step": train_step} + for name, value in max_s.items(): + flat_name = self._flatten_section_name(name) + logs[f"profile/{flat_name}_s"] = value + logs[f"profile/{flat_name}_ratio"] = ratios.get(name, 0.0) + return logs + + @staticmethod + def _build_summary(max_s: Dict[str, float], ratios: Dict[str, float]) -> str: + interesting_sections = [ + "collect/total", + "collect/generate", + "learn/total", + "learn/update_engine_weights", + "eval/total", + "checkpoint/total", + ] + parts = [] + step_total = max_s.get("step/total", 0.0) + parts.append(f"step_total={step_total:.2f}s") + for name in interesting_sections: + if name not in max_s: + continue + parts.append(f"{name}={max_s[name]:.2f}s ({ratios.get(name, 0.0):.1%})") + return "profile: " + ", ".join(parts) + + def _write_current_snapshot(self) -> None: + snapshot = self._build_current_snapshot() + if snapshot is None: + return + if not self._write_snapshot_if_current(self.rank_current_profile_path, snapshot): + return + if not self.is_rank_0: + return + + global_snapshot = self._build_global_current_snapshot(snapshot) + if global_snapshot is None: + return + self._write_atomic_json(self.current_profile_path, global_snapshot) + + def _build_current_snapshot(self) -> Optional[Dict]: + if not self.enabled: + return None + + with self._state_lock: + if self.current_step is None or self.current_step_start_wall is None: + return None + + current_elapsed_s = max(time.perf_counter() - self.current_step_start_wall, 0.0) + sections_local_s = dict(self.section_totals) + active_section_elapsed_s = None + if self.active_section_name is not None and self.active_section_start_wall is not None: + active_section_elapsed_s = max(time.perf_counter() - self.active_section_start_wall, 0.0) + sections_local_s[self.active_section_name + ] = (sections_local_s.get(self.active_section_name, 0.0) + active_section_elapsed_s) + current_ratios = { + name: (value / current_elapsed_s if current_elapsed_s > 0 else 0.0) + for name, value in sections_local_s.items() + } + snapshot = { + "train_step": self.current_step, + "episode": self.current_episode, + "rank": self.rank, + "world_size": self.world_size, + "started_at": self.current_step_started_at, + "partial": True, + "current_elapsed_s": current_elapsed_s, + "sections_local_s": sections_local_s, + "sections_local_ratio": current_ratios, + "_snapshot_generation": self._snapshot_generation, + } + if self.last_section_name is not None: + snapshot["last_section"] = self.last_section_name + if self.last_section_elapsed_s is not None: + snapshot["last_elapsed_s"] = self.last_section_elapsed_s + if self.active_section_name is not None: + snapshot["active_section"] = self.active_section_name + if active_section_elapsed_s is not None: + snapshot["active_section_elapsed_s"] = active_section_elapsed_s + return snapshot + + def _build_global_current_snapshot(self, rank0_snapshot: Dict) -> Optional[Dict]: + current_step = rank0_snapshot.get("train_step") + current_episode = rank0_snapshot.get("episode") + if current_step is None: + return None + + snapshots = [] + available_ranks = [] + active_sections = {} + for rank in range(self.world_size): + if rank == self.rank: + candidate = dict(rank0_snapshot) + else: + candidate = self._read_json(self.output_dir / f"step_profile.rank{rank}.current.json") + if not candidate: + continue + if candidate.get("train_step") != current_step or candidate.get("episode") != current_episode: + continue + snapshots.append(candidate) + available_ranks.append(rank) + active_section = candidate.get("active_section") + if active_section: + active_sections[f"rank{rank}"] = active_section + + if not snapshots: + return None + + aggregated = self._aggregate_sections([snapshot.get("sections_local_s", {}) for snapshot in snapshots]) + elapsed_values = [float(snapshot.get("current_elapsed_s", 0.0)) for snapshot in snapshots] + max_elapsed = max(elapsed_values) if elapsed_values else 0.0 + mean_elapsed = sum(elapsed_values) / len(elapsed_values) if elapsed_values else 0.0 + max_ratios = { + name: (value / max_elapsed if max_elapsed > 0 else 0.0) + for name, value in aggregated["max_s"].items() + } + mean_ratios = { + name: (value / mean_elapsed if mean_elapsed > 0 else 0.0) + for name, value in aggregated["mean_s"].items() + } + started_at_candidates = [ + snapshot.get("started_at") for snapshot in snapshots if snapshot.get("started_at") is not None + ] + + global_snapshot = { + "train_step": current_step, + "episode": current_episode, + "world_size": self.world_size, + "available_ranks": available_ranks, + "num_rank_snapshots": len(snapshots), + "started_at": min(started_at_candidates) if started_at_candidates else None, + "partial": True, + "current_elapsed_max_s": max_elapsed, + "current_elapsed_mean_s": mean_elapsed, + "sections_max_s": aggregated["max_s"], + "sections_mean_s": aggregated["mean_s"], + "sections_max_ratio": max_ratios, + "sections_mean_ratio": mean_ratios, + } + if active_sections: + global_snapshot["active_sections"] = active_sections + return global_snapshot + + def _start_heartbeat(self) -> None: + self._heartbeat_thread = threading.Thread( + target=self._heartbeat_loop, + name="step-profile-heartbeat", + daemon=True, + ) + self._heartbeat_thread.start() + + def _heartbeat_loop(self) -> None: + while not self._heartbeat_stop.wait(self.HEARTBEAT_INTERVAL_S): + self._write_current_snapshot() + + def _write_snapshot_if_current(self, path: Path, payload: Dict) -> bool: + snapshot_generation = payload.get("_snapshot_generation") + if snapshot_generation is None: + self._write_atomic_json(path, payload) + return True + + sanitized_payload = dict(payload) + sanitized_payload.pop("_snapshot_generation", None) + with self._write_lock: + with self._state_lock: + if snapshot_generation != self._snapshot_generation: + return False + self._write_atomic_json_unlocked(path, sanitized_payload) + return True + + def _append_jsonl(self, path: Path, payload: Dict) -> None: + with self._write_lock: + self._append_jsonl_unlocked(path, payload) + + @staticmethod + def _append_jsonl_unlocked(path: Path, payload: Dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as f: + f.write(json.dumps(payload, ensure_ascii=False, sort_keys=True) + "\n") + f.flush() + os.fsync(f.fileno()) + + def _write_atomic_json(self, path: Path, payload: Dict) -> None: + with self._write_lock: + self._write_atomic_json_unlocked(path, payload) + + @staticmethod + def _write_atomic_json_unlocked(path: Path, payload: Dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + with tmp_path.open("w", encoding="utf-8") as f: + json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, path) + + @staticmethod + def _read_json(path: Path) -> Optional[Dict]: + try: + with path.open("r", encoding="utf-8") as f: + return json.load(f) + except (FileNotFoundError, json.JSONDecodeError, OSError): + return None diff --git a/requirements.txt b/requirements.txt index 56a06651..565ace1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,3 +25,11 @@ librosa qwen-vl-utils sympy matplotlib +# URSA-MATH dependencies +attrdict +numpy +pandas +Pillow +regex +timm +torchvision