From 5331c4e4e59047aa380c397be79668ef6a9b1ef2 Mon Sep 17 00:00:00 2001 From: AltmanD Date: Thu, 22 Jan 2026 15:47:53 +0800 Subject: [PATCH 1/3] Migrate partial rollout feature code --- examples/gsm8k_geo3k/train_colocate.py | 9 + lightrft/trainer/fast_exp_maker_partial.py | 606 +++++++++++++++++++++ lightrft/trainer/spmd_ppo_trainer.py | 224 ++++++++ 3 files changed, 839 insertions(+) create mode 100644 lightrft/trainer/fast_exp_maker_partial.py diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 20abc358..4d8ccd37 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -397,6 +397,10 @@ def train(args): overlong_buffer_len=args.overlong_buffer_len, overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, print_replay_buffer_stats=args.print_replay_buffer_stats, + # partial rollout + use_partial=args.use_partial, + partial_percent=args.partial_percent, + max_budget=args.max_budget, ) trainer.fit(args, prompts_dataloader=prompts_dataloader, pretrain_dataloader=pretrain_dataloader, eval_dataloader=eval_dataloader, consumed_samples=0, num_update_steps_per_episodes=num_update_steps_per_episodes) @@ -608,6 +612,11 @@ def train(args): # High-entropy token filtering (from "Beyond the 80/20 Rule" paper) parser.add_argument("--high_entropy_token_ratio", type=float, default=0.0, help="Ratio of high-entropy tokens to use for gradient updates (0.0 means use all tokens, 0.2 means use top 20% highest entropy tokens). Common value when enabled: 0.2. Based on 'Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective Reinforcement Learning for LLM Reasoning' (https://arxiv.org/abs/2506.01939)") + # Partial Rollout + parser.add_argument("--use_partial", action="store_true", default=False, help="whether to use the partial rollout trainer") + parser.add_argument("--partial_percent", type=float, default=0.7, help="partial rollout percent") + parser.add_argument("--max_budget", type=float, default=1024, help="partial rollout max_budget") + add_arguments(parser) args = parser.parse_args() diff --git a/lightrft/trainer/fast_exp_maker_partial.py b/lightrft/trainer/fast_exp_maker_partial.py new file mode 100644 index 00000000..b9e7cab7 --- /dev/null +++ b/lightrft/trainer/fast_exp_maker_partial.py @@ -0,0 +1,606 @@ +""" +PartialFastExperienceMaker – FastExperienceMaker with partial rollout and token‑budget regeneration. + +This subclass adds two key features: + 1. Partial rollout: only a fraction (partial_percent) of the total rollout batch is generated + in each call; the rest is kept in buffers. + 2. Token‑budget regeneration: samples whose generation reaches max_token_budget are flagged + and can be regenerated later (e.g., for continuing long‑form tasks). + +The class reuses the parent's infrastructure (MultimodalDataProcessor, RewardComputationEngine, +etc.) and only overrides the methods that implement the partial‑rollout logic. +""" + +from typing import List, Optional, Union, Tuple, Dict, Any +import os +import time +from copy import deepcopy + +import torch +import torch.distributed as dist +from vllm import SamplingParams +from easydict import EasyDict + +from openrlhf.trainer.ppo_utils.experience_maker import Experience, Samples +from openrlhf.trainer.ppo_utils.experience_maker_vl import SamplesVL +from lightrft.trainer.fast_exp_maker import FastExperienceMaker + + +class PartialFastExperienceMaker(FastExperienceMaker): + """ + FastExperienceMaker with partial rollout and token‑budget regeneration. + + Args: + partial_percent (float): fraction of the rollout batch to generate in one call. + max_token_budget (int): maximum allowed generation length before regeneration. + packing_samples (bool): whether to pack samples (inherited). + processor: multimodal processor (inherited). + *args, **kwargs: passed to parent. + """ + + def __init__( + self, + *args, + partial_percent: float = 0.7, + max_token_budget: int = 1024, + packing_samples: bool = False, + processor=None, + **kwargs + ): + super().__init__(*args, packing_samples=packing_samples, processor=processor, **kwargs) + self.partial_percent = partial_percent + self.max_token_budget = max_token_budget + + # Buffers for regeneration (regen) and non‑regeneration (noregen) samples. + # Each buffer is a dict mapping field names to lists of data. + self.regen_buffer: Dict[str, List] = {} + self.noregen_buffer: Dict[str, List] = {} + fields = [ + 'output', 'labels', 'prompts', 'images', 'images_num', + 'images_pixel_values', 'images_grid_thw', 'image_flags', 'references' + ] + for field in fields: + self.regen_buffer[field] = [] + self.noregen_buffer[field] = [] + + # Placeholders for batch‑size parameters (set by need_new_prompts) + self.rollout_batch_size = None + self.micro_rollout_batch_size = None + + def need_new_prompts(self, rollout_batch_size: int, micro_rollout_batch_size: int) -> bool: + """ + Check whether the buffers contain enough data to make a full experience batch. + + Returns: + True if new prompts need to be fetched (i.e., buffers are below the partial threshold). + """ + self.rollout_batch_size = rollout_batch_size + self.micro_rollout_batch_size = micro_rollout_batch_size + + # Total micro‑batches needed for a full rollout + total_micro = rollout_batch_size // micro_rollout_batch_size + # Micro‑batches we want to generate in one call + target_micro = int(self.partial_percent * total_micro) + required_samples = target_micro * micro_rollout_batch_size + + # Count how many samples are already stored in both buffers + total_samples = len(self.noregen_buffer.get('output', [])) + len(self.regen_buffer.get('output', [])) + return total_samples < required_samples + + @torch.no_grad() + def generate_samples( + self, + all_prompts: List[str], + all_images: Optional[List] = None, + images_num: Optional[List[int]] = None, + all_references: Optional[List[str]] = None, + all_labels: Optional[List] = None, + **generate_kwargs + ) -> List[Samples]: + """ + Generate samples using the parent's pipeline, but only a partial fraction. + + The method: + 1. If new inputs are provided, generate them with the parent's generate_samples. + 2. Split the generated outputs into regeneration and non‑regeneration buffers. + 3. Draw from the buffers to produce the requested number of samples (partial_percent). + 4. If the noregen buffer is insufficient, regenerate some samples from the regen buffer. + + Returns: + List of Samples (or SamplesVL) ready for experience making. + """ + args = self.strategy.args + is_multimodal = all_images is not None + internvl = "internvl" in self.actor.pretrain_or_model.lower() if is_multimodal else False + + # -------------------------------------------------------------------- + # Step 1: Generate new samples if inputs are provided + # -------------------------------------------------------------------- + if all_prompts is not None: + # Replicate the generation logic from fast_exp_maker_partial.py + # Prepare sampling parameters + if args.engine_type == "vllm": + sampling_params = SamplingParams( + temperature=generate_kwargs.get("temperature", 1.0), + top_p=generate_kwargs.get("top_p", 1.0), + top_k=generate_kwargs.get("top_k", -1), + max_tokens=self.max_token_budget, # use token budget + min_tokens=generate_kwargs.get("min_new_tokens", 1), + skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), + include_stop_str_in_output=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + elif args.engine_type == "sglang": + sampling_params = dict( + n=1, + temperature=generate_kwargs.get("temperature", 1.0), + top_p=generate_kwargs.get("top_p", 1.0), + top_k=generate_kwargs.get("top_k", -1), + max_new_tokens=self.max_token_budget, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), + spaces_between_special_tokens=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + else: + raise ValueError(f"Unsupported backend: {args.engine_type}") + + # Expand labels + expanded_labels = sum([[label] * args.n_samples_per_prompt for label in all_labels], []) if all_labels else [] + + # Process multimodal data + if is_multimodal: + processed = self._process_multimodal_data( + all_prompts=all_prompts, + all_images=all_images, + is_internvl=internvl, + all_references=all_references, + images_num=images_num + ) + prompt_token_ids = processed["all_prompt_token_ids"] + prompts = processed["all_prompts"] + images = processed["all_images"] + images_num = processed["all_images_num"] + pixel_values = processed["all_images_pixel_values"] + grid_thw = processed["all_images_grid_thw"] + image_flags = processed["all_image_flags"] + references = processed["all_references"] + else: + tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False) + prompt_token_ids = tokenized["input_ids"] + prompt_token_ids = sum([[token_ids] * args.n_samples_per_prompt for token_ids in prompt_token_ids], []) + prompts = all_prompts * args.n_samples_per_prompt + images = None + references = all_references * args.n_samples_per_prompt if all_references else None + + # Generate outputs via inference engine + outputs = self.strategy.gather_and_generate( + sampling_params=sampling_params, + all_prompt_token_ids=prompt_token_ids, + all_prompts=prompts if is_multimodal else None, + all_images=images if is_multimodal else None, + sleep_engine=False, + images_num=images_num if is_multimodal else None, + ) + + # Process outputs in micro-batches and store in buffers + for i in range(0, len(outputs), args.micro_rollout_batch_size): + batch_slice = slice(i, i + args.micro_rollout_batch_size) + output_batch = outputs[batch_slice] + labels_batch = expanded_labels[batch_slice] if expanded_labels else [] + prompts_batch = prompts[batch_slice] + images_batch = images[batch_slice] if images else None + images_num_batch = images_num[batch_slice] if images_num else None + references_batch = references[batch_slice] if references else None + + # Check if regeneration is needed + needs_regen = any(len(out.output_token_ids) >= self.max_token_budget for out in output_batch) + buffer_type = "regen" if needs_regen else "noregen" + + # Add to appropriate buffer + self._add_to_buffer(buffer_type, "output", output_batch) + self._add_to_buffer(buffer_type, "labels", labels_batch) + self._add_to_buffer(buffer_type, "prompts", prompts_batch) + if images_batch is not None: + self._add_to_buffer(buffer_type, "images", images_batch) + if images_num_batch is not None: + self._add_to_buffer(buffer_type, "images_num", images_num_batch) + if references_batch is not None: + self._add_to_buffer(buffer_type, "references", references_batch) + + if is_multimodal: + self._add_to_buffer(buffer_type, "image_flags", image_flags[batch_slice]) + # Handle image tensors + grid_batch = grid_thw[batch_slice] + self._add_to_buffer(buffer_type, "images_grid_thw", grid_batch) + # Calculate pixel values slice + patch_start = sum(g[0] * g[1] * g[2] for g in grid_thw[:i]) + patch_end = patch_start + sum(g[0] * g[1] * g[2] for g in grid_batch) + self._add_to_buffer(buffer_type, "images_pixel_values", pixel_values[patch_start:patch_end]) + + # -------------------------------------------------------------------- + # Step 2: Determine how many micro‑batches we need to return + # -------------------------------------------------------------------- + total_micro = self.rollout_batch_size // self.micro_rollout_batch_size + target_micro = int(self.partial_percent * total_micro) + + # How many micro‑batches are already available in the noregen buffer? + noregen_micro = len(self.noregen_buffer['output']) // self.micro_rollout_batch_size + if noregen_micro >= target_micro: + # Enough noregen samples – just take them + samples_data = self._get_from_buffer('noregen', target_micro * self.micro_rollout_batch_size) + else: + # Take all noregen samples and supplement with regenerated ones + samples_needed = target_micro - noregen_micro + noregen_data = self._get_from_buffer('noregen', noregen_micro * self.micro_rollout_batch_size) + regen_data = self._regenerate_from_buffer(samples_needed * self.micro_rollout_batch_size, **generate_kwargs) + samples_data = self._merge_data(noregen_data, regen_data) + + # -------------------------------------------------------------------- + # Step 3: Convert the collected data back to Samples objects + # -------------------------------------------------------------------- + samples_list = self._generate_sample_list( + samples_data, + is_multimodal, + internvl, + **generate_kwargs + ) + self.strategy.maybe_sleep_inference_engine() + if all_prompts is None: + return samples_list + else: + # Return tuple with samples_list and images_num, consistent with fast_exp_maker_partial.py + images_num_list = samples_data.get("images_num") + return samples_list, images_num_list + + def _process_multimodal_data(self, all_prompts, all_images, is_internvl, all_references, images_num): + """Wrapper around parent's multimodal_processor.process_multimodal_batch.""" + if self.multimodal_processor is None: + raise ValueError("Multimodal processor not initialized.") + return self.multimodal_processor.process_multimodal_batch( + all_prompts=all_prompts, + all_images=all_images, + all_references=all_references, + images_num=images_num, + n_samples_per_prompt=self.strategy.config.n_samples_per_prompt, + is_internvl=is_internvl, + ) + + def _add_to_buffer(self, buffer_type: str, data_name: str, data): + """Add data to specified buffer.""" + buffer = self.regen_buffer if buffer_type == 'regen' else self.noregen_buffer + if data_name not in buffer: + buffer[data_name] = [] + if isinstance(data, torch.Tensor): + buffer[data_name].append(data) + else: + buffer[data_name].extend(data if isinstance(data, list) else [data]) + + def _get_from_buffer(self, buffer_type: str, count: Optional[int]): + """Retrieve data from buffer, optionally limiting the amount.""" + buffer = self.regen_buffer if buffer_type == 'regen' else self.noregen_buffer + result = {} + for key, lst in buffer.items(): + if count is None: + result[key] = lst.copy() + buffer[key] = [] + else: + result[key] = lst[:count] + buffer[key] = lst[count:] + return result + + @torch.no_grad() + def _regenerate_from_buffer(self, num_needed: int, **kwargs) -> dict: + """Regenerate outputs for samples that reached token budget.""" + args = self.strategy.args + + # Get data from regeneration buffer + regen_data = self._get_from_buffer("regen", num_needed) + if not regen_data.get("output"): + return {} + + # Identify indices needing regeneration + regen_indices = [ + i for i, output in enumerate(regen_data["output"]) + if len(output.output_token_ids) >= self.max_token_budget + ] + + if not regen_indices: + return regen_data + + # Prepare regeneration inputs + regen_outputs = [regen_data["output"][i] for i in regen_indices] + regen_tokens = [output.output_token_ids for output in regen_outputs] + decoded_outputs = self.tokenizer.batch_decode(regen_tokens, skip_special_tokens=False) + + # Create new inputs by combining original prompts and partial outputs + new_inputs = [ + prompt + output + for prompt, output in zip( + [regen_data["prompts"][i] for i in regen_indices], + decoded_outputs + ) + ] + + # Prepare sampling parameters + if args.engine_type == "vllm": + sampling_params = SamplingParams( + temperature=kwargs.get("temperature", 1.0), + top_p=kwargs.get("top_p", 1.0), + top_k=kwargs.get("top_k", -1), + max_tokens=kwargs.get("max_new_tokens", 1024), + min_tokens=kwargs.get("min_new_tokens", 1), + skip_special_tokens=kwargs.get("skip_special_tokens", False), + include_stop_str_in_output=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + elif args.engine_type == "sglang": + sampling_params = dict( + n=1, + temperature=kwargs.get("temperature", 1.0), + top_p=kwargs.get("top_p", 1.0), + top_k=kwargs.get("top_k", -1), + max_new_tokens=kwargs.get("max_new_tokens", 1024), + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + skip_special_tokens=kwargs.get("skip_special_tokens", False), + spaces_between_special_tokens=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + else: + raise ValueError(f"Unsupported backend: {args.engine_type}") + + # Build inputs and regenerate using the same pattern as fast_exp_maker_partial.py + # First, check if multimodal + is_multimodal = regen_data.get("images") is not None and len(regen_data["images"]) > 0 + if is_multimodal: + # Use strategy._build_multimodal_inputs + inputs = self.strategy._build_multimodal_inputs( + all_prompts=new_inputs, + all_images=[regen_data["images"][i] for i in regen_indices], + images_num=[regen_data["images_num"][i] for i in regen_indices] + ) + # Use engine_generate_local for multimodal regeneration + regenerated = self.strategy.engine_generate_local( + sampling_params=sampling_params, + prompt_token_ids=None, + multi_modal_inputs=inputs, + ) + else: + # For text-only, we can reuse parent's generate_samples but need raw outputs. + # Instead, we can directly call strategy.gather_and_generate with tokenized inputs. + # Tokenize new prompts + tokenized = self.tokenize_fn(new_inputs, self.prompt_max_len, padding=False) + prompt_token_ids = tokenized["input_ids"] + # Expand by n_samples_per_prompt (should be 1 for regeneration?) + # In partial rollout, each sample is already expanded, so we assume n_samples_per_prompt=1. + # Use strategy.gather_and_generate + regenerated = self.strategy.gather_and_generate( + sampling_params=sampling_params, + all_prompt_token_ids=prompt_token_ids, + all_prompts=None, + all_images=None, + sleep_engine=False, + images_num=None, + ) + + # Update regenerated outputs in regen_data + for idx, new_output in zip(regen_indices, regenerated): + regen_data["output"][idx] = new_output + + return regen_data + + def _merge_data(self, data1: Dict[str, List], data2: Dict[str, List]) -> Dict[str, List]: + """Merge two data dictionaries, concatenating lists or tensors.""" + merged = {} + for key in set(data1.keys()) | set(data2.keys()): + val1 = data1.get(key, []) + val2 = data2.get(key, []) + if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + merged[key] = torch.cat([val1, val2]) + elif isinstance(val1, list) and isinstance(val2, list): + merged[key] = val1 + val2 + else: + merged[key] = val1 if val1 else val2 + return merged + + def _generate_sample_list( + self, + samples_data: Dict[str, List], + is_multimodal: bool, + internvl: bool, + **kwargs + ) -> List[Samples]: + """Convert buffered data into a list of Samples.""" + args = self.strategy.args + samples_list = [] + gen_max_len, gen_min_len = 0, 102400000 + index_pixel_patch = 0 + image_start_idx = 0 + + all_outputs = samples_data.get("output", []) + all_labels = samples_data.get("labels", []) + all_prompts = samples_data.get("prompts", []) + all_images = samples_data.get("images", []) + all_images_num = samples_data.get("images_num", []) + all_images_pixel_values = samples_data.get("images_pixel_values", []) + all_images_grid_thw = samples_data.get("images_grid_thw", []) + all_image_flags = samples_data.get("image_flags", []) + all_references = samples_data.get("references", []) + + for i in range(0, len(all_outputs), args.micro_rollout_batch_size): + outputs = all_outputs[i: i + args.micro_rollout_batch_size] + prompts = all_prompts[i: i + args.micro_rollout_batch_size] + if all_images: + assert all_images_num is not None + rollout_image_num = sum(all_images_num[i: i + args.micro_rollout_batch_size]) + images_grid_thw = all_images_grid_thw[image_start_idx: image_start_idx + rollout_image_num] + raw_images = all_images[image_start_idx: image_start_idx + rollout_image_num] + image_start_idx += rollout_image_num + if all_references: + references = all_references[i: i + args.micro_rollout_batch_size] + labels = all_labels[i: i + args.micro_rollout_batch_size] + + if not self.packing_samples: + # Build unpacked samples + max_input_len, max_output_len = 0, 0 + for output in outputs: + max_input_len = max(max_input_len, len(output.prompt_token_ids)) + max_output_len = max(max_output_len, len(output.output_token_ids)) + + pad_token_id, eos_token_id = self.tokenizer.pad_token_id, self.tokenizer.eos_token_id + sequences = [] + pixel_values = [] + image_grid_thw_list = [] + image_flags = [] + images_grid_id = 0 + for j in range(len(outputs)): + output = outputs[j] + input_len = len(output.prompt_token_ids) + input_ids = [pad_token_id] * (max_input_len - input_len) + list(output.prompt_token_ids) + output_len = len(output.output_token_ids) + output_ids = list(output.output_token_ids) + [pad_token_id] * (max_output_len - output_len) + # split pixel_patch + if all_images: + image_num = all_images_num[i + j] + for image_id in range(0, image_num): + images_grid = images_grid_thw[images_grid_id + image_id] + if internvl: + num_patch = images_grid if isinstance(images_grid, int) else images_grid.sum().item() + _image_flags = all_image_flags[index_pixel_patch: index_pixel_patch + num_patch] + image_flags.append(_image_flags) + image_grid_thw_list.append(torch.tensor([1, 1, num_patch]).unsqueeze(0)) + else: + num_patch = images_grid[0] * images_grid[1] * images_grid[2] + image_grid_thw_list.append(images_grid.clone().unsqueeze(0)) + images_pixel_value = all_images_pixel_values[index_pixel_patch: index_pixel_patch + num_patch] + pixel_values.append(images_pixel_value.clone()) + index_pixel_patch += num_patch + images_grid_id += image_num + sequences.append(input_ids + output_ids) + + sequences = torch.tensor(sequences) + sequences, attention_mask, action_mask = self.actor.process_sequences( + sequences, max_input_len, eos_token_id, pad_token_id + ) + sequences = sequences.to("cuda") + attention_mask = attention_mask.to("cuda") + action_mask = action_mask.to("cuda") + if not all_images: + samples_list.append( + Samples( + sequences=sequences, + attention_mask=attention_mask, + action_mask=action_mask, + num_actions=action_mask.size(1), + packed_seq_lens=None, + response_length=action_mask.float().sum(dim=-1), + total_length=attention_mask.float().sum(dim=-1), + prompts=prompts, + labels=labels, + pad_len=None, + ) + ) + else: + if internvl: + pixel_values_intern = torch.cat(pixel_values, dim=0).to("cuda") if pixel_values else None + pixel_values = None + else: + pixel_values = torch.cat(pixel_values, dim=0).to("cuda") if pixel_values else None + pixel_values_intern = None + samples_list.append( + SamplesVL( + sequences=sequences, + attention_mask=attention_mask, + action_mask=action_mask, + image_grid_thws=torch.cat(image_grid_thw_list, dim=0).to("cuda") if not internvl else None, + raw_images=raw_images, + pixel_values=pixel_values, + pixel_values_intern=pixel_values_intern, + image_flags=torch.cat(image_flags, dim=0).to("cuda") if internvl else None, + num_actions=action_mask.size(1), + packed_seq_lens=None, + response_length=action_mask.float().sum(dim=-1), + total_length=attention_mask.float().sum(dim=-1), + references=references, + labels=labels, + prompts=prompts, + ) + ) + else: + # Packed samples (not supporting VLM yet) + pad_token_id, eos_token_id = self.tokenizer.pad_token_id, self.tokenizer.eos_token_id + sequences = [] + packed_seq_lens = [] + attention_mask = [] + num_actions = [] + for idx, output in enumerate(outputs): + input_len = len(output.prompt_token_ids) + output_len = len(output.output_token_ids) + packed_seq_lens.append(input_len + output_len) + sequences.extend(output.prompt_token_ids + list(output.output_token_ids)) + attention_mask.extend([idx + 1] * (input_len + output_len)) + num_actions.append(max(1, output_len)) + gen_max_len = max(gen_max_len, output_len) + gen_min_len = min(gen_min_len, output_len) + + sequences = torch.tensor(sequences, device="cuda").unsqueeze(0) + attention_mask = torch.tensor(attention_mask, device="cuda").unsqueeze(0) + action_mask = None + response_length = torch.tensor(num_actions, device="cuda", dtype=torch.float) + total_length = torch.tensor(packed_seq_lens, device="cuda", dtype=torch.float) + samples_list.append( + Samples( + sequences=sequences, + attention_mask=attention_mask, + action_mask=None, + num_actions=num_actions, + packed_seq_lens=packed_seq_lens, + response_length=response_length, + total_length=total_length, + prompts=prompts, + labels=labels, + pad_len=None, + ) + ) + + if dist.get_rank(self.backend_mp_group) == 0: + print(f"*** response_length {gen_max_len=}, {gen_min_len=}") + + return samples_list + + def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Experience], List[torch.Tensor]]: + """ + Process experiences (reward shaping for partial rollout). + + This method overrides the parent's _process_experiences to handle + advantage estimators that expect a different group size (partial_percent). + """ + args = self.strategy.args + if args.advantage_estimator == "rloo": + rewards = torch.cat([exp.info["reward"] for exp in experiences]) + rewards = rewards.reshape(-1, args.n_samples_per_prompt) + baseline = (rewards.sum(-1, keepdim=True) - rewards) / (args.n_samples_per_prompt - 1) + rewards = rewards - baseline + rewards = rewards.flatten().chunk(len(experiences)) + return experiences, rewards + elif args.advantage_estimator in ["grpo", "group_norm"]: + # Adjust group size according to partial rollout + group_size = int(self.partial_percent * self.rollout_batch_size // args.micro_rollout_batch_size) + rewards = torch.cat([exp.info["reward"] for exp in experiences]) + rewards = rewards.reshape(-1, group_size) + baseline = rewards.mean(-1, keepdim=True) + rewards = (rewards - baseline) / (rewards.std(1, keepdim=True) + 1e-8) + rewards = rewards.flatten().chunk(len(experiences)) + return experiences, rewards + elif args.advantage_estimator == "reinforce_baseline": + rewards = torch.cat([exp.info["reward"] for exp in experiences]) + rewards = rewards.reshape(-1, args.n_samples_per_prompt).to(device="cuda") + rewards = rewards - rewards.mean(-1, keepdim=True) + rewards = rewards.reshape(-1).to(device="cpu").chunk(len(experiences)) + return experiences, rewards + else: + raise ValueError(f"Unhandled advantage_estimator: {args.advantage_estimator}") diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index 33e85d5b..b2fd5cc2 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -25,6 +25,7 @@ from lightrft.trainer import PPOTrainer, PPOTrainerVL from lightrft.trainer.fast_exp_maker import FastExperienceMaker +from lightrft.trainer.fast_exp_maker_partial import PartialFastExperienceMaker from lightrft.utils.trajectory_saver import create_trajectory_saver from lightrft.trainer.replay_buffer import make_experience_batch @@ -659,3 +660,226 @@ def __init__( # Then initialize our base class assert "processor" in kwargs and kwargs["processor"] is not None, "processor is required for SPMDPPOTrainerVL" SPMDPPOTrainerBase.__init__(self, *args, VLM=True, **kwargs) + if getattr(self.args, 'use_partial', False): + # Replace experience maker with partial version + self.experience_maker = PartialFastExperienceMaker( + self.actor, + self.critic, + self.reward_model, + self.initial_model, + self.tokenizer, + self.prompt_max_len, + self.kl_ctl, + self.strategy, + self.remote_rm_url, + self.reward_fn, + self.reward_fn_label_map, + self.reward_recipe, + packing_samples=self.packing_samples, + processor=self._processor, + partial_percent=getattr(self.args, "partial_percent", 0.7), + max_token_budget=getattr(self.args, "max_token_budget", 1024), + ) + + def _make_experience_iterator(self, dataloader, use_partial): + """ + Create an iterator that yields batches of experiences. + + Args: + dataloader: The dataloader providing prompts, images, references, and labels. + use_partial: Whether to use partial rollout logic. + + Yields: + List[Experience]: A list of experiences for each training step. + """ + if use_partial: + # Partial rollout logic + dataloader_iter = iter(dataloader) + while True: + # Generate experiences either from new prompts or cached ones + if self.experience_maker.need_new_prompts(self.args.rollout_batch_size, self.micro_rollout_batch_size): + try: + # Get next batch of prompts, images, references, and labels + rand_prompts, rand_images, rand_references, rand_labels = next(dataloader_iter) + except StopIteration: + # End of epoch reached + break + + # Generate experiences from new prompts + experiences = self.experience_maker.make_experience_list( + rand_prompts, rand_images, rand_references, rand_labels, **self.generate_kwargs + ) + else: + # Generate experiences from cached prompts + experiences = self.experience_maker.make_experience_list( + None, None, None, None, **self.generate_kwargs + ) + yield experiences + else: + # Non-partial rollout logic + for rand_prompts, rand_images, rand_references, rand_labels in dataloader: + experiences = self.experience_maker.make_experience_list( + rand_prompts, rand_images, rand_references, rand_labels, **self.generate_kwargs + ) + yield experiences + + def _process_experiences_and_train(self, experiences, steps): + """ + Process a batch of experiences: add to replay buffer, train, and update metrics. + + Args: + experiences: List of Experience objects. + steps: Current step counter. + + Returns: + dict: Training status metrics. + """ + # Add experiences to replay buffer + for i, experience in enumerate(experiences): + if i == 0: + # Decode first experience for debugging/monitoring + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.replay_buffer.append(experience) + + # Report memory usage after replay buffer is filled + self.strategy.report_memory('after replay_buffer ready') + + # Normalize advantages if not using group normalization + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) + + # Execute training phase + self.strategy.report_memory('before train') + status = self.ppo_train(steps) + self.strategy.report_memory('before clear buffer') + + # Clear replay buffer for next iteration + self.replay_buffer.clear() + self.strategy.report_memory('after train') + + # Update KL control coefficient + if "kl" in status: + self.kl_ctl.update(status["kl"], self.args.rollout_batch_size * self.args.n_samples_per_prompt) + + return status + + def fit( + self, + args, + prompts_dataloader, + pretrain_dataloader, + consumed_samples=0, + num_update_steps_per_episodes=1, + ) -> None: + """ + Execute the main training loop for vision-language models using SPMD PPO. + + This method orchestrates the complete training process including: + 1. Rollout phase: Generate experiences by interacting with the environment + 2. Training phase: Update actor and critic models using collected experiences + 3. Evaluation and checkpointing: Save models and logs at specified intervals + + The training loop follows the standard PPO pattern with distributed data + parallelism for efficient multi-device training. It handles both + image-text prompts and pre-training data for mixed objective optimization. + + :param args: Training configuration arguments containing hyperparameters like num_episodes, rollout_batch_size, n_samples_per_prompt, etc. + :type args: argparse.Namespace + :param prompts_dataloader: DataLoader providing image-text prompts for training + :type prompts_dataloader: torch.utils.data.DataLoader + :param pretrain_dataloader: DataLoader providing pre-training data for supervised fine-tuning + :type pretrain_dataloader: torch.utils.data.DataLoader + :param consumed_samples: Number of samples already processed (for resuming training) + :type consumed_samples: int + :param num_update_steps_per_episodes: Number of update steps per training episode + :type num_update_steps_per_episodes: int + :return: None + + Example:: + + trainer.fit( + args=training_args, + prompts_dataloader=image_prompt_loader, + pretrain_dataloader=pretrain_loader, + consumed_samples=0, + num_update_steps_per_episodes=4 + ) + + .. note:: Distributed Training + This method handles distributed data parallelism automatically using + DistributedSampler when available. Each rank processes a different + subset of the data. + + .. warning:: Memory Management + The method includes explicit memory reporting and cache clearing + to manage GPU memory efficiently during training. + """ + # Calculate number of rollouts per episode based on batch sizes and epochs + num_rollouts_per_episodes = ( + num_update_steps_per_episodes + * args.train_batch_size + // args.max_epochs + // args.rollout_batch_size + // args.n_samples_per_prompt + ) + + # Configure evaluation and checkpointing intervals + # If eval_steps is -1, evaluate once per epoch + if args.eval_steps == -1: + args.eval_steps = num_rollouts_per_episodes + # If save_steps is -1, disable checkpoint saving + if args.save_steps == -1: + args.save_steps = float("inf") + + # Store data loaders for access throughout training + self.prompts_dataloader = prompts_dataloader + self.pretrain_dataloader = pretrain_dataloader + + # Calculate starting step and episode for resuming training + steps = consumed_samples // args.rollout_batch_size + 1 + start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes + consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size) + + # Determine if using partial rollout + use_partial = getattr(self.args, 'use_partial', False) + + # Main training loop over episodes + for episode in range(start_episode, args.num_episodes): + # Configure distributed sampler for current episode + if isinstance(self.prompts_dataloader.sampler, DistributedSampler): + self.prompts_dataloader.sampler.set_epoch( + episode, consumed_samples=0 if episode > start_episode else consumed_samples + ) + + # Progress bar for monitoring training progress + pbar = tqdm( + range(self.prompts_dataloader.__len__()), + desc=f"Episode [{episode + 1}/{args.num_episodes}]", + disable=not self.strategy.is_rank_0(), + ) + + # Create experience iterator + experience_iterator = self._make_experience_iterator(self.prompts_dataloader, use_partial) + + for experiences in experience_iterator: + # Process experiences and perform training step + status = self._process_experiences_and_train(experiences, steps) + + # Update progress bar with training status + pbar.set_postfix(status) + + # Save logs and checkpoints at appropriate intervals + client_states = {"consumed_samples": steps * args.rollout_batch_size} + self.save_logs_and_checkpoints(args, steps, pbar, status, client_states) + + # Update step counter and progress bar + pbar.update() + steps = steps + 1 + + # Clean up monitoring tools + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() From 9e22a3d0dc1672e3cc20439e13d7e0bee60b98fe Mon Sep 17 00:00:00 2001 From: AltmanD Date: Thu, 5 Feb 2026 14:18:49 +0800 Subject: [PATCH 2/3] fix the comment format --- lightrft/trainer/fast_exp_maker_partial.py | 216 ++++++++++++++++----- lightrft/trainer/spmd_ppo_trainer.py | 39 ++-- 2 files changed, 192 insertions(+), 63 deletions(-) diff --git a/lightrft/trainer/fast_exp_maker_partial.py b/lightrft/trainer/fast_exp_maker_partial.py index b9e7cab7..f43c84bd 100644 --- a/lightrft/trainer/fast_exp_maker_partial.py +++ b/lightrft/trainer/fast_exp_maker_partial.py @@ -1,14 +1,50 @@ """ -PartialFastExperienceMaker – FastExperienceMaker with partial rollout and token‑budget regeneration. +PartialFastExperienceMaker Module – FastExperienceMaker with Partial Rollout and Token‑Budget Regeneration. -This subclass adds two key features: - 1. Partial rollout: only a fraction (partial_percent) of the total rollout batch is generated - in each call; the rest is kept in buffers. - 2. Token‑budget regeneration: samples whose generation reaches max_token_budget are flagged - and can be regenerated later (e.g., for continuing long‑form tasks). +This module extends FastExperienceMaker to support incremental rollout and controlled generation +for long‑form or high‑cost tasks. It introduces two core mechanisms: + 1. Partial rollout: only a fraction (partial_percent) of the total rollout batch is generated + in each call; the remaining samples are kept in buffers for subsequent steps, reducing + per‑iteration latency and enabling smoother pipeline scheduling. + 2. Token‑budget regeneration: samples that reach a predefined token budget (max_token_budget) + are flagged and can be regenerated later, allowing continuation of long‑form generation + without discarding already‑produced content. The class reuses the parent's infrastructure (MultimodalDataProcessor, RewardComputationEngine, etc.) and only overrides the methods that implement the partial‑rollout logic. + +Implementation Overview: + - The rollout batch is split into "regeneration" and "non‑regeneration" buffers based on + whether a sample has exhausted its token budget. + - The method `need_new_prompts` determines whether fresh prompts are required to fill the + partial batch. + - Generation is performed via the parent's inference engine (VLLM/SGLang), but outputs are + post‑processed to respect the token budget and partial fraction. + - Advantage estimation methods (RLOO, Group Norm, REINFORCE) are adjusted to account for the + reduced group size introduced by partial rollout. + +Key Features: + - Partial rollout with configurable fraction (partial_percent) for reduced iteration latency + - Token‑budget regeneration for long‑form continuation (max_token_budget) + - Buffered sample management (regen_buffer, noregen_buffer) for stateful rollout + - Seamless integration with VLLM/SGLang backends and multimodal processing + - Adaptive advantage estimation that respects the partial batch size + - Support for both packed and unpacked sample formats + +Parameters: + partial_percent (float): Fraction of the total rollout batch to generate in one call. + Values between 0.0 and 1.0. Default: 0.7. + max_token_budget (int): Maximum allowed generation length before a sample is flagged for + regeneration. Samples that reach this length are stored in the regeneration buffer + and can be continued in a later step. Default: 1024. + +References: + - Kimi1.5: "Kimi k1.5: Scaling Reinforcement Learning with LLMs" (https://arxiv.org/abs/2501.12599) + - MiMo: "MiMo: Unlocking the Reasoning Potential of Language Model + -- From Pretraining to Posttraining" (https://arxiv.org/abs/2505.07608) + +Classes: + PartialFastExperienceMaker: Main experience generation class with partial‑rollout support. """ from typing import List, Optional, Union, Tuple, Dict, Any @@ -30,12 +66,31 @@ class PartialFastExperienceMaker(FastExperienceMaker): """ FastExperienceMaker with partial rollout and token‑budget regeneration. + This class extends FastExperienceMaker to support incremental rollout and controlled generation + for long‑form or high‑cost tasks. It introduces two core mechanisms: + 1. Partial rollout: only a fraction (partial_percent) of the total rollout batch is generated + in each call; the remaining samples are kept in buffers for subsequent steps, reducing + per‑iteration latency and enabling smoother pipeline scheduling. + 2. Token‑budget regeneration: samples that reach a predefined token budget (max_token_budget) + are flagged and can be regenerated later, allowing continuation of long‑form generation + without discarding already‑produced content. + + The class reuses the parent's infrastructure (MultimodalDataProcessor, RewardComputationEngine, + etc.) and only overrides the methods that implement the partial‑rollout logic. + + The partial‑rollout pipeline: + 1. Buffer Management: Maintain regeneration (regen) and non‑regeneration (noregen) buffers + 2. Need‑Prompts Check: Determine if fresh prompts are required to fill the partial batch + 3. Generation: Use parent's inference engine (VLLM/SGLang) but respect token budget and partial fraction + 4. Regeneration: For samples that exceed token budget, regenerate with continuation + 5. Advantage Adaptation: Adjust advantage estimators (RLOO, Group Norm, REINFORCE) for partial group size + Args: - partial_percent (float): fraction of the rollout batch to generate in one call. - max_token_budget (int): maximum allowed generation length before regeneration. - packing_samples (bool): whether to pack samples (inherited). - processor: multimodal processor (inherited). - *args, **kwargs: passed to parent. + partial_percent: Fraction of the total rollout batch to generate in one call (0.0‑1.0) + max_token_budget: Maximum allowed generation length before a sample is flagged for regeneration + packing_samples: Whether to pack multiple sequences into single batch (inherited from parent) + processor: Multimodal processor for vision‑language models (inherited from parent) + *args, **kwargs: Arguments passed to parent FastExperienceMaker """ def __init__( @@ -47,6 +102,22 @@ def __init__( processor=None, **kwargs ): + """ + Initialize PartialFastExperienceMaker. + + :param args: Positional arguments for parent FastExperienceMaker + :type args: tuple + :param partial_percent: Fraction of total rollout batch to generate per call (0.0‑1.0) + :type partial_percent: float + :param max_token_budget: Maximum generation length before a sample is flagged for regeneration + :type max_token_budget: int + :param packing_samples: Enable sample packing for efficiency (inherited) + :type packing_samples: bool + :param processor: Multimodal processor for vision‑language models (inherited) + :type processor: Optional[Any] + :param kwargs: Keyword arguments for parent FastExperienceMaker + :type kwargs: dict + """ super().__init__(*args, packing_samples=packing_samples, processor=processor, **kwargs) self.partial_percent = partial_percent self.max_token_budget = max_token_budget @@ -57,7 +128,8 @@ def __init__( self.noregen_buffer: Dict[str, List] = {} fields = [ 'output', 'labels', 'prompts', 'images', 'images_num', - 'images_pixel_values', 'images_grid_thw', 'image_flags', 'references' + 'images_pixel_values', 'images_grid_thw', 'image_flags', 'references', + 'videos', 'videos_num', 'videos_pixel_values', 'videos_grid_thw', 'video_flags' ] for field in fields: self.regen_buffer[field] = [] @@ -69,10 +141,21 @@ def __init__( def need_new_prompts(self, rollout_batch_size: int, micro_rollout_batch_size: int) -> bool: """ - Check whether the buffers contain enough data to make a full experience batch. + Determine whether new prompts are required to fill the partial rollout batch. - Returns: - True if new prompts need to be fetched (i.e., buffers are below the partial threshold). + This method checks the current regeneration and non‑regeneration buffers + and compares the total stored samples against the number needed for the + current partial rollout (partial_percent × total rollout batch size). + + If the buffers contain insufficient samples, the caller should fetch fresh + prompts and call generate_samples with those prompts. + + :param rollout_batch_size: Total number of samples in a full rollout batch + :type rollout_batch_size: int + :param micro_rollout_batch_size: Size of each micro‑batch used in generation + :type micro_rollout_batch_size: int + :return: True if new prompts are needed (buffers below partial threshold), else False + :rtype: bool """ self.rollout_batch_size = rollout_batch_size self.micro_rollout_batch_size = micro_rollout_batch_size @@ -92,7 +175,9 @@ def generate_samples( self, all_prompts: List[str], all_images: Optional[List] = None, + all_videos: Optional[List] = None, images_num: Optional[List[int]] = None, + videos_num: Optional[List[int]] = None, all_references: Optional[List[str]] = None, all_labels: Optional[List] = None, **generate_kwargs @@ -100,18 +185,41 @@ def generate_samples( """ Generate samples using the parent's pipeline, but only a partial fraction. - The method: - 1. If new inputs are provided, generate them with the parent's generate_samples. - 2. Split the generated outputs into regeneration and non‑regeneration buffers. - 3. Draw from the buffers to produce the requested number of samples (partial_percent). - 4. If the noregen buffer is insufficient, regenerate some samples from the regen buffer. - - Returns: - List of Samples (or SamplesVL) ready for experience making. + This method implements the partial‑rollout logic: + 1. If new prompts are provided, generate them via the parent's inference engine + (VLLM/SGLang) using token‑budget‑limited generation. + 2. Split the generated outputs into regeneration (regen) and non‑regeneration (noregen) + buffers based on whether they have reached the token budget. + 3. Draw from the buffers to produce the requested number of samples + (partial_percent × total rollout batch size). + 4. If the noregen buffer is insufficient, regenerate some samples from the regen buffer + by continuing generation from the partially‑produced output. + + The method returns a list of Samples (or SamplesVL) ready for experience making. + When new prompts are provided, it also returns the image counts for multimodal data. + + :param all_prompts: List of text prompts (or None to only draw from buffers) + :type all_prompts: List[str] + :param all_images: Optional images for vision‑language models + :type all_images: Optional[List] + :param all_videos: Optional videos for vision‑language models + :type all_videos: Optional[List] + :param images_num: Number of images per prompt + :type images_num: Optional[List[int]] + :param videos_num: Number of videos per prompt + :type videos_num: Optional[List[int]] + :param all_references: Reference texts for evaluation + :type all_references: Optional[List[str]] + :param all_labels: Sample labels for reward shaping + :type all_labels: Optional[List] + :param generate_kwargs: Generation parameters (temperature, max_new_tokens, etc.) + :type generate_kwargs: dict + :return: List of Samples (or SamplesVL) when all_prompts is None, + otherwise tuple (samples_list, images_num_list) + :rtype: Union[List[Samples], Tuple[List[Samples], Optional[List[int]]]] """ args = self.strategy.args - is_multimodal = all_images is not None - internvl = "internvl" in self.actor.pretrain_or_model.lower() if is_multimodal else False + is_multimodal = all_images is not None or all_videos is not None # -------------------------------------------------------------------- # Step 1: Generate new samples if inputs are provided @@ -155,9 +263,10 @@ def generate_samples( processed = self._process_multimodal_data( all_prompts=all_prompts, all_images=all_images, - is_internvl=internvl, + all_videos=all_videos, all_references=all_references, - images_num=images_num + images_num=images_num, + videos_num=videos_num ) prompt_token_ids = processed["all_prompt_token_ids"] prompts = processed["all_prompts"] @@ -167,6 +276,11 @@ def generate_samples( grid_thw = processed["all_images_grid_thw"] image_flags = processed["all_image_flags"] references = processed["all_references"] + videos = processed.get("all_videos") + videos_num = processed.get("all_videos_num") + pixel_values_videos = processed.get("all_videos_pixel_values") + video_grid_thw = processed.get("all_videos_grid_thw") + video_flags = processed.get("all_video_flags") else: tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False) prompt_token_ids = tokenized["input_ids"] @@ -244,7 +358,6 @@ def generate_samples( samples_list = self._generate_sample_list( samples_data, is_multimodal, - internvl, **generate_kwargs ) self.strategy.maybe_sleep_inference_engine() @@ -255,15 +368,17 @@ def generate_samples( images_num_list = samples_data.get("images_num") return samples_list, images_num_list - def _process_multimodal_data(self, all_prompts, all_images, is_internvl, all_references, images_num): + def _process_multimodal_data(self, all_prompts, all_images, all_videos, is_internvl, all_references, images_num, videos_num): """Wrapper around parent's multimodal_processor.process_multimodal_batch.""" if self.multimodal_processor is None: raise ValueError("Multimodal processor not initialized.") return self.multimodal_processor.process_multimodal_batch( all_prompts=all_prompts, all_images=all_images, + all_videos=all_videos, all_references=all_references, images_num=images_num, + videos_num=videos_num, n_samples_per_prompt=self.strategy.config.n_samples_per_prompt, is_internvl=is_internvl, ) @@ -411,7 +526,6 @@ def _generate_sample_list( self, samples_data: Dict[str, List], is_multimodal: bool, - internvl: bool, **kwargs ) -> List[Samples]: """Convert buffered data into a list of Samples.""" @@ -468,14 +582,8 @@ def _generate_sample_list( image_num = all_images_num[i + j] for image_id in range(0, image_num): images_grid = images_grid_thw[images_grid_id + image_id] - if internvl: - num_patch = images_grid if isinstance(images_grid, int) else images_grid.sum().item() - _image_flags = all_image_flags[index_pixel_patch: index_pixel_patch + num_patch] - image_flags.append(_image_flags) - image_grid_thw_list.append(torch.tensor([1, 1, num_patch]).unsqueeze(0)) - else: - num_patch = images_grid[0] * images_grid[1] * images_grid[2] - image_grid_thw_list.append(images_grid.clone().unsqueeze(0)) + num_patch = images_grid[0] * images_grid[1] * images_grid[2] + image_grid_thw_list.append(images_grid.clone().unsqueeze(0)) images_pixel_value = all_images_pixel_values[index_pixel_patch: index_pixel_patch + num_patch] pixel_values.append(images_pixel_value.clone()) index_pixel_patch += num_patch @@ -505,22 +613,18 @@ def _generate_sample_list( ) ) else: - if internvl: - pixel_values_intern = torch.cat(pixel_values, dim=0).to("cuda") if pixel_values else None - pixel_values = None - else: - pixel_values = torch.cat(pixel_values, dim=0).to("cuda") if pixel_values else None - pixel_values_intern = None + pixel_values = torch.cat(pixel_values, dim=0).to("cuda") if pixel_values else None + pixel_values_intern = None samples_list.append( SamplesVL( sequences=sequences, attention_mask=attention_mask, action_mask=action_mask, - image_grid_thws=torch.cat(image_grid_thw_list, dim=0).to("cuda") if not internvl else None, + image_grid_thws=torch.cat(image_grid_thw_list, dim=0).to("cuda"), raw_images=raw_images, pixel_values=pixel_values, pixel_values_intern=pixel_values_intern, - image_flags=torch.cat(image_flags, dim=0).to("cuda") if internvl else None, + image_flags=torch.cat(image_flags, dim=0).to("cuda"), num_actions=action_mask.size(1), packed_seq_lens=None, response_length=action_mask.float().sum(dim=-1), @@ -574,10 +678,22 @@ def _generate_sample_list( def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Experience], List[torch.Tensor]]: """ - Process experiences (reward shaping for partial rollout). - - This method overrides the parent's _process_experiences to handle - advantage estimators that expect a different group size (partial_percent). + Process experiences for reward shaping and filtering under partial rollout. + + This method overrides the parent's process_experiences to adjust advantage + estimators (RLOO, Group Norm, REINFORCE) for the reduced group size introduced + by partial rollout. The group size is scaled by partial_percent to reflect the + actual number of samples generated per call. + + For each estimator: + - RLOO: Uses n_samples_per_prompt as usual (since RLOO operates per‑prompt) + - GRPO/Group Norm: Adjusts group size to partial_percent × rollout_batch_size // micro_rollout_batch_size + - REINFORCE Baseline: Removes the baseline using the same n_samples_per_prompt + + :param experiences: List of Experience objects with raw rewards stored in info["reward"] + :type experiences: List[Experience] + :return: Tuple of (unchanged experiences, shaped reward tensors split per experience) + :rtype: Tuple[List[Experience], List[torch.Tensor]] """ args = self.strategy.args if args.advantage_estimator == "rloo": diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index 2b45ca52..cbbe868d 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -690,7 +690,7 @@ def __init__( # Then initialize our base class assert "processor" in kwargs and kwargs["processor"] is not None, "processor is required for SPMDPPOTrainerVL" SPMDPPOTrainerBase.__init__(self, *args, VLM=True, **kwargs) - if getattr(self.args, 'use_partial', False): + if self.args.use_partial: # Replace experience maker with partial version self.experience_maker = PartialFastExperienceMaker( self.actor, @@ -715,12 +715,17 @@ def _make_experience_iterator(self, dataloader, use_partial): """ Create an iterator that yields batches of experiences. - Args: - dataloader: The dataloader providing prompts, images, references, and labels. - use_partial: Whether to use partial rollout logic. - - Yields: - List[Experience]: A list of experiences for each training step. + This method handles both partial and non‑partial rollout logic. + For partial rollouts, it reuses cached prompts when possible to reduce + data loading overhead. For standard rollouts, it processes each batch + from the dataloader sequentially. + + :param dataloader: DataLoader providing prompts, images, references, and labels + :type dataloader: torch.utils.data.DataLoader + :param use_partial: Whether to use partial rollout logic + :type use_partial: bool + :yield: List of Experience objects for each training step + :ytype: List[lightrft.trainer.experience_maker_vl.Experience] """ if use_partial: # Partial rollout logic @@ -757,12 +762,20 @@ def _process_experiences_and_train(self, experiences, steps): """ Process a batch of experiences: add to replay buffer, train, and update metrics. - Args: - experiences: List of Experience objects. - steps: Current step counter. - - Returns: - dict: Training status metrics. + This method handles the core training loop for each batch of experiences: + 1. Appends experiences to the replay buffer + 2. Reports memory usage + 3. Normalizes advantages (if not using group normalization) + 4. Executes PPO training + 5. Clears the replay buffer + 6. Updates KL control coefficient + + :param experiences: List of Experience objects to process + :type experiences: List[lightrft.trainer.experience_maker_vl.Experience] + :param steps: Current step counter for training progress tracking + :type steps: int + :return: Dictionary containing training status metrics (policy loss, critic loss, reward, etc.) + :rtype: Dict[str, float] """ # Add experiences to replay buffer for i, experience in enumerate(experiences): From 6f3c161bd29e29f713c64dfc6c661b1f3450a54e Mon Sep 17 00:00:00 2001 From: AltmanD Date: Tue, 10 Feb 2026 11:37:50 +0800 Subject: [PATCH 3/3] fix for new traingin pipeline --- lightrft/trainer/fast_exp_maker_partial.py | 558 ++++++++++----------- lightrft/trainer/spmd_ppo_trainer.py | 262 +++++++--- 2 files changed, 472 insertions(+), 348 deletions(-) diff --git a/lightrft/trainer/fast_exp_maker_partial.py b/lightrft/trainer/fast_exp_maker_partial.py index f43c84bd..720c127a 100644 --- a/lightrft/trainer/fast_exp_maker_partial.py +++ b/lightrft/trainer/fast_exp_maker_partial.py @@ -53,13 +53,13 @@ from copy import deepcopy import torch -import torch.distributed as dist from vllm import SamplingParams from easydict import EasyDict -from openrlhf.trainer.ppo_utils.experience_maker import Experience, Samples -from openrlhf.trainer.ppo_utils.experience_maker_vl import SamplesVL +from lightrft.trainer.experience_maker import Experience, Samples +from lightrft.trainer.experience_maker_vl import SamplesVL from lightrft.trainer.fast_exp_maker import FastExperienceMaker +from lightrft.utils import Timer, get_current_device class PartialFastExperienceMaker(FastExperienceMaker): @@ -127,9 +127,9 @@ def __init__( self.regen_buffer: Dict[str, List] = {} self.noregen_buffer: Dict[str, List] = {} fields = [ - 'output', 'labels', 'prompts', 'images', 'images_num', - 'images_pixel_values', 'images_grid_thw', 'image_flags', 'references', - 'videos', 'videos_num', 'videos_pixel_values', 'videos_grid_thw', 'video_flags' + 'output', 'labels', 'prompts', 'references', + 'images', 'images_num', 'images_grid_thw', 'images_pixel_values', + 'videos', 'videos_num', 'videos_grid_thw', 'videos_pixel_values' ] for field in fields: self.regen_buffer[field] = [] @@ -161,7 +161,7 @@ def need_new_prompts(self, rollout_batch_size: int, micro_rollout_batch_size: in self.micro_rollout_batch_size = micro_rollout_batch_size # Total micro‑batches needed for a full rollout - total_micro = rollout_batch_size // micro_rollout_batch_size + total_micro = rollout_batch_size // self.strategy.world_size # Micro‑batches we want to generate in one call target_micro = int(self.partial_percent * total_micro) required_samples = target_micro * micro_rollout_batch_size @@ -218,8 +218,18 @@ def generate_samples( otherwise tuple (samples_list, images_num_list) :rtype: Union[List[Samples], Tuple[List[Samples], Optional[List[int]]]] """ - args = self.strategy.args - is_multimodal = all_images is not None or all_videos is not None + assert self.strategy.inference_engine is not None, "Inference engine required" + + torch.cuda.synchronize() + start_time = time.time() + + config = self.strategy.config + if all_prompts is not None: + is_multimodal = all_images is not None or all_videos is not None + else: + is_multimodal = (len(self.noregen_buffer.get('images', [])) + len(self.regen_buffer.get('images', []))) != 0 or \ + (len(self.noregen_buffer.get('videos', [])) + len(self.regen_buffer.get('videos', []))) != 0 + n_samples = config.n_samples_per_prompt # -------------------------------------------------------------------- # Step 1: Generate new samples if inputs are provided @@ -227,7 +237,7 @@ def generate_samples( if all_prompts is not None: # Replicate the generation logic from fast_exp_maker_partial.py # Prepare sampling parameters - if args.engine_type == "vllm": + if config.engine_type == "vllm": sampling_params = SamplingParams( temperature=generate_kwargs.get("temperature", 1.0), top_p=generate_kwargs.get("top_p", 1.0), @@ -238,7 +248,7 @@ def generate_samples( include_stop_str_in_output=True, ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", ) - elif args.engine_type == "sglang": + elif config.engine_type == "sglang": sampling_params = dict( n=1, temperature=generate_kwargs.get("temperature", 1.0), @@ -253,60 +263,93 @@ def generate_samples( ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", ) else: - raise ValueError(f"Unsupported backend: {args.engine_type}") + raise ValueError(f"Unsupported backend: {config.engine_type}") # Expand labels - expanded_labels = sum([[label] * args.n_samples_per_prompt for label in all_labels], []) if all_labels else [] + expanded_labels = sum([[label] * n_samples for label in all_labels], []) if all_labels else [] # Process multimodal data if is_multimodal: - processed = self._process_multimodal_data( + processed = self.multimodal_processor.process_multimodal_batch( all_prompts=all_prompts, all_images=all_images, all_videos=all_videos, all_references=all_references, images_num=images_num, - videos_num=videos_num + videos_num=videos_num, + n_samples_per_prompt=n_samples, ) prompt_token_ids = processed["all_prompt_token_ids"] prompts = processed["all_prompts"] images = processed["all_images"] images_num = processed["all_images_num"] - pixel_values = processed["all_images_pixel_values"] - grid_thw = processed["all_images_grid_thw"] - image_flags = processed["all_image_flags"] + images_pixel_values = processed["all_images_pixel_values"] + images_grid_thw = processed["all_images_grid_thw"] references = processed["all_references"] videos = processed.get("all_videos") videos_num = processed.get("all_videos_num") - pixel_values_videos = processed.get("all_videos_pixel_values") - video_grid_thw = processed.get("all_videos_grid_thw") - video_flags = processed.get("all_video_flags") + videos_pixel_values = processed.get("all_videos_pixel_values") + videos_grid_thw = processed.get("all_videos_grid_thw") else: + # Text-only processing tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False) - prompt_token_ids = tokenized["input_ids"] - prompt_token_ids = sum([[token_ids] * args.n_samples_per_prompt for token_ids in prompt_token_ids], []) - prompts = all_prompts * args.n_samples_per_prompt - images = None - references = all_references * args.n_samples_per_prompt if all_references else None - - # Generate outputs via inference engine - outputs = self.strategy.gather_and_generate( - sampling_params=sampling_params, - all_prompt_token_ids=prompt_token_ids, - all_prompts=prompts if is_multimodal else None, - all_images=images if is_multimodal else None, - sleep_engine=False, - images_num=images_num if is_multimodal else None, - ) + prompt_token_ids = sum([[token_ids] * n_samples for token_ids in tokenized["input_ids"]], []) + + # ========== Generate via Inference Engine ========== + # Call fire_sampling function or direct generation + try: + if hasattr(self.strategy.args, 'use_fire') and self.strategy.args.use_fire: + # Use FIRE sampling (Flaming-hot Initiation with Regular Execution) + outputs = fire_sampling( + all_prompt_token_ids=all_prompt_token_ids, + generate_fn=generate_fn, # noqa: TODO + engine_type=config.engine_type, + first_token_temperature=generate_kwargs.get("first_token_temperature", 10.0), + temperature=generate_kwargs.get("temperature", 1.0), + first_token_top_k=generate_kwargs.get( + "first_token_top_k", sampling_params.top_k if hasattr(sampling_params, 'top_k') else -1 + ), + first_token_top_p=generate_kwargs.get( + "first_token_top_p", sampling_params.top_p if hasattr(sampling_params, 'top_p') else 1.0 + ), + is_multimodal=is_multimodal, + all_prompts=prompts, + all_images=images, + all_videos=videos, + all_images_num=images_num, + all_videos_num=videos_num, + sampling_params=sampling_params, + ) + else: + # maybe this can be called in if and else respectively? or like this? + # Use original single-shot generation + outputs = self.strategy.gather_and_generate( + sampling_params=sampling_params, + all_prompt_token_ids=prompt_token_ids, + all_prompts=prompts if is_multimodal else None, + sleep_engine=self.strategy.args.enable_engine_sleep, + all_images=images if is_multimodal else None, + all_videos=videos if is_multimodal else None, + images_num=images_num if is_multimodal else None, + videos_num=videos_num if is_multimodal else None, + ) + except ValueError as e: + if "prompt" in str(e) and "too long" in str(e): + self.strategy.print(f"[Skip] {e}") + return None # Return None, subsequent experience_maker will ignore + else: + raise # Process outputs in micro-batches and store in buffers - for i in range(0, len(outputs), args.micro_rollout_batch_size): - batch_slice = slice(i, i + args.micro_rollout_batch_size) + for i in range(0, len(outputs), n_samples): + batch_slice = slice(i, i + n_samples) output_batch = outputs[batch_slice] labels_batch = expanded_labels[batch_slice] if expanded_labels else [] prompts_batch = prompts[batch_slice] images_batch = images[batch_slice] if images else None images_num_batch = images_num[batch_slice] if images_num else None + videos_batch = videos[batch_slice] if videos else None + videos_num_batch = videos_num[batch_slice] if videos_num else None references_batch = references[batch_slice] if references else None # Check if regeneration is needed @@ -321,23 +364,31 @@ def generate_samples( self._add_to_buffer(buffer_type, "images", images_batch) if images_num_batch is not None: self._add_to_buffer(buffer_type, "images_num", images_num_batch) + if videos_batch is not None: + self._add_to_buffer(buffer_type, "videos", videos_batch) + if videos_num_batch is not None: + self._add_to_buffer(buffer_type, "videos_num", videos_num_batch) if references_batch is not None: self._add_to_buffer(buffer_type, "references", references_batch) if is_multimodal: - self._add_to_buffer(buffer_type, "image_flags", image_flags[batch_slice]) # Handle image tensors - grid_batch = grid_thw[batch_slice] + grid_batch = images_grid_thw[batch_slice] self._add_to_buffer(buffer_type, "images_grid_thw", grid_batch) # Calculate pixel values slice - patch_start = sum(g[0] * g[1] * g[2] for g in grid_thw[:i]) + patch_start = sum(g[0] * g[1] * g[2] for g in images_grid_thw[:i]) patch_end = patch_start + sum(g[0] * g[1] * g[2] for g in grid_batch) - self._add_to_buffer(buffer_type, "images_pixel_values", pixel_values[patch_start:patch_end]) + self._add_to_buffer(buffer_type, "images_pixel_values", images_pixel_values[patch_start:patch_end]) + # Handle video tensors + if videos_grid_thw is not None: + videos_grid_batch = videos_grid_thw[batch_slice] + self._add_to_buffer(buffer_type, "videos_grid_thw", videos_grid_batch) + # -------------------------------------------------------------------- # Step 2: Determine how many micro‑batches we need to return # -------------------------------------------------------------------- - total_micro = self.rollout_batch_size // self.micro_rollout_batch_size + total_micro = self.rollout_batch_size // self.strategy.world_size target_micro = int(self.partial_percent * total_micro) # How many micro‑batches are already available in the noregen buffer? @@ -349,67 +400,207 @@ def generate_samples( # Take all noregen samples and supplement with regenerated ones samples_needed = target_micro - noregen_micro noregen_data = self._get_from_buffer('noregen', noregen_micro * self.micro_rollout_batch_size) - regen_data = self._regenerate_from_buffer(samples_needed * self.micro_rollout_batch_size, **generate_kwargs) + regen_data = self._regenerate_from_buffer(samples_needed * self.micro_rollout_batch_size, is_multimodal, **generate_kwargs) samples_data = self._merge_data(noregen_data, regen_data) # -------------------------------------------------------------------- # Step 3: Convert the collected data back to Samples objects # -------------------------------------------------------------------- - samples_list = self._generate_sample_list( - samples_data, - is_multimodal, - **generate_kwargs - ) - self.strategy.maybe_sleep_inference_engine() - if all_prompts is None: - return samples_list - else: - # Return tuple with samples_list and images_num, consistent with fast_exp_maker_partial.py - images_num_list = samples_data.get("images_num") - return samples_list, images_num_list - - def _process_multimodal_data(self, all_prompts, all_images, all_videos, is_internvl, all_references, images_num, videos_num): - """Wrapper around parent's multimodal_processor.process_multimodal_batch.""" - if self.multimodal_processor is None: - raise ValueError("Multimodal processor not initialized.") - return self.multimodal_processor.process_multimodal_batch( - all_prompts=all_prompts, - all_images=all_images, - all_videos=all_videos, - all_references=all_references, - images_num=images_num, - videos_num=videos_num, - n_samples_per_prompt=self.strategy.config.n_samples_per_prompt, - is_internvl=is_internvl, - ) + samples_list = [] + image_patch_idx = 0 + video_patch_idx = 0 + image_start_idx = 0 + video_start_idx = 0 + + all_outputs = samples_data.get("output", []) + all_labels = samples_data.get("labels", []) + all_prompts = samples_data.get("prompts", []) + all_images = samples_data.get("images", []) + all_images_num = samples_data.get("images_num", None) + all_images_grid_thw = samples_data.get("images_grid_thw", None) + all_images_pixel_values = samples_data.get("images_pixel_values", None) + all_videos_num = samples_data.get("videos_num", None) + all_videos_grid_thw = samples_data.get("videos_grid_thw", None) + all_videos_pixel_values = samples_data.get("videos_pixel_values", None) + all_references = samples_data.get("references", []) + + for i in range(0, len(all_outputs), config.micro_rollout_batch_size): + micro_batch_outputs = all_outputs[i:i + config.micro_rollout_batch_size] + micro_batch_prompts = all_prompts[i:i + config.micro_rollout_batch_size] + + # Extract micro-batch data + micro_batch_grid_thw = None + micro_batch_video_grid_thw = None + micro_batch_raw_images = None + + if is_multimodal: + rollout_image_count = sum(all_images_num[i:i + config.micro_rollout_batch_size]) + micro_batch_grid_thw = all_images_grid_thw[image_start_idx:image_start_idx + rollout_image_count] + micro_batch_raw_images = all_images[i:i + config.micro_rollout_batch_size] + image_start_idx += rollout_image_count + + rollout_video_count = sum(all_videos_num[i:i + config.micro_rollout_batch_size]) + micro_batch_video_grid_thw = all_videos_grid_thw[video_start_idx:video_start_idx + rollout_video_count] + video_start_idx += rollout_video_count + + micro_batch_references = (all_references[i:i + config.micro_rollout_batch_size] if all_references else None) + micro_batch_labels = (all_labels[i:i + config.micro_rollout_batch_size] if all_labels else None) + # Build samples + if not self.packing_samples: + sample, updated_patch_idx, updated_video_patch_idx = self._build_unpacked_sample( + outputs=micro_batch_outputs, + prompts=micro_batch_prompts, + labels=micro_batch_labels, + references=micro_batch_references, + is_multimodal=is_multimodal, + grid_thw=micro_batch_grid_thw, + video_grid_thw=micro_batch_video_grid_thw, + raw_images=micro_batch_raw_images, + pixel_values=all_images_pixel_values if is_multimodal else None, + pixel_values_videos=all_videos_pixel_values if is_multimodal else None, + images_num=all_images_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None, + videos_num=all_videos_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None, + image_patch_idx=image_patch_idx, + video_patch_idx=video_patch_idx, + ) + # Update patch indices from the returned values + if updated_patch_idx is not None: + image_patch_idx = updated_patch_idx + if updated_video_patch_idx is not None: + video_patch_idx = updated_video_patch_idx + samples_list.append(sample) + else: + # Packed samples + sample = self._build_packed_sample( + outputs=micro_batch_outputs, + prompts=micro_batch_prompts, + labels=micro_batch_labels, + references=micro_batch_references, + ) + samples_list.append(sample) + + # Report timing + torch.cuda.synchronize() + gen_time = torch.tensor(time.time() - start_time, device=get_current_device()) + torch.distributed.all_reduce(gen_time, op=torch.distributed.ReduceOp.MAX) + self.strategy.print(f"***Rollout engine generation time (global max): {gen_time.item():.4f}s") + self.strategy.report_memory("after rollout engine generation") + + return samples_list def _add_to_buffer(self, buffer_type: str, data_name: str, data): - """Add data to specified buffer.""" + """Add data to specified buffer. + + Args: + buffer_type: 'regen' or 'noregen' + data_name: Key name for storing data + data: Data to add (can be tensor, list, or other) + + Special handling: + - Keys with 'grid_thw': split 2D tensors by rows + - Keys with 'pixel_values': keep 2D tensors as-is + - Other 2D tensors: split by rows + """ buffer = self.regen_buffer if buffer_type == 'regen' else self.noregen_buffer if data_name not in buffer: buffer[data_name] = [] + if isinstance(data, torch.Tensor): - buffer[data_name].append(data) + is_grid_thw = 'grid_thw' in data_name + is_pixel_values = 'pixel_values' in data_name + + if data.dim() == 2: + if is_grid_thw: + # Split grid_thw 2D tensors by rows + buffer[data_name].extend(torch.unbind(data, dim=0)) + elif is_pixel_values: + # Keep pixel_values 2D tensors intact + buffer[data_name].append(data) + else: + # Split other 2D tensors by rows + buffer[data_name].extend(torch.unbind(data, dim=0)) + else: + # Add 1D or higher-dim tensors as-is + buffer[data_name].append(data) else: buffer[data_name].extend(data if isinstance(data, list) else [data]) - def _get_from_buffer(self, buffer_type: str, count: Optional[int]): - """Retrieve data from buffer, optionally limiting the amount.""" + + def _get_from_buffer(self, buffer_type: str, count: Optional[int] = None): + """Retrieve data from buffer, optionally limiting the amount. + + Args: + buffer_type: 'regen' or 'noregen' + count: Number of items to retrieve. If None, retrieve all. + + Returns: + Dictionary with retrieved data. + Special handling: + - grid_thw keys: stack 1D tensors to 2D + - pixel_values keys: concatenate 2D tensors + """ buffer = self.regen_buffer if buffer_type == 'regen' else self.noregen_buffer result = {} + for key, lst in buffer.items(): + if not lst: + # Return empty tensor with proper shape + if 'grid_thw' in key: + result[key] = torch.tensor([]).reshape(0, 3) + elif 'pixel_values' in key: + result[key] = torch.tensor([]) + else: + result[key] = torch.tensor([]) + + if count is None: + buffer[key] = [] + continue + + all_tensors = all(isinstance(item, torch.Tensor) for item in lst) + if count is None: - result[key] = lst.copy() + # Retrieve all data + if all_tensors: + if 'pixel_values' in key and lst[0].dim() >= 2: + # Concatenate pixel_values 2D tensors + result[key] = torch.cat(lst, dim=0) if lst else torch.tensor([]) + elif 'grid_thw' in key and lst[0].dim() == 1: + # Stack grid_thw 1D tensors to 2D + result[key] = torch.stack(lst, dim=0) if lst else torch.tensor([]).reshape(0, 3) + elif lst[0].dim() == 1: + # Stack 1D tensors to 2D + result[key] = torch.stack(lst, dim=0) if lst else torch.tensor([]) + else: + # Concatenate other tensors + result[key] = torch.cat(lst, dim=0) if lst else torch.tensor([]) + else: + result[key] = lst.copy() buffer[key] = [] else: - result[key] = lst[:count] + # Retrieve specified number of items + items_to_take = lst[:count] + + if all_tensors and items_to_take: + if 'pixel_values' in key and items_to_take[0].dim() >= 2: + result[key] = torch.cat(items_to_take, dim=0) + elif 'grid_thw' in key and items_to_take[0].dim() == 1: + result[key] = torch.stack(items_to_take, dim=0) + elif items_to_take[0].dim() == 1: + result[key] = torch.stack(items_to_take, dim=0) + else: + result[key] = torch.cat(items_to_take, dim=0) + else: + result[key] = items_to_take + + # Update buffer buffer[key] = lst[count:] + return result @torch.no_grad() - def _regenerate_from_buffer(self, num_needed: int, **kwargs) -> dict: + def _regenerate_from_buffer(self, num_needed: int, is_multimodal: bool, **kwargs) -> dict: """Regenerate outputs for samples that reached token budget.""" - args = self.strategy.args + config = self.strategy.config # Get data from regeneration buffer regen_data = self._get_from_buffer("regen", num_needed) @@ -440,7 +631,7 @@ def _regenerate_from_buffer(self, num_needed: int, **kwargs) -> dict: ] # Prepare sampling parameters - if args.engine_type == "vllm": + if config.engine_type == "vllm": sampling_params = SamplingParams( temperature=kwargs.get("temperature", 1.0), top_p=kwargs.get("top_p", 1.0), @@ -451,7 +642,7 @@ def _regenerate_from_buffer(self, num_needed: int, **kwargs) -> dict: include_stop_str_in_output=True, ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", ) - elif args.engine_type == "sglang": + elif config.engine_type == "sglang": sampling_params = dict( n=1, temperature=kwargs.get("temperature", 1.0), @@ -466,11 +657,9 @@ def _regenerate_from_buffer(self, num_needed: int, **kwargs) -> dict: ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", ) else: - raise ValueError(f"Unsupported backend: {args.engine_type}") + raise ValueError(f"Unsupported backend: {config.engine_type}") - # Build inputs and regenerate using the same pattern as fast_exp_maker_partial.py - # First, check if multimodal - is_multimodal = regen_data.get("images") is not None and len(regen_data["images"]) > 0 + # Build inputs and regenerate using the same pattern if is_multimodal: # Use strategy._build_multimodal_inputs inputs = self.strategy._build_multimodal_inputs( @@ -521,202 +710,3 @@ def _merge_data(self, data1: Dict[str, List], data2: Dict[str, List]) -> Dict[st else: merged[key] = val1 if val1 else val2 return merged - - def _generate_sample_list( - self, - samples_data: Dict[str, List], - is_multimodal: bool, - **kwargs - ) -> List[Samples]: - """Convert buffered data into a list of Samples.""" - args = self.strategy.args - samples_list = [] - gen_max_len, gen_min_len = 0, 102400000 - index_pixel_patch = 0 - image_start_idx = 0 - - all_outputs = samples_data.get("output", []) - all_labels = samples_data.get("labels", []) - all_prompts = samples_data.get("prompts", []) - all_images = samples_data.get("images", []) - all_images_num = samples_data.get("images_num", []) - all_images_pixel_values = samples_data.get("images_pixel_values", []) - all_images_grid_thw = samples_data.get("images_grid_thw", []) - all_image_flags = samples_data.get("image_flags", []) - all_references = samples_data.get("references", []) - - for i in range(0, len(all_outputs), args.micro_rollout_batch_size): - outputs = all_outputs[i: i + args.micro_rollout_batch_size] - prompts = all_prompts[i: i + args.micro_rollout_batch_size] - if all_images: - assert all_images_num is not None - rollout_image_num = sum(all_images_num[i: i + args.micro_rollout_batch_size]) - images_grid_thw = all_images_grid_thw[image_start_idx: image_start_idx + rollout_image_num] - raw_images = all_images[image_start_idx: image_start_idx + rollout_image_num] - image_start_idx += rollout_image_num - if all_references: - references = all_references[i: i + args.micro_rollout_batch_size] - labels = all_labels[i: i + args.micro_rollout_batch_size] - - if not self.packing_samples: - # Build unpacked samples - max_input_len, max_output_len = 0, 0 - for output in outputs: - max_input_len = max(max_input_len, len(output.prompt_token_ids)) - max_output_len = max(max_output_len, len(output.output_token_ids)) - - pad_token_id, eos_token_id = self.tokenizer.pad_token_id, self.tokenizer.eos_token_id - sequences = [] - pixel_values = [] - image_grid_thw_list = [] - image_flags = [] - images_grid_id = 0 - for j in range(len(outputs)): - output = outputs[j] - input_len = len(output.prompt_token_ids) - input_ids = [pad_token_id] * (max_input_len - input_len) + list(output.prompt_token_ids) - output_len = len(output.output_token_ids) - output_ids = list(output.output_token_ids) + [pad_token_id] * (max_output_len - output_len) - # split pixel_patch - if all_images: - image_num = all_images_num[i + j] - for image_id in range(0, image_num): - images_grid = images_grid_thw[images_grid_id + image_id] - num_patch = images_grid[0] * images_grid[1] * images_grid[2] - image_grid_thw_list.append(images_grid.clone().unsqueeze(0)) - images_pixel_value = all_images_pixel_values[index_pixel_patch: index_pixel_patch + num_patch] - pixel_values.append(images_pixel_value.clone()) - index_pixel_patch += num_patch - images_grid_id += image_num - sequences.append(input_ids + output_ids) - - sequences = torch.tensor(sequences) - sequences, attention_mask, action_mask = self.actor.process_sequences( - sequences, max_input_len, eos_token_id, pad_token_id - ) - sequences = sequences.to("cuda") - attention_mask = attention_mask.to("cuda") - action_mask = action_mask.to("cuda") - if not all_images: - samples_list.append( - Samples( - sequences=sequences, - attention_mask=attention_mask, - action_mask=action_mask, - num_actions=action_mask.size(1), - packed_seq_lens=None, - response_length=action_mask.float().sum(dim=-1), - total_length=attention_mask.float().sum(dim=-1), - prompts=prompts, - labels=labels, - pad_len=None, - ) - ) - else: - pixel_values = torch.cat(pixel_values, dim=0).to("cuda") if pixel_values else None - pixel_values_intern = None - samples_list.append( - SamplesVL( - sequences=sequences, - attention_mask=attention_mask, - action_mask=action_mask, - image_grid_thws=torch.cat(image_grid_thw_list, dim=0).to("cuda"), - raw_images=raw_images, - pixel_values=pixel_values, - pixel_values_intern=pixel_values_intern, - image_flags=torch.cat(image_flags, dim=0).to("cuda"), - num_actions=action_mask.size(1), - packed_seq_lens=None, - response_length=action_mask.float().sum(dim=-1), - total_length=attention_mask.float().sum(dim=-1), - references=references, - labels=labels, - prompts=prompts, - ) - ) - else: - # Packed samples (not supporting VLM yet) - pad_token_id, eos_token_id = self.tokenizer.pad_token_id, self.tokenizer.eos_token_id - sequences = [] - packed_seq_lens = [] - attention_mask = [] - num_actions = [] - for idx, output in enumerate(outputs): - input_len = len(output.prompt_token_ids) - output_len = len(output.output_token_ids) - packed_seq_lens.append(input_len + output_len) - sequences.extend(output.prompt_token_ids + list(output.output_token_ids)) - attention_mask.extend([idx + 1] * (input_len + output_len)) - num_actions.append(max(1, output_len)) - gen_max_len = max(gen_max_len, output_len) - gen_min_len = min(gen_min_len, output_len) - - sequences = torch.tensor(sequences, device="cuda").unsqueeze(0) - attention_mask = torch.tensor(attention_mask, device="cuda").unsqueeze(0) - action_mask = None - response_length = torch.tensor(num_actions, device="cuda", dtype=torch.float) - total_length = torch.tensor(packed_seq_lens, device="cuda", dtype=torch.float) - samples_list.append( - Samples( - sequences=sequences, - attention_mask=attention_mask, - action_mask=None, - num_actions=num_actions, - packed_seq_lens=packed_seq_lens, - response_length=response_length, - total_length=total_length, - prompts=prompts, - labels=labels, - pad_len=None, - ) - ) - - if dist.get_rank(self.backend_mp_group) == 0: - print(f"*** response_length {gen_max_len=}, {gen_min_len=}") - - return samples_list - - def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Experience], List[torch.Tensor]]: - """ - Process experiences for reward shaping and filtering under partial rollout. - - This method overrides the parent's process_experiences to adjust advantage - estimators (RLOO, Group Norm, REINFORCE) for the reduced group size introduced - by partial rollout. The group size is scaled by partial_percent to reflect the - actual number of samples generated per call. - - For each estimator: - - RLOO: Uses n_samples_per_prompt as usual (since RLOO operates per‑prompt) - - GRPO/Group Norm: Adjusts group size to partial_percent × rollout_batch_size // micro_rollout_batch_size - - REINFORCE Baseline: Removes the baseline using the same n_samples_per_prompt - - :param experiences: List of Experience objects with raw rewards stored in info["reward"] - :type experiences: List[Experience] - :return: Tuple of (unchanged experiences, shaped reward tensors split per experience) - :rtype: Tuple[List[Experience], List[torch.Tensor]] - """ - args = self.strategy.args - if args.advantage_estimator == "rloo": - rewards = torch.cat([exp.info["reward"] for exp in experiences]) - rewards = rewards.reshape(-1, args.n_samples_per_prompt) - baseline = (rewards.sum(-1, keepdim=True) - rewards) / (args.n_samples_per_prompt - 1) - rewards = rewards - baseline - rewards = rewards.flatten().chunk(len(experiences)) - return experiences, rewards - elif args.advantage_estimator in ["grpo", "group_norm"]: - # Adjust group size according to partial rollout - group_size = int(self.partial_percent * self.rollout_batch_size // args.micro_rollout_batch_size) - rewards = torch.cat([exp.info["reward"] for exp in experiences]) - rewards = rewards.reshape(-1, group_size) - baseline = rewards.mean(-1, keepdim=True) - rewards = (rewards - baseline) / (rewards.std(1, keepdim=True) + 1e-8) - rewards = rewards.flatten().chunk(len(experiences)) - return experiences, rewards - elif args.advantage_estimator == "reinforce_baseline": - rewards = torch.cat([exp.info["reward"] for exp in experiences]) - rewards = rewards.reshape(-1, args.n_samples_per_prompt).to(device="cuda") - rewards = rewards - rewards.mean(-1, keepdim=True) - rewards = rewards.reshape(-1).to(device="cpu").chunk(len(experiences)) - return experiences, rewards - else: - raise ValueError(f"Unhandled advantage_estimator: {args.advantage_estimator}") diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index cbbe868d..93132235 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -21,6 +21,7 @@ import time import torch +import math from tqdm import tqdm from lightrft.trainer import PPOTrainer, PPOTrainerVL @@ -31,6 +32,7 @@ from lightrft.trainer.replay_buffer import make_experience_batch from lightrft.trainer.replay_buffer_vl import make_experience_batch as make_experience_batch_vl from lightrft.models.utils import create_high_entropy_mask +from lightrft.utils.distributed_sampler import DistributedSampler from lightrft.utils import init_logger logger = init_logger(__name__) @@ -692,6 +694,7 @@ def __init__( SPMDPPOTrainerBase.__init__(self, *args, VLM=True, **kwargs) if self.args.use_partial: # Replace experience maker with partial version + processor = kwargs.pop("processor", None) self.experience_maker = PartialFastExperienceMaker( self.actor, self.critic, @@ -706,7 +709,7 @@ def __init__( self.reward_fn_label_map, self.reward_recipe, packing_samples=self.packing_samples, - processor=self._processor, + processor=processor, partial_percent=getattr(self.args, "partial_percent", 0.7), max_token_budget=getattr(self.args, "max_token_budget", 1024), ) @@ -735,27 +738,60 @@ def _make_experience_iterator(self, dataloader, use_partial): if self.experience_maker.need_new_prompts(self.args.rollout_batch_size, self.micro_rollout_batch_size): try: # Get next batch of prompts, images, references, and labels - rand_prompts, rand_images, rand_references, rand_labels = next(dataloader_iter) + batch = next(dataloader_iter) + # Handle variable batch size (4 or 5 elements) + if len(batch) == 5: + rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch + else: + rand_prompts, rand_images, rand_references, rand_labels = batch + rand_videos = None except StopIteration: # End of epoch reached break # Generate experiences from new prompts experiences = self.experience_maker.make_experience_list( - rand_prompts, rand_images, rand_references, rand_labels, **self.generate_kwargs + rand_prompts, rand_images, rand_videos, rand_references, rand_labels, + **self.generate_kwargs ) else: # Generate experiences from cached prompts experiences = self.experience_maker.make_experience_list( - None, None, None, None, **self.generate_kwargs + None, None, None, None, None, **self.generate_kwargs ) yield experiences else: # Non-partial rollout logic - for rand_prompts, rand_images, rand_references, rand_labels in dataloader: + for batch in dataloader: + # Compatible with both image-only (4 args) and video (5 args) dataloaders + if len(batch) == 5: + rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch + else: + rand_prompts, rand_images, rand_references, rand_labels = batch + rand_videos = None + + # TODO: Remove debug print + self.strategy.print( + f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa + ) + experiences = self.experience_maker.make_experience_list( - rand_prompts, rand_images, rand_references, rand_labels, **self.generate_kwargs + rand_prompts, rand_images, rand_videos, rand_references, rand_labels, + **self.generate_kwargs ) + + # Debug print for first experience + for i, experience in enumerate(experiences): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + self.strategy.print( + f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + ) + break + yield experiences def _process_experiences_and_train(self, experiences, steps): @@ -789,6 +825,88 @@ def _process_experiences_and_train(self, experiences, steps): # Report memory usage after replay buffer is filled self.strategy.report_memory('after replay_buffer ready') + # Aggregate rollout statistics from replay buffer + # Collect metrics from the rollout/collection phase + rollout_status = {} + if self.replay_buffer.items: + all_rewards = [] + all_format_rewards = [] + all_accuracy_rewards = [] + all_response_lengths = [] + + for item in self.replay_buffer.items: + # Collect rewards from rollout + if hasattr(item, 'info') and item.info is not None and 'reward' in item.info: + all_rewards.append(item.info['reward']) + + # Robust handling of reward_metrics + # 1. Check if info exists + # 2. Check if 'reward_metrics' key exists + # 3. Check if reward_metrics is not None (critical!) + if ( + hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info + and item.info['reward_metrics'] is not None + ): + + reward_metrics = item.info['reward_metrics'] + + # Safely extract sub-metrics + if 'format_reward' in reward_metrics: + all_format_rewards.append(reward_metrics['format_reward']) + if 'accuracy_reward' in reward_metrics: + all_accuracy_rewards.append(reward_metrics['accuracy_reward']) + + # Collect response lengths from rollout + if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: + all_response_lengths.append(item.info['response_length']) + + # Compute rollout statistics + device = torch.cuda.current_device() + + if all_rewards: + # [TENSOR-FIX] Handle both tensor lists and scalar lists + if isinstance(all_rewards[0], torch.Tensor): + rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards]) + else: + rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device) + rollout_status["rollout_reward"] = rewards_tensor.mean().item() + rollout_status["rollout_reward_std"] = rewards_tensor.std().item() + + if all_format_rewards: + # [TENSOR-FIX] Handle both tensor lists and scalar lists + if isinstance(all_format_rewards[0], torch.Tensor): + format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards]) + else: + format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) + + mean_format_reward = format_tensor.mean().item() + + # Only display if mean is significantly non-zero + if abs(mean_format_reward) > 1e-6: + rollout_status["rollout_format_reward"] = mean_format_reward + + if all_accuracy_rewards: + # [TENSOR-FIX] Handle both tensor lists and scalar lists + if isinstance(all_accuracy_rewards[0], torch.Tensor): + accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards]) + else: + accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) + + mean_accuracy_reward = accuracy_tensor.mean().item() + + # Only display if mean is significantly non-zero + if abs(mean_accuracy_reward) > 1e-6: + rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward + + if all_response_lengths: + # [TENSOR-FIX] Handle both tensor lists and scalar lists + if isinstance(all_response_lengths[0], torch.Tensor): + lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths]) + else: + lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device) + + rollout_status["rollout_response_length"] = lengths_tensor.mean().item() + # Normalize advantages if not using group normalization if self.args.advantage_estimator != "group_norm": self.replay_buffer.normalize("advantages", self.strategy) @@ -806,88 +924,105 @@ def _process_experiences_and_train(self, experiences, steps): if "kl" in status: self.kl_ctl.update(status["kl"], self.args.rollout_batch_size * self.args.n_samples_per_prompt) - return status + # Merge rollout status and training status + merged_status = {**rollout_status, **status} + return merged_status def fit( self, args, prompts_dataloader, pretrain_dataloader, + eval_dataloader=None, consumed_samples=0, num_update_steps_per_episodes=1, ) -> None: """ - Execute the main training loop for vision-language models using SPMD PPO. - - This method orchestrates the complete training process including: - 1. Rollout phase: Generate experiences by interacting with the environment - 2. Training phase: Update actor and critic models using collected experiences - 3. Evaluation and checkpointing: Save models and logs at specified intervals - - The training loop follows the standard PPO pattern with distributed data - parallelism for efficient multi-device training. It handles both - image-text prompts and pre-training data for mixed objective optimization. - - :param args: Training configuration arguments containing hyperparameters like num_episodes, rollout_batch_size, n_samples_per_prompt, etc. - :type args: argparse.Namespace - :param prompts_dataloader: DataLoader providing image-text prompts for training - :type prompts_dataloader: torch.utils.data.DataLoader - :param pretrain_dataloader: DataLoader providing pre-training data for supervised fine-tuning - :type pretrain_dataloader: torch.utils.data.DataLoader - :param consumed_samples: Number of samples already processed (for resuming training) + Main training loop for PPO. + + :param args: Training arguments. + :type args: Namespace + :param prompts_dataloader: DataLoader for prompt data. + :type prompts_dataloader: DataLoader + :param pretrain_dataloader: DataLoader for pre-training data. + :type pretrain_dataloader: DataLoader + :param eval_dataloader: DataLoader for evaluation data, defaults to None. + :type eval_dataloader: DataLoader, optional + :param consumed_samples: Number of samples already consumed, defaults to 0. :type consumed_samples: int - :param num_update_steps_per_episodes: Number of update steps per training episode + :param num_update_steps_per_episodes: Number of update steps per episode, defaults to 1. :type num_update_steps_per_episodes: int - :return: None - - Example:: + """ + # Determine if using partial rollout + use_partial = getattr(self.args, 'use_partial', False) - trainer.fit( - args=training_args, - prompts_dataloader=image_prompt_loader, - pretrain_dataloader=pretrain_loader, - consumed_samples=0, - num_update_steps_per_episodes=4 + # Calculate samples per rollout and per training iteration + samples_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt + samples_per_train = args.train_batch_size * args.n_samples_per_prompt + + # Print training mode information + if args.train_batch_size < args.rollout_batch_size: + updates_per_rollout = samples_per_rollout / samples_per_train + self.strategy.print( + f"\n{'=' * 80}\n" + f"HIGH FREQUENCY UPDATE MODE: train_batch_size ({args.train_batch_size}) < rollout_batch_size ({args.rollout_batch_size})\n" # noqa + f"{'=' * 80}\n" + f"Behavior:\n" + f" - Each rollout generates {samples_per_rollout} samples.\n" + f" - Each rollout will trigger {updates_per_rollout:.2f} optimizer updates.\n" + f" - Total updates will be HIGHER than standard mode for the same amount of data.\n" + f"{'=' * 80}\n" + ) + elif args.train_batch_size > args.rollout_batch_size: + self.strategy.print( + f"\n{'=' * 80}\n" + f"ACCUMULATION MODE: train_batch_size ({args.train_batch_size}) > rollout_batch_size ({args.rollout_batch_size})\n" # noqa + f"{'=' * 80}\n" + f"Behavior:\n" + f" - Multiple rollouts needed for one update.\n" + f"{'=' * 80}\n" ) - .. note:: Distributed Training - This method handles distributed data parallelism automatically using - DistributedSampler when available. Each rank processes a different - subset of the data. + # Calculate number of rollouts per episode. + # Regardless of TBS and RBS relationship, rollout count should be determined by "total data / rollout size". + # Numerator (num_update_steps * train_batch_size) equals "total samples planned for this episode". + # Denominator (rollout_batch_size * n_samples) equals "samples produced per rollout". + # This calculation ensures data collection volume is constant. + # When TBS=64, num_update_steps is naturally twice as large as when TBS=128. + # Substituting into formula: (2N * 0.5T) / R = (N * T) / R. + # Conclusion: Rollout count unchanged, but internal update loop count doubles due to smaller TBS. - .. warning:: Memory Management - The method includes explicit memory reporting and cache clearing - to manage GPU memory efficiently during training. - """ - # Calculate number of rollouts per episode based on batch sizes and epochs num_rollouts_per_episodes = ( - num_update_steps_per_episodes - * args.train_batch_size - // args.max_epochs - // args.rollout_batch_size - // args.n_samples_per_prompt + num_update_steps_per_episodes * args.train_batch_size // args.max_epochs // args.rollout_batch_size // + args.n_samples_per_prompt ) - # Configure evaluation and checkpointing intervals - # If eval_steps is -1, evaluate once per epoch + # Safeguard to prevent num_rollouts_per_episodes from being 0 + if num_rollouts_per_episodes == 0: + # Try recalculating with ceil to prevent fractional values from being discarded by integer division + val = (num_update_steps_per_episodes * + args.train_batch_size) / (args.max_epochs * args.rollout_batch_size * args.n_samples_per_prompt) + num_rollouts_per_episodes = math.ceil(val) + + if num_rollouts_per_episodes == 0: + self.strategy.print("[WARNING] Calculated num_rollouts_per_episodes is 0. Forcing to 1.") + num_rollouts_per_episodes = 1 + + # Get eval and save steps if args.eval_steps == -1: - args.eval_steps = num_rollouts_per_episodes - # If save_steps is -1, disable checkpoint saving + args.eval_steps = num_rollouts_per_episodes # Evaluate once per epoch if args.save_steps == -1: - args.save_steps = float("inf") + args.save_steps = float("inf") # Do not save checkpoint - # Store data loaders for access throughout training self.prompts_dataloader = prompts_dataloader self.pretrain_dataloader = pretrain_dataloader + self.eval_dataloader = eval_dataloader # Save for evaluation - # Calculate starting step and episode for resuming training + # Restore step and start_episode steps = consumed_samples // args.rollout_batch_size + 1 start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size) - # Determine if using partial rollout - use_partial = getattr(self.args, 'use_partial', False) - # Main training loop over episodes for episode in range(start_episode, args.num_episodes): # Configure distributed sampler for current episode @@ -903,26 +1038,25 @@ def fit( disable=not self.strategy.is_rank_0(), ) - # Create experience iterator + # Unified training loop using experience iterator experience_iterator = self._make_experience_iterator(self.prompts_dataloader, use_partial) for experiences in experience_iterator: # Process experiences and perform training step status = self._process_experiences_and_train(experiences, steps) - # Update progress bar with training status + # Update progress bar with training status (includes rollout stats) pbar.set_postfix(status) # Save logs and checkpoints at appropriate intervals client_states = {"consumed_samples": steps * args.rollout_batch_size} - self.save_logs_and_checkpoints(args, steps, pbar, status, client_states) + self.save_logs_and_checkpoints(args, steps, pbar, status, client_states, episode=episode) # Update step counter and progress bar pbar.update() steps = steps + 1 - # Clean up monitoring tools if self._wandb is not None and self.strategy.is_rank_0(): self._wandb.finish() if self._tensorboard is not None and self.strategy.is_rank_0(): - self._tensorboard.close() + self._tensorboard.close() \ No newline at end of file