refactor(sunjx): refactor dataset and reward module#13
Open
Jiaxuan-Sun wants to merge 6 commits into
Open
Conversation
puyuan1996
requested changes
Jan 4, 2026
80 tasks
HansBug
added a commit
to HansBug/LightRFT
that referenced
this pull request
May 8, 2026
…nt 2
URSA paper variant 2 figure-ablation describes per-step PRM reward but
does not specify how the resulting N step_scores integrate with the GRPO
baseline-subtraction convention. Two interpretations:
raw : scatter raw sigmoid step_score directly (paper figure
ablation). PRM gives all valid steps a positive value
(typically 0.6-0.95) so token-level returns are nearly
always positive — weak PG signal (smoke pg ~ 6e-5,
~10^3 weaker than PSGRPO baseline). Reproduces paper's
"variant 2 underperforms" observation.
group_norm : for each step k, subtract group mean and divide by group
std across the K trajectories sharing the same prompt
BEFORE scattering. Matches GRPO convention: zero-mean
signed advantages with magnitude ~1, restoring directional
PG signal.
Implementation:
- examples/math_prm/train_colocate.py: --per_step_reward_mode CLI flag,
default "raw" (preserves prior behavior).
- lightrft/trainer/fast_exp_maker.py::_apply_step_reward_group_norm:
cross-experience step-level normalization at the entry of
_compute_advantages_and_returns. Reshapes per-traj step_rewards to
(G, K, max_steps), computes masked mean/std along K dim (padding from
step_token_indices < 0 is masked), writes back into experience.info.
Downstream compute_reward scatter logic unchanged.
- examples/math_prm/run_smoke_per_step_prm_groupnorm.sh: smoke variant
exercising --per_step_reward_mode group_norm.
Validation:
- Math sanity (4 trajectories, geometry holdout opendilab#13): per-step
group-normalized values have mean=0, std=1 across all 5 steps.
- Smoke v6 end-to-end (K=2 + 1 PPO step + 500-sample eval):
alignment_failed=0.00%, outcome=0.5952, n_aligned_steps=4.97 — no
regression vs raw mode (small K can't yet show group_norm benefit;
effective ablation requires K=8 + multi-step training).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
1. Dataset Module Refactoring (
lightrft/datasets/)Modified:
__init__.py: Refactored imports with unified interfaces and improved optional dependency handlingAdded:
config.py:DatasetConfigclassdata_pathanddata_probs(supports string/list)for_train(),for_eval(),for_pretrain()loader.py:DatasetLoaderclassblending_datasetsparametersPromptDatasetVLandSFTDatasetVL2. Reward Module (
lightrft/reward/)Added:
__init__.py: Module entry point with unified exportsbase.py:BaseRewardabstract base classcompute()method signature(rewards, metrics)rule.py:RuleRewardclass<think>tags,\boxed{}notation)default,geo3k_*,gsm8k_*model.py: Reward model implementationsSingleRewardModel: Single reward model wrapper with auto load/offloadMultiRewardModel: Multiple reward model ensemble with recipe-based aggregationmanager.py:RewardManagerclassfrom_config()factory method