From 5cbc55d84b5a7cbf05a9cf020c468052e8d94d00 Mon Sep 17 00:00:00 2001 From: Ronglai Zuo <43442976+2000ZRL@users.noreply.github.com> Date: Sat, 21 Feb 2026 23:49:51 +0000 Subject: [PATCH] fix batched generation --- mGPT/archs/lm_multihead.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mGPT/archs/lm_multihead.py b/mGPT/archs/lm_multihead.py index 2839677..2bca1ac 100644 --- a/mGPT/archs/lm_multihead.py +++ b/mGPT/archs/lm_multihead.py @@ -620,7 +620,7 @@ def generate( idx_next_rhand = torch.argmax(probs_rhand, dim=1, keepdim=True) idx_next_lhand[finished] = idx_next_rhand[finished] = self.eos_idx - finished = torch.any(idx_next_body.squeeze(-1) == self.eos_idx, dim=-1) + finished = finished | (idx_next_body.squeeze(-1) == self.eos_idx) # append sampled index to the running sequence and continue decoder_input_ids = torch.cat((decoder_input_ids, idx_next_body), dim=1)