diff --git a/podium/datasets/iterator.py b/podium/datasets/iterator.py index b141f0c6..926a2fc2 100644 --- a/podium/datasets/iterator.py +++ b/podium/datasets/iterator.py @@ -257,7 +257,6 @@ def _create_batch(self, examples): for field in self._dataset.fields: if field.is_numericalizable and field.batch_as_matrix: # If this field is numericalizable, generate a possibly padded matrix - # the length to which all the rows are padded (or truncated) pad_length = Iterator._get_pad_length(field, examples) @@ -268,7 +267,7 @@ def _create_batch(self, examples): matrix = None # np.empty(shape=(n_rows, pad_length)) # non-sequential fields all have length = 1, no padding necessary - should_pad = True if field.is_sequential else False + should_pad = field.is_sequential for i, example in enumerate(examples): @@ -321,14 +320,14 @@ def _create_batch(self, examples): @staticmethod def _get_pad_length(field, examples): - if not field.is_sequential: - return 1 - # the fixed_length attribute of Field has priority over the max length # of all the examples in the batch if field.fixed_length is not None: return field.fixed_length + if not field.is_sequential: + return 1 + # if fixed_length is None, then return the maximum length of all the # examples in the batch def length_of_field(example): diff --git a/podium/storage/__init__.py b/podium/storage/__init__.py index ee192886..44a1c04f 100644 --- a/podium/storage/__init__.py +++ b/podium/storage/__init__.py @@ -2,7 +2,7 @@ from .example_factory import ExampleFactory, ExampleFormat from .field import Field, TokenizedField, MultilabelField, MultioutputField, \ - unpack_fields, LabelField + unpack_fields, LabelField, SentenceEmbeddingField from .resources.downloader import (BaseDownloader, SCPDownloader, HttpDownloader, SimpleHttpDownloader) from .resources.large_resource import LargeResource, SCPLargeResource @@ -21,6 +21,6 @@ __all__ = ["BaseDownloader", "SCPDownloader", "HttpDownloader", "SimpleHttpDownloader", "Field", "TokenizedField", "LabelField", "MultilabelField", "MultioutputField", - "unpack_fields", "LargeResource", "SCPLargeResource", + "unpack_fields", "LargeResource", "SCPLargeResource", "SentenceEmbeddingField", "VectorStorage", "BasicVectorStorage", "SpecialVocabSymbols", "Vocab", "ExampleFactory", "ExampleFormat", "TfIdfVectorizer"] diff --git a/podium/storage/field.py b/podium/storage/field.py index 794badaf..3cb0dd0a 100644 --- a/podium/storage/field.py +++ b/podium/storage/field.py @@ -2,6 +2,7 @@ import logging import itertools from collections import deque +from typing import Callable import numpy as np @@ -249,13 +250,13 @@ def __init__(self, If true, the output of the tokenizer is presumed to be a list of tokens and will be numericalized using the provided Vocab or custom_numericalize. For numericalizable fields, Iterator will generate batch fields containing - numpy matrices. + numpy matrices. - If false, the out of the tokenizer is presumed to be a custom datatype. - Posttokenization hooks aren't allowed to be added as they can't be called - on custom datatypes. For non-numericalizable fields, Iterator will generate - batch fields containing lists of these custom data type instances returned - by the tokenizer. + If false, the out of the tokenizer is presumed to be a custom datatype. + Posttokenization hooks aren't allowed to be added as they can't be called + on custom datatypes. For non-numericalizable fields, Iterator will generate + batch fields containing lists of these custom data type instances returned + by the tokenizer. custom_numericalize : callable The numericalization function that will be called if the field doesn't use a vocabulary. If using custom_numericalize and padding is @@ -666,7 +667,7 @@ def numericalize(self, data): _LOGGER.error(error_msg) raise ValueError(error_msg) - else: + elif not self.custom_numericalize: return None # raw data is just a string, so we need to wrap it into an iterable @@ -1005,6 +1006,50 @@ def _numericalize_tokens(self, tokens): return numericalize_multihot(tokens, token_numericalize, self.num_of_classes) +class SentenceEmbeddingField(Field): + """Field used for sentence-level multidimensional embeddings.""" + + def __init__(self, + name: str, + embedding_fn: Callable[[str], np.array], + embedding_size: int, + vocab=None, + is_target=False, + language='en', + allow_missing_data=False): + """ + Field used for sentence-level multidimensional embeddings. + + Parameters + ---------- + name: str + Field name, used for referencing data in the dataset. + embedding_fn: Callable[[str], np.array] + Callable that takes a string and returns a fixed-width embedding. + In case of missing data, this callable will be passed a None. + embedding_size: int + Width of the embedding. + vocab: Vocab + Vocab that will be updated with the sentences passed to this field. + Keep in mind that whole sentences will be passed to the vocab. + language: str + Langage of the data. Not used in this field. + allow_missing_data: bool + Whether this field will allow the processing of missing data. + """ + super().__init__(name, + custom_numericalize=embedding_fn, + tokenizer=None, + language=language, + vocab=vocab, + tokenize=False, + store_as_raw=True, + store_as_tokenized=False, + is_target=is_target, + fixed_length=embedding_size, + allow_missing_data=allow_missing_data) + + def numericalize_multihot(tokens, token_indexer, num_of_classes): active_classes = list(map(token_indexer, tokens)) multihot_encoding = np.zeros(num_of_classes, dtype=np.bool) diff --git a/test/storage/test_field.py b/test/storage/test_field.py index 49b11c1c..8b777e78 100644 --- a/test/storage/test_field.py +++ b/test/storage/test_field.py @@ -5,7 +5,7 @@ from mock import patch from podium.storage import Field, TokenizedField, MultilabelField, \ - Vocab, SpecialVocabSymbols, MultioutputField, LabelField + Vocab, SpecialVocabSymbols, MultioutputField, LabelField, SentenceEmbeddingField ONE_TO_FIVE = [1, 2, 3, 4, 5] @@ -689,36 +689,14 @@ def test_missing_values_default_sequential(): custom_numericalize=lambda x: hash(x), allow_missing_data=True) - _, data_missing = fld.preprocess(None)[0] _, data_exists = fld.preprocess("data_string")[0] - assert data_missing == (None, None) assert data_exists == (None, ["data_string"]) fld.finalize() - assert fld.numericalize(data_missing) is None assert np.all(fld.numericalize(data_exists) == np.array([hash("data_string")])) -def test_missing_values_custom_numericalize(): - fld = Field(name="test_field", - store_as_raw=True, - tokenize=False, - custom_numericalize=int, - allow_missing_data=True) - - _, data_missing = fld.preprocess(None)[0] - _, data_exists = fld.preprocess("404")[0] - - assert data_missing == (None, None) - assert data_exists == ("404", None) - - fld.finalize() - - assert fld.numericalize(data_missing) is None - assert np.all(fld.numericalize(data_exists) == np.array([404])) - - def test_missing_symbol_index_vocab(): vocab = Vocab() fld = Field(name="test_field", @@ -874,3 +852,25 @@ def test_label_field(): _, example = x[0] raw, _ = example assert label_field.numericalize(example) == vocab.stoi[raw] + + +def test_sentence_embedding_field(): + def mock_embedding_fn(sentence): + if sentence == "test_sentence": + return np.array([1, 2, 3, 4]) + + if sentence is None: + return np.zeros(4) + + field = SentenceEmbeddingField("test_field", + embedding_fn=mock_embedding_fn, + embedding_size=4, + allow_missing_data=True) + + (_, data), = field.preprocess("test_sentence") + numericalization_1 = field.numericalize(data) + assert np.all(numericalization_1 == np.array([1, 2, 3, 4])) + + (_, data), = field.preprocess(None) + numericalization_2 = field.numericalize(data) + assert np.all(numericalization_2 == np.zeros(4))