diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 818bcce2..5a378a44 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -398,6 +398,10 @@ def train(args): overlong_buffer_len=args.overlong_buffer_len, overlong_buffer_penalty_factor=args.overlong_buffer_penalty_factor, print_replay_buffer_stats=args.print_replay_buffer_stats, + # partial rollout + use_partial=args.use_partial, + partial_percent=args.partial_percent, + max_budget=args.max_budget, ) trainer.fit(args, prompts_dataloader=prompts_dataloader, pretrain_dataloader=pretrain_dataloader, eval_dataloader=eval_dataloader, consumed_samples=0, num_update_steps_per_episodes=num_update_steps_per_episodes) @@ -610,6 +614,11 @@ def train(args): # High-entropy token filtering (from "Beyond the 80/20 Rule" paper) parser.add_argument("--high_entropy_token_ratio", type=float, default=0.0, help="Ratio of high-entropy tokens to use for gradient updates (0.0 means use all tokens, 0.2 means use top 20% highest entropy tokens). Common value when enabled: 0.2. Based on 'Beyond the 80/20 Rule: High-Entropy Minority Tokens Drive Effective Reinforcement Learning for LLM Reasoning' (https://arxiv.org/abs/2506.01939)") + # Partial Rollout + parser.add_argument("--use_partial", action="store_true", default=False, help="whether to use the partial rollout trainer") + parser.add_argument("--partial_percent", type=float, default=0.7, help="partial rollout percent") + parser.add_argument("--max_budget", type=float, default=1024, help="partial rollout max_budget") + add_arguments(parser) args = parser.parse_args() diff --git a/lightrft/trainer/fast_exp_maker_partial.py b/lightrft/trainer/fast_exp_maker_partial.py new file mode 100644 index 00000000..720c127a --- /dev/null +++ b/lightrft/trainer/fast_exp_maker_partial.py @@ -0,0 +1,712 @@ +""" +PartialFastExperienceMaker Module – FastExperienceMaker with Partial Rollout and Token‑Budget Regeneration. + +This module extends FastExperienceMaker to support incremental rollout and controlled generation +for long‑form or high‑cost tasks. It introduces two core mechanisms: + 1. Partial rollout: only a fraction (partial_percent) of the total rollout batch is generated + in each call; the remaining samples are kept in buffers for subsequent steps, reducing + per‑iteration latency and enabling smoother pipeline scheduling. + 2. Token‑budget regeneration: samples that reach a predefined token budget (max_token_budget) + are flagged and can be regenerated later, allowing continuation of long‑form generation + without discarding already‑produced content. + +The class reuses the parent's infrastructure (MultimodalDataProcessor, RewardComputationEngine, +etc.) and only overrides the methods that implement the partial‑rollout logic. + +Implementation Overview: + - The rollout batch is split into "regeneration" and "non‑regeneration" buffers based on + whether a sample has exhausted its token budget. + - The method `need_new_prompts` determines whether fresh prompts are required to fill the + partial batch. + - Generation is performed via the parent's inference engine (VLLM/SGLang), but outputs are + post‑processed to respect the token budget and partial fraction. + - Advantage estimation methods (RLOO, Group Norm, REINFORCE) are adjusted to account for the + reduced group size introduced by partial rollout. + +Key Features: + - Partial rollout with configurable fraction (partial_percent) for reduced iteration latency + - Token‑budget regeneration for long‑form continuation (max_token_budget) + - Buffered sample management (regen_buffer, noregen_buffer) for stateful rollout + - Seamless integration with VLLM/SGLang backends and multimodal processing + - Adaptive advantage estimation that respects the partial batch size + - Support for both packed and unpacked sample formats + +Parameters: + partial_percent (float): Fraction of the total rollout batch to generate in one call. + Values between 0.0 and 1.0. Default: 0.7. + max_token_budget (int): Maximum allowed generation length before a sample is flagged for + regeneration. Samples that reach this length are stored in the regeneration buffer + and can be continued in a later step. Default: 1024. + +References: + - Kimi1.5: "Kimi k1.5: Scaling Reinforcement Learning with LLMs" (https://arxiv.org/abs/2501.12599) + - MiMo: "MiMo: Unlocking the Reasoning Potential of Language Model + -- From Pretraining to Posttraining" (https://arxiv.org/abs/2505.07608) + +Classes: + PartialFastExperienceMaker: Main experience generation class with partial‑rollout support. +""" + +from typing import List, Optional, Union, Tuple, Dict, Any +import os +import time +from copy import deepcopy + +import torch +from vllm import SamplingParams +from easydict import EasyDict + +from lightrft.trainer.experience_maker import Experience, Samples +from lightrft.trainer.experience_maker_vl import SamplesVL +from lightrft.trainer.fast_exp_maker import FastExperienceMaker +from lightrft.utils import Timer, get_current_device + + +class PartialFastExperienceMaker(FastExperienceMaker): + """ + FastExperienceMaker with partial rollout and token‑budget regeneration. + + This class extends FastExperienceMaker to support incremental rollout and controlled generation + for long‑form or high‑cost tasks. It introduces two core mechanisms: + 1. Partial rollout: only a fraction (partial_percent) of the total rollout batch is generated + in each call; the remaining samples are kept in buffers for subsequent steps, reducing + per‑iteration latency and enabling smoother pipeline scheduling. + 2. Token‑budget regeneration: samples that reach a predefined token budget (max_token_budget) + are flagged and can be regenerated later, allowing continuation of long‑form generation + without discarding already‑produced content. + + The class reuses the parent's infrastructure (MultimodalDataProcessor, RewardComputationEngine, + etc.) and only overrides the methods that implement the partial‑rollout logic. + + The partial‑rollout pipeline: + 1. Buffer Management: Maintain regeneration (regen) and non‑regeneration (noregen) buffers + 2. Need‑Prompts Check: Determine if fresh prompts are required to fill the partial batch + 3. Generation: Use parent's inference engine (VLLM/SGLang) but respect token budget and partial fraction + 4. Regeneration: For samples that exceed token budget, regenerate with continuation + 5. Advantage Adaptation: Adjust advantage estimators (RLOO, Group Norm, REINFORCE) for partial group size + + Args: + partial_percent: Fraction of the total rollout batch to generate in one call (0.0‑1.0) + max_token_budget: Maximum allowed generation length before a sample is flagged for regeneration + packing_samples: Whether to pack multiple sequences into single batch (inherited from parent) + processor: Multimodal processor for vision‑language models (inherited from parent) + *args, **kwargs: Arguments passed to parent FastExperienceMaker + """ + + def __init__( + self, + *args, + partial_percent: float = 0.7, + max_token_budget: int = 1024, + packing_samples: bool = False, + processor=None, + **kwargs + ): + """ + Initialize PartialFastExperienceMaker. + + :param args: Positional arguments for parent FastExperienceMaker + :type args: tuple + :param partial_percent: Fraction of total rollout batch to generate per call (0.0‑1.0) + :type partial_percent: float + :param max_token_budget: Maximum generation length before a sample is flagged for regeneration + :type max_token_budget: int + :param packing_samples: Enable sample packing for efficiency (inherited) + :type packing_samples: bool + :param processor: Multimodal processor for vision‑language models (inherited) + :type processor: Optional[Any] + :param kwargs: Keyword arguments for parent FastExperienceMaker + :type kwargs: dict + """ + super().__init__(*args, packing_samples=packing_samples, processor=processor, **kwargs) + self.partial_percent = partial_percent + self.max_token_budget = max_token_budget + + # Buffers for regeneration (regen) and non‑regeneration (noregen) samples. + # Each buffer is a dict mapping field names to lists of data. + self.regen_buffer: Dict[str, List] = {} + self.noregen_buffer: Dict[str, List] = {} + fields = [ + 'output', 'labels', 'prompts', 'references', + 'images', 'images_num', 'images_grid_thw', 'images_pixel_values', + 'videos', 'videos_num', 'videos_grid_thw', 'videos_pixel_values' + ] + for field in fields: + self.regen_buffer[field] = [] + self.noregen_buffer[field] = [] + + # Placeholders for batch‑size parameters (set by need_new_prompts) + self.rollout_batch_size = None + self.micro_rollout_batch_size = None + + def need_new_prompts(self, rollout_batch_size: int, micro_rollout_batch_size: int) -> bool: + """ + Determine whether new prompts are required to fill the partial rollout batch. + + This method checks the current regeneration and non‑regeneration buffers + and compares the total stored samples against the number needed for the + current partial rollout (partial_percent × total rollout batch size). + + If the buffers contain insufficient samples, the caller should fetch fresh + prompts and call generate_samples with those prompts. + + :param rollout_batch_size: Total number of samples in a full rollout batch + :type rollout_batch_size: int + :param micro_rollout_batch_size: Size of each micro‑batch used in generation + :type micro_rollout_batch_size: int + :return: True if new prompts are needed (buffers below partial threshold), else False + :rtype: bool + """ + self.rollout_batch_size = rollout_batch_size + self.micro_rollout_batch_size = micro_rollout_batch_size + + # Total micro‑batches needed for a full rollout + total_micro = rollout_batch_size // self.strategy.world_size + # Micro‑batches we want to generate in one call + target_micro = int(self.partial_percent * total_micro) + required_samples = target_micro * micro_rollout_batch_size + + # Count how many samples are already stored in both buffers + total_samples = len(self.noregen_buffer.get('output', [])) + len(self.regen_buffer.get('output', [])) + return total_samples < required_samples + + @torch.no_grad() + def generate_samples( + self, + all_prompts: List[str], + all_images: Optional[List] = None, + all_videos: Optional[List] = None, + images_num: Optional[List[int]] = None, + videos_num: Optional[List[int]] = None, + all_references: Optional[List[str]] = None, + all_labels: Optional[List] = None, + **generate_kwargs + ) -> List[Samples]: + """ + Generate samples using the parent's pipeline, but only a partial fraction. + + This method implements the partial‑rollout logic: + 1. If new prompts are provided, generate them via the parent's inference engine + (VLLM/SGLang) using token‑budget‑limited generation. + 2. Split the generated outputs into regeneration (regen) and non‑regeneration (noregen) + buffers based on whether they have reached the token budget. + 3. Draw from the buffers to produce the requested number of samples + (partial_percent × total rollout batch size). + 4. If the noregen buffer is insufficient, regenerate some samples from the regen buffer + by continuing generation from the partially‑produced output. + + The method returns a list of Samples (or SamplesVL) ready for experience making. + When new prompts are provided, it also returns the image counts for multimodal data. + + :param all_prompts: List of text prompts (or None to only draw from buffers) + :type all_prompts: List[str] + :param all_images: Optional images for vision‑language models + :type all_images: Optional[List] + :param all_videos: Optional videos for vision‑language models + :type all_videos: Optional[List] + :param images_num: Number of images per prompt + :type images_num: Optional[List[int]] + :param videos_num: Number of videos per prompt + :type videos_num: Optional[List[int]] + :param all_references: Reference texts for evaluation + :type all_references: Optional[List[str]] + :param all_labels: Sample labels for reward shaping + :type all_labels: Optional[List] + :param generate_kwargs: Generation parameters (temperature, max_new_tokens, etc.) + :type generate_kwargs: dict + :return: List of Samples (or SamplesVL) when all_prompts is None, + otherwise tuple (samples_list, images_num_list) + :rtype: Union[List[Samples], Tuple[List[Samples], Optional[List[int]]]] + """ + assert self.strategy.inference_engine is not None, "Inference engine required" + + torch.cuda.synchronize() + start_time = time.time() + + config = self.strategy.config + if all_prompts is not None: + is_multimodal = all_images is not None or all_videos is not None + else: + is_multimodal = (len(self.noregen_buffer.get('images', [])) + len(self.regen_buffer.get('images', []))) != 0 or \ + (len(self.noregen_buffer.get('videos', [])) + len(self.regen_buffer.get('videos', []))) != 0 + n_samples = config.n_samples_per_prompt + + # -------------------------------------------------------------------- + # Step 1: Generate new samples if inputs are provided + # -------------------------------------------------------------------- + if all_prompts is not None: + # Replicate the generation logic from fast_exp_maker_partial.py + # Prepare sampling parameters + if config.engine_type == "vllm": + sampling_params = SamplingParams( + temperature=generate_kwargs.get("temperature", 1.0), + top_p=generate_kwargs.get("top_p", 1.0), + top_k=generate_kwargs.get("top_k", -1), + max_tokens=self.max_token_budget, # use token budget + min_tokens=generate_kwargs.get("min_new_tokens", 1), + skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), + include_stop_str_in_output=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + elif config.engine_type == "sglang": + sampling_params = dict( + n=1, + temperature=generate_kwargs.get("temperature", 1.0), + top_p=generate_kwargs.get("top_p", 1.0), + top_k=generate_kwargs.get("top_k", -1), + max_new_tokens=self.max_token_budget, + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + skip_special_tokens=generate_kwargs.get("skip_special_tokens", False), + spaces_between_special_tokens=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + else: + raise ValueError(f"Unsupported backend: {config.engine_type}") + + # Expand labels + expanded_labels = sum([[label] * n_samples for label in all_labels], []) if all_labels else [] + + # Process multimodal data + if is_multimodal: + processed = self.multimodal_processor.process_multimodal_batch( + all_prompts=all_prompts, + all_images=all_images, + all_videos=all_videos, + all_references=all_references, + images_num=images_num, + videos_num=videos_num, + n_samples_per_prompt=n_samples, + ) + prompt_token_ids = processed["all_prompt_token_ids"] + prompts = processed["all_prompts"] + images = processed["all_images"] + images_num = processed["all_images_num"] + images_pixel_values = processed["all_images_pixel_values"] + images_grid_thw = processed["all_images_grid_thw"] + references = processed["all_references"] + videos = processed.get("all_videos") + videos_num = processed.get("all_videos_num") + videos_pixel_values = processed.get("all_videos_pixel_values") + videos_grid_thw = processed.get("all_videos_grid_thw") + else: + # Text-only processing + tokenized = self.tokenize_fn(all_prompts, self.prompt_max_len, padding=False) + prompt_token_ids = sum([[token_ids] * n_samples for token_ids in tokenized["input_ids"]], []) + + # ========== Generate via Inference Engine ========== + # Call fire_sampling function or direct generation + try: + if hasattr(self.strategy.args, 'use_fire') and self.strategy.args.use_fire: + # Use FIRE sampling (Flaming-hot Initiation with Regular Execution) + outputs = fire_sampling( + all_prompt_token_ids=all_prompt_token_ids, + generate_fn=generate_fn, # noqa: TODO + engine_type=config.engine_type, + first_token_temperature=generate_kwargs.get("first_token_temperature", 10.0), + temperature=generate_kwargs.get("temperature", 1.0), + first_token_top_k=generate_kwargs.get( + "first_token_top_k", sampling_params.top_k if hasattr(sampling_params, 'top_k') else -1 + ), + first_token_top_p=generate_kwargs.get( + "first_token_top_p", sampling_params.top_p if hasattr(sampling_params, 'top_p') else 1.0 + ), + is_multimodal=is_multimodal, + all_prompts=prompts, + all_images=images, + all_videos=videos, + all_images_num=images_num, + all_videos_num=videos_num, + sampling_params=sampling_params, + ) + else: + # maybe this can be called in if and else respectively? or like this? + # Use original single-shot generation + outputs = self.strategy.gather_and_generate( + sampling_params=sampling_params, + all_prompt_token_ids=prompt_token_ids, + all_prompts=prompts if is_multimodal else None, + sleep_engine=self.strategy.args.enable_engine_sleep, + all_images=images if is_multimodal else None, + all_videos=videos if is_multimodal else None, + images_num=images_num if is_multimodal else None, + videos_num=videos_num if is_multimodal else None, + ) + except ValueError as e: + if "prompt" in str(e) and "too long" in str(e): + self.strategy.print(f"[Skip] {e}") + return None # Return None, subsequent experience_maker will ignore + else: + raise + + # Process outputs in micro-batches and store in buffers + for i in range(0, len(outputs), n_samples): + batch_slice = slice(i, i + n_samples) + output_batch = outputs[batch_slice] + labels_batch = expanded_labels[batch_slice] if expanded_labels else [] + prompts_batch = prompts[batch_slice] + images_batch = images[batch_slice] if images else None + images_num_batch = images_num[batch_slice] if images_num else None + videos_batch = videos[batch_slice] if videos else None + videos_num_batch = videos_num[batch_slice] if videos_num else None + references_batch = references[batch_slice] if references else None + + # Check if regeneration is needed + needs_regen = any(len(out.output_token_ids) >= self.max_token_budget for out in output_batch) + buffer_type = "regen" if needs_regen else "noregen" + + # Add to appropriate buffer + self._add_to_buffer(buffer_type, "output", output_batch) + self._add_to_buffer(buffer_type, "labels", labels_batch) + self._add_to_buffer(buffer_type, "prompts", prompts_batch) + if images_batch is not None: + self._add_to_buffer(buffer_type, "images", images_batch) + if images_num_batch is not None: + self._add_to_buffer(buffer_type, "images_num", images_num_batch) + if videos_batch is not None: + self._add_to_buffer(buffer_type, "videos", videos_batch) + if videos_num_batch is not None: + self._add_to_buffer(buffer_type, "videos_num", videos_num_batch) + if references_batch is not None: + self._add_to_buffer(buffer_type, "references", references_batch) + + if is_multimodal: + # Handle image tensors + grid_batch = images_grid_thw[batch_slice] + self._add_to_buffer(buffer_type, "images_grid_thw", grid_batch) + # Calculate pixel values slice + patch_start = sum(g[0] * g[1] * g[2] for g in images_grid_thw[:i]) + patch_end = patch_start + sum(g[0] * g[1] * g[2] for g in grid_batch) + self._add_to_buffer(buffer_type, "images_pixel_values", images_pixel_values[patch_start:patch_end]) + # Handle video tensors + if videos_grid_thw is not None: + videos_grid_batch = videos_grid_thw[batch_slice] + self._add_to_buffer(buffer_type, "videos_grid_thw", videos_grid_batch) + + + # -------------------------------------------------------------------- + # Step 2: Determine how many micro‑batches we need to return + # -------------------------------------------------------------------- + total_micro = self.rollout_batch_size // self.strategy.world_size + target_micro = int(self.partial_percent * total_micro) + + # How many micro‑batches are already available in the noregen buffer? + noregen_micro = len(self.noregen_buffer['output']) // self.micro_rollout_batch_size + if noregen_micro >= target_micro: + # Enough noregen samples – just take them + samples_data = self._get_from_buffer('noregen', target_micro * self.micro_rollout_batch_size) + else: + # Take all noregen samples and supplement with regenerated ones + samples_needed = target_micro - noregen_micro + noregen_data = self._get_from_buffer('noregen', noregen_micro * self.micro_rollout_batch_size) + regen_data = self._regenerate_from_buffer(samples_needed * self.micro_rollout_batch_size, is_multimodal, **generate_kwargs) + samples_data = self._merge_data(noregen_data, regen_data) + + # -------------------------------------------------------------------- + # Step 3: Convert the collected data back to Samples objects + # -------------------------------------------------------------------- + samples_list = [] + image_patch_idx = 0 + video_patch_idx = 0 + image_start_idx = 0 + video_start_idx = 0 + + all_outputs = samples_data.get("output", []) + all_labels = samples_data.get("labels", []) + all_prompts = samples_data.get("prompts", []) + all_images = samples_data.get("images", []) + all_images_num = samples_data.get("images_num", None) + all_images_grid_thw = samples_data.get("images_grid_thw", None) + all_images_pixel_values = samples_data.get("images_pixel_values", None) + all_videos_num = samples_data.get("videos_num", None) + all_videos_grid_thw = samples_data.get("videos_grid_thw", None) + all_videos_pixel_values = samples_data.get("videos_pixel_values", None) + all_references = samples_data.get("references", []) + + for i in range(0, len(all_outputs), config.micro_rollout_batch_size): + micro_batch_outputs = all_outputs[i:i + config.micro_rollout_batch_size] + micro_batch_prompts = all_prompts[i:i + config.micro_rollout_batch_size] + + # Extract micro-batch data + micro_batch_grid_thw = None + micro_batch_video_grid_thw = None + micro_batch_raw_images = None + + if is_multimodal: + rollout_image_count = sum(all_images_num[i:i + config.micro_rollout_batch_size]) + micro_batch_grid_thw = all_images_grid_thw[image_start_idx:image_start_idx + rollout_image_count] + micro_batch_raw_images = all_images[i:i + config.micro_rollout_batch_size] + image_start_idx += rollout_image_count + + rollout_video_count = sum(all_videos_num[i:i + config.micro_rollout_batch_size]) + micro_batch_video_grid_thw = all_videos_grid_thw[video_start_idx:video_start_idx + rollout_video_count] + video_start_idx += rollout_video_count + + micro_batch_references = (all_references[i:i + config.micro_rollout_batch_size] if all_references else None) + micro_batch_labels = (all_labels[i:i + config.micro_rollout_batch_size] if all_labels else None) + # Build samples + if not self.packing_samples: + sample, updated_patch_idx, updated_video_patch_idx = self._build_unpacked_sample( + outputs=micro_batch_outputs, + prompts=micro_batch_prompts, + labels=micro_batch_labels, + references=micro_batch_references, + is_multimodal=is_multimodal, + grid_thw=micro_batch_grid_thw, + video_grid_thw=micro_batch_video_grid_thw, + raw_images=micro_batch_raw_images, + pixel_values=all_images_pixel_values if is_multimodal else None, + pixel_values_videos=all_videos_pixel_values if is_multimodal else None, + images_num=all_images_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None, + videos_num=all_videos_num[i:i + config.micro_rollout_batch_size] if is_multimodal else None, + image_patch_idx=image_patch_idx, + video_patch_idx=video_patch_idx, + ) + # Update patch indices from the returned values + if updated_patch_idx is not None: + image_patch_idx = updated_patch_idx + if updated_video_patch_idx is not None: + video_patch_idx = updated_video_patch_idx + samples_list.append(sample) + else: + # Packed samples + sample = self._build_packed_sample( + outputs=micro_batch_outputs, + prompts=micro_batch_prompts, + labels=micro_batch_labels, + references=micro_batch_references, + ) + samples_list.append(sample) + + # Report timing + torch.cuda.synchronize() + gen_time = torch.tensor(time.time() - start_time, device=get_current_device()) + torch.distributed.all_reduce(gen_time, op=torch.distributed.ReduceOp.MAX) + self.strategy.print(f"***Rollout engine generation time (global max): {gen_time.item():.4f}s") + self.strategy.report_memory("after rollout engine generation") + + return samples_list + + def _add_to_buffer(self, buffer_type: str, data_name: str, data): + """Add data to specified buffer. + + Args: + buffer_type: 'regen' or 'noregen' + data_name: Key name for storing data + data: Data to add (can be tensor, list, or other) + + Special handling: + - Keys with 'grid_thw': split 2D tensors by rows + - Keys with 'pixel_values': keep 2D tensors as-is + - Other 2D tensors: split by rows + """ + buffer = self.regen_buffer if buffer_type == 'regen' else self.noregen_buffer + if data_name not in buffer: + buffer[data_name] = [] + + if isinstance(data, torch.Tensor): + is_grid_thw = 'grid_thw' in data_name + is_pixel_values = 'pixel_values' in data_name + + if data.dim() == 2: + if is_grid_thw: + # Split grid_thw 2D tensors by rows + buffer[data_name].extend(torch.unbind(data, dim=0)) + elif is_pixel_values: + # Keep pixel_values 2D tensors intact + buffer[data_name].append(data) + else: + # Split other 2D tensors by rows + buffer[data_name].extend(torch.unbind(data, dim=0)) + else: + # Add 1D or higher-dim tensors as-is + buffer[data_name].append(data) + else: + buffer[data_name].extend(data if isinstance(data, list) else [data]) + + + def _get_from_buffer(self, buffer_type: str, count: Optional[int] = None): + """Retrieve data from buffer, optionally limiting the amount. + + Args: + buffer_type: 'regen' or 'noregen' + count: Number of items to retrieve. If None, retrieve all. + + Returns: + Dictionary with retrieved data. + Special handling: + - grid_thw keys: stack 1D tensors to 2D + - pixel_values keys: concatenate 2D tensors + """ + buffer = self.regen_buffer if buffer_type == 'regen' else self.noregen_buffer + result = {} + + for key, lst in buffer.items(): + if not lst: + # Return empty tensor with proper shape + if 'grid_thw' in key: + result[key] = torch.tensor([]).reshape(0, 3) + elif 'pixel_values' in key: + result[key] = torch.tensor([]) + else: + result[key] = torch.tensor([]) + + if count is None: + buffer[key] = [] + continue + + all_tensors = all(isinstance(item, torch.Tensor) for item in lst) + + if count is None: + # Retrieve all data + if all_tensors: + if 'pixel_values' in key and lst[0].dim() >= 2: + # Concatenate pixel_values 2D tensors + result[key] = torch.cat(lst, dim=0) if lst else torch.tensor([]) + elif 'grid_thw' in key and lst[0].dim() == 1: + # Stack grid_thw 1D tensors to 2D + result[key] = torch.stack(lst, dim=0) if lst else torch.tensor([]).reshape(0, 3) + elif lst[0].dim() == 1: + # Stack 1D tensors to 2D + result[key] = torch.stack(lst, dim=0) if lst else torch.tensor([]) + else: + # Concatenate other tensors + result[key] = torch.cat(lst, dim=0) if lst else torch.tensor([]) + else: + result[key] = lst.copy() + buffer[key] = [] + else: + # Retrieve specified number of items + items_to_take = lst[:count] + + if all_tensors and items_to_take: + if 'pixel_values' in key and items_to_take[0].dim() >= 2: + result[key] = torch.cat(items_to_take, dim=0) + elif 'grid_thw' in key and items_to_take[0].dim() == 1: + result[key] = torch.stack(items_to_take, dim=0) + elif items_to_take[0].dim() == 1: + result[key] = torch.stack(items_to_take, dim=0) + else: + result[key] = torch.cat(items_to_take, dim=0) + else: + result[key] = items_to_take + + # Update buffer + buffer[key] = lst[count:] + + return result + + @torch.no_grad() + def _regenerate_from_buffer(self, num_needed: int, is_multimodal: bool, **kwargs) -> dict: + """Regenerate outputs for samples that reached token budget.""" + config = self.strategy.config + + # Get data from regeneration buffer + regen_data = self._get_from_buffer("regen", num_needed) + if not regen_data.get("output"): + return {} + + # Identify indices needing regeneration + regen_indices = [ + i for i, output in enumerate(regen_data["output"]) + if len(output.output_token_ids) >= self.max_token_budget + ] + + if not regen_indices: + return regen_data + + # Prepare regeneration inputs + regen_outputs = [regen_data["output"][i] for i in regen_indices] + regen_tokens = [output.output_token_ids for output in regen_outputs] + decoded_outputs = self.tokenizer.batch_decode(regen_tokens, skip_special_tokens=False) + + # Create new inputs by combining original prompts and partial outputs + new_inputs = [ + prompt + output + for prompt, output in zip( + [regen_data["prompts"][i] for i in regen_indices], + decoded_outputs + ) + ] + + # Prepare sampling parameters + if config.engine_type == "vllm": + sampling_params = SamplingParams( + temperature=kwargs.get("temperature", 1.0), + top_p=kwargs.get("top_p", 1.0), + top_k=kwargs.get("top_k", -1), + max_tokens=kwargs.get("max_new_tokens", 1024), + min_tokens=kwargs.get("min_new_tokens", 1), + skip_special_tokens=kwargs.get("skip_special_tokens", False), + include_stop_str_in_output=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + elif config.engine_type == "sglang": + sampling_params = dict( + n=1, + temperature=kwargs.get("temperature", 1.0), + top_p=kwargs.get("top_p", 1.0), + top_k=kwargs.get("top_k", -1), + max_new_tokens=kwargs.get("max_new_tokens", 1024), + presence_penalty=0.0, + frequency_penalty=0.0, + repetition_penalty=1.0, + skip_special_tokens=kwargs.get("skip_special_tokens", False), + spaces_between_special_tokens=True, + ignore_eos=os.environ.get("IGNORE_EOS", "0") == "1", + ) + else: + raise ValueError(f"Unsupported backend: {config.engine_type}") + + # Build inputs and regenerate using the same pattern + if is_multimodal: + # Use strategy._build_multimodal_inputs + inputs = self.strategy._build_multimodal_inputs( + all_prompts=new_inputs, + all_images=[regen_data["images"][i] for i in regen_indices], + images_num=[regen_data["images_num"][i] for i in regen_indices] + ) + # Use engine_generate_local for multimodal regeneration + regenerated = self.strategy.engine_generate_local( + sampling_params=sampling_params, + prompt_token_ids=None, + multi_modal_inputs=inputs, + ) + else: + # For text-only, we can reuse parent's generate_samples but need raw outputs. + # Instead, we can directly call strategy.gather_and_generate with tokenized inputs. + # Tokenize new prompts + tokenized = self.tokenize_fn(new_inputs, self.prompt_max_len, padding=False) + prompt_token_ids = tokenized["input_ids"] + # Expand by n_samples_per_prompt (should be 1 for regeneration?) + # In partial rollout, each sample is already expanded, so we assume n_samples_per_prompt=1. + # Use strategy.gather_and_generate + regenerated = self.strategy.gather_and_generate( + sampling_params=sampling_params, + all_prompt_token_ids=prompt_token_ids, + all_prompts=None, + all_images=None, + sleep_engine=False, + images_num=None, + ) + + # Update regenerated outputs in regen_data + for idx, new_output in zip(regen_indices, regenerated): + regen_data["output"][idx] = new_output + + return regen_data + + def _merge_data(self, data1: Dict[str, List], data2: Dict[str, List]) -> Dict[str, List]: + """Merge two data dictionaries, concatenating lists or tensors.""" + merged = {} + for key in set(data1.keys()) | set(data2.keys()): + val1 = data1.get(key, []) + val2 = data2.get(key, []) + if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + merged[key] = torch.cat([val1, val2]) + elif isinstance(val1, list) and isinstance(val2, list): + merged[key] = val1 + val2 + else: + merged[key] = val1 if val1 else val2 + return merged diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index d79a7458..93132235 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -21,15 +21,18 @@ import time import torch +import math from tqdm import tqdm from lightrft.trainer import PPOTrainer, PPOTrainerVL from lightrft.trainer.fast_exp_maker import FastExperienceMaker +from lightrft.trainer.fast_exp_maker_partial import PartialFastExperienceMaker from lightrft.utils.trajectory_saver import create_trajectory_saver from lightrft.trainer.replay_buffer import make_experience_batch from lightrft.trainer.replay_buffer_vl import make_experience_batch as make_experience_batch_vl from lightrft.models.utils import create_high_entropy_mask +from lightrft.utils.distributed_sampler import DistributedSampler from lightrft.utils import init_logger logger = init_logger(__name__) @@ -689,3 +692,371 @@ def __init__( # Then initialize our base class assert "processor" in kwargs and kwargs["processor"] is not None, "processor is required for SPMDPPOTrainerVL" SPMDPPOTrainerBase.__init__(self, *args, VLM=True, **kwargs) + if self.args.use_partial: + # Replace experience maker with partial version + processor = kwargs.pop("processor", None) + self.experience_maker = PartialFastExperienceMaker( + self.actor, + self.critic, + self.reward_model, + self.initial_model, + self.tokenizer, + self.prompt_max_len, + self.kl_ctl, + self.strategy, + self.remote_rm_url, + self.reward_fn, + self.reward_fn_label_map, + self.reward_recipe, + packing_samples=self.packing_samples, + processor=processor, + partial_percent=getattr(self.args, "partial_percent", 0.7), + max_token_budget=getattr(self.args, "max_token_budget", 1024), + ) + + def _make_experience_iterator(self, dataloader, use_partial): + """ + Create an iterator that yields batches of experiences. + + This method handles both partial and non‑partial rollout logic. + For partial rollouts, it reuses cached prompts when possible to reduce + data loading overhead. For standard rollouts, it processes each batch + from the dataloader sequentially. + + :param dataloader: DataLoader providing prompts, images, references, and labels + :type dataloader: torch.utils.data.DataLoader + :param use_partial: Whether to use partial rollout logic + :type use_partial: bool + :yield: List of Experience objects for each training step + :ytype: List[lightrft.trainer.experience_maker_vl.Experience] + """ + if use_partial: + # Partial rollout logic + dataloader_iter = iter(dataloader) + while True: + # Generate experiences either from new prompts or cached ones + if self.experience_maker.need_new_prompts(self.args.rollout_batch_size, self.micro_rollout_batch_size): + try: + # Get next batch of prompts, images, references, and labels + batch = next(dataloader_iter) + # Handle variable batch size (4 or 5 elements) + if len(batch) == 5: + rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch + else: + rand_prompts, rand_images, rand_references, rand_labels = batch + rand_videos = None + except StopIteration: + # End of epoch reached + break + + # Generate experiences from new prompts + experiences = self.experience_maker.make_experience_list( + rand_prompts, rand_images, rand_videos, rand_references, rand_labels, + **self.generate_kwargs + ) + else: + # Generate experiences from cached prompts + experiences = self.experience_maker.make_experience_list( + None, None, None, None, None, **self.generate_kwargs + ) + yield experiences + else: + # Non-partial rollout logic + for batch in dataloader: + # Compatible with both image-only (4 args) and video (5 args) dataloaders + if len(batch) == 5: + rand_prompts, rand_images, rand_videos, rand_references, rand_labels = batch + else: + rand_prompts, rand_images, rand_references, rand_labels = batch + rand_videos = None + + # TODO: Remove debug print + self.strategy.print( + f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa + ) + + experiences = self.experience_maker.make_experience_list( + rand_prompts, rand_images, rand_videos, rand_references, rand_labels, + **self.generate_kwargs + ) + + # Debug print for first experience + for i, experience in enumerate(experiences): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + self.strategy.print( + f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + ) + break + + yield experiences + + def _process_experiences_and_train(self, experiences, steps): + """ + Process a batch of experiences: add to replay buffer, train, and update metrics. + + This method handles the core training loop for each batch of experiences: + 1. Appends experiences to the replay buffer + 2. Reports memory usage + 3. Normalizes advantages (if not using group normalization) + 4. Executes PPO training + 5. Clears the replay buffer + 6. Updates KL control coefficient + + :param experiences: List of Experience objects to process + :type experiences: List[lightrft.trainer.experience_maker_vl.Experience] + :param steps: Current step counter for training progress tracking + :type steps: int + :return: Dictionary containing training status metrics (policy loss, critic loss, reward, etc.) + :rtype: Dict[str, float] + """ + # Add experiences to replay buffer + for i, experience in enumerate(experiences): + if i == 0: + # Decode first experience for debugging/monitoring + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.replay_buffer.append(experience) + + # Report memory usage after replay buffer is filled + self.strategy.report_memory('after replay_buffer ready') + + # Aggregate rollout statistics from replay buffer + # Collect metrics from the rollout/collection phase + rollout_status = {} + if self.replay_buffer.items: + all_rewards = [] + all_format_rewards = [] + all_accuracy_rewards = [] + all_response_lengths = [] + + for item in self.replay_buffer.items: + # Collect rewards from rollout + if hasattr(item, 'info') and item.info is not None and 'reward' in item.info: + all_rewards.append(item.info['reward']) + + # Robust handling of reward_metrics + # 1. Check if info exists + # 2. Check if 'reward_metrics' key exists + # 3. Check if reward_metrics is not None (critical!) + if ( + hasattr(item, 'info') and item.info is not None and 'reward_metrics' in item.info + and item.info['reward_metrics'] is not None + ): + + reward_metrics = item.info['reward_metrics'] + + # Safely extract sub-metrics + if 'format_reward' in reward_metrics: + all_format_rewards.append(reward_metrics['format_reward']) + if 'accuracy_reward' in reward_metrics: + all_accuracy_rewards.append(reward_metrics['accuracy_reward']) + + # Collect response lengths from rollout + if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: + all_response_lengths.append(item.info['response_length']) + + # Compute rollout statistics + device = torch.cuda.current_device() + + if all_rewards: + # [TENSOR-FIX] Handle both tensor lists and scalar lists + if isinstance(all_rewards[0], torch.Tensor): + rewards_tensor = torch.cat([t.to(device).float() for t in all_rewards]) + else: + rewards_tensor = torch.tensor(all_rewards, dtype=torch.float32, device=device) + rollout_status["rollout_reward"] = rewards_tensor.mean().item() + rollout_status["rollout_reward_std"] = rewards_tensor.std().item() + + if all_format_rewards: + # [TENSOR-FIX] Handle both tensor lists and scalar lists + if isinstance(all_format_rewards[0], torch.Tensor): + format_tensor = torch.cat([t.to(device).float() for t in all_format_rewards]) + else: + format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) + + mean_format_reward = format_tensor.mean().item() + + # Only display if mean is significantly non-zero + if abs(mean_format_reward) > 1e-6: + rollout_status["rollout_format_reward"] = mean_format_reward + + if all_accuracy_rewards: + # [TENSOR-FIX] Handle both tensor lists and scalar lists + if isinstance(all_accuracy_rewards[0], torch.Tensor): + accuracy_tensor = torch.cat([t.to(device).float() for t in all_accuracy_rewards]) + else: + accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) + + mean_accuracy_reward = accuracy_tensor.mean().item() + + # Only display if mean is significantly non-zero + if abs(mean_accuracy_reward) > 1e-6: + rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward + + if all_response_lengths: + # [TENSOR-FIX] Handle both tensor lists and scalar lists + if isinstance(all_response_lengths[0], torch.Tensor): + lengths_tensor = torch.cat([t.to(device).float() for t in all_response_lengths]) + else: + lengths_tensor = torch.tensor(all_response_lengths, dtype=torch.float32, device=device) + + rollout_status["rollout_response_length"] = lengths_tensor.mean().item() + + # Normalize advantages if not using group normalization + if self.args.advantage_estimator != "group_norm": + self.replay_buffer.normalize("advantages", self.strategy) + + # Execute training phase + self.strategy.report_memory('before train') + status = self.ppo_train(steps) + self.strategy.report_memory('before clear buffer') + + # Clear replay buffer for next iteration + self.replay_buffer.clear() + self.strategy.report_memory('after train') + + # Update KL control coefficient + if "kl" in status: + self.kl_ctl.update(status["kl"], self.args.rollout_batch_size * self.args.n_samples_per_prompt) + + # Merge rollout status and training status + merged_status = {**rollout_status, **status} + return merged_status + + def fit( + self, + args, + prompts_dataloader, + pretrain_dataloader, + eval_dataloader=None, + consumed_samples=0, + num_update_steps_per_episodes=1, + ) -> None: + """ + Main training loop for PPO. + + :param args: Training arguments. + :type args: Namespace + :param prompts_dataloader: DataLoader for prompt data. + :type prompts_dataloader: DataLoader + :param pretrain_dataloader: DataLoader for pre-training data. + :type pretrain_dataloader: DataLoader + :param eval_dataloader: DataLoader for evaluation data, defaults to None. + :type eval_dataloader: DataLoader, optional + :param consumed_samples: Number of samples already consumed, defaults to 0. + :type consumed_samples: int + :param num_update_steps_per_episodes: Number of update steps per episode, defaults to 1. + :type num_update_steps_per_episodes: int + """ + # Determine if using partial rollout + use_partial = getattr(self.args, 'use_partial', False) + + # Calculate samples per rollout and per training iteration + samples_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt + samples_per_train = args.train_batch_size * args.n_samples_per_prompt + + # Print training mode information + if args.train_batch_size < args.rollout_batch_size: + updates_per_rollout = samples_per_rollout / samples_per_train + self.strategy.print( + f"\n{'=' * 80}\n" + f"HIGH FREQUENCY UPDATE MODE: train_batch_size ({args.train_batch_size}) < rollout_batch_size ({args.rollout_batch_size})\n" # noqa + f"{'=' * 80}\n" + f"Behavior:\n" + f" - Each rollout generates {samples_per_rollout} samples.\n" + f" - Each rollout will trigger {updates_per_rollout:.2f} optimizer updates.\n" + f" - Total updates will be HIGHER than standard mode for the same amount of data.\n" + f"{'=' * 80}\n" + ) + elif args.train_batch_size > args.rollout_batch_size: + self.strategy.print( + f"\n{'=' * 80}\n" + f"ACCUMULATION MODE: train_batch_size ({args.train_batch_size}) > rollout_batch_size ({args.rollout_batch_size})\n" # noqa + f"{'=' * 80}\n" + f"Behavior:\n" + f" - Multiple rollouts needed for one update.\n" + f"{'=' * 80}\n" + ) + + # Calculate number of rollouts per episode. + # Regardless of TBS and RBS relationship, rollout count should be determined by "total data / rollout size". + # Numerator (num_update_steps * train_batch_size) equals "total samples planned for this episode". + # Denominator (rollout_batch_size * n_samples) equals "samples produced per rollout". + # This calculation ensures data collection volume is constant. + # When TBS=64, num_update_steps is naturally twice as large as when TBS=128. + # Substituting into formula: (2N * 0.5T) / R = (N * T) / R. + # Conclusion: Rollout count unchanged, but internal update loop count doubles due to smaller TBS. + + num_rollouts_per_episodes = ( + num_update_steps_per_episodes * args.train_batch_size // args.max_epochs // args.rollout_batch_size // + args.n_samples_per_prompt + ) + + # Safeguard to prevent num_rollouts_per_episodes from being 0 + if num_rollouts_per_episodes == 0: + # Try recalculating with ceil to prevent fractional values from being discarded by integer division + val = (num_update_steps_per_episodes * + args.train_batch_size) / (args.max_epochs * args.rollout_batch_size * args.n_samples_per_prompt) + num_rollouts_per_episodes = math.ceil(val) + + if num_rollouts_per_episodes == 0: + self.strategy.print("[WARNING] Calculated num_rollouts_per_episodes is 0. Forcing to 1.") + num_rollouts_per_episodes = 1 + + # Get eval and save steps + if args.eval_steps == -1: + args.eval_steps = num_rollouts_per_episodes # Evaluate once per epoch + if args.save_steps == -1: + args.save_steps = float("inf") # Do not save checkpoint + + self.prompts_dataloader = prompts_dataloader + self.pretrain_dataloader = pretrain_dataloader + self.eval_dataloader = eval_dataloader # Save for evaluation + + # Restore step and start_episode + steps = consumed_samples // args.rollout_batch_size + 1 + start_episode = consumed_samples // args.rollout_batch_size // num_rollouts_per_episodes + consumed_samples = consumed_samples % (num_rollouts_per_episodes * args.rollout_batch_size) + + # Main training loop over episodes + for episode in range(start_episode, args.num_episodes): + # Configure distributed sampler for current episode + if isinstance(self.prompts_dataloader.sampler, DistributedSampler): + self.prompts_dataloader.sampler.set_epoch( + episode, consumed_samples=0 if episode > start_episode else consumed_samples + ) + + # Progress bar for monitoring training progress + pbar = tqdm( + range(self.prompts_dataloader.__len__()), + desc=f"Episode [{episode + 1}/{args.num_episodes}]", + disable=not self.strategy.is_rank_0(), + ) + + # Unified training loop using experience iterator + experience_iterator = self._make_experience_iterator(self.prompts_dataloader, use_partial) + + for experiences in experience_iterator: + # Process experiences and perform training step + status = self._process_experiences_and_train(experiences, steps) + + # Update progress bar with training status (includes rollout stats) + pbar.set_postfix(status) + + # Save logs and checkpoints at appropriate intervals + client_states = {"consumed_samples": steps * args.rollout_batch_size} + self.save_logs_and_checkpoints(args, steps, pbar, status, client_states, episode=episode) + + # Update step counter and progress bar + pbar.update() + steps = steps + 1 + # Clean up monitoring tools + if self._wandb is not None and self.strategy.is_rank_0(): + self._wandb.finish() + if self._tensorboard is not None and self.strategy.is_rank_0(): + self._tensorboard.close() \ No newline at end of file