diff --git a/examples/grm_vl_rl/train_colocate.py b/examples/grm_vl_rl/train_colocate.py index 7005195b..54d5a978 100644 --- a/examples/grm_vl_rl/train_colocate.py +++ b/examples/grm_vl_rl/train_colocate.py @@ -314,8 +314,11 @@ def train(args: argparse.Namespace) -> None: save_hf_ckpt=args.save_hf_ckpt, disable_ds_ckpt=args.disable_ds_ckpt, packing_samples=args.packing_samples, - # overlong_reward + # DAPO dynamic sampling dynamic_sampling=args.dynamic_sampling, + dynamic_sampling_metric=getattr(args, 'dynamic_sampling_metric', 'reward'), + max_num_gen_batches=getattr(args, 'max_num_gen_batches', 10), + # overlong_reward overlong_buffer=args.overlong_buffer, overlong_buffer_len=args.overlong_buffer_len, overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, @@ -365,6 +368,8 @@ def train(args: argparse.Namespace) -> None: # DAPO parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering") + parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation") parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them") diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 01eeb3c3..e96be89c 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -417,8 +417,11 @@ def train(args): save_hf_ckpt=args.save_hf_ckpt, disable_ds_ckpt=args.disable_ds_ckpt, packing_samples=args.packing_samples, - # overlong_reward + # DAPO dynamic sampling dynamic_sampling=args.dynamic_sampling, + dynamic_sampling_metric=args.dynamic_sampling_metric, + max_num_gen_batches=args.max_num_gen_batches, + # overlong_reward overlong_buffer=args.overlong_buffer, overlong_buffer_len=args.overlong_buffer_len, overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, @@ -466,7 +469,9 @@ def train(args): parser.add_argument("--load_checkpoint", action="store_true", default=False) # DAPO - parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy: filter out prompt groups with zero metric variance and accumulate until train_batch_size is reached") + parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering (default: reward)") + parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation. Non-positive means no limit (default: 10)") parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them") diff --git a/examples/r1_aqa/train_colocate.py b/examples/r1_aqa/train_colocate.py index fee2fb23..e391fa08 100644 --- a/examples/r1_aqa/train_colocate.py +++ b/examples/r1_aqa/train_colocate.py @@ -329,8 +329,11 @@ def train(args): save_hf_ckpt=args.save_hf_ckpt, disable_ds_ckpt=args.disable_ds_ckpt, packing_samples=args.packing_samples, - # DAPO / overlong + # DAPO dynamic sampling dynamic_sampling=args.dynamic_sampling, + dynamic_sampling_metric=getattr(args, 'dynamic_sampling_metric', 'reward'), + max_num_gen_batches=getattr(args, 'max_num_gen_batches', 10), + # overlong_reward overlong_buffer=args.overlong_buffer, overlong_buffer_len=args.overlong_buffer_len, overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, @@ -388,7 +391,9 @@ def train(args): parser.add_argument("--load_checkpoint", action="store_true", default=False) # DAPO - parser.add_argument("--dynamic_sampling", action="store_true", default=False) + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", choices=["reward", "acc"], help="Metric for dynamic sampling group filtering") + parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Max generation batches for dynamic sampling accumulation") parser.add_argument("--overlong_buffer", action="store_true", default=False) parser.add_argument("--overlong_buffer_len", type=int, default=1024) parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0) diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index c6993005..50feb138 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -112,6 +112,11 @@ class StrategyConfig: # Dynamic sampling and advantage estimation # (bool): Enable dynamic sampling for advantage estimation, defaults to False dynamic_sampling: bool = False + # (str): Metric used for dynamic sampling group filtering ("acc", "reward"), defaults to "reward" + dynamic_sampling_metric: str = "reward" + # (int): Max number of generation batches for dynamic sampling accumulation. + # Non-positive values mean no upper limit. defaults to 10 + max_num_gen_batches: int = 10 # (str): Advantage estimator method, defaults to "gae" advantage_estimator: str = "group_norm" @@ -280,7 +285,7 @@ def print_config_summary(self) -> None: # Dynamic Sampling and Advantage Estimation Parameters print("\nDynamic Sampling and Advantage Estimation Parameters:") - for attr in ['dynamic_sampling', 'advantage_estimator']: + for attr in ['dynamic_sampling', 'dynamic_sampling_metric', 'max_num_gen_batches', 'advantage_estimator']: current = getattr(self, attr) default = getattr(default_config, attr) status = "Overridden" if current != default else "Default" diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 7e998817..f0e655cb 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -6,6 +6,7 @@ import torch import math +import numpy as np import torch.nn as nn from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -16,6 +17,7 @@ from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl from lightrft.utils.distributed_sampler import DistributedSampler from lightrft.trainer import AdaptiveKLController, ExperienceVL, FixedKLController, NaiveExperienceMakerVL, NaiveReplayBufferVL # noqa +from lightrft.trainer.replay_buffer_utils import split_experience_batch, remove_padding_in_sequences class PPOTrainerVL(ABC): @@ -252,6 +254,136 @@ def __init__( log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name) self._tensorboard = SummaryWriter(log_dir=log_dir) + def _filter_buffer_items_by_metric(self, items, n_samples_per_prompt, metric_name="reward"): + """ + Filter replay buffer items by group metric variance (DAPO dynamic sampling). + + Groups items by prompt (every n_samples_per_prompt consecutive items form a group). + Keeps only groups where the metric has non-zero variance (i.e., not all-correct or all-wrong). + + :param items: List of BufferItemVL from replay buffer + :type items: List + :param n_samples_per_prompt: Number of samples per prompt + :type n_samples_per_prompt: int + :param metric_name: Metric to use for filtering ("reward" or "acc") + :type metric_name: str + :return: Tuple of (kept_items, num_kept_prompts, num_total_prompts, num_filtered_prompts) + :rtype: Tuple[List, int, int, int] + """ + n = n_samples_per_prompt + num_groups = len(items) // n + kept_items = [] + num_filtered = 0 + + for g in range(num_groups): + group = items[g * n:(g + 1) * n] + + metric_vals = [] + for item in group: + if metric_name == "reward" and hasattr( + item, 'info' + ) and item.info is not None and 'reward' in item.info: + val = item.info['reward'] + if isinstance(val, torch.Tensor): + metric_vals.append(val.float().mean().item()) + else: + metric_vals.append(float(val)) + elif metric_name == "acc" and hasattr( + item, 'info' + ) and item.info is not None and 'reward_metrics' in item.info: + rm = item.info.get('reward_metrics') + if rm is not None and 'accuracy_reward' in rm: + val = rm['accuracy_reward'] + if isinstance(val, torch.Tensor): + metric_vals.append(val.float().mean().item()) + else: + metric_vals.append(float(val)) + else: + metric_vals.append(0.0) + else: + metric_vals.append(0.0) + + metric_std = np.std(metric_vals) + + if metric_std > 1e-8 or n == 1: + kept_items.extend(group) + else: + num_filtered += 1 + + return kept_items, num_groups - num_filtered, num_groups, num_filtered + + def _collect_rollout_status(self, items): + """ + Collect rollout statistics from replay buffer items. + + :param items: List of experience items + :type items: List + :return: Dictionary of rollout statistics + :rtype: dict + """ + rollout_status = {} + if not items: + return rollout_status + + all_rewards = [] + all_format_rewards = [] + all_accuracy_rewards = [] + all_response_lengths = [] + + for item in items: + if hasattr(item, 'info') and item.info is not None and 'reward' in item.info: + all_rewards.append(item.info['reward']) + + if ( + hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info + and item.info['reward_metrics'] is not None + ): + reward_metrics = item.info['reward_metrics'] + if 'format_reward' in reward_metrics: + all_format_rewards.append(reward_metrics['format_reward']) + if 'accuracy_reward' in reward_metrics: + all_accuracy_rewards.append(reward_metrics['accuracy_reward']) + + if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: + all_response_lengths.append(item.info['response_length']) + + device = torch.cuda.current_device() + + if all_rewards: + if isinstance(all_rewards[0], torch.Tensor): + rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards]) + else: + rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device) + rollout_status["rollout_reward"] = rewards_tensor.mean().item() + rollout_status["rollout_reward_std"] = rewards_tensor.std().item() + + if all_format_rewards: + if isinstance(all_format_rewards[0], torch.Tensor): + format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards]) + else: + format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) + mean_format_reward = format_tensor.mean().item() + if abs(mean_format_reward) > 1e-6: + rollout_status["rollout_format_reward"] = mean_format_reward + + if all_accuracy_rewards: + if isinstance(all_accuracy_rewards[0], torch.Tensor): + accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards]) + else: + accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) + mean_accuracy_reward = accuracy_tensor.mean().item() + if abs(mean_accuracy_reward) > 1e-6: + rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward + + if all_response_lengths: + if isinstance(all_response_lengths[0], torch.Tensor): + lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths]) + else: + lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device) + rollout_status["rollout_response_length"] = lengths_tensor.mean().item() + + return rollout_status + def fit( self, args, @@ -262,7 +394,11 @@ def fit( num_update_steps_per_episodes=1, ) -> None: """ - Main training loop for PPO. + Main training loop for PPO with optional DAPO dynamic sampling. + + When dynamic_sampling is enabled, the loop filters out prompt groups with zero + metric variance and accumulates qualified groups across multiple generation batches + until train_batch_size is reached or max_num_gen_batches is hit. :param args: Training arguments. :type args: Namespace @@ -282,8 +418,34 @@ def fit( samples_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt samples_per_train = args.train_batch_size * args.n_samples_per_prompt + # Dynamic sampling configuration + use_dynamic_sampling = getattr(args, 'dynamic_sampling', False) + dynamic_sampling_metric = getattr(args, 'dynamic_sampling_metric', 'reward') + max_num_gen_batches = getattr(args, 'max_num_gen_batches', 10) + # Print training mode information - if args.train_batch_size < args.rollout_batch_size: + if use_dynamic_sampling: + ds_world_size = torch.distributed.get_world_size() + ds_local_target = args.train_batch_size // ds_world_size + self.strategy.print( + f"\n{'=' * 80}\n" + f"DYNAMIC SAMPLING MODE (DAPO)\n" + f"{'=' * 80}\n" + f"Configuration:\n" + f" - Metric: {dynamic_sampling_metric}\n" + f" - Max generation batches: {max_num_gen_batches} " + f"({'unlimited' if max_num_gen_batches <= 0 else max_num_gen_batches})\n" + f" - train_batch_size: {args.train_batch_size} prompts (global)\n" + f" - local target per rank: {ds_local_target} prompts (world_size={ds_world_size})\n" + f" - rollout_batch_size: {args.rollout_batch_size} prompts per generation\n" + f" - n_samples_per_prompt: {args.n_samples_per_prompt}\n" + f"Behavior:\n" + f" - Groups with zero metric variance (all-correct/all-wrong) are filtered out.\n" + f" - Generation repeats until each rank has {ds_local_target} qualified prompts.\n" + f" - Training decision is synchronized across all ranks via all_reduce.\n" + f"{'=' * 80}\n" + ) + elif args.train_batch_size < args.rollout_batch_size: updates_per_rollout = samples_per_rollout / samples_per_train self.strategy.print( f"\n{'=' * 80}\n" @@ -306,14 +468,6 @@ def fit( ) # Calculate number of rollouts per episode. - # Regardless of TBS and RBS relationship, rollout count should be determined by "total data / rollout size". - # Numerator (num_update_steps * train_batch_size) equals "total samples planned for this episode". - # Denominator (rollout_batch_size * n_samples) equals "samples produced per rollout". - # This calculation ensures data collection volume is constant. - # When TBS=64, num_update_steps is naturally twice as large as when TBS=128. - # Substituting into formula: (2N * 0.5T) / R = (N * T) / R. - # Conclusion: Rollout count unchanged, but internal update loop count doubles due to smaller TBS. - num_rollouts_per_episodes = ( num_update_steps_per_episodes * args.train_batch_size // args.max_epochs // args.rollout_batch_size // args.n_samples_per_prompt @@ -321,7 +475,6 @@ def fit( # Safeguard to prevent num_rollouts_per_episodes from being 0 if num_rollouts_per_episodes == 0: - # Try recalculating with ceil to prevent fractional values from being discarded by integer division val = (num_update_steps_per_episodes * args.train_batch_size) / (args.max_epochs * args.rollout_batch_size * args.n_samples_per_prompt) num_rollouts_per_episodes = math.ceil(val) @@ -356,162 +509,326 @@ def fit( disable=not self.strategy.is_rank_0(), ) - for batch in self.prompts_dataloader: - # Compatible with both image-only (4 args) and video (5 args) dataloaders - if len(batch) == 5: - rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch - else: - rand_prompts, rand_images, rand_references, rand_labels = batch - rand_videos = None + if use_dynamic_sampling: + # ============================================================ + # DYNAMIC SAMPLING PATH (DAPO) + # Accumulate qualified prompt groups across multiple generation + # batches until train_batch_size is reached. + # Reference implementation: https://github.com/verl-project/verl + # + # Flow: + # 1. Generate rollouts → append to temp buffer (split into items) + # 2. Filter items by prompt group metric variance + # 3. Accumulate kept items across generation batches + # 4. When enough qualified prompts → train + # + # MULTI-GPU SYNCHRONIZATION: + # Each rank processes different prompts (via DistributedSampler), + # so filtering rates differ across ranks. To prevent deadlocks + # from divergent control flow (one rank training while another + # is still generating), we synchronize the "ready to train" + # decision across all ranks using all_reduce(MIN). + # ============================================================ + world_size = torch.distributed.get_world_size() + local_target_prompts = args.train_batch_size // world_size + device = torch.cuda.current_device() + + accumulated_items = [] # List of BufferItemVL + num_qualified_prompts = 0 + num_gen_batches = 0 + total_generated_prompts = 0 + total_filtered_prompts = 0 + + for batch in self.prompts_dataloader: + if len(batch) == 5: + rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch + else: + rand_prompts, rand_images, rand_references, rand_labels = batch + rand_videos = None + + self.strategy.print( + f"[DynSamp] Gen batch {num_gen_batches + 1}: generating rollouts " + f"for {len(rand_prompts)} prompts..." + ) - # TODO: Remove debug print - self.strategy.print( - f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa - ) + # Generate experiences and collect individual items + batch_items = [] + for i, experience in enumerate( + self.experience_maker.make_experience_list( + rand_prompts, + rand_images, + all_videos=rand_videos, + all_references=rand_references, + all_labels=rand_labels, + **self.generate_kwargs + ) + ): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + + # Split batched experience into individual items (same as replay_buffer.append) + experience.to_device(torch.device("cpu")) + items = split_experience_batch(experience) + items = remove_padding_in_sequences(items) + batch_items.extend(items) + + num_gen_batches += 1 + + # Filter by metric variance at the item level + kept, num_kept, num_total, num_filtered = self._filter_buffer_items_by_metric( + batch_items, args.n_samples_per_prompt, metric_name=dynamic_sampling_metric + ) - for i, experience in enumerate( - self.experience_maker.make_experience_list( - rand_prompts, - rand_images, - all_videos=rand_videos, - all_references=rand_references, - all_labels=rand_labels, - **self.generate_kwargs + accumulated_items.extend(kept) + num_qualified_prompts += num_kept + total_generated_prompts += num_total + total_filtered_prompts += num_filtered + + self.strategy.print( + f"[DynSamp] Batch {num_gen_batches}: kept {num_kept}/{num_total} prompts " + f"(filtered {num_filtered}). " + f"Accumulated: {num_qualified_prompts}/{local_target_prompts} local qualified prompts." ) - ): - if i == 0: - output = self.tokenizer.batch_decode( - experience.sequences[0].unsqueeze(0), skip_special_tokens=True + + # Synchronize "ready to train" decision across all ranks. + # Use all_reduce(MIN) so training only starts when ALL ranks + # have enough qualified prompts, preventing deadlock. + local_ready = torch.tensor( + [1.0 if num_qualified_prompts >= local_target_prompts else 0.0], + device=device, + ) + torch.distributed.all_reduce(local_ready, op=torch.distributed.ReduceOp.MIN) + all_ranks_ready = local_ready.item() > 0 + + # Also synchronize the max_gen_batches fallback decision + local_max_reached = torch.tensor( + [1.0 if (max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches) else 0.0], + device=device, + ) + torch.distributed.all_reduce(local_max_reached, op=torch.distributed.ReduceOp.MAX) + any_rank_max_reached = local_max_reached.item() > 0 + + if all_ranks_ready: + # Trim to exact local_target_prompts * n_samples_per_prompt + target_num_items = local_target_prompts * args.n_samples_per_prompt + accumulated_items = accumulated_items[:target_num_items] + num_qualified_prompts = local_target_prompts + + self.strategy.print( + f"[DynSamp] All ranks ready. Local target={local_target_prompts} reached after " + f"{num_gen_batches} generation batches. " + f"Total generated: {total_generated_prompts}, filtered: {total_filtered_prompts} " + f"({total_filtered_prompts / max(total_generated_prompts, 1) * 100:.1f}% filtered)." + ) + + # Add filtered items directly to replay buffer + self.replay_buffer.items.extend(accumulated_items) + + self.strategy.report_memory('after replay_buffer ready (dynamic sampling)') + + # Collect rollout statistics + rollout_status = self._collect_rollout_status(self.replay_buffer.items) + rollout_status["ds_gen_batches"] = num_gen_batches + rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) + rollout_status["ds_qualified_prompts"] = num_qualified_prompts * world_size + + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) + + self.strategy.report_memory('before train') + status = self.ppo_train(steps) + + self.strategy.report_memory('before clear buffer') + self.replay_buffer.clear() + self.strategy.report_memory('after train') + + if "kl" in status: + self.kl_ctl.update(status["kl"], local_target_prompts * args.n_samples_per_prompt) + + pbar.set_postfix(rollout_status) + + client_states = {"consumed_samples": steps * args.rollout_batch_size} + logs_dict_combined = {**rollout_status, **status} + self.save_logs_and_checkpoints( + args, steps, pbar, logs_dict_combined, client_states, episode=episode ) - self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + + pbar.update() + steps += 1 + + # Reset accumulation for next training step + accumulated_items = [] + num_qualified_prompts = 0 + num_gen_batches = 0 + total_generated_prompts = 0 + total_filtered_prompts = 0 + + elif any_rank_max_reached: + # At least one rank reached max generation batches self.strategy.print( - f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + f"[DynSamp] WARNING: max_num_gen_batches={max_num_gen_batches} reached by some rank " + f"with local {num_qualified_prompts}/{local_target_prompts} qualified prompts. " + f"Training with available data." ) - # print all - # self.strategy.print( - # f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa - # ) - - self.replay_buffer.append(experience) - - self.strategy.report_memory('after replay_buffer ready') - - # Aggregate rollout statistics from replay buffer - # Collect metrics from the rollout/collection phase - rollout_status = {} - if self.replay_buffer.items: - all_rewards = [] - all_format_rewards = [] - all_accuracy_rewards = [] - all_response_lengths = [] - - for item in self.replay_buffer.items: - # Collect rewards from rollout - if hasattr(item, 'info') and item.info is not None and 'reward' in item.info: - all_rewards.append(item.info['reward']) - - # Robust handling of reward_metrics - # 1. Check if info exists - # 2. Check if 'reward_metrics' key exists - # 3. Check if reward_metrics is not None (critical!) - if ( - hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info - and item.info['reward_metrics'] is not None - ): - - reward_metrics = item.info['reward_metrics'] - - # Safely extract sub-metrics - if 'format_reward' in reward_metrics: - all_format_rewards.append(reward_metrics['format_reward']) - if 'accuracy_reward' in reward_metrics: - all_accuracy_rewards.append(reward_metrics['accuracy_reward']) - - # Collect response lengths from rollout - if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: - all_response_lengths.append(item.info['response_length']) - - # Compute rollout statistics - device = torch.cuda.current_device() - - if all_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_rewards[0], torch.Tensor): - rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards]) - else: - rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device) - rollout_status["rollout_reward"] = rewards_tensor.mean().item() - rollout_status["rollout_reward_std"] = rewards_tensor.std().item() - - if all_format_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - # Issue: all_format_rewards may contain tensors (from reward_metrics), - # but torch.tensor() cannot convert a list of tensors directly. - # Solution: Use torch.cat() for tensor lists, torch.tensor() for scalar lists - if isinstance(all_format_rewards[0], torch.Tensor): - # List of tensors: concatenate them - format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards]) + + # Synchronize item count: use the minimum across ranks + # to ensure all ranks have the same number of training steps. + # All ranks must participate in all_reduce. + local_count = torch.tensor([len(accumulated_items)], device=device, dtype=torch.long) + torch.distributed.all_reduce(local_count, op=torch.distributed.ReduceOp.MIN) + min_items = local_count.item() + # Round down to multiple of n_samples_per_prompt to keep prompt groups intact + min_items = (min_items // args.n_samples_per_prompt) * args.n_samples_per_prompt + accumulated_items = accumulated_items[:min_items] + + if accumulated_items: + self.replay_buffer.items.extend(accumulated_items) + + rollout_status = self._collect_rollout_status(self.replay_buffer.items) + rollout_status["ds_gen_batches"] = num_gen_batches + rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) + rollout_status["ds_qualified_prompts"] = len(accumulated_items) // args.n_samples_per_prompt + + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) + + status = self.ppo_train(steps) + self.replay_buffer.clear() + + if "kl" in status: + self.kl_ctl.update(status["kl"], len(accumulated_items)) + + pbar.set_postfix(rollout_status) + client_states = {"consumed_samples": steps * args.rollout_batch_size} + logs_dict_combined = {**rollout_status, **status} + self.save_logs_and_checkpoints( + args, steps, pbar, logs_dict_combined, client_states, episode=episode + ) + + pbar.update() + steps += 1 else: - # List of scalars: convert to tensor - format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) + # All ranks have zero items after sync - skip training + self.strategy.print( + "[DynSamp] WARNING: No items available after synchronization. Skipping training step." + ) + + # Reset + accumulated_items = [] + num_qualified_prompts = 0 + num_gen_batches = 0 + total_generated_prompts = 0 + total_filtered_prompts = 0 + + # Handle remaining accumulated items at end of episode. + # All ranks must participate in the all_reduce even if some have no items. + local_count = torch.tensor([len(accumulated_items)], device=device, dtype=torch.long) + torch.distributed.all_reduce(local_count, op=torch.distributed.ReduceOp.MIN) + min_items = local_count.item() + min_items = (min_items // args.n_samples_per_prompt) * args.n_samples_per_prompt + accumulated_items = accumulated_items[:min_items] + + if accumulated_items: + self.strategy.print( + f"[DynSamp] End of episode: training with " + f"{len(accumulated_items) // args.n_samples_per_prompt} " + f"remaining qualified prompts per rank." + ) + self.replay_buffer.items.extend(accumulated_items) - mean_format_reward = format_tensor.mean().item() + rollout_status = self._collect_rollout_status(self.replay_buffer.items) + rollout_status["ds_gen_batches"] = num_gen_batches + rollout_status["ds_filter_rate"] = total_filtered_prompts / max(total_generated_prompts, 1) + rollout_status["ds_qualified_prompts"] = len(accumulated_items) // args.n_samples_per_prompt - # Only display if mean is significantly non-zero - if abs(mean_format_reward) > 1e-6: - rollout_status["rollout_format_reward"] = mean_format_reward + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) - if all_accuracy_rewards: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_accuracy_rewards[0], torch.Tensor): - accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards]) - else: - accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) + status = self.ppo_train(steps) + self.replay_buffer.clear() - mean_accuracy_reward = accuracy_tensor.mean().item() + if "kl" in status: + self.kl_ctl.update(status["kl"], len(accumulated_items)) - # Only display if mean is significantly non-zero - if abs(mean_accuracy_reward) > 1e-6: - rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward + pbar.set_postfix(rollout_status) + client_states = {"consumed_samples": steps * args.rollout_batch_size} + logs_dict_combined = {**rollout_status, **status} + self.save_logs_and_checkpoints( + args, steps, pbar, logs_dict_combined, client_states, episode=episode + ) - if all_response_lengths: - # [TENSOR-FIX] Handle both tensor lists and scalar lists - if isinstance(all_response_lengths[0], torch.Tensor): - lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths]) - else: - lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device) + pbar.update() + steps += 1 - rollout_status["rollout_response_length"] = lengths_tensor.mean().item() + else: + # ============================================================ + # STANDARD PATH (no dynamic sampling) + # ============================================================ + for batch in self.prompts_dataloader: + if len(batch) == 5: + rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch + else: + rand_prompts, rand_images, rand_references, rand_labels = batch + rand_videos = None + + self.strategy.print( + f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa + ) - # TODO: Check normalization behavior - if self.args.advantage_estimator != "group_norm": - self.replay_buffer.normalize("advantages", self.strategy) + for i, experience in enumerate( + self.experience_maker.make_experience_list( + rand_prompts, + rand_images, + all_videos=rand_videos, + all_references=rand_references, + all_labels=rand_labels, + **self.generate_kwargs + ) + ): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + self.strategy.print( + f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + ) - self.strategy.report_memory('before train') + self.replay_buffer.append(experience) - status = self.ppo_train(steps) + self.strategy.report_memory('after replay_buffer ready') - self.strategy.report_memory('before clear buffer') - self.replay_buffer.clear() + rollout_status = self._collect_rollout_status(self.replay_buffer.items) - self.strategy.report_memory('after train') + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) - if "kl" in status: - self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt) + self.strategy.report_memory('before train') - # Update Episode pbar with ROLLOUT statistics (not training statistics!) - pbar.set_postfix(rollout_status) + status = self.ppo_train(steps) - # Logs/checkpoints: save BOTH ROLLOUT and TRAINING statistics to wandb - # [FIX] Merge rollout_status (from inference) and status (from training) - # to ensure wandb logs contain both types of metrics - client_states = {"consumed_samples": steps * args.rollout_batch_size} - logs_dict_combined = {**rollout_status, **status} # Merge: rollout first, training second + self.strategy.report_memory('before clear buffer') + self.replay_buffer.clear() - self.save_logs_and_checkpoints(args, steps, pbar, logs_dict_combined, client_states, episode=episode) + self.strategy.report_memory('after train') + + if "kl" in status: + self.kl_ctl.update(status["kl"], args.rollout_batch_size * args.n_samples_per_prompt) + + pbar.set_postfix(rollout_status) + + client_states = {"consumed_samples": steps * args.rollout_batch_size} + logs_dict_combined = {**rollout_status, **status} + + self.save_logs_and_checkpoints( + args, steps, pbar, logs_dict_combined, client_states, episode=episode + ) - pbar.update() - steps = steps + 1 + pbar.update() + steps = steps + 1 if self._wandb is not None and self.strategy.is_rank_0(): self._wandb.finish()