From 0ebcdbfd984fd987ff6830746ab9530bc5fd9cf8 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Wed, 31 Dec 2025 20:07:15 +0800 Subject: [PATCH 1/8] refactor(sunjx): refactor loss-filter for sample filtering and loss weighting --- lightrft/trainer/filter_weight/__init__.py | 167 +++++++ lightrft/trainer/filter_weight/filters.py | 424 +++++++++++++++++ lightrft/trainer/filter_weight/manager.py | 461 ++++++++++++++++++ lightrft/trainer/filter_weight/metrics.py | 332 +++++++++++++ lightrft/trainer/filter_weight/weights.py | 525 +++++++++++++++++++++ 5 files changed, 1909 insertions(+) create mode 100644 lightrft/trainer/filter_weight/__init__.py create mode 100644 lightrft/trainer/filter_weight/filters.py create mode 100644 lightrft/trainer/filter_weight/manager.py create mode 100644 lightrft/trainer/filter_weight/metrics.py create mode 100644 lightrft/trainer/filter_weight/weights.py diff --git a/lightrft/trainer/filter_weight/__init__.py b/lightrft/trainer/filter_weight/__init__.py new file mode 100644 index 00000000..2489e31d --- /dev/null +++ b/lightrft/trainer/filter_weight/__init__.py @@ -0,0 +1,167 @@ +""" +Filter and Weight Module + +Unified interface for sample filtering and loss weighting in RLHF. + +This module provides a three-layer architecture for managing sample filtering +and loss weighting: + +1. **Metrics Layer**: Compute sample-level metrics (entropy, difficulty, staleness, etc.) +2. **Filter Layer**: Filter samples based on metrics (keep/discard decisions) +3. **Weight Layer**: Compute per-sample loss weights based on metrics + +The FilterWeightManager provides a high-level API to orchestrate these components. + +Example usage: + ```python + from lightrft.trainer.filter_weight import ( + FilterWeightManager, + ResponseLengthFilter, + DifficultyWeighting, + ) + + # Create manager + manager = FilterWeightManager( + filters=[ResponseLengthFilter(max_length=1024)], + weights=[(DifficultyWeighting(mode="prioritized"), 1.0)], + enable_metrics={"difficulty": True} + ) + + # Compute metrics + metrics = manager.compute_metrics(outputs) + + # Apply to experiences + experiences, weights = manager.apply_to_experiences(experiences, metrics) + ``` + +Author: LightRLHF Team +""" + +# Metrics +from .metrics import ( + SampleMetrics, + MetricsComputer, +) + +# Filters +from .filters import ( + SampleFilter, + ResponseLengthFilter, + RewardValueFilter, + EntropyFilter, + DifficultyFilter, + CompositeFilter, + PercentileFilter, +) + +# Weights +from .weights import ( + LossWeighting, + ResponseLengthWeighting, + EntropyWeighting, + DifficultyWeighting, + StalenessWeighting, + RewardMagnitudeWeighting, + CompositeWeighting, + UniformWeighting, +) + +# Manager +from .manager import ( + FilterWeightManager, + FilterWeightManagerBuilder, +) + + +__all__ = [ + # ========== Metrics ========== + "SampleMetrics", + "MetricsComputer", + + # ========== Filters ========== + "SampleFilter", + "ResponseLengthFilter", + "RewardValueFilter", + "EntropyFilter", + "DifficultyFilter", + "CompositeFilter", + "PercentileFilter", + + # ========== Weights ========== + "LossWeighting", + "ResponseLengthWeighting", + "EntropyWeighting", + "DifficultyWeighting", + "StalenessWeighting", + "RewardMagnitudeWeighting", + "CompositeWeighting", + "UniformWeighting", + + # ========== Manager ========== + "FilterWeightManager", + "FilterWeightManagerBuilder", +] + + +# Version info +__version__ = "1.0.0" +__author__ = "LightRLHF Team" + + +def get_version(): + """Get the version of the filter_weight module.""" + return __version__ + + +# Quick access functions for common use cases +def create_length_filter(max_length: int = 1024, **kwargs): + """ + Quick function to create a response length filter. + + Args: + max_length: Maximum response length + **kwargs: Additional arguments for ResponseLengthFilter + + Returns: + ResponseLengthFilter instance + """ + return ResponseLengthFilter(max_length=max_length, **kwargs) + + +def create_difficulty_weighting(mode: str = "prioritized", alpha: float = 0.6, **kwargs): + """ + Quick function to create difficulty weighting. + + Args: + mode: Weighting mode ("prioritized" or "curriculum") + alpha: Prioritization exponent + **kwargs: Additional arguments for DifficultyWeighting + + Returns: + DifficultyWeighting instance + """ + return DifficultyWeighting(mode=mode, alpha=alpha, **kwargs) + + +def create_manager_from_args(args, packing_samples: bool = False): + """ + Quick function to create FilterWeightManager from training arguments. + + Args: + args: Training arguments + packing_samples: Whether samples are packed + + Returns: + FilterWeightManager instance + """ + return FilterWeightManagerBuilder.from_args(args, packing_samples) + + +# Add convenience imports at module level +__all__.extend([ + "get_version", + "create_length_filter", + "create_difficulty_weighting", + "create_manager_from_args", +]) + diff --git a/lightrft/trainer/filter_weight/filters.py b/lightrft/trainer/filter_weight/filters.py new file mode 100644 index 00000000..a7df4944 --- /dev/null +++ b/lightrft/trainer/filter_weight/filters.py @@ -0,0 +1,424 @@ +""" +Sample Filtering Module + +Provides unified interface for filtering samples based on various criteria. + +Author: LightRLHF Team +""" + +from abc import ABC, abstractmethod +from typing import List, Optional +import torch +import warnings + + +class SampleFilter(ABC): + """ + Base class for sample filters. + + Filters determine which samples should be kept for training. They return + a boolean mask where True indicates the sample should be kept. + """ + + @abstractmethod + def filter( + self, + metrics, # SampleMetrics + experiences: List # List[ExperienceVL] + ) -> torch.Tensor: + """ + Compute filter mask. + + Args: + metrics: SampleMetrics containing computed metrics + experiences: List of Experience/ExperienceVL objects + + Returns: + mask: BoolTensor (total_samples,) where True = keep, False = filter out + """ + pass + + +class ResponseLengthFilter(SampleFilter): + """ + Filter samples based on response length. + + This filter can enforce minimum/maximum length constraints or use a buffer-based + approach (e.g., expected_length ± buffer_length). + """ + + def __init__( + self, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + expected_length: Optional[int] = None, + buffer_length: Optional[int] = None + ): + """ + Initialize response length filter. + + Args: + min_length: Minimum allowed response length (inclusive) + max_length: Maximum allowed response length (inclusive) + expected_length: Expected response length (for buffer-based filtering) + buffer_length: Buffer around expected length (filters if length > expected + buffer) + """ + self.min_length = min_length + self.max_length = max_length + self.expected_length = expected_length + self.buffer_length = buffer_length + + # Validation + if expected_length is not None and buffer_length is None: + warnings.warn( + "expected_length specified but buffer_length is None. " + "Buffer-based filtering will not be applied." + ) + + def filter(self, metrics, experiences): + """ + Filter based on length constraints. + + Args: + metrics: SampleMetrics with response_length + experiences: List of experiences (unused but kept for interface consistency) + + Returns: + mask: BoolTensor indicating which samples to keep + """ + lengths = metrics.response_length + mask = torch.ones(len(lengths), dtype=torch.bool, device=lengths.device) + + # Apply min length + if self.min_length is not None: + mask &= (lengths >= self.min_length) + + # Apply max length + if self.max_length is not None: + mask &= (lengths <= self.max_length) + + # Apply buffer-based filtering + if self.expected_length is not None and self.buffer_length is not None: + max_allowed = self.expected_length + self.buffer_length + mask &= (lengths <= max_allowed) + + return mask + + +class RewardValueFilter(SampleFilter): + """ + Filter samples with degenerate reward values (DAPO dynamic sampling). + + This filter detects and removes groups of samples where all rewards are identical + (e.g., all 0s or all 1s), which provides no learning signal. + """ + + def __init__( + self, + filter_all_zeros: bool = True, + filter_all_ones: bool = True, + n_samples_per_prompt: int = 1, + group_size: Optional[int] = None, + tolerance: float = 1e-6 + ): + """ + Initialize reward value filter. + + Args: + filter_all_zeros: Filter groups where all rewards are 0 + filter_all_ones: Filter groups where all rewards are 1 + n_samples_per_prompt: Number of samples per prompt (for grouping) + group_size: Custom group size (if None, uses n_samples_per_prompt) + tolerance: Tolerance for comparing reward values + """ + self.filter_all_zeros = filter_all_zeros + self.filter_all_ones = filter_all_ones + self.group_size = group_size or n_samples_per_prompt + self.tolerance = tolerance + + def filter(self, metrics, experiences): + """ + Filter groups with degenerate rewards. + + Args: + metrics: SampleMetrics with reward_value + experiences: List of experiences (unused) + + Returns: + mask: BoolTensor indicating which samples to keep + """ + if metrics.reward_value is None: + # No rewards available, keep all samples + total_samples = sum(len(exp.sequences) for exp in experiences) + return torch.ones(total_samples, dtype=torch.bool, device='cuda') + + rewards = metrics.reward_value + mask = torch.ones(len(rewards), dtype=torch.bool, device=rewards.device) + + # Check if can evenly divide into groups + if len(rewards) % self.group_size != 0: + warnings.warn( + f"Number of samples ({len(rewards)}) not divisible by group_size ({self.group_size}). " + f"Skipping reward value filtering." + ) + return mask + + # Reshape into groups + grouped_rewards = rewards.reshape(-1, self.group_size) + + # Check for all-zero or all-one groups + group_mask = torch.ones(len(grouped_rewards), dtype=torch.bool, device=rewards.device) + + if self.filter_all_zeros: + all_zeros = torch.all(torch.abs(grouped_rewards) < self.tolerance, dim=1) + group_mask &= ~all_zeros + + if self.filter_all_ones: + all_ones = torch.all(torch.abs(grouped_rewards - 1.0) < self.tolerance, dim=1) + group_mask &= ~all_ones + + # Expand group mask back to sample level + mask = group_mask.repeat_interleave(self.group_size) + + return mask + + +class EntropyFilter(SampleFilter): + """ + Filter samples based on policy entropy. + + Entropy measures the uncertainty/diversity of the policy. Low entropy indicates + confident/deterministic generation, high entropy indicates uncertain/exploratory generation. + """ + + def __init__( + self, + min_entropy: Optional[float] = None, + max_entropy: Optional[float] = None + ): + """ + Initialize entropy filter. + + Args: + min_entropy: Minimum entropy threshold (filter out low entropy) + max_entropy: Maximum entropy threshold (filter out high entropy) + """ + self.min_entropy = min_entropy + self.max_entropy = max_entropy + + def filter(self, metrics, experiences): + """ + Filter based on entropy. + + Args: + metrics: SampleMetrics with entropy + experiences: List of experiences + + Returns: + mask: BoolTensor indicating which samples to keep + """ + if metrics.entropy is None: + # Entropy not computed, keep all samples + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, dtype=torch.bool, device=device) + + entropy = metrics.entropy + mask = torch.ones(len(entropy), dtype=torch.bool, device=entropy.device) + + # Apply min entropy threshold + if self.min_entropy is not None: + mask &= (entropy >= self.min_entropy) + + # Apply max entropy threshold + if self.max_entropy is not None: + mask &= (entropy <= self.max_entropy) + + return mask + + +class DifficultyFilter(SampleFilter): + """ + Filter samples based on difficulty scores. + + This can be used to implement curriculum learning (filter out hard samples early) + or focus training (filter out easy samples). + """ + + def __init__( + self, + min_difficulty: Optional[float] = None, + max_difficulty: Optional[float] = None, + mode: str = "absolute" # "absolute" or "percentile" + ): + """ + Initialize difficulty filter. + + Args: + min_difficulty: Minimum difficulty threshold + max_difficulty: Maximum difficulty threshold + mode: Filtering mode + - "absolute": Use absolute threshold values + - "percentile": Interpret thresholds as percentiles (0-100) + """ + self.min_difficulty = min_difficulty + self.max_difficulty = max_difficulty + self.mode = mode + + def filter(self, metrics, experiences): + """ + Filter based on difficulty. + + Args: + metrics: SampleMetrics with difficulty + experiences: List of experiences + + Returns: + mask: BoolTensor indicating which samples to keep + """ + if metrics.difficulty is None: + # Difficulty not computed, keep all samples + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, dtype=torch.bool, device=device) + + difficulty = metrics.difficulty + mask = torch.ones(len(difficulty), dtype=torch.bool, device=difficulty.device) + + # Handle percentile mode + if self.mode == "percentile": + if self.min_difficulty is not None: + threshold = torch.quantile(difficulty, self.min_difficulty / 100.0) + mask &= (difficulty >= threshold) + + if self.max_difficulty is not None: + threshold = torch.quantile(difficulty, self.max_difficulty / 100.0) + mask &= (difficulty <= threshold) + + else: # absolute mode + if self.min_difficulty is not None: + mask &= (difficulty >= self.min_difficulty) + + if self.max_difficulty is not None: + mask &= (difficulty <= self.max_difficulty) + + return mask + + +class CompositeFilter(SampleFilter): + """ + Combine multiple filters with AND/OR logic. + + This allows building complex filtering logic by composing simple filters. + """ + + def __init__(self, filters: List[SampleFilter], logic: str = "AND"): + """ + Initialize composite filter. + + Args: + filters: List of filters to combine + logic: Combination logic + - "AND": All filters must pass (intersection) + - "OR": Any filter must pass (union) + """ + self.filters = filters + self.logic = logic.upper() + + if self.logic not in ["AND", "OR"]: + raise ValueError(f"Invalid logic: {logic}. Must be 'AND' or 'OR'") + + def filter(self, metrics, experiences): + """ + Combine filters according to logic. + + Args: + metrics: SampleMetrics + experiences: List of experiences + + Returns: + mask: Combined BoolTensor + """ + if not self.filters: + # No filters, keep all samples + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, dtype=torch.bool, device=device) + + # Apply first filter + combined_mask = self.filters[0].filter(metrics, experiences) + + # Combine with rest + for f in self.filters[1:]: + mask = f.filter(metrics, experiences) + + if self.logic == "AND": + combined_mask &= mask + else: # OR + combined_mask |= mask + + return combined_mask + + +class PercentileFilter(SampleFilter): + """ + Filter samples based on percentile ranking of a metric. + + This is useful for keeping top-k% or bottom-k% samples according to some metric. + """ + + def __init__( + self, + metric_name: str, + top_percentile: Optional[float] = None, + bottom_percentile: Optional[float] = None + ): + """ + Initialize percentile filter. + + Args: + metric_name: Name of metric in SampleMetrics (e.g., "entropy", "difficulty") + top_percentile: Keep top X% (e.g., 20 = keep top 20%) + bottom_percentile: Keep bottom X% (e.g., 20 = keep bottom 20%) + """ + self.metric_name = metric_name + self.top_percentile = top_percentile + self.bottom_percentile = bottom_percentile + + if top_percentile is None and bottom_percentile is None: + raise ValueError("At least one of top_percentile or bottom_percentile must be specified") + + def filter(self, metrics, experiences): + """ + Filter based on percentile ranking. + + Args: + metrics: SampleMetrics + experiences: List of experiences + + Returns: + mask: BoolTensor indicating which samples to keep + """ + # Get metric value + metric_value = getattr(metrics, self.metric_name, None) + + if metric_value is None: + # Metric not available, keep all samples + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, dtype=torch.bool, device=device) + + mask = torch.zeros(len(metric_value), dtype=torch.bool, device=metric_value.device) + + # Keep top percentile (highest values) + if self.top_percentile is not None: + threshold = torch.quantile(metric_value, 1 - self.top_percentile / 100.0) + mask |= (metric_value >= threshold) + + # Keep bottom percentile (lowest values) + if self.bottom_percentile is not None: + threshold = torch.quantile(metric_value, self.bottom_percentile / 100.0) + mask |= (metric_value <= threshold) + + return mask + diff --git a/lightrft/trainer/filter_weight/manager.py b/lightrft/trainer/filter_weight/manager.py new file mode 100644 index 00000000..ae0435ce --- /dev/null +++ b/lightrft/trainer/filter_weight/manager.py @@ -0,0 +1,461 @@ +""" +Unified Filter-Weight Manager + +Provides high-level API for managing sample filtering and loss weighting. + +Author: LightRLHF Team +""" + +from typing import List, Optional, Dict, Tuple +import torch +import warnings + +from .metrics import MetricsComputer, SampleMetrics +from .filters import SampleFilter, CompositeFilter +from .weights import LossWeighting, CompositeWeighting, UniformWeighting + + +class FilterWeightManager: + """ + Unified manager for sample filtering and loss weighting. + + This class orchestrates the entire pipeline of: + 1. Computing sample metrics + 2. Applying filters to determine which samples to keep + 3. Computing loss weights for each sample + 4. Applying filters and weights to experiences + + Example usage: + ```python + manager = FilterWeightManager( + filters=[ + ResponseLengthFilter(max_length=1024), + RewardValueFilter(n_samples_per_prompt=4), + ], + weights=[ + (ResponseLengthWeighting(mode="inverse"), 0.5), + (DifficultyWeighting(mode="prioritized"), 0.5), + ], + enable_metrics={ + "entropy": True, + "difficulty": True, + "difficulty_mode": "td_error", + } + ) + + # In make_experience_list: + metrics = manager.compute_metrics(outputs) + experiences, weights = manager.apply_to_experiences(experiences, metrics) + ``` + + Args: + metrics_computer: Custom metrics computer (if None, creates default) + filters: List of filters to apply + weights: List of (weighting, coefficient) pairs + enable_metrics: Dict of metric names to enable + packing_samples: Whether samples are packed + """ + + def __init__( + self, + metrics_computer: Optional[MetricsComputer] = None, + filters: Optional[List[SampleFilter]] = None, + weights: Optional[List[Tuple[LossWeighting, float]]] = None, + enable_metrics: Optional[Dict[str, bool]] = None, + packing_samples: bool = False + ): + """ + Initialize filter-weight manager. + + Args: + metrics_computer: Custom metrics computer (default: MetricsComputer()) + filters: List of filters to apply (default: []) + weights: List of (weighting, coefficient) tuples (default: []) + enable_metrics: Dict specifying which metrics to compute (default: {}) + packing_samples: Whether samples are packed (affects metric computation) + """ + self.metrics_computer = metrics_computer or MetricsComputer(packing_samples) + self.filters = filters or [] + self.weights = weights or [] + self.enable_metrics = enable_metrics or {} + self.packing_samples = packing_samples + + # Validate configuration + self._validate_config() + + def _validate_config(self): + """Validate configuration and emit warnings if needed.""" + # Check if any filter/weight requires metrics that are not enabled + required_metrics = set() + + # Map filter/weight types to required metrics + filter_metric_map = { + "EntropyFilter": "entropy", + "DifficultyFilter": "difficulty", + } + weight_metric_map = { + "EntropyWeighting": "entropy", + "DifficultyWeighting": "difficulty", + "StalenessWeighting": "staleness", + } + + # Check filters + for f in self.filters: + filter_type = type(f).__name__ + if filter_type in filter_metric_map: + metric_name = filter_metric_map[filter_type] + if not self.enable_metrics.get(metric_name, False): + warnings.warn( + f"{filter_type} requires '{metric_name}' metric but it is not enabled. " + f"Filter may not work correctly." + ) + required_metrics.add(metric_name) + + # Check weights + for w, _ in self.weights: + weight_type = type(w).__name__ + if weight_type in weight_metric_map: + metric_name = weight_metric_map[weight_type] + if not self.enable_metrics.get(metric_name, False): + warnings.warn( + f"{weight_type} requires '{metric_name}' metric but it is not enabled. " + f"Weighting may not work correctly." + ) + required_metrics.add(metric_name) + + def compute_metrics( + self, + outputs: List, # List[_SamplesOutput] + current_step: Optional[int] = None + ) -> SampleMetrics: + """ + Compute all enabled metrics. + + Args: + outputs: List of sample outputs from model inference (_SamplesOutput objects) + current_step: Current training step (required for staleness computation) + + Returns: + SampleMetrics with computed metrics + """ + return self.metrics_computer.compute_all_metrics( + outputs, + self.enable_metrics, + current_step + ) + + def apply_filters( + self, + metrics: SampleMetrics, + experiences: List # List[ExperienceVL] + ) -> torch.Tensor: + """ + Apply all configured filters. + + Args: + metrics: Computed sample metrics + experiences: List of Experience/ExperienceVL objects + + Returns: + mask: BoolTensor (total_samples,) indicating which samples to keep + True = keep, False = filter out + """ + if not self.filters: + # No filters, keep all samples + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, dtype=torch.bool, device=device) + + # Combine all filters with AND logic + composite = CompositeFilter(self.filters, logic="AND") + return composite.filter(metrics, experiences) + + def compute_weights( + self, + metrics: SampleMetrics, + experiences: List # List[ExperienceVL] + ) -> torch.Tensor: + """ + Compute combined sample weights. + + Args: + metrics: Computed sample metrics + experiences: List of Experience/ExperienceVL objects + + Returns: + weights: FloatTensor (total_samples,) with loss weights + """ + if not self.weights: + # No weighting, return uniform weights + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, device=device) + + # Combine all weightings + composite = CompositeWeighting(self.weights, mode="weighted_sum") + return composite.compute_weights(metrics, experiences) + + def apply_to_experiences( + self, + experiences: List, # List[ExperienceVL] + metrics: SampleMetrics, + apply_filter_to_mask: bool = True, + apply_filter_to_weights: bool = True + ) -> Tuple[List, torch.Tensor]: + """ + Apply filtering and weighting to experiences. + + This method: + 1. Computes filter mask + 2. Updates action_mask to exclude filtered samples (if apply_filter_to_mask=True) + 3. Computes loss weights + 4. Zeros out weights for filtered samples (if apply_filter_to_weights=True) + + Args: + experiences: List of Experience/ExperienceVL objects to process + metrics: Computed sample metrics + apply_filter_to_mask: If True, update action_mask to exclude filtered samples + apply_filter_to_weights: If True, zero out weights for filtered samples + + Returns: + (experiences, weights): Modified experiences and per-sample weights + """ + # Apply filters + keep_mask = self.apply_filters(metrics, experiences) + + # Update action masks if requested + if apply_filter_to_mask: + sample_idx = 0 + for exp in experiences: + batch_size = len(exp.sequences) + batch_mask = keep_mask[sample_idx : sample_idx + batch_size] + + # Zero out action_mask for filtered samples + # This effectively removes them from loss computation + if exp.action_mask is not None: + exp.action_mask = exp.action_mask & batch_mask.unsqueeze(-1).to(exp.action_mask.device) + + sample_idx += batch_size + + # Compute weights + weights = self.compute_weights(metrics, experiences) + + # Zero out weights for filtered samples if requested + if apply_filter_to_weights: + weights = weights * keep_mask.float() + + return experiences, weights + + def get_filter_stats( + self, + metrics: SampleMetrics, + experiences: List # List[ExperienceVL] + ) -> Dict[str, float]: + """ + Get statistics about filtering. + + Args: + metrics: Computed sample metrics + experiences: List of experiences + + Returns: + Dict with statistics: + - "total_samples": Total number of samples + - "filtered_samples": Number of filtered samples + - "filter_rate": Fraction of samples filtered + - "kept_samples": Number of kept samples + """ + total_samples = sum(len(exp.sequences) for exp in experiences) + keep_mask = self.apply_filters(metrics, experiences) + kept_samples = keep_mask.sum().item() + filtered_samples = total_samples - kept_samples + filter_rate = filtered_samples / total_samples if total_samples > 0 else 0.0 + + return { + "total_samples": total_samples, + "filtered_samples": filtered_samples, + "kept_samples": kept_samples, + "filter_rate": filter_rate, + } + + def get_weight_stats( + self, + metrics: SampleMetrics, + experiences: List # List[ExperienceVL] + ) -> Dict[str, float]: + """ + Get statistics about weighting. + + Args: + metrics: Computed sample metrics + experiences: List of experiences + + Returns: + Dict with statistics: + - "weight_mean": Mean weight + - "weight_std": Standard deviation of weights + - "weight_min": Minimum weight + - "weight_max": Maximum weight + """ + weights = self.compute_weights(metrics, experiences) + + return { + "weight_mean": weights.mean().item(), + "weight_std": weights.std().item(), + "weight_min": weights.min().item(), + "weight_max": weights.max().item(), + } + + def log_stats( + self, + metrics: SampleMetrics, + experiences: List, # List[ExperienceVL] + logger=None + ): + """ + Log filtering and weighting statistics. + + Args: + metrics: Computed sample metrics + experiences: List of experiences + logger: Logger object (if None, uses print) + """ + filter_stats = self.get_filter_stats(metrics, experiences) + weight_stats = self.get_weight_stats(metrics, experiences) + + log_fn = logger.info if logger else print + + log_fn("=" * 60) + log_fn("Filter & Weight Statistics") + log_fn("=" * 60) + + # Filter stats + log_fn(f"Total samples: {filter_stats['total_samples']}") + log_fn(f"Kept samples: {filter_stats['kept_samples']}") + log_fn(f"Filtered samples: {filter_stats['filtered_samples']}") + log_fn(f"Filter rate: {filter_stats['filter_rate']:.2%}") + + # Weight stats + log_fn(f"Weight mean: {weight_stats['weight_mean']:.4f}") + log_fn(f"Weight std: {weight_stats['weight_std']:.4f}") + log_fn(f"Weight range: [{weight_stats['weight_min']:.4f}, {weight_stats['weight_max']:.4f}]") + + log_fn("=" * 60) + + +class FilterWeightManagerBuilder: + """ + Builder pattern for constructing FilterWeightManager from args. + + This provides a convenient way to construct a FilterWeightManager from + training arguments. + + Example: + ```python + builder = FilterWeightManagerBuilder() + manager = builder.from_args(args) + ``` + """ + + @staticmethod + def from_args(args, packing_samples: bool = False) -> FilterWeightManager: + """ + Build FilterWeightManager from training arguments. + + Args: + args: Training arguments object + packing_samples: Whether samples are packed + + Returns: + FilterWeightManager configured according to args + """ + from .filters import ResponseLengthFilter, RewardValueFilter, EntropyFilter + from .weights import ( + ResponseLengthWeighting, + EntropyWeighting, + DifficultyWeighting, + StalenessWeighting, + ) + + # Build filters + filters = [] + + # Response length filter (from overlong_buffer settings) + if getattr(args, "overlong_buffer", False): + expected_len = getattr(args, "max_new_tokens", 1024) - getattr(args, "overlong_buffer_len", 0) + buffer_len = getattr(args, "overlong_buffer_len", 0) + filters.append(ResponseLengthFilter( + expected_length=expected_len, + buffer_length=buffer_len + )) + + # Reward value filter (for dynamic sampling) + if getattr(args, "dynamic_sampling", False) and getattr(args, "advantage_estimator", "") == "group_norm": + filters.append(RewardValueFilter( + n_samples_per_prompt=getattr(args, "n_samples_per_prompt", 1) + )) + + # Entropy filter + if getattr(args, "enable_entropy_filter", False): + filters.append(EntropyFilter( + min_entropy=getattr(args, "min_entropy", None), + max_entropy=getattr(args, "max_entropy", None) + )) + + # Build weights + weights = [] + + # Response length weighting + if getattr(args, "enable_length_weighting", False): + weight = ResponseLengthWeighting( + mode=getattr(args, "length_weight_mode", "inverse"), + normalize=True + ) + coef = getattr(args, "length_weight_coef", 1.0) + weights.append((weight, coef)) + + # Entropy weighting + if getattr(args, "enable_entropy_weighting", False): + weight = EntropyWeighting( + mode=getattr(args, "entropy_weight_mode", "favor_high"), + temperature=getattr(args, "entropy_weight_temperature", 1.0), + normalize=True + ) + coef = getattr(args, "entropy_weight_coef", 1.0) + weights.append((weight, coef)) + + # Difficulty weighting + if getattr(args, "enable_difficulty_weighting", False): + weight = DifficultyWeighting( + mode=getattr(args, "difficulty_weight_mode", "prioritized"), + alpha=getattr(args, "difficulty_alpha", 0.6), + normalize=True + ) + coef = getattr(args, "difficulty_weight_coef", 1.0) + weights.append((weight, coef)) + + # Staleness weighting + if getattr(args, "enable_staleness_weighting", False): + weight = StalenessWeighting( + decay_factor=getattr(args, "staleness_decay_factor", 0.95), + normalize=True + ) + coef = getattr(args, "staleness_weight_coef", 1.0) + weights.append((weight, coef)) + + # Build enable_metrics dict + enable_metrics = { + "entropy": getattr(args, "compute_entropy", False) or getattr(args, "enable_entropy_filter", False) or getattr(args, "enable_entropy_weighting", False), + "difficulty": getattr(args, "enable_difficulty_weighting", False), + "difficulty_mode": getattr(args, "difficulty_mode", "td_error"), + "staleness": getattr(args, "enable_staleness_weighting", False), + "staleness_mode": getattr(args, "staleness_mode", "linear"), + } + + return FilterWeightManager( + filters=filters, + weights=weights, + enable_metrics=enable_metrics, + packing_samples=packing_samples + ) + diff --git a/lightrft/trainer/filter_weight/metrics.py b/lightrft/trainer/filter_weight/metrics.py new file mode 100644 index 00000000..346bbcef --- /dev/null +++ b/lightrft/trainer/filter_weight/metrics.py @@ -0,0 +1,332 @@ +""" +Metrics Computation Module + +Provides unified interface for computing various sample-level metrics +used in filtering and weighting. + +Author: LightRLHF Team +""" + +from dataclasses import dataclass +from typing import Optional, Dict, List +import torch +from lightrft.models.utils import masked_mean, unpacking_samples + + +@dataclass +class SampleMetrics: + """ + Container for all computed sample metrics. + + All metrics are per-sample tensors aligned with the sample indices. + Shapes are (total_samples,) where total_samples = sum(batch_size across all micro-batches). + + Attributes: + response_length: Length of generated responses (total_samples,) + entropy: Policy entropy -sum(p * log(p)) (total_samples,) + logit_kl: KL divergence at logit level (total_samples,) + difficulty: Sample difficulty score (total_samples,) + staleness: Sample age/staleness (total_samples,) + reward_value: Reward values (total_samples,) + n_samples_per_prompt: Number of samples per prompt (for grouping) + micro_batch_size: Micro batch size (for splitting) + """ + # Core metrics (always available) + response_length: torch.Tensor # (total_samples,) + + # Optional metrics + entropy: Optional[torch.Tensor] = None # (total_samples,) + logit_kl: Optional[torch.Tensor] = None # (total_samples,) + difficulty: Optional[torch.Tensor] = None # (total_samples,) + staleness: Optional[torch.Tensor] = None # (total_samples,) + reward_value: Optional[torch.Tensor] = None # (total_samples,) + + # Auxiliary data for filtering/weighting + n_samples_per_prompt: Optional[int] = None + micro_batch_size: Optional[int] = None + + +class MetricsComputer: + """ + Compute various metrics for experience samples. + + This class provides methods to compute sample-level metrics that can be used + for filtering and weighting during experience generation. + + Args: + packing_samples: Whether samples are packed (affects unpacking logic) + """ + + def __init__(self, packing_samples: bool = False): + """ + Initialize metrics computer. + + Args: + packing_samples: Whether samples are packed into single sequences + """ + self.packing_samples = packing_samples + + def compute_entropy( + self, + action_log_probs: torch.Tensor, + action_mask: torch.Tensor, + num_actions: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Compute policy entropy per sample. + + Entropy = -sum(p * log(p)) where p = exp(log_prob) + Higher entropy indicates more uncertain/exploratory policy. + + Args: + action_log_probs: Log probabilities (batch_size, seq_len) or (1, total_len) for packed + action_mask: Action mask (batch_size, seq_len) or (1, total_len) for packed + num_actions: Number of actions per sample (for packed samples) + + Returns: + entropy: Per-sample entropy (batch_size,) + """ + # Convert log probs to probs + probs = torch.exp(action_log_probs) + + # Entropy = -sum(p * log(p)) = -sum(p * log_p) + entropy_per_token = -(probs * action_log_probs) + + # Handle packed vs unpacked samples + if self.packing_samples and num_actions is not None: + # Unpack and compute mean for each sample + entropy_unpacked = unpacking_samples(entropy_per_token, num_actions) + return torch.tensor([ent.mean() for ent in entropy_unpacked], device=action_log_probs.device) + else: + # Mask and average over sequence + return masked_mean(entropy_per_token, action_mask, dim=-1) + + def compute_logit_kl( + self, + current_logits: torch.Tensor, + reference_logits: torch.Tensor, + action_mask: torch.Tensor, + num_actions: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """ + Compute KL divergence between logit distributions. + + This differs from log-prob KL in that it operates on the full + distribution rather than just the selected actions. + + KL(current || reference) = sum(p_curr * (log p_curr - log p_ref)) + + Args: + current_logits: Logits from current policy (batch, seq, vocab) + reference_logits: Logits from reference policy (batch, seq, vocab) + action_mask: Action mask (batch, seq) + num_actions: Number of actions per sample (for packed samples) + + Returns: + kl: Per-sample KL divergence (batch_size,) + """ + # Convert to log probabilities + current_log_probs = torch.log_softmax(current_logits, dim=-1) + reference_log_probs = torch.log_softmax(reference_logits, dim=-1) + + # KL(current || reference) = sum(p_curr * (log p_curr - log p_ref)) + kl_per_token = torch.sum( + torch.exp(current_log_probs) * (current_log_probs - reference_log_probs), + dim=-1 + ) + + # Handle packed vs unpacked samples + if self.packing_samples and num_actions is not None: + kl_unpacked = unpacking_samples(kl_per_token, num_actions) + return torch.tensor([kl.mean() for kl in kl_unpacked], device=kl_per_token.device) + else: + return masked_mean(kl_per_token, action_mask, dim=-1) + + def compute_difficulty( + self, + rewards: torch.Tensor, + values: Optional[torch.Tensor] = None, + mode: str = "td_error" + ) -> torch.Tensor: + """ + Compute sample difficulty. + + Difficulty can be measured in various ways: + - "td_error": Temporal difference error |reward - value| + - "low_reward": Inverse of reward (harder = lower reward) + - "high_variance": Variance of reward within group (placeholder) + - "abs_reward": Absolute reward magnitude + + Args: + rewards: Per-sample rewards (total_samples,) + values: Per-sample value estimates (total_samples,) - required for "td_error" + mode: Difficulty computation mode + + Returns: + difficulty: Per-sample difficulty scores (total_samples,) + """ + if mode == "td_error": + if values is None: + raise ValueError("Values required for td_error difficulty mode") + # Higher TD error = more surprising = harder + return torch.abs(rewards - values) + + elif mode == "low_reward": + # Lower reward = harder (normalized to [0, 1]) + min_r, max_r = rewards.min(), rewards.max() + if max_r - min_r < 1e-6: + return torch.ones_like(rewards) * 0.5 + return 1 - (rewards - min_r) / (max_r - min_r) + + elif mode == "abs_reward": + # Absolute reward magnitude + return torch.abs(rewards) + + elif mode == "high_variance": + # Compute variance within groups (placeholder - needs grouping info) + # This would require n_samples_per_prompt and group-wise computation + # For now, return zeros as placeholder + return torch.zeros_like(rewards) + + else: + raise ValueError(f"Unknown difficulty mode: {mode}") + + def compute_staleness( + self, + generation_steps: torch.Tensor, + current_step: int, + mode: str = "linear" + ) -> torch.Tensor: + """ + Compute sample staleness based on age. + + Staleness measures how old a sample is relative to current training step. + Older samples may be less relevant due to policy shift. + + Args: + generation_steps: Step when each sample was generated (total_samples,) + current_step: Current training step + mode: Staleness computation mode + - "linear": age normalized to [0, 1] + - "exponential": 1 - exp(-age / tau) + + Returns: + staleness: Per-sample staleness scores (total_samples,) + """ + age = current_step - generation_steps + + if mode == "linear": + # Normalize to [0, 1] range + max_age = age.max() + if max_age < 1: + return torch.zeros_like(age, dtype=torch.float32) + return age.float() / max_age + + elif mode == "exponential": + # Exponential decay: 1 - exp(-age / tau) + tau = 10.0 # Half-life parameter + return 1 - torch.exp(-age.float() / tau) + + else: + raise ValueError(f"Unknown staleness mode: {mode}") + + def compute_all_metrics( + self, + outputs: List, # List[_SamplesOutput] + enable_flags: Dict[str, bool], + current_step: Optional[int] = None + ) -> SampleMetrics: + """ + Compute all enabled metrics from sample outputs. + + This is the main entry point for computing metrics. It checks enable_flags + to determine which metrics to compute and returns a SampleMetrics object. + + Args: + outputs: List of _SamplesOutput objects from model inference + enable_flags: Dict indicating which metrics to compute + - "entropy": bool + - "logit_kl": bool + - "difficulty": bool + - "difficulty_mode": str + - "staleness": bool + - "staleness_mode": str + current_step: Current training step (required for staleness) + + Returns: + SampleMetrics with all enabled metrics computed + """ + # Collect core metrics (always available) + response_lengths = torch.cat([out.response_length for out in outputs]) + + # Get rewards if available + rewards = None + if outputs[0].rewards is not None: + rewards = torch.cat([out.rewards for out in outputs]) + + # Initialize optional metrics + entropy = None + logit_kl = None + difficulty = None + staleness = None + + # Compute entropy if enabled + if enable_flags.get("entropy", False): + if outputs[0].action_log_probs is not None: + action_log_probs_list = [] + action_mask_list = [] + num_actions_list = [] + + for out in outputs: + action_log_probs_list.append(out.action_log_probs) + action_mask_list.append(out.action_mask) + if self.packing_samples: + num_actions_list.extend(out.num_actions) + + # Concatenate + action_log_probs = torch.cat(action_log_probs_list, dim=0) + action_mask = torch.cat(action_mask_list, dim=0) + + if self.packing_samples: + entropy = self.compute_entropy( + action_log_probs, action_mask, num_actions_list + ) + else: + entropy = self.compute_entropy(action_log_probs, action_mask) + + # Compute logit KL if enabled (requires storing logits, not typically available) + # This is a placeholder for future implementation + if enable_flags.get("logit_kl", False): + # Would need access to current_logits and reference_logits + # For now, leave as None + pass + + # Compute difficulty if enabled + if enable_flags.get("difficulty", False) and rewards is not None: + difficulty_mode = enable_flags.get("difficulty_mode", "td_error") + + # Get values if needed + values = None + if difficulty_mode == "td_error" and outputs[0].value is not None: + values = torch.cat([out.value for out in outputs]) + + difficulty = self.compute_difficulty(rewards, values, mode=difficulty_mode) + + # Compute staleness if enabled + if enable_flags.get("staleness", False) and current_step is not None: + # Would need to track generation_steps in outputs + # For now, this is a placeholder + # In practice, you'd need to add generation_step to _SamplesOutput + pass + + return SampleMetrics( + response_length=response_lengths, + entropy=entropy, + logit_kl=logit_kl, + difficulty=difficulty, + staleness=staleness, + reward_value=rewards, + n_samples_per_prompt=None, # Can be set externally + micro_batch_size=None, # Can be set externally + ) + diff --git a/lightrft/trainer/filter_weight/weights.py b/lightrft/trainer/filter_weight/weights.py new file mode 100644 index 00000000..a17666ab --- /dev/null +++ b/lightrft/trainer/filter_weight/weights.py @@ -0,0 +1,525 @@ +""" +Loss Weighting Module + +Provides unified interface for computing sample-level loss weights. + +Author: LightRLHF Team +""" + +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple +import torch +import warnings + + +class LossWeighting(ABC): + """ + Base class for loss weighting. + + Weightings compute per-sample weights that modulate the contribution + of each sample to the loss function. + """ + + @abstractmethod + def compute_weights( + self, + metrics, # SampleMetrics + experiences: List # List[ExperienceVL] + ) -> torch.Tensor: + """ + Compute per-sample weights. + + Args: + metrics: SampleMetrics containing computed metrics + experiences: List of Experience/ExperienceVL objects + + Returns: + weights: FloatTensor (total_samples,) with loss weights + """ + pass + + +class ResponseLengthWeighting(LossWeighting): + """ + Weight samples by response length. + + This can be used to: + - Give more weight to longer responses (mode="linear") + - Give more weight to shorter responses (mode="inverse") + - Balance weights by length (mode="sqrt", "log") + """ + + def __init__( + self, + mode: str = "linear", + normalize: bool = True, + clip_min: Optional[float] = None, + clip_max: Optional[float] = None, + epsilon: float = 1e-6 + ): + """ + Initialize response length weighting. + + Args: + mode: Weighting mode + - "linear": weight = length + - "inverse": weight = 1/length + - "sqrt": weight = sqrt(length) + - "log": weight = log(1 + length) + normalize: Whether to normalize weights to mean=1 + clip_min: Minimum weight value + clip_max: Maximum weight value + epsilon: Small constant to avoid division by zero + """ + self.mode = mode + self.normalize = normalize + self.clip_min = clip_min + self.clip_max = clip_max + self.epsilon = epsilon + + valid_modes = ["linear", "inverse", "sqrt", "log"] + if mode not in valid_modes: + raise ValueError(f"Invalid mode: {mode}. Must be one of {valid_modes}") + + def compute_weights(self, metrics, experiences): + """ + Compute length-based weights. + + Args: + metrics: SampleMetrics with response_length + experiences: List of experiences (unused) + + Returns: + weights: FloatTensor of per-sample weights + """ + lengths = metrics.response_length.float() + + # Compute weights according to mode + if self.mode == "linear": + weights = lengths + elif self.mode == "inverse": + weights = 1.0 / (lengths + self.epsilon) + elif self.mode == "sqrt": + weights = torch.sqrt(lengths + self.epsilon) + elif self.mode == "log": + weights = torch.log(1.0 + lengths) + else: + raise ValueError(f"Unknown mode: {self.mode}") + + # Clip if specified + if self.clip_min is not None: + weights = torch.clamp(weights, min=self.clip_min) + if self.clip_max is not None: + weights = torch.clamp(weights, max=self.clip_max) + + # Normalize to mean=1 + if self.normalize: + weights = weights / (weights.mean() + self.epsilon) + + return weights + + +class EntropyWeighting(LossWeighting): + """ + Weight samples by policy entropy. + + This can encourage exploration (favor high entropy) or exploitation (favor low entropy). + """ + + def __init__( + self, + mode: str = "favor_high", + temperature: float = 1.0, + normalize: bool = True, + epsilon: float = 1e-6 + ): + """ + Initialize entropy weighting. + + Args: + mode: Weighting mode + - "favor_high": Higher entropy → higher weight (encourage exploration) + - "favor_low": Lower entropy → higher weight (encourage exploitation) + - "linear": weight = entropy (linear scaling) + - "inverse": weight = 1/entropy (inverse scaling) + temperature: Temperature for softmax weighting (used in favor_high/favor_low modes) + normalize: Normalize to mean=1 + epsilon: Small constant to avoid numerical issues + """ + self.mode = mode + self.temperature = temperature + self.normalize = normalize + self.epsilon = epsilon + + valid_modes = ["favor_high", "favor_low", "linear", "inverse"] + if mode not in valid_modes: + raise ValueError(f"Invalid mode: {mode}. Must be one of {valid_modes}") + + def compute_weights(self, metrics, experiences): + """ + Compute entropy-based weights. + + Args: + metrics: SampleMetrics with entropy + experiences: List of experiences + + Returns: + weights: FloatTensor of per-sample weights + """ + if metrics.entropy is None: + # Entropy not available, return uniform weights + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, device=device) + + entropy = metrics.entropy + + # Compute weights according to mode + if self.mode == "favor_high": + # Softmax over entropy (higher entropy → higher weight) + weights = torch.softmax(entropy / self.temperature, dim=0) * len(entropy) + elif self.mode == "favor_low": + # Inverse softmax (lower entropy → higher weight) + weights = torch.softmax(-entropy / self.temperature, dim=0) * len(entropy) + elif self.mode == "linear": + weights = entropy + elif self.mode == "inverse": + weights = 1.0 / (entropy + self.epsilon) + else: + raise ValueError(f"Unknown mode: {self.mode}") + + # Normalize to mean=1 + if self.normalize: + weights = weights / (weights.mean() + self.epsilon) + + return weights + + +class DifficultyWeighting(LossWeighting): + """ + Weight samples by difficulty. + + This implements prioritized experience replay (PER) style weighting or + curriculum learning approaches. + """ + + def __init__( + self, + mode: str = "prioritized", + alpha: float = 0.6, + normalize: bool = True, + epsilon: float = 1e-6 + ): + """ + Initialize difficulty weighting. + + Args: + mode: Weighting mode + - "prioritized": Prioritized experience replay (difficulty^alpha) + - "curriculum": Curriculum learning (favor easier samples) + - "linear": weight = difficulty + - "inverse": weight = 1/difficulty + alpha: Exponent for prioritization (typical range: 0.4-0.8) + normalize: Normalize to mean=1 + epsilon: Small constant to avoid numerical issues + """ + self.mode = mode + self.alpha = alpha + self.normalize = normalize + self.epsilon = epsilon + + valid_modes = ["prioritized", "curriculum", "linear", "inverse"] + if mode not in valid_modes: + raise ValueError(f"Invalid mode: {mode}. Must be one of {valid_modes}") + + def compute_weights(self, metrics, experiences): + """ + Compute difficulty-based weights. + + Args: + metrics: SampleMetrics with difficulty + experiences: List of experiences + + Returns: + weights: FloatTensor of per-sample weights + """ + if metrics.difficulty is None: + # Difficulty not available, return uniform weights + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, device=device) + + difficulty = metrics.difficulty + + # Compute weights according to mode + if self.mode == "prioritized": + # Higher difficulty → higher weight (PER-style) + weights = torch.pow(difficulty + self.epsilon, self.alpha) + elif self.mode == "curriculum": + # Lower difficulty → higher weight (curriculum learning) + # Use inverse with alpha exponent + weights = torch.pow(1.0 / (difficulty + self.epsilon), self.alpha) + elif self.mode == "linear": + weights = difficulty + elif self.mode == "inverse": + weights = 1.0 / (difficulty + self.epsilon) + else: + raise ValueError(f"Unknown mode: {self.mode}") + + # Normalize to mean=1 + if self.normalize: + weights = weights / (weights.mean() + self.epsilon) + + return weights + + +class StalenessWeighting(LossWeighting): + """ + Weight samples by staleness (age). + + Older samples may be less relevant due to policy shift, so they + receive exponentially decaying weights. + """ + + def __init__( + self, + decay_factor: float = 0.95, + normalize: bool = True, + epsilon: float = 1e-6 + ): + """ + Initialize staleness weighting. + + Args: + decay_factor: Exponential decay factor (0 < decay_factor < 1) + - decay_factor close to 1: slow decay + - decay_factor close to 0: fast decay + normalize: Normalize to mean=1 + epsilon: Small constant + """ + self.decay_factor = decay_factor + self.normalize = normalize + self.epsilon = epsilon + + if not 0 < decay_factor <= 1: + raise ValueError(f"decay_factor must be in (0, 1], got {decay_factor}") + + def compute_weights(self, metrics, experiences): + """ + Compute staleness-based weights. + + Args: + metrics: SampleMetrics with staleness + experiences: List of experiences + + Returns: + weights: FloatTensor of per-sample weights + """ + if metrics.staleness is None: + # Staleness not available, return uniform weights + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, device=device) + + staleness = metrics.staleness + + # Exponential decay: weight = decay_factor^staleness + weights = torch.pow(self.decay_factor, staleness) + + # Normalize to mean=1 + if self.normalize: + weights = weights / (weights.mean() + self.epsilon) + + return weights + + +class RewardMagnitudeWeighting(LossWeighting): + """ + Weight samples by reward magnitude. + + This can be used to focus on high-reward or low-reward samples. + """ + + def __init__( + self, + mode: str = "favor_high", + temperature: float = 1.0, + normalize: bool = True, + epsilon: float = 1e-6 + ): + """ + Initialize reward magnitude weighting. + + Args: + mode: Weighting mode + - "favor_high": Higher reward → higher weight + - "favor_low": Lower reward → higher weight + - "absolute": weight = |reward| + temperature: Temperature for softmax (used in favor_high/favor_low) + normalize: Normalize to mean=1 + epsilon: Small constant + """ + self.mode = mode + self.temperature = temperature + self.normalize = normalize + self.epsilon = epsilon + + def compute_weights(self, metrics, experiences): + """ + Compute reward-based weights. + + Args: + metrics: SampleMetrics with reward_value + experiences: List of experiences + + Returns: + weights: FloatTensor of per-sample weights + """ + if metrics.reward_value is None: + # Reward not available, return uniform weights + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, device=device) + + rewards = metrics.reward_value + + # Compute weights according to mode + if self.mode == "favor_high": + weights = torch.softmax(rewards / self.temperature, dim=0) * len(rewards) + elif self.mode == "favor_low": + weights = torch.softmax(-rewards / self.temperature, dim=0) * len(rewards) + elif self.mode == "absolute": + weights = torch.abs(rewards) + else: + raise ValueError(f"Unknown mode: {self.mode}") + + # Normalize to mean=1 + if self.normalize: + weights = weights / (weights.mean() + self.epsilon) + + return weights + + +class CompositeWeighting(LossWeighting): + """ + Combine multiple weighting schemes. + + This allows building complex weighting strategies by composing simple weightings. + """ + + def __init__( + self, + weightings: List[Tuple[LossWeighting, float]], + mode: str = "product", + normalize: bool = True, + epsilon: float = 1e-6 + ): + """ + Initialize composite weighting. + + Args: + weightings: List of (weighting, coefficient) pairs + mode: Combination mode + - "product": Multiply all weights + - "sum": Sum all weights + - "weighted_sum": Weighted sum using coefficients + - "weighted_product": Product of (weight^coefficient) + normalize: Normalize final weights to mean=1 + epsilon: Small constant + """ + self.weightings = weightings + self.mode = mode + self.normalize = normalize + self.epsilon = epsilon + + valid_modes = ["product", "sum", "weighted_sum", "weighted_product"] + if mode not in valid_modes: + raise ValueError(f"Invalid mode: {mode}. Must be one of {valid_modes}") + + if not weightings: + warnings.warn("CompositeWeighting initialized with empty weightings list") + + def compute_weights(self, metrics, experiences): + """ + Combine multiple weights. + + Args: + metrics: SampleMetrics + experiences: List of experiences + + Returns: + weights: Combined FloatTensor of per-sample weights + """ + if not self.weightings: + # No weightings, return uniform weights + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, device=device) + + # Compute first weight + first_weighting, first_coef = self.weightings[0] + combined = first_weighting.compute_weights(metrics, experiences) + + # Combine with rest + if self.mode == "product": + # Multiply all weights + for weighting, _ in self.weightings[1:]: + w = weighting.compute_weights(metrics, experiences) + combined = combined * w + + elif self.mode == "sum": + # Sum all weights + for weighting, _ in self.weightings[1:]: + w = weighting.compute_weights(metrics, experiences) + combined = combined + w + + elif self.mode == "weighted_sum": + # Weighted sum using coefficients + combined = combined * first_coef + for weighting, coef in self.weightings[1:]: + w = weighting.compute_weights(metrics, experiences) + combined = combined + coef * w + + elif self.mode == "weighted_product": + # Product of (weight^coefficient) + combined = torch.pow(combined + self.epsilon, first_coef) + for weighting, coef in self.weightings[1:]: + w = weighting.compute_weights(metrics, experiences) + combined = combined * torch.pow(w + self.epsilon, coef) + + else: + raise ValueError(f"Unknown mode: {self.mode}") + + # Normalize to mean=1 + if self.normalize: + combined = combined / (combined.mean() + self.epsilon) + + return combined + + +class UniformWeighting(LossWeighting): + """ + Uniform weighting (all weights = 1). + + This is a no-op weighting for baseline comparisons. + """ + + def __init__(self): + """Initialize uniform weighting.""" + pass + + def compute_weights(self, metrics, experiences): + """ + Return uniform weights. + + Args: + metrics: SampleMetrics (unused) + experiences: List of experiences + + Returns: + weights: FloatTensor of ones + """ + total_samples = sum(len(exp.sequences) for exp in experiences) + device = experiences[0].sequences.device if experiences else 'cuda' + return torch.ones(total_samples, device=device) + From 11e81ac498417dfa193b761170e1c4aa231b9e0c Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Fri, 2 Jan 2026 00:59:10 +0800 Subject: [PATCH 2/8] refactor(sunjx): refactor loss-filter implementation --- lightrft/trainer/fast_exp_maker.py | 63 ++++++++++++++++++++-- lightrft/trainer/filter_weight/__init__.py | 2 + lightrft/trainer/filter_weight/filters.py | 2 + lightrft/trainer/filter_weight/manager.py | 2 + lightrft/trainer/filter_weight/metrics.py | 2 + lightrft/trainer/filter_weight/weights.py | 2 + 6 files changed, 68 insertions(+), 5 deletions(-) diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 93ecd89f..05ee19a1 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -53,6 +53,7 @@ from lightrft.utils.remote_rm_utils import remote_rm_fn from lightrft.utils import Timer, get_current_device from .utils import RunningMoments, compute_clip_fraction, get_cpgd_advantages_returns, fire_sampling +from .filter_weight import FilterWeightManagerBuilder # ============================================================================ # Data Structures @@ -951,6 +952,19 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg packing_samples=self.packing_samples, ) + # Initialize filter-weight manager + self.filter_weight_manager = self._init_filter_weight_manager() + + def _init_filter_weight_manager(self): + """ + Initialize filter-weight manager from strategy args. + + :return: FilterWeightManager instance + :rtype: FilterWeightManager + """ + args = self.strategy.args + return FilterWeightManagerBuilder.from_args(args, packing_samples=self.packing_samples) + # ======================================================================== # Public API Methods # ======================================================================== @@ -1020,12 +1034,32 @@ def make_experience_list( # ========== Stage 3: Model Inference ========== Timer.start(' make_experience') - experiences = self._make_experience_list_by_model(all_samples) + experiences, outputs = self._make_experience_list_by_model(all_samples) Timer.stop(' make_experience') # ========== Stage 4: Shard-Parallel Postprocessing ========== experiences = self.strategy.sp_data_processor.postprocess(experiences) + # ========== Stage 4.5: Apply Filter-Weight Framework ========== + if self.filter_weight_manager is not None and ( + self.filter_weight_manager.filters or self.filter_weight_manager.weights + ): + # Compute metrics from outputs + current_step = getattr(self.strategy, "global_step", None) + metrics = self.filter_weight_manager.compute_metrics(outputs, current_step=current_step) + + # Apply filters and weights + experiences, sample_weights = self.filter_weight_manager.apply_to_experiences( + experiences, metrics + ) + + # Store sample weights in experience info for later use + sample_idx = 0 + for exp in experiences: + batch_size = len(exp.sequences) + exp.info["sample_weights"] = sample_weights[sample_idx:sample_idx + batch_size] + sample_idx += batch_size + # ========== Stage 5: Reward Processing ========== experiences, rewards = self._process_experiences( # GRPO's -mean / std operation is performed in this method experiences, generate_kwargs.get("max_new_tokens", 1024) @@ -1360,7 +1394,16 @@ def _process_experiences( rewards = torch.cat([exp.info["reward"] for exp in experiences]) # ========== Overlong Sequence Penalty ========== - if config.overlong_buffer: + # Use new filter_weight framework if enabled, otherwise use legacy logic + from .filter_weight import ResponseLengthFilter + use_filter_weight = ( + self.filter_weight_manager is not None + and self.filter_weight_manager.filters + and any(isinstance(f, ResponseLengthFilter) for f in self.filter_weight_manager.filters) + ) + + if config.overlong_buffer and not use_filter_weight: + # Legacy overlong buffer penalty (only if not using filter_weight framework) expected_len = max_new_tokens - config.overlong_buffer_len actual_lens = torch.cat([exp.action_mask.sum(dim=1) for exp in experiences]) exceed_len = actual_lens - expected_len @@ -1393,7 +1436,16 @@ def _process_experiences( elif config.advantage_estimator in ["group_norm", "grpo"]: # Group normalization with optional dynamic filtering - if config.dynamic_sampling: + # Use new filter_weight framework if enabled, otherwise use legacy logic + from .filter_weight import RewardValueFilter + use_dynamic_filter = ( + self.filter_weight_manager is not None + and self.filter_weight_manager.filters + and any(isinstance(f, RewardValueFilter) for f in self.filter_weight_manager.filters) + ) + + if config.dynamic_sampling and not use_dynamic_filter: + # Legacy dynamic sampling (only if not using filter_weight framework) step_size = config.n_samples_per_prompt // config.micro_train_batch_size for i in range(0, len(experiences), step_size): chunk = experiences[i:i + step_size] @@ -1545,7 +1597,7 @@ def _compute_advantages_and_returns( def _make_experience_list_by_model( self, all_samples: List[Union[Samples, SamplesVL]], - ) -> List[Union[Experience, ExperienceVL]]: + ) -> Tuple[List[Union[Experience, ExperienceVL]], List[_SamplesOutput]]: """ Batch forward pass through all models to create experiences. @@ -1607,7 +1659,8 @@ def _make_experience_list_by_model( self.reward_engine.compute_rewards(outputs, vlm_mode, device) # ========== Stage 5: Assemble Experiences ========== - return [self._pack_experience(output, vlm_mode) for output in outputs] + experiences = [self._pack_experience(output, vlm_mode) for output in outputs] + return experiences, outputs def _preprocess_sample( self, diff --git a/lightrft/trainer/filter_weight/__init__.py b/lightrft/trainer/filter_weight/__init__.py index 2489e31d..ee547cfe 100644 --- a/lightrft/trainer/filter_weight/__init__.py +++ b/lightrft/trainer/filter_weight/__init__.py @@ -165,3 +165,5 @@ def create_manager_from_args(args, packing_samples: bool = False): "create_manager_from_args", ]) + + diff --git a/lightrft/trainer/filter_weight/filters.py b/lightrft/trainer/filter_weight/filters.py index a7df4944..50f72334 100644 --- a/lightrft/trainer/filter_weight/filters.py +++ b/lightrft/trainer/filter_weight/filters.py @@ -422,3 +422,5 @@ def filter(self, metrics, experiences): return mask + + diff --git a/lightrft/trainer/filter_weight/manager.py b/lightrft/trainer/filter_weight/manager.py index ae0435ce..fe5dc873 100644 --- a/lightrft/trainer/filter_weight/manager.py +++ b/lightrft/trainer/filter_weight/manager.py @@ -459,3 +459,5 @@ def from_args(args, packing_samples: bool = False) -> FilterWeightManager: packing_samples=packing_samples ) + + diff --git a/lightrft/trainer/filter_weight/metrics.py b/lightrft/trainer/filter_weight/metrics.py index 346bbcef..53a08f1d 100644 --- a/lightrft/trainer/filter_weight/metrics.py +++ b/lightrft/trainer/filter_weight/metrics.py @@ -330,3 +330,5 @@ def compute_all_metrics( micro_batch_size=None, # Can be set externally ) + + diff --git a/lightrft/trainer/filter_weight/weights.py b/lightrft/trainer/filter_weight/weights.py index a17666ab..df3cee90 100644 --- a/lightrft/trainer/filter_weight/weights.py +++ b/lightrft/trainer/filter_weight/weights.py @@ -523,3 +523,5 @@ def compute_weights(self, metrics, experiences): device = experiences[0].sequences.device if experiences else 'cuda' return torch.ones(total_samples, device=device) + + From 008c90a1db269c84acb8ee7f3d882eda12e0f5b1 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Thu, 8 Jan 2026 15:56:01 +0800 Subject: [PATCH 3/8] refactor(sunjx): Unify the comment style --- lightrft/trainer/filter_weight/__init__.py | 53 +++--- lightrft/trainer/filter_weight/filters.py | 153 ++++++++-------- lightrft/trainer/filter_weight/manager.py | 143 +++++++-------- lightrft/trainer/filter_weight/metrics.py | 96 +++++----- lightrft/trainer/filter_weight/weights.py | 197 ++++++++++----------- 5 files changed, 320 insertions(+), 322 deletions(-) diff --git a/lightrft/trainer/filter_weight/__init__.py b/lightrft/trainer/filter_weight/__init__.py index ee547cfe..e10d883d 100644 --- a/lightrft/trainer/filter_weight/__init__.py +++ b/lightrft/trainer/filter_weight/__init__.py @@ -33,8 +33,6 @@ # Apply to experiences experiences, weights = manager.apply_to_experiences(experiences, metrics) ``` - -Author: LightRLHF Team """ # Metrics @@ -103,13 +101,13 @@ ] -# Version info -__version__ = "1.0.0" -__author__ = "LightRLHF Team" - - def get_version(): - """Get the version of the filter_weight module.""" + """ + Get the version of the filter_weight module. + + :return: Version string + :rtype: str + """ return __version__ @@ -118,12 +116,12 @@ def create_length_filter(max_length: int = 1024, **kwargs): """ Quick function to create a response length filter. - Args: - max_length: Maximum response length - **kwargs: Additional arguments for ResponseLengthFilter - - Returns: - ResponseLengthFilter instance + :param max_length: Maximum response length + :type max_length: int + :param kwargs: Additional arguments for ResponseLengthFilter + :type kwargs: dict + :return: ResponseLengthFilter instance + :rtype: ResponseLengthFilter """ return ResponseLengthFilter(max_length=max_length, **kwargs) @@ -132,13 +130,14 @@ def create_difficulty_weighting(mode: str = "prioritized", alpha: float = 0.6, * """ Quick function to create difficulty weighting. - Args: - mode: Weighting mode ("prioritized" or "curriculum") - alpha: Prioritization exponent - **kwargs: Additional arguments for DifficultyWeighting - - Returns: - DifficultyWeighting instance + :param mode: Weighting mode ("prioritized" or "curriculum") + :type mode: str + :param alpha: Prioritization exponent + :type alpha: float + :param kwargs: Additional arguments for DifficultyWeighting + :type kwargs: dict + :return: DifficultyWeighting instance + :rtype: DifficultyWeighting """ return DifficultyWeighting(mode=mode, alpha=alpha, **kwargs) @@ -147,12 +146,12 @@ def create_manager_from_args(args, packing_samples: bool = False): """ Quick function to create FilterWeightManager from training arguments. - Args: - args: Training arguments - packing_samples: Whether samples are packed - - Returns: - FilterWeightManager instance + :param args: Training arguments + :type args: Any + :param packing_samples: Whether samples are packed + :type packing_samples: bool + :return: FilterWeightManager instance + :rtype: FilterWeightManager """ return FilterWeightManagerBuilder.from_args(args, packing_samples) diff --git a/lightrft/trainer/filter_weight/filters.py b/lightrft/trainer/filter_weight/filters.py index 50f72334..d36767d2 100644 --- a/lightrft/trainer/filter_weight/filters.py +++ b/lightrft/trainer/filter_weight/filters.py @@ -2,8 +2,6 @@ Sample Filtering Module Provides unified interface for filtering samples based on various criteria. - -Author: LightRLHF Team """ from abc import ABC, abstractmethod @@ -29,12 +27,12 @@ def filter( """ Compute filter mask. - Args: - metrics: SampleMetrics containing computed metrics - experiences: List of Experience/ExperienceVL objects - - Returns: - mask: BoolTensor (total_samples,) where True = keep, False = filter out + :param metrics: SampleMetrics containing computed metrics + :type metrics: SampleMetrics + :param experiences: List of Experience/ExperienceVL objects + :type experiences: List + :return: BoolTensor (total_samples,) where True = keep, False = filter out + :rtype: torch.Tensor """ pass @@ -57,11 +55,14 @@ def __init__( """ Initialize response length filter. - Args: - min_length: Minimum allowed response length (inclusive) - max_length: Maximum allowed response length (inclusive) - expected_length: Expected response length (for buffer-based filtering) - buffer_length: Buffer around expected length (filters if length > expected + buffer) + :param min_length: Minimum allowed response length (inclusive) + :type min_length: Optional[int] + :param max_length: Maximum allowed response length (inclusive) + :type max_length: Optional[int] + :param expected_length: Expected response length (for buffer-based filtering) + :type expected_length: Optional[int] + :param buffer_length: Buffer around expected length (filters if length > expected + buffer) + :type buffer_length: Optional[int] """ self.min_length = min_length self.max_length = max_length @@ -79,12 +80,12 @@ def filter(self, metrics, experiences): """ Filter based on length constraints. - Args: - metrics: SampleMetrics with response_length - experiences: List of experiences (unused but kept for interface consistency) - - Returns: - mask: BoolTensor indicating which samples to keep + :param metrics: SampleMetrics with response_length + :type metrics: SampleMetrics + :param experiences: List of experiences (unused but kept for interface consistency) + :type experiences: List + :return: BoolTensor indicating which samples to keep + :rtype: torch.Tensor """ lengths = metrics.response_length mask = torch.ones(len(lengths), dtype=torch.bool, device=lengths.device) @@ -124,12 +125,16 @@ def __init__( """ Initialize reward value filter. - Args: - filter_all_zeros: Filter groups where all rewards are 0 - filter_all_ones: Filter groups where all rewards are 1 - n_samples_per_prompt: Number of samples per prompt (for grouping) - group_size: Custom group size (if None, uses n_samples_per_prompt) - tolerance: Tolerance for comparing reward values + :param filter_all_zeros: Filter groups where all rewards are 0 + :type filter_all_zeros: bool + :param filter_all_ones: Filter groups where all rewards are 1 + :type filter_all_ones: bool + :param n_samples_per_prompt: Number of samples per prompt (for grouping) + :type n_samples_per_prompt: int + :param group_size: Custom group size (if None, uses n_samples_per_prompt) + :type group_size: Optional[int] + :param tolerance: Tolerance for comparing reward values + :type tolerance: float """ self.filter_all_zeros = filter_all_zeros self.filter_all_ones = filter_all_ones @@ -140,12 +145,12 @@ def filter(self, metrics, experiences): """ Filter groups with degenerate rewards. - Args: - metrics: SampleMetrics with reward_value - experiences: List of experiences (unused) - - Returns: - mask: BoolTensor indicating which samples to keep + :param metrics: SampleMetrics with reward_value + :type metrics: SampleMetrics + :param experiences: List of experiences (unused) + :type experiences: List + :return: BoolTensor indicating which samples to keep + :rtype: torch.Tensor """ if metrics.reward_value is None: # No rewards available, keep all samples @@ -199,9 +204,10 @@ def __init__( """ Initialize entropy filter. - Args: - min_entropy: Minimum entropy threshold (filter out low entropy) - max_entropy: Maximum entropy threshold (filter out high entropy) + :param min_entropy: Minimum entropy threshold (filter out low entropy) + :type min_entropy: Optional[float] + :param max_entropy: Maximum entropy threshold (filter out high entropy) + :type max_entropy: Optional[float] """ self.min_entropy = min_entropy self.max_entropy = max_entropy @@ -210,12 +216,12 @@ def filter(self, metrics, experiences): """ Filter based on entropy. - Args: - metrics: SampleMetrics with entropy - experiences: List of experiences - - Returns: - mask: BoolTensor indicating which samples to keep + :param metrics: SampleMetrics with entropy + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: BoolTensor indicating which samples to keep + :rtype: torch.Tensor """ if metrics.entropy is None: # Entropy not computed, keep all samples @@ -254,12 +260,12 @@ def __init__( """ Initialize difficulty filter. - Args: - min_difficulty: Minimum difficulty threshold - max_difficulty: Maximum difficulty threshold - mode: Filtering mode - - "absolute": Use absolute threshold values - - "percentile": Interpret thresholds as percentiles (0-100) + :param min_difficulty: Minimum difficulty threshold + :type min_difficulty: Optional[float] + :param max_difficulty: Maximum difficulty threshold + :type max_difficulty: Optional[float] + :param mode: Filtering mode ("absolute" or "percentile") + :type mode: str """ self.min_difficulty = min_difficulty self.max_difficulty = max_difficulty @@ -269,12 +275,12 @@ def filter(self, metrics, experiences): """ Filter based on difficulty. - Args: - metrics: SampleMetrics with difficulty - experiences: List of experiences - - Returns: - mask: BoolTensor indicating which samples to keep + :param metrics: SampleMetrics with difficulty + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: BoolTensor indicating which samples to keep + :rtype: torch.Tensor """ if metrics.difficulty is None: # Difficulty not computed, keep all samples @@ -316,11 +322,10 @@ def __init__(self, filters: List[SampleFilter], logic: str = "AND"): """ Initialize composite filter. - Args: - filters: List of filters to combine - logic: Combination logic - - "AND": All filters must pass (intersection) - - "OR": Any filter must pass (union) + :param filters: List of filters to combine + :type filters: List[SampleFilter] + :param logic: Combination logic ("AND" or "OR") + :type logic: str """ self.filters = filters self.logic = logic.upper() @@ -332,12 +337,12 @@ def filter(self, metrics, experiences): """ Combine filters according to logic. - Args: - metrics: SampleMetrics - experiences: List of experiences - - Returns: - mask: Combined BoolTensor + :param metrics: SampleMetrics + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: Combined BoolTensor + :rtype: torch.Tensor """ if not self.filters: # No filters, keep all samples @@ -376,10 +381,12 @@ def __init__( """ Initialize percentile filter. - Args: - metric_name: Name of metric in SampleMetrics (e.g., "entropy", "difficulty") - top_percentile: Keep top X% (e.g., 20 = keep top 20%) - bottom_percentile: Keep bottom X% (e.g., 20 = keep bottom 20%) + :param metric_name: Name of metric in SampleMetrics (e.g., "entropy", "difficulty") + :type metric_name: str + :param top_percentile: Keep top X% (e.g., 20 = keep top 20%) + :type top_percentile: Optional[float] + :param bottom_percentile: Keep bottom X% (e.g., 20 = keep bottom 20%) + :type bottom_percentile: Optional[float] """ self.metric_name = metric_name self.top_percentile = top_percentile @@ -392,12 +399,12 @@ def filter(self, metrics, experiences): """ Filter based on percentile ranking. - Args: - metrics: SampleMetrics - experiences: List of experiences - - Returns: - mask: BoolTensor indicating which samples to keep + :param metrics: SampleMetrics + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: BoolTensor indicating which samples to keep + :rtype: torch.Tensor """ # Get metric value metric_value = getattr(metrics, self.metric_name, None) diff --git a/lightrft/trainer/filter_weight/manager.py b/lightrft/trainer/filter_weight/manager.py index fe5dc873..285d9568 100644 --- a/lightrft/trainer/filter_weight/manager.py +++ b/lightrft/trainer/filter_weight/manager.py @@ -2,8 +2,6 @@ Unified Filter-Weight Manager Provides high-level API for managing sample filtering and loss weighting. - -Author: LightRLHF Team """ from typing import List, Optional, Dict, Tuple @@ -48,12 +46,16 @@ class FilterWeightManager: experiences, weights = manager.apply_to_experiences(experiences, metrics) ``` - Args: - metrics_computer: Custom metrics computer (if None, creates default) - filters: List of filters to apply - weights: List of (weighting, coefficient) pairs - enable_metrics: Dict of metric names to enable - packing_samples: Whether samples are packed + :param metrics_computer: Custom metrics computer (if None, creates default) + :type metrics_computer: Optional[MetricsComputer] + :param filters: List of filters to apply + :type filters: Optional[List[SampleFilter]] + :param weights: List of (weighting, coefficient) pairs + :type weights: Optional[List[Tuple[LossWeighting, float]]] + :param enable_metrics: Dict of metric names to enable + :type enable_metrics: Optional[Dict[str, bool]] + :param packing_samples: Whether samples are packed + :type packing_samples: bool """ def __init__( @@ -67,12 +69,16 @@ def __init__( """ Initialize filter-weight manager. - Args: - metrics_computer: Custom metrics computer (default: MetricsComputer()) - filters: List of filters to apply (default: []) - weights: List of (weighting, coefficient) tuples (default: []) - enable_metrics: Dict specifying which metrics to compute (default: {}) - packing_samples: Whether samples are packed (affects metric computation) + :param metrics_computer: Custom metrics computer (default: MetricsComputer()) + :type metrics_computer: Optional[MetricsComputer] + :param filters: List of filters to apply (default: []) + :type filters: Optional[List[SampleFilter]] + :param weights: List of (weighting, coefficient) tuples (default: []) + :type weights: Optional[List[Tuple[LossWeighting, float]]] + :param enable_metrics: Dict specifying which metrics to compute (default: {}) + :type enable_metrics: Optional[Dict[str, bool]] + :param packing_samples: Whether samples are packed (affects metric computation) + :type packing_samples: bool """ self.metrics_computer = metrics_computer or MetricsComputer(packing_samples) self.filters = filters or [] @@ -131,12 +137,12 @@ def compute_metrics( """ Compute all enabled metrics. - Args: - outputs: List of sample outputs from model inference (_SamplesOutput objects) - current_step: Current training step (required for staleness computation) - - Returns: - SampleMetrics with computed metrics + :param outputs: List of sample outputs from model inference (_SamplesOutput objects) + :type outputs: List + :param current_step: Current training step (required for staleness computation) + :type current_step: Optional[int] + :return: SampleMetrics with computed metrics + :rtype: SampleMetrics """ return self.metrics_computer.compute_all_metrics( outputs, @@ -152,13 +158,12 @@ def apply_filters( """ Apply all configured filters. - Args: - metrics: Computed sample metrics - experiences: List of Experience/ExperienceVL objects - - Returns: - mask: BoolTensor (total_samples,) indicating which samples to keep - True = keep, False = filter out + :param metrics: Computed sample metrics + :type metrics: SampleMetrics + :param experiences: List of Experience/ExperienceVL objects + :type experiences: List + :return: BoolTensor (total_samples,) indicating which samples to keep (True = keep, False = filter out) + :rtype: torch.Tensor """ if not self.filters: # No filters, keep all samples @@ -178,12 +183,12 @@ def compute_weights( """ Compute combined sample weights. - Args: - metrics: Computed sample metrics - experiences: List of Experience/ExperienceVL objects - - Returns: - weights: FloatTensor (total_samples,) with loss weights + :param metrics: Computed sample metrics + :type metrics: SampleMetrics + :param experiences: List of Experience/ExperienceVL objects + :type experiences: List + :return: FloatTensor (total_samples,) with loss weights + :rtype: torch.Tensor """ if not self.weights: # No weighting, return uniform weights @@ -211,14 +216,16 @@ def apply_to_experiences( 3. Computes loss weights 4. Zeros out weights for filtered samples (if apply_filter_to_weights=True) - Args: - experiences: List of Experience/ExperienceVL objects to process - metrics: Computed sample metrics - apply_filter_to_mask: If True, update action_mask to exclude filtered samples - apply_filter_to_weights: If True, zero out weights for filtered samples - - Returns: - (experiences, weights): Modified experiences and per-sample weights + :param experiences: List of Experience/ExperienceVL objects to process + :type experiences: List + :param metrics: Computed sample metrics + :type metrics: SampleMetrics + :param apply_filter_to_mask: If True, update action_mask to exclude filtered samples + :type apply_filter_to_mask: bool + :param apply_filter_to_weights: If True, zero out weights for filtered samples + :type apply_filter_to_weights: bool + :return: Modified experiences and per-sample weights + :rtype: Tuple[List, torch.Tensor] """ # Apply filters keep_mask = self.apply_filters(metrics, experiences) @@ -254,16 +261,12 @@ def get_filter_stats( """ Get statistics about filtering. - Args: - metrics: Computed sample metrics - experiences: List of experiences - - Returns: - Dict with statistics: - - "total_samples": Total number of samples - - "filtered_samples": Number of filtered samples - - "filter_rate": Fraction of samples filtered - - "kept_samples": Number of kept samples + :param metrics: Computed sample metrics + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: Dict with statistics (total_samples, filtered_samples, filter_rate, kept_samples) + :rtype: Dict[str, float] """ total_samples = sum(len(exp.sequences) for exp in experiences) keep_mask = self.apply_filters(metrics, experiences) @@ -286,16 +289,12 @@ def get_weight_stats( """ Get statistics about weighting. - Args: - metrics: Computed sample metrics - experiences: List of experiences - - Returns: - Dict with statistics: - - "weight_mean": Mean weight - - "weight_std": Standard deviation of weights - - "weight_min": Minimum weight - - "weight_max": Maximum weight + :param metrics: Computed sample metrics + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: Dict with statistics (weight_mean, weight_std, weight_min, weight_max) + :rtype: Dict[str, float] """ weights = self.compute_weights(metrics, experiences) @@ -315,10 +314,12 @@ def log_stats( """ Log filtering and weighting statistics. - Args: - metrics: Computed sample metrics - experiences: List of experiences - logger: Logger object (if None, uses print) + :param metrics: Computed sample metrics + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :param logger: Logger object (if None, uses print) + :type logger: Optional[Any] """ filter_stats = self.get_filter_stats(metrics, experiences) weight_stats = self.get_weight_stats(metrics, experiences) @@ -362,12 +363,12 @@ def from_args(args, packing_samples: bool = False) -> FilterWeightManager: """ Build FilterWeightManager from training arguments. - Args: - args: Training arguments object - packing_samples: Whether samples are packed - - Returns: - FilterWeightManager configured according to args + :param args: Training arguments object + :type args: Any + :param packing_samples: Whether samples are packed + :type packing_samples: bool + :return: FilterWeightManager configured according to args + :rtype: FilterWeightManager """ from .filters import ResponseLengthFilter, RewardValueFilter, EntropyFilter from .weights import ( diff --git a/lightrft/trainer/filter_weight/metrics.py b/lightrft/trainer/filter_weight/metrics.py index 53a08f1d..d345a9e8 100644 --- a/lightrft/trainer/filter_weight/metrics.py +++ b/lightrft/trainer/filter_weight/metrics.py @@ -3,8 +3,6 @@ Provides unified interface for computing various sample-level metrics used in filtering and weighting. - -Author: LightRLHF Team """ from dataclasses import dataclass @@ -53,16 +51,16 @@ class MetricsComputer: This class provides methods to compute sample-level metrics that can be used for filtering and weighting during experience generation. - Args: - packing_samples: Whether samples are packed (affects unpacking logic) + :param packing_samples: Whether samples are packed (affects unpacking logic) + :type packing_samples: bool """ def __init__(self, packing_samples: bool = False): """ Initialize metrics computer. - Args: - packing_samples: Whether samples are packed into single sequences + :param packing_samples: Whether samples are packed into single sequences + :type packing_samples: bool """ self.packing_samples = packing_samples @@ -78,13 +76,14 @@ def compute_entropy( Entropy = -sum(p * log(p)) where p = exp(log_prob) Higher entropy indicates more uncertain/exploratory policy. - Args: - action_log_probs: Log probabilities (batch_size, seq_len) or (1, total_len) for packed - action_mask: Action mask (batch_size, seq_len) or (1, total_len) for packed - num_actions: Number of actions per sample (for packed samples) - - Returns: - entropy: Per-sample entropy (batch_size,) + :param action_log_probs: Log probabilities (batch_size, seq_len) or (1, total_len) for packed + :type action_log_probs: torch.Tensor + :param action_mask: Action mask (batch_size, seq_len) or (1, total_len) for packed + :type action_mask: torch.Tensor + :param num_actions: Number of actions per sample (for packed samples) + :type num_actions: Optional[torch.Tensor] + :return: Per-sample entropy (batch_size,) + :rtype: torch.Tensor """ # Convert log probs to probs probs = torch.exp(action_log_probs) @@ -116,14 +115,16 @@ def compute_logit_kl( KL(current || reference) = sum(p_curr * (log p_curr - log p_ref)) - Args: - current_logits: Logits from current policy (batch, seq, vocab) - reference_logits: Logits from reference policy (batch, seq, vocab) - action_mask: Action mask (batch, seq) - num_actions: Number of actions per sample (for packed samples) - - Returns: - kl: Per-sample KL divergence (batch_size,) + :param current_logits: Logits from current policy (batch, seq, vocab) + :type current_logits: torch.Tensor + :param reference_logits: Logits from reference policy (batch, seq, vocab) + :type reference_logits: torch.Tensor + :param action_mask: Action mask (batch, seq) + :type action_mask: torch.Tensor + :param num_actions: Number of actions per sample (for packed samples) + :type num_actions: Optional[torch.Tensor] + :return: Per-sample KL divergence (batch_size,) + :rtype: torch.Tensor """ # Convert to log probabilities current_log_probs = torch.log_softmax(current_logits, dim=-1) @@ -157,13 +158,14 @@ def compute_difficulty( - "high_variance": Variance of reward within group (placeholder) - "abs_reward": Absolute reward magnitude - Args: - rewards: Per-sample rewards (total_samples,) - values: Per-sample value estimates (total_samples,) - required for "td_error" - mode: Difficulty computation mode - - Returns: - difficulty: Per-sample difficulty scores (total_samples,) + :param rewards: Per-sample rewards (total_samples,) + :type rewards: torch.Tensor + :param values: Per-sample value estimates (total_samples,) - required for "td_error" + :type values: Optional[torch.Tensor] + :param mode: Difficulty computation mode + :type mode: str + :return: Per-sample difficulty scores (total_samples,) + :rtype: torch.Tensor """ if mode == "td_error": if values is None: @@ -203,15 +205,14 @@ def compute_staleness( Staleness measures how old a sample is relative to current training step. Older samples may be less relevant due to policy shift. - Args: - generation_steps: Step when each sample was generated (total_samples,) - current_step: Current training step - mode: Staleness computation mode - - "linear": age normalized to [0, 1] - - "exponential": 1 - exp(-age / tau) - - Returns: - staleness: Per-sample staleness scores (total_samples,) + :param generation_steps: Step when each sample was generated (total_samples,) + :type generation_steps: torch.Tensor + :param current_step: Current training step + :type current_step: int + :param mode: Staleness computation mode ("linear" or "exponential") + :type mode: str + :return: Per-sample staleness scores (total_samples,) + :rtype: torch.Tensor """ age = current_step - generation_steps @@ -242,19 +243,14 @@ def compute_all_metrics( This is the main entry point for computing metrics. It checks enable_flags to determine which metrics to compute and returns a SampleMetrics object. - Args: - outputs: List of _SamplesOutput objects from model inference - enable_flags: Dict indicating which metrics to compute - - "entropy": bool - - "logit_kl": bool - - "difficulty": bool - - "difficulty_mode": str - - "staleness": bool - - "staleness_mode": str - current_step: Current training step (required for staleness) - - Returns: - SampleMetrics with all enabled metrics computed + :param outputs: List of _SamplesOutput objects from model inference + :type outputs: List + :param enable_flags: Dict indicating which metrics to compute + :type enable_flags: Dict[str, bool] + :param current_step: Current training step (required for staleness) + :type current_step: Optional[int] + :return: SampleMetrics with all enabled metrics computed + :rtype: SampleMetrics """ # Collect core metrics (always available) response_lengths = torch.cat([out.response_length for out in outputs]) diff --git a/lightrft/trainer/filter_weight/weights.py b/lightrft/trainer/filter_weight/weights.py index df3cee90..0efef420 100644 --- a/lightrft/trainer/filter_weight/weights.py +++ b/lightrft/trainer/filter_weight/weights.py @@ -2,8 +2,6 @@ Loss Weighting Module Provides unified interface for computing sample-level loss weights. - -Author: LightRLHF Team """ from abc import ABC, abstractmethod @@ -29,12 +27,12 @@ def compute_weights( """ Compute per-sample weights. - Args: - metrics: SampleMetrics containing computed metrics - experiences: List of Experience/ExperienceVL objects - - Returns: - weights: FloatTensor (total_samples,) with loss weights + :param metrics: SampleMetrics containing computed metrics + :type metrics: SampleMetrics + :param experiences: List of Experience/ExperienceVL objects + :type experiences: List + :return: FloatTensor (total_samples,) with loss weights + :rtype: torch.Tensor """ pass @@ -60,16 +58,16 @@ def __init__( """ Initialize response length weighting. - Args: - mode: Weighting mode - - "linear": weight = length - - "inverse": weight = 1/length - - "sqrt": weight = sqrt(length) - - "log": weight = log(1 + length) - normalize: Whether to normalize weights to mean=1 - clip_min: Minimum weight value - clip_max: Maximum weight value - epsilon: Small constant to avoid division by zero + :param mode: Weighting mode ("linear", "inverse", "sqrt", or "log") + :type mode: str + :param normalize: Whether to normalize weights to mean=1 + :type normalize: bool + :param clip_min: Minimum weight value + :type clip_min: Optional[float] + :param clip_max: Maximum weight value + :type clip_max: Optional[float] + :param epsilon: Small constant to avoid division by zero + :type epsilon: float """ self.mode = mode self.normalize = normalize @@ -85,12 +83,12 @@ def compute_weights(self, metrics, experiences): """ Compute length-based weights. - Args: - metrics: SampleMetrics with response_length - experiences: List of experiences (unused) - - Returns: - weights: FloatTensor of per-sample weights + :param metrics: SampleMetrics with response_length + :type metrics: SampleMetrics + :param experiences: List of experiences (unused) + :type experiences: List + :return: FloatTensor of per-sample weights + :rtype: torch.Tensor """ lengths = metrics.response_length.float() @@ -136,15 +134,14 @@ def __init__( """ Initialize entropy weighting. - Args: - mode: Weighting mode - - "favor_high": Higher entropy → higher weight (encourage exploration) - - "favor_low": Lower entropy → higher weight (encourage exploitation) - - "linear": weight = entropy (linear scaling) - - "inverse": weight = 1/entropy (inverse scaling) - temperature: Temperature for softmax weighting (used in favor_high/favor_low modes) - normalize: Normalize to mean=1 - epsilon: Small constant to avoid numerical issues + :param mode: Weighting mode ("favor_high", "favor_low", "linear", or "inverse") + :type mode: str + :param temperature: Temperature for softmax weighting (used in favor_high/favor_low modes) + :type temperature: float + :param normalize: Normalize to mean=1 + :type normalize: bool + :param epsilon: Small constant to avoid numerical issues + :type epsilon: float """ self.mode = mode self.temperature = temperature @@ -159,12 +156,12 @@ def compute_weights(self, metrics, experiences): """ Compute entropy-based weights. - Args: - metrics: SampleMetrics with entropy - experiences: List of experiences - - Returns: - weights: FloatTensor of per-sample weights + :param metrics: SampleMetrics with entropy + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: FloatTensor of per-sample weights + :rtype: torch.Tensor """ if metrics.entropy is None: # Entropy not available, return uniform weights @@ -213,15 +210,14 @@ def __init__( """ Initialize difficulty weighting. - Args: - mode: Weighting mode - - "prioritized": Prioritized experience replay (difficulty^alpha) - - "curriculum": Curriculum learning (favor easier samples) - - "linear": weight = difficulty - - "inverse": weight = 1/difficulty - alpha: Exponent for prioritization (typical range: 0.4-0.8) - normalize: Normalize to mean=1 - epsilon: Small constant to avoid numerical issues + :param mode: Weighting mode ("prioritized", "curriculum", "linear", or "inverse") + :type mode: str + :param alpha: Exponent for prioritization (typical range: 0.4-0.8) + :type alpha: float + :param normalize: Normalize to mean=1 + :type normalize: bool + :param epsilon: Small constant to avoid numerical issues + :type epsilon: float """ self.mode = mode self.alpha = alpha @@ -236,12 +232,12 @@ def compute_weights(self, metrics, experiences): """ Compute difficulty-based weights. - Args: - metrics: SampleMetrics with difficulty - experiences: List of experiences - - Returns: - weights: FloatTensor of per-sample weights + :param metrics: SampleMetrics with difficulty + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: FloatTensor of per-sample weights + :rtype: torch.Tensor """ if metrics.difficulty is None: # Difficulty not available, return uniform weights @@ -290,12 +286,12 @@ def __init__( """ Initialize staleness weighting. - Args: - decay_factor: Exponential decay factor (0 < decay_factor < 1) - - decay_factor close to 1: slow decay - - decay_factor close to 0: fast decay - normalize: Normalize to mean=1 - epsilon: Small constant + :param decay_factor: Exponential decay factor (0 < decay_factor < 1) + :type decay_factor: float + :param normalize: Normalize to mean=1 + :type normalize: bool + :param epsilon: Small constant + :type epsilon: float """ self.decay_factor = decay_factor self.normalize = normalize @@ -308,12 +304,12 @@ def compute_weights(self, metrics, experiences): """ Compute staleness-based weights. - Args: - metrics: SampleMetrics with staleness - experiences: List of experiences - - Returns: - weights: FloatTensor of per-sample weights + :param metrics: SampleMetrics with staleness + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: FloatTensor of per-sample weights + :rtype: torch.Tensor """ if metrics.staleness is None: # Staleness not available, return uniform weights @@ -350,14 +346,14 @@ def __init__( """ Initialize reward magnitude weighting. - Args: - mode: Weighting mode - - "favor_high": Higher reward → higher weight - - "favor_low": Lower reward → higher weight - - "absolute": weight = |reward| - temperature: Temperature for softmax (used in favor_high/favor_low) - normalize: Normalize to mean=1 - epsilon: Small constant + :param mode: Weighting mode ("favor_high", "favor_low", or "absolute") + :type mode: str + :param temperature: Temperature for softmax (used in favor_high/favor_low) + :type temperature: float + :param normalize: Normalize to mean=1 + :type normalize: bool + :param epsilon: Small constant + :type epsilon: float """ self.mode = mode self.temperature = temperature @@ -368,12 +364,12 @@ def compute_weights(self, metrics, experiences): """ Compute reward-based weights. - Args: - metrics: SampleMetrics with reward_value - experiences: List of experiences - - Returns: - weights: FloatTensor of per-sample weights + :param metrics: SampleMetrics with reward_value + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: FloatTensor of per-sample weights + :rtype: torch.Tensor """ if metrics.reward_value is None: # Reward not available, return uniform weights @@ -417,15 +413,14 @@ def __init__( """ Initialize composite weighting. - Args: - weightings: List of (weighting, coefficient) pairs - mode: Combination mode - - "product": Multiply all weights - - "sum": Sum all weights - - "weighted_sum": Weighted sum using coefficients - - "weighted_product": Product of (weight^coefficient) - normalize: Normalize final weights to mean=1 - epsilon: Small constant + :param weightings: List of (weighting, coefficient) pairs + :type weightings: List[Tuple[LossWeighting, float]] + :param mode: Combination mode ("product", "sum", "weighted_sum", or "weighted_product") + :type mode: str + :param normalize: Normalize final weights to mean=1 + :type normalize: bool + :param epsilon: Small constant + :type epsilon: float """ self.weightings = weightings self.mode = mode @@ -443,12 +438,12 @@ def compute_weights(self, metrics, experiences): """ Combine multiple weights. - Args: - metrics: SampleMetrics - experiences: List of experiences - - Returns: - weights: Combined FloatTensor of per-sample weights + :param metrics: SampleMetrics + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: Combined FloatTensor of per-sample weights + :rtype: torch.Tensor """ if not self.weightings: # No weightings, return uniform weights @@ -512,12 +507,12 @@ def compute_weights(self, metrics, experiences): """ Return uniform weights. - Args: - metrics: SampleMetrics (unused) - experiences: List of experiences - - Returns: - weights: FloatTensor of ones + :param metrics: SampleMetrics (unused) + :type metrics: SampleMetrics + :param experiences: List of experiences + :type experiences: List + :return: FloatTensor of ones + :rtype: torch.Tensor """ total_samples = sum(len(exp.sequences) for exp in experiences) device = experiences[0].sequences.device if experiences else 'cuda' From 4d04e1d259fedfd5ea7ca644d045d6fe444ddb81 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Wed, 14 Jan 2026 12:03:00 +0800 Subject: [PATCH 4/8] refactor(sunjx): fix format/fcheck bugs --- lightrft/trainer/fast_exp_maker.py | 18 ++++---- lightrft/trainer/filter_weight/__init__.py | 15 ------- lightrft/trainer/filter_weight/filters.py | 21 +--------- lightrft/trainer/filter_weight/manager.py | 49 +++++++--------------- lightrft/trainer/filter_weight/metrics.py | 23 ++-------- lightrft/trainer/filter_weight/weights.py | 38 ++--------------- 6 files changed, 31 insertions(+), 133 deletions(-) diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 4a734ce7..ec1b817a 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -1027,12 +1027,10 @@ def make_experience_list( # Compute metrics from outputs current_step = getattr(self.strategy, "global_step", None) metrics = self.filter_weight_manager.compute_metrics(outputs, current_step=current_step) - + # Apply filters and weights - experiences, sample_weights = self.filter_weight_manager.apply_to_experiences( - experiences, metrics - ) - + experiences, sample_weights = self.filter_weight_manager.apply_to_experiences(experiences, metrics) + # Store sample weights in experience info for later use sample_idx = 0 for exp in experiences: @@ -1370,11 +1368,10 @@ def _process_experiences( # Use new filter_weight framework if enabled, otherwise use legacy logic from .filter_weight import ResponseLengthFilter use_filter_weight = ( - self.filter_weight_manager is not None - and self.filter_weight_manager.filters + self.filter_weight_manager is not None and self.filter_weight_manager.filters and any(isinstance(f, ResponseLengthFilter) for f in self.filter_weight_manager.filters) ) - + if config.overlong_buffer and not use_filter_weight: # Legacy overlong buffer penalty (only if not using filter_weight framework) expected_len = max_new_tokens - config.overlong_buffer_len @@ -1412,11 +1409,10 @@ def _process_experiences( # Use new filter_weight framework if enabled, otherwise use legacy logic from .filter_weight import RewardValueFilter use_dynamic_filter = ( - self.filter_weight_manager is not None - and self.filter_weight_manager.filters + self.filter_weight_manager is not None and self.filter_weight_manager.filters and any(isinstance(f, RewardValueFilter) for f in self.filter_weight_manager.filters) ) - + if config.dynamic_sampling and not use_dynamic_filter: # Legacy dynamic sampling (only if not using filter_weight framework) step_size = config.n_samples_per_prompt // config.micro_train_batch_size diff --git a/lightrft/trainer/filter_weight/__init__.py b/lightrft/trainer/filter_weight/__init__.py index e10d883d..21cad70c 100644 --- a/lightrft/trainer/filter_weight/__init__.py +++ b/lightrft/trainer/filter_weight/__init__.py @@ -70,7 +70,6 @@ FilterWeightManagerBuilder, ) - __all__ = [ # ========== Metrics ========== "SampleMetrics", @@ -101,16 +100,6 @@ ] -def get_version(): - """ - Get the version of the filter_weight module. - - :return: Version string - :rtype: str - """ - return __version__ - - # Quick access functions for common use cases def create_length_filter(max_length: int = 1024, **kwargs): """ @@ -158,11 +147,7 @@ def create_manager_from_args(args, packing_samples: bool = False): # Add convenience imports at module level __all__.extend([ - "get_version", "create_length_filter", "create_difficulty_weighting", "create_manager_from_args", ]) - - - diff --git a/lightrft/trainer/filter_weight/filters.py b/lightrft/trainer/filter_weight/filters.py index d36767d2..5beb74f9 100644 --- a/lightrft/trainer/filter_weight/filters.py +++ b/lightrft/trainer/filter_weight/filters.py @@ -17,7 +17,6 @@ class SampleFilter(ABC): Filters determine which samples should be kept for training. They return a boolean mask where True indicates the sample should be kept. """ - @abstractmethod def filter( self, @@ -44,7 +43,6 @@ class ResponseLengthFilter(SampleFilter): This filter can enforce minimum/maximum length constraints or use a buffer-based approach (e.g., expected_length ± buffer_length). """ - def __init__( self, min_length: Optional[int] = None, @@ -113,7 +111,6 @@ class RewardValueFilter(SampleFilter): This filter detects and removes groups of samples where all rewards are identical (e.g., all 0s or all 1s), which provides no learning signal. """ - def __init__( self, filter_all_zeros: bool = True, @@ -195,12 +192,7 @@ class EntropyFilter(SampleFilter): Entropy measures the uncertainty/diversity of the policy. Low entropy indicates confident/deterministic generation, high entropy indicates uncertain/exploratory generation. """ - - def __init__( - self, - min_entropy: Optional[float] = None, - max_entropy: Optional[float] = None - ): + def __init__(self, min_entropy: Optional[float] = None, max_entropy: Optional[float] = None): """ Initialize entropy filter. @@ -250,7 +242,6 @@ class DifficultyFilter(SampleFilter): This can be used to implement curriculum learning (filter out hard samples early) or focus training (filter out easy samples). """ - def __init__( self, min_difficulty: Optional[float] = None, @@ -317,7 +308,6 @@ class CompositeFilter(SampleFilter): This allows building complex filtering logic by composing simple filters. """ - def __init__(self, filters: List[SampleFilter], logic: str = "AND"): """ Initialize composite filter. @@ -371,12 +361,8 @@ class PercentileFilter(SampleFilter): This is useful for keeping top-k% or bottom-k% samples according to some metric. """ - def __init__( - self, - metric_name: str, - top_percentile: Optional[float] = None, - bottom_percentile: Optional[float] = None + self, metric_name: str, top_percentile: Optional[float] = None, bottom_percentile: Optional[float] = None ): """ Initialize percentile filter. @@ -428,6 +414,3 @@ def filter(self, metrics, experiences): mask |= (metric_value <= threshold) return mask - - - diff --git a/lightrft/trainer/filter_weight/manager.py b/lightrft/trainer/filter_weight/manager.py index 285d9568..146b04e2 100644 --- a/lightrft/trainer/filter_weight/manager.py +++ b/lightrft/trainer/filter_weight/manager.py @@ -57,7 +57,6 @@ class FilterWeightManager: :param packing_samples: Whether samples are packed :type packing_samples: bool """ - def __init__( self, metrics_computer: Optional[MetricsComputer] = None, @@ -144,11 +143,7 @@ def compute_metrics( :return: SampleMetrics with computed metrics :rtype: SampleMetrics """ - return self.metrics_computer.compute_all_metrics( - outputs, - self.enable_metrics, - current_step - ) + return self.metrics_computer.compute_all_metrics(outputs, self.enable_metrics, current_step) def apply_filters( self, @@ -235,7 +230,7 @@ def apply_to_experiences( sample_idx = 0 for exp in experiences: batch_size = len(exp.sequences) - batch_mask = keep_mask[sample_idx : sample_idx + batch_size] + batch_mask = keep_mask[sample_idx:sample_idx + batch_size] # Zero out action_mask for filtered samples # This effectively removes them from loss computation @@ -357,7 +352,6 @@ class FilterWeightManagerBuilder: manager = builder.from_args(args) ``` """ - @staticmethod def from_args(args, packing_samples: bool = False) -> FilterWeightManager: """ @@ -385,33 +379,26 @@ def from_args(args, packing_samples: bool = False) -> FilterWeightManager: if getattr(args, "overlong_buffer", False): expected_len = getattr(args, "max_new_tokens", 1024) - getattr(args, "overlong_buffer_len", 0) buffer_len = getattr(args, "overlong_buffer_len", 0) - filters.append(ResponseLengthFilter( - expected_length=expected_len, - buffer_length=buffer_len - )) + filters.append(ResponseLengthFilter(expected_length=expected_len, buffer_length=buffer_len)) # Reward value filter (for dynamic sampling) if getattr(args, "dynamic_sampling", False) and getattr(args, "advantage_estimator", "") == "group_norm": - filters.append(RewardValueFilter( - n_samples_per_prompt=getattr(args, "n_samples_per_prompt", 1) - )) + filters.append(RewardValueFilter(n_samples_per_prompt=getattr(args, "n_samples_per_prompt", 1))) # Entropy filter if getattr(args, "enable_entropy_filter", False): - filters.append(EntropyFilter( - min_entropy=getattr(args, "min_entropy", None), - max_entropy=getattr(args, "max_entropy", None) - )) + filters.append( + EntropyFilter( + min_entropy=getattr(args, "min_entropy", None), max_entropy=getattr(args, "max_entropy", None) + ) + ) # Build weights weights = [] # Response length weighting if getattr(args, "enable_length_weighting", False): - weight = ResponseLengthWeighting( - mode=getattr(args, "length_weight_mode", "inverse"), - normalize=True - ) + weight = ResponseLengthWeighting(mode=getattr(args, "length_weight_mode", "inverse"), normalize=True) coef = getattr(args, "length_weight_coef", 1.0) weights.append((weight, coef)) @@ -437,16 +424,14 @@ def from_args(args, packing_samples: bool = False) -> FilterWeightManager: # Staleness weighting if getattr(args, "enable_staleness_weighting", False): - weight = StalenessWeighting( - decay_factor=getattr(args, "staleness_decay_factor", 0.95), - normalize=True - ) + weight = StalenessWeighting(decay_factor=getattr(args, "staleness_decay_factor", 0.95), normalize=True) coef = getattr(args, "staleness_weight_coef", 1.0) weights.append((weight, coef)) # Build enable_metrics dict enable_metrics = { - "entropy": getattr(args, "compute_entropy", False) or getattr(args, "enable_entropy_filter", False) or getattr(args, "enable_entropy_weighting", False), + "entropy": getattr(args, "compute_entropy", False) or getattr(args, "enable_entropy_filter", False) + or getattr(args, "enable_entropy_weighting", False), "difficulty": getattr(args, "enable_difficulty_weighting", False), "difficulty_mode": getattr(args, "difficulty_mode", "td_error"), "staleness": getattr(args, "enable_staleness_weighting", False), @@ -454,11 +439,5 @@ def from_args(args, packing_samples: bool = False) -> FilterWeightManager: } return FilterWeightManager( - filters=filters, - weights=weights, - enable_metrics=enable_metrics, - packing_samples=packing_samples + filters=filters, weights=weights, enable_metrics=enable_metrics, packing_samples=packing_samples ) - - - diff --git a/lightrft/trainer/filter_weight/metrics.py b/lightrft/trainer/filter_weight/metrics.py index d345a9e8..28e149af 100644 --- a/lightrft/trainer/filter_weight/metrics.py +++ b/lightrft/trainer/filter_weight/metrics.py @@ -54,7 +54,6 @@ class MetricsComputer: :param packing_samples: Whether samples are packed (affects unpacking logic) :type packing_samples: bool """ - def __init__(self, packing_samples: bool = False): """ Initialize metrics computer. @@ -131,10 +130,7 @@ def compute_logit_kl( reference_log_probs = torch.log_softmax(reference_logits, dim=-1) # KL(current || reference) = sum(p_curr * (log p_curr - log p_ref)) - kl_per_token = torch.sum( - torch.exp(current_log_probs) * (current_log_probs - reference_log_probs), - dim=-1 - ) + kl_per_token = torch.sum(torch.exp(current_log_probs) * (current_log_probs - reference_log_probs), dim=-1) # Handle packed vs unpacked samples if self.packing_samples and num_actions is not None: @@ -144,10 +140,7 @@ def compute_logit_kl( return masked_mean(kl_per_token, action_mask, dim=-1) def compute_difficulty( - self, - rewards: torch.Tensor, - values: Optional[torch.Tensor] = None, - mode: str = "td_error" + self, rewards: torch.Tensor, values: Optional[torch.Tensor] = None, mode: str = "td_error" ) -> torch.Tensor: """ Compute sample difficulty. @@ -194,10 +187,7 @@ def compute_difficulty( raise ValueError(f"Unknown difficulty mode: {mode}") def compute_staleness( - self, - generation_steps: torch.Tensor, - current_step: int, - mode: str = "linear" + self, generation_steps: torch.Tensor, current_step: int, mode: str = "linear" ) -> torch.Tensor: """ Compute sample staleness based on age. @@ -284,9 +274,7 @@ def compute_all_metrics( action_mask = torch.cat(action_mask_list, dim=0) if self.packing_samples: - entropy = self.compute_entropy( - action_log_probs, action_mask, num_actions_list - ) + entropy = self.compute_entropy(action_log_probs, action_mask, num_actions_list) else: entropy = self.compute_entropy(action_log_probs, action_mask) @@ -325,6 +313,3 @@ def compute_all_metrics( n_samples_per_prompt=None, # Can be set externally micro_batch_size=None, # Can be set externally ) - - - diff --git a/lightrft/trainer/filter_weight/weights.py b/lightrft/trainer/filter_weight/weights.py index 0efef420..dae2d498 100644 --- a/lightrft/trainer/filter_weight/weights.py +++ b/lightrft/trainer/filter_weight/weights.py @@ -17,7 +17,6 @@ class LossWeighting(ABC): Weightings compute per-sample weights that modulate the contribution of each sample to the loss function. """ - @abstractmethod def compute_weights( self, @@ -46,7 +45,6 @@ class ResponseLengthWeighting(LossWeighting): - Give more weight to shorter responses (mode="inverse") - Balance weights by length (mode="sqrt", "log") """ - def __init__( self, mode: str = "linear", @@ -123,13 +121,8 @@ class EntropyWeighting(LossWeighting): This can encourage exploration (favor high entropy) or exploitation (favor low entropy). """ - def __init__( - self, - mode: str = "favor_high", - temperature: float = 1.0, - normalize: bool = True, - epsilon: float = 1e-6 + self, mode: str = "favor_high", temperature: float = 1.0, normalize: bool = True, epsilon: float = 1e-6 ): """ Initialize entropy weighting. @@ -199,14 +192,7 @@ class DifficultyWeighting(LossWeighting): This implements prioritized experience replay (PER) style weighting or curriculum learning approaches. """ - - def __init__( - self, - mode: str = "prioritized", - alpha: float = 0.6, - normalize: bool = True, - epsilon: float = 1e-6 - ): + def __init__(self, mode: str = "prioritized", alpha: float = 0.6, normalize: bool = True, epsilon: float = 1e-6): """ Initialize difficulty weighting. @@ -276,13 +262,7 @@ class StalenessWeighting(LossWeighting): Older samples may be less relevant due to policy shift, so they receive exponentially decaying weights. """ - - def __init__( - self, - decay_factor: float = 0.95, - normalize: bool = True, - epsilon: float = 1e-6 - ): + def __init__(self, decay_factor: float = 0.95, normalize: bool = True, epsilon: float = 1e-6): """ Initialize staleness weighting. @@ -335,13 +315,8 @@ class RewardMagnitudeWeighting(LossWeighting): This can be used to focus on high-reward or low-reward samples. """ - def __init__( - self, - mode: str = "favor_high", - temperature: float = 1.0, - normalize: bool = True, - epsilon: float = 1e-6 + self, mode: str = "favor_high", temperature: float = 1.0, normalize: bool = True, epsilon: float = 1e-6 ): """ Initialize reward magnitude weighting. @@ -402,7 +377,6 @@ class CompositeWeighting(LossWeighting): This allows building complex weighting strategies by composing simple weightings. """ - def __init__( self, weightings: List[Tuple[LossWeighting, float]], @@ -498,7 +472,6 @@ class UniformWeighting(LossWeighting): This is a no-op weighting for baseline comparisons. """ - def __init__(self): """Initialize uniform weighting.""" pass @@ -517,6 +490,3 @@ def compute_weights(self, metrics, experiences): total_samples = sum(len(exp.sequences) for exp in experiences) device = experiences[0].sequences.device if experiences else 'cuda' return torch.ones(total_samples, device=device) - - - From a43ae21419a6dd76edced117360ef99a25c9d0e5 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Sun, 18 Jan 2026 13:23:42 +0800 Subject: [PATCH 5/8] feature(sunjx): fix dynamic_sampling bugs --- lightrft/strategy/strategy_base.py | 11 +++ lightrft/trainer/fast_exp_maker.py | 86 ++++++++++++++++++--- lightrft/trainer/ppo_trainer_vl.py | 110 ++++++++++++++++++++------- lightrft/trainer/spmd_ppo_trainer.py | 27 +++++-- 4 files changed, 193 insertions(+), 41 deletions(-) diff --git a/lightrft/strategy/strategy_base.py b/lightrft/strategy/strategy_base.py index d68e47b5..fbde2f08 100644 --- a/lightrft/strategy/strategy_base.py +++ b/lightrft/strategy/strategy_base.py @@ -402,6 +402,17 @@ def all_reduce(self, data, op="mean"): """ assert op in ("mean", "max", "sum") if isinstance(data, dict): + # Ensure consistent key order and presence across ranks to avoid NCCL hangs. + if dist.is_available() and dist.is_initialized() and self.world_size > 1: + local_keys = list(data.keys()) + gathered_keys = [None for _ in range(self.world_size)] + dist.all_gather_object(gathered_keys, local_keys) + all_keys = sorted({k for keys in gathered_keys for k in keys}) + ret = {} + for k in all_keys: + ret[k] = self.all_reduce(data.get(k, 0.0), op) + return ret + ret = {} for k, v in data.items(): ret[k] = self.all_reduce(v, op) diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index ec1b817a..4ddbf4d3 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -30,6 +30,7 @@ from copy import deepcopy import torch +import torch.distributed as dist from PIL import Image from easydict import EasyDict from vllm import SamplingParams @@ -49,6 +50,7 @@ ExperienceVL, SamplesVL, ) +from lightrft.trainer.replay_buffer_utils import make_experience_batch, split_experience_batch from lightrft.utils.remote_rm_utils import remote_rm_fn from lightrft.utils import Timer, get_current_device @@ -1415,15 +1417,81 @@ def _process_experiences( if config.dynamic_sampling and not use_dynamic_filter: # Legacy dynamic sampling (only if not using filter_weight framework) - step_size = config.n_samples_per_prompt // config.micro_train_batch_size - for i in range(0, len(experiences), step_size): - chunk = experiences[i:i + step_size] - chunk_rewards = torch.cat([exp.info["reward"] for exp in chunk]) - - # Filter out degenerate cases (all 0s or all 1s) - if torch.all(chunk_rewards == 0) or torch.all(chunk_rewards == 1): - for exp in chunk: - exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool) + group_size = config.n_samples_per_prompt + tolerance = 1e-6 + if rewards.numel() % group_size != 0: + warnings.warn( + f"Number of samples ({rewards.numel()}) not divisible by group_size ({group_size}). " + f"Skipping dynamic sampling filtering." + ) + else: + grouped_rewards = rewards.reshape(-1, group_size) + all_zeros = torch.all(torch.abs(grouped_rewards) < tolerance, dim=1) + all_ones = torch.all(torch.abs(grouped_rewards - 1.0) < tolerance, dim=1) + keep_group_mask = ~(all_zeros | all_ones) + keep_mask = keep_group_mask.repeat_interleave(group_size) + + # In distributed training, keep sample counts aligned across ranks. + # We apply filtering by masking (no removal) to avoid NCCL hangs. + is_distributed = dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 + if is_distributed: + if keep_mask.sum().item() == 0: + self.strategy.print( + "[Warning] No sample kept after filtering on this rank; skip filtering this step." + ) + keep_mask = torch.ones_like(keep_mask, dtype=torch.bool) + + if keep_mask.sum().item() < keep_mask.numel(): + offset = 0 + for exp in experiences: + batch_size = len(exp.sequences) + exp_mask = keep_mask[offset:offset + batch_size] + offset += batch_size + if exp.action_mask is not None: + exp.action_mask = exp.action_mask & exp_mask.unsqueeze(-1).to( + exp.action_mask.device + ) + exp_rewards = exp.info["reward"] + exp.info["reward"] = exp_rewards * exp_mask.to(exp_rewards.device).float() + + rewards = rewards * keep_mask.to(rewards.device).float() + else: + if keep_mask.sum().item() == 0: + raise RuntimeError("No sample is kept after filtering. Please check your data.") + + if keep_mask.sum().item() < keep_mask.numel(): + filtered_experiences = [] + filtered_rewards = [] + offset = 0 + for exp in experiences: + batch_size = len(exp.sequences) + exp_mask = keep_mask[offset:offset + batch_size] + offset += batch_size + if exp_mask.sum().item() == 0: + continue + + # Filter rewards for this batch + exp_rewards = exp.info["reward"] + filtered_rewards.append(exp_rewards[exp_mask.to(exp_rewards.device)]) + + if exp_mask.all(): + filtered_experiences.append(exp) + continue + + # Rebuild experience batch to keep shapes consistent (esp. pixel_values) + items = split_experience_batch(exp) + kept_items = [item for item, keep in zip(items, exp_mask.cpu().tolist()) if keep] + if not kept_items: + continue + filtered_experiences.append( + make_experience_batch(kept_items, packing_samples=self.packing_samples) + ) + + if not filtered_experiences: + raise RuntimeError("No sample is kept after filtering. Please check your data.") + + experiences = filtered_experiences + rewards = torch.cat(filtered_rewards) # # Normalize within groups # rewards = rewards.reshape(-1, config.n_samples_per_prompt).to("cuda") diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index e5b98d56..1ee310c5 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -689,11 +689,16 @@ def training_step_actor(self, experience: ExperienceVL) -> Dict[str, float]: base_action_log_probs = experience.base_action_log_probs if advantages is not None: - # Log max advantage before clipping for debugging (optional) - max_adv = advantages.max().item() - if max_adv > 10.0: - self.strategy.print(f"[Warning] Huge advantage detected: {max_adv}") - advantages = torch.clamp(advantages, min=-10.0, max=10.0) + # Check if advantages is empty (e.g., after dynamic sampling filtering) + if advantages.numel() == 0: + self.strategy.print("[Warning] Empty advantages after filtering; using zero-loss step.") + advantages = None + else: + # Log max advantage before clipping for debugging (optional) + max_adv = advantages.max().item() + if max_adv > 10.0: + self.strategy.print(f"[Warning] Huge advantage detected: {max_adv}") + advantages = torch.clamp(advantages, min=-10.0, max=10.0) # Actor loss action_log_probs, output = self.actor( @@ -712,23 +717,33 @@ def training_step_actor(self, experience: ExperienceVL) -> Dict[str, float]: # action_log_probs = action_log_probs * experience.action_mask # Loss function - actor_loss = self.actor_loss_fn( - action_log_probs, - old_action_log_probs, - advantages, - action_mask=experience.action_mask, - ) + if experience.action_mask is not None and experience.action_mask.sum().item() == 0: + # No valid actions; use zero loss to keep distributed steps in sync. + actor_loss = action_log_probs.sum() * 0.0 + else: + actor_loss = self.actor_loss_fn( + action_log_probs, + old_action_log_probs, + advantages, + action_mask=experience.action_mask, + ) if self.args.use_kl_loss: if self.initial_model is not None: # TODO(pu): Text-only action mask for KL calculation - - kl = compute_approx_kl( - action_log_probs, - base_action_log_probs, - experience.action_mask, - kl_estimator=self.args.kl_estimator, - ) + # If no valid actions or base log-probs are empty, skip KL safely. + if ((experience.action_mask is not None and experience.action_mask.sum().item() == 0) + or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)): + kl = torch.zeros_like( + action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device + ) + else: + kl = compute_approx_kl( + action_log_probs, + base_action_log_probs, + experience.action_mask, + kl_estimator=self.args.kl_estimator, + ) # [Protection measure 2] Per-token KL Clamping # NOTE: Adding this causes svkng training to not converge @@ -738,12 +753,22 @@ def training_step_actor(self, experience: ExperienceVL) -> Dict[str, float]: kl = torch.zeros_like(action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device) if not self.args.packing_samples: - kl_mean = masked_mean(kl, experience.action_mask, dim=-1) + # Guard against empty or mismatched masks (can happen after filtering) + if ( + experience.action_mask is None or experience.action_mask.numel() == 0 + or experience.action_mask.size(-1) != kl.size(-1) or experience.action_mask.sum().item() == 0 + ): + kl_mean = torch.zeros(kl.size(0), device=kl.device, dtype=kl.dtype) + else: + kl_mean = masked_mean(kl, experience.action_mask, dim=-1) # Not supported for packed samples else: # Convert tensor into list of tensors for easier manipulation within dataset kl = unpacking_samples(kl, num_actions) - kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device) + if not kl: + kl_mean = torch.zeros(0, device=action_log_probs.device) + else: + kl_mean = torch.tensor([each_kl.mean() for each_kl in kl], device=action_log_probs.device) kl_loss = kl_mean.mean() experience.info["kl"] = kl_loss.item() @@ -903,6 +928,18 @@ def ensure_device_and_contiguous(tensor, name="tensor"): # TODO: This is a bad indicator to say that data is packed... if isinstance(experience.sequences, list): + # Check if sequences list is empty after filtering + if len(experience.sequences) == 0 or all(s.numel() == 0 for s in experience.sequences): + self.strategy.print("[Warning] No valid samples for critic after filtering, using zero-loss step") + # Create zero loss to keep distributed steps in sync + dummy = sum([s.sum() for s in experience.sequences]) * 0.0 + self.strategy.backward(dummy, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic") + return { + "critic_loss": 0.0, + "values": 0.0, + "critic_lr": self.critic_scheduler.get_last_lr()[0] if self.critic_scheduler else 0.0, + } sequences = torch.cat(experience.sequences, dim=0).unsqueeze(0) old_values = torch.cat(experience.values, dim=0).unsqueeze(0) returns = torch.cat(experience.returns, dim=0).unsqueeze(0) @@ -917,6 +954,17 @@ def ensure_device_and_contiguous(tensor, name="tensor"): num_actions = experience.action_mask.size(1) packed_seq_lens = None attention_mask = experience.attention_mask + # Check if sequences is empty after filtering + if sequences.numel() == 0: + self.strategy.print("[Warning] No valid samples for critic after filtering, using zero-loss step") + dummy = sequences.sum() * 0.0 + self.strategy.backward(dummy, self.critic, self.critic_optim) + self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic") + return { + "critic_loss": 0.0, + "values": 0.0, + "critic_lr": self.critic_scheduler.get_last_lr()[0] if self.critic_scheduler else 0.0, + } # Ensure sequences and attention_mask are also on device and contiguous sequences = ensure_device_and_contiguous(sequences, "sequences") @@ -933,12 +981,16 @@ def ensure_device_and_contiguous(tensor, name="tensor"): packed_seq_lens=packed_seq_lens, ) # Loss function - critic_loss = self.critic_loss_fn( - values, - old_values, - returns, - action_mask=experience.action_mask, - ) + if experience.action_mask is not None and experience.action_mask.sum().item() == 0: + # No valid actions; use zero loss to keep distributed steps in sync. + critic_loss = values.sum() * 0.0 + else: + critic_loss = self.critic_loss_fn( + values, + old_values, + returns, + action_mask=experience.action_mask, + ) # Mixtral auxiliary loss if self.aux_loss: aux_loss = output.aux_loss @@ -949,9 +1001,13 @@ def ensure_device_and_contiguous(tensor, name="tensor"): self.strategy.optimizer_step(self.critic_optim, self.critic, self.critic_scheduler, name="critic") # Status + if experience.action_mask is not None and experience.action_mask.sum().item() == 0: + values_mean = 0.0 + else: + values_mean = masked_mean(values, experience.action_mask).item() status = { "critic_loss": critic_loss.item(), - "values": masked_mean(values, experience.action_mask).item(), + "values": values_mean, "critic_lr": self.critic_scheduler.get_last_lr()[0], } return status diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index 414f7373..25dbfe46 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -214,11 +214,18 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train status = self.training_step(experience, global_steps) # for DP - # weighted mean for kl - if "kl" in status: + # weighted mean for kl (ensure all ranks enter collective) + if self.strategy.world_size > 1: + if "kl" not in status: + status["kl"] = 0.0 + if "response_length" not in status: + status["response_length"] = 0.0 status["kl"] *= status["response_length"] status = self.strategy.all_reduce(status) - status["kl"] /= status["response_length"] + if status["response_length"] != 0: + status["kl"] /= status["response_length"] + else: + status["kl"] = 0.0 # Training epoch progress bar: show per-batch metrics for detailed monitoring short_status = {} @@ -254,10 +261,20 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train # "kl": KL divergence # "act_lr": actor_lr if status_list: - status_mean = status_list[0] - for m in status_list[1:]: + # Collect all unique keys from all status dicts + all_keys = set() + for m in status_list: + all_keys.update(m.keys()) + + # Initialize status_mean with all keys set to 0 + status_mean = {k: 0.0 for k in all_keys} + + # Sum up values from all status dicts + for m in status_list: for k, v in m.items(): status_mean[k] += v + + # Compute mean for k in status_mean.keys(): status_mean[k] /= len(status_list) From 7d8dea47c6ece1a56ac1b2258e84d539fa9574b8 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Sun, 18 Jan 2026 13:51:43 +0800 Subject: [PATCH 6/8] refactor(sunjx): pass formt and fcheck --- lightrft/trainer/ppo_trainer_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index c9b60671..4aa3637e 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -759,7 +759,7 @@ def training_step_actor(self, # TODO(pu): Text-only action mask for KL calculation # If no valid actions or base log-probs are empty, skip KL safely. if ((experience.action_mask is not None and experience.action_mask.sum().item() == 0) - or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)): + or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)): kl = torch.zeros_like( action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device ) From a659c0026f62e5a9de16dbd76784dbed776f4c94 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Tue, 20 Jan 2026 11:15:09 +0800 Subject: [PATCH 7/8] refactor(sunjx): pass format and fcheck check --- lightrft/trainer/ppo_trainer_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 4aa3637e..c9b60671 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -759,7 +759,7 @@ def training_step_actor(self, # TODO(pu): Text-only action mask for KL calculation # If no valid actions or base log-probs are empty, skip KL safely. if ((experience.action_mask is not None and experience.action_mask.sum().item() == 0) - or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)): + or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)): kl = torch.zeros_like( action_log_probs, dtype=action_log_probs.dtype, device=action_log_probs.device ) From 97f8a9260ce95ae6c42967cc60fbc15cc9fb25d5 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Tue, 20 Jan 2026 21:57:49 +0800 Subject: [PATCH 8/8] refactor(sunjx): Organize the code --- lightrft/strategy/config.py | 66 ++++++++++ lightrft/trainer/fast_exp_maker.py | 140 ++++------------------ lightrft/trainer/filter_weight/manager.py | 104 +++++++++------- lightrft/trainer/filter_weight/metrics.py | 11 +- 4 files changed, 158 insertions(+), 163 deletions(-) diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index c6993005..87154e4f 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -133,6 +133,58 @@ class StrategyConfig: # (bool): Use TensorBoard for logging, defaults to False use_tensorboard: bool = False + # Filter and weight parameters + # (int): Maximum number of new tokens to generate, defaults to 1024 + max_new_tokens: int = 1024 + + # Entropy-based filtering and weighting + # (bool): Enable entropy-based filtering, defaults to False + enable_entropy_filter: bool = False + # (Optional[float]): Minimum entropy threshold for filtering, defaults to None + min_entropy: Optional[float] = None + # (Optional[float]): Maximum entropy threshold for filtering, defaults to None + max_entropy: Optional[float] = None + # (bool): Compute entropy metrics, defaults to False + compute_entropy: bool = False + # (bool): Enable entropy-based loss weighting, defaults to False + enable_entropy_weighting: bool = False + # (str): Entropy weighting mode ("favor_high", "favor_low"), defaults to "favor_high" + entropy_weight_mode: str = "favor_high" + # (float): Temperature parameter for entropy weighting, defaults to 1.0 + entropy_weight_temperature: float = 1.0 + # (float): Coefficient for entropy weighting, defaults to 1.0 + entropy_weight_coef: float = 1.0 + + # Length-based weighting + # (bool): Enable response length-based loss weighting, defaults to False + enable_length_weighting: bool = False + # (str): Length weighting mode ("inverse", "linear", "squared"), defaults to "inverse" + length_weight_mode: str = "inverse" + # (float): Coefficient for length weighting, defaults to 1.0 + length_weight_coef: float = 1.0 + + # Difficulty-based weighting + # (bool): Enable difficulty-based loss weighting, defaults to False + enable_difficulty_weighting: bool = False + # (str): Difficulty weighting mode ("prioritized", "linear"), defaults to "prioritized" + difficulty_weight_mode: str = "prioritized" + # (str): Difficulty computation mode ("td_error", "abs_advantage"), defaults to "td_error" + difficulty_mode: str = "td_error" + # (float): Alpha parameter for prioritized difficulty weighting, defaults to 0.6 + difficulty_alpha: float = 0.6 + # (float): Coefficient for difficulty weighting, defaults to 1.0 + difficulty_weight_coef: float = 1.0 + + # Staleness-based weighting + # (bool): Enable staleness-based loss weighting, defaults to False + enable_staleness_weighting: bool = False + # (float): Decay factor for staleness weighting, defaults to 0.95 + staleness_decay_factor: float = 0.95 + # (str): Staleness computation mode ("linear", "exponential"), defaults to "linear" + staleness_mode: str = "linear" + # (float): Coefficient for staleness weighting, defaults to 1.0 + staleness_weight_coef: float = 1.0 + # Additional arguments for backward compatibility # (Dict[str, Any]): Extra arguments for backward compatibility, defaults to {} extra_args: Dict[str, Any] = field(default_factory=dict) @@ -310,6 +362,20 @@ def print_config_summary(self) -> None: status = "Overridden" if current != default else "Default" print(f" {attr}: {current} ({status})") + # Filter and Weight Parameters + print("\nFilter and Weight Parameters:") + for attr in [ + 'max_new_tokens', 'enable_entropy_filter', 'min_entropy', 'max_entropy', 'compute_entropy', + 'enable_entropy_weighting', 'entropy_weight_mode', 'entropy_weight_temperature', 'entropy_weight_coef', + 'enable_length_weighting', 'length_weight_mode', 'length_weight_coef', 'enable_difficulty_weighting', + 'difficulty_weight_mode', 'difficulty_mode', 'difficulty_alpha', 'difficulty_weight_coef', + 'enable_staleness_weighting', 'staleness_decay_factor', 'staleness_mode', 'staleness_weight_coef' + ]: + current = getattr(self, attr) + default = getattr(default_config, attr) + status = "Overridden" if current != default else "Default" + print(f" {attr}: {current} ({status})") + # extra_args if self.extra_args: print("\nExtra Parameters (extra_args):") diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index c01873fb..060f8d52 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -1036,15 +1036,10 @@ def make_experience_list( current_step = getattr(self.strategy, "global_step", None) metrics = self.filter_weight_manager.compute_metrics(outputs, current_step=current_step) - # Apply filters and weights - experiences, sample_weights = self.filter_weight_manager.apply_to_experiences(experiences, metrics) - - # Store sample weights in experience info for later use - sample_idx = 0 - for exp in experiences: - batch_size = len(exp.sequences) - exp.info["sample_weights"] = sample_weights[sample_idx:sample_idx + batch_size] - sample_idx += batch_size + # Apply filters and weights (with distributed training support) + experiences, _ = self.filter_weight_manager.apply_to_experiences( + experiences, metrics, strategy=self.strategy + ) # ========== Stage 5: Reward Processing ========== experiences, rewards = self._process_experiences( # GRPO's -mean / std operation is performed in this method @@ -1440,29 +1435,25 @@ def _process_experiences( config = self.strategy.config rewards = torch.cat([exp.info["reward"] for exp in experiences]) - # ========== Overlong Sequence Penalty ========== - # Use new filter_weight framework if enabled, otherwise use legacy logic - from .filter_weight import ResponseLengthFilter - use_filter_weight = ( - self.filter_weight_manager is not None and self.filter_weight_manager.filters - and any(isinstance(f, ResponseLengthFilter) for f in self.filter_weight_manager.filters) - ) - - if config.overlong_buffer and not use_filter_weight: - # Legacy overlong buffer penalty (only if not using filter_weight framework) - expected_len = max_new_tokens - config.overlong_buffer_len - actual_lens = torch.cat([exp.action_mask.sum(dim=1) for exp in experiences]) - exceed_len = actual_lens - expected_len - - # Penalty: clamp(-exceed_len / buffer_len * penalty_factor, max=0) - penalty = torch.clamp( - -exceed_len / config.overlong_buffer_len * config.overlong_buffer_penalty_factor, max=0.0 - ) - rewards += penalty - # ========== Dynamic Sampling Warning ========== - if config.dynamic_sampling and config.advantage_estimator in ["rloo", "reinforce_baseline"]: - warnings.warn(f"dynamic_sampling not implemented for {config.advantage_estimator}, ignoring", UserWarning) + if config.dynamic_sampling: + if config.advantage_estimator in ["rloo", "reinforce_baseline"]: + warnings.warn( + f"dynamic_sampling not implemented for {config.advantage_estimator}, ignoring", UserWarning + ) + elif config.advantage_estimator in ["group_norm", "grpo"]: + # Check if filter_weight_manager is properly configured + from .filter_weight import RewardValueFilter + has_reward_filter = ( + self.filter_weight_manager is not None and self.filter_weight_manager.filters + and any(isinstance(f, RewardValueFilter) for f in self.filter_weight_manager.filters) + ) + if not has_reward_filter: + warnings.warn( + "dynamic_sampling is enabled but FilterWeightManager is not configured with " + "RewardValueFilter. Dynamic sampling will not be applied. " + "Please ensure filter_weight_manager is properly initialized.", UserWarning + ) # ========== Advantage Estimator-Specific Shaping ========== if config.advantage_estimator == "rloo": @@ -1481,92 +1472,7 @@ def _process_experiences( return experiences, rewards elif config.advantage_estimator in ["group_norm", "grpo"]: - # Group normalization with optional dynamic filtering - # Use new filter_weight framework if enabled, otherwise use legacy logic - from .filter_weight import RewardValueFilter - use_dynamic_filter = ( - self.filter_weight_manager is not None and self.filter_weight_manager.filters - and any(isinstance(f, RewardValueFilter) for f in self.filter_weight_manager.filters) - ) - - if config.dynamic_sampling and not use_dynamic_filter: - # Legacy dynamic sampling (only if not using filter_weight framework) - group_size = config.n_samples_per_prompt - tolerance = 1e-6 - if rewards.numel() % group_size != 0: - warnings.warn( - f"Number of samples ({rewards.numel()}) not divisible by group_size ({group_size}). " - f"Skipping dynamic sampling filtering." - ) - else: - grouped_rewards = rewards.reshape(-1, group_size) - all_zeros = torch.all(torch.abs(grouped_rewards) < tolerance, dim=1) - all_ones = torch.all(torch.abs(grouped_rewards - 1.0) < tolerance, dim=1) - keep_group_mask = ~(all_zeros | all_ones) - keep_mask = keep_group_mask.repeat_interleave(group_size) - - # In distributed training, keep sample counts aligned across ranks. - # We apply filtering by masking (no removal) to avoid NCCL hangs. - is_distributed = dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 - if is_distributed: - if keep_mask.sum().item() == 0: - self.strategy.print( - "[Warning] No sample kept after filtering on this rank; skip filtering this step." - ) - keep_mask = torch.ones_like(keep_mask, dtype=torch.bool) - - if keep_mask.sum().item() < keep_mask.numel(): - offset = 0 - for exp in experiences: - batch_size = len(exp.sequences) - exp_mask = keep_mask[offset:offset + batch_size] - offset += batch_size - if exp.action_mask is not None: - exp.action_mask = exp.action_mask & exp_mask.unsqueeze(-1).to( - exp.action_mask.device - ) - exp_rewards = exp.info["reward"] - exp.info["reward"] = exp_rewards * exp_mask.to(exp_rewards.device).float() - - rewards = rewards * keep_mask.to(rewards.device).float() - else: - if keep_mask.sum().item() == 0: - raise RuntimeError("No sample is kept after filtering. Please check your data.") - - if keep_mask.sum().item() < keep_mask.numel(): - filtered_experiences = [] - filtered_rewards = [] - offset = 0 - for exp in experiences: - batch_size = len(exp.sequences) - exp_mask = keep_mask[offset:offset + batch_size] - offset += batch_size - if exp_mask.sum().item() == 0: - continue - - # Filter rewards for this batch - exp_rewards = exp.info["reward"] - filtered_rewards.append(exp_rewards[exp_mask.to(exp_rewards.device)]) - - if exp_mask.all(): - filtered_experiences.append(exp) - continue - - # Rebuild experience batch to keep shapes consistent (esp. pixel_values) - items = split_experience_batch(exp) - kept_items = [item for item, keep in zip(items, exp_mask.cpu().tolist()) if keep] - if not kept_items: - continue - filtered_experiences.append( - make_experience_batch(kept_items, packing_samples=self.packing_samples) - ) - - if not filtered_experiences: - raise RuntimeError("No sample is kept after filtering. Please check your data.") - - experiences = filtered_experiences - rewards = torch.cat(filtered_rewards) - + # Group normalization # # Normalize within groups # rewards = rewards.reshape(-1, config.n_samples_per_prompt).to("cuda") # rewards = (rewards - rewards.mean(-1, keepdim=True)) / (rewards.std(-1, keepdim=True) + 1e-9) diff --git a/lightrft/trainer/filter_weight/manager.py b/lightrft/trainer/filter_weight/manager.py index 146b04e2..264aa9ee 100644 --- a/lightrft/trainer/filter_weight/manager.py +++ b/lightrft/trainer/filter_weight/manager.py @@ -200,16 +200,20 @@ def apply_to_experiences( experiences: List, # List[ExperienceVL] metrics: SampleMetrics, apply_filter_to_mask: bool = True, - apply_filter_to_weights: bool = True + apply_filter_to_weights: bool = True, + handle_distributed: bool = True, + strategy=None ) -> Tuple[List, torch.Tensor]: """ Apply filtering and weighting to experiences. This method: 1. Computes filter mask - 2. Updates action_mask to exclude filtered samples (if apply_filter_to_mask=True) - 3. Computes loss weights - 4. Zeros out weights for filtered samples (if apply_filter_to_weights=True) + 2. Handles distributed training edge cases (e.g., all samples filtered) + 3. Updates action_mask to exclude filtered samples (if apply_filter_to_mask=True) + 4. Optionally updates exp.info["reward"] for filtered samples + 5. Computes loss weights + 6. Zeros out weights for filtered samples (if apply_filter_to_weights=True) :param experiences: List of Experience/ExperienceVL objects to process :type experiences: List @@ -219,14 +223,38 @@ def apply_to_experiences( :type apply_filter_to_mask: bool :param apply_filter_to_weights: If True, zero out weights for filtered samples :type apply_filter_to_weights: bool + :param handle_distributed: If True, handle distributed training edge cases + :type handle_distributed: bool + :param strategy: Strategy object for logging (optional) + :type strategy: Optional[Any] :return: Modified experiences and per-sample weights :rtype: Tuple[List, torch.Tensor] """ + import torch.distributed as dist + # Apply filters keep_mask = self.apply_filters(metrics, experiences) + # Handle distributed training edge cases + if handle_distributed: + is_distributed = dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1 + + if is_distributed and keep_mask.sum().item() == 0: + # All samples filtered on this rank - reset mask to avoid NCCL issues + if strategy is not None: + strategy.print( + "[Warning] FilterWeightManager: No sample kept after filtering on this rank; " + "skipping filtering this step to maintain synchronization." + ) + else: + warnings.warn( + "FilterWeightManager: No sample kept after filtering on this rank; " + "skipping filtering this step." + ) + keep_mask = torch.ones_like(keep_mask, dtype=torch.bool) + # Update action masks if requested - if apply_filter_to_mask: + if apply_filter_to_mask and keep_mask.sum().item() < keep_mask.numel(): sample_idx = 0 for exp in experiences: batch_size = len(exp.sequences) @@ -237,6 +265,11 @@ def apply_to_experiences( if exp.action_mask is not None: exp.action_mask = exp.action_mask & batch_mask.unsqueeze(-1).to(exp.action_mask.device) + # Also mask rewards in exp.info for consistency with legacy behavior + if "reward" in exp.info: + exp_rewards = exp.info["reward"] + exp.info["reward"] = exp_rewards * batch_mask.to(exp_rewards.device).float() + sample_idx += batch_size # Compute weights @@ -357,7 +390,7 @@ def from_args(args, packing_samples: bool = False) -> FilterWeightManager: """ Build FilterWeightManager from training arguments. - :param args: Training arguments object + :param args: Training arguments object with filter/weight configuration :type args: Any :param packing_samples: Whether samples are packed :type packing_samples: bool @@ -376,66 +409,55 @@ def from_args(args, packing_samples: bool = False) -> FilterWeightManager: filters = [] # Response length filter (from overlong_buffer settings) - if getattr(args, "overlong_buffer", False): - expected_len = getattr(args, "max_new_tokens", 1024) - getattr(args, "overlong_buffer_len", 0) - buffer_len = getattr(args, "overlong_buffer_len", 0) + if args.overlong_buffer: + expected_len = args.max_new_tokens - args.overlong_buffer_len + buffer_len = args.overlong_buffer_len filters.append(ResponseLengthFilter(expected_length=expected_len, buffer_length=buffer_len)) # Reward value filter (for dynamic sampling) - if getattr(args, "dynamic_sampling", False) and getattr(args, "advantage_estimator", "") == "group_norm": - filters.append(RewardValueFilter(n_samples_per_prompt=getattr(args, "n_samples_per_prompt", 1))) + if args.dynamic_sampling and args.advantage_estimator == "group_norm": + filters.append(RewardValueFilter(n_samples_per_prompt=args.n_samples_per_prompt)) # Entropy filter - if getattr(args, "enable_entropy_filter", False): - filters.append( - EntropyFilter( - min_entropy=getattr(args, "min_entropy", None), max_entropy=getattr(args, "max_entropy", None) - ) - ) + if args.enable_entropy_filter: + filters.append(EntropyFilter(min_entropy=args.min_entropy, max_entropy=args.max_entropy)) # Build weights weights = [] # Response length weighting - if getattr(args, "enable_length_weighting", False): - weight = ResponseLengthWeighting(mode=getattr(args, "length_weight_mode", "inverse"), normalize=True) - coef = getattr(args, "length_weight_coef", 1.0) + if args.enable_length_weighting: + weight = ResponseLengthWeighting(mode=args.length_weight_mode, normalize=True) + coef = args.length_weight_coef weights.append((weight, coef)) # Entropy weighting - if getattr(args, "enable_entropy_weighting", False): + if args.enable_entropy_weighting: weight = EntropyWeighting( - mode=getattr(args, "entropy_weight_mode", "favor_high"), - temperature=getattr(args, "entropy_weight_temperature", 1.0), - normalize=True + mode=args.entropy_weight_mode, temperature=args.entropy_weight_temperature, normalize=True ) - coef = getattr(args, "entropy_weight_coef", 1.0) + coef = args.entropy_weight_coef weights.append((weight, coef)) # Difficulty weighting - if getattr(args, "enable_difficulty_weighting", False): - weight = DifficultyWeighting( - mode=getattr(args, "difficulty_weight_mode", "prioritized"), - alpha=getattr(args, "difficulty_alpha", 0.6), - normalize=True - ) - coef = getattr(args, "difficulty_weight_coef", 1.0) + if args.enable_difficulty_weighting: + weight = DifficultyWeighting(mode=args.difficulty_weight_mode, alpha=args.difficulty_alpha, normalize=True) + coef = args.difficulty_weight_coef weights.append((weight, coef)) # Staleness weighting - if getattr(args, "enable_staleness_weighting", False): - weight = StalenessWeighting(decay_factor=getattr(args, "staleness_decay_factor", 0.95), normalize=True) - coef = getattr(args, "staleness_weight_coef", 1.0) + if args.enable_staleness_weighting: + weight = StalenessWeighting(decay_factor=args.staleness_decay_factor, normalize=True) + coef = args.staleness_weight_coef weights.append((weight, coef)) # Build enable_metrics dict enable_metrics = { - "entropy": getattr(args, "compute_entropy", False) or getattr(args, "enable_entropy_filter", False) - or getattr(args, "enable_entropy_weighting", False), - "difficulty": getattr(args, "enable_difficulty_weighting", False), - "difficulty_mode": getattr(args, "difficulty_mode", "td_error"), - "staleness": getattr(args, "enable_staleness_weighting", False), - "staleness_mode": getattr(args, "staleness_mode", "linear"), + "entropy": args.compute_entropy or args.enable_entropy_filter or args.enable_entropy_weighting, + "difficulty": args.enable_difficulty_weighting, + "difficulty_mode": args.difficulty_mode, + "staleness": args.enable_staleness_weighting, + "staleness_mode": args.staleness_mode, } return FilterWeightManager( diff --git a/lightrft/trainer/filter_weight/metrics.py b/lightrft/trainer/filter_weight/metrics.py index 28e149af..763ea98b 100644 --- a/lightrft/trainer/filter_weight/metrics.py +++ b/lightrft/trainer/filter_weight/metrics.py @@ -303,12 +303,13 @@ def compute_all_metrics( # In practice, you'd need to add generation_step to _SamplesOutput pass + # NOTE: The following metrics have not been fully tested yet return SampleMetrics( - response_length=response_lengths, - entropy=entropy, - logit_kl=logit_kl, - difficulty=difficulty, - staleness=staleness, + response_length=response_lengths, # Not tested + entropy=entropy, # Not tested + logit_kl=logit_kl, # Not tested + difficulty=difficulty, # Not tested + staleness=staleness, # Not tested reward_value=rewards, n_samples_per_prompt=None, # Can be set externally micro_batch_size=None, # Can be set externally