diff --git a/mGPT/archs/lm_multihead.py b/mGPT/archs/lm_multihead.py index 2839677..2bca1ac 100644 --- a/mGPT/archs/lm_multihead.py +++ b/mGPT/archs/lm_multihead.py @@ -620,7 +620,7 @@ def generate( idx_next_rhand = torch.argmax(probs_rhand, dim=1, keepdim=True) idx_next_lhand[finished] = idx_next_rhand[finished] = self.eos_idx - finished = torch.any(idx_next_body.squeeze(-1) == self.eos_idx, dim=-1) + finished = finished | (idx_next_body.squeeze(-1) == self.eos_idx) # append sampled index to the running sequence and continue decoder_input_ids = torch.cat((decoder_input_ids, idx_next_body), dim=1)