Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ tabulate = "^0.9"
# for TokenClassificationModelWithSeq2SeqEncoderAndCrf
pytorch-crf = ">=0.7.2"
# for rouge metric (tests only) and for NltkSentenceSplitter
nltk = "^3.8.1"
nltk = "^3.9.1"
# for NltkSentenceSplitter
flair = "^0.13.1"
# for SpansViaRelationMerger
Expand Down
15 changes: 12 additions & 3 deletions src/pie_modules/document/processing/sentence_splitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import os
from typing import TypeVar

from pie_modules.annotations import LabeledSpan
Expand Down Expand Up @@ -30,8 +31,9 @@ def __init__(
self,
partition_layer_name: str = "labeled_partitions",
text_field_name: str = "text",
sentencizer_url: str = "tokenizers/punkt/PY3/english.pickle",
language: str = "english",
inplace: bool = True,
sentencizer_url: str | None = None,
):
try:
import nltk
Expand All @@ -41,12 +43,19 @@ def __init__(
"You can install NLTK with `pip install nltk`."
)

if sentencizer_url is not None:
logger.warning(
"The 'sentencizer_url' argument is deprecated. Please use 'language' instead."
)
if sentencizer_url[-7:] == ".pickle":
language = os.path.split(sentencizer_url[:-7])[-1]

self.partition_layer_name = partition_layer_name
self.text_field_name = text_field_name
self.inplace = inplace
# download the NLTK Punkt tokenizer model
nltk.download("punkt")
self.sentencizer = nltk.data.load(sentencizer_url)
nltk.download("punkt_tab")
self.sentencizer = nltk.tokenize.PunktTokenizer(language)

def __call__(self, document: D) -> D:
if not self.inplace:
Expand Down
29 changes: 29 additions & 0 deletions tests/document/processing/test_sentence_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,35 @@ def test_nltk_sentence_splitter(caplog, inplace):
assert str(doc.labeled_partitions[1]) == "This is another one."


def test_nltk_sentence_splitter_deprecated_arg(caplog):
doc = TextDocumentWithLabeledPartitions(
text="This is a test sentence. This is another one.", id="test_doc"
)
# add a dummy text partition to trigger the warning (see below)
doc.labeled_partitions.append(LabeledSpan(start=0, end=len(doc.text), label="text"))
caplog.clear()
# create the sentence splitter
with caplog.at_level("WARNING"):
sentence_splitter = NltkSentenceSplitter(
inplace=True, sentencizer_url="tokenizers/punkt/PY3/english.pickle"
)
# call the sentence splitter
result = sentence_splitter(doc)

assert result is doc
# check the log messages
assert len(caplog.records) == 2
assert caplog.messages == [
"The 'sentencizer_url' argument is deprecated. Please use 'language' instead.",
"Layer labeled_partitions in document test_doc is not empty. "
"Clearing it before adding new sentence partitions.",
]
# check the result
assert len(doc.labeled_partitions) == 2
assert str(doc.labeled_partitions[0]) == "This is a test sentence."
assert str(doc.labeled_partitions[1]) == "This is another one."


@pytest.mark.parametrize("inplace", [True, False])
def test_flair_segtok_sentence_splitter(caplog, inplace):
doc = TextDocumentWithLabeledPartitions(
Expand Down
Loading