diff --git a/pyabsa/tasks/AspectTermExtraction/prediction/aspect_extractor.py b/pyabsa/tasks/AspectTermExtraction/prediction/aspect_extractor.py index c05068e1..968a7bbf 100644 --- a/pyabsa/tasks/AspectTermExtraction/prediction/aspect_extractor.py +++ b/pyabsa/tasks/AspectTermExtraction/prediction/aspect_extractor.py @@ -548,6 +548,7 @@ def _extract(self, examples): ate_logits = torch.argmax(F.log_softmax(ate_logits, dim=2), dim=2) ate_logits = ate_logits.detach().cpu().numpy() label_ids = label_ids.to(DeviceTypeOption.CPU).numpy() + valid_ids = valid_ids.to(DeviceTypeOption.CPU).numpy() for i, i_ate_logits in enumerate(ate_logits): pred_iobs = [] sentence_res.append( @@ -561,7 +562,9 @@ def _extract(self, examples): ): break else: - pred_iobs.append(label_map.get(i_ate_logits[j], "O")) + # Only use predictions for the first BPE token of each original word + if valid_ids[i][j] == 1: + pred_iobs.append(label_map.get(i_ate_logits[j], "O")) ate_result = [] polarity = []