Looking at sparse_coder.py#L247, the code calls decode() on the auxk_acts to get e_hat, but calling decode() also adds in the decoder bias b_dec. This seems like a mistake, since e, the residual, is the difference between sae_out and y, and this should already take care of b_dec. This aux loss will then pull dead latents towards e - b_dec rather than just towards e as is likely intended, and is likely also causing an unintended gradient on b_dec as well.
Looking at sparse_coder.py#L247, the code calls
decode()on theauxk_actsto gete_hat, but callingdecode()also adds in the decoder biasb_dec. This seems like a mistake, sincee, the residual, is the difference betweensae_outandy, and this should already take care ofb_dec. This aux loss will then pull dead latents towardse - b_decrather than just towardseas is likely intended, and is likely also causing an unintended gradient onb_decas well.