diff --git a/README.md b/README.md
index d45e7afa1..2f3b7d2fd 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,6 @@
# pie-modules
-
-
-
+
[][pypi status]
[][tests]
@@ -10,31 +8,16 @@
[][pre-commit]
[][black]
-Model-, taskmodule-, and metric-implementations as well as document processing utilities for [PyTorch-IE](https://github.com/ChristophAlt/pytorch-ie).
+Annotation-, document- and metric implementations as well as utilities for [Python-IE](https://github.com/ArneBinder/pie-core).
-Available models:
+Available annotation types: see [here](src/pie_modules/annotations.py).
-- [SimpleSequenceClassificationModel](src/pie_modules/models/simple_sequence_classification.py)
-- [SequenceClassificationModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py)
-- [SequencePairSimilarityModelWithPooler](src/pie_modules/models/sequence_classification_with_pooler.py)
-- [SimpleTokenClassificationModel](src/pie_modules/models/simple_token_classification.py)
-- [TokenClassificationModelWithSeq2SeqEncoderAndCrf](src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py)
-- [SimpleExtractiveQuestionAnsweringModel](src/pie_modules/models/simple_extractive_question_answering.py)
-- [SimpleGenerativeModel](src/pie_modules/models/simple_generative.py)
-- [SpanTupleClassificationModel](src/pie_modules/models/span_tuple_classification.py)
-
-Available taskmodules:
-
-- [RETextClassificationWithIndicesTaskModule](src/pie_modules/taskmodules/re_text_classification_with_indices.py)
-- [CrossTextBinaryCorefTaskModule](src/pie_modules/taskmodules/cross_text_binary_coref.py)
-- [LabeledSpanExtractionByTokenClassificationTaskModule](src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py)
-- [ExtractiveQuestionAnsweringTaskModule](src/pie_modules/taskmodules/extractive_question_answering.py)
-- [TextToTextTaskModule](src/pie_modules/taskmodules/text_to_text.py)
-- [PointerNetworkTaskModuleForEnd2EndRE](src/pie_modules/taskmodules/pointer_network_for_end2end_re.py)
-- [RESpanPairClassificationTaskModule](src/pie_modules/taskmodules/re_span_pair_classification.py)
+Available document types: see [here](src/pie_modules/documents.py).
Available metrics:
+- [F1Metric](src/pie_modules/metrics/f1.py)
+- [ConfusionMatrix](src/pie_modules/metrics/confusion_matrix.py)
- [SpanLengthCollector](src/pie_modules/metrics/span_length_collector.py)
- [RelationArgumentDistanceCollector](src/pie_modules/metrics/relation_argument_distance_collector.py)
- [SpanCoverageCollector](src/pie_modules/metrics/span_coverage_collector.py)
@@ -48,7 +31,7 @@ Document processing utilities:
- [RelationArgumentSorter](src/pie_modules/document/processing/relation_argument_sorter.py)
- [SentenceSplitter](src/pie_modules/document/processing/sentence_splitter.py)
- [TextSpanTrimmer](src/pie_modules/document/processing/text_span_trimmer.py)
-- [tokenize_document](src/pie_modules/document/processing/tokenization.py)
+- [tokenization utils](src/pie_modules/document/processing/tokenization.py), e.g., `text_based_document_to_token_based` and `token_based_document_to_text_based`
## Setup
diff --git a/poetry.lock b/poetry.lock
index 911cda7bf..5fbd4d80e 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,17 +1,5 @@
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
-[[package]]
-name = "absl-py"
-version = "1.4.0"
-description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
-optional = false
-python-versions = ">=3.6"
-groups = ["main"]
-files = [
- {file = "absl-py-1.4.0.tar.gz", hash = "sha256:d2c244d01048ba476e7c080bd2c6df5e141d211de80223460d5b3b8a2a58433d"},
- {file = "absl_py-1.4.0-py3-none-any.whl", hash = "sha256:0d3fe606adfa4f7db64792dd4c7aee4ee0c38ab75dfd353b7a83ed3e957fcb47"},
-]
-
[[package]]
name = "accelerate"
version = "0.32.1"
@@ -44,151 +32,6 @@ test-prod = ["parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "py
test-trackers = ["comet-ml", "dvclive", "tensorboard", "wandb"]
testing = ["bitsandbytes", "datasets", "diffusers", "evaluate", "parameterized", "pytest (>=7.2.0,<=8.0.0)", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "timm", "torchpippy (>=0.2.0)", "tqdm", "transformers"]
-[[package]]
-name = "aiohttp"
-version = "3.9.5"
-description = "Async http client/server framework (asyncio)"
-optional = false
-python-versions = ">=3.8"
-groups = ["main"]
-files = [
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fcde4c397f673fdec23e6b05ebf8d4751314fa7c24f93334bf1f1364c1c69ac7"},
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5d6b3f1fabe465e819aed2c421a6743d8debbde79b6a8600739300630a01bf2c"},
- {file = "aiohttp-3.9.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ae79c1bc12c34082d92bf9422764f799aee4746fd7a392db46b7fd357d4a17a"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d3ebb9e1316ec74277d19c5f482f98cc65a73ccd5430540d6d11682cd857430"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84dabd95154f43a2ea80deffec9cb44d2e301e38a0c9d331cc4aa0166fe28ae3"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a02fbeca6f63cb1f0475c799679057fc9268b77075ab7cf3f1c600e81dd46b"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c26959ca7b75ff768e2776d8055bf9582a6267e24556bb7f7bd29e677932be72"},
- {file = "aiohttp-3.9.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:714d4e5231fed4ba2762ed489b4aec07b2b9953cf4ee31e9871caac895a839c0"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7a6a8354f1b62e15d48e04350f13e726fa08b62c3d7b8401c0a1314f02e3558"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c413016880e03e69d166efb5a1a95d40f83d5a3a648d16486592c49ffb76d0db"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ff84aeb864e0fac81f676be9f4685f0527b660f1efdc40dcede3c251ef1e867f"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ad7f2919d7dac062f24d6f5fe95d401597fbb015a25771f85e692d043c9d7832"},
- {file = "aiohttp-3.9.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:702e2c7c187c1a498a4e2b03155d52658fdd6fda882d3d7fbb891a5cf108bb10"},
- {file = "aiohttp-3.9.5-cp310-cp310-win32.whl", hash = "sha256:67c3119f5ddc7261d47163ed86d760ddf0e625cd6246b4ed852e82159617b5fb"},
- {file = "aiohttp-3.9.5-cp310-cp310-win_amd64.whl", hash = "sha256:471f0ef53ccedec9995287f02caf0c068732f026455f07db3f01a46e49d76bbb"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e0ae53e33ee7476dd3d1132f932eeb39bf6125083820049d06edcdca4381f342"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c088c4d70d21f8ca5c0b8b5403fe84a7bc8e024161febdd4ef04575ef35d474d"},
- {file = "aiohttp-3.9.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:639d0042b7670222f33b0028de6b4e2fad6451462ce7df2af8aee37dcac55424"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f26383adb94da5e7fb388d441bf09c61e5e35f455a3217bfd790c6b6bc64b2ee"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66331d00fb28dc90aa606d9a54304af76b335ae204d1836f65797d6fe27f1ca2"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ff550491f5492ab5ed3533e76b8567f4b37bd2995e780a1f46bca2024223233"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f22eb3a6c1080d862befa0a89c380b4dafce29dc6cd56083f630073d102eb595"},
- {file = "aiohttp-3.9.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a81b1143d42b66ffc40a441379387076243ef7b51019204fd3ec36b9f69e77d6"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f64fd07515dad67f24b6ea4a66ae2876c01031de91c93075b8093f07c0a2d93d"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:93e22add827447d2e26d67c9ac0161756007f152fdc5210277d00a85f6c92323"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:55b39c8684a46e56ef8c8d24faf02de4a2b2ac60d26cee93bc595651ff545de9"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4715a9b778f4293b9f8ae7a0a7cef9829f02ff8d6277a39d7f40565c737d3771"},
- {file = "aiohttp-3.9.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afc52b8d969eff14e069a710057d15ab9ac17cd4b6753042c407dcea0e40bf75"},
- {file = "aiohttp-3.9.5-cp311-cp311-win32.whl", hash = "sha256:b3df71da99c98534be076196791adca8819761f0bf6e08e07fd7da25127150d6"},
- {file = "aiohttp-3.9.5-cp311-cp311-win_amd64.whl", hash = "sha256:88e311d98cc0bf45b62fc46c66753a83445f5ab20038bcc1b8a1cc05666f428a"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c7a4b7a6cf5b6eb11e109a9755fd4fda7d57395f8c575e166d363b9fc3ec4678"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:0a158704edf0abcac8ac371fbb54044f3270bdbc93e254a82b6c82be1ef08f3c"},
- {file = "aiohttp-3.9.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d153f652a687a8e95ad367a86a61e8d53d528b0530ef382ec5aaf533140ed00f"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:82a6a97d9771cb48ae16979c3a3a9a18b600a8505b1115cfe354dfb2054468b4"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60cdbd56f4cad9f69c35eaac0fbbdf1f77b0ff9456cebd4902f3dd1cf096464c"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8676e8fd73141ded15ea586de0b7cda1542960a7b9ad89b2b06428e97125d4fa"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da00da442a0e31f1c69d26d224e1efd3a1ca5bcbf210978a2ca7426dfcae9f58"},
- {file = "aiohttp-3.9.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18f634d540dd099c262e9f887c8bbacc959847cfe5da7a0e2e1cf3f14dbf2daf"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:320e8618eda64e19d11bdb3bd04ccc0a816c17eaecb7e4945d01deee2a22f95f"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:2faa61a904b83142747fc6a6d7ad8fccff898c849123030f8e75d5d967fd4a81"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:8c64a6dc3fe5db7b1b4d2b5cb84c4f677768bdc340611eca673afb7cf416ef5a"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:393c7aba2b55559ef7ab791c94b44f7482a07bf7640d17b341b79081f5e5cd1a"},
- {file = "aiohttp-3.9.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c671dc117c2c21a1ca10c116cfcd6e3e44da7fcde37bf83b2be485ab377b25da"},
- {file = "aiohttp-3.9.5-cp312-cp312-win32.whl", hash = "sha256:5a7ee16aab26e76add4afc45e8f8206c95d1d75540f1039b84a03c3b3800dd59"},
- {file = "aiohttp-3.9.5-cp312-cp312-win_amd64.whl", hash = "sha256:5ca51eadbd67045396bc92a4345d1790b7301c14d1848feaac1d6a6c9289e888"},
- {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:694d828b5c41255e54bc2dddb51a9f5150b4eefa9886e38b52605a05d96566e8"},
- {file = "aiohttp-3.9.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0605cc2c0088fcaae79f01c913a38611ad09ba68ff482402d3410bf59039bfb8"},
- {file = "aiohttp-3.9.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4558e5012ee03d2638c681e156461d37b7a113fe13970d438d95d10173d25f78"},
- {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9dbc053ac75ccc63dc3a3cc547b98c7258ec35a215a92bd9f983e0aac95d3d5b"},
- {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4109adee842b90671f1b689901b948f347325045c15f46b39797ae1bf17019de"},
- {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a6ea1a5b409a85477fd8e5ee6ad8f0e40bf2844c270955e09360418cfd09abac"},
- {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f3c2890ca8c59ee683fd09adf32321a40fe1cf164e3387799efb2acebf090c11"},
- {file = "aiohttp-3.9.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3916c8692dbd9d55c523374a3b8213e628424d19116ac4308e434dbf6d95bbdd"},
- {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8d1964eb7617907c792ca00b341b5ec3e01ae8c280825deadbbd678447b127e1"},
- {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d5ab8e1f6bee051a4bf6195e38a5c13e5e161cb7bad83d8854524798bd9fcd6e"},
- {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:52c27110f3862a1afbcb2af4281fc9fdc40327fa286c4625dfee247c3ba90156"},
- {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:7f64cbd44443e80094309875d4f9c71d0401e966d191c3d469cde4642bc2e031"},
- {file = "aiohttp-3.9.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8b4f72fbb66279624bfe83fd5eb6aea0022dad8eec62b71e7bf63ee1caadeafe"},
- {file = "aiohttp-3.9.5-cp38-cp38-win32.whl", hash = "sha256:6380c039ec52866c06d69b5c7aad5478b24ed11696f0e72f6b807cfb261453da"},
- {file = "aiohttp-3.9.5-cp38-cp38-win_amd64.whl", hash = "sha256:da22dab31d7180f8c3ac7c7635f3bcd53808f374f6aa333fe0b0b9e14b01f91a"},
- {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1732102949ff6087589408d76cd6dea656b93c896b011ecafff418c9661dc4ed"},
- {file = "aiohttp-3.9.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c6021d296318cb6f9414b48e6a439a7f5d1f665464da507e8ff640848ee2a58a"},
- {file = "aiohttp-3.9.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:239f975589a944eeb1bad26b8b140a59a3a320067fb3cd10b75c3092405a1372"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3b7b30258348082826d274504fbc7c849959f1989d86c29bc355107accec6cfb"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2adf5c87ff6d8b277814a28a535b59e20bfea40a101db6b3bdca7e9926bc24"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a3d838441bebcf5cf442700e3963f58b5c33f015341f9ea86dcd7d503c07e2"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e3a1ae66e3d0c17cf65c08968a5ee3180c5a95920ec2731f53343fac9bad106"},
- {file = "aiohttp-3.9.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9c69e77370cce2d6df5d12b4e12bdcca60c47ba13d1cbbc8645dd005a20b738b"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf56238f4bbf49dab8c2dc2e6b1b68502b1e88d335bea59b3f5b9f4c001475"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:d1469f228cd9ffddd396d9948b8c9cd8022b6d1bf1e40c6f25b0fb90b4f893ed"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:45731330e754f5811c314901cebdf19dd776a44b31927fa4b4dbecab9e457b0c"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3fcb4046d2904378e3aeea1df51f697b0467f2aac55d232c87ba162709478c46"},
- {file = "aiohttp-3.9.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8cf142aa6c1a751fcb364158fd710b8a9be874b81889c2bd13aa8893197455e2"},
- {file = "aiohttp-3.9.5-cp39-cp39-win32.whl", hash = "sha256:7b179eea70833c8dee51ec42f3b4097bd6370892fa93f510f76762105568cf09"},
- {file = "aiohttp-3.9.5-cp39-cp39-win_amd64.whl", hash = "sha256:38d80498e2e169bc61418ff36170e0aad0cd268da8b38a17c4cf29d254a8b3f1"},
- {file = "aiohttp-3.9.5.tar.gz", hash = "sha256:edea7d15772ceeb29db4aff55e482d4bcfb6ae160ce144f2682de02f6d693551"},
-]
-
-[package.dependencies]
-aiosignal = ">=1.1.2"
-async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""}
-attrs = ">=17.3.0"
-frozenlist = ">=1.1.1"
-multidict = ">=4.5,<7.0"
-yarl = ">=1.0,<2.0"
-
-[package.extras]
-speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""]
-
-[[package]]
-name = "aiosignal"
-version = "1.3.1"
-description = "aiosignal: a list of registered asynchronous callbacks"
-optional = false
-python-versions = ">=3.7"
-groups = ["main"]
-files = [
- {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"},
- {file = "aiosignal-1.3.1.tar.gz", hash = "sha256:54cd96e15e1649b75d6c87526a6ff0b6c1b0dd3459f43d9ca11d48c339b68cfc"},
-]
-
-[package.dependencies]
-frozenlist = ">=1.1.0"
-
-[[package]]
-name = "async-timeout"
-version = "4.0.3"
-description = "Timeout context manager for asyncio programs"
-optional = false
-python-versions = ">=3.7"
-groups = ["main"]
-markers = "python_version < \"3.11\""
-files = [
- {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"},
- {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
-]
-
-[[package]]
-name = "attrs"
-version = "23.2.0"
-description = "Classes Without Boilerplate"
-optional = false
-python-versions = ">=3.7"
-groups = ["main"]
-files = [
- {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"},
- {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"},
-]
-
-[package.extras]
-cov = ["attrs[tests]", "coverage[toml] (>=5.3)"]
-dev = ["attrs[tests]", "pre-commit"]
-docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"]
-tests = ["attrs[tests-no-zope]", "zope-interface"]
-tests-mypy = ["mypy (>=1.6) ; platform_python_implementation == \"CPython\" and python_version >= \"3.8\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.8\""]
-tests-no-zope = ["attrs[tests-mypy]", "cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "pympler", "pytest (>=4.3.0)", "pytest-xdist[psutil]"]
-
[[package]]
name = "beautifulsoup4"
version = "4.12.3"
@@ -760,93 +603,6 @@ ufo = ["fs (>=2.2.0,<3)"]
unicode = ["unicodedata2 (>=15.1.0) ; python_version <= \"3.12\""]
woff = ["brotli (>=1.0.1) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\"", "zopfli (>=0.1.4)"]
-[[package]]
-name = "frozenlist"
-version = "1.4.1"
-description = "A list-like structure which implements collections.abc.MutableSequence"
-optional = false
-python-versions = ">=3.8"
-groups = ["main"]
-files = [
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"},
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"},
- {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"},
- {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"},
- {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"},
- {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"},
- {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"},
- {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"},
- {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"},
- {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"},
- {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"},
- {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"},
- {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"},
- {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"},
- {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"},
- {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"},
- {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"},
- {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"},
- {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"},
- {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"},
- {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"},
- {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"},
- {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"},
- {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"},
- {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"},
- {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"},
- {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"},
- {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"},
- {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"},
- {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"},
- {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"},
- {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"},
- {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"},
- {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"},
- {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"},
- {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"},
- {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"},
- {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"},
- {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"},
- {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"},
- {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"},
-]
-
[[package]]
name = "fsspec"
version = "2023.6.0"
@@ -859,10 +615,6 @@ files = [
{file = "fsspec-2023.6.0.tar.gz", hash = "sha256:d0b2f935446169753e7a5c5c55681c54ea91996cc67be93c39a154fb3a2742af"},
]
-[package.dependencies]
-aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
-requests = {version = "*", optional = true, markers = "extra == \"http\""}
-
[package.extras]
abfs = ["adlfs"]
adl = ["adlfs"]
@@ -1064,22 +816,6 @@ files = [
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
-[[package]]
-name = "intel-openmp"
-version = "2021.4.0"
-description = "Intel OpenMP* Runtime Library"
-optional = false
-python-versions = "*"
-groups = ["main", "dev"]
-markers = "platform_system == \"Windows\""
-files = [
- {file = "intel_openmp-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:41c01e266a7fdb631a7609191709322da2bbf24b252ba763f125dd651bcc7675"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:3b921236a38384e2016f0f3d65af6732cf2c12918087128a9163225451e776f2"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:e2240ab8d01472fed04f3544a878cda5da16c26232b7ea1b59132dbfb48b186e"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:6e863d8fd3d7e8ef389d52cf97a50fe2afe1a19247e8c0d168ce021546f96fc9"},
- {file = "intel_openmp-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:eef4c8bcc8acefd7f5cd3b9384dbf73d59e2c99fc56545712ded913f43c4a94f"},
-]
-
[[package]]
name = "janome"
version = "0.5.0"
@@ -1263,28 +999,6 @@ files = [
[package.dependencies]
six = "*"
-[[package]]
-name = "lightning-utilities"
-version = "0.11.2"
-description = "Lightning toolbox for across the our ecosystem."
-optional = false
-python-versions = ">=3.8"
-groups = ["main"]
-files = [
- {file = "lightning-utilities-0.11.2.tar.gz", hash = "sha256:adf4cf9c5d912fe505db4729e51d1369c6927f3a8ac55a9dff895ce5c0da08d9"},
- {file = "lightning_utilities-0.11.2-py3-none-any.whl", hash = "sha256:541f471ed94e18a28d72879338c8c52e873bb46f4c47644d89228faeb6751159"},
-]
-
-[package.dependencies]
-packaging = ">=17.1"
-setuptools = "*"
-typing-extensions = "*"
-
-[package.extras]
-cli = ["fire"]
-docs = ["requests (>=2.0.0)"]
-typing = ["mypy (>=1.0.0)", "types-setuptools"]
-
[[package]]
name = "lxml"
version = "5.2.2"
@@ -1580,26 +1294,6 @@ python-dateutil = ">=2.7"
[package.extras]
dev = ["meson-python (>=0.13.1,<0.17.0)", "numpy (>=1.25)", "pybind11 (>=2.6,!=2.13.3)", "setuptools (>=64)", "setuptools_scm (>=7)"]
-[[package]]
-name = "mkl"
-version = "2021.4.0"
-description = "Intel® oneAPI Math Kernel Library"
-optional = false
-python-versions = "*"
-groups = ["main", "dev"]
-markers = "platform_system == \"Windows\""
-files = [
- {file = "mkl-2021.4.0-py2.py3-none-macosx_10_15_x86_64.macosx_11_0_x86_64.whl", hash = "sha256:67460f5cd7e30e405b54d70d1ed3ca78118370b65f7327d495e9c8847705e2fb"},
- {file = "mkl-2021.4.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:636d07d90e68ccc9630c654d47ce9fdeb036bb46e2b193b3a9ac8cfea683cce5"},
- {file = "mkl-2021.4.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:398dbf2b0d12acaf54117a5210e8f191827f373d362d796091d161f610c1ebfb"},
- {file = "mkl-2021.4.0-py2.py3-none-win32.whl", hash = "sha256:439c640b269a5668134e3dcbcea4350459c4a8bc46469669b2d67e07e3d330e8"},
- {file = "mkl-2021.4.0-py2.py3-none-win_amd64.whl", hash = "sha256:ceef3cafce4c009dd25f65d7ad0d833a0fbadc3d8903991ec92351fe5de1e718"},
-]
-
-[package.dependencies]
-intel-openmp = "==2021.*"
-tbb = "==2021.*"
-
[[package]]
name = "more-itertools"
version = "10.3.0"
@@ -1646,106 +1340,6 @@ docs = ["sphinx"]
gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""]
tests = ["pytest (>=4.6)"]
-[[package]]
-name = "multidict"
-version = "6.0.5"
-description = "multidict implementation"
-optional = false
-python-versions = ">=3.7"
-groups = ["main"]
-files = [
- {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"},
- {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"},
- {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"},
- {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"},
- {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"},
- {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"},
- {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"},
- {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"},
- {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"},
- {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"},
- {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"},
- {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"},
- {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"},
- {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"},
- {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"},
- {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"},
- {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"},
- {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"},
- {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"},
- {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"},
- {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"},
- {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"},
- {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"},
- {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"},
- {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"},
- {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"},
- {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"},
- {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"},
- {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"},
- {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"},
- {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"},
- {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"},
- {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"},
- {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"},
- {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"},
- {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"},
- {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"},
- {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"},
- {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"},
- {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"},
- {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"},
- {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"},
- {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"},
- {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"},
- {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"},
- {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"},
- {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"},
- {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"},
- {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"},
- {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"},
- {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"},
- {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"},
- {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"},
- {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"},
-]
-
[[package]]
name = "networkx"
version = "3.2.1"
@@ -1849,6 +1443,228 @@ files = [
{file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
]
+[[package]]
+name = "nvidia-cublas-cu12"
+version = "12.6.4.1"
+description = "CUBLAS native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb"},
+ {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:235f728d6e2a409eddf1df58d5b0921cf80cfa9e72b9f2775ccb7b4a87984668"},
+ {file = "nvidia_cublas_cu12-12.6.4.1-py3-none-win_amd64.whl", hash = "sha256:9e4fa264f4d8a4eb0cdbd34beadc029f453b3bafae02401e999cf3d5a5af75f8"},
+]
+
+[[package]]
+name = "nvidia-cuda-cupti-cu12"
+version = "12.6.80"
+description = "CUDA profiling tools runtime libs."
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:166ee35a3ff1587f2490364f90eeeb8da06cd867bd5b701bf7f9a02b78bc63fc"},
+ {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_aarch64.whl", hash = "sha256:358b4a1d35370353d52e12f0a7d1769fc01ff74a191689d3870b2123156184c4"},
+ {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132"},
+ {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73"},
+ {file = "nvidia_cuda_cupti_cu12-12.6.80-py3-none-win_amd64.whl", hash = "sha256:bbe6ae76e83ce5251b56e8c8e61a964f757175682bbad058b170b136266ab00a"},
+]
+
+[[package]]
+name = "nvidia-cuda-nvrtc-cu12"
+version = "12.6.77"
+description = "NVRTC native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5847f1d6e5b757f1d2b3991a01082a44aad6f10ab3c5c0213fa3e25bddc25a13"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53"},
+ {file = "nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:f7007dbd914c56bd80ea31bc43e8e149da38f68158f423ba845fc3292684e45a"},
+]
+
+[[package]]
+name = "nvidia-cuda-runtime-cu12"
+version = "12.6.77"
+description = "CUDA Runtime native Libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6116fad3e049e04791c0256a9778c16237837c08b27ed8c8401e2e45de8d60cd"},
+ {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:d461264ecb429c84c8879a7153499ddc7b19b5f8d84c204307491989a365588e"},
+ {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7"},
+ {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8"},
+ {file = "nvidia_cuda_runtime_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:86c58044c824bf3c173c49a2dbc7a6c8b53cb4e4dca50068be0bf64e9dab3f7f"},
+]
+
+[[package]]
+name = "nvidia-cudnn-cu12"
+version = "9.5.1.17"
+description = "cuDNN runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:9fd4584468533c61873e5fda8ca41bac3a38bcb2d12350830c69b0a96a7e4def"},
+ {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2"},
+ {file = "nvidia_cudnn_cu12-9.5.1.17-py3-none-win_amd64.whl", hash = "sha256:d7af0f8a4f3b4b9dbb3122f2ef553b45694ed9c384d5a75bab197b8eefb79ab8"},
+]
+
+[package.dependencies]
+nvidia-cublas-cu12 = "*"
+
+[[package]]
+name = "nvidia-cufft-cu12"
+version = "11.3.0.4"
+description = "CUFFT native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d16079550df460376455cba121db6564089176d9bac9e4f360493ca4741b22a6"},
+ {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8510990de9f96c803a051822618d42bf6cb8f069ff3f48d93a8486efdacb48fb"},
+ {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5"},
+ {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca"},
+ {file = "nvidia_cufft_cu12-11.3.0.4-py3-none-win_amd64.whl", hash = "sha256:6048ebddfb90d09d2707efb1fd78d4e3a77cb3ae4dc60e19aab6be0ece2ae464"},
+]
+
+[package.dependencies]
+nvidia-nvjitlink-cu12 = "*"
+
+[[package]]
+name = "nvidia-cufile-cu12"
+version = "1.11.1.6"
+description = "cuFile GPUDirect libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159"},
+ {file = "nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:8f57a0051dcf2543f6dc2b98a98cb2719c37d3cee1baba8965d57f3bbc90d4db"},
+]
+
+[[package]]
+name = "nvidia-curand-cu12"
+version = "10.3.7.77"
+description = "CURAND native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6e82df077060ea28e37f48a3ec442a8f47690c7499bff392a5938614b56c98d8"},
+ {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf"},
+ {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117"},
+ {file = "nvidia_curand_cu12-10.3.7.77-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:7b2ed8e95595c3591d984ea3603dd66fe6ce6812b886d59049988a712ed06b6e"},
+ {file = "nvidia_curand_cu12-10.3.7.77-py3-none-win_amd64.whl", hash = "sha256:6d6d935ffba0f3d439b7cd968192ff068fafd9018dbf1b85b37261b13cfc9905"},
+]
+
+[[package]]
+name = "nvidia-cusolver-cu12"
+version = "11.7.1.2"
+description = "CUDA solver native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0ce237ef60acde1efc457335a2ddadfd7610b892d94efee7b776c64bb1cac9e0"},
+ {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c"},
+ {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6"},
+ {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:dbbe4fc38ec1289c7e5230e16248365e375c3673c9c8bac5796e2e20db07f56e"},
+ {file = "nvidia_cusolver_cu12-11.7.1.2-py3-none-win_amd64.whl", hash = "sha256:6813f9d8073f555444a8705f3ab0296d3e1cb37a16d694c5fc8b862a0d8706d7"},
+]
+
+[package.dependencies]
+nvidia-cublas-cu12 = "*"
+nvidia-cusparse-cu12 = "*"
+nvidia-nvjitlink-cu12 = "*"
+
+[[package]]
+name = "nvidia-cusparse-cu12"
+version = "12.5.4.2"
+description = "CUSPARSE native runtime libraries"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d25b62fb18751758fe3c93a4a08eff08effedfe4edf1c6bb5afd0890fe88f887"},
+ {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7aa32fa5470cf754f72d1116c7cbc300b4e638d3ae5304cfa4a638a5b87161b1"},
+ {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73"},
+ {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f"},
+ {file = "nvidia_cusparse_cu12-12.5.4.2-py3-none-win_amd64.whl", hash = "sha256:4acb8c08855a26d737398cba8fb6f8f5045d93f82612b4cfd84645a2332ccf20"},
+]
+
+[package.dependencies]
+nvidia-nvjitlink-cu12 = "*"
+
+[[package]]
+name = "nvidia-cusparselt-cu12"
+version = "0.6.3"
+description = "NVIDIA cuSPARSELt"
+optional = false
+python-versions = "*"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8371549623ba601a06322af2133c4a44350575f5a3108fb75f3ef20b822ad5f1"},
+ {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46"},
+ {file = "nvidia_cusparselt_cu12-0.6.3-py3-none-win_amd64.whl", hash = "sha256:3b325bcbd9b754ba43df5a311488fca11a6b5dc3d11df4d190c000cf1a0765c7"},
+]
+
+[[package]]
+name = "nvidia-nccl-cu12"
+version = "2.26.2"
+description = "NVIDIA Collective Communication Library (NCCL) Runtime"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c196e95e832ad30fbbb50381eb3cbd1fadd5675e587a548563993609af19522"},
+ {file = "nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6"},
+]
+
+[[package]]
+name = "nvidia-nvjitlink-cu12"
+version = "12.6.85"
+description = "Nvidia JIT LTO Library"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a"},
+ {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cf4eaa7d4b6b543ffd69d6abfb11efdeb2db48270d94dfd3a452c24150829e41"},
+ {file = "nvidia_nvjitlink_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:e61120e52ed675747825cdd16febc6a0730537451d867ee58bee3853b1b13d1c"},
+]
+
+[[package]]
+name = "nvidia-nvtx-cu12"
+version = "12.6.77"
+description = "NVIDIA Tools Extension"
+optional = false
+python-versions = ">=3"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f44f8d86bb7d5629988d61c8d3ae61dddb2015dee142740536bc7481b022fe4b"},
+ {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_aarch64.whl", hash = "sha256:adcaabb9d436c9761fca2b13959a2d237c5f9fd406c8e4b723c695409ff88059"},
+ {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2"},
+ {file = "nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1"},
+ {file = "nvidia_nvtx_cu12-12.6.77-py3-none-win_amd64.whl", hash = "sha256:2fb11a4af04a5e6c84073e6404d26588a34afd35379f0855a99797897efa75c0"},
+]
+
[[package]]
name = "packaging"
version = "24.0"
@@ -2248,72 +2064,6 @@ files = [
[package.dependencies]
six = ">=1.5"
-[[package]]
-name = "pytorch-crf"
-version = "0.7.2"
-description = "Conditional random field in PyTorch"
-optional = false
-python-versions = ">=3.6, <4"
-groups = ["dev"]
-files = [
- {file = "pytorch-crf-0.7.2.tar.gz", hash = "sha256:e6456e22ccfc99a3d4fe1e03e996103b1b39e9830bf3c7e12e7a9077d3be866d"},
- {file = "pytorch_crf-0.7.2-py3-none-any.whl", hash = "sha256:1b2d7d5eea3255f6e0cac09ab8b645472e76ff70d9333bc88762cf7317a4992d"},
-]
-
-[[package]]
-name = "pytorch-ie"
-version = "0.31.9"
-description = "State-of-the-art Information Extraction in PyTorch"
-optional = false
-python-versions = "<4.0,>=3.9"
-groups = ["main"]
-files = [
- {file = "pytorch_ie-0.31.9-py3-none-any.whl", hash = "sha256:002eab323d529022e13a1ed1a7effc43e1bc172bcc11abe58c46501a8c37eb54"},
- {file = "pytorch_ie-0.31.9.tar.gz", hash = "sha256:bd516817ce759c059fcbe61c8d420367366c82014d3ba49938a61cb564102610"},
-]
-
-[package.dependencies]
-absl-py = ">=1.0.0,<2.0.0"
-fsspec = "<2023.9.0"
-pandas = ">=2.0.0,<3.0.0"
-pie-core = ">=0.2.0,<0.3.0"
-pytorch-lightning = ">=2,<3"
-torch = ">=1.10"
-torchmetrics = ">=1,<2"
-transformers = ">=4.18,<5.0"
-
-[[package]]
-name = "pytorch-lightning"
-version = "2.2.5"
-description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
-optional = false
-python-versions = ">=3.8"
-groups = ["main"]
-files = [
- {file = "pytorch-lightning-2.2.5.tar.gz", hash = "sha256:8d06d0166e2204f82864f5d2b53a367c2c375d9cd5a7f6174434b2dffeaef7e9"},
- {file = "pytorch_lightning-2.2.5-py3-none-any.whl", hash = "sha256:67a7800863326914f68f6afd68f427855ef2315b4f00d554be8ea4c0f0557fd8"},
-]
-
-[package.dependencies]
-fsspec = {version = ">=2022.5.0", extras = ["http"]}
-lightning-utilities = ">=0.8.0"
-numpy = ">=1.17.2"
-packaging = ">=20.0"
-PyYAML = ">=5.4"
-torch = ">=1.13.0"
-torchmetrics = ">=0.7.0"
-tqdm = ">=4.57.0"
-typing-extensions = ">=4.4.0"
-
-[package.extras]
-all = ["bitsandbytes (==0.41.0)", "deepspeed (>=0.8.2,<=0.9.3) ; platform_system != \"Windows\"", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.27.7)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "requests (<2.32.0)", "rich (>=12.3.0)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.14.0)"]
-deepspeed = ["deepspeed (>=0.8.2,<=0.9.3) ; platform_system != \"Windows\""]
-dev = ["bitsandbytes (==0.41.0)", "cloudpickle (>=1.3)", "coverage (==7.3.1)", "deepspeed (>=0.8.2,<=0.9.3) ; platform_system != \"Windows\"", "fastapi", "gym[classic-control] (>=0.17.0)", "hydra-core (>=1.0.5)", "ipython[all] (<8.15.0)", "jsonargparse[signatures] (>=4.27.7)", "lightning-utilities (>=0.8.0)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "requests (<2.32.0)", "rich (>=12.3.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "tensorboardX (>=2.2)", "torchmetrics (>=0.10.0)", "torchvision (>=0.14.0)", "uvicorn"]
-examples = ["gym[classic-control] (>=0.17.0)", "ipython[all] (<8.15.0)", "lightning-utilities (>=0.8.0)", "requests (<2.32.0)", "torchmetrics (>=0.10.0)", "torchvision (>=0.14.0)"]
-extra = ["bitsandbytes (==0.41.0)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.27.7)", "matplotlib (>3.1)", "omegaconf (>=2.0.5)", "rich (>=12.3.0)", "tensorboardX (>=2.2)"]
-strategies = ["deepspeed (>=0.8.2,<=0.9.3) ; platform_system != \"Windows\""]
-test = ["cloudpickle (>=1.3)", "coverage (==7.3.1)", "fastapi", "onnx (>=0.14.0)", "onnxruntime (>=0.15.0)", "pandas (>1.0)", "psutil (<5.9.6)", "pytest (==7.4.0)", "pytest-cov (==4.1.0)", "pytest-random-order (==1.1.0)", "pytest-rerunfailures (==12.0)", "pytest-timeout (==2.1.0)", "scikit-learn (>0.22.1)", "tensorboard (>=2.9.1)", "uvicorn"]
-
[[package]]
name = "pytorch-revgrad"
version = "0.2.0"
@@ -2413,7 +2163,7 @@ version = "2024.5.15"
description = "Alternative regular expression module, to replace re."
optional = false
python-versions = ">=3.8"
-groups = ["main", "dev"]
+groups = ["dev"]
files = [
{file = "regex-2024.5.15-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a81e3cfbae20378d75185171587cbf756015ccb14840702944f014e0d93ea09f"},
{file = "regex-2024.5.15-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7b59138b219ffa8979013be7bc85bb60c6f7b7575df3d56dc1e403a438c7a3f6"},
@@ -2543,7 +2293,7 @@ version = "0.4.3"
description = ""
optional = false
python-versions = ">=3.7"
-groups = ["main", "dev"]
+groups = ["dev"]
files = [
{file = "safetensors-0.4.3-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:dcf5705cab159ce0130cd56057f5f3425023c407e170bca60b4868048bae64fd"},
{file = "safetensors-0.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bb4f8c5d0358a31e9a08daeebb68f5e161cdd4018855426d3f0c23bb51087055"},
@@ -2845,7 +2595,8 @@ version = "70.0.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
-groups = ["main"]
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" or python_version >= \"3.12\""
files = [
{file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"},
{file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"},
@@ -2918,18 +2669,21 @@ files = [
[[package]]
name = "sympy"
-version = "1.12.1"
+version = "1.14.0"
description = "Computer algebra system (CAS) in Python"
optional = false
-python-versions = ">=3.8"
+python-versions = ">=3.9"
groups = ["main", "dev"]
files = [
- {file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
- {file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
+ {file = "sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5"},
+ {file = "sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517"},
]
[package.dependencies]
-mpmath = ">=1.1.0,<1.4.0"
+mpmath = ">=1.1.0,<1.4"
+
+[package.extras]
+dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"]
[[package]]
name = "tabulate"
@@ -2946,21 +2700,6 @@ files = [
[package.extras]
widechars = ["wcwidth"]
-[[package]]
-name = "tbb"
-version = "2021.12.0"
-description = "Intel® oneAPI Threading Building Blocks (oneTBB)"
-optional = false
-python-versions = "*"
-groups = ["main", "dev"]
-markers = "platform_system == \"Windows\""
-files = [
- {file = "tbb-2021.12.0-py2.py3-none-manylinux1_i686.whl", hash = "sha256:f2cc9a7f8ababaa506cbff796ce97c3bf91062ba521e15054394f773375d81d8"},
- {file = "tbb-2021.12.0-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:a925e9a7c77d3a46ae31c34b0bb7f801c4118e857d137b68f68a8e458fcf2bd7"},
- {file = "tbb-2021.12.0-py3-none-win32.whl", hash = "sha256:b1725b30c174048edc8be70bd43bb95473f396ce895d91151a474d0fa9f450a8"},
- {file = "tbb-2021.12.0-py3-none-win_amd64.whl", hash = "sha256:fc2772d850229f2f3df85f1109c4844c495a2db7433d38200959ee9265b34789"},
-]
-
[[package]]
name = "threadpoolctl"
version = "3.5.0"
@@ -2979,7 +2718,7 @@ version = "0.15.2"
description = ""
optional = false
python-versions = ">=3.7"
-groups = ["main", "dev"]
+groups = ["dev"]
files = [
{file = "tokenizers-0.15.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:52f6130c9cbf70544287575a985bf44ae1bda2da7e8c24e97716080593638012"},
{file = "tokenizers-0.15.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:054c1cc9c6d68f7ffa4e810b3d5131e0ba511b6e4be34157aa08ee54c2f8d9ee"},
@@ -3116,71 +2855,65 @@ files = [
[[package]]
name = "torch"
-version = "2.3.0+cpu"
+version = "2.7.1"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
optional = false
-python-versions = ">=3.8.0"
+python-versions = ">=3.9.0"
groups = ["main", "dev"]
files = [
- {file = "torch-2.3.0+cpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:e3c220702d82c7596924150e0499fbbffcf62a88a59adc860fa357cd8dc1c302"},
- {file = "torch-2.3.0+cpu-cp310-cp310-win_amd64.whl", hash = "sha256:ab0c05525195b8fecdf2ea75968ed32ccd87dff16381b6e13249babb4a9596ff"},
- {file = "torch-2.3.0+cpu-cp311-cp311-linux_x86_64.whl", hash = "sha256:97a38b25ee0e3d020691e7846efbca62a3d8a57645c027dcb5ba0adfec36fe55"},
- {file = "torch-2.3.0+cpu-cp311-cp311-win_amd64.whl", hash = "sha256:a8ac195974be6f067245bae8156b8c06fb0a723b0eed8f2e244b5dd58c7e2a49"},
- {file = "torch-2.3.0+cpu-cp312-cp312-linux_x86_64.whl", hash = "sha256:a8982e52185771591dad577a124a7770f72f288f8ae5833317b1e329c0d2f07e"},
- {file = "torch-2.3.0+cpu-cp312-cp312-win_amd64.whl", hash = "sha256:483131a7997995d867313ee902743084e844e830ab2a0c5e079c61ec2da3cd17"},
- {file = "torch-2.3.0+cpu-cp38-cp38-linux_x86_64.whl", hash = "sha256:8c52484880d5fbe511cffc255dd34847ddeced3f94334c6bf7eb2b0445f10cb4"},
- {file = "torch-2.3.0+cpu-cp38-cp38-win_amd64.whl", hash = "sha256:28a11bcc0d709b397d675cff689707019b8cc122e6bf328b57b900f47c36f156"},
- {file = "torch-2.3.0+cpu-cp39-cp39-linux_x86_64.whl", hash = "sha256:1e86e225e472392440ace378ba3165b5e87648e8b5fbf16adc41c0df881c38b8"},
- {file = "torch-2.3.0+cpu-cp39-cp39-win_amd64.whl", hash = "sha256:5c2afdff80203eaabf4c223a294c2f465020b3360e8e87f76b52ace9c5801ebe"},
+ {file = "torch-2.7.1-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:a103b5d782af5bd119b81dbcc7ffc6fa09904c423ff8db397a1e6ea8fd71508f"},
+ {file = "torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:fe955951bdf32d182ee8ead6c3186ad54781492bf03d547d31771a01b3d6fb7d"},
+ {file = "torch-2.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:885453d6fba67d9991132143bf7fa06b79b24352f4506fd4d10b309f53454162"},
+ {file = "torch-2.7.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:d72acfdb86cee2a32c0ce0101606f3758f0d8bb5f8f31e7920dc2809e963aa7c"},
+ {file = "torch-2.7.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:236f501f2e383f1cb861337bdf057712182f910f10aeaf509065d54d339e49b2"},
+ {file = "torch-2.7.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:06eea61f859436622e78dd0cdd51dbc8f8c6d76917a9cf0555a333f9eac31ec1"},
+ {file = "torch-2.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:8273145a2e0a3c6f9fd2ac36762d6ee89c26d430e612b95a99885df083b04e52"},
+ {file = "torch-2.7.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:aea4fc1bf433d12843eb2c6b2204861f43d8364597697074c8d38ae2507f8730"},
+ {file = "torch-2.7.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:27ea1e518df4c9de73af7e8a720770f3628e7f667280bce2be7a16292697e3fa"},
+ {file = "torch-2.7.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c33360cfc2edd976c2633b3b66c769bdcbbf0e0b6550606d188431c81e7dd1fc"},
+ {file = "torch-2.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:d8bf6e1856ddd1807e79dc57e54d3335f2b62e6f316ed13ed3ecfe1fc1df3d8b"},
+ {file = "torch-2.7.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:787687087412c4bd68d315e39bc1223f08aae1d16a9e9771d95eabbb04ae98fb"},
+ {file = "torch-2.7.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:03563603d931e70722dce0e11999d53aa80a375a3d78e6b39b9f6805ea0a8d28"},
+ {file = "torch-2.7.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:d632f5417b6980f61404a125b999ca6ebd0b8b4bbdbb5fbbba44374ab619a412"},
+ {file = "torch-2.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:23660443e13995ee93e3d844786701ea4ca69f337027b05182f5ba053ce43b38"},
+ {file = "torch-2.7.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0da4f4dba9f65d0d203794e619fe7ca3247a55ffdcbd17ae8fb83c8b2dc9b585"},
+ {file = "torch-2.7.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:e08d7e6f21a617fe38eeb46dd2213ded43f27c072e9165dc27300c9ef9570934"},
+ {file = "torch-2.7.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:30207f672328a42df4f2174b8f426f354b2baa0b7cca3a0adb3d6ab5daf00dc8"},
+ {file = "torch-2.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:79042feca1c634aaf6603fe6feea8c6b30dfa140a6bbc0b973e2260c7e79a22e"},
+ {file = "torch-2.7.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:988b0cbc4333618a1056d2ebad9eb10089637b659eb645434d0809d8d937b946"},
+ {file = "torch-2.7.1-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:e0d81e9a12764b6f3879a866607c8ae93113cbcad57ce01ebde63eb48a576369"},
+ {file = "torch-2.7.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:8394833c44484547ed4a47162318337b88c97acdb3273d85ea06e03ffff44998"},
+ {file = "torch-2.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:df41989d9300e6e3c19ec9f56f856187a6ef060c3662fe54f4b6baf1fc90bd19"},
+ {file = "torch-2.7.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:a737b5edd1c44a5c1ece2e9f3d00df9d1b3fb9541138bee56d83d38293fb6c9d"},
]
[package.dependencies]
filelock = "*"
fsspec = "*"
jinja2 = "*"
-mkl = {version = ">=2021.1.1,<=2021.4.0", markers = "platform_system == \"Windows\""}
networkx = "*"
-sympy = "*"
-typing-extensions = ">=4.8.0"
+nvidia-cublas-cu12 = {version = "12.6.4.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-cupti-cu12 = {version = "12.6.80", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-nvrtc-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cuda-runtime-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cudnn-cu12 = {version = "9.5.1.17", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cufft-cu12 = {version = "11.3.0.4", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cufile-cu12 = {version = "1.11.1.6", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-curand-cu12 = {version = "10.3.7.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusolver-cu12 = {version = "11.7.1.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusparse-cu12 = {version = "12.5.4.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-cusparselt-cu12 = {version = "0.6.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nccl-cu12 = {version = "2.26.2", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvjitlink-cu12 = {version = "12.6.85", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+nvidia-nvtx-cu12 = {version = "12.6.77", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+setuptools = {version = "*", markers = "python_version >= \"3.12\""}
+sympy = ">=1.13.3"
+triton = {version = "3.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
+typing-extensions = ">=4.10.0"
[package.extras]
opt-einsum = ["opt-einsum (>=3.3)"]
-optree = ["optree (>=0.9.1)"]
-
-[package.source]
-type = "legacy"
-url = "https://download.pytorch.org/whl/cpu"
-reference = "pytorch"
-
-[[package]]
-name = "torchmetrics"
-version = "1.4.0.post0"
-description = "PyTorch native Metrics"
-optional = false
-python-versions = ">=3.8"
-groups = ["main"]
-files = [
- {file = "torchmetrics-1.4.0.post0-py3-none-any.whl", hash = "sha256:ab234216598e3fbd8d62ee4541a0e74e7e8fc935d099683af5b8da50f745b3c8"},
- {file = "torchmetrics-1.4.0.post0.tar.gz", hash = "sha256:ab9bcfe80e65dbabbddb6cecd9be21f1f1d5207bb74051ef95260740f2762358"},
-]
-
-[package.dependencies]
-lightning-utilities = ">=0.8.0"
-numpy = ">1.20.0"
-packaging = ">17.1"
-torch = ">=1.10.0"
-
-[package.extras]
-all = ["SciencePlots (>=2.0.0)", "ipadic (>=1.0.0)", "matplotlib (>=3.3.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.9.0)", "nltk (>=3.6)", "piq (<=0.8.0)", "pretty-errors (>=1.2.0)", "pycocotools (>2.0.0)", "pystoi (>=0.3.0)", "regex (>=2021.9.24)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.3.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>4.4.0)", "transformers (>=4.10.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
-audio = ["pystoi (>=0.3.0)", "torchaudio (>=0.10.0)"]
-debug = ["pretty-errors (>=1.2.0)"]
-detection = ["pycocotools (>2.0.0)", "torchvision (>=0.8)"]
-dev = ["SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (<=0.7.5)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.3.3)", "huggingface-hub (<0.23)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "lpips (<=0.1.4)", "matplotlib (>=3.3.0)", "mecab-ko (>=1.0.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.0)", "mypy (==1.9.0)", "netcal (>1.0.0)", "nltk (>=3.6)", "numpy (<1.27.0)", "pandas (>1.0.0)", "pandas (>=1.4.0)", "piq (<=0.8.0)", "pretty-errors (>=1.2.0)", "pycocotools (>2.0.0)", "pystoi (>=0.3.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.3.0)", "torch-complex (<=0.4.3)", "torch-fidelity (<=0.4.0)", "torchaudio (>=0.10.0)", "torchvision (>=0.8)", "tqdm (>=4.41.0)", "transformers (>4.4.0)", "transformers (>=4.10.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
-image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.8)"]
-multimodal = ["piq (<=0.8.0)", "transformers (>=4.10.0)"]
-text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>=3.6)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (>=4.41.0)", "transformers (>4.4.0)"]
-typing = ["mypy (==1.9.0)", "torch (==2.3.0)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
-visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.3.0)"]
+optree = ["optree (>=0.13.0)"]
[[package]]
name = "tqdm"
@@ -3226,7 +2959,7 @@ version = "4.36.2"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false
python-versions = ">=3.8.0"
-groups = ["main", "dev"]
+groups = ["dev"]
files = [
{file = "transformers-4.36.2-py3-none-any.whl", hash = "sha256:462066c4f74ee52516f12890dcc9ec71d1a5e97998db621668455117a54330f6"},
{file = "transformers-4.36.2.tar.gz", hash = "sha256:d8068e897e47793281501e547d2bbdfc5b8556409c2cb6c3d9e2ca77d4c0b4ec"},
@@ -3293,6 +3026,31 @@ torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata",
video = ["av (==9.2.0)", "decord (==0.6.0)"]
vision = ["Pillow (>=10.0.1,<=15.0)"]
+[[package]]
+name = "triton"
+version = "3.3.1"
+description = "A language and compiler for custom Deep Learning operations"
+optional = false
+python-versions = "*"
+groups = ["main", "dev"]
+markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""
+files = [
+ {file = "triton-3.3.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b74db445b1c562844d3cfad6e9679c72e93fdfb1a90a24052b03bb5c49d1242e"},
+ {file = "triton-3.3.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b31e3aa26f8cb3cc5bf4e187bf737cbacf17311e1112b781d4a059353dfd731b"},
+ {file = "triton-3.3.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9999e83aba21e1a78c1f36f21bce621b77bcaa530277a50484a7cb4a822f6e43"},
+ {file = "triton-3.3.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b89d846b5a4198317fec27a5d3a609ea96b6d557ff44b56c23176546023c4240"},
+ {file = "triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42"},
+ {file = "triton-3.3.1-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6139aeb04a146b0b8e0fbbd89ad1e65861c57cfed881f21d62d3cb94a36bab7"},
+]
+
+[package.dependencies]
+setuptools = ">=40.8.0"
+
+[package.extras]
+build = ["cmake (>=3.20)", "lit"]
+tests = ["autopep8", "isort", "llnl-hatchet", "numpy", "pytest", "pytest-forked", "pytest-xdist", "scipy (>=1.7.1)"]
+tutorials = ["matplotlib", "pandas", "tabulate"]
+
[[package]]
name = "typing-extensions"
version = "4.12.0"
@@ -3462,110 +3220,6 @@ files = [
{file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"},
]
-[[package]]
-name = "yarl"
-version = "1.9.4"
-description = "Yet another URL library"
-optional = false
-python-versions = ">=3.7"
-groups = ["main"]
-files = [
- {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"},
- {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"},
- {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"},
- {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"},
- {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"},
- {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"},
- {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"},
- {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"},
- {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"},
- {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"},
- {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"},
- {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"},
- {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"},
- {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"},
- {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"},
- {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"},
- {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"},
- {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"},
- {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"},
- {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"},
- {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"},
- {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"},
- {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"},
- {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"},
- {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"},
- {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"},
- {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"},
- {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"},
- {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"},
- {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"},
- {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"},
- {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"},
- {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"},
- {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"},
- {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"},
- {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"},
- {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"},
- {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"},
- {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"},
- {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"},
- {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"},
- {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"},
- {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"},
- {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"},
- {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"},
- {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"},
- {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"},
- {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"},
- {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"},
- {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"},
- {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"},
- {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"},
- {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"},
- {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"},
-]
-
-[package.dependencies]
-idna = ">=2.0"
-multidict = ">=4.0"
-
[[package]]
name = "zipp"
version = "3.19.2"
@@ -3586,4 +3240,4 @@ test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.funct
[metadata]
lock-version = "2.1"
python-versions = "^3.9"
-content-hash = "360e7128e1296a81a16070db02a7145bf9e807f3795bcfe56f4e1e453c976f6b"
+content-hash = "084124893c8d1054d2f99e0c3a8faa3346695641ecca8e97f2af90d38f842a84"
diff --git a/pyproject.toml b/pyproject.toml
index 9af0c843c..1b27be77c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "pie-modules"
version = "0.15.9"
-description = "Model and Taskmodule implementations for PyTorch-IE"
+description = "Utility modules for Python-IE"
authors = ["Arne Binder "]
readme = "README.md"
homepage = "https://github.com/arnebinder/pie-modules"
@@ -24,22 +24,15 @@ classifiers = [
[tool.poetry.dependencies]
python = "^3.9"
-# TODO: remove and use pie-core instead
-pytorch-ie = ">=0.31.9,<0.32.0"
-pytorch-lightning = "^2.1.0"
-torchmetrics = "^1"
-# >=4.35 because of BartModelWithDecoderPositionIds, <4.37 because of generation config
-# created from model config in BartAsPointerNetwork
-transformers = ">=4.35.0,<4.37.0"
+pie-core = ">=0.2.0,<0.3.0"
+# for show_as_markdown in metrics
+pandas = ">=2.0.3,<3.0.0"
[tool.poetry.group.dev.dependencies]
-torch = {version = "^2.1.0+cpu", source = "pytorch"}
pytest = "^7.4.2"
pytest-cov = "^4.1.0"
pre-commit = "^3.4.0"
tabulate = "^0.9"
-# for TokenClassificationModelWithSeq2SeqEncoderAndCrf
-pytorch-crf = ">=0.7.2"
# for rouge metric (tests only) and for NltkSentenceSplitter
nltk = "^3.8.1"
# for NltkSentenceSplitter
@@ -60,7 +53,6 @@ name = "pre-release"
url = "https://test.pypi.org/simple/"
priority = "explicit"
-
[tool.pytest.ini_options]
addopts = [
"--color=yes",
diff --git a/src/pie_modules/annotations.py b/src/pie_modules/annotations.py
index d7710595f..dfdbb586a 100644
--- a/src/pie_modules/annotations.py
+++ b/src/pie_modules/annotations.py
@@ -1,22 +1,192 @@
import dataclasses
-from typing import Optional
+from dataclasses import dataclass, field
+from typing import Any, Optional, Tuple
from pie_core import Annotation
-# re-export all annotations from pytorch_ie to have a single entry point
-from pytorch_ie.annotations import (
- BinaryRelation,
- Label,
- LabeledMultiSpan,
- LabeledSpan,
- MultiLabel,
- MultiLabeledBinaryRelation,
- MultiLabeledSpan,
- MultiSpan,
- NaryRelation,
- Span,
- _post_init_single_label,
-)
+
+def _post_init_single_label(self):
+ if not isinstance(self.label, str):
+ raise ValueError("label must be a single string.")
+
+ if not isinstance(self.score, float):
+ raise ValueError("score must be a single float.")
+
+
+def _post_init_multi_label(self):
+ if self.score is None:
+ score = tuple([1.0] * len(self.label))
+ object.__setattr__(self, "score", score)
+
+ if not isinstance(self.label, tuple):
+ object.__setattr__(self, "label", tuple(self.label))
+
+ if not isinstance(self.score, tuple):
+ object.__setattr__(self, "score", tuple(self.score))
+
+ if len(self.label) != len(self.score):
+ raise ValueError(
+ f"Number of labels ({len(self.label)}) and scores ({len(self.score)}) must be equal."
+ )
+
+
+def _post_init_multi_span(self):
+ if isinstance(self.slices, list):
+ object.__setattr__(self, "slices", tuple(tuple(s) for s in self.slices))
+
+
+def _post_init_arguments_and_roles(self):
+ if len(self.arguments) != len(self.roles):
+ raise ValueError(
+ f"Number of arguments ({len(self.arguments)}) and roles ({len(self.roles)}) must be equal"
+ )
+ if not isinstance(self.arguments, tuple):
+ object.__setattr__(self, "arguments", tuple(self.arguments))
+ if not isinstance(self.roles, tuple):
+ object.__setattr__(self, "roles", tuple(self.roles))
+
+
+@dataclass(eq=True, frozen=True)
+class Label(Annotation):
+ label: str
+ score: float = field(default=1.0, compare=False)
+
+ def __post_init__(self) -> None:
+ _post_init_single_label(self)
+
+ def resolve(self) -> Any:
+ return self.label
+
+
+@dataclass(eq=True, frozen=True)
+class MultiLabel(Annotation):
+ label: Tuple[str, ...]
+ score: Optional[Tuple[float, ...]] = field(default=None, compare=False)
+
+ def __post_init__(self) -> None:
+ _post_init_multi_label(self)
+
+ def resolve(self) -> Any:
+ return self.label
+
+
+@dataclass(eq=True, frozen=True)
+class Span(Annotation):
+ start: int
+ end: int
+
+ def __str__(self) -> str:
+ if not self.is_attached:
+ return super().__str__()
+ return str(self.target[self.start : self.end])
+
+ def resolve(self) -> Any:
+ if self.is_attached:
+ return self.target[self.start : self.end]
+ else:
+ raise ValueError(f"{self} is not attached to a target.")
+
+
+@dataclass(eq=True, frozen=True)
+class LabeledSpan(Span):
+ label: str
+ score: float = field(default=1.0, compare=False)
+
+ def __post_init__(self) -> None:
+ _post_init_single_label(self)
+
+ def resolve(self) -> Any:
+ return self.label, super().resolve()
+
+
+@dataclass(eq=True, frozen=True)
+class MultiLabeledSpan(Span):
+ label: Tuple[str, ...]
+ score: Optional[Tuple[float, ...]] = field(default=None, compare=False)
+
+ def __post_init__(self) -> None:
+ _post_init_multi_label(self)
+
+ def resolve(self) -> Any:
+ return self.label, super().resolve()
+
+
+@dataclass(eq=True, frozen=True)
+class MultiSpan(Annotation):
+ slices: Tuple[Tuple[int, int], ...]
+
+ def __post_init__(self) -> None:
+ _post_init_multi_span(self)
+
+ def __str__(self) -> str:
+ if not self.is_attached:
+ return super().__str__()
+ return str(tuple(self.target[start:end] for start, end in self.slices))
+
+ def resolve(self) -> Any:
+ if self.is_attached:
+ return tuple(self.target[start:end] for start, end in self.slices)
+ else:
+ raise ValueError(f"{self} is not attached to a target.")
+
+
+@dataclass(eq=True, frozen=True)
+class LabeledMultiSpan(MultiSpan):
+ label: str
+ score: float = field(default=1.0, compare=False)
+
+ def __post_init__(self) -> None:
+ super().__post_init__()
+ _post_init_single_label(self)
+
+ def resolve(self) -> Any:
+ return self.label, super().resolve()
+
+
+@dataclass(eq=True, frozen=True)
+class BinaryRelation(Annotation):
+ head: Annotation
+ tail: Annotation
+ label: str
+ score: float = field(default=1.0, compare=False)
+
+ def __post_init__(self) -> None:
+ _post_init_single_label(self)
+
+ def resolve(self) -> Any:
+ return self.label, (self.head.resolve(), self.tail.resolve())
+
+
+@dataclass(eq=True, frozen=True)
+class MultiLabeledBinaryRelation(Annotation):
+ head: Annotation
+ tail: Annotation
+ label: Tuple[str, ...]
+ score: Optional[Tuple[float, ...]] = field(default=None, compare=False)
+
+ def __post_init__(self) -> None:
+ _post_init_multi_label(self)
+
+ def resolve(self) -> Any:
+ return self.label, (self.head.resolve(), self.tail.resolve())
+
+
+@dataclass(eq=True, frozen=True)
+class NaryRelation(Annotation):
+ arguments: Tuple[Annotation, ...]
+ roles: Tuple[str, ...]
+ label: str
+ score: float = field(default=1.0, compare=False)
+
+ def __post_init__(self) -> None:
+ _post_init_arguments_and_roles(self)
+ _post_init_single_label(self)
+
+ def resolve(self) -> Any:
+ return (
+ self.label,
+ tuple((role, arg.resolve()) for arg, role in zip(self.arguments, self.roles)),
+ )
@dataclasses.dataclass(eq=True, frozen=True)
diff --git a/src/pie_modules/document/processing/__init__.py b/src/pie_modules/document/processing/__init__.py
index 296944c63..28d15ba1a 100644
--- a/src/pie_modules/document/processing/__init__.py
+++ b/src/pie_modules/document/processing/__init__.py
@@ -7,5 +7,4 @@
from .tokenization import (
text_based_document_to_token_based,
token_based_document_to_text_based,
- tokenize_document,
)
diff --git a/src/pie_modules/document/processing/text_span_trimmer.py b/src/pie_modules/document/processing/text_span_trimmer.py
index 412e01110..eb27a58f4 100644
--- a/src/pie_modules/document/processing/text_span_trimmer.py
+++ b/src/pie_modules/document/processing/text_span_trimmer.py
@@ -1,7 +1,7 @@
from __future__ import annotations
import logging
-from typing import TypeVar
+from typing import Any, Dict, TypeVar
from pie_core import AnnotationLayer, Document
@@ -45,6 +45,7 @@ def trim_text_spans(
text = spans.target
+ original_kwargs: dict[str, Any]
for span in spans:
if isinstance(span, Span):
starts_and_ends = [(span.start, span.end)]
@@ -99,6 +100,7 @@ def trim_text_spans(
)
removed_span_ids.append(span._id)
continue
+ new_kwargs: dict[str, Any]
if isinstance(span, Span):
if not len(new_starts_and_ends) == 1:
raise ValueError(f"Expected one span, got {len(new_starts_and_ends)}")
diff --git a/src/pie_modules/document/processing/tokenization.py b/src/pie_modules/document/processing/tokenization.py
index 12a8a6be3..c45b6ebc8 100644
--- a/src/pie_modules/document/processing/tokenization.py
+++ b/src/pie_modules/document/processing/tokenization.py
@@ -1,8 +1,6 @@
-import functools
-import json
import logging
from collections import defaultdict
-from copy import copy, deepcopy
+from copy import deepcopy
from typing import (
Callable,
Dict,
@@ -18,7 +16,6 @@
from pie_core import Annotation
from pie_core.utils.hydra import resolve_type
-from transformers import PreTrainedTokenizer
from pie_modules.annotations import MultiSpan, Span
from pie_modules.documents import TextBasedDocument, TokenBasedDocument
@@ -105,13 +102,13 @@ def char_span_to_token_span(
f"The first target of a text targeting span must be a string, but found {type(base_text)} as first "
f"target type. Can not convert the span {span}."
)
- stripped_slices = [
- get_stripped_offsets(start, end, base_text) for start, end in span.slices
- ]
+ stripped_slices = tuple(
+ [get_stripped_offsets(start, end, base_text) for start, end in span.slices]
+ )
else:
stripped_slices = span.slices
# remove empty and invalid slices
- stripped_slices = [(start, end) for start, end in stripped_slices if start < end]
+ stripped_slices = tuple([(start, end) for start, end in stripped_slices if start < end])
if len(stripped_slices) == 0:
return None
slices_inclusive_end = [
@@ -449,123 +446,3 @@ def token_based_document_to_text_based(
added_annotations.setdefault(layer_name, {}).update(annotation_mapping)
return result
-
-
-def tokenize_document(
- doc: TextBasedDocument,
- tokenizer: PreTrainedTokenizer,
- result_document_type: Type[ToD],
- partition_layer: Optional[str] = None,
- strip_spans: bool = False,
- strict_span_conversion: bool = True,
- added_annotations: Optional[List[Dict[str, Dict[Annotation, Annotation]]]] = None,
- verbose: bool = True,
- **tokenize_kwargs,
-) -> List[ToD]:
- """Tokenize a document with a given tokenizer and return a list of token based documents. The
- document is tokenized in partitions if a partition layer is provided. The annotations that
- target the text are converted to target the tokens and also all dependent annotations are
- converted.
-
- Args:
- doc (TextBasedDocument): The document to tokenize.
- tokenizer (PreTrainedTokenizer): The tokenizer.
- result_document_type (Type[ToD]): The exact type of the token based documents.
- partition_layer (Optional[str], optional): The layer to use for partitioning the document. If None, the whole
- document is tokenized. Defaults to None.
- strip_spans (bool, optional): If True, strip the whitespace from the character spans before converting them to
- token spans. Defaults to False.
- strict_span_conversion (bool, optional): If True, raise an error if not all annotations can be converted to
- token based documents. Defaults to True.
- added_annotations (Optional[List[Dict[str, Dict[Annotation, Annotation]]]], optional): Pass an empty list to
- collect the added annotations. Defaults to None.
- verbose (bool, optional): If True, log warnings if annotations can not be converted. Defaults to True.
-
- Returns:
- List[ToD]: The token based documents of type result_document_type with the converted annotations.
- """
-
- added_annotation_lists: Dict[str, List[Annotation]] = defaultdict(list)
- result = []
- partitions: Iterable[Span]
- if partition_layer is None:
- partitions = [Span(start=0, end=len(doc.text))]
- else:
- partitions = doc[partition_layer]
- for partition in partitions:
- text = doc.text[partition.start : partition.end]
- current_tokenize_kwargs = copy(tokenize_kwargs)
- if "text" in tokenize_kwargs:
- current_tokenize_kwargs["text_pair"] = text
- sequence_index = 1
- else:
- current_tokenize_kwargs["text"] = text
- sequence_index = 0
- tokenized_text = tokenizer(**current_tokenize_kwargs)
- for batch_encoding in tokenized_text.encodings:
- token_offset_mapping = batch_encoding.offsets
- char_to_token: Optional[Callable[[int], Optional[int]]]
- char_to_token = functools.partial(
- batch_encoding.char_to_token, sequence_index=sequence_index
- )
- token_offset_mapping = [
- offsets if s_id == sequence_index else (0, 0)
- for s_id, offsets in zip(batch_encoding.sequence_ids, token_offset_mapping)
- ]
- if partition.start > 0:
- token_offset_mapping = [
- (start + partition.start, end + partition.start)
- for start, end in token_offset_mapping
- ]
- char_to_token = None
- current_added_annotations: Dict[str, Dict[Annotation, Annotation]] = defaultdict(dict)
- tokenized_document = text_based_document_to_token_based(
- doc,
- tokens=batch_encoding.tokens,
- result_document_type=result_document_type,
- token_offset_mapping=token_offset_mapping,
- char_to_token=char_to_token,
- strict_span_conversion=False,
- strip_spans=strip_spans,
- verbose=False,
- added_annotations=current_added_annotations,
- )
- tokenized_document.metadata["tokenizer_encoding"] = batch_encoding
- result.append(tokenized_document)
- for k, v in current_added_annotations.items():
- added_annotation_lists[k].extend(v)
- if added_annotations is not None:
- added_annotations.append(current_added_annotations)
-
- missed_annotations = defaultdict(set)
- if strict_span_conversion or verbose:
- # We check the annotations with respect to the layers of the result_document_type.
- # Note that the original document may have more layers, but since result documents
- # are of type result_document_type, we only check the layers of this type.
- for annotation_field in result_document_type.annotation_fields():
- # do not check the partition layer because the partitions are not required later on
- # and entries get quite probably removed when windowing is applied, so this just pollutes the logs
- if annotation_field.name != partition_layer:
- current_missed_annotations = set(doc[annotation_field.name]) - set(
- added_annotation_lists[annotation_field.name]
- )
- if len(current_missed_annotations) > 0:
- missed_annotations[annotation_field.name] = current_missed_annotations
-
- if len(missed_annotations) > 0:
- missed_annotations_simplified = {k: str(v) for k, v in missed_annotations.items()}
- if strict_span_conversion:
- raise ValueError(
- f"could not convert all annotations from document with id={doc.id} to token based documents, "
- f"but strict_span_conversion is True, so raise an error, "
- f"missed annotations:\n{json.dumps(missed_annotations_simplified, sort_keys=True, indent=2)}"
- )
- else:
- if verbose:
- logger.warning(
- f"could not convert all annotations from document with id={doc.id} to token based documents, "
- f"missed annotations (disable this message with verbose=False):\n"
- f"{json.dumps(missed_annotations_simplified, sort_keys=True, indent=2)}"
- )
-
- return result
diff --git a/src/pie_modules/documents.py b/src/pie_modules/documents.py
index bf07de932..326621cdb 100644
--- a/src/pie_modules/documents.py
+++ b/src/pie_modules/documents.py
@@ -1,29 +1,8 @@
import dataclasses
+from typing import Any, Dict, Optional, Tuple
-from pie_core import AnnotationLayer, annotation_field
-
-# re-export all documents from pytorch_ie to have a single entry point
-from pytorch_ie.documents import (
- TextBasedDocument,
- TextDocumentWithLabel,
- TextDocumentWithLabeledMultiSpans,
- TextDocumentWithLabeledMultiSpansAndBinaryRelations,
- TextDocumentWithLabeledMultiSpansAndLabeledPartitions,
- TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
- TextDocumentWithLabeledPartitions,
- TextDocumentWithLabeledSpans,
- TextDocumentWithLabeledSpansAndBinaryRelations,
- TextDocumentWithLabeledSpansAndLabeledPartitions,
- TextDocumentWithLabeledSpansAndSentences,
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
- TextDocumentWithMultiLabel,
- TextDocumentWithSentences,
- TextDocumentWithSpans,
- TextDocumentWithSpansAndBinaryRelations,
- TextDocumentWithSpansAndLabeledPartitions,
- TextDocumentWithSpansBinaryRelationsAndLabeledPartitions,
- TokenBasedDocument,
-)
+from pie_core import AnnotationLayer, Document, annotation_field
+from typing_extensions import TypeAlias
from pie_modules.annotations import (
AbstractiveSummary,
@@ -31,12 +10,172 @@
BinaryRelation,
ExtractiveAnswer,
GenerativeAnswer,
+ Label,
LabeledMultiSpan,
LabeledSpan,
+ MultiLabel,
Question,
+ Span,
)
+@dataclasses.dataclass
+class WithMetadata:
+ id: Optional[str] = None
+ metadata: Dict[str, Any] = dataclasses.field(default_factory=dict)
+
+
+@dataclasses.dataclass
+class WithTokens:
+ tokens: Tuple[str, ...]
+
+
+@dataclasses.dataclass
+class WithText:
+ text: str
+
+
+@dataclasses.dataclass
+class TextBasedDocument(WithMetadata, WithText, Document):
+ pass
+
+
+@dataclasses.dataclass
+class TokenBasedDocument(WithMetadata, WithTokens, Document):
+ def __post_init__(self) -> None:
+
+ # When used in a dataset, the document gets serialized to json like structure which does not know tuples,
+ # so they get converted to lists. This is a workaround to automatically convert the "tokens" back to tuples
+ # when the document is created from a dataset.
+ if isinstance(self.tokens, list):
+ object.__setattr__(self, "tokens", tuple(self.tokens))
+ elif not isinstance(self.tokens, tuple):
+ raise ValueError("tokens must be a tuple.")
+
+ # Call the default document construction code
+ super().__post_init__()
+
+
+# backwards compatibility
+TextDocument: TypeAlias = TextBasedDocument
+
+
+@dataclasses.dataclass
+class DocumentWithLabel(Document):
+ label: AnnotationLayer[Label] = annotation_field()
+
+
+@dataclasses.dataclass
+class DocumentWithMultiLabel(Document):
+ label: AnnotationLayer[MultiLabel] = annotation_field()
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabel(DocumentWithLabel, TextBasedDocument):
+ pass
+
+
+@dataclasses.dataclass
+class TextDocumentWithMultiLabel(DocumentWithMultiLabel, TextBasedDocument):
+ pass
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledPartitions(TextBasedDocument):
+ labeled_partitions: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
+
+
+@dataclasses.dataclass
+class TextDocumentWithSentences(TextBasedDocument):
+ sentences: AnnotationLayer[Span] = annotation_field(target="text")
+
+
+@dataclasses.dataclass
+class TextDocumentWithSpans(TextBasedDocument):
+ spans: AnnotationLayer[Span] = annotation_field(target="text")
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledSpans(TextBasedDocument):
+ labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledSpansAndLabeledPartitions(
+ TextDocumentWithLabeledSpans, TextDocumentWithLabeledPartitions
+):
+ pass
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledSpansAndSentences(
+ TextDocumentWithLabeledSpans, TextDocumentWithSentences
+):
+ pass
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledSpansAndBinaryRelations(TextDocumentWithLabeledSpans):
+ binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="labeled_spans")
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
+ TextDocumentWithLabeledSpansAndLabeledPartitions,
+ TextDocumentWithLabeledSpansAndBinaryRelations,
+ TextDocumentWithLabeledPartitions,
+):
+ pass
+
+
+@dataclasses.dataclass
+class TextDocumentWithSpansAndBinaryRelations(TextDocumentWithSpans):
+ binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(target="spans")
+
+
+@dataclasses.dataclass
+class TextDocumentWithSpansAndLabeledPartitions(
+ TextDocumentWithSpans, TextDocumentWithLabeledPartitions
+):
+ pass
+
+
+@dataclasses.dataclass
+class TextDocumentWithSpansBinaryRelationsAndLabeledPartitions(
+ TextDocumentWithSpansAndLabeledPartitions,
+ TextDocumentWithSpansAndBinaryRelations,
+ TextDocumentWithLabeledPartitions,
+):
+ pass
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledMultiSpans(TextBasedDocument):
+ labeled_multi_spans: AnnotationLayer[LabeledMultiSpan] = annotation_field(target="text")
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledMultiSpansAndLabeledPartitions(
+ TextDocumentWithLabeledMultiSpans, TextDocumentWithLabeledPartitions
+):
+ pass
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledMultiSpansAndBinaryRelations(TextDocumentWithLabeledMultiSpans):
+ binary_relations: AnnotationLayer[BinaryRelation] = annotation_field(
+ target="labeled_multi_spans"
+ )
+
+
+@dataclasses.dataclass
+class TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions(
+ TextDocumentWithLabeledMultiSpansAndLabeledPartitions,
+ TextDocumentWithLabeledMultiSpansAndBinaryRelations,
+):
+ pass
+
+
@dataclasses.dataclass
class TextDocumentWithQuestionsAndExtractiveAnswers(TextBasedDocument):
"""A text based PIE document with annotations for extractive question answering."""
@@ -66,7 +205,6 @@ class TokenDocumentWithQuestionsAndExtractiveAnswers(TokenBasedDocument):
# backwards compatibility
ExtractiveQADocument = TextDocumentWithQuestionsAndExtractiveAnswers
TokenizedExtractiveQADocument = TokenDocumentWithQuestionsAndExtractiveAnswers
-TextDocument = TextBasedDocument
@dataclasses.dataclass
diff --git a/src/pie_modules/metrics/__init__.py b/src/pie_modules/metrics/__init__.py
index 7038ade07..674d2ef0c 100644
--- a/src/pie_modules/metrics/__init__.py
+++ b/src/pie_modules/metrics/__init__.py
@@ -8,5 +8,4 @@
FieldLengthCollector,
LabelCountCollector,
SubFieldLengthCollector,
- TokenCountCollector,
)
diff --git a/src/pie_modules/metrics/relation_argument_distance_collector.py b/src/pie_modules/metrics/relation_argument_distance_collector.py
index 935c11cbc..138d3ede2 100644
--- a/src/pie_modules/metrics/relation_argument_distance_collector.py
+++ b/src/pie_modules/metrics/relation_argument_distance_collector.py
@@ -1,13 +1,9 @@
from collections import defaultdict
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import Dict, List
from pie_core import Document, DocumentStatistic
-from pie_core.utils.hydra import resolve_target
-from transformers import AutoTokenizer, PreTrainedTokenizer
from pie_modules.annotations import BinaryRelation, NaryRelation, Span
-from pie_modules.document.processing import tokenize_document
-from pie_modules.documents import TextBasedDocument, TokenBasedDocument
from pie_modules.utils.span import distance
@@ -33,10 +29,6 @@ def __init__(
self,
layer: str,
distance_type: str = "outer",
- tokenize: bool = False,
- tokenize_kwargs: Optional[Dict[str, Any]] = None,
- tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
- tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None,
key_all: str = "ALL",
**kwargs,
):
@@ -44,45 +36,16 @@ def __init__(
self.layer = layer
self.distance_type = distance_type
self.key_all = key_all
- self.tokenize = tokenize
- self.tokenize_kwargs = tokenize_kwargs or {}
- if self.tokenize:
- if tokenizer is None:
- raise ValueError(
- "tokenizer must be provided to calculate distance in means of tokens"
- )
- if isinstance(tokenizer, str):
- tokenizer = AutoTokenizer.from_pretrained(tokenizer)
- self.tokenizer = tokenizer
- if tokenized_document_type is None:
- raise ValueError(
- "tokenized_document_type must be provided to calculate distance in means of tokens"
- )
- self.tokenized_document_type: Type[TokenBasedDocument] = resolve_target(
- tokenized_document_type
- )
def _collect(self, doc: Document) -> Dict[str, List[float]]:
- if self.tokenize:
- if not isinstance(doc, TextBasedDocument):
- raise ValueError(
- "doc must be a TextBasedDocument to calculate distance in means of tokens"
- )
- docs = tokenize_document(
- doc,
- tokenizer=self.tokenizer,
- result_document_type=self.tokenized_document_type,
- **self.tokenize_kwargs,
- )
- else:
- docs = [doc]
+ docs = [doc]
values: Dict[str, List[float]] = defaultdict(list)
for doc in docs:
layer_obj = getattr(doc, self.layer)
for binary_relation in layer_obj:
if isinstance(binary_relation, BinaryRelation):
- args = [binary_relation.head, binary_relation.tail]
+ args = (binary_relation.head, binary_relation.tail)
label = binary_relation.label
elif isinstance(binary_relation, NaryRelation):
args = binary_relation.arguments
diff --git a/src/pie_modules/metrics/span_coverage_collector.py b/src/pie_modules/metrics/span_coverage_collector.py
index df783f228..eb9338386 100644
--- a/src/pie_modules/metrics/span_coverage_collector.py
+++ b/src/pie_modules/metrics/span_coverage_collector.py
@@ -1,13 +1,9 @@
import logging
-from typing import Any, Dict, List, Optional, Set, Type, Union
+from typing import List, Optional, Set, Union
from pie_core import Document, DocumentStatistic
-from pie_core.utils.hydra import resolve_type
-from transformers import AutoTokenizer, PreTrainedTokenizer
from pie_modules.annotations import LabeledMultiSpan, Span
-from pie_modules.document.processing import tokenize_document
-from pie_modules.documents import TextBasedDocument, TokenBasedDocument
logger = logging.getLogger(__name__)
@@ -36,57 +32,16 @@ class SpanCoverageCollector(DocumentStatistic):
def __init__(
self,
layer: str,
- tokenize: bool = False,
- tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
- tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None,
labels: Optional[Union[List[str], str]] = None,
label_attribute: str = "label",
- tokenize_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.layer = layer
self.labels = labels
self.label_field = label_attribute
- self.tokenize = tokenize
- if self.tokenize:
- if tokenizer is None:
- raise ValueError(
- "tokenizer must be provided to calculate the span coverage in means of tokens"
- )
- if isinstance(tokenizer, str):
- tokenizer = AutoTokenizer.from_pretrained(tokenizer)
- self.tokenizer = tokenizer
- if tokenized_document_type is None:
- raise ValueError(
- "tokenized_document_type must be provided to calculate the span coverage in means of tokens"
- )
- self.tokenized_document_type = resolve_type(
- tokenized_document_type, expected_super_type=TokenBasedDocument
- )
- self.tokenize_kwargs = tokenize_kwargs or {}
def _collect(self, doc: Document) -> float:
- docs: Union[List[Document], List[TokenBasedDocument]]
- if self.tokenize:
- if not isinstance(doc, TextBasedDocument):
- raise ValueError(
- "doc must be a TextBasedDocument to calculate the span coverage in means of tokens"
- )
- docs = tokenize_document(
- doc,
- tokenizer=self.tokenizer,
- result_document_type=self.tokenized_document_type,
- **self.tokenize_kwargs,
- )
- if len(docs) != 1:
- raise ValueError(
- "tokenization of a single document must result in a single document to calculate the "
- "span coverage correctly. Please check your tokenization settings, especially that "
- "no windowing is applied because of max input length restrictions."
- )
- doc = docs[0]
-
layer_obj = getattr(doc, self.layer)
target = layer_obj.target
covered_indices: Set[int] = set()
diff --git a/src/pie_modules/metrics/span_length_collector.py b/src/pie_modules/metrics/span_length_collector.py
index 8acac1399..649a8e801 100644
--- a/src/pie_modules/metrics/span_length_collector.py
+++ b/src/pie_modules/metrics/span_length_collector.py
@@ -1,14 +1,10 @@
import logging
from collections import defaultdict
-from typing import Any, Callable, Dict, List, Optional, Type, Union
+from typing import Dict, List, Optional, Union
from pie_core import Document, DocumentStatistic
-from pie_core.utils.hydra import resolve_type
-from transformers import AutoTokenizer, PreTrainedTokenizer
from pie_modules.annotations import Span
-from pie_modules.document.processing import tokenize_document
-from pie_modules.documents import TextBasedDocument, TokenBasedDocument
logger = logging.getLogger(__name__)
@@ -26,12 +22,8 @@ class SpanLengthCollector(DocumentStatistic):
def __init__(
self,
layer: str,
- tokenize: bool = False,
- tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None,
- tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None,
labels: Optional[Union[List[str], str]] = None,
label_attribute: str = "label",
- tokenize_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -40,39 +32,9 @@ def __init__(
raise ValueError("labels must be a list of strings or 'INFERRED'")
self.labels = labels
self.label_field = label_attribute
- self.tokenize = tokenize
- if self.tokenize:
- if tokenizer is None:
- raise ValueError(
- "tokenizer must be provided to calculate the span length in means of tokens"
- )
- if isinstance(tokenizer, str):
- tokenizer = AutoTokenizer.from_pretrained(tokenizer)
- self.tokenizer = tokenizer
- if tokenized_document_type is None:
- raise ValueError(
- "tokenized_document_type must be provided to calculate the span length in means of tokens"
- )
- self.tokenized_document_type = resolve_type(
- tokenized_document_type, expected_super_type=TokenBasedDocument
- )
- self.tokenize_kwargs = tokenize_kwargs or {}
def _collect(self, doc: Document) -> Union[List[int], Dict[str, List[int]]]:
- docs: Union[List[Document], List[TokenBasedDocument]]
- if self.tokenize:
- if not isinstance(doc, TextBasedDocument):
- raise ValueError(
- "doc must be a TextBasedDocument to calculate the span length in means of tokens"
- )
- docs = tokenize_document(
- doc,
- tokenizer=self.tokenizer,
- result_document_type=self.tokenized_document_type,
- **self.tokenize_kwargs,
- )
- else:
- docs = [doc]
+ docs = [doc]
values: Dict[str, List[int]]
if isinstance(self.labels, str):
diff --git a/src/pie_modules/metrics/statistics.py b/src/pie_modules/metrics/statistics.py
index 16b9d41d3..9802b6356 100644
--- a/src/pie_modules/metrics/statistics.py
+++ b/src/pie_modules/metrics/statistics.py
@@ -1,46 +1,12 @@
import logging
from collections import defaultdict
-from typing import Any, Callable, Dict, List, Optional, Type, Union
+from typing import Any, Callable, Dict, List, Union
from pie_core import Document, DocumentStatistic
-from transformers import AutoTokenizer, PreTrainedTokenizer
-
-from pie_modules.documents import TextBasedDocument
logger = logging.getLogger(__name__)
-class TokenCountCollector(DocumentStatistic):
- """Collects the token count of a field when tokenizing its content with a Huggingface
- tokenizer.
-
- The content of the field should be a string.
- """
-
- def __init__(
- self,
- tokenizer: Union[str, PreTrainedTokenizer],
- text_field: str = "text",
- tokenizer_kwargs: Optional[Dict[str, Any]] = None,
- document_type: Optional[Type[Document]] = None,
- **kwargs,
- ):
- if document_type is None and text_field == "text":
- document_type = TextBasedDocument
- super().__init__(document_type=document_type, **kwargs)
- self.tokenizer = (
- AutoTokenizer.from_pretrained(tokenizer) if isinstance(tokenizer, str) else tokenizer
- )
- self.tokenizer_kwargs = tokenizer_kwargs or {}
- self.text_field = text_field
-
- def _collect(self, doc: Document) -> int:
- text = getattr(doc, self.text_field)
- encodings = self.tokenizer(text, **self.tokenizer_kwargs)
- tokens = encodings.tokens()
- return len(tokens)
-
-
class FieldLengthCollector(DocumentStatistic):
"""Collects the length of a field, e.g. to collect the number the characters in the input text.
diff --git a/src/pie_modules/models/__init__.py b/src/pie_modules/models/__init__.py
deleted file mode 100644
index df8f4a035..000000000
--- a/src/pie_modules/models/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from .sequence_classification_with_pooler import (
- SequenceClassificationModelWithPooler,
- SequencePairSimilarityModelWithPooler,
-)
-from .simple_extractive_question_answering import SimpleExtractiveQuestionAnsweringModel
-from .simple_generative import SimpleGenerativeModel
-from .simple_sequence_classification import SimpleSequenceClassificationModel
-from .simple_token_classification import SimpleTokenClassificationModel
-from .span_tuple_classification import SpanTupleClassificationModel
-from .token_classification_with_seq2seq_encoder_and_crf import (
- TokenClassificationModelWithSeq2SeqEncoderAndCrf,
-)
diff --git a/src/pie_modules/models/base_models/__init__.py b/src/pie_modules/models/base_models/__init__.py
deleted file mode 100644
index 8bc2cf097..000000000
--- a/src/pie_modules/models/base_models/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .bart_as_pointer_network import BartAsPointerNetwork
-from .bart_with_decoder_position_ids import BartModelWithDecoderPositionIds
diff --git a/src/pie_modules/models/base_models/bart_as_pointer_network.py b/src/pie_modules/models/base_models/bart_as_pointer_network.py
index b17540263..e69de29bb 100644
--- a/src/pie_modules/models/base_models/bart_as_pointer_network.py
+++ b/src/pie_modules/models/base_models/bart_as_pointer_network.py
@@ -1,476 +0,0 @@
-from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
-
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import Parameter
-from torch.optim import Optimizer
-from transformers import BartConfig, BartModel, BartPreTrainedModel, GenerationConfig
-from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
-from transformers.models.bart.modeling_bart import shift_tokens_right
-from transformers.utils import logging
-
-from pie_modules.models.base_models.bart_with_decoder_position_ids import (
- BartModelWithDecoderPositionIds,
-)
-from pie_modules.models.components.pointer_head import PointerHead
-
-logger = logging.get_logger(__name__)
-
-
-def get_layer_norm_parameters(
- named_parameters: Iterator[Tuple[str, Parameter]],
-) -> Iterator[Parameter]:
- return (
- param for name, param in named_parameters if "layernorm" in name or "layer_norm" in name
- )
-
-
-def get_non_layer_norm_parameters(
- named_parameters: Iterator[Tuple[str, Parameter]],
-) -> Iterator[Parameter]:
- return (
- param
- for name, param in named_parameters
- if not ("layernorm" in name or "layer_norm" in name)
- )
-
-
-class BartAsPointerNetworkConfig(BartConfig):
- def __init__(
- self,
- # respective token ids for the label-, eos-, and pad ids. Can be used as a mapping from the
- # target ids to the token ids.
- target_token_ids: Optional[List[int]] = None,
- # token id mapping to better initialize the label embedding weights
- embedding_weight_mapping: Optional[Dict[Union[int, str], List[int]]] = None,
- # special decoder position id handling
- decoder_position_id_mode: Optional[str] = None,
- decoder_position_id_pattern: Optional[List[int]] = None,
- decoder_position_id_mapping: Optional[Dict[int, int]] = None,
- # other parameters
- use_encoder_mlp: bool = True,
- use_constraints_encoder_mlp: bool = False,
- # optimizer
- lr: float = 5e-5,
- task_lr: Optional[float] = None,
- weight_decay: float = 1e-2,
- head_decay: Optional[float] = None,
- shared_decay: Optional[float] = None,
- encoder_layer_norm_decay: Optional[float] = 0.001,
- decoder_layer_norm_decay: Optional[float] = None,
- # other BartConfig parameters
- **kwargs,
- ):
- super().__init__(**kwargs)
-
- self.target_token_ids = target_token_ids
-
- self.embedding_weight_mapping = embedding_weight_mapping
-
- self.use_encoder_mlp = use_encoder_mlp
- self.use_constraints_encoder_mlp = use_constraints_encoder_mlp
-
- self.decoder_position_id_mode = decoder_position_id_mode
- self.decoder_position_id_pattern = decoder_position_id_pattern
- self.decoder_position_id_mapping = decoder_position_id_mapping
-
- self.lr = lr
- self.task_lr = task_lr
- self.weight_decay = weight_decay
- self.head_decay = head_decay
- self.shared_decay = shared_decay
- self.encoder_layer_norm_decay = encoder_layer_norm_decay
- self.decoder_layer_norm_decay = decoder_layer_norm_decay
-
-
-class BartAsPointerNetwork(BartPreTrainedModel):
- config_class = BartAsPointerNetworkConfig
- base_model_prefix = "model"
- _tied_weights_keys = [
- "encoder.embed_tokens.weight",
- "decoder.embed_tokens.weight",
- ]
-
- def __init__(self, config: BartAsPointerNetworkConfig):
- super().__init__(config)
- if self.config.decoder_position_id_mode is not None:
- self.model = BartModelWithDecoderPositionIds(config)
- else:
- self.model = BartModel(config)
-
- self.pointer_head = PointerHead(
- # target space ids
- bos_id=self.model.config.bos_token_id,
- eos_id=self.model.config.eos_token_id,
- pad_id=self.model.config.pad_token_id,
- # decoder-input token ids
- target_token_ids=self.model.config.target_token_ids,
- # embeddings
- embeddings=self.model.decoder.embed_tokens,
- embedding_weight_mapping=self.model.config.embedding_weight_mapping,
- # other parameters
- use_encoder_mlp=self.model.config.use_encoder_mlp,
- use_constraints_encoder_mlp=self.model.config.use_constraints_encoder_mlp,
- decoder_position_id_mode=self.model.config.decoder_position_id_mode,
- decoder_position_id_pattern=self.model.config.decoder_position_id_pattern,
- decoder_position_id_mapping=self.model.config.decoder_position_id_mapping,
- )
-
- # Initialize weights and apply final processing
- self.post_init()
-
- @classmethod
- def _load_pretrained_model(
- cls,
- *args,
- **kwargs,
- ):
- (
- model,
- missing_keys,
- unexpected_keys,
- mismatched_keys,
- offload_index,
- error_msgs,
- ) = super()._load_pretrained_model(*args, **kwargs)
- # adjust the model after loading the original model (e.g. vanilla BartModel)
- model.adjust_after_loading_original_model()
- return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
-
- def resize_token_embeddings(
- self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
- ) -> nn.Embedding:
- new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
- # we also need to update the embeddings in the pointer head
- self.pointer_head.set_embeddings(new_embeddings)
- return new_embeddings
-
- def adjust_after_loading_original_model(self):
- # target_token_ids contains all new target tokens for the labels and new tokens were added to the end
- # of the vocabulary, so we can use its maximum to resize the embedding weights
- self.resize_token_embeddings(new_num_tokens=max(self.config.target_token_ids) + 1)
- # initialize the newly added embeddings for the labels with better weights from the original embeddings
- self.pointer_head.overwrite_embeddings_with_mapping()
-
- # adjust generation settings
- # set the correct decoder_start_token_id
- self.config.decoder_start_token_id = self.config.bos_token_id
- # disable ForcedBOSTokenLogitsProcessor
- self.config.forced_bos_token_id = None
- # disable ForcedEOSTokenLogitsProcessor
- self.config.forced_eos_token_id = None
- # update the generation config accordingly
- self.generation_config = GenerationConfig.from_model_config(self.config)
-
- def base_model_named_params(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- yield from self.model.named_parameters(prefix=prefix + self.base_model_prefix)
-
- def head_named_params(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- base_model_param_names = {
- name for name, param in self.base_model_named_params(prefix=prefix)
- }
- for name, param in self.named_parameters(prefix=prefix):
- if name not in base_model_param_names:
- yield name, param
-
- def encoder_only_named_params(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- shared_params = set(dict(self.encoder_decoder_shared_named_params(prefix=prefix)).values())
- for name, param in self.model.encoder.named_parameters(
- prefix=prefix + self.base_model_prefix + ".encoder"
- ):
- if param not in shared_params:
- yield name, param
-
- def decoder_only_named_params(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- shared_params = set(dict(self.encoder_decoder_shared_named_params(prefix=prefix)).values())
- for name, param in self.model.decoder.named_parameters(
- prefix=prefix + self.base_model_prefix + ".decoder"
- ):
- if param not in shared_params:
- yield name, param
-
- def encoder_decoder_shared_named_params(
- self, prefix: str = ""
- ) -> Iterator[Tuple[str, Parameter]]:
- encoder_params = set(self.model.encoder.parameters())
- decoder_params = set(self.model.decoder.parameters())
- for name, param in self.base_model_named_params(prefix=prefix):
- if param in encoder_params and param in decoder_params:
- yield name, param
-
- def get_encoder(self):
- return self.model.get_encoder()
-
- def get_decoder(self):
- return self.model.get_decoder()
-
- # @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
- # @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
- # @add_end_docstrings(BART_GENERATION_EXAMPLE)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- decoder_position_ids: Optional[torch.LongTensor] = None,
- constraints: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- decoder_head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- encoder_outputs: Optional[List[torch.FloatTensor]] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Seq2SeqLMOutput]:
- r"""Labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels
- for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100`
- are ignored (masked), the loss is only computed for the tokens with labels in `[0, ...,
- config.vocab_size]`.
-
- Returns:
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if labels is not None:
- if use_cache:
- logger.warning(
- "The `use_cache` argument is changed to `False` since `labels` is provided."
- )
- use_cache = False
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- decoder_input_ids = shift_tokens_right(
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
- )
-
- if decoder_input_ids is None:
- # we can not create the decoder_input_ids from input_ids, because we need the
- # encoder_input_ids for the pointer network
- raise ValueError("decoder_input_ids has to be set!")
-
- # this adjusts the input_ids and, if available, the position_ids
- decoder_inputs = self.pointer_head.prepare_decoder_inputs(
- input_ids=decoder_input_ids,
- # in the case of generation (with past_key_values) the position_ids are already prepared
- position_ids=decoder_position_ids,
- encoder_input_ids=input_ids,
- )
-
- model_inputs = dict(
- input_ids=input_ids,
- encoder_outputs=encoder_outputs,
- attention_mask=attention_mask,
- decoder_attention_mask=decoder_attention_mask,
- head_mask=head_mask,
- decoder_head_mask=decoder_head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- decoder_inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=True,
- )
- for k, v in decoder_inputs.items():
- model_inputs[f"decoder_{k}"] = v
- outputs = self.model(**model_inputs)
-
- if not isinstance(outputs, Seq2SeqModelOutput):
- raise ValueError(
- "Inconsistent output: The output of the model forward should be of type "
- f"`Seq2SeqLMOutput`, but is of type `{type(outputs)}`."
- )
- logits, loss = self.pointer_head(
- last_hidden_state=outputs.last_hidden_state,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_input_ids=input_ids,
- encoder_attention_mask=attention_mask,
- labels=labels,
- decoder_attention_mask=decoder_attention_mask,
- constraints=constraints,
- )
-
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return Seq2SeqLMOutput(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- decoder_hidden_states=outputs.decoder_hidden_states,
- decoder_attentions=outputs.decoder_attentions,
- cross_attentions=outputs.cross_attentions,
- encoder_last_hidden_state=outputs.encoder_last_hidden_state,
- encoder_hidden_states=outputs.encoder_hidden_states,
- encoder_attentions=outputs.encoder_attentions,
- )
-
- def prepare_inputs_for_generation(
- self,
- decoder_input_ids,
- encoder_input_ids, # added for pointer network
- encoder_attention_mask, # added for pointer network
- past_key_values=None,
- attention_mask=None,
- decoder_attention_mask=None,
- head_mask=None,
- decoder_head_mask=None,
- cross_attn_head_mask=None,
- use_cache=None,
- encoder_outputs=None,
- **kwargs,
- ):
- result = {}
- if self.pointer_head.use_prepared_position_ids:
- # we need to prepare the position ids for the decoder here, because later we do not have the full
- # input_ids anymore
- result["decoder_position_ids"] = self.pointer_head.prepare_decoder_position_ids(
- input_ids=decoder_input_ids
- )
-
- # cut decoder_input_ids if past_key_values is used
- if past_key_values is not None:
- past_length = past_key_values[0][0].shape[2]
-
- # Some generation methods already pass only the last input ID
- if decoder_input_ids.shape[1] > past_length:
- remove_prefix_length = past_length
- else:
- # Default to old behavior: keep only final ID
- remove_prefix_length = decoder_input_ids.shape[1] - 1
-
- decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
-
- if "decoder_position_ids" in result:
- result["decoder_position_ids"] = result["decoder_position_ids"][
- :, remove_prefix_length:
- ]
-
- result.update(
- {
- "input_ids": encoder_input_ids,
- "encoder_outputs": encoder_outputs,
- "past_key_values": past_key_values,
- "decoder_input_ids": decoder_input_ids,
- "attention_mask": encoder_attention_mask,
- "decoder_attention_mask": decoder_attention_mask,
- "head_mask": head_mask,
- "decoder_head_mask": decoder_head_mask,
- "cross_attn_head_mask": cross_attn_head_mask,
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
- }
- )
- return result
-
- def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
- return shift_tokens_right(
- labels, self.config.pad_token_id, self.config.decoder_start_token_id
- )
-
- @staticmethod
- def _reorder_cache(past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- # cached cross_attention states don't have to be reordered -> they are always the same
- reordered_past += (
- tuple(
- past_state.index_select(0, beam_idx.to(past_state.device))
- for past_state in layer_past[:2]
- )
- + layer_past[2:],
- )
- return reordered_past
-
- def _prepare_encoder_decoder_kwargs_for_generation(
- self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
- ) -> Dict[str, Any]:
- result = super()._prepare_encoder_decoder_kwargs_for_generation(
- inputs_tensor=inputs_tensor,
- model_kwargs=model_kwargs,
- model_input_name=model_input_name,
- )
- # add items that are needed for pointer network
- result["encoder_input_ids"] = inputs_tensor
- result["encoder_attention_mask"] = result["attention_mask"]
- return result
-
- def configure_optimizer(self) -> Optimizer:
- parameters = []
-
- # head parameters
- head_decay = (
- self.config.head_decay
- if self.config.head_decay is not None
- else self.config.weight_decay
- )
- params = {
- "lr": self.config.task_lr if self.config.task_lr is not None else self.config.lr,
- "weight_decay": head_decay,
- "params": dict(self.head_named_params()).values(),
- }
- parameters.append(params)
-
- # decoder only layer norm parameters
- decoder_layer_norm_decay = (
- self.config.decoder_layer_norm_decay
- if self.config.decoder_layer_norm_decay is not None
- else self.config.weight_decay
- )
- params = {
- "lr": self.config.lr,
- "weight_decay": decoder_layer_norm_decay,
- "params": get_layer_norm_parameters(self.decoder_only_named_params()),
- }
- parameters.append(params)
-
- # decoder only other parameters
- params = {
- "lr": self.config.lr,
- "weight_decay": self.config.weight_decay,
- "params": get_non_layer_norm_parameters(self.decoder_only_named_params()),
- }
- parameters.append(params)
-
- # encoder only layer norm parameters
- encoder_layer_norm_decay = (
- self.config.encoder_layer_norm_decay
- if self.config.encoder_layer_norm_decay is not None
- else self.config.weight_decay
- )
- params = {
- "lr": self.config.lr,
- "weight_decay": encoder_layer_norm_decay,
- "params": get_layer_norm_parameters(self.encoder_only_named_params()),
- }
- parameters.append(params)
-
- # encoder only other parameters
- params = {
- "lr": self.config.lr,
- "weight_decay": self.config.weight_decay,
- "params": get_non_layer_norm_parameters(self.encoder_only_named_params()),
- }
- parameters.append(params)
-
- # encoder-decoder shared parameters
- shared_decay = (
- self.config.shared_decay
- if self.config.shared_decay is not None
- else self.config.weight_decay
- )
- params = {
- "lr": self.config.lr,
- "weight_decay": shared_decay,
- "params": dict(self.encoder_decoder_shared_named_params()).values(),
- }
- parameters.append(params)
-
- return torch.optim.AdamW(parameters)
diff --git a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py b/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py
deleted file mode 100644
index 03ca5dea6..000000000
--- a/src/pie_modules/models/base_models/bart_with_decoder_position_ids.py
+++ /dev/null
@@ -1,536 +0,0 @@
-# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""PyTorch BART model, but the decoder accepts predefined position ids. If not provided, the
-original logic is used to create the position ids.
-
-The model is based on the BartModel from Transformers 4.35.0,
-i.e. https://github.com/huggingface/transformers/blob/v4.35.0/src/transformers/models/bart/modeling_bart.py.
-
-Note: This also contains some minor modifications to make the code mypy (v1.4.1) compliant.
-.
-"""
-import math
-from typing import Any, List, Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-from transformers.modeling_attn_mask_utils import (
- _prepare_4d_attention_mask,
- _prepare_4d_causal_attention_mask,
-)
-from transformers.modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPastAndCrossAttentions,
- Seq2SeqModelOutput,
-)
-from transformers.models.bart import BartConfig
-from transformers.models.bart.modeling_bart import (
- _CHECKPOINT_FOR_DOC,
- _CONFIG_FOR_DOC,
- _EXPECTED_OUTPUT_SHAPE,
- BART_INPUTS_DOCSTRING,
- BART_START_DOCSTRING,
- BartDecoderLayer,
- BartEncoder,
- BartPreTrainedModel,
- shift_tokens_right,
-)
-from transformers.utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
-)
-
-logger = logging.get_logger(__name__)
-
-
-class BartLearnedPositionalEmbeddingWithPositionIds(nn.Embedding):
- """This module learns positional embeddings up to a fixed maximum size."""
-
- def __init__(self, num_embeddings: int, embedding_dim: int):
- # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
- # and adjust num_embeddings appropriately. Other models don't have this hack
- self.offset = 2
- super().__init__(num_embeddings + self.offset, embedding_dim)
-
- def forward(
- self,
- input_ids: torch.Tensor,
- past_key_values_length: int = 0,
- position_ids: Optional[torch.Tensor] = None,
- ):
- """`input_ids' shape is expected to be [bsz x seqlen]."""
-
- if position_ids is None:
- bsz, seq_len = input_ids.shape[:2]
- positions = torch.arange(
- past_key_values_length,
- past_key_values_length + seq_len,
- dtype=torch.long,
- device=self.weight.device,
- ).expand(bsz, -1)
- else:
- positions = position_ids
-
- return super().forward(positions + self.offset)
-
-
-class BartDecoderWithPositionIds(BartPreTrainedModel):
- """Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
- [`BartDecoderLayer`]
-
- Args:
- config: BartConfig
- embed_tokens (nn.Embedding): output embedding
- """
-
- def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
- super().__init__(config)
- self.dropout = config.dropout
- self.layerdrop = config.decoder_layerdrop
- self.padding_idx = config.pad_token_id
- self.max_target_positions = config.max_position_embeddings
- self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
-
- self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
-
- if embed_tokens is not None:
- self.embed_tokens.weight = embed_tokens.weight
-
- self.embed_positions = BartLearnedPositionalEmbeddingWithPositionIds(
- config.max_position_embeddings,
- config.d_model,
- )
- self.layers = nn.ModuleList(
- [BartDecoderLayer(config) for _ in range(config.decoder_layers)]
- )
- self.layernorm_embedding = nn.LayerNorm(config.d_model)
-
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
-
- def get_input_embeddings(self):
- return self.embed_tokens
-
- def set_input_embeddings(self, value):
- self.embed_tokens = value
-
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
- r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
- provide it.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Position indices for each input sequence token that are used to create the position embedding
- of the sequence. If `None` (default), position ids are automatically created as sequential
- integers (takes previous `past_key_values` into account, if provided).
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
- of the decoder.
- encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
- Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
- selected in `[0, 1]`:
-
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
-
- [What are attention masks?](../glossary#attention-mask)
- head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
-
- cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
- Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
- cross-attention on hidden heads. Mask values selected in `[0, 1]`:
-
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
-
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
- shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
-
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
- cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
-
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
- that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
- all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded
- representation. This is useful if you want more control over how to convert `input_ids` indices
- into associated vectors than the model's internal embedding lookup matrix.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- output_attentions = (
- output_attentions if output_attentions is not None else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- # retrieve input_ids and inputs_embeds
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
- )
- elif input_ids is not None:
- input = input_ids
- input_shape = input.shape
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- input = inputs_embeds[:, :, -1]
- else:
- raise ValueError(
- "You have to specify either decoder_input_ids or decoder_inputs_embeds"
- )
-
- # past_key_values_length
- past_key_values_length = (
- past_key_values[0][0].shape[2] if past_key_values is not None else 0
- )
-
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input) * self.embed_scale
-
- if getattr(self.config, "_flash_attn_2_enabled", False):
- # 2d mask is passed through the layers
- attention_mask = (
- attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- )
- else:
- # 4d mask is passed through the layers
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
- )
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if getattr(self.config, "_flash_attn_2_enabled", False):
- encoder_attention_mask = (
- encoder_attention_mask if 0 in encoder_attention_mask else None
- )
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
-
- # embed positions
- if position_ids is not None and position_ids.shape != input_shape:
- raise ValueError(
- f"Position IDs shape {position_ids.shape} does not match input ids shape {input_shape}."
- )
- positions = self.embed_positions(input, past_key_values_length, position_ids)
- positions = positions.to(inputs_embeds.device)
-
- hidden_states = inputs_embeds + positions
- hidden_states = self.layernorm_embedding(hidden_states)
-
- hidden_states = nn.functional.dropout(
- hidden_states, p=self.dropout, training=self.training
- )
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- # decoder layers
- all_hidden_states: Optional[Tuple[Any, ...]] = () if output_hidden_states else None
- all_self_attns: Optional[Tuple[Any, ...]] = () if output_attentions else None
- all_cross_attentions: Optional[Tuple[Any, ...]] = (
- () if (output_attentions and encoder_hidden_states is not None) else None
- )
- next_decoder_cache: Optional[Tuple[Any, ...]] = () if use_cache else None
-
- # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
- for attn_mask, mask_name in zip(
- [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]
- ):
- if attn_mask is not None:
- if attn_mask.size()[0] != (len(self.layers)):
- raise ValueError(
- f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
- f" {attn_mask.size()[0]}."
- )
-
- for idx, decoder_layer in enumerate(self.layers):
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
- if all_hidden_states is not None:
- all_hidden_states += (hidden_states,)
- if self.training:
- dropout_probability = torch.rand([])
- if dropout_probability < self.layerdrop:
- continue
-
- past_key_value = past_key_values[idx] if past_key_values is not None else None
-
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- attention_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- head_mask[idx] if head_mask is not None else None,
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
- None,
- output_attentions,
- use_cache,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- layer_head_mask=(head_mask[idx] if head_mask is not None else None),
- cross_attn_layer_head_mask=(
- cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
- ),
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- )
- hidden_states = layer_outputs[0]
-
- if next_decoder_cache is not None:
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
-
- if all_self_attns is not None:
- all_self_attns += (layer_outputs[1],)
-
- if all_cross_attentions is not None:
- all_cross_attentions += (layer_outputs[2],)
-
- # add hidden states from the last decoder layer
- if all_hidden_states is not None:
- all_hidden_states += (hidden_states,)
-
- next_cache = next_decoder_cache if use_cache else None
- if not return_dict:
- return tuple(
- v
- for v in [
- hidden_states,
- next_cache,
- all_hidden_states,
- all_self_attns,
- all_cross_attentions,
- ]
- if v is not None
- )
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- cross_attentions=all_cross_attentions,
- )
-
-
-@add_start_docstrings(
- "The bare BART Model outputting raw hidden-states without any specific head on top.",
- BART_START_DOCSTRING,
-)
-class BartModelWithDecoderPositionIds(BartPreTrainedModel):
- _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
-
- def __init__(self, config: BartConfig):
- super().__init__(config)
-
- padding_idx, vocab_size = config.pad_token_id, config.vocab_size
- self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
-
- self.encoder = BartEncoder(config, self.shared)
- self.decoder = BartDecoderWithPositionIds(config, self.shared)
-
- # Initialize weights and apply final processing
- self.post_init()
-
- def _tie_weights(self):
- if self.config.tie_word_embeddings:
- self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
- self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
-
- def get_input_embeddings(self):
- return self.shared
-
- def set_input_embeddings(self, value):
- self.shared = value
- self.encoder.embed_tokens = self.shared
- self.decoder.embed_tokens = self.shared
-
- def get_encoder(self):
- return self.encoder
-
- def get_decoder(self):
- return self.decoder
-
- @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=Seq2SeqModelOutput,
- config_class=_CONFIG_FOR_DOC,
- expected_output=_EXPECTED_OUTPUT_SHAPE,
- )
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- decoder_position_ids: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- decoder_head_mask: Optional[torch.Tensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- encoder_outputs: Optional[List[torch.FloatTensor]] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, Seq2SeqModelOutput]:
- # different to other models, Bart automatically creates decoder_input_ids from
- # input_ids if no decoder_input_ids are provided
- if decoder_input_ids is None and decoder_inputs_embeds is None:
- if input_ids is None:
- raise ValueError(
- "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
- "passed, `input_ids` cannot be `None`. Please pass either "
- "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
- )
-
- decoder_input_ids = shift_tokens_right(
- input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
- )
-
- output_attentions = (
- output_attentions if output_attentions is not None else self.config.output_attentions
- )
- output_hidden_states = (
- output_hidden_states
- if output_hidden_states is not None
- else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
-
- if not (
- isinstance(encoder_outputs, BaseModelOutput) or isinstance(encoder_outputs, tuple)
- ):
- raise ValueError(
- "Inconsistent output: The output of the model encoder should be of type "
- f"`BaseModelOutput` or tuple, but is of type `{type(encoder_outputs)}`."
- )
-
- # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- position_ids=decoder_position_ids,
- encoder_hidden_states=encoder_outputs[0],
- encoder_attention_mask=attention_mask,
- head_mask=decoder_head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- past_key_values=past_key_values,
- inputs_embeds=decoder_inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- if not return_dict:
- return decoder_outputs + encoder_outputs
-
- return Seq2SeqModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
diff --git a/src/pie_modules/models/common/__init__.py b/src/pie_modules/models/common/__init__.py
deleted file mode 100644
index b021af996..000000000
--- a/src/pie_modules/models/common/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .has_taskmodule import HasTaskmodule
-from .model_with_boilerplate import ModelWithBoilerplate
-from .model_with_metrics_from_taskmodule import ModelWithMetricsFromTaskModule
-from .stages import TESTING, TRAINING, VALIDATION
diff --git a/src/pie_modules/models/common/has_taskmodule.py b/src/pie_modules/models/common/has_taskmodule.py
deleted file mode 100644
index 03cd8154c..000000000
--- a/src/pie_modules/models/common/has_taskmodule.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from typing import Any, Dict, Optional
-
-from pie_core import AutoTaskModule, TaskModule
-
-from pie_modules.models.interface import RequiresTaskmoduleConfig
-
-
-class HasTaskmodule(RequiresTaskmoduleConfig):
- """A mixin class for models that have a taskmodule.
-
- Args:
- taskmodule_config: The config for the taskmodule which can be obtained from the
- taskmodule.config property.
- """
-
- def __init__(self, taskmodule_config: Optional[Dict[str, Any]] = None, **kwargs):
- super().__init__(**kwargs)
- self.taskmodule: Optional[TaskModule] = None
- if taskmodule_config is not None:
- self.taskmodule = AutoTaskModule.from_config(taskmodule_config)
diff --git a/src/pie_modules/models/common/model_with_boilerplate.py b/src/pie_modules/models/common/model_with_boilerplate.py
deleted file mode 100644
index a58d51e5a..000000000
--- a/src/pie_modules/models/common/model_with_boilerplate.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import logging
-from typing import Generic, Optional, Tuple, TypeVar
-
-from typing_extensions import TypeAlias
-
-from .model_with_metrics_from_taskmodule import ModelWithMetricsFromTaskModule
-from .stages import TESTING, TRAINING, VALIDATION
-
-InputType = TypeVar("InputType")
-OutputType = TypeVar("OutputType")
-TargetType = TypeVar("TargetType")
-StepInputType: TypeAlias = Tuple[InputType, TargetType]
-StepOutputType = TypeVar("StepOutputType")
-
-logger = logging.getLogger(__name__)
-
-
-class ModelWithBoilerplate(
- ModelWithMetricsFromTaskModule[InputType, TargetType, OutputType],
- Generic[InputType, OutputType, TargetType, StepOutputType],
-):
- """A PyTorchIEModel that adds boilerplate code for training, validation, and testing.
-
- Especially, it handles updating the metrics and logging of losses and metric results. Also see
- ModelWithMetricsFromTaskModule for more details on how metrics are handled.
- """
-
- def get_loss_from_outputs(self, outputs: OutputType) -> StepOutputType:
- if hasattr(outputs, "loss"):
- return outputs.loss
- else:
- raise ValueError(
- f"The model {self.__class__.__name__} does not define a 'loss' attribute in its output, "
- "so the loss cannot be automatically extracted from the outputs. Please override the"
- "get_loss_from_outputs() method for this model."
- )
-
- def log_loss(self, stage: str, loss: StepOutputType) -> None:
- # show loss on each step only during training
- self.log(
- f"loss/{stage}",
- loss,
- on_step=(stage == TRAINING),
- on_epoch=True,
- prog_bar=True,
- sync_dist=True,
- )
-
- def _step(self, stage: str, batch: StepInputType) -> StepOutputType:
- inputs, targets = batch
- outputs = self(inputs=inputs, targets=targets)
- loss = self.get_loss_from_outputs(outputs=outputs)
- self.log_loss(stage=stage, loss=loss)
- self.update_metric(inputs=inputs, outputs=outputs, targets=targets, stage=stage)
-
- return loss
-
- def training_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType:
- return self._step(stage=TRAINING, batch=batch)
-
- def validation_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType:
- return self._step(stage=VALIDATION, batch=batch)
-
- def test_step(self, batch: StepInputType, batch_idx: int) -> StepOutputType:
- return self._step(stage=TESTING, batch=batch)
-
- def predict_step(
- self, batch: StepInputType, batch_idx: int, dataloader_idx: int = 0
- ) -> TargetType:
- inputs, targets = batch
- return self.predict(inputs=inputs)
-
- def on_train_epoch_end(self) -> None:
- self.log_metric(stage=TRAINING)
-
- def on_validation_epoch_end(self) -> None:
- self.log_metric(stage=VALIDATION)
-
- def on_test_epoch_end(self) -> None:
- self.log_metric(stage=TESTING)
diff --git a/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py b/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py
deleted file mode 100644
index ea65ffd35..000000000
--- a/src/pie_modules/models/common/model_with_metrics_from_taskmodule.py
+++ /dev/null
@@ -1,152 +0,0 @@
-import logging
-from typing import Dict, Generic, List, Optional, Set, TypeVar, Union
-
-from pie_core.utils.dictionary import flatten_dict_s
-from pytorch_ie import PyTorchIEModel
-from torchmetrics import Metric, MetricCollection
-
-from .has_taskmodule import HasTaskmodule
-from .stages import TESTING, TRAINING, VALIDATION
-
-InputType = TypeVar("InputType")
-TargetType = TypeVar("TargetType")
-OutputType = TypeVar("OutputType")
-
-logger = logging.getLogger(__name__)
-
-
-class ModelWithMetricsFromTaskModule(
- HasTaskmodule, PyTorchIEModel, Generic[InputType, TargetType, OutputType]
-):
- """A PyTorchIEModel that adds metrics from a taskmodule.
-
- The metrics are added to the model as attributes with the names metric_{stage} via
- setup_metrics method, where stage is one of "train", "val", or "test". The metrics are updated
- with the update_metric method and logged with the log_metric method.
-
- Args:
- metric_stages: The stages for which to set up metrics. Must be one of "train", "val", or
- "test".
- metric_intervals: A dict mapping metric stages to the number of steps between metric
- calculation. If not provided, the metrics are calculated at the end of each epoch.
- metric_call_predict: Whether to call predict() and use its result for metric calculation
- instead of the (decoded) model output. This is useful, for instance, for generative models
- that define special logic to produce predictions, e.g. beam search, which requires multiple
- passes through the model. If True, predict() is called for all metric stages. If False (default),
- the model outputs are passed to decode() and that is used for all metric stages. If a list of
- metric stages is provided, predict() is called for these stages and the (decoded) model
- outputs for the remaining stages.
- """
-
- def __init__(
- self,
- metric_stages: List[str] = [TRAINING, VALIDATION, TESTING],
- metric_intervals: Optional[Dict[str, int]] = None,
- metric_call_predict: Union[bool, List[str]] = False,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.setup_metrics(metric_stages=metric_stages)
-
- self.metric_intervals = metric_intervals or {}
- missed_stages = set(self.metric_intervals) - set(metric_stages)
- if len(missed_stages) > 0:
- logger.warning(
- f"There are stages in metric_intervals that are not in metric_stages: "
- f"{missed_stages}. Available metric stages: {metric_stages}."
- )
-
- self.use_prediction_for_metrics: Set[str]
- if isinstance(metric_call_predict, bool):
- self.metric_call_predict = set(metric_stages) if metric_call_predict else set()
- else:
- self.metric_call_predict = set(metric_call_predict)
- missed_stages = self.metric_call_predict - set(metric_stages)
- if len(missed_stages) > 0:
- logger.warning(
- f"There are stages in metric_call_predict that are not in metric_stages: "
- f"{missed_stages}. Available metric stages: {metric_stages}."
- )
-
- def setup_metrics(self, metric_stages: List[str]) -> None:
- """Set up metrics for the given stages if a taskmodule is available.
-
- Args:
- metric_stages: The stages for which to set up metrics. Must be one of "train", "val", or
- "test".
- """
- if self.taskmodule is not None:
- for stage in metric_stages:
- metric = self.taskmodule.configure_model_metric(stage=stage)
- if metric is not None:
- self._set_metric(stage=stage, metric=metric)
- else:
- logger.warning(
- f"The taskmodule {self.taskmodule.__class__.__name__} does not define a metric for stage "
- f"'{stage}'."
- )
- elif len(metric_stages) > 0:
- logger.warning(
- "No taskmodule is available, so no metrics are set up. "
- "Please provide a taskmodule_config to enable metrics for stages "
- f"{metric_stages}."
- )
-
- def _get_metric(
- self, stage: str, batch_idx: int = 0
- ) -> Optional[Union[Metric, MetricCollection]]:
- metric_interval = self.metric_intervals.get(stage, 1)
- if (batch_idx + 1) % metric_interval == 0:
- return getattr(self, f"metric_{stage}", None)
- else:
- return None
-
- def _set_metric(self, stage: str, metric: Optional[Union[Metric, MetricCollection]]) -> None:
- setattr(self, f"metric_{stage}", metric)
-
- def update_metric(
- self,
- stage: str,
- inputs: InputType,
- targets: TargetType,
- outputs: OutputType,
- ) -> None:
- """Update the metric for the given stage. If outputs is provided, the predictions are
- decoded from the outputs. Otherwise, the predictions are obtained by directly calling the
- predict method with the inputs (note that this causes the model to be called a second
- time). Finally, the metric is updated with the predictions and targets.
-
- Args:
- stage: The stage for which to update the metric. Must be one of "train", "val", or "test".
- inputs: The inputs to the model.
- targets: The targets for the inputs.
- outputs: The outputs of the model. They are decoded into predictions if provided. If
- outputs is None, the predictions are obtained by directly calling the predict method
- on the inputs.
- """
-
- metric = self._get_metric(stage=stage)
- if metric is not None:
- if stage in self.metric_call_predict:
- predictions = self.predict(inputs=inputs)
- else:
- predictions = self.decode(inputs=inputs, outputs=outputs)
- metric.update(predictions, targets)
-
- def log_metric(self, stage: str, reset: bool = True) -> None:
- """Log the metric for the given stage and reset it."""
-
- metric = self._get_metric(stage=stage)
- if metric is not None:
- values = metric.compute()
- log_kwargs = {"on_step": False, "on_epoch": True, "sync_dist": True}
- if isinstance(values, dict):
- values_flat = flatten_dict_s(values, sep="/")
- for key, value in values_flat.items():
- self.log(f"metric/{key}/{stage}", value, **log_kwargs)
- else:
- metric_name = getattr(metric, "name", None) or type(metric).__name__
- self.log(f"metric/{metric_name}/{stage}", values, **log_kwargs)
- if reset:
- metric.reset()
diff --git a/src/pie_modules/models/common/stages.py b/src/pie_modules/models/common/stages.py
deleted file mode 100644
index 7299e2092..000000000
--- a/src/pie_modules/models/common/stages.py
+++ /dev/null
@@ -1,3 +0,0 @@
-TRAINING = "train"
-VALIDATION = "val"
-TESTING = "test"
diff --git a/src/pie_modules/models/components/__init__.py b/src/pie_modules/models/components/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/pie_modules/models/components/pointer_head.py b/src/pie_modules/models/components/pointer_head.py
deleted file mode 100644
index e8a02acfa..000000000
--- a/src/pie_modules/models/components/pointer_head.py
+++ /dev/null
@@ -1,357 +0,0 @@
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import CrossEntropyLoss
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-
-class PointerHead(torch.nn.Module):
- # Copy and generate,
- def __init__(
- self,
- # (decoder) input space
- target_token_ids: List[int],
- # output space (targets)
- bos_id: int,
- eos_id: int,
- pad_id: int,
- # embeddings
- embeddings: nn.Embedding,
- embedding_weight_mapping: Optional[Dict[Union[int, str], List[int]]] = None,
- # other parameters
- use_encoder_mlp: bool = False,
- use_constraints_encoder_mlp: bool = False,
- decoder_position_id_mode: Optional[nn.Module] = None,
- decoder_position_id_pattern: Optional[List[int]] = None,
- decoder_position_id_mapping: Optional[Dict[str, int]] = None,
- ):
- super().__init__()
-
- self.embeddings = embeddings
-
- self.pointer_offset = len(target_token_ids)
-
- # check that bos, eos, and pad are not out of bounds
- for target_id, target_id_name in zip(
- [bos_id, eos_id, pad_id], ["bos_id", "eos_id", "pad_id"]
- ):
- if target_id >= len(target_token_ids):
- raise ValueError(
- f"{target_id_name} [{target_id}] must be smaller than the number of target token ids "
- f"[{len(target_token_ids)}]!"
- )
-
- self.bos_id = bos_id
- self.eos_id = eos_id
- self.pad_id = pad_id
- # all ids that are not bos, eos or pad are label ids
- self.label_ids = [
- target_id
- for target_id in range(len(target_token_ids))
- if target_id not in [self.bos_id, self.eos_id, self.pad_id]
- ]
-
- target2token_id = torch.LongTensor(target_token_ids)
- self.register_buffer("target2token_id", target2token_id)
- self.label_token_ids = self.target2token_id[self.label_ids]
- self.eos_token_id = target_token_ids[self.eos_id]
- self.pad_token_id = target_token_ids[self.pad_id]
-
- hidden_size = self.embeddings.embedding_dim
- if use_encoder_mlp:
- self.encoder_mlp = nn.Sequential(
- nn.Linear(hidden_size, hidden_size),
- nn.Dropout(0.3),
- nn.ReLU(),
- nn.Linear(hidden_size, hidden_size),
- )
- if use_constraints_encoder_mlp:
- self.constraints_encoder_mlp = nn.Sequential(
- nn.Linear(hidden_size, hidden_size),
- nn.Dropout(0.3),
- nn.ReLU(),
- nn.Linear(hidden_size, hidden_size),
- )
-
- self.embedding_weight_mapping = None
- if embedding_weight_mapping is not None:
- # Because of config serialization, the keys may be strings. Convert them back to ints.
- self.embedding_weight_mapping = {
- int(k): v for k, v in embedding_weight_mapping.items()
- }
-
- self.decoder_position_id_mode = decoder_position_id_mode
- self.decoder_position_id_mapping = decoder_position_id_mapping
- if self.decoder_position_id_mode is None:
- pass
- elif self.decoder_position_id_mode in ["pattern", "pattern_with_increment"]:
- if decoder_position_id_pattern is None:
- raise ValueError(
- "decoder_position_id_pattern must be provided when using "
- 'decoder_position_id_mode="pattern" or "pattern_with_increment"!'
- )
- self.register_buffer(
- "decoder_position_id_pattern", torch.tensor(decoder_position_id_pattern)
- )
- elif self.decoder_position_id_mode == "mapping":
- if self.decoder_position_id_mapping is None:
- raise ValueError(
- 'decoder_position_id_mode="mapping" requires decoder_position_id_mapping to be provided!'
- )
- else:
- raise ValueError(
- f'decoder_position_id_mode="{self.decoder_position_id_mode}" is not supported, '
- 'use one of "pattern", "pattern_with_increment", or "mapping"!'
- )
-
- @property
- def use_prepared_position_ids(self):
- return self.decoder_position_id_mode is not None
-
- def set_embeddings(self, embedding: nn.Embedding) -> None:
- self.embeddings = embedding
-
- def overwrite_embeddings_with_mapping(self) -> None:
- """Overwrite individual embeddings with embeddings for other tokens.
-
- This is useful, for instance, if the label vocabulary is a subset of the source vocabulary.
- In this case, this method can be used to initialize each label embedding with one or
- multiple (averaged) source embeddings.
- """
- if self.embedding_weight_mapping is not None:
- for special_token_index, source_indices in self.embedding_weight_mapping.items():
- self.embeddings.weight.data[special_token_index] = self.embeddings.weight.data[
- source_indices
- ].mean(dim=0)
-
- def prepare_decoder_input_ids(
- self,
- input_ids: torch.LongTensor,
- encoder_input_ids: torch.LongTensor,
- ) -> torch.LongTensor:
- mapping_token_mask = input_ids.lt(self.pointer_offset)
- mapped_tokens = input_ids.masked_fill(input_ids.ge(self.pointer_offset), 0)
- tag_mapped_tokens = self.target2token_id[mapped_tokens]
-
- encoder_input_ids_index = input_ids - self.pointer_offset
- encoder_input_ids_index = encoder_input_ids_index.masked_fill(
- encoder_input_ids_index.lt(0), 0
- )
- encoder_input_length = encoder_input_ids.size(1)
- if encoder_input_ids_index.max() >= encoder_input_length:
- raise ValueError(
- f"encoder_input_ids_index.max() [{encoder_input_ids_index.max()}] must be smaller "
- f"than encoder_input_length [{encoder_input_length}]!"
- )
-
- word_mapped_tokens = encoder_input_ids.gather(index=encoder_input_ids_index, dim=1)
-
- decoder_input_ids = torch.where(
- mapping_token_mask, tag_mapped_tokens, word_mapped_tokens
- ).to(torch.long)
-
- # Note: we do not need to explicitly handle the padding (via a decoder attention mask) because
- # it gets automatically mapped to the pad token id
-
- return decoder_input_ids
-
- def prepare_decoder_position_ids(self, input_ids: torch.LongTensor) -> torch.LongTensor:
- if self.decoder_position_id_mode in ["pattern", "pattern_with_increment"]:
- bsz, tokens_len = input_ids.size()
- pattern_len = len(self.decoder_position_id_pattern)
- # the number of full and partly records. note that tokens_len includes the bos token
- repeat_num = (tokens_len - 2) // pattern_len + 1
- position_ids = self.decoder_position_id_pattern.repeat(bsz, repeat_num)
-
- if self.decoder_position_id_mode == "pattern_with_increment":
- position_ids_reshaped = position_ids.view(bsz, -1, pattern_len)
- add_shift_pos = (
- torch.arange(0, repeat_num, device=position_ids_reshaped.device)
- .repeat(bsz)
- .view(bsz, -1)
- .unsqueeze(-1)
- )
- # multiply by the highest position id in the pattern so that the position ids are unique
- # for any decoder_position_id_pattern across all records
- add_shift_pos *= max(self.decoder_position_id_pattern) + 1
- position_ids_reshaped = add_shift_pos + position_ids_reshaped
- position_ids = position_ids_reshaped.view(bsz, -1).long()
- # use start_position_id=0
- start_pos = torch.zeros(bsz, 1, dtype=position_ids.dtype, device=position_ids.device)
- # shift by 2 to account for start_position_id=0 and pad_position_id=1
- all_position_ids = torch.cat([start_pos, position_ids + 2], dim=-1)
- all_position_ids_truncated = all_position_ids[:bsz, :tokens_len]
-
- # mask the padding tokens
- mask_invalid = input_ids.eq(self.pad_id)
- all_position_ids_truncated_masked = all_position_ids_truncated.masked_fill(
- mask_invalid, 1
- )
-
- return all_position_ids_truncated_masked
- elif self.decoder_position_id_mode == "mapping":
- # we ignor the typing issue here because we ensure that the mapping is not None in the __init__
- mapping: Dict[str, int] = self.decoder_position_id_mapping # type: ignore
- if "default" not in mapping:
- raise ValueError(
- f"mapping must contain a default entry, but only contains {list(mapping)}!"
- )
- position_ids = input_ids.new_full(input_ids.size(), fill_value=mapping["default"])
- # ensure that values for all vocab entries are set first
- if "vocab" in mapping:
- position_ids[input_ids.lt(self.pointer_offset)] = mapping["vocab"]
- already_set: Dict[int, Tuple[str, int]] = {}
- for key, value in mapping.items():
- if key in ["default", "vocab"]:
- continue
- elif key == "bos":
- input_id = self.bos_id
- elif key == "eos":
- input_id = self.eos_id
- elif key == "pad":
- input_id = self.pad_id
- else:
- raise ValueError(f"Mapping contains unknown key '{key}' (mapping: {mapping}).")
- if already_set.get(input_id, (key, value))[1] != value:
- previous_key, previous_value = already_set[input_id]
- raise ValueError(
- f"Can not set the position ids for '{key}' to {value} because it was already "
- f"set to {previous_value} by key '{previous_key}'. Note that both, '{key}' and "
- f"'{previous_key}', have the same id ({input_id}), so their position_ids need to "
- f"be also the same (position id mapping: {mapping})."
- )
- position_ids[input_ids.eq(input_id)] = value
- already_set[input_id] = key, value
- return position_ids
- else:
- raise ValueError(
- f"decoder_position_id_mode={self.decoder_position_id_mode} not supported!"
- )
-
- def prepare_decoder_inputs(
- self,
- input_ids: torch.LongTensor,
- encoder_input_ids: torch.LongTensor,
- position_ids: Optional[torch.LongTensor] = None,
- ) -> Dict[str, torch.Tensor]:
- inputs = {}
- if self.use_prepared_position_ids:
- if position_ids is None:
- position_ids = self.prepare_decoder_position_ids(input_ids=input_ids)
- inputs["position_ids"] = position_ids
-
- inputs["input_ids"] = self.prepare_decoder_input_ids(
- input_ids=input_ids,
- encoder_input_ids=encoder_input_ids,
- )
- return inputs
-
- def forward(
- self,
- last_hidden_state,
- encoder_input_ids,
- encoder_last_hidden_state,
- encoder_attention_mask,
- labels: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.LongTensor] = None,
- constraints: Optional[torch.LongTensor] = None,
- ):
- # assemble the logits
- logits = last_hidden_state.new_full(
- (
- last_hidden_state.size(0),
- last_hidden_state.size(1),
- self.pointer_offset + encoder_input_ids.size(-1),
- ),
- fill_value=-1e24,
- )
-
- # eos and label scores depend only on the decoder output
- # bsz x max_len x 1
- eos_scores = F.linear(last_hidden_state, self.embeddings.weight[[self.eos_token_id]])
- label_embeddings = self.embeddings.weight[self.label_token_ids]
- # bsz x max_len x num_class
- label_scores = F.linear(last_hidden_state, label_embeddings)
-
- # the pointer depends on the src token embeddings, the encoder output and the decoder output
- # bsz x max_bpe_len x hidden_size
- src_outputs = encoder_last_hidden_state
- if getattr(self, "encoder_mlp", None) is not None:
- src_outputs = self.encoder_mlp(src_outputs)
-
- # bsz x max_word_len x hidden_size
- input_embed = self.embeddings(encoder_input_ids)
-
- # bsz x max_len x max_word_len
- word_scores = torch.einsum("blh,bnh->bln", last_hidden_state, src_outputs)
- gen_scores = torch.einsum("blh,bnh->bln", last_hidden_state, input_embed)
- avg_word_scores = (gen_scores + word_scores) / 2
-
- # never point to the padding or the eos token in the encoder input
- # TODO: why not excluding the bos token? seems to give worse results, but not tested extensively
- mask_invalid = encoder_attention_mask.eq(0) | encoder_input_ids.eq(self.eos_token_id)
- avg_word_scores = avg_word_scores.masked_fill(mask_invalid.unsqueeze(1), -1e32)
-
- # Note: the remaining row in logits contains the score for the bos token which should be never generated!
- logits[:, :, [self.eos_id]] = eos_scores
- logits[:, :, self.label_ids] = label_scores
- logits[:, :, self.pointer_offset :] = avg_word_scores
-
- loss = None
- # compute the loss if labels are provided
- if labels is not None:
- loss_fct = CrossEntropyLoss()
- logits_resized = logits.reshape(-1, logits.size(-1))
- labels_resized = labels.reshape(-1)
- if decoder_attention_mask is None:
- raise ValueError("decoder_attention_mask must be provided to compute the loss!")
- mask_resized = decoder_attention_mask.reshape(-1)
- labels_masked = labels_resized.masked_fill(
- ~mask_resized.to(torch.bool), loss_fct.ignore_index
- )
- loss = loss_fct(logits_resized, labels_masked)
-
- # compute the constraints loss if constraints are provided
- if constraints is not None:
- if getattr(self, "constraints_encoder_mlp", None) is not None:
- # TODO: is it fine to apply constraints_encoder_mlp to both src_outputs and label_embeddings?
- # This is what the original code seems to do, but this is different from the usage of encoder_mlp.
- constraints_src_outputs = self.constraints_encoder_mlp(src_outputs)
- constraints_label_embeddings = self.constraints_encoder_mlp(label_embeddings)
- else:
- constraints_src_outputs = src_outputs
- constraints_label_embeddings = label_embeddings
- constraints_label_scores = F.linear(last_hidden_state, constraints_label_embeddings)
- # bsz x max_len x max_word_len
- constraints_word_scores = torch.einsum(
- "blh,bnh->bln", last_hidden_state, constraints_src_outputs
- )
- constraints_logits = last_hidden_state.new_full(
- (
- last_hidden_state.size(0),
- last_hidden_state.size(1),
- self.pointer_offset + encoder_input_ids.size(-1),
- ),
- fill_value=-1e24,
- )
- constraints_logits[:, :, self.label_ids] = constraints_label_scores
- constraints_logits[:, :, self.pointer_offset :] = constraints_word_scores
-
- mask = constraints >= 0
- constraints_logits_valid = constraints_logits[mask]
- constraints_valid = constraints[mask]
- loss_c = F.binary_cross_entropy(
- torch.sigmoid(constraints_logits_valid), constraints_valid.float()
- )
-
- if loss is None:
- loss = loss_c
- else:
- loss += loss_c
-
- return logits, loss
diff --git a/src/pie_modules/models/components/pooler.py b/src/pie_modules/models/components/pooler.py
index b835e2f9c..e69de29bb 100644
--- a/src/pie_modules/models/components/pooler.py
+++ b/src/pie_modules/models/components/pooler.py
@@ -1,274 +0,0 @@
-import logging
-from typing import Any, Callable, Dict, Tuple, Union
-
-import torch
-from torch import Tensor, cat, nn
-
-# possible pooler types
-CLS_TOKEN = "cls_token" # CLS token
-START_TOKENS = "start_tokens" # MTB start tokens concat
-MENTION_POOLING = "mention_pooling" # mention token pooling and concat
-
-
-logger = logging.getLogger(__name__)
-
-
-def pool_cls(hidden_state: Tensor, **kwargs) -> Tensor:
- return hidden_state[:, 0, :]
-
-
-class AtIndexPooler(nn.Module):
- """Pooler that takes the hidden states at given indices. If the index is negative, a learned
- embedding is used.
-
- The indices are expected to have the shape [batch_size, num_indices]. The resulting embeddings are concatenated,
- so the output shape is [batch_size, num_indices * input_dim].
-
- Args:
- input_dim: The input dimension of the hidden state.
- num_indices: The number of indices to pool.
- offset: An offset to add to the indices. This can be useful if the input is prepared with special
- tokens at the beginning / at the end of indexed sequences, and we want to use the hidden state of this
- token instead of the first / last token of the sequence.
-
- Returns:
- The pooled hidden states with shape [batch_size, num_indices * input_dim].
- """
-
- def __init__(self, input_dim: int, num_indices: int = 2, offset: int = 0, **kwargs):
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.num_indices = num_indices
- self.offset = offset
- self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim))
- nn.init.normal_(self.missing_embeddings)
-
- def forward(self, hidden_state: Tensor, indices: Tensor, **kwargs) -> Tensor:
- batch_size, seq_len, hidden_size = hidden_state.shape
- if indices.shape[1] != self.num_indices:
- raise ValueError(
- f"number of indices [{indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
- )
-
- # respect the offset
- indices = indices + self.offset
-
- # times num_types due to concat
- result = torch.zeros(
- batch_size, hidden_size * self.num_indices, device=hidden_state.device
- )
- for batch_idx, current_indices in enumerate(indices):
- current_embeddings = [
- (
- hidden_state[batch_idx, current_indices[i], :]
- if current_indices[i] >= 0
- else self.missing_embeddings[i]
- )
- for i in range(self.num_indices)
- ]
- result[batch_idx] = cat(current_embeddings, 0)
- return result
-
- @property
- def output_dim(self) -> int:
- return self.input_dim * self.num_indices
-
-
-class ArgumentWrappedPooler(nn.Module):
- """Wraps a pooler and maps the arguments to the pooler.
-
- Args:
- pooler: The pooler to wrap.
- argument_mapping: A mapping from the arguments of the forward method to the arguments of the pooler.
- """
-
- def __init__(
- self, pooler: Union[nn.Module, Callable], argument_mapping: Dict[str, str], **kwargs
- ):
- super().__init__(**kwargs)
- self.pooler = pooler
- self.argument_mapping = argument_mapping
-
- def forward(self, hidden_state: Tensor, **kwargs) -> Tensor:
- pooler_kwargs = {}
- for k, v in kwargs.items():
- if k in self.argument_mapping:
- pooler_kwargs[self.argument_mapping[k]] = v
- return self.pooler(hidden_state, **pooler_kwargs)
-
-
-class SpanMaxPooler(nn.Module):
- """Pooler that takes the max hidden state over spans. If the start or end index is negative, a
- learned.
-
- embedding is used. The indices are expected to have the shape [batch_size, num_indices]. The resulting embeddings
- are concatenated, so the output shape is [batch_size, num_indices * input_dim].
-
- Args:
- input_dim: The input dimension of the hidden state.
- num_indices: The number of indices to pool.
-
- Returns:
- The pooled hidden states with shape [batch_size, num_indices * input_dim].
- """
-
- def __init__(self, input_dim: int, num_indices: int = 2, **kwargs):
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.num_indices = num_indices
- self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim))
- nn.init.normal_(self.missing_embeddings)
-
- def forward(
- self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs
- ) -> Tensor:
- batch_size, seq_len, hidden_size = hidden_state.shape
- if start_indices.shape[1] != self.num_indices:
- raise ValueError(
- f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
- )
-
- if end_indices.shape[1] != self.num_indices:
- raise ValueError(
- f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
- )
-
- # check that start_indices are before end_indices
- mask_both_positive = (start_indices >= 0) & (end_indices >= 0)
- mask_start_before_end = start_indices < end_indices
- mask_valid = mask_start_before_end | ~mask_both_positive
- if not torch.all(mask_valid):
- raise ValueError(
- f"values in start_indices have to be smaller than respective values in "
- f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}"
- )
-
- # times num_indices due to concat
- result = torch.zeros(
- batch_size, hidden_size * self.num_indices, device=hidden_state.device
- )
- for batch_idx in range(batch_size):
- current_start_indices = start_indices[batch_idx]
- current_end_indices = end_indices[batch_idx]
- current_embeddings = [
- (
- torch.amax(
- hidden_state[
- batch_idx, current_start_indices[i] : current_end_indices[i], :
- ],
- 0,
- )
- if current_start_indices[i] >= 0 and current_end_indices[i] >= 0
- else self.missing_embeddings[i]
- )
- for i in range(self.num_indices)
- ]
- result[batch_idx] = cat(current_embeddings, 0)
-
- return result
-
- @property
- def output_dim(self) -> int:
- return self.input_dim * self.num_indices
-
-
-class SpanMeanPooler(nn.Module):
- """Pooler that takes the mean hidden state over spans. If the start or end index is negative, a
- learned embedding is used. The indices are expected to have the shape [batch_size,
- num_indices].
-
- The resulting embeddings are concatenated, so the output shape is [batch_size, num_indices * input_dim].
- Note this a slightly modified version of the pie_modules.models.components.pooler.SpanMaxPooler,
- i.e. we changed the aggregation method from torch.amax to torch.mean.
-
- Args:
- input_dim: The input dimension of the hidden state.
- num_indices: The number of indices to pool.
-
- Returns:
- The pooled hidden states with shape [batch_size, num_indices * input_dim].
- """
-
- def __init__(self, input_dim: int, num_indices: int = 2, **kwargs):
- super().__init__(**kwargs)
- self.input_dim = input_dim
- self.num_indices = num_indices
- self.missing_embeddings = nn.Parameter(torch.empty(num_indices, self.input_dim))
- nn.init.normal_(self.missing_embeddings)
-
- def forward(
- self, hidden_state: Tensor, start_indices: Tensor, end_indices: Tensor, **kwargs
- ) -> Tensor:
- batch_size, seq_len, hidden_size = hidden_state.shape
- if start_indices.shape[1] != self.num_indices:
- raise ValueError(
- f"number of start indices [{start_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
- )
-
- if end_indices.shape[1] != self.num_indices:
- raise ValueError(
- f"number of end indices [{end_indices.shape[1]}] has to be the same as num_types [{self.num_indices}]"
- )
-
- # check that start_indices are before end_indices
- mask_both_positive = (start_indices >= 0) & (end_indices >= 0)
- mask_start_before_end = start_indices < end_indices
- mask_valid = mask_start_before_end | ~mask_both_positive
- if not torch.all(mask_valid):
- raise ValueError(
- f"values in start_indices have to be smaller than respective values in "
- f"end_indices, but start_indices=\n{start_indices}\n and end_indices=\n{end_indices}"
- )
-
- # times num_indices due to concat
- result = torch.zeros(
- batch_size, hidden_size * self.num_indices, device=hidden_state.device
- )
- for batch_idx in range(batch_size):
- current_start_indices = start_indices[batch_idx]
- current_end_indices = end_indices[batch_idx]
- current_embeddings = [
- (
- torch.mean(
- hidden_state[
- batch_idx, current_start_indices[i] : current_end_indices[i], :
- ],
- dim=0,
- )
- if current_start_indices[i] >= 0 and current_end_indices[i] >= 0
- else self.missing_embeddings[i]
- )
- for i in range(self.num_indices)
- ]
- result[batch_idx] = cat(current_embeddings, 0)
-
- return result
-
- @property
- def output_dim(self) -> int:
- return self.input_dim * self.num_indices
-
-
-def get_pooler_and_output_size(config: Dict[str, Any], input_dim: int) -> Tuple[Callable, int]:
- pooler_config = dict(config)
- pooler_type = pooler_config.pop("type", CLS_TOKEN)
- if pooler_type == CLS_TOKEN:
- return pool_cls, input_dim
- elif pooler_type == START_TOKENS:
- pooler = AtIndexPooler(input_dim=input_dim, offset=-1, **pooler_config)
- pooler_wrapped = ArgumentWrappedPooler(
- pooler=pooler, argument_mapping={"start_indices": "indices"}
- )
- return pooler_wrapped, pooler.output_dim
- elif pooler_type == MENTION_POOLING:
- aggregate = pooler_config.pop("aggregate", "max")
- if aggregate == "max":
- pooler = SpanMaxPooler(input_dim=input_dim, **pooler_config)
- return pooler, pooler.output_dim
- elif aggregate == "mean":
- pooler = SpanMeanPooler(input_dim=input_dim, **pooler_config)
- return pooler, pooler.output_dim
- else:
- raise ValueError(f'Unknown aggregation method for mention pooling: "{aggregate}"')
- else:
- raise ValueError(f'Unknown pooler type "{pooler_type}"')
diff --git a/src/pie_modules/models/components/seq2seq_encoder.py b/src/pie_modules/models/components/seq2seq_encoder.py
deleted file mode 100644
index 52866be9b..000000000
--- a/src/pie_modules/models/components/seq2seq_encoder.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import logging
-from copy import copy
-from typing import Any, Dict, List, Optional, Tuple
-
-from torch import Tensor, nn
-
-logger = logging.getLogger(__name__)
-
-RNN_TYPE2CLASS = {"lstm": nn.LSTM, "gru": nn.GRU, "rnn": nn.RNN}
-ACTIVATION_TYPE2CLASS = {
- "relu": nn.ReLU,
- "tanh": nn.Tanh,
- "sigmoid": nn.Sigmoid,
- "gelu": nn.GELU,
-}
-
-
-class RNNWrapper(nn.Module):
- def __init__(self, rnn: nn.Module):
- super().__init__()
- self.rnn = rnn
-
- def forward(self, *args, **kwargs) -> Tensor:
- return self.rnn(*args, **kwargs)[0]
-
- @property
- def output_size(self) -> int:
- if self.rnn.bidirectional:
- return self.rnn.hidden_size * 2
- else:
- return self.rnn.hidden_size
-
-
-def build_seq2seq_encoder(
- config: Dict[str, Any], input_size: int
-) -> Tuple[Optional[nn.Module], int]:
- # copy the config to avoid side effects
- config = copy(config)
- seq2seq_encoder_type = config.pop("type", None)
- if seq2seq_encoder_type is None:
- logger.warning(
- f"seq2seq_encoder_type is not specified in the seq2seq_encoder: {config}. "
- f"Do not build this seq2seq_encoder."
- )
- return None, input_size
-
- if seq2seq_encoder_type == "sequential":
- modules: List[nn.Module] = []
- output_size = input_size
- for key, subconfig in config.items():
- module, output_size = build_seq2seq_encoder(subconfig, input_size)
- if module is not None:
- modules.append(module)
- input_size = output_size
-
- seq2seq_encoder = nn.Sequential(*modules)
- elif seq2seq_encoder_type in RNN_TYPE2CLASS:
- rnn_class = RNN_TYPE2CLASS[seq2seq_encoder_type]
- seq2seq_encoder = RNNWrapper(rnn_class(input_size=input_size, batch_first=True, **config))
- output_size = seq2seq_encoder.output_size
- elif seq2seq_encoder_type == "linear":
- seq2seq_encoder = nn.Linear(in_features=input_size, **config)
- output_size = seq2seq_encoder.out_features
- elif seq2seq_encoder_type in ACTIVATION_TYPE2CLASS:
- activation_class = ACTIVATION_TYPE2CLASS[seq2seq_encoder_type]
- seq2seq_encoder = activation_class(**config)
- output_size = input_size
- elif seq2seq_encoder_type == "dropout":
- seq2seq_encoder = nn.Dropout(**config)
- output_size = input_size
- elif seq2seq_encoder_type == "none":
- seq2seq_encoder = None
- output_size = input_size
- else:
- raise ValueError(f"Unknown seq2seq_encoder_type: {seq2seq_encoder_type}")
-
- return seq2seq_encoder, output_size
diff --git a/src/pie_modules/models/interface.py b/src/pie_modules/models/interface.py
deleted file mode 100644
index a5bc0bfad..000000000
--- a/src/pie_modules/models/interface.py
+++ /dev/null
@@ -1,12 +0,0 @@
-class RequiresMaxInputLength:
- """Any class inheriting from this class should require a constructor parameter
- 'max_input_length'."""
-
- pass
-
-
-class RequiresTaskmoduleConfig:
- """Any class inheriting from this class should require a constructor parameter
- 'taskmodule_config'."""
-
- pass
diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py
deleted file mode 100644
index bb7d72785..000000000
--- a/src/pie_modules/models/sequence_classification_with_pooler.py
+++ /dev/null
@@ -1,362 +0,0 @@
-import logging
-from abc import ABC, abstractmethod
-from typing import (
- Any,
- Callable,
- Dict,
- Iterator,
- List,
- MutableMapping,
- Optional,
- Tuple,
- TypeVar,
- Union,
-)
-
-import torch
-from pytorch_ie import PyTorchIEModel
-from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
-from torch import FloatTensor, LongTensor, nn
-from torch.nn import Parameter
-from torch.optim import AdamW
-from transformers import (
- AutoConfig,
- AutoModel,
- PreTrainedModel,
- get_linear_schedule_with_warmup,
-)
-from transformers.modeling_outputs import SequenceClassifierOutput
-from typing_extensions import TypeAlias
-
-from .common import ModelWithBoilerplate
-from .components.pooler import get_pooler_and_output_size
-
-# model inputs / outputs / targets
-InputType: TypeAlias = MutableMapping[str, LongTensor]
-OutputType: TypeAlias = SequenceClassifierOutput
-TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]]
-# step inputs (batch) / outputs (loss)
-StepInputType: TypeAlias = Tuple[InputType, TargetType]
-StepOutputType: TypeAlias = FloatTensor
-
-
-HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE = {
- "albert": "classifier_dropout_prob",
- "distilbert": "seq_classif_dropout",
-}
-
-logger = logging.getLogger(__name__)
-
-T = TypeVar("T")
-
-
-def separate_arguments_by_prefix(
- arguments: MutableMapping[str, T], prefixes: List[str]
-) -> Dict[str, Dict[str, T]]:
- result: Dict[str, Dict[str, T]] = {prefix: {} for prefix in prefixes + ["remaining"]}
- for k, v in arguments.items():
- found = False
- for prefix in prefixes:
- if k.startswith(prefix):
- result[prefix][k[len(prefix) :]] = v
- found = True
- break
- if not found:
- result["remaining"][k] = v
- return result
-
-
-class SequenceClassificationModelWithPoolerBase(
- ABC,
- ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
- RequiresModelNameOrPath,
-):
- """Abstract base model for sequence classification with a pooler.
-
- Args:
- model_name_or_path: The name or path of the HuggingFace model to use.
- tokenizer_vocab_size: The size of the tokenizer vocabulary. If provided, the model's
- tokenizer embeddings are resized to this size.
- classifier_dropout: The dropout probability for the classifier. If not provided, the
- dropout probability is taken from the Huggingface model config.
- learning_rate: The learning rate for the optimizer.
- task_learning_rate: The learning rate for the task-specific parameters. If None, the
- learning rate for all parameters is set to `learning_rate`.
- warmup_proportion: The proportion of steps to warm up the learning rate.
- pooler: The pooler configuration. If None, CLS token pooling is used.
- freeze_base_model: If True, the base model parameters are frozen.
- base_model_prefix: The prefix of the base model parameters when using a task_learning_rate
- or freeze_base_model. If None, the base_model_prefix of the model is used.
- **kwargs: Additional keyword arguments passed to the parent class,
- see :class:`ModelWithBoilerplate`.
- """
-
- def __init__(
- self,
- model_name_or_path: str,
- tokenizer_vocab_size: Optional[int] = None,
- classifier_dropout: Optional[float] = None,
- learning_rate: float = 1e-5,
- task_learning_rate: Optional[float] = None,
- warmup_proportion: float = 0.1,
- pooler: Optional[Union[Dict[str, Any], str]] = None,
- freeze_base_model: bool = False,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.save_hyperparameters()
-
- self.learning_rate = learning_rate
- self.task_learning_rate = task_learning_rate
- self.warmup_proportion = warmup_proportion
- self.freeze_base_model = freeze_base_model
- self.model_name_or_path = model_name_or_path
-
- self.model = self.setup_base_model()
-
- if tokenizer_vocab_size is not None:
- self.model.resize_token_embeddings(tokenizer_vocab_size)
-
- if self.freeze_base_model:
- for param in self.model.parameters():
- param.requires_grad = False
-
- if classifier_dropout is None:
- # Get the classifier dropout value from the Huggingface model config.
- # This is a bit of a mess since some Configs use different variable names or change the semantics
- # of the dropout (e.g. DistilBert has one dropout prob for QA and one for Seq classification, and a
- # general one for embeddings, encoder and pooler).
- classifier_dropout_attr = HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE.get(
- self.model.config.model_type, "classifier_dropout"
- )
- classifier_dropout = getattr(self.model.config, classifier_dropout_attr) or 0.0
- self.dropout = nn.Dropout(classifier_dropout)
-
- if isinstance(pooler, str):
- pooler = {"type": pooler}
- self.pooler_config = pooler or {}
- self.pooler, pooler_output_dim = self.setup_pooler(input_dim=self.model.config.hidden_size)
- self.classifier = self.setup_classifier(pooler_output_dim=pooler_output_dim)
- self.loss_fct = self.setup_loss_fct()
-
- def setup_base_model(self) -> PreTrainedModel:
- config = AutoConfig.from_pretrained(self.model_name_or_path)
- if self.is_from_pretrained:
- return AutoModel.from_config(config=config)
- else:
- return AutoModel.from_pretrained(self.model_name_or_path, config=config)
-
- @abstractmethod
- def setup_classifier(self, pooler_output_dim: int) -> Callable:
- pass
-
- @abstractmethod
- def setup_loss_fct(self) -> Callable:
- pass
-
- def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
- """Set up the pooler. The pooler is used to get a representation of the input sequence(s)
- that can be used by the classifier. It is a callable that takes the hidden states of the
- base model (and additional model inputs that are prefixed with "pooler_") and returns the
- pooled output.
-
- Args:
- input_dim: The input dimension of the pooler, i.e. the hidden size of the base model.
-
- Returns:
- A tuple with the pooler and the output dimension of the pooler.
- """
- return get_pooler_and_output_size(config=self.pooler_config, input_dim=input_dim)
-
- def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor:
- output = self.model(**model_inputs)
- hidden_state = output.last_hidden_state
- pooled_output = self.pooler(hidden_state, **pooler_inputs)
- pooled_output = self.dropout(pooled_output)
- return pooled_output
-
- def forward(
- self,
- inputs: InputType,
- targets: Optional[TargetType] = None,
- return_hidden_states: bool = False,
- ) -> OutputType:
- sanitized_inputs = separate_arguments_by_prefix(arguments=inputs, prefixes=["pooler_"])
-
- pooled_output = self.get_pooled_output(
- model_inputs=sanitized_inputs["remaining"], pooler_inputs=sanitized_inputs["pooler_"]
- )
-
- logits = self.classifier(pooled_output)
-
- result = {"logits": logits}
- if targets is not None:
- labels = targets["labels"]
- loss = self.loss_fct(logits, labels)
- result["loss"] = loss
- if return_hidden_states:
- raise NotImplementedError("return_hidden_states is not yet implemented")
-
- return SequenceClassifierOutput(**result)
-
- @abstractmethod
- def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
- pass
-
- def base_model_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- if prefix:
- prefix = f"{prefix}."
- return self.model.named_parameters(prefix=f"{prefix}model")
-
- def task_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- if prefix:
- prefix = f"{prefix}."
- base_model_parameter_names = dict(self.base_model_named_parameters(prefix=prefix)).keys()
- for name, param in self.named_parameters(prefix=prefix):
- if name not in base_model_parameter_names:
- yield name, param
-
- def configure_optimizers(self):
- if self.task_learning_rate is not None:
- base_model_params = (param for name, param in self.base_model_named_parameters())
- task_params = (param for name, param in self.task_named_parameters())
- optimizer = AdamW(
- [
- {"params": base_model_params, "lr": self.learning_rate},
- {"params": task_params, "lr": self.task_learning_rate},
- ]
- )
- else:
- optimizer = AdamW(self.parameters(), lr=self.learning_rate)
-
- if self.warmup_proportion > 0.0:
- stepping_batches = self.trainer.estimated_stepping_batches
- scheduler = get_linear_schedule_with_warmup(
- optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches
- )
- return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
- else:
- return optimizer
-
-
-@PyTorchIEModel.register()
-class SequenceClassificationModelWithPooler(
- SequenceClassificationModelWithPoolerBase,
- RequiresNumClasses,
-):
- """A sequence classification model that uses a pooler to get a representation of the input
- sequence and then applies a linear classifier to that representation. The pooler can be
- configured via the `pooler` argument, see :func:`get_pooler_and_output_size` for details.
-
- Args:
- num_classes: The number of classes for the classification task.
- multi_label: If True, the model is trained as a multi-label classifier.
- multi_label_threshold: The threshold for the multi-label classifier, i.e. the probability
- above which a class is predicted.
- **kwargs
- """
-
- def __init__(
- self,
- num_classes: int,
- multi_label: bool = False,
- multi_label_threshold: float = 0.5,
- **kwargs,
- ):
- # set num_classes and multi_label before call to super init because they are used there
- # in setup_classifier and setup_loss_fct
- self.num_classes = num_classes
- self.multi_label = multi_label
- super().__init__(**kwargs)
-
- self.multi_label_threshold = multi_label_threshold
-
- def setup_classifier(self, pooler_output_dim: int) -> Callable:
- return nn.Linear(pooler_output_dim, self.num_classes)
-
- def setup_loss_fct(self) -> Callable:
- return nn.BCEWithLogitsLoss() if self.multi_label else nn.CrossEntropyLoss()
-
- def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
- if not self.multi_label:
- labels = torch.argmax(outputs.logits, dim=-1).to(torch.long)
- probabilities = torch.softmax(outputs.logits, dim=-1)
- else:
- probabilities = torch.sigmoid(outputs.logits)
- labels = (probabilities > self.multi_label_threshold).to(torch.long)
- return {"labels": labels, "probabilities": probabilities}
-
-
-@PyTorchIEModel.register()
-class SequencePairSimilarityModelWithPooler(
- SequenceClassificationModelWithPoolerBase,
-):
- """A span pair similarity model to detect of two spans occurring in different texts are
- similar. It uses an encoder to independently calculate contextualized embeddings of both texts,
- then uses a pooler to get representations of the spans and, finally, calculates the cosine to
- get the similarity scores.
-
- Args:
- label_threshold: The threshold above which score the spans are considered as similar.
- pooler: The pooler identifier or config, see :func:`get_pooler_and_output_size` for details.
- Defaults to "mention_pooling" (max pooling over the span token embeddings).
- **kwargs
- """
-
- def __init__(
- self,
- pooler: Optional[Union[Dict[str, Any], str]] = None,
- **kwargs,
- ):
- if pooler is None:
- # use (max) mention pooling per default
- pooler = {"type": "mention_pooling", "num_indices": 1}
- super().__init__(pooler=pooler, **kwargs)
-
- def setup_classifier(
- self, pooler_output_dim: int
- ) -> Callable[[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor]:
- return torch.nn.functional.cosine_similarity
-
- def setup_loss_fct(self) -> Callable:
- return nn.BCELoss()
-
- def forward(
- self,
- inputs: InputType,
- targets: Optional[TargetType] = None,
- return_hidden_states: bool = False,
- ) -> OutputType:
- sanitized_inputs = separate_arguments_by_prefix(
- # Note that the order of the prefixes is important because one is a prefix of the other,
- # so we need to start with the longer!
- arguments=inputs,
- prefixes=["pooler_pair_", "pooler_"],
- )
-
- pooled_output = self.get_pooled_output(
- model_inputs=sanitized_inputs["remaining"]["encoding"],
- pooler_inputs=sanitized_inputs["pooler_"],
- )
- pooled_output_pair = self.get_pooled_output(
- model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
- pooler_inputs=sanitized_inputs["pooler_pair_"],
- )
-
- logits = self.classifier(pooled_output, pooled_output_pair)
-
- result = {"logits": logits}
- if targets is not None:
- labels = targets["scores"]
- loss = self.loss_fct(logits, labels)
- result["loss"] = loss
- if return_hidden_states:
- raise NotImplementedError("return_hidden_states is not yet implemented")
-
- return SequenceClassifierOutput(**result)
-
- def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
- # probabilities = torch.sigmoid(outputs.logits)
- scores = outputs.logits
- return {"scores": scores}
diff --git a/src/pie_modules/models/simple_extractive_question_answering.py b/src/pie_modules/models/simple_extractive_question_answering.py
deleted file mode 100644
index d54436515..000000000
--- a/src/pie_modules/models/simple_extractive_question_answering.py
+++ /dev/null
@@ -1,165 +0,0 @@
-from typing import Any, Dict, MutableMapping, Optional, Tuple
-
-from pytorch_ie import PyTorchIEModel
-from pytorch_ie.models.interface import RequiresModelNameOrPath
-from pytorch_lightning.utilities.types import OptimizerLRScheduler
-from torch import Tensor
-from torch.nn import ModuleDict, functional
-from torch.optim import Adam
-from torchmetrics import F1Score
-from transformers import (
- AutoConfig,
- AutoModelForQuestionAnswering,
- BatchEncoding,
- get_linear_schedule_with_warmup,
-)
-from transformers.modeling_outputs import QuestionAnsweringModelOutput
-from typing_extensions import TypeAlias
-
-from pie_modules.models.interface import RequiresMaxInputLength
-
-BatchOutput: TypeAlias = Dict[str, Any]
-
-# The input to the forward method of this model. It is passed to
-# the base transformer model.
-ModelInputType: TypeAlias = MutableMapping[str, Any]
-# The output of the forward method of this model.
-ModelOutputType: TypeAlias = QuestionAnsweringModelOutput
-# The input to the step methods, i.e. training_step, validation_step, test_step.
-# It contains the input and target tensors for a single training step.
-StepBatchEncoding: TypeAlias = Tuple[
- ModelInputType,
- Optional[Dict[str, Tensor]],
-]
-
-
-TRAINING = "train"
-VALIDATION = "val"
-TEST = "test"
-
-
-@PyTorchIEModel.register()
-class SimpleExtractiveQuestionAnsweringModel(
- PyTorchIEModel, RequiresModelNameOrPath, RequiresMaxInputLength
-):
- """A PIE model for extractive question answering. It is a simple Pytorch-Lightning module that
- wraps around a question answering model from the Huggingface transformers library. The
- ExtractiveQuestionAnsweringTaskModule can be used create the input and target encodings as well
- as to decode the model output.
-
- Args:
- model_name_or_path: The name (Huggingface Hub model identifier) or local path of the model to use.
- max_input_length: The maximum length of the input sequence. Required for metric calculation.
- learning_rate: The learning rate to use for training. Defaults to 1e-5.
- """
-
- def __init__(
- self,
- model_name_or_path: str,
- max_input_length: int,
- learning_rate: float = 1e-5,
- warmup_proportion: float = 0.0,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- self.learning_rate = learning_rate
- self.warmup_proportion = warmup_proportion
- self.max_input_length = max_input_length
-
- config = AutoConfig.from_pretrained(model_name_or_path)
- if self.is_from_pretrained:
- self.model = AutoModelForQuestionAnswering.from_config(config=config)
- else:
- self.model = AutoModelForQuestionAnswering.from_pretrained(
- model_name_or_path, config=config
- )
-
- self.f1_start: Dict[str, F1Score] = ModuleDict(
- {
- f"stage_{stage}": F1Score(task="multiclass", num_classes=max_input_length)
- for stage in [TRAINING, VALIDATION, TEST]
- }
- )
- self.f1_end: Dict[str, F1Score] = ModuleDict(
- {
- f"stage_{stage}": F1Score(task="multiclass", num_classes=max_input_length)
- for stage in [TRAINING, VALIDATION, TEST]
- }
- )
-
- def forward(self, inputs: BatchEncoding) -> ModelOutputType:
- return self.model(**inputs)
-
- def step(
- self,
- stage: str,
- batch: StepBatchEncoding,
- ) -> Tensor:
- inputs, targets = batch
- if targets is None:
- raise ValueError("targets has to be available for training, but it is None")
-
- output = self({**inputs, **targets})
-
- loss = output.loss
- # show loss on each step only during training
- self.log(f"{stage}/loss", loss, on_step=(stage == TRAINING), on_epoch=True, prog_bar=True)
-
- start_positions = targets["start_positions"]
- end_positions = targets["end_positions"]
- start_logits = output.start_logits
- end_logits = output.end_logits
-
- sequence_length = inputs["input_ids"].size(1)
- f1_start = self.f1_end[f"stage_{stage}"]
- # We need to pad the logits to the max_input_length, otherwise the F1 metric complains
- # that the shape does not match the num_classes.
- start_logits_padded = functional.pad(
- start_logits, (0, self.max_input_length - sequence_length), value=float("-inf")
- )
- f1_start(start_logits_padded, start_positions)
- self.log(
- f"{stage}/f1_start",
- f1_start,
- on_step=(stage == TRAINING),
- on_epoch=True,
- prog_bar=True,
- )
- f1_end = self.f1_end[f"stage_{stage}"]
- # We need to pad the logits to the max_input_length, otherwise the F1 metric complains
- # that the shape does not match the num_classes.
- end_logits_padded = functional.pad(
- end_logits, (0, self.max_input_length - sequence_length), value=float("-inf")
- )
- f1_end(end_logits_padded, end_positions)
- self.log(
- f"{stage}/f1_end", f1_end, on_step=(stage == TRAINING), on_epoch=True, prog_bar=True
- )
- # log f1 as simple average of start and end f1. we need to call compute() on the metric to get
- # the actual value, otherwise lightning complains that there is no model attribute with name "f1"
- f1_value = (f1_start.compute() + f1_end.compute()) / 2
- self.log(f"{stage}/f1", f1_value, on_step=False, on_epoch=True, prog_bar=True)
- return loss
-
- def training_step(self, batch: StepBatchEncoding, batch_idx: int) -> Tensor:
- return self.step(stage=TRAINING, batch=batch)
-
- def validation_step(self, batch: StepBatchEncoding, batch_idx: int) -> Tensor:
- return self.step(stage=VALIDATION, batch=batch)
-
- def test_step(self, batch: StepBatchEncoding, batch_idx: int) -> Tensor:
- return self.step(stage=TEST, batch=batch)
-
- def configure_optimizers(self) -> OptimizerLRScheduler:
- optimizer = Adam(self.parameters(), lr=self.learning_rate)
-
- if self.warmup_proportion > 0.0:
- stepping_batches = self.trainer.estimated_stepping_batches
- scheduler = get_linear_schedule_with_warmup(
- optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches
- )
- return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
- else:
- return optimizer
diff --git a/src/pie_modules/models/simple_generative.py b/src/pie_modules/models/simple_generative.py
deleted file mode 100644
index e5d7d4549..000000000
--- a/src/pie_modules/models/simple_generative.py
+++ /dev/null
@@ -1,196 +0,0 @@
-import copy
-import logging
-from typing import Any, Dict, Optional, Tuple, Type, Union
-
-import torch
-from pie_core.utils.hydra import resolve_type
-from pytorch_ie import PyTorchIEModel
-from pytorch_lightning.utilities.types import OptimizerLRScheduler
-from torch import FloatTensor, LongTensor
-from torch.optim import Optimizer
-from transformers import PreTrainedModel, SchedulerType, get_scheduler
-from transformers.modeling_outputs import Seq2SeqLMOutput
-from typing_extensions import TypeAlias
-
-from pie_modules.models.common import ModelWithBoilerplate
-
-logger = logging.getLogger(__name__)
-
-# model inputs / outputs / targets
-InputType: TypeAlias = Dict[str, LongTensor]
-OutputType: TypeAlias = Seq2SeqLMOutput
-TargetType: TypeAlias = Dict[str, LongTensor]
-# step inputs (batch) / outputs (loss)
-StepInputType: TypeAlias = Tuple[InputType, TargetType]
-StepOutputType: TypeAlias = FloatTensor
-
-
-@PyTorchIEModel.register()
-class SimpleGenerativeModel(
- ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
-):
- """This model is a simple wrapper around a generative model from Huggingface transformers. That
- means, its predict() and predict_step() methods will call the generate() method of the base
- model.
-
- If a taskmodule config is provided, the taskmodule will be instantiated and used to create metrics and
- a generation config with its configure_model_metric() and configure_model_generation() methods,
- respectively.
-
- If the base model has a configure_optimizer() method, this will be used to create the optimizer. Otherwise,
- the optimizer_type and learning_rate will be used to create an optimizer.
-
- Args:
- base_model_type: The type of the base model, e.g. "transformers.AutoModelForSeq2SeqLM". It should have a
- from_pretrained() method.
- base_model_config: A dictionary with the keyword arguments that will be passed to the from_pretrained()
- method of the base model.
- override_generation_kwargs: The generation config for the base model. This will override the generation config
- from the taskmodule, if one is provided.
- warmup_proportion: The proportion of the training steps that will be used for the warmup of the learning rate
- scheduler.
- learning_rate: The learning rate for the optimizer. If the base model has a configure_optimizer() method, this
- will be ignored.
- optimizer_type: The type of the optimizer. If the base model has a configure_optimizer() method, this will be
- ignored.
- **kwargs: Additional keyword arguments that will be passed to the PyTorchIEModel constructor.
- """
-
- def __init__(
- self,
- # base model setup
- base_model: Optional[Dict[str, Any]] = None,
- # old setup
- base_model_type: Optional[str] = None,
- base_model_config: Optional[Dict[str, Any]] = None,
- # generation
- override_generation_kwargs: Optional[Dict[str, Any]] = None,
- # optimizer / schedular
- # important: the following entries (optimizer_type and learning_rate) are only used
- # if the base model does not have a configure_optimizer method!
- optimizer_type: Optional[Union[str, Type[Optimizer]]] = None,
- learning_rate: Optional[float] = None,
- warmup_proportion: float = 0.0,
- scheduler_name: Optional[Union[str, SchedulerType]] = None,
- scheduler_kwargs: Optional[Dict[str, Any]] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
-
- if base_model is None:
- if base_model_type is None:
- raise ValueError(
- "Either base_model or base_model_type must be provided. If base_model is not provided, "
- "base_model_type must be a valid model type, e.g. 'transformers.AutoModelForSeq2SeqLM'."
- )
- logger.warning(
- "The base_model_type and base_model_config arguments are deprecated. Please use base_model. "
- "You can use the following code to create the base_model argument: "
- "base_model = {'_type_': base_model_type, **base_model_config}"
- )
- base_model = {"_type_": base_model_type, **(base_model_config or {})}
-
- if scheduler_name is None and warmup_proportion > 0.0:
- logger.warning(
- "warmup_proportion is set to a value > 0.0, but scheduler_name is not set. "
- "Setting scheduler_name to 'linear' by default."
- )
- scheduler_name = "linear"
-
- self.save_hyperparameters(ignore=["base_model_type", "base_model_config"])
-
- # optimizer / scheduler
- self.learning_rate = learning_rate
- self.optimizer_type = optimizer_type
- self.scheduler_name = scheduler_name
- self.warmup_proportion = warmup_proportion
- self.scheduler_kwargs = scheduler_kwargs or {}
-
- self.model = self.setup_base_model(config=base_model)
- self.generation_config = self.configure_generation(**(override_generation_kwargs or {}))
-
- def setup_base_model(self, config: Dict[str, Any]) -> PreTrainedModel:
- config = copy.copy(config)
- resolved_base_model_type: Type[PreTrainedModel] = resolve_type(config.pop("_type_"))
- return resolved_base_model_type.from_pretrained(**config)
-
- def configure_generation(self, **kwargs) -> Dict[str, Any]:
- if self.taskmodule is not None:
- # get the generation config from the taskmodule
- generation_config = self.taskmodule.configure_model_generation()
- else:
- logger.warning(
- "No taskmodule is available, so no generation config will be created. Consider "
- "setting taskmodule_config to a valid taskmodule config to use specific setup for generation."
- )
- generation_config = {}
- generation_config.update(kwargs)
- return generation_config
-
- def predict(self, inputs, **kwargs) -> TargetType:
- is_training = self.training
- self.eval()
-
- generation_kwargs = copy.deepcopy(self.generation_config)
- generation_kwargs.update(kwargs)
- outputs = self.model.generate(**inputs, **generation_kwargs)
-
- if is_training:
- self.train()
-
- # TODO: move into base model? or does this work for "all" generative models?
- # strip the bos_id
- if isinstance(outputs, torch.Tensor):
- return {"labels": outputs[:, 1:]}
- else:
- raise ValueError(f"Unsupported output type: {type(outputs)}")
-
- def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
- kwargs = {**inputs, **(targets or {})}
- return self.model(**kwargs)
-
- def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
- # construct prediction from the model output
- logits = outputs.logits
- # get the indices (these are without the initial bos_ids, see above)
- prediction = torch.argmax(logits, dim=-1)
- return {"labels": prediction.to(torch.long)}
-
- def configure_optimizers(self) -> OptimizerLRScheduler:
- if hasattr(self.model, "configure_optimizer") and callable(self.model.configure_optimizer):
- if self.learning_rate is not None:
- raise ValueError(
- f"learning_rate is set to {self.learning_rate}, but the *base model* ({type(self.model)}) has a "
- f"configure_optimizer method. Please set learning_rate to None and configure the optimizer "
- f"inside the *base model*."
- )
- optimizer = self.model.configure_optimizer()
- else:
- logger.warning(
- f"The model does not have a configure_optimizer method. Creating an optimizer of "
- f"optimizer_type={self.optimizer_type} with the learning_rate={self.learning_rate} instead."
- )
- if self.optimizer_type is None:
- raise ValueError(
- f"optimizer_type is None, but the *base model* ({type(self.model)}) does not have a "
- f"configure_optimizer method. Please set the optimizer_type to a valid optimizer type, "
- f"e.g. optimizer_type=torch.optim.Adam."
- )
- resolved_optimizer_type = resolve_type(
- self.optimizer_type, expected_super_type=Optimizer
- )
- optimizer = resolved_optimizer_type(self.parameters(), lr=self.learning_rate)
-
- if self.scheduler_name is not None:
- num_training_steps = self.trainer.estimated_stepping_batches
- num_warmup_steps = int(num_training_steps * self.warmup_proportion)
- scheduler = get_scheduler(
- name=self.scheduler_name,
- optimizer=optimizer,
- num_warmup_steps=num_warmup_steps,
- num_training_steps=num_training_steps,
- scheduler_specific_kwargs=self.scheduler_kwargs,
- )
- return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
- else:
- return optimizer
diff --git a/src/pie_modules/models/simple_sequence_classification.py b/src/pie_modules/models/simple_sequence_classification.py
deleted file mode 100644
index 651a043a2..000000000
--- a/src/pie_modules/models/simple_sequence_classification.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import logging
-from typing import Iterator, MutableMapping, Optional, Tuple, Union
-
-import torch.nn
-from pytorch_ie import PyTorchIEModel
-from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
-from torch import FloatTensor, LongTensor
-from torch.nn import Parameter
-from torch.optim import AdamW
-from transformers import (
- AutoConfig,
- AutoModelForSequenceClassification,
- get_linear_schedule_with_warmup,
-)
-from transformers.modeling_outputs import SequenceClassifierOutput
-from typing_extensions import TypeAlias
-
-from pie_modules.models.common import ModelWithBoilerplate
-
-# model inputs / outputs / targets
-InputType: TypeAlias = MutableMapping[str, LongTensor]
-OutputType: TypeAlias = SequenceClassifierOutput
-TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]]
-# step inputs (batch) / outputs (loss)
-StepInputType: TypeAlias = Tuple[InputType, TargetType]
-StepOutputType: TypeAlias = FloatTensor
-
-
-logger = logging.getLogger(__name__)
-
-
-@PyTorchIEModel.register()
-class SimpleSequenceClassificationModel(
- ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
- RequiresModelNameOrPath,
- RequiresNumClasses,
-):
- """A simple sequence classification model. It wraps a HuggingFace
- AutoModelForSequenceClassification and adds boilerplate code for training and inference.
-
- Args:
- model_name_or_path: The name or path of the HuggingFace model to use.
- num_classes: The number of classes for the classification task.
- tokenizer_vocab_size: The size of the tokenizer vocabulary. If provided, the model's
- tokenizer embeddings are resized to this size.
- learning_rate: The learning rate for the optimizer.
- task_learning_rate: The learning rate for the task-specific parameters. If None, the
- learning rate for all parameters is set to `learning_rate`.
- warmup_proportion: The proportion of steps to warm up the learning rate.
- freeze_base_model: If True, the base model parameters are frozen.
- base_model_prefix: The prefix of the base model parameters when using a task_learning_rate
- or freeze_base_model. If None, the base_model_prefix of the model is used.
- **kwargs: Additional keyword arguments passed to the parent class,
- see :class:`ModelWithBoilerplate`.
- """
-
- def __init__(
- self,
- model_name_or_path: str,
- num_classes: int,
- tokenizer_vocab_size: Optional[int] = None,
- learning_rate: float = 1e-5,
- task_learning_rate: Optional[float] = None,
- warmup_proportion: float = 0.1,
- freeze_base_model: bool = False,
- base_model_prefix: Optional[str] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.save_hyperparameters()
-
- self.learning_rate = learning_rate
- self.task_learning_rate = task_learning_rate
- self.warmup_proportion = warmup_proportion
- self.freeze_base_model = freeze_base_model
-
- config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_classes)
- if self.is_from_pretrained:
- self.model = AutoModelForSequenceClassification.from_config(config=config)
- else:
- self.model = AutoModelForSequenceClassification.from_pretrained(
- model_name_or_path, config=config
- )
-
- self.base_model_prefix = base_model_prefix or self.model.base_model_prefix
-
- if tokenizer_vocab_size is not None:
- self.model.resize_token_embeddings(tokenizer_vocab_size)
-
- if self.freeze_base_model:
- for name, param in self.base_model_named_parameters():
- param.requires_grad = False
-
- def base_model_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- base_model: torch.nn.Module = getattr(self.model, self.base_model_prefix, None)
- if base_model is None:
- raise ValueError(
- f"Base model with prefix '{self.base_model_prefix}' not found in {type(self.model).__name__}"
- )
- if prefix:
- prefix = f"{prefix}."
- return base_model.named_parameters(prefix=f"{prefix}model.{self.base_model_prefix}")
-
- def task_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- base_model_parameter_names = dict(self.base_model_named_parameters(prefix=prefix)).keys()
- for name, param in self.named_parameters(prefix=prefix):
- if name not in base_model_parameter_names:
- yield name, param
-
- def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
- kwargs = {**inputs, **(targets or {})}
- return self.model(**kwargs)
-
- def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
- labels = torch.argmax(outputs.logits, dim=-1).to(torch.long)
- probabilities = torch.softmax(outputs.logits, dim=-1)
- return {"labels": labels, "probabilities": probabilities}
-
- def configure_optimizers(self):
- if self.task_learning_rate is not None:
- base_model_params = [param for name, param in self.base_model_named_parameters()]
- task_params = [param for name, param in self.task_named_parameters()]
- optimizer = AdamW(
- [
- {"params": base_model_params, "lr": self.learning_rate},
- {"params": task_params, "lr": self.task_learning_rate},
- ]
- )
- else:
- optimizer = AdamW(self.parameters(), lr=self.learning_rate)
-
- if self.warmup_proportion > 0.0:
- stepping_batches = self.trainer.estimated_stepping_batches
- scheduler = get_linear_schedule_with_warmup(
- optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches
- )
- return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
- else:
- return optimizer
diff --git a/src/pie_modules/models/simple_token_classification.py b/src/pie_modules/models/simple_token_classification.py
deleted file mode 100644
index 7879aae04..000000000
--- a/src/pie_modules/models/simple_token_classification.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import logging
-from typing import MutableMapping, Optional, Tuple, Union
-
-import torch
-from pytorch_ie import PyTorchIEModel
-from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
-from pytorch_lightning.utilities.types import OptimizerLRScheduler
-from torch import FloatTensor, LongTensor
-from transformers import AutoConfig, AutoModelForTokenClassification, BatchEncoding
-from transformers.modeling_outputs import TokenClassifierOutput
-from typing_extensions import TypeAlias
-
-from pie_modules.models.common import ModelWithBoilerplate
-
-# model inputs / outputs / targets
-InputType: TypeAlias = BatchEncoding
-OutputType: TypeAlias = TokenClassifierOutput
-TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]]
-# step inputs (batch) / outputs (loss)
-StepInputType: TypeAlias = Tuple[InputType, TargetType]
-StepOutputType: TypeAlias = FloatTensor
-
-
-logger = logging.getLogger(__name__)
-
-
-@PyTorchIEModel.register()
-class SimpleTokenClassificationModel(
- ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
- RequiresModelNameOrPath,
- RequiresNumClasses,
-):
- """A simple token classification model that wraps a (pretrained) model loaded with
- AutoModelForTokenClassification from the transformers library.
-
- The model is trained with a cross-entropy loss function and uses the Adam optimizer.
-
- Note that for training, the labels for the special tokens (as well as for padding tokens)
- are expected to have the value label_pad_id (-100 by default, which is the default ignore_index
- value for the CrossEntropyLoss). The predictions for these tokens are also replaced with
- label_pad_id to match the training labels for correct metric calculation. Therefore, the model
- requires the special_tokens_mask and attention_mask (for padding) to be passed as inputs.
-
- Args:
- model_name_or_path: The name or path of the pretrained transformer model to use.
- num_classes: The number of classes to predict.
- learning_rate: The learning rate to use for training.
- label_pad_id: The label id to use for padding labels (at the padding token positions
- as well as for the special tokens).
- """
-
- def __init__(
- self,
- model_name_or_path: str,
- num_classes: int,
- learning_rate: float = 1e-5,
- label_pad_id: int = -100,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- self.learning_rate = learning_rate
- self.label_pad_id = label_pad_id
- self.num_classes = num_classes
-
- config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_classes)
- if self.is_from_pretrained:
- self.model = AutoModelForTokenClassification.from_config(config=config)
- else:
- self.model = AutoModelForTokenClassification.from_pretrained(
- model_name_or_path, config=config
- )
-
- def forward(self, inputs: InputType, targets: Optional[TargetType] = None) -> OutputType:
- inputs_without_special_tokens_mask = {
- k: v for k, v in inputs.items() if k != "special_tokens_mask"
- }
- return self.model(**inputs_without_special_tokens_mask, **(targets or {}))
-
- def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
- # get the max index for each token from the logits
- tags_tensor = torch.argmax(outputs.logits, dim=-1).to(torch.long)
-
- # mask out the padding and special tokens
- tags_tensor = tags_tensor.masked_fill(inputs["attention_mask"] == 0, self.label_pad_id)
-
- # mask out the special tokens
- tags_tensor = tags_tensor.masked_fill(
- inputs["special_tokens_mask"] == 1, self.label_pad_id
- )
- probabilities = torch.softmax(outputs.logits, dim=-1)
-
- return {"labels": tags_tensor, "probabilities": probabilities}
-
- def configure_optimizers(self) -> OptimizerLRScheduler:
- return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
diff --git a/src/pie_modules/models/span_tuple_classification.py b/src/pie_modules/models/span_tuple_classification.py
deleted file mode 100644
index bbdbf4938..000000000
--- a/src/pie_modules/models/span_tuple_classification.py
+++ /dev/null
@@ -1,457 +0,0 @@
-import logging
-from dataclasses import dataclass
-from typing import Iterator, List, MutableMapping, Optional, Tuple, TypeVar, Union
-
-import torch
-from pytorch_ie import PyTorchIEModel
-from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
-from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
-from torch.nn import Dropout, Parameter
-from torch.optim import AdamW
-from transformers import AutoConfig, AutoModel, get_linear_schedule_with_warmup
-from transformers.utils import ModelOutput
-from typing_extensions import TypeAlias
-
-from .common import ModelWithBoilerplate
-
-
-class MLP(nn.Module):
- def __init__(self, n_in, n_out, dropout=0, activation=nn.GELU()):
- super().__init__()
- self.linear = nn.Linear(n_in, n_out)
- self.f = activation
- self.dropout = Dropout(p=dropout)
- self.reset_parameters()
-
- def reset_parameters(self):
- nn.init.xavier_normal_(self.linear.weight)
- nn.init.zeros_(self.linear.bias)
-
- def forward(self, x):
- x = self.f(self.linear(x))
- x = self.dropout(x)
- return x
-
-
-@dataclass
-class SpanPairClassifierOutput(ModelOutput):
- """Base class for outputs of span pair classification models.
-
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) :
- Classification loss.
- logits (`torch.FloatTensor` of shape `(num_valid_input_pairs_in_batch, config.num_labels)`):
- Classification scores (before SoftMax).
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
- The last hidden state of the transformer model. Returned if `return_embeddings=True`.
- span_embeddings (`torch.FloatTensor` of shape `(batch_size, num_spans, span_embedding_dim)`, *optional*):
- The embeddings of the spans. Returned if `return_embeddings=True`.
- tuple_embeddings (`torch.FloatTensor` of shape `(num_valid_input_pairs_in_batch, tuple_embedding_dim)`, *optional*):
- The embeddings of the tuples. Returned if `return_embeddings=True`.
- """
-
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- last_hidden_state: Optional[torch.FloatTensor] = None
- span_embeddings: Optional[torch.FloatTensor] = None
- tuple_embeddings: Optional[torch.FloatTensor] = None
-
-
-# model inputs / outputs / targets
-InputType: TypeAlias = MutableMapping[str, LongTensor]
-OutputType: TypeAlias = SpanPairClassifierOutput
-TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]]
-# step inputs (batch) / outputs (loss)
-StepInputType: TypeAlias = Tuple[InputType, TargetType]
-StepOutputType: TypeAlias = FloatTensor
-
-
-HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE = {
- "albert": "classifier_dropout_prob",
- "distilbert": "seq_classif_dropout",
-}
-
-logger = logging.getLogger(__name__)
-
-T = TypeVar("T", bound=Tensor)
-
-
-def get_embeddings_at_indices(embeddings: T, indices: LongTensor) -> T:
- # embeddings: (bs, seq_len, hidden_size)
- # indices: (bs, num_indices)
- hidden_size = embeddings.size(-1)
- # Expand dimensions of start_marker_positions to match hidden_states
- indices_expanded = indices.unsqueeze(-1).expand(-1, -1, hidden_size)
- # result: (bs, num_indices, hidden_size)
- result = embeddings.gather(1, indices_expanded)
- return result
-
-
-@PyTorchIEModel.register()
-class SpanTupleClassificationModel(
- ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
- RequiresModelNameOrPath,
- RequiresNumClasses,
-):
- """A span tuple classification model that uses a pooler to get a representation of the input
- spans and then applies a linear classifier to that representation. The pooler can be configured
- via the `span_embedding_mode` and `tuple_embedding_mode` arguments. It expects the input to
- contain the indices of the start and end tokens of the spans (for the span pooler) and the
- indices of the spans in the tuples to classify (for the tuple pooler).
-
- Args:
- model_name_or_path: The name or path of the HuggingFace model to use.
- num_classes: The number of classes for the classification task.
- span_embedding_mode: The mode to pool the hidden states for the spans. One of "start_token",
- "end_token", "start_and_end_token".
- tuple_embedding_mode: The mode to pool the span embeddings for the tuples. Possible values are
- "concat" (concatenate the embeddings of the tuple entries), "multiply2_and_concat"
- (multiply the embeddings of the first two entries and concatenate them with the
- embeddings of the first two entries) and "index_{idx}" (use the embedding of the entry
- at index {idx} as the tuple embedding). Note that "multiply2_and_concat" requires
- `num_tuple_entries=2`. Default: "multiply2_and_concat".
- num_tuple_entries: The number of entries in the tuples.
- tuple_entry_hidden_dim: If provided, the tuple entries (i.e. the span embeddings at the tuple indices)
- are mapped to this dimensionality before combining them. Default: 768.
- tokenizer_vocab_size: The size of the tokenizer vocabulary. If provided, the model's
- tokenizer embeddings are resized to this size.
- classifier_dropout: The dropout probability for the classifier. If not provided, the
- dropout probability is taken from the Huggingface model config.
- learning_rate: The learning rate for the optimizer.
- task_learning_rate: The learning rate for the task-specific parameters. If None, the
- learning rate for all parameters is set to `learning_rate`.
- warmup_proportion: The proportion of steps to warm up the learning rate.
- multi_label: If True, the model is trained as a multi-label classifier.
- multi_label_threshold: The threshold for the multi-label classifier, i.e. the probability
- above which a class is predicted.
- freeze_base_model: If True, the base model parameters are frozen.
- label_pad_value: The padding value for the labels.
- probability_pad_value: The padding value for the probabilities.
- **kwargs: Additional keyword arguments passed to the parent class,
- see :class:`ModelWithBoilerplate`.
- """
-
- def __init__(
- self,
- model_name_or_path: str,
- num_classes: int,
- span_embedding_mode: str = "start_and_end_token",
- tuple_embedding_mode: str = "multiply2_and_concat",
- num_tuple_entries: int = 2,
- tuple_entry_hidden_dim: Optional[int] = 768,
- tokenizer_vocab_size: Optional[int] = None,
- classifier_dropout: Optional[float] = None,
- learning_rate: float = 1e-5,
- task_learning_rate: Optional[float] = None,
- warmup_proportion: float = 0.1,
- multi_label: bool = False,
- multi_label_threshold: float = 0.5,
- freeze_base_model: bool = False,
- label_pad_value: int = -100,
- probability_pad_value: float = -1.0,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.save_hyperparameters()
-
- self.learning_rate = learning_rate
- self.task_learning_rate = task_learning_rate
- self.warmup_proportion = warmup_proportion
- self.freeze_base_model = freeze_base_model
- self.label_pad_value = label_pad_value
- self.probability_pad_value = probability_pad_value
-
- config = AutoConfig.from_pretrained(model_name_or_path)
- if self.is_from_pretrained:
- self.model = AutoModel.from_config(config=config)
- else:
- self.model = AutoModel.from_pretrained(model_name_or_path, config=config)
-
- if tokenizer_vocab_size is not None:
- self.model.resize_token_embeddings(tokenizer_vocab_size)
-
- if self.freeze_base_model:
- for param in self.model.parameters():
- param.requires_grad = False
-
- if classifier_dropout is None:
- # Get the classifier dropout value from the Huggingface model config.
- # This is a bit of a mess since some Configs use different variable names or change the semantics
- # of the dropout (e.g. DistilBert has one dropout prob for QA and one for Seq classification, and a
- # general one for embeddings, encoder and pooler).
- classifier_dropout_attr = HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE.get(
- config.model_type, "classifier_dropout"
- )
- classifier_dropout = getattr(config, classifier_dropout_attr) or 0.0
- self.dropout = nn.Dropout(classifier_dropout)
-
- # embedder for the spans
- self.span_embedding_mode = span_embedding_mode
- if self.span_embedding_mode in ["start_token", "end_token"]:
- self.span_embedding_dim = self.model.config.hidden_size
- elif self.span_embedding_mode in ["start_and_end_token"]:
- self.span_embedding_dim = self.model.config.hidden_size * 2
- else:
- raise ValueError(f"Invalid value for span_embedding_mode: {self.span_embedding_mode}")
-
- # embedder for the tuples
- self.num_tuple_entries = num_tuple_entries
- self.tuple_entry_hidden_dim = tuple_entry_hidden_dim
- if self.tuple_entry_hidden_dim is not None:
- self.tuple_entry_embedders = nn.ModuleList(
- [
- MLP(self.span_embedding_dim, self.tuple_entry_hidden_dim)
- for _ in range(num_tuple_entries)
- ]
- )
- tuple_entry_dim = self.tuple_entry_hidden_dim
- else:
- self.tuple_entry_embedders = None
- tuple_entry_dim = self.span_embedding_dim
- self.tuple_embedding_mode = tuple_embedding_mode
- if self.tuple_embedding_mode == "concat":
- tuple_embedding_dim = tuple_entry_dim * self.num_tuple_entries
- elif self.tuple_embedding_mode == "multiply2_and_concat":
- if self.num_tuple_entries != 2:
- raise ValueError(
- "tuple_embedding_mode='multiply2_and_concat' requires num_tuple_entries=2"
- )
- tuple_embedding_dim = tuple_entry_dim * 3
- elif self.tuple_embedding_mode.startswith("index_"):
- idx = int(self.tuple_embedding_mode.split("_")[1])
- if idx >= self.num_tuple_entries:
- raise ValueError(
- f"Invalid index IDX={idx} for tuple_embedding_mode='index_IDX'. "
- f"Number of entries in tuple: {self.num_tuple_entries}"
- )
- tuple_embedding_dim = tuple_entry_dim
- else:
- raise ValueError(
- f"Invalid value for tuple_embedding_mode: {self.tuple_embedding_mode}"
- )
-
- # classifier
- # TODO: do sth more sophisticated here
- self.classifier = nn.Linear(tuple_embedding_dim, num_classes)
-
- self.multi_label = multi_label
- self.multi_label_threshold = multi_label_threshold
- self.loss_fct = nn.BCEWithLogitsLoss() if self.multi_label else nn.CrossEntropyLoss()
-
- def span_embedder(
- self,
- hidden_state: FloatTensor,
- span_start_indices: LongTensor,
- span_end_indices: LongTensor,
- ) -> FloatTensor:
- """Create the span embeddings from the hidden states and the span start and end indices.
-
- Args:
- hidden_state: The last hidden state from the transformer model. shape: (batch_size, seq_len, hidden_size)
- span_start_indices: The indices of the start tokens of the spans. shape: (batch_size, num_spans)
- span_end_indices: The indices of the end tokens of the spans. shape: (batch_size, num_spans)
-
- Returns:
- The pooled span embeddings. shape: (batch_size, num_spans, hidden_size)
- """
-
- if self.span_embedding_mode == "start_token":
- span_embeddings = get_embeddings_at_indices(hidden_state, span_start_indices)
- elif self.span_embedding_mode == "end_token":
- span_embeddings = get_embeddings_at_indices(hidden_state, span_end_indices)
- elif self.span_embedding_mode == "start_and_end_token":
- span_embeddings = torch.cat(
- [
- get_embeddings_at_indices(hidden_state, span_start_indices),
- get_embeddings_at_indices(hidden_state, span_end_indices),
- ],
- dim=-1,
- )
- else:
- raise ValueError(f"Invalid value for span_embedding_mode: {self.span_embedding_mode}")
-
- return span_embeddings
-
- def tuple_embedder(
- self,
- span_embeddings: FloatTensor,
- tuple_indices: LongTensor,
- tuple_indices_mask: BoolTensor,
- ) -> FloatTensor:
- """Create the tuple embeddings from the span embeddings and the tuple indices.
-
- Args:
- span_embeddings: The span embeddings. shape: (batch_size, num_spans, span_embedding_size)
- tuple_indices: The indices of the spans in the tuples. shape: (batch_size, num_tuples, num_tuple_entries)
- tuple_indices_mask: A mask indicating which tuples are valid. shape: (batch_size, num_tuples)
-
- Returns:
- The pooled tuple embeddings. shape: (num_tuples_in_batch, num_tuple_entries * span_embedding_size)
- """
-
- if not tuple_indices.shape[-1] == self.num_tuple_entries:
- raise ValueError(
- f"Number of entries in tuple_indices should be equal to num_tuple_entries={self.num_tuple_entries}"
- )
- batch_size, max_num_spans = span_embeddings.shape[:2]
- # we need to add the batch offsets to the tuple indices to get the correct indices in the
- # flattened span_embeddings
- batch_offsets = (
- torch.arange(batch_size, device=tuple_indices.device).unsqueeze(-1).unsqueeze(-1)
- * max_num_spans
- )
- tuple_indices_with_offsets = tuple_indices + batch_offsets
- # shape: (num_tuples_in_batch, num_entries)
- valid_tuple_indices_flat = tuple_indices_with_offsets[tuple_indices_mask]
-
- # we need to flatten the span_embeddings to get the embeddings at the tuple indices
- # shape: (batch_size * num_spans, span_embedding_size)
- span_embeddings_flat = span_embeddings.view(-1, span_embeddings.size(-1))
-
- # map the span embeddings individually for each tuple entry
- # each entry has the shape: (batch_size * num_spans, tuple_entry_dim)
- if self.tuple_entry_embedders is not None:
- span_embeddings_mapped = [
- mlp(span_embeddings_flat) for mlp in self.tuple_entry_embedders
- ]
- else:
- span_embeddings_mapped = [span_embeddings_flat] * self.num_tuple_entries
-
- tuple_embeddings_list: List[FloatTensor] = []
- for i in range(self.num_tuple_entries):
- # shape: (num_tuples_in_batch)
- current_tuple_indices = valid_tuple_indices_flat[:, i]
- # get the embeddings that were mapped with the mlp for the current entry
- # shape: (batch_size * num_spans, tuple_entry_dim)
- span_embeddings_mapped_for_entry = span_embeddings_mapped[i]
- # shape: (num_tuples_in_batch, tuple_entry_dim)
- current_embeddings = span_embeddings_mapped_for_entry[current_tuple_indices]
- tuple_embeddings_list.append(current_embeddings)
- if self.tuple_embedding_mode == "concat":
- tuple_embeddings = torch.cat(tuple_embeddings_list, dim=-1).to(span_embeddings.dtype)
- elif self.tuple_embedding_mode == "multiply2_and_concat":
- tuple_embeddings = torch.cat(
- [
- tuple_embeddings_list[0] * tuple_embeddings_list[1],
- tuple_embeddings_list[0],
- tuple_embeddings_list[1],
- ],
- dim=-1,
- )
- elif self.tuple_embedding_mode.startswith("index_"):
- index = int(self.tuple_embedding_mode.split("_")[1])
- tuple_embeddings = tuple_embeddings_list[index]
- else:
- raise ValueError(
- f"Invalid value for tuple_embedding_mode: {self.tuple_embedding_mode}"
- )
- return tuple_embeddings
-
- def forward(
- self,
- inputs: InputType,
- targets: Optional[TargetType] = None,
- return_embeddings: bool = False,
- ) -> OutputType:
- span_embedder_inputs = {}
- tuple_embedder_inputs = {}
- base_model_inputs = {}
- for k, v in inputs.items():
- if k.startswith("span_"):
- span_embedder_inputs[k] = v
- elif k.startswith("tuple_"):
- tuple_embedder_inputs[k] = v
- else:
- base_model_inputs[k] = v
-
- output = self.model(**base_model_inputs)
- last_hidden_state = self.dropout(output.last_hidden_state)
-
- # get the span embeddings from the hidden states and the start and end marker positions
- span_embeddings = self.span_embedder(
- hidden_state=last_hidden_state, **span_embedder_inputs
- )
- # get the tuple embeddings from the span embeddings and the tuple indices
- # Note that this flattens the batch dimension to not compute embeddings for padding tuples!
- tuple_embeddings_flat = self.tuple_embedder(
- span_embeddings=span_embeddings, **tuple_embedder_inputs
- )
-
- logits_valid = self.classifier(tuple_embeddings_flat)
-
- result = {"logits": logits_valid}
- if targets is not None:
- labels = targets["labels"]
- mask = inputs["tuple_indices_mask"]
- valid_labels = labels[mask]
- loss = self.loss_fct(logits_valid, valid_labels)
- result["loss"] = loss
-
- if return_embeddings:
- result["last_hidden_state"] = last_hidden_state
- result["tuple_embeddings"] = tuple_embeddings_flat
- result["span_embeddings"] = span_embeddings
-
- return SpanPairClassifierOutput(**result)
-
- def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
- if not self.multi_label:
- labels_flat = torch.argmax(outputs.logits, dim=-1).to(torch.long)
- probabilities_flat = torch.softmax(outputs.logits, dim=-1)
- else:
- probabilities_flat = torch.sigmoid(outputs.logits)
- labels_flat = (probabilities_flat > self.multi_label_threshold).to(torch.long)
-
- # re-construct the original shape
- mask = inputs["tuple_indices_mask"]
- # create "empty" labels and probabilities tensors
- labels = (
- torch.ones(mask.shape, dtype=torch.long, device=labels_flat.device)
- * self.label_pad_value
- )
- prob_shape = list(mask.shape) + [probabilities_flat.shape[-1]]
- probabilities = (
- torch.ones(prob_shape, dtype=torch.float, device=probabilities_flat.device)
- * self.probability_pad_value
- )
- # fill in the valid values
- labels[mask] = labels_flat
- probabilities[mask] = probabilities_flat
-
- return {"labels": labels, "probabilities": probabilities}
-
- def base_model_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- if prefix:
- prefix = f"{prefix}."
- return self.model.named_parameters(prefix=f"{prefix}model")
-
- def task_named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:
- if prefix:
- prefix = f"{prefix}."
- base_model_parameter_names = dict(self.base_model_named_parameters(prefix=prefix)).keys()
- for name, param in self.named_parameters(prefix=prefix):
- if name not in base_model_parameter_names:
- yield name, param
-
- def configure_optimizers(self):
- if self.task_learning_rate is not None:
- base_model_params = (param for name, param in self.base_model_named_parameters())
- task_params = (param for name, param in self.task_named_parameters())
- optimizer = AdamW(
- [
- {"params": base_model_params, "lr": self.learning_rate},
- {"params": task_params, "lr": self.task_learning_rate},
- ]
- )
- else:
- optimizer = AdamW(self.parameters(), lr=self.learning_rate)
-
- if self.warmup_proportion > 0.0:
- stepping_batches = self.trainer.estimated_stepping_batches
- scheduler = get_linear_schedule_with_warmup(
- optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches
- )
- return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
- else:
- return optimizer
diff --git a/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py b/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py
deleted file mode 100644
index e5b0b6e16..000000000
--- a/src/pie_modules/models/token_classification_with_seq2seq_encoder_and_crf.py
+++ /dev/null
@@ -1,247 +0,0 @@
-import logging
-from typing import Any, Dict, MutableMapping, Optional, Tuple, Union
-
-import torch
-from pytorch_ie import PyTorchIEModel
-from pytorch_ie.models.interface import RequiresModelNameOrPath, RequiresNumClasses
-from pytorch_lightning.utilities.types import OptimizerLRScheduler
-from torch import FloatTensor, LongTensor, nn
-from transformers import (
- AutoConfig,
- AutoModel,
- BatchEncoding,
- get_linear_schedule_with_warmup,
-)
-from transformers.modeling_outputs import TokenClassifierOutput
-from typing_extensions import TypeAlias
-
-from pie_modules.models.common import ModelWithBoilerplate
-from pie_modules.models.components.seq2seq_encoder import build_seq2seq_encoder
-
-# model inputs / outputs / targets
-InputType: TypeAlias = BatchEncoding
-OutputType: TypeAlias = TokenClassifierOutput
-TargetType: TypeAlias = MutableMapping[str, Union[LongTensor, FloatTensor]]
-# step inputs (batch) / outputs (loss)
-StepInputType: TypeAlias = Tuple[InputType, TargetType]
-StepOutputType: TypeAlias = FloatTensor
-
-HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE = {
- "bert": "hidden_dropout_prob",
- "roberta": "hidden_dropout_prob",
- "albert": "classifier_dropout_prob",
- "distilbert": "seq_classif_dropout",
- "deberta-v2": "hidden_dropout_prob",
- "longformer": "hidden_dropout_prob",
-}
-
-logger = logging.getLogger(__name__)
-
-
-@PyTorchIEModel.register()
-class TokenClassificationModelWithSeq2SeqEncoderAndCrf(
- ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
- RequiresNumClasses,
- RequiresModelNameOrPath,
-):
- """A token classification model that wraps a (pretrained) model loaded with AutoModel from the
- transformers library. The model can optionally be followed by a seq2seq encoder (e.g. an LSTM).
- Finally, Conditional Random Fields (CRFs) can be used to decode the predictions.
-
- The model is trained with a cross-entropy loss function and uses the Adam optimizer.
-
- Note that for training, the labels for the special tokens (as well as for padding tokens)
- are expected to have the value label_pad_id (-100 by default, which is the default ignore_index
- value for the CrossEntropyLoss). The predictions for these tokens are also replaced with
- label_pad_id to match the training labels for correct metric calculation. Therefore, the model
- requires the special_tokens_mask and attention_mask (for padding) to be passed as inputs.
-
- Args:
- model_name_or_path: The name or path of the (pretrained) transformer model to use.
- num_classes: The number of classes to predict.
- learning_rate: The learning rate to use for training.
- task_learning_rate: The learning rate to use for the task-specific parameters, i.e.
- for the sequence-to-sequence encoder, classification head, and CRF. If None, the
- learning_rate is used for all parameters.
- use_crf: Whether to use a CRF to decode the predictions.
- label_pad_id: The label id to use for padding labels (at the padding token positions
- as well as for the special tokens).
- special_token_label_id: The label id to use for special tokens (e.g. [CLS], [SEP]). This
- is used to replace the targets for special tokens with the label_pad_id before passing
- them to the CRF because the CRF does not allow the first token to be masked out.
- classifier_dropout: The dropout probability to use for the classification head.
- freeze_base_model: Whether to freeze the base model (i.e. the transformer) during training.
- warmup_proportion: The proportion of training steps to use for the linear warmup.
- seq2seq_encoder: A dictionary with the configuration for the seq2seq encoder. If None, no
- seq2seq encoder is used. See ./components/seq2seq_encoder.py for further information.
- """
-
- def __init__(
- self,
- model_name_or_path: str,
- num_classes: int,
- learning_rate: float = 1e-5,
- task_learning_rate: Optional[float] = None,
- use_crf: bool = True,
- label_pad_id: int = -100,
- special_token_label_id: int = 0,
- classifier_dropout: Optional[float] = None,
- freeze_base_model: bool = False,
- warmup_proportion: float = 0.1,
- seq2seq_encoder: Optional[Dict[str, Any]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- self.special_token_label_id = special_token_label_id
-
- self.learning_rate = learning_rate
- self.warmup_proportion = warmup_proportion
- self.task_learning_rate = task_learning_rate
- self.label_pad_id = label_pad_id
- self.num_classes = num_classes
-
- config = AutoConfig.from_pretrained(model_name_or_path)
- if self.is_from_pretrained:
- self.model = AutoModel.from_config(config=config)
- else:
- self.model = AutoModel.from_pretrained(model_name_or_path, config=config)
-
- if freeze_base_model:
- self.model.requires_grad_(False)
-
- hidden_size = config.hidden_size
- self.seq2seq_encoder = None
- if seq2seq_encoder is not None:
- self.seq2seq_encoder, hidden_size = build_seq2seq_encoder(
- config=seq2seq_encoder, input_size=hidden_size
- )
-
- if classifier_dropout is None:
- # Get the classifier dropout value from the Huggingface model config.
- # This is a bit of a mess since some Configs use different variable names or change the semantics
- # of the dropout (e.g. DistilBert has one dropout prob for QA and one for Seq classification, and a
- # general one for embeddings, encoder and pooler).
- classifier_dropout_attr = HF_MODEL_TYPE_TO_CLASSIFIER_DROPOUT_ATTRIBUTE.get(
- config.model_type, "classifier_dropout"
- )
- if hasattr(config, classifier_dropout_attr):
- classifier_dropout = getattr(config, classifier_dropout_attr)
- else:
- raise ValueError(
- f"The config {type(config),__name__} loaded from {model_name_or_path} has no attribute "
- f"{classifier_dropout_attr}"
- )
- self.dropout = nn.Dropout(classifier_dropout)
-
- self.classifier = nn.Linear(hidden_size, num_classes)
-
- if use_crf:
- try:
- from torchcrf import CRF
- except ImportError:
- raise ImportError(
- "To use CRFs, the torchcrf package must be installed. "
- "You can install it with `pip install pytorch-crf`."
- )
-
- self.crf = CRF(num_tags=num_classes, batch_first=True)
- else:
- self.crf = None
-
- def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
- result = {}
- logits = outputs.logits
- attention_mask = inputs["attention_mask"]
- special_tokens_mask = inputs["special_tokens_mask"]
- attention_mask_bool = attention_mask.to(torch.bool)
- if self.crf is not None:
- decoded_tags = self.crf.decode(emissions=logits, mask=attention_mask_bool)
- # pad the decoded tags to the length of the logits to have the same shape as when not using the crf
- seq_len = logits.shape[1]
- padded_tags = [
- tags + [self.label_pad_id] * (seq_len - len(tags)) for tags in decoded_tags
- ]
- tags_tensor = torch.tensor(padded_tags, device=logits.device).to(torch.long)
- else:
- # get the max index for each token from the logits
- tags_tensor = torch.argmax(logits, dim=-1).to(torch.long)
- # set the padding and special tokens to the label_pad_id
- mask = attention_mask_bool & ~special_tokens_mask.to(torch.bool)
- tags_tensor = tags_tensor.masked_fill(~mask, self.label_pad_id)
-
- result["labels"] = tags_tensor
- # TODO: is it correct to use this also in the case of the crf?
- result["probabilities"] = torch.softmax(logits, dim=-1)
-
- return result
-
- def forward(
- self, inputs: InputType, targets: Optional[TargetType] = None
- ) -> TokenClassifierOutput:
- inputs_without_special_tokens_mask = {
- k: v for k, v in inputs.items() if k != "special_tokens_mask"
- }
- outputs = self.model(**inputs_without_special_tokens_mask)
- sequence_output = outputs[0]
-
- if self.seq2seq_encoder is not None:
- sequence_output = self.seq2seq_encoder(sequence_output)
-
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
-
- loss = None
- if targets is not None:
- labels = targets["labels"]
- if self.crf is not None:
- # Overwrite the padding labels with ignore_index. Note that this is different from the
- # attention_mask, because the attention_mask includes special tokens, whereas the labels
- # are set to label_pad_id also for special tokens (e.g. [CLS]). We need handle all
- # occurrences of label_pad_id because usually that index is out of range with respect to
- # the number of logits in which case the crf would complain. However, we can not simply
- # pass a mask to the crf that also masks out the special tokens, because the crf does not
- # allow the first token to be masked out.
- mask_pad_or_special = labels == self.label_pad_id
- labels_valid = labels.masked_fill(mask_pad_or_special, self.special_token_label_id)
- # the crf expects a bool mask
- if "attention_mask" in inputs:
- mask_bool = inputs["attention_mask"].to(torch.bool)
- else:
- mask_bool = None
- log_likelihood = self.crf(emissions=logits, tags=labels_valid, mask=mask_bool)
- loss = -log_likelihood
- else:
- loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_pad_id)
- loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
-
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
-
- def configure_optimizers(self) -> OptimizerLRScheduler:
- if self.task_learning_rate is not None:
- all_params = dict(self.named_parameters())
- base_model_params = dict(self.model.named_parameters(prefix="model"))
- task_params = {k: v for k, v in all_params.items() if k not in base_model_params}
- optimizer = torch.optim.AdamW(
- [
- {"params": base_model_params.values(), "lr": self.learning_rate},
- {"params": task_params.values(), "lr": self.task_learning_rate},
- ]
- )
- else:
- optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
-
- if self.warmup_proportion > 0.0:
- stepping_batches = self.trainer.estimated_stepping_batches
- scheduler = get_linear_schedule_with_warmup(
- optimizer, int(stepping_batches * self.warmup_proportion), stepping_batches
- )
- return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
- else:
- return optimizer
diff --git a/src/pie_modules/taskmodules/__init__.py b/src/pie_modules/taskmodules/__init__.py
deleted file mode 100644
index 46d3766ae..000000000
--- a/src/pie_modules/taskmodules/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from .cross_text_binary_coref import CrossTextBinaryCorefTaskModule
-from .extractive_question_answering import ExtractiveQuestionAnsweringTaskModule
-from .labeled_span_extraction_by_token_classification import (
- LabeledSpanExtractionByTokenClassificationTaskModule,
-)
-from .pointer_network_for_end2end_re import PointerNetworkTaskModuleForEnd2EndRE
-from .re_span_pair_classification import RESpanPairClassificationTaskModule
-from .re_text_classification_with_indices import (
- RETextClassificationWithIndicesTaskModule,
-)
-from .text_to_text import TextToTextTaskModule
diff --git a/src/pie_modules/taskmodules/common/__init__.py b/src/pie_modules/taskmodules/common/__init__.py
index e95a3f006..e69de29bb 100644
--- a/src/pie_modules/taskmodules/common/__init__.py
+++ b/src/pie_modules/taskmodules/common/__init__.py
@@ -1,4 +0,0 @@
-from .interfaces import AnnotationEncoderDecoder, DecodingException
-from .mixins import BatchableMixin, RelationStatisticsMixin, StatisticsMixin
-from .taskmodule_with_document_converter import TaskModuleWithDocumentConverter
-from .utils import get_first_occurrence_index
diff --git a/src/pie_modules/taskmodules/common/interfaces.py b/src/pie_modules/taskmodules/common/interfaces.py
deleted file mode 100644
index 388b60ae9..000000000
--- a/src/pie_modules/taskmodules/common/interfaces.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import abc
-from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
-
-from pie_core import Annotation
-
-# Annotation Encoding type: encoding for a single annotation
-AE = TypeVar("AE")
-# Annotation type
-A = TypeVar("A", bound=Annotation)
-# Annotation Collection Encoding type: encoding for a collection of annotations,
-# e.g. all relevant annotations for a document
-ACE = TypeVar("ACE")
-
-
-class DecodingException(Exception, Generic[AE], abc.ABC):
- """Exception raised when decoding fails."""
-
- identifier: str
-
- def __init__(self, message: str, encoding: AE):
- self.message = message
- self.encoding = encoding
-
-
-class AnnotationEncoderDecoder(abc.ABC, Generic[A, AE]):
- """Base class for annotation encoders and decoders."""
-
- @abc.abstractmethod
- def encode(self, annotation: A, metadata: Optional[Dict[str, Any]] = None) -> AE:
- pass
-
- @abc.abstractmethod
- def decode(self, encoding: AE, metadata: Optional[Dict[str, Any]] = None) -> A:
- pass
-
- def build_decoding_constraints(
- self, partial_encoding: AE
- ) -> Tuple[Optional[Any], Optional[Any]]:
- """Given a partial encoding, build the constraints for the next encoding step.
-
- Returns:
- - A tuple of two elements:
- - The first element is a set of positive constraints for the decoder.
- - The second element is a set of negative constraints for the decoder.
- """
- raise NotImplementedError(
- "build_decoder_constraints is not implemented for this encoder/decoder."
- )
-
- def parse(self, encoding: AE) -> Tuple[List[A], Dict[str, int], AE]:
- """Parse the encoding and return a list of annotations. This should be error tolerant and
- return all annotations that can be parsed and the remaining encoding.
-
- Args:
- encoding: The encoding to parse. Can be incomplete.
-
- Returns:
- - A tuple of three elements:
- - A list of encoded annotations.
- - A dictionary mapping error types to their counts.
- - The remaining encoding after parsing.
- """
- raise NotImplementedError("parse is not implemented for this encoder/decoder.")
diff --git a/src/pie_modules/taskmodules/common/mixins.py b/src/pie_modules/taskmodules/common/mixins.py
deleted file mode 100644
index 0e2909893..000000000
--- a/src/pie_modules/taskmodules/common/mixins.py
+++ /dev/null
@@ -1,297 +0,0 @@
-import dataclasses
-import logging
-from abc import ABC, abstractmethod
-from collections import defaultdict
-from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar
-
-import pandas as pd
-import torch
-import torch.nn.functional as F
-from pie_core import Annotation
-from torch import Tensor
-
-logger = logging.getLogger(__name__)
-
-
-def _pad_tensor(tensor: Tensor, target_shape: List[int], pad_value: float) -> Tensor:
- """Pad a tensor to a target shape.
-
- Args:
- tensor: The tensor to pad.
- target_shape: The target shape.
- pad_value: The value to use for padding.
-
- Returns: The padded tensor.
- """
-
- shape = tensor.shape
- pad: List[int] = []
- for i, s in enumerate(shape):
- pad = [0, target_shape[i] - s] + pad
- result = F.pad(tensor, pad=pad, value=pad_value)
-
- return result
-
-
-def maybe_pad_values(
- values: Any, pad_value: Optional[Any] = None, strategy: str = "longest"
-) -> Any:
- """If an iterable of values is passed and a pad value is given, pad the values to the same
- length and create a tensor from them. Otherwise, return the values unchanged.
-
- Note that the padding is done on all dimensions.
-
- Args:
- values: The values to pad.
- pad_value: The value to use for padding.
- strategy: The padding strategy. Currently only "longest" is supported.
-
- Returns: The padded values.
- """
-
- if pad_value is None:
- return values
- if not isinstance(values, Iterable):
- raise TypeError(f"values must be iterable to pad them, but got {type(values)}")
- if strategy != "longest":
- raise ValueError(f"unknown padding strategy: {strategy}")
- tensor_list = [torch.tensor(value_list) for value_list in values]
- shape_lists = list(zip(*[t.shape for t in tensor_list]))
- max_shape = [max(dims) for dims in shape_lists]
- padded = [
- _pad_tensor(tensor=t, target_shape=max_shape, pad_value=pad_value)
- for i, t in enumerate(tensor_list)
- ]
- return torch.stack(padded)
-
-
-def maybe_to_tensor(
- values: Iterable[Any], dtype: Optional[torch.dtype] = None, pad_value: Optional[Any] = None
-) -> Any:
- """If an iterable of values is passed and a dtype is given, convert the values to a tensor of
- the given type.
-
- Args:
- values: The values to convert.
- dtype: A dtype to convert the values to.
- pad_value: A pad value to use if the values are padded.
-
- Returns: A tensor or the values unchanged.
- """
-
- if all(v is None for v in values):
- return None
- if dtype is None:
- return values
- maybe_padded = maybe_pad_values(values=values, pad_value=pad_value)
- if not isinstance(maybe_padded, torch.Tensor):
- maybe_padded = torch.Tensor(maybe_padded)
- tensor = maybe_padded.to(dtype=dtype)
- return tensor
-
-
-class BatchableMixin:
- """A mixin class that provides a batch method to batch a list of instances of the class. All
- attributes, but also property methods, are batched. The batch method returns a dictionary with
- all attribute / property names as keys. The values are tensors created from the stacked values
- of the attributes / properties. The tensors are padded to the length of the longest instance in
- the batch and converted to the given dtype.
-
- Example:
- >>> import dataclasses
- >>> from typing import List, Dict
- >>> import torch
- >>>
- >>> @dataclasses.dataclass
- >>> class Foo(BatchableMixin):
- >>> a: List[int]
- >>>
- >>> @property
- >>> def len_a(self):
- >>> return len(self.a)
- >>>
- >>> x = Foo(a=[1, 2, 3])
- >>> y = Foo(a=[4, 5])
- >>>
- >>> Foo.batch(values=[x, y], dtypes={"a": torch.int64, "len_a": torch.int64}, pad_values={"a": 0})
- {'a': tensor([[1, 2, 3],[4, 5, 0]]), 'len_a': tensor([3, 2])}
- """
-
- @classmethod
- def get_property_names(cls) -> List[str]:
- return [name for name in cls.__dict__ if isinstance(getattr(cls, name), property)]
-
- @classmethod
- def get_dataclass_field_names(cls) -> List[str]:
- if dataclasses.is_dataclass(cls):
- return [f.name for f in dataclasses.fields(cls)]
- else:
- return []
-
- @classmethod
- def get_attribute_names(cls) -> List[str]:
- return cls.get_property_names() + cls.get_dataclass_field_names()
-
- @classmethod
- def batch(
- cls,
- values: List[Any],
- dtypes: Dict[str, torch.dtype],
- pad_values: Dict[str, Any],
- ) -> Dict[str, Any]:
- attribute_names = cls.get_attribute_names()
- return {
- k: maybe_to_tensor(
- values=[getattr(x, k) for x in values],
- dtype=dtypes.get(k, None),
- pad_value=pad_values.get(k, None),
- )
- for k in attribute_names
- # Only batch attributes that are not None for any of the values.
- if not all(getattr(x, k) is None for x in values)
- }
-
-
-T = TypeVar("T")
-
-
-def increase_counter(
- key: Tuple[Any, ...],
- statistics: Dict[Tuple[Any, ...], int],
- value: int = 1,
-):
- key_s = tuple(str(k) for k in key)
- statistics[key_s] += value
-
-
-class StatisticsMixin(ABC, Generic[T]):
- """A mixin class that provides methods to collect and format statistics.
-
- Args:
- collect_statistics: Control whether statistics should be collected.
- If `False`, the mixin will not show any statistics when calling
- `show_statistics`. Further effects depend on the implementation
- of the mixin.
- **kwargs: Additional keyword arguments to pass to the parent class.
- """
-
- def __init__(self, collect_statistics: bool = False, **kwargs):
- super().__init__(**kwargs)
- self.collect_statistics = collect_statistics
- self.reset_statistics()
-
- @abstractmethod
- def reset_statistics(self):
- """Reset the statistics collected by this mixin (state)."""
- pass
-
- @abstractmethod
- def get_statistics(self) -> T:
- """Get the statistics collected by this mixin.
-
- This should *not* modify the state of the mixin, repeated calls should return the same
- result!
- """
- pass
-
- def format_statistics(self, statistics: T) -> str:
- """Format the statistics collected by this mixin as string for display (usually on
- console)."""
- raise NotImplementedError(
- f"format_statistics is not implemented for {self.__class__.__name__}. "
- "Please implement this method to show formatted statistics."
- )
-
- def show_statistics(self):
- if self.collect_statistics:
- logger.info(f"statistics:\n{self.format_statistics(self.get_statistics())}")
-
-
-class RelationStatisticsMixin(StatisticsMixin[Dict[Tuple[str, str], int]]):
- """A mixin class that provides methods to collect and format statistics about relations.
-
- This mixin collects statistics about relations, such as the number of available, used, and
- skipped relations.
- """
-
- def get_none_label_for_statistics(self) -> str:
- if not hasattr(self, "_statistics_none_label"):
- if hasattr(self, "none_label"):
- # If the mixin has a `none_label` attribute, use it as the label for "no relation".
- self._statistics_none_label = self.none_label
- else:
- self._statistics_none_label = "no_relation"
- logger.warning(
- f"{type(self).__name__} does not have a `none_label` attribute. "
- "Using default value 'no_relation'. "
- "`none_label` is used as the label for relations with score 0 in statistics and "
- "all relations with label different from `none_label` will be summarized to 'all_relations'. "
- "Set the `none_label` attribute before using statistics or "
- "overwrite `get_none_label_for_statistics()` function to get rid of this message."
- )
-
- return self._statistics_none_label
-
- def reset_statistics(self):
- self._collected_relations: Dict[str, List[Annotation]] = defaultdict(list)
-
- def collect_relation(self, kind: str, relation: Annotation):
- if self.collect_statistics:
- self._collected_relations[kind].append(relation)
-
- def collect_all_relations(self, kind: str, relations: Iterable[Annotation]):
- if self.collect_statistics:
- self._collected_relations[kind].extend(relations)
-
- def get_statistics(self) -> Dict[Tuple[str, str], int]:
- if self.collect_statistics:
- # create statistics from the collected relations
- statistics: Dict[Tuple[str, str], int] = defaultdict(int)
- all_relations = set(self._collected_relations["available"])
- used_relations = set(self._collected_relations["used"])
- skipped_other = all_relations - used_relations
- for key, rels in self._collected_relations.items():
- rels_set = set(rels)
- if key.startswith("skipped_"):
- skipped_other -= rels_set
- elif key.startswith("used_"):
- pass
- elif key in ["available", "used"]:
- pass
- else:
- raise ValueError(f"unknown key: {key}")
- for rel in rels_set:
- # Set `none_label` as label when the score is zero. We encode negative relations
- # in such a way in the case of multi-label or binary (similarity for coref).
- label = rel.label if rel.score > 0 else self.get_none_label_for_statistics()
- increase_counter(key=(key, label), statistics=statistics)
- for rel in skipped_other:
- increase_counter(key=("skipped_other", rel.label), statistics=statistics)
-
- return dict(statistics)
- else:
- return {}
-
- def format_statistics(self, statistics: Dict[Tuple[str, str], int]) -> str:
- if len(statistics) > 0:
- to_show_series = pd.Series(statistics)
- # unstack index to have relation labels as column names
- to_show = to_show_series.unstack()
- else:
- # If there were no statistics, create an empty dummy dataframe.
- to_show = pd.DataFrame(pd.Series(dict()))
- # fill missing values with 0 and convert back to int (unstacking may introduce NaNs which are float type)
- to_show = to_show.fillna(0).astype(int)
- if to_show.columns.size > 1:
- to_show["all_relations"] = to_show.loc[
- :, to_show.columns != self.get_none_label_for_statistics()
- ].sum(axis=1)
-
- # transpose
- # to have the labels (which may be a lot) as index for improved readability and
- # to allow to keep counts as int columns (dtypes are per-column, not per-row)
- to_show = to_show.T
- if "used" in to_show.columns and "available" in to_show.columns:
- to_show["used %"] = (100 * to_show["used"] / to_show["available"]).round()
-
- return to_show.to_markdown()
diff --git a/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py b/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py
deleted file mode 100644
index 69f8e7668..000000000
--- a/src/pie_modules/taskmodules/common/taskmodule_with_document_converter.py
+++ /dev/null
@@ -1,117 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Generic, Iterable, Iterator, Optional, Sequence, Type, TypeVar, Union
-
-from pie_core import (
- Document,
- IterableTaskEncodingDataset,
- TaskEncoding,
- TaskEncodingDataset,
- TaskEncodingSequence,
- TaskModule,
-)
-from typing_extensions import TypeAlias
-
-DocumentType = TypeVar("DocumentType", bound=Document)
-ConvertedDocumentType = TypeVar("ConvertedDocumentType", bound=Document)
-InputEncodingType = TypeVar("InputEncodingType")
-TargetEncodingType = TypeVar("TargetEncodingType")
-# TaskEncoding: defined below
-TaskBatchEncodingType = TypeVar("TaskBatchEncodingType")
-# ModelBatchEncoding: defined in models
-ModelBatchOutputType = TypeVar("ModelBatchOutputType")
-TaskOutputType = TypeVar("TaskOutputType")
-
-TaskEncodingType: TypeAlias = TaskEncoding[
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
-]
-
-
-class TaskModuleWithDocumentConverter(
- TaskModule,
- ABC,
- Generic[
- ConvertedDocumentType,
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
- TaskBatchEncodingType,
- ModelBatchOutputType,
- TaskOutputType,
- ],
-):
- @property
- def document_type(self) -> Optional[Type[Document]]:
- if super().document_type is not None:
- raise NotImplementedError(f"please overwrite document_type for {type(self).__name__}")
- else:
- return None
-
- @abstractmethod
- def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
- """Convert a document of the taskmodule document type to the expected document type of the
- wrapped taskmodule.
-
- Args:
- document: the input document
-
- Returns: the converted document
- """
- pass
-
- def _prepare(self, documents: Sequence[DocumentType]) -> None:
- # use an iterator for lazy processing
- documents_converted = (self._convert_document(doc) for doc in documents)
- super()._prepare(documents=documents_converted)
-
- def convert_document(self, document: DocumentType) -> ConvertedDocumentType:
- converted_document = self._convert_document(document)
- if "original_document" in converted_document.metadata:
- raise ValueError(
- f"metadata of converted_document has already and entry 'original_document', "
- f"this is not allowed. Please adjust '{type(self).__name__}._convert_document()' "
- f"to produce documents without that key in metadata."
- )
- converted_document.metadata["original_document"] = document
- return converted_document
-
- def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs) -> Union[
- Sequence[TaskEncodingType],
- TaskEncodingSequence[TaskEncodingType, DocumentType],
- Iterator[TaskEncodingType],
- TaskEncodingDataset[TaskEncodingType],
- IterableTaskEncodingDataset[TaskEncodingType],
- ]:
- converted_documents: Union[DocumentType, Iterable[DocumentType]]
- if isinstance(documents, Document):
- converted_documents = self.convert_document(documents)
- else:
- converted_documents = [self.convert_document(doc) for doc in documents]
- return super().encode(documents=converted_documents, **kwargs)
-
- def decode(self, **kwargs) -> Sequence[DocumentType]:
- decoded_documents = super().decode(**kwargs)
- result = []
- for doc in decoded_documents:
- original_document = doc.metadata["original_document"]
- self._integrate_predictions_from_converted_document(
- converted_document=doc, document=original_document
- )
- result.append(original_document)
- return result
-
- @abstractmethod
- def _integrate_predictions_from_converted_document(
- self,
- document: DocumentType,
- converted_document: ConvertedDocumentType,
- ) -> None:
- """Convert the predictions at the respective layers of the converted_document and add them
- to the original document predictions.
-
- Args:
- document: document to attach the converted predictions to
- converted_document: the document returned by the wrapped taskmodule, including predictions
- """
- pass
diff --git a/src/pie_modules/taskmodules/common/utils.py b/src/pie_modules/taskmodules/common/utils.py
deleted file mode 100644
index cc4daf626..000000000
--- a/src/pie_modules/taskmodules/common/utils.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import logging
-from typing import Union
-
-import torch
-
-logger = logging.getLogger(__name__)
-
-
-def get_first_occurrence_index(
- tensor: Union[torch.FloatTensor, torch.LongTensor], value: Union[float, int]
-) -> torch.LongTensor:
- """Returns the index of the first occurrence of `value` in each row of `tensor`. If `value` is
- not found, seq_len is returned.
-
- Args:
- tensor: the tensor of shape (bsz, seq_len) to search in
- value: the value to search for
-
- Returns: a tensor of shape (bsz,) containing the index of the first occurrence of `value` in each row of `tensor`.
- """
-
- mask_value = tensor.eq(value)
- # count matching positions from the end
- value_counts_to_end = mask_value.flip(dims=[1]).cumsum(dim=1).flip(dims=[1])
- # at the first position stands the number of total matches
- total_matches = value_counts_to_end[:, 0]
- # the sum of all positions where the number of matches is equal to the total number of matches
- # is the index *after* the first occurrence
- result = value_counts_to_end.eq(total_matches.unsqueeze(-1)).sum(dim=1) - 1
- # set result to seq_len if no match was found
- result[total_matches == 0] = tensor.size(1)
- return result
diff --git a/src/pie_modules/taskmodules/cross_text_binary_coref.py b/src/pie_modules/taskmodules/cross_text_binary_coref.py
index cff2fa139..e69de29bb 100644
--- a/src/pie_modules/taskmodules/cross_text_binary_coref.py
+++ b/src/pie_modules/taskmodules/cross_text_binary_coref.py
@@ -1,292 +0,0 @@
-import copy
-import logging
-from typing import (
- Any,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Sequence,
- Tuple,
- TypedDict,
- TypeVar,
- Union,
-)
-
-import torch
-from pie_core import Annotation, TaskEncoding, TaskModule
-from pie_core.utils.dictionary import list_of_dicts2dict_of_lists
-from pytorch_ie.utils.window import get_window_around_slice
-from torchmetrics import MetricCollection
-from torchmetrics.classification import (
- BinaryAUROC,
- BinaryAveragePrecision,
- BinaryF1Score,
-)
-from transformers import AutoTokenizer, BatchEncoding
-from typing_extensions import TypeAlias
-
-from pie_modules.annotations import Span
-from pie_modules.documents import (
- TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
-)
-from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin
-from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction
-from pie_modules.utils.tokenization import (
- SpanNotAlignedWithTokenException,
- get_aligned_token_span,
-)
-
-logger = logging.getLogger(__name__)
-
-InputEncodingType: TypeAlias = Dict[str, Any]
-TargetEncodingType: TypeAlias = Sequence[float]
-DocumentType: TypeAlias = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
-
-TaskEncodingType: TypeAlias = TaskEncoding[
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
-]
-
-
-class TaskOutputType(TypedDict, total=False):
- score: float
- is_similar: bool
-
-
-ModelInputType: TypeAlias = Dict[str, torch.Tensor]
-ModelTargetType: TypeAlias = Dict[str, torch.Tensor]
-ModelOutputType: TypeAlias = Dict[str, torch.Tensor]
-
-TaskModuleType: TypeAlias = TaskModule[
- # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
- Tuple[ModelInputType, Optional[ModelTargetType]],
- ModelTargetType,
- TaskOutputType,
-]
-
-
-class SpanDoesNotFitIntoAvailableWindow(Exception):
- def __init__(self, span):
- self.span = span
-
-
-def _get_labels(model_output: ModelTargetType, label_threshold: float) -> torch.Tensor:
- return (model_output["scores"] > label_threshold).to(torch.int)
-
-
-def _get_scores(model_output: ModelTargetType) -> torch.Tensor:
- return model_output["scores"]
-
-
-S = TypeVar("S", bound=Span)
-
-
-def shift_span(span: S, offset: int) -> S:
- return span.copy(start=span.start + offset, end=span.end + offset)
-
-
-@TaskModule.register()
-class CrossTextBinaryCorefTaskModule(RelationStatisticsMixin, TaskModuleType):
- """This taskmodule processes documents of type
- TextPairDocumentWithLabeledSpansAndBinaryCorefRelations in preparation for a
- SequencePairSimilarityModelWithPooler."""
-
- DOCUMENT_TYPE = DocumentType
-
- def __init__(
- self,
- tokenizer_name_or_path: str,
- similarity_threshold: float = 0.9,
- max_window: Optional[int] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
- self.similarity_threshold = similarity_threshold
- self.max_window = max_window if max_window is not None else self.tokenizer.model_max_length
- self.available_window = self.max_window - self.tokenizer.num_special_tokens_to_add()
- self.num_special_tokens_before = len(self._get_special_tokens_before_input())
-
- def _get_special_tokens_before_input(self) -> List[int]:
- dummy_ids = self.tokenizer.build_inputs_with_special_tokens(token_ids_0=[-1])
- return dummy_ids[: dummy_ids.index(-1)]
-
- def encode(self, documents: Union[DocumentType, Iterable[DocumentType]], **kwargs):
- self.reset_statistics()
- result = super().encode(documents=documents, **kwargs)
- self.show_statistics()
- return result
-
- def truncate_encoding_around_span(
- self, encoding: BatchEncoding, char_span: Span
- ) -> Tuple[Dict[str, List[int]], Span]:
- input_ids = copy.deepcopy(encoding["input_ids"])
-
- token_span = get_aligned_token_span(encoding=encoding, char_span=char_span)
-
- # truncate input_ids and shift token_start and token_end
- if len(input_ids) > self.available_window:
- window_slice = get_window_around_slice(
- slice=(token_span.start, token_span.end),
- max_window_size=self.available_window,
- available_input_length=len(input_ids),
- )
- if window_slice is None:
- raise SpanDoesNotFitIntoAvailableWindow(span=token_span)
- window_start, window_end = window_slice
- input_ids = input_ids[window_start:window_end]
- token_span = shift_span(token_span, offset=-window_start)
-
- truncated_encoding = self.tokenizer.prepare_for_model(ids=input_ids)
- # shift indices because we added special tokens to the input_ids
- token_span = shift_span(token_span, offset=self.num_special_tokens_before)
-
- return truncated_encoding, token_span
-
- def encode_input(
- self,
- document: DocumentType,
- is_training: bool = False,
- ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
- self.collect_all_relations(kind="available", relations=document.binary_coref_relations)
- tokenizer_kwargs = dict(
- padding=False,
- truncation=False,
- add_special_tokens=False,
- )
- encoding = self.tokenizer(text=document.text, **tokenizer_kwargs)
- encoding_pair = self.tokenizer(text=document.text_pair, **tokenizer_kwargs)
-
- task_encodings = []
- for coref_rel in document.binary_coref_relations:
- # TODO: This can miss instances if both texts are the same. We could check that
- # coref_rel.head is in document.labeled_spans (same for the tail), but would this
- # slow down the encoding?
- if not (
- coref_rel.head.target == document.text
- or coref_rel.tail.target == document.text_pair
- ):
- raise ValueError(
- f"It is expected that coref relations go from (head) spans over 'text' "
- f"to (tail) spans over 'text_pair', but this is not the case for this "
- f"relation (i.e. it points into the other direction): {coref_rel.resolve()}"
- )
- try:
- current_encoding, token_span = self.truncate_encoding_around_span(
- encoding=encoding, char_span=coref_rel.head
- )
- current_encoding_pair, token_span_pair = self.truncate_encoding_around_span(
- encoding=encoding_pair, char_span=coref_rel.tail
- )
- except SpanNotAlignedWithTokenException as e:
- logger.warning(
- f"Could not get token offsets for argument ({e.span}) of coref relation: "
- f"{coref_rel.resolve()}. Skip it."
- )
- self.collect_relation(kind="skipped_args_not_aligned", relation=coref_rel)
- continue
- except SpanDoesNotFitIntoAvailableWindow as e:
- logger.warning(
- f"Argument span [{e.span}] does not fit into available token window "
- f"({self.available_window}). Skip it."
- )
- self.collect_relation(
- kind="skipped_span_does_not_fit_into_window", relation=coref_rel
- )
- continue
-
- task_encodings.append(
- TaskEncoding(
- document=document,
- inputs={
- "encoding": current_encoding,
- "encoding_pair": current_encoding_pair,
- "pooler_start_indices": token_span.start,
- "pooler_end_indices": token_span.end,
- "pooler_pair_start_indices": token_span_pair.start,
- "pooler_pair_end_indices": token_span_pair.end,
- },
- metadata={"candidate_annotation": coref_rel},
- )
- )
- self.collect_relation("used", coref_rel)
- return task_encodings
-
- def encode_target(
- self,
- task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
- ) -> Optional[TargetEncodingType]:
- return task_encoding.metadata["candidate_annotation"].score
-
- def collate(
- self,
- task_encodings: Sequence[
- TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType]
- ],
- ) -> Tuple[ModelInputType, Optional[ModelTargetType]]:
- inputs_dict = list_of_dicts2dict_of_lists(
- [task_encoding.inputs for task_encoding in task_encodings]
- )
-
- inputs = {
- k: (
- self.tokenizer.pad(v, return_tensors="pt").data
- if k in ["encoding", "encoding_pair"]
- else torch.tensor(v)
- )
- for k, v in inputs_dict.items()
- }
- for k, v in inputs.items():
- if k.startswith("pooler_") and k.endswith("_indices"):
- inputs[k] = v.unsqueeze(-1)
-
- if not task_encodings[0].has_targets:
- return inputs, None
- targets = {
- "scores": torch.tensor([task_encoding.targets for task_encoding in task_encodings])
- }
- return inputs, targets
-
- def configure_model_metric(self, stage: str) -> MetricCollection:
- return MetricCollection(
- metrics={
- "continuous": WrappedMetricWithPrepareFunction(
- metric=MetricCollection(
- {
- "auroc": BinaryAUROC(),
- "avg-P": BinaryAveragePrecision(validate_args=False),
- # "roc": BinaryROC(validate_args=False),
- # "PRCurve": BinaryPrecisionRecallCurve(validate_args=False),
- "f1": BinaryF1Score(threshold=self.similarity_threshold),
- }
- ),
- prepare_function=_get_scores,
- ),
- }
- )
-
- def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]:
- is_similar = (model_output["scores"] > self.similarity_threshold).detach().cpu().tolist()
- scores = model_output["scores"].detach().cpu().tolist()
- result: List[TaskOutputType] = [
- {"is_similar": is_sim, "score": prob} for is_sim, prob in zip(is_similar, scores)
- ]
- return result
-
- def create_annotations_from_output(
- self,
- task_encoding: TaskEncoding[DocumentType, InputEncodingType, TargetEncodingType],
- task_output: TaskOutputType,
- ) -> Iterator[Tuple[str, Annotation]]:
- if task_output["is_similar"]:
- score = task_output["score"]
- new_coref_rel = task_encoding.metadata["candidate_annotation"].copy(score=score)
- yield "binary_coref_relations", new_coref_rel
diff --git a/src/pie_modules/taskmodules/extractive_question_answering.py b/src/pie_modules/taskmodules/extractive_question_answering.py
deleted file mode 100644
index d774cf18b..000000000
--- a/src/pie_modules/taskmodules/extractive_question_answering.py
+++ /dev/null
@@ -1,239 +0,0 @@
-import dataclasses
-import logging
-from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
-
-import numpy as np
-import torch
-from pie_core import Annotation, AnnotationLayer, TaskEncoding, TaskModule
-from tokenizers import Encoding
-from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizer
-from transformers.modeling_outputs import QuestionAnsweringModelOutput
-from typing_extensions import TypeAlias
-
-from pie_modules.annotations import ExtractiveAnswer, Question
-from pie_modules.document.processing import tokenize_document
-from pie_modules.documents import (
- TextBasedDocument,
- TextDocumentWithQuestionsAndExtractiveAnswers,
- TokenDocumentWithQuestionsAndExtractiveAnswers,
-)
-
-logger = logging.getLogger(__name__)
-
-
-DocumentType: TypeAlias = TextBasedDocument
-InputEncoding: TypeAlias = Union[Dict[str, Any], BatchEncoding]
-
-
-@dataclasses.dataclass
-class TargetEncoding:
- start_position: int
- end_position: int
-
-
-TaskEncodingType: TypeAlias = TaskEncoding[
- TextDocumentWithQuestionsAndExtractiveAnswers,
- InputEncoding,
- TargetEncoding,
-]
-
-TaskBatchEncoding: TypeAlias = Tuple[BatchEncoding, Optional[Dict[str, Any]]]
-ModelBatchOutput: TypeAlias = QuestionAnsweringModelOutput
-
-
-@dataclasses.dataclass
-class TaskOutput:
- start: int
- end: int
- start_probability: float
- end_probability: float
-
-
-@TaskModule.register()
-class ExtractiveQuestionAnsweringTaskModule(TaskModule):
- """PIE task module for extractive question answering.
-
- This task module expects that the document is text based and contains an annotation layer for answers
- and one for questions.
-
- The task module will create a task encoding for each question-answer pair.
- The input encoding will be the tokenized document with the question as the second sequence.
- The target encoding will be the start and end position of the answer in the context.
- The task module will create a dummy target encoding where both start and end index are set to 0 (usually
- the CLS token position), if there is no answer for the question.
-
- Args:
- tokenizer_name_or_path: The name (Huggingface Hub identifier) or local path to a config of the tokenizer to use.
- max_length: The maximum length of the input sequence in means of tokens.
- answer_annotation: The name of the annotation layer for answers. Defaults to "answers".
- question_annotation: The name of the annotation layer for questions. Defaults to "questions".
- tokenize_kwargs: Additional keyword arguments for the tokenizer. Defaults to None.
- """
-
- DOCUMENT_TYPE = TextDocumentWithQuestionsAndExtractiveAnswers
-
- def __init__(
- self,
- tokenizer_name_or_path: str,
- max_length: int,
- answer_annotation: str = "answers",
- question_annotation: str = "questions",
- tokenize_kwargs: Optional[Dict[str, Any]] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- self.answer_annotation = answer_annotation
- self.question_annotation = question_annotation
- self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
- self.max_length = max_length
- self.tokenize_kwargs = tokenize_kwargs or {}
-
- def get_answer_layer(self, document: DocumentType) -> AnnotationLayer[ExtractiveAnswer]:
- # we expect that each document have an annotation layer for answers
- # where each entry is of type ExtractiveAnswer
- return document[self.answer_annotation]
-
- def get_question_layer(self, document: DocumentType) -> AnnotationLayer[Question]:
- answers = self.get_answer_layer(document)
- # we expect that the answers annotation layer targets the questions annotation layer
- # where each entry is of type Question
- return answers.target_layers[self.question_annotation]
-
- def get_context(self, document: DocumentType) -> str:
- answers = self.get_answer_layer(document)
- # we expect that the answers annotation layer targets the text field
- # which is a simple string
- return answers.targets["text"]
-
- def encode_input(
- self,
- document: DocumentType,
- is_training: bool = False,
- ) -> Optional[
- Union[
- TaskEncoding[DocumentType, InputEncoding, TargetEncoding],
- Sequence[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]],
- ]
- ]:
- questions = self.get_question_layer(document)
- task_encodings: List[TaskEncodingType] = []
- for question in questions:
- tokenized_docs = tokenize_document(
- document,
- tokenizer=self.tokenizer,
- text=question.text.strip(),
- truncation="only_second",
- max_length=self.max_length,
- return_overflowing_tokens=True,
- result_document_type=TokenDocumentWithQuestionsAndExtractiveAnswers,
- strict_span_conversion=False,
- verbose=False,
- **self.tokenize_kwargs,
- )
- for doc in tokenized_docs:
- inputs = self.tokenizer.convert_tokens_to_ids(list(doc.tokens))
- task_encodings.append(
- TaskEncodingType(
- document=document,
- inputs=inputs,
- metadata=dict(question=question, tokenized_document=doc),
- )
- )
- return task_encodings
-
- def encode_target(
- self,
- task_encoding: TaskEncodingType,
- ) -> Optional[TargetEncoding]:
- all_answers = self.get_answer_layer(task_encoding.metadata["tokenized_document"])
- # the document can contain multiple questions, so we filter the answers by the target question
- answers = [
- answer
- for answer in all_answers
- if answer.question == task_encoding.metadata["question"]
- ]
- # if there is no answer for the target question, we return a dummy target encoding
- if len(answers) == 0:
- return TargetEncoding(0, 0)
- if len(answers) > 1:
- logger.warning(
- f"The answers annotation layer is expected to have not more than one answer per question, "
- f"but it has {len(answers)} answers. We take just the first one."
- )
- answer = answers[0]
- return TargetEncoding(answer.start, answer.end - 1)
-
- def collate(
- self, task_encodings: Sequence[TaskEncoding[DocumentType, InputEncoding, TargetEncoding]]
- ) -> TaskBatchEncoding:
- def task_encoding2input_features(task_encoding: TaskEncodingType) -> Dict[str, Any]:
- encoding = task_encoding.metadata["tokenized_document"].metadata["tokenizer_encoding"]
- return {"input_ids": encoding.ids, "token_type_ids": encoding.type_ids}
-
- input_features = [
- task_encoding2input_features(task_encoding) for task_encoding in task_encodings
- ]
-
- # will contain: input_ids, token_type_ids, attention_mask
- inputs: BatchEncoding = self.tokenizer.pad(
- input_features, padding="longest", max_length=self.max_length, return_tensors="pt"
- )
-
- if not task_encodings[0].has_targets:
- return inputs, None
-
- start_positions = torch.tensor(
- [task_encoding.targets.start_position for task_encoding in task_encodings],
- dtype=torch.int64,
- )
- end_positions = torch.tensor(
- [task_encoding.targets.end_position for task_encoding in task_encodings],
- dtype=torch.int64,
- )
- targets = {"start_positions": start_positions, "end_positions": end_positions}
-
- return inputs, targets
-
- def unbatch_output(self, model_output: ModelBatchOutput) -> Sequence[TaskOutput]:
- batch_size = len(model_output.start_logits)
- start_probs = torch.softmax(model_output.start_logits, dim=-1).detach().cpu().numpy()
- end_probs = torch.softmax(model_output.end_logits, dim=-1).detach().cpu().numpy()
- best_start = np.argmax(start_probs, axis=1)
- best_end = np.argmax(end_probs, axis=1)
- return [
- TaskOutput(
- start=best_start[i],
- end=best_end[i],
- start_probability=start_probs[i, best_start[i]],
- end_probability=end_probs[i, best_end[i]],
- )
- for i in range(batch_size)
- ]
-
- def create_annotations_from_output(
- self,
- task_encoding: TaskEncoding[DocumentType, InputEncoding, TargetEncoding],
- task_output: TaskOutput,
- ) -> Iterator[Tuple[str, Annotation]]:
- tokenizer_encoding: Encoding = task_encoding.metadata["tokenized_document"].metadata[
- "tokenizer_encoding"
- ]
- start_chars = tokenizer_encoding.token_to_chars(task_output.start)
- end_chars = tokenizer_encoding.token_to_chars(task_output.end)
- if start_chars is not None and end_chars is not None:
- start_sequence_index = tokenizer_encoding.token_to_sequence(task_output.start)
- end_sequence_index = tokenizer_encoding.token_to_sequence(task_output.end)
- # the indices need to point into the context which is the second sequence
- if start_sequence_index == 1 and end_sequence_index == 1:
- start_char = start_chars[0]
- end_char = end_chars[-1]
- context = self.get_context(task_encoding.document)
- if 0 <= start_char < end_char <= len(context):
- yield self.answer_annotation, ExtractiveAnswer(
- start=start_char,
- end=end_char,
- question=task_encoding.metadata["question"],
- score=float(task_output.start_probability * task_output.end_probability),
- )
diff --git a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py b/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py
deleted file mode 100644
index ca566a024..000000000
--- a/src/pie_modules/taskmodules/labeled_span_extraction_by_token_classification.py
+++ /dev/null
@@ -1,468 +0,0 @@
-"""
-workflow:
- Document
- -> (InputEncoding, TargetEncoding) -> TaskEncoding
- -> ModelStepInputType -> ModelBatchOutput
- -> TaskOutput
- -> Document
-"""
-
-import logging
-from functools import partial
-from typing import (
- Any,
- Dict,
- Iterator,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- TypedDict,
- Union,
-)
-
-import torch
-from pie_core import AnnotationLayer, TaskEncoding, TaskModule
-from pie_core.utils.dictionary import list_of_dicts2dict_of_lists
-from tokenizers import Encoding
-from torchmetrics import F1Score, Metric, MetricCollection, Precision, Recall
-from transformers import AutoTokenizer
-from typing_extensions import TypeAlias
-
-from pie_modules.annotations import LabeledSpan
-from pie_modules.document.processing import (
- token_based_document_to_text_based,
- tokenize_document,
-)
-from pie_modules.documents import (
- TextBasedDocument,
- TextDocumentWithLabeledSpans,
- TextDocumentWithLabeledSpansAndLabeledPartitions,
- TokenDocumentWithLabeledSpans,
- TokenDocumentWithLabeledSpansAndLabeledPartitions,
-)
-from pie_modules.models.simple_token_classification import InputType as ModelInputType
-from pie_modules.models.simple_token_classification import TargetType as ModelTargetType
-from pie_modules.taskmodules.metrics import (
- PrecisionRecallAndF1ForLabeledAnnotations,
- WrappedMetricWithPrepareFunction,
-)
-from pie_modules.utils.sequence_tagging import tag_sequence_to_token_spans
-
-DocumentType: TypeAlias = TextBasedDocument
-
-InputEncodingType: TypeAlias = Encoding
-TargetEncodingType: TypeAlias = Sequence[int]
-TaskEncodingType: TypeAlias = TaskEncoding[
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
-]
-ModelStepInputType: TypeAlias = Tuple[
- ModelInputType,
- Optional[ModelTargetType],
-]
-ModelOutputType: TypeAlias = ModelTargetType
-
-
-class TaskOutputType(TypedDict, total=False):
- labels: torch.LongTensor
- probabilities: torch.FloatTensor
-
-
-TaskModuleType: TypeAlias = TaskModule[
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
- ModelStepInputType,
- ModelOutputType,
- TaskOutputType,
-]
-
-logger = logging.getLogger(__name__)
-
-
-def _get_label_ids_from_model_output(
- model_output: ModelTargetType,
-) -> torch.LongTensor:
- return model_output["labels"]
-
-
-def unbatch_and_decode_annotations(
- model_output: ModelOutputType,
- taskmodule: "LabeledSpanExtractionByTokenClassificationTaskModule",
-) -> List[Sequence[LabeledSpan]]:
- task_outputs = taskmodule.unbatch_output(model_output)
- annotations = [
- taskmodule.decode_annotations(task_output)["labeled_spans"] for task_output in task_outputs
- ]
- return annotations
-
-
-@TaskModule.register()
-class LabeledSpanExtractionByTokenClassificationTaskModule(TaskModuleType):
- """Taskmodule for span prediction (e.g. NER) as token classification.
-
- This taskmodule expects the input documents to be of TextBasedDocument with an annotation layer of
- labeled spans (e.g. TextDocumentWithLabeledSpans). The text is tokenized using the provided tokenizer and
- the labels are converted to BIO tags.
-
- To handle long documents, the text can be windowed using the respective parameters for the tokenizer,
- i.e. max_length (and stride). Note, that this requires to set return_overflowing_tokens=True, otherwise just
- the first window of input tokens is considered. The windowing is done in a way that the spans are not split
- across windows. If a span is split across windows, it is ignored during training and evaluation. Thus, if you
- have long spans in your data, it is recommended to set a stride that is as large as the average span length
- to avoid missing many spans.
-
- If a partition annotation is provided, the taskmodule expects the input documents to be of
- TextBasedDocument with two annotation layers of labeled spans, one for the spans and one for the partitions
- (e.g. TextDocumentWithLabeledSpansAndLabeledPartitions). Then, the text is tokenized and fed to the model
- individually per partition (e.g. per sentence). This is useful for long documents that can not be processed
- by the model as a whole, but where a natural partitioning exists (e.g. sentences or paragraphs) and, thus,
- windowing is not necessary (or a combination of both can be used).
-
- If labels are not provided, they are collected from the data during the prepare() step. If provided, they act as
- whitelist, i.e. spans with labels that are not in the labels are ignored during training and evaluation.
-
- Args:
- tokenizer_name_or_path: Name or path of the HuggingFace tokenizer to use.
- span_annotation: Name of the annotation layer that contains the labeled spans. Default: "labeled_spans".
- partition_annotation: Name of the annotation layer that contains the labeled partitions. If provided, the
- text is tokenized individually per partition. Default: None.
- label_pad_id: ID of the padding tag label. The model should ignore this for training. Default: -100.
- labels: List of labels to use. If not provided, the labels are collected from the labeled span annotations
- in the data during the prepare() step. Default: None.
- include_ill_formed_predictions: Whether to include ill-formed predictions in the output. If False, the
- predictions are corrected to be well-formed. Default: True.
- tokenize_kwargs: Keyword arguments to pass to the tokenizer during tokenization. Default: None.
- pad_kwargs: Keyword arguments to pass to the tokenizer during padding. Note, that this is used to pad the
- token ids *and* the tag ids, if available (i.e. during training or evaluation). Default: None.
- combine_token_scores_method: Method to combine the token scores to a span score. Options are "mean", "max",
- "min", and "product". Default: "mean".
- log_precision_recall_metrics: Whether to log precision and recall metrics (in addition to F1) for the
- spans. Default: True.
- """
-
- # list of attribute names that need to be set by _prepare()
- PREPARED_ATTRIBUTES: List[str] = ["labels"]
-
- def __init__(
- self,
- tokenizer_name_or_path: str,
- span_annotation: str = "labeled_spans",
- partition_annotation: Optional[str] = None,
- label_pad_id: int = -100,
- labels: Optional[List[str]] = None,
- include_ill_formed_predictions: bool = True,
- tokenize_kwargs: Optional[Dict[str, Any]] = None,
- pad_kwargs: Optional[Dict[str, Any]] = None,
- combine_token_scores_method: str = "mean",
- log_precision_recall_metrics: bool = True,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- self.span_annotation = span_annotation
- self.partition_annotation = partition_annotation
- self.labels = labels
- self.label_pad_id = label_pad_id
- self.include_ill_formed_predictions = include_ill_formed_predictions
- self.tokenize_kwargs = tokenize_kwargs or {}
- self.pad_kwargs = pad_kwargs or {}
- self.log_precision_recall_metrics = log_precision_recall_metrics
- self.combine_token_scores_method = combine_token_scores_method
-
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
-
- @property
- def document_type(self) -> Optional[Type[TextBasedDocument]]:
- dt: Type[TextBasedDocument]
- errors = []
- if self.span_annotation != "labeled_spans":
- errors.append(
- f"span_annotation={self.span_annotation} is not the default value ('labeled_spans')"
- )
- if self.partition_annotation is None:
- dt = TextDocumentWithLabeledSpans
- else:
- if self.partition_annotation != "labeled_partitions":
- errors.append(
- f"partition_annotation={self.partition_annotation} is not the default value "
- f"('labeled_partitions')"
- )
- dt = TextDocumentWithLabeledSpansAndLabeledPartitions
-
- if len(errors) == 0:
- return dt
- else:
- logger.warning(
- f"{' and '.join(errors)}, so the taskmodule {type(self).__name__} can not request "
- f"the usual document type ({dt.__name__}) for auto-conversion because this has the bespoken default "
- f"value as layer name(s) instead of the provided one(s)."
- )
- return None
-
- def get_span_layer(self, document: DocumentType) -> AnnotationLayer[LabeledSpan]:
- return document[self.span_annotation]
-
- def _prepare(self, documents: Sequence[DocumentType]) -> None:
- # collect all possible labels
- labels: Set[str] = set()
- for document in documents:
- spans: AnnotationLayer[LabeledSpan] = self.get_span_layer(document)
-
- for span in spans:
- labels.add(span.label)
-
- self.labels = sorted(labels)
- logger.info(f"Collected {len(self.labels)} labels from the data: {self.labels}")
-
- def _post_prepare(self):
- # create the real token labels (BIO scheme) from the labels
- self.label_to_id = {"O": 0}
- current_id = 1
- for label in sorted(self.labels):
- for prefix in ["B", "I"]:
- self.label_to_id[f"{prefix}-{label}"] = current_id
- current_id += 1
-
- self.id_to_label = {v: k for k, v in self.label_to_id.items()}
-
- def encode_input(
- self,
- document: TextBasedDocument,
- ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
- if self.partition_annotation is None:
- tokenized_document_type = TokenDocumentWithLabeledSpans
- casted_document_type = TextDocumentWithLabeledSpans
- field_mapping = {self.span_annotation: "labeled_spans"}
- else:
- tokenized_document_type = TokenDocumentWithLabeledSpansAndLabeledPartitions
- casted_document_type = TextDocumentWithLabeledSpansAndLabeledPartitions
- field_mapping = {
- self.span_annotation: "labeled_spans",
- self.partition_annotation: "labeled_partitions",
- }
- casted_document = document.as_type(casted_document_type, field_mapping=field_mapping)
- tokenized_docs = tokenize_document(
- casted_document,
- tokenizer=self.tokenizer,
- result_document_type=tokenized_document_type,
- partition_layer=(
- "labeled_partitions" if self.partition_annotation is not None else None
- ),
- strict_span_conversion=False,
- **self.tokenize_kwargs,
- )
-
- task_encodings: List[TaskEncodingType] = []
- for tokenized_doc in tokenized_docs:
- task_encodings.append(
- TaskEncoding(
- document=document,
- inputs=tokenized_doc.metadata["tokenizer_encoding"],
- metadata={"tokenized_document": tokenized_doc},
- )
- )
-
- return task_encodings
-
- def encode_target(
- self,
- task_encoding: TaskEncodingType,
- ) -> Optional[TargetEncodingType]:
- metadata = task_encoding.metadata
- tokenized_document = metadata["tokenized_document"]
- tokenizer_encoding: Encoding = tokenized_document.metadata["tokenizer_encoding"]
-
- tag_sequence = [
- None if tokenizer_encoding.special_tokens_mask[j] else "O"
- for j in range(len(tokenizer_encoding.ids))
- ]
- if self.labels is None:
- raise ValueError(
- "'labels' must be set before calling encode_target(). Was prepare() called on the taskmodule?"
- )
- sorted_spans = sorted(tokenized_document.labeled_spans, key=lambda s: (s.start, s.end))
- for span in sorted_spans:
- if span.label not in self.labels:
- continue
- start = span.start
- end = span.end
- if any(tag != "O" for tag in tag_sequence[start:end]):
- logger.warning(f"tag already assigned (current span has an overlap: {span}).")
- continue
-
- tag_sequence[start] = f"B-{span.label}"
- for j in range(start + 1, end):
- tag_sequence[j] = f"I-{span.label}"
-
- targets = [
- self.label_to_id[tag] if tag is not None else self.label_pad_id for tag in tag_sequence
- ]
-
- return targets
-
- def collate(self, task_encodings: Sequence[TaskEncodingType]) -> ModelStepInputType:
- input_encodings = [
- {
- "input_ids": task_encoding.inputs.ids,
- "attention_mask": task_encoding.inputs.attention_mask,
- "special_tokens_mask": task_encoding.inputs.special_tokens_mask,
- }
- for task_encoding in task_encodings
- ]
- inputs = self.tokenizer.pad(
- list_of_dicts2dict_of_lists(input_encodings), return_tensors="pt", **self.pad_kwargs
- )
-
- if not task_encodings[0].has_targets:
- return inputs, None
-
- tag_ids = [task_encoding.targets for task_encoding in task_encodings]
- targets = self.tokenizer.pad(
- {"input_ids": tag_ids}, return_tensors="pt", **self.pad_kwargs
- )["input_ids"]
-
- # set the padding label to the label_pad_token_id
- pad_mask = inputs["input_ids"] == self.tokenizer.pad_token_id
- targets[pad_mask] = self.label_pad_id
-
- return inputs, {"labels": targets}
-
- def unbatch_output(self, model_output: ModelOutputType) -> Sequence[TaskOutputType]:
- labels = model_output["labels"]
- probabilities = model_output.get("probabilities", None)
- batch_size = labels.shape[0]
- task_outputs: List[TaskOutputType] = []
- for batch_idx in range(batch_size):
- task_output: TaskOutputType = {"labels": labels[batch_idx]}
- if probabilities is not None:
- task_output["probabilities"] = probabilities[batch_idx]
- task_outputs.append(task_output)
- return task_outputs
-
- def decode_annotations(self, encoding: TaskOutputType) -> Dict[str, Sequence[LabeledSpan]]:
- labels = encoding["labels"]
- tag_sequence = [
- "O" if tag_id == self.label_pad_id else self.id_to_label[tag_id]
- for tag_id in labels.tolist()
- ]
- labeled_spans: List[LabeledSpan] = []
- for label, (start, end_inclusive) in tag_sequence_to_token_spans(
- tag_sequence,
- coding_scheme="IOB2",
- include_ill_formed=self.include_ill_formed_predictions,
- ):
- end = end_inclusive + 1
- # do not set the score if the probabilities are not available
- annotation_kwargs = {}
- if encoding.get("probabilities") is not None:
- span_probabilities = encoding["probabilities"][start:end]
- span_label_ids = labels[start:end]
- # get the probabilities at the label indices
- span_label_probs = torch.stack(
- [span_probabilities[i, l] for i, l in enumerate(span_label_ids)]
- )
- if self.combine_token_scores_method == "mean":
- # use mean probability of the span as score
- annotation_kwargs["score"] = span_label_probs.mean().item()
- elif self.combine_token_scores_method == "max":
- # use max probability of the span as score
- annotation_kwargs["score"] = span_label_probs.max().item()
- elif self.combine_token_scores_method == "min":
- # use min probability of the span as score
- annotation_kwargs["score"] = span_label_probs.min().item()
- elif self.combine_token_scores_method == "product":
- # use product of probabilities of the span as score
- annotation_kwargs["score"] = span_label_probs.prod().item()
- else:
- raise ValueError(
- f"combine_token_scores_method={self.combine_token_scores_method} is not supported."
- )
- labeled_span = LabeledSpan(label=label, start=start, end=end, **annotation_kwargs)
- labeled_spans.append(labeled_span)
- return {"labeled_spans": labeled_spans}
-
- def create_annotations_from_output(
- self,
- task_encoding: TaskEncodingType,
- task_output: TaskOutputType,
- ) -> Iterator[Tuple[str, LabeledSpan]]:
- tokenized_document = task_encoding.metadata["tokenized_document"]
- decoded_annotations = self.decode_annotations(task_output)
-
- # Note: token_based_document_to_text_based() does not yet consider predictions, so we need to clear
- # the main annotations and attach the predictions to that
- for layer_name, annotations in decoded_annotations.items():
- tokenized_document[layer_name].clear()
- for annotation in annotations:
- tokenized_document[layer_name].append(annotation)
-
- # we can not use self.document_type here because that may be None if self.span_annotation or
- # self.partition_annotation is not the default value
- document_type = (
- TextDocumentWithLabeledSpansAndLabeledPartitions
- if self.partition_annotation
- else TextDocumentWithLabeledSpans
- )
- untokenized_document: Union[
- TextDocumentWithLabeledSpans, TextDocumentWithLabeledSpansAndLabeledPartitions
- ] = token_based_document_to_text_based(
- tokenized_document, result_document_type=document_type
- )
-
- for span in untokenized_document.labeled_spans:
- # need to copy the span because it can be attached to only one document
- yield self.span_annotation, span.copy()
-
- def configure_model_metric(self, stage: str) -> Union[Metric, MetricCollection]:
- common_metric_kwargs = {
- "num_classes": len(self.label_to_id),
- "task": "multiclass",
- "ignore_index": self.label_pad_id,
- }
- token_scores = MetricCollection(
- {
- "token/macro/f1": WrappedMetricWithPrepareFunction(
- metric=F1Score(average="macro", **common_metric_kwargs),
- prepare_function=_get_label_ids_from_model_output,
- ),
- "token/micro/f1": WrappedMetricWithPrepareFunction(
- metric=F1Score(average="micro", **common_metric_kwargs),
- prepare_function=_get_label_ids_from_model_output,
- ),
- "token/macro/precision": WrappedMetricWithPrepareFunction(
- metric=Precision(average="macro", **common_metric_kwargs),
- prepare_function=_get_label_ids_from_model_output,
- ),
- "token/macro/recall": WrappedMetricWithPrepareFunction(
- metric=Recall(average="macro", **common_metric_kwargs),
- prepare_function=_get_label_ids_from_model_output,
- ),
- "token/micro/precision": WrappedMetricWithPrepareFunction(
- metric=Precision(average="micro", **common_metric_kwargs),
- prepare_function=_get_label_ids_from_model_output,
- ),
- "token/micro/recall": WrappedMetricWithPrepareFunction(
- metric=Recall(average="micro", **common_metric_kwargs),
- prepare_function=_get_label_ids_from_model_output,
- ),
- }
- )
-
- span_scores = PrecisionRecallAndF1ForLabeledAnnotations(
- flatten_result_with_sep="/",
- prefix="span/",
- return_recall_and_precision=self.log_precision_recall_metrics,
- )
- span_scores_wrapped = WrappedMetricWithPrepareFunction(
- metric=span_scores,
- prepare_function=partial(unbatch_and_decode_annotations, taskmodule=self),
- prepare_does_unbatch=True,
- )
-
- return MetricCollection([token_scores, span_scores_wrapped])
diff --git a/src/pie_modules/taskmodules/metrics/__init__.py b/src/pie_modules/taskmodules/metrics/__init__.py
deleted file mode 100644
index 8d23c8efb..000000000
--- a/src/pie_modules/taskmodules/metrics/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from .precision_recall_and_f1_for_labeled_annotations import (
- PrecisionRecallAndF1ForLabeledAnnotations,
-)
-from .wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function import (
- WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction,
-)
-from .wrapped_metric_with_prepare_function import WrappedMetricWithPrepareFunction
diff --git a/src/pie_modules/taskmodules/metrics/common.py b/src/pie_modules/taskmodules/metrics/common.py
deleted file mode 100644
index 18c8361eb..000000000
--- a/src/pie_modules/taskmodules/metrics/common.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import logging
-from abc import ABC
-from typing import Dict, Optional
-
-import torch
-from torch import LongTensor, Tensor
-from torchmetrics import Metric
-
-logger = logging.getLogger(__name__)
-
-
-class MetricWithArbitraryCounts(Metric, ABC):
- """A metric that hold counts for arbitrary keys."""
-
- def inc_counts(self, counts: LongTensor, key: Optional[str], prefix: str = "counts_"):
- full_key = prefix
- if key is not None:
- full_key += key
-
- if not hasattr(self, full_key):
- self.add_state(full_key, default=torch.zeros_like(counts), dist_reduce_fx="sum")
-
- prev_value = getattr(self, full_key)
- setattr(self, full_key, prev_value + counts)
-
- def get_counts(self, key_prefix: str = "counts_") -> Dict[Optional[str], LongTensor]:
- result = {}
- for k, v in self.metric_state.items():
- if k.startswith(key_prefix):
- if not isinstance(v, Tensor):
- raise ValueError(
- f"Expected metric state for key {k} to be a LongTensor, but got {type(v)}."
- )
- if not isinstance(v, LongTensor):
- v = v.long()
- key = k[len(key_prefix) :] or None
- result[key] = v
- return result
diff --git a/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py b/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py
deleted file mode 100644
index 2bb83eb4c..000000000
--- a/src/pie_modules/taskmodules/metrics/precision_recall_and_f1_for_labeled_annotations.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import logging
-from collections import Counter
-from typing import Any, Collection, Dict, Iterable, Optional, Union
-
-import torch
-from pie_core import Annotation
-from pie_core.utils.dictionary import flatten_dict_s
-from torch import LongTensor
-
-from .common import MetricWithArbitraryCounts
-
-logger = logging.getLogger(__name__)
-
-
-class PrecisionRecallAndF1ForLabeledAnnotations(MetricWithArbitraryCounts):
- """Computes precision, recall and F1 for labeled annotations. Inputs and targets are lists of
- annotations. True positives are counted as the number of annotations that are the same in both
- inputs and targets calculated as exact matches via set operation, false positives and false
- negatives accordingly. The annotations are deduplicated for each instance. But if the same
- annotation occurs in different instances, it is counted as two separate annotations.
-
- Args:
- label_mapping: A dictionary mapping annotation labels to human-readable labels. If None,
- the annotation labels are used as they are. Can be used to map label ids to string labels.
- key_micro: The key to use for the micro-average in the metric result dictionary.
- in_percent: Whether to return the results in percent, i.e. values between 0 and 100 instead of
- between 0 and 1.
- flatten_result_with_sep: If not None, the result dictionary is flattened and the keys of the
- different nesting levels are concatenated with the given separator.
- prefix: If not None, the most outer keys of the result dictionary are prefixed with this string.
- return_recall_and_precision: Whether to return recall and precision in addition to F1.
- """
-
- def __init__(
- self,
- label_mapping: Optional[Dict[Any, str]] = None,
- key_micro: Optional[str] = "micro",
- key_macro: Optional[str] = "macro",
- in_percent: bool = False,
- flatten_result_with_sep: Optional[str] = None,
- prefix: Optional[str] = None,
- return_recall_and_precision: bool = True,
- ):
- super().__init__()
- self.label_mapping = label_mapping
- self.key_micro = key_micro
- self.key_macro = key_macro
- self.in_percent = in_percent
- self.flatten_result_with_sep = flatten_result_with_sep
- self.prefix = prefix
- self.return_recall_and_precision = return_recall_and_precision
-
- def update(self, gold: Iterable[Annotation], predicted: Iterable[Annotation]) -> None:
- # remove duplicates within each list
- gold_set = set(gold)
- predicted_set = set(predicted)
- new_counts = self.calculate_counts(gold_set, predicted_set, gold_set & predicted_set)
- for k, v in new_counts.items():
- self.inc_counts(counts=v, key=k)
-
- def get_precision_recall_f1(self, n_gold_predicted_correct: LongTensor) -> Dict[str, float]:
- n_gold = n_gold_predicted_correct[0]
- n_predicted = n_gold_predicted_correct[1]
- n_correct = n_gold_predicted_correct[2]
- zero = torch.tensor(0.0).to(self.device)
- recall = zero if n_gold == 0 else (n_correct / n_gold)
- precision = zero if n_predicted == 0 else (n_correct / n_predicted)
- f1 = zero if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
-
- result = {"f1": f1}
- if self.return_recall_and_precision:
- result["recall"] = recall
- result["precision"] = precision
-
- if self.in_percent:
- result = {k: v * 100 for k, v in result.items()}
- return result
-
- def get_label(self, annotation: Annotation) -> Optional[str]:
- label: Optional[str] = getattr(annotation, "label", None)
- if self.label_mapping is not None:
- return self.label_mapping[label]
- return label
-
- def calculate_counts(
- self,
- gold: Collection[Annotation],
- predicted: Collection[Annotation],
- correct: Collection[Annotation],
- ) -> Dict[Optional[str], LongTensor]:
- result = {}
- # per class
- gold_counter = Counter([self.get_label(ann) for ann in gold])
- predicted_counter = Counter([self.get_label(ann) for ann in predicted])
- correct_counter = Counter([self.get_label(ann) for ann in correct])
- for label in gold_counter.keys() | predicted_counter.keys():
- if self.key_micro is not None and label == self.key_micro:
- raise ValueError(
- f"The key '{self.key_micro}' was used as an annotation label, but it is reserved for "
- f"the micro average. You can change which key is used for that with the 'key_micro' argument."
- )
- result[label] = torch.tensor(
- [
- gold_counter.get(label, 0),
- predicted_counter.get(label, 0),
- correct_counter.get(label, 0),
- ]
- ).to(device=self.device)
-
- # overall
- if self.key_micro is not None:
- result[self.key_micro] = torch.tensor([len(gold), len(predicted), len(correct)]).to(
- device=self.device
- )
- return result
-
- def compute(self) -> Union[Dict[str, Any], Dict[Optional[str], dict[str, float]]]:
- counts = self.get_counts()
- result = {label: self.get_precision_recall_f1(counts[label]) for label in counts.keys()}
- if self.key_macro is not None:
- result_without_micro = {
- k: v for k, v in result.items() if self.key_micro is None or k != self.key_micro
- }
- if len(result_without_micro) > 0:
- sub_keys = list(result_without_micro.values())[0].keys()
- result[self.key_macro] = {
- k: torch.stack([v[k] for v in result_without_micro.values()]).mean()
- for k in sub_keys
- }
-
- if self.prefix is not None:
- result = {f"{self.prefix}{k}": v for k, v in result.items()}
-
- if self.flatten_result_with_sep is not None:
- return flatten_dict_s(result, sep=self.flatten_result_with_sep)
- else:
- return result
diff --git a/src/pie_modules/taskmodules/metrics/wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py b/src/pie_modules/taskmodules/metrics/wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py
deleted file mode 100644
index c7c45aaea..000000000
--- a/src/pie_modules/taskmodules/metrics/wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py
+++ /dev/null
@@ -1,147 +0,0 @@
-import logging
-from typing import Any, Callable, Dict, Generic, Optional, Sequence, Tuple, TypeVar
-
-import torch
-from torch.nn import ModuleDict
-from torchmetrics import Metric
-
-from .common import MetricWithArbitraryCounts
-
-logger = logging.getLogger(__name__)
-T = TypeVar("T")
-U = TypeVar("U")
-
-
-class WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction(
- MetricWithArbitraryCounts, Generic[T, U]
-):
- """A wrapper around annotation layer metrics that can be used with batched encoded annotations.
-
- Args:
- layer_metrics: A dictionary mapping layer names to annotation layer metrics. Each metric
- should be a subclass of torchmetrics.Metric and should take two sets of annotations as
- input.
- unbatch_function: A function that takes a batched input and returns an iterable of
- individual inputs. This is used to unbatch the input before passing it to the annotation
- decoding function (decode_annotations_with_errors_function).
- decode_layers_with_errors_function: A function that takes an annotation encoding and
- returns a tuple of two dictionaries. The first dictionary maps layer names to a list of
- annotations. The second dictionary maps error names to the number of errors that were
- encountered while decoding the annotations.
- round_precision: The number of digits to round the results to. If None, no rounding is
- performed.
- error_key_correct: The key in the error dictionary whose value should be the number of *correctly*
- decoded annotations, so that the sum of all values in the error dictionary can be used to
- normalize the error counts. If None, the total number of training examples is used to
- normalize the error counts.
- collect_exact_encoding_matches: Whether to collect the number of examples where the full target encoding
- was predicted correctly (exact matches).
- """
-
- def __init__(
- self,
- layer_metrics: Dict[str, Metric],
- unbatch_function: Callable[[T], Sequence[U]],
- decode_layers_with_errors_function: Callable[[U], Tuple[Dict[str, Any], Dict[str, int]]],
- round_precision: Optional[int] = 4,
- error_key_correct: Optional[str] = None,
- collect_exact_encoding_matches: bool = True,
- ):
- super().__init__()
-
- self.key_error_correct = error_key_correct
- self.collect_exact_encoding_matches = collect_exact_encoding_matches
- self.round_precision = round_precision
- self.unbatch_function = unbatch_function
- self.decode_layers_with_errors_function = decode_layers_with_errors_function
- self.layer_metrics = ModuleDict(layer_metrics)
-
- # total number of encodings
- self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
- # this contains the number of examples where the full target sequence was predicted correctly (exact matches)
- self.add_state("exact_encoding_matches", default=torch.tensor(0), dist_reduce_fx="sum")
- # note: the error counts are stored via the MetricWithArbitraryCounts base class
-
- def update(self, prediction, expected):
- prediction_list = self.unbatch_function(prediction)
- expected_list = self.unbatch_function(expected)
- if len(prediction_list) != len(expected_list):
- raise ValueError(
- f"Number of predictions ({len(prediction_list)}) and targets ({len(expected_list)}) do not match."
- )
-
- for expected_encoding, prediction_encoding in zip(expected_list, prediction_list):
- expected_layers, _ = self.decode_layers_with_errors_function(expected_encoding)
- predicted_layers, predicted_errors = self.decode_layers_with_errors_function(
- prediction_encoding
- )
- for k, v in predicted_errors.items():
- self.inc_counts(counts=torch.tensor(v).to(self.device), key=k, prefix="errors_")
-
- for layer_name, metric in self.layer_metrics.items():
- metric.update(expected_layers[layer_name], predicted_layers[layer_name])
-
- if self.collect_exact_encoding_matches:
- if isinstance(expected_encoding, torch.Tensor) and isinstance(
- prediction_encoding, torch.Tensor
- ):
- is_match = torch.equal(expected_encoding, prediction_encoding)
- else:
- is_match = expected_encoding == prediction_encoding
- if is_match:
- self.exact_encoding_matches += 1
-
- self.total += 1
-
- def reset(self):
- super().reset()
-
- for metric in self.layer_metrics.values():
- metric.reset()
-
- def _nested_round(self, d: Dict[str, Any]) -> Dict[str, Any]:
- if self.round_precision is None:
- return d
- res: Dict[str, Any] = {}
- for k, v in d.items():
- if isinstance(v, dict):
- res[k] = self._nested_round(v)
- elif isinstance(v, float):
- res[k] = round(v, self.round_precision)
- else:
- res[k] = v
- return res
-
- def compute(self):
- res = {}
- if self.collect_exact_encoding_matches:
- res["exact_encoding_matches"] = (
- self.exact_encoding_matches / self.total if self.total > 0 else 0.0
- )
-
- errors = self.get_counts(key_prefix="errors_")
- # if errors contains a "correct" key, use that to normalize, otherwise use the number of training examples
- if self.key_error_correct in errors:
- errors_total = sum(errors.values())
- else:
- errors_total = self.total
- res["decoding_errors"] = {
- k: v / errors_total if errors_total > 0 else 0.0 for k, v in errors.items()
- }
- if "all" not in res["decoding_errors"]:
- res["decoding_errors"]["all"] = (
- sum(v for k, v in errors.items() if k != self.key_error_correct) / errors_total
- if errors_total > 0
- else 0.0
- )
-
- for layer_name, metric in self.layer_metrics.items():
- if layer_name in res:
- raise ValueError(
- f"Layer name '{layer_name}' is already used in the metric result dictionary."
- )
- res[layer_name] = metric.compute()
-
- res = self._nested_round(res)
-
- return res
diff --git a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py b/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py
deleted file mode 100644
index 7daceb9dd..000000000
--- a/src/pie_modules/taskmodules/metrics/wrapped_metric_with_prepare_function.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import logging
-from collections.abc import Collection, Sized
-from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
-
-from torch import Tensor
-from torchmetrics import Metric, MetricCollection
-from torchmetrics.wrappers.abstract import WrapperMetric
-
-logger = logging.getLogger(__name__)
-
-T = TypeVar("T")
-T2 = TypeVar("T2")
-
-
-class WrappedMetricWithPrepareFunction(WrapperMetric, Generic[T]):
- """A wrapper around a metric that can be used with predictions and targets that are need to be
- prepared (e.g. un-batched) before passing them to the metric.
-
- Args:
- metric: The metric to wrap. It should be a subclass of torchmetrics.Metric.
- prepare_function: A function that prepares the input for the metric. If provided, It is called with
- the predictions as well as the targets (separately).
- prepare_together_function: A function that prepares both the predictions and the targets together and
- should return them as a tuple. If provided, it is called with the predictions and the targets as
- arguments.
- prepare_does_unbatch: If True, the prepare_function is expected to return an iterable of
- individual inputs. This can be used to un-batch the input before passing it to the
- wrapped metric.
- """
-
- def __init__(
- self,
- metric: Union[Metric, MetricCollection],
- prepare_function: Optional[Callable[[T], Any]] = None,
- prepare_together_function: Optional[Callable[[T, T], Tuple[Any, Any]]] = None,
- prepare_does_unbatch: bool = False,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.metric = metric
- self.prepare_function = prepare_function
- self.prepare_both_function = prepare_together_function
- self.prepare_does_unbatch = prepare_does_unbatch
-
- def _is_empty_batch(self, prediction: T2, target: T2) -> bool:
- if isinstance(prediction, Sized) and isinstance(target, Sized):
- pred_len = len(prediction)
- target_len = len(target)
- else:
- raise ValueError(
- "Both prediction and target need to be sized when prepare_does_unbatch=False."
- )
- if pred_len != target_len:
- raise ValueError(
- f"Number of elements in prediction ({pred_len}) and target ({target_len}) do not match."
- )
- if pred_len == 0:
- return True
- return False
-
- def forward(self, prediction: T, target: T) -> Any:
- if self.prepare_function is not None:
- prediction = self.prepare_function(prediction)
- target = self.prepare_function(target)
- if self.prepare_both_function is not None:
- prediction, target = self.prepare_both_function(prediction, target)
- if self.prepare_does_unbatch:
- if not isinstance(prediction, Collection) or not isinstance(target, Collection):
- raise ValueError(
- "Both prediction and target need to be iterable and sized when prepare_does_unbatch=True."
- )
- if len(prediction) != len(target):
- raise ValueError(
- f"Number of prepared predictions ({len(prediction)}) and targets "
- f"({len(target)}) do not match."
- )
- if len(prediction) == 0:
- raise ValueError("Empty batch.")
- results = []
- for prediction_str, target_str in zip(prediction, target):
- current_result = self.metric(prediction_str, target_str)
- results.append(current_result)
- return results
- else:
- if not self._is_empty_batch(prediction, target):
- return self.metric(prediction, target)
- else:
- return None
-
- def update(self, prediction: T, target: T) -> None:
- if self.prepare_function is not None:
- prediction = self.prepare_function(prediction)
- target = self.prepare_function(target)
- if self.prepare_both_function is not None:
- prediction, target = self.prepare_both_function(prediction, target)
- if self.prepare_does_unbatch:
- if not isinstance(prediction, Collection) or not isinstance(target, Collection):
- raise ValueError(
- "Both prediction and target need to be iterable and sized when prepare_does_unbatch=True."
- )
- if len(prediction) != len(target):
- raise ValueError(
- f"Number of prepared predictions ({len(prediction)}) and targets "
- f"({len(target)}) do not match."
- )
- if len(prediction) == 0:
- raise ValueError("Empty batch.")
- for prediction_str, target_str in zip(prediction, target):
- self.metric.update(prediction_str, target_str)
- else:
- if not self._is_empty_batch(prediction, target):
- self.metric.update(prediction, target)
-
- def compute(self) -> Any:
- return self.metric.compute()
-
- def reset(self) -> None:
- self.metric.reset()
-
- @property
- def metric_state(self) -> Dict[str, Union[List[Tensor], Tensor]]:
- if isinstance(self.metric, Metric):
- return self.metric.metric_state
- elif isinstance(self.metric, MetricCollection):
- return {
- metric_name: metric.metric_state for metric_name, metric in self.metric.items()
- }
- else:
- raise ValueError(f"Unsupported metric type: {type(self.metric)}")
diff --git a/src/pie_modules/taskmodules/pointer_network/__init__.py b/src/pie_modules/taskmodules/pointer_network/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py b/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py
deleted file mode 100644
index 5d48629bd..000000000
--- a/src/pie_modules/taskmodules/pointer_network/annotation_encoder_decoder.py
+++ /dev/null
@@ -1,397 +0,0 @@
-import logging
-from collections import defaultdict
-from typing import Any, Dict, List, Optional, Set, Tuple
-
-from pie_modules.annotations import BinaryRelation, LabeledSpan, Span
-from pie_modules.taskmodules.common import AnnotationEncoderDecoder, DecodingException
-
-logger = logging.getLogger(__name__)
-
-
-class DecodingLengthException(DecodingException[List[int]]):
- identifier = "len"
-
-
-class DecodingOrderException(DecodingException[List[int]]):
- identifier = "order"
-
-
-class DecodingSpanOverlapException(DecodingException[List[int]]):
- identifier = "overlap"
-
-
-class DecodingLabelException(DecodingException[List[int]]):
- identifier = "label"
-
-
-class DecodingNegativeIndexException(DecodingException[List[int]]):
- identifier = "index"
-
-
-KEY_INVALID_CORRECT = "correct"
-
-
-class SpanEncoderDecoder(AnnotationEncoderDecoder[Span, List[int]]):
- def __init__(self, exclusive_end: bool = True):
- self.exclusive_end = exclusive_end
-
- def encode(self, annotation: Span, metadata: Optional[Dict[str, Any]] = None) -> List[int]:
- end_idx = annotation.end
- if not self.exclusive_end:
- end_idx -= 1
- return [annotation.start, end_idx]
-
- def decode(self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None) -> Span:
- if len(encoding) != 2:
- raise DecodingLengthException(
- f"two values are required to decode as Span, but encoding has length {len(encoding)}",
- encoding=encoding,
- )
- end_idx = encoding[1]
- if not self.exclusive_end:
- end_idx += 1
- if end_idx < encoding[0]:
- raise DecodingOrderException(
- f"end index can not be smaller than start index, but got: start={encoding[0]}, "
- f"end={end_idx}",
- encoding=encoding,
- )
- if any(idx < 0 for idx in encoding):
- raise DecodingNegativeIndexException(
- f"indices must be positive, but got: {encoding}", encoding=encoding
- )
- return Span(start=encoding[0], end=end_idx)
-
-
-class SpanEncoderDecoderWithOffset(SpanEncoderDecoder):
- def __init__(self, offset: int, **kwargs):
- super().__init__(**kwargs)
- self.offset = offset
-
- def encode(self, annotation: Span, metadata: Optional[Dict[str, Any]] = None) -> List[int]:
- encoding = super().encode(annotation=annotation, metadata=metadata)
- return [x + self.offset for x in encoding]
-
- def decode(self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None) -> Span:
- encoding = [x - self.offset for x in encoding]
- return super().decode(encoding=encoding, metadata=metadata)
-
-
-class LabeledSpanEncoderDecoder(AnnotationEncoderDecoder[LabeledSpan, List[int]]):
- def __init__(
- self,
- span_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]],
- label2id: Dict[str, int],
- mode: str,
- ):
- self.span_encoder_decoder = span_encoder_decoder
- self.label2id = label2id
- self.id2label = {idx: label for label, idx in self.label2id.items()}
- self.mode = mode
-
- def encode(
- self, annotation: LabeledSpan, metadata: Optional[Dict[str, Any]] = None
- ) -> List[int]:
- encoded_span = self.span_encoder_decoder.encode(annotation=annotation, metadata=metadata)
- encoded_label = self.label2id[annotation.label]
- if self.mode == "indices_label":
- return encoded_span + [encoded_label]
- elif self.mode == "label_indices":
- return [encoded_label] + encoded_span
- else:
- raise ValueError(f"unknown mode: {self.mode}")
-
- def decode(
- self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None
- ) -> LabeledSpan:
- if self.mode == "label_indices":
- encoded_label = encoding[0]
- encoded_span = encoding[1:]
- elif self.mode == "indices_label":
- encoded_label = encoding[-1]
- encoded_span = encoding[:-1]
- else:
- raise ValueError(f"unknown mode: {self.mode}")
-
- decoded_span = self.span_encoder_decoder.decode(encoding=encoded_span, metadata=metadata)
- if encoded_label not in self.id2label:
- raise DecodingLabelException(
- f"unknown label id: {encoded_label} (label2id: {self.label2id})", encoding=encoding
- )
- result = LabeledSpan(
- start=decoded_span.start,
- end=decoded_span.end,
- label=self.id2label[encoded_label],
- )
- return result
-
-
-class BinaryRelationEncoderDecoder(AnnotationEncoderDecoder[BinaryRelation, List[int]]):
- def __init__(
- self,
- head_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]],
- tail_encoder_decoder: AnnotationEncoderDecoder[Span, List[int]],
- label2id: Dict[str, int],
- mode: str,
- loop_dummy_relation_name: Optional[str] = None,
- none_label: Optional[str] = None,
- ):
- self.head_encoder_decoder = head_encoder_decoder
- self.tail_encoder_decoder = tail_encoder_decoder
- self.loop_dummy_relation_name = loop_dummy_relation_name
- self.none_label = none_label
- self.label2id = label2id
- self.id2label = {idx: label for label, idx in self.label2id.items()}
- self.mode = mode
-
- def encode(
- self, annotation: BinaryRelation, metadata: Optional[Dict[str, Any]] = None
- ) -> List[int]:
- encoded_head = self.head_encoder_decoder.encode(annotation=annotation.head)
- encoded_tail = self.tail_encoder_decoder.encode(annotation=annotation.tail)
-
- if (
- self.loop_dummy_relation_name is not None
- and annotation.label == self.loop_dummy_relation_name
- ):
- if annotation.head != annotation.tail:
- raise ValueError(
- f"expected head == tail for loop_dummy_relation, but got: {annotation.head}, "
- f"{annotation.tail}"
- )
- if self.none_label is None:
- raise ValueError(
- f"loop_dummy_relation_name is set, but none_label is not set: {self.none_label}"
- )
- none_id = self.label2id[self.none_label]
- encoded_none_argument = [none_id, none_id, none_id]
- if self.mode == "head_tail_label":
- return encoded_head + encoded_none_argument + [none_id]
- elif self.mode == "tail_head_label":
- return encoded_tail + encoded_none_argument + [none_id]
- elif self.mode == "label_head_tail":
- return [none_id] + encoded_head + encoded_none_argument
- elif self.mode == "label_tail_head":
- return [none_id] + encoded_tail + encoded_none_argument
- else:
- raise ValueError(f"unknown mode: {self.mode}")
- else:
- encoded_label = self.label2id[annotation.label]
- if self.mode == "tail_head_label":
- return encoded_tail + encoded_head + [encoded_label]
- elif self.mode == "head_tail_label":
- return encoded_head + encoded_tail + [encoded_label]
- elif self.mode == "label_head_tail":
- return [encoded_label] + encoded_head + encoded_tail
- elif self.mode == "label_tail_head":
- return [encoded_label] + encoded_tail + encoded_head
- else:
- raise ValueError(f"unknown mode: {self.mode}")
-
- def is_single_span_label(self, label: str) -> bool:
- return self.none_label is not None and label == self.none_label
-
- def decode(
- self, encoding: List[int], metadata: Optional[Dict[str, Any]] = None
- ) -> BinaryRelation:
- if len(encoding) != 7:
- raise DecodingLengthException(
- f"seven values are required to decode as BinaryRelation, but the encoding has length {len(encoding)}",
- encoding=encoding,
- )
- if self.mode.endswith("_label"):
- encoded_label = encoding[6]
- encoded_arguments = encoding[:6]
- argument_mode = self.mode[: -len("_label")]
- elif self.mode.startswith("label_"):
- encoded_label = encoding[0]
- encoded_arguments = encoding[1:]
- argument_mode = self.mode[len("label_") :]
- else:
- raise ValueError(f"unknown mode: {self.mode}")
- if encoded_label not in self.id2label:
- raise DecodingLabelException(
- f"unknown label id: {encoded_label} (label2id: {self.label2id})", encoding=encoding
- )
- label = self.id2label[encoded_label]
- if self.is_single_span_label(label=label):
- if argument_mode == "head_tail":
- span_encoder = self.head_encoder_decoder
- elif argument_mode == "tail_head":
- span_encoder = self.tail_encoder_decoder
- else:
- raise ValueError(f"unknown argument mode: {argument_mode}")
- encoded_span = encoded_arguments[:3]
- span = span_encoder.decode(encoding=encoded_span, metadata=metadata)
- if self.loop_dummy_relation_name is None:
- raise ValueError(
- f"loop_dummy_relation_name is not set, but none_label={self.none_label} "
- f"was found in decoded encoding: {encoding} (label2id: {self.label2id}))"
- )
- rel = BinaryRelation(head=span, tail=span, label=self.loop_dummy_relation_name)
- else:
- if argument_mode == "head_tail":
- encoded_head = encoded_arguments[:3]
- encoded_tail = encoded_arguments[3:]
- elif argument_mode == "tail_head":
- encoded_tail = encoded_arguments[:3]
- encoded_head = encoded_arguments[3:]
- else:
- raise ValueError(f"unknown argument mode: {argument_mode}")
- head = self.head_encoder_decoder.decode(encoding=encoded_head, metadata=metadata)
- tail = self.tail_encoder_decoder.decode(encoding=encoded_tail, metadata=metadata)
- rel = BinaryRelation(head=head, tail=tail, label=label)
-
- return rel
-
- def build_decoding_constraints(
- self, partial_encoding: List[int]
- ) -> Tuple[Optional[Set[int]], Optional[Set[int]]]:
- """Given a partial encoding, build the constraints for the next encoding step.
-
- Returns:
- Tuple[Optional[Set[int]], Optional[Set[int]]]: A tuple of two sets of integers representing the allowed
- and disallowed next indices. The first set contains the allowed indices, and the second set contains
- the disallowed indices. If no constraints are needed, both sets can be None.
- """
- allowed = None
- disallowed = None
-
- if self.mode != "tail_head_label":
- raise NotImplementedError(
- f"build_decoder_constraints is not implemented for mode {self.mode}"
- )
-
- if self.none_label not in self.label2id:
- raise ValueError(
- f"none_label not found in label2id: {self.label2id} (none_label: {self.none_label})"
- )
- none_id = self.label2id[self.none_label]
- if self.head_encoder_decoder != self.tail_encoder_decoder:
- raise NotImplementedError(
- "head and tail encoder/decoder must be the same for build_decoder_constraints"
- )
-
- if not isinstance(self.head_encoder_decoder, LabeledSpanEncoderDecoder):
- raise NotImplementedError(
- "head and tail encoder/decoder must be LabeledSpanEncoderDecoder for build_decoder_constraints"
- )
- if not isinstance(
- self.head_encoder_decoder.span_encoder_decoder, SpanEncoderDecoderWithOffset
- ):
- raise NotImplementedError(
- "head and tail encoder/decoder must be SpanEncoderDecoderWithOffset for build_decoder_constraints"
- )
- pointer_offset = self.head_encoder_decoder.span_encoder_decoder.offset
- if self.head_encoder_decoder.mode != "indices_label":
- raise NotImplementedError(
- "head and tail encoder/decoder must be indices_label for build_decoder_constraints"
- )
- if (
- not isinstance(self.head_encoder_decoder.span_encoder_decoder, SpanEncoderDecoder)
- or self.head_encoder_decoder.span_encoder_decoder.exclusive_end
- ):
- raise NotImplementedError(
- "head and tail encoder/decoder must be exclusive_end for build_decoder_constraints"
- )
- span_ids = set(self.head_encoder_decoder.label2id.values())
- relation_ids = set(self.label2id.values()) - {self.label2id[self.none_label]}
- contains_none = none_id in partial_encoding
- idx = len(partial_encoding)
- if idx == 0: # [] -> first span start or eos
- # Disallow all labels:
- disallowed = set(range(pointer_offset))
- elif idx == 1: # [14] -> first span end
- # Allow all offsets greater than the span start.
- span_start = partial_encoding[-1]
- # result[span_start:] = 1
- disallowed = set(range(span_start))
- # Disallow the none label:
- disallowed.add(none_id)
- elif idx == 2: # [14,14] -> first span label
- # Allow only span ids.
- allowed = span_ids
- elif idx == 3: # [14,14,s1] -> second span start or none
- # Disallow overlap of first and second span:
- first_span_start = partial_encoding[0]
- first_span_end = partial_encoding[1] + 1
- disallowed = set(range(first_span_start, first_span_end))
- # Disallow all span labels:
- disallowed.update(span_ids)
- # Disallow all relation labels:
- disallowed.update(relation_ids)
- # But allow the none label:
- disallowed.discard(none_id)
-
- elif idx == 4: # [14,14,s1,23] -> second span end or none
- # if we have a none label, allow only none
- if contains_none:
- allowed = {none_id}
- else:
-
- first_span_start = partial_encoding[0]
- # first_span_end = partial_encoding[1] + 1
- second_span_start = partial_encoding[-1]
- # if first span is after the second span,
- if second_span_start < first_span_start:
- # just allow the offsets between the two spans:
- allowed = set(range(second_span_start, first_span_start))
- else:
- # otherwise, disallow all offsets before the second span start:
- disallowed = set(range(second_span_start))
-
- # Disallow all span labels:
- disallowed.update(span_ids)
- # Disallow all relation labels:
- disallowed.update(relation_ids)
-
- elif idx == 5: # [14,14,s1,23,25] -> second span label or none
- # if we have a none label, allow only none
- if contains_none:
- # result[none_id] = 1
- allowed = {none_id}
- else:
- # allow only span ids
- allowed = span_ids
- elif idx == 6: # [14,14,s1,23,25,s2] -> relation label or none
- # if we have a none label, allow only none
- if contains_none:
- allowed = {none_id}
- else:
- # allow only relation ids
- allowed = relation_ids
- else:
- raise ValueError(
- f"unknown partial encoding length: {len(partial_encoding)} (encoding: {partial_encoding})"
- )
-
- return allowed, disallowed
-
- def parse(self, encoding: List[int]) -> Tuple[List[BinaryRelation], Dict[str, int], List[int]]:
- errors: Dict[str, int] = defaultdict(int)
- if self.none_label is None:
- raise ValueError(
- f"none_label is not set, but is required for parsing: {self.none_label}"
- )
- none_id = self.label2id[self.none_label]
- relation_ids = set(self.label2id.values()) - {none_id}
- encodings = []
- current_encoding: List[int] = []
- valid_encoding: BinaryRelation
- if len(encoding):
- for i in encoding:
- current_encoding.append(i)
- # An encoding is complete when it ends with a relation_id
- # or when it contains a none_id and has a length of 7
- if i in relation_ids or (i == none_id and len(current_encoding) == 7):
- # try to decode the current relation encoding
- try:
- valid_encoding = self.decode(encoding=current_encoding)
- encodings.append(valid_encoding)
- errors[KEY_INVALID_CORRECT] += 1
- except DecodingException as e:
- errors[e.identifier] += 1
-
- current_encoding = []
-
- return encodings, dict(errors), current_encoding
diff --git a/src/pie_modules/taskmodules/pointer_network/logits_processor.py b/src/pie_modules/taskmodules/pointer_network/logits_processor.py
deleted file mode 100644
index 777fd9763..000000000
--- a/src/pie_modules/taskmodules/pointer_network/logits_processor.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import math
-from typing import Callable, List
-
-import torch
-from transformers import LogitsProcessor, add_start_docstrings
-from transformers.generation.logits_process import LOGITS_PROCESSOR_INPUTS_DOCSTRING
-
-
-class PrefixConstrainedLogitsProcessorWithMaximum(LogitsProcessor):
- r"""This is similar to [`PrefixConstrainedLogitsProcessor`] but the constraint function gets the
- maximum possible index as input. This is useful for Pointer Network where the generated token
- can be an index into the input which depends on the length of that input.
-
- Args:
- prefix_allowed_tokens_fn (Callable[[int, torch.LongTensor, int], List[int]]):
- Should return the list of token ids allowed at the next generation step,
- given (`batch_id`, `input_ids_so_far`, `max_index`).
- """
-
- def __init__(
- self,
- prefix_allowed_tokens_fn: Callable[[int, torch.LongTensor, int], List[int]],
- num_beams: int,
- ):
- self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
- self._num_beams = num_beams
-
- @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
- def __call__(
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
- ) -> torch.FloatTensor:
- if not torch.isfinite(scores).all():
- raise ValueError(
- "scores contains ±inf or NaN, which is not allowed by "
- "PrefixConstrainedLogitsProcessorWithMaximum. "
- "Insert FinitizeLogitsProcessor earlier to clean them."
- )
- mask = torch.full_like(scores, -math.inf)
- for batch_id, beam_sent in enumerate(
- input_ids.view(-1, self._num_beams, input_ids.shape[-1])
- ):
- for beam_id, sent in enumerate(beam_sent):
- allowed_ids = self._prefix_allowed_tokens_fn(batch_id, sent, mask.size(1))
- if len(allowed_ids) == 0:
- raise ValueError(
- f"No allowed token ids for batch_id {batch_id}, beam_id {beam_id} with "
- f"previous ids: {sent}. This would result in undefined behaviour, "
- "so this is not allowed. Please adjust the prefix_allowed_tokens_fn "
- "implementation."
- )
- mask[batch_id * self._num_beams + beam_id, allowed_ids] = 0
-
- return scores + mask
-
-
-class FinitizeLogitsProcessor(LogitsProcessor):
- r"""Replaces any `±inf` logits with the largest-magnitude finite values for the tensor’s dtype,
- ensuring all logits are valid for downstream ops."""
-
- @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
- def __call__(
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
- ) -> torch.FloatTensor:
- finite_min = torch.finfo(scores.dtype).min
- finite_max = torch.finfo(scores.dtype).max
- # Use nan_to_num for a fast, fused replacement (PyTorch ≥ 1.8)
- return torch.nan_to_num(scores, neginf=finite_min, posinf=finite_max)
diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
deleted file mode 100644
index 9ddb4254b..000000000
--- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
+++ /dev/null
@@ -1,865 +0,0 @@
-import dataclasses
-import json
-import logging
-from collections import Counter, defaultdict
-from functools import cmp_to_key
-from typing import (
- Any,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- Union,
-)
-
-import torch
-from pie_core import (
- Annotation,
- AnnotationLayer,
- Document,
- TaskEncoding,
- TaskModule,
-)
-from pie_core.taskmodule import (
- InputEncoding,
- ModelBatchOutput,
- TargetEncoding,
- TaskBatchEncoding,
-)
-from pie_core.utils.hydra import resolve_type
-from torchmetrics import Metric
-from transformers import AutoTokenizer, LogitsProcessorList, PreTrainedTokenizer
-from typing_extensions import TypeAlias
-
-from pie_modules.annotations import BinaryRelation, LabeledSpan
-
-# import for backwards compatibility (don't remove!)
-from pie_modules.documents import (
- TextBasedDocument,
- TokenBasedDocument,
- TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
-)
-
-from ..document.processing import token_based_document_to_text_based, tokenize_document
-from .common import BatchableMixin, get_first_occurrence_index
-from .metrics import (
- PrecisionRecallAndF1ForLabeledAnnotations,
- WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction,
-)
-from .pointer_network.annotation_encoder_decoder import (
- KEY_INVALID_CORRECT,
- BinaryRelationEncoderDecoder,
- LabeledSpanEncoderDecoder,
- SpanEncoderDecoderWithOffset,
-)
-from .pointer_network.logits_processor import (
- FinitizeLogitsProcessor,
- PrefixConstrainedLogitsProcessorWithMaximum,
-)
-
-logger = logging.getLogger(__name__)
-
-
-DocumentType: TypeAlias = TextBasedDocument
-
-
-@dataclasses.dataclass
-class InputEncodingType(BatchableMixin):
- input_ids: List[int]
- attention_mask: List[int]
-
-
-@dataclasses.dataclass
-class LabelsAndOptionalConstraints(BatchableMixin):
- labels: List[int]
- constraints: Optional[List[List[int]]] = None
-
- @property
- def decoder_attention_mask(self) -> List[int]:
- return [1] * len(self.labels)
-
-
-TargetEncodingType: TypeAlias = LabelsAndOptionalConstraints
-TaskEncodingType: TypeAlias = TaskEncoding[
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
-]
-TaskOutputType: TypeAlias = LabelsAndOptionalConstraints
-
-
-def cmp_src_rel(v1: BinaryRelation, v2: BinaryRelation) -> int:
- if not all(isinstance(ann, LabeledSpan) for ann in [v1.head, v1.tail, v2.head, v2.tail]):
- raise Exception(f"expected LabeledSpan, but got: {v1}, {v2}")
- if v1.head.start == v2.head.start: # v1[0]["from"] == v2[0]["from"]:
- return v1.tail.start - v2.tail.start # v1[1]["from"] - v2[1]["from"]
- return v1.head.start - v2.head.start # v1[0]["from"] - v2[0]["from"]
-
-
-@TaskModule.register()
-class PointerNetworkTaskModuleForEnd2EndRE(
- TaskModule[
- DocumentType,
- InputEncoding,
- TargetEncoding,
- TaskBatchEncoding,
- ModelBatchOutput,
- TaskOutputType,
- ],
-):
- PREPARED_ATTRIBUTES = ["labels_per_layer"]
- REVERSED_RELATION_LABEL_SUFFIX = "_reversed"
-
- def __init__(
- self,
- tokenizer_name_or_path: str,
- # specific for this use case
- document_type: str = "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
- tokenized_document_type: str = "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
- relation_layer_name: str = "binary_relations",
- add_reversed_relations: bool = False,
- symmetric_relations: Optional[List[str]] = None,
- none_label: str = "none",
- loop_dummy_relation_name: str = "loop",
- constrained_generation: bool = False,
- # generic pointer network
- label_tokens: Optional[Dict[str, str]] = None,
- label_representations: Optional[Dict[str, str]] = None,
- labels_per_layer: Optional[Dict[str, List[str]]] = None,
- exclude_labels_per_layer: Optional[Dict[str, List[str]]] = None,
- # target encoding
- create_constraints: bool = False,
- # tokenization
- tokenizer_init_kwargs: Optional[Dict[str, Any]] = None,
- tokenizer_kwargs: Optional[Dict[str, Any]] = None,
- partition_layer_name: Optional[str] = None,
- annotation_field_mapping: Optional[Dict[str, str]] = None,
- # logging
- log_first_n_examples: Optional[int] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- # tokenization
- self._document_type: Type[TextBasedDocument] = resolve_type(
- document_type, expected_super_type=TextBasedDocument
- )
- self._tokenized_document_type: Type[TokenBasedDocument] = resolve_type(
- tokenized_document_type, expected_super_type=TokenBasedDocument
- )
- self.tokenizer_name_or_path = tokenizer_name_or_path
- self.tokenizer_kwargs = tokenizer_kwargs or {}
- self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
- tokenizer_name_or_path,
- **(tokenizer_init_kwargs or {}),
- )
- self.annotation_field_mapping = annotation_field_mapping or dict()
- annotation_field_mapping_inv = {v: k for k, v in self.annotation_field_mapping.items()}
- if len(self.annotation_field_mapping) != len(annotation_field_mapping_inv):
- raise ValueError(
- f"inverted annotation_field_mapping is not unique. annotation_field_mapping: "
- f"{self.annotation_field_mapping}"
- )
- self.partition_layer_name = partition_layer_name
-
- # for this specific use case: end-to-end relation extraction
- self.relation_layer_name = relation_layer_name
- relation_layer_mapped = self.annotation_field_mapping.get(
- relation_layer_name, relation_layer_name
- )
- relation_layer_target = self.document_type.target_name(relation_layer_mapped)
- self.span_layer_name = annotation_field_mapping_inv.get(
- relation_layer_target, relation_layer_target
- )
- self.add_reversed_relations = add_reversed_relations
- self.symmetric_relations = set(symmetric_relations or [])
- self.none_label = none_label
- self.loop_dummy_relation_name = loop_dummy_relation_name
- self.constrained_generation = constrained_generation
- # will be set in _post_prepare()
- self.relation_encoder_decoder: BinaryRelationEncoderDecoder
-
- # collected in prepare(), if not passed in
- self.labels_per_layer = labels_per_layer
- self.exclude_labels_per_layer = exclude_labels_per_layer or {}
-
- # how to encode and decode the annotations
- self.bos_token = self.tokenizer.bos_token
- self.eos_token = self.tokenizer.eos_token
- self.label_tokens = label_tokens or dict()
- self.label_representations = label_representations or dict()
-
- # target encoding
- self.create_constraints = create_constraints
- self.pad_values = {
- "input_ids": self.tokenizer.pad_token_id,
- "attention_mask": 0,
- "labels": self.target_pad_id,
- "decoder_attention_mask": 0,
- "constraints": -1,
- }
- self.dtypes = {
- "input_ids": torch.int64,
- "attention_mask": torch.int64,
- "labels": torch.int64,
- "decoder_attention_mask": torch.int64,
- "constraints": torch.int64,
- }
-
- # logging
- self.log_first_n_examples = log_first_n_examples
-
- @property
- def document_type(self) -> Type[TextBasedDocument]:
- return self._document_type
-
- @property
- def tokenized_document_type(self) -> Type[TokenBasedDocument]:
- return self._tokenized_document_type
-
- @property
- def layer_names(self) -> List[str]:
- return [self.span_layer_name, self.relation_layer_name]
-
- @property
- def special_targets(self) -> list[str]:
- return [self.bos_token, self.eos_token]
-
- @property
- def special_target2id(self) -> Dict[str, int]:
- return {target: idx for idx, target in enumerate(self.special_targets)}
-
- @property
- def target_pad_id(self) -> int:
- return self.special_target2id[self.eos_token]
-
- def configure_model_generation(self) -> Dict[str, Any]:
- result: Dict[str, Any] = {"no_repeat_ngram_size": 7}
- if self.constrained_generation:
- logits_processor = LogitsProcessorList()
- # PrefixConstrainedLogitsProcessorWithMaximum requires finite logits
- logits_processor.append(FinitizeLogitsProcessor())
- logits_processor.append(
- PrefixConstrainedLogitsProcessorWithMaximum(
- prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn_with_maximum,
- # use dummy value of 1, this is fine because num_beams affects only the value of batch_id
- # which is not used in _prefix_allowed_tokens_fn_with_maximum()
- num_beams=1,
- )
- )
- result["logits_processor"] = logits_processor
- return result
-
- def _prefix_allowed_tokens_fn_with_maximum(
- self, batch_id: int, input_ids: torch.LongTensor, maximum: int
- ) -> List[int]:
- # remove the first token (bos_token) and use unbatch_output to un-pad the label_ids
- label_ids_without_bos = input_ids[1:]
- if len(label_ids_without_bos) > 0:
- unpadded_label_ids = self.unbatch_output(
- {"labels": label_ids_without_bos.unsqueeze(0)}
- )[0].labels
- else:
- unpadded_label_ids = []
- _, _, remaining = self.relation_encoder_decoder.parse(encoding=unpadded_label_ids)
- # this is a binary mask
- constraint = self._build_constraint(
- previous_ids=remaining, input_len=maximum - self.pointer_offset
- )
- # convert to indices
- allowed_indices = torch.nonzero(constraint).squeeze(1)
- # convert to a list
- return allowed_indices.tolist()
-
- def add_reversed_relation_labels(self, relation_labels: Iterable[str]) -> Set[str]:
- result = set(relation_labels)
- for rel_label in set(relation_labels):
- if rel_label not in self.symmetric_relations:
- reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX
- if reversed_label in result:
- raise ValueError(
- f"reversed relation label {reversed_label} already exists in relation layer labels"
- )
- result.add(reversed_label)
- return result
-
- def _prepare(self, documents: Sequence[DocumentType]) -> None:
- # collect all labels
- labels: Dict[str, Set[str]] = {layer_name: set() for layer_name in self.layer_names}
- for doc in documents:
- for layer_name in self.layer_names:
- exclude_labels = self.exclude_labels_per_layer.get(layer_name, [])
- labels[layer_name].update(
- ac.label for ac in doc[layer_name] if ac.label not in exclude_labels
- )
-
- if self.add_reversed_relations:
- labels[self.relation_layer_name] = self.add_reversed_relation_labels(
- relation_labels=labels[self.relation_layer_name]
- )
-
- self.labels_per_layer = {
- # sort labels to ensure deterministic order
- layer_name: sorted(labels)
- for layer_name, labels in labels.items()
- }
-
- def construct_label_token(self, label: str) -> str:
- return self.label_tokens.get(label, f"<<{label}>>")
-
- def get_label_representation(self, label: str) -> str:
- return self.label_representations.get(label, label)
-
- def _post_prepare(self) -> None:
- # set up labels
- if self.labels_per_layer is None:
- raise Exception("labels_per_layer is not defined. Call prepare() first or pass it in.")
- self.labels: List[str] = [self.none_label]
- for layer_name in self.layer_names:
- self.labels.extend(self.labels_per_layer[layer_name])
- if len(set(self.labels)) != len(self.labels):
- raise Exception(f"labels are not unique: {self.labels}")
-
- # set up targets and ids
- self.targets: List[str] = self.special_targets + self.labels
- self.target2id: Dict[str, int] = {target: idx for idx, target in enumerate(self.targets)}
-
- # generic ids
- self.eos_id: int = self.target2id[self.eos_token]
- self.bos_id: int = self.target2id[self.bos_token]
-
- # span and relation ids
- self.span_ids: List[int] = [
- self.target2id[label] for label in self.labels_per_layer[self.span_layer_name]
- ]
- self.relation_ids: List[int] = [
- self.target2id[label] for label in self.labels_per_layer[self.relation_layer_name]
- ]
- # the none id is used for the dummy relation which models out-of-relation spans
- self.none_id: int = self.target2id[self.none_label]
-
- # helpers (same as targets / target2id, but only for labels)
- self.label2id: Dict[str, int] = {label: self.target2id[label] for label in self.labels}
- self.id2label: Dict[int, str] = {idx: label for label, idx in self.label2id.items()}
- self.label_ids: List[int] = [self.label2id[label] for label in self.labels]
-
- # annotation-encoder-decoders
- span_encoder_decoder = SpanEncoderDecoderWithOffset(
- offset=self.pointer_offset, exclusive_end=False
- )
- span_labels = self.labels_per_layer[self.span_layer_name]
- labeled_span_encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=span_encoder_decoder,
- # restrict label2id to get better error messages
- label2id={label: idx for label, idx in self.label2id.items() if label in span_labels},
- mode="indices_label",
- )
- relation_labels = self.labels_per_layer[self.relation_layer_name] + [
- self.loop_dummy_relation_name,
- self.none_label,
- ]
- self.relation_encoder_decoder = BinaryRelationEncoderDecoder(
- head_encoder_decoder=labeled_span_encoder_decoder,
- tail_encoder_decoder=labeled_span_encoder_decoder,
- # restrict label2id to get better error messages
- label2id={
- label: idx for label, idx in self.label2id.items() if label in relation_labels
- },
- loop_dummy_relation_name=self.loop_dummy_relation_name,
- none_label=self.none_label,
- mode="tail_head_label",
- )
-
- label2token = {label: self.construct_label_token(label=label) for label in self.labels}
- if len(set(label2token.values())) != len(label2token):
- raise Exception(f"label2token values are not unique: {label2token}")
-
- already_in_vocab = [
- tok
- for tok in label2token.values()
- if self.tokenizer.convert_tokens_to_ids(tok) != self.tokenizer.unk_token_id
- ]
- if len(already_in_vocab) > 0:
- raise Exception(
- f"some special tokens to add (mapped label ids) are already in the tokenizer vocabulary, "
- f"this is not allowed: {already_in_vocab}. You may want to adjust the label2special_token mapping"
- )
- # sort by length, so that longer tokens are added first
- label_tokens_sorted = sorted(label2token.values(), key=lambda x: len(x), reverse=True)
- self.tokenizer.add_special_tokens(
- special_tokens_dict={"additional_special_tokens": label_tokens_sorted}
- )
-
- # target tokens are the special tokens plus the mapped label tokens
- self.target_tokens: List[str] = self.special_targets + [
- label2token[label] for label in self.labels
- ]
- self.target_token_ids: List[int] = self.tokenizer.convert_tokens_to_ids(self.target_tokens)
-
- # construct a mapping from label_token_id to token_ids that will be used to initialize the embeddings
- # of the labels
- self.label_embedding_weight_mapping = dict()
- for label, label_token in label2token.items():
- label_token_indices = self.tokenizer.convert_tokens_to_ids(
- self.tokenizer.tokenize(label_token)
- )
- # sanity check: label_tokens should not be split up
- if len(label_token_indices) > 1:
- raise RuntimeError(f"{label_token} wrong split")
- else:
- label_token_idx = label_token_indices[0]
-
- label_representation = self.get_label_representation(label)
- source_indices = self.tokenizer.convert_tokens_to_ids(
- self.tokenizer.tokenize(label_representation)
- )
- if self.tokenizer.unk_token_id in source_indices:
- raise RuntimeError(
- f"tokenized label_token={label_token} [{source_indices}] contains unk_token"
- )
- self.label_embedding_weight_mapping[label_token_idx] = source_indices
-
- @property
- def pointer_offset(self) -> int:
- return len(self.targets)
-
- @property
- def target_ids(self) -> Set[int]:
- return set(range(self.pointer_offset))
-
- def configure_model_metric(self, stage: Optional[str] = None) -> Optional[Metric]:
- layer_metrics = {
- layer_name: PrecisionRecallAndF1ForLabeledAnnotations()
- for layer_name in self.layer_names
- }
-
- return WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction(
- unbatch_function=self.unbatch_output,
- decode_layers_with_errors_function=self.decode_annotations,
- layer_metrics=layer_metrics,
- error_key_correct=KEY_INVALID_CORRECT,
- )
-
- def reverse_relation(self, relation: Annotation) -> BinaryRelation:
- if isinstance(relation, BinaryRelation):
- reversed_label = relation.label
- if (
- reversed_label not in self.symmetric_relations
- and reversed_label != self.none_label
- ):
- reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX
- reversed_rel = relation.copy(
- head=relation.tail, tail=relation.head, label=reversed_label
- )
- return reversed_rel
- else:
- raise Exception(f"reversing of relations of type {type(relation)} is not supported")
-
- def unreverse_relation(self, relation: Annotation) -> BinaryRelation:
- if isinstance(relation, BinaryRelation):
- head, tail, label = relation.head, relation.tail, relation.label
- # if the relation is symmetric, we sort head and tail to ensure consistent order
- if relation.label in self.symmetric_relations:
- head, tail = sorted([head, tail], key=lambda x: (x.start, x.end))
- # if the relation was reversed, we need to reconstruct the original label and swap head and tail
- elif label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX):
- # reconstruct the original label and swap head and tail
- label = label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)]
- head, tail = tail, head
- return relation.copy(head=head, tail=tail, label=label)
- else:
- raise Exception(f"un-reversing of relations of type {type(relation)} is not supported")
-
- def encode_annotations(
- self, layers: Dict[str, Iterable[Annotation]], metadata: Optional[Dict[str, Any]] = None
- ) -> TaskOutputType:
- if not set(layers.keys()) == set(self.layer_names):
- raise Exception(f"unexpected layers: {layers.keys()}. expected: {self.layer_names}")
-
- if self.labels_per_layer is None:
- raise Exception("labels_per_layer is not defined. Call prepare() first or pass it in.")
-
- # encode relations
- all_relation_arguments = set()
- relation_arguments2label: Dict[Tuple[Annotation, ...], str] = dict()
- relation_encodings = dict()
- for rel in layers[self.relation_layer_name]:
- if not isinstance(rel, BinaryRelation):
- raise Exception(f"expected BinaryRelation, but got: {rel}")
- if rel.label in self.labels_per_layer[self.relation_layer_name]:
- if (rel.head, rel.tail) in relation_arguments2label:
- previous_label = relation_arguments2label[(rel.head, rel.tail)]
- if previous_label != rel.label:
- raise ValueError(
- f"relation {rel.head} -> {rel.tail} already exists, but has another label: "
- f"{previous_label} (current label: {rel.label})."
- )
- continue
- encoded_relation = self.relation_encoder_decoder.encode(
- annotation=rel, metadata=metadata
- )
- if encoded_relation is None:
- raise Exception(f"failed to encode relation: {rel}")
- relation_encodings[rel] = encoded_relation
- all_relation_arguments.update([rel.head, rel.tail])
- relation_arguments2label[(rel.head, rel.tail)] = rel.label
-
- # encode spans that are not arguments of any relation
- no_relation_spans = [
- span for span in layers[self.span_layer_name] if span not in all_relation_arguments
- ]
- for span in no_relation_spans:
- dummy_relation = BinaryRelation(
- head=span, tail=span, label=self.loop_dummy_relation_name
- )
- encoded_relation = self.relation_encoder_decoder.encode(
- annotation=dummy_relation, metadata=metadata
- )
- if encoded_relation is not None:
- relation_encodings[dummy_relation] = encoded_relation
-
- # sort relations by start indices of head and tail # TODO: is this correct?
- sorted_relations = sorted(relation_encodings, key=cmp_to_key(cmp_src_rel))
-
- # this should never be accessed as it is, so use negative pointer offset to provoke an error
- input_len = -self.pointer_offset - 1
- if self.create_constraints:
- if metadata is None or "src_len" not in metadata:
- raise Exception("metadata with 'src_len' is required to create constraints")
- input_len = metadata["src_len"]
-
- # build target_ids
- target_ids = []
- constraints_list = []
- for rel in sorted_relations:
- encoded_relation = relation_encodings[rel]
- target_ids.extend(encoded_relation)
-
- if self.create_constraints:
- # iterate over all prefixes of the relation encoding
- for idx, t in enumerate(encoded_relation):
- # get the constraints for the current prefix
- current_constraints = self._build_constraint(
- previous_ids=encoded_relation[:idx], input_len=input_len
- )
- # sanity check
- if current_constraints[t] == 0:
- raise Exception(
- f"current_constraints[{t}] is 0, but should be 1: {current_constraints}"
- )
- # add the constraints to the list
- constraints_list.append(current_constraints)
-
- target_ids.append(self.eos_id)
-
- if self.create_constraints:
- # add constraints for the eos_id
- eos_constraint = torch.zeros(input_len + self.pointer_offset, dtype=torch.int64)
- eos_constraint[self.eos_id] = 1
- constraints_list.append(eos_constraint)
- # combine all constraints
- constraints = torch.stack(constraints_list).tolist()
- else:
- constraints = None
-
- # sanity check
- _, encoding_errors, remaining = self.relation_encoder_decoder.parse(encoding=target_ids)
- if (
- not all(v == 0 for k, v in encoding_errors.items() if k != "correct")
- or len(remaining) > 0
- ):
- decoded, invalid = self.decode_annotations(LabelsAndOptionalConstraints(target_ids))
- not_encoded = {}
- for layer_name in layers:
- # convert to dicts to make them comparable (original annotations are attached which breaks comparison)
- decoded_dicts = [ann.asdict() for ann in decoded[layer_name]]
- # filter annotations and convert to str to make them json serializable
- filtered = {
- str(ann) for ann in layers[layer_name] if ann.asdict() not in decoded_dicts
- }
- if len(filtered) > 0:
- not_encoded[layer_name] = list(filtered)
- if len(not_encoded) > 0:
- logger.warning(
- f"encoding errors: {encoding_errors}, skipped annotations:\n"
- f"{json.dumps(not_encoded, sort_keys=True, indent=2)}"
- )
- elif len([tag for tag in remaining if tag != self.eos_id]) > 0:
- logger.warning(
- f"encoding errors: {encoding_errors}, remaining encoding ids: {remaining}"
- )
-
- return LabelsAndOptionalConstraints(labels=target_ids, constraints=constraints)
-
- def decode_annotations(
- self, encoding: TaskOutputType
- ) -> Tuple[Dict[str, Iterable[Annotation]], Dict[str, int]]:
- decoded_relations, errors, remaining = self.relation_encoder_decoder.parse(
- encoding=encoding.labels
- )
- relation_tuples: List[Tuple[Annotation, Annotation, str]] = []
- entity_labels: Dict[Annotation, List[str]] = defaultdict(list)
- for rel in decoded_relations:
- head_dummy = rel.head.copy(label="dummy")
- entity_labels[head_dummy].append(rel.head.label)
-
- if rel.label != self.loop_dummy_relation_name:
- tail_dummy = rel.tail.copy(label="dummy")
- entity_labels[tail_dummy].append(rel.tail.label)
- relation_tuples.append((head_dummy, tail_dummy, rel.label))
- else:
- assert rel.head == rel.tail
-
- # It may happen that some spans take part in multiple relations, but got generated with different labels.
- # In this case, we just create one span and take the most common label.
- entities: Dict[Annotation, Annotation] = {}
- for entity_dummy, labels in entity_labels.items():
- c = Counter(labels)
- # if len(c) > 1:
- # logger.warning(f"multiple labels for span, take the most common: {dict(c)}")
- most_common_label = c.most_common(1)[0][0]
- entities[entity_dummy] = entity_dummy.copy(label=most_common_label)
-
- entity_layer = list(entities.values())
- relation_layer = [
- BinaryRelation(head=entities[head_dummy], tail=entities[tail_dummy], label=label)
- for head_dummy, tail_dummy, label in relation_tuples
- ]
- return {
- self.span_layer_name: entity_layer,
- self.relation_layer_name: relation_layer,
- }, errors
-
- def _build_constraint(
- self,
- previous_ids: List[int],
- input_len: int,
- ) -> torch.LongTensor:
- """Build a constraint for the decoder. The constraint is a binary mask that indicates which
- ids are allowed to be predicted in the next decoding step. The mask is of size input_len +
- pointer_offset, where input_len is the length of the input sequence and pointer_offset is
- the number of labels and special tokens. Uses the relation_encoder_decoder to build the
- actual constraints.
-
- Args:
- previous_ids: previously decoded ids
- input_len: length of the input sequence
-
- Returns:
- A binary mask of size input_len + pointer_offset, where 1 indicates that the id is
- allowed to be predicted next, and 0 indicates that the id is not allowed to be predicted next.
- """
- result: torch.LongTensor = torch.zeros(input_len + self.pointer_offset, dtype=torch.int64)
- if self.eos_id in previous_ids:
- # once eos is predicted, only allow padding
- result[self.target_pad_id] = 1
- return result
-
- allowed_ids, disallowed_ids = self.relation_encoder_decoder.build_decoding_constraints(
- partial_encoding=previous_ids
- )
- if allowed_ids is not None and disallowed_ids is not None:
- raise Exception(
- f"allowed_ids and disallowed_ids are both not None: {allowed_ids}, {disallowed_ids}"
- )
- elif allowed_ids is not None:
- for allowed_id in allowed_ids:
- result[allowed_id] = 1
- elif disallowed_ids is not None:
- for id in range(len(result)):
- if id not in disallowed_ids:
- result[id] = 1
- else:
- raise Exception(
- f"allowed_ids and disallowed_ids are both None: {allowed_ids}, {disallowed_ids}"
- )
- if len(previous_ids) == 0:
- # if there are no previous ids, we also allow the eos_id
- result[self.eos_id] = 1
- else:
- # if there are previous ids, we don't allow the eos_id
- result[self.eos_id] = 0
- # never allow the bos_id
- result[self.bos_id] = 0
-
- return result
-
- def maybe_log_example(
- self,
- task_encoding: TaskEncodingType,
- targets: Optional[TargetEncodingType] = None,
- ):
- if self.log_first_n_examples is not None and self.log_first_n_examples > 0:
- tokenized_doc_id = task_encoding.metadata["tokenized_document"].id
- inputs = task_encoding.inputs
- targets = targets or task_encoding.targets
- input_tokens = self.tokenizer.convert_ids_to_tokens(inputs.input_ids)
- label_tokens = [
- (
- self.targets[target_id_or_offset]
- if target_id_or_offset < self.pointer_offset
- else str(target_id_or_offset)
- + " {"
- + str(input_tokens[target_id_or_offset - self.pointer_offset])
- + "}"
- )
- for target_id_or_offset in targets.labels
- ]
- logger.info("*** Example ***")
- logger.info(f"doc.id: {tokenized_doc_id}")
- logger.info(f"input_ids: {' '.join([str(i) for i in inputs.input_ids])}")
- logger.info(f"input_tokens: {' '.join(input_tokens)}")
- logger.info(f"label_ids: {' '.join([str(i) for i in targets.labels])}")
- logger.info(f"label_tokens: {' '.join(label_tokens)}")
- if self.create_constraints:
- # only show the shape because the content is not very readable
- logger.info(
- f"constraints: {torch.tensor(targets.constraints).shape} (content is omitted)"
- )
- self.log_first_n_examples -= 1
-
- def tokenize_document(self, document: DocumentType) -> List[TokenBasedDocument]:
- field_mapping = dict(self.annotation_field_mapping)
- if self.partition_layer_name is not None:
- field_mapping[self.partition_layer_name] = "labeled_partitions"
- partition_layer = "labeled_partitions"
- else:
- partition_layer = None
- casted_document = document.as_type(self.document_type, field_mapping=field_mapping)
- tokenized_docs = tokenize_document(
- casted_document,
- tokenizer=self.tokenizer,
- result_document_type=self.tokenized_document_type,
- partition_layer=partition_layer,
- **self.tokenizer_kwargs,
- )
- for idx, tokenized_doc in enumerate(tokenized_docs):
- tokenized_doc.id = f"{document.id}-tokenized-{idx+1}-of-{len(tokenized_docs)}"
-
- return tokenized_docs
-
- def encode_input(
- self, document: DocumentType, is_training: bool = False
- ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
- tokenized_docs = self.tokenize_document(document)
- task_encodings: List[TaskEncodingType] = []
- for tokenized_doc in tokenized_docs:
- tokenizer_encoding = tokenized_doc.metadata["tokenizer_encoding"]
- task_encodings.append(
- TaskEncoding(
- document=document,
- inputs=InputEncodingType(
- input_ids=tokenizer_encoding.ids,
- attention_mask=tokenizer_encoding.attention_mask,
- ),
- metadata={"tokenized_document": tokenized_doc},
- )
- )
-
- return task_encodings
-
- def get_mapped_layer(self, document: Document, layer_name: str) -> AnnotationLayer:
- if layer_name in self.annotation_field_mapping:
- layer_name = self.annotation_field_mapping[layer_name]
- return document[layer_name]
-
- def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncodingType]:
- try:
- document = task_encoding.metadata["tokenized_document"]
-
- layers = {
- layer_name: self.get_mapped_layer(document, layer_name=layer_name)
- for layer_name in self.layer_names
- }
-
- if self.add_reversed_relations:
- # create a copy to avoid modifying the annotation layer in the document
- relations = list(layers[self.relation_layer_name])
- reversed_relations = [self.reverse_relation(rel) for rel in relations]
- layers[self.relation_layer_name] = relations + reversed_relations
-
- result = self.encode_annotations(
- layers=layers,
- metadata={
- **task_encoding.metadata,
- "src_len": len(task_encoding.inputs.input_ids),
- },
- )
-
- self.maybe_log_example(task_encoding=task_encoding, targets=result)
- return result
- except Exception as e:
- logger.error(f"failed to encode target, it will be skipped: {e}")
- return None
-
- def collate(self, task_encodings: Sequence[TaskEncodingType]) -> TaskBatchEncoding:
- if len(task_encodings) == 0:
- raise ValueError("no task_encodings available")
- inputs = InputEncodingType.batch(
- values=[x.inputs for x in task_encodings],
- dtypes=self.dtypes,
- pad_values=self.pad_values,
- )
-
- targets = None
- if task_encodings[0].has_targets:
- targets = TargetEncodingType.batch(
- values=[x.targets for x in task_encodings],
- dtypes=self.dtypes,
- pad_values=self.pad_values,
- )
-
- return inputs, targets
-
- def unbatch_output(self, model_output: ModelBatchOutput) -> Sequence[TaskOutputType]:
- labels = model_output["labels"]
- batch_size = labels.size(0)
-
- # We use the position after the first eos token as the seq_len.
- # Note that, if eos_id is not in model_output for a given batch item, the result will be
- # model_output.size(1) + 1 (i.e. seq_len + 1) for that batch item. This is fine, because we use the
- # seq_lengths just to truncate the output and want to keep everything if eos_id is not present.
- seq_lengths = get_first_occurrence_index(labels, self.eos_id) + 1
-
- result = [
- LabelsAndOptionalConstraints(labels[i, : seq_lengths[i]].to(device="cpu").tolist())
- for i in range(batch_size)
- ]
- return result
-
- def create_annotations_from_output(
- self,
- task_encoding: TaskEncodingType,
- task_output: TaskOutputType,
- ) -> Iterator[Tuple[str, Annotation]]:
- layers, errors = self.decode_annotations(
- encoding=task_output, # metadata=task_encoding.metadata
- )
- tokenized_document = task_encoding.metadata["tokenized_document"]
-
- # Note: token_based_document_to_text_based() does not yet consider predictions, so we need to clear
- # the main annotations and attach the predictions to that
- for layer_name, annotations in layers.items():
- layer = self.get_mapped_layer(tokenized_document, layer_name=layer_name)
- layer.clear()
- layer.extend(annotations)
-
- untokenized_document = token_based_document_to_text_based(
- tokenized_document, result_document_type=self.document_type
- )
-
- for layer_name in layers:
- annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name)
- for annotation in annotations:
- # handle relations that may be reversed
- if layer_name == self.relation_layer_name and self.add_reversed_relations:
- unreversed_relation = self.unreverse_relation(annotation)
- yield layer_name, unreversed_relation
- else:
- yield layer_name, annotation.copy()
diff --git a/src/pie_modules/taskmodules/re_span_pair_classification.py b/src/pie_modules/taskmodules/re_span_pair_classification.py
deleted file mode 100644
index ad259bc3d..000000000
--- a/src/pie_modules/taskmodules/re_span_pair_classification.py
+++ /dev/null
@@ -1,829 +0,0 @@
-"""
-workflow:
- Document
- -> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding
- -> ModelBatchEncoding -> ModelBatchOutput
- -> TaskOutput
- -> Document
-"""
-
-import logging
-from collections import defaultdict
-from copy import deepcopy
-from typing import (
- Any,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- TypedDict,
- Union,
-)
-
-import pandas as pd
-import torch
-from pie_core import (
- Annotation,
- AnnotationLayer,
- Document,
- TaskEncoding,
- TaskModule,
-)
-from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
-from tokenizers import AddedToken
-from torch import BoolTensor, LongTensor, Tensor
-from torch.nn.utils.rnn import pad_sequence
-from torchmetrics import ClasswiseWrapper, F1Score, Metric, MetricCollection
-from transformers import AutoTokenizer
-from typing_extensions import TypeAlias
-
-from pie_modules.annotations import (
- BinaryRelation,
- LabeledSpan,
- MultiLabeledBinaryRelation,
- NaryRelation,
-)
-from pie_modules.document.processing import (
- token_based_document_to_text_based,
- tokenize_document,
-)
-from pie_modules.documents import (
- TextBasedDocument,
- TextDocumentWithLabeledPartitions,
- TextDocumentWithLabeledSpansAndBinaryRelations,
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
- TokenDocumentWithLabeledSpansAndBinaryRelations,
- TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
-)
-from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction
-from pie_modules.utils.span import distance as get_span_distance
-
-PAD_VALUES = {
- "input_ids": 0,
- "attention_mask": 0,
- "span_start_indices": 0,
- "span_end_indices": 0,
- "tuple_indices": -1,
- "labels": -100,
- "tuple_indices_mask": False,
-}
-DTYPES = {
- "input_ids": torch.long,
- "attention_mask": torch.long,
- "span_start_indices": torch.long,
- "span_end_indices": torch.long,
- "tuple_indices": torch.long,
- "labels": torch.long,
- "tuple_indices_mask": torch.bool,
-}
-
-
-class InputEncodingType(TypedDict, total=False):
- # shape: (sequence_length,)
- input_ids: LongTensor
- # shape: (sequence_length,)
- attention_mask: LongTensor
- # shape: (num_entities,)
- span_start_indices: LongTensor
- # shape: (num_entities,)
- span_end_indices: LongTensor
- # list of lists of argument indices: [[head_idx, tail_idx], ...]
- # NOTE: these indices point into span_start_indices and span_end_indices!
- tuple_indices: LongTensor
- tuple_indices_mask: BoolTensor
-
-
-class TargetEncodingType(TypedDict, total=False):
- # list of label indices: [label_idx, ...]
- labels: LongTensor
-
-
-DocumentType: TypeAlias = TextBasedDocument
-TaskEncodingType: TypeAlias = TaskEncoding[
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
-]
-
-
-class TaskOutputType(TypedDict, total=False):
- labels: Sequence[str]
- probabilities: Sequence[float]
-
-
-class ModelInputType(TypedDict, total=False):
- input_ids: LongTensor
- attention_mask: LongTensor
- span_start_indices: LongTensor
- span_end_indices: LongTensor
- tuple_indices: LongTensor
- tuple_indices_mask: BoolTensor
-
-
-class ModelTargetType(TypedDict, total=False):
- labels: LongTensor
- probabilities: LongTensor
-
-
-TaskModuleType: TypeAlias = TaskModule[
- # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
- Tuple[ModelInputType, Optional[ModelTargetType]],
- ModelTargetType,
- TaskOutputType,
-]
-
-
-HEAD = "head"
-TAIL = "tail"
-START = "start"
-END = "end"
-
-
-logger = logging.getLogger(__name__)
-
-
-def _get_label_ids_from_model_output(
- model_output: ModelTargetType,
-) -> LongTensor:
- return model_output["labels"]
-
-
-def get_relation_argument_spans_and_roles(
- relation: Annotation,
-) -> Tuple[Tuple[str, Annotation], ...]:
- if isinstance(relation, BinaryRelation):
- return (HEAD, relation.head), (TAIL, relation.tail)
- elif isinstance(relation, NaryRelation):
- # create unique order by sorting the arguments by their start and end positions and role
- sorted_args = sorted(
- zip(relation.roles, relation.arguments),
- key=lambda role_and_span: (
- role_and_span[1].start,
- role_and_span[1].end,
- role_and_span[0],
- ),
- )
- return tuple(sorted_args)
- else:
- raise NotImplementedError(
- f"the taskmodule does not yet support getting relation arguments for type: {type(relation)}"
- )
-
-
-def construct_argument_marker(pos: str, label: Optional[str] = None, role: str = "SPAN") -> str:
- if pos not in [START, END]:
- raise ValueError(f"pos must be one of {START} or {END}, but got: {pos}")
- start_or_end_marker = "" if pos == START else "/"
- if label is not None:
- return f"[{start_or_end_marker}{role}:{label}]"
- else:
- return f"[{start_or_end_marker}{role}]"
-
-
-def inject_markers_into_text(
- text: str, positions_and_markers: List[Tuple[int, str]]
-) -> Tuple[str, Dict[int, int]]:
- offset = 0
- original2new_pos = dict()
- for original_pos, marker in sorted(positions_and_markers):
- text = text[: original_pos + offset] + marker + text[original_pos + offset :]
- offset += len(marker)
- original2new_pos[original_pos] = original_pos + offset
- return text, original2new_pos
-
-
-def to_tensor(key: str, value: Any) -> Tensor:
- return torch.tensor(value, dtype=DTYPES[key])
-
-
-def pad_or_stack(key: str, values: List[LongTensor]) -> Tensor:
- if key in PAD_VALUES:
- max_last_dim = None
- if key == "tuple_indices":
- max_last_dim = max(v.shape[-1] for v in values if len(v.shape) == 2)
- values = [v.reshape(-1) for v in values]
- result = pad_sequence(values, batch_first=True, padding_value=PAD_VALUES[key])
- if key == "tuple_indices":
- batch_size = len(values)
- result = result.reshape(batch_size, -1, max_last_dim)
- else:
- result = torch.stack(values, dim=0)
- return result
-
-
-@TaskModule.register()
-class RESpanPairClassificationTaskModule(TaskModuleType, ChangesTokenizerVocabSize):
- """Task module for relation extraction as span pair classification.
-
- This task module frames relation extraction as a span pair classification task where all candidate
- pairs in a given text are classified at once. The task module injects start and end markers for
- each entity (i.e. "[SPAN]" and "[/SPAN]") into the text and tokenizes the text (the markers are
- handled as special tokens, and thus, kept as they are). It then collects the start- and end-marker
- positions for each entity and constructs a model input encoding from the tokenized text and these
- positions. The model target encoding consists of a list of label indices and a list of tuples
- (head and tail) of argument indices that point into the start- and end-marker positions from the
- model inputs. The model output is expected to be of the same format as the model target encoding,
- but with probabilities for each label.
-
- This means, that the model should return only positive relations (argument indices + label) and
- discard all negative ones.
-
- Args:
- tokenizer_name_or_path: The name or path of the tokenizer to use.
- relation_annotation: The name of the annotation layer that contains the binary relations.
- partition_annotation: The name of the annotation layer that contains the labeled partitions.
- If provided, the task module expects the document to have a partition layer with the
- given name containing LabeledSpans. These entries are used to split the text into
- partitions, e.g. paragraphs or sentences, that are treated as separate documents during
- tokenization. Defaults to None.
- tokenize_kwargs: Additional keyword arguments passed to the tokenizer during tokenization.
- create_candidate_relations: Whether to create candidate relations for training. If True, the
- task module creates all possible pairs of entities in the text as candidate relations.
- Defaults to False.
- create_candidate_relations_kwargs: Additional keyword arguments passed to the method that
- creates the candidate relations (e.g. max_argument_distance). Defaults to None.
- labels: The list of relation labels. If not provided, the task module will collect the labels
- from the documents during preparation. Defaults to None.
- entity_labels: The list of entity labels. If not provided, the task module will collect the
- entity labels from the documents during preparation. Defaults to None.
- add_type_to_marker: Whether to add the entity type to the markers. If True, the markers will
- look like this: "[SPAN:entity_type]" and "[/SPAN:entity_type]" where entity_type is the
- type of the respective entity. Defaults to False.
- log_first_n_examples: The number of examples to log during training. If 0, no examples are logged.
- Defaults to 0.
- collect_statistics: Whether to collect statistics during preparation. If True, the task module
- will collect statistics about the available, used, and skipped relations. Defaults to False.
- """
-
- PREPARED_ATTRIBUTES = ["labels", "entity_labels"]
-
- def __init__(
- self,
- tokenizer_name_or_path: str,
- relation_annotation: str = "binary_relations",
- no_relation_label: str = "no_relation",
- partition_annotation: Optional[str] = None,
- tokenize_kwargs: Optional[Dict[str, Any]] = None,
- create_candidate_relations: bool = False,
- create_candidate_relations_kwargs: Optional[Dict[str, Any]] = None,
- labels: Optional[List[str]] = None,
- entity_labels: Optional[List[str]] = None,
- add_type_to_marker: bool = True,
- log_first_n_examples: int = 0,
- collect_statistics: bool = False,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- self.relation_annotation = relation_annotation
- self.no_relation_label = no_relation_label
- self.tokenize_kwargs = tokenize_kwargs or {}
- self.create_candidate_relations = create_candidate_relations
- self.create_candidate_relations_kwargs = create_candidate_relations_kwargs or {}
- self.labels = labels
- self.add_type_to_marker = add_type_to_marker
- self.entity_labels = entity_labels
- self.partition_annotation = partition_annotation
- # overwrite None with 0 for backward compatibility
- self.log_first_n_examples = log_first_n_examples or 0
- self.collect_statistics = collect_statistics
-
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
-
- self.argument_markers = None
-
- self._logged_examples_counter = 0
-
- self.reset_statistics()
-
- @property
- def document_type(self) -> Optional[Type[DocumentType]]:
- if self.partition_annotation is not None:
- dt = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
- else:
- dt = TextDocumentWithLabeledSpansAndBinaryRelations
- if self.relation_annotation == "binary_relations":
- return dt
- else:
- logger.warning(
- f"relation_annotation={self.relation_annotation} is "
- f"not the default value ('binary_relations'), so the taskmodule {type(self).__name__} can not request "
- f"the usual document type for auto-conversion ({dt.__name__}) because this has the bespoken default "
- f"value as layer name instead of the provided one."
- )
- return None
-
- @property
- def tokenized_document_type(self) -> Type[TokenDocumentWithLabeledSpansAndBinaryRelations]:
- if self.partition_annotation is None:
- return TokenDocumentWithLabeledSpansAndBinaryRelations
- else:
- return TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
-
- @property
- def normalized_document_type(self) -> Type[TextDocumentWithLabeledSpansAndBinaryRelations]:
- if self.partition_annotation is None:
- return TextDocumentWithLabeledSpansAndBinaryRelations
- else:
- return TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
-
- def normalize_document(self, document) -> TextDocumentWithLabeledSpansAndBinaryRelations:
- span_layer_name = self.get_span_layer_name(document)
- field_mapping = {
- span_layer_name: "labeled_spans",
- self.relation_annotation: "binary_relations",
- }
- if self.partition_annotation is not None:
- field_mapping[self.partition_annotation] = "labeled_partitions"
- casted_document = document.as_type(
- self.normalized_document_type, field_mapping=field_mapping
- )
- return casted_document
-
- def get_relation_layer(self, document: Document) -> AnnotationLayer[BinaryRelation]:
- return document[self.relation_annotation]
-
- def get_span_layer_name(self, document: Document) -> str:
- return document[self.relation_annotation].target_name
-
- def get_entity_layer(self, document: Document) -> AnnotationLayer[LabeledSpan]:
- relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document)
- return relations.target_layer
-
- def _prepare(self, documents: Sequence[DocumentType]) -> None:
- entity_labels: Set[str] = set()
- relation_labels: Set[str] = set()
- for document in documents:
- relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document)
- entities: AnnotationLayer[LabeledSpan] = self.get_entity_layer(document)
-
- for entity in entities:
- entity_labels.add(entity.label)
-
- for relation in relations:
- relation_labels.add(relation.label)
-
- if self.no_relation_label in relation_labels:
- relation_labels.remove(self.no_relation_label)
-
- self.labels = sorted(relation_labels)
- self.entity_labels = sorted(entity_labels)
-
- def reset_statistics(self):
- self._statistics = defaultdict(int)
- self._collected_relations: Dict[str, List[Annotation]] = defaultdict(list)
-
- def collect_relation(self, kind: str, relation: Annotation):
- if self.collect_statistics:
- self._collected_relations[kind].append(relation)
-
- def collect_all_relations(self, kind: str, relations: Iterable[Annotation]):
- if self.collect_statistics:
- self._collected_relations[kind].extend(relations)
-
- def finalize_statistics(self):
- if self.collect_statistics:
- all_relations = set(self._collected_relations["available_tokenized"])
- used_relations = set(self._collected_relations["used"])
- skipped_other = all_relations - used_relations
- for key, rels in self._collected_relations.items():
- rels_set = set(rels)
- if key.startswith("skipped_"):
- skipped_other -= rels_set
- elif key.startswith("used_"):
- pass
- elif key in ["available", "available_tokenized", "used"]:
- pass
- else:
- raise ValueError(f"unknown key: {key}")
- for rel in rels_set:
- self.increase_counter(key=(key, rel.label))
- for rel in skipped_other:
- self.increase_counter(key=("skipped_other", rel.label))
-
- def show_statistics(self):
- if self.collect_statistics:
- self.finalize_statistics()
-
- to_show = pd.Series(self._statistics)
- if len(to_show.index.names) > 1:
- to_show = to_show.unstack()
- logger.info(f"statistics:\n{to_show.to_markdown()}")
-
- def increase_counter(self, key: Tuple[Any, ...], value: Optional[int] = 1):
- if self.collect_statistics:
- key_str = tuple(str(k) for k in key)
- self._statistics[key_str] += value
-
- def encode(self, *args, **kwargs):
- self.reset_statistics()
- res = super().encode(*args, **kwargs)
- self.show_statistics()
- return res
-
- def collect_argument_markers(self, entity_labels: Iterable[str]) -> List[str]:
- argument_markers: Set[str] = set()
- for arg_pos in [START, END]:
- if self.add_type_to_marker:
- for entity_label in entity_labels:
- argument_markers.add(
- construct_argument_marker(pos=arg_pos, label=entity_label)
- )
- else:
- argument_markers.add(construct_argument_marker(pos=arg_pos))
-
- return sorted(list(argument_markers))
-
- def _post_prepare(self):
- self.label_to_id = {label: i + 1 for i, label in enumerate(self.labels)}
- self.label_to_id[self.no_relation_label] = 0
- self.id_to_label = {v: k for k, v in self.label_to_id.items()}
-
- self.argument_markers = self.collect_argument_markers(self.entity_labels)
- num_added = self.tokenizer.add_special_tokens(
- {"additional_special_tokens": self.argument_markers}
- )
- if len(self.argument_markers) != num_added:
- logger.warning(
- f"expected to add {len(self.argument_markers)} argument markers, but added {num_added}. It seems "
- f"that the tokenizer already contains some of the argument markers."
- )
-
- self.argument_markers_to_id = {
- marker: self.tokenizer.vocab[marker] for marker in self.argument_markers
- }
-
- def _create_candidate_relations(
- self,
- document: TokenDocumentWithLabeledSpansAndBinaryRelations,
- max_argument_distance: Optional[int] = None,
- argument_distance_type: str = "inner",
- ) -> Sequence[Annotation]:
- # TODO: ensure that the relation layer type is BinaryRelation!
- labeled_spans = document.labeled_spans
- candidate_relations = []
- for i, head in enumerate(labeled_spans):
- for j, tail in enumerate(labeled_spans):
- if i == j:
- continue
- rel = BinaryRelation(head=head, tail=tail, label=self.no_relation_label)
- if max_argument_distance is not None:
- arg_distance = get_span_distance(
- start_end=(head.start, head.end),
- other_start_end=(tail.start, tail.end),
- distance_type=argument_distance_type,
- )
- if arg_distance > max_argument_distance:
- self.collect_relation("skipped_argument_distance", rel)
- continue
- candidate_relations.append(rel)
- return candidate_relations
-
- def inject_markers_for_labeled_spans(
- self,
- document: TextDocumentWithLabeledSpansAndBinaryRelations,
- ) -> Tuple[TextDocumentWithLabeledSpansAndBinaryRelations, Dict[LabeledSpan, LabeledSpan]]:
- # collect markers and injection positions
- positions_and_markers = []
- for labeled_span in document.labeled_spans:
- label_or_none = labeled_span.label if self.add_type_to_marker else None
- start_marker = construct_argument_marker(pos=START, label=label_or_none)
- positions_and_markers.append((labeled_span.start, start_marker))
- end_marker = construct_argument_marker(pos=END, label=label_or_none)
- positions_and_markers.append((labeled_span.end, end_marker))
-
- if isinstance(document, TextDocumentWithLabeledPartitions):
- # create "dummy" markers for the partitions so that entries for these positions are created
- # in original2new_pos
- for labeled_partition in document.labeled_partitions:
- positions_and_markers.append((labeled_partition.start, ""))
- positions_and_markers.append((labeled_partition.end, ""))
-
- # inject markers into the text
- marked_text, original2new_pos = inject_markers_into_text(
- document.text, positions_and_markers
- )
-
- # construct new spans
- old2new_spans = dict()
- for labeled_span in document.labeled_spans:
- start = original2new_pos[labeled_span.start]
- end = original2new_pos[labeled_span.end]
- new_span = LabeledSpan(start=start, end=end, label=labeled_span.label)
- old2new_spans[labeled_span] = new_span
-
- # construct new relations
- old2new_relations = dict()
- for relation in document.binary_relations:
- if isinstance(relation, BinaryRelation):
- head = old2new_spans[relation.head]
- tail = old2new_spans[relation.tail]
- new_relation = BinaryRelation(head=head, tail=tail, label=relation.label)
- else:
- raise NotImplementedError(
- f"the taskmodule does not yet support relations of type {type(relation)}"
- )
- old2new_relations[relation] = new_relation
-
- # construct new document
- new_document = type(document)(
- id=document.id,
- metadata=deepcopy(document.metadata),
- text=marked_text,
- )
- new_document.labeled_spans.extend(old2new_spans.values())
- new_document.binary_relations.extend(old2new_relations.values())
- if isinstance(document, TextDocumentWithLabeledPartitions):
- for labeled_partition in document.labeled_partitions:
- new_start = original2new_pos[labeled_partition.start]
- new_end = original2new_pos[labeled_partition.end]
- new_labeled_partitions = labeled_partition.copy(start=new_start, end=new_end)
- new_document.labeled_partitions.append(new_labeled_partitions)
-
- new2old_spans = {new_span: old_span for old_span, new_span in old2new_spans.items()}
- return new_document, new2old_spans
-
- def encode_input(
- self,
- document: DocumentType,
- is_training: bool = False,
- ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
- self.collect_all_relations("available", self.get_relation_layer(document))
-
- # 1. inject start and end markers for each entity into the text
- # - save mapping from new entities to original entities
- # 2. tokenize the text
- # - add the marker tokens to the tokenizer as special tokens
- # - tokenize with tokenize_document()
- # 3. get start- and end-token positions for each entity
- # 4. construct task encoding from tokenized text and entity positions
-
- normalized_document = self.normalize_document(document)
- document_with_markers, injected2original_spans = self.inject_markers_for_labeled_spans(
- normalized_document
- )
- all_added_annotations: List[Dict[str, Dict[Annotation, Annotation]]] = []
- tokenized_docs = tokenize_document(
- document_with_markers,
- tokenizer=self.tokenizer,
- result_document_type=self.tokenized_document_type,
- partition_layer=(
- "labeled_partitions" if self.partition_annotation is not None else None
- ),
- added_annotations=all_added_annotations,
- strict_span_conversion=False,
- **self.tokenize_kwargs,
- )
-
- task_encodings: List[TaskEncodingType] = []
- for tokenized_doc, tokenized_annotations in zip(tokenized_docs, all_added_annotations):
- self.collect_all_relations("available_tokenized", tokenized_doc.binary_relations)
- # collect start- and end-token positions for each entity
- span_start_indices = []
- span_end_indices = []
- for labeled_span in tokenized_doc.labeled_spans:
- # the start marker is one token before the start of the span
- span_start_indices.append(labeled_span.start - 1)
- # the end marker is one token after the end of the span, but the end index is exclusive
- span_end_indices.append(labeled_span.end)
-
- labeled_span2idx = {span: idx for idx, span in enumerate(tokenized_doc.labeled_spans)}
- tuple_indices = [] # list of lists of argument indices: [[head_idx, tail_idx], ...]
- if self.create_candidate_relations:
- candidate_relations = self._create_candidate_relations(
- tokenized_doc, **self.create_candidate_relations_kwargs
- )
- else:
- candidate_relations = tokenized_doc.binary_relations
-
- # if there are no candidate relations, skip the whole (tokenized) document
- if len(candidate_relations) == 0:
- continue
-
- for relation in candidate_relations:
- current_args_indices = []
- for _, arg_span in get_relation_argument_spans_and_roles(relation):
- arg_idx = labeled_span2idx[arg_span]
- current_args_indices.append(arg_idx)
- tuple_indices.append(current_args_indices)
-
- encoding = tokenized_doc.metadata["tokenizer_encoding"]
- inputs = {
- "input_ids": encoding.ids,
- "attention_mask": encoding.attention_mask,
- "span_start_indices": span_start_indices,
- "span_end_indices": span_end_indices,
- "tuple_indices": tuple_indices,
- "tuple_indices_mask": [True] * len(tuple_indices),
- }
- inputs_tensors = {k: to_tensor(k, v) for k, v in inputs.items()}
- task_encodings.append(
- TaskEncoding(
- document=document,
- inputs=inputs_tensors,
- metadata={
- "tokenized_document": tokenized_doc,
- "injected2original_spans": injected2original_spans,
- "candidate_relations": candidate_relations,
- "tokenized_annotations": tokenized_annotations,
- },
- )
- )
-
- return task_encodings
-
- def encode_target(
- self,
- task_encoding: TaskEncodingType,
- ) -> TargetEncodingType:
- gold_relations = task_encoding.metadata["tokenized_document"].binary_relations
- gold_roles_and_args2relation = defaultdict(list)
- for relation in gold_relations:
- # If we manually set the labels, we only consider relations with a label in the label_to_id mapping
- # This allows us to ignore relations with certain labels during training.
- if relation.label in self.label_to_id:
- gold_roles_and_args2relation[
- get_relation_argument_spans_and_roles(relation)
- ].append(relation)
- label_indices = [] # list of label indices
- candidate_relations = []
- for candidate_relation in task_encoding.metadata["candidate_relations"]:
- candidate_roles_and_args = get_relation_argument_spans_and_roles(candidate_relation)
- gold_relations = gold_roles_and_args2relation.get(candidate_roles_and_args, [])
- if len(gold_relations) == 0:
- label_idx = self.label_to_id[candidate_relation.label]
- self.collect_relation("used", candidate_relation)
- elif len(gold_relations) == 1:
- label_idx = self.label_to_id[gold_relations[0].label]
- self.collect_relation("used", gold_relations[0])
- else:
- # TODO: or should we add all gold relations with the same arguments?
- logger.warning(
- f"skip the candidate relation because there are more than one gold relation "
- f"for its args and roles: {gold_relations}"
- )
- for gold_relation in gold_relations:
- self.collect_relation("skipped_same_arguments", gold_relation)
- label_idx = PAD_VALUES["labels"]
-
- label_indices.append(label_idx)
- candidate_relations.append(candidate_relation)
-
- task_encoding.metadata["candidate_relations"] = candidate_relations
- target: TargetEncodingType = {"labels": to_tensor("labels", label_indices)}
-
- self._maybe_log_example(task_encoding=task_encoding, target=target)
-
- return target
-
- def _maybe_log_example(
- self,
- task_encoding: TaskEncodingType,
- target: TargetEncodingType,
- ):
- """Maybe log the example."""
-
- # log the first n examples
- if self._logged_examples_counter < self.log_first_n_examples:
- input_ids = task_encoding.inputs["input_ids"]
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
- logger.info("*** Example ***")
- logger.info(f"doc id: {task_encoding.document.id}")
- logger.info(f"tokens: {' '.join([x for x in tokens])}")
- logger.info(f"input_ids: {' '.join([str(x) for x in input_ids.tolist()])}")
- # target data
- span_start_indices = task_encoding.inputs["span_start_indices"]
- span_end_indices = task_encoding.inputs["span_end_indices"]
- labels = [self.id_to_label[label] for label in target["labels"].tolist()]
- for i, (label, tuple_indices) in enumerate(
- zip(labels, task_encoding.inputs["tuple_indices"])
- ):
- logger.info(f"relation {i}: {label}")
- for j, arg_idx in enumerate(tuple_indices):
- arg_tokens = tokens[span_start_indices[arg_idx] : span_end_indices[arg_idx]]
- logger.info(f"\targ {j}: {' '.join([str(x) for x in arg_tokens])}")
-
- self._logged_examples_counter += 1
-
- def collate(
- self, task_encodings: Sequence[TaskEncodingType]
- ) -> Tuple[ModelInputType, Optional[ModelTargetType]]:
- input_keys = task_encodings[0].inputs.keys()
- inputs: ModelInputType = { # type: ignore
- key: pad_or_stack(key, [task_encoding.inputs[key] for task_encoding in task_encodings])
- for key in input_keys
- }
-
- targets: Optional[ModelTargetType] = None
- if task_encodings[0].has_targets:
- target_keys = task_encodings[0].targets.keys()
- targets: ModelTargetType = { # type: ignore
- key: pad_or_stack(
- key, [task_encoding.targets[key] for task_encoding in task_encodings]
- )
- for key in target_keys
- }
-
- return inputs, targets
-
- def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]:
- # shape: (batch_size, num_candidates)
- label_ids = model_output["labels"].detach().cpu().tolist()
- # shape: (batch_size, num_candidates, num_labels)
- all_probabilities = model_output["probabilities"].detach().cpu().tolist()
- unbatched_output = []
- for batch_idx in range(len(label_ids)):
- labels = []
- probabilities = []
- for label_id, probs in zip(label_ids[batch_idx], all_probabilities[batch_idx]):
- labels.append(self.id_to_label[label_id])
- probabilities.append(probs[label_id])
- entry: TaskOutputType = {
- "labels": labels,
- "probabilities": probabilities,
- }
- unbatched_output.append(entry)
-
- return unbatched_output
-
- def decode_annotations(
- self,
- task_output: TaskOutputType,
- task_encoding: TaskEncodingType,
- ) -> Dict[str, List[Annotation]]:
- char2token_spans = task_encoding.metadata["tokenized_annotations"]["labeled_spans"]
- token2char_spans = {v: k for k, v in char2token_spans.items()}
- injected2original_spans = task_encoding.metadata["injected2original_spans"]
- new_relations = []
- for candidate_relation, label, probability, is_valid in zip(
- task_encoding.metadata["candidate_relations"],
- task_output["labels"],
- task_output["probabilities"],
- task_encoding.inputs["tuple_indices_mask"],
- ):
- # exclude
- # - padding entries (is_valid=False)
- # - negative relations (if we have added them)
- if is_valid and (
- label != self.no_relation_label or not self.create_candidate_relations
- ):
- token_head, token_tail = candidate_relation.head, candidate_relation.tail
- char_head = token2char_spans[token_head]
- char_tail = token2char_spans[token_tail]
- original_head = injected2original_spans[char_head]
- original_tail = injected2original_spans[char_tail]
- new_annotation = candidate_relation.copy(
- head=original_head, tail=original_tail, label=label, score=probability
- )
- new_relations.append(new_annotation)
-
- return {"binary_relations": new_relations}
-
- def create_annotations_from_output(
- self,
- task_encoding: TaskEncodingType,
- task_output: TaskOutputType,
- ) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation, NaryRelation]]]:
- decoded_annotations = self.decode_annotations(
- task_output=task_output, task_encoding=task_encoding
- )
-
- for relation in decoded_annotations["binary_relations"]:
- yield self.relation_annotation, relation
-
- def configure_model_metric(self, stage: str) -> Metric:
- if self.label_to_id is None:
- raise ValueError(
- "The taskmodule has not been prepared yet, so label_to_id is not known. "
- "Please call taskmodule.prepare(documents) before configuring the model metric "
- "or pass the labels to the taskmodule constructor an call taskmodule.post_prepare()."
- )
- labels = [self.id_to_label[i] for i in range(len(self.label_to_id))]
- common_metric_kwargs = {
- "num_classes": len(labels),
- "task": "multiclass",
- "ignore_index": PAD_VALUES["labels"],
- }
- return WrappedMetricWithPrepareFunction(
- metric=MetricCollection(
- {
- "micro/f1": F1Score(average="micro", **common_metric_kwargs),
- "macro/f1": F1Score(average="macro", **common_metric_kwargs),
- "f1_per_label": ClasswiseWrapper(
- F1Score(average=None, **common_metric_kwargs),
- labels=labels,
- postfix="/f1",
- ),
- }
- ),
- prepare_function=_get_label_ids_from_model_output,
- )
diff --git a/src/pie_modules/taskmodules/re_text_classification_with_indices.py b/src/pie_modules/taskmodules/re_text_classification_with_indices.py
deleted file mode 100644
index 61fa41f25..000000000
--- a/src/pie_modules/taskmodules/re_text_classification_with_indices.py
+++ /dev/null
@@ -1,1508 +0,0 @@
-"""
-workflow:
- Document
- -> (InputEncoding, TargetEncoding) -> TaskEncoding -> TaskBatchEncoding
- -> ModelBatchEncoding -> ModelBatchOutput
- -> TaskOutput
- -> Document
-"""
-
-import logging
-from collections import defaultdict
-from functools import partial
-from typing import (
- Any,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- TypedDict,
- Union,
-)
-
-import numpy as np
-import torch
-from pie_core import (
- Annotation,
- AnnotationLayer,
- Document,
- TaskEncoding,
- TaskModule,
-)
-from pytorch_ie.taskmodules.interface import ChangesTokenizerVocabSize
-from pytorch_ie.utils.window import get_window_around_slice
-from torch import LongTensor
-from torchmetrics import ClasswiseWrapper, F1Score, MetricCollection
-from transformers import AutoTokenizer
-from transformers.file_utils import PaddingStrategy
-from transformers.tokenization_utils_base import TruncationStrategy
-from typing_extensions import TypeAlias, TypeVar
-
-from pie_modules.annotations import (
- BinaryRelation,
- LabeledSpan,
- MultiLabeledBinaryRelation,
- NaryRelation,
- Span,
-)
-from pie_modules.documents import (
- TextBasedDocument,
- TextDocumentWithLabeledSpansAndBinaryRelations,
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
-)
-from pie_modules.models.simple_sequence_classification import (
- InputType as ModelInputType,
-)
-from pie_modules.models.simple_sequence_classification import (
- TargetType as ModelTargetType,
-)
-from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin
-from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction
-from pie_modules.utils.span import distance as span_distance
-from pie_modules.utils.span import is_contained_in
-from pie_modules.utils.tokenization import (
- SpanNotAlignedWithTokenException,
- get_aligned_token_span,
-)
-
-InputEncodingType: TypeAlias = Dict[str, Any]
-TargetEncodingType: TypeAlias = Sequence[int]
-DocumentType: TypeAlias = TextBasedDocument
-
-TaskEncodingType: TypeAlias = TaskEncoding[
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
-]
-
-
-class TaskOutputType(TypedDict, total=False):
- labels: Sequence[str]
- probabilities: Sequence[float]
-
-
-TaskModuleType: TypeAlias = TaskModule[
- # _InputEncoding, _TargetEncoding, _TaskBatchEncoding, _ModelBatchOutput, _TaskOutput
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
- Tuple[ModelInputType, Optional[ModelTargetType]],
- ModelTargetType,
- TaskOutputType,
-]
-
-
-HEAD = "head"
-TAIL = "tail"
-START = "start"
-END = "end"
-
-
-logger = logging.getLogger(__name__)
-
-
-def _get_labels(model_output: ModelTargetType) -> LongTensor:
- return model_output["labels"]
-
-
-def _get_labels_together_remove_none_label(
- predictions: ModelTargetType, targets: ModelTargetType, none_idx: int
-) -> Tuple[LongTensor, LongTensor]:
- mask_not_both_none = (predictions["labels"] != none_idx) | (targets["labels"] != none_idx)
- predictions_not_none = predictions["labels"][mask_not_both_none]
- targets_not_none = targets["labels"][mask_not_both_none]
- return predictions_not_none, targets_not_none
-
-
-def find_sublist(sub: List, bigger: List) -> int:
- if not bigger:
- return -1
- if not sub:
- return 0
- first, rest = sub[0], sub[1:]
- pos = 0
- try:
- while True:
- pos = bigger.index(first, pos) + 1
- if not rest or bigger[pos : pos + len(rest)] == rest:
- return pos - 1
- except ValueError:
- return -1
-
-
-class MarkerFactory:
- def __init__(self, role_to_marker: Dict[str, str]):
- self.role_to_marker = role_to_marker
-
- def _get_role_marker(self, role: str) -> str:
- return self.role_to_marker[role]
-
- def _get_marker(self, role: str, is_start: bool, label: Optional[str] = None) -> str:
- result = "["
- if not is_start:
- result += "/"
- result += self._get_role_marker(role)
- if label is not None:
- result += f":{label}"
- result += "]"
- return result
-
- def get_start_marker(self, role: str, label: Optional[str] = None) -> str:
- return self._get_marker(role=role, is_start=True, label=label)
-
- def get_end_marker(self, role: str, label: Optional[str] = None) -> str:
- return self._get_marker(role=role, is_start=False, label=label)
-
- def get_append_marker(self, role: str, label: Optional[str] = None) -> str:
- role_marker = self._get_role_marker(role)
- if label is None:
- return f"[{role_marker}]"
- else:
- return f"[{role_marker}={label}]"
-
- @property
- def all_roles(self) -> Set[str]:
- return set(self.role_to_marker)
-
- def get_all_markers(
- self,
- entity_labels: List[str],
- append_markers: bool = False,
- add_type_to_marker: bool = False,
- ) -> List[str]:
- result: Set[str] = set()
- if add_type_to_marker:
- none_and_labels = [None] + entity_labels
- else:
- none_and_labels = [None]
- for role in self.all_roles:
- # create start and end markers without label and for all labels, if add_type_to_marker
- for maybe_label in none_and_labels:
- result.add(self.get_start_marker(role=role, label=maybe_label))
- result.add(self.get_end_marker(role=role, label=maybe_label))
- # create append markers for all labels
- if append_markers:
- for entity_label in entity_labels:
- result.add(self.get_append_marker(role=role, label=entity_label))
-
- # sort and convert to list
- return sorted(result)
-
-
-class RelationArgument:
- def __init__(
- self,
- entity: LabeledSpan,
- role: str,
- token_span: Span,
- add_type_to_marker: bool,
- marker_factory: MarkerFactory,
- ) -> None:
- self.marker_factory = marker_factory
- if role not in self.marker_factory.all_roles:
- raise ValueError(
- f"role='{role}' not in known roles={sorted(self.marker_factory.all_roles)} (did you "
- f"initialise the taskmodule with the correct argument_role_to_marker dictionary?)"
- )
-
- self.entity = entity
-
- self.role = role
- self.token_span = token_span
- self.add_type_to_marker = add_type_to_marker
-
- @property
- def maybe_label(self) -> Optional[str]:
- return self.entity.label if self.add_type_to_marker else None
-
- @property
- def as_start_marker(self) -> str:
- return self.marker_factory.get_start_marker(role=self.role, label=self.maybe_label)
-
- @property
- def as_end_marker(self) -> str:
- return self.marker_factory.get_end_marker(role=self.role, label=self.maybe_label)
-
- @property
- def as_append_marker(self) -> str:
- # Note: we add the label in either case (we use self.entity.label instead of self.label)
- return self.marker_factory.get_append_marker(role=self.role, label=self.entity.label)
-
- def shift_token_span(self, value: int):
- self.token_span = Span(
- start=self.token_span.start + value, end=self.token_span.end + value
- )
-
-
-def get_relation_argument_spans_and_roles(
- relation: Annotation,
-) -> Tuple[Tuple[str, Annotation], ...]:
- if isinstance(relation, BinaryRelation):
- return (HEAD, relation.head), (TAIL, relation.tail)
- elif isinstance(relation, NaryRelation):
- # create unique order by sorting the arguments by their start and end positions and role
- sorted_args = sorted(
- zip(relation.roles, relation.arguments),
- key=lambda role_and_span: (
- role_and_span[1].start,
- role_and_span[1].end,
- role_and_span[0],
- ),
- )
- return tuple(sorted_args)
- else:
- raise NotImplementedError(
- f"the taskmodule does not yet support getting relation arguments for type: {type(relation)}"
- )
-
-
-def construct_mask(input_ids: torch.LongTensor, positive_ids: List[Any]) -> torch.LongTensor:
- """Construct a mask for the input_ids where all entries in mask_ids are 1."""
- masks = [torch.nonzero(input_ids == marker_token_id) for marker_token_id in positive_ids]
- globs = torch.cat(masks)
- value = torch.ones(globs.shape[0], dtype=int)
- mask = torch.zeros(input_ids.shape, dtype=int)
- mask.index_put_(tuple(globs.t()), value)
- return mask
-
-
-S = TypeVar("S", bound=Span)
-
-
-def shift_span(span: S, offset: int) -> S:
- return span.copy(start=span.start + offset, end=span.end + offset)
-
-
-def bio_encode_spans(
- spans: List[Tuple[int, int, str]], total_length: int, label2idx: Dict[str, int]
-) -> List[int]:
- # result = ["O"] * total_length
- result = [0] * total_length
- for start, end, label in spans:
- # result[start] = f"B-{label}"
- result[start] = label2idx[label] * 2 + 1
- for i in range(start + 1, end):
- # result[i] = f"I-{label}"
- result[i] = label2idx[label] * 2 + 2
- return result
-
-
-@TaskModule.register()
-class RETextClassificationWithIndicesTaskModule(
- RelationStatisticsMixin,
- TaskModuleType,
- ChangesTokenizerVocabSize,
-):
- """Marker based relation extraction. This taskmodule prepares the input token ids in such a way
- that before and after the candidate head and tail entities special marker tokens are inserted.
- Then, the modified token ids can be simply passed into a transformer based text classifier
- model.
-
- parameters:
-
- partition_annotation: str, optional. If specified, LabeledSpan annotations with this name are
- expected to define partitions of the document that will be processed individually, e.g. sentences
- or sections of the document text.
- none_label: str, defaults to "no_relation". The relation label that indicate dummy/negative relations.
- Predicted relations with that label will not be added to the document(s).
- max_window: int, optional. If specified, use the tokens in a window of maximal this amount of tokens
- around the center of head and tail entities and pass only that into the transformer.
- create_relation_candidates: bool, defaults to False. If True, create relation candidates by pairwise
- combining all entities in the document and assigning the none_label. If the document already contains
- a relation with the entity pair, we do not add it again. If False, assume that the document already
- contains relation annotations including negative examples (i.e. relations with the none_label).
- handle_relations_with_same_arguments: str, defaults to "keep_none". If "keep_none", all relations that
- share same arguments will be removed. If "keep_first", first occurred duplicate will be kept.
- argument_type_whitelist: List[List[str]], optional, defaults to None. If set, only relations (candidates)
- with given argument type tuples are created from document and by by `create_relation_candidates`.
- This affects only model input.
- argument_and_relation_type_whitelist: Union[Dict[str, List[List[str]]], List[List[str]]], optional,
- defaults None. If set, only given relation types with given argument types will persist in
- documents and generated by `create_relation_candidates`. This also affects predictions on
- `decode()`, so it strictly filters both model input and output. Can also be passed as a list
- of lists, where the first element is the relation type and the rest are the argument types.
- """
-
- PREPARED_ATTRIBUTES = ["labels", "entity_labels"]
-
- def __init__(
- self,
- tokenizer_name_or_path: str,
- relation_annotation: str = "binary_relations",
- add_candidate_relations: bool = False,
- add_reversed_relations: bool = False,
- partition_annotation: Optional[str] = None,
- none_label: str = "no_relation",
- padding: Union[bool, str, PaddingStrategy] = True,
- truncation: Union[bool, str, TruncationStrategy] = True,
- max_length: Optional[int] = None,
- pad_to_multiple_of: Optional[int] = None,
- multi_label: bool = False,
- labels: Optional[List[str]] = None,
- label_to_id: Optional[Dict[str, int]] = None,
- add_type_to_marker: bool = False,
- argument_role_to_marker: Optional[Dict[str, str]] = None,
- single_argument_pair: bool = True,
- append_markers: bool = False,
- insert_markers: bool = True,
- entity_labels: Optional[List[str]] = None,
- reversed_relation_label_suffix: str = "_reversed",
- symmetric_relations: Optional[List[str]] = None,
- reverse_symmetric_relations: bool = True,
- max_argument_distance: Optional[int] = None,
- max_argument_distance_type: str = "inner",
- max_argument_distance_tokens: Optional[int] = None,
- max_argument_distance_type_tokens: str = "inner",
- max_window: Optional[int] = None,
- allow_discontinuous_text: bool = False,
- log_first_n_examples: int = 0,
- add_argument_indices_to_input: bool = False,
- add_argument_tags_to_input: bool = False,
- add_entity_tags_to_input: bool = False,
- add_global_attention_mask_to_input: bool = False,
- argument_type_whitelist: Optional[List[List[str]]] = None,
- handle_relations_with_same_arguments: str = "keep_none",
- argument_and_relation_type_whitelist: Optional[
- Union[Dict[str, List[List[str]]], List[List[str]]]
- ] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- if label_to_id is not None:
- logger.warning(
- "The parameter label_to_id is deprecated and will be removed in a future version. "
- "Please use labels instead."
- )
- id_to_label = {v: k for k, v in label_to_id.items()}
- # reconstruct labels from label_to_id. Note that we need to remove the none_label
- labels = [
- id_to_label[i] for i in range(len(id_to_label)) if id_to_label[i] != none_label
- ]
- self.save_hyperparameters(ignore=["label_to_id"])
-
- self.relation_annotation = relation_annotation
- self.add_candidate_relations = add_candidate_relations
- self.add_reversed_relations = add_reversed_relations
- self.padding = padding
- self.truncation = truncation
- self.labels = labels
- self.max_length = max_length
- self.pad_to_multiple_of = pad_to_multiple_of
- self.multi_label = multi_label
- self.add_type_to_marker = add_type_to_marker
- self.single_argument_pair = single_argument_pair
- self.append_markers = append_markers
- self.insert_markers = insert_markers
- self.entity_labels = entity_labels
- self.partition_annotation = partition_annotation
- self.none_label = none_label
- self.reversed_relation_label_suffix = reversed_relation_label_suffix
- self.symmetric_relations = set(symmetric_relations or [])
- self.reverse_symmetric_relations = reverse_symmetric_relations
- self.max_argument_distance = max_argument_distance
- self.max_argument_distance_type = max_argument_distance_type
- self.max_argument_distance_tokens = max_argument_distance_tokens
- self.max_argument_distance_type_tokens = max_argument_distance_type_tokens
- self.max_window = max_window
- self.allow_discontinuous_text = allow_discontinuous_text
- self.handle_relations_with_same_arguments = handle_relations_with_same_arguments
- self.argument_type_whitelist: Optional[Set[Tuple[str, ...]]] = None
- self.argument_and_relation_type_whitelist: Optional[Dict[str, Set[Tuple[str, ...]]]] = None
-
- if argument_type_whitelist is not None:
- # hydra does not support tuples, so we got lists and need to convert them
- self.argument_type_whitelist = {tuple(types) for types in argument_type_whitelist}
- if argument_and_relation_type_whitelist is not None:
- # hydra does not support tuples, so we got lists and need to convert them
- if isinstance(argument_and_relation_type_whitelist, list):
- self.argument_and_relation_type_whitelist = defaultdict(set)
- for types_list in argument_and_relation_type_whitelist:
- if len(types_list) < 1:
- raise ValueError(
- "argument_and_relation_type_whitelist must be a list of lists with at least one element"
- )
- self.argument_and_relation_type_whitelist[types_list[0]].add(
- tuple(types_list[1:])
- )
- else:
- self.argument_and_relation_type_whitelist = {
- rel: {tuple(types) for types in types_list}
- for rel, types_list in argument_and_relation_type_whitelist.items()
- }
- # overwrite None with 0 for backward compatibility
- self.log_first_n_examples = log_first_n_examples or 0
- self.add_argument_indices_to_input = add_argument_indices_to_input
- self.add_argument_tags_to_input = add_argument_tags_to_input
- self.add_entity_tags_to_input = add_entity_tags_to_input
- self.add_global_attention_mask_to_input = add_global_attention_mask_to_input
- if argument_role_to_marker is None:
- self.argument_role_to_marker = {HEAD: "H", TAIL: "T"}
- else:
- self.argument_role_to_marker = argument_role_to_marker
-
- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
-
- # used when allow_discontinuous_text
- self.glue_token_ids = self._get_glue_token_ids()
-
- self.argument_markers = None
-
- self._logged_examples_counter = 0
-
- def _get_glue_token_ids(self):
- dummy_ids = self.tokenizer.build_inputs_with_special_tokens(
- token_ids_0=[-1], token_ids_1=[-2]
- )
- return dummy_ids[dummy_ids.index(-1) + 1 : dummy_ids.index(-2)]
-
- @property
- def document_type(self) -> Optional[Type[DocumentType]]:
- if self.partition_annotation is not None:
- dt = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
- else:
- dt = TextDocumentWithLabeledSpansAndBinaryRelations
- if self.relation_annotation == "binary_relations":
- return dt
- else:
- logger.warning(
- f"relation_annotation={self.relation_annotation} is "
- f"not the default value ('binary_relations'), so the taskmodule {type(self).__name__} can not request "
- f"the usual document type for auto-conversion ({dt.__name__}) because this has the bespoken default "
- f"value as layer name instead of the provided one."
- )
- return None
-
- def get_relation_layer(self, document: Document) -> AnnotationLayer[BinaryRelation]:
- return document[self.relation_annotation]
-
- def get_entity_layer(self, document: Document) -> AnnotationLayer[LabeledSpan]:
- relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document)
- return relations.target_layer
-
- def get_marker_factory(self) -> MarkerFactory:
- return MarkerFactory(role_to_marker=self.argument_role_to_marker)
-
- def _prepare(self, documents: Sequence[DocumentType]) -> None:
- entity_labels: Set[str] = set()
- relation_labels: Set[str] = set()
- for document in documents:
- relations: AnnotationLayer[BinaryRelation] = self.get_relation_layer(document)
- entities: AnnotationLayer[LabeledSpan] = self.get_entity_layer(document)
-
- for entity in entities:
- entity_labels.add(entity.label)
-
- for relation in relations:
- relation_labels.add(relation.label)
- if self.add_reversed_relations:
- if relation.label.endswith(self.reversed_relation_label_suffix):
- raise ValueError(
- f"doc.id={document.id}: the relation label '{relation.label}' already ends with "
- f"the reversed_relation_label_suffix '{self.reversed_relation_label_suffix}', "
- f"this is not allowed because we would not know if we should strip the suffix and "
- f"revert the arguments during inference or not"
- )
- if relation.label not in self.symmetric_relations:
- relation_labels.add(relation.label + self.reversed_relation_label_suffix)
-
- if self.none_label in relation_labels:
- relation_labels.remove(self.none_label)
-
- self.labels = sorted(relation_labels)
- self.entity_labels = sorted(entity_labels)
-
- def encode(self, *args, **kwargs):
- self.reset_statistics()
- res = super().encode(*args, **kwargs)
- self.show_statistics()
- return res
-
- def _post_prepare(self):
- self.label_to_id = {label: i + 1 for i, label in enumerate(self.labels)}
- self.label_to_id[self.none_label] = 0
- self.id_to_label = {v: k for k, v in self.label_to_id.items()}
-
- self.marker_factory = self.get_marker_factory()
- self.argument_markers = self.marker_factory.get_all_markers(
- append_markers=self.append_markers,
- add_type_to_marker=self.add_type_to_marker,
- entity_labels=self.entity_labels,
- )
- self.tokenizer.add_tokens(self.argument_markers, special_tokens=True)
-
- self.argument_markers_to_id = {
- marker: self.tokenizer.vocab[marker] for marker in self.argument_markers
- }
-
- self.argument_role2idx = {
- role: i for i, role in enumerate(sorted(self.marker_factory.all_roles))
- }
-
- def _add_reversed_relations(
- self,
- arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation],
- doc_id: Optional[str] = None,
- ) -> None:
- if self.add_reversed_relations:
- for arguments, rel in list(arguments2relation.items()):
- arg_roles, arg_spans = zip(*arguments)
- if isinstance(rel, BinaryRelation):
- label = rel.label
- if label in self.symmetric_relations and not self.reverse_symmetric_relations:
- continue
- if label.endswith(self.reversed_relation_label_suffix):
- raise ValueError(
- f"doc.id={doc_id}: The relation has the label '{label}' which already ends with the "
- f"reversed_relation_label_suffix='{self.reversed_relation_label_suffix}'. "
- f"It looks like the relation is already reversed, which is not allowed."
- )
- if rel.label not in self.symmetric_relations:
- label += self.reversed_relation_label_suffix
-
- reversed_rel = BinaryRelation(
- head=rel.tail,
- tail=rel.head,
- label=label,
- score=rel.score,
- )
- reversed_arguments = get_relation_argument_spans_and_roles(reversed_rel)
- if reversed_arguments in arguments2relation:
- prev_rel = arguments2relation[reversed_arguments]
- prev_label = prev_rel.label
- logger.warning(
- f"doc.id={doc_id}: there is already a relation with reversed "
- f"arguments={reversed_arguments} and label={prev_label}, so we do not add the reversed "
- f"relation (with label {prev_label}) for these arguments"
- )
- if self.collect_statistics:
- self.collect_relation("skipped_reversed_same_arguments", reversed_rel)
- continue
- elif rel.label in self.symmetric_relations:
- # warn if the original relation arguments were not sorted by their start and end positions
- # in the case of symmetric relations
- if not all(isinstance(arg_span, Span) for arg_span in arg_spans):
- raise NotImplementedError(
- f"doc.id={doc_id}: the taskmodule does not yet support adding reversed relations "
- f"for symmetric relations with arguments that are no Spans: {arguments}"
- )
- args_sorted = sorted(
- [rel.head, rel.tail], key=lambda span: (span.start, span.end)
- )
- if args_sorted != [rel.head, rel.tail]:
- logger.warning(
- f"doc.id={doc_id}: The symmetric relation with label '{label}' has arguments "
- f"{arguments} which are not sorted by their start and end positions. "
- f"This may lead to problems during evaluation because we assume that the "
- f"arguments of symmetric relations were sorted in the beginning and, thus, interpret "
- f"relations where this is not the case as reversed. All reversed relations will get "
- f"their arguments swapped during inference in the case of add_reversed_relations=True "
- f"to remove duplicates. You may consider adding reversed versions of the *symmetric* "
- f"relations on your own and then setting *reverse_symmetric_relations* to False."
- )
- if self.collect_statistics:
- self.collect_relation(
- "used_not_sorted_reversed_arguments", reversed_rel
- )
-
- arguments2relation[reversed_arguments] = reversed_rel
- else:
- raise NotImplementedError(
- f"doc.id={doc_id}: the taskmodule does not yet support adding reversed relations for type: "
- f"{type(rel)}"
- )
-
- def _filter_relations_by_argument_and_relation_type_whitelist(
- self,
- arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation],
- doc_id: Optional[str] = None,
- ) -> None:
- if self.argument_and_relation_type_whitelist is not None:
- for arguments, relation in list(arguments2relation.items()):
- argument_labels = tuple(getattr(ann, "label") for role, ann in arguments)
- relation_label = getattr(relation, "label")
- if (
- relation_label not in self.argument_and_relation_type_whitelist
- or argument_labels
- not in self.argument_and_relation_type_whitelist[relation_label]
- ):
- rel = arguments2relation.pop(arguments)
- self.collect_relation("skipped_argument_and_relation_type_whitelist", rel)
-
- def _filter_relations_by_argument_type_whitelist(
- self,
- arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation],
- doc_id: Optional[str] = None,
- ) -> None:
- if self.argument_type_whitelist is not None:
- for arguments, rel in list(arguments2relation.items()):
- argument_labels = tuple(getattr(arg, "label") for _, arg in arguments)
- if argument_labels not in self.argument_type_whitelist:
- rel = arguments2relation.pop(arguments)
- self.collect_relation("skipped_argument_type_whitelist", rel)
-
- def _add_candidate_relations(
- self,
- arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation],
- entities: Iterable[Span],
- arguments_blacklist: Optional[Set[Tuple[Tuple[str, Annotation], ...]]] = None,
- doc_id: Optional[str] = None,
- ) -> None:
- if self.add_candidate_relations:
- if self.marker_factory.all_roles == {HEAD, TAIL}:
- # flatten argument_and_relation_type_whitelist values
- arg_rel_whitelist_vals_set = (
- None
- if self.argument_and_relation_type_whitelist is None
- else {i for j in self.argument_and_relation_type_whitelist.values() for i in j}
- )
- # iterate over all possible argument candidates
- for head in entities:
- for tail in entities:
- if head == tail:
- continue
-
- # Create a relation candidate with the none label. Otherwise, we use the existing relation.
- new_relation = BinaryRelation(
- head=head, tail=tail, label=self.none_label, score=1.0
- )
- new_relation_args = get_relation_argument_spans_and_roles(new_relation)
- arg_roles, arg_spans = zip(*new_relation_args)
- arg_labels = tuple(getattr(ann, "label") for ann in arg_spans)
-
- # Skip if argument_type_whitelist and/or argument_and_relation_type_whitelist
- # are defined and current candidates do not fit.
- if (
- self.argument_type_whitelist is not None
- and arg_labels not in self.argument_type_whitelist
- ) or (
- arg_rel_whitelist_vals_set is not None
- and arg_labels not in arg_rel_whitelist_vals_set
- ):
- continue
-
- # check blacklist
- if (
- arguments_blacklist is not None
- and new_relation_args in arguments_blacklist
- ):
- continue
-
- # we use the new relation only if there is no existing relation with the same arguments
- if new_relation_args not in arguments2relation:
- arguments2relation[new_relation_args] = new_relation
- else:
- raise NotImplementedError(
- f"doc.id={doc_id}: the taskmodule does not yet support adding relation candidates "
- f"with argument roles other than 'head' and 'tail': {sorted(self.marker_factory.all_roles)}"
- )
-
- def _filter_relations_by_argument_distance(
- self,
- arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation],
- doc_id: Optional[str] = None,
- ) -> None:
- if self.max_argument_distance is not None:
- for arguments, rel in list(arguments2relation.items()):
- if isinstance(rel, BinaryRelation):
- if isinstance(rel.head, Span) and isinstance(rel.tail, Span):
- dist = span_distance(
- (rel.head.start, rel.head.end),
- (rel.tail.start, rel.tail.end),
- self.max_argument_distance_type,
- )
- if dist > self.max_argument_distance:
- arguments2relation.pop(arguments)
- self.collect_relation("skipped_argument_distance", rel)
- else:
- raise NotImplementedError(
- f"doc.id={doc_id}: the taskmodule does not yet support filtering relation candidates "
- f"with arguments of type: {type(rel.head)} and {type(rel.tail)}"
- )
- else:
- raise NotImplementedError(
- f"doc.id={doc_id}: the taskmodule does not yet support filtering relation candidates for "
- f"type: {type(rel)}"
- )
-
- def encode_input(
- self,
- document: DocumentType,
- is_training: bool = False,
- ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
- all_relations: Sequence[Annotation] = self.get_relation_layer(document)
- all_entities: Sequence[Span] = self.get_entity_layer(document)
- self.collect_all_relations("available", all_relations)
-
- partitions: Sequence[Span]
- if self.partition_annotation is not None:
- partitions = document[self.partition_annotation]
- if len(partitions) == 0:
- logger.warning(
- f"the document {document.id} has no '{self.partition_annotation}' partition entries, "
- f"no inputs will be created!"
- )
- else:
- # use single dummy partition
- partitions = [Span(start=0, end=len(document.text))]
-
- task_encodings: List[TaskEncodingType] = []
- for partition in partitions:
- # get all entities that are contained in the current partition
- entities: List[Span] = [
- entity
- for entity in all_entities
- if is_contained_in((entity.start, entity.end), (partition.start, partition.end))
- ]
-
- # Create a mapping from relation arguments to the respective relation objects.
- # Note that the data can contain multiple relations with the same arguments.
- entities_set = set(entities)
- arguments2relations: Dict[Tuple[Tuple[str, Annotation], ...], List[Annotation]] = (
- defaultdict(list)
- )
- for rel in all_relations:
- # Skip relations with unknown labels. Use label_to_id because that contains the none_label
- if rel.label not in self.label_to_id:
- self.collect_relation("skipped_unknown_label", rel)
- continue
-
- arguments = get_relation_argument_spans_and_roles(rel)
- arg_roles, arg_spans = zip(*arguments)
-
- # filter out all relations that are completely outside the current partition
- if all(arg_span not in entities_set for arg_span in arg_spans):
- continue
-
- # filter relations that are only partially contained in the current partition,
- # i.e. some arguments are in the partition and some are not
- if any(arg_span not in entities_set for arg_span in arg_spans):
- logger.warning(
- f"doc.id={document.id}: there is a relation with label '{rel.label}' and arguments "
- f"{arguments} that is only partially contained in the current partition. "
- f"We skip this relation."
- )
- self.collect_relation("skipped_partially_contained", rel)
- continue
- arguments2relations[arguments].append(rel)
-
- # resolve duplicates for same arguments
- arguments2relation: Dict[Tuple[Tuple[str, Annotation], ...], Annotation] = {}
- # we will never create an encoding for the relation candidates in arguments_blacklist
- arguments_blacklist: Set[Tuple[Tuple[str, Annotation], ...]] = set()
- for arguments, relations in arguments2relations.items():
- relations_set = set(relations)
- # more than one unique relation with the same arguments
- if len(relations_set) > 1:
- arguments_resolved = tuple(map(lambda x: (x[0], x[1].resolve()), arguments))
- labels = [rel.label for rel in relations]
- if self.handle_relations_with_same_arguments == "keep_first":
- # keep only the first relation
- arguments2relation[arguments] = relations[0]
- for discard_rel in set(relations) - {
- relations[0]
- }: # remove all other relations
- self.collect_relation("skipped_same_arguments", discard_rel)
- if not self.collect_statistics:
- # We show this warning only if statistics are disabled.
- # We want to be informed if such skip occurs, but having it in statistics and
- # getting lots of warnings in the same time seemed overwhelming.
- logger.warning(
- f"doc.id={document.id}: there are multiple relations with the same arguments "
- f"{arguments_resolved}, but different labels: {labels}. We only keep the first "
- f"occurring relation which has the label='{relations[0].label}'."
- )
- elif self.handle_relations_with_same_arguments == "keep_none":
- # add these arguments to the blacklist to not add them as 'no-relation's back again
- arguments_blacklist.add(arguments)
- # remove all relations with the same arguments
- for discard_rel in relations_set:
- self.collect_relation("skipped_same_arguments", discard_rel)
- if not self.collect_statistics:
- logger.warning(
- f"doc.id={document.id}: there are multiple relations with the same arguments "
- f"{arguments_resolved}, but different labels: {labels}. All relations will be removed."
- )
- else:
- raise ValueError(
- f"'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', "
- f"but got `{self.handle_relations_with_same_arguments}`."
- )
- else:
- arguments2relation[arguments] = relations[0]
- # more than one duplicate relation (with the same arguments)
- if len(relations) > 1:
- # if 'collect_statistics=true' such duplicates won't be collected and are not counted in
- # statistics if 'collect_statistics=true' either as 'available' or as 'skipped_same_arguments'
- logger.warning(
- f"doc.id={document.id}: Relation annotation `{rel.resolve()}` is duplicated. "
- f"We keep only one of them. Duplicate won't appear in statistics either as 'available' "
- f"or as skipped."
- )
-
- # We use this filter before adding reversed relations because we also don't want them to be reversed
- self._filter_relations_by_argument_and_relation_type_whitelist(
- arguments2relation=arguments2relation, doc_id=document.id
- )
- self._add_reversed_relations(arguments2relation=arguments2relation, doc_id=document.id)
- self._filter_relations_by_argument_type_whitelist(
- arguments2relation=arguments2relation, doc_id=document.id
- )
- self._add_candidate_relations(
- arguments2relation=arguments2relation,
- arguments_blacklist=arguments_blacklist,
- entities=entities,
- doc_id=document.id,
- )
-
- self._filter_relations_by_argument_distance(
- arguments2relation=arguments2relation, doc_id=document.id
- )
-
- without_special_tokens = self.max_window is not None
- text = document.text[partition.start : partition.end]
- encoding = self.tokenizer(
- text,
- padding=False,
- truncation=self.truncation if self.max_window is None else False,
- max_length=self.max_length,
- is_split_into_words=False,
- return_offsets_mapping=False,
- add_special_tokens=not without_special_tokens,
- )
-
- for arguments, rel in arguments2relation.items():
- arg_roles, arg_spans = zip(*arguments)
- if not all(isinstance(arg, LabeledSpan) for arg in arg_spans):
- # TODO: add test case for this
- raise ValueError(
- f"the taskmodule expects the relation arguments to be of type LabeledSpan, "
- f"but got {[type(arg) for arg in arg_spans]}"
- )
-
- arg_spans_partition = [
- shift_span(span, offset=-partition.start) for span in arg_spans
- ]
- # map character spans to token spans
- try:
- arg_token_spans = [
- get_aligned_token_span(
- encoding=encoding,
- char_span=arg,
- )
- for arg in arg_spans_partition
- ]
- # Check if the mapping was successful. It may fail (and is None) if any argument start or end does not
- # match a token start or end, respectively.
- except SpanNotAlignedWithTokenException as e:
- span_original = shift_span(e.span, offset=partition.start)
- # the span is not attached because we shifted it above, so we can not use str(e.span)
- span_text = document.text[span_original.start : span_original.end]
- logger.warning(
- f"doc.id={document.id}: Skipping invalid example, cannot get argument token slice for "
- f'{span_original}: "{span_text}"'
- )
- self.collect_relation("skipped_args_not_aligned", rel)
- continue
-
- # create the argument objects
- args = [
- RelationArgument(
- entity=span,
- role=role,
- token_span=token_span,
- add_type_to_marker=self.add_type_to_marker,
- marker_factory=self.marker_factory,
- )
- for span, role, token_span in zip(arg_spans, arg_roles, arg_token_spans)
- ]
-
- if self.max_argument_distance_tokens is not None:
- token_distances = []
- for idx1 in range(len(args) - 1):
- for idx in range(idx1 + 1, len(args)):
- arg1 = args[idx1]
- arg2 = args[idx]
- dist = span_distance(
- (arg1.token_span.start, arg1.token_span.end),
- (arg2.token_span.start, arg2.token_span.end),
- self.max_argument_distance_type_tokens,
- )
- token_distances.append(dist)
- if len(token_distances) > 0:
- if self.max_argument_distance_type_tokens == "outer":
- max_dist = max(token_distances)
- elif self.max_argument_distance_type_tokens == "inner":
- if len(args) > 2:
- raise NotImplementedError(
- f"max_argument_distance_type_tokens={self.max_argument_distance_type_tokens} "
- f"is not supported for relations with more than 2 arguments"
- )
- max_dist = max(token_distances)
- else:
- raise NotImplementedError(
- f"max_argument_distance_type_tokens={self.max_argument_distance_type_tokens} "
- f"is not supported"
- )
- if max_dist > self.max_argument_distance_tokens:
- self.collect_relation("skipped_argument_distance_tokens", rel)
- continue
-
- input_ids = encoding["input_ids"]
-
- entity_tags = None
- if self.add_entity_tags_to_input:
- entity_spans_partition = [
- shift_span(span, offset=-partition.start) for span in entities
- ]
- entity_token_spans = []
- for span in entity_spans_partition:
- try:
- entity_token_spans.append(
- get_aligned_token_span(
- encoding=encoding,
- char_span=span,
- )
- )
- except SpanNotAlignedWithTokenException as e:
- span_original = shift_span(e.span, offset=partition.start)
- span_text = document.text[span_original.start : span_original.end]
- logger.warning(
- f"doc.id={document.id}: Skipping invalid example, cannot get entity token slice for "
- f'{span_original}: "{span_text}"'
- )
- self.collect_relation("skipped_entity_not_aligned", rel)
- continue
-
- entity_tags = bio_encode_spans(
- spans=[
- (span.start, span.end, getattr(span, "label", "ENTITY"))
- for span in entity_token_spans
- ],
- total_length=len(input_ids),
- label2idx={
- label: idx for idx, label in enumerate(self.entity_labels or [])
- },
- )
-
- # windowing: we restrict the input to a window of a maximal size (max_window) with the arguments
- # of the candidate relation in the center (as much as possible)
- if self.max_window is not None:
- # The actual number of tokens needs to be lower than max_window because we add two
- # marker tokens (before / after) each argument and the default special tokens
- # (e.g. CLS and SEP).
- max_tokens = self.max_window - self.tokenizer.num_special_tokens_to_add()
- if self.insert_markers:
- max_tokens -= len(args) * 2
- # if we add the markers also to the end, this decreases the available window again by
- # two tokens (marker + sep) per argument
- if self.append_markers:
- # TODO: add test case for this
- max_tokens -= len(args) * 2
-
- if self.allow_discontinuous_text:
- if entity_tags is not None:
- raise NotImplementedError(
- "allow_discontinuous_text=True is not yet supported with add_entity_tags_to_input=True"
- )
-
- max_tokens_per_argument = max_tokens // len(args)
- max_tokens_per_argument -= len(self.glue_token_ids)
- if any(
- arg.token_span.end - arg.token_span.start > max_tokens_per_argument
- for arg in args
- ):
- self.collect_relation("skipped_too_long_argument", rel)
- continue
-
- mask = np.zeros_like(input_ids)
- for arg in args:
- # if the input is already fully covered by one argument frame, we keep everything
- if len(input_ids) <= max_tokens_per_argument:
- mask[:] = 1
- break
- arg_center = (arg.token_span.end + arg.token_span.start) // 2
- arg_frame_start = arg_center - max_tokens_per_argument // 2
- # shift the frame to the right if it is out of bounds
- if arg_frame_start < 0:
- arg_frame_start = 0
- arg_frame_end = arg_frame_start + max_tokens_per_argument
- # shift the frame to the left if it is out of bounds
- # Note that this can not cause to have arg_frame_start < 0 because we already
- # checked that the frame is not larger than the input.
- if arg_frame_end > len(input_ids):
- arg_frame_end = len(input_ids)
- arg_frame_start = arg_frame_end - max_tokens_per_argument
- # still, a sanity check
- if arg_frame_start < 0:
- raise ValueError(
- f"arg_frame_start={arg_frame_start} < 0 after adjusting arg_frame_end={arg_frame_end}"
- )
- mask[arg_frame_start:arg_frame_end] = 1
- offsets = np.cumsum(mask != 1)
- arg_cluster_offset_values = set()
- # sort by start indices
- args_sorted = sorted(args, key=lambda x: x.token_span.start)
- for arg in args_sorted:
- offset = offsets[arg.token_span.start]
- arg_cluster_offset_values.add(offset)
- arg.shift_token_span(-offset)
- # shift back according to inserted glue patterns
- num_glues = len(arg_cluster_offset_values) - 1
- arg.shift_token_span(num_glues * len(self.glue_token_ids))
-
- new_input_ids: List[int] = []
- for arg_cluster_offset_value in sorted(arg_cluster_offset_values):
- if len(new_input_ids) > 0:
- new_input_ids.extend(self.glue_token_ids)
- segment_mask = offsets == arg_cluster_offset_value
- segment_input_ids = [
- input_id
- for input_id, keep in zip(input_ids, mask & segment_mask)
- if keep
- ]
- new_input_ids.extend(segment_input_ids)
-
- input_ids = new_input_ids
- else:
- # the slice from the beginning of the first entity to the end of the second is required
- slice_required = (
- min(arg.token_span.start for arg in args),
- max(arg.token_span.end for arg in args),
- )
- window_slice = get_window_around_slice(
- slice=slice_required,
- max_window_size=max_tokens,
- available_input_length=len(input_ids),
- )
- # this happens if slice_required (all arguments) does not fit into max_tokens (the available window)
- if window_slice is None:
- self.collect_relation("skipped_too_long", rel)
- continue
-
- window_start, window_end = window_slice
- input_ids = input_ids[window_start:window_end]
-
- if entity_tags is not None:
- entity_tags = entity_tags[window_start:window_end]
-
- for arg in args:
- arg.shift_token_span(-window_start)
-
- # collect all markers with their target positions, the source argument, and
- marker_ids_with_positions = []
- for arg in args:
- marker_ids_with_positions.append(
- (
- self.argument_markers_to_id[arg.as_start_marker],
- arg.token_span.start,
- arg,
- START,
- )
- )
- marker_ids_with_positions.append(
- (
- self.argument_markers_to_id[arg.as_end_marker],
- arg.token_span.end,
- arg,
- END,
- )
- )
-
- # create new input ids with the markers inserted and collect new mention offsets
- input_ids_with_markers = list(input_ids)
- offset = 0
- arg_start_indices = [-1] * len(self.argument_role2idx)
- arg_end_indices = [-1] * len(self.argument_role2idx)
- marker_ids_with_positions_sorted = sorted(
- marker_ids_with_positions, key=lambda id_pos: id_pos[1]
- )
- for (
- marker_id,
- token_position,
- arg,
- marker_type,
- ) in marker_ids_with_positions_sorted:
- if self.insert_markers:
- input_ids_with_markers = (
- input_ids_with_markers[: token_position + offset]
- + [marker_id]
- + input_ids_with_markers[token_position + offset :]
- )
- if entity_tags is not None:
- entity_tags = (
- entity_tags[: token_position + offset]
- + [0]
- + entity_tags[token_position + offset :]
- )
- offset += 1
- if self.add_argument_indices_to_input or self.add_argument_tags_to_input:
- idx = self.argument_role2idx[arg.role]
- if marker_type == START:
- if arg_start_indices[idx] != -1:
- # TODO: add test case for this
- raise ValueError(
- f"Trying to overwrite arg_start_indices[{idx}]={arg_start_indices[idx]} with "
- f"{token_position + offset} for document {document.id}"
- )
- arg_start_indices[idx] = token_position + offset
- elif marker_type == END:
- if arg_end_indices[idx] != -1:
- # TODO: add test case for this
- raise ValueError(
- f"Trying to overwrite arg_start_indices[{idx}]={arg_end_indices[idx]} with "
- f"{token_position + offset} for document {document.id}"
- )
- # -1 to undo the additional offset for the end marker which does not
- # affect the mention offset
- arg_end_indices[idx] = (
- token_position + offset - (1 if self.insert_markers else 0)
- )
-
- if self.append_markers:
- if self.tokenizer.sep_token is None:
- # TODO: add test case for this
- raise ValueError("append_markers is True, but tokenizer has no sep_token")
- sep_token_id = self.tokenizer.vocab[self.tokenizer.sep_token]
- for arg in args:
- if without_special_tokens:
- # TODO: add test case for this
- input_ids_with_markers.append(sep_token_id)
- input_ids_with_markers.append(
- self.argument_markers_to_id[arg.as_append_marker]
- )
- else:
- input_ids_with_markers.append(
- self.argument_markers_to_id[arg.as_append_marker]
- )
- input_ids_with_markers.append(sep_token_id)
- if entity_tags is not None:
- entity_tags.append(0)
- entity_tags.append(0)
-
- # when windowing is used, we have to add the special tokens manually
- if without_special_tokens:
- original_input_ids_with_markers = input_ids_with_markers
- input_ids_with_markers = self.tokenizer.build_inputs_with_special_tokens(
- token_ids_0=input_ids_with_markers
- )
- if self.add_argument_indices_to_input or self.add_argument_tags_to_input:
- # get the number of prefix tokens
- index_offset = find_sublist(
- sub=original_input_ids_with_markers, bigger=input_ids_with_markers
- )
- if index_offset == -1:
- raise ValueError(
- f"Could not find the original tokens in the prefixed tokens for document {document.id}"
- )
- arg_start_indices = [
- idx + index_offset if idx != -1 else -1 for idx in arg_start_indices
- ]
- arg_end_indices = [
- idx + index_offset if idx != -1 else -1 for idx in arg_end_indices
- ]
- if entity_tags is not None:
- special_tokens_mask = self.tokenizer.get_special_tokens_mask(
- token_ids_0=input_ids_with_markers, already_has_special_tokens=True
- )
- entity_tags_with_special = self.tokenizer.build_inputs_with_special_tokens(
- token_ids_0=entity_tags
- )
- entity_tags = [
- tag if not is_special else 0
- for tag, is_special in zip(
- entity_tags_with_special, special_tokens_mask
- )
- ]
-
- inputs = {"input_ids": input_ids_with_markers}
- if self.add_argument_indices_to_input:
- inputs["pooler_start_indices"] = arg_start_indices
- inputs["pooler_end_indices"] = arg_end_indices
- if self.add_argument_tags_to_input:
- # create bio-encoded tags for the arguments
- # using arg_start_indices, arg_end_indices, and marker_ids_with_positions_sorted
- argument_spans = [
- (
- arg_start_indices[self.argument_role2idx[arg.role]],
- arg_end_indices[self.argument_role2idx[arg.role]],
- arg.role,
- )
- for marker_id, token_position, arg, marker_type in marker_ids_with_positions_sorted
- ]
- argument_tag_ids = bio_encode_spans(
- spans=argument_spans,
- total_length=len(input_ids_with_markers),
- label2idx=self.argument_role2idx,
- )
- inputs["argument_tags"] = argument_tag_ids
-
- if entity_tags is not None:
- inputs["entity_tags"] = entity_tags
-
- task_encodings.append(
- TaskEncoding(
- document=document,
- inputs=inputs,
- metadata=({"candidate_annotation": rel}),
- )
- )
-
- self.collect_relation("used", rel)
-
- return task_encodings
-
- def _maybe_log_example(
- self,
- task_encoding: TaskEncodingType,
- target: TargetEncodingType,
- ):
- """Maybe log the example."""
-
- # log the first n examples
- if self._logged_examples_counter < self.log_first_n_examples:
- input_ids = task_encoding.inputs["input_ids"]
- tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
- target_labels = [self.id_to_label[label_id] for label_id in target]
- logger.info("*** Example ***")
- logger.info("doc id: %s", task_encoding.document.id)
- logger.info("tokens: %s", " ".join([str(x) for x in tokens]))
- logger.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
- logger.info("Expected label: %s (ids = %s)", target_labels, target)
-
- self._logged_examples_counter += 1
-
- def encode_target(
- self,
- task_encoding: TaskEncodingType,
- ) -> TargetEncodingType:
- candidate_annotation = task_encoding.metadata["candidate_annotation"]
- if isinstance(candidate_annotation, (BinaryRelation, NaryRelation)):
- labels = [candidate_annotation.label]
- else:
- raise NotImplementedError(
- f"encoding the target with a candidate_annotation of another type than BinaryRelation or"
- f"NaryRelation is not yet supported. candidate_annotation has the type: "
- f"{type(candidate_annotation)}"
- )
- target = [self.label_to_id[label] for label in labels]
-
- self._maybe_log_example(task_encoding=task_encoding, target=target)
-
- return target
-
- def unbatch_output(self, model_output: ModelTargetType) -> Sequence[TaskOutputType]:
- unbatched_output = []
- if self.multi_label:
- raise NotImplementedError
- else:
- label_ids = model_output["labels"].detach().cpu().tolist()
- probabilities = model_output["probabilities"].detach().cpu().tolist()
- for batch_idx in range(len(label_ids)):
- label_id = label_ids[batch_idx]
- result: TaskOutputType = {
- "labels": [self.id_to_label[label_id]],
- "probabilities": [probabilities[batch_idx][label_id]],
- }
- unbatched_output.append(result)
-
- return unbatched_output
-
- def create_annotations_from_output(
- self,
- task_encoding: TaskEncodingType,
- task_output: TaskOutputType,
- ) -> Iterator[Tuple[str, Union[BinaryRelation, MultiLabeledBinaryRelation, NaryRelation]]]:
- candidate_annotation = task_encoding.metadata["candidate_annotation"]
- new_annotation: Union[BinaryRelation, MultiLabeledBinaryRelation, NaryRelation]
- if self.multi_label:
- raise NotImplementedError
- else:
- label = task_output["labels"][0]
- probability = (
- task_output["probabilities"][0] if "probabilities" in task_output else 1.0
- )
- if isinstance(candidate_annotation, BinaryRelation):
- head = candidate_annotation.head
- tail = candidate_annotation.tail
- # Reverse predicted reversed relations back. Serialization will remove any duplicated relations.
- if self.add_reversed_relations:
- # TODO: add test case for this
- if label.endswith(self.reversed_relation_label_suffix):
- label = label[: -len(self.reversed_relation_label_suffix)]
- head, tail = tail, head
- # If the predicted label is symmetric, we sort the arguments by its center.
- elif label in self.symmetric_relations and self.reverse_symmetric_relations:
- if not (isinstance(head, Span) and isinstance(tail, Span)):
- raise ValueError(
- f"the taskmodule expects the relation arguments of the candidate_annotation"
- f"to be of type Span, but got head of type: {type(head)} and tail of type: "
- f"{type(tail)}"
- )
- # use a unique order for the arguments: sort by start and end positions
- head, tail = sorted([head, tail], key=lambda span: (span.start, span.end))
- new_annotation = BinaryRelation(
- head=head, tail=tail, label=label, score=probability
- )
- elif isinstance(candidate_annotation, NaryRelation):
- # TODO: add test case for this
- if self.add_reversed_relations:
- raise ValueError("can not reverse a NaryRelation")
- new_annotation = NaryRelation(
- arguments=candidate_annotation.arguments,
- roles=candidate_annotation.roles,
- label=label,
- score=probability,
- )
- else:
- raise NotImplementedError(
- f"creating a new annotation from a candidate_annotation of another type than BinaryRelation is "
- f"not yet supported. candidate_annotation has the type: {type(candidate_annotation)}"
- )
-
- new_annotation_args = get_relation_argument_spans_and_roles(new_annotation)
- arg_roles, arg_spans = zip(*new_annotation_args)
- arg_labels = tuple(getattr(ann, "label") for ann in arg_spans)
-
- # Create annotation only if 1. and 2. are fulfilled:
- if (
- # 1. the label is not the no-relation-label,
- label != self.none_label
- # or we did not create candidate relations,
- or not self.add_candidate_relations
- ) and (
- # 2. the argument_and_relation_type_whitelist is not set,
- self.argument_and_relation_type_whitelist is None
- # or the label and argument types are in the whitelist
- or arg_labels in self.argument_and_relation_type_whitelist.get(label, {})
- ):
- yield self.relation_annotation, new_annotation
-
- def _get_global_attention(self, input_ids: torch.LongTensor) -> torch.LongTensor:
- # we want to have global attention on all marker tokens and the cls token
- positive_token_ids = list(self.argument_markers_to_id.values()) + [
- self.tokenizer.cls_token_id
- ]
- global_attention_mask = construct_mask(
- input_ids=input_ids, positive_ids=positive_token_ids
- )
- return global_attention_mask
-
- def collate(
- self, task_encodings: Sequence[TaskEncodingType]
- ) -> Tuple[ModelInputType, Optional[ModelTargetType]]:
- input_features = [
- {"input_ids": task_encoding.inputs["input_ids"]} for task_encoding in task_encodings
- ]
-
- inputs: Dict[str, torch.LongTensor] = self.tokenizer.pad(
- input_features,
- padding=self.padding,
- max_length=self.max_length,
- pad_to_multiple_of=self.pad_to_multiple_of,
- return_tensors="pt",
- )
- if self.add_argument_tags_to_input:
- argument_tags = [
- {"input_ids": task_encoding.inputs["argument_tags"]}
- for task_encoding in task_encodings
- ]
- argument_tags_padded = self.tokenizer.pad(
- argument_tags,
- padding=self.padding,
- max_length=self.max_length,
- pad_to_multiple_of=self.pad_to_multiple_of,
- return_tensors="pt",
- )
- # increase all values by 1 because 0 is used for padding
- inputs["argument_tags"] = argument_tags_padded["input_ids"] + 1
- # overwrite padding with 0
- inputs["argument_tags"][argument_tags_padded["attention_mask"] == 0] = 0
-
- if self.add_entity_tags_to_input:
- entity_tags = [
- {"input_ids": task_encoding.inputs["entity_tags"]}
- for task_encoding in task_encodings
- ]
- entity_tags_padded = self.tokenizer.pad(
- entity_tags,
- padding=self.padding,
- max_length=self.max_length,
- pad_to_multiple_of=self.pad_to_multiple_of,
- return_tensors="pt",
- )
- # increase all values by 1 because 0 is used for padding
- inputs["entity_tags"] = entity_tags_padded["input_ids"] + 1
- # overwrite padding with 0
- inputs["entity_tags"][entity_tags_padded["attention_mask"] == 0] = 0
-
- if self.add_argument_indices_to_input:
- inputs["pooler_start_indices"] = torch.tensor(
- [task_encoding.inputs["pooler_start_indices"] for task_encoding in task_encodings]
- ).to(torch.long)
- inputs["pooler_end_indices"] = torch.tensor(
- [task_encoding.inputs["pooler_end_indices"] for task_encoding in task_encodings]
- ).to(torch.long)
-
- if self.add_global_attention_mask_to_input:
- inputs["global_attention_mask"] = self._get_global_attention(
- input_ids=inputs["input_ids"]
- )
-
- if not task_encodings[0].has_targets:
- return inputs, None
-
- target_list: List[TargetEncodingType] = [
- task_encoding.targets for task_encoding in task_encodings
- ]
- targets = torch.tensor(target_list, dtype=torch.int64)
-
- if not self.multi_label:
- targets = targets.flatten()
-
- return inputs, {"labels": targets}
-
- def configure_model_metric(self, stage: str) -> MetricCollection:
- if self.label_to_id is None:
- raise ValueError(
- "The taskmodule has not been prepared yet, so label_to_id is not known. "
- "Please call taskmodule.prepare(documents) before configuring the model metric "
- "or pass the labels to the taskmodule constructor an call taskmodule.post_prepare()."
- )
- # we use the length of label_to_id because that contains the none_label (in contrast to labels)
- labels = [self.id_to_label[i] for i in range(len(self.label_to_id))]
- common_metric_kwargs = {
- "num_classes": len(labels),
- "task": "multilabel" if self.multi_label else "multiclass",
- }
- return MetricCollection(
- {
- "with_tn": WrappedMetricWithPrepareFunction(
- metric=MetricCollection(
- {
- "micro/f1": F1Score(average="micro", **common_metric_kwargs),
- "macro/f1": F1Score(average="macro", **common_metric_kwargs),
- "f1_per_label": ClasswiseWrapper(
- F1Score(average=None, **common_metric_kwargs),
- labels=labels,
- postfix="/f1",
- ),
- }
- ),
- prepare_function=_get_labels,
- ),
- # We can not easily calculate the macro f1 here, because
- # F1Score with average="macro" would still include the none_label.
- "micro/f1_without_tn": WrappedMetricWithPrepareFunction(
- metric=F1Score(average="micro", **common_metric_kwargs),
- prepare_together_function=partial(
- _get_labels_together_remove_none_label,
- none_idx=self.label_to_id[self.none_label],
- ),
- ),
- }
- )
diff --git a/src/pie_modules/taskmodules/text_to_text.py b/src/pie_modules/taskmodules/text_to_text.py
deleted file mode 100644
index 1a90a7854..000000000
--- a/src/pie_modules/taskmodules/text_to_text.py
+++ /dev/null
@@ -1,458 +0,0 @@
-import dataclasses
-import logging
-from functools import partial
-from typing import (
- Any,
- Dict,
- Iterator,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- Union,
-)
-
-import torch
-from pie_core import (
- Annotation,
- AnnotationLayer,
- Document,
- TaskEncoding,
- TaskModule,
-)
-from pie_core.taskmodule import (
- InputEncoding,
- ModelBatchOutput,
- TargetEncoding,
- TaskBatchEncoding,
-)
-from pie_core.utils.hydra import resolve_type
-from torchmetrics import Metric
-from transformers import AutoTokenizer, PreTrainedTokenizer
-from typing_extensions import TypeAlias
-
-from pie_modules.annotations import AnnotationWithText
-from pie_modules.document.processing import (
- token_based_document_to_text_based,
- tokenize_document,
-)
-from pie_modules.documents import TextBasedDocument, TokenBasedDocument
-
-from .common import BatchableMixin, get_first_occurrence_index
-from .metrics import WrappedMetricWithPrepareFunction
-
-logger = logging.getLogger(__name__)
-
-
-DocumentType: TypeAlias = TextBasedDocument
-
-
-@dataclasses.dataclass
-class InputEncodingType(BatchableMixin):
- input_ids: List[int]
- attention_mask: List[int]
-
-
-@dataclasses.dataclass
-class TargetEncodingType(BatchableMixin):
- labels: List[int]
- # this is optional because we use the same type for TaskOutputType, which does not have this field
- decoder_attention_mask: Optional[List[int]] = None
-
-
-TaskEncodingType: TypeAlias = TaskEncoding[
- DocumentType,
- InputEncodingType,
- TargetEncodingType,
-]
-TaskOutputType: TypeAlias = TargetEncodingType
-
-
-# we use a custom un-batch function for metrics, because the text metrics such as ROUGEScore metric expects
-# strings for input and target
-def unbatch_and_untokenize(
- batch: ModelBatchOutput, taskmodule: "TextToTextTaskModule"
-) -> Sequence[str]:
- unbatched = taskmodule.unbatch_output(batch)
- texts = [
- taskmodule.tokenizer.decode(encoding.labels, skip_special_tokens=True)
- for encoding in unbatched
- ]
- return texts
-
-
-@TaskModule.register()
-class TextToTextTaskModule(
- TaskModule[
- DocumentType,
- InputEncoding,
- TargetEncoding,
- TaskBatchEncoding,
- ModelBatchOutput,
- TaskOutputType,
- ],
-):
- """A PIE task module for text-to-text tasks. It works with simple text annotations, e.g.
- abstractive summaries, as target annotations.
-
- It can also be used with additional guidance annotations, e.g. questions for generative question answering, in
- which case the text of the guidance annotation is prepended to the input text.
-
- Args:
- tokenizer_name_or_path: The name (Huggingface Hub model identifier) or local path of the tokenizer to use.
- document_type: The type of the input document. Must be a string that resolves to a subclass of
- TextBasedDocument, e.g. "pie_modules.documents.TextDocumentWithAbstractiveSummary" for abstractive
- summarization.
- tokenized_document_type: The type of the tokenized document. Must be a string that resolves to a
- subclass of TokenBasedDocument, e.g. "pie_modules.documents.TokenDocumentWithAbstractiveSummary" for
- abstractive summarization.
- target_layer: The name of the annotation layer that contains the target annotations, e.g. "abstractive_summary"
- for abstractive summarization.
- target_annotation_type: The type of the target annotations. Must be a string that resolves to a subclass
- of AnnotationWithText, e.g. "pie_modules.annotations.AbstractiveSummary" for abstractive summarization.
- guidance_layer: The name of the annotation layer that contains the guidance annotations. If set, the text of
- the guidance annotation is prepended to the input text.
- guidance_annotation_field: The name of the field in the target annotations that contains the guidance
- annotation. Required if guidance_layer is defined to attach the guidance annotation to the newly created
- target annotation.
- text_metric_type: The type of the text metric to use for evaluation. Must be a string that resolves to a
- subclass of Metric, e.g. "torchmetrics.text.ROUGEScore" for ROUGE score.
- tokenizer_init_kwargs: Additional keyword arguments that are passed to the tokenizer constructor.
- tokenizer_kwargs: Additional keyword arguments that are passed when calling the tokenizer.
- partition_layer_name: The name of the annotation layer that contains the partitions. If set, the partitions
- will be used to split the input text into multiple parts which are then tokenized separately. This can be
- used to split long documents into multiple parts to avoid exceeding the maximum input length of the
- tokenizer / model.
- annotation_field_mapping: A mapping from input document annotation layer names to layer names defined in the
- document_type / tokenized_document_type. This can be used if the actual input documents have different
- annotation layer names than the provided document_type / tokenized_document_type.
- log_first_n_examples: The number of examples to log. If set to a positive integer n, the first n examples will
- be logged. This can be used to check if the input and target encodings are as expected.
- """
-
- def __init__(
- self,
- tokenizer_name_or_path: str,
- document_type: str,
- tokenized_document_type: str,
- target_layer: str,
- target_annotation_type: str,
- guidance_layer: Optional[str] = None,
- guidance_annotation_field: Optional[str] = None,
- text_metric_type: Optional[str] = None,
- tokenizer_init_kwargs: Optional[Dict[str, Any]] = None,
- tokenizer_kwargs: Optional[Dict[str, Any]] = None,
- partition_layer_name: Optional[str] = None,
- annotation_field_mapping: Optional[Dict[str, str]] = None,
- log_first_n_examples: Optional[int] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.save_hyperparameters()
-
- self.target_layer = target_layer
- self.guidance_layer = guidance_layer
- self.target_annotation_type: Type[AnnotationWithText] = resolve_type(
- target_annotation_type, expected_super_type=AnnotationWithText
- )
- self.guidance_annotation_field = guidance_annotation_field
- self.text_metric_type: Optional[Metric] = None
- if text_metric_type is not None:
- self.text_metric_type = resolve_type(text_metric_type, expected_super_type=Metric)
-
- # tokenization
- self._document_type: Type[TextBasedDocument] = resolve_type(
- document_type, expected_super_type=TextBasedDocument
- )
- self._tokenized_document_type: Type[TokenBasedDocument] = resolve_type(
- tokenized_document_type, expected_super_type=TokenBasedDocument
- )
- self.tokenizer_name_or_path = tokenizer_name_or_path
- self.tokenizer_kwargs = tokenizer_kwargs or {}
- self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
- tokenizer_name_or_path,
- **(tokenizer_init_kwargs or {}),
- )
- self.annotation_field_mapping = annotation_field_mapping or dict()
- self.partition_layer_name = partition_layer_name
-
- # target encoding
- self.pad_values = {
- "input_ids": self.tokenizer.pad_token_id,
- "attention_mask": 0,
- "labels": self.tokenizer.pad_token_id,
- "decoder_attention_mask": 0,
- }
- self.dtypes = {
- "input_ids": torch.int64,
- "attention_mask": torch.int64,
- "labels": torch.int64,
- "decoder_attention_mask": torch.int64,
- }
-
- # logging
- self.log_first_n_examples = log_first_n_examples
-
- @property
- def document_type(self) -> Type[TextBasedDocument]:
- return self._document_type
-
- @property
- def tokenized_document_type(self) -> Type[TokenBasedDocument]:
- return self._tokenized_document_type
-
- @property
- def layer_names(self) -> List[str]:
- return [self.target_layer]
-
- def get_mapped_layer(self, document: Document, layer_name: str) -> AnnotationLayer:
- if layer_name in self.annotation_field_mapping:
- layer_name = self.annotation_field_mapping[layer_name]
- return document[layer_name]
-
- @property
- def generation_config(self) -> Dict[str, Any]:
- return {}
-
- def maybe_log_example(
- self,
- task_encoding: TaskEncodingType,
- targets: Optional[TargetEncodingType] = None,
- ) -> None:
- if self.log_first_n_examples is not None and self.log_first_n_examples > 0:
- inputs = task_encoding.inputs
-
- logger.info(f"input_ids: {inputs.input_ids}")
- logger.info(f"attention_mask: {inputs.attention_mask}")
- if targets is not None or task_encoding.has_targets:
- targets = targets or task_encoding.targets
- logger.info(f"labels: {targets.labels}")
- self.log_first_n_examples -= 1
-
- def warn_only_once(self, message: str) -> None:
- if not hasattr(self, "_warned"):
- self._warned: Set[str] = set()
- if message not in self._warned:
- logger.warning(f"{message} (This warning will only be shown once)")
- self._warned.add(message)
-
- def encode_annotations(
- self,
- layers: Dict[str, AnnotationLayer],
- metadata: Optional[Dict[str, Any]] = None,
- ) -> TargetEncodingType:
- target_annotations = []
- guidance_annotation = (
- metadata.get("guidance_annotation", None) if metadata is not None else None
- )
- if guidance_annotation is not None:
- if self.guidance_annotation_field is None:
- raise ValueError(
- "guidance_annotation is available, but guidance_annotation_field is not set"
- )
- # filter annotations that belong to the guidance_annotation
- for target_annotation in layers[self.target_layer]:
- current_guidance_annotation = getattr(
- target_annotation, self.guidance_annotation_field
- )
- if current_guidance_annotation == guidance_annotation:
- target_annotations.append(target_annotation)
- else:
- target_annotations = layers[self.target_layer]
-
- if len(target_annotations) == 0:
- raise ValueError(f"target_annotations {self.target_layer} contains no annotation")
- elif len(target_annotations) > 1:
- self.warn_only_once(
- f"target_annotations {self.target_layer} contains more than one annotation, "
- f"but only the first one will be used"
- )
- annotation = target_annotations[0]
- if isinstance(annotation, self.target_annotation_type):
- text = target_annotations[0].text
- else:
- raise ValueError(
- f"target_annotations {self.target_layer} contains an annotation of type {type(annotation)}, "
- f"but expected {self.target_annotation_type}"
- )
- encoding = self.tokenizer(text)
- return TargetEncodingType(
- labels=encoding["input_ids"], decoder_attention_mask=encoding["attention_mask"]
- )
-
- def decode_annotations(
- self, encoding: TaskOutputType, metadata: Optional[Dict[str, Any]] = None
- ) -> Tuple[Dict[str, List[Annotation]], Any]:
- text = self.tokenizer.decode(encoding.labels, skip_special_tokens=True)
- annotation_kwargs = {}
- if self.guidance_annotation_field is not None:
- if metadata is None:
- raise ValueError(
- "metadata is required to decode annotations with guidance_annotation_field"
- )
- guidance_annotation = metadata.get("guidance_annotation", None)
- if guidance_annotation is not None:
- if self.guidance_annotation_field is None:
- raise ValueError(
- "guidance_annotation is available, but guidance_annotation_field is not set"
- )
- annotation_kwargs[self.guidance_annotation_field] = guidance_annotation
-
- decoded_layers = {
- self.target_layer: [self.target_annotation_type(text=text, **annotation_kwargs)]
- }
- # no error collection yet
- errors: Dict[str, Any] = {}
- return decoded_layers, errors
-
- def tokenize_document(
- self, document: DocumentType, source_text: Optional[str] = None
- ) -> List[TokenBasedDocument]:
- field_mapping = dict(self.annotation_field_mapping)
- if self.partition_layer_name is not None:
- field_mapping[self.partition_layer_name] = "labeled_partitions"
- partition_layer = "labeled_partitions"
- else:
- partition_layer = None
- casted_document = document.as_type(self.document_type, field_mapping=field_mapping)
-
- tokenizer_kwargs = dict(self.tokenizer_kwargs)
- if source_text is not None:
- tokenizer_kwargs["text"] = source_text
- tokenized_docs = tokenize_document(
- casted_document,
- tokenizer=self.tokenizer,
- result_document_type=self.tokenized_document_type,
- partition_layer=partition_layer,
- **tokenizer_kwargs,
- )
- for idx, tokenized_doc in enumerate(tokenized_docs):
- tokenized_doc.id = f"{document.id}-tokenized-{idx+1}-of-{len(tokenized_docs)}"
-
- return tokenized_docs
-
- def encode_input(
- self, document: DocumentType, is_training: bool = False
- ) -> Optional[Union[TaskEncodingType, Sequence[TaskEncodingType]]]:
- task_encodings: List[TaskEncodingType] = []
- if self.guidance_layer is None:
- guidance_annotations = [None]
- else:
- guidance_annotations = document[self.guidance_layer]
- for guidance_annotation in guidance_annotations:
- source_text = None
- if guidance_annotation is not None:
- # Here could also more sophisticated logic be implemented
- source_text = guidance_annotation.text
- tokenized_docs = self.tokenize_document(document, source_text=source_text)
- for tokenized_doc in tokenized_docs:
- tokenizer_encoding = tokenized_doc.metadata["tokenizer_encoding"]
- task_encodings.append(
- TaskEncoding(
- document=document,
- inputs=InputEncodingType(
- input_ids=tokenizer_encoding.ids,
- attention_mask=tokenizer_encoding.attention_mask,
- ),
- metadata={
- "tokenized_document": tokenized_doc,
- "guidance_annotation": guidance_annotation,
- },
- )
- )
-
- return task_encodings
-
- def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncodingType]:
- document = task_encoding.metadata["tokenized_document"]
- guidance_annotation = task_encoding.metadata["guidance_annotation"]
-
- layers = {
- layer_name: self.get_mapped_layer(document, layer_name=layer_name)
- for layer_name in self.layer_names
- }
- result = self.encode_annotations(
- layers=layers,
- metadata={**task_encoding.metadata, "guidance_annotation": guidance_annotation},
- )
-
- self.maybe_log_example(task_encoding=task_encoding, targets=result)
- return result
-
- def collate(self, task_encodings: Sequence[TaskEncodingType]) -> TaskBatchEncoding:
- if len(task_encodings) == 0:
- raise ValueError("no task_encodings available")
- inputs = InputEncodingType.batch(
- values=[x.inputs for x in task_encodings],
- dtypes=self.dtypes,
- pad_values=self.pad_values,
- )
-
- targets = None
- if task_encodings[0].has_targets:
- targets = TargetEncodingType.batch(
- values=[x.targets for x in task_encodings],
- dtypes=self.dtypes,
- pad_values=self.pad_values,
- )
-
- return inputs, targets
-
- def unbatch_output(self, model_output: ModelBatchOutput) -> Sequence[TaskOutputType]:
- labels = model_output["labels"]
- batch_size = labels.size(0)
-
- # We use the position after the first eos token as the seq_len.
- # Note that, if eos_id is not in model_output for a given batch item, the result will be
- # model_output.size(1) + 1 (i.e. seq_len + 1) for that batch item. This is fine, because we use the
- # seq_lengths just to truncate the output and want to keep everything if eos_id is not present.
- seq_lengths = get_first_occurrence_index(labels, self.tokenizer.eos_token_id) + 1
-
- result = [
- TaskOutputType(labels[i, : seq_lengths[i]].to(device="cpu").tolist())
- for i in range(batch_size)
- ]
- return result
-
- def create_annotations_from_output(
- self,
- task_encoding: TaskEncodingType,
- task_output: TaskOutputType,
- ) -> Iterator[Tuple[str, Annotation]]:
- layers, errors = self.decode_annotations(
- encoding=task_output, metadata=task_encoding.metadata
- )
- tokenized_document = task_encoding.metadata["tokenized_document"]
-
- # Note: token_based_document_to_text_based() does not yet consider predictions, so we need to clear
- # the main annotations and attach the predictions to that
- for layer_name, annotations in layers.items():
- layer = self.get_mapped_layer(tokenized_document, layer_name=layer_name)
- layer.clear()
- layer.extend(annotations)
-
- untokenized_document = token_based_document_to_text_based(
- tokenized_document, result_document_type=self.document_type
- )
-
- for layer_name in layers:
- annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name)
- for annotation in annotations:
- yield layer_name, annotation.copy()
-
- def configure_model_generation(self) -> Optional[Dict[str, Any]]:
- # we do not set any overrides here, because we want to use the default generation config as
- # it is derived from the Huggingface base model config.json
- return {}
-
- def configure_model_metric(self, stage: str) -> Optional[Metric]:
- if self.text_metric_type is None:
- return None
-
- return WrappedMetricWithPrepareFunction(
- metric=self.text_metric_type(),
- prepare_function=partial(unbatch_and_untokenize, taskmodule=self),
- prepare_does_unbatch=True,
- )
diff --git a/src/pie_modules/utils/__init__.py b/src/pie_modules/utils/__init__.py
index c8017711c..e69de29bb 100644
--- a/src/pie_modules/utils/__init__.py
+++ b/src/pie_modules/utils/__init__.py
@@ -1,3 +0,0 @@
-# backwards compatibility
-from .dictionary import flatten_dict, list_of_dicts2dict_of_lists
-from .hydra import resolve_type
diff --git a/src/pie_modules/utils/dictionary.py b/src/pie_modules/utils/dictionary.py
deleted file mode 100644
index 34cb20be0..000000000
--- a/src/pie_modules/utils/dictionary.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# backwards compatibility
-from pie_core.utils.dictionary import flatten_dict_s as flatten_dict
-from pie_core.utils.dictionary import list_of_dicts2dict_of_lists
diff --git a/src/pie_modules/utils/hydra.py b/src/pie_modules/utils/hydra.py
deleted file mode 100644
index 21e55eeb9..000000000
--- a/src/pie_modules/utils/hydra.py
+++ /dev/null
@@ -1,2 +0,0 @@
-# backwards compatibility
-from pie_core.utils.hydra import resolve_target, resolve_type
diff --git a/src/pie_modules/utils/tokenization.py b/src/pie_modules/utils/tokenization.py
deleted file mode 100644
index 85e89d448..000000000
--- a/src/pie_modules/utils/tokenization.py
+++ /dev/null
@@ -1,36 +0,0 @@
-from typing import TypeVar
-
-from transformers import BatchEncoding
-
-from pie_modules.annotations import Span
-
-S = TypeVar("S", bound=Span)
-
-
-class SpanNotAlignedWithTokenException(Exception):
- def __init__(self, span):
- self.span = span
-
-
-def get_aligned_token_span(encoding: BatchEncoding, char_span: S) -> S:
- # find the start
- token_start = None
- token_end_before = None
- char_start = None
- for idx in range(char_span.start, char_span.end):
- token_start = encoding.char_to_token(idx)
- if token_start is not None:
- char_start = idx
- break
-
- if char_start is None:
- raise SpanNotAlignedWithTokenException(span=char_span)
- for idx in range(char_span.end - 1, char_start - 1, -1):
- token_end_before = encoding.char_to_token(idx)
- if token_end_before is not None:
- break
-
- if token_start is None or token_end_before is None:
- raise SpanNotAlignedWithTokenException(span=char_span)
-
- return char_span.copy(start=token_start, end=token_end_before + 1)
diff --git a/tests/document/processing/test_relation_argument_sorter.py b/tests/document/processing/test_relation_argument_sorter.py
index fec8e1200..2068384ea 100644
--- a/tests/document/processing/test_relation_argument_sorter.py
+++ b/tests/document/processing/test_relation_argument_sorter.py
@@ -104,7 +104,7 @@ def test_get_args_wrong_type(document_with_nary_relation):
== "relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), "
"LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, "
"label='ORG', score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) "
- "has unknown type [], cannot get arguments from it"
+ "has unknown type [], cannot get arguments from it"
)
@@ -122,7 +122,7 @@ def test_construct_relation_with_new_args_wrong_type(document_with_nary_relation
== "original relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), "
"LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, label='ORG', "
"score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) has unknown type "
- "[], cannot reconstruct it with new arguments"
+ "[], cannot reconstruct it with new arguments"
)
diff --git a/tests/document/processing/test_tokenization.py b/tests/document/processing/test_tokenization.py
index c940be177..c6fdacd17 100644
--- a/tests/document/processing/test_tokenization.py
+++ b/tests/document/processing/test_tokenization.py
@@ -4,7 +4,6 @@
import pytest
from pie_core import Annotation, AnnotationLayer, Document, annotation_field
-from transformers import AutoTokenizer, PreTrainedTokenizer
from pie_modules.annotations import (
BinaryRelation,
@@ -16,7 +15,6 @@
from pie_modules.document.processing import (
text_based_document_to_token_based,
token_based_document_to_text_based,
- tokenize_document,
)
from pie_modules.document.processing.tokenization import find_token_offset_mapping
from pie_modules.documents import TextBasedDocument, TokenBasedDocument
@@ -229,11 +227,6 @@ def _test_token_document_with_multi_spans(doc):
]
-@pytest.fixture(scope="module")
-def tokenizer() -> PreTrainedTokenizer:
- return AutoTokenizer.from_pretrained("bert-base-cased")
-
-
def test_find_token_offset_mapping(text_document, token_document):
token_offset_mapping = find_token_offset_mapping(
text=text_document.text, tokens=list(token_document.tokens)
@@ -559,7 +552,7 @@ class WrongAnnotationType(TextBasedDocument):
assert (
str(excinfo.value)
== "can not convert layers that target the text but contain non-span annotations, "
- "but found "
+ "but found "
)
@@ -650,370 +643,5 @@ class WrongAnnotationType(TokenBasedDocument):
assert (
str(excinfo.value)
== "can not convert layers that target the tokens but contain non-span annotations, "
- "but found "
- )
-
-
-def test_tokenize_document(text_document, tokenizer):
- added_annotations = []
- tokenized_docs = tokenize_document(
- text_document,
- tokenizer=tokenizer,
- result_document_type=TokenizedTestDocument,
- added_annotations=added_annotations,
- )
- assert len(tokenized_docs) == 1
- tokenized_doc = tokenized_docs[0]
-
- # check (de-)serialization
- tokenized_doc.copy()
-
- assert (
- tokenized_doc.metadata["text"]
- == text_document.text
- == "First sentence. Entity M works at N. And it founded O."
- )
- assert tokenized_doc.tokens == (
- "[CLS]",
- "First",
- "sentence",
- ".",
- "En",
- "##ti",
- "##ty",
- "M",
- "works",
- "at",
- "N",
- ".",
- "And",
- "it",
- "founded",
- "O",
- ".",
- "[SEP]",
- )
- assert len(tokenized_doc.sentences) == len(text_document.sentences) == 3
- sentences = [str(sentence) for sentence in tokenized_doc.sentences]
- assert sentences == [
- "('First', 'sentence', '.')",
- "('En', '##ti', '##ty', 'M', 'works', 'at', 'N', '.')",
- "('And', 'it', 'founded', 'O', '.')",
- ]
- assert len(tokenized_doc.entities) == len(text_document.entities) == 4
- entities = [str(entity) for entity in tokenized_doc.entities]
- assert entities == ["('En', '##ti', '##ty', 'M')", "('N',)", "('it',)", "('O',)"]
- assert len(tokenized_doc.relations) == len(text_document.relations) == 2
- relation_tuples = [
- (str(rel.head), rel.label, str(rel.tail)) for rel in tokenized_doc.relations
- ]
- assert relation_tuples == [
- ("('En', '##ti', '##ty', 'M')", "per:employee_of", "('N',)"),
- ("('it',)", "per:founder", "('O',)"),
- ]
-
- assert len(added_annotations) == 1
- first_added_annotations = added_annotations[0]
- _assert_added_annotations(text_document, tokenized_doc, first_added_annotations)
-
-
-def test_tokenize_document_max_length(text_document, tokenizer, caplog):
- added_annotations = []
- caplog.clear()
- with caplog.at_level("WARNING"):
- tokenized_docs = tokenize_document(
- text_document,
- tokenizer=tokenizer,
- result_document_type=TokenizedTestDocument,
- # max_length is set to 10, so the document is split into two parts
- strict_span_conversion=False,
- max_length=10,
- return_overflowing_tokens=True,
- added_annotations=added_annotations,
- )
- assert len(caplog.records) == 1
- assert (
- caplog.records[0].message
- == "could not convert all annotations from document with id=None to token based documents, missed annotations "
- "(disable this message with verbose=False):\n"
- "{\n"
- ' "relations": "{BinaryRelation(head=LabeledSpan(start=16, end=24, label=\'per\', score=1.0), '
- "tail=LabeledSpan(start=34, end=35, label='org', score=1.0), label='per:employee_of', score=1.0)}\",\n"
- ' "sentences": "{Span(start=16, end=36)}"\n'
- "}"
- )
- assert len(tokenized_docs) == 2
- assert len(added_annotations) == 2
- tokenized_doc = tokenized_docs[0]
-
- # check (de-)serialization
- tokenized_doc.copy()
-
- assert (
- tokenized_doc.metadata["text"]
- == text_document.text
- == "First sentence. Entity M works at N. And it founded O."
- )
- assert tokenized_doc.tokens == (
- "[CLS]",
- "First",
- "sentence",
- ".",
- "En",
- "##ti",
- "##ty",
- "M",
- "works",
- "[SEP]",
- )
- assert len(tokenized_doc.sentences) == 1
- sentences = [str(sentence) for sentence in tokenized_doc.sentences]
- assert sentences == ["('First', 'sentence', '.')"]
- assert len(tokenized_doc.entities) == 1
- entities = [str(entity) for entity in tokenized_doc.entities]
- assert entities == ["('En', '##ti', '##ty', 'M')"]
- assert len(tokenized_doc.relations) == 0
- # check annotation mapping
- current_added_annotations = added_annotations[0]
- # no relations are added in the first tokenized document
- assert set(current_added_annotations) == {"sentences", "entities"}
- # check sentences
- sentence_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items()
- }
- assert sentence_mapping == {"First sentence.": ("First", "sentence", ".")}
- # check entities
- entity_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["entities"].items()
- }
- assert entity_mapping == {("per", "Entity M"): ("per", ("En", "##ti", "##ty", "M"))}
-
- tokenized_doc = tokenized_docs[1]
-
- # check (de-)serialization
- tokenized_doc.copy()
-
- assert (
- tokenized_doc.metadata["text"]
- == text_document.text
- == "First sentence. Entity M works at N. And it founded O."
- )
- assert tokenized_doc.tokens == (
- "[CLS]",
- "at",
- "N",
- ".",
- "And",
- "it",
- "founded",
- "O",
- ".",
- "[SEP]",
- )
- assert len(tokenized_doc.sentences) == 1
- sentences = [str(sentence) for sentence in tokenized_doc.sentences]
- assert sentences == ["('And', 'it', 'founded', 'O', '.')"]
- assert len(tokenized_doc.entities) == 3
- entities = [str(entity) for entity in tokenized_doc.entities]
- assert entities == ["('N',)", "('it',)", "('O',)"]
- assert len(tokenized_doc.relations) == 1
- relation_tuples = [
- (str(rel.head), rel.label, str(rel.tail)) for rel in tokenized_doc.relations
- ]
- assert relation_tuples == [("('it',)", "per:founder", "('O',)")]
- # check annotation mapping
- current_added_annotations = added_annotations[1]
- assert set(current_added_annotations) == {"sentences", "entities", "relations"}
- # check sentences
- sentence_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items()
- }
- assert sentence_mapping == {"And it founded O.": ("And", "it", "founded", "O", ".")}
- # check entities
- entity_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["entities"].items()
- }
- assert entity_mapping == {
- ("org", "N"): ("org", ("N",)),
- ("per", "it"): ("per", ("it",)),
- ("org", "O"): ("org", ("O",)),
- }
- # check relations
- relation_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["relations"].items()
- }
- assert relation_mapping == {
- ("per:founder", (("per", "it"), ("org", "O"))): (
- "per:founder",
- (("per", ("it",)), ("org", ("O",))),
- )
- }
-
-
-def test_tokenize_document_max_length_strict(text_document, tokenizer):
- with pytest.raises(ValueError) as excinfo:
- tokenize_document(
- text_document,
- tokenizer=tokenizer,
- result_document_type=TokenizedTestDocument,
- # max_length is set to 10, so the document is split into two parts
- strict_span_conversion=True,
- max_length=10,
- return_overflowing_tokens=True,
- )
- assert (
- str(excinfo.value)
- == "could not convert all annotations from document with id=None to token based documents, "
- "but strict_span_conversion is True, so raise an error, missed annotations:\n"
- "{\n"
- ' "relations": "{BinaryRelation(head=LabeledSpan(start=16, end=24, label=\'per\', score=1.0), '
- "tail=LabeledSpan(start=34, end=35, label='org', score=1.0), label='per:employee_of', score=1.0)}\",\n"
- ' "sentences": "{Span(start=16, end=36)}"\n'
- "}"
- )
-
-
-def test_tokenize_document_partition(text_document, tokenizer):
- added_annotations = []
- tokenized_docs = tokenize_document(
- text_document,
- tokenizer=tokenizer,
- result_document_type=TokenizedTestDocument,
- partition_layer="sentences",
- added_annotations=added_annotations,
- )
- assert len(tokenized_docs) == 3
- assert len(added_annotations) == 3
- tokenized_doc = tokenized_docs[0]
-
- # check (de-)serialization
- tokenized_doc.copy()
-
- assert (
- tokenized_doc.metadata["text"]
- == text_document.text
- == "First sentence. Entity M works at N. And it founded O."
+ "but found "
)
- assert tokenized_doc.tokens == ("[CLS]", "First", "sentence", ".", "[SEP]")
- assert len(tokenized_doc.sentences) == 1
- assert len(tokenized_doc.entities) == 0
- assert len(tokenized_doc.relations) == 0
-
- # check annotation mapping
- current_added_annotations = added_annotations[0]
- assert set(current_added_annotations) == {"sentences"}
- # check sentences
- sentence_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items()
- }
- assert sentence_mapping == {"First sentence.": ("First", "sentence", ".")}
-
- tokenized_doc = tokenized_docs[1]
-
- # check (de-)serialization
- tokenized_doc.copy()
-
- assert (
- tokenized_doc.metadata["text"]
- == text_document.text
- == "First sentence. Entity M works at N. And it founded O."
- )
- assert tokenized_doc.tokens == (
- "[CLS]",
- "En",
- "##ti",
- "##ty",
- "M",
- "works",
- "at",
- "N",
- ".",
- "[SEP]",
- )
- assert len(tokenized_doc.sentences) == 1
- sentences = [str(sentence) for sentence in tokenized_doc.sentences]
- assert sentences == ["('En', '##ti', '##ty', 'M', 'works', 'at', 'N', '.')"]
- assert len(tokenized_doc.entities) == 2
- entities = [str(entity) for entity in tokenized_doc.entities]
- assert entities == ["('En', '##ti', '##ty', 'M')", "('N',)"]
- assert len(tokenized_doc.relations) == 1
- relation_tuples = [
- (str(rel.head), rel.label, str(rel.tail)) for rel in tokenized_doc.relations
- ]
- assert relation_tuples == [("('En', '##ti', '##ty', 'M')", "per:employee_of", "('N',)")]
-
- # check annotation mapping
- current_added_annotations = added_annotations[1]
- assert set(current_added_annotations) == {"sentences", "entities", "relations"}
- # check sentences
- sentence_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items()
- }
- assert sentence_mapping == {
- "Entity M works at N.": ("En", "##ti", "##ty", "M", "works", "at", "N", ".")
- }
- # check entities
- entity_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["entities"].items()
- }
- assert entity_mapping == {
- ("per", "Entity M"): ("per", ("En", "##ti", "##ty", "M")),
- ("org", "N"): ("org", ("N",)),
- }
- # check relations
- relation_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["relations"].items()
- }
- assert relation_mapping == {
- ("per:employee_of", (("per", "Entity M"), ("org", "N"))): (
- "per:employee_of",
- (("per", ("En", "##ti", "##ty", "M")), ("org", ("N",))),
- )
- }
-
- tokenized_doc = tokenized_docs[2]
-
- # check (de-)serialization
- tokenized_doc.copy()
-
- assert (
- tokenized_doc.metadata["text"]
- == text_document.text
- == "First sentence. Entity M works at N. And it founded O."
- )
- assert tokenized_doc.tokens == ("[CLS]", "And", "it", "founded", "O", ".", "[SEP]")
- assert len(tokenized_doc.sentences) == 1
- sentences = [str(sentence) for sentence in tokenized_doc.sentences]
- assert sentences == ["('And', 'it', 'founded', 'O', '.')"]
- assert len(tokenized_doc.entities) == 2
- entities = [str(entity) for entity in tokenized_doc.entities]
- assert entities == ["('it',)", "('O',)"]
- assert len(tokenized_doc.relations) == 1
- relation_tuples = [
- (str(rel.head), rel.label, str(rel.tail)) for rel in tokenized_doc.relations
- ]
- assert relation_tuples == [("('it',)", "per:founder", "('O',)")]
-
- # check annotation mapping
- current_added_annotations = added_annotations[2]
- assert set(current_added_annotations) == {"sentences", "entities", "relations"}
- # check sentences
- sentence_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["sentences"].items()
- }
- assert sentence_mapping == {"And it founded O.": ("And", "it", "founded", "O", ".")}
- # check entities
- entity_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["entities"].items()
- }
- assert entity_mapping == {("per", "it"): ("per", ("it",)), ("org", "O"): ("org", ("O",))}
- # check relations
- relation_mapping = {
- k.resolve(): v.resolve() for k, v in current_added_annotations["relations"].items()
- }
- assert relation_mapping == {
- ("per:founder", (("per", "it"), ("org", "O"))): (
- "per:founder",
- (("per", ("it",)), ("org", ("O",))),
- )
- }
diff --git a/tests/metrics/test_relation_argument_distance_collector.py b/tests/metrics/test_relation_argument_distance_collector.py
index f51d6e819..2dc05ce90 100644
--- a/tests/metrics/test_relation_argument_distance_collector.py
+++ b/tests/metrics/test_relation_argument_distance_collector.py
@@ -82,107 +82,7 @@ def test_relation_argument_distance_collector_with_n_ary_relation():
}
-def test_relation_argument_distance_collector_with_tokenize():
- doc = TestDocument(
- text="This is the first entity. This is the second entity. And, this is the third entity."
- )
-
- doc.entities.append(LabeledSpan(start=0, end=25, label="entity"))
- doc.entities.append(LabeledSpan(start=26, end=52, label="entity"))
- doc.entities.append(LabeledSpan(start=53, end=83, label="entity"))
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="relation_label_1")
- )
- doc.relations.append(
- BinaryRelation(head=doc.entities[1], tail=doc.entities[2], label="relation_label_2")
- )
-
- @dataclasses.dataclass
- class TokenizedTestDocument(TokenBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")
- relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
-
- statistic = RelationArgumentDistanceCollector(
- layer="relations",
- tokenize=True,
- tokenizer="bert-base-uncased",
- tokenized_document_type=TokenizedTestDocument,
- )
- values = statistic(doc)
- assert values == {
- "ALL": {"len": 4, "mean": 13.0, "std": 1.0, "min": 12.0, "max": 14.0},
- "relation_label_1": {"len": 2, "mean": 12.0, "std": 0.0, "min": 12.0, "max": 12.0},
- "relation_label_2": {"len": 2, "mean": 14.0, "std": 0.0, "min": 14.0, "max": 14.0},
- }
-
-
-def test_relation_argument_distance_collector_with_tokenize_missing_tokenizer():
- with pytest.raises(ValueError) as excinfo:
- RelationArgumentDistanceCollector(
- layer="relations",
- tokenize=True,
- tokenized_document_type=TokenBasedDocument,
- )
- assert (
- str(excinfo.value) == "tokenizer must be provided to calculate distance in means of tokens"
- )
-
-
-def test_relation_argument_distance_collector_with_tokenize_missing_tokenized_document_type():
- with pytest.raises(ValueError) as excinfo:
- RelationArgumentDistanceCollector(
- layer="relations",
- tokenize=True,
- tokenizer="bert-base-uncased",
- )
- assert (
- str(excinfo.value)
- == "tokenized_document_type must be provided to calculate distance in means of tokens"
- )
-
-
-def test_relation_argument_distance_collector_with_tokenize_wrong_document_type():
- @dataclasses.dataclass
- class TestDocument(Document):
- data: str
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="data")
- relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
-
- doc = TestDocument(
- data="This is the first entity. This is the second entity. This is the third entity."
- )
-
- doc.entities.append(LabeledSpan(start=0, end=25, label="entity"))
- doc.entities.append(LabeledSpan(start=26, end=52, label="entity"))
- doc.entities.append(LabeledSpan(start=53, end=78, label="entity"))
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="relation_label_1")
- )
- doc.relations.append(
- BinaryRelation(head=doc.entities[1], tail=doc.entities[2], label="relation_label_2")
- )
-
- @dataclasses.dataclass
- class TokenizedTestDocument(TokenBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")
- relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
-
- statistic = RelationArgumentDistanceCollector(
- layer="relations",
- tokenize=True,
- tokenizer="bert-base-uncased",
- tokenized_document_type=TokenizedTestDocument,
- )
-
- with pytest.raises(ValueError) as excinfo:
- statistic(doc)
- assert (
- str(excinfo.value)
- == "doc must be a TextBasedDocument to calculate distance in means of tokens"
- )
-
-
-def test_relation_argument_distance_collector_with_tokenize_wrong_span_annotation_type():
+def test_relation_argument_distance_collector_with_wrong_span_annotation_type():
@dataclasses.dataclass(eq=True, frozen=True)
class UnknownSpan(Annotation):
start: int
@@ -216,7 +116,7 @@ class TestDocument(TextBasedDocument):
)
-def test_relation_argument_distance_collector_with_tokenize_wrong_relation_annotation_type():
+def test_relation_argument_distance_collector_with_wrong_relation_annotation_type():
@dataclasses.dataclass(eq=True, frozen=True)
class UnknownRelation(Annotation):
head: Annotation
diff --git a/tests/metrics/test_span_coverage_collector.py b/tests/metrics/test_span_coverage_collector.py
index 026ec28b7..109352df3 100644
--- a/tests/metrics/test_span_coverage_collector.py
+++ b/tests/metrics/test_span_coverage_collector.py
@@ -1,10 +1,10 @@
import dataclasses
import pytest
-from pie_core import Annotation, AnnotationLayer, Document, annotation_field
+from pie_core import Annotation, AnnotationLayer, annotation_field
from pie_modules.annotations import LabeledMultiSpan, LabeledSpan
-from pie_modules.documents import TextBasedDocument, TokenBasedDocument
+from pie_modules.documents import TextBasedDocument
from pie_modules.metrics import SpanCoverageCollector
@@ -53,85 +53,7 @@ def test_span_coverage_collector_with_labels():
assert values == {"len": 1, "max": 0.125, "mean": 0.125, "min": 0.125, "std": 0.0}
-def test_span_coverage_collector_with_tokenize():
- doc = TestDocument(text="A and O.")
- doc.entities.append(LabeledSpan(start=0, end=1, label="entity"))
- doc.entities.append(LabeledSpan(start=6, end=7, label="entity"))
-
- @dataclasses.dataclass
- class TokenizedTestDocument(TokenBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")
-
- statistic = SpanCoverageCollector(
- layer="entities",
- tokenize=True,
- tokenizer="bert-base-uncased",
- tokenized_document_type=TokenizedTestDocument,
- )
- values = statistic(doc)
- assert values == {
- "len": 1,
- "max": 0.3333333333333333,
- "mean": 0.3333333333333333,
- "min": 0.3333333333333333,
- "std": 0.0,
- }
-
-
-def test_span_coverage_collector_with_tokenize_missing_tokenizer():
- with pytest.raises(ValueError) as excinfo:
- SpanCoverageCollector(
- layer="entities",
- tokenize=True,
- tokenized_document_type=TokenBasedDocument,
- )
- assert (
- str(excinfo.value)
- == "tokenizer must be provided to calculate the span coverage in means of tokens"
- )
-
-
-def test_span_coverage_collector_with_tokenize_missing_tokenized_document_type():
- with pytest.raises(ValueError) as excinfo:
- SpanCoverageCollector(
- layer="entities",
- tokenize=True,
- tokenizer="bert-base-uncased",
- )
- assert (
- str(excinfo.value)
- == "tokenized_document_type must be provided to calculate the span coverage in means of tokens"
- )
-
-
-def test_span_coverage_collector_with_tokenize_wrong_document_type():
- @dataclasses.dataclass
- class TestDocument(Document):
- data: str
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="data")
-
- doc = TestDocument(data="A and O")
-
- @dataclasses.dataclass
- class TokenizedTestDocument(TokenBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")
-
- statistic = SpanCoverageCollector(
- layer="entities",
- tokenize=True,
- tokenizer="bert-base-uncased",
- tokenized_document_type=TokenizedTestDocument,
- )
-
- with pytest.raises(ValueError) as excinfo:
- statistic(doc)
- assert (
- str(excinfo.value)
- == "doc must be a TextBasedDocument to calculate the span coverage in means of tokens"
- )
-
-
-def test_span_coverage_collector_with_tokenize_wrong_annotation_type():
+def test_span_coverage_collector_with_wrong_annotation_type():
@dataclasses.dataclass(eq=True, frozen=True)
class UnknownSpan(Annotation):
start: int
diff --git a/tests/metrics/test_span_length_collector.py b/tests/metrics/test_span_length_collector.py
index dcd69af01..f482b1066 100644
--- a/tests/metrics/test_span_length_collector.py
+++ b/tests/metrics/test_span_length_collector.py
@@ -78,83 +78,7 @@ def test_span_length_collector_wrong_label_value():
assert str(excinfo.value) == "labels must be a list of strings or 'INFERRED'"
-def test_span_length_collector_with_tokenize(documents):
- @dataclasses.dataclass
- class TokenizedTestDocument(TokenBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")
-
- statistic = SpanLengthCollector(
- layer="entities",
- tokenize=True,
- tokenizer="bert-base-uncased",
- tokenized_document_type=TokenizedTestDocument,
- )
- values = statistic(documents)
- assert values == {
- "len": 7,
- "max": 3,
- "mean": 1.8571428571428572,
- "min": 1,
- "std": 0.8329931278350429,
- }
-
-
-def test_span_length_collector_with_tokenize_missing_tokenizer():
- with pytest.raises(ValueError) as excinfo:
- SpanLengthCollector(
- layer="entities",
- tokenize=True,
- tokenized_document_type=TokenBasedDocument,
- )
- assert (
- str(excinfo.value)
- == "tokenizer must be provided to calculate the span length in means of tokens"
- )
-
-
-def test_span_length_collector_with_tokenize_missing_tokenized_document_type():
- with pytest.raises(ValueError) as excinfo:
- SpanLengthCollector(
- layer="entities",
- tokenize=True,
- tokenizer="bert-base-uncased",
- )
- assert (
- str(excinfo.value)
- == "tokenized_document_type must be provided to calculate the span length in means of tokens"
- )
-
-
-def test_span_length_collector_with_tokenize_wrong_document_type():
- @dataclasses.dataclass
- class TestDocument(Document):
- data: str
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="data")
-
- doc = TestDocument(data="First sentence. Entity M works at N. And it founded O.")
- doc.entities.append(LabeledSpan(start=16, end=24, label="per"))
- assert str(doc.entities[0]) == "Entity M"
-
- @dataclasses.dataclass
- class TokenizedTestDocument(TokenBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="tokens")
-
- statistic = SpanLengthCollector(
- layer="entities",
- tokenize=True,
- tokenizer="bert-base-uncased",
- tokenized_document_type=TokenizedTestDocument,
- )
-
- with pytest.raises(ValueError) as excinfo:
- statistic(doc)
- assert (
- str(excinfo.value)
- == "doc must be a TextBasedDocument to calculate the span length in means of tokens"
- )
-
-
-def test_span_length_collector_with_tokenize_wrong_annotation_type():
+def test_span_length_collector_with_wrong_annotation_type():
@dataclasses.dataclass
class TestDocument(TextBasedDocument):
label: AnnotationLayer[Label] = annotation_field()
@@ -168,5 +92,5 @@ class TestDocument(TextBasedDocument):
statistic(doc)
assert (
str(excinfo.value)
- == "span length calculation is not yet supported for "
+ == "span length calculation is not yet supported for "
)
diff --git a/tests/metrics/test_statistics.py b/tests/metrics/test_statistics.py
index 4d1325795..951dbd57c 100644
--- a/tests/metrics/test_statistics.py
+++ b/tests/metrics/test_statistics.py
@@ -3,7 +3,6 @@
FieldLengthCollector,
LabelCountCollector,
SubFieldLengthCollector,
- TokenCountCollector,
)
@@ -71,17 +70,3 @@ def test_statistics(document_dataset):
"val": {"mean": 3.0, "std": 0.0, "min": 3, "max": 3},
"train": {"mean": 3.0, "std": 0.0, "min": 3, "max": 3},
}
-
-
-def test_statistics_with_tokenize(document_dataset):
- statistic = TokenCountCollector(
- text_field="text",
- tokenizer="bert-base-uncased",
- tokenizer_kwargs=dict(add_special_tokens=False),
- )
- values = statistic(document_dataset)
- assert values == {
- "test": {"max": 13, "mean": 8.5, "min": 4, "std": 4.5},
- "train": {"max": 14, "mean": 8.285714285714286, "min": 4, "std": 3.5742845723419436},
- "val": {"max": 13, "mean": 8.5, "min": 4, "std": 4.5},
- }
diff --git a/tests/models/__init__.py b/tests/models/__init__.py
deleted file mode 100644
index 1e3e58002..000000000
--- a/tests/models/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-def trunc_number(x: float, n: int) -> float:
- return int(x * 10**n) / 10**n
diff --git a/tests/models/base_models/__init__.py b/tests/models/base_models/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/models/base_models/test_bart_as_pointer_network.py b/tests/models/base_models/test_bart_as_pointer_network.py
deleted file mode 100644
index 554e52108..000000000
--- a/tests/models/base_models/test_bart_as_pointer_network.py
+++ /dev/null
@@ -1,983 +0,0 @@
-import pytest
-import torch
-from transformers import (
- BartModel,
- BeamSearchScorer,
- LogitsProcessorList,
- MinLengthLogitsProcessor,
-)
-from transformers.generation import BeamSearchEncoderDecoderOutput
-
-from pie_modules.models.base_models import (
- BartAsPointerNetwork,
- BartModelWithDecoderPositionIds,
-)
-from tests import _config_to_str
-from tests.models import trunc_number
-
-# this is a small model that can be used for testing
-MODEL_NAME_OR_PATH = "sshleifer/bart-tiny-random"
-DECODER_POSITION_ID_PATTERN = [0, 0, 1, 0, 0, 1, 1]
-CONFIGS = [{}, {"decoder_position_id_mode": "pattern"}]
-CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-
-
-@pytest.fixture(scope="module", params=CONFIG_DICT.keys())
-def config_str(request):
- return request.param
-
-
-@pytest.fixture(scope="module")
-def config(config_str):
- return CONFIG_DICT[config_str]
-
-
-@pytest.fixture(scope="module")
-def document():
- from pie_modules.annotations import BinaryRelation, LabeledSpan
- from pie_modules.documents import (
- TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
- )
-
- doc = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
- text="This is a dummy text about nothing. Trust me."
- )
- span1 = LabeledSpan(start=10, end=20, label="content")
- span2 = LabeledSpan(start=27, end=34, label="topic")
- span3 = LabeledSpan(start=42, end=44, label="person")
- doc.labeled_spans.extend([span1, span2, span3])
- assert str(span1) == "dummy text"
- assert str(span2) == "nothing"
- assert str(span3) == "me"
- rel = BinaryRelation(head=span1, tail=span2, label="is_about")
- doc.binary_relations.append(rel)
- assert str(rel.label) == "is_about"
- assert str(rel.head) == "dummy text"
- assert str(rel.tail) == "nothing"
-
- sent1 = LabeledSpan(start=0, end=35, label="1")
- sent2 = LabeledSpan(start=36, end=45, label="2")
- doc.labeled_partitions.extend([sent1, sent2])
- assert str(sent1) == "This is a dummy text about nothing."
- assert str(sent2) == "Trust me."
- return doc
-
-
-@pytest.fixture(scope="module")
-def taskmodule(document):
- from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
-
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path=MODEL_NAME_OR_PATH,
- partition_layer_name="labeled_partitions",
- create_constraints=True,
- )
-
- taskmodule.prepare(documents=[document])
-
- return taskmodule
-
-
-@pytest.fixture(scope="module")
-def model(config) -> BartAsPointerNetwork:
- model_name_or_path = MODEL_NAME_OR_PATH
-
- torch.random.manual_seed(42)
- model = BartAsPointerNetwork.from_pretrained(
- model_name_or_path,
- # label id space
- bos_token_id=0, # taskmodule.bos_id,
- eos_token_id=1, # taskmodule.eos_id,
- pad_token_id=1, # taskmodule.eos_id,
- # target token id space
- target_token_ids=[0, 2, 50266, 50269, 50268, 50265, 50267], # taskmodule.target_token_ids,
- # mapping to better initialize the label embedding weights
- # taken from taskmodule.label_embedding_weight_mapping
- embedding_weight_mapping={
- 50266: [39763],
- 50269: [10166],
- 50268: [5970],
- 50265: [45260],
- 50267: [354, 1215, 9006],
- },
- decoder_position_id_pattern=DECODER_POSITION_ID_PATTERN,
- **config,
- )
-
- return model
-
-
-def test_model(model, config):
- assert model is not None
- named_parameters = dict(model.named_parameters())
- parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()}
- parameter_means_expected = {
- "model.shared.weight": -1.41e-05,
- "model.encoder.embed_positions.weight": -0.0001324,
- "model.encoder.layers.0.self_attn.k_proj.weight": -0.0004574,
- "model.encoder.layers.0.self_attn.k_proj.bias": 0.0,
- "model.encoder.layers.0.self_attn.v_proj.weight": -0.0005457,
- "model.encoder.layers.0.self_attn.v_proj.bias": 0.0,
- "model.encoder.layers.0.self_attn.q_proj.weight": -0.0009775,
- "model.encoder.layers.0.self_attn.q_proj.bias": 0.0,
- "model.encoder.layers.0.self_attn.out_proj.weight": -0.0001075,
- "model.encoder.layers.0.self_attn.out_proj.bias": 0.0,
- "model.encoder.layers.0.self_attn_layer_norm.weight": 1.0,
- "model.encoder.layers.0.self_attn_layer_norm.bias": 0.0,
- "model.encoder.layers.0.fc1.weight": -0.0008655,
- "model.encoder.layers.0.fc1.bias": 0.0,
- "model.encoder.layers.0.fc2.weight": 0.0015535,
- "model.encoder.layers.0.fc2.bias": 0.0,
- "model.encoder.layers.0.final_layer_norm.weight": 1.0,
- "model.encoder.layers.0.final_layer_norm.bias": 0.0,
- "model.encoder.layers.1.self_attn.k_proj.weight": -0.0007831,
- "model.encoder.layers.1.self_attn.k_proj.bias": 0.0,
- "model.encoder.layers.1.self_attn.v_proj.weight": 0.0001186,
- "model.encoder.layers.1.self_attn.v_proj.bias": 0.0,
- "model.encoder.layers.1.self_attn.q_proj.weight": 0.0006847,
- "model.encoder.layers.1.self_attn.q_proj.bias": 0.0,
- "model.encoder.layers.1.self_attn.out_proj.weight": 0.0011724,
- "model.encoder.layers.1.self_attn.out_proj.bias": 0.0,
- "model.encoder.layers.1.self_attn_layer_norm.weight": 1.0,
- "model.encoder.layers.1.self_attn_layer_norm.bias": 0.0,
- "model.encoder.layers.1.fc1.weight": 0.0007757,
- "model.encoder.layers.1.fc1.bias": 0.0,
- "model.encoder.layers.1.fc2.weight": -0.0002014,
- "model.encoder.layers.1.fc2.bias": 0.0,
- "model.encoder.layers.1.final_layer_norm.weight": 1.0,
- "model.encoder.layers.1.final_layer_norm.bias": 0.0,
- "model.encoder.layernorm_embedding.weight": 1.0,
- "model.encoder.layernorm_embedding.bias": 0.0,
- "model.decoder.embed_positions.weight": -0.0001275,
- "model.decoder.layers.0.self_attn.k_proj.weight": -0.0010682,
- "model.decoder.layers.0.self_attn.k_proj.bias": 0.0,
- "model.decoder.layers.0.self_attn.v_proj.weight": 0.0005057,
- "model.decoder.layers.0.self_attn.v_proj.bias": 0.0,
- "model.decoder.layers.0.self_attn.q_proj.weight": 0.0003248,
- "model.decoder.layers.0.self_attn.q_proj.bias": 0.0,
- "model.decoder.layers.0.self_attn.out_proj.weight": -0.0002014,
- "model.decoder.layers.0.self_attn.out_proj.bias": 0.0,
- "model.decoder.layers.0.self_attn_layer_norm.weight": 1.0,
- "model.decoder.layers.0.self_attn_layer_norm.bias": 0.0,
- "model.decoder.layers.0.encoder_attn.k_proj.weight": -0.0004254,
- "model.decoder.layers.0.encoder_attn.k_proj.bias": 0.0,
- "model.decoder.layers.0.encoder_attn.v_proj.weight": -0.0004049,
- "model.decoder.layers.0.encoder_attn.v_proj.bias": 0.0,
- "model.decoder.layers.0.encoder_attn.q_proj.weight": -0.0003516,
- "model.decoder.layers.0.encoder_attn.q_proj.bias": 0.0,
- "model.decoder.layers.0.encoder_attn.out_proj.weight": 0.0009908,
- "model.decoder.layers.0.encoder_attn.out_proj.bias": 0.0,
- "model.decoder.layers.0.encoder_attn_layer_norm.weight": 1.0,
- "model.decoder.layers.0.encoder_attn_layer_norm.bias": 0.0,
- "model.decoder.layers.0.fc1.weight": 0.0008378,
- "model.decoder.layers.0.fc1.bias": 0.0,
- "model.decoder.layers.0.fc2.weight": -2e-05,
- "model.decoder.layers.0.fc2.bias": 0.0,
- "model.decoder.layers.0.final_layer_norm.weight": 1.0,
- "model.decoder.layers.0.final_layer_norm.bias": 0.0,
- "model.decoder.layers.1.self_attn.k_proj.weight": -0.0007669,
- "model.decoder.layers.1.self_attn.k_proj.bias": 0.0,
- "model.decoder.layers.1.self_attn.v_proj.weight": -0.0007123,
- "model.decoder.layers.1.self_attn.v_proj.bias": 0.0,
- "model.decoder.layers.1.self_attn.q_proj.weight": 0.0012958,
- "model.decoder.layers.1.self_attn.q_proj.bias": 0.0,
- "model.decoder.layers.1.self_attn.out_proj.weight": -0.0006818,
- "model.decoder.layers.1.self_attn.out_proj.bias": 0.0,
- "model.decoder.layers.1.self_attn_layer_norm.weight": 1.0,
- "model.decoder.layers.1.self_attn_layer_norm.bias": 0.0,
- "model.decoder.layers.1.encoder_attn.k_proj.weight": -0.0006906,
- "model.decoder.layers.1.encoder_attn.k_proj.bias": 0.0,
- "model.decoder.layers.1.encoder_attn.v_proj.weight": -0.0009213,
- "model.decoder.layers.1.encoder_attn.v_proj.bias": 0.0,
- "model.decoder.layers.1.encoder_attn.q_proj.weight": -0.000842,
- "model.decoder.layers.1.encoder_attn.q_proj.bias": 0.0,
- "model.decoder.layers.1.encoder_attn.out_proj.weight": 0.0008073,
- "model.decoder.layers.1.encoder_attn.out_proj.bias": 0.0,
- "model.decoder.layers.1.encoder_attn_layer_norm.weight": 1.0,
- "model.decoder.layers.1.encoder_attn_layer_norm.bias": 0.0,
- "model.decoder.layers.1.fc1.weight": 0.0015493,
- "model.decoder.layers.1.fc1.bias": 0.0,
- "model.decoder.layers.1.fc2.weight": -0.0009827,
- "model.decoder.layers.1.fc2.bias": 0.0,
- "model.decoder.layers.1.final_layer_norm.weight": 1.0,
- "model.decoder.layers.1.final_layer_norm.bias": 0.0,
- "model.decoder.layernorm_embedding.weight": 1.0,
- "model.decoder.layernorm_embedding.bias": 0.0,
- "pointer_head.encoder_mlp.0.weight": 0.0004805,
- "pointer_head.encoder_mlp.0.bias": 0.0,
- "pointer_head.encoder_mlp.3.weight": 0.0001837,
- "pointer_head.encoder_mlp.3.bias": 0.0,
- }
- assert parameter_means == parameter_means_expected
- assert isinstance(model, BartAsPointerNetwork)
- if config == {}:
- assert isinstance(model.model, BartModel)
- elif config == {"decoder_position_id_mode": "pattern"}:
- assert isinstance(model.model, BartModelWithDecoderPositionIds)
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-@pytest.fixture(scope="module")
-def batch():
- inputs = {
- "input_ids": torch.tensor(
- [
- [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2],
- [0, 18823, 162, 4, 2, 1, 1, 1, 1, 1],
- ]
- ),
- "attention_mask": torch.tensor(
- [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]]
- ),
- }
- targets = {
- "labels": torch.tensor([[14, 14, 5, 11, 12, 3, 6, 1], [9, 9, 4, 2, 2, 2, 2, 1]]),
- "decoder_attention_mask": torch.tensor(
- [[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]]
- ),
- }
- return inputs, targets
-
-
-@pytest.fixture(scope="module")
-def batch_with_constraints(batch):
- constraints = torch.tensor(
- [
- [
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ],
- [
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1, -1, -1, -1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, -1, -1, -1, -1, -1],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- ],
- ]
- )
- targets_with_constraints = {**batch[1], "constraints": constraints}
- return batch[0], targets_with_constraints
-
-
-@pytest.mark.skip(reason="This is just to show how to create the batch.")
-def test_batch_with_constraints(batch_with_constraints, taskmodule, document):
- inputs, targets = batch_with_constraints
- task_encodings = taskmodule.encode([document], encode_target=True)
- batch_from_documents = taskmodule.collate(task_encodings)
- inputs_from_documents, targets_from_documents = batch_from_documents
- for key in inputs:
- torch.testing.assert_close(inputs[key], inputs_from_documents[key])
-
- for key in targets:
- torch.testing.assert_close(targets[key], targets_from_documents[key])
-
-
-@pytest.fixture(scope="module")
-def decoder_input_ids(model):
- # taken from batch[1]["labels"]
- labels = torch.tensor([[14, 14, 5, 11, 12, 3, 6, 1], [9, 9, 4, 2, 2, 2, 2, 1]])
- decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=labels)
- return decoder_input_ids
-
-
-def test_prepare_decoder_input_ids_from_labels(decoder_input_ids):
- assert decoder_input_ids.shape == (2, 8)
- torch.testing.assert_close(
- decoder_input_ids,
- torch.tensor([[0, 14, 14, 5, 11, 12, 3, 6], [0, 9, 9, 4, 2, 2, 2, 2]]),
- )
-
-
-def test_forward(model, batch, decoder_input_ids, config):
- inputs, targets = batch
- torch.manual_seed(42)
- outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
- assert outputs.loss is None
- assert outputs.logits is not None
- # shape: (batch_size, output_seq_len, target_size=num_target_ids+num_offsets)
- assert outputs.logits.shape == (2, 8, 17)
- # check exact values only for the first sequence output
- torch.testing.assert_close(
- outputs.logits[:, 0, :],
- torch.tensor(
- [
- [
- -1.0000000138484279e24,
- -0.23238050937652588,
- 0.2958170175552368,
- 0.05529244244098663,
- 0.04253090173006058,
- 0.10081345587968826,
- -0.07145103067159653,
- 0.12317530065774918,
- -0.06861806660890579,
- 0.07819556444883347,
- 0.006490768864750862,
- -0.040455855429172516,
- 0.03176971897482872,
- 0.05362509936094284,
- 0.04528001323342323,
- -0.0684177577495575,
- -1.0000000331813535e32,
- ],
- [
- -1.0000000138484279e24,
- -0.23274855315685272,
- 0.2960396707057953,
- 0.05556505173444748,
- 0.04273710399866104,
- 0.10071954131126404,
- -0.071356862783432,
- 0.12314081937074661,
- 0.06498698145151138,
- 0.07938676327466965,
- -0.07943986356258392,
- -1.0000000331813535e32,
- -1.0000000331813535e32,
- -1.0000000331813535e32,
- -1.0000000331813535e32,
- -1.0000000331813535e32,
- -1.0000000331813535e32,
- ],
- ]
- ),
- )
- # check the sum of all logits
- if config == {}:
- torch.testing.assert_close(
- outputs.logits.sum(0).sum(0),
- torch.tensor(
- [
- -1.6000000221574846e25,
- -0.9064984321594238,
- 1.189674735069275,
- 0.9796359539031982,
- 0.1837124526500702,
- 1.3070943355560303,
- -0.1210818886756897,
- 0.5316579937934875,
- -0.12306825071573257,
- 0.6218758225440979,
- -0.4374474287033081,
- -8.000000265450828e32,
- -8.000000265450828e32,
- -8.000000265450828e32,
- -8.000000265450828e32,
- -8.000000265450828e32,
- -1.6000000530901656e33,
- ]
- ),
- )
- elif config == {"decoder_position_id_mode": "pattern"}:
- torch.testing.assert_close(
- outputs.logits.sum(0).sum(0),
- torch.tensor(
- [
- -1.6000000221574846e25,
- -0.5539568662643433,
- 0.7004716396331787,
- 1.5720455646514893,
- -0.3760950267314911,
- 0.7738710641860962,
- -0.1090446263551712,
- 0.287150502204895,
- -0.04344810172915459,
- 0.3674442768096924,
- -0.6838937997817993,
- -8.000000265450828e32,
- -8.000000265450828e32,
- -8.000000265450828e32,
- -8.000000265450828e32,
- -8.000000265450828e32,
- -1.6000000530901656e33,
- ]
- ),
- )
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-def test_forward_with_labels(model, batch, config):
- inputs, targets = batch
- targets_without_constraints = {
- key: value for key, value in targets.items() if key != "constraints"
- }
- assert set(inputs) == {"input_ids", "attention_mask"}
- assert set(targets_without_constraints) == {"labels", "decoder_attention_mask"}
- torch.manual_seed(42)
- outputs = model(**inputs, **targets_without_constraints)
- loss = outputs.loss
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(2.4516539573669434))
- elif config == {"decoder_position_id_mode": "pattern"}:
- torch.testing.assert_close(loss, torch.tensor(2.4184868335723877))
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-def test_forward_with_labels_and_constraints(model, batch_with_constraints, config):
- inputs, targets = batch_with_constraints
- assert set(inputs) == {"input_ids", "attention_mask"}
- assert set(targets) == {"labels", "decoder_attention_mask", "constraints"}
- torch.manual_seed(42)
- outputs = model(**inputs, **targets)
- loss = outputs.loss
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(4.776531219482422))
- elif config == {"decoder_position_id_mode": "pattern"}:
- torch.testing.assert_close(loss, torch.tensor(4.742183685302734))
- else:
- raise ValueError(f"Unknown model type {type(model.model)}")
-
-
-@pytest.fixture(scope="module")
-def empty_decoder_input_ids(batch, model):
- inputs, targets = batch
- batch_size, seq_len = inputs["input_ids"].shape
- decoder_input_ids = torch.ones((batch_size, 1), dtype=torch.long) * model.config.bos_token_id
- torch.testing.assert_close(
- decoder_input_ids,
- torch.tensor([[0], [0]]),
- )
- return decoder_input_ids
-
-
-@pytest.fixture(scope="module")
-def encoder_outputs(model, batch):
- inputs, targets = batch
- torch.manual_seed(42)
- encoder_outputs = model.get_encoder()(
- input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
- )
- return encoder_outputs
-
-
-@pytest.fixture(scope="module")
-def prepared_encoder_decoder_kwargs_for_generation(
- model, batch, empty_decoder_input_ids, encoder_outputs
-):
- model_kwargs = {
- "attention_mask": batch[0]["attention_mask"],
- "output_attentions": False,
- "output_hidden_states": False,
- "use_cache": True,
- }
- torch.manual_seed(42)
- prepared_kwargs = model._prepare_encoder_decoder_kwargs_for_generation(
- inputs_tensor=batch[0]["input_ids"],
- model_kwargs=model_kwargs,
- model_input_name="input_ids",
- )
- return prepared_kwargs
-
-
-def test_prepare_encoder_decoder_kwargs_for_generation(
- prepared_encoder_decoder_kwargs_for_generation, batch, encoder_outputs
-):
- model_kwargs = {
- "attention_mask": batch[0]["attention_mask"],
- "output_attentions": False,
- "output_hidden_states": False,
- "use_cache": True,
- }
-
- assert set(prepared_encoder_decoder_kwargs_for_generation) == set(model_kwargs) | {
- "encoder_input_ids",
- "encoder_attention_mask",
- "encoder_outputs",
- }
- torch.testing.assert_close(
- prepared_encoder_decoder_kwargs_for_generation["encoder_input_ids"],
- batch[0]["input_ids"],
- )
- torch.testing.assert_close(
- prepared_encoder_decoder_kwargs_for_generation["encoder_attention_mask"],
- batch[0]["attention_mask"],
- )
- torch.testing.assert_close(
- prepared_encoder_decoder_kwargs_for_generation["encoder_outputs"].last_hidden_state,
- encoder_outputs.last_hidden_state,
- )
-
-
-def test_prepare_inputs_for_generation(
- model,
- prepared_encoder_decoder_kwargs_for_generation,
- empty_decoder_input_ids,
- batch,
- encoder_outputs,
- config,
-):
- result = model.prepare_inputs_for_generation(
- decoder_input_ids=empty_decoder_input_ids, **prepared_encoder_decoder_kwargs_for_generation
- )
- result_keys = {
- "input_ids",
- "attention_mask",
- "encoder_outputs",
- "decoder_input_ids",
- "decoder_attention_mask",
- "past_key_values",
- "use_cache",
- "head_mask",
- "decoder_head_mask",
- "cross_attn_head_mask",
- }
- if model.pointer_head.use_prepared_position_ids:
- result_keys.add("decoder_position_ids")
- assert set(result) == result_keys
- torch.testing.assert_close(
- result["input_ids"],
- batch[0]["input_ids"],
- )
- torch.testing.assert_close(
- result["attention_mask"],
- batch[0]["attention_mask"],
- )
- torch.testing.assert_close(
- result["encoder_outputs"].last_hidden_state,
- encoder_outputs.last_hidden_state,
- )
- torch.testing.assert_close(
- result["decoder_input_ids"],
- empty_decoder_input_ids,
- )
- assert result["decoder_attention_mask"] is None
- assert result["past_key_values"] is None
- assert result["use_cache"] is True
- assert result["head_mask"] is None
- assert result["decoder_head_mask"] is None
- assert result["cross_attn_head_mask"] is None
- if config == {}:
- assert "decoder_position_ids" not in result
- elif config == {"decoder_position_id_mode": "pattern"}:
- torch.testing.assert_close(result["decoder_position_ids"], torch.tensor([[0], [0]]))
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-def test_prepare_inputs_for_generation_with_past_key_values(
- model,
- prepared_encoder_decoder_kwargs_for_generation,
- batch,
- encoder_outputs,
- config,
-):
- # shallow copy to avoid changing the original dict
- kwargs = dict(prepared_encoder_decoder_kwargs_for_generation)
- kwargs["decoder_input_ids"] = torch.tensor(
- [
- [0, 8, 9],
- [0, 8, 10],
- [0, 8, 15],
- [0, 8, 8],
- [0, 9, 10],
- [0, 8, 12],
- [0, 8, 9],
- [0, 8, 10],
- [0, 9, 10],
- [0, 8, 8],
- [0, 9, 9],
- [0, 8, 6],
- ]
- )
- # 12 is batch_size (2) * num_beams (6),
- # 16 is number of encoder / decoder attention heads,
- # 2 is the length of already generated tokens / 10 is the length of the encoder input,
- # 64 seems to be the size of the hidden states
- dummy_past_key_values = (
- torch.zeros((12, 16, 2, 64)),
- torch.zeros((12, 16, 2, 64)),
- torch.zeros((12, 16, 10, 64)),
- torch.zeros((12, 16, 10, 64)),
- )
-
- result = model.prepare_inputs_for_generation(past_key_values=dummy_past_key_values, **kwargs)
- if config == {}:
- assert len(result) == 10
- elif config == {"decoder_position_id_mode": "pattern"}:
- assert len(result) == 11
- else:
- raise ValueError(f"Unknown config: {config}")
- torch.testing.assert_close(
- result["input_ids"],
- batch[0]["input_ids"],
- )
- torch.testing.assert_close(
- result["attention_mask"],
- batch[0]["attention_mask"],
- )
- torch.testing.assert_close(
- result["encoder_outputs"].last_hidden_state,
- encoder_outputs.last_hidden_state,
- )
- torch.testing.assert_close(
- result["decoder_input_ids"],
- # just the last id for each entry
- torch.tensor([[9], [10], [15], [8], [10], [12], [9], [10], [10], [8], [9], [6]]),
- )
- assert result["decoder_attention_mask"] is None
- assert result["past_key_values"] is dummy_past_key_values
- assert result["use_cache"] is True
- assert result["head_mask"] is None
- assert result["decoder_head_mask"] is None
- assert result["cross_attn_head_mask"] is None
- if "decoder_position_ids" in result:
- torch.testing.assert_close(
- result["decoder_position_ids"],
- # originally this was 0 from the pattern, but got shifted for the position-bos and position-pad indices
- torch.tensor([[2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2], [2]]),
- )
-
-
-def test_generate(model, batch, empty_decoder_input_ids, config):
- inputs, targets = batch
- batch_size, seq_len = inputs["input_ids"].shape
- torch.manual_seed(42)
- outputs = model.generate(**inputs)
- if config == {}:
- assert outputs.shape == (batch_size, 20) # note that 20 is the model.config.max_length
- torch.testing.assert_close(
- outputs,
- torch.tensor(
- [
- [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
- [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
- ]
- ),
- )
- elif config == {"decoder_position_id_mode": "pattern"}:
- assert outputs.shape == (batch_size, 20) # note that 20 is the model.config.max_length
- torch.testing.assert_close(
- outputs,
- torch.tensor(
- [
- [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
- [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
- ]
- ),
- )
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-def test_head_named_params(model):
- parameter_shapes = {name: tuple(param.shape) for name, param in model.head_named_params()}
- assert parameter_shapes == {
- "pointer_head.encoder_mlp.0.bias": (24,),
- "pointer_head.encoder_mlp.0.weight": (24, 24),
- "pointer_head.encoder_mlp.3.bias": (24,),
- "pointer_head.encoder_mlp.3.weight": (24, 24),
- }
-
-
-def test_encoder_only_named_params(model):
- parameter_shapes = {
- name: tuple(param.shape) for name, param in model.encoder_only_named_params()
- }
- assert len(parameter_shapes) == 35
- assert parameter_shapes == {
- "model.encoder.embed_positions.weight": (1026, 24),
- "model.encoder.layernorm_embedding.bias": (24,),
- "model.encoder.layernorm_embedding.weight": (24,),
- "model.encoder.layers.0.fc1.bias": (16,),
- "model.encoder.layers.0.fc1.weight": (16, 24),
- "model.encoder.layers.0.fc2.bias": (24,),
- "model.encoder.layers.0.fc2.weight": (24, 16),
- "model.encoder.layers.0.final_layer_norm.bias": (24,),
- "model.encoder.layers.0.final_layer_norm.weight": (24,),
- "model.encoder.layers.0.self_attn.k_proj.bias": (24,),
- "model.encoder.layers.0.self_attn.k_proj.weight": (24, 24),
- "model.encoder.layers.0.self_attn.out_proj.bias": (24,),
- "model.encoder.layers.0.self_attn.out_proj.weight": (24, 24),
- "model.encoder.layers.0.self_attn.q_proj.bias": (24,),
- "model.encoder.layers.0.self_attn.q_proj.weight": (24, 24),
- "model.encoder.layers.0.self_attn.v_proj.bias": (24,),
- "model.encoder.layers.0.self_attn.v_proj.weight": (24, 24),
- "model.encoder.layers.0.self_attn_layer_norm.bias": (24,),
- "model.encoder.layers.0.self_attn_layer_norm.weight": (24,),
- "model.encoder.layers.1.fc1.bias": (16,),
- "model.encoder.layers.1.fc1.weight": (16, 24),
- "model.encoder.layers.1.fc2.bias": (24,),
- "model.encoder.layers.1.fc2.weight": (24, 16),
- "model.encoder.layers.1.final_layer_norm.bias": (24,),
- "model.encoder.layers.1.final_layer_norm.weight": (24,),
- "model.encoder.layers.1.self_attn.k_proj.bias": (24,),
- "model.encoder.layers.1.self_attn.k_proj.weight": (24, 24),
- "model.encoder.layers.1.self_attn.out_proj.bias": (24,),
- "model.encoder.layers.1.self_attn.out_proj.weight": (24, 24),
- "model.encoder.layers.1.self_attn.q_proj.bias": (24,),
- "model.encoder.layers.1.self_attn.q_proj.weight": (24, 24),
- "model.encoder.layers.1.self_attn.v_proj.bias": (24,),
- "model.encoder.layers.1.self_attn.v_proj.weight": (24, 24),
- "model.encoder.layers.1.self_attn_layer_norm.bias": (24,),
- "model.encoder.layers.1.self_attn_layer_norm.weight": (24,),
- }
-
-
-def test_decoder_only_named_params(model):
- parameter_shapes = {
- name: tuple(param.shape) for name, param in model.decoder_only_named_params()
- }
- assert len(parameter_shapes) == 55
- assert parameter_shapes == {
- "model.decoder.embed_positions.weight": (1026, 24),
- "model.decoder.layernorm_embedding.bias": (24,),
- "model.decoder.layernorm_embedding.weight": (24,),
- "model.decoder.layers.0.encoder_attn.k_proj.bias": (24,),
- "model.decoder.layers.0.encoder_attn.k_proj.weight": (24, 24),
- "model.decoder.layers.0.encoder_attn.out_proj.bias": (24,),
- "model.decoder.layers.0.encoder_attn.out_proj.weight": (24, 24),
- "model.decoder.layers.0.encoder_attn.q_proj.bias": (24,),
- "model.decoder.layers.0.encoder_attn.q_proj.weight": (24, 24),
- "model.decoder.layers.0.encoder_attn.v_proj.bias": (24,),
- "model.decoder.layers.0.encoder_attn.v_proj.weight": (24, 24),
- "model.decoder.layers.0.encoder_attn_layer_norm.bias": (24,),
- "model.decoder.layers.0.encoder_attn_layer_norm.weight": (24,),
- "model.decoder.layers.0.fc1.bias": (16,),
- "model.decoder.layers.0.fc1.weight": (16, 24),
- "model.decoder.layers.0.fc2.bias": (24,),
- "model.decoder.layers.0.fc2.weight": (24, 16),
- "model.decoder.layers.0.final_layer_norm.bias": (24,),
- "model.decoder.layers.0.final_layer_norm.weight": (24,),
- "model.decoder.layers.0.self_attn.k_proj.bias": (24,),
- "model.decoder.layers.0.self_attn.k_proj.weight": (24, 24),
- "model.decoder.layers.0.self_attn.out_proj.bias": (24,),
- "model.decoder.layers.0.self_attn.out_proj.weight": (24, 24),
- "model.decoder.layers.0.self_attn.q_proj.bias": (24,),
- "model.decoder.layers.0.self_attn.q_proj.weight": (24, 24),
- "model.decoder.layers.0.self_attn.v_proj.bias": (24,),
- "model.decoder.layers.0.self_attn.v_proj.weight": (24, 24),
- "model.decoder.layers.0.self_attn_layer_norm.bias": (24,),
- "model.decoder.layers.0.self_attn_layer_norm.weight": (24,),
- "model.decoder.layers.1.encoder_attn.k_proj.bias": (24,),
- "model.decoder.layers.1.encoder_attn.k_proj.weight": (24, 24),
- "model.decoder.layers.1.encoder_attn.out_proj.bias": (24,),
- "model.decoder.layers.1.encoder_attn.out_proj.weight": (24, 24),
- "model.decoder.layers.1.encoder_attn.q_proj.bias": (24,),
- "model.decoder.layers.1.encoder_attn.q_proj.weight": (24, 24),
- "model.decoder.layers.1.encoder_attn.v_proj.bias": (24,),
- "model.decoder.layers.1.encoder_attn.v_proj.weight": (24, 24),
- "model.decoder.layers.1.encoder_attn_layer_norm.bias": (24,),
- "model.decoder.layers.1.encoder_attn_layer_norm.weight": (24,),
- "model.decoder.layers.1.fc1.bias": (16,),
- "model.decoder.layers.1.fc1.weight": (16, 24),
- "model.decoder.layers.1.fc2.bias": (24,),
- "model.decoder.layers.1.fc2.weight": (24, 16),
- "model.decoder.layers.1.final_layer_norm.bias": (24,),
- "model.decoder.layers.1.final_layer_norm.weight": (24,),
- "model.decoder.layers.1.self_attn.k_proj.bias": (24,),
- "model.decoder.layers.1.self_attn.k_proj.weight": (24, 24),
- "model.decoder.layers.1.self_attn.out_proj.bias": (24,),
- "model.decoder.layers.1.self_attn.out_proj.weight": (24, 24),
- "model.decoder.layers.1.self_attn.q_proj.bias": (24,),
- "model.decoder.layers.1.self_attn.q_proj.weight": (24, 24),
- "model.decoder.layers.1.self_attn.v_proj.bias": (24,),
- "model.decoder.layers.1.self_attn.v_proj.weight": (24, 24),
- "model.decoder.layers.1.self_attn_layer_norm.bias": (24,),
- "model.decoder.layers.1.self_attn_layer_norm.weight": (24,),
- }
-
-
-def test_encoder_decoder_shared_named_params(model):
- parameter_shapes = {
- name: tuple(param.shape) for name, param in model.encoder_decoder_shared_named_params()
- }
- assert len(parameter_shapes) == 1
- assert parameter_shapes == {"model.shared.weight": (50270, 24)}
-
-
-def test_base_model_named_params(model):
- parameter_shapes = {
- name: tuple(param.shape) for name, param in model.base_model_named_params()
- }
- assert len(parameter_shapes) == 91
- encoder_only_parameter_shapes = {
- name: tuple(param.shape) for name, param in model.encoder_only_named_params()
- }
- decoder_only_parameter_shapes = {
- name: tuple(param.shape) for name, param in model.decoder_only_named_params()
- }
- shared_parameter_shapes = {
- name: tuple(param.shape) for name, param in model.encoder_decoder_shared_named_params()
- }
- expected_parameter_shapes = {
- **encoder_only_parameter_shapes,
- **decoder_only_parameter_shapes,
- **shared_parameter_shapes,
- }
-
- assert parameter_shapes == expected_parameter_shapes
-
-
-def test_configure_optimizer(model):
- optimizer = model.configure_optimizer()
- assert isinstance(optimizer, torch.optim.AdamW)
- assert optimizer.defaults["lr"] == 0.001
- assert optimizer.defaults["weight_decay"] == model.config.weight_decay == 0.01
- assert len(optimizer.param_groups) == 6
- assert all(param_group["lr"] == model.config.lr for param_group in optimizer.param_groups)
-
- # head parameters
- assert optimizer.param_groups[0]["weight_decay"] == model.config.weight_decay == 0.01
- # decoder only layer norm parameters
- assert optimizer.param_groups[1]["weight_decay"] == model.config.weight_decay == 0.01
- # decoder only other parameters
- assert optimizer.param_groups[2]["weight_decay"] == model.config.weight_decay == 0.01
- # encoder only layer norm parameters
- assert (
- optimizer.param_groups[3]["weight_decay"] == model.config.encoder_layer_norm_decay == 0.001
- )
- # encoder only other parameters
- assert optimizer.param_groups[4]["weight_decay"] == model.config.weight_decay == 0.01
- # encoder-decoder shared parameters
- assert optimizer.param_groups[5]["weight_decay"] == model.config.weight_decay == 0.01
-
- all_optimized_parameters = set()
- for param_group in optimizer.param_groups:
- all_optimized_parameters.update(set(param_group["params"]))
- assert len(all_optimized_parameters) > 0
- # check that all model parameters are covered
- all_model_parameters = {param for name, param in model.named_parameters()}
- assert all_optimized_parameters == all_model_parameters
-
-
-# note that this is only used for the tests below which are marked as slow
-# and are primarily meant to show how beam search works
-@pytest.fixture(scope="module")
-def pretrained_model() -> BartAsPointerNetwork:
- torch.random.manual_seed(42)
- model = BartAsPointerNetwork.from_pretrained(
- "sshleifer/distilbart-xsum-12-1",
- # label id space
- bos_token_id=0, # taskmodule.bos_id,
- eos_token_id=1, # taskmodule.eos_id,
- pad_token_id=1, # taskmodule.eos_id,
- # target token id space
- target_token_ids=[0, 2, 50266, 50269, 50268, 50265, 50267], # taskmodule.target_token_ids,
- # mapping to better initialize the label embedding weights
- # taken from taskmodule.label_embedding_weight_mapping
- embedding_weight_mapping={
- 50266: [39763],
- 50269: [10166],
- 50268: [5970],
- 50265: [45260],
- 50267: [354, 1215, 9006],
- },
- decoder_position_id_mode="pattern",
- decoder_position_id_pattern=[0, 0, 1, 0, 0, 1, 1],
- )
-
- return model
-
-
-ARTICLE_TO_SUMMARIZE = (
- "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
- "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
- "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
-)
-
-
-@pytest.mark.slow
-def test_bart_pointer_network_beam_search(pretrained_model, taskmodule):
- model = pretrained_model
- encoder_input_str = ARTICLE_TO_SUMMARIZE # "translate English to German: How old are you?"
- encoder_input_tokenized = taskmodule.tokenizer(encoder_input_str, return_tensors="pt")
-
- # lets run beam search using 3 beams
- num_beams = 3
- # define decoder start token ids
- decoder_input_ids = torch.ones((num_beams, 1), device=model.device, dtype=torch.long)
- decoder_input_ids = decoder_input_ids * model.config.decoder_start_token_id
-
- # add encoder_outputs to model keyword arguments
- encoder = model.get_encoder()
- encoder_input_ids = encoder_input_tokenized.input_ids.repeat_interleave(num_beams, dim=0)
- encoder_attention_mask = encoder_input_tokenized.attention_mask.repeat_interleave(
- num_beams, dim=0
- )
- torch.manual_seed(42)
- encoder_outputs = encoder(encoder_input_ids, return_dict=True)
- model_kwargs = {
- "encoder_outputs": encoder_outputs,
- "encoder_input_ids": encoder_input_ids,
- "encoder_attention_mask": encoder_attention_mask,
- }
-
- # instantiate beam scorer
- beam_scorer = BeamSearchScorer(
- batch_size=1,
- num_beams=num_beams,
- device=model.device,
- )
-
- # instantiate logits processors
- logits_processor = LogitsProcessorList(
- [
- MinLengthLogitsProcessor(5, eos_token_id=model.config.eos_token_id),
- ]
- )
-
- torch.manual_seed(42)
- outputs = model.beam_search(
- decoder_input_ids,
- beam_scorer,
- logits_processor=logits_processor,
- pad_token_id=model.config.pad_token_id,
- eos_token_id=model.config.eos_token_id,
- max_length=20,
- **model_kwargs,
- )
-
- torch.testing.assert_close(
- outputs,
- torch.tensor(
- [[0, 10, 30, 53, 54, 45, 15, 16, 17, 33, 33, 33, 35, 33, 58, 39, 41, 35, 33, 35]]
- ),
- )
-
- # result = tokenizer.batch_decode(outputs, skip_special_tokens=True)
- # assert result == [
- # " power lines in California have been shut down after a power provider said it was due to high winds."
- # ]
-
-
-@pytest.mark.slow
-def test_bart_pointer_network_generate_with_scores(pretrained_model, taskmodule):
- model = pretrained_model
- encoder_input_str = ARTICLE_TO_SUMMARIZE # "translate English to German: How old are you?"
- inputs = taskmodule.tokenizer(encoder_input_str, max_length=1024, return_tensors="pt")
-
- torch.manual_seed(42)
- outputs = model.generate(
- inputs["input_ids"],
- num_beams=3,
- min_length=5,
- max_length=20,
- return_dict_in_generate=True,
- output_scores=True,
- )
- assert isinstance(outputs, BeamSearchEncoderDecoderOutput)
- torch.testing.assert_close(outputs.sequences_scores, torch.tensor([-8.088160514831543]))
- torch.testing.assert_close(
- outputs.sequences,
- torch.tensor(
- [[0, 10, 30, 53, 54, 45, 15, 16, 17, 33, 33, 33, 35, 33, 58, 39, 41, 35, 33, 35]]
- ),
- )
-
- # result = tokenizer.batch_decode(
- # summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
- # )
- # assert result == [" power lines in California have been shut down on Friday."]
diff --git a/tests/models/base_models/test_bart_with_decoder_position_ids.py b/tests/models/base_models/test_bart_with_decoder_position_ids.py
deleted file mode 100644
index 0b814c64b..000000000
--- a/tests/models/base_models/test_bart_with_decoder_position_ids.py
+++ /dev/null
@@ -1,301 +0,0 @@
-import pytest
-import torch
-from torch.nn import Embedding
-from transformers import BartConfig
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
-from transformers.models.bart.modeling_bart import BartEncoder
-
-from pie_modules.models.base_models import BartModelWithDecoderPositionIds
-from pie_modules.models.base_models.bart_with_decoder_position_ids import (
- BartDecoderWithPositionIds,
- BartLearnedPositionalEmbeddingWithPositionIds,
-)
-
-
-def test_bart_learned_positional_embedding_with_position_ids():
- # Arrange
- torch.manual_seed(42)
- model = BartLearnedPositionalEmbeddingWithPositionIds(10, 6)
- input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
- position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
- position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 2]])
-
- # Act
- original = model(input_ids=input_ids)
- replaced_original = model(input_ids=input_ids, position_ids=position_ids_original)
- replaced_different = model(input_ids=input_ids, position_ids=position_ids_different)
-
- # Assert
- assert original.shape == (1, 10, 6)
- assert replaced_original.shape == (1, 10, 6)
- torch.testing.assert_close(original, replaced_original)
- assert replaced_different.shape == (1, 10, 6)
- assert not torch.allclose(original, replaced_different)
-
-
-@pytest.fixture(scope="module")
-def bart_config():
- return BartConfig(
- vocab_size=30,
- d_model=10,
- encoder_layers=1,
- decoder_layers=1,
- encoder_attention_heads=2,
- decoder_attention_heads=2,
- encoder_ffn_dim=20,
- decoder_ffn_dim=20,
- max_position_embeddings=10,
- )
-
-
-@pytest.fixture(scope="module")
-def bart_decoder_with_position_ids(bart_config):
- return BartDecoderWithPositionIds(config=bart_config)
-
-
-def test_bart_decoder_with_position_ids(bart_decoder_with_position_ids):
- assert bart_decoder_with_position_ids is not None
-
-
-def test_bart_decoder_with_position_ids_get_input_embeddings(bart_decoder_with_position_ids):
- input_embeddings = bart_decoder_with_position_ids.get_input_embeddings()
- assert input_embeddings is not None
- assert isinstance(input_embeddings, Embedding)
- assert input_embeddings.embedding_dim == 10
- assert input_embeddings.num_embeddings == 30
-
-
-def test_bart_decoder_with_position_ids_set_input_embeddings(bart_decoder_with_position_ids):
- original_input_embeddings = bart_decoder_with_position_ids.get_input_embeddings()
- torch.manual_seed(42)
- new_input_embeddings = Embedding(
- original_input_embeddings.num_embeddings, original_input_embeddings.embedding_dim
- )
- bart_decoder_with_position_ids.set_input_embeddings(new_input_embeddings)
- input_embeddings = bart_decoder_with_position_ids.get_input_embeddings()
- assert input_embeddings == new_input_embeddings
- assert input_embeddings is not original_input_embeddings
- # recover original input embeddings
- bart_decoder_with_position_ids.set_input_embeddings(original_input_embeddings)
-
-
-def test_bart_decoder_with_position_ids_forward(bart_decoder_with_position_ids):
- # Arrange
- model = bart_decoder_with_position_ids
- input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
- position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
- position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]])
-
- # Act
- torch.manual_seed(42)
- original = model(input_ids=input_ids)
- torch.manual_seed(42)
- replaced_original = model(input_ids=input_ids, position_ids=position_ids_original)
- torch.manual_seed(42)
- replaced_different = model(input_ids=input_ids, position_ids=position_ids_different)
-
- # Assert
- assert isinstance(original, BaseModelOutputWithPastAndCrossAttentions)
- assert original.last_hidden_state.shape == (1, 8, 10)
- assert isinstance(replaced_original, BaseModelOutputWithPastAndCrossAttentions)
- torch.testing.assert_close(original.last_hidden_state, replaced_original.last_hidden_state)
-
- assert isinstance(replaced_different, BaseModelOutputWithPastAndCrossAttentions)
- assert replaced_different.last_hidden_state.shape == (1, 8, 10)
- assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state)
-
-
-def test_bart_decoder_with_position_ids_forward_with_inputs_embeds(bart_decoder_with_position_ids):
- # Arrange
- model = bart_decoder_with_position_ids
- inputs_embeds = torch.randn(1, 8, 10)
- position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
- position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]])
-
- # Act
- torch.manual_seed(42)
- original = model(inputs_embeds=inputs_embeds)
- torch.manual_seed(42)
- replaced_original = model(inputs_embeds=inputs_embeds, position_ids=position_ids_original)
- torch.manual_seed(42)
- replaced_different = model(inputs_embeds=inputs_embeds, position_ids=position_ids_different)
-
- # Assert
- assert isinstance(original, BaseModelOutputWithPastAndCrossAttentions)
- assert original.last_hidden_state.shape == (1, 8, 10)
- assert isinstance(replaced_original, BaseModelOutputWithPastAndCrossAttentions)
- torch.testing.assert_close(original.last_hidden_state, replaced_original.last_hidden_state)
-
- assert isinstance(replaced_different, BaseModelOutputWithPastAndCrossAttentions)
- assert replaced_different.last_hidden_state.shape == (1, 8, 10)
- assert not torch.allclose(original.last_hidden_state, replaced_different.last_hidden_state)
-
-
-def test_bart_decoder_with_position_ids_forward_wrong_position_ids_shape(
- bart_decoder_with_position_ids,
-):
- # Arrange
- model = bart_decoder_with_position_ids
- input_ids = torch.tensor([[0, 1, 2, 3]])
- position_ids_wrong_shape = torch.tensor([[0, 1, 2]])
-
- # Act
- torch.manual_seed(42)
- with pytest.raises(ValueError) as excinfo:
- model(input_ids=input_ids, position_ids=position_ids_wrong_shape)
- assert (
- str(excinfo.value)
- == "Position IDs shape torch.Size([1, 3]) does not match input ids shape torch.Size([1, 4])."
- )
-
-
-@pytest.fixture(scope="module")
-def bart_model_with_decoder_position_ids(bart_config):
- torch.manual_seed(42)
- model = BartModelWithDecoderPositionIds(config=bart_config)
- model.train()
- return model
-
-
-def test_bart_model_with_decoder_position_ids(bart_model_with_decoder_position_ids):
- assert bart_model_with_decoder_position_ids is not None
-
-
-def test_bart_model_with_decoder_position_ids_get_input_embeddings(
- bart_model_with_decoder_position_ids,
-):
- input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings()
- assert input_embeddings is not None
- assert isinstance(input_embeddings, Embedding)
- assert input_embeddings.embedding_dim == 10
- assert input_embeddings.num_embeddings == 30
-
-
-def test_bart_model_with_decoder_position_ids_set_input_embeddings(
- bart_model_with_decoder_position_ids,
-):
- original_input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings()
- torch.manual_seed(42)
- new_input_embeddings = Embedding(
- original_input_embeddings.num_embeddings, original_input_embeddings.embedding_dim
- )
- bart_model_with_decoder_position_ids.set_input_embeddings(new_input_embeddings)
- input_embeddings = bart_model_with_decoder_position_ids.get_input_embeddings()
- assert input_embeddings == new_input_embeddings
- assert input_embeddings is not original_input_embeddings
- # recover original input embeddings
- bart_model_with_decoder_position_ids.set_input_embeddings(original_input_embeddings)
-
-
-def test_bart_model_with_decoder_position_ids_get_encoder(bart_model_with_decoder_position_ids):
- encoder = bart_model_with_decoder_position_ids.get_encoder()
- assert encoder is not None
- assert isinstance(encoder, BartEncoder)
-
-
-def test_bart_model_with_decoder_position_ids_get_decoder(bart_model_with_decoder_position_ids):
- decoder = bart_model_with_decoder_position_ids.get_decoder()
- assert decoder is not None
- assert isinstance(decoder, BartDecoderWithPositionIds)
-
-
-@pytest.mark.parametrize(
- "return_dict, prepare_encoder_outputs, output_everything",
- [(True, True, True), (False, False, False)],
-)
-def test_bart_model_with_decoder_position_forward(
- bart_model_with_decoder_position_ids, return_dict, prepare_encoder_outputs, output_everything
-):
- model = bart_model_with_decoder_position_ids
-
- # Arrange
- model.eval()
- input_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
- position_ids_original = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
- position_ids_different = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2]])
- common_kwargs = {"input_ids": input_ids, "return_dict": return_dict}
- if prepare_encoder_outputs:
- common_kwargs["encoder_outputs"] = bart_model_with_decoder_position_ids.get_encoder()(
- input_ids=input_ids, return_dict=False
- )
- else:
- common_kwargs["encoder_outputs"] = None
- if output_everything:
- common_kwargs["output_attentions"] = True
- common_kwargs["output_hidden_states"] = True
-
- # Act
- original = model(**common_kwargs)[0]
- replaced_original = model(
- decoder_position_ids=position_ids_original,
- **common_kwargs,
- )[0]
- replaced_different = model(decoder_position_ids=position_ids_different, **common_kwargs)[0]
-
- # Assert
- assert isinstance(original, torch.FloatTensor)
- assert original.shape == (1, 8, 10)
- torch.testing.assert_close(
- original[0, :5, :3],
- torch.tensor(
- [
- [0.7589594721794128, 1.0452316999435425, 0.7063764333724976],
- [-0.12192550301551819, -0.9932114481925964, -0.722382664680481],
- [0.24711951613426208, -0.291597843170166, -1.0466505289077759],
- [1.1228691339492798, -0.0873560905456543, 1.534016728401184],
- [-1.1132177114486694, 0.2277398556470871, 1.6456809043884277],
- ]
- ),
- )
- torch.testing.assert_close(
- original.sum(dim=-1),
- torch.tensor(
- [
- [
- 0.0,
- -1.1920928955078125e-07,
- -1.1920928955078125e-07,
- -2.682209014892578e-07,
- 5.960464477539063e-08,
- 5.960464477539063e-08,
- 2.384185791015625e-07,
- -5.960464477539063e-08,
- ]
- ]
- ),
- )
- assert isinstance(replaced_original, torch.FloatTensor)
- torch.testing.assert_close(original, replaced_original)
-
- assert isinstance(replaced_different, torch.FloatTensor)
- assert replaced_different.shape == (1, 8, 10)
- torch.testing.assert_close(
- replaced_different[0, :5, :3],
- torch.tensor(
- [
- [0.7589594721794128, 1.0452316999435425, 0.7063764333724976],
- [-0.0127173513174057, -0.8127143383026123, -1.256797194480896],
- [1.0517312288284302, 0.037927787750959396, -0.28661563992500305],
- [0.5884698629379272, 0.9930593371391296, 1.3842554092407227],
- [0.6132885813713074, -1.0105736255645752, 2.361264228820801],
- ]
- ),
- )
- torch.testing.assert_close(
- replaced_different.sum(dim=-1),
- torch.tensor(
- [
- [
- 0.0,
- -2.384185791015625e-07,
- -1.7881393432617188e-07,
- 2.5331974029541016e-07,
- 1.4901161193847656e-07,
- 1.1920928955078125e-07,
- -1.1920928955078125e-07,
- -1.7881393432617188e-07,
- ]
- ]
- ),
- )
- assert not torch.allclose(replaced_different, original)
diff --git a/tests/models/components/__init__.py b/tests/models/components/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/models/components/test_pointer_head.py b/tests/models/components/test_pointer_head.py
deleted file mode 100644
index eb2650f68..000000000
--- a/tests/models/components/test_pointer_head.py
+++ /dev/null
@@ -1,713 +0,0 @@
-import pytest
-import torch
-from torch import nn
-
-from pie_modules.models.components.pointer_head import PointerHead
-
-
-def get_pointer_head(num_embeddings=120, embedding_dim=3, eos_id=1, pad_id=2, **kwargs):
- torch.manual_seed(42)
- return PointerHead(
- embeddings=nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim),
- # bos, eos, pad, 3 x label ids
- target_token_ids=[100, 101, 102, 110, 111, 112],
- bos_id=0, # -> 100
- eos_id=eos_id, # 1 (default) -> 101
- pad_id=pad_id, # 2 (default) -> 102
- embedding_weight_mapping={
- "110": [20, 21],
- "111": [30],
- },
- use_encoder_mlp=True,
- use_constraints_encoder_mlp=True,
- **kwargs,
- )
-
-
-def test_get_pointer_head():
- pointer_head = get_pointer_head()
- assert pointer_head is not None
- assert not pointer_head.use_prepared_position_ids
-
-
-def test_set_embeddings():
- pointer_head = get_pointer_head()
- original_embeddings = pointer_head.embeddings
- new_embeddings = nn.Embedding(
- original_embeddings.num_embeddings, original_embeddings.embedding_dim
- )
- pointer_head.set_embeddings(new_embeddings)
- assert pointer_head.embeddings is not None
- assert pointer_head.embeddings != original_embeddings
- assert pointer_head.embeddings == new_embeddings
-
-
-def test_overwrite_embeddings_with_mapping():
- pointer_head = get_pointer_head()
- original_embeddings_weight = pointer_head.embeddings.weight.clone()
- pointer_head.overwrite_embeddings_with_mapping()
- assert pointer_head.embeddings is not None
- assert not torch.equal(pointer_head.embeddings.weight, original_embeddings_weight)
- torch.testing.assert_close(
- pointer_head.embeddings.weight[110], original_embeddings_weight[[20, 21]].mean(dim=0)
- )
- torch.testing.assert_close(
- pointer_head.embeddings.weight[111], original_embeddings_weight[[30]].mean(dim=0)
- )
-
-
-@pytest.mark.parametrize(
- "use_attention_mask",
- [True, False],
-)
-def test_prepare_decoder_input_ids(use_attention_mask):
- pointer_head = get_pointer_head()
- encoder_input_ids = torch.tensor(
- [
- [10, 11, 12, 13, 14, 15],
- [20, 21, 22, 23, 24, 0],
- ]
- ).to(torch.long)
- # we have 3 special tokens (bos, eos, pad) and 3 labels, so the offset is 6
- input_ids = torch.tensor(
- [
- # bos, offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6)
- [0, 6, 7, 3, 4, 8],
- # bos, label (3), offset (3=9-6), eos, pad, pad
- [0, 3, 9, 1, 2, 2],
- ]
- ).to(torch.long)
- # this is the attention mask for the (decoder) input_ids, not the encoder_input_ids
- attention_mask = (
- torch.tensor(
- [
- [1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 0, 0],
- ]
- ).to(torch.long)
- if use_attention_mask
- else None
- )
-
- prepared_decoder_input_ids = pointer_head.prepare_decoder_input_ids(
- input_ids=input_ids,
- encoder_input_ids=encoder_input_ids,
- )
- assert prepared_decoder_input_ids is not None
- assert prepared_decoder_input_ids.shape == input_ids.shape
- # to recap, the target2token_id mapping is (bos, eos, pad, 3 x label ids)
- torch.testing.assert_close(
- pointer_head.target2token_id, torch.tensor([100, 101, 102, 110, 111, 112])
- )
- # 3 labels + bos / pad
- assert pointer_head.pointer_offset == 6
- assert prepared_decoder_input_ids.tolist() == [
- # bos (0), offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6)
- [100, 10, 11, 110, 111, 12],
- # bos (0), label (3), offset (3=9-6), eos (1), pad (2), pad (2)
- [100, 110, 23, 101, 102, 102],
- ]
-
-
-def test_prepare_decoder_input_ids_out_of_bounds():
- pointer_head = get_pointer_head()
- # 3 labels + bos / pad
- assert pointer_head.pointer_offset == 6
- encoder_input_ids = torch.tensor(
- [
- [100, 101, 102],
- ]
- ).to(torch.long)
- input_ids = torch.tensor(
- [
- # 9 is out of bounds: > pointer_head.pointer_offset + len(encoder_input_ids)
- [0, 9],
- ]
- ).to(torch.long)
-
- with pytest.raises(ValueError) as excinfo:
- pointer_head.prepare_decoder_input_ids(
- input_ids=input_ids, encoder_input_ids=encoder_input_ids
- )
- assert str(excinfo.value) == (
- "encoder_input_ids_index.max() [3] must be smaller than encoder_input_length [3]!"
- )
-
-
-@pytest.mark.parametrize(
- "decoder_position_id_mode",
- ["pattern", "pattern_with_increment", "mapping"],
-)
-def test_prepare_decoder_position_ids(decoder_position_id_mode):
- pointer_head = get_pointer_head(
- decoder_position_id_mode=decoder_position_id_mode,
- decoder_position_id_pattern=[0, 1, 1, 2],
- decoder_position_id_mapping={"default": 3, "vocab": 2, "bos": 0, "eos": 0, "pad": 1},
- )
- input_ids = torch.tensor(
- [
- # bos, offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6)
- [0, 6, 7, 3, 4, 8],
- # bos, label (3), offset (3=9-6), eos, pad, pad
- [0, 3, 9, 1, 2, 2],
- ]
- ).to(torch.long)
-
- prepared_decoder_position_ids = pointer_head.prepare_decoder_position_ids(input_ids=input_ids)
- assert prepared_decoder_position_ids is not None
- assert prepared_decoder_position_ids.shape == input_ids.shape
- if decoder_position_id_mode == "pattern":
- assert prepared_decoder_position_ids.tolist() == [
- [0, 2, 3, 3, 4, 2],
- [0, 2, 3, 3, 1, 1],
- ]
- elif decoder_position_id_mode == "pattern_with_increment":
- # the position ids (except for position-bos=0 and position-pad=1) get increased by 3 per record
- # (which has length 4)
- assert prepared_decoder_position_ids.tolist() == [
- [0, 2, 3, 3, 4, 5],
- [0, 2, 3, 3, 1, 1],
- ]
- elif decoder_position_id_mode == "mapping":
- assert prepared_decoder_position_ids.tolist() == [
- [0, 3, 3, 2, 2, 3],
- [0, 2, 3, 0, 1, 1],
- ]
- else:
- raise ValueError(f"unknown decoder_position_id_mode={decoder_position_id_mode}")
-
-
-def test_prepare_decoder_position_ids_unknown_mode():
- with pytest.raises(ValueError) as excinfo:
- get_pointer_head(decoder_position_id_mode="unknown")
- assert str(excinfo.value) == (
- 'decoder_position_id_mode="unknown" is not supported, use one of "pattern", '
- '"pattern_with_increment", or "mapping"!'
- )
-
-
-@pytest.mark.parametrize(
- "decoder_position_id_mode",
- ["pattern", "pattern_with_increment", "mapping"],
-)
-def test_prepare_decoder_position_ids_missing_parameter(decoder_position_id_mode):
- with pytest.raises(ValueError) as excinfo:
- get_pointer_head(decoder_position_id_mode=decoder_position_id_mode)
- if decoder_position_id_mode in ["pattern", "pattern_with_increment"]:
- assert (
- str(excinfo.value) == "decoder_position_id_pattern must be provided when using "
- 'decoder_position_id_mode="pattern" or "pattern_with_increment"!'
- )
- elif decoder_position_id_mode == "mapping":
- assert (
- str(excinfo.value)
- == 'decoder_position_id_mode="mapping" requires decoder_position_id_mapping to be provided!'
- )
- else:
- raise ValueError(f"unknown decoder_position_id_mode={decoder_position_id_mode}")
-
-
-def test_prepare_decoder_position_ids_with_wrong_mapping():
- input_ids = torch.tensor(
- [
- # bos, offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6)
- [0, 6, 7, 3, 4, 8],
- # bos, label (3), offset (3=9-6), eos, pad, pad
- [0, 3, 9, 1, 2, 2],
- ]
- ).to(torch.long)
-
- # missing default
- pointer_head = get_pointer_head(
- decoder_position_id_mode="mapping",
- decoder_position_id_mapping={"vocab": 2, "bos": 0, "eos": 0, "pad": 1},
- )
- with pytest.raises(ValueError) as excinfo:
- pointer_head.prepare_decoder_position_ids(input_ids=input_ids)
- assert (
- str(excinfo.value)
- == "mapping must contain a default entry, but only contains ['vocab', 'bos', 'eos', 'pad']!"
- )
-
- # unknown key
- pointer_head = get_pointer_head(
- decoder_position_id_mode="mapping",
- decoder_position_id_mapping={
- "default": 3,
- "vocab": 2,
- "bos": 0,
- "eos": 0,
- "pad": 1,
- "unknown": 4,
- },
- )
- with pytest.raises(ValueError) as excinfo:
- pointer_head.prepare_decoder_position_ids(input_ids=input_ids)
- assert (
- str(excinfo.value) == "Mapping contains unknown key 'unknown' "
- "(mapping: {'default': 3, 'vocab': 2, 'bos': 0, 'eos': 0, 'pad': 1, 'unknown': 4})."
- )
-
- # multiple values for same input id
- pointer_head = get_pointer_head(
- # same id for eos and pad
- eos_id=1,
- pad_id=1,
- decoder_position_id_mode="mapping",
- decoder_position_id_mapping={
- "default": 3,
- "vocab": 2,
- "bos": 0,
- # different position ids for eos and pad, this is not allowed when eos and pad have the same id
- "eos": 0,
- "pad": 1,
- },
- )
- with pytest.raises(ValueError) as excinfo:
- pointer_head.prepare_decoder_position_ids(input_ids=input_ids)
- assert (
- str(excinfo.value)
- == "Can not set the position ids for 'pad' to 1 because it was already set to 0 by key 'eos'. "
- "Note that both, 'pad' and 'eos', have the same id (1), so their position_ids need to be "
- "also the same (position id mapping: {'default': 3, 'vocab': 2, 'bos': 0, 'eos': 0, 'pad': 1})."
- )
-
-
-def test_prepare_decoder_inputs():
- pointer_head = get_pointer_head(
- decoder_position_id_mode="pattern", decoder_position_id_pattern=[0, 1, 1, 2]
- )
- encoder_input_ids = torch.tensor(
- [
- [10, 11, 12, 13, 14, 15],
- [20, 21, 22, 23, 24, 0],
- ]
- ).to(torch.long)
- input_ids = torch.tensor(
- [
- # bos, offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6)
- [0, 6, 7, 3, 4, 8],
- # bos, label (3), offset (3=9-6), eos, pad, pad
- [0, 3, 9, 1, 2, 2],
- ]
- ).to(torch.long)
-
- decoder_inputs = pointer_head.prepare_decoder_inputs(
- input_ids=input_ids,
- encoder_input_ids=encoder_input_ids,
- )
- assert set(decoder_inputs.keys()) == {"input_ids", "position_ids"}
- assert decoder_inputs["input_ids"].shape == input_ids.shape
- assert decoder_inputs["position_ids"].shape == input_ids.shape
- # to recap, the target2token_id mapping is (bos, eos, pad, 3 x label ids)
- torch.testing.assert_close(
- pointer_head.target2token_id, torch.tensor([100, 101, 102, 110, 111, 112])
- )
- # 3 labels + bos / pad
- assert pointer_head.pointer_offset == 6
- assert decoder_inputs["input_ids"].tolist() == [
- # bos (0), offset (0=6-6), offset (1=7-6), label (3), label (4), offset (2=8-6)
- [100, 10, 11, 110, 111, 12],
- # bos (0), label (3), offset (3=9-6), eos (1), pad (2), pad (2)
- [100, 110, 23, 101, 102, 102],
- ]
- assert decoder_inputs["position_ids"].tolist() == [
- [0, 2, 3, 3, 4, 2],
- [0, 2, 3, 3, 1, 1],
- ]
-
-
-def test_forward():
- pointer_head = get_pointer_head()
- # shape: (batch_size=2, input_sequence_length=5)
- encoder_input_ids = torch.tensor(
- [
- [10, 11, 12, 13, 14],
- [20, 21, 22, 23, 0],
- ]
- ).to(torch.long)
- encoder_attention_mask = torch.tensor(
- [
- [1, 1, 1, 1, 1],
- [1, 1, 1, 1, 0],
- ]
- ).to(torch.long)
- # shape: (batch_size=2, input_sequence_length=5, hidden_size=3)
- encoder_last_hidden_state = pointer_head.embeddings(encoder_input_ids)
- # shape: (batch_size=2, target_sequence_length=4)
- prepared_input_ids = torch.tensor(
- [
- # bos (0), offset (0=6-6), offset (1=7-6), label (3)
- [100, 10, 11, 110],
- # bos (0), label (3), offset (3=9-6), eos (1)
- [100, 110, 23, 101],
- ]
- ).to(torch.long)
- # shape: (batch_size=2, target_sequence_length=4)
- last_hidden_state = pointer_head.embeddings(prepared_input_ids)
-
- torch.manual_seed(42)
- logits, loss = pointer_head(
- encoder_input_ids=encoder_input_ids,
- encoder_attention_mask=encoder_attention_mask,
- encoder_last_hidden_state=encoder_last_hidden_state,
- last_hidden_state=last_hidden_state,
- )
- assert loss is None
- assert logits is not None
- # shape: (batch_size=2, target_sequence_length=4, num_targets+num_offsets=6+5==11)
- assert logits.shape == (2, 4, 11)
- torch.testing.assert_close(
- logits,
- torch.tensor(
- [
- [
- [
- -1.0000000138484279e24,
- -0.9407045245170593,
- -1.0000000138484279e24,
- 0.5535521507263184,
- 0.04295700043439865,
- 1.0467679500579834,
- -1.110795497894287,
- 1.1652655601501465,
- 0.09444020688533783,
- 0.43052661418914795,
- -1.0437036752700806,
- ],
- [
- -1.0000000138484279e24,
- 1.1563994884490967,
- -1.0000000138484279e24,
- -0.8941665887832642,
- -0.6862093806266785,
- -1.154745101928711,
- 1.6984729766845703,
- -1.3889904022216797,
- -0.4076152741909027,
- -1.0112841129302979,
- 0.9846026301383972,
- ],
- [
- -1.0000000138484279e24,
- -1.9377808570861816,
- -1.0000000138484279e24,
- 2.437451124191284,
- 0.041493892669677734,
- 0.5383729338645935,
- -1.5238577127456665,
- 1.6700562238693237,
- -0.07231226563453674,
- 1.0911093950271606,
- -0.9189060926437378,
- ],
- [
- -1.0000000138484279e24,
- -1.880744218826294,
- -1.0000000138484279e24,
- 3.8719429969787598,
- 0.07287894189357758,
- -1.3378281593322754,
- -0.653921365737915,
- 0.783344566822052,
- -0.3344290256500244,
- 1.3571363687515259,
- 0.5505899786949158,
- ],
- ],
- [
- [
- -1.0000000138484279e24,
- -0.9407045245170593,
- -1.0000000138484279e24,
- 0.5535521507263184,
- 0.04295700043439865,
- 1.0467679500579834,
- -1.0019789934158325,
- 0.6891120672225952,
- -0.002076566219329834,
- 0.7561025619506836,
- -1.0000000331813535e32,
- ],
- [
- -1.0000000138484279e24,
- -1.880744218826294,
- -1.0000000138484279e24,
- 3.8719429969787598,
- 0.07287894189357758,
- -1.3378281593322754,
- -1.3875324726104736,
- -2.124865770339966,
- -2.559859275817871,
- 0.5425653457641602,
- -1.0000000331813535e32,
- ],
- [
- -1.0000000138484279e24,
- -1.479057788848877,
- -1.0000000138484279e24,
- 1.7857770919799805,
- 0.6723557114601135,
- 0.6378745436668396,
- -2.262815475463867,
- -0.1536862850189209,
- -0.5338708758354187,
- 1.3628911972045898,
- -1.0000000331813535e32,
- ],
- [
- -1.0000000138484279e24,
- 1.1815755367279053,
- -1.0000000138484279e24,
- -1.880744218826294,
- -0.10646091401576996,
- 0.1437276005744934,
- 1.0795626640319824,
- 0.6434042453765869,
- 1.0681594610214233,
- -0.5814396142959595,
- -1.0000000331813535e32,
- ],
- ],
- ]
- ),
- )
-
-
-@pytest.mark.parametrize(
- "with_constraints",
- [True, False],
-)
-def test_forward_with_labels(with_constraints):
- pointer_head = get_pointer_head(num_embeddings=300, embedding_dim=3)
-
- # shape: (batch_size=2, input_sequence_length=5)
- encoder_input_ids = torch.tensor(
- [
- [10, 11, 12, 13, 14],
- [20, 21, 22, 0, 0],
- ]
- ).to(torch.long)
- encoder_attention_mask = torch.tensor(
- [
- [1, 1, 1, 1, 1],
- [1, 1, 1, 0, 0],
- ]
- ).to(torch.long)
- # shape: (batch_size=2, input_sequence_length=5, hidden_size=3)
- # encoder_last_hidden_state = pointer_head.embeddings(encoder_input_ids)
- # shape: (batch_size=2, target_sequence_length=4)
- prepared_input_ids = torch.tensor(
- [
- # bos (0), offset (0=6-6), offset (1=7-6), label (3)
- [100, 10, 11, 110],
- # bos (0), label (3), offset (3=9-6), eos (1)
- [100, 110, 23, 101],
- ]
- ).to(torch.long)
- # shape: (batch_size=2, target_sequence_length=4)
- # last_hidden_state = pointer_head.embeddings(prepared_input_ids)
- labels = torch.tensor(
- [
- # offset (0=6-6), offset (1=7-6), label (3), label (4)
- [6, 7, 3, 4],
- # label (3), offset (3=9-6), eos, pad, pad
- [3, 9, 1, 2],
- ]
- ).to(torch.long)
- decoder_attention_mask = torch.tensor(
- [
- [1, 1, 1, 1],
- [1, 1, 1, 0],
- ]
- ).to(torch.long)
-
- # shape: (batch_size=2, target_sequence_length=4, num_targets+num_offsets=6+5==11)
- constraints = (
- # recap: the target2token_id mapping is (bos, eos, pad, 3 x label ids)
- torch.tensor(
- [
- [
- # allow all labels
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
- # allow all offsets different from previous label id (3)
- [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0],
- # allow all offsets different from previous label ids (3, 4)
- [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
- # allow all offsets
- [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
- ],
- [
- # allow all labels
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
- # allow all offsets
- [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
- # allow all offsets equal or bigger than previous one (9) or eos
- [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1],
- # allow only pad
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
- ],
- ]
- ).to(torch.long)
- )
-
- torch.manual_seed(42)
- # shape: (batch_size=2, input_sequence_length=6, hidden_size=3)
- encoder_last_hidden_state = pointer_head.embeddings(encoder_input_ids)
- last_hidden_state = pointer_head.embeddings(prepared_input_ids)
- _, loss = pointer_head(
- encoder_input_ids=encoder_input_ids,
- encoder_attention_mask=encoder_attention_mask,
- encoder_last_hidden_state=encoder_last_hidden_state,
- last_hidden_state=last_hidden_state,
- labels=labels,
- decoder_attention_mask=decoder_attention_mask,
- constraints=constraints if with_constraints else None,
- )
- assert loss is not None
- maybe_gradients = torch.autograd.grad(loss, pointer_head.parameters(), allow_unused=True)
- gradients = [g for g in maybe_gradients if g is not None]
- if not with_constraints:
- # embeddings.weight, 2 x (encoder_mlp.weight, encoder_mlp.bias)
- assert len(gradients) == 5
- # embeddings.weight (just check entries for special tokens and labels)
- torch.testing.assert_close(
- gradients[0][100:113],
- torch.tensor(
- [
- [0.29642319679260254, 0.012336060404777527, 0.14099650084972382],
- [0.015981415286660194, 0.17855659127235413, -0.21089009940624237],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [-0.8812153935432434, -0.43322375416755676, 0.07359108328819275],
- [0.22255337238311768, 0.09604272246360779, 0.017692387104034424],
- [-0.021408570930361748, -0.01747075282037258, 0.15882402658462524],
- ]
- ),
- )
- # first encoder_mlp.weight
- torch.testing.assert_close(
- gradients[1],
- torch.tensor(
- [
- [6.044770998414606e-05, -0.001140016596764326, 0.0007320810691453516],
- [0.014351745136082172, 0.01521987747400999, -0.028653975576162338],
- [0.011420723050832748, 0.0070406426675617695, -0.030101824551820755],
- ]
- ),
- )
- # first encoder_mlp.bias
- torch.testing.assert_close(
- gradients[2],
- torch.tensor([-0.0006180311902426183, -0.023118967190384865, -0.024205176159739494]),
- )
- # second encoder_mlp.weight
- torch.testing.assert_close(
- gradients[3],
- torch.tensor(
- [
- [-0.0005463349516503513, -0.016356423497200012, 0.01958528161048889],
- [-0.0005303063080646098, -0.029644077643752098, -0.1391362100839615],
- [0.0028533015865832567, 0.08096987009048462, 0.28279614448547363],
- ]
- ),
- )
- # second encoder_mlp.bias
- torch.testing.assert_close(
- gradients[4],
- torch.tensor([-0.030467912554740906, -0.045307278633117676, 0.06145985424518585]),
- )
- else:
- # embeddings.weight, 2 x (encoder_mlp.weight, encoder_mlp.bias), 2 x (constraints_encoder_mlp.weight, constraints_encoder_mlp.bias)
- assert len(gradients) == 9
- # embeddings.weight (just check entries for special tokens and labels)
- torch.testing.assert_close(
- gradients[0][100:113],
- torch.tensor(
- [
- [0.2915953993797302, 0.009700030088424683, 0.1484404355287552],
- [0.02216985821723938, 0.15251068770885468, -0.21624334156513214],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [0.0, 0.0, 0.0],
- [-0.8804605007171631, -0.4300656318664551, 0.0664108395576477],
- [0.21543428301811218, 0.093157559633255, 0.013825103640556335],
- [-0.021408570930361748, -0.01747075282037258, 0.15882402658462524],
- ]
- ),
- )
- # first encoder_mlp.weight
- torch.testing.assert_close(
- gradients[1],
- torch.tensor(
- [
- [-0.0003244421095587313, 0.006118832156062126, -0.003929311875253916],
- [0.013681752607226372, 0.013532182201743126, -0.027564184740185738],
- [0.012365758419036865, 0.00791379064321518, -0.02969365194439888],
- ]
- ),
- )
- # first encoder_mlp.bias
- torch.testing.assert_close(
- gradients[2],
- torch.tensor([0.003317170077934861, -0.021803036332130432, -0.023893579840660095]),
- )
- # second encoder_mlp.weight
- torch.testing.assert_close(
- gradients[3],
- torch.tensor(
- [
- [-0.004014550242573023, -0.018573174253106117, 0.019694898277521133],
- [-0.0019358742283657193, -0.030542463064193726, -0.13909178972244263],
- [0.0009692738531157374, 0.0797656774520874, 0.28285568952560425],
- ]
- ),
- )
- # second encoder_mlp.bias
- torch.testing.assert_close(
- gradients[4],
- torch.tensor([-0.046919066458940506, -0.05197446048259735, 0.05252313241362572]),
- )
- # first constraints_encoder_mlp.weight
- torch.testing.assert_close(
- gradients[5],
- torch.tensor(
- [
- [0.010755524039268494, -0.009512078016996384, -0.007983260788023472],
- [0.004236628767102957, 0.002073169220238924, -0.0010695274686440825],
- [-0.008700753562152386, -0.00425766222178936, 0.002196485875174403],
- ]
- ),
- )
- # first constraints_encoder_mlp.bias
- torch.testing.assert_close(
- gradients[6],
- torch.tensor([0.05254765599966049, -0.0024578727316111326, 0.005047726910561323]),
- )
- # second constraints_encoder_mlp.weight
- torch.testing.assert_close(
- gradients[7],
- torch.tensor(
- [
- [0.004190368112176657, -0.01078515499830246, -0.015312351286411285],
- [0.001505501102656126, -0.006679146084934473, -0.009482797235250473],
- [0.02189277485013008, -0.010388202033936977, -0.014748772606253624],
- ]
- ),
- )
- # second constraints_encoder_mlp.bias
- torch.testing.assert_close(
- gradients[8],
- torch.tensor([0.016296036541461945, -0.00018996000289916992, 0.05888192355632782]),
- )
diff --git a/tests/models/components/test_pooler.py b/tests/models/components/test_pooler.py
index a63b36a68..e69de29bb 100644
--- a/tests/models/components/test_pooler.py
+++ b/tests/models/components/test_pooler.py
@@ -1,218 +0,0 @@
-import pytest
-import torch
-
-from pie_modules.models.components.pooler import (
- CLS_TOKEN,
- MENTION_POOLING,
- START_TOKENS,
- ArgumentWrappedPooler,
- AtIndexPooler,
- SpanMaxPooler,
- SpanMeanPooler,
- get_pooler_and_output_size,
- pool_cls,
-)
-
-
-@pytest.mark.parametrize(
- "pooler_type",
- [
- CLS_TOKEN,
- START_TOKENS,
- MENTION_POOLING,
- ],
-)
-def test_get_pooler_and_output_size(pooler_type):
- pooler, output_size = get_pooler_and_output_size(config={"type": pooler_type}, input_dim=20)
- assert pooler is not None
- if pooler_type == CLS_TOKEN:
- assert output_size == 20
- elif pooler_type in (START_TOKENS, MENTION_POOLING):
- # pre default, num_indices is 2
- assert output_size == 20 * 2
- else:
- raise ValueError(f"Unknown pooler type {pooler_type}")
-
-
-@pytest.mark.parametrize("aggregate", ["max", "mean"])
-def test_get_pooler_and_output_size_mention(aggregate):
- pooler, output_size = get_pooler_and_output_size(
- config={"type": MENTION_POOLING, "aggregate": aggregate}, input_dim=20
- )
- assert pooler is not None
- assert output_size == 20 * 2
- if aggregate == "max":
- assert isinstance(pooler, SpanMaxPooler)
- elif aggregate == "mean":
- assert isinstance(pooler, SpanMeanPooler)
- else:
- raise ValueError(f"Unknown aggregate type {aggregate}")
-
-
-def test_get_pooler_and_output_size_mention_unknown_aggregate():
- with pytest.raises(ValueError) as excinfo:
- get_pooler_and_output_size(
- config={"type": MENTION_POOLING, "aggregate": "unknown"}, input_dim=20
- )
- assert str(excinfo.value) == 'Unknown aggregation method for mention pooling: "unknown"'
-
-
-def test_get_pooler_and_output_size_wrong_type():
- with pytest.raises(ValueError) as excinfo:
- get_pooler_and_output_size(config={"type": "wrong_type"}, input_dim=20)
- assert str(excinfo.value) == 'Unknown pooler type "wrong_type"'
-
-
-@pytest.fixture(scope="session")
-def hidde_state():
- result = torch.tensor(
- [
- [[0.00, 0.01], [0.10, 0.11], [0.20, 0.21], [0.30, 0.31]],
- [[1.00, 1.01], [1.10, 1.11], [1.20, 1.21], [1.30, 1.31]],
- ]
- )
- # batch_size x sequence_length x hidden_size
- assert result.shape == (2, 4, 2)
- return result
-
-
-def test_pool_cls(hidde_state):
- pooler = pool_cls
- output = pooler(hidden_state=hidde_state)
- assert output is not None
- batch_size = hidde_state.shape[0]
- hidden_size = hidde_state.shape[-1]
- assert output.shape == (batch_size, hidden_size)
- torch.testing.assert_close(output, hidde_state[:, 0, :])
- torch.testing.assert_close(output, torch.tensor([[0.00, 0.01], [1.00, 1.01]]))
-
-
-def test_at_index_pooler(hidde_state):
- pooler = AtIndexPooler(input_dim=hidde_state.shape[-1], num_indices=2)
- indices = torch.tensor([[2, 0], [1, 0]])
- output = pooler(hidden_state=hidde_state, indices=indices)
- assert output is not None
- batch_size = hidde_state.shape[0]
- hidden_size = hidde_state.shape[-1]
- # times num_indices (=2) due to concat
- assert output.shape == (batch_size, hidden_size * 2)
- torch.testing.assert_close(
- output, torch.tensor([[0.20, 0.21, 0.00, 0.01], [1.10, 1.11, 1.00, 1.01]])
- )
-
-
-def test_at_index_pooler_with_offset(hidde_state):
- # set the seed to make sure that we get the same missing embeddings
- torch.manual_seed(42)
- pooler = AtIndexPooler(input_dim=hidde_state.shape[-1], num_indices=2, offset=-1)
- indices = torch.tensor([[2, 1], [0, -10]])
- output = pooler(hidden_state=hidde_state, indices=indices)
- assert output is not None
- batch_size = hidde_state.shape[0]
- hidden_size = hidde_state.shape[-1]
- # times num_indices (=2) due to concat
- assert output.shape == (batch_size, hidden_size * 2)
- # the second batch element has out of bounds indices, so we expect the missing embeddings
- # it needs to be flattened, because the output is concatenated
- torch.testing.assert_close(output[1], pooler.missing_embeddings.view(-1))
- torch.testing.assert_close(
- output,
- torch.tensor(
- [
- [0.10, 0.11, 0.00, 0.01],
- [
- 0.33669036626815796,
- 0.12880940735340118,
- 0.23446236550807953,
- 0.23033303022384644,
- ],
- ]
- ),
- )
-
-
-def test_at_index_pooler_wrong_indices_shapes(hidde_state):
- pooler = AtIndexPooler(input_dim=hidde_state.shape[-1], num_indices=2)
- indices = torch.tensor([[2, 0, 1], [1, 0, 0]])
- with pytest.raises(ValueError) as excinfo:
- pooler(hidden_state=hidde_state, indices=indices)
- assert str(excinfo.value) == "number of indices [3] has to be the same as num_types [2]"
-
-
-def test_argument_wrapped_pooler(hidde_state):
- def dummy_pooler(hidden_state, y):
- return hidden_state[:, 0, :]
-
- pooler = ArgumentWrappedPooler(pooler=dummy_pooler, argument_mapping={"x": "y"})
- output = pooler(hidden_state=hidde_state, x="dummy")
- assert output is not None
- batch_size = hidde_state.shape[0]
- hidden_size = hidde_state.shape[-1]
- assert output.shape == (batch_size, hidden_size)
- torch.testing.assert_close(output, hidde_state[:, 0, :])
-
-
-def test_span_max_pooler(hidde_state):
- pooler = SpanMaxPooler(input_dim=hidde_state.shape[-1], num_indices=2)
- start_indices = torch.tensor([[2, 0], [0, 1]])
- end_indices = torch.tensor([[3, 3], [1, 2]])
- output = pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices)
- assert output is not None
- batch_size = hidde_state.shape[0]
- hidden_size = hidde_state.shape[-1]
- # times num_indices (=2) due to concat
- assert output.shape == (batch_size, hidden_size * 2)
- torch.testing.assert_close(
- output, torch.tensor([[0.20, 0.21, 0.20, 0.21], [1.00, 1.01, 1.10, 1.11]])
- )
-
-
-def test_span_max_pooler_wrong_start_indices_shape(hidde_state):
- pooler = SpanMaxPooler(input_dim=hidde_state.shape[-1], num_indices=2)
- start_indices = torch.tensor([[2, 0, 1], [0, 1, 0]])
- end_indices = torch.tensor([[3, 3], [1, 2]])
- with pytest.raises(ValueError) as excinfo:
- pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices)
- assert str(excinfo.value) == (
- "number of start indices [3] has to be the same as num_types [2]"
- )
-
-
-def test_span_max_pooler_wrong_end_indices_shape(hidde_state):
- pooler = SpanMaxPooler(input_dim=hidde_state.shape[-1], num_indices=2)
- start_indices = torch.tensor([[2, 0], [0, 1]])
- end_indices = torch.tensor([[3, 3, 3], [1, 2, 1]])
- with pytest.raises(ValueError) as excinfo:
- pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices)
- assert str(excinfo.value) == ("number of end indices [3] has to be the same as num_types [2]")
-
-
-def test_span_max_pooler_start_indices_bigger_than_end_indices(hidde_state):
- pooler = SpanMaxPooler(input_dim=hidde_state.shape[-1], num_indices=2)
- start_indices = torch.tensor([[2, 0], [0, 1]])
- end_indices = torch.tensor([[1, 3], [1, 2]])
- with pytest.raises(ValueError) as excinfo:
- pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices)
- assert str(excinfo.value) == (
- "values in start_indices have to be smaller than respective values in end_indices, but start_indices=\n"
- "tensor([[2, 0],\n"
- " [0, 1]])\n "
- "and end_indices=\n"
- "tensor([[1, 3],\n"
- " [1, 2]])"
- )
-
-
-def test_span_mean_pooler(hidde_state):
- pooler = SpanMeanPooler(input_dim=hidde_state.shape[-1], num_indices=2)
- start_indices = torch.tensor([[2, 0], [0, 1]])
- end_indices = torch.tensor([[3, 3], [1, 2]])
- output = pooler(hidden_state=hidde_state, start_indices=start_indices, end_indices=end_indices)
- assert output is not None
- batch_size = hidde_state.shape[0]
- hidden_size = hidde_state.shape[-1]
- # times num_indices (=2) due to concat
- assert output.shape == (batch_size, hidden_size * 2)
- torch.testing.assert_close(
- output, torch.tensor([[0.20, 0.21, 0.10, 0.11], [1.00, 1.01, 1.10, 1.11]])
- )
diff --git a/tests/models/components/test_seq2seq_encoder.py b/tests/models/components/test_seq2seq_encoder.py
deleted file mode 100644
index 58e2fc0c5..000000000
--- a/tests/models/components/test_seq2seq_encoder.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import pytest
-import torch
-
-from pie_modules.models.components.seq2seq_encoder import (
- ACTIVATION_TYPE2CLASS,
- RNN_TYPE2CLASS,
- build_seq2seq_encoder,
-)
-
-
-def test_no_encoder():
- seq2seq_dict = {}
- input_size = 10
- encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size)
- assert encoder is None
- assert output_size == input_size
-
- seq2seq_dict = {
- "type": "sequential",
- "rnn_layer": {
- "type": "none",
- },
- }
- input_size = 10
- encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size)
- assert len(encoder) == 0
- assert output_size == input_size
-
-
-@pytest.mark.parametrize("seq2seq_enc_type", list(RNN_TYPE2CLASS))
-@pytest.mark.parametrize("bidirectional", [True, False])
-def test_rnn_encoder(seq2seq_enc_type, bidirectional):
- hidden_size = 99
- seq2seq_dict = {
- "type": seq2seq_enc_type,
- "hidden_size": hidden_size,
- "bidirectional": bidirectional,
- }
- input_size = 10
- encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size)
- assert encoder is not None
- assert isinstance(encoder.rnn, RNN_TYPE2CLASS[seq2seq_enc_type])
-
- expected_output_size = hidden_size * 2 if bidirectional else hidden_size
- assert output_size is not None
- assert output_size == expected_output_size
-
-
-@pytest.mark.parametrize("activation_type", list(ACTIVATION_TYPE2CLASS))
-def test_activations(activation_type):
- seq2seq_dict = {
- "type": activation_type,
- }
- input_size = 10
- encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size)
- assert encoder is not None
- assert isinstance(encoder, ACTIVATION_TYPE2CLASS[activation_type])
- assert output_size == input_size
-
-
-def test_dropout():
- seq2seq_dict = {
- "type": "dropout",
- "p": 0.5,
- }
- input_size = 10
- encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size)
- assert encoder is not None
- assert isinstance(encoder, torch.nn.Dropout)
- assert output_size == input_size
-
-
-def test_linear():
- out_features = 99
- seq2seq_dict = {
- "type": "linear",
- "out_features": out_features,
- }
-
- input_size = 10
- encoder, output_size = build_seq2seq_encoder(seq2seq_dict, input_size)
- assert encoder is not None
- assert isinstance(encoder, torch.nn.Linear)
- assert output_size == out_features
-
-
-def test_unknown_rnn_type():
- seq2seq_dict = {
- "type": "unknown",
- }
- with pytest.raises(ValueError) as exc_info:
- build_seq2seq_encoder(seq2seq_dict, 10)
- assert str(exc_info.value) == "Unknown seq2seq_encoder_type: unknown"
diff --git a/tests/models/test_extractive_question_answering.py b/tests/models/test_extractive_question_answering.py
deleted file mode 100644
index 9d8b6c1cb..000000000
--- a/tests/models/test_extractive_question_answering.py
+++ /dev/null
@@ -1,245 +0,0 @@
-import json
-
-import pytest
-import torch
-import transformers
-from pytorch_lightning import Trainer
-
-from pie_modules.annotations import ExtractiveAnswer, Question
-from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers
-from pie_modules.models.simple_extractive_question_answering import (
- SimpleExtractiveQuestionAnsweringModel,
-)
-from pie_modules.taskmodules.extractive_question_answering import (
- ExtractiveQuestionAnsweringTaskModule,
-)
-from tests import DUMP_FIXTURE_DATA, FIXTURES_ROOT
-
-FIXTURES_TASKMODULE_DATA_PATH = FIXTURES_ROOT / "taskmodules" / "extractive_question_answering"
-
-
-@pytest.fixture
-def documents():
- document0 = TextDocumentWithQuestionsAndExtractiveAnswers(
- text="This is a test document", id="doc0"
- )
- document0.questions.append(Question(text="What is the first word?"))
- document0.answers.append(ExtractiveAnswer(question=document0.questions[0], start=0, end=3))
-
- document1 = TextDocumentWithQuestionsAndExtractiveAnswers(
- text="Oranges are orange in color.", id="doc1"
- )
- document1.questions.append(Question(text="What color are oranges?"))
- document1.answers.append(ExtractiveAnswer(question=document1.questions[0], start=23, end=27))
-
- document2 = TextDocumentWithQuestionsAndExtractiveAnswers(
- text="This is a test document that has two questions attached to it.", id="doc2"
- )
- document2.questions.append(Question(text="What type of document is this?"))
- document2.questions.append(Question(text="How many questions are attached to this document?"))
- document2.answers.append(ExtractiveAnswer(question=document2.questions[0], start=11, end=14))
- document2.answers.append(ExtractiveAnswer(question=document2.questions[1], start=34, end=36))
-
- documents = [document0, document1, document2]
- return documents
-
-
-@pytest.mark.skipif(
- condition=not DUMP_FIXTURE_DATA,
- reason="Only need to dump the data if taskmodule has changed",
-)
-def test_dump_fixtures(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = ExtractiveQuestionAnsweringTaskModule(
- tokenizer_name_or_path=tokenizer_name_or_path,
- max_length=512,
- )
-
- task_encodings = taskmodule.encode(documents, encode_target=True)
- batch_encoding = taskmodule.collate(task_encodings)
-
- FIXTURES_TASKMODULE_DATA_PATH.mkdir(parents=True, exist_ok=True)
- filepath = FIXTURES_TASKMODULE_DATA_PATH / "batch_encoding_inputs.json"
-
- inputs = {key: tensor.tolist() for key, tensor in batch_encoding[0].items()}
- targets = {key: tensor.tolist() for key, tensor in batch_encoding[1].items()}
- converted_batch_encoding = {
- "inputs": inputs,
- "targets": targets,
- }
-
- with open(filepath, "w") as f:
- json.dump(converted_batch_encoding, f)
- return converted_batch_encoding
-
-
-@pytest.fixture
-def batch():
- filepath = FIXTURES_TASKMODULE_DATA_PATH / "batch_encoding_inputs.json"
- with open(filepath) as f:
- batch_encoding = json.load(f)
-
- inputs = {key: torch.LongTensor(tensor) for key, tensor in batch_encoding["inputs"].items()}
- targets = {key: torch.LongTensor(tensor) for key, tensor in batch_encoding["targets"].items()}
- return inputs, targets
-
-
-def get_model(
- monkeypatch,
- model_type,
- batch_size,
- seq_len,
- add_dummy_linear=False,
- **model_kwargs,
-):
- class MockConfig:
- def __init__(
- self,
- hidden_size: int = 10,
- model_type=model_type,
- ) -> None:
- self.hidden_size = hidden_size
- self.model_type = model_type
-
- class MockModel(torch.nn.Module):
- def __init__(self, batch_size, seq_len, hidden_size, add_dummy_linear) -> None:
- super().__init__()
- self.batch_size = batch_size
- self.seq_len = seq_len
- self.hidden_size = hidden_size
- if add_dummy_linear:
- self.dummy_linear = torch.nn.Linear(self.hidden_size, 99)
-
- def __call__(self, *args, **kwargs):
- torch.manual_seed(42)
- start_logits = torch.FloatTensor(torch.rand(self.batch_size, self.seq_len))
- end_logits = torch.FloatTensor(torch.rand(self.batch_size, self.seq_len))
- loss = torch.FloatTensor(torch.rand(1))
- return transformers.modeling_outputs.QuestionAnsweringModelOutput(
- start_logits=start_logits,
- end_logits=end_logits,
- loss=loss,
- )
-
- hidden_size = 10
-
- monkeypatch.setattr(
- transformers.AutoConfig,
- "from_pretrained",
- lambda model_name_or_path: MockConfig(hidden_size=hidden_size, model_type=model_type),
- )
- monkeypatch.setattr(
- transformers.AutoModelForQuestionAnswering,
- "from_pretrained",
- lambda model_name_or_path, config: MockModel(
- batch_size=batch_size,
- seq_len=seq_len,
- hidden_size=hidden_size,
- add_dummy_linear=add_dummy_linear,
- ),
- )
-
- # set seed to make the classifier deterministic
- torch.manual_seed(42)
- result = SimpleExtractiveQuestionAnsweringModel(
- model_name_or_path=model_type,
- max_input_length=seq_len,
- **model_kwargs,
- )
- assert not result.is_from_pretrained
-
- return result
-
-
-@pytest.fixture
-def model(monkeypatch, batch):
- inputs, targets = batch
- model = get_model(
- monkeypatch=monkeypatch,
- model_type="bert",
- batch_size=inputs["input_ids"].shape[0],
- seq_len=inputs["input_ids"].shape[1],
- add_dummy_linear=True,
- )
- return model
-
-
-def test_get_model(monkeypatch, model):
- assert model is not None
- assert isinstance(model, SimpleExtractiveQuestionAnsweringModel)
-
-
-def test_forward(batch, model):
- inputs, targets = batch
- batch_size, seq_len = inputs["input_ids"].shape
-
- # set seed to make sure the output is deterministic
- torch.manual_seed(42)
- output = model.forward(inputs)
- assert set(output) == {"start_logits", "end_logits", "loss"}
- start_logits = output["start_logits"]
- end_logits = output["end_logits"]
- loss = output["loss"]
- assert start_logits.shape == (batch_size, seq_len)
- assert end_logits.shape == (batch_size, seq_len)
- assert loss.shape == (1,)
- expected_loss = torch.FloatTensor([0.04587])
- torch.testing.assert_close(output["loss"], expected_loss)
-
-
-def test_step(batch, model):
- torch.manual_seed(42)
- loss = model.step("train", batch)
- assert loss is not None
- expected_loss = torch.FloatTensor([0.04587])
- torch.testing.assert_close(loss, expected_loss)
-
-
-def test_training_step(batch, model):
- loss = model.training_step(batch, batch_idx=0)
- assert loss is not None
- expected_loss = torch.FloatTensor([0.04587])
- torch.testing.assert_close(loss, expected_loss)
-
-
-def test_validation_step(batch, model):
- loss = model.validation_step(batch, batch_idx=0)
- assert loss is not None
- expected_loss = torch.FloatTensor([0.04587])
- torch.testing.assert_close(loss, expected_loss)
-
-
-def test_test_step(batch, model):
- loss = model.test_step(batch, batch_idx=0)
- assert loss is not None
- expected_loss = torch.FloatTensor([0.04587])
- torch.testing.assert_close(loss, expected_loss)
-
-
-def test_optim(model):
- optimizer = model.configure_optimizers()
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.Adam)
- assert optimizer.defaults["lr"] == 1e-05
-
-
-def test_optim_with_warmup_proportion(monkeypatch, batch):
- inputs, targets = batch
- model = get_model(
- monkeypatch=monkeypatch,
- model_type="bert",
- batch_size=inputs["input_ids"].shape[0],
- seq_len=inputs["input_ids"].shape[1],
- add_dummy_linear=True,
- warmup_proportion=0.1,
- )
- model.trainer = Trainer(max_epochs=10)
- optimizers_and_schedulars = model.configure_optimizers()
- assert optimizers_and_schedulars is not None
- assert isinstance(optimizers_and_schedulars, tuple) and len(optimizers_and_schedulars) == 2
-
- optimizers, schedulers = optimizers_and_schedulars
- assert isinstance(optimizers[0], torch.optim.Optimizer)
- assert set(schedulers[0]) == {"scheduler", "interval"}
- schedular = schedulers[0]["scheduler"]
- assert isinstance(schedular, torch.optim.lr_scheduler.LRScheduler)
diff --git a/tests/models/test_sequence_classification_with_pooler.py b/tests/models/test_sequence_classification_with_pooler.py
deleted file mode 100644
index 07f9e10cd..000000000
--- a/tests/models/test_sequence_classification_with_pooler.py
+++ /dev/null
@@ -1,578 +0,0 @@
-from typing import Dict
-
-import pytest
-import torch
-from pytorch_lightning import Trainer
-from torch import LongTensor, tensor
-from torch.optim.lr_scheduler import LambdaLR
-from transformers.modeling_outputs import SequenceClassifierOutput
-
-from pie_modules.models import SequenceClassificationModelWithPooler
-from pie_modules.models.sequence_classification_with_pooler import OutputType
-from tests.models import trunc_number
-
-NUM_CLASSES = 4
-POOLER = "start_tokens"
-
-
-@pytest.fixture
-def inputs() -> Dict[str, LongTensor]:
- result_dict = {
- "input_ids": torch.tensor(
- [
- [
- 101,
- 28998,
- 13832,
- 3121,
- 2340,
- 138,
- 28996,
- 1759,
- 1120,
- 28999,
- 139,
- 28997,
- 119,
- 102,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28998,
- 13832,
- 3121,
- 2340,
- 144,
- 28996,
- 1759,
- 1120,
- 28999,
- 145,
- 28997,
- 119,
- 1262,
- 1771,
- 146,
- 119,
- 102,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28998,
- 13832,
- 3121,
- 2340,
- 144,
- 28996,
- 1759,
- 1120,
- 145,
- 119,
- 1262,
- 1771,
- 28999,
- 146,
- 28997,
- 119,
- 102,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 13832,
- 3121,
- 2340,
- 144,
- 1759,
- 1120,
- 28999,
- 145,
- 28997,
- 119,
- 1262,
- 1771,
- 28998,
- 146,
- 28996,
- 119,
- 102,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28998,
- 13832,
- 3121,
- 2340,
- 150,
- 28996,
- 1759,
- 1120,
- 28999,
- 151,
- 28997,
- 119,
- 1262,
- 1122,
- 1771,
- 152,
- 119,
- 102,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 13832,
- 3121,
- 2340,
- 150,
- 1759,
- 1120,
- 151,
- 119,
- 1262,
- 28998,
- 1122,
- 28996,
- 1771,
- 28999,
- 152,
- 28997,
- 119,
- 102,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 13832,
- 3121,
- 2340,
- 150,
- 1759,
- 1120,
- 151,
- 119,
- 1262,
- 28999,
- 1122,
- 28997,
- 1771,
- 28998,
- 152,
- 28996,
- 119,
- 102,
- ],
- ]
- ).to(torch.long),
- "attention_mask": torch.tensor(
- [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ).to(torch.long),
- "pooler_start_indices": torch.tensor(
- [[2, 10], [5, 13], [5, 17], [17, 11], [5, 13], [14, 18], [18, 14]]
- ).to(torch.long),
- "pooler_end_indices": torch.tensor(
- [[6, 11], [9, 14], [9, 18], [18, 12], [9, 14], [15, 19], [19, 15]]
- ).to(torch.long),
- }
-
- return result_dict
-
-
-@pytest.fixture
-def targets() -> Dict[str, LongTensor]:
- return {"labels": torch.tensor([0, 1, 2, 3, 1, 2, 3]).to(torch.long)}
-
-
-@pytest.fixture
-def model() -> SequenceClassificationModelWithPooler:
- torch.manual_seed(42)
- result = SequenceClassificationModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- pooler=POOLER,
- )
- return result
-
-
-def test_model(model):
- assert model is not None
- named_parameters = dict(model.named_parameters())
- parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()}
- parameter_means_expected = {
- "classifier.bias": -0.0253964,
- "classifier.weight": -0.000511,
- "model.embeddings.LayerNorm.bias": -0.0294608,
- "model.embeddings.LayerNorm.weight": 1.312345,
- "model.embeddings.position_embeddings.weight": 5.5e-05,
- "model.embeddings.token_type_embeddings.weight": -0.0015419,
- "model.embeddings.word_embeddings.weight": 0.0031152,
- "model.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714,
- "model.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831,
- "model.encoder.layer.0.attention.output.dense.bias": 0.0007209,
- "model.encoder.layer.0.attention.output.dense.weight": 3.01e-05,
- "model.encoder.layer.0.attention.self.key.bias": 0.0020557,
- "model.encoder.layer.0.attention.self.key.weight": 0.0003863,
- "model.encoder.layer.0.attention.self.query.bias": 0.0185744,
- "model.encoder.layer.0.attention.self.query.weight": -0.0003949,
- "model.encoder.layer.0.attention.self.value.bias": 0.0065417,
- "model.encoder.layer.0.attention.self.value.weight": 4.22e-05,
- "model.encoder.layer.0.intermediate.dense.bias": -0.1219958,
- "model.encoder.layer.0.intermediate.dense.weight": -0.0011731,
- "model.encoder.layer.0.output.LayerNorm.bias": 0.005295,
- "model.encoder.layer.0.output.LayerNorm.weight": 1.2419648,
- "model.encoder.layer.0.output.dense.bias": -0.0013031,
- "model.encoder.layer.0.output.dense.weight": -0.0002212,
- "model.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237,
- "model.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343,
- "model.encoder.layer.1.attention.output.dense.bias": 0.0041446,
- "model.encoder.layer.1.attention.output.dense.weight": -2.43e-05,
- "model.encoder.layer.1.attention.self.key.bias": 0.0045062,
- "model.encoder.layer.1.attention.self.key.weight": 0.0001333,
- "model.encoder.layer.1.attention.self.query.bias": -0.0358397,
- "model.encoder.layer.1.attention.self.query.weight": -0.0007321,
- "model.encoder.layer.1.attention.self.value.bias": -0.0007094,
- "model.encoder.layer.1.attention.self.value.weight": 0.0001012,
- "model.encoder.layer.1.intermediate.dense.bias": -0.1247257,
- "model.encoder.layer.1.intermediate.dense.weight": -0.001344,
- "model.encoder.layer.1.output.LayerNorm.bias": -0.0474442,
- "model.encoder.layer.1.output.LayerNorm.weight": 1.017162,
- "model.encoder.layer.1.output.dense.bias": 0.000677,
- "model.encoder.layer.1.output.dense.weight": -5.32e-05,
- "model.pooler.dense.bias": -0.0052078,
- "model.pooler.dense.weight": 0.0001295,
- "pooler.pooler.missing_embeddings": 0.0630417,
- }
- assert parameter_means == parameter_means_expected
-
-
-def test_model_pickleable(model):
- import pickle
-
- pickle.dumps(model)
-
-
-@pytest.fixture
-def model_output(model, inputs) -> OutputType:
- # set seed to make sure the output is deterministic
- torch.manual_seed(42)
- return model(inputs)
-
-
-def test_forward_logits(model_output, inputs):
- batch_size, seq_len = inputs["input_ids"].shape
-
- assert isinstance(model_output, SequenceClassifierOutput)
-
- logits = model_output.logits
-
- assert logits.shape == (batch_size, NUM_CLASSES)
-
- torch.testing.assert_close(
- logits,
- torch.tensor(
- [
- [
- -0.29492446780204773,
- -0.804599940776825,
- -0.19659805297851562,
- -1.0868580341339111,
- ],
- [
- -0.3601434826850891,
- -0.7454482316970825,
- 0.4882846474647522,
- -1.0253472328186035,
- ],
- [
- -0.40172430872917175,
- -1.2119399309158325,
- 0.5856620669364929,
- -1.0999149084091187,
- ],
- [
- -0.09260234981775284,
- -1.0742112398147583,
- 0.3299948275089264,
- -0.5182554125785828,
- ],
- [
- -0.40149545669555664,
- -0.7731614708900452,
- 0.4616103768348694,
- -1.0583568811416626,
- ],
- [
- -0.14356234669685364,
- -1.2634986639022827,
- 0.31660008430480957,
- -0.7487461566925049,
- ],
- [
- -0.11717557162046432,
- -0.971996009349823,
- 0.3477852940559387,
- -0.5993944406509399,
- ],
- ]
- ),
- )
-
-
-def test_decode(model, model_output, inputs):
- decoded = model.decode(inputs=inputs, outputs=model_output)
- assert isinstance(decoded, dict)
- assert set(decoded) == {"labels", "probabilities"}
- labels = decoded["labels"]
- assert labels.shape == (inputs["input_ids"].shape[0],)
- torch.testing.assert_close(
- labels,
- torch.tensor([2, 2, 2, 2, 2, 2, 2]),
- )
- probabilities = decoded["probabilities"]
- assert probabilities.shape == (inputs["input_ids"].shape[0], NUM_CLASSES)
- torch.testing.assert_close(
- probabilities.round(decimals=4),
- torch.tensor(
- [
- [0.3168, 0.1903, 0.3495, 0.1435],
- [0.2207, 0.1502, 0.5156, 0.1135],
- [0.2161, 0.0961, 0.5802, 0.1075],
- [0.2814, 0.1054, 0.4294, 0.1838],
- [0.2184, 0.1506, 0.5177, 0.1132],
- [0.2893, 0.0944, 0.4583, 0.1580],
- [0.2751, 0.1170, 0.4380, 0.1699],
- ]
- ),
- )
-
-
-def test_decode_with_multi_label(model_output, inputs):
- torch.manual_seed(42)
- model = SequenceClassificationModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- pooler=POOLER,
- multi_label=True,
- )
- decoded = model.decode(inputs=inputs, outputs=model_output)
- assert isinstance(decoded, dict)
- assert set(decoded) == {"labels", "probabilities"}
- labels = decoded["labels"]
- assert labels.shape == (inputs["input_ids"].shape[0], NUM_CLASSES)
- torch.testing.assert_close(
- labels,
- torch.tensor(
- [
- [0, 0, 0, 0],
- [0, 0, 1, 0],
- [0, 0, 1, 0],
- [0, 0, 1, 0],
- [0, 0, 1, 0],
- [0, 0, 1, 0],
- [0, 0, 1, 0],
- ]
- ),
- )
- probabilities = decoded["probabilities"]
- assert probabilities.shape == (inputs["input_ids"].shape[0], NUM_CLASSES)
- torch.testing.assert_close(
- probabilities.round(decimals=4),
- torch.tensor(
- [
- [0.4268, 0.3090, 0.4510, 0.2522],
- [0.4109, 0.3218, 0.6197, 0.2640],
- [0.4009, 0.2294, 0.6424, 0.2498],
- [0.4769, 0.2546, 0.5818, 0.3733],
- [0.4010, 0.3158, 0.6134, 0.2576],
- [0.4642, 0.2204, 0.5785, 0.3211],
- [0.4707, 0.2745, 0.5861, 0.3545],
- ]
- ),
- )
-
-
-@pytest.fixture
-def batch(inputs, targets):
- return inputs, targets
-
-
-def test_training_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.training_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.3899686336517334))
-
-
-def test_validation_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.validation_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.3899686336517334))
-
-
-def test_test_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.test_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.3899686336517334))
-
-
-def test_base_model_named_parameters(model):
- base_model_named_parameters = dict(model.base_model_named_parameters())
- assert set(base_model_named_parameters) == {
- "model.pooler.dense.bias",
- "model.encoder.layer.0.intermediate.dense.weight",
- "model.encoder.layer.0.intermediate.dense.bias",
- "model.encoder.layer.1.attention.output.dense.weight",
- "model.encoder.layer.1.attention.output.LayerNorm.weight",
- "model.encoder.layer.1.attention.self.query.weight",
- "model.encoder.layer.1.output.dense.weight",
- "model.encoder.layer.0.output.dense.bias",
- "model.encoder.layer.1.intermediate.dense.bias",
- "model.encoder.layer.1.attention.self.value.bias",
- "model.encoder.layer.0.attention.output.dense.weight",
- "model.encoder.layer.0.attention.self.query.bias",
- "model.encoder.layer.0.attention.self.value.bias",
- "model.encoder.layer.1.output.dense.bias",
- "model.encoder.layer.1.attention.self.query.bias",
- "model.encoder.layer.1.attention.output.LayerNorm.bias",
- "model.encoder.layer.0.attention.self.query.weight",
- "model.encoder.layer.0.attention.output.LayerNorm.bias",
- "model.encoder.layer.0.attention.self.key.bias",
- "model.encoder.layer.1.intermediate.dense.weight",
- "model.encoder.layer.1.output.LayerNorm.bias",
- "model.encoder.layer.1.output.LayerNorm.weight",
- "model.encoder.layer.0.attention.self.key.weight",
- "model.encoder.layer.1.attention.output.dense.bias",
- "model.encoder.layer.0.attention.output.dense.bias",
- "model.embeddings.LayerNorm.bias",
- "model.encoder.layer.0.attention.self.value.weight",
- "model.encoder.layer.0.attention.output.LayerNorm.weight",
- "model.embeddings.token_type_embeddings.weight",
- "model.encoder.layer.0.output.LayerNorm.weight",
- "model.embeddings.position_embeddings.weight",
- "model.encoder.layer.1.attention.self.key.bias",
- "model.embeddings.LayerNorm.weight",
- "model.encoder.layer.0.output.LayerNorm.bias",
- "model.encoder.layer.1.attention.self.key.weight",
- "model.pooler.dense.weight",
- "model.encoder.layer.0.output.dense.weight",
- "model.embeddings.word_embeddings.weight",
- "model.encoder.layer.1.attention.self.value.weight",
- }
-
-
-def test_task_named_parameters(model):
- task_named_parameters = dict(model.task_named_parameters())
- assert set(task_named_parameters) == {
- "classifier.weight",
- "pooler.pooler.missing_embeddings",
- "classifier.bias",
- }
-
-
-def test_configure_optimizers_with_warmup():
- model = SequenceClassificationModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- )
- model.trainer = Trainer(max_epochs=10)
- optimizers_and_schedulers = model.configure_optimizers()
- assert len(optimizers_and_schedulers) == 2
- optimizers, schedulers = optimizers_and_schedulers
- assert len(optimizers) == 1
- assert len(schedulers) == 1
- optimizer = optimizers[0]
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.AdamW)
- assert optimizer.defaults["lr"] == 1e-05
- assert optimizer.defaults["weight_decay"] == 0.01
- assert optimizer.defaults["eps"] == 1e-08
-
- scheduler = schedulers[0]
- assert isinstance(scheduler, dict)
- assert set(scheduler) == {"scheduler", "interval"}
- assert isinstance(scheduler["scheduler"], LambdaLR)
-
-
-def test_configure_optimizers_with_task_learning_rate(monkeypatch):
- model = SequenceClassificationModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- learning_rate=1e-5,
- task_learning_rate=1e-3,
- # disable warmup to make sure the scheduler is not added which would set the learning rate
- # to 0
- warmup_proportion=0.0,
- )
- optimizer = model.configure_optimizers()
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.AdamW)
- assert len(optimizer.param_groups) == 2
- # base model parameters
- param_group = optimizer.param_groups[0]
- assert len(param_group["params"]) == 39
- assert param_group["lr"] == 1e-5
- # classifier head parameters
- param_group = optimizer.param_groups[1]
- assert len(param_group["params"]) == 2
- assert param_group["lr"] == 1e-3
- # ensure that all parameters are covered
- assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set(
- model.parameters()
- )
-
-
-def test_freeze_base_model(monkeypatch, inputs, targets):
- model = SequenceClassificationModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- freeze_base_model=True,
- # disable warmup to make sure the scheduler is not added which would set the learning rate
- # to 0
- warmup_proportion=0.0,
- )
- base_model_params = [param for name, param in model.base_model_named_parameters()]
- task_params = [param for name, param in model.task_named_parameters()]
- assert len(base_model_params) + len(task_params) == len(list(model.parameters()))
- for param in base_model_params:
- assert not param.requires_grad
- for param in task_params:
- assert param.requires_grad
diff --git a/tests/models/test_sequence_pair_similarity_model_with_pooler.py b/tests/models/test_sequence_pair_similarity_model_with_pooler.py
deleted file mode 100644
index 3e6871a11..000000000
--- a/tests/models/test_sequence_pair_similarity_model_with_pooler.py
+++ /dev/null
@@ -1,326 +0,0 @@
-from typing import Dict
-
-import pytest
-import torch
-from pytorch_lightning import Trainer
-from torch import LongTensor, tensor
-from torch.optim.lr_scheduler import LambdaLR
-from transformers.modeling_outputs import SequenceClassifierOutput
-
-from pie_modules.models import SequencePairSimilarityModelWithPooler
-from pie_modules.models.sequence_classification_with_pooler import OutputType
-from tests.models import trunc_number
-
-
-@pytest.fixture
-def inputs() -> Dict[str, LongTensor]:
- result_dict = {
- "encoding": {
- "input_ids": tensor(
- [
- [101, 1262, 1131, 1771, 140, 119, 102],
- [101, 1262, 1131, 1771, 140, 119, 102],
- [101, 1262, 1131, 1771, 140, 119, 102],
- [101, 1262, 1131, 1771, 140, 119, 102],
- ]
- ),
- "token_type_ids": tensor(
- [
- [0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0],
- ]
- ),
- "attention_mask": tensor(
- [
- [1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1],
- ]
- ),
- },
- "encoding_pair": {
- "input_ids": tensor(
- [
- [101, 3162, 7871, 1117, 5855, 119, 102],
- [101, 3162, 7871, 1117, 5855, 119, 102],
- [101, 3162, 7871, 1117, 5855, 119, 102],
- [101, 3162, 7871, 1117, 5855, 119, 102],
- ]
- ),
- "token_type_ids": tensor(
- [
- [0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 0],
- ]
- ),
- "attention_mask": tensor(
- [
- [1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1],
- ]
- ),
- },
- "pooler_start_indices": tensor([[2], [2], [4], [4]]),
- "pooler_end_indices": tensor([[3], [3], [5], [5]]),
- "pooler_pair_start_indices": tensor([[1], [3], [1], [3]]),
- "pooler_pair_end_indices": tensor([[2], [5], [2], [5]]),
- }
-
- return result_dict
-
-
-@pytest.fixture
-def targets() -> Dict[str, LongTensor]:
- return {"scores": tensor([0.0, 0.0, 0.0, 0.0])}
-
-
-@pytest.fixture
-def model() -> SequencePairSimilarityModelWithPooler:
- torch.manual_seed(42)
- result = SequencePairSimilarityModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- )
- return result
-
-
-def test_model(model):
- assert model is not None
- named_parameters = dict(model.named_parameters())
- parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()}
- parameter_means_expected = {
- "model.embeddings.word_embeddings.weight": 0.0031152,
- "model.embeddings.position_embeddings.weight": 5.5e-05,
- "model.embeddings.token_type_embeddings.weight": -0.0015419,
- "model.embeddings.LayerNorm.weight": 1.312345,
- "model.embeddings.LayerNorm.bias": -0.0294608,
- "model.encoder.layer.0.attention.self.query.weight": -0.0003949,
- "model.encoder.layer.0.attention.self.query.bias": 0.0185744,
- "model.encoder.layer.0.attention.self.key.weight": 0.0003863,
- "model.encoder.layer.0.attention.self.key.bias": 0.0020557,
- "model.encoder.layer.0.attention.self.value.weight": 4.22e-05,
- "model.encoder.layer.0.attention.self.value.bias": 0.0065417,
- "model.encoder.layer.0.attention.output.dense.weight": 3.01e-05,
- "model.encoder.layer.0.attention.output.dense.bias": 0.0007209,
- "model.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831,
- "model.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714,
- "model.encoder.layer.0.intermediate.dense.weight": -0.0011731,
- "model.encoder.layer.0.intermediate.dense.bias": -0.1219958,
- "model.encoder.layer.0.output.dense.weight": -0.0002212,
- "model.encoder.layer.0.output.dense.bias": -0.0013031,
- "model.encoder.layer.0.output.LayerNorm.weight": 1.2419648,
- "model.encoder.layer.0.output.LayerNorm.bias": 0.005295,
- "model.encoder.layer.1.attention.self.query.weight": -0.0007321,
- "model.encoder.layer.1.attention.self.query.bias": -0.0358397,
- "model.encoder.layer.1.attention.self.key.weight": 0.0001333,
- "model.encoder.layer.1.attention.self.key.bias": 0.0045062,
- "model.encoder.layer.1.attention.self.value.weight": 0.0001012,
- "model.encoder.layer.1.attention.self.value.bias": -0.0007094,
- "model.encoder.layer.1.attention.output.dense.weight": -2.43e-05,
- "model.encoder.layer.1.attention.output.dense.bias": 0.0041446,
- "model.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343,
- "model.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237,
- "model.encoder.layer.1.intermediate.dense.weight": -0.001344,
- "model.encoder.layer.1.intermediate.dense.bias": -0.1247257,
- "model.encoder.layer.1.output.dense.weight": -5.32e-05,
- "model.encoder.layer.1.output.dense.bias": 0.000677,
- "model.encoder.layer.1.output.LayerNorm.weight": 1.017162,
- "model.encoder.layer.1.output.LayerNorm.bias": -0.0474442,
- "model.pooler.dense.weight": 0.0001295,
- "model.pooler.dense.bias": -0.0052078,
- "pooler.missing_embeddings": 0.0812017,
- }
- assert parameter_means == parameter_means_expected
-
-
-def test_model_pickleable(model):
- import pickle
-
- pickle.dumps(model)
-
-
-@pytest.fixture
-def model_output(model, inputs) -> OutputType:
- # set seed to make sure the output is deterministic
- torch.manual_seed(42)
- return model(inputs)
-
-
-def test_forward_logits(model_output, inputs):
- assert isinstance(model_output, SequenceClassifierOutput)
-
- logits = model_output.logits
-
- torch.testing.assert_close(
- logits,
- torch.tensor(
- [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893]
- ),
- )
-
-
-def test_decode(model, model_output, inputs):
- decoded = model.decode(inputs=inputs, outputs=model_output)
- assert isinstance(decoded, dict)
- assert set(decoded) == {"scores"}
- scores = decoded["scores"]
- torch.testing.assert_close(
- scores,
- torch.tensor(
- [0.5338148474693298, 0.5866107940673828, 0.5076886415481567, 0.5946245789527893]
- ),
- )
-
-
-@pytest.fixture
-def batch(inputs, targets):
- return inputs, targets
-
-
-def test_training_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.training_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(0.8145309686660767))
-
-
-def test_validation_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.validation_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(0.8145309686660767))
-
-
-def test_test_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.test_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(0.8145309686660767))
-
-
-def test_base_model_named_parameters(model):
- base_model_named_parameters = dict(model.base_model_named_parameters())
- assert set(base_model_named_parameters) == {
- "model.pooler.dense.bias",
- "model.encoder.layer.0.intermediate.dense.weight",
- "model.encoder.layer.0.intermediate.dense.bias",
- "model.encoder.layer.1.attention.output.dense.weight",
- "model.encoder.layer.1.attention.output.LayerNorm.weight",
- "model.encoder.layer.1.attention.self.query.weight",
- "model.encoder.layer.1.output.dense.weight",
- "model.encoder.layer.0.output.dense.bias",
- "model.encoder.layer.1.intermediate.dense.bias",
- "model.encoder.layer.1.attention.self.value.bias",
- "model.encoder.layer.0.attention.output.dense.weight",
- "model.encoder.layer.0.attention.self.query.bias",
- "model.encoder.layer.0.attention.self.value.bias",
- "model.encoder.layer.1.output.dense.bias",
- "model.encoder.layer.1.attention.self.query.bias",
- "model.encoder.layer.1.attention.output.LayerNorm.bias",
- "model.encoder.layer.0.attention.self.query.weight",
- "model.encoder.layer.0.attention.output.LayerNorm.bias",
- "model.encoder.layer.0.attention.self.key.bias",
- "model.encoder.layer.1.intermediate.dense.weight",
- "model.encoder.layer.1.output.LayerNorm.bias",
- "model.encoder.layer.1.output.LayerNorm.weight",
- "model.encoder.layer.0.attention.self.key.weight",
- "model.encoder.layer.1.attention.output.dense.bias",
- "model.encoder.layer.0.attention.output.dense.bias",
- "model.embeddings.LayerNorm.bias",
- "model.encoder.layer.0.attention.self.value.weight",
- "model.encoder.layer.0.attention.output.LayerNorm.weight",
- "model.embeddings.token_type_embeddings.weight",
- "model.encoder.layer.0.output.LayerNorm.weight",
- "model.embeddings.position_embeddings.weight",
- "model.encoder.layer.1.attention.self.key.bias",
- "model.embeddings.LayerNorm.weight",
- "model.encoder.layer.0.output.LayerNorm.bias",
- "model.encoder.layer.1.attention.self.key.weight",
- "model.pooler.dense.weight",
- "model.encoder.layer.0.output.dense.weight",
- "model.embeddings.word_embeddings.weight",
- "model.encoder.layer.1.attention.self.value.weight",
- }
-
-
-def test_task_named_parameters(model):
- task_named_parameters = dict(model.task_named_parameters())
- assert set(task_named_parameters) == {
- "pooler.missing_embeddings",
- }
-
-
-def test_configure_optimizers_with_warmup():
- model = SequencePairSimilarityModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- )
- model.trainer = Trainer(max_epochs=10)
- optimizers_and_schedulers = model.configure_optimizers()
- assert len(optimizers_and_schedulers) == 2
- optimizers, schedulers = optimizers_and_schedulers
- assert len(optimizers) == 1
- assert len(schedulers) == 1
- optimizer = optimizers[0]
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.AdamW)
- assert optimizer.defaults["lr"] == 1e-05
- assert optimizer.defaults["weight_decay"] == 0.01
- assert optimizer.defaults["eps"] == 1e-08
-
- scheduler = schedulers[0]
- assert isinstance(scheduler, dict)
- assert set(scheduler) == {"scheduler", "interval"}
- assert isinstance(scheduler["scheduler"], LambdaLR)
-
-
-def test_configure_optimizers_with_task_learning_rate(monkeypatch):
- model = SequencePairSimilarityModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- learning_rate=1e-5,
- task_learning_rate=1e-3,
- # disable warmup to make sure the scheduler is not added which would set the learning rate
- # to 0
- warmup_proportion=0.0,
- )
- optimizer = model.configure_optimizers()
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.AdamW)
- assert len(optimizer.param_groups) == 2
- # base model parameters
- param_group = optimizer.param_groups[0]
- assert len(param_group["params"]) == 39
- assert param_group["lr"] == 1e-5
- # classifier head parameters - there is just the default embedding (which is not used)
- param_group = optimizer.param_groups[1]
- assert len(param_group["params"]) == 1
- assert param_group["lr"] == 1e-3
- # ensure that all parameters are covered
- assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set(
- model.parameters()
- )
-
-
-def test_freeze_base_model(monkeypatch, inputs, targets):
- model = SequencePairSimilarityModelWithPooler(
- model_name_or_path="prajjwal1/bert-tiny",
- freeze_base_model=True,
- # disable warmup to make sure the scheduler is not added which would set the learning rate
- # to 0
- warmup_proportion=0.0,
- )
- base_model_params = [param for name, param in model.base_model_named_parameters()]
- task_params = [param for name, param in model.task_named_parameters()]
- assert len(base_model_params) + len(task_params) == len(list(model.parameters()))
- for param in base_model_params:
- assert not param.requires_grad
- for param in task_params:
- assert param.requires_grad
diff --git a/tests/models/test_simple_generative.py b/tests/models/test_simple_generative.py
deleted file mode 100644
index 147ce1f67..000000000
--- a/tests/models/test_simple_generative.py
+++ /dev/null
@@ -1,439 +0,0 @@
-import math
-from typing import List, Optional
-
-import pytest
-import torch
-from pytorch_lightning import Trainer
-from torch.optim import Optimizer
-
-from pie_modules.models import SimpleGenerativeModel
-from pie_modules.models.common import TESTING, VALIDATION
-from pie_modules.taskmodules import TextToTextTaskModule
-from tests.models import trunc_number
-
-MODEL_ID = "google/t5-efficient-tiny-nl2"
-
-
-@pytest.fixture(scope="module")
-def taskmodule():
- return TextToTextTaskModule(
- tokenizer_name_or_path=MODEL_ID,
- document_type="pie_modules.documents.TextDocumentWithAbstractiveSummary",
- target_layer="abstractive_summary",
- target_annotation_type="pie_modules.annotations.AbstractiveSummary",
- tokenized_document_type="pie_modules.documents.TokenDocumentWithAbstractiveSummary",
- text_metric_type="torchmetrics.text.ROUGEScore",
- )
-
-
-@pytest.fixture(scope="module")
-def model(taskmodule) -> SimpleGenerativeModel:
- return SimpleGenerativeModel(
- base_model={
- "_type_": "transformers.AutoModelForSeq2SeqLM",
- "pretrained_model_name_or_path": MODEL_ID,
- },
- # only use predictions for metrics in test stage to cover all cases (default is all stages)
- metric_call_predict=[TESTING],
- taskmodule_config=taskmodule.config,
- # use a strange learning rate to make sure it is passed through
- learning_rate=13e-3,
- optimizer_type="torch.optim.Adam",
- )
-
-
-def test_model(model):
- assert model is not None
- assert model.model is not None
- assert model.taskmodule is not None
- named_parameters = dict(model.named_parameters())
- parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()}
- parameter_means_expected = {
- "model.shared.weight": -0.3906954,
- "model.encoder.block.0.layer.0.SelfAttention.q.weight": 2.15e-05,
- "model.encoder.block.0.layer.0.SelfAttention.k.weight": -0.0015166,
- "model.encoder.block.0.layer.0.SelfAttention.v.weight": -0.0018635,
- "model.encoder.block.0.layer.0.SelfAttention.o.weight": 0.000866,
- "model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": -2.8229351,
- "model.encoder.block.0.layer.0.layer_norm.weight": 0.226491,
- "model.encoder.block.0.layer.1.DenseReluDense.wi.weight": 0.0034651,
- "model.encoder.block.0.layer.1.DenseReluDense.wo.weight": 0.00017,
- "model.encoder.block.0.layer.1.layer_norm.weight": 1.2047424,
- "model.encoder.block.1.layer.0.SelfAttention.q.weight": -7.88e-05,
- "model.encoder.block.1.layer.0.SelfAttention.k.weight": -0.0017292,
- "model.encoder.block.1.layer.0.SelfAttention.v.weight": -0.0025692,
- "model.encoder.block.1.layer.0.SelfAttention.o.weight": 0.000484,
- "model.encoder.block.1.layer.0.layer_norm.weight": 0.4024209,
- "model.encoder.block.1.layer.1.DenseReluDense.wi.weight": 0.0012148,
- "model.encoder.block.1.layer.1.DenseReluDense.wo.weight": -0.000555,
- "model.encoder.block.1.layer.1.layer_norm.weight": 1.9719848,
- "model.encoder.final_layer_norm.weight": 1.3045949,
- "model.decoder.block.0.layer.0.SelfAttention.q.weight": 4.21e-05,
- "model.decoder.block.0.layer.0.SelfAttention.k.weight": 0.0006944,
- "model.decoder.block.0.layer.0.SelfAttention.v.weight": -0.0001296,
- "model.decoder.block.0.layer.0.SelfAttention.o.weight": 0.0020978,
- "model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight": -0.5869011,
- "model.decoder.block.0.layer.0.layer_norm.weight": 0.1958751,
- "model.decoder.block.0.layer.1.EncDecAttention.q.weight": 7.8e-06,
- "model.decoder.block.0.layer.1.EncDecAttention.k.weight": -0.0001409,
- "model.decoder.block.0.layer.1.EncDecAttention.v.weight": -0.0010971,
- "model.decoder.block.0.layer.1.EncDecAttention.o.weight": 0.0026751,
- "model.decoder.block.0.layer.1.layer_norm.weight": 0.0658893,
- "model.decoder.block.0.layer.2.DenseReluDense.wi.weight": 0.0012591,
- "model.decoder.block.0.layer.2.DenseReluDense.wo.weight": 0.0033682,
- "model.decoder.block.0.layer.2.layer_norm.weight": 2.9871673,
- "model.decoder.block.1.layer.0.SelfAttention.q.weight": 6.16e-05,
- "model.decoder.block.1.layer.0.SelfAttention.k.weight": 0.0004128,
- "model.decoder.block.1.layer.0.SelfAttention.v.weight": -0.0003878,
- "model.decoder.block.1.layer.0.SelfAttention.o.weight": -0.0040457,
- "model.decoder.block.1.layer.0.layer_norm.weight": 1.1167399,
- "model.decoder.block.1.layer.1.EncDecAttention.q.weight": -0.0001246,
- "model.decoder.block.1.layer.1.EncDecAttention.k.weight": 0.0013352,
- "model.decoder.block.1.layer.1.EncDecAttention.v.weight": -0.0024415,
- "model.decoder.block.1.layer.1.EncDecAttention.o.weight": -9.83e-05,
- "model.decoder.block.1.layer.1.layer_norm.weight": 0.0755381,
- "model.decoder.block.1.layer.2.DenseReluDense.wi.weight": -0.0045786,
- "model.decoder.block.1.layer.2.DenseReluDense.wo.weight": 0.0101685,
- "model.decoder.block.1.layer.2.layer_norm.weight": 7.3835659,
- "model.decoder.final_layer_norm.weight": 0.8366433,
- }
- assert parameter_means == parameter_means_expected
-
-
-def test_model_pickleable(model):
- import pickle
-
- pickle.dumps(model)
-
-
-def test_model_without_taskmodule(caplog):
- with caplog.at_level("WARNING"):
- model = SimpleGenerativeModel(
- base_model={
- "_type_": "transformers.AutoModelForSeq2SeqLM",
- "pretrained_model_name_or_path": MODEL_ID,
- },
- )
- assert model is not None
- assert caplog.messages == [
- "No taskmodule is available, so no metrics are set up. Please provide a taskmodule_config "
- "to enable metrics for stages ['train', 'val', 'test'].",
- "No taskmodule is available, so no generation config will be created. Consider setting "
- "taskmodule_config to a valid taskmodule config to use specific setup for generation.",
- ]
-
-
-def test_missing_base_model_and_type():
- with pytest.raises(ValueError) as excinfo:
- SimpleGenerativeModel()
- assert (
- str(excinfo.value)
- == "Either base_model or base_model_type must be provided. If base_model is "
- "not provided, base_model_type must be a valid model type, "
- "e.g. 'transformers.AutoModelForSeq2SeqLM'."
- )
-
-
-def test_model_with_deprecated_base_model_setup(caplog, taskmodule):
- with caplog.at_level("WARNING"):
- model = SimpleGenerativeModel(
- base_model_type="transformers.AutoModelForSeq2SeqLM",
- base_model_config=dict(pretrained_model_name_or_path=MODEL_ID),
- taskmodule_config=taskmodule.config,
- )
- assert model is not None
- assert caplog.messages == [
- "The base_model_type and base_model_config arguments are deprecated. Please use base_model. "
- "You can use the following code to create the base_model argument: "
- "base_model = {'_type_': base_model_type, **base_model_config}",
- ]
-
-
-@pytest.fixture(scope="module")
-def batch(model):
- inputs = {
- "input_ids": torch.tensor(
- [
- [100, 19, 3, 9, 794, 1708, 1, 0, 0, 0, 0, 0],
- [100, 19, 430, 794, 1708, 84, 19, 3, 9, 720, 1200, 1],
- ]
- ),
- "attention_mask": torch.tensor(
- [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
- ),
- }
-
- targets = {
- "labels": torch.tensor([[3, 9, 1708, 1, 0], [3, 9, 1200, 1708, 1]]),
- "decoder_attention_mask": torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]),
- }
-
- return inputs, targets
-
-
-def test_batch(batch, taskmodule):
- inputs, targets = batch
- input_ids_tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
- assert input_ids_tokens == [
- [
- "▁This",
- "▁is",
- "▁",
- "a",
- "▁test",
- "▁document",
- "",
- "",
- "",
- "",
- "",
- "",
- ],
- [
- "▁This",
- "▁is",
- "▁another",
- "▁test",
- "▁document",
- "▁which",
- "▁is",
- "▁",
- "a",
- "▁bit",
- "▁longer",
- "",
- ],
- ]
-
- labels_tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(labels) for labels in targets["labels"]
- ]
- assert labels_tokens == [
- ["▁", "a", "▁document", "", ""],
- ["▁", "a", "▁longer", "▁document", ""],
- ]
-
-
-def test_training_step(batch, model):
- model.train()
- torch.manual_seed(42)
- metric = model._get_metric(VALIDATION, batch_idx=0)
- metric.reset()
- loss = model.training_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(8.98222827911377))
-
- metric_values = metric.compute()
- metric_values_float = {key: value.item() for key, value in metric_values.items()}
-
- # we do not collect metrics during training, so all entries should be NaN
- assert len(metric_values_float) > 0
- assert all([math.isnan(value) for value in metric_values_float.values()])
-
- model.on_train_epoch_end()
-
-
-def test_validation_step(batch, model):
- model.eval()
- torch.manual_seed(42)
- metric = model._get_metric(VALIDATION, batch_idx=0)
- metric.reset()
- loss = model.validation_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(10.146586418151855))
-
- metric_values = metric.compute()
- metric_values_float = {key: value.item() for key, value in metric_values.items()}
- assert metric_values_float == {
- "rouge1_fmeasure": 0.0,
- "rouge1_precision": 0.0,
- "rouge1_recall": 0.0,
- "rouge2_fmeasure": 0.0,
- "rouge2_precision": 0.0,
- "rouge2_recall": 0.0,
- "rougeL_fmeasure": 0.0,
- "rougeL_precision": 0.0,
- "rougeL_recall": 0.0,
- "rougeLsum_fmeasure": 0.0,
- "rougeLsum_precision": 0.0,
- "rougeLsum_recall": 0.0,
- }
-
- model.on_validation_epoch_end()
-
-
-def test_test_step(batch, model):
- model.eval()
- torch.manual_seed(42)
- metric = model._get_metric(TESTING, batch_idx=0)
- metric.reset()
- loss = model.test_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(10.146586418151855))
-
- metric_values = metric.compute()
- metric_values_float = {key: value.item() for key, value in metric_values.items()}
- assert metric_values_float == {
- "rouge1_fmeasure": 0.1111111119389534,
- "rouge1_precision": 0.06666667014360428,
- "rouge1_recall": 0.3333333432674408,
- "rouge2_fmeasure": 0.0,
- "rouge2_precision": 0.0,
- "rouge2_recall": 0.0,
- "rougeL_fmeasure": 0.1111111119389534,
- "rougeL_precision": 0.06666667014360428,
- "rougeL_recall": 0.3333333432674408,
- "rougeLsum_fmeasure": 0.0555555559694767,
- "rougeLsum_precision": 0.03333333507180214,
- "rougeLsum_recall": 0.1666666716337204,
- }
-
- model.on_test_epoch_end()
-
-
-def test_predict_step(batch, model):
- model.eval()
- torch.manual_seed(42)
- predictions = model.predict_step(batch, batch_idx=0)
- labels = predictions["labels"]
- assert labels is not None
- torch.testing.assert_close(
- labels,
- torch.tensor(
- [
- [32099, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [
- 32099,
- 19,
- 3,
- 9,
- 248,
- 194,
- 12,
- 129,
- 25,
- 708,
- 5,
- 37,
- 166,
- 794,
- 1708,
- 19,
- 3,
- 9,
- 794,
- ],
- ]
- ),
- )
-
- predicted_tokens = [
- model.taskmodule.tokenizer.convert_ids_to_tokens(label) for label in labels
- ]
- assert predicted_tokens == [
- [
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- "",
- ],
- [
- "",
- "▁is",
- "▁",
- "a",
- "▁great",
- "▁way",
- "▁to",
- "▁get",
- "▁you",
- "▁started",
- ".",
- "▁The",
- "▁first",
- "▁test",
- "▁document",
- "▁is",
- "▁",
- "a",
- "▁test",
- ],
- ]
-
-
-@pytest.fixture(scope="module")
-def optimizer(model):
- return model.configure_optimizers()
-
-
-def test_optimizer(optimizer):
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.Adam)
- assert optimizer.defaults["lr"] == 13e-3
- assert len(optimizer.param_groups) == 1
- param_group = optimizer.param_groups[0]
- assert len(param_group["params"]) == 47
-
-
-def _assert_optimizer(
- actual: Optimizer,
- expected: Optimizer,
- allow_mismatching_param_group_keys: Optional[List[str]] = None,
-):
- allow_mismatching_param_group_key_set = set(allow_mismatching_param_group_keys or [])
- assert actual is not None
- assert isinstance(actual, type(expected))
- assert actual.defaults == expected.defaults
- assert len(actual.param_groups) == len(expected.param_groups)
- for actual_param_group, expected_param_group in zip(
- actual.param_groups, expected.param_groups
- ):
- actual_keys = set(actual_param_group) - allow_mismatching_param_group_key_set
- expected_keys = set(expected_param_group) - allow_mismatching_param_group_key_set
- assert actual_keys == expected_keys
- for key in actual_keys:
- # also include the key in the comparison to have it in the assertion error message
- assert (key, actual_param_group[key]) == (key, expected_param_group[key])
-
-
-def test_configure_optimizers_with_warmup(model, optimizer):
- warmup_proportion_backup_value = model.warmup_proportion
- scheduler_name_backup_value = model.scheduler_name
- model.warmup_proportion = 0.1
- model.scheduler_name = "linear"
- model.trainer = Trainer(max_epochs=10)
- optimizer_and_schedular = model.configure_optimizers()
- assert optimizer_and_schedular is not None
- assert isinstance(optimizer_and_schedular, tuple)
- assert len(optimizer_and_schedular) == 2
- optimizers, schedulers = optimizer_and_schedular
- assert len(optimizers) == 1
- _assert_optimizer(
- optimizers[0], optimizer, allow_mismatching_param_group_keys=["initial_lr", "lr"]
- )
- assert len(schedulers) == 1
- assert set(schedulers[0]) == {"scheduler", "interval"}
- scheduler = schedulers[0]["scheduler"]
- assert isinstance(scheduler, torch.optim.lr_scheduler.LambdaLR)
- assert scheduler.optimizer is optimizers[0]
- assert scheduler.base_lrs == [13e-3]
-
- model.warmup_proportion = warmup_proportion_backup_value
- model.scheduler_name = scheduler_name_backup_value
diff --git a/tests/models/test_simple_sequence_classification.py b/tests/models/test_simple_sequence_classification.py
deleted file mode 100644
index 1fe740cd3..000000000
--- a/tests/models/test_simple_sequence_classification.py
+++ /dev/null
@@ -1,550 +0,0 @@
-from typing import Dict
-
-import pytest
-import torch
-from pytorch_lightning import Trainer
-from torch.optim.lr_scheduler import LambdaLR
-from transformers.modeling_outputs import SequenceClassifierOutput
-
-from pie_modules.models import SimpleSequenceClassificationModel
-from pie_modules.models.simple_sequence_classification import OutputType
-from tests.models import trunc_number
-
-NUM_CLASSES = 4
-
-
-@pytest.fixture
-def inputs() -> Dict[str, torch.LongTensor]:
- result_dict = {
- "input_ids": torch.tensor(
- [
- [
- 101,
- 28998,
- 13832,
- 3121,
- 2340,
- 138,
- 28996,
- 1759,
- 1120,
- 28999,
- 139,
- 28997,
- 119,
- 102,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28998,
- 13832,
- 3121,
- 2340,
- 144,
- 28996,
- 1759,
- 1120,
- 28999,
- 145,
- 28997,
- 119,
- 1262,
- 1771,
- 146,
- 119,
- 102,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28998,
- 13832,
- 3121,
- 2340,
- 144,
- 28996,
- 1759,
- 1120,
- 145,
- 119,
- 1262,
- 1771,
- 28999,
- 146,
- 28997,
- 119,
- 102,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 13832,
- 3121,
- 2340,
- 144,
- 1759,
- 1120,
- 28999,
- 145,
- 28997,
- 119,
- 1262,
- 1771,
- 28998,
- 146,
- 28996,
- 119,
- 102,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28998,
- 13832,
- 3121,
- 2340,
- 150,
- 28996,
- 1759,
- 1120,
- 28999,
- 151,
- 28997,
- 119,
- 1262,
- 1122,
- 1771,
- 152,
- 119,
- 102,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 13832,
- 3121,
- 2340,
- 150,
- 1759,
- 1120,
- 151,
- 119,
- 1262,
- 28998,
- 1122,
- 28996,
- 1771,
- 28999,
- 152,
- 28997,
- 119,
- 102,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 13832,
- 3121,
- 2340,
- 150,
- 1759,
- 1120,
- 151,
- 119,
- 1262,
- 28999,
- 1122,
- 28997,
- 1771,
- 28998,
- 152,
- 28996,
- 119,
- 102,
- ],
- ]
- ).to(torch.long),
- "attention_mask": torch.tensor(
- [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ).to(torch.long),
- }
-
- return result_dict
-
-
-@pytest.fixture
-def targets() -> Dict[str, torch.LongTensor]:
- return {"labels": torch.tensor([0, 1, 2, 3, 1, 2, 3]).to(torch.long)}
-
-
-@pytest.fixture
-def model() -> SimpleSequenceClassificationModel:
- torch.manual_seed(42)
- result = SimpleSequenceClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- )
- return result
-
-
-def test_model(model):
- assert model is not None
- named_parameters = dict(model.named_parameters())
- parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()}
- parameter_means_expected = {
- "model.bert.embeddings.word_embeddings.weight": 0.0031152,
- "model.bert.embeddings.position_embeddings.weight": 5.5e-05,
- "model.bert.embeddings.token_type_embeddings.weight": -0.0015419,
- "model.bert.embeddings.LayerNorm.weight": 1.312345,
- "model.bert.embeddings.LayerNorm.bias": -0.0294608,
- "model.bert.encoder.layer.0.attention.self.query.weight": -0.0003949,
- "model.bert.encoder.layer.0.attention.self.query.bias": 0.0185744,
- "model.bert.encoder.layer.0.attention.self.key.weight": 0.0003863,
- "model.bert.encoder.layer.0.attention.self.key.bias": 0.0020557,
- "model.bert.encoder.layer.0.attention.self.value.weight": 4.22e-05,
- "model.bert.encoder.layer.0.attention.self.value.bias": 0.0065417,
- "model.bert.encoder.layer.0.attention.output.dense.weight": 3.01e-05,
- "model.bert.encoder.layer.0.attention.output.dense.bias": 0.0007209,
- "model.bert.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831,
- "model.bert.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714,
- "model.bert.encoder.layer.0.intermediate.dense.weight": -0.0011731,
- "model.bert.encoder.layer.0.intermediate.dense.bias": -0.1219958,
- "model.bert.encoder.layer.0.output.dense.weight": -0.0002212,
- "model.bert.encoder.layer.0.output.dense.bias": -0.0013031,
- "model.bert.encoder.layer.0.output.LayerNorm.weight": 1.2419648,
- "model.bert.encoder.layer.0.output.LayerNorm.bias": 0.005295,
- "model.bert.encoder.layer.1.attention.self.query.weight": -0.0007321,
- "model.bert.encoder.layer.1.attention.self.query.bias": -0.0358397,
- "model.bert.encoder.layer.1.attention.self.key.weight": 0.0001333,
- "model.bert.encoder.layer.1.attention.self.key.bias": 0.0045062,
- "model.bert.encoder.layer.1.attention.self.value.weight": 0.0001012,
- "model.bert.encoder.layer.1.attention.self.value.bias": -0.0007094,
- "model.bert.encoder.layer.1.attention.output.dense.weight": -2.43e-05,
- "model.bert.encoder.layer.1.attention.output.dense.bias": 0.0041446,
- "model.bert.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343,
- "model.bert.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237,
- "model.bert.encoder.layer.1.intermediate.dense.weight": -0.001344,
- "model.bert.encoder.layer.1.intermediate.dense.bias": -0.1247257,
- "model.bert.encoder.layer.1.output.dense.weight": -5.32e-05,
- "model.bert.encoder.layer.1.output.dense.bias": 0.000677,
- "model.bert.encoder.layer.1.output.LayerNorm.weight": 1.017162,
- "model.bert.encoder.layer.1.output.LayerNorm.bias": -0.0474442,
- "model.bert.pooler.dense.weight": 0.0001295,
- "model.bert.pooler.dense.bias": -0.0052078,
- "model.classifier.weight": 0.0005538,
- "model.classifier.bias": 0.0,
- }
- assert parameter_means == parameter_means_expected
-
-
-def test_model_pickleable(model):
- import pickle
-
- pickle.dumps(model)
-
-
-@pytest.fixture
-def model_output(model, inputs) -> OutputType:
- # set seed to make sure the output is deterministic
- torch.manual_seed(42)
- return model(inputs)
-
-
-def test_forward(model_output, inputs):
- batch_size = inputs["input_ids"].shape[0]
- assert isinstance(model_output, SequenceClassifierOutput)
- assert set(model_output) == {"logits"}
- logits = model_output["logits"]
-
- assert logits.shape == (batch_size, NUM_CLASSES)
-
- torch.testing.assert_close(
- logits,
- torch.tensor(
- [
- [
- 0.16545572876930237,
- 0.17544983327388763,
- -0.011048287153244019,
- 0.05337674915790558,
- ],
- [
- 0.14748695492744446,
- 0.16249355673789978,
- -0.058017998933792114,
- 0.025398850440979004,
- ],
- [
- 0.14271709322929382,
- 0.16188383102416992,
- -0.061113521456718445,
- 0.026494741439819336,
- ],
- [
- 0.15641027688980103,
- 0.17225395143032074,
- -0.05567866563796997,
- 0.022433891892433167,
- ],
- [
- 0.15785054862499237,
- 0.16935551166534424,
- -0.054724469780921936,
- 0.012338697910308838,
- ],
- [
- 0.16152460873126984,
- 0.17789196968078613,
- -0.053754448890686035,
- 0.008724510669708252,
- ],
- [
- 0.16836002469062805,
- 0.17842254042625427,
- -0.052499815821647644,
- 0.006823211908340454,
- ],
- ]
- ),
- )
-
-
-def test_decode(model, model_output, inputs):
- decoded = model.decode(inputs=inputs, outputs=model_output)
- assert isinstance(decoded, dict)
- assert set(decoded) == {"labels", "probabilities"}
- labels = decoded["labels"]
- assert labels.shape == (inputs["input_ids"].shape[0],)
- torch.testing.assert_close(
- labels,
- torch.tensor([1, 1, 1, 1, 1, 1, 1]),
- )
- probabilities = decoded["probabilities"]
- assert probabilities.shape == (inputs["input_ids"].shape[0], NUM_CLASSES)
- torch.testing.assert_close(
- probabilities,
- torch.tensor(
- [
- [
- 0.2672215402126312,
- 0.26990556716918945,
- 0.22398385405540466,
- 0.23888900876045227,
- ],
- [
- 0.26922059059143066,
- 0.27329114079475403,
- 0.21920911967754364,
- 0.23827917873859406,
- ],
- [0.2684398889541626, 0.2736346125602722, 0.21893969178199768, 0.23898591101169586],
- [0.2703087329864502, 0.2746255099773407, 0.21865077316761017, 0.23641489446163177],
- [0.2713961601257324, 0.2745365798473358, 0.21942369639873505, 0.2346435934305191],
- [
- 0.27165648341178894,
- 0.27613937854766846,
- 0.21904107928276062,
- 0.23316311836242676,
- ],
- [0.2730168402194977, 0.2757779359817505, 0.21891282498836517, 0.23229233920574188],
- ]
- ),
- )
-
-
-@pytest.fixture
-def batch(inputs, targets):
- return inputs, targets
-
-
-def test_training_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.training_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.4069921970367432))
-
-
-def test_validation_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.validation_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.4069921970367432))
-
-
-def test_test_step(batch, model):
- # set the seed to make sure the loss is deterministic
- torch.manual_seed(42)
- loss = model.test_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.4069921970367432))
-
-
-def test_base_model_named_parameters(model):
- base_model_named_parameters = dict(model.base_model_named_parameters())
- assert set(base_model_named_parameters) == {
- "model.bert.pooler.dense.bias",
- "model.bert.encoder.layer.0.intermediate.dense.weight",
- "model.bert.encoder.layer.0.intermediate.dense.bias",
- "model.bert.encoder.layer.1.attention.output.dense.weight",
- "model.bert.encoder.layer.1.attention.output.LayerNorm.weight",
- "model.bert.encoder.layer.1.attention.self.query.weight",
- "model.bert.encoder.layer.1.output.dense.weight",
- "model.bert.encoder.layer.0.output.dense.bias",
- "model.bert.encoder.layer.1.intermediate.dense.bias",
- "model.bert.encoder.layer.1.attention.self.value.bias",
- "model.bert.encoder.layer.0.attention.output.dense.weight",
- "model.bert.encoder.layer.0.attention.self.query.bias",
- "model.bert.encoder.layer.0.attention.self.value.bias",
- "model.bert.encoder.layer.1.output.dense.bias",
- "model.bert.encoder.layer.1.attention.self.query.bias",
- "model.bert.encoder.layer.1.attention.output.LayerNorm.bias",
- "model.bert.encoder.layer.0.attention.self.query.weight",
- "model.bert.encoder.layer.0.attention.output.LayerNorm.bias",
- "model.bert.encoder.layer.0.attention.self.key.bias",
- "model.bert.encoder.layer.1.intermediate.dense.weight",
- "model.bert.encoder.layer.1.output.LayerNorm.bias",
- "model.bert.encoder.layer.1.output.LayerNorm.weight",
- "model.bert.encoder.layer.0.attention.self.key.weight",
- "model.bert.encoder.layer.1.attention.output.dense.bias",
- "model.bert.encoder.layer.0.attention.output.dense.bias",
- "model.bert.embeddings.LayerNorm.bias",
- "model.bert.encoder.layer.0.attention.self.value.weight",
- "model.bert.encoder.layer.0.attention.output.LayerNorm.weight",
- "model.bert.embeddings.token_type_embeddings.weight",
- "model.bert.encoder.layer.0.output.LayerNorm.weight",
- "model.bert.embeddings.position_embeddings.weight",
- "model.bert.encoder.layer.1.attention.self.key.bias",
- "model.bert.embeddings.LayerNorm.weight",
- "model.bert.encoder.layer.0.output.LayerNorm.bias",
- "model.bert.encoder.layer.1.attention.self.key.weight",
- "model.bert.pooler.dense.weight",
- "model.bert.encoder.layer.0.output.dense.weight",
- "model.bert.embeddings.word_embeddings.weight",
- "model.bert.encoder.layer.1.attention.self.value.weight",
- }
-
-
-def test_task_named_parameters(model):
- task_named_parameters = dict(model.task_named_parameters())
- assert set(task_named_parameters) == {
- "model.classifier.weight",
- "model.classifier.bias",
- }
-
-
-def test_configure_optimizers_with_warmup():
- model = SimpleSequenceClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- )
- model.trainer = Trainer(max_epochs=10)
- optimizers_and_schedulers = model.configure_optimizers()
- assert len(optimizers_and_schedulers) == 2
- optimizers, schedulers = optimizers_and_schedulers
- assert len(optimizers) == 1
- assert len(schedulers) == 1
- optimizer = optimizers[0]
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.AdamW)
- assert optimizer.defaults["lr"] == 1e-05
- assert optimizer.defaults["weight_decay"] == 0.01
- assert optimizer.defaults["eps"] == 1e-08
-
- scheduler = schedulers[0]
- assert isinstance(scheduler, dict)
- assert set(scheduler) == {"scheduler", "interval"}
- assert isinstance(scheduler["scheduler"], LambdaLR)
-
-
-def test_configure_optimizers_with_task_learning_rate(monkeypatch):
- model = SimpleSequenceClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- learning_rate=1e-5,
- task_learning_rate=1e-3,
- # disable warmup to make sure the scheduler is not added which would set the learning rate
- # to 0
- warmup_proportion=0.0,
- )
- optimizer = model.configure_optimizers()
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.AdamW)
- assert len(optimizer.param_groups) == 2
- # base model parameters
- param_group = optimizer.param_groups[0]
- assert len(param_group["params"]) == 39
- assert param_group["lr"] == 1e-5
- # classifier head parameters
- param_group = optimizer.param_groups[1]
- assert len(param_group["params"]) == 2
- assert param_group["lr"] == 1e-3
- # ensure that all parameters are covered
- assert set(optimizer.param_groups[0]["params"] + optimizer.param_groups[1]["params"]) == set(
- model.parameters()
- )
-
-
-def test_freeze_base_model(monkeypatch, inputs, targets):
- model = SimpleSequenceClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- freeze_base_model=True,
- # disable warmup to make sure the scheduler is not added which would set the learning rate
- # to 0
- warmup_proportion=0.0,
- )
- base_model_params = [param for name, param in model.base_model_named_parameters()]
- task_params = [param for name, param in model.task_named_parameters()]
- assert len(base_model_params) + len(task_params) == len(list(model.parameters()))
- for param in base_model_params:
- assert not param.requires_grad
- for param in task_params:
- assert param.requires_grad
-
-
-def test_base_model_named_parameters_wrong_prefix(monkeypatch):
- model = SimpleSequenceClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- base_model_prefix="wrong_prefix",
- )
- with pytest.raises(ValueError) as excinfo:
- model.base_model_named_parameters()
- assert (
- str(excinfo.value)
- == "Base model with prefix 'wrong_prefix' not found in BertForSequenceClassification"
- )
diff --git a/tests/models/test_simple_token_classification.py b/tests/models/test_simple_token_classification.py
deleted file mode 100644
index f6b190f6a..000000000
--- a/tests/models/test_simple_token_classification.py
+++ /dev/null
@@ -1,453 +0,0 @@
-import pytest
-import torch
-
-from pie_modules.models import SimpleTokenClassificationModel
-from pie_modules.models.common import TESTING, TRAINING, VALIDATION
-from pie_modules.taskmodules import LabeledSpanExtractionByTokenClassificationTaskModule
-from tests import _config_to_str
-from tests.models import trunc_number
-
-CONFIGS = [{}]
-CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-
-
-@pytest.fixture(scope="module", params=CONFIG_DICT.keys())
-def config_str(request):
- return request.param
-
-
-@pytest.fixture(scope="module")
-def config(config_str):
- return CONFIG_DICT[config_str]
-
-
-@pytest.fixture
-def taskmodule_config():
- return {
- "taskmodule_type": "LabeledSpanExtractionByTokenClassificationTaskModule",
- "tokenizer_name_or_path": "bert-base-cased",
- "span_annotation": "entities",
- "partition_annotation": None,
- "label_pad_id": -100,
- "labels": ["ORG", "PER"],
- "include_ill_formed_predictions": True,
- "tokenize_kwargs": None,
- "pad_kwargs": None,
- "combine_token_scores_method": "mean",
- "log_precision_recall_metrics": True,
- }
-
-
-def test_taskmodule_config(documents, taskmodule_config):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- span_annotation="entities",
- tokenizer_name_or_path=tokenizer_name_or_path,
- )
- taskmodule.prepare(documents)
- assert taskmodule.config == taskmodule_config
-
-
-def test_batch(documents, batch, taskmodule_config):
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule.from_config(
- taskmodule_config
- )
- encodings = taskmodule.encode(documents, encode_target=True)
- # just take the first 4 encodings
- batch_from_documents = taskmodule.collate(encodings[:4])
-
- inputs, targets = batch
- inputs_from_documents, targets_from_documents = batch_from_documents
- assert set(inputs) == set(inputs_from_documents)
- torch.testing.assert_close(inputs["input_ids"], inputs_from_documents["input_ids"])
- torch.testing.assert_close(inputs["attention_mask"], inputs_from_documents["attention_mask"])
- torch.testing.assert_close(targets, targets_from_documents)
-
-
-@pytest.fixture
-def batch():
- inputs = {
- "input_ids": torch.tensor(
- [
- [101, 138, 1423, 5650, 119, 102, 0, 0, 0, 0, 0, 0],
- [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102, 0, 0],
- [101, 13832, 3121, 2340, 140, 1105, 141, 119, 102, 0, 0, 0],
- [101, 1752, 5650, 119, 13832, 3121, 2340, 142, 1105, 143, 119, 102],
- ]
- ).to(torch.long),
- "attention_mask": torch.tensor(
- [
- [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ),
- "special_tokens_mask": torch.tensor(
- [
- [1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
- [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
- ]
- ),
- }
- targets = {
- "labels": torch.tensor(
- [
- [-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100],
- [-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100],
- [-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100],
- [-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100],
- ]
- )
- }
- return inputs, targets
-
-
-@pytest.fixture
-def model(monkeypatch, batch, config, taskmodule_config) -> SimpleTokenClassificationModel:
- torch.manual_seed(42)
- model = SimpleTokenClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=5,
- taskmodule_config=taskmodule_config,
- metric_stages=["val", "test"],
- )
- return model
-
-
-def test_model(model):
- assert model is not None
- named_parameters = dict(model.named_parameters())
- parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()}
- parameter_means_expected = {
- "model.bert.embeddings.word_embeddings.weight": 0.0031152,
- "model.bert.embeddings.position_embeddings.weight": 5.5e-05,
- "model.bert.embeddings.token_type_embeddings.weight": -0.0015419,
- "model.bert.embeddings.LayerNorm.weight": 1.312345,
- "model.bert.embeddings.LayerNorm.bias": -0.0294608,
- "model.bert.encoder.layer.0.attention.self.query.weight": -0.0003949,
- "model.bert.encoder.layer.0.attention.self.query.bias": 0.0185744,
- "model.bert.encoder.layer.0.attention.self.key.weight": 0.0003863,
- "model.bert.encoder.layer.0.attention.self.key.bias": 0.0020557,
- "model.bert.encoder.layer.0.attention.self.value.weight": 4.22e-05,
- "model.bert.encoder.layer.0.attention.self.value.bias": 0.0065417,
- "model.bert.encoder.layer.0.attention.output.dense.weight": 3.01e-05,
- "model.bert.encoder.layer.0.attention.output.dense.bias": 0.0007209,
- "model.bert.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831,
- "model.bert.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714,
- "model.bert.encoder.layer.0.intermediate.dense.weight": -0.0011731,
- "model.bert.encoder.layer.0.intermediate.dense.bias": -0.1219958,
- "model.bert.encoder.layer.0.output.dense.weight": -0.0002212,
- "model.bert.encoder.layer.0.output.dense.bias": -0.0013031,
- "model.bert.encoder.layer.0.output.LayerNorm.weight": 1.2419648,
- "model.bert.encoder.layer.0.output.LayerNorm.bias": 0.005295,
- "model.bert.encoder.layer.1.attention.self.query.weight": -0.0007321,
- "model.bert.encoder.layer.1.attention.self.query.bias": -0.0358397,
- "model.bert.encoder.layer.1.attention.self.key.weight": 0.0001333,
- "model.bert.encoder.layer.1.attention.self.key.bias": 0.0045062,
- "model.bert.encoder.layer.1.attention.self.value.weight": 0.0001012,
- "model.bert.encoder.layer.1.attention.self.value.bias": -0.0007094,
- "model.bert.encoder.layer.1.attention.output.dense.weight": -2.43e-05,
- "model.bert.encoder.layer.1.attention.output.dense.bias": 0.0041446,
- "model.bert.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343,
- "model.bert.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237,
- "model.bert.encoder.layer.1.intermediate.dense.weight": -0.001344,
- "model.bert.encoder.layer.1.intermediate.dense.bias": -0.1247257,
- "model.bert.encoder.layer.1.output.dense.weight": -5.32e-05,
- "model.bert.encoder.layer.1.output.dense.bias": 0.000677,
- "model.bert.encoder.layer.1.output.LayerNorm.weight": 1.017162,
- "model.bert.encoder.layer.1.output.LayerNorm.bias": -0.0474442,
- "model.classifier.weight": 0.0005138,
- "model.classifier.bias": 0.0,
- }
- assert parameter_means == parameter_means_expected
-
-
-def test_model_pickleable(model):
- import pickle
-
- pickle.dumps(model)
-
-
-def test_forward(batch, model):
- inputs, targets = batch
- batch_size, seq_len = inputs["input_ids"].shape
- num_classes = model.config["num_classes"]
-
- # set seed to make sure the output is deterministic
- torch.manual_seed(42)
- output = model.forward(inputs)
- assert set(output) == {"logits"}
- logits = output["logits"]
- assert logits.shape == (batch_size, seq_len, num_classes)
- # check the first batch entry
- torch.testing.assert_close(
- logits[0],
- torch.tensor(
- [
- [
- -0.13442197442054749,
- -0.06983129680156708,
- 0.17513807117938995,
- -0.24002864956855774,
- 0.08871676027774811,
- ],
- [
- -0.032687313854694366,
- -0.2071131318807602,
- 0.10695032775402069,
- -0.05829116329550743,
- -0.21174949407577515,
- ],
- [
- -0.17153336107730865,
- -0.2230629324913025,
- -0.11457862704992294,
- 0.03658870607614517,
- -0.242639422416687,
- ],
- [
- -0.07552017271518707,
- -0.20950022339820862,
- 0.041016221046447754,
- -0.13453879952430725,
- -0.09942213445901871,
- ],
- [
- -0.19299760460853577,
- -0.2081824392080307,
- 0.20880958437919617,
- -0.028745755553245544,
- -0.14375154674053192,
- ],
- [
- -0.20548884570598602,
- -0.17012161016464233,
- 0.0647551566362381,
- -0.090476393699646,
- -0.1362220048904419,
- ],
- [
- -0.09553629904985428,
- -0.1303575187921524,
- 0.2995688021183014,
- -0.04689876735210419,
- -0.17737819254398346,
- ],
- [
- -0.030023209750652313,
- -0.12308696657419205,
- 0.2582213580608368,
- -0.04085375368595123,
- -0.16487300395965576,
- ],
- [
- -0.04765648394823074,
- -0.18347612023353577,
- 0.24941012263298035,
- 0.022468380630016327,
- -0.19706891477108002,
- ],
- [
- -0.09828818589448929,
- -0.18449409306049347,
- 0.2711920738220215,
- 0.044708192348480225,
- -0.15743865072727203,
- ],
- [
- -0.13639293611049652,
- -0.16482298076152802,
- 0.3018418848514557,
- 0.0815257728099823,
- -0.15574774146080017,
- ],
- [
- -0.14846578240394592,
- -0.17294010519981384,
- 0.31513816118240356,
- 0.10425455123186111,
- -0.16388092935085297,
- ],
- ]
- ),
- )
-
- # check the sums per sequence
- torch.testing.assert_close(
- logits.sum(1),
- torch.tensor(
- [
- [
- -1.3690122365951538,
- -2.0469894409179688,
- 2.1774630546569824,
- -0.35028770565986633,
- -1.7614551782608032,
- ],
- [
- -0.892522394657135,
- -1.3144632577896118,
- 2.683281898498535,
- -1.4629074335098267,
- -3.3516180515289307,
- ],
- [
- -1.3936796188354492,
- 0.21844607591629028,
- 4.501010417938232,
- -0.15485064685344696,
- -2.651848316192627,
- ],
- [
- -1.7388781309127808,
- -0.7211084365844727,
- 3.463726043701172,
- -0.2992384433746338,
- -2.65508770942688,
- ],
- ]
- ),
- )
-
-
-def test_training_step_and_on_epoch_end(batch, model, config):
- assert model._get_metric(TRAINING) is None
- loss = model.training_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.730902075767517))
-
- model.on_train_epoch_end()
-
-
-def test_validation_step_and_on_epoch_end(batch, model, config):
- metric = model._get_metric(VALIDATION)
- metric.reset()
- loss = model.validation_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.730902075767517))
- metric_values = {k: v.item() for k, v in metric.compute().items()}
- assert metric_values == {
- "span/ORG/f1": 0.0,
- "span/ORG/precision": 0.0,
- "span/ORG/recall": 0.0,
- "span/PER/f1": 0.0,
- "span/PER/precision": 0.0,
- "span/PER/recall": 0.0,
- "span/macro/f1": 0.0,
- "span/macro/precision": 0.0,
- "span/macro/recall": 0.0,
- "span/micro/f1": 0.0,
- "span/micro/precision": 0.0,
- "span/micro/recall": 0.0,
- "token/macro/f1": 0.0,
- "token/micro/f1": 0.0,
- "token/macro/precision": 0.0,
- "token/macro/recall": 0.0,
- "token/micro/precision": 0.0,
- "token/micro/recall": 0.0,
- }
-
- model.on_validation_epoch_end()
-
-
-def test_test_step_and_on_epoch_end(batch, model, config):
- metric = model._get_metric(TESTING)
- metric.reset()
- loss = model.test_step(batch, batch_idx=0)
- assert loss is not None
- torch.testing.assert_close(loss, torch.tensor(1.730902075767517))
- metric_values = {k: v.item() for k, v in metric.compute().items()}
- assert metric_values == {
- "span/ORG/f1": 0.0,
- "span/ORG/precision": 0.0,
- "span/ORG/recall": 0.0,
- "span/PER/f1": 0.0,
- "span/PER/precision": 0.0,
- "span/PER/recall": 0.0,
- "span/macro/f1": 0.0,
- "span/macro/precision": 0.0,
- "span/macro/recall": 0.0,
- "span/micro/f1": 0.0,
- "span/micro/precision": 0.0,
- "span/micro/recall": 0.0,
- "token/macro/f1": 0.0,
- "token/micro/f1": 0.0,
- "token/macro/precision": 0.0,
- "token/macro/recall": 0.0,
- "token/micro/precision": 0.0,
- "token/micro/recall": 0.0,
- }
-
- model.on_test_epoch_end()
-
-
-@pytest.mark.parametrize("test_step", [False, True])
-def test_predict_and_predict_step(model, batch, config, test_step):
- torch.manual_seed(42)
- if test_step:
- predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0)
- else:
- predictions = model.predict(batch[0])
- assert set(predictions) == {"labels", "probabilities"}
-
- assert predictions["labels"].shape == batch[1]["labels"].shape
- torch.testing.assert_close(
- predictions["labels"],
- torch.tensor(
- [
- [-100, 2, 3, 2, 2, -100, -100, -100, -100, -100, -100, -100],
- [-100, 2, 2, 2, 2, 2, 2, 2, 2, -100, -100, -100],
- [-100, 2, 2, 2, 2, 2, 2, 2, -100, -100, -100, -100],
- [-100, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, -100],
- ]
- ),
- )
- torch.testing.assert_close(
- # just check the first two batch entries
- predictions["probabilities"][:2].round(decimals=4),
- torch.tensor(
- [
- [
- [0.1792, 0.1912, 0.2443, 0.1613, 0.2240],
- [0.2083, 0.1750, 0.2395, 0.2030, 0.1742],
- [0.1934, 0.1837, 0.2047, 0.2381, 0.1801],
- [0.2034, 0.1779, 0.2285, 0.1917, 0.1986],
- [0.1752, 0.1725, 0.2618, 0.2065, 0.1840],
- [0.1805, 0.1870, 0.2365, 0.2025, 0.1935],
- [0.1844, 0.1781, 0.2738, 0.1936, 0.1700],
- [0.1958, 0.1784, 0.2612, 0.1937, 0.1711],
- [0.1941, 0.1694, 0.2612, 0.2082, 0.1671],
- [0.1831, 0.1680, 0.2650, 0.2113, 0.1726],
- [0.1740, 0.1691, 0.2697, 0.2164, 0.1707],
- [0.1713, 0.1672, 0.2723, 0.2205, 0.1687],
- ],
- [
- [0.1654, 0.1989, 0.2729, 0.1542, 0.2086],
- [0.1787, 0.1511, 0.3093, 0.1968, 0.1641],
- [0.1888, 0.1966, 0.2365, 0.2081, 0.1700],
- [0.2092, 0.1935, 0.2428, 0.2034, 0.1511],
- [0.2275, 0.1784, 0.2546, 0.1856, 0.1539],
- [0.2254, 0.1959, 0.2377, 0.1873, 0.1536],
- [0.2177, 0.1879, 0.2485, 0.1975, 0.1484],
- [0.2227, 0.1906, 0.2541, 0.1906, 0.1420],
- [0.2080, 0.2098, 0.2667, 0.1764, 0.1391],
- [0.1815, 0.2015, 0.2600, 0.1852, 0.1718],
- [0.1672, 0.1883, 0.3065, 0.1773, 0.1607],
- [0.1750, 0.1846, 0.2911, 0.1862, 0.1630],
- ],
- ]
- ),
- )
-
-
-def test_configure_optimizers(model):
- optimizer = model.configure_optimizers()
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.Adam)
- assert optimizer.defaults["lr"] == 1e-05
- assert len(optimizer.param_groups) == 1
- assert len(optimizer.param_groups[0]["params"]) > 0
- assert set(optimizer.param_groups[0]["params"]) == set(model.parameters())
diff --git a/tests/models/test_span_tuple_classification.py b/tests/models/test_span_tuple_classification.py
deleted file mode 100644
index f984fdbdb..000000000
--- a/tests/models/test_span_tuple_classification.py
+++ /dev/null
@@ -1,549 +0,0 @@
-import pytest
-import torch
-from pytorch_lightning import Trainer
-from torch import tensor
-
-from pie_modules.models import SpanTupleClassificationModel
-from pie_modules.models.common import TESTING, TRAINING, VALIDATION
-from pie_modules.taskmodules import RESpanPairClassificationTaskModule
-from tests import _config_to_str
-
-CONFIGS = [{}]
-CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-NUM_CLASSES = 4
-
-
-@pytest.fixture(scope="module", params=CONFIG_DICT.keys())
-def config_str(request):
- return request.param
-
-
-@pytest.fixture(scope="module")
-def config(config_str):
- return CONFIG_DICT[config_str]
-
-
-@pytest.fixture
-def taskmodule_config():
- return {
- "taskmodule_type": "RESpanPairClassificationTaskModule",
- "tokenizer_name_or_path": "bert-base-cased",
- "relation_annotation": "relations",
- "no_relation_label": "no_relation",
- "partition_annotation": None,
- "tokenize_kwargs": None,
- "create_candidate_relations": False,
- "create_candidate_relations_kwargs": None,
- "labels": ["org:founded_by", "per:employee_of", "per:founder"],
- "entity_labels": ["ORG", "PER"],
- "add_type_to_marker": True,
- "log_first_n_examples": 0,
- "collect_statistics": False,
- }
-
-
-def test_taskmodule_config(documents, taskmodule_config):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RESpanPairClassificationTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- )
- taskmodule.prepare(documents)
- assert taskmodule.config == taskmodule_config
- assert len(taskmodule.id_to_label) == NUM_CLASSES
-
-
-def test_batch(documents, batch, taskmodule_config):
- taskmodule = RESpanPairClassificationTaskModule.from_config(taskmodule_config)
- encodings = taskmodule.encode(documents, encode_target=True, as_dataset=True)
- batch_from_documents = taskmodule.collate(encodings[:4])
-
- inputs, targets = batch
- inputs_from_documents, targets_from_documents = batch_from_documents
- assert set(inputs) == set(inputs_from_documents)
- for key in inputs:
- torch.testing.assert_close(inputs[key], inputs_from_documents[key])
- assert set(targets) == set(targets_from_documents)
- for key in targets:
- torch.testing.assert_close(targets[key], targets_from_documents[key])
-
-
-@pytest.fixture
-def batch():
- inputs = {
- "input_ids": tensor(
- [
- [
- 101,
- 28996,
- 13832,
- 3121,
- 2340,
- 138,
- 28998,
- 1759,
- 1120,
- 28999,
- 139,
- 28997,
- 119,
- 102,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28996,
- 13832,
- 3121,
- 2340,
- 144,
- 28998,
- 1759,
- 1120,
- 28999,
- 145,
- 28997,
- 119,
- 1262,
- 1771,
- 28999,
- 146,
- 28997,
- 119,
- 102,
- 0,
- 0,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28996,
- 13832,
- 3121,
- 2340,
- 150,
- 28998,
- 1759,
- 1120,
- 28999,
- 151,
- 28997,
- 119,
- 1262,
- 28996,
- 1122,
- 28998,
- 1771,
- 28999,
- 152,
- 28997,
- 119,
- 102,
- ],
- ]
- ),
- "attention_mask": tensor(
- [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ),
- "span_start_indices": tensor([[1, 9, 0, 0], [4, 12, 18, 0], [4, 12, 17, 21]]),
- "span_end_indices": tensor([[7, 12, 0, 0], [10, 15, 21, 0], [10, 15, 20, 24]]),
- "tuple_indices": tensor(
- [[[0, 1], [-1, -1], [-1, -1]], [[0, 1], [0, 2], [2, 1]], [[0, 1], [2, 3], [3, 2]]]
- ),
- "tuple_indices_mask": tensor(
- [[True, False, False], [True, True, True], [True, True, True]]
- ),
- }
- targets = {"labels": tensor([[2, -100, -100], [2, 3, 1], [2, 3, 1]])}
- return inputs, targets
-
-
-@pytest.fixture
-def model(batch, config, taskmodule_config) -> SpanTupleClassificationModel:
- torch.manual_seed(42)
- model = SpanTupleClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- taskmodule_config=taskmodule_config,
- metric_stages=["val", "test"],
- **config,
- )
- return model
-
-
-def test_model_pickleable(model):
- import pickle
-
- pickle.dumps(model)
-
-
-def test_freeze_base_model():
- model = SpanTupleClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- freeze_base_model=True,
- )
-
- base_model_params = dict(model.model.named_parameters(prefix="model"))
- assert len(base_model_params) > 0
- for param in base_model_params.values():
- assert not param.requires_grad
- task_params = {
- name: param for name, param in model.named_parameters() if name not in base_model_params
- }
- assert len(task_params) > 0
- for param in task_params.values():
- assert param.requires_grad
-
-
-def test_tune_base_model():
- model = SpanTupleClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=5,
- )
- base_model_params = dict(model.model.named_parameters(prefix="model"))
- assert len(base_model_params) > 0
- for param in base_model_params.values():
- assert param.requires_grad
- task_params = {
- name: param for name, param in model.named_parameters() if name not in base_model_params
- }
- assert len(task_params) > 0
- for param in task_params.values():
- assert param.requires_grad
-
-
-@pytest.mark.parametrize(
- "span_embedding_mode", ["start_and_end_token", "start_token", "end_token"]
-)
-@pytest.mark.parametrize(
- "tuple_embedding_mode", ["concat", "multiply2_and_concat", "index_0", "index_1"]
-)
-def test_forward_embeddings(batch, taskmodule_config, span_embedding_mode, tuple_embedding_mode):
- torch.manual_seed(42)
- simple_model = SpanTupleClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- # disable the tuple mlp to allow for checking the intermediate embeddings via the indices
- tuple_entry_hidden_dim=None,
- taskmodule_config=taskmodule_config,
- span_embedding_mode=span_embedding_mode,
- tuple_embedding_mode=tuple_embedding_mode,
- )
-
- inputs, targets = batch
- batch_size, seq_len = inputs["input_ids"].shape
-
- # set seed to make sure the output is deterministic
- torch.manual_seed(42)
- # return embeddings to check the logits
- output = simple_model.forward(inputs, return_embeddings=True)
- assert set(output) == {"logits", "last_hidden_state", "span_embeddings", "tuple_embeddings"}
- logits_flat = output["logits"]
- assert len(logits_flat.shape) == 2
- assert logits_flat.shape[-1] == NUM_CLASSES
-
- # check span_embeddings: they should be the entries of last_hidden_state at the
- # span_start_indices and span_end_indices
- for batch_idx in range(batch_size):
- for j, (start, end) in enumerate(
- zip(inputs["span_start_indices"][batch_idx], inputs["span_end_indices"][batch_idx])
- ):
- current_expected_span_embedding_list = []
- if simple_model.span_embedding_mode == "start_and_end_token":
- current_expected_span_embedding_list.append(
- output["last_hidden_state"][batch_idx, start]
- )
- current_expected_span_embedding_list.append(
- output["last_hidden_state"][batch_idx, end]
- )
- elif simple_model.span_embedding_mode == "start_token":
- current_expected_span_embedding_list.append(
- output["last_hidden_state"][batch_idx, start]
- )
- elif simple_model.span_embedding_mode == "end_token":
- current_expected_span_embedding_list.append(
- output["last_hidden_state"][batch_idx, end]
- )
- else:
- raise ValueError(
- f"Unknown span_embedding_mode: {simple_model.span_embedding_mode}"
- )
- expected_current_span_embedding = torch.concat(
- current_expected_span_embedding_list, dim=-1
- )
- current_span_embeddings = output["span_embeddings"][batch_idx, j]
- torch.testing.assert_close(current_span_embeddings, expected_current_span_embedding)
-
- # check tuple_embeddings: they should be the entries of span_embeddings at the tuple_indices
- tuple_idx = 0
- for batch_idx in range(batch_size):
- for indices, is_valid in zip(
- inputs["tuple_indices"][batch_idx], inputs["tuple_indices_mask"][batch_idx]
- ):
- if is_valid:
- current_expected_tuple_embedding_list = [
- output["span_embeddings"][batch_idx, idx] for idx in indices
- ]
- if simple_model.tuple_embedding_mode == "concat":
- expected_current_tuple_embedding = torch.concat(
- current_expected_tuple_embedding_list, dim=-1
- )
- elif simple_model.tuple_embedding_mode == "multiply2_and_concat":
- expected_current_tuple_embedding = torch.cat(
- [
- current_expected_tuple_embedding_list[0]
- * current_expected_tuple_embedding_list[1],
- current_expected_tuple_embedding_list[0],
- current_expected_tuple_embedding_list[1],
- ],
- dim=-1,
- )
- elif simple_model.tuple_embedding_mode.startswith("index_"):
- idx = int(simple_model.tuple_embedding_mode.split("_")[1])
- expected_current_tuple_embedding = current_expected_tuple_embedding_list[idx]
- else:
- raise ValueError(
- f"Unknown tuple_embedding_mode: {simple_model.tuple_embedding_mode}"
- )
- current_tuple_embedding = output["tuple_embeddings"][tuple_idx]
- torch.testing.assert_close(
- current_tuple_embedding, expected_current_tuple_embedding
- )
- tuple_idx += 1
-
-
-def test_forward_logits(batch, model):
- inputs, targets = batch
-
- # set seed to make sure the output is deterministic
- torch.manual_seed(42)
- # return embeddings to check the logits
- output = model.forward(inputs)
- assert set(output) == {"logits"}
- logits_flat = output["logits"]
- assert len(logits_flat.shape) == 2
- assert logits_flat.shape[-1] == NUM_CLASSES
- # check the actual logits
- torch.testing.assert_close(
- logits_flat,
- tensor(
- [
- [
- -0.23075447976589203,
- 0.08129829168319702,
- -0.26441076397895813,
- 0.3208361268043518,
- ],
- [
- -0.2247302085161209,
- 0.21453489363193512,
- -0.20609508454799652,
- 0.2984844148159027,
- ],
- [
- -0.0552724152803421,
- 0.18319237232208252,
- -0.14115819334983826,
- 0.23137536644935608,
- ],
- [
- -0.2897184491157532,
- 0.17462071776390076,
- -0.12004873156547546,
- 0.1817789375782013,
- ],
- [
- -0.3101494312286377,
- 0.18245069682598114,
- -0.13525372743606567,
- 0.28625163435935974,
- ],
- [
- -0.33728304505348206,
- 0.22038179636001587,
- -0.0482308566570282,
- 0.25237396359443665,
- ],
- [
- -0.3835912048816681,
- 0.20549766719341278,
- 0.15333643555641174,
- 0.23370930552482605,
- ],
- ]
- ),
- )
-
-
-def test_step(batch, model, config):
- torch.manual_seed(42)
- loss = model._step("train", batch)
- assert loss is not None
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(1.3911350965499878))
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-def test_training_step_and_on_epoch_end(batch, model, config):
- metric = model._get_metric(TRAINING, batch_idx=0)
- assert metric is None
- loss = model.training_step(batch, batch_idx=0)
- assert loss is not None
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(1.3911350965499878))
- else:
- raise ValueError(f"Unknown config: {config}")
-
- model.on_train_epoch_end()
-
-
-def test_validation_step_and_on_epoch_end(batch, model, config):
- metric = model._get_metric(VALIDATION, batch_idx=0)
- metric.reset()
- loss = model.validation_step(batch, batch_idx=0)
- assert loss is not None
- metric_values = {k: v.item() for k, v in metric.compute().items()}
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(1.3911350965499878))
- assert metric_values == {
- "macro/f1": 0.14814814925193787,
- "micro/f1": 0.2857142984867096,
- "no_relation/f1": 0.0,
- "org:founded_by/f1": 0.0,
- "per:employee_of/f1": 0.0,
- "per:founder/f1": 0.4444444477558136,
- }
- else:
- raise ValueError(f"Unknown config: {config}")
-
- model.on_validation_epoch_end()
-
-
-def test_test_step_and_on_epoch_end(batch, model, config):
- metric = model._get_metric(TESTING, batch_idx=0)
- metric.reset()
- loss = model.test_step(batch, batch_idx=0)
- assert loss is not None
- metric_values = {k: v.item() for k, v in metric.compute().items()}
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(1.3911350965499878))
- assert metric_values == {
- "macro/f1": 0.14814814925193787,
- "micro/f1": 0.2857142984867096,
- "no_relation/f1": 0.0,
- "org:founded_by/f1": 0.0,
- "per:employee_of/f1": 0.0,
- "per:founder/f1": 0.4444444477558136,
- }
- else:
- raise ValueError(f"Unknown config: {config}")
-
- model.on_test_epoch_end()
-
-
-@pytest.mark.parametrize("test_step", [False, True])
-def test_predict_and_predict_step(model, batch, config, test_step):
- torch.manual_seed(42)
- if test_step:
- predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0)
- else:
- predictions = model.predict(batch[0])
-
- assert set(predictions) == {"labels", "probabilities"}
- labels = predictions["labels"]
- assert labels.shape == batch[1]["labels"].shape
- probabilities = predictions["probabilities"]
- if config == {}:
- torch.testing.assert_close(labels, tensor([[3, -100, -100], [3, 3, 3], [3, 3, 3]]))
- torch.testing.assert_close(
- probabilities.round(decimals=4),
- tensor(
- [
- [
- [0.1973, 0.2695, 0.1907, 0.3425],
- [-1.0000, -1.0000, -1.0000, -1.0000],
- [-1.0000, -1.0000, -1.0000, -1.0000],
- ],
- [
- [0.1902, 0.2951, 0.1938, 0.3209],
- [0.2213, 0.2809, 0.2031, 0.2947],
- [0.1859, 0.2958, 0.2203, 0.2979],
- ],
- [
- [0.1772, 0.2900, 0.2111, 0.3217],
- [0.1699, 0.2968, 0.2269, 0.3064],
- [0.1571, 0.2831, 0.2687, 0.2912],
- ],
- ],
- ),
- )
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-def test_configure_optimizers(model):
- model.trainer = Trainer(max_epochs=10)
- optimizer_and_schedular = model.configure_optimizers()
- assert optimizer_and_schedular is not None
- optimizers, schedulers = optimizer_and_schedular
-
- assert len(optimizers) == 1
- optimizer = optimizers[0]
- assert isinstance(optimizer, torch.optim.AdamW)
- assert optimizer.defaults["lr"] == 1e-05
- assert optimizer.defaults["weight_decay"] == 0.01
- assert optimizer.defaults["eps"] == 1e-08
-
- assert len(schedulers) == 1
- scheduler = schedulers[0]
- assert isinstance(scheduler["scheduler"], torch.optim.lr_scheduler.LambdaLR)
-
-
-def test_configure_optimizers_with_task_learning_rate():
- model = SpanTupleClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=NUM_CLASSES,
- warmup_proportion=0.0,
- task_learning_rate=1e-4,
- )
- optimizer = model.configure_optimizers()
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.AdamW)
- assert len(optimizer.param_groups) == 2
- # check that all parameters are in the optimizer
- assert set(optimizer.param_groups[0]["params"]) | set(
- optimizer.param_groups[1]["params"]
- ) == set(model.parameters())
-
- # base model parameters
- param_group = optimizer.param_groups[0]
- assert param_group["lr"] == 1e-05
- assert len(param_group["params"]) == 39
-
- # task parameters
- param_group = optimizer.param_groups[1]
- assert param_group["lr"] == 1e-04
- assert len(param_group["params"]) == 6
diff --git a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py b/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py
deleted file mode 100644
index 5b31b2f13..000000000
--- a/tests/models/test_token_classification_with_seq2seq_encoder_and_crf.py
+++ /dev/null
@@ -1,639 +0,0 @@
-import pytest
-import torch
-from pytorch_lightning import Trainer
-
-from pie_modules.models import TokenClassificationModelWithSeq2SeqEncoderAndCrf
-from pie_modules.models.common import TESTING, TRAINING, VALIDATION
-from pie_modules.taskmodules import LabeledSpanExtractionByTokenClassificationTaskModule
-from tests import _config_to_str
-from tests.models import trunc_number
-
-CONFIGS = [{}, {"use_crf": False}]
-CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-
-
-@pytest.fixture(scope="module", params=CONFIG_DICT.keys())
-def config_str(request):
- return request.param
-
-
-@pytest.fixture(scope="module")
-def config(config_str):
- return CONFIG_DICT[config_str]
-
-
-@pytest.fixture
-def taskmodule_config():
- return {
- "taskmodule_type": "LabeledSpanExtractionByTokenClassificationTaskModule",
- "tokenizer_name_or_path": "bert-base-cased",
- "span_annotation": "entities",
- "partition_annotation": None,
- "label_pad_id": -100,
- "labels": ["ORG", "PER"],
- "include_ill_formed_predictions": True,
- "combine_token_scores_method": "mean",
- "tokenize_kwargs": None,
- "pad_kwargs": None,
- "log_precision_recall_metrics": True,
- }
-
-
-def test_taskmodule_config(documents, taskmodule_config):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- span_annotation="entities",
- tokenizer_name_or_path=tokenizer_name_or_path,
- )
- taskmodule.prepare(documents)
- assert taskmodule.config == taskmodule_config
-
-
-def test_batch(documents, batch, taskmodule_config):
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule.from_config(
- taskmodule_config
- )
- encodings = taskmodule.encode(documents, encode_target=True, as_dataset=True)
- batch_from_documents = taskmodule.collate(encodings[:4])
-
- inputs, targets = batch
- inputs_from_documents, targets_from_documents = batch_from_documents
- torch.testing.assert_close(inputs["input_ids"], inputs_from_documents["input_ids"])
- torch.testing.assert_close(inputs["attention_mask"], inputs_from_documents["attention_mask"])
- torch.testing.assert_close(targets, targets_from_documents)
-
-
-@pytest.fixture
-def batch():
- inputs = {
- "input_ids": torch.tensor(
- [
- [101, 138, 1423, 5650, 119, 102, 0, 0, 0, 0, 0, 0],
- [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102, 0, 0],
- [101, 13832, 3121, 2340, 140, 1105, 141, 119, 102, 0, 0, 0],
- [101, 1752, 5650, 119, 13832, 3121, 2340, 142, 1105, 143, 119, 102],
- ]
- ).to(torch.long),
- "attention_mask": torch.tensor(
- [
- [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ),
- "special_tokens_mask": torch.tensor(
- [
- [1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
- [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
- ]
- ),
- }
- targets = {
- "labels": torch.tensor(
- [
- [-100, 0, 0, 0, 0, -100, -100, -100, -100, -100, -100, -100],
- [-100, 3, 4, 4, 4, 0, 0, 1, 0, -100, -100, -100],
- [-100, 3, 4, 4, 4, 0, 1, 0, -100, -100, -100, -100],
- [-100, 0, 0, 0, 3, 4, 4, 4, 0, 1, 0, -100],
- ]
- )
- }
- return inputs, targets
-
-
-@pytest.fixture
-def model(batch, config, taskmodule_config) -> TokenClassificationModelWithSeq2SeqEncoderAndCrf:
- seq2seq_dict = {
- "type": "linear",
- "out_features": 10,
- }
- torch.manual_seed(42)
- model = TokenClassificationModelWithSeq2SeqEncoderAndCrf(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=5,
- seq2seq_encoder=seq2seq_dict,
- taskmodule_config=taskmodule_config,
- metric_stages=["val", "test"],
- **config,
- )
- return model
-
-
-def test_model(model, config):
- assert model is not None
- named_parameters = dict(model.named_parameters())
- parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()}
- parameter_means_expected = {
- "model.embeddings.word_embeddings.weight": 0.0031152,
- "model.embeddings.position_embeddings.weight": 5.5e-05,
- "model.embeddings.token_type_embeddings.weight": -0.0015419,
- "model.embeddings.LayerNorm.weight": 1.312345,
- "model.embeddings.LayerNorm.bias": -0.0294608,
- "model.encoder.layer.0.attention.self.query.weight": -0.0003949,
- "model.encoder.layer.0.attention.self.query.bias": 0.0185744,
- "model.encoder.layer.0.attention.self.key.weight": 0.0003863,
- "model.encoder.layer.0.attention.self.key.bias": 0.0020557,
- "model.encoder.layer.0.attention.self.value.weight": 4.22e-05,
- "model.encoder.layer.0.attention.self.value.bias": 0.0065417,
- "model.encoder.layer.0.attention.output.dense.weight": 3.01e-05,
- "model.encoder.layer.0.attention.output.dense.bias": 0.0007209,
- "model.encoder.layer.0.attention.output.LayerNorm.weight": 1.199831,
- "model.encoder.layer.0.attention.output.LayerNorm.bias": 0.0608714,
- "model.encoder.layer.0.intermediate.dense.weight": -0.0011731,
- "model.encoder.layer.0.intermediate.dense.bias": -0.1219958,
- "model.encoder.layer.0.output.dense.weight": -0.0002212,
- "model.encoder.layer.0.output.dense.bias": -0.0013031,
- "model.encoder.layer.0.output.LayerNorm.weight": 1.2419648,
- "model.encoder.layer.0.output.LayerNorm.bias": 0.005295,
- "model.encoder.layer.1.attention.self.query.weight": -0.0007321,
- "model.encoder.layer.1.attention.self.query.bias": -0.0358397,
- "model.encoder.layer.1.attention.self.key.weight": 0.0001333,
- "model.encoder.layer.1.attention.self.key.bias": 0.0045062,
- "model.encoder.layer.1.attention.self.value.weight": 0.0001012,
- "model.encoder.layer.1.attention.self.value.bias": -0.0007094,
- "model.encoder.layer.1.attention.output.dense.weight": -2.43e-05,
- "model.encoder.layer.1.attention.output.dense.bias": 0.0041446,
- "model.encoder.layer.1.attention.output.LayerNorm.weight": 1.0377343,
- "model.encoder.layer.1.attention.output.LayerNorm.bias": 0.0443237,
- "model.encoder.layer.1.intermediate.dense.weight": -0.001344,
- "model.encoder.layer.1.intermediate.dense.bias": -0.1247257,
- "model.encoder.layer.1.output.dense.weight": -5.32e-05,
- "model.encoder.layer.1.output.dense.bias": 0.000677,
- "model.encoder.layer.1.output.LayerNorm.weight": 1.017162,
- "model.encoder.layer.1.output.LayerNorm.bias": -0.0474442,
- "model.pooler.dense.weight": 0.0001295,
- "model.pooler.dense.bias": -0.0052078,
- "seq2seq_encoder.weight": -0.0015382,
- "seq2seq_encoder.bias": -0.0105704,
- "classifier.weight": 0.0261459,
- "classifier.bias": -0.0157966,
- }
- if config.get("use_crf", True):
- parameter_means_expected.update(
- {
- "crf.start_transitions": -0.0341042,
- "crf.end_transitions": 0.0140624,
- "crf.transitions": 0.0056733,
- }
- )
- assert parameter_means == parameter_means_expected
-
-
-def test_model_pickleable(model):
- import pickle
-
- pickle.dumps(model)
-
-
-def test_freeze_base_model():
- model = TokenClassificationModelWithSeq2SeqEncoderAndCrf(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=5,
- freeze_base_model=True,
- )
-
- base_model_params = dict(model.model.named_parameters(prefix="model"))
- assert len(base_model_params) > 0
- for param in base_model_params.values():
- assert not param.requires_grad
- task_params = {
- name: param for name, param in model.named_parameters() if name not in base_model_params
- }
- assert len(task_params) > 0
- for param in task_params.values():
- assert param.requires_grad
-
-
-def test_tune_base_model():
- model = TokenClassificationModelWithSeq2SeqEncoderAndCrf(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=5,
- )
- base_model_params = dict(model.model.named_parameters(prefix="model"))
- assert len(base_model_params) > 0
- for param in base_model_params.values():
- assert param.requires_grad
- task_params = {
- name: param for name, param in model.named_parameters() if name not in base_model_params
- }
- assert len(task_params) > 0
- for param in task_params.values():
- assert param.requires_grad
-
-
-def test_forward(batch, model):
- inputs, targets = batch
- batch_size, seq_len = inputs["input_ids"].shape
- num_classes = int(torch.max(targets["labels"]) + 1)
-
- # set seed to make sure the output is deterministic
- torch.manual_seed(42)
- output = model.forward(inputs)
- assert set(output) == {"logits"}
- logits = output["logits"]
- assert logits.shape == (batch_size, seq_len, num_classes)
- # check the first batch entry
- torch.testing.assert_close(
- logits[0],
- torch.tensor(
- [
- [
- -1.065280795097351,
- 0.22260898351669312,
- -0.013371739536523819,
- 1.0213487148284912,
- -0.08737741410732269,
- ],
- [
- -1.092915415763855,
- 0.07986105978488922,
- 0.011286348104476929,
- 0.7147902250289917,
- -0.014343257993459702,
- ],
- [
- -1.0107779502868652,
- 0.2041827142238617,
- -0.06531291455030441,
- 0.6551182270050049,
- 0.04944971576333046,
- ],
- [
- -0.3324984312057495,
- 0.27757787704467773,
- 0.13295423984527588,
- 0.26407280564308167,
- -0.007371138781309128,
- ],
- [
- -0.6176304817199707,
- 0.12915551662445068,
- 0.268213152885437,
- 0.43618908524513245,
- -0.13303528726100922,
- ],
- [
- -0.5220450758934021,
- 0.37291139364242554,
- 0.2522115111351013,
- 0.7383102178573608,
- 0.1278681606054306,
- ],
- [
- -1.0737248659133911,
- 0.0029090046882629395,
- 0.06924695521593094,
- 0.6680881977081299,
- -0.15523286163806915,
- ],
- [
- -0.5176048278808594,
- -0.01018303632736206,
- 0.14543311297893524,
- 0.5191693305969238,
- -0.3461107611656189,
- ],
- [
- -0.9277648329734802,
- 0.3154565095901489,
- -0.07648143172264099,
- 0.4210910201072693,
- 0.2663896083831787,
- ],
- [
- -0.8864655494689941,
- 0.2862459421157837,
- -0.04168111830949783,
- 0.4992614984512329,
- 0.28455498814582825,
- ],
- [
- -0.9500657916069031,
- 0.1869449019432068,
- -0.005329027771949768,
- 0.5908203721046448,
- 0.06730394065380096,
- ],
- [
- -0.5336291193962097,
- -0.053214408457279205,
- 0.22038350999355316,
- 0.48135989904403687,
- -0.4338146448135376,
- ],
- ]
- ),
- )
-
- # check the sums per sequence
- torch.testing.assert_close(
- logits.sum(1),
- torch.tensor(
- [
- [
- -9.530403137207031,
- 2.0144565105438232,
- 0.8975526690483093,
- 7.009620189666748,
- -0.3817189633846283,
- ],
- [
- -4.351415634155273,
- 0.3694552183151245,
- -0.8337129354476929,
- 3.612205743789673,
- 0.15454095602035522,
- ],
- [
- -6.173098564147949,
- -2.6261491775512695,
- 0.47521746158599854,
- 3.344158172607422,
- -5.086399078369141,
- ],
- [
- -9.28173542022705,
- -1.6196215152740479,
- 0.18393829464912415,
- 5.492751121520996,
- -4.148656845092773,
- ],
- ]
- ),
- )
-
-
-def test_step(batch, model, config):
- torch.manual_seed(42)
- loss = model._step("train", batch)
- assert loss is not None
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(75.52511596679688))
- elif config == {"use_crf": False}:
- torch.testing.assert_close(loss, torch.tensor(1.9434731006622314))
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-def test_training_step_and_on_epoch_end(batch, model, config):
- metric = model._get_metric(TRAINING, batch_idx=0)
- assert metric is None
- loss = model.training_step(batch, batch_idx=0)
- assert loss is not None
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(77.59623718261719))
- elif config == {"use_crf": False}:
- torch.testing.assert_close(loss, torch.tensor(1.9865683317184448))
- else:
- raise ValueError(f"Unknown config: {config}")
-
- model.on_train_epoch_end()
-
-
-def test_training_step_without_attention_mask(batch, model, config):
- inputs, targets = batch
- inputs_without_attention_mask = {k: v for k, v in inputs.items() if k != "attention_mask"}
- loss = model.training_step(batch=(inputs_without_attention_mask, targets), batch_idx=0)
- assert loss is not None
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(103.0061264038086))
- elif config == {"use_crf": False}:
- torch.testing.assert_close(loss, torch.tensor(1.9988830089569092))
- else:
- raise ValueError(f"Unknown config: {config}")
-
-
-def test_validation_step_and_on_epoch_end(batch, model, config):
- metric = model._get_metric(VALIDATION, batch_idx=0)
- metric.reset()
- loss = model.validation_step(batch, batch_idx=0)
- assert loss is not None
- metric_values = {k: v.item() for k, v in metric.compute().items()}
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(77.59623718261719))
- assert metric_values == {
- "token/macro/f1": 0.20666667819023132,
- "token/micro/f1": 0.2068965584039688,
- "token/macro/precision": 0.29019609093666077,
- "token/macro/recall": 0.2666666805744171,
- "token/micro/precision": 0.2068965584039688,
- "token/micro/recall": 0.2068965584039688,
- "span/ORG/f1": 0.3636363744735718,
- "span/ORG/recall": 0.25,
- "span/ORG/precision": 0.6666666865348816,
- "span/PER/f1": 0.0,
- "span/PER/recall": 0.0,
- "span/PER/precision": 0.0,
- "span/micro/f1": 0.12121212482452393,
- "span/micro/recall": 0.07407407462596893,
- "span/micro/precision": 0.3333333432674408,
- "span/macro/f1": 0.1818181872367859,
- "span/macro/recall": 0.125,
- "span/macro/precision": 0.3333333432674408,
- }
- elif config == {"use_crf": False}:
- torch.testing.assert_close(loss, torch.tensor(1.9865683317184448))
- assert metric_values == {
- "token/macro/f1": 0.11717171967029572,
- "token/micro/f1": 0.17241379618644714,
- "token/macro/precision": 0.22500000894069672,
- "token/macro/recall": 0.24444444477558136,
- "token/micro/precision": 0.17241379618644714,
- "token/micro/recall": 0.17241379618644714,
- "span/ORG/f1": 0.0,
- "span/ORG/recall": 0.0,
- "span/ORG/precision": 0.0,
- "span/PER/f1": 0.0,
- "span/PER/recall": 0.0,
- "span/PER/precision": 0.0,
- "span/micro/f1": 0.0,
- "span/micro/recall": 0.0,
- "span/micro/precision": 0.0,
- "span/macro/f1": 0.0,
- "span/macro/recall": 0.0,
- "span/macro/precision": 0.0,
- }
- else:
- raise ValueError(f"Unknown config: {config}")
-
- model.on_validation_epoch_end()
-
-
-def test_test_step_and_on_epoch_end(batch, model, config):
- metric = model._get_metric(TESTING, batch_idx=0)
- metric.reset()
- loss = model.test_step(batch, batch_idx=0)
- assert loss is not None
- metric_values = {k: v.item() for k, v in metric.compute().items()}
- if config == {}:
- torch.testing.assert_close(loss, torch.tensor(77.59623718261719))
- assert metric_values == {
- "token/macro/f1": 0.20666667819023132,
- "token/micro/f1": 0.2068965584039688,
- "token/macro/precision": 0.29019609093666077,
- "token/macro/recall": 0.2666666805744171,
- "token/micro/precision": 0.2068965584039688,
- "token/micro/recall": 0.2068965584039688,
- "span/ORG/f1": 0.3636363744735718,
- "span/ORG/recall": 0.25,
- "span/ORG/precision": 0.6666666865348816,
- "span/PER/f1": 0.0,
- "span/PER/recall": 0.0,
- "span/PER/precision": 0.0,
- "span/micro/f1": 0.12121212482452393,
- "span/micro/recall": 0.07407407462596893,
- "span/micro/precision": 0.3333333432674408,
- "span/macro/f1": 0.1818181872367859,
- "span/macro/recall": 0.125,
- "span/macro/precision": 0.3333333432674408,
- }
- elif config == {"use_crf": False}:
- torch.testing.assert_close(loss, torch.tensor(1.9865683317184448))
- assert metric_values == {
- "token/macro/f1": 0.11717171967029572,
- "token/micro/f1": 0.17241379618644714,
- "token/macro/precision": 0.22500000894069672,
- "token/macro/recall": 0.24444444477558136,
- "token/micro/precision": 0.17241379618644714,
- "token/micro/recall": 0.17241379618644714,
- "span/ORG/f1": 0.0,
- "span/ORG/recall": 0.0,
- "span/ORG/precision": 0.0,
- "span/PER/f1": 0.0,
- "span/PER/recall": 0.0,
- "span/PER/precision": 0.0,
- "span/micro/f1": 0.0,
- "span/micro/recall": 0.0,
- "span/micro/precision": 0.0,
- "span/macro/f1": 0.0,
- "span/macro/recall": 0.0,
- "span/macro/precision": 0.0,
- }
- else:
- raise ValueError(f"Unknown config: {config}")
-
- model.on_test_epoch_end()
-
-
-@pytest.mark.parametrize("test_step", [False, True])
-def test_predict_and_predict_step(model, batch, config, test_step):
- torch.manual_seed(42)
- if test_step:
- predictions = model.predict_step(batch, batch_idx=0, dataloader_idx=0)
- else:
- predictions = model.predict(batch[0])
-
- assert set(predictions) == {"labels", "probabilities"}
- labels = predictions["labels"]
- probabilities = predictions["probabilities"]
- if config == {}:
- torch.testing.assert_close(
- labels,
- torch.tensor(
- [
- [-100, 3, 3, 1, 3, -100, -100, -100, -100, -100, -100, -100],
- [-100, 3, 1, 4, 4, 3, 3, 3, 2, -100, -100, -100],
- [-100, 3, 2, 2, 3, 3, 3, 2, -100, -100, -100, -100],
- [-100, 3, 3, 3, 2, 3, 1, 4, 3, 3, 2, -100],
- ]
- ),
- )
- elif config == {"use_crf": False}:
- torch.testing.assert_close(
- labels,
- torch.tensor(
- [
- [-100, 3, 3, 1, 3, -100, -100, -100, -100, -100, -100, -100],
- [-100, 3, 3, 4, 4, 3, 3, 3, 3, -100, -100, -100],
- [-100, 3, 2, 2, 3, 3, 3, 2, -100, -100, -100, -100],
- [-100, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, -100],
- ]
- ),
- )
- else:
- raise ValueError(f"Unknown config: {config}")
-
- assert labels.shape == batch[1]["labels"].shape
- torch.testing.assert_close(
- probabilities[:2].round(decimals=4),
- torch.tensor(
- [
- [
- [0.0549, 0.1991, 0.1573, 0.4426, 0.1461],
- [0.0614, 0.1984, 0.1853, 0.3744, 0.1806],
- [0.0661, 0.2229, 0.1702, 0.3499, 0.1909],
- [0.1310, 0.2411, 0.2087, 0.2379, 0.1813],
- [0.0997, 0.2104, 0.2418, 0.2861, 0.1619],
- [0.0904, 0.2213, 0.1961, 0.3189, 0.1732],
- [0.0654, 0.1920, 0.2052, 0.3734, 0.1639],
- [0.1162, 0.1929, 0.2254, 0.3276, 0.1379],
- [0.0716, 0.2483, 0.1678, 0.2759, 0.2364],
- [0.0726, 0.2344, 0.1689, 0.2901, 0.2340],
- [0.0708, 0.2207, 0.1821, 0.3305, 0.1958],
- [0.1162, 0.1879, 0.2470, 0.3206, 0.1284],
- ],
- [
- [0.1242, 0.1911, 0.1516, 0.3256, 0.2075],
- [0.1291, 0.2089, 0.2046, 0.2890, 0.1684],
- [0.2033, 0.2016, 0.1920, 0.2260, 0.1771],
- [0.1793, 0.2191, 0.1800, 0.1889, 0.2328],
- [0.1854, 0.2150, 0.1638, 0.1898, 0.2460],
- [0.1363, 0.2007, 0.1738, 0.2887, 0.2005],
- [0.1254, 0.2014, 0.1826, 0.2890, 0.2016],
- [0.1305, 0.2056, 0.2056, 0.2590, 0.1993],
- [0.1400, 0.2022, 0.2252, 0.2544, 0.1783],
- [0.1299, 0.2051, 0.1933, 0.2751, 0.1966],
- [0.1088, 0.2086, 0.1599, 0.2861, 0.2367],
- [0.0910, 0.1794, 0.1840, 0.3793, 0.1663],
- ],
- ]
- ),
- )
-
-
-def test_configure_optimizers(model):
- model.trainer = Trainer(max_epochs=10)
- optimizer_and_schedular = model.configure_optimizers()
- assert optimizer_and_schedular is not None
- optimizers, schedulers = optimizer_and_schedular
-
- assert len(optimizers) == 1
- optimizer = optimizers[0]
- assert isinstance(optimizer, torch.optim.AdamW)
- assert optimizer.defaults["lr"] == 1e-05
- assert optimizer.defaults["weight_decay"] == 0.01
- assert optimizer.defaults["eps"] == 1e-08
-
- assert len(schedulers) == 1
- scheduler = schedulers[0]
- assert isinstance(scheduler["scheduler"], torch.optim.lr_scheduler.LambdaLR)
-
-
-def test_configure_optimizers_with_task_learning_rate():
- model = TokenClassificationModelWithSeq2SeqEncoderAndCrf(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=5,
- warmup_proportion=0.0,
- task_learning_rate=1e-4,
- )
- optimizer = model.configure_optimizers()
- assert optimizer is not None
- assert isinstance(optimizer, torch.optim.AdamW)
- assert len(optimizer.param_groups) == 2
- # check that all parameters are in the optimizer
- assert set(optimizer.param_groups[0]["params"]) | set(
- optimizer.param_groups[1]["params"]
- ) == set(model.parameters())
-
- # base model parameters
- param_group = optimizer.param_groups[0]
- assert param_group["lr"] == 1e-05
- assert len(param_group["params"]) == 39
-
- # task parameters
- param_group = optimizer.param_groups[1]
- assert param_group["lr"] == 1e-04
- assert len(param_group["params"]) == 5
diff --git a/tests/taskmodules/__init__.py b/tests/taskmodules/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/taskmodules/common/__init__.py b/tests/taskmodules/common/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/taskmodules/common/test_interfaces.py b/tests/taskmodules/common/test_interfaces.py
deleted file mode 100644
index 5c65c797a..000000000
--- a/tests/taskmodules/common/test_interfaces.py
+++ /dev/null
@@ -1,27 +0,0 @@
-from typing import Any, Dict, List, Set, Tuple
-
-from pie_modules.annotations import Span
-from pie_modules.taskmodules.common import AnnotationEncoderDecoder
-
-
-def test_annotation_encoder_decoder():
- """Test the AnnotationEncoderDecoder class."""
-
- class SpanAnnotationEncoderDecoder(AnnotationEncoderDecoder[Span, Tuple[int, int]]):
- """A class that uses the AnnotationEncoderDecoder class."""
-
- def encode(self, annotation: Span, **kwargs) -> Tuple[int, int]:
- return annotation.start, annotation.end
-
- def decode(self, encoding: Tuple[int, int], **kwargs) -> Span:
- return Span(start=encoding[0], end=encoding[1])
-
- def validate_encoding(self, encoding: Tuple[int, int]) -> Set[str]:
- return {"order"} if encoding[0] > encoding[1] else set()
-
- encoder_decoder = SpanAnnotationEncoderDecoder()
-
- assert encoder_decoder.encode(Span(start=1, end=2)) == (1, 2)
- assert encoder_decoder.decode((1, 2)) == Span(start=1, end=2)
- assert encoder_decoder.validate_encoding((1, 2)) == set()
- assert encoder_decoder.validate_encoding((2, 1)) == {"order"}
diff --git a/tests/taskmodules/common/test_mixins.py b/tests/taskmodules/common/test_mixins.py
deleted file mode 100644
index b6b921bb2..000000000
--- a/tests/taskmodules/common/test_mixins.py
+++ /dev/null
@@ -1,166 +0,0 @@
-import dataclasses
-import logging
-from typing import List
-
-import torch
-from pie_core import Annotation
-
-from pie_modules.taskmodules.common import BatchableMixin
-from pie_modules.taskmodules.common.mixins import RelationStatisticsMixin
-
-
-def test_batchable_mixin():
- """Test the BatchableMixin class."""
-
- @dataclasses.dataclass
- class Foo(BatchableMixin):
- """A class that uses the BatchableMixin class."""
-
- a: List[int]
-
- @property
- def len_a(self):
- """Return the length of the list a."""
- return len(self.a)
-
- x = Foo(a=[1, 2, 3])
- y = Foo(a=[4, 5])
-
- batch = Foo.batch(
- values=[x, y], dtypes={"a": torch.int64, "len_a": torch.int64}, pad_values={"a": 0}
- )
- torch.testing.assert_close(batch["a"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
- torch.testing.assert_close(batch["len_a"], torch.tensor([3, 2]))
-
-
-def test_relation_statistics_mixin_show_statistics(caplog):
- """Test the RelationStatisticsMixin class."""
-
- class Foo(RelationStatisticsMixin):
- """A class that uses the RelationStatisticsMixin class."""
-
- pass
-
- @dataclasses.dataclass(eq=True, frozen=True)
- class TestAnnotation(Annotation):
- label: str
- score: float = dataclasses.field(default=1.0, compare=False)
-
- x = Foo(collect_statistics=True)
-
- relations = [
- TestAnnotation(label="A", score=1.0),
- TestAnnotation(label="B", score=0.5),
- TestAnnotation(label="C", score=0.0),
- TestAnnotation(label="D", score=0.3),
- ]
- # all available relations
- x.collect_all_relations(kind="available", relations=relations)
- # relations skipped for a reason ("test")
- x.collect_relation(kind="skipped_test", relation=relations[1])
- # mark two relations as used, one of them is skipped for another (unknown) reason
- x.collect_all_relations(kind="used", relations=[relations[0], relations[2]])
-
- statistics = x.get_statistics()
-
- assert statistics == {
- ("available", "A"): 1,
- ("available", "B"): 1,
- ("available", "D"): 1,
- ("available", "no_relation"): 1,
- ("skipped_other", "D"): 1,
- ("skipped_test", "B"): 1,
- ("used", "A"): 1,
- ("used", "no_relation"): 1,
- }
-
- with caplog.at_level(logging.INFO):
- x.show_statistics()
- assert caplog.messages[0] == (
- "Foo does not have a `none_label` attribute. "
- "Using default value 'no_relation'. "
- "`none_label` is used as the label for relations with score 0 in statistics and "
- "all relations with label different from `none_label` will be summarized to 'all_relations'. "
- "Set the `none_label` attribute before using statistics or "
- "overwrite `get_none_label_for_statistics()` function to get rid of this message."
- )
- assert caplog.messages[1] == (
- "statistics:\n"
- "| | available | skipped_other | skipped_test | used | used % |\n"
- "|:--------------|------------:|----------------:|---------------:|-------:|---------:|\n"
- "| A | 1 | 0 | 0 | 1 | 100 |\n"
- "| B | 1 | 0 | 1 | 0 | 0 |\n"
- "| D | 1 | 1 | 0 | 0 | 0 |\n"
- "| no_relation | 1 | 0 | 0 | 1 | 100 |\n"
- "| all_relations | 3 | 1 | 1 | 1 | 33 |"
- )
-
-
-def test_relation_statistics_mixin_show_statistics_no_relations(caplog):
- """Test the RelationStatisticsMixin class with no predictions."""
-
- class Foo(RelationStatisticsMixin):
- """A class that uses the RelationStatisticsMixin class."""
-
- pass
-
- x = Foo(collect_statistics=True)
-
- # Test with no relations collected
- x.collect_all_relations(kind="available", relations=[])
- x.collect_all_relations(kind="used", relations=[])
-
- statistics = x.get_statistics()
-
- assert statistics == {}
-
- with caplog.at_level(logging.INFO):
- x.show_statistics()
- assert caplog.messages[0] == "statistics:\n" "|--:|\n" "| 0 |"
-
-
-def test_relation_statistics_mixin_show_statistics_custom_none_label(caplog):
- """Test the RelationStatisticsMixin class with custom none_label."""
-
- class Foo(RelationStatisticsMixin):
- """A class that uses the RelationStatisticsMixin class.
-
- It also sets the `none_label` attribute which will be used by statistics.
- """
-
- def __init__(self, none_label: str = "no_relation", **kwargs):
- super().__init__(**kwargs)
- self.none_label = none_label
-
- @dataclasses.dataclass(eq=True, frozen=True)
- class TestAnnotation(Annotation):
- label: str
- score: float = dataclasses.field(default=1.0, compare=False)
-
- x = Foo(collect_statistics=True, none_label="None_Label")
-
- relations = [
- TestAnnotation(label="A", score=1.0),
- TestAnnotation(label="B", score=0.5),
- TestAnnotation(label="C", score=0.0),
- TestAnnotation(label="D", score=0.3),
- ]
- # all available relations
- x.collect_all_relations(kind="available", relations=relations)
- # relations skipped for a reason ("test")
- x.collect_relation(kind="skipped_test", relation=relations[1])
- # mark two relations as used, one of them is skipped for another (unknown) reason
- x.collect_all_relations(kind="used", relations=[relations[0], relations[2]])
-
- with caplog.at_level(logging.INFO):
- x.show_statistics()
- assert caplog.messages[0] == (
- "statistics:\n"
- "| | available | skipped_other | skipped_test | used | used % |\n"
- "|:--------------|------------:|----------------:|---------------:|-------:|---------:|\n"
- "| A | 1 | 0 | 0 | 1 | 100 |\n"
- "| B | 1 | 0 | 1 | 0 | 0 |\n"
- "| D | 1 | 1 | 0 | 0 | 0 |\n"
- "| None_Label | 1 | 0 | 0 | 1 | 100 |\n"
- "| all_relations | 3 | 1 | 1 | 1 | 33 |"
- )
diff --git a/tests/taskmodules/common/test_taskmodule_with_document_converter.py b/tests/taskmodules/common/test_taskmodule_with_document_converter.py
deleted file mode 100644
index 8540b5a3f..000000000
--- a/tests/taskmodules/common/test_taskmodule_with_document_converter.py
+++ /dev/null
@@ -1,163 +0,0 @@
-from typing import Optional, Type
-
-import pytest
-from pie_core import Document
-from typing_extensions import TypeAlias
-
-from pie_modules.documents import TextDocumentWithLabeledSpansAndBinaryRelations
-from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule
-from pie_modules.taskmodules.common import TaskModuleWithDocumentConverter
-from tests.conftest import TestDocument
-
-DocumentType: TypeAlias = TestDocument
-ConvertedDocumentType: TypeAlias = TextDocumentWithLabeledSpansAndBinaryRelations
-
-
-class MyRETaskModuleWithDocConverter(
- TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule
-):
- @property
- def document_type(self) -> Optional[Type[Document]]:
- return TestDocument
-
- def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
- result = document.as_type(
- TextDocumentWithLabeledSpansAndBinaryRelations,
- field_mapping={"entities": "labeled_spans", "relations": "binary_relations"},
- )
- new2old_span = {
- new_s: old_s for old_s, new_s in zip(document.entities, result.labeled_spans)
- }
- result.metadata["new2old_span"] = new2old_span
- return result
-
- def _integrate_predictions_from_converted_document(
- self, document: DocumentType, converted_document: ConvertedDocumentType
- ) -> None:
- new2old_span = converted_document.metadata["new2old_span"]
- for rel in converted_document.binary_relations.predictions:
- new_rel = rel.copy(head=new2old_span[rel.head], tail=new2old_span[rel.tail])
- document.relations.predictions.append(new_rel)
-
-
-@pytest.fixture(scope="module")
-def taskmodule(documents):
- result = MyRETaskModuleWithDocConverter(tokenizer_name_or_path="bert-base-cased")
- result.prepare(documents)
- return result
-
-
-def test_taskmodule(taskmodule):
- assert taskmodule is not None
- assert taskmodule.document_type == TestDocument
-
-
-@pytest.fixture(scope="module")
-def task_encodings(taskmodule, documents):
- return taskmodule.encode(documents, encode_target=True)
-
-
-def test_task_encodings(task_encodings):
- assert len(task_encodings) == 7
-
-
-def test_decode(taskmodule, task_encodings):
- label_indices = [0, 1, 3, 0, 0, 2, 0]
- probabilities = [0.1738, 0.6643, 0.2101, 0.0801, 0.0319, 0.81, 0.3079]
- task_outputs = [
- {"labels": [taskmodule.id_to_label[label_idx]], "probabilities": [prob]}
- for label_idx, prob in zip(label_indices, probabilities)
- ]
- docs_with_predictions = taskmodule.decode(
- task_encodings=task_encodings, task_outputs=task_outputs
- )
- assert all(isinstance(doc, TestDocument) for doc in docs_with_predictions)
-
- all_gold_relations = [doc.relations.resolve() for doc in docs_with_predictions]
- assert all_gold_relations == [
- [("per:employee_of", (("PER", "Entity A"), ("ORG", "B")))],
- [
- ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))),
- ("per:founder", (("PER", "Entity G"), ("ORG", "I"))),
- ("org:founded_by", (("ORG", "I"), ("ORG", "H"))),
- ],
- [
- ("per:employee_of", (("PER", "Entity M"), ("ORG", "N"))),
- ("per:founder", (("PER", "it"), ("ORG", "O"))),
- ("org:founded_by", (("ORG", "O"), ("PER", "it"))),
- ],
- ]
-
- all_predicted_relations = [
- doc.relations.predictions.resolve() for doc in docs_with_predictions
- ]
- assert all_predicted_relations == [
- [("no_relation", (("PER", "Entity A"), ("ORG", "B")))],
- [
- ("org:founded_by", (("PER", "Entity G"), ("ORG", "H"))),
- ("per:founder", (("PER", "Entity G"), ("ORG", "I"))),
- ("no_relation", (("ORG", "I"), ("ORG", "H"))),
- ],
- [
- ("no_relation", (("PER", "Entity M"), ("ORG", "N"))),
- ("per:employee_of", (("PER", "it"), ("ORG", "O"))),
- ("no_relation", (("ORG", "O"), ("PER", "it"))),
- ],
- ]
-
-
-class MyRETaskModuleWithDocConverterWithoutDocType(
- TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule
-):
- def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
- pass
-
- def _integrate_predictions_from_converted_document(
- self, document: DocumentType, converted_document: ConvertedDocumentType
- ) -> None:
- pass
-
-
-def test_missing_document_type_overwrite():
- taskmodule = MyRETaskModuleWithDocConverterWithoutDocType(
- tokenizer_name_or_path="bert-base-cased"
- )
-
- with pytest.raises(NotImplementedError) as e:
- taskmodule.document_type
- assert (
- str(e.value)
- == "please overwrite document_type for MyRETaskModuleWithDocConverterWithoutDocType"
- )
-
-
-class MyRETaskModuleWithWrongDocConverter(
- TaskModuleWithDocumentConverter, RETextClassificationWithIndicesTaskModule
-):
- @property
- def document_type(self) -> Optional[Type[Document]]:
- return TestDocument
-
- def _convert_document(self, document: DocumentType) -> ConvertedDocumentType:
- result = TextDocumentWithLabeledSpansAndBinaryRelations(text="dummy")
- result.metadata["original_document"] = None
- return result
-
- def _integrate_predictions_from_converted_document(
- self, document: DocumentType, converted_document: ConvertedDocumentType
- ) -> None:
- pass
-
-
-def test_wrong_doc_converter(documents):
- taskmodule = MyRETaskModuleWithWrongDocConverter(tokenizer_name_or_path="bert-base-cased")
- taskmodule.prepare(documents)
- with pytest.raises(ValueError) as e:
- taskmodule.encode(documents, encode_target=True)
- assert (
- str(e.value)
- == "metadata of converted_document has already and entry 'original_document', "
- "this is not allowed. Please adjust "
- "'MyRETaskModuleWithWrongDocConverter._convert_document()' to produce "
- "documents without that key in metadata."
- )
diff --git a/tests/taskmodules/common/test_utils.py b/tests/taskmodules/common/test_utils.py
deleted file mode 100644
index b9a426402..000000000
--- a/tests/taskmodules/common/test_utils.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import torch
-
-from pie_modules.taskmodules.common.utils import get_first_occurrence_index
-
-
-def test_get_first_occurrence_index():
- tensor: torch.LongTensor = torch.tensor(
- [
- [0, 1, 1, 1, 1, 1], # 1
- [0, 0, 1, 1, 1, 1], # 2
- [0, 1, 1, 0, 0, 1], # 1
- [1, 1, 1, 1, 1, 1], # 0
- [0, 0, 0, 0, 0, 0], # 6 (=size of input) because no 1s at all
- ]
- ).to(torch.long)
- indices = get_first_occurrence_index(tensor, 1)
- torch.testing.assert_close(indices, torch.tensor([1, 2, 1, 0, 6]))
diff --git a/tests/taskmodules/metrics/__init__.py b/tests/taskmodules/metrics/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/taskmodules/metrics/test_precision_recall_and_f1_for_labeled_annotations.py b/tests/taskmodules/metrics/test_precision_recall_and_f1_for_labeled_annotations.py
deleted file mode 100644
index 90cc93e59..000000000
--- a/tests/taskmodules/metrics/test_precision_recall_and_f1_for_labeled_annotations.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import pytest
-from torch import tensor
-
-from pie_modules.annotations import LabeledSpan
-from pie_modules.taskmodules.metrics import PrecisionRecallAndF1ForLabeledAnnotations
-
-
-def test_precision_recall_and_f1_for_labeled_annotations():
- metric = PrecisionRecallAndF1ForLabeledAnnotations()
- assert metric.metric_state == {}
-
- metric.update(
- gold=[LabeledSpan(start=0, end=1, label="a")],
- predicted=[LabeledSpan(start=0, end=1, label="a")],
- )
- metric_state = {k: v.tolist() for k, v in metric.metric_state.items()}
- assert metric_state == {"counts_a": [1, 1, 1], "counts_micro": [1, 1, 1]}
- value = metric.compute()
- assert value == {
- "a": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "macro": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "micro": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- }
-
- metric.reset()
- metric.update(
- gold=[LabeledSpan(start=0, end=1, label="a"), LabeledSpan(start=0, end=1, label="b")],
- predicted=[LabeledSpan(start=0, end=1, label="b"), LabeledSpan(start=0, end=1, label="c")],
- )
- metric_state = {k: v.tolist() for k, v in metric.metric_state.items()}
- assert metric_state == {
- "counts_a": [1, 0, 0],
- "counts_b": [1, 1, 1],
- "counts_c": [0, 1, 0],
- "counts_micro": [2, 2, 1],
- }
- assert metric.compute() == {
- "b": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "a": {"recall": 0.0, "precision": 0.0, "f1": 0.0},
- "c": {"recall": 0.0, "precision": 0.0, "f1": 0.0},
- "macro": {
- "f1": tensor(0.3333333432674408),
- "precision": tensor(0.3333333432674408),
- "recall": tensor(0.3333333432674408),
- },
- "micro": {"recall": 0.5, "precision": 0.5, "f1": 0.5},
- }
-
- # check deduplication in same update
- metric.reset()
- metric.update(
- gold=[
- LabeledSpan(start=0, end=1, label="a"),
- LabeledSpan(start=0, end=1, label="a"),
- LabeledSpan(start=0, end=1, label="b"),
- ],
- predicted=[
- LabeledSpan(start=0, end=1, label="b"),
- LabeledSpan(start=0, end=1, label="b"),
- LabeledSpan(start=0, end=1, label="c"),
- ],
- )
- metric_state = {k: v.tolist() for k, v in metric.metric_state.items()}
- assert metric_state == {
- "counts_a": [1, 0, 0],
- "counts_b": [1, 1, 1],
- "counts_c": [0, 1, 0],
- "counts_micro": [2, 2, 1],
- }
-
- # assert no deduplication over multiple updates
- metric.reset()
- metric.update(
- gold=[LabeledSpan(start=0, end=1, label="a")],
- predicted=[LabeledSpan(start=0, end=1, label="b")],
- )
- metric.update(
- gold=[LabeledSpan(start=0, end=1, label="b")],
- predicted=[LabeledSpan(start=0, end=1, label="a")],
- )
- metric_state = {k: v.tolist() for k, v in metric.metric_state.items()}
- assert metric_state == {
- "counts_a": [1, 1, 0],
- "counts_b": [1, 1, 0],
- "counts_c": [0, 0, 0],
- "counts_micro": [2, 2, 0],
- }
- assert metric.compute() == {
- "a": {"f1": 0.0, "precision": 0.0, "recall": 0.0},
- "b": {"f1": 0.0, "precision": 0.0, "recall": 0.0},
- "c": {"f1": 0.0, "precision": 0.0, "recall": 0.0},
- "macro": {"f1": 0.0, "precision": 0.0, "recall": 0.0},
- "micro": {"f1": 0.0, "precision": 0.0, "recall": 0.0},
- }
-
-
-def test_precision_recall_and_f1_for_labeled_annotations_in_percent():
- metric = PrecisionRecallAndF1ForLabeledAnnotations(
- in_percent=True, flatten_result_with_sep="/"
- )
-
- metric.update(
- gold=[LabeledSpan(start=0, end=1, label="a")],
- predicted=[LabeledSpan(start=0, end=1, label="a"), LabeledSpan(start=0, end=1, label="b")],
- )
- values = {k: v.item() for k, v in metric.compute().items()}
- assert values == {
- "a/f1": 100.0,
- "a/precision": 100.0,
- "a/recall": 100.0,
- "b/f1": 0.0,
- "b/precision": 0.0,
- "b/recall": 0.0,
- "macro/f1": 50.0,
- "macro/precision": 50.0,
- "macro/recall": 50.0,
- "micro/f1": 66.66667175292969,
- "micro/precision": 50.0,
- "micro/recall": 100.0,
- }
-
-
-def test_precision_recall_and_f1_for_labeled_annotations_with_label_mapping():
- metric = PrecisionRecallAndF1ForLabeledAnnotations(
- label_mapping={"a": "label_a", "b": "label_b"}
- )
-
- metric.update(
- gold=[LabeledSpan(start=0, end=1, label="a")],
- predicted=[LabeledSpan(start=0, end=1, label="a"), LabeledSpan(start=0, end=1, label="b")],
- )
- assert metric.compute() == {
- "label_a": {"f1": 1.0, "precision": 1.0, "recall": 1.0},
- "label_b": {"f1": 0.0, "precision": 0.0, "recall": 0.0},
- "macro": {"f1": 0.5, "precision": 0.5, "recall": 0.5},
- "micro": {"f1": 0.6666666666666666, "precision": 0.5, "recall": 1.0},
- }
-
-
-def test_precision_recall_and_f1_for_labeled_annotations_key_micro_error():
- metric = PrecisionRecallAndF1ForLabeledAnnotations()
- with pytest.raises(ValueError) as excinfo:
- metric.update(
- gold=[LabeledSpan(start=0, end=1, label="micro")],
- predicted=[],
- )
- assert (
- str(excinfo.value)
- == "The key 'micro' was used as an annotation label, but it is reserved for the micro average. "
- "You can change which key is used for that with the 'key_micro' argument."
- )
diff --git a/tests/taskmodules/metrics/test_wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py b/tests/taskmodules/metrics/test_wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py
deleted file mode 100644
index b4e34ceb7..000000000
--- a/tests/taskmodules/metrics/test_wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function.py
+++ /dev/null
@@ -1,144 +0,0 @@
-import json
-from typing import Any, Dict, Tuple
-
-import pytest
-from torch import tensor
-from torchmetrics import Metric
-
-from pie_modules.taskmodules.metrics import (
- WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction,
-)
-
-
-class TestMetric(Metric):
- """A simple metric that computes the exact match ratio between predictions and targets."""
-
- def __init__(self):
- super().__init__()
- self.add_state("matching", default=[])
-
- def update(self, prediction: str, target: str):
- self.matching.append(prediction == target)
-
- def compute(self):
- # Note: returning NaN in the case of an empty list would be more correct, but
- # returning 0.0 is more convenient for testing.
- return sum(self.matching) / len(self.matching) if self.matching else 0.0
-
-
-@pytest.fixture(scope="module")
-def wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function():
- def decode_with_errors_function(x: str) -> Tuple[Dict[str, Any], Dict[str, int]]:
- if x == "error":
- return {"entities": [], "relations": []}, {"dummy": 1}
- else:
- return json.loads(x), {"dummy": 0}
-
- layer_metrics = {
- "entities": TestMetric(),
- "relations": TestMetric(),
- }
- metric = WrappedLayerMetricsWithUnbatchAndDecodeWithErrorsFunction(
- layer_metrics=layer_metrics,
- unbatch_function=lambda x: x.split("\n"),
- decode_layers_with_errors_function=decode_with_errors_function,
- )
- return metric
-
-
-def test_wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function(
- wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function,
-):
- metric = wrapped_layer_metrics_with_unbatch_and_decode_with_errors_function
- assert metric is not None
- assert metric.unbatch_function is not None
- assert metric.decode_layers_with_errors_function is not None
- assert metric.layer_metrics is not None
- assert metric.metric_state == {
- "total": tensor(0),
- "exact_encoding_matches": tensor(0),
- }
-
- values = metric.compute()
- assert metric.metric_state
- assert values == {
- "decoding_errors": {"all": 0.0},
- "entities": 0.0,
- "exact_encoding_matches": 0.0,
- "relations": 0.0,
- }
-
- metric.reset()
- # Prediction and expected are the same.
- metric.update(
- prediction=json.dumps({"entities": ["E1"], "relations": ["R1"]}),
- expected=json.dumps({"entities": ["E1"], "relations": ["R1"]}),
- )
- assert metric.metric_state == {
- "total": tensor(1),
- "exact_encoding_matches": tensor(1),
- "errors_dummy": tensor(0),
- }
- values = metric.compute()
- assert values == {
- "decoding_errors": {"all": 0.0, "dummy": 0.0},
- "entities": 1.0,
- "exact_encoding_matches": 1.0,
- "relations": 1.0,
- }
-
- metric.reset()
- # Prediction and expected are different and there are multiple entries.
- # The first entry is an exact match, the second entry is not.
- metric.update(
- prediction=json.dumps({"entities": ["E1"], "relations": ["R1"]})
- + "\n"
- + json.dumps({"entities": ["E1"], "relations": ["R1"]}),
- expected=json.dumps({"entities": ["E1"], "relations": ["R1"]})
- + "\n"
- + json.dumps({"entities": ["E1"], "relations": ["R2"]}),
- )
- assert metric.metric_state == {
- "total": tensor(2),
- "exact_encoding_matches": tensor(1),
- "errors_dummy": tensor(0),
- }
- values = metric.compute()
- assert values == {
- "decoding_errors": {"all": 0.0, "dummy": 0.0},
- "entities": 1.0,
- "exact_encoding_matches": 0.5,
- "relations": 0.5,
- }
-
- metric.reset()
- # Encoding error
- metric.update(
- prediction="error",
- expected=json.dumps({"entities": ["E1"], "relations": []}),
- )
- assert metric.metric_state == {
- "total": tensor(1),
- "exact_encoding_matches": tensor(0),
- "errors_dummy": tensor(1),
- }
- values = metric.compute()
- # In the case on an error, the decoding function returns adict with empty lists for entities and relations.
- # Thus, we get a perfect match for entities and a 0.0 match for relations.
- assert values == {
- "decoding_errors": {"all": 1.0, "dummy": 1.0},
- "entities": 0.0,
- "exact_encoding_matches": 0.0,
- "relations": 1.0,
- }
-
- # test mismatched number of predictions and targets
- metric.reset()
- with pytest.raises(ValueError) as excinfo:
- metric.update(
- prediction=json.dumps({"entities": ["E1"], "relations": ["R1"]}),
- expected=json.dumps({"entities": ["E1"], "relations": ["R1"]})
- + "\n"
- + json.dumps({"entities": ["E1"], "relations": ["R1"]}),
- )
- assert str(excinfo.value) == "Number of predictions (1) and targets (2) do not match."
diff --git a/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py b/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py
deleted file mode 100644
index e06130b09..000000000
--- a/tests/taskmodules/metrics/test_wrapped_metric_with_prepare_function.py
+++ /dev/null
@@ -1,137 +0,0 @@
-from functools import partial
-
-import pytest
-from torchmetrics import Metric
-
-from pie_modules.taskmodules.metrics import WrappedMetricWithPrepareFunction
-
-
-class TestMetric(Metric):
- """A simple metric that computes the exact match ratio between predictions and targets."""
-
- def __init__(self):
- super().__init__()
- self.add_state("matching", default=[])
-
- def update(self, prediction: str, target: str):
- self.matching.append(prediction == target)
-
- def compute(self):
- # Note: returning NaN in the case of an empty list would be more correct, but
- # returning 0.0 is more convenient for testing.
- return sum(self.matching) / len(self.matching) if self.matching else 0.0
-
-
-def test_metric():
- metric = WrappedMetricWithPrepareFunction(
- metric=TestMetric(), prepare_function=lambda x: x.split()[0]
- )
-
- assert metric is not None
- assert metric.prepare_function is not None
-
- assert metric.compute() == 0.0
-
- metric.reset()
- metric(prediction="abc", target="abc")
- assert metric.compute() == 1.0
-
- metric.reset()
- metric(prediction="abc", target="def")
- assert metric.compute() == 0.0
-
- metric.reset()
- metric(prediction="abc def", target="abc xyz")
- # we consider just the first word, so this is still 1.0
- assert metric.compute() == 1.0
-
- metric.reset()
- metric(prediction="abc def", target="xyz def")
- assert metric.compute() == 0.0
-
-
-def split_both_and_remove_where_both_match(
- preds: str, targets: str, match: str
-) -> tuple[list[str], list[str]]:
- preds = preds.split()
- targets = targets.split()
- not_both_none_indices = [
- i for i, (p, t) in enumerate(zip(preds, targets)) if p != match or t != match
- ]
- preds = [preds[i] for i in not_both_none_indices]
- targets = [targets[i] for i in not_both_none_indices]
- return preds, targets
-
-
-def test_wrapped_metric_with_prepare_both_function():
- metric = WrappedMetricWithPrepareFunction(
- metric=TestMetric(),
- prepare_together_function=partial(split_both_and_remove_where_both_match, match="none"),
- prepare_does_unbatch=True,
- )
-
- assert metric is not None
- assert metric.prepare_both_function is not None
-
- assert metric.compute() == 0.0
-
- # none is removed from both, remaining is the same
- metric.reset()
- metric(prediction="abc none", target="abc none")
- assert metric.compute() == 1.0
-
- # none is removed from both, remaining is different
- metric.reset()
- metric(prediction="abc none", target="def none")
- assert metric.compute() == 0.0
-
- # none is not removed from both, remaining is partially the same
- metric.reset()
- metric(prediction="abc def", target="abc none")
- assert metric.compute() == 0.5
-
- # none is not removed from both, remaining is different
- metric.reset()
- metric(prediction="abc def", target="def none")
- assert metric.compute() == 0.0
-
-
-@pytest.fixture(scope="module")
-def wrapped_metric_with_unbatch_function():
- # just split the strings to unbatch the inputs
- return WrappedMetricWithPrepareFunction(
- metric=TestMetric(), prepare_function=lambda x: x.split(), prepare_does_unbatch=True
- )
-
-
-def test_wrapped_metric_with_unbatch_function(wrapped_metric_with_unbatch_function):
- metric = wrapped_metric_with_unbatch_function
- assert metric is not None
-
- assert metric.compute() == 0.0
-
- metric.reset()
- metric(prediction="abc", target="abc")
- assert metric.compute() == 1.0
-
- metric.reset()
- metric(prediction="abc", target="def")
- assert metric.compute() == 0.0
-
- metric.reset()
- metric(prediction="abc def", target="abc def")
- assert metric.compute() == 1.0
-
- metric.reset()
- metric(prediction="abc def", target="def abc")
- assert metric.compute() == 0.0
-
- metric.reset()
- metric(prediction="abc xyz", target="def xyz")
- assert metric.compute() == 0.5
-
-
-def test_wrapped_metric_with_unbatch_function_size_mismatch(wrapped_metric_with_unbatch_function):
- with pytest.raises(ValueError) as excinfo:
- wrapped_metric_with_unbatch_function(prediction="abc", target="abc def")
- assert str(excinfo.value) == "Number of prepared predictions (1) and targets (2) do not match."
diff --git a/tests/taskmodules/pointer_network/__init__.py b/tests/taskmodules/pointer_network/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py b/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py
deleted file mode 100644
index 4dbeaab8a..000000000
--- a/tests/taskmodules/pointer_network/test_annotation_encoder_decoder.py
+++ /dev/null
@@ -1,441 +0,0 @@
-import pytest
-
-from pie_modules.annotations import BinaryRelation, LabeledSpan, Span
-from pie_modules.taskmodules.pointer_network.annotation_encoder_decoder import (
- BinaryRelationEncoderDecoder,
- DecodingLabelException,
- DecodingLengthException,
- DecodingNegativeIndexException,
- DecodingOrderException,
- LabeledSpanEncoderDecoder,
- SpanEncoderDecoder,
- SpanEncoderDecoderWithOffset,
-)
-
-
-@pytest.mark.parametrize("exclusive_end", [True, False])
-def test_span_encoder_decoder(exclusive_end):
- """Test the SimpleSpanEncoderDecoder class."""
-
- encoder_decoder = SpanEncoderDecoder(exclusive_end)
- if exclusive_end:
- assert encoder_decoder.encode(Span(start=1, end=2)) == [1, 2]
- assert encoder_decoder.decode([1, 2]) == Span(start=1, end=2)
- else:
- assert encoder_decoder.encode(Span(start=1, end=2)) == [1, 1]
- assert encoder_decoder.decode([1, 1]) == Span(start=1, end=2)
-
-
-def test_span_encoder_decoder_wrong_length():
- """Test the SimpleSpanEncoderDecoder class."""
-
- encoder_decoder = SpanEncoderDecoder()
- with pytest.raises(DecodingLengthException) as excinfo:
- encoder_decoder.decode([1])
- assert (
- str(excinfo.value)
- == "two values are required to decode as Span, but encoding has length 1"
- )
- assert excinfo.value.identifier == "len"
-
- with pytest.raises(DecodingLengthException) as excinfo:
- encoder_decoder.decode([1, 2, 3])
- assert (
- str(excinfo.value)
- == "two values are required to decode as Span, but encoding has length 3"
- )
- assert excinfo.value.identifier == "len"
-
-
-def test_span_encoder_decoder_wrong_order():
- """Test the SimpleSpanEncoderDecoder class."""
-
- encoder_decoder = SpanEncoderDecoder()
-
- with pytest.raises(DecodingOrderException) as excinfo:
- encoder_decoder.decode([3, 2])
- assert (
- str(excinfo.value)
- == "end index can not be smaller than start index, but got: start=3, end=2"
- )
- assert excinfo.value.identifier == "order"
-
- # zero-length span
- span = encoder_decoder.decode([1, 1])
- assert span is not None
-
-
-def test_span_encoder_decoder_wrong_offset():
- """Test the SimpleSpanEncoderDecoder class."""
-
- encoder_decoder = SpanEncoderDecoder()
-
- with pytest.raises(DecodingNegativeIndexException) as excinfo:
- encoder_decoder.decode([-1, 2])
- assert str(excinfo.value) == "indices must be positive, but got: [-1, 2]"
- assert excinfo.value.identifier == "index"
-
-
-def test_span_encoder_decoder_with_offset():
- """Test the SpanEncoderDecoderWithOffset class."""
-
- encoder_decoder = SpanEncoderDecoderWithOffset(offset=1)
-
- assert encoder_decoder.encode(Span(start=1, end=2)) == [2, 3]
- assert encoder_decoder.decode([2, 3]) == Span(start=1, end=2)
-
-
-@pytest.mark.parametrize("mode", ["indices_label", "label_indices"])
-def test_labeled_span_encoder_decoder(mode):
- """Test the LabeledSpanEncoderDecoder class."""
-
- label2id = {"A": 0, "B": 1}
- encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)),
- label2id=label2id,
- mode=mode,
- )
-
- if mode == "indices_label":
- assert encoder_decoder.encode(LabeledSpan(start=1, end=2, label="A")) == [3, 4, 0]
- assert encoder_decoder.decode([3, 4, 0]) == LabeledSpan(start=1, end=2, label="A")
- elif mode == "label_indices":
- assert encoder_decoder.encode(LabeledSpan(start=1, end=2, label="A")) == [0, 3, 4]
- assert encoder_decoder.decode([0, 3, 4]) == LabeledSpan(start=1, end=2, label="A")
- else:
- raise ValueError(f"unknown mode: {mode}")
-
-
-@pytest.mark.parametrize("mode", ["indices_label", "label_indices"])
-def test_labeled_span_encoder_decoder_wrong_label_encoding(mode):
- """Test the LabeledSpanEncoderDecoder class."""
-
- label2id = {"A": 0, "B": 1}
- encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)),
- label2id=label2id,
- mode=mode,
- )
-
- if mode == "indices_label":
- with pytest.raises(DecodingLabelException) as excinfo:
- encoder_decoder.decode([2, 3, 4])
- elif mode == "label_indices":
- with pytest.raises(DecodingLabelException) as excinfo:
- encoder_decoder.decode([4, 2, 3])
- assert str(excinfo.value) == "unknown label id: 4 (label2id: {'A': 0, 'B': 1})"
- assert excinfo.value.identifier == "label"
-
-
-def test_labeled_span_encoder_decoder_unknown_mode():
- """Test the LabeledSpanEncoderDecoder class."""
-
- label2id = {"A": 0, "B": 1}
- encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)),
- label2id=label2id,
- mode="unknown",
- )
- with pytest.raises(ValueError) as excinfo:
- encoder_decoder.encode(LabeledSpan(start=1, end=2, label="A"))
- assert str(excinfo.value) == "unknown mode: unknown"
-
- with pytest.raises(ValueError) as excinfo:
- encoder_decoder.decode([0, 3, 4])
- assert str(excinfo.value) == "unknown mode: unknown"
-
-
-@pytest.mark.parametrize(
- "mode", ["head_tail_label", "tail_head_label", "label_head_tail", "label_tail_head"]
-)
-def test_binary_relation_encoder_decoder(mode):
- """Test the BinaryRelationEncoderDecoder class."""
-
- label2id = {"A": 0, "B": 1, "C": 2}
- labeled_span_encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)),
- label2id=label2id,
- mode="indices_label",
- )
- encoder_decoder = BinaryRelationEncoderDecoder(
- head_encoder_decoder=labeled_span_encoder_decoder,
- tail_encoder_decoder=labeled_span_encoder_decoder,
- label2id=label2id,
- mode=mode,
- )
-
- if mode == "head_tail_label":
- assert encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=3, end=4, label="B"),
- label="C",
- )
- ) == [4, 5, 0, 6, 7, 1, 2]
- assert encoder_decoder.decode([4, 5, 0, 6, 7, 1, 2]) == BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=3, end=4, label="B"),
- label="C",
- )
- elif mode == "tail_head_label":
- assert encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=3, end=4, label="B"),
- label="C",
- )
- ) == [6, 7, 1, 4, 5, 0, 2]
- assert encoder_decoder.decode([6, 7, 1, 4, 5, 0, 2]) == BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=3, end=4, label="B"),
- label="C",
- )
- elif mode == "label_head_tail":
- assert encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=3, end=4, label="B"),
- label="C",
- )
- ) == [2, 4, 5, 0, 6, 7, 1]
- assert encoder_decoder.decode([2, 4, 5, 0, 6, 7, 1]) == BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=3, end=4, label="B"),
- label="C",
- )
- elif mode == "label_tail_head":
- assert encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=3, end=4, label="B"),
- label="C",
- )
- ) == [2, 6, 7, 1, 4, 5, 0]
- assert encoder_decoder.decode([2, 6, 7, 1, 4, 5, 0]) == BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=3, end=4, label="B"),
- label="C",
- )
-
-
-@pytest.mark.parametrize(
- "mode", ["head_tail_label", "tail_head_label", "label_head_tail", "label_tail_head"]
-)
-def test_binary_relation_encoder_decoder_loop_relation(mode):
- """Test the BinaryRelationEncoderDecoder class."""
-
- # we use different label2id for head and tail to test the case where the head and tail
- # have different label sets
- head_encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=3),
- label2id={"A": 1, "B": 2},
- mode="indices_label",
- )
- tail_encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=3),
- label2id={"A": -1, "B": -2},
- mode="indices_label",
- )
- encoder_decoder = BinaryRelationEncoderDecoder(
- head_encoder_decoder=head_encoder_decoder,
- tail_encoder_decoder=tail_encoder_decoder,
- label2id={"N": 3},
- mode=mode,
- loop_dummy_relation_name="L",
- none_label="N",
- )
-
- if mode == "head_tail_label":
- assert encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- ) == [4, 5, 1, 3, 3, 3, 3]
- assert encoder_decoder.decode([4, 5, 1, 3, 3, 3, 3]) == BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- elif mode == "tail_head_label":
- assert encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- ) == [4, 5, -1, 3, 3, 3, 3]
- assert encoder_decoder.decode([4, 5, -1, 3, 3, 3, 3]) == BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- elif mode == "label_head_tail":
- assert encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- ) == [3, 4, 5, 1, 3, 3, 3]
- assert encoder_decoder.decode([3, 4, 5, 1, 3, 3, 3]) == BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- elif mode == "label_tail_head":
- assert encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- ) == [3, 4, 5, -1, 3, 3, 3]
- assert encoder_decoder.decode([3, 4, 5, -1, 3, 3, 3]) == BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- else:
- raise ValueError(f"unknown mode: {mode}")
-
-
-@pytest.mark.parametrize(
- "loop_dummy_relation_name,none_label",
- [("L", None), (None, "N")],
-)
-def test_binary_relation_encoder_decoder_only_loop_or_none_label_provided(
- loop_dummy_relation_name, none_label
-):
- """Test the BinaryRelationEncoderDecoder class."""
-
- label2id = {"A": 0, "B": 1, "N": 2}
- labeled_span_encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)),
- label2id=label2id,
- mode="indices_label",
- )
- encoder_decoder = BinaryRelationEncoderDecoder(
- head_encoder_decoder=labeled_span_encoder_decoder,
- tail_encoder_decoder=labeled_span_encoder_decoder,
- label2id=label2id,
- mode="head_tail_label",
- loop_dummy_relation_name=loop_dummy_relation_name,
- none_label=none_label,
- )
-
- if loop_dummy_relation_name is not None:
- with pytest.raises(ValueError) as excinfo:
- encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label=loop_dummy_relation_name,
- )
- )
-
- assert (
- str(excinfo.value)
- == "loop_dummy_relation_name is set, but none_label is not set: None"
- )
- elif none_label is not None:
- none_id = label2id[none_label]
- with pytest.raises(ValueError) as excinfo:
- encoder_decoder.decode([4, 5, 1, none_id, none_id, none_id, none_id])
- assert (
- str(excinfo.value)
- == "loop_dummy_relation_name is not set, but none_label=N was found in decoded encoding: "
- "[4, 5, 1, 2, 2, 2, 2] (label2id: {'A': 0, 'B': 1, 'N': 2}))"
- )
- else:
- raise ValueError("unknown setting")
-
-
-@pytest.mark.parametrize(
- "loop_dummy_relation_name,none_label",
- [(None, None), ("L", "N")],
-)
-def test_binary_relation_encoder_decoder_unknown_mode(loop_dummy_relation_name, none_label):
- """Test the BinaryRelationEncoderDecoder class."""
-
- label2id = {"A": 0, "B": 1, "N": 2, "L": 3}
- labeled_span_encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)),
- label2id=label2id,
- mode="indices_label",
- )
- encoder_decoder = BinaryRelationEncoderDecoder(
- head_encoder_decoder=labeled_span_encoder_decoder,
- tail_encoder_decoder=labeled_span_encoder_decoder,
- label2id=label2id,
- mode="unknown",
- loop_dummy_relation_name=loop_dummy_relation_name,
- none_label=none_label,
- )
- with pytest.raises(ValueError) as excinfo:
- encoder_decoder.encode(
- BinaryRelation(
- head=LabeledSpan(start=1, end=2, label="A"),
- tail=LabeledSpan(start=1, end=2, label="A"),
- label="L",
- )
- )
- assert str(excinfo.value) == "unknown mode: unknown"
-
- with pytest.raises(ValueError) as excinfo:
- encoder_decoder.decode([2, 2, 2, 2, 2, 2, 2])
- assert str(excinfo.value) == "unknown mode: unknown"
-
-
-def test_binary_relation_encoder_decoder_wrong_encoding_size():
- """Test the BinaryRelationEncoderDecoder class."""
-
- label2id = {"A": 0, "B": 1, "C": 2}
- labeled_span_encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)),
- label2id=label2id,
- mode="indices_label",
- )
- encoder_decoder = BinaryRelationEncoderDecoder(
- head_encoder_decoder=labeled_span_encoder_decoder,
- tail_encoder_decoder=labeled_span_encoder_decoder,
- label2id=label2id,
- mode="head_tail_label",
- )
- with pytest.raises(DecodingLengthException) as excinfo:
- encoder_decoder.decode([1, 2, 3, 4, 5, 6])
- assert (
- str(excinfo.value)
- == "seven values are required to decode as BinaryRelation, but the encoding has length 6"
- )
- assert excinfo.value.identifier == "len"
-
- with pytest.raises(DecodingLengthException) as excinfo:
- encoder_decoder.decode([1, 2, 3, 4, 5, 6, 7, 8])
- assert (
- str(excinfo.value)
- == "seven values are required to decode as BinaryRelation, but the encoding has length 8"
- )
- assert excinfo.value.identifier == "len"
-
-
-def test_binary_relation_encoder_decoder_wrong_label_index():
- """Test the BinaryRelationEncoderDecoder class."""
-
- label2id = {"A": 0, "B": 1, "C": 2}
- labeled_span_encoder_decoder = LabeledSpanEncoderDecoder(
- span_encoder_decoder=SpanEncoderDecoderWithOffset(offset=len(label2id)),
- label2id=label2id,
- mode="indices_label",
- )
- encoder_decoder = BinaryRelationEncoderDecoder(
- head_encoder_decoder=labeled_span_encoder_decoder,
- tail_encoder_decoder=labeled_span_encoder_decoder,
- label2id=label2id,
- mode="head_tail_label",
- )
- with pytest.raises(DecodingLabelException) as excinfo:
- encoder_decoder.decode([1, 2, 3, 4, 5, 6, 7])
- assert str(excinfo.value) == "unknown label id: 7 (label2id: {'A': 0, 'B': 1, 'C': 2})"
- assert excinfo.value.identifier == "label"
diff --git a/tests/taskmodules/pointer_network/test_logits_processor.py b/tests/taskmodules/pointer_network/test_logits_processor.py
deleted file mode 100644
index 145fa1e4c..000000000
--- a/tests/taskmodules/pointer_network/test_logits_processor.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import pytest
-import torch
-
-from pie_modules.taskmodules.pointer_network.logits_processor import (
- FinitizeLogitsProcessor,
- PrefixConstrainedLogitsProcessorWithMaximum,
-)
-
-
-def test_prefix_constrained_logits_processor_with_maximum():
- def allow_last_three(batch_id, sent, max_index):
- return list(range(max_index - 3, max_index))
-
- logits_processor = PrefixConstrainedLogitsProcessorWithMaximum(
- prefix_allowed_tokens_fn=allow_last_three, num_beams=1
- )
-
- input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7]]).to(dtype=torch.long)
- scores = torch.tensor([[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0]]).to(dtype=torch.float)
- new_scores = logits_processor(input_ids, scores)
- assert new_scores.shape == scores.shape
- torch.testing.assert_close(
- new_scores,
- torch.tensor(
- [[-float("inf"), -float("inf"), -float("inf"), -float("inf"), 0.9, 0.9, 0.0]]
- ),
- )
-
-
-def test_prefix_constrained_logits_processor_with_maximum_with_inf_scores():
- def allow_last_three(batch_id, sent, max_index):
- return list(range(max_index - 3, max_index))
-
- logits_processor = PrefixConstrainedLogitsProcessorWithMaximum(
- prefix_allowed_tokens_fn=allow_last_three, num_beams=1
- )
- input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7]]).to(dtype=torch.long)
- scores_with_pos_inf = torch.tensor([[0.9, 0.9, float("inf"), 0.9, 0.9, 0.9, 0.0]]).to(
- dtype=torch.float
- )
- scores_with_neg_inf = torch.tensor([[0.9, 0.9, -float("inf"), 0.9, 0.9, 0.9, 0.0]]).to(
- dtype=torch.float
- )
-
- with pytest.raises(ValueError, match="scores contains ±inf or NaN"):
- logits_processor(input_ids, scores_with_pos_inf)
-
- with pytest.raises(ValueError, match="scores contains ±inf or NaN"):
- logits_processor(input_ids, scores_with_neg_inf)
-
-
-def test_prefix_constrained_logits_processor_with_maximum_without_allowed_tokens():
- def allow_no_tokens(batch_id, sent, max_index):
- return []
-
- logits_processor = PrefixConstrainedLogitsProcessorWithMaximum(
- prefix_allowed_tokens_fn=allow_no_tokens, num_beams=1
- )
-
- input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7]]).to(dtype=torch.long)
- scores = torch.tensor([[0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.0]]).to(dtype=torch.float)
-
- with pytest.raises(ValueError, match="No allowed token ids for batch_id"):
- logits_processor(input_ids, scores)
-
-
-def test_finitize_logits_processor():
- logits_processor = FinitizeLogitsProcessor()
-
- input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7]]).to(dtype=torch.long)
- scores = torch.tensor([[0.9, 0.9, float("inf"), 0.9, 0.9, -float("inf"), 0.0]]).to(
- dtype=torch.float
- )
- new_scores = logits_processor(input_ids, scores)
-
- assert new_scores.shape == scores.shape
- torch.testing.assert_close(
- new_scores,
- torch.tensor([[0.9, 0.9, 3.4028235e38, 0.9, 0.9, -3.4028235e38, 0.0]]),
- )
diff --git a/tests/taskmodules/test_cross_text_binary_coref.py b/tests/taskmodules/test_cross_text_binary_coref.py
deleted file mode 100644
index 065f9b519..000000000
--- a/tests/taskmodules/test_cross_text_binary_coref.py
+++ /dev/null
@@ -1,395 +0,0 @@
-import json
-from typing import Any, Dict, Union
-
-import pytest
-import torch.testing
-from pie_core.utils.dictionary import flatten_dict_s, list_of_dicts2dict_of_lists
-from torch import tensor
-from torchmetrics import Metric, MetricCollection
-
-from pie_modules.annotations import LabeledSpan
-from pie_modules.document.processing.text_pair import add_negative_coref_relations
-from pie_modules.documents import (
- BinaryCorefRelation,
- TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
-)
-from pie_modules.taskmodules import CrossTextBinaryCorefTaskModule
-from tests import FIXTURES_ROOT, _config_to_str
-
-TOKENIZER_NAME_OR_PATH = "bert-base-cased"
-DOC_IDX_WITH_TASK_ENCODINGS = 2
-
-CONFIGS = [
- {},
-]
-CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-
-
-@pytest.fixture(scope="module", params=CONFIGS_DICT.keys())
-def config(request):
- return CONFIGS_DICT[request.param]
-
-
-@pytest.fixture(scope="module")
-def positive_documents():
- doc1 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
- id="0", text="Entity A works at B.", text_pair="And she founded C."
- )
- doc1.labeled_spans.append(LabeledSpan(start=0, end=8, label="PERSON"))
- doc1.labeled_spans.append(LabeledSpan(start=18, end=19, label="COMPANY"))
- doc1.labeled_spans_pair.append(LabeledSpan(start=4, end=7, label="PERSON"))
- doc1.labeled_spans_pair.append(LabeledSpan(start=16, end=17, label="COMPANY"))
- doc1.binary_coref_relations.append(
- BinaryCorefRelation(head=doc1.labeled_spans[0], tail=doc1.labeled_spans_pair[0])
- )
-
- doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
- id="0", text="Bob loves his cat.", text_pair="She sleeps a lot."
- )
- doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON"))
- doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL"))
- doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL"))
- doc2.binary_coref_relations.append(
- BinaryCorefRelation(head=doc2.labeled_spans[1], tail=doc2.labeled_spans_pair[0])
- )
-
- return [doc1, doc2]
-
-
-def test_positive_documents(positive_documents):
- assert len(positive_documents) == 2
- doc1, doc2 = positive_documents
- assert doc1.labeled_spans.resolve() == [("PERSON", "Entity A"), ("COMPANY", "B")]
- assert doc1.labeled_spans_pair.resolve() == [("PERSON", "she"), ("COMPANY", "C")]
- assert doc1.binary_coref_relations.resolve() == [
- ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))
- ]
-
- assert doc2.labeled_spans.resolve() == [("PERSON", "Bob"), ("ANIMAL", "his cat")]
- assert doc2.labeled_spans_pair.resolve() == [("ANIMAL", "She")]
- assert doc2.binary_coref_relations.resolve() == [
- ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She")))
- ]
-
-
-@pytest.fixture(scope="module")
-def unprepared_taskmodule(config):
- taskmodule = CrossTextBinaryCorefTaskModule(
- tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH, **config
- )
- assert not taskmodule.is_from_pretrained
-
- return taskmodule
-
-
-@pytest.fixture(scope="module")
-def taskmodule(unprepared_taskmodule, positive_documents):
- unprepared_taskmodule.prepare(positive_documents)
- return unprepared_taskmodule
-
-
-@pytest.fixture(scope="module")
-def documents_with_negatives(taskmodule, positive_documents):
- file_name = (
- FIXTURES_ROOT / "taskmodules" / "cross_text_binary_coref" / "documents_with_negatives.json"
- )
-
- # result = list(add_negative_relations(positive_documents))
- # result_json = [doc.asdict() for doc in result]
- # with open(file_name, "w") as f:
- # json.dump(result_json, f, indent=2)
-
- with open(file_name) as f:
- result_json = json.load(f)
- result = [
- TextPairDocumentWithLabeledSpansAndBinaryCorefRelations.fromdict(doc_json)
- for doc_json in result_json
- ]
-
- return result
-
-
-@pytest.fixture(scope="module")
-def task_encodings_without_target(taskmodule, documents_with_negatives):
- task_encodings = taskmodule.encode_input(documents_with_negatives[DOC_IDX_WITH_TASK_ENCODINGS])
- return task_encodings
-
-
-def test_encode_input(task_encodings_without_target, taskmodule):
- task_encodings = task_encodings_without_target
- convert_ids_to_tokens = taskmodule.tokenizer.convert_ids_to_tokens
-
- inputs_dict = list_of_dicts2dict_of_lists(
- [task_encoding.inputs for task_encoding in task_encodings]
- )
- tokens = [convert_ids_to_tokens(encoding["input_ids"]) for encoding in inputs_dict["encoding"]]
- tokens_pair = [
- convert_ids_to_tokens(encoding["input_ids"]) for encoding in inputs_dict["encoding_pair"]
- ]
- assert tokens == [
- ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"],
- ["[CLS]", "And", "she", "founded", "C", ".", "[SEP]"],
- ]
- assert tokens_pair == [
- ["[CLS]", "En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"],
- ["[CLS]", "En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"],
- ]
- span_tokens = [
- toks[start:end]
- for toks, start, end in zip(
- tokens, inputs_dict["pooler_start_indices"], inputs_dict["pooler_end_indices"]
- )
- ]
- span_tokens_pair = [
- toks[start:end]
- for toks, start, end in zip(
- tokens_pair,
- inputs_dict["pooler_pair_start_indices"],
- inputs_dict["pooler_pair_end_indices"],
- )
- ]
- assert span_tokens == [["she"], ["C"]]
- assert span_tokens_pair == [["En", "##ti", "##ty", "A"], ["B"]]
-
-
-def test_encode_target(task_encodings_without_target, taskmodule):
- targets = [
- taskmodule.encode_target(task_encoding) for task_encoding in task_encodings_without_target
- ]
- assert targets == [1.0, 0.0]
-
-
-def test_encode_with_collect_statistics(taskmodule, positive_documents):
- documents_with_negatives = add_negative_coref_relations(positive_documents)
- original_values = taskmodule.collect_statistics
- taskmodule.collect_statistics = True
- taskmodule.encode(documents_with_negatives, encode_target=True)
- statistics = taskmodule.get_statistics()
- taskmodule.collect_statistics = original_values
-
- assert statistics == {
- ("available", "coref"): 4,
- ("available", "no_relation"): 6,
- ("used", "coref"): 4,
- ("used", "no_relation"): 6,
- }
-
-
-def test_encode_with_windowing(documents_with_negatives):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = CrossTextBinaryCorefTaskModule(
- tokenizer_name_or_path=tokenizer_name_or_path,
- max_window=4,
- collect_statistics=True,
- )
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents_with_negatives)
-
- assert len(documents_with_negatives) == 16
-
- task_encodings = taskmodule.encode(documents_with_negatives)
- statistics = taskmodule.get_statistics()
-
- assert statistics == {
- ("available", "coref"): 4,
- ("available", "no_relation"): 6,
- ("skipped_span_does_not_fit_into_window", "coref"): 2,
- ("skipped_span_does_not_fit_into_window", "no_relation"): 2,
- ("used", "coref"): 2,
- ("used", "no_relation"): 4,
- }
-
- assert len(task_encodings) == 6
- for task_encoding in task_encodings:
- for k, v in task_encoding.inputs["encoding"].items():
- assert len(v) <= taskmodule.max_window
- for k, v in task_encoding.inputs["encoding_pair"].items():
- assert len(v) <= taskmodule.max_window
-
-
-@pytest.fixture(scope="module")
-def task_encodings(taskmodule, documents_with_negatives):
- return taskmodule.encode(
- documents_with_negatives[DOC_IDX_WITH_TASK_ENCODINGS], encode_target=True
- )
-
-
-@pytest.fixture(scope="module")
-def batch(taskmodule, task_encodings):
- result = taskmodule.collate(task_encodings)
- return result
-
-
-def test_collate(batch, taskmodule):
- assert batch is not None
- inputs, targets = batch
- assert inputs is not None
- assert set(inputs) == {
- "pooler_end_indices",
- "encoding_pair",
- "pooler_pair_end_indices",
- "pooler_start_indices",
- "encoding",
- "pooler_pair_start_indices",
- }
- torch.testing.assert_close(
- inputs["encoding"]["input_ids"],
- torch.tensor(
- [[101, 1262, 1131, 1771, 140, 119, 102], [101, 1262, 1131, 1771, 140, 119, 102]]
- ),
- )
- torch.testing.assert_close(
- inputs["encoding"]["token_type_ids"], torch.zeros_like(inputs["encoding"]["input_ids"])
- )
- torch.testing.assert_close(
- inputs["encoding"]["attention_mask"], torch.ones_like(inputs["encoding"]["input_ids"])
- )
-
- torch.testing.assert_close(
- inputs["encoding_pair"]["input_ids"],
- torch.tensor(
- [
- [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102],
- [101, 13832, 3121, 2340, 138, 1759, 1120, 139, 119, 102],
- ]
- ),
- )
- torch.testing.assert_close(
- inputs["encoding_pair"]["token_type_ids"],
- torch.zeros_like(inputs["encoding_pair"]["input_ids"]),
- )
- torch.testing.assert_close(
- inputs["encoding_pair"]["attention_mask"],
- torch.ones_like(inputs["encoding_pair"]["input_ids"]),
- )
-
- torch.testing.assert_close(inputs["pooler_start_indices"], torch.tensor([[2], [4]]))
- torch.testing.assert_close(inputs["pooler_end_indices"], torch.tensor([[3], [5]]))
- torch.testing.assert_close(inputs["pooler_pair_start_indices"], torch.tensor([[1], [7]]))
- torch.testing.assert_close(inputs["pooler_pair_end_indices"], torch.tensor([[5], [8]]))
-
- torch.testing.assert_close(targets, {"scores": torch.tensor([1.0, 0.0])})
-
-
-@pytest.fixture(scope="module")
-def unbatched_output(taskmodule):
- model_output = {
- "scores": torch.tensor([0.5338148474693298, 0.9866107940673828]),
- }
- return taskmodule.unbatch_output(model_output=model_output)
-
-
-def test_unbatch_output(unbatched_output, taskmodule):
- assert len(unbatched_output) == 2
- assert unbatched_output == [
- {"is_similar": False, "score": 0.5338148474693298},
- {"is_similar": True, "score": 0.9866107702255249},
- ]
-
-
-def test_create_annotation_from_output(taskmodule, task_encodings, unbatched_output):
- all_new_annotations = []
- for task_encoding, task_output in zip(task_encodings, unbatched_output):
- for new_annotation in taskmodule.create_annotations_from_output(
- task_encoding=task_encoding, task_output=task_output
- ):
- all_new_annotations.append(new_annotation)
- assert all(layer_name == "binary_coref_relations" for layer_name, ann in all_new_annotations)
- resolve_annotations_with_scores = [
- (round(ann.score, 4), ann.resolve()) for layer_name, ann in all_new_annotations
- ]
- assert resolve_annotations_with_scores == [
- (0.9866, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))),
- ]
-
-
-def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]:
- if isinstance(metric_or_collection, Metric):
- return flatten_dict_s(metric_or_collection.metric_state)
- elif isinstance(metric_or_collection, MetricCollection):
- return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()})
- else:
- raise ValueError(f"unsupported type: {type(metric_or_collection)}")
-
-
-def test_configure_metric(taskmodule, batch):
- metric = taskmodule.configure_model_metric(stage="train")
-
- assert isinstance(metric, (Metric, MetricCollection))
- state = get_metric_state(metric)
- torch.testing.assert_close(
- state,
- {
- "continuous/auroc/preds": [],
- "continuous/auroc/target": [],
- "continuous/avg-P/preds": [],
- "continuous/avg-P/target": [],
- "continuous/f1/fn": tensor([0]),
- "continuous/f1/fp": tensor([0]),
- "continuous/f1/tn": tensor([0]),
- "continuous/f1/tp": tensor([0]),
- },
- )
-
- # targets = batch[1]
- targets = {
- "scores": torch.tensor([0.0, 1.0, 0.0, 0.0]),
- }
- metric.update(targets, targets)
-
- state = get_metric_state(metric)
- torch.testing.assert_close(
- state,
- {
- "continuous/auroc/preds": [tensor([0.0, 1.0, 0.0, 0.0])],
- "continuous/auroc/target": [tensor([0.0, 1.0, 0.0, 0.0])],
- "continuous/avg-P/preds": [tensor([0.0, 1.0, 0.0, 0.0])],
- "continuous/avg-P/target": [tensor([0.0, 1.0, 0.0, 0.0])],
- "continuous/f1/tp": tensor([1]),
- "continuous/f1/fp": tensor([0]),
- "continuous/f1/tn": tensor([3]),
- "continuous/f1/fn": tensor([0]),
- },
- )
-
- torch.testing.assert_close(
- metric.compute(),
- {"auroc": tensor(1.0), "avg-P": tensor(1.0), "f1": tensor(1.0)},
- )
-
- # torch.rand_like(targets)
- random_targets = {
- "scores": torch.tensor([0.2703, 0.6812, 0.2582, 0.9030]),
- }
- metric.update(random_targets, targets)
- state = get_metric_state(metric)
- torch.testing.assert_close(
- state,
- {
- "continuous/auroc/preds": [
- tensor([0.0, 1.0, 0.0, 0.0]),
- tensor([0.2703, 0.6812, 0.2582, 0.9030]),
- ],
- "continuous/auroc/target": [
- tensor([0.0, 1.0, 0.0, 0.0]),
- tensor([0.0, 1.0, 0.0, 0.0]),
- ],
- "continuous/avg-P/preds": [
- tensor([0.0, 1.0, 0.0, 0.0]),
- tensor([0.2703, 0.6812, 0.2582, 0.9030]),
- ],
- "continuous/avg-P/target": [
- tensor([0.0, 1.0, 0.0, 0.0]),
- tensor([0.0, 1.0, 0.0, 0.0]),
- ],
- "continuous/f1/tp": tensor([1]),
- "continuous/f1/fp": tensor([1]),
- "continuous/f1/tn": tensor([5]),
- "continuous/f1/fn": tensor([1]),
- },
- )
-
- torch.testing.assert_close(
- metric.compute(),
- {"auroc": tensor(0.91666663), "avg-P": tensor(0.83333337), "f1": tensor(0.50000000)},
- )
diff --git a/tests/taskmodules/test_extractive_question_answering.py b/tests/taskmodules/test_extractive_question_answering.py
deleted file mode 100644
index d38e532d6..000000000
--- a/tests/taskmodules/test_extractive_question_answering.py
+++ /dev/null
@@ -1,277 +0,0 @@
-import pytest
-import torch
-import transformers
-from pie_core import AnnotationLayer
-
-from pie_modules.annotations import ExtractiveAnswer, Question
-from pie_modules.documents import TextDocumentWithQuestionsAndExtractiveAnswers
-from pie_modules.taskmodules.extractive_question_answering import (
- ExtractiveQuestionAnsweringTaskModule,
-)
-
-
-@pytest.fixture()
-def document():
- document = TextDocumentWithQuestionsAndExtractiveAnswers(
- text="This is a test document", id="doc0"
- )
- document.questions.append(Question(text="What is the first word?"))
- document.answers.append(ExtractiveAnswer(question=document.questions[0], start=0, end=4))
- assert str(document.answers[0]) == "This"
- return document
-
-
-@pytest.fixture()
-def document1():
- document1 = TextDocumentWithQuestionsAndExtractiveAnswers(
- text="This is the second document", id="doc1"
- )
- document1.questions.append(Question(text="Which document is this?"))
- document1.answers.append(ExtractiveAnswer(question=document1.questions[0], start=13, end=18))
- assert str(document1.answers[0]) == "second"
- return document1
-
-
-@pytest.fixture()
-def document_with_no_answer():
- document = TextDocumentWithQuestionsAndExtractiveAnswers(
- text="This is a test document", id="document_with_no_answer"
- )
- document.questions.append(Question(text="What is the first word?"))
- return document
-
-
-@pytest.fixture()
-def document_with_multiple_answers():
- document = TextDocumentWithQuestionsAndExtractiveAnswers(
- text="This is a test document", id="document_with_multiple_answers"
- )
- document.questions.append(Question(text="What is the first word?"))
- document.answers.append(ExtractiveAnswer(question=document.questions[0], start=0, end=4))
- assert str(document.answers[0]) == "This"
- document.answers.append(ExtractiveAnswer(question=document.questions[0], start=0, end=7))
- assert str(document.answers[1]) == "This is"
- return document
-
-
-@pytest.fixture()
-def taskmodule():
- return ExtractiveQuestionAnsweringTaskModule(
- tokenizer_name_or_path="bert-base-uncased", max_length=128
- )
-
-
-def test_encode_input(
- taskmodule, document, document_with_no_answer, document_with_multiple_answers
-):
- inputs = taskmodule.encode_input(document)
- assert inputs is not None
- assert len(inputs) == 1
- expected_inputs = [
- 101,
- 2054,
- 2003,
- 1996,
- 2034,
- 2773,
- 1029,
- 102,
- 2023,
- 2003,
- 1037,
- 3231,
- 6254,
- 102,
- ]
- assert inputs[0].inputs == expected_inputs
-
- inputs = taskmodule.encode_input(document_with_no_answer)
- assert inputs is not None
- assert len(inputs) == 1
- assert inputs[0].inputs == expected_inputs
-
- inputs = taskmodule.encode_input(document_with_multiple_answers)
- assert inputs is not None
- assert len(inputs) == 1
- assert inputs[0].inputs == expected_inputs
-
-
-def test_encode_target(taskmodule, document, document_with_no_answer):
- inputs = taskmodule.encode_input(document)
- targets = taskmodule.encode_target(inputs[0])
- assert targets is not None
- assert targets.start_position == 8
- assert targets.end_position == 8
-
- inputs = taskmodule.encode_input(document_with_no_answer)
- targets = taskmodule.encode_target(inputs[0])
- assert targets is not None
- assert targets.start_position == 0
- assert targets.end_position == 0
-
-
-def test_get_question_layer(taskmodule, document, document_with_no_answer):
- question_layer = taskmodule.get_question_layer(document)
- assert question_layer is not None
- assert len(question_layer) == 1
- assert type(question_layer) is AnnotationLayer
- assert type(question_layer[0]) is Question
- assert question_layer[0].text == "What is the first word?"
-
- question_layer = taskmodule.get_question_layer(document_with_no_answer)
- assert question_layer is not None
- assert len(question_layer) == 1
- assert type(question_layer) is AnnotationLayer
- assert type(question_layer[0]) is Question
- assert question_layer[0].text == "What is the first word?"
-
-
-def test_get_answer_layer(taskmodule, document, document_with_no_answer):
- answer_layer = taskmodule.get_answer_layer(document)
- assert answer_layer is not None
- assert len(answer_layer) == 1
- assert type(answer_layer) is AnnotationLayer
- assert type(answer_layer[0]) is ExtractiveAnswer
- assert answer_layer[0].question.text == "What is the first word?"
- assert answer_layer[0].start == 0
- assert answer_layer[0].end == 4
-
- answer_layer = taskmodule.get_answer_layer(document_with_no_answer)
- assert answer_layer is not None
- assert len(answer_layer) == 0
- assert type(answer_layer) is AnnotationLayer
-
-
-def test_get_context(taskmodule, document, document_with_no_answer):
- context = taskmodule.get_context(document)
- assert context is not None
- assert context == "This is a test document"
-
- context = taskmodule.get_context(document_with_no_answer)
- assert context is not None
- assert context == "This is a test document"
-
-
-@pytest.fixture()
-def documents(document, document_with_no_answer):
- return [document, document_with_no_answer]
-
-
-@pytest.fixture()
-def batch_without_targets(taskmodule, documents):
- task_encodings = taskmodule.encode(documents)
- batch_encoding = taskmodule.collate(task_encodings)
- return batch_encoding
-
-
-def test_collate_without_targets(batch_without_targets):
- assert batch_without_targets is not None
- assert len(batch_without_targets) == 2
- inputs, targets = batch_without_targets
- assert inputs is not None
- assert targets is None
-
-
-@pytest.fixture()
-def task_encodings(taskmodule, documents):
- task_encodings = taskmodule.encode(documents, encode_target=True)
- return task_encodings
-
-
-@pytest.fixture()
-def batch(taskmodule, task_encodings):
- batch_encoding = taskmodule.collate(task_encodings)
- return batch_encoding
-
-
-def test_collate_with_targets(batch):
- assert batch is not None
- assert len(batch) == 2
- inputs, targets = batch
- assert inputs is not None
- assert set(inputs.data) == {"input_ids", "token_type_ids", "attention_mask"}
- assert inputs.data["input_ids"].shape == (2, 14)
- assert inputs.data["token_type_ids"].shape == (2, 14)
- assert inputs.data["attention_mask"].shape == (2, 14)
- assert targets is not None
- assert set(targets) == {"start_positions", "end_positions"}
- assert targets["start_positions"].shape == (2,)
- assert targets["end_positions"].shape == (2,)
-
- expected_inputs_ids = [
- [101, 2054, 2003, 1996, 2034, 2773, 1029, 102, 2023, 2003, 1037, 3231, 6254, 102],
- [101, 2054, 2003, 1996, 2034, 2773, 1029, 102, 2023, 2003, 1037, 3231, 6254, 102],
- ]
- expected_token_type_ids = [
- [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
- ]
- expected_attention_mask = [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- assert inputs.data["input_ids"].tolist() == expected_inputs_ids
- assert inputs.data["token_type_ids"].tolist() == expected_token_type_ids
- assert inputs.data["attention_mask"].tolist() == expected_attention_mask
-
- expected_start_positions = [8, 0]
- expected_end_positions = [8, 0]
- assert targets["start_positions"].tolist() == expected_start_positions
- assert targets["end_positions"].tolist() == expected_end_positions
-
-
-@pytest.fixture()
-def model_outputs(batch):
- # create probabilities that "perfectly" model the batch targets
- inputs, targets = batch
- start_probs = torch.zeros_like(inputs.input_ids, dtype=torch.float) + 0.05
- end_probs = torch.zeros_like(inputs.input_ids, dtype=torch.float) + 0.05
- # set target positions to 0.95 as a dummy value
- for idx, (start_position, end_position) in enumerate(
- zip(targets["start_positions"], targets["end_positions"])
- ):
- start_probs[idx, start_position] = 0.95
- end_probs[idx, end_position] = 0.95
-
- # convert probs to logits
- start_logits = torch.log(start_probs / (1 - start_probs))
- end_logits = torch.log(end_probs / (1 - end_probs))
-
- model_outputs = transformers.modeling_outputs.QuestionAnsweringModelOutput(
- start_logits=start_logits,
- end_logits=end_logits,
- )
- return model_outputs
-
-
-@pytest.fixture()
-def unbatched_output(taskmodule, model_outputs):
- return taskmodule.unbatch_output(model_outputs)
-
-
-def test_unbatch_output(unbatched_output):
- assert unbatched_output is not None
- assert len(unbatched_output) == 2
- # check first result
- assert unbatched_output[0].start == 8
- assert unbatched_output[0].end == 8
- assert unbatched_output[0].start_probability == pytest.approx(0.9652407)
- assert unbatched_output[0].end_probability == pytest.approx(0.9652407)
- # check second result
- assert unbatched_output[1].start == 0
- assert unbatched_output[1].end == 0
- assert unbatched_output[1].start_probability == pytest.approx(0.9652407)
- assert unbatched_output[1].end_probability == pytest.approx(0.9652407)
-
-
-def test_create_annotations_from_output(taskmodule, task_encodings, unbatched_output, documents):
- taskmodule.combine_outputs(task_encodings, unbatched_output)
- assert len(documents) > 0
- for doc in documents:
- gold_annotations = doc.answers
- predicted_annotations = doc.answers.predictions
- assert len(predicted_annotations) == len(gold_annotations)
- for predicted_annotation, gold_annotation in zip(predicted_annotations, gold_annotations):
- # we did construct the predicted annotations from the gold annotations, so they should be equal
- assert predicted_annotation == gold_annotation
- assert predicted_annotation.score == pytest.approx(0.9316896200180054)
diff --git a/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py b/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py
deleted file mode 100644
index b05374a9b..000000000
--- a/tests/taskmodules/test_labeled_span_extraction_by_token_classification.py
+++ /dev/null
@@ -1,883 +0,0 @@
-import logging
-import pickle
-from collections import defaultdict
-from dataclasses import dataclass
-from typing import Any, Dict, List
-
-import pytest
-import torch
-from pie_core import AnnotationLayer, annotation_field
-from torch import tensor
-from transformers import BatchEncoding
-
-from pie_modules.annotations import LabeledSpan
-from pie_modules.documents import (
- TextBasedDocument,
- TextDocumentWithLabeledSpans,
- TextDocumentWithLabeledSpansAndLabeledPartitions,
-)
-from pie_modules.taskmodules import LabeledSpanExtractionByTokenClassificationTaskModule
-from pie_modules.taskmodules.labeled_span_extraction_by_token_classification import (
- ModelOutputType,
-)
-
-
-def _config_to_str(cfg: Dict[str, Any]) -> str:
- # Converts a configuration dictionary to a string representation
- result = "-".join([f"{k}={cfg[k]}" for k in sorted(cfg)])
- return result
-
-
-CONFIG_DEFAULT = {}
-CONFIG_MAX_WINDOW = {
- "tokenize_kwargs": {"max_length": 8, "truncation": True, "return_overflowing_tokens": True}
-}
-CONFIG_MAX_WINDOW_WITH_STRIDE = {
- "tokenize_kwargs": {
- "max_length": 8,
- "stride": 2,
- "truncation": True,
- "return_overflowing_tokens": True,
- }
-}
-CONFIG_PARTITIONS = {"partition_annotation": "sentences"}
-
-CONFIGS: List[Dict[str, Any]] = [
- CONFIG_DEFAULT,
- CONFIG_MAX_WINDOW,
- CONFIG_MAX_WINDOW_WITH_STRIDE,
- CONFIG_PARTITIONS,
-]
-
-CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-
-
-@pytest.fixture(scope="module", params=CONFIGS_DICT.keys())
-def config(request):
- """
- - Provides clean and readable test configurations.
- - Yields config dictionaries from the CONFIGS list to produce clean test case identifiers.
-
- """
- return CONFIGS_DICT[request.param]
-
-
-@pytest.fixture(scope="module")
-def config_str(config):
- # Fixture returning a string representation of the config
- return _config_to_str(config)
-
-
-@pytest.fixture(scope="module")
-def unprepared_taskmodule(config):
- """
- - Prepares a task module with the specified tokenizer and configuration.
- - Sets up the task module with a unprepared state for testing purposes.
-
- """
- return LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased", span_annotation="entities", **config
- )
-
-
-@dataclass
-class ExampleDocument(TextBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- sentences: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
-
-
-@pytest.fixture(scope="module")
-def documents():
- """
- - Creates example documents with predefined texts.
- - Assigns labels to the documents for testing purposes.
-
- """
- doc1 = ExampleDocument(text="Mount Everest is the highest peak in the world.", id="doc1")
- doc1.entities.append(LabeledSpan(start=0, end=13, label="LOC"))
- assert str(doc1.entities[0]) == "Mount Everest"
-
- doc2 = ExampleDocument(text="Alice loves reading books. Bob enjoys playing soccer.", id="doc2")
- doc2.entities.append(LabeledSpan(start=0, end=5, label="PER"))
- assert str(doc2.entities[0]) == "Alice"
- doc2.entities.append(LabeledSpan(start=27, end=30, label="PER"))
- assert str(doc2.entities[1]) == "Bob"
- # we add just one sentence to doc2 that covers only Bob
- doc2.sentences.append(LabeledSpan(start=27, end=53, label="sentence"))
- assert str(doc2.sentences[0]) == "Bob enjoys playing soccer."
-
- return [doc1, doc2]
-
-
-def test_taskmodule(unprepared_taskmodule):
- assert unprepared_taskmodule is not None
-
-
-@pytest.fixture(scope="module")
-def taskmodule(unprepared_taskmodule, documents):
- """
- - Prepares the task module with the given documents, i.e. collect available label values.
- - Calls the necessary methods to prepare the task module with the documents.
- - Calls _prepare(documents) and then _post_prepare()
-
- """
- unprepared_taskmodule.prepare(documents)
- return unprepared_taskmodule
-
-
-def test_prepare(taskmodule):
- assert taskmodule is not None
- assert taskmodule.is_prepared
- assert taskmodule.label_to_id == {"B-LOC": 1, "B-PER": 3, "I-LOC": 2, "I-PER": 4, "O": 0}
- assert taskmodule.id_to_label == {0: "O", 1: "B-LOC", 2: "I-LOC", 3: "B-PER", 4: "I-PER"}
-
-
-def test_config(taskmodule):
- config = taskmodule._config()
- assert config["taskmodule_type"] == "LabeledSpanExtractionByTokenClassificationTaskModule"
- assert "labels" in config
- assert config["labels"] == ["LOC", "PER"]
-
-
-@pytest.fixture(scope="module")
-def task_encodings_without_targets(taskmodule, documents):
- """
- - Generates task encodings for all the documents, but without associated targets.
- """
- return taskmodule.encode(documents, encode_target=False)
-
-
-def test_task_encodings_without_targets(task_encodings_without_targets, taskmodule, config):
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(task_encoding.inputs.ids)
- for task_encoding in task_encodings_without_targets
- ]
-
- # If config is empty
- if config == CONFIG_DEFAULT:
- assert tokens == [
- [
- "[CLS]",
- "mount",
- "everest",
- "is",
- "the",
- "highest",
- "peak",
- "in",
- "the",
- "world",
- ".",
- "[SEP]",
- ],
- [
- "[CLS]",
- "alice",
- "loves",
- "reading",
- "books",
- ".",
- "bob",
- "enjoys",
- "playing",
- "soccer",
- ".",
- "[SEP]",
- ],
- ]
-
- # If config has the specified values (max_window=8, window_overlap=2)
- elif config == CONFIG_MAX_WINDOW_WITH_STRIDE:
- for t in tokens:
- assert len(t) <= 8
-
- assert tokens == [
- ["[CLS]", "mount", "everest", "is", "the", "highest", "peak", "[SEP]"],
- ["[CLS]", "highest", "peak", "in", "the", "world", ".", "[SEP]"],
- ["[CLS]", "alice", "loves", "reading", "books", ".", "bob", "[SEP]"],
- ["[CLS]", ".", "bob", "enjoys", "playing", "soccer", ".", "[SEP]"],
- ]
-
- # If config has the specified value (max_window=8)
- elif config == CONFIG_MAX_WINDOW:
- for t in tokens:
- assert len(t) <= 8
-
- assert tokens == [
- ["[CLS]", "mount", "everest", "is", "the", "highest", "peak", "[SEP]"],
- ["[CLS]", "in", "the", "world", ".", "[SEP]"],
- ["[CLS]", "alice", "loves", "reading", "books", ".", "bob", "[SEP]"],
- ["[CLS]", "enjoys", "playing", "soccer", ".", "[SEP]"],
- ]
-
- # If config has the specified value (partition_annotation=sentences)
- elif config == CONFIG_PARTITIONS:
- assert tokens
-
- else:
- raise ValueError(f"unknown config: {config}")
-
-
-@pytest.fixture(scope="module")
-def task_encodings(taskmodule, documents):
- return taskmodule.encode(documents, encode_target=True)
-
-
-def test_task_encodings(task_encodings, taskmodule, config):
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(task_encoding.inputs.ids)
- for task_encoding in task_encodings
- ]
- labels_tokens = [
- [taskmodule.id_to_label[x] if x != -100 else "" for x in task_encoding.targets]
- for task_encoding in task_encodings
- ]
- assert len(labels_tokens) == len(tokens)
-
- tokens_with_labels = list(zip(tokens, labels_tokens))
-
- for tokens, labels in tokens_with_labels:
- assert len(tokens) == len(labels)
-
- # If config is empty
- if config == CONFIG_DEFAULT:
- assert tokens_with_labels == [
- (
- [
- "[CLS]",
- "mount",
- "everest",
- "is",
- "the",
- "highest",
- "peak",
- "in",
- "the",
- "world",
- ".",
- "[SEP]",
- ],
- ["", "B-LOC", "I-LOC", "O", "O", "O", "O", "O", "O", "O", "O", ""],
- ),
- (
- [
- "[CLS]",
- "alice",
- "loves",
- "reading",
- "books",
- ".",
- "bob",
- "enjoys",
- "playing",
- "soccer",
- ".",
- "[SEP]",
- ],
- ["", "B-PER", "O", "O", "O", "O", "B-PER", "O", "O", "O", "O", ""],
- ),
- ]
-
- # If config has the specified values (max_window=8, window_overlap=2)
- elif config == CONFIG_MAX_WINDOW_WITH_STRIDE:
- for tokens, labels in tokens_with_labels:
- assert len(tokens) <= 8
-
- assert tokens_with_labels == [
- (
- ["[CLS]", "mount", "everest", "is", "the", "highest", "peak", "[SEP]"],
- ["", "B-LOC", "I-LOC", "O", "O", "O", "O", ""],
- ),
- (
- ["[CLS]", "highest", "peak", "in", "the", "world", ".", "[SEP]"],
- ["", "O", "O", "O", "O", "O", "O", ""],
- ),
- (
- ["[CLS]", "alice", "loves", "reading", "books", ".", "bob", "[SEP]"],
- ["", "B-PER", "O", "O", "O", "O", "B-PER", ""],
- ),
- (
- ["[CLS]", ".", "bob", "enjoys", "playing", "soccer", ".", "[SEP]"],
- ["", "O", "B-PER", "O", "O", "O", "O", ""],
- ),
- ]
-
- # If config has the specified value (max_window=8)
- elif config == CONFIG_MAX_WINDOW:
- for tokens, labels in tokens_with_labels:
- assert len(tokens) <= 8
-
- assert tokens_with_labels == [
- (
- ["[CLS]", "mount", "everest", "is", "the", "highest", "peak", "[SEP]"],
- ["", "B-LOC", "I-LOC", "O", "O", "O", "O", ""],
- ),
- (
- ["[CLS]", "in", "the", "world", ".", "[SEP]"],
- ["", "O", "O", "O", "O", ""],
- ),
- (
- ["[CLS]", "alice", "loves", "reading", "books", ".", "bob", "[SEP]"],
- ["", "B-PER", "O", "O", "O", "O", "B-PER", ""],
- ),
- (
- ["[CLS]", "enjoys", "playing", "soccer", ".", "[SEP]"],
- ["", "O", "O", "O", "O", ""],
- ),
- ]
-
- # If config has the specified value (partition_annotation=sentences)
- elif config == CONFIG_PARTITIONS:
- assert tokens_with_labels == [
- (
- ["[CLS]", "bob", "enjoys", "playing", "soccer", ".", "[SEP]"],
- ["", "B-PER", "O", "O", "O", "O", ""],
- )
- ]
-
- else:
- raise ValueError(f"unknown config: {config}")
-
-
-def test_encode_targets_with_overlap(caplog):
- # setup taskmodule
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased", labels=["LOC", "PER"]
- )
- taskmodule.post_prepare()
-
- # create a document with overlapping entities
- doc = TextDocumentWithLabeledSpans(
- text="Alice loves reading books. Bob enjoys playing soccer."
- )
- doc.labeled_spans.append(LabeledSpan(start=0, end=5, label="PER"))
- doc.labeled_spans.append(LabeledSpan(start=27, end=30, label="PER"))
- doc.labeled_spans.append(LabeledSpan(start=27, end=37, label="PER"))
- assert str(doc.labeled_spans[0]) == "Alice"
- assert str(doc.labeled_spans[1]) == "Bob"
- assert str(doc.labeled_spans[2]) == "Bob enjoys"
-
- # encode the document
- with caplog.at_level(logging.WARNING):
- task_encodings = taskmodule.encode([doc], encode_target=True)
- assert len(caplog.records) == 1
- assert (
- caplog.messages[0]
- == "tag already assigned (current span has an overlap: ('bob', 'enjoys'))."
- )
- assert len(task_encodings) == 1
- assert task_encodings[0].targets == [-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100]
-
-
-@pytest.fixture(scope="module")
-def task_encodings_for_batch(task_encodings, config):
- # just take everything we have
- return task_encodings
-
-
-@pytest.fixture(scope="module")
-def batch(taskmodule, task_encodings_for_batch, config) -> BatchEncoding:
- return taskmodule.collate(task_encodings_for_batch)
-
-
-def test_collate(batch, config):
- assert batch is not None
- assert len(batch) == 2
- inputs, targets = batch
-
- assert set(inputs.data) == {"input_ids", "attention_mask", "special_tokens_mask"}
- input_ids_list = inputs.input_ids.tolist()
- attention_mask_list = inputs.attention_mask.tolist()
- special_tokens_mask_list = inputs.special_tokens_mask.tolist()
- assert set(targets) == {"labels"}
- labels_list = targets["labels"].tolist()
-
- # If config is empty
- if config == CONFIG_DEFAULT:
- assert input_ids_list == [
- [101, 4057, 23914, 2003, 1996, 3284, 4672, 1999, 1996, 2088, 1012, 102],
- [101, 5650, 7459, 3752, 2808, 1012, 3960, 15646, 2652, 4715, 1012, 102],
- ]
- assert attention_mask_list == [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- assert labels_list == [
- [-100, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, -100],
- [-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100],
- ]
- assert special_tokens_mask_list == [
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
- [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
- ]
-
- # If config has the specified values (max_window=8, window_overlap=2)
- elif config == CONFIG_MAX_WINDOW_WITH_STRIDE:
- assert input_ids_list == [
- [101, 4057, 23914, 2003, 1996, 3284, 4672, 102],
- [101, 3284, 4672, 1999, 1996, 2088, 1012, 102],
- [101, 5650, 7459, 3752, 2808, 1012, 3960, 102],
- [101, 1012, 3960, 15646, 2652, 4715, 1012, 102],
- ]
- assert attention_mask_list == [
- [1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1],
- ]
- assert labels_list == [
- [-100, 1, 2, 0, 0, 0, 0, -100],
- [-100, 0, 0, 0, 0, 0, 0, -100],
- [-100, 3, 0, 0, 0, 0, 3, -100],
- [-100, 0, 3, 0, 0, 0, 0, -100],
- ]
-
- # If config has the specified values (max_window=8)
- elif config == CONFIG_MAX_WINDOW:
- assert input_ids_list == [
- [101, 4057, 23914, 2003, 1996, 3284, 4672, 102],
- [101, 1999, 1996, 2088, 1012, 102, 0, 0],
- [101, 5650, 7459, 3752, 2808, 1012, 3960, 102],
- [101, 15646, 2652, 4715, 1012, 102, 0, 0],
- ]
- assert attention_mask_list == [
- [1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 0, 0],
- ]
- assert labels_list == [
- [-100, 1, 2, 0, 0, 0, 0, -100],
- [-100, 0, 0, 0, 0, -100, -100, -100],
- [-100, 3, 0, 0, 0, 0, 3, -100],
- [-100, 0, 0, 0, 0, -100, -100, -100],
- ]
- assert special_tokens_mask_list == [
- [1, 0, 0, 0, 0, 0, 0, 1],
- [1, 0, 0, 0, 0, 1, 1, 1],
- [1, 0, 0, 0, 0, 0, 0, 1],
- [1, 0, 0, 0, 0, 1, 1, 1],
- ]
-
- # If config has the specified value (partition_annotation=sentences)
- elif config == CONFIG_PARTITIONS:
- assert input_ids_list == [[101, 3960, 15646, 2652, 4715, 1012, 102]]
- assert attention_mask_list == [[1, 1, 1, 1, 1, 1, 1]]
- assert labels_list == [[-100, 3, 0, 0, 0, 0, -100]]
- assert special_tokens_mask_list == [[1, 0, 0, 0, 0, 0, 1]]
-
- else:
- raise ValueError(f"unknown config: {config}")
-
- inputs_expected = BatchEncoding(
- data={
- "input_ids": torch.tensor(input_ids_list, dtype=torch.int64),
- "attention_mask": torch.tensor(attention_mask_list, dtype=torch.int64),
- "special_tokens_mask": torch.tensor(special_tokens_mask_list, dtype=torch.int64),
- }
- )
- assert set(inputs.data) == set(inputs_expected.data)
- labels_expected = torch.tensor(labels_list, dtype=torch.int64)
- assert torch.equal(targets["labels"], labels_expected)
-
-
-# This is not used, but can be used to create a batch of task encodings with targets for the unbatched_outputs fixture.
-@pytest.fixture(scope="module")
-def real_model_output(batch, taskmodule):
- from pytorch_ie.models import TransformerTokenClassificationModel
-
- model = TransformerTokenClassificationModel(
- model_name_or_path="prajjwal1/bert-tiny",
- num_classes=len(taskmodule.label_to_id),
- )
- inputs, targets = batch
- result = model(inputs)
- return result
-
-
-@pytest.fixture(scope="module")
-def model_output(config, batch, taskmodule) -> ModelOutputType:
- # create "perfect" output from targets
- labels = batch[1]["labels"]
- num_classes = len(taskmodule.label_to_id)
- # create one-hot encoding from labels
- labels_valid = labels.clone()
- labels_valid[labels_valid == taskmodule.label_pad_id] = taskmodule.label_to_id["O"]
- # create one-hot encoding from labels, but with 0.9 for the correct labels
- probabilities = (
- torch.nn.functional.one_hot(labels_valid, num_classes=num_classes).to(torch.float32) * 0.9
- )
- return {"labels": labels, "probabilities": probabilities}
-
-
-@pytest.fixture(scope="module")
-def unbatched_outputs(taskmodule, model_output):
- return taskmodule.unbatch_output(model_output)
-
-
-@pytest.mark.parametrize("combine_token_scores_method", ["mean", "max", "product", "UNKNOWN"])
-def test_combine_token_scores_method(documents, combine_token_scores_method):
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased",
- span_annotation="entities",
- combine_token_scores_method=combine_token_scores_method,
- )
- taskmodule.prepare(documents)
-
- task_encodings = taskmodule.encode(documents, encode_target=True)
- batch = taskmodule.collate(task_encodings)
-
- # create "perfect" output from targets
- labels = batch[1]["labels"]
- num_classes = len(taskmodule.label_to_id)
- # create one-hot encoding from labels
- labels_valid = labels.clone()
- labels_valid[labels_valid == taskmodule.label_pad_id] = taskmodule.label_to_id["O"]
- # create one-hot encoding from labels, but with 0.9 for the correct labels
- probabilities = (
- torch.nn.functional.one_hot(labels_valid, num_classes=num_classes).to(torch.float32) * 0.9
- )
- # stepwise decrease the "winning" probabilities per token to test the different combine_token_scores_methods
- diff = 0.0
- for i in range(probabilities.size(1)):
- probabilities[:, i] -= diff
- diff += 0.01
- probabilities[probabilities < 0] = 0.0
-
- model_output = {"labels": labels, "probabilities": probabilities}
-
- unbatched_outputs = taskmodule.unbatch_output(model_output)
-
- if combine_token_scores_method == "UNKNOWN":
- with pytest.raises(ValueError) as excinfo:
- taskmodule.decode_annotations(unbatched_outputs[0])
- assert str(excinfo.value) == "combine_token_scores_method=UNKNOWN is not supported."
- else:
- annotations = []
- scores = []
- for unbatched_output in unbatched_outputs:
- decoded_annotations = taskmodule.decode_annotations(unbatched_output)
- assert set(decoded_annotations.keys()) == {"labeled_spans"}
- # Sort the annotations in each document by start and end position and label
- sorted_annotations = sorted(decoded_annotations["labeled_spans"])
- annotations.append(sorted_annotations)
- scores.append([round(ann.score, 5) for ann in sorted_annotations])
-
- # input values are (before combination): [[0.89, 0.88], [[0.89], [0.84]]]
- if combine_token_scores_method == "mean":
- assert scores == [[(0.89 + 0.88) / 2], [0.89, 0.84]]
- elif combine_token_scores_method == "max":
- assert scores == [[0.89], [0.89, 0.84]]
- elif combine_token_scores_method == "min":
- assert scores == [[0.88], [0.89, 0.84]]
- elif combine_token_scores_method == "product":
- assert scores == [[0.89 * 0.88], [0.89, 0.84]]
- else:
- raise ValueError(f"unknown combine_token_scores_method: {combine_token_scores_method}")
-
-
-def test_unbatched_output(unbatched_outputs, config):
- assert unbatched_outputs is not None
-
- if config == CONFIG_DEFAULT:
- assert len(unbatched_outputs) == 2
- torch.testing.assert_close(
- unbatched_outputs[0]["labels"],
- torch.tensor([-100, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, -100]),
- )
- torch.testing.assert_close(
- unbatched_outputs[1]["labels"],
- torch.tensor([-100, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0, -100]),
- )
- elif config == CONFIG_MAX_WINDOW_WITH_STRIDE:
- assert len(unbatched_outputs) == 4
- torch.testing.assert_close(
- unbatched_outputs[0]["labels"], torch.tensor([-100, 1, 2, 0, 0, 0, 0, -100])
- )
- torch.testing.assert_close(
- unbatched_outputs[1]["labels"], torch.tensor([-100, 0, 0, 0, 0, 0, 0, -100])
- )
- torch.testing.assert_close(
- unbatched_outputs[2]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, 3, -100])
- )
- torch.testing.assert_close(
- unbatched_outputs[3]["labels"], torch.tensor([-100, 0, 3, 0, 0, 0, 0, -100])
- )
- elif config == CONFIG_MAX_WINDOW:
- assert len(unbatched_outputs) == 4
- torch.testing.assert_close(
- unbatched_outputs[0]["labels"], torch.tensor([-100, 1, 2, 0, 0, 0, 0, -100])
- )
- torch.testing.assert_close(
- unbatched_outputs[1]["labels"], torch.tensor([-100, 0, 0, 0, 0, -100, -100, -100])
- )
- torch.testing.assert_close(
- unbatched_outputs[2]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, 3, -100])
- )
- torch.testing.assert_close(
- unbatched_outputs[3]["labels"], torch.tensor([-100, 0, 0, 0, 0, -100, -100, -100])
- )
- elif config == CONFIG_PARTITIONS:
- assert len(unbatched_outputs) == 1
- torch.testing.assert_close(
- unbatched_outputs[0]["labels"], torch.tensor([-100, 3, 0, 0, 0, 0, -100])
- )
- else:
- raise ValueError(f"unknown config: {config}")
-
-
-def test_decode_annotations(taskmodule, unbatched_outputs, config):
- annotations = []
- for unbatched_output in unbatched_outputs:
- decoded_annotations = taskmodule.decode_annotations(unbatched_output)
- assert set(decoded_annotations.keys()) == {"labeled_spans"}
- # Sort the annotations in each document by start and end position and label
- annotations.append(
- sorted(
- decoded_annotations["labeled_spans"],
- key=lambda labeled_span: (
- labeled_span.start,
- labeled_span.end,
- labeled_span.label,
- ),
- )
- )
-
- # Check based on the config
- if config == CONFIG_DEFAULT:
- assert annotations == [
- [LabeledSpan(start=1, end=3, label="LOC")],
- [
- LabeledSpan(start=1, end=2, label="PER"),
- LabeledSpan(start=6, end=7, label="PER"),
- ],
- ]
-
- elif config == CONFIG_MAX_WINDOW_WITH_STRIDE:
- # We get two annotations for Bob because the window overlaps with the previous one.
- # This is not a problem because annotations get de-duplicated during serialization.
- assert annotations == [
- [LabeledSpan(start=1, end=3, label="LOC")],
- [],
- [
- LabeledSpan(start=1, end=2, label="PER"),
- LabeledSpan(start=6, end=7, label="PER"),
- ],
- [LabeledSpan(start=2, end=3, label="PER")],
- ]
-
- elif config == CONFIG_MAX_WINDOW:
- assert annotations == [
- [LabeledSpan(start=1, end=3, label="LOC")],
- [],
- [
- LabeledSpan(start=1, end=2, label="PER"),
- LabeledSpan(start=6, end=7, label="PER"),
- ],
- [],
- ]
-
- elif config == CONFIG_PARTITIONS:
- assert annotations == [[LabeledSpan(start=1, end=2, label="PER", score=1.0)]]
-
- else:
- raise ValueError(f"unknown config: {config}")
-
- # assert that all scores are 0.9
- for doc_annotations in annotations:
- for annotation in doc_annotations:
- assert round(annotation.score, 4) == 0.9
-
-
-@pytest.fixture(scope="module")
-def annotations_from_output(taskmodule, task_encodings_for_batch, unbatched_outputs, config):
- named_annotations_per_document = defaultdict(list)
- for task_encoding, task_output in zip(task_encodings_for_batch, unbatched_outputs):
- annotations = taskmodule.create_annotations_from_output(task_encoding, task_output)
- named_annotations_per_document[task_encoding.document.id].extend(list(annotations))
- return named_annotations_per_document
-
-
-def test_annotations_from_output(annotations_from_output, config, documents):
- assert annotations_from_output is not None
- # Sort the annotations in each document by start and end positions
- annotations_from_output = {
- doc_id: sorted(annotations, key=lambda x: (x[0], x[1].start, x[1].end))
- for doc_id, annotations in annotations_from_output.items()
- }
- documents_by_id = {doc.id: doc for doc in documents}
- documents_with_annotations = []
- resolved_annotations = defaultdict(list)
- # Check that the number of annotations is correct
- for doc_id, layer_names_and_annotations in annotations_from_output.items():
- new_doc = documents_by_id[doc_id].copy()
- for layer_name, annotation in layer_names_and_annotations:
- assert layer_name == "entities"
- assert isinstance(annotation, LabeledSpan)
- new_doc.entities.predictions.append(annotation)
- resolved_annotations[doc_id].append(str(annotation))
- documents_with_annotations.append(new_doc)
-
- resolved_annotations = dict(resolved_annotations)
- # Check based on the config
- if config == CONFIG_DEFAULT:
- assert resolved_annotations == {"doc1": ["Mount Everest"], "doc2": ["Alice", "Bob"]}
-
- elif config == CONFIG_MAX_WINDOW_WITH_STRIDE:
- # We get two annotations for Bob because the window overlaps with the previous one.
- # This is not a problem because annotations get de-duplicated during serialization.
- assert resolved_annotations == {"doc1": ["Mount Everest"], "doc2": ["Alice", "Bob", "Bob"]}
-
- elif config == CONFIG_MAX_WINDOW:
- assert resolved_annotations == {"doc1": ["Mount Everest"], "doc2": ["Alice", "Bob"]}
-
- elif config == CONFIG_PARTITIONS:
- assert resolved_annotations == {"doc2": ["Bob"]}
-
- else:
- raise ValueError(f"unknown config: {config}")
-
-
-def test_document_type():
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased"
- )
- assert taskmodule.document_type == TextDocumentWithLabeledSpans
-
-
-def test_document_type_with_partitions():
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased", partition_annotation="labeled_partitions"
- )
- assert taskmodule.document_type == TextDocumentWithLabeledSpansAndLabeledPartitions
-
-
-def test_document_type_with_non_default_span_annotation(caplog):
- with caplog.at_level(logging.WARNING):
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased", span_annotation="entities"
- )
- assert taskmodule.document_type is None
- assert len(caplog.records) == 1
- assert caplog.records[0].levelname == "WARNING"
- assert (
- caplog.records[0].message
- == "span_annotation=entities is not the default value ('labeled_spans'), so the taskmodule "
- "LabeledSpanExtractionByTokenClassificationTaskModule can not request the usual document type "
- "(TextDocumentWithLabeledSpans) for auto-conversion because this has the bespoken default value "
- "as layer name(s) instead of the provided one(s)."
- )
-
-
-def test_document_type_with_non_default_partition_annotation(caplog):
- with caplog.at_level(logging.WARNING):
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased", partition_annotation="sentences"
- )
- assert taskmodule.document_type is None
- assert len(caplog.records) == 1
- assert caplog.records[0].levelname == "WARNING"
- assert (
- caplog.records[0].message
- == "partition_annotation=sentences is not the default value ('labeled_partitions'), "
- "so the taskmodule LabeledSpanExtractionByTokenClassificationTaskModule can not request the usual document type "
- "(TextDocumentWithLabeledSpansAndLabeledPartitions) for auto-conversion because this has "
- "the bespoken default value as layer name(s) instead of the provided one(s)."
- )
-
-
-def test_document_type_with_non_default_span_and_partition_annotation(caplog):
- with caplog.at_level(logging.WARNING):
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased",
- span_annotation="entities",
- partition_annotation="sentences",
- )
- assert taskmodule.document_type is None
- assert len(caplog.records) == 1
- assert caplog.records[0].levelname == "WARNING"
- assert (
- caplog.records[0].message
- == "span_annotation=entities is not the default value ('labeled_spans') and "
- "partition_annotation=sentences is not the default value ('labeled_partitions'), "
- "so the taskmodule LabeledSpanExtractionByTokenClassificationTaskModule can not request the usual document "
- "type (TextDocumentWithLabeledSpansAndLabeledPartitions) for auto-conversion because "
- "this has the bespoken default value as layer name(s) instead of the provided one(s)."
- )
-
-
-def test_configure_model_metric(documents):
- taskmodule = LabeledSpanExtractionByTokenClassificationTaskModule(
- tokenizer_name_or_path="bert-base-uncased",
- span_annotation="entities",
- labels=["LOC", "PER"],
- )
- taskmodule.post_prepare()
-
- metric = taskmodule.configure_model_metric(stage="test")
- values = metric.compute()
- assert values == {
- "token/macro/f1": tensor(0.0),
- "token/micro/f1": tensor(0.0),
- "token/macro/precision": tensor(0.0),
- "token/macro/recall": tensor(0.0),
- "token/micro/precision": tensor(0.0),
- "token/micro/recall": tensor(0.0),
- }
-
- batch = taskmodule.collate(taskmodule.encode(documents, encode_target=True))
- targets = batch[1]
- metric.update(targets, targets)
- values = metric.compute()
- assert values == {
- "span/LOC/f1": tensor(1.0),
- "span/LOC/precision": tensor(1.0),
- "span/LOC/recall": tensor(1.0),
- "span/PER/f1": tensor(1.0),
- "span/PER/precision": tensor(1.0),
- "span/PER/recall": tensor(1.0),
- "span/macro/f1": tensor(1.0),
- "span/macro/precision": tensor(1.0),
- "span/macro/recall": tensor(1.0),
- "span/micro/f1": tensor(1.0),
- "span/micro/precision": tensor(1.0),
- "span/micro/recall": tensor(1.0),
- "token/macro/f1": tensor(1.0),
- "token/micro/f1": tensor(1.0),
- "token/macro/precision": tensor(1.0),
- "token/macro/recall": tensor(1.0),
- "token/micro/precision": tensor(1.0),
- "token/micro/recall": tensor(1.0),
- }
-
- target_labels = targets["labels"]
- predicted_labels = torch.ones_like(target_labels)
- # we need to set the same padding as in the targets
- predicted_labels[target_labels == taskmodule.label_pad_id] = taskmodule.label_pad_id
- prediction = {"labels": predicted_labels}
- metric.update(prediction, targets)
- values = metric.compute()
- values_converted = {k: v.item() for k, v in values.items()}
- assert values_converted == {
- "token/macro/f1": 0.5434783101081848,
- "token/micro/f1": 0.5249999761581421,
- "token/macro/precision": 0.773809552192688,
- "token/macro/recall": 0.625,
- "token/micro/precision": 0.5249999761581421,
- "token/micro/recall": 0.5249999761581421,
- "span/LOC/recall": 0.0476190485060215,
- "span/LOC/precision": 0.5,
- "span/LOC/f1": 0.08695652335882187,
- "span/macro/f1": 0.37681159377098083,
- "span/macro/precision": 0.5,
- "span/macro/recall": 0.523809552192688,
- "span/micro/recall": 0.1304347813129425,
- "span/micro/precision": 0.5,
- "span/micro/f1": 0.2068965584039688,
- "span/PER/recall": 1.0,
- "span/PER/precision": 0.5,
- "span/PER/f1": 0.6666666865348816,
- }
-
- # ensure that the metric can be pickled
- pickle.dumps(metric)
diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py
deleted file mode 100644
index 5251b8ba8..000000000
--- a/tests/taskmodules/test_pointer_network_for_end2end_re.py
+++ /dev/null
@@ -1,1313 +0,0 @@
-import logging
-import pickle
-from dataclasses import asdict, dataclass
-from typing import Dict, List, Set
-
-import pytest
-import torch
-from pie_core import AnnotationLayer, Document, annotation_field
-from transformers import LogitsProcessorList
-
-from pie_modules.annotations import BinaryRelation, LabeledSpan
-from pie_modules.documents import TextBasedDocument
-from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
-from pie_modules.taskmodules.pointer_network.logits_processor import (
- FinitizeLogitsProcessor,
- PrefixConstrainedLogitsProcessorWithMaximum,
-)
-from pie_modules.taskmodules.pointer_network_for_end2end_re import (
- LabelsAndOptionalConstraints,
-)
-
-logger = logging.getLogger(__name__)
-
-DUMP_FIXTURE_DATA = False
-
-
-def _config_to_str(cfg: Dict[str, str]) -> str:
- result = "-".join([f"{k}={cfg[k]}" for k in sorted(cfg)])
- return result
-
-
-CONFIGS = [{}, {"partition_layer_name": "sentences"}]
-CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-
-
-@pytest.fixture(scope="module", params=CONFIG_DICT.keys())
-def config_str(request):
- return request.param
-
-
-@pytest.fixture(scope="module")
-def config(config_str):
- return CONFIG_DICT[config_str]
-
-
-@dataclass
-class ExampleDocument(TextBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
- sentences: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
-
-
-@pytest.fixture(scope="module")
-def document():
- doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
- span1 = LabeledSpan(start=10, end=20, label="content")
- span2 = LabeledSpan(start=27, end=34, label="topic")
- span3 = LabeledSpan(start=42, end=44, label="person")
- doc.entities.extend([span1, span2, span3])
- assert str(span1) == "dummy text"
- assert str(span2) == "nothing"
- assert str(span3) == "me"
- rel = BinaryRelation(head=span1, tail=span2, label="is_about")
- doc.relations.append(rel)
- assert str(rel.label) == "is_about"
- assert str(rel.head) == "dummy text"
- assert str(rel.tail) == "nothing"
-
- no_rel = BinaryRelation(head=span1, tail=span3, label="no_relation")
- doc.relations.append(no_rel)
- assert str(no_rel.label) == "no_relation"
- assert str(no_rel.head) == "dummy text"
- assert str(no_rel.tail) == "me"
-
- sent1 = LabeledSpan(start=0, end=35, label="1")
- sent2 = LabeledSpan(start=36, end=45, label="2")
- doc.sentences.extend([sent1, sent2])
- assert str(sent1) == "This is a dummy text about nothing."
- assert str(sent2) == "Trust me."
- return doc
-
-
-def test_document(document):
- spans = document.entities
- assert len(spans) == 3
- assert (str(spans[0]), spans[0].label) == ("dummy text", "content")
- assert (str(spans[1]), spans[1].label) == ("nothing", "topic")
- assert (str(spans[2]), spans[2].label) == ("me", "person")
- relations = document.relations
- assert len(relations) == 2
- assert (str(relations[0].head), relations[0].label, str(relations[0].tail)) == (
- "dummy text",
- "is_about",
- "nothing",
- )
- assert (str(relations[1].head), relations[1].label, str(relations[1].tail)) == (
- "dummy text",
- "no_relation",
- "me",
- )
- sentences = document.sentences
- assert len(sentences) == 2
- assert str(sentences[0]) == "This is a dummy text about nothing."
- assert str(sentences[1]) == "Trust me."
-
-
-@pytest.fixture(scope="module")
-def taskmodule(document, config):
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- relation_layer_name="relations",
- exclude_labels_per_layer={"relations": ["no_relation"]},
- annotation_field_mapping={
- "entities": "labeled_spans",
- "relations": "binary_relations",
- },
- create_constraints=True,
- tokenizer_kwargs={"strict_span_conversion": False},
- **config,
- )
-
- taskmodule.prepare(documents=[document])
- return taskmodule
-
-
-def test_taskmodule(taskmodule):
- assert taskmodule.is_prepared
- assert taskmodule.prepared_attributes == {
- "labels_per_layer": {
- "entities": ["content", "person", "topic"],
- "relations": ["is_about"],
- },
- }
- assert taskmodule.layer_names == ["entities", "relations"]
- assert taskmodule.special_targets == ["", ""]
- assert taskmodule.labels == ["none", "content", "person", "topic", "is_about"]
- assert taskmodule.targets == [
- "",
- "",
- "none",
- "content",
- "person",
- "topic",
- "is_about",
- ]
- assert taskmodule.bos_id == 0
- assert taskmodule.eos_id == 1
- assert taskmodule.none_id == 2
- assert taskmodule.span_ids == [3, 4, 5]
- assert taskmodule.relation_ids == [6]
- assert taskmodule.label2id == {
- "content": 3,
- "is_about": 6,
- "none": 2,
- "person": 4,
- "topic": 5,
- }
- assert taskmodule.label_embedding_weight_mapping == {
- 50265: [45260],
- 50266: [39763],
- 50267: [354, 1215, 9006],
- 50268: [5970],
- 50269: [10166],
- }
- assert taskmodule.target_tokens == [
- "",
- "",
- "<>",
- "<>",
- "<>",
- "<>",
- "<>",
- ]
- assert taskmodule.target_token_ids == [0, 2, 50266, 50269, 50268, 50265, 50267]
-
-
-def test_taskmodule_with_wrong_annotation_field_mapping():
- with pytest.raises(ValueError) as exc_info:
- PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- relation_layer_name="relations",
- annotation_field_mapping={
- "entities": "labeled_spans",
- "sentences": "labeled_spans",
- },
- )
- assert str(exc_info.value) == (
- "inverted annotation_field_mapping is not unique. annotation_field_mapping: "
- "{'entities': 'labeled_spans', 'sentences': 'labeled_spans'}"
- )
-
-
-def test_prepared_config(taskmodule, config):
- if config == {}:
- assert taskmodule._config() == {
- "taskmodule_type": "PointerNetworkTaskModuleForEnd2EndRE",
- "relation_layer_name": "relations",
- "symmetric_relations": None,
- "none_label": "none",
- "loop_dummy_relation_name": "loop",
- "labels_per_layer": {
- "entities": ["content", "person", "topic"],
- "relations": ["is_about"],
- },
- "exclude_labels_per_layer": {"relations": ["no_relation"]},
- "create_constraints": True,
- "document_type": "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
- "tokenized_document_type": "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
- "tokenizer_name_or_path": "facebook/bart-base",
- "tokenizer_init_kwargs": None,
- "tokenizer_kwargs": {"strict_span_conversion": False},
- "partition_layer_name": None,
- "add_reversed_relations": False,
- "annotation_field_mapping": {
- "entities": "labeled_spans",
- "relations": "binary_relations",
- },
- "constrained_generation": False,
- "label_tokens": None,
- "label_representations": None,
- "log_first_n_examples": None,
- }
- elif config == {"partition_layer_name": "sentences"}:
- assert taskmodule._config() == {
- "taskmodule_type": "PointerNetworkTaskModuleForEnd2EndRE",
- "relation_layer_name": "relations",
- "symmetric_relations": None,
- "none_label": "none",
- "loop_dummy_relation_name": "loop",
- "labels_per_layer": {
- "entities": ["content", "person", "topic"],
- "relations": ["is_about"],
- },
- "exclude_labels_per_layer": {"relations": ["no_relation"]},
- "create_constraints": True,
- "document_type": "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
- "tokenized_document_type": "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
- "tokenizer_name_or_path": "facebook/bart-base",
- "tokenizer_init_kwargs": None,
- "tokenizer_kwargs": {"strict_span_conversion": False},
- "partition_layer_name": "sentences",
- "add_reversed_relations": False,
- "annotation_field_mapping": {
- "entities": "labeled_spans",
- "relations": "binary_relations",
- },
- "constrained_generation": False,
- "label_tokens": None,
- "label_representations": None,
- "log_first_n_examples": None,
- }
- else:
- raise Exception(f"unknown config: {config}")
-
-
-@pytest.fixture()
-def task_encoding_without_target(taskmodule, document):
- return taskmodule.encode_input(document)[0]
-
-
-def test_add_reversed_relation_labels():
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- symmetric_relations=["symmetric_relation"],
- )
-
- labels = ["is_about", "symmetric_relation"]
- labels_with_reversed = taskmodule.add_reversed_relation_labels(labels)
- assert labels_with_reversed == {"is_about", "is_about_reversed", "symmetric_relation"}
-
-
-def test_reverse_relation():
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- symmetric_relations=["symmetric_relation"],
- )
-
- rel = BinaryRelation(
- head=LabeledSpan(start=10, end=20, label="content"),
- tail=LabeledSpan(start=27, end=34, label="topic"),
- label="is_about",
- )
- reversed_relation = taskmodule.reverse_relation(relation=rel)
- assert reversed_relation == BinaryRelation(
- head=LabeledSpan(start=27, end=34, label="topic", score=1.0),
- tail=LabeledSpan(start=10, end=20, label="content", score=1.0),
- label="is_about_reversed",
- score=1.0,
- )
-
- sym_rel = BinaryRelation(
- head=LabeledSpan(start=10, end=20, label="content"),
- tail=LabeledSpan(start=27, end=34, label="topic"),
- label="symmetric_relation",
- )
- reversed_sym_rel = taskmodule.reverse_relation(relation=sym_rel)
- assert reversed_sym_rel == BinaryRelation(
- head=LabeledSpan(start=27, end=34, label="topic", score=1.0),
- tail=LabeledSpan(start=10, end=20, label="content", score=1.0),
- label="symmetric_relation",
- score=1.0,
- )
-
-
-def test_unreverse_relation():
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- symmetric_relations=["symmetric_relation"],
- )
-
- # nothing should change because the relation is not reversed
- rel = BinaryRelation(
- head=LabeledSpan(start=10, end=20, label="content"),
- tail=LabeledSpan(start=27, end=34, label="topic"),
- label="is_about",
- )
- same_rel = taskmodule.unreverse_relation(relation=rel)
- assert same_rel == rel
-
- # the relation is reversed, so it should be un-reversed
- reversed_rel = BinaryRelation(
- head=LabeledSpan(start=10, end=20, label="content"),
- tail=LabeledSpan(start=27, end=34, label="topic"),
- label="is_about_reversed",
- )
- unreversed_relation = taskmodule.unreverse_relation(relation=reversed_rel)
- assert unreversed_relation == BinaryRelation(
- head=LabeledSpan(start=27, end=34, label="topic", score=1.0),
- tail=LabeledSpan(start=10, end=20, label="content", score=1.0),
- label="is_about",
- score=1.0,
- )
-
- # nothing should change because the relation is symmetric and already ordered (head < tail)
- ordered_sym_rel = BinaryRelation(
- head=LabeledSpan(start=10, end=20, label="content"),
- tail=LabeledSpan(start=27, end=34, label="topic"),
- label="symmetric_relation",
- )
- unreversed_ordered_sym_rel = taskmodule.unreverse_relation(relation=ordered_sym_rel)
- assert ordered_sym_rel == unreversed_ordered_sym_rel
-
- # the relation is symmetric and unordered (head > tail), so it should be un-reversed
- unordered_sym_rel = BinaryRelation(
- head=LabeledSpan(start=27, end=34, label="topic"),
- tail=LabeledSpan(start=10, end=20, label="content"),
- label="symmetric_relation",
- )
- unreversed_unordered_sym_rel = taskmodule.unreverse_relation(relation=unordered_sym_rel)
- assert unreversed_unordered_sym_rel == BinaryRelation(
- head=LabeledSpan(start=10, end=20, label="content", score=1.0),
- tail=LabeledSpan(start=27, end=34, label="topic", score=1.0),
- label="symmetric_relation",
- score=1.0,
- )
-
-
-@pytest.fixture(params=[False, True])
-def taskmodule_with_reversed_relations(document, request) -> PointerNetworkTaskModuleForEnd2EndRE:
- is_about_is_symmetric = request.param
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- relation_layer_name="relations",
- exclude_labels_per_layer={"relations": ["no_relation"]},
- annotation_field_mapping={
- "entities": "labeled_spans",
- "relations": "binary_relations",
- },
- create_constraints=True,
- tokenizer_kwargs={"strict_span_conversion": False},
- add_reversed_relations=True,
- symmetric_relations=["is_about"] if is_about_is_symmetric else None,
- )
-
- taskmodule.prepare(documents=[document])
- assert taskmodule.is_prepared
- if is_about_is_symmetric:
- assert taskmodule.prepared_attributes == {
- "labels_per_layer": {
- "entities": ["content", "person", "topic"],
- "relations": ["is_about"],
- }
- }
- else:
- assert taskmodule.prepared_attributes == {
- "labels_per_layer": {
- "entities": ["content", "person", "topic"],
- "relations": ["is_about", "is_about_reversed"],
- }
- }
-
- return taskmodule
-
-
-def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, document):
- task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True)
- assert len(task_encodings) == 1
- task_encoding = task_encodings[0]
- assert task_encoding is not None
- assert asdict(task_encoding.inputs) == {
- "input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2],
- "attention_mask": [1] * 13,
- }
- tokens = taskmodule_with_reversed_relations.tokenizer.convert_ids_to_tokens(
- task_encoding.inputs.input_ids
- )
- assert tokens == [
- "",
- "This",
- "Ġis",
- "Ġa",
- "Ġdummy",
- "Ġtext",
- "Ġabout",
- "Ġnothing",
- ".",
- "ĠTrust",
- "Ġme",
- ".",
- "",
- ]
- if "is_about" in taskmodule_with_reversed_relations.symmetric_relations:
- decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations(
- task_encoding.targets
- )
- assert decoded_annotations == {
- "entities": [
- LabeledSpan(start=4, end=6, label="content", score=1.0),
- LabeledSpan(start=7, end=8, label="topic", score=1.0),
- LabeledSpan(start=10, end=11, label="person", score=1.0),
- ],
- "relations": [
- BinaryRelation(
- head=LabeledSpan(start=4, end=6, label="content", score=1.0),
- tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- label="is_about",
- score=1.0,
- ),
- BinaryRelation(
- head=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- tail=LabeledSpan(start=4, end=6, label="content", score=1.0),
- label="is_about",
- score=1.0,
- ),
- ],
- }
- else:
- decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations(
- task_encoding.targets
- )
- assert decoded_annotations == {
- "entities": [
- LabeledSpan(start=4, end=6, label="content", score=1.0),
- LabeledSpan(start=7, end=8, label="topic", score=1.0),
- LabeledSpan(start=10, end=11, label="person", score=1.0),
- ],
- "relations": [
- BinaryRelation(
- head=LabeledSpan(start=4, end=6, label="content", score=1.0),
- tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- label="is_about",
- score=1.0,
- ),
- BinaryRelation(
- head=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- tail=LabeledSpan(start=4, end=6, label="content", score=1.0),
- label="is_about_reversed",
- score=1.0,
- ),
- ],
- }
-
-
-def test_encode_with_add_reversed_relations_already_exists(caplog):
- doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
- doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
- doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
- )
- doc.relations.append(
- BinaryRelation(head=doc.entities[1], tail=doc.entities[0], label="is_about")
- )
-
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- relation_layer_name="relations",
- annotation_field_mapping={
- "entities": "labeled_spans",
- "relations": "binary_relations",
- },
- add_reversed_relations=True,
- symmetric_relations=["is_about"],
- )
- taskmodule.prepare(documents=[doc])
-
- with caplog.at_level(logging.WARNING):
- task_encodings = taskmodule.encode(doc, encode_target=True)
- assert len(caplog.messages) == 0
- assert len(task_encodings) == 1
- task_encoding = task_encodings[0]
-
- decoded_annotations, statistics = taskmodule.decode_annotations(task_encoding.targets)
- assert decoded_annotations == {
- "entities": [
- LabeledSpan(start=4, end=6, label="content", score=1.0),
- LabeledSpan(start=7, end=8, label="topic", score=1.0),
- ],
- "relations": [
- BinaryRelation(
- head=LabeledSpan(start=4, end=6, label="content", score=1.0),
- tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- label="is_about",
- score=1.0,
- ),
- BinaryRelation(
- head=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- tail=LabeledSpan(start=4, end=6, label="content", score=1.0),
- label="is_about",
- score=1.0,
- ),
- ],
- }
-
-
-def test_decode_with_add_reversed_relations():
- doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
- doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
- doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
- )
-
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- relation_layer_name="relations",
- annotation_field_mapping={
- "entities": "labeled_spans",
- "relations": "binary_relations",
- },
- add_reversed_relations=True,
- )
- taskmodule.prepare(documents=[doc])
-
- task_encodings = taskmodule.encode(doc, encode_target=True)
- assert len(task_encodings) == 1
- decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets)
- assert decoded_annotations == {
- "entities": [
- LabeledSpan(start=4, end=6, label="content", score=1.0),
- LabeledSpan(start=7, end=8, label="topic", score=1.0),
- ],
- "relations": [
- BinaryRelation(
- head=LabeledSpan(start=4, end=6, label="content", score=1.0),
- tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- label="is_about",
- score=1.0,
- ),
- BinaryRelation(
- head=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- tail=LabeledSpan(start=4, end=6, label="content", score=1.0),
- label="is_about_reversed",
- score=1.0,
- ),
- ],
- }
-
- task_outputs = [task_encoding.targets for task_encoding in task_encodings]
- docs_with_predictions = taskmodule.decode(task_encodings, task_outputs)
- assert len(docs_with_predictions) == 1
- doc_with_predictions: ExampleDocument = docs_with_predictions[0]
- assert set(doc_with_predictions.entities.predictions) == set(doc_with_predictions.entities)
- assert set(doc_with_predictions.relations.predictions) == set(doc_with_predictions.relations)
-
-
-@pytest.fixture()
-def target_encoding(taskmodule, task_encoding_without_target):
- return taskmodule.encode_target(task_encoding_without_target)
-
-
-def test_target_encoding(target_encoding, taskmodule):
- assert target_encoding is not None
- if taskmodule.partition_layer_name is None:
- assert asdict(target_encoding) == {
- "labels": [14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1],
- "constraints": [
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ],
- }
- elif taskmodule.partition_layer_name == "sentences":
- assert asdict(target_encoding) == {
- "labels": [14, 14, 5, 11, 12, 3, 6, 1],
- "constraints": [
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ],
- }
- else:
- raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")
-
-
-def test_task_encoding_with_deduplicated_relations(caplog):
- doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
- doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
- doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
- doc.entities.append(LabeledSpan(start=42, end=44, label="person"))
- assert doc.entities.resolve() == [
- ("content", "dummy text"),
- ("topic", "nothing"),
- ("person", "me"),
- ]
- # add the same relation twice (just use a different score, but that should not matter)
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
- )
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about", score=0.9)
- )
- assert doc.relations.resolve() == [
- ("is_about", (("content", "dummy text"), ("topic", "nothing"))),
- ("is_about", (("content", "dummy text"), ("topic", "nothing"))),
- ]
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- relation_layer_name="relations",
- annotation_field_mapping={
- "entities": "labeled_spans",
- "relations": "binary_relations",
- },
- )
- taskmodule.prepare(documents=[doc])
- caplog.clear()
- with caplog.at_level(logging.WARNING):
- task_encodings = taskmodule.encode(doc, encode_target=True)
- messages = list(caplog.messages)
-
- assert len(task_encodings) == 1
- decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets)
- assert decoded_annotations == {
- "entities": [
- LabeledSpan(start=4, end=6, label="content", score=1.0),
- LabeledSpan(start=7, end=8, label="topic", score=1.0),
- LabeledSpan(start=10, end=11, label="person", score=1.0),
- ],
- "relations": [
- BinaryRelation(
- head=LabeledSpan(start=4, end=6, label="content", score=1.0),
- tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
- label="is_about",
- score=1.0,
- )
- ],
- }
-
- assert messages == [
- (
- "encoding errors: {'correct': 2}, skipped annotations:\n"
- "{\n"
- ' "relations": [\n'
- ' "BinaryRelation('
- "head=LabeledSpan(start=4, end=6, label='content', score=1.0), "
- "tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), "
- "label='is_about', score=0.9"
- ')"\n'
- " ]\n"
- "}"
- )
- ]
-
-
-def test_task_encoding_with_conflicting_relations(caplog):
- doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
- doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
- doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
- doc.entities.append(LabeledSpan(start=42, end=44, label="person"))
- assert doc.entities.resolve() == [
- ("content", "dummy text"),
- ("topic", "nothing"),
- ("person", "me"),
- ]
- # add two relations with the same head and tail, but different labels
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
- )
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="wrong_relation")
- )
- assert doc.relations.resolve() == [
- ("is_about", (("content", "dummy text"), ("topic", "nothing"))),
- ("wrong_relation", (("content", "dummy text"), ("topic", "nothing"))),
- ]
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- relation_layer_name="relations",
- annotation_field_mapping={
- "entities": "labeled_spans",
- "relations": "binary_relations",
- },
- )
- taskmodule.prepare(documents=[doc])
- caplog.clear()
- with caplog.at_level(logging.ERROR):
- task_encodings = taskmodule.encode(doc, encode_target=True)
- messages = list(caplog.messages)
-
- assert len(task_encodings) == 0
-
- assert messages == [
- "failed to encode target, it will be skipped: "
- "relation ('Ġdummy', 'Ġtext') -> ('Ġnothing',) already exists, but has "
- "another label: is_about (current label: wrong_relation)."
- ]
-
-
-@pytest.fixture()
-def task_encoding(task_encoding_without_target, target_encoding):
- task_encoding_without_target.targets = target_encoding
- return task_encoding_without_target
-
-
-def _separate_constraint(constraint, taskmodule):
- special_ids = sorted(taskmodule.special_target2id.values())
- none_ids = [taskmodule.none_id]
- span_ids = taskmodule.span_ids
- rel_ids = taskmodule.relation_ids
- result = [[constraint[id] for id in ids] for ids in [special_ids, none_ids, span_ids, rel_ids]]
- result += [constraint[taskmodule.pointer_offset :]]
- assert sum(len(con_part) for con_part in result) == len(constraint)
- return result
-
-
-def test_build_constraint(taskmodule):
- target_ids = [14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1]
- input_len = 13
-
- # empty previous_ids
- constraint = taskmodule._build_constraint(previous_ids=[], input_len=input_len)
- # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow eos and all offsets
- assert constraint_formatted == [
- [0, 1],
- [0],
- [0, 0, 0],
- [0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
-
- # just first span start
- constraint = taskmodule._build_constraint(previous_ids=[14], input_len=input_len)
- # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow all offsets after first span start
- assert constraint_formatted == [
- [0, 0],
- [0],
- [0, 0, 0],
- [0],
- [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
- ]
-
- # first span start and end
- constraint = taskmodule._build_constraint(previous_ids=[14, 14], input_len=input_len)
- # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow all span ids
- assert constraint_formatted == [
- [0, 0],
- [0],
- [1, 1, 1],
- [0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ]
-
- # first span start, end, and label
- constraint = taskmodule._build_constraint(previous_ids=[14, 14, 5], input_len=input_len)
- # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow none and all offsets except offsets covered by first span
- assert constraint_formatted == [
- [0, 0],
- [1],
- [0, 0, 0],
- [0],
- [1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
- ]
-
- # first span, and second span start
- constraint = taskmodule._build_constraint(previous_ids=[14, 14, 5, 11], input_len=input_len)
- # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow all offsets after second span start, but not after first span start
- assert constraint_formatted == [
- [0, 0],
- [0],
- [0, 0, 0],
- [0],
- [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
- ]
-
- # first span, and second span start and end
- constraint = taskmodule._build_constraint(
- previous_ids=[14, 14, 5, 11, 12], input_len=input_len
- )
- # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow all span ids
- assert constraint_formatted == [
- [0, 0],
- [0],
- [1, 1, 1],
- [0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ]
-
- # first span, and second span
- constraint = taskmodule._build_constraint(
- previous_ids=[14, 14, 5, 11, 12, 3], input_len=input_len
- )
- # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow all relation ids
- assert constraint_formatted == [
- [0, 0],
- [0],
- [0, 0, 0],
- [1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ]
-
- # fist span, and (1 to 3)-times none
- for i in range(1, 3):
- none_ids = [2] * i
- constraint = taskmodule._build_constraint(
- previous_ids=[14, 14, 5] + none_ids, input_len=input_len
- )
- # [bos, eos], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow only none
- assert constraint_formatted == [
- [0, 0],
- [1],
- [0, 0, 0],
- [0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ]
-
- # contains eos
- constraint = taskmodule._build_constraint(
- previous_ids=[14, 14, 5, 11, 12, 3, 6, 1], input_len=input_len
- )
- # [bos, eos/pad], [none], [content, person, topic], [is_about] [13 offsets (all remaining)]
- constraint_formatted = _separate_constraint(constraint.tolist(), taskmodule)
- # allow only pad (same as eos)
- assert constraint_formatted == [
- [0, 1],
- [0],
- [0, 0, 0],
- [0],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ]
-
-
-def test_maybe_log_example(taskmodule, task_encoding, caplog, config):
- original_log_first_n_examples = taskmodule.log_first_n_examples
- taskmodule.log_first_n_examples = 1
- caplog.clear()
- with caplog.at_level(logging.INFO):
- taskmodule.maybe_log_example(task_encoding)
- if config == {}:
- assert caplog.messages == [
- "*** Example ***",
- "doc.id: None-tokenized-1-of-1",
- "input_ids: 0 713 16 10 34759 2788 59 1085 4 3101 162 4 2",
- "input_tokens: This Ġis Ġa Ġdummy Ġtext Ġabout Ġnothing . ĠTrust Ġme . " "",
- "label_ids: 14 14 5 11 12 3 6 17 17 4 2 2 2 2 1",
- "label_tokens: 14 {Ġnothing} 14 {Ġnothing} topic 11 {Ġdummy} 12 {Ġtext} content is_about 17 {Ġme} 17 "
- "{Ġme} person none none none none ",
- "constraints: torch.Size([15, 20]) (content is omitted)",
- ]
- elif config == {"partition_layer_name": "sentences"}:
- assert caplog.messages == [
- "*** Example ***",
- "doc.id: None-tokenized-1-of-2",
- "input_ids: 0 713 16 10 34759 2788 59 1085 4 2",
- "input_tokens: This Ġis Ġa Ġdummy Ġtext Ġabout Ġnothing . ",
- "label_ids: 14 14 5 11 12 3 6 1",
- "label_tokens: 14 {Ġnothing} 14 {Ġnothing} topic 11 {Ġdummy} 12 {Ġtext} content is_about ",
- "constraints: torch.Size([8, 17]) (content is omitted)",
- ]
- else:
- raise Exception(f"unknown config: {config}")
-
- # restore original value
- taskmodule.log_first_n_examples = original_log_first_n_examples
-
-
-def test_maybe_log_example_disabled(taskmodule, task_encoding, caplog):
- original_log_first_n_examples = taskmodule.log_first_n_examples
- taskmodule.log_first_n_examples = None
- caplog.clear()
- with caplog.at_level(logging.INFO):
- taskmodule.maybe_log_example(task_encoding)
- assert caplog.record_tuples == []
-
- # restore original value
- taskmodule.log_first_n_examples = original_log_first_n_examples
-
-
-@pytest.fixture()
-def task_encodings(taskmodule, document):
- return taskmodule.encode(documents=[document], encode_target=True)
-
-
-@pytest.fixture()
-def batch(taskmodule, task_encodings):
- return taskmodule.collate(task_encodings)
-
-
-def test_collate(batch, taskmodule):
- inputs, targets = batch
- for tensor in inputs.values():
- assert isinstance(tensor, torch.Tensor)
- assert tensor.dtype == torch.int64
- for tensor in targets.values():
- assert isinstance(tensor, torch.Tensor)
- assert tensor.dtype == torch.int64
- inputs_lists = {k: inputs[k].tolist() for k in sorted(inputs)}
- targets_lists = {k: targets[k].tolist() for k in sorted(targets)}
- if taskmodule.partition_layer_name is None:
- assert inputs_lists == {
- "input_ids": [[0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2]],
- "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
- }
- assert targets_lists == {
- "constraints": [
- [
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ]
- ],
- "labels": [[14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1]],
- "decoder_attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
- }
- elif taskmodule.partition_layer_name == "sentences":
- assert inputs_lists == {
- "input_ids": [
- [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2],
- [0, 18823, 162, 4, 2, 1, 1, 1, 1, 1],
- ],
- "attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]],
- }
- assert targets_lists == {
- "constraints": [
- [
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ],
- [
- [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1],
- [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1, -1, -1, -1],
- [0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- [0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, -1, -1, -1, -1, -1],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1],
- ],
- ],
- "labels": [[14, 14, 5, 11, 12, 3, 6, 1], [9, 9, 4, 2, 2, 2, 2, 1]],
- "decoder_attention_mask": [
- [1, 1, 1, 1, 1, 1, 1, 1],
- [1, 1, 1, 1, 1, 1, 1, 1],
- ],
- }
- else:
- raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")
-
-
-@pytest.fixture()
-def unbatched_output(taskmodule, batch):
- inputs, targets = batch
- # because the model is trained to reproduce the target tokens, we can just use them as model prediction
- return taskmodule.unbatch_output(targets)
-
-
-@pytest.fixture()
-def task_outputs(unbatched_output):
- return unbatched_output
-
-
-@pytest.fixture()
-def task_output(task_outputs) -> LabelsAndOptionalConstraints:
- return task_outputs[0]
-
-
-def test_task_output(task_output, taskmodule):
- output_list = task_output.labels
- if taskmodule.partition_layer_name is None:
- assert output_list == [14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1]
- elif taskmodule.partition_layer_name == "sentences":
- assert output_list == [14, 14, 5, 11, 12, 3, 6, 1]
- else:
- raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")
-
-
-def _test_annotations_from_output(task_encodings, task_outputs, taskmodule, layer_names_expected):
- assert len(task_outputs) == len(task_encodings)
-
- # this needs to be outside the below loop because documents can contain duplicates
- # which would break the comparison when clearing predictions that were already added
- for task_encoding in task_encodings:
- for layer_name in layer_names_expected:
- task_encoding.document[layer_name].predictions.clear()
-
- layer_names: Set[str] = set()
- # Note: this list may contain duplicates!
- documents: List[Document] = []
- for i in range(len(task_outputs)):
- task_encoding = task_encodings[i]
- task_output = task_outputs[i]
- documents.append(task_encoding.document)
-
- for layer_name, annotation in taskmodule.create_annotations_from_output(
- task_encoding=task_encoding, task_output=task_output
- ):
- task_encoding.document[layer_name].predictions.append(annotation)
- layer_names.add(layer_name)
-
- assert layer_names == layer_names_expected
-
- for document in documents:
- for layer_name in layer_names:
- layer = {
- str(ann)
- for ann in document[layer_name].predictions
- if ann.label in taskmodule.labels_per_layer[layer_name]
- }
- layer_expected = {
- str(ann)
- for ann in document[layer_name]
- if ann.label in taskmodule.labels_per_layer[layer_name]
- }
- assert layer == layer_expected
-
- # this needs to be outside the above loop because documents can contain duplicates
- # which would break the comparison when clearing predictions too early
- for document in documents:
- for layer_name in layer_names:
- document[layer_name].predictions.clear()
-
-
-def test_annotations_from_output(task_encodings, task_outputs, taskmodule):
- _test_annotations_from_output(
- taskmodule=taskmodule,
- task_encodings=task_encodings,
- task_outputs=task_outputs,
- layer_names_expected={"entities", "relations"},
- )
-
-
-def get_default_taskmodule(**kwargs):
- taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
- tokenizer_name_or_path="facebook/bart-base",
- labels_per_layer={
- "labeled_spans": ["content", "person", "topic"],
- "binary_relations": ["is_about"],
- },
- **kwargs,
- )
- taskmodule.post_prepare()
- return taskmodule
-
-
-def test_configure_model_metric():
- taskmodule = get_default_taskmodule()
- metric = taskmodule.configure_model_metric()
- assert metric is not None
- values = metric.compute()
- assert values == {
- "binary_relations": {},
- "decoding_errors": {"all": 0.0},
- "exact_encoding_matches": 0.0,
- "labeled_spans": {},
- }
-
- model_output = {"labels": torch.tensor([[14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1]])}
- # test with expected == prediction
- metric.update(model_output, model_output)
- values = metric.compute()
- assert values == {
- "exact_encoding_matches": 1.0,
- "decoding_errors": {"correct": 1.0, "all": 0.0},
- "labeled_spans": {
- "content": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "person": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "topic": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "macro": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "micro": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- },
- "binary_relations": {
- "is_about": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "macro": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- "micro": {"recall": 1.0, "precision": 1.0, "f1": 1.0},
- },
- }
- torch.random.manual_seed(42)
- # random_labels = torch.randint(0, 20, (1, 30))
- # split into random_labels1 and random_labels2 just for better code formatting
- random_labels1 = [0, 14, 4, 19, 2, 6, 18, 3, 0, 8, 8, 14, 2, 1]
- random_labels2 = [14, 6, 7, 8, 4, 1, 17, 9, 14, 7, 13, 15, 5, 12, 18, 13]
- labels_random = torch.tensor([random_labels1 + random_labels2])
- metric.reset()
- # test the case where we have mixed results (correct and wrong)
- metric.update(model_output, model_output)
- metric.update(prediction={"labels": labels_random}, expected=model_output)
- values = metric.compute()
- assert values == {
- "exact_encoding_matches": 0.5,
- "decoding_errors": {"correct": 0.5, "len": 0.25, "order": 0.25, "all": 0.5},
- "labeled_spans": {
- "person": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816},
- "topic": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816},
- "content": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816},
- "macro": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816},
- "micro": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816},
- },
- "binary_relations": {
- "is_about": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816},
- "macro": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816},
- "micro": {"recall": 0.5, "precision": 1.0, "f1": 0.6666666865348816},
- },
- }
-
- # ensure that the metric can be pickled
- pickle.dumps(metric)
-
-
-def test_configure_model_generation():
- taskmodule = get_default_taskmodule()
- assert taskmodule.configure_model_generation() == {
- "no_repeat_ngram_size": 7,
- }
-
-
-def test_configure_model_generation_with_constrained_generation():
- taskmodule = get_default_taskmodule(constrained_generation=True)
- generation_config = taskmodule.configure_model_generation()
- assert set(generation_config) == {"no_repeat_ngram_size", "logits_processor"}
- assert generation_config["no_repeat_ngram_size"] == 7
- logits_processor = generation_config["logits_processor"]
- assert isinstance(logits_processor, LogitsProcessorList)
- assert len(logits_processor) == 2
- assert isinstance(logits_processor[0], FinitizeLogitsProcessor)
- assert isinstance(logits_processor[1], PrefixConstrainedLogitsProcessorWithMaximum)
-
-
-def test_prefix_allowed_tokens_fn_with_maximum():
- taskmodule = get_default_taskmodule()
- # not that this includes the leading bos token
- add_previous_input_ids = torch.tensor([0, 14, 14, 5, 11, 12, 3, 6, 17, 17, 4, 2, 2, 2, 2, 1])
-
- # empty input (first entry)
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:1], maximum=20
- )
- # allow the eos id [1] and all offset ids [7..19]
- assert allowed_ids == [1, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
-
- # first span start
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:2], maximum=20
- )
- # allow all offset ids from first span start [14..19]
- assert allowed_ids == [14, 15, 16, 17, 18, 19]
-
- # first span start and end
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:3], maximum=20
- )
- # allow all span ids
- assert allowed_ids == [3, 4, 5]
-
- # first span start, end, and label
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:4], maximum=20
- )
- # allow none [2] and all offsets except offsets covered by first span [14]
- assert allowed_ids == [2, 7, 8, 9, 10, 11, 12, 13, 15, 16, 17, 18, 19]
-
- # first span, and second span start
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:5], maximum=20
- )
- # allow all offsets from second span start [11], but before first span start [14] because it would be an overlap
- assert allowed_ids == [11, 12, 13]
-
- # first span, and second span start and end
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:6], maximum=20
- )
- # allow all span ids
- assert allowed_ids == [3, 4, 5]
-
- # first span, and second span
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:7], maximum=20
- )
- # allow all relation ids
- assert allowed_ids == [6]
-
- # entry begins (second entry)
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:8], maximum=20
- )
- # allow eos [1] and all offsets [7..19]
- assert allowed_ids == [1, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
-
- # first span start
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:9], maximum=20
- )
- # allow all offsets from first span start [17..19]
- assert allowed_ids == [17, 18, 19]
-
- # first span start and end
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:10], maximum=20
- )
- # allow all span ids
- assert allowed_ids == [3, 4, 5]
-
- # first span start, end, and span label
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:11], maximum=20
- )
- # allow none [2] and all offsets except offsets covered by first span [17]
- assert allowed_ids == [2, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19]
-
- # first span, and none
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:12], maximum=20
- )
- # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else
- assert allowed_ids == [2]
-
- # first span, and none, and none
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:13], maximum=20
- )
- # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else
- assert allowed_ids == [2]
-
- # first span, and none, and none, and none
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:14], maximum=20
- )
- # allow only none [2] because when the entry contains already a none id, it cannot be followed by anything else
- assert allowed_ids == [2]
-
- # first span, and none, and none, and none, and none (second entry is complete)
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:15], maximum=20
- )
- # allow eos [1] and all offsets [7..19]
- assert allowed_ids == [1, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
-
- # got an eos, so the sequence is complete
- allowed_ids = taskmodule._prefix_allowed_tokens_fn_with_maximum(
- batch_id=0, input_ids=add_previous_input_ids[:16], maximum=20
- )
- # allow only pad [1] (same as eos) because the sequence is complete
- assert allowed_ids == [1]
diff --git a/tests/taskmodules/test_re_span_pair_classification.py b/tests/taskmodules/test_re_span_pair_classification.py
deleted file mode 100644
index f82554c1e..000000000
--- a/tests/taskmodules/test_re_span_pair_classification.py
+++ /dev/null
@@ -1,614 +0,0 @@
-import dataclasses
-import logging
-from typing import Any, Dict, Union
-
-import pytest
-import torch
-from pie_core import AnnotationLayer, annotation_field
-from pie_core.utils.dictionary import flatten_dict_s
-from torch import tensor
-from torchmetrics import Metric, MetricCollection
-
-from pie_modules.annotations import BinaryRelation, LabeledSpan
-from pie_modules.documents import TextBasedDocument
-from pie_modules.taskmodules import RESpanPairClassificationTaskModule
-from pie_modules.utils.span import distance
-from tests import _config_to_str
-
-TOKENIZER_NAME_OR_PATH = "bert-base-cased"
-
-CONFIGS = [{}, {"partition_annotation": "sentences"}]
-CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-
-
-@pytest.fixture(scope="module", params=CONFIGS_DICT.keys())
-def cfg(request):
- return CONFIGS_DICT[request.param]
-
-
-@pytest.fixture(scope="module")
-def unprepared_taskmodule(cfg):
- taskmodule = RESpanPairClassificationTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH,
- log_first_n_examples=10,
- collect_statistics=True,
- **cfg,
- )
- assert not taskmodule.is_from_pretrained
-
- return taskmodule
-
-
-@dataclasses.dataclass
-class FixedTestDocument(TextBasedDocument):
- sentences: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
-
-
-@pytest.fixture(scope="module")
-def fixed_documents(documents):
- result = []
- for document in documents:
- fixed_doc = document.copy(with_annotations=False).as_type(FixedTestDocument)
- for sentence in document.sentences:
- fixed_doc.sentences.append(
- LabeledSpan(start=sentence.start, end=sentence.end, label="sentence")
- )
- entity_mapping = {}
- for entity in document.entities:
- new_entity = entity.copy()
- fixed_doc.entities.append(new_entity)
- entity_mapping[entity] = new_entity
- for relation in document.relations:
- new_relation = relation.copy(
- head=entity_mapping[relation.head], tail=entity_mapping[relation.tail]
- )
- fixed_doc.relations.append(new_relation)
- result.append(fixed_doc)
- return result
-
-
-@pytest.fixture(scope="module")
-def taskmodule(unprepared_taskmodule, fixed_documents) -> RESpanPairClassificationTaskModule:
- unprepared_taskmodule.prepare(fixed_documents)
- return unprepared_taskmodule
-
-
-def test_taskmodule(taskmodule: RESpanPairClassificationTaskModule):
- assert taskmodule.is_prepared
-
- assert taskmodule.relation_annotation == "relations"
- assert taskmodule.labels == ["org:founded_by", "per:employee_of", "per:founder"]
- assert taskmodule.entity_labels == ["ORG", "PER"]
- assert taskmodule.label_to_id == {
- "org:founded_by": 1,
- "per:employee_of": 2,
- "per:founder": 3,
- "no_relation": 0,
- }
- assert taskmodule.argument_markers == [
- "[/SPAN:ORG]",
- "[/SPAN:PER]",
- "[SPAN:ORG]",
- "[SPAN:PER]",
- ]
- assert taskmodule.tokenizer.additional_special_tokens == [
- "[SPAN:PER]",
- "[/SPAN:ORG]",
- "[/SPAN:PER]",
- "[SPAN:ORG]",
- ]
- assert taskmodule.tokenizer.additional_special_tokens_ids == [28996, 28997, 28998, 28999]
-
- # because this is not the standard value for relation_annotation, we can not determine the document type
- assert taskmodule.document_type is None
-
-
-@pytest.fixture(scope="module")
-def document(fixed_documents):
- result = fixed_documents[4]
- assert (
- result.metadata["description"]
- == "sentences with multiple relation annotations and cross-sentence relation"
- )
- return result
-
-
-def test_create_candidate_relations(taskmodule, document):
- # _create_candidate_relations requires normalized documents
- normalized_document = taskmodule.normalize_document(document)
- candidate_relations = taskmodule._create_candidate_relations(normalized_document)
- resolved_relations = [ann.resolve() for ann in candidate_relations]
- assert resolved_relations == [
- ("no_relation", (("PER", "Entity G"), ("ORG", "H"))),
- ("no_relation", (("PER", "Entity G"), ("ORG", "I"))),
- ("no_relation", (("ORG", "H"), ("PER", "Entity G"))),
- ("no_relation", (("ORG", "H"), ("ORG", "I"))),
- ("no_relation", (("ORG", "I"), ("PER", "Entity G"))),
- ("no_relation", (("ORG", "I"), ("ORG", "H"))),
- ]
-
-
-def test_create_candidate_relations_with_max_distance(taskmodule, document):
- # _create_candidate_relations requires normalized documents
- normalized_document = taskmodule.normalize_document(document)
- candidate_relations = taskmodule._create_candidate_relations(
- normalized_document, max_argument_distance=10
- )
- resolved_relations = [ann.resolve() for ann in candidate_relations]
- assert resolved_relations == [
- ("no_relation", (("PER", "Entity G"), ("ORG", "H"))),
- ("no_relation", (("ORG", "H"), ("PER", "Entity G"))),
- ]
- distances = [
- distance(
- start_end=(rel.head.start, rel.head.end),
- other_start_end=(rel.tail.start, rel.tail.end),
- distance_type="inner",
- )
- for rel in candidate_relations
- ]
- assert distances == [10.0, 10.0]
-
-
-@pytest.fixture(scope="module")
-def task_encodings(taskmodule, document):
- result = taskmodule.encode(document, encode_target=True)
- return result
-
-
-def test_encode_input(task_encodings, document, taskmodule, cfg):
- assert task_encodings is not None
- if cfg == {}:
- assert len(task_encodings) == 1
- inputs = task_encodings[0].inputs
- assert set(inputs) == {
- "input_ids",
- "attention_mask",
- "span_start_indices",
- "span_end_indices",
- "tuple_indices",
- "tuple_indices_mask",
- }
- tokens = taskmodule.tokenizer.convert_ids_to_tokens(inputs["input_ids"])
- assert tokens == [
- "[CLS]",
- "First",
- "sentence",
- ".",
- "[SPAN:PER]",
- "En",
- "##ti",
- "##ty",
- "G",
- "[/SPAN:PER]",
- "works",
- "at",
- "[SPAN:ORG]",
- "H",
- "[/SPAN:ORG]",
- ".",
- "And",
- "founded",
- "[SPAN:ORG]",
- "I",
- "[/SPAN:ORG]",
- ".",
- "[SEP]",
- ]
- span_tokens = [
- tokens[start:end]
- for start, end in zip(inputs["span_start_indices"], inputs["span_end_indices"])
- ]
- assert span_tokens == [
- ["[SPAN:PER]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"],
- ["[SPAN:ORG]", "H", "[/SPAN:ORG]"],
- ["[SPAN:ORG]", "I", "[/SPAN:ORG]"],
- ]
- tuple_spans = [
- [span_tokens[idx] for idx in indices] for indices in inputs["tuple_indices"]
- ]
- assert tuple_spans == [
- [
- ["[SPAN:PER]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"],
- ["[SPAN:ORG]", "H", "[/SPAN:ORG]"],
- ],
- [
- ["[SPAN:PER]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"],
- ["[SPAN:ORG]", "I", "[/SPAN:ORG]"],
- ],
- [["[SPAN:ORG]", "I", "[/SPAN:ORG]"], ["[SPAN:ORG]", "H", "[/SPAN:ORG]"]],
- ]
- assert inputs["tuple_indices_mask"].tolist() == [True, True, True]
- elif cfg == {"partition_annotation": "sentences"}:
- assert len(task_encodings) == 1
- for idx, encoding in enumerate(task_encodings):
- inputs = encoding.inputs
- assert set(inputs) == {
- "input_ids",
- "attention_mask",
- "span_start_indices",
- "span_end_indices",
- "tuple_indices",
- "tuple_indices_mask",
- }
- tokens = taskmodule.tokenizer.convert_ids_to_tokens(inputs["input_ids"])
- span_tokens = [
- tokens[start:end]
- for start, end in zip(inputs["span_start_indices"], inputs["span_end_indices"])
- ]
- tuple_spans = [
- [span_tokens[idx] for idx in indices] for indices in inputs["tuple_indices"]
- ]
- if idx == 0:
- assert tokens == [
- "[CLS]",
- "En",
- "##ti",
- "##ty",
- "G",
- "[/SPAN:PER]",
- "works",
- "at",
- "[SPAN:ORG]",
- "H",
- "[/SPAN:ORG]",
- ".",
- "[SEP]",
- ]
- assert span_tokens == [
- ["[CLS]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"],
- ["[SPAN:ORG]", "H", "[/SPAN:ORG]"],
- ]
- assert tuple_spans == [
- [
- ["[CLS]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"],
- ["[SPAN:ORG]", "H", "[/SPAN:ORG]"],
- ]
- ]
- assert inputs["tuple_indices_mask"].tolist() == [True]
- else:
- raise ValueError(f"unexpected idx: {idx}")
- else:
- raise ValueError(f"unexpected config: {cfg}")
-
-
-def test_encode_target(taskmodule, task_encodings, cfg):
- if cfg == {}:
- assert len(task_encodings) == 1
- targets = task_encodings[0].targets
- labels = [taskmodule.id_to_label[label] for label in targets["labels"].tolist()]
- assert labels == ["per:employee_of", "per:founder", "org:founded_by"]
- elif cfg == {"partition_annotation": "sentences"}:
- assert len(task_encodings) == 1
- for idx, encoding in enumerate(task_encodings):
- targets = encoding.targets
- labels = [taskmodule.id_to_label[label] for label in targets["labels"].tolist()]
- if idx == 0:
- assert labels == ["per:employee_of"]
- else:
- raise ValueError(f"unexpected idx: {idx}")
- else:
- raise ValueError(f"unexpected config: {cfg}")
-
-
-def test_encode_with_no_gold_relation(document):
- # create a new taskmodule that does create candidate relations
- taskmodule = RESpanPairClassificationTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH,
- create_candidate_relations=True,
- labels=["org:founded_by", "per:employee_of", "per:founder"],
- entity_labels=["ORG", "PER"],
- )
- taskmodule.post_prepare()
- # create a new document that has no relations
- document = document.copy()
- document.relations.clear()
-
- encodings = taskmodule.encode(document, encode_target=True)
-
- assert len(encodings) == 1
- encoding = encodings[0]
- # same number of candidate relations as there are labels
- assert len(encoding.metadata["candidate_relations"]) == encoding.targets["labels"].numel()
- assert all(rel.label == "no_relation" for rel in encoding.metadata["candidate_relations"])
- assert encoding.targets["labels"].tolist() == [0, 0, 0, 0, 0, 0]
-
-
-def test_encode_with_multiple_gold_relations_with_same_arguments(document, caplog):
- # create a new taskmodule that does create candidate relations
- taskmodule = RESpanPairClassificationTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=TOKENIZER_NAME_OR_PATH,
- labels=["org:founded_by", "per:employee_of", "per:founder"],
- entity_labels=["ORG", "PER"],
- )
- taskmodule.post_prepare()
- # create a new document that has multiple relations with the same arguments
- document = document.copy()
- document.relations.clear()
- head = document.entities[0]
- tail = document.entities[1]
- document.relations.extend(
- [
- BinaryRelation(head=head, tail=tail, label="org:founded_by"),
- BinaryRelation(head=head, tail=tail, label="per:employee_of"),
- ]
- )
-
- caplog.clear()
- with caplog.at_level(logging.WARNING):
- encodings = taskmodule.encode(document, encode_target=True)
- assert len(caplog.messages) == 2
- assert (
- caplog.messages[0]
- == "skip the candidate relation because there are more than one gold relation for "
- "its args and roles: [BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), "
- "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='org:founded_by', score=1.0), "
- "BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), "
- "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='per:employee_of', score=1.0)]"
- )
- assert (
- caplog.messages[1]
- == "skip the candidate relation because there are more than one gold relation for "
- "its args and roles: [BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), "
- "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='org:founded_by', score=1.0), "
- "BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), "
- "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='per:employee_of', score=1.0)]"
- )
-
- assert len(encodings) == 1
- encoding = encodings[0]
- candidate_relations = encoding.metadata["candidate_relations"]
- # same number of candidate relations as there are labels
- assert len(candidate_relations) == encoding.targets["labels"].numel()
- assert candidate_relations[0].label == "org:founded_by"
- assert candidate_relations[1].label == "per:employee_of"
- assert encoding.targets["labels"].tolist() == [-100, -100]
-
-
-def test_maybe_log_example(taskmodule, task_encodings, caplog, cfg):
- caplog.clear()
- if cfg == {}:
- with caplog.at_level(logging.INFO):
- taskmodule._maybe_log_example(task_encodings[0], target=task_encodings[0].targets)
- assert caplog.messages == [
- "*** Example ***",
- "doc id: train_doc5",
- "tokens: [CLS] First sentence . [SPAN:PER] En ##ti ##ty G [/SPAN:PER] works at [SPAN:ORG] H [/SPAN:ORG] . And founded [SPAN:ORG] I [/SPAN:ORG] . [SEP]",
- "input_ids: 101 1752 5650 119 28996 13832 3121 2340 144 28998 1759 1120 28999 145 28997 119 1262 1771 28999 146 28997 119 102",
- "relation 0: per:employee_of",
- "\targ 0: [SPAN:PER] En ##ti ##ty G [/SPAN:PER]",
- "\targ 1: [SPAN:ORG] H [/SPAN:ORG]",
- "relation 1: per:founder",
- "\targ 0: [SPAN:PER] En ##ti ##ty G [/SPAN:PER]",
- "\targ 1: [SPAN:ORG] I [/SPAN:ORG]",
- "relation 2: org:founded_by",
- "\targ 0: [SPAN:ORG] I [/SPAN:ORG]",
- "\targ 1: [SPAN:ORG] H [/SPAN:ORG]",
- ]
- elif cfg == {"partition_annotation": "sentences"}:
- with caplog.at_level(logging.INFO):
- taskmodule._maybe_log_example(task_encodings[0], target=task_encodings[0].targets)
- assert caplog.messages == [
- "*** Example ***",
- "doc id: train_doc5",
- "tokens: [CLS] En ##ti ##ty G [/SPAN:PER] works at [SPAN:ORG] H [/SPAN:ORG] . [SEP]",
- "input_ids: 101 13832 3121 2340 144 28998 1759 1120 28999 145 28997 119 102",
- "relation 0: per:employee_of",
- "\targ 0: [CLS] En ##ti ##ty G [/SPAN:PER]",
- "\targ 1: [SPAN:ORG] H [/SPAN:ORG]",
- ]
- else:
- raise ValueError(f"unexpected config: {cfg}")
-
-
-def test_encode_with_statistics(taskmodule, fixed_documents, cfg, caplog):
- caplog.clear()
- with caplog.at_level(logging.INFO):
- taskmodule.encode(fixed_documents, encode_target=True)
- assert len(caplog.messages) > 0
- statistics = caplog.messages[-1]
- if cfg == {}:
- assert (
- statistics
- == """statistics:
-| | org:founded_by | per:employee_of | per:founder |
-|:--------------------|-----------------:|------------------:|--------------:|
-| available | 2 | 3 | 2 |
-| available_tokenized | 2 | 3 | 2 |
-| used | 2 | 3 | 2 |"""
- )
- elif cfg == {"partition_annotation": "sentences"}:
- assert (
- statistics
- == """statistics:
-| | org:founded_by | per:employee_of | per:founder |
-|:--------------------|-----------------:|------------------:|--------------:|
-| available | 2 | 3 | 2 |
-| available_tokenized | 1 | 3 | 1 |
-| used | 1 | 3 | 1 |"""
- )
- else:
- raise ValueError(f"unexpected config: {cfg}")
-
-
-def test_collate(taskmodule, task_encodings, cfg):
- result = taskmodule.collate(task_encodings)
- assert result is not None
- inputs, targets = result
- assert set(inputs) == {
- "input_ids",
- "attention_mask",
- "span_start_indices",
- "span_end_indices",
- "tuple_indices",
- "tuple_indices_mask",
- }
- if cfg == {}:
- torch.testing.assert_close(
- inputs["input_ids"],
- tensor(
- [
- [
- 101,
- 1752,
- 5650,
- 119,
- 28996,
- 13832,
- 3121,
- 2340,
- 144,
- 28998,
- 1759,
- 1120,
- 28999,
- 145,
- 28997,
- 119,
- 1262,
- 1771,
- 28999,
- 146,
- 28997,
- 119,
- 102,
- ]
- ]
- ),
- )
- torch.testing.assert_close(inputs["attention_mask"], torch.ones_like(inputs["input_ids"]))
- torch.testing.assert_close(inputs["span_start_indices"], tensor([[4, 12, 18]]))
- torch.testing.assert_close(inputs["span_end_indices"], tensor([[10, 15, 21]]))
- torch.testing.assert_close(inputs["tuple_indices"], tensor([[[0, 1], [0, 2], [2, 1]]]))
- torch.testing.assert_close(inputs["tuple_indices_mask"], tensor([[True, True, True]]))
- assert set(targets) == {"labels"}
- torch.testing.assert_close(targets["labels"], tensor([[2, 3, 1]]))
- elif cfg == {"partition_annotation": "sentences"}:
- torch.testing.assert_close(
- inputs["input_ids"],
- tensor(
- [[101, 13832, 3121, 2340, 144, 28998, 1759, 1120, 28999, 145, 28997, 119, 102]]
- ),
- )
- torch.testing.assert_close(
- inputs["attention_mask"], tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
- )
- torch.testing.assert_close(inputs["span_start_indices"], tensor([[0, 8]]))
- torch.testing.assert_close(inputs["span_end_indices"], tensor([[6, 11]]))
- torch.testing.assert_close(inputs["tuple_indices"], tensor([[[0, 1]]]))
- torch.testing.assert_close(inputs["tuple_indices_mask"], tensor([[True]]))
- assert set(targets) == {"labels"}
- torch.testing.assert_close(targets["labels"], tensor([[2]]))
- else:
- raise ValueError(f"unexpected config: {cfg}")
-
-
-@pytest.fixture
-def model_output():
- return {
- "labels": torch.tensor([[2, 3, 1]]),
- "probabilities": torch.tensor(
- [
- [
- # no_relation, org:founded_by, per:employee_of, per:founder
- [0.1, 0.2, 0.6, 0.1],
- [0.1, 0.2, 0.2, 0.5],
- [0.2, 0.5, 0.2, 0.1],
- ]
- ]
- ),
- }
-
-
-@pytest.fixture
-def unbatched_model_outputs(taskmodule, model_output):
- return taskmodule.unbatch_output(model_output)
-
-
-def test_unbatch_outputs(taskmodule, unbatched_model_outputs):
- assert len(unbatched_model_outputs) == 1
- result = unbatched_model_outputs[0]
- assert set(result) == {"labels", "probabilities"}
- assert result["labels"] == ["per:employee_of", "per:founder", "org:founded_by"]
- assert result["probabilities"] == [0.6000000238418579, 0.5, 0.5]
-
-
-def test_create_annotations_from_output(
- taskmodule, unbatched_model_outputs, task_encodings, document
-):
- result = list(
- taskmodule.create_annotations_from_output(
- task_encoding=task_encodings[0], task_output=unbatched_model_outputs[0]
- )
- )
- scores = [0.6000000238418579, 0.5, 0.5]
- for i, ((layer_name, predicted_relation), original_relation) in enumerate(
- zip(result, document.relations)
- ):
- assert layer_name == taskmodule.relation_annotation
- assert predicted_relation == original_relation.copy()
- assert predicted_relation.score == scores[i]
-
-
-def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]:
- if isinstance(metric_or_collection, Metric):
- return {
- k: v.tolist() for k, v in flatten_dict_s(metric_or_collection.metric_state).items()
- }
- elif isinstance(metric_or_collection, MetricCollection):
- return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()})
- else:
- raise ValueError(f"unsupported type: {type(metric_or_collection)}")
-
-
-def test_configure_model_metrics(taskmodule, model_output):
- metrics = taskmodule.configure_model_metric(stage="train")
- assert metrics is not None
- assert isinstance(metrics, (Metric, MetricCollection))
- state = get_metric_state(metrics)
- assert state == {
- "f1_per_label/tp": [0, 0, 0, 0],
- "f1_per_label/fp": [0, 0, 0, 0],
- "f1_per_label/tn": [0, 0, 0, 0],
- "f1_per_label/fn": [0, 0, 0, 0],
- "macro/f1/tp": [0, 0, 0, 0],
- "macro/f1/fp": [0, 0, 0, 0],
- "macro/f1/tn": [0, 0, 0, 0],
- "macro/f1/fn": [0, 0, 0, 0],
- "micro/f1/tp": [0],
- "micro/f1/fp": [0],
- "micro/f1/tn": [0],
- "micro/f1/fn": [0],
- }
-
- metric_values = metrics(model_output, model_output)
- state = get_metric_state(metrics)
- assert state == {
- "f1_per_label/tp": [0, 1, 1, 1],
- "f1_per_label/fp": [0, 0, 0, 0],
- "f1_per_label/tn": [3, 2, 2, 2],
- "f1_per_label/fn": [0, 0, 0, 0],
- "macro/f1/tp": [0, 1, 1, 1],
- "macro/f1/fp": [0, 0, 0, 0],
- "macro/f1/tn": [3, 2, 2, 2],
- "macro/f1/fn": [0, 0, 0, 0],
- "micro/f1/tp": [3],
- "micro/f1/fp": [0],
- "micro/f1/tn": [9],
- "micro/f1/fn": [0],
- }
-
- metric_values_converted = {key: value.item() for key, value in metric_values.items()}
- assert metric_values_converted == {
- "macro/f1": 1.0,
- "micro/f1": 1.0,
- "no_relation/f1": 0.0,
- "org:founded_by/f1": 1.0,
- "per:employee_of/f1": 1.0,
- "per:founder/f1": 1.0,
- }
diff --git a/tests/taskmodules/test_re_text_classification_with_indices.py b/tests/taskmodules/test_re_text_classification_with_indices.py
deleted file mode 100644
index 0574e10e5..000000000
--- a/tests/taskmodules/test_re_text_classification_with_indices.py
+++ /dev/null
@@ -1,3159 +0,0 @@
-import dataclasses
-import logging
-import pickle
-import re
-from dataclasses import dataclass
-from typing import Any, Dict, List, Union
-
-import pytest
-import torch
-from pie_core import (
- Annotation,
- AnnotationLayer,
- Document,
- TaskEncoding,
- annotation_field,
-)
-from pie_core.utils.dictionary import flatten_dict_s
-from torch import tensor
-from torchmetrics import Metric, MetricCollection
-
-from pie_modules.annotations import BinaryRelation, LabeledSpan, NaryRelation
-from pie_modules.documents import (
- TextBasedDocument,
- TextDocumentWithLabeledSpansAndBinaryRelations,
-)
-from pie_modules.taskmodules import RETextClassificationWithIndicesTaskModule
-from pie_modules.taskmodules.re_text_classification_with_indices import (
- HEAD,
- TAIL,
- find_sublist,
- get_relation_argument_spans_and_roles,
- span_distance,
-)
-from pie_modules.utils.span import distance_inner
-from tests import _config_to_str
-from tests.conftest import _TABULATE_AVAILABLE, TestDocument
-
-CONFIGS = [
- {"add_type_to_marker": False, "append_markers": False},
- {"add_type_to_marker": True, "append_markers": False},
- {"add_type_to_marker": False, "append_markers": True},
- {"add_type_to_marker": True, "append_markers": True},
-]
-CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}
-
-
-@pytest.fixture(scope="module", params=CONFIGS_DICT.keys())
-def cfg(request):
- return CONFIGS_DICT[request.param]
-
-
-def test_taskmodule_with_deprecated_parameters(caplog):
- with caplog.at_level(logging.WARNING):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- tokenizer_name_or_path=tokenizer_name_or_path, label_to_id={"a": 0, "b": 1}
- )
- assert taskmodule.labels == ["a", "b"]
- # check the warning message
- assert len(caplog.records) == 1
- assert (
- caplog.records[0].message
- == "The parameter label_to_id is deprecated and will be removed in a future version. Please use labels instead."
- )
-
-
-@pytest.fixture(scope="module")
-def unprepared_taskmodule(cfg):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations", tokenizer_name_or_path=tokenizer_name_or_path, **cfg
- )
- assert not taskmodule.is_from_pretrained
-
- return taskmodule
-
-
-@pytest.fixture(scope="module")
-def taskmodule(unprepared_taskmodule, documents):
- unprepared_taskmodule.prepare(documents)
- return unprepared_taskmodule
-
-
-@pytest.fixture
-def model_output():
- return {
- "labels": torch.tensor([1, 0, 2, 3, 1, 0, 0, 0]),
- "probabilities": torch.tensor(
- [
- # O, org:founded_by, per:employee_of, per:founder
- [0.1, 0.6, 0.1, 0.2],
- [0.5, 0.2, 0.2, 0.1],
- [0.1, 0.2, 0.6, 0.1],
- [0.1, 0.2, 0.2, 0.5],
- [0.2, 0.4, 0.3, 0.1],
- [0.5, 0.2, 0.2, 0.1],
- [0.6, 0.1, 0.2, 0.1],
- [0.5, 0.2, 0.2, 0.1],
- ]
- ),
- }
-
-
-def test_prepared_taskmodule(taskmodule, documents):
- assert taskmodule.is_prepared
-
- assert taskmodule.entity_labels == ["ORG", "PER"]
-
- if taskmodule.append_markers:
- if taskmodule.add_type_to_marker:
- assert taskmodule.argument_markers == [
- "[/H:ORG]",
- "[/H:PER]",
- "[/H]",
- "[/T:ORG]",
- "[/T:PER]",
- "[/T]",
- "[H:ORG]",
- "[H:PER]",
- "[H=ORG]",
- "[H=PER]",
- "[H]",
- "[T:ORG]",
- "[T:PER]",
- "[T=ORG]",
- "[T=PER]",
- "[T]",
- ]
- assert taskmodule.argument_markers_to_id == {
- "[/H:ORG]": 28996,
- "[/H:PER]": 28997,
- "[/H]": 28998,
- "[/T:ORG]": 28999,
- "[/T:PER]": 29000,
- "[/T]": 29001,
- "[H:ORG]": 29002,
- "[H:PER]": 29003,
- "[H=ORG]": 29004,
- "[H=PER]": 29005,
- "[H]": 29006,
- "[T:ORG]": 29007,
- "[T:PER]": 29008,
- "[T=ORG]": 29009,
- "[T=PER]": 29010,
- "[T]": 29011,
- }
-
- else:
- assert taskmodule.argument_markers == [
- "[/H]",
- "[/T]",
- "[H=ORG]",
- "[H=PER]",
- "[H]",
- "[T=ORG]",
- "[T=PER]",
- "[T]",
- ]
- assert taskmodule.argument_markers_to_id == {
- "[/H]": 28996,
- "[/T]": 28997,
- "[H=ORG]": 28998,
- "[H=PER]": 28999,
- "[H]": 29000,
- "[T=ORG]": 29001,
- "[T=PER]": 29002,
- "[T]": 29003,
- }
- else:
- if taskmodule.add_type_to_marker:
- assert taskmodule.argument_markers == [
- "[/H:ORG]",
- "[/H:PER]",
- "[/H]",
- "[/T:ORG]",
- "[/T:PER]",
- "[/T]",
- "[H:ORG]",
- "[H:PER]",
- "[H]",
- "[T:ORG]",
- "[T:PER]",
- "[T]",
- ]
- assert taskmodule.argument_markers_to_id == {
- "[/H:ORG]": 28996,
- "[/H:PER]": 28997,
- "[/H]": 28998,
- "[/T:ORG]": 28999,
- "[/T:PER]": 29000,
- "[/T]": 29001,
- "[H:ORG]": 29002,
- "[H:PER]": 29003,
- "[H]": 29004,
- "[T:ORG]": 29005,
- "[T:PER]": 29006,
- "[T]": 29007,
- }
- else:
- assert taskmodule.argument_markers == ["[/H]", "[/T]", "[H]", "[T]"]
- assert taskmodule.argument_markers_to_id == {
- "[/H]": 28996,
- "[/T]": 28997,
- "[H]": 28998,
- "[T]": 28999,
- }
-
- assert taskmodule.label_to_id == {
- "org:founded_by": 1,
- "per:employee_of": 2,
- "per:founder": 3,
- "no_relation": 0,
- }
- assert taskmodule.id_to_label == {
- 1: "org:founded_by",
- 2: "per:employee_of",
- 3: "per:founder",
- 0: "no_relation",
- }
-
-
-def test_config(taskmodule):
- config = taskmodule._config()
- assert config["taskmodule_type"] == "RETextClassificationWithIndicesTaskModule"
- assert taskmodule.PREPARED_ATTRIBUTES == ["labels", "entity_labels"]
- assert all(attribute in config for attribute in taskmodule.PREPARED_ATTRIBUTES)
- assert config["labels"] == ["org:founded_by", "per:employee_of", "per:founder"]
- assert config["entity_labels"] == ["ORG", "PER"]
-
-
-@pytest.mark.parametrize("encode_target", [False, True])
-def test_encode(taskmodule, documents, encode_target):
- task_encodings = taskmodule.encode(documents, encode_target=encode_target)
-
- assert len(task_encodings) == 7
-
- encoding = task_encodings[0]
-
- tokens = taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"])
- assert len(tokens) == len(encoding.inputs["input_ids"])
-
- if taskmodule.add_type_to_marker:
- assert tokens[:14] == [
- "[CLS]",
- "[H:PER]",
- "En",
- "##ti",
- "##ty",
- "A",
- "[/H:PER]",
- "works",
- "at",
- "[T:ORG]",
- "B",
- "[/T:ORG]",
- ".",
- "[SEP]",
- ]
- else:
- assert tokens[:14] == [
- "[CLS]",
- "[H]",
- "En",
- "##ti",
- "##ty",
- "A",
- "[/H]",
- "works",
- "at",
- "[T]",
- "B",
- "[/T]",
- ".",
- "[SEP]",
- ]
- if taskmodule.append_markers:
- assert len(tokens) == 14 + 4
- assert tokens[-4:] == ["[H=PER]", "[SEP]", "[T=ORG]", "[SEP]"]
- else:
- assert len(tokens) == 14
-
- if encode_target:
- assert encoding.targets == [2]
- else:
- assert not encoding.has_targets
-
- with pytest.raises(ValueError, match=re.escape("task encoding has no target")):
- encoding.targets
-
-
-@pytest.fixture(scope="module")
-def batch(taskmodule, documents):
- documents = [documents[i] for i in [0, 1, 4]]
- task_encodings = taskmodule.encode(documents, encode_target=True)
- return taskmodule.collate(task_encodings[:2])
-
-
-def test_collate(taskmodule, batch):
- inputs, targets = batch
-
- assert "input_ids" in inputs
- assert "attention_mask" in inputs
- assert inputs["input_ids"].shape == inputs["attention_mask"].shape
-
- if taskmodule.append_markers:
- assert inputs["input_ids"].shape == (2, 25)
- if taskmodule.add_type_to_marker:
- torch.testing.assert_close(
- inputs.input_ids,
- torch.tensor(
- [
- [
- 101,
- 29003,
- 13832,
- 3121,
- 2340,
- 138,
- 28997,
- 1759,
- 1120,
- 29007,
- 139,
- 28999,
- 119,
- 102,
- 29005,
- 102,
- 29009,
- 102,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 29003,
- 13832,
- 3121,
- 2340,
- 144,
- 28997,
- 1759,
- 1120,
- 29007,
- 145,
- 28999,
- 119,
- 1262,
- 1771,
- 146,
- 119,
- 102,
- 29005,
- 102,
- 29009,
- 102,
- ],
- ]
- ),
- )
- else:
- torch.testing.assert_close(
- inputs.input_ids,
- torch.tensor(
- [
- [
- 101,
- 29000,
- 13832,
- 3121,
- 2340,
- 138,
- 28996,
- 1759,
- 1120,
- 29003,
- 139,
- 28997,
- 119,
- 102,
- 28999,
- 102,
- 29001,
- 102,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 29000,
- 13832,
- 3121,
- 2340,
- 144,
- 28996,
- 1759,
- 1120,
- 29003,
- 145,
- 28997,
- 119,
- 1262,
- 1771,
- 146,
- 119,
- 102,
- 28999,
- 102,
- 29001,
- 102,
- ],
- ]
- ),
- )
- torch.testing.assert_close(
- inputs.attention_mask,
- torch.tensor(
- [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ),
- )
-
- else:
- assert inputs["input_ids"].shape == (2, 21)
-
- if taskmodule.add_type_to_marker:
- torch.testing.assert_close(
- inputs.input_ids,
- torch.tensor(
- [
- [
- 101,
- 29003,
- 13832,
- 3121,
- 2340,
- 138,
- 28997,
- 1759,
- 1120,
- 29005,
- 139,
- 28999,
- 119,
- 102,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 29003,
- 13832,
- 3121,
- 2340,
- 144,
- 28997,
- 1759,
- 1120,
- 29005,
- 145,
- 28999,
- 119,
- 1262,
- 1771,
- 146,
- 119,
- 102,
- ],
- ]
- ),
- )
- else:
- torch.testing.assert_close(
- inputs.input_ids,
- torch.tensor(
- [
- [
- 101,
- 28998,
- 13832,
- 3121,
- 2340,
- 138,
- 28996,
- 1759,
- 1120,
- 28999,
- 139,
- 28997,
- 119,
- 102,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- 0,
- ],
- [
- 101,
- 1752,
- 5650,
- 119,
- 28998,
- 13832,
- 3121,
- 2340,
- 144,
- 28996,
- 1759,
- 1120,
- 28999,
- 145,
- 28997,
- 119,
- 1262,
- 1771,
- 146,
- 119,
- 102,
- ],
- ]
- ),
- )
- torch.testing.assert_close(
- inputs.attention_mask,
- torch.tensor(
- [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ),
- )
-
- assert set(targets) == {"labels"}
- torch.testing.assert_close(targets["labels"], torch.tensor([2, 2]))
-
-
-def test_unbatch_output(taskmodule, model_output):
- unbatched_outputs = taskmodule.unbatch_output(model_output)
-
- assert len(unbatched_outputs) == 8
-
- labels = [
- "org:founded_by",
- "no_relation",
- "per:employee_of",
- "per:founder",
- "org:founded_by",
- "no_relation",
- "no_relation",
- "no_relation",
- ]
- probabilities = [0.6, 0.5, 0.6, 0.5, 0.4, 0.5, 0.6, 0.5]
-
- for output, label, probability in zip(unbatched_outputs, labels, probabilities):
- assert set(output.keys()) == {"labels", "probabilities"}
- assert output["labels"] == [label]
- assert output["probabilities"] == pytest.approx([probability])
-
-
-@pytest.mark.parametrize("inplace", [False, True])
-def test_decode(taskmodule, documents, model_output, inplace):
- # copy the documents, because the taskmodule may modify them
- documents = [documents[i].copy() for i in [0, 1, 4]]
-
- encodings = taskmodule.encode(documents, encode_target=False)
- unbatched_outputs = taskmodule.unbatch_output(model_output)
- decoded_documents = taskmodule.decode(
- task_encodings=encodings,
- task_outputs=unbatched_outputs,
- inplace=inplace,
- )
-
- assert len(decoded_documents) == len(documents)
-
- if inplace:
- assert {id(doc) for doc in decoded_documents} == {id(doc) for doc in documents}
- else:
- assert {id(doc) for doc in decoded_documents}.isdisjoint({id(doc) for doc in documents})
-
- expected_scores = [0.6, 0.5, 0.6, 0.5, 0.4, 0.5, 0.6, 0.5]
- i = 0
- for document in decoded_documents:
- for relation_expected, relation_decoded in zip(
- document["entities"], document["entities"].predictions
- ):
- assert relation_expected.start == relation_decoded.start
- assert relation_expected.end == relation_decoded.end
- assert relation_expected.label == relation_decoded.label
- assert expected_scores[i] == pytest.approx(relation_decoded.score)
- i += 1
-
- if not inplace:
- for document in documents:
- assert not document["relations"].predictions
-
-
-def test_encode_with_partition(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- partition_annotation="sentences",
- )
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- assert len(documents) == 7
- encodings = taskmodule.encode(documents)
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"])
- for encoding in encodings
- ]
- assert len(encodings) == 5
- assert encodings[0].document != encodings[1].document
- assert encodings[1].document != encodings[2].document
- # the last document contains 3 valid relations
- assert encodings[2].document == encodings[3].document
- assert encodings[3].document == encodings[4].document
- assert tokens[0] == [
- "[CLS]",
- "[H]",
- "En",
- "##ti",
- "##ty",
- "A",
- "[/H]",
- "works",
- "at",
- "[T]",
- "B",
- "[/T]",
- ".",
- "[SEP]",
- ]
- assert tokens[1] == [
- "[CLS]",
- "[H]",
- "En",
- "##ti",
- "##ty",
- "G",
- "[/H]",
- "works",
- "at",
- "[T]",
- "H",
- "[/T]",
- ".",
- "[SEP]",
- ]
- assert tokens[2] == [
- "[CLS]",
- "[H]",
- "En",
- "##ti",
- "##ty",
- "M",
- "[/H]",
- "works",
- "at",
- "[T]",
- "N",
- "[/T]",
- ".",
- "[SEP]",
- ]
- assert tokens[3] == [
- "[CLS]",
- "And",
- "[H]",
- "it",
- "[/H]",
- "founded",
- "[T]",
- "O",
- "[/T]",
- "[SEP]",
- ]
- assert tokens[4] == [
- "[CLS]",
- "And",
- "[T]",
- "it",
- "[/T]",
- "founded",
- "[H]",
- "O",
- "[/H]",
- "[SEP]",
- ]
-
-
-def test_encode_with_windowing(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- max_window=12,
- )
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- assert len(documents) == 7
- encodings = taskmodule.encode(documents)
- assert len(encodings) == 3
- for encoding in encodings:
- assert len(encoding.inputs["input_ids"]) <= taskmodule.max_window
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"])
- for encoding in encodings
- ]
- assert tokens[0] == [
- "[CLS]",
- "at",
- "[T]",
- "H",
- "[/T]",
- ".",
- "And",
- "founded",
- "[H]",
- "I",
- "[/H]",
- "[SEP]",
- ]
- assert tokens[1] == [
- "[CLS]",
- ".",
- "And",
- "[H]",
- "it",
- "[/H]",
- "founded",
- "[T]",
- "O",
- "[/T]",
- ".",
- "[SEP]",
- ]
- assert tokens[2] == [
- "[CLS]",
- ".",
- "And",
- "[T]",
- "it",
- "[/T]",
- "founded",
- "[H]",
- "O",
- "[/H]",
- ".",
- "[SEP]",
- ]
-
-
-def test_encode_with_allow_discontinuous_text(documents):
- tokenizer_name_or_path = "bert-base-cased"
- # tokenizer_name_or_path = "allenai/longformer-scico"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- max_window=12,
- allow_discontinuous_text=True,
- )
-
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- assert len(documents) == 7
- encodings = taskmodule.encode(documents)
- assert len(encodings) == 3
-
- for encoding in encodings:
- assert len(encoding.inputs["input_ids"]) <= taskmodule.max_window
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"])
- for encoding in encodings
- ]
- assert tokens == [
- ["[CLS]", "at", "[T]", "H", "[/T]", "[SEP]", "founded", "[H]", "I", "[/H]", "[SEP]"],
- ["[CLS]", "And", "[H]", "it", "[/H]", "founded", "[T]", "O", "[/T]", "[SEP]"],
- ["[CLS]", "And", "[T]", "it", "[/T]", "founded", "[H]", "O", "[/H]", "[SEP]"],
- ]
-
-
-def test_encode_with_allow_discontinuous_text_and_binary_relations():
- """This checks whether relation arguments at the very beginning or end of the document are
- encoded correctly.
-
- Also, it checks whether the encoding of the consecutive spans that fit within the frame
- specified by max_window is correct.
- """
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- tokenizer_name_or_path=tokenizer_name_or_path,
- max_window=128,
- allow_discontinuous_text=True,
- )
- texts = [
- "Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut labore et dolore nagna aliqua.",
- "Ut enin ad ninin venian, quis nostrun exercitationen ullan corporis suscipit laboriosan, nisi ut aliquid ex ea connodi consequatur.",
- "Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur.",
- "Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit anin id est laborun.",
- ]
- text_lengths = [len(text) for text in texts]
- sep = " "
-
- doc = TextDocumentWithLabeledSpansAndBinaryRelations(
- text=sep.join(texts),
- id="123",
- )
-
- labeled_spans = []
- offset = 0
- for i, text in enumerate(texts):
- labeled_spans.append(
- LabeledSpan(start=0 + offset, end=text_lengths[i] + offset, label="sentence")
- )
- offset += text_lengths[i] + len(sep)
-
- for span in labeled_spans:
- doc.labeled_spans.append(span)
- assert doc.labeled_spans.resolve() == [
- (
- "sentence",
- "Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut "
- "labore et dolore nagna aliqua.",
- ),
- (
- "sentence",
- "Ut enin ad ninin venian, quis nostrun exercitationen ullan corporis suscipit laboriosan, "
- "nisi ut aliquid ex ea connodi consequatur.",
- ),
- (
- "sentence",
- "Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur.",
- ),
- (
- "sentence",
- "Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit "
- "anin id est laborun.",
- ),
- ]
-
- rel_start = BinaryRelation(
- head=doc.labeled_spans[0], tail=doc.labeled_spans[2], label="relation", score=1.0
- )
- doc.binary_relations.append(rel_start)
- rel_end = BinaryRelation(
- head=doc.labeled_spans[-1], tail=doc.labeled_spans[0], label="relation", score=1.0
- )
- doc.binary_relations.append(rel_end)
- rel_consecutive = BinaryRelation(
- head=doc.labeled_spans[2], tail=doc.labeled_spans[3], label="relation", score=1.0
- )
- doc.binary_relations.append(rel_consecutive)
-
- # test document where everything is already included in one argument frame
- doc2 = TextDocumentWithLabeledSpansAndBinaryRelations("A founded B.", id="123")
- doc2.labeled_spans.append(LabeledSpan(start=0, end=1, label="PER"))
- doc2.labeled_spans.append(LabeledSpan(start=10, end=11, label="PER"))
- assert doc2.labeled_spans.resolve() == [("PER", "A"), ("PER", "B")]
- rel = BinaryRelation(head=doc2.labeled_spans[0], tail=doc2.labeled_spans[1], label="relation")
- doc2.binary_relations.append(rel)
-
- taskmodule.prepare([doc, doc2])
- encoded = taskmodule.encode_input(doc)
-
- decoded_arg_start = taskmodule.tokenizer.decode(encoded[0].inputs["input_ids"])
- decoded_arg_end = taskmodule.tokenizer.decode(encoded[1].inputs["input_ids"])
- decoded_arg_consecutive = taskmodule.tokenizer.decode(encoded[2].inputs["input_ids"])
-
- assert (
- decoded_arg_start
- == "[CLS] [H] Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut labore et dolore nagna aliqua. [/H] Ut enin ad ninin venian, quis no [SEP] ex ea connodi consequatur. [T] Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur. [/T] Excepteur sint obcaecat cupid [SEP]"
- )
-
- assert (
- decoded_arg_end
- == "[CLS] [T] Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut labore et dolore nagna aliqua. [/T] Ut enin ad ninin venian, quis no [SEP]se cillun dolore eu fugiat nulla pariatur. [H] Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit anin id est laborun. [/H] [SEP]"
- )
-
- assert (
- decoded_arg_consecutive
- == "[CLS] ex ea connodi consequatur. [H] Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur. [/H] [T] Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit anin id est laborun. [/T] [SEP]"
- )
-
- encoded2 = taskmodule.encode_input(doc2)
- assert len(encoded2) == 1
- decoded2 = taskmodule.tokenizer.decode(encoded2[0].inputs["input_ids"])
- assert decoded2 == "[CLS] [H] A [/H] founded [T] B [/T]. [SEP]"
-
-
-def get_arg_token_span(
- tokens: List[str],
- start_indices: List[int],
- end_indices: List[int],
- argument_role2idx: Dict[str, int],
-) -> Dict[str, List[str]]:
- return {
- role: tokens[start_indices[argument_role2idx[role]] : end_indices[argument_role2idx[role]]]
- for role, idx in argument_role2idx.items()
- }
-
-
-def test_encode_with_add_argument_indices(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_argument_indices_to_input=True,
- )
-
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- encodings = taskmodule.encode(documents, encode_target=True)
- assert len(encodings) == 7
- batch = taskmodule.collate(encodings)
- inputs, targets = batch
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
-
- arg_spans = [
- get_arg_token_span(
- current_tokens,
- current_start_indices,
- current_end_indices,
- taskmodule.argument_role2idx,
- )
- for current_tokens, current_start_indices, current_end_indices in zip(
- tokens, inputs["pooler_start_indices"].tolist(), inputs["pooler_end_indices"].tolist()
- )
- ]
-
- assert arg_spans == [
- {"head": ["En", "##ti", "##ty", "A"], "tail": ["B"]},
- {"head": ["En", "##ti", "##ty", "G"], "tail": ["H"]},
- {"head": ["En", "##ti", "##ty", "G"], "tail": ["I"]},
- {"head": ["I"], "tail": ["H"]},
- {"head": ["En", "##ti", "##ty", "M"], "tail": ["N"]},
- {"head": ["it"], "tail": ["O"]},
- {"head": ["O"], "tail": ["it"]},
- ]
-
-
-def test_encode_with_add_argument_indices_and_without_insert_markers(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_argument_indices_to_input=True,
- insert_markers=False,
- )
-
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- encodings = taskmodule.encode(documents, encode_target=True)
- assert len(encodings) == 7
- batch = taskmodule.collate(encodings)
- inputs, targets = batch
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
-
- arg_spans = [
- get_arg_token_span(
- current_tokens,
- current_start_indices,
- current_end_indices,
- taskmodule.argument_role2idx,
- )
- for current_tokens, current_start_indices, current_end_indices in zip(
- tokens, inputs["pooler_start_indices"].tolist(), inputs["pooler_end_indices"].tolist()
- )
- ]
-
- assert arg_spans == [
- {"head": ["En", "##ti", "##ty", "A"], "tail": ["B"]},
- {"head": ["En", "##ti", "##ty", "G"], "tail": ["H"]},
- {"head": ["En", "##ti", "##ty", "G"], "tail": ["I"]},
- {"head": ["I"], "tail": ["H"]},
- {"head": ["En", "##ti", "##ty", "M"], "tail": ["N"]},
- {"head": ["it"], "tail": ["O"]},
- {"head": ["O"], "tail": ["it"]},
- ]
-
-
-def test_find_sublist():
- # default case
- assert find_sublist(sub=[2, 3], bigger=[1, 2, 3, 4]) == 1
- # no sublist
- assert find_sublist(sub=[2, 3], bigger=[1, 3, 2, 4]) == -1
- # empty sublist: occurs on every position, but first is returned
- assert find_sublist(sub=[], bigger=[1, 3, 2, 4]) == 0
- # empty bigger
- assert find_sublist(sub=[2, 3], bigger=[]) == -1
-
-
-def test_encode_with_add_argument_indices_and_windowing(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_argument_indices_to_input=True,
- max_window=12,
- )
-
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- encodings = taskmodule.encode(documents, encode_target=True)
- assert len(encodings) == 3
- batch = taskmodule.collate(encodings)
- inputs, targets = batch
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
-
- arg_spans = [
- get_arg_token_span(
- current_tokens,
- current_start_indices,
- current_end_indices,
- taskmodule.argument_role2idx,
- )
- for current_tokens, current_start_indices, current_end_indices in zip(
- tokens, inputs["pooler_start_indices"].tolist(), inputs["pooler_end_indices"].tolist()
- )
- ]
-
- assert arg_spans == [
- {"head": ["I"], "tail": ["H"]},
- {"head": ["it"], "tail": ["O"]},
- {"head": ["O"], "tail": ["it"]},
- ]
-
-
-def test_encode_with_add_argument_indices_windowing_and_without_insert_markers(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_argument_indices_to_input=True,
- max_window=8,
- insert_markers=False,
- )
-
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- encodings = taskmodule.encode(documents, encode_target=True)
- assert len(encodings) == 3
- batch = taskmodule.collate(encodings)
- inputs, targets = batch
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
-
- arg_spans = [
- get_arg_token_span(
- current_tokens,
- current_start_indices,
- current_end_indices,
- taskmodule.argument_role2idx,
- )
- for current_tokens, current_start_indices, current_end_indices in zip(
- tokens, inputs["pooler_start_indices"].tolist(), inputs["pooler_end_indices"].tolist()
- )
- ]
-
- assert arg_spans == [
- {"head": ["I"], "tail": ["H"]},
- {"head": ["it"], "tail": ["O"]},
- {"head": ["O"], "tail": ["it"]},
- ]
-
-
-@pytest.mark.parametrize("handle_relations_with_same_arguments", ["keep_first", "keep_none"])
-@pytest.mark.parametrize("add_candidate_relations", [False, True])
-@pytest.mark.parametrize("collect_statistics", [False, True])
-def test_encode_input_multiple_relations_for_same_arguments(
- caplog, handle_relations_with_same_arguments, add_candidate_relations, collect_statistics
-):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- handle_relations_with_same_arguments=handle_relations_with_same_arguments,
- collect_statistics=collect_statistics,
- add_candidate_relations=add_candidate_relations,
- )
- document = TestDocument(text="A founded B.", id="test_doc")
- document.entities.append(LabeledSpan(start=0, end=1, label="PER"))
- document.entities.append(LabeledSpan(start=10, end=11, label="PER"))
- entities = document.entities
- assert str(entities[0]) == "A"
- assert str(entities[1]) == "B"
- document.relations.extend(
- [
- BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"),
- BinaryRelation(head=entities[0], tail=entities[1], label="per:founder"),
- BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"),
- ]
- )
- taskmodule.prepare([document])
-
- with caplog.at_level(logging.WARNING):
- encodings = taskmodule.encode_input(document)
-
- statistics = taskmodule.get_statistics()
- candidate_relation = [enc.metadata["candidate_annotation"] for enc in encodings]
- candidate_relation_tuples = [
- (rel.head.resolve(), rel.label, rel.tail.resolve()) for rel in candidate_relation
- ]
-
- if handle_relations_with_same_arguments == "keep_first":
- # Note: Warnings are shown only if statistics are disabled. For details see comment at
- # src/pie_modules/taskmodules/re_text_classification_with_indices.py:811-818
- expected_warning = (
- "doc.id=test_doc: there are multiple relations with the same arguments "
- "(('head', ('PER', 'A')), ('tail', ('PER', 'B'))), but different labels: "
- "['per:founded_by', 'per:founder', 'per:founded_by']. We only keep the first "
- "occurring relation which has the label='per:founded_by'."
- )
- if not add_candidate_relations:
- # with 'keep_first', only first relation occurred is kept ('per:founded_by').
- # full duplicate of 'per:founded_by' is removed and appears neither as available,
- # nor as skipped in statistics.
- assert candidate_relation_tuples == [(("PER", "A"), "per:founded_by", ("PER", "B"))]
- if collect_statistics:
- assert statistics == {
- ("available", "per:founded_by"): 1,
- ("available", "per:founder"): 1,
- ("skipped_same_arguments", "per:founder"): 1,
- ("used", "per:founded_by"): 1,
- }
- assert caplog.messages == []
- else:
- assert statistics == {}
- assert caplog.messages == [expected_warning]
-
- else:
- # as above, but with candidate (negative) relations added
- assert candidate_relation_tuples == [
- (("PER", "A"), "per:founded_by", ("PER", "B")),
- (("PER", "B"), "no_relation", ("PER", "A")),
- ]
- if collect_statistics:
- assert statistics == {
- ("available", "per:founded_by"): 1,
- ("available", "per:founder"): 1,
- ("used", "no_relation"): 1,
- ("used", "per:founded_by"): 1,
- ("skipped_same_arguments", "per:founder"): 1,
- }
- assert caplog.messages == []
- else:
- assert statistics == {}
- assert caplog.messages == [expected_warning]
-
- elif handle_relations_with_same_arguments == "keep_none":
- # Note: Warnings are shown only if statistics are disabled. For details see comment at
- # src/pie_modules/taskmodules/re_text_classification_with_indices.py:811-818
- expected_warning = (
- "doc.id=test_doc: there are multiple relations with the same arguments "
- "(('head', ('PER', 'A')), ('tail', ('PER', 'B'))), but different labels: "
- "['per:founded_by', 'per:founder', 'per:founded_by']. All relations will be removed."
- )
- if not add_candidate_relations:
- # with 'keep_none' both relations sharing same arguments are removed
- # full duplicate of 'per:founded_by' is removed and appears neither as available,
- # nor as skipped in statistics.
- assert candidate_relation_tuples == []
- if collect_statistics:
- assert statistics == {
- ("available", "per:founded_by"): 1,
- ("available", "per:founder"): 1,
- ("skipped_same_arguments", "per:founder"): 1,
- ("skipped_same_arguments", "per:founded_by"): 1,
- }
- assert caplog.messages == []
- else:
- assert statistics == {}
- assert caplog.messages == [expected_warning]
- else:
- # all conflicting relations go into the same direction, so we can create a candidate (negative)
- # relation for the other direction.
- assert candidate_relation_tuples == [(("PER", "B"), "no_relation", ("PER", "A"))]
- if collect_statistics:
- assert statistics == {
- ("available", "per:founded_by"): 1,
- ("available", "per:founder"): 1,
- ("skipped_same_arguments", "per:founded_by"): 1,
- ("skipped_same_arguments", "per:founder"): 1,
- ("used", "no_relation"): 1,
- }
- assert caplog.messages == []
- else:
- assert statistics == {}
- assert caplog.messages == [expected_warning]
-
-
-def test_encode_input_handle_relations_with_same_arguments_unknown_value():
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- handle_relations_with_same_arguments="unknown_value",
- )
- document = TestDocument(text="A founded B.", id="test_doc")
- document.entities.append(LabeledSpan(start=0, end=1, label="PER"))
- document.entities.append(LabeledSpan(start=10, end=11, label="PER"))
- document.relations.append(
- BinaryRelation(
- head=document.entities[0], tail=document.entities[1], label="per:founded_by"
- )
- )
- document.relations.append(
- BinaryRelation(head=document.entities[0], tail=document.entities[1], label="per:founder")
- )
- assert document.relations.resolve() == [
- ("per:founded_by", (("PER", "A"), ("PER", "B"))),
- ("per:founder", (("PER", "A"), ("PER", "B"))),
- ]
- taskmodule.prepare([document])
-
- with pytest.raises(ValueError) as excinfo:
- taskmodule.encode_input(document)
- assert str(excinfo.value) == (
- "'handle_relations_with_same_arguments' must be 'keep_first' or 'keep_none', but got `unknown_value`."
- )
-
-
-@pytest.mark.parametrize("handle_relations_with_same_arguments", ["keep_first", "keep_none"])
-@pytest.mark.parametrize("add_candidate_relations", [False, True])
-@pytest.mark.parametrize("collect_statistics", [False, True])
-def test_encode_input_duplicated_relations(
- caplog, handle_relations_with_same_arguments, add_candidate_relations, collect_statistics
-):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- handle_relations_with_same_arguments=handle_relations_with_same_arguments,
- add_candidate_relations=add_candidate_relations,
- collect_statistics=collect_statistics,
- )
- document = TestDocument(text="A founded B.", id="test_doc")
- document.entities.append(LabeledSpan(start=0, end=1, label="PER"))
- document.entities.append(LabeledSpan(start=10, end=11, label="PER"))
- entities = document.entities
- assert str(entities[0]) == "A"
- assert str(entities[1]) == "B"
- document.relations.extend(
- [
- BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"),
- BinaryRelation(head=entities[0], tail=entities[1], label="per:founded_by"),
- ]
- )
- taskmodule.prepare([document])
- with caplog.at_level(logging.WARNING):
- encodings = taskmodule.encode_input(document)
-
- statistics = taskmodule.get_statistics()
-
- assert len(caplog.messages) == 1
- assert (
- caplog.messages[0] == "doc.id=test_doc: Relation annotation "
- "`('per:founded_by', (('PER', 'A'), ('PER', 'B')))` is duplicated. We keep "
- "only one of them. Duplicate won't appear in statistics either as 'available' or as skipped."
- )
- candidate_relation = [enc.metadata["candidate_annotation"] for enc in encodings]
- candidate_relation_tuples = [
- (rel.head.resolve(), rel.label, rel.tail.resolve()) for rel in candidate_relation
- ]
- # equally for 'keep_first' and 'keep_last', full duplicates are not affected and do not appear in statistics, but still
- # generate a warning.
- if add_candidate_relations:
- assert candidate_relation_tuples == [
- (("PER", "A"), "per:founded_by", ("PER", "B")),
- (("PER", "B"), "no_relation", ("PER", "A")),
- ]
- if collect_statistics:
- assert statistics == {
- ("available", "per:founded_by"): 1,
- ("used", "no_relation"): 1,
- ("used", "per:founded_by"): 1,
- }
- else:
- assert statistics == {}
- else:
- assert candidate_relation_tuples == [(("PER", "A"), "per:founded_by", ("PER", "B"))]
- if collect_statistics:
- assert statistics == {
- ("available", "per:founded_by"): 1,
- ("used", "per:founded_by"): 1,
- }
- else:
- assert statistics == {}
-
-
-def test_encode_input_argument_role_unknown(documents):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- # the tail argument is not in the role_to_marker
- argument_role_to_marker={HEAD: "H"},
- )
- taskmodule.prepare(documents)
- with pytest.raises(ValueError) as excinfo:
- taskmodule.encode_input(documents[1])
- assert (
- str(excinfo.value) == "role='tail' not in known roles=['head'] (did you initialise the "
- "taskmodule with the correct argument_role_to_marker dictionary?)"
- )
-
-
-def test_encode_input_with_add_candidate_relations(documents):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- add_candidate_relations=True,
- )
- taskmodule.prepare(documents)
- documents_without_relations = []
- encodings = []
- # just take the first three documents
- for doc in documents[:3]:
- doc_without_relations = doc.copy()
- relations = list(doc_without_relations.relations)
- doc_without_relations.relations.clear()
- # re-add one relation to test if it is kept
- if len(relations) > 0:
- doc_without_relations.relations.append(relations[0])
- documents_without_relations.append(doc_without_relations)
- encodings.extend(taskmodule.encode(doc_without_relations))
-
- assert len(encodings) == 4
- relations = [encoding.metadata["candidate_annotation"] for encoding in encodings]
- texts = [encoding.document.text for encoding in encodings]
- relation_tuples = [(str(rel.head), rel.label, str(rel.tail)) for rel in relations]
-
- # There are no entities in the first document, so there are no created relation candidates
-
- # this relation was kept
- assert texts[0] == "Entity A works at B."
- assert relation_tuples[0] == ("Entity A", "per:employee_of", "B")
-
- # the following relations were added
- assert texts[1] == "Entity A works at B."
- assert relation_tuples[1] == ("B", "no_relation", "Entity A")
- assert texts[2] == "Entity C and D."
- assert relation_tuples[2] == ("Entity C", "no_relation", "D")
- assert texts[3] == "Entity C and D."
- assert relation_tuples[3] == ("D", "no_relation", "Entity C")
-
-
-@pytest.fixture
-def document_with_nary_relations():
- @dataclasses.dataclass
- class TestDocumentWithNaryRelations(TextBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- relations: AnnotationLayer[NaryRelation] = annotation_field(target="entities")
-
- document = TestDocumentWithNaryRelations(
- text="Entity A works at B.", id="doc_with_nary_relations"
- )
- document.entities.append(LabeledSpan(start=0, end=8, label="PER"))
- document.entities.append(LabeledSpan(start=18, end=19, label="PER"))
- document.relations.append(
- NaryRelation(
- arguments=tuple(document.entities),
- roles=tuple(["head", "tail"]),
- label="per:employee_of",
- )
- )
- return document
-
-
-def test_encode_input_with_add_candidate_relations_with_wrong_relation_type(
- document_with_nary_relations,
-):
- doc = document_with_nary_relations
-
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- add_candidate_relations=True,
- argument_role_to_marker={HEAD: "H", "arg2": "T"},
- )
- taskmodule.prepare([doc])
- with pytest.raises(NotImplementedError) as excinfo:
- taskmodule.encode_input(doc)
- assert (
- str(excinfo.value)
- == "doc.id=doc_with_nary_relations: the taskmodule does not yet support adding relation candidates "
- "with argument roles other than 'head' and 'tail': ['arg2', 'head']"
- )
-
-
-def test_filter_relations_by_argument_type_whitelist(documents):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- add_candidate_relations=True,
- argument_type_whitelist=[["PER", "ORG"], ["ORG", "PER"]],
- )
- doc = documents[4]
- taskmodule.prepare(documents)
-
- assert doc.entities.resolve() == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")]
- assert doc.relations.resolve() == [
- ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))),
- ("per:founder", (("PER", "Entity G"), ("ORG", "I"))),
- ("org:founded_by", (("ORG", "I"), ("ORG", "H"))),
- ]
- arguments2relation = {}
- for rel in doc.relations:
- arguments2relation[get_relation_argument_spans_and_roles(rel)] = rel
- assert len(arguments2relation) == 3
-
- taskmodule._filter_relations_by_argument_type_whitelist(arguments2relation=arguments2relation)
- assert len(arguments2relation) == 2
-
- relation_tuples = [rel.resolve() for rel in arguments2relation.values()]
- assert relation_tuples[0] == ("per:employee_of", (("PER", "Entity G"), ("ORG", "H")))
- assert relation_tuples[1] == ("per:founder", (("PER", "Entity G"), ("ORG", "I")))
-
- assert ("org:founded_by", (("ORG", "I"), ("ORG", "H"))) not in relation_tuples
-
-
-def test_add_candidate_relations_with_argument_type_whitelist(documents):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- add_candidate_relations=True,
- argument_type_whitelist=[["PER", "ORG"], ["ORG", "PER"]],
- )
- doc = documents[4]
- taskmodule.prepare(documents)
-
- assert doc.entities.resolve() == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")]
- assert doc.relations.resolve() == [
- ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))),
- ("per:founder", (("PER", "Entity G"), ("ORG", "I"))),
- ("org:founded_by", (("ORG", "I"), ("ORG", "H"))),
- ]
- arguments2relation = {}
- for rel in doc.relations:
- arguments2relation[get_relation_argument_spans_and_roles(rel)] = rel
- assert len(arguments2relation) == 3
-
- taskmodule._add_candidate_relations(
- arguments2relation=arguments2relation, entities=doc.entities
- )
- assert len(arguments2relation) == 5
-
- relation_tuples = [rel.resolve() for rel in arguments2relation.values()]
-
- # Original relations from document (aren't affected by whitelist)
- assert relation_tuples[0] == ("per:employee_of", (("PER", "Entity G"), ("ORG", "H")))
- assert relation_tuples[1] == ("per:founder", (("PER", "Entity G"), ("ORG", "I")))
- assert relation_tuples[2] == ("org:founded_by", (("ORG", "I"), ("ORG", "H")))
-
- # Relation candidate added by _add_candidate_relations()
- assert relation_tuples[3] == ("no_relation", (("ORG", "H"), ("PER", "Entity G")))
- assert relation_tuples[4] == ("no_relation", (("ORG", "I"), ("PER", "Entity G")))
-
- # Relations not created due to whitelist
- assert ("no_relation", (("ORG", "H"), ("ORG", "I"))) not in relation_tuples
-
-
-def test_filter_relations_by_argument_and_relation_type_whitelist(documents):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- add_candidate_relations=True,
- argument_and_relation_type_whitelist={
- "per:employee_of": [["PER", "ORG"]],
- "per:founder": [["PER", "ORG"]],
- "org:founded_by": [["ORG", "PER"]],
- },
- )
- doc = documents[4]
- taskmodule.prepare(documents)
-
- assert doc.entities.resolve() == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")]
- assert doc.relations.resolve() == [
- ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))),
- ("per:founder", (("PER", "Entity G"), ("ORG", "I"))),
- ("org:founded_by", (("ORG", "I"), ("ORG", "H"))),
- ]
- arguments2relation = {}
- for rel in doc.relations:
- arguments2relation[get_relation_argument_spans_and_roles(rel)] = rel
- assert len(arguments2relation) == 3
-
- taskmodule._filter_relations_by_argument_and_relation_type_whitelist(
- arguments2relation=arguments2relation
- )
- assert len(arguments2relation) == 2
-
- relation_tuples = [rel.resolve() for rel in arguments2relation.values()]
- assert relation_tuples[0] == ("per:employee_of", (("PER", "Entity G"), ("ORG", "H")))
- assert relation_tuples[1] == ("per:founder", (("PER", "Entity G"), ("ORG", "I")))
-
- assert ("org:founded_by", (("ORG", "I"), ("ORG", "H"))) not in relation_tuples
-
-
-def test_add_candidate_relations_with_argument_and_relation_type_whitelist(documents):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- add_candidate_relations=True,
- argument_and_relation_type_whitelist={
- "per:employee_of": [["PER", "ORG"]],
- "per:founder": [["PER", "ORG"]],
- "org:founded_by": [["ORG", "PER"]],
- },
- )
- doc = documents[4]
- taskmodule.prepare(documents)
-
- assert doc.entities.resolve() == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")]
- assert doc.relations.resolve() == [
- ("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))),
- ("per:founder", (("PER", "Entity G"), ("ORG", "I"))),
- ("org:founded_by", (("ORG", "I"), ("ORG", "H"))),
- ]
- arguments2relation = {}
- for rel in doc.relations:
- arguments2relation[get_relation_argument_spans_and_roles(rel)] = rel
- assert len(arguments2relation) == 3
-
- taskmodule._add_candidate_relations(
- arguments2relation=arguments2relation, entities=doc.entities
- )
- assert len(arguments2relation) == 5
-
- relation_tuples = [rel.resolve() for rel in arguments2relation.values()]
-
- # Original relations from document (aren't affected by whitelist)
- assert relation_tuples[0] == ("per:employee_of", (("PER", "Entity G"), ("ORG", "H")))
- assert relation_tuples[1] == ("per:founder", (("PER", "Entity G"), ("ORG", "I")))
- assert relation_tuples[2] == ("org:founded_by", (("ORG", "I"), ("ORG", "H")))
-
- # Relation candidate added by _add_candidate_relations()
- assert relation_tuples[3] == ("no_relation", (("ORG", "H"), ("PER", "Entity G")))
- assert relation_tuples[4] == ("no_relation", (("ORG", "I"), ("PER", "Entity G")))
-
- # Relations not created due to whitelist
- assert ("no_relation", (("ORG", "H"), ("ORG", "I"))) not in relation_tuples
-
-
-def test_encode_input_with_add_reversed_relations(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_reversed_relations=True,
- )
- taskmodule.prepare(documents)
- encodings = []
- # just take the first three documents
- for doc in documents[:3]:
- encodings.extend(taskmodule.encode_input(doc))
-
- assert len(encodings) == 2
- texts = [encoding.document.text for encoding in encodings]
- relations = [encoding.metadata["candidate_annotation"] for encoding in encodings]
- relation_tuples = [(str(rel.head), rel.label, str(rel.tail)) for rel in relations]
-
- # There are no relations in the first and last document, so there are also no new reversed relations
-
- # this is the original relation
- assert texts[0] == "Entity A works at B."
- assert relation_tuples[0] == ("Entity A", "per:employee_of", "B")
-
- # this is the reversed relation
- assert texts[1] == "Entity A works at B."
- assert relation_tuples[1] == ("B", "per:employee_of_reversed", "Entity A")
-
- # test that an already reversed relation is not reversed again
- document = TestDocument(
- text="Entity A works at B.", id="doc_with_relation_with_reversed_suffix"
- )
- document.entities.extend(
- [LabeledSpan(start=0, end=8, label="PER"), LabeledSpan(start=18, end=19, label="PER")]
- )
- document.relations.append(
- BinaryRelation(
- head=document.entities[1],
- tail=document.entities[0],
- label=f"per:employee_of{taskmodule.reversed_relation_label_suffix}",
- )
- )
- with pytest.raises(ValueError) as excinfo:
- taskmodule.encode_input(document)
- assert str(excinfo.value) == (
- "doc.id=doc_with_relation_with_reversed_suffix: The relation has the label 'per:employee_of_reversed' "
- "which already ends with the reversed_relation_label_suffix='_reversed'. It looks like the relation is "
- "already reversed, which is not allowed."
- )
-
-
-def test_prepare_with_add_reversed_relations_with_label_has_suffix():
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_reversed_relations=True,
- )
- document = TestDocument(
- text="Entity A works at B.", id="doc_with_relation_with_reversed_suffix"
- )
- document.entities.extend(
- [LabeledSpan(start=0, end=8, label="PER"), LabeledSpan(start=18, end=19, label="PER")]
- )
- document.relations.append(
- BinaryRelation(
- head=document.entities[0],
- tail=document.entities[1],
- label=f"per:employee_of{taskmodule.reversed_relation_label_suffix}",
- )
- )
-
- with pytest.raises(ValueError) as excinfo:
- taskmodule.prepare([document])
- assert (
- str(excinfo.value)
- == "doc.id=doc_with_relation_with_reversed_suffix: the relation label 'per:employee_of_reversed' "
- "already ends with the reversed_relation_label_suffix '_reversed', this is not allowed because "
- "we would not know if we should strip the suffix and revert the arguments during inference or not"
- )
-
-
-@pytest.mark.parametrize("reverse_symmetric_relations", [False, True])
-def test_encode_input_with_add_reversed_relations_with_symmetric_relations(
- reverse_symmetric_relations, caplog
-):
- document = TestDocument(
- text="Entity A is married with B, but likes C, who is married with D.",
- id="doc_with_symmetric_relation",
- )
- document.entities.extend(
- [
- LabeledSpan(start=0, end=8, label="PER"),
- LabeledSpan(start=25, end=26, label="PER"),
- LabeledSpan(start=38, end=39, label="PER"),
- LabeledSpan(start=61, end=62, label="PER"),
- ]
- )
- assert str(document.entities[0]) == "Entity A"
- assert str(document.entities[1]) == "B"
- assert str(document.entities[2]) == "C"
- assert str(document.entities[3]) == "D"
- document.relations.extend(
- [
- BinaryRelation(
- head=document.entities[0], tail=document.entities[1], label="per:is_married_with"
- ),
- BinaryRelation(
- head=document.entities[0], tail=document.entities[2], label="per:likes"
- ),
- BinaryRelation(
- head=document.entities[2], tail=document.entities[3], label="per:is_married_with"
- ),
- BinaryRelation(
- head=document.entities[3], tail=document.entities[2], label="per:is_married_with"
- ),
- ]
- )
-
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_reversed_relations=True,
- symmetric_relations=["per:is_married_with"],
- reverse_symmetric_relations=reverse_symmetric_relations,
- )
- taskmodule.prepare([document])
- encodings = taskmodule.encode_input(document)
- relations = [encoding.metadata["candidate_annotation"] for encoding in encodings]
- relation_tuples = [
- (str(relation.head), relation.label, str(relation.tail)) for relation in relations
- ]
- if reverse_symmetric_relations:
- assert relation_tuples == [
- ("Entity A", "per:is_married_with", "B"),
- ("Entity A", "per:likes", "C"),
- ("C", "per:is_married_with", "D"),
- ("D", "per:is_married_with", "C"),
- ("B", "per:is_married_with", "Entity A"),
- ("C", "per:likes_reversed", "Entity A"),
- ]
- assert len(caplog.messages) == 2
- assert (
- caplog.messages[0]
- == "doc.id=doc_with_symmetric_relation: there is already a relation with reversed "
- "arguments=(('head', LabeledSpan(start=61, end=62, label='PER', score=1.0)), "
- "('tail', LabeledSpan(start=38, end=39, label='PER', score=1.0))) and label=per:is_married_with, "
- "so we do not add the reversed relation (with label per:is_married_with) for these arguments"
- )
- assert (
- caplog.messages[1]
- == "doc.id=doc_with_symmetric_relation: there is already a relation with reversed "
- "arguments=(('head', LabeledSpan(start=38, end=39, label='PER', score=1.0)), "
- "('tail', LabeledSpan(start=61, end=62, label='PER', score=1.0))) and label=per:is_married_with, "
- "so we do not add the reversed relation (with label per:is_married_with) for these arguments"
- )
- else:
- assert relation_tuples == [
- ("Entity A", "per:is_married_with", "B"),
- ("Entity A", "per:likes", "C"),
- ("C", "per:is_married_with", "D"),
- ("D", "per:is_married_with", "C"),
- ("C", "per:likes_reversed", "Entity A"),
- ]
- assert len(caplog.messages) == 0
-
- caplog.clear()
- document = TestDocument(
- text="Entity A is married with B.",
- id="doc_with_reversed_symmetric_relation",
- )
- document.entities.append(LabeledSpan(start=0, end=8, label="PER"))
- document.entities.append(LabeledSpan(start=25, end=26, label="PER"))
- document.relations.append(
- BinaryRelation(
- head=document.entities[1], tail=document.entities[0], label="per:is_married_with"
- )
- )
- encodings = taskmodule.encode_input(document)
- relations = [encoding.metadata["candidate_annotation"] for encoding in encodings]
- relation_tuples = [
- (str(relation.head), relation.label, str(relation.tail)) for relation in relations
- ]
- if reverse_symmetric_relations:
- assert len(relation_tuples) == 2
- assert relation_tuples[0] == ("B", "per:is_married_with", "Entity A")
- assert relation_tuples[1] == ("Entity A", "per:is_married_with", "B")
- assert len(caplog.messages) == 1
- assert (
- caplog.messages[0]
- == "doc.id=doc_with_reversed_symmetric_relation: The symmetric relation with label 'per:is_married_with' "
- "has arguments (('head', LabeledSpan(start=25, end=26, label='PER', score=1.0)), "
- "('tail', LabeledSpan(start=0, end=8, label='PER', score=1.0))) which are not sorted by their start "
- "and end positions. This may lead to problems during evaluation because we assume that the arguments "
- "of symmetric relations were sorted in the beginning and, thus, interpret relations where this is not "
- "the case as reversed. All reversed relations will get their arguments swapped during inference in "
- "the case of add_reversed_relations=True to remove duplicates. You may consider adding reversed "
- "versions of the *symmetric* relations on your own and then setting *reverse_symmetric_relations* "
- "to False."
- )
- else:
- assert len(relation_tuples) == 1
- assert relation_tuples[0] == ("B", "per:is_married_with", "Entity A")
- assert len(caplog.messages) == 0
-
-
-def test_encode_input_with_add_reversed_relations_with_wrong_relation_type(
- document_with_nary_relations,
-):
- doc = document_with_nary_relations
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- add_reversed_relations=True,
- symmetric_relations=["per:employee_of"],
- )
- taskmodule.prepare([doc])
- with pytest.raises(NotImplementedError) as excinfo:
- taskmodule.encode_input(doc)
- assert (
- str(excinfo.value)
- == "doc.id=doc_with_nary_relations: the taskmodule does not yet support adding "
- "reversed relations for type: "
- )
-
-
-def test_inner_span_distance_overlap():
- dist = distance_inner((0, 2), (1, 3))
- assert dist == -1
-
-
-def test_span_distance_unknown_type():
- with pytest.raises(ValueError) as excinfo:
- span_distance((0, 1), (2, 3), "unknown")
- assert str(excinfo.value) == "unknown distance_type=unknown. use one of: center, inner, outer"
-
-
-def test_encode_input_with_max_argument_distance():
- document = TestDocument(
- text="Entity A works at B and C.", id="doc_with_three_entities_and_two_relations"
- )
- e0 = LabeledSpan(start=0, end=8, label="PER")
- e1 = LabeledSpan(start=18, end=19, label="PER")
- e2 = LabeledSpan(start=24, end=25, label="PER")
- document.entities.extend([e0, e1, e2])
- assert str(document.entities[0]) == "Entity A"
- assert str(document.entities[1]) == "B"
- assert str(document.entities[2]) == "C"
- document.relations.append(
- BinaryRelation(
- head=document.entities[0], tail=document.entities[1], label="per:employee_of"
- )
- )
- document.relations.append(
- BinaryRelation(
- head=document.entities[0], tail=document.entities[2], label="per:employee_of"
- )
- )
- dist_01 = span_distance((e0.start, e0.end), (e1.start, e1.end), "inner")
- dist_02 = span_distance((e0.start, e0.end), (e2.start, e2.end), "inner")
- assert dist_01 == 10
- assert dist_02 == 16
-
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- max_argument_distance=10,
- )
- taskmodule.prepare([document])
- encodings = taskmodule.encode_input(document)
-
- # there are two relations, but only one is within the max_argument_distance
- assert len(encodings) == 1
- relation = encodings[0].metadata["candidate_annotation"]
- assert str(relation.head) == "Entity A"
- assert str(relation.tail) == "B"
- assert relation.label == "per:employee_of"
-
-
-def test_encode_input_with_max_argument_distance_with_wrong_relation_type(
- document_with_nary_relations,
-):
- doc = document_with_nary_relations
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- max_argument_distance=10,
- )
- taskmodule.prepare([doc])
- with pytest.raises(NotImplementedError) as excinfo:
- encodings = taskmodule.encode_input(doc)
- assert (
- str(excinfo.value)
- == "doc.id=doc_with_nary_relations: the taskmodule does not yet support filtering "
- "relation candidates for type: "
- )
-
-
-@pytest.mark.parametrize("distance_type", ["inner", "outer", "unknown"])
-def test_encode_input_with_max_argument_distance_tokens(distance_type):
- document = TestDocument(
- text="Entity A works at B and C.", id="doc_with_three_entities_and_two_relations"
- )
- e0 = LabeledSpan(start=0, end=8, label="PER")
- e1 = LabeledSpan(start=18, end=19, label="PER")
- e2 = LabeledSpan(start=24, end=25, label="PER")
- document.entities.extend([e0, e1, e2])
- assert str(document.entities[0]) == "Entity A"
- assert str(document.entities[1]) == "B"
- assert str(document.entities[2]) == "C"
- document.relations.append(
- BinaryRelation(
- head=document.entities[0], tail=document.entities[1], label="per:employee_of"
- )
- )
- document.relations.append(
- BinaryRelation(
- head=document.entities[0], tail=document.entities[2], label="per:employee_of"
- )
- )
- dist_01 = span_distance((e0.start, e0.end), (e1.start, e1.end), "inner")
- dist_02 = span_distance((e0.start, e0.end), (e2.start, e2.end), "inner")
- assert dist_01 == 10
- assert dist_02 == 16
-
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- max_argument_distance_tokens=3 if distance_type == "inner" else 8,
- max_argument_distance_type_tokens=distance_type,
- )
- taskmodule.prepare([document])
- if distance_type == "unknown":
- with pytest.raises(ValueError) as excinfo:
- taskmodule.encode_input(document)
- assert (
- str(excinfo.value) == "unknown distance_type=unknown. use one of: center, inner, outer"
- )
- return
-
- encodings = taskmodule.encode_input(document)
-
- # there are two relations, but only one is within the max_argument_distance
- assert len(encodings) == 1
- encoding = encodings[0]
- tokens = taskmodule.tokenizer.convert_ids_to_tokens(encoding.inputs["input_ids"])
- assert tokens == [
- "[CLS]",
- "[H]",
- "En",
- "##ti",
- "##ty",
- "A",
- "[/H]",
- "works",
- "at",
- "[T]",
- "B",
- "[/T]",
- "and",
- "C",
- ".",
- "[SEP]",
- ]
- head_start = tokens.index("[H]") + 1
- head_end = tokens.index("[/H]")
- tail_start = tokens.index("[T]") + 1
- tail_end = tokens.index("[/T]")
- assert (head_start, head_end, tail_start, tail_end) == (2, 6, 10, 11)
- # subtract 2 for the special marker tokens [/H] and [T]
- inner_dist = tail_start - head_end - 2
- assert inner_dist == 2
- # subtract 2 for the special marker tokens [H] and [/T]
- outer_dist = tail_end - head_start - 2
- assert outer_dist == 7
-
- relation = encodings[0].metadata["candidate_annotation"]
- assert str(relation.head) == "Entity A"
- assert str(relation.tail) == "B"
- assert relation.label == "per:employee_of"
-
-
-def test_encode_input_with_unknown_label():
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- labels=["rel"],
- entity_labels=["a", "b"],
- collect_statistics=True,
- )
- taskmodule.post_prepare()
-
- doc = TestDocument(text="hello world", id="doc_with_unknown_label")
- doc.entities.append(LabeledSpan(start=0, end=5, label="a"))
- doc.entities.append(LabeledSpan(start=6, end=11, label="b"))
- doc.relations.append(
- BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="unknown")
- )
-
- task_encodings = taskmodule.encode_input(doc)
- assert len(task_encodings) == 0
-
- statistics = taskmodule.get_statistics()
- assert statistics == {("available", "unknown"): 1, ("skipped_unknown_label", "unknown"): 1}
-
-
-def test_encode_with_empty_partition_layer(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- partition_annotation="sentences",
- )
- taskmodule.prepare(documents)
- documents_without_sentences = []
- # just take the first three documents
- for doc in documents[:3]:
- doc_without_sentences = doc.copy()
- doc_without_sentences.sentences.clear()
- documents_without_sentences.append(doc_without_sentences)
-
- encodings = taskmodule.encode(documents_without_sentences)
- # since there are no sentences, but we use partition_annotation="sentences",
- # there are no encodings
- assert len(encodings) == 0
-
-
-def test_encode_nary_relatio():
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- argument_role_to_marker={"r1": "R1", "r2": "R2", "r3": "R3"},
- # setting labels and entity_labels makes the taskmodule prepared
- labels=["rel"],
- entity_labels=["a", "b", "c"],
- )
- taskmodule._post_prepare()
-
- @dataclass
- class DocWithNaryRelation(TextBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- relations: AnnotationLayer[NaryRelation] = annotation_field(target="entities")
-
- doc = DocWithNaryRelation(text="hello my world")
- entity1 = LabeledSpan(start=0, end=5, label="a")
- entity2 = LabeledSpan(start=6, end=8, label="b")
- entity3 = LabeledSpan(start=9, end=14, label="c")
- doc.entities.extend([entity1, entity2, entity3])
- doc.relations.append(
- NaryRelation(
- arguments=tuple([entity1, entity2, entity3]),
- roles=tuple(["r1", "r2", "r3"]),
- label="rel",
- )
- )
-
- task_encodings = taskmodule.encode([doc])
- assert len(task_encodings) == 1
- encoding = task_encodings[0]
- assert encoding.document == doc
- assert encoding.document.text == "hello my world"
- rel = encoding.metadata["candidate_annotation"]
- assert str(rel.arguments[0]) == "hello"
- assert str(rel.arguments[1]) == "my"
- assert str(rel.arguments[2]) == "world"
- assert rel.label == "rel"
-
-
-def test_encode_unknown_relation_type():
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- # setting labels and entity_labels makes the taskmodule prepared
- labels=["has_wrong_type"],
- entity_labels=["a"],
- )
- taskmodule._post_prepare()
-
- @dataclass(frozen=True)
- class UnknownRelation(Annotation):
- arg: LabeledSpan
- label: str
-
- @dataclass
- class DocWithUnknownRelationType(TextBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- relations: AnnotationLayer[UnknownRelation] = annotation_field(target="entities")
-
- doc = DocWithUnknownRelationType(text="hello world")
- entity = LabeledSpan(start=0, end=1, label="a")
- doc.entities.append(entity)
- doc.relations.append(UnknownRelation(arg=entity, label="has_wrong_type"))
-
- with pytest.raises(NotImplementedError) as excinfo:
- taskmodule.encode([doc])
- assert str(excinfo.value).startswith(
- "the taskmodule does not yet support getting relation arguments for type: "
- ) and str(excinfo.value).endswith(".UnknownRelation'>")
-
-
-def test_encode_with_unaligned_span(caplog):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- # setting v and entity_labels makes the taskmodule prepared
- labels=["rel"],
- entity_labels=["a"],
- )
- taskmodule._post_prepare()
-
- @dataclass
- class MyDocument(TextBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
-
- doc = MyDocument(text="hello space", id="doc1")
- entity1 = LabeledSpan(start=0, end=5, label="a")
- entity2 = LabeledSpan(start=7, end=13, label="a")
- entity3 = LabeledSpan(start=6, end=8, label="a")
- doc.entities.extend([entity1, entity2, entity3])
- # the start of entity2 is not aligned with a token, but this will get fixed
- assert str(entity2) == " space"
- doc.relations.append(BinaryRelation(head=entity1, tail=entity2, label="rel"))
- # entity3 can not get fixed because it contains only space
- assert str(entity3) == " "
- doc.relations.append(BinaryRelation(head=entity1, tail=entity3, label="rel"))
-
- task_encodings = taskmodule.encode([doc])
- # the second relation is skipped because we can not get an aligned token span for it
- assert len(task_encodings) == 1
- task_encoding = task_encodings[0]
- tokens = taskmodule.tokenizer.convert_ids_to_tokens(task_encoding.inputs["input_ids"])
- assert tokens == ["[CLS]", "[H]", "hello", "[/H]", "[T]", "space", "[/T]", "[SEP]"]
-
- assert len(caplog.records) == 1
- assert caplog.records[0].levelname == "WARNING"
- assert (
- caplog.messages[0]
- == "doc.id=doc1: Skipping invalid example, cannot get argument token slice for LabeledSpan(start=6, end=8, label='a', score=1.0): \" \""
- )
-
-
-def test_encode_with_log_first_n_examples(caplog):
- @dataclass
- class DocumentWithLabeledEntitiesAndRelations(TextBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
- relations: AnnotationLayer[BinaryRelation] = annotation_field(target="entities")
-
- doc = DocumentWithLabeledEntitiesAndRelations(text="hello world", id="doc1")
- entity1 = LabeledSpan(start=0, end=5, label="a")
- entity2 = LabeledSpan(start=6, end=11, label="a")
- doc.entities.extend([entity1, entity2])
- doc.relations.append(BinaryRelation(head=entity1, tail=entity2, label="rel"))
-
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- log_first_n_examples=1,
- )
- taskmodule.prepare([doc])
-
- # we need to set the log level to INFO, otherwise the log messages are not captured
- with caplog.at_level(logging.INFO):
- task_encodings = taskmodule.encode([doc, doc], encode_target=True)
-
- # the second example is skipped because log_first_n_examples=1
- assert len(task_encodings) == 2
- assert len(caplog.records) == 5
- assert all([record.levelname == "INFO" for record in caplog.records])
- assert caplog.records[0].message == "*** Example ***"
- assert caplog.records[1].message == "doc id: doc1"
- assert caplog.records[2].message == "tokens: [CLS] [H] hello [/H] [T] world [/T] [SEP]"
- assert caplog.records[3].message == "input_ids: 101 28998 19082 28996 28999 1362 28997 102"
- assert caplog.records[4].message == "Expected label: ['rel'] (ids = [1])"
-
-
-@pytest.mark.skipif(condition=not _TABULATE_AVAILABLE, reason="requires the 'tabulate' package")
-def test_encode_with_collect_statistics(documents):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- collect_statistics=True,
- )
- taskmodule.prepare(documents)
- task_encodings = taskmodule.encode(documents)
- statistics = taskmodule.get_statistics()
- assert len(task_encodings) == 7
-
- assert statistics == {
- ("available", "org:founded_by"): 2,
- ("available", "per:employee_of"): 3,
- ("available", "per:founder"): 2,
- ("used", "org:founded_by"): 2,
- ("used", "per:employee_of"): 3,
- ("used", "per:founder"): 2,
- }
-
-
-def test_get_global_attention(taskmodule, batch, cfg):
- global_attention_mask = taskmodule._get_global_attention(input_ids=batch[0]["input_ids"])
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(token_ids)
- for token_ids in batch[0]["input_ids"].tolist()
- ]
- global_attention_tokens = [
- [tok for tok, m in zip(tkns, glob_attn_mask) if m]
- for tkns, glob_attn_mask in zip(tokens, global_attention_mask)
- ]
- pad_tok = taskmodule.tokenizer.pad_token
- not_global_attention_tokens = [
- [tok for tok, m in zip(tkns, glob_attn_mask) if not (m or tok == pad_tok)]
- for tkns, glob_attn_mask in zip(tokens, global_attention_mask)
- ]
- if not cfg.get("append_markers", False):
- torch.testing.assert_close(
- global_attention_mask,
- torch.tensor(
- [
- [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0],
- ]
- ),
- )
- assert not_global_attention_tokens == [
- ["En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]"],
- [
- "First",
- "sentence",
- ".",
- "En",
- "##ti",
- "##ty",
- "G",
- "works",
- "at",
- "H",
- ".",
- "And",
- "founded",
- "I",
- ".",
- "[SEP]",
- ],
- ]
- else:
- torch.testing.assert_close(
- global_attention_mask,
- torch.tensor(
- [
- [1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
- [1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0],
- ]
- ),
- )
- assert not_global_attention_tokens == [
- ["En", "##ti", "##ty", "A", "works", "at", "B", ".", "[SEP]", "[SEP]", "[SEP]"],
- [
- "First",
- "sentence",
- ".",
- "En",
- "##ti",
- "##ty",
- "G",
- "works",
- "at",
- "H",
- ".",
- "And",
- "founded",
- "I",
- ".",
- "[SEP]",
- "[SEP]",
- "[SEP]",
- ],
- ]
-
- if cfg == {"add_type_to_marker": False, "append_markers": False}:
- assert global_attention_tokens == [
- ["[CLS]", "[H]", "[/H]", "[T]", "[/T]"],
- ["[CLS]", "[H]", "[/H]", "[T]", "[/T]"],
- ]
- elif cfg == {"add_type_to_marker": True, "append_markers": False}:
- assert global_attention_tokens == [
- ["[CLS]", "[H:PER]", "[/H:PER]", "[T:ORG]", "[/T:ORG]"],
- ["[CLS]", "[H:PER]", "[/H:PER]", "[T:ORG]", "[/T:ORG]"],
- ]
- elif cfg == {"add_type_to_marker": False, "append_markers": True}:
- assert global_attention_tokens == [
- ["[CLS]", "[H]", "[/H]", "[T]", "[/T]", "[H=PER]", "[T=ORG]"],
- ["[CLS]", "[H]", "[/H]", "[T]", "[/T]", "[H=PER]", "[T=ORG]"],
- ]
- elif cfg == {"add_type_to_marker": True, "append_markers": True}:
- assert global_attention_tokens == [
- ["[CLS]", "[H:PER]", "[/H:PER]", "[T:ORG]", "[/T:ORG]", "[H=PER]", "[T=ORG]"],
- ["[CLS]", "[H:PER]", "[/H:PER]", "[T:ORG]", "[/T:ORG]", "[H=PER]", "[T=ORG]"],
- ]
- else:
- raise ValueError(f"unexpected config: {cfg}")
-
-
-def get_metric_state(metric_or_collection: Union[Metric, MetricCollection]) -> Dict[str, Any]:
- if isinstance(metric_or_collection, Metric):
- return {
- k: v.tolist() for k, v in flatten_dict_s(metric_or_collection.metric_state).items()
- }
- elif isinstance(metric_or_collection, MetricCollection):
- return flatten_dict_s({k: get_metric_state(v) for k, v in metric_or_collection.items()})
- else:
- raise ValueError(f"unsupported type: {type(metric_or_collection)}")
-
-
-def test_configure_model_metric(documents, taskmodule):
- task_encodings = taskmodule.encode(documents, encode_target=True)
- batch = taskmodule.collate(task_encodings)
-
- metric = taskmodule.configure_model_metric(stage="train")
- assert isinstance(metric, (Metric, MetricCollection))
- state = get_metric_state(metric)
- assert state == {
- "micro/f1_without_tn/tp": [0],
- "micro/f1_without_tn/fp": [0],
- "micro/f1_without_tn/tn": [0],
- "micro/f1_without_tn/fn": [0],
- "with_tn/f1_per_label/tp": [0, 0, 0, 0],
- "with_tn/f1_per_label/fp": [0, 0, 0, 0],
- "with_tn/f1_per_label/tn": [0, 0, 0, 0],
- "with_tn/f1_per_label/fn": [0, 0, 0, 0],
- "with_tn/macro/f1/tp": [0, 0, 0, 0],
- "with_tn/macro/f1/fp": [0, 0, 0, 0],
- "with_tn/macro/f1/tn": [0, 0, 0, 0],
- "with_tn/macro/f1/fn": [0, 0, 0, 0],
- "with_tn/micro/f1/tp": [0],
- "with_tn/micro/f1/fp": [0],
- "with_tn/micro/f1/tn": [0],
- "with_tn/micro/f1/fn": [0],
- }
- assert metric.compute() == {
- "no_relation/f1": tensor(0.0),
- "org:founded_by/f1": tensor(0.0),
- "per:employee_of/f1": tensor(0.0),
- "per:founder/f1": tensor(0.0),
- "macro/f1": tensor(0.0),
- "micro/f1": tensor(0.0),
- "micro/f1_without_tn": tensor(0.0),
- }
-
- targets = batch[1]
- metric.update(targets, targets)
- state = get_metric_state(metric)
- assert state == {
- "micro/f1_without_tn/tp": [7],
- "micro/f1_without_tn/fp": [0],
- "micro/f1_without_tn/tn": [21],
- "micro/f1_without_tn/fn": [0],
- "with_tn/f1_per_label/tp": [0, 2, 3, 2],
- "with_tn/f1_per_label/fp": [0, 0, 0, 0],
- "with_tn/f1_per_label/tn": [7, 5, 4, 5],
- "with_tn/f1_per_label/fn": [0, 0, 0, 0],
- "with_tn/macro/f1/tp": [0, 2, 3, 2],
- "with_tn/macro/f1/fp": [0, 0, 0, 0],
- "with_tn/macro/f1/tn": [7, 5, 4, 5],
- "with_tn/macro/f1/fn": [0, 0, 0, 0],
- "with_tn/micro/f1/tp": [7],
- "with_tn/micro/f1/fp": [0],
- "with_tn/micro/f1/tn": [21],
- "with_tn/micro/f1/fn": [0],
- }
- assert metric.compute() == {
- "no_relation/f1": tensor(0.0),
- "org:founded_by/f1": tensor(1.0),
- "per:employee_of/f1": tensor(1.0),
- "per:founder/f1": tensor(1.0),
- "macro/f1": tensor(1.0),
- "micro/f1": tensor(1.0),
- "micro/f1_without_tn": tensor(1.0),
- }
-
- metric.reset()
- modified_targets = {"labels": torch.tensor([2, 2, 3, 1, 2, 0, 1])}
- # three positive matches and one true negative
- random_predictions = {"labels": torch.tensor([1, 1, 3, 1, 2, 0, 0])}
- metric.update(random_predictions, modified_targets)
- state = get_metric_state(metric)
- assert state == {
- "micro/f1_without_tn/tp": [3],
- "micro/f1_without_tn/fp": [3],
- "micro/f1_without_tn/tn": [15],
- "micro/f1_without_tn/fn": [3],
- "with_tn/f1_per_label/tp": [1, 1, 1, 1],
- "with_tn/f1_per_label/fp": [1, 2, 0, 0],
- "with_tn/f1_per_label/tn": [5, 3, 4, 6],
- "with_tn/f1_per_label/fn": [0, 1, 2, 0],
- "with_tn/macro/f1/tp": [1, 1, 1, 1],
- "with_tn/macro/f1/fp": [1, 2, 0, 0],
- "with_tn/macro/f1/tn": [5, 3, 4, 6],
- "with_tn/macro/f1/fn": [0, 1, 2, 0],
- "with_tn/micro/f1/tp": [4],
- "with_tn/micro/f1/fp": [3],
- "with_tn/micro/f1/tn": [18],
- "with_tn/micro/f1/fn": [3],
- }
- # created with torch.set_printoptions(precision=6)
- torch.testing.assert_close(
- metric.compute(),
- {
- "no_relation/f1": tensor(0.666667),
- "org:founded_by/f1": tensor(0.400000),
- "per:employee_of/f1": tensor(0.500000),
- "per:founder/f1": tensor(1.0),
- "macro/f1": tensor(0.641667),
- "micro/f1": tensor(0.571429),
- "micro/f1_without_tn": tensor(0.500000),
- },
- )
-
- # no targets and no predictions
- metric.reset()
- no_targets = {"labels": torch.tensor([0, 0, 0])}
- no_predictions = {"labels": torch.tensor([0, 0, 0])}
- metric.update(no_targets, no_predictions)
- state = get_metric_state(metric)
-
- assert state == {
- "micro/f1_without_tn/tp": [0],
- "micro/f1_without_tn/fp": [0],
- "micro/f1_without_tn/tn": [0],
- "micro/f1_without_tn/fn": [0],
- "with_tn/f1_per_label/tp": [3, 0, 0, 0],
- "with_tn/f1_per_label/fp": [0, 0, 0, 0],
- "with_tn/f1_per_label/tn": [0, 3, 3, 3],
- "with_tn/f1_per_label/fn": [0, 0, 0, 0],
- "with_tn/macro/f1/tp": [3, 0, 0, 0],
- "with_tn/macro/f1/fp": [0, 0, 0, 0],
- "with_tn/macro/f1/tn": [0, 3, 3, 3],
- "with_tn/macro/f1/fn": [0, 0, 0, 0],
- "with_tn/micro/f1/tp": [3],
- "with_tn/micro/f1/fp": [0],
- "with_tn/micro/f1/tn": [9],
- "with_tn/micro/f1/fn": [0],
- }
- torch.testing.assert_close(
- metric.compute(),
- {
- "micro/f1_without_tn": tensor(0.0),
- "no_relation/f1": tensor(1.0),
- "org:founded_by/f1": tensor(0.0),
- "per:employee_of/f1": tensor(0.0),
- "per:founder/f1": tensor(0.0),
- "macro/f1": tensor(1.0),
- "micro/f1": tensor(1.0),
- },
- )
-
- # ensure that the metric can be pickled
- pickle.dumps(metric)
-
-
-def get_bio_tag(tag_id: int, idx2label: Dict[int, str]) -> str:
- if tag_id == 0:
- return "O"
- tag_id -= 1
- label = idx2label[tag_id // 2]
- if tag_id % 2 == 0:
- return f"B-{label}"
- else:
- return f"I-{label}"
-
-
-def test_encode_without_insert_marker_but_argument_tags(documents):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- insert_markers=False,
- add_argument_tags_to_input=True,
- )
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- assert len(documents) == 7
- encodings = taskmodule.encode(documents)
- batch = taskmodule.collate(encodings)
- inputs, targets = batch
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
-
- idx2role = {v: k for k, v in taskmodule.argument_role2idx.items()}
- argument_tag_ids = [
- [get_bio_tag(tag_id, idx2role) for tag_id in (argument_tags - 1).tolist() if tag_id >= 0]
- for argument_tags in inputs["argument_tags"]
- ]
- tokens_with_tags = [
- [(tok, tag) for tok, tag in zip(tkns, tags)]
- for tkns, tags in zip(tokens, argument_tag_ids)
- ]
- assert tokens_with_tags == [
- [
- ("[CLS]", "O"),
- ("En", "B-head"),
- ("##ti", "I-head"),
- ("##ty", "I-head"),
- ("A", "I-head"),
- ("works", "O"),
- ("at", "O"),
- ("B", "B-tail"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("En", "B-head"),
- ("##ti", "I-head"),
- ("##ty", "I-head"),
- ("G", "I-head"),
- ("works", "O"),
- ("at", "O"),
- ("H", "B-tail"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("I", "O"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("En", "B-head"),
- ("##ti", "I-head"),
- ("##ty", "I-head"),
- ("G", "I-head"),
- ("works", "O"),
- ("at", "O"),
- ("H", "O"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("I", "B-tail"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("En", "O"),
- ("##ti", "O"),
- ("##ty", "O"),
- ("G", "O"),
- ("works", "O"),
- ("at", "O"),
- ("H", "B-tail"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("I", "B-head"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("En", "B-head"),
- ("##ti", "I-head"),
- ("##ty", "I-head"),
- ("M", "I-head"),
- ("works", "O"),
- ("at", "O"),
- ("N", "B-tail"),
- (".", "O"),
- ("And", "O"),
- ("it", "O"),
- ("founded", "O"),
- ("O", "O"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("En", "O"),
- ("##ti", "O"),
- ("##ty", "O"),
- ("M", "O"),
- ("works", "O"),
- ("at", "O"),
- ("N", "O"),
- (".", "O"),
- ("And", "O"),
- ("it", "B-head"),
- ("founded", "O"),
- ("O", "B-tail"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("En", "O"),
- ("##ti", "O"),
- ("##ty", "O"),
- ("M", "O"),
- ("works", "O"),
- ("at", "O"),
- ("N", "O"),
- (".", "O"),
- ("And", "O"),
- ("it", "B-tail"),
- ("founded", "O"),
- ("O", "B-head"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- ]
-
-
-@pytest.mark.parametrize("add_argument_indices_to_input", [True, False])
-def test_encode_without_insert_marker_but_argument_tags_and_windowing(
- documents, add_argument_indices_to_input
-):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_argument_indices_to_input=add_argument_indices_to_input,
- add_argument_tags_to_input=True,
- max_window=8,
- insert_markers=False,
- )
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- encodings = taskmodule.encode(documents, encode_target=True)
- assert len(encodings) == 3
- batch = taskmodule.collate(encodings)
- inputs, targets = batch
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
-
- if add_argument_indices_to_input:
- arg_spans = [
- get_arg_token_span(
- current_tokens,
- current_start_indices,
- current_end_indices,
- taskmodule.argument_role2idx,
- )
- for current_tokens, current_start_indices, current_end_indices in zip(
- tokens,
- inputs["pooler_start_indices"].tolist(),
- inputs["pooler_end_indices"].tolist(),
- )
- ]
-
- assert arg_spans == [
- {"head": ["I"], "tail": ["H"]},
- {"head": ["it"], "tail": ["O"]},
- {"head": ["O"], "tail": ["it"]},
- ]
-
- idx2role = {v: k for k, v in taskmodule.argument_role2idx.items()}
- argument_tag_ids = [
- [get_bio_tag(tag_id, idx2role) for tag_id in (argument_tags - 1).tolist() if tag_id >= 0]
- for argument_tags in inputs["argument_tags"]
- ]
- tokens_with_tags = [
- [(tok, tag) for tok, tag in zip(tkns, tags)]
- for tkns, tags in zip(tokens, argument_tag_ids)
- ]
- assert tokens_with_tags == [
- [
- ("[CLS]", "O"),
- ("at", "O"),
- ("H", "B-tail"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("I", "B-head"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- (".", "O"),
- ("And", "O"),
- ("it", "B-head"),
- ("founded", "O"),
- ("O", "B-tail"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- (".", "O"),
- ("And", "O"),
- ("it", "B-tail"),
- ("founded", "O"),
- ("O", "B-head"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- ]
-
-
-@pytest.mark.parametrize("insert_markers", [True, False])
-def test_encode_with_add_entity_tags_to_input(documents, insert_markers):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_entity_tags_to_input=True,
- insert_markers=insert_markers,
- )
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- encodings = taskmodule.encode(documents)
- assert len(encodings) == 7
- batch = taskmodule.collate(encodings)
- inputs, targets = batch
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
-
- idx2label = {k: v for k, v in enumerate(taskmodule.entity_labels)}
- entity_tag_ids = [
- [get_bio_tag(tag_id, idx2label) for tag_id in (argument_tags - 1).tolist() if tag_id >= 0]
- for argument_tags in inputs["entity_tags"]
- ]
- tokens_with_tags = [
- [(tok, tag) for tok, tag in zip(tkns, tags)] for tkns, tags in zip(tokens, entity_tag_ids)
- ]
- if insert_markers:
- assert tokens_with_tags[:3] == [
- [
- ("[CLS]", "O"),
- ("[H]", "O"),
- ("En", "B-PER"),
- ("##ti", "I-PER"),
- ("##ty", "I-PER"),
- ("A", "I-PER"),
- ("[/H]", "O"),
- ("works", "O"),
- ("at", "O"),
- ("[T]", "O"),
- ("B", "B-ORG"),
- ("[/T]", "O"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("[H]", "O"),
- ("En", "B-PER"),
- ("##ti", "I-PER"),
- ("##ty", "I-PER"),
- ("G", "I-PER"),
- ("[/H]", "O"),
- ("works", "O"),
- ("at", "O"),
- ("[T]", "O"),
- ("H", "B-ORG"),
- ("[/T]", "O"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("I", "B-ORG"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("[H]", "O"),
- ("En", "B-PER"),
- ("##ti", "I-PER"),
- ("##ty", "I-PER"),
- ("G", "I-PER"),
- ("[/H]", "O"),
- ("works", "O"),
- ("at", "O"),
- ("H", "B-ORG"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("[T]", "O"),
- ("I", "B-ORG"),
- ("[/T]", "O"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- ]
- else:
- assert tokens_with_tags[:3] == [
- [
- ("[CLS]", "O"),
- ("En", "B-PER"),
- ("##ti", "I-PER"),
- ("##ty", "I-PER"),
- ("A", "I-PER"),
- ("works", "O"),
- ("at", "O"),
- ("B", "B-ORG"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("En", "B-PER"),
- ("##ti", "I-PER"),
- ("##ty", "I-PER"),
- ("G", "I-PER"),
- ("works", "O"),
- ("at", "O"),
- ("H", "B-ORG"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("I", "B-ORG"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- ("First", "O"),
- ("sentence", "O"),
- (".", "O"),
- ("En", "B-PER"),
- ("##ti", "I-PER"),
- ("##ty", "I-PER"),
- ("G", "I-PER"),
- ("works", "O"),
- ("at", "O"),
- ("H", "B-ORG"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("I", "B-ORG"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- ]
-
-
-@pytest.mark.parametrize("insert_markers", [True, False])
-def test_encode_with_add_entity_tags_to_input_windowing(documents, insert_markers):
- tokenizer_name_or_path = "bert-base-cased"
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path=tokenizer_name_or_path,
- add_entity_tags_to_input=True,
- insert_markers=insert_markers,
- max_window=12 if insert_markers else 8,
- )
- assert not taskmodule.is_from_pretrained
- taskmodule.prepare(documents)
-
- encodings = taskmodule.encode(documents, encode_target=True)
- assert len(encodings) == 3
- batch = taskmodule.collate(encodings)
- inputs, targets = batch
- tokens = [
- taskmodule.tokenizer.convert_ids_to_tokens(input_ids) for input_ids in inputs["input_ids"]
- ]
-
- idx2label = {k: v for k, v in enumerate(taskmodule.entity_labels)}
- entity_tag_ids = [
- [get_bio_tag(tag_id, idx2label) for tag_id in (argument_tags - 1).tolist() if tag_id >= 0]
- for argument_tags in inputs["entity_tags"]
- ]
- tokens_with_tags = [
- [(tok, tag) for tok, tag in zip(tkns, tags)] for tkns, tags in zip(tokens, entity_tag_ids)
- ]
-
- if insert_markers:
- assert tokens_with_tags == [
- [
- ("[CLS]", "O"),
- ("at", "O"),
- ("[T]", "O"),
- ("H", "B-ORG"),
- ("[/T]", "O"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("[H]", "O"),
- ("I", "B-ORG"),
- ("[/H]", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- (".", "O"),
- ("And", "O"),
- ("[H]", "O"),
- ("it", "B-PER"),
- ("[/H]", "O"),
- ("founded", "O"),
- ("[T]", "O"),
- ("O", "B-ORG"),
- ("[/T]", "O"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- (".", "O"),
- ("And", "O"),
- ("[T]", "O"),
- ("it", "B-PER"),
- ("[/T]", "O"),
- ("founded", "O"),
- ("[H]", "O"),
- ("O", "B-ORG"),
- ("[/H]", "O"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- ]
- else:
- assert tokens_with_tags == [
- [
- ("[CLS]", "O"),
- ("at", "O"),
- ("H", "B-ORG"),
- (".", "O"),
- ("And", "O"),
- ("founded", "O"),
- ("I", "B-ORG"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- (".", "O"),
- ("And", "O"),
- ("it", "B-PER"),
- ("founded", "O"),
- ("O", "B-ORG"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- [
- ("[CLS]", "O"),
- (".", "O"),
- ("And", "O"),
- ("it", "B-PER"),
- ("founded", "O"),
- ("O", "B-ORG"),
- (".", "O"),
- ("[SEP]", "O"),
- ],
- ]
-
-
-@pytest.mark.parametrize("add_candidate_relations", [False, True])
-def test_create_annotations_from_output(add_candidate_relations):
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- # pass in the labels and entity_labels to avoid calling prepare
- # (which would required documents to collect the labels from)
- labels=["org:founded_by", "per:employee_of", "per:founder"],
- entity_labels=["PER", "ORG"],
- # we want to test the effect of creating candidate relations
- add_candidate_relations=add_candidate_relations,
- )
- # just call post_prepare to set up the taskmodule since labels
- # and entity_labels are already set
- taskmodule.post_prepare()
-
- entities = [
- LabeledSpan(start=16, end=24, label="PER"),
- LabeledSpan(start=34, end=35, label="ORG"),
- LabeledSpan(start=49, end=50, label="ORG"),
- ]
-
- assert taskmodule.none_label == "no_relation"
- candidate_relations = [
- BinaryRelation(head=entities[0], tail=entities[1], label="no_relation"),
- BinaryRelation(head=entities[0], tail=entities[2], label="no_relation"),
- BinaryRelation(head=entities[2], tail=entities[1], label="no_relation"),
- ]
-
- # Just create the task encodings with dummy inputs and a dummy document since
- # we do not want to pass them into the model, but add correct metadata
- # (which is used to create the annotations).
- task_encodings = [
- TaskEncoding(inputs={}, metadata={"candidate_annotation": rel}, document=Document())
- for rel in candidate_relations
- ]
- unbatched_model_outputs = [
- {"labels": ["per:employee_of"], "probabilities": [0.6000000238418579]},
- {"labels": ["per:founder"], "probabilities": [0.5]},
- {"labels": ["no_relation"], "probabilities": [0.6000000238418579]},
- ]
-
- result_flat = []
- for i in range(len(unbatched_model_outputs)):
- result_flat.extend(
- list(
- taskmodule.create_annotations_from_output(
- task_encoding=task_encodings[i], task_output=unbatched_model_outputs[i]
- )
- )
- )
-
- # The entities need to be added to a document. This is only required to resolve
- # the relations later on for better readability!
- document = TestDocument(text="First sentence. Entity G works at H. And founded I.")
- document.entities.extend(entities)
-
- # this would be the "model input"
- assert [rel.resolve() for rel in candidate_relations] == [
- ("no_relation", (("PER", "Entity G"), ("ORG", "H"))),
- ("no_relation", (("PER", "Entity G"), ("ORG", "I"))),
- ("no_relation", (("ORG", "I"), ("ORG", "H"))),
- ]
-
- # this is the final "output"
- relations_resolved_with_score = [
- (rel.resolve(), round(rel.score, 4)) for _, rel in result_flat
- ]
- if add_candidate_relations:
- # if candidate relations were added, the no-relation is removed
- assert relations_resolved_with_score == [
- (("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), 0.6),
- (("per:founder", (("PER", "Entity G"), ("ORG", "I"))), 0.5),
- ]
- else:
- # if no candidate relations were added, the no-relation is kept
- assert relations_resolved_with_score == [
- (("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), 0.6),
- (("per:founder", (("PER", "Entity G"), ("ORG", "I"))), 0.5),
- (("no_relation", (("ORG", "I"), ("ORG", "H"))), 0.6),
- ]
-
-
-@pytest.mark.parametrize("as_list", [False, True])
-@pytest.mark.parametrize("add_candidate_relations", [False, True])
-def test_create_annotations_from_output_with_argument_and_relation_type_whitelist(
- add_candidate_relations, as_list
-):
- if as_list:
- argument_and_relation_type_whitelist = [
- ["per:employee_of", "PER", "ORG"],
- ["per:founder", "PER", "ORG"],
- ["org:founded_by", "ORG", "PER"],
- ["no_relation", "PER", "ORG"],
- ["no_relation", "ORG", "PER"],
- ]
- else:
- argument_and_relation_type_whitelist = {
- "per:employee_of": [["PER", "ORG"]],
- "per:founder": [["PER", "ORG"]],
- "org:founded_by": [["ORG", "PER"]],
- "no_relation": [["PER", "ORG"], ["ORG", "PER"]],
- }
- taskmodule = RETextClassificationWithIndicesTaskModule(
- relation_annotation="relations",
- tokenizer_name_or_path="bert-base-cased",
- # pass in the labels and entity_labels to avoid calling prepare
- # (which would required documents to collect the labels from)
- labels=["org:founded_by", "per:employee_of", "per:founder"],
- entity_labels=["PER", "ORG"],
- # we want to test the effect of creating candidate relations
- add_candidate_relations=add_candidate_relations,
- argument_and_relation_type_whitelist=argument_and_relation_type_whitelist,
- )
- # just call post_prepare to set up the taskmodule since labels
- # and entity_labels are already set
- taskmodule.post_prepare()
-
- entities = [
- LabeledSpan(start=16, end=24, label="PER"),
- LabeledSpan(start=34, end=35, label="ORG"),
- LabeledSpan(start=49, end=50, label="ORG"),
- ]
-
- assert taskmodule.none_label == "no_relation"
- candidate_relations = [
- BinaryRelation(head=entities[0], tail=entities[1], label="no_relation"),
- BinaryRelation(head=entities[0], tail=entities[2], label="no_relation"),
- BinaryRelation(head=entities[2], tail=entities[0], label="no_relation"),
- BinaryRelation(head=entities[2], tail=entities[1], label="no_relation"),
- BinaryRelation(head=entities[1], tail=entities[2], label="no_relation"),
- ]
-
- # Just create the task encodings with dummy inputs and a dummy document since
- # we do not want to pass them into the model, but add correct metadata
- # (which is used to create the annotations).
- task_encodings = [
- TaskEncoding(inputs={}, metadata={"candidate_annotation": rel}, document=Document())
- for rel in candidate_relations
- ]
- unbatched_model_outputs = [
- {"labels": ["per:employee_of"], "probabilities": [0.6000000238418579]},
- {"labels": ["per:founder"], "probabilities": [0.5]},
- {"labels": ["no_relation"], "probabilities": [0.6000000238418579]},
- {"labels": ["org:founded_by"], "probabilities": [0.6000000238418579]},
- {"labels": ["no_relation"], "probabilities": [0.6000000238418579]},
- ]
-
- result_flat = []
- for i in range(len(unbatched_model_outputs)):
- result_flat.extend(
- list(
- taskmodule.create_annotations_from_output(
- task_encoding=task_encodings[i], task_output=unbatched_model_outputs[i]
- )
- )
- )
-
- # The entities need to be added to a document. This is only required to resolve
- # the relations later on for better readability!
- document = TestDocument(text="First sentence. Entity G works at H. And founded I.")
- document.entities.extend(entities)
-
- # this would be the "model input"
- assert [rel.resolve() for rel in candidate_relations] == [
- ("no_relation", (("PER", "Entity G"), ("ORG", "H"))),
- ("no_relation", (("PER", "Entity G"), ("ORG", "I"))),
- ("no_relation", (("ORG", "I"), ("PER", "Entity G"))),
- ("no_relation", (("ORG", "I"), ("ORG", "H"))),
- ("no_relation", (("ORG", "H"), ("ORG", "I"))),
- ]
-
- # this is the final "output"
- relations_resolved_with_score = [
- (rel.resolve(), round(rel.score, 4)) for _, rel in result_flat
- ]
- if add_candidate_relations:
- # if candidate relations were added, no-relations are removed
- # relations with wrong entity types are also removed.
- assert relations_resolved_with_score == [
- (("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), 0.6),
- (("per:founder", (("PER", "Entity G"), ("ORG", "I"))), 0.5),
- ]
- else:
- # if no candidate relations were added, only relations not fitting the filter
- # are removed. We explicitly need to add "no_relation" with possible argument types
- # to whitelist if we don't want them to be filtered.
- assert relations_resolved_with_score == [
- (("per:employee_of", (("PER", "Entity G"), ("ORG", "H"))), 0.6),
- (("per:founder", (("PER", "Entity G"), ("ORG", "I"))), 0.5),
- (("no_relation", (("ORG", "I"), ("PER", "Entity G"))), 0.6),
- ]
diff --git a/tests/taskmodules/test_text2text.py b/tests/taskmodules/test_text2text.py
deleted file mode 100644
index 7dc779acb..000000000
--- a/tests/taskmodules/test_text2text.py
+++ /dev/null
@@ -1,275 +0,0 @@
-import pickle
-from typing import Any, Dict, List, Sequence, Tuple
-
-import pytest
-import torch
-from pie_core import Annotation, TaskEncoding
-
-from pie_modules.annotations import AbstractiveSummary
-from pie_modules.documents import (
- TextDocumentWithAbstractiveSummary,
- TokenDocumentWithAbstractiveSummary,
-)
-from pie_modules.models.common import VALIDATION
-from pie_modules.taskmodules import TextToTextTaskModule
-from pie_modules.taskmodules.text_to_text import (
- InputEncodingType,
- TargetEncodingType,
- TaskEncodingType,
- TaskOutputType,
-)
-
-
-@pytest.fixture(scope="module")
-def documents():
- result = []
-
- doc = TextDocumentWithAbstractiveSummary(text="This is a test document")
- summary = AbstractiveSummary(text="a document")
- doc.abstractive_summary.append(summary)
- result.append(doc)
-
- doc = TextDocumentWithAbstractiveSummary(
- text="This is another test document which is a bit longer"
- )
- summary = AbstractiveSummary(text="a longer document")
- doc.abstractive_summary.append(summary)
- result.append(doc)
-
- return result
-
-
-@pytest.fixture(scope="module")
-def taskmodule():
- return TextToTextTaskModule(
- tokenizer_name_or_path="google/t5-efficient-tiny-nl2",
- document_type="pie_modules.documents.TextDocumentWithAbstractiveSummary",
- target_layer="abstractive_summary",
- target_annotation_type="pie_modules.annotations.AbstractiveSummary",
- tokenized_document_type="pie_modules.documents.TokenDocumentWithAbstractiveSummary",
- text_metric_type="torchmetrics.text.ROUGEScore",
- )
-
-
-def test_taskmodule(taskmodule):
- assert taskmodule is not None
- assert taskmodule.document_type == TextDocumentWithAbstractiveSummary
- assert taskmodule.tokenized_document_type == TokenDocumentWithAbstractiveSummary
- assert taskmodule.target_annotation_type == AbstractiveSummary
- assert taskmodule.layer_names == ["abstractive_summary"]
- assert taskmodule.generation_config == {}
-
-
-@pytest.fixture(scope="module")
-def task_encodings(taskmodule, documents) -> Sequence[TaskEncodingType]:
- encodings = taskmodule.encode(documents, encode_target=True)
- assert all(isinstance(encoding, TaskEncoding) for encoding in encodings)
- assert len(encodings) == 2 == len(documents)
- assert encodings[0].document == documents[0]
- assert encodings[1].document == documents[1]
- return encodings
-
-
-def test_maybe_log_example(taskmodule, task_encodings, caplog):
- counter_backup = taskmodule.log_first_n_examples
-
- taskmodule.log_first_n_examples = 1
- with caplog.at_level("INFO"):
- taskmodule.maybe_log_example(task_encodings[0])
-
- assert len(caplog.messages) == 3
- assert caplog.messages[0] == "input_ids: [100, 19, 3, 9, 794, 1708, 1]"
- assert caplog.messages[1] == "attention_mask: [1, 1, 1, 1, 1, 1, 1]"
- assert caplog.messages[2] == "labels: [3, 9, 1708, 1]"
-
- taskmodule.log_first_n_examples = counter_backup
-
-
-@pytest.fixture(scope="module")
-def input_encoding(taskmodule, task_encodings) -> InputEncodingType:
- assert len(task_encodings) > 0
- return task_encodings[0].inputs
-
-
-def test_input_encoding(taskmodule, input_encoding):
- assert isinstance(input_encoding, InputEncodingType)
- assert input_encoding.input_ids == [100, 19, 3, 9, 794, 1708, 1]
- assert input_encoding.attention_mask == [1, 1, 1, 1, 1, 1, 1]
-
- tokens = taskmodule.tokenizer.convert_ids_to_tokens(input_encoding.input_ids)
- assert tokens == ["▁This", "▁is", "▁", "a", "▁test", "▁document", ""]
-
-
-@pytest.fixture(scope="module")
-def metadata(taskmodule, task_encodings) -> Dict[str, Any]:
- assert len(task_encodings) > 0
- return task_encodings[0].metadata
-
-
-def test_metadata(taskmodule, metadata):
- assert set(metadata) == {"tokenized_document", "guidance_annotation"}
-
- tokenized_document = metadata["tokenized_document"]
- assert isinstance(tokenized_document, TokenDocumentWithAbstractiveSummary)
- assert tokenized_document.tokens == ("▁This", "▁is", "▁", "a", "▁test", "▁document", "")
- assert len(tokenized_document.abstractive_summary) == 1
- assert tokenized_document.abstractive_summary[0].text == "a document"
-
-
-@pytest.fixture(scope="module")
-def target_encoding(taskmodule, task_encodings) -> TargetEncodingType:
- assert len(task_encodings) > 0
- return task_encodings[0].targets
-
-
-def test_target_encoding(taskmodule, target_encoding):
- assert isinstance(target_encoding, TargetEncodingType)
- assert target_encoding.labels == [3, 9, 1708, 1]
- assert target_encoding.decoder_attention_mask == [1, 1, 1, 1]
-
-
-@pytest.fixture(scope="module")
-def batch(taskmodule, task_encodings) -> List[TaskEncodingType]:
- result = taskmodule.collate(task_encodings)
- return result
-
-
-def test_batch(taskmodule, batch):
- assert len(batch) == 2
- inputs, targets = batch
-
- assert set(inputs) == {"input_ids", "attention_mask"}
- torch.testing.assert_close(
- inputs["input_ids"],
- torch.tensor(
- [
- [100, 19, 3, 9, 794, 1708, 1, 0, 0, 0, 0, 0],
- [100, 19, 430, 794, 1708, 84, 19, 3, 9, 720, 1200, 1],
- ]
- ),
- )
- torch.testing.assert_close(
- inputs["attention_mask"],
- torch.tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
- )
-
- assert set(targets) == {"labels", "decoder_attention_mask"}
- torch.testing.assert_close(
- targets["labels"], torch.tensor([[3, 9, 1708, 1, 0], [3, 9, 1200, 1708, 1]])
- )
- torch.testing.assert_close(
- targets["decoder_attention_mask"], torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]])
- )
-
-
-@pytest.fixture(scope="module")
-def unbatched_output(taskmodule, batch) -> Sequence[TaskOutputType]:
- inputs, targets = batch
- return taskmodule.unbatch_output(targets)
-
-
-def test_unbatched_output(taskmodule, unbatched_output):
- assert all(isinstance(output, TargetEncodingType) for output in unbatched_output)
- assert len(unbatched_output) == 2
-
- assert unbatched_output[0].labels == [3, 9, 1708, 1]
- assert unbatched_output[0].decoder_attention_mask is None
-
- assert unbatched_output[1].labels == [3, 9, 1200, 1708, 1]
- assert unbatched_output[1].decoder_attention_mask is None
-
-
-@pytest.fixture(scope="module")
-def decoded_annotations(
- taskmodule, task_encodings, unbatched_output
-) -> List[Tuple[str, Annotation]]:
- result = []
- for encoding, output in zip(task_encodings, unbatched_output):
- result.extend(
- taskmodule.create_annotations_from_output(task_encoding=encoding, task_output=output)
- )
- return result
-
-
-def test_decoded_annotations(taskmodule, decoded_annotations):
- names, annotations = zip(*decoded_annotations)
- assert all(layer_name == taskmodule.target_layer for layer_name in names)
- assert all(
- isinstance(annotation, taskmodule.target_annotation_type) for annotation in annotations
- )
-
- assert len(annotations) == 2
- assert annotations[0].text == "a document"
- assert annotations[0].score is None
- assert annotations[1].text == "a longer document"
- assert annotations[1].score is None
-
-
-def test_configure_model_metrics(taskmodule):
- metric = taskmodule.configure_model_metric(stage=VALIDATION)
- assert metric is not None
- values = metric.compute()
- keys = {
- "rouge2_fmeasure",
- "rougeL_recall",
- "rouge1_precision",
- "rouge1_recall",
- "rouge2_recall",
- "rougeL_precision",
- "rouge1_fmeasure",
- "rougeLsum_recall",
- "rougeLsum_precision",
- "rougeL_fmeasure",
- "rouge2_precision",
- "rougeLsum_fmeasure",
- }
- assert set(values) == keys
- assert all(torch.isnan(value) for value in values.values())
-
- labels = torch.tensor([[3, 9, 1708, 1, 0], [3, 9, 1200, 1708, 1]])
- metric.update(prediction={"labels": labels}, target={"labels": labels})
- assert set(metric.metric_state) == keys
- assert all(
- value == [torch.tensor(1.0), torch.tensor(1.0)] for value in metric.metric_state.values()
- )
- values = metric.compute()
- assert set(values) == keys
- assert all(value == torch.tensor(1.0) for value in values.values())
-
- random_labels = torch.tensor([[875, 885, 112, 289, 769], [270, 583, 970, 114, 71]])
- metric.update(prediction={"labels": random_labels}, target={"labels": labels})
- values = metric.compute()
- assert {k: v.item() for k, v in values.items()} == {
- "rouge1_fmeasure": 0.5625,
- "rouge1_precision": 0.550000011920929,
- "rouge1_recall": 0.5833333134651184,
- "rouge2_fmeasure": 0.5,
- "rouge2_precision": 0.5,
- "rouge2_recall": 0.5,
- "rougeL_fmeasure": 0.5625,
- "rougeL_precision": 0.550000011920929,
- "rougeL_recall": 0.5833333134651184,
- "rougeLsum_fmeasure": 0.5625,
- "rougeLsum_precision": 0.550000011920929,
- "rougeLsum_recall": 0.5833333134651184,
- }
-
- # ensure that the metric can be pickled
- pickle.dumps(metric)
-
-
-def test_configure_model_generation(taskmodule):
- generation_config = taskmodule.configure_model_generation()
- assert generation_config is not None
- assert generation_config == {}
-
-
-def test_warn_once(taskmodule, caplog):
- with caplog.at_level("WARNING"):
- taskmodule.warn_only_once("test")
- taskmodule.warn_only_once("test")
- taskmodule.warn_only_once("test2")
-
- assert len(caplog.messages) == 2
- assert caplog.messages[0] == "test (This warning will only be shown once)"
- assert caplog.messages[1] == "test2 (This warning will only be shown once)"
diff --git a/tests/taskmodules/test_text2text_with_guidance.py b/tests/taskmodules/test_text2text_with_guidance.py
deleted file mode 100644
index c66711610..000000000
--- a/tests/taskmodules/test_text2text_with_guidance.py
+++ /dev/null
@@ -1,240 +0,0 @@
-from typing import Any, Dict, List, Sequence, Tuple
-
-import pytest
-import torch
-from pie_core import Annotation, TaskEncoding
-
-from pie_modules.annotations import GenerativeAnswer, Question
-from pie_modules.documents import (
- TextDocumentWithQuestionsAndGenerativeAnswers,
- TokenDocumentWithQuestionsAndGenerativeAnswers,
-)
-from pie_modules.taskmodules import TextToTextTaskModule
-from pie_modules.taskmodules.text_to_text import (
- InputEncodingType,
- TargetEncodingType,
- TaskEncodingType,
- TaskOutputType,
-)
-
-
-@pytest.fixture(scope="module")
-def documents():
- result = []
-
- doc = TextDocumentWithQuestionsAndGenerativeAnswers(text="This is a test document")
- question = Question(text="What is this?")
- doc.questions.append(question)
- answer = GenerativeAnswer(text="a document", question=question)
- doc.generative_answers.append(answer)
- result.append(doc)
-
- doc = TextDocumentWithQuestionsAndGenerativeAnswers(
- text="This is another test document which is a bit longer."
- )
- question = Question(text="And what is this?")
- doc.questions.append(question)
- answer = GenerativeAnswer(text="a longer document", question=question)
- doc.generative_answers.append(answer)
- result.append(doc)
-
- return result
-
-
-@pytest.fixture(scope="module")
-def taskmodule():
- return TextToTextTaskModule(
- tokenizer_name_or_path="google/t5-efficient-tiny-nl2",
- document_type="pie_modules.documents.TextDocumentWithQuestionsAndGenerativeAnswers",
- target_layer="generative_answers",
- target_annotation_type="pie_modules.annotations.GenerativeAnswer",
- tokenized_document_type="pie_modules.documents.TokenDocumentWithQuestionsAndGenerativeAnswers",
- guidance_layer="questions",
- guidance_annotation_field="question",
- text_metric_type="torchmetrics.text.ROUGEScore",
- )
-
-
-def test_taskmodule(taskmodule):
- assert taskmodule is not None
- assert taskmodule.document_type == TextDocumentWithQuestionsAndGenerativeAnswers
- assert taskmodule.tokenized_document_type == TokenDocumentWithQuestionsAndGenerativeAnswers
- assert taskmodule.target_annotation_type == GenerativeAnswer
- assert taskmodule.layer_names == ["generative_answers"]
- assert taskmodule.generation_config == {}
-
-
-@pytest.fixture(scope="module")
-def task_encodings(taskmodule, documents) -> Sequence[TaskEncodingType]:
- encodings = taskmodule.encode(documents, encode_target=True)
- assert all(isinstance(encoding, TaskEncoding) for encoding in encodings)
- assert len(encodings) == 2 == len(documents)
- assert encodings[0].document == documents[0]
- assert encodings[1].document == documents[1]
- return encodings
-
-
-def test_maybe_log_example(taskmodule, task_encodings, caplog):
- counter_backup = taskmodule.log_first_n_examples
-
- taskmodule.log_first_n_examples = 1
- with caplog.at_level("INFO"):
- taskmodule.maybe_log_example(task_encodings[0])
-
- assert len(caplog.messages) == 3
- assert caplog.messages[0] == "input_ids: [363, 19, 48, 58, 1, 100, 19, 3, 9, 794, 1708, 1]"
- assert caplog.messages[1] == "attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]"
- assert caplog.messages[2] == "labels: [3, 9, 1708, 1]"
-
- taskmodule.log_first_n_examples = counter_backup
-
-
-@pytest.fixture(scope="module")
-def input_encoding(taskmodule, task_encodings) -> InputEncodingType:
- assert len(task_encodings) > 0
- return task_encodings[0].inputs
-
-
-def test_input_encoding(taskmodule, input_encoding):
- assert isinstance(input_encoding, InputEncodingType)
- assert input_encoding.input_ids == [363, 19, 48, 58, 1, 100, 19, 3, 9, 794, 1708, 1]
- assert input_encoding.attention_mask == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
-
- tokens = taskmodule.tokenizer.convert_ids_to_tokens(input_encoding.input_ids)
- assert tokens == [
- "▁What",
- "▁is",
- "▁this",
- "?",
- "",
- "▁This",
- "▁is",
- "▁",
- "a",
- "▁test",
- "▁document",
- "",
- ]
-
-
-@pytest.fixture(scope="module")
-def metadata(taskmodule, task_encodings) -> Dict[str, Any]:
- assert len(task_encodings) > 0
- return task_encodings[0].metadata
-
-
-def test_metadata(taskmodule, metadata):
- assert set(metadata) == {"tokenized_document", "guidance_annotation"}
-
- tokenized_document = metadata["tokenized_document"]
- assert isinstance(tokenized_document, TokenDocumentWithQuestionsAndGenerativeAnswers)
- assert tokenized_document.tokens == (
- "▁What",
- "▁is",
- "▁this",
- "?",
- "",
- "▁This",
- "▁is",
- "▁",
- "a",
- "▁test",
- "▁document",
- "",
- )
- assert len(tokenized_document.questions) == 1
- assert tokenized_document.questions[0].text == "What is this?"
-
-
-@pytest.fixture(scope="module")
-def target_encoding(taskmodule, task_encodings) -> TargetEncodingType:
- assert len(task_encodings) > 0
- return task_encodings[0].targets
-
-
-def test_target_encoding(taskmodule, target_encoding):
- assert isinstance(target_encoding, TargetEncodingType)
- assert target_encoding.labels == [3, 9, 1708, 1]
- assert target_encoding.decoder_attention_mask == [1, 1, 1, 1]
-
-
-@pytest.fixture(scope="module")
-def batch(taskmodule, task_encodings) -> List[TaskEncodingType]:
- result = taskmodule.collate(task_encodings)
- return result
-
-
-def test_batch(taskmodule, batch):
- assert len(batch) == 2
- inputs, targets = batch
-
- assert set(inputs) == {"input_ids", "attention_mask"}
- torch.testing.assert_close(
- inputs["input_ids"],
- torch.tensor(
- [
- [363, 19, 48, 58, 1, 100, 19, 3, 9, 794, 1708, 1, 0, 0, 0, 0, 0, 0, 0],
- [275, 125, 19, 48, 58, 1, 100, 19, 430, 794, 1708, 84, 19, 3, 9, 720, 1200, 5, 1],
- ]
- ),
- )
- torch.testing.assert_close(
- inputs["attention_mask"],
- torch.tensor(
- [
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
- ]
- ),
- )
-
- assert set(targets) == {"labels", "decoder_attention_mask"}
- torch.testing.assert_close(
- targets["labels"], torch.tensor([[3, 9, 1708, 1, 0], [3, 9, 1200, 1708, 1]])
- )
- torch.testing.assert_close(
- targets["decoder_attention_mask"], torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]])
- )
-
-
-@pytest.fixture(scope="module")
-def unbatched_output(taskmodule, batch) -> Sequence[TaskOutputType]:
- inputs, targets = batch
- return taskmodule.unbatch_output(targets)
-
-
-def test_unbatched_output(taskmodule, unbatched_output):
- assert all(isinstance(output, TargetEncodingType) for output in unbatched_output)
- assert len(unbatched_output) == 2
-
- assert unbatched_output[0].labels == [3, 9, 1708, 1]
- assert unbatched_output[0].decoder_attention_mask is None
-
- assert unbatched_output[1].labels == [3, 9, 1200, 1708, 1]
- assert unbatched_output[1].decoder_attention_mask is None
-
-
-@pytest.fixture(scope="module")
-def decoded_annotations(
- taskmodule, task_encodings, unbatched_output
-) -> List[Tuple[str, Annotation]]:
- result = []
- for encoding, output in zip(task_encodings, unbatched_output):
- result.extend(
- taskmodule.create_annotations_from_output(task_encoding=encoding, task_output=output)
- )
- return result
-
-
-def test_decoded_annotations(taskmodule, decoded_annotations):
- names, annotations = zip(*decoded_annotations)
- assert all(layer_name == taskmodule.target_layer for layer_name in names)
- assert all(
- isinstance(annotation, taskmodule.target_annotation_type) for annotation in annotations
- )
-
- assert len(annotations) == 2
- assert annotations[0].text == "a document"
- assert annotations[0].score is None
- assert annotations[1].text == "a longer document"
- assert annotations[1].score is None
diff --git a/tests/test_annotations.py b/tests/test_annotations.py
index 5ea2a0b85..b41d4bc5d 100644
--- a/tests/test_annotations.py
+++ b/tests/test_annotations.py
@@ -1,10 +1,23 @@
+import dataclasses
import json
+import re
from typing import Dict, Optional
import pytest
-from pie_core import Annotation
+from pie_core import Annotation, AnnotationLayer, annotation_field
-from pie_modules.annotations import LabeledMultiSpan
+from pie_modules.annotations import (
+ BinaryRelation,
+ Label,
+ LabeledMultiSpan,
+ LabeledSpan,
+ MultiLabel,
+ MultiLabeledBinaryRelation,
+ MultiLabeledSpan,
+ NaryRelation,
+ Span,
+)
+from pie_modules.documents import TextBasedDocument
def _test_annotation_reconstruction(
@@ -17,6 +30,163 @@ def _test_annotation_reconstruction(
assert annotation_reconstructed == annotation
+def test_label():
+ label1 = Label(label="label1")
+ assert label1.label == "label1"
+ assert label1.score == pytest.approx(1.0)
+ assert label1.resolve() == "label1"
+
+ label2 = Label(label="label2", score=0.5)
+ assert label2.label == "label2"
+ assert label2.score == pytest.approx(0.5)
+
+ assert label2.asdict() == {
+ "_id": label2._id,
+ "label": "label2",
+ "score": 0.5,
+ }
+
+ _test_annotation_reconstruction(label2)
+
+
+def test_multilabel():
+ multilabel1 = MultiLabel(label=("label1", "label2"))
+ assert multilabel1.label == ("label1", "label2")
+ assert multilabel1.score == pytest.approx((1.0, 1.0))
+ assert multilabel1.resolve() == ("label1", "label2")
+
+ multilabel2 = MultiLabel(label=("label3", "label4"), score=(0.4, 0.5))
+ assert multilabel2.label == ("label3", "label4")
+ assert multilabel2.score == pytest.approx((0.4, 0.5))
+
+ assert multilabel2.asdict() == {
+ "_id": multilabel2._id,
+ "label": ("label3", "label4"),
+ "score": (0.4, 0.5),
+ }
+
+ _test_annotation_reconstruction(multilabel2)
+
+ with pytest.raises(
+ ValueError, match=re.escape("Number of labels (2) and scores (3) must be equal.")
+ ):
+ MultiLabel(label=("label5", "label6"), score=(0.1, 0.2, 0.3))
+
+
+def test_span():
+ span = Span(start=1, end=2)
+ assert span.start == 1
+ assert span.end == 2
+
+ assert span.asdict() == {
+ "_id": span._id,
+ "start": 1,
+ "end": 2,
+ }
+
+ _test_annotation_reconstruction(span)
+
+ with pytest.raises(ValueError) as excinfo:
+ span.resolve()
+ assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target."
+
+ @dataclasses.dataclass
+ class TestDocument(TextBasedDocument):
+ spans: AnnotationLayer[Span] = annotation_field(target="text")
+
+ doc = TestDocument(text="Hello, world!")
+ span = Span(start=7, end=12)
+ doc.spans.append(span)
+ assert span.resolve() == "world"
+
+
+def test_labeled_span():
+ labeled_span1 = LabeledSpan(start=1, end=2, label="label1")
+ assert labeled_span1.start == 1
+ assert labeled_span1.end == 2
+ assert labeled_span1.label == "label1"
+ assert labeled_span1.score == pytest.approx(1.0)
+
+ labeled_span2 = LabeledSpan(start=3, end=4, label="label2", score=0.5)
+ assert labeled_span2.start == 3
+ assert labeled_span2.end == 4
+ assert labeled_span2.label == "label2"
+ assert labeled_span2.score == pytest.approx(0.5)
+
+ assert labeled_span2.asdict() == {
+ "_id": labeled_span2._id,
+ "start": 3,
+ "end": 4,
+ "label": "label2",
+ "score": 0.5,
+ }
+
+ _test_annotation_reconstruction(labeled_span2)
+
+ with pytest.raises(ValueError) as excinfo:
+ labeled_span1.resolve()
+ assert (
+ str(excinfo.value)
+ == "LabeledSpan(start=1, end=2, label='label1', score=1.0) is not attached to a target."
+ )
+
+ @dataclasses.dataclass
+ class TestDocument(TextBasedDocument):
+ spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
+
+ doc = TestDocument(text="Hello, world!")
+ labeled_span = LabeledSpan(start=7, end=12, label="LOC")
+ doc.spans.append(labeled_span)
+ assert labeled_span.resolve() == ("LOC", "world")
+
+
+def test_multilabeled_span():
+ multilabeled_span1 = MultiLabeledSpan(start=1, end=2, label=("label1", "label2"))
+ assert multilabeled_span1.start == 1
+ assert multilabeled_span1.end == 2
+ assert multilabeled_span1.label == ("label1", "label2")
+ assert multilabeled_span1.score == pytest.approx((1.0, 1.0))
+
+ multilabeled_span2 = MultiLabeledSpan(
+ start=3, end=4, label=("label3", "label4"), score=(0.4, 0.5)
+ )
+ assert multilabeled_span2.start == 3
+ assert multilabeled_span2.end == 4
+ assert multilabeled_span2.label == ("label3", "label4")
+ assert multilabeled_span2.score == pytest.approx((0.4, 0.5))
+
+ assert multilabeled_span2.asdict() == {
+ "_id": multilabeled_span2._id,
+ "start": 3,
+ "end": 4,
+ "label": ("label3", "label4"),
+ "score": (0.4, 0.5),
+ }
+
+ _test_annotation_reconstruction(multilabeled_span2)
+
+ with pytest.raises(
+ ValueError, match=re.escape("Number of labels (2) and scores (3) must be equal.")
+ ):
+ MultiLabeledSpan(start=5, end=6, label=("label5", "label6"), score=(0.1, 0.2, 0.3))
+
+ with pytest.raises(ValueError) as excinfo:
+ multilabeled_span1.resolve()
+ assert (
+ str(excinfo.value)
+ == "MultiLabeledSpan(start=1, end=2, label=('label1', 'label2'), score=(1.0, 1.0)) is not attached to a target."
+ )
+
+ @dataclasses.dataclass
+ class TestDocument(TextBasedDocument):
+ spans: AnnotationLayer[MultiLabeledSpan] = annotation_field(target="text")
+
+ doc = TestDocument(text="Hello, world!")
+ multilabeled_span = MultiLabeledSpan(start=7, end=12, label=("LOC", "ORG"))
+ doc.spans.append(multilabeled_span)
+ assert multilabeled_span.resolve() == (("LOC", "ORG"), "world")
+
+
def test_labeled_multi_span():
labeled_multi_span1 = LabeledMultiSpan(slices=((1, 2), (3, 4)), label="label1")
assert labeled_multi_span1.slices == ((1, 2), (3, 4))
@@ -40,3 +210,179 @@ def test_labeled_multi_span():
}
_test_annotation_reconstruction(labeled_multi_span2)
+
+
+def test_binary_relation():
+ head = Span(start=1, end=2)
+ tail = Span(start=3, end=4)
+
+ binary_relation1 = BinaryRelation(head=head, tail=tail, label="label1")
+ assert binary_relation1.head == head
+ assert binary_relation1.tail == tail
+ assert binary_relation1.label == "label1"
+ assert binary_relation1.score == pytest.approx(1.0)
+
+ binary_relation2 = BinaryRelation(head=head, tail=tail, label="label2", score=0.5)
+ assert binary_relation2.head == head
+ assert binary_relation2.tail == tail
+ assert binary_relation2.label == "label2"
+ assert binary_relation2.score == pytest.approx(0.5)
+
+ assert binary_relation2.asdict() == {
+ "_id": binary_relation2._id,
+ "head": head._id,
+ "tail": tail._id,
+ "label": "label2",
+ "score": 0.5,
+ }
+
+ annotation_store = {
+ head._id: head,
+ tail._id: tail,
+ }
+ _test_annotation_reconstruction(binary_relation2, annotation_store=annotation_store)
+
+ with pytest.raises(
+ ValueError,
+ match=re.escape("Unable to resolve the annotation id without annotation_store."),
+ ):
+ BinaryRelation.fromdict(binary_relation2.asdict())
+
+ with pytest.raises(ValueError) as excinfo:
+ binary_relation1.resolve()
+ assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target."
+
+ @dataclasses.dataclass
+ class TestDocument(TextBasedDocument):
+ spans: AnnotationLayer[Span] = annotation_field(target="text")
+ relations: AnnotationLayer[BinaryRelation] = annotation_field(target="spans")
+
+ doc = TestDocument(text="Hello, world!")
+ head = Span(start=0, end=5)
+ tail = Span(start=7, end=12)
+ doc.spans.extend([head, tail])
+ relation = BinaryRelation(head=head, tail=tail, label="LABEL")
+ doc.relations.append(relation)
+ assert relation.resolve() == ("LABEL", ("Hello", "world"))
+
+
+def test_multilabeled_binary_relation():
+ head = Span(start=1, end=2)
+ tail = Span(start=3, end=4)
+
+ binary_relation1 = MultiLabeledBinaryRelation(head=head, tail=tail, label=("label1", "label2"))
+ assert binary_relation1.head == head
+ assert binary_relation1.tail == tail
+ assert binary_relation1.label == ("label1", "label2")
+ assert binary_relation1.score == pytest.approx((1.0, 1.0))
+
+ binary_relation2 = MultiLabeledBinaryRelation(
+ head=head, tail=tail, label=("label3", "label4"), score=(0.4, 0.5)
+ )
+ assert binary_relation2.head == head
+ assert binary_relation2.tail == tail
+ assert binary_relation2.label == ("label3", "label4")
+ assert binary_relation2.score == pytest.approx((0.4, 0.5))
+
+ assert binary_relation2.asdict() == {
+ "_id": binary_relation2._id,
+ "head": head._id,
+ "tail": tail._id,
+ "label": ("label3", "label4"),
+ "score": (0.4, 0.5),
+ }
+
+ annotation_store = {
+ head._id: head,
+ tail._id: tail,
+ }
+ _test_annotation_reconstruction(binary_relation2, annotation_store=annotation_store)
+
+ with pytest.raises(
+ ValueError,
+ match=re.escape("Unable to resolve the annotation id without annotation_store."),
+ ):
+ MultiLabeledBinaryRelation.fromdict(binary_relation2.asdict())
+
+ with pytest.raises(
+ ValueError, match=re.escape("Number of labels (2) and scores (3) must be equal.")
+ ):
+ MultiLabeledBinaryRelation(
+ head=head, tail=tail, label=("label5", "label6"), score=(0.1, 0.2, 0.3)
+ )
+
+ with pytest.raises(ValueError) as excinfo:
+ binary_relation1.resolve()
+ assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target."
+
+ @dataclasses.dataclass
+ class TestDocument(TextBasedDocument):
+ spans: AnnotationLayer[Span] = annotation_field(target="text")
+ relations: AnnotationLayer[MultiLabeledBinaryRelation] = annotation_field(target="spans")
+
+ doc = TestDocument(text="Hello, world!")
+ head = Span(start=0, end=5)
+ tail = Span(start=7, end=12)
+ doc.spans.extend([head, tail])
+ relation = MultiLabeledBinaryRelation(head=head, tail=tail, label=("LABEL1", "LABEL2"))
+ doc.relations.append(relation)
+ assert relation.resolve() == (("LABEL1", "LABEL2"), ("Hello", "world"))
+
+
+def test_nary_relation():
+ arg1 = Span(start=1, end=2)
+ arg2 = Span(start=3, end=4)
+ arg3 = Span(start=5, end=6)
+
+ nary_relation1 = NaryRelation(
+ arguments=(arg1, arg2, arg3), roles=("role1", "role2", "role3"), label="label1"
+ )
+
+ assert nary_relation1.arguments == (arg1, arg2, arg3)
+ assert nary_relation1.roles == ("role1", "role2", "role3")
+ assert nary_relation1.label == "label1"
+ assert nary_relation1.score == pytest.approx(1.0)
+
+ assert nary_relation1.asdict() == {
+ "_id": nary_relation1._id,
+ "arguments": [arg1._id, arg2._id, arg3._id],
+ "roles": ("role1", "role2", "role3"),
+ "label": "label1",
+ "score": 1.0,
+ }
+
+ annotation_store = {
+ arg1._id: arg1,
+ arg2._id: arg2,
+ arg3._id: arg3,
+ }
+ _test_annotation_reconstruction(nary_relation1, annotation_store=annotation_store)
+
+ with pytest.raises(
+ ValueError,
+ match=re.escape("Unable to resolve the annotation id without annotation_store."),
+ ):
+ NaryRelation.fromdict(nary_relation1.asdict())
+
+ with pytest.raises(ValueError) as excinfo:
+ nary_relation1.resolve()
+ assert str(excinfo.value) == "Span(start=1, end=2) is not attached to a target."
+
+ @dataclasses.dataclass
+ class TestDocument(TextBasedDocument):
+ spans: AnnotationLayer[Span] = annotation_field(target="text")
+ relations: AnnotationLayer[NaryRelation] = annotation_field(target="spans")
+
+ doc = TestDocument(text="Hello, world A and B!")
+ arg1 = Span(start=0, end=5)
+ arg2 = Span(start=7, end=14)
+ arg3 = Span(start=19, end=20)
+ doc.spans.extend([arg1, arg2, arg3])
+ relation = NaryRelation(
+ arguments=(arg1, arg2, arg3), roles=("ARG1", "ARG2", "ARG3"), label="LABEL"
+ )
+ doc.relations.append(relation)
+ assert relation.resolve() == (
+ "LABEL",
+ (("ARG1", "Hello"), ("ARG2", "world A"), ("ARG3", "B")),
+ )
diff --git a/tests/utils/test_hydra.py b/tests/utils/test_hydra.py
deleted file mode 100644
index e0f6ddf82..000000000
--- a/tests/utils/test_hydra.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import dataclasses
-
-import pytest
-from pie_core import AnnotationLayer, annotation_field
-from pie_core.utils.hydra import resolve_type
-
-from pie_modules.annotations import LabeledSpan, Span
-from pie_modules.documents import TextBasedDocument
-
-
-@dataclasses.dataclass
-class TestDocumentWithEntities(TextBasedDocument):
- entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")
-
-
-@dataclasses.dataclass
-class TestDocumentWithSentences(TextBasedDocument):
- sentences: AnnotationLayer[Span] = annotation_field(target="text")
-
-
-def test_resolve_document_type():
- assert resolve_type(TestDocumentWithEntities) == TestDocumentWithEntities
- assert (
- resolve_type("tests.utils.test_hydra.TestDocumentWithEntities") == TestDocumentWithEntities
- )
- with pytest.raises(TypeError) as exc_info:
- resolve_type("tests.utils.test_hydra.test_resolve_document_type")
- assert str(exc_info.value).startswith(
- "type must be a subclass of None or a string that resolves to that, but got "
- "