Skip to content
Open

RLLM #1232

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6095032
skywork and some qwrn changes
Jul 22, 2025
f3d876a
Removing think tokens
Jul 22, 2025
d7e9fd0
Merge branch 'ot_merge' into swarna/skyworkv2
swarnaHub Jul 22, 2025
164458b
Fixing GRMs
Jul 23, 2025
b162c05
Merge branch 'swarna/skyworkv2' of github.com:facebookresearch/fairse…
Jul 23, 2025
4fc9aea
Black
Jul 23, 2025
a6ad0a8
Merge branch online_training
Jul 29, 2025
f829d82
Import issue
Jul 29, 2025
9392f0d
add missing sw import
Aug 6, 2025
55dc622
Different configs for pairwise GRM
Aug 7, 2025
ee17161
Minor fix and more logging
Aug 7, 2025
18ff4c4
Online dpo: pairwise GRM should sample at least two rollouts
Aug 7, 2025
585b744
zero reward for rollouts not involved in pairwise judgments
Aug 7, 2025
510bdf2
simplifying
Aug 7, 2025
38aaf53
SequenceBatch seq_lens type ensure to be a list
Aug 9, 2025
a6ab8b0
add pairwsie J1 with reference answer
Aug 13, 2025
2004533
fix None ref answer
Aug 18, 2025
bfc255b
Pairwise with pivot changes
Aug 19, 2025
1f942d7
New pivot changes + cleanup
Aug 20, 2025
5eee4ee
Fix
Aug 20, 2025
5cdb6b9
Making pair type configurable
Aug 20, 2025
e14421d
Config change
Aug 21, 2025
8831f36
update prompt
Aug 24, 2025
9fc9dbb
some more logging
Aug 27, 2025
d14fb90
Merge branch 'swarna/skyworkv2' of github.com:facebookresearch/fairse…
Aug 27, 2025
b1ba0e2
Fixing typo in comment
Aug 27, 2025
d47ef15
kwise judgment support
Sep 2, 2025
d474070
Adding support for acemath
Sep 3, 2025
1162d60
Skywork-RM from hf
Sep 4, 2025
4ea811d
add parsed ref
Sep 6, 2025
6703f1b
update prompt template
Sep 7, 2025
fe84d9e
all comparisons in k-wise
Sep 10, 2025
e7137ac
Jacklanchantin/qwen (#1260)
jacklanchantin Sep 30, 2025
dfb958a
octothinker assets
Sep 30, 2025
a746129
Changes
Sep 30, 2025
1116f02
Merging
Sep 30, 2025
9355ce6
Minor changes
Oct 16, 2025
474a537
Logging judge input
Oct 22, 2025
8211664
Tracking a second reward (for debugging)
Oct 29, 2025
b9681ad
Clean up and principia changes
Jan 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/fairseq2/assets/cards/models/llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,34 @@ model_arch: llama3_1_8b
checkpoint: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer: "hg://deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer_family: llama
use_v2_tokenizer: true
use_v2_tokenizer: true

---

name: octothinker_8b_hybrid
model_family: llama
model_arch: llama3_1_8b
checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Hybrid-Base/
tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Hybrid-Base/
tokenizer_family: llama
use_v2_tokenizer: true

---

name: octothinker_8b_long
model_family: llama
model_arch: llama3_1_8b
checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Long-Base/
tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Long-Base/
tokenizer_family: llama
use_v2_tokenizer: true

---

name: octothinker_8b_short
model_family: llama
model_arch: llama3_1_8b
checkpoint: /datasets/pretrained-llms/OctoThinker-8B-Short-Base/
tokenizer: /datasets/pretrained-llms/OctoThinker-8B-Short-Base/
tokenizer_family: llama
use_v2_tokenizer: true
36 changes: 36 additions & 0 deletions src/fairseq2/recipes/lm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -42,12 +42,24 @@
from fairseq2.recipes.lm._online_finetune._generative_judge import (
GeneralVerifierExtractorHandler as GeneralVerifierExtractorHandler,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
PrincipiaExtractor as PrincipiaExtractor,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
PrincipiaExtractorHandler as PrincipiaExtractorHandler,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1PairwiseScoreExtractor as J1PairwiseScoreExtractor,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1PairwiseScoreExtractorHandler as J1PairwiseScoreExtractorHandler,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1KwiseScoreExtractor as J1KwiseScoreExtractor,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1KwiseScoreExtractorHandler as J1KwiseScoreExtractorHandler,
)
from fairseq2.recipes.lm._online_finetune._generative_judge import (
J1PointwiseExtractor as J1PointwiseExtractor,
)
Expand Down Expand Up @@ -84,6 +96,12 @@
from fairseq2.recipes.lm._online_finetune._remote_model import (
NoEnvGeneralVerifierPipeline as NoEnvGeneralVerifierPipeline,
)
from fairseq2.recipes.lm._online_finetune._remote_model import (
NoEnvAceMathRMPipeline as NoEnvAceMathRMPipeline,
)
from fairseq2.recipes.lm._online_finetune._remote_model import (
NoEnvSkyworkRMPipeline as NoEnvSkyworkRMPipeline,
)
from fairseq2.recipes.lm._online_finetune._remote_model import (
RemoteModelHandler as RemoteModelHandler,
)
Expand All @@ -93,6 +111,18 @@
from fairseq2.recipes.lm._online_finetune._rewards import (
AtheneVerifierHandler as AtheneVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
SkyworkVerifier as SkyworkVerifier,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
SkyworkVerifierHandler as SkyworkVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
AceMathVerifier as AceMathVerifier,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
AceMathVerifierHandler as AceMathVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
GenerativePairwiseVerifier as GenerativePairwiseVerifier,
)
Expand All @@ -105,6 +135,12 @@
from fairseq2.recipes.lm._online_finetune._rewards import (
GenerativePointwiseVerifierHandler as GenerativePointwiseVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
GenerativeKwiseVerifier as GenerativeKwiseVerifier,
)
from fairseq2.recipes.lm._online_finetune._rewards import (
GenerativeKwiseVerifierHandler as GenerativeKwiseVerifierHandler,
)
from fairseq2.recipes.lm._online_finetune._rewards import GSM8kVerifier as GSM8kVerifier
from fairseq2.recipes.lm._online_finetune._rewards import (
GSM8kVerifierHandler as GSM8kVerifierHandler,
Expand Down
132 changes: 123 additions & 9 deletions src/fairseq2/recipes/lm/_online_finetune/_common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -8,6 +8,7 @@

import contextlib
import io
import re
from dataclasses import dataclass
from typing import List, cast

Expand All @@ -17,14 +18,8 @@
from torch import Tensor
from vllm import RequestOutput

from fairseq2.data import (
CollateOptionsOverride,
Collater,
SequenceData,
)
from fairseq2.datasets import (
SequenceBatch,
)
from fairseq2.data import CollateOptionsOverride, Collater, SequenceData
from fairseq2.datasets import SequenceBatch
from fairseq2.datasets.preference import PreferenceBatch
from fairseq2.datasets.prompt import PromptBatch
from fairseq2.gang import Gang, Gangs
Expand Down Expand Up @@ -93,9 +88,13 @@

seq_data = cast(SequenceData, collater(to_collate))

seq_lens = seq_data["seqs"]["seq_lens"]
assert isinstance(seq_lens, Tensor) or isinstance(seq_lens, list)
if isinstance(seq_lens, Tensor):
seq_lens = seq_lens.tolist()
batch = SequenceBatch(
seq_data["seqs"]["seqs"],
seq_data["seqs"]["seq_lens"],
seq_lens,
target_mask=seq_data["target_loss_mask"]["seqs"],
)
batch.to(device)
Expand Down Expand Up @@ -364,6 +363,54 @@

return responses

def get_vllm_logprobs(
vllm_outputs: List[RequestOutput],
gangs,
rollout_start_end: tuple[int, int] | None = None,
):
"""Compute per-token logprobs for selected continuations across a list of requests.

For each RequestOutput (one prompt) and each of its sampled continuations we
concatenate the prompt logprobs (skipping the first entry) with the generation
logprobs. All resulting sequences are then right-padded with 0.0 to the global
maximum length and stacked into a single tensor.

Parameters
----------
vllm_outputs:
List of vLLM RequestOutput objects (one per prompt).
gangs:
Fairseq2 gangs object (unused, kept for parity/extensibility).
rollout_start_end:
Optional (start, end) slice specifying which continuation indices to include
per prompt (used for micro-batching when forward_group_size < group_size).

Returns
-------
Tensor
Shape ``(num_selected_continuations, max_seq_len)`` with 0.0 padding.
"""
sequences: List[Tensor] = []
for request in vllm_outputs:
prompt_logprobs = [
list(d.values())[0].logprob for d in request.prompt_logprobs[1:]
]
outputs = request.outputs
if rollout_start_end is not None: # micro-batching
s, e = rollout_start_end
outputs = outputs[s:e]
for output in outputs:
gen_logprobs = [list(d.values())[0].logprob for d in output.logprobs]
seq = torch.tensor(prompt_logprobs + gen_logprobs)
sequences.append(seq)

max_len = max(t.size(0) for t in sequences)
padded = torch.zeros(len(sequences), max_len)
for i, t in enumerate(sequences):
padded[i, : t.size(0)] = t

return padded


def convert_vllm_output_to_ref_score(vllm_outputs: List[RequestOutput], gangs):
ref_scores = []
Expand Down Expand Up @@ -395,6 +442,8 @@
prompt = prompt_batch.meta_info.get("prompt_raw")[0]
elif "raw_prompt" in prompt_batch.meta_info:
prompt = prompt_batch.meta_info.get("raw_prompt")[0]
elif "problem" in prompt_batch.meta_info:
prompt = prompt_batch.meta_info.get("problem")[0]
else:
# raw text prompt doesn't exist for this dataset
prompt = "DUMMY PROMPT"
Expand All @@ -416,6 +465,57 @@
return rollout_lengths


def compute_reward_agreement_metrics(rewards: Tensor) -> dict[str, Tensor]:
"""Compute metrics for reward agreement across rollouts.

Args:
rewards: Tensor of shape [Batch, Rollouts] containing binary rewards (0 or 1)

Returns:
Dictionary containing:
- 'frac_all_correct': Fraction of prompts where all rollouts got reward=1
- 'frac_all_incorrect': Fraction of prompts where all rollouts got reward=0
"""
# Check if all rollouts for each prompt have reward=1
all_correct = (rewards == 1).all(dim=1) # [Batch]

# Check if all rollouts for each prompt have reward=0
all_incorrect = (rewards == 0).all(dim=1) # [Batch]

# Compute fractions
frac_all_correct = all_correct.float().mean()
frac_all_incorrect = all_incorrect.float().mean()

return {
'frac_all_correct': frac_all_correct,
'frac_all_incorrect': frac_all_incorrect
}


def strip_think_tokens(rollouts: List[SequenceData]):
count_stripped, count_not_stripped, total_count, think_present = 0, 0, 0, 0
for sample in rollouts:
for rollout in sample.outputs:
rollout_text = rollout.text
if "<think>" in rollout_text:
think_present += 1
if rollout.finish_reason == "length":
count_not_stripped += 1
if rollout.finish_reason == "stop":
count_stripped += 1
total_count += 1
rollout.text = re.sub(
r"<think>.*?</think>", "", rollout_text, flags=re.DOTALL
).strip()

log.info(f"Total count: {total_count}")
log.info(f"Think present: {think_present}")
log.info(f"Count stripped: {count_stripped/total_count}")
log.info(f"Count not stripped: {count_not_stripped/total_count}")

return rollouts


class StatefulRolloutBag:
"""A stateful container for managing and reusing model rollouts across multiple micro-batches.

Expand Down Expand Up @@ -505,6 +605,20 @@
def update_avg_reward(metric_bag: MetricBag, avg_reward):
metric_bag.get(Mean, "avg_reward").update(avg_reward, weight=1)

@torch.inference_mode()
def update_reward_matches(metric_bag: MetricBag, reward_matches):
metric_bag.get(Mean, "reward_matches").update(reward_matches, weight=1)


@torch.inference_mode()
def update_frac_all_correct(metric_bag: MetricBag, frac_all_correct):
metric_bag.get(Mean, "frac_all_correct").update(frac_all_correct, weight=1)


@torch.inference_mode()
def update_frac_all_incorrect(metric_bag: MetricBag, frac_all_incorrect):
metric_bag.get(Mean, "frac_all_incorrect").update(frac_all_incorrect, weight=1)


@torch.inference_mode()
def update_std_reward(metric_bag: MetricBag, std_reward):
Expand Down
Loading
Loading