Skip to content
Open
66 changes: 66 additions & 0 deletions lightrft/strategy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,58 @@ class StrategyConfig:
# (bool): Use TensorBoard for logging, defaults to False
use_tensorboard: bool = False

# Filter and weight parameters
# (int): Maximum number of new tokens to generate, defaults to 1024
max_new_tokens: int = 1024

# Entropy-based filtering and weighting
# (bool): Enable entropy-based filtering, defaults to False
enable_entropy_filter: bool = False
# (Optional[float]): Minimum entropy threshold for filtering, defaults to None
min_entropy: Optional[float] = None
# (Optional[float]): Maximum entropy threshold for filtering, defaults to None
max_entropy: Optional[float] = None
# (bool): Compute entropy metrics, defaults to False
compute_entropy: bool = False
# (bool): Enable entropy-based loss weighting, defaults to False
enable_entropy_weighting: bool = False
# (str): Entropy weighting mode ("favor_high", "favor_low"), defaults to "favor_high"
entropy_weight_mode: str = "favor_high"
# (float): Temperature parameter for entropy weighting, defaults to 1.0
entropy_weight_temperature: float = 1.0
# (float): Coefficient for entropy weighting, defaults to 1.0
entropy_weight_coef: float = 1.0

# Length-based weighting
# (bool): Enable response length-based loss weighting, defaults to False
enable_length_weighting: bool = False
# (str): Length weighting mode ("inverse", "linear", "squared"), defaults to "inverse"
length_weight_mode: str = "inverse"
# (float): Coefficient for length weighting, defaults to 1.0
length_weight_coef: float = 1.0

# Difficulty-based weighting
# (bool): Enable difficulty-based loss weighting, defaults to False
enable_difficulty_weighting: bool = False
# (str): Difficulty weighting mode ("prioritized", "linear"), defaults to "prioritized"
difficulty_weight_mode: str = "prioritized"
# (str): Difficulty computation mode ("td_error", "abs_advantage"), defaults to "td_error"
difficulty_mode: str = "td_error"
# (float): Alpha parameter for prioritized difficulty weighting, defaults to 0.6
difficulty_alpha: float = 0.6
# (float): Coefficient for difficulty weighting, defaults to 1.0
difficulty_weight_coef: float = 1.0

# Staleness-based weighting
# (bool): Enable staleness-based loss weighting, defaults to False
enable_staleness_weighting: bool = False
# (float): Decay factor for staleness weighting, defaults to 0.95
staleness_decay_factor: float = 0.95
# (str): Staleness computation mode ("linear", "exponential"), defaults to "linear"
staleness_mode: str = "linear"
# (float): Coefficient for staleness weighting, defaults to 1.0
staleness_weight_coef: float = 1.0

# Additional arguments for backward compatibility
# (Dict[str, Any]): Extra arguments for backward compatibility, defaults to {}
extra_args: Dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -310,6 +362,20 @@ def print_config_summary(self) -> None:
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")

# Filter and Weight Parameters
print("\nFilter and Weight Parameters:")
for attr in [
'max_new_tokens', 'enable_entropy_filter', 'min_entropy', 'max_entropy', 'compute_entropy',
'enable_entropy_weighting', 'entropy_weight_mode', 'entropy_weight_temperature', 'entropy_weight_coef',
'enable_length_weighting', 'length_weight_mode', 'length_weight_coef', 'enable_difficulty_weighting',
'difficulty_weight_mode', 'difficulty_mode', 'difficulty_alpha', 'difficulty_weight_coef',
'enable_staleness_weighting', 'staleness_decay_factor', 'staleness_mode', 'staleness_weight_coef'
]:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
print(f" {attr}: {current} ({status})")

# extra_args
if self.extra_args:
print("\nExtra Parameters (extra_args):")
Expand Down
11 changes: 11 additions & 0 deletions lightrft/strategy/strategy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,17 @@ def all_reduce(self,
"""
assert op in ("mean", "max", "sum")
if isinstance(data, dict):
# Ensure consistent key order and presence across ranks to avoid NCCL hangs.
if dist.is_available() and dist.is_initialized() and self.world_size > 1:
local_keys = list(data.keys())
gathered_keys = [None for _ in range(self.world_size)]
dist.all_gather_object(gathered_keys, local_keys)
all_keys = sorted({k for keys in gathered_keys for k in keys})
ret = {}
for k in all_keys:
ret[k] = self.all_reduce(data.get(k, 0.0), op)
return ret

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this added? Does it cause an error without it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to prevent deadlock in distributed all-reduce operations.
After dynamic sampling, the set of keys in the status dictionary may differ across ranks (some ranks have keys like kl and ptx_loss, while others do not). The all_reduce(dict) operation calls dist.all_reduce for each key individually. If the keys or their order differ between ranks, the collective operations will be inconsistent, causing the process to hang.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动确实有效防止了 NCCL deadlock,但是,这里的实现有两个比较严重的隐患:

数学逻辑问题:如果 op="mean",对于缺失 key 的 Rank 默认补 0.0,这会把 0.0 计入分子,并除以 world_size。这会严重拉低该指标的真实均值(比如只有一张卡有 KL=1.0,4卡平均后变成了 0.25)。
架构与性能问题:all_reduce 作为底层通信原语,内部高频调用 all_gather_object (依赖 pickle) 会带来性能损耗;而且在底层强行补 0.0 掩盖了上游状态不对齐的问题。

建议的修改方向:

最好不要在底层 all_reduce 中做 key 的对齐。我们应该在调用 all_reduce 之前的业务层(比如 metrics logging 处),显式地初始化所有可能的 keys。对于被 filter 掉的 rank,可以传 0.0 并配合一个 valid_count 掩码,最后用 sum(values) / sum(valid_counts) 来计算准确的 mean。


ret = {}
for k, v in data.items():
ret[k] = self.all_reduce(v, op)
Expand Down
81 changes: 52 additions & 29 deletions lightrft/trainer/fast_exp_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from copy import deepcopy

import torch
import torch.distributed as dist
import numpy as np
from PIL import Image
from easydict import EasyDict
Expand All @@ -50,10 +51,12 @@
ExperienceVL,
SamplesVL,
)
from lightrft.trainer.replay_buffer_utils import make_experience_batch, split_experience_batch

from lightrft.utils.remote_rm_utils import remote_rm_fn
from lightrft.utils import Timer, get_current_device
from .utils import RunningMoments, compute_clip_fraction, get_cpgd_advantages_returns, fire_sampling
from .filter_weight import FilterWeightManagerBuilder
from .image_utils import normalize_images, get_images_num
from .video_utils import normalize_videos, get_videos_num

Expand Down Expand Up @@ -923,6 +926,19 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg
packing_samples=self.packing_samples,
)

# Initialize filter-weight manager
self.filter_weight_manager = self._init_filter_weight_manager()

def _init_filter_weight_manager(self):
"""
Initialize filter-weight manager from strategy args.

:return: FilterWeightManager instance
:rtype: FilterWeightManager
"""
args = self.strategy.args
return FilterWeightManagerBuilder.from_args(args, packing_samples=self.packing_samples)

# ========================================================================
# Public API Methods
# ========================================================================
Expand Down Expand Up @@ -1006,12 +1022,25 @@ def make_experience_list(

# ========== Stage 3: Model Inference ==========
Timer.start(' make_experience')
experiences = self._make_experience_list_by_model(all_samples)
experiences, outputs = self._make_experience_list_by_model(all_samples)
Timer.stop(' make_experience')

# ========== Stage 4: Shard-Parallel Postprocessing ==========
experiences = self.strategy.sp_data_processor.postprocess(experiences)

# ========== Stage 4.5: Apply Filter-Weight Framework ==========
if self.filter_weight_manager is not None and (
self.filter_weight_manager.filters or self.filter_weight_manager.weights
):
# Compute metrics from outputs
current_step = getattr(self.strategy, "global_step", None)
metrics = self.filter_weight_manager.compute_metrics(outputs, current_step=current_step)

# Apply filters and weights (with distributed training support)
experiences, _ = self.filter_weight_manager.apply_to_experiences(
experiences, metrics, strategy=self.strategy
)

# ========== Stage 5: Reward Processing ==========
experiences, rewards = self._process_experiences( # GRPO's -mean / std operation is performed in this method
experiences, generate_kwargs.get("max_new_tokens", 1024)
Expand Down Expand Up @@ -1406,21 +1435,25 @@ def _process_experiences(
config = self.strategy.config
rewards = torch.cat([exp.info["reward"] for exp in experiences])

# ========== Overlong Sequence Penalty ==========
if config.overlong_buffer:
expected_len = max_new_tokens - config.overlong_buffer_len
actual_lens = torch.cat([exp.action_mask.sum(dim=1) for exp in experiences])
exceed_len = actual_lens - expected_len

# Penalty: clamp(-exceed_len / buffer_len * penalty_factor, max=0)
penalty = torch.clamp(
-exceed_len / config.overlong_buffer_len * config.overlong_buffer_penalty_factor, max=0.0
)
rewards += penalty

# ========== Dynamic Sampling Warning ==========
if config.dynamic_sampling and config.advantage_estimator in ["rloo", "reinforce_baseline"]:
warnings.warn(f"dynamic_sampling not implemented for {config.advantage_estimator}, ignoring", UserWarning)
if config.dynamic_sampling:
if config.advantage_estimator in ["rloo", "reinforce_baseline"]:
warnings.warn(
f"dynamic_sampling not implemented for {config.advantage_estimator}, ignoring", UserWarning
)
elif config.advantage_estimator in ["group_norm", "grpo"]:
# Check if filter_weight_manager is properly configured
from .filter_weight import RewardValueFilter
has_reward_filter = (
self.filter_weight_manager is not None and self.filter_weight_manager.filters
and any(isinstance(f, RewardValueFilter) for f in self.filter_weight_manager.filters)
)
if not has_reward_filter:
warnings.warn(
"dynamic_sampling is enabled but FilterWeightManager is not configured with "
"RewardValueFilter. Dynamic sampling will not be applied. "
"Please ensure filter_weight_manager is properly initialized.", UserWarning
)

# ========== Advantage Estimator-Specific Shaping ==========
if config.advantage_estimator == "rloo":
Expand All @@ -1439,18 +1472,7 @@ def _process_experiences(
return experiences, rewards

elif config.advantage_estimator in ["group_norm", "grpo"]:
# Group normalization with optional dynamic filtering
if config.dynamic_sampling:
step_size = config.n_samples_per_prompt // config.micro_train_batch_size
for i in range(0, len(experiences), step_size):
chunk = experiences[i:i + step_size]
chunk_rewards = torch.cat([exp.info["reward"] for exp in chunk])

# Filter out degenerate cases (all 0s or all 1s)
if torch.all(chunk_rewards == 0) or torch.all(chunk_rewards == 1):
for exp in chunk:
exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool)

# Group normalization
# # Normalize within groups
# rewards = rewards.reshape(-1, config.n_samples_per_prompt).to("cuda")
# rewards = (rewards - rewards.mean(-1, keepdim=True)) / (rewards.std(-1, keepdim=True) + 1e-9)
Expand Down Expand Up @@ -1592,7 +1614,7 @@ def _compute_advantages_and_returns(
def _make_experience_list_by_model(
self,
all_samples: List[Union[Samples, SamplesVL]],
) -> List[Union[Experience, ExperienceVL]]:
) -> Tuple[List[Union[Experience, ExperienceVL]], List[_SamplesOutput]]:
"""
Batch forward pass through all models to create experiences.

Expand Down Expand Up @@ -1670,7 +1692,8 @@ def _make_experience_list_by_model(
self.reward_engine.compute_rewards(outputs, vlm_mode, device)

# ========== Stage 5: Assemble Experiences ==========
return [self._pack_experience(output, vlm_mode) for output in outputs]
experiences = [self._pack_experience(output, vlm_mode) for output in outputs]
return experiences, outputs

def _preprocess_sample(
self,
Expand Down
153 changes: 153 additions & 0 deletions lightrft/trainer/filter_weight/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""
Filter and Weight Module

Unified interface for sample filtering and loss weighting in RLHF.

This module provides a three-layer architecture for managing sample filtering
and loss weighting:

1. **Metrics Layer**: Compute sample-level metrics (entropy, difficulty, staleness, etc.)
2. **Filter Layer**: Filter samples based on metrics (keep/discard decisions)
3. **Weight Layer**: Compute per-sample loss weights based on metrics

The FilterWeightManager provides a high-level API to orchestrate these components.

Example usage:
```python
from lightrft.trainer.filter_weight import (
FilterWeightManager,
ResponseLengthFilter,
DifficultyWeighting,
)

# Create manager
manager = FilterWeightManager(
filters=[ResponseLengthFilter(max_length=1024)],
weights=[(DifficultyWeighting(mode="prioritized"), 1.0)],
enable_metrics={"difficulty": True}
)

# Compute metrics
metrics = manager.compute_metrics(outputs)

# Apply to experiences
experiences, weights = manager.apply_to_experiences(experiences, metrics)
```
"""

# Metrics
from .metrics import (
SampleMetrics,
MetricsComputer,
)

# Filters
from .filters import (
SampleFilter,
ResponseLengthFilter,
RewardValueFilter,
EntropyFilter,
DifficultyFilter,
CompositeFilter,
PercentileFilter,
)

# Weights
from .weights import (
LossWeighting,
ResponseLengthWeighting,
EntropyWeighting,
DifficultyWeighting,
StalenessWeighting,
RewardMagnitudeWeighting,
CompositeWeighting,
UniformWeighting,
)

# Manager
from .manager import (
FilterWeightManager,
FilterWeightManagerBuilder,
)

__all__ = [
# ========== Metrics ==========
"SampleMetrics",
"MetricsComputer",

# ========== Filters ==========
"SampleFilter",
"ResponseLengthFilter",
"RewardValueFilter",
"EntropyFilter",
"DifficultyFilter",
"CompositeFilter",
"PercentileFilter",

# ========== Weights ==========
"LossWeighting",
"ResponseLengthWeighting",
"EntropyWeighting",
"DifficultyWeighting",
"StalenessWeighting",
"RewardMagnitudeWeighting",
"CompositeWeighting",
"UniformWeighting",

# ========== Manager ==========
"FilterWeightManager",
"FilterWeightManagerBuilder",
]


# Quick access functions for common use cases
def create_length_filter(max_length: int = 1024, **kwargs):
"""
Quick function to create a response length filter.

:param max_length: Maximum response length
:type max_length: int
:param kwargs: Additional arguments for ResponseLengthFilter
:type kwargs: dict
:return: ResponseLengthFilter instance
:rtype: ResponseLengthFilter
"""
return ResponseLengthFilter(max_length=max_length, **kwargs)


def create_difficulty_weighting(mode: str = "prioritized", alpha: float = 0.6, **kwargs):
"""
Quick function to create difficulty weighting.

:param mode: Weighting mode ("prioritized" or "curriculum")
:type mode: str
:param alpha: Prioritization exponent
:type alpha: float
:param kwargs: Additional arguments for DifficultyWeighting
:type kwargs: dict
:return: DifficultyWeighting instance
:rtype: DifficultyWeighting
"""
return DifficultyWeighting(mode=mode, alpha=alpha, **kwargs)


def create_manager_from_args(args, packing_samples: bool = False):
"""
Quick function to create FilterWeightManager from training arguments.

:param args: Training arguments
:type args: Any
:param packing_samples: Whether samples are packed
:type packing_samples: bool
:return: FilterWeightManager instance
:rtype: FilterWeightManager
"""
return FilterWeightManagerBuilder.from_args(args, packing_samples)


# Add convenience imports at module level
__all__.extend([
"create_length_filter",
"create_difficulty_weighting",
"create_manager_from_args",
])
Loading
Loading