Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion src/pie_modules/models/components/seq2seq_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from copy import copy
from typing import Any, Dict, List, Optional, Tuple

from torch import Tensor, nn
import torch
from torch import FloatTensor, LongTensor, Tensor, nn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -31,6 +32,47 @@ def output_size(self) -> int:
return self.rnn.hidden_size


class ConcatenatedSequencesWrapper(nn.Module):
"""Wrapper for a module that processes concatenated sequences.

The input tensor is expected to have the shape (batch_size, sequence_length, input_size) and
multiple sequences are concatenated along the batch dimension. The module is expected to
process the concatenated sequences and return the processed sequences with the same shape. The
processed sequences are then separated back to the original sequences and returned as the
output tensor.
"""

def __init__(self, module: nn.Module, module_output_size: int):
super().__init__()
self.module = module
self.module_output_size = module_output_size

def forward(
self, values: FloatTensor, sequence_ids: LongTensor, *args, **kwargs
) -> FloatTensor:
results = torch.zeros(
values.size(0), values.size(1), self.module_output_size, device=values.device
)
for seq_idx in torch.unique(sequence_ids):
# get values for the current sequence (from multiple batch entries)
mask = sequence_ids == seq_idx
# shape: (num_selected, sequence_length, input_size)
selected_values = values[mask]
# flatten the batch dimension
concatenated_sequence = selected_values.view(-1, selected_values.size(-1))
# (num_selected * sequence_length, input_size) -> (num_selected * sequence_length, output_size)
processed_sequence = self.module(
concatenated_sequence.unsqueeze(0), *args, **kwargs
).squeeze(0)
# restore the batch dimension: (num_selected, sequence_length, output_size)
reconstructed_sequence = processed_sequence.view(
selected_values.size(0), selected_values.size(1), processed_sequence.size(-1)
)
# store the processed sequence back to the results tensor at the correct batch indices
results[mask] = reconstructed_sequence
return results


def build_seq2seq_encoder(
config: Dict[str, Any], input_size: int
) -> Tuple[Optional[nn.Module], int]:
Expand All @@ -54,6 +96,9 @@ def build_seq2seq_encoder(
input_size = output_size

seq2seq_encoder = nn.Sequential(*modules)
elif seq2seq_encoder_type == "concatenate_sequences":
submodule, output_size = build_seq2seq_encoder(config["module"], input_size)
seq2seq_encoder = ConcatenatedSequencesWrapper(submodule, output_size)
elif seq2seq_encoder_type in RNN_TYPE2CLASS:
rnn_class = RNN_TYPE2CLASS[seq2seq_encoder_type]
seq2seq_encoder = RNNWrapper(rnn_class(input_size=input_size, batch_first=True, **config))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,20 @@ def forward(
self, inputs: InputType, targets: Optional[TargetType] = None
) -> TokenClassifierOutput:
inputs_without_special_tokens_mask = {
k: v for k, v in inputs.items() if k != "special_tokens_mask"
k: v
for k, v in inputs.items()
if k != "special_tokens_mask" and not k.startswith("seq2seq_encoder_")
}
outputs = self.model(**inputs_without_special_tokens_mask)
sequence_output = outputs[0]

if self.seq2seq_encoder is not None:
sequence_output = self.seq2seq_encoder(sequence_output)
seq2seq_encoder_kwargs = {
k[len("seq2seq_encoder_") :]: v
for k, v in inputs.items()
if k.startswith("seq2seq_encoder_")
}
sequence_output = self.seq2seq_encoder(sequence_output, **seq2seq_encoder_kwargs)

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(
tokenize_kwargs: Optional[Dict[str, Any]] = None,
pad_kwargs: Optional[Dict[str, Any]] = None,
log_precision_recall_metrics: bool = True,
inputs_key_document_indices: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -173,6 +174,7 @@ def __init__(
self.tokenize_kwargs = tokenize_kwargs or {}
self.pad_kwargs = pad_kwargs or {}
self.log_precision_recall_metrics = log_precision_recall_metrics
self.inputs_key_document_indices = inputs_key_document_indices

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)

Expand Down Expand Up @@ -230,6 +232,15 @@ def _post_prepare(self):

self.id_to_label = {v: k for k, v in self.label_to_id.items()}

def encode_inputs(
self,
documents: Sequence[DocumentType],
show_progress: bool = False,
) -> Tuple[Sequence[TaskEncodingType], Sequence[DocumentType]]:
self._doc_idx = 0
result = super().encode_inputs(documents, show_progress=show_progress)
return result

def encode_input(
self,
document: TextDocument,
Expand Down Expand Up @@ -263,10 +274,11 @@ def encode_input(
TaskEncoding(
document=document,
inputs=tokenized_doc.metadata["tokenizer_encoding"],
metadata={"tokenized_document": tokenized_doc},
metadata={"tokenized_document": tokenized_doc, "document_idx": self._doc_idx},
)
)

self._doc_idx += 1
return task_encodings

def encode_target(
Expand Down Expand Up @@ -317,6 +329,11 @@ def collate(self, task_encodings: Sequence[TaskEncodingType]) -> ModelStepInputT
inputs = self.tokenizer.pad(
list_of_dicts2dict_of_lists(input_encodings), return_tensors="pt", **self.pad_kwargs
)
if self.inputs_key_document_indices is not None:
document_indices = torch.tensor(
[task_encoding.metadata["document_idx"] for task_encoding in task_encodings]
)
inputs[self.inputs_key_document_indices] = document_indices

if not task_encodings[0].has_targets:
return inputs, None
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_simple_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def taskmodule_config():
"tokenize_kwargs": None,
"pad_kwargs": None,
"log_precision_recall_metrics": True,
"inputs_key_document_indices": None,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def taskmodule_config():
"tokenize_kwargs": None,
"pad_kwargs": None,
"log_precision_recall_metrics": True,
"inputs_key_document_indices": None,
}


Expand Down