-
Notifications
You must be signed in to change notification settings - Fork 11
refactor(sunjx): refactor loss-filter implementation #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Jiaxuan-Sun
wants to merge
10
commits into
opendilab:main
Choose a base branch
from
Jiaxuan-Sun:refactor/loss-filter
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
0ebcdbf
refactor(sunjx): refactor loss-filter for sample filtering and loss w…
Jiaxuan-Sun 11e81ac
refactor(sunjx): refactor loss-filter implementation
Jiaxuan-Sun 008c90a
refactor(sunjx): Unify the comment style
Jiaxuan-Sun ab61fef
Merge remote-tracking branch 'opendilab/main' into refactor/loss-filter
Jiaxuan-Sun 4d04e1d
refactor(sunjx): fix format/fcheck bugs
Jiaxuan-Sun a43ae21
feature(sunjx): fix dynamic_sampling bugs
Jiaxuan-Sun d0346d0
Merge branch 'main' into refactor/loss-filter
Jiaxuan-Sun 7d8dea4
refactor(sunjx): pass formt and fcheck
Jiaxuan-Sun a659c00
refactor(sunjx): pass format and fcheck check
Jiaxuan-Sun 97f8a92
refactor(sunjx): Organize the code
Jiaxuan-Sun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| """ | ||
| Filter and Weight Module | ||
|
|
||
| Unified interface for sample filtering and loss weighting in RLHF. | ||
|
|
||
| This module provides a three-layer architecture for managing sample filtering | ||
| and loss weighting: | ||
|
|
||
| 1. **Metrics Layer**: Compute sample-level metrics (entropy, difficulty, staleness, etc.) | ||
| 2. **Filter Layer**: Filter samples based on metrics (keep/discard decisions) | ||
| 3. **Weight Layer**: Compute per-sample loss weights based on metrics | ||
|
|
||
| The FilterWeightManager provides a high-level API to orchestrate these components. | ||
|
|
||
| Example usage: | ||
| ```python | ||
| from lightrft.trainer.filter_weight import ( | ||
| FilterWeightManager, | ||
| ResponseLengthFilter, | ||
| DifficultyWeighting, | ||
| ) | ||
|
|
||
| # Create manager | ||
| manager = FilterWeightManager( | ||
| filters=[ResponseLengthFilter(max_length=1024)], | ||
| weights=[(DifficultyWeighting(mode="prioritized"), 1.0)], | ||
| enable_metrics={"difficulty": True} | ||
| ) | ||
|
|
||
| # Compute metrics | ||
| metrics = manager.compute_metrics(outputs) | ||
|
|
||
| # Apply to experiences | ||
| experiences, weights = manager.apply_to_experiences(experiences, metrics) | ||
| ``` | ||
| """ | ||
|
|
||
| # Metrics | ||
| from .metrics import ( | ||
| SampleMetrics, | ||
| MetricsComputer, | ||
| ) | ||
|
|
||
| # Filters | ||
| from .filters import ( | ||
| SampleFilter, | ||
| ResponseLengthFilter, | ||
| RewardValueFilter, | ||
| EntropyFilter, | ||
| DifficultyFilter, | ||
| CompositeFilter, | ||
| PercentileFilter, | ||
| ) | ||
|
|
||
| # Weights | ||
| from .weights import ( | ||
| LossWeighting, | ||
| ResponseLengthWeighting, | ||
| EntropyWeighting, | ||
| DifficultyWeighting, | ||
| StalenessWeighting, | ||
| RewardMagnitudeWeighting, | ||
| CompositeWeighting, | ||
| UniformWeighting, | ||
| ) | ||
|
|
||
| # Manager | ||
| from .manager import ( | ||
| FilterWeightManager, | ||
| FilterWeightManagerBuilder, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| # ========== Metrics ========== | ||
| "SampleMetrics", | ||
| "MetricsComputer", | ||
|
|
||
| # ========== Filters ========== | ||
| "SampleFilter", | ||
| "ResponseLengthFilter", | ||
| "RewardValueFilter", | ||
| "EntropyFilter", | ||
| "DifficultyFilter", | ||
| "CompositeFilter", | ||
| "PercentileFilter", | ||
|
|
||
| # ========== Weights ========== | ||
| "LossWeighting", | ||
| "ResponseLengthWeighting", | ||
| "EntropyWeighting", | ||
| "DifficultyWeighting", | ||
| "StalenessWeighting", | ||
| "RewardMagnitudeWeighting", | ||
| "CompositeWeighting", | ||
| "UniformWeighting", | ||
|
|
||
| # ========== Manager ========== | ||
| "FilterWeightManager", | ||
| "FilterWeightManagerBuilder", | ||
| ] | ||
|
|
||
|
|
||
| # Quick access functions for common use cases | ||
| def create_length_filter(max_length: int = 1024, **kwargs): | ||
| """ | ||
| Quick function to create a response length filter. | ||
|
|
||
| :param max_length: Maximum response length | ||
| :type max_length: int | ||
| :param kwargs: Additional arguments for ResponseLengthFilter | ||
| :type kwargs: dict | ||
| :return: ResponseLengthFilter instance | ||
| :rtype: ResponseLengthFilter | ||
| """ | ||
| return ResponseLengthFilter(max_length=max_length, **kwargs) | ||
|
|
||
|
|
||
| def create_difficulty_weighting(mode: str = "prioritized", alpha: float = 0.6, **kwargs): | ||
| """ | ||
| Quick function to create difficulty weighting. | ||
|
|
||
| :param mode: Weighting mode ("prioritized" or "curriculum") | ||
| :type mode: str | ||
| :param alpha: Prioritization exponent | ||
| :type alpha: float | ||
| :param kwargs: Additional arguments for DifficultyWeighting | ||
| :type kwargs: dict | ||
| :return: DifficultyWeighting instance | ||
| :rtype: DifficultyWeighting | ||
| """ | ||
| return DifficultyWeighting(mode=mode, alpha=alpha, **kwargs) | ||
|
|
||
|
|
||
| def create_manager_from_args(args, packing_samples: bool = False): | ||
| """ | ||
| Quick function to create FilterWeightManager from training arguments. | ||
|
|
||
| :param args: Training arguments | ||
| :type args: Any | ||
| :param packing_samples: Whether samples are packed | ||
| :type packing_samples: bool | ||
| :return: FilterWeightManager instance | ||
| :rtype: FilterWeightManager | ||
| """ | ||
| return FilterWeightManagerBuilder.from_args(args, packing_samples) | ||
|
|
||
|
|
||
| # Add convenience imports at module level | ||
| __all__.extend([ | ||
| "create_length_filter", | ||
| "create_difficulty_weighting", | ||
| "create_manager_from_args", | ||
| ]) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this added? Does it cause an error without it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to prevent deadlock in distributed all-reduce operations.
After dynamic sampling, the set of keys in the status dictionary may differ across ranks (some ranks have keys like kl and ptx_loss, while others do not). The all_reduce(dict) operation calls dist.all_reduce for each key individually. If the keys or their order differ between ranks, the collective operations will be inconsistent, causing the process to hang.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个改动确实有效防止了 NCCL deadlock,但是,这里的实现有两个比较严重的隐患:
数学逻辑问题:如果 op="mean",对于缺失 key 的 Rank 默认补 0.0,这会把 0.0 计入分子,并除以 world_size。这会严重拉低该指标的真实均值(比如只有一张卡有 KL=1.0,4卡平均后变成了 0.25)。
架构与性能问题:all_reduce 作为底层通信原语,内部高频调用 all_gather_object (依赖 pickle) 会带来性能损耗;而且在底层强行补 0.0 掩盖了上游状态不对齐的问题。
建议的修改方向:
最好不要在底层 all_reduce 中做 key 的对齐。我们应该在调用 all_reduce 之前的业务层(比如 metrics logging 处),显式地初始化所有可能的 keys。对于被 filter 掉的 rank,可以传 0.0 并配合一个 valid_count 掩码,最后用 sum(values) / sum(valid_counts) 来计算准确的 mean。