From 530ce9b009116c6ac12d48c273cd25de69ce6f3b Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 4 Mar 2025 18:59:14 +0100 Subject: [PATCH 1/6] implement ConcatenatedSequencesWrapper --- .../models/components/seq2seq_encoder.py | 41 ++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/models/components/seq2seq_encoder.py b/src/pie_modules/models/components/seq2seq_encoder.py index 52866be9b..c470f30b5 100644 --- a/src/pie_modules/models/components/seq2seq_encoder.py +++ b/src/pie_modules/models/components/seq2seq_encoder.py @@ -2,7 +2,8 @@ from copy import copy from typing import Any, Dict, List, Optional, Tuple -from torch import Tensor, nn +import torch +from torch import FloatTensor, LongTensor, Tensor, nn logger = logging.getLogger(__name__) @@ -31,6 +32,41 @@ def output_size(self) -> int: return self.rnn.hidden_size +class ConcatenatedSequencesWrapper(nn.Module): + """Wrapper for a module that processes concatenated sequences. + + The input tensor is expected to have the shape (batch_size, sequence_length, input_size) and + multiple sequences are concatenated along the batch dimension. The module is expected to + process the concatenated sequences and return the processed sequences with the same shape. The + processed sequences are then separated back to the original sequences and returned as the + output tensor. + """ + + def __init__(self, module: nn.Module, module_output_size: int): + super().__init__() + self.module = module + self.module_output_size = module_output_size + + def forward(self, values: FloatTensor, sequence_ids: LongTensor, *args, **kwargs) -> Tensor: + results = torch.zeros( + values.size(0), values.size(1), self.module_output_size, device=values.device + ) + for seq_idx in torch.unique(sequence_ids): + # get values for the current sequence (from multiple batch entries) + mask = sequence_ids == seq_idx + selected_values = values[mask] + # flatten the batch dimension + concatenated_sequence = selected_values.view(-1, selected_values.size(-1)) + processed_sequence = self.module(concatenated_sequence.unsqueeze(0), *args, **kwargs) + # restore the batch dimension + reconstructed_sequence = processed_sequence.view( + selected_values.size(), processed_sequence.size(-1) + ) + # store the processed sequence back to the results tensor at the correct batch indices + results[mask] = reconstructed_sequence + return results + + def build_seq2seq_encoder( config: Dict[str, Any], input_size: int ) -> Tuple[Optional[nn.Module], int]: @@ -54,6 +90,9 @@ def build_seq2seq_encoder( input_size = output_size seq2seq_encoder = nn.Sequential(*modules) + elif seq2seq_encoder_type == "concatenate_sequences": + submodule, output_size = build_seq2seq_encoder(config["module"], input_size) + seq2seq_encoder = ConcatenatedSequencesWrapper(submodule, output_size) elif seq2seq_encoder_type in RNN_TYPE2CLASS: rnn_class = RNN_TYPE2CLASS[seq2seq_encoder_type] seq2seq_encoder = RNNWrapper(rnn_class(input_size=input_size, batch_first=True, **config)) From 34ae14d8c4b75a5f63044966460da2300c9eaf1f Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 4 Mar 2025 18:59:57 +0100 Subject: [PATCH 2/6] handle seq2seq_encoder kwargs in forward --- ...ken_classification_with_seq2seq_encoder_and_crf.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py b/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py index 745a6e1b1..64d7a55c0 100644 --- a/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py +++ b/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py @@ -181,13 +181,20 @@ def forward( self, inputs: InputType, targets: Optional[TargetType] = None ) -> TokenClassifierOutput: inputs_without_special_tokens_mask = { - k: v for k, v in inputs.items() if k != "special_tokens_mask" + k: v + for k, v in inputs.items() + if k != "special_tokens_mask" and not k.startswith("seq2seq_encoder_") } outputs = self.model(**inputs_without_special_tokens_mask) sequence_output = outputs[0] if self.seq2seq_encoder is not None: - sequence_output = self.seq2seq_encoder(sequence_output) + seq2seq_encoder_kwargs = { + k[len("seq2seq_encoder_") :]: v + for k, v in inputs.items() + if k.startswith("seq2seq_encoder_") + } + sequence_output = self.seq2seq_encoder(sequence_output, **seq2seq_encoder_kwargs) sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) From 0b5b2cb7a515d93dc33279537eb522e92dc0c298 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 4 Mar 2025 19:00:44 +0100 Subject: [PATCH 3/6] implement parameter inputs_key_document_indices --- ...span_extraction_by_token_classification.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py index 7a457ebed..f8fcdafc3 100644 --- a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py +++ b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py @@ -160,6 +160,7 @@ def __init__( tokenize_kwargs: Optional[Dict[str, Any]] = None, pad_kwargs: Optional[Dict[str, Any]] = None, log_precision_recall_metrics: bool = True, + inputs_key_document_indices: Optional[str] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -173,6 +174,7 @@ def __init__( self.tokenize_kwargs = tokenize_kwargs or {} self.pad_kwargs = pad_kwargs or {} self.log_precision_recall_metrics = log_precision_recall_metrics + self.inputs_key_document_indices = inputs_key_document_indices self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) @@ -230,6 +232,15 @@ def _post_prepare(self): self.id_to_label = {v: k for k, v in self.label_to_id.items()} + def encode_inputs( + self, + documents: Sequence[DocumentType], + show_progress: bool = False, + ) -> Tuple[Sequence[TaskEncodingType], Sequence[DocumentType]]: + self._doc_idx = 0 + result = super().encode_inputs(documents, show_progress=show_progress) + return result + def encode_input( self, document: TextDocument, @@ -263,10 +274,11 @@ def encode_input( TaskEncoding( document=document, inputs=tokenized_doc.metadata["tokenizer_encoding"], - metadata={"tokenized_document": tokenized_doc}, + metadata={"tokenized_document": tokenized_doc, "document_idx": self._doc_idx}, ) ) + self._doc_idx += 1 return task_encodings def encode_target( @@ -317,6 +329,11 @@ def collate(self, task_encodings: Sequence[TaskEncodingType]) -> ModelStepInputT inputs = self.tokenizer.pad( list_of_dicts2dict_of_lists(input_encodings), return_tensors="pt", **self.pad_kwargs ) + if self.inputs_key_document_indices is not None: + document_indices = torch.tensor( + [task_encoding.metadata["document_idx"] for task_encoding in task_encodings] + ) + inputs[self.inputs_key_document_indices] = document_indices if not task_encodings[0].has_targets: return inputs, None From 26b67548042f61873b001c949ea6a5fb961c58fd Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 4 Mar 2025 19:10:19 +0100 Subject: [PATCH 4/6] fix tests --- tests/models/test_simple_token_classification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_simple_token_classification.py b/tests/models/test_simple_token_classification.py index bce5f5f1d..66adec89c 100644 --- a/tests/models/test_simple_token_classification.py +++ b/tests/models/test_simple_token_classification.py @@ -34,6 +34,7 @@ def taskmodule_config(): "tokenize_kwargs": None, "pad_kwargs": None, "log_precision_recall_metrics": True, + "inputs_key_document_indices": None, } From e871aced225d14724f1ee3fdfc809634d899cad1 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 4 Mar 2025 19:15:00 +0100 Subject: [PATCH 5/6] fix tests --- .../test_token_classification_with_seq2seq_encoder_and_crf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py b/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py index 70aa02044..e83ed9c9b 100644 --- a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py +++ b/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py @@ -35,6 +35,7 @@ def taskmodule_config(): "tokenize_kwargs": None, "pad_kwargs": None, "log_precision_recall_metrics": True, + "inputs_key_document_indices": None, } From a91461483d763495160c31eaba0c3b58ee0cde20 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Thu, 6 Mar 2025 21:49:15 +0100 Subject: [PATCH 6/6] fix reshaping --- .../models/components/seq2seq_encoder.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/pie_modules/models/components/seq2seq_encoder.py b/src/pie_modules/models/components/seq2seq_encoder.py index c470f30b5..8f79fc790 100644 --- a/src/pie_modules/models/components/seq2seq_encoder.py +++ b/src/pie_modules/models/components/seq2seq_encoder.py @@ -47,20 +47,26 @@ def __init__(self, module: nn.Module, module_output_size: int): self.module = module self.module_output_size = module_output_size - def forward(self, values: FloatTensor, sequence_ids: LongTensor, *args, **kwargs) -> Tensor: + def forward( + self, values: FloatTensor, sequence_ids: LongTensor, *args, **kwargs + ) -> FloatTensor: results = torch.zeros( values.size(0), values.size(1), self.module_output_size, device=values.device ) for seq_idx in torch.unique(sequence_ids): # get values for the current sequence (from multiple batch entries) mask = sequence_ids == seq_idx + # shape: (num_selected, sequence_length, input_size) selected_values = values[mask] # flatten the batch dimension concatenated_sequence = selected_values.view(-1, selected_values.size(-1)) - processed_sequence = self.module(concatenated_sequence.unsqueeze(0), *args, **kwargs) - # restore the batch dimension + # (num_selected * sequence_length, input_size) -> (num_selected * sequence_length, output_size) + processed_sequence = self.module( + concatenated_sequence.unsqueeze(0), *args, **kwargs + ).squeeze(0) + # restore the batch dimension: (num_selected, sequence_length, output_size) reconstructed_sequence = processed_sequence.view( - selected_values.size(), processed_sequence.size(-1) + selected_values.size(0), selected_values.size(1), processed_sequence.size(-1) ) # store the processed sequence back to the results tensor at the correct batch indices results[mask] = reconstructed_sequence