Description
It seems like the num_tokens_since_fired update mechanism is inconsistent for distributed training modes, leading to incorrect dead feature detection behavior.
Issue 1: --distribute_modules = False, (DDP)
- Each replica maintains its own
num_tokens_since_fired and did_fire counters
- No cross-rank synchronization occurs
- (For the 2 GPUs case) When a feature fires on rank 0 but not rank 1:
- Only rank 0 resets its counter
- Rank 1 continues incrementing and may cross the threshold, therefore penalizing a feature using aux. loss.
Issue 2: --distribute_modules = True
- In the model hook we gather a global batch but increment counters with the local token count:
def hook(...):
...
dist.all_gather_into_tensor(world_outputs, outputs)
outputs = world_outputs
...
...
N = tokens_mask.sum().item()
num_tokens_in_step += N
self.num_tokens_since_fired += num_tokens_in_step
- Each SAE processes
world_size * B samples but only tracks B, so the effective dead_feature_threshold is multiplied by world_size.
Proposed Fix
Implement proper counting and synchronization of global tokens and fired flags.
I believe it is not very hard to handle, so I can do it if appropriate and after some discussion on the implementation.
Description
It seems like the
num_tokens_since_firedupdate mechanism is inconsistent for distributed training modes, leading to incorrect dead feature detection behavior.Issue 1:
--distribute_modules = False, (DDP)num_tokens_since_firedanddid_firecountersIssue 2:
--distribute_modules = Trueworld_size * Bsamples but only tracksB, so the effectivedead_feature_thresholdis multiplied byworld_size.Proposed Fix
Implement proper counting and synchronization of global tokens and fired flags.
I believe it is not very hard to handle, so I can do it if appropriate and after some discussion on the implementation.