Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/grm_vl_rl/train_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,11 @@ def train(args: argparse.Namespace) -> None:
save_hf_ckpt=args.save_hf_ckpt,
disable_ds_ckpt=args.disable_ds_ckpt,
packing_samples=args.packing_samples,
# overlong_reward
# DAPO dynamic sampling
dynamic_sampling=args.dynamic_sampling,
dynamic_sampling_metric=getattr(args, 'dynamic_sampling_metric', 'reward'),
max_num_gen_batches=getattr(args, 'max_num_gen_batches', 10),
# overlong_reward
overlong_buffer=args.overlong_buffer,
overlong_buffer_len=args.overlong_buffer_len,
overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor,
Expand Down Expand Up @@ -365,6 +368,8 @@ def train(args: argparse.Namespace) -> None:

# DAPO
parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy")
parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering")
parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation")
parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO")
parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer")
parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them")
Expand Down
9 changes: 7 additions & 2 deletions examples/gsm8k_geo3k/train_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,11 @@ def train(args):
save_hf_ckpt=args.save_hf_ckpt,
disable_ds_ckpt=args.disable_ds_ckpt,
packing_samples=args.packing_samples,
# overlong_reward
# DAPO dynamic sampling
dynamic_sampling=args.dynamic_sampling,
dynamic_sampling_metric=args.dynamic_sampling_metric,
max_num_gen_batches=args.max_num_gen_batches,
# overlong_reward
overlong_buffer=args.overlong_buffer,
overlong_buffer_len=args.overlong_buffer_len,
overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor,
Expand Down Expand Up @@ -466,7 +469,9 @@ def train(args):
parser.add_argument("--load_checkpoint", action="store_true", default=False)

# DAPO
parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy")
parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy: filter out prompt groups with zero metric variance and accumulate until train_batch_size is reached")
parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering (default: reward)")
parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation. Non-positive means no limit (default: 10)")
parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO")
parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer")
parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them")
Expand Down
9 changes: 7 additions & 2 deletions examples/r1_aqa/train_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,11 @@ def train(args):
save_hf_ckpt=args.save_hf_ckpt,
disable_ds_ckpt=args.disable_ds_ckpt,
packing_samples=args.packing_samples,
# DAPO / overlong
# DAPO dynamic sampling
dynamic_sampling=args.dynamic_sampling,
dynamic_sampling_metric=getattr(args, 'dynamic_sampling_metric', 'reward'),
max_num_gen_batches=getattr(args, 'max_num_gen_batches', 10),
# overlong_reward
overlong_buffer=args.overlong_buffer,
overlong_buffer_len=args.overlong_buffer_len,
overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor,
Expand Down Expand Up @@ -388,7 +391,9 @@ def train(args):
parser.add_argument("--load_checkpoint", action="store_true", default=False)

# DAPO
parser.add_argument("--dynamic_sampling", action="store_true", default=False)
parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy")
parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering")
parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation")
parser.add_argument("--overlong_buffer", action="store_true", default=False)
parser.add_argument("--overlong_buffer_len", type=int, default=1024)
parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0)
Expand Down
7 changes: 6 additions & 1 deletion lightrft/strategy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ class StrategyConfig:
# Dynamic sampling and advantage estimation
# (bool): Enable dynamic sampling for advantage estimation, defaults to False
dynamic_sampling: bool = False
# (str): Metric used for dynamic sampling group filtering ("acc", "reward"), defaults to "reward"
dynamic_sampling_metric: str = "reward"
# (int): Max number of generation batches for dynamic sampling accumulation.
# Non-positive values mean no upper limit. defaults to 10
max_num_gen_batches: int = 10
# (str): Advantage estimator method, defaults to "gae"
advantage_estimator: str = "group_norm"

Expand Down Expand Up @@ -280,7 +285,7 @@ def print_config_summary(self) -> None:

# Dynamic Sampling and Advantage Estimation Parameters
print("\nDynamic Sampling and Advantage Estimation Parameters:")
for attr in ['dynamic_sampling', 'advantage_estimator']:
for attr in ['dynamic_sampling', 'dynamic_sampling_metric', 'max_num_gen_batches', 'advantage_estimator']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
Expand Down
Loading
Loading