diff --git a/poetry.lock b/poetry.lock index 911cda7bf..c9d17702b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1767,14 +1767,14 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "nltk" -version = "3.8.1" +version = "3.9.1" description = "Natural Language Toolkit" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" groups = ["dev"] files = [ - {file = "nltk-3.8.1-py3-none-any.whl", hash = "sha256:fd5c9109f976fa86bcadba8f91e47f5e9293bd034474752e92a520f81c93dda5"}, - {file = "nltk-3.8.1.zip", hash = "sha256:1834da3d0682cba4f2cede2f9aad6b0fafb6461ba451db0efb6f9c39798d64d3"}, + {file = "nltk-3.9.1-py3-none-any.whl", hash = "sha256:4fa26829c5b00715afe3061398a8989dc643b92ce7dd93fb4585a70930d168a1"}, + {file = "nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868"}, ] [package.dependencies] @@ -3586,4 +3586,4 @@ test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.funct [metadata] lock-version = "2.1" python-versions = "^3.9" -content-hash = "360e7128e1296a81a16070db02a7145bf9e807f3795bcfe56f4e1e453c976f6b" +content-hash = "84243d0376dded1435e1defadde55a26a8d2b1428792e307c8db139721262fe6" diff --git a/pyproject.toml b/pyproject.toml index 9af0c843c..bb3bba8e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ tabulate = "^0.9" # for TokenClassificationModelWithSeq2SeqEncoderAndCrf pytorch-crf = ">=0.7.2" # for rouge metric (tests only) and for NltkSentenceSplitter -nltk = "^3.8.1" +nltk = "^3.9.1" # for NltkSentenceSplitter flair = "^0.13.1" # for SpansViaRelationMerger diff --git a/src/pie_modules/document/processing/sentence_splitter.py b/src/pie_modules/document/processing/sentence_splitter.py index 5230e2d19..ac28b94a9 100644 --- a/src/pie_modules/document/processing/sentence_splitter.py +++ b/src/pie_modules/document/processing/sentence_splitter.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os from typing import TypeVar from pie_modules.annotations import LabeledSpan @@ -30,8 +31,9 @@ def __init__( self, partition_layer_name: str = "labeled_partitions", text_field_name: str = "text", - sentencizer_url: str = "tokenizers/punkt/PY3/english.pickle", + language: str = "english", inplace: bool = True, + sentencizer_url: str | None = None, ): try: import nltk @@ -41,12 +43,19 @@ def __init__( "You can install NLTK with `pip install nltk`." ) + if sentencizer_url is not None: + logger.warning( + "The 'sentencizer_url' argument is deprecated. Please use 'language' instead." + ) + if sentencizer_url[-7:] == ".pickle": + language = os.path.split(sentencizer_url[:-7])[-1] + self.partition_layer_name = partition_layer_name self.text_field_name = text_field_name self.inplace = inplace # download the NLTK Punkt tokenizer model - nltk.download("punkt") - self.sentencizer = nltk.data.load(sentencizer_url) + nltk.download("punkt_tab") + self.sentencizer = nltk.tokenize.PunktTokenizer(language) def __call__(self, document: D) -> D: if not self.inplace: diff --git a/tests/document/processing/test_sentence_splitter.py b/tests/document/processing/test_sentence_splitter.py index 590670bc5..a64fadf02 100644 --- a/tests/document/processing/test_sentence_splitter.py +++ b/tests/document/processing/test_sentence_splitter.py @@ -37,6 +37,35 @@ def test_nltk_sentence_splitter(caplog, inplace): assert str(doc.labeled_partitions[1]) == "This is another one." +def test_nltk_sentence_splitter_deprecated_arg(caplog): + doc = TextDocumentWithLabeledPartitions( + text="This is a test sentence. This is another one.", id="test_doc" + ) + # add a dummy text partition to trigger the warning (see below) + doc.labeled_partitions.append(LabeledSpan(start=0, end=len(doc.text), label="text")) + caplog.clear() + # create the sentence splitter + with caplog.at_level("WARNING"): + sentence_splitter = NltkSentenceSplitter( + inplace=True, sentencizer_url="tokenizers/punkt/PY3/english.pickle" + ) + # call the sentence splitter + result = sentence_splitter(doc) + + assert result is doc + # check the log messages + assert len(caplog.records) == 2 + assert caplog.messages == [ + "The 'sentencizer_url' argument is deprecated. Please use 'language' instead.", + "Layer labeled_partitions in document test_doc is not empty. " + "Clearing it before adding new sentence partitions.", + ] + # check the result + assert len(doc.labeled_partitions) == 2 + assert str(doc.labeled_partitions[0]) == "This is a test sentence." + assert str(doc.labeled_partitions[1]) == "This is another one." + + @pytest.mark.parametrize("inplace", [True, False]) def test_flair_segtok_sentence_splitter(caplog, inplace): doc = TextDocumentWithLabeledPartitions(