Skip to content

refactor(sunjx): refactor loss-filter implementation#17

Open
Jiaxuan-Sun wants to merge 10 commits into
opendilab:mainfrom
Jiaxuan-Sun:refactor/loss-filter
Open

refactor(sunjx): refactor loss-filter implementation#17
Jiaxuan-Sun wants to merge 10 commits into
opendilab:mainfrom
Jiaxuan-Sun:refactor/loss-filter

Conversation

@Jiaxuan-Sun

@Jiaxuan-Sun Jiaxuan-Sun commented Jan 1, 2026

Copy link
Copy Markdown
Contributor

Add new lightrft/trainer/filter_weight/ module with:

  • metrics.py - Metrics computation layer (entropy, difficulty, staleness, etc.)
  • filters.py - Sample filtering layer (length, reward value, entropy, difficulty filters, etc.)
  • weights.py - Loss weighting layer (length, entropy, difficulty, staleness weightings, etc.)
  • manager.py - Unified management layer (FilterWeightManager)

Note:​ The dynamic sampling feature has been tested. Other components are reserved for future extension.

Comment thread lightrft/trainer/filter_weight/__init__.py Outdated
Comment thread lightrft/trainer/filter_weight/filters.py
Comment thread lightrft/trainer/filter_weight/__init__.py Outdated
Comment thread lightrft/trainer/filter_weight/__init__.py Outdated
@puyuan1996 puyuan1996 added enhancement New feature or request refactor Cleanup, formatting, or restructuring of existing code. labels Jan 4, 2026
Comment thread lightrft/trainer/filter_weight/manager.py Outdated
Comment thread lightrft/trainer/filter_weight/metrics.py
ret = {}
for k in all_keys:
ret[k] = self.all_reduce(data.get(k, 0.0), op)
return ret

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this added? Does it cause an error without it?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to prevent deadlock in distributed all-reduce operations.
After dynamic sampling, the set of keys in the status dictionary may differ across ranks (some ranks have keys like kl and ptx_loss, while others do not). The all_reduce(dict) operation calls dist.all_reduce for each key individually. If the keys or their order differ between ranks, the collective operations will be inconsistent, causing the process to hang.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个改动确实有效防止了 NCCL deadlock,但是,这里的实现有两个比较严重的隐患:

数学逻辑问题:如果 op="mean",对于缺失 key 的 Rank 默认补 0.0,这会把 0.0 计入分子,并除以 world_size。这会严重拉低该指标的真实均值(比如只有一张卡有 KL=1.0,4卡平均后变成了 0.25)。
架构与性能问题:all_reduce 作为底层通信原语,内部高频调用 all_gather_object (依赖 pickle) 会带来性能损耗;而且在底层强行补 0.0 掩盖了上游状态不对齐的问题。

建议的修改方向:

最好不要在底层 all_reduce 中做 key 的对齐。我们应该在调用 all_reduce 之前的业务层(比如 metrics logging 处),显式地初始化所有可能的 keys。对于被 filter 掉的 rank,可以传 0.0 并配合一个 valid_count 掩码,最后用 sum(values) / sum(valid_counts) 来计算准确的 mean。

Comment thread lightrft/trainer/fast_exp_maker.py Outdated
# If no valid actions or base log-probs are empty, skip KL safely.
if ((experience.action_mask is not None and experience.action_mask.sum().item() == 0)
or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)):
kl = torch.zeros_like(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have these null-check branches actually been hit during testing? If it's null, we should probably just throw an error directly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, an error occurred where a dimension mismatch was reported due to an action_mask value of 0 or baseline logprobs being empty (entering compute_approx_kl when base_action_log_probs was empty), indicating that these branches are actually triggered in such dynamic sampling and filtering scenarios.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's the case, we should probably figure out a way to avoid such issues during the upstream filter/weight stages (e.g., by filtering out these invalid batches early on), rather than just forcing the KL to 0 here. Setting it to 0 is more of a workaround and might mask underlying issues with the data flow or sampling logic.

如果是这样的话,那我们最好想办法在前置的 filter 或 weight 阶段就规避掉这类问题(比如提前把这些无效数据的 batch 过滤掉),而不是在这里强行让 KL 为 0。因为在这里直接置为 0 治标不治本,反而可能会掩盖潜在的数据流转或采样逻辑问题。

Comment thread lightrft/trainer/fast_exp_maker.py Outdated
Comment thread lightrft/trainer/fast_exp_maker.py Outdated
Comment thread lightrft/trainer/fast_exp_maker.py Outdated
@puyuan1996 puyuan1996 mentioned this pull request Jan 21, 2026
1 task
@Jiaxuan-Sun

Copy link
Copy Markdown
Contributor Author

Need to be tested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request refactor Cleanup, formatting, or restructuring of existing code.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants