Task-agnostic implementation of Training Language Models to Self-Correct via Reinforcement Learning (Kumar et al., ICLR 2025), built on Unsloth LoRA for small open-weight models.
Core algorithm is as in the paper: Stage I anchors attempt 1 to the reference policy via KL, Stage II adds the reward-shaping bonus α · (r(y2) − r(y1)). Adaptations we introduced:
- Reasoning-tag format from DeepSeek-R1 (the GRPO paper): model reasons inside
<think>...</think>and gives the final answer inside<answer>...</answer>. The compound reward below scores this format directly. - K3 KL estimator (Schulman):
K3 = exp(log π_ref − log π) − (log π_ref − log π) − 1. Unbiased forward KL from per-token log-probs only — no[B, T, V]log-softmax tensor, which is what makes Stage II's two-graph backward fit in memory. - Optional Dr.GRPO-style length normalization (opt-in via
train.length_norm, not the default): the default"sequence"divides the PG and K3 KL terms by the actual per-sample generated length — the original form. Setting"constant"instead divides bymax_new_tokens(Dr.GRPO), which avoids the length bias where the policy games the loss by inflating generation length (each extra low-KL / low-grad token otherwise dilutes the per-token signal). Flip it in the YAML if you want it. - Compound reward (
format_and_match): 0.25 for a<think>...</think>pair, 0.25 for one<answer>...</answer>pair, 0.5 for extracted-answer match. Gives the α-bonus more signal than pure binary and explicitly anchors format. A stricterstrict_format_and_matchvariant (0.5 format / 0.5 match) requires the whole output to be exactly the two blocks — no prose between</think>and<answer>, nothing after</answer>. Select either viareward.fnin the YAML. - LoRA-only: reference policy = same model with
model.disable_adapter(). No second model in VRAM.
train.py # task-agnostic entry point
reward_function.py # @register_reward / @register_extractor + shipped fns
configs/gsm8k.yaml # primary target
configs/arithmetic.yaml # fast toy task — 5-digit subtraction Qwen3-0.6B can't solve; full run < 1h
configs/math.yaml # math baseline
build_hard_arithmetic.py # builds the toy-task dataset (greedy-decode, keep the failures)
profile_run.py # optional length/extraction diagnostic
train.py never references a specific task — model, dataset, prompts, reward, and extractor are all YAML.
pip install -r requirements.txt
# full run
python train.py --config configs/gsm8k.yaml
# fast toy task (~1h two-stage run) — good for debugging the loop end-to-end
python train.py --config configs/arithmetic.yaml
# 1-minute end-to-end smoke (N examples, capped eval)
python train.py --config configs/gsm8k.yaml --smoke 8
# skip stage 1 and resume stage 2 from a saved stage-1 adapter
python train.py --config configs/gsm8k.yaml \
--start-stage 2 --resume-adapter outputs/score-gsm8k/stage1/step_200Logs go to W&B (wandb_project in the config). End-of-stage adapters save to outputs/{run_name}/stage{1,2}/. Set train.checkpoint_every: N in the YAML to also save mid-stage to outputs/{run_name}/stage{1,2}/step_N/ every N optimizer steps — useful for resuming or for picking the best checkpoint by reward curve.
configs/arithmetic.yaml uses torchtrade/arithmetic-hard-qwen3-0.6b — 200 train / 50 eval 5-digit subtractions that Qwen3-0.6B fails greedy-decoded (built by build_hard_arithmetic.py; see the dataset card for source + method). It's a debug/shake-out task, not a benchmark: difficulty is defined relative to that one model, so a full two-stage run finishes in ~1h and stresses every part of the loop.
One full two-stage run of configs/arithmetic.yaml (Qwen3-0.6B, 50-example held-out eval):
| checkpoint | acc attempt 1 | acc attempt 2 | Δ (attempt 2 − attempt 1) |
|---|---|---|---|
| pretrain | 0.42 | 0.51 | +0.09 |
| after Stage I | 0.57 | 0.65 | +0.08 |
| after Stage II | 0.71 | 0.735 | +0.025 |
Total accuracy climbs steadily through both stages — +29 pts on attempt 1 (0.42 → 0.71) and +22.5 pts on attempt 2 (0.51 → 0.735). The held-out Δ narrows as attempt 1 itself gets strong (less left to correct), but the self-correction signal shows up clearly in the training rollouts: the attempt-2-minus-attempt-1 reward gap flips from −0.38 in Stage I to +0.25 in Stage II once the α · (r(y2) − r(y1)) bonus kicks in — Stage II is what turns a worse second attempt into a better one.
- Push your dataset to the HF Hub with
trainandtestsplits. For datasets needing a sub-config (e.g.openai/gsm8khasmain/socratic), setdataset.config_name. - Add an answer extractor in
reward_function.pyif the shipped ones don't fit (gsm8k_hash,math_final_answer,identity). - Add a reward function if none of the shipped ones fit.
- Copy
configs/gsm8k.yaml, editmodel.*/dataset.*/prompts.*/reward.*. - Smoke first (
--smoke 8), then full.