From 2fdf00344635763773b02084131850703d33ccfd Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Sat, 7 Mar 2026 16:09:41 +0800 Subject: [PATCH 1/4] feature(sunjx): implement dynamic sampling --- examples/grm_vl_rl/train_colocate.py | 7 +- examples/gsm8k_geo3k/train_colocate.py | 9 +- examples/r1_aqa/train_colocate.py | 9 +- lightrft/strategy/config.py | 7 +- lightrft/trainer/ppo_trainer_vl.py | 534 ++++++++++++++++++------- 5 files changed, 416 insertions(+), 150 deletions(-) diff --git a/examples/grm_vl_rl/train_colocate.py b/examples/grm_vl_rl/train_colocate.py index 7005195b..54d5a978 100644 --- a/examples/grm_vl_rl/train_colocate.py +++ b/examples/grm_vl_rl/train_colocate.py @@ -314,8 +314,11 @@ def train(args: argparse.Namespace) -> None: save_hf_ckpt=args.save_hf_ckpt, disable_ds_ckpt=args.disable_ds_ckpt, packing_samples=args.packing_samples, - # overlong_reward + # DAPO dynamic sampling dynamic_sampling=args.dynamic_sampling, + dynamic_sampling_metric=getattr(args, 'dynamic_sampling_metric', 'reward'), + max_num_gen_batches=getattr(args, 'max_num_gen_batches', 10), + # overlong_reward overlong_buffer=args.overlong_buffer, overlong_buffer_len=args.overlong_buffer_len, overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, @@ -365,6 +368,8 @@ def train(args: argparse.Namespace) -> None: # DAPO parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering") + parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation") parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them") diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 01eeb3c3..e96be89c 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -417,8 +417,11 @@ def train(args): save_hf_ckpt=args.save_hf_ckpt, disable_ds_ckpt=args.disable_ds_ckpt, packing_samples=args.packing_samples, - # overlong_reward + # DAPO dynamic sampling dynamic_sampling=args.dynamic_sampling, + dynamic_sampling_metric=args.dynamic_sampling_metric, + max_num_gen_batches=args.max_num_gen_batches, + # overlong_reward overlong_buffer=args.overlong_buffer, overlong_buffer_len=args.overlong_buffer_len, overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, @@ -466,7 +469,9 @@ def train(args): parser.add_argument("--load_checkpoint", action="store_true", default=False) # DAPO - parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy: filter out prompt groups with zero metric variance and accumulate until train_batch_size is reached") + parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering (default: reward)") + parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation. Non-positive means no limit (default: 10)") parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them") diff --git a/examples/r1_aqa/train_colocate.py b/examples/r1_aqa/train_colocate.py index fee2fb23..e391fa08 100644 --- a/examples/r1_aqa/train_colocate.py +++ b/examples/r1_aqa/train_colocate.py @@ -329,8 +329,11 @@ def train(args): save_hf_ckpt=args.save_hf_ckpt, disable_ds_ckpt=args.disable_ds_ckpt, packing_samples=args.packing_samples, - # DAPO / overlong + # DAPO dynamic sampling dynamic_sampling=args.dynamic_sampling, + dynamic_sampling_metric=getattr(args, 'dynamic_sampling_metric', 'reward'), + max_num_gen_batches=getattr(args, 'max_num_gen_batches', 10), + # overlong_reward overlong_buffer=args.overlong_buffer, overlong_buffer_len=args.overlong_buffer_len, overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, @@ -388,7 +391,9 @@ def train(args): parser.add_argument("--load_checkpoint", action="store_true", default=False) # DAPO - parser.add_argument("--dynamic_sampling", action="store_true", default=False) + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering") + parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation") parser.add_argument("--overlong_buffer", action="store_true", default=False) parser.add_argument("--overlong_buffer_len", type=int, default=1024) parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0) diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index c6993005..50feb138 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -112,6 +112,11 @@ class StrategyConfig: # Dynamic sampling and advantage estimation # (bool): Enable dynamic sampling for advantage estimation, defaults to False dynamic_sampling: bool = False + # (str): Metric used for dynamic sampling group filtering ("acc", "reward"), defaults to "reward" + dynamic_sampling_metric: str = "reward" + # (int): Max number of generation batches for dynamic sampling accumulation. + # Non-positive values mean no upper limit. defaults to 10 + max_num_gen_batches: int = 10 # (str): Advantage estimator method, defaults to "gae" advantage_estimator: str = "group_norm" @@ -280,7 +285,7 @@ def print_config_summary(self) -> None: # Dynamic Sampling and Advantage Estimation Parameters print("\nDynamic Sampling and Advantage Estimation Parameters:") - for attr in ['dynamic_sampling', 'advantage_estimator']: + for attr in ['dynamic_sampling', 'dynamic_sampling_metric', 'max_num_gen_batches', 'advantage_estimator']: current = getattr(self, attr) default = getattr(default_config, attr) status = "Overridden" if current != default else "Default" diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 7e998817..5fa820d3 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -6,6 +6,7 @@ import torch import math +import numpy as np import torch.nn as nn from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -16,6 +17,7 @@ from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl from lightrft.utils.distributed_sampler import DistributedSampler from lightrft.trainer import AdaptiveKLController, ExperienceVL, FixedKLController, NaiveExperienceMakerVL, NaiveReplayBufferVL # noqa +from lightrft.trainer.replay_buffer_utils import split_experience_batch, remove_padding_in_sequences class PPOTrainerVL(ABC): @@ -252,6 +254,132 @@ def __init__( log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) self._tensorboard = SummaryWriter(log_dir=log_dir) + def _filter_buffer_items_by_metric(self, items, n_samples_per_prompt, metric_name="reward"): + """ + Filter replay buffer items by group metric variance (DAPO dynamic sampling). + + Groups items by prompt (every n_samples_per_prompt consecutive items form a group). + Keeps only groups where the metric has non-zero variance (i.e., not all-correct or all-wrong). + + :param items: List of BufferItemVL from replay buffer + :type items: List + :param n_samples_per_prompt: Number of samples per prompt + :type n_samples_per_prompt: int + :param metric_name: Metric to use for filtering ("reward" or "acc") + :type metric_name: str + :return: Tuple of (kept_items, num_kept_prompts, num_total_prompts, num_filtered_prompts) + :rtype: Tuple[List, int, int, int] + """ + n = n_samples_per_prompt + num_groups = len(items) // n + kept_items = [] + num_filtered = 0 + + for g in range(num_groups): + group = items[g * n:(g + 1) * n] + + metric_vals = [] + for item in group: + if metric_name == "reward" and hasattr(item, 'info') and item.info is not None and 'reward' in item.info: + val = item.info['reward'] + if isinstance(val, torch.Tensor): + metric_vals.append(val.float().mean().item()) + else: + metric_vals.append(float(val)) + elif metric_name == "acc" and hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info: + rm = item.info.get('reward_metrics') + if rm is not None and 'accuracy_reward' in rm: + val = rm['accuracy_reward'] + if isinstance(val, torch.Tensor): + metric_vals.append(val.float().mean().item()) + else: + metric_vals.append(float(val)) + else: + metric_vals.append(0.0) + else: + metric_vals.append(0.0) + + metric_std = np.std(metric_vals) + + if metric_std > 1e-8 or n == 1: + kept_items.extend(group) + else: + num_filtered += 1 + + return kept_items, num_groups - num_filtered, num_groups, num_filtered + + def _collect_rollout_status(self, items): + """ + Collect rollout statistics from replay buffer items. + + :param items: List of experience items + :type items: List + :return: Dictionary of rollout statistics + :rtype: dict + """ + rollout_status = {} + if not items: + return rollout_status + + all_rewards = [] + all_format_rewards = [] + all_accuracy_rewards = [] + all_response_lengths = [] + + for item in items: + if hasattr(item, 'info') and item.info is not None and 'reward' in item.info: + all_rewards.append(item.info['reward']) + + if ( + hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info + and item.info['reward_metrics'] is not None + ): + reward_metrics = item.info['reward_metrics'] + if 'format_reward' in reward_metrics: + all_format_rewards.append(reward_metrics['format_reward']) + if 'accuracy_reward' in reward_metrics: + all_accuracy_rewards.append(reward_metrics['accuracy_reward']) + + if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: + all_response_lengths.append(item.info['response_length']) + + device = torch.cuda.current_device() + + if all_rewards: + if isinstance(all_rewards[0], torch.Tensor): + rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards]) + else: + rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device) + rollout_status["rollout_reward"] = rewards_tensor.mean().item() + rollout_status["rollout_reward_std"] = rewards_tensor.std().item() + + if all_format_rewards: + if isinstance(all_format_rewards[0], torch.Tensor): + format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards]) + else: + format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) + mean_format_reward = format_tensor.mean().item() + if abs(mean_format_reward) > 1e-6: + rollout_status["rollout_format_reward"] = mean_format_reward + + if all_accuracy_rewards: + if isinstance(all_accuracy_rewards[0], torch.Tensor): + accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards]) + else: + accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) + mean_accuracy_reward = accuracy_tensor.mean().item() + if abs(mean_accuracy_reward) > 1e-6: + rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward + + if all_response_lengths: + if isinstance(all_response_lengths[0], torch.Tensor): + lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths]) + else: + lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device) + rollout_status["rollout_response_length"] = lengths_tensor.mean().item() + + return rollout_status + def fit( self, args, @@ -262,7 +390,11 @@ def fit( num_update_steps_per_episodes=1, ) -> None: """ - Main training loop for PPO. + Main training loop for PPO with optional DAPO dynamic sampling. + + When dynamic_sampling is enabled, the loop filters out prompt groups with zero + metric variance and accumulates qualified groups across multiple generation batches + until train_batch_size is reached or max_num_gen_batches is hit. :param args: Training arguments. :type args: Namespace @@ -282,8 +414,29 @@ def fit( samples_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt samples_per_train = args.train_batch_size * args.n_samples_per_prompt + # Dynamic sampling configuration + use_dynamic_sampling = getattr(args, 'dynamic_sampling', False) + dynamic_sampling_metric = getattr(args, 'dynamic_sampling_metric', 'reward') + max_num_gen_batches = getattr(args, 'max_num_gen_batches', 10) + # Print training mode information - if args.train_batch_size < args.rollout_batch_size: + if use_dynamic_sampling: + self.strategy.print( + f"\n{'=' * 80}\n" + f"DYNAMIC SAMPLING MODE (DAPO)\n" + f"{'=' * 80}\n" + f"Configuration:\n" + f" - Metric: {dynamic_sampling_metric}\n" + f" - Max generation batches: {max_num_gen_batches} ({'unlimited' if max_num_gen_batches <= 0 else max_num_gen_batches})\n" + f" - train_batch_size: {args.train_batch_size} prompts\n" + f" - rollout_batch_size: {args.rollout_batch_size} prompts per generation\n" + f" - n_samples_per_prompt: {args.n_samples_per_prompt}\n" + f"Behavior:\n" + f" - Groups with zero metric variance (all-correct/all-wrong) are filtered out.\n" + f" - Generation repeats until {args.train_batch_size} qualified prompts are accumulated.\n" + f"{'=' * 80}\n" + ) + elif args.train_batch_size < args.rollout_batch_size: updates_per_rollout = samples_per_rollout / samples_per_train self.strategy.print( f"\n{'=' * 80}\n" @@ -306,14 +459,6 @@ def fit( ) # Calculate number of rollouts per episode. - # Regardless of TBS and RBS relationship, rollout count should be determined by "total data / rollout size". - # Numerator (num_update_steps * train_batch_size) equals "total samples planned for this episode". - # Denominator (rollout_batch_size * n_samples) equals "samples produced per rollout". - # This calculation ensures data collection volume is constant. - # When TBS=64, num_update_steps is naturally twice as large as when TBS=128. - # Substituting into formula: (2N * 0.5T) / R = (N * T) / R. - # Conclusion: Rollout count unchanged, but internal update loop count doubles due to smaller TBS. - num_rollouts_per_episodes = ( num_update_steps_per_episodes * args.train_batch_size // args.max_epochs // args.rollout_batch_size // args.n_samples_per_prompt @@ -321,7 +466,6 @@ def fit( # Safeguard to prevent num_rollouts_per_episodes from being 0 if num_rollouts_per_episodes == 0: - # Try recalculating with ceil to prevent fractional values from being discarded by integer division val = (num_update_steps_per_episodes * args.train_batch_size) / (args.max_epochs * args.rollout_batch_size * args.n_samples_per_prompt) num_rollouts_per_episodes = math.ceil(val) @@ -356,162 +500,264 @@ def fit( disable=not self.strategy.is_rank_0(), ) - for batch in self.prompts_dataloader: - # Compatible with both image-only (4 args) and video (5 args) dataloaders - if len(batch) == 5: - rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch - else: - rand_prompts, rand_images, rand_references, rand_labels = batch - rand_videos = None + if use_dynamic_sampling: + # ============================================================ + # DYNAMIC SAMPLING PATH (DAPO) + # Accumulate qualified prompt groups across multiple generation + # batches until train_batch_size is reached. + # + # Flow: + # 1. Generate rollouts → append to temp buffer (split into items) + # 2. Filter items by prompt group metric variance + # 3. Accumulate kept items across generation batches + # 4. When enough qualified prompts → train + # ============================================================ + accumulated_items = [] # List of BufferItemVL + num_qualified_prompts = 0 + num_gen_batches = 0 + total_generated_prompts = 0 + total_filtered_prompts = 0 + + for batch in self.prompts_dataloader: + if len(batch) == 5: + rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch + else: + rand_prompts, rand_images, rand_references, rand_labels = batch + rand_videos = None + + self.strategy.print( + f"[DynSamp] Gen batch {num_gen_batches + 1}: generating rollouts for {len(rand_prompts)} prompts..." + ) - # TODO: Remove debug print - self.strategy.print( - f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa - ) + # Generate experiences and collect individual items + batch_items = [] + for i, experience in enumerate( + self.experience_maker.make_experience_list( + rand_prompts, + rand_images, + all_videos=rand_videos, + all_references=rand_references, + all_labels=rand_labels, + **self.generate_kwargs + ) + ): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + + # Split batched experience into individual items (same as replay_buffer.append) + experience.to_device(torch.device("cpu")) + items = split_experience_batch(experience) + items = remove_padding_in_sequences(items) + batch_items.extend(items) + + num_gen_batches += 1 + + # Filter by metric variance at the item level + kept, num_kept, num_total, num_filtered = self._filter_buffer_items_by_metric( + batch_items, args.n_samples_per_prompt, metric_name=dynamic_sampling_metric + ) - for i, experience in enumerate( - self.experience_maker.make_experience_list( - rand_prompts, - rand_images, - all_videos=rand_videos, - all_references=rand_references, - all_labels=rand_labels, - **self.generate_kwargs + accumulated_items.extend(kept) + num_qualified_prompts += num_kept + total_generated_prompts += num_total + total_filtered_prompts += num_filtered + + self.strategy.print( + f"[DynSamp] Batch {num_gen_batches}: kept {num_kept}/{num_total} prompts " + f"(filtered {num_filtered}). " + f"Accumulated: {num_qualified_prompts}/{args.train_batch_size} qualified prompts." ) - ): - if i == 0: - output = self.tokenizer.batch_decode( - experience.sequences[0].unsqueeze(0), skip_special_tokens=True + + # Check if we have enough qualified prompts + if num_qualified_prompts >= args.train_batch_size: + # Trim to exact train_batch_size * n_samples_per_prompt + target_num_items = args.train_batch_size * args.n_samples_per_prompt + accumulated_items = accumulated_items[:target_num_items] + num_qualified_prompts = args.train_batch_size + + self.strategy.print( + f"[DynSamp] Reached train_batch_size={args.train_batch_size} after " + f"{num_gen_batches} generation batches. " + f"Total generated: {total_generated_prompts}, filtered: {total_filtered_prompts} " + f"({total_filtered_prompts / max(total_generated_prompts, 1) * 100:.1f}% filtered)." ) - self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + + # Add filtered items directly to replay buffer + self.replay_buffer.items.extend(accumulated_items) + + self.strategy.report_memory('after replay_buffer ready (dynamic sampling)') + + # Collect rollout statistics + rollout_status = self._collect_rollout_status(self.replay_buffer.items) + rollout_status["ds_gen_batches"] = num_gen_batches + rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) + rollout_status["ds_qualified_prompts"] = num_qualified_prompts + + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) + + self.strategy.report_memory('before train') + status = self.ppo_train(steps) + + self.strategy.report_memory('before clear buffer') + self.replay_buffer.clear() + self.strategy.report_memory('after train') + + if "kl" in status: + self.kl_ctl.update(status["kl"], args.train_batch_size * args.n_samples_per_prompt) + + pbar.set_postfix(rollout_status) + + client_states = {"consumed_samples": steps * args.rollout_batch_size} + logs_dict_combined = {**rollout_status, **status} + self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + + pbar.update() + steps += 1 + + # Reset accumulation for next training step + accumulated_items = [] + num_qualified_prompts = 0 + num_gen_batches = 0 + total_generated_prompts = 0 + total_filtered_prompts = 0 + + elif max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches: + # Reached max generation batches without enough qualified prompts self.strategy.print( - f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + f"[DynSamp] WARNING: Reached max_num_gen_batches={max_num_gen_batches} " + f"with only {num_qualified_prompts}/{args.train_batch_size} qualified prompts. " + f"Training with available data." ) - # print all - # self.strategy.print( - # f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa - # ) - - self.replay_buffer.append(experience) - - self.strategy.report_memory('after replay_buffer ready') - - # Aggregate rollout statistics from replay buffer - # Collect metrics from the rollout/collection phase - rollout_status = {} - if self.replay_buffer.items: - all_rewards = [] - all_format_rewards = [] - all_accuracy_rewards = [] - all_response_lengths = [] - - for item in self.replay_buffer.items: - # Collect rewards from rollout - if hasattr(item, 'info') and item.info is not None and 'reward' in item.info: - all_rewards.append(item.info['reward']) - - # Robust handling of reward_metrics - # 1. Check if info exists - # 2. Check if 'reward_metrics' key exists - # 3. Check if reward_metrics is not None (critical!) - if ( - hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info - and item.info['reward_metrics'] is not None - ): - - reward_metrics = item.info['reward_metrics'] - - # Safely extract sub-metrics - if 'format_reward' in reward_metrics: - all_format_rewards.append(reward_metrics['format_reward']) - if 'accuracy_reward' in reward_metrics: - all_accuracy_rewards.append(reward_metrics['accuracy_reward']) - - # Collect response lengths from rollout - if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: - all_response_lengths.append(item.info['response_length']) - - # Compute rollout statistics - device = torch.cuda.current_device() - - if all_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_rewards[0], torch.Tensor): - rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards]) - else: - rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device) - rollout_status["rollout_reward"] = rewards_tensor.mean().item() - rollout_status["rollout_reward_std"] = rewards_tensor.std().item() - - if all_format_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - # Issue: all_format_rewards may contain tensors (from reward_metrics), - # but torch.tensor() cannot convert a list of tensors directly. - # Solution: Use torch.cat() for tensor lists, torch.tensor() for scalar lists - if isinstance(all_format_rewards[0], torch.Tensor): - # List of tensors: concatenate them - format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards]) - else: - # List of scalars: convert to tensor - format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) - mean_format_reward = format_tensor.mean().item() + if accumulated_items: + self.replay_buffer.items.extend(accumulated_items) - # Only display if mean is significantly non-zero - if abs(mean_format_reward) > 1e-6: - rollout_status["rollout_format_reward"] = mean_format_reward + rollout_status = self._collect_rollout_status(self.replay_buffer.items) + rollout_status["ds_gen_batches"] = num_gen_batches + rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) + rollout_status["ds_qualified_prompts"] = num_qualified_prompts - if all_accuracy_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_accuracy_rewards[0], torch.Tensor): - accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards]) - else: - accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) - mean_accuracy_reward = accuracy_tensor.mean().item() + status = self.ppo_train(steps) + self.replay_buffer.clear() - # Only display if mean is significantly non-zero - if abs(mean_accuracy_reward) > 1e-6: - rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward + if "kl" in status: + self.kl_ctl.update(status["kl"], len(accumulated_items)) - if all_response_lengths: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_response_lengths[0], torch.Tensor): - lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths]) - else: - lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device) + pbar.set_postfix(rollout_status) + client_states = {"consumed_samples": steps * args.rollout_batch_size} + logs_dict_combined = {**rollout_status, **status} + self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) - rollout_status["rollout_response_length"] = lengths_tensor.mean().item() + pbar.update() + steps += 1 - # TODO: Check normalization behavior - if self.args.advantage_estimator != "group_norm": - self.replay_buffer.normalize("advantages", self.strategy) + # Reset + accumulated_items = [] + num_qualified_prompts = 0 + num_gen_batches = 0 + total_generated_prompts = 0 + total_filtered_prompts = 0 - self.strategy.report_memory('before train') + # Handle remaining accumulated items at end of episode + if accumulated_items: + self.strategy.print( + f"[DynSamp] End of episode: training with {num_qualified_prompts} remaining qualified prompts." + ) + self.replay_buffer.items.extend(accumulated_items) - status = self.ppo_train(steps) + rollout_status = self._collect_rollout_status(self.replay_buffer.items) + rollout_status["ds_gen_batches"] = num_gen_batches + rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) + rollout_status["ds_qualified_prompts"] = num_qualified_prompts - self.strategy.report_memory('before clear buffer') - self.replay_buffer.clear() + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) - self.strategy.report_memory('after train') + status = self.ppo_train(steps) + self.replay_buffer.clear() - if "kl" in status: - self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt) + if "kl" in status: + self.kl_ctl.update(status["kl"], len(accumulated_items)) + + pbar.set_postfix(rollout_status) + client_states = {"consumed_samples": steps * args.rollout_batch_size} + logs_dict_combined = {**rollout_status, **status} + self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + + pbar.update() + steps += 1 + + else: + # ============================================================ + # STANDARD PATH (no dynamic sampling) + # ============================================================ + for batch in self.prompts_dataloader: + if len(batch) == 5: + rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch + else: + rand_prompts, rand_images, rand_references, rand_labels = batch + rand_videos = None + + self.strategy.print( + f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa + ) + + for i, experience in enumerate( + self.experience_maker.make_experience_list( + rand_prompts, + rand_images, + all_videos=rand_videos, + all_references=rand_references, + all_labels=rand_labels, + **self.generate_kwargs + ) + ): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + self.strategy.print( + f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + ) + + self.replay_buffer.append(experience) + + self.strategy.report_memory('after replay_buffer ready') + + rollout_status = self._collect_rollout_status(self.replay_buffer.items) + + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) + + self.strategy.report_memory('before train') + + status = self.ppo_train(steps) + + self.strategy.report_memory('before clear buffer') + self.replay_buffer.clear() + + self.strategy.report_memory('after train') + + if "kl" in status: + self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt) - # Update Episode pbar with ROLLOUT statistics (not training statistics!) - pbar.set_postfix(rollout_status) + pbar.set_postfix(rollout_status) - # Logs/checkpoints: save BOTH ROLLOUT and TRAINING statistics to wandb - # [FIX] Merge rollout_status (from inference) and status (from training) - # to ensure wandb logs contain both types of metrics - client_states = {"consumed_samples": steps * args.rollout_batch_size} - logs_dict_combined = {**rollout_status, **status} # Merge: rollout first, training second + client_states = {"consumed_samples": steps * args.rollout_batch_size} + logs_dict_combined = {**rollout_status, **status} - self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) - pbar.update() - steps = steps + 1 + pbar.update() + steps = steps + 1 if self._wandb is not None and self.strategy.is_rank_0(): self._wandb.finish() From 1834ad25251da8f9098684d836a456a62866b0f8 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Sat, 7 Mar 2026 20:28:08 +0800 Subject: [PATCH 2/4] feature(sunjx): Resolved multi-GPU synchronization to avoid deadlocks. --- lightrft/trainer/ppo_trainer_vl.py | 95 ++++++++++++++++++++++++------ 1 file changed, 76 insertions(+), 19 deletions(-) diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 5fa820d3..c411b990 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -421,6 +421,8 @@ def fit( # Print training mode information if use_dynamic_sampling: + ds_world_size = torch.distributed.get_world_size() + ds_local_target = args.train_batch_size // ds_world_size self.strategy.print( f"\n{'=' * 80}\n" f"DYNAMIC SAMPLING MODE (DAPO)\n" @@ -428,12 +430,14 @@ def fit( f"Configuration:\n" f" - Metric: {dynamic_sampling_metric}\n" f" - Max generation batches: {max_num_gen_batches} ({'unlimited' if max_num_gen_batches <= 0 else max_num_gen_batches})\n" - f" - train_batch_size: {args.train_batch_size} prompts\n" + f" - train_batch_size: {args.train_batch_size} prompts (global)\n" + f" - local target per rank: {ds_local_target} prompts (world_size={ds_world_size})\n" f" - rollout_batch_size: {args.rollout_batch_size} prompts per generation\n" f" - n_samples_per_prompt: {args.n_samples_per_prompt}\n" f"Behavior:\n" f" - Groups with zero metric variance (all-correct/all-wrong) are filtered out.\n" - f" - Generation repeats until {args.train_batch_size} qualified prompts are accumulated.\n" + f" - Generation repeats until each rank has {ds_local_target} qualified prompts.\n" + f" - Training decision is synchronized across all ranks via all_reduce.\n" f"{'=' * 80}\n" ) elif args.train_batch_size < args.rollout_batch_size: @@ -511,7 +515,18 @@ def fit( # 2. Filter items by prompt group metric variance # 3. Accumulate kept items across generation batches # 4. When enough qualified prompts → train + # + # MULTI-GPU SYNCHRONIZATION: + # Each rank processes different prompts (via DistributedSampler), + # so filtering rates differ across ranks. To prevent deadlocks + # from divergent control flow (one rank training while another + # is still generating), we synchronize the "ready to train" + # decision across all ranks using all_reduce(MIN). # ============================================================ + world_size = torch.distributed.get_world_size() + local_target_prompts = args.train_batch_size // world_size + device = torch.cuda.current_device() + accumulated_items = [] # List of BufferItemVL num_qualified_prompts = 0 num_gen_batches = 0 @@ -568,18 +583,35 @@ def fit( self.strategy.print( f"[DynSamp] Batch {num_gen_batches}: kept {num_kept}/{num_total} prompts " f"(filtered {num_filtered}). " - f"Accumulated: {num_qualified_prompts}/{args.train_batch_size} qualified prompts." + f"Accumulated: {num_qualified_prompts}/{local_target_prompts} local qualified prompts." + ) + + # Synchronize "ready to train" decision across all ranks. + # Use all_reduce(MIN) so training only starts when ALL ranks + # have enough qualified prompts, preventing deadlock. + local_ready = torch.tensor( + [1.0 if num_qualified_prompts >= local_target_prompts else 0.0], + device=device, + ) + torch.distributed.all_reduce(local_ready, op=torch.distributed.ReduceOp.MIN) + all_ranks_ready = local_ready.item() > 0 + + # Also synchronize the max_gen_batches fallback decision + local_max_reached = torch.tensor( + [1.0 if (max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches) else 0.0], + device=device, ) + torch.distributed.all_reduce(local_max_reached, op=torch.distributed.ReduceOp.MAX) + any_rank_max_reached = local_max_reached.item() > 0 - # Check if we have enough qualified prompts - if num_qualified_prompts >= args.train_batch_size: - # Trim to exact train_batch_size * n_samples_per_prompt - target_num_items = args.train_batch_size * args.n_samples_per_prompt + if all_ranks_ready: + # Trim to exact local_target_prompts * n_samples_per_prompt + target_num_items = local_target_prompts * args.n_samples_per_prompt accumulated_items = accumulated_items[:target_num_items] - num_qualified_prompts = args.train_batch_size + num_qualified_prompts = local_target_prompts self.strategy.print( - f"[DynSamp] Reached train_batch_size={args.train_batch_size} after " + f"[DynSamp] All ranks ready. Local target={local_target_prompts} reached after " f"{num_gen_batches} generation batches. " f"Total generated: {total_generated_prompts}, filtered: {total_filtered_prompts} " f"({total_filtered_prompts / max(total_generated_prompts, 1) * 100:.1f}% filtered)." @@ -594,7 +626,7 @@ def fit( rollout_status = self._collect_rollout_status(self.replay_buffer.items) rollout_status["ds_gen_batches"] = num_gen_batches rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) - rollout_status["ds_qualified_prompts"] = num_qualified_prompts + rollout_status["ds_qualified_prompts"] = num_qualified_prompts * world_size if self.args.advantage_estimator != "group_norm": self.replay_buffer.normalize("advantages", self.strategy) @@ -607,7 +639,7 @@ def fit( self.strategy.report_memory('after train') if "kl" in status: - self.kl_ctl.update(status["kl"], args.train_batch_size * args.n_samples_per_prompt) + self.kl_ctl.update(status["kl"], local_target_prompts * args.n_samples_per_prompt) pbar.set_postfix(rollout_status) @@ -625,21 +657,31 @@ def fit( total_generated_prompts = 0 total_filtered_prompts = 0 - elif max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches: - # Reached max generation batches without enough qualified prompts + elif any_rank_max_reached: + # At least one rank reached max generation batches self.strategy.print( - f"[DynSamp] WARNING: Reached max_num_gen_batches={max_num_gen_batches} " - f"with only {num_qualified_prompts}/{args.train_batch_size} qualified prompts. " + f"[DynSamp] WARNING: max_num_gen_batches={max_num_gen_batches} reached by some rank " + f"with local {num_qualified_prompts}/{local_target_prompts} qualified prompts. " f"Training with available data." ) + # Synchronize item count: use the minimum across ranks + # to ensure all ranks have the same number of training steps. + # All ranks must participate in all_reduce. + local_count = torch.tensor([len(accumulated_items)], device=device, dtype=torch.long) + torch.distributed.all_reduce(local_count, op=torch.distributed.ReduceOp.MIN) + min_items = local_count.item() + # Round down to multiple of n_samples_per_prompt to keep prompt groups intact + min_items = (min_items // args.n_samples_per_prompt) * args.n_samples_per_prompt + accumulated_items = accumulated_items[:min_items] + if accumulated_items: self.replay_buffer.items.extend(accumulated_items) rollout_status = self._collect_rollout_status(self.replay_buffer.items) rollout_status["ds_gen_batches"] = num_gen_batches rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) - rollout_status["ds_qualified_prompts"] = num_qualified_prompts + rollout_status["ds_qualified_prompts"] = len(accumulated_items) // args.n_samples_per_prompt if self.args.advantage_estimator != "group_norm": self.replay_buffer.normalize("advantages", self.strategy) @@ -657,6 +699,11 @@ def fit( pbar.update() steps += 1 + else: + # All ranks have zero items after sync - skip training + self.strategy.print( + "[DynSamp] WARNING: No items available after synchronization. Skipping training step." + ) # Reset accumulated_items = [] @@ -665,17 +712,27 @@ def fit( total_generated_prompts = 0 total_filtered_prompts = 0 - # Handle remaining accumulated items at end of episode + # Handle remaining accumulated items at end of episode. + # All ranks must participate in the all_reduce even if some have no items. + local_count = torch.tensor( + [len(accumulated_items)], device=device, dtype=torch.long + ) + torch.distributed.all_reduce(local_count, op=torch.distributed.ReduceOp.MIN) + min_items = local_count.item() + min_items = (min_items // args.n_samples_per_prompt) * args.n_samples_per_prompt + accumulated_items = accumulated_items[:min_items] + if accumulated_items: self.strategy.print( - f"[DynSamp] End of episode: training with {num_qualified_prompts} remaining qualified prompts." + f"[DynSamp] End of episode: training with {len(accumulated_items) // args.n_samples_per_prompt} " + f"remaining qualified prompts per rank." ) self.replay_buffer.items.extend(accumulated_items) rollout_status = self._collect_rollout_status(self.replay_buffer.items) rollout_status["ds_gen_batches"] = num_gen_batches rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) - rollout_status["ds_qualified_prompts"] = num_qualified_prompts + rollout_status["ds_qualified_prompts"] = len(accumulated_items) // args.n_samples_per_prompt if self.args.advantage_estimator != "group_norm": self.replay_buffer.normalize("advantages", self.strategy) From 2513fe1d3202eb61dda37fe040357c85001b7789 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Sat, 7 Mar 2026 20:31:05 +0800 Subject: [PATCH 3/4] feature(sunjx): pass fcheck and format --- lightrft/trainer/ppo_trainer_vl.py | 37 ++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index c411b990..ff8fca2d 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -280,13 +280,17 @@ def _filter_buffer_items_by_metric(self, items, n_samples_per_prompt, metric_nam metric_vals = [] for item in group: - if metric_name == "reward" and hasattr(item, 'info') and item.info is not None and 'reward' in item.info: + if metric_name == "reward" and hasattr( + item, 'info' + ) and item.info is not None and 'reward' in item.info: val = item.info['reward'] if isinstance(val, torch.Tensor): metric_vals.append(val.float().mean().item()) else: metric_vals.append(float(val)) - elif metric_name == "acc" and hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info: + elif metric_name == "acc" and hasattr( + item, 'info' + ) and item.info is not None and 'reward_metrics' in item.info: rm = item.info.get('reward_metrics') if rm is not None and 'accuracy_reward' in rm: val = rm['accuracy_reward'] @@ -429,7 +433,8 @@ def fit( f"{'=' * 80}\n" f"Configuration:\n" f" - Metric: {dynamic_sampling_metric}\n" - f" - Max generation batches: {max_num_gen_batches} ({'unlimited' if max_num_gen_batches <= 0 else max_num_gen_batches})\n" + f" - Max generation batches: {max_num_gen_batches} " + f"({'unlimited' if max_num_gen_batches <= 0 else max_num_gen_batches})\n" f" - train_batch_size: {args.train_batch_size} prompts (global)\n" f" - local target per rank: {ds_local_target} prompts (world_size={ds_world_size})\n" f" - rollout_batch_size: {args.rollout_batch_size} prompts per generation\n" @@ -541,7 +546,8 @@ def fit( rand_videos = None self.strategy.print( - f"[DynSamp] Gen batch {num_gen_batches + 1}: generating rollouts for {len(rand_prompts)} prompts..." + f"[DynSamp] Gen batch {num_gen_batches + 1}: generating rollouts " + f"for {len(rand_prompts)} prompts..." ) # Generate experiences and collect individual items @@ -645,7 +651,9 @@ def fit( client_states = {"consumed_samples": steps * args.rollout_batch_size} logs_dict_combined = {**rollout_status, **status} - self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + self.save_logs_and_checkpoints( + args, steps, pbar, logs_dict_combined, client_states, episode=episode + ) pbar.update() steps += 1 @@ -695,7 +703,9 @@ def fit( pbar.set_postfix(rollout_status) client_states = {"consumed_samples": steps * args.rollout_batch_size} logs_dict_combined = {**rollout_status, **status} - self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + self.save_logs_and_checkpoints( + args, steps, pbar, logs_dict_combined, client_states, episode=episode + ) pbar.update() steps += 1 @@ -714,9 +724,7 @@ def fit( # Handle remaining accumulated items at end of episode. # All ranks must participate in the all_reduce even if some have no items. - local_count = torch.tensor( - [len(accumulated_items)], device=device, dtype=torch.long - ) + local_count = torch.tensor([len(accumulated_items)], device=device, dtype=torch.long) torch.distributed.all_reduce(local_count, op=torch.distributed.ReduceOp.MIN) min_items = local_count.item() min_items = (min_items // args.n_samples_per_prompt) * args.n_samples_per_prompt @@ -724,7 +732,8 @@ def fit( if accumulated_items: self.strategy.print( - f"[DynSamp] End of episode: training with {len(accumulated_items) // args.n_samples_per_prompt} " + f"[DynSamp] End of episode: training with " + f"{len(accumulated_items) // args.n_samples_per_prompt} " f"remaining qualified prompts per rank." ) self.replay_buffer.items.extend(accumulated_items) @@ -746,7 +755,9 @@ def fit( pbar.set_postfix(rollout_status) client_states = {"consumed_samples": steps * args.rollout_batch_size} logs_dict_combined = {**rollout_status, **status} - self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + self.save_logs_and_checkpoints( + args, steps, pbar, logs_dict_combined, client_states, episode=episode + ) pbar.update() steps += 1 @@ -811,7 +822,9 @@ def fit( client_states = {"consumed_samples": steps * args.rollout_batch_size} logs_dict_combined = {**rollout_status, **status} - self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + self.save_logs_and_checkpoints( + args, steps, pbar, logs_dict_combined, client_states, episode=episode + ) pbar.update() steps = steps + 1 From c70cdf3d9dd0467b1c7e24bf42c93e448dbdeb85 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Wed, 18 Mar 2026 15:35:35 +0800 Subject: [PATCH 4/4] feature(sunjx): add implementations reference --- lightrft/trainer/ppo_trainer_vl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index ff8fca2d..f0e655cb 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -514,6 +514,7 @@ def fit( # DYNAMIC SAMPLING PATH (DAPO) # Accumulate qualified prompt groups across multiple generation # batches until train_batch_size is reached. + # Reference implementation: https://github.com/verl-project/verl # # Flow: # 1. Generate rollouts → append to temp buffer (split into items)