From 2d3651a5f1cc604d77418585df1a2181e68800f6 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Mon, 9 Feb 2026 16:32:17 +0800 Subject: [PATCH 1/3] feature(sunjx): dynamic sampling --- examples/gsm8k_geo3k/train_colocate.py | 6 ++- lightrft/strategy/config.py | 8 ++- lightrft/trainer/advantage_calculator.py | 50 +++++++++++++----- lightrft/trainer/ppo_trainer_vl.py | 64 ++++++++++++++++++++++-- 4 files changed, 108 insertions(+), 20 deletions(-) diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 818bcce2..da036e94 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -440,8 +440,10 @@ def train(args): parser.add_argument("--max_ckpt_mem", type=int, default=1e8) parser.add_argument("--load_checkpoint", action="store_true", default=False) - # DAPO - parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + # DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy to filter out groups with uniform rewards") + parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", help="Metric for dynamic sampling filtering: 'reward', 'acc', etc.") + parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Maximum number of generation batches for dynamic sampling (<=0 means no limit)") parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them") diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index c6993005..70318e6c 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -110,8 +110,12 @@ class StrategyConfig: overlong_buffer_penalty_factor: float = 1.0 # Dynamic sampling and advantage estimation - # (bool): Enable dynamic sampling for advantage estimation, defaults to False + # (bool): Enable dynamic sampling for advantage estimation (DAPO), defaults to False dynamic_sampling: bool = False + # (str): Metric to filter groups in dynamic sampling: "reward", "acc", etc., defaults to "reward" + dynamic_sampling_metric: str = "reward" + # (int): Maximum number of generation batches for dynamic sampling, <=0 means no limit, defaults to 10 + max_num_gen_batches: int = 10 # (str): Advantage estimator method, defaults to "gae" advantage_estimator: str = "group_norm" @@ -280,7 +284,7 @@ def print_config_summary(self) -> None: # Dynamic Sampling and Advantage Estimation Parameters print("\nDynamic Sampling and Advantage Estimation Parameters:") - for attr in ['dynamic_sampling', 'advantage_estimator']: + for attr in ['dynamic_sampling', 'dynamic_sampling_metric', 'max_num_gen_batches', 'advantage_estimator']: current = getattr(self, attr) default = getattr(default_config, attr) status = "Overridden" if current != default else "Default" diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 1ce9a8f7..ab1da834 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -627,9 +627,12 @@ class GroupNormCalculator(BaseREINFORCECalculator): """ Group normalization calculator (GRPO). - Normalizes rewards within each group and optionally filters degenerate cases. + Normalizes rewards within each group and optionally filters degenerate cases + using dynamic sampling strategy (DAPO). - Reference: GRPO: https://arxiv.org/pdf/2402.03300 + Reference: + - GRPO: https://arxiv.org/pdf/2402.03300 + - DAPO: https://arxiv.org/abs/2503.14476 """ def preprocess_rewards( self, @@ -640,6 +643,10 @@ def preprocess_rewards( """ Preprocess rewards using group normalization with optional dynamic filtering. + Dynamic sampling (DAPO) filters out groups where all samples have the same metric value + (e.g., all rewards are 0 or all are 1), as these groups provide no learning signal. + This is achieved by setting action_mask to all zeros for filtered groups. + :param rewards: Concatenated reward tensor :type rewards: torch.Tensor :param experiences: List of experiences (may be filtered) @@ -652,17 +659,36 @@ def preprocess_rewards( config = self.config n_samples = config.n_samples_per_prompt - # Dynamic sampling filtering + # Dynamic sampling filtering (DAPO) + # Filter out groups where all outputs have the same metric value if config.dynamic_sampling: - step_size = n_samples // config.micro_train_batch_size - for i in range(0, len(experiences), step_size): - chunk = experiences[i:i + step_size] - chunk_rewards = torch.cat([exp.info["reward"] for exp in chunk]) - - # Filter out degenerate cases (all 0s or all 1s) - if torch.all(chunk_rewards == 0) or torch.all(chunk_rewards == 1): - for exp in chunk: - exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool) + metric = config.dynamic_sampling_metric + + # When micro_rollout_batch_size == n_samples_per_prompt, each experience + # contains all samples for one prompt in batched format + # exp.info["reward"] has shape=[n_samples], representing all samples for that prompt + for exp in experiences: + reward_tensor = exp.info["reward"] # shape=[n_samples] + + # Extract metric values (all samples within this experience/prompt) + if metric == "reward": + metric_values = reward_tensor + elif metric == "acc": + # Use accuracy if available + if "accuracy" in exp.info: + metric_values = exp.info["accuracy"] + else: + # Fallback: treat reward as binary accuracy + metric_values = reward_tensor + else: + # Default to reward + metric_values = reward_tensor + + # Check if all samples have the same metric value (degenerate group) + # This prompt provides no learning signal for relative comparison + if torch.all(metric_values == metric_values[0]): + # Mark this experience for filtering by zeroing out action mask + exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool) # Group normalization rewards = rewards.reshape(-1, n_samples).to("cuda") diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 4584129b..25736a1a 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -369,6 +369,16 @@ def fit( f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa ) + # ========== Dynamic Sampling Loop (DAPO) ========== + # When dynamic_sampling is enabled, we may need to generate multiple batches + # to collect enough valid prompts (groups with varying rewards) + num_gen_batches = 0 + target_num_prompts = args.rollout_batch_size + + while True: + num_gen_batches += 1 + + # Generate experiences for current batch for i, experience in enumerate( self.experience_maker.make_experience_list( rand_prompts, @@ -387,12 +397,58 @@ def fit( self.strategy.print( f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa ) - # print all - # self.strategy.print( - # f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa - # ) self.replay_buffer.append(experience) + + # Check if dynamic sampling is enabled + if not self.strategy.config.dynamic_sampling: + # No dynamic sampling, exit after first batch + break + + # Count valid prompts (groups with non-zero action masks after filtering) + n_samples = args.n_samples_per_prompt + num_valid_prompts = 0 + + for i in range(0, len(self.replay_buffer.items), n_samples): + group = self.replay_buffer.items[i:i + n_samples] + # Check if any experience in this group has non-zero action mask + has_valid_actions = any(exp.action_mask.sum() > 0 for exp in group) + if has_valid_actions: + num_valid_prompts += 1 + + if self.strategy.is_rank_0(): + self.strategy.print( + f"Dynamic Sampling: num_valid_prompts={num_valid_prompts}, " + f"target={target_num_prompts}, num_gen_batches={num_gen_batches}" + ) + + # Check if we have enough valid prompts + if num_valid_prompts >= target_num_prompts: + # Trim to exact target size + target_num_experiences = target_num_prompts * n_samples + self.replay_buffer.items = self.replay_buffer.items[:target_num_experiences] + break + + # Check if we've reached the maximum number of generation batches + max_num_gen_batches = self.strategy.config.max_num_gen_batches + if max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches: + if self.strategy.is_rank_0(): + self.strategy.print( + f"Warning: Reached max_num_gen_batches={max_num_gen_batches} " + f"with only {num_valid_prompts} valid prompts. Proceeding with available data." + ) + break + + # Need more samples, continue to next batch + # Note: In a real implementation, you would fetch the next batch from dataloader + # For now, we break to avoid infinite loop (this is a simplified implementation) + if self.strategy.is_rank_0(): + self.strategy.print( + f"Warning: Dynamic sampling requires fetching more batches, " + f"but current implementation processes one batch at a time. " + f"Proceeding with {num_valid_prompts} valid prompts." + ) + break self.strategy.report_memory('after replay_buffer ready') From 36ef5c12a3e53ff7eb2eeb8f319821a238dfb4c4 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Mon, 9 Feb 2026 18:00:13 +0800 Subject: [PATCH 2/3] feature(sunjx): fix bugs --- lightrft/trainer/ppo_trainer_vl.py | 99 +++++++++++++++++++--------- lightrft/trainer/spmd_ppo_trainer.py | 11 ++++ 2 files changed, 79 insertions(+), 31 deletions(-) diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 25736a1a..06527dd4 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -374,31 +374,32 @@ def fit( # to collect enough valid prompts (groups with varying rewards) num_gen_batches = 0 target_num_prompts = args.rollout_batch_size + n_samples = args.n_samples_per_prompt while True: num_gen_batches += 1 # Generate experiences for current batch - for i, experience in enumerate( - self.experience_maker.make_experience_list( - rand_prompts, - rand_images, - all_videos=rand_videos, - all_references=rand_references, - all_labels=rand_labels, - **self.generate_kwargs - ) - ): - if i == 0: - output = self.tokenizer.batch_decode( - experience.sequences[0].unsqueeze(0), skip_special_tokens=True - ) - self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) - self.strategy.print( - f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + for i, experience in enumerate( + self.experience_maker.make_experience_list( + rand_prompts, + rand_images, + all_videos=rand_videos, + all_references=rand_references, + all_labels=rand_labels, + **self.generate_kwargs ) + ): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + self.strategy.print( + f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + ) - self.replay_buffer.append(experience) + self.replay_buffer.append(experience) # Check if dynamic sampling is enabled if not self.strategy.config.dynamic_sampling: @@ -406,19 +407,48 @@ def fit( break # Count valid prompts (groups with non-zero action masks after filtering) - n_samples = args.n_samples_per_prompt + # This check happens AFTER all experiences in the batch are generated + # Note: When micro_rollout_batch_size == n_samples_per_prompt, each experience + # contains all samples for one prompt. So we count experiences directly. + total_experiences = len(self.replay_buffer.items) num_valid_prompts = 0 + num_filtered_prompts = 0 - for i in range(0, len(self.replay_buffer.items), n_samples): - group = self.replay_buffer.items[i:i + n_samples] - # Check if any experience in this group has non-zero action mask - has_valid_actions = any(exp.action_mask.sum() > 0 for exp in group) - if has_valid_actions: + # Debug: Check experience structure + if self.strategy.is_rank_0() and total_experiences > 0: + first_exp = self.replay_buffer.items[0] + if hasattr(first_exp, 'action_mask'): + action_mask_shape = first_exp.action_mask.shape + # Safely get reward shape + reward_shape = "N/A" + if hasattr(first_exp, 'info') and first_exp.info is not None: + reward = first_exp.info.get("reward", None) + if reward is not None: + if isinstance(reward, torch.Tensor): + reward_shape = reward.shape + else: + reward_shape = f"scalar({type(reward).__name__})" + self.strategy.print( + f"Debug: total_experiences={total_experiences}, " + f"n_samples={n_samples}, " + f"first_exp.action_mask.shape={action_mask_shape}, " + f"first_exp.info['reward']={reward_shape}" + ) + + # Count valid prompts: each experience corresponds to one prompt + # when micro_rollout_batch_size == n_samples_per_prompt + for exp in self.replay_buffer.items: + # Check if this experience has any valid actions (not all filtered) + if exp.action_mask.sum() > 0: num_valid_prompts += 1 + else: + num_filtered_prompts += 1 if self.strategy.is_rank_0(): self.strategy.print( - f"Dynamic Sampling: num_valid_prompts={num_valid_prompts}, " + f"Dynamic Sampling: total_experiences={total_experiences}, " + f"num_valid_prompts={num_valid_prompts}, " + f"num_filtered_prompts={num_filtered_prompts}, " f"target={target_num_prompts}, num_gen_batches={num_gen_batches}" ) @@ -439,14 +469,13 @@ def fit( ) break - # Need more samples, continue to next batch - # Note: In a real implementation, you would fetch the next batch from dataloader - # For now, we break to avoid infinite loop (this is a simplified implementation) + # Need more samples, but current implementation only processes one batch + # In a full implementation, we would fetch the next batch from dataloader here + # For now, we proceed with what we have if self.strategy.is_rank_0(): self.strategy.print( - f"Warning: Dynamic sampling requires fetching more batches, " - f"but current implementation processes one batch at a time. " - f"Proceeding with {num_valid_prompts} valid prompts." + f"Warning: Dynamic sampling needs more batches, but current implementation " + f"processes one batch at a time. Proceeding with {num_valid_prompts} valid prompts." ) break @@ -767,6 +796,14 @@ def training_step_actor(self, base_action_log_probs = experience.base_action_log_probs if advantages is not None: + # Check if advantages is empty (can happen when dynamic sampling filters all samples) + if advantages.numel() == 0: + self.strategy.print( + "[Warning] Empty advantages tensor detected. This may occur when dynamic sampling " + "filters out all samples in a batch. Skipping this training step." + ) + return {} # Return empty status dict to skip this step + # Log max advantage before clipping for debugging (optional) max_adv = advantages.max().item() if max_adv > 10.0: diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index d79a7458..692c98f2 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -231,6 +231,17 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train sequences, pixel_values, context="pre_training_validation" ) should_skip_local = not is_valid + + # Check for empty advantages (can happen when dynamic sampling filters all samples) + if not should_skip_local and hasattr(experience, 'advantages') and experience.advantages is not None: + if isinstance(experience.advantages, list): + # Packed samples: check if any advantages are empty + if any(adv.numel() == 0 for adv in experience.advantages): + should_skip_local = True + else: + # Single tensor: check if empty + if experience.advantages.numel() == 0: + should_skip_local = True # Step 2: Synchronize skip decision across all ranks via all_reduce # This ensures all ranks agree on whether to skip, preventing execution divergence From 67dbd0d71e84272763df6f87f71c590352eb4ca0 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Tue, 10 Feb 2026 11:15:23 +0800 Subject: [PATCH 3/3] feature(sunjx): pass code check --- lightrft/trainer/advantage_calculator.py | 8 ++-- lightrft/trainer/ppo_trainer_vl.py | 50 +++++------------------- lightrft/trainer/spmd_ppo_trainer.py | 2 +- 3 files changed, 15 insertions(+), 45 deletions(-) diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index ab1da834..cb3844fc 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -630,7 +630,7 @@ class GroupNormCalculator(BaseREINFORCECalculator): Normalizes rewards within each group and optionally filters degenerate cases using dynamic sampling strategy (DAPO). - Reference: + Reference: - GRPO: https://arxiv.org/pdf/2402.03300 - DAPO: https://arxiv.org/abs/2503.14476 """ @@ -663,13 +663,13 @@ def preprocess_rewards( # Filter out groups where all outputs have the same metric value if config.dynamic_sampling: metric = config.dynamic_sampling_metric - + # When micro_rollout_batch_size == n_samples_per_prompt, each experience # contains all samples for one prompt in batched format # exp.info["reward"] has shape=[n_samples], representing all samples for that prompt for exp in experiences: reward_tensor = exp.info["reward"] # shape=[n_samples] - + # Extract metric values (all samples within this experience/prompt) if metric == "reward": metric_values = reward_tensor @@ -683,7 +683,7 @@ def preprocess_rewards( else: # Default to reward metric_values = reward_tensor - + # Check if all samples have the same metric value (degenerate group) # This prompt provides no learning signal for relative comparison if torch.all(metric_values == metric_values[0]): diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 06527dd4..ba00a3e0 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -375,10 +375,10 @@ def fit( num_gen_batches = 0 target_num_prompts = args.rollout_batch_size n_samples = args.n_samples_per_prompt - + while True: num_gen_batches += 1 - + # Generate experiences for current batch for i, experience in enumerate( self.experience_maker.make_experience_list( @@ -400,65 +400,35 @@ def fit( ) self.replay_buffer.append(experience) - + # Check if dynamic sampling is enabled if not self.strategy.config.dynamic_sampling: # No dynamic sampling, exit after first batch break - + # Count valid prompts (groups with non-zero action masks after filtering) # This check happens AFTER all experiences in the batch are generated # Note: When micro_rollout_batch_size == n_samples_per_prompt, each experience # contains all samples for one prompt. So we count experiences directly. - total_experiences = len(self.replay_buffer.items) num_valid_prompts = 0 - num_filtered_prompts = 0 - - # Debug: Check experience structure - if self.strategy.is_rank_0() and total_experiences > 0: - first_exp = self.replay_buffer.items[0] - if hasattr(first_exp, 'action_mask'): - action_mask_shape = first_exp.action_mask.shape - # Safely get reward shape - reward_shape = "N/A" - if hasattr(first_exp, 'info') and first_exp.info is not None: - reward = first_exp.info.get("reward", None) - if reward is not None: - if isinstance(reward, torch.Tensor): - reward_shape = reward.shape - else: - reward_shape = f"scalar({type(reward).__name__})" - self.strategy.print( - f"Debug: total_experiences={total_experiences}, " - f"n_samples={n_samples}, " - f"first_exp.action_mask.shape={action_mask_shape}, " - f"first_exp.info['reward']={reward_shape}" - ) - - # Count valid prompts: each experience corresponds to one prompt - # when micro_rollout_batch_size == n_samples_per_prompt for exp in self.replay_buffer.items: # Check if this experience has any valid actions (not all filtered) if exp.action_mask.sum() > 0: num_valid_prompts += 1 - else: - num_filtered_prompts += 1 - + if self.strategy.is_rank_0(): self.strategy.print( - f"Dynamic Sampling: total_experiences={total_experiences}, " - f"num_valid_prompts={num_valid_prompts}, " - f"num_filtered_prompts={num_filtered_prompts}, " + f"Dynamic Sampling: num_valid_prompts={num_valid_prompts}, " f"target={target_num_prompts}, num_gen_batches={num_gen_batches}" ) - + # Check if we have enough valid prompts if num_valid_prompts >= target_num_prompts: # Trim to exact target size target_num_experiences = target_num_prompts * n_samples self.replay_buffer.items = self.replay_buffer.items[:target_num_experiences] break - + # Check if we've reached the maximum number of generation batches max_num_gen_batches = self.strategy.config.max_num_gen_batches if max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches: @@ -468,7 +438,7 @@ def fit( f"with only {num_valid_prompts} valid prompts. Proceeding with available data." ) break - + # Need more samples, but current implementation only processes one batch # In a full implementation, we would fetch the next batch from dataloader here # For now, we proceed with what we have @@ -803,7 +773,7 @@ def training_step_actor(self, "filters out all samples in a batch. Skipping this training step." ) return {} # Return empty status dict to skip this step - + # Log max advantage before clipping for debugging (optional) max_adv = advantages.max().item() if max_adv > 10.0: diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index 692c98f2..7b7b5bd6 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -231,7 +231,7 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train sequences, pixel_values, context="pre_training_validation" ) should_skip_local = not is_valid - + # Check for empty advantages (can happen when dynamic sampling filters all samples) if not should_skip_local and hasattr(experience, 'advantages') and experience.advantages is not None: if isinstance(experience.advantages, list):