refactor(sunjx): refactor loss-filter implementation#17
Conversation
| ret = {} | ||
| for k in all_keys: | ||
| ret[k] = self.all_reduce(data.get(k, 0.0), op) | ||
| return ret |
There was a problem hiding this comment.
Why was this added? Does it cause an error without it?
There was a problem hiding this comment.
This is to prevent deadlock in distributed all-reduce operations.
After dynamic sampling, the set of keys in the status dictionary may differ across ranks (some ranks have keys like kl and ptx_loss, while others do not). The all_reduce(dict) operation calls dist.all_reduce for each key individually. If the keys or their order differ between ranks, the collective operations will be inconsistent, causing the process to hang.
There was a problem hiding this comment.
这个改动确实有效防止了 NCCL deadlock,但是,这里的实现有两个比较严重的隐患:
数学逻辑问题:如果 op="mean",对于缺失 key 的 Rank 默认补 0.0,这会把 0.0 计入分子,并除以 world_size。这会严重拉低该指标的真实均值(比如只有一张卡有 KL=1.0,4卡平均后变成了 0.25)。
架构与性能问题:all_reduce 作为底层通信原语,内部高频调用 all_gather_object (依赖 pickle) 会带来性能损耗;而且在底层强行补 0.0 掩盖了上游状态不对齐的问题。
建议的修改方向:
最好不要在底层 all_reduce 中做 key 的对齐。我们应该在调用 all_reduce 之前的业务层(比如 metrics logging 处),显式地初始化所有可能的 keys。对于被 filter 掉的 rank,可以传 0.0 并配合一个 valid_count 掩码,最后用 sum(values) / sum(valid_counts) 来计算准确的 mean。
| # If no valid actions or base log-probs are empty, skip KL safely. | ||
| if ((experience.action_mask is not None and experience.action_mask.sum().item() == 0) | ||
| or (base_action_log_probs is not None and base_action_log_probs.numel() == 0)): | ||
| kl = torch.zeros_like( |
There was a problem hiding this comment.
Have these null-check branches actually been hit during testing? If it's null, we should probably just throw an error directly.
There was a problem hiding this comment.
Yes, an error occurred where a dimension mismatch was reported due to an action_mask value of 0 or baseline logprobs being empty (entering compute_approx_kl when base_action_log_probs was empty), indicating that these branches are actually triggered in such dynamic sampling and filtering scenarios.
There was a problem hiding this comment.
If that's the case, we should probably figure out a way to avoid such issues during the upstream filter/weight stages (e.g., by filtering out these invalid batches early on), rather than just forcing the KL to 0 here. Setting it to 0 is more of a workaround and might mask underlying issues with the data flow or sampling logic.
如果是这样的话,那我们最好想办法在前置的 filter 或 weight 阶段就规避掉这类问题(比如提前把这些无效数据的 batch 过滤掉),而不是在这里强行让 KL 为 0。因为在这里直接置为 0 治标不治本,反而可能会掩盖潜在的数据流转或采样逻辑问题。
|
Need to be tested. |
Add new
lightrft/trainer/filter_weight/module with:metrics.py- Metrics computation layer (entropy, difficulty, staleness, etc.)filters.py- Sample filtering layer (length, reward value, entropy, difficulty filters, etc.)weights.py- Loss weighting layer (length, entropy, difficulty, staleness weightings, etc.)manager.py- Unified management layer (FilterWeightManager)Note: The dynamic sampling feature has been tested. Other components are reserved for future extension.